summaryrefslogtreecommitdiff
path: root/src/silx
diff options
context:
space:
mode:
authorPicca Frédéric-Emmanuel <picca@debian.org>2024-02-05 16:30:07 +0100
committerPicca Frédéric-Emmanuel <picca@debian.org>2024-02-05 16:30:07 +0100
commit04095a69f18767d222b16fae5b40f2b712cd6f7e (patch)
treed20abd3ee2f237319443e9dfd7500ad55d29a33d /src/silx
parent3427caf0e96690e56aac6231a91df8f0f7a64fc2 (diff)
New upstream version 2.0.0+dfsg
Diffstat (limited to 'src/silx')
-rw-r--r--src/silx/__main__.py28
-rw-r--r--src/silx/_config.py75
-rw-r--r--src/silx/_version.py48
-rw-r--r--src/silx/app/compare/CompareImagesWindow.py254
-rw-r--r--src/silx/app/compare/__init__.py (renamed from src/silx/gui/plot/matplotlib/__init__.py)19
-rw-r--r--src/silx/app/compare/main.py105
-rw-r--r--src/silx/app/compare/test/__init__.py23
-rw-r--r--src/silx/app/compare/test/test_compare.py (renamed from src/silx/third_party/scipy_spatial.py)39
-rw-r--r--src/silx/app/compare/test/test_launcher.py142
-rw-r--r--src/silx/app/convert.py362
-rw-r--r--src/silx/app/test/test_convert.py18
-rw-r--r--src/silx/app/utils/__init__.py (renamed from src/silx/app/view/utils.py)23
-rw-r--r--src/silx/app/utils/parseutils.py133
-rw-r--r--src/silx/app/utils/test/__init__.py23
-rw-r--r--src/silx/app/utils/test/test_parseutils.py68
-rw-r--r--src/silx/app/view/About.py24
-rw-r--r--src/silx/app/view/ApplicationContext.py39
-rw-r--r--src/silx/app/view/CustomNxdataWidget.py12
-rw-r--r--src/silx/app/view/DataPanel.py2
-rw-r--r--src/silx/app/view/Viewer.py183
-rw-r--r--src/silx/app/view/main.py108
-rw-r--r--src/silx/app/view/test/test_launcher.py13
-rw-r--r--src/silx/app/view/test/test_view.py22
-rw-r--r--src/silx/conftest.py69
-rw-r--r--src/silx/gui/_glutils/Context.py2
-rw-r--r--src/silx/gui/_glutils/FramebufferTexture.py68
-rw-r--r--src/silx/gui/_glutils/OpenGLWidget.py142
-rw-r--r--src/silx/gui/_glutils/Program.py33
-rw-r--r--src/silx/gui/_glutils/Texture.py106
-rw-r--r--src/silx/gui/_glutils/VertexBuffer.py82
-rw-r--r--src/silx/gui/_glutils/__init__.py2
-rw-r--r--src/silx/gui/_glutils/font.py158
-rw-r--r--src/silx/gui/_glutils/gl.py60
-rw-r--r--src/silx/gui/_glutils/test/__init__.py2
-rw-r--r--src/silx/gui/_glutils/test/test_gl.py4
-rw-r--r--src/silx/gui/_glutils/utils.py7
-rwxr-xr-xsrc/silx/gui/colors.py687
-rw-r--r--src/silx/gui/conftest.py42
-rw-r--r--src/silx/gui/console.py34
-rw-r--r--src/silx/gui/constants.py27
-rw-r--r--src/silx/gui/data/ArrayTableModel.py106
-rw-r--r--src/silx/gui/data/ArrayTableWidget.py11
-rw-r--r--src/silx/gui/data/DataViewer.py43
-rw-r--r--src/silx/gui/data/DataViewerSelector.py4
-rw-r--r--src/silx/gui/data/DataViews.py420
-rw-r--r--src/silx/gui/data/Hdf5TableView.py119
-rw-r--r--src/silx/gui/data/HexaTableView.py14
-rw-r--r--src/silx/gui/data/NXdataWidgets.py284
-rw-r--r--src/silx/gui/data/NumpyAxesSelector.py49
-rw-r--r--src/silx/gui/data/RecordTableView.py28
-rw-r--r--src/silx/gui/data/TextFormatter.py44
-rw-r--r--src/silx/gui/data/_RecordPlot.py39
-rw-r--r--src/silx/gui/data/_VolumeWindow.py24
-rw-r--r--src/silx/gui/data/test/test_arraywidget.py66
-rw-r--r--src/silx/gui/data/test/test_dataviewer.py33
-rw-r--r--src/silx/gui/data/test/test_numpyaxesselector.py2
-rw-r--r--src/silx/gui/data/test/test_textformatter.py46
-rw-r--r--src/silx/gui/dialog/AbstractDataFileDialog.py165
-rw-r--r--src/silx/gui/dialog/ColormapDialog.py358
-rw-r--r--src/silx/gui/dialog/DataFileDialog.py4
-rw-r--r--src/silx/gui/dialog/DatasetDialog.py27
-rw-r--r--src/silx/gui/dialog/FileTypeComboBox.py24
-rw-r--r--src/silx/gui/dialog/GroupDialog.py33
-rw-r--r--src/silx/gui/dialog/ImageFileDialog.py10
-rw-r--r--src/silx/gui/dialog/SafeFileIconProvider.py1
-rw-r--r--src/silx/gui/dialog/SafeFileSystemModel.py25
-rw-r--r--src/silx/gui/dialog/test/test_colormapdialog.py704
-rw-r--r--src/silx/gui/dialog/test/test_datafiledialog.py203
-rw-r--r--src/silx/gui/dialog/test/test_imagefiledialog.py172
-rw-r--r--src/silx/gui/dialog/utils.py2
-rw-r--r--src/silx/gui/fit/BackgroundWidget.py149
-rw-r--r--src/silx/gui/fit/FitConfig.py185
-rw-r--r--src/silx/gui/fit/FitWidget.py286
-rw-r--r--src/silx/gui/fit/FitWidgets.py101
-rw-r--r--src/silx/gui/fit/Parameters.py475
-rw-r--r--src/silx/gui/fit/test/testBackgroundWidget.py19
-rw-r--r--src/silx/gui/fit/test/testFitConfig.py36
-rw-r--r--src/silx/gui/fit/test/testFitWidget.py18
-rw-r--r--src/silx/gui/hdf5/Hdf5Formatter.py9
-rw-r--r--src/silx/gui/hdf5/Hdf5HeaderView.py60
-rwxr-xr-xsrc/silx/gui/hdf5/Hdf5Item.py157
-rw-r--r--src/silx/gui/hdf5/Hdf5Node.py3
-rw-r--r--src/silx/gui/hdf5/Hdf5TreeModel.py100
-rw-r--r--src/silx/gui/hdf5/Hdf5TreeView.py17
-rw-r--r--src/silx/gui/hdf5/NexusSortFilterProxyModel.py7
-rw-r--r--src/silx/gui/hdf5/__init__.py8
-rw-r--r--src/silx/gui/hdf5/_utils.py15
-rwxr-xr-xsrc/silx/gui/hdf5/test/test_hdf5.py282
-rw-r--r--src/silx/gui/icons.py35
-rw-r--r--src/silx/gui/plot/AlphaSlider.py30
-rw-r--r--src/silx/gui/plot/ColorBar.py216
-rw-r--r--src/silx/gui/plot/Colormap.py41
-rw-r--r--src/silx/gui/plot/ColormapDialog.py40
-rw-r--r--src/silx/gui/plot/Colors.py87
-rw-r--r--src/silx/gui/plot/CompareImages.py1040
-rw-r--r--src/silx/gui/plot/ComplexImageView.py117
-rw-r--r--src/silx/gui/plot/CurvesROIWidget.py474
-rw-r--r--src/silx/gui/plot/ImageStack.py348
-rw-r--r--src/silx/gui/plot/ImageView.py290
-rw-r--r--src/silx/gui/plot/Interaction.py51
-rw-r--r--src/silx/gui/plot/ItemsSelectionDialog.py46
-rwxr-xr-xsrc/silx/gui/plot/LegendSelector.py505
-rw-r--r--src/silx/gui/plot/LimitsHistory.py4
-rw-r--r--src/silx/gui/plot/MaskToolsWidget.py267
-rw-r--r--src/silx/gui/plot/PlotActions.py66
-rw-r--r--src/silx/gui/plot/PlotEvents.py185
-rw-r--r--src/silx/gui/plot/PlotInteraction.py1034
-rw-r--r--src/silx/gui/plot/PlotToolButtons.py163
-rwxr-xr-xsrc/silx/gui/plot/PlotWidget.py1609
-rw-r--r--src/silx/gui/plot/PlotWindow.py278
-rw-r--r--src/silx/gui/plot/PrintPreviewToolButton.py106
-rw-r--r--src/silx/gui/plot/Profile.py175
-rw-r--r--src/silx/gui/plot/ProfileMainWindow.py109
-rw-r--r--src/silx/gui/plot/ROIStatsWidget.py182
-rw-r--r--src/silx/gui/plot/ScatterMaskToolsWidget.py207
-rw-r--r--src/silx/gui/plot/ScatterView.py77
-rw-r--r--src/silx/gui/plot/StackView.py389
-rw-r--r--src/silx/gui/plot/StatsWidget.py319
-rw-r--r--src/silx/gui/plot/_BaseMaskToolsWidget.py340
-rw-r--r--src/silx/gui/plot/__init__.py12
-rw-r--r--src/silx/gui/plot/_utils/__init__.py27
-rw-r--r--src/silx/gui/plot/_utils/delaunay.py60
-rw-r--r--src/silx/gui/plot/_utils/dtime_ticklayout.py162
-rw-r--r--src/silx/gui/plot/_utils/panzoom.py136
-rw-r--r--src/silx/gui/plot/_utils/test/test_dtime_ticklayout.py1
-rw-r--r--src/silx/gui/plot/_utils/test/test_ticklayout.py19
-rw-r--r--src/silx/gui/plot/_utils/ticklayout.py16
-rw-r--r--src/silx/gui/plot/actions/PlotAction.py24
-rw-r--r--src/silx/gui/plot/actions/PlotToolAction.py30
-rwxr-xr-xsrc/silx/gui/plot/actions/control.py321
-rw-r--r--src/silx/gui/plot/actions/fit.py94
-rw-r--r--src/silx/gui/plot/actions/histogram.py131
-rw-r--r--src/silx/gui/plot/actions/io.py426
-rw-r--r--src/silx/gui/plot/actions/medfilt.py46
-rw-r--r--src/silx/gui/plot/actions/mode.py80
-rwxr-xr-xsrc/silx/gui/plot/backends/BackendBase.py137
-rwxr-xr-xsrc/silx/gui/plot/backends/BackendMatplotlib.py862
-rwxr-xr-xsrc/silx/gui/plot/backends/BackendOpenGL.py964
-rw-r--r--src/silx/gui/plot/backends/glutils/GLPlotCurve.py806
-rw-r--r--src/silx/gui/plot/backends/glutils/GLPlotFrame.py681
-rw-r--r--src/silx/gui/plot/backends/glutils/GLPlotImage.py357
-rw-r--r--src/silx/gui/plot/backends/glutils/GLPlotItem.py9
-rw-r--r--src/silx/gui/plot/backends/glutils/GLPlotTriangles.py45
-rw-r--r--src/silx/gui/plot/backends/glutils/GLSupport.py109
-rw-r--r--src/silx/gui/plot/backends/glutils/GLText.py207
-rw-r--r--src/silx/gui/plot/backends/glutils/GLTexture.py209
-rw-r--r--src/silx/gui/plot/backends/glutils/PlotImageFile.py99
-rw-r--r--src/silx/gui/plot/items/__init__.py44
-rw-r--r--src/silx/gui/plot/items/_arc_roi.py256
-rw-r--r--src/silx/gui/plot/items/_band_roi.py18
-rw-r--r--src/silx/gui/plot/items/_roi_base.py168
-rw-r--r--src/silx/gui/plot/items/axis.py88
-rw-r--r--src/silx/gui/plot/items/complex.py65
-rw-r--r--src/silx/gui/plot/items/core.py409
-rw-r--r--src/silx/gui/plot/items/curve.py209
-rw-r--r--src/silx/gui/plot/items/histogram.py139
-rw-r--r--src/silx/gui/plot/items/image.py165
-rw-r--r--src/silx/gui/plot/items/image_aggregated.py30
-rwxr-xr-xsrc/silx/gui/plot/items/marker.py95
-rw-r--r--src/silx/gui/plot/items/roi.py320
-rw-r--r--src/silx/gui/plot/items/scatter.py464
-rw-r--r--src/silx/gui/plot/items/shape.py99
-rw-r--r--src/silx/gui/plot/matplotlib/Colormap.py248
-rw-r--r--src/silx/gui/plot/stats/stats.py242
-rw-r--r--src/silx/gui/plot/stats/statshandler.py51
-rw-r--r--src/silx/gui/plot/test/conftest.py (renamed from src/silx/gui/plot/PlotTools.py)25
-rw-r--r--src/silx/gui/plot/test/testAlphaSlider.py40
-rw-r--r--src/silx/gui/plot/test/testAxis.py147
-rw-r--r--src/silx/gui/plot/test/testColorBar.py146
-rw-r--r--src/silx/gui/plot/test/testCompareImages.py271
-rw-r--r--src/silx/gui/plot/test/testComplexImageView.py3
-rw-r--r--src/silx/gui/plot/test/testCurvesROIWidget.py254
-rw-r--r--src/silx/gui/plot/test/testImageStack.py113
-rw-r--r--src/silx/gui/plot/test/testImageView.py56
-rw-r--r--src/silx/gui/plot/test/testInteraction.py32
-rw-r--r--src/silx/gui/plot/test/testItem.py382
-rw-r--r--src/silx/gui/plot/test/testLegendSelector.py50
-rw-r--r--src/silx/gui/plot/test/testMaskToolsWidget.py122
-rw-r--r--src/silx/gui/plot/test/testPixelIntensityHistoAction.py27
-rw-r--r--src/silx/gui/plot/test/testPlotActions.py41
-rw-r--r--src/silx/gui/plot/test/testPlotInteraction.py163
-rwxr-xr-xsrc/silx/gui/plot/test/testPlotWidget.py1579
-rwxr-xr-xsrc/silx/gui/plot/test/testPlotWidgetActiveItem.py416
-rw-r--r--src/silx/gui/plot/test/testPlotWidgetDataMargins.py135
-rw-r--r--src/silx/gui/plot/test/testPlotWidgetNoBackend.py528
-rw-r--r--src/silx/gui/plot/test/testPlotWindow.py38
-rw-r--r--src/silx/gui/plot/test/testRoiStatsWidget.py178
-rw-r--r--src/silx/gui/plot/test/testSaveAction.py60
-rw-r--r--src/silx/gui/plot/test/testScatterMaskToolsWidget.py124
-rw-r--r--src/silx/gui/plot/test/testScatterView.py20
-rw-r--r--src/silx/gui/plot/test/testStackView.py159
-rw-r--r--src/silx/gui/plot/test/testStats.py701
-rw-r--r--src/silx/gui/plot/test/testUtilsAxis.py77
-rw-r--r--src/silx/gui/plot/test/utils.py2
-rw-r--r--src/silx/gui/plot/tools/CurveLegendsWidget.py21
-rw-r--r--src/silx/gui/plot/tools/LimitsToolBar.py22
-rw-r--r--src/silx/gui/plot/tools/PlotToolButton.py92
-rw-r--r--src/silx/gui/plot/tools/PositionInfo.py88
-rw-r--r--src/silx/gui/plot/tools/RadarView.py53
-rw-r--r--src/silx/gui/plot/tools/RulerToolButton.py183
-rw-r--r--src/silx/gui/plot/tools/compare/__init__.py (renamed from src/silx/utils/html.py)15
-rw-r--r--src/silx/gui/plot/tools/compare/core.py198
-rw-r--r--src/silx/gui/plot/tools/compare/profile.py173
-rw-r--r--src/silx/gui/plot/tools/compare/statusbar.py218
-rw-r--r--src/silx/gui/plot/tools/compare/toolbar.py390
-rw-r--r--src/silx/gui/plot/tools/menus.py93
-rw-r--r--src/silx/gui/plot/tools/profile/ScatterProfileToolBar.py13
-rw-r--r--src/silx/gui/plot/tools/profile/core.py313
-rw-r--r--src/silx/gui/plot/tools/profile/editors.py28
-rw-r--r--src/silx/gui/plot/tools/profile/manager.py175
-rw-r--r--src/silx/gui/plot/tools/profile/rois.py246
-rw-r--r--src/silx/gui/plot/tools/profile/toolbar.py3
-rw-r--r--src/silx/gui/plot/tools/roi.py380
-rw-r--r--src/silx/gui/plot/tools/test/testCurveLegendsWidget.py31
-rw-r--r--src/silx/gui/plot/tools/test/testProfile.py155
-rw-r--r--src/silx/gui/plot/tools/test/testRoiCore.py (renamed from src/silx/gui/plot/tools/test/testROI.py)333
-rw-r--r--src/silx/gui/plot/tools/test/testRoiItems.py313
-rw-r--r--src/silx/gui/plot/tools/test/testScatterProfileToolBar.py18
-rw-r--r--src/silx/gui/plot/tools/test/testTools.py37
-rw-r--r--src/silx/gui/plot/tools/toolbars.py80
-rw-r--r--src/silx/gui/plot/utils/axis.py31
-rw-r--r--src/silx/gui/plot/utils/intersections.py26
-rw-r--r--src/silx/gui/plot3d/ParamTreeView.py368
-rw-r--r--src/silx/gui/plot3d/Plot3DWidget.py134
-rw-r--r--src/silx/gui/plot3d/SFViewParamTree.py368
-rw-r--r--src/silx/gui/plot3d/ScalarFieldView.py353
-rw-r--r--src/silx/gui/plot3d/SceneWidget.py102
-rw-r--r--src/silx/gui/plot3d/SceneWindow.py53
-rw-r--r--src/silx/gui/plot3d/__init__.py2
-rw-r--r--src/silx/gui/plot3d/_model/core.py28
-rw-r--r--src/silx/gui/plot3d/_model/items.py568
-rw-r--r--src/silx/gui/plot3d/_model/model.py2
-rw-r--r--src/silx/gui/plot3d/actions/io.py102
-rw-r--r--src/silx/gui/plot3d/actions/mode.py35
-rw-r--r--src/silx/gui/plot3d/actions/viewpoint.py87
-rw-r--r--src/silx/gui/plot3d/conftest.py1
-rw-r--r--src/silx/gui/plot3d/items/__init__.py9
-rw-r--r--src/silx/gui/plot3d/items/_pick.py50
-rw-r--r--src/silx/gui/plot3d/items/clipplane.py28
-rw-r--r--src/silx/gui/plot3d/items/core.py124
-rw-r--r--src/silx/gui/plot3d/items/image.py92
-rw-r--r--src/silx/gui/plot3d/items/mesh.py367
-rw-r--r--src/silx/gui/plot3d/items/mixins.py80
-rw-r--r--src/silx/gui/plot3d/items/scatter.py224
-rw-r--r--src/silx/gui/plot3d/items/volume.py135
-rw-r--r--src/silx/gui/plot3d/scene/axes.py70
-rw-r--r--src/silx/gui/plot3d/scene/camera.py117
-rw-r--r--src/silx/gui/plot3d/scene/core.py36
-rw-r--r--src/silx/gui/plot3d/scene/cutplane.py137
-rw-r--r--src/silx/gui/plot3d/scene/event.py41
-rw-r--r--src/silx/gui/plot3d/scene/function.py152
-rw-r--r--src/silx/gui/plot3d/scene/interaction.py340
-rw-r--r--src/silx/gui/plot3d/scene/primitives.py991
-rw-r--r--src/silx/gui/plot3d/scene/test/test_transform.py39
-rw-r--r--src/silx/gui/plot3d/scene/test/test_utils.py168
-rw-r--r--src/silx/gui/plot3d/scene/text.py241
-rw-r--r--src/silx/gui/plot3d/scene/transform.py304
-rw-r--r--src/silx/gui/plot3d/scene/utils.py112
-rw-r--r--src/silx/gui/plot3d/scene/viewport.py147
-rw-r--r--src/silx/gui/plot3d/scene/window.py107
-rw-r--r--src/silx/gui/plot3d/test/testGL.py18
-rw-r--r--src/silx/gui/plot3d/test/testScalarFieldView.py14
-rw-r--r--src/silx/gui/plot3d/test/testSceneWidget.py2
-rw-r--r--src/silx/gui/plot3d/test/testSceneWidgetPicking.py102
-rw-r--r--src/silx/gui/plot3d/test/testSceneWindow.py68
-rw-r--r--src/silx/gui/plot3d/test/testStatsWidget.py44
-rw-r--r--src/silx/gui/plot3d/tools/GroupPropertiesWidget.py25
-rw-r--r--src/silx/gui/plot3d/tools/PositionInfoWidget.py77
-rw-r--r--src/silx/gui/plot3d/tools/ViewpointTools.py4
-rw-r--r--src/silx/gui/plot3d/tools/test/testPositionInfoWidget.py7
-rw-r--r--src/silx/gui/plot3d/tools/toolbars.py8
-rw-r--r--src/silx/gui/plot3d/utils/mng.py55
-rw-r--r--src/silx/gui/qt/__init__.py12
-rw-r--r--src/silx/gui/qt/_pyqt6.py47
-rw-r--r--src/silx/gui/qt/_pyside_dynamic.py290
-rw-r--r--src/silx/gui/qt/_qt.py195
-rw-r--r--src/silx/gui/qt/_utils.py15
-rw-r--r--src/silx/gui/qt/inspect.py21
-rwxr-xr-xsrc/silx/gui/test/test_colors.py470
-rw-r--r--src/silx/gui/test/test_console.py6
-rw-r--r--src/silx/gui/test/test_icons.py12
-rw-r--r--src/silx/gui/test/test_qt.py18
-rw-r--r--src/silx/gui/test/utils.py40
-rwxr-xr-xsrc/silx/gui/utils/__init__.py7
-rw-r--r--src/silx/gui/utils/glutils/__init__.py165
-rw-r--r--src/silx/gui/utils/image.py78
-rw-r--r--src/silx/gui/utils/matplotlib.py198
-rw-r--r--src/silx/gui/utils/projecturl.py3
-rw-r--r--src/silx/gui/utils/signal.py13
-rw-r--r--src/silx/gui/utils/test/test.py1
-rw-r--r--src/silx/gui/utils/test/test_async.py9
-rw-r--r--src/silx/gui/utils/test/test_glutils.py4
-rw-r--r--src/silx/gui/utils/test/test_image.py83
-rwxr-xr-xsrc/silx/gui/utils/test/test_qtutils.py1
-rw-r--r--src/silx/gui/utils/test/test_testutils.py5
-rw-r--r--src/silx/gui/utils/testutils.py95
-rw-r--r--src/silx/gui/widgets/ElidedLabel.py6
-rw-r--r--src/silx/gui/widgets/FloatEdit.py102
-rw-r--r--src/silx/gui/widgets/FlowLayout.py8
-rw-r--r--src/silx/gui/widgets/FormGridLayout.py7
-rw-r--r--src/silx/gui/widgets/FrameBrowser.py25
-rwxr-xr-xsrc/silx/gui/widgets/LegendIconWidget.py146
-rw-r--r--src/silx/gui/widgets/MedianFilterDialog.py16
-rw-r--r--src/silx/gui/widgets/PeriodicTable.py297
-rw-r--r--src/silx/gui/widgets/PrintGeometryDialog.py43
-rw-r--r--src/silx/gui/widgets/PrintPreview.py173
-rw-r--r--src/silx/gui/widgets/RangeSlider.py150
-rw-r--r--src/silx/gui/widgets/StackedProgressBar.py314
-rw-r--r--src/silx/gui/widgets/TableWidget.py68
-rw-r--r--src/silx/gui/widgets/ThreadPoolPushButton.py4
-rw-r--r--src/silx/gui/widgets/UrlList.py139
-rw-r--r--src/silx/gui/widgets/UrlSelectionTable.py312
-rw-r--r--src/silx/gui/widgets/WaitingOverlay.py111
-rw-r--r--src/silx/gui/widgets/WaitingPushButton.py10
-rw-r--r--src/silx/gui/widgets/test/test_boxlayoutdockwidget.py6
-rw-r--r--src/silx/gui/widgets/test/test_elidedlabel.py1
-rw-r--r--src/silx/gui/widgets/test/test_floatedit.py82
-rw-r--r--src/silx/gui/widgets/test/test_flowlayout.py6
-rw-r--r--src/silx/gui/widgets/test/test_framebrowser.py2
-rw-r--r--src/silx/gui/widgets/test/test_hierarchicaltableview.py3
-rw-r--r--src/silx/gui/widgets/test/test_legendiconwidget.py2
-rw-r--r--src/silx/gui/widgets/test/test_periodictable.py20
-rw-r--r--src/silx/gui/widgets/test/test_printpreview.py11
-rw-r--r--src/silx/gui/widgets/test/test_rangeslider.py28
-rw-r--r--src/silx/gui/widgets/test/test_stackedprogressbar.py60
-rw-r--r--src/silx/gui/widgets/test/test_tablewidget.py1
-rw-r--r--src/silx/gui/widgets/test/test_threadpoolpushbutton.py4
-rw-r--r--src/silx/gui/widgets/test/test_urlselectiontable.py72
-rw-r--r--src/silx/gui/widgets/test/test_waitingoverlay.py31
-rw-r--r--src/silx/image/_boundingbox.py11
-rw-r--r--src/silx/image/backprojection.py2
-rw-r--r--src/silx/image/bilinear.pyx16
-rw-r--r--src/silx/image/marchingsquares/__init__.py13
-rw-r--r--src/silx/image/marchingsquares/_mergeimpl.pyx58
-rw-r--r--src/silx/image/marchingsquares/_skimage.py2
-rw-r--r--src/silx/image/marchingsquares/test/test_funcapi.py1
-rw-r--r--src/silx/image/marchingsquares/test/test_mergeimpl.py14
-rw-r--r--src/silx/image/medianfilter.py57
-rw-r--r--src/silx/image/phantomgenerator.py48
-rw-r--r--src/silx/image/projection.py2
-rw-r--r--src/silx/image/reconstruction.py2
-rw-r--r--src/silx/image/sift.py2
-rw-r--r--src/silx/image/test/test_bb.py21
-rw-r--r--src/silx/image/test/test_bilinear.py74
-rw-r--r--src/silx/image/test/test_medianfilter.py9
-rw-r--r--src/silx/image/test/test_shapes.py382
-rw-r--r--src/silx/image/test/test_tomography.py11
-rw-r--r--src/silx/image/tomography.py106
-rw-r--r--src/silx/image/utils.py3
-rw-r--r--src/silx/io/_sliceh5.py221
-rw-r--r--src/silx/io/commonh5.py65
-rw-r--r--src/silx/io/configdict.py115
-rw-r--r--src/silx/io/convert.py127
-rw-r--r--src/silx/io/dictdump.py194
-rwxr-xr-xsrc/silx/io/fabioh5.py132
-rw-r--r--src/silx/io/fioh5.py200
-rw-r--r--src/silx/io/h5link_utils.py77
-rw-r--r--src/silx/io/h5py_utils.py22
-rw-r--r--src/silx/io/nxdata/__init__.py15
-rw-r--r--src/silx/io/nxdata/_utils.py75
-rw-r--r--src/silx/io/nxdata/parse.py232
-rw-r--r--src/silx/io/nxdata/write.py72
-rw-r--r--src/silx/io/octaveh5.py52
-rw-r--r--src/silx/io/rawh5.py2
-rw-r--r--src/silx/io/specfile.pyx5
-rw-r--r--src/silx/io/specfilewrapper.py7
-rw-r--r--src/silx/io/spech5.py428
-rw-r--r--src/silx/io/spectoh5.py80
-rw-r--r--src/silx/io/test/test_commonh5.py15
-rw-r--r--src/silx/io/test/test_dictdump.py364
-rwxr-xr-xsrc/silx/io/test/test_fabioh5.py111
-rw-r--r--src/silx/io/test/test_fioh5.py146
-rw-r--r--src/silx/io/test/test_h5link_utils.py116
-rw-r--r--src/silx/io/test/test_nxdata.py399
-rw-r--r--src/silx/io/test/test_octaveh5.py152
-rw-r--r--src/silx/io/test/test_rawh5.py11
-rw-r--r--src/silx/io/test/test_sliceh5.py104
-rw-r--r--src/silx/io/test/test_specfile.py171
-rw-r--r--src/silx/io/test/test_specfilewrapper.py71
-rw-r--r--src/silx/io/test/test_spech5.py492
-rw-r--r--src/silx/io/test/test_spectoh5.py48
-rw-r--r--src/silx/io/test/test_url.py452
-rw-r--r--src/silx/io/test/test_utils.py507
-rw-r--r--src/silx/io/test/test_write_to_h5.py43
-rw-r--r--src/silx/io/url.py188
-rw-r--r--src/silx/io/utils.py515
-rw-r--r--src/silx/math/_colormap.pyx46
-rw-r--r--src/silx/math/calibration.py59
-rw-r--r--src/silx/math/colormap.py180
-rw-r--r--src/silx/math/fft/basefft.py37
-rw-r--r--src/silx/math/fft/clfft.py51
-rw-r--r--src/silx/math/fft/cufft.py44
-rw-r--r--src/silx/math/fft/fft.py4
-rw-r--r--src/silx/math/fft/fftw.py35
-rw-r--r--src/silx/math/fft/npfft.py4
-rw-r--r--src/silx/math/fft/test/test_fft.py110
-rw-r--r--src/silx/math/fit/__init__.py4
-rw-r--r--src/silx/math/fit/bgtheories.py216
-rw-r--r--src/silx/math/fit/fitmanager.py318
-rw-r--r--src/silx/math/fit/fittheories.py931
-rw-r--r--src/silx/math/fit/fittheory.py37
-rw-r--r--src/silx/math/fit/functions.pyx162
-rw-r--r--src/silx/math/fit/functions/include/functions.h1
-rw-r--r--src/silx/math/fit/functions/src/funs.c98
-rw-r--r--src/silx/math/fit/functions_wrapper.pxd6
-rw-r--r--src/silx/math/fit/leastsq.py289
-rw-r--r--src/silx/math/fit/test/test_bgtheories.py60
-rw-r--r--src/silx/math/fit/test/test_filters.py82
-rw-r--r--src/silx/math/fit/test/test_fit.py264
-rw-r--r--src/silx/math/fit/test/test_fitmanager.py183
-rw-r--r--src/silx/math/fit/test/test_functions.py125
-rw-r--r--src/silx/math/fit/test/test_peaks.py492
-rw-r--r--src/silx/math/histogram.py133
-rw-r--r--src/silx/math/histogramnd/include/histogramnd_c.h10
-rw-r--r--src/silx/math/histogramnd/include/msvc/stdint.h247
-rw-r--r--src/silx/math/medianfilter/__init__.py2
-rw-r--r--src/silx/math/medianfilter/test/benchmark.py19
-rw-r--r--src/silx/math/medianfilter/test/test_medianfilter.py559
-rw-r--r--src/silx/math/test/benchmark_combo.py96
-rw-r--r--src/silx/math/test/histo_benchmarks.py290
-rw-r--r--src/silx/math/test/test_HistogramndLut_nominal.py242
-rw-r--r--src/silx/math/test/test_calibration.py60
-rw-r--r--src/silx/math/test/test_colormap.py154
-rw-r--r--src/silx/math/test/test_combo.py81
-rw-r--r--src/silx/math/test/test_histogramnd_error.py441
-rw-r--r--src/silx/math/test/test_histogramnd_nominal.py578
-rw-r--r--src/silx/math/test/test_histogramnd_vs_np.py629
-rw-r--r--src/silx/math/test/test_interpolate.py39
-rw-r--r--src/silx/math/test/test_marchingcubes.py63
-rw-r--r--src/silx/opencl/atomic.py93
-rw-r--r--src/silx/opencl/backprojection.py150
-rw-r--r--src/silx/opencl/codec/bitshuffle_lz4.py214
-rw-r--r--src/silx/opencl/codec/byte_offset.py332
-rw-r--r--src/silx/opencl/codec/test/test_bitshuffle_lz4.py126
-rw-r--r--src/silx/opencl/codec/test/test_byte_offset.py94
-rw-r--r--src/silx/opencl/common.py331
-rw-r--r--src/silx/opencl/conftest.py1
-rw-r--r--src/silx/opencl/convolution.py95
-rw-r--r--src/silx/opencl/image.py333
-rw-r--r--src/silx/opencl/linalg.py75
-rw-r--r--src/silx/opencl/medfilt.py141
-rw-r--r--src/silx/opencl/processing.py147
-rw-r--r--src/silx/opencl/projection.py201
-rw-r--r--src/silx/opencl/reconstruction.py202
-rw-r--r--src/silx/opencl/sinofilter.py196
-rw-r--r--src/silx/opencl/sparse.py103
-rw-r--r--src/silx/opencl/statistics.py176
-rw-r--r--src/silx/opencl/test/test_addition.py72
-rw-r--r--src/silx/opencl/test/test_array_utils.py62
-rw-r--r--src/silx/opencl/test/test_backprojection.py49
-rw-r--r--src/silx/opencl/test/test_convolution.py16
-rw-r--r--src/silx/opencl/test/test_doubleword.py198
-rw-r--r--src/silx/opencl/test/test_image.py25
-rw-r--r--src/silx/opencl/test/test_kahan.py86
-rw-r--r--src/silx/opencl/test/test_linalg.py82
-rw-r--r--src/silx/opencl/test/test_medfilt.py35
-rw-r--r--src/silx/opencl/test/test_projection.py13
-rw-r--r--src/silx/opencl/test/test_sparse.py33
-rw-r--r--src/silx/opencl/test/test_stats.py49
-rw-r--r--src/silx/opencl/utils.py24
-rw-r--r--src/silx/resources/__init__.py170
-rw-r--r--src/silx/resources/gui/icons/ruler.pngbin0 -> 1416 bytes
-rw-r--r--src/silx/resources/gui/icons/ruler.svg216
-rw-r--r--src/silx/resources/opencl/codec/bitshuffle_lz4.cl625
-rw-r--r--src/silx/resources/opencl/doubleword.cl7
-rw-r--r--src/silx/sx/__init__.py40
-rw-r--r--src/silx/sx/_plot.py143
-rw-r--r--src/silx/sx/_plot3d.py59
-rw-r--r--src/silx/test/__init__.py37
-rw-r--r--src/silx/test/test_resources.py67
-rw-r--r--src/silx/test/test_sx.py76
-rw-r--r--src/silx/test/utils.py32
-rw-r--r--src/silx/third_party/EdfFile.py499
-rw-r--r--src/silx/third_party/TiffIO.py1274
-rw-r--r--src/silx/third_party/__init__.py11
-rw-r--r--src/silx/utils/ExternalResources.py114
-rw-r--r--src/silx/utils/array_like.py84
-rw-r--r--src/silx/utils/debug.py9
-rw-r--r--src/silx/utils/deprecation.py46
-rw-r--r--src/silx/utils/files.py1
-rw-r--r--src/silx/utils/launcher.py33
-rwxr-xr-xsrc/silx/utils/number.py20
-rw-r--r--src/silx/utils/property.py1
-rw-r--r--src/silx/utils/proxy.py3
-rw-r--r--src/silx/utils/retry.py7
-rw-r--r--src/silx/utils/test/test_array_like.py215
-rw-r--r--src/silx/utils/test/test_debug.py1
-rw-r--r--src/silx/utils/test/test_deprecation.py9
-rw-r--r--src/silx/utils/test/test_enum.py61
-rw-r--r--src/silx/utils/test/test_external_resources.py6
-rw-r--r--src/silx/utils/test/test_launcher.py4
-rw-r--r--src/silx/utils/test/test_launcher_command.py1
-rw-r--r--src/silx/utils/test/test_number.py32
-rw-r--r--src/silx/utils/test/test_proxy.py11
-rw-r--r--src/silx/utils/test/test_weakref.py6
-rwxr-xr-xsrc/silx/utils/testutils.py65
-rw-r--r--src/silx/utils/weakref.py6
497 files changed, 39660 insertions, 29114 deletions
diff --git a/src/silx/__main__.py b/src/silx/__main__.py
index 0f8727c..cbd5d34 100644
--- a/src/silx/__main__.py
+++ b/src/silx/__main__.py
@@ -35,6 +35,7 @@ __date__ = "07/06/2018"
import logging
+
logging.basicConfig()
import multiprocessing
@@ -55,15 +56,24 @@ def main():
multiprocessing.freeze_support()
launcher = Launcher(prog="silx", version=silx._version.version)
- launcher.add_command("view",
- module_name="silx.app.view.main",
- description="Browse a data file with a GUI")
- launcher.add_command("convert",
- module_name="silx.app.convert",
- description="Convert and concatenate files into a HDF5 file")
- launcher.add_command("test",
- module_name="silx.app.test_",
- description="Launch silx unittest")
+ launcher.add_command(
+ "view",
+ module_name="silx.app.view.main",
+ description="Browse a data file with a GUI",
+ )
+ launcher.add_command(
+ "convert",
+ module_name="silx.app.convert",
+ description="Convert and concatenate files into a HDF5 file",
+ )
+ launcher.add_command(
+ "compare",
+ module_name="silx.app.compare.main",
+ description="Compare images with a GUI",
+ )
+ launcher.add_command(
+ "test", module_name="silx.app.test_", description="Launch silx unittest"
+ )
status = launcher.execute(sys.argv)
return status
diff --git a/src/silx/_config.py b/src/silx/_config.py
index 5d7b445..f48c783 100644
--- a/src/silx/_config.py
+++ b/src/silx/_config.py
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# /*##########################################################################
#
-# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -55,7 +55,7 @@ class Config(object):
.. versionadded:: 0.8
"""
- DEFAULT_COLORMAP_NAME = 'gray'
+ DEFAULT_COLORMAP_NAME = "gray"
"""Default LUT for the plot widgets.
The available list of names are available in the module
@@ -64,7 +64,7 @@ class Config(object):
.. versionadded:: 0.8
"""
- DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION = 'upward'
+ DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION = "upward"
"""Default Y-axis orientation for plot widget displaying images.
This attribute can be set with:
@@ -83,27 +83,23 @@ class Config(object):
.. versionadded:: 0.8
"""
- DEFAULT_PLOT_CURVE_COLORS = ['#000000', # black
- '#0000ff', # blue
- '#ff0000', # red
- '#00ff00', # green
- '#ff66ff', # pink
- '#ffff00', # yellow
- '#a52a2a', # brown
- '#00ffff', # cyan
- '#ff00ff', # magenta
- '#ff9900', # orange
- '#6600ff', # violet
- '#a0a0a4', # grey
- '#000080', # darkBlue
- '#800000', # darkRed
- '#008000', # darkGreen
- '#008080', # darkCyan
- '#800080', # darkMagenta
- '#808000', # darkYellow
- '#660000'] # darkBrown
+ DEFAULT_PLOT_CURVE_COLORS = [
+ "#1f77b4", # tab:blue
+ "#ff7f0e", # tab:orange
+ "#2ca02c", # tab:green
+ "#d62728", # tab:red
+ "#9467bd", # tab:purple
+ "#8c564b", # tab:brown
+ "#e377c2", # tab:pink
+ "#7f7f7f", # tab:gray
+ "#bcbd22", # tab:olive
+ "#17becf", # tab:cyan
+ ]
"""Default list of colors for plot widget displaying curves.
+ It is based on the color cycle of matplotlib 2.0.
+ See https://github.com/matplotlib/matplotlib/blob/main/lib/matplotlib/_color_data.py#L14
+
It will have an influence on:
- :class:`silx.gui.plot.PlotWidget`
@@ -119,13 +115,13 @@ class Config(object):
.. versionadded:: 0.10
"""
- DEFAULT_PLOT_SYMBOL = 'o'
+ DEFAULT_PLOT_SYMBOL = "o"
"""Default marker of the item.
It will have an influence on PlotWidget items
Supported symbols:
-
+
- 'o', 'Circle'
- 'd', 'Diamond'
- 's', 'Square'
@@ -145,3 +141,34 @@ class Config(object):
.. versionadded:: 0.10
"""
+
+ DEFAULT_PLOT_ACTIVE_CURVE_COLOR = None
+ """Default color for the active curve.
+
+ It will have an influence on PlotWidget curve items
+
+ .. versionadded:: 2.0
+ """
+
+ DEFAULT_PLOT_ACTIVE_CURVE_LINEWIDTH = 2
+ """Default line width for the active curve.
+
+ It will have an influence on PlotWidget curve items
+
+ .. versionadded:: 2.0
+ """
+
+ DEFAULT_PLOT_MARKER_TEXT_FONT_SIZE = None
+ """Default font size for marker text.
+
+ It will have an influence on PlotWidget marker items
+
+ .. versionadded:: 2.0
+ """
+
+ _MPL_TIGHT_LAYOUT = False
+ """If true the matplotlib backend will use the
+ experimental tight layout.
+
+ .. versionadded:: 2.0
+ """
diff --git a/src/silx/_version.py b/src/silx/_version.py
index 5688237..1edc955 100644
--- a/src/silx/_version.py
+++ b/src/silx/_version.py
@@ -51,33 +51,35 @@ Thus 2.1.0a3 is hexversion 0x020100a3.
__authors__ = ["Jérôme Kieffer"]
__license__ = "MIT"
__copyright__ = "European Synchrotron Radiation Facility, Grenoble, France"
-__date__ = "30/09/2020"
+__date__ = "12/12/2023"
__status__ = "production"
-__docformat__ = 'restructuredtext'
-__all__ = ["date", "version_info", "strictversion", "hexversion", "debianversion",
- "calc_hexversion"]
-
-RELEASE_LEVEL_VALUE = {"dev": 0,
- "alpha": 10,
- "beta": 11,
- "candidate": 12,
- "final": 15}
-
-PRERELEASE_NORMALIZED_NAME = {"dev": "a",
- "alpha": "a",
- "beta": "b",
- "candidate": "rc"}
-
-MAJOR = 1
-MINOR = 1
-MICRO = 2
+__docformat__ = "restructuredtext"
+__all__ = [
+ "date",
+ "version_info",
+ "strictversion",
+ "hexversion",
+ "debianversion",
+ "calc_hexversion",
+]
+
+RELEASE_LEVEL_VALUE = {"dev": 0, "alpha": 10, "beta": 11, "candidate": 12, "final": 15}
+
+PRERELEASE_NORMALIZED_NAME = {"dev": "a", "alpha": "a", "beta": "b", "candidate": "rc"}
+
+MAJOR = 2
+MINOR = 0
+MICRO = 0
RELEV = "final" # <16
SERIAL = 0 # <16
date = __date__
from collections import namedtuple
-_version_info = namedtuple("version_info", ["major", "minor", "micro", "releaselevel", "serial"])
+
+_version_info = namedtuple(
+ "version_info", ["major", "minor", "micro", "releaselevel", "serial"]
+)
version_info = _version_info(MAJOR, MINOR, MICRO, RELEV, SERIAL)
@@ -85,7 +87,11 @@ strictversion = version = debianversion = "%d.%d.%d" % version_info[:3]
if version_info.releaselevel != "final":
_prerelease = PRERELEASE_NORMALIZED_NAME[version_info[3]]
version += "-%s%s" % (_prerelease, version_info[-1])
- debianversion += "~adev%i" % version_info[-1] if RELEV == "dev" else "~%s%i" % (_prerelease, version_info[-1])
+ debianversion += (
+ "~adev%i" % version_info[-1]
+ if RELEV == "dev"
+ else "~%s%i" % (_prerelease, version_info[-1])
+ )
strictversion += _prerelease + str(version_info[-1])
diff --git a/src/silx/app/compare/CompareImagesWindow.py b/src/silx/app/compare/CompareImagesWindow.py
new file mode 100644
index 0000000..7a509ae
--- /dev/null
+++ b/src/silx/app/compare/CompareImagesWindow.py
@@ -0,0 +1,254 @@
+#!/usr/bin/env python
+# /*##########################################################################
+#
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Main window used to compare images
+"""
+
+import logging
+import numpy
+import typing
+import os.path
+
+import silx.io
+from silx.gui import icons
+from silx.gui import qt
+from silx.gui.plot.CompareImages import CompareImages
+from silx.gui.widgets.UrlSelectionTable import UrlSelectionTable
+from ..utils import parseutils
+from silx.gui.plot.tools.profile.manager import ProfileManager
+from silx.gui.plot.tools.compare.profile import ProfileImageDirectedLineROI
+
+try:
+ import PIL
+except ImportError:
+ PIL = None
+
+
+_logger = logging.getLogger(__name__)
+
+
+def _get_image_from_file(urlPath: str) -> typing.Optional[numpy.ndarray]:
+ """Returns a dataset from an image file.
+
+ The returned layout shape is supposed to be `rows, columns, channels (rgb[a])`.
+ """
+ if PIL is None:
+ return None
+ return numpy.asarray(PIL.Image.open(urlPath))
+
+
+class CompareImagesWindow(qt.QMainWindow):
+ def __init__(self, backend=None, settings=None):
+ qt.QMainWindow.__init__(self, parent=None)
+ self.setWindowTitle("Silx compare")
+
+ silxIcon = icons.getQIcon("silx")
+ self.setWindowIcon(silxIcon)
+
+ self._plot = CompareImages(parent=self, backend=backend)
+ self._plot.setAutoResetZoom(False)
+
+ self.__manager = ProfileManager(self, self._plot.getPlot())
+ virtualItem = self._plot._getVirtualPlotItem()
+ self.__manager.setPlotItem(virtualItem)
+
+ directedLineAction = self.__manager.createProfileAction(
+ ProfileImageDirectedLineROI, self
+ )
+
+ profileToolBar = qt.QToolBar(self)
+ profileToolBar.setWindowTitle("Profile")
+ profileToolBar.addAction(directedLineAction)
+ self.__profileToolBar = profileToolBar
+ self._plot.addToolBar(profileToolBar)
+
+ self._selectionTable = UrlSelectionTable(parent=self)
+ self._selectionTable.setAcceptDrops(True)
+
+ self.__settings = settings
+ if settings:
+ self.restoreSettings(settings)
+
+ spliter = qt.QSplitter(self)
+ spliter.addWidget(self._selectionTable)
+ spliter.addWidget(self._plot)
+ spliter.setStretchFactor(1, 1)
+ spliter.setCollapsible(0, False)
+ spliter.setCollapsible(1, False)
+ self.__splitter = spliter
+
+ self.setCentralWidget(spliter)
+
+ self._selectionTable.sigImageAChanged.connect(self._updateImageA)
+ self._selectionTable.sigImageBChanged.connect(self._updateImageB)
+
+ def setUrls(self, urls):
+ self.clear()
+ for url in urls:
+ self._selectionTable.addUrl(url)
+ url1 = urls[0].path() if len(urls) >= 1 else None
+ url2 = urls[1].path() if len(urls) >= 2 else None
+ self._selectionTable.setUrlSelection(url_img_a=url1, url_img_b=url2)
+ self._plot.resetZoom()
+ self._plot.centerLines()
+
+ def clear(self):
+ self._plot.clear()
+ self._selectionTable.clear()
+
+ def _updateImageA(self, urlPath):
+ try:
+ data = self.readData(urlPath)
+ except Exception as e:
+ _logger.error("Error while loading URL %s", urlPath, exc_info=True)
+ self._selectionTable.setError(urlPath, e.args[0])
+ data = None
+ self._plot.setImage1(data)
+
+ def _updateImageB(self, urlPath):
+ try:
+ data = self.readData(urlPath)
+ except Exception as e:
+ _logger.error("Error while loading URL %s", urlPath, exc_info=True)
+ self._selectionTable.setError(urlPath, e.args[0])
+ data = None
+ self._plot.setImage2(data)
+
+ def readData(self, urlPath: str):
+ """Read an URL as an image"""
+ if urlPath in ("", None):
+ return None
+
+ data = None
+ _, ext = os.path.splitext(urlPath)
+ if ext in {".jpg", ".jpeg", ".png"}:
+ try:
+ data = _get_image_from_file(urlPath)
+ except Exception:
+ _logger.debug("Error while loading image with PIL", exc_info=True)
+
+ if data is None:
+ try:
+ data = silx.io.utils.get_data(urlPath)
+ except Exception:
+ raise ValueError(f"Data from '{urlPath}' is not readable")
+
+ if not isinstance(data, numpy.ndarray):
+ raise ValueError(f"URL '{urlPath}' does not link to a numpy array")
+ if data.dtype.kind not in set(["f", "u", "i", "b"]):
+ raise ValueError(f"URL '{urlPath}' does not link to a numeric numpy array")
+
+ if data.ndim == 2:
+ return data
+ if data.ndim == 3 and data.shape[2] in {3, 4}:
+ return data
+
+ raise ValueError(f"URL '{urlPath}' does not link to an numpy image")
+
+ def closeEvent(self, event):
+ settings = self.__settings
+ if settings:
+ self.saveSettings(self.__settings)
+
+ def saveSettings(self, settings):
+ """Save the window settings to this settings object
+
+ :param qt.QSettings settings: Initialized settings
+ """
+ isFullScreen = bool(self.windowState() & qt.Qt.WindowFullScreen)
+ if isFullScreen:
+ # show in normal to catch the normal geometry
+ self.showNormal()
+
+ settings.beginGroup("comparewindow")
+ settings.setValue("size", self.size())
+ settings.setValue("pos", self.pos())
+ settings.setValue("full-screen", isFullScreen)
+ settings.setValue("spliter", self.__splitter.sizes())
+ # NOTE: isInverted returns a numpy bool
+ settings.setValue(
+ "y-axis-inverted", bool(self._plot.getPlot().getYAxis().isInverted())
+ )
+
+ settings.setValue("visualization-mode", self._plot.getVisualizationMode().name)
+ settings.setValue("alignment-mode", self._plot.getAlignmentMode().name)
+ settings.setValue("display-keypoints", self._plot.getKeypointsVisible())
+
+ displayKeypoints = settings.value("display-keypoints", False)
+ displayKeypoints = parseutils.to_bool(displayKeypoints, False)
+
+ # self._plot.getAlignmentMode()
+ # self._plot.getVisualizationMode()
+ # self._plot.getKeypointsVisible()
+ settings.endGroup()
+
+ if isFullScreen:
+ self.showFullScreen()
+
+ def restoreSettings(self, settings):
+ """Restore the window settings using this settings object
+
+ :param qt.QSettings settings: Initialized settings
+ """
+ settings.beginGroup("comparewindow")
+ size = settings.value("size", qt.QSize(640, 480))
+ pos = settings.value("pos", qt.QPoint())
+ isFullScreen = settings.value("full-screen", False)
+ isFullScreen = parseutils.to_bool(isFullScreen, False)
+ yAxisInverted = settings.value("y-axis-inverted", False)
+ yAxisInverted = parseutils.to_bool(yAxisInverted, False)
+
+ visualizationMode = settings.value("visualization-mode", "")
+ visualizationMode = parseutils.to_enum(
+ visualizationMode,
+ CompareImages.VisualizationMode,
+ CompareImages.VisualizationMode.VERTICAL_LINE,
+ )
+ alignmentMode = settings.value("alignment-mode", "")
+ alignmentMode = parseutils.to_enum(
+ alignmentMode,
+ CompareImages.AlignmentMode,
+ CompareImages.AlignmentMode.ORIGIN,
+ )
+ displayKeypoints = settings.value("display-keypoints", False)
+ displayKeypoints = parseutils.to_bool(displayKeypoints, False)
+
+ try:
+ data = settings.value("spliter")
+ data = [int(d) for d in data]
+ self.__splitter.setSizes(data)
+ except Exception:
+ _logger.debug("Backtrace", exc_info=True)
+ settings.endGroup()
+
+ if not pos.isNull():
+ self.move(pos)
+ if not size.isNull():
+ self.resize(size)
+ if isFullScreen:
+ self.showFullScreen()
+ self._plot.setVisualizationMode(visualizationMode)
+ self._plot.setAlignmentMode(alignmentMode)
+ self._plot.setKeypointsVisible(displayKeypoints)
+ self._plot.getPlot().getYAxis().setInverted(yAxisInverted)
diff --git a/src/silx/gui/plot/matplotlib/__init__.py b/src/silx/app/compare/__init__.py
index 155ffd4..e5ec4c6 100644
--- a/src/silx/gui/plot/matplotlib/__init__.py
+++ b/src/silx/app/compare/__init__.py
@@ -1,6 +1,5 @@
# /*##########################################################################
-#
-# Copyright (c) 2020 European Synchrotron Radiation Facility
+# Copyright (C) 2022-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -20,17 +19,9 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
-# ###########################################################################*/
+# ############################################################################*/
+"""Package containing source code of the `silx compare` application"""
-__authors__ = ["T. Vincent"]
+__authors__ = ["V. Valls"]
__license__ = "MIT"
-__date__ = "15/07/2020"
-
-from silx.utils.deprecation import deprecated_warning
-
-deprecated_warning(type_='module',
- name=__file__,
- replacement='silx.gui.utils.matplotlib',
- since_version='0.14.0')
-
-from silx.gui.utils.matplotlib import FigureCanvasQTAgg # noqa
+__date__ = "04/13/2023"
diff --git a/src/silx/app/compare/main.py b/src/silx/app/compare/main.py
new file mode 100644
index 0000000..79c33f1
--- /dev/null
+++ b/src/silx/app/compare/main.py
@@ -0,0 +1,105 @@
+#!/usr/bin/env python
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""GUI to compare images"""
+
+import sys
+import logging
+import argparse
+import silx
+from silx.gui import qt
+from silx.app.utils import parseutils
+from silx.app.compare.CompareImagesWindow import CompareImagesWindow
+
+_logger = logging.getLogger(__name__)
+
+
+file_description = """
+Image data to compare (HDF5 file with path, EDF files, JPEG/PNG image files).
+Data from HDF5 files can be accessed using dataset path and slicing as an URL: silx:../my_file.h5?path=/entry/data&slice=10
+EDF file frames also can can be accessed using URL: fabio:../my_file.edf?slice=10
+Using URL in command like usually have to be quoted: "URL".
+"""
+
+
+def createParser():
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument("files", nargs=argparse.ZERO_OR_MORE, help=file_description)
+ parser.add_argument(
+ "--debug",
+ dest="debug",
+ action="store_true",
+ default=False,
+ help="Set logging system in debug mode",
+ )
+ parser.add_argument(
+ "--use-opengl-plot",
+ dest="use_opengl_plot",
+ action="store_true",
+ default=False,
+ help="Use OpenGL for plots (instead of matplotlib)",
+ )
+ return parser
+
+
+def mainQt(options):
+ """Part of the main depending on Qt"""
+ if options.debug:
+ logging.root.setLevel(logging.DEBUG)
+
+ if options.use_opengl_plot:
+ backend = "gl"
+ else:
+ backend = "mpl"
+
+ settings = qt.QSettings(
+ qt.QSettings.IniFormat, qt.QSettings.UserScope, "silx", "silx-compare", None
+ )
+
+ urls = list(parseutils.filenames_to_dataurls(options.files))
+
+ if options.use_opengl_plot:
+ # It have to be done after the settings (after the Viewer creation)
+ silx.config.DEFAULT_PLOT_BACKEND = "opengl"
+
+ app = qt.QApplication([])
+ window = CompareImagesWindow(backend=backend, settings=settings)
+ window.setAttribute(qt.Qt.WA_DeleteOnClose, True)
+
+ # Note: Have to be before setUrls to have a proper resetZoom
+ window.setVisible(True)
+
+ window.setUrls(urls)
+
+ app.exec()
+
+
+def main(argv):
+ parser = createParser()
+ options = parser.parse_args(argv[1:])
+ mainQt(options)
+
+
+if __name__ == "__main__":
+ main(sys.argv)
diff --git a/src/silx/app/compare/test/__init__.py b/src/silx/app/compare/test/__init__.py
new file mode 100644
index 0000000..1d8207b
--- /dev/null
+++ b/src/silx/app/compare/test/__init__.py
@@ -0,0 +1,23 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
diff --git a/src/silx/third_party/scipy_spatial.py b/src/silx/app/compare/test/test_compare.py
index 13069b3..45c6838 100644
--- a/src/silx/third_party/scipy_spatial.py
+++ b/src/silx/app/compare/test/test_compare.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2017 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -21,28 +21,29 @@
# THE SOFTWARE.
#
# ###########################################################################*/
-"""Wrapper module for `scipy.spatial.Delaunay` class.
+"""Module testing silx.app.view"""
-Uses a local copy of `scipy.spatial.Delaunay` if available,
-else it loads it from `scipy`.
-
-It should be used like that:
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "07/06/2023"
-.. code-block::
- from silx.third_party.scipy_spatial import Delaunay
+import weakref
+import pytest
+from silx.app.compare.CompareImagesWindow import CompareImagesWindow
+from silx.gui.utils.testutils import TestCaseQt
-"""
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "07/11/2017"
+@pytest.mark.usefixtures("qapp")
+class TestCompare(TestCaseQt):
+ """Test for Viewer class"""
-try:
- # try to import silx local copy of Delaunay
- from ._local.scipy_spatial import Delaunay # noqa
-except ImportError:
- # else import it from the python path
- from scipy.spatial import Delaunay # noqa
+ def testConstruct(self):
+ widget = CompareImagesWindow()
+ self.qWaitForWindowExposed(widget)
-__all__ = ['Delaunay']
+ def testDestroy(self):
+ widget = CompareImagesWindow()
+ ref = weakref.ref(widget)
+ widget = None
+ self.qWaitForDestroy(ref)
diff --git a/src/silx/app/compare/test/test_launcher.py b/src/silx/app/compare/test/test_launcher.py
new file mode 100644
index 0000000..a42b762
--- /dev/null
+++ b/src/silx/app/compare/test/test_launcher.py
@@ -0,0 +1,142 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Module testing silx.app.view"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "07/06/2023"
+
+
+import os
+import sys
+import shutil
+import logging
+import subprocess
+import pytest
+
+from .. import main
+from silx import __main__ as silx_main
+
+_logger = logging.getLogger(__name__)
+
+
+def test_help(qapp):
+ # option -h must cause a raise SystemExit or a return 0
+ try:
+ parser = main.createParser()
+ parser.parse_args(["compare", "--help"])
+ result = 0
+ except SystemExit as e:
+ result = e.args[0]
+ assert result == 0
+
+
+def test_wrong_option(qapp):
+ try:
+ parser = main.createParser()
+ parser.parse_args(["compare", "--foo"])
+ assert False
+ except SystemExit as e:
+ result = e.args[0]
+ assert result != 0
+
+
+def test_wrong_file(qapp):
+ try:
+ parser = main.createParser()
+ result = parser.parse_args(["compare", "__file.not.found__"])
+ result = 0
+ except SystemExit as e:
+ result = e.args[0]
+ assert result == 0
+
+
+def _create_test_env():
+ """
+ Returns an associated environment with a working project.
+ """
+ env = dict((str(k), str(v)) for k, v in os.environ.items())
+ env["PYTHONPATH"] = os.pathsep.join(sys.path)
+ return env
+
+
+@pytest.fixture
+def execute_as_script(tmp_path):
+ """Execute a command line.
+
+ Log output as debug in case of bad return code.
+ """
+
+ def execute_as_script(filename, *args):
+ env = _create_test_env()
+
+ # Copy file to temporary dir to avoid import from current dir.
+ script = os.path.join(tmp_path, "launcher.py")
+ shutil.copyfile(filename, script)
+ command_line = [sys.executable, script] + list(args)
+
+ _logger.info("Execute: %s", " ".join(command_line))
+ p = subprocess.Popen(
+ command_line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
+ )
+ out, err = p.communicate()
+ _logger.info("Return code: %d", p.returncode)
+ try:
+ out = out.decode("utf-8")
+ except UnicodeError:
+ pass
+ try:
+ err = err.decode("utf-8")
+ except UnicodeError:
+ pass
+
+ if p.returncode != 0:
+ _logger.error("stdout:")
+ _logger.error("%s", out)
+ _logger.error("stderr:")
+ _logger.error("%s", err)
+ else:
+ _logger.debug("stdout:")
+ _logger.debug("%s", out)
+ _logger.debug("stderr:")
+ _logger.debug("%s", err)
+ assert p.returncode == 0
+
+ return execute_as_script
+
+
+def test_execute_compare_help(qapp, execute_as_script):
+ """Test if the main module is well connected.
+
+ Uses subprocess to avoid to parasite the current environment.
+ """
+ execute_as_script(main.__file__, "--help")
+
+
+def test_execute_silx_compare_help(qapp, execute_as_script):
+ """Test if the main module is well connected.
+
+ Uses subprocess to avoid to parasite the current environment.
+ """
+ execute_as_script(silx_main.__file__, "view", "--help")
diff --git a/src/silx/app/convert.py b/src/silx/app/convert.py
index 78c1ebf..e20a448 100644
--- a/src/silx/app/convert.py
+++ b/src/silx/app/convert.py
@@ -85,8 +85,10 @@ def drop_indices_before_begin(filenames, regex, begin):
m = re.match(regex, fname)
file_indices = list(map(int, m.groups()))
if len(file_indices) != len(begin_indices):
- raise IOError("Number of indices found in filename "
- "does not match number of parsed end indices.")
+ raise IOError(
+ "Number of indices found in filename "
+ "does not match number of parsed end indices."
+ )
good_indices = True
for i, fidx in enumerate(file_indices):
if fidx < begin_indices[i]:
@@ -110,8 +112,10 @@ def drop_indices_after_end(filenames, regex, end):
m = re.match(regex, fname)
file_indices = list(map(int, m.groups()))
if len(file_indices) != len(end_indices):
- raise IOError("Number of indices found in filename "
- "does not match number of parsed end indices.")
+ raise IOError(
+ "Number of indices found in filename "
+ "does not match number of parsed end indices."
+ )
good_indices = True
for i, fidx in enumerate(file_indices):
if fidx > end_indices[i]:
@@ -133,15 +137,17 @@ def are_files_missing_in_series(filenames, regex):
previous_indices = None
for fname in filenames:
m = re.match(regex, fname)
- assert m is not None, \
- "regex %s does not match filename %s" % (fname, regex)
+ assert m is not None, "regex %s does not match filename %s" % (fname, regex)
new_indices = list(map(int, m.groups()))
if previous_indices is not None:
for old_idx, new_idx in zip(previous_indices, new_indices):
if (new_idx - old_idx) > 1:
- _logger.error("Index increment > 1 in file series: "
- "previous idx %d, next idx %d",
- old_idx, new_idx)
+ _logger.error(
+ "Index increment > 1 in file series: "
+ "previous idx %d, next idx %d",
+ old_idx,
+ new_idx,
+ )
return True
previous_indices = new_indices
return False
@@ -196,116 +202,134 @@ def main(argv):
"""
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
- 'input_files',
+ "input_files",
nargs="*",
- help='Input files (EDF, TIFF, FIO, SPEC...). When specifying '
- 'multiple files, you cannot specify both fabio images '
- 'and SPEC (or FIO) files. Multiple SPEC or FIO files will '
- 'simply be concatenated, with one entry per scan. '
- 'Multiple image files will be merged into a single '
- 'entry with a stack of images.')
+ help="Input files (EDF, TIFF, FIO, SPEC...). When specifying "
+ "multiple files, you cannot specify both fabio images "
+ "and SPEC (or FIO) files. Multiple SPEC or FIO files will "
+ "simply be concatenated, with one entry per scan. "
+ "Multiple image files will be merged into a single "
+ "entry with a stack of images.",
+ )
# input_files and --filepattern are mutually exclusive
parser.add_argument(
- '--file-pattern',
- help='File name pattern for loading a series of indexed image files '
- '(toto_%%04d.edf). This argument is incompatible with argument '
- 'input_files. If an output URI with a HDF5 path is provided, '
- 'only the content of the NXdetector group will be copied there. '
- 'If no HDF5 path, or just "/", is given, a complete NXdata '
- 'structure will be created.')
+ "--file-pattern",
+ help="File name pattern for loading a series of indexed image files "
+ "(toto_%%04d.edf). This argument is incompatible with argument "
+ "input_files. If an output URI with a HDF5 path is provided, "
+ "only the content of the NXdetector group will be copied there. "
+ 'If no HDF5 path, or just "/", is given, a complete NXdata '
+ "structure will be created.",
+ )
parser.add_argument(
- '-o', '--output-uri',
- default=time.strftime("%Y%m%d-%H%M%S") + '.h5',
- help='Output file name (HDF5). An URI can be provided to write'
- ' the data into a specific group in the output file: '
- '/path/to/file::/path/to/group. '
- 'If not provided, the filename defaults to a timestamp:'
- ' YYYYmmdd-HHMMSS.h5')
+ "-o",
+ "--output-uri",
+ default=time.strftime("%Y%m%d-%H%M%S") + ".h5",
+ help="Output file name (HDF5). An URI can be provided to write"
+ " the data into a specific group in the output file: "
+ "/path/to/file::/path/to/group. "
+ "If not provided, the filename defaults to a timestamp:"
+ " YYYYmmdd-HHMMSS.h5",
+ )
parser.add_argument(
- '-m', '--mode',
+ "-m",
+ "--mode",
default="w-",
help='Write mode: "r+" (read/write, file must exist), '
- '"w" (write, existing file is lost), '
- '"w-" (write, fail if file exists) or '
- '"a" (read/write if exists, create otherwise)')
+ '"w" (write, existing file is lost), '
+ '"w-" (write, fail if file exists) or '
+ '"a" (read/write if exists, create otherwise)',
+ )
parser.add_argument(
- '--begin',
- help='First file index, or first file indices to be considered. '
- 'This argument only makes sense when used together with '
- '--file-pattern. Provide as many start indices as there '
- 'are indices in the file pattern, separated by commas. '
- 'Examples: "--filepattern toto_%%d.edf --begin 100", '
- ' "--filepattern toto_%%d_%%04d_%%02d.edf --begin 100,2000,5".')
+ "--begin",
+ help="First file index, or first file indices to be considered. "
+ "This argument only makes sense when used together with "
+ "--file-pattern. Provide as many start indices as there "
+ "are indices in the file pattern, separated by commas. "
+ 'Examples: "--filepattern toto_%%d.edf --begin 100", '
+ ' "--filepattern toto_%%d_%%04d_%%02d.edf --begin 100,2000,5".',
+ )
parser.add_argument(
- '--end',
- help='Last file index, or last file indices to be considered. '
- 'The same rules as with argument --begin apply. '
- 'Example: "--filepattern toto_%%d_%%d.edf --end 199,1999"')
+ "--end",
+ help="Last file index, or last file indices to be considered. "
+ "The same rules as with argument --begin apply. "
+ 'Example: "--filepattern toto_%%d_%%d.edf --end 199,1999"',
+ )
parser.add_argument(
- '--add-root-group',
+ "--add-root-group",
action="store_true",
- help='This option causes each input file to be written to a '
- 'specific root group with the same name as the file. When '
- 'merging multiple input files, this can help preventing conflicts'
- ' when datasets have the same name (see --overwrite-data). '
- 'This option is ignored when using --file-pattern.')
+ help="This option causes each input file to be written to a "
+ "specific root group with the same name as the file. When "
+ "merging multiple input files, this can help preventing conflicts"
+ " when datasets have the same name (see --overwrite-data). "
+ "This option is ignored when using --file-pattern.",
+ )
parser.add_argument(
- '--overwrite-data',
+ "--overwrite-data",
action="store_true",
- help='If the output path exists and an input dataset has the same'
- ' name as an existing output dataset, overwrite the output '
- 'dataset (in modes "r+" or "a").')
+ help="If the output path exists and an input dataset has the same"
+ " name as an existing output dataset, overwrite the output "
+ 'dataset (in modes "r+" or "a").',
+ )
parser.add_argument(
- '--min-size',
+ "--min-size",
type=int,
default=500,
- help='Minimum number of elements required to be in a dataset to '
- 'apply compression or chunking (default 500).')
+ help="Minimum number of elements required to be in a dataset to "
+ "apply compression or chunking (default 500).",
+ )
parser.add_argument(
- '--chunks',
+ "--chunks",
nargs="?",
const="auto",
- help='Chunk shape. Provide an argument that evaluates as a python '
- 'tuple (e.g. "(1024, 768)"). If this option is provided without '
- 'specifying an argument, the h5py library will guess a chunk for '
- 'you. Note that if you specify an explicit chunking shape, it '
- 'will be applied identically to all datasets with a large enough '
- 'size (see --min-size). ')
+ help="Chunk shape. Provide an argument that evaluates as a python "
+ 'tuple (e.g. "(1024, 768)"). If this option is provided without '
+ "specifying an argument, the h5py library will guess a chunk for "
+ "you. Note that if you specify an explicit chunking shape, it "
+ "will be applied identically to all datasets with a large enough "
+ "size (see --min-size). ",
+ )
parser.add_argument(
- '--compression',
+ "--compression",
nargs="?",
const="gzip",
- help='Compression filter. By default, the datasets in the output '
- 'file are not compressed. If this option is specified without '
- 'argument, the GZIP compression is used. Additional compression '
- 'filters may be available, depending on your HDF5 installation.')
+ help="Compression filter. By default, the datasets in the output "
+ "file are not compressed. If this option is specified without "
+ "argument, the GZIP compression is used. Additional compression "
+ "filters may be available, depending on your HDF5 installation.",
+ )
def check_gzip_compression_opts(value):
ivalue = int(value)
if ivalue < 0 or ivalue > 9:
raise argparse.ArgumentTypeError(
- "--compression-opts must be an int from 0 to 9")
+ "--compression-opts must be an int from 0 to 9"
+ )
return ivalue
parser.add_argument(
- '--compression-opts',
+ "--compression-opts",
type=check_gzip_compression_opts,
help='Compression options. For "gzip", this may be an integer from '
- '0 to 9, with a default of 4. This is only supported for GZIP.')
+ "0 to 9, with a default of 4. This is only supported for GZIP.",
+ )
parser.add_argument(
- '--shuffle',
+ "--shuffle",
action="store_true",
- help='Enables the byte shuffle filter. This may improve the compression '
- 'ratio for block oriented compressors like GZIP or LZF.')
+ help="Enables the byte shuffle filter. This may improve the compression "
+ "ratio for block oriented compressors like GZIP or LZF.",
+ )
parser.add_argument(
- '--fletcher32',
+ "--fletcher32",
action="store_true",
- help='Adds a checksum to each chunk to detect data corruption.')
+ help="Adds a checksum to each chunk to detect data corruption.",
+ )
parser.add_argument(
- '--debug',
+ "--debug",
action="store_true",
default=False,
- help='Set logging system in debug mode')
+ help="Set logging system in debug mode",
+ )
options = parser.parse_args(argv[1:])
@@ -329,8 +353,10 @@ def main(argv):
write_to_h5 = None
if hdf5plugin is None:
- message = "Module 'hdf5plugin' is not installed. It supports additional hdf5"\
- + " compressions. You can install it using \"pip install hdf5plugin\"."
+ message = (
+ "Module 'hdf5plugin' is not installed. It supports additional hdf5"
+ + ' compressions. You can install it using "pip install hdf5plugin".'
+ )
_logger.debug(message)
# Process input arguments (mutually exclusive arguments)
@@ -360,33 +386,40 @@ def main(argv):
dirname = os.path.dirname(options.file_pattern)
file_pattern_re = c_format_string_to_re(options.file_pattern) + "$"
files_in_dir = glob(os.path.join(dirname, "*"))
- _logger.debug("""
+ _logger.debug(
+ """
Processing file_pattern
dirname: %s
file_pattern_re: %s
files_in_dir: %s
- """, dirname, file_pattern_re, files_in_dir)
-
- options.input_files = sorted(list(filter(lambda name: re.match(file_pattern_re, name),
- files_in_dir)))
+ """,
+ dirname,
+ file_pattern_re,
+ files_in_dir,
+ )
+
+ options.input_files = sorted(
+ list(filter(lambda name: re.match(file_pattern_re, name), files_in_dir))
+ )
_logger.debug("options.input_files: %s", options.input_files)
if options.begin is not None:
- options.input_files = drop_indices_before_begin(options.input_files,
- file_pattern_re,
- options.begin)
- _logger.debug("options.input_files after applying --begin: %s",
- options.input_files)
+ options.input_files = drop_indices_before_begin(
+ options.input_files, file_pattern_re, options.begin
+ )
+ _logger.debug(
+ "options.input_files after applying --begin: %s", options.input_files
+ )
if options.end is not None:
- options.input_files = drop_indices_after_end(options.input_files,
- file_pattern_re,
- options.end)
- _logger.debug("options.input_files after applying --end: %s",
- options.input_files)
-
- if are_files_missing_in_series(options.input_files,
- file_pattern_re):
+ options.input_files = drop_indices_after_end(
+ options.input_files, file_pattern_re, options.end
+ )
+ _logger.debug(
+ "options.input_files after applying --end: %s", options.input_files
+ )
+
+ if are_files_missing_in_series(options.input_files, file_pattern_re):
_logger.error("File missing in the file series. Aborting.")
return -1
@@ -402,37 +435,39 @@ def main(argv):
if os.path.isfile(output_name):
if options.mode == "w-":
- _logger.error("Output file %s exists and mode is 'w-' (default)."
- " Aborting. To append data to an existing file, "
- "use 'a' or 'r+'.",
- output_name)
+ _logger.error(
+ "Output file %s exists and mode is 'w-' (default)."
+ " Aborting. To append data to an existing file, "
+ "use 'a' or 'r+'.",
+ output_name,
+ )
return -1
elif not os.access(output_name, os.W_OK):
- _logger.error("Output file %s exists and is not writeable.",
- output_name)
+ _logger.error("Output file %s exists and is not writeable.", output_name)
return -1
elif options.mode == "w":
- _logger.info("Output file %s exists and mode is 'w'. "
- "Overwriting existing file.", output_name)
+ _logger.info(
+ "Output file %s exists and mode is 'w'. " "Overwriting existing file.",
+ output_name,
+ )
elif options.mode in ["a", "r+"]:
- _logger.info("Appending data to existing file %s.",
- output_name)
+ _logger.info("Appending data to existing file %s.", output_name)
else:
if options.mode == "r+":
- _logger.error("Output file %s does not exist and mode is 'r+'"
- " (append, file must exist). Aborting.",
- output_name)
+ _logger.error(
+ "Output file %s does not exist and mode is 'r+'"
+ " (append, file must exist). Aborting.",
+ output_name,
+ )
return -1
else:
- _logger.info("Creating new output file %s.",
- output_name)
+ _logger.info("Creating new output file %s.", output_name)
# Test that all input files exist and are readable
bad_input = False
for fname in options.input_files:
if not os.access(fname, os.R_OK):
- _logger.error("Cannot read input file %s.",
- fname)
+ _logger.error("Cannot read input file %s.", fname)
bad_input = True
if bad_input:
_logger.error("Aborting.")
@@ -456,10 +491,12 @@ def main(argv):
nitems = numpy.prod(chunks)
nbytes = nitems * 8
if nbytes > 10**6:
- _logger.warning("Requested chunk size might be larger than"
- " the default 1MB chunk cache, for float64"
- " data. This can dramatically affect I/O "
- "performances.")
+ _logger.warning(
+ "Requested chunk size might be larger than"
+ " the default 1MB chunk cache, for float64"
+ " data. This can dramatically affect I/O "
+ "performances."
+ )
create_dataset_args["chunks"] = chunks
if options.compression is not None:
@@ -478,61 +515,78 @@ def main(argv):
if options.fletcher32:
create_dataset_args["fletcher32"] = True
- if (len(options.input_files) > 1 and
- not contains_specfile(options.input_files) and
- not contains_fiofile(options.input_files) and
- not options.add_root_group) or options.file_pattern is not None:
+ if (
+ len(options.input_files) > 1
+ and not contains_specfile(options.input_files)
+ and not contains_fiofile(options.input_files)
+ and not options.add_root_group
+ ) or options.file_pattern is not None:
# File series -> stack of images
input_group = fabioh5.File(file_series=options.input_files)
if hdf5_path != "/":
# we want to append only data and headers to an existing file
input_group = input_group["/scan_0/instrument/detector_0"]
with h5py.File(output_name, mode=options.mode) as h5f:
- write_to_h5(input_group, h5f,
- h5path=hdf5_path,
- overwrite_data=options.overwrite_data,
- create_dataset_args=create_dataset_args,
- min_size=options.min_size)
-
- elif len(options.input_files) == 1 or \
- are_all_specfile(options.input_files) or\
- are_all_fiofile(options.input_files) or\
- options.add_root_group:
+ write_to_h5(
+ input_group,
+ h5f,
+ h5path=hdf5_path,
+ overwrite_data=options.overwrite_data,
+ create_dataset_args=create_dataset_args,
+ min_size=options.min_size,
+ )
+
+ elif (
+ len(options.input_files) == 1
+ or are_all_specfile(options.input_files)
+ or are_all_fiofile(options.input_files)
+ or options.add_root_group
+ ):
# single file, or spec files
h5paths_and_groups = []
for input_name in options.input_files:
hdf5_path_for_file = hdf5_path
if options.add_root_group:
- hdf5_path_for_file = hdf5_path.rstrip("/") + "/" + os.path.basename(input_name)
+ hdf5_path_for_file = (
+ hdf5_path.rstrip("/") + "/" + os.path.basename(input_name)
+ )
try:
- h5paths_and_groups.append((hdf5_path_for_file,
- silx.io.open(input_name)))
+ h5paths_and_groups.append(
+ (hdf5_path_for_file, silx.io.open(input_name))
+ )
except IOError:
- _logger.error("Cannot read file %s. If this is a file format "
- "supported by the fabio library, you can try to"
- " install fabio (`pip install fabio`)."
- " Aborting conversion.",
- input_name)
+ _logger.error(
+ "Cannot read file %s. If this is a file format "
+ "supported by the fabio library, you can try to"
+ " install fabio (`pip install fabio`)."
+ " Aborting conversion.",
+ input_name,
+ )
return -1
with h5py.File(output_name, mode=options.mode) as h5f:
for hdf5_path_for_file, input_group in h5paths_and_groups:
- write_to_h5(input_group, h5f,
- h5path=hdf5_path_for_file,
- overwrite_data=options.overwrite_data,
- create_dataset_args=create_dataset_args,
- min_size=options.min_size)
+ write_to_h5(
+ input_group,
+ h5f,
+ h5path=hdf5_path_for_file,
+ overwrite_data=options.overwrite_data,
+ create_dataset_args=create_dataset_args,
+ min_size=options.min_size,
+ )
else:
# multiple file, SPEC and fabio images mixed
- _logger.error("Multiple files with incompatible formats specified. "
- "You can provide multiple SPEC files or multiple image "
- "files, but not both.")
+ _logger.error(
+ "Multiple files with incompatible formats specified. "
+ "You can provide multiple SPEC files or multiple image "
+ "files, but not both."
+ )
return -1
with h5py.File(output_name, mode="r+") as h5f:
# append "silx convert" to the creator attribute, for NeXus files
- previous_creator = h5f.attrs.get("creator", u"")
+ previous_creator = h5f.attrs.get("creator", "")
creator = "silx convert (v%s)" % silx.version
# only if it not already there
if creator not in previous_creator:
@@ -541,7 +595,7 @@ def main(argv):
else:
new_creator = previous_creator + "; " + creator
h5f.attrs["creator"] = numpy.array(
- new_creator,
- dtype=h5py.special_dtype(vlen=str))
+ new_creator, dtype=h5py.special_dtype(vlen=str)
+ )
return 0
diff --git a/src/silx/app/test/test_convert.py b/src/silx/app/test/test_convert.py
index f3ca269..7ff94a3 100644
--- a/src/silx/app/test/test_convert.py
+++ b/src/silx/app/test/test_convert.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -29,7 +29,6 @@ __date__ = "17/01/2018"
import os
-import sys
import tempfile
import unittest
import io
@@ -120,16 +119,12 @@ class TestConvertCommand(unittest.TestCase):
# write a temporary SPEC file
specname = os.path.join(tempdir, "input.dat")
with io.open(specname, "wb") as fd:
- if sys.version_info < (3, ):
- fd.write(sftext)
- else:
- fd.write(bytes(sftext, 'ascii'))
+ fd.write(bytes(sftext, "ascii"))
# convert it
h5name = os.path.join(tempdir, "output.h5")
assert not os.path.isfile(h5name)
- command_list = ["convert", "-m", "w",
- specname, "-o", h5name]
+ command_list = ["convert", "-m", "w", specname, "-o", h5name]
result = convert.main(command_list)
self.assertEqual(result, 0)
@@ -137,15 +132,10 @@ class TestConvertCommand(unittest.TestCase):
with h5py.File(h5name, "r") as h5f:
title12 = h5py_read_dataset(h5f["/1.2/title"])
- if sys.version_info < (3, ):
- title12 = title12.encode("utf-8")
- self.assertEqual(title12,
- "aaaaaa")
+ self.assertEqual(title12, "aaaaaa")
creator = h5f.attrs.get("creator")
self.assertIsNotNone(creator, "No creator attribute in NXroot group")
- if sys.version_info < (3, ):
- creator = creator.encode("utf-8")
self.assertIn("silx convert (v%s)" % silx.version, creator)
# delete input file
diff --git a/src/silx/app/view/utils.py b/src/silx/app/utils/__init__.py
index 6a980e9..97ef4a5 100644
--- a/src/silx/app/view/utils.py
+++ b/src/silx/app/utils/__init__.py
@@ -1,5 +1,5 @@
# /*##########################################################################
-# Copyright (C) 2018 European Synchrotron Radiation Facility
+# Copyright (C) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -20,25 +20,8 @@
# THE SOFTWARE.
#
# ############################################################################*/
-"""Browse a data file with a GUI"""
+"""Package containing utils related to applications"""
__authors__ = ["V. Valls"]
__license__ = "MIT"
-__date__ = "28/05/2018"
-
-
-_trueStrings = set(["yes", "true", "1"])
-_falseStrings = set(["no", "false", "0"])
-
-
-def stringToBool(string):
- """Returns a boolean from a string.
-
- :raise ValueError: If the string do not contains a boolean information.
- """
- lower = string.lower()
- if lower in _trueStrings:
- return True
- if lower in _falseStrings:
- return False
- raise ValueError("'%s' is not a valid boolean" % string)
+__date__ = "07/06/2023"
diff --git a/src/silx/app/utils/parseutils.py b/src/silx/app/utils/parseutils.py
new file mode 100644
index 0000000..4135290
--- /dev/null
+++ b/src/silx/app/utils/parseutils.py
@@ -0,0 +1,133 @@
+# /*##########################################################################
+# Copyright (C) 2018-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Utils related to parsing"""
+
+from __future__ import annotations
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "28/05/2018"
+
+from collections.abc import Sequence
+import glob
+import logging
+from typing import Generator, Iterable, Union, Any, Optional
+from pathlib import Path
+
+
+_logger = logging.getLogger(__name__)
+"""Module logger"""
+
+
+_trueStrings = {"yes", "true", "1"}
+_falseStrings = {"no", "false", "0"}
+
+
+def _string_to_bool(string: str) -> bool:
+ """Returns a boolean from a string.
+
+ :raise ValueError: If the string do not contains a boolean information.
+ """
+ lower = string.lower()
+ if lower in _trueStrings:
+ return True
+ if lower in _falseStrings:
+ return False
+ raise ValueError("'%s' is not a valid boolean" % string)
+
+
+def to_bool(thing: Any, default: Optional[bool] = None) -> bool:
+ """Returns a boolean from an object.
+
+ :raise ValueError: If the thing can't be interpreted as a boolean and
+ no default is set
+ """
+ if isinstance(thing, bool):
+ return thing
+ try:
+ return _string_to_bool(thing)
+ except ValueError:
+ if default is not None:
+ return default
+ raise
+
+
+def filenames_to_dataurls(
+ filenames: Iterable[Union[str, Path]],
+ slices: Sequence[int] = tuple(),
+) -> Generator[object, None, None]:
+ """Expand filenames and HDF5 data path in files input argument"""
+ # Imports here so they are performed after setting HDF5_USE_FILE_LOCKING and logging level
+ import silx.io
+ from silx.io.utils import match
+ from silx.io.url import DataUrl
+ import silx.utils.files
+
+ extra_slices = tuple(slices)
+
+ for filename in filenames:
+ url = DataUrl(filename)
+
+ for file_path in sorted(silx.utils.files.expand_filenames([url.file_path()])):
+ if url.data_path() is not None and glob.has_magic(url.data_path()):
+ try:
+ with silx.io.open(file_path) as f:
+ data_paths = list(match(f, url.data_path()))
+ except BaseException as e:
+ _logger.error(
+ f"Error searching HDF5 path pattern '{url.data_path()}' in file '{file_path}': Ignored"
+ )
+ _logger.error(e.args[0])
+ _logger.debug("Backtrace", exc_info=True)
+ continue
+ else:
+ data_paths = [url.data_path()]
+
+ if not extra_slices:
+ data_slices = (url.data_slice(),)
+ elif not url.data_slice():
+ data_slices = extra_slices
+ else:
+ data_slices = [tuple(url.data_slice()) + (s,) for s in extra_slices]
+
+ for data_path in data_paths:
+ for data_slice in data_slices:
+ yield DataUrl(
+ file_path=file_path,
+ data_path=data_path,
+ data_slice=data_slice,
+ scheme=url.scheme(),
+ )
+
+
+def to_enum(thing: Any, enum_type, default: Optional[object] = None):
+ """Parse this string as this enum_type."""
+ try:
+ v = getattr(enum_type, str(thing))
+ if isinstance(v, enum_type):
+ return v
+ raise ValueError(f"{thing} is not a {enum_type.__name__}")
+ except (AttributeError, ValueError) as e:
+ if default is not None:
+ return default
+ raise
diff --git a/src/silx/app/utils/test/__init__.py b/src/silx/app/utils/test/__init__.py
new file mode 100644
index 0000000..f94d0a3
--- /dev/null
+++ b/src/silx/app/utils/test/__init__.py
@@ -0,0 +1,23 @@
+# /*##########################################################################
+#
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
diff --git a/src/silx/app/utils/test/test_parseutils.py b/src/silx/app/utils/test/test_parseutils.py
new file mode 100644
index 0000000..9570bb7
--- /dev/null
+++ b/src/silx/app/utils/test/test_parseutils.py
@@ -0,0 +1,68 @@
+# /*##########################################################################
+# Copyright (C) 2018-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+
+import pytest
+import h5py
+from ..parseutils import filenames_to_dataurls
+
+
+@pytest.fixture(scope="module")
+def data_path(tmp_path_factory):
+ tmp_path = tmp_path_factory.mktemp("silx_app_utils")
+ with h5py.File(tmp_path / "test1.h5", "w") as h5:
+ h5["g1/sub1/data1"] = 1
+ h5["g1/sub1/data2"] = 2
+ h5["g1/sub2/data1"] = 3
+ return tmp_path
+
+
+def test_h5__datapath(data_path):
+ urls = filenames_to_dataurls([data_path / "test1.h5::g1/sub1/data1"])
+ urls = list(urls)
+ assert len(urls) == 1
+ assert urls[0].data_path().replace("\\", "/") == "g1/sub1/data1"
+
+
+def test_h5__datapath_not_existing(data_path):
+ urls = filenames_to_dataurls([data_path / "test1.h5::g1/sub0/data1"])
+ urls = list(urls)
+ assert len(urls) == 1
+ assert urls[0].data_path().replace("\\", "/") == "g1/sub0/data1"
+
+
+def test_h5__datapath_with_magic(data_path):
+ urls = filenames_to_dataurls([data_path / "test1.h5::g1/sub*/data*"])
+ urls = list(urls)
+ assert len(urls) == 3
+
+
+def test_h5__datapath_with_magic_not_existing(data_path):
+ urls = filenames_to_dataurls([data_path / "test1.h5::g1/sub0/data*"])
+ urls = list(urls)
+ assert len(urls) == 0
+
+
+def test_h5__datapath_with_recursive_magic(data_path):
+ urls = filenames_to_dataurls([data_path / "test1.h5::**/data1"])
+ urls = list(urls)
+ assert len(urls) == 2
diff --git a/src/silx/app/view/About.py b/src/silx/app/view/About.py
index 2af7ed4..76e0cf2 100644
--- a/src/silx/app/view/About.py
+++ b/src/silx/app/view/About.py
@@ -115,10 +115,9 @@ class About(qt.QDialog):
:rtype: str
"""
from silx._version import __date__ as date
+
year = date.split("/")[2]
- info = dict(
- year=year
- )
+ info = dict(year=year)
textLicense = _LICENSE_TEMPLATE.format(**info)
return textLicense
@@ -191,6 +190,7 @@ class About(qt.QDialog):
# Previous versions only return True if the filter was first used
# to decode a dataset
import h5py.h5z
+
FILTER_LZ4 = 32004
FILTER_BITSHUFFLE = 32008
filters = [
@@ -201,7 +201,11 @@ class About(qt.QDialog):
isAvailable = h5py.h5z.filter_avail(filterId)
optionals.append(self.__formatOptionalFilters(name, isAvailable))
else:
- optionals.append(self.__formatOptionalLibraries("hdf5plugin", "hdf5plugin" in sys.modules))
+ optionals.append(
+ self.__formatOptionalLibraries(
+ "hdf5plugin", "hdf5plugin" in sys.modules
+ )
+ )
# Access to the logo in SVG or PNG
logo = icons.getQFile("silx:" + os.path.join("gui", "logo", "silx"))
@@ -217,7 +221,7 @@ class About(qt.QDialog):
qt_version=qt.qVersion(),
python_version=sys.version.replace("\n", "<br />"),
optional_lib="<br />".join(optionals),
- silx_image_path=logo.fileName()
+ silx_image_path=logo.fileName(),
)
self.__label.setText(message.format(**info))
@@ -225,10 +229,14 @@ class About(qt.QDialog):
def __updateSize(self):
"""Force the size to a QMessageBox like size."""
- if qt.BINDING in ("PySide2", "PyQt5"):
- screenSize = qt.QApplication.desktop().availableGeometry(qt.QCursor.pos()).size()
+ if qt.BINDING == "PyQt5":
+ screenSize = (
+ qt.QApplication.desktop().availableGeometry(qt.QCursor.pos()).size()
+ )
else: # Qt6
- screenSize = qt.QApplication.instance().primaryScreen().availableGeometry().size()
+ screenSize = (
+ qt.QApplication.instance().primaryScreen().availableGeometry().size()
+ )
hardLimit = min(screenSize.width() - 480, 1000)
if screenSize.width() <= 1024:
hardLimit = screenSize.width()
diff --git a/src/silx/app/view/ApplicationContext.py b/src/silx/app/view/ApplicationContext.py
index 30dad7d..157b8cc 100644
--- a/src/silx/app/view/ApplicationContext.py
+++ b/src/silx/app/view/ApplicationContext.py
@@ -1,5 +1,5 @@
# /*##########################################################################
-# Copyright (C) 2016-2018 European Synchrotron Radiation Facility
+# Copyright (C) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -28,6 +28,7 @@ __date__ = "23/05/2018"
import weakref
import logging
+from collections.abc import Sequence
import silx
from silx.gui.data.DataViews import DataViewHooks
@@ -69,15 +70,20 @@ class ApplicationContext(DataViewHooks):
if settings is None:
return
settings.beginGroup("library")
+ mplTightLayout = settings.value("mpl.tight_layout", False, bool)
plotBackend = settings.value("plot.backend", "")
plotImageYAxisOrientation = settings.value("plot-image.y-axis-orientation", "")
settings.endGroup()
# Use matplotlib backend by default
- silx.config.DEFAULT_PLOT_BACKEND = \
- "opengl" if plotBackend == "opengl" else "matplotlib"
+ silx.config.DEFAULT_PLOT_BACKEND = (
+ ("opengl", "matplotlib") if plotBackend == "opengl" else "matplotlib"
+ )
if plotImageYAxisOrientation != "":
- silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION = plotImageYAxisOrientation
+ silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION = (
+ plotImageYAxisOrientation
+ )
+ silx.config._MPL_TIGHT_LAYOUT = mplTightLayout
def restoreSettings(self):
"""Restore the settings of all the application"""
@@ -121,8 +127,12 @@ class ApplicationContext(DataViewHooks):
settings.endGroup()
settings.beginGroup("library")
- settings.setValue("plot.backend", silx.config.DEFAULT_PLOT_BACKEND)
- settings.setValue("plot-image.y-axis-orientation", silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION)
+ settings.setValue("plot.backend", self.getDefaultPlotBackend())
+ settings.setValue(
+ "plot-image.y-axis-orientation",
+ silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION,
+ )
+ settings.setValue("mpl.tight_layout", silx.config._MPL_TIGHT_LAYOUT)
settings.endGroup()
settings.beginGroup("recent-files")
@@ -162,8 +172,7 @@ class ApplicationContext(DataViewHooks):
self.__recentFiles.pop()
def clearRencentFiles(self):
- """Clear the history of the rencent files.
- """
+ """Clear the history of the rencent files."""
self.__recentFiles[:] = []
def getColormap(self, view):
@@ -192,3 +201,17 @@ class ApplicationContext(DataViewHooks):
dialog.setModal(False)
self.__defaultColormapDialog = dialog
return self.__defaultColormapDialog
+
+ @staticmethod
+ def getDefaultPlotBackend() -> str:
+ """Returns default plot backend as a str from current config"""
+ backend = silx.config.DEFAULT_PLOT_BACKEND
+ if isinstance(backend, str):
+ return backend
+ if (
+ isinstance(backend, Sequence)
+ and len(backend)
+ and isinstance(backend[0], str)
+ ):
+ return backend[0]
+ return "matplotlib" # fallback
diff --git a/src/silx/app/view/CustomNxdataWidget.py b/src/silx/app/view/CustomNxdataWidget.py
index 3c79c0d..3ec62c0 100644
--- a/src/silx/app/view/CustomNxdataWidget.py
+++ b/src/silx/app/view/CustomNxdataWidget.py
@@ -568,7 +568,7 @@ class _Model(qt.QStandardItemModel):
"""
if isinstance(item, _NxDataItem):
parent = item.parent()
- assert(parent is None)
+ assert parent is None
model = item.model()
model.removeRow(item.row())
else:
@@ -693,7 +693,7 @@ class CustomNxDataToolBar(qt.QToolBar):
def setCustomNxDataWidget(self, widget):
"""Set the linked CustomNxdataWidget to this toolbar."""
- assert(isinstance(widget, CustomNxdataWidget))
+ assert isinstance(widget, CustomNxdataWidget)
if self.__nxdataWidget is not None:
selectionModel = self.__nxdataWidget.selectionModel()
selectionModel.currentChanged.disconnect(self.__currentSelectionChanged)
@@ -713,7 +713,9 @@ class CustomNxDataToolBar(qt.QToolBar):
item = model.itemFromIndex(index)
self.__removeNxDataAction.setEnabled(isinstance(item, _NxDataItem))
self.__removeNxDataAxisAction.setEnabled(isinstance(item, _DatasetAxisItemRow))
- self.__addNxDataAxisAction.setEnabled(isinstance(item, _NxDataItem) or isinstance(item, _DatasetItemRow))
+ self.__addNxDataAxisAction.setEnabled(
+ isinstance(item, _NxDataItem) or isinstance(item, _DatasetItemRow)
+ )
class _HashDropZones(qt.QStyledItemDelegate):
@@ -847,7 +849,9 @@ class CustomNxdataWidget(qt.QTreeView):
if isinstance(item, _NxDataItem):
action = qt.QAction("Add a new axis", menu)
- action.triggered.connect(lambda: weakself.model().appendAxisToNxdataItem(item))
+ action.triggered.connect(
+ lambda: weakself.model().appendAxisToNxdataItem(item)
+ )
action.setIcon(icons.getQIcon("nxdata-axis-add"))
action.setIconVisibleInMenu(True)
menu.addAction(action)
diff --git a/src/silx/app/view/DataPanel.py b/src/silx/app/view/DataPanel.py
index d4a0e63..592a520 100644
--- a/src/silx/app/view/DataPanel.py
+++ b/src/silx/app/view/DataPanel.py
@@ -37,7 +37,6 @@ _logger = logging.getLogger(__name__)
class _HeaderLabel(qt.QLabel):
-
def __init__(self, parent=None):
qt.QLabel.__init__(self, parent=parent)
self.setFrameShape(qt.QFrame.StyledPanel)
@@ -89,7 +88,6 @@ class _HeaderLabel(qt.QLabel):
class DataPanel(qt.QWidget):
-
def __init__(self, parent=None, context=None):
qt.QWidget.__init__(self, parent=parent)
diff --git a/src/silx/app/view/Viewer.py b/src/silx/app/view/Viewer.py
index d9ecf6a..12426a1 100644
--- a/src/silx/app/view/Viewer.py
+++ b/src/silx/app/view/Viewer.py
@@ -1,5 +1,5 @@
# /*##########################################################################
-# Copyright (C) 2016-2022 European Synchrotron Radiation Facility
+# Copyright (C) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -22,15 +22,18 @@
# ############################################################################*/
"""Browse a data file with a GUI"""
+from __future__ import annotations
+
__authors__ = ["V. Valls"]
__license__ = "MIT"
__date__ = "15/01/2019"
import os
-import collections
import logging
import functools
+import traceback
+from types import TracebackType
from typing import Optional
import silx.io.nxdata
@@ -40,7 +43,7 @@ import silx.gui.hdf5
from .ApplicationContext import ApplicationContext
from .CustomNxdataWidget import CustomNxdataWidget
from .CustomNxdataWidget import CustomNxDataToolBar
-from . import utils
+from ..utils import parseutils
from silx.gui.utils import projecturl
from .DataPanel import DataPanel
@@ -65,6 +68,8 @@ class Viewer(qt.QMainWindow):
silxIcon = icons.getQIcon("silx")
self.setWindowIcon(silxIcon)
+ self.__error = ""
+
self.__context = self.createApplicationContext(settings)
self.__context.restoreLibrarySettings()
@@ -87,7 +92,9 @@ class Viewer(qt.QMainWindow):
treeModel.sigH5pyObjectRemoved.connect(self.__h5FileRemoved)
treeModel.sigH5pyObjectSynchronized.connect(self.__h5FileSynchonized)
treeModel.setDatasetDragEnabled(True)
- self.__treeModelSorted = silx.gui.hdf5.NexusSortFilterProxyModel(self.__treeview)
+ self.__treeModelSorted = silx.gui.hdf5.NexusSortFilterProxyModel(
+ self.__treeview
+ )
self.__treeModelSorted.setSourceModel(treeModel)
self.__treeModelSorted.sort(0, qt.Qt.AscendingOrder)
self.__treeModelSorted.setSortCaseSensitivity(qt.Qt.CaseInsensitive)
@@ -142,8 +149,8 @@ class Viewer(qt.QMainWindow):
columns.insert(1, treeModel.DESCRIPTION_COLUMN)
self.__treeview.header().setSections(columns)
- self._iconUpward = icons.getQIcon('plot-yup')
- self._iconDownward = icons.getQIcon('plot-ydown')
+ self._iconUpward = icons.getQIcon("plot-yup")
+ self._iconDownward = icons.getQIcon("plot-ydown")
self.createActions()
self.createMenus()
@@ -162,23 +169,22 @@ class Viewer(qt.QMainWindow):
action.setText("Refresh")
action.setToolTip("Refresh all selected items")
action.triggered.connect(self.__refreshSelected)
- action.setShortcut(qt.QKeySequence(qt.Qt.Key_F5))
+ action.setShortcuts(
+ [
+ qt.QKeySequence(qt.Qt.Key_F5),
+ qt.QKeySequence(qt.Qt.CTRL | qt.Qt.Key_R),
+ ]
+ )
toolbar.addAction(action)
treeView.addAction(action)
self.__refreshAction = action
- # Another shortcut for refresh
- action = qt.QAction(toolbar)
- action.setShortcut(qt.QKeySequence(qt.Qt.CTRL | qt.Qt.Key_R))
- treeView.addAction(action)
- action.triggered.connect(self.__refreshSelected)
-
action = qt.QAction(toolbar)
# action.setIcon(icons.getQIcon("view-refresh"))
action.setText("Close")
action.setToolTip("Close selected item")
action.triggered.connect(self.__removeSelected)
- action.setShortcut(qt.QKeySequence(qt.Qt.Key_Delete))
+ action.setShortcut(qt.QKeySequence.Delete)
treeView.addAction(action)
self.__closeAction = action
@@ -254,8 +260,7 @@ class Viewer(qt.QMainWindow):
qt.QApplication.restoreOverrideCursor()
def __refreshSelected(self):
- """Refresh all selected items
- """
+ """Refresh all selected items"""
qt.QApplication.setOverrideCursor(qt.Qt.WaitCursor)
selection = self.__treeview.selectionModel()
@@ -274,8 +279,12 @@ class Viewer(qt.QMainWindow):
rootRow = rootIndex.row()
relativePath = self.__getRelativePath(model, rootIndex, index)
selectedItems.append((rootRow, relativePath))
- h5 = model.data(rootIndex, role=silx.gui.hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE)
- item = model.data(rootIndex, role=silx.gui.hdf5.Hdf5TreeModel.H5PY_ITEM_ROLE)
+ h5 = model.data(
+ rootIndex, role=silx.gui.hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE
+ )
+ item = model.data(
+ rootIndex, role=silx.gui.hdf5.Hdf5TreeModel.H5PY_ITEM_ROLE
+ )
h5files.append((h5, item._openedPath))
if len(h5files) == 0:
@@ -350,7 +359,7 @@ class Viewer(qt.QMainWindow):
path = node._getCanonicalName()
if rootPath is None:
rootPath = path
- path = path[len(rootPath):]
+ path = path[len(rootPath) :]
paths.append(path)
for child in range(model.rowCount(index)):
@@ -455,9 +464,9 @@ class Viewer(qt.QMainWindow):
layout.addWidget(customNxdataWidget)
return widget
- def __h5FileLoaded(self, loadedH5):
+ def __h5FileLoaded(self, loadedH5, filename):
self.__context.pushRecentFile(loadedH5.file.filename)
- if loadedH5.file.filename == self.__displayIt:
+ if filename == self.__displayIt:
self.__displayIt = None
self.displayData(loadedH5)
@@ -521,11 +530,7 @@ class Viewer(qt.QMainWindow):
size = settings.value("size", qt.QSize(640, 480))
pos = settings.value("pos", qt.QPoint())
isFullScreen = settings.value("full-screen", False)
- try:
- if not isinstance(isFullScreen, bool):
- isFullScreen = utils.stringToBool(isFullScreen)
- except ValueError:
- isFullScreen = False
+ isFullScreen = parseutils.to_bool(isFullScreen, False)
settings.endGroup()
settings.beginGroup("mainlayout")
@@ -542,23 +547,14 @@ class Viewer(qt.QMainWindow):
except Exception:
_logger.debug("Backtrace", exc_info=True)
isVisible = settings.value("custom-nxdata-window-visible", False)
- try:
- if not isinstance(isVisible, bool):
- isVisible = utils.stringToBool(isVisible)
- except ValueError:
- isVisible = False
+ isVisible = parseutils.to_bool(isVisible, False)
self.__customNxdataWindow.setVisible(isVisible)
self._displayCustomNxdataWindow.setChecked(isVisible)
-
settings.endGroup()
settings.beginGroup("content")
isSorted = settings.value("is-sorted", True)
- try:
- if not isinstance(isSorted, bool):
- isSorted = utils.stringToBool(isSorted)
- except ValueError:
- isSorted = True
+ isSorted = parseutils.to_bool(isSorted, True)
self.setContentSorted(isSorted)
settings.endGroup()
@@ -571,12 +567,13 @@ class Viewer(qt.QMainWindow):
def createActions(self):
action = qt.QAction("E&xit", self)
- action.setShortcuts(qt.QKeySequence.Quit)
+ action.setShortcut(qt.QKeySequence.Quit)
action.setStatusTip("Exit the application")
action.triggered.connect(self.close)
self._exitAction = action
action = qt.QAction("&Open...", self)
+ action.setShortcut(qt.QKeySequence.Open)
action.setStatusTip("Open a file")
action.triggered.connect(self.open)
self._openAction = action
@@ -586,6 +583,7 @@ class Viewer(qt.QMainWindow):
self._openRecentMenu = menu
action = qt.QAction("Close All", self)
+ action.setShortcut(qt.QKeySequence.Close)
action.setStatusTip("Close all opened files")
action.triggered.connect(self.closeAll)
self._closeAllAction = action
@@ -627,9 +625,11 @@ class Viewer(qt.QMainWindow):
# Plot image orientation
self._plotImageOrientationMenu = qt.QMenu(
- "Default plot image y-axis orientation", self)
+ "Default plot image y-axis orientation", self
+ )
self._plotImageOrientationMenu.setStatusTip(
- "Select the default y-axis orientation used by plot displaying images")
+ "Select the default y-axis orientation used by plot displaying images"
+ )
group = qt.QActionGroup(self)
group.setExclusive(True)
@@ -652,10 +652,19 @@ class Viewer(qt.QMainWindow):
self._plotImageOrientationMenu.addAction(action)
self._useYAxisOrientationUpward = action
+ # mpl layout
+
+ action = qt.QAction("Use MPL tight layout", self)
+ action.setCheckable(True)
+ action.triggered.connect(self.__forceMplTightLayout)
+ self._useMplTightLayout = action
+
# Windows
action = qt.QAction("Show custom NXdata selector", self)
- action.setStatusTip("Show a widget which allow to create plot by selecting data and axes")
+ action.setStatusTip(
+ "Show a widget which allow to create plot by selecting data and axes"
+ )
action.setCheckable(True)
action.setShortcut(qt.QKeySequence(qt.Qt.Key_F6))
action.toggled.connect(self.__toggleCustomNxdataWindow)
@@ -674,7 +683,9 @@ class Viewer(qt.QMainWindow):
baseName = os.path.basename(filePath)
action = qt.QAction(baseName, self)
action.setToolTip(filePath)
- action.triggered.connect(functools.partial(self.__openRecentFile, filePath))
+ action.triggered.connect(
+ functools.partial(self.__openRecentFile, filePath)
+ )
self._openRecentMenu.addAction(action)
self._openRecentMenu.addSeparator()
baseName = os.path.basename(filePath)
@@ -696,17 +707,18 @@ class Viewer(qt.QMainWindow):
# plot backend
title = self._plotBackendMenu.title().split(": ", 1)[0]
- self._plotBackendMenu.setTitle("%s: %s" % (title, silx.config.DEFAULT_PLOT_BACKEND))
+ backend = self.__context.getDefaultPlotBackend()
+ self._plotBackendMenu.setTitle(f"{title}: {backend}")
action = self._usePlotWithMatplotlib
- action.setChecked(silx.config.DEFAULT_PLOT_BACKEND in ["matplotlib", "mpl"])
+ action.setChecked(backend in ["matplotlib", "mpl"])
title = action.text().split(" (", 1)[0]
if not action.isChecked():
title += " (applied after application restart)"
action.setText(title)
action = self._usePlotWithOpengl
- action.setChecked(silx.config.DEFAULT_PLOT_BACKEND in ["opengl", "gl"])
+ action.setChecked(backend in ["opengl", "gl"])
title = action.text().split(" (", 1)[0]
if not action.isChecked():
title += " (applied after application restart)"
@@ -721,19 +733,28 @@ class Viewer(qt.QMainWindow):
menu.setIcon(self._iconUpward)
action = self._useYAxisOrientationDownward
- action.setChecked(silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == "downward")
+ action.setChecked(
+ silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == "downward"
+ )
title = action.text().split(" (", 1)[0]
if not action.isChecked():
title += " (applied after application restart)"
action.setText(title)
action = self._useYAxisOrientationUpward
- action.setChecked(silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION != "downward")
+ action.setChecked(
+ silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION != "downward"
+ )
title = action.text().split(" (", 1)[0]
if not action.isChecked():
title += " (applied after application restart)"
action.setText(title)
+ # mpl
+
+ action = self._useMplTightLayout
+ action.setChecked(silx.config._MPL_TIGHT_LAYOUT)
+
def createMenus(self):
fileMenu = self.menuBar().addMenu("&File")
fileMenu.addAction(self._openAction)
@@ -746,6 +767,7 @@ class Viewer(qt.QMainWindow):
optionMenu = self.menuBar().addMenu("&Options")
optionMenu.addMenu(self._plotImageOrientationMenu)
optionMenu.addMenu(self._plotBackendMenu)
+ optionMenu.addAction(self._useMplTightLayout)
optionMenu.aboutToShow.connect(self.__updateOptionMenu)
viewMenu = self.menuBar().addMenu("&Views")
@@ -755,6 +777,17 @@ class Viewer(qt.QMainWindow):
helpMenu.addAction(self._aboutAction)
helpMenu.addAction(self._documentationAction)
+ self.__errorButton = qt.QToolButton(self)
+ self.__errorButton.setIcon(
+ self.style().standardIcon(qt.QStyle.SP_MessageBoxWarning)
+ )
+ self.__errorButton.setToolTip(
+ "An error occured!\nClick to display last error\nor check messages in the console"
+ )
+ self.__errorButton.setVisible(False)
+ self.__errorButton.clicked.connect(self.__errorButtonClicked)
+ self.menuBar().setCornerWidget(self.__errorButton)
+
def open(self):
dialog = self.createFileDialog()
if self.__dialogState is None:
@@ -784,7 +817,7 @@ class Viewer(qt.QMainWindow):
dialog.setModal(True)
# NOTE: hdf5plugin have to be loaded before
- extensions = collections.OrderedDict()
+ extensions = {}
for description, ext in silx.io.supported_extensions().items():
extensions[description] = " ".join(sorted(list(ext)))
@@ -812,6 +845,7 @@ class Viewer(qt.QMainWindow):
def about(self):
from .About import About
+
About.about(self, "Silx viewer")
def showDocumentation(self):
@@ -826,7 +860,6 @@ class Viewer(qt.QMainWindow):
"""
sort = bool(sort)
if sort != self.isContentSorted():
-
# save expanded nodes
pathss = []
root = qt.QModelIndex()
@@ -837,7 +870,8 @@ class Viewer(qt.QMainWindow):
pathss.append(paths)
self.__treeview.setModel(
- self.__treeModelSorted if sort else self.__treeModelSorted.sourceModel())
+ self.__treeModelSorted if sort else self.__treeModelSorted.sourceModel()
+ )
self._sortContentAction.setChecked(self.isContentSorted())
# restore expanded nodes
@@ -864,7 +898,10 @@ class Viewer(qt.QMainWindow):
silx.config.DEFAULT_PLOT_BACKEND = "matplotlib"
def __forceOpenglBackend(self):
- silx.config.DEFAULT_PLOT_BACKEND = "opengl"
+ silx.config.DEFAULT_PLOT_BACKEND = "opengl", "matplotlib"
+
+ def __forceMplTightLayout(self):
+ silx.config._MPL_TIGHT_LAYOUT = self._useMplTightLayout.isChecked()
def appendFile(self, filename):
if self.__displayIt is None:
@@ -873,8 +910,7 @@ class Viewer(qt.QMainWindow):
self.__treeview.findHdf5TreeModel().appendFile(filename)
def displaySelectedData(self):
- """Called to update the dataviewer with the selected data.
- """
+ """Called to update the dataviewer with the selected data."""
selected = list(self.__treeview.selectedH5Nodes(ignoreBrokenLinks=False))
if len(selected) == 1:
# Update the viewer for a single selection
@@ -884,8 +920,7 @@ class Viewer(qt.QMainWindow):
_logger.debug("Too many data selected")
def displayData(self, data):
- """Called to update the dataviewer with a secific data.
- """
+ """Called to update the dataviewer with a secific data."""
self.__dataPanel.setData(data)
def displaySelectedCustomData(self):
@@ -957,8 +992,42 @@ class Viewer(qt.QMainWindow):
if silx.io.is_file(h5):
action = qt.QAction("Close %s" % obj.local_filename, event.source())
- action.triggered.connect(lambda: self.__treeview.findHdf5TreeModel().removeH5pyObject(h5))
+ action.triggered.connect(
+ lambda: self.__treeview.findHdf5TreeModel().removeH5pyObject(h5)
+ )
menu.addAction(action)
- action = qt.QAction("Synchronize %s" % obj.local_filename, event.source())
+ action = qt.QAction(
+ "Synchronize %s" % obj.local_filename, event.source()
+ )
action.triggered.connect(lambda: self.__synchronizeH5pyObject(h5))
menu.addAction(action)
+
+ def __errorButtonClicked(self):
+ button = qt.QMessageBox.warning(
+ self,
+ "Error",
+ self.getError(),
+ qt.QMessageBox.Reset | qt.QMessageBox.Close,
+ qt.QMessageBox.Close,
+ )
+ if button == qt.QMessageBox.Reset:
+ self.setError("")
+
+ def getError(self) -> str:
+ """Returns error information string"""
+ return self.__error
+
+ def setError(self, error: str):
+ """Set error information string"""
+ if error == self.__error:
+ return
+
+ self.__error = error
+ self.__errorButton.setVisible(error != "")
+
+ def setErrorFromException(
+ self, type_: type[BaseException], value: BaseException, trace: TracebackType
+ ):
+ """Set information about the last exception that occured"""
+ formattedTrace = "\n".join(traceback.format_tb(trace))
+ self.setError(f"{type_.__name__}:\n{value}\n\n{formattedTrace}")
diff --git a/src/silx/app/view/main.py b/src/silx/app/view/main.py
index c37b8aa..f6c5274 100644
--- a/src/silx/app/view/main.py
+++ b/src/silx/app/view/main.py
@@ -1,5 +1,5 @@
# /*##########################################################################
-# Copyright (C) 2016-2022 European Synchrotron Radiation Facility
+# Copyright (C) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -27,12 +27,12 @@ __license__ = "MIT"
__date__ = "17/01/2019"
import argparse
-import glob
import logging
import os
import signal
import sys
-from typing import Generator, Iterable
+import traceback
+from silx.app.utils import parseutils
_logger = logging.getLogger(__name__)
@@ -42,72 +42,53 @@ _logger = logging.getLogger(__name__)
def createParser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
- 'files',
+ "files",
nargs=argparse.ZERO_OR_MORE,
- help='Data file to show (h5 file, edf files, spec files)')
+ help="Data file to show (h5 file, edf files, spec files)",
+ )
parser.add_argument(
- '--debug',
+ "--slices",
+ dest="slices",
+ default=tuple(),
+ type=int,
+ nargs="+",
+ help="List of slice indices to open (Only for dataset)",
+ )
+ parser.add_argument(
+ "--debug",
dest="debug",
action="store_true",
default=False,
- help='Set logging system in debug mode')
+ help="Set logging system in debug mode",
+ )
parser.add_argument(
- '--use-opengl-plot',
+ "--use-opengl-plot",
dest="use_opengl_plot",
action="store_true",
default=False,
- help='Use OpenGL for plots (instead of matplotlib)')
+ help="Use OpenGL for plots (instead of matplotlib)",
+ )
parser.add_argument(
- '-f', '--fresh',
+ "-f",
+ "--fresh",
dest="fresh_preferences",
action="store_true",
default=False,
- help='Start the application using new fresh user preferences')
+ help="Start the application using new fresh user preferences",
+ )
parser.add_argument(
- '--hdf5-file-locking',
+ "--hdf5-file-locking",
dest="hdf5_file_locking",
action="store_true",
default=False,
- help='Start the application with HDF5 file locking enabled (it is disabled by default)')
+ help="Start the application with HDF5 file locking enabled (it is disabled by default)",
+ )
return parser
-def filesArgToUrls(filenames: Iterable[str]) -> Generator[object, None, None]:
- """Expand filenames and HDF5 data path in files input argument"""
- # Imports here so they are performed after setting HDF5_USE_FILE_LOCKING and logging level
- import silx.io
- from silx.io.utils import match
- from silx.io.url import DataUrl
- import silx.utils.files
-
- for filename in filenames:
- url = DataUrl(filename)
-
- for file_path in sorted(silx.utils.files.expand_filenames([url.file_path()])):
- if url.data_path() is not None and glob.has_magic(url.data_path()):
- try:
- with silx.io.open(file_path) as f:
- data_paths = list(match(f, url.data_path()))
- except BaseException as e:
- _logger.error(
- f"Error searching HDF5 path pattern '{url.data_path()}' in file '{file_path}': Ignored")
- _logger.error(e.args[0])
- _logger.debug("Backtrace", exc_info=True)
- continue
- else:
- data_paths = [url.data_path()]
-
- for data_path in data_paths:
- yield DataUrl(
- file_path=file_path,
- data_path=data_path,
- data_slice=url.data_slice(),
- scheme=url.scheme(),
- )
-
-
def createWindow(parent, settings):
from .Viewer import Viewer
+
window = Viewer(parent=None, settings=settings)
return window
@@ -127,7 +108,7 @@ def mainQt(options):
except ImportError:
_logger.debug("No resource module available")
else:
- if hasattr(resource, 'RLIMIT_NOFILE'):
+ if hasattr(resource, "RLIMIT_NOFILE"):
try:
hard_nofile = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
resource.setrlimit(resource.RLIMIT_NOFILE, (hard_nofile, hard_nofile))
@@ -137,9 +118,9 @@ def mainQt(options):
_logger.debug("Set max opened files to %d", hard_nofile)
# This needs to be done prior to load HDF5
- hdf5_file_locking = 'TRUE' if options.hdf5_file_locking else 'FALSE'
- _logger.info('Set HDF5_USE_FILE_LOCKING=%s', hdf5_file_locking)
- os.environ['HDF5_USE_FILE_LOCKING'] = hdf5_file_locking
+ hdf5_file_locking = "TRUE" if options.hdf5_file_locking else "FALSE"
+ _logger.info("Set HDF5_USE_FILE_LOCKING=%s", hdf5_file_locking)
+ os.environ["HDF5_USE_FILE_LOCKING"] = hdf5_file_locking
try:
# it should be loaded before h5py
@@ -151,6 +132,7 @@ def mainQt(options):
import silx
from silx.gui import qt
+
# Make sure matplotlib is configured
# Needed for Debian 8: compatibility between Qt4/Qt5 and old matplotlib
import silx.gui.utils.matplotlib # noqa
@@ -163,7 +145,6 @@ def mainQt(options):
qt.QApplication.quit()
signal.signal(signal.SIGINT, sigintHandler)
- sys.excepthook = qt.exceptionHandler
timer = qt.QTimer()
timer.start(500)
@@ -171,23 +152,30 @@ def mainQt(options):
# catched
timer.timeout.connect(lambda: None)
- settings = qt.QSettings(qt.QSettings.IniFormat,
- qt.QSettings.UserScope,
- "silx",
- "silx-view",
- None)
+ settings = qt.QSettings(
+ qt.QSettings.IniFormat, qt.QSettings.UserScope, "silx", "silx-view", None
+ )
if options.fresh_preferences:
settings.clear()
window = createWindow(parent=None, settings=settings)
window.setAttribute(qt.Qt.WA_DeleteOnClose, True)
+ def exceptHook(type_, value, trace):
+ _logger.error("An error occured in silx view:")
+ _logger.error("%s %s %s", type_, value, "".join(traceback.format_tb(trace)))
+ try:
+ window.setErrorFromException(type_, value, trace)
+ except Exception:
+ pass
+
+ sys.excepthook = exceptHook
+
if options.use_opengl_plot:
# It have to be done after the settings (after the Viewer creation)
silx.config.DEFAULT_PLOT_BACKEND = "opengl"
-
- for url in filesArgToUrls(options.files):
+ for url in parseutils.filenames_to_dataurls(options.files, options.slices):
# TODO: Would be nice to add a process widget and a cancel button
try:
window.appendFile(url.path())
@@ -214,5 +202,5 @@ def main(argv):
mainQt(options)
-if __name__ == '__main__':
+if __name__ == "__main__":
main(sys.argv)
diff --git a/src/silx/app/view/test/test_launcher.py b/src/silx/app/view/test/test_launcher.py
index 8ccf4af..49b1032 100644
--- a/src/silx/app/view/test/test_launcher.py
+++ b/src/silx/app/view/test/test_launcher.py
@@ -84,23 +84,22 @@ class TestLauncher(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdir:
# Copy file to temporary dir to avoid import from current dir.
- script = os.path.join(tmpdir, 'launcher.py')
+ script = os.path.join(tmpdir, "launcher.py")
shutil.copyfile(filename, script)
command_line = [sys.executable, script] + list(args)
_logger.info("Execute: %s", " ".join(command_line))
- p = subprocess.Popen(command_line,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- env=env)
+ p = subprocess.Popen(
+ command_line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
+ )
out, err = p.communicate()
_logger.info("Return code: %d", p.returncode)
try:
- out = out.decode('utf-8')
+ out = out.decode("utf-8")
except UnicodeError:
pass
try:
- err = err.decode('utf-8')
+ err = err.decode("utf-8")
except UnicodeError:
pass
diff --git a/src/silx/app/view/test/test_view.py b/src/silx/app/view/test/test_view.py
index 362995a..1eb588b 100644
--- a/src/silx/app/view/test/test_view.py
+++ b/src/silx/app/view/test/test_view.py
@@ -115,7 +115,6 @@ class TestAbout(TestCaseQt):
@pytest.mark.usefixtures("qapp")
@pytest.mark.usefixtures("data_class_attr")
class TestDataPanel(TestCaseQt):
-
def testConstruct(self):
widget = DataPanel()
self.qWaitForWindowExposed(widget)
@@ -169,7 +168,7 @@ class TestDataPanel(TestCaseQt):
self.assertIs(widget.getCustomNxdataItem(), data)
def testRemoveDatasetsFrom(self):
- f = h5py.File(self.data_h5, mode='r')
+ f = h5py.File(self.data_h5, mode="r")
try:
widget = DataPanel()
widget.setData(f["arrays/scalar"])
@@ -180,8 +179,8 @@ class TestDataPanel(TestCaseQt):
f.close()
def testReplaceDatasetsFrom(self):
- f = h5py.File(self.data_h5, mode='r')
- f2 = h5py.File(self.data2_h5, mode='r')
+ f = h5py.File(self.data_h5, mode="r")
+ f2 = h5py.File(self.data2_h5, mode="r")
try:
widget = DataPanel()
widget.setData(f["arrays/scalar"])
@@ -197,7 +196,6 @@ class TestDataPanel(TestCaseQt):
@pytest.mark.usefixtures("qapp")
@pytest.mark.usefixtures("data_class_attr")
class TestCustomNxdataWidget(TestCaseQt):
-
def testConstruct(self):
widget = CustomNxdataWidget()
self.qWaitForWindowExposed(widget)
@@ -250,7 +248,7 @@ class TestCustomNxdataWidget(TestCaseQt):
self.assertFalse(item.isValid())
def testRemoveDatasetsFrom(self):
- f = h5py.File(self.data_h5, mode='r')
+ f = h5py.File(self.data_h5, mode="r")
try:
widget = CustomNxdataWidget()
model = widget.model()
@@ -262,8 +260,8 @@ class TestCustomNxdataWidget(TestCaseQt):
f.close()
def testReplaceDatasetsFrom(self):
- f = h5py.File(self.data_h5, mode='r')
- f2 = h5py.File(self.data2_h5, mode='r')
+ f = h5py.File(self.data_h5, mode="r")
+ f2 = h5py.File(self.data2_h5, mode="r")
try:
widget = CustomNxdataWidget()
model = widget.model()
@@ -299,14 +297,18 @@ class TestCustomNxdataWidgetInteraction(TestCaseQt):
def testSelectedNxdata(self):
index = self.model.index(0, 0)
- self.selectionModel.setCurrentIndex(index, qt.QItemSelectionModel.ClearAndSelect)
+ self.selectionModel.setCurrentIndex(
+ index, qt.QItemSelectionModel.ClearAndSelect
+ )
nxdata = self.widget.selectedNxdata()
self.assertEqual(len(nxdata), 1)
self.assertIsNot(nxdata[0], None)
def testSelectedItems(self):
index = self.model.index(0, 0)
- self.selectionModel.setCurrentIndex(index, qt.QItemSelectionModel.ClearAndSelect)
+ self.selectionModel.setCurrentIndex(
+ index, qt.QItemSelectionModel.ClearAndSelect
+ )
items = self.widget.selectedItems()
self.assertEqual(len(items), 1)
self.assertIsNot(items[0], None)
diff --git a/src/silx/conftest.py b/src/silx/conftest.py
index bec67c0..9b43f5d 100644
--- a/src/silx/conftest.py
+++ b/src/silx/conftest.py
@@ -1,6 +1,9 @@
import pytest
import logging
import os
+from io import BytesIO
+
+import h5py
logger = logging.getLogger(__name__)
@@ -12,9 +15,6 @@ def _set_qt_binding(binding):
if binding == "pyqt5":
logger.info("Force using PyQt5")
import PyQt5.QtCore # noqa
- elif binding == "pyside2":
- logger.info("Force using PySide2")
- import PySide2.QtCore # noqa
elif binding == "pyside6":
logger.info("Force using PySide6")
import PySide6.QtCore # noqa
@@ -26,25 +26,46 @@ def _set_qt_binding(binding):
def pytest_addoption(parser):
- parser.addoption("--qt-binding", type=str, default=None, dest="qt_binding",
- help="Force using a Qt binding: 'PyQt5', 'PySide2', 'PySide6', 'PyQt6'")
- parser.addoption("--no-gui", dest="gui", default=True,
- action="store_false",
- help="Disable the test of the graphical use interface")
- parser.addoption("--no-opengl", dest="opengl", default=True,
- action="store_false",
- help="Disable tests using OpenGL")
- parser.addoption("--no-opencl", dest="opencl", default=True,
- action="store_false",
- help="Disable the test of the OpenCL part")
- parser.addoption("--low-mem", dest="low_mem", default=False,
- action="store_true",
- help="Disable test with large memory consumption (>100Mbyte")
+ parser.addoption(
+ "--qt-binding",
+ type=str,
+ default=None,
+ dest="qt_binding",
+ help="Force using a Qt binding: 'PyQt5', 'PySide6', 'PyQt6'",
+ )
+ parser.addoption(
+ "--no-gui",
+ dest="gui",
+ default=True,
+ action="store_false",
+ help="Disable the test of the graphical use interface",
+ )
+ parser.addoption(
+ "--no-opengl",
+ dest="opengl",
+ default=True,
+ action="store_false",
+ help="Disable tests using OpenGL",
+ )
+ parser.addoption(
+ "--no-opencl",
+ dest="opencl",
+ default=True,
+ action="store_false",
+ help="Disable the test of the OpenCL part",
+ )
+ parser.addoption(
+ "--low-mem",
+ dest="low_mem",
+ default=False,
+ action="store_true",
+ help="Disable test with large memory consumption (>100Mbyte",
+ )
def pytest_configure(config):
- if not config.getoption('opencl', True):
- os.environ['SILX_OPENCL'] = 'False' # Disable OpenCL support in silx
+ if not config.getoption("opencl", True):
+ os.environ["SILX_OPENCL"] = "False" # Disable OpenCL support in silx
_set_qt_binding(config.option.qt_binding)
@@ -52,6 +73,7 @@ def pytest_configure(config):
@pytest.fixture(scope="session")
def test_options(request):
from .test import utils
+
options = utils._TestOptions()
options.configure(request.config.option)
yield options
@@ -111,6 +133,7 @@ def qapp(use_gui, xvfb, request):
_set_qt_binding(request.config.option.qt_binding)
from silx.gui import qt
+
app = qt.QApplication.instance()
if app is None:
app = qt.QApplication([])
@@ -125,9 +148,17 @@ def qapp(use_gui, xvfb, request):
def qapp_utils(qapp):
"""Helper containing method to deal with QApplication and widget"""
from silx.gui.utils.testutils import TestCaseQt
+
utils = TestCaseQt()
utils.setUpClass()
utils.setUp()
yield utils
utils.tearDown()
utils.tearDownClass()
+
+
+@pytest.fixture
+def tmp_h5py_file():
+ with BytesIO() as buffer:
+ with h5py.File(buffer, mode="w") as h5file:
+ yield h5file
diff --git a/src/silx/gui/_glutils/Context.py b/src/silx/gui/_glutils/Context.py
index d2ddaa3..c0def5c 100644
--- a/src/silx/gui/_glutils/Context.py
+++ b/src/silx/gui/_glutils/Context.py
@@ -36,8 +36,10 @@ import contextlib
class _DEFAULT_CONTEXT(object):
"""The default value for OpenGL context"""
+
pass
+
_context = _DEFAULT_CONTEXT
"""The current OpenGL context"""
diff --git a/src/silx/gui/_glutils/FramebufferTexture.py b/src/silx/gui/_glutils/FramebufferTexture.py
index 75db264..6d1a8d9 100644
--- a/src/silx/gui/_glutils/FramebufferTexture.py
+++ b/src/silx/gui/_glutils/FramebufferTexture.py
@@ -53,13 +53,14 @@ class FramebufferTexture(object):
_PACKED_FORMAT = gl.GL_DEPTH24_STENCIL8, gl.GL_DEPTH_STENCIL
- def __init__(self,
- internalFormat,
- shape,
- stencilFormat=gl.GL_DEPTH24_STENCIL8,
- depthFormat=gl.GL_DEPTH24_STENCIL8,
- **kwargs):
-
+ def __init__(
+ self,
+ internalFormat,
+ shape,
+ stencilFormat=gl.GL_DEPTH24_STENCIL8,
+ depthFormat=gl.GL_DEPTH24_STENCIL8,
+ **kwargs,
+ ):
self._texture = Texture(internalFormat, shape=shape, **kwargs)
self._texture.prepare()
@@ -69,24 +70,28 @@ class FramebufferTexture(object):
with self: # Bind FBO
# Attachments
- gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER,
- gl.GL_COLOR_ATTACHMENT0,
- gl.GL_TEXTURE_2D,
- self._texture.name,
- 0)
+ gl.glFramebufferTexture2D(
+ gl.GL_FRAMEBUFFER,
+ gl.GL_COLOR_ATTACHMENT0,
+ gl.GL_TEXTURE_2D,
+ self._texture.name,
+ 0,
+ )
height, width = self._texture.shape
if stencilFormat is not None:
self._stencilId = gl.glGenRenderbuffers(1)
gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self._stencilId)
- gl.glRenderbufferStorage(gl.GL_RENDERBUFFER,
- stencilFormat,
- width, height)
- gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER,
- gl.GL_STENCIL_ATTACHMENT,
- gl.GL_RENDERBUFFER,
- self._stencilId)
+ gl.glRenderbufferStorage(
+ gl.GL_RENDERBUFFER, stencilFormat, width, height
+ )
+ gl.glFramebufferRenderbuffer(
+ gl.GL_FRAMEBUFFER,
+ gl.GL_STENCIL_ATTACHMENT,
+ gl.GL_RENDERBUFFER,
+ self._stencilId,
+ )
else:
self._stencilId = None
@@ -96,13 +101,15 @@ class FramebufferTexture(object):
else:
self._depthId = gl.glGenRenderbuffers(1)
gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self._depthId)
- gl.glRenderbufferStorage(gl.GL_RENDERBUFFER,
- depthFormat,
- width, height)
- gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER,
- gl.GL_DEPTH_ATTACHMENT,
- gl.GL_RENDERBUFFER,
- self._depthId)
+ gl.glRenderbufferStorage(
+ gl.GL_RENDERBUFFER, depthFormat, width, height
+ )
+ gl.glFramebufferRenderbuffer(
+ gl.GL_FRAMEBUFFER,
+ gl.GL_DEPTH_ATTACHMENT,
+ gl.GL_RENDERBUFFER,
+ self._depthId,
+ )
else:
self._depthId = None
@@ -110,7 +117,8 @@ class FramebufferTexture(object):
if status != gl.GL_FRAMEBUFFER_COMPLETE:
_logger.error(
"OpenGL framebuffer initialization not complete, display may fail (error %d)",
- status)
+ status,
+ )
@property
def shape(self):
@@ -130,8 +138,10 @@ class FramebufferTexture(object):
if self._name is not None:
return self._name
else:
- raise RuntimeError("No OpenGL framebuffer resource, \
- discard has already been called")
+ raise RuntimeError(
+ "No OpenGL framebuffer resource, \
+ discard has already been called"
+ )
def bind(self):
"""Bind this framebuffer for rendering"""
diff --git a/src/silx/gui/_glutils/OpenGLWidget.py b/src/silx/gui/_glutils/OpenGLWidget.py
index d35bb73..59fa4f0 100644
--- a/src/silx/gui/_glutils/OpenGLWidget.py
+++ b/src/silx/gui/_glutils/OpenGLWidget.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -43,16 +43,16 @@ from .._glutils import gl
_logger = logging.getLogger(__name__)
-if not hasattr(qt, 'QOpenGLWidget') and not hasattr(qt, 'QGLWidget'):
+if not hasattr(qt, "QOpenGLWidget") and not hasattr(qt, "QGLWidget"):
_OpenGLWidget = None
else:
- if hasattr(qt, 'QOpenGLWidget'): # PyQt>=5.4
- _logger.info('Using QOpenGLWidget')
+ if hasattr(qt, "QOpenGLWidget"): # PyQt>=5.4
+ _logger.info("Using QOpenGLWidget")
_BaseOpenGLWidget = qt.QOpenGLWidget
else:
- _logger.info('Using QGLWidget')
+ _logger.info("Using QGLWidget")
_BaseOpenGLWidget = qt.QGLWidget
class _OpenGLWidget(_BaseOpenGLWidget):
@@ -64,14 +64,17 @@ else:
It provides the error reason as a str.
"""
- def __init__(self, parent,
- alphaBufferSize=0,
- depthBufferSize=24,
- stencilBufferSize=8,
- version=(2, 0),
- f=qt.Qt.Widget):
+ def __init__(
+ self,
+ parent,
+ alphaBufferSize=0,
+ depthBufferSize=24,
+ stencilBufferSize=8,
+ version=(2, 0),
+ f=qt.Qt.Widget,
+ ):
# True if using QGLWidget, False if using QOpenGLWidget
- self.__legacy = not hasattr(qt, 'QOpenGLWidget')
+ self.__legacy = not hasattr(qt, "QOpenGLWidget")
self.__devicePixelRatio = 1.0
self.__requestedOpenGLVersion = int(version[0]), int(version[1])
@@ -131,12 +134,23 @@ else:
# Go through all OpenGL version flags checking support
flags = self.format().openGLVersionFlags()
- for version in ((1, 1), (1, 2), (1, 3), (1, 4), (1, 5),
- (2, 0), (2, 1),
- (3, 0), (3, 1), (3, 2), (3, 3),
- (4, 0)):
- versionFlag = getattr(qt.QGLFormat,
- 'OpenGL_Version_%d_%d' % version)
+ for version in (
+ (1, 1),
+ (1, 2),
+ (1, 3),
+ (1, 4),
+ (1, 5),
+ (2, 0),
+ (2, 1),
+ (3, 0),
+ (3, 1),
+ (3, 2),
+ (3, 3),
+ (4, 0),
+ ):
+ versionFlag = getattr(
+ qt.QGLFormat, "OpenGL_Version_%d_%d" % version
+ )
if not versionFlag & flags:
break
supportedVersion = version
@@ -171,13 +185,13 @@ else:
def initializeGL(self):
parent = self.parent()
if parent is None:
- _logger.error('_OpenGLWidget has no parent')
+ _logger.error("_OpenGLWidget has no parent")
return
# Check OpenGL version
if self.getOpenGLVersion() >= self.getRequestedOpenGLVersion():
try:
- gl.glGetError() # clear any previous error (if any)
+ gl.glGetError() # clear any previous error (if any)
version = gl.glGetString(gl.GL_VERSION)
except:
version = None
@@ -185,18 +199,19 @@ else:
if version:
self.__isValid = True
else:
- errMsg = 'OpenGL not available'
- if sys.platform.startswith('linux'):
- errMsg += ': If connected remotely, ' \
- 'GLX forwarding might be disabled.'
+ errMsg = "OpenGL not available"
+ if sys.platform.startswith("linux"):
+ errMsg += (
+ ": If connected remotely, "
+ "GLX forwarding might be disabled."
+ )
_logger.error(errMsg)
self.sigOpenGLContextError.emit(errMsg)
self.__isValid = False
else:
- errMsg = 'OpenGL %d.%d not available' % \
- self.getRequestedOpenGLVersion()
- _logger.error('OpenGL widget disabled: %s', errMsg)
+ errMsg = "OpenGL %d.%d not available" % self.getRequestedOpenGLVersion()
+ _logger.error("OpenGL widget disabled: %s", errMsg)
self.sigOpenGLContextError.emit(errMsg)
self.__isValid = False
@@ -206,7 +221,7 @@ else:
def paintGL(self):
parent = self.parent()
if parent is None:
- _logger.error('_OpenGLWidget has no parent')
+ _logger.error("_OpenGLWidget has no parent")
return
devicePixelRatio = self.window().windowHandle().devicePixelRatio()
@@ -224,7 +239,7 @@ else:
def resizeGL(self, width, height):
parent = self.parent()
if parent is None:
- _logger.error('_OpenGLWidget has no parent')
+ _logger.error("_OpenGLWidget has no parent")
return
if self.isValid():
@@ -256,12 +271,15 @@ class OpenGLWidget(qt.QWidget):
:param f: see :class:`QWidget`
"""
- def __init__(self, parent=None,
- alphaBufferSize=0,
- depthBufferSize=24,
- stencilBufferSize=8,
- version=(2, 0),
- f=qt.Qt.Widget):
+ def __init__(
+ self,
+ parent=None,
+ alphaBufferSize=0,
+ depthBufferSize=24,
+ stencilBufferSize=8,
+ version=(2, 0),
+ f=qt.Qt.Widget,
+ ):
super(OpenGLWidget, self).__init__(parent, f)
layout = qt.QHBoxLayout(self)
@@ -272,24 +290,26 @@ class OpenGLWidget(qt.QWidget):
_check = isOpenGLAvailable(version=version, runtimeCheck=False)
if _OpenGLWidget is None or not _check:
- _logger.error('OpenGL-based widget disabled: %s', _check.error)
+ _logger.error("OpenGL-based widget disabled: %s", _check.error)
self.__openGLWidget = None
label = self._createErrorQLabel(_check.error)
self.layout().addWidget(label)
-
- else:
- self.__openGLWidget = _OpenGLWidget(
- parent=self,
- alphaBufferSize=alphaBufferSize,
- depthBufferSize=depthBufferSize,
- stencilBufferSize=stencilBufferSize,
- version=version,
- f=f)
- # Async connection need, otherwise issue when hiding OpenGL
- # widget while doing the rendering..
- self.__openGLWidget.sigOpenGLContextError.connect(
- self._handleOpenGLInitError, qt.Qt.QueuedConnection)
- self.layout().addWidget(self.__openGLWidget)
+ return
+
+ self.__openGLWidget = _OpenGLWidget(
+ parent=self,
+ alphaBufferSize=alphaBufferSize,
+ depthBufferSize=depthBufferSize,
+ stencilBufferSize=stencilBufferSize,
+ version=version,
+ f=f,
+ )
+ # Async connection need, otherwise issue when hiding OpenGL
+ # widget while doing the rendering..
+ self.__openGLWidget.sigOpenGLContextError.connect(
+ self._handleOpenGLInitError, qt.Qt.QueuedConnection
+ )
+ self.layout().addWidget(self.__openGLWidget)
@staticmethod
def _createErrorQLabel(error):
@@ -297,7 +317,7 @@ class OpenGLWidget(qt.QWidget):
:param str error: The error message to display"""
label = qt.QLabel()
- label.setText('OpenGL-based widget disabled:\n%s' % error)
+ label.setText("OpenGL-based widget disabled:\n%s" % error)
label.setAlignment(qt.Qt.AlignCenter)
label.setWordWrap(True)
return label
@@ -323,7 +343,7 @@ class OpenGLWidget(qt.QWidget):
:rtype: float
"""
if self.__openGLWidget is None:
- return 1.
+ return 1.0
else:
return self.__openGLWidget.getDevicePixelRatio()
@@ -333,13 +353,17 @@ class OpenGLWidget(qt.QWidget):
:rtype: float
"""
screen = self.window().windowHandle().screen()
- if screen is not None:
- # TODO check if this is correct on different OS/screen
- # OK on macOS10.12/qt5.13.2
- dpi = screen.physicalDotsPerInch() * self.getDevicePixelRatio()
- else: # Fallback
- dpi = 96. * self.getDevicePixelRatio()
- return dpi
+ if screen is None:
+ return 96.0 * self.getDevicePixelRatio()
+
+ physicalDPI = screen.physicalDotsPerInch()
+ if physicalDPI > 1000.0:
+ _logger.error(
+ "Reported screen DPI too high: %f, using default value instead",
+ physicalDPI,
+ )
+ physicalDPI = 96.0
+ return physicalDPI * self.getDevicePixelRatio()
def getOpenGLVersion(self):
"""Returns the available OpenGL version.
diff --git a/src/silx/gui/_glutils/Program.py b/src/silx/gui/_glutils/Program.py
index d61c07d..b2adacf 100644
--- a/src/silx/gui/_glutils/Program.py
+++ b/src/silx/gui/_glutils/Program.py
@@ -55,8 +55,7 @@ class Program(object):
array attached to it in order for the rendering to occur....
"""
- def __init__(self, vertexShader, fragmentShader,
- attrib0='position'):
+ def __init__(self, vertexShader, fragmentShader, attrib0="position"):
self._vertexShader = vertexShader
self._fragmentShader = fragmentShader
self._attrib0 = attrib0
@@ -66,7 +65,7 @@ class Program(object):
def _compileGL(vertexShader, fragmentShader, attrib0):
program = gl.glCreateProgram()
- gl.glBindAttribLocation(program, 0, attrib0.encode('ascii'))
+ gl.glBindAttribLocation(program, 0, attrib0.encode("ascii"))
vertex = gl.glCreateShader(gl.GL_VERTEX_SHADER)
gl.glShaderSource(vertex, vertexShader)
@@ -79,8 +78,7 @@ class Program(object):
fragment = gl.glCreateShader(gl.GL_FRAGMENT_SHADER)
gl.glShaderSource(fragment, fragmentShader)
gl.glCompileShader(fragment)
- if gl.glGetShaderiv(fragment,
- gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
+ if gl.glGetShaderiv(fragment, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
raise RuntimeError(gl.glGetShaderInfoLog(fragment))
gl.glAttachShader(program, fragment)
gl.glDeleteShader(fragment)
@@ -90,16 +88,15 @@ class Program(object):
raise RuntimeError(gl.glGetProgramInfoLog(program))
attributes = {}
- for index in range(gl.glGetProgramiv(program,
- gl.GL_ACTIVE_ATTRIBUTES)):
+ for index in range(gl.glGetProgramiv(program, gl.GL_ACTIVE_ATTRIBUTES)):
name = gl.glGetActiveAttrib(program, index)[0]
- namestr = name.decode('ascii')
+ namestr = name.decode("ascii")
attributes[namestr] = gl.glGetAttribLocation(program, name)
uniforms = {}
for index in range(gl.glGetProgramiv(program, gl.GL_ACTIVE_UNIFORMS)):
name = gl.glGetActiveUniform(program, index)[0]
- namestr = name.decode('ascii')
+ namestr = name.decode("ascii")
uniforms[namestr] = gl.glGetUniformLocation(program, name)
return program, attributes, uniforms
@@ -107,8 +104,7 @@ class Program(object):
def _getProgramInfo(self):
glcontext = Context.getCurrent()
if glcontext not in self._programs:
- raise RuntimeError(
- "Program was not compiled for current OpenGL context.")
+ raise RuntimeError("Program was not compiled for current OpenGL context.")
return self._programs[glcontext]
@property
@@ -152,16 +148,15 @@ class Program(object):
if glcontext not in self._programs:
self._programs[glcontext] = self._compileGL(
- self._vertexShader,
- self._fragmentShader,
- self._attrib0)
+ self._vertexShader, self._fragmentShader, self._attrib0
+ )
if _logger.getEffectiveLevel() <= logging.DEBUG:
gl.glValidateProgram(self.program)
- if gl.glGetProgramiv(
- self.program, gl.GL_VALIDATE_STATUS) != gl.GL_TRUE:
- _logger.debug('Cannot validate program: %s',
- gl.glGetProgramInfoLog(self.program))
+ if gl.glGetProgramiv(self.program, gl.GL_VALIDATE_STATUS) != gl.GL_TRUE:
+ _logger.debug(
+ "Cannot validate program: %s", gl.glGetProgramInfoLog(self.program)
+ )
gl.glUseProgram(self.program)
@@ -198,4 +193,4 @@ class Program(object):
gl.glUniformMatrix4fv(location, count, transpose, value)
elif not safe:
- raise KeyError('No uniform: %s' % name)
+ raise KeyError("No uniform: %s" % name)
diff --git a/src/silx/gui/_glutils/Texture.py b/src/silx/gui/_glutils/Texture.py
index 76bdcd8..aac380d 100644
--- a/src/silx/gui/_glutils/Texture.py
+++ b/src/silx/gui/_glutils/Texture.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2014-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2014-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -28,11 +28,7 @@ __license__ = "MIT"
__date__ = "04/10/2016"
-try:
- from collections import abc
-except ImportError: # Python2 support
- import collections as abc
-
+from collections import abc
from ctypes import c_void_p
import logging
@@ -62,10 +58,17 @@ class Texture(object):
:type wrap: OpenGL wrap mode or 2 or 3-tuple of wrap mode
"""
- def __init__(self, internalFormat, data=None, format_=None,
- shape=None, texUnit=0,
- minFilter=None, magFilter=None, wrap=None):
-
+ def __init__(
+ self,
+ internalFormat,
+ data=None,
+ format_=None,
+ shape=None,
+ texUnit=0,
+ minFilter=None,
+ magFilter=None,
+ wrap=None,
+ ):
self._internalFormat = internalFormat
if format_ is None:
format_ = self.internalFormat
@@ -74,7 +77,7 @@ class Texture(object):
assert shape is not None
else:
assert shape is None
- data = numpy.array(data, copy=False, order='C')
+ data = numpy.array(data, copy=False, order="C")
if format_ != gl.GL_RED:
shape = data.shape[:-1] # Last dimension is channels
else:
@@ -164,9 +167,7 @@ class Texture(object):
:rtype: bool
"""
- return (self._name is None or
- self._texParameterUpdates or
- self._deferredUpdates)
+ return self._name is None or self._texParameterUpdates or self._deferredUpdates
def _prepareAndBind(self, texUnit=None):
"""Synchronizes the OpenGL texture"""
@@ -200,10 +201,14 @@ class Texture(object):
if offset is None: # Initialize texture
if self.ndim == 2:
_logger.debug(
- 'Creating 2D texture shape: (%d, %d),'
- ' internal format: %s, format: %s, type: %s',
- self.shape[0], self.shape[1],
- str(self.internalFormat), str(format_), str(type_))
+ "Creating 2D texture shape: (%d, %d),"
+ " internal format: %s, format: %s, type: %s",
+ self.shape[0],
+ self.shape[1],
+ str(self.internalFormat),
+ str(format_),
+ str(type_),
+ )
gl.glTexImage2D(
gl.GL_TEXTURE_2D,
@@ -214,14 +219,20 @@ class Texture(object):
0,
format_,
type_,
- data)
+ data,
+ )
else:
_logger.debug(
- 'Creating 3D texture shape: (%d, %d, %d),'
- ' internal format: %s, format: %s, type: %s',
- self.shape[0], self.shape[1], self.shape[2],
- str(self.internalFormat), str(format_), str(type_))
+ "Creating 3D texture shape: (%d, %d, %d),"
+ " internal format: %s, format: %s, type: %s",
+ self.shape[0],
+ self.shape[1],
+ self.shape[2],
+ str(self.internalFormat),
+ str(format_),
+ str(type_),
+ )
gl.glTexImage3D(
gl.GL_TEXTURE_3D,
@@ -233,32 +244,37 @@ class Texture(object):
0,
format_,
type_,
- data)
+ data,
+ )
else: # Update already existing texture
if self.ndim == 2:
- gl.glTexSubImage2D(gl.GL_TEXTURE_2D,
- 0,
- offset[1],
- offset[0],
- data.shape[1],
- data.shape[0],
- format_,
- type_,
- data)
+ gl.glTexSubImage2D(
+ gl.GL_TEXTURE_2D,
+ 0,
+ offset[1],
+ offset[0],
+ data.shape[1],
+ data.shape[0],
+ format_,
+ type_,
+ data,
+ )
else:
- gl.glTexSubImage3D(gl.GL_TEXTURE_3D,
- 0,
- offset[2],
- offset[1],
- offset[0],
- data.shape[2],
- data.shape[1],
- data.shape[0],
- format_,
- type_,
- data)
+ gl.glTexSubImage3D(
+ gl.GL_TEXTURE_3D,
+ 0,
+ offset[2],
+ offset[1],
+ offset[0],
+ data.shape[2],
+ data.shape[1],
+ data.shape[0],
+ format_,
+ type_,
+ data,
+ )
self._deferredUpdates = []
@@ -340,7 +356,7 @@ class Texture(object):
:param bool copy:
True (default) to copy data, False to use as is (do not modify)
"""
- data = numpy.array(data, copy=copy, order='C')
+ data = numpy.array(data, copy=copy, order="C")
offset = tuple(offset)
assert data.ndim == self.ndim
diff --git a/src/silx/gui/_glutils/VertexBuffer.py b/src/silx/gui/_glutils/VertexBuffer.py
index 65fff86..d71bbeb 100644
--- a/src/silx/gui/_glutils/VertexBuffer.py
+++ b/src/silx/gui/_glutils/VertexBuffer.py
@@ -50,15 +50,12 @@ class VertexBuffer(object):
:param target: Target buffer:
GL_ARRAY_BUFFER (default) or GL_ELEMENT_ARRAY_BUFFER
"""
+
# OpenGL|ES 2.0 subset:
_USAGES = gl.GL_STREAM_DRAW, gl.GL_STATIC_DRAW, gl.GL_DYNAMIC_DRAW
_TARGETS = gl.GL_ARRAY_BUFFER, gl.GL_ELEMENT_ARRAY_BUFFER
- def __init__(self,
- data=None,
- size=None,
- usage=None,
- target=None):
+ def __init__(self, data=None, size=None, usage=None, target=None):
if usage is None:
usage = gl.GL_STATIC_DRAW
assert usage in self._USAGES
@@ -76,20 +73,14 @@ class VertexBuffer(object):
if data is None:
assert size is not None
self._size = size
- gl.glBufferData(self._target,
- self._size,
- c_void_p(0),
- self._usage)
+ gl.glBufferData(self._target, self._size, c_void_p(0), self._usage)
else:
- data = numpy.array(data, copy=False, order='C')
+ data = numpy.array(data, copy=False, order="C")
if size is not None:
assert size <= data.nbytes
self._size = size or data.nbytes
- gl.glBufferData(self._target,
- self._size,
- data,
- self._usage)
+ gl.glBufferData(self._target, self._size, data, self._usage)
gl.glBindBuffer(self._target, 0)
@@ -109,8 +100,10 @@ class VertexBuffer(object):
if self._name is not None:
return self._name
else:
- raise RuntimeError("No OpenGL buffer resource, \
- discard has already been called")
+ raise RuntimeError(
+ "No OpenGL buffer resource, \
+ discard has already been called"
+ )
@property
def size(self):
@@ -118,8 +111,10 @@ class VertexBuffer(object):
if self._size is not None:
return self._size
else:
- raise RuntimeError("No OpenGL buffer resource, \
- discard has already been called")
+ raise RuntimeError(
+ "No OpenGL buffer resource, \
+ discard has already been called"
+ )
def bind(self):
"""Bind the vertex buffer"""
@@ -132,7 +127,7 @@ class VertexBuffer(object):
:param int offset: Offset in bytes in the buffer where to put the data
:param int size: If provided, size of data to copy
"""
- data = numpy.array(data, copy=False, order='C')
+ data = numpy.array(data, copy=False, order="C")
if size is None:
size = data.nbytes
assert offset + size <= self.size
@@ -172,14 +167,9 @@ class VertexBufferAttrib(object):
_GL_TYPES = gl.GL_UNSIGNED_BYTE, gl.GL_FLOAT, gl.GL_INT
- def __init__(self,
- vbo,
- type_,
- size,
- dimension=1,
- offset=0,
- stride=0,
- normalization=False):
+ def __init__(
+ self, vbo, type_, size, dimension=1, offset=0, stride=0, normalization=False
+ ):
self.vbo = vbo
assert type_ in self._GL_TYPES
self.type_ = type_
@@ -201,21 +191,25 @@ class VertexBufferAttrib(object):
"""Call glVertexAttribPointer with objects information"""
normalization = gl.GL_TRUE if self.normalization else gl.GL_FALSE
with self.vbo:
- gl.glVertexAttribPointer(attribute,
- self.dimension,
- self.type_,
- normalization,
- self.stride,
- c_void_p(self.offset))
+ gl.glVertexAttribPointer(
+ attribute,
+ self.dimension,
+ self.type_,
+ normalization,
+ self.stride,
+ c_void_p(self.offset),
+ )
def copy(self):
- return VertexBufferAttrib(self.vbo,
- self.type_,
- self.size,
- self.dimension,
- self.offset,
- self.stride,
- self.normalization)
+ return VertexBufferAttrib(
+ self.vbo,
+ self.type_,
+ self.size,
+ self.dimension,
+ self.offset,
+ self.stride,
+ self.normalization,
+ )
def vertexBuffer(arrays, prefix=None, suffix=None, usage=None):
@@ -241,7 +235,7 @@ def vertexBuffer(arrays, prefix=None, suffix=None, usage=None):
suffix = (0,) * len(arrays)
for data, pre, post in zip(arrays, prefix, suffix):
- data = numpy.array(data, copy=False, order='C')
+ data = numpy.array(data, copy=False, order="C")
shape = data.shape
assert len(shape) <= 2
type_ = numpyToGLType(data.dtype)
@@ -250,8 +244,7 @@ def vertexBuffer(arrays, prefix=None, suffix=None, usage=None):
sizeinbytes = size * dimension * sizeofGLType(type_)
sizeinbytes = 4 * ((sizeinbytes + 3) >> 2) # 4 bytes alignment
copyoffset = vbosize + pre * dimension * sizeofGLType(type_)
- info.append((data, type_, size, dimension,
- vbosize, sizeinbytes, copyoffset))
+ info.append((data, type_, size, dimension, vbosize, sizeinbytes, copyoffset))
vbosize += sizeinbytes
vbo = VertexBuffer(size=vbosize, usage=usage)
@@ -260,6 +253,5 @@ def vertexBuffer(arrays, prefix=None, suffix=None, usage=None):
for data, type_, size, dimension, offset, sizeinbytes, copyoffset in info:
copysize = data.shape[0] * dimension * sizeofGLType(type_)
vbo.update(data, offset=copyoffset, size=copysize)
- result.append(
- VertexBufferAttrib(vbo, type_, size, dimension, offset, 0))
+ result.append(VertexBufferAttrib(vbo, type_, size, dimension, offset, 0))
return result
diff --git a/src/silx/gui/_glutils/__init__.py b/src/silx/gui/_glutils/__init__.py
index a7a4bee..9526ba4 100644
--- a/src/silx/gui/_glutils/__init__.py
+++ b/src/silx/gui/_glutils/__init__.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2014-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2014-2022 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
diff --git a/src/silx/gui/_glutils/font.py b/src/silx/gui/_glutils/font.py
index bee9745..4c0268e 100644
--- a/src/silx/gui/_glutils/font.py
+++ b/src/silx/gui/_glutils/font.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2016-2022 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -28,162 +28,12 @@ __license__ = "MIT"
__date__ = "13/10/2016"
-import logging
-import numpy
-
from .. import qt
-from ..utils.image import convertQImageToArray
-
-try:
- from ..utils.matplotlib import rasterMathText
-except ImportError:
- rasterMathText = None
-_logger = logging.getLogger(__name__)
+# Expose rasterMathText as part of this module
+from ..utils.matplotlib import rasterMathText as rasterText # noqa
-def getDefaultFontFamily():
+def getDefaultFontFamily() -> str:
"""Returns the default font family of the application"""
return qt.QApplication.instance().font().family()
-
-
-# Font weights
-ULTRA_LIGHT = 0
-"""Lightest characters: Minimum font weight"""
-
-LIGHT = 25
-"""Light characters"""
-
-NORMAL = 50
-"""Normal characters"""
-
-SEMI_BOLD = 63
-"""Between normal and bold characters"""
-
-BOLD = 74
-"""Thicker characters"""
-
-BLACK = 87
-"""Really thick characters"""
-
-ULTRA_BLACK = 99
-"""Thickest characters: Maximum font weight"""
-
-
-def rasterTextQt(text, font, size=-1, weight=-1, italic=False, devicePixelRatio=1.0):
- """Raster text using Qt.
-
- It supports multiple lines.
-
- :param str text: The text to raster
- :param font: Font name or QFont to use
- :type font: str or :class:`QFont`
- :param int size:
- Font size in points
- Used only if font is given as name.
- :param int weight:
- Font weight in [0, 99], see QFont.Weight.
- Used only if font is given as name.
- :param bool italic:
- True for italic font (default: False).
- Used only if font is given as name.
- :param float devicePixelRatio:
- The current ratio between device and device-independent pixel
- (default: 1.0)
- :return: Corresponding image in gray scale and baseline offset from top
- :rtype: (HxW numpy.ndarray of uint8, int)
- """
- if not text:
- _logger.info("Trying to raster empty text, replaced by white space")
- text = " " # Replace empty text by white space to produce an image
-
- if not isinstance(font, qt.QFont):
- font = qt.QFont(font, size, weight, italic)
-
- # get text size
- image = qt.QImage(1, 1, qt.QImage.Format_RGB888)
- painter = qt.QPainter()
- painter.begin(image)
- painter.setPen(qt.Qt.white)
- painter.setFont(font)
- bounds = painter.boundingRect(
- qt.QRect(0, 0, 4096, 4096), qt.Qt.TextExpandTabs, text
- )
- painter.end()
-
- metrics = qt.QFontMetrics(font)
-
- # This does not provide the correct text bbox on macOS
- # size = metrics.size(qt.Qt.TextExpandTabs, text)
- # bounds = metrics.boundingRect(
- # qt.QRect(0, 0, size.width(), size.height()),
- # qt.Qt.TextExpandTabs,
- # text)
-
- # Add extra border and handle devicePixelRatio
- width = bounds.width() * devicePixelRatio + 2
- # align line size to 32 bits to ease conversion to numpy array
- width = 4 * ((width + 3) // 4)
- image = qt.QImage(
- int(width), int(bounds.height() * devicePixelRatio + 2), qt.QImage.Format_RGB888
- )
- image.setDevicePixelRatio(devicePixelRatio)
-
- # TODO if Qt5 use Format_Grayscale8 instead
- image.fill(0)
-
- # Raster text
- painter = qt.QPainter()
- painter.begin(image)
- painter.setPen(qt.Qt.white)
- painter.setFont(font)
- painter.drawText(bounds, qt.Qt.TextExpandTabs, text)
- painter.end()
-
- array = convertQImageToArray(image)
-
- # RGB to R
- array = numpy.ascontiguousarray(array[:, :, 0])
-
- # Remove leading and trailing empty columns/rows but one on each side
- filled_rows = numpy.nonzero(numpy.sum(array, axis=1))[0]
- filled_columns = numpy.nonzero(numpy.sum(array, axis=0))[0]
- if len(filled_rows) == 0 or len(filled_columns) == 0:
- return array, metrics.ascent()
-
- min_row = max(0, filled_rows[0] - 1)
- array = array[
- min_row : filled_rows[-1] + 2,
- max(0, filled_columns[0] - 1) : filled_columns[-1] + 2,
- ]
-
- return array, metrics.ascent() - min_row
-
-
-def rasterText(text, font, size=-1, weight=-1, italic=False, devicePixelRatio=1.0):
- """Raster text using Qt or matplotlib if there may be math syntax.
-
- It supports multiple lines.
-
- :param str text: The text to raster
- :param font: Font name or QFont to use
- :type font: str or :class:`QFont`
- :param int size:
- Font size in points
- Used only if font is given as name.
- :param int weight:
- Font weight in [0, 99], see QFont.Weight.
- Used only if font is given as name.
- :param bool italic:
- True for italic font (default: False).
- Used only if font is given as name.
- :param float devicePixelRatio:
- The current ratio between device and device-independent pixel
- (default: 1.0)
- :return: Corresponding image in gray scale and baseline offset from top
- :rtype: (HxW numpy.ndarray of uint8, int)
- """
- if rasterMathText is not None and text.count("$") >= 2:
- return rasterMathText(text, font, size, weight, italic, devicePixelRatio)
- else:
- return rasterTextQt(text, font, size, weight, italic, devicePixelRatio)
diff --git a/src/silx/gui/_glutils/gl.py b/src/silx/gui/_glutils/gl.py
index fb0a3fa..aff7617 100644
--- a/src/silx/gui/_glutils/gl.py
+++ b/src/silx/gui/_glutils/gl.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2014-2017 European Synchrotron Radiation Facility
+# Copyright (c) 2014-2022 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -31,13 +31,19 @@ __date__ = "25/07/2016"
from contextlib import contextmanager as _contextmanager
from ctypes import c_uint
import logging
+import sys
+from typing import Optional
+
+from packaging.version import Version
+
_logger = logging.getLogger(__name__)
import OpenGL
+
# Set the following to true for debugging
if _logger.getEffectiveLevel() <= logging.DEBUG:
- _logger.debug('Enabling PyOpenGL debug flags')
+ _logger.debug("Enabling PyOpenGL debug flags")
OpenGL.ERROR_LOGGING = True
OpenGL.ERROR_CHECKING = True
OpenGL.ERROR_ON_COPY = True
@@ -46,8 +52,14 @@ else:
OpenGL.ERROR_CHECKING = False
OpenGL.ERROR_ON_COPY = False
+if sys.version_info >= (3, 12) and Version(OpenGL.__version__) <= Version("3.1.7"):
+ # Python3.12 patch: see https://github.com/mcfletch/pyopengl/pull/100
+ OpenGL.FormatHandler.by_name("ctypesparameter").check.append("_ctypes.CArgObject")
+
+
import OpenGL.GL as _GL
from OpenGL.GL import * # noqa
+import OpenGL.platform
# Extentions core in OpenGL 3
from OpenGL.GL.ARB import framebuffer_object as _FBO
@@ -60,9 +72,22 @@ try:
GLchar
except NameError:
from ctypes import c_char
+
GLchar = c_char
+def getPlatform() -> Optional[str]:
+ """Returns the name of the PyOpenGL class handling the platform.
+
+ E.g., GLXPlatform, EGLPlatform
+ """
+ try:
+ platform = OpenGL.platform.PLATFORM
+ except AttributeError:
+ return None
+ return platform.__class__.__name__
+
+
def getVersion() -> tuple:
"""Returns the GL version as tuple of integers.
@@ -74,7 +99,7 @@ def getVersion() -> tuple:
if isinstance(desc, bytes):
desc = desc.decode("ascii")
version = desc.split(" ", 1)[0]
- return tuple([int(i) for i in version.split('.')])
+ return tuple([int(i) for i in version.split(".")])
except Exception as e:
raise ValueError("GL version not properly formatted") from e
@@ -90,21 +115,23 @@ def testGL() -> bool:
_logger.error("OpenGL version >=2.1 required, running with %s" % version)
return False
- from OpenGL.GL.ARB.framebuffer_object import glInitFramebufferObjectARB
- from OpenGL.GL.ARB.texture_rg import glInitTextureRgARB
+ if major == 2:
+ from OpenGL.GL.ARB.framebuffer_object import glInitFramebufferObjectARB
+ from OpenGL.GL.ARB.texture_rg import glInitTextureRgARB
- if not glInitFramebufferObjectARB():
- _logger.error("OpenGL GL_ARB_framebuffer_object extension required!")
- return False
+ if not glInitFramebufferObjectARB():
+ _logger.error("OpenGL GL_ARB_framebuffer_object extension required!")
+ return False
+
+ if not glInitTextureRgARB():
+ _logger.error("OpenGL GL_ARB_texture_rg extension required!")
+ return False
- if not glInitTextureRgARB():
- _logger.error("OpenGL GL_ARB_texture_rg extension required!")
- return False
return True
# Additional setup
-if hasattr(glget, 'addGLGetConstant'):
+if hasattr(glget, "addGLGetConstant"):
glget.addGLGetConstant(GL_FRAMEBUFFER_BINDING, (1,))
@@ -145,6 +172,7 @@ def disabled(capacity, disable=True):
# Additional OpenGL wrapping
+
def glGetActiveAttrib(program, index):
"""Wrap PyOpenGL glGetActiveAttrib"""
bufsize = glGetProgramiv(program, GL_ACTIVE_ATTRIBUTE_MAX_LENGTH)
@@ -158,28 +186,28 @@ def glGetActiveAttrib(program, index):
def glDeleteRenderbuffers(buffers):
- if not hasattr(buffers, '__len__'): # Support single int argument
+ if not hasattr(buffers, "__len__"): # Support single int argument
buffers = [buffers]
length = len(buffers)
_FBO.glDeleteRenderbuffers(length, (c_uint * length)(*buffers))
def glDeleteFramebuffers(buffers):
- if not hasattr(buffers, '__len__'): # Support single int argument
+ if not hasattr(buffers, "__len__"): # Support single int argument
buffers = [buffers]
length = len(buffers)
_FBO.glDeleteFramebuffers(length, (c_uint * length)(*buffers))
def glDeleteBuffers(buffers):
- if not hasattr(buffers, '__len__'): # Support single int argument
+ if not hasattr(buffers, "__len__"): # Support single int argument
buffers = [buffers]
length = len(buffers)
_GL.glDeleteBuffers(length, (c_uint * length)(*buffers))
def glDeleteTextures(textures):
- if not hasattr(textures, '__len__'): # Support single int argument
+ if not hasattr(textures, "__len__"): # Support single int argument
textures = [textures]
length = len(textures)
_GL.glDeleteTextures((c_uint * length)(*textures))
diff --git a/src/silx/gui/_glutils/test/__init__.py b/src/silx/gui/_glutils/test/__init__.py
index 5ad4c28..e9dd44d 100644
--- a/src/silx/gui/_glutils/test/__init__.py
+++ b/src/silx/gui/_glutils/test/__init__.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2015-2022 European Synchrotron Radiation Facility
+# Copyright (c) 2015-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
diff --git a/src/silx/gui/_glutils/test/test_gl.py b/src/silx/gui/_glutils/test/test_gl.py
index be9332b..d719c08 100644
--- a/src/silx/gui/_glutils/test/test_gl.py
+++ b/src/silx/gui/_glutils/test/test_gl.py
@@ -26,11 +26,11 @@ from .. import gl
def test_version_bytes(mocker):
- mocker.patch('silx.gui._glutils.gl.glGetString', return_value=b"3.0 Mock")
+ mocker.patch("silx.gui._glutils.gl.glGetString", return_value=b"3.0 Mock")
assert gl.getVersion() == (3, 0)
def test_version_str(mocker):
"""In case glGetString returns str"""
- mocker.patch('silx.gui._glutils.gl.glGetString', return_value="3.0 Mock")
+ mocker.patch("silx.gui._glutils.gl.glGetString", return_value="3.0 Mock")
assert gl.getVersion() == (3, 0)
diff --git a/src/silx/gui/_glutils/utils.py b/src/silx/gui/_glutils/utils.py
index 49b431a..56ac935 100644
--- a/src/silx/gui/_glutils/utils.py
+++ b/src/silx/gui/_glutils/utils.py
@@ -94,8 +94,9 @@ def segmentTrianglesIntersection(segment, triangles):
subVolumes[:, 2] = numpy.sum(t0s0CrossEdge01 * d, axis=1)
subVolumes[:, 0] = volume - subVolumes[:, 1] - subVolumes[:, 2]
intersect = numpy.logical_or(
- numpy.all(subVolumes >= 0., axis=1), # All positive
- numpy.all(subVolumes <= 0., axis=1)) # All negative
+ numpy.all(subVolumes >= 0.0, axis=1), # All positive
+ numpy.all(subVolumes <= 0.0, axis=1),
+ ) # All negative
intersect = numpy.where(intersect)[0] # Indices of intersected triangles
# Get barycentric coordinates
@@ -112,7 +113,7 @@ def segmentTrianglesIntersection(segment, triangles):
del volAlpha
del volume
- inSegmentMask = numpy.logical_and(t >= 0., t <= 1.)
+ inSegmentMask = numpy.logical_and(t >= 0.0, t <= 1.0)
intersect = intersect[inSegmentMask]
t = t[inSegmentMask]
barycentric = barycentric[inSegmentMask]
diff --git a/src/silx/gui/colors.py b/src/silx/gui/colors.py
index 4a5f278..b47fa85 100755
--- a/src/silx/gui/colors.py
+++ b/src/silx/gui/colors.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2015-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2015-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -24,26 +24,39 @@
"""This module provides API to manage colors.
"""
+from __future__ import annotations
+
__authors__ = ["T. Vincent", "H.Payno"]
__license__ = "MIT"
__date__ = "29/01/2019"
+
import numpy
import logging
+import numbers
+import re
+from collections.abc import Iterable
+from typing import Any, Sequence, Tuple, Union
+import silx
from silx.gui import qt
from silx.gui.utils import blockSignals
from silx.math import colormap as _colormap
from silx.utils.exceptions import NotEditableError
-from silx.utils import deprecation
_logger = logging.getLogger(__name__)
try:
import silx.gui.utils.matplotlib # noqa Initalize matplotlib
- from matplotlib import cm as _matplotlib_cm
- from matplotlib.pyplot import colormaps as _matplotlib_colormaps
+
+ try:
+ from matplotlib import colormaps as _matplotlib_colormaps
+ except ImportError: # For matplotlib < 3.5
+ from matplotlib import cm as _matplotlib_cm
+ from matplotlib.pyplot import colormaps as _matplotlib_colormaps
+ else:
+ _matplotlib_cm = None
except ImportError:
_logger.info("matplotlib not available, only embedded colormaps available")
_matplotlib_cm = None
@@ -53,29 +66,29 @@ except ImportError:
_COLORDICT = {}
"""Dictionary of common colors."""
-_COLORDICT['b'] = _COLORDICT['blue'] = '#0000ff'
-_COLORDICT['r'] = _COLORDICT['red'] = '#ff0000'
-_COLORDICT['g'] = _COLORDICT['green'] = '#00ff00'
-_COLORDICT['k'] = _COLORDICT['black'] = '#000000'
-_COLORDICT['w'] = _COLORDICT['white'] = '#ffffff'
-_COLORDICT['pink'] = '#ff66ff'
-_COLORDICT['brown'] = '#a52a2a'
-_COLORDICT['orange'] = '#ff9900'
-_COLORDICT['violet'] = '#6600ff'
-_COLORDICT['gray'] = _COLORDICT['grey'] = '#a0a0a4'
+_COLORDICT["b"] = _COLORDICT["blue"] = "#0000ff"
+_COLORDICT["r"] = _COLORDICT["red"] = "#ff0000"
+_COLORDICT["g"] = _COLORDICT["green"] = "#00ff00"
+_COLORDICT["k"] = _COLORDICT["black"] = "#000000"
+_COLORDICT["w"] = _COLORDICT["white"] = "#ffffff"
+_COLORDICT["pink"] = "#ff66ff"
+_COLORDICT["brown"] = "#a52a2a"
+_COLORDICT["orange"] = "#ff9900"
+_COLORDICT["violet"] = "#6600ff"
+_COLORDICT["gray"] = _COLORDICT["grey"] = "#a0a0a4"
# _COLORDICT['darkGray'] = _COLORDICT['darkGrey'] = '#808080'
# _COLORDICT['lightGray'] = _COLORDICT['lightGrey'] = '#c0c0c0'
-_COLORDICT['y'] = _COLORDICT['yellow'] = '#ffff00'
-_COLORDICT['m'] = _COLORDICT['magenta'] = '#ff00ff'
-_COLORDICT['c'] = _COLORDICT['cyan'] = '#00ffff'
-_COLORDICT['darkBlue'] = '#000080'
-_COLORDICT['darkRed'] = '#800000'
-_COLORDICT['darkGreen'] = '#008000'
-_COLORDICT['darkBrown'] = '#660000'
-_COLORDICT['darkCyan'] = '#008080'
-_COLORDICT['darkYellow'] = '#808000'
-_COLORDICT['darkMagenta'] = '#800080'
-_COLORDICT['transparent'] = '#00000000'
+_COLORDICT["y"] = _COLORDICT["yellow"] = "#ffff00"
+_COLORDICT["m"] = _COLORDICT["magenta"] = "#ff00ff"
+_COLORDICT["c"] = _COLORDICT["cyan"] = "#00ffff"
+_COLORDICT["darkBlue"] = "#000080"
+_COLORDICT["darkRed"] = "#800000"
+_COLORDICT["darkGreen"] = "#008000"
+_COLORDICT["darkBrown"] = "#660000"
+_COLORDICT["darkCyan"] = "#008080"
+_COLORDICT["darkYellow"] = "#808000"
+_COLORDICT["darkMagenta"] = "#800080"
+_COLORDICT["transparent"] = "#00000000"
# FIXME: It could be nice to expose a functional API instead of that attribute
@@ -88,109 +101,149 @@ DEFAULT_MAX_LIN = 1
"""Default max value if in linear normalization"""
-def rgba(color, colorDict=None):
- """Convert color code '#RRGGBB' and '#RRGGBBAA' to a tuple (R, G, B, A)
- of floats.
+_INDEXED_COLOR_PATTERN = re.compile(r"C(?P<index>[0-9]+)")
- It also supports RGB(A) from uint8 in [0, 255], float in [0, 1], and
- QColor as color argument.
- :param str color: The color to convert
- :param dict colorDict: A dictionary of color name conversion to color code
- :returns: RGBA colors as floats in [0., 1.]
- :rtype: tuple
- """
- if colorDict is None:
- colorDict = _COLORDICT
+ColorType = Union[str, Sequence[numbers.Real], qt.QColor]
+"""Type of :func:`rgba`'s color argument"""
- if hasattr(color, 'getRgb'): # QColor support
- color = color.getRgb()
- values = numpy.asarray(color).ravel()
+RGBAColorType = Tuple[float, float, float, float]
+"""Type of :func:`rgba` return value"""
- if values.dtype.kind in 'iuf': # integer or float
- # Color is an array
- assert len(values) in (3, 4)
- # Convert from integers in [0, 255] to float in [0, 1]
- if values.dtype.kind in 'iu':
- values = values / 255.
+def rgba(
+ color: ColorType,
+ colorDict: dict[str, str] | None = None,
+ colors: Sequence[str] | None = None,
+) -> RGBAColorType:
+ """Convert different kind of color definition to a tuple (R, G, B, A) of floats.
- # Clip to [0, 1]
- values[values < 0.] = 0.
- values[values > 1.] = 1.
+ It supports:
+ - color names: e.g., 'green'
+ - color codes: '#RRGGBB' and '#RRGGBBAA'
+ - indexed color names: e.g., 'C0'
+ - RGB(A) sequence of uint8 in [0, 255] or float in [0, 1]
+ - QColor
- if len(values) == 3:
- return values[0], values[1], values[2], 1.
- else:
- return tuple(values)
+ :param color: The color to convert
+ :param colorDict: A dictionary of color name conversion to color code
+ :param colors: Sequence of colors to use for `
+ :returns: RGBA colors as floats in [0., 1.]
+ :raises ValueError: if the input is not a valid color
+ """
+ if isinstance(color, str):
+ # From name
+ colorFromDict = (_COLORDICT if colorDict is None else colorDict).get(color)
+ if colorFromDict is not None:
+ return rgba(colorFromDict, colorDict, colors)
+
+ # From indexed color name: color{index}
+ match = _INDEXED_COLOR_PATTERN.fullmatch(color)
+ if match is not None:
+ if colors is None:
+ colors = silx.config.DEFAULT_PLOT_CURVE_COLORS
+ index = int(match["index"]) % len(colors)
+ return rgba(colors[index], colorDict, colors)
+
+ # From #code
+ if len(color) in (7, 9) and color[0] == "#":
+ r = int(color[1:3], 16) / 255.0
+ g = int(color[3:5], 16) / 255.0
+ b = int(color[5:7], 16) / 255.0
+ a = int(color[7:9], 16) / 255.0 if len(color) == 9 else 1.0
+ return r, g, b, a
+
+ raise ValueError(f"The string '{color}' is not a valid color")
+
+ # From QColor
+ if isinstance(color, qt.QColor):
+ return rgba(color.getRgb(), colorDict, colors)
+
+ # From array
+ values = numpy.asarray(color).ravel()
+
+ if values.dtype.kind not in "iuf":
+ raise ValueError(
+ f"The array color must be integer/unsigned or float. Found '{values.dtype.kind}'"
+ )
+ if len(values) not in (3, 4):
+ raise ValueError(
+ f"The array color must have 3 or 4 compound. Found '{len(values)}'"
+ )
- # We assume color is a string
- if not color.startswith('#'):
- color = colorDict[color]
+ # Convert from integers in [0, 255] to float in [0, 1]
+ if values.dtype.kind in "iu":
+ values = values / 255.0
- assert len(color) in (7, 9) and color[0] == '#'
- r = int(color[1:3], 16) / 255.
- g = int(color[3:5], 16) / 255.
- b = int(color[5:7], 16) / 255.
- a = int(color[7:9], 16) / 255. if len(color) == 9 else 1.
- return r, g, b, a
+ values = numpy.clip(values, 0.0, 1.0)
+ if len(values) == 3:
+ return values[0], values[1], values[2], 1.0
+ return tuple(values)
-def greyed(color, colorDict=None):
+
+def greyed(
+ color: ColorType,
+ colorDict: dict[str, str] | None = None,
+) -> RGBAColorType:
"""Convert color code '#RRGGBB' and '#RRGGBBAA' to a grey color
(R, G, B, A).
It also supports RGB(A) from uint8 in [0, 255], float in [0, 1], and
QColor as color argument.
- :param str color: The color to convert
- :param dict colorDict: A dictionary of color name conversion to color code
+ :param color: The color to convert
+ :param colorDict: A dictionary of color name conversion to color code
:returns: RGBA colors as floats in [0., 1.]
- :rtype: tuple
"""
r, g, b, a = rgba(color=color, colorDict=colorDict)
g = 0.21 * r + 0.72 * g + 0.07 * b
return g, g, g, a
-def asQColor(color):
+def asQColor(color: ColorType) -> qt.QColor:
"""Convert color code '#RRGGBB' and '#RRGGBBAA' to a `qt.QColor`.
It also supports RGB(A) from uint8 in [0, 255], float in [0, 1], and
QColor as color argument.
- :param str color: The color to convert
- :rtype: qt.QColor
+ :param color: The color to convert
"""
color = rgba(color)
return qt.QColor.fromRgbF(*color)
-def cursorColorForColormap(colormapName):
+def cursorColorForColormap(colormapName: str) -> str:
"""Get a color suitable for overlay over a colormap.
- :param str colormapName: The name of the colormap.
+ :param colormapName: The name of the colormap.
:return: Name of the color.
- :rtype: str
"""
return _colormap.get_colormap_cursor_color(colormapName)
# Colormap loader
-def _registerColormapFromMatplotlib(name, cursor_color='black', preferred=False):
- colormap = _matplotlib_cm.get_cmap(name)
+
+def _registerColormapFromMatplotlib(
+ name: str,
+ cursor_color: str = "black",
+ preferred: bool = False,
+):
+ if _matplotlib_cm is not None:
+ colormap = _matplotlib_cm.get_cmap(name)
+ else: # matplotlib >= 3.5
+ colormap = _matplotlib_colormaps[name]
lut = colormap(numpy.linspace(0, 1, colormap.N, endpoint=True))
colors = _colormap.array_to_rgba8888(lut)
registerLUT(name, colors, cursor_color, preferred)
-def _getColormap(name):
+def _getColormap(name: str) -> numpy.ndarray:
"""Returns the color LUT corresponding to a colormap name
- :param str name: Name of the colormap to load
+ :param name: Name of the colormap to load
:returns: Corresponding table of colors
- :rtype: numpy.ndarray
:raise ValueError: If no colormap corresponds to name
"""
name = str(name)
@@ -198,40 +251,65 @@ def _getColormap(name):
return _colormap.get_colormap_lut(name)
except ValueError:
# Colormap is not available, try to load it from matplotlib
- _registerColormapFromMatplotlib(name, 'black', False)
+ _registerColormapFromMatplotlib(name, "black", False)
return _colormap.get_colormap_lut(name)
+class _Colormappable:
+ """Class for objects that can be colormapped by a :class:`Colormap`
+
+ Used by silx.gui.plot.items.core.ColormapMixIn
+ """
+
+ def _getColormapAutoscaleRange(
+ self,
+ colormap: Colormap | None,
+ ) -> tuple[float | None, float | None]:
+ """Returns the autoscale range for given colormap.
+
+ :param colormap:
+ The colormap for which to compute the autoscale range.
+ If None, the default, the colormap of the item is used
+ :return: (vmin, vmax) range
+ """
+ raise NotImplementedError("This method must be implemented in subclass")
+
+ def getColormappedData(copy: bool = False) -> numpy.ndarray | None:
+ """Returns the data used to compute the displayed colors
+
+ :param copy: True to get a copy, False to get internal data (do not modify!).
+ """
+ raise NotImplementedError("This method must be implemented in subclass")
+
+
class Colormap(qt.QObject):
"""Description of a colormap
If no `name` nor `colors` are provided, a default gray LUT is used.
- :param str name: Name of the colormap
- :param tuple colors: optional, custom colormap.
+ :param name: Name of the colormap
+ :param colors: optional, custom colormap.
Nx3 or Nx4 numpy array of RGB(A) colors,
either uint8 or float in [0, 1].
If 'name' is None, then this array is used as the colormap.
- :param str normalization: Normalization: 'linear' (default) or 'log'
+ :param normalization: Normalization: 'linear' (default) or 'log'
:param vmin: Lower bound of the colormap or None for autoscale (default)
- :type vmin: Union[None, float]
:param vmax: Upper bounds of the colormap or None for autoscale (default)
- :type vmax: Union[None, float]
"""
- LINEAR = 'linear'
+ LINEAR = "linear"
"""constant for linear normalization"""
- LOGARITHM = 'log'
+ LOGARITHM = "log"
"""constant for logarithmic normalization"""
- SQRT = 'sqrt'
+ SQRT = "sqrt"
"""constant for square root normalization"""
- GAMMA = 'gamma'
+ GAMMA = "gamma"
"""Constant for gamma correction normalization"""
- ARCSINH = 'arcsinh'
+ ARCSINH = "arcsinh"
"""constant for inverse hyperbolic sine normalization"""
_BASIC_NORMALIZATIONS = {
@@ -239,16 +317,16 @@ class Colormap(qt.QObject):
LOGARITHM: _colormap.LogarithmicNormalization(),
SQRT: _colormap.SqrtNormalization(),
ARCSINH: _colormap.ArcsinhNormalization(),
- }
+ }
"""Normalizations without parameters"""
NORMALIZATIONS = LINEAR, LOGARITHM, SQRT, GAMMA, ARCSINH
"""Tuple of managed normalizations"""
- MINMAX = 'minmax'
+ MINMAX = "minmax"
"""constant for autoscale using min/max data range"""
- STDDEV3 = 'stddev3'
+ STDDEV3 = "stddev3"
"""constant for autoscale using mean +/- 3*std(data)
with a clamp on min/max of the data"""
@@ -260,7 +338,15 @@ class Colormap(qt.QObject):
_DEFAULT_NAN_COLOR = 255, 255, 255, 0
- def __init__(self, name=None, colors=None, normalization=LINEAR, vmin=None, vmax=None, autoscaleMode=MINMAX):
+ def __init__(
+ self,
+ name: str | None = None,
+ colors: numpy.ndarray | None = None,
+ normalization: str = LINEAR,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ autoscaleMode: str = MINMAX,
+ ):
qt.QObject.__init__(self)
self._editable = True
self.__gamma = 2.0
@@ -273,7 +359,7 @@ class Colormap(qt.QObject):
if normalization is Colormap.LOGARITHM:
if (vmin is not None and vmin < 0) or (vmax is not None and vmax < 0):
m = "Unsuported vmin (%s) and/or vmax (%s) given for a log scale."
- m += ' Autoscale will be performed.'
+ m += " Autoscale will be performed."
m = m % (vmin, vmax)
_logger.warning(m)
vmin = None
@@ -283,13 +369,7 @@ class Colormap(qt.QObject):
self._colors = None
if colors is not None and name is not None:
- deprecation.deprecated_warning("Argument",
- name="silx.gui.plot.Colors",
- reason="name and colors can't be used at the same time",
- since_version="0.10.0",
- skip_backtrace_count=1)
-
- colors = None
+ raise ValueError("name and colors arguments can't be set at the same time")
if name is not None:
self.setName(name) # And resets colormap LUT
@@ -306,13 +386,13 @@ class Colormap(qt.QObject):
self.__warnBadVmin = True
self.__warnBadVmax = True
- def setFromColormap(self, other):
+ def setFromColormap(self, other: Colormap):
"""Set this colormap using information from the `other` colormap.
- :param ~silx.gui.colors.Colormap other: Colormap to use as reference.
+ :param other: Colormap to use as reference.
"""
if not self.isEditable():
- raise NotEditableError('Colormap is not editable')
+ raise NotEditableError("Colormap is not editable")
if self == other:
return
with blockSignals(self):
@@ -323,22 +403,19 @@ class Colormap(qt.QObject):
self.setColormapLUT(other.getColormapLUT())
self.setNaNColor(other.getNaNColor())
self.setNormalization(other.getNormalization())
- self.setGammaNormalizationParameter(
- other.getGammaNormalizationParameter())
+ self.setGammaNormalizationParameter(other.getGammaNormalizationParameter())
self.setAutoscaleMode(other.getAutoscaleMode())
self.setVRange(*other.getVRange())
self.setEditable(other.isEditable())
self.sigChanged.emit()
- def getNColors(self, nbColors=None):
+ def getNColors(self, nbColors: int | None = None) -> numpy.ndarray:
"""Returns N colors computed by sampling the colormap regularly.
:param nbColors:
The number of colors in the returned array or None for the default value.
The default value is the size of the colormap LUT.
- :type nbColors: int or None
:return: 2D array of uint8 of shape (nbColors, 4)
- :rtype: numpy.ndarray
"""
# Handle default value for nbColors
if nbColors is None:
@@ -348,20 +425,17 @@ class Colormap(qt.QObject):
colormap = self.copy()
colormap.setNormalization(Colormap.LINEAR)
colormap.setVRange(vmin=0, vmax=nbColors - 1)
- colors = colormap.applyToData(
- numpy.arange(nbColors, dtype=numpy.int32))
+ colors = colormap.applyToData(numpy.arange(nbColors, dtype=numpy.int32))
return colors
- def getName(self):
- """Return the name of the colormap
- :rtype: str
- """
+ def getName(self) -> str | None:
+ """Return the name of the colormap"""
return self._name
- def setName(self, name):
+ def setName(self, name: str):
"""Set the name of the colormap to use.
- :param str name: The name of the colormap.
+ :param name: The name of the colormap.
At least the following names are supported: 'gray',
'reversed gray', 'temperature', 'red', 'green', 'blue', 'jet',
'viridis', 'magma', 'inferno', 'plasma'.
@@ -370,44 +444,45 @@ class Colormap(qt.QObject):
if self._name == name:
return
if self.isEditable() is False:
- raise NotEditableError('Colormap is not editable')
+ raise NotEditableError("Colormap is not editable")
if name not in self.getSupportedColormaps():
raise ValueError("Colormap name '%s' is not supported" % name)
self._name = name
self._colors = _getColormap(self._name)
self.sigChanged.emit()
- def getColormapLUT(self, copy=True):
+ def getColormapLUT(self, copy: bool = True) -> numpy.ndarray | None:
"""Return the list of colors for the colormap or None if not set.
This returns None if the colormap was set with :meth:`setName`.
Use :meth:`getNColors` to get the colormap LUT for any colormap.
- :param bool copy: If true a copy of the numpy array is provided
+ :param copy: If true a copy of the numpy array is provided
:return: the list of colors for the colormap or None if not set
- :rtype: numpy.ndarray or None
"""
if self._name is None:
return numpy.array(self._colors, copy=copy)
- else:
- return None
+ return None
- def setColormapLUT(self, colors):
+ def setColormapLUT(self, colors: numpy.ndarray):
"""Set the colors of the colormap.
- :param numpy.ndarray colors: the colors of the LUT.
+ :param colors: the colors of the LUT.
If float, it is converted from [0, 1] to uint8 range.
Otherwise it is casted to uint8.
.. warning: this will set the value of name to None
"""
if self.isEditable() is False:
- raise NotEditableError('Colormap is not editable')
+ raise NotEditableError("Colormap is not editable")
assert colors is not None
colors = numpy.array(colors, copy=False)
if colors.shape == ():
- raise TypeError("An array is expected for 'colors' argument. '%s' was found." % type(colors))
+ raise TypeError(
+ "An array is expected for 'colors' argument. '%s' was found."
+ % type(colors)
+ )
assert len(colors) != 0
assert colors.ndim >= 2
colors.shape = -1, colors.shape[-1]
@@ -415,44 +490,39 @@ class Colormap(qt.QObject):
self._name = None
self.sigChanged.emit()
- def getNaNColor(self):
- """Returns the color to use for Not-A-Number floating point value.
-
- :rtype: QColor
- """
+ def getNaNColor(self) -> qt.QColor:
+ """Returns the color to use for Not-A-Number floating point value."""
return qt.QColor(*self.__nanColor)
- def setNaNColor(self, color):
+ def setNaNColor(self, color: ColorType):
"""Set the color to use for Not-A-Number floating point value.
:param color: RGB(A) color to use for NaN values
- :type color: QColor, str, tuple of uint8 or float in [0., 1.]
"""
color = (numpy.array(rgba(color)) * 255).astype(numpy.uint8)
if not numpy.array_equal(self.__nanColor, color):
self.__nanColor = color
self.sigChanged.emit()
- def getNormalization(self):
+ def getNormalization(self) -> str:
"""Return the normalization of the colormap.
See :meth:`setNormalization` for returned values.
:return: the normalization of the colormap
- :rtype: str
"""
return self._normalization
- def setNormalization(self, norm):
+ def setNormalization(self, norm: str):
"""Set the colormap normalization.
Accepted normalizations: 'log', 'linear', 'sqrt'
- :param str norm: the norm to set
+ :param norm: the norm to set
"""
assert norm in self.NORMALIZATIONS
if self.isEditable() is False:
- raise NotEditableError('Colormap is not editable')
+ raise NotEditableError("Colormap is not editable")
norm = str(norm)
if norm != self._normalization:
self._normalization = norm
@@ -460,71 +530,63 @@ class Colormap(qt.QObject):
self.__warnBadVmax = True
self.sigChanged.emit()
- def setGammaNormalizationParameter(self, gamma: float) -> None:
+ def setGammaNormalizationParameter(self, gamma: float):
"""Set the gamma correction parameter.
Only used for gamma correction normalization.
- :param float gamma:
:raise ValueError: If gamma is not valid
"""
- if gamma < 0. or not numpy.isfinite(gamma):
+ if gamma < 0.0 or not numpy.isfinite(gamma):
raise ValueError("Gamma value not supported")
if gamma != self.__gamma:
self.__gamma = gamma
self.sigChanged.emit()
def getGammaNormalizationParameter(self) -> float:
- """Returns the gamma correction parameter value.
-
- :rtype: float
- """
+ """Returns the gamma correction parameter value."""
return self.__gamma
- def getAutoscaleMode(self):
- """Return the autoscale mode of the colormap ('minmax' or 'stddev3')
-
- :rtype: str
- """
+ def getAutoscaleMode(self) -> str:
+ """Return the autoscale mode of the colormap ('minmax' or 'stddev3')"""
return self._autoscaleMode
- def setAutoscaleMode(self, mode):
+ def setAutoscaleMode(self, mode: str):
"""Set the autoscale mode: either 'minmax' or 'stddev3'
- :param str mode: the mode to set
+ :param mode: the mode to set
"""
if self.isEditable() is False:
- raise NotEditableError('Colormap is not editable')
+ raise NotEditableError("Colormap is not editable")
assert mode in self.AUTOSCALE_MODES
if mode != self._autoscaleMode:
self._autoscaleMode = mode
self.sigChanged.emit()
- def isAutoscale(self):
+ def isAutoscale(self) -> bool:
"""Return True if both min and max are in autoscale mode"""
return self._vmin is None and self._vmax is None
- def getVMin(self):
+ def getVMin(self) -> float | None:
"""Return the lower bound of the colormap
- :return: the lower bound of the colormap
- :rtype: float or None
- """
+ :return: the lower bound of the colormap
+ """
return self._vmin
- def setVMin(self, vmin):
+ def setVMin(self, vmin: float | None):
"""Set the minimal value of the colormap
- :param float vmin: Lower bound of the colormap or None for autoscale
- (default)
- value)
+ :param vmin: Lower bound of the colormap or None for autoscale (initial value)
"""
if self.isEditable() is False:
- raise NotEditableError('Colormap is not editable')
+ raise NotEditableError("Colormap is not editable")
if vmin is not None:
if self._vmax is not None and vmin > self._vmax:
- err = "Can't set vmin because vmin >= vmax. " \
- "vmin = %s, vmax = %s" % (vmin, self._vmax)
+ err = "Can't set vmin because vmin >= vmax. " "vmin = %s, vmax = %s" % (
+ vmin,
+ self._vmax,
+ )
raise ValueError(err)
if vmin != self._vmin:
@@ -532,26 +594,26 @@ class Colormap(qt.QObject):
self.__warnBadVmin = True
self.sigChanged.emit()
- def getVMax(self):
+ def getVMax(self) -> float | None:
"""Return the upper bounds of the colormap or None
:return: the upper bounds of the colormap or None
- :rtype: float or None
"""
return self._vmax
- def setVMax(self, vmax):
+ def setVMax(self, vmax: float | None):
"""Set the maximal value of the colormap
- :param float vmax: Upper bounds of the colormap or None for autoscale
- (default)
+ :param vmax: Upper bounds of the colormap or None for autoscale (initial value)
"""
if self.isEditable() is False:
- raise NotEditableError('Colormap is not editable')
+ raise NotEditableError("Colormap is not editable")
if vmax is not None:
if self._vmin is not None and vmax < self._vmin:
- err = "Can't set vmax because vmax <= vmin. " \
- "vmin = %s, vmax = %s" % (self._vmin, vmax)
+ err = "Can't set vmax because vmax <= vmin. " "vmin = %s, vmax = %s" % (
+ self._vmin,
+ vmax,
+ )
raise ValueError(err)
if vmax != self._vmax:
@@ -559,25 +621,24 @@ class Colormap(qt.QObject):
self.__warnBadVmax = True
self.sigChanged.emit()
- def isEditable(self):
- """ Return if the colormap is editable or not
+ def isEditable(self) -> bool:
+ """Return if the colormap is editable or not
:return: editable state of the colormap
- :rtype: bool
"""
return self._editable
- def setEditable(self, editable):
+ def setEditable(self, editable: bool):
"""
Set the editable state of the colormap
- :param bool editable: is the colormap editable
+ :param editable: is the colormap editable
"""
assert type(editable) is bool
self._editable = editable
self.sigChanged.emit()
- def _getNormalizer(self):
+ def _getNormalizer(self): # TODO
"""Returns normalizer object"""
normalization = self.getNormalization()
if normalization == self.GAMMA:
@@ -585,26 +646,28 @@ class Colormap(qt.QObject):
else:
return self._BASIC_NORMALIZATIONS[normalization]
- def _computeAutoscaleRange(self, data):
+ def _computeAutoscaleRange(self, data: numpy.ndarray):
"""Compute the data range which will be used in autoscale mode.
- :param numpy.ndarray data: The data for which to compute the range
+ :param data: The data for which to compute the range
:return: (vmin, vmax) range
"""
- return self._getNormalizer().autoscale(
- data, mode=self.getAutoscaleMode())
+ return self._getNormalizer().autoscale(data, mode=self.getAutoscaleMode())
- def getColormapRange(self, data=None):
+ def getColormapRange(
+ self,
+ data: numpy.ndarray | _Colormappable | None = None,
+ ) -> tuple[float, float]:
"""Return (vmin, vmax) the range of the colormap for the given data or item.
- :param Union[numpy.ndarray,~silx.gui.plot.items.ColormapMixIn] data:
- The data or item to use for autoscale bounds.
+ :param data: The data or item to use for autoscale bounds.
:return: (vmin, vmax) corresponding to the colormap applied to data if provided.
- :rtype: tuple
"""
vmin = self._vmin
vmax = self._vmax
- assert vmin is None or vmax is None or vmin <= vmax # TODO handle this in setters
+ assert (
+ vmin is None or vmax is None or vmin <= vmax
+ ) # TODO handle this in setters
normalizer = self._getNormalizer()
@@ -612,26 +675,22 @@ class Colormap(qt.QObject):
if vmin is not None and not normalizer.is_valid(vmin):
if self.__warnBadVmin:
self.__warnBadVmin = False
- _logger.info(
- 'Invalid vmin, switching to autoscale for lower bound')
+ _logger.info("Invalid vmin, switching to autoscale for lower bound")
vmin = None
if vmax is not None and not normalizer.is_valid(vmax):
if self.__warnBadVmax:
self.__warnBadVmax = False
- _logger.info(
- 'Invalid vmax, switching to autoscale for upper bound')
+ _logger.info("Invalid vmax, switching to autoscale for upper bound")
vmax = None
if vmin is None or vmax is None: # Handle autoscale
- from .plot.items.core import ColormapMixIn # avoid cyclic import
- if isinstance(data, ColormapMixIn):
+ if isinstance(data, _Colormappable):
min_, max_ = data._getColormapAutoscaleRange(self)
# Make sure min_, max_ are not None
min_ = normalizer.DEFAULT_RANGE[0] if min_ is None else min_
max_ = normalizer.DEFAULT_RANGE[1] if max_ is None else max_
else:
- min_, max_ = normalizer.autoscale(
- data, mode=self.getAutoscaleMode())
+ min_, max_ = normalizer.autoscale(data, mode=self.getAutoscaleMode())
if vmin is None: # Set vmin respecting provided vmax
vmin = min_ if vmax is None else min(min_, vmax)
@@ -641,16 +700,15 @@ class Colormap(qt.QObject):
return vmin, vmax
- def getVRange(self):
+ def getVRange(self) -> tuple[float | None, float | None]:
"""Get the bounds of the colormap
- :rtype: Tuple(Union[float,None],Union[float,None])
:returns: A tuple of 2 values for min and max. Or None instead of float
for autoscale
"""
return self.getVMin(), self.getVMax()
- def setVRange(self, vmin, vmax):
+ def setVRange(self, vmin: float | None, vmax: float | None):
"""Set the bounds of the colormap
:param vmin: Lower bound of the colormap or None for autoscale
@@ -659,11 +717,23 @@ class Colormap(qt.QObject):
(default)
"""
if self.isEditable() is False:
- raise NotEditableError('Colormap is not editable')
+ raise NotEditableError("Colormap is not editable")
+
+ if (vmin is not None and not numpy.isfinite(vmin)) or (
+ vmax is not None and not numpy.isfinite(vmax)
+ ):
+ err = (
+ "Can't set vmin and vmax because vmin or vmax are not finite "
+ "vmin = %s, vmax = %s" % (vmin, vmax)
+ )
+ raise ValueError(err)
+
if vmin is not None and vmax is not None:
if vmin > vmax:
- err = "Can't set vmin and vmax because vmin >= vmax " \
- "vmin = %s, vmax = %s" % (vmin, vmax)
+ err = (
+ "Can't set vmin and vmax because vmin >= vmax "
+ "vmin = %s, vmax = %s" % (vmin, vmax)
+ )
raise ValueError(err)
if self._vmin == vmin and self._vmax == vmax:
@@ -677,79 +747,78 @@ class Colormap(qt.QObject):
self._vmax = vmax
self.sigChanged.emit()
- def __getitem__(self, item):
- if item == 'autoscale':
+ def __getitem__(self, item: str):
+ if item == "autoscale":
return self.isAutoscale()
- elif item == 'name':
+ elif item == "name":
return self.getName()
- elif item == 'normalization':
+ elif item == "normalization":
return self.getNormalization()
- elif item == 'vmin':
+ elif item == "vmin":
return self.getVMin()
- elif item == 'vmax':
+ elif item == "vmax":
return self.getVMax()
- elif item == 'colors':
+ elif item == "colors":
return self.getColormapLUT()
- elif item == 'autoscaleMode':
+ elif item == "autoscaleMode":
return self.getAutoscaleMode()
else:
raise KeyError(item)
- def _toDict(self):
+ def _toDict(self) -> dict:
"""Return the equivalent colormap as a dictionary
(old colormap representation)
:return: the representation of the Colormap as a dictionary
- :rtype: dict
"""
return {
- 'name': self._name,
- 'colors': self.getColormapLUT(),
- 'vmin': self._vmin,
- 'vmax': self._vmax,
- 'autoscale': self.isAutoscale(),
- 'normalization': self.getNormalization(),
- 'autoscaleMode': self.getAutoscaleMode(),
- }
-
- def _setFromDict(self, dic):
+ "name": self._name,
+ "colors": self.getColormapLUT(),
+ "vmin": self._vmin,
+ "vmax": self._vmax,
+ "autoscale": self.isAutoscale(),
+ "normalization": self.getNormalization(),
+ "autoscaleMode": self.getAutoscaleMode(),
+ }
+
+ def _setFromDict(self, dic: dict):
"""Set values to the colormap from a dictionary
- :param dict dic: the colormap as a dictionary
+ :param dic: the colormap as a dictionary
"""
if self.isEditable() is False:
- raise NotEditableError('Colormap is not editable')
- name = dic['name'] if 'name' in dic else None
- colors = dic['colors'] if 'colors' in dic else None
+ raise NotEditableError("Colormap is not editable")
+ name = dic["name"] if "name" in dic else None
+ colors = dic["colors"] if "colors" in dic else None
if name is not None and colors is not None:
if isinstance(colors, int):
# Filter out argument which was supported but never used
_logger.info("Unused 'colors' from colormap dictionary filterer.")
colors = None
- vmin = dic['vmin'] if 'vmin' in dic else None
- vmax = dic['vmax'] if 'vmax' in dic else None
- if 'normalization' in dic:
- normalization = dic['normalization']
+ vmin = dic["vmin"] if "vmin" in dic else None
+ vmax = dic["vmax"] if "vmax" in dic else None
+ if "normalization" in dic:
+ normalization = dic["normalization"]
else:
- warn = 'Normalization not given in the dictionary, '
- warn += 'set by default to ' + Colormap.LINEAR
+ warn = "Normalization not given in the dictionary, "
+ warn += "set by default to " + Colormap.LINEAR
_logger.warning(warn)
normalization = Colormap.LINEAR
if name is None and colors is None:
- err = 'The colormap should have a name defined or a tuple of colors'
+ err = "The colormap should have a name defined or a tuple of colors"
raise ValueError(err)
if normalization not in Colormap.NORMALIZATIONS:
- err = 'Given normalization is not recognized (%s)' % normalization
+ err = "Given normalization is not recognized (%s)" % normalization
raise ValueError(err)
- autoscaleMode = dic.get('autoscaleMode', Colormap.MINMAX)
+ autoscaleMode = dic.get("autoscaleMode", Colormap.MINMAX)
if autoscaleMode not in Colormap.AUTOSCALE_MODES:
- err = 'Given autoscale mode is not recognized (%s)' % autoscaleMode
+ err = "Given autoscale mode is not recognized (%s)" % autoscaleMode
raise ValueError(err)
# If autoscale, then set boundaries to None
- if dic.get('autoscale', False):
+ if dic.get("autoscale", False):
vmin, vmax = None, None
if name is not None:
@@ -767,61 +836,57 @@ class Colormap(qt.QObject):
self.sigChanged.emit()
@staticmethod
- def _fromDict(dic):
+ def _fromDict(dic: dict):
colormap = Colormap()
colormap._setFromDict(dic)
return colormap
- def copy(self):
- """Return a copy of the Colormap.
-
- :rtype: silx.gui.colors.Colormap
- """
- colormap = Colormap(name=self._name,
- colors=self.getColormapLUT(),
- vmin=self._vmin,
- vmax=self._vmax,
- normalization=self.getNormalization(),
- autoscaleMode=self.getAutoscaleMode())
+ def copy(self) -> Colormap:
+ """Return a copy of the Colormap."""
+ colormap = Colormap(
+ name=self._name,
+ colors=self.getColormapLUT(),
+ vmin=self._vmin,
+ vmax=self._vmax,
+ normalization=self.getNormalization(),
+ autoscaleMode=self.getAutoscaleMode(),
+ )
colormap.setNaNColor(self.getNaNColor())
- colormap.setGammaNormalizationParameter(
- self.getGammaNormalizationParameter())
+ colormap.setGammaNormalizationParameter(self.getGammaNormalizationParameter())
colormap.setEditable(self.isEditable())
return colormap
- def applyToData(self, data, reference=None):
+ def applyToData(
+ self,
+ data: numpy.ndarray | _Colormappable,
+ reference: numpy.ndarray | _Colormappable | None = None,
+ ) -> numpy.ndarray:
"""Apply the colormap to the data
- :param Union[numpy.ndarray,~silx.gui.plot.item.ColormapMixIn] data:
+ :param data:
The data to convert or the item for which to apply the colormap.
- :param Union[numpy.ndarray,~silx.gui.plot.item.ColormapMixIn,None] reference:
+ :param reference:
The data or item to use as reference to compute autoscale
"""
if reference is None:
reference = data
vmin, vmax = self.getColormapRange(reference)
- if hasattr(data, "getColormappedData"): # Use item's data
+ if isinstance(data, _Colormappable): # Use item's data
data = data.getColormappedData(copy=False)
return _colormap.cmap(
- data,
- self._colors,
- vmin,
- vmax,
- self._getNormalizer(),
- self.__nanColor)
+ data, self._colors, vmin, vmax, self._getNormalizer(), self.__nanColor
+ )
@staticmethod
- def getSupportedColormaps():
+ def getSupportedColormaps() -> tuple[str, ...]:
"""Get the supported colormap names as a tuple of str.
The list should at least contain and start by:
('gray', 'reversed gray', 'temperature', 'red', 'green', 'blue',
'viridis', 'magma', 'inferno', 'plasma')
-
- :rtype: tuple
"""
registered_colormaps = _colormap.get_registered_colormaps()
colormaps = set(registered_colormaps)
@@ -829,14 +894,15 @@ class Colormap(qt.QObject):
colormaps.update(_matplotlib_colormaps())
# Put registered_colormaps first
- colormaps = tuple(cmap for cmap in sorted(colormaps)
- if cmap not in registered_colormaps)
+ colormaps = tuple(
+ cmap for cmap in sorted(colormaps) if cmap not in registered_colormaps
+ )
return registered_colormaps + colormaps
- def __str__(self):
+ def __str__(self) -> str:
return str(self._toDict())
- def __eq__(self, other):
+ def __eq__(self, other: Any):
"""Compare colormap values and not pointers"""
if other is None:
return False
@@ -845,28 +911,31 @@ class Colormap(qt.QObject):
if self.getNormalization() != other.getNormalization():
return False
if self.getNormalization() == self.GAMMA:
- delta = self.getGammaNormalizationParameter() - other.getGammaNormalizationParameter()
+ delta = (
+ self.getGammaNormalizationParameter()
+ - other.getGammaNormalizationParameter()
+ )
if abs(delta) > 0.001:
return False
- return (self.getName() == other.getName() and
- self.getAutoscaleMode() == other.getAutoscaleMode() and
- self.getVMin() == other.getVMin() and
- self.getVMax() == other.getVMax() and
- numpy.array_equal(self.getColormapLUT(), other.getColormapLUT())
- )
+ return (
+ self.getName() == other.getName()
+ and self.getAutoscaleMode() == other.getAutoscaleMode()
+ and self.getVMin() == other.getVMin()
+ and self.getVMax() == other.getVMax()
+ and numpy.array_equal(self.getColormapLUT(), other.getColormapLUT())
+ )
_SERIAL_VERSION = 3
- def restoreState(self, byteArray):
+ def restoreState(self, byteArray: qt.QByteArray) -> bool:
"""
Read the colormap state from a QByteArray.
- :param qt.QByteArray byteArray: Stream containing the state
+ :param byteArray: Stream containing the state
:return: True if the restoration sussseed
- :rtype: bool
"""
if self.isEditable() is False:
- raise NotEditableError('Colormap is not editable')
+ raise NotEditableError("Colormap is not editable")
stream = qt.QDataStream(byteArray, qt.QIODevice.ReadOnly)
className = stream.readQString()
@@ -875,7 +944,7 @@ class Colormap(qt.QObject):
return False
version = stream.readUInt32()
- if version not in numpy.arange(1, self._SERIAL_VERSION+1):
+ if version not in numpy.arange(1, self._SERIAL_VERSION + 1):
_logger.warning("Serial version mismatch. Found %d." % version)
return False
@@ -905,7 +974,12 @@ class Colormap(qt.QObject):
if version <= 2:
nanColor = self._DEFAULT_NAN_COLOR
else:
- nanColor = stream.readInt32(), stream.readInt32(), stream.readInt32(), stream.readInt32()
+ nanColor = (
+ stream.readInt32(),
+ stream.readInt32(),
+ stream.readInt32(),
+ stream.readInt32(),
+ )
# emit change event only once
old = self.blockSignals(True)
@@ -922,12 +996,8 @@ class Colormap(qt.QObject):
self.sigChanged.emit()
return True
- def saveState(self):
- """
- Save state of the colomap into a QDataStream.
-
- :rtype: qt.QByteArray
- """
+ def saveState(self) -> qt.QByteArray:
+ """Save state of the colomap into a QDataStream."""
data = qt.QByteArray()
stream = qt.QDataStream(data, qt.QIODevice.WriteOnly)
@@ -959,20 +1029,27 @@ Tuple of preferred colormap names accessed with :meth:`preferredColormaps`.
"""
_DEFAULT_PREFERRED_COLORMAPS = (
- 'gray', 'reversed gray', 'red', 'green', 'blue',
- 'viridis', 'cividis', 'magma', 'inferno', 'plasma',
- 'temperature',
- 'jet', 'hsv'
+ "gray",
+ "reversed gray",
+ "red",
+ "green",
+ "blue",
+ "viridis",
+ "cividis",
+ "magma",
+ "inferno",
+ "plasma",
+ "temperature",
+ "jet",
+ "hsv",
)
-def preferredColormaps():
+def preferredColormaps() -> tuple[str, ...]:
"""Returns the name of the preferred colormaps.
This list is used by widgets allowing to change the colormap
like the :class:`ColormapDialog` as a subset of colormap choices.
-
- :rtype: tuple of str
"""
global _PREFERRED_COLORMAPS
if _PREFERRED_COLORMAPS is None:
@@ -981,14 +1058,13 @@ def preferredColormaps():
return tuple(_PREFERRED_COLORMAPS)
-def setPreferredColormaps(colormaps):
+def setPreferredColormaps(colormaps: Iterable[str]):
"""Set the list of preferred colormap names.
Warning: If a colormap name is not available
it will be removed from the list.
:param colormaps: Not empty list of colormap names
- :type colormaps: iterable of str
:raise ValueError: if the list of available preferred colormaps is empty.
"""
supportedColormaps = Colormap.getSupportedColormaps()
@@ -1000,18 +1076,23 @@ def setPreferredColormaps(colormaps):
_PREFERRED_COLORMAPS = colormaps
-def registerLUT(name, colors, cursor_color='black', preferred=True):
+def registerLUT(
+ name: str,
+ colors: numpy.ndarray,
+ cursor_color: str = "black",
+ preferred: bool = True,
+):
"""Register a custom LUT to be used with `Colormap` objects.
It can override existing LUT names.
- :param str name: Name of the LUT as defined to configure colormaps
- :param numpy.ndarray colors: The custom LUT to register.
+ :param name: Name of the LUT as defined to configure colormaps
+ :param colors: The custom LUT to register.
Nx3 or Nx4 numpy array of RGB(A) colors,
either uint8 or float in [0, 1].
- :param bool preferred: If true, this LUT will be displayed as part of the
+ :param preferred: If true, this LUT will be displayed as part of the
preferred colormaps in dialogs.
- :param str cursor_color: Color used to display overlay over images using
+ :param cursor_color: Color used to display overlay over images using
colormap with this LUT.
"""
_colormap.register_colormap(name, colors, cursor_color)
@@ -1029,5 +1110,5 @@ def registerLUT(name, colors, cursor_color='black', preferred=True):
# Load some colormaps from matplotlib by default
if _matplotlib_cm is not None:
- _registerColormapFromMatplotlib('jet', cursor_color='pink', preferred=True)
- _registerColormapFromMatplotlib('hsv', cursor_color='black', preferred=True)
+ _registerColormapFromMatplotlib("jet", cursor_color="pink", preferred=True)
+ _registerColormapFromMatplotlib("hsv", cursor_color="black", preferred=True)
diff --git a/src/silx/gui/conftest.py b/src/silx/gui/conftest.py
index 74b5c19..2e9cf0d 100644
--- a/src/silx/gui/conftest.py
+++ b/src/silx/gui/conftest.py
@@ -1,5 +1,47 @@
import pytest
+from silx.gui import qt
+from silx.gui.qt.inspect import isValid
+
+
@pytest.fixture(autouse=True)
def auto_qapp(qapp):
pass
+
+
+@pytest.fixture
+def qWidgetFactory(qapp, qapp_utils):
+ """QWidget factory as fixture
+
+ This fixture provides a function taking a QWidget subclass as argument
+ which returns an instance of this QWidget making sure it is shown first
+ and destroyed once the test is done.
+ """
+ widgets = []
+
+ def createWidget(cls, *args, **kwargs):
+ widget = cls(*args, **kwargs)
+ widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ widget.show()
+ qapp_utils.qWaitForWindowExposed(widget)
+ widgets.append(widget)
+
+ return widget
+
+ yield createWidget
+
+ qapp.processEvents()
+
+ for widget in widgets:
+ if isValid(widget):
+ widget.close()
+ qapp.processEvents()
+
+ # Wait some time for all widgets to be deleted
+ for _ in range(10):
+ validWidgets = [widget for widget in widgets if isValid(widget)]
+ if validWidgets:
+ qapp_utils.qWait(10)
+
+ validWidgets = [widget for widget in widgets if isValid(widget)]
+ assert not validWidgets, f"Some widgets were not destroyed: {validWidgets}"
diff --git a/src/silx/gui/console.py b/src/silx/gui/console.py
index c66d44a..df0e36c 100644
--- a/src/silx/gui/console.py
+++ b/src/silx/gui/console.py
@@ -87,15 +87,16 @@ else:
raise ImportError(msg)
try:
- from qtconsole.rich_jupyter_widget import RichJupyterWidget as \
- _RichJupyterWidget
+ from qtconsole.rich_jupyter_widget import RichJupyterWidget as _RichJupyterWidget
except ImportError:
try:
- from qtconsole.rich_ipython_widget import RichJupyterWidget as \
- _RichJupyterWidget
+ from qtconsole.rich_ipython_widget import (
+ RichJupyterWidget as _RichJupyterWidget,
+ )
except ImportError:
- from qtconsole.rich_ipython_widget import RichIPythonWidget as \
- _RichJupyterWidget
+ from qtconsole.rich_ipython_widget import (
+ RichIPythonWidget as _RichJupyterWidget,
+ )
from qtconsole.inprocess import QtInProcessKernelManager
@@ -126,11 +127,15 @@ class IPythonWidget(_RichJupyterWidget):
# Monkey-patch to workaround issue:
# https://github.com/ipython/ipykernel/issues/370
- if (_ipykernel_version_info is not None and
- _ipykernel_version_info[0] > 4 and
- _ipykernel_version_info[:3] <= (5, 1, 0)):
+ if (
+ _ipykernel_version_info is not None
+ and _ipykernel_version_info[0] > 4
+ and _ipykernel_version_info[:3] <= (5, 1, 0)
+ ):
+
def _abort_queues(*args, **kwargs):
pass
+
kernel_manager.kernel._abort_queues = _abort_queues
self.kernel_client = kernel_client = self._kernel_manager.client()
@@ -139,6 +144,7 @@ class IPythonWidget(_RichJupyterWidget):
def stop():
kernel_client.stop_channels()
kernel_manager.shutdown_kernel()
+
self.exit_requested.connect(stop)
def sizeHint(self):
@@ -146,7 +152,7 @@ class IPythonWidget(_RichJupyterWidget):
return qt.QSize(500, 300)
def pushVariables(self, variable_dict):
- """ Given a dictionary containing name / value pairs, push those
+ """Given a dictionary containing name / value pairs, push those
variables to the IPython console widget.
:param variable_dict: Dictionary of variables to be pushed to the
@@ -169,8 +175,10 @@ class IPythonDockWidget(qt.QDockWidget):
:param parent: Parent :class:`qt.QMainWindow` containing this
:class:`qt.QDockWidget`
"""
- def __init__(self, parent=None, available_vars=None, custom_banner=None,
- title="Console"):
+
+ def __init__(
+ self, parent=None, available_vars=None, custom_banner=None, title="Console"
+ ):
super(IPythonDockWidget, self).__init__(title, parent)
self.ipyconsole = IPythonWidget(custom_banner=custom_banner)
@@ -190,5 +198,5 @@ def main():
app.exec()
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/src/silx/gui/constants.py b/src/silx/gui/constants.py
new file mode 100644
index 0000000..cc8b45e
--- /dev/null
+++ b/src/silx/gui/constants.py
@@ -0,0 +1,27 @@
+# /*##########################################################################
+#
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Constants related to silx GUI"""
+
+SILX_URI_MIMETYPE = "application/x-silx-uri"
+"""Used by silx to share data URL between application"""
diff --git a/src/silx/gui/data/ArrayTableModel.py b/src/silx/gui/data/ArrayTableModel.py
index 00cc235..2de0f05 100644
--- a/src/silx/gui/data/ArrayTableModel.py
+++ b/src/silx/gui/data/ArrayTableModel.py
@@ -192,8 +192,7 @@ class ArrayTableModel(qt.QAbstractTableModel):
dim = self._getRowDim()
else:
dim = self._getColumnDim()
- return (dim is not None and
- self._array.shape[dim] > self.MAX_NUMBER_OF_SECTIONS)
+ return dim is not None and self._array.shape[dim] > self.MAX_NUMBER_OF_SECTIONS
def __isClippedIndex(self, index) -> bool:
"""Returns whether or not index's cell represents clipped data."""
@@ -224,17 +223,23 @@ class ArrayTableModel(qt.QAbstractTableModel):
row, column = index.row(), index.column()
# When clipped, display last data of the array in last column of the table
- if (self.__isClipped(qt.Qt.Vertical) and
- row == self.MAX_NUMBER_OF_SECTIONS - 1):
+ if (
+ self.__isClipped(qt.Qt.Vertical)
+ and row == self.MAX_NUMBER_OF_SECTIONS - 1
+ ):
row = self._array.shape[self._getRowDim()] - 1
- if (self.__isClipped(qt.Qt.Horizontal) and
- column == self.MAX_NUMBER_OF_SECTIONS - 1):
+ if (
+ self.__isClipped(qt.Qt.Horizontal)
+ and column == self.MAX_NUMBER_OF_SECTIONS - 1
+ ):
column = self._array.shape[self._getColumnDim()] - 1
selection = self._getIndexTuple(row, column)
if role == qt.Qt.DisplayRole or role == qt.Qt.EditRole:
- return self._formatter.toString(self._array[selection], self._array.dtype)
+ return self._formatter.toString(
+ self._array[selection], self._array.dtype
+ )
if role == qt.Qt.BackgroundRole and self._bgcolors is not None:
r, g, b = self._bgcolors[selection][0:3]
@@ -305,8 +310,7 @@ class ArrayTableModel(qt.QAbstractTableModel):
except ValueError:
return False
- selection = self._getIndexTuple(index.row(),
- index.column())
+ selection = self._getIndexTuple(index.row(), index.column())
self._array[selection] = v
self.dataChanged.emit(index, index)
return True
@@ -314,8 +318,7 @@ class ArrayTableModel(qt.QAbstractTableModel):
return False
# Public methods
- def setArrayData(self, data, copy=True,
- perspective=None, editable=False):
+ def setArrayData(self, data, copy=True, perspective=None, editable=False):
"""Set the data array and the viewing perspective.
You can set ``copy=False`` if you need more performances, when dealing
@@ -352,9 +355,11 @@ class ArrayTableModel(qt.QAbstractTableModel):
# Avoid to lose the monkey-patched h5py dtype
self._array.dtype = data.dtype
elif not _is_array(data):
- raise TypeError("data is not a proper array. Try setting" +
- " copy=True to convert it into a numpy array" +
- " (this will cause the data to be copied!)")
+ raise TypeError(
+ "data is not a proper array. Try setting"
+ + " copy=True to convert it into a numpy array"
+ + " (this will cause the data to be copied!)"
+ )
# # copy not requested, but necessary
# _logger.warning(
# "data is not an array-like object. " +
@@ -366,8 +371,7 @@ class ArrayTableModel(qt.QAbstractTableModel):
self._array = data
# reset colors to None if new data shape is inconsistent
- valid_color_shapes = (self._array.shape + (3,),
- self._array.shape + (4,))
+ valid_color_shapes = (self._array.shape + (3,), self._array.shape + (4,))
if self._bgcolors is not None:
if self._bgcolors.shape not in valid_color_shapes:
self._bgcolors = None
@@ -378,8 +382,11 @@ class ArrayTableModel(qt.QAbstractTableModel):
self.setEditable(editable)
self._index = [0 for _i in range((len(self._array.shape) - 2))]
- self._perspective = tuple(perspective) if perspective is not None else\
- tuple(range(0, len(self._array.shape) - 2))
+ self._perspective = (
+ tuple(perspective)
+ if perspective is not None
+ else tuple(range(0, len(self._array.shape) - 2))
+ )
self.endResetModel()
@@ -442,8 +449,9 @@ class ArrayTableModel(qt.QAbstractTableModel):
if hasattr(self._array.file, "mode"):
if editable and self._array.file.mode == "r":
_logger.warning(
- "Data is a HDF5 dataset open in read-only " +
- "mode. Editing must be disabled.")
+ "Data is a HDF5 dataset open in read-only "
+ + "mode. Editing must be disabled."
+ )
self._editable = False
return False
return True
@@ -489,14 +497,17 @@ class ArrayTableModel(qt.QAbstractTableModel):
else:
self._index = index
if not 0 <= self._index[0] < len_:
- raise ValueError("Index must be a positive integer " +
- "lower than %d" % len_)
+ raise ValueError(
+ "Index must be a positive integer " + "lower than %d" % len_
+ )
else:
# general n-D case
for i_, idx in enumerate(index):
if not 0 <= idx < shape[self._perspective[i_]]:
- raise IndexError("Invalid index %d " % idx +
- "not in range 0-%d" % (shape[i_] - 1))
+ raise IndexError(
+ "Invalid index %d " % idx
+ + "not in range 0-%d" % (shape[i_] - 1)
+ )
self._index = index
self.endResetModel()
@@ -528,8 +539,7 @@ class ArrayTableModel(qt.QAbstractTableModel):
return self._formatter
def __formatChanged(self):
- """Called when the format changed.
- """
+ """Called when the format changed."""
self.reset()
def setPerspective(self, perspective):
@@ -562,8 +572,7 @@ class ArrayTableModel(qt.QAbstractTableModel):
"""
n_dimensions = len(self._array.shape)
if n_dimensions < 3:
- _logger.warning(
- "perspective is not relevant for 1D and 2D arrays")
+ _logger.warning("perspective is not relevant for 1D and 2D arrays")
return
if not hasattr(perspective, "__len__"):
@@ -576,12 +585,18 @@ class ArrayTableModel(qt.QAbstractTableModel):
# ensure unicity of dimensions in perspective
perspective = tuple(set(perspective))
- if len(perspective) != n_dimensions - 2 or\
- min(perspective) < 0 or max(perspective) >= n_dimensions:
+ if (
+ len(perspective) != n_dimensions - 2
+ or min(perspective) < 0
+ or max(perspective) >= n_dimensions
+ ):
raise IndexError(
- "Invalid perspective " + str(perspective) +
- " for %d-D array " % n_dimensions +
- "with shape " + str(self._array.shape))
+ "Invalid perspective "
+ + str(perspective)
+ + " for %d-D array " % n_dimensions
+ + "with shape "
+ + str(self._array.shape)
+ )
self.beginResetModel()
@@ -606,24 +621,31 @@ class ArrayTableModel(qt.QAbstractTableModel):
:raise: IndexError if axes are invalid
"""
if row_axis > col_axis:
- _logger.warning("The dimension of the row axis must be lower " +
- "than the dimension of the column axis. Swapping.")
+ _logger.warning(
+ "The dimension of the row axis must be lower "
+ + "than the dimension of the column axis. Swapping."
+ )
row_axis, col_axis = min(row_axis, col_axis), max(row_axis, col_axis)
n_dimensions = len(self._array.shape)
if n_dimensions < 3:
- _logger.warning(
- "Frame axes cannot be changed for 1D and 2D arrays")
+ _logger.warning("Frame axes cannot be changed for 1D and 2D arrays")
return
perspective = tuple(set(range(0, n_dimensions)) - {row_axis, col_axis})
- if len(perspective) != n_dimensions - 2 or\
- min(perspective) < 0 or max(perspective) >= n_dimensions:
+ if (
+ len(perspective) != n_dimensions - 2
+ or min(perspective) < 0
+ or max(perspective) >= n_dimensions
+ ):
raise IndexError(
- "Invalid perspective " + str(perspective) +
- " for %d-D array " % n_dimensions +
- "with shape " + str(self._array.shape))
+ "Invalid perspective "
+ + str(perspective)
+ + " for %d-D array " % n_dimensions
+ + "with shape "
+ + str(self._array.shape)
+ )
self.beginResetModel()
diff --git a/src/silx/gui/data/ArrayTableWidget.py b/src/silx/gui/data/ArrayTableWidget.py
index 2f7762d..882c730 100644
--- a/src/silx/gui/data/ArrayTableWidget.py
+++ b/src/silx/gui/data/ArrayTableWidget.py
@@ -54,6 +54,7 @@ class AxesSelector(qt.QWidget):
The two axes can be used to select the row axis and the column axis t
display a slice of the array data in a table view.
"""
+
sigDimensionsChanged = qt.Signal(int, int)
"""Signal emitted whenever one of the comboboxes is changed.
The signal carries the two selected dimensions."""
@@ -126,7 +127,9 @@ class AxesSelector(qt.QWidget):
if not (0 <= row_dim < self.n - 1):
raise IndexError("Row dimension must be between 0 and %d" % (self.n - 2))
if not (row_dim < col_dim <= self.n - 1):
- raise IndexError("Col dimension must be between %d and %d" % (row_dim + 1, self.n - 1))
+ raise IndexError(
+ "Col dimension must be between %d and %d" % (row_dim + 1, self.n - 1)
+ )
# set the rows dimension; this triggers an update of columnsCB
self.rowsCB.setCurrentIndex(row_dim)
@@ -147,8 +150,7 @@ class AxesSelector(qt.QWidget):
self.columnsCB.clear()
def _getRowDim(self):
- """Get rows dimension, selected in :attr:`rowsCB`
- """
+ """Get rows dimension, selected in :attr:`rowsCB`"""
# rows combobox contains elements "0", ..."n-2",
# so the selected dim is always equal to the index
return self.rowsCB.currentIndex()
@@ -231,6 +233,7 @@ class ArrayTableWidget(qt.QWidget):
.. image:: img/ArrayTableWidget.png
"""
+
def __init__(self, parent=None):
"""
@@ -468,6 +471,7 @@ class ArrayTableWidget(qt.QWidget):
def main():
import numpy
+
a = qt.QApplication([])
d = numpy.random.normal(0, 1, (4, 5, 1000, 1000))
for j in range(4):
@@ -486,5 +490,6 @@ def main():
w.show()
a.exec()
+
if __name__ == "__main__":
main()
diff --git a/src/silx/gui/data/DataViewer.py b/src/silx/gui/data/DataViewer.py
index 2c93c65..aa522ec 100644
--- a/src/silx/gui/data/DataViewer.py
+++ b/src/silx/gui/data/DataViewer.py
@@ -43,9 +43,9 @@ __date__ = "12/02/2019"
_logger = logging.getLogger(__name__)
-DataSelection = collections.namedtuple("DataSelection",
- ["filename", "datapath",
- "slice", "permutation"])
+DataSelection = collections.namedtuple(
+ "DataSelection", ["filename", "datapath", "slice", "permutation"]
+)
class DataViewer(qt.QFrame):
@@ -172,7 +172,9 @@ class DataViewer(qt.QFrame):
view = viewClass(parent)
views.append(view)
except Exception:
- _logger.warning("%s instantiation failed. View is ignored" % viewClass.__name__)
+ _logger.warning(
+ "%s instantiation failed. View is ignored" % viewClass.__name__
+ )
_logger.debug("Backtrace", exc_info=True)
return views
@@ -222,19 +224,25 @@ class DataViewer(qt.QFrame):
info = self._getInfo()
axisNames = self.__currentView.axesNames(self.__data, info)
- if (info.isArray and info.size != 0 and
- self.__data is not None and axisNames is not None):
+ if (
+ info.isArray
+ and info.size != 0
+ and self.__data is not None
+ and axisNames is not None
+ ):
self.__useAxisSelection = True
self.__numpySelection.setAxisNames(axisNames)
self.__numpySelection.setCustomAxis(
- self.__currentView.customAxisNames())
+ self.__currentView.customAxisNames()
+ )
data = self.normalizeData(self.__data)
self.__numpySelection.setData(data)
# Try to restore previous permutation and selection
try:
self.__numpySelection.setSelection(
- previousSelection, previousPermutation)
+ previousSelection, previousPermutation
+ )
except ValueError as e:
_logger.info("Not restoring selection because: %s", e)
@@ -277,8 +285,10 @@ class DataViewer(qt.QFrame):
except:
datapath = None
- # FIXME: maybe use DataUrl, with added support of permutation
- self.__displayedSelection = DataSelection(filename, datapath, slicing, permutation)
+ # FIXME: maybe use DataUrl, with added support of permutation
+ self.__displayedSelection = DataSelection(
+ filename, datapath, slicing, permutation
+ )
# TODO: would be good to avoid that, it should be synchonous
qt.QTimer.singleShot(10, self.__setDataInView)
@@ -286,12 +296,19 @@ class DataViewer(qt.QFrame):
def __setDataInView(self):
self.__currentView.setData(self.__displayedData)
self.__currentView.setDataSelection(self.__displayedSelection)
+
+ if self.__displayedSelection is None:
+ return
+
# Emit signal only when selection has changed
- if (self.__previousSelection.slice != self.__displayedSelection.slice or
- self.__previousSelection.permutation != self.__displayedSelection.permutation
+ if (
+ self.__previousSelection.slice != self.__displayedSelection.slice
+ or self.__previousSelection.permutation
+ != self.__displayedSelection.permutation
):
self.selectionChanged.emit(
- self.__displayedSelection.slice, self.__displayedSelection.permutation)
+ self.__displayedSelection.slice, self.__displayedSelection.permutation
+ )
self.__previousSelection = self.__displayedSelection
def setDisplayedView(self, view):
diff --git a/src/silx/gui/data/DataViewerSelector.py b/src/silx/gui/data/DataViewerSelector.py
index d67908e..61a4077 100644
--- a/src/silx/gui/data/DataViewerSelector.py
+++ b/src/silx/gui/data/DataViewerSelector.py
@@ -118,7 +118,9 @@ class DataViewerSelector(qt.QWidget):
return
if self.__dataViewer is not None:
self.__dataViewer.dataChanged.disconnect(self.__updateButtonsVisibility)
- self.__dataViewer.displayedViewChanged.disconnect(self.__displayedViewChanged)
+ self.__dataViewer.displayedViewChanged.disconnect(
+ self.__displayedViewChanged
+ )
self.__dataViewer = dataViewer
if self.__dataViewer is not None:
self.__dataViewer.dataChanged.connect(self.__updateButtonsVisibility)
diff --git a/src/silx/gui/data/DataViews.py b/src/silx/gui/data/DataViews.py
index 0a4569f..ed688b8 100644
--- a/src/silx/gui/data/DataViews.py
+++ b/src/silx/gui/data/DataViews.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2016-2022 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -24,14 +24,12 @@
"""This module defines a views used by :class:`silx.gui.data.DataViewer`.
"""
-from collections import OrderedDict
import logging
import numbers
import numpy
import os
import silx.io
-from silx.utils import deprecation
from silx.gui import qt, icons
from silx.gui.data.TextFormatter import TextFormatter
from silx.io import nxdata
@@ -129,9 +127,17 @@ class DataInfo(object):
if nxd is not None:
self.hasNXdata = True
# can we plot it?
- is_scalar = nxd.signal_is_0d or nxd.interpretation in ["scalar", "scaler"]
- if not (is_scalar or nxd.is_curve or nxd.is_x_y_value_scatter or
- nxd.is_image or nxd.is_stack):
+ is_scalar = nxd.signal_is_0d or nxd.interpretation in [
+ "scalar",
+ "scaler",
+ ]
+ if not (
+ is_scalar
+ or nxd.is_curve
+ or nxd.is_x_y_value_scatter
+ or nxd.is_image
+ or nxd.is_stack
+ ):
# invalid: cannot be plotted by any widget
self.isInvalidNXdata = True
elif nx_class == "NXdata":
@@ -174,8 +180,7 @@ class DataInfo(object):
self.isComplex = numpy.issubdtype(data.dtype, numpy.complexfloating)
self.isBoolean = numpy.issubdtype(data.dtype, numpy.bool_)
elif self.hasNXdata:
- self.isNumeric = numpy.issubdtype(nxd.signal.dtype,
- numpy.number)
+ self.isNumeric = numpy.issubdtype(nxd.signal.dtype, numpy.number)
self.isComplex = numpy.issubdtype(nxd.signal.dtype, numpy.complexfloating)
self.isBoolean = numpy.issubdtype(nxd.signal.dtype, numpy.bool_)
else:
@@ -235,6 +240,7 @@ class DataViewHooks(object):
"""Called when the widget of the view was created"""
return
+
class DataView(object):
"""Holder for the data view."""
@@ -335,13 +341,11 @@ class DataView(object):
pass
def isWidgetInitialized(self):
- """Returns true if the widget is already initialized.
- """
+ """Returns true if the widget is already initialized."""
return self.__widget is not None
def select(self):
- """Called when the view is selected to display the data.
- """
+ """Called when the view is selected to display the data."""
return
def getWidget(self):
@@ -384,20 +388,25 @@ class DataView(object):
:rtype: str
"""
if indices is None:
- return ''
+ return ""
def formatSlice(slice_):
start, stop, step = slice_.start, slice_.stop, slice_.step
- string = ('' if start is None else str(start)) + ':'
+ string = ("" if start is None else str(start)) + ":"
if stop is not None:
string += str(stop)
if step not in (None, 1):
- string += ':' + step
+ string += ":" + step
return string
- return '[' + ', '.join(
- formatSlice(index) if isinstance(index, slice) else str(index)
- for index in indices) + ']'
+ return (
+ "["
+ + ", ".join(
+ formatSlice(index) if isinstance(index, slice) else str(index)
+ for index in indices
+ )
+ + "]"
+ )
def titleForSelection(self, selection):
"""Build title from given selection information.
@@ -413,9 +422,9 @@ class DataView(object):
slicing = self.__formatSlices(selection.slice)
except Exception:
_logger.debug("Error while formatting slices", exc_info=True)
- slicing = '[sliced]'
+ slicing = "[sliced]"
- permuted = '(permuted)' if selection.permutation is not None else ''
+ permuted = "(permuted)" if selection.permutation is not None else ""
try:
title = self.TITLE_PATTERN.format(
@@ -423,7 +432,8 @@ class DataView(object):
filename=filename,
datapath=selection.datapath,
slicing=slicing,
- permuted=permuted)
+ permuted=permuted,
+ )
except Exception:
_logger.debug("Error while formatting title", exc_info=True)
title = selection.datapath + slicing
@@ -530,10 +540,6 @@ class _CompositeDataView(DataView):
"""
raise NotImplementedError()
- @deprecation.deprecated(replacement="getReachableViews", since_version="0.10")
- def availableViews(self):
- return self.getViews()
-
def isSupportedData(self, data, info):
"""If true, the composite view allow sub views to access to this data.
Else this this data is considered as not supported by any of sub views
@@ -556,7 +562,7 @@ class SelectOneDataView(_CompositeDataView):
:param qt.QWidget parent: Parent of the hold widget
"""
super(SelectOneDataView, self).__init__(parent, modeId, icon, label)
- self.__views = OrderedDict()
+ self.__views = {}
self.__currentView = None
def setHooks(self, hooks):
@@ -711,9 +717,10 @@ class SelectOneDataView(_CompositeDataView):
return False
# replace oldView with new view in dict
- self.__views = OrderedDict(
- (newView, None) if view is oldView else (view, idx) for
- view, idx in self.__views.items())
+ self.__views = dict(
+ (newView, None) if view is oldView else (view, idx)
+ for view, idx in self.__views.items()
+ )
return True
@@ -733,7 +740,9 @@ class SelectManyDataView(_CompositeDataView):
:param qt.QWidget parent: Parent of the hold widget
"""
- super(SelectManyDataView, self).__init__(parent, modeId=None, icon=None, label=None)
+ super(SelectManyDataView, self).__init__(
+ parent, modeId=None, icon=None, label=None
+ )
if views is None:
views = []
self.__views = views
@@ -776,7 +785,11 @@ class SelectManyDataView(_CompositeDataView):
"""
if not self.isSupportedData(data, info):
return []
- views = [v for v in self.__views if v.getCachedDataPriority(data, info) != DataView.UNSUPPORTED]
+ views = [
+ v
+ for v in self.__views
+ if v.getCachedDataPriority(data, info) != DataView.UNSUPPORTED
+ ]
return views
def customAxisNames(self):
@@ -870,11 +883,13 @@ class _Plot1dView(DataView):
parent=parent,
modeId=PLOT1D_MODE,
label="Curve",
- icon=icons.getQIcon("view-1d"))
+ icon=icons.getQIcon("view-1d"),
+ )
self.__resetZoomNextTime = True
def createWidget(self, parent):
from silx.gui import plot
+
widget = plot.Plot1D(parent=parent)
widget.setGraphGrid(True)
return widget
@@ -892,10 +907,12 @@ class _Plot1dView(DataView):
data = self.normalizeData(data)
plotWidget = self.getWidget()
legend = "data"
- plotWidget.addCurve(legend=legend,
- x=range(len(data)),
- y=data,
- resetzoom=self.__resetZoomNextTime)
+ plotWidget.addCurve(
+ legend=legend,
+ x=range(len(data)),
+ y=data,
+ resetzoom=self.__resetZoomNextTime,
+ )
plotWidget.setActiveCurve(legend)
self.__resetZoomNextTime = True
@@ -928,7 +945,8 @@ class _Plot2dRecordView(DataView):
parent=parent,
modeId=RECORD_PLOT_MODE,
label="Curve",
- icon=icons.getQIcon("view-1d"))
+ icon=icons.getQIcon("view-1d"),
+ )
self.__resetZoomNextTime = True
self._data = None
self._xAxisDropDown = None
@@ -937,6 +955,7 @@ class _Plot2dRecordView(DataView):
def createWidget(self, parent):
from ._RecordPlot import RecordPlot
+
return RecordPlot(parent=parent)
def clear(self):
@@ -952,7 +971,9 @@ class _Plot2dRecordView(DataView):
self._data = self.normalizeData(data)
all_fields = sorted(self._data.dtype.fields.items(), key=lambda e: e[1][1])
- numeric_fields = [f[0] for f in all_fields if numpy.issubdtype(f[1][0], numpy.number)]
+ numeric_fields = [
+ f[0] for f in all_fields if numpy.issubdtype(f[1][0], numpy.number)
+ ]
if numeric_fields == self.__fields: # Reuse previously selected fields
fieldNameX = self.getWidget().getXAxisFieldName()
fieldNameY = self.getWidget().getYAxisFieldName()
@@ -974,8 +995,12 @@ class _Plot2dRecordView(DataView):
self._plotData(fieldNameX, fieldNameY)
if not self._xAxisDropDown:
- self._xAxisDropDown = self.getWidget().getAxesSelectionToolBar().getXAxisDropDown()
- self._yAxisDropDown = self.getWidget().getAxesSelectionToolBar().getYAxisDropDown()
+ self._xAxisDropDown = (
+ self.getWidget().getAxesSelectionToolBar().getXAxisDropDown()
+ )
+ self._yAxisDropDown = (
+ self.getWidget().getAxesSelectionToolBar().getYAxisDropDown()
+ )
self._xAxisDropDown.activated.connect(self._onAxesSelectionChaned)
self._yAxisDropDown.activated.connect(self._onAxesSelectionChaned)
@@ -993,10 +1018,9 @@ class _Plot2dRecordView(DataView):
xdata = numpy.arange(len(ydata))
else:
xdata = self._data[fieldNameX]
- self.getWidget().addCurve(legend="data",
- x=xdata,
- y=ydata,
- resetzoom=self.__resetZoomNextTime)
+ self.getWidget().addCurve(
+ legend="data", x=xdata, y=ydata, resetzoom=self.__resetZoomNextTime
+ )
self.getWidget().setXAxisFieldName(fieldNameX)
self.getWidget().setYAxisFieldName(fieldNameY)
self.__resetZoomNextTime = True
@@ -1031,18 +1055,20 @@ class _Plot2dView(DataView):
parent=parent,
modeId=PLOT2D_MODE,
label="Image",
- icon=icons.getQIcon("view-2d"))
+ icon=icons.getQIcon("view-2d"),
+ )
self.__resetZoomNextTime = True
def createWidget(self, parent):
from silx.gui import plot
+
widget = plot.Plot2D(parent=parent)
widget.setDefaultColormap(self.defaultColormap())
- widget.getColormapAction().setColorDialog(self.defaultColorDialog())
+ widget.getColormapAction().setColormapDialog(self.defaultColorDialog())
widget.getIntensityHistogramAction().setVisible(True)
widget.setKeepDataAspectRatio(True)
- widget.getXAxis().setLabel('X')
- widget.getYAxis().setLabel('Y')
+ widget.getXAxis().setLabel("X")
+ widget.getYAxis().setLabel("Y")
maskToolsWidget = widget.getMaskToolsDockWidget().widget()
maskToolsWidget.setItemMaskUpdated(True)
return widget
@@ -1058,9 +1084,9 @@ class _Plot2dView(DataView):
def setData(self, data):
data = self.normalizeData(data)
- self.getWidget().addImage(legend="data",
- data=data,
- resetzoom=self.__resetZoomNextTime)
+ self.getWidget().addImage(
+ legend="data", data=data, resetzoom=self.__resetZoomNextTime
+ )
self.__resetZoomNextTime = False
def setDataSelection(self, selection):
@@ -1072,9 +1098,7 @@ class _Plot2dView(DataView):
def getDataPriority(self, data, info):
if info.size <= 0:
return DataView.UNSUPPORTED
- if (data is None or
- not info.isArray or
- not (info.isNumeric or info.isBoolean)):
+ if data is None or not info.isArray or not (info.isNumeric or info.isBoolean):
return DataView.UNSUPPORTED
if info.dim < 2:
return DataView.UNSUPPORTED
@@ -1094,7 +1118,8 @@ class _Plot3dView(DataView):
parent=parent,
modeId=PLOT3D_MODE,
label="Cube",
- icon=icons.getQIcon("view-3d"))
+ icon=icons.getQIcon("view-3d"),
+ )
try:
from ._VolumeWindow import VolumeWindow # noqa
except ImportError:
@@ -1145,20 +1170,32 @@ class _ComplexImageView(DataView):
parent=parent,
modeId=COMPLEX_IMAGE_MODE,
label="Complex Image",
- icon=icons.getQIcon("view-2d"))
+ icon=icons.getQIcon("view-2d"),
+ )
def createWidget(self, parent):
from silx.gui.plot.ComplexImageView import ComplexImageView
+
widget = ComplexImageView(parent=parent)
- widget.setColormap(self.defaultColormap(), mode=ComplexImageView.ComplexMode.ABSOLUTE)
- widget.setColormap(self.defaultColormap(), mode=ComplexImageView.ComplexMode.SQUARE_AMPLITUDE)
- widget.setColormap(self.defaultColormap(), mode=ComplexImageView.ComplexMode.REAL)
- widget.setColormap(self.defaultColormap(), mode=ComplexImageView.ComplexMode.IMAGINARY)
- widget.getPlot().getColormapAction().setColorDialog(self.defaultColorDialog())
+ widget.setColormap(
+ self.defaultColormap(), mode=ComplexImageView.ComplexMode.ABSOLUTE
+ )
+ widget.setColormap(
+ self.defaultColormap(), mode=ComplexImageView.ComplexMode.SQUARE_AMPLITUDE
+ )
+ widget.setColormap(
+ self.defaultColormap(), mode=ComplexImageView.ComplexMode.REAL
+ )
+ widget.setColormap(
+ self.defaultColormap(), mode=ComplexImageView.ComplexMode.IMAGINARY
+ )
+ widget.getPlot().getColormapAction().setColormapDialog(
+ self.defaultColorDialog()
+ )
widget.getPlot().getIntensityHistogramAction().setVisible(True)
widget.getPlot().setKeepDataAspectRatio(True)
- widget.getXAxis().setLabel('X')
- widget.getYAxis().setLabel('Y')
+ widget.getXAxis().setLabel("X")
+ widget.getYAxis().setLabel("Y")
maskToolsWidget = widget.getPlot().getMaskToolsDockWidget().widget()
maskToolsWidget.setItemMaskUpdated(True)
return widget
@@ -1175,8 +1212,7 @@ class _ComplexImageView(DataView):
self.getWidget().setData(data)
def setDataSelection(self, selection):
- self.getWidget().getPlot().setGraphTitle(
- self.titleForSelection(selection))
+ self.getWidget().getPlot().setGraphTitle(self.titleForSelection(selection))
def axesNames(self, data, info):
return ["y", "x"]
@@ -1204,6 +1240,7 @@ class _ArrayView(DataView):
def createWidget(self, parent):
from silx.gui.data.ArrayTableWidget import ArrayTableWidget
+
widget = ArrayTableWidget(parent)
widget.displayAxesSelector(False)
return widget
@@ -1238,7 +1275,8 @@ class _StackView(DataView):
parent=parent,
modeId=STACK_MODE,
label="Image stack",
- icon=icons.getQIcon("view-2d-stack"))
+ icon=icons.getQIcon("view-2d-stack"),
+ )
self.__resetZoomNextTime = True
def customAxisNames(self):
@@ -1252,9 +1290,12 @@ class _StackView(DataView):
def createWidget(self, parent):
from silx.gui import plot
+
widget = plot.StackView(parent=parent)
widget.setColormap(self.defaultColormap())
- widget.getPlotWidget().getColormapAction().setColorDialog(self.defaultColorDialog())
+ widget.getPlotWidget().getColormapAction().setColormapDialog(
+ self.defaultColorDialog()
+ )
widget.setKeepDataAspectRatio(True)
widget.setLabels(self.axesNames(None, None))
# hide default option panel
@@ -1281,8 +1322,7 @@ class _StackView(DataView):
def setDataSelection(self, selection):
title = self.titleForSelection(selection)
- self.getWidget().setTitleCallback(
- lambda idx: "%s z=%d" % (title, idx))
+ self.getWidget().setTitleCallback(lambda idx: "%s z=%d" % (title, idx))
def axesNames(self, data, info):
return ["depth", "y", "x"]
@@ -1350,6 +1390,7 @@ class _RecordView(DataView):
def createWidget(self, parent):
from .RecordTableView import RecordTableView
+
widget = RecordTableView(parent)
widget.setWordWrap(False)
return widget
@@ -1394,6 +1435,7 @@ class _HexaView(DataView):
def createWidget(self, parent):
from .HexaTableView import HexaTableView
+
widget = HexaTableView(parent)
return widget
@@ -1424,10 +1466,12 @@ class _Hdf5View(DataView):
parent=parent,
modeId=HDF5_MODE,
label="HDF5",
- icon=icons.getQIcon("view-hdf5"))
+ icon=icons.getQIcon("view-hdf5"),
+ )
def createWidget(self, parent):
from .Hdf5TableView import Hdf5TableView
+
widget = Hdf5TableView(parent)
return widget
@@ -1459,10 +1503,8 @@ class _RawView(CompositeDataView):
def __init__(self, parent):
super(_RawView, self).__init__(
- parent=parent,
- modeId=RAW_MODE,
- label="Raw",
- icon=icons.getQIcon("view-raw"))
+ parent=parent, modeId=RAW_MODE, label="Raw", icon=icons.getQIcon("view-raw")
+ )
self.addView(_HexaView(parent))
self.addView(_ScalarView(parent))
self.addView(_ArrayView(parent))
@@ -1480,7 +1522,8 @@ class _ImageView(CompositeDataView):
parent=parent,
modeId=IMAGE_MODE,
label="Image",
- icon=icons.getQIcon("view-2d"))
+ icon=icons.getQIcon("view-2d"),
+ )
self.addView(_ComplexImageView(parent))
self.addView(_Plot2dView(parent))
@@ -1489,9 +1532,9 @@ class _InvalidNXdataView(DataView):
"""DataView showing a simple label with an error message
to inform that a group with @NX_class=NXdata cannot be
interpreted by any NXDataview."""
+
def __init__(self, parent):
- DataView.__init__(self, parent,
- modeId=NXDATA_INVALID_MODE)
+ DataView.__init__(self, parent, modeId=NXDATA_INVALID_MODE)
self._msg = ""
def createWidget(self, parent):
@@ -1532,8 +1575,10 @@ class _InvalidNXdataView(DataView):
self._msg += "@default attribute, "
if default_nxdata_name not in default_entry:
self._msg += " but no corresponding NXdata group exists."
- elif get_attr_as_unicode(default_entry[default_nxdata_name],
- "NX_class") != "NXdata":
+ elif (
+ get_attr_as_unicode(default_entry[default_nxdata_name], "NX_class")
+ != "NXdata"
+ ):
self._msg += " but the corresponding item is not a "
self._msg += "NXdata group."
else:
@@ -1544,7 +1589,10 @@ class _InvalidNXdataView(DataView):
default_nxdata_name = data.attrs["default"]
if default_nxdata_name not in data:
self._msg += " but no corresponding NXdata group exists."
- elif get_attr_as_unicode(data[default_nxdata_name], "NX_class") != "NXdata":
+ elif (
+ get_attr_as_unicode(data[default_nxdata_name], "NX_class")
+ != "NXdata"
+ ):
self._msg += " but the corresponding item is not a "
self._msg += "NXdata group."
else:
@@ -1564,18 +1612,20 @@ class _NXdataBaseDataView(DataView):
cmap_norm = nxdata.plot_style.signal_scale_type
if cmap_norm is not None:
self.defaultColormap().setNormalization(
- 'log' if cmap_norm == 'log' else 'linear')
+ "log" if cmap_norm == "log" else "linear"
+ )
class _NXdataScalarView(_NXdataBaseDataView):
"""DataView using a table view for displaying NXdata scalars:
0-D signal or n-D signal with *@interpretation=scalar*"""
+
def __init__(self, parent):
- _NXdataBaseDataView.__init__(
- self, parent, modeId=NXDATA_SCALAR_MODE)
+ _NXdataBaseDataView.__init__(self, parent, modeId=NXDATA_SCALAR_MODE)
def createWidget(self, parent):
from silx.gui.data.ArrayTableWidget import ArrayTableWidget
+
widget = ArrayTableWidget(parent)
# widget.displayAxesSelector(False)
return widget
@@ -1584,16 +1634,14 @@ class _NXdataScalarView(_NXdataBaseDataView):
return ["col", "row"]
def clear(self):
- self.getWidget().setArrayData(numpy.array([[]]),
- labels=True)
+ self.getWidget().setArrayData(numpy.array([[]]), labels=True)
def setData(self, data):
data = self.normalizeData(data)
# data could be a NXdata or an NXentry
nxd = nxdata.get_default(data, validate=False)
signal = nxd.signal
- self.getWidget().setArrayData(signal,
- labels=True)
+ self.getWidget().setArrayData(signal, labels=True)
def getDataPriority(self, data, info):
data = self.normalizeData(data)
@@ -1612,12 +1660,13 @@ class _NXdataCurveView(_NXdataBaseDataView):
It also handles basic scatter plots:
a 1-D signal with one axis whose values are not monotonically increasing.
"""
+
def __init__(self, parent):
- _NXdataBaseDataView.__init__(
- self, parent, modeId=NXDATA_CURVE_MODE)
+ _NXdataBaseDataView.__init__(self, parent, modeId=NXDATA_CURVE_MODE)
def createWidget(self, parent):
from silx.gui.data.NXdataWidgets import ArrayCurvePlot
+
widget = ArrayCurvePlot(parent)
return widget
@@ -1637,24 +1686,17 @@ class _NXdataCurveView(_NXdataBaseDataView):
else:
x_errors = None
- # this fix is necessary until the next release of PyMca (5.2.3 or 5.3.0)
- # see https://github.com/vasole/pymca/issues/144 and https://github.com/vasole/pymca/pull/145
- if not hasattr(self.getWidget(), "setCurvesData") and \
- hasattr(self.getWidget(), "setCurveData"):
- _logger.warning("Using deprecated ArrayCurvePlot API, "
- "without support of auxiliary signals")
- self.getWidget().setCurveData(nxd.signal, nxd.axes[-1],
- yerror=nxd.errors, xerror=x_errors,
- ylabel=nxd.signal_name, xlabel=nxd.axes_names[-1],
- title=nxd.title or nxd.signal_name)
- return
-
- self.getWidget().setCurvesData([nxd.signal] + nxd.auxiliary_signals, nxd.axes[-1],
- yerror=nxd.errors, xerror=x_errors,
- ylabels=signals_names, xlabel=nxd.axes_names[-1],
- title=nxd.title or signals_names[0],
- xscale=nxd.plot_style.axes_scale_types[-1],
- yscale=nxd.plot_style.signal_scale_type)
+ self.getWidget().setCurvesData(
+ [nxd.signal] + nxd.auxiliary_signals,
+ nxd.axes[-1],
+ yerror=nxd.errors,
+ xerror=x_errors,
+ ylabels=signals_names,
+ xlabel=nxd.axes_names[-1],
+ title=nxd.title or signals_names[0],
+ xscale=nxd.plot_style.axes_scale_types[-1],
+ yscale=nxd.plot_style.signal_scale_type,
+ )
def getDataPriority(self, data, info):
data = self.normalizeData(data)
@@ -1667,16 +1709,18 @@ class _NXdataCurveView(_NXdataBaseDataView):
class _NXdataXYVScatterView(_NXdataBaseDataView):
"""DataView using a Plot1D for displaying NXdata 3D scatters as
a scatter of coloured points (1-D signal with 2 axes)"""
+
def __init__(self, parent):
- _NXdataBaseDataView.__init__(
- self, parent, modeId=NXDATA_XYVSCATTER_MODE)
+ _NXdataBaseDataView.__init__(self, parent, modeId=NXDATA_XYVSCATTER_MODE)
def createWidget(self, parent):
from silx.gui.data.NXdataWidgets import XYVScatterPlot
+
widget = XYVScatterPlot(parent)
widget.getScatterView().setColormap(self.defaultColormap())
- widget.getScatterView().getScatterToolBar().getColormapAction().setColorDialog(
- self.defaultColorDialog())
+ widget.getScatterView().getScatterToolBar().getColormapAction().setColormapDialog(
+ self.defaultColorDialog()
+ )
return widget
def axesNames(self, data, info):
@@ -1709,13 +1753,19 @@ class _NXdataXYVScatterView(_NXdataBaseDataView):
self._updateColormap(nxd)
- self.getWidget().setScattersData(y_axis, x_axis, values=[nxd.signal] + nxd.auxiliary_signals,
- yerror=y_errors, xerror=x_errors,
- ylabel=y_label, xlabel=x_label,
- title=nxd.title,
- scatter_titles=[nxd.signal_name] + nxd.auxiliary_signals_names,
- xscale=nxd.plot_style.axes_scale_types[-2],
- yscale=nxd.plot_style.axes_scale_types[-1])
+ self.getWidget().setScattersData(
+ y_axis,
+ x_axis,
+ values=[nxd.signal] + nxd.auxiliary_signals,
+ yerror=y_errors,
+ xerror=x_errors,
+ ylabel=y_label,
+ xlabel=x_label,
+ title=nxd.title,
+ scatter_titles=[nxd.signal_name] + nxd.auxiliary_signals_names,
+ xscale=nxd.plot_style.axes_scale_types[-2],
+ yscale=nxd.plot_style.axes_scale_types[-1],
+ )
def getDataPriority(self, data, info):
data = self.normalizeData(data)
@@ -1730,15 +1780,18 @@ class _NXdataXYVScatterView(_NXdataBaseDataView):
class _NXdataImageView(_NXdataBaseDataView):
"""DataView using a Plot2D for displaying NXdata images:
2-D signal or n-D signals with *@interpretation=image*."""
+
def __init__(self, parent):
- _NXdataBaseDataView.__init__(
- self, parent, modeId=NXDATA_IMAGE_MODE)
+ _NXdataBaseDataView.__init__(self, parent, modeId=NXDATA_IMAGE_MODE)
def createWidget(self, parent):
from silx.gui.data.NXdataWidgets import ArrayImagePlot
+
widget = ArrayImagePlot(parent)
widget.getPlot().setDefaultColormap(self.defaultColormap())
- widget.getPlot().getColormapAction().setColorDialog(self.defaultColorDialog())
+ widget.getPlot().getColormapAction().setColormapDialog(
+ self.defaultColorDialog()
+ )
return widget
def axesNames(self, data, info):
@@ -1760,16 +1813,20 @@ class _NXdataImageView(_NXdataBaseDataView):
y_axis, x_axis = nxd.axes[img_slicing]
y_label, x_label = nxd.axes_names[img_slicing]
y_scale, x_scale = nxd.plot_style.axes_scale_types[img_slicing]
- x_units = get_attr_as_unicode(x_axis, 'units') if x_axis else None
- y_units = get_attr_as_unicode(y_axis, 'units') if y_axis else None
+ x_units = get_attr_as_unicode(x_axis, "units") if x_axis else None
+ y_units = get_attr_as_unicode(y_axis, "units") if y_axis else None
self.getWidget().setImageData(
[nxd.signal] + nxd.auxiliary_signals,
- x_axis=x_axis, y_axis=y_axis,
+ x_axis=x_axis,
+ y_axis=y_axis,
signals_names=[nxd.signal_name] + nxd.auxiliary_signals_names,
- xlabel=x_label, ylabel=y_label,
- title=nxd.title, isRgba=isRgba,
- xscale=x_scale, yscale=y_scale,
+ xlabel=x_label,
+ ylabel=y_label,
+ title=nxd.title,
+ isRgba=isRgba,
+ xscale=x_scale,
+ yscale=y_scale,
keep_ratio=(x_units == y_units),
)
@@ -1786,14 +1843,17 @@ class _NXdataImageView(_NXdataBaseDataView):
class _NXdataComplexImageView(_NXdataBaseDataView):
"""DataView using a ComplexImageView for displaying NXdata complex images:
2-D signal or n-D signals with *@interpretation=image*."""
+
def __init__(self, parent):
- _NXdataBaseDataView.__init__(
- self, parent, modeId=NXDATA_IMAGE_MODE)
+ _NXdataBaseDataView.__init__(self, parent, modeId=NXDATA_IMAGE_MODE)
def createWidget(self, parent):
from silx.gui.data.NXdataWidgets import ArrayComplexImagePlot
+
widget = ArrayComplexImagePlot(parent, colormap=self.defaultColormap())
- widget.getPlot().getColormapAction().setColorDialog(self.defaultColorDialog())
+ widget.getPlot().getColormapAction().setColormapDialog(
+ self.defaultColorDialog()
+ )
return widget
def clear(self):
@@ -1809,14 +1869,16 @@ class _NXdataComplexImageView(_NXdataBaseDataView):
img_slicing = slice(-2, None)
y_axis, x_axis = nxd.axes[img_slicing]
y_label, x_label = nxd.axes_names[img_slicing]
- x_units = get_attr_as_unicode(x_axis, 'units') if x_axis else None
- y_units = get_attr_as_unicode(y_axis, 'units') if y_axis else None
+ x_units = get_attr_as_unicode(x_axis, "units") if x_axis else None
+ y_units = get_attr_as_unicode(y_axis, "units") if y_axis else None
self.getWidget().setImageData(
[nxd.signal] + nxd.auxiliary_signals,
- x_axis=x_axis, y_axis=y_axis,
+ x_axis=x_axis,
+ y_axis=y_axis,
signals_names=[nxd.signal_name] + nxd.auxiliary_signals_names,
- xlabel=x_label, ylabel=y_label,
+ xlabel=x_label,
+ ylabel=y_label,
title=nxd.title,
keep_ratio=(x_units == y_units),
)
@@ -1838,14 +1900,16 @@ class _NXdataComplexImageView(_NXdataBaseDataView):
class _NXdataStackView(_NXdataBaseDataView):
def __init__(self, parent):
- _NXdataBaseDataView.__init__(
- self, parent, modeId=NXDATA_STACK_MODE)
+ _NXdataBaseDataView.__init__(self, parent, modeId=NXDATA_STACK_MODE)
def createWidget(self, parent):
from silx.gui.data.NXdataWidgets import ArrayStackPlot
+
widget = ArrayStackPlot(parent)
widget.getStackView().setColormap(self.defaultColormap())
- widget.getStackView().getPlotWidget().getColormapAction().setColorDialog(self.defaultColorDialog())
+ widget.getStackView().getPlotWidget().getColormapAction().setColormapDialog(
+ self.defaultColorDialog()
+ )
return widget
def axesNames(self, data, info):
@@ -1867,10 +1931,16 @@ class _NXdataStackView(_NXdataBaseDataView):
widget = self.getWidget()
widget.setStackData(
- nxd.signal, x_axis=x_axis, y_axis=y_axis, z_axis=z_axis,
- signal_name=signal_name,
- xlabel=x_label, ylabel=y_label, zlabel=z_label,
- title=title)
+ nxd.signal,
+ x_axis=x_axis,
+ y_axis=y_axis,
+ z_axis=z_axis,
+ signal_name=signal_name,
+ xlabel=x_label,
+ ylabel=y_label,
+ zlabel=z_label,
+ title=title,
+ )
# Override the colormap, while setStack overwrite it
widget.getStackView().setColormap(self.defaultColormap())
@@ -1886,10 +1956,12 @@ class _NXdataStackView(_NXdataBaseDataView):
class _NXdataVolumeView(_NXdataBaseDataView):
def __init__(self, parent):
_NXdataBaseDataView.__init__(
- self, parent,
+ self,
+ parent,
label="NXdata (3D)",
icon=icons.getQIcon("view-nexus"),
- modeId=NXDATA_VOLUME_MODE)
+ modeId=NXDATA_VOLUME_MODE,
+ )
try:
import silx.gui.plot3d # noqa
except ImportError:
@@ -1904,6 +1976,7 @@ class _NXdataVolumeView(_NXdataBaseDataView):
def createWidget(self, parent):
from silx.gui.data.NXdataWidgets import ArrayVolumePlot
+
widget = ArrayVolumePlot(parent)
return widget
@@ -1924,10 +1997,16 @@ class _NXdataVolumeView(_NXdataBaseDataView):
widget = self.getWidget()
widget.setData(
- nxd.signal, x_axis=x_axis, y_axis=y_axis, z_axis=z_axis,
+ nxd.signal,
+ x_axis=x_axis,
+ y_axis=y_axis,
+ z_axis=z_axis,
signal_name=signal_name,
- xlabel=x_label, ylabel=y_label, zlabel=z_label,
- title=title)
+ xlabel=x_label,
+ ylabel=y_label,
+ zlabel=z_label,
+ title=title,
+ )
def getDataPriority(self, data, info):
data = self.normalizeData(data)
@@ -1941,16 +2020,21 @@ class _NXdataVolumeView(_NXdataBaseDataView):
class _NXdataVolumeAsStackView(_NXdataBaseDataView):
def __init__(self, parent):
_NXdataBaseDataView.__init__(
- self, parent,
+ self,
+ parent,
label="NXdata (2D)",
icon=icons.getQIcon("view-nexus"),
- modeId=NXDATA_VOLUME_AS_STACK_MODE)
+ modeId=NXDATA_VOLUME_AS_STACK_MODE,
+ )
def createWidget(self, parent):
from silx.gui.data.NXdataWidgets import ArrayStackPlot
+
widget = ArrayStackPlot(parent)
widget.getStackView().setColormap(self.defaultColormap())
- widget.getStackView().getPlotWidget().getColormapAction().setColorDialog(self.defaultColorDialog())
+ widget.getStackView().getPlotWidget().getColormapAction().setColormapDialog(
+ self.defaultColorDialog()
+ )
return widget
def axesNames(self, data, info):
@@ -1972,10 +2056,16 @@ class _NXdataVolumeAsStackView(_NXdataBaseDataView):
widget = self.getWidget()
widget.setStackData(
- nxd.signal, x_axis=x_axis, y_axis=y_axis, z_axis=z_axis,
- signal_name=signal_name,
- xlabel=x_label, ylabel=y_label, zlabel=z_label,
- title=title)
+ nxd.signal,
+ x_axis=x_axis,
+ y_axis=y_axis,
+ z_axis=z_axis,
+ signal_name=signal_name,
+ xlabel=x_label,
+ ylabel=y_label,
+ zlabel=z_label,
+ title=title,
+ )
# Override the colormap, while setStack overwrite it
widget.getStackView().setColormap(self.defaultColormap())
@@ -1989,19 +2079,25 @@ class _NXdataVolumeAsStackView(_NXdataBaseDataView):
return DataView.UNSUPPORTED
+
class _NXdataComplexVolumeAsStackView(_NXdataBaseDataView):
def __init__(self, parent):
_NXdataBaseDataView.__init__(
- self, parent,
+ self,
+ parent,
label="NXdata (2D)",
icon=icons.getQIcon("view-nexus"),
- modeId=NXDATA_VOLUME_AS_STACK_MODE)
+ modeId=NXDATA_VOLUME_AS_STACK_MODE,
+ )
self._is_complex_data = False
def createWidget(self, parent):
from silx.gui.data.NXdataWidgets import ArrayComplexImagePlot
+
widget = ArrayComplexImagePlot(parent, colormap=self.defaultColormap())
- widget.getPlot().getColormapAction().setColorDialog(self.defaultColorDialog())
+ widget.getPlot().getColormapAction().setColormapDialog(
+ self.defaultColorDialog()
+ )
return widget
def axesNames(self, data, info):
@@ -2023,9 +2119,13 @@ class _NXdataComplexVolumeAsStackView(_NXdataBaseDataView):
self.getWidget().setImageData(
[nxd.signal] + nxd.auxiliary_signals,
- x_axis=x_axis, y_axis=y_axis,
+ x_axis=x_axis,
+ y_axis=y_axis,
signals_names=[nxd.signal_name] + nxd.auxiliary_signals_names,
- xlabel=x_label, ylabel=y_label, title=nxd.title)
+ xlabel=x_label,
+ ylabel=y_label,
+ title=nxd.title,
+ )
def getDataPriority(self, data, info):
data = self.normalizeData(data)
@@ -2041,12 +2141,14 @@ class _NXdataComplexVolumeAsStackView(_NXdataBaseDataView):
class _NXdataView(CompositeDataView):
"""Composite view displaying NXdata groups using the most adequate
widget depending on the dimensionality."""
+
def __init__(self, parent):
super(_NXdataView, self).__init__(
parent=parent,
label="NXdata",
modeId=NXDATA_MODE,
- icon=icons.getQIcon("view-nexus"))
+ icon=icons.getQIcon("view-nexus"),
+ )
self.addView(_InvalidNXdataView(parent))
self.addView(_NXdataScalarView(parent))
diff --git a/src/silx/gui/data/Hdf5TableView.py b/src/silx/gui/data/Hdf5TableView.py
index f3fbb69..bb14768 100644
--- a/src/silx/gui/data/Hdf5TableView.py
+++ b/src/silx/gui/data/Hdf5TableView.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -30,7 +30,6 @@ __authors__ = ["V. Valls"]
__license__ = "MIT"
__date__ = "12/02/2019"
-import collections
import functools
import os.path
import logging
@@ -39,6 +38,7 @@ import numpy
from silx.gui import qt
import silx.io
+from silx.io import h5link_utils
from .TextFormatter import TextFormatter
import silx.gui.hdf5
from silx.gui.widgets import HierarchicalTableView
@@ -50,8 +50,8 @@ _logger = logging.getLogger(__name__)
class _CellData(object):
- """Store a table item
- """
+ """Store a table item"""
+
def __init__(self, value=None, isHeader=False, span=None, tooltip=None):
"""
Constructor
@@ -73,8 +73,7 @@ class _CellData(object):
return self.__isHeader
def value(self):
- """Returns the value of the item.
- """
+ """Returns the value of the item."""
return self.__value
def span(self):
@@ -187,10 +186,18 @@ class _CellFilterAvailableData(_CellData):
_states = {
True: ("Available", qt.QColor(0x000000), None, None),
- False: ("Not available", qt.QColor(0xFFFFFF), qt.QColor(0xFF0000),
- "You have to install this filter on your system to be able to read this dataset"),
- "na": ("n.a.", qt.QColor(0x000000), None,
- "This version of h5py/hdf5 is not able to display the information"),
+ False: (
+ "Not available",
+ qt.QColor(0xFFFFFF),
+ qt.QColor(0xFF0000),
+ "You have to install this filter on your system to be able to read this dataset",
+ ),
+ "na": (
+ "n.a.",
+ qt.QColor(0x000000),
+ None,
+ "This version of h5py/hdf5 is not able to display the information",
+ ),
}
def __init__(self, filterId):
@@ -309,7 +316,9 @@ class Hdf5TableModel(HierarchicalTableView.HierarchicalTableModel):
if h5pyObject is None or self.isSupportedObject(h5pyObject):
self.__obj = h5pyObject
else:
- _logger.warning("Object class %s unsupported. Object ignored.", type(h5pyObject))
+ _logger.warning(
+ "Object class %s unsupported. Object ignored.", type(h5pyObject)
+ )
self.__initProperties()
self.endResetModel()
@@ -319,10 +328,12 @@ class Hdf5TableModel(HierarchicalTableView.HierarchicalTableModel):
return self.__hdf5Formatter.humanReadableHdf5Type(dataset)
def __attributeTooltip(self, attribute):
- attributeDict = collections.OrderedDict()
+ attributeDict = {}
if hasattr(attribute, "shape"):
attributeDict["Shape"] = self.__hdf5Formatter.humanReadableShape(attribute)
- attributeDict["Data type"] = self.__hdf5Formatter.humanReadableType(attribute, full=True)
+ attributeDict["Data type"] = self.__hdf5Formatter.humanReadableType(
+ attribute, full=True
+ )
html = htmlFromDict(attributeDict, title="HDF5 Attribute")
return html
@@ -336,7 +347,7 @@ class Hdf5TableModel(HierarchicalTableView.HierarchicalTableModel):
return self.__hdf5Formatter.humanReadableShape(dataset)
size = dataset.size
shape = self.__hdf5Formatter.humanReadableShape(dataset)
- return u"%s = %s" % (shape, size)
+ return "%s = %s" % (shape, size)
def __formatChunks(self, dataset):
"""Format the shape"""
@@ -344,7 +355,7 @@ class Hdf5TableModel(HierarchicalTableView.HierarchicalTableModel):
if chunks is None:
return ""
shape = " \u00D7 ".join([str(i) for i in chunks])
- sizes = numpy.product(chunks)
+ sizes = numpy.prod(chunks)
text = "%s = %s" % (shape, sizes)
return text
@@ -383,7 +394,9 @@ class Hdf5TableModel(HierarchicalTableView.HierarchicalTableModel):
self.__data.addHeaderValueRow("Local", local)
else:
# it's a real H5py object
- self.__data.addHeaderValueRow("Basename", lambda x: os.path.basename(x.name))
+ self.__data.addHeaderValueRow(
+ "Basename", lambda x: os.path.basename(x.name)
+ )
self.__data.addHeaderValueRow("Name", lambda x: x.name)
if obj.file is not None:
self.__data.addHeaderValueRow("File", lambda x: x.file.filename)
@@ -399,25 +412,10 @@ class Hdf5TableModel(HierarchicalTableView.HierarchicalTableModel):
showPhysicalLocation = False
# External data (nothing to do with external links)
- nExtSources = 0
- firstExtSource = None
- extType = None
- if silx.io.is_dataset(hdf5obj):
- if hasattr(hdf5obj, "is_virtual"):
- if hdf5obj.is_virtual:
- extSources = hdf5obj.virtual_sources()
- if extSources:
- firstExtSource = extSources[0].file_name + SEPARATOR + extSources[0].dset_name
- extType = "Virtual"
- nExtSources = len(extSources)
- if hasattr(hdf5obj, "external"):
- extSources = hdf5obj.external
- if extSources:
- firstExtSource = extSources[0][0]
- extType = "Raw"
- nExtSources = len(extSources)
+ external_dataset_info = h5link_utils.external_dataset_info(hdf5obj)
if showPhysicalLocation:
+
def _physical_location(x):
if isinstance(obj, silx.gui.hdf5.H5Node):
return x.physical_filename + SEPARATOR + x.physical_name
@@ -431,34 +429,15 @@ class Hdf5TableModel(HierarchicalTableView.HierarchicalTableModel):
self.__data.addHeaderValueRow("Physical", _physical_location)
- if extType:
- def _first_source(x):
- # Absolute path
- if os.path.isabs(firstExtSource):
- return firstExtSource
-
- # Relative path with respect to the file directory
- if isinstance(obj, silx.gui.hdf5.H5Node):
- filename = x.physical_filename
- elif silx.io.is_file(obj):
- filename = x.filename
- elif obj.file is not None:
- filename = x.file.filename
- else:
- return firstExtSource
-
- if firstExtSource[0] == ".":
- return filename + firstExtSource[1:]
- else:
- return os.path.join(os.path.dirname(filename), firstExtSource)
-
+ if external_dataset_info is not None:
self.__data.addHeaderRow(headerLabel="External sources")
- self.__data.addHeaderValueRow("Type", extType)
- self.__data.addHeaderValueRow("Count", str(nExtSources))
- self.__data.addHeaderValueRow("First", _first_source)
+ self.__data.addHeaderValueRow("Type", external_dataset_info.type)
+ self.__data.addHeaderValueRow("Count", external_dataset_info.nfiles)
+ self.__data.addHeaderValueRow(
+ "First", external_dataset_info.first_source_url
+ )
if hasattr(obj, "dtype"):
-
self.__data.addHeaderRow(headerLabel="Data info")
if hasattr(obj, "id") and hasattr(obj.id, "get_type"):
@@ -500,10 +479,14 @@ class Hdf5TableModel(HierarchicalTableView.HierarchicalTableModel):
self.__data.addHeaderRow(headerLabel="Attributes")
for key in sorted(obj.attrs.keys()):
callback = lambda key, x: self.__formatter.toString(x.attrs[key])
- callbackTooltip = lambda key, x: self.__attributeTooltip(x.attrs[key])
- self.__data.addHeaderValueRow(headerLabel=key,
- value=functools.partial(callback, key),
- tooltip=functools.partial(callbackTooltip, key))
+ callbackTooltip = lambda key, x: self.__attributeTooltip(
+ x.attrs[key]
+ )
+ self.__data.addHeaderValueRow(
+ headerLabel=key,
+ value=functools.partial(callback, key),
+ tooltip=functools.partial(callbackTooltip, key),
+ )
def __getFilterInfo(self, dataset, filterIndex):
"""Get a tuple of readable info from dataset filters
@@ -558,8 +541,7 @@ class Hdf5TableModel(HierarchicalTableView.HierarchicalTableModel):
return self.__formatter
def __formatChanged(self):
- """Called when the format changed.
- """
+ """Called when the format changed."""
self.reset()
@@ -628,6 +610,11 @@ class Hdf5TableView(HierarchicalTableView.HierarchicalTableView):
for row in range(model.rowCount()):
for column in range(model.columnCount()):
index = model.index(row, column)
- if (index.isValid() and index.data(
- HierarchicalTableView.HierarchicalTableModel.IsHeaderRole) is False):
+ if (
+ index.isValid()
+ and index.data(
+ HierarchicalTableView.HierarchicalTableModel.IsHeaderRole
+ )
+ is False
+ ):
self.openPersistentEditor(index)
diff --git a/src/silx/gui/data/HexaTableView.py b/src/silx/gui/data/HexaTableView.py
index 30f62f0..f50bf88 100644
--- a/src/silx/gui/data/HexaTableView.py
+++ b/src/silx/gui/data/HexaTableView.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -26,8 +26,6 @@ This module defines model and widget to display raw data using an
hexadecimal viewer.
"""
-import collections
-
import numpy
from silx.gui import qt
@@ -46,7 +44,7 @@ class _VoidConnector(object):
"""
def __init__(self, data):
- self.__cache = collections.OrderedDict()
+ self.__cache = {}
self.__len = data.itemsize
self.__data = data
@@ -55,10 +53,10 @@ class _VoidConnector(object):
pos = bufferId << 10
data = self.__data
if hasattr(data, "tobytes"):
- data = data.tobytes()[pos:pos + 1024]
+ data = data.tobytes()[pos : pos + 1024]
else:
# Old fashion
- data = data.data[pos:pos + 1024]
+ data = data.data[pos : pos + 1024]
self.__cache[bufferId] = data
if len(self.__cache) > 32:
@@ -98,6 +96,7 @@ class HexaTableModel(qt.QAbstractTableModel):
:param qt.QObject parent: Parent object
:param data: A numpy array or a h5py dataset
"""
+
def __init__(self, parent=None, data=None):
qt.QAbstractTableModel.__init__(self, parent)
@@ -136,7 +135,7 @@ class HexaTableModel(qt.QAbstractTableModel):
if role == qt.Qt.DisplayRole:
if column == 0x10:
- start = (row << 4)
+ start = row << 4
text = ""
for i in range(0x10):
pos = start + i
@@ -236,6 +235,7 @@ class HexaTableView(qt.QTableView):
It customs the column size to provide a better layout.
"""
+
def __init__(self, parent=None):
"""
Constructor
diff --git a/src/silx/gui/data/NXdataWidgets.py b/src/silx/gui/data/NXdataWidgets.py
index b9e34d2..a2bab7a 100644
--- a/src/silx/gui/data/NXdataWidgets.py
+++ b/src/silx/gui/data/NXdataWidgets.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2017-2022 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -59,6 +59,7 @@ class ArrayCurvePlot(qt.QWidget):
This widget also handles simple 2D or 3D scatter plots (third dimension
displayed as colour of points).
"""
+
def __init__(self, parent=None):
"""
@@ -97,10 +98,18 @@ class ArrayCurvePlot(qt.QWidget):
"""
return self._plot
- def setCurvesData(self, ys, x=None,
- yerror=None, xerror=None,
- ylabels=None, xlabel=None, title=None,
- xscale=None, yscale=None):
+ def setCurvesData(
+ self,
+ ys,
+ x=None,
+ yerror=None,
+ xerror=None,
+ ylabels=None,
+ xlabel=None,
+ title=None,
+ xscale=None,
+ yscale=None,
+ ):
"""
:param List[ndarray] ys: List of arrays to be represented by the y (vertical) axis.
@@ -139,11 +148,9 @@ class ArrayCurvePlot(qt.QWidget):
self._plot.setGraphTitle(title or "")
if xscale is not None:
- self._plot.getXAxis().setScale(
- 'log' if xscale == 'log' else 'linear')
+ self._plot.getXAxis().setScale("log" if xscale == "log" else "linear")
if yscale is not None:
- self._plot.getYAxis().setScale(
- 'log' if yscale == 'log' else 'linear')
+ self._plot.getYAxis().setScale("log" if yscale == "log" else "linear")
self._updateCurve()
if not self.__selector_is_connected:
@@ -168,8 +175,10 @@ class ArrayCurvePlot(qt.QWidget):
# Only remove curves that will no longer belong to the plot
# So remaining curves keep their settings
for item in self._plot.getItems():
- if (isinstance(item, items.Curve) and
- item.getName() not in self.__signals_names):
+ if (
+ isinstance(item, items.Curve)
+ and item.getName() not in self.__signals_names
+ ):
self._plot.remove(item)
for i in range(len(self.__signals)):
@@ -179,9 +188,9 @@ class ArrayCurvePlot(qt.QWidget):
y_errors = None
if i == 0 and self.__signal_errors is not None:
y_errors = self.__signal_errors[self._selector.selection()]
- self._plot.addCurve(x, ys[i], legend=legend,
- xerror=self.__x_axis_errors,
- yerror=y_errors)
+ self._plot.addCurve(
+ x, ys[i], legend=legend, xerror=self.__x_axis_errors, yerror=y_errors
+ )
if i == 0:
self._plot.setActiveCurve(legend)
@@ -207,6 +216,7 @@ class XYVScatterPlot(qt.QWidget):
Widget for plotting one or more scatters
(with identical x, y coordinates).
"""
+
def __init__(self, parent=None):
"""
@@ -229,9 +239,11 @@ class XYVScatterPlot(qt.QWidget):
self.__y_axis_errors = None
self._plot = ScatterView(self)
- self._plot.setColormap(Colormap(name="viridis",
- vmin=None, vmax=None,
- normalization=Colormap.LINEAR))
+ self._plot.setColormap(
+ Colormap(
+ name="viridis", vmin=None, vmax=None, normalization=Colormap.LINEAR
+ )
+ )
self._slider = HorizontalSliderWithBrowser(parent=self)
self._slider.setMinimum(0)
@@ -263,11 +275,20 @@ class XYVScatterPlot(qt.QWidget):
"""
return self._plot.getPlotWidget()
- def setScattersData(self, y, x, values,
- yerror=None, xerror=None,
- ylabel=None, xlabel=None,
- title="", scatter_titles=None,
- xscale=None, yscale=None):
+ def setScattersData(
+ self,
+ y,
+ x,
+ values,
+ yerror=None,
+ xerror=None,
+ ylabel=None,
+ xlabel=None,
+ title="",
+ scatter_titles=None,
+ xscale=None,
+ yscale=None,
+ ):
"""
:param ndarray y: 1D array for y (vertical) coordinates.
@@ -306,11 +327,9 @@ class XYVScatterPlot(qt.QWidget):
self._slider.valueChanged[int].connect(self._sliderIdxChanged)
if xscale is not None:
- self._plot.getXAxis().setScale(
- 'log' if xscale == 'log' else 'linear')
+ self._plot.getXAxis().setScale("log" if xscale == "log" else "linear")
if yscale is not None:
- self._plot.getYAxis().setScale(
- 'log' if yscale == 'log' else 'linear')
+ self._plot.getYAxis().setScale("log" if yscale == "log" else "linear")
self._updateScatter()
@@ -324,14 +343,18 @@ class XYVScatterPlot(qt.QWidget):
title = self.__graph_title # main NXdata @title
if len(self.__scatter_titles) > 1:
# Append dataset name only when there is many datasets
- title += '\n' + self.__scatter_titles[idx]
+ title += "\n" + self.__scatter_titles[idx]
else:
title = self.__scatter_titles[idx] # scatter dataset name
self._plot.setGraphTitle(title)
- self._plot.setData(x, y, self.__values[idx],
- xerror=self.__x_axis_errors,
- yerror=self.__y_axis_errors)
+ self._plot.setData(
+ x,
+ y,
+ self.__values[idx],
+ xerror=self.__x_axis_errors,
+ yerror=self.__y_axis_errors,
+ )
self._plot.resetZoom()
self._plot.getXAxis().setLabel(self.__x_axis_name)
self._plot.getYAxis().setLabel(self.__y_axis_name)
@@ -356,6 +379,7 @@ class ArrayImagePlot(qt.QWidget):
If one or both of the axes does not have regularly spaced values, the
the image is plotted as a coloured scatter plot.
"""
+
def __init__(self, parent=None):
"""
@@ -371,9 +395,11 @@ class ArrayImagePlot(qt.QWidget):
self.__y_axis_name = None
self._plot = Plot2D(self)
- self._plot.setDefaultColormap(Colormap(name="viridis",
- vmin=None, vmax=None,
- normalization=Colormap.LINEAR))
+ self._plot.setDefaultColormap(
+ Colormap(
+ name="viridis", vmin=None, vmax=None, normalization=Colormap.LINEAR
+ )
+ )
self._plot.getIntensityHistogramAction().setVisible(True)
self._plot.setKeepDataAspectRatio(True)
maskToolWidget = self._plot.getMaskToolsDockWidget().widget()
@@ -407,13 +433,20 @@ class ArrayImagePlot(qt.QWidget):
"""
return self._plot
- def setImageData(self, signals,
- x_axis=None, y_axis=None,
- signals_names=None,
- xlabel=None, ylabel=None,
- title=None, isRgba=False,
- xscale=None, yscale=None,
- keep_ratio: bool=True):
+ def setImageData(
+ self,
+ signals,
+ x_axis=None,
+ y_axis=None,
+ signals_names=None,
+ xlabel=None,
+ ylabel=None,
+ title=None,
+ isRgba=False,
+ xscale=None,
+ yscale=None,
+ keep_ratio: bool = True,
+ ):
"""
:param signals: list of n-D datasets, whose last 2 dimensions are used as the
@@ -466,13 +499,14 @@ class ArrayImagePlot(qt.QWidget):
self._auxSigSlider.setValue(0)
self._axis_scales = xscale, yscale
- self._updateImage()
- self._plot.setKeepDataAspectRatio(keep_ratio)
- self._plot.resetZoom()
self._selector.selectionChanged.connect(self._updateImage)
self._auxSigSlider.valueChanged.connect(self._sliderIdxChanged)
+ self._updateImage()
+ self._plot.setKeepDataAspectRatio(keep_ratio)
+ self._plot.resetZoom()
+
def _updateImage(self):
selection = self._selector.selection()
auxSigIdx = self._auxSigSlider.value()
@@ -494,7 +528,7 @@ class ArrayImagePlot(qt.QWidget):
x_axis = numpy.arange(image.shape[1])
elif numpy.isscalar(x_axis) or len(x_axis) == 1:
# constant axis
- x_axis = x_axis * numpy.ones((image.shape[1], ))
+ x_axis = x_axis * numpy.ones((image.shape[1],))
elif len(x_axis) == 2:
# linear calibration
x_axis = x_axis[0] * numpy.arange(image.shape[1]) + x_axis[1]
@@ -502,14 +536,25 @@ class ArrayImagePlot(qt.QWidget):
if y_axis is None:
y_axis = numpy.arange(image.shape[0])
elif numpy.isscalar(y_axis) or len(y_axis) == 1:
- y_axis = y_axis * numpy.ones((image.shape[0], ))
+ y_axis = y_axis * numpy.ones((image.shape[0],))
elif len(y_axis) == 2:
y_axis = y_axis[0] * numpy.arange(image.shape[0]) + y_axis[1]
- xcalib = ArrayCalibration(x_axis)
- ycalib = ArrayCalibration(y_axis)
-
- self._plot.remove(kind=("scatter", "image",))
+ try:
+ xcalib = ArrayCalibration(x_axis)
+ except ValueError:
+ xcalib = NoCalibration()
+ try:
+ ycalib = ArrayCalibration(y_axis)
+ except ValueError:
+ ycalib = NoCalibration()
+
+ self._plot.remove(
+ kind=(
+ "scatter",
+ "image",
+ )
+ )
if xcalib.is_affine() and ycalib.is_affine():
# regular image
xorigin, xscale = xcalib(0), xcalib.get_slope()
@@ -517,33 +562,42 @@ class ArrayImagePlot(qt.QWidget):
origin = (xorigin, yorigin)
scale = (xscale, yscale)
- self._plot.getXAxis().setScale('linear')
- self._plot.getYAxis().setScale('linear')
- self._plot.addImage(image, legend=legend,
- origin=origin, scale=scale,
- replace=True, resetzoom=False)
+ self._plot.getXAxis().setScale("linear")
+ self._plot.getYAxis().setScale("linear")
+ self._plot.addImage(
+ image,
+ legend=legend,
+ origin=origin,
+ scale=scale,
+ replace=True,
+ resetzoom=False,
+ )
else:
xaxisscale, yaxisscale = self._axis_scales
if xaxisscale is not None:
self._plot.getXAxis().setScale(
- 'log' if xaxisscale == 'log' else 'linear')
+ "log" if xaxisscale == "log" else "linear"
+ )
if yaxisscale is not None:
self._plot.getYAxis().setScale(
- 'log' if yaxisscale == 'log' else 'linear')
+ "log" if yaxisscale == "log" else "linear"
+ )
scatterx, scattery = numpy.meshgrid(x_axis, y_axis)
# fixme: i don't think this can handle "irregular" RGBA images
- self._plot.addScatter(numpy.ravel(scatterx),
- numpy.ravel(scattery),
- numpy.ravel(image),
- legend=legend)
+ self._plot.addScatter(
+ numpy.ravel(scatterx),
+ numpy.ravel(scattery),
+ numpy.ravel(image),
+ legend=legend,
+ )
if self.__title:
title = self.__title
if len(self.__signals_names) > 1:
# Append dataset name only when there is many datasets
- title += '\n' + self.__signals_names[auxSigIdx]
+ title += "\n" + self.__signals_names[auxSigIdx]
else:
title = self.__signals_names[auxSigIdx]
self._plot.setGraphTitle(title)
@@ -573,6 +627,7 @@ class ArrayComplexImagePlot(qt.QWidget):
If one or both of the axes does not have regularly spaced values, the
the image is plotted as a coloured scatter plot.
"""
+
def __init__(self, parent=None, colormap=None):
"""
@@ -589,10 +644,12 @@ class ArrayComplexImagePlot(qt.QWidget):
self._plot = ComplexImageView(self)
if colormap is not None:
- for mode in (ComplexImageView.ComplexMode.ABSOLUTE,
- ComplexImageView.ComplexMode.SQUARE_AMPLITUDE,
- ComplexImageView.ComplexMode.REAL,
- ComplexImageView.ComplexMode.IMAGINARY):
+ for mode in (
+ ComplexImageView.ComplexMode.ABSOLUTE,
+ ComplexImageView.ComplexMode.SQUARE_AMPLITUDE,
+ ComplexImageView.ComplexMode.REAL,
+ ComplexImageView.ComplexMode.IMAGINARY,
+ ):
self._plot.setColormap(colormap, mode)
self._plot.getPlot().getIntensityHistogramAction().setVisible(True)
@@ -628,12 +685,17 @@ class ArrayComplexImagePlot(qt.QWidget):
"""
return self._plot.getPlot()
- def setImageData(self, signals,
- x_axis=None, y_axis=None,
- signals_names=None,
- xlabel=None, ylabel=None,
- title=None,
- keep_ratio: bool=True):
+ def setImageData(
+ self,
+ signals,
+ x_axis=None,
+ y_axis=None,
+ signals_names=None,
+ xlabel=None,
+ ylabel=None,
+ title=None,
+ keep_ratio: bool = True,
+ ):
"""
:param signals: list of n-D datasets, whose last 2 dimensions are used as the
@@ -703,7 +765,7 @@ class ArrayComplexImagePlot(qt.QWidget):
x_axis = numpy.arange(image.shape[1])
elif numpy.isscalar(x_axis) or len(x_axis) == 1:
# constant axis
- x_axis = x_axis * numpy.ones((image.shape[1], ))
+ x_axis = x_axis * numpy.ones((image.shape[1],))
elif len(x_axis) == 2:
# linear calibration
x_axis = x_axis[0] * numpy.arange(image.shape[1]) + x_axis[1]
@@ -711,25 +773,31 @@ class ArrayComplexImagePlot(qt.QWidget):
if y_axis is None:
y_axis = numpy.arange(image.shape[0])
elif numpy.isscalar(y_axis) or len(y_axis) == 1:
- y_axis = y_axis * numpy.ones((image.shape[0], ))
+ y_axis = y_axis * numpy.ones((image.shape[0],))
elif len(y_axis) == 2:
y_axis = y_axis[0] * numpy.arange(image.shape[0]) + y_axis[1]
- xcalib = ArrayCalibration(x_axis)
- ycalib = ArrayCalibration(y_axis)
+ try:
+ xcalib = ArrayCalibration(x_axis)
+ except ValueError:
+ xcalib = NoCalibration()
+ try:
+ ycalib = ArrayCalibration(y_axis)
+ except ValueError:
+ ycalib = NoCalibration()
self._plot.setData(image)
if xcalib.is_affine():
xorigin, xscale = xcalib(0), xcalib.get_slope()
else:
_logger.warning("Unsupported complex image X axis calibration")
- xorigin, xscale = 0., 1.
+ xorigin, xscale = 0.0, 1.0
if ycalib.is_affine():
yorigin, yscale = ycalib(0), ycalib.get_slope()
else:
_logger.warning("Unsupported complex image Y axis calibration")
- yorigin, yscale = 0., 1.
+ yorigin, yscale = 0.0, 1.0
self._plot.setOrigin((xorigin, yorigin))
self._plot.setScale((xscale, yscale))
@@ -738,7 +806,7 @@ class ArrayComplexImagePlot(qt.QWidget):
title = self.__title
if len(self.__signals_names) > 1:
# Append dataset name only when there is many datasets
- title += '\n' + self.__signals_names[auxSigIdx]
+ title += "\n" + self.__signals_names[auxSigIdx]
else:
title = self.__signals_names[auxSigIdx]
self._plot.setGraphTitle(title)
@@ -765,6 +833,7 @@ class ArrayStackPlot(qt.QWidget):
the signal array, and the plot is updated to load the stack corresponding
to the selection.
"""
+
def __init__(self, parent=None):
"""
@@ -784,7 +853,9 @@ class ArrayStackPlot(qt.QWidget):
self.__x_axis_name = None
self._stack_view = StackView(self)
- maskToolWidget = self._stack_view.getPlotWidget().getMaskToolsDockWidget().widget()
+ maskToolWidget = (
+ self._stack_view.getPlotWidget().getMaskToolsDockWidget().widget()
+ )
maskToolWidget.setItemMaskUpdated(True)
self._hline = qt.QFrame(self)
@@ -810,11 +881,18 @@ class ArrayStackPlot(qt.QWidget):
"""
return self._stack_view
- def setStackData(self, signal,
- x_axis=None, y_axis=None, z_axis=None,
- signal_name=None,
- xlabel=None, ylabel=None, zlabel=None,
- title=None):
+ def setStackData(
+ self,
+ signal,
+ x_axis=None,
+ y_axis=None,
+ z_axis=None,
+ signal_name=None,
+ xlabel=None,
+ ylabel=None,
+ zlabel=None,
+ title=None,
+ ):
"""
:param signal: n-D dataset, whose last 3 dimensions are used as the
@@ -896,13 +974,12 @@ class ArrayStackPlot(qt.QWidget):
calibrations = []
for axis in [z_axis, y_axis, x_axis]:
-
if axis is None:
calibrations.append(NoCalibration())
elif len(axis) == 2:
calibrations.append(
- LinearCalibration(y_intercept=axis[0],
- slope=axis[1]))
+ LinearCalibration(y_intercept=axis[0], slope=axis[1])
+ )
else:
calibrations.append(ArrayCalibration(axis))
@@ -917,9 +994,8 @@ class ArrayStackPlot(qt.QWidget):
self._stack_view.setStack(stk, calibrations=calibrations)
self._stack_view.setLabels(
- labels=[self.__z_axis_name,
- self.__y_axis_name,
- self.__x_axis_name])
+ labels=[self.__z_axis_name, self.__y_axis_name, self.__x_axis_name]
+ )
def clear(self):
old = self._selector.blockSignals(True)
@@ -941,6 +1017,7 @@ class ArrayVolumePlot(qt.QWidget):
the signal array, and the plot is updated to load the stack corresponding
to the selection.
"""
+
def __init__(self, parent=None):
"""
@@ -986,11 +1063,18 @@ class ArrayVolumePlot(qt.QWidget):
"""
return self._view
- def setData(self, signal,
- x_axis=None, y_axis=None, z_axis=None,
- signal_name=None,
- xlabel=None, ylabel=None, zlabel=None,
- title=None):
+ def setData(
+ self,
+ signal,
+ x_axis=None,
+ y_axis=None,
+ z_axis=None,
+ signal_name=None,
+ xlabel=None,
+ ylabel=None,
+ zlabel=None,
+ title=None,
+ ):
"""
:param signal: n-D dataset, whose last 3 dimensions are used as the
@@ -1056,14 +1140,13 @@ class ArrayVolumePlot(qt.QWidget):
if axis is None:
calibration = NoCalibration()
elif len(axis) == 2:
- calibration = LinearCalibration(
- y_intercept=axis[0], slope=axis[1])
+ calibration = LinearCalibration(y_intercept=axis[0], slope=axis[1])
else:
calibration = ArrayCalibration(axis)
if not calibration.is_affine():
_logger.warning("Axis has not linear values, ignored")
- offset.append(0.)
- scale.append(1.)
+ offset.append(0.0)
+ scale.append(1.0)
else:
offset.append(calibration(0))
scale.append(calibration.get_slope())
@@ -1083,7 +1166,8 @@ class ArrayVolumePlot(qt.QWidget):
volumeView = self.getVolumeView()
volumeView.setData(data, offset=offset, scale=scale)
volumeView.setAxesLabels(
- self.__x_axis_name, self.__y_axis_name, self.__z_axis_name)
+ self.__x_axis_name, self.__y_axis_name, self.__z_axis_name
+ )
def clear(self):
old = self._selector.blockSignals(True)
diff --git a/src/silx/gui/data/NumpyAxesSelector.py b/src/silx/gui/data/NumpyAxesSelector.py
index 50b8dcd..9b62c29 100644
--- a/src/silx/gui/data/NumpyAxesSelector.py
+++ b/src/silx/gui/data/NumpyAxesSelector.py
@@ -268,8 +268,9 @@ class NumpyAxesSelector(qt.QWidget):
:param List[str] axesNames: List of distinct strings identifying axis names
"""
self.__axisNames = list(axesNames)
- assert len(set(self.__axisNames)) == len(self.__axisNames),\
+ assert len(set(self.__axisNames)) == len(self.__axisNames), (
"Non-unique axes names: %s" % self.__axisNames
+ )
delta = len(self.__axis) - len(self.__axisNames)
if delta < 0:
@@ -318,10 +319,14 @@ class NumpyAxesSelector(qt.QWidget):
if index >= delta and index - delta < len(self.__axisNames):
axis.setAxisName(self.__axisNames[index - delta])
# this weak method was expected to be able to delete sub widget
- callback = functools.partial(silx.utils.weakref.WeakMethodProxy(self.__axisValueChanged), axis)
+ callback = functools.partial(
+ silx.utils.weakref.WeakMethodProxy(self.__axisValueChanged), axis
+ )
axis.valueChanged.connect(callback)
# this weak method was expected to be able to delete sub widget
- callback = functools.partial(silx.utils.weakref.WeakMethodProxy(self.__axisNameChanged), axis)
+ callback = functools.partial(
+ silx.utils.weakref.WeakMethodProxy(self.__axisNameChanged), axis
+ )
axis.axisNameChanged.connect(callback)
axis.setNamedAxisSelectorVisibility(self.__namedAxesVisibility)
self.layout().addWidget(axis)
@@ -335,8 +340,12 @@ class NumpyAxesSelector(qt.QWidget):
"""Update axes geometry to align all axes components together."""
if len(self.__axis) <= 0:
return
- lineEditWidth = max([a.slider().lineEdit().minimumSize().width() for a in self.__axis])
- limitWidth = max([a.slider().limitWidget().minimumSizeHint().width() for a in self.__axis])
+ lineEditWidth = max(
+ [a.slider().lineEdit().minimumSize().width() for a in self.__axis]
+ )
+ limitWidth = max(
+ [a.slider().limitWidget().minimumSizeHint().width() for a in self.__axis]
+ )
for a in self.__axis:
a.slider().lineEdit().setFixedWidth(lineEditWidth)
a.slider().limitWidget().setFixedWidth(limitWidth)
@@ -418,7 +427,9 @@ class NumpyAxesSelector(qt.QWidget):
# get a view with few fixed dimensions
# with a h5py dataset, it create a copy
# TODO we can reuse the same memory in case of a copy
- self.__selectedData = numpy.transpose(self.__data[self.selection()], permutation)
+ self.__selectedData = numpy.transpose(
+ self.__data[self.selection()], permutation
+ )
self.selectionChanged.emit()
def data(self):
@@ -477,8 +488,12 @@ class NumpyAxesSelector(qt.QWidget):
if self.__data is None:
return tuple()
else:
- return tuple([axis.value() if axis.axisName() == "" else slice(None)
- for axis in self.__axis])
+ return tuple(
+ [
+ axis.value() if axis.axisName() == "" else slice(None)
+ for axis in self.__axis
+ ]
+ )
def setSelection(self, selection, permutation=None):
"""Set the selection along each dimension.
@@ -501,8 +516,9 @@ class NumpyAxesSelector(qt.QWidget):
# Check selection
if len(selection) != len(data_shape):
raise ValueError(
- "Selection length (%d) and data ndim (%d) mismatch" %
- (len(selection), len(data_shape)))
+ "Selection length (%d) and data ndim (%d) mismatch"
+ % (len(selection), len(data_shape))
+ )
# Check selection type
selectedDataNDim = 0
@@ -510,8 +526,9 @@ class NumpyAxesSelector(qt.QWidget):
if isinstance(element, int):
if not 0 <= element < size:
raise ValueError(
- "Selected index (%d) outside data dimension range [0-%d]" %
- (element, size))
+ "Selected index (%d) outside data dimension range [0-%d]"
+ % (element, size)
+ )
elif element is None or element == slice(None):
selectedDataNDim += 1
else:
@@ -520,8 +537,9 @@ class NumpyAxesSelector(qt.QWidget):
ndim = len(self.__axisNames)
if selectedDataNDim != ndim:
raise ValueError(
- "Selection dimensions (%d) and number of axes (%d) mismatch" %
- (selectedDataNDim, ndim))
+ "Selection dimensions (%d) and number of axes (%d) mismatch"
+ % (selectedDataNDim, ndim)
+ )
# check permutation
if permutation is None:
@@ -530,7 +548,8 @@ class NumpyAxesSelector(qt.QWidget):
if set(permutation) != set(range(ndim)):
raise ValueError(
"Error in provided permutation: "
- "Wrong size, elements out of range or duplicates")
+ "Wrong size, elements out of range or duplicates"
+ )
inversePermutation = numpy.argsort(permutation)
diff --git a/src/silx/gui/data/RecordTableView.py b/src/silx/gui/data/RecordTableView.py
index 9079ba6..8bf1683 100644
--- a/src/silx/gui/data/RecordTableView.py
+++ b/src/silx/gui/data/RecordTableView.py
@@ -53,8 +53,9 @@ class _MultiLineItem(qt.QItemDelegate):
"""
qt.QItemDelegate.__init__(self, parent)
self.__textOptions = qt.QTextOption()
- self.__textOptions.setFlags(qt.QTextOption.IncludeTrailingSpaces |
- qt.QTextOption.ShowTabsAndSpaces)
+ self.__textOptions.setFlags(
+ qt.QTextOption.IncludeTrailingSpaces | qt.QTextOption.ShowTabsAndSpaces
+ )
self.__textOptions.setWrapMode(qt.QTextOption.NoWrap)
self.__textOptions.setAlignment(qt.Qt.AlignTop | qt.Qt.AlignLeft)
@@ -148,7 +149,7 @@ class RecordTableModel(qt.QAbstractTableModel):
:param numpy.ndarray data: A numpy array or a h5py dataset
"""
- MAX_NUMBER_OF_ROWS = 10e6
+ MAX_NUMBER_OF_ROWS = int(10e6)
"""Maximum number of display values of the dataset"""
def __init__(self, parent=None, data=None):
@@ -242,9 +243,11 @@ class RecordTableModel(qt.QAbstractTableModel):
return None
# Handle clipping of huge tables
- if (self.__isClipped() and
- orientation == qt.Qt.Vertical and
- section == self.rowCount() - 2):
+ if (
+ self.__isClipped()
+ and orientation == qt.Qt.Vertical
+ and section == self.rowCount() - 2
+ ):
return self.__clippedData(role)
if role == qt.Qt.DisplayRole:
@@ -276,7 +279,11 @@ class RecordTableModel(qt.QAbstractTableModel):
def __isClipped(self) -> bool:
"""Returns whether the displayed array is clipped or not"""
- return self.__data is not None and self.__is_array and len(self.__data) > self.MAX_NUMBER_OF_ROWS
+ return (
+ self.__data is not None
+ and self.__is_array
+ and len(self.__data) > self.MAX_NUMBER_OF_ROWS
+ )
def setArrayData(self, data):
"""Set the data array and the viewing perspective.
@@ -359,8 +366,7 @@ class RecordTableModel(qt.QAbstractTableModel):
return self.__formatter
def __formatChanged(self):
- """Called when the format changed.
- """
+ """Called when the format changed."""
self.__editFormatter = TextFormatter(self, self.getFormatter())
self.__editFormatter.setUseQuoteForText(False)
self.reset()
@@ -398,8 +404,8 @@ class _ShowEditorProxyModel(qt.QIdentityProxyModel):
class RecordTableView(qt.QTableView):
- """TableView using DatabaseTableModel as default model.
- """
+ """TableView using DatabaseTableModel as default model."""
+
def __init__(self, parent=None):
"""
Constructor
diff --git a/src/silx/gui/data/TextFormatter.py b/src/silx/gui/data/TextFormatter.py
index d409381..aee2427 100644
--- a/src/silx/gui/data/TextFormatter.py
+++ b/src/silx/gui/data/TextFormatter.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -83,8 +83,8 @@ class TextFormatter(qt.QObject):
self.__integerFormat = "%d"
self.__floatFormat = "%g"
self.__useQuoteForText = True
- self.__imaginaryUnit = u"j"
- self.__enumFormat = u"%(name)s(%(value)d)"
+ self.__imaginaryUnit = "j"
+ self.__enumFormat = "%(name)s(%(value)d)"
def integerFormat(self):
"""Returns the format string controlling how the integer data
@@ -195,7 +195,7 @@ class TextFormatter(qt.QObject):
def __formatText(self, text):
if self.__useQuoteForText:
- text = "\"%s\"" % text.replace("\\", "\\\\").replace("\"", "\\\"")
+ text = '"%s"' % text.replace("\\", "\\\\").replace('"', '\\"')
return text
def __formatBinary(self, data):
@@ -209,7 +209,7 @@ class TextFormatter(qt.QObject):
pass
data = ["\\x%02X" % d for d in data]
if self.__useQuoteForText:
- return "b\"%s\"" % "".join(data)
+ return 'b"%s"' % "".join(data)
else:
return "".join(data)
@@ -217,7 +217,7 @@ class TextFormatter(qt.QObject):
data = [chr(d) if (d > 0x20 and d < 0x7F) else "\\x%02X" % d for d in data]
if self.__useQuoteForText:
data = [c if c != '"' else "\\" + c for c in data]
- return "b\"%s\"" % "".join(data)
+ return 'b"%s"' % "".join(data)
else:
return "".join(data)
@@ -233,6 +233,8 @@ class TextFormatter(qt.QObject):
:param data: A binary string of char expected in ASCII
:rtype: str
"""
+ if isinstance(data, str):
+ return self.__formatText(data)
try:
text = "%s" % data.decode("ascii")
return self.__formatText(text)
@@ -242,7 +244,7 @@ class TextFormatter(qt.QObject):
_logger.error("Invalid ASCII string %s.", data)
if data == b"\xB0":
_logger.error("Fallback using cp1252 encoding")
- return self.__formatText(u"\u00B0")
+ return self.__formatText("\u00B0")
return self.__formatSafeAscii(data)
def __formatH5pyObject(self, data, dtype):
@@ -294,7 +296,7 @@ class TextFormatter(qt.QObject):
else:
text = [self.toString(d, dtype) for d in data]
return "[" + " ".join(text) + "]"
- if dtype is not None and dtype.kind == 'O':
+ if dtype is not None and dtype.kind == "O":
text = self.__formatH5pyObject(data, dtype)
if text is not None:
return text
@@ -304,7 +306,9 @@ class TextFormatter(qt.QObject):
if dtype.fields is not None:
text = []
for index, field in enumerate(dtype.fields.items()):
- text.append(field[0] + ":" + self.toString(data[index], field[1][0]))
+ text.append(
+ field[0] + ":" + self.toString(data[index], field[1][0])
+ )
return "(" + " ".join(text) + ")"
return self.__formatBinary(data)
elif isinstance(data, (numpy.unicode_, str)):
@@ -314,9 +318,9 @@ class TextFormatter(qt.QObject):
dtype = data.dtype
if dtype is not None:
# Maybe a sub item from HDF5
- if dtype.kind == 'S':
+ if dtype.kind == "S":
return self.__formatCharString(data)
- elif dtype.kind == 'O':
+ elif dtype.kind == "O":
text = self.__formatH5pyObject(data, dtype)
if text is not None:
return text
@@ -353,18 +357,28 @@ class TextFormatter(qt.QObject):
text += self.__floatFormat % data.real
if data.real != 0 and data.imag != 0:
if data.imag < 0:
- template = self.__floatFormat + " - " + self.__floatFormat + self.__imaginaryUnit
+ template = (
+ self.__floatFormat
+ + " - "
+ + self.__floatFormat
+ + self.__imaginaryUnit
+ )
params = (data.real, -data.imag)
else:
- template = self.__floatFormat + " + " + self.__floatFormat + self.__imaginaryUnit
+ template = (
+ self.__floatFormat
+ + " + "
+ + self.__floatFormat
+ + self.__imaginaryUnit
+ )
params = (data.real, data.imag)
else:
if data.imag != 0:
template = self.__floatFormat + self.__imaginaryUnit
- params = (data.imag)
+ params = data.imag
else:
template = self.__floatFormat
- params = (data.real)
+ params = data.real
return template % params
elif isinstance(data, h5py.h5r.Reference):
dtype = h5py.special_dtype(ref=h5py.Reference)
diff --git a/src/silx/gui/data/_RecordPlot.py b/src/silx/gui/data/_RecordPlot.py
index 5be792f..b994a6e 100644
--- a/src/silx/gui/data/_RecordPlot.py
+++ b/src/silx/gui/data/_RecordPlot.py
@@ -5,16 +5,28 @@ from .. import qt
class RecordPlot(PlotWindow):
def __init__(self, parent=None, backend=None):
- super(RecordPlot, self).__init__(parent=parent, backend=backend,
- resetzoom=True, autoScale=True,
- logScale=True, grid=True,
- curveStyle=True, colormap=False,
- aspectRatio=False, yInverted=False,
- copy=True, save=True, print_=True,
- control=True, position=True,
- roi=True, mask=False, fit=True)
+ super(RecordPlot, self).__init__(
+ parent=parent,
+ backend=backend,
+ resetzoom=True,
+ autoScale=True,
+ logScale=True,
+ grid=True,
+ curveStyle=True,
+ colormap=False,
+ aspectRatio=False,
+ yInverted=False,
+ copy=True,
+ save=True,
+ print_=True,
+ control=True,
+ position=True,
+ roi=True,
+ mask=False,
+ fit=True,
+ )
if parent is None:
- self.setWindowTitle('RecordPlot')
+ self.setWindowTitle("RecordPlot")
self._axesSelectionToolBar = AxesSelectionToolBar(parent=self, plot=self)
self.addToolBar(qt.Qt.BottomToolBarArea, self._axesSelectionToolBar)
@@ -23,7 +35,7 @@ class RecordPlot(PlotWindow):
:param Union[str,None] value:
"""
- label = '' if value is None else value
+ label = "" if value is None else value
index = self._axesSelectionToolBar.getXAxisDropDown().findData(value)
if index >= 0:
@@ -53,7 +65,7 @@ class RecordPlot(PlotWindow):
"""
comboBox = self._axesSelectionToolBar.getXAxisDropDown()
comboBox.clear()
- comboBox.addItem('-', None)
+ comboBox.addItem("-", None)
comboBox.insertSeparator(1)
for name in fieldNames:
comboBox.addItem(name, name)
@@ -65,8 +77,9 @@ class RecordPlot(PlotWindow):
def getAxesSelectionToolBar(self):
return self._axesSelectionToolBar
+
class AxesSelectionToolBar(qt.QToolBar):
- def __init__(self, parent=None, plot=None, title='Plot Axes Selection'):
+ def __init__(self, parent=None, plot=None, title="Plot Axes Selection"):
super(AxesSelectionToolBar, self).__init__(title, parent)
assert isinstance(plot, PlotWidget)
@@ -89,4 +102,4 @@ class AxesSelectionToolBar(qt.QToolBar):
return self._selectXAxisDropDown
def getYAxisDropDown(self):
- return self._selectYAxisDropDown \ No newline at end of file
+ return self._selectYAxisDropDown
diff --git a/src/silx/gui/data/_VolumeWindow.py b/src/silx/gui/data/_VolumeWindow.py
index fa2730c..49b18d5 100644
--- a/src/silx/gui/data/_VolumeWindow.py
+++ b/src/silx/gui/data/_VolumeWindow.py
@@ -56,16 +56,16 @@ class VolumeWindow(SceneWindow):
"""
sceneWidget = self.getSceneWidget()
sceneWidget.getSceneGroup().setAxesLabels(
- 'X' if xlabel is None else xlabel,
- 'Y' if ylabel is None else ylabel,
- 'Z' if zlabel is None else zlabel)
+ "X" if xlabel is None else xlabel,
+ "Y" if ylabel is None else ylabel,
+ "Z" if zlabel is None else zlabel,
+ )
def clear(self):
"""Clear any currently displayed data"""
sceneWidget = self.getSceneWidget()
items = sceneWidget.getItems()
- if (len(items) == 1 and
- isinstance(items[0], (ScalarField3D, ComplexField3D))):
+ if len(items) == 1 and isinstance(items[0], (ScalarField3D, ComplexField3D)):
items[0].setData(None)
else: # Safety net
sceneWidget.clearItems()
@@ -83,7 +83,7 @@ class VolumeWindow(SceneWindow):
else:
return numpy.mean(data) + numpy.std(data)
- def setData(self, data, offset=(0., 0., 0.), scale=(1., 1., 1.)):
+ def setData(self, data, offset=(0.0, 0.0, 0.0), scale=(1.0, 1.0, 1.0)):
"""Set the 3D array data to display.
:param numpy.ndarray data: 3D array of float or complex
@@ -94,9 +94,11 @@ class VolumeWindow(SceneWindow):
dataMaxCoords = numpy.array(list(reversed(data.shape))) - 1
previousItems = sceneWidget.getItems()
- if (len(previousItems) == 1 and
- isinstance(previousItems[0], (ScalarField3D, ComplexField3D)) and
- numpy.iscomplexobj(data) == isinstance(previousItems[0], ComplexField3D)):
+ if (
+ len(previousItems) == 1
+ and isinstance(previousItems[0], (ScalarField3D, ComplexField3D))
+ and numpy.iscomplexobj(data) == isinstance(previousItems[0], ComplexField3D)
+ ):
# Reuse existing volume item
volume = sceneWidget.getItems()[0]
volume.setData(data, copy=False)
@@ -109,13 +111,13 @@ class VolumeWindow(SceneWindow):
# Add a new volume
sceneWidget.clearItems()
volume = sceneWidget.addVolume(data, copy=False)
- volume.setLabel('Volume')
+ volume.setLabel("Volume")
for plane in volume.getCutPlanes():
# Make plane going through the center of the data
plane.setPoint(dataMaxCoords // 2)
plane.setVisible(False)
plane.sigItemChanged.connect(self.__cutPlaneUpdated)
- volume.addIsosurface(self.__computeIsolevel, '#FF0000FF')
+ volume.addIsosurface(self.__computeIsolevel, "#FF0000FF")
# Expand the parameter tree
model = self.getParamTreeView().model()
diff --git a/src/silx/gui/data/test/test_arraywidget.py b/src/silx/gui/data/test/test_arraywidget.py
index 024383d..faca333 100644
--- a/src/silx/gui/data/test/test_arraywidget.py
+++ b/src/silx/gui/data/test/test_arraywidget.py
@@ -27,7 +27,6 @@ __date__ = "05/12/2016"
import os
import tempfile
-import unittest
import numpy
@@ -41,6 +40,7 @@ import h5py
class TestArrayWidget(TestCaseQt):
"""Basic test for ArrayTableWidget with a numpy array"""
+
def setUp(self):
super(TestArrayWidget, self).setUp()
self.aw = ArrayTableWidget.ArrayTableWidget()
@@ -79,16 +79,13 @@ class TestArrayWidget(TestCaseQt):
self.assertEqual(len(self.aw.model._perspective), 0)
def testSetData4D(self):
- a = numpy.reshape(numpy.linspace(0.213, 1.234, 1250),
- (5, 5, 5, 10))
+ a = numpy.reshape(numpy.linspace(0.213, 1.234, 1250), (5, 5, 5, 10))
self.aw.setArrayData(a)
# default perspective (0, 1)
- self.assertEqual(list(self.aw.model._perspective),
- [0, 1])
+ self.assertEqual(list(self.aw.model._perspective), [0, 1])
self.aw.setPerspective((1, 3))
- self.assertEqual(list(self.aw.model._perspective),
- [1, 3])
+ self.assertEqual(list(self.aw.model._perspective), [1, 3])
b = self.aw.getData(copy=True)
self.assertTrue(numpy.array_equal(a, b))
@@ -96,12 +93,10 @@ class TestArrayWidget(TestCaseQt):
# 4D data has a 2-tuple as frame index
self.assertEqual(len(self.aw.model._index), 2)
# default index is (0, 0)
- self.assertEqual(list(self.aw.model._index),
- [0, 0])
+ self.assertEqual(list(self.aw.model._index), [0, 0])
self.aw.setFrameIndex((3, 1))
- self.assertEqual(list(self.aw.model._index),
- [3, 1])
+ self.assertEqual(list(self.aw.model._index), [3, 1])
def testColors(self):
a = numpy.arange(256, dtype=numpy.uint8)
@@ -121,18 +116,20 @@ class TestArrayWidget(TestCaseQt):
for i in range(256):
# all RGB channels for BG equal to data value
self.assertEqual(
- self.aw.model.data(self.aw.model.index(0, i),
- role=qt.Qt.BackgroundRole),
+ self.aw.model.data(
+ self.aw.model.index(0, i), role=qt.Qt.BackgroundRole
+ ),
qt.QColor(i, i, i),
- "Unexpected background color"
+ "Unexpected background color",
)
# all RGB channels for FG equal to XOR(data value, 255)
self.assertEqual(
- self.aw.model.data(self.aw.model.index(0, i),
- role=qt.Qt.ForegroundRole),
+ self.aw.model.data(
+ self.aw.model.index(0, i), role=qt.Qt.ForegroundRole
+ ),
qt.QColor(i ^ 255, i ^ 255, i ^ 255),
- "Unexpected text color"
+ "Unexpected text color",
)
# test colors are reset to None when a new data array is loaded
@@ -142,30 +139,27 @@ class TestArrayWidget(TestCaseQt):
for i in range(300):
# all RGB channels for BG equal to data value
self.assertIsNone(
- self.aw.model.data(self.aw.model.index(0, i),
- role=qt.Qt.BackgroundRole))
+ self.aw.model.data(self.aw.model.index(0, i), role=qt.Qt.BackgroundRole)
+ )
def testDefaultFlagNotEditable(self):
"""editable should be False by default, in setArrayData"""
self.aw.setArrayData([[0]])
idx = self.aw.model.createIndex(0, 0)
# model is editable
- self.assertFalse(
- self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
+ self.assertFalse(self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
def testFlagEditable(self):
self.aw.setArrayData([[0]], editable=True)
idx = self.aw.model.createIndex(0, 0)
# model is editable
- self.assertTrue(
- self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
+ self.assertTrue(self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
def testFlagNotEditable(self):
self.aw.setArrayData([[0]], editable=False)
idx = self.aw.model.createIndex(0, 0)
# model is editable
- self.assertFalse(
- self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
+ self.assertFalse(self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
def testReferenceReturned(self):
"""when setting the data with copy=False and
@@ -173,8 +167,7 @@ class TestArrayWidget(TestCaseQt):
the same original object.
"""
# n-D (n >=2)
- a0 = numpy.reshape(numpy.linspace(0.213, 1.234, 1000),
- (10, 10, 10))
+ a0 = numpy.reshape(numpy.linspace(0.213, 1.234, 1000), (10, 10, 10))
self.aw.setArrayData(a0, copy=False)
a1 = self.aw.getData(copy=False)
@@ -203,15 +196,15 @@ class TestH5pyArrayWidget(TestCaseQt):
"""Basic test for ArrayTableWidget with a dataset.
Test flags, for dataset open in read-only or read-write modes"""
+
def setUp(self):
super(TestH5pyArrayWidget, self).setUp()
self.aw = ArrayTableWidget.ArrayTableWidget()
- self.data = numpy.reshape(numpy.linspace(0.213, 1.234, 1000),
- (10, 10, 10))
+ self.data = numpy.reshape(numpy.linspace(0.213, 1.234, 1000), (10, 10, 10))
# create an h5py file with a dataset
self.tempdir = tempfile.mkdtemp()
self.h5_fname = os.path.join(self.tempdir, "array.h5")
- h5f = h5py.File(self.h5_fname, mode='w')
+ h5f = h5py.File(self.h5_fname, mode="w")
h5f["my_array"] = self.data
h5f["my_scalar"] = 3.14
h5f["my_1D_array"] = numpy.array(numpy.arange(1000))
@@ -236,7 +229,7 @@ class TestH5pyArrayWidget(TestCaseQt):
self.aw.setArrayData(a, copy=False, editable=True)
- self.assertIsInstance(a, h5py.Dataset) # simple sanity check
+ self.assertIsInstance(a, h5py.Dataset) # simple sanity check
# internal representation must be a reference to original data (copy=False)
self.assertIsInstance(self.aw.model._array, h5py.Dataset)
self.assertTrue(self.aw.model._array.file.mode == "r")
@@ -247,12 +240,12 @@ class TestH5pyArrayWidget(TestCaseQt):
# model must have detected read-only dataset and disabled editing
self.assertFalse(self.aw.model._editable)
idx = self.aw.model.createIndex(0, 0)
- self.assertFalse(
- self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
+ self.assertFalse(self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
# force editing read-only datasets raises IOError
- self.assertRaises(IOError, self.aw.model.setData,
- idx, 123.4, role=qt.Qt.EditRole)
+ self.assertRaises(
+ IOError, self.aw.model.setData, idx, 123.4, role=qt.Qt.EditRole
+ )
h5f.close()
def testReadWrite(self):
@@ -266,8 +259,7 @@ class TestH5pyArrayWidget(TestCaseQt):
idx = self.aw.model.createIndex(0, 0)
# model is editable
- self.assertTrue(
- self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
+ self.assertTrue(self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
h5f.close()
def testSetData0D(self):
diff --git a/src/silx/gui/data/test/test_dataviewer.py b/src/silx/gui/data/test/test_dataviewer.py
index 80f47b7..85bbf7a 100644
--- a/src/silx/gui/data/test/test_dataviewer.py
+++ b/src/silx/gui/data/test/test_dataviewer.py
@@ -90,7 +90,7 @@ class _TestAbstractDataViewer(TestCaseQt):
self.assertEqual(DataViews.RAW_MODE, widget.displayMode())
def test_plot_1d_data(self):
- data = numpy.arange(3 ** 1)
+ data = numpy.arange(3**1)
data.shape = [3] * 1
widget = self.create_widget()
widget.setData(data)
@@ -99,7 +99,7 @@ class _TestAbstractDataViewer(TestCaseQt):
self.assertIn(DataViews.PLOT1D_MODE, availableModes)
def test_image_data(self):
- data = numpy.arange(3 ** 2)
+ data = numpy.arange(3**2)
data.shape = [3] * 2
widget = self.create_widget()
widget.setData(data)
@@ -117,7 +117,7 @@ class _TestAbstractDataViewer(TestCaseQt):
self.assertIn(DataViews.IMAGE_MODE, availableModes)
def test_image_complex_data(self):
- data = numpy.arange(3 ** 2, dtype=numpy.complex64)
+ data = numpy.arange(3**2, dtype=numpy.complex64)
data.shape = [3] * 2
widget = self.create_widget()
widget.setData(data)
@@ -126,41 +126,42 @@ class _TestAbstractDataViewer(TestCaseQt):
self.assertIn(DataViews.IMAGE_MODE, availableModes)
def test_plot_3d_data(self):
- data = numpy.arange(3 ** 3)
+ data = numpy.arange(3**3)
data.shape = [3] * 3
widget = self.create_widget()
widget.setData(data)
availableModes = set([v.modeId() for v in widget.currentAvailableViews()])
try:
import silx.gui.plot3d # noqa
+
self.assertIn(DataViews.PLOT3D_MODE, availableModes)
except ImportError:
self.assertIn(DataViews.STACK_MODE, availableModes)
self.assertEqual(DataViews.RAW_MODE, widget.displayMode())
def test_array_1d_data(self):
- data = numpy.array(["aaa"] * (3 ** 1))
+ data = numpy.array(["aaa"] * (3**1))
data.shape = [3] * 1
widget = self.create_widget()
widget.setData(data)
self.assertEqual(DataViews.RAW_MODE, widget.displayedView().modeId())
def test_array_2d_data(self):
- data = numpy.array(["aaa"] * (3 ** 2))
+ data = numpy.array(["aaa"] * (3**2))
data.shape = [3] * 2
widget = self.create_widget()
widget.setData(data)
self.assertEqual(DataViews.RAW_MODE, widget.displayedView().modeId())
def test_array_4d_data(self):
- data = numpy.array(["aaa"] * (3 ** 4))
+ data = numpy.array(["aaa"] * (3**4))
data.shape = [3] * 4
widget = self.create_widget()
widget.setData(data)
self.assertEqual(DataViews.RAW_MODE, widget.displayedView().modeId())
def test_record_4d_data(self):
- data = numpy.zeros(3 ** 4, dtype='3int8, float32, (2,3)float64')
+ data = numpy.zeros(3**4, dtype="3int8, float32, (2,3)float64")
data.shape = [3] * 4
widget = self.create_widget()
widget.setData(data)
@@ -192,7 +193,7 @@ class _TestAbstractDataViewer(TestCaseQt):
def test_change_display_mode(self):
listener = SignalListener()
- data = numpy.arange(10 ** 4)
+ data = numpy.arange(10**4)
data.shape = [10] * 4
widget = self.create_widget()
widget.selectionChanged.connect(listener)
@@ -245,8 +246,7 @@ class _TestAbstractDataViewer(TestCaseQt):
def test_replace_view(self):
widget = self.create_widget()
view = _DataViewMock(widget)
- widget.replaceView(DataViews.RAW_MODE,
- view)
+ widget.replaceView(DataViews.RAW_MODE, view)
self.assertIsNone(widget.getViewFromModeId(DataViews.RAW_MODE))
self.assertTrue(view in widget.availableViews())
self.assertTrue(view in widget.currentAvailableViews())
@@ -255,29 +255,30 @@ class _TestAbstractDataViewer(TestCaseQt):
# replace a view that is a child of a composite view
widget = self.create_widget()
view = _DataViewMock(widget)
- replaced = widget.replaceView(DataViews.NXDATA_INVALID_MODE,
- view)
+ replaced = widget.replaceView(DataViews.NXDATA_INVALID_MODE, view)
self.assertTrue(replaced)
nxdata_view = widget.getViewFromModeId(DataViews.NXDATA_MODE)
- self.assertNotIn(DataViews.NXDATA_INVALID_MODE,
- [v.modeId() for v in nxdata_view.getViews()])
+ self.assertNotIn(
+ DataViews.NXDATA_INVALID_MODE, [v.modeId() for v in nxdata_view.getViews()]
+ )
self.assertTrue(view in nxdata_view.getViews())
class TestDataViewer(_TestAbstractDataViewer):
__test__ = True # because _TestAbstractDataViewer is ignored
+
def create_widget(self):
return DataViewer()
class TestDataViewerFrame(_TestAbstractDataViewer):
__test__ = True # because _TestAbstractDataViewer is ignored
+
def create_widget(self):
return DataViewerFrame()
class TestDataView(TestCaseQt):
-
def createComplexData(self):
line = [1, 2j, 3 + 3j, 4]
image = [line, line, line, line]
diff --git a/src/silx/gui/data/test/test_numpyaxesselector.py b/src/silx/gui/data/test/test_numpyaxesselector.py
index 4a53149..450b89d 100644
--- a/src/silx/gui/data/test/test_numpyaxesselector.py
+++ b/src/silx/gui/data/test/test_numpyaxesselector.py
@@ -27,7 +27,6 @@ __date__ = "29/01/2018"
import os
import tempfile
-import unittest
from contextlib import contextmanager
import numpy
@@ -40,7 +39,6 @@ import h5py
class TestNumpyAxesSelector(TestCaseQt):
-
def test_creation(self):
data = numpy.arange(3 * 3 * 3)
data.shape = 3, 3, 3
diff --git a/src/silx/gui/data/test/test_textformatter.py b/src/silx/gui/data/test/test_textformatter.py
index b82cc7a..49b8283 100644
--- a/src/silx/gui/data/test/test_textformatter.py
+++ b/src/silx/gui/data/test/test_textformatter.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -25,7 +25,6 @@ __authors__ = ["V. Valls"]
__license__ = "MIT"
__date__ = "12/12/2017"
-import unittest
import shutil
import tempfile
@@ -34,13 +33,12 @@ import numpy
from silx.gui.utils.testutils import TestCaseQt
from silx.gui.utils.testutils import SignalListener
from ..TextFormatter import TextFormatter
-from silx.io.utils import h5py_read_dataset
import h5py
+import pytest
class TestTextFormatter(TestCaseQt):
-
def test_copy(self):
formatter = TextFormatter()
copy = TextFormatter(formatter=formatter)
@@ -97,11 +95,10 @@ class TestTextFormatter(TestCaseQt):
# degree character in cp1252
formatter = TextFormatter()
result = formatter.toString(numpy.bytes_(b"\xB0"))
- self.assertEqual(result, u'"\u00B0"')
+ self.assertEqual(result, '"\u00B0"')
class TestTextFormatterWithH5py(TestCaseQt):
-
@classmethod
def setUpClass(cls):
super(TestTextFormatterWithH5py, cls).setUpClass()
@@ -131,10 +128,10 @@ class TestTextFormatterWithH5py(TestCaseQt):
self.assertEqual(result, '"abc"')
def testUnicode(self):
- d = self.create_dataset(data=u"i\u2661cookies")
+ d = self.create_dataset(data="i\u2661cookies")
result = self.read_dataset(d)
self.assertEqual(len(result), 11)
- self.assertEqual(result, u'"i\u2661cookies"')
+ self.assertEqual(result, '"i\u2661cookies"')
def testBadAscii(self):
d = self.create_dataset(data=b"\xF0\x9F\x92\x94")
@@ -147,18 +144,18 @@ class TestTextFormatterWithH5py(TestCaseQt):
self.assertEqual(result, 'b"\\x61\\x62\\x63\\xF0"')
def testEnum(self):
- dtype = h5py.special_dtype(enum=('i', {"RED": 0, "GREEN": 1, "BLUE": 42}))
+ dtype = h5py.special_dtype(enum=("i", {"RED": 0, "GREEN": 1, "BLUE": 42}))
d = numpy.array(42, dtype=dtype)
d = self.create_dataset(data=d)
result = self.read_dataset(d)
- self.assertEqual(result, 'BLUE(42)')
+ self.assertEqual(result, "BLUE(42)")
def testRef(self):
dtype = h5py.special_dtype(ref=h5py.Reference)
d = numpy.array(self.h5File.ref, dtype=dtype)
d = self.create_dataset(data=d)
result = self.read_dataset(d)
- self.assertEqual(result, 'REF')
+ self.assertEqual(result, "REF")
def testArrayAscii(self):
d = self.create_dataset(data=[b"abc"])
@@ -167,11 +164,11 @@ class TestTextFormatterWithH5py(TestCaseQt):
def testArrayUnicode(self):
dtype = h5py.special_dtype(vlen=str)
- d = numpy.array([u"i\u2661cookies"], dtype=dtype)
+ d = numpy.array(["i\u2661cookies"], dtype=dtype)
d = self.create_dataset(data=d)
result = self.read_dataset(d)
self.assertEqual(len(result), 13)
- self.assertEqual(result, u'["i\u2661cookies"]')
+ self.assertEqual(result, '["i\u2661cookies"]')
def testArrayBadAscii(self):
d = self.create_dataset(data=[b"\xF0\x9F\x92\x94"])
@@ -184,15 +181,32 @@ class TestTextFormatterWithH5py(TestCaseQt):
self.assertEqual(result, '[b"\\x61\\x62\\x63\\xF0"]')
def testArrayEnum(self):
- dtype = h5py.special_dtype(enum=('i', {"RED": 0, "GREEN": 1, "BLUE": 42}))
+ dtype = h5py.special_dtype(enum=("i", {"RED": 0, "GREEN": 1, "BLUE": 42}))
d = numpy.array([42, 1, 100], dtype=dtype)
d = self.create_dataset(data=d)
result = self.read_dataset(d)
- self.assertEqual(result, '[BLUE(42) GREEN(1) 100]')
+ self.assertEqual(result, "[BLUE(42) GREEN(1) 100]")
def testArrayRef(self):
dtype = h5py.special_dtype(ref=h5py.Reference)
d = numpy.array([self.h5File.ref, None], dtype=dtype)
d = self.create_dataset(data=d)
result = self.read_dataset(d)
- self.assertEqual(result, '[REF NULL_REF]')
+ self.assertEqual(result, "[REF NULL_REF]")
+
+
+@pytest.mark.parametrize(
+ "data, expected",
+ [
+ (b"bytes", '"bytes"'),
+ ("unicode", '"unicode"'),
+ ((b"elem0", b"elem1"), '["elem0" "elem1"]'),
+ (("elem0", "elem1"), '["elem0" "elem1"]'),
+ ],
+)
+def test_formatter_h5py_attr(tmp_h5py_file, data, expected):
+ """Test formatter with h5py attributes"""
+ tmp_h5py_file.attrs["attr"] = data
+ formatter = TextFormatter()
+ result = formatter.toString(tmp_h5py_file.attrs["attr"])
+ assert result == expected
diff --git a/src/silx/gui/dialog/AbstractDataFileDialog.py b/src/silx/gui/dialog/AbstractDataFileDialog.py
index f656bb2..00db275 100644
--- a/src/silx/gui/dialog/AbstractDataFileDialog.py
+++ b/src/silx/gui/dialog/AbstractDataFileDialog.py
@@ -56,7 +56,6 @@ some version of PyQt."""
class _IconProvider(object):
-
FileDialogToParentDir = qt.QStyle.SP_CustomBase + 1
FileDialogToParentFile = qt.QStyle.SP_CustomBase + 2
@@ -92,7 +91,9 @@ class _IconProvider(object):
pixmap.fill(qt.Qt.transparent)
painter = qt.QPainter(pixmap)
painter.drawPixmap(0, 0, backgroundIcon.pixmap(baseSize, mode=mode))
- painter.drawPixmap(0, size.height() // 3, baseIcon.pixmap(baseSize, mode=mode))
+ painter.drawPixmap(
+ 0, size.height() // 3, baseIcon.pixmap(baseSize, mode=mode)
+ )
painter.end()
icon.addPixmap(pixmap, mode=mode)
@@ -100,12 +101,16 @@ class _IconProvider(object):
def getFileDialogToParentDir(self):
if self.__iconFileDialogToParentDir is None:
- self.__iconFileDialogToParentDir = self._createIconToParent(qt.QStyle.SP_DirIcon)
+ self.__iconFileDialogToParentDir = self._createIconToParent(
+ qt.QStyle.SP_DirIcon
+ )
return self.__iconFileDialogToParentDir
def getFileDialogToParentFile(self):
if self.__iconFileDialogToParentFile is None:
- self.__iconFileDialogToParentFile = self._createIconToParent(qt.QStyle.SP_FileIcon)
+ self.__iconFileDialogToParentFile = self._createIconToParent(
+ qt.QStyle.SP_FileIcon
+ )
return self.__iconFileDialogToParentFile
def icon(self, kind):
@@ -147,13 +152,17 @@ class _SideBar(qt.QListView):
:rtype: List[str]
"""
urls = []
- version = tuple(map(int, qt.qVersion().split('.')[:3]))
+ version = tuple(map(int, qt.qVersion().split(".")[:3]))
feed_sidebar = True
if not DEFAULT_SIDEBAR_URL:
_logger.debug("Skip default sidebar URLs (from setted variable)")
feed_sidebar = False
- elif version < (5, 11, 2) and qt.BINDING == "PyQt5" and sys.platform in ["linux", "linux2"]:
+ elif (
+ version < (5, 11, 2)
+ and qt.BINDING == "PyQt5"
+ and sys.platform in ["linux", "linux2"]
+ ):
# Avoid segfault on PyQt5 + gtk
_logger.debug("Skip default sidebar URLs (avoid PyQt5 segfault)")
feed_sidebar = False
@@ -186,7 +195,9 @@ class _SideBar(qt.QListView):
selectionModel = self.selectionModel()
if selected is not None:
- selectionModel.setCurrentIndex(selected, qt.QItemSelectionModel.ClearAndSelect)
+ selectionModel.setCurrentIndex(
+ selected, qt.QItemSelectionModel.ClearAndSelect
+ )
else:
selectionModel.clear()
@@ -232,11 +243,12 @@ class _SideBar(qt.QListView):
def sizeHint(self):
index = self.model().index(0, 0)
- return self.sizeHintForIndex(index) + qt.QSize(2 * self.frameWidth(), 2 * self.frameWidth())
+ return self.sizeHintForIndex(index) + qt.QSize(
+ 2 * self.frameWidth(), 2 * self.frameWidth()
+ )
class _Browser(qt.QStackedWidget):
-
activated = qt.Signal(qt.QModelIndex)
selected = qt.Signal(qt.QModelIndex)
rootIndexChanged = qt.Signal(qt.QModelIndex)
@@ -302,7 +314,7 @@ class _Browser(qt.QStackedWidget):
elif self.currentIndex() == 1:
return qt.QFileDialog.Detail
else:
- assert(False)
+ assert False
def setViewMode(self, mode):
"""Set the current view mode.
@@ -314,7 +326,7 @@ class _Browser(qt.QStackedWidget):
elif mode == qt.QFileDialog.List:
self.showList()
else:
- assert(False)
+ assert False
def showList(self):
self.__listView.show()
@@ -342,11 +354,10 @@ class _Browser(qt.QStackedWidget):
self.__detailView.setModel(None)
def setRootIndex(self, index, model=None):
- """Sets the root item to the item at the given index.
- """
+ """Sets the root item to the item at the given index."""
rootIndex = self.__listView.rootIndex()
newModel = model or index.model()
- assert(newModel is not None)
+ assert newModel is not None
if rootIndex is None or rootIndex.model() is not newModel:
# update the model
@@ -415,12 +426,16 @@ class _Browser(qt.QStackedWidget):
nameId = stream.readQString()
if nameId != "Browser":
- _logger.warning("Stored state contains an invalid name id. Browser restoration cancelled.")
+ _logger.warning(
+ "Stored state contains an invalid name id. Browser restoration cancelled."
+ )
return False
version = stream.readInt32()
if version != self.__serialVersion:
- _logger.warning("Stored state contains an invalid version. Browser restoration cancelled.")
+ _logger.warning(
+ "Stored state contains an invalid version. Browser restoration cancelled."
+ )
return False
headerData = stream.readQVariant()
@@ -438,12 +453,12 @@ class _Browser(qt.QStackedWidget):
data = qt.QByteArray()
stream = qt.QDataStream(data, qt.QIODevice.WriteOnly)
- nameId = u"Browser"
+ nameId = "Browser"
stream.writeQString(nameId)
stream.writeInt32(self.__serialVersion)
stream.writeQVariant(self.__detailView.header().saveState())
viewMode = self.viewMode()
- if qt.BINDING == 'PyQt6': # No auto conversion to int
+ if qt.BINDING in ("PyQt6", "PySide6"): # No auto conversion to int
viewMode = viewMode.value
stream.writeInt32(viewMode)
@@ -451,7 +466,6 @@ class _Browser(qt.QStackedWidget):
class _FabioData(object):
-
def __init__(self, fabioFile):
self.__fabioFile = fabioFile
@@ -492,7 +506,6 @@ class _PathEdit(qt.QLineEdit):
class _CatchResizeEvent(qt.QObject):
-
resized = qt.Signal(qt.QResizeEvent)
def __init__(self, parent, target):
@@ -565,6 +578,7 @@ class AbstractDataFileDialog(qt.QDialog):
_logger.debug("Uses default QFileSystemModel with a SafeFileIconProvider")
self.__fileModel = qt.QFileSystemModel(self)
from .SafeFileIconProvider import SafeFileIconProvider
+
iconProvider = SafeFileIconProvider()
self.__fileModel.setIconProvider(iconProvider)
@@ -677,8 +691,12 @@ class AbstractDataFileDialog(qt.QDialog):
self.__fileTypeCombo = FileTypeComboBox(self)
self.__fileTypeCombo.setObjectName("fileTypeCombo")
self.__fileTypeCombo.setDuplicatesEnabled(False)
- self.__fileTypeCombo.setSizeAdjustPolicy(qt.QComboBox.AdjustToMinimumContentsLengthWithIcon)
- self.__fileTypeCombo.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
+ self.__fileTypeCombo.setSizeAdjustPolicy(
+ qt.QComboBox.AdjustToMinimumContentsLengthWithIcon
+ )
+ self.__fileTypeCombo.setSizePolicy(
+ qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed
+ )
self.__fileTypeCombo.activated[int].connect(self.__filterSelected)
self.__fileTypeCombo.setFabioUrlSupproted(self._isFabioFilesSupported())
@@ -708,7 +726,9 @@ class AbstractDataFileDialog(qt.QDialog):
if self.__selectorWidget is not None:
self.__selectorWidget.selectionChanged.connect(self.__selectorWidgetChanged)
- self.__previewToolBar = self._createPreviewToolbar(self, self.__previewWidget, self.__selectorWidget)
+ self.__previewToolBar = self._createPreviewToolbar(
+ self, self.__previewWidget, self.__selectorWidget
+ )
self.__dataIcon = qt.QLabel(self)
self.__dataIcon.setSizePolicy(qt.QSizePolicy.Fixed, qt.QSizePolicy.Fixed)
@@ -767,7 +787,9 @@ class AbstractDataFileDialog(qt.QDialog):
parentFileDirectory = qt.QAction(toolbar)
parentFileDirectory.setText("Parent directory of the file")
parentFileDirectory.setObjectName("toDirectoryAction")
- parentFileDirectory.setIcon(iconProvider.icon(iconProvider.FileDialogToParentDir))
+ parentFileDirectory.setIcon(
+ iconProvider.icon(iconProvider.FileDialogToParentDir)
+ )
parentFileDirectory.triggered.connect(self.__navigateToParentDir)
self.__parentFileDirectoryAction = parentFileDirectory
@@ -818,11 +840,15 @@ class AbstractDataFileDialog(qt.QDialog):
dummyCombo.setFixedHeight(self.__fileTypeCombo.height())
self.__resizeCombo = _CatchResizeEvent(self, self.__fileTypeCombo)
- self.__resizeCombo.resized.connect(lambda e: dummyCombo.setFixedHeight(e.size().height()))
+ self.__resizeCombo.resized.connect(
+ lambda e: dummyCombo.setFixedHeight(e.size().height())
+ )
dummyToolBar.setFixedHeight(self.__browseToolBar.height())
self.__resizeToolbar = _CatchResizeEvent(self, self.__browseToolBar)
- self.__resizeToolbar.resized.connect(lambda e: dummyToolBar.setFixedHeight(e.size().height()))
+ self.__resizeToolbar.resized.connect(
+ lambda e: dummyToolBar.setFixedHeight(e.size().height())
+ )
datasetSelection = qt.QWidget(self)
layoutLeft = qt.QVBoxLayout()
@@ -831,7 +857,9 @@ class AbstractDataFileDialog(qt.QDialog):
layoutLeft.addWidget(self.__browser)
layoutLeft.addWidget(self.__fileTypeCombo)
datasetSelection.setLayout(layoutLeft)
- datasetSelection.setSizePolicy(qt.QSizePolicy.MinimumExpanding, qt.QSizePolicy.Expanding)
+ datasetSelection.setSizePolicy(
+ qt.QSizePolicy.MinimumExpanding, qt.QSizePolicy.Expanding
+ )
infoLayout = qt.QHBoxLayout()
infoLayout.setContentsMargins(0, 0, 0, 0)
@@ -858,7 +886,9 @@ class AbstractDataFileDialog(qt.QDialog):
dummyToolbar2.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
dummyToolbar2.setFixedHeight(self.__browseToolBar.height())
self.__resizeToolbar = _CatchResizeEvent(self, self.__browseToolBar)
- self.__resizeToolbar.resized.connect(lambda e: dummyToolbar2.setFixedHeight(e.size().height()))
+ self.__resizeToolbar.resized.connect(
+ lambda e: dummyToolbar2.setFixedHeight(e.size().height())
+ )
dataLayout.addWidget(dummyToolbar2)
dataLayout.addWidget(dataFrame)
@@ -870,7 +900,9 @@ class AbstractDataFileDialog(qt.QDialog):
dummyCombo2.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
dummyCombo2.setFixedHeight(self.__fileTypeCombo.height())
self.__resizeToolbar = _CatchResizeEvent(self, self.__fileTypeCombo)
- self.__resizeToolbar.resized.connect(lambda e: dummyCombo2.setFixedHeight(e.size().height()))
+ self.__resizeToolbar.resized.connect(
+ lambda e: dummyCombo2.setFixedHeight(e.size().height())
+ )
dataLayout.addWidget(dummyCombo2)
dataSelection.setLayout(dataLayout)
@@ -904,7 +936,10 @@ class AbstractDataFileDialog(qt.QDialog):
def __navigateForward(self):
"""Navigate through the history one step forward."""
- if len(self.__currentHistory) > 0 and self.__currentHistoryLocation < len(self.__currentHistory) - 1:
+ if (
+ len(self.__currentHistory) > 0
+ and self.__currentHistoryLocation < len(self.__currentHistory) - 1
+ ):
self.__currentHistoryLocation += 1
url = self.__currentHistory[self.__currentHistoryLocation]
self.selectUrl(url)
@@ -971,7 +1006,7 @@ class AbstractDataFileDialog(qt.QDialog):
self.__listViewAction.setChecked(True)
self.__detailViewAction.setChecked(False)
else:
- assert(False)
+ assert False
def __showAsListView(self):
self.setViewMode(qt.QFileDialog.List)
@@ -1005,7 +1040,7 @@ class AbstractDataFileDialog(qt.QDialog):
if silx.io.is_group(obj):
self.__browser.setRootIndex(index)
else:
- assert(False)
+ assert False
def __browsedItemSelected(self, index):
self.__dataSelected(index)
@@ -1020,7 +1055,7 @@ class AbstractDataFileDialog(qt.QDialog):
:param str path: Path to load
"""
- assert(path is not None)
+ assert path is not None
if path != "" and not os.path.exists(path):
return
if self.hasPendingEvents():
@@ -1102,8 +1137,7 @@ class AbstractDataFileDialog(qt.QDialog):
return True
def __isSilxHavePriority(self, filename):
- """Silx have priority when there is a specific decoder
- """
+ """Silx have priority when there is a specific decoder"""
_, ext = os.path.splitext(filename)
ext = "*%s" % ext
formats = silx.io.supported_extensions(flat_formats=False)
@@ -1166,14 +1200,17 @@ class AbstractDataFileDialog(qt.QDialog):
if os.path.isfile(path):
codec = self.__fileTypeCombo.currentCodec()
is_fabio_decoder = codec.is_fabio_codec()
- is_fabio_have_priority = not codec.is_silx_codec() and not self.__isSilxHavePriority(path)
+ is_fabio_have_priority = (
+ not codec.is_silx_codec()
+ and not self.__isSilxHavePriority(path)
+ )
if is_fabio_decoder or is_fabio_have_priority:
# Then it's flat frame container
self.__openFabioFile(path)
if self.__fabio is not None:
selectedData = _FabioData(self.__fabio)
else:
- assert(False)
+ assert False
self.__setData(selectedData)
@@ -1193,7 +1230,9 @@ class AbstractDataFileDialog(qt.QDialog):
self.__setSelectedData(data)
self.__selectorWidget.hide()
else:
- self.__selectorWidget.setVisible(self.__selectorWidget.hasVisibleSelectors())
+ self.__selectorWidget.setVisible(
+ self.__selectorWidget.hasVisibleSelectors()
+ )
# Needed to fake the fact we have to reset the zoom in preview
self.__selectedData = None
self.__selectorWidget.selectionChanged.emit()
@@ -1266,7 +1305,10 @@ class AbstractDataFileDialog(qt.QDialog):
def __updateDataInfo(self):
if self.__errorWhileLoadingFile is not None:
filename, message = self.__errorWhileLoadingFile
- message = "<b>Error while loading file '%s'</b><hr/>%s" % (filename, message)
+ message = "<b>Error while loading file '%s'</b><hr/>%s" % (
+ filename,
+ message,
+ )
size = self.__dataInfo.height()
icon = self.style().standardIcon(qt.QStyle.SP_MessageBoxCritical)
pixmap = icon.pixmap(size, size)
@@ -1317,7 +1359,11 @@ class AbstractDataFileDialog(qt.QDialog):
filename = ""
dataPath = None
- if useSelectorWidget and self.__selectorWidget is not None and self.__selectorWidget.isUsed():
+ if (
+ useSelectorWidget
+ and self.__selectorWidget is not None
+ and self.__selectorWidget.isUsed()
+ ):
slicing = self.__selectorWidget.slicing()
if slicing == tuple():
slicing = None
@@ -1340,7 +1386,9 @@ class AbstractDataFileDialog(qt.QDialog):
else:
scheme = None
- url = silx.io.url.DataUrl(file_path=filename, data_path=dataPath, data_slice=slicing, scheme=scheme)
+ url = silx.io.url.DataUrl(
+ file_path=filename, data_path=dataPath, data_slice=slicing, scheme=scheme
+ )
return url
def __updatePath(self):
@@ -1362,7 +1410,9 @@ class AbstractDataFileDialog(qt.QDialog):
if currentUrl is None or currentUrl != url.path():
# clean up the forward history
- self.__currentHistory = self.__currentHistory[0:self.__currentHistoryLocation + 1]
+ self.__currentHistory = self.__currentHistory[
+ 0 : self.__currentHistoryLocation + 1
+ ]
self.__currentHistory.append(url.path())
self.__currentHistoryLocation += 1
@@ -1400,15 +1450,16 @@ class AbstractDataFileDialog(qt.QDialog):
selectionModel.selectionChanged.connect(self.__shortcutSelected)
def __updateActionHistory(self):
- self.__forwardAction.setEnabled(len(self.__currentHistory) - 1 > self.__currentHistoryLocation)
+ self.__forwardAction.setEnabled(
+ len(self.__currentHistory) - 1 > self.__currentHistoryLocation
+ )
self.__backwardAction.setEnabled(self.__currentHistoryLocation > 0)
def __textChanged(self, text):
self.__pathChanged()
def _isFabioFilesSupported(self):
- """Returns true fabio files can be loaded.
- """
+ """Returns true fabio files can be loaded."""
return True
def _isLoadableUrl(self, url):
@@ -1479,7 +1530,7 @@ class AbstractDataFileDialog(qt.QDialog):
# data = _FabioData(self.__fabio)
# self.__setData(data)
else:
- assert(False)
+ assert False
else:
self.__browser.setRootIndex(index, model=self.__fileModel)
self.__clearData()
@@ -1615,7 +1666,7 @@ class AbstractDataFileDialog(qt.QDialog):
"""
if len(self.__currentHistory) <= 1:
return []
- history = self.__currentHistory[0:self.__currentHistoryLocation]
+ history = self.__currentHistory[0 : self.__currentHistoryLocation]
return list(history)
def setHistory(self, history):
@@ -1670,12 +1721,18 @@ class AbstractDataFileDialog(qt.QDialog):
qualifiedName = stream.readQString()
if qualifiedName != self.qualifiedName():
- _logger.warning("Stored state contains an invalid qualified name. %s restoration cancelled.", self.__class__.__name__)
+ _logger.warning(
+ "Stored state contains an invalid qualified name. %s restoration cancelled.",
+ self.__class__.__name__,
+ )
return False
version = stream.readInt32()
if version != self.__serialVersion:
- _logger.warning("Stored state contains an invalid version. %s restoration cancelled.", self.__class__.__name__)
+ _logger.warning(
+ "Stored state contains an invalid version. %s restoration cancelled.",
+ self.__class__.__name__,
+ )
return False
result = True
@@ -1713,17 +1770,17 @@ class AbstractDataFileDialog(qt.QDialog):
stream = qt.QDataStream(data, qt.QIODevice.WriteOnly)
s = self.qualifiedName()
- stream.writeQString(u"%s" % s)
+ stream.writeQString("%s" % s)
stream.writeInt32(self.__serialVersion)
stream.writeQVariant(self.__splitter.saveState())
- strings = [u"%s" % s.toString() for s in self.sidebarUrls()]
+ strings = ["%s" % s.toString() for s in self.sidebarUrls()]
stream.writeQStringList(strings)
- strings = [u"%s" % s for s in self.history()]
+ strings = ["%s" % s for s in self.history()]
stream.writeQStringList(strings)
- stream.writeQString(u"%s" % self.directory())
+ stream.writeQString("%s" % self.directory())
stream.writeQVariant(self.__browser.saveState())
viewMode = self.viewMode()
- if qt.BINDING == 'PyQt6': # No auto conversion to int
+ if qt.BINDING in ("PyQt6", "PySide6"): # No auto conversion to int
viewMode = viewMode.value
stream.writeInt32(viewMode)
colormap = self.colormap()
diff --git a/src/silx/gui/dialog/ColormapDialog.py b/src/silx/gui/dialog/ColormapDialog.py
index f3f38b5..75ab39e 100644
--- a/src/silx/gui/dialog/ColormapDialog.py
+++ b/src/silx/gui/dialog/ColormapDialog.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2004-2022 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -58,6 +58,8 @@ The updates of the colormap description are also available through the signal:
:attr:`ColormapDialog.sigColormapChanged`.
""" # noqa
+from __future__ import annotations
+
__authors__ = ["V.A. Sole", "T. Vincent", "H. Payno"]
__license__ = "MIT"
__date__ = "08/12/2020"
@@ -82,9 +84,9 @@ from silx.gui.qt import inspect as qtinspect
from silx.gui.widgets.ColormapNameComboBox import ColormapNameComboBox
from silx.gui.widgets.FormGridLayout import FormGridLayout
from silx.math.histogram import Histogramnd
-from silx.utils import deprecation
from silx.gui.plot.items.roi import RectangleROI
from silx.gui.plot.tools.roi import RegionOfInterestManager
+from silx.utils.enum import Enum as _Enum
_logger = logging.getLogger(__name__)
@@ -128,8 +130,8 @@ class _BoundaryWidget(qt.QWidget):
self.layout().setContentsMargins(0, 0, 0, 0)
self._numVal = FloatEdit(parent=self, value=value)
- self._iconAuto = icons.getQIcon('scale-auto')
- self._iconFixed = icons.getQIcon('scale-fixed')
+ self._iconAuto = icons.getQIcon("scale-auto")
+ self._iconFixed = icons.getQIcon("scale-fixed")
self._autoToggleAction = qt.QAction(self)
self._autoToggleAction.setText("Auto scale")
@@ -142,7 +144,7 @@ class _BoundaryWidget(qt.QWidget):
self._numVal.addAction(self._autoToggleAction, qt.QLineEdit.LeadingPosition)
self.layout().addWidget(self._numVal)
- self._autoCB = qt.QCheckBox('auto', parent=self)
+ self._autoCB = qt.QCheckBox("auto", parent=self)
self.layout().addWidget(self._autoCB)
self._autoCB.setChecked(False)
self._autoCB.setVisible(False)
@@ -220,7 +222,7 @@ class _BoundaryWidget(qt.QWidget):
color = palette.color(qt.QPalette.Disabled, qt.QPalette.Base)
icon = self._iconAuto
else:
- color = palette.color(qt.QPalette.Normal, qt.QPalette.Base)
+ color = palette.color(qt.QPalette.Active, qt.QPalette.Base)
icon = self._iconFixed
palette.setColor(qt.QPalette.Base, color)
self._numVal.setPalette(palette)
@@ -228,7 +230,6 @@ class _BoundaryWidget(qt.QWidget):
class _AutoscaleModeComboBox(qt.QComboBox):
-
DATA = {
Colormap.MINMAX: ("Min/max", "Use the data min/max"),
Colormap.STDDEV3: ("Mean±3std", "Use the data mean ± 3 × standard deviation"),
@@ -275,7 +276,6 @@ class _AutoscaleModeComboBox(qt.QComboBox):
class _AutoScaleButton(qt.QPushButton):
-
autoRangeChanged = qt.Signal(object)
def __init__(self, parent=None):
@@ -300,11 +300,13 @@ class _AutoScaleButton(qt.QPushButton):
with utils.blockSignals(self):
self.setChecked(autoRange[0] if autoRange[0] == autoRange[1] else False)
+
@enum.unique
-class _DataInPlotMode(enum.Enum):
+class DisplayMode(_Enum):
"""Enum for each mode of display of the data in the plot."""
- RANGE = 'range'
- HISTOGRAM = 'histogram'
+
+ RANGE = "range"
+ HISTOGRAM = "histogram"
class _ColormapHistogram(qt.QWidget):
@@ -333,7 +335,7 @@ class _ColormapHistogram(qt.QWidget):
def __init__(self, parent):
qt.QWidget.__init__(self, parent=parent)
- self._dataInPlotMode = _DataInPlotMode.RANGE
+ self._displayMode = DisplayMode.RANGE
self._finiteRange = None, None
self._initPlot()
@@ -350,7 +352,7 @@ class _ColormapHistogram(qt.QWidget):
def paintEvent(self, event):
if self._invalidated:
- self._updateDataInPlot()
+ self._updateDisplayMode()
self._invalidated = False
self._updateMarkerPosition()
return super(_ColormapHistogram, self).paintEvent(event)
@@ -419,7 +421,9 @@ class _ColormapHistogram(qt.QWidget):
dataRange = self._getNormalizedDataRange()
if dataRange[0] is None or dataRange[1] is None:
return None, None
- counts, edges = self.parent().computeHistogram(data, scale=norm, dataRange=dataRange)
+ counts, edges = self.parent().computeHistogram(
+ data, scale=norm, dataRange=dataRange
+ )
return counts, edges
def _getNormalizedDataRange(self):
@@ -511,22 +515,22 @@ class _ColormapHistogram(qt.QWidget):
self._plot.setDataMargins(0.125, 0.125, 0.01, 0.01)
self._plot.getXAxis().setLabel("Data Values")
self._plot.getYAxis().setLabel("")
- self._plot.setInteractiveMode('select', zoomOnWheel=False)
+ self._plot.setInteractiveMode("select", zoomOnWheel=False)
self._plot.setActiveCurveHandling(False)
self._plot.setMinimumSize(qt.QSize(250, 200))
self._plot.sigPlotSignal.connect(self._plotEventReceived)
palette = self.palette()
- color = palette.color(qt.QPalette.Normal, qt.QPalette.Window)
+ color = palette.color(qt.QPalette.Active, qt.QPalette.Window)
self._plot.setBackgroundColor(color)
self._plot.setDataBackgroundColor("white")
lut = numpy.arange(256)
lut.shape = 1, -1
- self._plot.addImage(lut, legend='lut')
+ self._plot.addImage(lut, legend="lut")
self._lutItem = self._plot._getItem("image", "lut")
self._lutItem.setVisible(False)
- self._plot.addScatter(x=[], y=[], value=[], legend='lut2')
+ self._plot.addScatter(x=[], y=[], value=[], legend="lut2")
self._lutItem2 = self._plot._getItem("scatter", "lut2")
self._lutItem2.setVisible(False)
self.__lutY = numpy.array([-0.05] * 256)
@@ -546,24 +550,32 @@ class _ColormapHistogram(qt.QWidget):
group = qt.QActionGroup(self._plotToolbar)
group.setExclusive(True)
-
+ # data range mode
action = qt.QAction("Data range", self)
- action.setToolTip("Display the data range within the colormap range. A fast data processing have to be done.")
- action.setIcon(icons.getQIcon('colormap-range'))
+ action.setToolTip(
+ "Display the data range within the colormap range. A fast data processing have to be done."
+ )
+ action.setIcon(icons.getQIcon("colormap-range"))
action.setCheckable(True)
- action.setData(_DataInPlotMode.RANGE)
- action.setChecked(action.data() == self._dataInPlotMode)
+ action.setData(DisplayMode.RANGE)
+ action.setChecked(action.data() == self._displayMode)
self._plotToolbar.addAction(action)
group.addAction(action)
+ self._dataRangeAction = action
+ # histogram mode
action = qt.QAction("Histogram", self)
- action.setToolTip("Display the data histogram within the colormap range. A slow data processing have to be done. ")
- action.setIcon(icons.getQIcon('colormap-histogram'))
+ action.setToolTip(
+ "Display the data histogram within the colormap range. A slow data processing have to be done. "
+ )
+ action.setIcon(icons.getQIcon("colormap-histogram"))
action.setCheckable(True)
- action.setData(_DataInPlotMode.HISTOGRAM)
- action.setChecked(action.data() == self._dataInPlotMode)
+ action.setData(DisplayMode.HISTOGRAM)
+ action.setChecked(action.data() == self._displayMode)
self._plotToolbar.addAction(action)
group.addAction(action)
- group.triggered.connect(self._displayDataInPlotModeChanged)
+ self._dataHistogramAction = action
+ group.setExclusive(True)
+ group.triggered.connect(self._displayModeChanged)
plotBoxLayout = qt.QHBoxLayout()
plotBoxLayout.setContentsMargins(0, 0, 0, 0)
@@ -575,28 +587,28 @@ class _ColormapHistogram(qt.QWidget):
def _plotEventReceived(self, event):
"""Handle events from the plot"""
- kind = event['event']
+ kind = event["event"]
- if kind == 'markerMoving':
- value = event['xdata']
- if event['label'] == 'Min':
+ if kind == "markerMoving":
+ value = event["xdata"]
+ if event["label"] == "Min":
self._dragging = True, False, False
self._finiteRange = value, self._finiteRange[1]
self._last = value, None, None
self._updateGammaPosition()
self.sigRangeMoving.emit(*self._last)
- elif event['label'] == 'Max':
+ elif event["label"] == "Max":
self._dragging = False, True, False
self._finiteRange = self._finiteRange[0], value
self._last = None, value, None
self._updateGammaPosition()
self.sigRangeMoving.emit(*self._last)
- elif event['label'] == 'Gamma':
+ elif event["label"] == "Gamma":
self._dragging = False, False, True
self._last = None, None, value
self.sigRangeMoving.emit(*self._last)
self._updateLutItem(self._finiteRange)
- elif kind == 'markerMoved':
+ elif kind == "markerMoved":
self.sigRangeMoved.emit(*self._last)
self._plot.resetZoom()
self._dragging = False, False, False
@@ -616,20 +628,22 @@ class _ColormapHistogram(qt.QWidget):
if posMin is not None and not self._dragging[0]:
self._plot.addXMarker(
posMin,
- legend='Min',
- text='Min',
+ legend="Min",
+ text="Min",
draggable=isDraggable,
color="blue",
- constraint=self._plotMinMarkerConstraint)
+ constraint=self._plotMinMarkerConstraint,
+ )
self._updateGammaPosition()
if posMax is not None and not self._dragging[1]:
self._plot.addXMarker(
posMax,
- legend='Max',
- text='\n\nMax',
+ legend="Max",
+ text="\n\nMax",
draggable=isDraggable,
color="blue",
- constraint=self._plotMaxMarkerConstraint)
+ constraint=self._plotMaxMarkerConstraint,
+ )
self._updateLutItem((posMin, posMax))
self._plot.resetZoom()
@@ -650,14 +664,14 @@ class _ColormapHistogram(qt.QWidget):
if not self._dragging[2]:
posRange = posMax - posMin
if posRange > 0:
- gammaPos = posMin + posRange * 0.5**(1/gamma)
+ gammaPos = posMin + posRange * 0.5 ** (1 / gamma)
else:
gammaPos = posMin
marker = self._plot._getMarker(
self._plot.addXMarker(
gammaPos,
- legend='Gamma',
- text='\nGamma',
+ legend="Gamma",
+ text="\nGamma",
draggable=True,
color="blue",
constraint=self._plotGammaMarkerConstraint,
@@ -666,7 +680,7 @@ class _ColormapHistogram(qt.QWidget):
marker.setZValue(2)
else:
try:
- self._plot.removeMarker('Gamma')
+ self._plot.removeMarker("Gamma")
except Exception:
pass
@@ -703,10 +717,9 @@ class _ColormapHistogram(qt.QWidget):
self._lutItem2.setVisible(False)
self._lutItem2.setColormap(normColormap)
xx = numpy.geomspace(posMin, posMax, 256)
- self._lutItem2.setData(x=xx,
- y=self.__lutY,
- value=self.__lutV,
- copy=False)
+ self._lutItem2.setData(
+ x=xx, y=self.__lutY, value=self.__lutV, copy=False
+ )
self._lutItem2.setSymbol("|")
self._lutItem2.setVisible(True)
self._lutItem.setVisible(False)
@@ -717,10 +730,8 @@ class _ColormapHistogram(qt.QWidget):
self._lutItem2.setColormap(normColormap)
xx = numpy.linspace(posMin, posMax, 256, endpoint=True)
self._lutItem2.setData(
- x=xx,
- y=self.__lutY,
- value=self.__lutV,
- copy=False)
+ x=xx, y=self.__lutY, value=self.__lutV, copy=False
+ )
self._lutItem2.setSymbol("|")
self._lutItem2.setVisible(True)
self._lutItem.setVisible(False)
@@ -750,15 +761,29 @@ class _ColormapHistogram(qt.QWidget):
x = min(x, vmax)
return x, y
- def _setDataInPlotMode(self, mode):
- if self._dataInPlotMode == mode:
+ def setDisplayMode(self, mode: str | DisplayMode):
+ mode = DisplayMode.from_value(mode)
+ if mode is DisplayMode.HISTOGRAM:
+ action = self._dataHistogramAction
+ elif mode is DisplayMode.RANGE:
+ action = self._dataRangeAction
+ else:
+ raise ValueError("Mode not supported")
+ action.setChecked(True)
+ self._displayModeChanged(action)
+
+ def _setDisplayMode(self, mode):
+ if self._displayMode == mode:
return
- self._dataInPlotMode = mode
- self._updateDataInPlot()
+ self._displayMode = mode
+ self._updateDisplayMode()
- def _displayDataInPlotModeChanged(self, action):
+ def getDsiplayMode(self) -> DisplayMode:
+ return self._displayMode
+
+ def _displayModeChanged(self, action):
mode = action.data()
- self._setDataInPlotMode(mode)
+ self._setDisplayMode(mode)
def invalidateData(self):
self._histogramData = {}
@@ -766,8 +791,8 @@ class _ColormapHistogram(qt.QWidget):
self._invalidated = True
self.update()
- def _updateDataInPlot(self):
- mode = self._dataInPlotMode
+ def _updateDisplayMode(self):
+ mode = self._displayMode
norm = self._getNorm()
if norm == Colormap.LINEAR:
@@ -780,38 +805,42 @@ class _ColormapHistogram(qt.QWidget):
axis = self._plot.getXAxis()
axis.setScale(scale)
- if mode == _DataInPlotMode.RANGE:
+ if mode == DisplayMode.RANGE:
dataRange = self._getNormalizedDataRange()
xmin, xmax = dataRange
if xmax is None or xmin is None:
- self._plot.remove(legend='Data', kind='histogram')
+ self._plot.remove(legend="Data", kind="histogram")
else:
histogram = numpy.array([1])
bin_edges = numpy.array([xmin, xmax])
- self._plot.addHistogram(histogram,
- bin_edges,
- legend="Data",
- color='gray',
- align='center',
- fill=True,
- z=1)
-
- elif mode == _DataInPlotMode.HISTOGRAM:
+ self._plot.addHistogram(
+ histogram,
+ bin_edges,
+ legend="Data",
+ color="gray",
+ align="center",
+ fill=True,
+ z=1,
+ )
+
+ elif mode == DisplayMode.HISTOGRAM:
histogram, bin_edges = self._getNormalizedHistogram()
if histogram is None or bin_edges is None:
- self._plot.remove(legend='Data', kind='histogram')
+ self._plot.remove(legend="Data", kind="histogram")
else:
histogram = numpy.array(histogram, copy=True)
bin_edges = numpy.array(bin_edges, copy=True)
- with numpy.errstate(invalid='ignore'):
+ with numpy.errstate(invalid="ignore"):
norm_histogram = histogram / numpy.nanmax(histogram)
- self._plot.addHistogram(norm_histogram,
- bin_edges,
- legend="Data",
- color='gray',
- align='center',
- fill=True,
- z=1)
+ self._plot.addHistogram(
+ norm_histogram,
+ bin_edges,
+ legend="Data",
+ color="gray",
+ align="center",
+ fill=True,
+ z=1,
+ )
else:
_logger.error("Mode unsupported")
@@ -830,7 +859,7 @@ class _ColormapHistogram(qt.QWidget):
return norm
def updateNormalization(self):
- self._updateDataInPlot()
+ self._updateDisplayMode()
self.update()
@@ -887,16 +916,19 @@ class ColormapDialog(qt.QDialog):
# Colormap row
self._comboBoxColormap = ColormapNameComboBox(parent=self)
- self._comboBoxColormap.currentIndexChanged[int].connect(self._comboBoxColormapUpdated)
+ self._comboBoxColormap.currentIndexChanged[int].connect(
+ self._comboBoxColormapUpdated
+ )
# Normalization row
self._comboBoxNormalization = qt.QComboBox(parent=self)
normalizations = [
- ('Linear', Colormap.LINEAR),
- ('Gamma correction', Colormap.GAMMA),
- ('Arcsinh', Colormap.ARCSINH),
- ('Logarithmic', Colormap.LOGARITHM),
- ('Square root', Colormap.SQRT)]
+ ("Linear", Colormap.LINEAR),
+ ("Gamma correction", Colormap.GAMMA),
+ ("Arcsinh", Colormap.ARCSINH),
+ ("Logarithmic", Colormap.LOGARITHM),
+ ("Square root", Colormap.SQRT),
+ ]
for name, userData in normalizations:
try:
icon = icons.getQIcon("colormap-norm-%s" % userData)
@@ -904,11 +936,12 @@ class ColormapDialog(qt.QDialog):
icon = qt.QIcon()
self._comboBoxNormalization.addItem(icon, name, userData)
self._comboBoxNormalization.currentIndexChanged[int].connect(
- self._normalizationUpdated)
+ self._normalizationUpdated
+ )
self._gammaSpinBox = qt.QDoubleSpinBox(parent=self)
self._gammaSpinBox.setEnabled(False)
- self._gammaSpinBox.setRange(0.01, 100.)
+ self._gammaSpinBox.setRange(0.01, 100.0)
self._gammaSpinBox.setDecimals(4)
if hasattr(qt.QDoubleSpinBox, "setStepType"):
# Introduced in Qt 5.12
@@ -916,7 +949,7 @@ class ColormapDialog(qt.QDialog):
else:
self._gammaSpinBox.setSingleStep(0.1)
self._gammaSpinBox.valueChanged.connect(self._gammaUpdated)
- self._gammaSpinBox.setValue(2.)
+ self._gammaSpinBox.setValue(2.0)
autoScaleCombo = _AutoscaleModeComboBox(self)
autoScaleCombo.currentIndexChanged.connect(self._autoscaleModeUpdated)
@@ -959,15 +992,17 @@ class ColormapDialog(qt.QDialog):
self._histoWidget = _ColormapHistogram(self)
self._histoWidget.sigRangeMoving.connect(self._histogramRangeMoving)
self._histoWidget.sigRangeMoved.connect(self._histogramRangeMoved)
- self._histoWidget.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Expanding)
+ self._histoWidget.setSizePolicy(
+ qt.QSizePolicy.Expanding, qt.QSizePolicy.Expanding
+ )
# Scale to buttons
self._visibleAreaButton = qt.QPushButton(self)
self._visibleAreaButton.setEnabled(False)
self._visibleAreaButton.setText("Visible Area")
self._visibleAreaButton.clicked.connect(
- self._handleScaleToVisibleAreaClicked,
- type=qt.Qt.QueuedConnection)
+ self._handleScaleToVisibleAreaClicked, type=qt.Qt.QueuedConnection
+ )
# Place-holder for selected area ROI manager
self._roiForColormapManager = None
@@ -979,8 +1014,8 @@ class ColormapDialog(qt.QDialog):
self._selectedAreaButton.setIcon(icons.getQIcon("add-shape-rectangle"))
self._selectedAreaButton.setCheckable(True)
self._selectedAreaButton.toggled.connect(
- self._handleScaleToSelectionToggled,
- type=qt.Qt.QueuedConnection)
+ self._handleScaleToSelectionToggled, type=qt.Qt.QueuedConnection
+ )
# define modal buttons
types = qt.QDialogButtonBox.Ok | qt.QDialogButtonBox.Cancel
@@ -1003,8 +1038,9 @@ class ColormapDialog(qt.QDialog):
self._buttonsNonModal.setFocus(qt.Qt.OtherFocusReason)
# Set the colormap to default values
- self.setColormap(Colormap(name='gray', normalization='linear',
- vmin=None, vmax=None))
+ self.setColormap(
+ Colormap(name="gray", normalization="linear", vmin=None, vmax=None)
+ )
self.setModal(self.isModal())
@@ -1023,19 +1059,23 @@ class ColormapDialog(qt.QDialog):
layoutScale.addWidget(self._autoScaleCombo)
layoutScale.addStretch()
-
formLayout = FormGridLayout(self)
formLayout.setContentsMargins(10, 10, 10, 10)
- formLayout.addRow('Colormap:', self._comboBoxColormap)
- formLayout.addRow('Normalization:', self._comboBoxNormalization)
- formLayout.addRow('Gamma:', self._gammaSpinBox)
+ formLayout.addRow("Colormap:", self._comboBoxColormap)
+ formLayout.addRow("Normalization:", self._comboBoxNormalization)
+ formLayout.addRow("Gamma:", self._gammaSpinBox)
- formLayout.addItem(qt.QSpacerItem(1, 1, qt.QSizePolicy.Fixed, qt.QSizePolicy.Fixed))
+ formLayout.addItem(
+ qt.QSpacerItem(1, 1, qt.QSizePolicy.Fixed, qt.QSizePolicy.Fixed)
+ )
formLayout.addRow(self._histoWidget)
+ formLayout.setRowStretch(formLayout.rowCount() - 1, 1)
formLayout.addRow(rangeLayout)
- formLayout.addItem(qt.QSpacerItem(1, 1, qt.QSizePolicy.Fixed, qt.QSizePolicy.Fixed))
- formLayout.addRow('Scale:', layoutScale)
+ formLayout.addItem(
+ qt.QSpacerItem(1, 1, qt.QSizePolicy.Fixed, qt.QSizePolicy.Fixed)
+ )
+ formLayout.addRow("Scale:", layoutScale)
formLayout.addRow("Fixed scale on:", self._scaleToAreaGroup)
formLayout.addRow(self._buttonsModal)
formLayout.addRow(self._buttonsNonModal)
@@ -1054,6 +1094,9 @@ class ColormapDialog(qt.QDialog):
self._applyColormap()
+ def getHistogramWidget(self):
+ return self._histoWidget
+
def _invalidateColormap(self):
if self.isVisible():
self._applyColormap()
@@ -1156,9 +1199,9 @@ class ColormapDialog(qt.QDialog):
if dataRange is None or len(dataRange) != 3:
qt.QMessageBox.warning(
- None, "No Data",
- "Image data does not contain any real value")
- dataRange = 1., 1., 10.
+ None, "No Data", "Image data does not contain any real value"
+ )
+ dataRange = 1.0, 1.0, 10.0
return dataRange
@@ -1182,11 +1225,11 @@ class ColormapDialog(qt.QDialog):
return None, None
if data.ndim == 3: # RGB(A) images
- _logger.info('Converting current image from RGB(A) to grayscale\
- in order to compute the intensity distribution')
- data = (data[:,:, 0] * 0.299 +
- data[:,:, 1] * 0.587 +
- data[:,:, 2] * 0.114)
+ _logger.info(
+ "Converting current image from RGB(A) to grayscale\
+ in order to compute the intensity distribution"
+ )
+ data = data[:, :, 0] * 0.299 + data[:, :, 1] * 0.587 + data[:, :, 2] * 0.114
# bad hack: get 256 continuous bins in the case we have a B&W
normalizeData = True
@@ -1200,7 +1243,7 @@ class ColormapDialog(qt.QDialog):
if normalizeData:
if scale == Colormap.LOGARITHM:
- with numpy.errstate(divide='ignore', invalid='ignore'):
+ with numpy.errstate(divide="ignore", invalid="ignore"):
data = numpy.log10(data)
if dataRange is not None:
@@ -1231,7 +1274,7 @@ class ColormapDialog(qt.QDialog):
bins = histogram.edges[0]
if normalizeData:
if scale == Colormap.LOGARITHM:
- bins = 10 ** bins
+ bins = 10**bins
return histogram.histo, bins
def _getItem(self):
@@ -1263,8 +1306,7 @@ class ColormapDialog(qt.QDialog):
if oldArray is array:
return
- self._data = None
- self._itemHolder = None
+ self.__resetItem()
try:
if item is None:
self._item = None
@@ -1272,6 +1314,7 @@ class ColormapDialog(qt.QDialog):
if not isinstance(item, items.ColormapMixIn):
self._item = None
raise ValueError("Item %s is not supported" % item)
+ item.sigItemChanged.connect(self.__itemChanged)
self._item = weakref.ref(item, self._itemAboutToFinalize)
finally:
self._syncScaleToButtonsEnabled()
@@ -1279,6 +1322,20 @@ class ColormapDialog(qt.QDialog):
self._histogramData = None
self._invalidateData()
+ def __resetItem(self):
+ """Reset item and data used by the dialog"""
+ self._data = None
+ self._itemHolder = None
+ if self._item is not None:
+ item = self._item()
+ self._item = None
+ if item is not None:
+ item.sigItemChanged.disconnect(self.__itemChanged)
+
+ def __itemChanged(self, event):
+ if event == items.ItemChangedType.DATA:
+ self._invalidateData()
+
def _getData(self):
if self._data is None:
return None
@@ -1295,12 +1352,9 @@ class ColormapDialog(qt.QDialog):
if oldData is data:
return
- self._item = None
+ self.__resetItem()
self._syncScaleToButtonsEnabled()
- if data is None:
- self._data = None
- self._itemHolder = None
- else:
+ if data is not None:
self._data = weakref.ref(data, self._dataAboutToFinalize)
self._itemHolder = _DataRefHolder(self._data)
@@ -1338,14 +1392,6 @@ class ColormapDialog(qt.QDialog):
if self._item is weakref and qtinspect.isValid(self):
self.setItem(None)
- @deprecation.deprecated(reason="It is private data", since_version="0.13")
- def getHistogram(self):
- histo = self._getHistogram()
- if histo is None:
- return None
- counts, bin_edges = histo
- return numpy.array(counts, copy=True), numpy.array(bin_edges, copy=True)
-
def _getHistogram(self):
"""Returns the histogram defined by the dialog as metadata
to describe the data in order to speed up the dialog.
@@ -1429,7 +1475,7 @@ class ColormapDialog(qt.QDialog):
(xmin, xmax, ymin, ymax) Rectangular region in data space
"""
if bounds is None:
- return None # no-op
+ return # no-op
colormap = self.getColormap()
if colormap is None:
@@ -1437,13 +1483,15 @@ class ColormapDialog(qt.QDialog):
item = self._getItem()
if not isinstance(item, items.ColormapMixIn):
- return None # no-op
+ return # no-op
data = item.getColormappedData(copy=False)
-
xmin, xmax, ymin, ymax = bounds
if isinstance(item, items.ImageBase):
+ if data.ndim != 2:
+ return # no-op
+
ox, oy = item.getOrigin()
sx, sy = item.getScale()
@@ -1460,7 +1508,9 @@ class ColormapDialog(qt.QDialog):
subset = data[
numpy.logical_and(
numpy.logical_and(xmin <= x, x <= xmax),
- numpy.logical_and(ymin <= y, y <= ymax))]
+ numpy.logical_and(ymin <= y, y <= ymax),
+ )
+ ]
if subset.size == 0:
return # no-op
@@ -1563,19 +1613,21 @@ class ColormapDialog(qt.QDialog):
self._comboBoxColormap.setEnabled(colormap.isEditable())
with utils.blockSignals(self._comboBoxNormalization):
index = self._comboBoxNormalization.findData(
- colormap.getNormalization())
+ colormap.getNormalization()
+ )
if index < 0:
- _logger.error('Unsupported normalization: %s' %
- colormap.getNormalization())
+ _logger.error(
+ "Unsupported normalization: %s" % colormap.getNormalization()
+ )
else:
self._comboBoxNormalization.setCurrentIndex(index)
self._comboBoxNormalization.setEnabled(colormap.isEditable())
with utils.blockSignals(self._gammaSpinBox):
- self._gammaSpinBox.setValue(
- colormap.getGammaNormalizationParameter())
+ self._gammaSpinBox.setValue(colormap.getGammaNormalizationParameter())
self._gammaSpinBox.setEnabled(
- colormap.getNormalization() == Colormap.GAMMA and
- colormap.isEditable())
+ colormap.getNormalization() == Colormap.GAMMA
+ and colormap.isEditable()
+ )
with utils.blockSignals(self._autoScaleCombo):
self._autoScaleCombo.setCurrentMode(colormap.getAutoscaleMode())
self._autoScaleCombo.setEnabled(colormap.isEditable())
@@ -1624,8 +1676,8 @@ class ColormapDialog(qt.QDialog):
dataRange = self._getFiniteColormapRange()
# Final colormap range
- vmin = (dataRange[0] if not autoRange[0] else None)
- vmax = (dataRange[1] if not autoRange[1] else None)
+ vmin = dataRange[0] if not autoRange[0] else None
+ vmax = dataRange[1] if not autoRange[1] else None
with self._colormapChange:
colormap = self.getColormap()
@@ -1645,7 +1697,7 @@ class ColormapDialog(qt.QDialog):
colormap = self.getColormap()
if colormap is not None:
normalization = self._comboBoxNormalization.itemData(index)
- self._gammaSpinBox.setEnabled(normalization == 'gamma')
+ self._gammaSpinBox.setEnabled(normalization == "gamma")
with self._colormapChange:
colormap.setNormalization(normalization)
@@ -1745,7 +1797,7 @@ class ColormapDialog(qt.QDialog):
gamma = self._gammaSpinBox.minimum()
else:
gamma = numpy.clip(
- numpy.log(0.5)/numpy.log((gammaPos - vmin) / (vmax - vmin)),
+ numpy.log(0.5) / numpy.log((gammaPos - vmin) / (vmax - vmin)),
self._gammaSpinBox.minimum(),
self._gammaSpinBox.maximum(),
)
@@ -1771,7 +1823,9 @@ class ColormapDialog(qt.QDialog):
def _syncScaleToButtonsEnabled(self):
"""Set the state of scale to buttons according to current item and colormap"""
colormap = self.getColormap()
- enabled = self._item is not None and colormap is not None and colormap.isEditable()
+ enabled = (
+ self._item is not None and colormap is not None and colormap.isEditable()
+ )
self._scaleToAreaGroup.setVisible(enabled)
self._visibleAreaButton.setEnabled(enabled)
if not enabled:
@@ -1817,10 +1871,14 @@ class ColormapDialog(qt.QDialog):
self._roiForColormapManager = RegionOfInterestManager(parent=plotWidget)
cmap = self.getColormap()
self._roiForColormapManager.setColor(
- 'black' if cmap is None else cursorColorForColormap(cmap.getName()))
+ "black" if cmap is None else cursorColorForColormap(cmap.getName())
+ )
self._roiForColormapManager.sigInteractiveModeFinished.connect(
- self.__roiInteractiveModeFinished)
- self._roiForColormapManager.sigInteractiveRoiFinalized.connect(self.__roiFinalized)
+ self.__roiInteractiveModeFinished
+ )
+ self._roiForColormapManager.sigInteractiveRoiFinalized.connect(
+ self.__roiFinalized
+ )
self._roiForColormapManager.start(RectangleROI)
def __roiInteractiveModeFinished(self):
@@ -1830,7 +1888,7 @@ class ColormapDialog(qt.QDialog):
if roi is not None:
ox, oy = roi.getOrigin()
width, height = roi.getSize()
- self.setColormapRangeFromDataBounds((ox, ox+width, oy, oy+height))
+ self.setColormapRangeFromDataBounds((ox, ox + width, oy, oy + height))
# clear ROI
self._roiForColormapManager.removeRoi(roi)
diff --git a/src/silx/gui/dialog/DataFileDialog.py b/src/silx/gui/dialog/DataFileDialog.py
index 75b1721..4c6891e 100644
--- a/src/silx/gui/dialog/DataFileDialog.py
+++ b/src/silx/gui/dialog/DataFileDialog.py
@@ -36,8 +36,6 @@ from silx.gui.hdf5.Hdf5Formatter import Hdf5Formatter
import silx.io
from .AbstractDataFileDialog import AbstractDataFileDialog
-import fabio
-
_logger = logging.getLogger(__name__)
@@ -336,4 +334,4 @@ class DataFileDialog(AbstractDataFileDialog):
selection widget (basically the data from the browsing widget)
:rtype: bool
"""
- return u""
+ return ""
diff --git a/src/silx/gui/dialog/DatasetDialog.py b/src/silx/gui/dialog/DatasetDialog.py
index 5d8af0d..1bc2722 100644
--- a/src/silx/gui/dialog/DatasetDialog.py
+++ b/src/silx/gui/dialog/DatasetDialog.py
@@ -60,17 +60,22 @@ class DatasetDialog(_Hdf5ItemSelectionDialog):
print("Operation cancelled :(")
"""
+
def __init__(self, parent=None):
_Hdf5ItemSelectionDialog.__init__(self, parent)
# customization for groups
self.setWindowTitle("HDF5 dataset selection")
- self._header.setSections([self._model.NAME_COLUMN,
- self._model.NODE_COLUMN,
- self._model.LINK_COLUMN,
- self._model.TYPE_COLUMN,
- self._model.SHAPE_COLUMN])
+ self._header.setSections(
+ [
+ self._model.NAME_COLUMN,
+ self._model.NODE_COLUMN,
+ self._model.LINK_COLUMN,
+ self._model.TYPE_COLUMN,
+ self._model.SHAPE_COLUMN,
+ ]
+ )
self._selectDatasetStatusText = "Select a dataset or type a new dataset name"
def setMode(self, mode):
@@ -80,7 +85,9 @@ class DatasetDialog(_Hdf5ItemSelectionDialog):
"""
_Hdf5ItemSelectionDialog.setMode(self, mode)
if mode == DatasetDialog.SaveMode:
- self._selectDatasetStatusText = "Select a dataset or type a new dataset name"
+ self._selectDatasetStatusText = (
+ "Select a dataset or type a new dataset name"
+ )
elif mode == DatasetDialog.LoadMode:
self._selectDatasetStatusText = "Select a dataset"
@@ -110,11 +117,11 @@ class DatasetDialog(_Hdf5ItemSelectionDialog):
isDatasetSelected = True
if isDatasetSelected:
- self._selectedUrl = DataUrl(file_path=node.local_filename,
- data_path=data_path)
+ self._selectedUrl = DataUrl(
+ file_path=node.local_filename, data_path=data_path
+ )
self._okButton.setEnabled(True)
- self._labelSelection.setText(
- self._selectedUrl.path())
+ self._labelSelection.setText(self._selectedUrl.path())
else:
self._selectedUrl = None
self._okButton.setEnabled(False)
diff --git a/src/silx/gui/dialog/FileTypeComboBox.py b/src/silx/gui/dialog/FileTypeComboBox.py
index 0ffc3a5..85ad3b1 100644
--- a/src/silx/gui/dialog/FileTypeComboBox.py
+++ b/src/silx/gui/dialog/FileTypeComboBox.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -30,12 +30,13 @@ __license__ = "MIT"
__date__ = "17/01/2019"
import fabio
+from fabio import fabioutils
+
import silx.io
from silx.gui import qt
class Codec(object):
-
def __init__(self, any_fabio=False, any_silx=False, fabio_codec=None, auto=False):
self.__any_fabio = any_fabio
self.__any_silx = any_silx
@@ -63,7 +64,7 @@ class FileTypeComboBox(qt.QComboBox):
CODEC_ROLE = qt.Qt.UserRole + 2
- INDENTATION = u"\u2022 "
+ INDENTATION = "\u2022 "
def __init__(self, parent=None):
qt.QComboBox.__init__(self, parent)
@@ -134,20 +135,13 @@ class FileTypeComboBox(qt.QComboBox):
def __insertFabioFormats(self):
formats = fabio.fabioformats.get_classes(reader=True)
- from fabio import fabioutils
- if hasattr(fabioutils, "COMPRESSED_EXTENSIONS"):
- compressedExtensions = fabioutils.COMPRESSED_EXTENSIONS
- else:
- # Support for fabio < 0.9
- compressedExtensions = set(["gz", "bz2"])
-
extensions = []
allExtensions = set([])
def extensionsIterator(reader):
for extension in reader.DEFAULT_EXTENSIONS:
yield "*.%s" % extension
- for compressedExtension in compressedExtensions:
+ for compressedExtension in fabioutils.COMPRESSED_EXTENSIONS:
for extension in reader.DEFAULT_EXTENSIONS:
yield "*.%s.%s" % (extension, compressedExtension)
@@ -163,7 +157,9 @@ class FileTypeComboBox(qt.QComboBox):
allExtensions.update(ext)
if ext == []:
ext = ["*"]
- extensions.append((reader.DESCRIPTION, displayext, ext, reader.codec_name()))
+ extensions.append(
+ (reader.DESCRIPTION, displayext, ext, reader.codec_name())
+ )
extensions = list(sorted(extensions))
allExtensions = list(sorted(list(allExtensions)))
@@ -176,7 +172,9 @@ class FileTypeComboBox(qt.QComboBox):
description, displayExt, allExt, _codecName = e
index = self.count()
if len(e[1]) < 10:
- self.addItem("%s%s (%s)" % (self.INDENTATION, description, " ".join(displayExt)))
+ self.addItem(
+ "%s%s (%s)" % (self.INDENTATION, description, " ".join(displayExt))
+ )
else:
self.addItem("%s%s" % (self.INDENTATION, description))
codec = Codec(fabio_codec=_codecName)
diff --git a/src/silx/gui/dialog/GroupDialog.py b/src/silx/gui/dialog/GroupDialog.py
index fb85d83..ca669f2 100644
--- a/src/silx/gui/dialog/GroupDialog.py
+++ b/src/silx/gui/dialog/GroupDialog.py
@@ -54,8 +54,7 @@ class _Hdf5ItemSelectionDialog(qt.QDialog):
self._tree = Hdf5TreeView(self)
self._tree.setSelectionMode(qt.QAbstractItemView.SingleSelection)
self._tree.activated.connect(self._onActivation)
- self._tree.selectionModel().selectionChanged.connect(
- self._onSelectionChange)
+ self._tree.selectionModel().selectionChanged.connect(self._onSelectionChange)
self._model = self._tree.findHdf5TreeModel()
@@ -67,10 +66,9 @@ class _Hdf5ItemSelectionDialog(qt.QDialog):
self._labelNewItem.setText("Create new item in selected group (optional):")
self._lineEditNewItem = qt.QLineEdit(self._newItemWidget)
self._lineEditNewItem.setToolTip(
- "Specify the name of a new item "
- "to be created in the selected group.")
- self._lineEditNewItem.textChanged.connect(
- self._onNewItemNameChange)
+ "Specify the name of a new item " "to be created in the selected group."
+ )
+ self._lineEditNewItem.textChanged.connect(self._onNewItemNameChange)
newItemLayout.addWidget(self._labelNewItem)
newItemLayout.addWidget(self._lineEditNewItem)
@@ -151,11 +149,11 @@ class _Hdf5ItemSelectionDialog(qt.QDialog):
if not data_path.endswith("/"):
data_path += "/"
data_path += subgroupName.lstrip("/")
- self._selectedUrl = DataUrl(file_path=node.local_filename,
- data_path=data_path)
+ self._selectedUrl = DataUrl(
+ file_path=node.local_filename, data_path=data_path
+ )
self._okButton.setEnabled(True)
- self._labelSelection.setText(
- self._selectedUrl.path())
+ self._labelSelection.setText(self._selectedUrl.path())
def getSelectedDataUrl(self):
"""Return a :class:`DataUrl` with a file path and a data path.
@@ -189,15 +187,16 @@ class GroupDialog(_Hdf5ItemSelectionDialog):
print("Operation cancelled :(")
"""
+
def __init__(self, parent=None):
_Hdf5ItemSelectionDialog.__init__(self, parent)
# customization for groups
self.setWindowTitle("HDF5 group selection")
- self._header.setSections([self._model.NAME_COLUMN,
- self._model.NODE_COLUMN,
- self._model.LINK_COLUMN])
+ self._header.setSections(
+ [self._model.NAME_COLUMN, self._model.NODE_COLUMN, self._model.LINK_COLUMN]
+ )
def _onActivation(self, idx):
# double-click or enter press: filter for groups
@@ -218,11 +217,11 @@ class GroupDialog(_Hdf5ItemSelectionDialog):
if not data_path.endswith("/"):
data_path += "/"
data_path += subgroupName.lstrip("/")
- self._selectedUrl = DataUrl(file_path=node.local_filename,
- data_path=data_path)
+ self._selectedUrl = DataUrl(
+ file_path=node.local_filename, data_path=data_path
+ )
self._okButton.setEnabled(True)
- self._labelSelection.setText(
- self._selectedUrl.path())
+ self._labelSelection.setText(self._selectedUrl.path())
else:
self._selectedUrl = None
self._okButton.setEnabled(False)
diff --git a/src/silx/gui/dialog/ImageFileDialog.py b/src/silx/gui/dialog/ImageFileDialog.py
index ed455f3..e7ce38f 100644
--- a/src/silx/gui/dialog/ImageFileDialog.py
+++ b/src/silx/gui/dialog/ImageFileDialog.py
@@ -198,7 +198,9 @@ class _ImagePreview(qt.QWidget):
axis = self.__plot.getXAxis()
axis.setLimitsConstraints(midWidth - widthContraint, midWidth + widthContraint)
axis = self.__plot.getYAxis()
- axis.setLimitsConstraints(midHeight - heightContraint, midHeight + heightContraint)
+ axis.setLimitsConstraints(
+ midHeight - heightContraint, midHeight + heightContraint
+ )
def __imageItem(self):
image = self.__plot.getImage("data")
@@ -340,14 +342,14 @@ class ImageFileDialog(AbstractDataFileDialog):
"""
destination = self.__formatShape(dataAfterSelection.shape)
source = self.__formatShape(dataBeforeSelection.shape)
- return u"%s \u2192 %s" % (source, destination)
+ return "%s \u2192 %s" % (source, destination)
def __formatShape(self, shape):
result = []
for s in shape:
if isinstance(s, slice):
- v = u"\u2026"
+ v = "\u2026"
else:
v = str(s)
result.append(v)
- return u" \u00D7 ".join(result)
+ return " \u00D7 ".join(result)
diff --git a/src/silx/gui/dialog/SafeFileIconProvider.py b/src/silx/gui/dialog/SafeFileIconProvider.py
index 141bedf..7022876 100644
--- a/src/silx/gui/dialog/SafeFileIconProvider.py
+++ b/src/silx/gui/dialog/SafeFileIconProvider.py
@@ -91,6 +91,7 @@ class SafeFileIconProvider(qt.QFileIconProvider):
def __windowsDriveTypeId(self, info):
try:
import ctypes
+
path = info.filePath()
dtype = ctypes.cdll.kernel32.GetDriveTypeW(path)
except Exception:
diff --git a/src/silx/gui/dialog/SafeFileSystemModel.py b/src/silx/gui/dialog/SafeFileSystemModel.py
index b9f3913..7cacc1e 100644
--- a/src/silx/gui/dialog/SafeFileSystemModel.py
+++ b/src/silx/gui/dialog/SafeFileSystemModel.py
@@ -41,7 +41,6 @@ _logger = logging.getLogger(__name__)
class _Item(object):
-
def __init__(self, fileInfo):
self.__fileInfo = fileInfo
self.__parent = None
@@ -101,7 +100,9 @@ class _Item(object):
elif self.isDrive():
path = self.__fileInfo.filePath()
else:
- path = os.path.join(self.parent().absoluteFilePath(), self.__fileInfo.fileName())
+ path = os.path.join(
+ self.parent().absoluteFilePath(), self.__fileInfo.fileName()
+ )
if path == "":
return "/"
self.__absolutePath = path
@@ -236,7 +237,9 @@ class _RawFileSystemModel(qt.QAbstractItemModel):
self.__header = "Name", "Size", "Type", "Last modification"
self.__currentPath = ""
self.__iconProvider = SafeFileIconProvider()
- self.__directoryLoadedSync.connect(self.__emitDirectoryLoaded, qt.Qt.QueuedConnection)
+ self.__directoryLoadedSync.connect(
+ self.__emitDirectoryLoaded, qt.Qt.QueuedConnection
+ )
def headerData(self, section, orientation, role=qt.Qt.DisplayRole):
if orientation == qt.Qt.Horizontal:
@@ -496,7 +499,7 @@ class _RawFileSystemModel(qt.QAbstractItemModel):
return
def setReadOnly(self, enable):
- assert(enable is True)
+ assert enable is True
def isReadOnly(self):
return False
@@ -612,20 +615,20 @@ class SafeFileSystemModel(qt.QSortFilterProxyModel):
filterPermissions = (filters & qt.QDir.PermissionMask) != 0
if filterPermissions and (filters & (qt.QDir.Dirs | qt.QDir.Files)):
- if (filters & qt.QDir.Readable):
+ if filters & qt.QDir.Readable:
# Hide unreadable
if not fileInfo.isReadable():
return False
- if (filters & qt.QDir.Writable):
+ if filters & qt.QDir.Writable:
# Hide unwritable
if not fileInfo.isWritable():
return False
- if (filters & qt.QDir.Executable):
+ if filters & qt.QDir.Executable:
# Hide unexecutable
if not fileInfo.isExecutable():
return False
- if (filters & qt.QDir.NoSymLinks):
+ if filters & qt.QDir.NoSymLinks:
# Hide sym links
if fileInfo.isSymLink():
return False
@@ -711,7 +714,9 @@ class SafeFileSystemModel(qt.QSortFilterProxyModel):
def setNameFilters(self, filters):
self.__nameFilters = []
isCaseSensitive = self.__filters & qt.QDir.CaseSensitive
- caseSensitive = qt.Qt.CaseSensitive if isCaseSensitive else qt.Qt.CaseInsensitive
+ caseSensitive = (
+ qt.Qt.CaseSensitive if isCaseSensitive else qt.Qt.CaseInsensitive
+ )
for f in filters:
reg = qt.QRegExp(f, caseSensitive, qt.QRegExp.Wildcard)
self.__nameFilters.append(reg)
@@ -730,7 +735,7 @@ class SafeFileSystemModel(qt.QSortFilterProxyModel):
self.invalidate()
def setReadOnly(self, enable):
- assert(enable is True)
+ assert enable is True
def isReadOnly(self):
return False
diff --git a/src/silx/gui/dialog/test/test_colormapdialog.py b/src/silx/gui/dialog/test/test_colormapdialog.py
index 1bfd584..1afafc0 100644
--- a/src/silx/gui/dialog/test/test_colormapdialog.py
+++ b/src/silx/gui/dialog/test/test_colormapdialog.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2016-2022 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2024 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -29,371 +29,369 @@ __date__ = "09/11/2018"
import pytest
-import weakref
from silx.gui import qt
from silx.gui.dialog import ColormapDialog
-from silx.gui.utils.testutils import TestCaseQt
from silx.gui.colors import Colormap, preferredColormaps
-from silx.utils.testutils import ParametricTestCase
from silx.gui.plot.items.image import ImageData
import numpy
-@pytest.fixture
-def colormap():
- colormap = Colormap(name='gray',
- vmin=10.0, vmax=20.0,
- normalization='linear')
- yield colormap
+def testGUIEdition(qWidgetFactory):
+ """Make sure the colormap is correctly edited and also that the
+ modification are correctly updated if an other colormapdialog is
+ editing the same colormap"""
+ colormap = Colormap(name="gray", vmin=10.0, vmax=20.0, normalization="linear")
+ dialog = qWidgetFactory(ColormapDialog.ColormapDialog)
+ dialog.setColormap(colormap)
+ dialog2 = qWidgetFactory(ColormapDialog.ColormapDialog)
+ dialog2.setColormap(colormap)
+
+ dialog._comboBoxColormap._setCurrentName("red")
+ dialog._comboBoxNormalization.setCurrentIndex(
+ dialog._comboBoxNormalization.findData(Colormap.LOGARITHM)
+ )
+ assert colormap.getName() == "red"
+ assert dialog.getColormap().getName() == "red"
+ assert colormap.getNormalization() == "log"
+ assert colormap.getVMin() == 10
+ assert colormap.getVMax() == 20
+ # checked second colormap dialog
+ assert dialog2._comboBoxColormap.getCurrentName() == "red"
+ assert dialog2._comboBoxNormalization.currentData() == Colormap.LOGARITHM
+ assert int(dialog2._minValue.getValue()) == 10
+ assert int(dialog2._maxValue.getValue()) == 20
+
+
+def testGUIModalOk(qapp, qapp_utils, qWidgetFactory):
+ """Make sure the colormap is modified if gone through accept"""
+ colormap = Colormap(name="gray", vmin=10.0, vmax=20.0, normalization="linear")
+ assert colormap.isAutoscale() is False
+ dialog = qWidgetFactory(ColormapDialog.ColormapDialog)
+ dialog.setModal(True)
+ qapp.processEvents()
+
+ dialog.setColormap(colormap)
+ assert colormap.getVMin() is not None
+ dialog._minValue.sigAutoScaleChanged.emit(True)
+ assert colormap.getVMin() is None
+ dialog._maxValue.sigAutoScaleChanged.emit(True)
+ qapp_utils.mouseClick(
+ widget=dialog._buttonsModal.button(qt.QDialogButtonBox.Ok),
+ button=qt.Qt.LeftButton,
+ )
+ assert colormap.getVMin() is None
+ assert colormap.getVMax() is None
+ assert colormap.isAutoscale() is True
+
+
+def testGUIModalCancel(qapp, qapp_utils, qWidgetFactory):
+ """Make sure the colormap is not modified if gone through reject"""
+ colormap = Colormap(name="gray", vmin=10.0, vmax=20.0, normalization="linear")
+ assert colormap.isAutoscale() is False
+ dialog = qWidgetFactory(ColormapDialog.ColormapDialog)
+ dialog.setModal(True)
+ qapp.processEvents()
+
+ dialog.setColormap(colormap)
+ assert colormap.getVMin() is not None
+ dialog._minValue.sigAutoScaleChanged.emit(True)
+ assert colormap.getVMin() is None
+ qapp_utils.mouseClick(
+ widget=dialog._buttonsModal.button(qt.QDialogButtonBox.Cancel),
+ button=qt.Qt.LeftButton,
+ )
+ assert colormap.getVMin() is not None
+
+
+def testGUIModalClose(qapp, qapp_utils, qWidgetFactory):
+ colormap = Colormap(name="gray", vmin=10.0, vmax=20.0, normalization="linear")
+ assert colormap.isAutoscale() is False
+ dialog = qWidgetFactory(ColormapDialog.ColormapDialog)
+ dialog.setModal(False)
+ qapp.processEvents()
+ dialog.setColormap(colormap)
+ assert colormap.getVMin() is not None
+ dialog._minValue.sigAutoScaleChanged.emit(True)
+ assert colormap.getVMin() is None
+ qapp_utils.mouseClick(
+ widget=dialog._buttonsNonModal.button(qt.QDialogButtonBox.Close),
+ button=qt.Qt.LeftButton,
+ )
+ assert colormap.getVMin() is None
+
+
+def testGUIModalReset(qapp, qapp_utils, qWidgetFactory):
+ colormap = Colormap(name="gray", vmin=10.0, vmax=20.0, normalization="linear")
+ assert colormap.isAutoscale() is False
+ dialog = qWidgetFactory(ColormapDialog.ColormapDialog)
+ dialog.setModal(False)
+ dialog.show()
+ qapp.processEvents()
+ dialog.setColormap(colormap)
+ assert colormap.getVMin() is not None
+ dialog._minValue.sigAutoScaleChanged.emit(True)
+ assert colormap.getVMin() is None
+ qapp_utils.mouseClick(
+ widget=dialog._buttonsNonModal.button(qt.QDialogButtonBox.Reset),
+ button=qt.Qt.LeftButton,
+ )
+ assert colormap.getVMin() is not None
+ dialog.close()
+
+
+def testGUIClose(qapp, qWidgetFactory):
+ """Make sure the colormap is modify if go through reject"""
+ colormap = Colormap(name="gray", vmin=10.0, vmax=20.0, normalization="linear")
+ dialog = qWidgetFactory(ColormapDialog.ColormapDialog)
+ assert colormap.isAutoscale() is False
+ qapp.processEvents()
+
+ dialog.setColormap(colormap)
+ assert colormap.getVMin() is not None
+ dialog._minValue.sigAutoScaleChanged.emit(True)
+ assert colormap.getVMin() is None
+ dialog.close()
+ qapp.processEvents()
+ assert colormap.getVMin() is None
+
+
+@pytest.mark.parametrize("norm", Colormap.NORMALIZATIONS)
+@pytest.mark.parametrize("autoscale", (True, False))
+def testSetColormapIsCorrect(norm, autoscale, qapp, qWidgetFactory):
+ """Make sure the interface fir the colormap when set a new colormap"""
+ dialog = qWidgetFactory(ColormapDialog.ColormapDialog)
+
+ colormap = Colormap(name="gray", vmin=10.0, vmax=20.0, normalization="linear")
+ colormap.setName("red")
+ if autoscale is True:
+ colormap.setVRange(None, None)
+ else:
+ colormap.setVRange(11, 101)
+ colormap.setNormalization(norm)
+ dialog.setColormap(colormap)
+ qapp.processEvents()
-@pytest.fixture
-def colormapDialog(qapp):
- dialog = ColormapDialog.ColormapDialog()
- dialog.setAttribute(qt.Qt.WA_DeleteOnClose)
- yield weakref.proxy(dialog)
+ assert dialog._comboBoxNormalization.currentData() == norm
+ assert dialog._comboBoxColormap.getCurrentName() == "red"
+ assert dialog._minValue.isAutoChecked() == autoscale
+ assert dialog._maxValue.isAutoChecked() == autoscale
+ if autoscale is False:
+ assert dialog._minValue.getValue() == 11
+ assert dialog._maxValue.getValue() == 101
+ assert dialog._minValue.isEnabled()
+ assert dialog._maxValue.isEnabled()
+ else:
+ assert dialog._minValue._numVal.isReadOnly()
+ assert dialog._maxValue._numVal.isReadOnly()
+
+
+def testColormapDel(qapp, qWidgetFactory):
+ """Check behavior if the colormap has been deleted outside. For now
+ we make sure the colormap is still running and nothing more"""
+ dialog = qWidgetFactory(ColormapDialog.ColormapDialog)
+ colormap = Colormap(name="gray")
+ dialog.setColormap(colormap)
qapp.processEvents()
- from silx.gui.qt import inspect
- if inspect.isValid(dialog):
- dialog.close()
- del dialog
- qapp.processEvents()
+ colormap = None
+ assert dialog.getColormap() is None
+ dialog._comboBoxColormap._setCurrentName("blue")
-@pytest.fixture
-def colormap_class_attr(request, qapp_utils, colormap, colormapDialog):
- """Provides few fixtures to a class as class attribute
- Used as transition from TestCase to pytest
+def testColormapEditedOutside(qapp, qWidgetFactory):
+ """Make sure the GUI is still up to date if the colormap is modified
+ outside"""
+ colormap = Colormap(name="gray", vmin=10.0, vmax=20.0, normalization="linear")
+ dialog = qWidgetFactory(ColormapDialog.ColormapDialog)
+ dialog.setColormap(colormap)
+ qapp.processEvents()
+
+ colormap.setName("red")
+ assert dialog._comboBoxColormap.getCurrentName() == "red"
+ colormap.setNormalization(Colormap.LOGARITHM)
+ assert dialog._comboBoxNormalization.currentData() == Colormap.LOGARITHM
+ colormap.setVRange(11, 201)
+ assert dialog._minValue.getValue() == 11
+ assert dialog._maxValue.getValue() == 201
+ assert not (dialog._minValue._numVal.isReadOnly())
+ assert not (dialog._maxValue._numVal.isReadOnly())
+ assert not (dialog._minValue.isAutoChecked())
+ assert not (dialog._maxValue.isAutoChecked())
+ colormap.setVRange(None, None)
+ qapp.processEvents()
+
+ assert dialog._minValue._numVal.isReadOnly()
+ assert dialog._maxValue._numVal.isReadOnly()
+ assert dialog._minValue.isAutoChecked()
+ assert dialog._maxValue.isAutoChecked()
+
+
+def testSetColormapScenario(qWidgetFactory):
+ """Test of a simple scenario of a colormap dialog editing several
+ colormap"""
+ dialog = qWidgetFactory(ColormapDialog.ColormapDialog)
+ colormap = Colormap(name="gray", vmin=10.0, vmax=20.0, normalization="linear")
+ colormap1 = Colormap(name="gray", vmin=10.0, vmax=20.0, normalization="linear")
+ colormap2 = Colormap(name="red", vmin=10.0, vmax=20.0, normalization="log")
+ colormap3 = Colormap(name="blue", vmin=None, vmax=None, normalization="linear")
+
+ dialog.setColormap(colormap)
+ dialog.setColormap(colormap1)
+ del colormap1
+ dialog.setColormap(colormap2)
+ del colormap2
+ dialog.setColormap(colormap3)
+ del colormap3
+
+
+def testNotPreferredColormap(qapp, qWidgetFactory):
+ """Test that the colormapEditor is able to edit a colormap which is not
+ part of the 'prefered colormap'
"""
- request.cls.qapp_utils = qapp_utils
- request.cls.colormap = colormap
- request.cls.colormapDiag = colormapDialog
- yield
- request.cls.qapp_utils = None
- request.cls.colormap = None
- request.cls.colormapDiag = None
-
-
-@pytest.mark.usefixtures("colormap_class_attr")
-class TestColormapDialog(TestCaseQt, ParametricTestCase):
-
- def testGUIEdition(self):
- """Make sure the colormap is correctly edited and also that the
- modification are correctly updated if an other colormapdialog is
- editing the same colormap"""
- colormapDiag2 = ColormapDialog.ColormapDialog()
- colormapDiag2.setAttribute(qt.Qt.WA_DeleteOnClose)
- colormapDiag2.setColormap(self.colormap)
- colormapDiag2.show()
- self.colormapDiag.setColormap(self.colormap)
- self.colormapDiag.show()
- self.qapp.processEvents()
-
- self.colormapDiag._comboBoxColormap._setCurrentName('red')
- self.colormapDiag._comboBoxNormalization.setCurrentIndex(
- self.colormapDiag._comboBoxNormalization.findData(Colormap.LOGARITHM))
- self.assertTrue(self.colormap.getName() == 'red')
- self.assertTrue(self.colormapDiag.getColormap().getName() == 'red')
- self.assertTrue(self.colormap.getNormalization() == 'log')
- self.assertTrue(self.colormap.getVMin() == 10)
- self.assertTrue(self.colormap.getVMax() == 20)
- # checked second colormap dialog
- self.assertTrue(colormapDiag2._comboBoxColormap.getCurrentName() == 'red')
- self.assertEqual(colormapDiag2._comboBoxNormalization.currentData(),
- Colormap.LOGARITHM)
- self.assertTrue(int(colormapDiag2._minValue.getValue()) == 10)
- self.assertTrue(int(colormapDiag2._maxValue.getValue()) == 20)
- colormapDiag2.close()
- del colormapDiag2
- self.qapp.processEvents()
-
- def testGUIModalOk(self):
- """Make sure the colormap is modified if gone through accept"""
- assert self.colormap.isAutoscale() is False
- self.colormapDiag.setModal(True)
- self.colormapDiag.show()
- self.qapp.processEvents()
- self.colormapDiag.setColormap(self.colormap)
- self.assertTrue(self.colormap.getVMin() is not None)
- self.colormapDiag._minValue.sigAutoScaleChanged.emit(True)
- self.assertTrue(self.colormap.getVMin() is None)
- self.colormapDiag._maxValue.sigAutoScaleChanged.emit(True)
- self.mouseClick(
- widget=self.colormapDiag._buttonsModal.button(qt.QDialogButtonBox.Ok),
- button=qt.Qt.LeftButton
- )
- self.assertTrue(self.colormap.getVMin() is None)
- self.assertTrue(self.colormap.getVMax() is None)
- self.assertTrue(self.colormap.isAutoscale() is True)
-
- def testGUIModalCancel(self):
- """Make sure the colormap is not modified if gone through reject"""
- assert self.colormap.isAutoscale() is False
- self.colormapDiag.setModal(True)
- self.colormapDiag.show()
- self.qapp.processEvents()
- self.colormapDiag.setColormap(self.colormap)
- self.assertTrue(self.colormap.getVMin() is not None)
- self.colormapDiag._minValue.sigAutoScaleChanged.emit(True)
- self.assertTrue(self.colormap.getVMin() is None)
- self.mouseClick(
- widget=self.colormapDiag._buttonsModal.button(qt.QDialogButtonBox.Cancel),
- button=qt.Qt.LeftButton
- )
- self.assertTrue(self.colormap.getVMin() is not None)
-
- def testGUIModalClose(self):
- assert self.colormap.isAutoscale() is False
- self.colormapDiag.setModal(False)
- self.colormapDiag.show()
- self.qapp.processEvents()
- self.colormapDiag.setColormap(self.colormap)
- self.assertTrue(self.colormap.getVMin() is not None)
- self.colormapDiag._minValue.sigAutoScaleChanged.emit(True)
- self.assertTrue(self.colormap.getVMin() is None)
- self.mouseClick(
- widget=self.colormapDiag._buttonsNonModal.button(qt.QDialogButtonBox.Close),
- button=qt.Qt.LeftButton
- )
- self.assertTrue(self.colormap.getVMin() is None)
-
- def testGUIModalReset(self):
- assert self.colormap.isAutoscale() is False
- self.colormapDiag.setModal(False)
- self.colormapDiag.show()
- self.qapp.processEvents()
- self.colormapDiag.setColormap(self.colormap)
- self.assertTrue(self.colormap.getVMin() is not None)
- self.colormapDiag._minValue.sigAutoScaleChanged.emit(True)
- self.assertTrue(self.colormap.getVMin() is None)
- self.mouseClick(
- widget=self.colormapDiag._buttonsNonModal.button(qt.QDialogButtonBox.Reset),
- button=qt.Qt.LeftButton
- )
- self.assertTrue(self.colormap.getVMin() is not None)
- self.colormapDiag.close()
-
- def testGUIClose(self):
- """Make sure the colormap is modify if go through reject"""
- assert self.colormap.isAutoscale() is False
- self.colormapDiag.show()
- self.qapp.processEvents()
- self.colormapDiag.setColormap(self.colormap)
- self.assertTrue(self.colormap.getVMin() is not None)
- self.colormapDiag._minValue.sigAutoScaleChanged.emit(True)
- self.assertTrue(self.colormap.getVMin() is None)
- self.colormapDiag.close()
- self.qapp.processEvents()
- self.assertTrue(self.colormap.getVMin() is None)
-
- def testSetColormapIsCorrect(self):
- """Make sure the interface fir the colormap when set a new colormap"""
- self.colormap.setName('red')
- self.colormapDiag.show()
- self.qapp.processEvents()
- for norm in (Colormap.NORMALIZATIONS):
- for autoscale in (True, False):
- if autoscale is True:
- self.colormap.setVRange(None, None)
- else:
- self.colormap.setVRange(11, 101)
- self.colormap.setNormalization(norm)
- with self.subTest(colormap=self.colormap):
- self.colormapDiag.setColormap(self.colormap)
- self.assertEqual(
- self.colormapDiag._comboBoxNormalization.currentData(), norm)
- self.assertTrue(
- self.colormapDiag._comboBoxColormap.getCurrentName() == 'red')
- self.assertTrue(
- self.colormapDiag._minValue.isAutoChecked() == autoscale)
- self.assertTrue(
- self.colormapDiag._maxValue.isAutoChecked() == autoscale)
- if autoscale is False:
- self.assertTrue(self.colormapDiag._minValue.getValue() == 11)
- self.assertTrue(self.colormapDiag._maxValue.getValue() == 101)
- self.assertTrue(self.colormapDiag._minValue.isEnabled())
- self.assertTrue(self.colormapDiag._maxValue.isEnabled())
- else:
- self.assertTrue(self.colormapDiag._minValue._numVal.isReadOnly())
- self.assertTrue(self.colormapDiag._maxValue._numVal.isReadOnly())
-
- def testColormapDel(self):
- """Check behavior if the colormap has been deleted outside. For now
- we make sure the colormap is still running and nothing more"""
- colormap = Colormap(name='gray')
- self.colormapDiag.setColormap(colormap)
- self.colormapDiag.show()
- self.qapp.processEvents()
- colormap = None
- self.assertTrue(self.colormapDiag.getColormap() is None)
- self.colormapDiag._comboBoxColormap._setCurrentName('blue')
-
- def testColormapEditedOutside(self):
- """Make sure the GUI is still up to date if the colormap is modified
- outside"""
- self.colormapDiag.setColormap(self.colormap)
- self.colormapDiag.show()
- self.qapp.processEvents()
-
- self.colormap.setName('red')
- self.assertTrue(
- self.colormapDiag._comboBoxColormap.getCurrentName() == 'red')
- self.colormap.setNormalization(Colormap.LOGARITHM)
- self.assertEqual(self.colormapDiag._comboBoxNormalization.currentData(),
- Colormap.LOGARITHM)
- self.colormap.setVRange(11, 201)
- self.assertTrue(self.colormapDiag._minValue.getValue() == 11)
- self.assertTrue(self.colormapDiag._maxValue.getValue() == 201)
- self.assertFalse(self.colormapDiag._minValue._numVal.isReadOnly())
- self.assertFalse(self.colormapDiag._maxValue._numVal.isReadOnly())
- self.assertFalse(self.colormapDiag._minValue.isAutoChecked())
- self.assertFalse(self.colormapDiag._maxValue.isAutoChecked())
- self.colormap.setVRange(None, None)
- self.qapp.processEvents()
- self.assertTrue(self.colormapDiag._minValue._numVal.isReadOnly())
- self.assertTrue(self.colormapDiag._maxValue._numVal.isReadOnly())
- self.assertTrue(self.colormapDiag._minValue.isAutoChecked())
- self.assertTrue(self.colormapDiag._maxValue.isAutoChecked())
-
- def testSetColormapScenario(self):
- """Test of a simple scenario of a colormap dialog editing several
- colormap"""
- colormap1 = Colormap(name='gray', vmin=10.0, vmax=20.0,
- normalization='linear')
- colormap2 = Colormap(name='red', vmin=10.0, vmax=20.0,
- normalization='log')
- colormap3 = Colormap(name='blue', vmin=None, vmax=None,
- normalization='linear')
- self.colormapDiag.setColormap(self.colormap)
- self.colormapDiag.setColormap(colormap1)
- del colormap1
- self.colormapDiag.setColormap(colormap2)
- del colormap2
- self.colormapDiag.setColormap(colormap3)
- del colormap3
-
- def testNotPreferredColormap(self):
- """Test that the colormapEditor is able to edit a colormap which is not
- part of the 'prefered colormap'
- """
- def getFirstNotPreferredColormap():
- cms = Colormap.getSupportedColormaps()
- preferred = preferredColormaps()
- for cm in cms:
- if cm not in preferred:
- return cm
- return None
-
- colormapName = getFirstNotPreferredColormap()
- assert colormapName is not None
- colormap = Colormap(name=colormapName)
- self.colormapDiag.setColormap(colormap)
- self.colormapDiag.show()
- self.qapp.processEvents()
- cb = self.colormapDiag._comboBoxColormap
- self.assertTrue(cb.getCurrentName() == colormapName)
- cb.setCurrentIndex(0)
- index = cb.findLutName(colormapName)
- assert index != 0 # if 0 then the rest of the test has no sense
- cb.setCurrentIndex(index)
- self.assertTrue(cb.getCurrentName() == colormapName)
-
- def testColormapEditableMode(self):
- """Test that the colormapDialog is correctly updated when changing the
- colormap editable status"""
- colormap = Colormap(normalization='linear', vmin=1.0, vmax=10.0)
- self.colormapDiag.show()
- self.qapp.processEvents()
- self.colormapDiag.setColormap(colormap)
- for editable in (True, False):
- with self.subTest(editable=editable):
- colormap.setEditable(editable)
- self.assertTrue(
- self.colormapDiag._comboBoxColormap.isEnabled() is editable)
- self.assertTrue(
- self.colormapDiag._minValue.isEnabled() is editable)
- self.assertTrue(
- self.colormapDiag._maxValue.isEnabled() is editable)
- self.assertTrue(
- self.colormapDiag._comboBoxNormalization.isEnabled() is editable)
-
- # Make sure the reset button is also set to enable when edition mode is
- # False
- self.colormapDiag.setModal(False)
- colormap.setEditable(True)
- self.colormapDiag._comboBoxNormalization.setCurrentIndex(
- self.colormapDiag._comboBoxNormalization.findData(Colormap.LOGARITHM))
- resetButton = self.colormapDiag._buttonsNonModal.button(qt.QDialogButtonBox.Reset)
- self.assertTrue(resetButton.isEnabled())
- colormap.setEditable(False)
- self.assertFalse(resetButton.isEnabled())
-
- def testImageData(self):
- data = numpy.random.rand(5, 5)
- self.colormapDiag.setData(data)
-
- def testEmptyData(self):
- data = numpy.empty((10, 0))
- self.colormapDiag.setData(data)
-
- def testNoneData(self):
- data = numpy.random.rand(5, 5)
- self.colormapDiag.setData(data)
- self.colormapDiag.setData(None)
-
- def testImageItem(self):
- """Check that an ImageData plot item can be used"""
- dialog = self.colormapDiag
- colormap = Colormap(name='gray', vmin=None, vmax=None)
- data = numpy.arange(3**2).reshape(3, 3)
- item = ImageData()
- item.setData(data, copy=False)
-
- dialog.setColormap(colormap)
- dialog.show()
- self.qapp.processEvents()
- dialog.setItem(item)
- vrange = dialog._getFiniteColormapRange()
- self.assertEqual(vrange, (0, 8))
-
- def testItemDel(self):
- """Check that the plot items are not hard linked to the dialog"""
- dialog = self.colormapDiag
- colormap = Colormap(name='gray', vmin=None, vmax=None)
- data = numpy.arange(3**2).reshape(3, 3)
- item = ImageData()
- item.setData(data, copy=False)
-
- dialog.setColormap(colormap)
- dialog.show()
- self.qapp.processEvents()
- dialog.setItem(item)
- previousRange = dialog._getFiniteColormapRange()
- del item
- vrange = dialog._getFiniteColormapRange()
- self.assertNotEqual(vrange, previousRange)
-
- def testDataDel(self):
- """Check that the data are not hard linked to the dialog"""
- dialog = self.colormapDiag
- colormap = Colormap(name='gray', vmin=None, vmax=None)
- data = numpy.arange(5)
-
- dialog.setColormap(colormap)
- dialog.show()
- self.qapp.processEvents()
- dialog.setData(data)
- previousRange = dialog._getFiniteColormapRange()
- del data
- vrange = dialog._getFiniteColormapRange()
- self.assertNotEqual(vrange, previousRange)
-
- def testDeleteWhileExec(self):
- colormapDiag = self.colormapDiag
- self.colormapDiag = None
- qt.QTimer.singleShot(1000, colormapDiag.deleteLater)
- result = colormapDiag.exec()
- self.assertEqual(result, 0)
+
+ def getFirstNotPreferredColormap():
+ cms = Colormap.getSupportedColormaps()
+ preferred = preferredColormaps()
+ for cm in cms:
+ if cm not in preferred:
+ return cm
+ return None
+
+ dialog = qWidgetFactory(ColormapDialog.ColormapDialog)
+ colormapName = getFirstNotPreferredColormap()
+ assert colormapName is not None
+ colormap = Colormap(name=colormapName)
+ dialog.setColormap(colormap)
+ qapp.processEvents()
+
+ cb = dialog._comboBoxColormap
+ assert cb.getCurrentName() == colormapName
+ cb.setCurrentIndex(0)
+ index = cb.findLutName(colormapName)
+ assert index != 0 # if 0 then the rest of the test has no sense
+ cb.setCurrentIndex(index)
+ assert cb.getCurrentName() == colormapName
+
+
+def testColormapEditableMode(qWidgetFactory):
+ """Test that the colormapDialog is correctly updated when changing the
+ colormap editable status"""
+ dialog = qWidgetFactory(ColormapDialog.ColormapDialog)
+ colormap = Colormap(normalization="linear", vmin=1.0, vmax=10.0)
+
+ dialog.setColormap(colormap)
+
+ for editable in (True, False):
+ colormap.setEditable(editable)
+ assert dialog._comboBoxColormap.isEnabled() is editable
+ assert dialog._minValue.isEnabled() is editable
+ assert dialog._maxValue.isEnabled() is editable
+ assert dialog._comboBoxNormalization.isEnabled() is editable
+
+ # Make sure the reset button is also set to enable when edition mode is
+ # False
+ dialog.setModal(False)
+ colormap.setEditable(True)
+ dialog._comboBoxNormalization.setCurrentIndex(
+ dialog._comboBoxNormalization.findData(Colormap.LOGARITHM)
+ )
+ resetButton = dialog._buttonsNonModal.button(qt.QDialogButtonBox.Reset)
+ assert resetButton.isEnabled()
+ colormap.setEditable(False)
+ assert not (resetButton.isEnabled())
+
+
+def testImageData(qWidgetFactory):
+ dialog = qWidgetFactory(ColormapDialog.ColormapDialog)
+ data = numpy.random.rand(5, 5)
+ dialog.setData(data)
+
+
+def testEmptyData(qWidgetFactory):
+ dialog = qWidgetFactory(ColormapDialog.ColormapDialog)
+ data = numpy.empty((10, 0))
+ dialog.setData(data)
+
+
+def testNoneData(qWidgetFactory):
+ dialog = qWidgetFactory(ColormapDialog.ColormapDialog)
+ data = numpy.random.rand(5, 5)
+ dialog.setData(data)
+ dialog.setData(None)
+
+
+def testImageItem(qapp, qWidgetFactory):
+ """Check that an ImageData plot item can be used"""
+ dialog = qWidgetFactory(ColormapDialog.ColormapDialog)
+ colormap = Colormap(name="gray", vmin=None, vmax=None)
+ data = numpy.arange(3**2).reshape(3, 3)
+ item = ImageData()
+ item.setData(data, copy=False)
+
+ dialog.setColormap(colormap)
+ qapp.processEvents()
+
+ dialog.setItem(item)
+ vrange = dialog._getFiniteColormapRange()
+ assert vrange == (0, 8)
+
+
+def testItemDel(qapp, qWidgetFactory):
+ """Check that the plot items are not hard linked to the dialog"""
+ dialog = qWidgetFactory(ColormapDialog.ColormapDialog)
+ colormap = Colormap(name="gray", vmin=None, vmax=None)
+ data = numpy.arange(3**2).reshape(3, 3)
+ item = ImageData()
+ item.setData(data, copy=False)
+
+ dialog.setColormap(colormap)
+ dialog.show()
+ qapp.processEvents()
+ dialog.setItem(item)
+ previousRange = dialog._getFiniteColormapRange()
+ del item
+ vrange = dialog._getFiniteColormapRange()
+ assert vrange != previousRange
+
+
+def testDataDel(qapp, qWidgetFactory):
+ """Check that the data are not hard linked to the dialog"""
+ dialog = qWidgetFactory(ColormapDialog.ColormapDialog)
+ colormap = Colormap(name="gray", vmin=None, vmax=None)
+ data = numpy.arange(5)
+
+ dialog.setColormap(colormap)
+ qapp.processEvents()
+
+ dialog.setData(data)
+ previousRange = dialog._getFiniteColormapRange()
+ del data
+ vrange = dialog._getFiniteColormapRange()
+ assert vrange != previousRange
+
+
+def testDeleteWhileExec(qWidgetFactory):
+ dialog = qWidgetFactory(ColormapDialog.ColormapDialog)
+ qt.QTimer.singleShot(1000, dialog.deleteLater)
+ result = dialog.exec()
+ assert result == 0
+
+
+def testUpdateImageData(qapp, qWidgetFactory):
+ """Test that range/histogram takes into account item updates"""
+ dialog = qWidgetFactory(ColormapDialog.ColormapDialog)
+
+ item = ImageData()
+ item.setColormap(Colormap())
+ dialog.setItem(item)
+ dialog.setColormap(item.getColormap())
+ qapp.processEvents()
+
+ assert dialog._histoWidget.getFiniteRange() == (0, 1)
+
+ item.setData([(1, 2), (3, 4)])
+
+ assert dialog._histoWidget.getFiniteRange() == (1, 4)
diff --git a/src/silx/gui/dialog/test/test_datafiledialog.py b/src/silx/gui/dialog/test/test_datafiledialog.py
index 32d75c2..887ff1c 100644
--- a/src/silx/gui/dialog/test/test_datafiledialog.py
+++ b/src/silx/gui/dialog/test/test_datafiledialog.py
@@ -28,7 +28,6 @@ __license__ = "MIT"
__date__ = "08/03/2019"
-import unittest
import tempfile
import numpy
import shutil
@@ -65,7 +64,7 @@ def setUpModule():
f["complex_image"] = data * 1j
f["group/image"] = data
f["nxdata/foo"] = 10
- f["nxdata"].attrs["NX_class"] = u"NXdata"
+ f["nxdata"].attrs["NX_class"] = "NXdata"
f.close()
directory = os.path.join(_tmpDirectory, "data")
@@ -78,7 +77,7 @@ def setUpModule():
f["complex_image"] = data * 1j
f["group/image"] = data
f["nxdata/foo"] = 10
- f["nxdata"].attrs["NX_class"] = u"NXdata"
+ f["nxdata"].attrs["NX_class"] = "NXdata"
f.close()
filename = _tmpDirectory + "/badformat.h5"
@@ -99,7 +98,6 @@ def tearDownModule():
class _UtilsMixin(object):
-
def createDialog(self):
self._deleteDialog()
self._dialog = self._createDialog()
@@ -139,7 +137,6 @@ class _UtilsMixin(object):
class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
-
def tearDown(self):
self._deleteDialog()
testutils.TestCaseQt.tearDown(self)
@@ -219,7 +216,11 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.qWaitForPendingActions(dialog)
# select, then double click on the file
- index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/group"])
+ index = (
+ browser.rootIndex()
+ .model()
+ .indexFromH5Object(dialog._AbstractDataFileDialog__h5["/group"])
+ )
browser.selectIndex(index)
browser.activated.emit(index)
self.qWaitForPendingActions(dialog)
@@ -249,7 +250,11 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.qWaitForPendingActions(dialog)
# select, then double click on the file
- index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/scalar"])
+ index = (
+ browser.rootIndex()
+ .model()
+ .indexFromH5Object(dialog._AbstractDataFileDialog__h5["/scalar"])
+ )
browser.selectIndex(index)
browser.activated.emit(index)
self.qWaitForPendingActions(dialog)
@@ -276,12 +281,16 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
path = silx.io.url.DataUrl(file_path=filename, data_path="/group/image").path()
dialog.selectUrl(path)
self.qWaitForPendingActions(dialog)
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/image").path()
+ path = silx.io.url.DataUrl(
+ scheme="silx", file_path=filename, data_path="/group/image"
+ ).path()
self.assertSamePath(url.text(), path)
# test
self.mouseClick(toParentButton, qt.Qt.LeftButton)
self.qWaitForPendingActions(dialog)
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
+ path = silx.io.url.DataUrl(
+ scheme="silx", file_path=filename, data_path="/"
+ ).path()
self.assertSamePath(url.text(), path)
self.mouseClick(toParentButton, qt.Qt.LeftButton)
@@ -303,7 +312,9 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
filename = _tmpDirectory + "/data.h5"
# init state
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/image").path()
+ path = silx.io.url.DataUrl(
+ scheme="silx", file_path=filename, data_path="/group/image"
+ ).path()
dialog.selectUrl(path)
self.qWaitForPendingActions(dialog)
self.assertSamePath(url.text(), path)
@@ -311,7 +322,9 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
# test
self.mouseClick(button, qt.Qt.LeftButton)
self.qWaitForPendingActions(dialog)
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
+ path = silx.io.url.DataUrl(
+ scheme="silx", file_path=filename, data_path="/"
+ ).path()
self.assertSamePath(url.text(), path)
# self.assertFalse(button.isEnabled())
@@ -329,7 +342,9 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
path = silx.io.url.DataUrl(file_path=filename, data_path="/group/image").path()
dialog.selectUrl(path)
self.qWaitForPendingActions(dialog)
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/image").path()
+ path = silx.io.url.DataUrl(
+ scheme="silx", file_path=filename, data_path="/group/image"
+ ).path()
self.assertSamePath(url.text(), path)
self.assertTrue(button.isEnabled())
# test
@@ -348,8 +363,12 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.qWaitForWindowExposed(dialog)
url = testutils.findChildren(dialog, qt.QLineEdit, name="url")[0]
- forwardAction = testutils.findChildren(dialog, qt.QAction, name="forwardAction")[0]
- backwardAction = testutils.findChildren(dialog, qt.QAction, name="backwardAction")[0]
+ forwardAction = testutils.findChildren(
+ dialog, qt.QAction, name="forwardAction"
+ )[0]
+ backwardAction = testutils.findChildren(
+ dialog, qt.QAction, name="backwardAction"
+ )[0]
filename = _tmpDirectory + "/data.h5"
dialog.setDirectory(_tmpDirectory)
@@ -358,10 +377,14 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
# Then we feed the history using selectPath
dialog.selectUrl(filename)
self.qWaitForPendingActions(dialog)
- path2 = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
+ path2 = silx.io.url.DataUrl(
+ scheme="silx", file_path=filename, data_path="/"
+ ).path()
dialog.selectUrl(path2)
self.qWaitForPendingActions(dialog)
- path3 = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group").path()
+ path3 = silx.io.url.DataUrl(
+ scheme="silx", file_path=filename, data_path="/group"
+ ).path()
dialog.selectUrl(path3)
self.qWaitForPendingActions(dialog)
self.assertFalse(forwardAction.isEnabled())
@@ -388,7 +411,11 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
# init state
filename = _tmpDirectory + "/singleimage.edf"
- url = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/scan_0/instrument/detector_0/data")
+ url = silx.io.url.DataUrl(
+ scheme="silx",
+ file_path=filename,
+ data_path="/scan_0/instrument/detector_0/data",
+ )
dialog.selectUrl(url.path())
self.assertEqual(dialog._selectedData().shape, (100, 100))
self.assertSamePath(dialog.selectedFile(), filename)
@@ -401,7 +428,9 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
# init state
filename = _tmpDirectory + "/data.h5"
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/image").path()
+ path = silx.io.url.DataUrl(
+ scheme="silx", file_path=filename, data_path="/image"
+ ).path()
dialog.selectUrl(path)
# test
self.assertEqual(dialog._selectedData().shape, (100, 100))
@@ -415,7 +444,9 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
# init state
filename = _tmpDirectory + "/data.h5"
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/scalar").path()
+ path = silx.io.url.DataUrl(
+ scheme="silx", file_path=filename, data_path="/scalar"
+ ).path()
dialog.selectUrl(path)
# test
self.assertEqual(dialog._selectedData()[()], 10)
@@ -464,7 +495,9 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.qWaitForPendingActions(dialog)
browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
filename = _tmpDirectory + "/data.h5"
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
+ path = silx.io.url.DataUrl(
+ scheme="silx", file_path=filename, data_path="/"
+ ).path()
index = browser.rootIndex().model().index(filename)
# click
browser.selectIndex(index)
@@ -508,11 +541,12 @@ class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.qWaitForWindowExposed(dialog)
dialog.selectUrl(_tmpDirectory)
self.qWaitForPendingActions(dialog)
- self.assertEqual(self._countSelectableItems(browser.model(), browser.rootIndex()), 4)
+ self.assertEqual(
+ self._countSelectableItems(browser.model(), browser.rootIndex()), 4
+ )
class TestDataFileDialog_FilterDataset(testutils.TestCaseQt, _UtilsMixin):
-
def tearDown(self):
self._deleteDialog()
testutils.TestCaseQt.tearDown(self)
@@ -539,7 +573,11 @@ class TestDataFileDialog_FilterDataset(testutils.TestCaseQt, _UtilsMixin):
self.qWaitForPendingActions(dialog)
# select, then double click on the file
- index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/group"])
+ index = (
+ browser.rootIndex()
+ .model()
+ .indexFromH5Object(dialog._AbstractDataFileDialog__h5["/group"])
+ )
browser.selectIndex(index)
browser.activated.emit(index)
self.qWaitForPendingActions(dialog)
@@ -564,7 +602,11 @@ class TestDataFileDialog_FilterDataset(testutils.TestCaseQt, _UtilsMixin):
self.qWaitForPendingActions(dialog)
# select, then double click on the file
- index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/scalar"])
+ index = (
+ browser.rootIndex()
+ .model()
+ .indexFromH5Object(dialog._AbstractDataFileDialog__h5["/scalar"])
+ )
browser.selectIndex(index)
browser.activated.emit(index)
self.qWaitForPendingActions(dialog)
@@ -582,7 +624,6 @@ class TestDataFileDialog_FilterDataset(testutils.TestCaseQt, _UtilsMixin):
class TestDataFileDialog_FilterGroup(testutils.TestCaseQt, _UtilsMixin):
-
def tearDown(self):
self._deleteDialog()
testutils.TestCaseQt.tearDown(self)
@@ -609,7 +650,11 @@ class TestDataFileDialog_FilterGroup(testutils.TestCaseQt, _UtilsMixin):
self.qWaitForPendingActions(dialog)
# select, then double click on the file
- index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/group"])
+ index = (
+ browser.rootIndex()
+ .model()
+ .indexFromH5Object(dialog._AbstractDataFileDialog__h5["/group"])
+ )
browser.selectIndex(index)
browser.activated.emit(index)
self.qWaitForPendingActions(dialog)
@@ -641,7 +686,11 @@ class TestDataFileDialog_FilterGroup(testutils.TestCaseQt, _UtilsMixin):
self.qWaitForPendingActions(dialog)
# select, then double click on the file
- index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/scalar"])
+ index = (
+ browser.rootIndex()
+ .model()
+ .indexFromH5Object(dialog._AbstractDataFileDialog__h5["/scalar"])
+ )
browser.selectIndex(index)
browser.activated.emit(index)
self.qWaitForPendingActions(dialog)
@@ -651,7 +700,6 @@ class TestDataFileDialog_FilterGroup(testutils.TestCaseQt, _UtilsMixin):
class TestDataFileDialog_FilterNXdata(testutils.TestCaseQt, _UtilsMixin):
-
def tearDown(self):
self._deleteDialog()
testutils.TestCaseQt.tearDown(self)
@@ -659,7 +707,7 @@ class TestDataFileDialog_FilterNXdata(testutils.TestCaseQt, _UtilsMixin):
def _createDialog(self):
def customFilter(obj):
if "NX_class" in obj.attrs:
- return obj.attrs["NX_class"] == u"NXdata"
+ return obj.attrs["NX_class"] == "NXdata"
return False
dialog = DataFileDialog()
@@ -684,7 +732,11 @@ class TestDataFileDialog_FilterNXdata(testutils.TestCaseQt, _UtilsMixin):
self.qWaitForPendingActions(dialog)
# select, then double click on the file
- index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/group"])
+ index = (
+ browser.rootIndex()
+ .model()
+ .indexFromH5Object(dialog._AbstractDataFileDialog__h5["/group"])
+ )
browser.selectIndex(index)
browser.activated.emit(index)
self.qWaitForPendingActions(dialog)
@@ -711,7 +763,11 @@ class TestDataFileDialog_FilterNXdata(testutils.TestCaseQt, _UtilsMixin):
self.qWaitForPendingActions(dialog)
# select, then double click on the file
- index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/nxdata"])
+ index = (
+ browser.rootIndex()
+ .model()
+ .indexFromH5Object(dialog._AbstractDataFileDialog__h5["/nxdata"])
+ )
browser.selectIndex(index)
browser.activated.emit(index)
self.qWaitForPendingActions(dialog)
@@ -726,7 +782,6 @@ class TestDataFileDialog_FilterNXdata(testutils.TestCaseQt, _UtilsMixin):
class TestDataFileDialogApi(testutils.TestCaseQt, _UtilsMixin):
-
def tearDown(self):
self._deleteDialog()
testutils.TestCaseQt.tearDown(self)
@@ -779,46 +834,50 @@ class TestDataFileDialogApi(testutils.TestCaseQt, _UtilsMixin):
print()
print("\\\n".join(strings))
- STATE_VERSION1_QT4 = b''\
- b'\x00\x00\x00Z\x00s\x00i\x00l\x00x\x00.\x00g\x00u\x00i\x00.\x00'\
- b'd\x00i\x00a\x00l\x00o\x00g\x00.\x00D\x00a\x00t\x00a\x00F\x00i'\
- b'\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00.\x00D\x00a\x00t\x00'\
- b'a\x00F\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00\x00\x00'\
- b'\x01\x00\x00\x00\x0C\x00\x00\x00\x00"\x00\x00\x00\xFF\x00\x00'\
- b'\x00\x00\x00\x00\x00\x03\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF'\
- b'\xFF\xFF\x01\x00\x00\x00\x06\x01\x00\x00\x00\x01\x00\x00\x00\x00'\
- b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0C\x00\x00\x00\x00'\
- b'}\x00\x00\x00\x0E\x00B\x00r\x00o\x00w\x00s\x00e\x00r\x00\x00\x00'\
- b'\x01\x00\x00\x00\x0C\x00\x00\x00\x00Z\x00\x00\x00\xFF\x00\x00'\
- b'\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00'\
- b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'\
- b'\x00\x01\x90\x00\x00\x00\x04\x01\x01\x00\x00\x00\x00\x00\x00\x00'\
- b'\x00\x00\x00\x00\x00\x00\x00d\xFF\xFF\xFF\xFF\x00\x00\x00\x81'\
- b'\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x01\x90\x00\x00\x00\x04'\
- b'\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x00'\
- b'\x01\xFF\xFF\xFF\xFF'
+ STATE_VERSION1_QT4 = (
+ b""
+ b"\x00\x00\x00Z\x00s\x00i\x00l\x00x\x00.\x00g\x00u\x00i\x00.\x00"
+ b"d\x00i\x00a\x00l\x00o\x00g\x00.\x00D\x00a\x00t\x00a\x00F\x00i"
+ b"\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00.\x00D\x00a\x00t\x00"
+ b"a\x00F\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00\x00\x00"
+ b'\x01\x00\x00\x00\x0C\x00\x00\x00\x00"\x00\x00\x00\xFF\x00\x00'
+ b"\x00\x00\x00\x00\x00\x03\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF"
+ b"\xFF\xFF\x01\x00\x00\x00\x06\x01\x00\x00\x00\x01\x00\x00\x00\x00"
+ b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0C\x00\x00\x00\x00"
+ b"}\x00\x00\x00\x0E\x00B\x00r\x00o\x00w\x00s\x00e\x00r\x00\x00\x00"
+ b"\x01\x00\x00\x00\x0C\x00\x00\x00\x00Z\x00\x00\x00\xFF\x00\x00"
+ b"\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00"
+ b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
+ b"\x00\x01\x90\x00\x00\x00\x04\x01\x01\x00\x00\x00\x00\x00\x00\x00"
+ b"\x00\x00\x00\x00\x00\x00\x00d\xFF\xFF\xFF\xFF\x00\x00\x00\x81"
+ b"\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x01\x90\x00\x00\x00\x04"
+ b"\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x00"
+ b"\x01\xFF\xFF\xFF\xFF"
+ )
"""Serialized state on Qt4. Generated using :meth:`printState`"""
- STATE_VERSION1_QT5 = b''\
- b'\x00\x00\x00Z\x00s\x00i\x00l\x00x\x00.\x00g\x00u\x00i\x00.\x00'\
- b'd\x00i\x00a\x00l\x00o\x00g\x00.\x00D\x00a\x00t\x00a\x00F\x00i'\
- b'\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00.\x00D\x00a\x00t\x00'\
- b'a\x00F\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00\x00\x00'\
- b'\x01\x00\x00\x00\x0C\x00\x00\x00\x00#\x00\x00\x00\xFF\x00\x00'\
- b'\x00\x01\x00\x00\x00\x03\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF'\
- b'\xFF\xFF\x01\xFF\xFF\xFF\xFF\x01\x00\x00\x00\x01\x00\x00\x00\x00'\
- b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0C\x00\x00\x00'\
- b'\x00\xAA\x00\x00\x00\x0E\x00B\x00r\x00o\x00w\x00s\x00e\x00r\x00'\
- b'\x00\x00\x01\x00\x00\x00\x0C\x00\x00\x00\x00\x87\x00\x00\x00\xFF'\
- b'\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00'\
- b'\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'\
- b'\x00\x00\x00\x01\x90\x00\x00\x00\x04\x01\x01\x00\x00\x00\x00\x00'\
- b'\x00\x00\x00\x00\x00\x00\x00\x00\x00d\xFF\xFF\xFF\xFF\x00\x00'\
- b'\x00\x81\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00d\x00\x00'\
- b'\x00\x01\x00\x00\x00\x00\x00\x00\x00d\x00\x00\x00\x01\x00\x00'\
- b'\x00\x00\x00\x00\x00d\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00'\
- b'\x00d\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x03\xE8\x00\xFF'\
- b'\xFF\xFF\xFF\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x00\x01'
+ STATE_VERSION1_QT5 = (
+ b""
+ b"\x00\x00\x00Z\x00s\x00i\x00l\x00x\x00.\x00g\x00u\x00i\x00.\x00"
+ b"d\x00i\x00a\x00l\x00o\x00g\x00.\x00D\x00a\x00t\x00a\x00F\x00i"
+ b"\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00.\x00D\x00a\x00t\x00"
+ b"a\x00F\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00\x00\x00"
+ b"\x01\x00\x00\x00\x0C\x00\x00\x00\x00#\x00\x00\x00\xFF\x00\x00"
+ b"\x00\x01\x00\x00\x00\x03\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF"
+ b"\xFF\xFF\x01\xFF\xFF\xFF\xFF\x01\x00\x00\x00\x01\x00\x00\x00\x00"
+ b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0C\x00\x00\x00"
+ b"\x00\xAA\x00\x00\x00\x0E\x00B\x00r\x00o\x00w\x00s\x00e\x00r\x00"
+ b"\x00\x00\x01\x00\x00\x00\x0C\x00\x00\x00\x00\x87\x00\x00\x00\xFF"
+ b"\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00"
+ b"\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
+ b"\x00\x00\x00\x01\x90\x00\x00\x00\x04\x01\x01\x00\x00\x00\x00\x00"
+ b"\x00\x00\x00\x00\x00\x00\x00\x00\x00d\xFF\xFF\xFF\xFF\x00\x00"
+ b"\x00\x81\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00d\x00\x00"
+ b"\x00\x01\x00\x00\x00\x00\x00\x00\x00d\x00\x00\x00\x01\x00\x00"
+ b"\x00\x00\x00\x00\x00d\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00"
+ b"\x00d\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x03\xE8\x00\xFF"
+ b"\xFF\xFF\xFF\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x00\x01"
+ )
"""Serialized state on Qt5. Generated using :meth:`printState`"""
def testAvoidRestoreRegression_Version1(self):
@@ -903,7 +962,9 @@ class TestDataFileDialogApi(testutils.TestCaseQt, _UtilsMixin):
browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
filename = _tmpDirectory + "/data.h5"
- url = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/foobar")
+ url = silx.io.url.DataUrl(
+ scheme="silx", file_path=filename, data_path="/group/foobar"
+ )
dialog.selectUrl(url.path())
self.qWaitForPendingActions(dialog)
self.assertIsNotNone(dialog._selectedData())
diff --git a/src/silx/gui/dialog/test/test_imagefiledialog.py b/src/silx/gui/dialog/test/test_imagefiledialog.py
index 79c12ed..9d2c414 100644
--- a/src/silx/gui/dialog/test/test_imagefiledialog.py
+++ b/src/silx/gui/dialog/test/test_imagefiledialog.py
@@ -28,7 +28,6 @@ __license__ = "MIT"
__date__ = "08/03/2019"
-import unittest
import tempfile
import numpy
import shutil
@@ -106,7 +105,6 @@ def tearDownModule():
class _UtilsMixin(object):
-
def createDialog(self):
self._deleteDialog()
self._dialog = self._createDialog()
@@ -146,7 +144,6 @@ class _UtilsMixin(object):
class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
-
def tearDown(self):
self._deleteDialog()
testutils.TestCaseQt.tearDown(self)
@@ -201,6 +198,9 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.assertEqual(dialog.result(), qt.QDialog.Accepted)
def testClickOnShortcut(self):
+ if qt.BINDING == "PySide6":
+ self.skipTest("Avoid segmentation fault with PySide6")
+
dialog = self.createDialog()
dialog.show()
self.qWaitForWindowExposed(dialog)
@@ -264,12 +264,16 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
path = silx.io.url.DataUrl(file_path=filename, data_path="/group/image").path()
dialog.selectUrl(path)
self.qWaitForPendingActions(dialog)
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/image").path()
+ path = silx.io.url.DataUrl(
+ scheme="silx", file_path=filename, data_path="/group/image"
+ ).path()
self.assertSamePath(url.text(), path)
# test
self.mouseClick(toParentButton, qt.Qt.LeftButton)
self.qWaitForPendingActions(dialog)
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
+ path = silx.io.url.DataUrl(
+ scheme="silx", file_path=filename, data_path="/"
+ ).path()
self.assertSamePath(url.text(), path)
self.mouseClick(toParentButton, qt.Qt.LeftButton)
@@ -291,7 +295,9 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
filename = _tmpDirectory + "/data.h5"
# init state
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/image").path()
+ path = silx.io.url.DataUrl(
+ scheme="silx", file_path=filename, data_path="/group/image"
+ ).path()
dialog.selectUrl(path)
self.qWaitForPendingActions(dialog)
self.assertSamePath(url.text(), path)
@@ -299,7 +305,9 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
# test
self.mouseClick(button, qt.Qt.LeftButton)
self.qWaitForPendingActions(dialog)
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
+ path = silx.io.url.DataUrl(
+ scheme="silx", file_path=filename, data_path="/"
+ ).path()
self.assertSamePath(url.text(), path)
# self.assertFalse(button.isEnabled())
@@ -317,7 +325,9 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
path = silx.io.url.DataUrl(file_path=filename, data_path="/group/image").path()
dialog.selectUrl(path)
self.qWaitForPendingActions(dialog)
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/image").path()
+ path = silx.io.url.DataUrl(
+ scheme="silx", file_path=filename, data_path="/group/image"
+ ).path()
self.assertSamePath(url.text(), path)
self.assertTrue(button.isEnabled())
# test
@@ -336,8 +346,12 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.qWaitForWindowExposed(dialog)
url = testutils.findChildren(dialog, qt.QLineEdit, name="url")[0]
- forwardAction = testutils.findChildren(dialog, qt.QAction, name="forwardAction")[0]
- backwardAction = testutils.findChildren(dialog, qt.QAction, name="backwardAction")[0]
+ forwardAction = testutils.findChildren(
+ dialog, qt.QAction, name="forwardAction"
+ )[0]
+ backwardAction = testutils.findChildren(
+ dialog, qt.QAction, name="backwardAction"
+ )[0]
filename = _tmpDirectory + "/data.h5"
dialog.setDirectory(_tmpDirectory)
@@ -346,10 +360,14 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
# Then we feed the history using selectPath
dialog.selectUrl(filename)
self.qWaitForPendingActions(dialog)
- path2 = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
+ path2 = silx.io.url.DataUrl(
+ scheme="silx", file_path=filename, data_path="/"
+ ).path()
dialog.selectUrl(path2)
self.qWaitForPendingActions(dialog)
- path3 = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group").path()
+ path3 = silx.io.url.DataUrl(
+ scheme="silx", file_path=filename, data_path="/group"
+ ).path()
dialog.selectUrl(path3)
self.qWaitForPendingActions(dialog)
self.assertFalse(forwardAction.isEnabled())
@@ -412,7 +430,9 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
# init state
filename = _tmpDirectory + "/multiframe.edf"
- path = silx.io.url.DataUrl(scheme="fabio", file_path=filename, data_slice=(1,)).path()
+ path = silx.io.url.DataUrl(
+ scheme="fabio", file_path=filename, data_slice=(1,)
+ ).path()
dialog.selectUrl(path)
# test
image = dialog.selectedImage()
@@ -442,7 +462,9 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
# init state
filename = _tmpDirectory + "/data.h5"
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/image").path()
+ path = silx.io.url.DataUrl(
+ scheme="silx", file_path=filename, data_path="/image"
+ ).path()
dialog.selectUrl(path)
# test
self.assertEqual(dialog.selectedImage().shape, (100, 100))
@@ -459,7 +481,9 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.qWaitForPendingActions(dialog)
browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
filename = _tmpDirectory + "/data.h5"
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
+ path = silx.io.url.DataUrl(
+ scheme="silx", file_path=filename, data_path="/"
+ ).path()
index = browser.rootIndex().model().index(filename)
# click
browser.selectIndex(index)
@@ -476,7 +500,9 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
# init state
filename = _tmpDirectory + "/data.h5"
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/cube", data_slice=(1, )).path()
+ path = silx.io.url.DataUrl(
+ scheme="silx", file_path=filename, data_path="/cube", data_slice=(1,)
+ ).path()
dialog.selectUrl(path)
# test
self.assertEqual(dialog.selectedImage().shape, (100, 100))
@@ -491,7 +517,12 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
# init state
filename = _tmpDirectory + "/data.h5"
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/single_frame", data_slice=(0, )).path()
+ path = silx.io.url.DataUrl(
+ scheme="silx",
+ file_path=filename,
+ data_path="/single_frame",
+ data_slice=(0,),
+ ).path()
dialog.selectUrl(path)
# test
self.assertEqual(dialog.selectedImage().shape, (100, 100))
@@ -534,25 +565,30 @@ class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
self.qWaitForWindowExposed(dialog)
dialog.selectUrl(_tmpDirectory)
self.qWaitForPendingActions(dialog)
- self.assertEqual(self._countSelectableItems(browser.model(), browser.rootIndex()), 6)
+ self.assertEqual(
+ self._countSelectableItems(browser.model(), browser.rootIndex()), 6
+ )
codecName = fabio.edfimage.EdfImage.codec_name()
index = filters.indexFromCodec(codecName)
filters.setCurrentIndex(index)
filters.activated[int].emit(index)
self.qWait(50)
- self.assertEqual(self._countSelectableItems(browser.model(), browser.rootIndex()), 4)
+ self.assertEqual(
+ self._countSelectableItems(browser.model(), browser.rootIndex()), 4
+ )
codecName = fabio.fit2dmaskimage.Fit2dMaskImage.codec_name()
index = filters.indexFromCodec(codecName)
filters.setCurrentIndex(index)
filters.activated[int].emit(index)
self.qWait(50)
- self.assertEqual(self._countSelectableItems(browser.model(), browser.rootIndex()), 2)
+ self.assertEqual(
+ self._countSelectableItems(browser.model(), browser.rootIndex()), 2
+ )
class TestImageFileDialogApi(testutils.TestCaseQt, _UtilsMixin):
-
def tearDown(self):
self._deleteDialog()
testutils.TestCaseQt.tearDown(self)
@@ -606,51 +642,55 @@ class TestImageFileDialogApi(testutils.TestCaseQt, _UtilsMixin):
print()
print("\\\n".join(strings))
- STATE_VERSION1_QT4 = b''\
- b'\x00\x00\x00^\x00s\x00i\x00l\x00x\x00.\x00g\x00u\x00i\x00.\x00'\
- b'd\x00i\x00a\x00l\x00o\x00g\x00.\x00I\x00m\x00a\x00g\x00e\x00F'\
- b'\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00.\x00I\x00m\x00'\
- b'a\x00g\x00e\x00F\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g'\
- b'\x00\x00\x00\x01\x00\x00\x00\x0C\x00\x00\x00\x00"\x00\x00\x00'\
- b'\xFF\x00\x00\x00\x00\x00\x00\x00\x03\xFF\xFF\xFF\xFF\xFF\xFF\xFF'\
- b'\xFF\xFF\xFF\xFF\xFF\x01\x00\x00\x00\x06\x01\x00\x00\x00\x01\x00'\
- b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0C\x00'\
- b'\x00\x00\x00}\x00\x00\x00\x0E\x00B\x00r\x00o\x00w\x00s\x00e\x00'\
- b'r\x00\x00\x00\x01\x00\x00\x00\x0C\x00\x00\x00\x00Z\x00\x00\x00'\
- b'\xFF\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00'\
- b'\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'\
- b'\x00\x00\x00\x00\x01\x90\x00\x00\x00\x04\x01\x01\x00\x00\x00\x00'\
- b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00d\xFF\xFF\xFF\xFF\x00'\
- b'\x00\x00\x81\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x01\x90\x00'\
- b'\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x01\x00'\
- b'\x00\x00\x0C\x00\x00\x00\x000\x00\x00\x00\x10\x00C\x00o\x00l\x00'\
- b'o\x00r\x00m\x00a\x00p\x00\x00\x00\x01\x00\x00\x00\x08\x00g\x00'\
- b'r\x00a\x00y\x01\x01\x00\x00\x00\x06\x00l\x00o\x00g'
+ STATE_VERSION1_QT4 = (
+ b""
+ b"\x00\x00\x00^\x00s\x00i\x00l\x00x\x00.\x00g\x00u\x00i\x00.\x00"
+ b"d\x00i\x00a\x00l\x00o\x00g\x00.\x00I\x00m\x00a\x00g\x00e\x00F"
+ b"\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00.\x00I\x00m\x00"
+ b"a\x00g\x00e\x00F\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g"
+ b'\x00\x00\x00\x01\x00\x00\x00\x0C\x00\x00\x00\x00"\x00\x00\x00'
+ b"\xFF\x00\x00\x00\x00\x00\x00\x00\x03\xFF\xFF\xFF\xFF\xFF\xFF\xFF"
+ b"\xFF\xFF\xFF\xFF\xFF\x01\x00\x00\x00\x06\x01\x00\x00\x00\x01\x00"
+ b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0C\x00"
+ b"\x00\x00\x00}\x00\x00\x00\x0E\x00B\x00r\x00o\x00w\x00s\x00e\x00"
+ b"r\x00\x00\x00\x01\x00\x00\x00\x0C\x00\x00\x00\x00Z\x00\x00\x00"
+ b"\xFF\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00"
+ b"\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
+ b"\x00\x00\x00\x00\x01\x90\x00\x00\x00\x04\x01\x01\x00\x00\x00\x00"
+ b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00d\xFF\xFF\xFF\xFF\x00"
+ b"\x00\x00\x81\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x01\x90\x00"
+ b"\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x01\x00"
+ b"\x00\x00\x0C\x00\x00\x00\x000\x00\x00\x00\x10\x00C\x00o\x00l\x00"
+ b"o\x00r\x00m\x00a\x00p\x00\x00\x00\x01\x00\x00\x00\x08\x00g\x00"
+ b"r\x00a\x00y\x01\x01\x00\x00\x00\x06\x00l\x00o\x00g"
+ )
"""Serialized state on Qt4. Generated using :meth:`printState`"""
- STATE_VERSION1_QT5 = b''\
- b'\x00\x00\x00^\x00s\x00i\x00l\x00x\x00.\x00g\x00u\x00i\x00.\x00'\
- b'd\x00i\x00a\x00l\x00o\x00g\x00.\x00I\x00m\x00a\x00g\x00e\x00F'\
- b'\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00.\x00I\x00m\x00'\
- b'a\x00g\x00e\x00F\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g'\
- b'\x00\x00\x00\x01\x00\x00\x00\x0C\x00\x00\x00\x00#\x00\x00\x00'\
- b'\xFF\x00\x00\x00\x01\x00\x00\x00\x03\xFF\xFF\xFF\xFF\xFF\xFF\xFF'\
- b'\xFF\xFF\xFF\xFF\xFF\x01\xFF\xFF\xFF\xFF\x01\x00\x00\x00\x01\x00'\
- b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0C'\
- b'\x00\x00\x00\x00\xAA\x00\x00\x00\x0E\x00B\x00r\x00o\x00w\x00s'\
- b'\x00e\x00r\x00\x00\x00\x01\x00\x00\x00\x0C\x00\x00\x00\x00\x87'\
- b'\x00\x00\x00\xFF\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00'\
- b'\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'\
- b'\x00\x00\x00\x00\x00\x00\x00\x01\x90\x00\x00\x00\x04\x01\x01\x00'\
- b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00d\xFF\xFF'\
- b'\xFF\xFF\x00\x00\x00\x81\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00'\
- b'\x00d\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00d\x00\x00\x00'\
- b'\x01\x00\x00\x00\x00\x00\x00\x00d\x00\x00\x00\x01\x00\x00\x00'\
- b'\x00\x00\x00\x00d\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x03'\
- b'\xE8\x00\xFF\xFF\xFF\xFF\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00'\
- b'\x00\x0C\x00\x00\x00\x000\x00\x00\x00\x10\x00C\x00o\x00l\x00o'\
- b'\x00r\x00m\x00a\x00p\x00\x00\x00\x01\x00\x00\x00\x08\x00g\x00'\
- b'r\x00a\x00y\x01\x01\x00\x00\x00\x06\x00l\x00o\x00g'
+ STATE_VERSION1_QT5 = (
+ b""
+ b"\x00\x00\x00^\x00s\x00i\x00l\x00x\x00.\x00g\x00u\x00i\x00.\x00"
+ b"d\x00i\x00a\x00l\x00o\x00g\x00.\x00I\x00m\x00a\x00g\x00e\x00F"
+ b"\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00.\x00I\x00m\x00"
+ b"a\x00g\x00e\x00F\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g"
+ b"\x00\x00\x00\x01\x00\x00\x00\x0C\x00\x00\x00\x00#\x00\x00\x00"
+ b"\xFF\x00\x00\x00\x01\x00\x00\x00\x03\xFF\xFF\xFF\xFF\xFF\xFF\xFF"
+ b"\xFF\xFF\xFF\xFF\xFF\x01\xFF\xFF\xFF\xFF\x01\x00\x00\x00\x01\x00"
+ b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0C"
+ b"\x00\x00\x00\x00\xAA\x00\x00\x00\x0E\x00B\x00r\x00o\x00w\x00s"
+ b"\x00e\x00r\x00\x00\x00\x01\x00\x00\x00\x0C\x00\x00\x00\x00\x87"
+ b"\x00\x00\x00\xFF\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00"
+ b"\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
+ b"\x00\x00\x00\x00\x00\x00\x00\x01\x90\x00\x00\x00\x04\x01\x01\x00"
+ b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00d\xFF\xFF"
+ b"\xFF\xFF\x00\x00\x00\x81\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00"
+ b"\x00d\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00d\x00\x00\x00"
+ b"\x01\x00\x00\x00\x00\x00\x00\x00d\x00\x00\x00\x01\x00\x00\x00"
+ b"\x00\x00\x00\x00d\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x03"
+ b"\xE8\x00\xFF\xFF\xFF\xFF\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00"
+ b"\x00\x0C\x00\x00\x00\x000\x00\x00\x00\x10\x00C\x00o\x00l\x00o"
+ b"\x00r\x00m\x00a\x00p\x00\x00\x00\x01\x00\x00\x00\x08\x00g\x00"
+ b"r\x00a\x00y\x01\x01\x00\x00\x00\x06\x00l\x00o\x00g"
+ )
"""Serialized state on Qt5. Generated using :meth:`printState`"""
def testAvoidRestoreRegression_Version1(self):
@@ -757,7 +797,9 @@ class TestImageFileDialogApi(testutils.TestCaseQt, _UtilsMixin):
browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
filename = _tmpDirectory + "/data.h5"
- url = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/foobar")
+ url = silx.io.url.DataUrl(
+ scheme="silx", file_path=filename, data_path="/group/foobar"
+ )
dialog.selectUrl(url.path())
self.qWaitForPendingActions(dialog)
self.assertIsNone(dialog._selectedData())
diff --git a/src/silx/gui/dialog/utils.py b/src/silx/gui/dialog/utils.py
index e07cf9f..1697bcf 100644
--- a/src/silx/gui/dialog/utils.py
+++ b/src/silx/gui/dialog/utils.py
@@ -85,7 +85,7 @@ def patchToConsumeReturnKey(widget):
Monkey-patch a widget to consume the return key instead of propagating it
to the dialog.
"""
- assert(not hasattr(widget, "_oldKeyPressEvent"))
+ assert not hasattr(widget, "_oldKeyPressEvent")
def keyPressEvent(self, event):
k = event.key()
diff --git a/src/silx/gui/fit/BackgroundWidget.py b/src/silx/gui/fit/BackgroundWidget.py
index 9ab63e4..d9cfcc8 100644
--- a/src/silx/gui/fit/BackgroundWidget.py
+++ b/src/silx/gui/fit/BackgroundWidget.py
@@ -1,4 +1,4 @@
-#/*##########################################################################
+# /*##########################################################################
# Copyright (C) 2004-2021 V.A. Sole, European Synchrotron Radiation Facility
#
# This file is part of the PyMca X-ray Fluorescence Toolkit developed at
@@ -44,8 +44,9 @@ __date__ = "28/06/2017"
class HorizontalSpacer(qt.QWidget):
def __init__(self, *args):
qt.QWidget.__init__(self, *args)
- self.setSizePolicy(qt.QSizePolicy(qt.QSizePolicy.Expanding,
- qt.QSizePolicy.Fixed))
+ self.setSizePolicy(
+ qt.QSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
+ )
class BackgroundParamWidget(qt.QWidget):
@@ -56,6 +57,7 @@ class BackgroundParamWidget(qt.QWidget):
Updating the widgets causes :attr:`sigBackgroundParamWidgetSignal` to
be emitted.
"""
+
sigBackgroundParamWidgetSignal = qt.pyqtSignal(object)
def __init__(self, parent=None):
@@ -70,8 +72,7 @@ class BackgroundParamWidget(qt.QWidget):
self.algorithmCombo = qt.QComboBox(self)
self.algorithmCombo.addItem("Strip")
self.algorithmCombo.addItem("Snip")
- self.algorithmCombo.activated[int].connect(
- self._algorithmComboActivated)
+ self.algorithmCombo.activated[int].connect(self._algorithmComboActivated)
# Strip parameters ---------------------------------------------------
self.stripWidthLabel = qt.QLabel(self)
@@ -90,9 +91,10 @@ class BackgroundParamWidget(qt.QWidget):
self.stripIterValue.setText("0")
self.stripIterValue.editingFinished[()].connect(self._emitSignal)
self.stripIterValue.setToolTip(
- "Number of iterations for strip algorithm.\n" +
- "If greater than 999, an 2nd pass of strip filter is " +
- "applied to remove artifacts created by first pass.")
+ "Number of iterations for strip algorithm.\n"
+ + "If greater than 999, an 2nd pass of strip filter is "
+ + "applied to remove artifacts created by first pass."
+ )
# Snip parameters ----------------------------------------------------
self.snipWidthLabel = qt.QLabel(self)
@@ -103,7 +105,6 @@ class BackgroundParamWidget(qt.QWidget):
self.snipWidthSpin.setMinimum(0)
self.snipWidthSpin.valueChanged[int].connect(self._emitSignal)
-
# Smoothing parameters -----------------------------------------------
self.smoothingFlagCheck = qt.QCheckBox(self)
self.smoothingFlagCheck.setText("Smoothing Width (Savitsky-Golay)")
@@ -111,7 +112,7 @@ class BackgroundParamWidget(qt.QWidget):
self.smoothingSpin = qt.QSpinBox(self)
self.smoothingSpin.setMinimum(3)
- #self.smoothingSpin.setMaximum(40)
+ # self.smoothingSpin.setMaximum(40)
self.smoothingSpin.setSingleStep(2)
self.smoothingSpin.valueChanged[int].connect(self._emitSignal)
@@ -125,12 +126,12 @@ class BackgroundParamWidget(qt.QWidget):
self.anchorsFlagCheck = qt.QCheckBox(self.anchorsGroup)
self.anchorsFlagCheck.setText("Use anchors")
self.anchorsFlagCheck.setToolTip(
- "Define X coordinates of points that must remain fixed")
- self.anchorsFlagCheck.stateChanged[int].connect(
- self._anchorsToggled)
+ "Define X coordinates of points that must remain fixed"
+ )
+ self.anchorsFlagCheck.stateChanged[int].connect(self._anchorsToggled)
anchorsLayout.addWidget(self.anchorsFlagCheck)
- maxnchannel = 16384 * 4 # Fixme ?
+ maxnchannel = 16384 * 4 # Fixme ?
self.anchorsList = []
num_anchors = 4
for i in range(num_anchors):
@@ -170,8 +171,7 @@ class BackgroundParamWidget(qt.QWidget):
:param algorithm: "snip" or "strip"
"""
if algorithm not in ["strip", "snip"]:
- raise ValueError(
- "Unknown background filter algorithm %s" % algorithm)
+ raise ValueError("Unknown background filter algorithm %s" % algorithm)
self.algorithm = algorithm
self.stripWidthSpin.setEnabled(algorithm == "strip")
@@ -220,7 +220,7 @@ class BackgroundParamWidget(qt.QWidget):
if "AnchorsList" in ddict:
anchorslist = ddict["AnchorsList"]
- if anchorslist in [None, 'None']:
+ if anchorslist in [None, "None"]:
anchorslist = []
for spin in self.anchorsList:
spin.setValue(0)
@@ -248,20 +248,22 @@ class BackgroundParamWidget(qt.QWidget):
stripitertext = self.stripIterValue.text()
stripiter = int(stripitertext) if len(stripitertext) else 0
- return {"algorithm": self.algorithm,
- "StripThreshold": 1.0,
- "SnipWidth": self.snipWidthSpin.value(),
- "StripIterations": stripiter,
- "StripWidth": self.stripWidthSpin.value(),
- "SmoothingFlag": self.smoothingFlagCheck.isChecked(),
- "SmoothingWidth": self.smoothingSpin.value(),
- "AnchorsFlag": self.anchorsFlagCheck.isChecked(),
- "AnchorsList": [spin.value() for spin in self.anchorsList]}
+ return {
+ "algorithm": self.algorithm,
+ "StripThreshold": 1.0,
+ "SnipWidth": self.snipWidthSpin.value(),
+ "StripIterations": stripiter,
+ "StripWidth": self.stripWidthSpin.value(),
+ "SmoothingFlag": self.smoothingFlagCheck.isChecked(),
+ "SmoothingWidth": self.smoothingSpin.value(),
+ "AnchorsFlag": self.anchorsFlagCheck.isChecked(),
+ "AnchorsList": [spin.value() for spin in self.anchorsList],
+ }
def _emitSignal(self, dummy=None):
self.sigBackgroundParamWidgetSignal.emit(
- {'event': 'ParametersChanged',
- 'parameters': self.getParameters()})
+ {"event": "ParametersChanged", "parameters": self.getParameters()}
+ )
class BackgroundWidget(qt.QWidget):
@@ -270,6 +272,7 @@ class BackgroundWidget(qt.QWidget):
Strip and snip filters parameters can be adjusted using input widgets,
and the computed backgrounds are plotted next to the original data to
show the result."""
+
def __init__(self, parent=None):
qt.QWidget.__init__(self, parent)
self.setWindowTitle("Strip and SNIP Configuration Window")
@@ -328,8 +331,7 @@ class BackgroundWidget(qt.QWidget):
self._update()
def _update(self, resetzoom=False):
- """Compute strip and snip backgrounds, update the curves
- """
+ """Compute strip and snip backgrounds, update the curves"""
if self._y is None:
return
@@ -338,7 +340,7 @@ class BackgroundWidget(qt.QWidget):
# smoothed data
y = numpy.ravel(numpy.array(self._y)).astype(numpy.float64)
if pars["SmoothingFlag"]:
- ysmooth = filters.savitsky_golay(y, pars['SmoothingWidth'])
+ ysmooth = filters.savitsky_golay(y, pars["SmoothingWidth"])
f = [0.25, 0.5, 0.25]
ysmooth[1:-1] = numpy.convolve(ysmooth, f, mode=0)
ysmooth[0] = 0.5 * (ysmooth[0] + ysmooth[1])
@@ -346,14 +348,13 @@ class BackgroundWidget(qt.QWidget):
else:
ysmooth = y
-
# loop for anchors
x = self._x
- niter = pars['StripIterations']
+ niter = pars["StripIterations"]
anchors_indices = []
- if pars['AnchorsFlag'] and pars['AnchorsList'] is not None:
+ if pars["AnchorsFlag"] and pars["AnchorsList"] is not None:
ravelled = x
- for channel in pars['AnchorsList']:
+ for channel in pars["AnchorsList"]:
if channel <= ravelled[0]:
continue
index = numpy.nonzero(ravelled >= channel)[0]
@@ -362,52 +363,56 @@ class BackgroundWidget(qt.QWidget):
if index > 0:
anchors_indices.append(index)
- stripBackground = filters.strip(ysmooth,
- w=pars['StripWidth'],
- niterations=niter,
- factor=pars['StripThreshold'],
- anchors=anchors_indices)
+ stripBackground = filters.strip(
+ ysmooth,
+ w=pars["StripWidth"],
+ niterations=niter,
+ factor=pars["StripThreshold"],
+ anchors=anchors_indices,
+ )
if niter >= 1000:
# final smoothing
- stripBackground = filters.strip(stripBackground,
- w=1,
- niterations=50*pars['StripWidth'],
- factor=pars['StripThreshold'],
- anchors=anchors_indices)
+ stripBackground = filters.strip(
+ stripBackground,
+ w=1,
+ niterations=50 * pars["StripWidth"],
+ factor=pars["StripThreshold"],
+ anchors=anchors_indices,
+ )
if len(anchors_indices) == 0:
- anchors_indices = [0, len(ysmooth)-1]
+ anchors_indices = [0, len(ysmooth) - 1]
anchors_indices.sort()
snipBackground = 0.0 * ysmooth
lastAnchor = 0
for anchor in anchors_indices:
if (anchor > lastAnchor) and (anchor < len(ysmooth)):
- snipBackground[lastAnchor:anchor] =\
- filters.snip1d(ysmooth[lastAnchor:anchor],
- pars['SnipWidth'])
+ snipBackground[lastAnchor:anchor] = filters.snip1d(
+ ysmooth[lastAnchor:anchor], pars["SnipWidth"]
+ )
lastAnchor = anchor
if lastAnchor < len(ysmooth):
- snipBackground[lastAnchor:] =\
- filters.snip1d(ysmooth[lastAnchor:],
- pars['SnipWidth'])
-
- self.graphWidget.addCurve(x, y,
- legend='Input Data',
- replace=True,
- resetzoom=resetzoom)
- self.graphWidget.addCurve(x, stripBackground,
- legend='Strip Background',
- resetzoom=False)
- self.graphWidget.addCurve(x, snipBackground,
- legend='SNIP Background',
- resetzoom=False)
+ snipBackground[lastAnchor:] = filters.snip1d(
+ ysmooth[lastAnchor:], pars["SnipWidth"]
+ )
+
+ self.graphWidget.addCurve(
+ x, y, legend="Input Data", replace=True, resetzoom=resetzoom
+ )
+ self.graphWidget.addCurve(
+ x, stripBackground, legend="Strip Background", resetzoom=False
+ )
+ self.graphWidget.addCurve(
+ x, snipBackground, legend="SNIP Background", resetzoom=False
+ )
if self._xmin is not None and self._xmax is not None:
self.graphWidget.getXAxis().setLimits(self._xmin, self._xmax)
class BackgroundDialog(qt.QDialog):
"""QDialog window featuring a :class:`BackgroundWidget`"""
+
def __init__(self, parent=None):
qt.QDialog.__init__(self, parent)
self.setWindowTitle("Strip and Snip Configuration Window")
@@ -452,14 +457,15 @@ class BackgroundDialog(qt.QDialog):
# self.output = ddict
def accept(self):
- """Update :attr:`output`, then call :meth:`QDialog.accept`
- """
+ """Update :attr:`output`, then call :meth:`QDialog.accept`"""
self.output = self.getParameters()
super(BackgroundDialog, self).accept()
def sizeHint(self):
- return qt.QSize(int(1.5*qt.QDialog.sizeHint(self).width()),
- qt.QDialog.sizeHint(self).height())
+ return qt.QSize(
+ int(1.5 * qt.QDialog.sizeHint(self).width()),
+ qt.QDialog.sizeHint(self).height(),
+ )
def setData(self, x, y, xmin=None, xmax=None):
"""See :meth:`BackgroundWidget.setData`"""
@@ -498,11 +504,7 @@ def main():
x = numpy.arange(5000)
# (height1, center1, fwhm1, ...) 5 peaks
- params1 = (50, 500, 100,
- 20, 2000, 200,
- 50, 2250, 100,
- 40, 3000, 75,
- 23, 4000, 150)
+ params1 = (50, 500, 100, 20, 2000, 200, 50, 2250, 100, 40, 3000, 75, 23, 4000, 150)
y0 = sum_gauss(x, *params1)
# random values between [-1;1]
@@ -527,7 +529,8 @@ def main():
w.parametersWidget.parametersWidget.sigBackgroundParamWidgetSignal.connect(mySlot)
w.setData(x, y)
w.exec()
- #a.exec()
+ # a.exec()
+
if __name__ == "__main__":
main()
diff --git a/src/silx/gui/fit/FitConfig.py b/src/silx/gui/fit/FitConfig.py
index 09dbfaa..5887b4a 100644
--- a/src/silx/gui/fit/FitConfig.py
+++ b/src/silx/gui/fit/FitConfig.py
@@ -45,6 +45,7 @@ class TabsDialog(qt.QDialog):
This dialog defines a __len__ returning the number of tabs,
and an __iter__ method yielding the tab widgets.
"""
+
def __init__(self, parent=None):
qt.QDialog.__init__(self, parent)
self.tabWidget = qt.QTabWidget(self)
@@ -62,9 +63,9 @@ class TabsDialog(qt.QDialog):
self.buttonDefault.setText("Undo changes")
layout2.addWidget(self.buttonDefault)
- spacer = qt.QSpacerItem(20, 20,
- qt.QSizePolicy.Expanding,
- qt.QSizePolicy.Minimum)
+ spacer = qt.QSpacerItem(
+ 20, 20, qt.QSizePolicy.Expanding, qt.QSizePolicy.Minimum
+ )
layout2.addItem(spacer)
self.buttonOk = qt.QPushButton(self)
@@ -119,6 +120,7 @@ class TabsDialogData(TabsDialog):
A default dictionary can be supplied when this dialog is initialized, to
be used as default data for :attr:`output`.
"""
+
def __init__(self, parent=None, modal=True, default=None):
"""
@@ -197,11 +199,14 @@ class ConstraintsPage(qt.QGroupBox):
"""Checkable QGroupBox widget filled with QCheckBox widgets,
to configure the fit estimation for standard fit theories.
"""
+
def __init__(self, parent=None, title="Set constraints"):
super(ConstraintsPage, self).__init__(parent)
self.setTitle(title)
- self.setToolTip("Disable 'Set constraints' to remove all " +
- "constraints on all fit parameters")
+ self.setToolTip(
+ "Disable 'Set constraints' to remove all "
+ + "constraints on all fit parameters"
+ )
self.setCheckable(True)
layout = qt.QVBoxLayout(self)
@@ -212,8 +217,7 @@ class ConstraintsPage(qt.QGroupBox):
layout.addWidget(self.positiveHeightCB)
self.positionInIntervalCB = qt.QCheckBox("Force position in interval", self)
- self.positionInIntervalCB.setToolTip(
- "Fit must position peak within X limits")
+ self.positionInIntervalCB.setToolTip("Fit must position peak within X limits")
layout.addWidget(self.positionInIntervalCB)
self.positiveFwhmCB = qt.QCheckBox("Force positive FWHM", self)
@@ -226,7 +230,8 @@ class ConstraintsPage(qt.QGroupBox):
self.quotedEtaCB = qt.QCheckBox("Force Eta between 0 and 1", self)
self.quotedEtaCB.setToolTip(
- "Fit must find Eta between 0 and 1 for pseudo-Voigt function")
+ "Fit must find Eta between 0 and 1 for pseudo-Voigt function"
+ )
layout.addWidget(self.quotedEtaCB)
layout.addStretch()
@@ -241,29 +246,27 @@ class ConstraintsPage(qt.QGroupBox):
if default_dict is None:
default_dict = {}
# this one uses reverse logic: if checked, NoConstraintsFlag must be False
- self.setChecked(
- not default_dict.get('NoConstraintsFlag', False))
+ self.setChecked(not default_dict.get("NoConstraintsFlag", False))
self.positiveHeightCB.setChecked(
- default_dict.get('PositiveHeightAreaFlag', True))
+ default_dict.get("PositiveHeightAreaFlag", True)
+ )
self.positionInIntervalCB.setChecked(
- default_dict.get('QuotedPositionFlag', False))
- self.positiveFwhmCB.setChecked(
- default_dict.get('PositiveFwhmFlag', True))
- self.sameFwhmCB.setChecked(
- default_dict.get('SameFwhmFlag', False))
- self.quotedEtaCB.setChecked(
- default_dict.get('QuotedEtaFlag', False))
+ default_dict.get("QuotedPositionFlag", False)
+ )
+ self.positiveFwhmCB.setChecked(default_dict.get("PositiveFwhmFlag", True))
+ self.sameFwhmCB.setChecked(default_dict.get("SameFwhmFlag", False))
+ self.quotedEtaCB.setChecked(default_dict.get("QuotedEtaFlag", False))
def get(self):
"""Return a dictionary of constraint flags, to be processed by the
:meth:`configure` method of the selected fit theory."""
ddict = {
- 'NoConstraintsFlag': not self.isChecked(),
- 'PositiveHeightAreaFlag': self.positiveHeightCB.isChecked(),
- 'QuotedPositionFlag': self.positionInIntervalCB.isChecked(),
- 'PositiveFwhmFlag': self.positiveFwhmCB.isChecked(),
- 'SameFwhmFlag': self.sameFwhmCB.isChecked(),
- 'QuotedEtaFlag': self.quotedEtaCB.isChecked(),
+ "NoConstraintsFlag": not self.isChecked(),
+ "PositiveHeightAreaFlag": self.positiveHeightCB.isChecked(),
+ "QuotedPositionFlag": self.positionInIntervalCB.isChecked(),
+ "PositiveFwhmFlag": self.positiveFwhmCB.isChecked(),
+ "SameFwhmFlag": self.sameFwhmCB.isChecked(),
+ "QuotedEtaFlag": self.quotedEtaCB.isChecked(),
}
return ddict
@@ -276,8 +279,9 @@ class SearchPage(qt.QWidget):
self.manualFwhmGB = qt.QGroupBox("Define FWHM manually", self)
self.manualFwhmGB.setCheckable(True)
self.manualFwhmGB.setToolTip(
- "If disabled, the FWHM parameter used for peak search is " +
- "estimated based on the highest peak in the data")
+ "If disabled, the FWHM parameter used for peak search is "
+ + "estimated based on the highest peak in the data"
+ )
layout.addWidget(self.manualFwhmGB)
# ------------ GroupBox fwhm--------------------------
layout2 = qt.QHBoxLayout(self.manualFwhmGB)
@@ -295,8 +299,9 @@ class SearchPage(qt.QWidget):
self.manualScalingGB = qt.QGroupBox("Define scaling manually", self)
self.manualScalingGB.setCheckable(True)
self.manualScalingGB.setToolTip(
- "If disabled, the Y scaling used for peak search is " +
- "estimated automatically")
+ "If disabled, the Y scaling used for peak search is "
+ + "estimated automatically"
+ )
layout.addWidget(self.manualScalingGB)
# ------------ GroupBox scaling-----------------------
layout3 = qt.QHBoxLayout(self.manualScalingGB)
@@ -307,8 +312,8 @@ class SearchPage(qt.QWidget):
self.yScalingEntry = qt.QLineEdit(self.manualScalingGB)
self.yScalingEntry.setToolTip(
- "Data values will be multiplied by this value prior to peak" +
- " search")
+ "Data values will be multiplied by this value prior to peak" + " search"
+ )
self.yScalingEntry.setValidator(qt.QDoubleValidator(self))
layout3.addWidget(self.yScalingEntry)
# ----------------------------------------------------
@@ -323,9 +328,10 @@ class SearchPage(qt.QWidget):
self.sensitivityEntry = qt.QLineEdit(containerWidget)
self.sensitivityEntry.setToolTip(
- "Peak search sensitivity threshold, expressed as a multiple " +
- "of the standard deviation of the noise.\nMinimum value is 1 " +
- "(to be detected, peak must be higher than the estimated noise)")
+ "Peak search sensitivity threshold, expressed as a multiple "
+ + "of the standard deviation of the noise.\nMinimum value is 1 "
+ + "(to be detected, peak must be higher than the estimated noise)"
+ )
sensivalidator = qt.QDoubleValidator(self)
sensivalidator.setBottom(1.0)
self.sensitivityEntry.setValidator(sensivalidator)
@@ -335,8 +341,9 @@ class SearchPage(qt.QWidget):
self.forcePeakPresenceCB = qt.QCheckBox("Force peak presence", self)
self.forcePeakPresenceCB.setToolTip(
- "If peak search algorithm is unsuccessful, place one peak " +
- "at the maximum of the curve")
+ "If peak search algorithm is unsuccessful, place one peak "
+ + "at the maximum of the curve"
+ )
layout.addWidget(self.forcePeakPresenceCB)
layout.addStretch()
@@ -350,29 +357,25 @@ class SearchPage(qt.QWidget):
a parameter, its values are used as default values."""
if default_dict is None:
default_dict = {}
- self.manualFwhmGB.setChecked(
- not default_dict.get('AutoFwhm', True))
- self.fwhmPointsSpin.setValue(
- default_dict.get('FwhmPoints', 8))
- self.sensitivityEntry.setText(
- str(default_dict.get('Sensitivity', 1.0)))
- self.manualScalingGB.setChecked(
- not default_dict.get('AutoScaling', False))
- self.yScalingEntry.setText(
- str(default_dict.get('Yscaling', 1.0)))
+ self.manualFwhmGB.setChecked(not default_dict.get("AutoFwhm", True))
+ self.fwhmPointsSpin.setValue(default_dict.get("FwhmPoints", 8))
+ self.sensitivityEntry.setText(str(default_dict.get("Sensitivity", 1.0)))
+ self.manualScalingGB.setChecked(not default_dict.get("AutoScaling", False))
+ self.yScalingEntry.setText(str(default_dict.get("Yscaling", 1.0)))
self.forcePeakPresenceCB.setChecked(
- default_dict.get('ForcePeakPresence', False))
+ default_dict.get("ForcePeakPresence", False)
+ )
def get(self):
"""Return a dictionary of peak search parameters, to be processed by
the :meth:`configure` method of the selected fit theory."""
ddict = {
- 'AutoFwhm': not self.manualFwhmGB.isChecked(),
- 'FwhmPoints': self.fwhmPointsSpin.value(),
- 'Sensitivity': safe_float(self.sensitivityEntry.text()),
- 'AutoScaling': not self.manualScalingGB.isChecked(),
- 'Yscaling': safe_float(self.yScalingEntry.text()),
- 'ForcePeakPresence': self.forcePeakPresenceCB.isChecked()
+ "AutoFwhm": not self.manualFwhmGB.isChecked(),
+ "FwhmPoints": self.fwhmPointsSpin.value(),
+ "Sensitivity": safe_float(self.sensitivityEntry.text()),
+ "AutoScaling": not self.manualScalingGB.isChecked(),
+ "Yscaling": safe_float(self.yScalingEntry.text()),
+ "ForcePeakPresence": self.forcePeakPresenceCB.isChecked(),
}
return ddict
@@ -380,60 +383,69 @@ class SearchPage(qt.QWidget):
class BackgroundPage(qt.QGroupBox):
"""Background subtraction configuration, specific to fittheories
estimation functions."""
- def __init__(self, parent=None,
- title="Subtract strip background prior to estimation"):
+
+ def __init__(
+ self, parent=None, title="Subtract strip background prior to estimation"
+ ):
super(BackgroundPage, self).__init__(parent)
self.setTitle(title)
self.setCheckable(True)
self.setToolTip(
- "The strip algorithm strips away peaks to compute the " +
- "background signal.\nAt each iteration, a sample is compared " +
- "to the average of the two samples at a given distance in both" +
- " directions,\n and if its value is higher than the average,"
- "it is replaced by the average.")
+ "The strip algorithm strips away peaks to compute the "
+ + "background signal.\nAt each iteration, a sample is compared "
+ + "to the average of the two samples at a given distance in both"
+ + " directions,\n and if its value is higher than the average,"
+ "it is replaced by the average."
+ )
layout = qt.QGridLayout(self)
self.setLayout(layout)
for i, label_text in enumerate(
- ["Strip width (in samples)",
- "Number of iterations",
- "Strip threshold factor"]):
+ [
+ "Strip width (in samples)",
+ "Number of iterations",
+ "Strip threshold factor",
+ ]
+ ):
label = qt.QLabel(label_text)
layout.addWidget(label, i, 0)
self.stripWidthSpin = qt.QSpinBox(self)
self.stripWidthSpin.setToolTip(
- "Width, in number of samples, of the strip operator")
+ "Width, in number of samples, of the strip operator"
+ )
self.stripWidthSpin.setRange(1, 999999)
layout.addWidget(self.stripWidthSpin, 0, 1)
self.numIterationsSpin = qt.QSpinBox(self)
- self.numIterationsSpin.setToolTip(
- "Number of iterations of the strip algorithm")
+ self.numIterationsSpin.setToolTip("Number of iterations of the strip algorithm")
self.numIterationsSpin.setRange(1, 999999)
layout.addWidget(self.numIterationsSpin, 1, 1)
self.thresholdFactorEntry = qt.QLineEdit(self)
self.thresholdFactorEntry.setToolTip(
- "Factor used by the strip algorithm to decide whether a sample" +
- "value should be stripped.\nThe value must be higher than the " +
- "average of the 2 samples at +- w times this factor.\n")
+ "Factor used by the strip algorithm to decide whether a sample"
+ + "value should be stripped.\nThe value must be higher than the "
+ + "average of the 2 samples at +- w times this factor.\n"
+ )
self.thresholdFactorEntry.setValidator(qt.QDoubleValidator(self))
layout.addWidget(self.thresholdFactorEntry, 2, 1)
self.smoothStripGB = qt.QGroupBox("Apply smoothing prior to strip", self)
self.smoothStripGB.setCheckable(True)
self.smoothStripGB.setToolTip(
- "Apply a smoothing before subtracting strip background" +
- " in fit and estimate processes")
+ "Apply a smoothing before subtracting strip background"
+ + " in fit and estimate processes"
+ )
smoothlayout = qt.QHBoxLayout(self.smoothStripGB)
label = qt.QLabel("Smoothing width (Savitsky-Golay)")
smoothlayout.addWidget(label)
self.smoothingWidthSpin = qt.QSpinBox(self)
self.smoothingWidthSpin.setToolTip(
- "Width parameter for Savitsky-Golay smoothing (number of samples, must be odd)")
+ "Width parameter for Savitsky-Golay smoothing (number of samples, must be odd)"
+ )
self.smoothingWidthSpin.setRange(3, 101)
self.smoothingWidthSpin.setSingleStep(2)
smoothlayout.addWidget(self.smoothingWidthSpin)
@@ -452,31 +464,25 @@ class BackgroundPage(qt.QGroupBox):
if default_dict is None:
default_dict = {}
- self.setChecked(
- default_dict.get('StripBackgroundFlag', True))
+ self.setChecked(default_dict.get("StripBackgroundFlag", True))
- self.stripWidthSpin.setValue(
- default_dict.get('StripWidth', 2))
- self.numIterationsSpin.setValue(
- default_dict.get('StripIterations', 5000))
- self.thresholdFactorEntry.setText(
- str(default_dict.get('StripThreshold', 1.0)))
- self.smoothStripGB.setChecked(
- default_dict.get('SmoothingFlag', False))
- self.smoothingWidthSpin.setValue(
- default_dict.get('SmoothingWidth', 3))
+ self.stripWidthSpin.setValue(default_dict.get("StripWidth", 2))
+ self.numIterationsSpin.setValue(default_dict.get("StripIterations", 5000))
+ self.thresholdFactorEntry.setText(str(default_dict.get("StripThreshold", 1.0)))
+ self.smoothStripGB.setChecked(default_dict.get("SmoothingFlag", False))
+ self.smoothingWidthSpin.setValue(default_dict.get("SmoothingWidth", 3))
def get(self):
"""Return a dictionary of background subtraction parameters, to be
processed by the :meth:`configure` method of the selected fit theory.
"""
ddict = {
- 'StripBackgroundFlag': self.isChecked(),
- 'StripWidth': self.stripWidthSpin.value(),
- 'StripIterations': self.numIterationsSpin.value(),
- 'StripThreshold': safe_float(self.thresholdFactorEntry.text()),
- 'SmoothingFlag': self.smoothStripGB.isChecked(),
- 'SmoothingWidth': self.smoothingWidthSpin.value()
+ "StripBackgroundFlag": self.isChecked(),
+ "StripWidth": self.stripWidthSpin.value(),
+ "StripIterations": self.numIterationsSpin.value(),
+ "StripThreshold": safe_float(self.thresholdFactorEntry.text()),
+ "SmoothingFlag": self.smoothStripGB.isChecked(),
+ "SmoothingWidth": self.smoothingWidthSpin.value(),
}
return ddict
@@ -538,5 +544,6 @@ def main():
a.exec()
+
if __name__ == "__main__":
main()
diff --git a/src/silx/gui/fit/FitWidget.py b/src/silx/gui/fit/FitWidget.py
index 88f95cf..2487c23 100644
--- a/src/silx/gui/fit/FitWidget.py
+++ b/src/silx/gui/fit/FitWidget.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
#
# This file is part of the PyMca X-ray Fluorescence Toolkit developed at
# the ESRF by the Software group.
@@ -46,11 +46,14 @@ import traceback
from silx.math.fit import fittheories
from silx.math.fit import fitmanager, functions
from silx.gui import qt
-from .FitWidgets import (FitActionsButtons, FitStatusLines,
- FitConfigWidget, ParametersTab)
+from .FitWidgets import (
+ FitActionsButtons,
+ FitStatusLines,
+ FitConfigWidget,
+ ParametersTab,
+)
from .FitConfig import getFitConfigDialog
from .BackgroundWidget import getBgDialog, BackgroundDialog
-from ...utils.deprecation import deprecated
DEBUG = 0
_logger = logging.getLogger(__name__)
@@ -88,6 +91,7 @@ class FitWidget(qt.QWidget):
.. image:: img/FitWidget.png
"""
+
sigFitWidgetSignal = qt.Signal(object)
"""This signal is emitted by the estimation and fit methods.
It carries a dictionary with two items:
@@ -105,8 +109,15 @@ class FitWidget(qt.QWidget):
:attr:`silx.math.fit.fitmanager.FitManager.fit_results`)
"""
- def __init__(self, parent=None, title=None, fitmngr=None,
- enableconfig=True, enablestatus=True, enablebuttons=True):
+ def __init__(
+ self,
+ parent=None,
+ title=None,
+ fitmngr=None,
+ enableconfig=True,
+ enablestatus=True,
+ enablebuttons=True,
+ ):
"""
:param parent: Parent widget
@@ -199,15 +210,16 @@ class FitWidget(qt.QWidget):
"""Function selector and configuration widget"""
self.guiConfig.FunConfigureButton.clicked.connect(
- self.__funConfigureGuiSlot)
- self.guiConfig.BgConfigureButton.clicked.connect(
- self.__bgConfigureGuiSlot)
+ self.__funConfigureGuiSlot
+ )
+ self.guiConfig.BgConfigureButton.clicked.connect(self.__bgConfigureGuiSlot)
self.guiConfig.WeightCheckBox.setChecked(
- self.fitconfig.get("WeightFlag", False))
+ self.fitconfig.get("WeightFlag", False)
+ )
self.guiConfig.WeightCheckBox.stateChanged[int].connect(self.weightEvent)
- if qt.BINDING in ('PySide2', 'PyQt5'):
+ if qt.BINDING == "PyQt5":
self.guiConfig.BkgComBox.activated[str].connect(self.bkgEvent)
self.guiConfig.FunComBox.activated[str].connect(self.funEvent)
else: # Qt6
@@ -262,21 +274,21 @@ class FitWidget(qt.QWidget):
# associate silx.gui.fit.FitConfig with all theories
# Users can later associate their own custom dialogs to
# replace the default.
- configdialog = getFitConfigDialog(parent=self,
- default=self.fitconfig)
+ configdialog = getFitConfigDialog(parent=self, default=self.fitconfig)
for theory in self.fitmanager.theories:
self.associateConfigDialog(theory, configdialog)
for bgtheory in self.fitmanager.bgtheories:
- self.associateConfigDialog(bgtheory, configdialog,
- theory_is_background=True)
+ self.associateConfigDialog(
+ bgtheory, configdialog, theory_is_background=True
+ )
# associate silx.gui.fit.BackgroundWidget with Strip and Snip
- bgdialog = getBgDialog(parent=self,
- default=self.fitconfig)
+ bgdialog = getBgDialog(parent=self, default=self.fitconfig)
for bgtheory in ["Strip", "Snip"]:
if bgtheory in self.fitmanager.bgtheories:
- self.associateConfigDialog(bgtheory, bgdialog,
- theory_is_background=True)
+ self.associateConfigDialog(
+ bgtheory, bgdialog, theory_is_background=True
+ )
def _populateFunctions(self):
"""Fill combo-boxes with fit theories and background theories
@@ -286,16 +298,18 @@ class FitWidget(qt.QWidget):
for theory_name in self.fitmanager.bgtheories:
self.guiConfig.BkgComBox.addItem(theory_name)
self.guiConfig.BkgComBox.setItemData(
- self.guiConfig.BkgComBox.findText(theory_name),
- self.fitmanager.bgtheories[theory_name].description,
- qt.Qt.ToolTipRole)
+ self.guiConfig.BkgComBox.findText(theory_name),
+ self.fitmanager.bgtheories[theory_name].description,
+ qt.Qt.ToolTipRole,
+ )
for theory_name in self.fitmanager.theories:
self.guiConfig.FunComBox.addItem(theory_name)
self.guiConfig.FunComBox.setItemData(
- self.guiConfig.FunComBox.findText(theory_name),
- self.fitmanager.theories[theory_name].description,
- qt.Qt.ToolTipRole)
+ self.guiConfig.FunComBox.findText(theory_name),
+ self.fitmanager.theories[theory_name].description,
+ qt.Qt.ToolTipRole,
+ )
# - activate selected fit theory (if any)
# - activate selected bg theory (if any)
@@ -319,10 +333,6 @@ class FitWidget(qt.QWidget):
configuration.update(self.configure())
- @deprecated(replacement='setData', since_version='0.3.0')
- def setdata(self, x, y, sigmay=None, xmin=None, xmax=None):
- self.setData(x, y, sigmay, xmin, xmax)
-
def setData(self, x=None, y=None, sigmay=None, xmin=None, xmax=None):
"""Set data to be fitted.
@@ -345,14 +355,14 @@ class FitWidget(qt.QWidget):
else:
self.guibuttons.EstimateButton.setEnabled(True)
self.guibuttons.StartFitButton.setEnabled(True)
- self.fitmanager.setdata(x=x, y=y, sigmay=sigmay,
- xmin=xmin, xmax=xmax)
+ self.fitmanager.setdata(x=x, y=y, sigmay=sigmay, xmin=xmin, xmax=xmax)
for config_dialog in self.bgconfigdialogs.values():
if isinstance(config_dialog, BackgroundDialog):
config_dialog.setData(x, y, xmin=xmin, xmax=xmax)
- def associateConfigDialog(self, theory_name, config_widget,
- theory_is_background=False):
+ def associateConfigDialog(
+ self, theory_name, config_widget, theory_is_background=False
+ ):
"""Associate an instance of custom configuration dialog widget to
a fit theory or to a background theory.
@@ -372,23 +382,30 @@ class FitWidget(qt.QWidget):
methods (*show*, *exec*, *result*, *setDefault*) or the mandatory
attribute (*output*).
"""
- theories = self.fitmanager.bgtheories if theory_is_background else\
- self.fitmanager.theories
+ theories = (
+ self.fitmanager.bgtheories
+ if theory_is_background
+ else self.fitmanager.theories
+ )
if theory_name not in theories:
raise KeyError("%s does not match an existing fitmanager theory")
if config_widget is not None:
- if (not hasattr(config_widget, "exec") and
- not hasattr(config_widget, "exec_")):
+ if not hasattr(config_widget, "exec") and not hasattr(
+ config_widget, "exec_"
+ ):
raise AttributeError(
- "Custom configuration widget must define exec or exec_")
+ "Custom configuration widget must define exec or exec_"
+ )
for mandatory_attr in ["show", "result", "output"]:
if not hasattr(config_widget, mandatory_attr):
raise AttributeError(
- "Custom configuration widget must define " +
- "attribute or method " + mandatory_attr)
+ "Custom configuration widget must define "
+ + "attribute or method "
+ + mandatory_attr
+ )
if theory_is_background:
self.bgconfigdialogs[theory_name] = config_widget
@@ -427,25 +444,23 @@ class FitWidget(qt.QWidget):
configuration.update(self.configure(**newconfiguration))
# set fit function theory
try:
- i = 1 + \
- list(self.fitmanager.theories.keys()).index(
- self.fitmanager.selectedtheory)
+ i = 1 + list(self.fitmanager.theories.keys()).index(
+ self.fitmanager.selectedtheory
+ )
self.guiConfig.FunComBox.setCurrentIndex(i)
self.funEvent(self.fitmanager.selectedtheory)
except ValueError:
- _logger.error("Function not in list %s",
- self.fitmanager.selectedtheory)
+ _logger.error("Function not in list %s", self.fitmanager.selectedtheory)
self.funEvent(list(self.fitmanager.theories.keys())[0])
# current background
try:
- i = 1 + \
- list(self.fitmanager.bgtheories.keys()).index(
- self.fitmanager.selectedbg)
+ i = 1 + list(self.fitmanager.bgtheories.keys()).index(
+ self.fitmanager.selectedbg
+ )
self.guiConfig.BkgComBox.setCurrentIndex(i)
self.bkgEvent(self.fitmanager.selectedbg)
except ValueError:
- _logger.error("Background not in list %s",
- self.fitmanager.selectedbg)
+ _logger.error("Background not in list %s", self.fitmanager.selectedbg)
self.bkgEvent(list(self.fitmanager.bgtheories.keys())[0])
# update the Gui
@@ -509,8 +524,7 @@ class FitWidget(qt.QWidget):
theory_name = self.fitmanager.selectedtheory
estimation_function = self.fitmanager.theories[theory_name].estimate
if estimation_function is not None:
- ddict = {'event': 'EstimateStarted',
- 'data': None}
+ ddict = {"event": "EstimateStarted", "data": None}
self._emitSignal(ddict)
self.fitmanager.estimate(callback=self.fitStatus)
else:
@@ -520,34 +534,25 @@ class FitWidget(qt.QWidget):
text += "the initial parameters. Please, fill them\n"
text += "yourself in the table and press Start Fit\n"
msg.setText(text)
- msg.setWindowTitle('FitWidget Message')
+ msg.setWindowTitle("FitWidget Message")
msg.exec()
return
- except Exception as e: # noqa (we want to catch and report all errors)
- _logger.warning('Estimate error: %s', traceback.format_exc())
+ except Exception as e: # noqa (we want to catch and report all errors)
+ _logger.warning("Estimate error: %s", traceback.format_exc())
msg = qt.QMessageBox(self)
msg.setIcon(qt.QMessageBox.Critical)
msg.setWindowTitle("Estimate Error")
msg.setText("Error on estimate: %s" % e)
msg.exec()
- ddict = {
- 'event': 'EstimateFailed',
- 'data': None}
+ ddict = {"event": "EstimateFailed", "data": None}
self._emitSignal(ddict)
return
- self.guiParameters.fillFromFit(
- self.fitmanager.fit_results, view='Fit')
- self.guiParameters.removeAllViews(keep='Fit')
- ddict = {
- 'event': 'EstimateFinished',
- 'data': self.fitmanager.fit_results}
+ self.guiParameters.fillFromFit(self.fitmanager.fit_results, view="Fit")
+ self.guiParameters.removeAllViews(keep="Fit")
+ ddict = {"event": "EstimateFinished", "data": self.fitmanager.fit_results}
self._emitSignal(ddict)
- @deprecated(replacement='startFit', since_version='0.3.0')
- def startfit(self):
- self.startFit()
-
def startFit(self):
"""Run fit, then emit :attr:`sigFitWidgetSignal` with a dictionary
containing a status message and a list of fit
@@ -563,31 +568,23 @@ class FitWidget(qt.QWidget):
"""
self.fitmanager.fit_results = self.guiParameters.getFitResults()
try:
- ddict = {'event': 'FitStarted',
- 'data': None}
+ ddict = {"event": "FitStarted", "data": None}
self._emitSignal(ddict)
self.fitmanager.runfit(callback=self.fitStatus)
except Exception as e: # noqa (we want to catch and report all errors)
- _logger.warning('Estimate error: %s', traceback.format_exc())
+ _logger.warning("Estimate error: %s", traceback.format_exc())
msg = qt.QMessageBox(self)
msg.setIcon(qt.QMessageBox.Critical)
msg.setWindowTitle("Fit Error")
msg.setText("Error on Fit: %s" % e)
msg.exec()
- ddict = {
- 'event': 'FitFailed',
- 'data': None
- }
+ ddict = {"event": "FitFailed", "data": None}
self._emitSignal(ddict)
return
- self.guiParameters.fillFromFit(
- self.fitmanager.fit_results, view='Fit')
- self.guiParameters.removeAllViews(keep='Fit')
- ddict = {
- 'event': 'FitFinished',
- 'data': self.fitmanager.fit_results
- }
+ self.guiParameters.fillFromFit(self.fitmanager.fit_results, view="Fit")
+ self.guiParameters.removeAllViews(keep="Fit")
+ ddict = {"event": "FitFinished", "data": self.fitmanager.fit_results}
self._emitSignal(ddict)
return
@@ -598,15 +595,17 @@ class FitWidget(qt.QWidget):
self.fitmanager.setbackground(bgtheory)
else:
functionsfile = qt.QFileDialog.getOpenFileName(
- self, "Select python module with your function(s)", "",
- "Python Files (*.py);;All Files (*)")
+ self,
+ "Select python module with your function(s)",
+ "",
+ "Python Files (*.py);;All Files (*)",
+ )
if len(functionsfile):
try:
self.fitmanager.loadbgtheories(functionsfile)
except ImportError:
- qt.QMessageBox.critical(self, "ERROR",
- "Function not imported")
+ qt.QMessageBox.critical(self, "ERROR", "Function not imported")
return
else:
# empty the ComboBox
@@ -616,9 +615,9 @@ class FitWidget(qt.QWidget):
for key in self.fitmanager.bgtheories:
self.guiConfig.BkgComBox.addItem(str(key))
- i = 1 + \
- list(self.fitmanager.bgtheories.keys()).index(
- self.fitmanager.selectedbg)
+ i = 1 + list(self.fitmanager.bgtheories.keys()).index(
+ self.fitmanager.selectedbg
+ )
self.guiConfig.BkgComBox.setCurrentIndex(i)
self.__initialParameters()
@@ -637,15 +636,17 @@ class FitWidget(qt.QWidget):
else:
# open a load file dialog
functionsfile = qt.QFileDialog.getOpenFileName(
- self, "Select python module with your function(s)", "",
- "Python Files (*.py);;All Files (*)")
+ self,
+ "Select python module with your function(s)",
+ "",
+ "Python Files (*.py);;All Files (*)",
+ )
if len(functionsfile):
try:
self.fitmanager.loadtheories(functionsfile)
except ImportError:
- qt.QMessageBox.critical(self, "ERROR",
- "Function not imported")
+ qt.QMessageBox.critical(self, "ERROR", "Function not imported")
return
else:
# empty the ComboBox
@@ -655,9 +656,9 @@ class FitWidget(qt.QWidget):
for key in self.fitmanager.theories:
self.guiConfig.FunComBox.addItem(str(key))
- i = 1 + \
- list(self.fitmanager.theories.keys()).index(
- self.fitmanager.selectedtheory)
+ i = 1 + list(self.fitmanager.theories.keys()).index(
+ self.fitmanager.selectedtheory
+ )
self.guiConfig.FunComBox.setCurrentIndex(i)
self.__initialParameters()
@@ -682,45 +683,52 @@ class FitWidget(qt.QWidget):
self.fitmanager.fit_results = []
for pname in self.fitmanager.bgtheories[self.fitmanager.selectedbg].parameters:
self.fitmanager.parameter_names.append(pname)
- self.fitmanager.fit_results.append({'name': pname,
- 'estimation': 0,
- 'group': 0,
- 'code': 'FREE',
- 'cons1': 0,
- 'cons2': 0,
- 'fitresult': 0.0,
- 'sigma': 0.0,
- 'xmin': None,
- 'xmax': None})
+ self.fitmanager.fit_results.append(
+ {
+ "name": pname,
+ "estimation": 0,
+ "group": 0,
+ "code": "FREE",
+ "cons1": 0,
+ "cons2": 0,
+ "fitresult": 0.0,
+ "sigma": 0.0,
+ "xmin": None,
+ "xmax": None,
+ }
+ )
if self.fitmanager.selectedtheory is not None:
theory = self.fitmanager.selectedtheory
for pname in self.fitmanager.theories[theory].parameters:
self.fitmanager.parameter_names.append(pname + "1")
- self.fitmanager.fit_results.append({'name': pname + "1",
- 'estimation': 0,
- 'group': 1,
- 'code': 'FREE',
- 'cons1': 0,
- 'cons2': 0,
- 'fitresult': 0.0,
- 'sigma': 0.0,
- 'xmin': None,
- 'xmax': None})
-
- self.guiParameters.fillFromFit(
- self.fitmanager.fit_results, view='Fit')
+ self.fitmanager.fit_results.append(
+ {
+ "name": pname + "1",
+ "estimation": 0,
+ "group": 1,
+ "code": "FREE",
+ "cons1": 0,
+ "cons2": 0,
+ "fitresult": 0.0,
+ "sigma": 0.0,
+ "xmin": None,
+ "xmax": None,
+ }
+ )
+
+ self.guiParameters.fillFromFit(self.fitmanager.fit_results, view="Fit")
def fitStatus(self, data):
"""Set *status* and *chisq* in status bar"""
- if 'chisq' in data:
- if data['chisq'] is None:
+ if "chisq" in data:
+ if data["chisq"] is None:
self.guistatus.ChisqLine.setText(" ")
else:
- chisq = data['chisq']
+ chisq = data["chisq"]
self.guistatus.ChisqLine.setText("%6.2f" % chisq)
- if 'status' in data:
- status = data['status']
+ if "status" in data:
+ status = data["status"]
self.guistatus.StatusLine.setText(str(status))
def dismiss(self):
@@ -734,13 +742,29 @@ if __name__ == "__main__":
x = numpy.arange(1500).astype(numpy.float64)
constant_bg = 3.14
- p = [1000, 100., 30.0,
- 500, 300., 25.,
- 1700, 500., 35.,
- 750, 700., 30.0,
- 1234, 900., 29.5,
- 302, 1100., 30.5,
- 75, 1300., 21.]
+ p = [
+ 1000,
+ 100.0,
+ 30.0,
+ 500,
+ 300.0,
+ 25.0,
+ 1700,
+ 500.0,
+ 35.0,
+ 750,
+ 700.0,
+ 30.0,
+ 1234,
+ 900.0,
+ 29.5,
+ 302,
+ 1100.0,
+ 30.5,
+ 75,
+ 1300.0,
+ 21.0,
+ ]
y = functions.sum_gauss(x, *p) + constant_bg
a = qt.QApplication(sys.argv)
diff --git a/src/silx/gui/fit/FitWidgets.py b/src/silx/gui/fit/FitWidgets.py
index 7bcf28c..b7aef07 100644
--- a/src/silx/gui/fit/FitWidgets.py
+++ b/src/silx/gui/fit/FitWidgets.py
@@ -1,5 +1,5 @@
# /*##########################################################################
-# Copyright (C) 2004-2021 European Synchrotron Radiation Facility
+# Copyright (C) 2004-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -23,8 +23,6 @@
"""Collection of widgets used to build
:class:`silx.gui.fit.FitWidget.FitWidget`"""
-from collections import OrderedDict
-
from silx.gui import qt
from silx.gui.fit.Parameters import Parameters
@@ -69,17 +67,17 @@ class FitActionsButtons(qt.QWidget):
self.EstimateButton = qt.QPushButton(self)
self.EstimateButton.setText("Estimate")
layout.addWidget(self.EstimateButton)
- spacer = qt.QSpacerItem(20, 20,
- qt.QSizePolicy.Expanding,
- qt.QSizePolicy.Minimum)
+ spacer = qt.QSpacerItem(
+ 20, 20, qt.QSizePolicy.Expanding, qt.QSizePolicy.Minimum
+ )
layout.addItem(spacer)
self.StartFitButton = qt.QPushButton(self)
self.StartFitButton.setText("Start Fit")
layout.addWidget(self.StartFitButton)
- spacer_2 = qt.QSpacerItem(20, 20,
- qt.QSizePolicy.Expanding,
- qt.QSizePolicy.Minimum)
+ spacer_2 = qt.QSpacerItem(
+ 20, 20, qt.QSizePolicy.Expanding, qt.QSizePolicy.Minimum
+ )
layout.addItem(spacer_2)
self.DismissButton = qt.QPushButton(self)
@@ -148,6 +146,7 @@ class FitConfigWidget(qt.QWidget):
- open a dialog for modifying advanced parameters through
:attr:`FunConfigureButton`
"""
+
def __init__(self, parent=None):
qt.QWidget.__init__(self, parent)
@@ -163,9 +162,11 @@ class FitConfigWidget(qt.QWidget):
self.FunComBox = qt.QComboBox(self)
self.FunComBox.addItem("Add Function(s)")
- self.FunComBox.setItemData(self.FunComBox.findText("Add Function(s)"),
- "Load fit theories from a file",
- qt.Qt.ToolTipRole)
+ self.FunComBox.setItemData(
+ self.FunComBox.findText("Add Function(s)"),
+ "Load fit theories from a file",
+ qt.Qt.ToolTipRole,
+ )
layout.addWidget(self.FunComBox, 0, 1)
self.BkgLabel = qt.QLabel(self)
@@ -174,28 +175,33 @@ class FitConfigWidget(qt.QWidget):
self.BkgComBox = qt.QComboBox(self)
self.BkgComBox.addItem("Add Background(s)")
- self.BkgComBox.setItemData(self.BkgComBox.findText("Add Background(s)"),
- "Load background theories from a file",
- qt.Qt.ToolTipRole)
+ self.BkgComBox.setItemData(
+ self.BkgComBox.findText("Add Background(s)"),
+ "Load background theories from a file",
+ qt.Qt.ToolTipRole,
+ )
layout.addWidget(self.BkgComBox, 1, 1)
self.FunConfigureButton = qt.QPushButton(self)
self.FunConfigureButton.setText("Configure")
self.FunConfigureButton.setToolTip(
- "Open a configuration dialog for the selected function")
+ "Open a configuration dialog for the selected function"
+ )
layout.addWidget(self.FunConfigureButton, 0, 2)
self.BgConfigureButton = qt.QPushButton(self)
self.BgConfigureButton.setText("Configure")
self.BgConfigureButton.setToolTip(
- "Open a configuration dialog for the selected background")
+ "Open a configuration dialog for the selected background"
+ )
layout.addWidget(self.BgConfigureButton, 1, 2)
self.WeightCheckBox = qt.QCheckBox(self)
self.WeightCheckBox.setText("Weighted fit")
self.WeightCheckBox.setToolTip(
- "Enable usage of weights in the least-square problem.\n Use" +
- " the uncertainties (sigma) if provided, else use sqrt(y).")
+ "Enable usage of weights in the least-square problem.\n Use"
+ + " the uncertainties (sigma) if provided, else use sqrt(y)."
+ )
layout.addWidget(self.WeightCheckBox, 0, 3, 2, 1)
@@ -281,7 +287,7 @@ class ParametersTab(qt.QTabWidget):
self.setWindowTitle(name)
self.setContentsMargins(0, 0, 0, 0)
- self.views = OrderedDict()
+ self.views = {}
"""Dictionary of views. Keys are view names,
items are :class:`Parameters` widgets"""
@@ -310,8 +316,8 @@ class ParametersTab(qt.QTabWidget):
view = self.latest_view
else:
raise KeyError(
- "No view available. You must specify a view" +
- " name the first time you call this method."
+ "No view available. You must specify a view"
+ + " name the first time you call this method."
)
if view in self.tables.keys():
@@ -403,7 +409,7 @@ class ParametersTab(qt.QTabWidget):
text += "<tr>"
ncols = table.columnCount()
for l in range(ncols):
- text += ('<td align="left" bgcolor="%s"><b>' % hcolor)
+ text += '<td align="left" bgcolor="%s"><b>' % hcolor
text += str(table.horizontalHeaderItem(l).text())
text += "</b></td>"
text += "</tr>"
@@ -437,11 +443,9 @@ class ParametersTab(qt.QTabWidget):
else:
finalcolor = "white"
if c < 2:
- text += ('<td align="left" bgcolor="%s">%s' %
- (finalcolor, b))
+ text += '<td align="left" bgcolor="%s">%s' % (finalcolor, b)
else:
- text += ('<td align="right" bgcolor="%s">%s' %
- (finalcolor, b))
+ text += '<td align="right" bgcolor="%s">%s' % (finalcolor, b)
text += newtext
if len(b):
text += "</td>"
@@ -505,14 +509,18 @@ def test():
fit = fitmanager.FitManager(x=x, y=y1)
fitfuns = fittheories.FitTheories()
- fit.addtheory(name="Gaussian",
- function=functions.sum_gauss,
- parameters=("height", "peak center", "fwhm"),
- estimate=fitfuns.estimate_height_position_fwhm)
- fit.settheory('Gaussian')
- fit.configure(PositiveFwhmFlag=True,
- PositiveHeightAreaFlag=True,
- AutoFwhm=True,)
+ fit.addtheory(
+ name="Gaussian",
+ function=functions.sum_gauss,
+ parameters=("height", "peak center", "fwhm"),
+ estimate=fitfuns.estimate_height_position_fwhm,
+ )
+ fit.settheory("Gaussian")
+ fit.configure(
+ PositiveFwhmFlag=True,
+ PositiveHeightAreaFlag=True,
+ AutoFwhm=True,
+ )
# Fit
fit.estimate()
@@ -520,26 +528,27 @@ def test():
w = ParametersTab()
w.show()
- w.fillFromFit(fit.fit_results, view='Gaussians')
+ w.fillFromFit(fit.fit_results, view="Gaussians")
- y2 = functions.sum_splitgauss(x,
- 100, 400, 100, 40,
- 10, 600, 50, 500,
- 80, 850, 10, 50)
+ y2 = functions.sum_splitgauss(
+ x, 100, 400, 100, 40, 10, 600, 50, 500, 80, 850, 10, 50
+ )
fit.setdata(x=x, y=y2)
# Define new theory
- fit.addtheory(name="Asymetric gaussian",
- function=functions.sum_splitgauss,
- parameters=("height", "peak center", "left fwhm", "right fwhm"),
- estimate=fitfuns.estimate_splitgauss)
- fit.settheory('Asymetric gaussian')
+ fit.addtheory(
+ name="Asymetric gaussian",
+ function=functions.sum_splitgauss,
+ parameters=("height", "peak center", "left fwhm", "right fwhm"),
+ estimate=fitfuns.estimate_splitgauss,
+ )
+ fit.settheory("Asymetric gaussian")
# Fit
fit.estimate()
fit.runfit()
- w.fillFromFit(fit.fit_results, view='Asymetric gaussians')
+ w.fillFromFit(fit.fit_results, view="Asymetric gaussians")
# Plot
pw = PlotWindow(control=True)
diff --git a/src/silx/gui/fit/Parameters.py b/src/silx/gui/fit/Parameters.py
index e9601a8..bd2605e 100644
--- a/src/silx/gui/fit/Parameters.py
+++ b/src/silx/gui/fit/Parameters.py
@@ -1,5 +1,5 @@
# /*##########################################################################
-# Copyright (C) 2004-2021 European Synchrotron Radiation Facility
+# Copyright (C) 2004-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -27,7 +27,6 @@ __license__ = "MIT"
__date__ = "25/11/2016"
import sys
-from collections import OrderedDict
from silx.gui import qt
from silx.gui.widgets.TableWidget import TableWidget
@@ -55,6 +54,7 @@ class QComboTableItem(qt.QComboBox):
:param row: Row number of the table cell containing this widget
:param col: Column number of the table cell containing this widget"""
+
sigCellChanged = qt.Signal(int, int)
"""Signal emitted when this ``QComboBox`` is activated.
A ``(row, column)`` tuple is passed."""
@@ -78,6 +78,7 @@ class QCheckBoxItem(qt.QCheckBox):
:param row: Row number of the table cell containing this widget
:param col: Column number of the table cell containing this widget"""
+
sigCellChanged = qt.Signal(int, int)
"""Signal emitted when this ``QCheckBox`` is clicked.
A ``(row, column)`` tuple is passed."""
@@ -106,22 +107,39 @@ class Parameters(TableWidget):
peak.
:type paramlist: list[str] or None
"""
+
def __init__(self, parent=None, paramlist=None):
TableWidget.__init__(self, parent)
self.setContentsMargins(0, 0, 0, 0)
- labels = ['Parameter', 'Estimation', 'Fit Value', 'Sigma',
- 'Constraints', 'Min/Parame', 'Max/Factor/Delta']
- tooltips = ["Fit parameter name",
- "Estimated value for fit parameter. You can edit this column.",
- "Actual value for parameter, after fit",
- "Uncertainty (same unit as the parameter)",
- "Constraint to be applied to the parameter for fit",
- "First parameter for constraint (name of another param or min value)",
- "Second parameter for constraint (max value, or factor/delta)"]
-
- self.columnKeys = ['name', 'estimation', 'fitresult',
- 'sigma', 'code', 'val1', 'val2']
+ labels = [
+ "Parameter",
+ "Estimation",
+ "Fit Value",
+ "Sigma",
+ "Constraints",
+ "Min/Parame",
+ "Max/Factor/Delta",
+ ]
+ tooltips = [
+ "Fit parameter name",
+ "Estimated value for fit parameter. You can edit this column.",
+ "Actual value for parameter, after fit",
+ "Uncertainty (same unit as the parameter)",
+ "Constraint to be applied to the parameter for fit",
+ "First parameter for constraint (name of another param or min value)",
+ "Second parameter for constraint (max value, or factor/delta)",
+ ]
+
+ self.columnKeys = [
+ "name",
+ "estimation",
+ "fitresult",
+ "sigma",
+ "code",
+ "val1",
+ "val2",
+ ]
"""This list assigns shorter keys to refer to columns than the
displayed labels."""
@@ -133,8 +151,7 @@ class Parameters(TableWidget):
for i, label in enumerate(labels):
item = self.horizontalHeaderItem(i)
if item is None:
- item = qt.QTableWidgetItem(label,
- qt.QTableWidgetItem.Type)
+ item = qt.QTableWidgetItem(label, qt.QTableWidgetItem.Type)
self.setHorizontalHeaderItem(i, item)
item.setText(label)
@@ -148,7 +165,7 @@ class Parameters(TableWidget):
# Initialize the table with one line per supplied parameter
paramlist = paramlist if paramlist is not None else []
- self.parameters = OrderedDict()
+ self.parameters = {}
"""This attribute stores all the data in an ordered dictionary.
New data can be added using :meth:`newParameterLine`.
Existing data can be modified using :meth:`configureLine`
@@ -184,8 +201,17 @@ class Parameters(TableWidget):
for line, param in enumerate(paramlist):
self.newParameterLine(param, line)
- self.code_options = ["FREE", "POSITIVE", "QUOTED", "FIXED",
- "FACTOR", "DELTA", "SUM", "IGNORE", "ADD"]
+ self.code_options = [
+ "FREE",
+ "POSITIVE",
+ "QUOTED",
+ "FIXED",
+ "FACTOR",
+ "DELTA",
+ "SUM",
+ "IGNORE",
+ "ADD",
+ ]
"""Possible values in the combo boxes in the 'Constraints' column.
"""
@@ -210,43 +236,46 @@ class Parameters(TableWidget):
self.setRowCount(line + 1)
# default configuration for fit parameters
- self.parameters[param] = OrderedDict((('line', line),
- ('estimation', '0'),
- ('fitresult', ''),
- ('sigma', ''),
- ('code', 'FREE'),
- ('val1', ''),
- ('val2', ''),
- ('cons1', 0),
- ('cons2', 0),
- ('vmin', '0'),
- ('vmax', '1'),
- ('relatedto', ''),
- ('factor', '1.0'),
- ('delta', '0.0'),
- ('sum', '0.0'),
- ('group', ''),
- ('name', param),
- ('xmin', None),
- ('xmax', None)))
- self.setReadWrite(param, 'estimation')
- self.setReadOnly(param, ['name', 'fitresult', 'sigma', 'val1', 'val2'])
+ self.parameters[param] = dict(
+ (
+ ("line", line),
+ ("estimation", "0"),
+ ("fitresult", ""),
+ ("sigma", ""),
+ ("code", "FREE"),
+ ("val1", ""),
+ ("val2", ""),
+ ("cons1", 0),
+ ("cons2", 0),
+ ("vmin", "0"),
+ ("vmax", "1"),
+ ("relatedto", ""),
+ ("factor", "1.0"),
+ ("delta", "0.0"),
+ ("sum", "0.0"),
+ ("group", ""),
+ ("name", param),
+ ("xmin", None),
+ ("xmax", None),
+ )
+ )
+ self.setReadWrite(param, "estimation")
+ self.setReadOnly(param, ["name", "fitresult", "sigma", "val1", "val2"])
# Constraint codes
a = []
for option in self.code_options:
a.append(option)
- code_column_index = self.columnIndexByField('code')
+ code_column_index = self.columnIndexByField("code")
cellWidget = self.cellWidget(line, code_column_index)
if cellWidget is None:
- cellWidget = QComboTableItem(self, row=line,
- col=code_column_index)
+ cellWidget = QComboTableItem(self, row=line, col=code_column_index)
cellWidget.addItems(a)
self.setCellWidget(line, code_column_index, cellWidget)
cellWidget.sigCellChanged[int, int].connect(self.onCellChanged)
- self.parameters[param]['code_item'] = cellWidget
- self.parameters[param]['relatedto_item'] = None
+ self.parameters[param]["code_item"] = cellWidget
+ self.parameters[param]["relatedto_item"] = None
self.__configuring = False
def columnIndexByField(self, field):
@@ -268,44 +297,48 @@ class Parameters(TableWidget):
self.setRowCount(len(fitresults))
# Reinitialize and fill self.parameters
- self.parameters = OrderedDict()
- for (line, param) in enumerate(fitresults):
- self.newParameterLine(param['name'], line)
+ self.parameters = {}
+ for line, param in enumerate(fitresults):
+ self.newParameterLine(param["name"], line)
for param in fitresults:
- name = param['name']
- code = str(param['code'])
+ name = param["name"]
+ code = str(param["code"])
if code not in self.code_options:
# convert code from int to descriptive string
code = self.code_options[int(code)]
- val1 = param['cons1']
- val2 = param['cons2']
- estimation = param['estimation']
- group = param['group']
- sigma = param['sigma']
- fitresult = param['fitresult']
-
- xmin = param.get('xmin')
- xmax = param.get('xmax')
-
- self.configureLine(name=name,
- code=code,
- val1=val1, val2=val2,
- estimation=estimation,
- fitresult=fitresult,
- sigma=sigma,
- group=group,
- xmin=xmin, xmax=xmax)
+ val1 = param["cons1"]
+ val2 = param["cons2"]
+ estimation = param["estimation"]
+ group = param["group"]
+ sigma = param["sigma"]
+ fitresult = param["fitresult"]
+
+ xmin = param.get("xmin")
+ xmax = param.get("xmax")
+
+ self.configureLine(
+ name=name,
+ code=code,
+ val1=val1,
+ val2=val2,
+ estimation=estimation,
+ fitresult=fitresult,
+ sigma=sigma,
+ group=group,
+ xmin=xmin,
+ xmax=xmax,
+ )
def getConfiguration(self):
"""Return ``FitManager.paramlist`` dictionary
encapsulated in another dictionary"""
- return {'parameters': self.getFitResults()}
+ return {"parameters": self.getFitResults()}
def setConfiguration(self, ddict):
"""Fill table with values from a ``FitManager.paramlist`` dictionary
encapsulated in another dictionary"""
- self.fillFromFit(ddict['parameters'])
+ self.fillFromFit(ddict["parameters"])
def getFitResults(self):
"""Return fit parameters as a list of dictionaries in the format used
@@ -316,33 +349,33 @@ class Parameters(TableWidget):
fitparam = {}
name = param
estimation, [code, cons1, cons2] = self.getEstimationConstraints(name)
- buf = str(self.parameters[param]['fitresult'])
- xmin = self.parameters[param]['xmin']
- xmax = self.parameters[param]['xmax']
+ buf = str(self.parameters[param]["fitresult"])
+ xmin = self.parameters[param]["xmin"]
+ xmax = self.parameters[param]["xmax"]
if len(buf):
fitresult = float(buf)
else:
fitresult = 0.0
- buf = str(self.parameters[param]['sigma'])
+ buf = str(self.parameters[param]["sigma"])
if len(buf):
sigma = float(buf)
else:
sigma = 0.0
- buf = str(self.parameters[param]['group'])
+ buf = str(self.parameters[param]["group"])
if len(buf):
group = float(buf)
else:
group = 0
- fitparam['name'] = name
- fitparam['estimation'] = estimation
- fitparam['fitresult'] = fitresult
- fitparam['sigma'] = sigma
- fitparam['group'] = group
- fitparam['code'] = code
- fitparam['cons1'] = cons1
- fitparam['cons2'] = cons2
- fitparam['xmin'] = xmin
- fitparam['xmax'] = xmax
+ fitparam["name"] = name
+ fitparam["estimation"] = estimation
+ fitparam["fitresult"] = fitresult
+ fitparam["sigma"] = sigma
+ fitparam["group"] = group
+ fitparam["code"] = code
+ fitparam["cons1"] = cons1
+ fitparam["cons2"] = cons2
+ fitparam["xmin"] = xmin
+ fitparam["xmax"] = xmax
fitparameterslist.append(fitparam)
return fitparameterslist
@@ -370,7 +403,7 @@ class Parameters(TableWidget):
if item is not None:
newvalue = item.text()
else:
- newvalue = ''
+ newvalue = ""
else:
# this is the combobox
widget = self.cellWidget(row, col)
@@ -379,12 +412,12 @@ class Parameters(TableWidget):
paramdict = {"name": param, field: newvalue}
self.configureLine(**paramdict)
else:
- if field == 'code':
+ if field == "code":
# New code not valid, try restoring the old one
index = self.code_options.index(oldvalue)
self.__configuring = True
try:
- self.parameters[param]['code_item'].setCurrentIndex(index)
+ self.parameters[param]["code_item"].setCurrentIndex(index)
finally:
self.__configuring = False
else:
@@ -400,10 +433,14 @@ class Parameters(TableWidget):
:param newvalue: New value to be validated
:return: True if new cell value is valid, else False
"""
- if field == 'code':
+ if field == "code":
return self.setCodeValue(param, oldvalue, newvalue)
# FIXME: validate() shouldn't have side effects. Move this bit to configureLine()?
- if field == 'val1' and str(self.parameters[param]['code']) in ['DELTA', 'FACTOR', 'SUM']:
+ if field == "val1" and str(self.parameters[param]["code"]) in [
+ "DELTA",
+ "FACTOR",
+ "SUM",
+ ]:
_, candidates = self.getRelatedCandidates(param)
# We expect val1 to be a fit parameter name
if str(newvalue) in candidates:
@@ -429,52 +466,48 @@ class Parameters(TableWidget):
:return: ``True`` if code was successfully updated
"""
- if str(newvalue) in ['FREE', 'POSITIVE', 'QUOTED', 'FIXED']:
- self.configureLine(name=param,
- code=newvalue)
- if str(oldvalue) == 'IGNORE':
+ if str(newvalue) in ["FREE", "POSITIVE", "QUOTED", "FIXED"]:
+ self.configureLine(name=param, code=newvalue)
+ if str(oldvalue) == "IGNORE":
self.freeRestOfGroup(param)
return True
- elif str(newvalue) in ['FACTOR', 'DELTA', 'SUM']:
+ elif str(newvalue) in ["FACTOR", "DELTA", "SUM"]:
# I should check here that some parameter is set
best, candidates = self.getRelatedCandidates(param)
if len(candidates) == 0:
return False
- self.configureLine(name=param,
- code=newvalue,
- relatedto=best)
- if str(oldvalue) == 'IGNORE':
+ self.configureLine(name=param, code=newvalue, relatedto=best)
+ if str(oldvalue) == "IGNORE":
self.freeRestOfGroup(param)
return True
- elif str(newvalue) == 'IGNORE':
+ elif str(newvalue) == "IGNORE":
# I should check if the group can be ignored
# for the time being I just fix all of them to ignore
- group = int(float(str(self.parameters[param]['group'])))
+ group = int(float(str(self.parameters[param]["group"])))
candidates = []
for param in self.parameters.keys():
- if group == int(float(str(self.parameters[param]['group']))):
+ if group == int(float(str(self.parameters[param]["group"]))):
candidates.append(param)
# print candidates
# I should check here if there is any relation to them
for param in candidates:
- self.configureLine(name=param,
- code=newvalue)
+ self.configureLine(name=param, code=newvalue)
return True
- elif str(newvalue) == 'ADD':
- group = int(float(str(self.parameters[param]['group'])))
+ elif str(newvalue) == "ADD":
+ group = int(float(str(self.parameters[param]["group"])))
if group == 0:
# One cannot add a background group
return False
i = 0
for param in self.parameters:
- if i <= int(float(str(self.parameters[param]['group']))):
+ if i <= int(float(str(self.parameters[param]["group"]))):
i += 1
- if (group == 0) and (i == 1): # FIXME: why +1?
+ if (group == 0) and (i == 1): # FIXME: why +1?
i += 1
self.addGroup(i, group)
return False
- elif str(newvalue) == 'SHOW':
+ elif str(newvalue) == "SHOW":
print(self.getEstimationConstraints(param))
return False
@@ -492,14 +525,14 @@ class Parameters(TableWidget):
newparam = []
# loop through parameters until we encounter group number `gtype`
for param in list(self.parameters):
- paramgroup = int(float(str(self.parameters[param]['group'])))
+ paramgroup = int(float(str(self.parameters[param]["group"])))
# copy parameter names in group number `gtype`
if paramgroup == gtype:
# but replace `gtype` with `newg`
newparam.append(param.rstrip("0123456789") + "%d" % newg)
- xmin = self.parameters[param]['xmin']
- xmax = self.parameters[param]['xmax']
+ xmin = self.parameters[param]["xmin"]
+ xmax = self.parameters[param]["xmax"]
# Add new parameters (one table line per parameter) and configureLine each
# one by updating xmin and xmax to the same values as group `gtype`
@@ -519,16 +552,14 @@ class Parameters(TableWidget):
:param workparam: Fit parameter name
"""
if workparam in self.parameters.keys():
- group = int(float(str(self.parameters[workparam]['group'])))
+ group = int(float(str(self.parameters[workparam]["group"])))
for param in self.parameters:
- if param != workparam and\
- group == int(float(str(self.parameters[param]['group']))):
- self.configureLine(name=param,
- code='FREE',
- cons1=0,
- cons2=0,
- val1='',
- val2='')
+ if param != workparam and group == int(
+ float(str(self.parameters[param]["group"]))
+ ):
+ self.configureLine(
+ name=param, code="FREE", cons1=0, cons2=0, val1="", val2=""
+ )
def getRelatedCandidates(self, workparam):
"""If fit parameter ``workparam`` has a constraint that involves other
@@ -543,12 +574,16 @@ class Parameters(TableWidget):
for param_name in self.parameters:
if param_name != workparam:
# ignore parameters that are fixed by a constraint
- if str(self.parameters[param_name]['code']) not in\
- ['IGNORE', 'FACTOR', 'DELTA', 'SUM']:
+ if str(self.parameters[param_name]["code"]) not in [
+ "IGNORE",
+ "FACTOR",
+ "DELTA",
+ "SUM",
+ ]:
candidates.append(param_name)
# take the previous one (before code cell changed) if possible
- if str(self.parameters[workparam]['relatedto']) in candidates:
- best = str(self.parameters[workparam]['relatedto'])
+ if str(self.parameters[workparam]["relatedto"]) in candidates:
+ best = str(self.parameters[workparam]["relatedto"])
return best, candidates
# take the first with same base name (after removing numbers)
for param_name in candidates:
@@ -584,9 +619,7 @@ class Parameters(TableWidget):
:param fields: Field names identifying the columns
:type fields: str or list[str]
"""
- editflags = qt.Qt.ItemIsSelectable |\
- qt.Qt.ItemIsEnabled |\
- qt.Qt.ItemIsEditable
+ editflags = qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled | qt.Qt.ItemIsEditable
self.setField(parameter, fields, editflags)
def setField(self, parameter, fields, edit_flags):
@@ -601,13 +634,11 @@ class Parameters(TableWidget):
qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled |
qt.Qt.ItemIsEditable
"""
- if isinstance(parameter, list) or \
- isinstance(parameter, tuple):
+ if isinstance(parameter, list) or isinstance(parameter, tuple):
paramlist = parameter
else:
paramlist = [parameter]
- if isinstance(fields, list) or \
- isinstance(fields, tuple):
+ if isinstance(fields, list) or isinstance(fields, tuple):
fieldlist = fields
else:
fieldlist = [fields]
@@ -623,7 +654,7 @@ class Parameters(TableWidget):
row = list(self.parameters.keys()).index(param)
for field in fieldlist:
col = self.columnIndexByField(field)
- if field != 'code':
+ if field != "code":
key = field + "_item"
item = self.item(row, col)
if item is None:
@@ -638,10 +669,22 @@ class Parameters(TableWidget):
# Restore previous _configuring flag
self.__configuring = _oldvalue
- def configureLine(self, name, code=None, val1=None, val2=None,
- sigma=None, estimation=None, fitresult=None,
- group=None, xmin=None, xmax=None, relatedto=None,
- cons1=None, cons2=None):
+ def configureLine(
+ self,
+ name,
+ code=None,
+ val1=None,
+ val2=None,
+ sigma=None,
+ estimation=None,
+ fitresult=None,
+ group=None,
+ xmin=None,
+ xmax=None,
+ relatedto=None,
+ cons1=None,
+ cons2=None,
+ ):
"""This function updates values in a line of the table
:param name: Name of the parameter (serves as unique identifier for
@@ -675,73 +718,88 @@ class Parameters(TableWidget):
# update code first, if specified
if code is not None:
code = str(code)
- self.parameters[name]['code'] = code
+ self.parameters[name]["code"] = code
# update combobox
- index = self.parameters[name]['code_item'].findText(code)
- self.parameters[name]['code_item'].setCurrentIndex(index)
+ index = self.parameters[name]["code_item"].findText(code)
+ self.parameters[name]["code_item"].setCurrentIndex(index)
else:
# set code to previous value, used later for setting val1 val2
- code = self.parameters[name]['code']
+ code = self.parameters[name]["code"]
# val1 and sigma have special formats
if val1 is not None:
- fmt = None if self.parameters[name]['code'] in\
- ['DELTA', 'FACTOR', 'SUM'] else "%8g"
+ fmt = (
+ None
+ if self.parameters[name]["code"] in ["DELTA", "FACTOR", "SUM"]
+ else "%8g"
+ )
self._updateField(name, "val1", val1, fmat=fmt)
if sigma is not None:
self._updateField(name, "sigma", sigma, fmat="%6.3g")
# other fields are formatted as "%8g"
- keys_params = (("val2", val2), ("estimation", estimation),
- ("fitresult", fitresult))
+ keys_params = (
+ ("val2", val2),
+ ("estimation", estimation),
+ ("fitresult", fitresult),
+ )
for key, value in keys_params:
if value is not None:
self._updateField(name, key, value, fmat="%8g")
# the rest of the parameters are treated as strings and don't need
# validation
- keys_params = (("group", group), ("xmin", xmin),
- ("xmax", xmax), ("relatedto", relatedto),
- ("cons1", cons1), ("cons2", cons2))
+ keys_params = (
+ ("group", group),
+ ("xmin", xmin),
+ ("xmax", xmax),
+ ("relatedto", relatedto),
+ ("cons1", cons1),
+ ("cons2", cons2),
+ )
for key, value in keys_params:
if value is not None:
self.parameters[name][key] = str(value)
# val1 and val2 have different meanings depending on the code
- if code == 'QUOTED':
+ if code == "QUOTED":
if val1 is not None:
- self.parameters[name]['vmin'] = self.parameters[name]['val1']
+ self.parameters[name]["vmin"] = self.parameters[name]["val1"]
else:
- self.parameters[name]['val1'] = self.parameters[name]['vmin']
+ self.parameters[name]["val1"] = self.parameters[name]["vmin"]
if val2 is not None:
- self.parameters[name]['vmax'] = self.parameters[name]['val2']
+ self.parameters[name]["vmax"] = self.parameters[name]["val2"]
else:
- self.parameters[name]['val2'] = self.parameters[name]['vmax']
+ self.parameters[name]["val2"] = self.parameters[name]["vmax"]
# cons1 and cons2 are scalar representations of val1 and val2
- self.parameters[name]['cons1'] =\
- float_else_zero(self.parameters[name]['val1'])
- self.parameters[name]['cons2'] =\
- float_else_zero(self.parameters[name]['val2'])
+ self.parameters[name]["cons1"] = float_else_zero(
+ self.parameters[name]["val1"]
+ )
+ self.parameters[name]["cons2"] = float_else_zero(
+ self.parameters[name]["val2"]
+ )
# cons1, cons2 = min(val1, val2), max(val1, val2)
- if self.parameters[name]['cons1'] > self.parameters[name]['cons2']:
- self.parameters[name]['cons1'], self.parameters[name]['cons2'] =\
- self.parameters[name]['cons2'], self.parameters[name]['cons1']
+ if self.parameters[name]["cons1"] > self.parameters[name]["cons2"]:
+ self.parameters[name]["cons1"], self.parameters[name]["cons2"] = (
+ self.parameters[name]["cons2"],
+ self.parameters[name]["cons1"],
+ )
- elif code in ['DELTA', 'SUM', 'FACTOR']:
+ elif code in ["DELTA", "SUM", "FACTOR"]:
# For these codes, val1 is the fit parameter name on which the
# constraint depends
if val1 is not None and val1 in paramlist:
- self.parameters[name]['relatedto'] = self.parameters[name]["val1"]
+ self.parameters[name]["relatedto"] = self.parameters[name]["val1"]
elif val1 is not None:
# val1 could be the index of the fit parameter
try:
- self.parameters[name]['relatedto'] = paramlist[int(val1)]
+ self.parameters[name]["relatedto"] = paramlist[int(val1)]
except ValueError:
- self.parameters[name]['relatedto'] = self.parameters[name]["val1"]
+ self.parameters[name]["relatedto"] = self.parameters[name]["val1"]
elif relatedto is not None:
# code changed, val1 not specified but relatedto specified:
@@ -753,25 +811,27 @@ class Parameters(TableWidget):
self.parameters[name][key] = self.parameters[name]["val2"]
# FIXME: val1 is sometimes specified as an index rather than a param name
- self.parameters[name]['val1'] = self.parameters[name]['relatedto']
+ self.parameters[name]["val1"] = self.parameters[name]["relatedto"]
# cons1 is the index of the fit parameter in the ordered dictionary
- if self.parameters[name]['val1'] in paramlist:
- self.parameters[name]['cons1'] =\
- paramlist.index(self.parameters[name]['val1'])
+ if self.parameters[name]["val1"] in paramlist:
+ self.parameters[name]["cons1"] = paramlist.index(
+ self.parameters[name]["val1"]
+ )
# cons2 is the constraint value (factor, delta or sum)
try:
- self.parameters[name]['cons2'] =\
- float(str(self.parameters[name]['val2']))
+ self.parameters[name]["cons2"] = float(
+ str(self.parameters[name]["val2"])
+ )
except ValueError:
- self.parameters[name]['cons2'] = 1.0 if code == "FACTOR" else 0.0
+ self.parameters[name]["cons2"] = 1.0 if code == "FACTOR" else 0.0
- elif code in ['FREE', 'POSITIVE', 'IGNORE', 'FIXED']:
- self.parameters[name]['val1'] = ""
- self.parameters[name]['val2'] = ""
- self.parameters[name]['cons1'] = 0
- self.parameters[name]['cons2'] = 0
+ elif code in ["FREE", "POSITIVE", "IGNORE", "FIXED"]:
+ self.parameters[name]["val1"] = ""
+ self.parameters[name]["val2"] = ""
+ self.parameters[name]["cons1"] = 0
+ self.parameters[name]["cons2"] = 0
self._updateCellRWFlags(name, code)
@@ -793,9 +853,9 @@ class Parameters(TableWidget):
newvalue = fmat % float(value) if value != "" else ""
else:
newvalue = value
- self.parameters[name][field] = newvalue if\
- self.validate(name, field, oldvalue, newvalue) else\
- oldvalue
+ self.parameters[name][field] = (
+ newvalue if self.validate(name, field, oldvalue, newvalue) else oldvalue
+ )
def _updateCellRWFlags(self, name, code=None):
"""Set read-only or read-write flags in a row,
@@ -806,12 +866,12 @@ class Parameters(TableWidget):
`'FIXED', 'FACTOR', 'DELTA', 'SUM', 'ADD'`
:return:
"""
- if code in ['FREE', 'POSITIVE', 'IGNORE', 'FIXED']:
- self.setReadWrite(name, 'estimation')
- self.setReadOnly(name, ['fitresult', 'sigma', 'val1', 'val2'])
+ if code in ["FREE", "POSITIVE", "IGNORE", "FIXED"]:
+ self.setReadWrite(name, "estimation")
+ self.setReadOnly(name, ["fitresult", "sigma", "val1", "val2"])
else:
- self.setReadWrite(name, ['estimation', 'val1', 'val2'])
- self.setReadOnly(name, ['fitresult', 'sigma'])
+ self.setReadWrite(name, ["estimation", "val1", "val2"])
+ self.setReadOnly(name, ["fitresult", "sigma"])
def getEstimationConstraints(self, param):
"""
@@ -822,18 +882,17 @@ class Parameters(TableWidget):
estimation = None
constraints = None
if param in self.parameters.keys():
- buf = str(self.parameters[param]['estimation'])
+ buf = str(self.parameters[param]["estimation"])
if len(buf):
estimation = float(buf)
else:
estimation = 0
- if str(self.parameters[param]['code']) in self.code_options:
- code = self.code_options.index(
- str(self.parameters[param]['code']))
+ if str(self.parameters[param]["code"]) in self.code_options:
+ code = self.code_options.index(str(self.parameters[param]["code"]))
else:
- code = str(self.parameters[param]['code'])
- cons1 = self.parameters[param]['cons1']
- cons2 = self.parameters[param]['cons2']
+ code = str(self.parameters[param]["code"])
+ cons1 = self.parameters[param]["cons1"]
+ cons2 = self.parameters[param]["cons2"]
constraints = [code, cons1, cons2]
return estimation, constraints
@@ -841,21 +900,24 @@ class Parameters(TableWidget):
def main(args):
from silx.math.fit import fittheories
from silx.math.fit import fitmanager
+
try:
from PyMca5 import PyMcaDataDir
except ImportError:
raise ImportError("This demo requires PyMca data. Install PyMca5.")
import numpy
import os
+
app = qt.QApplication(args)
- tab = Parameters(paramlist=['Height', 'Position', 'FWHM'])
+ tab = Parameters(paramlist=["Height", "Position", "FWHM"])
tab.showGrid()
- tab.configureLine(name='Height', estimation='1234', group=0)
- tab.configureLine(name='Position', code='FIXED', group=1)
- tab.configureLine(name='FWHM', group=1)
+ tab.configureLine(name="Height", estimation="1234", group=0)
+ tab.configureLine(name="Position", code="FIXED", group=1)
+ tab.configureLine(name="FWHM", group=1)
- y = numpy.loadtxt(os.path.join(PyMcaDataDir.PYMCA_DATA_DIR,
- "XRFSpectrum.mca")) # FIXME
+ y = numpy.loadtxt(
+ os.path.join(PyMcaDataDir.PYMCA_DATA_DIR, "XRFSpectrum.mca")
+ ) # FIXME
x = numpy.arange(len(y)) * 0.0502883 - 0.492773
fit = fitmanager.FitManager()
@@ -863,19 +925,22 @@ def main(args):
fit.loadtheories(fittheories)
- fit.settheory('ahypermet')
- fit.configure(Yscaling=1.,
- PositiveFwhmFlag=True,
- PositiveHeightAreaFlag=True,
- FwhmPoints=16,
- QuotedPositionFlag=1,
- HypermetTails=1)
- fit.setbackground('Linear')
+ fit.settheory("ahypermet")
+ fit.configure(
+ Yscaling=1.0,
+ PositiveFwhmFlag=True,
+ PositiveHeightAreaFlag=True,
+ FwhmPoints=16,
+ QuotedPositionFlag=1,
+ HypermetTails=1,
+ )
+ fit.setbackground("Linear")
fit.estimate()
fit.runfit()
tab.fillFromFit(fit.fit_results)
tab.show()
app.exec()
+
if __name__ == "__main__":
main(sys.argv)
diff --git a/src/silx/gui/fit/test/testBackgroundWidget.py b/src/silx/gui/fit/test/testBackgroundWidget.py
index 353d3d5..73e3fba 100644
--- a/src/silx/gui/fit/test/testBackgroundWidget.py
+++ b/src/silx/gui/fit/test/testBackgroundWidget.py
@@ -21,8 +21,6 @@
# THE SOFTWARE.
#
# ###########################################################################*/
-import unittest
-
from silx.gui.utils.testutils import TestCaseQt
from .. import BackgroundWidget
@@ -36,8 +34,7 @@ class TestBackgroundWidget(TestCaseQt):
def setUp(self):
super(TestBackgroundWidget, self).setUp()
self.bgdialog = BackgroundWidget.BackgroundDialog()
- self.bgdialog.setData(list([0, 1, 2, 3]),
- list([0, 1, 4, 8]))
+ self.bgdialog.setData(list([0, 1, 2, 3]), list([0, 1, 4, 8]))
self.qWaitForWindowExposed(self.bgdialog)
def tearDown(self):
@@ -60,9 +57,17 @@ class TestBackgroundWidget(TestCaseQt):
self.bgdialog.accept()
output = self.bgdialog.output
- for key in ["algorithm", "StripThreshold", "SnipWidth",
- "StripIterations", "StripWidth", "SmoothingFlag",
- "SmoothingWidth", "AnchorsFlag", "AnchorsList"]:
+ for key in [
+ "algorithm",
+ "StripThreshold",
+ "SnipWidth",
+ "StripIterations",
+ "StripWidth",
+ "SmoothingFlag",
+ "SmoothingWidth",
+ "AnchorsFlag",
+ "AnchorsList",
+ ]:
self.assertIn(key, output)
self.assertFalse(output["AnchorsFlag"])
diff --git a/src/silx/gui/fit/test/testFitConfig.py b/src/silx/gui/fit/test/testFitConfig.py
index 114ff62..d59562c 100644
--- a/src/silx/gui/fit/test/testFitConfig.py
+++ b/src/silx/gui/fit/test/testFitConfig.py
@@ -27,8 +27,6 @@ __authors__ = ["P. Knobel"]
__license__ = "MIT"
__date__ = "05/12/2016"
-import unittest
-
from silx.gui.utils.testutils import TestCaseQt
from .. import FitConfig
@@ -61,22 +59,24 @@ class TestFitConfig(TestCaseQt):
self.fit_config.accept()
output = self.fit_config.output
- for key in ["AutoFwhm",
- "PositiveHeightAreaFlag",
- "QuotedPositionFlag",
- "PositiveFwhmFlag",
- "SameFwhmFlag",
- "QuotedEtaFlag",
- "NoConstraintsFlag",
- "FwhmPoints",
- "Sensitivity",
- "Yscaling",
- "ForcePeakPresence",
- "StripBackgroundFlag",
- "StripWidth",
- "StripIterations",
- "StripThreshold",
- "SmoothingFlag"]:
+ for key in [
+ "AutoFwhm",
+ "PositiveHeightAreaFlag",
+ "QuotedPositionFlag",
+ "PositiveFwhmFlag",
+ "SameFwhmFlag",
+ "QuotedEtaFlag",
+ "NoConstraintsFlag",
+ "FwhmPoints",
+ "Sensitivity",
+ "Yscaling",
+ "ForcePeakPresence",
+ "StripBackgroundFlag",
+ "StripWidth",
+ "StripIterations",
+ "StripThreshold",
+ "SmoothingFlag",
+ ]:
self.assertIn(key, output)
self.assertTrue(output["AutoFwhm"])
diff --git a/src/silx/gui/fit/test/testFitWidget.py b/src/silx/gui/fit/test/testFitWidget.py
index fe61268..e59fa92 100644
--- a/src/silx/gui/fit/test/testFitWidget.py
+++ b/src/silx/gui/fit/test/testFitWidget.py
@@ -23,8 +23,6 @@
# ###########################################################################*/
"""Basic tests for :class:`FitWidget`"""
-import unittest
-
from silx.gui.utils.testutils import TestCaseQt
from ... import qt
@@ -82,13 +80,9 @@ class TestFitWidget(TestCaseQt):
y = [fitfun(x_, 2, 3) for x_ in x]
def conf(**kw):
- return {"spam": "eggs",
- "hello": "world!"}
+ return {"spam": "eggs", "hello": "world!"}
- theory = FitTheory(
- function=fitfun,
- parameters=["a", "b"],
- configure=conf)
+ theory = FitTheory(function=fitfun, parameters=["a", "b"], configure=conf)
fitmngr = FitManager()
fitmngr.setdata(x, y)
@@ -97,8 +91,9 @@ class TestFitWidget(TestCaseQt):
fitmngr.addbgtheory("spam", theory)
fw = FitWidget(fitmngr=fitmngr)
- fw.associateConfigDialog("spam", CustomConfigWidget(),
- theory_is_background=True)
+ fw.associateConfigDialog(
+ "spam", CustomConfigWidget(), theory_is_background=True
+ )
fw.associateConfigDialog("foo", CustomConfigWidget())
fw.show()
self.qWaitForWindowExposed(fw)
@@ -106,8 +101,7 @@ class TestFitWidget(TestCaseQt):
fw.bgconfigdialogs["spam"].accept()
self.assertTrue(fw.bgconfigdialogs["spam"].result())
- self.assertEqual(fw.bgconfigdialogs["spam"].output,
- {"hello": "world"})
+ self.assertEqual(fw.bgconfigdialogs["spam"].output, {"hello": "world"})
fw.bgconfigdialogs["spam"].reject()
self.assertFalse(fw.bgconfigdialogs["spam"].result())
diff --git a/src/silx/gui/hdf5/Hdf5Formatter.py b/src/silx/gui/hdf5/Hdf5Formatter.py
index 4dbb0fc..99e0bb6 100644
--- a/src/silx/gui/hdf5/Hdf5Formatter.py
+++ b/src/silx/gui/hdf5/Hdf5Formatter.py
@@ -37,8 +37,7 @@ import h5py
class Hdf5Formatter(qt.QObject):
- """Formatter to convert HDF5 data to string.
- """
+ """Formatter to convert HDF5 data to string."""
formatChanged = qt.Signal()
"""Emitted when properties of the formatter change."""
@@ -87,7 +86,7 @@ class Hdf5Formatter(qt.QObject):
if dataset.shape == tuple():
return "scalar"
shape = [str(i) for i in dataset.shape]
- text = u" \u00D7 ".join(shape)
+ text = " \u00D7 ".join(shape)
return text
def humanReadableValue(self, dataset):
@@ -162,7 +161,7 @@ class Hdf5Formatter(qt.QObject):
if enumType is not None:
return "enum"
- text = str(dtype.newbyteorder('N'))
+ text = str(dtype.newbyteorder("N"))
if numpy.issubdtype(dtype, numpy.floating):
if hasattr(numpy, "float128") and dtype == numpy.float128:
text = "float80"
@@ -181,7 +180,7 @@ class Hdf5Formatter(qt.QObject):
elif dtype.byteorder == "=":
text = "Native " + text
- dtype = dtype.newbyteorder('N')
+ dtype = dtype.newbyteorder("N")
return text
def humanReadableHdf5Type(self, dataset):
diff --git a/src/silx/gui/hdf5/Hdf5HeaderView.py b/src/silx/gui/hdf5/Hdf5HeaderView.py
index 6d306e5..16323dd 100644
--- a/src/silx/gui/hdf5/Hdf5HeaderView.py
+++ b/src/silx/gui/hdf5/Hdf5HeaderView.py
@@ -72,21 +72,49 @@ class Hdf5HeaderView(qt.QHeaderView):
def __updateAutoResize(self):
"""Update the view according to the state of the auto-resize"""
if self.__auto_resize:
- self.setSectionResizeMode(Hdf5TreeModel.NAME_COLUMN, qt.QHeaderView.ResizeToContents)
- self.setSectionResizeMode(Hdf5TreeModel.TYPE_COLUMN, qt.QHeaderView.ResizeToContents)
- self.setSectionResizeMode(Hdf5TreeModel.SHAPE_COLUMN, qt.QHeaderView.ResizeToContents)
- self.setSectionResizeMode(Hdf5TreeModel.VALUE_COLUMN, qt.QHeaderView.Interactive)
- self.setSectionResizeMode(Hdf5TreeModel.DESCRIPTION_COLUMN, qt.QHeaderView.Interactive)
- self.setSectionResizeMode(Hdf5TreeModel.NODE_COLUMN, qt.QHeaderView.ResizeToContents)
- self.setSectionResizeMode(Hdf5TreeModel.LINK_COLUMN, qt.QHeaderView.ResizeToContents)
+ self.setSectionResizeMode(
+ Hdf5TreeModel.NAME_COLUMN, qt.QHeaderView.ResizeToContents
+ )
+ self.setSectionResizeMode(
+ Hdf5TreeModel.TYPE_COLUMN, qt.QHeaderView.ResizeToContents
+ )
+ self.setSectionResizeMode(
+ Hdf5TreeModel.SHAPE_COLUMN, qt.QHeaderView.ResizeToContents
+ )
+ self.setSectionResizeMode(
+ Hdf5TreeModel.VALUE_COLUMN, qt.QHeaderView.Interactive
+ )
+ self.setSectionResizeMode(
+ Hdf5TreeModel.DESCRIPTION_COLUMN, qt.QHeaderView.Interactive
+ )
+ self.setSectionResizeMode(
+ Hdf5TreeModel.NODE_COLUMN, qt.QHeaderView.ResizeToContents
+ )
+ self.setSectionResizeMode(
+ Hdf5TreeModel.LINK_COLUMN, qt.QHeaderView.ResizeToContents
+ )
else:
- self.setSectionResizeMode(Hdf5TreeModel.NAME_COLUMN, qt.QHeaderView.Interactive)
- self.setSectionResizeMode(Hdf5TreeModel.TYPE_COLUMN, qt.QHeaderView.Interactive)
- self.setSectionResizeMode(Hdf5TreeModel.SHAPE_COLUMN, qt.QHeaderView.Interactive)
- self.setSectionResizeMode(Hdf5TreeModel.VALUE_COLUMN, qt.QHeaderView.Interactive)
- self.setSectionResizeMode(Hdf5TreeModel.DESCRIPTION_COLUMN, qt.QHeaderView.Interactive)
- self.setSectionResizeMode(Hdf5TreeModel.NODE_COLUMN, qt.QHeaderView.Interactive)
- self.setSectionResizeMode(Hdf5TreeModel.LINK_COLUMN, qt.QHeaderView.Interactive)
+ self.setSectionResizeMode(
+ Hdf5TreeModel.NAME_COLUMN, qt.QHeaderView.Interactive
+ )
+ self.setSectionResizeMode(
+ Hdf5TreeModel.TYPE_COLUMN, qt.QHeaderView.Interactive
+ )
+ self.setSectionResizeMode(
+ Hdf5TreeModel.SHAPE_COLUMN, qt.QHeaderView.Interactive
+ )
+ self.setSectionResizeMode(
+ Hdf5TreeModel.VALUE_COLUMN, qt.QHeaderView.Interactive
+ )
+ self.setSectionResizeMode(
+ Hdf5TreeModel.DESCRIPTION_COLUMN, qt.QHeaderView.Interactive
+ )
+ self.setSectionResizeMode(
+ Hdf5TreeModel.NODE_COLUMN, qt.QHeaderView.Interactive
+ )
+ self.setSectionResizeMode(
+ Hdf5TreeModel.LINK_COLUMN, qt.QHeaderView.Interactive
+ )
def setAutoResizeColumns(self, autoResize):
"""Enable/disable auto-resize. When auto-resized, the header take care
@@ -125,7 +153,9 @@ class Hdf5HeaderView(qt.QHeaderView):
"""
return self.__hide_columns_popup
- enableHideColumnsPopup = qt.Property(bool, hasHideColumnsPopup, setAutoResizeColumns)
+ enableHideColumnsPopup = qt.Property(
+ bool, hasHideColumnsPopup, setAutoResizeColumns
+ )
"""Property to enable/disable popup allowing to hide/show columns."""
def __genHideSectionEvent(self, column):
diff --git a/src/silx/gui/hdf5/Hdf5Item.py b/src/silx/gui/hdf5/Hdf5Item.py
index 8f20649..2777a94 100755
--- a/src/silx/gui/hdf5/Hdf5Item.py
+++ b/src/silx/gui/hdf5/Hdf5Item.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -28,7 +28,6 @@ __date__ = "17/01/2019"
import logging
-import collections
import enum
from typing import Optional
@@ -39,6 +38,7 @@ from .Hdf5Node import Hdf5Node
import silx.io.utils
from silx.gui.data.TextFormatter import TextFormatter
from ..hdf5.Hdf5Formatter import Hdf5Formatter
+
_logger = logging.getLogger(__name__)
_formatter = TextFormatter()
_hdf5Formatter = Hdf5Formatter(textFormatter=_formatter)
@@ -46,8 +46,8 @@ _hdf5Formatter = Hdf5Formatter(textFormatter=_formatter)
class DescriptionType(enum.Enum):
- """List of available kind of description.
- """
+ """List of available kind of description."""
+
ERROR = "error"
DESCRIPTION = "description"
TITLE = "title"
@@ -210,9 +210,14 @@ class Hdf5Item(Hdf5Node):
class_ = silx.io.utils.get_h5_class(self.__obj)
if class_ == silx.io.utils.H5Type.EXTERNAL_LINK:
- message = "External link broken. Path %s::%s does not exist" % (self.__obj.filename, self.__obj.path)
+ message = "External link broken. Path %s::%s does not exist" % (
+ self.__obj.filename,
+ self.__obj.path,
+ )
elif class_ == silx.io.utils.H5Type.SOFT_LINK:
- message = "Soft link broken. Path %s does not exist" % (self.__obj.path)
+ message = "Soft link broken. Path %s does not exist" % (
+ self.__obj.path
+ )
else:
name = self.__obj.__class__.__name__.split(".")[-1].capitalize()
message = "%s broken" % (name)
@@ -220,7 +225,10 @@ class Hdf5Item(Hdf5Node):
self.__isBroken = True
else:
self.__obj = obj
- if silx.io.utils.get_h5_class(obj) not in [silx.io.utils.H5Type.GROUP, silx.io.utils.H5Type.FILE]:
+ if silx.io.utils.get_h5_class(obj) not in [
+ silx.io.utils.H5Type.GROUP,
+ silx.io.utils.H5Type.FILE,
+ ]:
try:
# pre-fetch of the data
if obj.shape is None:
@@ -257,7 +265,10 @@ class Hdf5Item(Hdf5Node):
keys.append(name)
except Exception:
lib_name = self.obj.__class__.__module__.split(".")[0]
- _logger.error("Internal %s error (second time). The file is corrupted.", lib_name)
+ _logger.error(
+ "Internal %s error (second time). The file is corrupted.",
+ lib_name,
+ )
_logger.debug("Backtrace", exc_info=True)
for name in keys:
try:
@@ -281,7 +292,14 @@ class Hdf5Item(Hdf5Node):
h5class = silx.io.utils.get_h5_class(class_=class_)
if h5class is None:
_logger.error("Class %s unsupported", class_)
- item = Hdf5Item(text=name, obj=None, parent=self, key=name, h5Class=h5class, linkClass=link)
+ item = Hdf5Item(
+ text=name,
+ obj=None,
+ parent=self,
+ key=name,
+ h5Class=h5class,
+ linkClass=link,
+ )
self.appendChild(item)
def hasChildren(self):
@@ -330,7 +348,7 @@ class Hdf5Item(Hdf5Node):
:param Dict[str,str] attributeDict: Key/value attributes
"""
- attributeDict = collections.OrderedDict()
+ attributeDict = {}
if self.h5Class == silx.io.utils.H5Type.DATASET:
attributeDict["#Title"] = "HDF5 Dataset"
@@ -338,7 +356,9 @@ class Hdf5Item(Hdf5Node):
attributeDict["Path"] = self.obj.name
attributeDict["Shape"] = self._getFormatter().humanReadableShape(self.obj)
attributeDict["Value"] = self._getFormatter().humanReadableValue(self.obj)
- attributeDict["Data type"] = self._getFormatter().humanReadableType(self.obj, full=True)
+ attributeDict["Data type"] = self._getFormatter().humanReadableType(
+ self.obj, full=True
+ )
elif self.h5Class == silx.io.utils.H5Type.GROUP:
attributeDict["#Title"] = "HDF5 Group"
if self.nexusClassName:
@@ -395,14 +415,18 @@ class Hdf5Item(Hdf5Node):
# Check NX_class formatting
lower = text.lower()
formatedNX_class = ""
- if lower.startswith('nx'):
- formatedNX_class = 'NX' + lower[2:]
- if lower == 'nxcansas':
- formatedNX_class = 'NXcanSAS' # That's the only class with capital letters...
+ if lower.startswith("nx"):
+ formatedNX_class = "NX" + lower[2:]
+ if lower == "nxcansas":
+ formatedNX_class = (
+ "NXcanSAS" # That's the only class with capital letters...
+ )
if text != formatedNX_class:
- _logger.error("NX_class: '%s' is malformed (should be '%s')",
- text,
- formatedNX_class)
+ _logger.error(
+ "NX_class: '%s' is malformed (should be '%s')",
+ text,
+ formatedNX_class,
+ )
text = formatedNX_class
self.__nx_class = text
@@ -469,59 +493,44 @@ class Hdf5Item(Hdf5Node):
return None
_NEXUS_CLASS_TO_VALUE_CHILDREN = {
- 'NXaperture': (
- (DescriptionType.DESCRIPTION, 'description'),
- ),
- 'NXbeam_stop': (
- (DescriptionType.DESCRIPTION, 'description'),
- ),
- 'NXdetector': (
- (DescriptionType.NAME, 'local_name'),
- (DescriptionType.DESCRIPTION, 'description')
- ),
- 'NXentry': (
- (DescriptionType.TITLE, 'title'),
- ),
- 'NXenvironment': (
- (DescriptionType.NAME, 'short_name'),
- (DescriptionType.NAME, 'name'),
- (DescriptionType.DESCRIPTION, 'description')
- ),
- 'NXinstrument': (
- (DescriptionType.NAME, 'name'),
- ),
- 'NXlog': (
- (DescriptionType.DESCRIPTION, 'description'),
- ),
- 'NXmirror': (
- (DescriptionType.DESCRIPTION, 'description'),
- ),
- 'NXpositioner': (
- (DescriptionType.NAME, 'name'),
+ "NXaperture": ((DescriptionType.DESCRIPTION, "description"),),
+ "NXbeam_stop": ((DescriptionType.DESCRIPTION, "description"),),
+ "NXdetector": (
+ (DescriptionType.NAME, "local_name"),
+ (DescriptionType.DESCRIPTION, "description"),
),
- 'NXprocess': (
- (DescriptionType.PROGRAM, 'program'),
+ "NXentry": ((DescriptionType.TITLE, "title"),),
+ "NXenvironment": (
+ (DescriptionType.NAME, "short_name"),
+ (DescriptionType.NAME, "name"),
+ (DescriptionType.DESCRIPTION, "description"),
),
- 'NXsample': (
- (DescriptionType.TITLE, 'short_title'),
- (DescriptionType.NAME, 'name'),
- (DescriptionType.DESCRIPTION, 'description')
+ "NXinstrument": ((DescriptionType.NAME, "name"),),
+ "NXlog": ((DescriptionType.DESCRIPTION, "description"),),
+ "NXmirror": ((DescriptionType.DESCRIPTION, "description"),),
+ "NXnote": ((DescriptionType.DESCRIPTION, "description"),),
+ "NXpositioner": ((DescriptionType.NAME, "name"),),
+ "NXprocess": ((DescriptionType.PROGRAM, "program"),),
+ "NXsample": (
+ (DescriptionType.TITLE, "short_title"),
+ (DescriptionType.NAME, "name"),
+ (DescriptionType.DESCRIPTION, "description"),
),
- 'NXsample_component': (
- (DescriptionType.NAME, 'name'),
- (DescriptionType.DESCRIPTION, 'description')
+ "NXsample_component": (
+ (DescriptionType.NAME, "name"),
+ (DescriptionType.DESCRIPTION, "description"),
),
- 'NXsensor': (
- (DescriptionType.NAME, 'short_name'),
- (DescriptionType.NAME, 'name')
+ "NXsensor": (
+ (DescriptionType.NAME, "short_name"),
+ (DescriptionType.NAME, "name"),
),
- 'NXsource': (
- (DescriptionType.NAME, 'name'),
+ "NXsource": (
+ (DescriptionType.NAME, "name"),
), # or its 'short_name' attribute... This is not supported
- 'NXsubentry': (
- (DescriptionType.DESCRIPTION, 'definition'),
- (DescriptionType.PROGRAM, 'program_name'),
- (DescriptionType.TITLE, 'title'),
+ "NXsubentry": (
+ (DescriptionType.DESCRIPTION, "definition"),
+ (DescriptionType.PROGRAM, "program_name"),
+ (DescriptionType.TITLE, "title"),
),
}
"""Mapping from NeXus class to child names containing data to use as value"""
@@ -536,19 +545,25 @@ class Hdf5Item(Hdf5Node):
return DescriptionType.ERROR, self.__error
if self.h5Class == silx.io.utils.H5Type.DATASET:
- return DescriptionType.VALUE, self._getFormatter().humanReadableValue(self.obj)
+ return DescriptionType.VALUE, self._getFormatter().humanReadableValue(
+ self.obj
+ )
elif self.isGroupObj() and self.nexusClassName:
# For NeXus groups, try to find a title or name
# By default, look for a title (most application definitions should have one)
- defaultSequence = ((DescriptionType.TITLE, 'title'),)
- sequence = self._NEXUS_CLASS_TO_VALUE_CHILDREN.get(self.nexusClassName, defaultSequence)
+ defaultSequence = ((DescriptionType.TITLE, "title"),)
+ sequence = self._NEXUS_CLASS_TO_VALUE_CHILDREN.get(
+ self.nexusClassName, defaultSequence
+ )
for kind, child_name in sequence:
for index in range(self.childCount()):
child = self.child(index)
- if (isinstance(child, Hdf5Item) and
- child.h5Class == silx.io.utils.H5Type.DATASET and
- child.basename == child_name):
+ if (
+ isinstance(child, Hdf5Item)
+ and child.h5Class == silx.io.utils.H5Type.DATASET
+ and child.basename == child_name
+ ):
return kind, self._getFormatter().humanReadableValue(child.obj)
description = self.obj.attrs.get("desc", None)
diff --git a/src/silx/gui/hdf5/Hdf5Node.py b/src/silx/gui/hdf5/Hdf5Node.py
index 0d58748..db49594 100644
--- a/src/silx/gui/hdf5/Hdf5Node.py
+++ b/src/silx/gui/hdf5/Hdf5Node.py
@@ -36,11 +36,12 @@ class Hdf5Node(object):
It provides link to the childs and to the parents, and a link to an
external object.
"""
+
def __init__(
self,
parent=None,
populateAll=False,
- openedPath: Optional[str]=None,
+ openedPath: Optional[str] = None,
):
"""
Constructor
diff --git a/src/silx/gui/hdf5/Hdf5TreeModel.py b/src/silx/gui/hdf5/Hdf5TreeModel.py
index 8ac800a..3353ab3 100644
--- a/src/silx/gui/hdf5/Hdf5TreeModel.py
+++ b/src/silx/gui/hdf5/Hdf5TreeModel.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -38,6 +38,7 @@ from .Hdf5Item import Hdf5Item
from .Hdf5LoadingItem import Hdf5LoadingItem
from . import _utils
from ... import io as silx_io
+from ...io._sliceh5 import DatasetSlice
import h5py
@@ -61,6 +62,8 @@ def _createRootLabel(h5obj):
if path.startswith("/"):
path = path[1:]
label = "%s::%s" % (filename, path)
+ if isinstance(h5obj, DatasetSlice):
+ label += str(list(h5obj.indices))
return label
@@ -69,7 +72,8 @@ class LoadingItemRunnable(qt.QRunnable):
class __Signals(qt.QObject):
"""Signal holder"""
- itemReady = qt.Signal(object, object, object)
+
+ itemReady = qt.Signal(object, object, object, str)
runnerFinished = qt.Signal(object)
def __init__(self, filename, item):
@@ -126,7 +130,7 @@ class LoadingItemRunnable(qt.QRunnable):
if h5file is not None:
h5file.close()
- self.itemReady.emit(self.oldItem, newItem, error)
+ self.itemReady.emit(self.oldItem, newItem, error, self.filename)
self.runnerFinished.emit(self)
def autoDelete(self):
@@ -181,7 +185,7 @@ class Hdf5TreeModel(qt.QAbstractItemModel):
]
"""List of logical columns available"""
- sigH5pyObjectLoaded = qt.Signal(object)
+ sigH5pyObjectLoaded = qt.Signal(object, str)
"""Emitted when a new root item was loaded and inserted to the model."""
sigH5pyObjectRemoved = qt.Signal(object)
@@ -201,13 +205,13 @@ class Hdf5TreeModel(qt.QAbstractItemModel):
super(Hdf5TreeModel, self).__init__(parent)
self.header_labels = [None] * len(self.COLUMN_IDS)
- self.header_labels[self.NAME_COLUMN] = 'Name'
- self.header_labels[self.TYPE_COLUMN] = 'Type'
- self.header_labels[self.SHAPE_COLUMN] = 'Shape'
- self.header_labels[self.VALUE_COLUMN] = 'Value'
- self.header_labels[self.DESCRIPTION_COLUMN] = 'Description'
- self.header_labels[self.NODE_COLUMN] = 'Node'
- self.header_labels[self.LINK_COLUMN] = 'Link'
+ self.header_labels[self.NAME_COLUMN] = "Name"
+ self.header_labels[self.TYPE_COLUMN] = "Type"
+ self.header_labels[self.SHAPE_COLUMN] = "Shape"
+ self.header_labels[self.VALUE_COLUMN] = "Value"
+ self.header_labels[self.DESCRIPTION_COLUMN] = "Description"
+ self.header_labels[self.NODE_COLUMN] = "Node"
+ self.header_labels[self.LINK_COLUMN] = "Link"
# Create items
self.__root = Hdf5Node()
@@ -247,7 +251,6 @@ class Hdf5TreeModel(qt.QAbstractItemModel):
"""Static method to close explicit references to internal objects."""
_logger.debug("Clear Hdf5TreeModel")
for obj in fileList:
- _logger.debug("Close file %s", obj.filename)
obj.close()
fileList[:] = []
@@ -266,14 +269,21 @@ class Hdf5TreeModel(qt.QAbstractItemModel):
index2 = self.index(i, self.columnCount() - 1, qt.QModelIndex())
self.dataChanged.emit(index1, index2)
- def __itemReady(self, oldItem, newItem, error):
+ def __itemReady(
+ self,
+ oldItem: Hdf5Node,
+ newItem: Optional[Hdf5Node],
+ error: Optional[Exception],
+ filename: str,
+ ):
"""Called at the end of a concurent file loading, when the loading
item is ready. AN error is defined if an exception occured when
loading the newItem .
- :param Hdf5Node oldItem: current displayed item
- :param Hdf5Node newItem: item loaded, or None if error is defined
- :param Exception error: An exception, or None if newItem is defined
+ :param oldItem: current displayed item
+ :param newItem: item loaded, or None if error is defined
+ :param error: An exception, or None if newItem is defined
+ :param filename: The filename used to load the new item
"""
row = self.__root.indexOfChild(oldItem)
@@ -291,7 +301,7 @@ class Hdf5TreeModel(qt.QAbstractItemModel):
self.endInsertRows()
if isinstance(oldItem, Hdf5LoadingItem):
- self.sigH5pyObjectLoaded.emit(newItem.obj)
+ self.sigH5pyObjectLoaded.emit(newItem.obj, filename)
else:
self.sigH5pyObjectSynchronized.emit(oldItem.obj, newItem.obj)
@@ -384,7 +394,9 @@ class Hdf5TreeModel(qt.QAbstractItemModel):
if action == qt.Qt.IgnoreAction:
return True
- if self.__fileMoveEnabled and mimedata.hasFormat(_utils.Hdf5DatasetMimeData.MIME_TYPE):
+ if self.__fileMoveEnabled and mimedata.hasFormat(
+ _utils.Hdf5DatasetMimeData.MIME_TYPE
+ ):
if mimedata.isRoot():
dragNode = mimedata.node()
parentNode = self.nodeFromIndex(parentIndex)
@@ -404,10 +416,9 @@ class Hdf5TreeModel(qt.QAbstractItemModel):
return True
if self.__fileDropEnabled and mimedata.hasFormat("text/uri-list"):
-
parentNode = self.nodeFromIndex(parentIndex)
if parentNode is not self.__root:
- while(parentNode is not self.__root):
+ while parentNode is not self.__root:
node = parentNode
parentNode = node.parent
row = parentNode.indexOfChild(node)
@@ -424,7 +435,10 @@ class Hdf5TreeModel(qt.QAbstractItemModel):
messages.append(e.args[0])
if len(messages) > 0:
title = "Error occurred when loading files"
- message = "<html>%s:<ul><li>%s</li><ul></html>" % (title, "</li><li>".join(messages))
+ message = "<html>%s:<ul><li>%s</li><ul></html>" % (
+ title,
+ "</li><li>".join(messages),
+ )
qt.QMessageBox.critical(None, title, message)
return True
@@ -443,14 +457,31 @@ class Hdf5TreeModel(qt.QAbstractItemModel):
self.__root.insertChild(row, node)
self.endInsertRows()
- def moveRow(self, sourceParentIndex, sourceRow, destinationParentIndex, destinationRow):
+ def moveRow(
+ self, sourceParentIndex, sourceRow, destinationParentIndex, destinationRow
+ ):
if sourceRow == destinationRow or sourceRow == destinationRow - 1:
# abort move, same place
return
- return self.moveRows(sourceParentIndex, sourceRow, 1, destinationParentIndex, destinationRow)
+ return self.moveRows(
+ sourceParentIndex, sourceRow, 1, destinationParentIndex, destinationRow
+ )
- def moveRows(self, sourceParentIndex, sourceRow, count, destinationParentIndex, destinationRow):
- self.beginMoveRows(sourceParentIndex, sourceRow, sourceRow, destinationParentIndex, destinationRow)
+ def moveRows(
+ self,
+ sourceParentIndex,
+ sourceRow,
+ count,
+ destinationParentIndex,
+ destinationRow,
+ ):
+ self.beginMoveRows(
+ sourceParentIndex,
+ sourceRow,
+ sourceRow,
+ destinationParentIndex,
+ destinationRow,
+ )
sourceNode = self.nodeFromIndex(sourceParentIndex)
destinationNode = self.nodeFromIndex(destinationParentIndex)
@@ -532,14 +563,14 @@ class Hdf5TreeModel(qt.QAbstractItemModel):
return qt.QModelIndex()
row = grandparent.indexOfChild(parent)
- assert row != - 1
+ assert row != -1
return self.createIndex(row, 0, parent)
def nodeFromIndex(self, index):
return index.internalPointer() if index.isValid() else self.__root
def _closeFileIfOwned(self, node):
- """"Close the file if it was loaded from a filename or a
+ """Close the file if it was loaded from a filename or a
drag-and-drop"""
obj = node.obj
for f in self.__openedFiles:
@@ -572,10 +603,15 @@ class Hdf5TreeModel(qt.QAbstractItemModel):
# else compare commonh5 objects
if not isinstance(obj2, type(obj1)):
return False
+
def key(item):
- if item.file is None:
- return item.name
- return item.file.filename, item.file.mode, item.name
+ info = [item.name]
+ if item.file is not None:
+ info += [item.file.filename, item.file.mode]
+ if isinstance(item, DatasetSlice):
+ info.append(item.indices)
+ return tuple(info)
+
return key(obj1) == key(obj2)
def h5pyObjectRow(self, h5pyObject):
@@ -655,7 +691,7 @@ class Hdf5TreeModel(qt.QAbstractItemModel):
obj=h5pyObject,
parent=self.__root,
openedPath=filename,
- )
+ ),
)
def hasPendingOperations(self):
@@ -697,7 +733,7 @@ class Hdf5TreeModel(qt.QAbstractItemModel):
h5file = silx_io.open(filename)
if self.__ownFiles:
self.__openedFiles.append(h5file)
- self.sigH5pyObjectLoaded.emit(h5file)
+ self.sigH5pyObjectLoaded.emit(h5file, filename)
self.insertH5pyObject(h5file, row=row, filename=filename)
except IOError:
_logger.debug("File '%s' can't be read.", filename, exc_info=True)
diff --git a/src/silx/gui/hdf5/Hdf5TreeView.py b/src/silx/gui/hdf5/Hdf5TreeView.py
index da35d15..a477fc3 100644
--- a/src/silx/gui/hdf5/Hdf5TreeView.py
+++ b/src/silx/gui/hdf5/Hdf5TreeView.py
@@ -57,6 +57,7 @@ class Hdf5TreeView(qt.QTreeView):
:meth:`removeContextMenuCallback` to add your custum actions according
to the selected objects.
"""
+
def __init__(self, parent=None):
"""
Constructor
@@ -167,7 +168,11 @@ class Hdf5TreeView(qt.QTreeView):
def dragEnterEvent(self, event):
model = self.findHdf5TreeModel()
- if model is not None and model.isFileDropEnabled() and event.mimeData().hasFormat("text/uri-list"):
+ if (
+ model is not None
+ and model.isFileDropEnabled()
+ and event.mimeData().hasFormat("text/uri-list")
+ ):
self.setState(qt.QAbstractItemView.DraggingState)
event.accept()
else:
@@ -175,7 +180,11 @@ class Hdf5TreeView(qt.QTreeView):
def dragMoveEvent(self, event):
model = self.findHdf5TreeModel()
- if model is not None and model.isFileDropEnabled() and event.mimeData().hasFormat("text/uri-list"):
+ if (
+ model is not None
+ and model.isFileDropEnabled()
+ and event.mimeData().hasFormat("text/uri-list")
+ ):
event.setDropAction(qt.Qt.CopyAction)
event.accept()
else:
@@ -215,7 +224,9 @@ class Hdf5TreeView(qt.QTreeView):
model = model.sourceModel()
else:
break
- raise RuntimeError("Model from the requested index is not reachable from this view")
+ raise RuntimeError(
+ "Model from the requested index is not reachable from this view"
+ )
def mapToModel(self, index):
"""Map an index from any model reachable by the view to an index from
diff --git a/src/silx/gui/hdf5/NexusSortFilterProxyModel.py b/src/silx/gui/hdf5/NexusSortFilterProxyModel.py
index 1b80c3e..0bc7352 100644
--- a/src/silx/gui/hdf5/NexusSortFilterProxyModel.py
+++ b/src/silx/gui/hdf5/NexusSortFilterProxyModel.py
@@ -76,7 +76,8 @@ class NexusSortFilterProxyModel(qt.QSortFilterProxyModel):
"""
if sourceLeft.column() != Hdf5TreeModel.NAME_COLUMN:
return super(NexusSortFilterProxyModel, self).lessThan(
- sourceLeft, sourceRight)
+ sourceLeft, sourceRight
+ )
# Do not sort child of root (files)
if sourceLeft.parent() == qt.QModelIndex():
@@ -217,7 +218,9 @@ class NexusSortFilterProxyModel(qt.QSortFilterProxyModel):
if index.column() == Hdf5TreeModel.NAME_COLUMN:
if role == qt.Qt.DecorationRole:
sourceIndex = self.mapToSource(index)
- item = self.sourceModel().data(sourceIndex, Hdf5TreeModel.H5PY_ITEM_ROLE)
+ item = self.sourceModel().data(
+ sourceIndex, Hdf5TreeModel.H5PY_ITEM_ROLE
+ )
if self.__isNXnode(item):
result = self.__getNxIcon(result)
return result
diff --git a/src/silx/gui/hdf5/__init__.py b/src/silx/gui/hdf5/__init__.py
index 2243484..8e07407 100644
--- a/src/silx/gui/hdf5/__init__.py
+++ b/src/silx/gui/hdf5/__init__.py
@@ -40,4 +40,10 @@ from ._utils import Hdf5ContextMenuEvent # noqa
from .NexusSortFilterProxyModel import NexusSortFilterProxyModel # noqa
from .Hdf5TreeModel import Hdf5TreeModel # noqa
-__all__ = ['Hdf5TreeView', 'H5Node', 'Hdf5ContextMenuEvent', 'NexusSortFilterProxyModel', 'Hdf5TreeModel']
+__all__ = [
+ "Hdf5TreeView",
+ "H5Node",
+ "Hdf5ContextMenuEvent",
+ "NexusSortFilterProxyModel",
+ "Hdf5TreeModel",
+]
diff --git a/src/silx/gui/hdf5/_utils.py b/src/silx/gui/hdf5/_utils.py
index 1d1b4cb..7232bfe 100644
--- a/src/silx/gui/hdf5/_utils.py
+++ b/src/silx/gui/hdf5/_utils.py
@@ -33,7 +33,7 @@ __date__ = "17/01/2019"
from html import escape
import logging
import os.path
-
+from silx.gui import constants
import silx.io.utils
import silx.io.url
from .. import qt
@@ -109,19 +109,20 @@ class Hdf5DatasetMimeData(qt.QMimeData):
MIME_TYPE = "application/x-internal-h5py-dataset"
- SILX_URI_TYPE = "application/x-silx-uri"
+ SILX_URI_TYPE = constants.SILX_URI_MIMETYPE
+ """For compatibility with silx <= 1.1"""
def __init__(self, node=None, dataset=None, isRoot=False):
qt.QMimeData.__init__(self)
self.__dataset = dataset
self.__node = node
self.__isRoot = isRoot
- self.setData(self.MIME_TYPE, "".encode(encoding='utf-8'))
+ self.setData(self.MIME_TYPE, "".encode(encoding="utf-8"))
if node is not None:
h5Node = H5Node(node)
silxUrl = h5Node.url
self.setText(silxUrl)
- self.setData(self.SILX_URI_TYPE, silxUrl.encode(encoding='utf-8'))
+ self.setData(constants.SILX_URI_MIMETYPE, silxUrl.encode(encoding="utf-8"))
def isRoot(self):
return self.__isRoot
@@ -427,9 +428,9 @@ class H5Node(object):
:rtype: ~silx.io.url.DataUrl
"""
absolute_filename = os.path.abspath(self.local_filename)
- return silx.io.url.DataUrl(scheme="silx",
- file_path=absolute_filename,
- data_path=self.local_name)
+ return silx.io.url.DataUrl(
+ scheme="silx", file_path=absolute_filename, data_path=self.local_name
+ )
@property
def url(self):
diff --git a/src/silx/gui/hdf5/test/test_hdf5.py b/src/silx/gui/hdf5/test/test_hdf5.py
index 6e77e1d..cb08436 100755
--- a/src/silx/gui/hdf5/test/test_hdf5.py
+++ b/src/silx/gui/hdf5/test/test_hdf5.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -30,23 +30,24 @@ __date__ = "12/03/2019"
import time
import os
-import unittest
import tempfile
import numpy
-from pkg_resources import parse_version
+from packaging.version import Version
from contextlib import contextmanager
from silx.gui import qt
from silx.gui.utils.testutils import TestCaseQt
from silx.gui import hdf5
from silx.gui.utils.testutils import SignalListener
from silx.io import commonh5
+from silx.io import h5py_utils
+from silx.io.url import DataUrl
import weakref
import h5py
import pytest
-h5py2_9 = parse_version(h5py.version.version) >= parse_version('2.9.0')
+h5py2_9 = Version(h5py.version.version) >= Version("2.9.0")
@pytest.fixture(scope="class")
@@ -54,7 +55,7 @@ def useH5File(request, tmpdir_factory):
tmp = tmpdir_factory.mktemp("test_hdf5")
request.cls.filename = os.path.join(tmp, "data.h5")
# create h5 data
- with h5py.File(request.cls.filename, "w") as f:
+ with h5py_utils.File(request.cls.filename, "w") as f:
g = f.create_group("arrays")
g.create_dataset("scalar", data=10)
yield
@@ -69,7 +70,6 @@ def create_NXentry(group, name):
@pytest.mark.usefixtures("useH5File")
class TestHdf5TreeModel(TestCaseQt):
-
def setUp(self):
super(TestHdf5TreeModel, self).setUp()
@@ -87,7 +87,7 @@ class TestHdf5TreeModel(TestCaseQt):
fd, tmp_name = tempfile.mkstemp(suffix=".h5")
os.close(fd)
# create h5 data
- h5file = h5py.File(tmp_name, "w")
+ h5file = h5py_utils.File(tmp_name, "w")
g = h5file.create_group("arrays")
g.create_dataset("scalar", data=10)
h5file.close()
@@ -134,7 +134,9 @@ class TestHdf5TreeModel(TestCaseQt):
self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
model.insertFileAsync(self.filename)
index = model.index(0, 0, qt.QModelIndex())
- self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5LoadingItem.Hdf5LoadingItem)
+ self.assertIsInstance(
+ model.nodeFromIndex(index), hdf5.Hdf5LoadingItem.Hdf5LoadingItem
+ )
self.waitForPendingOperations(model)
index = model.index(0, 0, qt.QModelIndex())
self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5Item.Hdf5Item)
@@ -160,7 +162,7 @@ class TestHdf5TreeModel(TestCaseQt):
self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
def testSynchronizeObject(self):
- h5 = h5py.File(self.filename, mode="r")
+ h5 = h5py_utils.File(self.filename, mode="r")
model = hdf5.Hdf5TreeModel()
model.insertH5pyObject(h5)
self.assertEqual(model.rowCount(qt.QModelIndex()), 1)
@@ -235,7 +237,7 @@ class TestHdf5TreeModel(TestCaseQt):
"""A file inserted as an h5py object is not open (then not closed)
internally."""
try:
- h5File = h5py.File(self.filename, mode="r")
+ h5File = h5py_utils.File(self.filename, mode="r")
model = hdf5.Hdf5TreeModel()
self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
model.insertH5pyObject(h5File)
@@ -244,7 +246,9 @@ class TestHdf5TreeModel(TestCaseQt):
h5File = model.data(index, role=hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE)
model.removeIndex(index)
self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
- self.assertTrue(bool(h5File.id.valid), "The HDF5 file was unexpetedly closed")
+ self.assertTrue(
+ bool(h5File.id.valid), "The HDF5 file was unexpetedly closed"
+ )
finally:
h5File.close()
@@ -269,7 +273,12 @@ class TestHdf5TreeModel(TestCaseQt):
def getRowDataAsDict(self, model, row):
displayed = {}
- roles = [qt.Qt.DisplayRole, qt.Qt.DecorationRole, qt.Qt.ToolTipRole, qt.Qt.TextAlignmentRole]
+ roles = [
+ qt.Qt.DisplayRole,
+ qt.Qt.DecorationRole,
+ qt.Qt.ToolTipRole,
+ qt.Qt.TextAlignmentRole,
+ ]
for column in range(0, model.columnCount(qt.QModelIndex())):
index = model.index(0, column, qt.QModelIndex())
for role in roles:
@@ -286,13 +295,27 @@ class TestHdf5TreeModel(TestCaseQt):
model = hdf5.Hdf5TreeModel()
model.insertH5pyObject(h5)
displayed = self.getRowDataAsDict(model, row=0)
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DisplayRole], "1.mock")
- self.assertIsInstance(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DecorationRole], qt.QIcon)
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.TYPE_COLUMN, qt.Qt.DisplayRole], "")
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.SHAPE_COLUMN, qt.Qt.DisplayRole], "")
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.VALUE_COLUMN, qt.Qt.DisplayRole], "")
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.DESCRIPTION_COLUMN, qt.Qt.DisplayRole], None)
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.NODE_COLUMN, qt.Qt.DisplayRole], "File")
+ self.assertEqual(
+ displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DisplayRole], "1.mock"
+ )
+ self.assertIsInstance(
+ displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DecorationRole], qt.QIcon
+ )
+ self.assertEqual(
+ displayed[hdf5.Hdf5TreeModel.TYPE_COLUMN, qt.Qt.DisplayRole], ""
+ )
+ self.assertEqual(
+ displayed[hdf5.Hdf5TreeModel.SHAPE_COLUMN, qt.Qt.DisplayRole], ""
+ )
+ self.assertEqual(
+ displayed[hdf5.Hdf5TreeModel.VALUE_COLUMN, qt.Qt.DisplayRole], ""
+ )
+ self.assertEqual(
+ displayed[hdf5.Hdf5TreeModel.DESCRIPTION_COLUMN, qt.Qt.DisplayRole], None
+ )
+ self.assertEqual(
+ displayed[hdf5.Hdf5TreeModel.NODE_COLUMN, qt.Qt.DisplayRole], "File"
+ )
def testGroupData(self):
h5 = commonh5.File("/foo/bar/1.mock", "w")
@@ -302,13 +325,27 @@ class TestHdf5TreeModel(TestCaseQt):
model = hdf5.Hdf5TreeModel()
model.insertH5pyObject(d)
displayed = self.getRowDataAsDict(model, row=0)
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DisplayRole], "1.mock::foo")
- self.assertIsInstance(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DecorationRole], qt.QIcon)
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.TYPE_COLUMN, qt.Qt.DisplayRole], "")
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.SHAPE_COLUMN, qt.Qt.DisplayRole], "")
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.VALUE_COLUMN, qt.Qt.DisplayRole], "")
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.DESCRIPTION_COLUMN, qt.Qt.DisplayRole], "fooo")
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.NODE_COLUMN, qt.Qt.DisplayRole], "Group")
+ self.assertEqual(
+ displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DisplayRole], "1.mock::foo"
+ )
+ self.assertIsInstance(
+ displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DecorationRole], qt.QIcon
+ )
+ self.assertEqual(
+ displayed[hdf5.Hdf5TreeModel.TYPE_COLUMN, qt.Qt.DisplayRole], ""
+ )
+ self.assertEqual(
+ displayed[hdf5.Hdf5TreeModel.SHAPE_COLUMN, qt.Qt.DisplayRole], ""
+ )
+ self.assertEqual(
+ displayed[hdf5.Hdf5TreeModel.VALUE_COLUMN, qt.Qt.DisplayRole], ""
+ )
+ self.assertEqual(
+ displayed[hdf5.Hdf5TreeModel.DESCRIPTION_COLUMN, qt.Qt.DisplayRole], "fooo"
+ )
+ self.assertEqual(
+ displayed[hdf5.Hdf5TreeModel.NODE_COLUMN, qt.Qt.DisplayRole], "Group"
+ )
def testDatasetData(self):
h5 = commonh5.File("/foo/bar/1.mock", "w")
@@ -318,13 +355,29 @@ class TestHdf5TreeModel(TestCaseQt):
model = hdf5.Hdf5TreeModel()
model.insertH5pyObject(d)
displayed = self.getRowDataAsDict(model, row=0)
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DisplayRole], "1.mock::foo")
- self.assertIsInstance(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DecorationRole], qt.QIcon)
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.TYPE_COLUMN, qt.Qt.DisplayRole], value.dtype.name)
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.SHAPE_COLUMN, qt.Qt.DisplayRole], "3")
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.VALUE_COLUMN, qt.Qt.DisplayRole], "[1 2 3]")
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.DESCRIPTION_COLUMN, qt.Qt.DisplayRole], "[1 2 3]")
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.NODE_COLUMN, qt.Qt.DisplayRole], "Dataset")
+ self.assertEqual(
+ displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DisplayRole], "1.mock::foo"
+ )
+ self.assertIsInstance(
+ displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DecorationRole], qt.QIcon
+ )
+ self.assertEqual(
+ displayed[hdf5.Hdf5TreeModel.TYPE_COLUMN, qt.Qt.DisplayRole],
+ value.dtype.name,
+ )
+ self.assertEqual(
+ displayed[hdf5.Hdf5TreeModel.SHAPE_COLUMN, qt.Qt.DisplayRole], "3"
+ )
+ self.assertEqual(
+ displayed[hdf5.Hdf5TreeModel.VALUE_COLUMN, qt.Qt.DisplayRole], "[1 2 3]"
+ )
+ self.assertEqual(
+ displayed[hdf5.Hdf5TreeModel.DESCRIPTION_COLUMN, qt.Qt.DisplayRole],
+ "[1 2 3]",
+ )
+ self.assertEqual(
+ displayed[hdf5.Hdf5TreeModel.NODE_COLUMN, qt.Qt.DisplayRole], "Dataset"
+ )
def testDropLastAsFirst(self):
model = hdf5.Hdf5TreeModel()
@@ -365,17 +418,18 @@ class TestHdf5TreeModel(TestCaseQt):
@pytest.mark.usefixtures("useH5File")
class TestHdf5TreeModelSignals(TestCaseQt):
-
def setUp(self):
TestCaseQt.setUp(self)
self.model = hdf5.Hdf5TreeModel()
- self.h5 = h5py.File(self.filename, mode='r')
+ self.h5 = h5py_utils.File(self.filename, mode="r")
self.model.insertH5pyObject(self.h5)
self.listener = SignalListener()
self.model.sigH5pyObjectLoaded.connect(self.listener.partial(signal="loaded"))
self.model.sigH5pyObjectRemoved.connect(self.listener.partial(signal="removed"))
- self.model.sigH5pyObjectSynchronized.connect(self.listener.partial(signal="synchronized"))
+ self.model.sigH5pyObjectSynchronized.connect(
+ self.listener.partial(signal="synchronized")
+ )
def tearDown(self):
self.signals = None
@@ -395,16 +449,28 @@ class TestHdf5TreeModelSignals(TestCaseQt):
raise RuntimeError("Still waiting for a pending operation")
def testInsert(self):
- h5 = h5py.File(self.filename, mode='r')
+ h5 = h5py_utils.File(self.filename, mode="r")
self.model.insertH5pyObject(h5)
self.assertEqual(self.listener.callCount(), 0)
def testLoaded(self):
- self.model.insertFile(self.filename)
- self.assertEqual(self.listener.callCount(), 1)
- self.assertEqual(self.listener.karguments(argumentName="signal")[0], "loaded")
- self.assertIsNot(self.listener.arguments(callIndex=0)[0], self.h5)
- self.assertEqual(self.listener.arguments(callIndex=0)[0].filename, self.filename)
+ for data_path in [None, "/arrays/scalar"]:
+ with self.subTest(data_path=data_path):
+ url = DataUrl(file_path=self.filename, data_path=data_path)
+ insertedFilename = url.path()
+ self.model.insertFile(insertedFilename)
+ self.assertEqual(self.listener.callCount(), 1)
+ self.assertEqual(
+ self.listener.karguments(argumentName="signal")[0], "loaded"
+ )
+ self.assertIsNot(self.listener.arguments(callIndex=0)[0], self.h5)
+ self.assertEqual(
+ self.listener.arguments(callIndex=0)[0].file.filename, self.filename
+ )
+ self.assertEqual(
+ self.listener.arguments(callIndex=0)[1], insertedFilename
+ )
+ self.listener.clear()
def testRemoved(self):
self.model.removeH5pyObject(self.h5)
@@ -416,13 +482,14 @@ class TestHdf5TreeModelSignals(TestCaseQt):
self.model.synchronizeH5pyObject(self.h5)
self.waitForPendingOperations(self.model)
self.assertEqual(self.listener.callCount(), 1)
- self.assertEqual(self.listener.karguments(argumentName="signal")[0], "synchronized")
+ self.assertEqual(
+ self.listener.karguments(argumentName="signal")[0], "synchronized"
+ )
self.assertIs(self.listener.arguments(callIndex=0)[0], self.h5)
self.assertIsNot(self.listener.arguments(callIndex=0)[1], self.h5)
class TestNexusSortFilterProxyModel(TestCaseQt):
-
def getChildNames(self, model, index):
count = model.rowCount(index)
result = []
@@ -451,9 +518,15 @@ class TestNexusSortFilterProxyModel(TestCaseQt):
"""Test NXentry with start_time"""
model = hdf5.Hdf5TreeModel()
h5 = commonh5.File("/foo/bar/1.mock", "w")
- create_NXentry(h5, "a").create_dataset("start_time", data=numpy.array([numpy.string_("2015")]))
- create_NXentry(h5, "b").create_dataset("start_time", data=numpy.array([numpy.string_("2013")]))
- create_NXentry(h5, "c").create_dataset("start_time", data=numpy.array([numpy.string_("2014")]))
+ create_NXentry(h5, "a").create_dataset(
+ "start_time", data=numpy.array([numpy.string_("2015")])
+ )
+ create_NXentry(h5, "b").create_dataset(
+ "start_time", data=numpy.array([numpy.string_("2013")])
+ )
+ create_NXentry(h5, "c").create_dataset(
+ "start_time", data=numpy.array([numpy.string_("2014")])
+ )
model.insertH5pyObject(h5)
proxy = hdf5.NexusSortFilterProxyModel()
@@ -466,9 +539,15 @@ class TestNexusSortFilterProxyModel(TestCaseQt):
"""Test NXentry with end_time"""
model = hdf5.Hdf5TreeModel()
h5 = commonh5.File("/foo/bar/1.mock", "w")
- create_NXentry(h5, "a").create_dataset("end_time", data=numpy.array([numpy.string_("2015")]))
- create_NXentry(h5, "b").create_dataset("end_time", data=numpy.array([numpy.string_("2013")]))
- create_NXentry(h5, "c").create_dataset("end_time", data=numpy.array([numpy.string_("2014")]))
+ create_NXentry(h5, "a").create_dataset(
+ "end_time", data=numpy.array([numpy.string_("2015")])
+ )
+ create_NXentry(h5, "b").create_dataset(
+ "end_time", data=numpy.array([numpy.string_("2013")])
+ )
+ create_NXentry(h5, "c").create_dataset(
+ "end_time", data=numpy.array([numpy.string_("2014")])
+ )
model.insertH5pyObject(h5)
proxy = hdf5.NexusSortFilterProxyModel()
@@ -565,7 +644,7 @@ class TestNexusSortFilterProxyModel(TestCaseQt):
self.assertListEqual(names, ["100aaa", "aaa100"])
-@pytest.fixture(scope='class')
+@pytest.fixture(scope="class")
def useH5Model(request, tmpdir_factory):
# Create HDF5 files
tmp = tmpdir_factory.mktemp("test_hdf5")
@@ -573,39 +652,53 @@ def useH5Model(request, tmpdir_factory):
extH5FileName = os.path.join(tmp, "base__external.h5")
extDatFileName = os.path.join(tmp, "base__external.dat")
- externalh5 = h5py.File(extH5FileName, mode="w")
+ externalh5 = h5py_utils.File(extH5FileName, mode="w")
externalh5["target/dataset"] = 50
externalh5["target/link"] = h5py.SoftLink("/target/dataset")
externalh5["/ext/vds0"] = [0, 1]
externalh5["/ext/vds1"] = [2, 3]
externalh5.close()
- numpy.array([0,1,10,10,2,3]).tofile(extDatFileName)
+ numpy.array([0, 1, 10, 10, 2, 3]).tofile(extDatFileName)
- h5 = h5py.File(filename, mode="w")
+ h5 = h5py_utils.File(filename, mode="w")
h5["group/dataset"] = 50
h5["link/soft_link"] = h5py.SoftLink("/group/dataset")
h5["link/soft_link_to_group"] = h5py.SoftLink("/group")
- h5["link/soft_link_to_link"] = h5py.SoftLink("/link/soft_link")
+ h5["link/soft_link_to_soft_link"] = h5py.SoftLink("/link/soft_link")
+ h5["link/soft_link_to_external_link"] = h5py.SoftLink("/link/external_link")
h5["link/soft_link_to_file"] = h5py.SoftLink("/")
h5["group/soft_link_relative"] = h5py.SoftLink("dataset")
h5["link/external_link"] = h5py.ExternalLink(extH5FileName, "/target/dataset")
- h5["link/external_link_to_link"] = h5py.ExternalLink(extH5FileName, "/target/link")
- h5["broken_link/external_broken_file"] = h5py.ExternalLink(extH5FileName + "_not_exists", "/target/link")
- h5["broken_link/external_broken_link"] = h5py.ExternalLink(extH5FileName, "/target/not_exists")
+ h5["link/external_link_to_soft_link"] = h5py.ExternalLink(
+ extH5FileName, "/target/link"
+ )
+ h5["broken_link/external_broken_file"] = h5py.ExternalLink(
+ extH5FileName + "_not_exists", "/target/link"
+ )
+ h5["broken_link/external_broken_link"] = h5py.ExternalLink(
+ extH5FileName, "/target/not_exists"
+ )
h5["broken_link/soft_broken_link"] = h5py.SoftLink("/group/not_exists")
h5["broken_link/soft_link_to_broken_link"] = h5py.SoftLink("/group/not_exists")
if h5py2_9:
- layout = h5py.VirtualLayout((2,2), dtype=int)
- layout[0] = h5py.VirtualSource("base__external.h5", name="/ext/vds0", shape=(2,), dtype=int)
- layout[1] = h5py.VirtualSource("base__external.h5", name="/ext/vds1", shape=(2,), dtype=int)
+ layout = h5py.VirtualLayout((2, 2), dtype=int)
+ layout[0] = h5py.VirtualSource(
+ "base__external.h5", name="/ext/vds0", shape=(2,), dtype=int
+ )
+ layout[1] = h5py.VirtualSource(
+ "base__external.h5", name="/ext/vds1", shape=(2,), dtype=int
+ )
h5.create_group("/ext")
h5["/ext"].create_virtual_dataset("virtual", layout)
- external = [("base__external.dat", 0, 2*8), ("base__external.dat", 4*8, 2*8)]
- h5["/ext"].create_dataset("raw", shape=(2,2), dtype=int, external=external)
+ external = [
+ ("base__external.dat", 0, 2 * 8),
+ ("base__external.dat", 4 * 8, 2 * 8),
+ ]
+ h5["/ext"].create_dataset("raw", shape=(2, 2), dtype=int, external=external)
h5.close()
- with h5py.File(filename, mode="r") as h5File:
+ with h5py_utils.File(filename, mode="r") as h5File:
# Create model
request.cls.model = hdf5.Hdf5TreeModel()
request.cls.model.insertH5pyObject(h5File)
@@ -615,7 +708,7 @@ def useH5Model(request, tmpdir_factory):
TestCaseQt.qWaitForDestroy(ref)
-@pytest.mark.usefixtures('useH5Model')
+@pytest.mark.usefixtures("useH5Model")
class _TestModelBase(TestCaseQt):
def getIndexFromPath(self, model, path):
"""
@@ -639,7 +732,6 @@ class _TestModelBase(TestCaseQt):
class TestH5Item(_TestModelBase):
-
def testFile(self):
path = ["base.h5"]
h5item = self.getH5ItemFromPath(self.model, path)
@@ -665,7 +757,7 @@ class TestH5Item(_TestModelBase):
self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Soft")
def testSoftLinkToLink(self):
- path = ["base.h5", "link", "soft_link_to_link"]
+ path = ["base.h5", "link", "soft_link_to_soft_link"]
h5item = self.getH5ItemFromPath(self.model, path)
self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Soft")
@@ -683,7 +775,7 @@ class TestH5Item(_TestModelBase):
self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "External")
def testExternalLinkToLink(self):
- path = ["base.h5", "link", "external_link_to_link"]
+ path = ["base.h5", "link", "external_link_to_soft_link"]
h5item = self.getH5ItemFromPath(self.model, path)
self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "External")
@@ -719,7 +811,14 @@ class TestH5Item(_TestModelBase):
self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "")
def testDatasetFromSoftLinkToFile(self):
- path = ["base.h5", "link", "soft_link_to_file", "link", "soft_link_to_group", "dataset"]
+ path = [
+ "base.h5",
+ "link",
+ "soft_link_to_file",
+ "link",
+ "soft_link_to_group",
+ "dataset",
+ ]
h5item = self.getH5ItemFromPath(self.model, path)
self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "")
@@ -740,7 +839,6 @@ class TestH5Item(_TestModelBase):
class TestH5Node(_TestModelBase):
-
def getH5NodeFromPath(self, model, path):
item = self.getH5ItemFromPath(model, path)
h5node = hdf5.H5Node(item)
@@ -790,16 +888,30 @@ class TestH5Node(_TestModelBase):
self.assertEqual(h5node.local_basename, "soft_link")
self.assertEqual(h5node.local_name, "/link/soft_link")
- def testSoftLinkToLink(self):
- path = ["base.h5", "link", "soft_link_to_link"]
+ def testSoftLinkToSoftLink(self):
+ path = ["base.h5", "link", "soft_link_to_soft_link"]
h5node = self.getH5NodeFromPath(self.model, path)
self.assertEqual(h5node.physical_filename, h5node.local_filename)
self.assertIn("base.h5", h5node.physical_filename)
self.assertEqual(h5node.physical_basename, "dataset")
self.assertEqual(h5node.physical_name, "/group/dataset")
- self.assertEqual(h5node.local_basename, "soft_link_to_link")
- self.assertEqual(h5node.local_name, "/link/soft_link_to_link")
+ self.assertEqual(h5node.local_basename, "soft_link_to_soft_link")
+ self.assertEqual(h5node.local_name, "/link/soft_link_to_soft_link")
+
+ def testSoftLinkToExternalLink(self):
+ path = ["base.h5", "link", "soft_link_to_external_link"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ with self.assertRaises(KeyError):
+ # h5py bug: #1706
+ self.assertNotEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.local_filename)
+ self.assertIn("base__external.h5", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "dataset")
+ self.assertEqual(h5node.physical_name, "/target/dataset")
+ self.assertEqual(h5node.local_basename, "soft_link_to_external_link")
+ self.assertEqual(h5node.local_name, "/link/soft_link_to_external_link")
def testSoftLinkRelative(self):
path = ["base.h5", "group", "soft_link_relative"]
@@ -824,19 +936,18 @@ class TestH5Node(_TestModelBase):
self.assertEqual(h5node.local_basename, "external_link")
self.assertEqual(h5node.local_name, "/link/external_link")
- def testExternalLinkToLink(self):
- path = ["base.h5", "link", "external_link_to_link"]
+ def testExternalLinkToSoftLink(self):
+ path = ["base.h5", "link", "external_link_to_soft_link"]
h5node = self.getH5NodeFromPath(self.model, path)
self.assertNotEqual(h5node.physical_filename, h5node.local_filename)
self.assertIn("base.h5", h5node.local_filename)
self.assertIn("base__external.h5", h5node.physical_filename)
-
self.assertNotEqual(h5node.physical_filename, h5node.local_filename)
self.assertEqual(h5node.physical_basename, "dataset")
self.assertEqual(h5node.physical_name, "/target/dataset")
- self.assertEqual(h5node.local_basename, "external_link_to_link")
- self.assertEqual(h5node.local_name, "/link/external_link_to_link")
+ self.assertEqual(h5node.local_basename, "external_link_to_soft_link")
+ self.assertEqual(h5node.local_name, "/link/external_link_to_soft_link")
def testExternalBrokenFile(self):
path = ["base.h5", "broken_link", "external_broken_file"]
@@ -896,7 +1007,14 @@ class TestH5Node(_TestModelBase):
self.assertEqual(h5node.local_name, "/link/soft_link_to_group/dataset")
def testDatasetFromSoftLinkToFile(self):
- path = ["base.h5", "link", "soft_link_to_file", "link", "soft_link_to_group", "dataset"]
+ path = [
+ "base.h5",
+ "link",
+ "soft_link_to_file",
+ "link",
+ "soft_link_to_group",
+ "dataset",
+ ]
h5node = self.getH5NodeFromPath(self.model, path)
self.assertEqual(h5node.physical_filename, h5node.local_filename)
@@ -904,7 +1022,9 @@ class TestH5Node(_TestModelBase):
self.assertEqual(h5node.physical_basename, "dataset")
self.assertEqual(h5node.physical_name, "/group/dataset")
self.assertEqual(h5node.local_basename, "dataset")
- self.assertEqual(h5node.local_name, "/link/soft_link_to_file/link/soft_link_to_group/dataset")
+ self.assertEqual(
+ h5node.local_name, "/link/soft_link_to_file/link/soft_link_to_group/dataset"
+ )
@pytest.mark.skipif(not h5py2_9, reason="requires h5py>=2.9")
def testExternalVirtual(self):
diff --git a/src/silx/gui/icons.py b/src/silx/gui/icons.py
index b7a9000..3e2501b 100644
--- a/src/silx/gui/icons.py
+++ b/src/silx/gui/icons.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -37,7 +37,6 @@ import weakref
from . import qt
import silx.resources
from silx.utils import weakref as silxweakref
-from silx.utils.deprecation import deprecated
_logger = logging.getLogger(__name__)
@@ -92,7 +91,7 @@ class AbstractAnimatedIcon(qt.QObject):
"""Signal sent with a QIcon everytime the animation changed."""
def register(self, obj):
- """Register an object to the AnimatedIcon.
+ """Register an object to the AbstractAnimatedIcon.
If no object are registred, the animation is paused.
Object are stored in a weaked list.
@@ -120,7 +119,7 @@ class AbstractAnimatedIcon(qt.QObject):
return len(self.__targets)
def isRegistered(self, obj):
- """Returns true if the object is registred in the AnimatedIcon.
+ """Returns true if the object is registred in the AbstractAnimatedIcon.
:param object obj: An object
:rtype: bool
@@ -191,7 +190,7 @@ class MovieAnimatedIcon(AbstractAnimatedIcon):
def _updateState(self):
"""Update the movie play according to internal stat of the
- AnimatedIcon."""
+ MovieAnimatedIcon."""
self.__movie.setPaused(not self.hasRegistredObjects())
@@ -212,7 +211,7 @@ class MultiImageAnimatedIcon(AbstractAnimatedIcon):
self.__frames = []
for i in range(100):
try:
- frame_filename = os.sep.join((filename, ("%02d" %i)))
+ frame_filename = os.sep.join((filename, ("%02d" % i)))
frame_file = getQFile(frame_filename)
except ValueError:
break
@@ -257,22 +256,6 @@ class MultiImageAnimatedIcon(AbstractAnimatedIcon):
self.__timer.stop()
-class AnimatedIcon(MovieAnimatedIcon):
- """Store a looping QMovie to provide icons for each frames.
- Provides an event with the new icon everytime the movie frame
- is updated.
-
- It may not be available anymore for the silx release 0.6.
-
- .. deprecated:: 0.5
- Use :class:`MovieAnimatedIcon` instead.
- """
-
- @deprecated
- def __init__(self, filename, parent=None):
- MovieAnimatedIcon.__init__(self, filename, parent=parent)
-
-
def getWaitIcon():
"""Returns a cached version of the waiting AbstractAnimatedIcon.
@@ -307,7 +290,6 @@ def getAnimatedIcon(name):
key = name + "__anim"
cached_icons = getIconCache()
if key not in cached_icons:
-
qtMajorVersion = int(qt.qVersion().split(".")[0])
icon = None
@@ -415,10 +397,11 @@ def getQFile(name):
for format_ in _supported_formats:
format_ = str(format_)
- filename = silx.resources._resource_filename('%s.%s' % (name, format_),
- default_directory=os.path.join('gui', 'icons'))
+ filename = silx.resources._resource_filename(
+ "%s.%s" % (name, format_), default_directory="gui/icons"
+ )
qfile = qt.QFile(filename)
if qfile.exists():
return qfile
_logger.debug("File '%s' not found.", filename)
- raise ValueError('Not an icon name: %s' % name)
+ raise ValueError("Not an icon name: %s" % name)
diff --git a/src/silx/gui/plot/AlphaSlider.py b/src/silx/gui/plot/AlphaSlider.py
index 486ca6f..8a0a711 100644
--- a/src/silx/gui/plot/AlphaSlider.py
+++ b/src/silx/gui/plot/AlphaSlider.py
@@ -96,6 +96,7 @@ class BaseAlphaSlider(qt.QSlider):
You must subclass this class and implement :meth:`getItem`.
"""
+
sigAlphaChanged = qt.Signal(float)
"""Emits the alpha value when the slider's value changes,
as a float between 0. and 1."""
@@ -119,7 +120,7 @@ class BaseAlphaSlider(qt.QSlider):
self.setEnabled(False)
else:
alpha = self.getItem().getAlpha()
- self.setValue(round(255*alpha))
+ self.setValue(round(255 * alpha))
self.valueChanged.connect(self._valueChanged)
@@ -132,8 +133,8 @@ class BaseAlphaSlider(qt.QSlider):
:rtype: :class:`silx.plot.items.Item`
"""
raise NotImplementedError(
- "BaseAlphaSlider must be subclassed to " +
- "implement getItem()")
+ "BaseAlphaSlider must be subclassed to " + "implement getItem()"
+ )
def getAlpha(self):
"""Get the opacity, as a float between 0. and 1.
@@ -141,15 +142,14 @@ class BaseAlphaSlider(qt.QSlider):
:return: Alpha value in [0., 1.]
:rtype: float
"""
- return self.value() / 255.
+ return self.value() / 255.0
def _valueChanged(self, value):
self._updateItem()
- self.sigAlphaChanged.emit(value / 255.)
+ self.sigAlphaChanged.emit(value / 255.0)
def _updateItem(self):
- """Update the item's alpha channel.
- """
+ """Update the item's alpha channel."""
item = self.getItem()
if item is not None:
item.setAlpha(self.getAlpha())
@@ -164,6 +164,7 @@ class ActiveImageAlphaSlider(BaseAlphaSlider):
See documentation of :class:`BaseAlphaSlider`
"""
+
def __init__(self, parent=None, plot=None):
"""
@@ -203,8 +204,8 @@ class NamedItemAlphaSlider(BaseAlphaSlider):
:param str legend: Legend of item whose transparency is to be
controlled.
"""
- def __init__(self, parent=None, plot=None,
- kind=None, legend=None):
+
+ def __init__(self, parent=None, plot=None, kind=None, legend=None):
self._item_legend = legend
self._item_kind = kind
@@ -234,8 +235,7 @@ class NamedItemAlphaSlider(BaseAlphaSlider):
:rtype: subclass of :class:`silx.gui.plot.items.Item`"""
if self._item_legend is None or self._item_kind is None:
return None
- return self.plot._getItem(kind=self._item_kind,
- legend=self._item_legend)
+ return self.plot._getItem(kind=self._item_kind, legend=self._item_legend)
def setLegend(self, legend):
"""Associate a different item (of the same kind) to the slider.
@@ -280,9 +280,9 @@ class NamedImageAlphaSlider(NamedItemAlphaSlider):
:param str legend: Legend of image whose transparency is to be
controlled.
"""
+
def __init__(self, parent=None, plot=None, legend=None):
- NamedItemAlphaSlider.__init__(self, parent, plot,
- kind="image", legend=legend)
+ NamedItemAlphaSlider.__init__(self, parent, plot, kind="image", legend=legend)
class NamedScatterAlphaSlider(NamedItemAlphaSlider):
@@ -294,6 +294,6 @@ class NamedScatterAlphaSlider(NamedItemAlphaSlider):
:param str legend: Legend of scatter whose transparency is to be
controlled.
"""
+
def __init__(self, parent=None, plot=None, legend=None):
- NamedItemAlphaSlider.__init__(self, parent, plot,
- kind="scatter", legend=legend)
+ NamedItemAlphaSlider.__init__(self, parent, plot, kind="scatter", legend=legend)
diff --git a/src/silx/gui/plot/ColorBar.py b/src/silx/gui/plot/ColorBar.py
index 247da07..ee31f25 100644
--- a/src/silx/gui/plot/ColorBar.py
+++ b/src/silx/gui/plot/ColorBar.py
@@ -69,6 +69,7 @@ class ColorBarWidget(qt.QWidget):
:param plot: PlotWidget the colorbar is attached to (optional)
:param str legend: the label to set to the colorbar
"""
+
sigVisibleChanged = qt.Signal(bool)
"""Emitted when the property `visible` have changed."""
@@ -88,12 +89,11 @@ class ColorBarWidget(qt.QWidget):
self.setLayout(qt.QHBoxLayout())
# create color scale widget
- self._colorScale = ColorScaleBar(parent=self,
- colormap=None)
+ self._colorScale = ColorScaleBar(parent=self, colormap=None)
self.layout().addWidget(self._colorScale)
# legend (is the right group)
- self.legend = _VerticalLegend('', self)
+ self.legend = _VerticalLegend("", self)
self.layout().addWidget(self.legend)
self.layout().setSizeConstraint(qt.QLayout.SetMinAndMaxSize)
@@ -118,10 +118,8 @@ class ColorBarWidget(qt.QWidget):
self._isConnected = False
plot = self.getPlot()
if plot is not None and qt_inspect.isValid(plot):
- plot.sigActiveImageChanged.disconnect(
- self._activeImageChanged)
- plot.sigActiveScatterChanged.disconnect(
- self._activeScatterChanged)
+ plot.sigActiveImageChanged.disconnect(self._activeImageChanged)
+ plot.sigActiveScatterChanged.disconnect(self._activeScatterChanged)
plot.sigPlotSignal.disconnect(self._defaultColormapChanged)
def _connectPlot(self):
@@ -129,8 +127,7 @@ class ColorBarWidget(qt.QWidget):
plot = self.getPlot()
if plot is not None and not self._isConnected:
activeImageLegend = plot.getActiveImage(just_legend=True)
- activeScatterLegend = plot._getActiveItem(
- kind='scatter', just_legend=True)
+ activeScatterLegend = plot.getActiveScatter(just_legend=True)
if activeImageLegend is None and activeScatterLegend is None:
# Show plot default colormap
self._syncWithDefaultColormap()
@@ -170,8 +167,7 @@ class ColorBarWidget(qt.QWidget):
The data to display or item, needed if the colormap require an autoscale
"""
self._data = data
- self.getColorScaleBar().setColormap(colormap=colormap,
- data=data)
+ self.getColorScaleBar().setColormap(colormap=colormap, data=data)
if self._colormap is not None:
self._colormap.sigChanged.disconnect(self._colormapHasChanged)
self._colormap = colormap
@@ -179,11 +175,9 @@ class ColorBarWidget(qt.QWidget):
self._colormap.sigChanged.connect(self._colormapHasChanged)
def _colormapHasChanged(self):
- """handler of the Colormap.sigChanged signal
- """
+ """handler of the Colormap.sigChanged signal"""
assert self._colormap is not None
- self.setColormap(colormap=self._colormap,
- data=self._data)
+ self.setColormap(colormap=self._colormap, data=self._data)
def setLegend(self, legend):
"""Set the legend displayed along the colorbar
@@ -220,18 +214,16 @@ class ColorBarWidget(qt.QWidget):
return
# Sync with active scatter
- scatter = plot._getActiveItem(kind='scatter')
+ scatter = plot.getActiveScatter()
- self.setColormap(colormap=scatter.getColormap(),
- data=scatter)
+ self.setColormap(colormap=scatter.getColormap(), data=scatter)
def _activeImageChanged(self, previous, legend):
"""Handle plot active image changed"""
plot = self.getPlot()
if legend is None: # No active image, try with active scatter
- activeScatterLegend = plot._getActiveItem(
- kind='scatter', just_legend=True)
+ activeScatterLegend = plot.getActiveScatter(just_legend=True)
# No more active image, use active scatter if any
self._activeScatterChanged(None, activeScatterLegend)
else:
@@ -251,11 +243,13 @@ class ColorBarWidget(qt.QWidget):
def _defaultColormapChanged(self, event):
"""Handle plot default colormap changed"""
- if event['event'] == 'defaultColormapChanged':
+ if event["event"] == "defaultColormapChanged":
plot = self.getPlot()
- if (plot is not None and
- plot.getActiveImage() is None and
- plot._getActiveItem(kind='scatter') is None):
+ if (
+ plot is not None
+ and plot.getActiveImage() is None
+ and plot.getActiveScatter() is None
+ ):
# No active item, take default colormap update into account
self._syncWithDefaultColormap()
@@ -272,8 +266,8 @@ class ColorBarWidget(qt.QWidget):
class _VerticalLegend(qt.QLabel):
- """Display vertically the given text
- """
+ """Display vertically the given text"""
+
def __init__(self, text, parent=None):
"""
@@ -333,8 +327,7 @@ class ColorScaleBar(qt.QWidget):
"""The tick bar need a margin to display all labels at the correct place.
So the ColorScale should have the same margin in order for both to fit"""
- def __init__(self, parent=None, colormap=None, data=None,
- displayTicksValues=True):
+ def __init__(self, parent=None, colormap=None, data=None, displayTicksValues=True):
super(ColorScaleBar, self).__init__(parent)
self.minVal = None
@@ -345,10 +338,9 @@ class ColorScaleBar(qt.QWidget):
self.setLayout(qt.QGridLayout())
# create the left side group (ColorScale)
- self.colorScale = _ColorScale(colormap=colormap,
- data=data,
- parent=self,
- margin=ColorScaleBar._TEXT_MARGIN)
+ self.colorScale = _ColorScale(
+ colormap=colormap, data=data, parent=self, margin=ColorScaleBar._TEXT_MARGIN
+ )
if colormap:
vmin, vmax = colormap.getColormapRange(data)
normalizer = colormap._getNormalizer()
@@ -356,12 +348,14 @@ class ColorScaleBar(qt.QWidget):
vmin, vmax = colors.DEFAULT_MIN_LIN, colors.DEFAULT_MAX_LIN
normalizer = None
- self.tickbar = _TickBar(vmin=vmin,
- vmax=vmax,
- normalizer=normalizer,
- parent=self,
- displayValues=displayTicksValues,
- margin=ColorScaleBar._TEXT_MARGIN)
+ self.tickbar = _TickBar(
+ vmin=vmin,
+ vmax=vmax,
+ normalizer=normalizer,
+ parent=self,
+ displayValues=displayTicksValues,
+ margin=ColorScaleBar._TEXT_MARGIN,
+ )
self.layout().addWidget(self.tickbar, 1, 0, 1, 1, qt.Qt.AlignRight)
self.layout().addWidget(self.colorScale, 1, 1, qt.Qt.AlignLeft)
@@ -421,9 +415,7 @@ class ColorScaleBar(qt.QWidget):
vmin, vmax = None, None
normalizer = None
- self.tickbar.update(vmin=vmin,
- vmax=vmax,
- normalizer=normalizer)
+ self.tickbar.update(vmin=vmin, vmax=vmax, normalizer=normalizer)
self._setMinMaxLabels(vmin, vmax)
def setMinMaxVisible(self, val=True):
@@ -438,24 +430,24 @@ class ColorScaleBar(qt.QWidget):
"""Update the min and max label if we are in the case of the
configuration 'minMaxValueOnly'"""
if self.minVal is None:
- text, tooltip = '', ''
+ text, tooltip = "", ""
else:
if self.minVal == 0 or 0 <= numpy.log10(abs(self.minVal)) < 7:
- text = '%.7g' % self.minVal
+ text = "%.7g" % self.minVal
else:
- text = '%.2e' % self.minVal
+ text = "%.2e" % self.minVal
tooltip = repr(self.minVal)
self._minLabel.setText(text)
self._minLabel.setToolTip(tooltip)
if self.maxVal is None:
- text, tooltip = '', ''
+ text, tooltip = "", ""
else:
if self.maxVal == 0 or 0 <= numpy.log10(abs(self.maxVal)) < 7:
- text = '%.7g' % self.maxVal
+ text = "%.7g" % self.maxVal
else:
- text = '%.2e' % self.maxVal
+ text = "%.2e" % self.maxVal
tooltip = repr(self.maxVal)
self._maxLabel.setText(text)
@@ -561,7 +553,7 @@ class _ColorScale(qt.QWidget):
if colormap is None:
return
- indices = numpy.linspace(0., 1., self._NB_CONTROL_POINTS)
+ indices = numpy.linspace(0.0, 1.0, self._NB_CONTROL_POINTS)
colors = colormap.getNColors(nbColors=self._NB_CONTROL_POINTS)
self._gradient = qt.QLinearGradient(0, 1, 0, 0)
self._gradient.setCoordinateMode(qt.QGradient.StretchToDeviceMode)
@@ -574,30 +566,39 @@ class _ColorScale(qt.QWidget):
painter = qt.QPainter(self)
if self.getColormap() is not None:
painter.setBrush(self._gradient)
- penColor = self.palette().color(qt.QPalette.Active,
- qt.QPalette.WindowText)
+ penColor = self.palette().color(qt.QPalette.Active, qt.QPalette.WindowText)
else:
- penColor = self.palette().color(qt.QPalette.Disabled,
- qt.QPalette.WindowText)
+ penColor = self.palette().color(
+ qt.QPalette.Disabled, qt.QPalette.WindowText
+ )
painter.setPen(penColor)
- painter.drawRect(qt.QRect(
- 0,
- self.margin,
- self.width() - 1,
- self.height() - 2 * self.margin - 1))
+ painter.drawRect(
+ qt.QRect(
+ 0, self.margin, self.width() - 1, self.height() - 2 * self.margin - 1
+ )
+ )
def mouseMoveEvent(self, event):
- tooltip = str(self.getValueFromRelativePosition(
- self._getRelativePosition(qt.getMouseEventPosition(event)[1])))
- qt.QToolTip.showText(event.globalPos(), tooltip, self)
+ tooltip = str(
+ self.getValueFromRelativePosition(
+ self._getRelativePosition(qt.getMouseEventPosition(event)[1])
+ )
+ )
+ if qt.BINDING == "PyQt5":
+ position = event.globalPos()
+ else: # Qt6
+ position = event.globalPosition().toPoint()
+ qt.QToolTip.showText(position, tooltip, self)
super(_ColorScale, self).mouseMoveEvent(event)
def _getRelativePosition(self, yPixel):
- """yPixel : pixel position into _ColorScale widget reference
- """
+ """yPixel : pixel position into _ColorScale widget reference"""
# widgets are bottom-top referencial but we display in top-bottom referential
- return 1. - (yPixel - self.margin) / float(self.height() - 2 * self.margin)
+ height = float(self.height() - 2 * self.margin)
+ if height == 0:
+ return 0.0
+ return 1.0 - (yPixel - self.margin) / height
def getValueFromRelativePosition(self, value):
"""Return the value in the colorMap from a relative position in the
@@ -610,12 +611,15 @@ class _ColorScale(qt.QWidget):
if colormap is None:
return
- value = numpy.clip(value, 0., 1.)
+ value = numpy.clip(value, 0.0, 1.0)
normalizer = colormap._getNormalizer()
- normMin, normMax = normalizer.apply([self.vmin, self.vmax], self.vmin, self.vmax)
+ normMin, normMax = normalizer.apply(
+ [self.vmin, self.vmax], self.vmin, self.vmax
+ )
return normalizer.revert(
- normMin + (normMax - normMin) * value, self.vmin, self.vmax)
+ normMin + (normMax - normMin) * value, self.vmin, self.vmax
+ )
def setMargin(self, margin):
"""Define the margin to fit with a TickBar object.
@@ -651,6 +655,7 @@ class _TickBar(qt.QWidget):
number of ticks from the tick density.
:param int margin: margin to set on the top and bottom
"""
+
_WIDTH_DISP_VAL = 45
"""widget width when displayed with ticks labels"""
_WIDTH_NO_DISP_VAL = 10
@@ -662,8 +667,16 @@ class _TickBar(qt.QWidget):
DEFAULT_TICK_DENSITY = 0.015
- def __init__(self, vmin, vmax, normalizer, parent=None, displayValues=True,
- nticks=None, margin=5):
+ def __init__(
+ self,
+ vmin,
+ vmax,
+ normalizer,
+ parent=None,
+ displayValues=True,
+ nticks=None,
+ margin=5,
+ ):
super(_TickBar, self).__init__(parent)
self.margin = margin
self._nticks = None
@@ -722,7 +735,7 @@ class _TickBar(qt.QWidget):
(nticks=None) then you can specify a ticks density to be displayed.
"""
if density < 0.0:
- raise ValueError('Density should be a positive value')
+ raise ValueError("Density should be a positive value")
self.ticksDensity = density
def computeTicks(self):
@@ -752,14 +765,16 @@ class _TickBar(qt.QWidget):
def _computeTicksLog(self, nticks):
logMin = numpy.log10(self._vmin)
logMax = numpy.log10(self._vmax)
- lowBound, highBound, spacing, self._nfrac = ticklayout.niceNumbersForLog10(logMin,
- logMax,
- nticks)
- self.ticks = numpy.power(10., numpy.arange(lowBound, highBound, spacing))
+ lowBound, highBound, spacing, self._nfrac = ticklayout.niceNumbersForLog10(
+ logMin, logMax, nticks
+ )
+ self.ticks = numpy.power(10.0, numpy.arange(lowBound, highBound, spacing))
if spacing == 1:
- self.subTicks = ticklayout.computeLogSubTicks(ticks=self.ticks,
- lowBound=numpy.power(10., lowBound),
- highBound=numpy.power(10., highBound))
+ self.subTicks = ticklayout.computeLogSubTicks(
+ ticks=self.ticks,
+ lowBound=numpy.power(10.0, lowBound),
+ highBound=numpy.power(10.0, highBound),
+ )
else:
self.subTicks = []
@@ -768,9 +783,9 @@ class _TickBar(qt.QWidget):
self.computeTicks()
def _computeTicksLin(self, nticks):
- _min, _max, _spacing, self._nfrac = ticklayout.niceNumbers(self._vmin,
- self._vmax,
- nticks)
+ _min, _max, _spacing, self._nfrac = ticklayout.niceNumbers(
+ self._vmin, self._vmax, nticks
+ )
self.ticks = numpy.arange(_min, _max, _spacing)
self.subTicks = []
@@ -793,19 +808,18 @@ class _TickBar(qt.QWidget):
self._paintTick(val, painter, majorTick=False)
def _getRelativePosition(self, val):
- """Return the relative position of val according to min and max value
- """
+ """Return the relative position of val according to min and max value"""
if self._normalizer is None:
- return 0.
+ return 0.0
normMin, normMax, normVal = self._normalizer.apply(
- [self._vmin, self._vmax, val],
- self._vmin,
- self._vmax)
+ [self._vmin, self._vmax, val], self._vmin, self._vmax
+ )
if normMin == normMax:
- return 0.
- else:
- return 1. - (normVal - normMin) / (normMax - normMin)
+ return 0.0
+ if not numpy.isfinite(normVal):
+ return 0.0
+ return 1.0 - (normVal - normMin) / (normMax - normMin)
def _paintTick(self, val, painter, majorTick=True):
"""
@@ -821,14 +835,14 @@ class _TickBar(qt.QWidget):
if majorTick is False:
lineWidth /= 2
- painter.drawLine(qt.QLine(int(self.width() - lineWidth),
- height,
- self.width(),
- height))
+ painter.drawLine(
+ qt.QLine(int(self.width() - lineWidth), height, self.width(), height)
+ )
if self.displayValues and majorTick is True:
- painter.drawText(qt.QPoint(0, int(height + fm.height() / 2)),
- self.form.format(val))
+ painter.drawText(
+ qt.QPoint(0, int(height + fm.height() / 2)), self.form.format(val)
+ )
def setDisplayType(self, disType):
"""Set the type of display we want to set for ticks labels
@@ -841,8 +855,10 @@ class _TickBar(qt.QWidget):
- 'e' for scientific display
- None to let the _TickBar guess the best display for this kind of data.
"""
- if disType not in (None, 'std', 'e'):
- raise ValueError("display type not recognized, value should be in (None, 'std', 'e'")
+ if disType not in (None, "std", "e"):
+ raise ValueError(
+ "display type not recognized, value should be in (None, 'std', 'e'"
+ )
self._forcedDisplayType = disType
def _getStandardFormat(self):
@@ -851,12 +867,14 @@ class _TickBar(qt.QWidget):
def _getFormat(self, font):
if self._forcedDisplayType is None:
return self._guessType(font)
- elif self._forcedDisplayType == 'std':
+ elif self._forcedDisplayType == "std":
return self._getStandardFormat()
- elif self._forcedDisplayType == 'e':
+ elif self._forcedDisplayType == "e":
return self._getScientificForm()
else:
- err = 'Forced type for display %s is not recognized' % self._forcedDisplayType
+ err = (
+ "Forced type for display %s is not recognized" % self._forcedDisplayType
+ )
raise ValueError(err)
def _getScientificForm(self):
diff --git a/src/silx/gui/plot/Colormap.py b/src/silx/gui/plot/Colormap.py
deleted file mode 100644
index 8eaee84..0000000
--- a/src/silx/gui/plot/Colormap.py
+++ /dev/null
@@ -1,41 +0,0 @@
-# /*##########################################################################
-#
-# Copyright (c) 2015-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Deprecated module providing the Colormap object
-"""
-
-__authors__ = ["T. Vincent", "H.Payno"]
-__license__ = "MIT"
-__date__ = "27/11/2020"
-
-import silx.utils.deprecation
-
-silx.utils.deprecation.deprecated_warning("Module",
- name="silx.gui.plot.Colormap",
- reason="moved",
- replacement="silx.gui.colors.Colormap",
- since_version="0.8.0",
- only_once=True,
- skip_backtrace_count=1)
-
-from ..colors import * # noqa
diff --git a/src/silx/gui/plot/ColormapDialog.py b/src/silx/gui/plot/ColormapDialog.py
deleted file mode 100644
index 0c0df2c..0000000
--- a/src/silx/gui/plot/ColormapDialog.py
+++ /dev/null
@@ -1,40 +0,0 @@
-# /*##########################################################################
-#
-# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Deprecated module providing ColormapDialog."""
-
-__authors__ = ["T. Vincent", "H.Payno"]
-__license__ = "MIT"
-__date__ = "24/04/2018"
-
-import silx.utils.deprecation
-
-silx.utils.deprecation.deprecated_warning("Module",
- name="silx.gui.plot.ColormapDialog",
- reason="moved",
- replacement="silx.gui.dialog.ColormapDialog",
- since_version="0.8.0",
- only_once=True,
- skip_backtrace_count=1)
-
-from ..dialog.ColormapDialog import * # noqa
diff --git a/src/silx/gui/plot/Colors.py b/src/silx/gui/plot/Colors.py
deleted file mode 100644
index 34ee815..0000000
--- a/src/silx/gui/plot/Colors.py
+++ /dev/null
@@ -1,87 +0,0 @@
-# /*##########################################################################
-#
-# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Color conversion function, color dictionary and colormap tools."""
-
-__authors__ = ["V.A. Sole", "T. Vincent"]
-__license__ = "MIT"
-__date__ = "14/06/2018"
-
-import silx.utils.deprecation
-
-silx.utils.deprecation.deprecated_warning("Module",
- name="silx.gui.plot.Colors",
- reason="moved",
- replacement="silx.gui.colors",
- since_version="0.8.0",
- only_once=True,
- skip_backtrace_count=1)
-
-from ..colors import * # noqa
-
-
-@silx.utils.deprecation.deprecated(replacement='silx.gui.colors.Colormap.applyColormap')
-def applyColormapToData(data,
- name='gray',
- normalization='linear',
- autoscale=True,
- vmin=0.,
- vmax=1.,
- colors=None):
- """Apply a colormap to the data and returns the RGBA image
-
- This supports data of any dimensions (not only of dimension 2).
- The returned array will have one more dimension (with 4 entries)
- than the input data to store the RGBA channels
- corresponding to each bin in the array.
-
- :param numpy.ndarray data: The data to convert.
- :param str name: Name of the colormap (default: 'gray').
- :param str normalization: Colormap mapping: 'linear' or 'log'.
- :param bool autoscale: Whether to use data min/max (True, default)
- or [vmin, vmax] range (False).
- :param float vmin: The minimum value of the range to use if
- 'autoscale' is False.
- :param float vmax: The maximum value of the range to use if
- 'autoscale' is False.
- :param numpy.ndarray colors: Only used if name is None.
- Custom colormap colors as Nx3 or Nx4 RGB or RGBA arrays
- :return: The computed RGBA image
- :rtype: numpy.ndarray of uint8
- """
- colormap = Colormap(name=name,
- normalization=normalization,
- vmin=vmin,
- vmax=vmax,
- colors=colors)
- return colormap.applyToData(data)
-
-
-@silx.utils.deprecation.deprecated(replacement='silx.gui.colors.Colormap.getSupportedColormaps')
-def getSupportedColormaps():
- """Get the supported colormap names as a tuple of str.
-
- The list should at least contain and start by:
- ('gray', 'reversed gray', 'temperature', 'red', 'green', 'blue')
- """
- return Colormap.getSupportedColormaps()
diff --git a/src/silx/gui/plot/CompareImages.py b/src/silx/gui/plot/CompareImages.py
index 80e0db3..3823ae2 100644
--- a/src/silx/gui/plot/CompareImages.py
+++ b/src/silx/gui/plot/CompareImages.py
@@ -29,505 +29,30 @@ __license__ = "MIT"
__date__ = "23/07/2018"
-import enum
import logging
import numpy
-import weakref
-import collections
import math
import silx.image.bilinear
from silx.gui import qt
from silx.gui import plot
-from silx.gui import icons
from silx.gui.colors import Colormap
from silx.gui.plot import tools
+from silx.utils.deprecation import deprecated_warning
from silx.utils.weakref import WeakMethodProxy
+from silx.gui.plot.items import Scatter
+from silx.math.colormap import normalize
-_logger = logging.getLogger(__name__)
-
-from silx.opencl import ocl
-if ocl is not None:
- try:
- from silx.opencl import sift
- except ImportError:
- # sift module is not available (e.g., in official Debian packages)
- sift = None
-else: # No OpenCL device or no pyopencl
- sift = None
-
-
-@enum.unique
-class VisualizationMode(enum.Enum):
- """Enum for each visualization mode available."""
- ONLY_A = 'a'
- ONLY_B = 'b'
- VERTICAL_LINE = 'vline'
- HORIZONTAL_LINE = 'hline'
- COMPOSITE_RED_BLUE_GRAY = "rbgchannel"
- COMPOSITE_RED_BLUE_GRAY_NEG = "rbgnegchannel"
- COMPOSITE_A_MINUS_B = "aminusb"
-
-
-@enum.unique
-class AlignmentMode(enum.Enum):
- """Enum for each alignment mode available."""
- ORIGIN = 'origin'
- CENTER = 'center'
- STRETCH = 'stretch'
- AUTO = 'auto'
-
-
-AffineTransformation = collections.namedtuple("AffineTransformation",
- ["tx", "ty", "sx", "sy", "rot"])
-"""Contains a 2D affine transformation: translation, scale and rotation"""
-
-
-class CompareImagesToolBar(qt.QToolBar):
- """ToolBar containing specific tools to custom the configuration of a
- :class:`CompareImages` widget
-
- Use :meth:`setCompareWidget` to connect this toolbar to a specific
- :class:`CompareImages` widget.
-
- :param Union[qt.QWidget,None] parent: Parent of this widget.
- """
- def __init__(self, parent=None):
- qt.QToolBar.__init__(self, parent)
-
- self.__compareWidget = None
-
- menu = qt.QMenu(self)
- self.__visualizationToolButton = qt.QToolButton(self)
- self.__visualizationToolButton.setMenu(menu)
- self.__visualizationToolButton.setPopupMode(qt.QToolButton.InstantPopup)
- self.addWidget(self.__visualizationToolButton)
- self.__visualizationGroup = qt.QActionGroup(self)
- self.__visualizationGroup.setExclusive(True)
- self.__visualizationGroup.triggered.connect(self.__visualizationModeChanged)
-
- icon = icons.getQIcon("compare-mode-a")
- action = qt.QAction(icon, "Display the first image only", self)
- action.setIconVisibleInMenu(True)
- action.setCheckable(True)
- action.setShortcut(qt.QKeySequence(qt.Qt.Key_A))
- action.setProperty("mode", VisualizationMode.ONLY_A)
- menu.addAction(action)
- self.__aModeAction = action
- self.__visualizationGroup.addAction(action)
-
- icon = icons.getQIcon("compare-mode-b")
- action = qt.QAction(icon, "Display the second image only", self)
- action.setIconVisibleInMenu(True)
- action.setCheckable(True)
- action.setShortcut(qt.QKeySequence(qt.Qt.Key_B))
- action.setProperty("mode", VisualizationMode.ONLY_B)
- menu.addAction(action)
- self.__bModeAction = action
- self.__visualizationGroup.addAction(action)
-
- icon = icons.getQIcon("compare-mode-vline")
- action = qt.QAction(icon, "Vertical compare mode", self)
- action.setIconVisibleInMenu(True)
- action.setCheckable(True)
- action.setShortcut(qt.QKeySequence(qt.Qt.Key_V))
- action.setProperty("mode", VisualizationMode.VERTICAL_LINE)
- menu.addAction(action)
- self.__vlineModeAction = action
- self.__visualizationGroup.addAction(action)
-
- icon = icons.getQIcon("compare-mode-hline")
- action = qt.QAction(icon, "Horizontal compare mode", self)
- action.setIconVisibleInMenu(True)
- action.setCheckable(True)
- action.setShortcut(qt.QKeySequence(qt.Qt.Key_H))
- action.setProperty("mode", VisualizationMode.HORIZONTAL_LINE)
- menu.addAction(action)
- self.__hlineModeAction = action
- self.__visualizationGroup.addAction(action)
-
- icon = icons.getQIcon("compare-mode-rb-channel")
- action = qt.QAction(icon, "Blue/red compare mode (additive mode)", self)
- action.setIconVisibleInMenu(True)
- action.setCheckable(True)
- action.setShortcut(qt.QKeySequence(qt.Qt.Key_C))
- action.setProperty("mode", VisualizationMode.COMPOSITE_RED_BLUE_GRAY)
- menu.addAction(action)
- self.__brChannelModeAction = action
- self.__visualizationGroup.addAction(action)
-
- icon = icons.getQIcon("compare-mode-rbneg-channel")
- action = qt.QAction(icon, "Yellow/cyan compare mode (subtractive mode)", self)
- action.setIconVisibleInMenu(True)
- action.setCheckable(True)
- action.setShortcut(qt.QKeySequence(qt.Qt.Key_W))
- action.setProperty("mode", VisualizationMode.COMPOSITE_RED_BLUE_GRAY_NEG)
- menu.addAction(action)
- self.__ycChannelModeAction = action
- self.__visualizationGroup.addAction(action)
-
- icon = icons.getQIcon("compare-mode-a-minus-b")
- action = qt.QAction(icon, "Raw A minus B compare mode", self)
- action.setIconVisibleInMenu(True)
- action.setCheckable(True)
- action.setShortcut(qt.QKeySequence(qt.Qt.Key_W))
- action.setProperty("mode", VisualizationMode.COMPOSITE_A_MINUS_B)
- menu.addAction(action)
- self.__ycChannelModeAction = action
- self.__visualizationGroup.addAction(action)
-
- menu = qt.QMenu(self)
- self.__alignmentToolButton = qt.QToolButton(self)
- self.__alignmentToolButton.setMenu(menu)
- self.__alignmentToolButton.setPopupMode(qt.QToolButton.InstantPopup)
- self.addWidget(self.__alignmentToolButton)
- self.__alignmentGroup = qt.QActionGroup(self)
- self.__alignmentGroup.setExclusive(True)
- self.__alignmentGroup.triggered.connect(self.__alignmentModeChanged)
-
- icon = icons.getQIcon("compare-align-origin")
- action = qt.QAction(icon, "Align images on their upper-left pixel", self)
- action.setProperty("mode", AlignmentMode.ORIGIN)
- action.setIconVisibleInMenu(True)
- action.setCheckable(True)
- self.__originAlignAction = action
- menu.addAction(action)
- self.__alignmentGroup.addAction(action)
-
- icon = icons.getQIcon("compare-align-center")
- action = qt.QAction(icon, "Center images", self)
- action.setProperty("mode", AlignmentMode.CENTER)
- action.setIconVisibleInMenu(True)
- action.setCheckable(True)
- self.__centerAlignAction = action
- menu.addAction(action)
- self.__alignmentGroup.addAction(action)
-
- icon = icons.getQIcon("compare-align-stretch")
- action = qt.QAction(icon, "Stretch the second image on the first one", self)
- action.setProperty("mode", AlignmentMode.STRETCH)
- action.setIconVisibleInMenu(True)
- action.setCheckable(True)
- self.__stretchAlignAction = action
- menu.addAction(action)
- self.__alignmentGroup.addAction(action)
-
- icon = icons.getQIcon("compare-align-auto")
- action = qt.QAction(icon, "Auto-alignment of the second image", self)
- action.setProperty("mode", AlignmentMode.AUTO)
- action.setIconVisibleInMenu(True)
- action.setCheckable(True)
- self.__autoAlignAction = action
- menu.addAction(action)
- if sift is None:
- action.setEnabled(False)
- action.setToolTip("Sift module is not available")
- self.__alignmentGroup.addAction(action)
-
- icon = icons.getQIcon("compare-keypoints")
- action = qt.QAction(icon, "Display/hide alignment keypoints", self)
- action.setCheckable(True)
- action.triggered.connect(self.__keypointVisibilityChanged)
- self.addAction(action)
- self.__displayKeypoints = action
-
- def setCompareWidget(self, widget):
- """
- Connect this tool bar to a specific :class:`CompareImages` widget.
-
- :param Union[None,CompareImages] widget: The widget to connect with.
- """
- compareWidget = self.getCompareWidget()
- if compareWidget is not None:
- compareWidget.sigConfigurationChanged.disconnect(self.__updateSelectedActions)
- compareWidget = widget
- if compareWidget is None:
- self.__compareWidget = None
- else:
- self.__compareWidget = weakref.ref(compareWidget)
- if compareWidget is not None:
- widget.sigConfigurationChanged.connect(self.__updateSelectedActions)
- self.__updateSelectedActions()
-
- def getCompareWidget(self):
- """Returns the connected widget.
-
- :rtype: CompareImages
- """
- if self.__compareWidget is None:
- return None
- else:
- return self.__compareWidget()
-
- def __updateSelectedActions(self):
- """
- Update the state of this tool bar according to the state of the
- connected :class:`CompareImages` widget.
- """
- widget = self.getCompareWidget()
- if widget is None:
- return
-
- mode = widget.getVisualizationMode()
- action = None
- for a in self.__visualizationGroup.actions():
- actionMode = a.property("mode")
- if mode == actionMode:
- action = a
- break
- old = self.__visualizationGroup.blockSignals(True)
- if action is not None:
- # Check this action
- action.setChecked(True)
- else:
- action = self.__visualizationGroup.checkedAction()
- if action is not None:
- # Uncheck this action
- action.setChecked(False)
- self.__updateVisualizationMenu()
- self.__visualizationGroup.blockSignals(old)
-
- mode = widget.getAlignmentMode()
- action = None
- for a in self.__alignmentGroup.actions():
- actionMode = a.property("mode")
- if mode == actionMode:
- action = a
- break
- old = self.__alignmentGroup.blockSignals(True)
- if action is not None:
- # Check this action
- action.setChecked(True)
- else:
- action = self.__alignmentGroup.checkedAction()
- if action is not None:
- # Uncheck this action
- action.setChecked(False)
- self.__updateAlignmentMenu()
- self.__alignmentGroup.blockSignals(old)
-
- def __visualizationModeChanged(self, selectedAction):
- """Called when user requesting changes of the visualization mode.
- """
- self.__updateVisualizationMenu()
- widget = self.getCompareWidget()
- if widget is not None:
- mode = selectedAction.property("mode")
- widget.setVisualizationMode(mode)
-
- def __updateVisualizationMenu(self):
- """Update the state of the action containing visualization menu.
- """
- selectedAction = self.__visualizationGroup.checkedAction()
- if selectedAction is not None:
- self.__visualizationToolButton.setText(selectedAction.text())
- self.__visualizationToolButton.setIcon(selectedAction.icon())
- self.__visualizationToolButton.setToolTip(selectedAction.toolTip())
- else:
- self.__visualizationToolButton.setText("")
- self.__visualizationToolButton.setIcon(qt.QIcon())
- self.__visualizationToolButton.setToolTip("")
-
- def __alignmentModeChanged(self, selectedAction):
- """Called when user requesting changes of the alignment mode.
- """
- self.__updateAlignmentMenu()
- widget = self.getCompareWidget()
- if widget is not None:
- mode = selectedAction.property("mode")
- widget.setAlignmentMode(mode)
-
- def __updateAlignmentMenu(self):
- """Update the state of the action containing alignment menu.
- """
- selectedAction = self.__alignmentGroup.checkedAction()
- if selectedAction is not None:
- self.__alignmentToolButton.setText(selectedAction.text())
- self.__alignmentToolButton.setIcon(selectedAction.icon())
- self.__alignmentToolButton.setToolTip(selectedAction.toolTip())
- else:
- self.__alignmentToolButton.setText("")
- self.__alignmentToolButton.setIcon(qt.QIcon())
- self.__alignmentToolButton.setToolTip("")
-
- def __keypointVisibilityChanged(self):
- """Called when action managing keypoints visibility changes"""
- widget = self.getCompareWidget()
- if widget is not None:
- keypointsVisible = self.__displayKeypoints.isChecked()
- widget.setKeypointsVisible(keypointsVisible)
+from .tools.compare.core import sift
+from .tools.compare.core import VisualizationMode
+from .tools.compare.core import AlignmentMode
+from .tools.compare.core import AffineTransformation
+from .tools.compare.toolbar import CompareImagesToolBar
+from .tools.compare.statusbar import CompareImagesStatusBar
+from .tools.compare.core import _CompareImageItem
-class CompareImagesStatusBar(qt.QStatusBar):
- """StatusBar containing specific information contained in a
- :class:`CompareImages` widget
-
- Use :meth:`setCompareWidget` to connect this toolbar to a specific
- :class:`CompareImages` widget.
-
- :param Union[qt.QWidget,None] parent: Parent of this widget.
- """
- def __init__(self, parent=None):
- qt.QStatusBar.__init__(self, parent)
- self.setSizeGripEnabled(False)
- self.layout().setSpacing(0)
- self.__compareWidget = None
- self._label1 = qt.QLabel(self)
- self._label1.setFrameShape(qt.QFrame.WinPanel)
- self._label1.setFrameShadow(qt.QFrame.Sunken)
- self._label2 = qt.QLabel(self)
- self._label2.setFrameShape(qt.QFrame.WinPanel)
- self._label2.setFrameShadow(qt.QFrame.Sunken)
- self._transform = qt.QLabel(self)
- self._transform.setFrameShape(qt.QFrame.WinPanel)
- self._transform.setFrameShadow(qt.QFrame.Sunken)
- self.addWidget(self._label1)
- self.addWidget(self._label2)
- self.addWidget(self._transform)
- self._pos = None
- self._updateStatusBar()
-
- def setCompareWidget(self, widget):
- """
- Connect this tool bar to a specific :class:`CompareImages` widget.
-
- :param Union[None,CompareImages] widget: The widget to connect with.
- """
- compareWidget = self.getCompareWidget()
- if compareWidget is not None:
- compareWidget.getPlot().sigPlotSignal.disconnect(self.__plotSignalReceived)
- compareWidget.sigConfigurationChanged.disconnect(self.__dataChanged)
- compareWidget = widget
- if compareWidget is None:
- self.__compareWidget = None
- else:
- self.__compareWidget = weakref.ref(compareWidget)
- if compareWidget is not None:
- compareWidget.getPlot().sigPlotSignal.connect(self.__plotSignalReceived)
- compareWidget.sigConfigurationChanged.connect(self.__dataChanged)
-
- def getCompareWidget(self):
- """Returns the connected widget.
-
- :rtype: CompareImages
- """
- if self.__compareWidget is None:
- return None
- else:
- return self.__compareWidget()
-
- def __plotSignalReceived(self, event):
- """Called when old style signals at emmited from the plot."""
- if event["event"] == "mouseMoved":
- x, y = event["x"], event["y"]
- self.__mouseMoved(x, y)
-
- def __mouseMoved(self, x, y):
- """Called when mouse move over the plot."""
- self._pos = x, y
- self._updateStatusBar()
-
- def __dataChanged(self):
- """Called when internal data from the connected widget changes."""
- self._updateStatusBar()
-
- def _formatData(self, data):
- """Format pixel of an image.
-
- It supports intensity, RGB, and RGBA.
-
- :param Union[int,float,numpy.ndarray,str]: Value of a pixel
- :rtype: str
- """
- if data is None:
- return "No data"
- if isinstance(data, (int, numpy.integer)):
- return "%d" % data
- if isinstance(data, (float, numpy.floating)):
- return "%f" % data
- if isinstance(data, numpy.ndarray):
- # RGBA value
- if data.shape == (3,):
- return "R:%d G:%d B:%d" % (data[0], data[1], data[2])
- elif data.shape == (4,):
- return "R:%d G:%d B:%d A:%d" % (data[0], data[1], data[2], data[3])
- _logger.debug("Unsupported data format %s. Cast it to string.", type(data))
- return str(data)
-
- def _updateStatusBar(self):
- """Update the content of the status bar"""
- widget = self.getCompareWidget()
- if widget is None:
- self._label1.setText("Image1: NA")
- self._label2.setText("Image2: NA")
- self._transform.setVisible(False)
- else:
- transform = widget.getTransformation()
- self._transform.setVisible(transform is not None)
- if transform is not None:
- has_notable_translation = not numpy.isclose(transform.tx, 0.0, atol=0.01) \
- or not numpy.isclose(transform.ty, 0.0, atol=0.01)
- has_notable_scale = not numpy.isclose(transform.sx, 1.0, atol=0.01) \
- or not numpy.isclose(transform.sy, 1.0, atol=0.01)
- has_notable_rotation = not numpy.isclose(transform.rot, 0.0, atol=0.01)
-
- strings = []
- if has_notable_translation:
- strings.append("Translation")
- if has_notable_scale:
- strings.append("Scale")
- if has_notable_rotation:
- strings.append("Rotation")
- if strings == []:
- has_translation = not numpy.isclose(transform.tx, 0.0) \
- or not numpy.isclose(transform.ty, 0.0)
- has_scale = not numpy.isclose(transform.sx, 1.0) \
- or not numpy.isclose(transform.sy, 1.0)
- has_rotation = not numpy.isclose(transform.rot, 0.0)
- if has_translation or has_scale or has_rotation:
- text = "No big changes"
- else:
- text = "No changes"
- else:
- text = "+".join(strings)
- self._transform.setText("Align: " + text)
-
- strings = []
- if not numpy.isclose(transform.ty, 0.0):
- strings.append("Translation x: %0.3fpx" % transform.tx)
- if not numpy.isclose(transform.ty, 0.0):
- strings.append("Translation y: %0.3fpx" % transform.ty)
- if not numpy.isclose(transform.sx, 1.0):
- strings.append("Scale x: %0.3f" % transform.sx)
- if not numpy.isclose(transform.sy, 1.0):
- strings.append("Scale y: %0.3f" % transform.sy)
- if not numpy.isclose(transform.rot, 0.0):
- strings.append("Rotation: %0.3fdeg" % (transform.rot * 180 / numpy.pi))
- if strings == []:
- text = "No transformation"
- else:
- text = "\n".join(strings)
- self._transform.setToolTip(text)
-
- if self._pos is None:
- self._label1.setText("Image1: NA")
- self._label2.setText("Image2: NA")
- else:
- data1, data2 = widget.getRawPixelData(self._pos[0], self._pos[1])
- if isinstance(data1, str):
- self._label1.setToolTip(data1)
- text1 = "NA"
- else:
- self._label1.setToolTip("")
- text1 = self._formatData(data1)
- if isinstance(data2, str):
- self._label2.setToolTip(data2)
- text2 = "NA"
- else:
- self._label2.setToolTip("")
- text2 = self._formatData(data2)
- self._label1.setText("Image1: %s" % text1)
- self._label2.setText("Image2: %s" % text2)
+_logger = logging.getLogger(__name__)
class CompareImages(qt.QMainWindow):
@@ -550,22 +75,28 @@ class CompareImages(qt.QMainWindow):
sigConfigurationChanged = qt.Signal()
"""Emitted when the configuration of the widget (visualization mode,
- alignement mode...) have changed."""
+ alignment mode...) have changed."""
def __init__(self, parent=None, backend=None):
qt.QMainWindow.__init__(self, parent)
self._resetZoomActive = True
self._colormap = Colormap()
"""Colormap shared by all modes, except the compose images (rgb image)"""
- self._colormapKeyPoints = Colormap('spring')
+ self._colormapKeyPoints = Colormap("spring")
"""Colormap used for sift keypoints"""
+ self._colormap.sigChanged.connect(self.__colormapChanged)
+
if parent is None:
- self.setWindowTitle('Compare images')
+ self.setWindowTitle("Compare images")
else:
self.setWindowFlags(qt.Qt.Widget)
self.__transformation = None
+ self.__item = _CompareImageItem()
+ self.__item.setName("_virtual")
+ self.__item.setColormap(self._colormap)
+
self.__raw1 = None
self.__raw2 = None
self.__data1 = None
@@ -574,35 +105,44 @@ class CompareImages(qt.QMainWindow):
self.__plot = plot.PlotWidget(parent=self, backend=backend)
self.__plot.setDefaultColormap(self._colormap)
- self.__plot.getXAxis().setLabel('Columns')
- self.__plot.getYAxis().setLabel('Rows')
- if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == 'downward':
+ self.__plot.getXAxis().setLabel("Columns")
+ self.__plot.getYAxis().setLabel("Rows")
+ if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == "downward":
self.__plot.getYAxis().setInverted(True)
+ self.__plot.addItem(self.__item)
+ self.__plot.setActiveImage(self.__item)
self.__plot.setKeepDataAspectRatio(True)
self.__plot.sigPlotSignal.connect(self.__plotSlot)
self.__plot.setAxesDisplayed(False)
+ self.__scatter = Scatter()
+ self.__scatter.setZValue(1)
+ self.__scatter.setColormap(self._colormapKeyPoints)
+ self.__plot.addItem(self.__scatter)
+
self.setCentralWidget(self.__plot)
legend = VisualizationMode.VERTICAL_LINE.name
self.__plot.addXMarker(
- 0,
- legend=legend,
- text='',
- draggable=True,
- color='blue',
- constraint=WeakMethodProxy(self.__separatorConstraint))
+ 0,
+ legend=legend,
+ text="",
+ draggable=True,
+ color="blue",
+ constraint=WeakMethodProxy(self.__separatorConstraint),
+ )
self.__vline = self.__plot._getMarker(legend)
legend = VisualizationMode.HORIZONTAL_LINE.name
self.__plot.addYMarker(
- 0,
- legend=legend,
- text='',
- draggable=True,
- color='blue',
- constraint=WeakMethodProxy(self.__separatorConstraint))
+ 0,
+ legend=legend,
+ text="",
+ draggable=True,
+ color="blue",
+ constraint=WeakMethodProxy(self.__separatorConstraint),
+ )
self.__hline = self.__plot._getMarker(legend)
# default values
@@ -630,6 +170,26 @@ class CompareImages(qt.QMainWindow):
if self._statusBar is not None:
self.setStatusBar(self._statusBar)
+ def __getSealedColormap(self):
+ vrange = self._colormap.getColormapRange(
+ self.__item.getColormappedData(copy=False)
+ )
+ sealed = self._colormap.copy()
+ sealed.setVRange(*vrange)
+ return sealed
+
+ def __colormapChanged(self):
+ sealed = self.__getSealedColormap()
+ if self.__image1 is not None:
+ if self.__getImageMode(self.__image1.getData(copy=False)) == "intensity":
+ self.__image1.setColormap(sealed)
+ if self.__image2 is not None:
+ if self.__getImageMode(self.__image2.getData(copy=False)) == "intensity":
+ self.__image2.setColormap(sealed)
+
+ if "COMPOSITE" in self.__visualizationMode.name:
+ self.__updateData()
+
def _createStatusBar(self, plot):
self._statusBar = CompareImagesStatusBar(self)
self._statusBar.setCompareWidget(self)
@@ -644,6 +204,9 @@ class CompareImages(qt.QMainWindow):
toolBar.setCompareWidget(self)
self._compareToolBar = toolBar
+ def _getVirtualPlotItem(self):
+ return self.__item
+
def getPlot(self):
"""Returns the plot which is used to display the images.
@@ -676,10 +239,15 @@ class CompareImages(qt.QMainWindow):
It also could be a string containing information is some cases.
:rtype: Tuple(Union[int,float,numpy.ndarray,str],Union[int,float,numpy.ndarray,str])
"""
- data2 = None
alignmentMode = self.__alignmentMode
raw1, raw2 = self.__raw1, self.__raw2
- if alignmentMode == AlignmentMode.ORIGIN:
+
+ if raw1 is None or raw2 is None:
+ x1 = x
+ y1 = y
+ x2 = x
+ y2 = y
+ elif alignmentMode == AlignmentMode.ORIGIN:
x1 = x
y1 = y
x2 = x
@@ -700,22 +268,29 @@ class CompareImages(qt.QMainWindow):
x1 = x
y1 = y
# Not implemented
- data2 = "Not implemented with sift"
+ x2 = -1
+ y2 = -1
else:
- assert(False)
+ assert False
x1, y1 = int(x1), int(y1)
- if raw1 is None or y1 < 0 or y1 >= raw1.shape[0] or x1 < 0 or x1 >= raw1.shape[1]:
- data1 = None
+ x2, y2 = int(x2), int(y2)
+
+ if raw1 is None:
+ data1 = "No image A"
+ elif y1 < 0 or y1 >= raw1.shape[0] or x1 < 0 or x1 >= raw1.shape[1]:
+ data1 = ""
else:
data1 = raw1[y1, x1]
- if data2 is None:
- x2, y2 = int(x2), int(y2)
- if raw2 is None or y2 < 0 or y2 >= raw2.shape[0] or x2 < 0 or x2 >= raw2.shape[1]:
- data2 = None
- else:
- data2 = raw2[y2, x2]
+ if raw2 is None:
+ data2 = "No image B"
+ elif alignmentMode == AlignmentMode.AUTO:
+ data2 = "Not implemented with sift"
+ elif y2 < 0 or y2 >= raw2.shape[0] or x2 < 0 or x2 >= raw2.shape[1]:
+ data2 = None
+ else:
+ data2 = raw2[y2, x2]
return data1, data2
@@ -726,20 +301,31 @@ class CompareImages(qt.QMainWindow):
"""
if self.__visualizationMode == mode:
return
- previousMode = self.getVisualizationMode()
self.__visualizationMode = mode
- mode = self.getVisualizationMode()
+ self.__item.setVizualisationMode(mode)
self.__vline.setVisible(mode == VisualizationMode.VERTICAL_LINE)
self.__hline.setVisible(mode == VisualizationMode.HORIZONTAL_LINE)
- visModeRawDisplay = (VisualizationMode.ONLY_A,
- VisualizationMode.ONLY_B,
- VisualizationMode.VERTICAL_LINE,
- VisualizationMode.HORIZONTAL_LINE)
- updateColormap = not(previousMode in visModeRawDisplay and
- mode in visModeRawDisplay)
- self.__updateData(updateColormap=updateColormap)
+ self.__updateData()
self.sigConfigurationChanged.emit()
+ def centerLines(self):
+ """Center the line used to compare the 2 images."""
+ if self.__image1 is None:
+ return
+ data_range = self.__plot.getDataRange()
+
+ if data_range[0] is not None:
+ cx = (data_range[0][0] + data_range[0][1]) * 0.5
+ else:
+ cx = 0
+ if data_range[1] is not None:
+ cy = (data_range[1][0] + data_range[1][1]) * 0.5
+ else:
+ cy = 0
+ self.__vline.setPosition(cx, cy)
+ self.__hline.setPosition(cx, cy)
+ self.__updateSeparators()
+
def getVisualizationMode(self):
"""Returns the current interaction mode."""
return self.__visualizationMode
@@ -752,13 +338,17 @@ class CompareImages(qt.QMainWindow):
if self.__alignmentMode == mode:
return
self.__alignmentMode = mode
- self.__updateData(updateColormap=False)
+ self.__updateData()
self.sigConfigurationChanged.emit()
def getAlignmentMode(self):
"""Returns the current selected alignemnt mode."""
return self.__alignmentMode
+ def getKeypointsVisible(self):
+ """Returns true if the keypoints are displayed"""
+ return self.__keypointsVisible
+
def setKeypointsVisible(self, isVisible):
"""Set keypoints visibility.
@@ -776,16 +366,16 @@ class CompareImages(qt.QMainWindow):
def __plotSlot(self, event):
"""Handle events from the plot"""
- if event['event'] in ('markerMoving', 'markerMoved'):
+ if event["event"] in ("markerMoving", "markerMoved"):
mode = self.getVisualizationMode()
legend = mode.name
- if event['label'] == legend:
+ if event["label"] == legend:
if mode == VisualizationMode.VERTICAL_LINE:
- value = int(float(str(event['xdata'])))
+ value = int(float(str(event["xdata"])))
elif mode == VisualizationMode.HORIZONTAL_LINE:
- value = int(float(str(event['ydata'])))
+ value = int(float(str(event["ydata"])))
else:
- assert(False)
+ assert False
if self.__previousSeparatorPosition != value:
self.__separatorMoved(value)
self.__previousSeparatorPosition = value
@@ -807,8 +397,7 @@ class CompareImages(qt.QMainWindow):
return x, y
def __updateSeparators(self):
- """Redraw images according to the current state of the separators.
- """
+ """Redraw images according to the current state of the separators."""
mode = self.getVisualizationMode()
if mode == VisualizationMode.VERTICAL_LINE:
pos = self.__vline.getXPosition()
@@ -820,7 +409,8 @@ class CompareImages(qt.QMainWindow):
self.__previousSeparatorPosition = pos
else:
self.__image1.setOrigin((0, 0))
- self.__image2.setOrigin((0, 0))
+ if self.__image2 is not None:
+ self.__image2.setOrigin((0, 0))
def __separatorMoved(self, pos):
"""Called when vertical or horizontal separators have moved.
@@ -840,8 +430,9 @@ class CompareImages(qt.QMainWindow):
data1 = self.__data1[:, 0:pos]
data2 = self.__data2[:, pos:]
self.__image1.setData(data1, copy=False)
- self.__image2.setData(data2, copy=False)
- self.__image2.setOrigin((pos, 0))
+ if self.__image2 is not None:
+ self.__image2.setData(data2, copy=False)
+ self.__image2.setOrigin((pos, 0))
elif mode == VisualizationMode.HORIZONTAL_LINE:
pos = int(pos)
if pos <= 0:
@@ -851,150 +442,209 @@ class CompareImages(qt.QMainWindow):
data1 = self.__data1[0:pos, :]
data2 = self.__data2[pos:, :]
self.__image1.setData(data1, copy=False)
- self.__image2.setData(data2, copy=False)
- self.__image2.setOrigin((0, pos))
+ if self.__image2 is not None:
+ self.__image2.setData(data2, copy=False)
+ self.__image2.setOrigin((0, pos))
else:
- assert(False)
+ assert False
- def setData(self, image1, image2, updateColormap=True):
+ def clear(self):
+ self.setData(None, None)
+
+ def setData(self, image1, image2, updateColormap="deprecated"):
"""Set images to compare.
Images can contains floating-point or integer values, or RGB and RGBA
values, but should have comparable intensities.
RGB and RGBA images are provided as an array as `[width,height,channels]`
- of usigned integer 8-bits or floating-points between 0.0 to 1.0.
+ of unsigned integer 8-bits or floating-points between 0.0 to 1.0.
:param numpy.ndarray image1: The first image
:param numpy.ndarray image2: The second image
"""
+ if updateColormap != "deprecated":
+ deprecated_warning(
+ "Argument", "setData's updateColormap argument", since_version="2.0.0"
+ )
+
self.__raw1 = image1
self.__raw2 = image2
- self.__updateData(updateColormap=updateColormap)
+ self.__updateData()
if self.isAutoResetZoom():
self.__plot.resetZoom()
- def setImage1(self, image1, updateColormap=True):
+ def setImage1(self, image1, updateColormap="deprecated"):
"""Set image1 to be compared.
Images can contains floating-point or integer values, or RGB and RGBA
values, but should have comparable intensities.
RGB and RGBA images are provided as an array as `[width,height,channels]`
- of usigned integer 8-bits or floating-points between 0.0 to 1.0.
+ of unsigned integer 8-bits or floating-points between 0.0 to 1.0.
:param numpy.ndarray image1: The first image
"""
+ if updateColormap != "deprecated":
+ deprecated_warning(
+ "Argument", "setImage1's updateColormap argument", since_version="2.0.0"
+ )
+
self.__raw1 = image1
- self.__updateData(updateColormap=updateColormap)
+ self.__updateData()
if self.isAutoResetZoom():
self.__plot.resetZoom()
- def setImage2(self, image2, updateColormap=True):
+ def setImage2(self, image2, updateColormap="deprecated"):
"""Set image2 to be compared.
Images can contains floating-point or integer values, or RGB and RGBA
values, but should have comparable intensities.
RGB and RGBA images are provided as an array as `[width,height,channels]`
- of usigned integer 8-bits or floating-points between 0.0 to 1.0.
+ of unsigned integer 8-bits or floating-points between 0.0 to 1.0.
:param numpy.ndarray image2: The second image
"""
+ if updateColormap != "deprecated":
+ deprecated_warning(
+ "Argument", "setImage2's updateColormap argument", since_version="2.0.0"
+ )
+
self.__raw2 = image2
- self.__updateData(updateColormap=updateColormap)
+ self.__updateData()
if self.isAutoResetZoom():
self.__plot.resetZoom()
def __updateKeyPoints(self):
- """Update the displayed keypoints using cached keypoints.
- """
- if self.__keypointsVisible:
+ """Update the displayed keypoints using cached keypoints."""
+ if self.__keypointsVisible and self.__matching_keypoints:
data = self.__matching_keypoints
else:
data = [], [], []
- self.__plot.addScatter(x=data[0],
- y=data[1],
- z=1,
- value=data[2],
- colormap=self._colormapKeyPoints,
- legend="keypoints")
-
- def __updateData(self, updateColormap):
+ self.__scatter.setData(x=data[0], y=data[1], value=data[2])
+
+ def __updateData(self):
"""Compute aligned image when the alignment mode changes.
This function cache input images which are used when
vertical/horizontal separators moves.
"""
raw1, raw2 = self.__raw1, self.__raw2
- if raw1 is None or raw2 is None:
- return
alignmentMode = self.getAlignmentMode()
self.__transformation = None
- if alignmentMode == AlignmentMode.ORIGIN:
- yy = max(raw1.shape[0], raw2.shape[0])
- xx = max(raw1.shape[1], raw2.shape[1])
- size = yy, xx
- data1 = self.__createMarginImage(raw1, size, transparent=True)
- data2 = self.__createMarginImage(raw2, size, transparent=True)
- self.__matching_keypoints = [0.0], [0.0], [1.0]
- elif alignmentMode == AlignmentMode.CENTER:
- yy = max(raw1.shape[0], raw2.shape[0])
- xx = max(raw1.shape[1], raw2.shape[1])
- size = yy, xx
- data1 = self.__createMarginImage(raw1, size, transparent=True, center=True)
- data2 = self.__createMarginImage(raw2, size, transparent=True, center=True)
- self.__matching_keypoints = ([data1.shape[1] // 2],
- [data1.shape[0] // 2],
- [1.0])
- elif alignmentMode == AlignmentMode.STRETCH:
- data1 = raw1
- data2 = self.__rescaleImage(raw2, data1.shape)
- self.__matching_keypoints = ([0, data1.shape[1], data1.shape[1], 0],
- [0, 0, data1.shape[0], data1.shape[0]],
- [1.0, 1.0, 1.0, 1.0])
- elif alignmentMode == AlignmentMode.AUTO:
- # TODO: sift implementation do not support RGBA images
- yy = max(raw1.shape[0], raw2.shape[0])
- xx = max(raw1.shape[1], raw2.shape[1])
- size = yy, xx
- data1 = self.__createMarginImage(raw1, size)
- data2 = self.__createMarginImage(raw2, size)
- self.__matching_keypoints = [0.0], [0.0], [1.0]
- try:
- data1, data2 = self.__createSiftData(data1, data2)
- if data2 is None:
- raise ValueError("Unexpected None value")
- except Exception as e:
- # TODO: Display it on the GUI
- _logger.error(e)
- self.__setDefaultAlignmentMode()
- return
+ if raw1 is None or raw2 is None:
+ # No need to realign the 2 images
+ # But create a dummy image when there is None for simplification
+ if raw1 is None:
+ data1 = numpy.empty((0, 0))
+ else:
+ data1 = raw1
+ if raw2 is None:
+ data2 = numpy.empty((0, 0))
+ else:
+ data2 = raw2
+ self.__matching_keypoints = None
else:
- assert(False)
+ if alignmentMode == AlignmentMode.ORIGIN:
+ yy = max(raw1.shape[0], raw2.shape[0])
+ xx = max(raw1.shape[1], raw2.shape[1])
+ size = yy, xx
+ data1 = self.__createMarginImage(raw1, size, transparent=True)
+ data2 = self.__createMarginImage(raw2, size, transparent=True)
+ self.__matching_keypoints = [0.0], [0.0], [1.0]
+ elif alignmentMode == AlignmentMode.CENTER:
+ yy = max(raw1.shape[0], raw2.shape[0])
+ xx = max(raw1.shape[1], raw2.shape[1])
+ size = yy, xx
+ data1 = self.__createMarginImage(
+ raw1, size, transparent=True, center=True
+ )
+ data2 = self.__createMarginImage(
+ raw2, size, transparent=True, center=True
+ )
+ self.__matching_keypoints = (
+ [data1.shape[1] // 2],
+ [data1.shape[0] // 2],
+ [1.0],
+ )
+ elif alignmentMode == AlignmentMode.STRETCH:
+ data1 = raw1
+ data2 = self.__rescaleImage(raw2, data1.shape)
+ self.__matching_keypoints = (
+ [0, data1.shape[1], data1.shape[1], 0],
+ [0, 0, data1.shape[0], data1.shape[0]],
+ [1.0, 1.0, 1.0, 1.0],
+ )
+ elif alignmentMode == AlignmentMode.AUTO:
+ # TODO: sift implementation do not support RGBA images
+ yy = max(raw1.shape[0], raw2.shape[0])
+ xx = max(raw1.shape[1], raw2.shape[1])
+ size = yy, xx
+ data1 = self.__createMarginImage(raw1, size)
+ data2 = self.__createMarginImage(raw2, size)
+ self.__matching_keypoints = [0.0], [0.0], [1.0]
+ try:
+ data1, data2 = self.__createSiftData(data1, data2)
+ if data2 is None:
+ raise ValueError("Unexpected None value")
+ except Exception as e:
+ # TODO: Display it on the GUI
+ _logger.error(e)
+ self.__setDefaultAlignmentMode()
+ return
+ else:
+ assert False
+
+ self.__item.setImageData1(data1)
+ self.__item.setImageData2(data2)
mode = self.getVisualizationMode()
if mode == VisualizationMode.COMPOSITE_RED_BLUE_GRAY_NEG:
- data1 = self.__composeImage(data1, data2, mode)
- data2 = numpy.empty((0, 0))
+ data1 = self.__composeRgbImage(data1, data2, mode)
+ data2 = None
elif mode == VisualizationMode.COMPOSITE_RED_BLUE_GRAY:
- data1 = self.__composeImage(data1, data2, mode)
- data2 = numpy.empty((0, 0))
+ data1 = self.__composeRgbImage(data1, data2, mode)
+ data2 = None
elif mode == VisualizationMode.COMPOSITE_A_MINUS_B:
- data1 = self.__composeImage(data1, data2, mode)
- data2 = numpy.empty((0, 0))
+ data1 = self.__composeAMinusBImage(data1, data2)
+ data2 = None
elif mode == VisualizationMode.ONLY_A:
- data2 = numpy.empty((0, 0))
+ data2 = None
elif mode == VisualizationMode.ONLY_B:
data1 = numpy.empty((0, 0))
self.__data1, self.__data2 = data1, data2
- self.__plot.addImage(data1, z=0, legend="image1", resetzoom=False)
- self.__plot.addImage(data2, z=0, legend="image2", resetzoom=False)
+
+ colormap = self.__getSealedColormap()
+ mode1 = self.__getImageMode(self.__data1)
+ if mode1 == "intensity":
+ colormap1 = colormap
+ else:
+ colormap1 = None
+ self.__plot.addImage(
+ data1, z=0, legend="image1", resetzoom=False, colormap=colormap1
+ )
self.__image1 = self.__plot.getImage("image1")
- self.__image2 = self.__plot.getImage("image2")
+
+ if data2 is not None:
+ mode2 = self.__getImageMode(data2)
+ if mode2 == "intensity":
+ colormap2 = colormap
+ else:
+ colormap2 = None
+ self.__plot.addImage(
+ data2, z=0, legend="image2", resetzoom=False, colormap=colormap2
+ )
+ self.__image2 = self.__plot.getImage("image2")
+ self.__image2.setVisible(True)
+ else:
+ if self.__image2 is not None:
+ self.__image2.setVisible(False)
+ self.__image2 = None
+ self.__data2 = numpy.empty((0, 0))
self.__updateKeyPoints()
# Set the separator into the middle
@@ -1004,27 +654,6 @@ class CompareImages(qt.QMainWindow):
value = self.__data1.shape[0] // 2
self.__hline.setPosition(0, value)
self.__updateSeparators()
- if updateColormap:
- self.__updateColormap()
-
- def __updateColormap(self):
- # TODO: The colormap histogram will still be wrong
- mode1 = self.__getImageMode(self.__data1)
- mode2 = self.__getImageMode(self.__data2)
- if mode1 == "intensity" and mode1 == mode2:
- if self.__data1.size == 0:
- vmin = self.__data2.min()
- vmax = self.__data2.max()
- elif self.__data2.size == 0:
- vmin = self.__data1.min()
- vmax = self.__data1.max()
- else:
- vmin = min(self.__data1.min(), self.__data2.min())
- vmax = max(self.__data1.max(), self.__data2.max())
- colormap = self.getColormap()
- colormap.setVRange(vmin=vmin, vmax=vmax)
- self.__image1.setColormap(colormap)
- self.__image2.setColormap(colormap)
def __getImageMode(self, image):
"""Returns a value identifying the way the image is stored in the
@@ -1060,62 +689,117 @@ class CompareImages(qt.QMainWindow):
data[:, :, c] = self.__rescaleArray(image[:, :, c], shape)
return data
- def __composeImage(self, data1, data2, mode):
+ def __composeRgbImage(self, data1, data2, mode):
"""Returns an RBG image containing composition of data1 and data2 in 2
different channels
+ A data image of a size of 0 is considered as missing. This does not
+ interrupt the processing.
+
:param numpy.ndarray data1: First image
:param numpy.ndarray data1: Second image
:param VisualizationMode mode: Composition mode.
:rtype: numpy.ndarray
"""
- assert(data1.shape[0:2] == data2.shape[0:2])
- if mode == VisualizationMode.COMPOSITE_A_MINUS_B:
- # TODO: this calculation has no interest of generating a 'composed'
- # rgb image, this could be moved in an other function or doc
- # should be modified
- _type = data1.dtype
- result = data1.astype(numpy.float64) - data2.astype(numpy.float64)
- return result
- mode1 = self.__getImageMode(data1)
- if mode1 in ["rgb", "rgba"]:
- intensity1 = self.__luminosityImage(data1)
- vmin1, vmax1 = 0.0, 1.0
+ if data1.size != 0 and data2.size != 0:
+ assert data1.shape[0:2] == data2.shape[0:2]
+
+ sealed = self.__getSealedColormap()
+ vmin, vmax = sealed.getVRange()
+
+ if data1.size == 0:
+ intensity1 = numpy.zeros(data2.shape[0:2])
else:
- intensity1 = data1
- vmin1, vmax1 = data1.min(), data1.max()
+ mode1 = self.__getImageMode(data1)
+ if mode1 in ["rgb", "rgba"]:
+ intensity1 = self.__luminosityImage(data1)
+ else:
+ intensity1 = data1
- mode2 = self.__getImageMode(data2)
- if mode2 in ["rgb", "rgba"]:
- intensity2 = self.__luminosityImage(data2)
- vmin2, vmax2 = 0.0, 1.0
+ if data2.size == 0:
+ intensity2 = numpy.zeros(data1.shape[0:2])
else:
- intensity2 = data2
- vmin2, vmax2 = data2.min(), data2.max()
+ mode2 = self.__getImageMode(data2)
+ if mode2 in ["rgb", "rgba"]:
+ intensity2 = self.__luminosityImage(data2)
+ else:
+ intensity2 = data2
- vmin, vmax = min(vmin1, vmin2) * 1.0, max(vmax1, vmax2) * 1.0
- shape = data1.shape
+ shape = intensity1.shape
result = numpy.empty((shape[0], shape[1], 3), dtype=numpy.uint8)
- a = (intensity1 - vmin) * (1.0 / (vmax - vmin)) * 255.0
- b = (intensity2 - vmin) * (1.0 / (vmax - vmin)) * 255.0
+ a, _, _ = normalize(
+ intensity1,
+ norm=sealed.getNormalization(),
+ autoscale=sealed.getAutoscaleMode(),
+ vmin=sealed.getVMin(),
+ vmax=sealed.getVMax(),
+ gamma=sealed.getGammaNormalizationParameter(),
+ )
+ b, _, _ = normalize(
+ intensity2,
+ norm=sealed.getNormalization(),
+ autoscale=sealed.getAutoscaleMode(),
+ vmin=sealed.getVMin(),
+ vmax=sealed.getVMax(),
+ gamma=sealed.getGammaNormalizationParameter(),
+ )
if mode == VisualizationMode.COMPOSITE_RED_BLUE_GRAY:
result[:, :, 0] = a
- result[:, :, 1] = (a + b) / 2
+ result[:, :, 1] = a // 2 + b // 2
result[:, :, 2] = b
elif mode == VisualizationMode.COMPOSITE_RED_BLUE_GRAY_NEG:
result[:, :, 0] = 255 - b
- result[:, :, 1] = 255 - (a + b) / 2
+ result[:, :, 1] = 255 - (a // 2 + b // 2)
result[:, :, 2] = 255 - a
return result
- def __luminosityImage(self, image):
+ def __composeAMinusBImage(self, data1, data2):
+ """Returns an intensity image containing the composition of `A-B`.
+
+ A data image of a size of 0 is considered as missing. This does not
+ interrupt the processing.
+
+ :param numpy.ndarray data1: First image
+ :param numpy.ndarray data1: Second image
+ :rtype: numpy.ndarray
+ """
+ if data1.size != 0 and data2.size != 0:
+ assert data1.shape[0:2] == data2.shape[0:2]
+
+ data1 = self.__asIntensityImage(data1)
+ data2 = self.__asIntensityImage(data2)
+ if data1.size == 0:
+ result = data2
+ elif data2.size == 0:
+ result = data1
+ else:
+ result = data1.astype(numpy.float32) - data2.astype(numpy.float32)
+ return result
+
+ def __asIntensityImage(self, image: numpy.ndarray):
+ """Returns an intensity image.
+
+ If the image use a single channel, it will be returned as it is.
+
+ If the image is an RBG(A) image, the luminosity (0..1) is extracted and
+ returned. The alpha channel is ignored.
+
+ :rtype: numpy.ndarray
+ """
+ mode = self.__getImageMode(image)
+ if mode in ["rgb", "rgba"]:
+ return self.__luminosityImage(image)
+ return image
+
+ def __luminosityImage(self, image: numpy.ndarray):
"""Returns the luminosity channel from an RBG(A) image.
+
The alpha channel is ignored.
:rtype: numpy.ndarray
"""
mode = self.__getImageMode(image)
- assert(mode in ["rgb", "rgba"])
+ assert mode in ["rgb", "rgba"]
is_uint8 = image.dtype.type == numpy.uint8
# luminosity
image = 0.21 * image[..., 0] + 0.72 * image[..., 1] + 0.07 * image[..., 2]
@@ -1128,8 +812,10 @@ class CompareImages(qt.QMainWindow):
:rtype: numpy.ndarray
"""
- y, x = numpy.ogrid[:shape[0], :shape[1]]
- y, x = y * 1.0 * (image.shape[0] - 1) / (shape[0] - 1), x * 1.0 * (image.shape[1] - 1) / (shape[1] - 1)
+ y, x = numpy.ogrid[: shape[0], : shape[1]]
+ y, x = y * 1.0 * (image.shape[0] - 1) / (shape[0] - 1), x * 1.0 * (
+ image.shape[1] - 1
+ ) / (shape[1] - 1)
b = silx.image.bilinear.BilinearImage(image)
# TODO: could be optimized using strides
x2d = numpy.zeros_like(y) + x
@@ -1142,8 +828,8 @@ class CompareImages(qt.QMainWindow):
:rtype: numpy.ndarray
"""
- assert(image.shape[0] <= size[0])
- assert(image.shape[1] <= size[1])
+ assert image.shape[0] <= size[0]
+ assert image.shape[1] <= size[1]
if image.shape == size:
return image
mode = self.__getImageMode(image)
@@ -1156,7 +842,7 @@ class CompareImages(qt.QMainWindow):
if mode == "intensity":
data = numpy.zeros(size, dtype=image.dtype)
- data[pos0:pos0 + image.shape[0], pos1:pos1 + image.shape[1]] = image
+ data[pos0 : pos0 + image.shape[0], pos1 : pos1 + image.shape[1]] = image
# TODO: It is maybe possible to put NaN on the margin
else:
if transparent:
@@ -1164,9 +850,13 @@ class CompareImages(qt.QMainWindow):
else:
data = numpy.zeros((size[0], size[1], 3), dtype=numpy.uint8)
depth = min(data.shape[2], image.shape[2])
- data[pos0:pos0 + image.shape[0], pos1:pos1 + image.shape[1], 0:depth] = image[:, :, 0:depth]
+ data[
+ pos0 : pos0 + image.shape[0], pos1 : pos1 + image.shape[1], 0:depth
+ ] = image[:, :, 0:depth]
if transparent and depth == 3:
- data[pos0:pos0 + image.shape[0], pos1:pos1 + image.shape[1], 3] = 255
+ data[
+ pos0 : pos0 + image.shape[0], pos1 : pos1 + image.shape[1], 3
+ ] = 255
return data
def __toAffineTransformation(self, sift_result):
@@ -1190,7 +880,7 @@ class CompareImages(qt.QMainWindow):
return AffineTransformation(tx, ty, sx, sy, rot)
def getTransformation(self):
- """Retuns the affine transformation applied to the second image to align
+ """Returns the affine transformation applied to the second image to align
it to the first image.
This result is only valid for sift alignment.
@@ -1219,9 +909,11 @@ class CompareImages(qt.QMainWindow):
_logger.info("Number of Keypoints within image 1: %i" % keypoints.size)
_logger.info(" within image 2: %i" % second_keypoints.size)
- self.__matching_keypoints = (match[:].x[:, 0],
- match[:].y[:, 0],
- match[:].scale[:, 0])
+ self.__matching_keypoints = (
+ match[:].x[:, 0],
+ match[:].y[:, 0],
+ match[:].scale[:, 0],
+ )
matching_keypoints = match.shape[0]
_logger.info("Matching keypoints: %i" % matching_keypoints)
if matching_keypoints == 0:
@@ -1241,6 +933,10 @@ class CompareImages(qt.QMainWindow):
self.__transformation = self.__toAffineTransformation(result)
return data1, data2
+ def resetZoom(self, dataMargins=None):
+ """Reset the plot limits to the bounds of the data and redraw the plot."""
+ self.__plot.resetZoom(dataMargins)
+
def setAutoResetZoom(self, activate=True):
"""
diff --git a/src/silx/gui/plot/ComplexImageView.py b/src/silx/gui/plot/ComplexImageView.py
index 7febd19..654a1c1 100644
--- a/src/silx/gui/plot/ComplexImageView.py
+++ b/src/silx/gui/plot/ComplexImageView.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -33,10 +33,8 @@ __date__ = "24/04/2018"
import logging
-import collections
import numpy
-from ...utils.deprecation import deprecated
from .. import qt, icons
from .PlotWindow import Plot2D
from . import items
@@ -48,6 +46,7 @@ _logger = logging.getLogger(__name__)
# Widgets
+
class _AmplitudeRangeDialog(qt.QDialog):
"""QDialog asking for the amplitude range to display."""
@@ -57,12 +56,9 @@ class _AmplitudeRangeDialog(qt.QDialog):
It provides the new range as a 2-tuple: (max, delta)
"""
- def __init__(self,
- parent=None,
- amplitudeRange=None,
- displayedRange=(None, 2)):
+ def __init__(self, parent=None, amplitudeRange=None, displayedRange=(None, 2)):
super(_AmplitudeRangeDialog, self).__init__(parent)
- self.setWindowTitle('Set Displayed Amplitude Range')
+ self.setWindowTitle("Set Displayed Amplitude Range")
if amplitudeRange is not None:
amplitudeRange = min(amplitudeRange), max(amplitudeRange)
@@ -74,25 +70,24 @@ class _AmplitudeRangeDialog(qt.QDialog):
if self._amplitudeRange is not None:
min_, max_ = self._amplitudeRange
- layout.addRow(
- qt.QLabel('Data Amplitude Range: [%g, %g]' % (min_, max_)))
+ layout.addRow(qt.QLabel("Data Amplitude Range: [%g, %g]" % (min_, max_)))
self._maxLineEdit = FloatEdit(parent=self)
- self._maxLineEdit.validator().setBottom(0.)
+ self._maxLineEdit.validator().setBottom(0.0)
self._maxLineEdit.setAlignment(qt.Qt.AlignRight)
self._maxLineEdit.editingFinished.connect(self._rangeUpdated)
- layout.addRow('Displayed Max.:', self._maxLineEdit)
+ layout.addRow("Displayed Max.:", self._maxLineEdit)
- self._autoscale = qt.QCheckBox('autoscale')
+ self._autoscale = qt.QCheckBox("autoscale")
self._autoscale.toggled.connect(self._autoscaleCheckBoxToggled)
- layout.addRow('', self._autoscale)
+ layout.addRow("", self._autoscale)
self._deltaLineEdit = FloatEdit(parent=self)
- self._deltaLineEdit.validator().setBottom(1.)
+ self._deltaLineEdit.validator().setBottom(1.0)
self._deltaLineEdit.setAlignment(qt.Qt.AlignRight)
self._deltaLineEdit.editingFinished.connect(self._rangeUpdated)
- layout.addRow('Displayed delta (log10 unit):', self._deltaLineEdit)
+ layout.addRow("Displayed delta (log10 unit):", self._deltaLineEdit)
buttons = qt.QDialogButtonBox(self)
buttons.addButton(qt.QDialogButtonBox.Ok)
@@ -107,8 +102,7 @@ class _AmplitudeRangeDialog(qt.QDialog):
self.rejected.connect(self._handleRejected)
def _resetDialogToDefault(self):
- """Set Widgets of the dialog from range information
- """
+ """Set Widgets of the dialog from range information"""
max_, delta = self._defaultDisplayedRange
if max_ is not None: # Not in autoscale
@@ -116,7 +110,7 @@ class _AmplitudeRangeDialog(qt.QDialog):
elif self._amplitudeRange is not None: # Autoscale with data
displayedMax = self._amplitudeRange[1]
else: # Autoscale without data
- displayedMax = ''
+ displayedMax = ""
if displayedMax == "":
self._maxLineEdit.setText("")
else:
@@ -149,7 +143,7 @@ class _AmplitudeRangeDialog(qt.QDialog):
"""Handle autoscale checkbox state changes"""
if checked: # Use default values
if self._amplitudeRange is None:
- max_ = ''
+ max_ = ""
else:
max_ = self._amplitudeRange[1]
if max_ == "":
@@ -167,21 +161,31 @@ class _ComplexDataToolButton(qt.QToolButton):
:param plot: The :class:`ComplexImageView` to control
"""
- _MODES = collections.OrderedDict([
- (ImageComplexData.ComplexMode.ABSOLUTE, ('math-amplitude', 'Amplitude')),
- (ImageComplexData.ComplexMode.SQUARE_AMPLITUDE,
- ('math-square-amplitude', 'Square amplitude')),
- (ImageComplexData.ComplexMode.PHASE, ('math-phase', 'Phase')),
- (ImageComplexData.ComplexMode.REAL, ('math-real', 'Real part')),
- (ImageComplexData.ComplexMode.IMAGINARY,
- ('math-imaginary', 'Imaginary part')),
- (ImageComplexData.ComplexMode.AMPLITUDE_PHASE,
- ('math-phase-color', 'Amplitude and Phase')),
- (ImageComplexData.ComplexMode.LOG10_AMPLITUDE_PHASE,
- ('math-phase-color-log', 'Log10(Amp.) and Phase'))
- ])
-
- _RANGE_DIALOG_TEXT = 'Set Amplitude Range...'
+ _MODES = dict(
+ [
+ (ImageComplexData.ComplexMode.ABSOLUTE, ("math-amplitude", "Amplitude")),
+ (
+ ImageComplexData.ComplexMode.SQUARE_AMPLITUDE,
+ ("math-square-amplitude", "Square amplitude"),
+ ),
+ (ImageComplexData.ComplexMode.PHASE, ("math-phase", "Phase")),
+ (ImageComplexData.ComplexMode.REAL, ("math-real", "Real part")),
+ (
+ ImageComplexData.ComplexMode.IMAGINARY,
+ ("math-imaginary", "Imaginary part"),
+ ),
+ (
+ ImageComplexData.ComplexMode.AMPLITUDE_PHASE,
+ ("math-phase-color", "Amplitude and Phase"),
+ ),
+ (
+ ImageComplexData.ComplexMode.LOG10_AMPLITUDE_PHASE,
+ ("math-phase-color-log", "Log10(Amp.) and Phase"),
+ ),
+ ]
+ )
+
+ _RANGE_DIALOG_TEXT = "Set Amplitude Range..."
def __init__(self, parent=None, plot=None):
super(_ComplexDataToolButton, self).__init__(parent=parent)
@@ -207,16 +211,16 @@ class _ComplexDataToolButton(qt.QToolButton):
self.setPopupMode(qt.QToolButton.InstantPopup)
self._modeChanged(self._plot2DComplex.getComplexMode())
- self._plot2DComplex.sigVisualizationModeChanged.connect(
- self._modeChanged)
+ self._plot2DComplex.sigVisualizationModeChanged.connect(self._modeChanged)
def _modeChanged(self, mode):
"""Handle change of visualization modes"""
icon, text = self._MODES[mode]
self.setIcon(icons.getQIcon(icon))
- self.setToolTip('Display the ' + text.lower())
+ self.setToolTip("Display the " + text.lower())
self._rangeDialogAction.setEnabled(
- mode == ImageComplexData.ComplexMode.LOG10_AMPLITUDE_PHASE)
+ mode == ImageComplexData.ComplexMode.LOG10_AMPLITUDE_PHASE
+ )
def _triggered(self, action):
"""Handle triggering of menu actions"""
@@ -236,7 +240,8 @@ class _ComplexDataToolButton(qt.QToolButton):
dialog = _AmplitudeRangeDialog(
parent=self,
amplitudeRange=dataRange,
- displayedRange=self._plot2DComplex._getAmplitudeRangeInfo())
+ displayedRange=self._plot2DComplex._getAmplitudeRangeInfo(),
+ )
dialog.sigRangeChanged.connect(self._rangeChanged)
dialog.exec()
dialog.sigRangeChanged.disconnect(self._rangeChanged)
@@ -272,7 +277,7 @@ class ComplexImageView(qt.QWidget):
def __init__(self, parent=None):
super(ComplexImageView, self).__init__(parent)
if parent is None:
- self.setWindowTitle('ComplexImageView')
+ self.setWindowTitle("ComplexImageView")
self._plot2D = Plot2D(self)
@@ -284,14 +289,13 @@ class ComplexImageView(qt.QWidget):
# Create and add image to the plot
self._plotImage = ImageComplexData()
- self._plotImage.setName('__ComplexImageView__complex_image__')
+ self._plotImage.setName("__ComplexImageView__complex_image__")
self._plotImage.sigItemChanged.connect(self._itemChanged)
self._plot2D.addItem(self._plotImage)
- self._plot2D.setActiveImage(self._plotImage.getName())
+ self._plot2D.setActiveImage(self._plotImage)
- toolBar = qt.QToolBar('Complex', self)
- toolBar.addWidget(
- _ComplexDataToolButton(parent=self, plot=self))
+ toolBar = qt.QToolBar("Complex", self)
+ toolBar.addWidget(_ComplexDataToolButton(parent=self, plot=self))
self._plot2D.insertToolBar(self._plot2D.getProfileToolbar(), toolBar)
@@ -344,8 +348,10 @@ class ComplexImageView(qt.QWidget):
:rtype: numpy.ndarray of float with 2 dims or RGBA image (uint8).
"""
mode = self.getComplexMode()
- if mode in (self.ComplexMode.AMPLITUDE_PHASE,
- self.ComplexMode.LOG10_AMPLITUDE_PHASE):
+ if mode in (
+ self.ComplexMode.AMPLITUDE_PHASE,
+ self.ComplexMode.LOG10_AMPLITUDE_PHASE,
+ ):
return self._plotImage.getRgbaImageData(copy=copy)
else:
return self._plotImage.getData(copy=copy)
@@ -354,19 +360,6 @@ class ComplexImageView(qt.QWidget):
Mode = ComplexMode
- @classmethod
- @deprecated(replacement='supportedComplexModes', since_version='0.11.0')
- def getSupportedVisualizationModes(cls):
- return cls.supportedComplexModes()
-
- @deprecated(replacement='setComplexMode', since_version='0.11.0')
- def setVisualizationMode(self, mode):
- return self.setComplexMode(mode)
-
- @deprecated(replacement='getComplexMode', since_version='0.11.0')
- def getVisualizationMode(self):
- return self.getComplexMode()
-
# Image item proxy
@staticmethod
@@ -490,7 +483,7 @@ class ComplexImageView(qt.QWidget):
:rtype: :class:`.items.Axis`
"""
- return self.getPlot().getYAxis(axis='left')
+ return self.getPlot().getYAxis(axis="left")
def getGraphTitle(self):
"""Return the plot main title as a str."""
diff --git a/src/silx/gui/plot/CurvesROIWidget.py b/src/silx/gui/plot/CurvesROIWidget.py
index f0cc7f3..bd47da0 100644
--- a/src/silx/gui/plot/CurvesROIWidget.py
+++ b/src/silx/gui/plot/CurvesROIWidget.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2004-2022 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -32,14 +32,12 @@ __authors__ = ["V.A. Sole", "T. Vincent", "H. Payno"]
__license__ = "MIT"
__date__ = "13/03/2018"
-from collections import OrderedDict
import logging
import os
import sys
import functools
import numpy
from silx.io import dictdump
-from silx.utils import deprecation
from silx.utils.weakref import WeakMethodProxy
from silx.utils.proxy import docstring
from .. import icons, qt
@@ -107,8 +105,7 @@ class CurvesROIWidget(qt.QWidget):
layout.addWidget(self.headerLabel)
widgetAllCheckbox = qt.QWidget(parent=self)
- self._showAllCheckBox = qt.QCheckBox("show all ROI",
- parent=widgetAllCheckbox)
+ self._showAllCheckBox = qt.QCheckBox("show all ROI", parent=widgetAllCheckbox)
widgetAllCheckbox.setLayout(qt.QHBoxLayout())
spacer = qt.QWidget(parent=widgetAllCheckbox)
spacer.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
@@ -132,14 +129,15 @@ class CurvesROIWidget(qt.QWidget):
self.addButton = qt.QPushButton(hbox)
self.addButton.setText("Add ROI")
- self.addButton.setToolTip('Create a new ROI')
+ self.addButton.setToolTip("Create a new ROI")
self.delButton = qt.QPushButton(hbox)
self.delButton.setText("Delete ROI")
- self.addButton.setToolTip('Remove the selected ROI')
+ self.addButton.setToolTip("Remove the selected ROI")
self.resetButton = qt.QPushButton(hbox)
self.resetButton.setText("Reset")
- self.addButton.setToolTip('Clear all created ROIs. We only let the '
- 'default ROI')
+ self.addButton.setToolTip(
+ "Clear all created ROIs. We only let the " "default ROI"
+ )
hboxlayout.addWidget(self.addButton)
hboxlayout.addWidget(self.delButton)
@@ -149,10 +147,10 @@ class CurvesROIWidget(qt.QWidget):
self.loadButton = qt.QPushButton(hbox)
self.loadButton.setText("Load")
- self.loadButton.setToolTip('Load ROIs from a .ini file')
+ self.loadButton.setToolTip("Load ROIs from a .ini file")
self.saveButton = qt.QPushButton(hbox)
self.saveButton.setText("Save")
- self.loadButton.setToolTip('Save ROIs to a .ini file')
+ self.loadButton.setToolTip("Save ROIs to a .ini file")
hboxlayout.addWidget(self.loadButton)
hboxlayout.addWidget(self.saveButton)
layout.setStretchFactor(self.headerLabel, 0)
@@ -210,6 +208,7 @@ class CurvesROIWidget(qt.QWidget):
def _add(self):
"""Add button clicked handler"""
+
def getNextRoiName():
rois = self.roiTable.getRois(order=None)
roisNames = []
@@ -224,6 +223,7 @@ class CurvesROIWidget(qt.QWidget):
i += 1
newroi = "newroi %d" % i
return newroi
+
roi = ROI(name=getNextRoiName())
if roi.getName() == "ICR":
@@ -242,9 +242,9 @@ class CurvesROIWidget(qt.QWidget):
# back compatibility pymca roi signals
ddict = {}
- ddict['event'] = "AddROI"
- ddict['roilist'] = self.roiTable.roidict.values()
- ddict['roidict'] = self.roiTable.roidict
+ ddict["event"] = "AddROI"
+ ddict["roilist"] = self.roiTable.roidict.values()
+ ddict["roidict"] = self.roiTable.roidict
self.sigROIWidgetSignal.emit(ddict)
# end back compatibility pymca roi signals
@@ -254,9 +254,9 @@ class CurvesROIWidget(qt.QWidget):
# back compatibility pymca roi signals
ddict = {}
- ddict['event'] = "DelROI"
- ddict['roilist'] = self.roiTable.roidict.values()
- ddict['roidict'] = self.roiTable.roidict
+ ddict["event"] = "DelROI"
+ ddict["roilist"] = self.roiTable.roidict.values()
+ ddict["roidict"] = self.roiTable.roidict
self.sigROIWidgetSignal.emit(ddict)
# end back compatibility pymca roi signals
@@ -269,17 +269,16 @@ class CurvesROIWidget(qt.QWidget):
# back compatibility pymca roi signals
ddict = {}
- ddict['event'] = "ResetROI"
- ddict['roilist'] = self.roiTable.roidict.values()
- ddict['roidict'] = self.roiTable.roidict
+ ddict["event"] = "ResetROI"
+ ddict["roilist"] = self.roiTable.roidict.values()
+ ddict["roidict"] = self.roiTable.roidict
self.sigROIWidgetSignal.emit(ddict)
# end back compatibility pymca roi signals
def _load(self):
"""Load button clicked handler"""
dialog = qt.QFileDialog(self)
- dialog.setNameFilters(
- ['INI File *.ini', 'JSON File *.json', 'All *.*'])
+ dialog.setNameFilters(["INI File *.ini", "JSON File *.json", "All *.*"])
dialog.setFileMode(qt.QFileDialog.ExistingFile)
dialog.setDirectory(self.roiFileDir)
if not dialog.exec():
@@ -295,9 +294,9 @@ class CurvesROIWidget(qt.QWidget):
# back compatibility pymca roi signals
ddict = {}
- ddict['event'] = "LoadROI"
- ddict['roilist'] = self.roiTable.roidict.values()
- ddict['roidict'] = self.roiTable.roidict
+ ddict["event"] = "LoadROI"
+ ddict["roilist"] = self.roiTable.roidict.values()
+ ddict["roidict"] = self.roiTable.roidict
self.sigROIWidgetSignal.emit(ddict)
# end back compatibility pymca roi signals
@@ -311,7 +310,7 @@ class CurvesROIWidget(qt.QWidget):
def _save(self):
"""Save button clicked handler"""
dialog = qt.QFileDialog(self)
- dialog.setNameFilters(['INI File *.ini', 'JSON File *.json'])
+ dialog.setNameFilters(["INI File *.ini", "JSON File *.json"])
dialog.setFileMode(qt.QFileDialog.AnyFile)
dialog.setAcceptMode(qt.QFileDialog.AcceptSave)
dialog.setDirectory(self.roiFileDir)
@@ -320,7 +319,7 @@ class CurvesROIWidget(qt.QWidget):
return
outputFile = dialog.selectedFiles()[0]
- extension = '.' + dialog.selectedNameFilter().split('.')[-1]
+ extension = "." + dialog.selectedNameFilter().split(".")[-1]
dialog.close()
if not outputFile.endswith(extension):
@@ -345,16 +344,10 @@ class CurvesROIWidget(qt.QWidget):
"""
self.roiTable.save(filename)
- def setHeader(self, text='ROIs'):
+ def setHeader(self, text="ROIs"):
"""Set the header text of this widget"""
self.headerLabel.setText("<b>%s<\b>" % text)
- @deprecation.deprecated(replacement="calculateRois",
- reason="CamelCase convention",
- since_version="0.7")
- def calculateROIs(self, *args, **kw):
- self.calculateRois(*args, **kw)
-
def calculateRois(self, roiList=None, roiDict=None):
"""Compute ROI information"""
return self.roiTable.calculateRois()
@@ -367,7 +360,7 @@ class CurvesROIWidget(qt.QWidget):
plot = self.getPlotWidget()
curves = () if plot is None else plot.getAllCurves()
if not curves:
- return 1.0, 1.0, 100., 100.
+ return 1.0, 1.0, 100.0, 100.0
xmin, ymin = None, None
xmax, ymax = None, None
@@ -420,12 +413,12 @@ class CurvesROIWidget(qt.QWidget):
def _emitCurrentROISignal(self):
ddict = {}
- ddict['event'] = "currentROISignal"
+ ddict["event"] = "currentROISignal"
if self.roiTable.activeRoi is not None:
- ddict['ROI'] = self.roiTable.activeRoi.toDict()
- ddict['current'] = self.roiTable.activeRoi.getName()
+ ddict["ROI"] = self.roiTable.activeRoi.toDict()
+ ddict["current"] = self.roiTable.activeRoi.getName()
else:
- ddict['current'] = None
+ ddict["current"] = None
if self.__lastSigROISignal != ddict:
self.__lastSigROISignal = ddict
@@ -440,13 +433,14 @@ class _FloatItem(qt.QTableWidgetItem):
"""
Simple QTableWidgetItem overloading the < operator to deal with ordering
"""
+
def __init__(self):
qt.QTableWidgetItem.__init__(self, type=qt.QTableWidgetItem.Type)
def __lt__(self, other):
- if self.text() in ('', ROITable.INFO_NOT_FOUND):
+ if self.text() in ("", ROITable.INFO_NOT_FOUND):
return False
- if other.text() in ('', ROITable.INFO_NOT_FOUND):
+ if other.text() in ("", ROITable.INFO_NOT_FOUND):
return True
return float(self.text()) < float(other.text())
@@ -464,21 +458,23 @@ class ROITable(TableWidget):
"""Signal emitted when the active roi changed or when the value of the
active roi are changing"""
- COLUMNS_INDEX = OrderedDict([
- ('ID', 0),
- ('ROI', 1),
- ('Type', 2),
- ('From', 3),
- ('To', 4),
- ('Raw Counts', 5),
- ('Net Counts', 6),
- ('Raw Area', 7),
- ('Net Area', 8),
- ])
+ COLUMNS_INDEX = dict(
+ [
+ ("ID", 0),
+ ("ROI", 1),
+ ("Type", 2),
+ ("From", 3),
+ ("To", 4),
+ ("Raw Counts", 5),
+ ("Net Counts", 6),
+ ("Raw Area", 7),
+ ("Net Area", 8),
+ ]
+ )
COLUMNS = list(COLUMNS_INDEX.keys())
- INFO_NOT_FOUND = '????????'
+ INFO_NOT_FOUND = "????????"
def __init__(self, parent=None, plot=None, rois=None):
super(ROITable, self).__init__(parent)
@@ -529,26 +525,32 @@ class ROITable(TableWidget):
header = self.horizontalHeader()
header.setSectionResizeMode(qt.QHeaderView.ResizeToContents)
self.sortByColumn(0, qt.Qt.AscendingOrder)
- self.hideColumn(self.COLUMNS_INDEX['ID'])
+ self.hideColumn(self.COLUMNS_INDEX["ID"])
def setPlot(self, plot):
self.clear()
self.plot = plot
def __setTooltip(self):
- self.horizontalHeaderItem(self.COLUMNS_INDEX['ROI']).setToolTip(
- 'Region of interest identifier')
- self.horizontalHeaderItem(self.COLUMNS_INDEX['Type']).setToolTip(
- 'Type of the ROI')
- self.horizontalHeaderItem(self.COLUMNS_INDEX['From']).setToolTip(
- 'X-value of the min point')
- self.horizontalHeaderItem(self.COLUMNS_INDEX['To']).setToolTip(
- 'X-value of the max point')
- self.horizontalHeaderItem(self.COLUMNS_INDEX['Raw Counts']).setToolTip(
- 'Estimation of the integral between y=0 and the selected curve')
- self.horizontalHeaderItem(self.COLUMNS_INDEX['Net Counts']).setToolTip(
- 'Estimation of the integral between the segment [maxPt, minPt] '
- 'and the selected curve')
+ self.horizontalHeaderItem(self.COLUMNS_INDEX["ROI"]).setToolTip(
+ "Region of interest identifier"
+ )
+ self.horizontalHeaderItem(self.COLUMNS_INDEX["Type"]).setToolTip(
+ "Type of the ROI"
+ )
+ self.horizontalHeaderItem(self.COLUMNS_INDEX["From"]).setToolTip(
+ "X-value of the min point"
+ )
+ self.horizontalHeaderItem(self.COLUMNS_INDEX["To"]).setToolTip(
+ "X-value of the max point"
+ )
+ self.horizontalHeaderItem(self.COLUMNS_INDEX["Raw Counts"]).setToolTip(
+ "Estimation of the integral between y=0 and the selected curve"
+ )
+ self.horizontalHeaderItem(self.COLUMNS_INDEX["Net Counts"]).setToolTip(
+ "Estimation of the integral between the segment [maxPt, minPt] "
+ "and the selected curve"
+ )
def setRois(self, rois, order=None):
"""Set the ROIs by providing a dictionary of ROI information.
@@ -565,7 +567,7 @@ class ROITable(TableWidget):
:param str order: Field used for ordering the ROIs.
One of "from", "to", "type".
None (default) for no ordering, or same order as specified
- in parameter ``roidict`` if provided as an OrderedDict.
+ in parameter ``rois`` if provided as a dict.
"""
assert order in [None, "from", "to", "type"]
self.clear()
@@ -576,7 +578,7 @@ class ROITable(TableWidget):
if isinstance(roi, ROI):
_roi = roi
else:
- roi['name'] = roiName
+ roi["name"] = roiName
_roi = ROI._fromDict(roi)
self.addRoi(_roi)
else:
@@ -591,12 +593,11 @@ class ROITable(TableWidget):
:param :class:`ROI` roi: roi to add to the table
"""
assert isinstance(roi, ROI)
- self._getItem(name='ID', row=None, roi=roi)
+ self._getItem(name="ID", row=None, roi=roi)
self._roiDict[roi.getID()] = roi
self._markersHandler.add(roi, _RoiMarkerHandler(roi, self.plot))
self._updateRoiInfo(roi.getID())
- callback = functools.partial(WeakMethodProxy(self._updateRoiInfo),
- roi.getID())
+ callback = functools.partial(WeakMethodProxy(self._updateRoiInfo), roi.getID())
roi.sigChanged.connect(callback)
# set it as the active one
self.setActiveRoi(roi)
@@ -609,7 +610,7 @@ class ROITable(TableWidget):
if item:
return item
else:
- if name == 'ID':
+ if name == "ID":
assert roi
if roi.getID() in self._roiToItems:
return self._roiToItems[roi.getID()]
@@ -617,41 +618,47 @@ class ROITable(TableWidget):
# create a new row
row = self.rowCount()
self.setRowCount(self.rowCount() + 1)
- item = qt.QTableWidgetItem(str(roi.getID()),
- type=qt.QTableWidgetItem.Type)
+ item = qt.QTableWidgetItem(
+ str(roi.getID()), type=qt.QTableWidgetItem.Type
+ )
self._roiToItems[roi.getID()] = item
- elif name == 'ROI':
- item = qt.QTableWidgetItem(roi.getName() if roi else '',
- type=qt.QTableWidgetItem.Type)
- if roi.getName().upper() in ('ICR', 'DEFAULT'):
+ elif name == "ROI":
+ item = qt.QTableWidgetItem(
+ roi.getName() if roi else "", type=qt.QTableWidgetItem.Type
+ )
+ if roi.getName().upper() in ("ICR", "DEFAULT"):
item.setFlags(qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled)
else:
- item.setFlags(qt.Qt.ItemIsSelectable |
- qt.Qt.ItemIsEnabled |
- qt.Qt.ItemIsEditable)
- elif name == 'Type':
+ item.setFlags(
+ qt.Qt.ItemIsSelectable
+ | qt.Qt.ItemIsEnabled
+ | qt.Qt.ItemIsEditable
+ )
+ elif name == "Type":
item = qt.QTableWidgetItem(type=qt.QTableWidgetItem.Type)
item.setFlags((qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled))
- elif name in ('To', 'From'):
+ elif name in ("To", "From"):
item = _FloatItem()
- if roi.getName().upper() in ('ICR', 'DEFAULT'):
+ if roi.getName().upper() in ("ICR", "DEFAULT"):
item.setFlags(qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled)
else:
- item.setFlags(qt.Qt.ItemIsSelectable |
- qt.Qt.ItemIsEnabled |
- qt.Qt.ItemIsEditable)
- elif name in ('Raw Counts', 'Net Counts', 'Raw Area', 'Net Area'):
+ item.setFlags(
+ qt.Qt.ItemIsSelectable
+ | qt.Qt.ItemIsEnabled
+ | qt.Qt.ItemIsEditable
+ )
+ elif name in ("Raw Counts", "Net Counts", "Raw Area", "Net Area"):
item = _FloatItem()
item.setFlags((qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled))
else:
- raise ValueError('item type not recognized')
+ raise ValueError("item type not recognized")
self.setItem(row, self.COLUMNS_INDEX[name], item)
return item
def _itemChanged(self, item):
def getRoi():
- IDItem = self.item(item.row(), self.COLUMNS_INDEX['ID'])
+ IDItem = self.item(item.row(), self.COLUMNS_INDEX["ID"])
assert IDItem
id = int(IDItem.text())
assert id in self._roiDict
@@ -663,21 +670,21 @@ class ROITable(TableWidget):
self.activeROIChanged.emit()
self._userIsEditingRoi = True
- if item.column() in (self.COLUMNS_INDEX['To'], self.COLUMNS_INDEX['From']):
+ if item.column() in (self.COLUMNS_INDEX["To"], self.COLUMNS_INDEX["From"]):
roi = getRoi()
- if item.text() not in ('', self.INFO_NOT_FOUND):
+ if item.text() not in ("", self.INFO_NOT_FOUND):
try:
value = float(item.text())
except ValueError:
value = 0
changed = False
- if item.column() == self.COLUMNS_INDEX['To']:
+ if item.column() == self.COLUMNS_INDEX["To"]:
if value != roi.getTo():
roi.setTo(value)
changed = True
else:
- assert(item.column() == self.COLUMNS_INDEX['From'])
+ assert item.column() == self.COLUMNS_INDEX["From"]
if value != roi.getFrom():
roi.setFrom(value)
changed = True
@@ -685,7 +692,7 @@ class ROITable(TableWidget):
self._updateMarker(roi.getName())
signalChanged(roi)
- if item.column() is self.COLUMNS_INDEX['ROI']:
+ if item.column() is self.COLUMNS_INDEX["ROI"]:
roi = getRoi()
if roi.getName() != item.text():
roi.setName(item.text())
@@ -705,7 +712,7 @@ class ROITable(TableWidget):
roiToRm = set()
for item in activeItems:
row = item.row()
- itemID = self.item(row, self.COLUMNS_INDEX['ID'])
+ itemID = self.item(row, self.COLUMNS_INDEX["ID"])
roiToRm.add(self._roiDict[int(itemID.text())])
[self.removeROI(roi) for roi in roiToRm]
self.blockSignals(old)
@@ -726,8 +733,9 @@ class ROITable(TableWidget):
del self._roiDict[roi.getID()]
self._markersHandler.remove(roi)
- callback = functools.partial(WeakMethodProxy(self._updateRoiInfo),
- roi.getID())
+ callback = functools.partial(
+ WeakMethodProxy(self._updateRoiInfo), roi.getID()
+ )
roi.sigChanged.connect(callback)
def setActiveRoi(self, roi):
@@ -769,42 +777,42 @@ class ROITable(TableWidget):
roi.setTo(max)
roi.blockSignals(False)
- itemID = self._getItem(name='ID', roi=roi, row=None)
- itemName = self._getItem(name='ROI', row=itemID.row(), roi=roi)
+ itemID = self._getItem(name="ID", roi=roi, row=None)
+ itemName = self._getItem(name="ROI", row=itemID.row(), roi=roi)
itemName.setText(roi.getName())
- itemType = self._getItem(name='Type', row=itemID.row(), roi=roi)
+ itemType = self._getItem(name="Type", row=itemID.row(), roi=roi)
itemType.setText(roi.getType() or self.INFO_NOT_FOUND)
- itemFrom = self._getItem(name='From', row=itemID.row(), roi=roi)
- fromdata = str(roi.getFrom()) if roi.getFrom() is not None else self.INFO_NOT_FOUND
+ itemFrom = self._getItem(name="From", row=itemID.row(), roi=roi)
+ fromdata = (
+ str(roi.getFrom()) if roi.getFrom() is not None else self.INFO_NOT_FOUND
+ )
itemFrom.setText(fromdata)
- itemTo = self._getItem(name='To', row=itemID.row(), roi=roi)
+ itemTo = self._getItem(name="To", row=itemID.row(), roi=roi)
todata = str(roi.getTo()) if roi.getTo() is not None else self.INFO_NOT_FOUND
itemTo.setText(todata)
rawCounts, netCounts = roi.computeRawAndNetCounts(
- curve=self.plot.getActiveCurve(just_legend=False))
- itemRawCounts = self._getItem(name='Raw Counts', row=itemID.row(),
- roi=roi)
+ curve=self.plot.getActiveCurve(just_legend=False)
+ )
+ itemRawCounts = self._getItem(name="Raw Counts", row=itemID.row(), roi=roi)
rawCounts = str(rawCounts) if rawCounts is not None else self.INFO_NOT_FOUND
itemRawCounts.setText(rawCounts)
- itemNetCounts = self._getItem(name='Net Counts', row=itemID.row(),
- roi=roi)
+ itemNetCounts = self._getItem(name="Net Counts", row=itemID.row(), roi=roi)
netCounts = str(netCounts) if netCounts is not None else self.INFO_NOT_FOUND
itemNetCounts.setText(netCounts)
rawArea, netArea = roi.computeRawAndNetArea(
- curve=self.plot.getActiveCurve(just_legend=False))
- itemRawArea = self._getItem(name='Raw Area', row=itemID.row(),
- roi=roi)
+ curve=self.plot.getActiveCurve(just_legend=False)
+ )
+ itemRawArea = self._getItem(name="Raw Area", row=itemID.row(), roi=roi)
rawArea = str(rawArea) if rawArea is not None else self.INFO_NOT_FOUND
itemRawArea.setText(rawArea)
- itemNetArea = self._getItem(name='Net Area', row=itemID.row(),
- roi=roi)
+ itemNetArea = self._getItem(name="Net Area", row=itemID.row(), roi=roi)
netArea = str(netArea) if netArea is not None else self.INFO_NOT_FOUND
itemNetArea.setText(netArea)
@@ -813,49 +821,23 @@ class ROITable(TableWidget):
def currentChanged(self, current, previous):
if previous and current.row() != previous.row() and current.row() >= 0:
- roiItem = self.item(current.row(),
- self.COLUMNS_INDEX['ID'])
+ roiItem = self.item(current.row(), self.COLUMNS_INDEX["ID"])
assert roiItem
self.setActiveRoi(self._roiDict[int(roiItem.text())])
self._markersHandler.updateAllMarkers()
qt.QTableWidget.currentChanged(self, current, previous)
- @deprecation.deprecated(reason="Removed",
- replacement="roidict and roidict.values()",
- since_version="0.10.0")
- def getROIListAndDict(self):
- """
-
- :return: the list of roi objects and the dictionary of roi name to roi
- object.
- """
- roidict = self._roiDict
- return list(roidict.values()), roidict
-
- def calculateRois(self, roiList=None, roiDict=None):
- """
- Update values of all registred rois (raw and net counts in particular)
-
- :param roiList: deprecated parameter
- :param roiDict: deprecated parameter
- """
- if roiDict:
- deprecation.deprecated_warning(name='roiDict', type_='Parameter',
- reason='Unused parameter',
- since_version="0.10.0")
- if roiList:
- deprecation.deprecated_warning(name='roiList', type_='Parameter',
- reason='Unused parameter',
- since_version="0.10.0")
-
+ def calculateRois(self):
+ """Update values of all registred rois (raw and net counts in particular)"""
for roiID in self._roiDict:
self._updateRoiInfo(roiID)
def _updateMarker(self, roiID):
"""Make sure the marker of the given roi name is updated"""
- if self._showAllMarkers or (self.activeRoi
- and self.activeRoi.getName() == roiID):
+ if self._showAllMarkers or (
+ self.activeRoi and self.activeRoi.getName() == roiID
+ ):
self._updateMarkers()
def _updateMarkers(self):
@@ -865,7 +847,9 @@ class ROITable(TableWidget):
if not self.activeRoi or not self.plot:
return
assert isinstance(self.activeRoi, ROI)
- markerHandler = self._markersHandler.getMarkerHandler(self.activeRoi.getID())
+ markerHandler = self._markersHandler.getMarkerHandler(
+ self.activeRoi.getID()
+ )
if markerHandler is not None:
markerHandler.updateMarkers()
@@ -884,12 +868,16 @@ class ROITable(TableWidget):
if order is None or order.lower() == "none":
ordered_roilist = list(self._roiDict.values())
- res = OrderedDict([(roi.getName(), self._roiDict[roi.getID()]) for roi in ordered_roilist])
+ res = dict(
+ [(roi.getName(), self._roiDict[roi.getID()]) for roi in ordered_roilist]
+ )
else:
assert order in ["from", "to", "type", "netcounts", "rawcounts"]
- ordered_roilist = sorted(self._roiDict.keys(),
- key=lambda roi_id: self._roiDict[roi_id].get(order))
- res = OrderedDict([(roi.getName(), self._roiDict[id]) for id in ordered_roilist])
+ ordered_roilist = sorted(
+ self._roiDict.keys(),
+ key=lambda roi_id: self._roiDict[roi_id].get(order),
+ )
+ res = dict([(roi.getName(), self._roiDict[id]) for id in ordered_roilist])
return res
@@ -904,7 +892,7 @@ class ROITable(TableWidget):
for roiID, roi in self._roiDict.items():
roilist.append(roi.toDict())
roidict[roi.getName()] = roi.toDict()
- datadict = {'ROI': {'roilist': roilist, 'roidict': roidict}}
+ datadict = {"ROI": {"roilist": roilist, "roidict": roidict}}
dictdump.dump(datadict, filename)
def load(self, filename):
@@ -917,9 +905,9 @@ class ROITable(TableWidget):
rois = []
# Remove rawcounts and netcounts from ROIs
- for roiDict in roisDict['ROI']['roidict'].values():
- roiDict.pop('rawcounts', None)
- roiDict.pop('netcounts', None)
+ for roiDict in roisDict["ROI"]["roidict"].values():
+ roiDict.pop("rawcounts", None)
+ roiDict.pop("netcounts", None)
rois.append(ROI._fromDict(roiDict))
self.setRois(rois)
@@ -946,14 +934,13 @@ class ROITable(TableWidget):
def _handleROIMarkerEvent(self, ddict):
"""Handle plot signals related to marker events."""
- if ddict['event'] == 'markerMoved':
- label = ddict['label']
+ if ddict["event"] == "markerMoved":
+ label = ddict["label"]
roiID = self._markersHandler.getRoiID(markerID=label)
if roiID is not None:
# avoid several emission of sigROISignal
old = self.blockSignals(True)
- self._markersHandler.changePosition(markerID=label,
- x=ddict['x'])
+ self._markersHandler.changePosition(markerID=label, x=ddict["x"])
self.blockSignals(old)
self._updateRoiInfo(roiID)
@@ -994,11 +981,11 @@ class ROITable(TableWidget):
should be visible.
"""
if visible is True:
- self.showColumn(self.COLUMNS_INDEX['Raw Counts'])
- self.showColumn(self.COLUMNS_INDEX['Net Counts'])
+ self.showColumn(self.COLUMNS_INDEX["Raw Counts"])
+ self.showColumn(self.COLUMNS_INDEX["Net Counts"])
else:
- self.hideColumn(self.COLUMNS_INDEX['Raw Counts'])
- self.hideColumn(self.COLUMNS_INDEX['Net Counts'])
+ self.hideColumn(self.COLUMNS_INDEX["Raw Counts"])
+ self.hideColumn(self.COLUMNS_INDEX["Net Counts"])
def setAreaVisible(self, visible):
"""
@@ -1008,11 +995,11 @@ class ROITable(TableWidget):
should be visible.
"""
if visible is True:
- self.showColumn(self.COLUMNS_INDEX['Raw Area'])
- self.showColumn(self.COLUMNS_INDEX['Net Area'])
+ self.showColumn(self.COLUMNS_INDEX["Raw Area"])
+ self.showColumn(self.COLUMNS_INDEX["Net Area"])
else:
- self.hideColumn(self.COLUMNS_INDEX['Raw Area'])
- self.hideColumn(self.COLUMNS_INDEX['Net Area'])
+ self.hideColumn(self.COLUMNS_INDEX["Raw Area"])
+ self.hideColumn(self.COLUMNS_INDEX["Net Area"])
def fillFromROIDict(self, roilist=(), roidict=None, currentroi=None):
"""
@@ -1073,7 +1060,7 @@ class ROI(_RegionOfInterestBase):
self._fromdata = fromdata
self._todata = todata
- self._type = type_ or 'Default'
+ self._type = type_ or "Default"
self.sigItemChanged.connect(self.__itemChanged)
@@ -1150,27 +1137,27 @@ class ROI(_RegionOfInterestBase):
:return: dict containing the roi parameters
"""
ddict = {
- 'type': self._type,
- 'name': self.getName(),
- 'from': self._fromdata,
- 'to': self._todata,
+ "type": self._type,
+ "name": self.getName(),
+ "from": self._fromdata,
+ "to": self._todata,
}
- if hasattr(self, '_extraInfo'):
+ if hasattr(self, "_extraInfo"):
ddict.update(self._extraInfo)
return ddict
@staticmethod
def _fromDict(dic):
- assert 'name' in dic
- roi = ROI(name=dic['name'])
+ assert "name" in dic
+ roi = ROI(name=dic["name"])
roi._extraInfo = {}
for key in dic:
- if key == 'from':
- roi.setFrom(dic['from'])
- elif key == 'to':
- roi.setTo(dic['to'])
- elif key == 'type':
- roi.setType(dic['type'])
+ if key == "from":
+ roi.setFrom(dic["from"])
+ elif key == "to":
+ roi.setTo(dic["to"])
+ elif key == "type":
+ roi.setType(dic["type"])
else:
roi._extraInfo[key] = dic[key]
@@ -1181,7 +1168,7 @@ class ROI(_RegionOfInterestBase):
:return: True if the ROI is the `ICR`
"""
- return self.getName() == 'ICR'
+ return self.getName() == "ICR"
def computeRawAndNetCounts(self, curve):
"""Compute the Raw and net counts in the ROI for the given curve.
@@ -1206,8 +1193,7 @@ class ROI(_RegionOfInterestBase):
x = curve.getXData(copy=False)
y = curve.getYData(copy=False)
- idx = numpy.nonzero((self._fromdata <= x) &
- (x <= self._todata))[0]
+ idx = numpy.nonzero((self._fromdata <= x) & (x <= self._todata))[0]
if len(idx):
xw = x[idx]
yw = y[idx]
@@ -1215,10 +1201,9 @@ class ROI(_RegionOfInterestBase):
deltaX = xw[-1] - xw[0]
deltaY = yw[-1] - yw[0]
if deltaX > 0.0:
- slope = (deltaY / deltaX)
+ slope = deltaY / deltaX
background = yw[0] + slope * (xw - xw[0])
- netCounts = (rawCounts -
- background.sum(dtype=numpy.float64))
+ netCounts = rawCounts - background.sum(dtype=numpy.float64)
else:
netCounts = 0.0
else:
@@ -1274,6 +1259,7 @@ class _RoiMarkerManager(object):
"""
Deal with all the ROI markers
"""
+
def __init__(self):
self._roiMarkerHandlers = {}
self._middleROIMarkerFlag = False
@@ -1293,7 +1279,7 @@ class _RoiMarkerManager(object):
assert isinstance(roi, ROI)
assert isinstance(markersHandler, _RoiMarkerHandler)
if roi.getID() in self._roiMarkerHandlers:
- raise ValueError('roi with the same ID already existing')
+ raise ValueError("roi with the same ID already existing")
else:
self._roiMarkerHandlers[roi.getID()] = markersHandler
@@ -1323,25 +1309,30 @@ class _RoiMarkerManager(object):
def changePosition(self, markerID, x):
markerHandler = self.getMarker(markerID)
if markerHandler is None:
- raise ValueError('Marker %s not register' % markerID)
+ raise ValueError("Marker %s not register" % markerID)
markerHandler.changePosition(markerID=markerID, x=x)
def updateMarker(self, markerID):
markerHandler = self.getMarker(markerID)
if markerHandler is None:
- raise ValueError('Marker %s not register' % markerID)
+ raise ValueError("Marker %s not register" % markerID)
roiID = self.getRoiID(markerID)
- visible = (self._activeRoi and self._activeRoi.getID() == roiID) or self._showAllMarkers is True
+ visible = (
+ self._activeRoi and self._activeRoi.getID() == roiID
+ ) or self._showAllMarkers is True
markerHandler.setVisible(visible)
markerHandler.updateAllMarkers()
def updateRoiMarkers(self, roiID):
if roiID in self._roiMarkerHandlers:
- visible = ((self._activeRoi and self._activeRoi.getID() == roiID)
- or self._showAllMarkers is True)
+ visible = (
+ self._activeRoi and self._activeRoi.getID() == roiID
+ ) or self._showAllMarkers is True
_roi = self._roiMarkerHandlers[roiID]._roi()
if _roi and not _roi.isICR():
- self._roiMarkerHandlers[roiID].showMiddleMarker(self._middleROIMarkerFlag)
+ self._roiMarkerHandlers[roiID].showMiddleMarker(
+ self._middleROIMarkerFlag
+ )
self._roiMarkerHandlers[roiID].setVisible(visible)
self._roiMarkerHandlers[roiID].updateMarkers()
@@ -1372,8 +1363,11 @@ class _RoiMarkerManager(object):
def getVisibleRois(self):
res = {}
for roiID, roiHandler in self._roiMarkerHandlers.items():
- markers = (roiHandler.getMarker('min'), roiHandler.getMarker('max'),
- roiHandler.getMarker('middle'))
+ markers = (
+ roiHandler.getMarker("min"),
+ roiHandler.getMarker("max"),
+ roiHandler.getMarker("middle"),
+ )
for marker in markers:
if marker.isVisible():
if roiID not in res:
@@ -1384,6 +1378,7 @@ class _RoiMarkerManager(object):
class _RoiMarkerHandler(object):
"""Used to deal with ROI markers used in ROITable"""
+
def __init__(self, roi, plot):
assert roi and isinstance(roi, ROI)
assert plot
@@ -1391,7 +1386,7 @@ class _RoiMarkerHandler(object):
self._roi = weakref.ref(roi)
self._plot = weakref.ref(plot)
self._draggable = False if roi.isICR() else True
- self._color = 'black' if roi.isICR() else 'blue'
+ self._color = "black" if roi.isICR() else "blue"
self._displayMidMarker = False
self._visible = True
@@ -1405,9 +1400,9 @@ class _RoiMarkerHandler(object):
def clear(self):
if self.plot and self.roi:
- self.plot.removeMarker(self._markerID('min'))
- self.plot.removeMarker(self._markerID('max'))
- self.plot.removeMarker(self._markerID('middle'))
+ self.plot.removeMarker(self._markerID("min"))
+ self.plot.removeMarker(self._markerID("max"))
+ self.plot.removeMarker(self._markerID("middle"))
@property
def roi(self):
@@ -1423,7 +1418,7 @@ class _RoiMarkerHandler(object):
_logger.warning("ROI is not draggable. Won't display middle marker")
return
self._displayMidMarker = visible
- self.getMarker('middle').setVisible(self._displayMidMarker)
+ self.getMarker("middle").setVisible(self._displayMidMarker)
def updateMarkers(self):
if self.roi is None:
@@ -1433,54 +1428,56 @@ class _RoiMarkerHandler(object):
self._updateMiddleMarkerPos()
def _updateMinMarkerPos(self):
- self.getMarker('min').setPosition(x=self.roi.getFrom(), y=None)
- self.getMarker('min').setVisible(self._visible)
+ self.getMarker("min").setPosition(x=self.roi.getFrom(), y=None)
+ self.getMarker("min").setVisible(self._visible)
def _updateMaxMarkerPos(self):
- self.getMarker('max').setPosition(x=self.roi.getTo(), y=None)
- self.getMarker('max').setVisible(self._visible)
+ self.getMarker("max").setPosition(x=self.roi.getTo(), y=None)
+ self.getMarker("max").setVisible(self._visible)
def _updateMiddleMarkerPos(self):
- self.getMarker('middle').setPosition(x=self.roi.getMiddle(), y=None)
- self.getMarker('middle').setVisible(self._displayMidMarker and self._visible)
+ self.getMarker("middle").setPosition(x=self.roi.getMiddle(), y=None)
+ self.getMarker("middle").setVisible(self._displayMidMarker and self._visible)
def getMarker(self, markerType):
if self.plot is None:
return None
- assert markerType in ('min', 'max', 'middle')
+ assert markerType in ("min", "max", "middle")
if self.plot._getMarker(self._markerID(markerType)) is None:
assert self.roi
- if markerType == 'min':
+ if markerType == "min":
val = self.roi.getFrom()
- elif markerType == 'max':
+ elif markerType == "max":
val = self.roi.getTo()
else:
val = self.roi.getMiddle()
_color = self._color
- if markerType == 'middle':
- _color = 'yellow'
- self.plot.addXMarker(val,
- legend=self._markerID(markerType),
- text=self.getMarkerName(markerType),
- color=_color,
- draggable=self.draggable)
+ if markerType == "middle":
+ _color = "yellow"
+ self.plot.addXMarker(
+ val,
+ legend=self._markerID(markerType),
+ text=self.getMarkerName(markerType),
+ color=_color,
+ draggable=self.draggable,
+ )
return self.plot._getMarker(self._markerID(markerType))
def _markerID(self, markerType):
- assert markerType in ('min', 'max', 'middle')
+ assert markerType in ("min", "max", "middle")
assert self.roi
- return '_'.join((str(self.roi.getID()), markerType))
+ return "_".join((str(self.roi.getID()), markerType))
def getMarkerName(self, markerType):
- assert markerType in ('min', 'max', 'middle')
+ assert markerType in ("min", "max", "middle")
assert self.roi
- return ' '.join((self.roi.getName(), markerType))
+ return " ".join((self.roi.getName(), markerType))
def updateTexts(self):
- self.getMarker('min').setText(self.getMarkerName('min'))
- self.getMarker('max').setText(self.getMarkerName('max'))
- self.getMarker('middle').setText(self.getMarkerName('middle'))
+ self.getMarker("min").setText(self.getMarkerName("min"))
+ self.getMarker("max").setText(self.getMarkerName("max"))
+ self.getMarker("middle").setText(self.getMarkerName("middle"))
def changePosition(self, markerID, x):
assert self.hasMarker(markerID)
@@ -1488,10 +1485,10 @@ class _RoiMarkerHandler(object):
assert markerType is not None
if self.roi is None:
return
- if markerType == 'min':
+ if markerType == "min":
self.roi.setFrom(x)
self._updateMiddleMarkerPos()
- elif markerType == 'max':
+ elif markerType == "max":
self.roi.setTo(x)
self._updateMiddleMarkerPos()
else:
@@ -1502,17 +1499,19 @@ class _RoiMarkerHandler(object):
self._updateMaxMarkerPos()
def hasMarker(self, marker):
- return marker in (self._markerID('min'),
- self._markerID('max'),
- self._markerID('middle'))
+ return marker in (
+ self._markerID("min"),
+ self._markerID("max"),
+ self._markerID("middle"),
+ )
def _getMarkerType(self, markerID):
- if markerID.endswith('_min'):
- return 'min'
- elif markerID.endswith('_max'):
- return 'max'
- elif markerID.endswith('_middle'):
- return 'middle'
+ if markerID.endswith("_min"):
+ return "min"
+ elif markerID.endswith("_max"):
+ return "max"
+ elif markerID.endswith("_middle"):
+ return "middle"
else:
return None
@@ -1526,6 +1525,7 @@ class CurvesROIDockWidget(qt.QDockWidget):
:param plot: :class:`.PlotWindow` instance on which to operate
:param name: See :class:`QDockWidget`
"""
+
sigROISignal = qt.Signal(object)
"""Deprecated signal for backward compatibility with silx < 0.7.
Prefer connecting directly to :attr:`CurvesRoiWidget.sigRoiSignal`
@@ -1564,7 +1564,7 @@ class CurvesROIDockWidget(qt.QDockWidget):
See :class:`QMainWindow`.
"""
action = super(CurvesROIDockWidget, self).toggleViewAction()
- action.setIcon(icons.getQIcon('plot-roi'))
+ action.setIcon(icons.getQIcon("plot-roi"))
return action
@property
diff --git a/src/silx/gui/plot/ImageStack.py b/src/silx/gui/plot/ImageStack.py
index e2bed9d..175d6e4 100644
--- a/src/silx/gui/plot/ImageStack.py
+++ b/src/silx/gui/plot/ImageStack.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2020-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2020-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -23,118 +23,35 @@
# ###########################################################################*/
"""Image stack view with data prefetch capabilty."""
+from __future__ import annotations
+
__authors__ = ["H. Payno"]
__license__ = "MIT"
__date__ = "04/03/2019"
-from silx.gui import icons, qt
+from silx.gui import qt
from silx.gui.plot import Plot2D
-from silx.gui.utils import concurrent
from silx.io.url import DataUrl
from silx.io.utils import get_data
-from collections import OrderedDict
from silx.gui.widgets.FrameBrowser import HorizontalSliderWithBrowser
-import time
-import threading
+from silx.gui.widgets.UrlList import UrlList
+from silx.gui.utils import blockSignals
+from silx.utils.deprecation import deprecated
+
import typing
import logging
+from silx.gui.widgets.WaitingOverlay import WaitingOverlay
+from collections.abc import Iterable
_logger = logging.getLogger(__name__)
-class _PlotWithWaitingLabel(qt.QWidget):
- """Image plot widget with an overlay 'waiting' status.
- """
-
- class AnimationThread(threading.Thread):
- def __init__(self, label):
- self.running = True
- self._label = label
- self.animated_icon = icons.getWaitIcon()
- self.animated_icon.register(self._label)
- super(_PlotWithWaitingLabel.AnimationThread, self).__init__()
-
- def run(self):
- while self.running:
- time.sleep(0.05)
- icon = self.animated_icon.currentIcon()
- self.future_result = concurrent.submitToQtMainThread(
- self._label.setPixmap, icon.pixmap(30, state=qt.QIcon.On))
-
- def stop(self):
- """Stop the update thread"""
- if self.running:
- self.animated_icon.unregister(self._label)
- self.running = False
- self.join(2)
-
- def __init__(self, parent):
- super(_PlotWithWaitingLabel, self).__init__(parent=parent)
- self._autoResetZoom = True
- layout = qt.QStackedLayout(self)
- layout.setStackingMode(qt.QStackedLayout.StackAll)
-
- self._waiting_label = qt.QLabel(parent=self)
- self._waiting_label.setAlignment(qt.Qt.AlignHCenter | qt.Qt.AlignVCenter)
- layout.addWidget(self._waiting_label)
-
- self._plot = Plot2D(parent=self)
- layout.addWidget(self._plot)
-
- self.updateThread = _PlotWithWaitingLabel.AnimationThread(self._waiting_label)
- self.updateThread.start()
-
- def close(self) -> bool:
- super(_PlotWithWaitingLabel, self).close()
- self.stopUpdateThread()
-
- def stopUpdateThread(self):
- self.updateThread.stop()
-
- def setAutoResetZoom(self, reset):
- """
- Should we reset the zoom when adding an image (eq. when browsing)
-
- :param bool reset:
- """
- self._autoResetZoom = reset
- if self._autoResetZoom:
- self._plot.resetZoom()
-
- def isAutoResetZoom(self):
- """
-
- :return: True if a reset is done when the image change
- :rtype: bool
- """
- return self._autoResetZoom
-
- def setWaiting(self, activate=True):
- if activate is True:
- self._plot.clear()
- self._waiting_label.show()
- else:
- self._waiting_label.hide()
-
- def setData(self, data):
- self.setWaiting(activate=False)
- self._plot.addImage(data=data, resetzoom=self._autoResetZoom)
-
- def clear(self):
- self._plot.clear()
- self.setWaiting(False)
-
- def getPlotWidget(self):
- return self._plot
-
-
class _HorizontalSlider(HorizontalSliderWithBrowser):
-
sigCurrentUrlIndexChanged = qt.Signal(int)
def __init__(self, parent):
- super(_HorizontalSlider, self).__init__(parent=parent)
+ super().__init__(parent=parent)
# connect signal / slot
self.valueChanged.connect(self._urlChanged)
@@ -146,67 +63,23 @@ class _HorizontalSlider(HorizontalSliderWithBrowser):
self.sigCurrentUrlIndexChanged.emit(value)
-class UrlList(qt.QWidget):
- """List of URLs the user to select an URL"""
-
- sigCurrentUrlChanged = qt.Signal(str)
- """Signal emitted when the active/current url change"""
-
- def __init__(self, parent=None):
- super(UrlList, self).__init__(parent)
- self.setLayout(qt.QVBoxLayout())
- self.layout().setSpacing(0)
- self.layout().setContentsMargins(0, 0, 0, 0)
- self._listWidget = qt.QListWidget(parent=self)
- self.layout().addWidget(self._listWidget)
-
- # connect signal / Slot
- self._listWidget.currentItemChanged.connect(self._notifyCurrentUrlChanged)
-
- # expose API
- self.currentItem = self._listWidget.currentItem
-
- def setUrls(self, urls: list) -> None:
- url_names = []
- [url_names.append(url.path()) for url in urls]
- self._listWidget.addItems(url_names)
-
- def _notifyCurrentUrlChanged(self, current, previous):
- if current is None:
- pass
- else:
- self.sigCurrentUrlChanged.emit(current.text())
-
- def setUrl(self, url: DataUrl) -> None:
- assert isinstance(url, DataUrl)
- sel_items = self._listWidget.findItems(url.path(), qt.Qt.MatchExactly)
- if sel_items is None:
- _logger.warning(url.path(), ' is not registered in the list.')
- elif len(sel_items) > 0:
- item = sel_items[0]
- self._listWidget.setCurrentItem(item)
- self.sigCurrentUrlChanged.emit(item.text())
-
- def clear(self):
- self._listWidget.clear()
-
-
class _ToggleableUrlSelectionTable(qt.QWidget):
-
_BUTTON_ICON = qt.QStyle.SP_ToolBarHorizontalExtensionButton # noqa
sigCurrentUrlChanged = qt.Signal(str)
"""Signal emitted when the active/current url change"""
+ sigUrlRemoved = qt.Signal(str)
+
def __init__(self, parent=None) -> None:
- qt.QWidget.__init__(self, parent)
+ super().__init__(parent)
self.setLayout(qt.QGridLayout())
self._toggleButton = qt.QPushButton(parent=self)
self.layout().addWidget(self._toggleButton, 0, 2, 1, 1)
- self._toggleButton.setSizePolicy(qt.QSizePolicy.Fixed,
- qt.QSizePolicy.Fixed)
+ self._toggleButton.setSizePolicy(qt.QSizePolicy.Fixed, qt.QSizePolicy.Fixed)
self._urlsTable = UrlList(parent=self)
+
self.layout().addWidget(self._urlsTable, 1, 1, 1, 2)
# set up
@@ -214,12 +87,8 @@ class _ToggleableUrlSelectionTable(qt.QWidget):
# Signal / slot connection
self._toggleButton.clicked.connect(self.toggleUrlSelectionTable)
- self._urlsTable.sigCurrentUrlChanged.connect(self._propagateSignal)
-
- # expose API
- self.setUrls = self._urlsTable.setUrls
- self.setUrl = self._urlsTable.setUrl
- self.currentItem = self._urlsTable.currentItem
+ self._urlsTable.sigCurrentUrlChanged.connect(self.sigCurrentUrlChanged)
+ self._urlsTable.sigUrlRemoved.connect(self.sigUrlRemoved)
def toggleUrlSelectionTable(self):
visible = not self.urlSelectionTableIsVisible()
@@ -236,21 +105,36 @@ class _ToggleableUrlSelectionTable(qt.QWidget):
self._toggleButton.setIcon(icon)
def urlSelectionTableIsVisible(self):
- return self._urlsTable.isVisible()
-
- def _propagateSignal(self, url):
- self.sigCurrentUrlChanged.emit(url)
+ return self._urlsTable.isVisibleTo(self)
def clear(self):
self._urlsTable.clear()
+ # expose UrlList API
+ @deprecated(replacement="addUrls", since_version="2.0")
+ def setUrls(self, urls: Iterable[DataUrl]):
+ self._urlsTable.addUrls(urls=urls)
+
+ def addUrls(self, urls: Iterable[DataUrl]):
+ self._urlsTable.addUrls(urls=urls)
+
+ def setUrl(self, url: typing.Optional[DataUrl]):
+ self._urlsTable.setUrl(url=url)
+
+ def removeUrl(self, url: str):
+ self._urlsTable.removeUrl(url)
+
+ def currentItem(self):
+ return self._urlsTable.currentItem()
+
class UrlLoader(qt.QThread):
"""
Thread use to load DataUrl
"""
+
def __init__(self, parent, url):
- super(UrlLoader, self).__init__(parent=parent)
+ super().__init__(parent=parent)
assert isinstance(url, DataUrl)
self.url = url
self.data = None
@@ -277,17 +161,21 @@ class ImageStack(qt.QMainWindow):
"""Signal emitted when the current url change"""
def __init__(self, parent=None) -> None:
- super(ImageStack, self).__init__(parent)
+ super().__init__(parent)
self.__n_prefetch = ImageStack.N_PRELOAD
self._loadingThreads = []
self.setWindowFlags(qt.Qt.Widget)
self._current_url = None
self._url_loader = UrlLoader
"class to instantiate for loading urls"
+ self._autoResetZoom = True
# main widget
- self._plot = _PlotWithWaitingLabel(parent=self)
+ self._plot = Plot2D(parent=self)
self._plot.setAttribute(qt.Qt.WA_DeleteOnClose, True)
+ self._waitingOverlay = WaitingOverlay(self._plot)
+ self._waitingOverlay.setIconSize(qt.QSize(30, 30))
+ self._waitingOverlay.hide()
self.setWindowTitle("Image stack")
self.setCentralWidget(self._plot)
@@ -308,12 +196,14 @@ class ImageStack(qt.QMainWindow):
# connect signal / slot
self._urlsTable.sigCurrentUrlChanged.connect(self.setCurrentUrl)
+ self._urlsTable.sigUrlRemoved.connect(self.removeUrl)
self._slider.sigCurrentUrlIndexChanged.connect(self.setCurrentUrlIndex)
def close(self) -> bool:
self._freeLoadingThreads()
+ self._waitingOverlay.close()
self._plot.close()
- super(ImageStack, self).close()
+ super().close()
def setUrlLoaderClass(self, urlLoader: typing.Type[UrlLoader]) -> None:
"""
@@ -346,14 +236,14 @@ class ImageStack(qt.QMainWindow):
:return: PlotWidget contained in this window
:rtype: Plot2D
"""
- return self._plot.getPlotWidget()
+ return self._plot
def reset(self) -> None:
"""Clear the plot and remove any link to url"""
self._freeLoadingThreads()
self._urls = None
self._urlIndexes = None
- self._urlData = OrderedDict({})
+ self._urlData = {}
self._current_url = None
self._plot.clear()
self._urlsTable.clear()
@@ -396,7 +286,8 @@ class ImageStack(qt.QMainWindow):
if url in self._urlIndexes:
self._urlData[url] = sender.data
if self.getCurrentUrl().path() == url:
- self._plot.setData(self._urlData[url])
+ self._waitingOverlay.setVisible(False)
+ self._plot.addImage(self._urlData[url], resetzoom=self._autoResetZoom)
if sender in self._loadingThreads:
self._loadingThreads.remove(sender)
self.sigLoaded.emit(url)
@@ -421,6 +312,29 @@ class ImageStack(qt.QMainWindow):
"""
return self.__n_prefetch
+ def setUrlsEditable(self, editable: bool):
+ self._urlsTable._urlsTable.setEditable(editable)
+ if editable:
+ selection_mode = qt.QAbstractItemView.ExtendedSelection
+ else:
+ selection_mode = qt.QAbstractItemView.SingleSelection
+ self._urlsTable._urlsTable.setSelectionMode(selection_mode)
+
+ @staticmethod
+ def createUrlIndexes(urls: tuple):
+ indexes = {}
+ for index, url in enumerate(urls):
+ assert isinstance(
+ url, DataUrl
+ ), f"url is expected to be a DataUrl. Get {type(url)}"
+ indexes[index] = url
+ return indexes
+
+ def _resetSlider(self):
+ with blockSignals(self._slider):
+ self._slider.setMinimum(0)
+ self._slider.setMaximum(len(self._urls) - 1)
+
def setUrls(self, urls: list) -> None:
"""list of urls within an index. Warning: urls should contain an image
compatible with the silx.gui.plot.Plot class
@@ -429,26 +343,16 @@ class ImageStack(qt.QMainWindow):
(position in the stack), value is the DataUrl
:type: list
"""
- def createUrlIndexes():
- indexes = OrderedDict()
- for index, url in enumerate(urls):
- indexes[index] = url
- return indexes
-
- urls_with_indexes = createUrlIndexes()
+ urls_with_indexes = self.createUrlIndexes(urls=urls)
urlsToIndex = self._urlsToIndex(urls_with_indexes)
self.reset()
self._urls = urls_with_indexes
self._urlIndexes = urlsToIndex
- old_url_table = self._urlsTable.blockSignals(True)
- self._urlsTable.setUrls(urls=list(self._urls.values()))
- self._urlsTable.blockSignals(old_url_table)
+ with blockSignals(self._urlsTable):
+ self._urlsTable.addUrls(urls=list(self._urls.values()))
- old_slider = self._slider.blockSignals(True)
- self._slider.setMinimum(0)
- self._slider.setMaximum(len(self._urls) - 1)
- self._slider.blockSignals(old_slider)
+ self._resetSlider()
if self.getCurrentUrl() in self._urls:
self.setCurrentUrl(self.getCurrentUrl())
@@ -457,6 +361,35 @@ class ImageStack(qt.QMainWindow):
first_url = self._urls[list(self._urls.keys())[0]]
self.setCurrentUrl(first_url)
+ def removeUrl(self, url: str) -> None:
+ """
+ Remove provided URL from the table
+
+ :param url: URL as str
+ """
+ # remove the given urls from self._urls and self._urlIndexes
+ if not isinstance(url, str):
+ raise TypeError("url is expected to be the str representation of the url")
+
+ # try to get reset the url displayed
+ current_url = self.getCurrentUrl()
+ with blockSignals(self._urlsTable):
+ self._urlsTable.removeUrl(url)
+ # update urls
+ urls_with_indexes = self.createUrlIndexes(
+ filter(
+ lambda a: a.path() != url,
+ self._urls.values(),
+ )
+ )
+ urlsToIndex = self._urlsToIndex(urls_with_indexes)
+ self._urls = urls_with_indexes
+ self._urlIndexes = urlsToIndex
+ self._resetSlider()
+
+ if current_url != url:
+ self.setCurrentUrl(current_url)
+
def getUrls(self) -> tuple:
"""
@@ -555,41 +488,46 @@ class ImageStack(qt.QMainWindow):
if self._urls is None:
return
elif index >= len(self._urls):
- raise ValueError('requested index out of bounds')
+ raise ValueError("requested index out of bounds")
else:
return self.setCurrentUrl(self._urls[index])
- def setCurrentUrl(self, url: typing.Union[DataUrl, str]) -> None:
+ def setCurrentUrl(self, url: typing.Optional[typing.Union[DataUrl, str]]) -> None:
"""
Define the url to be displayed
:param url: url to be displayed
:type: DataUrl
+ :raises KeyError: raised if the url is not know
"""
- assert isinstance(url, (DataUrl, str))
- if isinstance(url, str):
+ assert isinstance(url, (DataUrl, str, type(None)))
+ if url == "":
+ url = None
+ elif isinstance(url, str):
url = DataUrl(path=url)
- if url != self._current_url:
+ if url is not None and url != self._current_url:
self._current_url = url
self.sigCurrentUrlChanged.emit(url.path())
- old_url_table = self._urlsTable.blockSignals(True)
- old_slider = self._slider.blockSignals(True)
-
- self._urlsTable.setUrl(url)
- self._slider.setUrlIndex(self._urlIndexes[url.path()])
- if self._current_url is None:
- self._plot.clear()
- else:
- if self._current_url.path() in self._urlData:
- self._plot.setData(self._urlData[url.path()])
- else:
- self._load(url)
- self._notifyLoading()
- self._preFetch(self._getNNextUrls(self.__n_prefetch, url))
- self._preFetch(self._getNPreviousUrls(self.__n_prefetch, url))
- self._urlsTable.blockSignals(old_url_table)
- self._slider.blockSignals(old_slider)
+ with blockSignals(self._urlsTable):
+ with blockSignals(self._slider):
+ self._urlsTable.setUrl(url)
+ if url is not None:
+ self._slider.setUrlIndex(self._urlIndexes[url.path()])
+ if self._current_url is None:
+ self._plot.clear()
+ else:
+ if self._current_url.path() in self._urlData:
+ self._waitingOverlay.setVisible(False)
+ self._plot.addImage(
+ self._urlData[url.path()], resetzoom=self._autoResetZoom
+ )
+ else:
+ self._plot.clear()
+ self._load(url)
+ self._waitingOverlay.setVisible(True)
+ self._preFetch(self._getNNextUrls(self.__n_prefetch, url))
+ self._preFetch(self._getNPreviousUrls(self.__n_prefetch, url))
def getCurrentUrl(self) -> typing.Union[None, DataUrl]:
"""
@@ -618,17 +556,15 @@ class ImageStack(qt.QMainWindow):
res[url.path()] = index
return res
- def _notifyLoading(self):
- """display a simple image of loading..."""
- self._plot.setWaiting(activate=True)
-
def setAutoResetZoom(self, reset):
"""
Should we reset the zoom when adding an image (eq. when browsing)
:param bool reset:
"""
- self._plot.setAutoResetZoom(reset)
+ self._autoResetZoom = reset
+ if self._autoResetZoom:
+ self._plot.resetZoom()
def isAutoResetZoom(self) -> bool:
"""
@@ -636,4 +572,12 @@ class ImageStack(qt.QMainWindow):
:return: True if a reset is done when the image change
:rtype: bool
"""
- return self._plot.isAutoResetZoom()
+ return self._autoResetZoom
+
+ def getWaiterOverlay(self):
+ """
+
+ :return: Return the instance of `WaitingOverlay` used to display if processing or not
+ :rtype: WaitingOverlay
+ """
+ return self._waitingOverlay
diff --git a/src/silx/gui/plot/ImageView.py b/src/silx/gui/plot/ImageView.py
index a451b2d..eaca42b 100644
--- a/src/silx/gui/plot/ImageView.py
+++ b/src/silx/gui/plot/ImageView.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2015-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2015-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -63,19 +63,27 @@ from .tools.RadarView import RadarView
from .utils.axis import SyncAxes
from ..utils import blockSignals
from . import _utils
-from .tools.profile import manager
from .tools.profile import rois
from .actions import PlotAction
_logger = logging.getLogger(__name__)
-ProfileSumResult = collections.namedtuple("ProfileResult",
- ["dataXRange", "dataYRange",
- 'histoH', 'histoHRange',
- 'histoV', 'histoVRange',
- "xCoords", "xData",
- "yCoords", "yData"])
+ProfileSumResult = collections.namedtuple(
+ "ProfileResult",
+ [
+ "dataXRange",
+ "dataYRange",
+ "histoH",
+ "histoHRange",
+ "histoV",
+ "histoVRange",
+ "xCoords",
+ "xData",
+ "yCoords",
+ "yData",
+ ],
+)
def computeProfileSumOnRange(imageItem, xRange, yRange, cache=None):
@@ -103,8 +111,7 @@ def computeProfileSumOnRange(imageItem, xRange, yRange, cache=None):
yMin = int((yMin - origin[1]) / scale[1])
yMax = int((yMax - origin[1]) / scale[1])
- if (xMin >= width or xMax < 0 or
- yMin >= height or yMax < 0):
+ if xMin >= width or xMax < 0 or yMin >= height or yMax < 0:
return None
# The image is at least partly in the plot area
@@ -115,14 +122,15 @@ def computeProfileSumOnRange(imageItem, xRange, yRange, cache=None):
subsetYMax = (height if yMax >= height else yMax) + 1
if cache is not None:
- if ((subsetXMin, subsetXMax) == cache.dataXRange and
- (subsetYMin, subsetYMax) == cache.dataYRange):
+ if (subsetXMin, subsetXMax) == cache.dataXRange and (
+ subsetYMin,
+ subsetYMax,
+ ) == cache.dataYRange:
# The visible area of data is the same
return cache
# Rebuild histograms for visible area
- visibleData = data[subsetYMin:subsetYMax,
- subsetXMin:subsetXMax]
+ visibleData = data[subsetYMin:subsetYMax, subsetXMin:subsetXMax]
histoHVisibleData = numpy.nansum(visibleData, axis=0)
histoVVisibleData = numpy.nansum(visibleData, axis=1)
histoHMin = numpy.nanmin(histoHVisibleData)
@@ -151,7 +159,8 @@ def computeProfileSumOnRange(imageItem, xRange, yRange, cache=None):
xCoords=xCoords,
xData=xData,
yCoords=yCoords,
- yData=yData)
+ yData=yData,
+ )
return result
@@ -177,8 +186,8 @@ class _SideHistogram(PlotWidget):
def _plotEvents(self, eventDict):
"""Callback for horizontal histogram plot events."""
- if eventDict['event'] == 'mouseMoved':
- self.sigMouseMoved.emit(eventDict['x'], eventDict['y'])
+ if eventDict["event"] == "mouseMoved":
+ self.sigMouseMoved.emit(eventDict["x"], eventDict["y"])
def setProfileColor(self, color):
self._color = color
@@ -218,13 +227,13 @@ class _SideHistogram(PlotWidget):
profileSum = self.__profileSum
try:
- self.removeCurve('profile')
+ self.removeCurve("profile")
except Exception:
pass
if profileSum is None:
try:
- self.removeCurve('profilesum')
+ self.removeCurve("profilesum")
except Exception:
pass
return
@@ -236,13 +245,17 @@ class _SideHistogram(PlotWidget):
else:
assert False
- self.addCurve(xx, yy,
- xlabel='', ylabel='',
- legend="profilesum",
- color=self._color,
- linestyle='-',
- selectable=False,
- resetzoom=False)
+ self.addCurve(
+ xx,
+ yy,
+ xlabel="",
+ ylabel="",
+ legend="profilesum",
+ color=self._color,
+ linestyle="-",
+ selectable=False,
+ resetzoom=False,
+ )
self.__updateLimits()
@@ -254,13 +267,13 @@ class _SideHistogram(PlotWidget):
profile = self.__profile
try:
- self.removeCurve('profilesum')
+ self.removeCurve("profilesum")
except Exception:
pass
if profile is None:
try:
- self.removeCurve('profile')
+ self.removeCurve("profile")
except Exception:
pass
self.setProfileSum(self.__profileSum)
@@ -273,11 +286,7 @@ class _SideHistogram(PlotWidget):
else:
assert False
- self.addCurve(xx,
- yy,
- legend="profile",
- color=self._roiColor,
- resetzoom=False)
+ self.addCurve(xx, yy, legend="profile", color=self._roiColor, resetzoom=False)
self.__updateLimits()
@@ -299,9 +308,13 @@ class _SideHistogram(PlotWidget):
# Tune the result using the data margins
margins = self.getDataMargins()
if self._direction == qt.Qt.Horizontal:
- _, _, vMin, vMax = _utils.addMarginsToLimits(margins, False, False, 0, 0, vMin, vMax)
+ _, _, vMin, vMax = _utils.addMarginsToLimits(
+ margins, False, False, 0, 0, vMin, vMax
+ )
elif self._direction == qt.Qt.Vertical:
- vMin, vMax, _, _ = _utils.addMarginsToLimits(margins, False, False, vMin, vMax, 0, 0)
+ vMin, vMax, _, _ = _utils.addMarginsToLimits(
+ margins, False, False, vMin, vMax, 0, 0
+ )
else:
assert False
@@ -325,10 +338,14 @@ class ShowSideHistogramsAction(PlotAction):
def __init__(self, plot, parent=None):
super(ShowSideHistogramsAction, self).__init__(
- plot, icon='side-histograms', text='Show/hide side histograms',
- tooltip='Show/hide side histogram',
+ plot,
+ icon="side-histograms",
+ text="Show/hide side histograms",
+ tooltip="Show/hide side histogram",
triggered=self._actionTriggered,
- checkable=True, parent=parent)
+ checkable=True,
+ parent=parent,
+ )
def _actionTriggered(self, checked=False):
if self.plot.isSideHistogramDisplayed() != checked:
@@ -349,25 +366,33 @@ class AggregationModeAction(qt.QWidgetAction):
filterAction.setText("No filter")
filterAction.setCheckable(True)
filterAction.setChecked(True)
- filterAction.setProperty("aggregation", items.ImageDataAggregated.Aggregation.NONE)
+ filterAction.setProperty(
+ "aggregation", items.ImageDataAggregated.Aggregation.NONE
+ )
densityNoFilterAction = filterAction
filterAction = qt.QAction(self)
filterAction.setText("Max filter")
filterAction.setCheckable(True)
- filterAction.setProperty("aggregation", items.ImageDataAggregated.Aggregation.MAX)
+ filterAction.setProperty(
+ "aggregation", items.ImageDataAggregated.Aggregation.MAX
+ )
densityMaxFilterAction = filterAction
filterAction = qt.QAction(self)
filterAction.setText("Mean filter")
filterAction.setCheckable(True)
- filterAction.setProperty("aggregation", items.ImageDataAggregated.Aggregation.MEAN)
+ filterAction.setProperty(
+ "aggregation", items.ImageDataAggregated.Aggregation.MEAN
+ )
densityMeanFilterAction = filterAction
filterAction = qt.QAction(self)
filterAction.setText("Min filter")
filterAction.setCheckable(True)
- filterAction.setProperty("aggregation", items.ImageDataAggregated.Aggregation.MIN)
+ filterAction.setProperty(
+ "aggregation", items.ImageDataAggregated.Aggregation.MIN
+ )
densityMinFilterAction = filterAction
densityGroup = qt.QActionGroup(self)
@@ -428,7 +453,7 @@ class ImageView(PlotWindow):
:type backend: str or :class:`BackendBase.BackendBase`
"""
- HISTOGRAMS_COLOR = 'blue'
+ HISTOGRAMS_COLOR = "blue"
"""Color to use for the side histograms."""
HISTOGRAMS_HEIGHT = 200
@@ -452,26 +477,37 @@ class ImageView(PlotWindow):
class ProfileWindowBehavior(Enum):
"""ImageView's profile window behavior options"""
- POPUP = 'popup'
+ POPUP = "popup"
"""All profiles are displayed in pop-up windows"""
- EMBEDDED = 'embedded'
+ EMBEDDED = "embedded"
"""Horizontal, vertical and cross profiles are displayed in
sides widgets, others are displayed in pop-up windows.
"""
def __init__(self, parent=None, backend=None):
- self._imageLegend = '__ImageView__image' + str(id(self))
+ self._imageLegend = "__ImageView__image" + str(id(self))
self._cache = None # Store currently visible data information
- super(ImageView, self).__init__(parent=parent, backend=backend,
- resetzoom=True, autoScale=False,
- logScale=False, grid=False,
- curveStyle=False, colormap=True,
- aspectRatio=True, yInverted=True,
- copy=True, save=True, print_=True,
- control=False, position=False,
- roi=False, mask=True)
+ super(ImageView, self).__init__(
+ parent=parent,
+ backend=backend,
+ resetzoom=True,
+ autoScale=False,
+ logScale=False,
+ grid=False,
+ curveStyle=False,
+ colormap=True,
+ aspectRatio=True,
+ yInverted=True,
+ copy=True,
+ save=True,
+ print_=True,
+ control=False,
+ position=False,
+ roi=False,
+ mask=True,
+ )
# Enable mask synchronisation to use it in profiles
maskToolsWidget = self.getMaskToolsDockWidget().widget()
@@ -481,12 +517,14 @@ class ImageView(PlotWindow):
self.__showSideHistogramsAction.setChecked(True)
self.__aggregationModeAction = AggregationModeAction(self)
- self.__aggregationModeAction.sigAggregationModeChanged.connect(self._aggregationModeChanged)
+ self.__aggregationModeAction.sigAggregationModeChanged.connect(
+ self._aggregationModeChanged
+ )
if parent is None:
- self.setWindowTitle('ImageView')
+ self.setWindowTitle("ImageView")
- if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == 'downward':
+ if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == "downward":
self.getYAxis().setInverted(True)
self._initWidgets(backend)
@@ -501,26 +539,32 @@ class ImageView(PlotWindow):
def _initWidgets(self, backend):
"""Set-up layout and plots."""
- self._histoHPlot = _SideHistogram(backend=backend, parent=self, direction=qt.Qt.Horizontal)
+ self._histoHPlot = _SideHistogram(
+ backend=backend, parent=self, direction=qt.Qt.Horizontal
+ )
widgetHandle = self._histoHPlot.getWidgetHandle()
widgetHandle.setMinimumHeight(self.HISTOGRAMS_HEIGHT)
widgetHandle.setMaximumHeight(self.HISTOGRAMS_HEIGHT)
- self._histoHPlot.setInteractiveMode('zoom')
- self._histoHPlot.setDataMargins(0., 0., 0.1, 0.1)
+ self._histoHPlot.setInteractiveMode("zoom")
+ self._histoHPlot.setDataMargins(0.0, 0.0, 0.1, 0.1)
self._histoHPlot.sigMouseMoved.connect(self._mouseMovedOnHistoH)
self._histoHPlot.setProfileColor(self.HISTOGRAMS_COLOR)
- self._histoVPlot = _SideHistogram(backend=backend, parent=self, direction=qt.Qt.Vertical)
+ self._histoVPlot = _SideHistogram(
+ backend=backend, parent=self, direction=qt.Qt.Vertical
+ )
widgetHandle = self._histoVPlot.getWidgetHandle()
widgetHandle.setMinimumWidth(self.HISTOGRAMS_HEIGHT)
widgetHandle.setMaximumWidth(self.HISTOGRAMS_HEIGHT)
- self._histoVPlot.setInteractiveMode('zoom')
- self._histoVPlot.setDataMargins(0.1, 0.1, 0., 0.)
+ # Trick to align the histogram to the main plot
+ self._histoVPlot.setGraphTitle(" ")
+ self._histoVPlot.setInteractiveMode("zoom")
+ self._histoVPlot.setDataMargins(0.1, 0.1, 0.0, 0.0)
self._histoVPlot.sigMouseMoved.connect(self._mouseMovedOnHistoV)
self._histoVPlot.setProfileColor(self.HISTOGRAMS_COLOR)
self.setPanWithArrowKeys(True)
- self.setInteractiveMode('zoom') # Color set in setColormap
+ self.setInteractiveMode("zoom") # Color set in setColormap
self.sigPlotSignal.connect(self._imagePlotCB)
self.sigActiveImageChanged.connect(self._activeImageChangedSlot)
@@ -604,7 +648,7 @@ class ImageView(PlotWindow):
def isSideHistogramDisplayed(self):
"""True if the side histograms are displayed"""
- return self._histoHPlot.isVisible()
+ return self._histoHPlot.isVisibleTo(self)
def _updateHistograms(self):
"""Update histograms content using current active image."""
@@ -625,7 +669,7 @@ class ImageView(PlotWindow):
def _imagePlotCB(self, eventDict):
"""Callback for imageView plot events."""
- if eventDict['event'] == 'mouseMoved':
+ if eventDict["event"] == "mouseMoved":
activeImage = self.getActiveImage()
if activeImage is not None:
data = activeImage.getData(copy=False)
@@ -634,16 +678,14 @@ class ImageView(PlotWindow):
# Get corresponding coordinate in image
origin = activeImage.getOrigin()
scale = activeImage.getScale()
- if (eventDict['x'] >= origin[0] and
- eventDict['y'] >= origin[1]):
- x = int((eventDict['x'] - origin[0]) / scale[0])
- y = int((eventDict['y'] - origin[1]) / scale[1])
+ if eventDict["x"] >= origin[0] and eventDict["y"] >= origin[1]:
+ x = int((eventDict["x"] - origin[0]) / scale[0])
+ y = int((eventDict["y"] - origin[1]) / scale[1])
if x >= 0 and x < width and y >= 0 and y < height:
- self.valueChanged.emit(float(x), float(y),
- data[y][x])
+ self.valueChanged.emit(float(x), float(y), data[y][x])
- elif eventDict['event'] == 'limitsChanged':
+ elif eventDict["event"] == "limitsChanged":
self._updateHistograms()
def _mouseMovedOnHistoH(self, x, y):
@@ -663,9 +705,10 @@ class ImageView(PlotWindow):
column = int((x - minValue) / xScale)
if column >= 0 and column < data.shape[0]:
self.valueChanged.emit(
- float('nan'),
+ float("nan"),
float(column + self._cache.dataXRange[0]),
- data[column])
+ data[column],
+ )
def _mouseMovedOnHistoV(self, x, y):
if self._cache is None:
@@ -684,9 +727,8 @@ class ImageView(PlotWindow):
row = int((y - minValue) / yScale)
if row >= 0 and row < data.shape[0]:
self.valueChanged.emit(
- float(row + self._cache.dataYRange[0]),
- float('nan'),
- data[row])
+ float(row + self._cache.dataYRange[0]), float("nan"), data[row]
+ )
def _activeImageChangedSlot(self, previous, legend):
"""Handle Plot active image change.
@@ -733,7 +775,7 @@ class ImageView(PlotWindow):
return self.__profileWindowBehavior
def getProfileToolBar(self):
- """"Returns profile tools attached to this plot.
+ """Returns profile tools attached to this plot.
:rtype: silx.gui.plot.PlotTools.ProfileToolBar
"""
@@ -757,18 +799,20 @@ class ImageView(PlotWindow):
:return: The histogram and its extent as a dict or None.
:rtype: dict
"""
- assert axis in ('x', 'y')
+ assert axis in ("x", "y")
if self._cache is None:
return None
else:
- if axis == 'x':
+ if axis == "x":
return dict(
data=numpy.array(self._cache.histoH, copy=True),
- extent=self._cache.dataXRange)
+ extent=self._cache.dataXRange,
+ )
else:
return dict(
data=numpy.array(self._cache.histoV, copy=True),
- extent=(self._cache.dataYRange))
+ extent=(self._cache.dataYRange),
+ )
def radarView(self):
"""Get the lower right radarView widget."""
@@ -795,8 +839,15 @@ class ImageView(PlotWindow):
"""
return self.getDefaultColormap()
- def setColormap(self, colormap=None, normalization=None,
- autoscale=None, vmin=None, vmax=None, colors=None):
+ def setColormap(
+ self,
+ colormap=None,
+ normalization=None,
+ autoscale=None,
+ vmin=None,
+ vmax=None,
+ colors=None,
+ ):
"""Set the default colormap and update active image.
Parameters that are not provided are taken from the current colormap.
@@ -868,10 +919,17 @@ class ImageView(PlotWindow):
cmap.setColormapLUT(colors)
cursorColor = cursorColorForColormap(cmap.getName())
- self.setInteractiveMode('zoom', color=cursorColor)
-
- def setImage(self, image, origin=(0, 0), scale=(1., 1.),
- copy=True, reset=None, resetzoom=True):
+ self.setInteractiveMode("zoom", color=cursorColor)
+
+ def setImage(
+ self,
+ image,
+ origin=(0, 0),
+ scale=(1.0, 1.0),
+ copy=True,
+ reset=None,
+ resetzoom=True,
+ ):
"""Set the image to display.
:param image: A 2D array representing the image or None to empty plot.
@@ -901,12 +959,12 @@ class ImageView(PlotWindow):
assert scale[1] > 0
if image is None:
- self.remove(self._imageLegend, kind='image')
+ self.remove(self._imageLegend, kind="image")
return
- data = numpy.array(image, order='C', copy=copy)
+ data = numpy.array(image, order="C", copy=copy)
if data.size == 0:
- self.remove(self._imageLegend, kind='image')
+ self.remove(self._imageLegend, kind="image")
return
assert data.ndim == 2 or (data.ndim == 3 and data.shape[2] in (3, 4))
@@ -917,11 +975,14 @@ class ImageView(PlotWindow):
aggregation = items.ImageDataAggregated.Aggregation.NONE
if aggregation is items.ImageDataAggregated.Aggregation.NONE:
- self.addImage(data,
- legend=self._imageLegend,
- origin=origin, scale=scale,
- colormap=self.getColormap(),
- resetzoom=False)
+ self.addImage(
+ data,
+ legend=self._imageLegend,
+ origin=origin,
+ scale=scale,
+ colormap=self.getColormap(),
+ resetzoom=False,
+ )
else:
item = self._getItem("image", self._imageLegend)
if isinstance(item, items.ImageDataAggregated):
@@ -954,31 +1015,33 @@ class ImageView(PlotWindow):
# ImageViewMainWindow #########################################################
+
class ImageViewMainWindow(ImageView):
""":class:`ImageView` with additional toolbars
Adds extra toolbar and a status bar to :class:`ImageView`.
"""
+
def __init__(self, parent=None, backend=None):
self._dataInfo = None
super(ImageViewMainWindow, self).__init__(parent, backend)
self.setWindowFlags(qt.Qt.Window)
- self.getXAxis().setLabel('X')
- self.getYAxis().setLabel('Y')
- self.setGraphTitle('Image')
+ self.getXAxis().setLabel("X")
+ self.getYAxis().setLabel("Y")
+ self.setGraphTitle("Image")
# Add toolbars and status bar
self.addToolBar(qt.Qt.BottomToolBarArea, LimitsToolBar(plot=self))
- menu = self.menuBar().addMenu('File')
+ menu = self.menuBar().addMenu("File")
menu.addAction(self.getOutputToolBar().getSaveAction())
menu.addAction(self.getOutputToolBar().getPrintAction())
menu.addSeparator()
- action = menu.addAction('Quit')
+ action = menu.addAction("Quit")
action.triggered[bool].connect(qt.QApplication.instance().quit)
- menu = self.menuBar().addMenu('Edit')
+ menu = self.menuBar().addMenu("Edit")
menu.addAction(self.getOutputToolBar().getCopyAction())
menu.addSeparator()
menu.addAction(self.getResetZoomAction())
@@ -987,7 +1050,7 @@ class ImageViewMainWindow(ImageView):
menu.addAction(actions.control.YAxisInvertedAction(self, self))
menu.addAction(self.getShowSideHistogramsAction())
- self.__profileMenu = self.menuBar().addMenu('Profile')
+ self.__profileMenu = self.menuBar().addMenu("Profile")
self.__updateProfileMenu()
# Connect to ImageView's signal
@@ -1007,7 +1070,12 @@ class ImageViewMainWindow(ImageView):
try:
if isinstance(value, numpy.ndarray):
if len(value) == 4:
- return "RGBA: %.3g, %.3g, %.3g, %.3g" % (value[0], value[1], value[2], value[3])
+ return "RGBA: %.3g, %.3g, %.3g, %.3g" % (
+ value[0],
+ value[1],
+ value[2],
+ value[3],
+ )
elif len(value) == 3:
return "RGB: %.3g, %.3g, %.3g" % (value[0], value[1], value[2])
else:
@@ -1020,14 +1088,14 @@ class ImageViewMainWindow(ImageView):
def _statusBarSlot(self, row, column, value):
"""Update status bar with coordinates/value from plots."""
if numpy.isnan(row):
- msg = 'Column: %d, Sum: %g' % (int(column), value)
+ msg = "Column: %d, Sum: %g" % (int(column), value)
elif numpy.isnan(column):
- msg = 'Row: %d, Sum: %g' % (int(row), value)
+ msg = "Row: %d, Sum: %g" % (int(row), value)
else:
msg_value = self._formatValueToString(value)
- msg = 'Position: (%d, %d), %s' % (int(row), int(column), msg_value)
+ msg = "Position: (%d, %d), %s" % (int(row), int(column), msg_value)
if self._dataInfo is not None:
- msg = self._dataInfo + ', ' + msg
+ msg = self._dataInfo + ", " + msg
self.statusBar().showMessage(msg)
@@ -1038,10 +1106,10 @@ class ImageViewMainWindow(ImageView):
@docstring(ImageView)
def setImage(self, image, *args, **kwargs):
- if hasattr(image, 'dtype') and hasattr(image, 'shape'):
+ if hasattr(image, "dtype") and hasattr(image, "shape"):
assert image.ndim == 2 or (image.ndim == 3 and image.shape[2] in (3, 4))
height, width = image.shape[0:2]
- dataInfo = 'Data: %dx%d (%s)' % (width, height, str(image.dtype))
+ dataInfo = "Data: %dx%d (%s)" % (width, height, str(image.dtype))
else:
dataInfo = None
diff --git a/src/silx/gui/plot/Interaction.py b/src/silx/gui/plot/Interaction.py
index 053fbe5..2d8bf63 100644
--- a/src/silx/gui/plot/Interaction.py
+++ b/src/silx/gui/plot/Interaction.py
@@ -84,6 +84,7 @@ import weakref
# state machine ###############################################################
+
class State(object):
"""Base class for the states of a state machine.
@@ -142,6 +143,7 @@ class State(object):
"""
pass
+
class StateMachine(object):
"""State machine controller.
@@ -184,7 +186,7 @@ class StateMachine(object):
:param str eventName: Name of the event to handle
:returns: The return value of the handler or None
"""
- handlerName = 'on' + eventName[0].upper() + eventName[1:]
+ handlerName = "on" + eventName[0].upper() + eventName[1:]
try:
handler = getattr(self.state, handlerName)
except AttributeError:
@@ -204,13 +206,13 @@ class StateMachine(object):
# clickOrDrag #################################################################
-LEFT_BTN = 'left'
+LEFT_BTN = "left"
"""Left mouse button."""
-RIGHT_BTN = 'right'
+RIGHT_BTN = "right"
"""Right mouse button."""
-MIDDLE_BTN = 'middle'
+MIDDLE_BTN = "middle"
"""Middle mouse button."""
@@ -224,15 +226,15 @@ class ClickOrDrag(StateMachine):
:param Set[str] dragButtons: Set of buttons that provides drag interaction
"""
- DRAG_THRESHOLD_SQUARE_DIST = 5 ** 2
+ DRAG_THRESHOLD_SQUARE_DIST = 5**2
class Idle(State):
def onPress(self, x, y, btn):
if btn in self.machine.dragButtons:
- self.goto('clickOrDrag', x, y, btn)
+ self.goto("clickOrDrag", x, y, btn)
return True
elif btn in self.machine.clickButtons:
- self.goto('click', x, y, btn)
+ self.goto("click", x, y, btn)
return True
class Click(State):
@@ -244,12 +246,12 @@ class ClickOrDrag(StateMachine):
dx2 = (x - self.initPos[0]) ** 2
dy2 = (y - self.initPos[1]) ** 2
if (dx2 + dy2) >= self.machine.DRAG_THRESHOLD_SQUARE_DIST:
- self.goto('idle')
+ self.goto("idle")
def onRelease(self, x, y, btn):
if btn == self.button:
self.machine.click(x, y, btn)
- self.goto('idle')
+ self.goto("idle")
class ClickOrDrag(State):
def enterState(self, x, y, btn):
@@ -260,13 +262,13 @@ class ClickOrDrag(StateMachine):
dx2 = (x - self.initPos[0]) ** 2
dy2 = (y - self.initPos[1]) ** 2
if (dx2 + dy2) >= self.machine.DRAG_THRESHOLD_SQUARE_DIST:
- self.goto('drag', self.initPos, (x, y), self.button)
+ self.goto("drag", self.initPos, (x, y), self.button)
def onRelease(self, x, y, btn):
if btn == self.button:
if btn in self.machine.clickButtons:
self.machine.click(x, y, btn)
- self.goto('idle')
+ self.goto("idle")
class Drag(State):
def enterState(self, initPos, curPos, btn):
@@ -281,26 +283,27 @@ class ClickOrDrag(StateMachine):
def onRelease(self, x, y, btn):
if btn == self.button:
self.machine.endDrag(self.initPos, (x, y), btn)
- self.goto('idle')
+ self.goto("idle")
- def __init__(self,
- clickButtons=(LEFT_BTN, RIGHT_BTN),
- dragButtons=(LEFT_BTN,)):
+ def __init__(self, clickButtons=(LEFT_BTN, RIGHT_BTN), dragButtons=(LEFT_BTN,)):
states = {
- 'idle': self.Idle,
- 'click': self.Click,
- 'clickOrDrag': self.ClickOrDrag,
- 'drag': self.Drag
+ "idle": self.Idle,
+ "click": self.Click,
+ "clickOrDrag": self.ClickOrDrag,
+ "drag": self.Drag,
}
self.__clickButtons = set(clickButtons)
self.__dragButtons = set(dragButtons)
- super(ClickOrDrag, self).__init__(states, 'idle')
+ super(ClickOrDrag, self).__init__(states, "idle")
- clickButtons = property(lambda self: self.__clickButtons,
- doc="Buttons with click interaction (Set[int])")
+ clickButtons = property(
+ lambda self: self.__clickButtons,
+ doc="Buttons with click interaction (Set[int])",
+ )
- dragButtons = property(lambda self: self.__dragButtons,
- doc="Buttons with drag interaction (Set[int])")
+ dragButtons = property(
+ lambda self: self.__dragButtons, doc="Buttons with drag interaction (Set[int])"
+ )
def click(self, x, y, btn):
"""Called upon a button supporting click.
diff --git a/src/silx/gui/plot/ItemsSelectionDialog.py b/src/silx/gui/plot/ItemsSelectionDialog.py
index c303c6b..b4e4f9e 100644
--- a/src/silx/gui/plot/ItemsSelectionDialog.py
+++ b/src/silx/gui/plot/ItemsSelectionDialog.py
@@ -43,6 +43,7 @@ class KindsSelector(qt.QListWidget):
"""List widget allowing to select plot item kinds
("curve", "scatter", "image"...)
"""
+
sigSelectedKindsChanged = qt.Signal(list)
def __init__(self, parent=None, kinds=None):
@@ -87,8 +88,10 @@ class KindsSelector(qt.QListWidget):
def selectAll(self):
"""Select all available kinds."""
- if self.selectionMode() in [qt.QAbstractItemView.SingleSelection,
- qt.QAbstractItemView.NoSelection]:
+ if self.selectionMode() in [
+ qt.QAbstractItemView.SingleSelection,
+ qt.QAbstractItemView.NoSelection,
+ ]:
raise RuntimeError("selectAll requires a multiple selection mode")
for i in range(self.count()):
self.item(i).setSelected(True)
@@ -102,6 +105,7 @@ class PlotItemsSelector(qt.QTableWidget):
You can be warned of selection changes by listening to signal
:attr:`itemSelectionChanged`.
"""
+
def __init__(self, parent=None, plot=None):
if plot is None or not isinstance(plot, PlotWidget):
raise AttributeError("parameter plot is required")
@@ -131,8 +135,9 @@ class PlotItemsSelector(qt.QTableWidget):
:param list[str] kinds: Sequence of kinds
"""
if not set(kinds) <= set(PlotWidget.ITEM_KINDS):
- raise KeyError("Illegal plot item kinds: %s" %
- set(kinds) - set(PlotWidget.ITEM_KINDS))
+ raise KeyError(
+ "Illegal plot item kinds: %s" % set(kinds) - set(PlotWidget.ITEM_KINDS)
+ )
self.plot_item_kinds = kinds
self.updatePlotItems()
@@ -199,6 +204,7 @@ class ItemsSelectionDialog(qt.QDialog):
else:
print("Selection cancelled")
"""
+
def __init__(self, parent=None, plot=None):
if plot is None or not isinstance(plot, PlotWidget):
raise AttributeError("parameter plot is required")
@@ -211,7 +217,8 @@ class ItemsSelectionDialog(qt.QDialog):
self.kind_selector = KindsSelector(self)
self.kind_selector.setToolTip(
- "select one or more item kinds to show them in the item list")
+ "select one or more item kinds to show them in the item list"
+ )
self.item_selector = PlotItemsSelector(self, plot)
self.item_selector.setToolTip("select items")
@@ -261,25 +268,26 @@ class ItemsSelectionDialog(qt.QDialog):
:param mode: One of :class:`QTableWidget` selection modes
"""
if mode == self.item_selector.SingleSelection:
- self.item_selector.setToolTip(
- "Select one item by clicking on it.")
+ self.item_selector.setToolTip("Select one item by clicking on it.")
elif mode == self.item_selector.MultiSelection:
self.item_selector.setToolTip(
- "Select one or more items by clicking with the left mouse"
- " button.\nYou can unselect items by clicking them again.\n"
- "Multiple items can be toggled by dragging the mouse over them.")
+ "Select one or more items by clicking with the left mouse"
+ " button.\nYou can unselect items by clicking them again.\n"
+ "Multiple items can be toggled by dragging the mouse over them."
+ )
elif mode == self.item_selector.ExtendedSelection:
self.item_selector.setToolTip(
- "Select one or more items. You can select multiple items "
- "by keeping the Ctrl key pushed when clicking.\nYou can "
- "select a range of items by clicking on the first and "
- "last while keeping the Shift key pushed.")
+ "Select one or more items. You can select multiple items "
+ "by keeping the Ctrl key pushed when clicking.\nYou can "
+ "select a range of items by clicking on the first and "
+ "last while keeping the Shift key pushed."
+ )
elif mode == self.item_selector.ContiguousSelection:
self.item_selector.setToolTip(
- "Select one item by clicking on it. If you press the Shift"
- " key while clicking on a second item,\nall items between "
- "the two will be selected.")
+ "Select one item by clicking on it. If you press the Shift"
+ " key while clicking on a second item,\nall items between "
+ "the two will be selected."
+ )
elif mode == self.item_selector.NoSelection:
- raise ValueError("The NoSelection mode is not allowed "
- "in this context.")
+ raise ValueError("The NoSelection mode is not allowed " "in this context.")
self.item_selector.setSelectionMode(mode)
diff --git a/src/silx/gui/plot/LegendSelector.py b/src/silx/gui/plot/LegendSelector.py
index 4d8ebe9..22348fb 100755
--- a/src/silx/gui/plot/LegendSelector.py
+++ b/src/silx/gui/plot/LegendSelector.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2004-2022 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -39,6 +39,7 @@ import numpy
from .. import qt, colors
from ..widgets.LegendIconWidget import LegendIconWidget
from . import items
+from ...utils.deprecation import deprecated
_logger = logging.getLogger(__name__)
@@ -86,11 +87,10 @@ class LegendIcon(LegendIconWidget):
self._update()
def _update(self):
- """Update widget according to current curve state.
- """
+ """Update widget according to current curve state."""
curve = self.getCurve()
if curve is None:
- _logger.error('Curve no more exists')
+ _logger.error("Curve no more exists")
self.setEnabled(False)
return
@@ -104,11 +104,10 @@ class LegendIcon(LegendIconWidget):
color = style.getColor()
if numpy.array(color, copy=False).ndim != 1:
# array of colors, use transparent black
- color = 0., 0., 0., 0.
+ color = 0.0, 0.0, 0.0, 0.0
color = colors.rgba(color) # Make sure it is float in [0, 1]
alpha = curve.getAlpha()
- color = qt.QColor.fromRgbF(
- color[0], color[1], color[2], color[3] * alpha)
+ color = qt.QColor.fromRgbF(color[0], color[1], color[2], color[3] * alpha)
self.setLineColor(color)
self.setSymbolColor(color)
self.update() # TODO this should not be needed
@@ -118,15 +117,17 @@ class LegendIcon(LegendIconWidget):
:param event: Kind of change
"""
- if event in (items.ItemChangedType.VISIBLE,
- items.ItemChangedType.SYMBOL,
- items.ItemChangedType.SYMBOL_SIZE,
- items.ItemChangedType.LINE_WIDTH,
- items.ItemChangedType.LINE_STYLE,
- items.ItemChangedType.COLOR,
- items.ItemChangedType.ALPHA,
- items.ItemChangedType.HIGHLIGHTED,
- items.ItemChangedType.HIGHLIGHTED_STYLE):
+ if event in (
+ items.ItemChangedType.VISIBLE,
+ items.ItemChangedType.SYMBOL,
+ items.ItemChangedType.SYMBOL_SIZE,
+ items.ItemChangedType.LINE_WIDTH,
+ items.ItemChangedType.LINE_STYLE,
+ items.ItemChangedType.COLOR,
+ items.ItemChangedType.ALPHA,
+ items.ItemChangedType.HIGHLIGHTED,
+ items.ItemChangedType.HIGHLIGHTED_STYLE,
+ ):
self._update()
@@ -142,12 +143,14 @@ class LegendModel(qt.QAbstractListModel):
- symbol
- visibility of the symbols
"""
+
iconColorRole = qt.Qt.UserRole + 0
iconLineWidthRole = qt.Qt.UserRole + 1
iconLineStyleRole = qt.Qt.UserRole + 2
showLineRole = qt.Qt.UserRole + 3
iconSymbolRole = qt.Qt.UserRole + 4
showSymbolRole = qt.Qt.UserRole + 5
+ itemRole = qt.Qt.UserRole + 6
def __init__(self, legendList=None, parent=None):
super(LegendModel, self).__init__(parent)
@@ -159,16 +162,14 @@ class LegendModel(qt.QAbstractListModel):
def __getitem__(self, idx):
if idx >= len(self.legendList):
- raise IndexError('list index out of range')
+ raise IndexError("list index out of range")
return self.legendList[idx]
def rowCount(self, modelIndex=None):
return len(self.legendList)
def flags(self, index):
- return (qt.Qt.ItemIsEditable |
- qt.Qt.ItemIsEnabled |
- qt.Qt.ItemIsSelectable)
+ return qt.Qt.ItemIsEditable | qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable
def data(self, modelIndex, role):
if modelIndex.isValid:
@@ -176,7 +177,7 @@ class LegendModel(qt.QAbstractListModel):
else:
return None
if idx >= len(self.legendList):
- raise IndexError('list index out of range')
+ raise IndexError("list index out of range")
item = self.legendList[idx]
isActive = item[1].get("active", False)
@@ -186,7 +187,7 @@ class LegendModel(qt.QAbstractListModel):
return legend
elif role == qt.Qt.SizeHintRole:
# size = qt.QSize(200,50)
- _logger.warning('LegendModel -- size hint role not implemented')
+ _logger.warning("LegendModel -- size hint role not implemented")
return qt.QSize()
elif role == qt.Qt.TextAlignmentRole:
alignment = qt.Qt.AlignVCenter | qt.Qt.AlignLeft
@@ -194,7 +195,7 @@ class LegendModel(qt.QAbstractListModel):
elif role == qt.Qt.BackgroundRole:
# Background color, must be QBrush
if isActive:
- brush = self._palette.brush(qt.QPalette.Normal, qt.QPalette.Highlight)
+ brush = self._palette.brush(qt.QPalette.Active, qt.QPalette.Highlight)
elif idx % 2:
brush = qt.QBrush(qt.QColor(240, 240, 240))
else:
@@ -203,28 +204,32 @@ class LegendModel(qt.QAbstractListModel):
elif role == qt.Qt.ForegroundRole:
# ForegroundRole color, must be QBrush
if isActive:
- brush = self._palette.brush(qt.QPalette.Normal, qt.QPalette.HighlightedText)
+ brush = self._palette.brush(
+ qt.QPalette.Active, qt.QPalette.HighlightedText
+ )
else:
- brush = self._palette.brush(qt.QPalette.Normal, qt.QPalette.WindowText)
+ brush = self._palette.brush(qt.QPalette.Active, qt.QPalette.WindowText)
return brush
elif role == qt.Qt.CheckStateRole:
return bool(item[2]) # item[2] == True
elif role == qt.Qt.ToolTipRole or role == qt.Qt.StatusTipRole:
- return ''
+ return ""
elif role == self.iconColorRole:
- return item[1]['color']
+ return item[1]["color"]
elif role == self.iconLineWidthRole:
- return item[1]['linewidth']
+ return item[1]["linewidth"]
elif role == self.iconLineStyleRole:
- return item[1]['linestyle']
+ return item[1]["linestyle"]
elif role == self.iconSymbolRole:
- return item[1]['symbol']
+ return item[1]["symbol"]
elif role == self.showLineRole:
return item[3]
elif role == self.showSymbolRole:
return item[4]
+ elif role == self.itemRole:
+ return item[5]
else:
- _logger.info('Unkown role requested: %s', str(role))
+ _logger.info("Unkown role requested: %s", str(role))
return None
def setData(self, modelIndex, value, role):
@@ -234,8 +239,7 @@ class LegendModel(qt.QAbstractListModel):
return None
if idx >= len(self.legendList):
# raise IndexError('list index out of range')
- _logger.warning(
- 'setData -- List index out of range, idx: %d', idx)
+ _logger.warning("setData -- List index out of range, idx: %d", idx)
return None
item = self.legendList[idx]
@@ -244,22 +248,25 @@ class LegendModel(qt.QAbstractListModel):
# Set legend
item[0] = str(value)
elif role == self.iconColorRole:
- item[1]['color'] = qt.QColor(value)
+ item[1]["color"] = qt.QColor(value)
elif role == self.iconLineWidthRole:
- item[1]['linewidth'] = int(value)
+ item[1]["linewidth"] = int(value)
elif role == self.iconLineStyleRole:
- item[1]['linestyle'] = str(value)
+ item[1]["linestyle"] = value
elif role == self.iconSymbolRole:
- item[1]['symbol'] = str(value)
+ item[1]["symbol"] = str(value)
elif role == qt.Qt.CheckStateRole:
item[2] = value
elif role == self.showLineRole:
item[3] = value
elif role == self.showSymbolRole:
item[4] = value
+ elif role == self.itemRole:
+ item[5] = value
except ValueError:
- _logger.warning('Conversion failed:\n\tvalue: %s\n\trole: %s',
- str(value), str(role))
+ _logger.warning(
+ "Conversion failed:\n\tvalue: %s\n\trole: %s", str(value), str(role)
+ )
# Can that be right? Read docs again..
self.dataChanged.emit(modelIndex, modelIndex)
return True
@@ -272,44 +279,45 @@ class LegendModel(qt.QAbstractListModel):
"""
modelIndex = self.createIndex(row, 0)
count = len(llist)
- super(LegendModel, self).beginInsertRows(modelIndex,
- row,
- row + count)
+ super(LegendModel, self).beginInsertRows(modelIndex, row, row + count)
head = self.legendList[0:row]
tail = self.legendList[row:]
new = []
- for (legend, icon) in llist:
- linestyle = icon.get('linestyle', None)
+ for legend, icon in llist:
+ linestyle = icon.get("linestyle", None)
if LegendIconWidget.isEmptyLineStyle(linestyle):
# Curve had no line, give it one and hide it
# So when toggle line, it will display a solid line
showLine = False
- icon['linestyle'] = '-'
+ icon["linestyle"] = "-"
else:
showLine = True
- symbol = icon.get('symbol', None)
+ symbol = icon.get("symbol", None)
if LegendIconWidget.isEmptySymbol(symbol):
# Curve had no symbol, give it one and hide it
# So when toggle symbol, it will display 'o'
showSymbol = False
- icon['symbol'] = 'o'
+ icon["symbol"] = "o"
else:
showSymbol = True
- selected = icon.get('selected', True)
- item = [legend,
- icon,
- selected,
- showLine,
- showSymbol]
+ selected = icon.get("selected", True)
+ item = [
+ legend,
+ icon,
+ selected,
+ showLine,
+ showSymbol,
+ icon.get("item", None),
+ ]
new.append(item)
self.legendList = head + new + tail
super(LegendModel, self).endInsertRows()
return True
def insertRows(self, row, count, modelIndex=qt.QModelIndex()):
- raise NotImplementedError('Use LegendModel.insertLegendList instead')
+ raise NotImplementedError("Use LegendModel.insertLegendList instead")
def removeRow(self, row):
return self.removeRows(row, 1)
@@ -320,14 +328,13 @@ class LegendModel(qt.QAbstractListModel):
# Nothing to do..
return True
if row < 0 or row >= length:
- raise IndexError('Index out of range -- ' +
- 'idx: %d, len: %d' % (row, length))
+ raise IndexError(
+ "Index out of range -- " + "idx: %d, len: %d" % (row, length)
+ )
if count == 0:
return False
- super(LegendModel, self).beginRemoveRows(modelIndex,
- row,
- row + count)
- del(self.legendList[row:row + count])
+ super(LegendModel, self).beginRemoveRows(modelIndex, row, row + count)
+ del self.legendList[row : row + count]
super(LegendModel, self).endRemoveRows()
return True
@@ -338,8 +345,7 @@ class LegendModel(qt.QAbstractListModel):
:type editor: QWidget
"""
if event not in self.eventList:
- raise ValueError('setEditor -- Event must be in %s' %
- str(self.eventList))
+ raise ValueError("setEditor -- Event must be in %s" % str(self.eventList))
self.editorDict[event] = editor
@@ -380,12 +386,11 @@ class LegendListItemWidget(qt.QItemDelegate):
iconSize = self.icon.sizeHint()
# Calculate icon position
x = rect.left() + 2
- y = rect.top() + int(.5 * (rect.height() - iconSize.height()))
+ y = rect.top() + int(0.5 * (rect.height() - iconSize.height()))
iconRect = qt.QRect(qt.QPoint(x, y), iconSize)
# Calculate label rectangle
- legendSize = qt.QSize(rect.width() - iconSize.width() - 30,
- rect.height())
+ legendSize = qt.QSize(rect.width() - iconSize.width() - 30, rect.height())
# Calculate label position
x = rect.left() + iconRect.width()
y = rect.top()
@@ -443,8 +448,7 @@ class LegendListItemWidget(qt.QItemDelegate):
else:
checkState = qt.Qt.Unchecked
- self.drawCheck(
- painter, qt.QStyleOptionViewItem(), chBoxRect, checkState)
+ self.drawCheck(painter, qt.QStyleOptionViewItem(), chBoxRect, checkState)
painter.restore()
@@ -453,7 +457,11 @@ class LegendListItemWidget(qt.QItemDelegate):
# Mouse events are sent to editorEvent()
# even if they don't start editing of the item.
if event.button() == qt.Qt.RightButton and self.contextMenu:
- self.contextMenu.exec(event.globalPos(), modelIndex)
+ if qt.BINDING == "PyQt5":
+ position = event.globalPos()
+ else: # Qt6
+ position = event.globalPosition().toPoint()
+ self.contextMenu.exec(position, modelIndex)
return True
elif event.button() == qt.Qt.LeftButton:
# Check if checkbox was clicked
@@ -461,26 +469,29 @@ class LegendListItemWidget(qt.QItemDelegate):
cbRect = self.cbDict[idx]
if cbRect.contains(event.pos()):
# Toggle checkbox
- model.setData(modelIndex,
- not modelIndex.data(qt.Qt.CheckStateRole),
- qt.Qt.CheckStateRole)
+ model.setData(
+ modelIndex,
+ not modelIndex.data(qt.Qt.CheckStateRole),
+ qt.Qt.CheckStateRole,
+ )
event.ignore()
return True
else:
return super(LegendListItemWidget, self).editorEvent(
- event, model, option, modelIndex)
+ event, model, option, modelIndex
+ )
def createEditor(self, parent, option, idx):
- _logger.info('### Editor request ###')
+ _logger.info("### Editor request ###")
def sizeHint(self, option, idx):
# return qt.QSize(68,24)
iconSize = self.icon.sizeHint()
legendSize = self.legend.sizeHint()
checkboxSize = self.checkbox.sizeHint()
- height = max([iconSize.height(),
- legendSize.height(),
- checkboxSize.height()]) + 4
+ height = (
+ max([iconSize.height(), legendSize.height(), checkboxSize.height()]) + 4
+ )
width = iconSize.width() + legendSize.width() + checkboxSize.width()
return qt.QSize(width, height)
@@ -491,9 +502,9 @@ class LegendListView(qt.QListView):
sigLegendSignal = qt.Signal(object)
"""Signal emitting a dict when an action is triggered by the user."""
- __mouseClickedEvent = 'mouseClicked'
- __checkBoxClickedEvent = 'checkBoxClicked'
- __legendClickedEvent = 'legendClicked'
+ __mouseClickedEvent = "mouseClicked"
+ __checkBoxClickedEvent = "checkBoxClicked"
+ __legendClickedEvent = "legendClicked"
def __init__(self, parent=None, model=None, contextMenu=None):
super(LegendListView, self).__init__(parent)
@@ -539,47 +550,55 @@ class LegendListView(qt.QListView):
model.setData(modelIndex, new_legend, qt.Qt.DisplayRole)
color = modelIndex.data(LegendModel.iconColorRole)
- new_color = icon.get('color', None)
+ new_color = icon.get("color", None)
if new_color != color:
model.setData(modelIndex, new_color, LegendModel.iconColorRole)
linewidth = modelIndex.data(LegendModel.iconLineWidthRole)
- new_linewidth = icon.get('linewidth', 1.0)
+ new_linewidth = icon.get("linewidth", 1.0)
if new_linewidth != linewidth:
- model.setData(modelIndex, new_linewidth, LegendModel.iconLineWidthRole)
+ model.setData(
+ modelIndex, new_linewidth, LegendModel.iconLineWidthRole
+ )
linestyle = modelIndex.data(LegendModel.iconLineStyleRole)
- new_linestyle = icon.get('linestyle', None)
+ new_linestyle = icon.get("linestyle", None)
visible = not LegendIconWidget.isEmptyLineStyle(new_linestyle)
model.setData(modelIndex, visible, LegendModel.showLineRole)
if new_linestyle != linestyle:
- model.setData(modelIndex, new_linestyle, LegendModel.iconLineStyleRole)
+ model.setData(
+ modelIndex, new_linestyle, LegendModel.iconLineStyleRole
+ )
symbol = modelIndex.data(LegendModel.iconSymbolRole)
- new_symbol = icon.get('symbol', None)
+ new_symbol = icon.get("symbol", None)
visible = not LegendIconWidget.isEmptySymbol(new_symbol)
model.setData(modelIndex, visible, LegendModel.showSymbolRole)
if new_symbol != symbol:
model.setData(modelIndex, new_symbol, LegendModel.iconSymbolRole)
selected = modelIndex.data(qt.Qt.CheckStateRole)
- new_selected = icon.get('selected', True)
+ new_selected = icon.get("selected", True)
if new_selected != selected:
model.setData(modelIndex, new_selected, qt.Qt.CheckStateRole)
- _logger.debug('LegendListView.setLegendList(legendList) finished')
+
+ item = modelIndex.data(LegendModel.itemRole)
+ newItem = icon.get("item", None)
+ if item is not newItem:
+ model.setData(modelIndex, newItem, LegendModel.itemRole)
+ _logger.debug("LegendListView.setLegendList(legendList) finished")
def clear(self):
model = self.model()
model.removeRows(0, model.rowCount())
- _logger.debug('LegendListView.clear() finished')
+ _logger.debug("LegendListView.clear() finished")
def setContextMenu(self, contextMenu=None):
delegate = self.itemDelegate()
if isinstance(delegate, LegendListItemWidget) and self.model():
if contextMenu is None:
delegate.contextMenu = LegendListContextMenu(self.model())
- delegate.contextMenu.sigContextMenu.connect(
- self._contextMenuSlot)
+ delegate.contextMenu.sigContextMenu.connect(self._contextMenuSlot)
else:
delegate.contextMenu = contextMenu
@@ -632,12 +651,11 @@ class LegendListView(qt.QListView):
:param QModelIndex modelIndex: index of the clicked item
"""
- _logger.debug('self._handleMouseClick called')
- if self.__lastButton not in [qt.Qt.LeftButton,
- qt.Qt.RightButton]:
+ _logger.debug("self._handleMouseClick called")
+ if self.__lastButton not in [qt.Qt.LeftButton, qt.Qt.RightButton]:
return
if not modelIndex.isValid():
- _logger.debug('_handleMouseClick -- Invalid QModelIndex')
+ _logger.debug("_handleMouseClick -- Invalid QModelIndex")
return
# model = self.model()
idx = modelIndex.row()
@@ -653,30 +671,29 @@ class LegendListView(qt.QListView):
# TODO: Check for doubleclicks on legend/icon and spawn editors
ddict = {
- 'legend': str(modelIndex.data(qt.Qt.DisplayRole)),
- 'icon': {
- 'linewidth': str(modelIndex.data(
- LegendModel.iconLineWidthRole)),
- 'linestyle': str(modelIndex.data(
- LegendModel.iconLineStyleRole)),
- 'symbol': str(modelIndex.data(LegendModel.iconSymbolRole))
+ "legend": str(modelIndex.data(qt.Qt.DisplayRole)),
+ "icon": {
+ "linewidth": str(modelIndex.data(LegendModel.iconLineWidthRole)),
+ "linestyle": modelIndex.data(LegendModel.iconLineStyleRole),
+ "symbol": str(modelIndex.data(LegendModel.iconSymbolRole)),
},
- 'selected': modelIndex.data(qt.Qt.CheckStateRole),
- 'type': str(modelIndex.data())
+ "selected": modelIndex.data(qt.Qt.CheckStateRole),
+ "type": str(modelIndex.data()),
+ "item": modelIndex.data(LegendModel.itemRole),
}
if self.__lastButton == qt.Qt.RightButton:
- _logger.debug('Right clicked')
- ddict['button'] = "right"
- ddict['event'] = self.__mouseClickedEvent
+ _logger.debug("Right clicked")
+ ddict["button"] = "right"
+ ddict["event"] = self.__mouseClickedEvent
elif cbClicked:
- _logger.debug('CheckBox clicked')
- ddict['button'] = "left"
- ddict['event'] = self.__checkBoxClickedEvent
+ _logger.debug("CheckBox clicked")
+ ddict["button"] = "left"
+ ddict["event"] = self.__checkBoxClickedEvent
else:
- _logger.debug('Legend clicked')
- ddict['button'] = "left"
- ddict['event'] = self.__legendClickedEvent
- _logger.debug(' idx: %d\n ddict: %s', idx, str(ddict))
+ _logger.debug("Legend clicked")
+ ddict["button"] = "left"
+ ddict["event"] = self.__legendClickedEvent
+ _logger.debug(" idx: %d\n ddict: %s", idx, str(ddict))
self.sigLegendSignal.emit(ddict)
@@ -690,29 +707,26 @@ class LegendListContextMenu(qt.QMenu):
super(LegendListContextMenu, self).__init__(parent=None)
self.model = model
- self.addAction('Set Active', self.setActiveAction)
- self.addAction('Map to left', self.mapToLeftAction)
- self.addAction('Map to right', self.mapToRightAction)
+ self.addAction("Set Active", self.setActiveAction)
+ self.addAction("Map to left", self.mapToLeftAction)
+ self.addAction("Map to right", self.mapToRightAction)
- self._pointsAction = self.addAction(
- 'Points', self.togglePointsAction)
+ self._pointsAction = self.addAction("Points", self.togglePointsAction)
self._pointsAction.setCheckable(True)
- self._linesAction = self.addAction('Lines', self.toggleLinesAction)
+ self._linesAction = self.addAction("Lines", self.toggleLinesAction)
self._linesAction.setCheckable(True)
- self.addAction('Remove curve', self.removeItemAction)
- self.addAction('Rename curve', self.renameItemAction)
+ self.addAction("Remove curve", self.removeItemAction)
+ self.addAction("Rename curve", self.renameItemAction)
def exec(self, pos, idx):
self.__currentIdx = idx
# Set checkable action state
modelIndex = self.currentIdx()
- self._pointsAction.setChecked(
- modelIndex.data(LegendModel.showSymbolRole))
- self._linesAction.setChecked(
- modelIndex.data(LegendModel.showLineRole))
+ self._pointsAction.setChecked(modelIndex.data(LegendModel.showSymbolRole))
+ self._linesAction.setChecked(modelIndex.data(LegendModel.showLineRole))
super(LegendListContextMenu, self).popup(pos)
@@ -723,55 +737,59 @@ class LegendListContextMenu(qt.QMenu):
return self.__currentIdx
def mapToLeftAction(self):
- _logger.debug('LegendListContextMenu.mapToLeftAction called')
+ _logger.debug("LegendListContextMenu.mapToLeftAction called")
modelIndex = self.currentIdx()
legend = str(modelIndex.data(qt.Qt.DisplayRole))
ddict = {
- 'legend': legend,
- 'label': legend,
- 'selected': modelIndex.data(qt.Qt.CheckStateRole),
- 'type': str(modelIndex.data()),
- 'event': "mapToLeft"
+ "legend": legend,
+ "label": legend,
+ "selected": modelIndex.data(qt.Qt.CheckStateRole),
+ "type": str(modelIndex.data()),
+ "item": modelIndex.data(LegendModel.itemRole),
+ "event": "mapToLeft",
}
self.sigContextMenu.emit(ddict)
def mapToRightAction(self):
- _logger.debug('LegendListContextMenu.mapToRightAction called')
+ _logger.debug("LegendListContextMenu.mapToRightAction called")
modelIndex = self.currentIdx()
legend = str(modelIndex.data(qt.Qt.DisplayRole))
ddict = {
- 'legend': legend,
- 'label': legend,
- 'selected': modelIndex.data(qt.Qt.CheckStateRole),
- 'type': str(modelIndex.data()),
- 'event': "mapToRight"
+ "legend": legend,
+ "label": legend,
+ "selected": modelIndex.data(qt.Qt.CheckStateRole),
+ "type": str(modelIndex.data()),
+ "item": modelIndex.data(LegendModel.itemRole),
+ "event": "mapToRight",
}
self.sigContextMenu.emit(ddict)
def removeItemAction(self):
- _logger.debug('LegendListContextMenu.removeCurveAction called')
+ _logger.debug("LegendListContextMenu.removeCurveAction called")
modelIndex = self.currentIdx()
legend = str(modelIndex.data(qt.Qt.DisplayRole))
ddict = {
- 'legend': legend,
- 'label': legend,
- 'selected': modelIndex.data(qt.Qt.CheckStateRole),
- 'type': str(modelIndex.data()),
- 'event': "removeCurve"
+ "legend": legend,
+ "label": legend,
+ "selected": modelIndex.data(qt.Qt.CheckStateRole),
+ "type": str(modelIndex.data()),
+ "item": modelIndex.data(LegendModel.itemRole),
+ "event": "removeCurve",
}
self.model.removeRow(modelIndex.row())
self.sigContextMenu.emit(ddict)
def renameItemAction(self):
- _logger.debug('LegendListContextMenu.renameCurveAction called')
+ _logger.debug("LegendListContextMenu.renameCurveAction called")
modelIndex = self.currentIdx()
legend = str(modelIndex.data(qt.Qt.DisplayRole))
ddict = {
- 'legend': legend,
- 'label': legend,
- 'selected': modelIndex.data(qt.Qt.CheckStateRole),
- 'type': str(modelIndex.data()),
- 'event': "renameCurve"
+ "legend": legend,
+ "label": legend,
+ "selected": modelIndex.data(qt.Qt.CheckStateRole),
+ "type": str(modelIndex.data()),
+ "item": modelIndex.data(LegendModel.itemRole),
+ "event": "renameCurve",
}
self.sigContextMenu.emit(ddict)
@@ -779,17 +797,18 @@ class LegendListContextMenu(qt.QMenu):
modelIndex = self.currentIdx()
legend = str(modelIndex.data(qt.Qt.DisplayRole))
ddict = {
- 'legend': legend,
- 'label': legend,
- 'selected': modelIndex.data(qt.Qt.CheckStateRole),
- 'type': str(modelIndex.data()),
+ "legend": legend,
+ "label": legend,
+ "selected": modelIndex.data(qt.Qt.CheckStateRole),
+ "item": modelIndex.data(LegendModel.itemRole),
+ "type": str(modelIndex.data()),
}
linestyle = modelIndex.data(LegendModel.iconLineStyleRole)
visible = not modelIndex.data(LegendModel.showLineRole)
- _logger.debug('toggleLinesAction -- lines visible: %s', str(visible))
- ddict['event'] = "toggleLine"
- ddict['line'] = visible
- ddict['linestyle'] = linestyle if visible else ''
+ _logger.debug("toggleLinesAction -- lines visible: %s", str(visible))
+ ddict["event"] = "toggleLine"
+ ddict["line"] = visible
+ ddict["linestyle"] = linestyle if visible else ""
self.model.setData(modelIndex, visible, LegendModel.showLineRole)
self.sigContextMenu.emit(ddict)
@@ -797,33 +816,34 @@ class LegendListContextMenu(qt.QMenu):
modelIndex = self.currentIdx()
legend = str(modelIndex.data(qt.Qt.DisplayRole))
ddict = {
- 'legend': legend,
- 'label': legend,
- 'selected': modelIndex.data(qt.Qt.CheckStateRole),
- 'type': str(modelIndex.data()),
+ "legend": legend,
+ "label": legend,
+ "selected": modelIndex.data(qt.Qt.CheckStateRole),
+ "type": str(modelIndex.data()),
+ "item": modelIndex.data(LegendModel.itemRole),
}
flag = modelIndex.data(LegendModel.showSymbolRole)
symbol = modelIndex.data(LegendModel.iconSymbolRole)
visible = not flag or LegendIconWidget.isEmptySymbol(symbol)
- _logger.debug(
- 'togglePointsAction -- Symbols visible: %s', str(visible))
+ _logger.debug("togglePointsAction -- Symbols visible: %s", str(visible))
- ddict['event'] = "togglePoints"
- ddict['points'] = visible
- ddict['symbol'] = symbol if visible else ''
+ ddict["event"] = "togglePoints"
+ ddict["points"] = visible
+ ddict["symbol"] = symbol if visible else ""
self.model.setData(modelIndex, visible, LegendModel.showSymbolRole)
self.sigContextMenu.emit(ddict)
def setActiveAction(self):
modelIndex = self.currentIdx()
legend = str(modelIndex.data(qt.Qt.DisplayRole))
- _logger.debug('setActiveAction -- active curve: %s', legend)
+ _logger.debug("setActiveAction -- active curve: %s", legend)
ddict = {
- 'legend': legend,
- 'label': legend,
- 'selected': modelIndex.data(qt.Qt.CheckStateRole),
- 'type': str(modelIndex.data()),
- 'event': "setActiveCurve",
+ "legend": legend,
+ "label": legend,
+ "selected": modelIndex.data(qt.Qt.CheckStateRole),
+ "type": str(modelIndex.data()),
+ "item": modelIndex.data(LegendModel.itemRole),
+ "event": "setActiveCurve",
}
self.sigContextMenu.emit(ddict)
@@ -842,10 +862,10 @@ class RenameCurveDialog(qt.QDialog):
self.hboxLayout = qt.QHBoxLayout(self.hbox)
self.hboxLayout.addStretch(1)
self.okButton = qt.QPushButton(self.hbox)
- self.okButton.setText('OK')
+ self.okButton.setText("OK")
self.hboxLayout.addWidget(self.okButton)
self.cancelButton = qt.QPushButton(self.hbox)
- self.cancelButton.setText('Cancel')
+ self.cancelButton.setText("Cancel")
self.hboxLayout.addWidget(self.cancelButton)
self.hboxLayout.addStretch(1)
layout.addWidget(self.lineEdit)
@@ -895,8 +915,7 @@ class LegendsDockWidget(qt.QDockWidget):
self.layout().setContentsMargins(0, 0, 0, 0)
self.setWidget(self._legendWidget)
- self.visibilityChanged.connect(
- self._visibilityChangedHandler)
+ self.visibilityChanged.connect(self._visibilityChangedHandler)
self._legendWidget.sigLegendSignal.connect(self._legendSignalHandler)
@@ -905,6 +924,7 @@ class LegendsDockWidget(qt.QDockWidget):
"""The :class:`.PlotWindow` this widget is attached to."""
return self._plotRef()
+ @deprecated(reason="No longer needed", since_version="2.0.0")
def renameCurve(self, oldLegend, newLegend):
"""Change the name of a curve using remove and addCurve
@@ -913,88 +933,77 @@ class LegendsDockWidget(qt.QDockWidget):
"""
is_active = self.plot.getActiveCurve(just_legend=True) == oldLegend
curve = self.plot.getCurve(oldLegend)
- self.plot.remove(oldLegend, kind='curve')
- self.plot.addCurve(curve.getXData(copy=False),
- curve.getYData(copy=False),
- legend=newLegend,
- info=curve.getInfo(),
- color=curve.getColor(),
- symbol=curve.getSymbol(),
- linewidth=curve.getLineWidth(),
- linestyle=curve.getLineStyle(),
- xlabel=curve.getXLabel(),
- ylabel=curve.getYLabel(),
- xerror=curve.getXErrorData(copy=False),
- yerror=curve.getYErrorData(copy=False),
- z=curve.getZValue(),
- selectable=curve.isSelectable(),
- fill=curve.isFill(),
- resetzoom=False)
+ self.plot.remove(oldLegend, kind="curve")
+ self.plot.addCurve(
+ curve.getXData(copy=False),
+ curve.getYData(copy=False),
+ legend=newLegend,
+ info=curve.getInfo(),
+ color=curve.getColor(),
+ symbol=curve.getSymbol(),
+ linewidth=curve.getLineWidth(),
+ linestyle=curve.getLineStyle(),
+ xlabel=curve.getXLabel(),
+ ylabel=curve.getYLabel(),
+ xerror=curve.getXErrorData(copy=False),
+ yerror=curve.getYErrorData(copy=False),
+ z=curve.getZValue(),
+ selectable=curve.isSelectable(),
+ fill=curve.isFill(),
+ resetzoom=False,
+ )
if is_active:
self.plot.setActiveCurve(newLegend)
def _legendSignalHandler(self, ddict):
"""Handles events from the LegendListView signal"""
_logger.debug("Legend signal ddict = %s", str(ddict))
+ # If item is not provided, retrieve it from its legend
+ curve = ddict.get("item", None)
+ if curve is None:
+ curve = self.plot.getCurve(ddict["legend"])
- if ddict['event'] == "legendClicked":
- if ddict['button'] == "left":
- self.plot.setActiveCurve(ddict['legend'])
+ if ddict["event"] == "legendClicked":
+ if ddict["button"] == "left":
+ self.plot.setActiveCurve(curve)
- elif ddict['event'] == "removeCurve":
- self.plot.removeCurve(ddict['legend'])
+ elif ddict["event"] == "removeCurve":
+ self.plot.removeItem(curve)
- elif ddict['event'] == "renameCurve":
+ elif ddict["event"] == "renameCurve":
curveList = self.plot.getAllCurves(just_legend=True)
- oldLegend = ddict['legend']
+ oldLegend = ddict["legend"]
dialog = RenameCurveDialog(self.plot, oldLegend, curveList)
ret = dialog.exec()
if ret:
newLegend = dialog.getText()
- self.renameCurve(oldLegend, newLegend)
-
- elif ddict['event'] == "setActiveCurve":
- self.plot.setActiveCurve(ddict['legend'])
-
- elif ddict['event'] == "checkBoxClicked":
- self.plot.hideCurve(ddict['legend'], not ddict['selected'])
-
- elif ddict['event'] in ["mapToRight", "mapToLeft"]:
- legend = ddict['legend']
- curve = self.plot.getCurve(legend)
- yaxis = 'right' if ddict['event'] == 'mapToRight' else 'left'
- self.plot.addCurve(x=curve.getXData(copy=False),
- y=curve.getYData(copy=False),
- legend=curve.getName(),
- info=curve.getInfo(),
- yaxis=yaxis)
-
- elif ddict['event'] == "togglePoints":
- legend = ddict['legend']
- curve = self.plot.getCurve(legend)
- symbol = ddict['symbol'] if ddict['points'] else ''
- self.plot.addCurve(x=curve.getXData(copy=False),
- y=curve.getYData(copy=False),
- legend=curve.getName(),
- info=curve.getInfo(),
- symbol=symbol)
-
- elif ddict['event'] == "toggleLine":
- legend = ddict['legend']
- curve = self.plot.getCurve(legend)
- linestyle = ddict['linestyle'] if ddict['line'] else ''
- self.plot.addCurve(x=curve.getXData(copy=False),
- y=curve.getYData(copy=False),
- legend=curve.getName(),
- info=curve.getInfo(),
- linestyle=linestyle)
+ wasActive = self.plot.getActiveCurve() is curve
+ self.plot.removeItem(curve)
+ curve.setName(newLegend)
+ self.plot.addItem(curve)
+ if wasActive:
+ self.plot.setActiveCurve(curve)
+
+ elif ddict["event"] == "setActiveCurve":
+ self.plot.setActiveCurve(curve)
+
+ elif ddict["event"] == "checkBoxClicked":
+ curve.setVisible(ddict["selected"])
+
+ elif ddict["event"] in ["mapToRight", "mapToLeft"]:
+ curve.setYAxis("right" if ddict["event"] == "mapToRight" else "left")
+
+ elif ddict["event"] == "togglePoints":
+ curve.setSymbol(ddict["symbol"] if ddict["points"] else "")
+
+ elif ddict["event"] == "toggleLine":
+ curve.setLineStyle(ddict["linestyle"] if ddict["line"] else "")
else:
- _logger.debug("unhandled event %s", str(ddict['event']))
+ _logger.debug("unhandled event %s", str(ddict["event"]))
def updateLegends(self, *args):
- """Sync the LegendSelector widget displayed info with the plot.
- """
+ """Sync the LegendSelector widget displayed info with the plot."""
legendList = []
for curve in self.plot.getAllCurves(withhidden=True):
legend = curve.getName()
@@ -1004,15 +1013,17 @@ class LegendsDockWidget(qt.QDockWidget):
color = style.getColor()
if numpy.array(color, copy=False).ndim != 1:
# array of colors, use transparent black
- color = 0., 0., 0., 0.
+ color = 0.0, 0.0, 0.0, 0.0
curveInfo = {
- 'color': qt.QColor.fromRgbF(*color),
- 'linewidth': style.getLineWidth(),
- 'linestyle': style.getLineStyle(),
- 'symbol': style.getSymbol(),
- 'selected': not self.plot.isCurveHidden(legend),
- 'active': isActive}
+ "color": qt.QColor.fromRgbF(*color),
+ "linewidth": style.getLineWidth(),
+ "linestyle": style.getLineStyle(),
+ "symbol": style.getSymbol(),
+ "selected": not self.plot.isCurveHidden(legend),
+ "active": isActive,
+ "item": curve,
+ }
legendList.append((legend, curveInfo))
self._legendWidget.setLegendList(legendList)
diff --git a/src/silx/gui/plot/LimitsHistory.py b/src/silx/gui/plot/LimitsHistory.py
index 7215e37..f4e0afc 100644
--- a/src/silx/gui/plot/LimitsHistory.py
+++ b/src/silx/gui/plot/LimitsHistory.py
@@ -55,8 +55,8 @@ class LimitsHistory(qt.QObject):
"""Append current limits to the history."""
plot = self.parent()
xmin, xmax = plot.getXAxis().getLimits()
- ymin, ymax = plot.getYAxis(axis='left').getLimits()
- y2min, y2max = plot.getYAxis(axis='right').getLimits()
+ ymin, ymax = plot.getYAxis(axis="left").getLimits()
+ y2min, y2max = plot.getYAxis(axis="right").getLimits()
self._history.append((xmin, xmax, ymin, ymax, y2min, y2max))
def pop(self):
diff --git a/src/silx/gui/plot/MaskToolsWidget.py b/src/silx/gui/plot/MaskToolsWidget.py
index 327cdd6..40b2717 100644
--- a/src/silx/gui/plot/MaskToolsWidget.py
+++ b/src/silx/gui/plot/MaskToolsWidget.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2017-2022 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -38,9 +38,12 @@ import os
import sys
import numpy
import logging
-import collections
import h5py
+import fabio
+from fabio.edfimage import EdfImage
+from fabio.TiffIO import TiffIO
+
from silx.image import shapes
from silx.io.utils import NEXUS_HDF5_EXT, is_dataset
from silx.gui.dialog.DatasetDialog import DatasetDialog
@@ -51,14 +54,10 @@ from ..colors import cursorColorForColormap, rgba
from .. import qt
from ..utils import LockReentrant
-from silx.third_party.EdfFile import EdfFile
-from silx.third_party.TiffIO import TiffIO
-
-import fabio
_logger = logging.getLogger(__name__)
-_HDF5_EXT_STR = ' '.join(['*' + ext for ext in NEXUS_HDF5_EXT])
+_HDF5_EXT_STR = " ".join(["*" + ext for ext in NEXUS_HDF5_EXT])
def _selectDataset(filename, mode=DatasetDialog.SaveMode):
@@ -110,16 +109,17 @@ class ImageMask(BaseMask):
or 'msk' (if FabIO is installed)
:raise Exception: Raised if the file writing fail
"""
- if kind == 'edf':
- edfFile = EdfFile(filename, access="w+")
- header = {"program_name": "silx-mask", "masked_value": "nonzero"}
- edfFile.WriteImage(header, self.getMask(copy=False), Append=0)
+ if kind == "edf":
+ EdfImage(
+ data=self.getMask(),
+ header={"program_name": "silx-mask", "masked_value": "nonzero"},
+ ).write(filename)
- elif kind == 'tif':
- tiffFile = TiffIO(filename, mode='w')
- tiffFile.writeImage(self.getMask(copy=False), software='silx')
+ elif kind == "tif":
+ tiffFile = TiffIO(filename, mode="w")
+ tiffFile.writeImage(self.getMask(copy=False), software="silx")
- elif kind == 'npy':
+ elif kind == "npy":
try:
numpy.save(filename, self.getMask(copy=False))
except IOError:
@@ -128,7 +128,7 @@ class ImageMask(BaseMask):
elif ("." + kind) in NEXUS_HDF5_EXT:
self._saveToHdf5(filename, self.getMask(copy=False))
- elif kind == 'msk':
+ elif kind == "msk":
try:
data = self.getMask(copy=False)
image = fabio.fabioimage.FabioImage(data=data)
@@ -159,10 +159,11 @@ class ImageMask(BaseMask):
existing_ds = h5f.get(dataPath)
if existing_ds is not None:
reply = qt.QMessageBox.question(
- None,
- "Confirm overwrite",
- "Do you want to overwrite an existing dataset?",
- qt.QMessageBox.Yes | qt.QMessageBox.No)
+ None,
+ "Confirm overwrite",
+ "Do you want to overwrite an existing dataset?",
+ qt.QMessageBox.Yes | qt.QMessageBox.No,
+ )
if reply != qt.QMessageBox.Yes:
return False
del h5f[dataPath]
@@ -186,10 +187,11 @@ class ImageMask(BaseMask):
assert 0 < level < 256
if row + height <= 0 or col + width <= 0:
return # Rectangle outside image, avoid negative indices
- selection = self._mask[max(0, row):row + height + 1,
- max(0, col):col + width + 1]
+ selection = self._mask[
+ max(0, row) : row + height + 1, max(0, col) : col + width + 1
+ ]
if mask:
- selection[:,:] = level
+ selection[:, :] = level
else:
selection[selection == level] = 0
self._notify()
@@ -205,8 +207,7 @@ class ImageMask(BaseMask):
if mask:
self._mask[fill != 0] = level
else:
- self._mask[numpy.logical_and(fill != 0,
- self._mask == level)] = 0
+ self._mask[numpy.logical_and(fill != 0, self._mask == level)] = 0
self._notify()
def updatePoints(self, level, rows, cols, mask=True):
@@ -221,8 +222,8 @@ class ImageMask(BaseMask):
"""
valid = numpy.logical_and(
numpy.logical_and(rows >= 0, cols >= 0),
- numpy.logical_and(rows < self._mask.shape[0],
- cols < self._mask.shape[1]))
+ numpy.logical_and(rows < self._mask.shape[0], cols < self._mask.shape[1]),
+ )
rows, cols = rows[valid], cols[valid]
if mask:
@@ -278,10 +279,9 @@ class MaskToolsWidget(BaseMaskToolsWidget):
_maxLevelNumber = 255
def __init__(self, parent=None, plot=None):
- super(MaskToolsWidget, self).__init__(parent, plot,
- mask=ImageMask())
- self._origin = (0., 0.) # Mask origin in plot
- self._scale = (1., 1.) # Mask scale in plot
+ super(MaskToolsWidget, self).__init__(parent, plot, mask=ImageMask())
+ self._origin = (0.0, 0.0) # Mask origin in plot
+ self._scale = (1.0, 1.0) # Mask scale in plot
self._z = 1 # Mask layer in plot
self._data = numpy.zeros((0, 0), dtype=numpy.uint8) # Store image
@@ -336,11 +336,11 @@ class MaskToolsWidget(BaseMaskToolsWidget):
mask = numpy.array(mask, copy=False, dtype=numpy.uint8)
if len(mask.shape) != 2:
- _logger.error('Not an image, shape: %d', len(mask.shape))
+ _logger.error("Not an image, shape: %d", len(mask.shape))
return None
# Handle mask with single level
- if self.multipleMasks() == 'single':
+ if self.multipleMasks() == "single":
mask = numpy.array(mask != 0, dtype=numpy.uint8)
# if mask has not changed, do nothing
@@ -352,15 +352,17 @@ class MaskToolsWidget(BaseMaskToolsWidget):
self._mask.commit()
return mask.shape
else:
- _logger.warning('Mask has not the same size as current image.'
- ' Mask will be cropped or padded to fit image'
- ' dimensions. %s != %s',
- str(mask.shape), str(self._data.shape))
- resizedMask = numpy.zeros(self._data.shape[0:2],
- dtype=numpy.uint8)
+ _logger.warning(
+ "Mask has not the same size as current image."
+ " Mask will be cropped or padded to fit image"
+ " dimensions. %s != %s",
+ str(mask.shape),
+ str(self._data.shape),
+ )
+ resizedMask = numpy.zeros(self._data.shape[0:2], dtype=numpy.uint8)
height = min(self._data.shape[0], mask.shape[0])
width = min(self._data.shape[1], mask.shape[1])
- resizedMask[:height,:width] = mask[:height,:width]
+ resizedMask[:height, :width] = mask[:height, :width]
self._mask.setMask(resizedMask, copy=False)
self._mask.commit()
return resizedMask.shape
@@ -387,12 +389,13 @@ class MaskToolsWidget(BaseMaskToolsWidget):
self.plot.addItem(maskItem)
elif self.plot.getImage(self._maskName):
- self.plot.remove(self._maskName, kind='image')
+ self.plot.remove(self._maskName, kind="image")
def showEvent(self, event):
try:
self.plot.sigActiveImageChanged.disconnect(
- self._activeImageChangedAfterCare)
+ self._activeImageChangedAfterCare
+ )
except (RuntimeError, TypeError):
pass
@@ -402,8 +405,7 @@ class MaskToolsWidget(BaseMaskToolsWidget):
def hideEvent(self, event):
try:
- self.plot.sigActiveImageChanged.disconnect(
- self._activeImageChanged)
+ self.plot.sigActiveImageChanged.disconnect(self._activeImageChanged)
except (RuntimeError, TypeError):
pass
@@ -424,11 +426,10 @@ class MaskToolsWidget(BaseMaskToolsWidget):
self._mask.reset()
if self.plot.getImage(self._maskName):
- self.plot.remove(self._maskName, kind='image')
+ self.plot.remove(self._maskName, kind="image")
elif self.getSelectionMask(copy=False) is not None:
- self.plot.sigActiveImageChanged.connect(
- self._activeImageChangedAfterCare)
+ self.plot.sigActiveImageChanged.connect(self._activeImageChangedAfterCare)
def _activeImageChanged(self, previous, current):
"""Reacts upon active image change.
@@ -448,10 +449,9 @@ class MaskToolsWidget(BaseMaskToolsWidget):
"""
if isinstance(image, items.ColormapMixIn):
colormap = image.getColormap()
- self._defaultOverlayColor = rgba(
- cursorColorForColormap(colormap['name']))
+ self._defaultOverlayColor = rgba(cursorColorForColormap(colormap["name"]))
else:
- self._defaultOverlayColor = rgba('black')
+ self._defaultOverlayColor = rgba("black")
def _activeImageChangedAfterCare(self, *args):
"""Check synchro of active image and mask when mask widget is hidden.
@@ -467,15 +467,17 @@ class MaskToolsWidget(BaseMaskToolsWidget):
self._mask.reset()
if self.plot.getImage(self._maskName):
- self.plot.remove(self._maskName, kind='image')
+ self.plot.remove(self._maskName, kind="image")
self.plot.sigActiveImageChanged.disconnect(
- self._activeImageChangedAfterCare)
+ self._activeImageChangedAfterCare
+ )
else:
self._setOverlayColorForImage(activeImage)
- self._setMaskColors(self.levelSpinBox.value(),
- self.transparencySlider.value() /
- self.transparencySlider.maximum())
+ self._setMaskColors(
+ self.levelSpinBox.value(),
+ self.transparencySlider.value() / self.transparencySlider.maximum(),
+ )
self._origin = activeImage.getOrigin()
self._scale = activeImage.getScale()
@@ -484,10 +486,11 @@ class MaskToolsWidget(BaseMaskToolsWidget):
if self._data.shape[:2] != self._mask.getMask(copy=False).shape:
# Image has not the same size, remove mask and stop listening
if self.plot.getImage(self._maskName):
- self.plot.remove(self._maskName, kind='image')
+ self.plot.remove(self._maskName, kind="image")
self.plot.sigActiveImageChanged.disconnect(
- self._activeImageChangedAfterCare)
+ self._activeImageChangedAfterCare
+ )
else:
# Refresh in case origin, scale, z changed
self._mask.setDataItem(activeImage)
@@ -519,11 +522,9 @@ class MaskToolsWidget(BaseMaskToolsWidget):
if self.isItemMaskUpdated():
if image.getMaskData(copy=False) is None:
# Image item has no mask: use current mask from the tool
- image.setMaskData(
- self.getSelectionMask(copy=False), copy=True)
+ image.setMaskData(self.getSelectionMask(copy=False), copy=True)
else: # Image item has a mask: set it in tool
- self.setSelectionMask(
- image.getMaskData(copy=False), copy=True)
+ self.setSelectionMask(image.getMaskData(copy=False), copy=True)
self._mask.resetHistory()
self.__imageUpdated()
if self.isVisible():
@@ -536,17 +537,21 @@ class MaskToolsWidget(BaseMaskToolsWidget):
_logger.error("Mask is not attached to an image")
return
- if event in (items.ItemChangedType.COLORMAP,
- items.ItemChangedType.DATA,
- items.ItemChangedType.POSITION,
- items.ItemChangedType.SCALE,
- items.ItemChangedType.VISIBLE,
- items.ItemChangedType.ZVALUE):
+ if event in (
+ items.ItemChangedType.COLORMAP,
+ items.ItemChangedType.DATA,
+ items.ItemChangedType.POSITION,
+ items.ItemChangedType.SCALE,
+ items.ItemChangedType.VISIBLE,
+ items.ItemChangedType.ZVALUE,
+ ):
self.__imageUpdated()
- elif (event == items.ItemChangedType.MASK and
- self.isItemMaskUpdated() and
- not self.__itemMaskUpdatedLock.locked()):
+ elif (
+ event == items.ItemChangedType.MASK
+ and self.isItemMaskUpdated()
+ and not self.__itemMaskUpdatedLock.locked()
+ ):
# Update mask from the image item unless mask tool is updating it
self.setSelectionMask(image.getMaskData(copy=False), copy=True)
@@ -559,9 +564,10 @@ class MaskToolsWidget(BaseMaskToolsWidget):
self._setOverlayColorForImage(image)
- self._setMaskColors(self.levelSpinBox.value(),
- self.transparencySlider.value() /
- self.transparencySlider.maximum())
+ self._setMaskColors(
+ self.levelSpinBox.value(),
+ self.transparencySlider.value() / self.transparencySlider.maximum(),
+ )
self._origin = image.getOrigin()
self._scale = image.getScale()
@@ -602,26 +608,11 @@ class MaskToolsWidget(BaseMaskToolsWidget):
_logger.error("Can't load filename '%s'", filename)
_logger.debug("Backtrace", exc_info=True)
raise RuntimeError('File "%s" is not a numpy file.', filename)
- elif extension in ["tif", "tiff"]:
- try:
- image = TiffIO(filename, mode="r")
- mask = image.getImage(0)
- except Exception as e:
- _logger.error("Can't load filename %s", filename)
- _logger.debug("Backtrace", exc_info=True)
- raise e
- elif extension == "edf":
- try:
- mask = EdfFile(filename, access='r').GetData(0)
- except Exception as e:
- _logger.error("Can't load filename %s", filename)
- _logger.debug("Backtrace", exc_info=True)
- raise e
- elif extension == "msk":
+ elif extension in ("edf", "msk", "tif", "tiff"):
try:
mask = fabio.open(filename).data
except Exception as e:
- _logger.error("Can't load fit2d mask file")
+ _logger.error(f"Can't load filename {filename}")
_logger.debug("Backtrace", exc_info=True)
raise e
elif ("." + extension) in NEXUS_HDF5_EXT:
@@ -636,7 +627,7 @@ class MaskToolsWidget(BaseMaskToolsWidget):
if effectiveMaskShape is None:
return
if mask.shape != effectiveMaskShape:
- msg = 'Mask was resized from %s to %s'
+ msg = "Mask was resized from %s to %s"
msg = msg % (str(mask.shape), str(effectiveMaskShape))
raise RuntimeWarning(msg)
@@ -646,7 +637,7 @@ class MaskToolsWidget(BaseMaskToolsWidget):
dialog.setWindowTitle("Load Mask")
dialog.setModal(1)
- extensions = collections.OrderedDict()
+ extensions = {}
extensions["EDF files"] = "*.edf"
extensions["TIFF files"] = "*.tif *.tiff"
extensions["NumPy binary files"] = "*.npy"
@@ -714,15 +705,15 @@ class MaskToolsWidget(BaseMaskToolsWidget):
dialog.setWindowTitle("Save Mask")
dialog.setOption(qt.QFileDialog.DontUseNativeDialog)
dialog.setModal(1)
- hdf5Filter = 'HDF5 (%s)' % _HDF5_EXT_STR
+ hdf5Filter = "HDF5 (%s)" % _HDF5_EXT_STR
filters = [
- 'EDF (*.edf)',
- 'TIFF (*.tif)',
- 'NumPy binary file (*.npy)',
+ "EDF (*.edf)",
+ "TIFF (*.tif)",
+ "NumPy binary file (*.npy)",
hdf5Filter,
# Fit2D mask is displayed anyway fabio is here or not
# to show to the user that the option exists
- 'Fit2D mask (*.msk)',
+ "Fit2D mask (*.msk)",
]
dialog.setNameFilters(filters)
dialog.setFileMode(qt.QFileDialog.AnyFile)
@@ -749,8 +740,10 @@ class MaskToolsWidget(BaseMaskToolsWidget):
if "HDF5" in nameFilter:
has_allowed_ext = False
for ext in NEXUS_HDF5_EXT:
- if (len(filename) > len(ext) and
- filename[-len(ext):].lower() == ext.lower()):
+ if (
+ len(filename) > len(ext)
+ and filename[-len(ext) :].lower() == ext.lower()
+ ):
has_allowed_ext = True
extension = ext
if not has_allowed_ext:
@@ -774,8 +767,7 @@ class MaskToolsWidget(BaseMaskToolsWidget):
strerror = e.strerror
else:
strerror = sys.exc_info()[1]
- msg.setText("Cannot save.\n"
- "Input Output Error: %s" % strerror)
+ msg.setText("Cannot save.\n" "Input Output Error: %s" % strerror)
msg.exec()
return
@@ -803,8 +795,10 @@ class MaskToolsWidget(BaseMaskToolsWidget):
def _plotDrawEvent(self, event):
"""Handle draw events from the plot"""
- if (self._drawingMode is None or
- event['event'] not in ('drawingProgress', 'drawingFinished')):
+ if self._drawingMode is None or event["event"] not in (
+ "drawingProgress",
+ "drawingFinished",
+ ):
return
if not len(self._data):
@@ -812,56 +806,54 @@ class MaskToolsWidget(BaseMaskToolsWidget):
level = self.levelSpinBox.value()
- if self._drawingMode == 'rectangle':
- if event['event'] == 'drawingFinished':
+ if self._drawingMode == "rectangle":
+ if event["event"] == "drawingFinished":
# Convert from plot to array coords
doMask = self._isMasking()
ox, oy = self._origin
sx, sy = self._scale
- height = int(abs(event['height'] / sy))
- width = int(abs(event['width'] / sx))
+ height = int(abs(event["height"] / sy))
+ width = int(abs(event["width"] / sx))
- row = int((event['y'] - oy) / sy)
+ row = int((event["y"] - oy) / sy)
if sy < 0:
row -= height
- col = int((event['x'] - ox) / sx)
+ col = int((event["x"] - ox) / sx)
if sx < 0:
col -= width
self._mask.updateRectangle(
- level,
- row=row,
- col=col,
- height=height,
- width=width,
- mask=doMask)
+ level, row=row, col=col, height=height, width=width, mask=doMask
+ )
self._mask.commit()
- elif self._drawingMode == 'ellipse':
- if event['event'] == 'drawingFinished':
+ elif self._drawingMode == "ellipse":
+ if event["event"] == "drawingFinished":
doMask = self._isMasking()
# Convert from plot to array coords
- center = (event['points'][0] - self._origin) / self._scale
- size = event['points'][1] / self._scale
+ center = (event["points"][0] - self._origin) / self._scale
+ size = event["points"][1] / self._scale
center = center.astype(numpy.int64) # (row, col)
- self._mask.updateEllipse(level, center[1], center[0], size[1], size[0], doMask)
+ self._mask.updateEllipse(
+ level, center[1], center[0], size[1], size[0], doMask
+ )
self._mask.commit()
- elif self._drawingMode == 'polygon':
- if event['event'] == 'drawingFinished':
+ elif self._drawingMode == "polygon":
+ if event["event"] == "drawingFinished":
doMask = self._isMasking()
# Convert from plot to array coords
- vertices = (event['points'] - self._origin) / self._scale
+ vertices = (event["points"] - self._origin) / self._scale
vertices = vertices.astype(numpy.int64)[:, (1, 0)] # (row, col)
self._mask.updatePolygon(level, vertices, doMask)
self._mask.commit()
- elif self._drawingMode == 'pencil':
+ elif self._drawingMode == "pencil":
doMask = self._isMasking()
# convert from plot to array coords
- col, row = (event['points'][-1] - self._origin) / self._scale
+ col, row = (event["points"][-1] - self._origin) / self._scale
col, row = int(col), int(row)
brushSize = self._getPencilWidth()
@@ -870,15 +862,18 @@ class MaskToolsWidget(BaseMaskToolsWidget):
# Draw the line
self._mask.updateLine(
level,
- self._lastPencilPos[0], self._lastPencilPos[1],
- row, col,
+ self._lastPencilPos[0],
+ self._lastPencilPos[1],
+ row,
+ col,
brushSize,
- doMask)
+ doMask,
+ )
# Draw the very first, or last point
- self._mask.updateDisk(level, row, col, brushSize / 2., doMask)
+ self._mask.updateDisk(level, row, col, brushSize / 2.0, doMask)
- if event['event'] == 'drawingFinished':
+ if event["event"] == "drawingFinished":
self._mask.commit()
self._lastPencilPos = None
else:
@@ -889,15 +884,17 @@ class MaskToolsWidget(BaseMaskToolsWidget):
def _loadRangeFromColormapTriggered(self):
"""Set range from active image colormap range"""
activeImage = self.plot.getActiveImage()
- if (isinstance(activeImage, items.ColormapMixIn) and
- activeImage.getName() != self._maskName):
+ if (
+ isinstance(activeImage, items.ColormapMixIn)
+ and activeImage.getName() != self._maskName
+ ):
# Update thresholds according to colormap
colormap = activeImage.getColormap()
- if colormap['autoscale']:
+ if colormap["autoscale"]:
min_ = numpy.nanmin(activeImage.getData(copy=False))
max_ = numpy.nanmax(activeImage.getData(copy=False))
else:
- min_, max_ = colormap['vmin'], colormap['vmax']
+ min_, max_ = colormap["vmin"], colormap["vmax"]
self.minLineEdit.setText(str(min_))
self.maxLineEdit.setText(str(max_))
@@ -912,6 +909,6 @@ class MaskToolsDockWidget(BaseMaskToolsDockWidget):
:paran str name: The title of this widget
"""
- def __init__(self, parent=None, plot=None, name='Mask'):
+ def __init__(self, parent=None, plot=None, name="Mask"):
widget = MaskToolsWidget(plot=plot)
super(MaskToolsDockWidget, self).__init__(parent, name, widget)
diff --git a/src/silx/gui/plot/PlotActions.py b/src/silx/gui/plot/PlotActions.py
deleted file mode 100644
index f32be3c..0000000
--- a/src/silx/gui/plot/PlotActions.py
+++ /dev/null
@@ -1,66 +0,0 @@
-# /*##########################################################################
-#
-# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Depracted module linking old PlotAction with the actions.xxx"""
-
-
-__author__ = ["V.A. Sole", "T. Vincent"]
-__license__ = "MIT"
-__date__ = "01/06/2017"
-
-from silx.utils.deprecation import deprecated_warning
-
-deprecated_warning(type_='module',
- name=__file__,
- reason='PlotActions refactoring',
- replacement='plot.actions',
- since_version='0.6')
-
-from .actions import PlotAction
-
-from .actions.io import CopyAction
-from .actions.io import PrintAction
-from .actions.io import SaveAction
-
-from .actions.control import ColormapAction
-from .actions.control import CrosshairAction
-from .actions.control import CurveStyleAction
-from .actions.control import GridAction
-from .actions.control import KeepAspectRatioAction
-from .actions.control import PanWithArrowKeysAction
-from .actions.control import ResetZoomAction
-from .actions.control import XAxisAutoScaleAction
-from .actions.control import XAxisLogarithmicAction
-from .actions.control import YAxisAutoScaleAction
-from .actions.control import YAxisLogarithmicAction
-from .actions.control import YAxisInvertedAction
-from .actions.control import ZoomInAction
-from .actions.control import ZoomOutAction
-
-from .actions.medfilt import MedianFilter1DAction
-from .actions.medfilt import MedianFilter2DAction
-from .actions.medfilt import MedianFilterAction
-
-from .actions.histogram import PixelIntensitiesHistoAction
-
-from .actions.fit import FitAction
diff --git a/src/silx/gui/plot/PlotEvents.py b/src/silx/gui/plot/PlotEvents.py
index be875d7..b4cbe30 100644
--- a/src/silx/gui/plot/PlotEvents.py
+++ b/src/silx/gui/plot/PlotEvents.py
@@ -33,60 +33,71 @@ import numpy as np
def prepareDrawingSignal(event, type_, points, parameters=None):
"""See Plot documentation for content of events"""
- assert event in ('drawingProgress', 'drawingFinished')
+ assert event in ("drawingProgress", "drawingFinished")
if parameters is None:
parameters = {}
eventDict = {}
- eventDict['event'] = event
- eventDict['type'] = type_
+ eventDict["event"] = event
+ eventDict["type"] = type_
points = np.array(points, dtype=np.float32)
points.shape = -1, 2
- eventDict['points'] = points
- eventDict['xdata'] = points[:, 0]
- eventDict['ydata'] = points[:, 1]
- if type_ in ('rectangle',):
- eventDict['x'] = eventDict['xdata'].min()
- eventDict['y'] = eventDict['ydata'].min()
- eventDict['width'] = eventDict['xdata'].max() - eventDict['x']
- eventDict['height'] = eventDict['ydata'].max() - eventDict['y']
- eventDict['parameters'] = parameters.copy()
+ eventDict["points"] = points
+ eventDict["xdata"] = points[:, 0]
+ eventDict["ydata"] = points[:, 1]
+ if type_ in ("rectangle",):
+ eventDict["x"] = eventDict["xdata"].min()
+ eventDict["y"] = eventDict["ydata"].min()
+ eventDict["width"] = eventDict["xdata"].max() - eventDict["x"]
+ eventDict["height"] = eventDict["ydata"].max() - eventDict["y"]
+ eventDict["parameters"] = parameters.copy()
return eventDict
def prepareMouseSignal(eventType, button, xData, yData, xPixel, yPixel):
"""See Plot documentation for content of events"""
- assert eventType in ('mouseMoved', 'mouseClicked', 'mouseDoubleClicked')
- assert button in (None, 'left', 'middle', 'right')
+ assert eventType in ("mouseMoved", "mouseClicked", "mouseDoubleClicked")
+ assert button in (None, "left", "middle", "right")
- return {'event': eventType,
- 'x': xData,
- 'y': yData,
- 'xpixel': xPixel,
- 'ypixel': yPixel,
- 'button': button}
+ return {
+ "event": eventType,
+ "x": xData,
+ "y": yData,
+ "xpixel": xPixel,
+ "ypixel": yPixel,
+ "button": button,
+ }
def prepareHoverSignal(label, type_, posData, posPixel, draggable, selectable):
"""See Plot documentation for content of events"""
- return {'event': 'hover',
- 'label': label,
- 'type': type_,
- 'x': posData[0],
- 'y': posData[1],
- 'xpixel': posPixel[0],
- 'ypixel': posPixel[1],
- 'draggable': draggable,
- 'selectable': selectable}
-
-
-def prepareMarkerSignal(eventType, button, label, type_,
- draggable, selectable,
- posDataMarker,
- posPixelCursor=None, posDataCursor=None):
+ return {
+ "event": "hover",
+ "label": label,
+ "type": type_,
+ "x": posData[0],
+ "y": posData[1],
+ "xpixel": posPixel[0],
+ "ypixel": posPixel[1],
+ "draggable": draggable,
+ "selectable": selectable,
+ }
+
+
+def prepareMarkerSignal(
+ eventType,
+ button,
+ label,
+ type_,
+ draggable,
+ selectable,
+ posDataMarker,
+ posPixelCursor=None,
+ posDataCursor=None,
+):
"""See Plot documentation for content of events"""
- if eventType == 'markerClicked':
+ if eventType == "markerClicked":
assert posPixelCursor is not None
assert posDataCursor is None
@@ -96,11 +107,11 @@ def prepareMarkerSignal(eventType, button, label, type_,
if hasattr(posDataCursor[1], "__len__"):
posDataCursor[1] = posDataCursor[1][-1]
- elif eventType == 'markerMoving':
+ elif eventType == "markerMoving":
assert posPixelCursor is not None
assert posDataCursor is not None
- elif eventType == 'markerMoved':
+ elif eventType == "markerMoved":
assert posPixelCursor is None
assert posDataCursor is None
@@ -108,58 +119,66 @@ def prepareMarkerSignal(eventType, button, label, type_,
else:
raise NotImplementedError("Unknown event type {0}".format(eventType))
- eventDict = {'event': eventType,
- 'button': button,
- 'label': label,
- 'type': type_,
- 'x': posDataCursor[0],
- 'y': posDataCursor[1],
- 'xdata': posDataMarker[0],
- 'ydata': posDataMarker[1],
- 'draggable': draggable,
- 'selectable': selectable}
-
- if eventType in ('markerMoving', 'markerClicked'):
- eventDict['xpixel'] = posPixelCursor[0]
- eventDict['ypixel'] = posPixelCursor[1]
+ eventDict = {
+ "event": eventType,
+ "button": button,
+ "label": label,
+ "type": type_,
+ "x": posDataCursor[0],
+ "y": posDataCursor[1],
+ "xdata": posDataMarker[0],
+ "ydata": posDataMarker[1],
+ "draggable": draggable,
+ "selectable": selectable,
+ }
+
+ if eventType in ("markerMoving", "markerClicked"):
+ eventDict["xpixel"] = posPixelCursor[0]
+ eventDict["ypixel"] = posPixelCursor[1]
return eventDict
-def prepareImageSignal(button, label, type_, col, row,
- x, y, xPixel, yPixel):
+def prepareImageSignal(button, item, col, row, x, y, xPixel, yPixel):
"""See Plot documentation for content of events"""
- return {'event': 'imageClicked',
- 'button': button,
- 'label': label,
- 'type': type_,
- 'col': col,
- 'row': row,
- 'x': x,
- 'y': y,
- 'xpixel': xPixel,
- 'ypixel': yPixel}
-
-
-def prepareCurveSignal(button, label, type_, xData, yData,
- x, y, xPixel, yPixel):
+ return {
+ "event": "imageClicked",
+ "button": button,
+ "item": item,
+ "label": item.getName(),
+ "type": "image",
+ "col": col,
+ "row": row,
+ "x": x,
+ "y": y,
+ "xpixel": xPixel,
+ "ypixel": yPixel,
+ }
+
+
+def prepareCurveSignal(button, item, xData, yData, x, y, xPixel, yPixel):
"""See Plot documentation for content of events"""
- return {'event': 'curveClicked',
- 'button': button,
- 'label': label,
- 'type': type_,
- 'xdata': xData,
- 'ydata': yData,
- 'x': x,
- 'y': y,
- 'xpixel': xPixel,
- 'ypixel': yPixel}
+ return {
+ "event": "curveClicked",
+ "button": button,
+ "item": item,
+ "label": item.getName(),
+ "type": "curve",
+ "xdata": xData,
+ "ydata": yData,
+ "x": x,
+ "y": y,
+ "xpixel": xPixel,
+ "ypixel": yPixel,
+ }
def prepareLimitsChangedSignal(sourceObj, xRange, yRange, y2Range):
"""See Plot documentation for content of events"""
- return {'event': 'limitsChanged',
- 'source': id(sourceObj),
- 'xdata': xRange,
- 'ydata': yRange,
- 'y2data': y2Range}
+ return {
+ "event": "limitsChanged",
+ "source": id(sourceObj),
+ "xdata": xRange,
+ "ydata": yRange,
+ "y2data": y2Range,
+ }
diff --git a/src/silx/gui/plot/PlotInteraction.py b/src/silx/gui/plot/PlotInteraction.py
index c4d64a5..d19bb6d 100644
--- a/src/silx/gui/plot/PlotInteraction.py
+++ b/src/silx/gui/plot/PlotInteraction.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2014-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2014-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -23,6 +23,8 @@
# ###########################################################################*/
"""Implementation of the interaction for the :class:`Plot`."""
+from __future__ import annotations
+
__authors__ = ["T. Vincent"]
__license__ = "MIT"
__date__ = "15/02/2019"
@@ -32,30 +34,53 @@ import math
import numpy
import time
import weakref
+from typing import NamedTuple, Optional
+from silx.gui import qt
from .. import colors
-from .. import qt
from . import items
-from .Interaction import (ClickOrDrag, LEFT_BTN, RIGHT_BTN, MIDDLE_BTN,
- State, StateMachine)
-from .PlotEvents import (prepareCurveSignal, prepareDrawingSignal,
- prepareHoverSignal, prepareImageSignal,
- prepareMarkerSignal, prepareMouseSignal)
-
-from .backends.BackendBase import (CURSOR_POINTING, CURSOR_SIZE_HOR,
- CURSOR_SIZE_VER, CURSOR_SIZE_ALL)
-
-from ._utils import (FLOAT32_SAFE_MIN, FLOAT32_MINPOS, FLOAT32_SAFE_MAX,
- applyZoomToPlot)
+from .Interaction import (
+ ClickOrDrag,
+ LEFT_BTN,
+ RIGHT_BTN,
+ MIDDLE_BTN,
+ State,
+ StateMachine,
+)
+from .PlotEvents import (
+ prepareCurveSignal,
+ prepareDrawingSignal,
+ prepareHoverSignal,
+ prepareImageSignal,
+ prepareMarkerSignal,
+ prepareMouseSignal,
+)
+
+from .backends.BackendBase import (
+ CURSOR_POINTING,
+ CURSOR_SIZE_HOR,
+ CURSOR_SIZE_VER,
+ CURSOR_SIZE_ALL,
+)
+
+from ._utils import (
+ FLOAT32_SAFE_MIN,
+ FLOAT32_MINPOS,
+ FLOAT32_SAFE_MAX,
+ applyZoomToPlot,
+ EnabledAxes,
+)
# Base class ##################################################################
+
class _PlotInteraction(object):
"""Base class for interaction handler.
It provides a weakref to the plot and methods to set/reset overlay.
"""
+
def __init__(self, plot):
"""Init.
@@ -71,7 +96,7 @@ class _PlotInteraction(object):
assert plot is not None
return plot
- def setSelectionArea(self, points, fill, color, name='', shape='polygon'):
+ def setSelectionArea(self, points, fill, color, name="", shape="polygon"):
"""Set a polygon selection area overlaid on the plot.
Multiple simultaneous areas are supported through the name parameter.
@@ -83,7 +108,7 @@ class _PlotInteraction(object):
:param name: The key associated with this selection area
:param str shape: Shape of the area in 'polygon', 'polylines'
"""
- assert shape in ('polygon', 'polylines')
+ assert shape in ("polygon", "polylines")
if color is None:
return
@@ -91,9 +116,9 @@ class _PlotInteraction(object):
points = numpy.asarray(points)
# TODO Not very nice, but as is for now
- legend = '__SELECTION_AREA__' + name
+ legend = "__SELECTION_AREA__" + name
- fill = fill != 'none' # TODO not very nice either
+ fill = fill != "none" # TODO not very nice either
greyed = colors.greyed(color)[0]
if greyed < 0.5:
@@ -101,36 +126,39 @@ class _PlotInteraction(object):
else:
color2 = "black"
- self.plot.addShape(points[:, 0], points[:, 1], legend=legend,
- replace=False,
- shape=shape, fill=fill,
- color=color, linebgcolor=color2, linestyle="--",
- overlay=True)
+ self.plot.addShape(
+ points[:, 0],
+ points[:, 1],
+ legend=legend,
+ replace=False,
+ shape=shape,
+ fill=fill,
+ color=color,
+ gapcolor=color2,
+ linestyle="--",
+ overlay=True,
+ )
self._selectionAreas.add(legend)
def resetSelectionArea(self):
"""Remove all selection areas set by setSelectionArea."""
for legend in self._selectionAreas:
- self.plot.remove(legend, kind='item')
+ self.plot.remove(legend, kind="item")
self._selectionAreas = set()
# Zoom/Pan ####################################################################
-class _ZoomOnWheel(ClickOrDrag, _PlotInteraction):
- """:class:`ClickOrDrag` state machine with zooming on mouse wheel.
+
+class _PlotInteractionWithClickEvents(ClickOrDrag, _PlotInteraction):
+ """:class:`ClickOrDrag` state machine emitting click and double click events.
Base class for :class:`Pan` and :class:`Zoom`
"""
_DOUBLE_CLICK_TIMEOUT = 0.4
- class Idle(ClickOrDrag.Idle):
- def onWheel(self, x, y, angle):
- scaleF = 1.1 if angle > 0 else 1. / 1.1
- applyZoomToPlot(self.machine.plot, scaleF, (x, y))
-
def click(self, x, y, btn):
"""Handle clicks by sending events
@@ -144,18 +172,19 @@ class _ZoomOnWheel(ClickOrDrag, _PlotInteraction):
# Signal mouse double clicked event first
if (time.time() - lastClickTime) <= self._DOUBLE_CLICK_TIMEOUT:
# Use position of first click
- eventDict = prepareMouseSignal('mouseDoubleClicked', 'left',
- *lastClickPos)
+ eventDict = prepareMouseSignal(
+ "mouseDoubleClicked", "left", *lastClickPos
+ )
self.plot.notify(**eventDict)
- self._lastClick = 0., None
+ self._lastClick = 0.0, None
else:
# Signal mouse clicked event
dataPos = self.plot.pixelToData(x, y)
assert dataPos is not None
- eventDict = prepareMouseSignal('mouseClicked', 'left',
- dataPos[0], dataPos[1],
- x, y)
+ eventDict = prepareMouseSignal(
+ "mouseClicked", "left", dataPos[0], dataPos[1], x, y
+ )
self.plot.notify(**eventDict)
self._lastClick = time.time(), (dataPos[0], dataPos[1], x, y)
@@ -164,9 +193,9 @@ class _ZoomOnWheel(ClickOrDrag, _PlotInteraction):
# Signal mouse clicked event
dataPos = self.plot.pixelToData(x, y)
assert dataPos is not None
- eventDict = prepareMouseSignal('mouseClicked', 'right',
- dataPos[0], dataPos[1],
- x, y)
+ eventDict = prepareMouseSignal(
+ "mouseClicked", "right", dataPos[0], dataPos[1], x, y
+ )
self.plot.notify(**eventDict)
def __init__(self, plot, **kwargs):
@@ -174,7 +203,7 @@ class _ZoomOnWheel(ClickOrDrag, _PlotInteraction):
:param plot: The plot to apply modifications to.
"""
- self._lastClick = 0., None
+ self._lastClick = 0.0, None
_PlotInteraction.__init__(self, plot)
ClickOrDrag.__init__(self, **kwargs)
@@ -182,12 +211,13 @@ class _ZoomOnWheel(ClickOrDrag, _PlotInteraction):
# Pan #########################################################################
-class Pan(_ZoomOnWheel):
+
+class Pan(_PlotInteractionWithClickEvents):
"""Pan plot content and zoom on wheel state machine."""
def _pixelToData(self, x, y):
xData, yData = self.plot.pixelToData(x, y)
- _, y2Data = self.plot.pixelToData(x, y, axis='right')
+ _, y2Data = self.plot.pixelToData(x, y, axis="right")
return xData, yData, y2Data
def beginDrag(self, x, y, btn):
@@ -199,13 +229,13 @@ class Pan(_ZoomOnWheel):
xMin, xMax = self.plot.getXAxis().getLimits()
yMin, yMax = self.plot.getYAxis().getLimits()
- y2Min, y2Max = self.plot.getYAxis(axis='right').getLimits()
+ y2Min, y2Max = self.plot.getYAxis(axis="right").getLimits()
if self.plot.getXAxis()._isLogarithmic():
try:
dx = math.log10(xData) - math.log10(lastX)
- newXMin = pow(10., (math.log10(xMin) - dx))
- newXMax = pow(10., (math.log10(xMax) - dx))
+ newXMin = pow(10.0, (math.log10(xMin) - dx))
+ newXMax = pow(10.0, (math.log10(xMax) - dx))
except (ValueError, OverflowError):
newXMin, newXMax = xMin, xMax
@@ -223,19 +253,23 @@ class Pan(_ZoomOnWheel):
if self.plot.getYAxis()._isLogarithmic():
try:
dy = math.log10(yData) - math.log10(lastY)
- newYMin = pow(10., math.log10(yMin) - dy)
- newYMax = pow(10., math.log10(yMax) - dy)
+ newYMin = pow(10.0, math.log10(yMin) - dy)
+ newYMax = pow(10.0, math.log10(yMax) - dy)
dy2 = math.log10(y2Data) - math.log10(lastY2)
- newY2Min = pow(10., math.log10(y2Min) - dy2)
- newY2Max = pow(10., math.log10(y2Max) - dy2)
+ newY2Min = pow(10.0, math.log10(y2Min) - dy2)
+ newY2Max = pow(10.0, math.log10(y2Max) - dy2)
except (ValueError, OverflowError):
newYMin, newYMax = yMin, yMax
newY2Min, newY2Max = y2Min, y2Max
# Makes sure y and y2 stays in positive float32 range
- if (newYMin < FLOAT32_MINPOS or newYMax > FLOAT32_SAFE_MAX or
- newY2Min < FLOAT32_MINPOS or newY2Max > FLOAT32_SAFE_MAX):
+ if (
+ newYMin < FLOAT32_MINPOS
+ or newYMax > FLOAT32_SAFE_MAX
+ or newY2Min < FLOAT32_MINPOS
+ or newY2Max > FLOAT32_SAFE_MAX
+ ):
newYMin, newYMax = yMin, yMax
newY2Min, newY2Max = y2Min, y2Max
else:
@@ -245,16 +279,16 @@ class Pan(_ZoomOnWheel):
newY2Min, newY2Max = y2Min - dy2, y2Max - dy2
# Makes sure y and y2 stays in float32 range
- if (newYMin < FLOAT32_SAFE_MIN or
- newYMax > FLOAT32_SAFE_MAX or
- newY2Min < FLOAT32_SAFE_MIN or
- newY2Max > FLOAT32_SAFE_MAX):
+ if (
+ newYMin < FLOAT32_SAFE_MIN
+ or newYMax > FLOAT32_SAFE_MAX
+ or newY2Min < FLOAT32_SAFE_MIN
+ or newY2Max > FLOAT32_SAFE_MAX
+ ):
newYMin, newYMax = yMin, yMax
newY2Min, newY2Max = y2Min, y2Max
- self.plot.setLimits(newXMin, newXMax,
- newYMin, newYMax,
- newY2Min, newY2Max)
+ self.plot.setLimits(newXMin, newXMax, newYMin, newYMax, newY2Min, newY2Max)
self._previousDataPos = self._pixelToData(x, y)
@@ -267,7 +301,17 @@ class Pan(_ZoomOnWheel):
# Zoom ########################################################################
-class Zoom(_ZoomOnWheel):
+
+class AxesExtent(NamedTuple):
+ xmin: float
+ xmax: float
+ ymin: float
+ ymax: float
+ y2min: float
+ y2max: float
+
+
+class Zoom(_PlotInteractionWithClickEvents):
"""Zoom-in/out state machine.
Zoom-in on selected area, zoom-out on right click,
@@ -278,34 +322,67 @@ class Zoom(_ZoomOnWheel):
def __init__(self, plot, color):
self.color = color
+ self.enabledAxes = EnabledAxes()
super(Zoom, self).__init__(plot)
self.plot.getLimitsHistory().clear()
- def _areaWithAspectRatio(self, x0, y0, x1, y1):
- _plotLeft, _plotTop, plotW, plotH = self.plot.getPlotBoundsInPixels()
-
- areaX0, areaY0, areaX1, areaY1 = x0, y0, x1, y1
-
- if plotH != 0.:
- plotRatio = plotW / float(plotH)
- width, height = math.fabs(x1 - x0), math.fabs(y1 - y0)
-
- if height != 0. and width != 0.:
- if width / height > plotRatio:
- areaHeight = width / plotRatio
- areaX0, areaX1 = x0, x1
+ def _getAxesExtent(
+ self,
+ x0: float,
+ y0: float,
+ x1: float,
+ y1: float,
+ enabledAxes: Optional[EnabledAxes] = None,
+ ) -> AxesExtent:
+ """Convert selection coordinates (pixels) to axes coordinates (data)
+
+ This takes into account axes selected for zoom and aspect ratio.
+ """
+ if enabledAxes is None:
+ enabledAxes = self.enabledAxes
+
+ y2_0, y2_1 = y0, y1
+ left, top, width, height = self.plot.getPlotBoundsInPixels()
+
+ if not all(enabledAxes) and not self.plot.isKeepDataAspectRatio():
+ # Handle axes disabled for zoom if plot is not keeping aspec ratio
+ if not enabledAxes.xaxis:
+ x0, x1 = left, left + width
+ if not enabledAxes.yaxis:
+ y0, y1 = top, top + height
+ if not enabledAxes.y2axis:
+ y2_0, y2_1 = top, top + height
+
+ if self.plot.isKeepDataAspectRatio() and height != 0 and width != 0:
+ ratio = width / height
+ xextent, yextent = math.fabs(x1 - x0), math.fabs(y1 - y0)
+ if xextent != 0 and yextent != 0:
+ if xextent / yextent > ratio:
+ areaHeight = xextent / ratio
center = 0.5 * (y0 + y1)
- areaY0 = center - numpy.sign(y1 - y0) * 0.5 * areaHeight
- areaY1 = center + numpy.sign(y1 - y0) * 0.5 * areaHeight
+ y0 = center - numpy.sign(y1 - y0) * 0.5 * areaHeight
+ y1 = center + numpy.sign(y1 - y0) * 0.5 * areaHeight
else:
- areaWidth = height * plotRatio
- areaY0, areaY1 = y0, y1
+ areaWidth = yextent * ratio
center = 0.5 * (x0 + x1)
- areaX0 = center - numpy.sign(x1 - x0) * 0.5 * areaWidth
- areaX1 = center + numpy.sign(x1 - x0) * 0.5 * areaWidth
+ x0 = center - numpy.sign(x1 - x0) * 0.5 * areaWidth
+ x1 = center + numpy.sign(x1 - x0) * 0.5 * areaWidth
- return areaX0, areaY0, areaX1, areaY1
+ # Convert to data space
+ x0, y0 = self.plot.pixelToData(x0, y0, check=False)
+ x1, y1 = self.plot.pixelToData(x1, y1, check=False)
+ y2_0 = self.plot.pixelToData(None, y2_0, axis="right", check=False)[1]
+ y2_1 = self.plot.pixelToData(None, y2_1, axis="right", check=False)[1]
+
+ return AxesExtent(
+ min(x0, x1),
+ max(x0, x1),
+ min(y0, y1),
+ max(y0, y1),
+ min(y2_0, y2_1),
+ max(y2_0, y2_1),
+ )
def beginDrag(self, x, y, btn):
dataPos = self.plot.pixelToData(x, y)
@@ -319,66 +396,54 @@ class Zoom(_ZoomOnWheel):
dataPos = self.plot.pixelToData(x1, y1)
assert dataPos is not None
- if self.plot.isKeepDataAspectRatio():
- area = self._areaWithAspectRatio(self.x0, self.y0, x1, y1)
- areaX0, areaY0, areaX1, areaY1 = area
- areaPoints = ((areaX0, areaY0),
- (areaX1, areaY0),
- (areaX1, areaY1),
- (areaX0, areaY1))
- areaPoints = numpy.array([self.plot.pixelToData(
- x, y, check=False) for (x, y) in areaPoints])
-
- if self.color != 'video inverted':
+ if self.plot.isKeepDataAspectRatio() or not all(self.enabledAxes):
+ # Patch enabledAxes to display the right Y axis area on the left Y axis
+ # since the selection area is always displayed on the left Y axis
+ isY2Visible = self.plot.getYAxis("right").isVisible()
+ areaZoomEnabledAxes = EnabledAxes(
+ self.enabledAxes.xaxis,
+ self.enabledAxes.yaxis and (not isY2Visible or self.enabledAxes.y2axis),
+ self.enabledAxes.y2axis,
+ )
+ extents = self._getAxesExtent(self.x0, self.y0, x1, y1, areaZoomEnabledAxes)
+ areaCorners = (
+ (extents.xmin, extents.ymin),
+ (extents.xmax, extents.ymin),
+ (extents.xmax, extents.ymax),
+ (extents.xmin, extents.ymax),
+ )
+
+ if self.color != "video inverted":
areaColor = list(self.color)
areaColor[3] *= 0.25
else:
- areaColor = [1., 1., 1., 1.]
+ areaColor = [1.0, 1.0, 1.0, 1.0]
- self.setSelectionArea(areaPoints,
- fill='none',
- color=areaColor,
- name="zoomedArea")
+ self.setSelectionArea(
+ areaCorners, fill="none", color=areaColor, name="zoomedArea"
+ )
- corners = ((self.x0, self.y0),
- (self.x0, y1),
- (x1, y1),
- (x1, self.y0))
- corners = numpy.array([self.plot.pixelToData(x, y, check=False)
- for (x, y) in corners])
+ corners = ((self.x0, self.y0), (self.x0, y1), (x1, y1), (x1, self.y0))
+ corners = numpy.array(
+ [self.plot.pixelToData(x, y, check=False) for (x, y) in corners]
+ )
- self.setSelectionArea(corners, fill='none', color=self.color)
+ self.setSelectionArea(corners, fill="none", color=self.color)
def _zoom(self, x0, y0, x1, y1):
- """Zoom to the rectangle view x0,y0 x1,y1.
- """
- startPos = x0, y0
- endPos = x1, y1
-
+ """Zoom to the rectangle view x0,y0 x1,y1."""
# Store current zoom state in stack
self.plot.getLimitsHistory().push()
- if self.plot.isKeepDataAspectRatio():
- x0, y0, x1, y1 = self._areaWithAspectRatio(x0, y0, x1, y1)
-
- # Convert to data space and set limits
- x0, y0 = self.plot.pixelToData(x0, y0, check=False)
-
- dataPos = self.plot.pixelToData(
- startPos[0], startPos[1], axis="right", check=False)
- y2_0 = dataPos[1]
-
- x1, y1 = self.plot.pixelToData(x1, y1, check=False)
-
- dataPos = self.plot.pixelToData(
- endPos[0], endPos[1], axis="right", check=False)
- y2_1 = dataPos[1]
-
- xMin, xMax = min(x0, x1), max(x0, x1)
- yMin, yMax = min(y0, y1), max(y0, y1)
- y2Min, y2Max = min(y2_0, y2_1), max(y2_0, y2_1)
-
- self.plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
+ extents = self._getAxesExtent(x0, y0, x1, y1)
+ self.plot.setLimits(
+ extents.xmin,
+ extents.xmax,
+ extents.ymin,
+ extents.ymax,
+ extents.y2min,
+ extents.y2max,
+ )
def endDrag(self, startPos, endPos, btn):
x0, y0 = startPos
@@ -391,12 +456,13 @@ class Zoom(_ZoomOnWheel):
self.resetSelectionArea()
def cancel(self):
- if isinstance(self.state, self.states['drag']):
+ if isinstance(self.state, self.states["drag"]):
self.resetSelectionArea()
# Select ######################################################################
+
class Select(StateMachine, _PlotInteraction):
"""Base class for drawing selection areas."""
@@ -412,13 +478,9 @@ class Select(StateMachine, _PlotInteraction):
self.parameters = parameters
StateMachine.__init__(self, states, state)
- def onWheel(self, x, y, angle):
- scaleF = 1.1 if angle > 0 else 1. / 1.1
- applyZoomToPlot(self.plot, scaleF, (x, y))
-
@property
def color(self):
- return self.parameters.get('color', None)
+ return self.parameters.get("color", None)
class SelectPolygon(Select):
@@ -429,7 +491,7 @@ class SelectPolygon(Select):
class Idle(State):
def onPress(self, x, y, btn):
if btn == LEFT_BTN:
- self.goto('select', x, y)
+ self.goto("select", x, y)
return True
class Select(State):
@@ -446,25 +508,28 @@ class SelectPolygon(Select):
x, y = self.machine.plot.dataToPixel(*self._firstPos, check=False)
offset = self.machine.getDragThreshold()
- points = [(x - offset, y - offset),
- (x - offset, y + offset),
- (x + offset, y + offset),
- (x + offset, y - offset)]
- points = [self.machine.plot.pixelToData(xpix, ypix, check=False)
- for xpix, ypix in points]
- self.machine.setSelectionArea(points, fill=None,
- color=self.machine.color,
- name='first_point')
+ points = [
+ (x - offset, y - offset),
+ (x - offset, y + offset),
+ (x + offset, y + offset),
+ (x + offset, y - offset),
+ ]
+ points = [
+ self.machine.plot.pixelToData(xpix, ypix, check=False)
+ for xpix, ypix in points
+ ]
+ self.machine.setSelectionArea(
+ points, fill=None, color=self.machine.color, name="first_point"
+ )
def updateSelectionArea(self):
"""Update drawing selection area using self.points"""
- self.machine.setSelectionArea(self.points,
- fill='hatch',
- color=self.machine.color)
- eventDict = prepareDrawingSignal('drawingProgress',
- 'polygon',
- self.points,
- self.machine.parameters)
+ self.machine.setSelectionArea(
+ self.points, fill="hatch", color=self.machine.color
+ )
+ eventDict = prepareDrawingSignal(
+ "drawingProgress", "polygon", self.points, self.machine.parameters
+ )
self.machine.plot.notify(**eventDict)
def validate(self):
@@ -478,12 +543,11 @@ class SelectPolygon(Select):
def closePolygon(self):
self.machine.resetSelectionArea()
self.points[-1] = self.points[0]
- eventDict = prepareDrawingSignal('drawingFinished',
- 'polygon',
- self.points,
- self.machine.parameters)
+ eventDict = prepareDrawingSignal(
+ "drawingFinished", "polygon", self.points, self.machine.parameters
+ )
self.machine.plot.notify(**eventDict)
- self.goto('idle')
+ self.goto("idle")
def onWheel(self, x, y, angle):
self.machine.onWheel(x, y, angle)
@@ -493,8 +557,7 @@ class SelectPolygon(Select):
if btn == LEFT_BTN:
# checking if the position is close to the first point
# if yes : closing the "loop"
- firstPos = self.machine.plot.dataToPixel(*self._firstPos,
- check=False)
+ firstPos = self.machine.plot.dataToPixel(*self._firstPos, check=False)
dx, dy = abs(firstPos[0] - x), abs(firstPos[1] - y)
threshold = self.machine.getDragThreshold()
@@ -516,8 +579,9 @@ class SelectPolygon(Select):
# in Idle state, but with a slightly different position that
# the mouse press. So we had the two first vertices that were
# almost identical.
- previousPos = self.machine.plot.dataToPixel(*self.points[-2],
- check=False)
+ previousPos = self.machine.plot.dataToPixel(
+ *self.points[-2], check=False
+ )
dx, dy = abs(previousPos[0] - x), abs(previousPos[1] - y)
if dx >= threshold or dy >= threshold:
self.points.append(dataPos)
@@ -528,8 +592,7 @@ class SelectPolygon(Select):
return False
def onMove(self, x, y):
- firstPos = self.machine.plot.dataToPixel(*self._firstPos,
- check=False)
+ firstPos = self.machine.plot.dataToPixel(*self._firstPos, check=False)
dx, dy = abs(firstPos[0] - x), abs(firstPos[1] - y)
threshold = self.machine.getDragThreshold()
@@ -542,15 +605,11 @@ class SelectPolygon(Select):
self.updateSelectionArea()
def __init__(self, plot, parameters):
- states = {
- 'idle': SelectPolygon.Idle,
- 'select': SelectPolygon.Select
- }
- super(SelectPolygon, self).__init__(plot, parameters,
- states, 'idle')
+ states = {"idle": SelectPolygon.Idle, "select": SelectPolygon.Select}
+ super(SelectPolygon, self).__init__(plot, parameters, states, "idle")
def cancel(self):
- if isinstance(self.state, self.states['select']):
+ if isinstance(self.state, self.states["select"]):
self.resetSelectionArea()
def getDragThreshold(self):
@@ -564,10 +623,11 @@ class SelectPolygon(Select):
class Select2Points(Select):
"""Base class for drawing selection based on 2 input points."""
+
class Idle(State):
def onPress(self, x, y, btn):
if btn == LEFT_BTN:
- self.goto('start', x, y)
+ self.goto("start", x, y)
return True
class Start(State):
@@ -575,11 +635,11 @@ class Select2Points(Select):
self.machine.beginSelect(x, y)
def onMove(self, x, y):
- self.goto('select', x, y)
+ self.goto("select", x, y)
def onRelease(self, x, y, btn):
if btn == LEFT_BTN:
- self.goto('select', x, y)
+ self.goto("select", x, y)
return True
class Select(State):
@@ -592,16 +652,15 @@ class Select2Points(Select):
def onRelease(self, x, y, btn):
if btn == LEFT_BTN:
self.machine.endSelect(x, y)
- self.goto('idle')
+ self.goto("idle")
def __init__(self, plot, parameters):
states = {
- 'idle': Select2Points.Idle,
- 'start': Select2Points.Start,
- 'select': Select2Points.Select
+ "idle": Select2Points.Idle,
+ "start": Select2Points.Start,
+ "select": Select2Points.Select,
}
- super(Select2Points, self).__init__(plot, parameters,
- states, 'idle')
+ super(Select2Points, self).__init__(plot, parameters, states, "idle")
def beginSelect(self, x, y):
pass
@@ -616,12 +675,13 @@ class Select2Points(Select):
pass
def cancel(self):
- if isinstance(self.state, self.states['select']):
+ if isinstance(self.state, self.states["select"]):
self.cancelSelect()
class SelectEllipse(Select2Points):
"""Drawing ellipse selection area state machine."""
+
def beginSelect(self, x, y):
self.center = self.plot.pixelToData(x, y)
assert self.center is not None
@@ -667,21 +727,23 @@ class SelectEllipse(Select2Points):
width, height = self._getEllipseSize(dataPos)
# Circle used for circle preview
- nbpoints = 27.
+ nbpoints = 27.0
angles = numpy.arange(nbpoints) * numpy.pi * 2.0 / nbpoints
- circleShape = numpy.array((numpy.cos(angles) * width,
- numpy.sin(angles) * height)).T
+ circleShape = numpy.array(
+ (numpy.cos(angles) * width, numpy.sin(angles) * height)
+ ).T
circleShape += numpy.array(self.center)
- self.setSelectionArea(circleShape,
- shape="polygon",
- fill='hatch',
- color=self.color)
+ self.setSelectionArea(
+ circleShape, shape="polygon", fill="hatch", color=self.color
+ )
- eventDict = prepareDrawingSignal('drawingProgress',
- 'ellipse',
- (self.center, (width, height)),
- self.parameters)
+ eventDict = prepareDrawingSignal(
+ "drawingProgress",
+ "ellipse",
+ (self.center, (width, height)),
+ self.parameters,
+ )
self.plot.notify(**eventDict)
def endSelect(self, x, y):
@@ -691,10 +753,12 @@ class SelectEllipse(Select2Points):
assert dataPos is not None
width, height = self._getEllipseSize(dataPos)
- eventDict = prepareDrawingSignal('drawingFinished',
- 'ellipse',
- (self.center, (width, height)),
- self.parameters)
+ eventDict = prepareDrawingSignal(
+ "drawingFinished",
+ "ellipse",
+ (self.center, (width, height)),
+ self.parameters,
+ )
self.plot.notify(**eventDict)
def cancelSelect(self):
@@ -703,6 +767,7 @@ class SelectEllipse(Select2Points):
class SelectRectangle(Select2Points):
"""Drawing rectangle selection area state machine."""
+
def beginSelect(self, x, y):
self.startPt = self.plot.pixelToData(x, y)
assert self.startPt is not None
@@ -711,17 +776,20 @@ class SelectRectangle(Select2Points):
dataPos = self.plot.pixelToData(x, y)
assert dataPos is not None
- self.setSelectionArea((self.startPt,
- (self.startPt[0], dataPos[1]),
- dataPos,
- (dataPos[0], self.startPt[1])),
- fill='hatch',
- color=self.color)
-
- eventDict = prepareDrawingSignal('drawingProgress',
- 'rectangle',
- (self.startPt, dataPos),
- self.parameters)
+ self.setSelectionArea(
+ (
+ self.startPt,
+ (self.startPt[0], dataPos[1]),
+ dataPos,
+ (dataPos[0], self.startPt[1]),
+ ),
+ fill="hatch",
+ color=self.color,
+ )
+
+ eventDict = prepareDrawingSignal(
+ "drawingProgress", "rectangle", (self.startPt, dataPos), self.parameters
+ )
self.plot.notify(**eventDict)
def endSelect(self, x, y):
@@ -730,10 +798,9 @@ class SelectRectangle(Select2Points):
dataPos = self.plot.pixelToData(x, y)
assert dataPos is not None
- eventDict = prepareDrawingSignal('drawingFinished',
- 'rectangle',
- (self.startPt, dataPos),
- self.parameters)
+ eventDict = prepareDrawingSignal(
+ "drawingFinished", "rectangle", (self.startPt, dataPos), self.parameters
+ )
self.plot.notify(**eventDict)
def cancelSelect(self):
@@ -742,6 +809,7 @@ class SelectRectangle(Select2Points):
class SelectLine(Select2Points):
"""Drawing line selection area state machine."""
+
def beginSelect(self, x, y):
self.startPt = self.plot.pixelToData(x, y)
assert self.startPt is not None
@@ -750,14 +818,11 @@ class SelectLine(Select2Points):
dataPos = self.plot.pixelToData(x, y)
assert dataPos is not None
- self.setSelectionArea((self.startPt, dataPos),
- fill='hatch',
- color=self.color)
+ self.setSelectionArea((self.startPt, dataPos), fill="hatch", color=self.color)
- eventDict = prepareDrawingSignal('drawingProgress',
- 'line',
- (self.startPt, dataPos),
- self.parameters)
+ eventDict = prepareDrawingSignal(
+ "drawingProgress", "line", (self.startPt, dataPos), self.parameters
+ )
self.plot.notify(**eventDict)
def endSelect(self, x, y):
@@ -766,10 +831,9 @@ class SelectLine(Select2Points):
dataPos = self.plot.pixelToData(x, y)
assert dataPos is not None
- eventDict = prepareDrawingSignal('drawingFinished',
- 'line',
- (self.startPt, dataPos),
- self.parameters)
+ eventDict = prepareDrawingSignal(
+ "drawingFinished", "line", (self.startPt, dataPos), self.parameters
+ )
self.plot.notify(**eventDict)
def cancelSelect(self):
@@ -778,10 +842,11 @@ class SelectLine(Select2Points):
class Select1Point(Select):
"""Base class for drawing selection area based on one input point."""
+
class Idle(State):
def onPress(self, x, y, btn):
if btn == LEFT_BTN:
- self.goto('select', x, y)
+ self.goto("select", x, y)
return True
class Select(State):
@@ -794,18 +859,15 @@ class Select1Point(Select):
def onRelease(self, x, y, btn):
if btn == LEFT_BTN:
self.machine.endSelect(x, y)
- self.goto('idle')
+ self.goto("idle")
def onWheel(self, x, y, angle):
self.machine.onWheel(x, y, angle) # Call select default wheel
self.machine.select(x, y)
def __init__(self, plot, parameters):
- states = {
- 'idle': Select1Point.Idle,
- 'select': Select1Point.Select
- }
- super(Select1Point, self).__init__(plot, parameters, states, 'idle')
+ states = {"idle": Select1Point.Idle, "select": Select1Point.Select}
+ super(Select1Point, self).__init__(plot, parameters, states, "idle")
def select(self, x, y):
pass
@@ -817,12 +879,13 @@ class Select1Point(Select):
pass
def cancel(self):
- if isinstance(self.state, self.states['select']):
+ if isinstance(self.state, self.states["select"]):
self.cancelSelect()
class SelectHLine(Select1Point):
"""Drawing a horizontal line selection area state machine."""
+
def _hLine(self, y):
"""Return points in data coords of the segment visible in the plot.
@@ -836,21 +899,19 @@ class SelectHLine(Select1Point):
def select(self, x, y):
points = self._hLine(y)
- self.setSelectionArea(points, fill='hatch', color=self.color)
+ self.setSelectionArea(points, fill="hatch", color=self.color)
- eventDict = prepareDrawingSignal('drawingProgress',
- 'hline',
- points,
- self.parameters)
+ eventDict = prepareDrawingSignal(
+ "drawingProgress", "hline", points, self.parameters
+ )
self.plot.notify(**eventDict)
def endSelect(self, x, y):
self.resetSelectionArea()
- eventDict = prepareDrawingSignal('drawingFinished',
- 'hline',
- self._hLine(y),
- self.parameters)
+ eventDict = prepareDrawingSignal(
+ "drawingFinished", "hline", self._hLine(y), self.parameters
+ )
self.plot.notify(**eventDict)
def cancelSelect(self):
@@ -859,6 +920,7 @@ class SelectHLine(Select1Point):
class SelectVLine(Select1Point):
"""Drawing a vertical line selection area state machine."""
+
def _vLine(self, x):
"""Return points in data coords of the segment visible in the plot.
@@ -872,21 +934,19 @@ class SelectVLine(Select1Point):
def select(self, x, y):
points = self._vLine(x)
- self.setSelectionArea(points, fill='hatch', color=self.color)
+ self.setSelectionArea(points, fill="hatch", color=self.color)
- eventDict = prepareDrawingSignal('drawingProgress',
- 'vline',
- points,
- self.parameters)
+ eventDict = prepareDrawingSignal(
+ "drawingProgress", "vline", points, self.parameters
+ )
self.plot.notify(**eventDict)
def endSelect(self, x, y):
self.resetSelectionArea()
- eventDict = prepareDrawingSignal('drawingFinished',
- 'vline',
- self._vLine(x),
- self.parameters)
+ eventDict = prepareDrawingSignal(
+ "drawingFinished", "vline", self._vLine(x), self.parameters
+ )
self.plot.notify(**eventDict)
def cancelSelect(self):
@@ -901,7 +961,7 @@ class DrawFreeHand(Select):
class Idle(State):
def onPress(self, x, y, btn):
if btn == LEFT_BTN:
- self.goto('select', x, y)
+ self.goto("select", x, y)
return True
def onMove(self, x, y):
@@ -924,7 +984,7 @@ class DrawFreeHand(Select):
if self.__isOut:
self.machine.resetSelectionArea()
self.machine.endSelect(x, y)
- self.goto('idle')
+ self.goto("idle")
def onEnter(self):
self.__isOut = False
@@ -934,20 +994,16 @@ class DrawFreeHand(Select):
def __init__(self, plot, parameters):
# Circle used for pencil preview
- angle = numpy.arange(13.) * numpy.pi * 2.0 / 13.
- size = parameters.get('width', 1.) * 0.5
- self._circle = size * numpy.array((numpy.cos(angle),
- numpy.sin(angle))).T
+ angle = numpy.arange(13.0) * numpy.pi * 2.0 / 13.0
+ size = parameters.get("width", 1.0) * 0.5
+ self._circle = size * numpy.array((numpy.cos(angle), numpy.sin(angle))).T
- states = {
- 'idle': DrawFreeHand.Idle,
- 'select': DrawFreeHand.Select
- }
- super(DrawFreeHand, self).__init__(plot, parameters, states, 'idle')
+ states = {"idle": DrawFreeHand.Idle, "select": DrawFreeHand.Select}
+ super(DrawFreeHand, self).__init__(plot, parameters, states, "idle")
@property
def width(self):
- return self.parameters.get('width', None)
+ return self.parameters.get("width", None)
def setFirstPoint(self, x, y):
self._points = []
@@ -959,7 +1015,7 @@ class DrawFreeHand(Select):
polygon = center + self._circle
- self.setSelectionArea(polygon, fill='none', color=self.color)
+ self.setSelectionArea(polygon, fill="none", color=self.color)
def select(self, x, y):
pos = self.plot.pixelToData(x, y, check=False)
@@ -968,10 +1024,9 @@ class DrawFreeHand(Select):
# Skip same points
return
self._points.append(pos)
- eventDict = prepareDrawingSignal('drawingProgress',
- 'polylines',
- self._points,
- self.parameters)
+ eventDict = prepareDrawingSignal(
+ "drawingProgress", "polylines", self._points, self.parameters
+ )
self.plot.notify(**eventDict)
def endSelect(self, x, y):
@@ -981,10 +1036,9 @@ class DrawFreeHand(Select):
# Append if different
self._points.append(pos)
- eventDict = prepareDrawingSignal('drawingFinished',
- 'polylines',
- self._points,
- self.parameters)
+ eventDict = prepareDrawingSignal(
+ "drawingFinished", "polylines", self._points, self.parameters
+ )
self.plot.notify(**eventDict)
self._points = None
@@ -1010,13 +1064,9 @@ class SelectFreeLine(ClickOrDrag, _PlotInteraction):
_PlotInteraction.__init__(self, plot)
self.parameters = parameters
- def onWheel(self, x, y, angle):
- scaleF = 1.1 if angle > 0 else 1. / 1.1
- applyZoomToPlot(self.plot, scaleF, (x, y))
-
@property
def color(self):
- return self.parameters.get('color', None)
+ return self.parameters.get("color", None)
def click(self, x, y, btn):
if btn == LEFT_BTN:
@@ -1045,21 +1095,24 @@ class SelectFreeLine(ClickOrDrag, _PlotInteraction):
if isNewPoint or isLast:
eventDict = prepareDrawingSignal(
- 'drawingFinished' if isLast else 'drawingProgress',
- 'polylines',
+ "drawingFinished" if isLast else "drawingProgress",
+ "polylines",
self._points,
- self.parameters)
+ self.parameters,
+ )
self.plot.notify(**eventDict)
if not isLast:
- self.setSelectionArea(self._points, fill='none', color=self.color,
- shape='polylines')
+ self.setSelectionArea(
+ self._points, fill="none", color=self.color, shape="polylines"
+ )
else:
self.cancel()
# ItemInteraction #############################################################
+
class ItemsInteraction(ClickOrDrag, _PlotInteraction):
"""Interaction with items (markers, curves and images).
@@ -1073,9 +1126,12 @@ class ItemsInteraction(ClickOrDrag, _PlotInteraction):
super(ItemsInteraction.Idle, self).__init__(*args, **kw)
self._hoverMarker = None
- def onWheel(self, x, y, angle):
- scaleF = 1.1 if angle > 0 else 1. / 1.1
- applyZoomToPlot(self.machine.plot, scaleF, (x, y))
+ def enterState(self):
+ widget = self.machine.plot.getWidgetHandle()
+ if widget is None or not widget.isVisible():
+ return
+ position = widget.mapFromGlobal(qt.QCursor.pos())
+ self.onMove(position.x(), position.y())
def onMove(self, x, y):
marker = self.machine.plot._getMarkerAt(x, y)
@@ -1084,30 +1140,18 @@ class ItemsInteraction(ClickOrDrag, _PlotInteraction):
dataPos = self.machine.plot.pixelToData(x, y)
assert dataPos is not None
eventDict = prepareHoverSignal(
- marker.getName(), 'marker',
- dataPos, (x, y),
+ marker.getName(),
+ "marker",
+ dataPos,
+ (x, y),
marker.isDraggable(),
- marker.isSelectable())
+ marker.isSelectable(),
+ )
self.machine.plot.notify(**eventDict)
if marker != self._hoverMarker:
self._hoverMarker = marker
-
- if marker is None:
- self.machine.plot.setGraphCursorShape()
-
- elif marker.isDraggable():
- if isinstance(marker, items.YMarker):
- self.machine.plot.setGraphCursorShape(CURSOR_SIZE_VER)
- elif isinstance(marker, items.XMarker):
- self.machine.plot.setGraphCursorShape(CURSOR_SIZE_HOR)
- else:
- self.machine.plot.setGraphCursorShape(CURSOR_SIZE_ALL)
-
- elif marker.isSelectable():
- self.machine.plot.setGraphCursorShape(CURSOR_POINTING)
- else:
- self.machine.plot.setGraphCursorShape()
+ self.machine._setCursorForMarker(marker)
return True
@@ -1115,9 +1159,30 @@ class ItemsInteraction(ClickOrDrag, _PlotInteraction):
self._pan = Pan(plot)
_PlotInteraction.__init__(self, plot)
- ClickOrDrag.__init__(self,
- clickButtons=(LEFT_BTN, RIGHT_BTN),
- dragButtons=(LEFT_BTN, MIDDLE_BTN))
+ ClickOrDrag.__init__(
+ self, clickButtons=(LEFT_BTN, RIGHT_BTN), dragButtons=(LEFT_BTN, MIDDLE_BTN)
+ )
+
+ def _setCursorForMarker(self, marker: Optional[items.MarkerBase] = None):
+ """Set mouse cursor for given marker"""
+ if marker is None:
+ cursor = None
+
+ elif marker.isDraggable():
+ if isinstance(marker, items.YMarker):
+ cursor = CURSOR_SIZE_VER
+ elif isinstance(marker, items.XMarker):
+ cursor = CURSOR_SIZE_HOR
+ else:
+ cursor = CURSOR_SIZE_ALL
+
+ elif marker.isSelectable():
+ cursor = CURSOR_POINTING
+
+ else:
+ cursor = None
+
+ self.plot.setGraphCursorShape(cursor)
def click(self, x, y, btn):
"""Handle mouse click
@@ -1130,9 +1195,9 @@ class ItemsInteraction(ClickOrDrag, _PlotInteraction):
# Signal mouse clicked event
dataPos = self.plot.pixelToData(x, y)
assert dataPos is not None
- eventDict = prepareMouseSignal('mouseClicked', btn,
- dataPos[0], dataPos[1],
- x, y)
+ eventDict = prepareMouseSignal(
+ "mouseClicked", btn, dataPos[0], dataPos[1], x, y
+ )
self.plot.notify(**eventDict)
eventDict = self._handleClick(x, y, btn)
@@ -1163,14 +1228,17 @@ class ItemsInteraction(ClickOrDrag, _PlotInteraction):
if yData is None:
yData = [0, 1]
- eventDict = prepareMarkerSignal('markerClicked',
- 'left',
- item.getName(),
- 'marker',
- item.isDraggable(),
- item.isSelectable(),
- (xData, yData),
- (x, y), None)
+ eventDict = prepareMarkerSignal(
+ "markerClicked",
+ "left",
+ item.getName(),
+ "marker",
+ item.isDraggable(),
+ item.isSelectable(),
+ (xData, yData),
+ (x, y),
+ None,
+ )
return eventDict
elif isinstance(item, items.Curve):
@@ -1181,13 +1249,16 @@ class ItemsInteraction(ClickOrDrag, _PlotInteraction):
yData = item.getYData(copy=False)
indices = result.getIndices(copy=False)
- eventDict = prepareCurveSignal('left',
- item.getName(),
- 'curve',
- xData[indices],
- yData[indices],
- dataPos[0], dataPos[1],
- x, y)
+ eventDict = prepareCurveSignal(
+ "left",
+ item,
+ xData[indices],
+ yData[indices],
+ dataPos[0],
+ dataPos[1],
+ x,
+ y,
+ )
return eventDict
elif isinstance(item, items.ImageBase):
@@ -1196,12 +1267,9 @@ class ItemsInteraction(ClickOrDrag, _PlotInteraction):
indices = result.getIndices(copy=False)
row, column = indices[0][0], indices[1][0]
- eventDict = prepareImageSignal('left',
- item.getName(),
- 'image',
- column, row,
- dataPos[0], dataPos[1],
- x, y)
+ eventDict = prepareImageSignal(
+ "left", item, column, row, dataPos[0], dataPos[1], x, y
+ )
return eventDict
return None
@@ -1218,24 +1286,26 @@ class ItemsInteraction(ClickOrDrag, _PlotInteraction):
posDataCursor = self.plot.pixelToData(x, y)
assert posDataCursor is not None
- eventDict = prepareMarkerSignal(eventType,
- 'left',
- marker.getName(),
- 'marker',
- marker.isDraggable(),
- marker.isSelectable(),
- (xData, yData),
- (x, y),
- posDataCursor)
+ eventDict = prepareMarkerSignal(
+ eventType,
+ "left",
+ marker.getName(),
+ "marker",
+ marker.isDraggable(),
+ marker.isSelectable(),
+ (xData, yData),
+ (x, y),
+ posDataCursor,
+ )
self.plot.notify(**eventDict)
@staticmethod
def __isDraggableItem(item):
return isinstance(item, items.DraggableMixIn) and item.isDraggable()
- def __terminateDrag(self):
+ def __terminateDrag(self, x, y):
"""Finalize a drag operation by reseting to initial state"""
- self.plot.setGraphCursorShape()
+ self._setCursorForMarker(self.plot._getMarkerAt(x, y))
self.draggedItemRef = None
def beginDrag(self, x, y, btn):
@@ -1256,11 +1326,11 @@ class ItemsInteraction(ClickOrDrag, _PlotInteraction):
self.draggedItemRef = None if item is None else weakref.ref(item)
if item is None:
- self.__terminateDrag()
+ self.__terminateDrag(x, y)
return False
if isinstance(item, items.MarkerBase):
- self._signalMarkerMovingEvent('markerMoving', item, x, y)
+ self._signalMarkerMovingEvent("markerMoving", item, x, y)
item._startDrag()
return True
@@ -1278,7 +1348,7 @@ class ItemsInteraction(ClickOrDrag, _PlotInteraction):
item.drag(self._lastPos, dataPos)
if isinstance(item, items.MarkerBase):
- self._signalMarkerMovingEvent('markerMoving', item, x, y)
+ self._signalMarkerMovingEvent("markerMoving", item, x, y)
self._lastPos = dataPos
elif btn == MIDDLE_BTN:
@@ -1290,46 +1360,52 @@ class ItemsInteraction(ClickOrDrag, _PlotInteraction):
if isinstance(item, items.MarkerBase):
posData = list(item.getPosition())
if posData[0] is None:
- posData[0] = 1.
+ posData[0] = 1.0
if posData[1] is None:
- posData[1] = 1.
+ posData[1] = 1.0
eventDict = prepareMarkerSignal(
- 'markerMoved',
- 'left',
+ "markerMoved",
+ "left",
item.getLegend(),
- 'marker',
+ "marker",
item.isDraggable(),
item.isSelectable(),
- posData)
+ posData,
+ )
self.plot.notify(**eventDict)
item._endDrag()
- self.__terminateDrag()
+ self.__terminateDrag(*endPos)
elif btn == MIDDLE_BTN:
self._pan.endDrag(startPos, endPos, btn)
def cancel(self):
self._pan.cancel()
- self.__terminateDrag()
+ widget = self.plot.getWidgetHandle()
+ if widget is None or not widget.isVisible():
+ return
+ position = widget.mapFromGlobal(qt.QCursor.pos())
+ self.__terminateDrag(position.x(), position.y())
class ItemsInteractionForCombo(ItemsInteraction):
- """Interaction with items to combine through :class:`FocusManager`.
- """
+ """Interaction with items to combine through :class:`FocusManager`."""
class Idle(ItemsInteraction.Idle):
@staticmethod
def __isItemSelectableOrDraggable(item):
- return (item.isSelectable() or (
- isinstance(item, items.DraggableMixIn) and item.isDraggable()))
+ return item.isSelectable() or (
+ isinstance(item, items.DraggableMixIn) and item.isDraggable()
+ )
def onPress(self, x, y, btn):
if btn == LEFT_BTN:
result = self.machine.plot._pickTopMost(
- x, y, self.__isItemSelectableOrDraggable)
+ x, y, self.__isItemSelectableOrDraggable
+ )
if result is not None: # Request focus and handle interaction
- self.goto('clickOrDrag', x, y, btn)
+ self.goto("clickOrDrag", x, y, btn)
return True
else: # Do not request focus
return False
@@ -1339,19 +1415,21 @@ class ItemsInteractionForCombo(ItemsInteraction):
# FocusManager ################################################################
+
class FocusManager(StateMachine):
"""Manages focus across multiple event handlers
On press an event handler can acquire focus.
By default it looses focus when all buttons are released.
"""
+
class Idle(State):
def onPress(self, x, y, btn):
if btn == LEFT_BTN:
for eventHandler in self.machine.eventHandlers:
- requestFocus = eventHandler.handleEvent('press', x, y, btn)
+ requestFocus = eventHandler.handleEvent("press", x, y, btn)
if requestFocus:
- self.goto('focus', eventHandler, btn)
+ self.goto("focus", eventHandler, btn)
break
def _processEvent(self, *args):
@@ -1361,14 +1439,14 @@ class FocusManager(StateMachine):
break
def onMove(self, x, y):
- self._processEvent('move', x, y)
+ self._processEvent("move", x, y)
def onRelease(self, x, y, btn):
if btn == LEFT_BTN:
- self._processEvent('release', x, y, btn)
+ self._processEvent("release", x, y, btn)
def onWheel(self, x, y, angle):
- self._processEvent('wheel', x, y, angle)
+ self._processEvent("wheel", x, y, angle)
class Focus(State):
def enterState(self, eventHandler, btn):
@@ -1377,34 +1455,31 @@ class FocusManager(StateMachine):
def validate(self):
self.eventHandler.validate()
- self.goto('idle')
+ self.goto("idle")
def onPress(self, x, y, btn):
if btn == LEFT_BTN:
self.focusBtns.add(btn)
- self.eventHandler.handleEvent('press', x, y, btn)
+ self.eventHandler.handleEvent("press", x, y, btn)
def onMove(self, x, y):
- self.eventHandler.handleEvent('move', x, y)
+ self.eventHandler.handleEvent("move", x, y)
def onRelease(self, x, y, btn):
if btn == LEFT_BTN:
self.focusBtns.discard(btn)
- requestFocus = self.eventHandler.handleEvent('release', x, y, btn)
+ requestFocus = self.eventHandler.handleEvent("release", x, y, btn)
if len(self.focusBtns) == 0 and not requestFocus:
- self.goto('idle')
+ self.goto("idle")
def onWheel(self, x, y, angleInDegrees):
- self.eventHandler.handleEvent('wheel', x, y, angleInDegrees)
+ self.eventHandler.handleEvent("wheel", x, y, angleInDegrees)
def __init__(self, eventHandlers=()):
self.eventHandlers = list(eventHandlers)
- states = {
- 'idle': FocusManager.Idle,
- 'focus': FocusManager.Focus
- }
- super(FocusManager, self).__init__(states, 'idle')
+ states = {"idle": FocusManager.Idle, "focus": FocusManager.Focus}
+ super(FocusManager, self).__init__(states, "idle")
def cancel(self):
for handler in self.eventHandlers:
@@ -1428,6 +1503,15 @@ class ZoomAndSelect(ItemsInteraction):
"""Color of the zoom area"""
return self._zoom.color
+ @property
+ def zoomEnabledAxes(self) -> EnabledAxes:
+ """Whether or not to apply zoom for each axis"""
+ return self._zoom.enabledAxes
+
+ @zoomEnabledAxes.setter
+ def zoomEnabledAxes(self, enabledAxes: EnabledAxes):
+ self._zoom.enabledAxes = enabledAxes
+
def click(self, x, y, btn):
"""Handle mouse click
@@ -1442,9 +1526,9 @@ class ZoomAndSelect(ItemsInteraction):
# Signal mouse clicked event
dataPos = self.plot.pixelToData(x, y)
assert dataPos is not None
- clickedEventDict = prepareMouseSignal('mouseClicked', btn,
- dataPos[0], dataPos[1],
- x, y)
+ clickedEventDict = prepareMouseSignal(
+ "mouseClicked", btn, dataPos[0], dataPos[1], x, y
+ )
self.plot.notify(**clickedEventDict)
self.plot.notify(**eventDict)
@@ -1513,9 +1597,9 @@ class PanAndSelect(ItemsInteraction):
# Signal mouse clicked event
dataPos = self.plot.pixelToData(x, y)
assert dataPos is not None
- clickedEventDict = prepareMouseSignal('mouseClicked', btn,
- dataPos[0], dataPos[1],
- x, y)
+ clickedEventDict = prepareMouseSignal(
+ "mouseClicked", btn, dataPos[0], dataPos[1], x, y
+ )
self.plot.notify(**clickedEventDict)
self.plot.notify(**eventDict)
@@ -1563,15 +1647,15 @@ class PanAndSelect(ItemsInteraction):
# Mapping of draw modes: event handler
_DRAW_MODES = {
- 'polygon': SelectPolygon,
- 'rectangle': SelectRectangle,
- 'ellipse': SelectEllipse,
- 'line': SelectLine,
- 'vline': SelectVLine,
- 'hline': SelectHLine,
- 'polylines': SelectFreeLine,
- 'pencil': DrawFreeHand,
- }
+ "polygon": SelectPolygon,
+ "rectangle": SelectRectangle,
+ "ellipse": SelectEllipse,
+ "line": SelectLine,
+ "vline": SelectVLine,
+ "hline": SelectHLine,
+ "polylines": SelectFreeLine,
+ "pencil": DrawFreeHand,
+}
class DrawMode(FocusManager):
@@ -1580,19 +1664,22 @@ class DrawMode(FocusManager):
def __init__(self, plot, shape, label, color, width):
eventHandlerClass = _DRAW_MODES[shape]
parameters = {
- 'shape': shape,
- 'label': label,
- 'color': color,
- 'width': width,
- }
- super().__init__((
- Pan(plot, clickButtons=(), dragButtons=(MIDDLE_BTN,)),
- eventHandlerClass(plot, parameters)))
+ "shape": shape,
+ "label": label,
+ "color": color,
+ "width": width,
+ }
+ super().__init__(
+ (
+ Pan(plot, clickButtons=(), dragButtons=(MIDDLE_BTN,)),
+ eventHandlerClass(plot, parameters),
+ )
+ )
def getDescription(self):
"""Returns the dict describing this interactive mode"""
params = self.eventHandlers[1].parameters.copy()
- params['mode'] = 'draw'
+ params["mode"] = "draw"
return params
@@ -1604,27 +1691,27 @@ class DrawSelectMode(FocusManager):
self._pan = Pan(plot)
self._panStart = None
parameters = {
- 'shape': shape,
- 'label': label,
- 'color': color,
- 'width': width,
- }
- super().__init__((
- ItemsInteractionForCombo(plot),
- eventHandlerClass(plot, parameters)))
+ "shape": shape,
+ "label": label,
+ "color": color,
+ "width": width,
+ }
+ super().__init__(
+ (ItemsInteractionForCombo(plot), eventHandlerClass(plot, parameters))
+ )
def handleEvent(self, eventName, *args, **kwargs):
# Hack to add pan interaction to select-draw
# See issue Refactor PlotWidget interaction #3292
- if eventName == 'press' and args[2] == MIDDLE_BTN:
+ if eventName == "press" and args[2] == MIDDLE_BTN:
self._panStart = args[:2]
self._pan.beginDrag(*args)
return # Consume middle click events
- elif eventName == 'release' and args[2] == MIDDLE_BTN:
+ elif eventName == "release" and args[2] == MIDDLE_BTN:
self._panStart = None
self._pan.endDrag(self._panStart, args[:2], MIDDLE_BTN)
return # Consume middle click events
- elif self._panStart is not None and eventName == 'move':
+ elif self._panStart is not None and eventName == "move":
x, y = args[:2]
self._pan.drag(x, y, MIDDLE_BTN)
@@ -1633,67 +1720,94 @@ class DrawSelectMode(FocusManager):
def getDescription(self):
"""Returns the dict describing this interactive mode"""
params = self.eventHandlers[1].parameters.copy()
- params['mode'] = 'select-draw'
+ params["mode"] = "select-draw"
return params
-class PlotInteraction(object):
- """Proxy to currently use state machine for interaction.
-
- This allows to switch interactive mode.
+class PlotInteraction(qt.QObject):
+ """PlotWidget user interaction handler.
- :param plot: The :class:`Plot` to apply interaction to
+ :param plot: The :class:`PlotWidget` to apply interaction to
"""
+ sigChanged = qt.Signal()
+ """Signal emitted when the interaction configuration has changed"""
+
_DRAW_MODES = {
- 'polygon': SelectPolygon,
- 'rectangle': SelectRectangle,
- 'ellipse': SelectEllipse,
- 'line': SelectLine,
- 'vline': SelectVLine,
- 'hline': SelectHLine,
- 'polylines': SelectFreeLine,
- 'pencil': DrawFreeHand,
+ "polygon": SelectPolygon,
+ "rectangle": SelectRectangle,
+ "ellipse": SelectEllipse,
+ "line": SelectLine,
+ "vline": SelectVLine,
+ "hline": SelectHLine,
+ "polylines": SelectFreeLine,
+ "pencil": DrawFreeHand,
}
- def __init__(self, plot):
- self._plot = weakref.ref(plot) # Avoid cyclic-ref
-
- self.zoomOnWheel = True
- """True to enable zoom on wheel, False otherwise."""
+ def __init__(self, parent):
+ super().__init__(parent)
+ self.__zoomOnWheel = True
+ self.__zoomEnabledAxes = EnabledAxes()
# Default event handler
- self._eventHandler = ItemsInteraction(plot)
+ self._eventHandler = ItemsInteraction(parent)
+
+ def isZoomOnWheelEnabled(self) -> bool:
+ """Returns whether or not wheel interaction triggers zoom"""
+ return self.__zoomOnWheel
+
+ def setZoomOnWheelEnabled(self, enabled: bool):
+ """Toggle zoom on wheel interaction"""
+ if enabled != self.__zoomOnWheel:
+ self.__zoomOnWheel = enabled
+ self.sigChanged.emit()
- def getInteractiveMode(self):
+ def setZoomEnabledAxes(self, xaxis: bool, yaxis: bool, y2axis: bool):
+ """Toggle zoom interaction for each axis
+
+ This is taken into account only if the plot does not keep aspect ratio.
+ """
+ zoomEnabledAxes = EnabledAxes(xaxis, yaxis, y2axis)
+ if zoomEnabledAxes != self.__zoomEnabledAxes:
+ self.__zoomEnabledAxes = zoomEnabledAxes
+ if isinstance(self._eventHandler, ZoomAndSelect):
+ self._eventHandler.zoomEnabledAxes = zoomEnabledAxes
+ self.sigChanged.emit()
+
+ def getZoomEnabledAxes(self) -> EnabledAxes:
+ """Returns axes for which zoom is enabled"""
+ return self.__zoomEnabledAxes
+
+ def _getInteractiveMode(self):
"""Returns the current interactive mode as a dict.
The returned dict contains at least the key 'mode'.
Mode can be: 'draw', 'pan', 'select', 'select-draw', 'zoom'.
It can also contains extra keys (e.g., 'color') specific to a mode
- as provided to :meth:`setInteractiveMode`.
+ as provided to :meth:`_setInteractiveMode`.
"""
if isinstance(self._eventHandler, ZoomAndSelect):
- return {'mode': 'zoom', 'color': self._eventHandler.color}
+ return {"mode": "zoom", "color": self._eventHandler.color}
elif isinstance(self._eventHandler, (DrawMode, DrawSelectMode)):
return self._eventHandler.getDescription()
elif isinstance(self._eventHandler, PanAndSelect):
- return {'mode': 'pan'}
+ return {"mode": "pan"}
else:
- return {'mode': 'select'}
+ return {"mode": "select"}
- def validate(self):
+ def _validate(self):
"""Validate the current interaction if possible
If was designed to close the polygon interaction.
"""
self._eventHandler.validate()
- def setInteractiveMode(self, mode, color='black',
- shape='polygon', label=None, width=None):
+ def _setInteractiveMode(
+ self, mode, color="black", shape="polygon", label=None, width=None
+ ):
"""Switch the interactive mode.
:param str mode: The name of the interactive mode.
@@ -1710,36 +1824,62 @@ class PlotInteraction(object):
:param str label: Only for 'draw' mode.
:param float width: Width of the pencil. Only for draw pencil mode.
"""
- assert mode in ('draw', 'pan', 'select', 'select-draw', 'zoom')
+ assert mode in ("draw", "pan", "select", "select-draw", "zoom")
- plot = self._plot()
- assert plot is not None
+ plotWidget = self.parent()
+ assert plotWidget is not None
- if isinstance(color, numpy.ndarray) or color not in (None, 'video inverted'):
+ if isinstance(color, numpy.ndarray) or color not in (None, "video inverted"):
color = colors.rgba(color)
- if mode in ('draw', 'select-draw'):
+ if mode in ("draw", "select-draw"):
self._eventHandler.cancel()
- handlerClass = DrawMode if mode == 'draw' else DrawSelectMode
- self._eventHandler = handlerClass(plot, shape, label, color, width)
+ handlerClass = DrawMode if mode == "draw" else DrawSelectMode
+ self._eventHandler = handlerClass(plotWidget, shape, label, color, width)
- elif mode == 'pan':
+ elif mode == "pan":
# Ignores color, shape and label
self._eventHandler.cancel()
- self._eventHandler = PanAndSelect(plot)
+ self._eventHandler = PanAndSelect(plotWidget)
- elif mode == 'zoom':
+ elif mode == "zoom":
# Ignores shape and label
self._eventHandler.cancel()
- self._eventHandler = ZoomAndSelect(plot, color)
+ self._eventHandler = ZoomAndSelect(plotWidget, color)
+ self._eventHandler.zoomEnabledAxes = self.getZoomEnabledAxes()
else: # Default mode: interaction with plot objects
# Ignores color, shape and label
self._eventHandler.cancel()
- self._eventHandler = ItemsInteraction(plot)
+ self._eventHandler = ItemsInteraction(plotWidget)
+
+ self.sigChanged.emit()
def handleEvent(self, event, *args, **kwargs):
"""Forward event to current interactive mode state machine."""
- if not self.zoomOnWheel and event == 'wheel':
- return # Discard wheel events
+ if event == "wheel": # Handle wheel events directly
+ self._onWheel(*args, **kwargs)
+ return
+
self._eventHandler.handleEvent(event, *args, **kwargs)
+
+ def _onWheel(self, x: float, y: float, angle: float):
+ """Handle wheel events"""
+ if not self.isZoomOnWheelEnabled():
+ return
+
+ plotWidget = self.parent()
+ if plotWidget is None:
+ return
+
+ # All axes are enabled if keep aspect ratio is on
+ enabledAxes = (
+ EnabledAxes()
+ if plotWidget.isKeepDataAspectRatio()
+ else self.getZoomEnabledAxes()
+ )
+ if enabledAxes.isDisabled():
+ return
+
+ scale = 1.1 if angle > 0 else 1.0 / 1.1
+ applyZoomToPlot(plotWidget, scale, (x, y), enabledAxes)
diff --git a/src/silx/gui/plot/PlotToolButtons.py b/src/silx/gui/plot/PlotToolButtons.py
index a810ce1..e132877 100644
--- a/src/silx/gui/plot/PlotToolButtons.py
+++ b/src/silx/gui/plot/PlotToolButtons.py
@@ -29,6 +29,7 @@ The following QToolButton are available:
- :class:`.AspectToolButton`
- :class:`.YAxisOriginToolButton`
- :class:`.ProfileToolButton`
+- :class:`.RulerToolButton`
- :class:`.SymbolToolButton`
"""
@@ -40,11 +41,11 @@ __date__ = "27/06/2017"
import functools
import logging
-import weakref
from .. import icons
from .. import qt
from ... import config
+from .tools.PlotToolButton import PlotToolButton
from .items import SymbolMixIn, Scatter
@@ -52,58 +53,6 @@ from .items import SymbolMixIn, Scatter
_logger = logging.getLogger(__name__)
-class PlotToolButton(qt.QToolButton):
- """A QToolButton connected to a :class:`~silx.gui.plot.PlotWidget`.
- """
-
- def __init__(self, parent=None, plot=None):
- super(PlotToolButton, self).__init__(parent)
- self._plotRef = None
- if plot is not None:
- self.setPlot(plot)
-
- def plot(self):
- """
- Returns the plot connected to the widget.
- """
- return None if self._plotRef is None else self._plotRef()
-
- def setPlot(self, plot):
- """
- Set the plot connected to the widget
-
- :param plot: :class:`.PlotWidget` instance on which to operate.
- """
- previousPlot = self.plot()
-
- if previousPlot is plot:
- return
- if previousPlot is not None:
- self._disconnectPlot(previousPlot)
-
- if plot is None:
- self._plotRef = None
- else:
- self._plotRef = weakref.ref(plot)
- self._connectPlot(plot)
-
- def _connectPlot(self, plot):
- """
- Called when the plot is connected to the widget
-
- :param plot: :class:`.PlotWidget` instance
- """
- pass
-
- def _disconnectPlot(self, plot):
- """
- Called when the plot is disconnected from the widget
-
- :param plot: :class:`.PlotWidget` instance
- """
- pass
-
-
class AspectToolButton(PlotToolButton):
"""Tool button to switch keep aspect ratio of a plot"""
@@ -114,11 +63,11 @@ class AspectToolButton(PlotToolButton):
if self.STATE is None:
self.STATE = {}
# dont keep ratio
- self.STATE[False, "icon"] = icons.getQIcon('shape-ellipse-solid')
+ self.STATE[False, "icon"] = icons.getQIcon("shape-ellipse-solid")
self.STATE[False, "state"] = "Aspect ratio is not kept"
self.STATE[False, "action"] = "Do no keep data aspect ratio"
# keep ratio
- self.STATE[True, "icon"] = icons.getQIcon('shape-circle-solid')
+ self.STATE[True, "icon"] = icons.getQIcon("shape-circle-solid")
self.STATE[True, "state"] = "Aspect ratio is kept"
self.STATE[True, "action"] = "Keep data aspect ratio"
@@ -166,7 +115,10 @@ class AspectToolButton(PlotToolButton):
def _keepDataAspectRatioChanged(self, aspectRatio):
"""Handle Plot set keep aspect ratio signal"""
- icon, toolTip = self.STATE[aspectRatio, "icon"], self.STATE[aspectRatio, "state"]
+ icon, toolTip = (
+ self.STATE[aspectRatio, "icon"],
+ self.STATE[aspectRatio, "state"],
+ )
self.setIcon(icon)
self.setToolTip(toolTip)
@@ -181,11 +133,11 @@ class YAxisOriginToolButton(PlotToolButton):
if self.STATE is None:
self.STATE = {}
# is down
- self.STATE[False, "icon"] = icons.getQIcon('plot-ydown')
+ self.STATE[False, "icon"] = icons.getQIcon("plot-ydown")
self.STATE[False, "state"] = "Y-axis is oriented downward"
self.STATE[False, "action"] = "Orient Y-axis downward"
# keep ration
- self.STATE[True, "icon"] = icons.getQIcon('plot-yup')
+ self.STATE[True, "icon"] = icons.getQIcon("plot-yup")
self.STATE[True, "state"] = "Y-axis is oriented upward"
self.STATE[True, "action"] = "Orient Y-axis upward"
@@ -242,28 +194,29 @@ class YAxisOriginToolButton(PlotToolButton):
class ProfileOptionToolButton(PlotToolButton):
"""Button to define option on the profile"""
+
sigMethodChanged = qt.Signal(str)
-
+
def __init__(self, parent=None, plot=None):
PlotToolButton.__init__(self, parent=parent, plot=plot)
self.STATE = {}
# is down
- self.STATE['sum', "icon"] = icons.getQIcon('math-sigma')
- self.STATE['sum', "state"] = "Compute profile sum"
- self.STATE['sum', "action"] = "Compute profile sum"
+ self.STATE["sum", "icon"] = icons.getQIcon("math-sigma")
+ self.STATE["sum", "state"] = "Compute profile sum"
+ self.STATE["sum", "action"] = "Compute profile sum"
# keep ration
- self.STATE['mean', "icon"] = icons.getQIcon('math-mean')
- self.STATE['mean', "state"] = "Compute profile mean"
- self.STATE['mean', "action"] = "Compute profile mean"
+ self.STATE["mean", "icon"] = icons.getQIcon("math-mean")
+ self.STATE["mean", "state"] = "Compute profile mean"
+ self.STATE["mean", "action"] = "Compute profile mean"
- self.sumAction = self._createAction('sum')
+ self.sumAction = self._createAction("sum")
self.sumAction.triggered.connect(self.setSum)
self.sumAction.setIconVisibleInMenu(True)
self.sumAction.setCheckable(True)
self.sumAction.setChecked(True)
- self.meanAction = self._createAction('mean')
+ self.meanAction = self._createAction("mean")
self.meanAction.triggered.connect(self.setMean)
self.meanAction.setIconVisibleInMenu(True)
self.meanAction.setCheckable(True)
@@ -273,7 +226,7 @@ class ProfileOptionToolButton(PlotToolButton):
menu.addAction(self.meanAction)
self.setMenu(menu)
self.setPopupMode(qt.QToolButton.InstantPopup)
- self._method = 'mean'
+ self._method = "mean"
self._update()
def _createAction(self, method):
@@ -282,7 +235,7 @@ class ProfileOptionToolButton(PlotToolButton):
return qt.QAction(icon, text, self)
def setSum(self):
- self.setMethod('sum')
+ self.setMethod("sum")
def _update(self):
icon = self.STATE[self._method, "icon"]
@@ -293,7 +246,7 @@ class ProfileOptionToolButton(PlotToolButton):
self.meanAction.setChecked(self._method == "mean")
def setMean(self):
- self.setMethod('mean')
+ self.setMethod("mean")
def setMethod(self, method):
"""Set the method to use.
@@ -301,13 +254,12 @@ class ProfileOptionToolButton(PlotToolButton):
:param str method: Either 'sum' or 'mean'
"""
if method != self._method:
- if method in ('sum', 'mean'):
+ if method in ("sum", "mean"):
self._method = method
self.sigMethodChanged.emit(self._method)
self._update()
else:
- _logger.warning(
- "Unsupported method '%s'. Setting ignored.", method)
+ _logger.warning("Unsupported method '%s'. Setting ignored.", method)
def getMethod(self):
"""Returns the current method in use (See :meth:`setMethod`).
@@ -320,6 +272,7 @@ class ProfileOptionToolButton(PlotToolButton):
class ProfileToolButton(PlotToolButton):
"""Button used in Profile3DToolbar to switch between 2D profile
and 1D profile."""
+
STATE = None
"""Lazy loaded states used to feed ProfileToolButton"""
@@ -328,12 +281,16 @@ class ProfileToolButton(PlotToolButton):
def __init__(self, parent=None, plot=None):
if self.STATE is None:
self.STATE = {
- (1, "icon"): icons.getQIcon('profile1D'),
+ (1, "icon"): icons.getQIcon("profile1D"),
(1, "state"): "1D profile is computed on visible image",
(1, "action"): "1D profile on visible image",
- (2, "icon"): icons.getQIcon('profile2D'),
- (2, "state"): "2D profile is computed, one 1D profile for each image in the stack",
- (2, "action"): "2D profile on image stack"}
+ (2, "icon"): icons.getQIcon("profile2D"),
+ (
+ 2,
+ "state",
+ ): "2D profile is computed, one 1D profile for each image in the stack",
+ (2, "action"): "2D profile on image stack",
+ }
# Compute 1D profile
# Compute 2D profile
@@ -359,7 +316,7 @@ class ProfileToolButton(PlotToolButton):
menu.addAction(profile2DAction)
self.setMenu(menu)
self.setPopupMode(qt.QToolButton.InstantPopup)
- menu.setTitle('Select profile dimension')
+ menu.setTitle("Select profile dimension")
self.computeProfileIn1D()
def _createAction(self, profileDimension):
@@ -431,12 +388,12 @@ class _SymbolToolButtonBase(PlotToolButton):
:param QMenu menu:
"""
- for marker, name in zip(SymbolMixIn.getSupportedSymbols(),
- SymbolMixIn.getSupportedSymbolNames()):
+ for marker, name in zip(
+ SymbolMixIn.getSupportedSymbols(), SymbolMixIn.getSupportedSymbolNames()
+ ):
action = qt.QAction(name, menu)
action.setCheckable(False)
- action.triggered.connect(
- functools.partial(self._markerChanged, marker))
+ action.triggered.connect(functools.partial(self._markerChanged, marker))
menu.addAction(action)
def _sizeChanged(self, value):
@@ -476,8 +433,8 @@ class SymbolToolButton(_SymbolToolButtonBase):
def __init__(self, parent=None, plot=None):
super(SymbolToolButton, self).__init__(parent=parent, plot=plot)
- self.setToolTip('Set symbol size and marker')
- self.setIcon(icons.getQIcon('plot-symbols'))
+ self.setToolTip("Set symbol size and marker")
+ self.setIcon(icons.getQIcon("plot-symbols"))
menu = qt.QMenu(self)
self._addSizeSliderToMenu(menu)
@@ -496,12 +453,10 @@ class ScatterVisualizationToolButton(_SymbolToolButtonBase):
"""
def __init__(self, parent=None, plot=None):
- super(ScatterVisualizationToolButton, self).__init__(
- parent=parent, plot=plot)
+ super(ScatterVisualizationToolButton, self).__init__(parent=parent, plot=plot)
- self.setToolTip(
- 'Set scatter visualization mode, symbol marker and size')
- self.setIcon(icons.getQIcon('eye'))
+ self.setToolTip("Set scatter visualization mode, symbol marker and size")
+ self.setIcon(icons.getQIcon("eye"))
menu = qt.QMenu(self)
@@ -513,26 +468,33 @@ class ScatterVisualizationToolButton(_SymbolToolButtonBase):
action = qt.QAction(name, menu)
action.setCheckable(False)
action.triggered.connect(
- functools.partial(self._visualizationChanged, mode, None))
+ functools.partial(self._visualizationChanged, mode, None)
+ )
menu.addAction(action)
if Scatter.Visualization.BINNED_STATISTIC in Scatter.supportedVisualizations():
reductions = Scatter.supportedVisualizationParameterValues(
- Scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION)
+ Scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION
+ )
if reductions:
- submenu = menu.addMenu('Binned Statistic')
+ submenu = menu.addMenu("Binned Statistic")
for reduction in reductions:
name = reduction.capitalize()
action = qt.QAction(name, menu)
action.setCheckable(False)
- action.triggered.connect(functools.partial(
- self._visualizationChanged,
- Scatter.Visualization.BINNED_STATISTIC,
- {Scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION: reduction}))
+ action.triggered.connect(
+ functools.partial(
+ self._visualizationChanged,
+ Scatter.Visualization.BINNED_STATISTIC,
+ {
+ Scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION: reduction
+ },
+ )
+ )
submenu.addAction(action)
submenu.addSeparator()
- binsmenu = submenu.addMenu('N Bins')
+ binsmenu = submenu.addMenu("N Bins")
slider = qt.QSlider(qt.Qt.Horizontal)
slider.setRange(10, 1000)
@@ -545,10 +507,10 @@ class ScatterVisualizationToolButton(_SymbolToolButtonBase):
menu.addSeparator()
- submenu = menu.addMenu(icons.getQIcon('plot-symbols'), "Symbol")
+ submenu = menu.addMenu(icons.getQIcon("plot-symbols"), "Symbol")
self._addSymbolsToMenu(submenu)
- submenu = menu.addMenu(icons.getQIcon('plot-symbols'), "Symbol Size")
+ submenu = menu.addMenu(icons.getQIcon("plot-symbols"), "Symbol Size")
self._addSizeSliderToMenu(submenu)
self.setMenu(menu)
@@ -587,5 +549,6 @@ class ScatterVisualizationToolButton(_SymbolToolButtonBase):
if isinstance(item, Scatter):
item.setVisualizationParameter(
Scatter.VisualizationParameter.BINNED_STATISTIC_SHAPE,
- (value, value))
+ (value, value),
+ )
item.setVisualization(Scatter.Visualization.BINNED_STATISTIC)
diff --git a/src/silx/gui/plot/PlotWidget.py b/src/silx/gui/plot/PlotWidget.py
index f07ef30..a01ca48 100755
--- a/src/silx/gui/plot/PlotWidget.py
+++ b/src/silx/gui/plot/PlotWidget.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2004-2022 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -25,6 +25,8 @@
The :class:`PlotWidget` implements the plot API initially provided in PyMca.
"""
+from __future__ import annotations
+
__authors__ = ["V.A. Sole", "T. Vincent"]
__license__ = "MIT"
__date__ = "21/12/2018"
@@ -34,20 +36,20 @@ import logging
_logger = logging.getLogger(__name__)
-from collections import OrderedDict, namedtuple
+from collections import namedtuple
+from collections.abc import Sequence
from contextlib import contextmanager
+from typing import Optional, Union
import datetime as dt
import itertools
import numbers
-import typing
import warnings
import numpy
import silx
from silx.utils.weakref import WeakMethodProxy
-from silx.utils.property import classproperty
-from silx.utils.deprecation import deprecated, deprecated_warning
+
try:
# Import matplotlib now to init matplotlib our way
import silx.gui.utils.matplotlib # noqa
@@ -68,17 +70,13 @@ from .items.axis import TickMode # noqa
from .. import qt
from ._utils.panzoom import ViewConstraints
from ...gui.plot._utils.dtime_ticklayout import timestamp
+from ...utils.deprecation import deprecated_warning
-
-_COLORDICT = colors.COLORDICT
-_COLORLIST = silx.config.DEFAULT_PLOT_CURVE_COLORS
-
"""
Object returned when requesting the data range.
"""
-_PlotDataRange = namedtuple('PlotDataRange',
- ['x', 'y', 'yright'])
+_PlotDataRange = namedtuple("PlotDataRange", ["x", "y", "yright"])
class _PlotWidgetSelection(qt.QObject):
@@ -104,10 +102,14 @@ class _PlotWidgetSelection(qt.QObject):
# Init history
self.__history = [ # Store active items from most recent to oldest
- item for item in (parent.getActiveCurve(),
- parent.getActiveImage(),
- parent.getActiveScatter())
- if item is not None]
+ item
+ for item in (
+ parent.getActiveCurve(),
+ parent.getActiveImage(),
+ parent.getActiveScatter(),
+ )
+ if item is not None
+ ]
self.__current = self.__mostRecentActiveItem()
@@ -115,11 +117,11 @@ class _PlotWidgetSelection(qt.QObject):
parent.sigActiveCurveChanged.connect(self._activeCurveChanged)
parent.sigActiveScatterChanged.connect(self._activeScatterChanged)
- def __mostRecentActiveItem(self) -> typing.Optional[items.Item]:
+ def __mostRecentActiveItem(self) -> Optional[items.Item]:
"""Returns most recent active item."""
return self.__history[0] if len(self.__history) >= 1 else None
- def getSelectedItems(self) -> typing.Tuple[items.Item]:
+ def getSelectedItems(self) -> tuple[items.Item]:
"""Returns the list of currently selected items in the :class:`PlotWidget`.
The list is given from most recently current item to oldest one."""
@@ -136,11 +138,11 @@ class _PlotWidgetSelection(qt.QObject):
return active
- def getCurrentItem(self) -> typing.Optional[items.Item]:
- """Returns the current item in the :class:`PlotWidget` or None. """
+ def getCurrentItem(self) -> Optional[items.Item]:
+ """Returns the current item in the :class:`PlotWidget` or None."""
return self.__current
- def setCurrentItem(self, item: typing.Optional[items.Item]):
+ def setCurrentItem(self, item: Optional[items.Item]):
"""Set the current item in the :class:`PlotWidget`.
:param item:
@@ -166,20 +168,21 @@ class _PlotWidgetSelection(qt.QObject):
elif isinstance(item, items.Item):
plot = self.parent()
if plot is None or item.getPlot() is not plot:
- raise ValueError(
- "Item is not in the PlotWidget: %s" % str(item))
+ raise ValueError("Item is not in the PlotWidget: %s" % str(item))
self.__current = item
kind = plot._itemKind(item)
# Clean-up history to be safe
- self.__history = [item for item in self.__history
- if PlotWidget._itemKind(item) != kind]
+ self.__history = [
+ item for item in self.__history if PlotWidget._itemKind(item) != kind
+ ]
# Sync active item if needed
- if (kind in plot._ACTIVE_ITEM_KINDS and
- item is not plot._getActiveItem(kind)):
- plot._setActiveItem(kind, item.getName())
+ if kind in plot._ACTIVE_ITEM_KINDS and item is not plot._getActiveItem(
+ kind
+ ):
+ plot._setActiveItem(kind, item)
else:
raise ValueError("Not an Item: %s" % str(item))
@@ -188,10 +191,9 @@ class _PlotWidgetSelection(qt.QObject):
if previousSelected != self.getSelectedItems():
self.sigSelectedItemsChanged.emit()
- def __activeItemChanged(self,
- kind: str,
- previous: typing.Optional[str],
- legend: typing.Optional[str]):
+ def __activeItemChanged(
+ self, kind: str, previous: Optional[str], legend: Optional[str]
+ ):
"""Set current item from kind and legend"""
if previous == legend:
return # No-op for update of item
@@ -203,8 +205,9 @@ class _PlotWidgetSelection(qt.QObject):
previousSelected = self.getSelectedItems()
# Remove items of this kind from the history
- self.__history = [item for item in self.__history
- if PlotWidget._itemKind(item) != kind]
+ self.__history = [
+ item for item in self.__history if PlotWidget._itemKind(item) != kind
+ ]
# Retrieve current item
if legend is None: # Use most recent active item
@@ -230,15 +233,15 @@ class _PlotWidgetSelection(qt.QObject):
def _activeImageChanged(self, previous, current):
"""Handle active image change"""
- self.__activeItemChanged('image', previous, current)
+ self.__activeItemChanged("image", previous, current)
def _activeCurveChanged(self, previous, current):
"""Handle active curve change"""
- self.__activeItemChanged('curve', previous, current)
+ self.__activeItemChanged("curve", previous, current)
def _activeScatterChanged(self, previous, current):
"""Handle active scatter change"""
- self.__activeItemChanged('scatter', previous, current)
+ self.__activeItemChanged("scatter", previous, current)
class PlotWidget(qt.QMainWindow):
@@ -260,15 +263,10 @@ class PlotWidget(qt.QMainWindow):
:type backend: str or :class:`BackendBase.BackendBase`
"""
- # TODO: Can be removed for silx 0.10
- @classproperty
- @deprecated(replacement="silx.config.DEFAULT_PLOT_BACKEND", since_version="0.8", skip_backtrace_count=2)
- def DEFAULT_BACKEND(self):
- """Class attribute setting the default backend for all instances."""
- return silx.config.DEFAULT_PLOT_BACKEND
-
- colorList = _COLORLIST
- colorDict = _COLORDICT
+ # The following 2 class attributes are no longer used
+ # but there is no way to warn about deprecation
+ colorList = silx.config.DEFAULT_PLOT_CURVE_COLORS
+ colorDict = colors.COLORDICT
sigPlotSignal = qt.Signal(object)
"""Signal for all events of the plot.
@@ -368,6 +366,9 @@ class PlotWidget(qt.QMainWindow):
It provides the menu which will be displayed.
"""
+ sigBackendChanged = qt.Signal()
+ """Signal emitted when the backend have changed."""
+
def __init__(self, parent=None, backend=None):
self._autoreplot = False
self._dirty = False
@@ -382,7 +383,7 @@ class PlotWidget(qt.QMainWindow):
# behave as a widget
self.setWindowFlags(qt.Qt.Widget)
else:
- self.setWindowTitle('PlotWidget')
+ self.setWindowTitle("PlotWidget")
# Init the backend
self._backend = self.__getBackendClass(backend)(self, self)
@@ -390,25 +391,28 @@ class PlotWidget(qt.QMainWindow):
self.setCallback() # set _callback
# Items handling
- self._content = OrderedDict()
- self._contentToUpdate = [] # Used as an OrderedSet
+ self.__items = []
+ self.__itemsToUpdate = [] # Used as an OrderedSet
+ self.__activeItems = {"curve": None, "image": None, "scatter": None}
self._dataRange = None
# line types
- self._styleList = ['-', '--', '-.', ':']
+ self._defaultColors = None
+ self._styleList = ["-", "--", "-.", ":"]
self._colorIndex = 0
self._styleIndex = 0
self._activeCurveSelectionMode = "atmostone"
- self._activeCurveStyle = CurveStyle(color='#000000')
- self._activeLegend = {'curve': None, 'image': None,
- 'scatter': None}
+ self._activeCurveStyle = CurveStyle(
+ color=silx.config.DEFAULT_PLOT_ACTIVE_CURVE_COLOR,
+ linewidth=silx.config.DEFAULT_PLOT_ACTIVE_CURVE_LINEWIDTH,
+ )
# plot colors (updated later to sync backend)
- self._foregroundColor = 0., 0., 0., 1.
- self._gridColor = .7, .7, .7, 1.
- self._backgroundColor = 1., 1., 1., 1.
+ self._foregroundColor = 0.0, 0.0, 0.0, 1.0
+ self._gridColor = 0.7, 0.7, 0.7, 1.0
+ self._backgroundColor = 1.0, 1.0, 1.0, 1.0
self._dataBackgroundColor = None
# default properties
@@ -419,18 +423,18 @@ class PlotWidget(qt.QMainWindow):
self._yRightAxis = items.YRightAxis(self, self._yAxis)
self._grid = None
- self._graphTitle = ''
- self.__graphCursorShape = 'default'
+ self._graphTitle = ""
+ self.__graphCursorShape = "default"
# Set axes margins
self.__axesDisplayed = True
- self.__axesMargins = 0., 0., 0., 0.
- self.setAxesMargins(.15, .1, .1, .15)
+ self.__axesMargins = 0.0, 0.0, 0.0, 0.0
+ self.setAxesMargins(0.15, 0.1, 0.1, 0.15)
self.setGraphTitle()
self.setGraphXLabel()
self.setGraphYLabel()
- self.setGraphYLabel('', axis='right')
+ self.setGraphYLabel("", axis="right")
self.setDefaultColormap() # Init default colormap
@@ -440,12 +444,14 @@ class PlotWidget(qt.QMainWindow):
self._limitsHistory = LimitsHistory(self)
self._eventHandler = PlotInteraction.PlotInteraction(self)
- self._eventHandler.setInteractiveMode('zoom', color=(0., 0., 0., 1.))
+ self._eventHandler._setInteractiveMode("zoom", color=(0.0, 0.0, 0.0, 1.0))
+ self._eventHandler.sigChanged.connect(self.__interactionChanged)
+ self.__isInteractionSignalForwarded = True
self._previousDefaultMode = "zoom", True
self._pressedButtons = [] # Currently pressed mouse buttons
- self._defaultDataMargins = (0., 0., 0., 0.)
+ self._defaultDataMargins = (0.0, 0.0, 0.0, 0.0)
# Only activate autoreplot at the end
# This avoids errors when loaded in Qt designer
@@ -462,9 +468,9 @@ class PlotWidget(qt.QMainWindow):
self.setFocus(qt.Qt.OtherFocusReason)
# Set default limits
- self.setGraphXLimits(0., 100.)
- self.setGraphYLimits(0., 100., axis='right')
- self.setGraphYLimits(0., 100., axis='left')
+ self.setGraphXLimits(0.0, 100.0)
+ self.setGraphYLimits(0.0, 100.0, axis="right")
+ self.setGraphYLimits(0.0, 100.0, axis="left")
# Sync backend colors with default ones
self._foregroundColorsUpdated()
@@ -492,30 +498,32 @@ class PlotWidget(qt.QMainWindow):
elif isinstance(backend, str):
backend = backend.lower()
- if backend in ('matplotlib', 'mpl'):
+ if backend in ("matplotlib", "mpl"):
try:
- from .backends.BackendMatplotlib import \
- BackendMatplotlibQt as backendClass
+ from .backends.BackendMatplotlib import (
+ BackendMatplotlibQt as backendClass,
+ )
except ImportError:
_logger.debug("Backtrace", exc_info=True)
raise RuntimeError("matplotlib backend is not available")
- elif backend in ('gl', 'opengl'):
+ elif backend in ("gl", "opengl"):
from ..utils.glutils import isOpenGLAvailable
+
checkOpenGL = isOpenGLAvailable(version=(2, 1), runtimeCheck=False)
if not checkOpenGL:
_logger.debug("OpenGL check failed")
raise RuntimeError(
- "OpenGL backend is not available: %s" % checkOpenGL.error)
+ "OpenGL backend is not available: %s" % checkOpenGL.error
+ )
try:
- from .backends.BackendOpenGL import \
- BackendOpenGL as backendClass
+ from .backends.BackendOpenGL import BackendOpenGL as backendClass
except ImportError:
_logger.debug("Backtrace", exc_info=True)
raise RuntimeError("OpenGL backend is not available")
- elif backend == 'none':
+ elif backend == "none":
from .backends.BackendBase import BackendBase as backendClass
else:
@@ -540,20 +548,6 @@ class PlotWidget(qt.QMainWindow):
self.__selection = _PlotWidgetSelection(parent=self)
return self.__selection
- # TODO: Can be removed for silx 0.10
- @staticmethod
- @deprecated(replacement="silx.config.DEFAULT_PLOT_BACKEND", since_version="0.8", skip_backtrace_count=2)
- def setDefaultBackend(backend):
- """Set system wide default plot backend.
-
- .. versionadded:: 0.6
-
- :param backend: The backend to use, in:
- 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'none'
- or a :class:`BackendBase.BackendBase` class
- """
- silx.config.DEFAULT_PLOT_BACKEND = backend
-
def setBackend(self, backend):
"""Set the backend to use for rendering.
@@ -576,8 +570,8 @@ class PlotWidget(qt.QMainWindow):
# First save state that is stored in the backend
xaxis = self.getXAxis()
xmin, xmax = xaxis.getLimits()
- ymin, ymax = self.getYAxis(axis='left').getLimits()
- y2min, y2max = self.getYAxis(axis='right').getLimits()
+ ymin, ymax = self.getYAxis(axis="left").getLimits()
+ y2min, y2max = self.getYAxis(axis="right").getLimits()
isKeepDataAspectRatio = self.isKeepDataAspectRatio()
xTimeZone = xaxis.getTimeZone()
isXAxisTimeSeries = xaxis.getTickMode() == TickMode.TIME_SERIES
@@ -606,7 +600,7 @@ class PlotWidget(qt.QMainWindow):
self._backend.setGraphCursorShape(self.getGraphCursorShape())
crosshairConfig = self.getGraphCursor()
if crosshairConfig is None:
- self._backend.setGraphCursor(False, 'black', 1, '-')
+ self._backend.setGraphCursor(False, "black", 1, "-")
else:
self._backend.setGraphCursor(True, *crosshairConfig)
@@ -615,21 +609,21 @@ class PlotWidget(qt.QMainWindow):
if self.isAxesDisplayed():
self._backend.setAxesMargins(*self.getAxesMargins())
else:
- self._backend.setAxesMargins(0., 0., 0., 0.)
+ self._backend.setAxesMargins(0.0, 0.0, 0.0, 0.0)
# Set axes
xaxis = self.getXAxis()
self._backend.setGraphXLabel(xaxis.getLabel())
self._backend.setXAxisTimeZone(xTimeZone)
self._backend.setXAxisTimeSeries(isXAxisTimeSeries)
- self._backend.setXAxisLogarithmic(
- xaxis.getScale() == items.Axis.LOGARITHMIC)
+ self._backend.setXAxisLogarithmic(xaxis.getScale() == items.Axis.LOGARITHMIC)
- for axis in ('left', 'right'):
+ for axis in ("left", "right"):
self._backend.setGraphYLabel(self.getYAxis(axis).getLabel(), axis)
self._backend.setYAxisInverted(isYAxisInverted)
self._backend.setYAxisLogarithmic(
- self.getYAxis().getScale() == items.Axis.LOGARITHMIC)
+ self.getYAxis().getScale() == items.Axis.LOGARITHMIC
+ )
# Finally restore aspect ratio and limits
self._backend.setKeepDataAspectRatio(isKeepDataAspectRatio)
@@ -639,6 +633,8 @@ class PlotWidget(qt.QMainWindow):
for item in self.getItems():
item._updated()
+ self.sigBackendChanged.emit()
+
def getBackend(self):
"""Returns the backend currently used by :class:`PlotWidget`.
@@ -665,12 +661,16 @@ class PlotWidget(qt.QMainWindow):
"""Override QWidget.contextMenuEvent to implement the context menu"""
menu = qt.QMenu(self)
from .actions.control import ZoomBackAction # Avoid cyclic import
+
zoomBackAction = ZoomBackAction(plot=self, parent=menu)
menu.addAction(zoomBackAction)
mode = self.getInteractiveMode()
if "shape" in mode and mode["shape"] == "polygon":
- from .actions.control import ClosePolygonInteractionAction # Avoid cyclic import
+ from .actions.control import (
+ ClosePolygonInteractionAction,
+ ) # Avoid cyclic import
+
action = ClosePolygonInteractionAction(plot=self, parent=menu)
menu.addAction(action)
@@ -691,7 +691,7 @@ class PlotWidget(qt.QMainWindow):
wasDirty = self._dirty
if not self._dirty and overlayOnly:
- self._dirty = 'overlay'
+ self._dirty = "overlay"
else:
self._dirty = True
@@ -704,8 +704,7 @@ class PlotWidget(qt.QMainWindow):
gridColor = self._foregroundColor
else:
gridColor = self._gridColor
- self._backend.setForegroundColors(
- self._foregroundColor, gridColor)
+ self._backend.setForegroundColors(self._foregroundColor, gridColor)
self._setDirtyPlot()
def getForegroundColor(self):
@@ -759,8 +758,7 @@ class PlotWidget(qt.QMainWindow):
dataBGColor = self._backgroundColor
else:
dataBGColor = self._dataBackgroundColor
- self._backend.setBackgroundColors(
- self._backgroundColor, dataBGColor)
+ self._backend.setBackgroundColors(self._backgroundColor, dataBGColor)
self._setDirtyPlot()
def getBackgroundColor(self):
@@ -829,7 +827,14 @@ class PlotWidget(qt.QMainWindow):
def hideEvent(self, event):
super(PlotWidget, self).hideEvent(event)
- self.sigVisibilityChanged.emit(False)
+ if qt.BINDING == "PySide6":
+ # Workaround RuntimeError: The SignalInstance object was already deleted
+ try:
+ self.sigVisibilityChanged.emit(False)
+ except RuntimeError as e:
+ _logger.error(f"Exception occured: {e}")
+ else:
+ self.sigVisibilityChanged.emit(False)
def _invalidateDataRange(self):
"""
@@ -842,42 +847,43 @@ class PlotWidget(qt.QMainWindow):
"""
Recomputes the range of the data displayed on this PlotWidget.
"""
- xMin = yMinLeft = yMinRight = float('nan')
- xMax = yMaxLeft = yMaxRight = float('nan')
+ xMin = yMinLeft = yMinRight = float("nan")
+ xMax = yMaxLeft = yMaxRight = float("nan")
for item in self.getItems():
if item.isVisible():
bounds = item.getBounds()
if bounds is not None:
with warnings.catch_warnings():
- warnings.simplefilter('ignore', category=RuntimeWarning)
+ warnings.simplefilter("ignore", category=RuntimeWarning)
# Ignore All-NaN slice encountered
xMin = numpy.nanmin([xMin, bounds[0]])
xMax = numpy.nanmax([xMax, bounds[1]])
# Take care of right axis
- if (isinstance(item, items.YAxisMixIn) and
- item.getYAxis() == 'right'):
+ if (
+ isinstance(item, items.YAxisMixIn)
+ and item.getYAxis() == "right"
+ ):
with warnings.catch_warnings():
- warnings.simplefilter('ignore', category=RuntimeWarning)
+ warnings.simplefilter("ignore", category=RuntimeWarning)
# Ignore All-NaN slice encountered
yMinRight = numpy.nanmin([yMinRight, bounds[2]])
yMaxRight = numpy.nanmax([yMaxRight, bounds[3]])
else:
with warnings.catch_warnings():
- warnings.simplefilter('ignore', category=RuntimeWarning)
+ warnings.simplefilter("ignore", category=RuntimeWarning)
# Ignore All-NaN slice encountered
yMinLeft = numpy.nanmin([yMinLeft, bounds[2]])
yMaxLeft = numpy.nanmax([yMaxLeft, bounds[3]])
def lGetRange(x, y):
return None if numpy.isnan(x) and numpy.isnan(y) else (x, y)
+
xRange = lGetRange(xMin, xMax)
yLeftRange = lGetRange(yMinLeft, yMaxLeft)
yRightRange = lGetRange(yMinRight, yMaxRight)
- self._dataRange = _PlotDataRange(x=xRange,
- y=yLeftRange,
- yright=yRightRange)
+ self._dataRange = _PlotDataRange(x=xRange, y=yLeftRange, yright=yRightRange)
def getDataRange(self):
"""
@@ -895,19 +901,19 @@ class PlotWidget(qt.QMainWindow):
# Content management
_KIND_TO_CLASSES = {
- 'curve': (items.Curve,),
- 'image': (items.ImageBase,),
- 'scatter': (items.Scatter,),
- 'marker': (items.MarkerBase,),
- 'item': (
+ "curve": (items.Curve,),
+ "image": (items.ImageBase,),
+ "scatter": (items.Scatter,),
+ "marker": (items.MarkerBase,),
+ "item": (
items.Line,
items.Shape,
items.BoundingRect,
items.XAxisExtent,
items.YAxisExtent,
),
- 'histogram': (items.Histogram,),
- }
+ "histogram": (items.Histogram,),
+ }
"""Mapping kind to item classes of this kind"""
@classmethod
@@ -920,11 +926,15 @@ class PlotWidget(qt.QMainWindow):
for kind, itemClasses in cls._KIND_TO_CLASSES.items():
if isinstance(item, itemClasses):
return kind
- raise ValueError('Unsupported item type %s' % type(item))
+ return "other"
def _notifyContentChanged(self, item):
- self.notify('contentChanged', action='add',
- kind=self._itemKind(item), legend=item.getName())
+ self.notify(
+ "contentChanged",
+ action="add",
+ kind=self._itemKind(item),
+ legend=item.getName(),
+ )
def _itemRequiresUpdate(self, item):
"""Called by items in the plot for asynchronous update
@@ -933,34 +943,25 @@ class PlotWidget(qt.QMainWindow):
"""
assert item.getPlot() == self
# Put item at the end of the list
- if item in self._contentToUpdate:
- self._contentToUpdate.remove(item)
- self._contentToUpdate.append(item)
+ if item in self.__itemsToUpdate:
+ self.__itemsToUpdate.remove(item)
+ self.__itemsToUpdate.append(item)
self._setDirtyPlot(overlayOnly=item.isOverlay())
- def addItem(self, item=None, *args, **kwargs):
+ def addItem(self, item):
"""Add an item to the plot content.
:param ~silx.gui.plot.items.Item item: The item to add.
:raises ValueError: If item is already in the plot.
"""
if not isinstance(item, items.Item):
- deprecated_warning(
- 'Function',
- 'addItem',
- replacement='addShape',
- since_version='0.13')
- if item is None and not args: # Only kwargs
- return self.addShape(**kwargs)
- else:
- return self.addShape(item, *args, **kwargs)
+ raise ValueError(f"argument must be a subclass of Item")
- assert not args and not kwargs
if item in self.getItems():
- raise ValueError('Item already in the plot')
+ raise ValueError("Item already in the plot")
# Add item to plot
- self._content[(item.getName(), self._itemKind(item))] = item
+ self.__items.append(item)
item._setPlot(self)
self._itemRequiresUpdate(item)
if isinstance(item, items.DATA_ITEMS):
@@ -975,19 +976,11 @@ class PlotWidget(qt.QMainWindow):
:param ~silx.gui.plot.items.Item item: Item to remove from the plot.
:raises ValueError: If item is not in the plot.
"""
- if not isinstance(item, items.Item): # Previous method usage
- deprecated_warning(
- 'Function',
- 'removeItem',
- replacement='remove(legend, kind="item")',
- since_version='0.13')
- if item is None:
- return
- self.remove(item, kind='item')
- return
+ if not isinstance(item, items.Item):
+ raise ValueError("argument must be an Item")
if item not in self.getItems():
- raise ValueError('Item not in the plot')
+ raise ValueError("Item not in the plot")
self.sigItemAboutToBeRemoved.emit(item)
@@ -999,9 +992,9 @@ class PlotWidget(qt.QMainWindow):
self._setActiveItem(kind, None)
# Remove item from plot
- self._content.pop((item.getName(), kind))
- if item in self._contentToUpdate:
- self._contentToUpdate.remove(item)
+ self.__items.remove(item)
+ if item in self.__itemsToUpdate:
+ self.__itemsToUpdate.remove(item)
if item.isVisible():
self._setDirtyPlot(overlayOnly=item.isOverlay())
if item.getBounds() is not None:
@@ -1009,14 +1002,12 @@ class PlotWidget(qt.QMainWindow):
item._removeBackendRenderer(self._backend)
item._setPlot(None)
- if (kind == 'curve' and not self.getAllCurves(just_legend=True,
- withhidden=True)):
+ if kind == "curve" and not self.getAllCurves(just_legend=True, withhidden=True):
self._resetColorAndStyle()
self.sigItemRemoved.emit(item)
- self.notify('contentChanged', action='remove',
- kind=kind, legend=item.getName())
+ self.notify("contentChanged", action="remove", kind=kind, legend=item.getName())
def discardItem(self, item) -> bool:
"""Remove the item from the plot.
@@ -1033,20 +1024,12 @@ class PlotWidget(qt.QMainWindow):
else:
return True
- @deprecated(replacement='addItem', since_version='0.13')
- def _add(self, item):
- return self.addItem(item)
-
- @deprecated(replacement='removeItem', since_version='0.13')
- def _remove(self, item):
- return self.removeItem(item)
-
def getItems(self):
"""Returns the list of items in the plot
:rtype: List[silx.gui.plot.items.Item]
"""
- return tuple(self._content.values())
+ return tuple(self.__items)
@contextmanager
def _muteActiveItemChangedSignal(self):
@@ -1064,15 +1047,30 @@ class PlotWidget(qt.QMainWindow):
# Store used value.
# This value is used when curve is updated either internally or by user.
- def addCurve(self, x, y, legend=None, info=None,
- replace=False,
- color=None, symbol=None,
- linewidth=None, linestyle=None,
- xlabel=None, ylabel=None, yaxis=None,
- xerror=None, yerror=None, z=None, selectable=None,
- fill=None, resetzoom=True,
- histogram=None, copy=True,
- baseline=None):
+ def addCurve(
+ self,
+ x,
+ y,
+ legend=None,
+ info=None,
+ replace=False,
+ color=None,
+ symbol=None,
+ linewidth=None,
+ linestyle=None,
+ xlabel=None,
+ ylabel=None,
+ yaxis=None,
+ xerror=None,
+ yerror=None,
+ z=None,
+ selectable=None,
+ fill=None,
+ resetzoom=True,
+ histogram=None,
+ copy=True,
+ baseline=None,
+ ):
"""Add a 1D curve given by x an y to the graph.
Curves are uniquely identified by their legend.
@@ -1155,18 +1153,19 @@ class PlotWidget(qt.QMainWindow):
False to use provided arrays.
:param baseline: curve baseline
:type: Union[None,float,numpy.ndarray]
- :returns: The key string identify this curve
+ :returns: The curve item
"""
# This is an histogram, use addHistogram
if histogram is not None:
- histoLegend = self.addHistogram(histogram=y,
- edges=x,
- legend=legend,
- color=color,
- fill=fill,
- align=histogram,
- copy=copy)
- histo = self.getHistogram(histoLegend)
+ histo = self.addHistogram(
+ histogram=y,
+ edges=x,
+ legend=legend,
+ color=color,
+ fill=fill,
+ align=histogram,
+ copy=copy,
+ )
histo.setInfo(info)
if linewidth is not None:
@@ -1174,25 +1173,21 @@ class PlotWidget(qt.QMainWindow):
if linestyle is not None:
histo.setLineStyle(linestyle)
if xlabel is not None:
- _logger.warning(
- 'addCurve: Histogram does not support xlabel argument')
+ _logger.warning("addCurve: Histogram does not support xlabel argument")
if ylabel is not None:
- _logger.warning(
- 'addCurve: Histogram does not support ylabel argument')
+ _logger.warning("addCurve: Histogram does not support ylabel argument")
if yaxis is not None:
histo.setYAxis(yaxis)
if z is not None:
histo.setZValue(z)
if selectable is not None:
_logger.warning(
- 'addCurve: Histogram does not support selectable argument')
-
- return
+ "addCurve: Histogram does not support selectable argument"
+ )
- legend = 'Unnamed curve 1.1' if legend is None else str(legend)
+ return histo
- # Check if curve was previously active
- wasActive = self.getActiveCurve(just_legend=True) == legend
+ legend = "Unnamed curve 1.1" if legend is None else str(legend)
if replace:
self._resetColorAndStyle()
@@ -1217,7 +1212,11 @@ class PlotWidget(qt.QMainWindow):
# Override previous/default values with provided ones
curve.setInfo(info)
if color is not None:
- curve.setColor(color)
+ curve.setColor(
+ colors.rgba(color, colors=self.getDefaultColors())
+ if isinstance(color, str)
+ else color
+ )
if symbol is not None:
curve.setSymbol(symbol)
if linewidth is not None:
@@ -1264,14 +1263,13 @@ class PlotWidget(qt.QMainWindow):
else:
self._notifyContentChanged(curve)
- if wasActive:
- self.setActiveCurve(curve.getName())
- elif self.getActiveCurveSelectionMode() == "legacy":
- if self.getActiveCurve(just_legend=True) is None:
- if len(self.getAllCurves(just_legend=True,
- withhidden=False)) == 1:
- if curve.isVisible():
- self.setActiveCurve(curve.getName())
+ if curve is self.getActiveCurve() or (
+ self.getActiveCurveSelectionMode() == "legacy"
+ and self.getActiveCurve() is None
+ and len(self.getAllCurves(just_legend=True, withhidden=False)) == 1
+ and curve.isVisible()
+ ):
+ self.setActiveCurve(curve)
if resetzoom:
# We ask for a zoom reset in order to handle the plot scaling
@@ -1279,19 +1277,21 @@ class PlotWidget(qt.QMainWindow):
# axes has to be set to off.
self.resetZoom()
- return legend
-
- def addHistogram(self,
- histogram,
- edges,
- legend=None,
- color=None,
- fill=None,
- align='center',
- resetzoom=True,
- copy=True,
- z=None,
- baseline=None):
+ return curve
+
+ def addHistogram(
+ self,
+ histogram,
+ edges,
+ legend=None,
+ color=None,
+ fill=None,
+ align="center",
+ resetzoom=True,
+ copy=True,
+ z=None,
+ baseline=None,
+ ):
"""Add an histogram to the graph.
This is NOT computing the histogram, this method takes as parameter
@@ -1325,9 +1325,9 @@ class PlotWidget(qt.QMainWindow):
:param int z: Layer on which to draw the histogram
:param baseline: histogram baseline
:type: Union[None,float,numpy.ndarray]
- :returns: The key string identify this histogram
+ :returns: The histogram item
"""
- legend = 'Unnamed histogram' if legend is None else str(legend)
+ legend = "Unnamed histogram" if legend is None else str(legend)
# Create/Update histogram object
histo = self.getHistogram(legend)
@@ -1341,15 +1341,20 @@ class PlotWidget(qt.QMainWindow):
# Override previous/default values with provided ones
if color is not None:
- histo.setColor(color)
+ histo.setColor(
+ colors.rgba(color, colors=self.getDefaultColors())
+ if isinstance(color, str)
+ else color
+ )
if fill is not None:
histo.setFill(fill)
if z is not None:
histo.setZValue(z=z)
# Set histogram data
- histo.setData(histogram=histogram, edges=edges, baseline=baseline,
- align=align, copy=copy)
+ histo.setData(
+ histogram=histogram, edges=edges, baseline=baseline, align=align, copy=copy
+ )
if mustBeAdded:
self.addItem(histo)
@@ -1362,16 +1367,26 @@ class PlotWidget(qt.QMainWindow):
# axes has to be set to off.
self.resetZoom()
- return legend
-
- def addImage(self, data, legend=None, info=None,
- replace=False,
- z=None,
- selectable=None, draggable=None,
- colormap=None, pixmap=None,
- xlabel=None, ylabel=None,
- origin=None, scale=None,
- resetzoom=True, copy=True):
+ return histo
+
+ def addImage(
+ self,
+ data,
+ legend=None,
+ info=None,
+ replace=False,
+ z=None,
+ selectable=None,
+ draggable=None,
+ colormap=None,
+ pixmap=None,
+ xlabel=None,
+ ylabel=None,
+ origin=None,
+ scale=None,
+ resetzoom=True,
+ copy=True,
+ ):
"""Add a 2D dataset or an image to the plot.
It displays either an array of data using a colormap or a RGB(A) image.
@@ -1421,13 +1436,10 @@ class PlotWidget(qt.QMainWindow):
:param bool resetzoom: True (the default) to reset the zoom.
:param bool copy: True make a copy of the data (default),
False to use provided arrays.
- :returns: The key string identify this image
+ :returns: The image item
"""
legend = "Unnamed Image 1.1" if legend is None else str(legend)
- # Check if image was previously active
- wasActive = self.getActiveImage(just_legend=True) == legend
-
data = numpy.array(data, copy=False)
assert data.ndim in (2, 3)
@@ -1480,7 +1492,8 @@ class PlotWidget(qt.QMainWindow):
else: # RGB(A) image
if pixmap is not None:
_logger.warning(
- 'addImage: pixmap argument ignored when data is RGB(A)')
+ "addImage: pixmap argument ignored when data is RGB(A)"
+ )
image.setData(data, copy=copy)
if replace:
@@ -1493,8 +1506,8 @@ class PlotWidget(qt.QMainWindow):
else:
self._notifyContentChanged(image)
- if len(self.getAllImages()) == 1 or wasActive:
- self.setActiveImage(legend)
+ if len(self.getAllImages()) == 1 or image is self.getActiveImage():
+ self.setActiveImage(image)
if resetzoom:
# We ask for a zoom reset in order to handle the plot scaling
@@ -1502,11 +1515,22 @@ class PlotWidget(qt.QMainWindow):
# axes has to be set to off.
self.resetZoom()
- return legend
-
- def addScatter(self, x, y, value, legend=None, colormap=None,
- info=None, symbol=None, xerror=None, yerror=None,
- z=None, copy=True):
+ return image
+
+ def addScatter(
+ self,
+ x,
+ y,
+ value,
+ legend=None,
+ colormap=None,
+ info=None,
+ symbol=None,
+ xerror=None,
+ yerror=None,
+ z=None,
+ copy=True,
+ ):
"""Add a (x, y, value) scatter to the graph.
Scatters are uniquely identified by their legend.
@@ -1549,16 +1573,12 @@ class PlotWidget(qt.QMainWindow):
:param bool copy: True make a copy of the data (default),
False to use provided arrays.
- :returns: The key string identify this scatter
+ :returns: The scatter item
"""
- legend = 'Unnamed scatter 1.1' if legend is None else str(legend)
-
- # Check if scatter was previously active
- wasActive = self._getActiveItem(kind='scatter',
- just_legend=True) == legend
+ legend = "Unnamed scatter 1.1" if legend is None else str(legend)
# Create/Update curve object
- scatter = self._getItem(kind='scatter', legend=legend)
+ scatter = self._getItem(kind="scatter", legend=legend)
mustBeAdded = scatter is None
if scatter is None:
# No previous scatter, create a default one and add it to the plot
@@ -1600,18 +1620,33 @@ class PlotWidget(qt.QMainWindow):
else:
self._notifyContentChanged(scatter)
- scatters = [item for item in self.getItems()
- if isinstance(item, items.Scatter) and item.isVisible()]
- if len(scatters) == 1 or wasActive:
- self._setActiveItem('scatter', scatter.getName())
-
- return legend
-
- def addShape(self, xdata, ydata, legend=None, info=None,
- replace=False,
- shape="polygon", color='black', fill=True,
- overlay=False, z=None, linestyle="-", linewidth=1.0,
- linebgcolor=None):
+ scatters = [
+ item
+ for item in self.getItems()
+ if isinstance(item, items.Scatter) and item.isVisible()
+ ]
+ if len(scatters) == 1 or scatter is self.getActiveScatter():
+ self.setActiveScatter(scatter)
+
+ return scatter
+
+ def addShape(
+ self,
+ xdata,
+ ydata,
+ legend=None,
+ info=None,
+ replace=False,
+ shape="polygon",
+ color="black",
+ fill=True,
+ overlay=False,
+ z=None,
+ linestyle="-",
+ linewidth=1.0,
+ linebgcolor="deprecated",
+ gapcolor=None,
+ ):
"""Add an item (i.e. a shape) to the plot.
Items are uniquely identified by their legend.
@@ -1624,7 +1659,8 @@ class PlotWidget(qt.QMainWindow):
:param numpy.ndarray ydata: The Y coords of the points of the shape
:param str legend: The legend to be associated to the item
:param info: User-defined information associated to the item
- :param bool replace: True (default) to delete already existing images
+ :param bool replace: True to delete already existing items
+ (the default is False)
:param str shape: Type of item to be drawn in
hline, polygon (the default), rectangle, vline,
polylines
@@ -1646,9 +1682,9 @@ class PlotWidget(qt.QMainWindow):
- ':' dotted line
:param float linewidth: Width of the line.
Only relevant for line markers where X or Y is None.
- :param str linebgcolor: Background color of the line, e.g., 'blue', 'b',
+ :param str gapcolor: Gap color of the line, e.g., 'blue', 'b',
'#FF0000'. It is used to draw dotted line using a second color.
- :returns: The key string identify this item
+ :returns: The shape item
"""
# expected to receive the same parameters as the signal
@@ -1657,9 +1693,9 @@ class PlotWidget(qt.QMainWindow):
z = int(z) if z is not None else 2
if replace:
- self.remove(kind='item')
+ self.remove(kind="item")
else:
- self.remove(legend, kind='item')
+ self.remove(legend, kind="item")
item = items.Shape(shape)
item.setName(legend)
@@ -1671,19 +1707,31 @@ class PlotWidget(qt.QMainWindow):
item.setPoints(numpy.array((xdata, ydata)).T)
item.setLineStyle(linestyle)
item.setLineWidth(linewidth)
- item.setLineBgColor(linebgcolor)
+ if linebgcolor != "deprecated":
+ deprecated_warning(
+ type_="Argument",
+ name="linebgcolor",
+ replacement="gapcolor",
+ since_version="2.0.0",
+ )
+ gapcolor = linebgcolor if gapcolor is None else gapcolor
+ item.setLineGapColor(gapcolor)
self.addItem(item)
- return legend
-
- def addXMarker(self, x, legend=None,
- text=None,
- color=None,
- selectable=False,
- draggable=False,
- constraint=None,
- yaxis='left'):
+ return item
+
+ def addXMarker(
+ self,
+ x,
+ legend=None,
+ text=None,
+ color=None,
+ selectable=False,
+ draggable=False,
+ constraint=None,
+ yaxis="left",
+ ):
"""Add a vertical line marker to the plot.
Markers are uniquely identified by their legend.
@@ -1710,22 +1758,32 @@ class PlotWidget(qt.QMainWindow):
the current cursor position in the plot as input
and that returns the filtered coordinates.
:param str yaxis: The Y axis this marker belongs to in: 'left', 'right'
- :return: The key string identify this marker
- """
- return self._addMarker(x=x, y=None, legend=legend,
- text=text, color=color,
- selectable=selectable, draggable=draggable,
- symbol=None, constraint=constraint,
- yaxis=yaxis)
-
- def addYMarker(self, y,
- legend=None,
- text=None,
- color=None,
- selectable=False,
- draggable=False,
- constraint=None,
- yaxis='left'):
+ :return: The marker item
+ """
+ return self._addMarker(
+ x=x,
+ y=None,
+ legend=legend,
+ text=text,
+ color=color,
+ selectable=selectable,
+ draggable=draggable,
+ symbol=None,
+ constraint=constraint,
+ yaxis=yaxis,
+ )
+
+ def addYMarker(
+ self,
+ y,
+ legend=None,
+ text=None,
+ color=None,
+ selectable=False,
+ draggable=False,
+ constraint=None,
+ yaxis="left",
+ ):
"""Add a horizontal line marker to the plot.
Markers are uniquely identified by their legend.
@@ -1752,22 +1810,34 @@ class PlotWidget(qt.QMainWindow):
the current cursor position in the plot as input
and that returns the filtered coordinates.
:param str yaxis: The Y axis this marker belongs to in: 'left', 'right'
- :return: The key string identify this marker
- """
- return self._addMarker(x=None, y=y, legend=legend,
- text=text, color=color,
- selectable=selectable, draggable=draggable,
- symbol=None, constraint=constraint,
- yaxis=yaxis)
-
- def addMarker(self, x, y, legend=None,
- text=None,
- color=None,
- selectable=False,
- draggable=False,
- symbol='+',
- constraint=None,
- yaxis='left'):
+ :return: The marker item
+ """
+ return self._addMarker(
+ x=None,
+ y=y,
+ legend=legend,
+ text=text,
+ color=color,
+ selectable=selectable,
+ draggable=draggable,
+ symbol=None,
+ constraint=constraint,
+ yaxis=yaxis,
+ )
+
+ def addMarker(
+ self,
+ x,
+ y,
+ legend=None,
+ text=None,
+ color=None,
+ selectable=False,
+ draggable=False,
+ symbol="+",
+ constraint=None,
+ yaxis="left",
+ ):
"""Add a point marker to the plot.
Markers are uniquely identified by their legend.
@@ -1806,7 +1876,7 @@ class PlotWidget(qt.QMainWindow):
the current cursor position in the plot as input
and that returns the filtered coordinates.
:param str yaxis: The Y axis this marker belongs to in: 'left', 'right'
- :return: The key string identify this marker
+ :return: The marker item
"""
if x is None:
xmin, xmax = self._xAxis.getLimits()
@@ -1816,17 +1886,32 @@ class PlotWidget(qt.QMainWindow):
ymin, ymax = self._yAxis.getLimits()
y = 0.5 * (ymax + ymin)
- return self._addMarker(x=x, y=y, legend=legend,
- text=text, color=color,
- selectable=selectable, draggable=draggable,
- symbol=symbol, constraint=constraint,
- yaxis=yaxis)
-
- def _addMarker(self, x, y, legend,
- text, color,
- selectable, draggable,
- symbol, constraint,
- yaxis=None):
+ return self._addMarker(
+ x=x,
+ y=y,
+ legend=legend,
+ text=text,
+ color=color,
+ selectable=selectable,
+ draggable=draggable,
+ symbol=symbol,
+ constraint=constraint,
+ yaxis=yaxis,
+ )
+
+ def _addMarker(
+ self,
+ x,
+ y,
+ legend,
+ text,
+ color,
+ selectable,
+ draggable,
+ symbol,
+ constraint,
+ yaxis=None,
+ ):
"""Common method for adding point, vline and hline marker.
See :meth:`addMarker` for argument documentation.
@@ -1834,8 +1919,11 @@ class PlotWidget(qt.QMainWindow):
assert (x, y) != (None, None)
if legend is None: # Find an unused legend
- markerLegends = [item.getName() for item in self.getItems()
- if isinstance(item, items.MarkerBase)]
+ markerLegends = [
+ item.getName()
+ for item in self.getItems()
+ if isinstance(item, items.MarkerBase)
+ ]
for index in itertools.count():
legend = "Unnamed Marker %d" % index
if legend not in markerLegends:
@@ -1852,8 +1940,9 @@ class PlotWidget(qt.QMainWindow):
# Create/Update marker object
marker = self._getMarker(legend)
if marker is not None and not isinstance(marker, markerClass):
- _logger.warning('Adding marker with same legend'
- ' but different type replaces it')
+ _logger.warning(
+ "Adding marker with same legend" " but different type replaces it"
+ )
self.removeItem(marker)
marker = None
@@ -1886,7 +1975,7 @@ class PlotWidget(qt.QMainWindow):
else:
self._notifyContentChanged(marker)
- return legend
+ return marker
# Hide
@@ -1896,7 +1985,7 @@ class PlotWidget(qt.QMainWindow):
:param str legend: The legend key identifying the curve
:return: True if the associated curve is hidden, False otherwise
"""
- curve = self._getItem('curve', legend)
+ curve = self._getItem("curve", legend)
return curve is not None and not curve.isVisible()
def hideCurve(self, legend, flag=True):
@@ -1907,9 +1996,9 @@ class PlotWidget(qt.QMainWindow):
:param str legend: The legend associated to the curve to be hidden
:param bool flag: True (default) to hide the curve, False to show it
"""
- curve = self._getItem('curve', legend)
+ curve = self._getItem("curve", legend)
if curve is None:
- _logger.warning('Curve not in plot: %s', legend)
+ _logger.warning("Curve not in plot: %s", legend)
return
isVisible = not flag
@@ -1918,13 +2007,17 @@ class PlotWidget(qt.QMainWindow):
# Remove
- ITEM_KINDS = 'curve', 'image', 'scatter', 'item', 'marker', 'histogram'
+ ITEM_KINDS = "curve", "image", "scatter", "item", "marker", "histogram"
"""List of supported kind of items in the plot."""
- _ACTIVE_ITEM_KINDS = 'curve', 'scatter', 'image'
+ _ACTIVE_ITEM_KINDS = "curve", "scatter", "image"
"""List of item's kind which have a active item."""
- def remove(self, legend=None, kind=ITEM_KINDS):
+ def remove(
+ self,
+ legend: str | items.Item | None = None,
+ kind: str | Sequence[str] = ITEM_KINDS,
+ ):
"""Remove one or all element(s) of the given legend and kind.
Examples:
@@ -1938,14 +2031,17 @@ class PlotWidget(qt.QMainWindow):
- ``remove('myImage')`` removes elements (for instance curve, image,
item and marker) with legend 'myImage'.
- :param str legend: The legend associated to the element to remove,
- or None to remove
- :param kind: The kind of elements to remove from the plot.
+ :param legend:
+ The legend of the item to remove or the item itself.
+ If None all items of given kind are removed.
+ :param kind: The kind of items to remove from the plot.
See :attr:`ITEM_KINDS`.
By default, it removes all kind of elements.
- :type kind: str or tuple of str to specify multiple kinds.
"""
- if kind == 'all': # Replace all by tuple of all kinds
+ if isinstance(legend, items.Item):
+ return self.removeItem(legend)
+
+ if kind == "all": # Replace all by tuple of all kinds
kind = self.ITEM_KINDS
if kind in self.ITEM_KINDS: # Kind is a str, make it a tuple
@@ -1958,8 +2054,10 @@ class PlotWidget(qt.QMainWindow):
# Clear each given kind
for aKind in kind:
for item in self.getItems():
- if (isinstance(item, self._KIND_TO_CLASSES[aKind]) and
- item.getPlot() is self): # Make sure item is still in the plot
+ if (
+ isinstance(item, self._KIND_TO_CLASSES[aKind])
+ and item.getPlot() is self
+ ): # Make sure item is still in the plot
self.removeItem(item)
else: # This is removing a single element
@@ -1969,32 +2067,41 @@ class PlotWidget(qt.QMainWindow):
if item is not None:
self.removeItem(item)
- def removeCurve(self, legend):
+ def removeCurve(self, legend: str | items.Curve | None):
"""Remove the curve associated to legend from the graph.
- :param str legend: The legend associated to the curve to be deleted
+ :param legend:
+ The legend of the curve to be deleted or the curve item
"""
if legend is None:
return
- self.remove(legend, kind='curve')
+ if isinstance(legend, items.Item):
+ return self.removeItem(legend)
+ self.remove(legend, kind="curve")
- def removeImage(self, legend):
+ def removeImage(self, legend: str | items.ImageBase | None):
"""Remove the image associated to legend from the graph.
- :param str legend: The legend associated to the image to be deleted
+ :param legend:
+ The legend of the image to be deleted or the image item
"""
if legend is None:
return
- self.remove(legend, kind='image')
+ if isinstance(legend, items.Item):
+ return self.removeItem(legend)
+ self.remove(legend, kind="image")
- def removeMarker(self, legend):
+ def removeMarker(self, legend: str | items.Marker | None):
"""Remove the marker associated to legend from the graph.
- :param str legend: The legend associated to the marker to be deleted
+ :param legend:
+ The legend of the marker to be deleted or the marker item
"""
if legend is None:
return
- self.remove(legend, kind='marker')
+ if isinstance(legend, items.Item):
+ return self.removeItem(legend)
+ self.remove(legend, kind="marker")
# Clear
@@ -2006,19 +2113,19 @@ class PlotWidget(qt.QMainWindow):
def clearCurves(self):
"""Remove all the curves from the plot."""
- self.remove(kind='curve')
+ self.remove(kind="curve")
def clearImages(self):
"""Remove all the images from the plot."""
- self.remove(kind='image')
+ self.remove(kind="image")
def clearItems(self):
- """Remove all the items from the plot. """
- self.remove(kind='item')
+ """Remove all the items from the plot."""
+ self.remove(kind="item")
def clearMarkers(self):
"""Remove all the markers from the plot."""
- self.remove(kind='marker')
+ self.remove(kind="marker")
# Interaction
@@ -2032,8 +2139,7 @@ class PlotWidget(qt.QMainWindow):
"""
return self._cursorConfiguration
- def setGraphCursor(self, flag=False, color='black',
- linewidth=1, linestyle='-'):
+ def setGraphCursor(self, flag=False, color="black", linewidth=1, linestyle="-"):
"""Toggle the display of a crosshair cursor and set its attributes.
:param bool flag: Toggle the display of a crosshair cursor.
@@ -2057,11 +2163,11 @@ class PlotWidget(qt.QMainWindow):
else:
self._cursorConfiguration = None
- self._backend.setGraphCursor(flag=flag, color=color,
- linewidth=linewidth, linestyle=linestyle)
+ self._backend.setGraphCursor(
+ flag=flag, color=color, linewidth=linewidth, linestyle=linestyle
+ )
self._setDirtyPlot()
- self.notify('setGraphCursor',
- state=self._cursorConfiguration is not None)
+ self.notify("setGraphCursor", state=self._cursorConfiguration is not None)
def pan(self, direction, factor=0.1):
"""Pan the graph in the given direction by the given factor.
@@ -2072,20 +2178,21 @@ class PlotWidget(qt.QMainWindow):
:param float factor: Proportion of the range used to pan the graph.
Must be strictly positive.
"""
- assert direction in ('up', 'down', 'left', 'right')
- assert factor > 0.
+ assert direction in ("up", "down", "left", "right")
+ assert factor > 0.0
- if direction in ('left', 'right'):
- xFactor = factor if direction == 'right' else - factor
+ if direction in ("left", "right"):
+ xFactor = factor if direction == "right" else -factor
xMin, xMax = self._xAxis.getLimits()
- xMin, xMax = _utils.applyPan(xMin, xMax, xFactor,
- self._xAxis.getScale() == self._xAxis.LOGARITHMIC)
+ xMin, xMax = _utils.applyPan(
+ xMin, xMax, xFactor, self._xAxis.getScale() == self._xAxis.LOGARITHMIC
+ )
self._xAxis.setLimits(xMin, xMax)
else: # direction in ('up', 'down')
- sign = -1. if self._yAxis.isInverted() else 1.
- yFactor = sign * (factor if direction == 'up' else -factor)
+ sign = -1.0 if self._yAxis.isInverted() else 1.0
+ yFactor = sign * (factor if direction == "up" else -factor)
yMin, yMax = self._yAxis.getLimits()
yIsLog = self._yAxis.getScale() == self._yAxis.LOGARITHMIC
@@ -2104,7 +2211,7 @@ class PlotWidget(qt.QMainWindow):
:rtype: bool
"""
- return self.getActiveCurveSelectionMode() != 'none'
+ return self.getActiveCurveSelectionMode() != "none"
def setActiveCurveHandling(self, flag=True):
"""Enable/Disable active curve selection.
@@ -2112,7 +2219,7 @@ class PlotWidget(qt.QMainWindow):
:param bool flag: True to enable 'atmostone' active curve selection,
False to disable active curve selection.
"""
- self.setActiveCurveSelectionMode('atmostone' if flag else 'none')
+ self.setActiveCurveSelectionMode("atmostone" if flag else "none")
def getActiveCurveStyle(self):
"""Returns the current style applied to active curve
@@ -2121,12 +2228,9 @@ class PlotWidget(qt.QMainWindow):
"""
return self._activeCurveStyle
- def setActiveCurveStyle(self,
- color=None,
- linewidth=None,
- linestyle=None,
- symbol=None,
- symbolsize=None):
+ def setActiveCurveStyle(
+ self, color=None, linewidth=None, linestyle=None, symbol=None, symbolsize=None
+ ):
"""Set the style of active curve
:param color: Color
@@ -2135,36 +2239,17 @@ class PlotWidget(qt.QMainWindow):
:param Union[str,None] symbol: Symbol of the markers
:param Union[float,None] symbolsize: Size of the symbols
"""
- self._activeCurveStyle = CurveStyle(color=color,
- linewidth=linewidth,
- linestyle=linestyle,
- symbol=symbol,
- symbolsize=symbolsize)
+ self._activeCurveStyle = CurveStyle(
+ color=color,
+ linewidth=linewidth,
+ linestyle=linestyle,
+ symbol=symbol,
+ symbolsize=symbolsize,
+ )
curve = self.getActiveCurve()
if curve is not None:
curve.setHighlightedStyle(self.getActiveCurveStyle())
- @deprecated(replacement="getActiveCurveStyle", since_version="0.9")
- def getActiveCurveColor(self):
- """Get the color used to display the currently active curve.
-
- See :meth:`setActiveCurveColor`.
- """
- return self._activeCurveStyle.getColor()
-
- @deprecated(replacement="setActiveCurveStyle", since_version="0.9")
- def setActiveCurveColor(self, color="#000000"):
- """Set the color to use to display the currently active curve.
-
- :param str color: Color of the active curve,
- e.g., 'blue', 'b', '#FF0000' (Default: 'black')
- """
- if color is None:
- color = "black"
- if color in self.colorDict:
- color = self.colorDict[color]
- self.setActiveCurveStyle(color=color)
-
def getActiveCurve(self, just_legend=False):
"""Return the currently active curve.
@@ -2180,7 +2265,7 @@ class PlotWidget(qt.QMainWindow):
if not self.isActiveCurveHandling():
return None
- return self._getActiveItem(kind='curve', just_legend=just_legend)
+ return self._getActiveItem(kind="curve", just_legend=just_legend)
def setActiveCurve(self, legend):
"""Make the curve associated to legend the active curve.
@@ -2193,10 +2278,11 @@ class PlotWidget(qt.QMainWindow):
return
if legend is None and self.getActiveCurveSelectionMode() == "legacy":
_logger.info(
- 'setActiveCurve(None) ignored due to active curve selection mode')
+ "setActiveCurve(None) ignored due to active curve selection mode"
+ )
return
- return self._setActiveItem(kind='curve', legend=legend)
+ return self._setActiveItem(kind="curve", item=legend)
def setActiveCurveSelectionMode(self, mode):
"""Sets the current selection mode.
@@ -2204,17 +2290,16 @@ class PlotWidget(qt.QMainWindow):
:param str mode: The active curve selection mode to use.
It can be: 'legacy', 'atmostone' or 'none'.
"""
- assert mode in ('legacy', 'atmostone', 'none')
+ assert mode in ("legacy", "atmostone", "none")
if mode != self._activeCurveSelectionMode:
self._activeCurveSelectionMode = mode
- if mode == 'none': # reset active curve
- self._setActiveItem(kind='curve', legend=None)
+ if mode == "none": # reset active curve
+ self._setActiveItem(kind="curve", item=None)
- elif mode == 'legacy' and self.getActiveCurve() is None:
+ elif mode == "legacy" and self.getActiveCurve() is None:
# Select an active curve
- curves = self.getAllCurves(just_legend=False,
- withhidden=False)
+ curves = self.getAllCurves(just_legend=False, withhidden=False)
if len(curves) == 1:
if curves[0].isVisible():
self.setActiveCurve(curves[0].getName())
@@ -2240,7 +2325,7 @@ class PlotWidget(qt.QMainWindow):
:rtype: str, :class:`.items.ImageData`, :class:`.items.ImageRgba`
or None
"""
- return self._getActiveItem(kind='image', just_legend=just_legend)
+ return self._getActiveItem(kind="image", just_legend=just_legend)
def setActiveImage(self, legend):
"""Make the image associated to legend the active image.
@@ -2248,7 +2333,7 @@ class PlotWidget(qt.QMainWindow):
:param str legend: The legend associated to the image
or None to have no active image.
"""
- return self._setActiveItem(kind='image', legend=legend)
+ return self._setActiveItem(kind="image", item=legend)
def getActiveScatter(self, just_legend=False):
"""Returns the currently active scatter.
@@ -2261,7 +2346,7 @@ class PlotWidget(qt.QMainWindow):
:return: Active scatter's legend or corresponding scatter object
:rtype: str, :class:`.items.Scatter` or None
"""
- return self._getActiveItem(kind='scatter', just_legend=just_legend)
+ return self._getActiveItem(kind="scatter", just_legend=just_legend)
def setActiveScatter(self, legend):
"""Make the scatter associated to legend the active scatter.
@@ -2269,78 +2354,79 @@ class PlotWidget(qt.QMainWindow):
:param str legend: The legend associated to the scatter
or None to have no active scatter.
"""
- return self._setActiveItem(kind='scatter', legend=legend)
+ return self._setActiveItem(kind="scatter", item=legend)
- def _getActiveItem(self, kind, just_legend=False):
- """Return the currently active item of that kind if any
+ def _getActiveItem(
+ self,
+ kind: str | None,
+ just_legend: bool = False,
+ ) -> items.Curve | items.Scatter | items.ImageBase | None:
+ """Return the currently active item of given kind if any.
- :param str kind: Type of item: 'curve', 'scatter' or 'image'
- :param bool just_legend: True to get the legend,
- False (default) to get the item
- :return: legend or item or None if no active item
+ :param kind: Type of item: 'curve', 'scatter' or 'image'
+ :param just_legend:
+ True to get the item's legend, False (the default) to get the item
"""
assert kind in self._ACTIVE_ITEM_KINDS
+ item = self.__activeItems[kind]
+ if item is not None and just_legend:
+ return item.getName()
+ return item
- if self._activeLegend[kind] is None:
- return None
+ def _setActiveItem(
+ self,
+ kind: str,
+ item: items.Curve | items.ImageBase | items.Scatter | str | None,
+ ) -> str | None:
+ """Make the given item active.
+
+ Note: There is one active item per "kind" of item.
+ """
+ assert kind in self._ACTIVE_ITEM_KINDS
- item = self._getItem(kind, self._activeLegend[kind])
if item is None:
- return None
+ legend = None
+ elif isinstance(item, items.Item):
+ legend = item.getName()
+ else:
+ legend = str(item)
+ item = self._getItem(kind, legend)
+ if item is None:
+ _logger.warning("This %s does not exist: %s", kind, legend)
- return item.getName() if just_legend else item
+ oldActiveItem = self._getActiveItem(kind=kind)
- def _setActiveItem(self, kind, legend):
- """Make the curve associated to legend the active curve.
+ if oldActiveItem is None and item is None:
+ return None
- :param str kind: Type of item: 'curve' or 'image'
- :param legend: The legend associated to the curve
- or None to have no active curve.
- :type legend: str or None
- """
- assert kind in self._ACTIVE_ITEM_KINDS
+ if oldActiveItem is not None:
+ # Stop listening previous active item
+ oldActiveItem.sigItemChanged.disconnect(self._activeItemChanged)
+ # Curve specific: Reset highlight of previous active curve
+ if kind == "curve":
+ oldActiveItem.setHighlighted(False)
+
+ self.__activeItems[kind] = item
xLabel = None
yLabel = None
yRightLabel = None
- oldActiveItem = self._getActiveItem(kind=kind)
-
- if oldActiveItem is not None: # Stop listening previous active image
- oldActiveItem.sigItemChanged.disconnect(self._activeItemChanged)
+ if item is not None:
+ # Curve specific: handle highlight
+ if kind == "curve":
+ item.setHighlightedStyle(self.getActiveCurveStyle())
+ item.setHighlighted(True)
- # Curve specific: Reset highlight of previous active curve
- if kind == 'curve' and oldActiveItem is not None:
- oldActiveItem.setHighlighted(False)
+ if isinstance(item, items.LabelsMixIn):
+ xLabel = item.getXLabel()
+ if isinstance(item, items.YAxisMixIn) and item.getYAxis() == "right":
+ yRightLabel = item.getYLabel()
+ else:
+ yLabel = item.getYLabel()
- if legend is None:
- self._activeLegend[kind] = None
- else:
- legend = str(legend)
- item = self._getItem(kind, legend)
- if item is None:
- _logger.warning("This %s does not exist: %s", kind, legend)
- self._activeLegend[kind] = None
- else:
- self._activeLegend[kind] = legend
-
- # Curve specific: handle highlight
- if kind == 'curve':
- item.setHighlightedStyle(self.getActiveCurveStyle())
- item.setHighlighted(True)
-
- if isinstance(item, items.LabelsMixIn):
- if item.getXLabel() is not None:
- xLabel = item.getXLabel()
- if item.getYLabel() is not None:
- if (isinstance(item, items.YAxisMixIn) and
- item.getYAxis() == 'right'):
- yRightLabel = item.getYLabel()
- else:
- yLabel = item.getYLabel()
-
- # Start listening new active item
- item.sigItemChanged.connect(self._activeItemChanged)
+ # Start listening new active item
+ item.sigItemChanged.connect(self._activeItemChanged)
# Store current labels and update plot
self._xAxis._setCurrentLabel(xLabel)
@@ -2349,19 +2435,13 @@ class PlotWidget(qt.QMainWindow):
self._setDirtyPlot()
- activeLegend = self._activeLegend[kind]
- if oldActiveItem is not None or activeLegend is not None:
- if oldActiveItem is None:
- oldActiveLegend = None
- else:
- oldActiveLegend = oldActiveItem.getName()
- self.notify(
- 'active' + kind[0].upper() + kind[1:] + 'Changed',
- updated=oldActiveLegend != activeLegend,
- previous=oldActiveLegend,
- legend=activeLegend)
-
- return activeLegend
+ self.notify(
+ f"active{kind.capitalize()}Changed",
+ updated=oldActiveItem is not item,
+ previous=None if oldActiveItem is None else oldActiveItem.getName(),
+ legend=legend,
+ )
+ return legend
def _activeItemChanged(self, type_):
"""Listen for active item changed signal and broadcast signal
@@ -2373,10 +2453,11 @@ class PlotWidget(qt.QMainWindow):
if item is not None:
kind = self._itemKind(item)
self.notify(
- 'active' + kind[0].upper() + kind[1:] + 'Changed',
+ "active" + kind[0].upper() + kind[1:] + "Changed",
updated=False,
previous=item.getName(),
- legend=item.getName())
+ legend=item.getName(),
+ )
# Getters
@@ -2396,24 +2477,29 @@ class PlotWidget(qt.QMainWindow):
:return: list of curves' legend or :class:`.items.Curve`
:rtype: list of str or list of :class:`.items.Curve`
"""
- curves = [item for item in self.getItems() if
- isinstance(item, items.Curve) and
- (withhidden or item.isVisible())]
+ curves = [
+ item
+ for item in self.getItems()
+ if isinstance(item, items.Curve) and (withhidden or item.isVisible())
+ ]
return [curve.getName() for curve in curves] if just_legend else curves
- def getCurve(self, legend=None):
+ def getCurve(self, legend: str | items.Curve | None = None) -> items.Curve:
"""Get the object describing a specific curve.
It returns None in case no matching curve is found.
- :param str legend:
+ :param legend:
The legend identifying the curve.
If not provided or None (the default), the active curve is returned
or if there is no active curve, the latest updated curve that is
not hidden is returned if there are curves in the plot.
:return: None or :class:`.items.Curve` object
"""
- return self._getItem(kind='curve', legend=legend)
+ if isinstance(legend, items.Curve):
+ _logger.warning("getCurve call not needed: legend is already an item")
+ return legend
+ return self._getItem(kind="curve", legend=legend)
def getAllImages(self, just_legend=False):
"""Returns all images legend or objects.
@@ -2430,83 +2516,62 @@ class PlotWidget(qt.QMainWindow):
:return: list of images' legend or :class:`.items.ImageBase`
:rtype: list of str or list of :class:`.items.ImageBase`
"""
- images = [item for item in self.getItems()
- if isinstance(item, items.ImageBase)]
+ images = [item for item in self.getItems() if isinstance(item, items.ImageBase)]
return [image.getName() for image in images] if just_legend else images
- def getImage(self, legend=None):
+ def getImage(self, legend: str | items.ImageBase | None = None) -> items.ImageBase:
"""Get the object describing a specific image.
It returns None in case no matching image is found.
- :param str legend:
+ :param legend:
The legend identifying the image.
If not provided or None (the default), the active image is returned
or if there is no active image, the latest updated image
is returned if there are images in the plot.
:return: None or :class:`.items.ImageBase` object
"""
- return self._getItem(kind='image', legend=legend)
+ if isinstance(legend, items.ImageBase):
+ _logger.warning("getImage call not needed: legend is already an item")
+ return legend
+ return self._getItem(kind="image", legend=legend)
- def getScatter(self, legend=None):
+ def getScatter(self, legend: str | items.Scatter | None = None) -> items.Scatter:
"""Get the object describing a specific scatter.
It returns None in case no matching scatter is found.
- :param str legend:
+ :param legend:
The legend identifying the scatter.
If not provided or None (the default), the active scatter is
returned or if there is no active scatter, the latest updated
scatter is returned if there are scatters in the plot.
:return: None or :class:`.items.Scatter` object
"""
- return self._getItem(kind='scatter', legend=legend)
+ if isinstance(legend, items.Scatter):
+ _logger.warning("getScatter call not needed: legend is already an item")
+ return legend
+ return self._getItem(kind="scatter", legend=legend)
- def getHistogram(self, legend=None):
+ def getHistogram(
+ self, legend: str | items.Histogram | None = None
+ ) -> items.Histogram:
"""Get the object describing a specific histogram.
It returns None in case no matching histogram is found.
- :param str legend:
+ :param legend:
The legend identifying the histogram.
If not provided or None (the default), the latest updated scatter
is returned if there are histograms in the plot.
:return: None or :class:`.items.Histogram` object
"""
- return self._getItem(kind='histogram', legend=legend)
-
- @deprecated(replacement='getItems', since_version='0.13')
- def _getItems(self, kind=ITEM_KINDS, just_legend=False, withhidden=False):
- """Retrieve all items of a kind in the plot
-
- :param kind: The kind of elements to retrieve from the plot.
- See :attr:`ITEM_KINDS`.
- By default, it removes all kind of elements.
- :type kind: str or tuple of str to specify multiple kinds.
- :param str kind: Type of item: 'curve' or 'image'
- :param bool just_legend: True to get the legend of the curves,
- False (the default) to get the curves' data
- and info.
- :param bool withhidden: False (default) to skip hidden curves.
- :return: list of legends or item objects
- """
- if kind == 'all': # Replace all by tuple of all kinds
- kind = self.ITEM_KINDS
-
- if kind in self.ITEM_KINDS: # Kind is a str, make it a tuple
- kind = (kind,)
-
- for aKind in kind:
- assert aKind in self.ITEM_KINDS
-
- output = []
- for item in self.getItems():
- type_ = self._itemKind(item)
- if type_ in kind and (withhidden or item.isVisible()):
- output.append(item.getName() if just_legend else item)
- return output
+ if isinstance(legend, items.Histogram):
+ _logger.warning("getHistogram call not needed: legend is already an item")
+ return legend
+ return self._getItem(kind="histogram", legend=legend)
- def _getItem(self, kind, legend=None):
+ def _getItem(self, kind, legend=None) -> items.Item:
"""Get an item from the plot: either an image or a curve.
Returns None if no match found.
@@ -2517,20 +2582,30 @@ class PlotWidget(qt.QMainWindow):
None to get active or last item
:return: Object describing the item or None
"""
+ if isinstance(legend, items.Item):
+ _logger.warning("_getItem call not needed: legend is already an item")
+ return legend
+
assert kind in self.ITEM_KINDS
if legend is not None:
- return self._content.get((legend, kind), None)
- else:
- if kind in self._ACTIVE_ITEM_KINDS:
- item = self._getActiveItem(kind=kind)
- if item is not None: # Return active item if available
+ for item in self.getItems():
+ if item.getName() == legend and kind == self._itemKind(item):
return item
- # Return last visible item if any
- itemClasses = self._KIND_TO_CLASSES[kind]
- allItems = [item for item in self.getItems()
- if isinstance(item, itemClasses) and item.isVisible()]
- return allItems[-1] if allItems else None
+ return None # No item found
+
+ if kind in self._ACTIVE_ITEM_KINDS:
+ item = self._getActiveItem(kind=kind)
+ if item is not None: # Return active item if available
+ return item
+ # Return last visible item if any
+ itemClasses = self._KIND_TO_CLASSES[kind]
+ allItems = [
+ item
+ for item in self.getItems()
+ if isinstance(item, itemClasses) and item.isVisible()
+ ]
+ return allItems[-1] if allItems else None
# Limits
@@ -2545,7 +2620,8 @@ class PlotWidget(qt.QMainWindow):
for axis, limits in zip(axes, ranges):
axis.sigLimitsChanged.emit(*limits)
event = PlotEvents.prepareLimitsChangedSignal(
- id(self.getWidgetHandle()), xRange, yRange, y2Range)
+ id(self.getWidgetHandle()), xRange, yRange, y2Range
+ )
self.notify(**event)
def getLimitsHistory(self):
@@ -2567,18 +2643,18 @@ class PlotWidget(qt.QMainWindow):
"""
self._xAxis.setLimits(xmin, xmax)
- def getGraphYLimits(self, axis='left'):
+ def getGraphYLimits(self, axis="left"):
"""Get the graph Y limits.
:param str axis: The axis for which to get the limits:
Either 'left' or 'right'
:return: Minimum and maximum values of the X axis
"""
- assert axis in ('left', 'right')
- yAxis = self._yAxis if axis == 'left' else self._yRightAxis
+ assert axis in ("left", "right")
+ yAxis = self._yAxis if axis == "left" else self._yRightAxis
return yAxis.getLimits()
- def setGraphYLimits(self, ymin, ymax, axis='left'):
+ def setGraphYLimits(self, ymin, ymax, axis="left"):
"""Set the graph Y limits.
:param float ymin: minimum bottom axis value
@@ -2586,40 +2662,80 @@ class PlotWidget(qt.QMainWindow):
:param str axis: The axis for which to get the limits:
Either 'left' or 'right'
"""
- assert axis in ('left', 'right')
- yAxis = self._yAxis if axis == 'left' else self._yRightAxis
+ assert axis in ("left", "right")
+ yAxis = self._yAxis if axis == "left" else self._yRightAxis
return yAxis.setLimits(ymin, ymax)
- def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None):
+ def setLimits(
+ self,
+ xmin: float,
+ xmax: float,
+ ymin: float,
+ ymax: float,
+ y2min: Optional[float] = None,
+ y2max: Optional[float] = None,
+ margins: Union[bool, tuple[float, float, float, float]] = False,
+ ):
"""Set the limits of the X and Y axes at once.
If y2min or y2max is None, the right Y axis limits are not updated.
- :param float xmin: minimum bottom axis value
- :param float xmax: maximum bottom axis value
- :param float ymin: minimum left axis value
- :param float ymax: maximum left axis value
- :param float y2min: minimum right axis value or None (the default)
- :param float y2max: maximum right axis value or None (the default)
- """
- # Deal with incorrect values
- axis = self.getXAxis()
- xmin, xmax = axis._checkLimits(xmin, xmax)
- axis = self.getYAxis()
- ymin, ymax = axis._checkLimits(ymin, ymax)
-
- if y2min is None or y2max is None:
- # if one limit is None, both are ignored
- y2min, y2max = None, None
- else:
- axis = self.getYAxis(axis="right")
- y2min, y2max = axis._checkLimits(y2min, y2max)
+ :param xmin: minimum bottom axis value
+ :param xmax: maximum bottom axis value
+ :param ymin: minimum left axis value
+ :param ymax: maximum left axis value
+ :param y2min: minimum right axis value or None (the default)
+ :param y2max: maximum right axis value or None (the default)
+ :param margins:
+ Data margins to add to the limits or a boolean telling
+ whether or not to add margins from :meth:`getDataMargins`.
+ """
+ limits = [
+ *self.getXAxis()._checkLimits(xmin, xmax),
+ *self.getYAxis()._checkLimits(ymin, ymax),
+ ]
+
+ # Only consider y2 axis if both limits are not None
+ if None not in (y2min, y2max):
+ limits.extend(self.getYAxis(axis="right")._checkLimits(y2min, y2max))
+
+ if margins: # Add margins around limits inside the plot area
+ limits = list(
+ _utils.addMarginsToLimits(
+ self.getDataMargins() if margins is True else margins,
+ self.getXAxis()._isLogarithmic(),
+ self.getYAxis()._isLogarithmic(),
+ *limits,
+ )
+ )
+
+ if self.isKeepDataAspectRatio():
+ # Use limits with margins to keep ratio
+ xmin, xmax, ymin, ymax = limits[:4]
+
+ # Compute bbox wth figure aspect ratio
+ plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:]
+ if plotWidth > 0 and plotHeight > 0:
+ plotRatio = plotHeight / plotWidth
+ dataRatio = (ymax - ymin) / (xmax - xmin)
+ if dataRatio < plotRatio:
+ # Increase y range
+ ycenter = 0.5 * (ymax + ymin)
+ yrange = (xmax - xmin) * plotRatio
+ limits[2] = ycenter - 0.5 * yrange
+ limits[3] = ycenter + 0.5 * yrange
+
+ elif dataRatio > plotRatio:
+ # Increase x range
+ xcenter = 0.5 * (xmax + xmin)
+ xrange_ = (ymax - ymin) / plotRatio
+ limits[0] = xcenter - 0.5 * xrange_
+ limits[1] = xcenter + 0.5 * xrange_
if self._viewConstrains:
- view = self._viewConstrains.normalize(xmin, xmax, ymin, ymax)
- xmin, xmax, ymin, ymax = view
+ limits[:4] = self._viewConstrains.normalize(*limits[:4])
- self._backend.setLimits(xmin, xmax, ymin, ymax, y2min, y2max)
+ self._backend.setLimits(*limits)
self._setDirtyPlot()
self._notifyLimitsChanged()
@@ -2661,16 +2777,16 @@ class PlotWidget(qt.QMainWindow):
"""
self._xAxis.setLabel(label)
- def getGraphYLabel(self, axis='left'):
+ def getGraphYLabel(self, axis="left"):
"""Return the current Y axis label as a str.
:param str axis: The Y axis for which to get the label (left or right)
"""
- assert axis in ('left', 'right')
- yAxis = self._yAxis if axis == 'left' else self._yRightAxis
+ assert axis in ("left", "right")
+ yAxis = self._yAxis if axis == "left" else self._yRightAxis
return yAxis.getLabel()
- def setGraphYLabel(self, label="Y", axis='left'):
+ def setGraphYLabel(self, label="Y", axis="left"):
"""Set the plot Y axis label.
The provided label can be temporarily replaced by the Y label of the
@@ -2679,8 +2795,8 @@ class PlotWidget(qt.QMainWindow):
:param str label: The Y axis label (default: 'Y')
:param str axis: The Y axis for which to set the label (left or right)
"""
- assert axis in ('left', 'right')
- yAxis = self._yAxis if axis == 'left' else self._yRightAxis
+ assert axis in ("left", "right")
+ yAxis = self._yAxis if axis == "left" else self._yRightAxis
return yAxis.setLabel(label)
# Axes
@@ -2703,7 +2819,7 @@ class PlotWidget(qt.QMainWindow):
('left' or 'right').
:rtype: :class:`.items.Axis`
"""
- assert(axis in ["left", "right"])
+ assert axis in ["left", "right"]
return self._yAxis if axis == "left" else self._yRightAxis
def setAxesDisplayed(self, displayed: bool):
@@ -2717,7 +2833,7 @@ class PlotWidget(qt.QMainWindow):
if displayed:
self._backend.setAxesMargins(*self.__axesMargins)
else:
- self._backend.setAxesMargins(0., 0., 0., 0.)
+ self._backend.setAxesMargins(0.0, 0.0, 0.0, 0.0)
self._setDirtyPlot()
self._sigAxesVisibilityChanged.emit(displayed)
@@ -2728,8 +2844,7 @@ class PlotWidget(qt.QMainWindow):
"""
return self.__axesDisplayed
- def setAxesMargins(
- self, left: float, top: float, right: float, bottom: float):
+ def setAxesMargins(self, left: float, top: float, right: float, bottom: float):
"""Set ratios of margins surrounding data plot area.
All ratios must be within [0., 1.].
@@ -2742,9 +2857,9 @@ class PlotWidget(qt.QMainWindow):
:raises ValueError:
"""
for value in (left, top, right, bottom):
- if value < 0. or value > 1.:
+ if value < 0.0 or value > 1.0:
raise ValueError("Margin ratios must be within [0., 1.]")
- if left + right >= 1. or top + bottom >= 1.:
+ if left + right >= 1.0 or top + bottom >= 1.0:
raise ValueError("Sum of ratios of opposed sides >= 1")
margins = left, top, right, bottom
@@ -2835,7 +2950,7 @@ class PlotWidget(qt.QMainWindow):
self._backend.setKeepDataAspectRatio(flag=flag)
self._setDirtyPlot()
self._forceResetZoom()
- self.notify('setKeepDataAspectRatio', state=flag)
+ self.notify("setKeepDataAspectRatio", state=flag)
def getGraphGrid(self):
"""Return the current grid mode, either None, 'major' or 'both'.
@@ -2852,15 +2967,15 @@ class PlotWidget(qt.QMainWindow):
'both' for grid on both major and minor ticks.
:type which: str of bool
"""
- assert which in (None, True, False, 'both', 'major')
+ assert which in (None, True, False, "both", "major")
if not which:
which = None
elif which is True:
- which = 'major'
+ which = "major"
self._grid = which
self._backend.setGraphGrid(which)
self._setDirtyPlot()
- self.notify('setGraphGrid', which=str(which))
+ self.notify("setGraphGrid", which=str(which))
# Defaults
@@ -2876,7 +2991,7 @@ class PlotWidget(qt.QMainWindow):
:param bool flag: True to use 'o' as the default curve symbol,
False to use no symbol.
"""
- self._defaultPlotPoints = silx.config.DEFAULT_PLOT_SYMBOL if flag else ''
+ self._defaultPlotPoints = silx.config.DEFAULT_PLOT_SYMBOL if flag else ""
# Reset symbol of all curves
curves = self.getAllCurves(just_legend=False, withhidden=True)
@@ -2897,7 +3012,7 @@ class PlotWidget(qt.QMainWindow):
"""
self._plotLines = bool(flag)
- linestyle = '-' if self._plotLines else ' '
+ linestyle = "-" if self._plotLines else " "
# Reset linestyle of all curves
curves = self.getAllCurves(withhidden=True)
@@ -2927,16 +3042,18 @@ class PlotWidget(qt.QMainWindow):
autoscale gray colormap.
"""
if colormap is None:
- colormap = Colormap(name=silx.config.DEFAULT_COLORMAP_NAME,
- normalization='linear',
- vmin=None,
- vmax=None)
+ colormap = Colormap(
+ name=silx.config.DEFAULT_COLORMAP_NAME,
+ normalization="linear",
+ vmin=None,
+ vmax=None,
+ )
if isinstance(colormap, dict):
self._defaultColormap = Colormap._fromDict(colormap)
else:
assert isinstance(colormap, Colormap)
self._defaultColormap = colormap
- self.notify('defaultColormapChanged')
+ self.notify("defaultColormapChanged")
@staticmethod
def getSupportedColormaps():
@@ -2948,17 +3065,35 @@ class PlotWidget(qt.QMainWindow):
"""
return Colormap.getSupportedColormaps()
+ def setDefaultColors(self, colors: Optional[Tuple[str, ...]]):
+ """Set the list of colors to use as default for curves and histograms.
+
+ Set to None to use `silx.config.DEFAULT_PLOT_CURVE_COLORS`.
+ """
+ self._defaultColors = None if colors is None else tuple(colors)
+ self._resetColorAndStyle()
+
+ def getDefaultColors(self) -> Tuple[str, ...]:
+ """Returns the list of default colors for curves and histograms"""
+ if self._defaultColors is None:
+ return tuple(silx.config.DEFAULT_PLOT_CURVE_COLORS)
+ return self._defaultColors
+
def _resetColorAndStyle(self):
self._colorIndex = 0
self._styleIndex = 0
- def _getColorAndStyle(self):
- color = self.colorList[self._colorIndex]
+ def _getColorAndStyle(self) -> Tuple[str, str]:
+ defaultColors = self.getDefaultColors()
+ if self._colorIndex >= len(defaultColors): # Handle list length updated
+ self._colorIndex = 0
+
+ color = defaultColors[self._colorIndex]
style = self._styleList[self._styleIndex]
# Loop over color and then styles
self._colorIndex += 1
- if self._colorIndex >= len(self.colorList):
+ if self._colorIndex >= len(defaultColors):
self._colorIndex = 0
self._styleIndex = (self._styleIndex + 1) % len(self._styleList)
@@ -2967,7 +3102,7 @@ class PlotWidget(qt.QMainWindow):
color, style = self._getColorAndStyle()
if not self._plotLines:
- style = ' '
+ style = " "
return color, style
@@ -2990,32 +3125,30 @@ class PlotWidget(qt.QMainWindow):
:param kwargs: The information of the event.
"""
eventDict = kwargs.copy()
- eventDict['event'] = event
+ eventDict["event"] = event
self.sigPlotSignal.emit(eventDict)
- if event == 'setKeepDataAspectRatio':
- self.sigSetKeepDataAspectRatio.emit(kwargs['state'])
- elif event == 'setGraphGrid':
- self.sigSetGraphGrid.emit(kwargs['which'])
- elif event == 'setGraphCursor':
- self.sigSetGraphCursor.emit(kwargs['state'])
- elif event == 'contentChanged':
+ if event == "setKeepDataAspectRatio":
+ self.sigSetKeepDataAspectRatio.emit(kwargs["state"])
+ elif event == "setGraphGrid":
+ self.sigSetGraphGrid.emit(kwargs["which"])
+ elif event == "setGraphCursor":
+ self.sigSetGraphCursor.emit(kwargs["state"])
+ elif event == "contentChanged":
self.sigContentChanged.emit(
- kwargs['action'], kwargs['kind'], kwargs['legend'])
- elif event == 'activeCurveChanged':
- self.sigActiveCurveChanged.emit(
- kwargs['previous'], kwargs['legend'])
- elif event == 'activeImageChanged':
- self.sigActiveImageChanged.emit(
- kwargs['previous'], kwargs['legend'])
- elif event == 'activeScatterChanged':
- self.sigActiveScatterChanged.emit(
- kwargs['previous'], kwargs['legend'])
- elif event == 'interactiveModeChanged':
- self.sigInteractiveModeChanged.emit(kwargs['source'])
+ kwargs["action"], kwargs["kind"], kwargs["legend"]
+ )
+ elif event == "activeCurveChanged":
+ self.sigActiveCurveChanged.emit(kwargs["previous"], kwargs["legend"])
+ elif event == "activeImageChanged":
+ self.sigActiveImageChanged.emit(kwargs["previous"], kwargs["legend"])
+ elif event == "activeScatterChanged":
+ self.sigActiveScatterChanged.emit(kwargs["previous"], kwargs["legend"])
+ elif event == "interactiveModeChanged":
+ self.sigInteractiveModeChanged.emit(kwargs["source"])
eventDict = kwargs.copy()
- eventDict['event'] = event
+ eventDict["event"] = event
self._callback(eventDict)
def setCallback(self, callbackFunction=None):
@@ -3045,11 +3178,11 @@ class PlotWidget(qt.QMainWindow):
ddict = {}
_logger.debug("Received dict keys = %s", str(ddict.keys()))
_logger.debug(str(ddict))
- if ddict['event'] in ["legendClicked", "curveClicked"]:
- if ddict['button'] == "left":
- self.setActiveCurve(ddict['label'])
- qt.QToolTip.showText(self.cursor().pos(), ddict['label'])
- elif ddict['event'] == 'mouseClicked' and ddict['button'] == 'left':
+ if ddict["event"] == "curveClicked":
+ if ddict["button"] == "left":
+ self.setActiveCurve(ddict["item"])
+ qt.QToolTip.showText(self.cursor().pos(), ddict["label"])
+ elif ddict["event"] == "mouseClicked" and ddict["button"] == "left":
self.setActiveCurve(None)
def saveGraph(self, filename, fileFormat=None, dpi=None):
@@ -3066,42 +3199,51 @@ class PlotWidget(qt.QMainWindow):
:return: False if cannot save the plot, True otherwise
"""
if fileFormat is None:
- if not hasattr(filename, 'lower'):
- _logger.warning(
- 'saveGraph cancelled, cannot define file format.')
+ if not hasattr(filename, "lower"):
+ _logger.warning("saveGraph cancelled, cannot define file format.")
return False
else:
fileFormat = (filename.split(".")[-1]).lower()
- supportedFormats = ("png", "svg", "pdf", "ps", "eps",
- "tif", "tiff", "jpeg", "jpg")
+ supportedFormats = (
+ "png",
+ "svg",
+ "pdf",
+ "ps",
+ "eps",
+ "tif",
+ "tiff",
+ "jpeg",
+ "jpg",
+ )
if fileFormat not in supportedFormats:
- _logger.warning('Unsupported format %s', fileFormat)
+ _logger.warning("Unsupported format %s", fileFormat)
return False
else:
- self._backend.saveGraph(filename,
- fileFormat=fileFormat,
- dpi=dpi)
+ self._backend.saveGraph(filename, fileFormat=fileFormat, dpi=dpi)
return True
- def getDataMargins(self):
+ def getDataMargins(self) -> tuple[float, float, float, float]:
"""Get the default data margin ratios, see :meth:`setDataMargins`.
:return: The margin ratios for each side (xMin, xMax, yMin, yMax).
- :rtype: A 4-tuple of floats.
"""
return self._defaultDataMargins
- def setDataMargins(self, xMinMargin=0., xMaxMargin=0.,
- yMinMargin=0., yMaxMargin=0.):
+ def setDataMargins(
+ self,
+ xMinMargin: float = 0.0,
+ xMaxMargin: float = 0.0,
+ yMinMargin: float = 0.0,
+ yMaxMargin: float = 0.0,
+ ):
"""Set the default data margins to use in :meth:`resetZoom`.
- Set the default ratios of margins (as floats) to add around the data
+ Set the default ratios of margins to add around the data
inside the plot area for each side.
"""
- self._defaultDataMargins = (xMinMargin, xMaxMargin,
- yMinMargin, yMaxMargin)
+ self._defaultDataMargins = (xMinMargin, xMaxMargin, yMinMargin, yMaxMargin)
def getAutoReplot(self):
"""Return True if replot is automatically handled, False otherwise.
@@ -3133,10 +3275,10 @@ class PlotWidget(qt.QMainWindow):
It is in charge of performing required PlotWidget operations
"""
- for item in self._contentToUpdate:
+ for item in self.__itemsToUpdate:
item._update(self._backend)
- self._contentToUpdate = []
+ self.__itemsToUpdate = []
yield
self._dirty = False # reset dirty flag
@@ -3144,7 +3286,10 @@ class PlotWidget(qt.QMainWindow):
"""Request to draw the plot."""
self._backend.replot()
- def _forceResetZoom(self, dataMargins=None):
+ def _forceResetZoom(
+ self,
+ dataMargins: Optional[tuple[float, float, float, float]] = None,
+ ):
"""Reset the plot limits to the bounds of the data and redraw the plot.
This method forces a reset zoom and does not check axis autoscale.
@@ -3155,55 +3300,30 @@ class PlotWidget(qt.QMainWindow):
data (xMin, xMax, yMin and yMax limits).
For log scale, extra margins are applied in log10 of the data.
- :param dataMargins: Ratios of margins to add around the data inside
- the plot area for each side (default: no margins).
- :type dataMargins: A 4-tuple of float as (xMin, xMax, yMin, yMax).
+ :param dataMargins:
+ Ratios of margins to add around the data inside the plot area for each side.
+ If None (the default), use margins from :meth:`getDataMargins`.
"""
- if dataMargins is None:
- dataMargins = self._defaultDataMargins
-
# Get data range
ranges = self.getDataRange()
- xmin, xmax = (1., 100.) if ranges.x is None else ranges.x
- ymin, ymax = (1., 100.) if ranges.y is None else ranges.y
+ xmin, xmax = (1.0, 100.0) if ranges.x is None else ranges.x
+ ymin, ymax = (1.0, 100.0) if ranges.y is None else ranges.y
if ranges.yright is None:
- ymin2, ymax2 = ymin, ymax
+ y2min, y2max = ymin, ymax
else:
- ymin2, ymax2 = ranges.yright
+ y2min, y2max = ranges.yright
if ranges.y is None:
ymin, ymax = ranges.yright
- # Add margins around data inside the plot area
- newLimits = list(_utils.addMarginsToLimits(
- dataMargins,
- self._xAxis._isLogarithmic(),
- self._yAxis._isLogarithmic(),
- xmin, xmax, ymin, ymax, ymin2, ymax2))
-
- if self.isKeepDataAspectRatio():
- # Use limits with margins to keep ratio
- xmin, xmax, ymin, ymax = newLimits[:4]
-
- # Compute bbox wth figure aspect ratio
- plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:]
- if plotWidth > 0 and plotHeight > 0:
- plotRatio = plotHeight / plotWidth
- dataRatio = (ymax - ymin) / (xmax - xmin)
- if dataRatio < plotRatio:
- # Increase y range
- ycenter = 0.5 * (ymax + ymin)
- yrange = (xmax - xmin) * plotRatio
- newLimits[2] = ycenter - 0.5 * yrange
- newLimits[3] = ycenter + 0.5 * yrange
-
- elif dataRatio > plotRatio:
- # Increase x range
- xcenter = 0.5 * (xmax + xmin)
- xrange_ = (ymax - ymin) / plotRatio
- newLimits[0] = xcenter - 0.5 * xrange_
- newLimits[1] = xcenter + 0.5 * xrange_
-
- self.setLimits(*newLimits)
+ self.setLimits(
+ xmin,
+ xmax,
+ ymin,
+ ymax,
+ y2min,
+ y2max,
+ margins=dataMargins if dataMargins is not None else True,
+ )
def resetZoom(self, dataMargins=None):
"""Reset the plot limits to the bounds of the data and redraw the plot.
@@ -3233,7 +3353,9 @@ class PlotWidget(qt.QMainWindow):
# This avoids issues with toggling log scale with matplotlib 2.1.0
if self._xAxis.getScale() == self._xAxis.LOGARITHMIC and xLimits[0] <= 0:
xAuto = True
- if self._yAxis.getScale() == self._yAxis.LOGARITHMIC and (yLimits[0] <= 0 or y2Limits[0] <= 0):
+ if self._yAxis.getScale() == self._yAxis.LOGARITHMIC and (
+ yLimits[0] <= 0 or y2Limits[0] <= 0
+ ):
yAuto = True
if not xAuto and not yAuto:
@@ -3246,14 +3368,15 @@ class PlotWidget(qt.QMainWindow):
self.setGraphXLimits(*xLimits)
elif xAuto and not yAuto:
if y2Limits is not None:
- self.setGraphYLimits(
- y2Limits[0], y2Limits[1], axis='right')
+ self.setGraphYLimits(y2Limits[0], y2Limits[1], axis="right")
if yLimits is not None:
- self.setGraphYLimits(yLimits[0], yLimits[1], axis='left')
+ self.setGraphYLimits(yLimits[0], yLimits[1], axis="left")
- if (xLimits != self._xAxis.getLimits() or
- yLimits != self._yAxis.getLimits() or
- y2Limits != self._yRightAxis.getLimits()):
+ if (
+ xLimits != self._xAxis.getLimits()
+ or yLimits != self._yAxis.getLimits()
+ or y2Limits != self._yRightAxis.getLimits()
+ ):
self._notifyLimitsChanged()
# Coord conversion
@@ -3296,7 +3419,7 @@ class PlotWidget(qt.QMainWindow):
if check:
isOutside = numpy.logical_or(
numpy.logical_or(x > xmax, x < xmin),
- numpy.logical_or(y > ymax, y < ymin)
+ numpy.logical_or(y > ymax, y < ymin),
)
if numpy.any(isOutside):
@@ -3337,7 +3460,8 @@ class PlotWidget(qt.QMainWindow):
left, top, width, height = self.getPlotBoundsInPixels()
isOutside = numpy.logical_or(
numpy.logical_or(x < left, x > left + width),
- numpy.logical_or(y < top, y > top + height))
+ numpy.logical_or(y < top, y > top + height),
+ )
if numpy.any(isOutside):
return None
@@ -3367,14 +3491,6 @@ class PlotWidget(qt.QMainWindow):
self.__graphCursorShape = cursor
self._backend.setGraphCursorShape(cursor)
- @deprecated(replacement='getItems', since_version='0.13')
- def _getAllMarkers(self, just_legend=False):
- markers = [item for item in self.getItems() if isinstance(item, items.MarkerBase)]
- if just_legend:
- return [marker.getName() for marker in markers]
- else:
- return markers
-
def _getMarkerAt(self, x, y):
"""Return the most interactive marker at a location, else None
@@ -3382,10 +3498,13 @@ class PlotWidget(qt.QMainWindow):
:param float y: Y position in pixels
:rtype: None of marker object
"""
+
def checkDraggable(item):
return isinstance(item, items.MarkerBase) and item.isDraggable()
+
def checkSelectable(item):
return isinstance(item, items.MarkerBase) and item.isSelectable()
+
def check(item):
return isinstance(item, items.MarkerBase)
@@ -3405,7 +3524,7 @@ class PlotWidget(qt.QMainWindow):
:param str legend: The legend of the marker to retrieve
:rtype: None of marker object
"""
- return self._getItem(kind='marker', legend=legend)
+ return self._getItem(kind="marker", legend=legend)
def pickItems(self, x, y, condition=None):
"""Generator of picked items in the plot at given position.
@@ -3420,7 +3539,9 @@ class PlotWidget(qt.QMainWindow):
:return: Iterable of :class:`PickingResult` objects at picked position.
Items are ordered from front to back.
"""
- for item in reversed(self._backend.getItemsFromBackToFront(condition=condition)):
+ for item in reversed(
+ self._backend.getItemsFromBackToFront(condition=condition)
+ ):
result = item.pick(x, y)
if result is not None:
yield result
@@ -3466,7 +3587,7 @@ class PlotWidget(qt.QMainWindow):
"""
if self._isPositionInPlotArea(xPixel, yPixel) == (xPixel, yPixel):
self._pressedButtons.append(btn)
- self._eventHandler.handleEvent('press', xPixel, yPixel, btn)
+ self._eventHandler.handleEvent("press", xPixel, yPixel, btn)
def onMouseMove(self, xPixel, yPixel):
"""Handle mouse move event.
@@ -3479,8 +3600,7 @@ class PlotWidget(qt.QMainWindow):
if self._cursorInPlot != isCursorInPlot:
self._cursorInPlot = isCursorInPlot
- self._eventHandler.handleEvent(
- 'enter' if self._cursorInPlot else 'leave')
+ self._eventHandler.handleEvent("enter" if self._cursorInPlot else "leave")
if isCursorInPlot:
# Signal mouse move event
@@ -3489,12 +3609,13 @@ class PlotWidget(qt.QMainWindow):
btn = self._pressedButtons[-1] if self._pressedButtons else None
event = PlotEvents.prepareMouseSignal(
- 'mouseMoved', btn, dataPos[0], dataPos[1], xPixel, yPixel)
+ "mouseMoved", btn, dataPos[0], dataPos[1], xPixel, yPixel
+ )
self.notify(**event)
# Either button was pressed in the plot or cursor is in the plot
if isCursorInPlot or self._pressedButtons:
- self._eventHandler.handleEvent('move', inXPixel, inYPixel)
+ self._eventHandler.handleEvent("move", inXPixel, inYPixel)
def onMouseRelease(self, xPixel, yPixel, btn):
"""Handle mouse release event.
@@ -3509,7 +3630,7 @@ class PlotWidget(qt.QMainWindow):
pass
else:
xPixel, yPixel = self._isPositionInPlotArea(xPixel, yPixel)
- self._eventHandler.handleEvent('release', xPixel, yPixel, btn)
+ self._eventHandler.handleEvent("release", xPixel, yPixel, btn)
def onMouseWheel(self, xPixel, yPixel, angleInDegrees):
"""Handle mouse wheel event.
@@ -3521,17 +3642,25 @@ class PlotWidget(qt.QMainWindow):
negative for movement toward the user.
"""
if self._isPositionInPlotArea(xPixel, yPixel) == (xPixel, yPixel):
- self._eventHandler.handleEvent(
- 'wheel', xPixel, yPixel, angleInDegrees)
+ self._eventHandler.handleEvent("wheel", xPixel, yPixel, angleInDegrees)
def onMouseLeaveWidget(self):
"""Handle mouse leave widget event."""
if self._cursorInPlot:
self._cursorInPlot = False
- self._eventHandler.handleEvent('leave')
+ self._eventHandler.handleEvent("leave")
# Interaction modes #
+ def interaction(self) -> PlotInteraction:
+ """Returns the interaction handler for this PlotWidget"""
+ return self._eventHandler
+
+ def __interactionChanged(self):
+ """Handle PlotInteraction updates"""
+ if self.__isInteractionSignalForwarded:
+ self.sigInteractiveModeChanged.emit(None)
+
def getInteractiveMode(self):
"""Returns the current interactive mode as a dict.
@@ -3540,7 +3669,7 @@ class PlotWidget(qt.QMainWindow):
It can also contains extra keys (e.g., 'color') specific to a mode
as provided to :meth:`setInteractiveMode`.
"""
- return self._eventHandler.getInteractiveMode()
+ return self.interaction()._getInteractiveMode()
def resetInteractiveMode(self):
"""Reset the interactive mode to use the previous basic interactive
@@ -3551,36 +3680,47 @@ class PlotWidget(qt.QMainWindow):
mode, zoomOnWheel = self._previousDefaultMode
self.setInteractiveMode(mode=mode, zoomOnWheel=zoomOnWheel)
- def setInteractiveMode(self, mode, color='black',
- shape='polygon', label=None,
- zoomOnWheel=True, source=None, width=None):
+ def setInteractiveMode(
+ self,
+ mode: str,
+ color: Union[str, Sequence[numbers.Real]] = "black",
+ shape: str = "polygon",
+ label: Optional[str] = None,
+ zoomOnWheel: bool = True,
+ source=None,
+ width: Optional[float] = None,
+ ):
"""Switch the interactive mode.
- :param str mode: The name of the interactive mode.
- In 'draw', 'pan', 'select', 'select-draw', 'zoom'.
+ :param mode: The name of the interactive mode.
+ In 'draw', 'pan', 'select', 'select-draw', 'zoom'.
:param color: Only for 'draw' and 'zoom' modes.
Color to use for drawing selection area. Default black.
:type color: Color description: The name as a str or
a tuple of 4 floats.
- :param str shape: Only for 'draw' mode. The kind of shape to draw.
- In 'polygon', 'rectangle', 'line', 'vline', 'hline',
- 'freeline'.
- Default is 'polygon'.
- :param str label: Only for 'draw' mode, sent in drawing events.
- :param bool zoomOnWheel: Toggle zoom on wheel support
+ :param shape: Only for 'draw' mode. The kind of shape to draw.
+ In 'polygon', 'rectangle', 'line', 'vline', 'hline',
+ 'freeline'.
+ Default is 'polygon'.
+ :param label: Only for 'draw' mode, sent in drawing events.
+ :param zoomOnWheel: Toggle zoom on wheel support
:param source: A user-defined object (typically the caller object)
that will be send in the interactiveModeChanged event,
to identify which object required a mode change.
Default: None
- :param float width: Width of the pencil. Only for draw pencil mode.
+ :param width: Width of the pencil. Only for draw pencil mode.
"""
- self._eventHandler.setInteractiveMode(mode, color, shape, label, width)
- self._eventHandler.zoomOnWheel = zoomOnWheel
+ self.__isInteractionSignalForwarded = False
+ try:
+ self._eventHandler._setInteractiveMode(mode, color, shape, label, width)
+ self._eventHandler.setZoomOnWheelEnabled(zoomOnWheel)
+ finally:
+ self.__isInteractionSignalForwarded = True
+
if mode in ["pan", "zoom"]:
self._previousDefaultMode = mode, zoomOnWheel
- self.notify(
- 'interactiveModeChanged', source=source)
+ self.notify("interactiveModeChanged", source=source)
# Panning with arrow keys
@@ -3613,10 +3753,10 @@ class PlotWidget(qt.QMainWindow):
# Dict to convert Qt arrow key code to direction str.
_ARROWS_TO_PAN_DIRECTION = {
- qt.Qt.Key_Left: 'left',
- qt.Qt.Key_Right: 'right',
- qt.Qt.Key_Up: 'up',
- qt.Qt.Key_Down: 'down'
+ qt.Qt.Key_Left: "left",
+ qt.Qt.Key_Right: "right",
+ qt.Qt.Key_Up: "up",
+ qt.Qt.Key_Down: "down",
}
def __simulateMouseMove(self):
@@ -3626,7 +3766,8 @@ class PlotWidget(qt.QMainWindow):
qt.QPointF(self.getWidgetHandle().mapFromGlobal(qt.QCursor.pos())),
qt.Qt.NoButton,
qapp.mouseButtons(),
- qapp.keyboardModifiers())
+ qapp.keyboardModifiers(),
+ )
qapp.sendEvent(self.getWidgetHandle(), event)
def keyPressEvent(self, event):
diff --git a/src/silx/gui/plot/PlotWindow.py b/src/silx/gui/plot/PlotWindow.py
index e8da174..9aa8c78 100644
--- a/src/silx/gui/plot/PlotWindow.py
+++ b/src/silx/gui/plot/PlotWindow.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2004-2022 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -30,16 +30,12 @@ __authors__ = ["V.A. Sole", "T. Vincent"]
__license__ = "MIT"
__date__ = "12/04/2019"
-try:
- from collections import abc
-except ImportError: # Python2 support
- import collections as abc
+from collections import abc
import logging
import weakref
import silx
from silx.utils.weakref import WeakMethodProxy
-from silx.utils.deprecation import deprecated
from silx.utils.proxy import docstring
from . import PlotWidget
@@ -57,6 +53,7 @@ from .CurvesROIWidget import CurvesROIDockWidget
from .MaskToolsWidget import MaskToolsDockWidget
from .StatsWidget import BasicStatsWidget
from .ColorBar import ColorBarWidget
+
try:
from ..console import IPythonDockWidget
except ImportError:
@@ -103,16 +100,30 @@ class PlotWindow(PlotWidget):
:param bool fit: Toggle visibilty of fit action.
"""
- def __init__(self, parent=None, backend=None,
- resetzoom=True, autoScale=True, logScale=True, grid=True,
- curveStyle=True, colormap=True,
- aspectRatio=True, yInverted=True,
- copy=True, save=True, print_=True,
- control=False, position=False,
- roi=True, mask=True, fit=False):
+ def __init__(
+ self,
+ parent=None,
+ backend=None,
+ resetzoom=True,
+ autoScale=True,
+ logScale=True,
+ grid=True,
+ curveStyle=True,
+ colormap=True,
+ aspectRatio=True,
+ yInverted=True,
+ copy=True,
+ save=True,
+ print_=True,
+ control=False,
+ position=False,
+ roi=True,
+ mask=True,
+ fit=False,
+ ):
super(PlotWindow, self).__init__(parent=parent, backend=backend)
if parent is None:
- self.setWindowTitle('PlotWindow')
+ self.setWindowTitle("PlotWindow")
self._dockWidgets = []
@@ -131,63 +142,80 @@ class PlotWindow(PlotWidget):
self.group.setExclusive(False)
self.resetZoomAction = self.group.addAction(
- actions.control.ResetZoomAction(self, parent=self))
+ actions.control.ResetZoomAction(self, parent=self)
+ )
self.resetZoomAction.setVisible(resetzoom)
self.addAction(self.resetZoomAction)
- self.zoomInAction = actions.control.ZoomInAction(self, parent=self)
+ self.zoomInAction = self.group.addAction(
+ actions.control.ZoomInAction(self, parent=self)
+ )
+ self.zoomInAction.setVisible(False)
self.addAction(self.zoomInAction)
- self.zoomOutAction = actions.control.ZoomOutAction(self, parent=self)
+ self.zoomOutAction = self.group.addAction(
+ actions.control.ZoomOutAction(self, parent=self)
+ )
+ self.zoomOutAction.setVisible(False)
self.addAction(self.zoomOutAction)
self.xAxisAutoScaleAction = self.group.addAction(
- actions.control.XAxisAutoScaleAction(self, parent=self))
+ actions.control.XAxisAutoScaleAction(self, parent=self)
+ )
self.xAxisAutoScaleAction.setVisible(autoScale)
self.addAction(self.xAxisAutoScaleAction)
self.yAxisAutoScaleAction = self.group.addAction(
- actions.control.YAxisAutoScaleAction(self, parent=self))
+ actions.control.YAxisAutoScaleAction(self, parent=self)
+ )
self.yAxisAutoScaleAction.setVisible(autoScale)
self.addAction(self.yAxisAutoScaleAction)
self.xAxisLogarithmicAction = self.group.addAction(
- actions.control.XAxisLogarithmicAction(self, parent=self))
+ actions.control.XAxisLogarithmicAction(self, parent=self)
+ )
self.xAxisLogarithmicAction.setVisible(logScale)
self.addAction(self.xAxisLogarithmicAction)
self.yAxisLogarithmicAction = self.group.addAction(
- actions.control.YAxisLogarithmicAction(self, parent=self))
+ actions.control.YAxisLogarithmicAction(self, parent=self)
+ )
self.yAxisLogarithmicAction.setVisible(logScale)
self.addAction(self.yAxisLogarithmicAction)
self.gridAction = self.group.addAction(
- actions.control.GridAction(self, gridMode='both', parent=self))
+ actions.control.GridAction(self, gridMode="both", parent=self)
+ )
self.gridAction.setVisible(grid)
self.addAction(self.gridAction)
self.curveStyleAction = self.group.addAction(
- actions.control.CurveStyleAction(self, parent=self))
+ actions.control.CurveStyleAction(self, parent=self)
+ )
self.curveStyleAction.setVisible(curveStyle)
self.addAction(self.curveStyleAction)
self.colormapAction = self.group.addAction(
- actions.control.ColormapAction(self, parent=self))
+ actions.control.ColormapAction(self, parent=self)
+ )
self.colormapAction.setVisible(colormap)
self.addAction(self.colormapAction)
self.colorbarAction = self.group.addAction(
- actions_control.ColorBarAction(self, parent=self))
+ actions_control.ColorBarAction(self, parent=self)
+ )
self.colorbarAction.setVisible(False)
self.addAction(self.colorbarAction)
self._colorbar.setVisible(False)
self.keepDataAspectRatioButton = PlotToolButtons.AspectToolButton(
- parent=self, plot=self)
+ parent=self, plot=self
+ )
self.keepDataAspectRatioButton.setVisible(aspectRatio)
self.yAxisInvertedButton = PlotToolButtons.YAxisOriginToolButton(
- parent=self, plot=self)
+ parent=self, plot=self
+ )
self.yAxisInvertedButton.setVisible(yInverted)
self.group.addAction(self.getRoiAction())
@@ -197,15 +225,18 @@ class PlotWindow(PlotWidget):
self.getMaskAction().setVisible(mask)
self._intensityHistoAction = self.group.addAction(
- actions_histogram.PixelIntensitiesHistoAction(self, parent=self))
+ actions_histogram.PixelIntensitiesHistoAction(self, parent=self)
+ )
self._intensityHistoAction.setVisible(False)
self._medianFilter2DAction = self.group.addAction(
- actions_medfilt.MedianFilter2DAction(self, parent=self))
+ actions_medfilt.MedianFilter2DAction(self, parent=self)
+ )
self._medianFilter2DAction.setVisible(False)
self._medianFilter1DAction = self.group.addAction(
- actions_medfilt.MedianFilter1DAction(self, parent=self))
+ actions_medfilt.MedianFilter1DAction(self, parent=self)
+ )
self._medianFilter1DAction.setVisible(False)
self.fitAction = self.group.addAction(actions_fit.FitAction(self, parent=self))
@@ -239,24 +270,25 @@ class PlotWindow(PlotWidget):
converters = position
else:
converters = None
- self._positionWidget = tools.PositionInfo(
- plot=self, converters=converters)
+ self._positionWidget = tools.PositionInfo(plot=self, converters=converters)
# Set a snapping mode that is consistent with legacy one
self._positionWidget.setSnappingMode(
- tools.PositionInfo.SNAPPING_CROSSHAIR |
- tools.PositionInfo.SNAPPING_ACTIVE_ONLY |
- tools.PositionInfo.SNAPPING_SYMBOLS_ONLY |
- tools.PositionInfo.SNAPPING_CURVE |
- tools.PositionInfo.SNAPPING_SCATTER)
+ tools.PositionInfo.SNAPPING_CROSSHAIR
+ | tools.PositionInfo.SNAPPING_ACTIVE_ONLY
+ | tools.PositionInfo.SNAPPING_SYMBOLS_ONLY
+ | tools.PositionInfo.SNAPPING_CURVE
+ | tools.PositionInfo.SNAPPING_SCATTER
+ )
self.__setCentralWidget()
# Creating the toolbar also create actions for toolbuttons
self._interactiveModeToolBar = tools.InteractiveModeToolBar(
- parent=self, plot=self)
+ parent=self, plot=self
+ )
self.addToolBar(self._interactiveModeToolBar)
- self._toolbar = self._createToolBar(title='Plot', parent=self)
+ self._toolbar = self._createToolBar(title="Plot", parent=self)
self.addToolBar(self._toolbar)
self._outputToolBar = tools.OutputToolBar(parent=self, plot=self)
@@ -352,11 +384,6 @@ class PlotWindow(PlotWidget):
"""
return self._outputToolBar
- @property
- @deprecated(replacement="getPositionInfoWidget()", since_version="0.8.0")
- def positionWidget(self):
- return self.getPositionInfoWidget()
-
def getPositionInfoWidget(self):
"""Returns the widget displaying current cursor position information
@@ -395,12 +422,12 @@ class PlotWindow(PlotWidget):
banner = "The variable 'plt' is available. Use the 'whos' "
banner += "and 'help(plt)' commands for more information.\n\n"
self._consoleDockWidget = IPythonDockWidget(
- available_vars=available_vars,
- custom_banner=banner,
- parent=self)
+ available_vars=available_vars, custom_banner=banner, parent=self
+ )
self.addTabbedDockWidget(self._consoleDockWidget)
self._consoleDockWidget.toggleViewAction().toggled.connect(
- self._consoleDockWidgetToggled)
+ self._consoleDockWidgetToggled
+ )
self._consoleDockWidget.setVisible(isChecked)
@@ -440,15 +467,14 @@ class PlotWindow(PlotWidget):
elif obj is self.yAxisInvertedButton:
self.yAxisInvertedAction = toolbar.addWidget(obj)
else:
- raise RuntimeError()
+ raise RuntimeError("unknow action to be defined")
return toolbar
def toolBar(self):
- """Return a QToolBar from the QAction of the PlotWindow.
- """
+ """Return a QToolBar from the QAction of the PlotWindow."""
return self._toolbar
- def menu(self, title='Plot', parent=None):
+ def menu(self, title="Plot", parent=None):
"""Return a QMenu from the QAction of the PlotWindow.
:param str title: The title of the QMenu
@@ -495,8 +521,7 @@ class PlotWindow(PlotWidget):
self.addDockWidget(area, dock_widget)
else:
# Other dock widgets are added as tabs to the same widget area
- self.tabifyDockWidget(self._dockWidgets[0],
- dock_widget)
+ self.tabifyDockWidget(self._dockWidgets[0], dock_widget)
def removeDockWidget(self, dockwidget):
"""Removes the *dockwidget* from the main window layout and hides it.
@@ -521,8 +546,7 @@ class PlotWindow(PlotWidget):
"""
if visible:
dockWidget = self.sender()
- dockWidget.visibilityChanged.disconnect(
- self._handleFirstDockWidgetShow)
+ dockWidget.visibilityChanged.disconnect(self._handleFirstDockWidgetShow)
self.addTabbedDockWidget(dockWidget)
def _handleDockWidgetViewActionTriggered(self, checked):
@@ -551,9 +575,11 @@ class PlotWindow(PlotWidget):
self._legendsDockWidget = LegendsDockWidget(plot=self)
self._legendsDockWidget.hide()
self._legendsDockWidget.toggleViewAction().triggered.connect(
- self._handleDockWidgetViewActionTriggered)
+ self._handleDockWidgetViewActionTriggered
+ )
self._legendsDockWidget.visibilityChanged.connect(
- self._handleFirstDockWidgetShow)
+ self._handleFirstDockWidgetShow
+ )
return self._legendsDockWidget
def getCurvesRoiDockWidget(self):
@@ -561,12 +587,15 @@ class PlotWindow(PlotWidget):
# (still used internally for lazy loading)
if self._curvesROIDockWidget is None:
self._curvesROIDockWidget = CurvesROIDockWidget(
- plot=self, name='Regions Of Interest')
+ plot=self, name="Regions Of Interest"
+ )
self._curvesROIDockWidget.hide()
self._curvesROIDockWidget.toggleViewAction().triggered.connect(
- self._handleDockWidgetViewActionTriggered)
+ self._handleDockWidgetViewActionTriggered
+ )
self._curvesROIDockWidget.visibilityChanged.connect(
- self._handleFirstDockWidgetShow)
+ self._handleFirstDockWidgetShow
+ )
return self._curvesROIDockWidget
def getCurvesRoiWidget(self):
@@ -583,13 +612,14 @@ class PlotWindow(PlotWidget):
def getMaskToolsDockWidget(self):
"""DockWidget with image mask panel (lazy-loaded)."""
if self._maskToolsDockWidget is None:
- self._maskToolsDockWidget = MaskToolsDockWidget(
- plot=self, name='Mask')
+ self._maskToolsDockWidget = MaskToolsDockWidget(plot=self, name="Mask")
self._maskToolsDockWidget.hide()
self._maskToolsDockWidget.toggleViewAction().triggered.connect(
- self._handleDockWidgetViewActionTriggered)
+ self._handleDockWidgetViewActionTriggered
+ )
self._maskToolsDockWidget.visibilityChanged.connect(
- self._handleFirstDockWidgetShow)
+ self._handleFirstDockWidgetShow
+ )
return self._maskToolsDockWidget
def getStatsWidget(self):
@@ -605,23 +635,14 @@ class PlotWindow(PlotWidget):
self._statsDockWidget.setWidget(statsWidget)
self._statsDockWidget.hide()
self._statsDockWidget.toggleViewAction().triggered.connect(
- self._handleDockWidgetViewActionTriggered)
+ self._handleDockWidgetViewActionTriggered
+ )
self._statsDockWidget.visibilityChanged.connect(
- self._handleFirstDockWidgetShow)
+ self._handleFirstDockWidgetShow
+ )
return self._statsDockWidget.widget()
# getters for actions
- @property
- @deprecated(replacement="getInteractiveModeToolBar().getZoomModeAction()",
- since_version="0.8.0")
- def zoomModeAction(self):
- return self.getInteractiveModeToolBar().getZoomModeAction()
-
- @property
- @deprecated(replacement="getInteractiveModeToolBar().getPanModeAction()",
- since_version="0.8.0")
- def panModeAction(self):
- return self.getInteractiveModeToolBar().getPanModeAction()
def getConsoleAction(self):
"""QAction handling the IPython console activation.
@@ -634,7 +655,7 @@ class PlotWindow(PlotWidget):
:rtype: QAction
"""
if self._consoleAction is None:
- self._consoleAction = qt.QAction('Console', self)
+ self._consoleAction = qt.QAction("Console", self)
self._consoleAction.setCheckable(True)
if IPythonDockWidget is not None:
self._consoleAction.toggled.connect(self._toggleConsoleVisibility)
@@ -650,7 +671,7 @@ class PlotWindow(PlotWidget):
:rtype: actions.PlotAction
"""
if self._crosshairAction is None:
- self._crosshairAction = actions.control.CrosshairAction(self, color='red')
+ self._crosshairAction = actions.control.CrosshairAction(self, color="red")
return self._crosshairAction
def getMaskAction(self):
@@ -854,22 +875,36 @@ class Plot1D(PlotWindow):
"""
def __init__(self, parent=None, backend=None):
- super(Plot1D, self).__init__(parent=parent, backend=backend,
- resetzoom=True, autoScale=True,
- logScale=True, grid=True,
- curveStyle=True, colormap=False,
- aspectRatio=False, yInverted=False,
- copy=True, save=True, print_=True,
- control=True, position=True,
- roi=True, mask=False, fit=True)
+ super(Plot1D, self).__init__(
+ parent=parent,
+ backend=backend,
+ resetzoom=True,
+ autoScale=True,
+ logScale=True,
+ grid=True,
+ curveStyle=True,
+ colormap=False,
+ aspectRatio=False,
+ yInverted=False,
+ copy=True,
+ save=True,
+ print_=True,
+ control=True,
+ position=True,
+ roi=True,
+ mask=False,
+ fit=True,
+ )
if parent is None:
- self.setWindowTitle('Plot1D')
- self.getXAxis().setLabel('X')
- self.getYAxis().setLabel('Y')
+ self.setWindowTitle("Plot1D")
+ self.getXAxis().setLabel("X")
+ self.getYAxis().setLabel("Y")
action = self.getFitAction()
action.setXRangeUpdatedOnZoom(True)
action.setFittedItemUpdatedFromActiveCurve(True)
+ self.getInteractiveModeToolBar().getZoomModeAction().setAxesMenuEnabled(True)
+
class Plot2D(PlotWindow):
"""PlotWindow with a toolbar specific for images.
@@ -885,26 +920,37 @@ class Plot2D(PlotWindow):
def __init__(self, parent=None, backend=None):
# List of information to display at the bottom of the plot
posInfo = [
- ('X', lambda x, y: x),
- ('Y', lambda x, y: y),
- ('Data', WeakMethodProxy(self._getImageValue)),
- ('Dims', WeakMethodProxy(self._getImageDims)),
+ ("X", lambda x, y: x),
+ ("Y", lambda x, y: y),
+ ("Data", WeakMethodProxy(self._getImageValue)),
+ ("Dims", WeakMethodProxy(self._getImageDims)),
]
- super(Plot2D, self).__init__(parent=parent, backend=backend,
- resetzoom=True, autoScale=False,
- logScale=False, grid=False,
- curveStyle=False, colormap=True,
- aspectRatio=True, yInverted=True,
- copy=True, save=True, print_=True,
- control=False, position=posInfo,
- roi=False, mask=True)
+ super(Plot2D, self).__init__(
+ parent=parent,
+ backend=backend,
+ resetzoom=True,
+ autoScale=False,
+ logScale=False,
+ grid=False,
+ curveStyle=False,
+ colormap=True,
+ aspectRatio=True,
+ yInverted=True,
+ copy=True,
+ save=True,
+ print_=True,
+ control=False,
+ position=posInfo,
+ roi=False,
+ mask=True,
+ )
if parent is None:
- self.setWindowTitle('Plot2D')
- self.getXAxis().setLabel('Columns')
- self.getYAxis().setLabel('Rows')
+ self.setWindowTitle("Plot2D")
+ self.getXAxis().setLabel("Columns")
+ self.getYAxis().setLabel("Rows")
- if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == 'downward':
+ if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == "downward":
self.getYAxis().setInverted(True)
self.profile = ProfileToolBar(plot=self)
@@ -959,8 +1005,9 @@ class Plot2D(PlotWindow):
"""
pickedMask = None
for picked in self.pickItems(
- *self.dataToPixel(x, y, check=False),
- lambda item: isinstance(item, items.ImageBase)):
+ *self.dataToPixel(x, y, check=False),
+ lambda item: isinstance(item, items.ImageBase),
+ ):
if isinstance(picked.getItem(), items.MaskImageData):
if pickedMask is None: # Use top-most if many masks
pickedMask = picked
@@ -980,16 +1027,15 @@ class Plot2D(PlotWindow):
return value, "Masked"
return value
- return '-' # No image picked
+ return "-" # No image picked
def _getImageDims(self, *args):
activeImage = self.getActiveImage()
- if (activeImage is not None and
- activeImage.getData(copy=False) is not None):
+ if activeImage is not None and activeImage.getData(copy=False) is not None:
dims = activeImage.getData(copy=False).shape[1::-1]
- return 'x'.join(str(dim) for dim in dims)
+ return "x".join(str(dim) for dim in dims)
else:
- return '-'
+ return "-"
def getProfileToolbar(self):
"""Profile tools attached to this plot
@@ -998,10 +1044,6 @@ class Plot2D(PlotWindow):
"""
return self.profile
- @deprecated(replacement="getProfilePlot", since_version="0.5.0")
- def getProfileWindow(self):
- return self.getProfilePlot()
-
def getProfilePlot(self):
"""Return plot window used to display profile curve.
diff --git a/src/silx/gui/plot/PrintPreviewToolButton.py b/src/silx/gui/plot/PrintPreviewToolButton.py
index 9069ac3..0812420 100644
--- a/src/silx/gui/plot/PrintPreviewToolButton.py
+++ b/src/silx/gui/plot/PrintPreviewToolButton.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -109,7 +109,6 @@ from .. import icons
from . import PlotWidget
from ..widgets.PrintPreview import PrintPreviewDialog, SingletonPrintPreviewDialog
from ..widgets.PrintGeometryDialog import PrintGeometryDialog
-from silx.utils.deprecation import deprecated
__authors__ = ["P. Knobel"]
__license__ = "MIT"
@@ -126,6 +125,7 @@ class PrintPreviewToolButton(qt.QToolButton):
:param parent: See :class:`QAction`
:param plot: :class:`.PlotWidget` instance on which to operate
"""
+
def __init__(self, parent=None, plot=None):
super(PrintPreviewToolButton, self).__init__(parent)
@@ -133,17 +133,19 @@ class PrintPreviewToolButton(qt.QToolButton):
raise TypeError("plot parameter must be a PlotWidget")
self._plot = plot
- self.setIcon(icons.getQIcon('document-print'))
+ self.setIcon(icons.getQIcon("document-print"))
printGeomAction = qt.QAction("Print geometry", self)
- printGeomAction.setToolTip("Define a print geometry prior to sending "
- "the plot to the print preview dialog")
- printGeomAction.setIcon(icons.getQIcon('shape-rectangle'))
+ printGeomAction.setToolTip(
+ "Define a print geometry prior to sending "
+ "the plot to the print preview dialog"
+ )
+ printGeomAction.setIcon(icons.getQIcon("shape-rectangle"))
printGeomAction.triggered.connect(self._setPrintConfiguration)
printPreviewAction = qt.QAction("Print preview", self)
printPreviewAction.setToolTip("Send plot to the print preview dialog")
- printPreviewAction.setIcon(icons.getQIcon('document-print'))
+ printPreviewAction.setIcon(icons.getQIcon("document-print"))
printPreviewAction.triggered.connect(self._plotToPrintPreview)
menu = qt.QMenu(self)
@@ -155,12 +157,14 @@ class PrintPreviewToolButton(qt.QToolButton):
self._printPreviewDialog = None
self._printConfigurationDialog = None
- self._printGeometry = {"xOffset": 0.1,
- "yOffset": 0.1,
- "width": 0.9,
- "height": 0.9,
- "units": "page",
- "keepAspectRatio": True}
+ self._printGeometry = {
+ "xOffset": 0.1,
+ "yOffset": 0.1,
+ "width": 0.9,
+ "height": 0.9,
+ "units": "page",
+ "keepAspectRatio": True,
+ }
@property
def printPreviewDialog(self):
@@ -189,12 +193,6 @@ class PrintPreviewToolButton(qt.QToolButton):
"""
return None, None
- @property
- @deprecated(since_version="0.10",
- replacement="getPlot()")
- def plot(self):
- return self._plot
-
def getPlot(self):
"""Return the :class:`.PlotWidget` associated with this tool button.
@@ -212,19 +210,23 @@ class PrintPreviewToolButton(qt.QToolButton):
if qt.HAS_SVG:
svgRenderer, viewBox = self._getSvgRendererAndViewbox()
- self.printPreviewDialog.addSvgItem(svgRenderer,
- title=self.getTitle(),
- comment=comment,
- commentPosition=commentPosition,
- viewBox=viewBox,
- keepRatio=self._printGeometry["keepAspectRatio"])
+ self.printPreviewDialog.addSvgItem(
+ svgRenderer,
+ title=self.getTitle(),
+ comment=comment,
+ commentPosition=commentPosition,
+ viewBox=viewBox,
+ keepRatio=self._printGeometry["keepAspectRatio"],
+ )
else:
_logger.warning("Missing QtSvg library, using a raster image")
pixmap = self._plot.centralWidget().grab()
- self.printPreviewDialog.addPixmap(pixmap,
- title=self.getTitle(),
- comment=comment,
- commentPosition=commentPosition)
+ self.printPreviewDialog.addPixmap(
+ pixmap,
+ title=self.getTitle(),
+ comment=comment,
+ commentPosition=commentPosition,
+ )
self.printPreviewDialog.show()
self.printPreviewDialog.raise_()
@@ -236,8 +238,7 @@ class PrintPreviewToolButton(qt.QToolButton):
and to the geometry configuration (width, height, ratio) specified
by the user."""
imgData = StringIO()
- assert self._plot.saveGraph(imgData, fileFormat="svg"), \
- "Unable to save graph"
+ assert self._plot.saveGraph(imgData, fileFormat="svg"), "Unable to save graph"
imgData.flush()
imgData.seek(0)
svgData = imgData.read()
@@ -261,8 +262,7 @@ class PrintPreviewToolButton(qt.QToolButton):
return svgRenderer, viewbox
def _getViewBox(self):
- """
- """
+ """ """
printer = self.printPreviewDialog.printer
dpix = printer.logicalDpiX()
dpiy = printer.logicalDpiY()
@@ -270,23 +270,23 @@ class PrintPreviewToolButton(qt.QToolButton):
availableHeight = printer.height()
config = self._printGeometry
- width = config['width']
- height = config['height']
- xOffset = config['xOffset']
- yOffset = config['yOffset']
- units = config['units']
- keepAspectRatio = config['keepAspectRatio']
+ width = config["width"]
+ height = config["height"]
+ xOffset = config["xOffset"]
+ yOffset = config["yOffset"]
+ units = config["units"]
+ keepAspectRatio = config["keepAspectRatio"]
aspectRatio = self._getPlotAspectRatio()
# convert the offsets to dots
- if units.lower() in ['inch', 'inches']:
+ if units.lower() in ["inch", "inches"]:
xOffset = xOffset * dpix
yOffset = yOffset * dpiy
if width is not None:
width = width * dpix
if height is not None:
height = height * dpiy
- elif units.lower() in ['cm', 'centimeters']:
+ elif units.lower() in ["cm", "centimeters"]:
xOffset = (xOffset / 2.54) * dpix
yOffset = (yOffset / 2.54) * dpiy
if width is not None:
@@ -307,13 +307,17 @@ class PrintPreviewToolButton(qt.QToolButton):
if width is not None:
if (availableWidth + 0.1) < width:
- txt = "Available width %f is less than requested width %f" % \
- (availableWidth, width)
+ txt = "Available width %f is less than requested width %f" % (
+ availableWidth,
+ width,
+ )
raise ValueError(txt)
if height is not None:
if (availableHeight + 0.1) < height:
- txt = "Available height %f is less than requested height %f" % \
- (availableHeight, height)
+ txt = "Available height %f is less than requested height %f" % (
+ availableHeight,
+ height,
+ )
raise ValueError(txt)
if keepAspectRatio:
@@ -328,10 +332,7 @@ class PrintPreviewToolButton(qt.QToolButton):
bodyWidth = width or availableWidth
bodyHeight = height or availableHeight
- return qt.QRectF(xOffset,
- yOffset,
- bodyWidth,
- bodyHeight)
+ return qt.QRectF(xOffset, yOffset, bodyWidth, bodyHeight)
def _setPrintConfiguration(self):
"""Open a dialog to prompt the user to adjust print
@@ -357,6 +358,7 @@ class SingletonPrintPreviewToolButton(PrintPreviewToolButton):
This allows for several plots to send their content to the
same print page, and for users to arrange them."""
+
def __init__(self, parent=None, plot=None):
PrintPreviewToolButton.__init__(self, parent, plot)
@@ -367,14 +369,14 @@ class SingletonPrintPreviewToolButton(PrintPreviewToolButton):
return self._printPreviewDialog
-if __name__ == '__main__':
+if __name__ == "__main__":
import numpy
+
app = qt.QApplication([])
pw = PlotWidget()
toolbar = qt.QToolBar(pw)
- toolbutton = PrintPreviewToolButton(parent=toolbar,
- plot=pw)
+ toolbutton = PrintPreviewToolButton(parent=toolbar, plot=pw)
pw.addToolBar(toolbar)
toolbar.addWidget(toolbutton)
pw.show()
diff --git a/src/silx/gui/plot/Profile.py b/src/silx/gui/plot/Profile.py
index bf793c8..f89f780 100644
--- a/src/silx/gui/plot/Profile.py
+++ b/src/silx/gui/plot/Profile.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -34,27 +34,18 @@ import weakref
from .. import qt
from . import actions
-from .tools.profile import core
from .tools.profile import manager
from .tools.profile import rois
from silx.gui.widgets.MultiModeAction import MultiModeAction
-from silx.utils.deprecation import deprecated
-from silx.utils.deprecation import deprecated_warning
from .tools import roi as roi_mdl
from silx.gui.plot import items
-@deprecated(replacement="silx.gui.plot.tools.profile.createProfile", since_version="0.13.0")
-def createProfile(roiInfo, currentData, origin, scale, lineWidth, method):
- return core.createProfile(roiInfo, currentData, origin,
- scale, lineWidth, method)
-
-
class _CustomProfileManager(manager.ProfileManager):
"""This custom profile manager uses a single predefined profile window
if it is specified. Else the behavior is the same as the default
- ProfileManager """
+ ProfileManager"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -78,7 +69,10 @@ class _CustomProfileManager(manager.ProfileManager):
self.__profileWindow = profileWindow
def createProfileWindow(self, plot, roi):
- for roiClass, specializedProfileWindow in self.__specializedProfileWindows.items():
+ for (
+ roiClass,
+ specializedProfileWindow,
+ ) in self.__specializedProfileWindows.items():
if isinstance(roi, roiClass):
return specializedProfileWindow
@@ -121,23 +115,13 @@ class ProfileToolBar(qt.QToolBar):
:param plot: :class:`PlotWindow` instance on which to operate.
:param profileWindow: Plot widget instance where to
display the profile curve or None to create one.
- :param str title: See :class:`QToolBar`.
:param parent: See :class:`QToolBar`.
"""
- def __init__(self, parent=None, plot=None, profileWindow=None,
- title=None):
- super(ProfileToolBar, self).__init__(title, parent)
+ def __init__(self, parent=None, plot=None, profileWindow=None):
+ super(ProfileToolBar, self).__init__(parent)
assert plot is not None
- if title is not None:
- deprecated_warning("Attribute",
- name="title",
- reason="removed",
- since_version="0.13.0",
- only_once=True,
- skip_backtrace_count=1)
-
self._plotRef = weakref.ref(plot)
# If a profileWindow is defined,
@@ -185,22 +169,27 @@ class ProfileToolBar(qt.QToolBar):
return _CustomProfileManager(parent, plot)
def _createProfileActions(self):
- self.hLineAction = self._manager.createProfileAction(rois.ProfileImageHorizontalLineROI, self)
- self.vLineAction = self._manager.createProfileAction(rois.ProfileImageVerticalLineROI, self)
- self.lineAction = self._manager.createProfileAction(rois.ProfileImageLineROI, self)
- self.freeLineAction = self._manager.createProfileAction(rois.ProfileImageDirectedLineROI, self)
- self.crossAction = self._manager.createProfileAction(rois.ProfileImageCrossROI, self)
+ self.hLineAction = self._manager.createProfileAction(
+ rois.ProfileImageHorizontalLineROI, self
+ )
+ self.vLineAction = self._manager.createProfileAction(
+ rois.ProfileImageVerticalLineROI, self
+ )
+ self.lineAction = self._manager.createProfileAction(
+ rois.ProfileImageLineROI, self
+ )
+ self.freeLineAction = self._manager.createProfileAction(
+ rois.ProfileImageDirectedLineROI, self
+ )
+ self.crossAction = self._manager.createProfileAction(
+ rois.ProfileImageCrossROI, self
+ )
self.clearAction = self._manager.createClearAction(self)
def getPlotWidget(self):
"""The :class:`.PlotWidget` associated to the toolbar."""
return self._plotRef()
- @property
- @deprecated(since_version="0.13.0", replacement="getPlotWidget()")
- def plot(self):
- return self.getPlotWidget()
-
def _setRoiActionEnabled(self, itemKind, enabled):
for action in self.__multiAction.getMenu().actions():
if not isinstance(action, roi_mdl.CreateRoiModeAction):
@@ -221,16 +210,6 @@ class ProfileToolBar(qt.QToolBar):
enabled = image.getData(copy=False).size > 0
self._setRoiActionEnabled(type(image), enabled)
- @property
- @deprecated(since_version="0.6.0")
- def browseAction(self):
- return self._browseAction
-
- @property
- @deprecated(replacement="getProfilePlot", since_version="0.5.0")
- def profileWindow(self):
- return self.getProfilePlot()
-
def getProfileManager(self):
"""Return the manager of the profiles.
@@ -238,114 +217,38 @@ class ProfileToolBar(qt.QToolBar):
"""
return self._manager
- @deprecated(since_version="0.13.0")
- def getProfilePlot(self):
- """Return plot widget in which the profile curve or the
- profile image is plotted.
- """
- window = self.getProfileMainWindow()
- if window is None:
- return None
- return window.getCurrentPlotWidget()
-
- @deprecated(replacement="getProfileManager().getCurrentRoi().getProfileWindow()", since_version="0.13.0")
- def getProfileMainWindow(self):
- """Return window containing the profile curve widget.
-
- This can return None if no profile was computed.
- """
- roi = self._manager.getCurrentRoi()
- if roi is None:
- return None
- return roi.getProfileWindow()
-
- @property
- @deprecated(since_version="0.13.0")
- def overlayColor(self):
- """This method does nothing anymore. But could be implemented if needed.
-
- It was used to set color to use for the ROI.
-
- If set to None (the default), the overlay color is adapted to the
- active image colormap and changes if the active image colormap changes.
- """
- pass
-
- @overlayColor.setter
- @deprecated(since_version="0.13.0")
- def overlayColor(self, color):
- """This method does nothing anymore. But could be implemented if needed.
- """
- pass
-
def clearProfile(self):
"""Remove profile curve and profile area."""
self._manager.clearProfile()
- @deprecated(since_version="0.13.0")
- def updateProfile(self):
- """This method does nothing anymore. But could be implemented if needed.
-
- It was used to update the displayed profile and profile ROI.
-
- This uses the current active image of the plot and the current ROI.
- """
- pass
-
- @deprecated(replacement="clearProfile()", since_version="0.13.0")
- def hideProfileWindow(self):
- """Hide profile window.
- """
- self.clearProfile()
-
- @deprecated(since_version="0.13.0")
- def setProfileMethod(self, method):
- assert method in ('sum', 'mean')
- roi = self._manager.getCurrentRoi()
- if roi is None:
- raise RuntimeError("No profile ROI selected")
- roi.setProfileMethod(method)
-
- @deprecated(since_version="0.13.0")
- def getProfileMethod(self):
- roi = self._manager.getCurrentRoi()
- if roi is None:
- raise RuntimeError("No profile ROI selected")
- return roi.getProfileMethod()
-
- @deprecated(since_version="0.13.0")
- def getProfileOptionToolAction(self):
- return self._editor
-
class Profile3DToolBar(ProfileToolBar):
- def __init__(self, parent=None, stackview=None,
- title=None):
+ def __init__(self, parent=None, stackview=None):
"""QToolBar providing profile tools for an image or a stack of images.
:param parent: the parent QWidget
:param stackview: :class:`StackView` instance on which to operate.
- :param str title: See :class:`QToolBar`.
:param parent: See :class:`QToolBar`.
"""
# TODO: add param profileWindow (specify the plot used for profiles)
- super(Profile3DToolBar, self).__init__(parent=parent,
- plot=stackview.getPlotWidget())
-
- if title is not None:
- deprecated_warning("Attribute",
- name="title",
- reason="removed",
- since_version="0.13.0",
- only_once=True,
- skip_backtrace_count=1)
+ super(Profile3DToolBar, self).__init__(
+ parent=parent, plot=stackview.getPlotWidget()
+ )
self.stackView = stackview
""":class:`StackView` instance"""
def _createProfileActions(self):
- self.hLineAction = self._manager.createProfileAction(rois.ProfileImageStackHorizontalLineROI, self)
- self.vLineAction = self._manager.createProfileAction(rois.ProfileImageStackVerticalLineROI, self)
- self.lineAction = self._manager.createProfileAction(rois.ProfileImageStackLineROI, self)
- self.crossAction = self._manager.createProfileAction(rois.ProfileImageStackCrossROI, self)
+ self.hLineAction = self._manager.createProfileAction(
+ rois.ProfileImageStackHorizontalLineROI, self
+ )
+ self.vLineAction = self._manager.createProfileAction(
+ rois.ProfileImageStackVerticalLineROI, self
+ )
+ self.lineAction = self._manager.createProfileAction(
+ rois.ProfileImageStackLineROI, self
+ )
+ self.crossAction = self._manager.createProfileAction(
+ rois.ProfileImageStackCrossROI, self
+ )
self.clearAction = self._manager.createClearAction(self)
diff --git a/src/silx/gui/plot/ProfileMainWindow.py b/src/silx/gui/plot/ProfileMainWindow.py
deleted file mode 100644
index 09a5b41..0000000
--- a/src/silx/gui/plot/ProfileMainWindow.py
+++ /dev/null
@@ -1,109 +0,0 @@
-# /*##########################################################################
-#
-# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module contains a QMainWindow class used to display profile plots.
-"""
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "21/02/2017"
-
-import silx.utils.deprecation
-from silx.gui import qt
-from .tools.profile.manager import ProfileWindow
-
-silx.utils.deprecation.deprecated_warning("Module",
- name="silx.gui.plot.ProfileMainWindow",
- reason="moved",
- replacement="silx.gui.plot.tools.profile.manager.ProfileWindow",
- since_version="0.13.0",
- only_once=True,
- skip_backtrace_count=1)
-
-class ProfileMainWindow(ProfileWindow):
- """QMainWindow providing 2 plot widgets specialized in
- 1D and 2D plotting, with different toolbars.
-
- Only one of the plots is visible at any given time.
-
- :param qt.QWidget parent: The parent of this widget or None (default).
- :param Union[str,Class] backend: The backend to use, in:
- 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'none'
- or a :class:`BackendBase.BackendBase` class
- """
-
- sigProfileDimensionsChanged = qt.Signal(int)
- """This signal is emitted when :meth:`setProfileDimensions` is called.
- It carries the number of dimensions for the profile data (1 or 2).
- It can be used to be notified that the profile plot widget has changed.
-
- Note: This signal should be removed.
- """
-
- sigProfileMethodChanged = qt.Signal(str)
- """Emitted when the method to compute the profile changed (for now can be
- sum or mean)
-
- Note: This signal should be removed.
- """
-
- def __init__(self, parent=None, backend=None):
- ProfileWindow.__init__(self, parent=parent, backend=backend)
- # by default, profile is assumed to be a 1D curve
- self._profileType = None
-
- def setProfileType(self, profileType):
- """Set which profile plot widget (1D or 2D) is to be used
-
- Note: This method should be removed.
-
- :param str profileType: Type of profile data,
- "1D" for a curve or "2D" for an image
- """
- self._profileType = profileType
- if self._profileType == "1D":
- self._showPlot1D()
- elif self._profileType == "2D":
- self._showPlot2D()
- else:
- raise ValueError("Profile type must be '1D' or '2D'")
- self.sigProfileDimensionsChanged.emit(profileType)
-
- def getPlot(self):
- """Return the profile plot widget which is currently in use.
- This can be the 2D profile plot or the 1D profile plot.
-
- Note: This method should be removed.
- """
- return self.getCurrentPlotWidget()
-
- def setProfileMethod(self, method):
- """
- Note: This method should be removed.
-
- :param str method: method to manage the 'width' in the profile
- (computing mean or sum).
- """
- assert method in ('sum', 'mean')
- self._method = method
- self.sigProfileMethodChanged.emit(self._method)
diff --git a/src/silx/gui/plot/ROIStatsWidget.py b/src/silx/gui/plot/ROIStatsWidget.py
index 732c60f..36f3391 100644
--- a/src/silx/gui/plot/ROIStatsWidget.py
+++ b/src/silx/gui/plot/ROIStatsWidget.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -34,8 +34,8 @@ __date__ = "22/07/2019"
from contextlib import contextmanager
from silx.gui import qt
from silx.gui import icons
-from silx.gui.plot.StatsWidget import _StatsWidgetBase, StatsTable, _Container
-from silx.gui.plot.StatsWidget import UpdateModeWidget, UpdateMode
+from silx.gui.plot.StatsWidget import _StatsWidgetBase, _Container
+from silx.gui.plot.StatsWidget import UpdateMode
from silx.gui.widgets.TableWidget import TableWidget
from silx.gui.plot.items.roi import RegionOfInterest
from silx.gui.plot import items as plotitems
@@ -43,7 +43,6 @@ from silx.gui.plot.items.core import ItemChangedType
from silx.gui.plot3d import items as plot3ditems
from silx.gui.plot.CurvesROIWidget import ROI
from silx.gui.plot import stats as statsmdl
-from collections import OrderedDict
from silx.utils.proxy import docstring
import silx.gui.plot.items.marker
import silx.gui.plot.items.shape
@@ -57,7 +56,8 @@ class _GetROIItemCoupleDialog(qt.QDialog):
"""
Dialog used to know which plot item and which roi he wants
"""
- _COMPATIBLE_KINDS = ('curve', 'image', 'scatter', 'histogram')
+
+ _COMPATIBLE_KINDS = ("curve", "image", "scatter", "histogram")
def __init__(self, parent=None, plot=None, rois=None):
qt.QDialog.__init__(self, parent=parent)
@@ -92,13 +92,15 @@ class _GetROIItemCoupleDialog(qt.QDialog):
def _getCompatibleRois(self, kind):
"""Return compatible rois for the given item kind"""
+
def is_compatible(roi, kind):
if isinstance(roi, RegionOfInterest):
- return kind in ('image', 'scatter')
+ return kind in ("image", "scatter")
elif isinstance(roi, ROI):
- return kind in ('curve', 'histogram')
+ return kind in ("curve", "histogram")
else:
- raise ValueError('kind not managed')
+ raise ValueError("kind not managed")
+
return list(filter(lambda x: is_compatible(x, kind), self._rois))
def exec(self):
@@ -114,6 +116,7 @@ class _GetROIItemCoupleDialog(qt.QDialog):
self._kind_name_to_item = {}
# key is (kind, legend name) value is item
for kind in _GetROIItemCoupleDialog._COMPATIBLE_KINDS:
+
def getItems(kind):
output = []
for item in self._plot.getItems():
@@ -135,7 +138,7 @@ class _GetROIItemCoupleDialog(qt.QDialog):
# filter roi according to kinds
if len(self._valid_kinds) == 0:
- _logger.warning('no couple item/roi detected for displaying stats')
+ _logger.warning("no couple item/roi detected for displaying stats")
return self.reject()
for kind in self._valid_kinds:
@@ -173,10 +176,11 @@ class ROIStatsItemHelper(object):
Display on one row statistics regarding the couple
(Item (plot item) / roi).
- :param Item plot_item: item for which we want statistics
+ :param Item plot_item: item for which we want statistics
:param Union[ROI,RegionOfInterest]: region of interest to use for
statistics.
"""
+
def __init__(self, plot_item, roi):
self._plot_item = plot_item
self._roi = roi
@@ -192,7 +196,7 @@ class ROIStatsItemHelper(object):
elif isinstance(self._roi, RegionOfInterest):
return self._roi.getName()
else:
- raise TypeError('Unmanaged roi type')
+ raise TypeError("Unmanaged roi type")
@property
def roi_kind(self):
@@ -203,19 +207,21 @@ class ROIStatsItemHelper(object):
def item_kind(self):
"""item kind"""
if isinstance(self._plot_item, plotitems.Curve):
- return 'curve'
+ return "curve"
elif isinstance(self._plot_item, plotitems.ImageData):
- return 'image'
+ return "image"
elif isinstance(self._plot_item, plotitems.Scatter):
- return 'scatter'
+ return "scatter"
elif isinstance(self._plot_item, plotitems.Histogram):
- return 'histogram'
- elif isinstance(self._plot_item, (plot3ditems.ImageData,
- plot3ditems.ScalarField3D)):
- return 'image'
- elif isinstance(self._plot_item, (plot3ditems.Scatter2D,
- plot3ditems.Scatter3D)):
- return 'scatter'
+ return "histogram"
+ elif isinstance(
+ self._plot_item, (plot3ditems.ImageData, plot3ditems.ScalarField3D)
+ ):
+ return "image"
+ elif isinstance(
+ self._plot_item, (plot3ditems.Scatter2D, plot3ditems.Scatter3D)
+ ):
+ return "scatter"
@property
def item_legend(self):
@@ -224,27 +230,28 @@ class ROIStatsItemHelper(object):
def id_key(self):
"""unique key to represent the couple (item, roi)"""
- return (self.item_kind(), self.item_legend, self.roi_kind,
- self.roi_name())
+ return (self.item_kind(), self.item_legend, self.roi_kind, self.roi_name())
class _StatsROITable(_StatsWidgetBase, TableWidget):
"""
Table sued to display some statistics regarding a couple (item/roi)
"""
- _LEGEND_HEADER_DATA = 'legend'
- _KIND_HEADER_DATA = 'kind'
+ _LEGEND_HEADER_DATA = "legend"
+
+ _KIND_HEADER_DATA = "kind"
- _ROI_HEADER_DATA = 'roi'
+ _ROI_HEADER_DATA = "roi"
sigUpdateModeChanged = qt.Signal(object)
"""Signal emitted when the update mode changed"""
def __init__(self, parent, plot):
TableWidget.__init__(self, parent)
- _StatsWidgetBase.__init__(self, statsOnVisibleData=False,
- displayOnlyActItem=False)
+ _StatsWidgetBase.__init__(
+ self, statsOnVisibleData=False, displayOnlyActItem=False
+ )
self.__region_edition_callback = {}
"""We need to keep trace of the roi signals connection because
the roi emits the sigChanged during roi edition"""
@@ -284,8 +291,8 @@ class _StatsROITable(_StatsWidgetBase, TableWidget):
def _addItem(self, item):
"""
Add a _RoiStatsItemWidget item to the table.
-
- :param item:
+
+ :param item:
:return: True if successfully added.
"""
if not isinstance(item, ROIStatsItemHelper):
@@ -307,7 +314,8 @@ class _StatsROITable(_StatsWidgetBase, TableWidget):
tableItems = [
qt.QTableWidgetItem(), # Legend
qt.QTableWidgetItem(), # Kind
- qt.QTableWidgetItem()] # roi
+ qt.QTableWidgetItem(),
+ ] # roi
for column in range(3, self.columnCount()):
header = self.horizontalHeaderItem(column)
@@ -334,8 +342,7 @@ class _StatsROITable(_StatsWidgetBase, TableWidget):
row = self.rowCount() - 1
for column, tableItem in enumerate(tableItems):
tableItem.setData(qt.Qt.UserRole, _Container(item))
- tableItem.setFlags(
- qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable)
+ tableItem.setFlags(qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable)
self.setItem(row, column, tableItem)
# Update table items content
@@ -344,8 +351,9 @@ class _StatsROITable(_StatsWidgetBase, TableWidget):
# Listen for item changes
# Using queued connection to avoid issue with sender
# being that of the signal calling the signal
- item._plot_item.sigItemChanged.connect(self._plotItemChanged,
- qt.Qt.QueuedConnection)
+ item._plot_item.sigItemChanged.connect(
+ self._plotItemChanged, qt.Qt.QueuedConnection
+ )
return True
def _removeAllItems(self):
@@ -369,7 +377,9 @@ class _StatsROITable(_StatsWidgetBase, TableWidget):
_StatsWidgetBase.setStats(self, statsHandler)
self.setRowCount(0)
- self.setColumnCount(len(self._statsHandler.stats) + 3) # + legend, kind and roi # noqa
+ self.setColumnCount(
+ len(self._statsHandler.stats) + 3
+ ) # + legend, kind and roi # noqa
for index, stat in enumerate(self._statsHandler.stats.values()):
headerItem = qt.QTableWidgetItem(stat.name.capitalize())
@@ -407,10 +417,14 @@ class _StatsROITable(_StatsWidgetBase, TableWidget):
statsHandler = self.getStatsHandler()
if statsHandler is not None:
- stats = statsHandler.calculate(plotItem, plot,
- onlimits=self._statsOnVisibleData,
- roi=roi, data_changed=data_changed,
- roi_changed=roi_changed)
+ stats = statsHandler.calculate(
+ plotItem,
+ plot,
+ onlimits=self._statsOnVisibleData,
+ roi=roi,
+ data_changed=data_changed,
+ roi_changed=roi_changed,
+ )
else:
stats = {}
@@ -428,7 +442,7 @@ class _StatsROITable(_StatsWidgetBase, TableWidget):
value = stats.get(name)
if value is None:
_logger.error("Value not found for: %s", name)
- tableItem.setText('-')
+ tableItem.setText("-")
else:
tableItem.setText(str(value))
@@ -473,9 +487,9 @@ class _StatsROITable(_StatsWidgetBase, TableWidget):
:param item: The plot item
:return: An ordered dict of column name to QTableWidgetItem mapping
for the given plot item.
- :rtype: OrderedDict
+ :rtype: dict
"""
- result = OrderedDict()
+ result = {}
row = self._itemToRow(item)
if row is not None:
for column in range(self.columnCount()):
@@ -519,15 +533,21 @@ class _StatsROITable(_StatsWidgetBase, TableWidget):
# item connection within sigRegionChanged should only be
# stopped during the region edition
self.__region_edition_callback[item._roi] = functools.partial(
- self._updateAllStats, False, True)
- item._roi.sigRegionChanged.connect(self.__region_edition_callback[item._roi])
- item._roi.sigEditingStarted.connect(functools.partial(
- self._startFiltering, item._roi))
- item._roi.sigEditingFinished.connect(functools.partial(
- self._endFiltering, item._roi))
+ self._updateAllStats, False, True
+ )
+ item._roi.sigRegionChanged.connect(
+ self.__region_edition_callback[item._roi]
+ )
+ item._roi.sigEditingStarted.connect(
+ functools.partial(self._startFiltering, item._roi)
+ )
+ item._roi.sigEditingFinished.connect(
+ functools.partial(self._endFiltering, item._roi)
+ )
else:
- item._roi.sigChanged.connect(functools.partial(
- self._updateAllStats, False, True))
+ item._roi.sigChanged.connect(
+ functools.partial(self._updateAllStats, False, True)
+ )
self.__roiToItems[item._roi].add(item)
def _startFiltering(self, roi):
@@ -541,10 +561,12 @@ class _StatsROITable(_StatsWidgetBase, TableWidget):
if roi in self.__roiToItems:
del self.__roiToItems[roi]
if isinstance(roi, RegionOfInterest):
- roi.sigRegionEditionStarted.disconnect(functools.partial(
- self._startFiltering, roi))
- roi.sigRegionEditionFinished.disconnect(functools.partial(
- self._startFiltering, roi))
+ roi.sigRegionEditionStarted.disconnect(
+ functools.partial(self._startFiltering, roi)
+ )
+ roi.sigRegionEditionFinished.disconnect(
+ functools.partial(self._startFiltering, roi)
+ )
try:
roi.sigRegionChanged.disconnect(self._updateAllStats)
except:
@@ -575,11 +597,13 @@ class _StatsROITable(_StatsWidgetBase, TableWidget):
self.setRowHidden(row_index, not item.isVisible())
def _removeItem(self, itemKey):
- if isinstance(itemKey, (silx.gui.plot.items.marker.Marker,
- silx.gui.plot.items.shape.Shape)):
+ if isinstance(
+ itemKey,
+ (silx.gui.plot.items.marker.Marker, silx.gui.plot.items.shape.Shape),
+ ):
return
if itemKey not in self._items:
- _logger.warning('key not recognized. Won\'t remove any item')
+ _logger.warning("key not recognized. Won't remove any item")
return
item = self._items[itemKey]
row = self._itemToRow(item)
@@ -597,16 +621,20 @@ class _StatsROITable(_StatsWidgetBase, TableWidget):
:param bool is_request: True if come from a manual request
"""
- if (self.getUpdateMode() is UpdateMode.MANUAL and
- not is_request and not roi_changed):
+ if (
+ self.getUpdateMode() is UpdateMode.MANUAL
+ and not is_request
+ and not roi_changed
+ ):
return
with self._disableSorting():
for row in range(self.rowCount()):
tableItem = self.item(row, 0)
item = self._tableItemToItem(tableItem)
- self._updateStats(item, roi_changed=roi_changed,
- data_changed=is_request)
+ self._updateStats(
+ item, roi_changed=roi_changed, data_changed=is_request
+ )
def _plotCurrentChanged(self, *args):
pass
@@ -624,7 +652,10 @@ class _StatsROITable(_StatsWidgetBase, TableWidget):
"""return the plotItem fitting the requirement kind, legend.
This information is enough to be sure it is unique (in the widget)"""
for plotItem in self.__plotItemToItems:
- if legend == plotItem.getLegend() and self._plotWrapper.getKind(plotItem) == kind:
+ if (
+ legend == plotItem.getLegend()
+ and self._plotWrapper.getKind(plotItem) == kind
+ ):
return plotItem
return None
@@ -668,12 +699,12 @@ class ROIStatsWidget(qt.QMainWindow):
qt.QMainWindow.__init__(self, parent)
toolbar = qt.QToolBar(self)
- icon = icons.getQIcon('add')
+ icon = icons.getQIcon("add")
self._rois = list(rois) if rois is not None else []
- self._addAction = qt.QAction(icon, 'add item/roi', toolbar)
+ self._addAction = qt.QAction(icon, "add item/roi", toolbar)
self._addAction.triggered.connect(self._addRoiStatsItem)
- icon = icons.getQIcon('rm')
- self._removeAction = qt.QAction(icon, 'remove item/roi', toolbar)
+ icon = icons.getQIcon("rm")
+ self._removeAction = qt.QAction(icon, "remove item/roi", toolbar)
self._removeAction.triggered.connect(self._removeCurrentRow)
toolbar.addAction(self._addAction)
@@ -717,15 +748,14 @@ class ROIStatsWidget(qt.QMainWindow):
@docstring(_StatsROITable)
def getStatsHandler(self):
"""
-
- :return:
+
+ :return:
"""
return self._statsROITable.getStatsHandler()
def _addRoiStatsItem(self):
"""Ask the user what couple ROI / item he want to display"""
- dialog = _GetROIItemCoupleDialog(parent=self, plot=self._plot,
- rois=self._rois)
+ dialog = _GetROIItemCoupleDialog(parent=self, plot=self._plot, rois=self._rois)
if dialog.exec():
self.addItem(roi=dialog.getROI(), plotItem=dialog.getItem())
@@ -755,7 +785,7 @@ class ROIStatsWidget(qt.QMainWindow):
def _removeCurrentRow(self):
def is1DKind(kind):
- if kind in ('curve', 'histogram', 'scatter'):
+ if kind in ("curve", "histogram", "scatter"):
return True
else:
return False
@@ -768,12 +798,10 @@ class ROIStatsWidget(qt.QMainWindow):
roi_kind = ROI if is1DKind(item_kind) else RegionOfInterest
roi = self._statsROITable._getRoi(kind=roi_kind, name=roi_name)
if roi is None:
- _logger.warning('failed to retrieve the roi you want to remove')
+ _logger.warning("failed to retrieve the roi you want to remove")
return False
- plot_item = self._statsROITable._getPlotItem(kind=item_kind,
- legend=item_legend)
+ plot_item = self._statsROITable._getPlotItem(kind=item_kind, legend=item_legend)
if plot_item is None:
- _logger.warning('failed to retrieve the plot item you want to'
- 'remove')
+ _logger.warning("failed to retrieve the plot item you want to" "remove")
return False
return self.removeItem(plotItem=plot_item, roi=roi)
diff --git a/src/silx/gui/plot/ScatterMaskToolsWidget.py b/src/silx/gui/plot/ScatterMaskToolsWidget.py
index 5c82fcf..300f3a6 100644
--- a/src/silx/gui/plot/ScatterMaskToolsWidget.py
+++ b/src/silx/gui/plot/ScatterMaskToolsWidget.py
@@ -54,8 +54,8 @@ _logger = logging.getLogger(__name__)
class ScatterMask(BaseMask):
- """A 1D mask for scatter data.
- """
+ """A 1D mask for scatter data."""
+
def __init__(self, scatter=None):
"""
@@ -76,7 +76,7 @@ class ScatterMask(BaseMask):
return self._dataItem.getValueData(copy=False)
def save(self, filename, kind):
- if kind == 'npy':
+ if kind == "npy":
try:
numpy.save(filename, self.getMask(copy=False))
except IOError:
@@ -116,8 +116,9 @@ class ScatterMask(BaseMask):
x, y = self._getXY()
# TODO: this could be optimized if necessary
- indices_in_polygon = [idx for idx in range(len(x)) if
- polygon.is_inside(y[idx], x[idx])]
+ indices_in_polygon = [
+ idx for idx in range(len(x)) if polygon.is_inside(y[idx], x[idx])
+ ]
self.updatePoints(level, indices_in_polygon, mask)
@@ -131,10 +132,7 @@ class ScatterMask(BaseMask):
:param float width:
:param bool mask: True to mask (default), False to unmask.
"""
- vertices = [(y, x),
- (y + height, x),
- (y + height, x + width),
- (y, x + width)]
+ vertices = [(y, x), (y + height, x), (y + height, x + width), (y, x + width)]
self.updatePolygon(level, vertices, mask)
def updateDisk(self, level, cy, cx, radius, mask=True):
@@ -147,7 +145,7 @@ class ScatterMask(BaseMask):
:param bool mask: True to mask (default), False to unmask.
"""
x, y = self._getXY()
- stencil = (y - cy)**2 + (x - cx)**2 < radius**2
+ stencil = (y - cy) ** 2 + (x - cx) ** 2 < radius**2
self.updateStencil(level, stencil, mask)
def updateEllipse(self, level, crow, ccol, radius_r, radius_c, mask=True):
@@ -160,8 +158,12 @@ class ScatterMask(BaseMask):
:param float radius_c: Radius of the ellipse in the column
:param bool mask: True to mask (default), False to unmask.
"""
+
def is_inside(px, py):
- return (px - ccol)**2 / radius_c**2 + (py - crow)**2 / radius_r**2 <= 1.0
+ return (px - ccol) ** 2 / radius_c**2 + (
+ py - crow
+ ) ** 2 / radius_r**2 <= 1.0
+
x, y = self._getXY()
indices_inside = [idx for idx in range(len(x)) if is_inside(x[idx], y[idx])]
self.updatePoints(level, indices_inside, mask)
@@ -180,13 +182,15 @@ class ScatterMask(BaseMask):
"""
# theta is the angle between the horizontal and the line
theta = math.atan((y1 - y0) / (x1 - x0)) if x1 - x0 else 0
- w_over_2_sin_theta = width / 2. * math.sin(theta)
- w_over_2_cos_theta = width / 2. * math.cos(theta)
-
- vertices = [(y0 - w_over_2_cos_theta, x0 + w_over_2_sin_theta),
- (y0 + w_over_2_cos_theta, x0 - w_over_2_sin_theta),
- (y1 + w_over_2_cos_theta, x1 - w_over_2_sin_theta),
- (y1 - w_over_2_cos_theta, x1 + w_over_2_sin_theta)]
+ w_over_2_sin_theta = width / 2.0 * math.sin(theta)
+ w_over_2_cos_theta = width / 2.0 * math.cos(theta)
+
+ vertices = [
+ (y0 - w_over_2_cos_theta, x0 + w_over_2_sin_theta),
+ (y0 + w_over_2_cos_theta, x0 - w_over_2_sin_theta),
+ (y1 + w_over_2_cos_theta, x1 - w_over_2_sin_theta),
+ (y1 - w_over_2_cos_theta, x1 + w_over_2_sin_theta),
+ ]
self.updatePolygon(level, vertices, mask)
@@ -196,8 +200,7 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget):
:class:`PlotWidget`."""
def __init__(self, parent=None, plot=None):
- super(ScatterMaskToolsWidget, self).__init__(parent, plot,
- mask=ScatterMask())
+ super(ScatterMaskToolsWidget, self).__init__(parent, plot, mask=ScatterMask())
self._z = 2 # Mask layer in plot
self._data_scatter = None
"""plot Scatter item for data"""
@@ -223,7 +226,7 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget):
"""
if self._data_scatter is None:
# this can happen if the mask tools widget has never been shown
- self._data_scatter = self.plot._getActiveItem(kind="scatter")
+ self._data_scatter = self.plot.getActiveScatter()
if self._data_scatter is None:
return None
self._adjustColorAndBrushSize(self._data_scatter)
@@ -234,8 +237,10 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget):
mask = numpy.array(mask, copy=False, dtype=numpy.uint8)
- if self._data_scatter.getXData(copy=False).shape == (0,) \
- or mask.shape == self._data_scatter.getXData(copy=False).shape:
+ if (
+ self._data_scatter.getXData(copy=False).shape == (0,)
+ or mask.shape == self._data_scatter.getXData(copy=False).shape
+ ):
self._mask.setMask(mask, copy=copy)
self._mask.commit()
return mask.shape
@@ -248,25 +253,28 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget):
"""Update mask image in plot"""
mask = self.getSelectionMask(copy=False)
if mask is not None:
- self.plot.addScatter(self._data_scatter.getXData(),
- self._data_scatter.getYData(),
- mask,
- legend=self._maskName,
- colormap=self._colormap,
- z=self._z)
- self._mask_scatter = self.plot._getItem(kind="scatter",
- legend=self._maskName)
- self._mask_scatter.setSymbolSize(
- self._data_scatter.getSymbolSize() + 2.0)
+ self.plot.addScatter(
+ self._data_scatter.getXData(),
+ self._data_scatter.getYData(),
+ mask,
+ legend=self._maskName,
+ colormap=self._colormap,
+ z=self._z,
+ )
+ self._mask_scatter = self.plot._getItem(
+ kind="scatter", legend=self._maskName
+ )
+ self._mask_scatter.setSymbolSize(self._data_scatter.getSymbolSize() + 2.0)
self._mask_scatter.sigItemChanged.connect(self.__maskScatterChanged)
- elif self.plot._getItem(kind="scatter",
- legend=self._maskName) is not None:
- self.plot.remove(self._maskName, kind='scatter')
+ elif self.plot._getItem(kind="scatter", legend=self._maskName) is not None:
+ self.plot.remove(self._maskName, kind="scatter")
def __maskScatterChanged(self, event):
"""Handles update of mask scatter"""
- if (event is ItemChangedType.VISUALIZATION_MODE and
- self._mask_scatter is not None):
+ if (
+ event is ItemChangedType.VISUALIZATION_MODE
+ and self._mask_scatter is not None
+ ):
self._mask_scatter.setVisualization(Scatter.Visualization.POINTS)
# track widget visibility and plot active image changes
@@ -274,10 +282,11 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget):
def showEvent(self, event):
try:
self.plot.sigActiveScatterChanged.disconnect(
- self._activeScatterChangedAfterCare)
+ self._activeScatterChangedAfterCare
+ )
except (RuntimeError, TypeError):
pass
- self._activeScatterChanged(None, None) # Init mask + enable/disable widget
+ self._activeScatterChanged(None, None) # Init mask + enable/disable widget
self.plot.sigActiveScatterChanged.connect(self._activeScatterChanged)
def hideEvent(self, event):
@@ -294,14 +303,16 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget):
if self.getSelectionMask(copy=False) is not None:
self.plot.sigActiveScatterChanged.connect(
- self._activeScatterChangedAfterCare)
+ self._activeScatterChangedAfterCare
+ )
def _adjustColorAndBrushSize(self, activeScatter):
colormap = activeScatter.getColormap()
- self._defaultOverlayColor = rgba(cursorColorForColormap(colormap['name']))
- self._setMaskColors(self.levelSpinBox.value(),
- self.transparencySlider.value() /
- self.transparencySlider.maximum())
+ self._defaultOverlayColor = rgba(cursorColorForColormap(colormap["name"]))
+ self._setMaskColors(
+ self.levelSpinBox.value(),
+ self.transparencySlider.value() / self.transparencySlider.maximum(),
+ )
self._z = activeScatter.getZValue() + 1
self._data_scatter = activeScatter
@@ -323,25 +334,30 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget):
removed, otherwise it is adjusted to z.
"""
# check that content changed was the active scatter
- activeScatter = self.plot._getActiveItem(kind="scatter")
+ activeScatter = self.plot.getActiveScatter()
if activeScatter is None or activeScatter.getName() == self._maskName:
# No active scatter or active scatter is the mask...
self.plot.sigActiveScatterChanged.disconnect(
- self._activeScatterChangedAfterCare)
+ self._activeScatterChangedAfterCare
+ )
self._data_extent = None
self._data_scatter = None
else:
self._adjustColorAndBrushSize(activeScatter)
- if self._data_scatter.getXData(copy=False).shape != self._mask.getMask(copy=False).shape:
+ if (
+ self._data_scatter.getXData(copy=False).shape
+ != self._mask.getMask(copy=False).shape
+ ):
# scatter has not the same size, remove mask and stop listening
if self.plot._getItem(kind="scatter", legend=self._maskName):
- self.plot.remove(self._maskName, kind='scatter')
+ self.plot.remove(self._maskName, kind="scatter")
self.plot.sigActiveScatterChanged.disconnect(
- self._activeScatterChangedAfterCare)
+ self._activeScatterChangedAfterCare
+ )
self._data_extent = None
self._data_scatter = None
@@ -352,7 +368,7 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget):
def _activeScatterChanged(self, previous, next):
"""Update widget and mask according to active scatter changes"""
- activeScatter = self.plot._getActiveItem(kind="scatter")
+ activeScatter = self.plot.getActiveScatter()
if activeScatter is None or activeScatter.getName() == self._maskName:
# No active scatter or active scatter is the mask...
@@ -368,7 +384,10 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget):
self._adjustColorAndBrushSize(activeScatter)
self._mask.setDataItem(self._data_scatter)
- if self._data_scatter.getXData(copy=False).shape != self._mask.getMask(copy=False).shape:
+ if (
+ self._data_scatter.getXData(copy=False).shape
+ != self._mask.getMask(copy=False).shape
+ ):
self._mask.reset(self._data_scatter.getXData(copy=False).shape)
self._mask.commit()
else:
@@ -395,16 +414,14 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget):
except IOError:
_logger.error("Can't load filename '%s'", filename)
_logger.debug("Backtrace", exc_info=True)
- raise RuntimeError('File "%s" is not a numpy file.',
- filename)
+ raise RuntimeError('File "%s" is not a numpy file.', filename)
elif extension in ["txt", "csv"]:
try:
mask = numpy.loadtxt(filename)
except IOError:
_logger.error("Can't load filename '%s'", filename)
_logger.debug("Backtrace", exc_info=True)
- raise RuntimeError('File "%s" is not a numpy txt file.',
- filename)
+ raise RuntimeError('File "%s" is not a numpy txt file.', filename)
else:
msg = "Extension '%s' is not supported."
raise RuntimeError(msg % extension)
@@ -417,8 +434,8 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget):
dialog.setWindowTitle("Load Mask")
dialog.setModal(1)
filters = [
- 'NumPy binary file (*.npy)',
- 'CSV text file (*.csv)',
+ "NumPy binary file (*.npy)",
+ "CSV text file (*.csv)",
]
dialog.setNameFilters(filters)
dialog.setFileMode(qt.QFileDialog.ExistingFile)
@@ -454,8 +471,8 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget):
dialog.setWindowTitle("Save Mask")
dialog.setModal(1)
filters = [
- 'NumPy binary file (*.npy)',
- 'CSV text file (*.csv)',
+ "NumPy binary file (*.npy)",
+ "CSV text file (*.csv)",
]
dialog.setNameFilters(filters)
dialog.setFileMode(qt.QFileDialog.AnyFile)
@@ -485,8 +502,7 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget):
strerror = e.strerror
else:
strerror = sys.exc_info()[1]
- msg.setText("Cannot save.\n"
- "Input Output Error: %s" % strerror)
+ msg.setText("Cannot save.\n" "Input Output Error: %s" % strerror)
msg.exec()
return
@@ -509,8 +525,7 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget):
def resetSelectionMask(self):
"""Reset the mask"""
- self._mask.reset(
- shape=self._data_scatter.getXData(copy=False).shape)
+ self._mask.reset(shape=self._data_scatter.getXData(copy=False).shape)
self._mask.commit()
def _getPencilWidth(self):
@@ -525,8 +540,10 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget):
def _plotDrawEvent(self, event):
"""Handle draw events from the plot"""
- if (self._drawingMode is None or
- event['event'] not in ('drawingProgress', 'drawingFinished')):
+ if self._drawingMode is None or event["event"] not in (
+ "drawingProgress",
+ "drawingFinished",
+ ):
return
if not len(self._data_scatter.getXData(copy=False)):
@@ -534,40 +551,42 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget):
level = self.levelSpinBox.value()
- if self._drawingMode == 'rectangle':
- if event['event'] == 'drawingFinished':
+ if self._drawingMode == "rectangle":
+ if event["event"] == "drawingFinished":
doMask = self._isMasking()
self._mask.updateRectangle(
level,
- y=event['y'],
- x=event['x'],
- height=abs(event['height']),
- width=abs(event['width']),
- mask=doMask)
+ y=event["y"],
+ x=event["x"],
+ height=abs(event["height"]),
+ width=abs(event["width"]),
+ mask=doMask,
+ )
self._mask.commit()
- elif self._drawingMode == 'ellipse':
- if event['event'] == 'drawingFinished':
+ elif self._drawingMode == "ellipse":
+ if event["event"] == "drawingFinished":
doMask = self._isMasking()
- center = event['points'][0]
- size = event['points'][1]
- self._mask.updateEllipse(level, center[1], center[0],
- size[1], size[0], doMask)
+ center = event["points"][0]
+ size = event["points"][1]
+ self._mask.updateEllipse(
+ level, center[1], center[0], size[1], size[0], doMask
+ )
self._mask.commit()
- elif self._drawingMode == 'polygon':
- if event['event'] == 'drawingFinished':
+ elif self._drawingMode == "polygon":
+ if event["event"] == "drawingFinished":
doMask = self._isMasking()
- vertices = event['points']
+ vertices = event["points"]
vertices = vertices[:, (1, 0)] # (y, x)
self._mask.updatePolygon(level, vertices, doMask)
self._mask.commit()
- elif self._drawingMode == 'pencil':
+ elif self._drawingMode == "pencil":
doMask = self._isMasking()
# convert from plot to array coords
- x, y = event['points'][-1]
+ x, y = event["points"][-1]
brushSize = self._getPencilWidth()
@@ -576,15 +595,18 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget):
# Draw the line
self._mask.updateLine(
level,
- self._lastPencilPos[0], self._lastPencilPos[1],
- y, x,
+ self._lastPencilPos[0],
+ self._lastPencilPos[1],
+ y,
+ x,
brushSize,
- doMask)
+ doMask,
+ )
# Draw the very first, or last point
- self._mask.updateDisk(level, y, x, brushSize / 2., doMask)
+ self._mask.updateDisk(level, y, x, brushSize / 2.0, doMask)
- if event['event'] == 'drawingFinished':
+ if event["event"] == "drawingFinished":
self._mask.commit()
self._lastPencilPos = None
else:
@@ -597,11 +619,11 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget):
if self._data_scatter is not None:
# Update thresholds according to colormap
colormap = self._data_scatter.getColormap()
- if colormap['autoscale']:
+ if colormap["autoscale"]:
min_ = numpy.nanmin(self._data_scatter.getValueData(copy=False))
max_ = numpy.nanmax(self._data_scatter.getValueData(copy=False))
else:
- min_, max_ = colormap['vmin'], colormap['vmax']
+ min_, max_ = colormap["vmin"], colormap["vmax"]
self.minLineEdit.setText(str(min_))
self.maxLineEdit.setText(str(max_))
@@ -615,6 +637,7 @@ class ScatterMaskToolsDockWidget(BaseMaskToolsDockWidget):
:param plot: The PlotWidget this widget is operating on
:paran str name: The title of this widget
"""
- def __init__(self, parent=None, plot=None, name='Mask'):
+
+ def __init__(self, parent=None, plot=None, name="Mask"):
widget = ScatterMaskToolsWidget(plot=plot)
super(ScatterMaskToolsDockWidget, self).__init__(parent, name, widget)
diff --git a/src/silx/gui/plot/ScatterView.py b/src/silx/gui/plot/ScatterView.py
index abacbef..06475e3 100644
--- a/src/silx/gui/plot/ScatterView.py
+++ b/src/silx/gui/plot/ScatterView.py
@@ -63,7 +63,7 @@ class ScatterView(qt.QMainWindow):
:type backend: Union[str,~silx.gui.plot.backends.BackendBase.BackendBase]
"""
- _SCATTER_LEGEND = ' '
+ _SCATTER_LEGEND = " "
"""Legend used for the scatter item"""
def __init__(self, parent=None, backend=None):
@@ -72,7 +72,7 @@ class ScatterView(qt.QMainWindow):
# behave as a widget
self.setWindowFlags(qt.Qt.Widget)
else:
- self.setWindowTitle('ScatterView')
+ self.setWindowTitle("ScatterView")
# Create plot widget
plot = PlotWidget(parent=self, backend=backend)
@@ -93,10 +93,13 @@ class ScatterView(qt.QMainWindow):
self.__pickingCache = None
self._positionInfo = tools.PositionInfo(
plot=plot,
- converters=(('X', WeakMethodProxy(self._getPickedX)),
- ('Y', WeakMethodProxy(self._getPickedY)),
- ('Data', WeakMethodProxy(self._getPickedValue)),
- ('Index', WeakMethodProxy(self._getPickedIndex))))
+ converters=(
+ ("X", WeakMethodProxy(self._getPickedX)),
+ ("Y", WeakMethodProxy(self._getPickedY)),
+ ("Data", WeakMethodProxy(self._getPickedValue)),
+ ("Index", WeakMethodProxy(self._getPickedIndex)),
+ ),
+ )
# Combine plot, position info and colorbar into central widget
gridLayout = qt.QGridLayout()
@@ -114,23 +117,25 @@ class ScatterView(qt.QMainWindow):
# Create mask tool dock widget
self._maskToolsWidget = ScatterMaskToolsWidget(parent=self, plot=plot)
self._maskDock = BoxLayoutDockWidget()
- self._maskDock.setWindowTitle('Scatter Mask')
+ self._maskDock.setWindowTitle("Scatter Mask")
self._maskDock.setWidget(self._maskToolsWidget)
self._maskDock.setVisible(False)
self.addDockWidget(qt.Qt.BottomDockWidgetArea, self._maskDock)
self._maskAction = self._maskDock.toggleViewAction()
- self._maskAction.setIcon(icons.getQIcon('image-mask'))
+ self._maskAction.setIcon(icons.getQIcon("image-mask"))
self._maskAction.setToolTip("Display/hide mask tools")
- self._intensityHistoAction = actions_histogram.PixelIntensitiesHistoAction(plot=plot, parent=self)
+ self._intensityHistoAction = actions_histogram.PixelIntensitiesHistoAction(
+ plot=plot, parent=self
+ )
# Create toolbars
self._interactiveModeToolBar = tools.InteractiveModeToolBar(
- parent=self, plot=plot)
+ parent=self, plot=plot
+ )
- self._scatterToolBar = tools.ScatterToolBar(
- parent=self, plot=plot)
+ self._scatterToolBar = tools.ScatterToolBar(parent=self, plot=plot)
self._scatterToolBar.addAction(self._maskAction)
self._scatterToolBar.addAction(self._intensityHistoAction)
@@ -139,15 +144,16 @@ class ScatterView(qt.QMainWindow):
self._outputToolBar = tools.OutputToolBar(parent=self, plot=plot)
# Activate shortcuts in PlotWindow widget:
- for toolbar in (self._interactiveModeToolBar,
- self._scatterToolBar,
- self._profileToolBar,
- self._outputToolBar):
+ for toolbar in (
+ self._interactiveModeToolBar,
+ self._scatterToolBar,
+ self._profileToolBar,
+ self._outputToolBar,
+ ):
self.addToolBar(toolbar)
for action in toolbar.actions():
self.addAction(action)
-
def __createEmptyScatter(self):
"""Create an empty scatter item that is used to display the data
@@ -155,8 +161,7 @@ class ScatterView(qt.QMainWindow):
"""
plot = self.getPlotWidget()
plot.addScatter(x=(), y=(), value=(), legend=self._SCATTER_LEGEND)
- scatter = plot._getItem(
- kind='scatter', legend=self._SCATTER_LEGEND)
+ scatter = plot._getItem(kind="scatter", legend=self._SCATTER_LEGEND)
# Profile is not selectable,
# so it does not interfere with profile interaction
scatter._setSelectable(False)
@@ -180,16 +185,24 @@ class ScatterView(qt.QMainWindow):
if pixelPos is not None:
# Start from top-most item
result = plot._pickTopMost(
- pixelPos[0], pixelPos[1],
- lambda item: isinstance(item, items.Scatter))
+ pixelPos[0],
+ pixelPos[1],
+ lambda item: isinstance(item, items.Scatter),
+ )
if result is not None:
item = result.getItem()
- if item.getVisualization() is items.Scatter.Visualization.BINNED_STATISTIC:
+ if (
+ item.getVisualization()
+ is items.Scatter.Visualization.BINNED_STATISTIC
+ ):
# Get highest index of closest points
selected = result.getIndices(copy=False)[::-1]
- dataIndex = selected[numpy.argmin(
- (item.getXData(copy=False)[selected] - x)**2 +
- (item.getYData(copy=False)[selected] - y)**2)]
+ dataIndex = selected[
+ numpy.argmin(
+ (item.getXData(copy=False)[selected] - x) ** 2
+ + (item.getYData(copy=False)[selected] - y) ** 2
+ )
+ ]
else:
# Get last index
# with matplotlib it should be the top-most point
@@ -198,7 +211,8 @@ class ScatterView(qt.QMainWindow):
dataIndex,
item.getXData(copy=False)[dataIndex],
item.getYData(copy=False)[dataIndex],
- item.getValueData(copy=False)[dataIndex])
+ item.getValueData(copy=False)[dataIndex],
+ )
return self.__pickingCache
@@ -210,7 +224,7 @@ class ScatterView(qt.QMainWindow):
:return: The data index at that point or '-'
"""
picking = self._pickScatterData(x, y)
- return '-' if picking is None else picking[0]
+ return "-" if picking is None else picking[0]
def _getPickedX(self, x, y):
"""Returns X position snapped to scatter plot when close enough
@@ -240,7 +254,7 @@ class ScatterView(qt.QMainWindow):
:return: The data value at that point or '-'
"""
picking = self._pickScatterData(x, y)
- return '-' if picking is None else picking[3]
+ return "-" if picking is None else picking[3]
def _mouseInPlotArea(self, x, y):
"""Clip mouse coordinates to plot area coordinates
@@ -344,7 +358,7 @@ class ScatterView(qt.QMainWindow):
:param yerror: Values with the uncertainties on the y values
:type yerror: A float, or a numpy.ndarray of float32. See xerror.
:param alpha: Values with the transparency (between 0 and 1)
- :type alpha: A float, or a numpy.ndarray of float32
+ :type alpha: A float, or a numpy.ndarray of float32
:param bool copy: True make a copy of the data (default),
False to use provided arrays.
"""
@@ -353,7 +367,8 @@ class ScatterView(qt.QMainWindow):
value = () if value is None else value
self.getScatterItem().setData(
- x=x, y=y, value=value, xerror=xerror, yerror=yerror, alpha=alpha, copy=copy)
+ x=x, y=y, value=value, xerror=xerror, yerror=yerror, alpha=alpha, copy=copy
+ )
@docstring(items.Scatter)
def getData(self, *args, **kwargs):
@@ -367,7 +382,7 @@ class ScatterView(qt.QMainWindow):
:rtype: ~silx.gui.plot.items.Scatter
"""
plot = self.getPlotWidget()
- scatter = plot._getItem(kind='scatter', legend=self._SCATTER_LEGEND)
+ scatter = plot._getItem(kind="scatter", legend=self._SCATTER_LEGEND)
if scatter is None: # Resilient to call to PlotWidget API (e.g., clear)
scatter = self.__createEmptyScatter()
return scatter
diff --git a/src/silx/gui/plot/StackView.py b/src/silx/gui/plot/StackView.py
index 5101f87..36560fd 100644
--- a/src/silx/gui/plot/StackView.py
+++ b/src/silx/gui/plot/StackView.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2016-2022 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -56,7 +56,7 @@ Example::
sv = StackViewMainWindow()
- sv.setColormap("jet", autoscale=True)
+ sv.setColormap("viridis", vmin=-4, vmax=4)
sv.setStack(mystack)
sv.setLabels(["1st dim (0-99)", "2nd dim (0-199)",
"3rd dim (0-299)"])
@@ -84,15 +84,11 @@ from .tools import LimitsToolBar
from .Profile import Profile3DToolBar
from ..widgets.FrameBrowser import HorizontalSliderWithBrowser
-from silx.gui.plot.actions import control as actions_control
from silx.gui.plot.actions import io as silx_io
from silx.io.nxdata import save_NXdata
from silx.utils.array_like import DatasetView, ListOfImages
from silx.math import calibration
-from silx.utils.deprecation import deprecated_warning
-from silx.utils.deprecation import deprecated
-import h5py
from silx.io.utils import is_dataset
_logger = logging.getLogger(__name__)
@@ -130,6 +126,7 @@ class StackView(qt.QMainWindow):
See :class:`silx.gui.plot.PlotTools.PositionInfo`.
:param bool mask: Toggle visibilty of mask action.
"""
+
# Qt signals
valueChanged = qt.Signal(object, object, object)
"""Signals that the data value under the cursor has changed.
@@ -163,20 +160,34 @@ class StackView(qt.QMainWindow):
This signal provides the current frame number.
"""
- IMAGE_STACK_FILTER_NXDATA = 'Stack of images as NXdata (%s)' % silx_io._NEXUS_HDF5_EXT_STR
-
+ IMAGE_STACK_FILTER_NXDATA = (
+ "Stack of images as NXdata (%s)" % silx_io._NEXUS_HDF5_EXT_STR
+ )
- def __init__(self, parent=None, resetzoom=True, backend=None,
- autoScale=False, logScale=False, grid=False,
- colormap=True, aspectRatio=True, yinverted=True,
- copy=True, save=True, print_=True, control=False,
- position=None, mask=True):
+ def __init__(
+ self,
+ parent=None,
+ resetzoom=True,
+ backend=None,
+ autoScale=False,
+ logScale=False,
+ grid=False,
+ colormap=True,
+ aspectRatio=True,
+ yinverted=True,
+ copy=True,
+ save=True,
+ print_=True,
+ control=False,
+ position=None,
+ mask=True,
+ ):
qt.QMainWindow.__init__(self, parent)
if parent is not None:
# behave as a widget
self.setWindowFlags(qt.Qt.Widget)
else:
- self.setWindowTitle('StackView')
+ self.setWindowTitle("StackView")
self._stack = None
"""Loaded stack, as a 3D array, a 3D dataset or a list of 2D arrays."""
@@ -188,14 +199,10 @@ class StackView(qt.QMainWindow):
self._stackItem = ImageStack()
"""Hold the item displaying the stack"""
- imageLegend = '__StackView__image' + str(id(self))
+ imageLegend = "__StackView__image" + str(id(self))
self._stackItem.setName(imageLegend)
- self.__autoscaleCmap = False
- """Flag to disable/enable colormap auto-scaling
- based on the min/max values of the entire 3D volume"""
- self.__dimensionsLabels = ["Dimension 0", "Dimension 1",
- "Dimension 2"]
+ self.__dimensionsLabels = ["Dimension 0", "Dimension 1", "Dimension 2"]
"""These labels are displayed on the X and Y axes.
:meth:`setLabels` updates this attribute."""
@@ -206,39 +213,56 @@ class StackView(qt.QMainWindow):
"""Function returning the plot title based on the frame index.
It can be set to a custom function using :meth:`setTitleCallback`"""
- self.calibrations3D = (calibration.NoCalibration(),
- calibration.NoCalibration(),
- calibration.NoCalibration())
+ self.calibrations3D = (
+ calibration.NoCalibration(),
+ calibration.NoCalibration(),
+ calibration.NoCalibration(),
+ )
central_widget = qt.QWidget(self)
- self._plot = PlotWindow(parent=central_widget, backend=backend,
- resetzoom=resetzoom, autoScale=autoScale,
- logScale=logScale, grid=grid,
- curveStyle=False, colormap=colormap,
- aspectRatio=aspectRatio, yInverted=yinverted,
- copy=copy, save=save, print_=print_,
- control=control, position=position,
- roi=False, mask=mask)
+ self._plot = PlotWindow(
+ parent=central_widget,
+ backend=backend,
+ resetzoom=resetzoom,
+ autoScale=autoScale,
+ logScale=logScale,
+ grid=grid,
+ curveStyle=False,
+ colormap=colormap,
+ aspectRatio=aspectRatio,
+ yInverted=yinverted,
+ copy=copy,
+ save=save,
+ print_=print_,
+ control=control,
+ position=position,
+ roi=False,
+ mask=mask,
+ )
self._plot.addItem(self._stackItem)
self._plot.getIntensityHistogramAction().setVisible(True)
self.sigInteractiveModeChanged = self._plot.sigInteractiveModeChanged
self.sigActiveImageChanged = self._plot.sigActiveImageChanged
self.sigPlotSignal = self._plot.sigPlotSignal
- if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == 'downward':
+ if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == "downward":
self._plot.getYAxis().setInverted(True)
self._plot.getColorBarAction().setVisible(True)
self._plot.getColorBarWidget().setVisible(True)
- self._profileToolBar = Profile3DToolBar(parent=self._plot,
- stackview=self)
+ self._profileToolBar = Profile3DToolBar(parent=self._plot, stackview=self)
self._plot.addToolBar(self._profileToolBar)
- self._plot.getXAxis().setLabel('Columns')
- self._plot.getYAxis().setLabel('Rows')
+ self._plot.getXAxis().setLabel("Columns")
+ self._plot.getYAxis().setLabel("Rows")
self._plot.sigPlotSignal.connect(self._plotCallback)
- self._plot.getSaveAction().setFileFilter('image', self.IMAGE_STACK_FILTER_NXDATA, func=self._saveImageStack, appendToFile=True)
+ self._plot.getSaveAction().setFileFilter(
+ "image",
+ self.IMAGE_STACK_FILTER_NXDATA,
+ func=self._saveImageStack,
+ appendToFile=True,
+ )
self.__planeSelection = PlanesWidget(self._plot)
self.__planeSelection.sigPlaneSelectionChanged.connect(self.setPerspective)
@@ -262,7 +286,8 @@ class StackView(qt.QMainWindow):
# clear profile lines when the perspective changes (plane browsed changed)
self.__planeSelection.sigPlaneSelectionChanged.connect(
- self._profileToolBar.clearProfile)
+ self._profileToolBar.clearProfile
+ )
def _saveImageStack(self, plot, filename, nameFilter):
"""Save all images from the stack into a volume.
@@ -274,21 +299,25 @@ class StackView(qt.QMainWindow):
:raises: ValueError if nameFilter is invalid
"""
if not nameFilter == self.IMAGE_STACK_FILTER_NXDATA:
- raise ValueError('Wrong callback')
- entryPath = silx_io.SaveAction._selectWriteableOutputGroup(filename, parent=self)
+ raise ValueError("Wrong callback")
+ entryPath = silx_io.SaveAction._selectWriteableOutputGroup(
+ filename, parent=self
+ )
if entryPath is None:
return False
- return save_NXdata(filename,
- nxentry_name=entryPath,
- signal=self.getStack(copy=False, returnNumpyArray=True)[0],
- signal_name="image_stack")
+ return save_NXdata(
+ filename,
+ nxentry_name=entryPath,
+ signal=self.getStack(copy=False, returnNumpyArray=True)[0],
+ signal_name="image_stack",
+ )
def _plotCallback(self, eventDict):
"""Callback for plot events.
Emit :attr:`valueChanged` signal, with (x, y, value) tuple of the
cursor location in the plot."""
- if eventDict['event'] == 'mouseMoved':
+ if eventDict["event"] == "mouseMoved":
activeImage = self.getActiveImage()
if activeImage is not None:
data = activeImage.getData()
@@ -297,15 +326,13 @@ class StackView(qt.QMainWindow):
# Get corresponding coordinate in image
origin = activeImage.getOrigin()
scale = activeImage.getScale()
- x = int((eventDict['x'] - origin[0]) / scale[0])
- y = int((eventDict['y'] - origin[1]) / scale[1])
+ x = int((eventDict["x"] - origin[0]) / scale[0])
+ y = int((eventDict["y"] - origin[1]) / scale[1])
if 0 <= x < width and 0 <= y < height:
- self.valueChanged.emit(float(x), float(y),
- data[y][x])
+ self.valueChanged.emit(float(x), float(y), data[y][x])
else:
- self.valueChanged.emit(float(x), float(y),
- None)
+ self.valueChanged.emit(float(x), float(y), None)
def getPerspective(self):
"""Returns the index of the dimension the stack is browsed with
@@ -329,8 +356,7 @@ class StackView(qt.QMainWindow):
return
else:
if perspective > 2 or perspective < 0:
- raise ValueError(
- "Perspective must be 0, 1 or 2, not %s" % perspective)
+ raise ValueError("Perspective must be 0, 1 or 2, not %s" % perspective)
self._perspective = int(perspective)
self.__createTransposedView()
@@ -338,20 +364,29 @@ class StackView(qt.QMainWindow):
self._plot.resetZoom()
self.__updatePlotLabels()
self._updateTitle()
- self._browser_label.setText("Image index (Dim%d):" %
- (self._first_stack_dimension + perspective))
+ self._browser_label.setText(
+ "Image index (Dim%d):" % (self._first_stack_dimension + perspective)
+ )
self.sigPlaneSelectionChanged.emit(perspective)
- self.sigStackChanged.emit(self._stack.size if
- self._stack is not None else 0)
- self.__planeSelection.sigPlaneSelectionChanged.disconnect(self.setPerspective)
+ self.sigStackChanged.emit(
+ self._stack.size if self._stack is not None else 0
+ )
+ self.__planeSelection.sigPlaneSelectionChanged.disconnect(
+ self.setPerspective
+ )
self.__planeSelection.setPerspective(self._perspective)
self.__planeSelection.sigPlaneSelectionChanged.connect(self.setPerspective)
def __updatePlotLabels(self):
"""Update plot axes labels depending on perspective"""
- y, x = (1, 2) if self._perspective == 0 else \
- (0, 2) if self._perspective == 1 else (0, 1)
+ y, x = (
+ (1, 2)
+ if self._perspective == 0
+ else (0, 2)
+ if self._perspective == 1
+ else (0, 1)
+ )
self.setGraphXLabel(self.__dimensionsLabels[x])
self.setGraphYLabel(self.__dimensionsLabels[y])
@@ -409,9 +444,11 @@ class StackView(qt.QMainWindow):
See setStack for parameter documentation
"""
if calibrations is None:
- self.calibrations3D = (calibration.NoCalibration(),
- calibration.NoCalibration(),
- calibration.NoCalibration())
+ self.calibrations3D = (
+ calibration.NoCalibration(),
+ calibration.NoCalibration(),
+ calibration.NoCalibration(),
+ )
else:
self.calibrations3D = []
for i, calib in enumerate(calibrations):
@@ -420,17 +457,20 @@ class StackView(qt.QMainWindow):
elif calib is None:
calib = calibration.NoCalibration()
elif not isinstance(calib, calibration.AbstractCalibration):
- raise TypeError("calibration must be a 2-tuple, None or" +
- " an instance of an AbstractCalibration " +
- "subclass")
+ raise TypeError(
+ "calibration must be a 2-tuple, None or"
+ + " an instance of an AbstractCalibration "
+ + "subclass"
+ )
elif not calib.is_affine():
_logger.warning(
- "Calibration for dimension %d is not linear, "
- "it will be ignored for scaling the graph axes.",
- i)
+ "Calibration for dimension %d is not linear, "
+ "it will be ignored for scaling the graph axes.",
+ i,
+ )
self.calibrations3D.append(calib)
- def getCalibrations(self, order='array'):
+ def getCalibrations(self, order="array"):
"""Returns currently used calibrations for each axis
Returned calibrations might differ from the ones that were set as
@@ -442,7 +482,7 @@ class StackView(qt.QMainWindow):
:return: Calibrations ordered depending on order
:rtype: List[~silx.math.calibration.AbstractCalibration]
"""
- assert order in ('array', 'axes')
+ assert order in ("array", "axes")
calibs = []
# filter out non-linear calibration for graph axes
@@ -451,11 +491,13 @@ class StackView(qt.QMainWindow):
calib = calibration.NoCalibration()
calibs.append(calib)
- if order == 'axes': # Move 'z' axis to the end
+ if order == "axes": # Move 'z' axis to the end
xy_dims = [d for d in (0, 1, 2) if d != self._perspective]
- calibs = [calibs[max(xy_dims)],
- calibs[min(xy_dims)],
- calibs[self._perspective]]
+ calibs = [
+ calibs[max(xy_dims)],
+ calibs[min(xy_dims)],
+ calibs[self._perspective],
+ ]
return tuple(calibs)
@@ -463,14 +505,14 @@ class StackView(qt.QMainWindow):
"""
:return: 2-tuple (XScale, YScale) for current image view
"""
- xcalib, ycalib, _zcalib = self.getCalibrations(order='axes')
+ xcalib, ycalib, _zcalib = self.getCalibrations(order="axes")
return xcalib.get_slope(), ycalib.get_slope()
def _getImageOrigin(self):
"""
:return: 2-tuple (XOrigin, YOrigin) for current image view
"""
- xcalib, ycalib, _zcalib = self.getCalibrations(order='axes')
+ xcalib, ycalib, _zcalib = self.getCalibrations(order="axes")
return xcalib(0), ycalib(0)
def _getImageZ(self, index):
@@ -478,7 +520,7 @@ class StackView(qt.QMainWindow):
:param idx: 0-based image index in the stack
:return: calibrated Z value corresponding to the image idx
"""
- _xcalib, _ycalib, zcalib = self.getCalibrations(order='axes')
+ _xcalib, _ycalib, zcalib = self.getCalibrations(order="axes")
return zcalib(index)
def _updateTitle(self):
@@ -525,8 +567,8 @@ class StackView(qt.QMainWindow):
assert len(img.shape) == 2
except AssertionError:
raise ValueError(
- "Stack must be a 3D array/dataset or a list of " +
- "2D arrays.")
+ "Stack must be a 3D array/dataset or a list of " + "2D arrays."
+ )
stack = ListOfImages(stack)
assert len(stack.shape) == 3, "data must be 3D"
@@ -539,9 +581,6 @@ class StackView(qt.QMainWindow):
perspective_changed = True
self.setPerspective(perspective)
- if self.__autoscaleCmap:
- self.scaleColormapRangeToStack()
-
# init plot
self._stackItem.setStackData(self.__transposed_view, 0, copy=False)
self._stackItem.setColormap(self.getColormap())
@@ -554,7 +593,7 @@ class StackView(qt.QMainWindow):
if exists is None:
self._plot.addItem(self._stackItem)
- self._plot.setActiveImage(self._stackItem.getName())
+ self._plot.setActiveImage(self._stackItem)
self.__updatePlotLabels()
self._updateTitle()
@@ -564,7 +603,7 @@ class StackView(qt.QMainWindow):
# enable and init browser
self._browser.setEnabled(True)
- if not perspective_changed: # avoid double signal (see self.setPerspective)
+ if not perspective_changed: # avoid double signal (see self.setPerspective)
self.sigStackChanged.emit(stack.size)
def getStack(self, copy=True, returnNumpyArray=False):
@@ -590,15 +629,15 @@ class StackView(qt.QMainWindow):
colormap = image.getColormap()
params = {
- 'info': image.getInfo(),
- 'origin': image.getOrigin(),
- 'scale': image.getScale(),
- 'z': image.getZValue(),
- 'selectable': image.isSelectable(),
- 'draggable': image.isDraggable(),
- 'colormap': colormap,
- 'xlabel': image.getXLabel(),
- 'ylabel': image.getYLabel(),
+ "info": image.getInfo(),
+ "origin": image.getOrigin(),
+ "scale": image.getScale(),
+ "z": image.getZValue(),
+ "selectable": image.isSelectable(),
+ "draggable": image.isDraggable(),
+ "colormap": colormap,
+ "xlabel": image.getXLabel(),
+ "ylabel": image.getYLabel(),
}
if returnNumpyArray or copy:
return numpy.array(self._stack, copy=copy), params
@@ -641,15 +680,15 @@ class StackView(qt.QMainWindow):
colormap = None
params = {
- 'info': image.getInfo(),
- 'origin': image.getOrigin(),
- 'scale': image.getScale(),
- 'z': image.getZValue(),
- 'selectable': image.isSelectable(),
- 'draggable': image.isDraggable(),
- 'colormap': colormap,
- 'xlabel': image.getXLabel(),
- 'ylabel': image.getYLabel(),
+ "info": image.getInfo(),
+ "origin": image.getOrigin(),
+ "scale": image.getScale(),
+ "z": image.getZValue(),
+ "selectable": image.isSelectable(),
+ "draggable": image.isDraggable(),
+ "colormap": colormap,
+ "xlabel": image.getXLabel(),
+ "ylabel": image.getYLabel(),
}
if returnNumpyArray or copy:
return numpy.array(self.__transposed_view, copy=copy), params
@@ -718,8 +757,8 @@ class StackView(qt.QMainWindow):
def clear(self):
"""Clear the widget:
- - clear the plot
- - clear the loaded data volume
+ - clear the plot
+ - clear the loaded data volume
"""
self._stack = None
self.__transposed_view = None
@@ -742,9 +781,11 @@ class StackView(qt.QMainWindow):
of the data volumes.
"""
- default_labels = ["Dimension %d" % self._first_stack_dimension,
- "Dimension %d" % (self._first_stack_dimension + 1),
- "Dimension %d" % (self._first_stack_dimension + 2)]
+ default_labels = [
+ "Dimension %d" % self._first_stack_dimension,
+ "Dimension %d" % (self._first_stack_dimension + 1),
+ "Dimension %d" % (self._first_stack_dimension + 2),
+ ]
if labels is None:
new_labels = default_labels
else:
@@ -791,8 +832,9 @@ class StackView(qt.QMainWindow):
vmin, vmax = colormap.getColormapRange(data=stack[0])
colormap.setVRange(vmin=vmin, vmax=vmax)
- def setColormap(self, colormap=None, normalization=None,
- autoscale=None, vmin=None, vmax=None, colors=None):
+ def setColormap(
+ self, colormap=None, normalization=None, vmin=None, vmax=None, colors=None
+ ):
"""Set the colormap and update active image.
Parameters that are not provided are taken from the current colormap.
@@ -818,59 +860,33 @@ class StackView(qt.QMainWindow):
Or a :class`.Colormap` object.
:type colormap: dict or str.
:param str normalization: Colormap mapping: 'linear' or 'log'.
- :param bool autoscale: Whether to use autoscale or [vmin, vmax] range.
- Default value of autoscale is False. This option is not compatible
- with h5py datasets.
- :param float vmin: The minimum value of the range to use if
- 'autoscale' is False.
- :param float vmax: The maximum value of the range to use if
- 'autoscale' is False.
+ :param float vmin: The minimum value of the range to use.
+ :param float vmax: The maximum value of the range to use.
:param numpy.ndarray colors: Only used if name is None.
Custom colormap colors as Nx3 or Nx4 RGB or RGBA arrays
"""
# if is a colormap object or a dictionary
if isinstance(colormap, Colormap) or isinstance(colormap, dict):
# Support colormap parameter as a dict
- errmsg = "If colormap is provided as a Colormap object, all other parameters"
+ errmsg = (
+ "If colormap is provided as a Colormap object, all other parameters"
+ )
errmsg += " must not be specified when calling setColormap"
assert normalization is None, errmsg
- assert autoscale is None, errmsg
assert vmin is None, errmsg
assert vmax is None, errmsg
assert colors is None, errmsg
- if isinstance(colormap, dict):
- reason = 'colormap parameter should now be an object'
- replacement = 'Colormap()'
- since_version = '0.6'
- deprecated_warning(type_='function',
- name='setColormap',
- reason=reason,
- replacement=replacement,
- since_version=since_version)
- _colormap = Colormap._fromDict(colormap)
- else:
- _colormap = colormap
+ _colormap = colormap
else:
- norm = normalization if normalization is not None else 'linear'
- name = colormap if colormap is not None else 'gray'
- _colormap = Colormap(name=name,
- normalization=norm,
- vmin=vmin,
- vmax=vmax,
- colors=colors)
-
- if autoscale is not None:
- deprecated_warning(
- type_='function',
- name='setColormap',
- reason='autoscale argument is replaced by a method',
- replacement='scaleColormapRangeToStack',
- since_version='0.14')
- self.__autoscaleCmap = bool(autoscale)
+ norm = normalization if normalization is not None else "linear"
+ name = colormap if colormap is not None else "gray"
+ _colormap = Colormap(
+ name=name, normalization=norm, vmin=vmin, vmax=vmax, colors=colors
+ )
cursorColor = cursorColorForColormap(_colormap.getName())
- self._plot.setInteractiveMode('zoom', color=cursorColor)
+ self._plot.setInteractiveMode("zoom", color=cursorColor)
self._plot.setDefaultColormap(_colormap)
@@ -879,16 +895,6 @@ class StackView(qt.QMainWindow):
if isinstance(activeImage, items.ColormapMixIn):
activeImage.setColormap(self.getColormap())
- if self.__autoscaleCmap:
- # scaleColormapRangeToStack needs to be called **after**
- # setDefaultColormap so getColormap returns the right colormap
- self.scaleColormapRangeToStack()
-
-
- @deprecated(replacement="getPlotWidget", since_version="0.13")
- def getPlot(self):
- return self.getPlotWidget()
-
def getPlotWidget(self):
"""Return the :class:`PlotWidget`.
@@ -912,13 +918,11 @@ class StackView(qt.QMainWindow):
# proxies to PlotWidget or PlotWindow methods
def getProfileToolbar(self):
- """Profile tools attached to this plot
- """
+ """Profile tools attached to this plot"""
return self._profileToolBar
def getGraphTitle(self):
- """Return the plot main title as a str.
- """
+ """Return the plot main title as a str."""
return self._plot.getGraphTitle()
def setGraphTitle(self, title=""):
@@ -929,8 +933,7 @@ class StackView(qt.QMainWindow):
return self._plot.setGraphTitle(title)
def getGraphXLabel(self):
- """Return the current horizontal axis label as a str.
- """
+ """Return the current horizontal axis label as a str."""
return self._plot.getXAxis().getLabel()
def setGraphXLabel(self, label=None):
@@ -942,14 +945,14 @@ class StackView(qt.QMainWindow):
label = self.__dimensionsLabels[1 if self._perspective == 2 else 2]
self._plot.getXAxis().setLabel(label)
- def getGraphYLabel(self, axis='left'):
+ def getGraphYLabel(self, axis="left"):
"""Return the current vertical axis label as a str.
:param str axis: The Y axis for which to get the label (left or right)
"""
return self._plot.getYAxis().getLabel(axis)
- def setGraphYLabel(self, label=None, axis='left'):
+ def setGraphYLabel(self, label=None, axis="left"):
"""Set the vertical axis label on the plot.
:param str label: The Y axis label
@@ -1033,8 +1036,7 @@ class StackView(qt.QMainWindow):
# kind of private methods, but needed by Profile
def getActiveImage(self, just_legend=False):
- """Returns the stack image object.
- """
+ """Returns the stack image object."""
if just_legend:
return self._stackItem.getName()
return self._stackItem
@@ -1049,8 +1051,7 @@ class StackView(qt.QMainWindow):
"""
return self._plot.getColorBarAction()
- def remove(self, legend=None,
- kind=('curve', 'image', 'item', 'marker')):
+ def remove(self, legend=None, kind=("curve", "image", "item", "marker")):
"""See :meth:`Plot.Plot.remove`"""
self._plot.remove(legend, kind)
@@ -1060,10 +1061,6 @@ class StackView(qt.QMainWindow):
"""
self._plot.setInteractiveMode(*args, **kwargs)
- @deprecated(replacement="addShape", since_version="0.13")
- def addItem(self, *args, **kwargs):
- self.addShape(*args, **kwargs)
-
def addShape(self, *args, **kwargs):
"""
See :meth:`Plot.Plot.addShape`
@@ -1076,6 +1073,7 @@ class PlanesWidget(qt.QWidget):
:param parent: the parent QWidget
"""
+
sigPlaneSelectionChanged = qt.Signal(int)
def __init__(self, parent):
@@ -1098,7 +1096,8 @@ class PlanesWidget(qt.QWidget):
self.qcbAxisSelection = qt.QComboBox(self)
self._setCBChoices(first_stack_dimension=0)
self.qcbAxisSelection.currentIndexChanged[int].connect(
- self.__planeSelectionChanged)
+ self.__planeSelectionChanged
+ )
layout0.addWidget(self.qcbAxisSelection)
@@ -1117,12 +1116,12 @@ class PlanesWidget(qt.QWidget):
def _setCBChoices(self, first_stack_dimension):
self.qcbAxisSelection.clear()
- dim1dim2 = 'Dim%d-Dim%d' % (first_stack_dimension + 1,
- first_stack_dimension + 2)
- dim0dim2 = 'Dim%d-Dim%d' % (first_stack_dimension,
- first_stack_dimension + 2)
- dim0dim1 = 'Dim%d-Dim%d' % (first_stack_dimension,
- first_stack_dimension + 1)
+ dim1dim2 = "Dim%d-Dim%d" % (
+ first_stack_dimension + 1,
+ first_stack_dimension + 2,
+ )
+ dim0dim2 = "Dim%d-Dim%d" % (first_stack_dimension, first_stack_dimension + 2)
+ dim0dim1 = "Dim%d-Dim%d" % (first_stack_dimension, first_stack_dimension + 1)
self.qcbAxisSelection.addItem(icons.getQIcon("cube-front"), dim1dim2)
self.qcbAxisSelection.addItem(icons.getQIcon("cube-bottom"), dim0dim2)
@@ -1160,25 +1159,25 @@ class StackViewMainWindow(StackView):
:param QWidget parent: Parent widget, or None
"""
+
def __init__(self, parent=None):
self._dataInfo = None
super(StackViewMainWindow, self).__init__(parent)
self.setWindowFlags(qt.Qt.Window)
# Add toolbars and status bar
- self.addToolBar(qt.Qt.BottomToolBarArea,
- LimitsToolBar(plot=self._plot))
+ self.addToolBar(qt.Qt.BottomToolBarArea, LimitsToolBar(plot=self._plot))
self.statusBar()
- menu = self.menuBar().addMenu('File')
+ menu = self.menuBar().addMenu("File")
menu.addAction(self._plot.getOutputToolBar().getSaveAction())
menu.addAction(self._plot.getOutputToolBar().getPrintAction())
menu.addSeparator()
- action = menu.addAction('Quit')
+ action = menu.addAction("Quit")
action.triggered[bool].connect(qt.QApplication.instance().quit)
- menu = self.menuBar().addMenu('Edit')
+ menu = self.menuBar().addMenu("Edit")
menu.addAction(self._plot.getOutputToolBar().getCopyAction())
menu.addSeparator()
menu.addAction(self._plot.getResetZoomAction())
@@ -1188,7 +1187,7 @@ class StackViewMainWindow(StackView):
menu.addAction(actions.control.KeepAspectRatioAction(self._plot, self))
menu.addAction(actions.control.YAxisInvertedAction(self._plot, self))
- menu = self.menuBar().addMenu('Profile')
+ menu = self.menuBar().addMenu("Profile")
profileToolBar = self._profileToolBar
menu.addAction(profileToolBar.hLineAction)
menu.addAction(profileToolBar.vLineAction)
@@ -1218,11 +1217,11 @@ class StackViewMainWindow(StackView):
elif self._perspective == 2:
dim0, dim1, dim2 = int(y), int(x), img_idx
- msg = 'Position: (%d, %d, %d)' % (dim0, dim1, dim2)
+ msg = "Position: (%d, %d, %d)" % (dim0, dim1, dim2)
if value is not None:
- msg += ', Value: %g' % value
+ msg += ", Value: %g" % value
if self._dataInfo is not None:
- msg = self._dataInfo + ', ' + msg
+ msg = self._dataInfo + ", " + msg
self.statusBar().showMessage(msg)
@@ -1231,11 +1230,15 @@ class StackViewMainWindow(StackView):
See :meth:`StackView.setStack` for details.
"""
- if hasattr(stack, 'dtype') and hasattr(stack, 'shape'):
+ if hasattr(stack, "dtype") and hasattr(stack, "shape"):
assert len(stack.shape) == 3
nframes, height, width = stack.shape
- self._dataInfo = 'Data: %dx%dx%d (%s)' % (nframes, height, width,
- str(stack.dtype))
+ self._dataInfo = "Data: %dx%dx%d (%s)" % (
+ nframes,
+ height,
+ width,
+ str(stack.dtype),
+ )
self.statusBar().showMessage(self._dataInfo)
else:
self._dataInfo = None
diff --git a/src/silx/gui/plot/StatsWidget.py b/src/silx/gui/plot/StatsWidget.py
index b23946f..0c37f52 100644
--- a/src/silx/gui/plot/StatsWidget.py
+++ b/src/silx/gui/plot/StatsWidget.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -30,7 +30,6 @@ __license__ = "MIT"
__date__ = "24/07/2018"
-from collections import OrderedDict
from contextlib import contextmanager
import logging
import weakref
@@ -55,8 +54,8 @@ _logger = logging.getLogger(__name__)
@enum.unique
class UpdateMode(_Enum):
- AUTO = 'auto'
- MANUAL = 'manual'
+ AUTO = "auto"
+ MANUAL = "manual"
# Helper class to handle specific calls to PlotWidget and SceneWidget
@@ -126,7 +125,7 @@ class _Wrapper(qt.QObject):
:param item:
:rtype: str
"""
- return ''
+ return ""
def getKind(self, item):
"""Returns the kind of an item or None if not supported
@@ -164,18 +163,18 @@ class _PlotWidgetWrapper(_Wrapper):
self.sigCurrentChanged.emit(item)
def _activeCurveChanged(self, previous, current):
- self._activeChanged(kind='curve')
+ self._activeChanged(kind="curve")
def _activeImageChanged(self, previous, current):
- self._activeChanged(kind='image')
+ self._activeChanged(kind="image")
def _activeScatterChanged(self, previous, current):
- self._activeChanged(kind='scatter')
+ self._activeChanged(kind="scatter")
def _limitsChanged(self, event):
"""Handle change of plot area limits."""
- if event['event'] == 'limitsChanged':
- self.sigVisibleDataChanged.emit()
+ if event["event"] == "limitsChanged":
+ self.sigVisibleDataChanged.emit()
def getItems(self):
plot = self.getPlot()
@@ -200,20 +199,20 @@ class _PlotWidgetWrapper(_Wrapper):
kind = self.getKind(item)
if kind in plot._ACTIVE_ITEM_KINDS:
if plot._getActiveItem(kind) != item:
- plot._setActiveItem(kind, item.getName())
+ plot._setActiveItem(kind, item)
def getLabel(self, item):
return item.getName()
def getKind(self, item):
if isinstance(item, plotitems.Curve):
- return 'curve'
+ return "curve"
elif isinstance(item, plotitems.ImageData):
- return 'image'
+ return "image"
elif isinstance(item, plotitems.Scatter):
- return 'scatter'
+ return "scatter"
elif isinstance(item, plotitems.Histogram):
- return 'histogram'
+ return "histogram"
else:
return None
@@ -259,12 +258,10 @@ class _SceneWidgetWrapper(_Wrapper):
def getKind(self, item):
from ..plot3d import items as plot3ditems
- if isinstance(item, (plot3ditems.ImageData,
- plot3ditems.ScalarField3D)):
- return 'image'
- elif isinstance(item, (plot3ditems.Scatter2D,
- plot3ditems.Scatter3D)):
- return 'scatter'
+ if isinstance(item, (plot3ditems.ImageData, plot3ditems.ScalarField3D)):
+ return "image"
+ elif isinstance(item, (plot3ditems.Scatter2D, plot3ditems.Scatter3D)):
+ return "scatter"
else:
return None
@@ -306,10 +303,10 @@ class _ScalarFieldViewWrapper(_Wrapper):
pass
def getLabel(self, item):
- return 'Data'
+ return "Data"
def getKind(self, item):
- return 'image'
+ return "image"
class _Container(object):
@@ -319,6 +316,7 @@ class _Container(object):
:param QObject obj:
"""
+
def __init__(self, obj):
self._obj = obj
@@ -383,7 +381,10 @@ class _StatsWidgetBase(object):
else: # Expect a ScalarFieldView
self._plotWrapper = _ScalarFieldViewWrapper(plot)
else:
- _logger.warning('OpenGL not installed, %s not managed' % ('SceneWidget qnd ScalarFieldView'))
+ _logger.warning(
+ "OpenGL not installed, %s not managed"
+ % ("SceneWidget qnd ScalarFieldView")
+ )
self._dealWithPlotConnection(create=True)
def setStats(self, statsHandler):
@@ -422,16 +423,19 @@ class _StatsWidgetBase(object):
connections = [] # List of (signal, slot) to connect/disconnect
if self._statsOnVisibleData:
connections.append(
- (self._plotWrapper.sigVisibleDataChanged, self._updateAllStats))
+ (self._plotWrapper.sigVisibleDataChanged, self._updateAllStats)
+ )
if self._displayOnlyActItem:
connections.append(
- (self._plotWrapper.sigCurrentChanged, self._updateCurrentItem))
+ (self._plotWrapper.sigCurrentChanged, self._updateCurrentItem)
+ )
else:
connections += [
(self._plotWrapper.sigItemAdded, self._addItem),
(self._plotWrapper.sigItemRemoved, self._removeItem),
- (self._plotWrapper.sigCurrentChanged, self._plotCurrentChanged)]
+ (self._plotWrapper.sigCurrentChanged, self._plotCurrentChanged),
+ ]
for signal, slot in connections:
if create:
@@ -441,12 +445,12 @@ class _StatsWidgetBase(object):
def _updateItemObserve(self, *args):
"""Reload table depending on mode"""
- raise NotImplementedError('Base class')
+ raise NotImplementedError("Base class")
def _updateCurrentItem(self, *args):
"""specific callback for the sigCurrentChanged and with the
_displayOnlyActItem option."""
- raise NotImplementedError('Base class')
+ raise NotImplementedError("Base class")
def _updateStats(self, item, data_changed=False, roi_changed=False):
"""Update displayed information for given plot item
@@ -455,11 +459,11 @@ class _StatsWidgetBase(object):
:param bool data_changed: is the item data changed.
:param bool roi_changed: is the associated roi changed.
"""
- raise NotImplementedError('Base class')
+ raise NotImplementedError("Base class")
def _updateAllStats(self):
"""Update stats for all rows in the table"""
- raise NotImplementedError('Base class')
+ raise NotImplementedError("Base class")
def setDisplayOnlyActiveItem(self, displayOnlyActItem):
"""Toggle display off all items or only the active/selected one
@@ -494,21 +498,21 @@ class _StatsWidgetBase(object):
:returns: True if the item is added to the widget.
:rtype: bool
"""
- raise NotImplementedError('Base class')
+ raise NotImplementedError("Base class")
def _removeItem(self, item):
"""Remove table items corresponding to given plot item from the table.
:param item: The plot item
"""
- raise NotImplementedError('Base class')
+ raise NotImplementedError("Base class")
def _plotCurrentChanged(self, current):
"""Handle change of current item and update selection in table
:param current:
"""
- raise NotImplementedError('Base class')
+ raise NotImplementedError("Base class")
def clear(self):
"""clear GUI"""
@@ -562,16 +566,17 @@ class StatsTable(_StatsWidgetBase, TableWidget):
:class:`PlotWidget` or :class:`SceneWidget` instance on which to operate
"""
- _LEGEND_HEADER_DATA = 'legend'
- _KIND_HEADER_DATA = 'kind'
+ _LEGEND_HEADER_DATA = "legend"
+ _KIND_HEADER_DATA = "kind"
sigUpdateModeChanged = qt.Signal(object)
"""Signal emitted when the update mode changed"""
def __init__(self, parent=None, plot=None):
TableWidget.__init__(self, parent)
- _StatsWidgetBase.__init__(self, statsOnVisibleData=False,
- displayOnlyActItem=False)
+ _StatsWidgetBase.__init__(
+ self, statsOnVisibleData=False, displayOnlyActItem=False
+ )
# Init for _displayOnlyActItem == False
assert self._displayOnlyActItem is False
@@ -669,7 +674,15 @@ class StatsTable(_StatsWidgetBase, TableWidget):
If exists, update it only when we are in 'auto' mode"""
if self.getUpdateMode() is UpdateMode.MANUAL:
# when sigCurrentChanged is giving the current item
- if len(args) > 0 and isinstance(args[0], (plotitems.Curve, plotitems.Histogram, plotitems.ImageData, plotitems.Scatter)):
+ if len(args) > 0 and isinstance(
+ args[0],
+ (
+ plotitems.Curve,
+ plotitems.Histogram,
+ plotitems.ImageData,
+ plotitems.Scatter,
+ ),
+ ):
item = args[0]
tableItems = self._itemToTableItems(item)
# if the table does not exists yet
@@ -722,9 +735,9 @@ class StatsTable(_StatsWidgetBase, TableWidget):
:param item: The plot item
:return: An ordered dict of column name to QTableWidgetItem mapping
for the given plot item.
- :rtype: OrderedDict
+ :rtype: dict
"""
- result = OrderedDict()
+ result = {}
row = self._itemToRow(item)
if row is not None:
for column in range(self.columnCount()):
@@ -776,9 +789,7 @@ class StatsTable(_StatsWidgetBase, TableWidget):
return False
# Prepare table items
- tableItems = [
- qt.QTableWidgetItem(), # Legend
- qt.QTableWidgetItem()] # Kind
+ tableItems = [qt.QTableWidgetItem(), qt.QTableWidgetItem()] # Legend # Kind
for column in range(2, self.columnCount()):
header = self.horizontalHeaderItem(column)
@@ -805,8 +816,7 @@ class StatsTable(_StatsWidgetBase, TableWidget):
row = self.rowCount() - 1
for column, tableItem in enumerate(tableItems):
tableItem.setData(qt.Qt.UserRole, _Container(item))
- tableItem.setFlags(
- qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable)
+ tableItem.setFlags(qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable)
self.setItem(row, column, tableItem)
# Update table items content
@@ -815,8 +825,7 @@ class StatsTable(_StatsWidgetBase, TableWidget):
# Listen for item changes
# Using queued connection to avoid issue with sender
# being that of the signal calling the signal
- item.sigItemChanged.connect(self._plotItemChanged,
- qt.Qt.QueuedConnection)
+ item.sigItemChanged.connect(self._plotItemChanged, qt.Qt.QueuedConnection)
return True
@@ -871,8 +880,12 @@ class StatsTable(_StatsWidgetBase, TableWidget):
else:
roi_changed = False
stats = statsHandler.calculate(
- item, plot, self._statsOnVisibleData,
- data_changed=data_changed, roi_changed=roi_changed)
+ item,
+ plot,
+ self._statsOnVisibleData,
+ data_changed=data_changed,
+ roi_changed=roi_changed,
+ )
else:
stats = {}
@@ -887,7 +900,7 @@ class StatsTable(_StatsWidgetBase, TableWidget):
value = stats.get(name)
if value is None:
_logger.error("Value not found for: %s", name)
- tableItem.setText('-')
+ tableItem.setText("-")
else:
tableItem.setText(str(value))
@@ -943,6 +956,7 @@ class StatsTable(_StatsWidgetBase, TableWidget):
class UpdateModeWidget(qt.QWidget):
"""Widget used to select the mode of update"""
+
sigUpdateModeChanged = qt.Signal(object)
"""signal emitted when the mode for update changed"""
sigUpdateRequested = qt.Signal()
@@ -954,22 +968,22 @@ class UpdateModeWidget(qt.QWidget):
self._buttonGrp = qt.QButtonGroup(parent=self)
self._buttonGrp.setExclusive(True)
- spacer = qt.QSpacerItem(20, 20,
- qt.QSizePolicy.Expanding,
- qt.QSizePolicy.Minimum)
+ spacer = qt.QSpacerItem(
+ 20, 20, qt.QSizePolicy.Expanding, qt.QSizePolicy.Minimum
+ )
self.layout().addItem(spacer)
- self._autoRB = qt.QRadioButton('auto', parent=self)
+ self._autoRB = qt.QRadioButton("auto", parent=self)
self.layout().addWidget(self._autoRB)
self._buttonGrp.addButton(self._autoRB)
- self._manualRB = qt.QRadioButton('manual', parent=self)
+ self._manualRB = qt.QRadioButton("manual", parent=self)
self.layout().addWidget(self._manualRB)
self._buttonGrp.addButton(self._manualRB)
self._manualRB.setChecked(True)
- refresh_icon = icons.getQIcon('view-refresh')
- self._updatePB = qt.QPushButton(refresh_icon, '', parent=self)
+ refresh_icon = icons.getQIcon("view-refresh")
+ self._updatePB = qt.QPushButton(refresh_icon, "", parent=self)
self.layout().addWidget(self._updatePB)
# connect signal / SLOT
@@ -1006,7 +1020,7 @@ class UpdateModeWidget(qt.QWidget):
if not self._manualRB.isChecked():
self._manualRB.setChecked(True)
else:
- raise ValueError('mode', mode, 'is not recognized')
+ raise ValueError("mode", mode, "is not recognized")
def getUpdateMode(self):
"""Returns update mode (See :meth:`setUpdateMode`).
@@ -1031,7 +1045,6 @@ class UpdateModeWidget(qt.QWidget):
class _OptionsWidget(qt.QToolBar):
-
def __init__(self, parent=None, updateMode=None, displayOnlyActItem=False):
assert updateMode is not None
qt.QToolBar.__init__(self, parent)
@@ -1055,12 +1068,14 @@ class _OptionsWidget(qt.QToolBar):
action = qt.QAction(self)
action.setIcon(icons.getQIcon("stats-visible-data"))
action.setText("Use the visible data range")
- action.setToolTip("Use the visible data range.<br/>"
- "If activated the data is filtered to only use"
- "visible data of the plot."
- "The filtering is a data sub-sampling."
- "No interpolation is made to fit data to"
- "boundaries.")
+ action.setToolTip(
+ "Use the visible data range.<br/>"
+ "If activated the data is filtered to only use"
+ "visible data of the plot."
+ "The filtering is a data sub-sampling."
+ "No interpolation is made to fit data to"
+ "boundaries."
+ )
action.setCheckable(True)
self.__useVisibleData = action
@@ -1156,7 +1171,7 @@ class StatsWidget(qt.QWidget):
It Provides the visibility of the widget.
"""
- NUMBER_FORMAT = '{0:.3f}'
+ NUMBER_FORMAT = "{0:.3f}"
def __init__(self, parent=None, plot=None, stats=None):
qt.QWidget.__init__(self, parent)
@@ -1172,15 +1187,15 @@ class StatsWidget(qt.QWidget):
self.layout().addWidget(self._statsTable)
old = self._statsTable.blockSignals(True)
- self._options.itemSelection.triggered.connect(
- self._optSelectionChanged)
- self._options.dataRangeSelection.triggered.connect(
- self._optDataRangeChanged)
+ self._options.itemSelection.triggered.connect(self._optSelectionChanged)
+ self._options.dataRangeSelection.triggered.connect(self._optDataRangeChanged)
self._optDataRangeChanged()
self._statsTable.blockSignals(old)
self._statsTable.sigUpdateModeChanged.connect(self._options._setUpdateMode)
- callback = functools.partial(self._getStatsTable()._updateAllStats, is_request=True)
+ callback = functools.partial(
+ self._getStatsTable()._updateAllStats, is_request=True
+ )
self._options.sigUpdateStats.connect(callback)
def _getStatsTable(self):
@@ -1199,12 +1214,12 @@ class StatsWidget(qt.QWidget):
qt.QWidget.hideEvent(self, event)
def _optSelectionChanged(self, action=None):
- self._getStatsTable().setDisplayOnlyActiveItem(
- self._options.isActiveItemMode())
+ self._getStatsTable().setDisplayOnlyActiveItem(self._options.isActiveItemMode())
def _optDataRangeChanged(self, action=None):
self._getStatsTable().setStatsOnVisibleData(
- self._options.isVisibleDataRangeMode())
+ self._options.isVisibleDataRangeMode()
+ )
# Proxy methods
@@ -1215,7 +1230,8 @@ class StatsWidget(qt.QWidget):
@docstring(StatsTable)
def setPlot(self, plot):
self._options.setVisibleDataRangeModeEnabled(
- plot is None or isinstance(plot, PlotWidget))
+ plot is None or isinstance(plot, PlotWidget)
+ )
return self._getStatsTable().setPlot(plot=plot)
@docstring(StatsTable)
@@ -1229,7 +1245,8 @@ class StatsWidget(qt.QWidget):
self._options.setDisplayActiveItems(displayOnlyActItem)
self._options.blockSignals(old)
return self._getStatsTable().setDisplayOnlyActiveItem(
- displayOnlyActItem=displayOnlyActItem)
+ displayOnlyActItem=displayOnlyActItem
+ )
@docstring(StatsTable)
def setStatsOnVisibleData(self, b):
@@ -1244,15 +1261,17 @@ class StatsWidget(qt.QWidget):
self._statsTable.setUpdateMode(mode)
-DEFAULT_STATS = StatsHandler((
- (statsmdl.StatMin(), StatFormatter()),
- statsmdl.StatCoordMin(),
- (statsmdl.StatMax(), StatFormatter()),
- statsmdl.StatCoordMax(),
- statsmdl.StatCOM(),
- (('mean', numpy.mean), StatFormatter()),
- (('std', numpy.std), StatFormatter()),
-))
+DEFAULT_STATS = StatsHandler(
+ (
+ (statsmdl.StatMin(), StatFormatter()),
+ statsmdl.StatCoordMin(),
+ (statsmdl.StatMax(), StatFormatter()),
+ statsmdl.StatCoordMax(),
+ statsmdl.StatCOM(),
+ (("mean", numpy.mean), StatFormatter()),
+ (("std", numpy.std), StatFormatter()),
+ )
+)
class BasicStatsWidget(StatsWidget):
@@ -1282,9 +1301,9 @@ class BasicStatsWidget(StatsWidget):
widget = BasicStatsWidget(plot=plot)
widget.show()
"""
+
def __init__(self, parent=None, plot=None):
- StatsWidget.__init__(self, parent=parent, plot=plot,
- stats=DEFAULT_STATS)
+ StatsWidget.__init__(self, parent=parent, plot=plot, stats=DEFAULT_STATS)
class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget):
@@ -1306,8 +1325,9 @@ class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget):
sigUpdateModeChanged = qt.Signal(object)
"""Signal emitted when the update mode changed"""
- def __init__(self, parent=None, plot=None, kind='curve', stats=None,
- statsOnVisibleData=False):
+ def __init__(
+ self, parent=None, plot=None, kind="curve", stats=None, statsOnVisibleData=False
+ ):
self._item_kind = kind
"""The item displayed"""
self._statQlineEdit = {}
@@ -1315,9 +1335,9 @@ class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget):
self._n_statistics_per_line = 4
"""number of statistics displayed per line in the grid layout"""
qt.QWidget.__init__(self, parent)
- _StatsWidgetBase.__init__(self,
- statsOnVisibleData=statsOnVisibleData,
- displayOnlyActItem=True)
+ _StatsWidgetBase.__init__(
+ self, statsOnVisibleData=statsOnVisibleData, displayOnlyActItem=True
+ )
self.setLayout(self._createLayout())
self.setPlot(plot)
if stats is not None:
@@ -1336,8 +1356,8 @@ class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget):
widget = qt.QWidget(parent=self)
parent = widget
- qLabel = qt.QLabel(statistic.name + ':', parent=parent)
- qLineEdit = qt.QLineEdit('', parent=parent)
+ qLabel = qt.QLabel(statistic.name + ":", parent=parent)
+ qLineEdit = qt.QLineEdit("", parent=parent)
qLineEdit.setReadOnly(True)
self._addStatsWidgetsToLayout(qLabel=qLabel, qLineEdit=qLineEdit)
@@ -1353,7 +1373,7 @@ class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget):
self._updateAllStats()
def _addStatsWidgetsToLayout(self, qLabel, qLineEdit):
- raise NotImplementedError('Base class')
+ raise NotImplementedError("Base class")
def setStats(self, statsHandler):
"""Set which stats to display and the associated formatting.
@@ -1379,6 +1399,7 @@ class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget):
def kind_filter(_item):
return self._plotWrapper.getKind(_item) == self.getKind()
+
items = list(filter(kind_filter, _items))
assert len(items) in (0, 1)
if len(items) == 1:
@@ -1402,15 +1423,13 @@ class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget):
def _setItem(self, item, data_changed=True):
if item is None:
for stat_name, stat_widget in self._statQlineEdit.items():
- stat_widget.setText('')
- elif (self._statsHandler is not None and len(
- self._statsHandler.stats) > 0):
+ stat_widget.setText("")
+ elif self._statsHandler is not None and len(self._statsHandler.stats) > 0:
plot = self.getPlot()
if plot is not None:
- statsValDict = self._statsHandler.calculate(item,
- plot,
- self._statsOnVisibleData,
- data_changed=data_changed)
+ statsValDict = self._statsHandler.calculate(
+ item, plot, self._statsOnVisibleData, data_changed=data_changed
+ )
for statName, statVal in list(statsValDict.items()):
self._statQlineEdit[statName].setText(statVal)
@@ -1422,6 +1441,7 @@ class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget):
def kind_filter(_item):
return self._plotWrapper.getKind(_item) == self.getKind()
+
items = list(filter(kind_filter, _items))
assert len(items) in (0, 1)
_item = items[0] if len(items) == 1 else None
@@ -1432,27 +1452,38 @@ class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget):
def _createLayout(self):
"""create an instance of the main QLayout"""
- raise NotImplementedError('Base class')
+ raise NotImplementedError("Base class")
def _addItem(self, item):
- raise NotImplementedError('Display only the active item')
+ raise NotImplementedError("Display only the active item")
def _removeItem(self, item):
- raise NotImplementedError('Display only the active item')
+ raise NotImplementedError("Display only the active item")
def _plotCurrentChanged(self, current):
- raise NotImplementedError('Display only the active item')
+ raise NotImplementedError("Display only the active item")
def _updateModeHasChanged(self):
self.sigUpdateModeChanged.emit(self._updateMode)
class _BasicLineStatsWidget(_BaseLineStatsWidget):
- def __init__(self, parent=None, plot=None, kind='curve',
- stats=DEFAULT_STATS, statsOnVisibleData=False):
- _BaseLineStatsWidget.__init__(self, parent=parent, kind=kind,
- plot=plot, stats=stats,
- statsOnVisibleData=statsOnVisibleData)
+ def __init__(
+ self,
+ parent=None,
+ plot=None,
+ kind="curve",
+ stats=DEFAULT_STATS,
+ statsOnVisibleData=False,
+ ):
+ _BaseLineStatsWidget.__init__(
+ self,
+ parent=parent,
+ kind=kind,
+ plot=plot,
+ stats=stats,
+ statsOnVisibleData=statsOnVisibleData,
+ )
def _createLayout(self):
return FlowLayout()
@@ -1488,15 +1519,26 @@ class BasicLineStatsWidget(qt.QWidget):
:param bool statsOnVisibleData: compute statistics for the whole data or
only visible ones.
"""
- def __init__(self, parent=None, plot=None, kind='curve',
- stats=DEFAULT_STATS, statsOnVisibleData=False):
+
+ def __init__(
+ self,
+ parent=None,
+ plot=None,
+ kind="curve",
+ stats=DEFAULT_STATS,
+ statsOnVisibleData=False,
+ ):
qt.QWidget.__init__(self, parent)
self.setLayout(qt.QHBoxLayout())
self.layout().setSpacing(0)
self.layout().setContentsMargins(0, 0, 0, 0)
- self._lineStatsWidget = _BasicLineStatsWidget(parent=self, plot=plot,
- kind=kind, stats=stats,
- statsOnVisibleData=statsOnVisibleData)
+ self._lineStatsWidget = _BasicLineStatsWidget(
+ parent=self,
+ plot=plot,
+ kind=kind,
+ stats=stats,
+ statsOnVisibleData=statsOnVisibleData,
+ )
self.layout().addWidget(self._lineStatsWidget)
self._options = UpdateModeWidget()
@@ -1548,12 +1590,23 @@ class BasicLineStatsWidget(qt.QWidget):
class _BasicGridStatsWidget(_BaseLineStatsWidget):
- def __init__(self, parent=None, plot=None, kind='curve',
- stats=DEFAULT_STATS, statsOnVisibleData=False,
- statsPerLine=4):
- _BaseLineStatsWidget.__init__(self, parent=parent, kind=kind,
- plot=plot, stats=stats,
- statsOnVisibleData=statsOnVisibleData)
+ def __init__(
+ self,
+ parent=None,
+ plot=None,
+ kind="curve",
+ stats=DEFAULT_STATS,
+ statsOnVisibleData=False,
+ statsPerLine=4,
+ ):
+ _BaseLineStatsWidget.__init__(
+ self,
+ parent=parent,
+ kind=kind,
+ plot=plot,
+ stats=stats,
+ statsOnVisibleData=statsOnVisibleData,
+ )
self._n_statistics_per_line = statsPerLine
def _addStatsWidgetsToLayout(self, qLabel, qLineEdit):
@@ -1597,8 +1650,14 @@ class BasicGridStatsWidget(qt.QWidget):
widget.show()
"""
- def __init__(self, parent=None, plot=None, kind='curve',
- stats=DEFAULT_STATS, statsOnVisibleData=False):
+ def __init__(
+ self,
+ parent=None,
+ plot=None,
+ kind="curve",
+ stats=DEFAULT_STATS,
+ statsOnVisibleData=False,
+ ):
qt.QWidget.__init__(self, parent)
self.setLayout(qt.QVBoxLayout())
self.layout().setSpacing(0)
@@ -1608,9 +1667,13 @@ class BasicGridStatsWidget(qt.QWidget):
self._options.showRadioButtons(False)
self.layout().addWidget(self._options)
- self._lineStatsWidget = _BasicGridStatsWidget(parent=self, plot=plot,
- kind=kind, stats=stats,
- statsOnVisibleData=statsOnVisibleData)
+ self._lineStatsWidget = _BasicGridStatsWidget(
+ parent=self,
+ plot=plot,
+ kind=kind,
+ stats=stats,
+ statsOnVisibleData=statsOnVisibleData,
+ )
self.layout().addWidget(self._lineStatsWidget)
# tune options
diff --git a/src/silx/gui/plot/_BaseMaskToolsWidget.py b/src/silx/gui/plot/_BaseMaskToolsWidget.py
index 1673137..6b98289 100644
--- a/src/silx/gui/plot/_BaseMaskToolsWidget.py
+++ b/src/silx/gui/plot/_BaseMaskToolsWidget.py
@@ -134,7 +134,7 @@ class BaseMask(qt.QObject):
:param bool copy: True (the default) to copy the array,
False to use it as is if possible.
"""
- self._mask = numpy.array(mask, copy=copy, order='C', dtype=numpy.uint8)
+ self._mask = numpy.array(mask, copy=copy, order="C", dtype=numpy.uint8)
self._notify()
# History control
@@ -147,8 +147,11 @@ class BaseMask(qt.QObject):
def commit(self):
"""Append the current mask to history if changed"""
- if (not self._history or self._redo or
- not numpy.array_equal(self._mask, self._history[-1])):
+ if (
+ not self._history
+ or self._redo
+ or not numpy.array_equal(self._mask, self._history[-1])
+ ):
if self._redo:
self._redo = [] # Reset redo as a new action as been performed
self.sigRedoable[bool].emit(False)
@@ -222,7 +225,7 @@ class BaseMask(qt.QObject):
if shape is None:
# assume dimensionality never changes
shape = (0,) * len(self._mask.shape) # empty array
- shapeChanged = (shape != self._mask.shape)
+ shapeChanged = shape != self._mask.shape
self._mask = numpy.zeros(shape, dtype=numpy.uint8)
if shapeChanged:
self.resetHistory()
@@ -263,9 +266,7 @@ class BaseMask(qt.QObject):
:param float threshold: Threshold
:param bool mask: True to mask (default), False to unmask.
"""
- self.updateStencil(level,
- self.getDataValues() < threshold,
- mask)
+ self.updateStencil(level, self.getDataValues() < threshold, mask)
def updateBetweenThresholds(self, level, min_, max_, mask=True):
"""Mask/unmask all points whose values are in a range.
@@ -275,8 +276,9 @@ class BaseMask(qt.QObject):
:param float max_: Upper threshold
:param bool mask: True to mask (default), False to unmask.
"""
- stencil = numpy.logical_and(min_ <= self.getDataValues(),
- self.getDataValues() <= max_)
+ stencil = numpy.logical_and(
+ min_ <= self.getDataValues(), self.getDataValues() <= max_
+ )
self.updateStencil(level, stencil, mask)
def updateAboveThreshold(self, level, threshold, mask=True):
@@ -286,9 +288,7 @@ class BaseMask(qt.QObject):
:param float threshold: Threshold.
:param bool mask: True to mask (default), False to unmask.
"""
- self.updateStencil(level,
- self.getDataValues() > threshold,
- mask)
+ self.updateStencil(level, self.getDataValues() > threshold, mask)
def updateNotFinite(self, level, mask=True):
"""Mask/unmask all points whose values are not finite.
@@ -296,9 +296,9 @@ class BaseMask(qt.QObject):
:param int level: Mask level to update.
:param bool mask: True to mask (default), False to unmask.
"""
- self.updateStencil(level,
- numpy.logical_not(numpy.isfinite(self.getDataValues())),
- mask)
+ self.updateStencil(
+ level, numpy.logical_not(numpy.isfinite(self.getDataValues())), mask
+ )
# Drawing operations:
def updateRectangle(self, level, row, col, height, width, mask=True):
@@ -390,18 +390,20 @@ class BaseMaskToolsWidget(qt.QWidget):
# register if the user as force a color for the corresponding mask level
self._defaultColors = numpy.ones((self._maxLevelNumber + 1), dtype=bool)
# overlays colors set by the user
- self._overlayColors = numpy.zeros((self._maxLevelNumber + 1, 3), dtype=numpy.float32)
+ self._overlayColors = numpy.zeros(
+ (self._maxLevelNumber + 1, 3), dtype=numpy.float32
+ )
# as parent have to be the first argument of the widget to fit
# QtDesigner need but here plot can't be None by default.
assert plot is not None
self._plotRef = weakref.ref(plot)
- self._maskName = '__MASK_TOOLS_%d' % id(self) # Legend of the mask
+ self._maskName = "__MASK_TOOLS_%d" % id(self) # Legend of the mask
- self._colormap = Colormap(normalization='linear',
- vmin=0,
- vmax=self._maxLevelNumber)
- self._defaultOverlayColor = rgba('gray') # Color of the mask
+ self._colormap = Colormap(
+ normalization="linear", vmin=0, vmax=self._maxLevelNumber
+ )
+ self._defaultOverlayColor = rgba("gray") # Color of the mask
self._setMaskColors(1, 0.5) # Set the colormap LUT
if not isinstance(mask, BaseMask):
@@ -413,11 +415,10 @@ class BaseMaskToolsWidget(qt.QWidget):
self._drawingMode = None # Store current drawing mode
self._lastPencilPos = None
- self._multipleMasks = 'exclusive'
+ self._multipleMasks = "exclusive"
self._maskFileDir = qt.QDir.current().absolutePath()
- self.plot.sigInteractiveModeChanged.connect(
- self._interactiveModeChanged)
+ self.plot.sigInteractiveModeChanged.connect(self._interactiveModeChanged)
self._initWidgets()
@@ -470,11 +471,11 @@ class BaseMaskToolsWidget(qt.QWidget):
:param str mode: The mode to use
"""
- assert mode in ('exclusive', 'single')
+ assert mode in ("exclusive", "single")
if mode != self._multipleMasks:
self._multipleMasks = mode
- self._levelWidget.setVisible(self._multipleMasks != 'single')
- self._clearAllBtn.setVisible(self._multipleMasks != 'single')
+ self._levelWidget.setVisible(self._multipleMasks != "single")
+ self._clearAllBtn.setVisible(self._multipleMasks != "single")
def setMaskFileDirectory(self, path):
"""Set the default directory to use by load/save GUI tools
@@ -505,7 +506,8 @@ class BaseMaskToolsWidget(qt.QWidget):
plot = self._plotRef()
if plot is None:
raise RuntimeError(
- 'Mask widget attached to a PlotWidget that no longer exists')
+ "Mask widget attached to a PlotWidget that no longer exists"
+ )
return plot
def setDirection(self, direction=qt.QBoxLayout.LeftToRight):
@@ -534,7 +536,7 @@ class BaseMaskToolsWidget(qt.QWidget):
False for no trailing stretch
:return: A QWidget with a QHBoxLayout
"""
- stretch = kwargs.get('stretch', True)
+ stretch = kwargs.get("stretch", True)
layout = qt.QHBoxLayout()
layout.setContentsMargins(0, 0, 0, 0)
@@ -547,20 +549,27 @@ class BaseMaskToolsWidget(qt.QWidget):
return widget
def _initTransparencyWidget(self):
- """ Init the mask transparency widget """
+ """Init the mask transparency widget"""
transparencyWidget = qt.QWidget(parent=self)
grid = qt.QGridLayout()
grid.setContentsMargins(0, 0, 0, 0)
- self.transparencySlider = qt.QSlider(qt.Qt.Horizontal, parent=transparencyWidget)
+ self.transparencySlider = qt.QSlider(
+ qt.Qt.Horizontal, parent=transparencyWidget
+ )
self.transparencySlider.setRange(3, 10)
self.transparencySlider.setValue(8)
- self.transparencySlider.setToolTip(
- 'Set the transparency of the mask display')
+ self.transparencySlider.setToolTip("Set the transparency of the mask display")
self.transparencySlider.valueChanged.connect(self._updateColors)
- grid.addWidget(qt.QLabel('Display:', parent=transparencyWidget), 0, 0)
+ grid.addWidget(qt.QLabel("Display:", parent=transparencyWidget), 0, 0)
grid.addWidget(self.transparencySlider, 0, 1, 1, 3)
- grid.addWidget(qt.QLabel('<small><b>Transparent</b></small>', parent=transparencyWidget), 1, 1)
- grid.addWidget(qt.QLabel('<small><b>Opaque</b></small>', parent=transparencyWidget), 1, 3)
+ grid.addWidget(
+ qt.QLabel("<small><b>Transparent</b></small>", parent=transparencyWidget),
+ 1,
+ 1,
+ )
+ grid.addWidget(
+ qt.QLabel("<small><b>Opaque</b></small>", parent=transparencyWidget), 1, 3
+ )
transparencyWidget.setLayout(grid)
return transparencyWidget
@@ -571,11 +580,13 @@ class BaseMaskToolsWidget(qt.QWidget):
self.levelSpinBox = qt.QSpinBox()
self.levelSpinBox.setRange(1, self._maxLevelNumber)
self.levelSpinBox.setToolTip(
- 'Choose which mask level is edited.\n'
- 'A mask can have up to 255 non-overlapping levels.')
+ "Choose which mask level is edited.\n"
+ "A mask can have up to 255 non-overlapping levels."
+ )
self.levelSpinBox.valueChanged[int].connect(self._updateColors)
- self._levelWidget = self._hboxWidget(qt.QLabel('Mask level:'),
- self.levelSpinBox)
+ self._levelWidget = self._hboxWidget(
+ qt.QLabel("Mask level:"), self.levelSpinBox
+ )
# Transparency
self._transparencyWidget = self._initTransparencyWidget()
@@ -593,62 +604,66 @@ class BaseMaskToolsWidget(qt.QWidget):
return qt.QIcon()
undoAction = qt.QAction(self)
- undoAction.setText('Undo')
+ undoAction.setText("Undo")
icon = getIcon("edit-undo", qt.QStyle.SP_ArrowBack)
undoAction.setIcon(icon)
undoAction.setShortcut(qt.QKeySequence.Undo)
- undoAction.setToolTip('Undo last mask change <b>%s</b>' %
- undoAction.shortcut().toString())
+ undoAction.setToolTip(
+ "Undo last mask change <b>%s</b>" % undoAction.shortcut().toString()
+ )
self._mask.sigUndoable.connect(undoAction.setEnabled)
undoAction.triggered.connect(self._mask.undo)
redoAction = qt.QAction(self)
- redoAction.setText('Redo')
+ redoAction.setText("Redo")
icon = getIcon("edit-redo", qt.QStyle.SP_ArrowForward)
redoAction.setIcon(icon)
redoAction.setShortcut(qt.QKeySequence.Redo)
- redoAction.setToolTip('Redo last undone mask change <b>%s</b>' %
- redoAction.shortcut().toString())
+ redoAction.setToolTip(
+ "Redo last undone mask change <b>%s</b>" % redoAction.shortcut().toString()
+ )
self._mask.sigRedoable.connect(redoAction.setEnabled)
redoAction.triggered.connect(self._mask.redo)
loadAction = qt.QAction(self)
- loadAction.setText('Load...')
+ loadAction.setText("Load...")
icon = icons.getQIcon("document-open")
loadAction.setIcon(icon)
- loadAction.setToolTip('Load mask from file')
+ loadAction.setToolTip("Load mask from file")
loadAction.triggered.connect(self._loadMask)
saveAction = qt.QAction(self)
- saveAction.setText('Save...')
+ saveAction.setText("Save...")
icon = icons.getQIcon("document-save")
saveAction.setIcon(icon)
- saveAction.setToolTip('Save mask to file')
+ saveAction.setToolTip("Save mask to file")
saveAction.triggered.connect(self._saveMask)
invertAction = qt.QAction(self)
- invertAction.setText('Invert')
+ invertAction.setText("Invert")
icon = icons.getQIcon("mask-invert")
invertAction.setIcon(icon)
invertAction.setShortcut(qt.QKeySequence(qt.Qt.CTRL | qt.Qt.Key_I))
- invertAction.setToolTip('Invert current mask <b>%s</b>' %
- invertAction.shortcut().toString())
+ invertAction.setToolTip(
+ "Invert current mask <b>%s</b>" % invertAction.shortcut().toString()
+ )
invertAction.triggered.connect(self._handleInvertMask)
clearAction = qt.QAction(self)
- clearAction.setText('Clear')
+ clearAction.setText("Clear")
icon = icons.getQIcon("mask-clear")
clearAction.setIcon(icon)
clearAction.setShortcut(qt.QKeySequence.Delete)
- clearAction.setToolTip('Clear current mask level <b>%s</b>' %
- clearAction.shortcut().toString())
+ clearAction.setToolTip(
+ "Clear current mask level <b>%s</b>" % clearAction.shortcut().toString()
+ )
clearAction.triggered.connect(self._handleClearMask)
clearAllAction = qt.QAction(self)
- clearAllAction.setText('Clear all')
+ clearAllAction.setText("Clear all")
icon = icons.getQIcon("mask-clear-all")
clearAllAction.setIcon(icon)
- clearAllAction.setToolTip('Clear all mask levels')
+ clearAllAction.setToolTip("Clear all mask levels")
clearAllAction.triggered.connect(self.resetSelectionMask)
# Buttons group
@@ -657,9 +672,17 @@ class BaseMaskToolsWidget(qt.QWidget):
margin2 = qt.QWidget(self)
margin2.setMinimumWidth(6)
- actions = (loadAction, saveAction, margin1,
- undoAction, redoAction, margin2,
- invertAction, clearAction, clearAllAction)
+ actions = (
+ loadAction,
+ saveAction,
+ margin1,
+ undoAction,
+ redoAction,
+ margin2,
+ invertAction,
+ clearAction,
+ clearAllAction,
+ )
widgets = []
for action in actions:
if isinstance(action, qt.QWidget):
@@ -679,7 +702,7 @@ class BaseMaskToolsWidget(qt.QWidget):
layout.addWidget(self._transparencyWidget)
layout.addStretch(1)
- maskGroup = qt.QGroupBox('Mask')
+ maskGroup = qt.QGroupBox("Mask")
maskGroup.setLayout(layout)
return maskGroup
@@ -695,44 +718,46 @@ class BaseMaskToolsWidget(qt.QWidget):
self.addAction(self.browseAction)
# Draw tools
- self.rectAction = qt.QAction(icons.getQIcon('shape-rectangle'),
- 'Rectangle selection',
- self)
+ self.rectAction = qt.QAction(
+ icons.getQIcon("shape-rectangle"), "Rectangle selection", self
+ )
self.rectAction.setToolTip(
- 'Rectangle selection tool: (Un)Mask a rectangular region <b>R</b>')
+ "Rectangle selection tool: (Un)Mask a rectangular region <b>R</b>"
+ )
self.rectAction.setShortcut(qt.QKeySequence(qt.Qt.Key_R))
self.rectAction.setCheckable(True)
self.rectAction.triggered.connect(self._activeRectMode)
self.addAction(self.rectAction)
- self.ellipseAction = qt.QAction(icons.getQIcon('shape-ellipse'),
- 'Circle selection',
- self)
+ self.ellipseAction = qt.QAction(
+ icons.getQIcon("shape-ellipse"), "Circle selection", self
+ )
self.ellipseAction.setToolTip(
- 'Rectangle selection tool: (Un)Mask a circle region <b>R</b>')
+ "Rectangle selection tool: (Un)Mask a circle region <b>R</b>"
+ )
self.ellipseAction.setShortcut(qt.QKeySequence(qt.Qt.Key_R))
self.ellipseAction.setCheckable(True)
self.ellipseAction.triggered.connect(self._activeEllipseMode)
self.addAction(self.ellipseAction)
- self.polygonAction = qt.QAction(icons.getQIcon('shape-polygon'),
- 'Polygon selection',
- self)
+ self.polygonAction = qt.QAction(
+ icons.getQIcon("shape-polygon"), "Polygon selection", self
+ )
self.polygonAction.setShortcut(qt.QKeySequence(qt.Qt.Key_S))
self.polygonAction.setToolTip(
- 'Polygon selection tool: (Un)Mask a polygonal region <b>S</b><br>'
- 'Left-click to place new polygon corners<br>'
- 'Left-click on first corner to close the polygon')
+ "Polygon selection tool: (Un)Mask a polygonal region <b>S</b><br>"
+ "Left-click to place new polygon corners<br>"
+ "Left-click on first corner to close the polygon"
+ )
self.polygonAction.setCheckable(True)
self.polygonAction.triggered.connect(self._activePolygonMode)
self.addAction(self.polygonAction)
- self.pencilAction = qt.QAction(icons.getQIcon('draw-pencil'),
- 'Pencil tool',
- self)
+ self.pencilAction = qt.QAction(
+ icons.getQIcon("draw-pencil"), "Pencil tool", self
+ )
self.pencilAction.setShortcut(qt.QKeySequence(qt.Qt.Key_P))
- self.pencilAction.setToolTip(
- 'Pencil tool: (Un)Mask using a pencil <b>P</b>')
+ self.pencilAction.setToolTip("Pencil tool: (Un)Mask using a pencil <b>P</b>")
self.pencilAction.setCheckable(True)
self.pencilAction.triggered.connect(self._activePencilMode)
self.addAction(self.pencilAction)
@@ -744,8 +769,13 @@ class BaseMaskToolsWidget(qt.QWidget):
self.drawActionGroup.addAction(self.polygonAction)
self.drawActionGroup.addAction(self.pencilAction)
- actions = (self.browseAction, self.rectAction, self.ellipseAction,
- self.polygonAction, self.pencilAction)
+ actions = (
+ self.browseAction,
+ self.rectAction,
+ self.ellipseAction,
+ self.polygonAction,
+ self.pencilAction,
+ )
drawButtons = []
for action in actions:
btn = qt.QToolButton()
@@ -755,14 +785,16 @@ class BaseMaskToolsWidget(qt.QWidget):
layout.addWidget(container)
# Mask/Unmask radio buttons
- maskRadioBtn = qt.QRadioButton('Mask')
+ maskRadioBtn = qt.QRadioButton("Mask")
maskRadioBtn.setToolTip(
- 'Drawing masks with current level. Press <b>Ctrl</b> to unmask')
+ "Drawing masks with current level. Press <b>Ctrl</b> to unmask"
+ )
maskRadioBtn.setChecked(True)
- unmaskRadioBtn = qt.QRadioButton('Unmask')
+ unmaskRadioBtn = qt.QRadioButton("Unmask")
unmaskRadioBtn.setToolTip(
- 'Drawing unmasks with current level. Press <b>Ctrl</b> to mask')
+ "Drawing unmasks with current level. Press <b>Ctrl</b> to mask"
+ )
self.maskStateGroup = qt.QButtonGroup()
self.maskStateGroup.addButton(maskRadioBtn, 1)
@@ -780,7 +812,7 @@ class BaseMaskToolsWidget(qt.QWidget):
layout.addStretch(1)
- drawGroup = qt.QGroupBox('Draw tools')
+ drawGroup = qt.QGroupBox("Draw tools")
drawGroup.setLayout(layout)
return drawGroup
@@ -797,7 +829,7 @@ class BaseMaskToolsWidget(qt.QWidget):
self.pencilSlider.setRange(1, 50)
self.pencilSlider.setToolTip(pencilToolTip)
- pencilLabel = qt.QLabel('Pencil size:', parent=pencilSetting)
+ pencilLabel = qt.QLabel("Pencil size:", parent=pencilSetting)
layout = qt.QGridLayout()
layout.addWidget(pencilLabel, 0, 0)
@@ -813,26 +845,29 @@ class BaseMaskToolsWidget(qt.QWidget):
def _initThresholdGroupBox(self):
"""Init thresholding widgets"""
- self.belowThresholdAction = qt.QAction(icons.getQIcon('plot-roi-below'),
- 'Mask below threshold',
- self)
+ self.belowThresholdAction = qt.QAction(
+ icons.getQIcon("plot-roi-below"), "Mask below threshold", self
+ )
self.belowThresholdAction.setToolTip(
- 'Mask image where values are below given threshold')
+ "Mask image where values are below given threshold"
+ )
self.belowThresholdAction.setCheckable(True)
self.belowThresholdAction.setChecked(True)
- self.betweenThresholdAction = qt.QAction(icons.getQIcon('plot-roi-between'),
- 'Mask within range',
- self)
+ self.betweenThresholdAction = qt.QAction(
+ icons.getQIcon("plot-roi-between"), "Mask within range", self
+ )
self.betweenThresholdAction.setToolTip(
- 'Mask image where values are within given range')
+ "Mask image where values are within given range"
+ )
self.betweenThresholdAction.setCheckable(True)
- self.aboveThresholdAction = qt.QAction(icons.getQIcon('plot-roi-above'),
- 'Mask above threshold',
- self)
+ self.aboveThresholdAction = qt.QAction(
+ icons.getQIcon("plot-roi-above"), "Mask above threshold", self
+ )
self.aboveThresholdAction.setToolTip(
- 'Mask image where values are above given threshold')
+ "Mask image where values are above given threshold"
+ )
self.aboveThresholdAction.setCheckable(True)
self.thresholdActionGroup = qt.QActionGroup(self)
@@ -840,17 +875,18 @@ class BaseMaskToolsWidget(qt.QWidget):
self.thresholdActionGroup.addAction(self.belowThresholdAction)
self.thresholdActionGroup.addAction(self.betweenThresholdAction)
self.thresholdActionGroup.addAction(self.aboveThresholdAction)
- self.thresholdActionGroup.triggered.connect(
- self._thresholdActionGroupTriggered)
+ self.thresholdActionGroup.triggered.connect(self._thresholdActionGroupTriggered)
- self.loadColormapRangeAction = qt.QAction(icons.getQIcon('view-refresh'),
- 'Set min-max from colormap',
- self)
+ self.loadColormapRangeAction = qt.QAction(
+ icons.getQIcon("view-refresh"), "Set min-max from colormap", self
+ )
self.loadColormapRangeAction.setToolTip(
- 'Set min and max values from current colormap range')
+ "Set min and max values from current colormap range"
+ )
self.loadColormapRangeAction.setCheckable(False)
self.loadColormapRangeAction.triggered.connect(
- self._loadRangeFromColormapTriggered)
+ self._loadRangeFromColormapTriggered
+ )
widgets = []
for action in self.thresholdActionGroup.actions():
@@ -859,8 +895,7 @@ class BaseMaskToolsWidget(qt.QWidget):
widgets.append(btn)
spacer = qt.QWidget(parent=self)
- spacer.setSizePolicy(qt.QSizePolicy.Expanding,
- qt.QSizePolicy.Preferred)
+ spacer.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Preferred)
widgets.append(spacer)
loadColormapRangeBtn = qt.QToolButton()
@@ -882,7 +917,7 @@ class BaseMaskToolsWidget(qt.QWidget):
config.addWidget(self.maxLineLabel, 1, 0)
config.addWidget(self.maxLineEdit, 1, 1)
- self.applyMaskBtn = qt.QPushButton('Apply mask')
+ self.applyMaskBtn = qt.QPushButton("Apply mask")
self.applyMaskBtn.clicked.connect(self._maskBtnClicked)
layout = qt.QVBoxLayout()
@@ -891,7 +926,7 @@ class BaseMaskToolsWidget(qt.QWidget):
layout.addWidget(self.applyMaskBtn)
layout.addStretch(1)
- self.thresholdGroup = qt.QGroupBox('Threshold')
+ self.thresholdGroup = qt.QGroupBox("Threshold")
self.thresholdGroup.setLayout(layout)
# Init widget state
@@ -903,21 +938,23 @@ class BaseMaskToolsWidget(qt.QWidget):
def _initOtherToolsGroupBox(self):
layout = qt.QVBoxLayout()
- self.maskNanBtn = qt.QPushButton('Mask not finite values')
- self.maskNanBtn.setToolTip('Mask Not a Number and infinite values')
+ self.maskNanBtn = qt.QPushButton("Mask not finite values")
+ self.maskNanBtn.setToolTip("Mask Not a Number and infinite values")
self.maskNanBtn.clicked.connect(self._maskNotFiniteBtnClicked)
layout.addWidget(self.maskNanBtn)
layout.addStretch(1)
- self.otherToolGroup = qt.QGroupBox('Other tools')
+ self.otherToolGroup = qt.QGroupBox("Other tools")
self.otherToolGroup.setLayout(layout)
return self.otherToolGroup
def changeEvent(self, event):
"""Reset drawing action when disabling widget"""
- if (event.type() == qt.QEvent.EnabledChange and
- not self.isEnabled() and
- self.drawActionGroup.checkedAction()):
+ if (
+ event.type() == qt.QEvent.EnabledChange
+ and not self.isEnabled()
+ and self.drawActionGroup.checkedAction()
+ ):
# Disable drawing tool by reseting interaction to pan or zoom
self.plot.resetInteractiveMode()
@@ -952,20 +989,20 @@ class BaseMaskToolsWidget(qt.QWidget):
colors = numpy.empty((self._maxLevelNumber + 1, 4), dtype=numpy.float32)
# Set color
- colors[:,:3] = self._defaultOverlayColor[:3]
+ colors[:, :3] = self._defaultOverlayColor[:3]
# check if some colors has been directly set by the user
mask = numpy.equal(self._defaultColors, False)
- colors[mask,:3] = self._overlayColors[mask,:3]
+ colors[mask, :3] = self._overlayColors[mask, :3]
# Set alpha
- colors[:, -1] = alpha / 2.
+ colors[:, -1] = alpha / 2.0
# Set highlighted level color
colors[level, 3] = alpha
# Set no mask level
- colors[0] = (0., 0., 0., 0.)
+ colors[0] = (0.0, 0.0, 0.0, 0.0)
self._colormap.setColormapLUT(colors)
@@ -1007,14 +1044,14 @@ class BaseMaskToolsWidget(qt.QWidget):
def _updateColors(self, *args):
"""Rebuild mask colormap when selected level or transparency change"""
- self._setMaskColors(self.levelSpinBox.value(),
- self.transparencySlider.value() /
- self.transparencySlider.maximum())
+ self._setMaskColors(
+ self.levelSpinBox.value(),
+ self.transparencySlider.value() / self.transparencySlider.maximum(),
+ )
self._updatePlotMask()
self._updateInteractiveMode()
def _pencilWidthChanged(self, width):
-
old = self.pencilSpinBox.blockSignals(True)
try:
self.pencilSpinBox.setValue(width)
@@ -1032,13 +1069,13 @@ class BaseMaskToolsWidget(qt.QWidget):
"""Update the current mode to the same if some cached data have to be
updated. It is the case for the color for example.
"""
- if self._drawingMode == 'rectangle':
+ if self._drawingMode == "rectangle":
self._activeRectMode()
- elif self._drawingMode == 'ellipse':
+ elif self._drawingMode == "ellipse":
self._activeEllipseMode()
- elif self._drawingMode == 'polygon':
+ elif self._drawingMode == "polygon":
self._activePolygonMode()
- elif self._drawingMode == 'pencil':
+ elif self._drawingMode == "pencil":
self._activePencilMode()
def _handleClearMask(self):
@@ -1075,30 +1112,30 @@ class BaseMaskToolsWidget(qt.QWidget):
def _activeRectMode(self):
"""Handle rect action mode triggering"""
self._releaseDrawingMode()
- self._drawingMode = 'rectangle'
+ self._drawingMode = "rectangle"
self.plot.sigPlotSignal.connect(self._plotDrawEvent)
color = self.getCurrentMaskColor()
self.plot.setInteractiveMode(
- 'draw', shape='rectangle', source=self, color=color)
+ "draw", shape="rectangle", source=self, color=color
+ )
self._updateDrawingModeWidgets()
def _activeEllipseMode(self):
"""Handle circle action mode triggering"""
self._releaseDrawingMode()
- self._drawingMode = 'ellipse'
+ self._drawingMode = "ellipse"
self.plot.sigPlotSignal.connect(self._plotDrawEvent)
color = self.getCurrentMaskColor()
- self.plot.setInteractiveMode(
- 'draw', shape='ellipse', source=self, color=color)
+ self.plot.setInteractiveMode("draw", shape="ellipse", source=self, color=color)
self._updateDrawingModeWidgets()
def _activePolygonMode(self):
"""Handle polygon action mode triggering"""
self._releaseDrawingMode()
- self._drawingMode = 'polygon'
+ self._drawingMode = "polygon"
self.plot.sigPlotSignal.connect(self._plotDrawEvent)
color = self.getCurrentMaskColor()
- self.plot.setInteractiveMode('draw', shape='polygon', source=self, color=color)
+ self.plot.setInteractiveMode("draw", shape="polygon", source=self, color=color)
self._updateDrawingModeWidgets()
def _getPencilWidth(self):
@@ -1111,17 +1148,18 @@ class BaseMaskToolsWidget(qt.QWidget):
def _activePencilMode(self):
"""Handle pencil action mode triggering"""
self._releaseDrawingMode()
- self._drawingMode = 'pencil'
+ self._drawingMode = "pencil"
self.plot.sigPlotSignal.connect(self._plotDrawEvent)
color = self.getCurrentMaskColor()
width = self._getPencilWidth()
self.plot.setInteractiveMode(
- 'draw', shape='pencil', source=self, color=color, width=width)
+ "draw", shape="pencil", source=self, color=color, width=width
+ )
self._updateDrawingModeWidgets()
def _updateDrawingModeWidgets(self):
self.maskStateWidget.setVisible(self._drawingMode is not None)
- self.pencilSetting.setVisible(self._drawingMode == 'pencil')
+ self.pencilSetting.setVisible(self._drawingMode == "pencil")
# Handle plot drawing events
@@ -1131,7 +1169,7 @@ class BaseMaskToolsWidget(qt.QWidget):
:rtype: bool"""
# First draw event, use current modifiers for all draw sequence
- doMask = (self.maskStateGroup.checkedId() == 1)
+ doMask = self.maskStateGroup.checkedId() == 1
if qt.QApplication.keyboardModifiers() & qt.Qt.ControlModifier:
doMask = not doMask
return doMask
@@ -1163,29 +1201,29 @@ class BaseMaskToolsWidget(qt.QWidget):
def _maskBtnClicked(self):
if self.belowThresholdAction.isChecked():
if self.minLineEdit.text():
- self._mask.updateBelowThreshold(self.levelSpinBox.value(),
- self.minLineEdit.value())
+ self._mask.updateBelowThreshold(
+ self.levelSpinBox.value(), self.minLineEdit.value()
+ )
self._mask.commit()
elif self.betweenThresholdAction.isChecked():
if self.minLineEdit.text() and self.maxLineEdit.text():
min_ = self.minLineEdit.value()
max_ = self.maxLineEdit.value()
- self._mask.updateBetweenThresholds(self.levelSpinBox.value(),
- min_, max_)
+ self._mask.updateBetweenThresholds(
+ self.levelSpinBox.value(), min_, max_
+ )
self._mask.commit()
elif self.aboveThresholdAction.isChecked():
if self.maxLineEdit.text():
max_ = float(self.maxLineEdit.value())
- self._mask.updateAboveThreshold(self.levelSpinBox.value(),
- max_)
+ self._mask.updateAboveThreshold(self.levelSpinBox.value(), max_)
self._mask.commit()
def _maskNotFiniteBtnClicked(self):
"""Handle not finite mask button clicked: mask NaNs and inf"""
- self._mask.updateNotFinite(
- self.levelSpinBox.value())
+ self._mask.updateNotFinite(self.levelSpinBox.value())
self._mask.commit()
@@ -1201,7 +1239,7 @@ class BaseMaskToolsDockWidget(qt.QDockWidget):
sigMaskChanged = qt.Signal()
- def __init__(self, parent=None, name='Mask', widget=None):
+ def __init__(self, parent=None, name="Mask", widget=None):
super(BaseMaskToolsDockWidget, self).__init__(parent)
self.setWindowTitle(name)
@@ -1255,7 +1293,7 @@ class BaseMaskToolsDockWidget(qt.QDockWidget):
See :class:`QMainWindow`.
"""
action = super(BaseMaskToolsDockWidget, self).toggleViewAction()
- action.setIcon(icons.getQIcon('image-mask'))
+ action.setIcon(icons.getQIcon("image-mask"))
action.setToolTip("Display/hide mask tools")
return action
diff --git a/src/silx/gui/plot/__init__.py b/src/silx/gui/plot/__init__.py
index 129c4de..2a1587f 100644
--- a/src/silx/gui/plot/__init__.py
+++ b/src/silx/gui/plot/__init__.py
@@ -66,5 +66,13 @@ from .ImageView import ImageView # noqa
from .StackView import StackView # noqa
from .ScatterView import ScatterView # noqa
-__all__ = ['ImageView', 'PlotWidget', 'PlotWindow', 'Plot1D', 'Plot2D',
- 'StackView', 'ScatterView', 'TickMode']
+__all__ = [
+ "ImageView",
+ "PlotWidget",
+ "PlotWindow",
+ "Plot1D",
+ "Plot2D",
+ "StackView",
+ "ScatterView",
+ "TickMode",
+]
diff --git a/src/silx/gui/plot/_utils/__init__.py b/src/silx/gui/plot/_utils/__init__.py
index 39fa7e4..3075007 100644
--- a/src/silx/gui/plot/_utils/__init__.py
+++ b/src/silx/gui/plot/_utils/__init__.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -31,11 +31,12 @@ __date__ = "21/03/2017"
import numpy
from .panzoom import FLOAT32_SAFE_MIN, FLOAT32_MINPOS, FLOAT32_SAFE_MAX
-from .panzoom import applyZoomToPlot, applyPan, checkAxisLimits
+from .panzoom import applyZoomToPlot, applyPan, checkAxisLimits, EnabledAxes
-def addMarginsToLimits(margins, isXLog, isYLog,
- xMin, xMax, yMin, yMax, y2Min=None, y2Max=None):
+def addMarginsToLimits(
+ margins, isXLog, isYLog, xMin, xMax, yMin, yMax, y2Min=None, y2Max=None
+):
"""Returns updated limits by extending them with margins.
:param margins: The ratio of the margins to add or None for no margins.
@@ -55,35 +56,35 @@ def addMarginsToLimits(margins, isXLog, isYLog,
xMin -= xMinMargin * xRange
xMax += xMaxMargin * xRange
- elif xMin > 0. and xMax > 0.: # Log scale
+ elif xMin > 0.0 and xMax > 0.0: # Log scale
# Do not apply margins if limits < 0
xMinLog, xMaxLog = numpy.log10(xMin), numpy.log10(xMax)
xRangeLog = xMaxLog - xMinLog
- xMin = pow(10., xMinLog - xMinMargin * xRangeLog)
- xMax = pow(10., xMaxLog + xMaxMargin * xRangeLog)
+ xMin = pow(10.0, xMinLog - xMinMargin * xRangeLog)
+ xMax = pow(10.0, xMaxLog + xMaxMargin * xRangeLog)
if not isYLog:
yRange = yMax - yMin
yMin -= yMinMargin * yRange
yMax += yMaxMargin * yRange
- elif yMin > 0. and yMax > 0.: # Log scale
+ elif yMin > 0.0 and yMax > 0.0: # Log scale
# Do not apply margins if limits < 0
yMinLog, yMaxLog = numpy.log10(yMin), numpy.log10(yMax)
yRangeLog = yMaxLog - yMinLog
- yMin = pow(10., yMinLog - yMinMargin * yRangeLog)
- yMax = pow(10., yMaxLog + yMaxMargin * yRangeLog)
+ yMin = pow(10.0, yMinLog - yMinMargin * yRangeLog)
+ yMax = pow(10.0, yMaxLog + yMaxMargin * yRangeLog)
if y2Min is not None and y2Max is not None:
if not isYLog:
yRange = y2Max - y2Min
y2Min -= yMinMargin * yRange
y2Max += yMaxMargin * yRange
- elif y2Min > 0. and y2Max > 0.: # Log scale
+ elif y2Min > 0.0 and y2Max > 0.0: # Log scale
# Do not apply margins if limits < 0
yMinLog, yMaxLog = numpy.log10(y2Min), numpy.log10(y2Max)
yRangeLog = yMaxLog - yMinLog
- y2Min = pow(10., yMinLog - yMinMargin * yRangeLog)
- y2Max = pow(10., yMaxLog + yMaxMargin * yRangeLog)
+ y2Min = pow(10.0, yMinLog - yMinMargin * yRangeLog)
+ y2Max = pow(10.0, yMaxLog + yMaxMargin * yRangeLog)
if y2Min is None or y2Max is None:
return xMin, xMax, yMin, yMax
diff --git a/src/silx/gui/plot/_utils/delaunay.py b/src/silx/gui/plot/_utils/delaunay.py
deleted file mode 100644
index 48b0db7..0000000
--- a/src/silx/gui/plot/_utils/delaunay.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# /*##########################################################################
-#
-# Copyright (c) 2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Wrapper over Delaunay implementation"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "02/05/2019"
-
-
-import logging
-import sys
-
-import numpy
-
-
-_logger = logging.getLogger(__name__)
-
-
-def delaunay(x, y):
- """Returns Delaunay instance for x, y points
-
- :param numpy.ndarray x:
- :param numpy.ndarray y:
- :rtype: Union[None,scipy.spatial.Delaunay]
- """
- # Lazy-loading of Delaunay
- try:
- from scipy.spatial import Delaunay as _Delaunay
- except ImportError: # Fallback using local Delaunay
- from silx.third_party.scipy_spatial import Delaunay as _Delaunay
-
- points = numpy.array((x, y)).T
- try:
- delaunay = _Delaunay(points)
- except (RuntimeError, ValueError):
- _logger.debug("Delaunay tesselation failed: %s", sys.exc_info()[1])
- delaunay = None
-
- return delaunay
diff --git a/src/silx/gui/plot/_utils/dtime_ticklayout.py b/src/silx/gui/plot/_utils/dtime_ticklayout.py
index 3c355d7..ba0fda7 100644
--- a/src/silx/gui/plot/_utils/dtime_ticklayout.py
+++ b/src/silx/gui/plot/_utils/dtime_ticklayout.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2014-2022 European Synchrotron Radiation Facility
+# Copyright (c) 2014-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -21,6 +21,8 @@
# THE SOFTWARE.
#
# ###########################################################################*/
+from __future__ import annotations
+
"""This module implements date-time labels layout on graph axes."""
__authors__ = ["P. Kenter"]
@@ -28,6 +30,7 @@ __license__ = "MIT"
__date__ = "04/04/2018"
+from collections.abc import Sequence
import datetime as dt
import enum
import logging
@@ -48,14 +51,15 @@ SECONDS_PER_MINUTE = 60
SECONDS_PER_HOUR = 60 * SECONDS_PER_MINUTE
SECONDS_PER_DAY = 24 * SECONDS_PER_HOUR
SECONDS_PER_YEAR = 365.25 * SECONDS_PER_DAY
-SECONDS_PER_MONTH_AVERAGE = SECONDS_PER_YEAR / 12 # Seconds per average month
+SECONDS_PER_MONTH_AVERAGE = SECONDS_PER_YEAR / 12 # Seconds per average month
# No dt.timezone in Python 2.7 so we use dateutil.tz.tzutc
_EPOCH = dt.datetime(1970, 1, 1, tzinfo=dateutil.tz.tzutc())
+
def timestamp(dtObj):
- """ Returns POSIX timestamp of a datetime objects.
+ """Returns POSIX timestamp of a datetime objects.
If the dtObj object has a timestamp() method (python 3.3), this is
used. Otherwise (e.g. python 2.7) it is calculated here.
@@ -73,9 +77,22 @@ def timestamp(dtObj):
else:
# Back ported from Python 3.5
if dtObj.tzinfo is None:
- return time.mktime((dtObj.year, dtObj.month, dtObj.day,
- dtObj.hour, dtObj.minute, dtObj.second,
- -1, -1, -1)) + dtObj.microsecond / 1e6
+ return (
+ time.mktime(
+ (
+ dtObj.year,
+ dtObj.month,
+ dtObj.day,
+ dtObj.hour,
+ dtObj.minute,
+ dtObj.second,
+ -1,
+ -1,
+ -1,
+ )
+ )
+ + dtObj.microsecond / 1e6
+ )
else:
return (dtObj - _EPOCH).total_seconds()
@@ -92,7 +109,7 @@ class DtUnit(enum.Enum):
def getDateElement(dateTime, unit):
- """ Picks the date element with the unit from the dateTime
+ """Picks the date element with the unit from the dateTime
E.g. getDateElement(datetime(1970, 5, 6), DtUnit.Day) will return 6
@@ -118,7 +135,7 @@ def getDateElement(dateTime, unit):
def setDateElement(dateTime, value, unit):
- """ Returns a copy of dateTime with the tickStep unit set to value
+ """Returns a copy of dateTime with the tickStep unit set to value
:param datetime.datetime: date time object
:param int value: value to set
@@ -126,8 +143,9 @@ def setDateElement(dateTime, value, unit):
:return: datetime.datetime
"""
intValue = int(value)
- _logger.debug("setDateElement({}, {} (int={}), {})"
- .format(dateTime, value, intValue, unit))
+ _logger.debug(
+ "setDateElement({}, {} (int={}), {})".format(dateTime, value, intValue, unit)
+ )
year = dateTime.year
month = dateTime.month
@@ -154,16 +172,19 @@ def setDateElement(dateTime, value, unit):
else:
raise ValueError("Unexpected DtUnit: {}".format(unit))
- _logger.debug("creating date time {}"
- .format((year, month, day, hour, minute, second, microsecond)))
-
- return dt.datetime(year, month, day, hour, minute, second, microsecond,
- tzinfo=dateTime.tzinfo)
+ _logger.debug(
+ "creating date time {}".format(
+ (year, month, day, hour, minute, second, microsecond)
+ )
+ )
+ return dt.datetime(
+ year, month, day, hour, minute, second, microsecond, tzinfo=dateTime.tzinfo
+ )
def roundToElement(dateTime, unit):
- """ Returns a copy of dateTime rounded to given unit
+ """Returns a copy of dateTime rounded to given unit
:param datetime.datetime: date time object
:param DtUnit unit: unit
@@ -178,7 +199,7 @@ def roundToElement(dateTime, unit):
microsecond = dateTime.microsecond
if unit.value < DtUnit.YEARS.value:
- pass # Never round years
+ pass # Never round years
if unit.value < DtUnit.MONTHS.value:
month = 1
if unit.value < DtUnit.DAYS.value:
@@ -192,14 +213,15 @@ def roundToElement(dateTime, unit):
if unit.value < DtUnit.MICRO_SECONDS.value:
microsecond = 0
- result = dt.datetime(year, month, day, hour, minute, second, microsecond,
- tzinfo=dateTime.tzinfo)
+ result = dt.datetime(
+ year, month, day, hour, minute, second, microsecond, tzinfo=dateTime.tzinfo
+ )
return result
def addValueToDate(dateTime, value, unit):
- """ Adds a value with unit to a dateTime.
+ """Adds a value with unit to a dateTime.
Uses dateutil.relativedelta.relativedelta from the standard library to do
the actual math. This function doesn't allow for fractional month or years,
@@ -211,13 +233,13 @@ def addValueToDate(dateTime, value, unit):
:return:
:raises ValueError: unit is unsupported or result is out of datetime bounds
"""
- #logger.debug("addValueToDate({}, {}, {})".format(dateTime, value, unit))
+ # logger.debug("addValueToDate({}, {}, {})".format(dateTime, value, unit))
if unit == DtUnit.YEARS:
- intValue = int(value) # floats not implemented in relativeDelta(years)
+ intValue = int(value) # floats not implemented in relativeDelta(years)
return dateTime + relativedelta(years=intValue)
elif unit == DtUnit.MONTHS:
- intValue = int(value) # floats not implemented in relativeDelta(mohths)
+ intValue = int(value) # floats not implemented in relativeDelta(mohths)
return dateTime + relativedelta(months=intValue)
elif unit == DtUnit.DAYS:
return dateTime + relativedelta(days=value)
@@ -234,7 +256,7 @@ def addValueToDate(dateTime, value, unit):
def bestUnit(durationInSeconds):
- """ Gets the best tick spacing given a duration in seconds.
+ """Gets the best tick spacing given a duration in seconds.
:param durationInSeconds: time span duration in seconds
:return: DtUnit enumeration.
@@ -264,8 +286,7 @@ def bestUnit(durationInSeconds):
elif durationInSeconds > 1 * 2:
return (durationInSeconds, DtUnit.SECONDS)
else:
- return (durationInSeconds * MICROSECONDS_PER_SECOND,
- DtUnit.MICRO_SECONDS)
+ return (durationInSeconds * MICROSECONDS_PER_SECOND, DtUnit.MICRO_SECONDS)
NICE_DATE_VALUES = {
@@ -275,12 +296,12 @@ NICE_DATE_VALUES = {
DtUnit.HOURS: [1, 2, 3, 4, 6, 12],
DtUnit.MINUTES: [1, 2, 3, 5, 10, 15, 30],
DtUnit.SECONDS: [1, 2, 3, 5, 10, 15, 30],
- DtUnit.MICRO_SECONDS : [1.0, 2.0, 5.0, 10.0], # floats for microsec
+ DtUnit.MICRO_SECONDS: [1.0, 2.0, 3.0, 4.0, 5.0, 10.0], # floats for microsec
}
def bestFormatString(spacing, unit):
- """ Finds the best format string given the spacing and DtUnit.
+ """Finds the best format string given the spacing and DtUnit.
If the spacing is a fractional number < 1 the format string will take this
into account
@@ -310,8 +331,31 @@ def bestFormatString(spacing, unit):
raise ValueError("Unexpected DtUnit: {}".format(unit))
+def formatDatetimes(
+ datetimes: Sequence[dt.datetime], spacing: int | None, unit: DtUnit | None
+) -> dict[dt.datetime, str]:
+ """Returns formatted string for each datetime according to tick spacing and time unit"""
+ if spacing is None or unit is None:
+ # Locator has no spacing or units yet: Use elaborate fmtString
+ return {
+ datetime: datetime.strftime("Y-%m-%d %H:%M:%S") for datetime in datetimes
+ }
+
+ formatString = bestFormatString(spacing, unit)
+ if unit != DtUnit.MICRO_SECONDS:
+ return {datetime: datetime.strftime(formatString) for datetime in datetimes}
+
+ # For microseconds: Strip leading/trailing zeros
+ texts = tuple(datetime.strftime(formatString) for datetime in datetimes)
+ nzeros = min(len(text) - len(text.rstrip("0")) for text in texts)
+ return {
+ datetime: text[0 if text[0] != "0" else 1 : -min(nzeros, 5)]
+ for datetime, text in zip(datetimes, texts)
+ }
+
+
def niceDateTimeElement(value, unit, isRound=False):
- """ Uses the Nice Numbers algorithm to determine a nice value.
+ """Uses the Nice Numbers algorithm to determine a nice value.
The fractions are optimized for the unit of the date element.
"""
@@ -326,10 +370,8 @@ def niceDateTimeElement(value, unit, isRound=False):
def findStartDate(dMin, dMax, nTicks):
- """ Rounds a date down to the nearest nice number of ticks
- """
- assert dMax >= dMin, \
- "dMin ({}) should come before dMax ({})".format(dMin, dMax)
+ """Rounds a date down to the nearest nice number of ticks"""
+ assert dMax >= dMin, "dMin ({}) should come before dMax ({})".format(dMin, dMax)
if dMin == dMax:
# Fallback when range is smaller than microsecond resolution
@@ -337,34 +379,42 @@ def findStartDate(dMin, dMax, nTicks):
delta = dMax - dMin
lengthSec = delta.total_seconds()
- _logger.debug("findStartDate: {}, {} (duration = {} sec, {} days)"
- .format(dMin, dMax, lengthSec, lengthSec / SECONDS_PER_DAY))
+ _logger.debug(
+ "findStartDate: {}, {} (duration = {} sec, {} days)".format(
+ dMin, dMax, lengthSec, lengthSec / SECONDS_PER_DAY
+ )
+ )
length, unit = bestUnit(lengthSec)
niceLength = niceDateTimeElement(length, unit)
- _logger.debug("Length: {:8.3f} {} (nice = {})"
- .format(length, unit.name, niceLength))
+ _logger.debug(
+ "Length: {:8.3f} {} (nice = {})".format(length, unit.name, niceLength)
+ )
niceSpacing = niceDateTimeElement(niceLength / nTicks, unit, isRound=True)
- _logger.debug("Spacing: {:8.3f} {} (nice = {})"
- .format(niceLength / nTicks, unit.name, niceSpacing))
+ _logger.debug(
+ "Spacing: {:8.3f} {} (nice = {})".format(
+ niceLength / nTicks, unit.name, niceSpacing
+ )
+ )
dVal = getDateElement(dMin, unit)
- if unit == DtUnit.MONTHS: # TODO: better rounding?
- niceVal = math.floor((dVal-1) / niceSpacing) * niceSpacing + 1
+ if unit == DtUnit.MONTHS: # TODO: better rounding?
+ niceVal = math.floor((dVal - 1) / niceSpacing) * niceSpacing + 1
elif unit == DtUnit.DAYS:
- niceVal = math.floor((dVal-1) / niceSpacing) * niceSpacing + 1
+ niceVal = math.floor((dVal - 1) / niceSpacing) * niceSpacing + 1
else:
niceVal = math.floor(dVal / niceSpacing) * niceSpacing
if unit == DtUnit.YEARS and niceVal <= dt.MINYEAR:
niceVal = max(1, niceSpacing)
- _logger.debug("StartValue: dVal = {}, niceVal: {} ({})"
- .format(dVal, niceVal, unit.name))
+ _logger.debug(
+ "StartValue: dVal = {}, niceVal: {} ({})".format(dVal, niceVal, unit.name)
+ )
startDate = roundToElement(dMin, unit)
startDate = setDateElement(startDate, niceVal, unit)
@@ -372,8 +422,8 @@ def findStartDate(dMin, dMax, nTicks):
return startDate, niceSpacing, unit
-def dateRange(dMin, dMax, step, unit, includeFirstBeyond = False):
- """ Generates a range of dates
+def dateRange(dMin, dMax, step, unit, includeFirstBeyond=False):
+ """Generates a range of dates
:param datetime dMin: start date
:param datetime dMax: end date
@@ -384,8 +434,7 @@ def dateRange(dMin, dMax, step, unit, includeFirstBeyond = False):
datetime will always be smaller than dMax.
:return:
"""
- if (unit == DtUnit.YEARS or unit == DtUnit.MONTHS or
- unit == DtUnit.MICRO_SECONDS):
+ if unit == DtUnit.YEARS or unit == DtUnit.MONTHS or unit == DtUnit.MICRO_SECONDS:
# No support for fractional month or year and resolution is microsecond
# In those cases, make sure the step is at least 1
step = max(1, step)
@@ -404,7 +453,6 @@ def dateRange(dMin, dMax, step, unit, includeFirstBeyond = False):
yield dateTime
-
def calcTicks(dMin, dMax, nTicks):
"""Returns tick positions.
@@ -414,27 +462,19 @@ def calcTicks(dMin, dMax, nTicks):
ticks may differ.
:returns: (list of datetimes, DtUnit) tuple
"""
- _logger.debug("Calc calcTicks({}, {}, nTicks={})"
- .format(dMin, dMax, nTicks))
+ _logger.debug("Calc calcTicks({}, {}, nTicks={})".format(dMin, dMax, nTicks))
startDate, niceSpacing, unit = findStartDate(dMin, dMax, nTicks)
result = []
- for d in dateRange(startDate, dMax, niceSpacing, unit,
- includeFirstBeyond=True):
+ for d in dateRange(startDate, dMax, niceSpacing, unit, includeFirstBeyond=True):
result.append(d)
return result, niceSpacing, unit
def calcTicksAdaptive(dMin, dMax, axisLength, tickDensity):
- """ Calls calcTicks with a variable number of ticks, depending on axisLength
- """
+ """Calls calcTicks with a variable number of ticks, depending on axisLength"""
# At least 2 ticks
nticks = max(2, int(round(tickDensity * axisLength)))
- return calcTicks(dMin, dMax, nticks)
-
-
-
-
-
+ return calcTicks(dMin, dMax, nticks)
diff --git a/src/silx/gui/plot/_utils/panzoom.py b/src/silx/gui/plot/_utils/panzoom.py
index 8592ad0..cac591d 100644
--- a/src/silx/gui/plot/_utils/panzoom.py
+++ b/src/silx/gui/plot/_utils/panzoom.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -23,6 +23,8 @@
# ###########################################################################*/
"""Functions to apply pan and zoom on a Plot"""
+from __future__ import annotations
+
__authors__ = ["T. Vincent", "V. Valls"]
__license__ = "MIT"
__date__ = "08/08/2017"
@@ -30,6 +32,7 @@ __date__ = "08/08/2017"
import logging
import math
+from typing import NamedTuple
import numpy
@@ -46,11 +49,11 @@ FLOAT32_SAFE_MAX = 1e37
# TODO double support
-def checkAxisLimits(vmin, vmax, isLog: bool=False, name: str=""):
+def checkAxisLimits(vmin: float, vmax: float, isLog: bool = False, name: str = ""):
"""Makes sure axis range is not empty and within supported range.
- :param float vmin: Min axis value
- :param float vmax: Max axis value
+ :param vmin: Min axis value
+ :param vmax: Max axis value
:return: (min, max) making sure min < max
:rtype: 2-tuple of float
"""
@@ -59,11 +62,11 @@ def checkAxisLimits(vmin, vmax, isLog: bool=False, name: str=""):
vmin = numpy.clip(vmin, min_, FLOAT32_SAFE_MAX)
if vmax < vmin:
- _logger.debug('%s axis: max < min, inverting limits.', name)
+ _logger.debug("%s axis: max < min, inverting limits.", name)
vmin, vmax = vmax, vmin
elif vmax == vmin:
- _logger.debug('%s axis: max == min, expanding limits.', name)
- if vmin == 0.:
+ _logger.debug("%s axis: max == min, expanding limits.", name)
+ if vmin == 0.0:
vmin, vmax = -0.1, 0.1
elif vmin < 0:
vmax *= 0.9
@@ -75,26 +78,27 @@ def checkAxisLimits(vmin, vmax, isLog: bool=False, name: str=""):
return vmin, vmax
-def scale1DRange(min_, max_, center, scale, isLog):
+def scale1DRange(
+ min_: float, max_: float, center: float, scale: float, isLog: bool
+) -> tuple[float, float]:
"""Scale a 1D range given a scale factor and an center point.
Keeps the values in a smaller range than float32.
- :param float min_: The current min value of the range.
- :param float max_: The current max value of the range.
- :param float center: The center of the zoom (i.e., invariant point).
- :param float scale: The scale to use for zoom
- :param bool isLog: Whether using log scale or not.
- :return: The zoomed range.
- :rtype: tuple of 2 floats: (min, max)
+ :param min_: The current min value of the range.
+ :param max_: The current max value of the range.
+ :param center: The center of the zoom (i.e., invariant point).
+ :param scale: The scale to use for zoom
+ :param isLog: Whether using log scale or not.
+ :return: The zoomed range (min, max)
"""
if isLog:
# Min and center can be < 0 when
# autoscale is off and switch to log scale
# max_ < 0 should not happen
- min_ = numpy.log10(min_) if min_ > 0. else FLOAT32_MINPOS
- center = numpy.log10(center) if center > 0. else FLOAT32_MINPOS
- max_ = numpy.log10(max_) if max_ > 0. else FLOAT32_MINPOS
+ min_ = numpy.log10(min_) if min_ > 0.0 else FLOAT32_MINPOS
+ center = numpy.log10(center) if center > 0.0 else FLOAT32_MINPOS
+ max_ = numpy.log10(max_) if max_ > 0.0 else FLOAT32_MINPOS
if min_ == max_:
return min_, max_
@@ -102,12 +106,12 @@ def scale1DRange(min_, max_, center, scale, isLog):
offset = (center - min_) / (max_ - min_)
range_ = (max_ - min_) / scale
newMin = center - offset * range_
- newMax = center + (1. - offset) * range_
+ newMax = center + (1.0 - offset) * range_
if isLog:
# No overflow as exponent is log10 of a float32
- newMin = pow(10., newMin)
- newMax = pow(10., newMax)
+ newMin = pow(10.0, newMin)
+ newMax = pow(10.0, newMax)
newMin = numpy.clip(newMin, FLOAT32_MINPOS, FLOAT32_SAFE_MAX)
newMax = numpy.clip(newMax, FLOAT32_MINPOS, FLOAT32_SAFE_MAX)
else:
@@ -116,16 +120,34 @@ def scale1DRange(min_, max_, center, scale, isLog):
return newMin, newMax
-def applyZoomToPlot(plot, scaleF, center=None):
+class EnabledAxes(NamedTuple):
+ """Toggle zoom for each axis"""
+
+ xaxis: bool = True
+ yaxis: bool = True
+ y2axis: bool = True
+
+ def isDisabled(self) -> bool:
+ """True only if all axes are disabled"""
+ return not (self.xaxis or self.yaxis or self.y2axis)
+
+
+def applyZoomToPlot(
+ plot,
+ scale: float,
+ center: tuple[float, float] = None,
+ enabled: EnabledAxes = EnabledAxes(),
+):
"""Zoom in/out plot given a scale and a center point.
:param plot: The plot on which to apply zoom.
- :param float scaleF: Scale factor of zoom.
+ :param scale: Scale factor of zoom.
:param center: (x, y) coords in pixel coordinates of the zoom center.
- :type center: 2-tuple of float
+ :param enabled: Toggle zoom for each axis independently
"""
xMin, xMax = plot.getXAxis().getLimits()
yMin, yMax = plot.getYAxis().getLimits()
+ y2Min, y2Max = plot.getYAxis(axis="right").getLimits()
if center is None:
left, top, width, height = plot.getPlotBoundsInPixels()
@@ -136,18 +158,23 @@ def applyZoomToPlot(plot, scaleF, center=None):
dataCenterPos = plot.pixelToData(cx, cy)
assert dataCenterPos is not None
- xMin, xMax = scale1DRange(xMin, xMax, dataCenterPos[0], scaleF,
- plot.getXAxis()._isLogarithmic())
+ if enabled.xaxis:
+ xMin, xMax = scale1DRange(
+ xMin, xMax, dataCenterPos[0], scale, plot.getXAxis()._isLogarithmic()
+ )
- yMin, yMax = scale1DRange(yMin, yMax, dataCenterPos[1], scaleF,
- plot.getYAxis()._isLogarithmic())
+ if enabled.yaxis:
+ yMin, yMax = scale1DRange(
+ yMin, yMax, dataCenterPos[1], scale, plot.getYAxis()._isLogarithmic()
+ )
- dataPos = plot.pixelToData(cx, cy, axis="right")
- assert dataPos is not None
- y2Center = dataPos[1]
- y2Min, y2Max = plot.getYAxis(axis="right").getLimits()
- y2Min, y2Max = scale1DRange(y2Min, y2Max, y2Center, scaleF,
- plot.getYAxis()._isLogarithmic())
+ if enabled.y2axis:
+ dataPos = plot.pixelToData(cx, cy, axis="right")
+ assert dataPos is not None
+ y2Center = dataPos[1]
+ y2Min, y2Max = scale1DRange(
+ y2Min, y2Max, y2Center, scale, plot.getYAxis()._isLogarithmic()
+ )
plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
@@ -166,15 +193,15 @@ def applyPan(min_, max_, panFactor, isLog10):
:return: New min and max value with pan applied.
:rtype: 2-tuple of float.
"""
- if isLog10 and min_ > 0.:
+ if isLog10 and min_ > 0.0:
# Negative range and log scale can happen with matplotlib
logMin, logMax = math.log10(min_), math.log10(max_)
logOffset = panFactor * (logMax - logMin)
- newMin = pow(10., logMin + logOffset)
- newMax = pow(10., logMax + logOffset)
+ newMin = pow(10.0, logMin + logOffset)
+ newMax = pow(10.0, logMax + logOffset)
# Takes care of out-of-range values
- if newMin > 0. and newMax < float('inf'):
+ if newMin > 0.0 and newMax < float("inf"):
min_, max_ = newMin, newMax
else:
@@ -182,13 +209,14 @@ def applyPan(min_, max_, panFactor, isLog10):
newMin, newMax = min_ + offset, max_ + offset
# Takes care of out-of-range values
- if newMin > - float('inf') and newMax < float('inf'):
+ if newMin > -float("inf") and newMax < float("inf"):
min_, max_ = newMin, newMax
return min_, max_
class _Unset(object):
"""To be able to have distinction between None and unset"""
+
pass
@@ -203,10 +231,17 @@ class ViewConstraints(object):
self._minRange = [None, None]
self._maxRange = [None, None]
- def update(self, xMin=_Unset, xMax=_Unset,
- yMin=_Unset, yMax=_Unset,
- minXRange=_Unset, maxXRange=_Unset,
- minYRange=_Unset, maxYRange=_Unset):
+ def update(
+ self,
+ xMin=_Unset,
+ xMax=_Unset,
+ yMin=_Unset,
+ yMax=_Unset,
+ minXRange=_Unset,
+ maxXRange=_Unset,
+ minYRange=_Unset,
+ maxYRange=_Unset,
+ ):
"""
Update the constraints managed by the object
@@ -238,7 +273,6 @@ class ViewConstraints(object):
maxPos = [xMax, yMax]
for axis in range(2):
-
value = minPos[axis]
if value is not _Unset and value != self._min[axis]:
self._min[axis] = value
@@ -262,7 +296,11 @@ class ViewConstraints(object):
# Sanity checks
for axis in range(2):
- if self._maxRange[axis] is not None and self._min[axis] is not None and self._max[axis] is not None:
+ if (
+ self._maxRange[axis] is not None
+ and self._min[axis] is not None
+ and self._max[axis] is not None
+ ):
# max range cannot be larger than bounds
diff = self._max[axis] - self._min[axis]
self._maxRange[axis] = min(self._maxRange[axis], diff)
@@ -298,8 +336,12 @@ class ViewConstraints(object):
viewRange[axis][1] += delta * 0.5
# clamp min and max positions
- outMin = self._min[axis] is not None and viewRange[axis][0] < self._min[axis]
- outMax = self._max[axis] is not None and viewRange[axis][1] > self._max[axis]
+ outMin = (
+ self._min[axis] is not None and viewRange[axis][0] < self._min[axis]
+ )
+ outMax = (
+ self._max[axis] is not None and viewRange[axis][1] > self._max[axis]
+ )
if outMin and outMax:
if allow_scaling:
diff --git a/src/silx/gui/plot/_utils/test/test_dtime_ticklayout.py b/src/silx/gui/plot/_utils/test/test_dtime_ticklayout.py
index 87c0742..adcb9c9 100644
--- a/src/silx/gui/plot/_utils/test/test_dtime_ticklayout.py
+++ b/src/silx/gui/plot/_utils/test/test_dtime_ticklayout.py
@@ -57,7 +57,6 @@ def testNoCrash():
value = 100e-6 # Start at 100 micro sec range.
while value <= 200 * SECONDS_PER_YEAR:
-
d2 = d1 + dt.timedelta(microseconds=value * 1e6) # end date range
for numTicks in range(2, 12):
diff --git a/src/silx/gui/plot/_utils/test/test_ticklayout.py b/src/silx/gui/plot/_utils/test/test_ticklayout.py
index 8388c7e..1413563 100644
--- a/src/silx/gui/plot/_utils/test/test_ticklayout.py
+++ b/src/silx/gui/plot/_utils/test/test_ticklayout.py
@@ -27,7 +27,6 @@ __license__ = "MIT"
__date__ = "17/01/2018"
-import unittest
import numpy
from silx.utils.testutils import ParametricTestCase
@@ -41,10 +40,10 @@ class TestTickLayout(ParametricTestCase):
def testTicks(self):
"""Test of :func:`ticks`"""
tests = { # (vmin, vmax): ref_ticks
- (1., 1.): (1.,),
+ (1.0, 1.0): (1.0,),
(0.5, 10.5): (2.0, 4.0, 6.0, 8.0, 10.0),
- (0.001, 0.005): (0.001, 0.002, 0.003, 0.004, 0.005)
- }
+ (0.001, 0.005): (0.001, 0.002, 0.003, 0.004, 0.005),
+ }
for (vmin, vmax), ref_ticks in tests.items():
with self.subTest(vmin=vmin, vmax=vmax):
@@ -55,9 +54,9 @@ class TestTickLayout(ParametricTestCase):
"""Minimalistic tests of :func:`niceNumbers`"""
tests = { # (vmin, vmax): ref_ticks
(0.5, 10.5): (0.0, 12.0, 2.0, 0),
- (10000., 10000.5): (10000.0, 10000.5, 0.1, 1),
- (0.001, 0.005): (0.001, 0.005, 0.001, 3)
- }
+ (10000.0, 10000.5): (10000.0, 10000.5, 0.1, 1),
+ (0.001, 0.005): (0.001, 0.005, 0.001, 3),
+ }
for (vmin, vmax), ref_ticks in tests.items():
with self.subTest(vmin=vmin, vmax=vmax):
@@ -67,9 +66,9 @@ class TestTickLayout(ParametricTestCase):
def testNiceNumbersLog(self):
"""Minimalistic tests of :func:`niceNumbersForLog10`"""
tests = { # (log10(min), log10(max): ref_ticks
- (0., 3.): (0, 3, 1, 0),
- (-3., 3): (-3, 3, 1, 0),
- (-32., 0.): (-36, 0, 6, 0)
+ (0.0, 3.0): (0, 3, 1, 0),
+ (-3.0, 3): (-3, 3, 1, 0),
+ (-32.0, 0.0): (-36, 0, 6, 0),
}
for (vmin, vmax), ref_ticks in tests.items():
diff --git a/src/silx/gui/plot/_utils/ticklayout.py b/src/silx/gui/plot/_utils/ticklayout.py
index 4266be0..3678270 100644
--- a/src/silx/gui/plot/_utils/ticklayout.py
+++ b/src/silx/gui/plot/_utils/ticklayout.py
@@ -33,6 +33,7 @@ import math
# utils #######################################################################
+
def numberOfDigits(tickSpacing):
"""Returns the number of digits to display for text label.
@@ -76,7 +77,7 @@ def numberOfDigits(tickSpacing):
def niceNumGeneric(value, niceFractions=None, isRound=False):
- """ A more generic implementation of the _niceNum function
+ """A more generic implementation of the _niceNum function
Allows the user to specify the fractions instead of using a hardcoded
list of [1, 2, 5, 10.0].
@@ -85,15 +86,15 @@ def niceNumGeneric(value, niceFractions=None, isRound=False):
return value
if niceFractions is None: # Use default values
- niceFractions = 1., 2., 5., 10.
- roundFractions = (1.5, 3., 7., 10.) if isRound else niceFractions
+ niceFractions = 1.0, 2.0, 5.0, 10.0
+ roundFractions = (1.5, 3.0, 7.0, 10.0) if isRound else niceFractions
else:
roundFractions = list(niceFractions)
if isRound:
# Take the average with the next element. The last remains the same.
for i in range(len(roundFractions) - 1):
- roundFractions[i] = (niceFractions[i] + niceFractions[i+1]) / 2
+ roundFractions[i] = (niceFractions[i] + niceFractions[i + 1]) / 2
highest = niceFractions[-1]
value = float(value)
@@ -133,7 +134,7 @@ def niceNumbers(vMin, vMax, nTicks=5):
def _frange(start, stop, step):
"""range for float (including stop)."""
- assert step >= 0.
+ assert step >= 0.0
while start <= stop:
yield start
start += step
@@ -166,7 +167,7 @@ def ticks(vMin, vMax, nbTicks=5):
nfrac = numberOfDigits(vMax - vMin)
# Generate labels
- format_ = '%g' if nfrac == 0 else '%.{}f'.format(nfrac)
+ format_ = "%g" if nfrac == 0 else "%.{}f".format(nfrac)
labels = [format_ % tick for tick in positions]
return positions, labels
@@ -194,6 +195,7 @@ def niceNumbersAdaptative(vMin, vMax, axisLength, tickDensity):
# Nice Numbers for log scale ##################################################
+
def niceNumbersForLog10(minLog, maxLog, nTicks=5):
"""Return tick positions for logarithmic scale
@@ -209,7 +211,7 @@ def niceNumbersForLog10(minLog, maxLog, nTicks=5):
rangelog = graphmaxlog - graphminlog
if rangelog <= nTicks:
- spacing = 1.
+ spacing = 1.0
else:
spacing = math.floor(rangelog / nTicks)
diff --git a/src/silx/gui/plot/actions/PlotAction.py b/src/silx/gui/plot/actions/PlotAction.py
index de041dc..9341bdd 100644
--- a/src/silx/gui/plot/actions/PlotAction.py
+++ b/src/silx/gui/plot/actions/PlotAction.py
@@ -31,26 +31,36 @@ __license__ = "MIT"
__date__ = "03/01/2018"
+from typing import Callable, Optional, Union
import weakref
from silx.gui import icons
from silx.gui import qt
+from silx.gui.plot import PlotWidget
class PlotAction(qt.QAction):
"""Base class for QAction that operates on a PlotWidget.
:param plot: :class:`.PlotWidget` instance on which to operate.
- :param icon: QIcon or str name of icon to use
- :param str text: The name of this action to be used for menu label
- :param str tooltip: The text of the tooltip
+ :param icon: QIcon or name of icon to use
+ :param text: The name of this action to be used for menu label
+ :param tooltip: The text of the tooltip
:param triggered: The callback to connect to the action's triggered
- signal or None for no callback.
- :param bool checkable: True for checkable action, False otherwise (default)
+ signal. None for no callback (default)
+ :param checkable: True for checkable action, False otherwise (default)
:param parent: See :class:`QAction`.
"""
- def __init__(self, plot, icon, text, tooltip=None,
- triggered=None, checkable=False, parent=None):
+ def __init__(
+ self,
+ plot: PlotWidget,
+ icon: Union[str, qt.QIcon],
+ text: str,
+ tooltip: Optional[str] = None,
+ triggered: Optional[Callable] = None,
+ checkable: bool = False,
+ parent: Optional[qt.QObject] = None,
+ ):
assert plot is not None
self._plotRef = weakref.ref(plot)
diff --git a/src/silx/gui/plot/actions/PlotToolAction.py b/src/silx/gui/plot/actions/PlotToolAction.py
index 8c3b3c2..479d7c2 100644
--- a/src/silx/gui/plot/actions/PlotToolAction.py
+++ b/src/silx/gui/plot/actions/PlotToolAction.py
@@ -41,16 +41,26 @@ class PlotToolAction(PlotAction):
"""Base class for QAction that maintain a tool window operating on a
PlotWidget."""
- def __init__(self, plot, icon, text, tooltip=None,
- triggered=None, checkable=False, parent=None):
- PlotAction.__init__(self,
- plot=plot,
- icon=icon,
- text=text,
- tooltip=tooltip,
- triggered=self._triggered,
- parent=parent,
- checkable=True)
+ def __init__(
+ self,
+ plot,
+ icon,
+ text,
+ tooltip=None,
+ triggered=None,
+ checkable=False,
+ parent=None,
+ ):
+ PlotAction.__init__(
+ self,
+ plot=plot,
+ icon=icon,
+ text=text,
+ tooltip=tooltip,
+ triggered=self._triggered,
+ parent=parent,
+ checkable=True,
+ )
self._previousGeometry = None
self._toolWindow = None
diff --git a/src/silx/gui/plot/actions/control.py b/src/silx/gui/plot/actions/control.py
index e75048a..c21d235 100755
--- a/src/silx/gui/plot/actions/control.py
+++ b/src/silx/gui/plot/actions/control.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2004-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -55,6 +55,7 @@ from silx.gui.plot import items
from silx.gui.plot._utils import applyZoomToPlot as _applyZoomToPlot
from silx.gui import qt
from silx.gui import icons
+from silx.utils.deprecation import deprecated
_logger = logging.getLogger(__name__)
@@ -68,10 +69,14 @@ class ResetZoomAction(PlotAction):
def __init__(self, plot, parent=None):
super(ResetZoomAction, self).__init__(
- plot, icon='zoom-original', text='Reset Zoom',
- tooltip='Auto-scale the graph',
+ plot,
+ icon="zoom-original",
+ text="Reset Zoom",
+ tooltip="Auto-scale the graph",
triggered=self._actionTriggered,
- checkable=False, parent=parent)
+ checkable=False,
+ parent=parent,
+ )
self._autoscaleChanged(True)
plot.getXAxis().sigAutoScaleChanged.connect(self._autoscaleChanged)
plot.getYAxis().sigAutoScaleChanged.connect(self._autoscaleChanged)
@@ -82,13 +87,13 @@ class ResetZoomAction(PlotAction):
self.setEnabled(xAxis.isAutoScale() or yAxis.isAutoScale())
if xAxis.isAutoScale() and yAxis.isAutoScale():
- tooltip = 'Auto-scale the graph'
+ tooltip = "Auto-scale the graph"
elif xAxis.isAutoScale(): # And not Y axis
- tooltip = 'Auto-scale the x-axis of the graph only'
+ tooltip = "Auto-scale the x-axis of the graph only"
elif yAxis.isAutoScale(): # And not X axis
- tooltip = 'Auto-scale the y-axis of the graph only'
+ tooltip = "Auto-scale the y-axis of the graph only"
else: # no axis in autoscale
- tooltip = 'Auto-scale the graph'
+ tooltip = "Auto-scale the graph"
self.setToolTip(tooltip)
def _actionTriggered(self, checked=False):
@@ -104,10 +109,14 @@ class ZoomBackAction(PlotAction):
def __init__(self, plot, parent=None):
super(ZoomBackAction, self).__init__(
- plot, icon='zoom-back', text='Zoom Back',
- tooltip='Zoom back the plot',
+ plot,
+ icon="zoom-back",
+ text="Zoom Back",
+ tooltip="Zoom back the plot",
triggered=self._actionTriggered,
- checkable=False, parent=parent)
+ checkable=False,
+ parent=parent,
+ )
self.setShortcutContext(qt.Qt.WidgetShortcut)
def _actionTriggered(self, checked=False):
@@ -123,10 +132,14 @@ class ZoomInAction(PlotAction):
def __init__(self, plot, parent=None):
super(ZoomInAction, self).__init__(
- plot, icon='zoom-in', text='Zoom In',
- tooltip='Zoom in the plot',
+ plot,
+ icon="zoom-in",
+ text="Zoom In",
+ tooltip="Zoom in the plot",
triggered=self._actionTriggered,
- checkable=False, parent=parent)
+ checkable=False,
+ parent=parent,
+ )
self.setShortcut(qt.QKeySequence.ZoomIn)
self.setShortcutContext(qt.Qt.WidgetShortcut)
@@ -143,15 +156,19 @@ class ZoomOutAction(PlotAction):
def __init__(self, plot, parent=None):
super(ZoomOutAction, self).__init__(
- plot, icon='zoom-out', text='Zoom Out',
- tooltip='Zoom out the plot',
+ plot,
+ icon="zoom-out",
+ text="Zoom Out",
+ tooltip="Zoom out the plot",
triggered=self._actionTriggered,
- checkable=False, parent=parent)
+ checkable=False,
+ parent=parent,
+ )
self.setShortcut(qt.QKeySequence.ZoomOut)
self.setShortcutContext(qt.Qt.WidgetShortcut)
def _actionTriggered(self, checked=False):
- _applyZoomToPlot(self.plot, 1. / 1.1)
+ _applyZoomToPlot(self.plot, 1.0 / 1.1)
class XAxisAutoScaleAction(PlotAction):
@@ -163,11 +180,15 @@ class XAxisAutoScaleAction(PlotAction):
def __init__(self, plot, parent=None):
super(XAxisAutoScaleAction, self).__init__(
- plot, icon='plot-xauto', text='X Autoscale',
- tooltip='Enable x-axis auto-scale when checked.\n'
- 'If unchecked, x-axis does not change when reseting zoom.',
+ plot,
+ icon="plot-xauto",
+ text="X Autoscale",
+ tooltip="Enable x-axis auto-scale when checked.\n"
+ "If unchecked, x-axis does not change when reseting zoom.",
triggered=self._actionTriggered,
- checkable=True, parent=parent)
+ checkable=True,
+ parent=parent,
+ )
self.setChecked(plot.getXAxis().isAutoScale())
plot.getXAxis().sigAutoScaleChanged.connect(self.setChecked)
@@ -186,11 +207,15 @@ class YAxisAutoScaleAction(PlotAction):
def __init__(self, plot, parent=None):
super(YAxisAutoScaleAction, self).__init__(
- plot, icon='plot-yauto', text='Y Autoscale',
- tooltip='Enable y-axis auto-scale when checked.\n'
- 'If unchecked, y-axis does not change when reseting zoom.',
+ plot,
+ icon="plot-yauto",
+ text="Y Autoscale",
+ tooltip="Enable y-axis auto-scale when checked.\n"
+ "If unchecked, y-axis does not change when reseting zoom.",
triggered=self._actionTriggered,
- checkable=True, parent=parent)
+ checkable=True,
+ parent=parent,
+ )
self.setChecked(plot.getYAxis().isAutoScale())
plot.getYAxis().sigAutoScaleChanged.connect(self.setChecked)
@@ -209,10 +234,14 @@ class XAxisLogarithmicAction(PlotAction):
def __init__(self, plot, parent=None):
super(XAxisLogarithmicAction, self).__init__(
- plot, icon='plot-xlog', text='X Log. scale',
- tooltip='Logarithmic x-axis when checked',
+ plot,
+ icon="plot-xlog",
+ text="X Log. scale",
+ tooltip="Logarithmic x-axis when checked",
triggered=self._actionTriggered,
- checkable=True, parent=parent)
+ checkable=True,
+ parent=parent,
+ )
self.axis = plot.getXAxis()
self.setChecked(self.axis.getScale() == self.axis.LOGARITHMIC)
self.axis.sigScaleChanged.connect(self._setCheckedIfLogScale)
@@ -234,10 +263,14 @@ class YAxisLogarithmicAction(PlotAction):
def __init__(self, plot, parent=None):
super(YAxisLogarithmicAction, self).__init__(
- plot, icon='plot-ylog', text='Y Log. scale',
- tooltip='Logarithmic y-axis when checked',
+ plot,
+ icon="plot-ylog",
+ text="Y Log. scale",
+ tooltip="Logarithmic y-axis when checked",
triggered=self._actionTriggered,
- checkable=True, parent=parent)
+ checkable=True,
+ parent=parent,
+ )
self.axis = plot.getYAxis()
self.setChecked(self.axis.getScale() == self.axis.LOGARITHMIC)
self.axis.sigScaleChanged.connect(self._setCheckedIfLogScale)
@@ -259,21 +292,25 @@ class GridAction(PlotAction):
:param parent: See :class:`QAction`
"""
- def __init__(self, plot, gridMode='both', parent=None):
- assert gridMode in ('both', 'major')
+ def __init__(self, plot, gridMode="both", parent=None):
+ assert gridMode in ("both", "major")
self._gridMode = gridMode
super(GridAction, self).__init__(
- plot, icon='plot-grid', text='Grid',
- tooltip='Toggle grid (on/off)',
+ plot,
+ icon="plot-grid",
+ text="Grid",
+ tooltip="Toggle grid (on/off)",
triggered=self._actionTriggered,
- checkable=True, parent=parent)
+ checkable=True,
+ parent=parent,
+ )
self.setChecked(plot.getGraphGrid() is not None)
plot.sigSetGraphGrid.connect(self._gridChanged)
def _gridChanged(self, which):
"""Slot listening for PlotWidget grid mode change."""
- self.setChecked(which != 'None')
+ self.setChecked(which != "None")
def _actionTriggered(self, checked=False):
self.plot.setGraphGrid(self._gridMode if checked else None)
@@ -291,14 +328,17 @@ class CurveStyleAction(PlotAction):
def __init__(self, plot, parent=None):
super(CurveStyleAction, self).__init__(
- plot, icon='plot-toggle-points', text='Curve style',
- tooltip='Change curve line and markers style',
+ plot,
+ icon="plot-toggle-points",
+ text="Curve style",
+ tooltip="Change curve line and markers style",
triggered=self._actionTriggered,
- checkable=False, parent=parent)
+ checkable=False,
+ parent=parent,
+ )
def _actionTriggered(self, checked=False):
- currentState = (self.plot.isDefaultPlotLines(),
- self.plot.isDefaultPlotPoints())
+ currentState = (self.plot.isDefaultPlotLines(), self.plot.isDefaultPlotPoints())
if currentState == (False, False):
newState = True, False
@@ -323,21 +363,39 @@ class ColormapAction(PlotAction):
def __init__(self, plot, parent=None):
self._dialog = None # To store an instance of ColormapDialog
super(ColormapAction, self).__init__(
- plot, icon='colormap', text='Colormap',
+ plot,
+ icon="colormap",
+ text="Colormap",
tooltip="Change colormap",
triggered=self._actionTriggered,
- checkable=True, parent=parent)
+ checkable=True,
+ parent=parent,
+ )
self.plot.sigActiveImageChanged.connect(self._updateColormap)
self.plot.sigActiveScatterChanged.connect(self._updateColormap)
- def setColorDialog(self, colorDialog):
- """Set a specific color dialog instead of using the default dialog."""
- assert(colorDialog is not None)
- assert(self._dialog is None)
- self._dialog = colorDialog
- self._dialog.visibleChanged.connect(self._dialogVisibleChanged)
+ def setColormapDialog(self, dialog):
+ """Set a specific colormap dialog instead of using the default one."""
+ assert dialog is not None
+ if self._dialog is not None:
+ self._dialog.visibleChanged.disconnect(self._dialogVisibleChanged)
+
+ self._dialog = dialog
+ self._dialog.visibleChanged.connect(
+ self._dialogVisibleChanged, qt.Qt.UniqueConnection
+ )
self.setChecked(self._dialog.isVisible())
+ @deprecated(replacement="setColormapDialog", since_version="2.0")
+ def setColorDialog(self, colorDialog):
+ self.setColormapDialog(colorDialog)
+
+ def getColormapDialog(self):
+ if self._dialog is None:
+ self._dialog = self._createDialog(self.plot)
+ self._dialog.visibleChanged.connect(self._dialogVisibleChanged)
+ return self._dialog
+
@staticmethod
def _createDialog(parent):
"""Create the dialog if not already existing
@@ -346,22 +404,20 @@ class ColormapAction(PlotAction):
:rtype: ColormapDialog
"""
from silx.gui.dialog.ColormapDialog import ColormapDialog
+
dialog = ColormapDialog(parent=parent)
dialog.setModal(False)
return dialog
def _actionTriggered(self, checked=False):
"""Create a cmap dialog and update active image and default cmap."""
- if self._dialog is None:
- self._dialog = self._createDialog(self.plot)
- self._dialog.visibleChanged.connect(self._dialogVisibleChanged)
-
+ dialog = self.getColormapDialog()
# Run the dialog listening to colormap change
if checked is True:
self._updateColormap()
- self._dialog.show()
+ dialog.show()
else:
- self._dialog.hide()
+ dialog.hide()
def _dialogVisibleChanged(self, isVisible):
self.setChecked(isVisible)
@@ -380,7 +436,7 @@ class ColormapAction(PlotAction):
else:
# No active image or active image is RGBA,
# Check for active scatter plot
- scatter = self.plot._getActiveItem(kind='scatter')
+ scatter = self.plot.getActiveScatter()
if scatter is not None:
colormap = scatter.getColormap()
self._dialog.setItem(scatter)
@@ -405,10 +461,14 @@ class ColorBarAction(PlotAction):
def __init__(self, plot, parent=None):
self._dialog = None # To store an instance of ColorBar
super(ColorBarAction, self).__init__(
- plot, icon='colorbar', text='Colorbar',
+ plot,
+ icon="colorbar",
+ text="Colorbar",
tooltip="Show/Hide the colorbar",
triggered=self._actionTriggered,
- checkable=True, parent=parent)
+ checkable=True,
+ parent=parent,
+ )
colorBarWidget = self.plot.getColorBarWidget()
old = self.blockSignals(True)
self.setChecked(colorBarWidget.isVisibleTo(self.plot))
@@ -439,23 +499,24 @@ class KeepAspectRatioAction(PlotAction):
def __init__(self, plot, parent=None):
# Uses two images for checked/unchecked states
self._states = {
- False: (icons.getQIcon('shape-circle-solid'),
- "Keep data aspect ratio"),
- True: (icons.getQIcon('shape-ellipse-solid'),
- "Do no keep data aspect ratio")
+ False: (icons.getQIcon("shape-circle-solid"), "Keep data aspect ratio"),
+ True: (
+ icons.getQIcon("shape-ellipse-solid"),
+ "Do no keep data aspect ratio",
+ ),
}
icon, tooltip = self._states[plot.isKeepDataAspectRatio()]
super(KeepAspectRatioAction, self).__init__(
plot,
icon=icon,
- text='Toggle keep aspect ratio',
+ text="Toggle keep aspect ratio",
tooltip=tooltip,
triggered=self._actionTriggered,
checkable=False,
- parent=parent)
- plot.sigSetKeepDataAspectRatio.connect(
- self._keepDataAspectRatioChanged)
+ parent=parent,
+ )
+ plot.sigSetKeepDataAspectRatio.connect(self._keepDataAspectRatioChanged)
def _keepDataAspectRatioChanged(self, aspectRatio):
"""Handle Plot set keep aspect ratio signal"""
@@ -478,21 +539,20 @@ class YAxisInvertedAction(PlotAction):
def __init__(self, plot, parent=None):
# Uses two images for checked/unchecked states
self._states = {
- False: (icons.getQIcon('plot-ydown'),
- "Orient Y axis downward"),
- True: (icons.getQIcon('plot-yup'),
- "Orient Y axis upward"),
+ False: (icons.getQIcon("plot-ydown"), "Orient Y axis downward"),
+ True: (icons.getQIcon("plot-yup"), "Orient Y axis upward"),
}
icon, tooltip = self._states[plot.getYAxis().isInverted()]
super(YAxisInvertedAction, self).__init__(
plot,
icon=icon,
- text='Invert Y Axis',
+ text="Invert Y Axis",
tooltip=tooltip,
triggered=self._actionTriggered,
checkable=False,
- parent=parent)
+ parent=parent,
+ )
plot.getYAxis().sigInvertedChanged.connect(self._yAxisInvertedChanged)
def _yAxisInvertedChanged(self, inverted):
@@ -517,8 +577,7 @@ class CrosshairAction(PlotAction):
:param parent: See :class:`QAction`
"""
- def __init__(self, plot, color='black', linewidth=1, linestyle='-',
- parent=None):
+ def __init__(self, plot, color="black", linewidth=1, linestyle="-", parent=None):
self.color = color
"""Color used to draw the crosshair (str)."""
@@ -529,18 +588,24 @@ class CrosshairAction(PlotAction):
"""Style of line of the cursor (str)."""
super(CrosshairAction, self).__init__(
- plot, icon='crosshair', text='Crosshair Cursor',
- tooltip='Enable crosshair cursor when checked',
+ plot,
+ icon="crosshair",
+ text="Crosshair Cursor",
+ tooltip="Enable crosshair cursor when checked",
triggered=self._actionTriggered,
- checkable=True, parent=parent)
+ checkable=True,
+ parent=parent,
+ )
self.setChecked(plot.getGraphCursor() is not None)
plot.sigSetGraphCursor.connect(self.setChecked)
def _actionTriggered(self, checked=False):
- self.plot.setGraphCursor(checked,
- color=self.color,
- linestyle=self.linestyle,
- linewidth=self.linewidth)
+ self.plot.setGraphCursor(
+ checked,
+ color=self.color,
+ linestyle=self.linestyle,
+ linewidth=self.linewidth,
+ )
class PanWithArrowKeysAction(PlotAction):
@@ -551,12 +616,15 @@ class PanWithArrowKeysAction(PlotAction):
"""
def __init__(self, plot, parent=None):
-
super(PanWithArrowKeysAction, self).__init__(
- plot, icon='arrow-keys', text='Pan with arrow keys',
- tooltip='Enable pan with arrow keys when checked',
+ plot,
+ icon="arrow-keys",
+ text="Pan with arrow keys",
+ tooltip="Enable pan with arrow keys when checked",
triggered=self._actionTriggered,
- checkable=True, parent=parent)
+ checkable=True,
+ parent=parent,
+ )
self.setChecked(plot.isPanWithArrowKeys())
plot.sigSetPanWithArrowKeys.connect(self.setChecked)
@@ -572,15 +640,17 @@ class ShowAxisAction(PlotAction):
"""
def __init__(self, plot, parent=None):
- tooltip = 'Show plot axis when checked, otherwise hide them'
- PlotAction.__init__(self,
- plot,
- icon='axis',
- text='show axis',
- tooltip=tooltip,
- triggered=self._actionTriggered,
- checkable=True,
- parent=parent)
+ tooltip = "Show plot axis when checked, otherwise hide them"
+ PlotAction.__init__(
+ self,
+ plot,
+ icon="axis",
+ text="show axis",
+ tooltip=tooltip,
+ triggered=self._actionTriggered,
+ checkable=True,
+ parent=parent,
+ )
self.setChecked(self.plot.isAxesDisplayed())
plot._sigAxesVisibilityChanged.connect(self.setChecked)
@@ -597,15 +667,17 @@ class ClosePolygonInteractionAction(PlotAction):
"""
def __init__(self, plot, parent=None):
- tooltip = 'Close the current polygon drawn'
- PlotAction.__init__(self,
- plot,
- icon='add-shape-polygon',
- text='Close the polygon',
- tooltip=tooltip,
- triggered=self._actionTriggered,
- checkable=True,
- parent=parent)
+ tooltip = "Close the current polygon drawn"
+ PlotAction.__init__(
+ self,
+ plot,
+ icon="add-shape-polygon",
+ text="Close the polygon",
+ tooltip=tooltip,
+ triggered=self._actionTriggered,
+ checkable=True,
+ parent=parent,
+ )
self.plot.sigInteractiveModeChanged.connect(self._modeChanged)
self._modeChanged(None)
@@ -615,7 +687,7 @@ class ClosePolygonInteractionAction(PlotAction):
self.setEnabled(enabled)
def _actionTriggered(self, checked=False):
- self.plot._eventHandler.validate()
+ self.plot.interaction()._validate()
class OpenGLAction(PlotAction):
@@ -630,29 +702,32 @@ class OpenGLAction(PlotAction):
def __init__(self, plot, parent=None):
# Uses two images for checked/unchecked states
self._states = {
- "opengl": (icons.getQIcon('backend-opengl'),
- "OpenGL rendering (fast)\nClick to disable OpenGL"),
- "matplotlib": (icons.getQIcon('backend-opengl'),
- "Matplotlib rendering (safe)\nClick to enable OpenGL"),
- "unknown": (icons.getQIcon('backend-opengl'),
- "Custom rendering")
+ "opengl": (
+ icons.getQIcon("backend-opengl"),
+ "OpenGL rendering (fast)\nClick to disable OpenGL",
+ ),
+ "matplotlib": (
+ icons.getQIcon("backend-opengl"),
+ "Matplotlib rendering (safe)\nClick to enable OpenGL",
+ ),
+ "unknown": (icons.getQIcon("backend-opengl"), "Custom rendering"),
}
name = self._getBackendName(plot)
- self.__state = name
icon, tooltip = self._states[name]
super(OpenGLAction, self).__init__(
plot,
icon=icon,
- text='Enable/disable OpenGL rendering',
+ text="Enable/disable OpenGL rendering",
tooltip=tooltip,
triggered=self._actionTriggered,
checkable=True,
- parent=parent)
+ parent=parent,
+ )
+ plot.sigBackendChanged.connect(self._backendUpdated)
def _backendUpdated(self):
name = self._getBackendName(self.plot)
- self.__state = name
icon, tooltip = self._states[name]
self.setIcon(icon)
self.setToolTip(tooltip)
@@ -671,21 +746,15 @@ class OpenGLAction(PlotAction):
def _actionTriggered(self, checked=False):
plot = self.plot
name = self._getBackendName(self.plot)
- if self.__state != name:
- # THere is no event to know the backend was updated
- # So here we check if there is a mismatch between the displayed state
- # and the real state of the widget
- self._backendUpdated()
- return
if name != "opengl":
from silx.gui.utils import glutils
+
result = glutils.isOpenGLAvailable()
if not result:
- qt.QMessageBox.critical(plot, "OpenGL rendering not available", result.error)
- # Uncheck if needed
- self._backendUpdated()
+ qt.QMessageBox.critical(
+ plot, "OpenGL rendering is not available", result.error
+ )
return
plot.setBackend("opengl")
else:
plot.setBackend("matplotlib")
- self._backendUpdated()
diff --git a/src/silx/gui/plot/actions/fit.py b/src/silx/gui/plot/actions/fit.py
index 3489f70..ae8835a 100644
--- a/src/silx/gui/plot/actions/fit.py
+++ b/src/silx/gui/plot/actions/fit.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -42,7 +42,6 @@ import numpy
from .PlotToolAction import PlotToolAction
from .. import items
-from ....utils.deprecation import deprecated
from silx.gui import qt
from silx.gui.plot.ItemsSelectionDialog import ItemsSelectionDialog
@@ -63,10 +62,8 @@ def _getUniqueCurveOrHistogram(plot):
return curve
visibleItems = [item for item in plot.getItems() if item.isVisible()]
- histograms = [item for item in visibleItems
- if isinstance(item, items.Histogram)]
- curves = [item for item in visibleItems
- if isinstance(item, items.Curve)]
+ histograms = [item for item in visibleItems if isinstance(item, items.Histogram)]
+ curves = [item for item in visibleItems if isinstance(item, items.Curve)]
if len(histograms) == 1 and len(curves) == 0:
return histograms[0]
@@ -114,12 +111,11 @@ class _FitItemSelector(qt.QObject):
# disconnect from previous plot
previousPlotWidget = self.getPlotWidget()
if previousPlotWidget is not None:
- previousPlotWidget.sigItemAdded.disconnect(
- self.__plotWidgetUpdated)
- previousPlotWidget.sigItemRemoved.disconnect(
- self.__plotWidgetUpdated)
+ previousPlotWidget.sigItemAdded.disconnect(self.__plotWidgetUpdated)
+ previousPlotWidget.sigItemRemoved.disconnect(self.__plotWidgetUpdated)
previousPlotWidget.sigActiveCurveChanged.disconnect(
- self.__plotWidgetUpdated)
+ self.__plotWidgetUpdated
+ )
if plotWidget is None:
self.__plotWidgetRef = None
@@ -184,49 +180,15 @@ class FitAction(PlotToolAction):
self.__legend = None
super(FitAction, self).__init__(
- plot, icon='math-fit', text='Fit curve',
- tooltip='Open a fit dialog',
- parent=parent)
+ plot,
+ icon="math-fit",
+ text="Fit curve",
+ tooltip="Open a fit dialog",
+ parent=parent,
+ )
self.__fitItemSelector = _FitItemSelector()
- self.__fitItemSelector.sigCurrentItemChanged.connect(
- self._setFittedItem)
-
-
- @property
- @deprecated(replacement='getXRange()[0]', since_version='0.13.0')
- def xmin(self):
- return self.getXRange()[0]
-
- @property
- @deprecated(replacement='getXRange()[1]', since_version='0.13.0')
- def xmax(self):
- return self.getXRange()[1]
-
- @property
- @deprecated(replacement='getXData()', since_version='0.13.0')
- def x(self):
- return self.getXData()
-
- @property
- @deprecated(replacement='getYData()', since_version='0.13.0')
- def y(self):
- return self.getYData()
-
- @property
- @deprecated(since_version='0.13.0')
- def xlabel(self):
- return self.__curveParams.get('xlabel', None)
-
- @property
- @deprecated(since_version='0.13.0')
- def ylabel(self):
- return self.__curveParams.get('ylabel', None)
-
- @property
- @deprecated(since_version='0.13.0')
- def legend(self):
- return self.__legend
+ self.__fitItemSelector.sigCurrentItemChanged.connect(self._setFittedItem)
def _createToolWindow(self):
# import done here rather than at module level to avoid circular import
@@ -299,11 +261,10 @@ class FitAction(PlotToolAction):
else:
xmin, xmax = self.getXRange()
- fitWidget.setData(
- xdata, ydata, xmin=xmin, xmax=xmax)
+ fitWidget.setData(xdata, ydata, xmin=xmin, xmax=xmax)
fitWidget.setWindowTitle(
- "Fitting " + item.getName() +
- " on x range %f-%f" % (xmin, xmax))
+ "Fitting " + item.getName() + " on x range %f-%f" % (xmin, xmax)
+ )
# X Range management
@@ -397,12 +358,12 @@ class FitAction(PlotToolAction):
self.__updateFitWidget()
return
- axis = item.getYAxis() if isinstance(item, items.YAxisMixIn) else 'left'
+ axis = item.getYAxis() if isinstance(item, items.YAxisMixIn) else "left"
self.__curveParams = {
- 'yaxis': axis,
- 'xlabel': plot.getXAxis().getLabel(),
- 'ylabel': plot.getYAxis(axis).getLabel(),
- }
+ "yaxis": axis,
+ "xlabel": plot.getXAxis().getLabel(),
+ "ylabel": plot.getYAxis(axis).getLabel(),
+ }
self.__legend = item.getName()
if isinstance(item, items.Histogram):
@@ -415,7 +376,7 @@ class FitAction(PlotToolAction):
self.__x = item.getXData(copy=False)
self.__y = item.getYData(copy=False)
- self.__item = item
+ self.__item = item
self.__updateFitWidget()
def __setFittedItemAutoUpdateEnabled(self, enabled):
@@ -468,14 +429,13 @@ class FitAction(PlotToolAction):
return
y_fit = fit_widget.fitmanager.gendata()
if fit_curve is None:
- self.plot.addCurve(x_fit, y_fit,
- fit_legend,
- resetzoom=False,
- **self.__curveParams)
+ self.plot.addCurve(
+ x_fit, y_fit, fit_legend, resetzoom=False, **self.__curveParams
+ )
else:
fit_curve.setData(x_fit, y_fit)
fit_curve.setVisible(True)
- fit_curve.setYAxis(self.__curveParams.get('yaxis', 'left'))
+ fit_curve.setYAxis(self.__curveParams.get("yaxis", "left"))
if ddict["event"] in ["FitStarted", "FitFailed"]:
if fit_curve is not None:
diff --git a/src/silx/gui/plot/actions/histogram.py b/src/silx/gui/plot/actions/histogram.py
index 448dd55..39c669b 100644
--- a/src/silx/gui/plot/actions/histogram.py
+++ b/src/silx/gui/plot/actions/histogram.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -31,7 +31,7 @@ The following QAction are available:
"""
__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"]
-__date__ = "01/12/2020"
+__date__ = "07/11/2023"
__license__ = "MIT"
from typing import Optional, Tuple
@@ -47,7 +47,6 @@ from silx.gui import qt
from silx.gui.plot import items
from silx.gui.widgets.ElidedLabel import ElidedLabel
from silx.gui.widgets.RangeSlider import RangeSlider
-from silx.utils.deprecation import deprecated
_logger = logging.getLogger(__name__)
@@ -62,7 +61,7 @@ class _ElidedLabel(ElidedLabel):
def sizeHint(self):
hint = super().sizeHint()
nbchar = max(len(self.text()), 12)
- width = self.fontMetrics().boundingRect('#' * nbchar).width()
+ width = self.fontMetrics().boundingRect("#" * nbchar).width()
return qt.QSize(max(hint.width(), width), hint.height())
@@ -73,7 +72,7 @@ class _StatWidget(qt.QWidget):
:param name:
"""
- def __init__(self, parent=None, name: str=''):
+ def __init__(self, parent=None, name: str = ""):
super().__init__(parent)
layout = qt.QHBoxLayout(self)
layout.setContentsMargins(0, 0, 0, 0)
@@ -84,7 +83,8 @@ class _StatWidget(qt.QWidget):
self.__valueWidget = _ElidedLabel(parent=self)
self.__valueWidget.setText("-")
self.__valueWidget.setTextInteractionFlags(
- qt.Qt.TextSelectableByMouse | qt.Qt.TextSelectableByKeyboard)
+ qt.Qt.TextSelectableByMouse | qt.Qt.TextSelectableByKeyboard
+ )
layout.addWidget(self.__valueWidget)
def setValue(self, value: Optional[float]):
@@ -92,8 +92,7 @@ class _StatWidget(qt.QWidget):
:param value:
"""
- self.__valueWidget.setText(
- "-" if value is None else "{:.5g}".format(value))
+ self.__valueWidget.setText("-" if value is None else "{:.5g}".format(value))
class _IntEdit(qt.QLineEdit):
@@ -124,9 +123,7 @@ class _IntEdit(qt.QLineEdit):
font = self.font()
font.setStyle(qt.QFont.StyleItalic)
fontMetrics = qt.QFontMetrics(font)
- self.setMaximumWidth(
- fontMetrics.boundingRect('0' * (nbchar + 1)).width()
- )
+ self.setMaximumWidth(fontMetrics.boundingRect("0" * (nbchar + 1)).width())
self.setMaxLength(nbchar)
def __textEdited(self, _):
@@ -191,7 +188,7 @@ class _IntEdit(qt.QLineEdit):
self.setRange(min(value, bottom), max(value, top))
return numpy.clip(value, *self.getRange())
- def setDefaultValue(self, value: int, extend_range: bool=False):
+ def setDefaultValue(self, value: int, extend_range: bool = False):
"""Set default value when QLineEdit is empty
:param int value:
@@ -210,7 +207,7 @@ class _IntEdit(qt.QLineEdit):
except ValueError:
return None
- def setCurrentValue(self, value: int, extend_range: bool=False):
+ def setCurrentValue(self, value: int, extend_range: bool = False):
"""Set the currently displayed value
:param int value:
@@ -236,7 +233,7 @@ class HistogramWidget(qt.QWidget):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.setWindowTitle('Histogram')
+ self.setWindowTitle("Histogram")
self.__itemRef = None # weakref on the item to track
@@ -247,6 +244,7 @@ class HistogramWidget(qt.QWidget):
# Plot
# Lazy import to avoid circular dependencies
from silx.gui.plot.PlotWindow import Plot1D
+
self.__plot = Plot1D(self)
layout.addWidget(self.__plot)
@@ -266,16 +264,18 @@ class HistogramWidget(qt.QWidget):
controlsLayout.addWidget(qt.QLabel("N. bins:"))
self.__nbinsLineEdit = _IntEdit(self)
self.__nbinsLineEdit.setRange(2, 9999)
- self.__nbinsLineEdit.sigValueChanged.connect(
- self.__updateHistogramFromControls)
+ self.__nbinsLineEdit.sigValueChanged.connect(self.__updateHistogramFromControls)
controlsLayout.addWidget(self.__nbinsLineEdit)
self.__rangeLabel = qt.QLabel("Range:")
controlsLayout.addWidget(self.__rangeLabel)
self.__rangeSlider = RangeSlider(parent=self)
- self.__rangeSlider.sigValueChanged.connect(
- self.__updateHistogramFromControls)
+ self.__rangeSlider.sigValueChanged.connect(self.__updateHistogramFromControls)
self.__rangeSlider.sigValueChanged.connect(self.__rangeChanged)
controlsLayout.addWidget(self.__rangeSlider)
+ self.__weightCheckBox = qt.QCheckBox(self)
+ self.__weightCheckBox.setText("Use weights")
+ self.__weightCheckBox.clicked.connect(self.__weightChanged)
+ controlsLayout.addWidget(self.__weightCheckBox)
controlsLayout.addStretch(1)
# Stats display
@@ -286,7 +286,8 @@ class HistogramWidget(qt.QWidget):
self.__statsWidgets = dict(
(name, _StatWidget(parent=statsWidget, name=name))
- for name in ("min", "max", "mean", "std", "sum"))
+ for name in ("min", "max", "mean", "std", "sum")
+ )
for widget in self.__statsWidgets.values():
statsLayout.addWidget(widget)
@@ -336,8 +337,10 @@ class HistogramWidget(qt.QWidget):
hist = self.getHistogram(copy=False)
if hist is not None:
count, edges = hist
- if (len(count) == self.__nbinsLineEdit.getValue() and
- (edges[0], edges[-1]) == self.__rangeSlider.getValues()):
+ if (
+ len(count) == self.__nbinsLineEdit.getValue()
+ and (edges[0], edges[-1]) == self.__rangeSlider.getValues()
+ ):
return # Nothing has changed
self._updateFromItem()
@@ -348,6 +351,9 @@ class HistogramWidget(qt.QWidget):
self.__rangeSlider.setToolTip(tooltip)
self.__rangeLabel.setToolTip(tooltip)
+ def __weightChanged(self, value):
+ self._updateFromItem()
+
def _updateFromItem(self):
"""Update histogram and stats from the item"""
item = self.getItem()
@@ -388,31 +394,39 @@ class HistogramWidget(qt.QWidget):
if xmin == 0:
range_ = -0.01, 0.01
else:
- range_ = sorted((xmin * .99, xmin * 1.01))
+ range_ = sorted((xmin * 0.99, xmin * 1.01))
else:
range_ = xmin, xmax
self.__rangeSlider.setRange(*range_)
self.__rangeSlider.setPositions(*previousPositions)
+ data = array.ravel().astype(numpy.float32)
histogram = Histogramnd(
- array.ravel().astype(numpy.float32),
+ data,
n_bins=max(2, self.__nbinsLineEdit.getValue()),
histo_range=self.__rangeSlider.getValues(),
+ weights=data,
)
if len(histogram.edges) != 1:
_logger.error("Error while computing the histogram")
self.reset()
return
- self.setHistogram(histogram.histo, histogram.edges[0])
+ if self.__weightCheckBox.isChecked():
+ self.setHistogram(histogram.weighted_histo, histogram.edges[0])
+ self.__plot.getYAxis().setLabel("Count * Value")
+ else:
+ self.setHistogram(histogram.histo, histogram.edges[0])
+ self.__plot.getYAxis().setLabel("Count")
self.resetZoom()
self.setStatistics(
min_=xmin,
max_=xmax,
mean=numpy.nanmean(array),
std=numpy.nanstd(array),
- sum_=numpy.nansum(array))
+ sum_=numpy.nansum(array),
+ )
def setHistogram(self, histogram, edges):
"""Set displayed histogram
@@ -422,20 +436,21 @@ class HistogramWidget(qt.QWidget):
"""
# Only useful if setHistogram is called directly
# TODO
- #nbins = len(histogram)
- #if nbins != self.__nbinsLineEdit.getDefaultValue():
+ # nbins = len(histogram)
+ # if nbins != self.__nbinsLineEdit.getDefaultValue():
# self.__nbinsLineEdit.setValue(nbins, extend_range=True)
- #self.__rangeSlider.setValues(edges[0], edges[-1])
+ # self.__rangeSlider.setValues(edges[0], edges[-1])
self.getPlotWidget().addHistogram(
histogram=histogram,
edges=edges,
- legend='histogram',
+ legend="histogram",
fill=True,
- color='#66aad7',
- resetzoom=False)
+ color="#66aad7",
+ resetzoom=False,
+ )
- def getHistogram(self, copy: bool=True):
+ def getHistogram(self, copy: bool = True):
"""Returns currently displayed histogram.
:param copy: True to get a copy,
@@ -443,24 +458,25 @@ class HistogramWidget(qt.QWidget):
:return: (histogram, edges) or None
"""
for item in self.getPlotWidget().getItems():
- if item.getName() == 'histogram':
- return (item.getValueData(copy=copy),
- item.getBinEdgesData(copy=copy))
+ if item.getName() == "histogram":
+ return (item.getValueData(copy=copy), item.getBinEdgesData(copy=copy))
else:
return None
- def setStatistics(self,
- min_: Optional[float] = None,
- max_: Optional[float] = None,
- mean: Optional[float] = None,
- std: Optional[float] = None,
- sum_: Optional[float] = None):
+ def setStatistics(
+ self,
+ min_: Optional[float] = None,
+ max_: Optional[float] = None,
+ mean: Optional[float] = None,
+ std: Optional[float] = None,
+ sum_: Optional[float] = None,
+ ):
"""Set displayed statistic indicators."""
- self.__statsWidgets['min'].setValue(min_)
- self.__statsWidgets['max'].setValue(max_)
- self.__statsWidgets['mean'].setValue(mean)
- self.__statsWidgets['std'].setValue(std)
- self.__statsWidgets['sum'].setValue(sum_)
+ self.__statsWidgets["min"].setValue(min_)
+ self.__statsWidgets["max"].setValue(max_)
+ self.__statsWidgets["mean"].setValue(mean)
+ self.__statsWidgets["std"].setValue(std)
+ self.__statsWidgets["sum"].setValue(sum_)
class PixelIntensitiesHistoAction(PlotToolAction):
@@ -471,12 +487,14 @@ class PixelIntensitiesHistoAction(PlotToolAction):
"""
def __init__(self, plot, parent=None):
- PlotToolAction.__init__(self,
- plot,
- icon='pixel-intensities',
- text='pixels intensity',
- tooltip='Compute image intensity distribution',
- parent=parent)
+ PlotToolAction.__init__(
+ self,
+ plot,
+ icon="pixel-intensities",
+ text="pixels intensity",
+ tooltip="Compute image intensity distribution",
+ parent=parent,
+ )
def _connectPlot(self, window):
plot = self.plot
@@ -514,19 +532,10 @@ class PixelIntensitiesHistoAction(PlotToolAction):
if self._isWindowInUse():
self._updateSelectedItem()
- @deprecated(since_version='0.15.0')
- def computeIntensityDistribution(self):
- self.getHistogramWidget()._updateFromItem()
-
def getHistogramWidget(self):
"""Returns the widget displaying the histogram"""
return self._getToolWindow()
- @deprecated(since_version='0.15.0',
- replacement='getHistogramWidget().getPlotWidget()')
- def getHistogramPlotWidget(self):
- return self._getToolWindow().getPlotWidget()
-
def _createToolWindow(self):
return HistogramWidget(self.plot, qt.Qt.Window)
diff --git a/src/silx/gui/plot/actions/io.py b/src/silx/gui/plot/actions/io.py
index 1ed9649..1ff95f3 100644
--- a/src/silx/gui/plot/actions/io.py
+++ b/src/silx/gui/plot/actions/io.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2004-2022 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -36,30 +36,27 @@ __authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"]
__license__ = "MIT"
__date__ = "25/09/2020"
-from . import PlotAction
-from silx.io.utils import save1D, savespec, NEXUS_HDF5_EXT
-from silx.io.nxdata import save_NXdata
+from io import BytesIO
import logging
import sys
import os.path
-from collections import OrderedDict
import traceback
import numpy
-from silx.utils.deprecation import deprecated
+from fabio.TiffIO import TiffIO
+from fabio.edfimage import EdfImage
+
from silx.gui import qt, printer
from silx.gui.dialog.GroupDialog import GroupDialog
-from silx.third_party.EdfFile import EdfFile
-from silx.third_party.TiffIO import TiffIO
+from silx.io.utils import save1D, savespec, NEXUS_HDF5_EXT
+from silx.io.nxdata import save_NXdata
+
+from . import PlotAction
from ...utils.image import convertArrayToQImage
-if sys.version_info[0] == 3:
- from io import BytesIO
-else:
- import cStringIO as _StringIO
- BytesIO = _StringIO.StringIO
+
_logger = logging.getLogger(__name__)
-_NEXUS_HDF5_EXT_STR = ' '.join(['*' + ext for ext in NEXUS_HDF5_EXT])
+_NEXUS_HDF5_EXT_STR = " ".join(["*" + ext for ext in NEXUS_HDF5_EXT])
def selectOutputGroup(h5filename):
@@ -87,111 +84,142 @@ class SaveAction(PlotAction):
:param parent: See :class:`QAction`.
"""
- SNAPSHOT_FILTER_SVG = 'Plot Snapshot as SVG (*.svg)'
- SNAPSHOT_FILTER_PNG = 'Plot Snapshot as PNG (*.png)'
+ SNAPSHOT_FILTER_SVG = "Plot Snapshot as SVG (*.svg)"
+ SNAPSHOT_FILTER_PNG = "Plot Snapshot as PNG (*.png)"
DEFAULT_ALL_FILTERS = (SNAPSHOT_FILTER_PNG, SNAPSHOT_FILTER_SVG)
# Dict of curve filters with CSV-like format
# Using ordered dict to guarantee filters order
# Note: '%.18e' is numpy.savetxt default format
- CURVE_FILTERS_TXT = OrderedDict((
- ('Curve as Raw ASCII (*.txt)',
- {'fmt': '%.18e', 'delimiter': ' ', 'header': False}),
- ('Curve as ";"-separated CSV (*.csv)',
- {'fmt': '%.18e', 'delimiter': ';', 'header': True}),
- ('Curve as ","-separated CSV (*.csv)',
- {'fmt': '%.18e', 'delimiter': ',', 'header': True}),
- ('Curve as tab-separated CSV (*.csv)',
- {'fmt': '%.18e', 'delimiter': '\t', 'header': True}),
- ('Curve as OMNIC CSV (*.csv)',
- {'fmt': '%.7E', 'delimiter': ',', 'header': False}),
- ('Curve as SpecFile (*.dat)',
- {'fmt': '%.10g', 'delimiter': '', 'header': False})
- ))
-
- CURVE_FILTER_NPY = 'Curve as NumPy binary file (*.npy)'
-
- CURVE_FILTER_NXDATA = 'Curve as NXdata (%s)' % _NEXUS_HDF5_EXT_STR
+ CURVE_FILTERS_TXT = dict(
+ (
+ (
+ "Curve as Raw ASCII (*.txt)",
+ {"fmt": "%.18e", "delimiter": " ", "header": False},
+ ),
+ (
+ 'Curve as ";"-separated CSV (*.csv)',
+ {"fmt": "%.18e", "delimiter": ";", "header": True},
+ ),
+ (
+ 'Curve as ","-separated CSV (*.csv)',
+ {"fmt": "%.18e", "delimiter": ",", "header": True},
+ ),
+ (
+ "Curve as tab-separated CSV (*.csv)",
+ {"fmt": "%.18e", "delimiter": "\t", "header": True},
+ ),
+ (
+ "Curve as OMNIC CSV (*.csv)",
+ {"fmt": "%.7E", "delimiter": ",", "header": False},
+ ),
+ (
+ "Curve as SpecFile (*.dat)",
+ {"fmt": "%.10g", "delimiter": "", "header": False},
+ ),
+ )
+ )
+
+ CURVE_FILTER_NPY = "Curve as NumPy binary file (*.npy)"
+
+ CURVE_FILTER_NXDATA = "Curve as NXdata (%s)" % _NEXUS_HDF5_EXT_STR
DEFAULT_CURVE_FILTERS = list(CURVE_FILTERS_TXT.keys()) + [
- CURVE_FILTER_NPY, CURVE_FILTER_NXDATA]
+ CURVE_FILTER_NPY,
+ CURVE_FILTER_NXDATA,
+ ]
DEFAULT_ALL_CURVES_FILTERS = ("All curves as SpecFile (*.dat)",)
- IMAGE_FILTER_EDF = 'Image data as EDF (*.edf)'
- IMAGE_FILTER_TIFF = 'Image data as TIFF (*.tif)'
- IMAGE_FILTER_NUMPY = 'Image data as NumPy binary file (*.npy)'
- IMAGE_FILTER_ASCII = 'Image data as ASCII (*.dat)'
- IMAGE_FILTER_CSV_COMMA = 'Image data as ,-separated CSV (*.csv)'
- IMAGE_FILTER_CSV_SEMICOLON = 'Image data as ;-separated CSV (*.csv)'
- IMAGE_FILTER_CSV_TAB = 'Image data as tab-separated CSV (*.csv)'
- IMAGE_FILTER_RGB_PNG = 'Image as PNG (*.png)'
- IMAGE_FILTER_NXDATA = 'Image as NXdata (%s)' % _NEXUS_HDF5_EXT_STR
-
- DEFAULT_IMAGE_FILTERS = (IMAGE_FILTER_EDF,
- IMAGE_FILTER_TIFF,
- IMAGE_FILTER_NUMPY,
- IMAGE_FILTER_ASCII,
- IMAGE_FILTER_CSV_COMMA,
- IMAGE_FILTER_CSV_SEMICOLON,
- IMAGE_FILTER_CSV_TAB,
- IMAGE_FILTER_RGB_PNG,
- IMAGE_FILTER_NXDATA)
-
- SCATTER_FILTER_NXDATA = 'Scatter as NXdata (%s)' % _NEXUS_HDF5_EXT_STR
+ IMAGE_FILTER_EDF = "Image data as EDF (*.edf)"
+ IMAGE_FILTER_TIFF = "Image data as TIFF (*.tif)"
+ IMAGE_FILTER_NUMPY = "Image data as NumPy binary file (*.npy)"
+ IMAGE_FILTER_ASCII = "Image data as ASCII (*.dat)"
+ IMAGE_FILTER_CSV_COMMA = "Image data as ,-separated CSV (*.csv)"
+ IMAGE_FILTER_CSV_SEMICOLON = "Image data as ;-separated CSV (*.csv)"
+ IMAGE_FILTER_CSV_TAB = "Image data as tab-separated CSV (*.csv)"
+ IMAGE_FILTER_RGB_PNG = "Image as PNG (*.png)"
+ IMAGE_FILTER_NXDATA = "Image as NXdata (%s)" % _NEXUS_HDF5_EXT_STR
+
+ DEFAULT_IMAGE_FILTERS = (
+ IMAGE_FILTER_EDF,
+ IMAGE_FILTER_TIFF,
+ IMAGE_FILTER_NUMPY,
+ IMAGE_FILTER_ASCII,
+ IMAGE_FILTER_CSV_COMMA,
+ IMAGE_FILTER_CSV_SEMICOLON,
+ IMAGE_FILTER_CSV_TAB,
+ IMAGE_FILTER_RGB_PNG,
+ IMAGE_FILTER_NXDATA,
+ )
+
+ SCATTER_FILTER_NXDATA = "Scatter as NXdata (%s)" % _NEXUS_HDF5_EXT_STR
DEFAULT_SCATTER_FILTERS = (SCATTER_FILTER_NXDATA,)
# filters for which we don't want an "overwrite existing file" warning
- DEFAULT_APPEND_FILTERS = (CURVE_FILTER_NXDATA, IMAGE_FILTER_NXDATA,
- SCATTER_FILTER_NXDATA)
+ DEFAULT_APPEND_FILTERS = (
+ CURVE_FILTER_NXDATA,
+ IMAGE_FILTER_NXDATA,
+ SCATTER_FILTER_NXDATA,
+ )
def __init__(self, plot, parent=None):
self._filters = {
- 'all': OrderedDict(),
- 'curve': OrderedDict(),
- 'curves': OrderedDict(),
- 'image': OrderedDict(),
- 'scatter': OrderedDict()}
+ "all": {},
+ "curve": {},
+ "curves": {},
+ "image": {},
+ "scatter": {},
+ }
self._appendFilters = list(self.DEFAULT_APPEND_FILTERS)
# Initialize filters
for nameFilter in self.DEFAULT_ALL_FILTERS:
self.setFileFilter(
- dataKind='all', nameFilter=nameFilter, func=self._saveSnapshot)
+ dataKind="all", nameFilter=nameFilter, func=self._saveSnapshot
+ )
for nameFilter in self.DEFAULT_CURVE_FILTERS:
self.setFileFilter(
- dataKind='curve', nameFilter=nameFilter, func=self._saveCurve)
+ dataKind="curve", nameFilter=nameFilter, func=self._saveCurve
+ )
for nameFilter in self.DEFAULT_ALL_CURVES_FILTERS:
self.setFileFilter(
- dataKind='curves', nameFilter=nameFilter, func=self._saveCurves)
+ dataKind="curves", nameFilter=nameFilter, func=self._saveCurves
+ )
for nameFilter in self.DEFAULT_IMAGE_FILTERS:
self.setFileFilter(
- dataKind='image', nameFilter=nameFilter, func=self._saveImage)
+ dataKind="image", nameFilter=nameFilter, func=self._saveImage
+ )
for nameFilter in self.DEFAULT_SCATTER_FILTERS:
self.setFileFilter(
- dataKind='scatter', nameFilter=nameFilter, func=self._saveScatter)
+ dataKind="scatter", nameFilter=nameFilter, func=self._saveScatter
+ )
super(SaveAction, self).__init__(
- plot, icon='document-save', text='Save as...',
- tooltip='Save curve/image/plot snapshot dialog',
+ plot,
+ icon="document-save",
+ text="Save as...",
+ tooltip="Save curve/image/plot snapshot dialog",
triggered=self._actionTriggered,
- checkable=False, parent=parent)
+ checkable=False,
+ parent=parent,
+ )
self.setShortcut(qt.QKeySequence.Save)
self.setShortcutContext(qt.Qt.WidgetShortcut)
@staticmethod
- def _errorMessage(informativeText='', parent=None):
+ def _errorMessage(informativeText="", parent=None):
"""Display an error message."""
# TODO issue with QMessageBox size fixed and too small
msg = qt.QMessageBox(parent)
msg.setIcon(qt.QMessageBox.Critical)
- msg.setInformativeText(informativeText + ' ' + str(sys.exc_info()[1]))
+ msg.setInformativeText(informativeText + " " + str(sys.exc_info()[1]))
msg.setDetailedText(traceback.format_exc())
msg.exec()
@@ -204,12 +232,11 @@ class SaveAction(PlotAction):
True otherwise.
"""
if nameFilter == self.SNAPSHOT_FILTER_PNG:
- fileFormat = 'png'
+ fileFormat = "png"
elif nameFilter == self.SNAPSHOT_FILTER_SVG:
- fileFormat = 'svg'
+ fileFormat = "svg"
else: # Format not supported
- _logger.error(
- 'Saving plot snapshot failed: format not supported')
+ _logger.error("Saving plot snapshot failed: format not supported")
return False
plot.saveGraph(filename, fileFormat=fileFormat)
@@ -260,8 +287,11 @@ class SaveAction(PlotAction):
@staticmethod
def _selectWriteableOutputGroup(filename, parent):
- if os.path.exists(filename) and os.path.isfile(filename) \
- and os.access(filename, os.W_OK):
+ if (
+ os.path.exists(filename)
+ and os.path.isfile(filename)
+ and os.access(filename, os.W_OK)
+ ):
entryPath = selectOutputGroup(filename)
if entryPath is None:
_logger.info("Save operation cancelled")
@@ -271,7 +301,7 @@ class SaveAction(PlotAction):
# create new entry in new file
return "/entry"
else:
- SaveAction._errorMessage('Save failed (file access issue)\n', parent=parent)
+ SaveAction._errorMessage("Save failed (file access issue)\n", parent=parent)
return None
def _saveCurveAsNXdata(self, curve, filename):
@@ -292,7 +322,8 @@ class SaveAction(PlotAction):
axes_long_names=[xlabel],
signal_errors=curve.getYErrorData(copy=False),
axes_errors=[curve.getXErrorData(copy=True)],
- title=self.plot.getGraphTitle())
+ title=self.plot.getGraphTitle(),
+ )
def _saveCurve(self, plot, filename, nameFilter):
"""Save a curve from the plot.
@@ -318,9 +349,9 @@ class SaveAction(PlotAction):
if nameFilter in self.CURVE_FILTERS_TXT:
filter_ = self.CURVE_FILTERS_TXT[nameFilter]
- fmt = filter_['fmt']
- csvdelim = filter_['delimiter']
- autoheader = filter_['header']
+ fmt = filter_["fmt"]
+ csvdelim = filter_["delimiter"]
+ autoheader = filter_["header"]
else:
# .npy or nxdata
fmt, csvdelim, autoheader = ("", "", False)
@@ -331,13 +362,18 @@ class SaveAction(PlotAction):
xdata, data, xlabel, labels = self._get1dData(curve)
try:
- save1D(filename,
- xdata, data,
- xlabel, labels,
- fmt=fmt, csvdelim=csvdelim,
- autoheader=autoheader)
+ save1D(
+ filename,
+ xdata,
+ data,
+ xlabel,
+ labels,
+ fmt=fmt,
+ csvdelim=csvdelim,
+ autoheader=autoheader,
+ )
except IOError:
- self._errorMessage('Save failed\n', parent=self.plot)
+ self._errorMessage("Save failed\n", parent=self.plot)
return False
return True
@@ -363,28 +399,39 @@ class SaveAction(PlotAction):
try:
xdata, data, xlabel, labels = self._get1dData(curve)
- specfile = savespec(filename,
- xdata, data,
- xlabel, labels,
- fmt="%.7g", scan_number=1, mode="w",
- write_file_header=True,
- close_file=False)
+ specfile = savespec(
+ filename,
+ xdata,
+ data,
+ xlabel,
+ labels,
+ fmt="%.7g",
+ scan_number=1,
+ mode="w",
+ write_file_header=True,
+ close_file=False,
+ )
except IOError:
- self._errorMessage('Save failed\n', parent=self.plot)
+ self._errorMessage("Save failed\n", parent=self.plot)
return False
for curve in curves[1:]:
try:
scanno += 1
xdata, data, xlabel, labels = self._get1dData(curve)
- specfile = savespec(specfile,
- xdata, data,
- xlabel, labels,
- fmt="%.7g", scan_number=scanno,
- write_file_header=False,
- close_file=False)
+ specfile = savespec(
+ specfile,
+ xdata,
+ data,
+ xlabel,
+ labels,
+ fmt="%.7g",
+ scan_number=scanno,
+ write_file_header=False,
+ close_file=False,
+ )
except IOError:
- self._errorMessage('Save failed\n', parent=self.plot)
+ self._errorMessage("Save failed\n", parent=self.plot)
return False
specfile.close()
@@ -403,28 +450,26 @@ class SaveAction(PlotAction):
image = plot.getActiveImage()
if image is None:
- qt.QMessageBox.warning(
- plot, "No Data", "No image to be saved")
+ qt.QMessageBox.warning(plot, "No Data", "No image to be saved")
return False
data = image.getData(copy=False)
# TODO Use silx.io for writing files
if nameFilter == self.IMAGE_FILTER_EDF:
- edfFile = EdfFile(filename, access="w+")
- edfFile.WriteImage({}, data, Append=0)
+ EdfImage(data=data, header={}).write(filename)
return True
elif nameFilter == self.IMAGE_FILTER_TIFF:
- tiffFile = TiffIO(filename, mode='w')
- tiffFile.writeImage(data, software='silx')
+ tiffFile = TiffIO(filename, mode="w")
+ tiffFile.writeImage(data, software="silx")
return True
elif nameFilter == self.IMAGE_FILTER_NUMPY:
try:
numpy.save(filename, data)
except IOError:
- self._errorMessage('Save failed\n', parent=self.plot)
+ self._errorMessage("Save failed\n", parent=self.plot)
return False
return True
@@ -439,39 +484,47 @@ class SaveAction(PlotAction):
xlabel, ylabel = self._getAxesLabels(image)
interpretation = "image" if len(data.shape) == 2 else "rgba-image"
- return save_NXdata(filename,
- nxentry_name=entryPath,
- signal=data,
- axes=[yaxis, xaxis],
- signal_name="image",
- axes_names=["y", "x"],
- axes_long_names=[ylabel, xlabel],
- title=plot.getGraphTitle(),
- interpretation=interpretation)
-
- elif nameFilter in (self.IMAGE_FILTER_ASCII,
- self.IMAGE_FILTER_CSV_COMMA,
- self.IMAGE_FILTER_CSV_SEMICOLON,
- self.IMAGE_FILTER_CSV_TAB):
+ return save_NXdata(
+ filename,
+ nxentry_name=entryPath,
+ signal=data,
+ axes=[yaxis, xaxis],
+ signal_name="image",
+ axes_names=["y", "x"],
+ axes_long_names=[ylabel, xlabel],
+ title=plot.getGraphTitle(),
+ interpretation=interpretation,
+ )
+
+ elif nameFilter in (
+ self.IMAGE_FILTER_ASCII,
+ self.IMAGE_FILTER_CSV_COMMA,
+ self.IMAGE_FILTER_CSV_SEMICOLON,
+ self.IMAGE_FILTER_CSV_TAB,
+ ):
csvdelim, filetype = {
- self.IMAGE_FILTER_ASCII: (' ', 'txt'),
- self.IMAGE_FILTER_CSV_COMMA: (',', 'csv'),
- self.IMAGE_FILTER_CSV_SEMICOLON: (';', 'csv'),
- self.IMAGE_FILTER_CSV_TAB: ('\t', 'csv'),
- }[nameFilter]
+ self.IMAGE_FILTER_ASCII: (" ", "txt"),
+ self.IMAGE_FILTER_CSV_COMMA: (",", "csv"),
+ self.IMAGE_FILTER_CSV_SEMICOLON: (";", "csv"),
+ self.IMAGE_FILTER_CSV_TAB: ("\t", "csv"),
+ }[nameFilter]
height, width = data.shape
rows, cols = numpy.mgrid[0:height, 0:width]
try:
- save1D(filename, rows.ravel(), (cols.ravel(), data.ravel()),
- filetype=filetype,
- xlabel='row',
- ylabels=['column', 'value'],
- csvdelim=csvdelim,
- autoheader=True)
+ save1D(
+ filename,
+ rows.ravel(),
+ (cols.ravel(), data.ravel()),
+ filetype=filetype,
+ xlabel="row",
+ ylabels=["column", "value"],
+ csvdelim=csvdelim,
+ autoheader=True,
+ )
except IOError:
- self._errorMessage('Save failed\n', parent=self.plot)
+ self._errorMessage("Save failed\n", parent=self.plot)
return False
return True
@@ -481,14 +534,13 @@ class SaveAction(PlotAction):
# Convert RGB QImage
qimage = convertArrayToQImage(rgbaImage[:, :, :3])
- if qimage.save(filename, 'PNG'):
+ if qimage.save(filename, "PNG"):
return True
else:
- _logger.error('Failed to save image as %s', filename)
+ _logger.error("Failed to save image as %s", filename)
qt.QMessageBox.critical(
- self.parent(),
- 'Save image as',
- 'Failed to save image')
+ self.parent(), "Save image as", "Failed to save image"
+ )
return False
@@ -533,7 +585,8 @@ class SaveAction(PlotAction):
axes_names=["x", "y"],
axes_long_names=[xlabel, ylabel],
axes_errors=[xerror, yerror],
- title=plot.getGraphTitle())
+ title=plot.getGraphTitle(),
+ )
def setFileFilter(self, dataKind, nameFilter, func, index=None, appendToFile=False):
"""Set a name filter to add/replace a file format support
@@ -550,7 +603,7 @@ class SaveAction(PlotAction):
file.
:param integer index: Index of the filter in the final list (or None)
"""
- assert dataKind in ('all', 'curve', 'curves', 'image', 'scatter')
+ assert dataKind in ("all", "curve", "curves", "image", "scatter")
if appendToFile:
self._appendFilters.append(nameFilter)
@@ -572,7 +625,7 @@ class SaveAction(PlotAction):
if index >= len(keyList):
# nothing to be done, already at the end
- txt = 'Requested index %d impossible, already at the end' % index
+ txt = "Requested index %d impossible, already at the end" % index
_logger.info(txt)
return
@@ -582,7 +635,7 @@ class SaveAction(PlotAction):
keyList.insert(index, nameFilter)
# build the new filters
- newFilters = OrderedDict()
+ newFilters = {}
for key in keyList:
newFilters[key] = self._filters[dataKind][key]
@@ -597,34 +650,33 @@ class SaveAction(PlotAction):
The kind of data for which the provided filter is valid.
On of: 'all', 'curve', 'curves', 'image', 'scatter'
:return: {nameFilter: function} associations.
- :rtype: collections.OrderedDict
+ :rtype: dict
"""
- assert dataKind in ('all', 'curve', 'curves', 'image', 'scatter')
+ assert dataKind in ("all", "curve", "curves", "image", "scatter")
return self._filters[dataKind].copy()
def _actionTriggered(self, checked=False):
"""Handle save action."""
# Set-up filters
- filters = OrderedDict()
+ filters = {}
# Add image filters if there is an active image
if self.plot.getActiveImage() is not None:
- filters.update(self._filters['image'].items())
+ filters.update(self._filters["image"].items())
# Add curve filters if there is a curve to save
- if (self.plot.getActiveCurve() is not None or
- len(self.plot.getAllCurves()) == 1):
- filters.update(self._filters['curve'].items())
+ if self.plot.getActiveCurve() is not None or len(self.plot.getAllCurves()) == 1:
+ filters.update(self._filters["curve"].items())
if len(self.plot.getAllCurves()) >= 1:
- filters.update(self._filters['curves'].items())
+ filters.update(self._filters["curves"].items())
# Add scatter filters if there is a scatter
# todo: CSV
if self.plot.getScatter() is not None:
- filters.update(self._filters['scatter'].items())
+ filters.update(self._filters["scatter"].items())
- filters.update(self._filters['all'].items())
+ filters.update(self._filters["all"].items())
# Create and run File dialog
dialog = qt.QFileDialog(self.plot)
@@ -653,14 +705,18 @@ class SaveAction(PlotAction):
filename = dialog.selectedFiles()[0]
dialog.close()
- if '(' in nameFilter and ')' == nameFilter.strip()[-1]:
+ if "(" in nameFilter and ")" == nameFilter.strip()[-1]:
# Check for correct file extension
# Extract file extensions as .something
- extensions = [ext[ext.find('.'):] for ext in
- nameFilter[nameFilter.find('(') + 1:-1].split()]
+ extensions = [
+ ext[ext.find(".") :]
+ for ext in nameFilter[nameFilter.find("(") + 1 : -1].split()
+ ]
for ext in extensions:
- if (len(filename) > len(ext) and
- filename[-len(ext):].lower() == ext.lower()):
+ if (
+ len(filename) > len(ext)
+ and filename[-len(ext) :].lower() == ext.lower()
+ ):
break
else: # filename has no extension supported in nameFilter, add one
if len(extensions) >= 1:
@@ -671,7 +727,7 @@ class SaveAction(PlotAction):
if func is not None:
return func(self.plot, filename, nameFilter)
else:
- _logger.error('Unsupported file filter: %s', nameFilter)
+ _logger.error("Unsupported file filter: %s", nameFilter)
return False
@@ -681,7 +737,7 @@ def _plotAsPNG(plot):
:param plot: The :class:`Plot` to save
"""
pngFile = BytesIO()
- plot.saveGraph(pngFile, fileFormat='png')
+ plot.saveGraph(pngFile, fileFormat="png")
pngFile.flush()
pngFile.seek(0)
data = pngFile.read()
@@ -703,10 +759,14 @@ class PrintAction(PlotAction):
def __init__(self, plot, parent=None):
super(PrintAction, self).__init__(
- plot, icon='document-print', text='Print...',
- tooltip='Open print dialog',
+ plot,
+ icon="document-print",
+ text="Print...",
+ tooltip="Open print dialog",
triggered=self.printPlot,
- checkable=False, parent=parent)
+ checkable=False,
+ parent=parent,
+ )
self.setShortcut(qt.QKeySequence.Print)
self.setShortcutContext(qt.Qt.WidgetShortcut)
@@ -717,11 +777,6 @@ class PrintAction(PlotAction):
"""
return printer.getDefaultPrinter()
- @property
- @deprecated(replacement="getPrinter()", since_version="0.8.0")
- def printer(self):
- return self.getPrinter()
-
def printPlotAsWidget(self):
"""Open the print dialog and print the plot.
@@ -730,7 +785,7 @@ class PrintAction(PlotAction):
:return: True if successful
"""
dialog = qt.QPrintDialog(self.getPrinter(), self.plot)
- dialog.setWindowTitle('Print Plot')
+ dialog.setWindowTitle("Print Plot")
if not dialog.exec():
return False
@@ -746,9 +801,9 @@ class PrintAction(PlotAction):
yScale = pageRect.height() / widget.height()
scale = min(xScale, yScale)
- painter.translate(pageRect.width() / 2., 0.)
+ painter.translate(pageRect.width() / 2.0, 0.0)
painter.scale(scale, scale)
- painter.translate(-widget.width() / 2., 0.)
+ painter.translate(-widget.width() / 2.0, 0.0)
widget.render(painter)
painter.end()
@@ -763,7 +818,7 @@ class PrintAction(PlotAction):
"""
# Init printer and start printer dialog
dialog = qt.QPrintDialog(self.getPrinter(), self.plot)
- dialog.setWindowTitle('Print Plot')
+ dialog.setWindowTitle("Print Plot")
if not dialog.exec():
return False
@@ -771,7 +826,7 @@ class PrintAction(PlotAction):
pngData = _plotAsPNG(self.plot)
pixmap = qt.QPixmap()
- pixmap.loadFromData(pngData, 'png')
+ pixmap.loadFromData(pngData, "png")
pageRect = self.getPrinter().pageRect(qt.QPrinter.DevicePixel)
xScale = pageRect.width() / pixmap.width()
@@ -783,10 +838,9 @@ class PrintAction(PlotAction):
if not painter.begin(self.getPrinter()):
return False
- painter.drawPixmap(0, 0,
- pixmap.width() * scale,
- pixmap.height() * scale,
- pixmap)
+ painter.drawPixmap(
+ 0, 0, pixmap.width() * scale, pixmap.height() * scale, pixmap
+ )
painter.end()
return True
@@ -801,10 +855,14 @@ class CopyAction(PlotAction):
def __init__(self, plot, parent=None):
super(CopyAction, self).__init__(
- plot, icon='edit-copy', text='Copy plot',
- tooltip='Copy a snapshot of the plot into the clipboard',
+ plot,
+ icon="edit-copy",
+ text="Copy plot",
+ tooltip="Copy a snapshot of the plot into the clipboard",
triggered=self.copyPlot,
- checkable=False, parent=parent)
+ checkable=False,
+ parent=parent,
+ )
self.setShortcut(qt.QKeySequence.Copy)
self.setShortcutContext(qt.Qt.WidgetShortcut)
@@ -812,5 +870,5 @@ class CopyAction(PlotAction):
"""Copy plot content to the clipboard as a bitmap."""
# Save Plot as PNG and make a QImage from it with default dpi
pngData = _plotAsPNG(self.plot)
- image = qt.QImage.fromData(pngData, 'png')
+ image = qt.QImage.fromData(pngData, "png")
qt.QApplication.clipboard().setImage(image)
diff --git a/src/silx/gui/plot/actions/medfilt.py b/src/silx/gui/plot/actions/medfilt.py
index 25fcdb2..a335499 100644
--- a/src/silx/gui/plot/actions/medfilt.py
+++ b/src/silx/gui/plot/actions/medfilt.py
@@ -54,12 +54,14 @@ class MedianFilterAction(PlotToolAction):
"""
def __init__(self, plot, parent=None):
- PlotToolAction.__init__(self,
- plot,
- icon='median-filter',
- text='median filter',
- tooltip='Apply a median filter on the image',
- parent=parent)
+ PlotToolAction.__init__(
+ self,
+ plot,
+ icon="median-filter",
+ text="median filter",
+ tooltip="Apply a median filter on the image",
+ parent=parent,
+ )
self._originalImage = None
self._legend = None
self._filteredImage = None
@@ -85,7 +87,9 @@ class MedianFilterAction(PlotToolAction):
self._originalImage = None
self._legend = None
else:
- self._originalImage = self.plot.getImage(self._activeImageLegend).getData(copy=False)
+ self._originalImage = self.plot.getImage(self._activeImageLegend).getData(
+ copy=False
+ )
self._legend = self.plot.getImage(self._activeImageLegend).getName()
def _updateFilter(self, kernelWidth, conditional=False):
@@ -94,13 +98,11 @@ class MedianFilterAction(PlotToolAction):
self.plot.sigActiveImageChanged.disconnect(self._updateActiveImage)
filteredImage = self._computeFilteredImage(kernelWidth, conditional)
- self.plot.addImage(data=filteredImage,
- legend=self._legend,
- replace=True)
+ self.plot.addImage(data=filteredImage, legend=self._legend, replace=True)
self.plot.sigActiveImageChanged.connect(self._updateActiveImage)
def _computeFilteredImage(self, kernelWidth, conditional):
- raise NotImplementedError('MedianFilterAction is a an abstract class')
+ raise NotImplementedError("MedianFilterAction is a an abstract class")
def getFilteredImage(self):
"""
@@ -114,16 +116,13 @@ class MedianFilter1DAction(MedianFilterAction):
:param plot: :class:`.PlotWidget` instance on which to operate
:param parent: See :class:`QAction`
"""
+
def __init__(self, plot, parent=None):
- MedianFilterAction.__init__(self,
- plot,
- parent=parent)
+ MedianFilterAction.__init__(self, plot, parent=parent)
def _computeFilteredImage(self, kernelWidth, conditional):
- assert(self.plot is not None)
- return medfilt2d(self._originalImage,
- (kernelWidth, 1),
- conditional)
+ assert self.plot is not None
+ return medfilt2d(self._originalImage, (kernelWidth, 1), conditional)
class MedianFilter2DAction(MedianFilterAction):
@@ -132,13 +131,10 @@ class MedianFilter2DAction(MedianFilterAction):
:param plot: :class:`.PlotWidget` instance on which to operate
:param parent: See :class:`QAction`
"""
+
def __init__(self, plot, parent=None):
- MedianFilterAction.__init__(self,
- plot,
- parent=parent)
+ MedianFilterAction.__init__(self, plot, parent=parent)
def _computeFilteredImage(self, kernelWidth, conditional):
- assert(self.plot is not None)
- return medfilt2d(self._originalImage,
- (kernelWidth, kernelWidth),
- conditional)
+ assert self.plot is not None
+ return medfilt2d(self._originalImage, (kernelWidth, kernelWidth), conditional)
diff --git a/src/silx/gui/plot/actions/mode.py b/src/silx/gui/plot/actions/mode.py
index 7edc8bb..511a8df 100644
--- a/src/silx/gui/plot/actions/mode.py
+++ b/src/silx/gui/plot/actions/mode.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -35,10 +35,11 @@ __authors__ = ["V. Valls"]
__license__ = "MIT"
__date__ = "16/08/2017"
-from . import PlotAction
-import logging
-_logger = logging.getLogger(__name__)
+from silx.gui import qt
+
+from ..tools.menus import ZoomEnabledAxesMenu
+from . import PlotAction
class ZoomModeAction(PlotAction):
@@ -50,25 +51,58 @@ class ZoomModeAction(PlotAction):
def __init__(self, plot, parent=None):
super(ZoomModeAction, self).__init__(
- plot, icon='zoom', text='Zoom mode',
- tooltip='Zoom in or out',
+ plot,
+ icon="zoom",
+ text="Zoom mode",
+ tooltip="Zoom-in on mouse selection",
triggered=self._actionTriggered,
- checkable=True, parent=parent)
- # Listen to mode change
- self.plot.sigInteractiveModeChanged.connect(self._modeChanged)
+ checkable=True,
+ parent=parent,
+ )
+
+ self.__menu = ZoomEnabledAxesMenu(self.plot, self.plot)
+
+ # Listen to interaction configuration change
+ self.plot.interaction().sigChanged.connect(self._interactionChanged)
# Init the state
- self._modeChanged(None)
+ self._interactionChanged()
+
+ def isAxesMenuEnabled(self) -> bool:
+ """Returns whether the axes selection menu is enabled or not (default: False)"""
+ return self.menu() is self.__menu
+
+ def setAxesMenuEnabled(self, enabled: bool):
+ """Toggle the availability of the axes selection menu (default: False)"""
+ if enabled == self.isAxesMenuEnabled():
+ return
+
+ self.setMenu(self.__menu if enabled else None)
+
+ # Update associated QToolButton's popupMode if any, this is not done at least with Qt5
+ parent = self.parent()
+ if not isinstance(parent, qt.QToolBar):
+ return
+ widget = parent.widgetForAction(self)
+ if not isinstance(widget, qt.QToolButton):
+ return
+ widget.setPopupMode(
+ qt.QToolButton.MenuButtonPopup if enabled else qt.QToolButton.DelayedPopup
+ )
+ widget.update()
+
+ def _interactionChanged(self):
+ plot = self.plot
+ if plot is None:
+ return
- def _modeChanged(self, source):
- modeDict = self.plot.getInteractiveMode()
- old = self.blockSignals(True)
- self.setChecked(modeDict["mode"] == "zoom")
- self.blockSignals(old)
+ self.setChecked(plot.getInteractiveMode()["mode"] == "zoom")
def _actionTriggered(self, checked=False):
plot = self.plot
- if plot is not None:
- plot.setInteractiveMode('zoom', source=self)
+ if plot is None:
+ return
+
+ plot.setInteractiveMode("zoom", source=self)
class PanModeAction(PlotAction):
@@ -80,10 +114,14 @@ class PanModeAction(PlotAction):
def __init__(self, plot, parent=None):
super(PanModeAction, self).__init__(
- plot, icon='pan', text='Pan mode',
- tooltip='Pan the view',
+ plot,
+ icon="pan",
+ text="Pan mode",
+ tooltip="Pan the view",
triggered=self._actionTriggered,
- checkable=True, parent=parent)
+ checkable=True,
+ parent=parent,
+ )
# Listen to mode change
self.plot.sigInteractiveModeChanged.connect(self._modeChanged)
# Init the state
@@ -98,4 +136,4 @@ class PanModeAction(PlotAction):
def _actionTriggered(self, checked=False):
plot = self.plot
if plot is not None:
- plot.setInteractiveMode('pan', source=self)
+ plot.setInteractiveMode("pan", source=self)
diff --git a/src/silx/gui/plot/backends/BackendBase.py b/src/silx/gui/plot/backends/BackendBase.py
index d7653f3..8d70286 100755
--- a/src/silx/gui/plot/backends/BackendBase.py
+++ b/src/silx/gui/plot/backends/BackendBase.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -28,20 +28,26 @@ It documents the Plot backend API.
This API is a simplified version of PyMca PlotBackend API.
"""
+from __future__ import annotations
+
+
__authors__ = ["V.A. Sole", "T. Vincent"]
__license__ = "MIT"
__date__ = "21/12/2018"
+from collections.abc import Callable
import weakref
+from silx.gui.colors import RGBAColorType
+
from ... import qt
# Names for setCursor
-CURSOR_DEFAULT = 'default'
-CURSOR_POINTING = 'pointing'
-CURSOR_SIZE_HOR = 'size horizontal'
-CURSOR_SIZE_VER = 'size vertical'
-CURSOR_SIZE_ALL = 'size all'
+CURSOR_DEFAULT = "default"
+CURSOR_POINTING = "pointing"
+CURSOR_SIZE_HOR = "size horizontal"
+CURSOR_SIZE_VER = "size vertical"
+CURSOR_SIZE_ALL = "size all"
class BackendBase(object):
@@ -53,8 +59,8 @@ class BackendBase(object):
:param Plot plot: The Plot this backend is attached to
:param parent: The parent widget of the plot widget.
"""
- self.__xLimits = 1., 100.
- self.__yLimits = {'left': (1., 100.), 'right': (1., 100.)}
+ self.__xLimits = 1.0, 100.0
+ self.__yLimits = {"left": (1.0, 100.0), "right": (1.0, 100.0)}
self.__yAxisInverted = False
self.__keepDataAspectRatio = False
self.__xAxisTimeSeries = False
@@ -66,11 +72,11 @@ class BackendBase(object):
def _plot(self):
"""The plot this backend is attached to."""
if self._plotRef is None:
- raise RuntimeError('This backend is not attached to a Plot')
+ raise RuntimeError("This backend is not attached to a Plot")
plot = self._plotRef()
if plot is None:
- raise RuntimeError('This backend is no more attached to a Plot')
+ raise RuntimeError("This backend is no more attached to a Plot")
return plot
def _setPlot(self, plot):
@@ -82,11 +88,23 @@ class BackendBase(object):
# Add methods
- def addCurve(self, x, y,
- color, symbol, linewidth, linestyle,
- yaxis,
- xerror, yerror,
- fill, alpha, symbolsize, baseline):
+ def addCurve(
+ self,
+ x,
+ y,
+ color,
+ gapcolor,
+ symbol,
+ linewidth,
+ linestyle,
+ yaxis,
+ xerror,
+ yerror,
+ fill,
+ alpha,
+ symbolsize,
+ baseline,
+ ):
"""Add a 1D curve given by x an y to the graph.
:param numpy.ndarray x: The data corresponding to the x axis
@@ -94,6 +112,8 @@ class BackendBase(object):
:param color: color(s) to be used
:type color: string ("#RRGGBB") or (npoints, 4) unsigned byte array or
one of the predefined color names defined in colors.py
+ :param Union[str, None] gapcolor:
+ color used to fill dashed line gaps.
:param str symbol: Symbol to be drawn at each (x, y) position::
- ' ' or '' no symbol
@@ -106,13 +126,14 @@ class BackendBase(object):
- 's' square
:param float linewidth: The width of the curve in pixels
- :param str linestyle: Type of line::
+ :param linestyle: Type of line::
- ' ' or '' no line
- '-' solid line
- '--' dashed line
- '-.' dash-dot line
- ':' dotted line
+ - (offset, (dash pattern))
:param str yaxis: The Y axis this curve belongs to in: 'left', 'right'
:param xerror: Values with the uncertainties on the x values
@@ -127,9 +148,7 @@ class BackendBase(object):
"""
return object()
- def addImage(self, data,
- origin, scale,
- colormap, alpha):
+ def addImage(self, data, origin, scale, colormap, alpha):
"""Add an image to the plot.
:param numpy.ndarray data: (nrows, ncolumns) data or
@@ -147,8 +166,7 @@ class BackendBase(object):
"""
return object()
- def addTriangles(self, x, y, triangles,
- color, alpha):
+ def addTriangles(self, x, y, triangles, color, alpha):
"""Add a set of triangles.
:param numpy.ndarray x: The data corresponding to the x axis
@@ -161,8 +179,9 @@ class BackendBase(object):
"""
return object()
- def addShape(self, x, y, shape, color, fill, overlay,
- linestyle, linewidth, linebgcolor):
+ def addShape(
+ self, x, y, shape, color, fill, overlay, linestyle, linewidth, gapcolor
+ ):
"""Add an item (i.e. a shape) to the plot.
:param numpy.ndarray x: The X coords of the points of the shape
@@ -172,7 +191,7 @@ class BackendBase(object):
:param str color: Color of the item
:param bool fill: True to fill the shape
:param bool overlay: True if item is an overlay, False otherwise
- :param str linestyle: Style of the line.
+ :param linestyle: Style of the line.
Only relevant for line markers where X or Y is None.
Value in:
@@ -181,25 +200,39 @@ class BackendBase(object):
- '--' dashed line
- '-.' dash-dot line
- ':' dotted line
+ - (offset, (dash pattern))
:param float linewidth: Width of the line.
Only relevant for line markers where X or Y is None.
- :param str linebgcolor: Background color of the line, e.g., 'blue', 'b',
+ :param str gapcolor: Background color of the line, e.g., 'blue', 'b',
'#FF0000'. It is used to draw dotted line using a second color.
:returns: The handle used by the backend to univocally access the item
"""
return object()
- def addMarker(self, x, y, text, color,
- symbol, linestyle, linewidth, constraint, yaxis):
+ def addMarker(
+ self,
+ x: float | None,
+ y: float | None,
+ text: str | None,
+ color: str,
+ symbol: str | None,
+ linestyle: str | tuple[float, tuple[float, ...] | None],
+ linewidth: float,
+ constraint: Callable[[float, float], tuple[float, float]] | None,
+ yaxis: str,
+ font: qt.QFont,
+ bgcolor: RGBAColorType | None,
+ ) -> object:
"""Add a point, vertical line or horizontal line marker to the plot.
- :param float x: Horizontal position of the marker in graph coordinates.
- If None, the marker is a horizontal line.
- :param float y: Vertical position of the marker in graph coordinates.
- If None, the marker is a vertical line.
- :param str text: Text associated to the marker (or None for no text)
- :param str color: Color to be used for instance 'blue', 'b', '#FF0000'
- :param str symbol: Symbol representing the marker.
+ :param x: Horizontal position of the marker in graph coordinates.
+ If None, the marker is a horizontal line.
+ :param y: Vertical position of the marker in graph coordinates.
+ If None, the marker is a vertical line.
+ :param text: Text associated to the marker (or None for no text)
+ :param color: Color to be used for instance 'blue', 'b', '#FF0000'
+ :param bgcolor: Text background color to be used for instance 'blue', 'b', '#FF0000'
+ :param symbol: Symbol representing the marker.
Only relevant for point markers where X and Y are not None.
Value in:
@@ -210,7 +243,7 @@ class BackendBase(object):
- 'x' x-cross
- 'd' diamond
- 's' square
- :param str linestyle: Style of the line.
+ :param linestyle: Style of the line.
Only relevant for line markers where X or Y is None.
Value in:
@@ -219,16 +252,16 @@ class BackendBase(object):
- '--' dashed line
- '-.' dash-dot line
- ':' dotted line
- :param float linewidth: Width of the line.
+ - (offset, (dash pattern))
+ :param linewidth: Width of the line.
Only relevant for line markers where X or Y is None.
:param constraint: A function filtering marker displacement by
- dragging operations or None for no filter.
- This function is called each time a marker is
- moved.
- :type constraint: None or a callable that takes the coordinates of
- the current cursor position in the plot as input
- and that returns the filtered coordinates.
- :param str yaxis: The Y axis this marker belongs to in: 'left', 'right'
+ dragging operations or None for no filter.
+ This function is called each time a marker is moved.
+ It takes the coordinates of the current cursor position in the plot
+ as input and that returns the filtered coordinates.
+ :param yaxis: The Y axis this marker belongs to in: 'left', 'right'
+ :param font: QFont to use to render text
:return: Handle used by the backend to univocally access the marker
"""
return object()
@@ -270,8 +303,9 @@ class BackendBase(object):
- '--' dashed line
- '-.' dash-dot line
- ':' dotted line
+ - (offset, (dash pattern))
- :type linestyle: None or one of the predefined styles.
+ :type linestyle: None, one of the predefined styles or (offset, (dash pattern)).
"""
pass
@@ -295,8 +329,8 @@ class BackendBase(object):
content = [item for item in content if condition(item)]
return sorted(
- content,
- key=lambda i: ((1 if i.isOverlay() else 0), i.getZValue()))
+ content, key=lambda i: ((1 if i.isOverlay() else 0), i.getZValue())
+ )
def pickItem(self, x, y, item):
"""Return picked indices if any, or None.
@@ -384,9 +418,9 @@ class BackendBase(object):
:param float y2max: maximum right axis value
"""
self.__xLimits = xmin, xmax
- self.__yLimits['left'] = ymin, ymax
+ self.__yLimits["left"] = ymin, ymax
if y2min is not None and y2max is not None:
- self.__yLimits['right'] = y2min, y2max
+ self.__yLimits["right"] = y2min, y2max
def getGraphXLimits(self):
"""Get the graph X (bottom) limits.
@@ -422,7 +456,6 @@ class BackendBase(object):
# Graph axes
-
def getXAxisTimeZone(self):
"""Returns tzinfo that is used if the X-Axis plots date-times.
@@ -480,6 +513,10 @@ class BackendBase(object):
"""Return True if left Y axis is inverted, False otherwise."""
return self.__yAxisInverted
+ def isYRightAxisVisible(self) -> bool:
+ """Return True if the Y axis on the right side of the plot is visible"""
+ return False
+
def isKeepDataAspectRatio(self):
"""Returns whether the plot is keeping data aspect ratio or not."""
return self.__keepDataAspectRatio
@@ -553,7 +590,7 @@ class BackendBase(object):
def setForegroundColors(self, foregroundColor, gridColor):
"""Set foreground and grid colors used to display this widget.
-
+
:param List[float] foregroundColor: RGBA foreground color of the widget
:param List[float] gridColor: RGBA grid color of the data view
"""
diff --git a/src/silx/gui/plot/backends/BackendMatplotlib.py b/src/silx/gui/plot/backends/BackendMatplotlib.py
index 1b31582..facb63c 100755
--- a/src/silx/gui/plot/backends/BackendMatplotlib.py
+++ b/src/silx/gui/plot/backends/BackendMatplotlib.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2004-2022 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -23,6 +23,8 @@
# ###########################################################################*/
"""Matplotlib Plot backend."""
+from __future__ import annotations
+
__authors__ = ["V.A. Sole", "T. Vincent, H. Payno"]
__license__ = "MIT"
__date__ = "21/12/2018"
@@ -33,7 +35,7 @@ import datetime as dt
from typing import Tuple, Union
import numpy
-from pkg_resources import parse_version as _parse_version
+from packaging.version import Version
_logger = logging.getLogger(__name__)
@@ -42,7 +44,11 @@ _logger = logging.getLogger(__name__)
from ... import qt
# First of all init matplotlib and set its backend
-from ...utils.matplotlib import FigureCanvasQTAgg
+from ...utils.matplotlib import (
+ DefaultTickFormatter,
+ FigureCanvasQTAgg,
+ qFontToFontProperties,
+)
import matplotlib
from matplotlib.container import Container
from matplotlib.figure import Figure
@@ -52,7 +58,7 @@ from matplotlib.backend_bases import MouseEvent
from matplotlib.lines import Line2D
from matplotlib.text import Text
from matplotlib.collections import PathCollection, LineCollection
-from matplotlib.ticker import Formatter, ScalarFormatter, Locator
+from matplotlib.ticker import Formatter, Locator
from matplotlib.tri import Triangulation
from matplotlib.collections import TriMesh
from matplotlib import path as mpath
@@ -60,15 +66,21 @@ from matplotlib import path as mpath
from . import BackendBase
from .. import items
from .._utils import FLOAT32_MINPOS
-from .._utils.dtime_ticklayout import calcTicks, bestFormatString, timestamp
+from .._utils.dtime_ticklayout import (
+ calcTicks,
+ formatDatetimes,
+ timestamp,
+)
from ...qt import inspect as qt_inspect
+from .... import config
+from silx.gui.colors import RGBAColorType
_PATCH_LINESTYLE = {
- "-": 'solid',
- "--": 'dashed',
- '-.': 'dashdot',
- ':': 'dotted',
- '': "solid",
+ "-": "solid",
+ "--": "dashed",
+ "-.": "dashdot",
+ ":": "dotted",
+ "": "solid",
None: "solid",
}
"""Patches do not uses the same matplotlib syntax"""
@@ -77,14 +89,14 @@ _MARKER_PATHS = {}
"""Store cached extra marker paths"""
_SPECIAL_MARKERS = {
- 'tickleft': 0,
- 'tickright': 1,
- 'tickup': 2,
- 'tickdown': 3,
- 'caretleft': 4,
- 'caretright': 5,
- 'caretup': 6,
- 'caretdown': 7,
+ "tickleft": 0,
+ "tickright": 1,
+ "tickup": 2,
+ "tickdown": 3,
+ "caretleft": 4,
+ "caretright": 5,
+ "caretup": 6,
+ "caretdown": 7,
}
@@ -92,6 +104,7 @@ def normalize_linestyle(linestyle):
"""Normalize known old-style linestyle, else return the provided value."""
return _PATCH_LINESTYLE.get(linestyle, linestyle)
+
def get_path_from_symbol(symbol):
"""Get the path representation of a symbol, else None if
it is not provided.
@@ -99,21 +112,40 @@ def get_path_from_symbol(symbol):
:param str symbol: Symbol description used by silx
:rtype: Union[None,matplotlib.path.Path]
"""
- if symbol == u'\u2665':
+ if symbol == "\u2665":
path = _MARKER_PATHS.get(symbol, None)
if path is not None:
return path
- vertices = numpy.array([
- [0,-99],
- [31,-73], [47,-55], [55,-46],
- [63,-37], [94,-2], [94,33],
- [94,69], [71,89], [47,89],
- [24,89], [8,74], [0,58],
- [-8,74], [-24,89], [-47,89],
- [-71,89], [-94,69], [-94,33],
- [-94,-2], [-63,-37], [-55,-46],
- [-47,-55], [-31,-73], [0,-99],
- [0,-99]])
+ vertices = numpy.array(
+ [
+ [0, -99],
+ [31, -73],
+ [47, -55],
+ [55, -46],
+ [63, -37],
+ [94, -2],
+ [94, 33],
+ [94, 69],
+ [71, 89],
+ [47, 89],
+ [24, 89],
+ [8, 74],
+ [0, 58],
+ [-8, 74],
+ [-24, 89],
+ [-47, 89],
+ [-71, 89],
+ [-94, 69],
+ [-94, 33],
+ [-94, -2],
+ [-63, -37],
+ [-55, -46],
+ [-47, -55],
+ [-31, -73],
+ [0, -99],
+ [0, -99],
+ ]
+ )
codes = [mpath.Path.CURVE4] * len(vertices)
codes[0] = mpath.Path.MOVETO
codes[-1] = mpath.Path.CLOSEPOLY
@@ -122,6 +154,7 @@ def get_path_from_symbol(symbol):
return path
return None
+
class NiceDateLocator(Locator):
"""
Matplotlib Locator that uses Nice Numbers algorithm (adapted to dates)
@@ -130,6 +163,7 @@ class NiceDateLocator(Locator):
Expects the data to be posix timestampes (i.e. seconds since 1970)
"""
+
def __init__(self, numTicks=5, tz=None):
"""
:param numTicks: target number of ticks
@@ -144,12 +178,12 @@ class NiceDateLocator(Locator):
@property
def spacing(self):
- """ The current spacing. Will be updated when new tick value are made"""
+ """The current spacing. Will be updated when new tick value are made"""
return self._spacing
@property
def unit(self):
- """ The current DtUnit. Will be updated when new tick value are made"""
+ """The current DtUnit. Will be updated when new tick value are made"""
return self._unit
def __call__(self):
@@ -158,8 +192,7 @@ class NiceDateLocator(Locator):
return self.tick_values(vmin, vmax)
def tick_values(self, vmin, vmax):
- """ Calculates tick values
- """
+ """Calculates tick values"""
if vmax < vmin:
vmin, vmax = vmax, vmin
@@ -171,8 +204,7 @@ class NiceDateLocator(Locator):
_logger.warning("Data range cannot be displayed with time axis")
return []
- dtTicks, self._spacing, self._unit = \
- calcTicks(dtMin, dtMax, self.numTicks)
+ dtTicks, self._spacing, self._unit = calcTicks(dtMin, dtMax, self.numTicks)
# Convert datetime back to time stamps.
ticks = [timestamp(dtTick) for dtTick in dtTicks]
@@ -194,21 +226,25 @@ class NiceAutoDateFormatter(Formatter):
self.locator = locator
self.tz = tz
- @property
- def formatString(self):
- if self.locator.spacing is None or self.locator.unit is None:
- # Locator has no spacing or units yet. Return elaborate fmtString
- return "Y-%m-%d %H:%M:%S"
- else:
- return bestFormatString(self.locator.spacing, self.locator.unit)
-
def __call__(self, x, pos=None):
"""Return the format for tick val *x* at position *pos*
- Expects x to be a POSIX timestamp (seconds since 1 Jan 1970)
+ Expects x to be a POSIX timestamp (seconds since 1 Jan 1970)
"""
- dateTime = dt.datetime.fromtimestamp(x, tz=self.tz)
- tickStr = dateTime.strftime(self.formatString)
- return tickStr
+ datetime = dt.datetime.fromtimestamp(x, tz=self.tz)
+ return formatDatetimes(
+ [datetime],
+ self.locator.spacing,
+ self.locator.unit,
+ )[datetime]
+
+ def format_ticks(self, values):
+ return tuple(
+ formatDatetimes(
+ [dt.datetime.fromtimestamp(value, tz=self.tz) for value in values],
+ self.locator.spacing,
+ self.locator.unit,
+ ).values()
+ )
class _PickableContainer(Container):
@@ -222,7 +258,7 @@ class _PickableContainer(Container):
def axes(self):
"""Mimin Artist.axes"""
for child in self.get_children():
- if hasattr(child, 'axes'):
+ if hasattr(child, "axes"):
return child.axes
return None
@@ -354,18 +390,19 @@ class _MarkerContainer(_PickableContainer):
:param yinverted: True if the y axis is inverted
"""
if self.text is not None:
- visible = ((self.x is None or xmin <= self.x <= xmax) and
- (self.y is None or ymin <= self.y <= ymax))
+ visible = (self.x is None or xmin <= self.x <= xmax) and (
+ self.y is None or ymin <= self.y <= ymax
+ )
self.text.set_visible(visible)
if self.x is not None and self.y is not None:
if self.symbol is None:
- valign = 'baseline'
+ valign = "baseline"
else:
if yinverted:
- valign = 'bottom'
+ valign = "bottom"
else:
- valign = 'top'
+ valign = "top"
self.text.set_verticalalignment(valign)
elif self.y is None: # vertical line
@@ -393,42 +430,47 @@ class _MarkerContainer(_PickableContainer):
return self.line.contains(mouseevent)
-class _DoubleColoredLinePatch(matplotlib.patches.Patch):
- """Matplotlib patch to display any patch using double color."""
+class SecondEdgeColorPatchMixIn:
+ """Mix-in class to add a second color for patches with dashed lines"""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._second_edgecolor = None
- def __init__(self, patch):
- super(_DoubleColoredLinePatch, self).__init__()
- self.__patch = patch
- self.linebgcolor = None
+ def set_second_edgecolor(self, color):
+ """Set the second color used to fill dashed edges"""
+ self._second_edgecolor = color
- def __getattr__(self, name):
- return getattr(self.__patch, name)
+ def get_second_edgecolor(self):
+ """Returns the second color used to fill dashed edges"""
+ return self._second_edgecolor
def draw(self, renderer):
- oldLineStype = self.__patch.get_linestyle()
- if self.linebgcolor is not None and oldLineStype != "solid":
- oldLineColor = self.__patch.get_edgecolor()
- oldHatch = self.__patch.get_hatch()
- self.__patch.set_linestyle("solid")
- self.__patch.set_edgecolor(self.linebgcolor)
- self.__patch.set_hatch(None)
- self.__patch.draw(renderer)
- self.__patch.set_linestyle(oldLineStype)
- self.__patch.set_edgecolor(oldLineColor)
- self.__patch.set_hatch(oldHatch)
- self.__patch.draw(renderer)
+ linestyle = self.get_linestyle()
+ if linestyle == "solid" or self.get_second_edgecolor() is None:
+ super().draw(renderer)
+ return
- def set_transform(self, transform):
- self.__patch.set_transform(transform)
+ edgecolor = self.get_edgecolor()
+ hatch = self.get_hatch()
- def get_path(self):
- return self.__patch.get_path()
+ self.set_linestyle("solid")
+ self.set_edgecolor(self.get_second_edgecolor())
+ self.set_hatch(None)
+ super().draw(renderer)
- def contains(self, mouseevent, radius=None):
- return self.__patch.contains(mouseevent, radius)
+ self.set_linestyle(linestyle)
+ self.set_edgecolor(edgecolor)
+ self.set_hatch(hatch)
+ super().draw(renderer)
- def contains_point(self, point, radius=None):
- return self.__patch.contains_point(point, radius)
+
+class Rectangle2EdgeColor(SecondEdgeColorPatchMixIn, Rectangle):
+ """Rectangle patch with a second edge color for dashed line"""
+
+
+class Polygon2EdgeColor(SecondEdgeColorPatchMixIn, Polygon):
+ """Polygon patch with a second edge color for dashed line"""
class Image(AxesImage):
@@ -438,10 +480,7 @@ class Image(AxesImage):
:param List[float] silx_scale: (sx, sy) Scale of the image.
"""
- def __init__(self, *args,
- silx_origin=(0., 0.),
- silx_scale=(1., 1.),
- **kwargs):
+ def __init__(self, *args, silx_origin=(0.0, 0.0), silx_scale=(1.0, 1.0), **kwargs):
super().__init__(*args, **kwargs)
self.__silx_origin = silx_origin
self.__silx_scale = silx_scale
@@ -456,7 +495,7 @@ class Image(AxesImage):
height, width = self.get_size()
column = numpy.clip(int((x - ox) / sx), 0, width - 1)
row = numpy.clip(int((y - oy) / sy), 0, height - 1)
- info['ind'] = (row,), (column,)
+ info["ind"] = (row,), (column,)
return inside, info
def set_data(self, A):
@@ -489,12 +528,17 @@ class BackendMatplotlib(BackendBase.BackendBase):
# when getting the limits at the expense of a replot
self._dirtyLimits = True
self._axesDisplayed = True
- self._matplotlibVersion = _parse_version(matplotlib.__version__)
+ self._matplotlibVersion = Version(matplotlib.__version__)
- self.fig = Figure()
+ self.fig = Figure(
+ tight_layout=config._MPL_TIGHT_LAYOUT,
+ )
self.fig.set_facecolor("w")
- self.ax = self.fig.add_axes([.15, .15, .75, .75], label="left")
+ if config._MPL_TIGHT_LAYOUT:
+ self.ax = self.fig.add_subplot(label="left")
+ else:
+ self.ax = self.fig.add_axes([0.15, 0.15, 0.75, 0.75], label="left")
self.ax2 = self.ax.twinx()
self.ax2.set_label("right")
# Make sure background of Axes is displayed
@@ -504,28 +548,17 @@ class BackendMatplotlib(BackendBase.BackendBase):
# Set axis zorder=0.5 so grid is displayed at 0.5
self.ax.set_axisbelow(True)
- # disable the use of offsets
- try:
- axes = [
- self.ax.get_yaxis().get_major_formatter(),
- self.ax.get_xaxis().get_major_formatter(),
- self.ax2.get_yaxis().get_major_formatter(),
- self.ax2.get_xaxis().get_major_formatter(),
- ]
- for axis in axes:
- axis.set_useOffset(False)
- axis.set_scientific(False)
- except:
- _logger.warning('Cannot disabled axes offsets in %s '
- % matplotlib.__version__)
+ # Configure axes tick label formatter
+ for axis in (self.ax.yaxis, self.ax.xaxis, self.ax2.yaxis, self.ax2.xaxis):
+ axis.set_major_formatter(DefaultTickFormatter())
self.ax2.set_autoscaley_on(True)
# this works but the figure color is left
- if self._matplotlibVersion < _parse_version('2'):
- self.ax.set_axis_bgcolor('none')
+ if self._matplotlibVersion < Version("2"):
+ self.ax.set_axis_bgcolor("none")
else:
- self.ax.set_facecolor('none')
+ self.ax.set_facecolor("none")
self.fig.sca(self.ax)
self._background = None
@@ -534,30 +567,33 @@ class BackendMatplotlib(BackendBase.BackendBase):
self._graphCursor = tuple()
- self._enableAxis('right', False)
+ self._enableAxis("right", False)
self._isXAxisTimeSeries = False
def getItemsFromBackToFront(self, condition=None):
"""Order as BackendBase + take into account matplotlib Axes structure"""
+
def axesOrder(item):
if item.isOverlay():
return 2
- elif isinstance(item, items.YAxisMixIn) and item.getYAxis() == 'right':
+ elif isinstance(item, items.YAxisMixIn) and item.getYAxis() == "right":
return 1
else:
return 0
return sorted(
- BackendBase.BackendBase.getItemsFromBackToFront(
- self, condition=condition),
- key=axesOrder)
+ BackendBase.BackendBase.getItemsFromBackToFront(self, condition=condition),
+ key=axesOrder,
+ )
def _overlayItems(self):
"""Generator of backend renderer for overlay items"""
for item in self._plot.getItems():
- if (item.isOverlay() and
- item.isVisible() and
- item._backendRenderer is not None):
+ if (
+ item.isOverlay()
+ and item.isVisible()
+ and item._backendRenderer is not None
+ ):
yield item._backendRenderer
def _hasOverlays(self):
@@ -591,19 +627,40 @@ class BackendMatplotlib(BackendBase.BackendBase):
# This symbol must be supported by matplotlib
return symbol
- def addCurve(self, x, y,
- color, symbol, linewidth, linestyle,
- yaxis,
- xerror, yerror,
- fill, alpha, symbolsize, baseline):
- for parameter in (x, y, color, symbol, linewidth, linestyle,
- yaxis, fill, alpha, symbolsize):
+ def addCurve(
+ self,
+ x,
+ y,
+ color,
+ gapcolor,
+ symbol,
+ linewidth,
+ linestyle,
+ yaxis,
+ xerror,
+ yerror,
+ fill,
+ alpha,
+ symbolsize,
+ baseline,
+ ):
+ for parameter in (
+ x,
+ y,
+ color,
+ symbol,
+ linewidth,
+ linestyle,
+ yaxis,
+ fill,
+ alpha,
+ symbolsize,
+ ):
assert parameter is not None
- assert yaxis in ('left', 'right')
+ assert yaxis in ("left", "right")
- if (len(color) == 4 and
- type(color[3]) in [type(1), numpy.uint8, numpy.int8]):
- color = numpy.array(color, dtype=numpy.float64) / 255.
+ if len(color) == 4 and type(color[3]) in [type(1), numpy.uint8, numpy.int8]:
+ color = numpy.array(color, dtype=numpy.float64) / 255.0
if yaxis == "right":
axes = self.ax2
@@ -617,50 +674,62 @@ class BackendMatplotlib(BackendBase.BackendBase):
# First add errorbars if any so they are behind the curve
if xerror is not None or yerror is not None:
- if hasattr(color, 'dtype') and len(color) == len(x):
- errorbarColor = 'k'
+ if hasattr(color, "dtype") and len(color) == len(x):
+ errorbarColor = "k"
else:
errorbarColor = color
# Nx1 error array deprecated in matplotlib >=3.1 (removed in 3.3)
- if (isinstance(xerror, numpy.ndarray) and xerror.ndim == 2 and
- xerror.shape[1] == 1):
+ if (
+ isinstance(xerror, numpy.ndarray)
+ and xerror.ndim == 2
+ and xerror.shape[1] == 1
+ ):
xerror = numpy.ravel(xerror)
- if (isinstance(yerror, numpy.ndarray) and yerror.ndim == 2 and
- yerror.shape[1] == 1):
+ if (
+ isinstance(yerror, numpy.ndarray)
+ and yerror.ndim == 2
+ and yerror.shape[1] == 1
+ ):
yerror = numpy.ravel(yerror)
- errorbars = axes.errorbar(x, y,
- xerr=xerror, yerr=yerror,
- linestyle=' ', color=errorbarColor)
+ errorbars = axes.errorbar(
+ x, y, xerr=xerror, yerr=yerror, linestyle=" ", color=errorbarColor
+ )
artists += list(errorbars.get_children())
- if hasattr(color, 'dtype') and len(color) == len(x):
+ if hasattr(color, "dtype") and len(color) == len(x):
# scatter plot
if color.dtype not in [numpy.float32, numpy.float64]:
- actualColor = color / 255.
+ actualColor = color / 255.0
else:
actualColor = color
if linestyle not in ["", " ", None]:
# scatter plot with an actual line ...
# we need to assign a color ...
- curveList = axes.plot(x, y,
- linestyle=linestyle,
- color=actualColor[0],
- linewidth=linewidth,
- picker=True,
- pickradius=pickradius,
- marker=None)
+ curveList = axes.plot(
+ x,
+ y,
+ linestyle=linestyle,
+ color=actualColor[0],
+ linewidth=linewidth,
+ picker=True,
+ pickradius=pickradius,
+ marker=None,
+ )
artists += list(curveList)
marker = self._getMarkerFromSymbol(symbol)
- scatter = axes.scatter(x, y,
- color=actualColor,
- marker=marker,
- picker=True,
- pickradius=pickradius,
- s=symbolsize**2)
+ scatter = axes.scatter(
+ x,
+ y,
+ color=actualColor,
+ marker=marker,
+ picker=True,
+ pickradius=pickradius,
+ s=symbolsize**2,
+ )
artists.append(scatter)
if fill:
@@ -668,18 +737,28 @@ class BackendMatplotlib(BackendBase.BackendBase):
_baseline = FLOAT32_MINPOS
else:
_baseline = baseline
- artists.append(axes.fill_between(
- x, _baseline, y, facecolor=actualColor[0], linestyle=''))
+ artists.append(
+ axes.fill_between(
+ x, _baseline, y, facecolor=actualColor[0], linestyle=""
+ )
+ )
else: # Curve
- curveList = axes.plot(x, y,
- linestyle=linestyle,
- color=color,
- linewidth=linewidth,
- marker=symbol,
- picker=True,
- pickradius=pickradius,
- markersize=symbolsize)
+ curveList = axes.plot(
+ x,
+ y,
+ linestyle=linestyle,
+ color=color,
+ linewidth=linewidth,
+ marker=symbol,
+ picker=True,
+ pickradius=pickradius,
+ markersize=symbolsize,
+ )
+
+ if gapcolor is not None and self._matplotlibVersion >= Version("3.6.0"):
+ for line2d in curveList:
+ line2d.set_gapcolor(gapcolor)
artists += list(curveList)
if fill:
@@ -687,8 +766,7 @@ class BackendMatplotlib(BackendBase.BackendBase):
_baseline = FLOAT32_MINPOS
else:
_baseline = baseline
- artists.append(
- axes.fill_between(x, _baseline, y, facecolor=color))
+ artists.append(axes.fill_between(x, _baseline, y, facecolor=color))
for artist in artists:
if alpha < 1:
@@ -709,12 +787,14 @@ class BackendMatplotlib(BackendBase.BackendBase):
height, width = data.shape[0:2]
# All image are shown as RGBA image
- image = Image(self.ax,
- interpolation='nearest',
- picker=True,
- origin='lower',
- silx_origin=origin,
- silx_scale=scale)
+ image = Image(
+ self.ax,
+ interpolation="nearest",
+ picker=True,
+ origin="lower",
+ silx_origin=origin,
+ silx_scale=scale,
+ )
if alpha < 1:
image.set_alpha(alpha)
@@ -722,21 +802,21 @@ class BackendMatplotlib(BackendBase.BackendBase):
# Set image extent
xmin = origin[0]
xmax = xmin + scale[0] * width
- if scale[0] < 0.:
+ if scale[0] < 0.0:
xmin, xmax = xmax, xmin
ymin = origin[1]
ymax = ymin + scale[1] * height
- if scale[1] < 0.:
+ if scale[1] < 0.0:
ymin, ymax = ymax, ymin
image.set_extent((xmin, xmax, ymin, ymax))
# Set image data
- if scale[0] < 0. or scale[1] < 0.:
+ if scale[0] < 0.0 or scale[1] < 0.0:
# For negative scale, step by -1
- xstep = 1 if scale[0] >= 0. else -1
- ystep = 1 if scale[1] >= 0. else -1
+ xstep = 1 if scale[0] >= 0.0 else -1
+ ystep = 1 if scale[1] >= 0.0 else -1
data = data[::ystep, ::xstep]
if data.ndim == 2: # Data image, convert to RGBA image
@@ -745,7 +825,7 @@ class BackendMatplotlib(BackendBase.BackendBase):
# Normalize uint16 data to have a similar behavior as opengl backend
data = data.astype(numpy.float32)
data /= 65535
-
+
image.set_data(data)
self.ax.add_artist(image)
return image
@@ -758,87 +838,92 @@ class BackendMatplotlib(BackendBase.BackendBase):
assert color.ndim == 2 and len(color) == len(x)
if color.dtype not in [numpy.float32, numpy.float64]:
- color = color.astype(numpy.float32) / 255.
+ color = color.astype(numpy.float32) / 255.0
collection = TriMesh(
- Triangulation(x, y, triangles),
- alpha=alpha,
- pickradius=0) # 0 enables picking on filled triangle
+ Triangulation(x, y, triangles), alpha=alpha, pickradius=0
+ ) # 0 enables picking on filled triangle
collection.set_color(color)
self.ax.add_collection(collection)
return collection
- def addShape(self, x, y, shape, color, fill, overlay,
- linestyle, linewidth, linebgcolor):
- if (linebgcolor is not None and
- shape not in ('rectangle', 'polygon', 'polylines')):
+ def addShape(
+ self, x, y, shape, color, fill, overlay, linestyle, linewidth, gapcolor
+ ):
+ if gapcolor is not None and shape not in (
+ "rectangle",
+ "polygon",
+ "polylines",
+ ):
_logger.warning(
- 'linebgcolor not implemented for %s with matplotlib backend',
- shape)
+ "gapcolor not implemented for %s with matplotlib backend", shape
+ )
xView = numpy.array(x, copy=False)
yView = numpy.array(y, copy=False)
linestyle = normalize_linestyle(linestyle)
if shape == "line":
- item = self.ax.plot(x, y, color=color,
- linestyle=linestyle, linewidth=linewidth,
- marker=None)[0]
+ item = self.ax.plot(
+ x, y, color=color, linestyle=linestyle, linewidth=linewidth, marker=None
+ )[0]
elif shape == "hline":
if hasattr(y, "__len__"):
y = y[-1]
- item = self.ax.axhline(y, color=color,
- linestyle=linestyle, linewidth=linewidth)
+ item = self.ax.axhline(
+ y, color=color, linestyle=linestyle, linewidth=linewidth
+ )
elif shape == "vline":
if hasattr(x, "__len__"):
x = x[-1]
- item = self.ax.axvline(x, color=color,
- linestyle=linestyle, linewidth=linewidth)
+ item = self.ax.axvline(
+ x, color=color, linestyle=linestyle, linewidth=linewidth
+ )
- elif shape == 'rectangle':
+ elif shape == "rectangle":
xMin = numpy.nanmin(xView)
xMax = numpy.nanmax(xView)
yMin = numpy.nanmin(yView)
yMax = numpy.nanmax(yView)
w = xMax - xMin
h = yMax - yMin
- item = Rectangle(xy=(xMin, yMin),
- width=w,
- height=h,
- fill=False,
- color=color,
- linestyle=linestyle,
- linewidth=linewidth)
- if fill:
- item.set_hatch('.')
+ item = Rectangle2EdgeColor(
+ xy=(xMin, yMin),
+ width=w,
+ height=h,
+ fill=False,
+ color=color,
+ linestyle=linestyle,
+ linewidth=linewidth,
+ )
+ item.set_second_edgecolor(gapcolor)
- if linestyle != "solid" and linebgcolor is not None:
- item = _DoubleColoredLinePatch(item)
- item.linebgcolor = linebgcolor
+ if fill:
+ item.set_hatch(".")
self.ax.add_patch(item)
- elif shape in ('polygon', 'polylines'):
+ elif shape in ("polygon", "polylines"):
points = numpy.array((xView, yView)).T
- if shape == 'polygon':
+ if shape == "polygon":
closed = True
else: # shape == 'polylines'
closed = numpy.all(numpy.equal(points[0], points[-1]))
- item = Polygon(points,
- closed=closed,
- fill=False,
- color=color,
- linestyle=linestyle,
- linewidth=linewidth)
- if fill and shape == 'polygon':
- item.set_hatch('/')
-
- if linestyle != "solid" and linebgcolor is not None:
- item = _DoubleColoredLinePatch(item)
- item.linebgcolor = linebgcolor
+ item = Polygon2EdgeColor(
+ points,
+ closed=closed,
+ fill=False,
+ color=color,
+ linestyle=linestyle,
+ linewidth=linewidth,
+ )
+ item.set_second_edgecolor(gapcolor)
+
+ if fill and shape == "polygon":
+ item.set_hatch("/")
self.ax.add_patch(item)
@@ -850,61 +935,87 @@ class BackendMatplotlib(BackendBase.BackendBase):
return item
- def addMarker(self, x, y, text, color,
- symbol, linestyle, linewidth, constraint, yaxis):
+ def addMarker(
+ self,
+ x,
+ y,
+ text,
+ color,
+ symbol,
+ linestyle,
+ linewidth,
+ constraint,
+ yaxis,
+ font,
+ bgcolor: RGBAColorType | None,
+ ):
textArtist = None
+ fontProperties = None if font is None else qFontToFontProperties(font)
xmin, xmax = self.getGraphXLimits()
ymin, ymax = self.getGraphYLimits(axis=yaxis)
- if yaxis == 'left':
+ if yaxis == "left":
ax = self.ax
- elif yaxis == 'right':
+ elif yaxis == "right":
ax = self.ax2
else:
- assert(False)
+ assert False
+
+ if bgcolor is None:
+ bgcolor = "none"
marker = self._getMarkerFromSymbol(symbol)
if x is not None and y is not None:
- line = ax.plot(x, y,
- linestyle=" ",
- color=color,
- marker=marker,
- markersize=10.)[-1]
+ line = ax.plot(
+ x, y, linestyle=" ", color=color, marker=marker, markersize=10.0
+ )[-1]
if text is not None:
- textArtist = _TextWithOffset(x, y, text,
- color=color,
- horizontalalignment='left')
+ textArtist = _TextWithOffset(
+ x,
+ y,
+ text,
+ color=color,
+ backgroundcolor=bgcolor,
+ horizontalalignment="left",
+ fontproperties=fontProperties,
+ )
if symbol is not None:
textArtist.pixel_offset = 10, 3
elif x is not None:
- line = ax.axvline(x,
- color=color,
- linewidth=linewidth,
- linestyle=linestyle)
+ line = ax.axvline(x, color=color, linewidth=linewidth, linestyle=linestyle)
if text is not None:
# Y position will be updated in updateMarkerText call
- textArtist = _TextWithOffset(x, 1., text,
- color=color,
- horizontalalignment='left',
- verticalalignment='top')
+ textArtist = _TextWithOffset(
+ x,
+ 1.0,
+ text,
+ color=color,
+ backgroundcolor=bgcolor,
+ horizontalalignment="left",
+ verticalalignment="top",
+ fontproperties=fontProperties,
+ )
textArtist.pixel_offset = 5, 3
elif y is not None:
- line = ax.axhline(y,
- color=color,
- linewidth=linewidth,
- linestyle=linestyle)
+ line = ax.axhline(y, color=color, linewidth=linewidth, linestyle=linestyle)
if text is not None:
# X position will be updated in updateMarkerText call
- textArtist = _TextWithOffset(1., y, text,
- color=color,
- horizontalalignment='right',
- verticalalignment='top')
+ textArtist = _TextWithOffset(
+ 1.0,
+ y,
+ text,
+ color=color,
+ backgroundcolor=bgcolor,
+ horizontalalignment="right",
+ verticalalignment="top",
+ fontproperties=fontProperties,
+ )
textArtist.pixel_offset = 5, 3
else:
- raise RuntimeError('A marker must at least have one coordinate')
+ raise RuntimeError("A marker must at least have one coordinate")
line.set_picker(True)
line.set_pickradius(5)
@@ -928,7 +1039,7 @@ class BackendMatplotlib(BackendBase.BackendBase):
yinverted = self.isYAxisInverted()
for item in self._overlayItems():
if isinstance(item, _MarkerContainer):
- if item.yAxis == 'left':
+ if item.yAxis == "left":
item.updateMarkerText(xmin, xmax, ymin1, ymax1, yinverted)
else:
item.updateMarkerText(xmin, xmax, ymin2, ymax2, yinverted)
@@ -946,13 +1057,21 @@ class BackendMatplotlib(BackendBase.BackendBase):
def setGraphCursor(self, flag, color, linewidth, linestyle):
if flag:
lineh = self.ax.axhline(
- self.ax.get_ybound()[0], visible=False, color=color,
- linewidth=linewidth, linestyle=linestyle)
+ self.ax.get_ybound()[0],
+ visible=False,
+ color=color,
+ linewidth=linewidth,
+ linestyle=linestyle,
+ )
lineh.set_animated(True)
linev = self.ax.axvline(
- self.ax.get_xbound()[0], visible=False, color=color,
- linewidth=linewidth, linestyle=linestyle)
+ self.ax.get_xbound()[0],
+ visible=False,
+ color=color,
+ linewidth=linewidth,
+ linestyle=linestyle,
+ )
linev.set_animated(True)
self._graphCursor = lineh, linev
@@ -974,8 +1093,7 @@ class BackendMatplotlib(BackendBase.BackendBase):
artist.set_facecolors(color)
artist.set_edgecolors(color)
else:
- _logger.warning(
- 'setActiveCurve ignoring artist %s', str(artist))
+ _logger.warning("setActiveCurve ignoring artist %s", str(artist))
# Misc.
@@ -988,8 +1106,8 @@ class BackendMatplotlib(BackendBase.BackendBase):
:param str axis: Axis name: 'left' or 'right'
:param bool flag: Default, True
"""
- assert axis in ('right', 'left')
- axes = self.ax2 if axis == 'right' else self.ax
+ assert axis in ("right", "left")
+ axes = self.ax2 if axis == "right" else self.ax
axes.get_yaxis().set_visible(flag)
def replot(self):
@@ -1007,18 +1125,20 @@ class BackendMatplotlib(BackendBase.BackendBase):
# Hide right Y axis if no line is present
self._dirtyLimits = False
if not self.ax2.lines:
- self._enableAxis('right', False)
+ self._enableAxis("right", False)
def _drawOverlays(self):
"""Draw overlays if any."""
+
def condition(item):
- return (item.isVisible() and
- item._backendRenderer is not None and
- item.isOverlay())
+ return (
+ item.isVisible()
+ and item._backendRenderer is not None
+ and item.isOverlay()
+ )
for item in self.getItemsFromBackToFront(condition=condition):
- if (isinstance(item, items.YAxisMixIn) and
- item.getYAxis() == 'right'):
+ if isinstance(item, items.YAxisMixIn) and item.getYAxis() == "right":
axes = self.ax2
else:
axes = self.ax
@@ -1030,14 +1150,15 @@ class BackendMatplotlib(BackendBase.BackendBase):
def updateZOrder(self):
"""Reorder all items with z order from 0 to 1"""
items = self.getItemsFromBackToFront(
- lambda item: item.isVisible() and item._backendRenderer is not None)
+ lambda item: item.isVisible() and item._backendRenderer is not None
+ )
count = len(items)
for index, item in enumerate(items):
if item.getZValue() < 0.5:
# Make sure matplotlib z order is below the grid (with z=0.5)
zorder = 0.5 * index / count
else: # Make sure matplotlib z order is above the grid (> 0.5)
- zorder = 1. + index / count
+ zorder = 1.0 + index / count
if zorder != item._backendRenderer.get_zorder():
item._backendRenderer.set_zorder(zorder)
@@ -1060,7 +1181,7 @@ class BackendMatplotlib(BackendBase.BackendBase):
self.ax.set_xlabel(label)
def setGraphYLabel(self, label, axis):
- axes = self.ax if axis == 'left' else self.ax2
+ axes = self.ax if axis == "left" else self.ax2
axes.set_ylabel(label)
# Graph limits
@@ -1096,8 +1217,8 @@ class BackendMatplotlib(BackendBase.BackendBase):
self._updateMarkers()
def getGraphYLimits(self, axis):
- assert axis in ('left', 'right')
- ax = self.ax2 if axis == 'right' else self.ax
+ assert axis in ("left", "right")
+ ax = self.ax2 if axis == "right" else self.ax
if not ax.get_visible():
return None
@@ -1110,7 +1231,7 @@ class BackendMatplotlib(BackendBase.BackendBase):
return ax.get_ybound()
def setGraphYLimits(self, ymin, ymax, axis):
- ax = self.ax2 if axis == 'right' else self.ax
+ ax = self.ax2 if axis == "right" else self.ax
if ymax < ymin:
ymin, ymax = ymax, ymin
self._dirtyLimits = True
@@ -1137,6 +1258,23 @@ class BackendMatplotlib(BackendBase.BackendBase):
# Graph axes
+ def __initXAxisFormatterAndLocator(self):
+ if self.ax.xaxis.get_scale() != "linear":
+ return # Do not override formatter and locator
+
+ if not self.isXAxisTimeSeries():
+ self.ax.xaxis.set_major_formatter(DefaultTickFormatter())
+ return
+
+ # We can't use a matplotlib.dates.DateFormatter because it expects
+ # the data to be in datetimes. Silx works internally with
+ # timestamps (floats).
+ locator = NiceDateLocator(tz=self.getXAxisTimeZone())
+ self.ax.xaxis.set_major_locator(locator)
+ self.ax.xaxis.set_major_formatter(
+ NiceAutoDateFormatter(locator, tz=self.getXAxisTimeZone())
+ )
+
def setXAxisTimeZone(self, tz):
super(BackendMatplotlib, self).setXAxisTimeZone(tz)
@@ -1148,40 +1286,27 @@ class BackendMatplotlib(BackendBase.BackendBase):
def setXAxisTimeSeries(self, isTimeSeries):
self._isXAxisTimeSeries = isTimeSeries
- if self._isXAxisTimeSeries:
- # We can't use a matplotlib.dates.DateFormatter because it expects
- # the data to be in datetimes. Silx works internally with
- # timestamps (floats).
- locator = NiceDateLocator(tz=self.getXAxisTimeZone())
- self.ax.xaxis.set_major_locator(locator)
- self.ax.xaxis.set_major_formatter(
- NiceAutoDateFormatter(locator, tz=self.getXAxisTimeZone()))
- else:
- try:
- scalarFormatter = ScalarFormatter(useOffset=False)
- except:
- _logger.warning('Cannot disabled axes offsets in %s ' %
- matplotlib.__version__)
- scalarFormatter = ScalarFormatter()
- self.ax.xaxis.set_major_formatter(scalarFormatter)
+ self.__initXAxisFormatterAndLocator()
def setXAxisLogarithmic(self, flag):
# Workaround for matplotlib 2.1.0 when one tries to set an axis
# to log scale with both limits <= 0
# In this case a draw with positive limits is needed first
- if flag and self._matplotlibVersion >= _parse_version('2.1.0'):
+ if flag and self._matplotlibVersion >= Version("2.1.0"):
xlim = self.ax.get_xlim()
if xlim[0] <= 0 and xlim[1] <= 0:
self.ax.set_xlim(1, 10)
self.draw()
- self.ax2.set_xscale('log' if flag else 'linear')
- self.ax.set_xscale('log' if flag else 'linear')
+ xscale = "log" if flag else "linear"
+ self.ax2.set_xscale(xscale)
+ self.ax.set_xscale(xscale)
+ self.__initXAxisFormatterAndLocator()
def setYAxisLogarithmic(self, flag):
# Workaround for matplotlib 2.0 issue with negative bounds
# before switching to log scale
- if flag and self._matplotlibVersion >= _parse_version('2.0.0'):
+ if flag and self._matplotlibVersion >= Version("2.0.0"):
redraw = False
for axis, dataRangeIndex in ((self.ax, 1), (self.ax2, 2)):
ylim = axis.get_ylim()
@@ -1194,8 +1319,15 @@ class BackendMatplotlib(BackendBase.BackendBase):
if redraw:
self.draw()
- self.ax2.set_yscale('log' if flag else 'linear')
- self.ax.set_yscale('log' if flag else 'linear')
+ if flag:
+ self.ax2.set_yscale("log")
+ self.ax.set_yscale("log")
+ return
+
+ self.ax2.set_yscale("linear")
+ self.ax2.yaxis.set_major_formatter(DefaultTickFormatter())
+ self.ax.set_yscale("linear")
+ self.ax.yaxis.set_major_formatter(DefaultTickFormatter())
def setYAxisInverted(self, flag):
if self.ax.yaxis_inverted() != bool(flag):
@@ -1205,15 +1337,18 @@ class BackendMatplotlib(BackendBase.BackendBase):
def isYAxisInverted(self):
return self.ax.yaxis_inverted()
+ def isYRightAxisVisible(self):
+ return self.ax2.yaxis.get_visible()
+
def isKeepDataAspectRatio(self):
- return self.ax.get_aspect() in (1.0, 'equal')
+ return self.ax.get_aspect() in (1.0, "equal")
def setKeepDataAspectRatio(self, flag):
- self.ax.set_aspect(1.0 if flag else 'auto')
- self.ax2.set_aspect(1.0 if flag else 'auto')
+ self.ax.set_aspect(1.0 if flag else "auto")
+ self.ax2.set_aspect(1.0 if flag else "auto")
def setGraphGrid(self, which):
- self.ax.grid(False, which='both') # Disable all grid first
+ self.ax.grid(False, which="both") # Disable all grid first
if which is not None:
self.ax.grid(True, which=which)
@@ -1221,23 +1356,19 @@ class BackendMatplotlib(BackendBase.BackendBase):
def _getDevicePixelRatio(self) -> float:
"""Compatibility wrapper for devicePixelRatioF"""
- return 1.
+ return 1.0
def _mplToQtPosition(
- self,
- x: Union[float,numpy.ndarray],
- y: Union[float,numpy.ndarray]
- ) -> Tuple[Union[float,numpy.ndarray], Union[float,numpy.ndarray]]:
- """Convert matplotlib "display" space coord to Qt widget logical pixel
- """
+ self, x: Union[float, numpy.ndarray], y: Union[float, numpy.ndarray]
+ ) -> Tuple[Union[float, numpy.ndarray], Union[float, numpy.ndarray]]:
+ """Convert matplotlib "display" space coord to Qt widget logical pixel"""
ratio = self._getDevicePixelRatio()
# Convert from matplotlib origin (bottom) to Qt origin (top)
# and apply device pixel ratio
return x / ratio, (self.fig.get_window_extent().height - y) / ratio
def _qtToMplPosition(self, x: float, y: float) -> Tuple[float, float]:
- """Convert Qt widget logical pixel to matplotlib "display" space coord
- """
+ """Convert Qt widget logical pixel to matplotlib "display" space coord"""
ratio = self._getDevicePixelRatio()
# Apply device pixel ration and
# convert from Qt origin (top) to matplotlib origin (bottom)
@@ -1258,18 +1389,33 @@ class BackendMatplotlib(BackendBase.BackendBase):
bbox = self.ax.get_window_extent()
# Warning this is not returning int...
ratio = self._getDevicePixelRatio()
- return tuple(int(value / ratio) for value in (
- bbox.xmin,
- self.fig.get_window_extent().height - bbox.ymax,
- bbox.width,
- bbox.height))
+ return tuple(
+ int(value / ratio)
+ for value in (
+ bbox.xmin,
+ self.fig.get_window_extent().height - bbox.ymax,
+ bbox.width,
+ bbox.height,
+ )
+ )
def setAxesMargins(self, left: float, top: float, right: float, bottom: float):
- width, height = 1. - left - right, 1. - top - bottom
+ width, height = 1.0 - left - right, 1.0 - top - bottom
position = left, bottom, width, height
+ istight = config._MPL_TIGHT_LAYOUT and (left, top, right, bottom) != (
+ 0,
+ 0,
+ 0,
+ 0,
+ )
+ if self._matplotlibVersion >= Version("3.6"):
+ self.fig.set_layout_engine("tight" if istight else None)
+ else:
+ self.fig.set_tight_layout(True if istight else None)
+
# Toggle display of axes and viewbox rect
- isFrameOn = position != (0., 0., 1., 1.)
+ isFrameOn = position != (0.0, 0.0, 1.0, 1.0)
self.ax.set_frame_on(isFrameOn)
self.ax2.set_frame_on(isFrameOn)
@@ -1291,7 +1437,7 @@ class BackendMatplotlib(BackendBase.BackendBase):
if self.ax.get_frame_on():
self.fig.patch.set_facecolor(backgroundColor)
- if self._matplotlibVersion < _parse_version('2'):
+ if self._matplotlibVersion < Version("2"):
self.ax.set_axis_bgcolor(dataBackgroundColor)
else:
self.ax.set_facecolor(dataBackgroundColor)
@@ -1309,12 +1455,12 @@ class BackendMatplotlib(BackendBase.BackendBase):
for axes in (self.ax, self.ax2):
if axes.get_frame_on():
- axes.spines['bottom'].set_color(foregroundColor)
- axes.spines['top'].set_color(foregroundColor)
- axes.spines['right'].set_color(foregroundColor)
- axes.spines['left'].set_color(foregroundColor)
- axes.tick_params(axis='x', colors=foregroundColor)
- axes.tick_params(axis='y', colors=foregroundColor)
+ axes.spines["bottom"].set_color(foregroundColor)
+ axes.spines["top"].set_color(foregroundColor)
+ axes.spines["right"].set_color(foregroundColor)
+ axes.spines["left"].set_color(foregroundColor)
+ axes.tick_params(axis="x", colors=foregroundColor)
+ axes.tick_params(axis="y", colors=foregroundColor)
axes.yaxis.label.set_color(foregroundColor)
axes.xaxis.label.set_color(foregroundColor)
axes.title.set_color(foregroundColor)
@@ -1350,19 +1496,19 @@ class BackendMatplotlibQt(BackendMatplotlib, FigureCanvasQTAgg):
self._limitsBeforeResize = None
FigureCanvasQTAgg.setSizePolicy(
- self, qt.QSizePolicy.Expanding, qt.QSizePolicy.Expanding)
+ self, qt.QSizePolicy.Expanding, qt.QSizePolicy.Expanding
+ )
FigureCanvasQTAgg.updateGeometry(self)
# Make postRedisplay asynchronous using Qt signal
- self._sigPostRedisplay.connect(
- self.__deferredReplot, qt.Qt.QueuedConnection)
+ self._sigPostRedisplay.connect(self.__deferredReplot, qt.Qt.QueuedConnection)
self._picked = None
- self.mpl_connect('button_press_event', self._onMousePress)
- self.mpl_connect('button_release_event', self._onMouseRelease)
- self.mpl_connect('motion_notify_event', self._onMouseMove)
- self.mpl_connect('scroll_event', self._onMouseWheel)
+ self.mpl_connect("button_press_event", self._onMousePress)
+ self.mpl_connect("button_release_event", self._onMouseRelease)
+ self.mpl_connect("motion_notify_event", self._onMouseMove)
+ self.mpl_connect("scroll_event", self._onMouseWheel)
def postRedisplay(self):
self._sigPostRedisplay.emit()
@@ -1370,23 +1516,21 @@ class BackendMatplotlibQt(BackendMatplotlib, FigureCanvasQTAgg):
def __deferredReplot(self):
# Since this is deferred, makes sure it is still needed
plot = self._plotRef()
- if (plot is not None and
- plot._getDirtyPlot() and
- plot.getBackend() is self):
+ if plot is not None and plot._getDirtyPlot() and plot.getBackend() is self:
self.replot()
def _getDevicePixelRatio(self) -> float:
"""Compatibility wrapper for devicePixelRatioF"""
- if hasattr(self, 'devicePixelRatioF'):
+ if hasattr(self, "devicePixelRatioF"):
ratio = self.devicePixelRatioF()
else: # Qt < 5.6 compatibility
ratio = float(self.devicePixelRatio())
# Safety net: avoid returning 0
- return ratio if ratio != 0. else 1.
+ return ratio if ratio != 0.0 else 1.0
# Mouse event forwarding
- _MPL_TO_PLOT_BUTTONS = {1: 'left', 2: 'middle', 3: 'right'}
+ _MPL_TO_PLOT_BUTTONS = {1: "left", 2: "middle", 3: "right"}
def _onMousePress(self, event):
button = self._MPL_TO_PLOT_BUTTONS.get(event.button, None)
@@ -1397,8 +1541,7 @@ class BackendMatplotlibQt(BackendMatplotlib, FigureCanvasQTAgg):
def _onMouseMove(self, event):
x, y = self._mplToQtPosition(event.x, event.y)
if self._graphCursor:
- position = self._plot.pixelToData(
- x, y, axis='left', check=True)
+ position = self._plot.pixelToData(x, y, axis="left", check=True)
lineh, linev = self._graphCursor
if position is not None:
linev.set_visible(True)
@@ -1407,9 +1550,9 @@ class BackendMatplotlibQt(BackendMatplotlib, FigureCanvasQTAgg):
lineh.set_ydata((position[1], position[1]))
self._plot._setDirtyPlot(overlayOnly=True)
elif lineh.get_visible():
- lineh.set_visible(False)
- linev.set_visible(False)
- self._plot._setDirtyPlot(overlayOnly=True)
+ lineh.set_visible(False)
+ linev.set_visible(False)
+ self._plot._setDirtyPlot(overlayOnly=True)
# onMouseMove must trigger replot if dirty flag is raised
self._plot.onMouseMove(int(x), int(y))
@@ -1438,11 +1581,13 @@ class BackendMatplotlibQt(BackendMatplotlib, FigureCanvasQTAgg):
def pickItem(self, x, y, item):
xDisplay, yDisplay = self._qtToMplPosition(x, y)
mouseEvent = MouseEvent(
- 'button_press_event', self, int(xDisplay), int(yDisplay))
+ "button_press_event", self, int(xDisplay), int(yDisplay)
+ )
# Override axes and data position with the axes
mouseEvent.inaxes = item.axes
mouseEvent.xdata, mouseEvent.ydata = self.pixelToData(
- x, y, axis='left' if item.axes is self.ax else 'right')
+ x, y, axis="left" if item.axes is self.ax else "right"
+ )
picked, info = item.contains(mouseEvent)
if not picked:
@@ -1451,26 +1596,30 @@ class BackendMatplotlibQt(BackendMatplotlib, FigureCanvasQTAgg):
elif isinstance(item, TriMesh):
# Convert selected triangle to data point indices
triangulation = item._triangulation
- indices = triangulation.get_masked_triangles()[info['ind'][0]]
+ indices = triangulation.get_masked_triangles()[info["ind"][0]]
# Sort picked triangle points by distance to mouse
# from furthest to closest to put closest point last
# This is to be somewhat consistent with last scatter point
# being the top one.
- xdata, ydata = self.pixelToData(x, y, axis='left')
- dists = ((triangulation.x[indices] - xdata) ** 2 +
- (triangulation.y[indices] - ydata) ** 2)
+ xdata, ydata = self.pixelToData(x, y, axis="left")
+ dists = (triangulation.x[indices] - xdata) ** 2 + (
+ triangulation.y[indices] - ydata
+ ) ** 2
return indices[numpy.flip(numpy.argsort(dists), axis=0)]
else: # Returns indices if any
- return info.get('ind', ())
+ return info.get("ind", ())
# replot control
def resizeEvent(self, event):
# Store current limits
self._limitsBeforeResize = (
- self.ax.get_xbound(), self.ax.get_ybound(), self.ax2.get_ybound())
+ self.ax.get_xbound(),
+ self.ax.get_ybound(),
+ self.ax2.get_ybound(),
+ )
FigureCanvasQTAgg.resizeEvent(self, event)
if self.isKeepDataAspectRatio() or self._hasOverlays():
@@ -1490,15 +1639,20 @@ class BackendMatplotlibQt(BackendMatplotlib, FigureCanvasQTAgg):
self.updateZOrder()
+ if not qt_inspect.isValid(self):
+ _logger.info("draw requested but widget no longer exists")
+ return
+
# Starting with mpl 2.1.0, toggling autoscale raises a ValueError
# in some situations. See #1081, #1136, #1163,
- if self._matplotlibVersion >= _parse_version("2.0.0"):
+ if self._matplotlibVersion >= Version("2.0.0"):
try:
FigureCanvasQTAgg.draw(self)
except ValueError as err:
_logger.debug(
- "ValueError caught while calling FigureCanvasQTAgg.draw: "
- "'%s'", err)
+ "ValueError caught while calling FigureCanvasQTAgg.draw: " "'%s'",
+ err,
+ )
else:
FigureCanvasQTAgg.draw(self)
@@ -1513,16 +1667,15 @@ class BackendMatplotlibQt(BackendMatplotlib, FigureCanvasQTAgg):
xLimits, yLimits, yRightLimits = self._limitsBeforeResize
self._limitsBeforeResize = None
- if (xLimits != self.ax.get_xbound() or
- yLimits != self.ax.get_ybound()):
+ if xLimits != self.ax.get_xbound() or yLimits != self.ax.get_ybound():
self._updateMarkers()
if xLimits != self.ax.get_xbound():
self._plot.getXAxis()._emitLimitsChanged()
if yLimits != self.ax.get_ybound():
- self._plot.getYAxis(axis='left')._emitLimitsChanged()
+ self._plot.getYAxis(axis="left")._emitLimitsChanged()
if yRightLimits != self.ax2.get_ybound():
- self._plot.getYAxis(axis='right')._emitLimitsChanged()
+ self._plot.getYAxis(axis="right")._emitLimitsChanged()
self._drawOverlays()
@@ -1536,7 +1689,7 @@ class BackendMatplotlibQt(BackendMatplotlib, FigureCanvasQTAgg):
dirtyFlag = self._plot._getDirtyPlot()
- if dirtyFlag == 'overlay':
+ if dirtyFlag == "overlay":
# Only redraw overlays using fast rendering path
if self._background is None:
self._background = self.copy_from_bbox(self.fig.bbox)
@@ -1548,8 +1701,9 @@ class BackendMatplotlibQt(BackendMatplotlib, FigureCanvasQTAgg):
self.draw()
# Workaround issue of rendering overlays with some matplotlib versions
- if (_parse_version('1.5') <= self._matplotlibVersion < _parse_version('2.1') and
- not hasattr(self, '_firstReplot')):
+ if Version("1.5") <= self._matplotlibVersion < Version(
+ "2.1"
+ ) and not hasattr(self, "_firstReplot"):
self._firstReplot = False
if self._hasOverlays():
qt.QTimer.singleShot(0, self.draw) # Request async draw
diff --git a/src/silx/gui/plot/backends/BackendOpenGL.py b/src/silx/gui/plot/backends/BackendOpenGL.py
index d7e8346..370f14b 100755
--- a/src/silx/gui/plot/backends/BackendOpenGL.py
+++ b/src/silx/gui/plot/backends/BackendOpenGL.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2014-2022 European Synchrotron Radiation Facility
+# Copyright (c) 2014-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -23,6 +23,8 @@
# ############################################################################*/
"""OpenGL Plot backend."""
+from __future__ import annotations
+
__authors__ = ["T. Vincent"]
__license__ = "MIT"
__date__ = "21/12/2018"
@@ -42,6 +44,7 @@ from ..._glutils import gl
from ... import _glutils as glu
from . import glutils
from .glutils.PlotImageFile import saveImageToFile
+from silx.gui.colors import RGBAColorType
_logger = logging.getLogger(__name__)
@@ -52,64 +55,95 @@ _logger = logging.getLogger(__name__)
# Content #####################################################################
+
class _ShapeItem(dict):
- def __init__(self, x, y, shape, color, fill, overlay,
- linestyle, linewidth, linebgcolor):
+ def __init__(
+ self,
+ x,
+ y,
+ shape,
+ color,
+ fill,
+ overlay,
+ linewidth,
+ dashoffset,
+ dashpattern,
+ gapcolor,
+ ):
super(_ShapeItem, self).__init__()
- if shape not in ('polygon', 'rectangle', 'line',
- 'vline', 'hline', 'polylines'):
+ if shape not in ("polygon", "rectangle", "line", "vline", "hline", "polylines"):
raise NotImplementedError("Unsupported shape {0}".format(shape))
x = numpy.array(x, copy=False)
y = numpy.array(y, copy=False)
- if shape == 'rectangle':
+ if shape == "rectangle":
xMin, xMax = x
x = numpy.array((xMin, xMin, xMax, xMax))
yMin, yMax = y
y = numpy.array((yMin, yMax, yMax, yMin))
# Ignore fill for polylines to mimic matplotlib
- fill = fill if shape != 'polylines' else False
-
- self.update({
- 'shape': shape,
- 'color': colors.rgba(color),
- 'fill': 'hatch' if fill else None,
- 'x': x,
- 'y': y,
- 'linestyle': linestyle,
- 'linewidth': linewidth,
- 'linebgcolor': linebgcolor,
- })
+ fill = fill if shape != "polylines" else False
+
+ self.update(
+ {
+ "shape": shape,
+ "color": colors.rgba(color),
+ "fill": "hatch" if fill else None,
+ "x": x,
+ "y": y,
+ "linewidth": linewidth,
+ "dashoffset": dashoffset,
+ "dashpattern": dashpattern,
+ "gapcolor": gapcolor,
+ }
+ )
class _MarkerItem(dict):
- def __init__(self, x, y, text, color,
- symbol, linestyle, linewidth, constraint, yaxis):
+ def __init__(
+ self,
+ x,
+ y,
+ text,
+ color,
+ symbol,
+ linewidth,
+ dashoffset,
+ dashpattern,
+ constraint,
+ yaxis,
+ font,
+ bgcolor,
+ ):
super(_MarkerItem, self).__init__()
if symbol is None:
- symbol = '+'
+ symbol = "+"
# Apply constraint to provided position
- isConstraint = (constraint is not None and
- x is not None and y is not None)
+ isConstraint = constraint is not None and x is not None and y is not None
if isConstraint:
x, y = constraint(x, y)
- self.update({
- 'x': x,
- 'y': y,
- 'text': text,
- 'color': colors.rgba(color),
- 'constraint': constraint if isConstraint else None,
- 'symbol': symbol,
- 'linestyle': linestyle,
- 'linewidth': linewidth,
- 'yaxis': yaxis,
- })
+ self.update(
+ {
+ "x": x,
+ "y": y,
+ "text": text,
+ "color": colors.rgba(color),
+ "constraint": constraint if isConstraint else None,
+ "symbol": symbol,
+ "linewidth": linewidth,
+ "dashoffset": dashoffset,
+ "dashpattern": dashpattern,
+ "yaxis": yaxis,
+ "font": font,
+ "bgcolor": bgcolor,
+ }
+ )
# shaders #####################################################################
@@ -193,26 +227,30 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
So, the caller should not modify these arrays afterwards.
"""
+ _TEXT_MARKER_PADDING = 4
+
def __init__(self, plot, parent=None, f=qt.Qt.Widget):
- glu.OpenGLWidget.__init__(self, parent,
- alphaBufferSize=8,
- depthBufferSize=0,
- stencilBufferSize=0,
- version=(2, 1),
- f=f)
+ glu.OpenGLWidget.__init__(
+ self,
+ parent,
+ alphaBufferSize=8,
+ depthBufferSize=0,
+ stencilBufferSize=0,
+ version=(2, 1),
+ f=f,
+ )
BackendBase.BackendBase.__init__(self, plot, parent)
+ self._defaultFont: qt.QFont = None
self.__isOpenGLValid = False
- self._backgroundColor = 1., 1., 1., 1.
- self._dataBackgroundColor = 1., 1., 1., 1.
+ self._backgroundColor = 1.0, 1.0, 1.0, 1.0
+ self._dataBackgroundColor = 1.0, 1.0, 1.0, 1.0
self.matScreenProj = glutils.mat4Identity()
- self._progBase = glu.Program(
- _baseVertShd, _baseFragShd, attrib0='position')
- self._progTex = glu.Program(
- _texVertShd, _texFragShd, attrib0='position')
+ self._progBase = glu.Program(_baseVertShd, _baseFragShd, attrib0="position")
+ self._progTex = glu.Program(_texVertShd, _texFragShd, attrib0="position")
self._plotFBOs = weakref.WeakKeyDictionary()
self._keepDataAspectRatio = False
@@ -223,12 +261,15 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
self._glGarbageCollector = []
self._plotFrame = glutils.GLPlotFrame2D(
- foregroundColor=(0., 0., 0., 1.),
- gridColor=(.7, .7, .7, 1.),
- marginRatios=(.15, .1, .1, .15))
+ foregroundColor=(0.0, 0.0, 0.0, 1.0),
+ gridColor=(0.7, 0.7, 0.7, 1.0),
+ marginRatios=(0.15, 0.1, 0.1, 0.15),
+ font=self.getDefaultFont(),
+ )
self._plotFrame.size = ( # Init size with size int
int(self.getDevicePixelRatio() * 640),
- int(self.getDevicePixelRatio() * 480))
+ int(self.getDevicePixelRatio() * 480),
+ )
self.setAutoFillBackground(False)
self.setMouseTracking(True)
@@ -236,9 +277,9 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
# QWidget
_MOUSE_BTNS = {
- qt.Qt.LeftButton: 'left',
- qt.Qt.RightButton: 'right',
- qt.Qt.MiddleButton: 'middle',
+ qt.Qt.LeftButton: "left",
+ qt.Qt.RightButton: "right",
+ qt.Qt.MiddleButton: "middle",
}
def sizeHint(self):
@@ -262,8 +303,10 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
else:
self._mousePosInPixels = None # Mouse outside plot area
- if (self._crosshairCursor is not None and
- previousMousePosInPixels != self._mousePosInPixels):
+ if (
+ self._crosshairCursor is not None
+ and previousMousePosInPixels != self._mousePosInPixels
+ ):
# Avoid replot when cursor remains outside plot area
self._plot._setDirtyPlot(overlayOnly=True)
@@ -279,7 +322,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
def wheelEvent(self, event):
delta = event.angleDelta().y()
- angleInDegrees = delta / 8.
+ angleInDegrees = delta / 8.0
x, y = qt.getMouseEventPosition(event)
self._plot.onMouseWheel(x, y, angleInDegrees)
event.accept()
@@ -298,10 +341,9 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
gl.glEnable(gl.GL_BLEND)
# gl.glBlendFunc(gl.GL_SRC_ALPHA, gl.GL_ONE_MINUS_SRC_ALPHA)
- gl.glBlendFuncSeparate(gl.GL_SRC_ALPHA,
- gl.GL_ONE_MINUS_SRC_ALPHA,
- gl.GL_ONE,
- gl.GL_ONE)
+ gl.glBlendFuncSeparate(
+ gl.GL_SRC_ALPHA, gl.GL_ONE_MINUS_SRC_ALPHA, gl.GL_ONE, gl.GL_ONE
+ )
# For lines
gl.glHint(gl.GL_LINE_SMOOTH_HINT, gl.GL_NICEST)
@@ -319,28 +361,33 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
def _paintFBOGL(self):
context = glu.Context.getCurrent()
plotFBOTex = self._plotFBOs.get(context)
- if (self._plot._getDirtyPlot() or self._plotFrame.isDirty or
- plotFBOTex is None):
+ if self._plot._getDirtyPlot() or self._plotFrame.isDirty or plotFBOTex is None:
self._plotVertices = (
# Vertex coordinates
- numpy.array(((-1., -1.), (1., -1.), (-1., 1.), (1., 1.)),
- dtype=numpy.float32),
- # Texture coordinates
- numpy.array(((0., 0.), (1., 0.), (0., 1.), (1., 1.)),
- dtype=numpy.float32))
- if plotFBOTex is None or \
- plotFBOTex.shape[1] != self._plotFrame.size[0] or \
- plotFBOTex.shape[0] != self._plotFrame.size[1]:
+ numpy.array(
+ ((-1.0, -1.0), (1.0, -1.0), (-1.0, 1.0), (1.0, 1.0)),
+ dtype=numpy.float32,
+ ),
+ # Texture coordinates
+ numpy.array(
+ ((0.0, 0.0), (1.0, 0.0), (0.0, 1.0), (1.0, 1.0)),
+ dtype=numpy.float32,
+ ),
+ )
+ if (
+ plotFBOTex is None
+ or plotFBOTex.shape[1] != self._plotFrame.size[0]
+ or plotFBOTex.shape[0] != self._plotFrame.size[1]
+ ):
if plotFBOTex is not None:
plotFBOTex.discard()
plotFBOTex = glu.FramebufferTexture(
gl.GL_RGBA,
- shape=(self._plotFrame.size[1],
- self._plotFrame.size[0]),
+ shape=(self._plotFrame.size[1], self._plotFrame.size[0]),
minFilter=gl.GL_NEAREST,
magFilter=gl.GL_NEAREST,
- wrap=(gl.GL_CLAMP_TO_EDGE,
- gl.GL_CLAMP_TO_EDGE))
+ wrap=(gl.GL_CLAMP_TO_EDGE, gl.GL_CLAMP_TO_EDGE),
+ )
self._plotFBOs[context] = plotFBOTex
with plotFBOTex:
@@ -355,25 +402,33 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
self._progTex.use()
texUnit = 0
- gl.glUniform1i(self._progTex.uniforms['tex'], texUnit)
- gl.glUniformMatrix4fv(self._progTex.uniforms['matrix'], 1, gl.GL_TRUE,
- glutils.mat4Identity().astype(numpy.float32))
-
- gl.glEnableVertexAttribArray(self._progTex.attributes['position'])
- gl.glVertexAttribPointer(self._progTex.attributes['position'],
- 2,
- gl.GL_FLOAT,
- gl.GL_FALSE,
- 0,
- self._plotVertices[0])
-
- gl.glEnableVertexAttribArray(self._progTex.attributes['texCoords'])
- gl.glVertexAttribPointer(self._progTex.attributes['texCoords'],
- 2,
- gl.GL_FLOAT,
- gl.GL_FALSE,
- 0,
- self._plotVertices[1])
+ gl.glUniform1i(self._progTex.uniforms["tex"], texUnit)
+ gl.glUniformMatrix4fv(
+ self._progTex.uniforms["matrix"],
+ 1,
+ gl.GL_TRUE,
+ glutils.mat4Identity().astype(numpy.float32),
+ )
+
+ gl.glEnableVertexAttribArray(self._progTex.attributes["position"])
+ gl.glVertexAttribPointer(
+ self._progTex.attributes["position"],
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0,
+ self._plotVertices[0],
+ )
+
+ gl.glEnableVertexAttribArray(self._progTex.attributes["texCoords"])
+ gl.glVertexAttribPointer(
+ self._progTex.attributes["texCoords"],
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0,
+ self._plotVertices[1],
+ )
with plotFBOTex.texture:
gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(self._plotVertices[0]))
@@ -404,6 +459,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
# Sync plot frame with window
self._plotFrame.devicePixelRatio = self.getDevicePixelRatio()
+ self._plotFrame.dotsPerInch = self.getDotsPerInch()
# self._paintDirectGL()
self._paintFBOGL()
@@ -434,18 +490,22 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
)
for plotItem in self.getItemsFromBackToFront(
- condition=lambda i: i.isVisible() and i.isOverlay() == overlay):
+ condition=lambda i: i.isVisible() and i.isOverlay() == overlay
+ ):
if plotItem._backendRenderer is None:
continue
item = plotItem._backendRenderer
if isinstance(item, glutils.GLPlotItem): # Render data items
- gl.glViewport(self._plotFrame.margins.left,
- self._plotFrame.margins.bottom,
- plotWidth, plotHeight)
+ gl.glViewport(
+ self._plotFrame.margins.left,
+ self._plotFrame.margins.bottom,
+ plotWidth,
+ plotHeight,
+ )
# Set matrix
- if item.yaxis == 'right':
+ if item.yaxis == "right":
context.matrix = self._plotFrame.transformedDataY2ProjMat
else:
context.matrix = self._plotFrame.transformedDataProjMat
@@ -454,140 +514,187 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
elif isinstance(item, _ShapeItem): # Render shape items
gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
- if ((isXLog and numpy.min(item['x']) < FLOAT32_MINPOS) or
- (isYLog and numpy.min(item['y']) < FLOAT32_MINPOS)):
+ if (isXLog and numpy.min(item["x"]) < FLOAT32_MINPOS) or (
+ isYLog and numpy.min(item["y"]) < FLOAT32_MINPOS
+ ):
# Ignore items <= 0. on log axes
continue
- if item['shape'] == 'hline':
+ if item["shape"] == "hline":
width = self._plotFrame.size[0]
_, yPixel = self._plotFrame.dataToPixel(
- 0.5 * sum(self._plotFrame.dataRanges[0]),
- item['y'],
- axis='left')
- subShapes = [numpy.array(((0., yPixel), (width, yPixel)),
- dtype=numpy.float32)]
+ 0.5 * sum(self._plotFrame.dataRanges[0]), item["y"], axis="left"
+ )
+ subShapes = [
+ numpy.array(
+ ((0.0, yPixel), (width, yPixel)), dtype=numpy.float32
+ )
+ ]
- elif item['shape'] == 'vline':
+ elif item["shape"] == "vline":
xPixel, _ = self._plotFrame.dataToPixel(
- item['x'],
- 0.5 * sum(self._plotFrame.dataRanges[1]),
- axis='left')
+ item["x"], 0.5 * sum(self._plotFrame.dataRanges[1]), axis="left"
+ )
height = self._plotFrame.size[1]
- subShapes = [numpy.array(((xPixel, 0), (xPixel, height)),
- dtype=numpy.float32)]
+ subShapes = [
+ numpy.array(
+ ((xPixel, 0), (xPixel, height)), dtype=numpy.float32
+ )
+ ]
else:
# Split sub-shapes at not finite values
- splits = numpy.nonzero(numpy.logical_not(numpy.logical_and(
- numpy.isfinite(item['x']), numpy.isfinite(item['y']))))[0]
- splits = numpy.concatenate(([-1], splits, [len(item['x'])]))
+ splits = numpy.nonzero(
+ numpy.logical_not(
+ numpy.logical_and(
+ numpy.isfinite(item["x"]), numpy.isfinite(item["y"])
+ )
+ )
+ )[0]
+ splits = numpy.concatenate(([-1], splits, [len(item["x"])]))
subShapes = []
for begin, end in zip(splits[:-1] + 1, splits[1:]):
if end > begin:
- subShapes.append(numpy.array([
- self._plotFrame.dataToPixel(x, y, axis='left')
- for (x, y) in zip(item['x'][begin:end], item['y'][begin:end])]))
+ subShapes.append(
+ numpy.array(
+ [
+ self._plotFrame.dataToPixel(x, y, axis="left")
+ for (x, y) in zip(
+ item["x"][begin:end], item["y"][begin:end]
+ )
+ ]
+ )
+ )
for points in subShapes: # Draw each sub-shape
# Draw the fill
- if (item['fill'] is not None and
- item['shape'] not in ('hline', 'vline')):
+ if item["fill"] is not None and item["shape"] not in (
+ "hline",
+ "vline",
+ ):
self._progBase.use()
gl.glUniformMatrix4fv(
- self._progBase.uniforms['matrix'], 1, gl.GL_TRUE,
- self.matScreenProj.astype(numpy.float32))
- gl.glUniform2i(self._progBase.uniforms['isLog'], False, False)
- gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.)
+ self._progBase.uniforms["matrix"],
+ 1,
+ gl.GL_TRUE,
+ self.matScreenProj.astype(numpy.float32),
+ )
+ gl.glUniform2i(self._progBase.uniforms["isLog"], False, False)
+ gl.glUniform1f(self._progBase.uniforms["tickLen"], 0.0)
shape2D = glutils.FilledShape2D(
- points, style=item['fill'], color=item['color'])
+ points, style=item["fill"], color=item["color"]
+ )
shape2D.render(
- posAttrib=self._progBase.attributes['position'],
- colorUnif=self._progBase.uniforms['color'],
- hatchStepUnif=self._progBase.uniforms['hatchStep'])
+ posAttrib=self._progBase.attributes["position"],
+ colorUnif=self._progBase.uniforms["color"],
+ hatchStepUnif=self._progBase.uniforms["hatchStep"],
+ )
# Draw the stroke
- if item['linestyle'] not in ('', ' ', None):
- if item['shape'] != 'polylines':
+ if item["dashpattern"] is not None:
+ if item["shape"] != "polylines":
# close the polyline
- points = numpy.append(points,
- numpy.atleast_2d(points[0]), axis=0)
+ points = numpy.append(
+ points, numpy.atleast_2d(points[0]), axis=0
+ )
lines = glutils.GLLines2D(
- points[:, 0], points[:, 1],
- style=item['linestyle'],
- color=item['color'],
- dash2ndColor=item['linebgcolor'],
- width=item['linewidth'])
+ points[:, 0],
+ points[:, 1],
+ color=item["color"],
+ gapColor=item["gapcolor"],
+ width=item["linewidth"],
+ dashOffset=item["dashoffset"],
+ dashPattern=item["dashpattern"],
+ )
context.matrix = self.matScreenProj
lines.render(context)
elif isinstance(item, _MarkerItem):
gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
- xCoord, yCoord, yAxis = item['x'], item['y'], item['yaxis']
+ xCoord, yCoord, yAxis = item["x"], item["y"], item["yaxis"]
- if ((isXLog and xCoord is not None and xCoord <= 0) or
- (isYLog and yCoord is not None and yCoord <= 0)):
+ if (isXLog and xCoord is not None and xCoord <= 0) or (
+ isYLog and yCoord is not None and yCoord <= 0
+ ):
# Do not render markers with negative coords on log axis
continue
- color = item['color']
- intensity = color[0] * 0.299 + color[1] * 0.587 + color[2] * 0.114
- bgColor = (1., 1., 1., 0.75) if intensity <= 0.5 else (0., 0., 0., 0.75)
+ color = item["color"]
+ bgColor = item["bgcolor"]
if xCoord is None or yCoord is None:
if xCoord is None: # Horizontal line in data space
pixelPos = self._plotFrame.dataToPixel(
- 0.5 * sum(self._plotFrame.dataRanges[0]),
- yCoord,
- axis=yAxis)
-
- if item['text'] is not None:
- x = self._plotFrame.size[0] - \
- self._plotFrame.margins.right - pixelOffset
+ 0.5 * sum(self._plotFrame.dataRanges[0]), yCoord, axis=yAxis
+ )
+
+ if item["text"] is not None:
+ x = (
+ self._plotFrame.size[0]
+ - self._plotFrame.margins.right
+ - pixelOffset
+ )
y = pixelPos[1] - pixelOffset
label = glutils.Text2D(
- item['text'], x, y,
- color=item['color'],
+ item["text"],
+ item["font"],
+ x,
+ y,
+ color=color,
bgColor=bgColor,
align=glutils.RIGHT,
valign=glutils.BOTTOM,
- devicePixelRatio=self.getDevicePixelRatio())
+ devicePixelRatio=self.getDevicePixelRatio(),
+ padding=self._TEXT_MARKER_PADDING,
+ )
labels.append(label)
width = self._plotFrame.size[0]
lines = glutils.GLLines2D(
- (0, width), (pixelPos[1], pixelPos[1]),
- style=item['linestyle'],
- color=item['color'],
- width=item['linewidth'])
+ (0, width),
+ (pixelPos[1], pixelPos[1]),
+ color=color,
+ width=item["linewidth"],
+ dashOffset=item["dashoffset"],
+ dashPattern=item["dashpattern"],
+ )
context.matrix = self.matScreenProj
lines.render(context)
else: # yCoord is None: vertical line in data space
- yRange = self._plotFrame.dataRanges[1 if yAxis == 'left' else 2]
+ yRange = self._plotFrame.dataRanges[1 if yAxis == "left" else 2]
pixelPos = self._plotFrame.dataToPixel(
- xCoord, 0.5 * sum(yRange), axis=yAxis)
+ xCoord, 0.5 * sum(yRange), axis=yAxis
+ )
- if item['text'] is not None:
+ if item["text"] is not None:
x = pixelPos[0] + pixelOffset
y = self._plotFrame.margins.top + pixelOffset
label = glutils.Text2D(
- item['text'], x, y,
- color=item['color'],
+ item["text"],
+ item["font"],
+ x,
+ y,
+ color=color,
bgColor=bgColor,
align=glutils.LEFT,
valign=glutils.TOP,
- devicePixelRatio=self.getDevicePixelRatio())
+ devicePixelRatio=self.getDevicePixelRatio(),
+ padding=self._TEXT_MARKER_PADDING,
+ )
labels.append(label)
height = self._plotFrame.size[1]
lines = glutils.GLLines2D(
- (pixelPos[0], pixelPos[0]), (0, height),
- style=item['linestyle'],
- color=item['color'],
- width=item['linewidth'])
+ (pixelPos[0], pixelPos[0]),
+ (0, height),
+ color=color,
+ width=item["linewidth"],
+ dashOffset=item["dashoffset"],
+ dashPattern=item["dashpattern"],
+ )
context.matrix = self.matScreenProj
lines.render(context)
@@ -597,8 +704,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
if not xmin < xCoord < xmax or not ymin < yCoord < ymax:
# Do not render markers outside visible plot area
continue
- pixelPos = self._plotFrame.dataToPixel(
- xCoord, yCoord, axis=yAxis)
+ pixelPos = self._plotFrame.dataToPixel(xCoord, yCoord, axis=yAxis)
if isYInverted:
valign = glutils.BOTTOM
@@ -607,16 +713,21 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
valign = glutils.TOP
vPixelOffset = pixelOffset
- if item['text'] is not None:
+ if item["text"] is not None:
x = pixelPos[0] + pixelOffset
y = pixelPos[1] + vPixelOffset
label = glutils.Text2D(
- item['text'], x, y,
- color=item['color'],
+ item["text"],
+ item["font"],
+ x,
+ y,
+ color=color,
bgColor=bgColor,
align=glutils.LEFT,
valign=valign,
- devicePixelRatio=self.getDevicePixelRatio())
+ devicePixelRatio=self.getDevicePixelRatio(),
+ padding=self._TEXT_MARKER_PADDING,
+ )
labels.append(label)
# For now simple implementation: using a curve for each marker
@@ -624,30 +735,33 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
marker = glutils.Points2D(
(pixelPos[0],),
(pixelPos[1],),
- marker=item['symbol'],
- color=item['color'],
+ marker=item["symbol"],
+ color=color,
size=11,
)
context.matrix = self.matScreenProj
marker.render(context)
else:
- _logger.error('Unsupported item: %s', str(item))
+ _logger.error("Unsupported item: %s", str(item))
continue
# Render marker labels
gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
for label in labels:
- label.render(self.matScreenProj)
+ label.render(self.matScreenProj, self._plotFrame.dotsPerInch)
def _renderOverlayGL(self):
"""Render overlay layer: overlay items and crosshair."""
plotWidth, plotHeight = self._plotFrame.plotSize
# Scissor to plot area
- gl.glScissor(self._plotFrame.margins.left,
- self._plotFrame.margins.bottom,
- plotWidth, plotHeight)
+ gl.glScissor(
+ self._plotFrame.margins.left,
+ self._plotFrame.margins.bottom,
+ plotWidth,
+ plotHeight,
+ )
gl.glEnable(gl.GL_SCISSOR_TEST)
self._renderItems(overlay=True)
@@ -655,17 +769,18 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
# Render crosshair cursor
if self._crosshairCursor is not None and self._mousePosInPixels is not None:
self._progBase.use()
- gl.glUniform2i(self._progBase.uniforms['isLog'], False, False)
- gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.)
- posAttrib = self._progBase.attributes['position']
- matrixUnif = self._progBase.uniforms['matrix']
- colorUnif = self._progBase.uniforms['color']
- hatchStepUnif = self._progBase.uniforms['hatchStep']
+ gl.glUniform2i(self._progBase.uniforms["isLog"], False, False)
+ gl.glUniform1f(self._progBase.uniforms["tickLen"], 0.0)
+ posAttrib = self._progBase.attributes["position"]
+ matrixUnif = self._progBase.uniforms["matrix"]
+ colorUnif = self._progBase.uniforms["color"]
+ hatchStepUnif = self._progBase.uniforms["hatchStep"]
gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
- gl.glUniformMatrix4fv(matrixUnif, 1, gl.GL_TRUE,
- self.matScreenProj.astype(numpy.float32))
+ gl.glUniformMatrix4fv(
+ matrixUnif, 1, gl.GL_TRUE, self.matScreenProj.astype(numpy.float32)
+ )
color, lineWidth = self._crosshairCursor
gl.glUniform4f(colorUnif, *color)
@@ -673,18 +788,20 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
xPixel, yPixel = self._mousePosInPixels
xPixel, yPixel = xPixel + 0.5, yPixel + 0.5
- vertices = numpy.array(((0., yPixel),
- (self._plotFrame.size[0], yPixel),
- (xPixel, 0.),
- (xPixel, self._plotFrame.size[1])),
- dtype=numpy.float32)
+ vertices = numpy.array(
+ (
+ (0.0, yPixel),
+ (self._plotFrame.size[0], yPixel),
+ (xPixel, 0.0),
+ (xPixel, self._plotFrame.size[1]),
+ ),
+ dtype=numpy.float32,
+ )
gl.glEnableVertexAttribArray(posAttrib)
- gl.glVertexAttribPointer(posAttrib,
- 2,
- gl.GL_FLOAT,
- gl.GL_FALSE,
- 0, vertices)
+ gl.glVertexAttribPointer(
+ posAttrib, 2, gl.GL_FLOAT, gl.GL_FALSE, 0, vertices
+ )
gl.glLineWidth(lineWidth)
gl.glDrawArrays(gl.GL_LINES, 0, len(vertices))
@@ -697,9 +814,12 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
"""
plotWidth, plotHeight = self._plotFrame.plotSize
- gl.glScissor(self._plotFrame.margins.left,
- self._plotFrame.margins.bottom,
- plotWidth, plotHeight)
+ gl.glScissor(
+ self._plotFrame.margins.left,
+ self._plotFrame.margins.bottom,
+ plotWidth,
+ plotHeight,
+ )
gl.glEnable(gl.GL_SCISSOR_TEST)
if self._dataBackgroundColor != self._backgroundColor:
@@ -722,29 +842,28 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
self._plotFrame.size = (
int(self.getDevicePixelRatio() * width),
- int(self.getDevicePixelRatio() * height))
+ int(self.getDevicePixelRatio() * height),
+ )
self.matScreenProj = glutils.mat4Ortho(
- 0, self._plotFrame.size[0],
- self._plotFrame.size[1], 0,
- 1, -1)
+ 0, self._plotFrame.size[0], self._plotFrame.size[1], 0, 1, -1
+ )
# Store current ranges
previousXRange = self.getGraphXLimits()
- previousYRange = self.getGraphYLimits(axis='left')
- previousYRightRange = self.getGraphYLimits(axis='right')
+ previousYRange = self.getGraphYLimits(axis="left")
+ previousYRightRange = self.getGraphYLimits(axis="right")
- (xMin, xMax), (yMin, yMax), (y2Min, y2Max) = \
- self._plotFrame.dataRanges
+ (xMin, xMax), (yMin, yMax), (y2Min, y2Max) = self._plotFrame.dataRanges
self.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
# If plot range has changed, then emit signal
if previousXRange != self.getGraphXLimits():
self._plot.getXAxis()._emitLimitsChanged()
- if previousYRange != self.getGraphYLimits(axis='left'):
- self._plot.getYAxis(axis='left')._emitLimitsChanged()
- if previousYRightRange != self.getGraphYLimits(axis='right'):
- self._plot.getYAxis(axis='right')._emitLimitsChanged()
+ if previousYRange != self.getGraphYLimits(axis="left"):
+ self._plot.getYAxis(axis="left")._emitLimitsChanged()
+ if previousYRightRange != self.getGraphYLimits(axis="right"):
+ self._plot.getYAxis(axis="right")._emitLimitsChanged()
# Add methods
@@ -761,39 +880,92 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
elif numpy.issubdtype(v.dtype, numpy.integer):
return numpy.float32 if v.itemsize <= 2 else numpy.float64
else:
- raise ValueError('Unsupported data type')
-
- def addCurve(self, x, y,
- color, symbol, linewidth, linestyle,
- yaxis,
- xerror, yerror,
- fill, alpha, symbolsize, baseline):
- for parameter in (x, y, color, symbol, linewidth, linestyle,
- yaxis, fill, symbolsize):
+ raise ValueError("Unsupported data type")
+
+ _DASH_PATTERNS = {
+ "": (0.0, None),
+ " ": (0.0, None),
+ "-": (0.0, ()),
+ "--": (0.0, (3.7, 1.6, 3.7, 1.6)),
+ "-.": (0.0, (6.4, 1.6, 1, 1.6)),
+ ":": (0.0, (1, 1.65, 1, 1.65)),
+ None: (0.0, None),
+ }
+ """Convert from linestyle to (offset, (dash pattern))
+
+ Note: dash pattern internal convention differs from matplotlib:
+ - None: no line at all
+ - (): "solid" line
+ """
+
+ def _lineStyleToDashOffsetPattern(
+ self, style
+ ) -> tuple[float, tuple[float, float, float, float] | tuple[()] | None]:
+ """Convert a linestyle to its corresponding offset and dash pattern"""
+ if style is None or isinstance(style, str):
+ return self._DASH_PATTERNS[style]
+
+ # (offset, (dash pattern)) case
+ offset, pattern = style
+ if pattern is None:
+ # Convert from matplotlib to internal representation of solid
+ pattern = ()
+ if len(pattern) == 2:
+ pattern = pattern * 2
+ return float(offset), tuple(float(v) for v in pattern)
+
+ def addCurve(
+ self,
+ x,
+ y,
+ color,
+ gapcolor,
+ symbol,
+ linewidth,
+ linestyle,
+ yaxis,
+ xerror,
+ yerror,
+ fill,
+ alpha,
+ symbolsize,
+ baseline,
+ ):
+ for parameter in (
+ x,
+ y,
+ color,
+ symbol,
+ linewidth,
+ linestyle,
+ yaxis,
+ fill,
+ symbolsize,
+ ):
assert parameter is not None
- assert yaxis in ('left', 'right')
+ assert yaxis in ("left", "right")
# Convert input data
x = numpy.array(x, copy=False)
y = numpy.array(y, copy=False)
# Check if float32 is enough
- if (self._castArrayTo(x) is numpy.float32 and
- self._castArrayTo(y) is numpy.float32):
+ if (
+ self._castArrayTo(x) is numpy.float32
+ and self._castArrayTo(y) is numpy.float32
+ ):
dtype = numpy.float32
else:
dtype = numpy.float64
- x = numpy.array(x, dtype=dtype, copy=False, order='C')
- y = numpy.array(y, dtype=dtype, copy=False, order='C')
+ x = numpy.array(x, dtype=dtype, copy=False, order="C")
+ y = numpy.array(y, dtype=dtype, copy=False, order="C")
# Convert errors to float32
if xerror is not None:
- xerror = numpy.array(
- xerror, dtype=numpy.float32, copy=False, order='C')
+ xerror = numpy.array(xerror, dtype=numpy.float32, copy=False, order="C")
if yerror is not None:
- yerror = numpy.array(
- yerror, dtype=numpy.float32, copy=False, order='C')
+ yerror = numpy.array(yerror, dtype=numpy.float32, copy=False, order="C")
# Handle axes log scale: convert data
@@ -803,21 +975,21 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
if xerror is not None:
# Transform xerror so that
# log10(x) +/- xerror' = log10(x +/- xerror)
- if hasattr(xerror, 'shape') and len(xerror.shape) == 2:
+ if hasattr(xerror, "shape") and len(xerror.shape) == 2:
xErrorMinus, xErrorPlus = xerror[0], xerror[1]
else:
xErrorMinus, xErrorPlus = xerror, xerror
- with numpy.errstate(divide='ignore', invalid='ignore'):
+ with numpy.errstate(divide="ignore", invalid="ignore"):
# Ignore divide by zero, invalid value encountered in log10
xErrorMinus = logX - numpy.log10(x - xErrorMinus)
xErrorPlus = numpy.log10(x + xErrorPlus) - logX
- xerror = numpy.array((xErrorMinus, xErrorPlus),
- dtype=numpy.float32)
+ xerror = numpy.array((xErrorMinus, xErrorPlus), dtype=numpy.float32)
x = logX
- isYLog = (yaxis == 'left' and self._plotFrame.yAxis.isLog) or (
- yaxis == 'right' and self._plotFrame.y2Axis.isLog)
+ isYLog = (yaxis == "left" and self._plotFrame.yAxis.isLog) or (
+ yaxis == "right" and self._plotFrame.y2Axis.isLog
+ )
if isYLog:
logY = numpy.log10(y)
@@ -825,25 +997,23 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
if yerror is not None:
# Transform yerror so that
# log10(y) +/- yerror' = log10(y +/- yerror)
- if hasattr(yerror, 'shape') and len(yerror.shape) == 2:
+ if hasattr(yerror, "shape") and len(yerror.shape) == 2:
yErrorMinus, yErrorPlus = yerror[0], yerror[1]
else:
yErrorMinus, yErrorPlus = yerror, yerror
- with numpy.errstate(divide='ignore', invalid='ignore'):
+ with numpy.errstate(divide="ignore", invalid="ignore"):
# Ignore divide by zero, invalid value encountered in log10
yErrorMinus = logY - numpy.log10(y - yErrorMinus)
yErrorPlus = numpy.log10(y + yErrorPlus) - logY
- yerror = numpy.array((yErrorMinus, yErrorPlus),
- dtype=numpy.float32)
+ yerror = numpy.array((yErrorMinus, yErrorPlus), dtype=numpy.float32)
y = logY
# TODO check if need more filtering of error (e.g., clip to positive)
# TODO check and improve this
- if (len(color) == 4 and
- type(color[3]) in [type(1), numpy.uint8, numpy.int8]):
- color = numpy.array(color, dtype=numpy.float32) / 255.
+ if len(color) == 4 and type(color[3]) in [type(1), numpy.uint8, numpy.int8]:
+ color = numpy.array(color, dtype=numpy.float32) / 255.0
if isinstance(color, numpy.ndarray) and color.ndim == 2:
colorArray = color
@@ -852,7 +1022,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
colorArray = None
color = colors.rgba(color)
- if alpha < 1.: # Apply image transparency
+ if alpha < 1.0: # Apply image transparency
if colorArray is not None and colorArray.shape[1] == 4:
# multiply alpha channel
colorArray[:, 3] = colorArray[:, 3] * alpha
@@ -862,43 +1032,49 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
fillColor = None
if fill is True:
fillColor = color
+
+ dashoffset, dashpattern = self._lineStyleToDashOffsetPattern(linestyle)
curve = glutils.GLPlotCurve2D(
- x, y, colorArray,
+ x,
+ y,
+ colorArray,
xError=xerror,
yError=yerror,
- lineStyle=linestyle,
lineColor=color,
+ lineGapColor=gapcolor,
lineWidth=linewidth,
+ lineDashOffset=dashoffset,
+ lineDashPattern=dashpattern,
marker=symbol,
markerColor=color,
markerSize=symbolsize,
fillColor=fillColor,
baseline=baseline,
- isYLog=isYLog)
- curve.yaxis = 'left' if yaxis is None else yaxis
+ isYLog=isYLog,
+ )
+ curve.yaxis = "left" if yaxis is None else yaxis
if yaxis == "right":
self._plotFrame.isY2Axis = True
return curve
- def addImage(self, data,
- origin, scale,
- colormap, alpha):
+ def addImage(self, data, origin, scale, colormap, alpha):
for parameter in (data, origin, scale):
assert parameter is not None
if data.ndim == 2:
# Ensure array is contiguous and eventually convert its type
- dtypes = [dtype for dtype in (
- numpy.float32, numpy.float16, numpy.uint8, numpy.uint16)
- if glu.isSupportedGLType(dtype)]
+ dtypes = [
+ dtype
+ for dtype in (numpy.float32, numpy.float16, numpy.uint8, numpy.uint16)
+ if glu.isSupportedGLType(dtype)
+ ]
if data.dtype in dtypes:
- data = numpy.array(data, copy=False, order='C')
+ data = numpy.array(data, copy=False, order="C")
else:
- _logger.info(
- 'addImage: Convert %s data to float32', str(data.dtype))
- data = numpy.array(data, dtype=numpy.float32, order='C')
+ _logger.info("addImage: Convert %s data to float32", str(data.dtype))
+ data = numpy.array(data, dtype=numpy.float32, order="C")
normalization = colormap.getNormalization()
if normalization in glutils.GLPlotColormap.SUPPORTED_NORMALIZATIONS:
@@ -917,7 +1093,8 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
gamma,
cmapRange,
alpha,
- nanColor)
+ nanColor,
+ )
else: # Fallback applying colormap on CPU
rgba = colormap.applyToData(data)
@@ -934,7 +1111,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
elif numpy.issubdtype(data.dtype, numpy.integer):
data = numpy.array(data, dtype=numpy.uint8, copy=False)
else:
- raise ValueError('Unsupported data type')
+ raise ValueError("Unsupported data type")
image = glutils.GLPlotRGBAImage(data, origin, scale, alpha)
@@ -942,17 +1119,14 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
raise RuntimeError("Unsupported data shape {0}".format(data.shape))
# TODO is this needed?
- if self._plotFrame.xAxis.isLog and image.xMin <= 0.:
- raise RuntimeError(
- 'Cannot add image with X <= 0 with X axis log scale')
- if self._plotFrame.yAxis.isLog and image.yMin <= 0.:
- raise RuntimeError(
- 'Cannot add image with Y <= 0 with Y axis log scale')
+ if self._plotFrame.xAxis.isLog and image.xMin <= 0.0:
+ raise RuntimeError("Cannot add image with X <= 0 with X axis log scale")
+ if self._plotFrame.yAxis.isLog and image.yMin <= 0.0:
+ raise RuntimeError("Cannot add image with Y <= 0 with Y axis log scale")
return image
- def addTriangles(self, x, y, triangles,
- color, alpha):
+ def addTriangles(self, x, y, triangles, color, alpha):
# Handle axes log scale: convert data
if self._plotFrame.xAxis.isLog:
x = numpy.log10(x)
@@ -963,36 +1137,90 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
return triangles
- def addShape(self, x, y, shape, color, fill, overlay,
- linestyle, linewidth, linebgcolor):
+ def addShape(
+ self, x, y, shape, color, fill, overlay, linestyle, linewidth, gapcolor
+ ):
x = numpy.array(x, copy=False)
y = numpy.array(y, copy=False)
# TODO is this needed?
- if self._plotFrame.xAxis.isLog and x.min() <= 0.:
- raise RuntimeError(
- 'Cannot add item with X <= 0 with X axis log scale')
- if self._plotFrame.yAxis.isLog and y.min() <= 0.:
- raise RuntimeError(
- 'Cannot add item with Y <= 0 with Y axis log scale')
-
- return _ShapeItem(x, y, shape, color, fill, overlay,
- linestyle, linewidth, linebgcolor)
+ if self._plotFrame.xAxis.isLog and x.min() <= 0.0:
+ raise RuntimeError("Cannot add item with X <= 0 with X axis log scale")
+ if self._plotFrame.yAxis.isLog and y.min() <= 0.0:
+ raise RuntimeError("Cannot add item with Y <= 0 with Y axis log scale")
+
+ dashoffset, dashpattern = self._lineStyleToDashOffsetPattern(linestyle)
+ return _ShapeItem(
+ x,
+ y,
+ shape,
+ color,
+ fill,
+ overlay,
+ linewidth,
+ dashoffset,
+ dashpattern,
+ gapcolor,
+ )
- def addMarker(self, x, y, text, color,
- symbol, linestyle, linewidth, constraint, yaxis):
- return _MarkerItem(x, y, text, color,
- symbol, linestyle, linewidth, constraint, yaxis)
+ def getDefaultFont(self):
+ """Returns the default font, used by raw markers and axes labels"""
+ if self._defaultFont is None:
+ from matplotlib.font_manager import findfont, FontProperties
+
+ font_filename = findfont(FontProperties(family=["sans-serif"]))
+ _logger.debug("Load font from mpl: %s", font_filename)
+ id = qt.QFontDatabase.addApplicationFont(font_filename)
+ family = qt.QFontDatabase.applicationFontFamilies(id)[0]
+ font = qt.QFont(family, 10, qt.QFont.Normal, False)
+ font.setStyleStrategy(qt.QFont.PreferAntialias)
+ self._defaultFont = font
+ return self._defaultFont
+
+ def addMarker(
+ self,
+ x,
+ y,
+ text,
+ color,
+ symbol,
+ linestyle,
+ linewidth,
+ constraint,
+ yaxis,
+ font,
+ bgcolor: RGBAColorType | None,
+ ):
+ if font is None:
+ font = self.getDefaultFont()
+
+ dashoffset, dashpattern = self._lineStyleToDashOffsetPattern(linestyle)
+ return _MarkerItem(
+ x,
+ y,
+ text,
+ color,
+ symbol,
+ linewidth,
+ dashoffset,
+ dashpattern,
+ constraint,
+ yaxis,
+ font,
+ bgcolor,
+ )
# Remove methods
def remove(self, item):
if isinstance(item, glutils.GLPlotItem):
- if item.yaxis == 'right':
+ if item.yaxis == "right":
# Check if some curves remains on the right Y axis
- y2AxisItems = (item for item in self._plot.getItems()
- if isinstance(item, items.YAxisMixIn) and
- item.getYAxis() == 'right')
+ y2AxisItems = (
+ item
+ for item in self._plot.getItems()
+ if isinstance(item, items.YAxisMixIn) and item.getYAxis() == "right"
+ )
self._plotFrame.isY2Axis = next(y2AxisItems, None) is not None
if item.isInitialized():
@@ -1002,7 +1230,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
pass # No-op
else:
- _logger.error('Unsupported item: %s', str(item))
+ _logger.error("Unsupported item: %s", str(item))
# Interaction methods
@@ -1022,9 +1250,8 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
super(BackendOpenGL, self).setCursor(qt.QCursor(cursor))
def setGraphCursor(self, flag, color, linewidth, linestyle):
- if linestyle != '-':
- _logger.warning(
- "BackendOpenGL.setGraphCursor linestyle parameter ignored")
+ if linestyle != "-":
+ _logger.warning("BackendOpenGL.setGraphCursor linestyle parameter ignored")
if flag:
color = colors.rgba(color)
@@ -1048,8 +1275,10 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
:rtype: List[float]
"""
left, top, width, height = self.getPlotBoundsInPixels()
- return (numpy.clip(x, left, left + width - 1), # TODO -1?
- numpy.clip(y, top, top + height - 1))
+ return (
+ numpy.clip(x, left, left + width - 1), # TODO -1?
+ numpy.clip(y, top, top + height - 1),
+ )
def __pickCurves(self, item, x, y):
"""Perform picking on a curve item.
@@ -1064,24 +1293,26 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
if item.marker is not None:
# Convert markerSize from points to qt pixels
qtDpi = self.getDotsPerInch() / self.getDevicePixelRatio()
- size = item.markerSize / 72. * qtDpi
- offset = max(size / 2., offset)
- if item.lineStyle is not None:
+ size = item.markerSize / 72.0 * qtDpi
+ offset = max(size / 2.0, offset)
+ if item.lineDashPattern is not None:
# Convert line width from points to qt pixels
qtDpi = self.getDotsPerInch() / self.getDevicePixelRatio()
- lineWidth = item.lineWidth / 72. * qtDpi
- offset = max(lineWidth / 2., offset)
+ lineWidth = item.lineWidth / 72.0 * qtDpi
+ offset = max(lineWidth / 2.0, offset)
inAreaPos = self._mouseInPlotArea(x - offset, y - offset)
- dataPos = self._plot.pixelToData(inAreaPos[0], inAreaPos[1],
- axis=item.yaxis, check=True)
+ dataPos = self._plot.pixelToData(
+ inAreaPos[0], inAreaPos[1], axis=item.yaxis, check=True
+ )
if dataPos is None:
return None
xPick0, yPick0 = dataPos
inAreaPos = self._mouseInPlotArea(x + offset, y + offset)
- dataPos = self._plot.pixelToData(inAreaPos[0], inAreaPos[1],
- axis=item.yaxis, check=True)
+ dataPos = self._plot.pixelToData(
+ inAreaPos[0], inAreaPos[1], axis=item.yaxis, check=True
+ )
if dataPos is None:
return None
xPick1, yPick1 = dataPos
@@ -1101,17 +1332,17 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
xPickMin = numpy.log10(xPickMin)
xPickMax = numpy.log10(xPickMax)
- if (item.yaxis == 'left' and self._plotFrame.yAxis.isLog) or (
- item.yaxis == 'right' and self._plotFrame.y2Axis.isLog):
+ if (item.yaxis == "left" and self._plotFrame.yAxis.isLog) or (
+ item.yaxis == "right" and self._plotFrame.y2Axis.isLog
+ ):
yPickMin = numpy.log10(yPickMin)
yPickMax = numpy.log10(yPickMax)
- return item.pick(xPickMin, yPickMin,
- xPickMax, yPickMax)
+ return item.pick(xPickMin, yPickMin, xPickMax, yPickMax)
def pickItem(self, x, y, item):
# Picking is performed in Qt widget pixels not device pixels
- dataPos = self._plot.pixelToData(x, y, axis='left', check=True)
+ dataPos = self._plot.pixelToData(x, y, axis="left", check=True)
if dataPos is None:
return None # Outside plot area
@@ -1121,32 +1352,36 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
# Pick markers
if isinstance(item, _MarkerItem):
- yaxis = item['yaxis']
+ yaxis = item["yaxis"]
pixelPos = self._plot.dataToPixel(
- item['x'], item['y'], axis=yaxis, check=False)
+ item["x"], item["y"], axis=yaxis, check=False
+ )
if pixelPos is None:
return None # negative coord on a log axis
- if item['x'] is None: # Horizontal line
+ if item["x"] is None: # Horizontal line
pt1 = self._plot.pixelToData(
- x, y - self._PICK_OFFSET, axis=yaxis, check=False)
+ x, y - self._PICK_OFFSET, axis=yaxis, check=False
+ )
pt2 = self._plot.pixelToData(
- x, y + self._PICK_OFFSET, axis=yaxis, check=False)
- isPicked = (min(pt1[1], pt2[1]) <= item['y'] <=
- max(pt1[1], pt2[1]))
+ x, y + self._PICK_OFFSET, axis=yaxis, check=False
+ )
+ isPicked = min(pt1[1], pt2[1]) <= item["y"] <= max(pt1[1], pt2[1])
- elif item['y'] is None: # Vertical line
+ elif item["y"] is None: # Vertical line
pt1 = self._plot.pixelToData(
- x - self._PICK_OFFSET, y, axis=yaxis, check=False)
+ x - self._PICK_OFFSET, y, axis=yaxis, check=False
+ )
pt2 = self._plot.pixelToData(
- x + self._PICK_OFFSET, y, axis=yaxis, check=False)
- isPicked = (min(pt1[0], pt2[0]) <= item['x'] <=
- max(pt1[0], pt2[0]))
+ x + self._PICK_OFFSET, y, axis=yaxis, check=False
+ )
+ isPicked = min(pt1[0], pt2[0]) <= item["x"] <= max(pt1[0], pt2[0])
else:
isPicked = (
- numpy.fabs(x - pixelPos[0]) <= self._PICK_OFFSET and
- numpy.fabs(y - pixelPos[1]) <= self._PICK_OFFSET)
+ numpy.fabs(x - pixelPos[0]) <= self._PICK_OFFSET
+ and numpy.fabs(y - pixelPos[1]) <= self._PICK_OFFSET
+ )
return (0,) if isPicked else None
@@ -1177,11 +1412,11 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
if dpi is not None:
_logger.warning("saveGraph ignores dpi parameter")
- if fileFormat not in ['png', 'ppm', 'svg', 'tiff']:
- raise NotImplementedError('Unsupported format: %s' % fileFormat)
+ if fileFormat not in ["png", "ppm", "svg", "tif", "tiff"]:
+ raise NotImplementedError("Unsupported format: %s" % fileFormat)
if not self.isValid():
- _logger.error('OpenGL 2.1 not available, cannot save OpenGL image')
+ _logger.error("OpenGL 2.1 not available, cannot save OpenGL image")
width, height = self._plotFrame.size
data = numpy.zeros((height, width, 3), dtype=numpy.uint8)
else:
@@ -1189,7 +1424,9 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
data = numpy.empty(
(self._plotFrame.size[1], self._plotFrame.size[0], 3),
- dtype=numpy.uint8, order='C')
+ dtype=numpy.uint8,
+ order="C",
+ )
context = self.context()
framebufferTexture = self._plotFBOs.get(context)
@@ -1205,8 +1442,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
previousFramebuffer = gl.glGetInteger(gl.GL_FRAMEBUFFER_BINDING)
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fboName)
gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1)
- gl.glReadPixels(0, 0, width, height,
- gl.GL_RGB, gl.GL_UNSIGNED_BYTE, data)
+ gl.glReadPixels(0, 0, width, height, gl.GL_RGB, gl.GL_UNSIGNED_BYTE, data)
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, previousFramebuffer)
# glReadPixels gives bottom to top,
@@ -1225,7 +1461,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
self._plotFrame.xAxis.title = label
def setGraphYLabel(self, label, axis):
- if axis == 'left':
+ if axis == "left":
self._plotFrame.yAxis.title = label
else: # right axis
self._plotFrame.y2Axis.title = label
@@ -1258,24 +1494,27 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
if keepDim is None:
ranges = self._plot.getDataRange()
- if (ranges.y is not None and
- ranges.x is not None and
- (ranges.y[1] - ranges.y[0]) != 0.):
- dataRatio = (ranges.x[1] - ranges.x[0]) / float(ranges.y[1] - ranges.y[0])
+ if (
+ ranges.y is not None
+ and ranges.x is not None
+ and (ranges.y[1] - ranges.y[0]) != 0.0
+ ):
+ dataRatio = (ranges.x[1] - ranges.x[0]) / float(
+ ranges.y[1] - ranges.y[0]
+ )
plotRatio = plotWidth / float(plotHeight) # Test != 0 before
- keepDim = 'x' if dataRatio > plotRatio else 'y'
+ keepDim = "x" if dataRatio > plotRatio else "y"
else: # Limit case
- keepDim = 'x'
+ keepDim = "x"
- (xMin, xMax), (yMin, yMax), (y2Min, y2Max) = \
- self._plotFrame.dataRanges
- if keepDim == 'y':
+ (xMin, xMax), (yMin, yMax), (y2Min, y2Max) = self._plotFrame.dataRanges
+ if keepDim == "y":
dataW = (yMax - yMin) * plotWidth / float(plotHeight)
xCenter = 0.5 * (xMin + xMax)
xMin = xCenter - 0.5 * dataW
xMax = xCenter + 0.5 * dataW
- elif keepDim == 'x':
+ elif keepDim == "x":
dataH = (xMax - xMin) * plotHeight / float(plotWidth)
yCenter = 0.5 * (yMin + yMax)
yMin = yCenter - 0.5 * dataH
@@ -1284,19 +1523,14 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
y2Min = y2Center - 0.5 * dataH
y2Max = y2Center + 0.5 * dataH
else:
- raise RuntimeError('Unsupported dimension to keep: %s' % keepDim)
+ raise RuntimeError("Unsupported dimension to keep: %s" % keepDim)
# Update plot frame bounds
- self._setDataRanges(xlim=(xMin, xMax),
- ylim=(yMin, yMax),
- y2lim=(y2Min, y2Max))
+ self._setDataRanges(xlim=(xMin, xMax), ylim=(yMin, yMax), y2lim=(y2Min, y2Max))
- def _setPlotBounds(self, xRange=None, yRange=None, y2Range=None,
- keepDim=None):
+ def _setPlotBounds(self, xRange=None, yRange=None, y2Range=None, keepDim=None):
# Update axes range with a clipped range if too wide
- self._setDataRanges(xlim=xRange,
- ylim=yRange,
- y2lim=y2Range)
+ self._setDataRanges(xlim=xRange, ylim=yRange, y2lim=y2Range)
# Keep data aspect ratio
if self.isKeepDataAspectRatio():
@@ -1318,7 +1552,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
def setGraphXLimits(self, xmin, xmax):
assert xmin < xmax
- self._setPlotBounds(xRange=(xmin, xmax), keepDim='x')
+ self._setPlotBounds(xRange=(xmin, xmax), keepDim="x")
def getGraphYLimits(self, axis):
assert axis in ("left", "right")
@@ -1332,9 +1566,9 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
assert axis in ("left", "right")
if axis == "left":
- self._setPlotBounds(yRange=(ymin, ymax), keepDim='y')
+ self._setPlotBounds(yRange=(ymin, ymax), keepDim="y")
else:
- self._setPlotBounds(y2Range=(ymin, ymax), keepDim='y')
+ self._setPlotBounds(y2Range=(ymin, ymax), keepDim="y")
# Graph axes
@@ -1353,17 +1587,14 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
def setXAxisLogarithmic(self, flag):
if flag != self._plotFrame.xAxis.isLog:
if flag and self._keepDataAspectRatio:
- _logger.warning(
- "KeepDataAspectRatio is ignored with log axes")
+ _logger.warning("KeepDataAspectRatio is ignored with log axes")
self._plotFrame.xAxis.isLog = flag
def setYAxisLogarithmic(self, flag):
- if (flag != self._plotFrame.yAxis.isLog or
- flag != self._plotFrame.y2Axis.isLog):
+ if flag != self._plotFrame.yAxis.isLog or flag != self._plotFrame.y2Axis.isLog:
if flag and self._keepDataAspectRatio:
- _logger.warning(
- "KeepDataAspectRatio is ignored with log axes")
+ _logger.warning("KeepDataAspectRatio is ignored with log axes")
self._plotFrame.yAxis.isLog = flag
self._plotFrame.y2Axis.isLog = flag
@@ -1375,6 +1606,9 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
def isYAxisInverted(self):
return self._plotFrame.isYAxisInverted
+ def isYRightAxisVisible(self):
+ return self._plotFrame.isY2Axis
+
def isKeepDataAspectRatio(self):
if self._plotFrame.xAxis.isLog or self._plotFrame.yAxis.isLog:
return False
@@ -1382,14 +1616,13 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
return self._keepDataAspectRatio
def setKeepDataAspectRatio(self, flag):
- if flag and (self._plotFrame.xAxis.isLog or
- self._plotFrame.yAxis.isLog):
+ if flag and (self._plotFrame.xAxis.isLog or self._plotFrame.yAxis.isLog):
_logger.warning("KeepDataAspectRatio is ignored with log axes")
self._keepDataAspectRatio = flag
def setGraphGrid(self, which):
- assert which in (None, 'major', 'both')
+ assert which in (None, "major", "both")
self._plotFrame.grid = which is not None # TODO True grid support
# Data <-> Pixel coordinates conversion
@@ -1400,17 +1633,20 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
return None
else:
devicePixelRatio = self.getDevicePixelRatio()
- return tuple(value/devicePixelRatio for value in result)
+ return tuple(value / devicePixelRatio for value in result)
def pixelToData(self, x, y, axis):
devicePixelRatio = self.getDevicePixelRatio()
return self._plotFrame.pixelToData(
- x * devicePixelRatio, y * devicePixelRatio, axis)
+ x * devicePixelRatio, y * devicePixelRatio, axis
+ )
def getPlotBoundsInPixels(self):
devicePixelRatio = self.getDevicePixelRatio()
- return tuple(int(value / devicePixelRatio)
- for value in self._plotFrame.plotOrigin + self._plotFrame.plotSize)
+ return tuple(
+ int(value / devicePixelRatio)
+ for value in self._plotFrame.plotOrigin + self._plotFrame.plotSize
+ )
def setAxesMargins(self, left: float, top: float, right: float, bottom: float):
self._plotFrame.marginRatios = left, top, right, bottom
diff --git a/src/silx/gui/plot/backends/glutils/GLPlotCurve.py b/src/silx/gui/plot/backends/glutils/GLPlotCurve.py
index 4825479..26442d7 100644
--- a/src/silx/gui/plot/backends/glutils/GLPlotCurve.py
+++ b/src/silx/gui/plot/backends/glutils/GLPlotCurve.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2014-2022 European Synchrotron Radiation Facility
+# Copyright (c) 2014-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -46,7 +46,7 @@ from .GLPlotImage import GLPlotItem
_logger = logging.getLogger(__name__)
-_MPL_NONES = None, 'None', '', ' '
+_MPL_NONES = None, "None", "", " "
"""Possible values for None"""
@@ -75,6 +75,7 @@ def _notNaNSlices(array, length=1):
# fill ########################################################################
+
class _Fill2D(object):
"""Object rendering curve filling as polygons
@@ -107,12 +108,17 @@ class _Fill2D(object):
gl_FragColor = color;
}
""",
- attrib0='xPos')
-
- def __init__(self, xData=None, yData=None,
- baseline=0,
- color=(0., 0., 0., 1.),
- offset=(0., 0.)):
+ attrib0="xPos",
+ )
+
+ def __init__(
+ self,
+ xData=None,
+ yData=None,
+ baseline=0,
+ color=(0.0, 0.0, 0.0, 1.0),
+ offset=(0.0, 0.0),
+ ):
self.xData = xData
self.yData = yData
self._xFillVboData = None
@@ -125,9 +131,11 @@ class _Fill2D(object):
def prepare(self):
"""Rendering preparation: build indices and bounding box vertices"""
- if (self._xFillVboData is None and
- self.xData is not None and self.yData is not None):
-
+ if (
+ self._xFillVboData is None
+ and self.xData is not None
+ and self.yData is not None
+ ):
# Get slices of not NaN values longer than 1 element
isnan = numpy.logical_or(numpy.isnan(self.xData), numpy.isnan(self.yData))
notnan = numpy.logical_not(isnan)
@@ -151,20 +159,28 @@ class _Fill2D(object):
new_y_data = numpy.append(self.yData, self.baseline)
for start, end in slices:
# Duplicate first point for connecting degenerated triangle
- points[offset:offset+2] = self.xData[start], new_y_data[start]
+ points[offset : offset + 2] = self.xData[start], new_y_data[start]
# 2nd point of the polygon is last point
- points[offset+2] = self.xData[start], self.baseline[start]
-
- indices = numpy.append(numpy.arange(start, end),
- numpy.arange(len(self.xData) + end-1, len(self.xData) + start-1, -1))
+ points[offset + 2] = self.xData[start], self.baseline[start]
+
+ indices = numpy.append(
+ numpy.arange(start, end),
+ numpy.arange(
+ len(self.xData) + end - 1, len(self.xData) + start - 1, -1
+ ),
+ )
indices = indices[buildFillMaskIndices(len(indices))]
- points[offset+3:offset+3+len(indices), 0] = self.xData[indices % len(self.xData)]
- points[offset+3:offset+3+len(indices), 1] = new_y_data[indices]
+ points[offset + 3 : offset + 3 + len(indices), 0] = self.xData[
+ indices % len(self.xData)
+ ]
+ points[offset + 3 : offset + 3 + len(indices), 1] = new_y_data[indices]
# Duplicate last point for connecting degenerated triangle
- points[offset+3+len(indices)] = points[offset+3+len(indices)-1]
+ points[offset + 3 + len(indices)] = points[
+ offset + 3 + len(indices) - 1
+ ]
offset += len(indices) + 4
@@ -183,14 +199,18 @@ class _Fill2D(object):
self._PROGRAM.use()
gl.glUniformMatrix4fv(
- self._PROGRAM.uniforms['matrix'], 1, gl.GL_TRUE,
- numpy.dot(context.matrix,
- mat4Translate(*self.offset)).astype(numpy.float32))
+ self._PROGRAM.uniforms["matrix"],
+ 1,
+ gl.GL_TRUE,
+ numpy.dot(context.matrix, mat4Translate(*self.offset)).astype(
+ numpy.float32
+ ),
+ )
- gl.glUniform4f(self._PROGRAM.uniforms['color'], *self.color)
+ gl.glUniform4f(self._PROGRAM.uniforms["color"], *self.color)
- xPosAttrib = self._PROGRAM.attributes['xPos']
- yPosAttrib = self._PROGRAM.attributes['yPos']
+ xPosAttrib = self._PROGRAM.attributes["xPos"]
+ yPosAttrib = self._PROGRAM.attributes["yPos"]
gl.glEnableVertexAttribArray(xPosAttrib)
self._xFillVboData.setVertexAttrib(xPosAttrib)
@@ -215,16 +235,30 @@ class _Fill2D(object):
gl.glDepthMask(gl.GL_TRUE)
# Draw directly in NDC
- gl.glUniformMatrix4fv(self._PROGRAM.uniforms['matrix'], 1, gl.GL_TRUE,
- mat4Identity().astype(numpy.float32))
+ gl.glUniformMatrix4fv(
+ self._PROGRAM.uniforms["matrix"],
+ 1,
+ gl.GL_TRUE,
+ mat4Identity().astype(numpy.float32),
+ )
# NDC vertices
gl.glVertexAttribPointer(
- xPosAttrib, 1, gl.GL_FLOAT, gl.GL_FALSE, 0,
- numpy.array((-1., -1., 1., 1.), dtype=numpy.float32))
+ xPosAttrib,
+ 1,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0,
+ numpy.array((-1.0, -1.0, 1.0, 1.0), dtype=numpy.float32),
+ )
gl.glVertexAttribPointer(
- yPosAttrib, 1, gl.GL_FLOAT, gl.GL_FALSE, 0,
- numpy.array((-1., 1., -1., 1.), dtype=numpy.float32))
+ yPosAttrib,
+ 1,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0,
+ numpy.array((-1.0, 1.0, -1.0, 1.0), dtype=numpy.float32),
+ )
gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, 4)
@@ -244,8 +278,6 @@ class _Fill2D(object):
# line ########################################################################
-SOLID, DASHED, DASHDOT, DOTTED = '-', '--', '-.', ':'
-
class GLLines2D(object):
"""Object rendering curve as a polyline
@@ -254,17 +286,18 @@ class GLLines2D(object):
:param yVboData: Y coordinates VBO
:param colorVboData: VBO of colors
:param distVboData: VBO of distance along the polyline
- :param str style: Line style in: '-', '--', '-.', ':'
:param List[float] color: RGBA color as 4 float in [0, 1]
:param float width: Line width
- :param float dashPeriod: Period of dashes
+ :param List[float] dashPattern:
+ "unscaled" dash pattern as 4 lengths in points (dash1, gap1, dash2, gap2).
+ This pattern is scaled with the line width.
+ Set to () to draw solid lines (default), and to None to disable rendering.
+ :param float dashOffset: The offset in points the patterns starts at.
+ The offset is scaled with the line width.
:param drawMode: OpenGL drawing mode
:param List[float] offset: Translation of coordinates (ox, oy)
"""
- STYLES = SOLID, DASHED, DASHDOT, DOTTED
- """Supported line styles"""
-
_SOLID_PROGRAM = Program(
vertexShader="""
#version 120
@@ -290,7 +323,8 @@ class GLLines2D(object):
gl_FragColor = vColor;
}
""",
- attrib0='xPos')
+ attrib0="xPos",
+ )
# Limitation: Dash using an estimate of distance in screen coord
# to avoid computing distance when viewport is resized
@@ -321,51 +355,60 @@ class GLLines2D(object):
/* Dashes: [0, x], [y, z]
Dash period: w */
uniform vec4 dash;
- uniform vec4 dash2ndColor;
+ uniform float dashOffset;
+ uniform vec4 gapColor;
varying float vDist;
varying vec4 vColor;
void main(void) {
- float dist = mod(vDist, dash.w);
+ float dist = mod(vDist + dashOffset, dash.w);
if ((dist > dash.x && dist < dash.y) || dist > dash.z) {
- if (dash2ndColor.a == 0.) {
+ if (gapColor.a == 0.) {
discard; // Discard full transparent bg color
} else {
- gl_FragColor = dash2ndColor;
+ gl_FragColor = gapColor;
}
} else {
gl_FragColor = vColor;
}
}
""",
- attrib0='xPos')
-
- def __init__(self, xVboData=None, yVboData=None,
- colorVboData=None, distVboData=None,
- style=SOLID, color=(0., 0., 0., 1.), dash2ndColor=None,
- width=1, dashPeriod=10., drawMode=None,
- offset=(0., 0.)):
- if (xVboData is not None and
- not isinstance(xVboData, VertexBufferAttrib)):
+ attrib0="xPos",
+ )
+
+ def __init__(
+ self,
+ xVboData=None,
+ yVboData=None,
+ colorVboData=None,
+ distVboData=None,
+ color=(0.0, 0.0, 0.0, 1.0),
+ gapColor=None,
+ width=1,
+ dashOffset=0.0,
+ dashPattern=(),
+ drawMode=None,
+ offset=(0.0, 0.0),
+ ):
+ if xVboData is not None and not isinstance(xVboData, VertexBufferAttrib):
xVboData = numpy.array(xVboData, copy=False, dtype=numpy.float32)
self.xVboData = xVboData
- if (yVboData is not None and
- not isinstance(yVboData, VertexBufferAttrib)):
+ if yVboData is not None and not isinstance(yVboData, VertexBufferAttrib):
yVboData = numpy.array(yVboData, copy=False, dtype=numpy.float32)
self.yVboData = yVboData
# Compute distances if not given while providing numpy array coordinates
- if (isinstance(self.xVboData, numpy.ndarray) and
- isinstance(self.yVboData, numpy.ndarray) and
- distVboData is None):
+ if (
+ isinstance(self.xVboData, numpy.ndarray)
+ and isinstance(self.yVboData, numpy.ndarray)
+ and distVboData is None
+ ):
distVboData = distancesFromArrays(self.xVboData, self.yVboData)
- if (distVboData is not None and
- not isinstance(distVboData, VertexBufferAttrib)):
- distVboData = numpy.array(
- distVboData, copy=False, dtype=numpy.float32)
+ if distVboData is not None and not isinstance(distVboData, VertexBufferAttrib):
+ distVboData = numpy.array(distVboData, copy=False, dtype=numpy.float32)
self.distVboData = distVboData
if colorVboData is not None:
@@ -374,28 +417,14 @@ class GLLines2D(object):
self.useColorVboData = colorVboData is not None
self.color = color
- self.dash2ndColor = dash2ndColor
+ self.gapColor = gapColor
self.width = width
- self._style = None
- self.style = style
- self.dashPeriod = dashPeriod
+ self.dashPattern = dashPattern
+ self.dashOffset = dashOffset
self.offset = offset
self._drawMode = drawMode if drawMode is not None else gl.GL_LINE_STRIP
- @property
- def style(self):
- """Line style (Union[str,None])"""
- return self._style
-
- @style.setter
- def style(self, style):
- if style in _MPL_NONES:
- self._style = None
- else:
- assert style in self.STYLES
- self._style = style
-
@classmethod
def init(cls):
"""OpenGL context initialization"""
@@ -406,74 +435,57 @@ class GLLines2D(object):
:param RenderContext context:
"""
- width = self.width / 72. * context.dpi
-
- style = self.style
- if style is None:
+ if self.dashPattern is None: # Nothing to display
return
- elif style == SOLID:
+ if self.dashPattern == (): # No dash: solid line
program = self._SOLID_PROGRAM
program.use()
- else: # DASHED, DASHDOT, DOTTED
+ else: # Dashed line defined by 4 control points
program = self._DASH_PROGRAM
program.use()
- dashPeriod = self.dashPeriod * width
- if self.style == DOTTED:
- dash = (0.2 * dashPeriod,
- 0.5 * dashPeriod,
- 0.7 * dashPeriod,
- dashPeriod)
- elif self.style == DASHDOT:
- dash = (0.3 * dashPeriod,
- 0.5 * dashPeriod,
- 0.6 * dashPeriod,
- dashPeriod)
- else:
- dash = (0.5 * dashPeriod,
- dashPeriod,
- dashPeriod,
- dashPeriod)
-
- gl.glUniform4f(program.uniforms['dash'], *dash)
+ # Scale pattern by width, convert from lengths in points to offsets in pixels
+ scale = self.width / 72.0 * context.dpi
+ dashOffsets = tuple(
+ offset * scale for offset in numpy.cumsum(self.dashPattern)
+ )
+ gl.glUniform4f(program.uniforms["dash"], *dashOffsets)
+ gl.glUniform1f(program.uniforms["dashOffset"], self.dashOffset * scale)
- if self.dash2ndColor is None:
+ if self.gapColor is None:
# Use fully transparent color which gets discarded in shader
- dash2ndColor = (0., 0., 0., 0.)
+ gapColor = (0.0, 0.0, 0.0, 0.0)
else:
- dash2ndColor = self.dash2ndColor
- gl.glUniform4f(program.uniforms['dash2ndColor'], *dash2ndColor)
+ gapColor = self.gapColor
+ gl.glUniform4f(program.uniforms["gapColor"], *gapColor)
viewWidth = gl.glGetFloatv(gl.GL_VIEWPORT)[2]
xNDCPerData = (
- numpy.dot(context.matrix, [1., 0., 0., 1.])[0] -
- numpy.dot(context.matrix, [0., 0., 0., 1.])[0])
+ numpy.dot(context.matrix, [1.0, 0.0, 0.0, 1.0])[0]
+ - numpy.dot(context.matrix, [0.0, 0.0, 0.0, 1.0])[0]
+ )
xPixelPerData = 0.5 * viewWidth * xNDCPerData
- gl.glUniform1f(program.uniforms['distanceScale'], xPixelPerData)
+ gl.glUniform1f(program.uniforms["distanceScale"], xPixelPerData)
- distAttrib = program.attributes['distance']
+ distAttrib = program.attributes["distance"]
gl.glEnableVertexAttribArray(distAttrib)
if isinstance(self.distVboData, VertexBufferAttrib):
self.distVboData.setVertexAttrib(distAttrib)
else:
- gl.glVertexAttribPointer(distAttrib,
- 1,
- gl.GL_FLOAT,
- False,
- 0,
- self.distVboData)
-
- if width != 1:
- gl.glEnable(gl.GL_LINE_SMOOTH)
-
- matrix = numpy.dot(context.matrix,
- mat4Translate(*self.offset)).astype(numpy.float32)
- gl.glUniformMatrix4fv(program.uniforms['matrix'],
- 1, gl.GL_TRUE, matrix)
-
- colorAttrib = program.attributes['color']
+ gl.glVertexAttribPointer(
+ distAttrib, 1, gl.GL_FLOAT, False, 0, self.distVboData
+ )
+
+ gl.glEnable(gl.GL_LINE_SMOOTH)
+
+ matrix = numpy.dot(context.matrix, mat4Translate(*self.offset)).astype(
+ numpy.float32
+ )
+ gl.glUniformMatrix4fv(program.uniforms["matrix"], 1, gl.GL_TRUE, matrix)
+
+ colorAttrib = program.attributes["color"]
if self.useColorVboData and self.colorVboData is not None:
gl.glEnableVertexAttribArray(colorAttrib)
self.colorVboData.setVertexAttrib(colorAttrib)
@@ -481,37 +493,31 @@ class GLLines2D(object):
gl.glDisableVertexAttribArray(colorAttrib)
gl.glVertexAttrib4f(colorAttrib, *self.color)
- xPosAttrib = program.attributes['xPos']
+ xPosAttrib = program.attributes["xPos"]
gl.glEnableVertexAttribArray(xPosAttrib)
if isinstance(self.xVboData, VertexBufferAttrib):
self.xVboData.setVertexAttrib(xPosAttrib)
else:
- gl.glVertexAttribPointer(xPosAttrib,
- 1,
- gl.GL_FLOAT,
- False,
- 0,
- self.xVboData)
-
- yPosAttrib = program.attributes['yPos']
+ gl.glVertexAttribPointer(
+ xPosAttrib, 1, gl.GL_FLOAT, False, 0, self.xVboData
+ )
+
+ yPosAttrib = program.attributes["yPos"]
gl.glEnableVertexAttribArray(yPosAttrib)
if isinstance(self.yVboData, VertexBufferAttrib):
self.yVboData.setVertexAttrib(yPosAttrib)
else:
- gl.glVertexAttribPointer(yPosAttrib,
- 1,
- gl.GL_FLOAT,
- False,
- 0,
- self.yVboData)
-
- gl.glLineWidth(width)
+ gl.glVertexAttribPointer(
+ yPosAttrib, 1, gl.GL_FLOAT, False, 0, self.yVboData
+ )
+
+ gl.glLineWidth(self.width / 72.0 * context.dpi)
gl.glDrawArrays(self._drawMode, 0, self.xVboData.size)
gl.glDisable(gl.GL_LINE_SMOOTH)
-def distancesFromArrays(xData, yData, ratio: float=1.):
+def distancesFromArrays(xData, yData, ratio: float = 1.0):
"""Returns distances between each points
:param numpy.ndarray xData: X coordinate of points
@@ -520,8 +526,11 @@ def distancesFromArrays(xData, yData, ratio: float=1.):
:rtype: numpy.ndarray
"""
# Split array into sub-shapes at not finite points
- splits = numpy.nonzero(numpy.logical_not(numpy.logical_and(
- numpy.isfinite(xData), numpy.isfinite(yData))))[0]
+ splits = numpy.nonzero(
+ numpy.logical_not(
+ numpy.logical_and(numpy.isfinite(xData), numpy.isfinite(yData))
+ )
+ )[0]
splits = numpy.concatenate(([-1], splits, [len(xData) - 1]))
# Compute distance independently for each sub-shapes,
@@ -530,23 +539,35 @@ def distancesFromArrays(xData, yData, ratio: float=1.):
for begin, end in zip(splits[:-1] + 1, splits[1:] + 1):
if begin == end: # Empty shape
continue
- elif end - begin == 1: # Single element
+ elif end - begin == 1: # Single element
distances.append(numpy.array([0], dtype=numpy.float32))
else:
- deltas = numpy.dstack((
- numpy.ediff1d(xData[begin:end], to_begin=numpy.float32(0.)),
- numpy.ediff1d(yData[begin:end] * ratio, to_begin=numpy.float32(0.))))[0]
- distances.append(
- numpy.cumsum(numpy.sqrt(numpy.sum(deltas ** 2, axis=1))))
+ deltas = numpy.dstack(
+ (
+ numpy.ediff1d(xData[begin:end], to_begin=numpy.float32(0.0)),
+ numpy.ediff1d(
+ yData[begin:end] * ratio, to_begin=numpy.float32(0.0)
+ ),
+ )
+ )[0]
+ distances.append(numpy.cumsum(numpy.sqrt(numpy.sum(deltas**2, axis=1))))
return numpy.concatenate(distances)
# points ######################################################################
-DIAMOND, CIRCLE, SQUARE, PLUS, X_MARKER, POINT, PIXEL, ASTERISK = \
- 'd', 'o', 's', '+', 'x', '.', ',', '*'
+DIAMOND, CIRCLE, SQUARE, PLUS, X_MARKER, POINT, PIXEL, ASTERISK = (
+ "d",
+ "o",
+ "s",
+ "+",
+ "x",
+ ".",
+ ",",
+ "*",
+)
-H_LINE, V_LINE, HEART = '_', '|', u'\u2665'
+H_LINE, V_LINE, HEART = "_", "|", "\u2665"
TICK_LEFT = "tickleft"
TICK_RIGHT = "tickright"
@@ -570,9 +591,27 @@ class Points2D(object):
:param List[float] offset: Translation of coordinates (ox, oy)
"""
- MARKERS = (DIAMOND, CIRCLE, SQUARE, PLUS, X_MARKER, POINT, PIXEL, ASTERISK,
- H_LINE, V_LINE, HEART, TICK_LEFT, TICK_RIGHT, TICK_UP, TICK_DOWN,
- CARET_LEFT, CARET_RIGHT, CARET_UP, CARET_DOWN)
+ MARKERS = (
+ DIAMOND,
+ CIRCLE,
+ SQUARE,
+ PLUS,
+ X_MARKER,
+ POINT,
+ PIXEL,
+ ASTERISK,
+ H_LINE,
+ V_LINE,
+ HEART,
+ TICK_LEFT,
+ TICK_RIGHT,
+ TICK_UP,
+ TICK_DOWN,
+ CARET_LEFT,
+ CARET_RIGHT,
+ CARET_UP,
+ CARET_DOWN,
+ )
"""List of supported markers"""
_VERTEX_SHADER = """
@@ -595,47 +634,39 @@ class Points2D(object):
"""
_FRAGMENT_SHADER_SYMBOLS = {
- DIAMOND: """
+ DIAMOND: """
float alphaSymbol(vec2 coord, float size) {
vec2 centerCoord = abs(coord - vec2(0.5, 0.5));
float f = centerCoord.x + centerCoord.y;
return clamp(size * (0.5 - f), 0.0, 1.0);
}
""",
- CIRCLE: """
+ CIRCLE: """
float alphaSymbol(vec2 coord, float size) {
float radius = 0.5;
float r = distance(coord, vec2(0.5, 0.5));
return clamp(size * (radius - r), 0.0, 1.0);
}
""",
- SQUARE: """
+ SQUARE: """
float alphaSymbol(vec2 coord, float size) {
return 1.0;
}
""",
- PLUS: """
+ PLUS: """
float alphaSymbol(vec2 coord, float size) {
vec2 d = abs(size * (coord - vec2(0.5, 0.5)));
- if (min(d.x, d.y) < 0.5) {
- return 1.0;
- } else {
- return 0.0;
- }
+ return local_smoothstep(1.5, 0.5, min(d.x, d.y));
}
""",
- X_MARKER: """
+ X_MARKER: """
float alphaSymbol(vec2 coord, float size) {
vec2 pos = floor(size * coord) + 0.5;
vec2 d_x = abs(pos.x + vec2(- pos.y, pos.y - size));
- if (min(d_x.x, d_x.y) <= 0.5) {
- return 1.0;
- } else {
- return 0.0;
- }
+ return local_smoothstep(1.5, 0.5, min(d_x.x, d_x.y));
}
""",
- ASTERISK: """
+ ASTERISK: """
float alphaSymbol(vec2 coord, float size) {
/* Combining +, x and circle */
vec2 d_plus = abs(size * (coord - vec2(0.5, 0.5)));
@@ -651,27 +682,19 @@ class Points2D(object):
}
}
""",
- H_LINE: """
+ H_LINE: """
float alphaSymbol(vec2 coord, float size) {
- float dy = abs(size * (coord.y - 0.5));
- if (dy < 0.5) {
- return 1.0;
- } else {
- return 0.0;
- }
+ float d = abs(size * (coord.y - 0.5));
+ return local_smoothstep(1.5, 0.5, d);
}
""",
- V_LINE: """
+ V_LINE: """
float alphaSymbol(vec2 coord, float size) {
- float dx = abs(size * (coord.x - 0.5));
- if (dx < 0.5) {
- return 1.0;
- } else {
- return 0.0;
- }
+ float d = abs(size * (coord.x - 0.5));
+ return local_smoothstep(1.5, 0.5, d);
}
""",
- HEART: """
+ HEART: """
float alphaSymbol(vec2 coord, float size) {
coord = (coord - 0.5) * 2.;
coord *= 0.75;
@@ -682,93 +705,89 @@ class Points2D(object):
float d = (13.0*h - 22.0*h*h + 10.0*h*h*h)/(6.0-5.0*h);
float res = clamp(r-d, 0., 1.);
// antialiasing
- res = smoothstep(0.1, 0.001, res);
+ res = local_smoothstep(0.1, 0.001, res);
return res;
}
""",
- TICK_LEFT: """
+ TICK_LEFT: """
float alphaSymbol(vec2 coord, float size) {
coord = size * (coord - 0.5);
float dy = abs(coord.y);
- if (dy < 0.5 && coord.x < 0.5) {
- return 1.0;
- } else {
+ if (coord.x > 0.5) {
return 0.0;
}
+ return local_smoothstep(1.5, 0.5, dy);
}
""",
- TICK_RIGHT: """
+ TICK_RIGHT: """
float alphaSymbol(vec2 coord, float size) {
coord = size * (coord - 0.5);
float dy = abs(coord.y);
- if (dy < 0.5 && coord.x > -0.5) {
- return 1.0;
- } else {
+ if (coord.x < -0.5) {
return 0.0;
}
+ return local_smoothstep(1.5, 0.5, dy);
}
""",
- TICK_UP: """
+ TICK_UP: """
float alphaSymbol(vec2 coord, float size) {
- coord = size * (coord - 0.5);
+ coord = size * (coord - 0.5);
float dx = abs(coord.x);
- if (dx < 0.5 && coord.y < 0.5) {
- return 1.0;
- } else {
+ if (coord.y > 0.5) {
return 0.0;
}
+ return local_smoothstep(1.5, 0.5, dx);
}
""",
- TICK_DOWN: """
+ TICK_DOWN: """
float alphaSymbol(vec2 coord, float size) {
coord = size * (coord - 0.5);
float dx = abs(coord.x);
- if (dx < 0.5 && coord.y > -0.5) {
- return 1.0;
- } else {
+ if (coord.y < -0.5) {
return 0.0;
}
+ return local_smoothstep(1.5, 0.5, dx);
}
""",
- CARET_LEFT: """
+ CARET_LEFT: """
float alphaSymbol(vec2 coord, float size) {
coord = size * (coord - 0.5);
float d = abs(coord.x) - abs(coord.y);
if (d >= -0.1 && coord.x > 0.5) {
- return smoothstep(-0.1, 0.1, d);
+ return local_smoothstep(-0.1, 0.1, d);
} else {
return 0.0;
}
}
""",
- CARET_RIGHT: """
+ CARET_RIGHT: """
float alphaSymbol(vec2 coord, float size) {
coord = size * (coord - 0.5);
float d = abs(coord.x) - abs(coord.y);
if (d >= -0.1 && coord.x < 0.5) {
- return smoothstep(-0.1, 0.1, d);
+ return local_smoothstep(-0.1, 0.1, d);
} else {
return 0.0;
}
}
""",
- CARET_UP: """
+ CARET_UP: """
float alphaSymbol(vec2 coord, float size) {
coord = size * (coord - 0.5);
float d = abs(coord.y) - abs(coord.x);
if (d >= -0.1 && coord.y > 0.5) {
- return smoothstep(-0.1, 0.1, d);
+ return local_smoothstep(-0.1, 0.1, d);
} else {
return 0.0;
}
}
""",
- CARET_DOWN: """
+ CARET_DOWN: """
float alphaSymbol(vec2 coord, float size) {
coord = size * (coord - 0.5);
float d = abs(coord.y) - abs(coord.x);
if (d >= -0.1 && coord.y < 0.5) {
- return smoothstep(-0.1, 0.1, d);
+ return local_smoothstep(-0.1, 0.1, d);
} else {
return 0.0;
}
@@ -783,6 +802,13 @@ class Points2D(object):
varying vec4 vColor;
+ /* smoothstep function implementation to support GLSL 1.20 */
+ float local_smoothstep(float edge0, float edge1, float x) {
+ float t;
+ t = clamp((x - edge0) / (edge1 - edge0), 0.0, 1.0);
+ return t * t * (3.0 - 2.0 * t);
+ }
+
%s
void main(void) {
@@ -797,22 +823,27 @@ class Points2D(object):
_PROGRAMS = {}
- def __init__(self, xVboData=None, yVboData=None, colorVboData=None,
- marker=SQUARE, color=(0., 0., 0., 1.), size=7,
- offset=(0., 0.)):
+ def __init__(
+ self,
+ xVboData=None,
+ yVboData=None,
+ colorVboData=None,
+ marker=SQUARE,
+ color=(0.0, 0.0, 0.0, 1.0),
+ size=7,
+ offset=(0.0, 0.0),
+ ):
self.color = color
self._marker = None
self.marker = marker
self.size = size
self.offset = offset
- if (xVboData is not None and
- not isinstance(xVboData, VertexBufferAttrib)):
+ if xVboData is not None and not isinstance(xVboData, VertexBufferAttrib):
xVboData = numpy.array(xVboData, copy=False, dtype=numpy.float32)
self.xVboData = xVboData
- if (yVboData is not None and
- not isinstance(yVboData, VertexBufferAttrib)):
+ if yVboData is not None and not isinstance(yVboData, VertexBufferAttrib):
yVboData = numpy.array(yVboData, copy=False, dtype=numpy.float32)
self.yVboData = yVboData
@@ -845,9 +876,11 @@ class Points2D(object):
if marker not in cls._PROGRAMS:
cls._PROGRAMS[marker] = Program(
vertexShader=cls._VERTEX_SHADER,
- fragmentShader=(cls._FRAGMENT_SHADER_TEMPLATE %
- cls._FRAGMENT_SHADER_SYMBOLS[marker]),
- attrib0='xPos')
+ fragmentShader=(
+ cls._FRAGMENT_SHADER_TEMPLATE % cls._FRAGMENT_SHADER_SYMBOLS[marker]
+ ),
+ attrib0="xPos",
+ )
return cls._PROGRAMS[marker]
@@ -873,9 +906,10 @@ class Points2D(object):
program = self._getProgram(self.marker)
program.use()
- matrix = numpy.dot(context.matrix,
- mat4Translate(*self.offset)).astype(numpy.float32)
- gl.glUniformMatrix4fv(program.uniforms['matrix'], 1, gl.GL_TRUE, matrix)
+ matrix = numpy.dot(context.matrix, mat4Translate(*self.offset)).astype(
+ numpy.float32
+ )
+ gl.glUniformMatrix4fv(program.uniforms["matrix"], 1, gl.GL_TRUE, matrix)
if self.marker == PIXEL:
size = 1
@@ -883,17 +917,24 @@ class Points2D(object):
size = math.ceil(0.5 * self.size) + 1 # Mimic Matplotlib point
else:
size = self.size
- size = size / 72. * context.dpi
-
- if self.marker in (PLUS, H_LINE, V_LINE,
- TICK_LEFT, TICK_RIGHT, TICK_UP, TICK_DOWN):
+ size = size / 72.0 * context.dpi
+
+ if self.marker in (
+ PLUS,
+ H_LINE,
+ V_LINE,
+ TICK_LEFT,
+ TICK_RIGHT,
+ TICK_UP,
+ TICK_DOWN,
+ ):
# Convert to nearest odd number
- size = size // 2 * 2 + 1.
+ size = size // 2 * 2 + 1.0
- gl.glUniform1f(program.uniforms['size'], size)
+ gl.glUniform1f(program.uniforms["size"], size)
# gl.glPointSize(self.size)
- cAttrib = program.attributes['color']
+ cAttrib = program.attributes["color"]
if self.useColorVboData and self.colorVboData is not None:
gl.glEnableVertexAttribArray(cAttrib)
self.colorVboData.setVertexAttrib(cAttrib)
@@ -901,36 +942,30 @@ class Points2D(object):
gl.glDisableVertexAttribArray(cAttrib)
gl.glVertexAttrib4f(cAttrib, *self.color)
- xPosAttrib = program.attributes['xPos']
+ xPosAttrib = program.attributes["xPos"]
gl.glEnableVertexAttribArray(xPosAttrib)
if isinstance(self.xVboData, VertexBufferAttrib):
self.xVboData.setVertexAttrib(xPosAttrib)
else:
- gl.glVertexAttribPointer(xPosAttrib,
- 1,
- gl.GL_FLOAT,
- False,
- 0,
- self.xVboData)
-
- yPosAttrib = program.attributes['yPos']
+ gl.glVertexAttribPointer(
+ xPosAttrib, 1, gl.GL_FLOAT, False, 0, self.xVboData
+ )
+
+ yPosAttrib = program.attributes["yPos"]
gl.glEnableVertexAttribArray(yPosAttrib)
if isinstance(self.yVboData, VertexBufferAttrib):
self.yVboData.setVertexAttrib(yPosAttrib)
else:
- gl.glVertexAttribPointer(yPosAttrib,
- 1,
- gl.GL_FLOAT,
- False,
- 0,
- self.yVboData)
-
+ gl.glVertexAttribPointer(
+ yPosAttrib, 1, gl.GL_FLOAT, False, 0, self.yVboData
+ )
gl.glDrawArrays(gl.GL_POINTS, 0, self.xVboData.size)
# error bars ##################################################################
+
class _ErrorBars(object):
"""Display errors bars.
@@ -956,49 +991,58 @@ class _ErrorBars(object):
:param List[float] offset: Translation of coordinates (ox, oy)
"""
- def __init__(self, xData, yData, xError, yError,
- xMin, yMin,
- color=(0., 0., 0., 1.),
- offset=(0., 0.)):
+ def __init__(
+ self,
+ xData,
+ yData,
+ xError,
+ yError,
+ xMin,
+ yMin,
+ color=(0.0, 0.0, 0.0, 1.0),
+ offset=(0.0, 0.0),
+ ):
self._attribs = None
self._xMin, self._yMin = xMin, yMin
self.offset = offset
if xError is not None or yError is not None:
- self._xData = numpy.array(
- xData, order='C', dtype=numpy.float32, copy=False)
- self._yData = numpy.array(
- yData, order='C', dtype=numpy.float32, copy=False)
+ self._xData = numpy.array(xData, order="C", dtype=numpy.float32, copy=False)
+ self._yData = numpy.array(yData, order="C", dtype=numpy.float32, copy=False)
# This also works if xError, yError is a float/int
self._xError = numpy.array(
- xError, order='C', dtype=numpy.float32, copy=False)
+ xError, order="C", dtype=numpy.float32, copy=False
+ )
self._yError = numpy.array(
- yError, order='C', dtype=numpy.float32, copy=False)
+ yError, order="C", dtype=numpy.float32, copy=False
+ )
else:
self._xData, self._yData = None, None
self._xError, self._yError = None, None
self._lines = GLLines2D(
- None, None, color=color, drawMode=gl.GL_LINES, offset=offset)
+ None, None, color=color, drawMode=gl.GL_LINES, offset=offset
+ )
self._xErrPoints = Points2D(
- None, None, color=color, marker=V_LINE, offset=offset)
+ None, None, color=color, marker=V_LINE, offset=offset
+ )
self._yErrPoints = Points2D(
- None, None, color=color, marker=H_LINE, offset=offset)
+ None, None, color=color, marker=H_LINE, offset=offset
+ )
def _buildVertices(self):
"""Generates error bars vertices"""
- nbLinesPerDataPts = (0 if self._xError is None else 2) + \
- (0 if self._yError is None else 2)
+ nbLinesPerDataPts = (0 if self._xError is None else 2) + (
+ 0 if self._yError is None else 2
+ )
nbDataPts = len(self._xData)
# interleave coord+error, coord-error.
# xError vertices first if any, then yError vertices if any.
- xCoords = numpy.empty(nbDataPts * nbLinesPerDataPts * 2,
- dtype=numpy.float32)
- yCoords = numpy.empty(nbDataPts * nbLinesPerDataPts * 2,
- dtype=numpy.float32)
+ xCoords = numpy.empty(nbDataPts * nbLinesPerDataPts * 2, dtype=numpy.float32)
+ yCoords = numpy.empty(nbDataPts * nbLinesPerDataPts * 2, dtype=numpy.float32)
if self._xError is not None: # errors on the X axis
if len(self._xError.shape) == 2:
@@ -1010,15 +1054,15 @@ class _ErrorBars(object):
# Interleave vertices for xError
endXError = 4 * nbDataPts
with numpy.errstate(invalid="ignore"):
- xCoords[0:endXError-3:4] = self._xData + xErrorPlus
- xCoords[1:endXError-2:4] = self._xData
- xCoords[2:endXError-1:4] = self._xData
+ xCoords[0 : endXError - 3 : 4] = self._xData + xErrorPlus
+ xCoords[1 : endXError - 2 : 4] = self._xData
+ xCoords[2 : endXError - 1 : 4] = self._xData
with numpy.errstate(invalid="ignore"):
xCoords[3:endXError:4] = self._xData - xErrorMinus
- yCoords[0:endXError-3:4] = self._yData
- yCoords[1:endXError-2:4] = self._yData
- yCoords[2:endXError-1:4] = self._yData
+ yCoords[0 : endXError - 3 : 4] = self._yData
+ yCoords[1 : endXError - 2 : 4] = self._yData
+ yCoords[2 : endXError - 1 : 4] = self._yData
yCoords[3:endXError:4] = self._yData
else:
@@ -1033,16 +1077,16 @@ class _ErrorBars(object):
# Interleave vertices for yError
xCoords[endXError::4] = self._xData
- xCoords[endXError+1::4] = self._xData
- xCoords[endXError+2::4] = self._xData
- xCoords[endXError+3::4] = self._xData
+ xCoords[endXError + 1 :: 4] = self._xData
+ xCoords[endXError + 2 :: 4] = self._xData
+ xCoords[endXError + 3 :: 4] = self._xData
with numpy.errstate(invalid="ignore"):
yCoords[endXError::4] = self._yData + yErrorPlus
- yCoords[endXError+1::4] = self._yData
- yCoords[endXError+2::4] = self._yData
+ yCoords[endXError + 1 :: 4] = self._yData
+ yCoords[endXError + 2 :: 4] = self._yData
with numpy.errstate(invalid="ignore"):
- yCoords[endXError+3::4] = self._yData - yErrorMinus
+ yCoords[endXError + 3 :: 4] = self._yData - yErrorMinus
return xCoords, yCoords
@@ -1069,12 +1113,10 @@ class _ErrorBars(object):
# Set yError points using the same VBO as lines
self._yErrPoints.xVboData = xAttrib.copy()
self._yErrPoints.xVboData.size //= 2
- self._yErrPoints.xVboData.offset += (xAttrib.itemsize *
- xAttrib.size // 2)
+ self._yErrPoints.xVboData.offset += xAttrib.itemsize * xAttrib.size // 2
self._yErrPoints.yVboData = yAttrib.copy()
self._yErrPoints.yVboData.size //= 2
- self._yErrPoints.yVboData.offset += (yAttrib.itemsize *
- yAttrib.size // 2)
+ self._yErrPoints.yVboData.offset += yAttrib.itemsize * yAttrib.size // 2
def render(self, context):
"""Perform rendering
@@ -1103,12 +1145,14 @@ class _ErrorBars(object):
# curves ######################################################################
+
def _proxyProperty(*componentsAttributes):
"""Create a property to access an attribute of attribute(s).
Useful for composition.
Supports multiple components this way:
getter returns the first found, setter sets all
"""
+
def getter(self):
for compName, attrName in componentsAttributes:
try:
@@ -1122,22 +1166,30 @@ def _proxyProperty(*componentsAttributes):
for compName, attrName in componentsAttributes:
component = getattr(self, compName)
setattr(component, attrName, value)
+
return property(getter, setter)
class GLPlotCurve2D(GLPlotItem):
- def __init__(self, xData, yData, colorData=None,
- xError=None, yError=None,
- lineStyle=SOLID,
- lineColor=(0., 0., 0., 1.),
- lineWidth=1,
- lineDashPeriod=20,
- marker=SQUARE,
- markerColor=(0., 0., 0., 1.),
- markerSize=7,
- fillColor=None,
- baseline=None,
- isYLog=False):
+ def __init__(
+ self,
+ xData,
+ yData,
+ colorData=None,
+ xError=None,
+ yError=None,
+ lineColor=(0.0, 0.0, 0.0, 1.0),
+ lineGapColor=None,
+ lineWidth=1,
+ lineDashOffset=0.0,
+ lineDashPattern=(),
+ marker=SQUARE,
+ markerColor=(0.0, 0.0, 0.0, 1.0),
+ markerSize=7,
+ fillColor=None,
+ baseline=None,
+ isYLog=False,
+ ):
super().__init__()
self._ratio = None
self.colorData = colorData
@@ -1147,7 +1199,7 @@ class GLPlotCurve2D(GLPlotItem):
self.xMin, self.xMax = min_max(xData, min_positive=False)
else:
# Takes the error into account
- if hasattr(xError, 'shape') and len(xError.shape) == 2:
+ if hasattr(xError, "shape") and len(xError.shape) == 2:
xErrorMinus, xErrorPlus = xError[0], xError[1]
else:
xErrorMinus, xErrorPlus = xError, xError
@@ -1159,7 +1211,7 @@ class GLPlotCurve2D(GLPlotItem):
self.yMin, self.yMax = min_max(yData, min_positive=False)
else:
# Takes the error into account
- if hasattr(yError, 'shape') and len(yError.shape) == 2:
+ if hasattr(yError, "shape") and len(yError.shape) == 2:
yErrorMinus, yErrorPlus = yError[0], yError[1]
else:
yErrorMinus, yErrorPlus = yError, yError
@@ -1175,44 +1227,53 @@ class GLPlotCurve2D(GLPlotItem):
self.yData = (yData - self.offset[1]).astype(numpy.float32)
else: # float32
- self.offset = 0., 0.
+ self.offset = 0.0, 0.0
self.xData = xData
self.yData = yData
if fillColor is not None:
+
def deduce_baseline(baseline):
if baseline is None:
_baseline = 0
else:
_baseline = baseline
if not isinstance(_baseline, numpy.ndarray):
- _baseline = numpy.repeat(_baseline,
- len(self.xData))
+ _baseline = numpy.repeat(_baseline, len(self.xData))
if isYLog is True:
- with numpy.errstate(divide='ignore', invalid='ignore'):
+ with numpy.errstate(divide="ignore", invalid="ignore"):
log_val = numpy.log10(_baseline)
- _baseline = numpy.where(_baseline>0.0, log_val, -38)
+ _baseline = numpy.where(_baseline > 0.0, log_val, -38)
return _baseline
_baseline = deduce_baseline(baseline)
# Use different baseline depending of Y log scale
- self.fill = _Fill2D(self.xData, self.yData,
- baseline=_baseline,
- color=fillColor,
- offset=self.offset)
+ self.fill = _Fill2D(
+ self.xData,
+ self.yData,
+ baseline=_baseline,
+ color=fillColor,
+ offset=self.offset,
+ )
else:
self.fill = None
- self._errorBars = _ErrorBars(self.xData, self.yData,
- xError, yError,
- self.xMin, self.yMin,
- offset=self.offset)
+ self._errorBars = _ErrorBars(
+ self.xData,
+ self.yData,
+ xError,
+ yError,
+ self.xMin,
+ self.yMin,
+ offset=self.offset,
+ )
self.lines = GLLines2D()
- self.lines.style = lineStyle
self.lines.color = lineColor
+ self.lines.gapColor = lineGapColor
self.lines.width = lineWidth
- self.lines.dashPeriod = lineDashPeriod
+ self.lines.dashOffset = lineDashOffset
+ self.lines.dashPattern = lineDashPattern
self.lines.offset = self.offset
self.points = Points2D()
@@ -1221,31 +1282,33 @@ class GLPlotCurve2D(GLPlotItem):
self.points.size = markerSize
self.points.offset = self.offset
- xVboData = _proxyProperty(('lines', 'xVboData'), ('points', 'xVboData'))
+ xVboData = _proxyProperty(("lines", "xVboData"), ("points", "xVboData"))
+
+ yVboData = _proxyProperty(("lines", "yVboData"), ("points", "yVboData"))
- yVboData = _proxyProperty(('lines', 'yVboData'), ('points', 'yVboData'))
+ colorVboData = _proxyProperty(("lines", "colorVboData"), ("points", "colorVboData"))
- colorVboData = _proxyProperty(('lines', 'colorVboData'),
- ('points', 'colorVboData'))
+ useColorVboData = _proxyProperty(
+ ("lines", "useColorVboData"), ("points", "useColorVboData")
+ )
- useColorVboData = _proxyProperty(('lines', 'useColorVboData'),
- ('points', 'useColorVboData'))
+ distVboData = _proxyProperty(("lines", "distVboData"))
- distVboData = _proxyProperty(('lines', 'distVboData'))
+ lineColor = _proxyProperty(("lines", "color"))
- lineStyle = _proxyProperty(('lines', 'style'))
+ lineGapColor = _proxyProperty(("lines", "gapColor"))
- lineColor = _proxyProperty(('lines', 'color'))
+ lineWidth = _proxyProperty(("lines", "width"))
- lineWidth = _proxyProperty(('lines', 'width'))
+ lineDashOffset = _proxyProperty(("lines", "dashOffset"))
- lineDashPeriod = _proxyProperty(('lines', 'dashPeriod'))
+ lineDashPattern = _proxyProperty(("lines", "dashPattern"))
- marker = _proxyProperty(('points', 'marker'))
+ marker = _proxyProperty(("points", "marker"))
- markerColor = _proxyProperty(('points', 'color'))
+ markerColor = _proxyProperty(("points", "color"))
- markerSize = _proxyProperty(('points', 'size'))
+ markerSize = _proxyProperty(("points", "size"))
@classmethod
def init(cls):
@@ -1257,25 +1320,28 @@ class GLPlotCurve2D(GLPlotItem):
"""Rendering preparation: build indices and bounding box vertices"""
if self.xVboData is None:
xAttrib, yAttrib, cAttrib, dAttrib = None, None, None, None
- if self.lineStyle in (DASHED, DASHDOT, DOTTED):
+ if self.lineDashPattern:
dists = distancesFromArrays(self.xData, self.yData, self._ratio)
if self.colorData is None:
xAttrib, yAttrib, dAttrib = vertexBuffer(
- (self.xData, self.yData, dists))
+ (self.xData, self.yData, dists)
+ )
else:
xAttrib, yAttrib, cAttrib, dAttrib = vertexBuffer(
- (self.xData, self.yData, self.colorData, dists))
+ (self.xData, self.yData, self.colorData, dists)
+ )
elif self.colorData is None:
xAttrib, yAttrib = vertexBuffer((self.xData, self.yData))
else:
xAttrib, yAttrib, cAttrib = vertexBuffer(
- (self.xData, self.yData, self.colorData))
+ (self.xData, self.yData, self.colorData)
+ )
self.xVboData = xAttrib
self.yVboData = yAttrib
self.distVboData = dAttrib
- if cAttrib is not None and self.colorData.dtype.kind == 'u':
+ if cAttrib is not None and self.colorData.dtype.kind == "u":
cAttrib.normalization = True # Normalize uint to [0, 1]
self.colorVboData = cAttrib
self.useColorVboData = cAttrib is not None
@@ -1285,13 +1351,17 @@ class GLPlotCurve2D(GLPlotItem):
:param RenderContext context: Rendering information
"""
- if self.lineStyle in (DASHED, DASHDOT, DOTTED):
+ if self.lineDashPattern:
visibleRanges = context.plotFrame.transformedDataRanges
xLimits = visibleRanges.x
- yLimits = visibleRanges.y if self.yaxis == 'left' else visibleRanges.y2
+ yLimits = visibleRanges.y if self.yaxis == "left" else visibleRanges.y2
width, height = context.plotFrame.plotSize
- ratio = (height * (xLimits[1] - xLimits[0])) / (width * (yLimits[1] - yLimits[0]))
- if self._ratio is None or abs(1. - ratio/self._ratio) > 0.05: # Tolerate 5% difference
+ ratio = (height * (xLimits[1] - xLimits[0])) / (
+ width * (yLimits[1] - yLimits[0])
+ )
+ if (
+ self._ratio is None or abs(1.0 - ratio / self._ratio) > 0.05
+ ): # Tolerate 5% difference
# Rebuild curve buffers to update distances
self._ratio = ratio
self.discard()
@@ -1318,9 +1388,11 @@ class GLPlotCurve2D(GLPlotItem):
self.fill.discard()
def isInitialized(self):
- return (self.xVboData is not None or
- self._errorBars.isInitialized() or
- (self.fill is not None and self.fill.isInitialized()))
+ return (
+ self.xVboData is not None
+ or self._errorBars.isInitialized()
+ or (self.fill is not None and self.fill.isInitialized())
+ )
def pick(self, xPickMin, yPickMin, xPickMax, yPickMax):
"""Perform picking on the curve according to its rendering.
@@ -1335,9 +1407,13 @@ class GLPlotCurve2D(GLPlotItem):
:return: The indices of the picked data
:rtype: Union[List[int],None]
"""
- if (self.marker is None and self.lineStyle is None) or \
- self.xMin > xPickMax or xPickMin > self.xMax or \
- self.yMin > yPickMax or yPickMin > self.yMax:
+ if (
+ (self.marker is None and self.lineDashPattern is None)
+ or self.xMin > xPickMax
+ or xPickMin > self.xMax
+ or self.yMin > yPickMax
+ or yPickMin > self.yMax
+ ):
return None
# offset picking bounds
@@ -1346,25 +1422,27 @@ class GLPlotCurve2D(GLPlotItem):
yPickMin = yPickMin - self.offset[1]
yPickMax = yPickMax - self.offset[1]
- if self.lineStyle is not None:
+ if self.lineDashPattern is not None:
# Using Cohen-Sutherland algorithm for line clipping
- with numpy.errstate(invalid='ignore'): # Ignore NaN comparison warnings
- codes = ((self.yData > yPickMax) << 3) | \
- ((self.yData < yPickMin) << 2) | \
- ((self.xData > xPickMax) << 1) | \
- (self.xData < xPickMin)
-
- notNaN = numpy.logical_not(numpy.logical_or(
- numpy.isnan(self.xData), numpy.isnan(self.yData)))
+ with numpy.errstate(invalid="ignore"): # Ignore NaN comparison warnings
+ codes = (
+ ((self.yData > yPickMax) << 3)
+ | ((self.yData < yPickMin) << 2)
+ | ((self.xData > xPickMax) << 1)
+ | (self.xData < xPickMin)
+ )
+
+ notNaN = numpy.logical_not(
+ numpy.logical_or(numpy.isnan(self.xData), numpy.isnan(self.yData))
+ )
# Add all points that are inside the picking area
- indices = numpy.nonzero(
- numpy.logical_and(codes == 0, notNaN))[0].tolist()
+ indices = numpy.nonzero(numpy.logical_and(codes == 0, notNaN))[0].tolist()
# Segment that might cross the area with no end point inside it
- segToTestIdx = numpy.nonzero((codes[:-1] != 0) &
- (codes[1:] != 0) &
- ((codes[:-1] & codes[1:]) == 0))[0]
+ segToTestIdx = numpy.nonzero(
+ (codes[:-1] != 0) & (codes[1:] != 0) & ((codes[:-1] & codes[1:]) == 0)
+ )[0]
TOP, BOTTOM, RIGHT, LEFT = (1 << 3), (1 << 2), (1 << 1), (1 << 0)
@@ -1405,10 +1483,12 @@ class GLPlotCurve2D(GLPlotItem):
indices.sort()
else:
- with numpy.errstate(invalid='ignore'): # Ignore NaN comparison warnings
- indices = numpy.nonzero((self.xData >= xPickMin) &
- (self.xData <= xPickMax) &
- (self.yData >= yPickMin) &
- (self.yData <= yPickMax))[0].tolist()
+ with numpy.errstate(invalid="ignore"): # Ignore NaN comparison warnings
+ indices = numpy.nonzero(
+ (self.xData >= xPickMin)
+ & (self.xData <= xPickMax)
+ & (self.yData >= yPickMin)
+ & (self.yData <= yPickMax)
+ )[0].tolist()
return tuple(indices) if len(indices) > 0 else None
diff --git a/src/silx/gui/plot/backends/glutils/GLPlotFrame.py b/src/silx/gui/plot/backends/glutils/GLPlotFrame.py
index e5fabf2..42cfa50 100644
--- a/src/silx/gui/plot/backends/glutils/GLPlotFrame.py
+++ b/src/silx/gui/plot/backends/glutils/GLPlotFrame.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2014-2022 European Synchrotron Radiation Facility
+# Copyright (c) 2014-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -25,6 +25,8 @@
This modules provides the rendering of plot titles, axes and grid.
"""
+from __future__ import annotations
+
__authors__ = ["T. Vincent"]
__license__ = "MIT"
__date__ = "03/04/2017"
@@ -44,12 +46,19 @@ from collections import namedtuple
import numpy
+from .... import qt
from ...._glutils import gl, Program
+from ....utils.matplotlib import DefaultTickFormatter
from ..._utils import checkAxisLimits, FLOAT32_MINPOS
from .GLSupport import mat4Ortho
from .GLText import Text2D, CENTER, BOTTOM, TOP, LEFT, RIGHT, ROTATE_270
from ..._utils.ticklayout import niceNumbersAdaptative, niceNumbersForLog10
-from ..._utils.dtime_ticklayout import calcTicksAdaptive, bestFormatString
+from ..._utils.dtime_ticklayout import (
+ DtUnit,
+ bestUnit,
+ calcTicksAdaptive,
+ formatDatetimes,
+)
from ..._utils.dtime_ticklayout import timestamp
_logger = logging.getLogger(__name__)
@@ -57,36 +66,52 @@ _logger = logging.getLogger(__name__)
# PlotAxis ####################################################################
+
class PlotAxis(object):
"""Represents a 1D axis of the plot.
This class is intended to be used with :class:`GLPlotFrame`.
"""
- def __init__(self, plotFrame,
- tickLength=(0., 0.),
- foregroundColor=(0., 0., 0., 1.0),
- labelAlign=CENTER, labelVAlign=CENTER,
- titleAlign=CENTER, titleVAlign=CENTER,
- titleRotate=0, titleOffset=(0., 0.)):
+ def __init__(
+ self,
+ plotFrame,
+ tickLength=(0.0, 0.0),
+ foregroundColor=(0.0, 0.0, 0.0, 1.0),
+ labelAlign=CENTER,
+ labelVAlign=CENTER,
+ titleAlign=CENTER,
+ titleVAlign=CENTER,
+ orderOffsetAlign=CENTER,
+ orderOffsetVAlign=CENTER,
+ titleRotate=0,
+ titleOffset=(0.0, 0.0),
+ font: qt.QFont | None = None,
+ ):
+ self._tickFormatter = DefaultTickFormatter()
self._ticks = None
+ self._orderAndOffsetText = ""
self._plotFrameRef = weakref.ref(plotFrame)
self._isDateTime = False
self._timeZone = None
self._isLog = False
- self._dataRange = 1., 100.
- self._displayCoords = (0., 0.), (1., 0.)
- self._title = ''
+ self._dataRange = 1.0, 100.0
+ self._displayCoords = (0.0, 0.0), (1.0, 0.0)
+ self._title = ""
self._tickLength = tickLength
self._foregroundColor = foregroundColor
self._labelAlign = labelAlign
self._labelVAlign = labelVAlign
+ self._orderOffetAnchor = (1.0, 0.0)
+ self._orderOffsetAlign = orderOffsetAlign
+ self._orderOffsetVAlign = orderOffsetVAlign
self._titleAlign = titleAlign
self._titleVAlign = titleVAlign
self._titleRotate = titleRotate
self._titleOffset = titleOffset
+ self._font = font
@property
def dataRange(self):
@@ -94,6 +119,12 @@ class PlotAxis(object):
of 2 floats: (min, max)."""
return self._dataRange
+ @property
+ def font(self) -> qt.QFont:
+ if self._font is None:
+ return qt.QApplication.instance().font()
+ return self._font
+
@dataRange.setter
def dataRange(self, dataRange):
assert len(dataRange) == 2
@@ -161,7 +192,13 @@ class PlotAxis(object):
def devicePixelRatio(self):
"""Returns the ratio between qt pixels and device pixels."""
plotFrame = self._plotFrameRef()
- return plotFrame.devicePixelRatio if plotFrame is not None else 1.
+ return plotFrame.devicePixelRatio if plotFrame is not None else 1.0
+
+ @property
+ def dotsPerInch(self):
+ """Returns the screen DPI"""
+ plotFrame = self._plotFrameRef()
+ return plotFrame.dotsPerInch if plotFrame is not None else 92
@property
def title(self):
@@ -175,6 +212,17 @@ class PlotAxis(object):
self._dirtyPlotFrame()
@property
+ def orderOffetAnchor(self) -> tuple[float, float]:
+ """Anchor position for the tick order&offset text"""
+ return self._orderOffetAnchor
+
+ @orderOffetAnchor.setter
+ def orderOffetAnchor(self, position: tuple[float, float]):
+ if position != self._orderOffetAnchor:
+ self._orderOffetAnchor = position
+ self._dirtyTicks()
+
+ @property
def titleOffset(self):
"""Title offset in pixels (x: int, y: int)"""
return self._titleOffset
@@ -193,8 +241,9 @@ class PlotAxis(object):
@foregroundColor.setter
def foregroundColor(self, color):
"""Color used for frame and labels"""
- assert len(color) == 4, \
- "foregroundColor must have length 4, got {}".format(len(self._foregroundColor))
+ assert len(color) == 4, "foregroundColor must have length 4, got {}".format(
+ len(self._foregroundColor)
+ )
if self._foregroundColor != color:
self._foregroundColor = color
self._dirtyTicks()
@@ -213,7 +262,6 @@ class PlotAxis(object):
"""
vertices = list(self.displayCoords) # Add start and end points
labels = []
- tickLabelsSize = [0., 0.]
xTickLength, yTickLength = self._tickLength
xTickLength *= self.devicePixelRatio
@@ -222,27 +270,24 @@ class PlotAxis(object):
if text is None:
tickScale = 0.5
else:
- tickScale = 1.
-
- label = Text2D(text=text,
- color=self._foregroundColor,
- x=xPixel - xTickLength,
- y=yPixel - yTickLength,
- align=self._labelAlign,
- valign=self._labelVAlign,
- devicePixelRatio=self.devicePixelRatio)
-
- width, height = label.size
- if width > tickLabelsSize[0]:
- tickLabelsSize[0] = width
- if height > tickLabelsSize[1]:
- tickLabelsSize[1] = height
-
+ tickScale = 1.0
+
+ label = Text2D(
+ text=text,
+ font=self.font,
+ color=self._foregroundColor,
+ x=xPixel - xTickLength,
+ y=yPixel - yTickLength,
+ align=self._labelAlign,
+ valign=self._labelVAlign,
+ devicePixelRatio=self.devicePixelRatio,
+ )
labels.append(label)
vertices.append((xPixel, yPixel))
- vertices.append((xPixel + tickScale * xTickLength,
- yPixel + tickScale * yTickLength))
+ vertices.append(
+ (xPixel + tickScale * xTickLength, yPixel + tickScale * yTickLength)
+ )
(x0, y0), (x1, y1) = self.displayCoords
xAxisCenter = 0.5 * (x0 + x1)
@@ -257,16 +302,33 @@ class PlotAxis(object):
# yOffset = -tickLabelsSize[1] * yTickLength / tickNorm
# yOffset -= 3 * yTickLength
- axisTitle = Text2D(text=self.title,
- color=self._foregroundColor,
- x=xAxisCenter + xOffset,
- y=yAxisCenter + yOffset,
- align=self._titleAlign,
- valign=self._titleVAlign,
- rotate=self._titleRotate,
- devicePixelRatio=self.devicePixelRatio)
+ axisTitle = Text2D(
+ text=self.title,
+ font=self.font,
+ color=self._foregroundColor,
+ x=xAxisCenter + xOffset,
+ y=yAxisCenter + yOffset,
+ align=self._titleAlign,
+ valign=self._titleVAlign,
+ rotate=self._titleRotate,
+ devicePixelRatio=self.devicePixelRatio,
+ )
labels.append(axisTitle)
+ if self._orderAndOffsetText:
+ xOrderOffset, yOrderOffet = self.orderOffetAnchor
+ labels.append(
+ Text2D(
+ text=self._orderAndOffsetText,
+ font=self.font,
+ color=self._foregroundColor,
+ x=xOrderOffset,
+ y=yOrderOffet,
+ align=self._orderOffsetAlign,
+ valign=self._orderOffsetVAlign,
+ devicePixelRatio=self.devicePixelRatio,
+ )
+ )
return vertices, labels
def _dirtyPlotFrame(self):
@@ -291,19 +353,19 @@ class PlotAxis(object):
"""Generator of ticks as tuples:
((x, y) in display, dataPos, textLabel).
"""
+ self._orderAndOffsetText = ""
+
dataMin, dataMax = self.dataRange
- if self.isLog and dataMin <= 0.:
- _logger.warning(
- 'Getting ticks while isLog=True and dataRange[0]<=0.')
- dataMin = 1.
+ if self.isLog and dataMin <= 0.0:
+ _logger.warning("Getting ticks while isLog=True and dataRange[0]<=0.")
+ dataMin = 1.0
if dataMax < dataMin:
- dataMax = 1.
+ dataMax = 1.0
if dataMin != dataMax: # data range is not null
(x0, y0), (x1, y1) = self.displayCoords
if self.isLog:
-
if self.isTimeSeries:
_logger.warning("Time series not implemented for log-scale")
@@ -315,16 +377,16 @@ class PlotAxis(object):
for logPos in self._frange(tickMin, tickMax, step):
if logMin <= logPos <= logMax:
- dataPos = 10 ** logPos
+ dataPos = 10**logPos
xPixel = x0 + (logPos - logMin) * xScale
yPixel = y0 + (logPos - logMin) * yScale
- text = '1e%+03d' % logPos
+ text = "1e%+03d" % logPos
yield ((xPixel, yPixel), dataPos, text)
if step == 1:
ticks = list(self._frange(tickMin, tickMax, step))[:-1]
for logPos in ticks:
- dataOrigPos = 10 ** logPos
+ dataOrigPos = 10**logPos
for index in range(2, 10):
dataPos = dataOrigPos * index
if dataMin <= dataPos <= dataMax:
@@ -337,26 +399,34 @@ class PlotAxis(object):
xScale = (x1 - x0) / (dataMax - dataMin)
yScale = (y1 - y0) / (dataMax - dataMin)
- nbPixels = math.sqrt(pow(x1 - x0, 2) + pow(y1 - y0, 2)) / self.devicePixelRatio
+ nbPixels = (
+ math.sqrt(pow(x1 - x0, 2) + pow(y1 - y0, 2)) / self.devicePixelRatio
+ )
# Density of 1.3 label per 92 pixels
# i.e., 1.3 label per inch on a 92 dpi screen
- tickDensity = 1.3 / 92
+ tickDensity = 1.3 * self.devicePixelRatio / self.dotsPerInch
if not self.isTimeSeries:
- tickMin, tickMax, step, nbFrac = niceNumbersAdaptative(
- dataMin, dataMax, nbPixels, tickDensity)
-
- for dataPos in self._frange(tickMin, tickMax, step):
- if dataMin <= dataPos <= dataMax:
- xPixel = x0 + (dataPos - dataMin) * xScale
- yPixel = y0 + (dataPos - dataMin) * yScale
-
- if nbFrac == 0:
- text = '%g' % dataPos
- else:
- text = ('%.' + str(nbFrac) + 'f') % dataPos
- yield ((xPixel, yPixel), dataPos, text)
+ tickMin, tickMax, step, _ = niceNumbersAdaptative(
+ dataMin, dataMax, nbPixels, tickDensity
+ )
+
+ visibleTickPositions = [
+ pos
+ for pos in self._frange(tickMin, tickMax, step)
+ if dataMin <= pos <= dataMax
+ ]
+ self._tickFormatter.axis.set_view_interval(dataMin, dataMax)
+ self._tickFormatter.axis.set_data_interval(dataMin, dataMax)
+ texts = self._tickFormatter.format_ticks(visibleTickPositions)
+ self._orderAndOffsetText = self._tickFormatter.get_offset()
+
+ for dataPos, text in zip(visibleTickPositions, texts):
+ xPixel = x0 + (dataPos - dataMin) * xScale
+ yPixel = y0 + (dataPos - dataMin) * yScale
+ yield ((xPixel, yPixel), dataPos, text)
+
else:
# Time series
try:
@@ -366,24 +436,30 @@ class PlotAxis(object):
_logger.warning("Data range cannot be displayed with time axis")
return # Range is out of bound of the datetime
- tickDateTimes, spacing, unit = calcTicksAdaptive(
- dtMin, dtMax, nbPixels, tickDensity)
+ if bestUnit(
+ (dtMax - dtMin).total_seconds() == DtUnit.MICRO_SECONDS
+ ):
+ # Special case for micro seconds: Reduce tick density
+ tickDensity = 1.0 * self.devicePixelRatio / self.dotsPerInch
- for tickDateTime in tickDateTimes:
- if dtMin <= tickDateTime <= dtMax:
-
- dataPos = timestamp(tickDateTime)
- xPixel = x0 + (dataPos - dataMin) * xScale
- yPixel = y0 + (dataPos - dataMin) * yScale
-
- fmtStr = bestFormatString(spacing, unit)
- text = tickDateTime.strftime(fmtStr)
-
- yield ((xPixel, yPixel), dataPos, text)
+ tickDateTimes, spacing, unit = calcTicksAdaptive(
+ dtMin, dtMax, nbPixels, tickDensity
+ )
+ visibleDatetimes = tuple(
+ dt for dt in tickDateTimes if dtMin <= dt <= dtMax
+ )
+ ticks = formatDatetimes(visibleDatetimes, spacing, unit)
+
+ for tickDateTime, text in ticks.items():
+ dataPos = timestamp(tickDateTime)
+ xPixel = x0 + (dataPos - dataMin) * xScale
+ yPixel = y0 + (dataPos - dataMin) * yScale
+ yield ((xPixel, yPixel), dataPos, text)
# GLPlotFrame #################################################################
+
class GLPlotFrame(object):
"""Base class for rendering a 2D frame surrounded by axes."""
@@ -391,7 +467,7 @@ class GLPlotFrame(object):
_LINE_WIDTH = 1
_SHADERS = {
- 'vertex': """
+ "vertex": """
attribute vec2 position;
uniform mat4 matrix;
@@ -399,7 +475,7 @@ class GLPlotFrame(object):
gl_Position = matrix * vec4(position, 0.0, 1.0);
}
""",
- 'fragment': """
+ "fragment": """
uniform vec4 color;
uniform float tickFactor; /* = 1./tickLength or 0. for solid line */
@@ -410,15 +486,15 @@ class GLPlotFrame(object):
discard;
}
}
- """
+ """,
}
- _Margins = namedtuple('Margins', ('left', 'right', 'top', 'bottom'))
+ _Margins = namedtuple("Margins", ("left", "right", "top", "bottom"))
# Margins used when plot frame is not displayed
_NoDisplayMargins = _Margins(0, 0, 0, 0)
- def __init__(self, marginRatios, foregroundColor, gridColor):
+ def __init__(self, marginRatios, foregroundColor, gridColor, font: qt.QFont):
"""
:param List[float] marginRatios:
The ratios of margins around plot area for axis and labels.
@@ -427,6 +503,7 @@ class GLPlotFrame(object):
:type foregroundColor: tuple with RGBA values ranging from 0.0 to 1.0
:param gridColor: color used for grid lines.
:type gridColor: tuple RGBA with RGBA values ranging from 0.0 to 1.0
+ :param font: Font used by the axes label
"""
self._renderResources = None
@@ -439,10 +516,12 @@ class GLPlotFrame(object):
self.axes = [] # List of PlotAxis to be updated by subclasses
self._grid = False
- self._size = 0., 0.
- self._title = ''
+ self._size = 0.0, 0.0
+ self._title = ""
+ self._font: qt.QFont = font
- self._devicePixelRatio = 1.
+ self._devicePixelRatio = 1.0
+ self._dpi = 92
@property
def isDirty(self):
@@ -452,18 +531,19 @@ class GLPlotFrame(object):
GRID_NONE = 0
GRID_MAIN_TICKS = 1
GRID_SUB_TICKS = 2
- GRID_ALL_TICKS = (GRID_MAIN_TICKS + GRID_SUB_TICKS)
+ GRID_ALL_TICKS = GRID_MAIN_TICKS + GRID_SUB_TICKS
@property
def foregroundColor(self):
"""Color used for frame and labels"""
return self._foregroundColor
-
+
@foregroundColor.setter
def foregroundColor(self, color):
"""Color used for frame and labels"""
- assert len(color) == 4, \
- "foregroundColor must have length 4, got {}".format(len(self._foregroundColor))
+ assert len(color) == 4, "foregroundColor must have length 4, got {}".format(
+ len(self._foregroundColor)
+ )
if self._foregroundColor != color:
self._foregroundColor = color
for axis in self.axes:
@@ -474,20 +554,20 @@ class GLPlotFrame(object):
def gridColor(self):
"""Color used for frame and labels"""
return self._gridColor
-
+
@gridColor.setter
def gridColor(self, color):
"""Color used for frame and labels"""
- assert len(color) == 4, \
- "gridColor must have length 4, got {}".format(len(self._gridColor))
+ assert len(color) == 4, "gridColor must have length 4, got {}".format(
+ len(self._gridColor)
+ )
if self._gridColor != color:
self._gridColor = color
self._dirty()
@property
def marginRatios(self):
- """Plot margin ratios: (left, top, right, bottom) as 4 float in [0, 1].
- """
+ """Plot margin ratios: (left, top, right, bottom) as 4 float in [0, 1]."""
return self.__marginRatios
@marginRatios.setter
@@ -495,9 +575,9 @@ class GLPlotFrame(object):
ratios = tuple(float(v) for v in ratios)
assert len(ratios) == 4
for value in ratios:
- assert 0. <= value <= 1.
- assert ratios[0] + ratios[2] < 1.
- assert ratios[1] + ratios[3] < 1.
+ assert 0.0 <= value <= 1.0
+ assert ratios[0] + ratios[2] < 1.0
+ assert ratios[1] + ratios[3] < 1.0
if self.__marginRatios != ratios:
self.__marginRatios = ratios
@@ -511,10 +591,11 @@ class GLPlotFrame(object):
width, height = self.size
left, top, right, bottom = self.marginRatios
self.__marginsCache = self._Margins(
- left=int(left*width),
- right=int(right*width),
- top=int(top*height),
- bottom=int(bottom*height))
+ left=int(left * width),
+ right=int(right * width),
+ top=int(top * height),
+ bottom=int(bottom * height),
+ )
return self.__marginsCache
@property
@@ -528,6 +609,16 @@ class GLPlotFrame(object):
self._dirty()
@property
+ def dotsPerInch(self):
+ return self._dpi
+
+ @dotsPerInch.setter
+ def dotsPerInch(self, dpi):
+ if dpi != self._dpi:
+ self._dpi = dpi
+ self._dirty()
+
+ @property
def grid(self):
"""Grid display mode:
- 0: No grid.
@@ -538,8 +629,12 @@ class GLPlotFrame(object):
@grid.setter
def grid(self, grid):
- assert grid in (self.GRID_NONE, self.GRID_MAIN_TICKS,
- self.GRID_SUB_TICKS, self.GRID_ALL_TICKS)
+ assert grid in (
+ self.GRID_NONE,
+ self.GRID_MAIN_TICKS,
+ self.GRID_SUB_TICKS,
+ self.GRID_ALL_TICKS,
+ )
if grid != self._grid:
self._grid = grid
self._dirty()
@@ -595,16 +690,22 @@ class GLPlotFrame(object):
return []
elif self._grid == self.GRID_MAIN_TICKS:
+
def test(text):
return text is not None
+
elif self._grid == self.GRID_SUB_TICKS:
+
def test(text):
return text is None
+
elif self._grid == self.GRID_ALL_TICKS:
+
def test(_):
return True
+
else:
- logging.warning('Wrong grid mode: %d' % self._grid)
+ logging.warning("Wrong grid mode: %d" % self._grid)
return []
return self._buildGridVerticesWithTest(test)
@@ -626,25 +727,27 @@ class GLPlotFrame(object):
vertices = numpy.array(vertices, dtype=numpy.float32)
# Add main title
- xTitle = (self.size[0] + self.margins.left -
- self.margins.right) // 2
+ xTitle = (self.size[0] + self.margins.left - self.margins.right) // 2
yTitle = self.margins.top - self._TICK_LENGTH_IN_PIXELS
- labels.append(Text2D(text=self.title,
- color=self._foregroundColor,
- x=xTitle,
- y=yTitle,
- align=CENTER,
- valign=BOTTOM,
- devicePixelRatio=self.devicePixelRatio))
+ labels.append(
+ Text2D(
+ text=self.title,
+ font=self._font,
+ color=self._foregroundColor,
+ x=xTitle,
+ y=yTitle,
+ align=CENTER,
+ valign=BOTTOM,
+ devicePixelRatio=self.devicePixelRatio,
+ )
+ )
# grid
- gridVertices = numpy.array(self._buildGridVertices(),
- dtype=numpy.float32)
+ gridVertices = numpy.array(self._buildGridVertices(), dtype=numpy.float32)
self._renderResources = (vertices, gridVertices, labels)
- _program = Program(
- _SHADERS['vertex'], _SHADERS['fragment'], attrib0='position')
+ _program = Program(_SHADERS["vertex"], _SHADERS["fragment"], attrib0="position")
def render(self):
if self.margins == self._NoDisplayMargins:
@@ -664,22 +767,21 @@ class GLPlotFrame(object):
gl.glLineWidth(self._LINE_WIDTH)
- gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE,
- matProj.astype(numpy.float32))
- gl.glUniform4f(prog.uniforms['color'], *self._foregroundColor)
- gl.glUniform1f(prog.uniforms['tickFactor'], 0.)
+ gl.glUniformMatrix4fv(
+ prog.uniforms["matrix"], 1, gl.GL_TRUE, matProj.astype(numpy.float32)
+ )
+ gl.glUniform4f(prog.uniforms["color"], *self._foregroundColor)
+ gl.glUniform1f(prog.uniforms["tickFactor"], 0.0)
- gl.glEnableVertexAttribArray(prog.attributes['position'])
- gl.glVertexAttribPointer(prog.attributes['position'],
- 2,
- gl.GL_FLOAT,
- gl.GL_FALSE,
- 0, vertices)
+ gl.glEnableVertexAttribArray(prog.attributes["position"])
+ gl.glVertexAttribPointer(
+ prog.attributes["position"], 2, gl.GL_FLOAT, gl.GL_FALSE, 0, vertices
+ )
gl.glDrawArrays(gl.GL_LINES, 0, len(vertices))
for label in labels:
- label.render(matProj)
+ label.render(matProj, self.dotsPerInch)
def renderGrid(self):
if self._grid == self.GRID_NONE:
@@ -698,25 +800,25 @@ class GLPlotFrame(object):
prog.use()
gl.glLineWidth(self._LINE_WIDTH)
- gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE,
- matProj.astype(numpy.float32))
- gl.glUniform4f(prog.uniforms['color'], *self._gridColor)
- gl.glUniform1f(prog.uniforms['tickFactor'], 0.) # 1/2.) # 1/tickLen
-
- gl.glEnableVertexAttribArray(prog.attributes['position'])
- gl.glVertexAttribPointer(prog.attributes['position'],
- 2,
- gl.GL_FLOAT,
- gl.GL_FALSE,
- 0, gridVertices)
+ gl.glUniformMatrix4fv(
+ prog.uniforms["matrix"], 1, gl.GL_TRUE, matProj.astype(numpy.float32)
+ )
+ gl.glUniform4f(prog.uniforms["color"], *self._gridColor)
+ gl.glUniform1f(prog.uniforms["tickFactor"], 0.0) # 1/2.) # 1/tickLen
+
+ gl.glEnableVertexAttribArray(prog.attributes["position"])
+ gl.glVertexAttribPointer(
+ prog.attributes["position"], 2, gl.GL_FLOAT, gl.GL_FALSE, 0, gridVertices
+ )
gl.glDrawArrays(gl.GL_LINES, 0, len(gridVertices))
# GLPlotFrame2D ###############################################################
+
class GLPlotFrame2D(GLPlotFrame):
- def __init__(self, marginRatios, foregroundColor, gridColor):
+ def __init__(self, marginRatios, foregroundColor, gridColor, font: qt.QFont):
"""
:param List[float] marginRatios:
The ratios of margins around plot area for axis and labels.
@@ -725,38 +827,66 @@ class GLPlotFrame2D(GLPlotFrame):
:type foregroundColor: tuple with RGBA values ranging from 0.0 to 1.0
:param gridColor: color used for grid lines.
:type gridColor: tuple RGBA with RGBA values ranging from 0.0 to 1.0
-
+ :param font: Font used by the axes label
"""
- super(GLPlotFrame2D, self).__init__(marginRatios, foregroundColor, gridColor)
- self.axes.append(PlotAxis(self,
- tickLength=(0., -5.),
- foregroundColor=self._foregroundColor,
- labelAlign=CENTER, labelVAlign=TOP,
- titleAlign=CENTER, titleVAlign=TOP,
- titleRotate=0))
+ super(GLPlotFrame2D, self).__init__(
+ marginRatios, foregroundColor, gridColor, font
+ )
+ self._font = font
+
+ self.axes.append(
+ PlotAxis(
+ self,
+ tickLength=(0.0, -5.0),
+ foregroundColor=self._foregroundColor,
+ labelAlign=CENTER,
+ labelVAlign=TOP,
+ orderOffsetAlign=RIGHT,
+ orderOffsetVAlign=TOP,
+ titleAlign=CENTER,
+ titleVAlign=TOP,
+ titleRotate=0,
+ font=self._font,
+ )
+ )
self._x2AxisCoords = ()
- self.axes.append(PlotAxis(self,
- tickLength=(5., 0.),
- foregroundColor=self._foregroundColor,
- labelAlign=RIGHT, labelVAlign=CENTER,
- titleAlign=CENTER, titleVAlign=BOTTOM,
- titleRotate=ROTATE_270))
+ self.axes.append(
+ PlotAxis(
+ self,
+ tickLength=(5.0, 0.0),
+ foregroundColor=self._foregroundColor,
+ labelAlign=RIGHT,
+ labelVAlign=CENTER,
+ orderOffsetAlign=LEFT,
+ orderOffsetVAlign=BOTTOM,
+ titleAlign=CENTER,
+ titleVAlign=BOTTOM,
+ titleRotate=ROTATE_270,
+ font=self._font,
+ )
+ )
- self._y2Axis = PlotAxis(self,
- tickLength=(-5., 0.),
- foregroundColor=self._foregroundColor,
- labelAlign=LEFT, labelVAlign=CENTER,
- titleAlign=CENTER, titleVAlign=TOP,
- titleRotate=ROTATE_270)
+ self._y2Axis = PlotAxis(
+ self,
+ tickLength=(-5.0, 0.0),
+ foregroundColor=self._foregroundColor,
+ labelAlign=LEFT,
+ labelVAlign=CENTER,
+ orderOffsetAlign=RIGHT,
+ orderOffsetVAlign=BOTTOM,
+ titleAlign=CENTER,
+ titleVAlign=TOP,
+ titleRotate=ROTATE_270,
+ font=self._font,
+ )
self._isYAxisInverted = False
- self._dataRanges = {
- 'x': (1., 100.), 'y': (1., 100.), 'y2': (1., 100.)}
+ self._dataRanges = {"x": (1.0, 100.0), "y": (1.0, 100.0), "y2": (1.0, 100.0)}
- self._baseVectors = (1., 0.), (0., 1.)
+ self._baseVectors = (1.0, 0.0), (0.0, 1.0)
self._transformedDataRanges = None
self._transformedDataProjMat = None
@@ -771,10 +901,12 @@ class GLPlotFrame2D(GLPlotFrame):
@property
def isDirty(self):
"""True if it need to refresh graphic rendering, False otherwise."""
- return (super(GLPlotFrame2D, self).isDirty or
- self._transformedDataRanges is None or
- self._transformedDataProjMat is None or
- self._transformedDataY2ProjMat is None)
+ return (
+ super(GLPlotFrame2D, self).isDirty
+ or self._transformedDataRanges is None
+ or self._transformedDataProjMat is None
+ or self._transformedDataY2ProjMat is None
+ )
@property
def xAxis(self):
@@ -815,7 +947,7 @@ class GLPlotFrame2D(GLPlotFrame):
self._isYAxisInverted = value
self._dirty()
- DEFAULT_BASE_VECTORS = (1., 0.), (0., 1.)
+ DEFAULT_BASE_VECTORS = (1.0, 0.0), (0.0, 1.0)
"""Values of baseVectors for orthogonal axes."""
@property
@@ -835,10 +967,9 @@ class GLPlotFrame2D(GLPlotFrame):
(xx, xy), (yx, yy) = baseVectors
vectors = (float(xx), float(xy)), (float(yx), float(yy))
- det = (vectors[0][0] * vectors[1][1] - vectors[1][0] * vectors[0][1])
- if det == 0.:
- raise ValueError("Singular matrix for base vectors: " +
- str(vectors))
+ det = vectors[0][0] * vectors[1][1] - vectors[1][0] * vectors[0][1]
+ if det == 0.0:
+ raise ValueError("Singular matrix for base vectors: " + str(vectors))
if vectors != self._baseVectors:
self._baseVectors = vectors
@@ -870,9 +1001,9 @@ class GLPlotFrame2D(GLPlotFrame):
Type: ((xMin, xMax), (yMin, yMax), (y2Min, y2Max))
"""
- return self._DataRanges(self._dataRanges['x'],
- self._dataRanges['y'],
- self._dataRanges['y2'])
+ return self._DataRanges(
+ self._dataRanges["x"], self._dataRanges["y"], self._dataRanges["y2"]
+ )
def setDataRanges(self, x=None, y=None, y2=None):
"""Set data range over each axes.
@@ -885,22 +1016,25 @@ class GLPlotFrame2D(GLPlotFrame):
:param y2: (min, max) data range over Y2 axis
"""
if x is not None:
- self._dataRanges['x'] = checkAxisLimits(
- x[0], x[1], self.xAxis.isLog, name='x')
+ self._dataRanges["x"] = checkAxisLimits(
+ x[0], x[1], self.xAxis.isLog, name="x"
+ )
if y is not None:
- self._dataRanges['y'] = checkAxisLimits(
- y[0], y[1], self.yAxis.isLog, name='y')
+ self._dataRanges["y"] = checkAxisLimits(
+ y[0], y[1], self.yAxis.isLog, name="y"
+ )
if y2 is not None:
- self._dataRanges['y2'] = checkAxisLimits(
- y2[0], y2[1], self.y2Axis.isLog, name='y2')
+ self._dataRanges["y2"] = checkAxisLimits(
+ y2[0], y2[1], self.y2Axis.isLog, name="y2"
+ )
- self.xAxis.dataRange = self._dataRanges['x']
- self.yAxis.dataRange = self._dataRanges['y']
- self.y2Axis.dataRange = self._dataRanges['y2']
+ self.xAxis.dataRange = self._dataRanges["x"]
+ self.yAxis.dataRange = self._dataRanges["y"]
+ self.y2Axis.dataRange = self._dataRanges["y2"]
- _DataRanges = namedtuple('dataRanges', ('x', 'y', 'y2'))
+ _DataRanges = namedtuple("dataRanges", ("x", "y", "y2"))
@property
def transformedDataRanges(self):
@@ -916,39 +1050,40 @@ class GLPlotFrame2D(GLPlotFrame):
try:
xMin = math.log10(xMin)
except ValueError:
- _logger.info('xMin: warning log10(%f)', xMin)
- xMin = 0.
+ _logger.info("xMin: warning log10(%f)", xMin)
+ xMin = 0.0
try:
xMax = math.log10(xMax)
except ValueError:
- _logger.info('xMax: warning log10(%f)', xMax)
- xMax = 0.
+ _logger.info("xMax: warning log10(%f)", xMax)
+ xMax = 0.0
if self.yAxis.isLog:
try:
yMin = math.log10(yMin)
except ValueError:
- _logger.info('yMin: warning log10(%f)', yMin)
- yMin = 0.
+ _logger.info("yMin: warning log10(%f)", yMin)
+ yMin = 0.0
try:
yMax = math.log10(yMax)
except ValueError:
- _logger.info('yMax: warning log10(%f)', yMax)
- yMax = 0.
+ _logger.info("yMax: warning log10(%f)", yMax)
+ yMax = 0.0
try:
y2Min = math.log10(y2Min)
except ValueError:
- _logger.info('yMin: warning log10(%f)', y2Min)
- y2Min = 0.
+ _logger.info("yMin: warning log10(%f)", y2Min)
+ y2Min = 0.0
try:
y2Max = math.log10(y2Max)
except ValueError:
- _logger.info('yMax: warning log10(%f)', y2Max)
- y2Max = 0.
+ _logger.info("yMax: warning log10(%f)", y2Max)
+ y2Max = 0.0
self._transformedDataRanges = self._DataRanges(
- (xMin, xMax), (yMin, yMax), (y2Min, y2Max))
+ (xMin, xMax), (yMin, yMax), (y2Min, y2Max)
+ )
return self._transformedDataRanges
@@ -991,10 +1126,9 @@ class GLPlotFrame2D(GLPlotFrame):
@staticmethod
def __applyLog(
- data: Union[float, numpy.ndarray],
- isLog: bool
+ data: Union[float, numpy.ndarray], isLog: bool
) -> Optional[Union[float, numpy.ndarray]]:
- """Apply log to data filtering out """
+ """Apply log to data filtering out"""
if not isLog:
return data
@@ -1006,13 +1140,12 @@ class GLPlotFrame2D(GLPlotFrame):
data = numpy.array(data, copy=True, dtype=numpy.float64)
data[isBelowMin] = numpy.nan
- with numpy.errstate(divide='ignore'):
+ with numpy.errstate(divide="ignore"):
return numpy.log10(data)
- def dataToPixel(self, x, y, axis='left'):
- """Convert data coordinate to widget pixel coordinate.
- """
- assert axis in ('left', 'right')
+ def dataToPixel(self, x, y, axis="left"):
+ """Convert data coordinate to widget pixel coordinate."""
+ assert axis in ("left", "right")
trBounds = self.transformedDataRanges
@@ -1034,13 +1167,12 @@ class GLPlotFrame2D(GLPlotFrame):
plotWidth, plotHeight = self.plotSize
- xPixel = (self.margins.left +
- plotWidth * (xDataTr - trBounds.x[0]) /
- (trBounds.x[1] - trBounds.x[0]))
+ xPixel = self.margins.left + plotWidth * (xDataTr - trBounds.x[0]) / (
+ trBounds.x[1] - trBounds.x[0]
+ )
usedAxis = trBounds.y if axis == "left" else trBounds.y2
- yOffset = (plotHeight * (yDataTr - usedAxis[0]) /
- (usedAxis[1] - usedAxis[0]))
+ yOffset = plotHeight * (yDataTr - usedAxis[0]) / (usedAxis[1] - usedAxis[0])
if self.isYAxisInverted:
yPixel = self.margins.top + yOffset
@@ -1048,8 +1180,12 @@ class GLPlotFrame2D(GLPlotFrame):
yPixel = self.size[1] - self.margins.bottom - yOffset
return (
- int(xPixel) if isinstance(xPixel, numbers.Real) else xPixel.astype(numpy.int64),
- int(yPixel) if isinstance(yPixel, numbers.Real) else yPixel.astype(numpy.int64),
+ int(xPixel)
+ if isinstance(xPixel, numbers.Real)
+ else xPixel.astype(numpy.int64),
+ int(yPixel)
+ if isinstance(yPixel, numbers.Real)
+ else yPixel.astype(numpy.int64),
)
def pixelToData(self, x, y, axis="left"):
@@ -1105,8 +1241,7 @@ class GLPlotFrame2D(GLPlotFrame):
if axis == self.xAxis:
vertices.append((xPixel, self.margins.top))
elif axis == self.yAxis:
- vertices.append((self.size[0] - self.margins.right,
- yPixel))
+ vertices.append((self.size[0] - self.margins.right, yPixel))
else: # axis == self.y2Axis
vertices.append((self.margins.left, yPixel))
@@ -1115,28 +1250,33 @@ class GLPlotFrame2D(GLPlotFrame):
plotLeft, plotTop = self.plotOrigin
plotWidth, plotHeight = self.plotSize
- corners = [(plotLeft, plotTop),
- (plotLeft, plotTop + plotHeight),
- (plotLeft + plotWidth, plotTop + plotHeight),
- (plotLeft + plotWidth, plotTop)]
+ corners = [
+ (plotLeft, plotTop),
+ (plotLeft, plotTop + plotHeight),
+ (plotLeft + plotWidth, plotTop + plotHeight),
+ (plotLeft + plotWidth, plotTop),
+ ]
for axis in self.axes:
if axis == self.xAxis:
- cornersInData = numpy.array([
- self.pixelToData(x, y) for (x, y) in corners])
- borders = ((cornersInData[0], cornersInData[3]), # top
- (cornersInData[1], cornersInData[0]), # left
- (cornersInData[3], cornersInData[2])) # right
+ cornersInData = numpy.array(
+ [self.pixelToData(x, y) for (x, y) in corners]
+ )
+ borders = (
+ (cornersInData[0], cornersInData[3]), # top
+ (cornersInData[1], cornersInData[0]), # left
+ (cornersInData[3], cornersInData[2]),
+ ) # right
for (xPixel, yPixel), data, text in axis.ticks:
if test(text):
for (x0, y0), (x1, y1) in borders:
if min(x0, x1) <= data < max(x0, x1):
- yIntersect = (data - x0) * \
- (y1 - y0) / (x1 - x0) + y0
+ yIntersect = (data - x0) * (y1 - y0) / (
+ x1 - x0
+ ) + y0
- pixelPos = self.dataToPixel(
- data, yIntersect)
+ pixelPos = self.dataToPixel(data, yIntersect)
if pixelPos is not None:
vertices.append((xPixel, yPixel))
vertices.append(pixelPos)
@@ -1144,32 +1284,38 @@ class GLPlotFrame2D(GLPlotFrame):
else: # y or y2 axes
if axis == self.yAxis:
- axis_name = 'left'
- cornersInData = numpy.array([
- self.pixelToData(x, y) for (x, y) in corners])
+ axis_name = "left"
+ cornersInData = numpy.array(
+ [self.pixelToData(x, y) for (x, y) in corners]
+ )
borders = (
(cornersInData[3], cornersInData[2]), # right
(cornersInData[0], cornersInData[3]), # top
- (cornersInData[2], cornersInData[1])) # bottom
+ (cornersInData[2], cornersInData[1]),
+ ) # bottom
else: # axis == self.y2Axis
- axis_name = 'right'
- corners = numpy.array([self.pixelToData(
- x, y, axis='right') for (x, y) in corners])
+ axis_name = "right"
+ corners = numpy.array(
+ [self.pixelToData(x, y, axis="right") for (x, y) in corners]
+ )
borders = (
(cornersInData[1], cornersInData[0]), # left
(cornersInData[0], cornersInData[3]), # top
- (cornersInData[2], cornersInData[1])) # bottom
+ (cornersInData[2], cornersInData[1]),
+ ) # bottom
for (xPixel, yPixel), data, text in axis.ticks:
if test(text):
for (x0, y0), (x1, y1) in borders:
if min(y0, y1) <= data < max(y0, y1):
- xIntersect = (data - y0) * \
- (x1 - x0) / (y1 - y0) + x0
+ xIntersect = (data - y0) * (x1 - x0) / (
+ y1 - y0
+ ) + x0
pixelPos = self.dataToPixel(
- xIntersect, data, axis=axis_name)
+ xIntersect, data, axis=axis_name
+ )
if pixelPos is not None:
vertices.append((xPixel, yPixel))
vertices.append(pixelPos)
@@ -1180,26 +1326,47 @@ class GLPlotFrame2D(GLPlotFrame):
def _buildVerticesAndLabels(self):
width, height = self.size
- xCoords = (self.margins.left - 0.5,
- width - self.margins.right + 0.5)
- yCoords = (height - self.margins.bottom + 0.5,
- self.margins.top - 0.5)
+ xCoords = (self.margins.left - 0.5, width - self.margins.right + 0.5)
+ yCoords = (height - self.margins.bottom + 0.5, self.margins.top - 0.5)
- self.axes[0].displayCoords = ((xCoords[0], yCoords[0]),
- (xCoords[1], yCoords[0]))
+ self.axes[0].displayCoords = (
+ (xCoords[0], yCoords[0]),
+ (xCoords[1], yCoords[0]),
+ )
- self._x2AxisCoords = ((xCoords[0], yCoords[1]),
- (xCoords[1], yCoords[1]))
+ self._x2AxisCoords = ((xCoords[0], yCoords[1]), (xCoords[1], yCoords[1]))
+
+ # Set order&offset anchor **before** handling Y axis inversion
+ fontPixelSize = self._font.pixelSize()
+ if fontPixelSize == -1:
+ fontPixelSize = self._font.pointSizeF() / 72.0 * self.dotsPerInch
+
+ self.axes[0].orderOffetAnchor = (
+ xCoords[1],
+ yCoords[0] + fontPixelSize * 1.2,
+ )
+ self.axes[1].orderOffetAnchor = (
+ xCoords[0],
+ yCoords[1] - 4 * self.devicePixelRatio,
+ )
+ self._y2Axis.orderOffetAnchor = (
+ xCoords[1],
+ yCoords[1] - 4 * self.devicePixelRatio,
+ )
if self.isYAxisInverted:
# Y axes are inverted, axes coordinates are inverted
yCoords = yCoords[1], yCoords[0]
- self.axes[1].displayCoords = ((xCoords[0], yCoords[0]),
- (xCoords[0], yCoords[1]))
+ self.axes[1].displayCoords = (
+ (xCoords[0], yCoords[0]),
+ (xCoords[0], yCoords[1]),
+ )
- self._y2Axis.displayCoords = ((xCoords[1], yCoords[0]),
- (xCoords[1], yCoords[1]))
+ self._y2Axis.displayCoords = (
+ (xCoords[1], yCoords[0]),
+ (xCoords[1], yCoords[1]),
+ )
super(GLPlotFrame2D, self)._buildVerticesAndLabels()
@@ -1211,8 +1378,7 @@ class GLPlotFrame2D(GLPlotFrame):
if not self.isY2Axis:
extraVertices += self._y2Axis.displayCoords
- extraVertices = numpy.array(
- extraVertices, copy=False, dtype=numpy.float32)
+ extraVertices = numpy.array(extraVertices, copy=False, dtype=numpy.float32)
vertices = numpy.append(vertices, extraVertices, axis=0)
self._renderResources = (vertices, gridVertices, labels)
@@ -1225,8 +1391,9 @@ class GLPlotFrame2D(GLPlotFrame):
@foregroundColor.setter
def foregroundColor(self, color):
"""Color used for frame and labels"""
- assert len(color) == 4, \
- "foregroundColor must have length 4, got {}".format(len(self._foregroundColor))
+ assert len(color) == 4, "foregroundColor must have length 4, got {}".format(
+ len(self._foregroundColor)
+ )
if self._foregroundColor != color:
self._y2Axis.foregroundColor = color
- GLPlotFrame.foregroundColor.fset(self, color) # call parent property
+ GLPlotFrame.foregroundColor.fset(self, color) # call parent property
diff --git a/src/silx/gui/plot/backends/glutils/GLPlotImage.py b/src/silx/gui/plot/backends/glutils/GLPlotImage.py
index 8353911..0973c47 100644
--- a/src/silx/gui/plot/backends/glutils/GLPlotImage.py
+++ b/src/silx/gui/plot/backends/glutils/GLPlotImage.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2014-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2014-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -33,8 +33,6 @@ __date__ = "03/04/2017"
import math
import numpy
-from silx.math.combo import min_max
-
from ...._glutils import gl, Program, Texture
from ..._utils import FLOAT32_MINPOS
from .GLSupport import mat4Translate, mat4Scale
@@ -64,29 +62,28 @@ class _GLPlotData2D(GLPlotItem):
@property
def xMin(self):
ox, sx = self.origin[0], self.scale[0]
- return ox if sx >= 0. else ox + sx * self.data.shape[1]
+ return ox if sx >= 0.0 else ox + sx * self.data.shape[1]
@property
def yMin(self):
oy, sy = self.origin[1], self.scale[1]
- return oy if sy >= 0. else oy + sy * self.data.shape[0]
+ return oy if sy >= 0.0 else oy + sy * self.data.shape[0]
@property
def xMax(self):
ox, sx = self.origin[0], self.scale[0]
- return ox + sx * self.data.shape[1] if sx >= 0. else ox
+ return ox + sx * self.data.shape[1] if sx >= 0.0 else ox
@property
def yMax(self):
oy, sy = self.origin[1], self.scale[1]
- return oy + sy * self.data.shape[0] if sy >= 0. else oy
+ return oy + sy * self.data.shape[0] if sy >= 0.0 else oy
class GLPlotColormap(_GLPlotData2D):
-
_SHADERS = {
- 'linear': {
- 'vertex': """
+ "linear": {
+ "vertex": """
#version 120
uniform mat4 matrix;
@@ -100,14 +97,14 @@ class GLPlotColormap(_GLPlotData2D):
gl_Position = matrix * vec4(position, 0.0, 1.0);
}
""",
- 'fragTransform': """
+ "fragTransform": """
vec2 textureCoords(void) {
return coords;
}
- """},
-
- 'log': {
- 'vertex': """
+ """,
+ },
+ "log": {
+ "vertex": """
#version 120
attribute vec2 position;
@@ -131,7 +128,7 @@ class GLPlotColormap(_GLPlotData2D):
gl_Position = matrix * dataPos;
}
""",
- 'fragTransform': """
+ "fragTransform": """
uniform bvec2 isLog;
uniform vec2 bounds_oneOverRange;
uniform vec2 bounds_originOverRange;
@@ -147,9 +144,9 @@ class GLPlotColormap(_GLPlotData2D):
return pos * bounds_oneOverRange - bounds_originOverRange;
// TODO texture coords in range different from [0, 1]
}
- """},
-
- 'fragment': """
+ """,
+ },
+ "fragment": """
#version 120
/* isnan declaration for compatibility with GLSL 1.20 */
@@ -209,7 +206,7 @@ class GLPlotColormap(_GLPlotData2D):
}
gl_FragColor.a *= alpha;
}
- """
+ """,
}
_DATA_TEX_UNIT = 0
@@ -223,21 +220,32 @@ class GLPlotColormap(_GLPlotData2D):
numpy.dtype(numpy.uint8): gl.GL_R8,
}
- _linearProgram = Program(_SHADERS['linear']['vertex'],
- _SHADERS['fragment'] %
- _SHADERS['linear']['fragTransform'],
- attrib0='position')
-
- _logProgram = Program(_SHADERS['log']['vertex'],
- _SHADERS['fragment'] %
- _SHADERS['log']['fragTransform'],
- attrib0='position')
-
- SUPPORTED_NORMALIZATIONS = 'linear', 'log', 'sqrt', 'gamma', 'arcsinh'
-
- def __init__(self, data, origin, scale,
- colormap, normalization='linear', gamma=0., cmapRange=None,
- alpha=1.0, nancolor=(1., 1., 1., 0.)):
+ _linearProgram = Program(
+ _SHADERS["linear"]["vertex"],
+ _SHADERS["fragment"] % _SHADERS["linear"]["fragTransform"],
+ attrib0="position",
+ )
+
+ _logProgram = Program(
+ _SHADERS["log"]["vertex"],
+ _SHADERS["fragment"] % _SHADERS["log"]["fragTransform"],
+ attrib0="position",
+ )
+
+ SUPPORTED_NORMALIZATIONS = "linear", "log", "sqrt", "gamma", "arcsinh"
+
+ def __init__(
+ self,
+ data,
+ origin,
+ scale,
+ colormap,
+ normalization="linear",
+ gamma=0.0,
+ cmapRange=None,
+ alpha=1.0,
+ nancolor=(1.0, 1.0, 1.0, 0.0),
+ ):
"""Create a 2D colormap
:param data: The 2D scalar data array to display
@@ -267,10 +275,10 @@ class GLPlotColormap(_GLPlotData2D):
self.colormap = numpy.array(colormap, copy=False)
self.normalization = normalization
self.gamma = gamma
- self._cmapRange = (1., 10.) # Colormap range
+ self._cmapRange = (1.0, 10.0) # Colormap range
self.cmapRange = cmapRange # Update _cmapRange
- self._alpha = numpy.clip(alpha, 0., 1.)
- self._nancolor = numpy.clip(nancolor, 0., 1.)
+ self._alpha = numpy.clip(alpha, 0.0, 1.0)
+ self._nancolor = numpy.clip(nancolor, 0.0, 1.0)
self._cmap_texture = None
self._texture = None
@@ -287,15 +295,14 @@ class GLPlotColormap(_GLPlotData2D):
self._textureIsDirty = False
def isInitialized(self):
- return (self._cmap_texture is not None or
- self._texture is not None)
+ return self._cmap_texture is not None or self._texture is not None
@property
def cmapRange(self):
- if self.normalization == 'log':
- assert self._cmapRange[0] > 0. and self._cmapRange[1] > 0.
- elif self.normalization == 'sqrt':
- assert self._cmapRange[0] >= 0. and self._cmapRange[1] >= 0.
+ if self.normalization == "log":
+ assert self._cmapRange[0] > 0.0 and self._cmapRange[1] > 0.0
+ elif self.normalization == "sqrt":
+ assert self._cmapRange[0] >= 0.0 and self._cmapRange[1] >= 0.0
return self._cmapRange
@cmapRange.setter
@@ -314,8 +321,7 @@ class GLPlotColormap(_GLPlotData2D):
self.data = data
if self._texture is not None:
- if (self.data.shape != oldData.shape or
- self.data.dtype != oldData.dtype):
+ if self.data.shape != oldData.shape or self.data.dtype != oldData.dtype:
self.discard()
else:
self._textureIsDirty = True
@@ -324,74 +330,77 @@ class GLPlotColormap(_GLPlotData2D):
if self._cmap_texture is None:
# TODO share cmap texture accross Images
# put all cmaps in one texture
- colormap = numpy.empty((16, 256, self.colormap.shape[1]),
- dtype=self.colormap.dtype)
+ colormap = numpy.empty(
+ (16, 256, self.colormap.shape[1]), dtype=self.colormap.dtype
+ )
colormap[:] = self.colormap
format_ = gl.GL_RGBA if colormap.shape[-1] == 4 else gl.GL_RGB
- self._cmap_texture = Texture(internalFormat=format_,
- data=colormap,
- format_=format_,
- texUnit=self._CMAP_TEX_UNIT,
- minFilter=gl.GL_NEAREST,
- magFilter=gl.GL_NEAREST,
- wrap=(gl.GL_CLAMP_TO_EDGE,
- gl.GL_CLAMP_TO_EDGE))
+ self._cmap_texture = Texture(
+ internalFormat=format_,
+ data=colormap,
+ format_=format_,
+ texUnit=self._CMAP_TEX_UNIT,
+ minFilter=gl.GL_NEAREST,
+ magFilter=gl.GL_NEAREST,
+ wrap=(gl.GL_CLAMP_TO_EDGE, gl.GL_CLAMP_TO_EDGE),
+ )
self._cmap_texture.prepare()
if self._texture is None:
internalFormat = self._INTERNAL_FORMATS[self.data.dtype]
- self._texture = Image(internalFormat,
- self.data,
- format_=gl.GL_RED,
- texUnit=self._DATA_TEX_UNIT)
+ self._texture = Image(
+ internalFormat,
+ self.data,
+ format_=gl.GL_RED,
+ texUnit=self._DATA_TEX_UNIT,
+ )
elif self._textureIsDirty:
self._textureIsDirty = True
self._texture.updateAll(format_=gl.GL_RED, data=self.data)
def _setCMap(self, prog):
dataMin, dataMax = self.cmapRange # If log, it is stricly positive
- param = 0.
+ param = 0.0
if self.data.dtype in (numpy.uint16, numpy.uint8):
# Using unsigned int as normalized integer in OpenGL
# So revert normalization in the shader
dataScale = float(numpy.iinfo(self.data.dtype).max)
else:
- dataScale = 1.
+ dataScale = 1.0
- if self.normalization == 'log':
+ if self.normalization == "log":
dataMin = math.log10(dataMin)
dataMax = math.log10(dataMax)
normID = 1
- elif self.normalization == 'sqrt':
+ elif self.normalization == "sqrt":
dataMin = math.sqrt(dataMin)
dataMax = math.sqrt(dataMax)
normID = 2
- elif self.normalization == 'gamma':
+ elif self.normalization == "gamma":
# Keep dataMin, dataMax as is
param = self.gamma
normID = 3
- elif self.normalization == 'arcsinh':
+ elif self.normalization == "arcsinh":
dataMin = numpy.arcsinh(dataMin)
dataMax = numpy.arcsinh(dataMax)
normID = 4
else: # Linear and fallback
normID = 0
- gl.glUniform1f(prog.uniforms['data_scale'], dataScale)
- gl.glUniform1i(prog.uniforms['cmap_texture'],
- self._cmap_texture.texUnit)
- gl.glUniform1i(prog.uniforms['cmap_normalization'], normID)
- gl.glUniform1f(prog.uniforms['cmap_parameter'], param)
- gl.glUniform1f(prog.uniforms['cmap_min'], dataMin)
+ gl.glUniform1f(prog.uniforms["data_scale"], dataScale)
+ gl.glUniform1i(prog.uniforms["cmap_texture"], self._cmap_texture.texUnit)
+ gl.glUniform1i(prog.uniforms["cmap_normalization"], normID)
+ gl.glUniform1f(prog.uniforms["cmap_parameter"], param)
+ gl.glUniform1f(prog.uniforms["cmap_min"], dataMin)
if dataMax > dataMin:
- oneOverRange = 1. / (dataMax - dataMin)
+ oneOverRange = 1.0 / (dataMax - dataMin)
else:
- oneOverRange = 0. # Fall-back
- gl.glUniform1f(prog.uniforms['cmap_oneOverRange'], oneOverRange)
+ oneOverRange = 0.0 # Fall-back
+ gl.glUniform1f(prog.uniforms["cmap_oneOverRange"], oneOverRange)
- gl.glUniform4f(prog.uniforms['nancolor'], *self._nancolor)
+ gl.glUniform4f(prog.uniforms["nancolor"], *self._nancolor)
self._cmap_texture.bind()
@@ -405,21 +414,25 @@ class GLPlotColormap(_GLPlotData2D):
prog = self._linearProgram
prog.use()
- gl.glUniform1i(prog.uniforms['data'], self._DATA_TEX_UNIT)
+ gl.glUniform1i(prog.uniforms["data"], self._DATA_TEX_UNIT)
- mat = numpy.dot(numpy.dot(context.matrix,
- mat4Translate(*self.origin)),
- mat4Scale(*self.scale))
- gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE,
- mat.astype(numpy.float32))
+ mat = numpy.dot(
+ numpy.dot(context.matrix, mat4Translate(*self.origin)),
+ mat4Scale(*self.scale),
+ )
+ gl.glUniformMatrix4fv(
+ prog.uniforms["matrix"], 1, gl.GL_TRUE, mat.astype(numpy.float32)
+ )
- gl.glUniform1f(prog.uniforms['alpha'], self.alpha)
+ gl.glUniform1f(prog.uniforms["alpha"], self.alpha)
self._setCMap(prog)
- self._texture.render(prog.attributes['position'],
- prog.attributes['texCoords'],
- self._DATA_TEX_UNIT)
+ self._texture.render(
+ prog.attributes["position"],
+ prog.attributes["texCoords"],
+ self._DATA_TEX_UNIT,
+ )
def _renderLog10(self, context):
"""Perform rendering when one axis has log scale
@@ -427,8 +440,9 @@ class GLPlotColormap(_GLPlotData2D):
:param RenderContext context: Rendering information
"""
xMin, yMin = self.xMin, self.yMin
- if ((context.isXLog and xMin < FLOAT32_MINPOS) or
- (context.isYLog and yMin < FLOAT32_MINPOS)):
+ if (context.isXLog and xMin < FLOAT32_MINPOS) or (
+ context.isYLog and yMin < FLOAT32_MINPOS
+ ):
# Do not render images that are partly or totally <= 0
return
@@ -439,27 +453,33 @@ class GLPlotColormap(_GLPlotData2D):
ox, oy = self.origin
- gl.glUniform1i(prog.uniforms['data'], self._DATA_TEX_UNIT)
+ gl.glUniform1i(prog.uniforms["data"], self._DATA_TEX_UNIT)
- gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE,
- context.matrix.astype(numpy.float32))
+ gl.glUniformMatrix4fv(
+ prog.uniforms["matrix"], 1, gl.GL_TRUE, context.matrix.astype(numpy.float32)
+ )
mat = numpy.dot(mat4Translate(ox, oy), mat4Scale(*self.scale))
- gl.glUniformMatrix4fv(prog.uniforms['matOffset'], 1, gl.GL_TRUE,
- mat.astype(numpy.float32))
+ gl.glUniformMatrix4fv(
+ prog.uniforms["matOffset"], 1, gl.GL_TRUE, mat.astype(numpy.float32)
+ )
- gl.glUniform2i(prog.uniforms['isLog'], context.isXLog, context.isYLog)
+ gl.glUniform2i(prog.uniforms["isLog"], context.isXLog, context.isYLog)
ex = ox + self.scale[0] * self.data.shape[1]
ey = oy + self.scale[1] * self.data.shape[0]
- xOneOverRange = 1. / (ex - ox)
- yOneOverRange = 1. / (ey - oy)
- gl.glUniform2f(prog.uniforms['bounds_originOverRange'],
- ox * xOneOverRange, oy * yOneOverRange)
- gl.glUniform2f(prog.uniforms['bounds_oneOverRange'],
- xOneOverRange, yOneOverRange)
+ xOneOverRange = 1.0 / (ex - ox)
+ yOneOverRange = 1.0 / (ey - oy)
+ gl.glUniform2f(
+ prog.uniforms["bounds_originOverRange"],
+ ox * xOneOverRange,
+ oy * yOneOverRange,
+ )
+ gl.glUniform2f(
+ prog.uniforms["bounds_oneOverRange"], xOneOverRange, yOneOverRange
+ )
- gl.glUniform1f(prog.uniforms['alpha'], self.alpha)
+ gl.glUniform1f(prog.uniforms["alpha"], self.alpha)
self._setCMap(prog)
@@ -469,20 +489,19 @@ class GLPlotColormap(_GLPlotData2D):
raise RuntimeError("No texture, discard has already been called")
if len(tiles) > 1:
raise NotImplementedError(
- "Image over multiple textures not supported with log scale")
+ "Image over multiple textures not supported with log scale"
+ )
texture, vertices, info = tiles[0]
texture.bind(self._DATA_TEX_UNIT)
- posAttrib = prog.attributes['position']
+ posAttrib = prog.attributes["position"]
stride = vertices.shape[-1] * vertices.itemsize
gl.glEnableVertexAttribArray(posAttrib)
- gl.glVertexAttribPointer(posAttrib,
- 2,
- gl.GL_FLOAT,
- gl.GL_FALSE,
- stride, vertices)
+ gl.glVertexAttribPointer(
+ posAttrib, 2, gl.GL_FLOAT, gl.GL_FALSE, stride, vertices
+ )
gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(vertices))
@@ -503,11 +522,11 @@ class GLPlotColormap(_GLPlotData2D):
# image #######################################################################
-class GLPlotRGBAImage(_GLPlotData2D):
+class GLPlotRGBAImage(_GLPlotData2D):
_SHADERS = {
- 'linear': {
- 'vertex': """
+ "linear": {
+ "vertex": """
#version 120
attribute vec2 position;
@@ -521,7 +540,7 @@ class GLPlotRGBAImage(_GLPlotData2D):
coords = texCoords;
}
""",
- 'fragment': """
+ "fragment": """
#version 120
uniform sampler2D tex;
@@ -533,10 +552,10 @@ class GLPlotRGBAImage(_GLPlotData2D):
gl_FragColor = texture2D(tex, coords);
gl_FragColor.a *= alpha;
}
- """},
-
- 'log': {
- 'vertex': """
+ """,
+ },
+ "log": {
+ "vertex": """
#version 120
attribute vec2 position;
@@ -560,7 +579,7 @@ class GLPlotRGBAImage(_GLPlotData2D):
gl_Position = matrix * dataPos;
}
""",
- 'fragment': """
+ "fragment": """
#version 120
uniform sampler2D tex;
@@ -587,22 +606,25 @@ class GLPlotRGBAImage(_GLPlotData2D):
gl_FragColor = texture2D(tex, textureCoords());
gl_FragColor.a *= alpha;
}
- """}
+ """,
+ },
}
_DATA_TEX_UNIT = 0
- _SUPPORTED_DTYPES = (numpy.dtype(numpy.float32),
- numpy.dtype(numpy.uint8),
- numpy.dtype(numpy.uint16))
+ _SUPPORTED_DTYPES = (
+ numpy.dtype(numpy.float32),
+ numpy.dtype(numpy.uint8),
+ numpy.dtype(numpy.uint16),
+ )
- _linearProgram = Program(_SHADERS['linear']['vertex'],
- _SHADERS['linear']['fragment'],
- attrib0='position')
+ _linearProgram = Program(
+ _SHADERS["linear"]["vertex"], _SHADERS["linear"]["fragment"], attrib0="position"
+ )
- _logProgram = Program(_SHADERS['log']['vertex'],
- _SHADERS['log']['fragment'],
- attrib0='position')
+ _logProgram = Program(
+ _SHADERS["log"]["vertex"], _SHADERS["log"]["fragment"], attrib0="position"
+ )
def __init__(self, data, origin, scale, alpha):
"""Create a 2D RGB(A) image from data
@@ -621,7 +643,7 @@ class GLPlotRGBAImage(_GLPlotData2D):
super(GLPlotRGBAImage, self).__init__(data, origin, scale)
self._texture = None
self._textureIsDirty = False
- self._alpha = numpy.clip(alpha, 0., 1.)
+ self._alpha = numpy.clip(alpha, 0.0, 1.0)
@property
def alpha(self):
@@ -649,17 +671,16 @@ class GLPlotRGBAImage(_GLPlotData2D):
def prepare(self):
if self._texture is None:
- formatName = 'GL_RGBA' if self.data.shape[2] == 4 else 'GL_RGB'
+ formatName = "GL_RGBA" if self.data.shape[2] == 4 else "GL_RGB"
format_ = getattr(gl, formatName)
if self.data.dtype == numpy.uint16:
- formatName += '16' # Use sized internal format for uint16
+ formatName += "16" # Use sized internal format for uint16
internalFormat = getattr(gl, formatName)
- self._texture = Image(internalFormat,
- self.data,
- format_=format_,
- texUnit=self._DATA_TEX_UNIT)
+ self._texture = Image(
+ internalFormat, self.data, format_=format_, texUnit=self._DATA_TEX_UNIT
+ )
elif self._textureIsDirty:
self._textureIsDirty = False
@@ -677,18 +698,23 @@ class GLPlotRGBAImage(_GLPlotData2D):
prog = self._linearProgram
prog.use()
- gl.glUniform1i(prog.uniforms['tex'], self._DATA_TEX_UNIT)
+ gl.glUniform1i(prog.uniforms["tex"], self._DATA_TEX_UNIT)
- mat = numpy.dot(numpy.dot(context.matrix, mat4Translate(*self.origin)),
- mat4Scale(*self.scale))
- gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE,
- mat.astype(numpy.float32))
+ mat = numpy.dot(
+ numpy.dot(context.matrix, mat4Translate(*self.origin)),
+ mat4Scale(*self.scale),
+ )
+ gl.glUniformMatrix4fv(
+ prog.uniforms["matrix"], 1, gl.GL_TRUE, mat.astype(numpy.float32)
+ )
- gl.glUniform1f(prog.uniforms['alpha'], self.alpha)
+ gl.glUniform1f(prog.uniforms["alpha"], self.alpha)
- self._texture.render(prog.attributes['position'],
- prog.attributes['texCoords'],
- self._DATA_TEX_UNIT)
+ self._texture.render(
+ prog.attributes["position"],
+ prog.attributes["texCoords"],
+ self._DATA_TEX_UNIT,
+ )
def _renderLog(self, context):
"""Perform rendering with axes having log scale
@@ -702,27 +728,33 @@ class GLPlotRGBAImage(_GLPlotData2D):
ox, oy = self.origin
- gl.glUniform1i(prog.uniforms['tex'], self._DATA_TEX_UNIT)
+ gl.glUniform1i(prog.uniforms["tex"], self._DATA_TEX_UNIT)
- gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE,
- context.matrix.astype(numpy.float32))
+ gl.glUniformMatrix4fv(
+ prog.uniforms["matrix"], 1, gl.GL_TRUE, context.matrix.astype(numpy.float32)
+ )
mat = numpy.dot(mat4Translate(ox, oy), mat4Scale(*self.scale))
- gl.glUniformMatrix4fv(prog.uniforms['matOffset'], 1, gl.GL_TRUE,
- mat.astype(numpy.float32))
+ gl.glUniformMatrix4fv(
+ prog.uniforms["matOffset"], 1, gl.GL_TRUE, mat.astype(numpy.float32)
+ )
- gl.glUniform2i(prog.uniforms['isLog'], context.isXLog, context.isYLog)
+ gl.glUniform2i(prog.uniforms["isLog"], context.isXLog, context.isYLog)
- gl.glUniform1f(prog.uniforms['alpha'], self.alpha)
+ gl.glUniform1f(prog.uniforms["alpha"], self.alpha)
ex = ox + self.scale[0] * self.data.shape[1]
ey = oy + self.scale[1] * self.data.shape[0]
- xOneOverRange = 1. / (ex - ox)
- yOneOverRange = 1. / (ey - oy)
- gl.glUniform2f(prog.uniforms['bounds_originOverRange'],
- ox * xOneOverRange, oy * yOneOverRange)
- gl.glUniform2f(prog.uniforms['bounds_oneOverRange'],
- xOneOverRange, yOneOverRange)
+ xOneOverRange = 1.0 / (ex - ox)
+ yOneOverRange = 1.0 / (ey - oy)
+ gl.glUniform2f(
+ prog.uniforms["bounds_originOverRange"],
+ ox * xOneOverRange,
+ oy * yOneOverRange,
+ )
+ gl.glUniform2f(
+ prog.uniforms["bounds_oneOverRange"], xOneOverRange, yOneOverRange
+ )
try:
tiles = self._texture.tiles
@@ -730,20 +762,19 @@ class GLPlotRGBAImage(_GLPlotData2D):
raise RuntimeError("No texture, discard has already been called")
if len(tiles) > 1:
raise NotImplementedError(
- "Image over multiple textures not supported with log scale")
+ "Image over multiple textures not supported with log scale"
+ )
texture, vertices, info = tiles[0]
texture.bind(self._DATA_TEX_UNIT)
- posAttrib = prog.attributes['position']
+ posAttrib = prog.attributes["position"]
stride = vertices.shape[-1] * vertices.itemsize
gl.glEnableVertexAttribArray(posAttrib)
- gl.glVertexAttribPointer(posAttrib,
- 2,
- gl.GL_FLOAT,
- gl.GL_FALSE,
- stride, vertices)
+ gl.glVertexAttribPointer(
+ posAttrib, 2, gl.GL_FLOAT, gl.GL_FALSE, stride, vertices
+ )
gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(vertices))
diff --git a/src/silx/gui/plot/backends/glutils/GLPlotItem.py b/src/silx/gui/plot/backends/glutils/GLPlotItem.py
index 58f5f41..0287ad5 100644
--- a/src/silx/gui/plot/backends/glutils/GLPlotItem.py
+++ b/src/silx/gui/plot/backends/glutils/GLPlotItem.py
@@ -39,7 +39,9 @@ class RenderContext:
:param float dpi: Number of device pixels per inch
"""
- def __init__(self, matrix=None, isXLog=False, isYLog=False, dpi=96., plotFrame=None):
+ def __init__(
+ self, matrix=None, isXLog=False, isYLog=False, dpi=96.0, plotFrame=None
+ ):
self.matrix = matrix
"""Current transformation matrix"""
@@ -73,7 +75,7 @@ class GLPlotItem:
"""Base class for primitives used in the PlotWidget OpenGL backend"""
def __init__(self):
- self.yaxis = 'left'
+ self.yaxis = "left"
"YAxis this item is attached to (either 'left' or 'right')"
def pick(self, x, y):
@@ -99,6 +101,5 @@ class GLPlotItem:
pass
def isInitialized(self) -> bool:
- """Returns True if resources where initialized and requires `discard`.
- """
+ """Returns True if resources where initialized and requires `discard`."""
return True
diff --git a/src/silx/gui/plot/backends/glutils/GLPlotTriangles.py b/src/silx/gui/plot/backends/glutils/GLPlotTriangles.py
index a67afd9..e8a8e4a 100644
--- a/src/silx/gui/plot/backends/glutils/GLPlotTriangles.py
+++ b/src/silx/gui/plot/backends/glutils/GLPlotTriangles.py
@@ -70,9 +70,10 @@ class GLPlotTriangles(GLPlotItem):
gl_FragColor.a *= alpha;
}
""",
- attrib0='xPos')
+ attrib0="xPos",
+ )
- def __init__(self, x, y, color, triangles, alpha=1.):
+ def __init__(self, x, y, color, triangles, alpha=1.0):
"""
:param numpy.ndarray x: X coordinates of triangle corners
@@ -97,14 +98,14 @@ class GLPlotTriangles(GLPlotItem):
elif numpy.issubdtype(color.dtype, numpy.integer):
color = numpy.array(color, dtype=numpy.uint8, copy=False)
else:
- raise ValueError('Unsupported color type')
+ raise ValueError("Unsupported color type")
assert triangles.ndim == 2 and triangles.shape[1] == 3
self.__x_y_color = x, y, color
self.xMin, self.xMax = min_max(x, finite=True)
self.yMin, self.yMax = min_max(y, finite=True)
self.__triangles = triangles
- self.__alpha = numpy.clip(float(alpha), 0., 1.)
+ self.__alpha = numpy.clip(float(alpha), 0.0, 1.0)
self.__vbos = None
self.__indicesVbo = None
self.__picking_triangles = None
@@ -117,21 +118,22 @@ class GLPlotTriangles(GLPlotItem):
:return: List of picked data point indices
:rtype: Union[List[int],None]
"""
- if (x < self.xMin or x > self.xMax or
- y < self.yMin or y > self.yMax):
+ if x < self.xMin or x > self.xMax or y < self.yMin or y > self.yMax:
return None
xPts, yPts = self.__x_y_color[:2]
if self.__picking_triangles is None:
self.__picking_triangles = numpy.zeros(
- self.__triangles.shape + (3,), dtype=numpy.float32)
+ self.__triangles.shape + (3,), dtype=numpy.float32
+ )
self.__picking_triangles[:, :, 0] = xPts[self.__triangles]
self.__picking_triangles[:, :, 1] = yPts[self.__triangles]
segment = numpy.array(((x, y, -1), (x, y, 1)), dtype=numpy.float32)
# Picked triangle indices
indices = glutils.segmentTrianglesIntersection(
- segment, self.__picking_triangles)[0]
+ segment, self.__picking_triangles
+ )[0]
# Point indices
indices = numpy.unique(numpy.ravel(self.__triangles[indices]))
@@ -163,7 +165,8 @@ class GLPlotTriangles(GLPlotItem):
self.__indicesVbo = glutils.VertexBuffer(
numpy.ravel(self.__triangles),
usage=gl.GL_STATIC_DRAW,
- target=gl.GL_ELEMENT_ARRAY_BUFFER)
+ target=gl.GL_ELEMENT_ARRAY_BUFFER,
+ )
def render(self, context):
"""Perform rendering
@@ -177,20 +180,24 @@ class GLPlotTriangles(GLPlotItem):
self._PROGRAM.use()
- gl.glUniformMatrix4fv(self._PROGRAM.uniforms['matrix'],
- 1,
- gl.GL_TRUE,
- context.matrix.astype(numpy.float32))
+ gl.glUniformMatrix4fv(
+ self._PROGRAM.uniforms["matrix"],
+ 1,
+ gl.GL_TRUE,
+ context.matrix.astype(numpy.float32),
+ )
- gl.glUniform1f(self._PROGRAM.uniforms['alpha'], self.__alpha)
+ gl.glUniform1f(self._PROGRAM.uniforms["alpha"], self.__alpha)
- for index, name in enumerate(('xPos', 'yPos', 'color')):
+ for index, name in enumerate(("xPos", "yPos", "color")):
attr = self._PROGRAM.attributes[name]
gl.glEnableVertexAttribArray(attr)
self.__vbos[index].setVertexAttrib(attr)
with self.__indicesVbo:
- gl.glDrawElements(gl.GL_TRIANGLES,
- self.__triangles.size,
- glutils.numpyToGLType(self.__triangles.dtype),
- ctypes.c_void_p(0))
+ gl.glDrawElements(
+ gl.GL_TRIANGLES,
+ self.__triangles.size,
+ glutils.numpyToGLType(self.__triangles.dtype),
+ ctypes.c_void_p(0),
+ )
diff --git a/src/silx/gui/plot/backends/glutils/GLSupport.py b/src/silx/gui/plot/backends/glutils/GLSupport.py
index f5357e2..c9afda0 100644
--- a/src/silx/gui/plot/backends/glutils/GLSupport.py
+++ b/src/silx/gui/plot/backends/glutils/GLSupport.py
@@ -54,8 +54,7 @@ def buildFillMaskIndices(nIndices, dtype=None):
splitIndex = lastIndex // 2 + 1
indices = numpy.empty(nIndices, dtype=dtype)
indices[::2] = numpy.arange(0, splitIndex, step=1, dtype=dtype)
- indices[1::2] = numpy.arange(lastIndex, splitIndex - 1, step=-1,
- dtype=dtype)
+ indices[1::2] = numpy.arange(lastIndex, splitIndex - 1, step=-1, dtype=dtype)
return indices
@@ -63,16 +62,17 @@ class FilledShape2D(object):
_NO_HATCH = 0
_HATCH_STEP = 20
- def __init__(self, points, style='solid', color=(0., 0., 0., 1.)):
+ def __init__(self, points, style="solid", color=(0.0, 0.0, 0.0, 1.0)):
self.vertices = numpy.array(points, dtype=numpy.float32, copy=False)
self._indices = buildFillMaskIndices(len(self.vertices))
tVertex = numpy.transpose(self.vertices)
xMin, xMax = min(tVertex[0]), max(tVertex[0])
yMin, yMax = min(tVertex[1]), max(tVertex[1])
- self.bboxVertices = numpy.array(((xMin, yMin), (xMin, yMax),
- (xMax, yMin), (xMax, yMax)),
- dtype=numpy.float32)
+ self.bboxVertices = numpy.array(
+ ((xMin, yMin), (xMin, yMax), (xMax, yMin), (xMax, yMax)),
+ dtype=numpy.float32,
+ )
self._xMin, self._xMax = xMin, xMax
self._yMin, self._yMax = yMin, yMax
@@ -80,18 +80,16 @@ class FilledShape2D(object):
self.color = color
def render(self, posAttrib, colorUnif, hatchStepUnif):
- assert self.style in ('hatch', 'solid')
+ assert self.style in ("hatch", "solid")
gl.glUniform4f(colorUnif, *self.color)
- step = self._HATCH_STEP if self.style == 'hatch' else self._NO_HATCH
+ step = self._HATCH_STEP if self.style == "hatch" else self._NO_HATCH
gl.glUniform1i(hatchStepUnif, step)
# Prepare fill mask
gl.glEnableVertexAttribArray(posAttrib)
- gl.glVertexAttribPointer(posAttrib,
- 2,
- gl.GL_FLOAT,
- gl.GL_FALSE,
- 0, self.vertices)
+ gl.glVertexAttribPointer(
+ posAttrib, 2, gl.GL_FLOAT, gl.GL_FALSE, 0, self.vertices
+ )
gl.glEnable(gl.GL_STENCIL_TEST)
gl.glStencilMask(1)
@@ -100,8 +98,12 @@ class FilledShape2D(object):
gl.glColorMask(gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE)
gl.glDepthMask(gl.GL_FALSE)
- gl.glDrawElements(gl.GL_TRIANGLE_STRIP, len(self._indices),
- gl.GL_UNSIGNED_SHORT, self._indices)
+ gl.glDrawElements(
+ gl.GL_TRIANGLE_STRIP,
+ len(self._indices),
+ gl.GL_UNSIGNED_SHORT,
+ self._indices,
+ )
gl.glStencilFunc(gl.GL_EQUAL, 1, 1)
# Reset stencil while drawing
@@ -109,11 +111,9 @@ class FilledShape2D(object):
gl.glColorMask(gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE)
gl.glDepthMask(gl.GL_TRUE)
- gl.glVertexAttribPointer(posAttrib,
- 2,
- gl.GL_FLOAT,
- gl.GL_FALSE,
- 0, self.bboxVertices)
+ gl.glVertexAttribPointer(
+ posAttrib, 2, gl.GL_FLOAT, gl.GL_FALSE, 0, self.bboxVertices
+ )
gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(self.bboxVertices))
gl.glDisable(gl.GL_STENCIL_TEST)
@@ -121,37 +121,54 @@ class FilledShape2D(object):
# matrix ######################################################################
+
def mat4Ortho(left, right, bottom, top, near, far):
"""Orthographic projection matrix (row-major)"""
- return numpy.array((
- (2./(right - left), 0., 0., -(right+left)/float(right-left)),
- (0., 2./(top - bottom), 0., -(top+bottom)/float(top-bottom)),
- (0., 0., -2./(far-near), -(far+near)/float(far-near)),
- (0., 0., 0., 1.)), dtype=numpy.float64)
-
-
-def mat4Translate(x=0., y=0., z=0.):
+ return numpy.array(
+ (
+ (2.0 / (right - left), 0.0, 0.0, -(right + left) / float(right - left)),
+ (0.0, 2.0 / (top - bottom), 0.0, -(top + bottom) / float(top - bottom)),
+ (0.0, 0.0, -2.0 / (far - near), -(far + near) / float(far - near)),
+ (0.0, 0.0, 0.0, 1.0),
+ ),
+ dtype=numpy.float64,
+ )
+
+
+def mat4Translate(x=0.0, y=0.0, z=0.0):
"""Translation matrix (row-major)"""
- return numpy.array((
- (1., 0., 0., x),
- (0., 1., 0., y),
- (0., 0., 1., z),
- (0., 0., 0., 1.)), dtype=numpy.float64)
-
-
-def mat4Scale(sx=1., sy=1., sz=1.):
+ return numpy.array(
+ (
+ (1.0, 0.0, 0.0, x),
+ (0.0, 1.0, 0.0, y),
+ (0.0, 0.0, 1.0, z),
+ (0.0, 0.0, 0.0, 1.0),
+ ),
+ dtype=numpy.float64,
+ )
+
+
+def mat4Scale(sx=1.0, sy=1.0, sz=1.0):
"""Scale matrix (row-major)"""
- return numpy.array((
- (sx, 0., 0., 0.),
- (0., sy, 0., 0.),
- (0., 0., sz, 0.),
- (0., 0., 0., 1.)), dtype=numpy.float64)
+ return numpy.array(
+ (
+ (sx, 0.0, 0.0, 0.0),
+ (0.0, sy, 0.0, 0.0),
+ (0.0, 0.0, sz, 0.0),
+ (0.0, 0.0, 0.0, 1.0),
+ ),
+ dtype=numpy.float64,
+ )
def mat4Identity():
"""Identity matrix"""
- return numpy.array((
- (1., 0., 0., 0.),
- (0., 1., 0., 0.),
- (0., 0., 1., 0.),
- (0., 0., 0., 1.)), dtype=numpy.float64)
+ return numpy.array(
+ (
+ (1.0, 0.0, 0.0, 0.0),
+ (0.0, 1.0, 0.0, 0.0),
+ (0.0, 0.0, 1.0, 0.0),
+ (0.0, 0.0, 0.0, 1.0),
+ ),
+ dtype=numpy.float64,
+ )
diff --git a/src/silx/gui/plot/backends/glutils/GLText.py b/src/silx/gui/plot/backends/glutils/GLText.py
index 4862bff..15d7a70 100644
--- a/src/silx/gui/plot/backends/glutils/GLText.py
+++ b/src/silx/gui/plot/backends/glutils/GLText.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2014-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2014-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -26,6 +26,8 @@ This module provides minimalistic text support for OpenGL.
It provides Latin-1 (ISO8859-1) characters for one monospace font at one size.
"""
+from __future__ import annotations
+
__authors__ = ["T. Vincent"]
__license__ = "MIT"
__date__ = "03/04/2017"
@@ -36,14 +38,13 @@ import weakref
import numpy
+from .... import qt
from ...._glutils import font, gl, Context, Program, Texture
from .GLSupport import mat4Translate
+from silx.gui.colors import RGBAColorType
-# TODO: Font should be configurable by the main program: using mpl.rcParams?
-
-
-class _Cache(object):
+class _Cache:
"""LRU (Least Recent Used) cache.
:param int maxsize: Maximum number of (key, value) pairs in the cache
@@ -55,7 +56,7 @@ class _Cache(object):
def __init__(self, maxsize=128, callback=None):
self._maxsize = int(maxsize)
self._callback = callback
- self._cache = OrderedDict()
+ self._cache = OrderedDict() # Needed for popitem(last=False)
def __contains__(self, item):
return item in self._cache
@@ -84,15 +85,14 @@ class _Cache(object):
# Text2D ######################################################################
-LEFT, CENTER, RIGHT = 'left', 'center', 'right'
-TOP, BASELINE, BOTTOM = 'top', 'baseline', 'bottom'
+LEFT, CENTER, RIGHT = "left", "center", "right"
+TOP, BASELINE, BOTTOM = "top", "baseline", "bottom"
ROTATE_90, ROTATE_180, ROTATE_270 = 90, 180, 270
-class Text2D(object):
-
+class Text2D:
_SHADERS = {
- 'vertex': """
+ "vertex": """
#version 120
attribute vec2 position;
@@ -106,7 +106,7 @@ class Text2D(object):
vCoords = texCoords;
}
""",
- 'fragment': """
+ "fragment": """
#version 120
uniform sampler2D texText;
@@ -116,130 +116,134 @@ class Text2D(object):
varying vec2 vCoords;
void main(void) {
- gl_FragColor = mix(bgColor, color, texture2D(texText, vCoords).r);
+ if (vCoords.x < 0.0 || vCoords.x > 1.0 || vCoords.y < 0.0 || vCoords.y > 1.0) {
+ gl_FragColor = bgColor;
+ } else {
+ gl_FragColor = mix(bgColor, color, texture2D(texText, vCoords).r);
+ }
}
- """
+ """,
}
- _TEX_COORDS = numpy.array(((0., 0.), (1., 0.), (0., 1.), (1., 1.)),
- dtype=numpy.float32).ravel()
-
- _program = Program(_SHADERS['vertex'],
- _SHADERS['fragment'],
- attrib0='position')
+ _program = Program(_SHADERS["vertex"], _SHADERS["fragment"], attrib0="position")
# Discard texture objects when removed from the cache
_textures = weakref.WeakKeyDictionary()
"""Cache already created textures"""
- _sizes = _Cache()
- """Cache already computed sizes"""
-
- def __init__(self, text, x=0, y=0,
- color=(0., 0., 0., 1.),
- bgColor=None,
- align=LEFT, valign=BASELINE,
- rotate=0,
- devicePixelRatio= 1.):
+ def __init__(
+ self,
+ text: str,
+ font: qt.QFont,
+ x: float = 0.0,
+ y: float = 0.0,
+ color: tuple[float, float, float, float] = (0.0, 0.0, 0.0, 1.0),
+ bgColor: RGBAColorType | None = None,
+ align: str = LEFT,
+ valign: str = BASELINE,
+ rotate: float = 0.0,
+ devicePixelRatio: float = 1.0,
+ padding: int = 0,
+ ):
self.devicePixelRatio = devicePixelRatio
+ self.font = font
self._vertices = None
self._text = text
+ self._padding = padding
self.x = x
self.y = y
self.color = color
self.bgColor = bgColor
if align not in (LEFT, CENTER, RIGHT):
- raise ValueError(
- "Horizontal alignment not supported: {0}".format(align))
+ raise ValueError("Horizontal alignment not supported: {0}".format(align))
self._align = align
if valign not in (TOP, CENTER, BASELINE, BOTTOM):
- raise ValueError(
- "Vertical alignment not supported: {0}".format(valign))
+ raise ValueError("Vertical alignment not supported: {0}".format(valign))
self._valign = valign
self._rotate = numpy.radians(rotate)
- def _getTexture(self, text, devicePixelRatio):
+ def _getTexture(self, dotsPerInch: float) -> tuple[Texture, int]:
# Retrieve/initialize texture cache for current context
- textureKey = text, devicePixelRatio
+ key = self.text, self.font.key(), dotsPerInch
context = Context.getCurrent()
if context not in self._textures:
self._textures[context] = _Cache(
- callback=lambda key, value: value[0].discard())
+ callback=lambda key, value: value[0].discard()
+ )
textures = self._textures[context]
- if textureKey not in textures:
- image, offset = font.rasterText(
- text,
- font.getDefaultFontFamily(),
- devicePixelRatio=self.devicePixelRatio)
- if textureKey not in self._sizes:
- self._sizes[textureKey] = image.shape[1], image.shape[0]
+ if key not in textures:
+ image, offset = font.rasterText(self.text, self.font, dotsPerInch)
texture = Texture(
gl.GL_RED,
data=image,
minFilter=gl.GL_NEAREST,
magFilter=gl.GL_NEAREST,
- wrap=(gl.GL_CLAMP_TO_EDGE,
- gl.GL_CLAMP_TO_EDGE))
+ wrap=(gl.GL_CLAMP_TO_EDGE, gl.GL_CLAMP_TO_EDGE),
+ )
texture.prepare()
- textures[textureKey] = texture, offset
+ textures[key] = texture, offset
- return textures[textureKey]
+ return textures[key]
@property
- def text(self):
+ def text(self) -> str:
return self._text
@property
- def size(self):
- textureKey = self.text, self.devicePixelRatio
- if textureKey not in self._sizes:
- image, offset = font.rasterText(
- self.text,
- font.getDefaultFontFamily(),
- devicePixelRatio=self.devicePixelRatio)
- self._sizes[textureKey] = image.shape[1], image.shape[0]
- return self._sizes[textureKey]
-
- def getVertices(self, offset, shape):
+ def padding(self) -> int:
+ return self._padding
+
+ def getVertices(self, offset: int, shape: tuple[int, int]) -> numpy.ndarray:
height, width = shape
if self._align == LEFT:
xOrig = 0
elif self._align == RIGHT:
- xOrig = - width
+ xOrig = -width
else: # CENTER
- xOrig = - width // 2
+ xOrig = -width // 2
if self._valign == BASELINE:
- yOrig = - offset
+ yOrig = -offset
elif self._valign == TOP:
yOrig = 0
elif self._valign == BOTTOM:
- yOrig = - height
+ yOrig = -height
else: # CENTER
- yOrig = - height // 2
-
- vertices = numpy.array((
- (xOrig, yOrig),
- (xOrig + width, yOrig),
- (xOrig, yOrig + height),
- (xOrig + width, yOrig + height)), dtype=numpy.float32)
+ yOrig = -height // 2
+
+ vertices = numpy.array(
+ (
+ (xOrig, yOrig),
+ (xOrig + width, yOrig),
+ (xOrig, yOrig + height),
+ (xOrig + width, yOrig + height),
+ ),
+ dtype=numpy.float32,
+ )
cos, sin = numpy.cos(self._rotate), numpy.sin(self._rotate)
- vertices = numpy.ascontiguousarray(numpy.transpose(numpy.array((
- cos * vertices[:, 0] - sin * vertices[:, 1],
- sin * vertices[:, 0] + cos * vertices[:, 1]),
- dtype=numpy.float32)))
+ vertices = numpy.ascontiguousarray(
+ numpy.transpose(
+ numpy.array(
+ (
+ cos * vertices[:, 0] - sin * vertices[:, 1],
+ sin * vertices[:, 0] + cos * vertices[:, 1],
+ ),
+ dtype=numpy.float32,
+ )
+ )
+ )
return vertices
- def render(self, matrix):
+ def render(self, matrix: numpy.ndarray, dotsPerInch: float):
if not self.text.strip():
return
@@ -247,40 +251,47 @@ class Text2D(object):
prog.use()
texUnit = 0
- texture, offset = self._getTexture(self.text, self.devicePixelRatio)
+ texture, offset = self._getTexture(dotsPerInch)
- gl.glUniform1i(prog.uniforms['texText'], texUnit)
+ gl.glUniform1i(prog.uniforms["texText"], texUnit)
mat = numpy.dot(matrix, mat4Translate(int(self.x), int(self.y)))
- gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE,
- mat.astype(numpy.float32))
+ gl.glUniformMatrix4fv(
+ prog.uniforms["matrix"], 1, gl.GL_TRUE, mat.astype(numpy.float32)
+ )
- gl.glUniform4f(prog.uniforms['color'], *self.color)
+ gl.glUniform4f(prog.uniforms["color"], *self.color)
if self.bgColor is not None:
bgColor = self.bgColor
else:
- bgColor = self.color[0], self.color[1], self.color[2], 0.
- gl.glUniform4f(prog.uniforms['bgColor'], *bgColor)
+ bgColor = self.color[0], self.color[1], self.color[2], 0.0
+ gl.glUniform4f(prog.uniforms["bgColor"], *bgColor)
- vertices = self.getVertices(offset, texture.shape)
+ paddingOffset = max(0, int(self.padding * self.devicePixelRatio))
+ height, width = texture.shape
+ vertices = self.getVertices(
+ offset, (height + 2 * paddingOffset, width + 2 * paddingOffset)
+ )
- posAttrib = prog.attributes['position']
+ posAttrib = prog.attributes["position"]
gl.glEnableVertexAttribArray(posAttrib)
- gl.glVertexAttribPointer(posAttrib,
- 2,
- gl.GL_FLOAT,
- gl.GL_FALSE,
- 0,
- vertices)
-
- texAttrib = prog.attributes['texCoords']
+ gl.glVertexAttribPointer(posAttrib, 2, gl.GL_FLOAT, gl.GL_FALSE, 0, vertices)
+
+ xoffset = paddingOffset / width
+ yoffset = paddingOffset / height
+ texCoords = numpy.array(
+ (
+ (-xoffset, -yoffset),
+ (1.0 + xoffset, -yoffset),
+ (-xoffset, 1.0 + yoffset),
+ (1.0 + xoffset, 1.0 + yoffset),
+ ),
+ dtype=numpy.float32,
+ ).ravel()
+
+ texAttrib = prog.attributes["texCoords"]
gl.glEnableVertexAttribArray(texAttrib)
- gl.glVertexAttribPointer(texAttrib,
- 2,
- gl.GL_FLOAT,
- gl.GL_FALSE,
- 0,
- self._TEX_COORDS)
+ gl.glVertexAttribPointer(texAttrib, 2, gl.GL_FLOAT, gl.GL_FALSE, 0, texCoords)
with texture:
gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, 4)
diff --git a/src/silx/gui/plot/backends/glutils/GLTexture.py b/src/silx/gui/plot/backends/glutils/GLTexture.py
index caca111..cbbe7ac 100644
--- a/src/silx/gui/plot/backends/glutils/GLTexture.py
+++ b/src/silx/gui/plot/backends/glutils/GLTexture.py
@@ -39,29 +39,33 @@ from ...._glutils import gl, Texture, numpyToGLType
_logger = logging.getLogger(__name__)
-def _checkTexture2D(internalFormat, shape,
- format_=None, type_=gl.GL_FLOAT, border=0):
+def _checkTexture2D(internalFormat, shape, format_=None, type_=gl.GL_FLOAT, border=0):
"""Check if texture size with provided parameters is supported
:rtype: bool
"""
height, width = shape
- gl.glTexImage2D(gl.GL_PROXY_TEXTURE_2D, 0, internalFormat,
- width, height, border,
- format_ or internalFormat,
- type_, c_void_p(0))
- width = gl.glGetTexLevelParameteriv(
- gl.GL_PROXY_TEXTURE_2D, 0, gl.GL_TEXTURE_WIDTH)
+ gl.glTexImage2D(
+ gl.GL_PROXY_TEXTURE_2D,
+ 0,
+ internalFormat,
+ width,
+ height,
+ border,
+ format_ or internalFormat,
+ type_,
+ c_void_p(0),
+ )
+ width = gl.glGetTexLevelParameteriv(gl.GL_PROXY_TEXTURE_2D, 0, gl.GL_TEXTURE_WIDTH)
return bool(width)
MIN_TEXTURE_SIZE = 64
-def _getMaxSquareTexture2DSize(internalFormat=gl.GL_RGBA,
- format_=None,
- type_=gl.GL_FLOAT,
- border=0):
+def _getMaxSquareTexture2DSize(
+ internalFormat=gl.GL_RGBA, format_=None, type_=gl.GL_FLOAT, border=0
+):
"""Returns a supported size for a corresponding square texture
:returns: GL_MAX_TEXTURE_SIZE or a smaller supported size (not optimal)
@@ -69,16 +73,15 @@ def _getMaxSquareTexture2DSize(internalFormat=gl.GL_RGBA,
"""
# Is this useful?
maxTexSize = gl.glGetIntegerv(gl.GL_MAX_TEXTURE_SIZE)
- while maxTexSize > MIN_TEXTURE_SIZE and \
- not _checkTexture2D(internalFormat, (maxTexSize, maxTexSize),
- format_, type_, border):
+ while maxTexSize > MIN_TEXTURE_SIZE and not _checkTexture2D(
+ internalFormat, (maxTexSize, maxTexSize), format_, type_, border
+ ):
maxTexSize //= 2
return max(MIN_TEXTURE_SIZE, maxTexSize)
class Image(object):
- """Image of any size eventually using multiple textures or larger texture
- """
+ """Image of any size eventually using multiple textures or larger texture"""
_WRAP = (gl.GL_CLAMP_TO_EDGE, gl.GL_CLAMP_TO_EDGE)
_MIN_FILTER = gl.GL_NEAREST
@@ -90,34 +93,48 @@ class Image(object):
type_ = numpyToGLType(data.dtype)
if _checkTexture2D(internalFormat, data.shape[0:2], format_, type_):
- texture = Texture(internalFormat,
- data,
- format_,
- texUnit=texUnit,
- minFilter=self._MIN_FILTER,
- magFilter=self._MAG_FILTER,
- wrap=self._WRAP)
+ texture = Texture(
+ internalFormat,
+ data,
+ format_,
+ texUnit=texUnit,
+ minFilter=self._MIN_FILTER,
+ magFilter=self._MAG_FILTER,
+ wrap=self._WRAP,
+ )
texture.prepare()
- vertices = numpy.array((
- (0., 0., 0., 0.),
- (self.width, 0., 1., 0.),
- (0., self.height, 0., 1.),
- (self.width, self.height, 1., 1.)), dtype=numpy.float32)
- self.tiles = ((texture, vertices,
- {'xOrigData': 0, 'yOrigData': 0,
- 'wData': self.width, 'hData': self.height}),)
+ vertices = numpy.array(
+ (
+ (0.0, 0.0, 0.0, 0.0),
+ (self.width, 0.0, 1.0, 0.0),
+ (0.0, self.height, 0.0, 1.0),
+ (self.width, self.height, 1.0, 1.0),
+ ),
+ dtype=numpy.float32,
+ )
+ self.tiles = (
+ (
+ texture,
+ vertices,
+ {
+ "xOrigData": 0,
+ "yOrigData": 0,
+ "wData": self.width,
+ "hData": self.height,
+ },
+ ),
+ )
else:
# Handle dimension too large: make tiles
- maxTexSize = _getMaxSquareTexture2DSize(internalFormat,
- format_, type_)
+ maxTexSize = _getMaxSquareTexture2DSize(internalFormat, format_, type_)
- nCols = (self.width+maxTexSize-1) // maxTexSize
+ nCols = (self.width + maxTexSize - 1) // maxTexSize
colWidths = [self.width // nCols] * nCols
colWidths[-1] += self.width % nCols
- nRows = (self.height+maxTexSize-1) // maxTexSize
- rowHeights = [self.height//nRows] * nRows
+ nRows = (self.height + maxTexSize - 1) // maxTexSize
+ rowHeights = [self.height // nRows] * nRows
rowHeights[-1] += self.height % nRows
tiles = []
@@ -125,30 +142,32 @@ class Image(object):
for hData in rowHeights:
xOrig = 0
for wData in colWidths:
- if (hData < MIN_TEXTURE_SIZE or wData < MIN_TEXTURE_SIZE) \
- and not _checkTexture2D(internalFormat,
- (hData, wData),
- format_,
- type_):
+ if (
+ hData < MIN_TEXTURE_SIZE or wData < MIN_TEXTURE_SIZE
+ ) and not _checkTexture2D(
+ internalFormat, (hData, wData), format_, type_
+ ):
# Ensure texture size is at least MIN_TEXTURE_SIZE
tH = max(hData, MIN_TEXTURE_SIZE)
tW = max(wData, MIN_TEXTURE_SIZE)
- uMax, vMax = float(wData)/tW, float(hData)/tH
+ uMax, vMax = float(wData) / tW, float(hData) / tH
# TODO issue with type_ and alignment
- texture = Texture(internalFormat,
- data=None,
- format_=format_,
- shape=(tH, tW),
- texUnit=texUnit,
- minFilter=self._MIN_FILTER,
- magFilter=self._MAG_FILTER,
- wrap=self._WRAP)
+ texture = Texture(
+ internalFormat,
+ data=None,
+ format_=format_,
+ shape=(tH, tW),
+ texUnit=texUnit,
+ minFilter=self._MIN_FILTER,
+ magFilter=self._MAG_FILTER,
+ wrap=self._WRAP,
+ )
# TODO handle unpack
- texture.update(format_,
- data[yOrig:yOrig+hData,
- xOrig:xOrig+wData])
+ texture.update(
+ format_, data[yOrig : yOrig + hData, xOrig : xOrig + wData]
+ )
# texture.update(format_, type_, data,
# width=wData, height=hData,
# unpackRowLength=width,
@@ -159,28 +178,41 @@ class Image(object):
# TODO issue with type_ and unpacking tiles
# TODO idea to handle unpack: use array strides
# As it is now, it will make a copy
- texture = Texture(internalFormat,
- data[yOrig:yOrig+hData,
- xOrig:xOrig+wData],
- format_,
- texUnit=texUnit,
- minFilter=self._MIN_FILTER,
- magFilter=self._MAG_FILTER,
- wrap=self._WRAP)
+ texture = Texture(
+ internalFormat,
+ data[yOrig : yOrig + hData, xOrig : xOrig + wData],
+ format_,
+ texUnit=texUnit,
+ minFilter=self._MIN_FILTER,
+ magFilter=self._MAG_FILTER,
+ wrap=self._WRAP,
+ )
# TODO
# unpackRowLength=width,
# unpackSkipPixels=xOrig,
# unpackSkipRows=yOrig)
- vertices = numpy.array((
- (xOrig, yOrig, 0., 0.),
- (xOrig + wData, yOrig, uMax, 0.),
- (xOrig, yOrig + hData, 0., vMax),
- (xOrig + wData, yOrig + hData, uMax, vMax)),
- dtype=numpy.float32)
+ vertices = numpy.array(
+ (
+ (xOrig, yOrig, 0.0, 0.0),
+ (xOrig + wData, yOrig, uMax, 0.0),
+ (xOrig, yOrig + hData, 0.0, vMax),
+ (xOrig + wData, yOrig + hData, uMax, vMax),
+ ),
+ dtype=numpy.float32,
+ )
texture.prepare()
- tiles.append((texture, vertices,
- {'xOrigData': xOrig, 'yOrigData': yOrig,
- 'wData': wData, 'hData': hData}))
+ tiles.append(
+ (
+ texture,
+ vertices,
+ {
+ "xOrigData": xOrig,
+ "yOrigData": yOrig,
+ "wData": wData,
+ "hData": hData,
+ },
+ )
+ )
xOrig += wData
yOrig += hData
self.tiles = tuple(tiles)
@@ -191,7 +223,7 @@ class Image(object):
del self.tiles
def updateAll(self, format_, data, texUnit=0):
- if not hasattr(self, 'tiles'):
+ if not hasattr(self, "tiles"):
raise RuntimeError("No texture, discard has already been called")
assert data.shape[:2] == (self.height, self.width)
@@ -199,11 +231,13 @@ class Image(object):
self.tiles[0][0].update(format_, data, texUnit=texUnit)
else:
for texture, _, info in self.tiles:
- yOrig, xOrig = info['yOrigData'], info['xOrigData']
- height, width = info['hData'], info['wData']
- texture.update(format_,
- data[yOrig:yOrig+height, xOrig:xOrig+width],
- texUnit=texUnit)
+ yOrig, xOrig = info["yOrigData"], info["xOrigData"]
+ height, width = info["hData"], info["wData"]
+ texture.update(
+ format_,
+ data[yOrig : yOrig + height, xOrig : xOrig + width],
+ texUnit=texUnit,
+ )
texture.prepare()
# TODO check
# width=info['wData'], height=info['hData'],
@@ -223,18 +257,13 @@ class Image(object):
stride = vertices.shape[-1] * vertices.itemsize
gl.glEnableVertexAttribArray(posAttrib)
- gl.glVertexAttribPointer(posAttrib,
- 2,
- gl.GL_FLOAT,
- gl.GL_FALSE,
- stride, vertices)
-
- texCoordsPtr = c_void_p(vertices.ctypes.data +
- 2 * vertices.itemsize)
+ gl.glVertexAttribPointer(
+ posAttrib, 2, gl.GL_FLOAT, gl.GL_FALSE, stride, vertices
+ )
+
+ texCoordsPtr = c_void_p(vertices.ctypes.data + 2 * vertices.itemsize)
gl.glEnableVertexAttribArray(texAttrib)
- gl.glVertexAttribPointer(texAttrib,
- 2,
- gl.GL_FLOAT,
- gl.GL_FALSE,
- stride, texCoordsPtr)
+ gl.glVertexAttribPointer(
+ texAttrib, 2, gl.GL_FLOAT, gl.GL_FALSE, stride, texCoordsPtr
+ )
gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(vertices))
diff --git a/src/silx/gui/plot/backends/glutils/PlotImageFile.py b/src/silx/gui/plot/backends/glutils/PlotImageFile.py
index 75ee50b..1622122 100644
--- a/src/silx/gui/plot/backends/glutils/PlotImageFile.py
+++ b/src/silx/gui/plot/backends/glutils/PlotImageFile.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2014-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2014-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -30,12 +30,14 @@ __date__ = "03/04/2017"
import base64
import struct
-import sys
import zlib
+from fabio.TiffIO import TiffIO
+
# Image writer ################################################################
+
def convertRGBDataToPNG(data):
"""Convert a RGB bitmap to PNG.
@@ -53,29 +55,42 @@ def convertRGBDataToPNG(data):
colorType = 2 # 'truecolor' = RGB
interlace = 0 # No
- IHDRdata = struct.pack(">ccccIIBBBBB", b'I', b'H', b'D', b'R',
- width, height, depth, colorType,
- 0, 0, interlace)
+ IHDRdata = struct.pack(
+ ">ccccIIBBBBB",
+ b"I",
+ b"H",
+ b"D",
+ b"R",
+ width,
+ height,
+ depth,
+ colorType,
+ 0,
+ 0,
+ interlace,
+ )
# Add filter 'None' before each scanline
- preparedData = b'\x00' + b'\x00'.join(line.tobytes() for line in data)
+ preparedData = b"\x00" + b"\x00".join(line.tobytes() for line in data)
compressedData = zlib.compress(preparedData, 8)
- IDATdata = struct.pack("cccc", b'I', b'D', b'A', b'T')
+ IDATdata = struct.pack("cccc", b"I", b"D", b"A", b"T")
IDATdata += compressedData
- return b''.join([
- b'\x89PNG\r\n\x1a\n', # PNG signature
- # IHDR chunk: Image Header
- struct.pack(">I", 13), # length
- IHDRdata,
- struct.pack(">I", zlib.crc32(IHDRdata) & 0xffffffff), # CRC
- # IDAT chunk: Payload
- struct.pack(">I", len(compressedData)),
- IDATdata,
- struct.pack(">I", zlib.crc32(IDATdata) & 0xffffffff), # CRC
- b'\x00\x00\x00\x00IEND\xaeB`\x82' # IEND chunk: footer
- ])
+ return b"".join(
+ [
+ b"\x89PNG\r\n\x1a\n", # PNG signature
+ # IHDR chunk: Image Header
+ struct.pack(">I", 13), # length
+ IHDRdata,
+ struct.pack(">I", zlib.crc32(IHDRdata) & 0xFFFFFFFF), # CRC
+ # IDAT chunk: Payload
+ struct.pack(">I", len(compressedData)),
+ IDATdata,
+ struct.pack(">I", zlib.crc32(IDATdata) & 0xFFFFFFFF), # CRC
+ b"\x00\x00\x00\x00IEND\xaeB`\x82", # IEND chunk: footer
+ ]
+ )
def saveImageToFile(data, fileNameOrObj, fileFormat):
@@ -89,64 +104,56 @@ def saveImageToFile(data, fileNameOrObj, fileFormat):
"""
assert len(data.shape) == 3
assert data.shape[2] == 3
- assert fileFormat in ('png', 'ppm', 'svg', 'tiff')
+ assert fileFormat in ("png", "ppm", "svg", "tif", "tiff")
- if not hasattr(fileNameOrObj, 'write'):
- if sys.version_info < (3, ):
+ if not hasattr(fileNameOrObj, "write"):
+ if fileFormat in ("png", "ppm", "tiff"):
+ # Open in binary mode
fileObj = open(fileNameOrObj, "wb")
else:
- if fileFormat in ('png', 'ppm', 'tiff'):
- # Open in binary mode
- fileObj = open(fileNameOrObj, 'wb')
- else:
- fileObj = open(fileNameOrObj, 'w', newline='')
+ fileObj = open(fileNameOrObj, "w", newline="")
else: # Use as a file-like object
fileObj = fileNameOrObj
- if fileFormat == 'svg':
+ if fileFormat == "svg":
height, width = data.shape[:2]
base64Data = base64.b64encode(convertRGBDataToPNG(data))
- fileObj.write(
- '<?xml version="1.0" encoding="UTF-8" standalone="no"?>\n')
+ fileObj.write('<?xml version="1.0" encoding="UTF-8" standalone="no"?>\n')
fileObj.write('<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"\n')
- fileObj.write(
- ' "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">\n')
+ fileObj.write(' "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">\n')
fileObj.write('<svg xmlns:xlink="http://www.w3.org/1999/xlink"\n')
fileObj.write(' xmlns="http://www.w3.org/2000/svg"\n')
fileObj.write(' version="1.1"\n')
fileObj.write(' width="%d"\n' % width)
fileObj.write(' height="%d">\n' % height)
fileObj.write(' <image xlink:href="data:image/png;base64,')
- fileObj.write(base64Data.decode('ascii'))
+ fileObj.write(base64Data.decode("ascii"))
fileObj.write('"\n')
fileObj.write(' x="0"\n')
fileObj.write(' y="0"\n')
fileObj.write(' width="%d"\n' % width)
fileObj.write(' height="%d"\n' % height)
fileObj.write(' id="image" />\n')
- fileObj.write('</svg>')
+ fileObj.write("</svg>")
- elif fileFormat == 'ppm':
+ elif fileFormat == "ppm":
height, width = data.shape[:2]
- fileObj.write(b'P6\n')
- fileObj.write(b'%d %d\n' % (width, height))
- fileObj.write(b'255\n')
+ fileObj.write(b"P6\n")
+ fileObj.write(b"%d %d\n" % (width, height))
+ fileObj.write(b"255\n")
fileObj.write(data.tobytes())
- elif fileFormat == 'png':
+ elif fileFormat == "png":
fileObj.write(convertRGBDataToPNG(data))
- elif fileFormat == 'tiff':
+ elif fileFormat in ("tif", "tiff"):
if fileObj == fileNameOrObj:
- raise NotImplementedError(
- 'Save TIFF to a file-like object not implemented')
-
- from silx.third_party.TiffIO import TiffIO
+ raise NotImplementedError("Save TIFF to a file-like object not implemented")
- tif = TiffIO(fileNameOrObj, mode='wb+')
- tif.writeImage(data, info={'Title': 'OpenGL Plot Snapshot'})
+ tif = TiffIO(fileNameOrObj, mode="wb+")
+ tif.writeImage(data, info={"Title": "OpenGL Plot Snapshot"})
if fileObj != fileNameOrObj:
fileObj.close()
diff --git a/src/silx/gui/plot/items/__init__.py b/src/silx/gui/plot/items/__init__.py
index 6e26c64..bbb4220 100644
--- a/src/silx/gui/plot/items/__init__.py
+++ b/src/silx/gui/plot/items/__init__.py
@@ -31,22 +31,50 @@ __authors__ = ["T. Vincent"]
__license__ = "MIT"
__date__ = "22/06/2017"
-from .core import (Item, DataItem, # noqa
- LabelsMixIn, DraggableMixIn, ColormapMixIn, # noqa
- SymbolMixIn, ColorMixIn, YAxisMixIn, FillMixIn, # noqa
- AlphaMixIn, LineMixIn, ScatterVisualizationMixIn, # noqa
- ComplexMixIn, ItemChangedType, PointsBase) # noqa
+from .core import (
+ Item,
+ DataItem, # noqa
+ LabelsMixIn,
+ DraggableMixIn,
+ ColormapMixIn,
+ LineGapColorMixIn, # noqa
+ SymbolMixIn,
+ ColorMixIn,
+ YAxisMixIn,
+ FillMixIn, # noqa
+ AlphaMixIn,
+ LineMixIn,
+ ScatterVisualizationMixIn, # noqa
+ ComplexMixIn,
+ ItemChangedType,
+ PointsBase,
+) # noqa
from .complex import ImageComplexData # noqa
from .curve import Curve, CurveStyle # noqa
from .histogram import Histogram # noqa
-from .image import ImageBase, ImageData, ImageDataBase, ImageRgba, ImageStack, MaskImageData # noqa
+from .image import (
+ ImageBase,
+ ImageData,
+ ImageDataBase,
+ ImageRgba,
+ ImageStack,
+ MaskImageData,
+) # noqa
from .image_aggregated import ImageDataAggregated # noqa
from .shape import Line, Shape, BoundingRect, XAxisExtent, YAxisExtent # noqa
from .scatter import Scatter # noqa
from .marker import MarkerBase, Marker, XMarker, YMarker # noqa
from .axis import Axis, XAxis, YAxis, YRightAxis
-DATA_ITEMS = (ImageComplexData, Curve, Histogram, ImageBase, Scatter,
- BoundingRect, XAxisExtent, YAxisExtent)
+DATA_ITEMS = (
+ ImageComplexData,
+ Curve,
+ Histogram,
+ ImageBase,
+ Scatter,
+ BoundingRect,
+ XAxisExtent,
+ YAxisExtent,
+)
"""Classes of items representing data and to consider to compute data bounds.
"""
diff --git a/src/silx/gui/plot/items/_arc_roi.py b/src/silx/gui/plot/items/_arc_roi.py
index 40711b7..658573a 100644
--- a/src/silx/gui/plot/items/_arc_roi.py
+++ b/src/silx/gui/plot/items/_arc_roi.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2018-2022 European Synchrotron Radiation Facility
+# Copyright (c) 2018-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -30,6 +30,8 @@ __date__ = "28/06/2018"
import logging
import numpy
+import enum
+from typing import Tuple
from ... import utils
from .. import items
@@ -50,8 +52,18 @@ class _ArcGeometry:
The aim is is to switch between consistent state without dealing with
intermediate values.
"""
- def __init__(self, center, startPoint, endPoint, radius,
- weight, startAngle, endAngle, closed=False):
+
+ def __init__(
+ self,
+ center,
+ startPoint,
+ endPoint,
+ radius,
+ weight,
+ startAngle,
+ endAngle,
+ closed=False,
+ ):
"""Constructor for a consistent arc geometry.
There is also specific class method to create different kind of arc
@@ -68,46 +80,59 @@ class _ArcGeometry:
@classmethod
def createEmpty(cls):
- """Create an arc geometry from an empty shape
- """
+ """Create an arc geometry from an empty shape"""
zero = numpy.array([0, 0])
return cls(zero, zero.copy(), zero.copy(), 0, 0, 0, 0)
@classmethod
def createRect(cls, startPoint, endPoint, weight):
- """Create an arc geometry from a definition of a rectangle
- """
+ """Create an arc geometry from a definition of a rectangle"""
return cls(None, startPoint, endPoint, None, weight, None, None, False)
@classmethod
- def createCircle(cls, center, startPoint, endPoint, radius,
- weight, startAngle, endAngle):
- """Create an arc geometry from a definition of a circle
- """
- return cls(center, startPoint, endPoint, radius,
- weight, startAngle, endAngle, True)
+ def createCircle(
+ cls, center, startPoint, endPoint, radius, weight, startAngle, endAngle
+ ):
+ """Create an arc geometry from a definition of a circle"""
+ return cls(
+ center, startPoint, endPoint, radius, weight, startAngle, endAngle, True
+ )
def withWeight(self, weight):
- """Return a new geometry based on this object, with a specific weight
- """
- return _ArcGeometry(self.center, self.startPoint, self.endPoint,
- self.radius, weight,
- self.startAngle, self.endAngle, self._closed)
+ """Return a new geometry based on this object, with a specific weight"""
+ return _ArcGeometry(
+ self.center,
+ self.startPoint,
+ self.endPoint,
+ self.radius,
+ weight,
+ self.startAngle,
+ self.endAngle,
+ self._closed,
+ )
def withRadius(self, radius):
"""Return a new geometry based on this object, with a specific radius.
The weight and the center is conserved.
"""
- startPoint = self.center + (self.startPoint - self.center) / self.radius * radius
+ startPoint = (
+ self.center + (self.startPoint - self.center) / self.radius * radius
+ )
endPoint = self.center + (self.endPoint - self.center) / self.radius * radius
- return _ArcGeometry(self.center, startPoint, endPoint,
- radius, self.weight,
- self.startAngle, self.endAngle, self._closed)
+ return _ArcGeometry(
+ self.center,
+ startPoint,
+ endPoint,
+ radius,
+ self.weight,
+ self.startAngle,
+ self.endAngle,
+ self._closed,
+ )
def withStartAngle(self, startAngle):
- """Return a new geometry based on this object, with a specific start angle
- """
+ """Return a new geometry based on this object, with a specific start angle"""
vector = numpy.array([numpy.cos(startAngle), numpy.sin(startAngle)])
startPoint = self.center + vector * self.radius
@@ -131,8 +156,7 @@ class _ArcGeometry:
)
def withEndAngle(self, endAngle):
- """Return a new geometry based on this object, with a specific end angle
- """
+ """Return a new geometry based on this object, with a specific end angle"""
vector = numpy.array([numpy.cos(endAngle), numpy.sin(endAngle)])
endPoint = self.center + vector * self.radius
@@ -161,9 +185,16 @@ class _ArcGeometry:
center = None if self.center is None else self.center + delta
startPoint = None if self.startPoint is None else self.startPoint + delta
endPoint = None if self.endPoint is None else self.endPoint + delta
- return _ArcGeometry(center, startPoint, endPoint,
- self.radius, self.weight,
- self.startAngle, self.endAngle, self._closed)
+ return _ArcGeometry(
+ center,
+ startPoint,
+ endPoint,
+ self.radius,
+ self.weight,
+ self.startAngle,
+ self.endAngle,
+ self._closed,
+ )
def getKind(self):
"""Returns the kind of shape defined"""
@@ -191,14 +222,18 @@ class _ArcGeometry:
return self._closed
def __str__(self):
- return str((self.center,
- self.startPoint,
- self.endPoint,
- self.radius,
- self.weight,
- self.startAngle,
- self.endAngle,
- self._closed))
+ return str(
+ (
+ self.center,
+ self.startPoint,
+ self.endPoint,
+ self.radius,
+ self.weight,
+ self.startAngle,
+ self.endAngle,
+ self._closed,
+ )
+ )
class ArcROI(HandleBasedROI, items.LineMixIn, InteractionModeMixIn):
@@ -210,19 +245,37 @@ class ArcROI(HandleBasedROI, items.LineMixIn, InteractionModeMixIn):
- 1 anchor to translate the shape.
"""
- ICON = 'add-shape-arc'
- NAME = 'arc ROI'
+ ICON = "add-shape-arc"
+ NAME = "arc ROI"
SHORT_NAME = "arc"
"""Metadata for this kind of ROI"""
_plotShape = "line"
"""Plot shape which is used for the first interaction"""
- ThreePointMode = RoiInteractionMode("3 points", "Provides 3 points to define the main radius circle")
- PolarMode = RoiInteractionMode("Polar", "Provides anchors to edit the ROI in polar coords")
+ ThreePointMode = RoiInteractionMode(
+ "3 points", "Provides 3 points to define the main radius circle"
+ )
+ PolarMode = RoiInteractionMode(
+ "Polar", "Provides anchors to edit the ROI in polar coords"
+ )
# FIXME: MoveMode was designed cause there is too much anchors
# FIXME: It would be good replace it by a dnd on the shape
- MoveMode = RoiInteractionMode("Translation", "Provides anchors to only move the ROI")
+ MoveMode = RoiInteractionMode(
+ "Translation", "Provides anchors to only move the ROI"
+ )
+
+ class Role(enum.Enum):
+ """Identify a set of roles which can be used for now to reach some positions"""
+
+ START = 0
+ """Location of the anchor at the start of the arc"""
+ STOP = 1
+ """Location of the anchor at the stop of the arc"""
+ MIDDLE = 2
+ """Location of the anchor at the middle of the arc"""
+ CENTER = 3
+ """Location of the center of the circle"""
def __init__(self, parent=None):
HandleBasedROI.__init__(self, parent=parent)
@@ -265,22 +318,28 @@ class ArcROI(HandleBasedROI, items.LineMixIn, InteractionModeMixIn):
:param RoiInteractionMode modeId:
"""
if modeId is self.ThreePointMode:
+ self._handleStart.setVisible(True)
+ self._handleEnd.setVisible(True)
+ self._handleWeight.setVisible(True)
self._handleStart.setSymbol("s")
self._handleMid.setSymbol("s")
self._handleEnd.setSymbol("s")
self._handleWeight.setSymbol("d")
self._handleMove.setSymbol("+")
elif modeId is self.PolarMode:
+ self._handleStart.setVisible(True)
+ self._handleEnd.setVisible(True)
+ self._handleWeight.setVisible(True)
self._handleStart.setSymbol("o")
self._handleMid.setSymbol("o")
self._handleEnd.setSymbol("o")
self._handleWeight.setSymbol("d")
self._handleMove.setSymbol("+")
elif modeId is self.MoveMode:
- self._handleStart.setSymbol("")
+ self._handleStart.setVisible(False)
+ self._handleEnd.setVisible(False)
+ self._handleWeight.setVisible(False)
self._handleMid.setSymbol("+")
- self._handleEnd.setSymbol("")
- self._handleWeight.setSymbol("")
self._handleMove.setSymbol("+")
else:
assert False
@@ -302,7 +361,7 @@ class ArcROI(HandleBasedROI, items.LineMixIn, InteractionModeMixIn):
self.__shape.setLineWidth(style.getLineWidth())
def setFirstShapePoints(self, points):
- """"Initialize the ROI using the points from the first interaction.
+ """Initialize the ROI using the points from the first interaction.
This interaction is constrained by the plot API and only supports few
shapes.
@@ -367,7 +426,9 @@ class ArcROI(HandleBasedROI, items.LineMixIn, InteractionModeMixIn):
elif geometry.center is not None:
midAngle = (geometry.startAngle + geometry.endAngle) * 0.5
vector = numpy.array([numpy.cos(midAngle), numpy.sin(midAngle)])
- weightPos = geometry.center + (geometry.radius + geometry.weight * 0.5) * vector
+ weightPos = (
+ geometry.center + (geometry.radius + geometry.weight * 0.5) * vector
+ )
with utils.blockSignals(self._handleWeight):
self._handleWeight.setPosition(*weightPos)
@@ -393,7 +454,9 @@ class ArcROI(HandleBasedROI, items.LineMixIn, InteractionModeMixIn):
self._updateWeightHandle()
self._updateShape()
- def _updateCurvature(self, start, mid, end, updateCurveHandles, checkClosed=False, updateStart=False):
+ def _updateCurvature(
+ self, start, mid, end, updateCurveHandles, checkClosed=False, updateStart=False
+ ):
"""Update the curvature using 3 control points in the curve
:param bool updateCurveHandles: If False curve handles are already at
@@ -418,7 +481,9 @@ class ArcROI(HandleBasedROI, items.LineMixIn, InteractionModeMixIn):
self._handleEnd.setPosition(*end)
weight = self._geometry.weight
- geometry = self._createGeometryFromControlPoints(start, mid, end, weight, closed=closed)
+ geometry = self._createGeometryFromControlPoints(
+ start, mid, end, weight, closed=closed
+ )
self._geometry = geometry
self._updateWeightHandle()
@@ -433,10 +498,10 @@ class ArcROI(HandleBasedROI, items.LineMixIn, InteractionModeMixIn):
sign = 1 if geometry.startAngle < geometry.endAngle else -1
if updateStart:
geometry.startPoint = geometry.endPoint
- geometry.startAngle = geometry.endAngle - sign * 2*numpy.pi
+ geometry.startAngle = geometry.endAngle - sign * 2 * numpy.pi
else:
geometry.endPoint = geometry.startPoint
- geometry.endAngle = geometry.startAngle + sign * 2*numpy.pi
+ geometry.endAngle = geometry.startAngle + sign * 2 * numpy.pi
def handleDragUpdated(self, handle, origin, previous, current):
modeId = self.getInteractionMode()
@@ -445,8 +510,12 @@ class ArcROI(HandleBasedROI, items.LineMixIn, InteractionModeMixIn):
mid = numpy.array(self._handleMid.getPosition())
end = numpy.array(self._handleEnd.getPosition())
self._updateCurvature(
- current, mid, end, checkClosed=True, updateStart=True,
- updateCurveHandles=False
+ current,
+ mid,
+ end,
+ checkClosed=True,
+ updateStart=True,
+ updateCurveHandles=False,
)
elif modeId is self.PolarMode:
v = current - self._geometry.center
@@ -477,8 +546,12 @@ class ArcROI(HandleBasedROI, items.LineMixIn, InteractionModeMixIn):
start = numpy.array(self._handleStart.getPosition())
mid = numpy.array(self._handleMid.getPosition())
self._updateCurvature(
- start, mid, current, checkClosed=True, updateStart=False,
- updateCurveHandles=False
+ start,
+ mid,
+ current,
+ checkClosed=True,
+ updateStart=False,
+ updateCurveHandles=False,
)
elif modeId is self.PolarMode:
v = current - self._geometry.center
@@ -511,8 +584,7 @@ class ArcROI(HandleBasedROI, items.LineMixIn, InteractionModeMixIn):
return abs(point1[0] - point2[0]) + abs(point1[1] - point2[1]) < 15
def _normalizeGeometry(self):
- """Keep the same phisical geometry, but with normalized parameters.
- """
+ """Keep the same phisical geometry, but with normalized parameters."""
geometry = self._geometry
if geometry.weight * 0.5 >= geometry.radius:
radius = (geometry.weight * 0.5 + geometry.radius) * 0.5
@@ -582,8 +654,9 @@ class ArcROI(HandleBasedROI, items.LineMixIn, InteractionModeMixIn):
if endAngle > startAngle:
endAngle -= 2 * numpy.pi
- return _ArcGeometry(center, start, end,
- radius, weight, startAngle, endAngle)
+ return _ArcGeometry(
+ center, start, end, radius, weight, startAngle, endAngle
+ )
def _createShapeFromGeometry(self, geometry):
kind = geometry.getKind()
@@ -595,11 +668,14 @@ class ArcROI(HandleBasedROI, items.LineMixIn, InteractionModeMixIn):
distance = numpy.linalg.norm(normal)
if distance != 0:
normal /= distance
- points = numpy.array([
- geometry.startPoint + normal * geometry.weight * 0.5,
- geometry.endPoint + normal * geometry.weight * 0.5,
- geometry.endPoint - normal * geometry.weight * 0.5,
- geometry.startPoint - normal * geometry.weight * 0.5])
+ points = numpy.array(
+ [
+ geometry.startPoint + normal * geometry.weight * 0.5,
+ geometry.endPoint + normal * geometry.weight * 0.5,
+ geometry.endPoint - normal * geometry.weight * 0.5,
+ geometry.startPoint - normal * geometry.weight * 0.5,
+ ]
+ )
elif kind == "point":
# It is not an arc
# but we can display it as an intermediate shape
@@ -712,7 +788,29 @@ class ArcROI(HandleBasedROI, items.LineMixIn, InteractionModeMixIn):
geometry = self._geometry
if geometry.center is None:
raise ValueError("This ROI can't be represented as a section of circle")
- return geometry.center, self.getInnerRadius(), self.getOuterRadius(), geometry.startAngle, geometry.endAngle
+ return (
+ geometry.center,
+ self.getInnerRadius(),
+ self.getOuterRadius(),
+ geometry.startAngle,
+ geometry.endAngle,
+ )
+
+ def getPosition(self, role: Role = Role.CENTER) -> Tuple[float, float]:
+ """Returns a position by it's role.
+
+ By default returns the center of the circle of the arc ROI.
+ """
+ if role == self.Role.START:
+ return self._handleStart.getPosition()
+ if role == self.Role.STOP:
+ return self._handleEnd.getPosition()
+ if role == self.Role.MIDDLE:
+ return self._handleMid.getPosition()
+ if role == self.Role.CENTER:
+ p = self.getCenter()
+ return p[0], p[1]
+ raise ValueError(f"{role} is not supported")
def isClosed(self):
"""Returns true if the arc is a closed shape, like a circle or a donut.
@@ -795,9 +893,16 @@ class ArcROI(HandleBasedROI, items.LineMixIn, InteractionModeMixIn):
vector = numpy.array([numpy.cos(endAngle), numpy.sin(endAngle)])
endPoint = center + vector * radius
- geometry = _ArcGeometry(center, startPoint, endPoint,
- radius, weight,
- startAngle, endAngle, closed=None)
+ geometry = _ArcGeometry(
+ center,
+ startPoint,
+ endPoint,
+ radius,
+ weight,
+ startAngle,
+ endAngle,
+ closed=None,
+ )
self._geometry = geometry
self._updateHandles()
@@ -805,7 +910,9 @@ class ArcROI(HandleBasedROI, items.LineMixIn, InteractionModeMixIn):
def contains(self, position):
# first check distance, fastest
center = self.getCenter()
- distance = numpy.sqrt((position[1] - center[1]) ** 2 + ((position[0] - center[0])) ** 2)
+ distance = numpy.sqrt(
+ (position[1] - center[1]) ** 2 + ((position[0] - center[0])) ** 2
+ )
is_in_distance = self.getInnerRadius() <= distance <= self.getOuterRadius()
if not is_in_distance:
return False
@@ -871,8 +978,15 @@ class ArcROI(HandleBasedROI, items.LineMixIn, InteractionModeMixIn):
def __str__(self):
try:
center, innerRadius, outerRadius, startAngle, endAngle = self.getGeometry()
- params = center[0], center[1], innerRadius, outerRadius, startAngle, endAngle
- params = 'center: %f %f; radius: %f %f; angles: %f %f' % params
+ params = (
+ center[0],
+ center[1],
+ innerRadius,
+ outerRadius,
+ startAngle,
+ endAngle,
+ )
+ params = "center: %f %f; radius: %f %f; angles: %f %f" % params
except ValueError:
params = "invalid"
return "%s(%s)" % (self.__class__.__name__, params)
diff --git a/src/silx/gui/plot/items/_band_roi.py b/src/silx/gui/plot/items/_band_roi.py
index a60a177..0d2ad4e 100644
--- a/src/silx/gui/plot/items/_band_roi.py
+++ b/src/silx/gui/plot/items/_band_roi.py
@@ -100,7 +100,7 @@ class BandGeometry(NamedTuple):
def slope(self) -> float:
"""Slope of the line (begin, end), infinity for a vertical line"""
if self.begin.x == self.end.x:
- return float('inf')
+ return float("inf")
return (self.end.y - self.begin.y) / (self.end.x - self.begin.x)
@property
@@ -309,18 +309,20 @@ class BandROI(HandleBasedROI, items.LineMixIn, InteractionModeMixIn):
)
@staticmethod
- def __snap(point: Tuple[float, float], fixed: Tuple[float, float]) -> Tuple[float, float]:
+ def __snap(
+ point: Tuple[float, float], fixed: Tuple[float, float]
+ ) -> Tuple[float, float]:
"""Snap point so that vector [point, fixed] snap to direction 0, 45 or 90 degrees
:return: the snapped point position.
"""
vector = point[0] - fixed[0], point[1] - fixed[1]
angle = numpy.arctan2(vector[1], vector[0])
- snapAngle = numpy.pi/4 * numpy.round(angle / (numpy.pi/4))
+ snapAngle = numpy.pi / 4 * numpy.round(angle / (numpy.pi / 4))
length = numpy.linalg.norm(vector)
return (
fixed[0] + length * numpy.cos(snapAngle),
- fixed[1] + length * numpy.sin(snapAngle)
+ fixed[1] + length * numpy.sin(snapAngle),
)
def handleDragUpdated(self, handle, origin, previous, current):
@@ -353,12 +355,16 @@ class BandROI(HandleBasedROI, items.LineMixIn, InteractionModeMixIn):
def __handleWidthUpConstraint(self, x: float, y: float) -> Tuple[float, float]:
geometry = self.getGeometry()
- offset = max(0, numpy.dot(geometry.normal, numpy.array((x, y)) - geometry.center))
+ offset = max(
+ 0, numpy.dot(geometry.normal, numpy.array((x, y)) - geometry.center)
+ )
return tuple(geometry.center + offset * numpy.array(geometry.normal))
def __handleWidthDownConstraint(self, x: float, y: float) -> Tuple[float, float]:
geometry = self.getGeometry()
- offset = max(0, -numpy.dot(geometry.normal, numpy.array((x, y)) - geometry.center))
+ offset = max(
+ 0, -numpy.dot(geometry.normal, numpy.array((x, y)) - geometry.center)
+ )
return tuple(geometry.center - offset * numpy.array(geometry.normal))
@docstring(_RegionOfInterestBase)
diff --git a/src/silx/gui/plot/items/_roi_base.py b/src/silx/gui/plot/items/_roi_base.py
index 765a538..43c5381 100644
--- a/src/silx/gui/plot/items/_roi_base.py
+++ b/src/silx/gui/plot/items/_roi_base.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2018-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -37,14 +37,14 @@ __date__ = "28/06/2018"
import logging
import numpy
import weakref
+import functools
+from typing import Optional
from ....utils.weakref import WeakList
from ... import qt
from .. import items
from ..items import core
from ...colors import rgba
-import silx.utils.deprecation
-from ....utils.proxy import docstring
logger = logging.getLogger(__name__)
@@ -68,8 +68,10 @@ class _RegionOfInterestBase(qt.QObject):
"""
def __init__(self, parent=None):
- qt.QObject.__init__(self, parent=parent)
- self.__name = ''
+ qt.QObject.__init__(self)
+ if parent is not None:
+ self.setParent(parent)
+ self.__name = ""
def getName(self):
"""Returns the name of the ROI
@@ -120,10 +122,12 @@ class RoiInteractionMode(object):
@property
def label(self):
+ """Short name"""
return self._label
@property
def description(self):
+ """Longer description of the interaction mode"""
return self._description
@@ -188,6 +192,28 @@ class InteractionModeMixIn(object):
"""
return self.__modeId
+ def createMenuForInteractionMode(self, parent: qt.QWidget) -> qt.QMenu:
+ """Create a menu providing access to the different interaction modes"""
+ availableModes = self.availableInteractionModes()
+ currentMode = self.getInteractionMode()
+ submenu = qt.QMenu(parent)
+ modeGroup = qt.QActionGroup(parent)
+ modeGroup.setExclusive(True)
+ for mode in availableModes:
+ action = qt.QAction(parent)
+ action.setText(mode.label)
+ action.setToolTip(mode.description)
+ action.setCheckable(True)
+ if mode is currentMode:
+ action.setChecked(True)
+ else:
+ callback = functools.partial(self.setInteractionMode, mode)
+ action.triggered.connect(callback)
+ modeGroup.addAction(action)
+ submenu.addAction(action)
+ submenu.setTitle("Interaction mode")
+ return submenu
+
class RegionOfInterest(_RegionOfInterestBase, core.HighlightedMixIn):
"""Object describing a region of interest in a plot.
@@ -196,10 +222,10 @@ class RegionOfInterest(_RegionOfInterestBase, core.HighlightedMixIn):
The RegionOfInterestManager that created this object
"""
- _DEFAULT_LINEWIDTH = 1.
+ _DEFAULT_LINEWIDTH = 1.0
"""Default line width of the curve"""
- _DEFAULT_LINESTYLE = '-'
+ _DEFAULT_LINESTYLE = "-"
"""Default line style of the curve"""
_DEFAULT_HIGHLIGHT_STYLE = items.CurveStyle(linewidth=2)
@@ -225,15 +251,18 @@ class RegionOfInterest(_RegionOfInterestBase, core.HighlightedMixIn):
def __init__(self, parent=None):
# Avoid circular dependency
from ..tools import roi as roi_tools
+
assert parent is None or isinstance(parent, roi_tools.RegionOfInterestManager)
+ # Must be done before _RegionOfInterestBase.__init__
+ self._child = WeakList()
_RegionOfInterestBase.__init__(self, parent)
core.HighlightedMixIn.__init__(self)
- self._color = rgba('red')
+ self.__text = None
+ self._color = rgba("red")
self._editable = False
self._selectable = False
self._focusProxy = None
self._visible = True
- self._child = WeakList()
def _connectToPlot(self, plot):
"""Called after connection to a plot"""
@@ -263,8 +292,11 @@ class RegionOfInterest(_RegionOfInterestBase, core.HighlightedMixIn):
"""
# Avoid circular dependency
from ..tools import roi as roi_tools
- if (parent is not None and not isinstance(parent, roi_tools.RegionOfInterestManager)):
- raise ValueError('Unsupported parent')
+
+ if parent is not None and not isinstance(
+ parent, roi_tools.RegionOfInterestManager
+ ):
+ raise ValueError("Unsupported parent")
previousParent = self.parent()
if previousParent is not None:
@@ -292,7 +324,7 @@ class RegionOfInterest(_RegionOfInterestBase, core.HighlightedMixIn):
"""
assert item is not None
self._child.append(item)
- if item.getName() == '':
+ if item.getName() == "":
self._setItemName(item)
manager = self.parent()
if manager is not None:
@@ -352,26 +384,6 @@ class RegionOfInterest(_RegionOfInterestBase, core.HighlightedMixIn):
self._color = color
self._updated(items.ItemChangedType.COLOR)
- @silx.utils.deprecation.deprecated(reason='API modification',
- replacement='getName()',
- since_version=0.12)
- def getLabel(self):
- """Returns the label displayed for this ROI.
-
- :rtype: str
- """
- return self.getName()
-
- @silx.utils.deprecation.deprecated(reason='API modification',
- replacement='setName(name)',
- since_version=0.12)
- def setLabel(self, label):
- """Set the label displayed with this ROI.
-
- :param str label: The text label to display
- """
- self.setName(name=label)
-
def isEditable(self):
"""Returns whether the ROI is editable by the user or not.
@@ -457,6 +469,26 @@ class RegionOfInterest(_RegionOfInterestBase, core.HighlightedMixIn):
self._visible = visible
self._updated(items.ItemChangedType.VISIBLE)
+ def getText(self) -> str:
+ """Returns the currently displayed text for this ROI"""
+ return self.getName() if self.__text is None else self.__text
+
+ def setText(self, text: Optional[str] = None) -> None:
+ """Set the displayed text for this ROI.
+
+ If None (the default), the ROI name is used.
+ """
+ if self.__text != text:
+ self.__text = text
+ self._updated(items.ItemChangedType.TEXT)
+
+ def _updateText(self, text: str) -> None:
+ """Update the text displayed by this ROI
+
+ Override in subclass to custom text display
+ """
+ pass
+
@classmethod
def showFirstInteractionShape(cls):
"""Returns True if the shape created by the first interaction and
@@ -478,7 +510,7 @@ class RegionOfInterest(_RegionOfInterestBase, core.HighlightedMixIn):
return cls._plotShape
def setFirstShapePoints(self, points):
- """"Initialize the ROI using the points from the first interaction.
+ """Initialize the ROI using the points from the first interaction.
This interaction is constrained by the plot API and only supports few
shapes.
@@ -486,13 +518,11 @@ class RegionOfInterest(_RegionOfInterestBase, core.HighlightedMixIn):
raise NotImplementedError()
def creationStarted(self):
- """"Called when the ROI creation interaction was started.
- """
+ """Called when the ROI creation interaction was started."""
pass
def creationFinalized(self):
- """"Called when the ROI creation interaction was finalized.
- """
+ """Called when the ROI creation interaction was finalized."""
pass
def _updateItemProperty(self, event, source, destination):
@@ -544,15 +574,23 @@ class RegionOfInterest(_RegionOfInterestBase, core.HighlightedMixIn):
assert False
def _updated(self, event=None, checkVisibility=True):
- if event == items.ItemChangedType.HIGHLIGHTED:
+ if event == items.ItemChangedType.TEXT:
+ self._updateText(self.getText())
+ elif event == items.ItemChangedType.HIGHLIGHTED:
+ for item in self.getItems():
+ zoffset = 1000 if self.isHighlighted() else 0
+ item.setZValue(item._DEFAULT_Z_LAYER + zoffset)
+
style = self.getCurrentStyle()
self._updatedStyle(event, style)
else:
- styleEvents = [items.ItemChangedType.COLOR,
- items.ItemChangedType.LINE_STYLE,
- items.ItemChangedType.LINE_WIDTH,
- items.ItemChangedType.SYMBOL,
- items.ItemChangedType.SYMBOL_SIZE]
+ styleEvents = [
+ items.ItemChangedType.COLOR,
+ items.ItemChangedType.LINE_STYLE,
+ items.ItemChangedType.LINE_WIDTH,
+ items.ItemChangedType.SYMBOL,
+ items.ItemChangedType.SYMBOL_SIZE,
+ ]
if self.isHighlighted():
styleEvents.append(items.ItemChangedType.HIGHLIGHTED_STYLE)
@@ -562,7 +600,11 @@ class RegionOfInterest(_RegionOfInterestBase, core.HighlightedMixIn):
super(RegionOfInterest, self)._updated(event, checkVisibility)
- def _updatedStyle(self, event, style):
+ # Displayed text has changed, send a text event
+ if event == items.ItemChangedType.NAME and self.__text is None:
+ self._updated(items.ItemChangedType.TEXT, checkVisibility)
+
+ def _updatedStyle(self, event, style: items.CurveStyle):
"""Called when the current displayed style of the ROI was changed.
:param event: The event responsible of the change of the style
@@ -570,7 +612,7 @@ class RegionOfInterest(_RegionOfInterestBase, core.HighlightedMixIn):
"""
pass
- def getCurrentStyle(self):
+ def getCurrentStyle(self) -> items.CurveStyle:
"""Returns the current curve style.
Curve style depends on curve highlighting
@@ -588,7 +630,7 @@ class RegionOfInterest(_RegionOfInterestBase, core.HighlightedMixIn):
baseSymbol = self.getSymbol()
baseSymbolsize = self.getSymbolSize()
else:
- baseSymbol = 'o'
+ baseSymbol = "o"
baseSymbolsize = 1
if self.isHighlighted():
@@ -604,13 +646,16 @@ class RegionOfInterest(_RegionOfInterestBase, core.HighlightedMixIn):
linestyle=baseLinestyle if linestyle is None else linestyle,
linewidth=baseLinewidth if linewidth is None else linewidth,
symbol=baseSymbol if symbol is None else symbol,
- symbolsize=baseSymbolsize if symbolsize is None else symbolsize)
+ symbolsize=baseSymbolsize if symbolsize is None else symbolsize,
+ )
else:
- return items.CurveStyle(color=baseColor,
- linestyle=baseLinestyle,
- linewidth=baseLinewidth,
- symbol=baseSymbol,
- symbolsize=baseSymbolsize)
+ return items.CurveStyle(
+ color=baseColor,
+ linestyle=baseLinestyle,
+ linewidth=baseLinewidth,
+ symbol=baseSymbol,
+ symbolsize=baseSymbolsize,
+ )
def _editingStarted(self):
assert self._editable is True
@@ -619,6 +664,10 @@ class RegionOfInterest(_RegionOfInterestBase, core.HighlightedMixIn):
def _editingFinished(self):
self.sigEditingFinished.emit()
+ def populateContextMenu(self, menu: qt.QMenu):
+ """Populate a menu used as a context menu"""
+ pass
+
class HandleBasedROI(RegionOfInterest):
"""Manage a ROI based on a set of handles"""
@@ -730,9 +779,7 @@ class HandleBasedROI(RegionOfInterest):
See :class:`~silx.gui.plot.items.Item._updated`
"""
- if event == items.ItemChangedType.NAME:
- self._updateText(self.getName())
- elif event == items.ItemChangedType.VISIBLE:
+ if event == items.ItemChangedType.VISIBLE:
for item, role in self._handles:
visible = self.isVisible()
editionVisible = visible and self.isEditable()
@@ -754,9 +801,9 @@ class HandleBasedROI(RegionOfInterest):
color = rgba(self.getColor())
handleColor = self._computeHandleColor(color)
for item, role in self._handles:
- if role == 'user':
+ if role == "user":
pass
- elif role == 'label':
+ elif role == "label":
item.setColor(color)
else:
item.setColor(handleColor)
@@ -825,10 +872,3 @@ class HandleBasedROI(RegionOfInterest):
:rtype: Union[numpy.array,Tuple,List]
"""
return color[:3] + (0.5,)
-
- def _updateText(self, text):
- """Update the text displayed by this ROI
-
- :param str text: A text
- """
- pass
diff --git a/src/silx/gui/plot/items/axis.py b/src/silx/gui/plot/items/axis.py
index fa3f6d7..1ae1ef1 100644
--- a/src/silx/gui/plot/items/axis.py
+++ b/src/silx/gui/plot/items/axis.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2017-2022 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -24,28 +24,28 @@
"""This module provides the class for axes of the :class:`PlotWidget`.
"""
+from __future__ import annotations
+
__authors__ = ["V. Valls"]
__license__ = "MIT"
__date__ = "22/11/2018"
import datetime as dt
import enum
-import logging
+from typing import Optional
import dateutil.tz
-import numpy
+from ....utils.proxy import docstring
from ... import qt
from .. import _utils
-_logger = logging.getLogger(__name__)
-
-
class TickMode(enum.Enum):
"""Determines if ticks are regular number or datetimes."""
- DEFAULT = 0 # Ticks are regular numbers
- TIME_SERIES = 1 # Ticks are datetime objects
+
+ DEFAULT = 0 # Ticks are regular numbers
+ TIME_SERIES = 1 # Ticks are datetime objects
class Axis(qt.QObject):
@@ -53,6 +53,7 @@ class Axis(qt.QObject):
Note: This is an abstract class.
"""
+
# States are half-stored on the backend of the plot, and half-stored on this
# object.
# TODO It would be good to store all the states of an axis in this object.
@@ -91,10 +92,10 @@ class Axis(qt.QObject):
self._scale = self.LINEAR
self._isAutoScale = True
# Store default labels provided to setGraph[X|Y]Label
- self._defaultLabel = ''
+ self._defaultLabel = ""
# Store currently displayed labels
# Current label can differ from input one with active curve handling
- self._currentLabel = ''
+ self._currentLabel = ""
def _getPlot(self):
"""Returns the PlotWidget this Axis belongs to.
@@ -150,7 +151,12 @@ class Axis(qt.QObject):
:rtype: 2-tuple of float
"""
return _utils.checkAxisLimits(
- vmin, vmax, isLog=self._isLogarithmic(), name=self._defaultLabel)
+ vmin, vmax, isLog=self._isLogarithmic(), name=self._defaultLabel
+ )
+
+ def _getDataRange(self) -> Optional[tuple[float, float]]:
+ """Returns the range of data items over this axis as (vmin, vmax)"""
+ raise NotImplementedError()
def isInverted(self):
"""Return True if the axis is inverted (top to bottom for the y-axis),
@@ -172,6 +178,10 @@ class Axis(qt.QObject):
return
raise NotImplementedError()
+ def isVisible(self) -> bool:
+ """Returns whether the axis is displayed or not"""
+ return True
+
def getLabel(self):
"""Return the current displayed label of this axis.
@@ -199,10 +209,10 @@ class Axis(qt.QObject):
:param str label: Currently displayed label
"""
- if label is None or label == '':
+ if label is None or label == "":
label = self._defaultLabel
if label is None:
- label = ''
+ label = ""
self._currentLabel = label
self._internalSetCurrentLabel(label)
@@ -218,7 +228,7 @@ class Axis(qt.QObject):
:param str scale: Name of the scale ("log", or "linear")
"""
- assert(scale in self._SCALES)
+ assert scale in self._SCALES
if self._scale == scale:
return
@@ -227,6 +237,8 @@ class Axis(qt.QObject):
self._scale = scale
+ vmin, vmax = self.getLimits()
+
# TODO hackish way of forcing update of curves and images
plot = self._getPlot()
for item in plot.getItems():
@@ -235,13 +247,20 @@ class Axis(qt.QObject):
if scale == self.LOGARITHMIC:
self._internalSetLogarithmic(True)
+ if vmin <= 0:
+ dataRange = self._getDataRange()
+ if dataRange is None:
+ self.setLimits(1.0, 100.0)
+ else:
+ if vmax > 0 and dataRange[0] < vmax:
+ self.setLimits(dataRange[0], vmax)
+ else:
+ self.setLimits(*dataRange)
elif scale == self.LINEAR:
self._internalSetLogarithmic(False)
else:
raise ValueError("Scale %s unsupported" % scale)
- plot._forceResetZoom()
-
self.sigScaleChanged.emit(self._scale)
if emitLog:
self._sigLogarithmicChanged.emit(self._scale == self.LOGARITHMIC)
@@ -328,7 +347,7 @@ class Axis(qt.QObject):
plot = self._getPlot()
xMin, xMax = plot.getXAxis().getLimits()
yMin, yMax = plot.getYAxis().getLimits()
- y2Min, y2Max = plot.getYAxis('right').getLimits()
+ y2Min, y2Max = plot.getYAxis("right").getLimits()
plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
return updated
@@ -351,7 +370,7 @@ class Axis(qt.QObject):
plot = self._getPlot()
xMin, xMax = plot.getXAxis().getLimits()
yMin, yMax = plot.getYAxis().getLimits()
- y2Min, y2Max = plot.getYAxis('right').getLimits()
+ y2Min, y2Max = plot.getYAxis("right").getLimits()
plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
return updated
@@ -368,7 +387,7 @@ class XAxis(Axis):
def setTimeZone(self, tz):
if isinstance(tz, str) and tz.upper() == "UTC":
tz = dateutil.tz.tzutc()
- elif not(tz is None or isinstance(tz, dt.tzinfo)):
+ elif not (tz is None or isinstance(tz, dt.tzinfo)):
raise TypeError("tz must be a dt.tzinfo object, None or 'UTC'.")
self._getBackend().setXAxisTimeZone(tz)
@@ -410,6 +429,11 @@ class XAxis(Axis):
updated = constrains.update(minXRange=minRange, maxXRange=maxRange)
return updated
+ @docstring(Axis)
+ def _getDataRange(self) -> Optional[tuple[float, float]]:
+ ranges = self._getPlot().getDataRange()
+ return ranges.x
+
class YAxis(Axis):
"""Axis class defining primitives for the Y axis"""
@@ -418,13 +442,13 @@ class YAxis(Axis):
# specialised implementations (prefixel by '_internal')
def _internalSetCurrentLabel(self, label):
- self._getBackend().setGraphYLabel(label, axis='left')
+ self._getBackend().setGraphYLabel(label, axis="left")
def _internalGetLimits(self):
- return self._getBackend().getGraphYLimits(axis='left')
+ return self._getBackend().getGraphYLimits(axis="left")
def _internalSetLimits(self, ymin, ymax):
- self._getBackend().setGraphYLimits(ymin, ymax, axis='left')
+ self._getBackend().setGraphYLimits(ymin, ymax, axis="left")
def _internalSetLogarithmic(self, flag):
self._getBackend().setYAxisLogarithmic(flag)
@@ -462,6 +486,11 @@ class YAxis(Axis):
updated = constrains.update(minYRange=minRange, maxYRange=maxRange)
return updated
+ @docstring(Axis)
+ def _getDataRange(self) -> Optional[tuple[float, float]]:
+ ranges = self._getPlot().getDataRange()
+ return ranges.y
+
class YRightAxis(Axis):
"""Proxy axis for the secondary Y axes. It manages it own label and limit
@@ -485,13 +514,13 @@ class YRightAxis(Axis):
self.__mainAxis.sigAutoScaleChanged.connect(self.sigAutoScaleChanged.emit)
def _internalSetCurrentLabel(self, label):
- self._getBackend().setGraphYLabel(label, axis='right')
+ self._getBackend().setGraphYLabel(label, axis="right")
def _internalGetLimits(self):
- return self._getBackend().getGraphYLimits(axis='right')
+ return self._getBackend().getGraphYLimits(axis="right")
def _internalSetLimits(self, ymin, ymax):
- self._getBackend().setGraphYLimits(ymin, ymax, axis='right')
+ self._getBackend().setGraphYLimits(ymin, ymax, axis="right")
def setInverted(self, flag=True):
"""Set the Y axis orientation.
@@ -505,6 +534,10 @@ class YRightAxis(Axis):
"""Return True if Y axis goes from top to bottom, False otherwise."""
return self.__mainAxis.isInverted()
+ def isVisible(self) -> bool:
+ """Returns whether the axis is displayed or not"""
+ return self._getBackend().isYRightAxisVisible()
+
def getScale(self):
"""Return the name of the scale used by this axis.
@@ -541,3 +574,8 @@ class YRightAxis(Axis):
False to disable it.
"""
return self.__mainAxis.setAutoScale(flag)
+
+ @docstring(Axis)
+ def _getDataRange(self) -> Optional[tuple[float, float]]:
+ ranges = self._getPlot().getDataRange()
+ return ranges.y2
diff --git a/src/silx/gui/plot/items/complex.py b/src/silx/gui/plot/items/complex.py
index 82d821f..d10767f 100644
--- a/src/silx/gui/plot/items/complex.py
+++ b/src/silx/gui/plot/items/complex.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -34,7 +34,6 @@ import logging
import numpy
from ....utils.proxy import docstring
-from ....utils.deprecation import deprecated
from ...colors import Colormap
from .core import ColormapMixIn, ComplexMixIn, ItemChangedType
from .image import ImageBase
@@ -45,6 +44,7 @@ _logger = logging.getLogger(__name__)
# Complex colormap functions
+
def _phase2rgb(colormap, data):
"""Creates RGBA image with colour-coded phase.
@@ -60,7 +60,7 @@ def _phase2rgb(colormap, data):
return colormap.applyToData(phase)
-def _complex2rgbalog(phaseColormap, data, amin=0., dlogs=2, smax=None):
+def _complex2rgbalog(phaseColormap, data, amin=0.0, dlogs=2, smax=None):
"""Returns RGBA colors: colour-coded phases and log10(amplitude) in alpha.
:param Colormap phaseColormap: Colormap to use for the phase
@@ -117,7 +117,8 @@ class ImageComplexData(ImageBase, ColormapMixIn, ComplexMixIn):
ComplexMixIn.ComplexMode.IMAGINARY,
ComplexMixIn.ComplexMode.AMPLITUDE_PHASE,
ComplexMixIn.ComplexMode.LOG10_AMPLITUDE_PHASE,
- ComplexMixIn.ComplexMode.SQUARE_AMPLITUDE)
+ ComplexMixIn.ComplexMode.SQUARE_AMPLITUDE,
+ )
"""Overrides supported ComplexMode"""
def __init__(self):
@@ -130,10 +131,7 @@ class ImageComplexData(ImageBase, ColormapMixIn, ComplexMixIn):
# Use default from ColormapMixIn
colormap = super(ImageComplexData, self).getColormap()
- phaseColormap = Colormap(
- name='hsv',
- vmin=-numpy.pi,
- vmax=numpy.pi)
+ phaseColormap = Colormap(name="hsv", vmin=-numpy.pi, vmax=numpy.pi)
self._colormaps = { # Default colormaps for all modes
self.ComplexMode.ABSOLUTE: colormap,
@@ -154,8 +152,10 @@ class ImageComplexData(ImageBase, ColormapMixIn, ComplexMixIn):
return None
mode = self.getComplexMode()
- if mode in (self.ComplexMode.AMPLITUDE_PHASE,
- self.ComplexMode.LOG10_AMPLITUDE_PHASE):
+ if mode in (
+ self.ComplexMode.AMPLITUDE_PHASE,
+ self.ComplexMode.LOG10_AMPLITUDE_PHASE,
+ ):
# For those modes, compute RGBA image here
colormap = None
data = self.getRgbaImageData(copy=False)
@@ -171,11 +171,13 @@ class ImageComplexData(ImageBase, ColormapMixIn, ComplexMixIn):
if data.size == 0:
return None # No data to display
- return backend.addImage(data,
- origin=self.getOrigin(),
- scale=self.getScale(),
- colormap=colormap,
- alpha=self.getAlpha())
+ return backend.addImage(
+ data,
+ origin=self.getOrigin(),
+ scale=self.getScale(),
+ colormap=colormap,
+ alpha=self.getAlpha(),
+ )
@docstring(ComplexMixIn)
def setComplexMode(self, mode):
@@ -247,7 +249,7 @@ class ImageComplexData(ImageBase, ColormapMixIn, ComplexMixIn):
return self._colormaps[mode]
def setData(self, data, copy=True):
- """"Set the image complex data
+ """Set the image complex data
:param numpy.ndarray data: 2D array of complex with 2 dimensions (h, w)
:param bool copy: True (Default) to get a copy,
@@ -257,7 +259,8 @@ class ImageComplexData(ImageBase, ColormapMixIn, ComplexMixIn):
assert data.ndim == 2
if not numpy.issubdtype(data.dtype, numpy.complexfloating):
_logger.warning(
- 'Image is not complex, converting it to complex to plot it.')
+ "Image is not complex, converting it to complex to plot it."
+ )
data = numpy.array(data, dtype=numpy.complex64)
# Compute current mode data and set colormap data
@@ -274,8 +277,9 @@ class ImageComplexData(ImageBase, ColormapMixIn, ComplexMixIn):
if event in (ItemChangedType.DATA, ItemChangedType.MASK):
# Color-mapped data is NOT the `getValueData` for some modes
if self.getComplexMode() in (
- self.ComplexMode.AMPLITUDE_PHASE,
- self.ComplexMode.LOG10_AMPLITUDE_PHASE):
+ self.ComplexMode.AMPLITUDE_PHASE,
+ self.ComplexMode.LOG10_AMPLITUDE_PHASE,
+ ):
data = self.getData(copy=False, mode=self.ComplexMode.PHASE)
mask = self.getMaskData(copy=False)
if mask is not None:
@@ -308,16 +312,18 @@ class ImageComplexData(ImageBase, ColormapMixIn, ComplexMixIn):
return numpy.real(data)
elif mode is self.ComplexMode.IMAGINARY:
return numpy.imag(data)
- elif mode in (self.ComplexMode.ABSOLUTE,
- self.ComplexMode.LOG10_AMPLITUDE_PHASE,
- self.ComplexMode.AMPLITUDE_PHASE):
+ elif mode in (
+ self.ComplexMode.ABSOLUTE,
+ self.ComplexMode.LOG10_AMPLITUDE_PHASE,
+ self.ComplexMode.AMPLITUDE_PHASE,
+ ):
return numpy.absolute(data)
elif mode is self.ComplexMode.SQUARE_AMPLITUDE:
return numpy.absolute(data) ** 2
else:
_logger.error(
- 'Unsupported conversion mode: %s, fallback to absolute',
- str(mode))
+ "Unsupported conversion mode: %s, fallback to absolute", str(mode)
+ )
return numpy.absolute(data)
def getData(self, copy=True, mode=None):
@@ -340,7 +346,8 @@ class ImageComplexData(ImageBase, ColormapMixIn, ComplexMixIn):
if mode not in self._dataByModesCache:
self._dataByModesCache[mode] = self.__convertComplexData(
- self.getComplexData(copy=False), mode)
+ self.getComplexData(copy=False), mode
+ )
return numpy.array(self._dataByModesCache[mode], copy=copy)
@@ -373,11 +380,3 @@ class ImageComplexData(ImageBase, ColormapMixIn, ComplexMixIn):
# Backward compatibility
Mode = ComplexMixIn.ComplexMode
-
- @deprecated(replacement='setComplexMode', since_version='0.11.0')
- def setVisualizationMode(self, mode):
- return self.setComplexMode(mode)
-
- @deprecated(replacement='getComplexMode', since_version='0.11.0')
- def getVisualizationMode(self):
- return self.getComplexMode()
diff --git a/src/silx/gui/plot/items/core.py b/src/silx/gui/plot/items/core.py
index 074c168..7d754a7 100644
--- a/src/silx/gui/plot/items/core.py
+++ b/src/silx/gui/plot/items/core.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2017-2022 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -23,16 +23,14 @@
# ###########################################################################*/
"""This module provides the base class for items of the :class:`Plot`.
"""
+from __future__ import annotations
+
__authors__ = ["T. Vincent"]
__license__ = "MIT"
__date__ = "08/12/2020"
-import collections
-try:
- from collections import abc
-except ImportError: # Python2 support
- import collections as abc
+from collections import abc
from copy import deepcopy
import logging
import enum
@@ -41,13 +39,12 @@ import weakref
import numpy
-from ....utils.deprecation import deprecated
from ....utils.proxy import docstring
from ....utils.enum import Enum as _Enum
from ....math.combo import min_max
from ... import qt
from ... import colors
-from ...colors import Colormap
+from ...colors import Colormap, _Colormappable
from ._pick import PickingResult
from silx import config
@@ -58,98 +55,109 @@ _logger = logging.getLogger(__name__)
@enum.unique
class ItemChangedType(enum.Enum):
"""Type of modification provided by :attr:`Item.sigItemChanged` signal."""
+
# Private setters and setInfo are not emitting sigItemChanged signal.
# Signals to consider:
# COLORMAP_SET emitted when setColormap is called but not forward colormap object signal
# CURRENT_COLOR_CHANGED emitted current color changed because highlight changed,
# highlighted color changed or color changed depending on hightlight state.
- VISIBLE = 'visibleChanged'
+ VISIBLE = "visibleChanged"
"""Item's visibility changed flag."""
- ZVALUE = 'zValueChanged'
+ ZVALUE = "zValueChanged"
"""Item's Z value changed flag."""
- COLORMAP = 'colormapChanged' # Emitted when set + forward events from the colormap object
+ COLORMAP = (
+ "colormapChanged" # Emitted when set + forward events from the colormap object
+ )
"""Item's colormap changed flag.
This is emitted both when setting a new colormap and
when the current colormap object is updated.
"""
- SYMBOL = 'symbolChanged'
+ SYMBOL = "symbolChanged"
"""Item's symbol changed flag."""
- SYMBOL_SIZE = 'symbolSizeChanged'
+ SYMBOL_SIZE = "symbolSizeChanged"
"""Item's symbol size changed flag."""
- LINE_WIDTH = 'lineWidthChanged'
+ LINE_WIDTH = "lineWidthChanged"
"""Item's line width changed flag."""
- LINE_STYLE = 'lineStyleChanged'
+ LINE_STYLE = "lineStyleChanged"
"""Item's line style changed flag."""
- COLOR = 'colorChanged'
+ COLOR = "colorChanged"
"""Item's color changed flag."""
- LINE_BG_COLOR = 'lineBgColorChanged'
- """Item's line background color changed flag."""
+ LINE_BG_COLOR = "lineBgColorChanged" # Deprecated, use LINE_GAP_COLOR
+
+ LINE_GAP_COLOR = "lineGapColorChanged"
+ """Item's dashed line gap color changed flag."""
- YAXIS = 'yAxisChanged'
+ YAXIS = "yAxisChanged"
"""Item's Y axis binding changed flag."""
- FILL = 'fillChanged'
+ FILL = "fillChanged"
"""Item's fill changed flag."""
- ALPHA = 'alphaChanged'
+ ALPHA = "alphaChanged"
"""Item's transparency alpha changed flag."""
- DATA = 'dataChanged'
+ DATA = "dataChanged"
"""Item's data changed flag"""
- MASK = 'maskChanged'
+ MASK = "maskChanged"
"""Item's mask changed flag"""
- HIGHLIGHTED = 'highlightedChanged'
+ HIGHLIGHTED = "highlightedChanged"
"""Item's highlight state changed flag."""
- HIGHLIGHTED_COLOR = 'highlightedColorChanged'
+ HIGHLIGHTED_COLOR = "highlightedColorChanged"
"""Deprecated, use HIGHLIGHTED_STYLE instead."""
- HIGHLIGHTED_STYLE = 'highlightedStyleChanged'
+ HIGHLIGHTED_STYLE = "highlightedStyleChanged"
"""Item's highlighted style changed flag."""
- SCALE = 'scaleChanged'
+ SCALE = "scaleChanged"
"""Item's scale changed flag."""
- TEXT = 'textChanged'
+ TEXT = "textChanged"
"""Item's text changed flag."""
- POSITION = 'positionChanged'
+ POSITION = "positionChanged"
"""Item's position changed flag.
This is emitted when a marker position changed and
when an image origin changed.
"""
- OVERLAY = 'overlayChanged'
+ OVERLAY = "overlayChanged"
"""Item's overlay state changed flag."""
- VISUALIZATION_MODE = 'visualizationModeChanged'
+ VISUALIZATION_MODE = "visualizationModeChanged"
"""Item's visualization mode changed flag."""
- COMPLEX_MODE = 'complexModeChanged'
+ COMPLEX_MODE = "complexModeChanged"
"""Item's complex data visualization mode changed flag."""
- NAME = 'nameChanged'
+ NAME = "nameChanged"
"""Item's name changed flag."""
- EDITABLE = 'editableChanged'
+ EDITABLE = "editableChanged"
"""Item's editable state changed flags."""
- SELECTABLE = 'selectableChanged'
+ SELECTABLE = "selectableChanged"
"""Item's selectable state changed flags."""
+ FONT = "fontChanged"
+ """Item's text font changed flag."""
+
+ BACKGROUND_COLOR = "backgroundColorChanged"
+ """Item's text background color changed flag."""
+
class Item(qt.QObject):
"""Description of an item of the plot"""
@@ -184,7 +192,7 @@ class Item(qt.QObject):
self._info = None
self._xlabel = None
self._ylabel = None
- self.__name = ''
+ self.__name = ""
self.__visibleBoundsTracking = False
self.__previousVisibleBounds = None
@@ -206,7 +214,7 @@ class Item(qt.QObject):
:param Union[~silx.gui.plot.PlotWidget,None] plot: The Plot instance.
"""
if plot is not None and self._plotRef is not None:
- raise RuntimeError('Trying to add a node at two places.')
+ raise RuntimeError("Trying to add a node at two places.")
self.__disconnectFromPlotWidget()
self._plotRef = None if plot is None else weakref.ref(plot)
self.__connectToPlotWidget()
@@ -240,8 +248,7 @@ class Item(qt.QObject):
if visible != self._visible:
self._visible = visible
# When visibility has changed, always mark as dirty
- self._updated(ItemChangedType.VISIBLE,
- checkVisibility=False)
+ self._updated(ItemChangedType.VISIBLE, checkVisibility=False)
if visible:
self._visibleBoundsChanged()
@@ -268,8 +275,7 @@ class Item(qt.QObject):
name = str(name)
if self.__name != name:
if self.getPlot() is not None:
- raise RuntimeError(
- "Cannot change name while item is in a PlotWidget")
+ raise RuntimeError("Cannot change name while item is in a PlotWidget")
self.__name = name
self._updated(ItemChangedType.NAME)
@@ -277,11 +283,6 @@ class Item(qt.QObject):
def getLegend(self): # Replaced by getName for API consistency
return self.getName()
- @deprecated(replacement='setName', since_version='0.13')
- def _setLegend(self, legend):
- legend = str(legend) if legend is not None else ''
- self.setName(legend)
-
def isSelectable(self):
"""Returns true if item is selectable (bool)"""
return self._selectable
@@ -332,7 +333,8 @@ class Item(qt.QObject):
xmin, xmax = numpy.clip(bounds[:2], *plot.getXAxis().getLimits())
ymin, ymax = numpy.clip(
- bounds[2:], *plot.getYAxis(self.__getYAxis()).getLimits())
+ bounds[2:], *plot.getYAxis(self.__getYAxis()).getLimits()
+ )
if xmin == xmax or ymin == ymax: # Outside the plot area
return None
@@ -360,7 +362,7 @@ class Item(qt.QObject):
def __getYAxis(self) -> str:
"""Returns current Y axis ('left' or 'right')"""
- return self.getYAxis() if isinstance(self, YAxisMixIn) else 'left'
+ return self.getYAxis() if isinstance(self, YAxisMixIn) else "left"
def __connectToPlotWidget(self) -> None:
"""Connect to PlotWidget signals and install event filter"""
@@ -486,7 +488,7 @@ class Item(qt.QObject):
class DataItem(Item):
"""Item with a data extent in the plot"""
- def _boundsChanged(self, checkVisibility: bool=True) -> None:
+ def _boundsChanged(self, checkVisibility: bool = True) -> None:
"""Call this method in subclass when data bounds has changed.
:param bool checkVisibility:
@@ -506,6 +508,7 @@ class DataItem(Item):
self._boundsChanged(checkVisibility=False)
super().setVisible(visible)
+
# Mix-in classes ##############################################################
@@ -522,8 +525,7 @@ class ItemMixInBase(object):
:param bool checkVisibility: True to only mark as dirty if visible,
False to always mark as dirty.
"""
- raise RuntimeError(
- "Issue with Mix-In class inheritance order")
+ raise RuntimeError("Issue with Mix-In class inheritance order")
class LabelsMixIn(ItemMixInBase):
@@ -597,7 +599,7 @@ class DraggableMixIn(ItemMixInBase):
raise NotImplementedError("Must be implemented in subclass")
-class ColormapMixIn(ItemMixInBase):
+class ColormapMixIn(_Colormappable, ItemMixInBase):
"""Mix-in class for items with colormap"""
def __init__(self):
@@ -631,8 +633,9 @@ class ColormapMixIn(ItemMixInBase):
"""Handle updates of the colormap"""
self._updated(ItemChangedType.COLORMAP)
- def _setColormappedData(self, data, copy=True,
- min_=None, minPositive=None, max_=None):
+ def _setColormappedData(
+ self, data, copy=True, min_=None, minPositive=None, max_=None
+ ):
"""Set the data used to compute the colormapped display.
It also resets the cache of data ranges.
@@ -653,7 +656,10 @@ class ColormapMixIn(ItemMixInBase):
if min_ is not None and numpy.isfinite(min_):
self.__cacheColormapRange[Colormap.LINEAR, Colormap.MINMAX] = min_, max_
if minPositive is not None and numpy.isfinite(minPositive):
- self.__cacheColormapRange[Colormap.LOGARITHM, Colormap.MINMAX] = minPositive, max_
+ self.__cacheColormapRange[Colormap.LOGARITHM, Colormap.MINMAX] = (
+ minPositive,
+ max_,
+ )
colormap = self.getColormap()
if None in (colormap.getVMin(), colormap.getVMax()):
@@ -705,26 +711,29 @@ class SymbolMixIn(ItemMixInBase):
_DEFAULT_SYMBOL_SIZE = config.DEFAULT_PLOT_SYMBOL_SIZE
"""Default marker size of the item"""
- _SUPPORTED_SYMBOLS = collections.OrderedDict((
- ('o', 'Circle'),
- ('d', 'Diamond'),
- ('s', 'Square'),
- ('+', 'Plus'),
- ('x', 'Cross'),
- ('.', 'Point'),
- (',', 'Pixel'),
- ('|', 'Vertical line'),
- ('_', 'Horizontal line'),
- ('tickleft', 'Tick left'),
- ('tickright', 'Tick right'),
- ('tickup', 'Tick up'),
- ('tickdown', 'Tick down'),
- ('caretleft', 'Caret left'),
- ('caretright', 'Caret right'),
- ('caretup', 'Caret up'),
- ('caretdown', 'Caret down'),
- (u'\u2665', 'Heart'),
- ('', 'None')))
+ _SUPPORTED_SYMBOLS = dict(
+ (
+ ("o", "Circle"),
+ ("d", "Diamond"),
+ ("s", "Square"),
+ ("+", "Plus"),
+ ("x", "Cross"),
+ (".", "Point"),
+ (",", "Pixel"),
+ ("|", "Vertical line"),
+ ("_", "Horizontal line"),
+ ("tickleft", "Tick left"),
+ ("tickright", "Tick right"),
+ ("tickup", "Tick up"),
+ ("tickdown", "Tick down"),
+ ("caretleft", "Caret left"),
+ ("caretright", "Caret right"),
+ ("caretup", "Caret up"),
+ ("caretdown", "Caret down"),
+ ("\u2665", "Heart"),
+ ("", "None"),
+ )
+ )
"""Dict of supported symbols"""
def __init__(self):
@@ -799,7 +808,7 @@ class SymbolMixIn(ItemMixInBase):
symbol = symbolCode
break
else:
- raise ValueError('Unsupported symbol %s' % str(symbol))
+ raise ValueError("Unsupported symbol %s" % str(symbol))
if symbol != self._symbol:
self._symbol = symbol
@@ -826,50 +835,74 @@ class SymbolMixIn(ItemMixInBase):
self._updated(ItemChangedType.SYMBOL_SIZE)
+LineStyleType = Union[
+ str,
+ Tuple[Union[float, int], None],
+ Tuple[Union[float, int], Tuple[Union[float, int], Union[float, int]]],
+ Tuple[Union[float, int], Tuple[Union[float, int], Union[float, int], Union[float, int], Union[float, int]]],
+]
+"""Type for :class:`LineMixIn`'s line style"""
+
+
class LineMixIn(ItemMixInBase):
"""Mix-in class for item with line"""
- _DEFAULT_LINEWIDTH = 1.
+ _DEFAULT_LINEWIDTH: float = 1.0
"""Default line width"""
- _DEFAULT_LINESTYLE = '-'
+ _DEFAULT_LINESTYLE: LineStyleType = "-"
"""Default line style"""
- _SUPPORTED_LINESTYLE = '', ' ', '-', '--', '-.', ':', None
+ _SUPPORTED_LINESTYLE = "", " ", "-", "--", "-.", ":", None
"""Supported line styles"""
def __init__(self):
- self._linewidth = self._DEFAULT_LINEWIDTH
- self._linestyle = self._DEFAULT_LINESTYLE
+ self._linewidth: float = self._DEFAULT_LINEWIDTH
+ self._linestyle: LineStyleType = self._DEFAULT_LINESTYLE
@classmethod
- def getSupportedLineStyles(cls):
- """Returns list of supported line styles.
-
- :rtype: List[str,None]
- """
+ def getSupportedLineStyles(cls) -> tuple[str | None]:
+ """Returns list of supported constant line styles."""
return cls._SUPPORTED_LINESTYLE
- def getLineWidth(self):
- """Return the curve line width in pixels
-
- :rtype: float
- """
+ def getLineWidth(self) -> float:
+ """Return the curve line width in pixels"""
return self._linewidth
- def setLineWidth(self, width):
+ def setLineWidth(self, width: float):
"""Set the width in pixel of the curve line
See :meth:`getLineWidth`.
-
- :param float width: Width in pixels
"""
width = float(width)
if width != self._linewidth:
self._linewidth = width
self._updated(ItemChangedType.LINE_WIDTH)
- def getLineStyle(self):
+ @classmethod
+ def isValidLineStyle(cls, style: LineStyleType | None) -> bool:
+ """Returns True for valid styles"""
+ if style is None or style in cls.getSupportedLineStyles():
+ return True
+ if not isinstance(style, tuple):
+ return False
+ if (
+ len(style) == 2
+ and isinstance(style[0], (float, int))
+ and (
+ style[1] is None
+ or style[1] == ()
+ or (
+ isinstance(style[1], tuple)
+ and len(style[1]) in (2, 4)
+ and all(map(lambda item: isinstance(item, (float, int)), style[1]))
+ )
+ )
+ ):
+ return True
+ return False
+
+ def getLineStyle(self) -> LineStyleType:
"""Return the type of the line
Type of line::
@@ -879,20 +912,19 @@ class LineMixIn(ItemMixInBase):
- '--' dashed line
- '-.' dash-dot line
- ':' dotted line
-
- :rtype: str
+ - (offset, (dash pattern))
"""
return self._linestyle
- def setLineStyle(self, style):
+ def setLineStyle(self, style: LineStyleType | None):
"""Set the style of the curve line.
See :meth:`getLineStyle`.
- :param str style: Line style
+ :param style: Line style
"""
- style = str(style)
- assert style in self.getSupportedLineStyles()
+ if not self.isValidLineStyle(style):
+ raise ValueError(f"No a valid line style: {style}")
if style is None:
style = self._DEFAULT_LINESTYLE
if style != self._linestyle:
@@ -903,7 +935,7 @@ class LineMixIn(ItemMixInBase):
class ColorMixIn(ItemMixInBase):
"""Mix-in class for item with color"""
- _DEFAULT_COLOR = (0., 0., 0., 1.)
+ _DEFAULT_COLOR = (0.0, 0.0, 0.0, 1.0)
"""Default color of the item"""
def __init__(self):
@@ -941,10 +973,43 @@ class ColorMixIn(ItemMixInBase):
self._updated(ItemChangedType.COLOR)
+class LineGapColorMixIn(ItemMixInBase):
+ """Mix-in class for dashed line gap color"""
+
+ _DEFAULT_LINE_GAP_COLOR = None
+ """Default dashed line gap color of the item"""
+
+ def __init__(self):
+ self.__lineGapColor = self._DEFAULT_LINE_GAP_COLOR
+
+ def getLineGapColor(self):
+ """Returns the RGBA color of dashed line gap of the item
+
+ :rtype: 4-tuple of float in [0, 1] or None
+ """
+ return self.__lineGapColor
+
+ def setLineGapColor(self, color):
+ """Set dashed line gap color
+
+ It supports:
+ - color names: e.g., 'green'
+ - color codes: '#RRGGBB' and '#RRGGBBAA'
+ - indexed color names: e.g., 'C0'
+ - RGB(A) sequence of uint8 in [0, 255] or float in [0, 1]
+ - QColor
+
+ :param color: line background color to be used
+ :type color: Union[str, List[int], List[float], QColor, None]
+ """
+ self.__lineGapColor = None if color is None else colors.rgba(color)
+ self._updated(ItemChangedType.LINE_GAP_COLOR)
+
+
class YAxisMixIn(ItemMixInBase):
"""Mix-in class for item with yaxis"""
- _DEFAULT_YAXIS = 'left'
+ _DEFAULT_YAXIS = "left"
"""Default Y axis the item belongs to"""
def __init__(self):
@@ -965,7 +1030,7 @@ class YAxisMixIn(ItemMixInBase):
:param str yaxis: 'left' or 'right'
"""
yaxis = str(yaxis)
- assert yaxis in ('left', 'right')
+ assert yaxis in ("left", "right")
if yaxis != self._yaxis:
self._yaxis = yaxis
# Handle data extent changed for DataItem
@@ -977,11 +1042,13 @@ class YAxisMixIn(ItemMixInBase):
# Switch Y axis signal connection
plot = self.getPlot()
if plot is not None:
- previousYAxis = 'left' if self.getXAxis() == 'right' else 'right'
+ previousYAxis = "left" if self.getXAxis() == "right" else "right"
plot.getYAxis(previousYAxis).sigLimitsChanged.disconnect(
- self._visibleBoundsChanged)
+ self._visibleBoundsChanged
+ )
plot.getYAxis(self.getYAxis()).sigLimitsChanged.connect(
- self._visibleBoundsChanged)
+ self._visibleBoundsChanged
+ )
self._visibleBoundsChanged()
self._updated(ItemChangedType.YAXIS)
@@ -1015,7 +1082,7 @@ class AlphaMixIn(ItemMixInBase):
"""Mix-in class for item with opacity"""
def __init__(self):
- self._alpha = 1.
+ self._alpha = 1.0
def getAlpha(self):
"""Returns the opacity of the item
@@ -1038,7 +1105,7 @@ class AlphaMixIn(ItemMixInBase):
:type alpha: float
"""
alpha = float(alpha)
- alpha = max(0., min(alpha, 1.)) # Clip alpha to [0., 1.] range
+ alpha = max(0.0, min(alpha, 1.0)) # Clip alpha to [0., 1.] range
if alpha != self._alpha:
self._alpha = alpha
self._updated(ItemChangedType.ALPHA)
@@ -1052,14 +1119,15 @@ class ComplexMixIn(ItemMixInBase):
class ComplexMode(_Enum):
"""Identify available display mode for complex"""
- NONE = 'none'
- ABSOLUTE = 'amplitude'
- PHASE = 'phase'
- REAL = 'real'
- IMAGINARY = 'imaginary'
- AMPLITUDE_PHASE = 'amplitude_phase'
- LOG10_AMPLITUDE_PHASE = 'log10_amplitude_phase'
- SQUARE_AMPLITUDE = 'square_amplitude'
+
+ NONE = "none"
+ ABSOLUTE = "amplitude"
+ PHASE = "phase"
+ REAL = "real"
+ IMAGINARY = "imaginary"
+ AMPLITUDE_PHASE = "amplitude_phase"
+ LOG10_AMPLITUDE_PHASE = "log10_amplitude_phase"
+ SQUARE_AMPLITUDE = "square_amplitude"
def __init__(self):
self.__complex_mode = self.ComplexMode.ABSOLUTE
@@ -1115,7 +1183,7 @@ class ComplexMixIn(ItemMixInBase):
elif mode is self.ComplexMode.SQUARE_AMPLITUDE:
return numpy.absolute(data) ** 2
else:
- raise ValueError('Unsupported conversion mode: %s', str(mode))
+ raise ValueError("Unsupported conversion mode: %s", str(mode))
@classmethod
def supportedComplexModes(cls):
@@ -1141,22 +1209,22 @@ class ScatterVisualizationMixIn(ItemMixInBase):
class Visualization(_Enum):
"""Different modes of scatter plot visualizations"""
- POINTS = 'points'
+ POINTS = "points"
"""Display scatter plot as a point cloud"""
- LINES = 'lines'
+ LINES = "lines"
"""Display scatter plot as a wireframe.
This is based on Delaunay triangulation
"""
- SOLID = 'solid'
+ SOLID = "solid"
"""Display scatter plot as a set of filled triangles.
This is based on Delaunay triangulation
"""
- REGULAR_GRID = 'regular_grid'
+ REGULAR_GRID = "regular_grid"
"""Display scatter plot as an image.
It expects the points to be the intersection of a regular grid,
@@ -1165,7 +1233,7 @@ class ScatterVisualizationMixIn(ItemMixInBase):
(either all lines from left to right or all from right to left).
"""
- IRREGULAR_GRID = 'irregular_grid'
+ IRREGULAR_GRID = "irregular_grid"
"""Display scatter plot as contiguous quadrilaterals.
It expects the points to be the intersection of an irregular grid,
@@ -1174,7 +1242,7 @@ class ScatterVisualizationMixIn(ItemMixInBase):
(either all lines from left to right or all from right to left).
"""
- BINNED_STATISTIC = 'binned_statistic'
+ BINNED_STATISTIC = "binned_statistic"
"""Display scatter plot as 2D binned statistic (i.e., generalized histogram).
"""
@@ -1182,13 +1250,13 @@ class ScatterVisualizationMixIn(ItemMixInBase):
class VisualizationParameter(_Enum):
"""Different parameter names for scatter plot visualizations"""
- GRID_MAJOR_ORDER = 'grid_major_order'
+ GRID_MAJOR_ORDER = "grid_major_order"
"""The major order of points in the regular grid.
Either 'row' (row-major, fast X) or 'column' (column-major, fast Y).
"""
- GRID_BOUNDS = 'grid_bounds'
+ GRID_BOUNDS = "grid_bounds"
"""The expected range in data coordinates of the regular grid.
A 2-tuple of 2-tuple: (begin (x, y), end (x, y)).
@@ -1197,24 +1265,24 @@ class ScatterVisualizationMixIn(ItemMixInBase):
As for `GRID_SHAPE`, this can be wider than the current data.
"""
- GRID_SHAPE = 'grid_shape'
+ GRID_SHAPE = "grid_shape"
"""The expected size of the regular grid (height, width).
The given shape can be wider than the number of points,
in which case the grid is not fully filled.
"""
- BINNED_STATISTIC_SHAPE = 'binned_statistic_shape'
+ BINNED_STATISTIC_SHAPE = "binned_statistic_shape"
"""The number of bins in each dimension (height, width).
"""
- BINNED_STATISTIC_FUNCTION = 'binned_statistic_function'
+ BINNED_STATISTIC_FUNCTION = "binned_statistic_function"
"""The reduction function to apply to each bin (str).
Available reduction functions are: 'mean' (default), 'count', 'sum'.
"""
- DATA_BOUNDS_HINT = 'data_bounds_hint'
+ DATA_BOUNDS_HINT = "data_bounds_hint"
"""The expected bounds of the data in data coordinates.
A 2-tuple of 2-tuple: ((ymin, ymax), (xmin, xmax)).
@@ -1225,8 +1293,8 @@ class ScatterVisualizationMixIn(ItemMixInBase):
"""
_SUPPORTED_VISUALIZATION_PARAMETER_VALUES = {
- VisualizationParameter.GRID_MAJOR_ORDER: ('row', 'column'),
- VisualizationParameter.BINNED_STATISTIC_FUNCTION: ('mean', 'count', 'sum'),
+ VisualizationParameter.GRID_MAJOR_ORDER: ("row", "column"),
+ VisualizationParameter.BINNED_STATISTIC_FUNCTION: ("mean", "count", "sum"),
}
"""Supported visualization parameter values.
@@ -1235,9 +1303,12 @@ class ScatterVisualizationMixIn(ItemMixInBase):
def __init__(self):
self.__visualization = self.Visualization.POINTS
- self.__parameters = dict(# Init parameters to None
- (parameter, None) for parameter in self.VisualizationParameter)
- self.__parameters[self.VisualizationParameter.BINNED_STATISTIC_FUNCTION] = 'mean'
+ self.__parameters = dict( # Init parameters to None
+ (parameter, None) for parameter in self.VisualizationParameter
+ )
+ self.__parameters[
+ self.VisualizationParameter.BINNED_STATISTIC_FUNCTION
+ ] = "mean"
@classmethod
def supportedVisualizations(cls):
@@ -1263,8 +1334,7 @@ class ScatterVisualizationMixIn(ItemMixInBase):
:returns: tuple of supported of values or None if not defined.
"""
parameter = cls.VisualizationParameter(parameter)
- return cls._SUPPORTED_VISUALIZATION_PARAMETER_VALUES.get(
- parameter, None)
+ return cls._SUPPORTED_VISUALIZATION_PARAMETER_VALUES.get(parameter, None)
def setVisualization(self, mode):
"""Set the scatter plot visualization mode to use.
@@ -1351,6 +1421,7 @@ class ScatterVisualizationMixIn(ItemMixInBase):
class PointsBase(DataItem, SymbolMixIn, AlphaMixIn):
"""Base class for :class:`Curve` and :class:`Scatter`"""
+
# note: _logFilterData must be overloaded if you overload
# getData to change its signature
@@ -1398,22 +1469,18 @@ class PointsBase(DataItem, SymbolMixIn, AlphaMixIn):
errorClipped[mask] = valueMinusError[mask] <= 0
if numpy.any(errorClipped): # Need filtering
-
# expand errorbars to 2xN
if error.size == 1: # Scalar
- error = numpy.full(
- (2, len(value)), error, dtype=numpy.float64)
+ error = numpy.full((2, len(value)), error, dtype=numpy.float64)
elif error.ndim == 1: # N array
- newError = numpy.empty((2, len(value)),
- dtype=numpy.float64)
- newError[0,:] = error
- newError[1,:] = error
+ newError = numpy.empty((2, len(value)), dtype=numpy.float64)
+ newError[0, :] = error
+ newError[1, :] = error
error = newError
elif error.size == 2 * len(value): # 2xN array
- error = numpy.array(
- error, copy=True, dtype=numpy.float64)
+ error = numpy.array(error, copy=True, dtype=numpy.float64)
else:
_logger.error("Unhandled error array")
@@ -1437,16 +1504,17 @@ class PointsBase(DataItem, SymbolMixIn, AlphaMixIn):
if xPositive:
x = self.getXData(copy=False)
- with numpy.errstate(invalid='ignore'): # Ignore NaN warnings
+ with numpy.errstate(invalid="ignore"): # Ignore NaN warnings
xclipped = x <= 0
if yPositive:
y = self.getYData(copy=False)
- with numpy.errstate(invalid='ignore'): # Ignore NaN warnings
+ with numpy.errstate(invalid="ignore"): # Ignore NaN warnings
yclipped = y <= 0
- self._clippedCache[(xPositive, yPositive)] = \
- numpy.logical_or(xclipped, yclipped)
+ self._clippedCache[(xPositive, yPositive)] = numpy.logical_or(
+ xclipped, yclipped
+ )
return self._clippedCache[(xPositive, yPositive)]
def _logFilterData(self, xPositive, yPositive):
@@ -1484,7 +1552,7 @@ class PointsBase(DataItem, SymbolMixIn, AlphaMixIn):
def __minMaxDataWithError(
data: numpy.ndarray,
error: Optional[Union[float, numpy.ndarray]],
- positiveOnly: bool
+ positiveOnly: bool,
) -> Tuple[float]:
if error is None:
min_, max_ = min_max(data, finite=True)
@@ -1532,9 +1600,12 @@ class PointsBase(DataItem, SymbolMixIn, AlphaMixIn):
xmin, xmax = self.__minMaxDataWithError(x, xerror, xPositive)
ymin, ymax = self.__minMaxDataWithError(y, yerror, yPositive)
- self._boundsCache[(xPositive, yPositive)] = tuple([
- (bound if bound is not None else numpy.nan)
- for bound in (xmin, xmax, ymin, ymax)])
+ self._boundsCache[(xPositive, yPositive)] = tuple(
+ [
+ (bound if bound is not None else numpy.nan)
+ for bound in (xmin, xmax, ymin, ymax)
+ ]
+ )
return self._boundsCache[(xPositive, yPositive)]
def _getCachedData(self):
@@ -1548,8 +1619,9 @@ class PointsBase(DataItem, SymbolMixIn, AlphaMixIn):
if xPositive or yPositive:
# At least one axis has log scale, filter data
if (xPositive, yPositive) not in self._filteredCache:
- self._filteredCache[(xPositive, yPositive)] = \
- self._logFilterData(xPositive, yPositive)
+ self._filteredCache[(xPositive, yPositive)] = self._logFilterData(
+ xPositive, yPositive
+ )
return self._filteredCache[(xPositive, yPositive)]
return None
@@ -1570,10 +1642,12 @@ class PointsBase(DataItem, SymbolMixIn, AlphaMixIn):
if cached_data is not None:
return cached_data
- return (self.getXData(copy),
- self.getYData(copy),
- self.getXErrorData(copy),
- self.getYErrorData(copy))
+ return (
+ self.getXData(copy),
+ self.getYData(copy),
+ self.getXErrorData(copy),
+ self.getYErrorData(copy),
+ )
def getXData(self, copy=True):
"""Returns the x coordinates of the data points
@@ -1640,12 +1714,10 @@ class PointsBase(DataItem, SymbolMixIn, AlphaMixIn):
# Convert complex data
if numpy.iscomplexobj(x):
- _logger.warning(
- 'Converting x data to absolute value to plot it.')
+ _logger.warning("Converting x data to absolute value to plot it.")
x = numpy.absolute(x)
if numpy.iscomplexobj(y):
- _logger.warning(
- 'Converting y data to absolute value to plot it.')
+ _logger.warning("Converting y data to absolute value to plot it.")
y = numpy.absolute(y)
if xerror is not None:
@@ -1653,7 +1725,8 @@ class PointsBase(DataItem, SymbolMixIn, AlphaMixIn):
xerror = numpy.array(xerror, copy=copy)
if numpy.iscomplexobj(xerror):
_logger.warning(
- 'Converting xerror data to absolute value to plot it.')
+ "Converting xerror data to absolute value to plot it."
+ )
xerror = numpy.absolute(xerror)
else:
xerror = float(xerror)
@@ -1662,7 +1735,8 @@ class PointsBase(DataItem, SymbolMixIn, AlphaMixIn):
yerror = numpy.array(yerror, copy=copy)
if numpy.iscomplexobj(yerror):
_logger.warning(
- 'Converting yerror data to absolute value to plot it.')
+ "Converting yerror data to absolute value to plot it."
+ )
yerror = numpy.absolute(yerror)
else:
yerror = float(yerror)
@@ -1691,7 +1765,7 @@ class BaselineMixIn(object):
:param baseline: baseline value(s)
:type: Union[None,float,numpy.ndarray]
"""
- if (isinstance(baseline, abc.Iterable)):
+ if isinstance(baseline, abc.Iterable):
baseline = numpy.array(baseline)
self._baseline = baseline
@@ -1713,7 +1787,6 @@ class _Style:
class HighlightedMixIn(ItemMixInBase):
-
def __init__(self):
self._highlightStyle = self._DEFAULT_HIGHLIGHT_STYLE
self._highlighted = False
diff --git a/src/silx/gui/plot/items/curve.py b/src/silx/gui/plot/items/curve.py
index 93e4719..e8d0d52 100644
--- a/src/silx/gui/plot/items/curve.py
+++ b/src/silx/gui/plot/items/curve.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -23,6 +23,7 @@
# ###########################################################################*/
"""This module provides the :class:`Curve` item of the :class:`Plot`.
"""
+from __future__ import annotations
__authors__ = ["T. Vincent"]
__license__ = "MIT"
@@ -33,11 +34,22 @@ import logging
import numpy
-from ....utils.deprecation import deprecated
+from ....utils.deprecation import deprecated_warning
from ... import colors
-from .core import (PointsBase, LabelsMixIn, ColorMixIn, YAxisMixIn,
- FillMixIn, LineMixIn, SymbolMixIn, ItemChangedType,
- BaselineMixIn, HighlightedMixIn, _Style)
+from .core import (
+ PointsBase,
+ LabelsMixIn,
+ ColorMixIn,
+ YAxisMixIn,
+ FillMixIn,
+ LineMixIn,
+ LineGapColorMixIn,
+ LineStyleType,
+ SymbolMixIn,
+ BaselineMixIn,
+ HighlightedMixIn,
+ _Style,
+)
_logger = logging.getLogger(__name__)
@@ -49,14 +61,22 @@ class CurveStyle(_Style):
Set a value to None to use the default
:param color: Color
- :param Union[str,None] linestyle: Style of the line
- :param Union[float,None] linewidth: Width of the line
- :param Union[str,None] symbol: Symbol for markers
- :param Union[float,None] symbolsize: Size of the markers
+ :param linestyle: Style of the line
+ :param linewidth: Width of the line
+ :param symbol: Symbol for markers
+ :param symbolsize: Size of the markers
+ :param gapcolor: Color of gaps of dashed line
"""
- def __init__(self, color=None, linestyle=None, linewidth=None,
- symbol=None, symbolsize=None):
+ def __init__(
+ self,
+ color: colors.RGBAColorType | None = None,
+ linestyle: LineStyleType | None = None,
+ linewidth: float | None = None,
+ symbol: str | None = None,
+ symbolsize: float | None = None,
+ gapcolor: colors.RGBAColorType | None = None,
+ ):
if color is None:
self._color = None
else:
@@ -68,8 +88,8 @@ class CurveStyle(_Style):
color = colors.rgba(color)
self._color = color
- if linestyle is not None:
- assert linestyle in LineMixIn.getSupportedLineStyles()
+ if not LineMixIn.isValidLineStyle(linestyle):
+ raise ValueError(f"Not a valid line style: {linestyle}")
self._linestyle = linestyle
self._linewidth = None if linewidth is None else float(linewidth)
@@ -80,6 +100,8 @@ class CurveStyle(_Style):
self._symbolsize = None if symbolsize is None else float(symbolsize)
+ self._gapcolor = None if gapcolor is None else colors.rgba(gapcolor)
+
def getColor(self, copy=True):
"""Returns the color or None if not set.
@@ -93,7 +115,14 @@ class CurveStyle(_Style):
else:
return self._color
- def getLineStyle(self):
+ def getLineGapColor(self):
+ """Returns the color of dashed line gaps or None if not set.
+
+ :rtype: Union[List[float],None]
+ """
+ return self._gapcolor
+
+ def getLineStyle(self) -> LineStyleType | None:
"""Return the type of the line or None if not set.
Type of line::
@@ -103,8 +132,7 @@ class CurveStyle(_Style):
- '--' dashed line
- '-.' dash-dot line
- ':' dotted line
-
- :rtype: Union[str,None]
+ - (offset, (dash pattern))
"""
return self._linestyle
@@ -141,17 +169,29 @@ class CurveStyle(_Style):
def __eq__(self, other):
if isinstance(other, CurveStyle):
- return (numpy.array_equal(self.getColor(), other.getColor()) and
- self.getLineStyle() == other.getLineStyle() and
- self.getLineWidth() == other.getLineWidth() and
- self.getSymbol() == other.getSymbol() and
- self.getSymbolSize() == other.getSymbolSize())
+ return (
+ numpy.array_equal(self.getColor(), other.getColor())
+ and self.getLineStyle() == other.getLineStyle()
+ and self.getLineWidth() == other.getLineWidth()
+ and self.getSymbol() == other.getSymbol()
+ and self.getSymbolSize() == other.getSymbolSize()
+ and self.getLineGapColor() == other.getLineGapColor()
+ )
else:
return False
-class Curve(PointsBase, ColorMixIn, YAxisMixIn, FillMixIn, LabelsMixIn,
- LineMixIn, BaselineMixIn, HighlightedMixIn):
+class Curve(
+ PointsBase,
+ ColorMixIn,
+ YAxisMixIn,
+ FillMixIn,
+ LabelsMixIn,
+ LineMixIn,
+ LineGapColorMixIn,
+ BaselineMixIn,
+ HighlightedMixIn,
+):
"""Description of a curve"""
_DEFAULT_Z_LAYER = 1
@@ -160,13 +200,13 @@ class Curve(PointsBase, ColorMixIn, YAxisMixIn, FillMixIn, LabelsMixIn,
_DEFAULT_SELECTABLE = True
"""Default selectable state for curves"""
- _DEFAULT_LINEWIDTH = 1.
+ _DEFAULT_LINEWIDTH = 1.0
"""Default line width of the curve"""
- _DEFAULT_LINESTYLE = '-'
+ _DEFAULT_LINESTYLE = "-"
"""Default line style of the curve"""
- _DEFAULT_HIGHLIGHT_STYLE = CurveStyle(color='black')
+ _DEFAULT_HIGHLIGHT_STYLE = CurveStyle(color="black")
"""Default highlight style of the item"""
_DEFAULT_BASELINE = None
@@ -178,6 +218,7 @@ class Curve(PointsBase, ColorMixIn, YAxisMixIn, FillMixIn, LabelsMixIn,
FillMixIn.__init__(self)
LabelsMixIn.__init__(self)
LineMixIn.__init__(self)
+ LineGapColorMixIn.__init__(self)
BaselineMixIn.__init__(self)
HighlightedMixIn.__init__(self)
@@ -186,29 +227,38 @@ class Curve(PointsBase, ColorMixIn, YAxisMixIn, FillMixIn, LabelsMixIn,
def _addBackendRenderer(self, backend):
"""Update backend renderer"""
# Filter-out values <= 0
- xFiltered, yFiltered, xerror, yerror = self.getData(
- copy=False, displayed=True)
+ xFiltered, yFiltered, xerror, yerror = self.getData(copy=False, displayed=True)
if len(xFiltered) == 0 or not numpy.any(numpy.isfinite(xFiltered)):
return None # No data to display, do not add renderer to backend
style = self.getCurrentStyle()
- return backend.addCurve(xFiltered, yFiltered,
- color=style.getColor(),
- symbol=style.getSymbol(),
- linestyle=style.getLineStyle(),
- linewidth=style.getLineWidth(),
- yaxis=self.getYAxis(),
- xerror=xerror,
- yerror=yerror,
- fill=self.isFill(),
- alpha=self.getAlpha(),
- symbolsize=style.getSymbolSize(),
- baseline=self.getBaseline(copy=False))
+ return backend.addCurve(
+ xFiltered,
+ yFiltered,
+ color=style.getColor(),
+ gapcolor=style.getLineGapColor(),
+ symbol=style.getSymbol(),
+ linestyle=style.getLineStyle(),
+ linewidth=style.getLineWidth(),
+ yaxis=self.getYAxis(),
+ xerror=xerror,
+ yerror=yerror,
+ fill=self.isFill(),
+ alpha=self.getAlpha(),
+ symbolsize=style.getSymbolSize(),
+ baseline=self.getBaseline(copy=False),
+ )
def __getitem__(self, item):
"""Compatibility with PyMca and silx <= 0.4.0"""
+ deprecated_warning(
+ "Attributes",
+ "__getitem__",
+ since_version="2.0.0",
+ replacement="Use Curve methods",
+ )
if isinstance(item, slice):
return [self[index] for index in range(*item.indices(5))]
elif item == 0:
@@ -222,44 +272,24 @@ class Curve(PointsBase, ColorMixIn, YAxisMixIn, FillMixIn, LabelsMixIn,
return {} if info is None else info
elif item == 4:
params = {
- 'info': self.getInfo(),
- 'color': self.getColor(),
- 'symbol': self.getSymbol(),
- 'linewidth': self.getLineWidth(),
- 'linestyle': self.getLineStyle(),
- 'xlabel': self.getXLabel(),
- 'ylabel': self.getYLabel(),
- 'yaxis': self.getYAxis(),
- 'xerror': self.getXErrorData(copy=False),
- 'yerror': self.getYErrorData(copy=False),
- 'z': self.getZValue(),
- 'selectable': self.isSelectable(),
- 'fill': self.isFill(),
+ "info": self.getInfo(),
+ "color": self.getColor(),
+ "symbol": self.getSymbol(),
+ "linewidth": self.getLineWidth(),
+ "linestyle": self.getLineStyle(),
+ "xlabel": self.getXLabel(),
+ "ylabel": self.getYLabel(),
+ "yaxis": self.getYAxis(),
+ "xerror": self.getXErrorData(copy=False),
+ "yerror": self.getYErrorData(copy=False),
+ "z": self.getZValue(),
+ "selectable": self.isSelectable(),
+ "fill": self.isFill(),
}
return params
else:
raise IndexError("Index out of range: %s", str(item))
- @deprecated(replacement='Curve.getHighlightedStyle().getColor()',
- since_version='0.9.0')
- def getHighlightedColor(self):
- """Returns the RGBA highlight color of the item
-
- :rtype: 4-tuple of float in [0, 1]
- """
- return self.getHighlightedStyle().getColor()
-
- @deprecated(replacement='Curve.setHighlightedStyle()',
- since_version='0.9.0')
- def setHighlightedColor(self, color):
- """Set the color to use when highlighted
-
- :param color: color(s) to be used for highlight
- :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or
- one of the predefined color names defined in colors.py
- """
- self.setHighlightedStyle(CurveStyle(color))
-
def getCurrentStyle(self):
"""Returns the current curve style.
@@ -274,32 +304,26 @@ class Curve(PointsBase, ColorMixIn, YAxisMixIn, FillMixIn, LabelsMixIn,
linewidth = style.getLineWidth()
symbol = style.getSymbol()
symbolsize = style.getSymbolSize()
+ gapcolor = style.getLineGapColor()
return CurveStyle(
color=self.getColor() if color is None else color,
linestyle=self.getLineStyle() if linestyle is None else linestyle,
linewidth=self.getLineWidth() if linewidth is None else linewidth,
symbol=self.getSymbol() if symbol is None else symbol,
- symbolsize=self.getSymbolSize() if symbolsize is None else symbolsize)
+ symbolsize=self.getSymbolSize() if symbolsize is None else symbolsize,
+ gapcolor=self.getLineGapColor() if gapcolor is None else gapcolor,
+ )
else:
- return CurveStyle(color=self.getColor(),
- linestyle=self.getLineStyle(),
- linewidth=self.getLineWidth(),
- symbol=self.getSymbol(),
- symbolsize=self.getSymbolSize())
-
- @deprecated(replacement='Curve.getCurrentStyle()',
- since_version='0.9.0')
- def getCurrentColor(self):
- """Returns the current color of the curve.
-
- This color is either the color of the curve or the highlighted color,
- depending on the highlight state.
-
- :rtype: 4-tuple of float in [0, 1]
- """
- return self.getCurrentStyle().getColor()
+ return CurveStyle(
+ color=self.getColor(),
+ linestyle=self.getLineStyle(),
+ linewidth=self.getLineWidth(),
+ symbol=self.getSymbol(),
+ symbolsize=self.getSymbolSize(),
+ gapcolor=self.getLineGapColor(),
+ )
def setData(self, x, y, xerror=None, yerror=None, baseline=None, copy=True):
"""Set the data of the curve.
@@ -319,6 +343,5 @@ class Curve(PointsBase, ColorMixIn, YAxisMixIn, FillMixIn, LabelsMixIn,
:param bool copy: True make a copy of the data (default),
False to use provided arrays.
"""
- PointsBase.setData(self, x=x, y=y, xerror=xerror, yerror=yerror,
- copy=copy)
+ PointsBase.setData(self, x=x, y=y, xerror=xerror, yerror=yerror, copy=copy)
self._setBaseline(baseline=baseline)
diff --git a/src/silx/gui/plot/items/histogram.py b/src/silx/gui/plot/items/histogram.py
index 007f0c7..1dc851b 100644
--- a/src/silx/gui/plot/items/histogram.py
+++ b/src/silx/gui/plot/items/histogram.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -32,15 +32,20 @@ import logging
import typing
import numpy
-from collections import OrderedDict, namedtuple
-try:
- from collections import abc
-except ImportError: # Python2 support
- import collections as abc
+from collections import abc
from ....utils.proxy import docstring
-from .core import (DataItem, AlphaMixIn, BaselineMixIn, ColorMixIn, FillMixIn,
- LineMixIn, YAxisMixIn, ItemChangedType, Item)
+from .core import (
+ DataItem,
+ AlphaMixIn,
+ BaselineMixIn,
+ ColorMixIn,
+ FillMixIn,
+ LineMixIn,
+ LineGapColorMixIn,
+ YAxisMixIn,
+ ItemChangedType,
+)
from ._pick import PickingResult
_logger = logging.getLogger(__name__)
@@ -62,17 +67,17 @@ def _computeEdges(x, histogramType):
"""
# for now we consider that the spaces between xs are constant
edges = x.copy()
- if histogramType == 'left':
+ if histogramType == "left":
width = 1
if len(x) > 1:
width = x[1] - x[0]
edges = numpy.append(x[0] - width, edges)
- if histogramType == 'center':
- edges = _computeEdges(edges, 'right')
+ if histogramType == "center":
+ edges = _computeEdges(edges, "right")
widths = (edges[1:] - edges[0:-1]) / 2.0
widths = numpy.append(widths, widths[-1])
edges = edges - widths
- if histogramType == 'right':
+ if histogramType == "right":
width = 1
if len(x) > 1:
width = x[-1] - x[-2]
@@ -102,8 +107,16 @@ def _getHistogramCurve(histogram, edges):
# TODO: Yerror, test log scale
-class Histogram(DataItem, AlphaMixIn, ColorMixIn, FillMixIn,
- LineMixIn, YAxisMixIn, BaselineMixIn):
+class Histogram(
+ DataItem,
+ AlphaMixIn,
+ ColorMixIn,
+ FillMixIn,
+ LineMixIn,
+ LineGapColorMixIn,
+ YAxisMixIn,
+ BaselineMixIn,
+):
"""Description of an histogram"""
_DEFAULT_Z_LAYER = 1
@@ -112,10 +125,10 @@ class Histogram(DataItem, AlphaMixIn, ColorMixIn, FillMixIn,
_DEFAULT_SELECTABLE = False
"""Default selectable state for histograms"""
- _DEFAULT_LINEWIDTH = 1.
+ _DEFAULT_LINEWIDTH = 1.0
"""Default line width of the histogram"""
- _DEFAULT_LINESTYLE = '-'
+ _DEFAULT_LINESTYLE = "-"
"""Default line style of the histogram"""
_DEFAULT_BASELINE = None
@@ -127,6 +140,7 @@ class Histogram(DataItem, AlphaMixIn, ColorMixIn, FillMixIn,
ColorMixIn.__init__(self)
FillMixIn.__init__(self)
LineMixIn.__init__(self)
+ LineGapColorMixIn.__init__(self)
YAxisMixIn.__init__(self)
self._histogram = ()
@@ -156,26 +170,30 @@ class Histogram(DataItem, AlphaMixIn, ColorMixIn, FillMixIn,
if xPositive or yPositive:
clipped = numpy.logical_or(
- (x <= 0) if xPositive else False,
- (y <= 0) if yPositive else False)
+ (x <= 0) if xPositive else False, (y <= 0) if yPositive else False
+ )
# Make a copy and replace negative points by NaN
x = numpy.array(x, dtype=numpy.float64)
y = numpy.array(y, dtype=numpy.float64)
x[clipped] = numpy.nan
y[clipped] = numpy.nan
- return backend.addCurve(x, y,
- color=self.getColor(),
- symbol='',
- linestyle=self.getLineStyle(),
- linewidth=self.getLineWidth(),
- yaxis=self.getYAxis(),
- xerror=None,
- yerror=None,
- fill=self.isFill(),
- alpha=self.getAlpha(),
- baseline=baseline,
- symbolsize=1)
+ return backend.addCurve(
+ x,
+ y,
+ color=self.getColor(),
+ gapcolor=self.getLineGapColor(),
+ symbol="",
+ linestyle=self.getLineStyle(),
+ linewidth=self.getLineWidth(),
+ yaxis=self.getYAxis(),
+ xerror=None,
+ yerror=None,
+ fill=self.isFill(),
+ alpha=self.getAlpha(),
+ baseline=baseline,
+ symbolsize=1,
+ )
def _getBounds(self):
values, edges, baseline = self.getData(copy=False)
@@ -193,11 +211,10 @@ class Histogram(DataItem, AlphaMixIn, ColorMixIn, FillMixIn,
if xPositive:
# Replace edges <= 0 by NaN and corresponding values by NaN
- clipped_edges = (edges <= 0)
+ clipped_edges = edges <= 0
edges = numpy.array(edges, copy=True, dtype=numpy.float64)
edges[clipped_edges] = numpy.nan
- clipped_values = numpy.logical_or(clipped_edges[:-1],
- clipped_edges[1:])
+ clipped_values = numpy.logical_or(clipped_edges[:-1], clipped_edges[1:])
else:
clipped_values = numpy.zeros_like(values, dtype=bool)
@@ -208,20 +225,26 @@ class Histogram(DataItem, AlphaMixIn, ColorMixIn, FillMixIn,
values[clipped_values] = numpy.nan
if yPositive:
- return (numpy.nanmin(edges),
- numpy.nanmax(edges),
- numpy.nanmin(values),
- numpy.nanmax(values))
+ return (
+ numpy.nanmin(edges),
+ numpy.nanmax(edges),
+ numpy.nanmin(values),
+ numpy.nanmax(values),
+ )
else: # No log scale on y axis, include 0 in bounds
if numpy.all(numpy.isnan(values)):
return None
- return (numpy.nanmin(edges),
- numpy.nanmax(edges),
- min(0, numpy.nanmin(values)),
- max(0, numpy.nanmax(values)))
-
- def __pickFilledHistogram(self, x: float, y: float) -> typing.Optional[PickingResult]:
+ return (
+ numpy.nanmin(edges),
+ numpy.nanmax(edges),
+ min(0, numpy.nanmin(values)),
+ max(0, numpy.nanmax(values)),
+ )
+
+ def __pickFilledHistogram(
+ self, x: float, y: float
+ ) -> typing.Optional[PickingResult]:
"""Picking implementation for filled histogram
:param x: X position in pixels
@@ -241,7 +264,7 @@ class Histogram(DataItem, AlphaMixIn, ColorMixIn, FillMixIn,
# Check x
edges = self.getBinEdgesData(copy=False)
- index = numpy.searchsorted(edges, (xData,), side='left')[0] - 1
+ index = numpy.searchsorted(edges, (xData,), side="left")[0] - 1
# Safe indexing in histogram values
index = numpy.clip(index, 0, len(edges) - 2)
@@ -251,8 +274,9 @@ class Histogram(DataItem, AlphaMixIn, ColorMixIn, FillMixIn,
baseline = 0 # Default value
value = self.getValueData(copy=False)[index]
- if ((baseline <= value and baseline <= yData <= value) or
- (value < baseline and value <= yData <= baseline)):
+ if (baseline <= value and baseline <= yData <= value) or (
+ value < baseline and value <= yData <= baseline
+ ):
return PickingResult(self, numpy.array([index]))
else:
return None
@@ -296,12 +320,13 @@ class Histogram(DataItem, AlphaMixIn, ColorMixIn, FillMixIn,
:returns: (N histogram value, N+1 bin edges)
:rtype: 2-tuple of numpy.nadarray
"""
- return (self.getValueData(copy),
- self.getBinEdgesData(copy),
- self.getBaseline(copy))
+ return (
+ self.getValueData(copy),
+ self.getBinEdgesData(copy),
+ self.getBaseline(copy),
+ )
- def setData(self, histogram, edges, align='center', baseline=None,
- copy=True):
+ def setData(self, histogram, edges, align="center", baseline=None, copy=True):
"""Set the histogram values and bin edges.
:param numpy.ndarray histogram: The values of the histogram.
@@ -324,7 +349,7 @@ class Histogram(DataItem, AlphaMixIn, ColorMixIn, FillMixIn,
assert histogram.ndim == 1
assert edges.ndim == 1
assert edges.size in (histogram.size, histogram.size + 1)
- assert align in ('center', 'left', 'right')
+ assert align in ("center", "left", "right")
if histogram.size == 0: # No data
self._histogram = ()
@@ -338,12 +363,12 @@ class Histogram(DataItem, AlphaMixIn, ColorMixIn, FillMixIn,
edgesDiff = edgesDiff[numpy.logical_not(numpy.isnan(edgesDiff))]
assert numpy.all(edgesDiff >= 0) or numpy.all(edgesDiff <= 0)
# manage baseline
- if (isinstance(baseline, abc.Iterable)):
+ if isinstance(baseline, abc.Iterable):
baseline = numpy.array(baseline)
if baseline.size == histogram.size:
new_baseline = numpy.empty(baseline.shape[0] * 2)
for i_value, value in enumerate(baseline):
- new_baseline[i_value*2:i_value*2+2] = value
+ new_baseline[i_value * 2 : i_value * 2 + 2] = value
baseline = new_baseline
self._histogram = histogram
self._edges = edges
@@ -376,11 +401,11 @@ class Histogram(DataItem, AlphaMixIn, ColorMixIn, FillMixIn,
"""
# for now we consider that the spaces between xs are constant
edges = x.copy()
- if histogramType == 'left':
+ if histogramType == "left":
return edges[1:]
- if histogramType == 'center':
+ if histogramType == "center":
edges = (edges[1:] + edges[:-1]) / 2.0
- if histogramType == 'right':
+ if histogramType == "right":
width = 1
if len(x) > 1:
width = x[-1] + x[-2]
diff --git a/src/silx/gui/plot/items/image.py b/src/silx/gui/plot/items/image.py
index eaee05a..18310d9 100644
--- a/src/silx/gui/plot/items/image.py
+++ b/src/silx/gui/plot/items/image.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -29,17 +29,21 @@ __authors__ = ["T. Vincent"]
__license__ = "MIT"
__date__ = "08/12/2020"
-try:
- from collections import abc
-except ImportError: # Python2 support
- import collections as abc
+from collections import abc
import logging
import numpy
from ....utils.proxy import docstring
-from .core import (DataItem, LabelsMixIn, DraggableMixIn, ColormapMixIn,
- AlphaMixIn, ItemChangedType)
+from ....utils.deprecation import deprecated_warning
+from .core import (
+ DataItem,
+ LabelsMixIn,
+ DraggableMixIn,
+ ColormapMixIn,
+ AlphaMixIn,
+ ItemChangedType,
+)
_logger = logging.getLogger(__name__)
@@ -62,23 +66,22 @@ def _convertImageToRgba32(image, copy=True):
assert image.shape[-1] in (3, 4)
# Convert type to uint8
- if image.dtype.name != 'uint8':
- if image.dtype.kind == 'f': # Float in [0, 1]
- image = (numpy.clip(image, 0., 1.) * 255).astype(numpy.uint8)
- elif image.dtype.kind == 'b': # boolean
+ if image.dtype.name != "uint8":
+ if image.dtype.kind == "f": # Float in [0, 1]
+ image = (numpy.clip(image, 0.0, 1.0) * 255).astype(numpy.uint8)
+ elif image.dtype.kind == "b": # boolean
image = image.astype(numpy.uint8) * 255
- elif image.dtype.kind in ('i', 'u'): # int, uint
+ elif image.dtype.kind in ("i", "u"): # int, uint
image = numpy.clip(image, 0, 255).astype(numpy.uint8)
else:
- raise ValueError('Unsupported image dtype: %s', image.dtype.name)
+ raise ValueError("Unsupported image dtype: %s", image.dtype.name)
copy = False # A copy as already been done, avoid next one
# Convert RGB to RGBA
if image.shape[-1] == 3:
- new_image = numpy.empty((image.shape[0], image.shape[1], 4),
- dtype=numpy.uint8)
- new_image[:,:,:3] = image
- new_image[:,:, 3] = 255
+ new_image = numpy.empty((image.shape[0], image.shape[1], 4), dtype=numpy.uint8)
+ new_image[:, :, :3] = image
+ new_image[:, :, 3] = 255
return new_image # This is a copy anyway
else:
return numpy.array(image, copy=copy)
@@ -100,11 +103,17 @@ class ImageBase(DataItem, LabelsMixIn, DraggableMixIn, AlphaMixIn):
self._data = data
self._mask = mask
self.__valueDataCache = None # Store default data
- self._origin = (0., 0.)
- self._scale = (1., 1.)
+ self._origin = (0.0, 0.0)
+ self._scale = (1.0, 1.0)
def __getitem__(self, item):
"""Compatibility with PyMca and silx <= 0.4.0"""
+ deprecated_warning(
+ "Attributes",
+ "__getitem__",
+ since_version="2.0.0",
+ replacement="Use ImageBase methods",
+ )
if isinstance(item, slice):
return [self[index] for index in range(*item.indices(5))]
elif item == 0:
@@ -118,15 +127,15 @@ class ImageBase(DataItem, LabelsMixIn, DraggableMixIn, AlphaMixIn):
return None
elif item == 4:
params = {
- 'info': self.getInfo(),
- 'origin': self.getOrigin(),
- 'scale': self.getScale(),
- 'z': self.getZValue(),
- 'selectable': self.isSelectable(),
- 'draggable': self.isDraggable(),
- 'colormap': None,
- 'xlabel': self.getXLabel(),
- 'ylabel': self.getYLabel(),
+ "info": self.getInfo(),
+ "origin": self.getOrigin(),
+ "scale": self.getScale(),
+ "z": self.getZValue(),
+ "selectable": self.isSelectable(),
+ "draggable": self.isDraggable(),
+ "colormap": None,
+ "xlabel": self.getXLabel(),
+ "ylabel": self.getYLabel(),
}
return params
else:
@@ -167,8 +176,7 @@ class ImageBase(DataItem, LabelsMixIn, DraggableMixIn, AlphaMixIn):
@docstring(DraggableMixIn)
def drag(self, from_, to):
origin = self.getOrigin()
- self.setOrigin((origin[0] + to[0] - from_[0],
- origin[1] + to[1] - from_[1]))
+ self.setOrigin((origin[0] + to[0] - from_[0], origin[1] + to[1] - from_[1]))
def getData(self, copy=True):
"""Returns the image data
@@ -190,8 +198,10 @@ class ImageBase(DataItem, LabelsMixIn, DraggableMixIn, AlphaMixIn):
self._boundsChanged()
self._updated(ItemChangedType.DATA)
- if (self.getMaskData(copy=False) is not None and
- previousShape != self._data.shape):
+ if (
+ self.getMaskData(copy=False) is not None
+ and previousShape != self._data.shape
+ ):
# Data shape changed, so mask shape changes.
# Send event, mask is lazily updated in getMaskData
self._updated(ItemChangedType.MASK)
@@ -211,7 +221,9 @@ class ImageBase(DataItem, LabelsMixIn, DraggableMixIn, AlphaMixIn):
if self._mask.shape != shape:
# Clip/extend mask to match data
newMask = numpy.zeros(shape, dtype=self._mask.dtype)
- newMask[:self._mask.shape[0], :self._mask.shape[1]] = self._mask[:shape[0], :shape[1]]
+ newMask[: self._mask.shape[0], : self._mask.shape[1]] = self._mask[
+ : shape[0], : shape[1]
+ ]
self._mask = newMask
return numpy.array(self._mask, copy=copy)
@@ -228,7 +240,9 @@ class ImageBase(DataItem, LabelsMixIn, DraggableMixIn, AlphaMixIn):
shape = self.getData(copy=False).shape[:2]
if mask.shape != shape:
- _logger.warning("Inconsistent shape between mask and data %s, %s", mask.shape, shape)
+ _logger.warning(
+ "Inconsistent shape between mask and data %s, %s", mask.shape, shape
+ )
# Clip/extent is done lazily in getMaskData
elif self._mask is None:
return # No update
@@ -278,7 +292,7 @@ class ImageBase(DataItem, LabelsMixIn, DraggableMixIn, AlphaMixIn):
False to use internal representation (do not modify!)
:returns: numpy.ndarray of uint8 of shape (height, width, 4)
"""
- raise NotImplementedError('This MUST be implemented in sub-class')
+ raise NotImplementedError("This MUST be implemented in sub-class")
def getOrigin(self):
"""Returns the offset from origin at which to display the image.
@@ -336,9 +350,11 @@ class ImageDataBase(ImageBase, ColormapMixIn):
def _getColormapForRendering(self):
colormap = self.getColormap()
if colormap.isAutoscale():
+ # NOTE: Make sure getColormapRange comes from the original object
+ vrange = colormap.getColormapRange(self)
# Avoid backend to compute autoscale: use item cache
colormap = colormap.copy()
- colormap.setVRange(*colormap.getColormapRange(self))
+ colormap.setVRange(*vrange)
return colormap
def getRgbaImageData(self, copy=True):
@@ -350,7 +366,7 @@ class ImageDataBase(ImageBase, ColormapMixIn):
return self.getColormap().applyToData(self)
def setData(self, data, copy=True):
- """"Set the image data
+ """Set the image data
:param numpy.ndarray data: Data array with 2 dimensions (h, w)
:param bool copy: True (Default) to get a copy,
@@ -358,13 +374,11 @@ class ImageDataBase(ImageBase, ColormapMixIn):
"""
data = numpy.array(data, copy=copy)
assert data.ndim == 2
- if data.dtype.kind == 'b':
- _logger.warning(
- 'Converting boolean image to int8 to plot it.')
+ if data.dtype.kind == "b":
+ _logger.warning("Converting boolean image to int8 to plot it.")
data = numpy.array(data, copy=False, dtype=numpy.int8)
elif numpy.iscomplexobj(data):
- _logger.warning(
- 'Converting complex image to absolute value to plot it.')
+ _logger.warning("Converting complex image to absolute value to plot it.")
data = numpy.absolute(data)
super().setData(data)
@@ -391,8 +405,10 @@ class ImageData(ImageDataBase):
# Do not render with non linear scales
return None
- if (self.getAlternativeImageData(copy=False) is not None or
- self.getAlphaData(copy=False) is not None):
+ if (
+ self.getAlternativeImageData(copy=False) is not None
+ or self.getAlphaData(copy=False) is not None
+ ):
dataToUse = self.getRgbaImageData(copy=False)
else:
dataToUse = self.getData(copy=False)
@@ -400,20 +416,28 @@ class ImageData(ImageDataBase):
if dataToUse.size == 0:
return None # No data to display
- return backend.addImage(dataToUse,
- origin=self.getOrigin(),
- scale=self.getScale(),
- colormap=self._getColormapForRendering(),
- alpha=self.getAlpha())
+ return backend.addImage(
+ dataToUse,
+ origin=self.getOrigin(),
+ scale=self.getScale(),
+ colormap=self._getColormapForRendering(),
+ alpha=self.getAlpha(),
+ )
def __getitem__(self, item):
"""Compatibility with PyMca and silx <= 0.4.0"""
+ deprecated_warning(
+ "Attributes",
+ "__getitem__",
+ since_version="2.0.0",
+ replacement="Use ImageData methods",
+ )
if item == 3:
return self.getAlternativeImageData(copy=False)
params = ImageBase.__getitem__(self, item)
if item == 4:
- params['colormap'] = self.getColormap()
+ params["colormap"] = self.getColormap()
return params
@@ -431,7 +455,7 @@ class ImageData(ImageDataBase):
alphaImage = self.getAlphaData(copy=False)
if alphaImage is not None:
# Apply transparency
- image[:,:, 3] = image[:,:, 3] * alphaImage
+ image[:, :, 3] = image[:, :, 3] * alphaImage
return image
def getAlternativeImageData(self, copy=True):
@@ -459,7 +483,7 @@ class ImageData(ImageDataBase):
return numpy.array(self.__alpha, copy=copy)
def setData(self, data, alternative=None, alpha=None, copy=True):
- """"Set the image data and optionally an alternative RGB(A) representation
+ """Set the image data and optionally an alternative RGB(A) representation
:param numpy.ndarray data: Data array with 2 dimensions (h, w)
:param alternative: RGB(A) image to display instead of data,
@@ -484,10 +508,10 @@ class ImageData(ImageDataBase):
if alpha is not None:
alpha = numpy.array(alpha, copy=copy)
assert alpha.shape == data.shape
- if alpha.dtype.kind != 'f':
+ if alpha.dtype.kind != "f":
alpha = alpha.astype(numpy.float32)
- if numpy.any(numpy.logical_or(alpha < 0., alpha > 1.)):
- alpha = numpy.clip(alpha, 0., 1.)
+ if numpy.any(numpy.logical_or(alpha < 0.0, alpha > 1.0)):
+ alpha = numpy.clip(alpha, 0.0, 1.0)
self.__alpha = alpha
super().setData(data)
@@ -512,11 +536,13 @@ class ImageRgba(ImageBase):
if data.size == 0:
return None # No data to display
- return backend.addImage(data,
- origin=self.getOrigin(),
- scale=self.getScale(),
- colormap=None,
- alpha=self.getAlpha())
+ return backend.addImage(
+ data,
+ origin=self.getOrigin(),
+ scale=self.getScale(),
+ colormap=None,
+ alpha=self.getAlpha(),
+ )
def getRgbaImageData(self, copy=True):
"""Get the displayed RGB(A) image
@@ -533,8 +559,14 @@ class ImageRgba(ImageBase):
False to use internal representation (do not modify!)
"""
data = numpy.array(data, copy=copy)
- assert data.ndim == 3
- assert data.shape[-1] in (3, 4)
+ if data.ndim != 3:
+ raise ValueError(
+ f"RGB(A) image is expected to be a 3D dataset. Got {data.ndim} dimensions"
+ )
+ if data.shape[-1] not in (3, 4):
+ raise ValueError(
+ f"RGB(A) image is expected to have 3 or 4 elements as last dimension. Got {data.shape[-1]}"
+ )
super().setData(data)
def _getValueData(self, copy=True):
@@ -545,10 +577,10 @@ class ImageRgba(ImageBase):
:param bool copy:
"""
rgba = self.getRgbaImageData(copy=False).astype(numpy.float32)
- intensity = (rgba[:, :, 0] * 0.299 +
- rgba[:, :, 1] * 0.587 +
- rgba[:, :, 2] * 0.114)
- intensity *= rgba[:, :, 3] / 255.
+ intensity = (
+ rgba[:, :, 0] * 0.299 + rgba[:, :, 1] * 0.587 + rgba[:, :, 2] * 0.114
+ )
+ intensity *= rgba[:, :, 3] / 255.0
return intensity
@@ -558,6 +590,7 @@ class MaskImageData(ImageData):
This class is used to flag mask items. This information is used to improve
internal silx widgets.
"""
+
pass
diff --git a/src/silx/gui/plot/items/image_aggregated.py b/src/silx/gui/plot/items/image_aggregated.py
index ffd41b2..b35e00a 100644
--- a/src/silx/gui/plot/items/image_aggregated.py
+++ b/src/silx/gui/plot/items/image_aggregated.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2021 European Synchrotron Radiation Facility
+# Copyright (c) 2021-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -31,6 +31,7 @@ __date__ = "07/07/2021"
import enum
import logging
from typing import Tuple, Union
+import warnings
import numpy
@@ -68,7 +69,7 @@ class ImageDataAggregated(ImageDataBase):
self.__currentLOD = 0, 0
self.__aggregationMode = self.Aggregation.NONE
- def setAggregationMode(self, mode: Union[str,Aggregation]):
+ def setAggregationMode(self, mode: Union[str, Aggregation]):
"""Set the aggregation method used to reduce the data to screen resolution.
:param Aggregation mode: The aggregation method
@@ -115,12 +116,14 @@ class ImageDataAggregated(ImageDataBase):
if (lodx, lody) not in self.__cacheLODData:
height, width = data.shape
- self.__cacheLODData[(lodx, lody)] = aggregator(
- data[: (height // lody) * lody, : (width // lodx) * lodx].reshape(
- height // lody, lody, width // lodx, lodx
- ),
- axis=(1, 3),
- )
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=RuntimeWarning)
+ self.__cacheLODData[(lodx, lody)] = aggregator(
+ data[
+ : (height // lody) * lody, : (width // lodx) * lodx
+ ].reshape(height // lody, lody, width // lodx, lodx),
+ axis=(1, 3),
+ )
self.__currentLOD = lodx, lody
displayedData = self.__cacheLODData[self.__currentLOD]
@@ -153,10 +156,7 @@ class ImageDataAggregated(ImageDataBase):
xaxis = plot.getXAxis()
yaxis = plot.getYAxis(axis)
- if (
- xaxis.getScale() != Axis.LINEAR
- or yaxis.getScale() != Axis.LINEAR
- ):
+ if xaxis.getScale() != Axis.LINEAR or yaxis.getScale() != Axis.LINEAR:
raise RuntimeError("Only available with linear axes")
xmin, xmax = xaxis.getLimits()
@@ -200,8 +200,10 @@ class ImageDataAggregated(ImageDataBase):
def __plotLimitsChanged(self):
"""Trigger update if level of details has changed"""
- if (self.getAggregationMode() != self.Aggregation.NONE and
- self.__currentLOD != self._getLevelOfDetails()):
+ if (
+ self.getAggregationMode() != self.Aggregation.NONE
+ and self.__currentLOD != self._getLevelOfDetails()
+ ):
self._updated()
@docstring(ImageDataBase)
diff --git a/src/silx/gui/plot/items/marker.py b/src/silx/gui/plot/items/marker.py
index 7596eb0..b3da451 100755
--- a/src/silx/gui/plot/items/marker.py
+++ b/src/silx/gui/plot/items/marker.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -23,6 +23,7 @@
# ###########################################################################*/
"""This module provides markers item of the :class:`Plot`.
"""
+from __future__ import annotations
__authors__ = ["T. Vincent"]
__license__ = "MIT"
@@ -30,11 +31,22 @@ __date__ = "06/03/2017"
import logging
+import numpy
from ....utils.proxy import docstring
-from .core import (Item, DraggableMixIn, ColorMixIn, LineMixIn, SymbolMixIn,
- ItemChangedType, YAxisMixIn)
+from .core import (
+ Item,
+ DraggableMixIn,
+ ColorMixIn,
+ LineMixIn,
+ SymbolMixIn,
+ ItemChangedType,
+ YAxisMixIn,
+)
+from silx import config
from silx.gui import qt
+from silx.gui import colors
+
_logger = logging.getLogger(__name__)
@@ -47,7 +59,7 @@ class MarkerBase(Item, DraggableMixIn, ColorMixIn, YAxisMixIn):
sigDragFinished = qt.Signal()
"""Signal emitted when the marker is released"""
- _DEFAULT_COLOR = (0., 0., 0., 1.)
+ _DEFAULT_COLOR = (0.0, 0.0, 0.0, 1.0)
"""Default color of the markers"""
def __init__(self):
@@ -56,14 +68,21 @@ class MarkerBase(Item, DraggableMixIn, ColorMixIn, YAxisMixIn):
ColorMixIn.__init__(self)
YAxisMixIn.__init__(self)
- self._text = ''
+ self._text = ""
+ self._font = None
+ if config.DEFAULT_PLOT_MARKER_TEXT_FONT_SIZE is not None:
+ self._font = qt.QFont(
+ qt.QApplication.instance().font().family(),
+ config.DEFAULT_PLOT_MARKER_TEXT_FONT_SIZE,
+ )
+
self._x = None
self._y = None
+ self._bgColor: colors.RGBAColorType | None = None
self._constraint = self._defaultConstraint
self.__isBeingDragged = False
- def _addRendererCall(self, backend,
- symbol=None, linestyle='-', linewidth=1):
+ def _addRendererCall(self, backend, symbol=None, linestyle="-", linewidth=1):
"""Perform the update of the backend renderer"""
return backend.addMarker(
x=self.getXPosition(),
@@ -74,7 +93,10 @@ class MarkerBase(Item, DraggableMixIn, ColorMixIn, YAxisMixIn):
linestyle=linestyle,
linewidth=linewidth,
constraint=self.getConstraint(),
- yaxis=self.getYAxis())
+ yaxis=self.getYAxis(),
+ font=self._font, # Do not use getFont to spare creating a new QFont
+ bgcolor=self.getBackgroundColor(),
+ )
def _addBackendRenderer(self, backend):
"""Update backend renderer"""
@@ -108,6 +130,39 @@ class MarkerBase(Item, DraggableMixIn, ColorMixIn, YAxisMixIn):
self._text = text
self._updated(ItemChangedType.TEXT)
+ def getFont(self) -> qt.QFont | None:
+ """Returns a copy of the QFont used to render text.
+
+ To modify the text font, use :meth:`setFont`.
+ """
+ return None if self._font is None else qt.QFont(self._font)
+
+ def setFont(self, font: qt.QFont | None):
+ """Set the QFont used to render text, use None for default.
+
+ A copy is stored, so further modification of the provided font are not taken into account.
+ """
+ if font != self._font:
+ self._font = None if font is None else qt.QFont(font)
+ self._updated(ItemChangedType.FONT)
+
+ def getBackgroundColor(self) -> colors.RGBAColorType | None:
+ """Returns the RGBA background color of the item"""
+ return self._bgColor
+
+ def setBackgroundColor(self, color):
+ """Set item text background color
+
+ :param color: color(s) to be used as a str ("#RRGGBB") or (npoints, 4)
+ unsigned byte array or one of the predefined color names
+ defined in colors.py
+ """
+ if color is not None:
+ color = colors.rgba(color)
+ if self._bgColor != color:
+ self._bgColor = color
+ self._updated(ItemChangedType.BACKGROUND_COLOR)
+
def getXPosition(self):
"""Returns the X position of the marker line in data coordinates
@@ -122,14 +177,14 @@ class MarkerBase(Item, DraggableMixIn, ColorMixIn, YAxisMixIn):
"""
return self._y
- def getPosition(self):
+ def getPosition(self) -> tuple[float | None, float | None]:
"""Returns the (x, y) position of the marker in data coordinates
:rtype: 2-tuple of float or None
"""
return self._x, self._y
- def setPosition(self, x, y):
+ def setPosition(self, x: float, y: float):
"""Set marker position in data coordinates
Constraint are applied if any.
@@ -188,15 +243,15 @@ class MarkerBase(Item, DraggableMixIn, ColorMixIn, YAxisMixIn):
class Marker(MarkerBase, SymbolMixIn):
"""Description of a marker"""
- _DEFAULT_SYMBOL = '+'
+ _DEFAULT_SYMBOL = "+"
"""Default symbol of the marker"""
def __init__(self):
MarkerBase.__init__(self)
SymbolMixIn.__init__(self)
- self._x = 0.
- self._y = 0.
+ self._x = 0.0
+ self._y = 0.0
def _addBackendRenderer(self, backend):
return self._addRendererCall(backend, symbol=self.getSymbol())
@@ -209,9 +264,9 @@ class Marker(MarkerBase, SymbolMixIn):
:param constraint: The constraint of the dragging of this marker
:type: constraint: callable or str
"""
- if constraint == 'horizontal':
+ if constraint == "horizontal":
constraint = self._horizontalConstraint
- elif constraint == 'vertical':
+ elif constraint == "vertical":
constraint = self._verticalConstraint
super(Marker, self)._setConstraint(constraint)
@@ -231,9 +286,9 @@ class _LineMarker(MarkerBase, LineMixIn):
LineMixIn.__init__(self)
def _addBackendRenderer(self, backend):
- return self._addRendererCall(backend,
- linestyle=self.getLineStyle(),
- linewidth=self.getLineWidth())
+ return self._addRendererCall(
+ backend, linestyle=self.getLineStyle(), linewidth=self.getLineWidth()
+ )
class XMarker(_LineMarker):
@@ -241,7 +296,7 @@ class XMarker(_LineMarker):
def __init__(self):
_LineMarker.__init__(self)
- self._x = 0.
+ self._x = 0.0
def setPosition(self, x, y):
"""Set marker line position in data coordinates
@@ -263,7 +318,7 @@ class YMarker(_LineMarker):
def __init__(self):
_LineMarker.__init__(self)
- self._y = 0.
+ self._y = 0.0
def setPosition(self, x, y):
"""Set marker line position in data coordinates
diff --git a/src/silx/gui/plot/items/roi.py b/src/silx/gui/plot/items/roi.py
index 559e7e0..7390b88 100644
--- a/src/silx/gui/plot/items/roi.py
+++ b/src/silx/gui/plot/items/roi.py
@@ -35,6 +35,7 @@ __date__ = "28/06/2018"
import logging
import numpy
+from typing import Tuple
from ... import utils
from .. import items
@@ -60,15 +61,15 @@ logger = logging.getLogger(__name__)
class PointROI(RegionOfInterest, items.SymbolMixIn):
"""A ROI identifying a point in a 2D plot."""
- ICON = 'add-shape-point'
- NAME = 'point markers'
+ ICON = "add-shape-point"
+ NAME = "point markers"
SHORT_NAME = "point"
"""Metadata for this kind of ROI"""
_plotShape = "point"
"""Plot shape which is used for the first interaction"""
- _DEFAULT_SYMBOL = '+'
+ _DEFAULT_SYMBOL = "+"
"""Default symbol of the PointROI
It overwrite the `SymbolMixIn` class attribte.
@@ -88,30 +89,26 @@ class PointROI(RegionOfInterest, items.SymbolMixIn):
self.setPosition(points[0])
def _updated(self, event=None, checkVisibility=True):
- if event == items.ItemChangedType.NAME:
- label = self.getName()
- self._marker.setText(label)
- elif event == items.ItemChangedType.EDITABLE:
+ if event == items.ItemChangedType.EDITABLE:
self._marker._setDraggable(self.isEditable())
- elif event in [items.ItemChangedType.VISIBLE,
- items.ItemChangedType.SELECTABLE]:
+ elif event in [items.ItemChangedType.VISIBLE, items.ItemChangedType.SELECTABLE]:
self._updateItemProperty(event, self, self._marker)
super(PointROI, self)._updated(event, checkVisibility)
+ def _updateText(self, text: str):
+ self._marker.setText(text)
+
def _updatedStyle(self, event, style):
self._marker.setColor(style.getColor())
- def getPosition(self):
- """Returns the position of this ROI
-
- :rtype: numpy.ndarray
- """
+ def getPosition(self) -> Tuple[float, float]:
+ """Returns the position of this ROI"""
return self._marker.getPosition()
def setPosition(self, pos):
"""Set the position of this ROI
- :param numpy.ndarray pos: 2d-coordinate of this point
+ :param pos: 2d-coordinate of this point
"""
self._marker.setPosition(*pos)
@@ -126,16 +123,15 @@ class PointROI(RegionOfInterest, items.SymbolMixIn):
self.sigRegionChanged.emit()
def __str__(self):
- params = '%f %f' % self.getPosition()
+ params = "%f %f" % self.getPosition()
return "%s(%s)" % (self.__class__.__name__, params)
class CrossROI(HandleBasedROI, items.LineMixIn):
- """A ROI identifying a point in a 2D plot and displayed as a cross
- """
+ """A ROI identifying a point in a 2D plot and displayed as a cross"""
- ICON = 'add-shape-cross'
- NAME = 'cross marker'
+ ICON = "add-shape-cross"
+ NAME = "cross marker"
SHORT_NAME = "cross"
"""Metadata for this kind of ROI"""
@@ -177,17 +173,14 @@ class CrossROI(HandleBasedROI, items.LineMixIn):
pos = points[0]
self.setPosition(pos)
- def getPosition(self):
- """Returns the position of this ROI
-
- :rtype: numpy.ndarray
- """
+ def getPosition(self) -> Tuple[float, float]:
+ """Returns the position of this ROI"""
return self._handle.getPosition()
- def setPosition(self, pos):
+ def setPosition(self, pos: Tuple[float, float]):
"""Set the position of this ROI
- :param numpy.ndarray pos: 2d-coordinate of this point
+ :param pos: 2d-coordinate of this point
"""
self._handle.setPosition(*pos)
@@ -213,8 +206,8 @@ class LineROI(HandleBasedROI, items.LineMixIn):
in the center to translate the full ROI.
"""
- ICON = 'add-shape-diagonal'
- NAME = 'line ROI'
+ ICON = "add-shape-diagonal"
+ NAME = "line ROI"
SHORT_NAME = "line"
"""Metadata for this kind of ROI"""
@@ -244,11 +237,12 @@ class LineROI(HandleBasedROI, items.LineMixIn):
self._updateItemProperty(event, self, self.__shape)
super(LineROI, self)._updated(event, checkVisibility)
- def _updatedStyle(self, event, style):
+ def _updatedStyle(self, event, style: items.CurveStyle):
super(LineROI, self)._updatedStyle(event, style)
self.__shape.setColor(style.getColor())
self.__shape.setLineStyle(style.getLineStyle())
self.__shape.setLineWidth(style.getLineWidth())
+ self.__shape.setLineGapColor(style.getLineGapColor())
def setFirstShapePoints(self, points):
assert len(points) == 2
@@ -257,7 +251,7 @@ class LineROI(HandleBasedROI, items.LineMixIn):
def _updateText(self, text):
self._handleLabel.setText(text)
- def setEndPoints(self, startPoint, endPoint):
+ def setEndPoints(self, startPoint: numpy.ndarray, endPoint: numpy.ndarray):
"""Set this line location using the ending points
:param numpy.ndarray startPoint: Staring bounding point of the line
@@ -266,7 +260,7 @@ class LineROI(HandleBasedROI, items.LineMixIn):
if not numpy.array_equal((startPoint, endPoint), self.getEndPoints()):
self.__updateEndPoints(startPoint, endPoint)
- def __updateEndPoints(self, startPoint, endPoint):
+ def __updateEndPoints(self, startPoint: numpy.ndarray, endPoint: numpy.ndarray):
"""Update marker and shape to match given end points
:param numpy.ndarray startPoint: Staring bounding point of the line
@@ -328,28 +322,44 @@ class LineROI(HandleBasedROI, items.LineMixIn):
return False
return (
- segments_intersection(seg1_start_pt=line_pt1, seg1_end_pt=line_pt2,
- seg2_start_pt=bottom_left, seg2_end_pt=bottom_right) or
- segments_intersection(seg1_start_pt=line_pt1, seg1_end_pt=line_pt2,
- seg2_start_pt=bottom_right, seg2_end_pt=top_right) or
- segments_intersection(seg1_start_pt=line_pt1, seg1_end_pt=line_pt2,
- seg2_start_pt=top_right, seg2_end_pt=top_left) or
- segments_intersection(seg1_start_pt=line_pt1, seg1_end_pt=line_pt2,
- seg2_start_pt=top_left, seg2_end_pt=bottom_left)
+ segments_intersection(
+ seg1_start_pt=line_pt1,
+ seg1_end_pt=line_pt2,
+ seg2_start_pt=bottom_left,
+ seg2_end_pt=bottom_right,
+ )
+ or segments_intersection(
+ seg1_start_pt=line_pt1,
+ seg1_end_pt=line_pt2,
+ seg2_start_pt=bottom_right,
+ seg2_end_pt=top_right,
+ )
+ or segments_intersection(
+ seg1_start_pt=line_pt1,
+ seg1_end_pt=line_pt2,
+ seg2_start_pt=top_right,
+ seg2_end_pt=top_left,
+ )
+ or segments_intersection(
+ seg1_start_pt=line_pt1,
+ seg1_end_pt=line_pt2,
+ seg2_start_pt=top_left,
+ seg2_end_pt=bottom_left,
+ )
) is not None
def __str__(self):
start, end = self.getEndPoints()
params = start[0], start[1], end[0], end[1]
- params = 'start: %f %f; end: %f %f' % params
+ params = "start: %f %f; end: %f %f" % params
return "%s(%s)" % (self.__class__.__name__, params)
class HorizontalLineROI(RegionOfInterest, items.LineMixIn):
"""A ROI identifying an horizontal line in a 2D plot."""
- ICON = 'add-shape-horizontal'
- NAME = 'horizontal line ROI'
+ ICON = "add-shape-horizontal"
+ NAME = "horizontal line ROI"
SHORT_NAME = "hline"
"""Metadata for this kind of ROI"""
@@ -366,16 +376,15 @@ class HorizontalLineROI(RegionOfInterest, items.LineMixIn):
self.addItem(self._marker)
def _updated(self, event=None, checkVisibility=True):
- if event == items.ItemChangedType.NAME:
- label = self.getName()
- self._marker.setText(label)
- elif event == items.ItemChangedType.EDITABLE:
+ if event == items.ItemChangedType.EDITABLE:
self._marker._setDraggable(self.isEditable())
- elif event in [items.ItemChangedType.VISIBLE,
- items.ItemChangedType.SELECTABLE]:
+ elif event in [items.ItemChangedType.VISIBLE, items.ItemChangedType.SELECTABLE]:
self._updateItemProperty(event, self, self._marker)
super(HorizontalLineROI, self)._updated(event, checkVisibility)
+ def _updateText(self, text: str):
+ self._marker.setText(text)
+
def _updatedStyle(self, event, style):
self._marker.setColor(style.getColor())
self._marker.setLineStyle(style.getLineStyle())
@@ -387,18 +396,15 @@ class HorizontalLineROI(RegionOfInterest, items.LineMixIn):
return
self.setPosition(pos)
- def getPosition(self):
- """Returns the position of this line if the horizontal axis
-
- :rtype: float
- """
+ def getPosition(self) -> float:
+ """Returns the position of this line if the horizontal axis"""
pos = self._marker.getPosition()
return pos[1]
- def setPosition(self, pos):
+ def setPosition(self, pos: float):
"""Set the position of this ROI
- :param float pos: Horizontal position of this line
+ :param pos: Horizontal position of this line
"""
self._marker.setPosition(0, pos)
@@ -412,15 +418,15 @@ class HorizontalLineROI(RegionOfInterest, items.LineMixIn):
self.sigRegionChanged.emit()
def __str__(self):
- params = 'y: %f' % self.getPosition()
+ params = "y: %f" % self.getPosition()
return "%s(%s)" % (self.__class__.__name__, params)
class VerticalLineROI(RegionOfInterest, items.LineMixIn):
"""A ROI identifying a vertical line in a 2D plot."""
- ICON = 'add-shape-vertical'
- NAME = 'vertical line ROI'
+ ICON = "add-shape-vertical"
+ NAME = "vertical line ROI"
SHORT_NAME = "vline"
"""Metadata for this kind of ROI"""
@@ -437,16 +443,15 @@ class VerticalLineROI(RegionOfInterest, items.LineMixIn):
self.addItem(self._marker)
def _updated(self, event=None, checkVisibility=True):
- if event == items.ItemChangedType.NAME:
- label = self.getName()
- self._marker.setText(label)
- elif event == items.ItemChangedType.EDITABLE:
+ if event == items.ItemChangedType.EDITABLE:
self._marker._setDraggable(self.isEditable())
- elif event in [items.ItemChangedType.VISIBLE,
- items.ItemChangedType.SELECTABLE]:
+ elif event in [items.ItemChangedType.VISIBLE, items.ItemChangedType.SELECTABLE]:
self._updateItemProperty(event, self, self._marker)
super(VerticalLineROI, self)._updated(event, checkVisibility)
+ def _updateText(self, text: str):
+ self._marker.setText(text)
+
def _updatedStyle(self, event, style):
self._marker.setColor(style.getColor())
self._marker.setLineStyle(style.getLineStyle())
@@ -456,15 +461,12 @@ class VerticalLineROI(RegionOfInterest, items.LineMixIn):
pos = points[0, 0]
self.setPosition(pos)
- def getPosition(self):
- """Returns the position of this line if the horizontal axis
-
- :rtype: float
- """
+ def getPosition(self) -> float:
+ """Returns the position of this line if the horizontal axis"""
pos = self._marker.getPosition()
return pos[0]
- def setPosition(self, pos):
+ def setPosition(self, pos: float):
"""Set the position of this ROI
:param float pos: Horizontal position of this line
@@ -481,7 +483,7 @@ class VerticalLineROI(RegionOfInterest, items.LineMixIn):
self.sigRegionChanged.emit()
def __str__(self):
- params = 'x: %f' % self.getPosition()
+ params = "x: %f" % self.getPosition()
return "%s(%s)" % (self.__class__.__name__, params)
@@ -492,8 +494,8 @@ class RectangleROI(HandleBasedROI, items.LineMixIn):
center to translate the full ROI.
"""
- ICON = 'add-shape-rectangle'
- NAME = 'rectangle ROI'
+ ICON = "add-shape-rectangle"
+ NAME = "rectangle ROI"
SHORT_NAME = "rectangle"
"""Metadata for this kind of ROI"""
@@ -530,6 +532,7 @@ class RectangleROI(HandleBasedROI, items.LineMixIn):
self.__shape.setColor(style.getColor())
self.__shape.setLineStyle(style.getLineStyle())
self.__shape.setLineWidth(style.getLineWidth())
+ self.__shape.setLineGapColor(style.getLineGapColor())
def setFirstShapePoints(self, points):
assert len(points) == 2
@@ -598,11 +601,12 @@ class RectangleROI(HandleBasedROI, items.LineMixIn):
self.setGeometry(center=position, size=size)
def setGeometry(self, origin=None, size=None, center=None):
- """Set the geometry of the ROI
- """
- if ((origin is None or numpy.array_equal(origin, self.getOrigin())) and
- (center is None or numpy.array_equal(center, self.getCenter())) and
- numpy.array_equal(size, self.getSize())):
+ """Set the geometry of the ROI"""
+ if (
+ (origin is None or numpy.array_equal(origin, self.getOrigin()))
+ and (center is None or numpy.array_equal(center, self.getCenter()))
+ and numpy.array_equal(size, self.getSize())
+ ):
return # Nothing has changed
self._updateGeometry(origin, size, center)
@@ -661,17 +665,38 @@ class RectangleROI(HandleBasedROI, items.LineMixIn):
points = numpy.array([current, current2])
# Switch handles if they were crossed by interaction
- if self._handleBottomLeft.getXPosition() > self._handleBottomRight.getXPosition():
- self._handleBottomLeft, self._handleBottomRight = self._handleBottomRight, self._handleBottomLeft
+ if (
+ self._handleBottomLeft.getXPosition()
+ > self._handleBottomRight.getXPosition()
+ ):
+ self._handleBottomLeft, self._handleBottomRight = (
+ self._handleBottomRight,
+ self._handleBottomLeft,
+ )
if self._handleTopLeft.getXPosition() > self._handleTopRight.getXPosition():
- self._handleTopLeft, self._handleTopRight = self._handleTopRight, self._handleTopLeft
-
- if self._handleBottomLeft.getYPosition() > self._handleTopLeft.getYPosition():
- self._handleBottomLeft, self._handleTopLeft = self._handleTopLeft, self._handleBottomLeft
-
- if self._handleBottomRight.getYPosition() > self._handleTopRight.getYPosition():
- self._handleBottomRight, self._handleTopRight = self._handleTopRight, self._handleBottomRight
+ self._handleTopLeft, self._handleTopRight = (
+ self._handleTopRight,
+ self._handleTopLeft,
+ )
+
+ if (
+ self._handleBottomLeft.getYPosition()
+ > self._handleTopLeft.getYPosition()
+ ):
+ self._handleBottomLeft, self._handleTopLeft = (
+ self._handleTopLeft,
+ self._handleBottomLeft,
+ )
+
+ if (
+ self._handleBottomRight.getYPosition()
+ > self._handleTopRight.getYPosition()
+ ):
+ self._handleBottomRight, self._handleTopRight = (
+ self._handleTopRight,
+ self._handleBottomRight,
+ )
self._setBound(points)
@@ -679,7 +704,7 @@ class RectangleROI(HandleBasedROI, items.LineMixIn):
origin = self.getOrigin()
w, h = self.getSize()
params = origin[0], origin[1], w, h
- params = 'origin: %f %f; width: %f; height: %f' % params
+ params = "origin: %f %f; width: %f; height: %f" % params
return "%s(%s)" % (self.__class__.__name__, params)
@@ -690,8 +715,8 @@ class CircleROI(HandleBasedROI, items.LineMixIn):
and one anchor on the perimeter to change the radius.
"""
- ICON = 'add-shape-circle'
- NAME = 'circle ROI'
+ ICON = "add-shape-circle"
+ NAME = "circle ROI"
SHORT_NAME = "circle"
"""Metadata for this kind of ROI"""
@@ -731,6 +756,7 @@ class CircleROI(HandleBasedROI, items.LineMixIn):
self.__shape.setColor(style.getColor())
self.__shape.setLineStyle(style.getLineStyle())
self.__shape.setLineWidth(style.getLineWidth())
+ self.__shape.setLineGapColor(style.getLineGapColor())
def setFirstShapePoints(self, points):
assert len(points) == 2
@@ -779,8 +805,7 @@ class CircleROI(HandleBasedROI, items.LineMixIn):
self._updateGeometry()
def setGeometry(self, center, radius):
- """Set the geometry of the ROI
- """
+ """Set the geometry of the ROI"""
if numpy.array_equal(center, self.getCenter()):
self.setRadius(radius)
else:
@@ -797,8 +822,9 @@ class CircleROI(HandleBasedROI, items.LineMixIn):
nbpoints = 27
angles = numpy.arange(nbpoints) * 2.0 * numpy.pi / nbpoints
- circleShape = numpy.array((numpy.cos(angles) * self.__radius,
- numpy.sin(angles) * self.__radius)).T
+ circleShape = numpy.array(
+ (numpy.cos(angles) * self.__radius, numpy.sin(angles) * self.__radius)
+ ).T
circleShape += center
self.__shape.setPoints(circleShape)
self.sigRegionChanged.emit()
@@ -821,7 +847,7 @@ class CircleROI(HandleBasedROI, items.LineMixIn):
center = self.getCenter()
radius = self.getRadius()
params = center[0], center[1], radius
- params = 'center: %f %f; radius: %f;' % params
+ params = "center: %f %f; radius: %f;" % params
return "%s(%s)" % (self.__class__.__name__, params)
@@ -833,8 +859,8 @@ class EllipseROI(HandleBasedROI, items.LineMixIn):
minor-radius. These two anchors also allow to change the orientation.
"""
- ICON = 'add-shape-ellipse'
- NAME = 'ellipse ROI'
+ ICON = "add-shape-ellipse"
+ NAME = "ellipse ROI"
SHORT_NAME = "ellipse"
"""Metadata for this kind of ROI"""
@@ -860,8 +886,10 @@ class EllipseROI(HandleBasedROI, items.LineMixIn):
self.__shape = shape
self.addItem(shape)
- self._radius = 0., 0.
- self._orientation = 0. # angle in radians between the X-axis and the _handleAxis0
+ self._radius = 0.0, 0.0
+ self._orientation = (
+ 0.0 # angle in radians between the X-axis and the _handleAxis0
+ )
def _updated(self, event=None, checkVisibility=True):
if event == items.ItemChangedType.VISIBLE:
@@ -873,6 +901,7 @@ class EllipseROI(HandleBasedROI, items.LineMixIn):
self.__shape.setColor(style.getColor())
self.__shape.setLineStyle(style.getLineStyle())
self.__shape.setLineWidth(style.getLineWidth())
+ self.__shape.setLineGapColor(style.getLineGapColor())
def setFirstShapePoints(self, points):
assert len(points) == 2
@@ -905,9 +934,9 @@ class EllipseROI(HandleBasedROI, items.LineMixIn):
center = points[0]
radius = numpy.linalg.norm(points[0] - points[1])
orientation = self._calculateOrientation(points[0], points[1])
- self.setGeometry(center=center,
- radius=(radius, radius),
- orientation=orientation)
+ self.setGeometry(
+ center=center, radius=(radius, radius), orientation=orientation
+ )
def _updateText(self, text):
self._handleLabel.setText(text)
@@ -1007,10 +1036,11 @@ class EllipseROI(HandleBasedROI, items.LineMixIn):
# ensure that we store the orientation in range [0, 2*pi
orientation = numpy.mod(orientation, 2 * numpy.pi)
- if (numpy.array_equal(center, self.getCenter()) or
- radius != self._radius or
- orientation != self._orientation):
-
+ if (
+ numpy.array_equal(center, self.getCenter())
+ or radius != self._radius
+ or orientation != self._orientation
+ ):
# Update parameters directly
self._radius = radius
self._orientation = orientation
@@ -1030,10 +1060,18 @@ class EllipseROI(HandleBasedROI, items.LineMixIn):
# _handleAxis1 is the major axis
orientation -= numpy.pi / 2
- point0 = numpy.array([center[0] + self._radius[0] * numpy.cos(orientation),
- center[1] + self._radius[0] * numpy.sin(orientation)])
- point1 = numpy.array([center[0] - self._radius[1] * numpy.sin(orientation),
- center[1] + self._radius[1] * numpy.cos(orientation)])
+ point0 = numpy.array(
+ [
+ center[0] + self._radius[0] * numpy.cos(orientation),
+ center[1] + self._radius[0] * numpy.sin(orientation),
+ ]
+ )
+ point1 = numpy.array(
+ [
+ center[0] - self._radius[1] * numpy.sin(orientation),
+ center[1] + self._radius[1] * numpy.cos(orientation),
+ ]
+ )
with utils.blockSignals(self._handleAxis0):
self._handleAxis0.setPosition(*point0)
with utils.blockSignals(self._handleAxis1):
@@ -1043,10 +1081,12 @@ class EllipseROI(HandleBasedROI, items.LineMixIn):
nbpoints = 27
angles = numpy.arange(nbpoints) * 2.0 * numpy.pi / nbpoints
- X = (self._radius[0] * numpy.cos(angles) * numpy.cos(orientation)
- - self._radius[1] * numpy.sin(angles) * numpy.sin(orientation))
- Y = (self._radius[0] * numpy.cos(angles) * numpy.sin(orientation)
- + self._radius[1] * numpy.sin(angles) * numpy.cos(orientation))
+ X = self._radius[0] * numpy.cos(angles) * numpy.cos(orientation) - self._radius[
+ 1
+ ] * numpy.sin(angles) * numpy.sin(orientation)
+ Y = self._radius[0] * numpy.cos(angles) * numpy.sin(orientation) + self._radius[
+ 1
+ ] * numpy.sin(angles) * numpy.cos(orientation)
ellipseShape = numpy.array((X, Y)).T
ellipseShape += center
@@ -1083,8 +1123,10 @@ class EllipseROI(HandleBasedROI, items.LineMixIn):
major, minor = self.getMajorRadius(), self.getMinorRadius()
delta = self.getOrientation()
x, y = position - self.getCenter()
- return ((x*numpy.cos(delta) + y*numpy.sin(delta))**2/major**2 +
- (x*numpy.sin(delta) - y*numpy.cos(delta))**2/minor**2) <= 1
+ return (
+ (x * numpy.cos(delta) + y * numpy.sin(delta)) ** 2 / major**2
+ + (x * numpy.sin(delta) - y * numpy.cos(delta)) ** 2 / minor**2
+ ) <= 1
def __str__(self):
center = self.getCenter()
@@ -1092,7 +1134,10 @@ class EllipseROI(HandleBasedROI, items.LineMixIn):
minor = self.getMinorRadius()
orientation = self.getOrientation()
params = center[0], center[1], major, minor, orientation
- params = 'center: %f %f; major radius: %f: minor radius: %f; orientation: %f' % params
+ params = (
+ "center: %f %f; major radius: %f: minor radius: %f; orientation: %f"
+ % params
+ )
return "%s(%s)" % (self.__class__.__name__, params)
@@ -1102,8 +1147,8 @@ class PolygonROI(HandleBasedROI, items.LineMixIn):
This ROI provides 1 anchor for each point of the polygon.
"""
- ICON = 'add-shape-polygon'
- NAME = 'polygon ROI'
+ ICON = "add-shape-polygon"
+ NAME = "polygon ROI"
SHORT_NAME = "polygon"
"""Metadata for this kind of ROI"""
@@ -1134,6 +1179,7 @@ class PolygonROI(HandleBasedROI, items.LineMixIn):
self.__shape.setColor(style.getColor())
self.__shape.setLineStyle(style.getLineStyle())
self.__shape.setLineWidth(style.getLineWidth())
+ self.__shape.setLineGapColor(style.getLineGapColor())
if self._handleClose is not None:
color = self._computeHandleColor(style.getColor())
self._handleClose.setColor(color)
@@ -1156,8 +1202,7 @@ class PolygonROI(HandleBasedROI, items.LineMixIn):
self.setPoints(points)
def creationStarted(self):
- """"Called when the ROI creation interaction was started.
- """
+ """Called when the ROI creation interaction was started."""
# Handle to see where to close the polygon
self._handleClose = self.addUserHandle()
self._handleClose.setSymbol("o")
@@ -1178,8 +1223,7 @@ class PolygonROI(HandleBasedROI, items.LineMixIn):
return self._handleClose is not None
def creationFinalized(self):
- """"Called when the ROI creation interaction was finalized.
- """
+ """Called when the ROI creation interaction was finalized."""
self.removeHandle(self._handleClose)
self._handleClose = None
self.removeItem(self.__shape)
@@ -1206,7 +1250,7 @@ class PolygonROI(HandleBasedROI, items.LineMixIn):
:param numpy.ndarray pos: 2d-coordinate of this point
"""
- assert(len(points.shape) == 2 and points.shape[1] == 2)
+ assert len(points.shape) == 2 and points.shape[1] == 2
if numpy.array_equal(points, self._points):
return # Nothing has changed
@@ -1277,7 +1321,7 @@ class PolygonROI(HandleBasedROI, items.LineMixIn):
def __str__(self):
points = self._points
- params = '; '.join('%f %f' % (pt[0], pt[1]) for pt in points)
+ params = "; ".join("%f %f" % (pt[0], pt[1]) for pt in points)
return "%s(%s)" % (self.__class__.__name__, params)
@docstring(HandleBasedROI)
@@ -1300,8 +1344,8 @@ class PolygonROI(HandleBasedROI, items.LineMixIn):
class HorizontalRangeROI(RegionOfInterest, items.LineMixIn):
"""A ROI identifying an horizontal range in a 1D plot."""
- ICON = 'add-range-horizontal'
- NAME = 'horizontal range ROI'
+ ICON = "add-range-horizontal"
+ NAME = "horizontal range ROI"
SHORT_NAME = "hrange"
_plotShape = "line"
@@ -1333,16 +1377,13 @@ class HorizontalRangeROI(RegionOfInterest, items.LineMixIn):
self._updatePos(vmin, vmax)
def _updated(self, event=None, checkVisibility=True):
- if event == items.ItemChangedType.NAME:
- self._updateText()
- elif event == items.ItemChangedType.EDITABLE:
+ if event == items.ItemChangedType.EDITABLE:
self._updateEditable()
- self._updateText()
+ self._updateText(self.getText())
elif event == items.ItemChangedType.LINE_STYLE:
markers = [self._markerMin, self._markerMax]
self._updateItemProperty(event, self, markers)
- elif event in [items.ItemChangedType.VISIBLE,
- items.ItemChangedType.SELECTABLE]:
+ elif event in [items.ItemChangedType.VISIBLE, items.ItemChangedType.SELECTABLE]:
markers = [self._markerMin, self._markerMax, self._markerCen]
self._updateItemProperty(event, self, markers)
super(HorizontalRangeROI, self)._updated(event, checkVisibility)
@@ -1353,8 +1394,7 @@ class HorizontalRangeROI(RegionOfInterest, items.LineMixIn):
m.setColor(style.getColor())
m.setLineWidth(style.getLineWidth())
- def _updateText(self):
- text = self.getName()
+ def _updateText(self, text: str):
if self.isEditable():
self._markerMin.setText("")
self._markerCen.setText(text)
@@ -1409,8 +1449,10 @@ class HorizontalRangeROI(RegionOfInterest, items.LineMixIn):
err = "Can't set vmin or vmax to None"
raise ValueError(err)
if vmin > vmax:
- err = "Can't set vmin and vmax because vmin >= vmax " \
- "vmin = %s, vmax = %s" % (vmin, vmax)
+ err = (
+ "Can't set vmin and vmax because vmin >= vmax "
+ "vmin = %s, vmax = %s" % (vmin, vmax)
+ )
raise ValueError(err)
self._updatePos(vmin, vmax)
@@ -1515,5 +1557,5 @@ class HorizontalRangeROI(RegionOfInterest, items.LineMixIn):
def __str__(self):
vrange = self.getRange()
- params = 'min: %f; max: %f' % vrange
+ params = "min: %f; max: %f" % vrange
return "%s(%s)" % (self.__class__.__name__, params)
diff --git a/src/silx/gui/plot/items/scatter.py b/src/silx/gui/plot/items/scatter.py
index 96fb311..c46b60c 100644
--- a/src/silx/gui/plot/items/scatter.py
+++ b/src/silx/gui/plot/items/scatter.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2017-2022 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -33,6 +33,7 @@ from collections import namedtuple
import logging
import threading
import numpy
+from matplotlib.tri import LinearTriInterpolator, Triangulation
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, CancelledError
@@ -41,7 +42,6 @@ from ....utils.proxy import docstring
from ....math.combo import min_max
from ....math.histogram import Histogramnd
from ....utils.weakref import WeakList
-from .._utils.delaunay import delaunay
from .core import PointsBase, ColormapMixIn, ScatterVisualizationMixIn
from .axis import Axis
from ._pick import PickingResult
@@ -51,8 +51,7 @@ _logger = logging.getLogger(__name__)
class _GreedyThreadPoolExecutor(ThreadPoolExecutor):
- """:class:`ThreadPoolExecutor` with an extra :meth:`submit_greedy` method.
- """
+ """:class:`ThreadPoolExecutor` with an extra :meth:`submit_greedy` method."""
def __init__(self, *args, **kwargs):
super(_GreedyThreadPoolExecutor, self).__init__(*args, **kwargs)
@@ -76,8 +75,7 @@ class _GreedyThreadPoolExecutor(ThreadPoolExecutor):
if not future.done():
future.cancel()
- future = super(_GreedyThreadPoolExecutor, self).submit(
- fn, *args, **kwargs)
+ future = super(_GreedyThreadPoolExecutor, self).submit(fn, *args, **kwargs)
self.__futures[queue].append(future)
return future
@@ -85,6 +83,7 @@ class _GreedyThreadPoolExecutor(ThreadPoolExecutor):
# Functions to guess grid shape from coordinates
+
def _get_z_line_length(array):
"""Return length of line if array is a Z-like 2D regular grid.
@@ -97,7 +96,7 @@ def _get_z_line_length(array):
if len(sign) == 0 or sign[0] == 0: # We don't handle that
return 0
# Check this way to account for 0 sign (i.e., diff == 0)
- beginnings = numpy.where(sign == - sign[0])[0] + 1
+ beginnings = numpy.where(sign == -sign[0])[0] + 1
if len(beginnings) == 0:
return 0
length = beginnings[0]
@@ -121,11 +120,11 @@ def _guess_z_grid_shape(x, y):
"""
width = _get_z_line_length(x)
if width != 0:
- return 'row', (int(numpy.ceil(len(x) / width)), width)
+ return "row", (int(numpy.ceil(len(x) / width)), width)
else:
height = _get_z_line_length(y)
if height != 0:
- return 'column', (height, int(numpy.ceil(len(y) / height)))
+ return "column", (height, int(numpy.ceil(len(y) / height)))
return None
@@ -139,7 +138,7 @@ def is_monotonic(array):
:rtype: int
"""
diff = numpy.diff(numpy.ravel(array))
- with numpy.errstate(invalid='ignore'):
+ with numpy.errstate(invalid="ignore"):
if numpy.all(diff >= 0):
return 1
elif numpy.all(diff <= 0):
@@ -168,7 +167,7 @@ def _guess_grid(x, y):
else:
# Cannot guess a regular grid
# Let's assume it's a single line
- order = 'row' # or 'column' doesn't matter for a single line
+ order = "row" # or 'column' doesn't matter for a single line
y_monotonic = is_monotonic(y)
if is_monotonic(x) or y_monotonic: # we can guess a line
x_min, x_max = min_max(x)
@@ -211,18 +210,24 @@ def _quadrilateral_grid_coords(points):
neighbour_view = numpy.lib.stride_tricks.as_strided(
points,
shape=(dim0 - 1, dim1 - 1, 2, 2, points.shape[2]),
- strides=points.strides[:2] + points.strides[:2] + points.strides[-1:], writeable=False)
+ strides=points.strides[:2] + points.strides[:2] + points.strides[-1:],
+ writeable=False,
+ )
inner_points = numpy.mean(neighbour_view, axis=(2, 3))
grid_points[1:-1, 1:-1] = inner_points
# Compute 'vertical' sides
# Alternative: grid_points[1:-1, [0, -1]] = points[:-1, [0, -1]] + points[1:, [0, -1]] - inner_points[:, [0, -1]]
- grid_points[1:-1, [0, -1], 0] = points[:-1, [0, -1], 0] + points[1:, [0, -1], 0] - inner_points[:, [0, -1], 0]
+ grid_points[1:-1, [0, -1], 0] = (
+ points[:-1, [0, -1], 0] + points[1:, [0, -1], 0] - inner_points[:, [0, -1], 0]
+ )
grid_points[1:-1, [0, -1], 1] = inner_points[:, [0, -1], 1]
# Compute 'horizontal' sides
grid_points[[0, -1], 1:-1, 0] = inner_points[[0, -1], :, 0]
- grid_points[[0, -1], 1:-1, 1] = points[[0, -1], :-1, 1] + points[[0, -1], 1:, 1] - inner_points[[0, -1], :, 1]
+ grid_points[[0, -1], 1:-1, 1] = (
+ points[[0, -1], :-1, 1] + points[[0, -1], 1:, 1] - inner_points[[0, -1], :, 1]
+ )
# Compute corners
d0, d1 = [0, 0, -1, -1], [0, -1, -1, 0]
@@ -259,11 +264,13 @@ def _quadrilateral_grid_as_triangles(points):
_RegularGridInfo = namedtuple(
- '_RegularGridInfo', ['bounds', 'origin', 'scale', 'shape', 'order'])
+ "_RegularGridInfo", ["bounds", "origin", "scale", "shape", "order"]
+)
_HistogramInfo = namedtuple(
- '_HistogramInfo', ['mean', 'count', 'sum', 'origin', 'scale', 'shape'])
+ "_HistogramInfo", ["mean", "count", "sum", "origin", "scale", "shape"]
+)
class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
@@ -278,7 +285,7 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
ScatterVisualizationMixIn.Visualization.REGULAR_GRID,
ScatterVisualizationMixIn.Visualization.IRREGULAR_GRID,
ScatterVisualizationMixIn.Visualization.BINNED_STATISTIC,
- )
+ )
"""Overrides supported Visualizations"""
def __init__(self):
@@ -288,7 +295,7 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
self._value = ()
self.__alpha = None
# Cache Delaunay triangulation future object
- self.__delaunayFuture = None
+ self.__triangulationFuture = None
# Cache interpolator future object
self.__interpolatorFuture = None
self.__executor = None
@@ -310,7 +317,9 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
data = getattr(
histoInfo,
self.getVisualizationParameter(
- self.VisualizationParameter.BINNED_STATISTIC_FUNCTION))
+ self.VisualizationParameter.BINNED_STATISTIC_FUNCTION
+ ),
+ )
else:
data = self.getValueData(copy=False)
self._setColormappedData(data, copy=False)
@@ -319,8 +328,9 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
def setVisualization(self, mode):
previous = self.getVisualization()
if super().setVisualization(mode):
- if (bool(mode is self.Visualization.BINNED_STATISTIC) ^
- bool(previous is self.Visualization.BINNED_STATISTIC)):
+ if bool(mode is self.Visualization.BINNED_STATISTIC) ^ bool(
+ previous is self.Visualization.BINNED_STATISTIC
+ ):
self._updateColormappedData()
return True
else:
@@ -331,16 +341,22 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
parameter = self.VisualizationParameter.from_value(parameter)
if super(Scatter, self).setVisualizationParameter(parameter, value):
- if parameter in (self.VisualizationParameter.GRID_BOUNDS,
- self.VisualizationParameter.GRID_MAJOR_ORDER,
- self.VisualizationParameter.GRID_SHAPE):
+ if parameter in (
+ self.VisualizationParameter.GRID_BOUNDS,
+ self.VisualizationParameter.GRID_MAJOR_ORDER,
+ self.VisualizationParameter.GRID_SHAPE,
+ ):
self.__cacheRegularGridInfo = None
- if parameter in (self.VisualizationParameter.BINNED_STATISTIC_SHAPE,
- self.VisualizationParameter.BINNED_STATISTIC_FUNCTION,
- self.VisualizationParameter.DATA_BOUNDS_HINT):
- if parameter in (self.VisualizationParameter.BINNED_STATISTIC_SHAPE,
- self.VisualizationParameter.DATA_BOUNDS_HINT):
+ if parameter in (
+ self.VisualizationParameter.BINNED_STATISTIC_SHAPE,
+ self.VisualizationParameter.BINNED_STATISTIC_FUNCTION,
+ self.VisualizationParameter.DATA_BOUNDS_HINT,
+ ):
+ if parameter in (
+ self.VisualizationParameter.BINNED_STATISTIC_SHAPE,
+ self.VisualizationParameter.DATA_BOUNDS_HINT,
+ ):
self.__cacheHistogramInfo = None # Clean-up cache
if self.getVisualization() is self.Visualization.BINNED_STATISTIC:
self._updateColormappedData()
@@ -351,14 +367,16 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
@docstring(ScatterVisualizationMixIn)
def getCurrentVisualizationParameter(self, parameter):
value = self.getVisualizationParameter(parameter)
- if (parameter is self.VisualizationParameter.DATA_BOUNDS_HINT or
- value is not None):
+ if (
+ parameter is self.VisualizationParameter.DATA_BOUNDS_HINT
+ or value is not None
+ ):
return value # Value has been set, return it
elif parameter is self.VisualizationParameter.GRID_BOUNDS:
grid = self.__getRegularGridInfo()
return None if grid is None else grid.bounds
-
+
elif parameter is self.VisualizationParameter.GRID_MAJOR_ORDER:
grid = self.__getRegularGridInfo()
return None if grid is None else grid.order
@@ -378,15 +396,19 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
"""Get grid info"""
if self.__cacheRegularGridInfo is None:
shape = self.getVisualizationParameter(
- self.VisualizationParameter.GRID_SHAPE)
+ self.VisualizationParameter.GRID_SHAPE
+ )
order = self.getVisualizationParameter(
- self.VisualizationParameter.GRID_MAJOR_ORDER)
+ self.VisualizationParameter.GRID_MAJOR_ORDER
+ )
if shape is None or order is None:
- guess = _guess_grid(self.getXData(copy=False),
- self.getYData(copy=False))
+ guess = _guess_grid(
+ self.getXData(copy=False), self.getYData(copy=False)
+ )
if guess is None:
_logger.warning(
- 'Cannot guess a grid: Cannot display as regular grid image')
+ "Cannot guess a grid: Cannot display as regular grid image"
+ )
return None
if shape is None:
shape = guess[1]
@@ -397,16 +419,18 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
if nbpoints > shape[0] * shape[1]:
# More data points that provided grid shape: enlarge grid
_logger.warning(
- "More data points than provided grid shape size: extends grid")
+ "More data points than provided grid shape size: extends grid"
+ )
dim0, dim1 = shape
- if order == 'row': # keep dim1, enlarge dim0
+ if order == "row": # keep dim1, enlarge dim0
dim0 = nbpoints // dim1 + (1 if nbpoints % dim1 else 0)
else: # keep dim0, enlarge dim1
dim1 = nbpoints // dim0 + (1 if nbpoints % dim0 else 0)
shape = dim0, dim1
bounds = self.getVisualizationParameter(
- self.VisualizationParameter.GRID_BOUNDS)
+ self.VisualizationParameter.GRID_BOUNDS
+ )
if bounds is None:
x, y = self.getXData(copy=False), self.getYData(copy=False)
min_, max_ = min_max(x)
@@ -416,10 +440,12 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
bounds = (xRange[0], yRange[0]), (xRange[1], yRange[1])
begin, end = bounds
- scale = ((end[0] - begin[0]) / max(1, shape[1] - 1),
- (end[1] - begin[1]) / max(1, shape[0] - 1))
+ scale = (
+ (end[0] - begin[0]) / max(1, shape[1] - 1),
+ (end[1] - begin[1]) / max(1, shape[0] - 1),
+ )
if scale[0] == 0 and scale[1] == 0:
- scale = 1., 1.
+ scale = 1.0, 1.0
elif scale[0] == 0:
scale = scale[1], scale[1]
elif scale[1] == 0:
@@ -428,7 +454,8 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
origin = begin[0] - 0.5 * scale[0], begin[1] - 0.5 * scale[1]
self.__cacheRegularGridInfo = _RegularGridInfo(
- bounds=bounds, origin=origin, scale=scale, shape=shape, order=order)
+ bounds=bounds, origin=origin, scale=scale, shape=shape, order=order
+ )
return self.__cacheRegularGridInfo
@@ -436,9 +463,10 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
"""Get histogram info"""
if self.__cacheHistogramInfo is None:
shape = self.getVisualizationParameter(
- self.VisualizationParameter.BINNED_STATISTIC_SHAPE)
+ self.VisualizationParameter.BINNED_STATISTIC_SHAPE
+ )
if shape is None:
- shape = 100, 100 # TODO compute auto shape
+ shape = 100, 100 # TODO compute auto shape
x, y, values = self.getData(copy=False)[:3]
if len(x) == 0: # No histogram
@@ -451,31 +479,40 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
if not numpy.issubdtype(values.dtype, numpy.floating):
values = values.astype(numpy.float64)
- ranges = (tuple(min_max(y, finite=True)),
- tuple(min_max(x, finite=True)))
+ ranges = (tuple(min_max(y, finite=True)), tuple(min_max(x, finite=True)))
rangesHint = self.getVisualizationParameter(
- self.VisualizationParameter.DATA_BOUNDS_HINT)
+ self.VisualizationParameter.DATA_BOUNDS_HINT
+ )
if rangesHint is not None:
- ranges = tuple((min(dataMin, hintMin), max(dataMax, hintMax))
- for (dataMin, dataMax), (hintMin, hintMax) in zip(ranges, rangesHint))
+ ranges = tuple(
+ (min(dataMin, hintMin), max(dataMax, hintMax))
+ for (dataMin, dataMax), (hintMin, hintMax) in zip(
+ ranges, rangesHint
+ )
+ )
points = numpy.transpose(numpy.array((y, x)))
counts, sums, bin_edges = Histogramnd(
- points,
- histo_range=ranges,
- n_bins=shape,
- weights=values)
+ points, histo_range=ranges, n_bins=shape, weights=values
+ )
yEdges, xEdges = bin_edges
origin = xEdges[0], yEdges[0]
- scale = ((xEdges[-1] - xEdges[0]) / (len(xEdges) - 1),
- (yEdges[-1] - yEdges[0]) / (len(yEdges) - 1))
+ scale = (
+ (xEdges[-1] - xEdges[0]) / (len(xEdges) - 1),
+ (yEdges[-1] - yEdges[0]) / (len(yEdges) - 1),
+ )
- with numpy.errstate(divide='ignore', invalid='ignore'):
+ with numpy.errstate(divide="ignore", invalid="ignore"):
histo = sums / counts
self.__cacheHistogramInfo = _HistogramInfo(
- mean=histo, count=counts, sum=sums,
- origin=origin, scale=scale, shape=shape)
+ mean=histo,
+ count=counts,
+ sum=sums,
+ origin=origin,
+ scale=scale,
+ shape=shape,
+ )
return self.__cacheHistogramInfo
@@ -495,7 +532,8 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
"""Update backend renderer"""
# Filter-out values <= 0
xFiltered, yFiltered, valueFiltered, xerror, yerror = self.getData(
- copy=False, displayed=True)
+ copy=False, displayed=True
+ )
# Remove not finite numbers (this includes filtered out x, y <= 0)
mask = numpy.logical_and(numpy.isfinite(xFiltered), numpy.isfinite(yFiltered))
@@ -509,62 +547,79 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
if visualization is self.Visualization.BINNED_STATISTIC:
plot = self.getPlot()
- if (plot is None or
- plot.getXAxis().getScale() != Axis.LINEAR or
- plot.getYAxis().getScale() != Axis.LINEAR):
+ if (
+ plot is None
+ or plot.getXAxis().getScale() != Axis.LINEAR
+ or plot.getYAxis().getScale() != Axis.LINEAR
+ ):
# Those visualizations are not available with log scaled axes
return None
histoInfo = self.__getHistogramInfo()
if histoInfo is None:
return None
- data = getattr(histoInfo, self.getVisualizationParameter(
- self.VisualizationParameter.BINNED_STATISTIC_FUNCTION))
+ data = getattr(
+ histoInfo,
+ self.getVisualizationParameter(
+ self.VisualizationParameter.BINNED_STATISTIC_FUNCTION
+ ),
+ )
return backend.addImage(
data=data,
origin=histoInfo.origin,
scale=histoInfo.scale,
colormap=self.getColormap(),
- alpha=self.getAlpha())
+ alpha=self.getAlpha(),
+ )
elif visualization is self.Visualization.POINTS:
rgbacolors = self.__applyColormapToData()
- return backend.addCurve(xFiltered, yFiltered,
- color=rgbacolors[mask],
- symbol=self.getSymbol(),
- linewidth=0,
- linestyle="",
- yaxis='left',
- xerror=xerror,
- yerror=yerror,
- fill=False,
- alpha=self.getAlpha(),
- symbolsize=self.getSymbolSize(),
- baseline=None)
+ return backend.addCurve(
+ xFiltered,
+ yFiltered,
+ color=rgbacolors[mask],
+ gapcolor=None,
+ symbol=self.getSymbol(),
+ linewidth=0,
+ linestyle="",
+ yaxis="left",
+ xerror=xerror,
+ yerror=yerror,
+ fill=False,
+ alpha=self.getAlpha(),
+ symbolsize=self.getSymbolSize(),
+ baseline=None,
+ )
else:
plot = self.getPlot()
- if (plot is None or
- plot.getXAxis().getScale() != Axis.LINEAR or
- plot.getYAxis().getScale() != Axis.LINEAR):
+ if (
+ plot is None
+ or plot.getXAxis().getScale() != Axis.LINEAR
+ or plot.getYAxis().getScale() != Axis.LINEAR
+ ):
# Those visualizations are not available with log scaled axes
return None
if visualization is self.Visualization.SOLID:
- triangulation = self._getDelaunay().result()
- if triangulation is None:
+ try:
+ triangulation = self._getTriangulationFuture().result()
+ except (RuntimeError, ValueError):
_logger.warning(
- 'Cannot get a triangulation: Cannot display as solid surface')
+ "Cannot get a triangulation: Cannot display as solid surface"
+ )
return None
else:
rgbacolors = self.__applyColormapToData()
- triangles = triangulation.simplices.astype(numpy.int32)
- return backend.addTriangles(xFiltered,
- yFiltered,
- triangles,
- color=rgbacolors[mask],
- alpha=self.getAlpha())
+ triangles = triangulation.triangles.astype(numpy.int32)
+ return backend.addTriangles(
+ xFiltered,
+ yFiltered,
+ triangles,
+ color=rgbacolors[mask],
+ alpha=self.getAlpha(),
+ )
elif visualization is self.Visualization.REGULAR_GRID:
gridInfo = self.__getRegularGridInfo()
@@ -572,7 +627,7 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
return None
dim0, dim1 = gridInfo.shape
- if gridInfo.order == 'column': # transposition needed
+ if gridInfo.order == "column": # transposition needed
dim0, dim1 = dim1, dim0
values = self.getValueData(copy=False)
@@ -580,20 +635,21 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
image = values.reshape(dim0, dim1)
else:
# The points do not fill the whole image
- if (self.__alpha is None and
- numpy.issubdtype(values.dtype, numpy.floating)):
+ if self.__alpha is None and numpy.issubdtype(
+ values.dtype, numpy.floating
+ ):
image = numpy.empty(dim0 * dim1, dtype=values.dtype)
- image[:len(values)] = values
- image[len(values):] = float('nan') # Transparent pixels
+ image[: len(values)] = values
+ image[len(values) :] = float("nan") # Transparent pixels
image.shape = dim0, dim1
else: # Per value alpha or no NaN, so convert to RGBA
rgbacolors = self.__applyColormapToData()
image = numpy.empty((dim0 * dim1, 4), dtype=numpy.uint8)
- image[:len(rgbacolors)] = rgbacolors
- image[len(rgbacolors):] = (0, 0, 0, 0) # Transparent pixels
+ image[: len(rgbacolors)] = rgbacolors
+ image[len(rgbacolors) :] = (0, 0, 0, 0) # Transparent pixels
image.shape = dim0, dim1, 4
- if gridInfo.order == 'column':
+ if gridInfo.order == "column":
if image.ndim == 2:
image = numpy.transpose(image)
else:
@@ -613,7 +669,8 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
origin=gridInfo.origin,
scale=gridInfo.scale,
colormap=colormap,
- alpha=self.getAlpha())
+ alpha=self.getAlpha(),
+ )
elif visualization is self.Visualization.IRREGULAR_GRID:
gridInfo = self.__getRegularGridInfo()
@@ -629,33 +686,37 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
nbpoints = len(xFiltered)
if nbpoints == 1:
# single point, render as a square points
- return backend.addCurve(xFiltered, yFiltered,
- color=rgbacolors[mask],
- symbol='s',
- linewidth=0,
- linestyle="",
- yaxis='left',
- xerror=None,
- yerror=None,
- fill=False,
- alpha=self.getAlpha(),
- symbolsize=7,
- baseline=None)
+ return backend.addCurve(
+ xFiltered,
+ yFiltered,
+ color=rgbacolors[mask],
+ gapcolor=None,
+ symbol="s",
+ linewidth=0,
+ linestyle="",
+ yaxis="left",
+ xerror=None,
+ yerror=None,
+ fill=False,
+ alpha=self.getAlpha(),
+ symbolsize=7,
+ baseline=None,
+ )
# Make shape include all points
gridOrder = gridInfo.order
if nbpoints != numpy.prod(shape):
- if gridOrder == 'row':
+ if gridOrder == "row":
shape = int(numpy.ceil(nbpoints / shape[1])), shape[1]
- else: # column-major order
+ else: # column-major order
shape = shape[0], int(numpy.ceil(nbpoints / shape[0]))
if shape[0] < 2 or shape[1] < 2: # Single line, at least 2 points
points = numpy.ones((2, nbpoints, 2), dtype=numpy.float64)
# Use row/column major depending on shape, not on info value
- gridOrder = 'row' if shape[0] == 1 else 'column'
+ gridOrder = "row" if shape[0] == 1 else "column"
- if gridOrder == 'row':
+ if gridOrder == "row":
points[0, :, 0] = xFiltered
points[0, :, 1] = yFiltered
else: # column-major order
@@ -663,35 +724,51 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
points[0, :, 1] = xFiltered
# Add a second line that will be clipped in the end
- points[1, :-1] = points[0, :-1] + numpy.cross(
- points[0, 1:] - points[0, :-1], (0., 0., 1.))[:, :2]
- points[1, -1] = points[0, -1] + numpy.cross(
- points[0, -1] - points[0, -2], (0., 0., 1.))[:2]
+ points[1, :-1] = (
+ points[0, :-1]
+ + numpy.cross(points[0, 1:] - points[0, :-1], (0.0, 0.0, 1.0))[
+ :, :2
+ ]
+ )
+ points[1, -1] = (
+ points[0, -1]
+ + numpy.cross(points[0, -1] - points[0, -2], (0.0, 0.0, 1.0))[
+ :2
+ ]
+ )
points.shape = 2, nbpoints, 2 # Use same shape for both orders
coords, indices = _quadrilateral_grid_as_triangles(points)
- elif gridOrder == 'row': # row-major order
+ elif gridOrder == "row": # row-major order
if nbpoints != numpy.prod(shape):
- points = numpy.empty((numpy.prod(shape), 2), dtype=numpy.float64)
+ points = numpy.empty(
+ (numpy.prod(shape), 2), dtype=numpy.float64
+ )
points[:nbpoints, 0] = xFiltered
points[:nbpoints, 1] = yFiltered
# Index of last element of last fully filled row
index = (nbpoints // shape[1]) * shape[1]
- points[nbpoints:, 0] = xFiltered[index - (numpy.prod(shape) - nbpoints):index]
+ points[nbpoints:, 0] = xFiltered[
+ index - (numpy.prod(shape) - nbpoints) : index
+ ]
points[nbpoints:, 1] = yFiltered[-1]
else:
points = numpy.transpose((xFiltered, yFiltered))
points.shape = shape[0], shape[1], 2
- else: # column-major order
+ else: # column-major order
if nbpoints != numpy.prod(shape):
- points = numpy.empty((numpy.prod(shape), 2), dtype=numpy.float64)
+ points = numpy.empty(
+ (numpy.prod(shape), 2), dtype=numpy.float64
+ )
points[:nbpoints, 0] = yFiltered
points[:nbpoints, 1] = xFiltered
# Index of last element of last fully filled column
index = (nbpoints // shape[0]) * shape[0]
- points[nbpoints:, 0] = yFiltered[index - (numpy.prod(shape) - nbpoints):index]
+ points[nbpoints:, 0] = yFiltered[
+ index - (numpy.prod(shape) - nbpoints) : index
+ ]
points[nbpoints:, 1] = xFiltered[-1]
else:
points = numpy.transpose((yFiltered, xFiltered))
@@ -700,25 +777,24 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
coords, indices = _quadrilateral_grid_as_triangles(points)
# Remove unused extra triangles
- coords = coords[:4*nbpoints]
- indices = indices[:2*nbpoints]
+ coords = coords[: 4 * nbpoints]
+ indices = indices[: 2 * nbpoints]
- if gridOrder == 'row':
+ if gridOrder == "row":
x, y = coords[:, 0], coords[:, 1]
else: # column-major order
y, x = coords[:, 0], coords[:, 1]
rgbacolors = rgbacolors[mask] # Filter-out not finite points
gridcolors = numpy.empty(
- (4 * nbpoints, rgbacolors.shape[-1]), dtype=rgbacolors.dtype)
+ (4 * nbpoints, rgbacolors.shape[-1]), dtype=rgbacolors.dtype
+ )
for first in range(4):
gridcolors[first::4] = rgbacolors[:nbpoints]
- return backend.addTriangles(x,
- y,
- indices,
- color=gridcolors,
- alpha=self.getAlpha())
+ return backend.addTriangles(
+ x, y, indices, color=gridcolors, alpha=self.getAlpha()
+ )
else:
_logger.error("Unhandled visualization %s", visualization)
@@ -747,11 +823,13 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
if gridInfo is None:
return None
- if gridInfo.order == 'row':
+ if gridInfo.order == "row":
index = row * gridInfo.shape[1] + column
else:
index = row + column * gridInfo.shape[0]
- if index >= len(self.getXData(copy=False)): # OK as long as not log scale
+ if index >= len(
+ self.getXData(copy=False)
+ ): # OK as long as not log scale
return None # Image can be larger than scatter
result = PickingResult(self, (index,))
@@ -768,9 +846,16 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
ox, oy = histoInfo.origin
xdata = self.getXData(copy=False)
ydata = self.getYData(copy=False)
- indices = numpy.nonzero(numpy.logical_and(
- numpy.logical_and(xdata >= ox + sx * col, xdata < ox + sx * (col + 1)),
- numpy.logical_and(ydata >= oy + sy * row, ydata < oy + sy * (row + 1))))[0]
+ indices = numpy.nonzero(
+ numpy.logical_and(
+ numpy.logical_and(
+ xdata >= ox + sx * col, xdata < ox + sx * (col + 1)
+ ),
+ numpy.logical_and(
+ ydata >= oy + sy * row, ydata < oy + sy * (row + 1)
+ ),
+ )
+ )[0]
result = None if len(indices) == 0 else PickingResult(self, indices)
return result
@@ -784,69 +869,43 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
self.__executor = _GreedyThreadPoolExecutor(max_workers=2)
return self.__executor
- def _getDelaunay(self):
- """Returns a :class:`Future` which result is the Delaunay object.
+ def _getTriangulationFuture(self):
+ """Returns a :class:`Future` which result is the Triangulation object.
:rtype: concurrent.futures.Future
"""
- if self.__delaunayFuture is None or self.__delaunayFuture.cancelled():
+ if self.__triangulationFuture is None or self.__triangulationFuture.cancelled():
# Need to init a new delaunay
x, y = self.getData(copy=False)[:2]
# Remove not finite points
mask = numpy.logical_and(numpy.isfinite(x), numpy.isfinite(y))
- self.__delaunayFuture = self.__getExecutor().submit_greedy(
- 'delaunay', delaunay, x[mask], y[mask])
+ self.__triangulationFuture = self.__getExecutor().submit_greedy(
+ "Triangulation", Triangulation, x[mask], y[mask]
+ )
- return self.__delaunayFuture
+ return self.__triangulationFuture
@staticmethod
- def __initInterpolator(delaunayFuture, values):
+ def __initInterpolator(triangulationFuture, values):
"""Returns an interpolator for the given data points
- :param concurrent.futures.Future delaunayFuture:
- Future object which result is a Delaunay object
+ :param concurrent.futures.Future triangulationFuture:
+ Future object which result is a Triangulation object
:param numpy.ndarray values: The data value of valid points.
:rtype: Union[callable,None]
"""
- # Wait for Delaunay to complete
+ # Wait for Triangulation to complete
try:
- triangulation = delaunayFuture.result()
+ triangulation = triangulationFuture.result()
+ except (RuntimeError, ValueError):
+ return None # triangulation failed
except CancelledError:
- triangulation = None
-
- if triangulation is None:
- interpolator = None # Error case
- else:
- # Lazy-loading of interpolator
- try:
- from scipy.interpolate import LinearNDInterpolator
- except ImportError:
- LinearNDInterpolator = None
-
- if LinearNDInterpolator is not None:
- interpolator = LinearNDInterpolator(triangulation, values)
-
- # First call takes a while, do it here
- interpolator([(0., 0.)])
-
- else:
- # Fallback using matplotlib interpolator
- import matplotlib.tri
-
- x, y = triangulation.points.T
- tri = matplotlib.tri.Triangulation(
- x, y, triangles=triangulation.simplices)
- mplInterpolator = matplotlib.tri.LinearTriInterpolator(
- tri, values)
-
- # Wrap interpolator to have same API as scipy's one
- def interpolator(points):
- return mplInterpolator(*points.T)
+ return None
- return interpolator
+ return LinearTriInterpolator(triangulation, values)
- def _getInterpolator(self):
+ def _getInterpolatorFuture(self):
"""Returns a :class:`Future` which result is the interpolator.
The interpolator is a callable taking an array Nx2 of points
@@ -856,8 +915,7 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
:rtype: concurrent.futures.Future
"""
- if (self.__interpolatorFuture is None or
- self.__interpolatorFuture.cancelled()):
+ if self.__interpolatorFuture is None or self.__interpolatorFuture.cancelled():
# Need to init a new interpolator
x, y, values = self.getData(copy=False)[:3]
# Remove not finite points
@@ -865,8 +923,11 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
x, y, values = x[mask], y[mask], values[mask]
self.__interpolatorFuture = self.__getExecutor().submit_greedy(
- 'interpolator',
- self.__initInterpolator, self._getDelaunay(), values)
+ "interpolator",
+ self.__initInterpolator,
+ self._getTriangulationFuture(),
+ values,
+ )
return self.__interpolatorFuture
def _logFilterData(self, xPositive, yPositive):
@@ -928,11 +989,13 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
assert len(data) == 5
return data
- return (self.getXData(copy),
- self.getYData(copy),
- self.getValueData(copy),
- self.getXErrorData(copy),
- self.getYErrorData(copy))
+ return (
+ self.getXData(copy),
+ self.getYData(copy),
+ self.getValueData(copy),
+ self.getXErrorData(copy),
+ self.getYErrorData(copy),
+ )
# reimplemented from PointsBase to handle `value`
def setData(self, x, y, value, xerror=None, yerror=None, alpha=None, copy=True):
@@ -951,7 +1014,7 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
:param yerror: Values with the uncertainties on the y values
:type yerror: A float, or a numpy.ndarray of float32. See xerror.
:param alpha: Values with the transparency (between 0 and 1)
- :type alpha: A float, or a numpy.ndarray of float32
+ :type alpha: A float, or a numpy.ndarray of float32
:param bool copy: True make a copy of the data (default),
False to use provided arrays.
"""
@@ -961,14 +1024,13 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
# Convert complex data
if numpy.iscomplexobj(value):
- _logger.warning(
- 'Converting value data to absolute value to plot it.')
+ _logger.warning("Converting value data to absolute value to plot it.")
value = numpy.absolute(value)
# Reset triangulation and interpolator
- if self.__delaunayFuture is not None:
- self.__delaunayFuture.cancel()
- self.__delaunayFuture = None
+ if self.__triangulationFuture is not None:
+ self.__triangulationFuture.cancel()
+ self.__triangulationFuture = None
if self.__interpolatorFuture is not None:
self.__interpolatorFuture.cancel()
self.__interpolatorFuture = None
@@ -984,10 +1046,10 @@ class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
alpha = numpy.array(alpha, copy=copy)
assert alpha.ndim == 1
assert len(x) == len(alpha)
- if alpha.dtype.kind != 'f':
+ if alpha.dtype.kind != "f":
alpha = alpha.astype(numpy.float32)
- if numpy.any(numpy.logical_or(alpha < 0., alpha > 1.)):
- alpha = numpy.clip(alpha, 0., 1.)
+ if numpy.any(numpy.logical_or(alpha < 0.0, alpha > 1.0)):
+ alpha = numpy.clip(alpha, 0.0, 1.0)
self.__alpha = alpha
# set x, y, xerror, yerror
diff --git a/src/silx/gui/plot/items/shape.py b/src/silx/gui/plot/items/shape.py
index dc35864..c911924 100644
--- a/src/silx/gui/plot/items/shape.py
+++ b/src/silx/gui/plot/items/shape.py
@@ -33,11 +33,18 @@ import logging
import numpy
-from ... import colors
-from ..utils.intersections import lines_intersection
from .core import (
- Item, DataItem,
- AlphaMixIn, ColorMixIn, FillMixIn, ItemChangedType, ItemMixInBase, LineMixIn, YAxisMixIn)
+ Item,
+ DataItem,
+ AlphaMixIn,
+ ColorMixIn,
+ FillMixIn,
+ ItemChangedType,
+ LineMixIn,
+ LineGapColorMixIn,
+ YAxisMixIn,
+)
+from ....utils.deprecation import deprecated
_logger = logging.getLogger(__name__)
@@ -65,41 +72,20 @@ class _OverlayItem(Item):
self._updated(ItemChangedType.OVERLAY)
-class _TwoColorsLineMixIn(LineMixIn):
+class _TwoColorsLineMixIn(LineMixIn, LineGapColorMixIn):
"""Mix-in class for items with a background color for dashes"""
def __init__(self):
LineMixIn.__init__(self)
- self.__backgroundColor = None
+ LineGapColorMixIn.__init__(self)
+ @deprecated(replacement="getLineGapColor", since_version="2.0.0")
def getLineBgColor(self):
- """Returns the RGBA background color of dash line
+ return self.getLineGapColor()
- :rtype: 4-tuple of float in [0, 1] or array of colors
- """
- return self.__backgroundColor
-
- def setLineBgColor(self, color, copy: bool=True):
- """Set dash line background color
-
- :param color: color(s) to be used
- :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or
- one of the predefined color names defined in colors.py
- :param copy: True (Default) to get a copy,
- False to use internal representation (do not modify!)
- """
- if color is not None:
- if isinstance(color, str):
- color = colors.rgba(color)
- else:
- color = numpy.array(color, copy=copy)
- # TODO more checks + improve color array support
- if color.ndim == 1: # Single RGBA color
- color = colors.rgba(color)
- else: # Array of colors
- assert color.ndim == 2
-
- self.__backgroundColor = color
+ @deprecated(replacement="setLineGapColor", since_version="2.0.0")
+ def setLineBgColor(self, color, copy: bool = True):
+ self.setLineGapColor(color)
self._updated(ItemChangedType.LINE_BG_COLOR)
@@ -117,7 +103,7 @@ class Shape(_OverlayItem, ColorMixIn, FillMixIn, _TwoColorsLineMixIn):
ColorMixIn.__init__(self)
FillMixIn.__init__(self)
_TwoColorsLineMixIn.__init__(self)
- assert type_ in ('hline', 'polygon', 'rectangle', 'vline', 'polylines')
+ assert type_ in ("hline", "polygon", "rectangle", "vline", "polylines")
self._type = type_
self._points = ()
self._handle = None
@@ -126,15 +112,17 @@ class Shape(_OverlayItem, ColorMixIn, FillMixIn, _TwoColorsLineMixIn):
"""Update backend renderer"""
points = self.getPoints(copy=False)
x, y = points.T[0], points.T[1]
- return backend.addShape(x,
- y,
- shape=self.getType(),
- color=self.getColor(),
- fill=self.isFill(),
- overlay=self.isOverlay(),
- linestyle=self.getLineStyle(),
- linewidth=self.getLineWidth(),
- linebgcolor=self.getLineBgColor())
+ return backend.addShape(
+ x,
+ y,
+ shape=self.getType(),
+ color=self.getColor(),
+ fill=self.isFill(),
+ overlay=self.isOverlay(),
+ linestyle=self.getLineStyle(),
+ linewidth=self.getLineWidth(),
+ gapcolor=self.getLineGapColor(),
+ )
def getType(self):
"""Returns the type of shape to draw.
@@ -226,11 +214,11 @@ class _BaseExtent(DataItem):
:param str axis: Either 'x' or 'y'.
"""
- def __init__(self, axis='x'):
- assert axis in ('x', 'y')
+ def __init__(self, axis="x"):
+ assert axis in ("x", "y")
DataItem.__init__(self)
self.__axis = axis
- self.__range = 1., 100.
+ self.__range = 1.0, 100.0
def setRange(self, min_, max_):
"""Set the range of the extent of this item in data coordinates.
@@ -262,17 +250,17 @@ class _BaseExtent(DataItem):
plot = self.getPlot()
if plot is not None:
- axis = plot.getXAxis() if self.__axis == 'x' else plot.getYAxis()
+ axis = plot.getXAxis() if self.__axis == "x" else plot.getYAxis()
if axis._isLogarithmic():
if max_ <= 0:
return None
if min_ <= 0:
min_ = max_
- if self.__axis == 'x':
- return min_, max_, float('nan'), float('nan')
+ if self.__axis == "x":
+ return min_, max_, float("nan"), float("nan")
else:
- return float('nan'), float('nan'), min_, max_
+ return float("nan"), float("nan"), min_, max_
class XAxisExtent(_BaseExtent):
@@ -282,8 +270,9 @@ class XAxisExtent(_BaseExtent):
item with a horizontal extent regarding plot data bounds, i.e.,
:meth:`PlotWidget.resetZoom` will take this horizontal extent into account.
"""
+
def __init__(self):
- _BaseExtent.__init__(self, axis='x')
+ _BaseExtent.__init__(self, axis="x")
class YAxisExtent(_BaseExtent, YAxisMixIn):
@@ -295,7 +284,7 @@ class YAxisExtent(_BaseExtent, YAxisMixIn):
"""
def __init__(self):
- _BaseExtent.__init__(self, axis='y')
+ _BaseExtent.__init__(self, axis="y")
YAxisMixIn.__init__(self)
@@ -305,7 +294,7 @@ class Line(_OverlayItem, AlphaMixIn, ColorMixIn, _TwoColorsLineMixIn):
Warning: If slope is not finite, then the line is x = intercept.
"""
- def __init__(self, slope: float=0, intercept: float=0):
+ def __init__(self, slope: float = 0, intercept: float = 0):
assert numpy.isfinite(intercept)
_OverlayItem.__init__(self)
@@ -378,7 +367,7 @@ class Line(_OverlayItem, AlphaMixIn, ColorMixIn, _TwoColorsLineMixIn):
"""Set slope and intercept from 2 (x, y) points"""
x0, y0 = point0
x1, y1 = point1
- if x0 == x1: # Special case: vertical line
+ if x0 == x1: # Special case: vertical line
self.setSlope(float("inf"))
self.setIntercept(x0)
return
@@ -394,11 +383,11 @@ class Line(_OverlayItem, AlphaMixIn, ColorMixIn, _TwoColorsLineMixIn):
return backend.addShape(
*self.__coordinates,
- shape='polylines',
+ shape="polylines",
color=self.getColor(),
fill=False,
overlay=self.isOverlay(),
linestyle=self.getLineStyle(),
linewidth=self.getLineWidth(),
- linebgcolor=self.getLineBgColor(),
+ gapcolor=self.getLineGapColor(),
)
diff --git a/src/silx/gui/plot/matplotlib/Colormap.py b/src/silx/gui/plot/matplotlib/Colormap.py
deleted file mode 100644
index 1131df8..0000000
--- a/src/silx/gui/plot/matplotlib/Colormap.py
+++ /dev/null
@@ -1,248 +0,0 @@
-# /*##########################################################################
-# Copyright (C) 2017-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Matplotlib's new colormaps"""
-
-import numpy
-import logging
-from matplotlib.colors import ListedColormap
-import matplotlib.colors
-import matplotlib.cm
-import silx.resources
-from silx.utils.deprecation import deprecated, deprecated_warning
-
-
-deprecated_warning(type_='module',
- name=__file__,
- replacement='silx.gui.colors.Colormap',
- since_version='0.10.0')
-
-
-_logger = logging.getLogger(__name__)
-
-_AVAILABLE_AS_RESOURCE = ('magma', 'inferno', 'plasma', 'viridis')
-"""List available colormap name as resources"""
-
-_AVAILABLE_AS_BUILTINS = ('gray', 'reversed gray',
- 'temperature', 'red', 'green', 'blue')
-"""List of colormaps available through built-in declarations"""
-
-_CMAPS = {}
-"""Cache colormaps"""
-
-
-@property
-@deprecated(since_version='0.10.0')
-def magma():
- return getColormap('magma')
-
-
-@property
-@deprecated(since_version='0.10.0')
-def inferno():
- return getColormap('inferno')
-
-
-@property
-@deprecated(since_version='0.10.0')
-def plasma():
- return getColormap('plasma')
-
-
-@property
-@deprecated(since_version='0.10.0')
-def viridis():
- return getColormap('viridis')
-
-
-@deprecated(since_version='0.10.0')
-def getColormap(name):
- """Returns matplotlib colormap corresponding to given name
-
- :param str name: The name of the colormap
- :return: The corresponding colormap
- :rtype: matplolib.colors.Colormap
- """
- if not _CMAPS: # Lazy initialization of own colormaps
- cdict = {'red': ((0.0, 0.0, 0.0),
- (1.0, 1.0, 1.0)),
- 'green': ((0.0, 0.0, 0.0),
- (1.0, 0.0, 0.0)),
- 'blue': ((0.0, 0.0, 0.0),
- (1.0, 0.0, 0.0))}
- _CMAPS['red'] = matplotlib.colors.LinearSegmentedColormap(
- 'red', cdict, 256)
-
- cdict = {'red': ((0.0, 0.0, 0.0),
- (1.0, 0.0, 0.0)),
- 'green': ((0.0, 0.0, 0.0),
- (1.0, 1.0, 1.0)),
- 'blue': ((0.0, 0.0, 0.0),
- (1.0, 0.0, 0.0))}
- _CMAPS['green'] = matplotlib.colors.LinearSegmentedColormap(
- 'green', cdict, 256)
-
- cdict = {'red': ((0.0, 0.0, 0.0),
- (1.0, 0.0, 0.0)),
- 'green': ((0.0, 0.0, 0.0),
- (1.0, 0.0, 0.0)),
- 'blue': ((0.0, 0.0, 0.0),
- (1.0, 1.0, 1.0))}
- _CMAPS['blue'] = matplotlib.colors.LinearSegmentedColormap(
- 'blue', cdict, 256)
-
- # Temperature as defined in spslut
- cdict = {'red': ((0.0, 0.0, 0.0),
- (0.5, 0.0, 0.0),
- (0.75, 1.0, 1.0),
- (1.0, 1.0, 1.0)),
- 'green': ((0.0, 0.0, 0.0),
- (0.25, 1.0, 1.0),
- (0.75, 1.0, 1.0),
- (1.0, 0.0, 0.0)),
- 'blue': ((0.0, 1.0, 1.0),
- (0.25, 1.0, 1.0),
- (0.5, 0.0, 0.0),
- (1.0, 0.0, 0.0))}
- # but limited to 256 colors for a faster display (of the colorbar)
- _CMAPS['temperature'] = \
- matplotlib.colors.LinearSegmentedColormap(
- 'temperature', cdict, 256)
-
- # reversed gray
- cdict = {'red': ((0.0, 1.0, 1.0),
- (1.0, 0.0, 0.0)),
- 'green': ((0.0, 1.0, 1.0),
- (1.0, 0.0, 0.0)),
- 'blue': ((0.0, 1.0, 1.0),
- (1.0, 0.0, 0.0))}
-
- _CMAPS['reversed gray'] = \
- matplotlib.colors.LinearSegmentedColormap(
- 'yerg', cdict, 256)
-
- if name in _CMAPS:
- return _CMAPS[name]
- elif name in _AVAILABLE_AS_RESOURCE:
- filename = silx.resources.resource_filename("gui/colormaps/%s.npy" % name)
- data = numpy.load(filename)
- lut = ListedColormap(data, name=name)
- _CMAPS[name] = lut
- return lut
- else:
- # matplotlib built-in
- return matplotlib.cm.get_cmap(name)
-
-
-@deprecated(since_version='0.10.0')
-def getScalarMappable(colormap, data=None):
- """Returns matplotlib ScalarMappable corresponding to colormap
-
- :param :class:`.Colormap` colormap: The colormap to convert
- :param numpy.ndarray data:
- The data on which the colormap is applied.
- If provided, it is used to compute autoscale.
- :return: matplotlib object corresponding to colormap
- :rtype: matplotlib.cm.ScalarMappable
- """
- assert colormap is not None
-
- if colormap.getName() is not None:
- cmap = getColormap(colormap.getName())
-
- else: # No name, use custom colors
- if colormap.getColormapLUT() is None:
- raise ValueError(
- 'addImage: colormap no name nor list of colors.')
- colors = colormap.getColormapLUT()
- assert len(colors.shape) == 2
- assert colors.shape[-1] in (3, 4)
- if colors.dtype == numpy.uint8:
- # Convert to float in [0., 1.]
- colors = colors.astype(numpy.float32) / 255.
- cmap = matplotlib.colors.ListedColormap(colors)
-
- vmin, vmax = colormap.getColormapRange(data)
- normalization = colormap.getNormalization()
- if normalization == colormap.LOGARITHM:
- norm = matplotlib.colors.LogNorm(vmin, vmax)
- elif normalization == colormap.LINEAR:
- norm = matplotlib.colors.Normalize(vmin, vmax)
- else:
- raise RuntimeError("Unsupported normalization: %s" % normalization)
-
- return matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
-
-
-@deprecated(replacement='silx.colors.Colormap.applyToData',
- since_version='0.8.0')
-def applyColormapToData(data, colormap):
- """Apply a colormap to the data and returns the RGBA image
-
- This supports data of any dimensions (not only of dimension 2).
- The returned array will have one more dimension (with 4 entries)
- than the input data to store the RGBA channels
- corresponding to each bin in the array.
-
- :param numpy.ndarray data: The data to convert.
- :param :class:`.Colormap`: The colormap to apply
- """
- # Debian 7 specific support
- # No transparent colormap with matplotlib < 1.2.0
- # Add support for transparent colormap for uint8 data with
- # colormap with 256 colors, linear norm, [0, 255] range
- if matplotlib.__version__ < '1.2.0':
- if (colormap.getName() is None and
- colormap.getColormapLUT() is not None):
- colors = colormap.getColormapLUT()
- if (colors.shape[-1] == 4 and
- not numpy.all(numpy.equal(colors[3], 255))):
- # This is a transparent colormap
- if (colors.shape == (256, 4) and
- colormap.getNormalization() == 'linear' and
- not colormap.isAutoscale() and
- colormap.getVMin() == 0 and
- colormap.getVMax() == 255 and
- data.dtype == numpy.uint8):
- # Supported case, convert data to RGBA
- return colors[data.reshape(-1)].reshape(
- data.shape + (4,))
- else:
- _logger.warning(
- 'matplotlib %s does not support transparent '
- 'colormap.', matplotlib.__version__)
-
- scalarMappable = getScalarMappable(colormap, data)
- rgbaImage = scalarMappable.to_rgba(data, bytes=True)
-
- return rgbaImage
-
-
-@deprecated(replacement='silx.colors.Colormap.getSupportedColormaps',
- since_version='0.10.0')
-def getSupportedColormaps():
- """Get the supported colormap names as a tuple of str.
- """
- colormaps = set(matplotlib.cm.datad.keys())
- colormaps.update(_AVAILABLE_AS_BUILTINS)
- colormaps.update(_AVAILABLE_AS_RESOURCE)
- return tuple(sorted(colormaps))
diff --git a/src/silx/gui/plot/stats/stats.py b/src/silx/gui/plot/stats/stats.py
index d266d5c..d575e3f 100644
--- a/src/silx/gui/plot/stats/stats.py
+++ b/src/silx/gui/plot/stats/stats.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2017-2022 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -31,7 +31,6 @@ __license__ = "MIT"
__date__ = "06/06/2018"
-from collections import OrderedDict
from functools import lru_cache
import logging
@@ -44,12 +43,11 @@ from ..items.roi import RegionOfInterest
from ....math.combo import min_max
from silx.utils.proxy import docstring
-from ....utils.deprecation import deprecated
logger = logging.getLogger(__name__)
-class Stats(OrderedDict):
+class Stats(dict):
"""Class to define a set of statistic relative to a dataset
(image, curve...).
@@ -60,15 +58,17 @@ class Stats(OrderedDict):
:param List statslist: List of the :class:`Stat` object to be computed.
"""
+
def __init__(self, statslist=None):
- OrderedDict.__init__(self)
+ super().__init__()
_statslist = statslist if not None else []
if statslist is not None:
for stat in _statslist:
self.add(stat)
- def calculate(self, item, plot, onlimits, roi, data_changed=False,
- roi_changed=False):
+ def calculate(
+ self, item, plot, onlimits, roi, data_changed=False, roi_changed=False
+ ):
"""
Call all :class:`Stat` object registered and return the result of the
computation.
@@ -87,27 +87,26 @@ class Stats(OrderedDict):
of the calculation as value
"""
res = {}
- context = self._getContext(item=item, plot=plot, onlimits=onlimits,
- roi=roi)
+ context = self._getContext(item=item, plot=plot, onlimits=onlimits, roi=roi)
for statName, stat in list(self.items()):
if context.kind not in stat.compatibleKinds:
- logger.debug('kind %s not managed by statistic %s'
- % (context.kind, stat.name))
+ logger.debug(
+ "kind %s not managed by statistic %s" % (context.kind, stat.name)
+ )
res[statName] = None
else:
if roi_changed is True:
context.clear_mask()
if data_changed is True or roi_changed is True:
# if data changed or mask changed
- context.clipData(item=item, plot=plot, onlimits=onlimits,
- roi=roi)
+ context.clipData(item=item, plot=plot, onlimits=onlimits, roi=roi)
# init roi and data
res[statName] = stat.calculate(context)
return res
def __setitem__(self, key, value):
assert isinstance(value, StatBase)
- OrderedDict.__setitem__(self, key, value)
+ super().__setitem__(key, value)
def add(self, stat):
"""Add a :class:`Stat` to the set
@@ -134,14 +133,11 @@ class Stats(OrderedDict):
from ...plot3d import items as items3d # Lazy import
if isinstance(item, (items3d.Scatter2D, items3d.Scatter3D)):
- context = _plot3DScatterContext(item, plot, onlimits,
- roi=roi)
- elif isinstance(item,
- (items3d.ImageData, items3d.ScalarField3D)):
- context = _plot3DArrayContext(item, plot, onlimits,
- roi=roi)
+ context = _plot3DScatterContext(item, plot, onlimits, roi=roi)
+ elif isinstance(item, (items3d.ImageData, items3d.ScalarField3D)):
+ context = _plot3DArrayContext(item, plot, onlimits, roi=roi)
if context is None:
- raise ValueError('Item type not managed')
+ raise ValueError("Item type not managed")
return context
@@ -164,6 +160,7 @@ class _StatsContext(object):
For now, incompatible with `onlimits` calculation
:type roi: Union[None,:class:`_RegionOfInterestBase`]
"""
+
def __init__(self, item, kind, plot, onlimits, roi):
assert item
assert plot
@@ -234,13 +231,6 @@ class _StatsContext(object):
"""
raise NotImplementedError("Base class")
- @deprecated(reason="context are now stored and keep during stats life."
- "So this function will be called only once",
- replacement="clipData", since_version="0.13.0")
- def createContext(self, item, plot, onlimits, roi):
- return self.clipData(item=item, plot=plot, onlimits=onlimits,
- roi=roi)
-
def isStructuredData(self):
"""Returns True if data as an array-like structure.
@@ -271,15 +261,18 @@ class _StatsContext(object):
def _checkContextInputs(self, item, plot, onlimits, roi):
if roi is not None and onlimits is True:
- raise ValueError('Stats context is unable to manage both a ROI'
- 'and the `onlimits` option')
+ raise ValueError(
+ "Stats context is unable to manage both a ROI"
+ "and the `onlimits` option"
+ )
class _ScatterCurveHistoMixInContext(_StatsContext):
def __init__(self, kind, item, plot, onlimits, roi):
self.clear_mask()
- _StatsContext.__init__(self, item=item, kind=kind,
- plot=plot, onlimits=onlimits, roi=roi)
+ _StatsContext.__init__(
+ self, item=item, kind=kind, plot=plot, onlimits=onlimits, roi=roi
+ )
def _set_mask_validity(self, onlimits, from_, to_):
self._onlimits = onlimits
@@ -292,8 +285,7 @@ class _ScatterCurveHistoMixInContext(_StatsContext):
self._to_ = None
def is_mask_valid(self, onlimits, from_, to_):
- return (onlimits == self.onlimits and from_ == self._from_ and
- to_ == self._to_)
+ return onlimits == self.onlimits and from_ == self._from_ and to_ == self._to_
class _CurveContext(_ScatterCurveHistoMixInContext):
@@ -308,15 +300,15 @@ class _CurveContext(_ScatterCurveHistoMixInContext):
For now, incompatible with `onlinits` calculation
:type roi: Union[None, :class:`ROI`]
"""
+
def __init__(self, item, plot, onlimits, roi):
- _ScatterCurveHistoMixInContext.__init__(self, kind='curve', item=item,
- plot=plot, onlimits=onlimits,
- roi=roi)
+ _ScatterCurveHistoMixInContext.__init__(
+ self, kind="curve", item=item, plot=plot, onlimits=onlimits, roi=roi
+ )
@docstring(_StatsContext)
def clipData(self, item, plot, onlimits, roi):
- self._checkContextInputs(item=item, plot=plot, onlimits=onlimits,
- roi=roi)
+ self._checkContextInputs(item=item, plot=plot, onlimits=onlimits, roi=roi)
self.roi = roi
self.onlimits = onlimits
xData, yData = item.getData(copy=True)[0:2]
@@ -353,10 +345,11 @@ class _CurveContext(_ScatterCurveHistoMixInContext):
self.axes = (xData,)
def _checkContextInputs(self, item, plot, onlimits, roi):
- _StatsContext._checkContextInputs(self, item=item, plot=plot,
- onlimits=onlimits, roi=roi)
+ _StatsContext._checkContextInputs(
+ self, item=item, plot=plot, onlimits=onlimits, roi=roi
+ )
if roi is not None and not isinstance(roi, ROI):
- raise TypeError('curve `context` can ony manage 1D roi')
+ raise TypeError("curve `context` can ony manage 1D roi")
class _HistogramContext(_ScatterCurveHistoMixInContext):
@@ -371,15 +364,15 @@ class _HistogramContext(_ScatterCurveHistoMixInContext):
For now, incompatible with `onlinits` calculation
:type roi: Union[None, :class:`ROI`]
"""
+
def __init__(self, item, plot, onlimits, roi):
- _ScatterCurveHistoMixInContext.__init__(self, kind='histogram',
- item=item, plot=plot,
- onlimits=onlimits, roi=roi)
+ _ScatterCurveHistoMixInContext.__init__(
+ self, kind="histogram", item=item, plot=plot, onlimits=onlimits, roi=roi
+ )
@docstring(_StatsContext)
def clipData(self, item, plot, onlimits, roi):
- self._checkContextInputs(item=item, plot=plot, onlimits=onlimits,
- roi=roi)
+ self._checkContextInputs(item=item, plot=plot, onlimits=onlimits, roi=roi)
yData, edges = item.getData(copy=True)[0:2]
xData = item._revertComputeEdges(x=edges, histogramType=item.getAlignment())
@@ -392,13 +385,16 @@ class _HistogramContext(_ScatterCurveHistoMixInContext):
mask = mask == 0
self._set_mask_validity(onlimits=onlimits, from_=minX, to_=maxX)
elif roi:
- if self.is_mask_valid(onlimits=onlimits, from_=roi._fromdata, to_=roi._todata):
+ if self.is_mask_valid(
+ onlimits=onlimits, from_=roi._fromdata, to_=roi._todata
+ ):
mask = self.mask
else:
mask = (roi._fromdata <= xData) & (xData <= roi._todata)
mask = mask == 0
- self._set_mask_validity(onlimits=onlimits, from_=roi._fromdata,
- to_=roi._todata)
+ self._set_mask_validity(
+ onlimits=onlimits, from_=roi._fromdata, to_=roi._todata
+ )
else:
mask = numpy.zeros_like(yData)
mask = mask.astype(numpy.uint32)
@@ -414,11 +410,12 @@ class _HistogramContext(_ScatterCurveHistoMixInContext):
self.axes = (self.xData,)
def _checkContextInputs(self, item, plot, onlimits, roi):
- _StatsContext._checkContextInputs(self, item=item, plot=plot,
- onlimits=onlimits, roi=roi)
+ _StatsContext._checkContextInputs(
+ self, item=item, plot=plot, onlimits=onlimits, roi=roi
+ )
if roi is not None and not isinstance(roi, ROI):
- raise TypeError('curve `context` can ony manage 1D roi')
+ raise TypeError("curve `context` can ony manage 1D roi")
class _ScatterContext(_ScatterCurveHistoMixInContext):
@@ -434,15 +431,15 @@ class _ScatterContext(_ScatterCurveHistoMixInContext):
For now, incompatible with `onlinits` calculation
:type roi: Union[None, :class:`ROI`]
"""
+
def __init__(self, item, plot, onlimits, roi):
- _ScatterCurveHistoMixInContext.__init__(self, kind='scatter',
- item=item, plot=plot,
- onlimits=onlimits, roi=roi)
+ _ScatterCurveHistoMixInContext.__init__(
+ self, kind="scatter", item=item, plot=plot, onlimits=onlimits, roi=roi
+ )
@docstring(_ScatterCurveHistoMixInContext)
def clipData(self, item, plot, onlimits, roi):
- self._checkContextInputs(item=item, plot=plot, onlimits=onlimits,
- roi=roi)
+ self._checkContextInputs(item=item, plot=plot, onlimits=onlimits, roi=roi)
valueData = item.getValueData(copy=True)
xData = item.getXData(copy=True)
yData = item.getYData(copy=True)
@@ -461,8 +458,9 @@ class _ScatterContext(_ScatterCurveHistoMixInContext):
yData = yData[(minY <= yData) & (yData <= maxY)]
if roi:
- if self.is_mask_valid(onlimits=onlimits, from_=roi.getFrom(),
- to_=roi.getTo()):
+ if self.is_mask_valid(
+ onlimits=onlimits, from_=roi.getFrom(), to_=roi.getTo()
+ ):
mask = self.mask
else:
mask = (xData < roi.getFrom()) | (xData > roi.getTo())
@@ -480,11 +478,12 @@ class _ScatterContext(_ScatterCurveHistoMixInContext):
self.min, self.max = None, None
def _checkContextInputs(self, item, plot, onlimits, roi):
- _StatsContext._checkContextInputs(self, item=item, plot=plot,
- onlimits=onlimits, roi=roi)
+ _StatsContext._checkContextInputs(
+ self, item=item, plot=plot, onlimits=onlimits, roi=roi
+ )
if roi is not None and not isinstance(roi, ROI):
- raise TypeError('curve `context` can ony manage 1D roi')
+ raise TypeError("curve `context` can ony manage 1D roi")
class _ImageContext(_StatsContext):
@@ -511,13 +510,14 @@ class _ImageContext(_StatsContext):
For now, incompatible with `onlinits` calculation
:type roi: Union[None, :class:`ROI`]
"""
+
def __init__(self, item, plot, onlimits, roi):
self.clear_mask()
- _StatsContext.__init__(self, kind='image', item=item,
- plot=plot, onlimits=onlimits, roi=roi)
+ _StatsContext.__init__(
+ self, kind="image", item=item, plot=plot, onlimits=onlimits, roi=roi
+ )
- def _set_mask_validity(self, xmin: float, xmax: float, ymin: float, ymax
- : float):
+ def _set_mask_validity(self, xmin: float, xmax: float, ymin: float, ymax: float):
self._mask_x_min = xmin
self._mask_x_max = xmax
self._mask_y_min = ymin
@@ -530,13 +530,16 @@ class _ImageContext(_StatsContext):
self._mask_y_max = None
def is_mask_valid(self, xmin, xmax, ymin, ymax):
- return (xmin == self._mask_x_min and xmax == self._mask_x_max and
- ymin == self._mask_y_min and ymax == self._mask_y_max)
+ return (
+ xmin == self._mask_x_min
+ and xmax == self._mask_x_max
+ and ymin == self._mask_y_min
+ and ymax == self._mask_y_max
+ )
@docstring(_StatsContext)
def clipData(self, item, plot, onlimits, roi):
- self._checkContextInputs(item=item, plot=plot, onlimits=onlimits,
- roi=roi)
+ self._checkContextInputs(item=item, plot=plot, onlimits=onlimits, roi=roi)
self.origin = item.getOrigin()
self.scale = item.getScale()
@@ -560,8 +563,9 @@ class _ImageContext(_StatsContext):
if XMaxBound <= XMinBound or YMaxBound <= YMinBound:
self.data = None
else:
- self.data = self.data[YMinBound:YMaxBound + 1,
- XMinBound:XMaxBound + 1]
+ self.data = self.data[
+ YMinBound : YMaxBound + 1, XMinBound : XMaxBound + 1
+ ]
mask = numpy.zeros_like(self.data)
elif roi:
minX, maxX = 0, self.data.shape[1]
@@ -572,8 +576,9 @@ class _ImageContext(_StatsContext):
XMaxBound = min(maxX, self.data.shape[1])
YMaxBound = min(maxY, self.data.shape[0])
- if self.is_mask_valid(xmin=XMinBound, xmax=XMaxBound,
- ymin=YMinBound, ymax=YMaxBound):
+ if self.is_mask_valid(
+ xmin=XMinBound, xmax=XMaxBound, ymin=YMinBound, ymax=YMaxBound
+ ):
mask = self.mask
else:
for x in range(XMinBound, XMaxBound):
@@ -581,8 +586,9 @@ class _ImageContext(_StatsContext):
_x = (x * self.scale[0]) + self.origin[0]
_y = (y * self.scale[1]) + self.origin[1]
mask[y, x] = not roi.contains((_x, _y))
- self._set_mask_validity(xmin=XMinBound, xmax=XMaxBound,
- ymin=YMinBound, ymax=YMaxBound)
+ self._set_mask_validity(
+ xmin=XMinBound, xmax=XMaxBound, ymin=YMinBound, ymax=YMaxBound
+ )
self.values = numpy.ma.array(self.data, mask=mask)
if self.values.compressed().size > 0:
self.min, self.max = min_max(self.values.compressed())
@@ -590,15 +596,18 @@ class _ImageContext(_StatsContext):
self.min, self.max = None, None
if self.values is not None:
- self.axes = (self.origin[1] + self.scale[1] * numpy.arange(self.data.shape[0]),
- self.origin[0] + self.scale[0] * numpy.arange(self.data.shape[1]))
+ self.axes = (
+ self.origin[1] + self.scale[1] * numpy.arange(self.data.shape[0]),
+ self.origin[0] + self.scale[0] * numpy.arange(self.data.shape[1]),
+ )
def _checkContextInputs(self, item, plot, onlimits, roi):
- _StatsContext._checkContextInputs(self, item=item, plot=plot,
- onlimits=onlimits, roi=roi)
+ _StatsContext._checkContextInputs(
+ self, item=item, plot=plot, onlimits=onlimits, roi=roi
+ )
if roi is not None and not isinstance(roi, RegionOfInterest):
- raise TypeError('curve `context` can ony manage 2D roi')
+ raise TypeError("curve `context` can ony manage 2D roi")
class _plot3DScatterContext(_StatsContext):
@@ -615,14 +624,15 @@ class _plot3DScatterContext(_StatsContext):
For now, incompatible with `onlinits` calculation
:type roi: Union[None, :class:`ROI`]
"""
+
def __init__(self, item, plot, onlimits, roi):
- _StatsContext.__init__(self, kind='scatter', item=item, plot=plot,
- onlimits=onlimits, roi=roi)
+ _StatsContext.__init__(
+ self, kind="scatter", item=item, plot=plot, onlimits=onlimits, roi=roi
+ )
@docstring(_StatsContext)
def clipData(self, item, plot, onlimits, roi):
- self._checkContextInputs(item=item, plot=plot, onlimits=onlimits,
- roi=roi)
+ self._checkContextInputs(item=item, plot=plot, onlimits=onlimits, roi=roi)
if onlimits:
raise RuntimeError("Unsupported plot %s" % str(plot))
values = item.getValueData(copy=False)
@@ -646,11 +656,12 @@ class _plot3DScatterContext(_StatsContext):
self.min, self.max = None, None
def _checkContextInputs(self, item, plot, onlimits, roi):
- _StatsContext._checkContextInputs(self, item=item, plot=plot,
- onlimits=onlimits, roi=roi)
+ _StatsContext._checkContextInputs(
+ self, item=item, plot=plot, onlimits=onlimits, roi=roi
+ )
if roi is not None and not isinstance(roi, RegionOfInterest):
- raise TypeError('curve `context` can ony manage 2D roi')
+ raise TypeError("curve `context` can ony manage 2D roi")
class _plot3DArrayContext(_StatsContext):
@@ -667,14 +678,15 @@ class _plot3DArrayContext(_StatsContext):
For now, incompatible with `onlinits` calculation
:type roi: Union[None, :class:`ROI`]
"""
+
def __init__(self, item, plot, onlimits, roi):
- _StatsContext.__init__(self, kind='image', item=item, plot=plot,
- onlimits=onlimits, roi=roi)
+ _StatsContext.__init__(
+ self, kind="image", item=item, plot=plot, onlimits=onlimits, roi=roi
+ )
@docstring(_StatsContext)
def clipData(self, item, plot, onlimits, roi):
- self._checkContextInputs(item=item, plot=plot, onlimits=onlimits,
- roi=roi)
+ self._checkContextInputs(item=item, plot=plot, onlimits=onlimits, roi=roi)
if onlimits:
raise RuntimeError("Unsupported plot %s" % str(plot))
@@ -696,14 +708,15 @@ class _plot3DArrayContext(_StatsContext):
self.min, self.max = None, None
def _checkContextInputs(self, item, plot, onlimits, roi):
- _StatsContext._checkContextInputs(self, item=item, plot=plot,
- onlimits=onlimits, roi=roi)
+ _StatsContext._checkContextInputs(
+ self, item=item, plot=plot, onlimits=onlimits, roi=roi
+ )
if roi is not None and not isinstance(roi, RegionOfInterest):
- raise TypeError('curve `context` can ony manage 2D roi')
+ raise TypeError("curve `context` can ony manage 2D roi")
-BASIC_COMPATIBLE_KINDS = 'curve', 'image', 'scatter', 'histogram'
+BASIC_COMPATIBLE_KINDS = "curve", "image", "scatter", "histogram"
class StatBase(object):
@@ -714,6 +727,7 @@ class StatBase(object):
:param List[str] compatibleKinds:
The kind of items (curve, scatter...) for which the statistic apply.
"""
+
def __init__(self, name, compatibleKinds=BASIC_COMPATIBLE_KINDS, description=None):
self.name = name
self.compatibleKinds = compatibleKinds
@@ -726,7 +740,7 @@ class StatBase(object):
:param _StatsContext context:
:return dict: key is stat name, statistic computed is the dict value
"""
- raise NotImplementedError('Base class')
+ raise NotImplementedError("Base class")
def getToolTip(self, kind):
"""
@@ -749,6 +763,7 @@ class Stat(StatBase):
:param tuple kinds: the compatible item kinds of the function (curve,
image...)
"""
+
def __init__(self, name, fct, kinds=BASIC_COMPATIBLE_KINDS):
StatBase.__init__(self, name, kinds)
self._fct = fct
@@ -759,16 +774,18 @@ class Stat(StatBase):
if context.kind in self.compatibleKinds:
return self._fct(context.values)
else:
- raise ValueError('Kind %s not managed by %s'
- '' % (context.kind, self.name))
+ raise ValueError(
+ "Kind %s not managed by %s" "" % (context.kind, self.name)
+ )
else:
return None
class StatMin(StatBase):
"""Compute the minimal value on data"""
+
def __init__(self):
- StatBase.__init__(self, name='min')
+ StatBase.__init__(self, name="min")
@docstring(StatBase)
def calculate(self, context):
@@ -777,8 +794,9 @@ class StatMin(StatBase):
class StatMax(StatBase):
"""Compute the maximal value on data"""
+
def __init__(self):
- StatBase.__init__(self, name='max')
+ StatBase.__init__(self, name="max")
@docstring(StatBase)
def calculate(self, context):
@@ -787,8 +805,9 @@ class StatMax(StatBase):
class StatDelta(StatBase):
"""Compute the delta between minimal and maximal on data"""
+
def __init__(self):
- StatBase.__init__(self, name='delta')
+ StatBase.__init__(self, name="delta")
@docstring(StatBase)
def calculate(self, context):
@@ -822,8 +841,9 @@ class _StatCoord(StatBase):
class StatCoordMin(_StatCoord):
"""Compute the coordinates of the first minimum value of the data"""
+
def __init__(self):
- _StatCoord.__init__(self, name='coords min')
+ _StatCoord.__init__(self, name="coords min")
@docstring(StatBase)
def calculate(self, context):
@@ -840,8 +860,9 @@ class StatCoordMin(_StatCoord):
class StatCoordMax(_StatCoord):
"""Compute the coordinates of the first maximum value of the data"""
+
def __init__(self):
- _StatCoord.__init__(self, name='coords max')
+ _StatCoord.__init__(self, name="coords max")
@docstring(StatBase)
def calculate(self, context):
@@ -860,8 +881,9 @@ class StatCoordMax(_StatCoord):
class StatCOM(StatBase):
"""Compute data center of mass"""
+
def __init__(self):
- StatBase.__init__(self, name='COM', description='Center of mass')
+ StatBase.__init__(self, name="COM", description="Center of mass")
@docstring(StatBase)
def calculate(self, context):
@@ -870,7 +892,7 @@ class StatCOM(StatBase):
values = numpy.ma.array(context.values, mask=context.mask, dtype=numpy.float64)
sum_ = numpy.sum(values)
- if sum_ == 0. or numpy.ma.is_masked(sum_):
+ if sum_ == 0.0 or numpy.ma.is_masked(sum_):
return (numpy.nan,) * len(context.axes)
if context.isStructuredData():
@@ -878,11 +900,11 @@ class StatCOM(StatBase):
for index, axis in enumerate(context.axes):
axes = tuple([i for i in range(len(context.axes)) if i != index])
centerofmass.append(
- numpy.sum(axis * numpy.sum(values, axis=axes)) / sum_)
+ numpy.sum(axis * numpy.sum(values, axis=axes)) / sum_
+ )
return tuple(reversed(centerofmass))
else:
- return tuple(
- numpy.sum(axis * values) / sum_ for axis in context.axes)
+ return tuple(numpy.sum(axis * values) / sum_ for axis in context.axes)
@docstring(StatBase)
def getToolTip(self, kind):
diff --git a/src/silx/gui/plot/stats/statshandler.py b/src/silx/gui/plot/stats/statshandler.py
index 1531ba2..8e7e08b 100644
--- a/src/silx/gui/plot/stats/statshandler.py
+++ b/src/silx/gui/plot/stats/statshandler.py
@@ -48,8 +48,8 @@ class _FloatItem(qt.QTableWidgetItem):
qt.QTableWidgetItem.__init__(self, type=type)
def __lt__(self, other):
- self_values = self.text().lstrip('(').rstrip(')').split(',')
- other_values = other.text().lstrip('(').rstrip(')').split(',')
+ self_values = self.text().lstrip("(").rstrip(")").split(",")
+ other_values = other.text().lstrip("(").rstrip(")").split(",")
for self_value, other_value in zip(self_values, other_values):
f_self_value = float(self_value)
f_other_value = float(other_value)
@@ -67,7 +67,8 @@ class StatFormatter(object):
which will be used to display the result of the
statistic computation.
"""
- DEFAULT_FORMATTER = '{0:.3f}'
+
+ DEFAULT_FORMATTER = "{0:.3f}"
def __init__(self, formatter=DEFAULT_FORMATTER, qItemClass=_FloatItem):
self.formatter = formatter
@@ -121,9 +122,11 @@ class StatsHandler(object):
if isinstance(arg[0], statsmdl.StatBase):
stat = arg[0]
if len(arg) > 2:
- raise ValueError('To many argument with %s. At most one '
- 'argument can be associated with the '
- 'BaseStat (the `StatFormatter`')
+ raise ValueError(
+ "To many argument with %s. At most one "
+ "argument can be associated with the "
+ "BaseStat (the `StatFormatter`"
+ )
if len(arg) == 2:
assert arg[1] is None or isinstance(arg[1], (StatFormatter, str))
formatter = arg[1]
@@ -134,15 +137,20 @@ class StatsHandler(object):
arg = arg[0]
if type(arg[0]) is not str:
- raise ValueError('first element of the tuple should be a string'
- ' or a StatBase instance')
+ raise ValueError(
+ "first element of the tuple should be a string"
+ " or a StatBase instance"
+ )
if len(arg) == 1:
- raise ValueError('A function should be associated with the'
- 'stat name')
+ raise ValueError(
+ "A function should be associated with the" "stat name"
+ )
if len(arg) > 3:
- raise ValueError('Two much argument given for defining statistic.'
- 'Take at most three arguments (name, function, '
- 'kinds)')
+ raise ValueError(
+ "Two much argument given for defining statistic."
+ "Take at most three arguments (name, function, "
+ "kinds)"
+ )
if len(arg) == 2:
stat = statsmdl.Stat(name=arg[0], fct=arg[1])
else:
@@ -180,12 +188,13 @@ class StatsHandler(object):
if isinstance(val, (tuple, list)):
res = []
[res.append(self.formatters[name].format(_val)) for _val in val]
- return ', '.join(res)
+ return ", ".join(res)
else:
return self.formatters[name].format(val)
- def calculate(self, item, plot, onlimits, roi=None, data_changed=False,
- roi_changed=False):
+ def calculate(
+ self, item, plot, onlimits, roi=None, data_changed=False, roi_changed=False
+ ):
"""
compute all statistic registered and return the list of formatted
statistics result.
@@ -200,8 +209,14 @@ class StatsHandler(object):
:return: list of formatted statistics (as str)
:rtype: dict
"""
- res = self.stats.calculate(item, plot, onlimits, roi,
- data_changed=data_changed, roi_changed=roi_changed)
+ res = self.stats.calculate(
+ item,
+ plot,
+ onlimits,
+ roi,
+ data_changed=data_changed,
+ roi_changed=roi_changed,
+ )
for resName, resValue in list(res.items()):
res[resName] = self.format(resName, res[resName])
return res
diff --git a/src/silx/gui/plot/PlotTools.py b/src/silx/gui/plot/test/conftest.py
index 35d0f48..78475fb 100644
--- a/src/silx/gui/plot/PlotTools.py
+++ b/src/silx/gui/plot/test/conftest.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -21,20 +21,23 @@
# THE SOFTWARE.
#
# ###########################################################################*/
-"""Set of widgets to associate with a :class:'PlotWidget'.
-"""
+"""Test PlotWidget active item"""
__authors__ = ["T. Vincent"]
__license__ = "MIT"
-__date__ = "01/03/2018"
+__date__ = "13/12/2023"
-from ...utils.deprecation import deprecated_warning
+import pytest
+from silx.gui.plot import PlotWidget
-deprecated_warning(type_='module',
- name=__file__,
- reason='Plot tools refactoring',
- replacement='silx.gui.plot.tools',
- since_version='0.8')
-from .tools import PositionInfo, LimitsToolBar # noqa
+@pytest.fixture
+def plotWidget(qWidgetFactory, request):
+ try:
+ backend = request.param
+ except AttributeError:
+ backend = "mpl" # Backend was not defined
+ if backend == "gl":
+ request.getfixturevalue("use_opengl") # Skip test if OpenGL test disabled
+ yield qWidgetFactory(PlotWidget, backend=backend)
diff --git a/src/silx/gui/plot/test/testAlphaSlider.py b/src/silx/gui/plot/test/testAlphaSlider.py
index 8641da7..e9ccb45 100644
--- a/src/silx/gui/plot/test/testAlphaSlider.py
+++ b/src/silx/gui/plot/test/testAlphaSlider.py
@@ -29,7 +29,6 @@ __license__ = "MIT"
__date__ = "28/03/2017"
import numpy
-import unittest
from silx.gui import qt
from silx.gui.utils.testutils import TestCaseQt
@@ -76,19 +75,16 @@ class TestActiveImageAlphaSlider(TestCaseQt):
def testGetImage(self):
self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]))
- self.assertEqual(self.plot.getActiveImage(),
- self.aslider.getItem())
+ self.assertEqual(self.plot.getActiveImage(), self.aslider.getItem())
self.plot.addImage(numpy.array([[0, 1, 3], [2, 4, 6]]), legend="2")
self.plot.setActiveImage("2")
- self.assertEqual(self.plot.getImage("2"),
- self.aslider.getItem())
+ self.assertEqual(self.plot.getImage("2"), self.aslider.getItem())
def testGetAlpha(self):
self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1")
self.aslider.setValue(137)
- self.assertAlmostEqual(self.aslider.getAlpha(),
- 137. / 255)
+ self.assertAlmostEqual(self.aslider.getAlpha(), 137.0 / 255)
class TestNamedImageAlphaSlider(TestCaseQt):
@@ -130,19 +126,16 @@ class TestNamedImageAlphaSlider(TestCaseQt):
self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1")
self.plot.addImage(numpy.array([[0, 1, 3], [2, 4, 6]]), legend="2")
self.aslider.setLegend("1")
- self.assertEqual(self.plot.getImage("1"),
- self.aslider.getItem())
+ self.assertEqual(self.plot.getImage("1"), self.aslider.getItem())
self.aslider.setLegend("2")
- self.assertEqual(self.plot.getImage("2"),
- self.aslider.getItem())
+ self.assertEqual(self.plot.getImage("2"), self.aslider.getItem())
def testGetAlpha(self):
self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1")
self.aslider.setLegend("1")
self.aslider.setValue(128)
- self.assertAlmostEqual(self.aslider.getAlpha(),
- 128. / 255)
+ self.assertAlmostEqual(self.aslider.getAlpha(), 128.0 / 255)
class TestNamedScatterAlphaSlider(TestCaseQt):
@@ -175,29 +168,22 @@ class TestNamedScatterAlphaSlider(TestCaseQt):
# no Scatter set initially, slider must be deactivate
self.assertFalse(self.aslider.isEnabled())
- self.plot.addScatter([0, 1, 2], [2, 3, 4], [5, 6, 7],
- legend="1")
+ self.plot.addScatter([0, 1, 2], [2, 3, 4], [5, 6, 7], legend="1")
self.aslider.setLegend("1")
# now we have an image set
self.assertTrue(self.aslider.isEnabled())
def testGetScatter(self):
- self.plot.addScatter([0, 1, 2], [2, 3, 4], [5, 6, 7],
- legend="1")
- self.plot.addScatter([0, 10, 20], [20, 30, 40], [50, 60, 70],
- legend="2")
+ self.plot.addScatter([0, 1, 2], [2, 3, 4], [5, 6, 7], legend="1")
+ self.plot.addScatter([0, 10, 20], [20, 30, 40], [50, 60, 70], legend="2")
self.aslider.setLegend("1")
- self.assertEqual(self.plot.getScatter("1"),
- self.aslider.getItem())
+ self.assertEqual(self.plot.getScatter("1"), self.aslider.getItem())
self.aslider.setLegend("2")
- self.assertEqual(self.plot.getScatter("2"),
- self.aslider.getItem())
+ self.assertEqual(self.plot.getScatter("2"), self.aslider.getItem())
def testGetAlpha(self):
- self.plot.addScatter([0, 10, 20], [20, 30, 40], [50, 60, 70],
- legend="1")
+ self.plot.addScatter([0, 10, 20], [20, 30, 40], [50, 60, 70], legend="1")
self.aslider.setLegend("1")
self.aslider.setValue(128)
- self.assertAlmostEqual(self.aslider.getAlpha(),
- 128. / 255)
+ self.assertAlmostEqual(self.aslider.getAlpha(), 128.0 / 255)
diff --git a/src/silx/gui/plot/test/testAxis.py b/src/silx/gui/plot/test/testAxis.py
new file mode 100644
index 0000000..dcf2f06
--- /dev/null
+++ b/src/silx/gui/plot/test/testAxis.py
@@ -0,0 +1,147 @@
+# /*##########################################################################
+#
+# Copyright (c) 2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests of PlotWidget Axis items"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "15/06/2023"
+
+
+from silx.gui.plot import PlotWidget
+
+
+def testAxisIsVisible(qapp, qWidgetFactory):
+ """Test Axis.isVisible method"""
+ plotWidget = qWidgetFactory(PlotWidget)
+
+ assert plotWidget.getXAxis().isVisible()
+ assert plotWidget.getYAxis().isVisible()
+ assert not plotWidget.getYAxis("right").isVisible()
+
+ # Add curve on right axis
+ plotWidget.addCurve((0, 1, 2), (1, 2, 3), yaxis="right")
+ qapp.processEvents()
+
+ assert plotWidget.getYAxis("right").isVisible()
+
+ # hide curve on right axis
+ curve = plotWidget.getItems()[0]
+ curve.setVisible(False)
+ qapp.processEvents()
+
+ assert not plotWidget.getYAxis("right").isVisible()
+
+ # show curve on right axis
+ curve.setVisible(True)
+ qapp.processEvents()
+
+ assert plotWidget.getYAxis("right").isVisible()
+
+ # Move curve to left axis
+ curve.setYAxis("left")
+ qapp.processEvents()
+
+ assert not plotWidget.getYAxis("right").isVisible()
+
+
+def testAxisSetScaleLogNoData(qapp, qWidgetFactory):
+ """Test Axis.setScale('log') method with an empty plot
+
+ Limits are reset only when negative
+ """
+ plotWidget = qWidgetFactory(PlotWidget)
+ xaxis = plotWidget.getXAxis()
+ yaxis = plotWidget.getYAxis()
+ y2axis = plotWidget.getYAxis("right")
+
+ xaxis.setLimits(-1.0, 1.0)
+ yaxis.setLimits(2.0, 3.0)
+ y2axis.setLimits(-2.0, -1.0)
+
+ xaxis.setScale("log")
+ qapp.processEvents()
+
+ assert xaxis.getLimits() == (1.0, 100.0)
+ assert yaxis.getLimits() == (2.0, 3.0)
+ assert y2axis.getLimits() == (-2.0, -1.0)
+
+ xaxis.setLimits(10.0, 20.0)
+
+ yaxis.setScale("log")
+ qapp.processEvents()
+
+ assert xaxis.getLimits() == (10.0, 20.0)
+ assert yaxis.getLimits() == (2.0, 3.0) # Positive range is preserved
+ assert y2axis.getLimits() == (1.0, 100.0) # Negative min is reset
+
+
+def testAxisSetScaleLogWithData(qapp, qWidgetFactory):
+ """Test Axis.setScale('log') method with data
+
+ Limits are reset only when negative and takes the data range into account
+ """
+ plotWidget = qWidgetFactory(PlotWidget)
+ xaxis = plotWidget.getXAxis()
+ yaxis = plotWidget.getYAxis()
+ plotWidget.addCurve((-1, 1, 2, 3), (-1, 1, 2, 3))
+
+ xaxis.setLimits(-1.0, 0.5) # Limits contains no positive data
+ yaxis.setLimits(-1.0, 2.0) # Limits contains positive data
+
+ xaxis.setScale("log")
+ yaxis.setScale("log")
+ qapp.processEvents()
+
+ assert xaxis.getLimits() == (1.0, 3.0) # Reset to positive data range
+ assert yaxis.getLimits() == (1.0, 2.0) # Keep max limit
+
+
+def testAxisSetScaleLinear(qapp, qWidgetFactory):
+ """Test Axis.setScale('linear') method: Limits are not changed"""
+ plotWidget = qWidgetFactory(PlotWidget)
+ xaxis = plotWidget.getXAxis()
+ yaxis = plotWidget.getYAxis()
+ y2axis = plotWidget.getYAxis("right")
+ xaxis.setScale("log")
+ yaxis.setScale("log")
+ plotWidget.resetZoom()
+ qapp.processEvents()
+
+ xaxis.setLimits(10.0, 1000.0)
+ yaxis.setLimits(20.0, 2000.0)
+ y2axis.setLimits(30.0, 3000.0)
+
+ xaxis.setScale("linear")
+ qapp.processEvents()
+
+ assert xaxis.getLimits() == (10.0, 1000.0)
+ assert yaxis.getLimits() == (20.0, 2000.0)
+ assert y2axis.getLimits() == (30.0, 3000.0)
+
+ yaxis.setScale("linear")
+ qapp.processEvents()
+
+ assert xaxis.getLimits() == (10.0, 1000.0)
+ assert yaxis.getLimits() == (20.0, 2000.0)
+ assert y2axis.getLimits() == (30.0, 3000.0)
diff --git a/src/silx/gui/plot/test/testColorBar.py b/src/silx/gui/plot/test/testColorBar.py
index 199726b..7202bc2 100644
--- a/src/silx/gui/plot/test/testColorBar.py
+++ b/src/silx/gui/plot/test/testColorBar.py
@@ -27,7 +27,6 @@ __authors__ = ["H. Payno"]
__license__ = "MIT"
__date__ = "24/04/2018"
-import unittest
from silx.gui.utils.testutils import TestCaseQt
from silx.gui.plot.ColorBar import _ColorScale
from silx.gui.plot.ColorBar import ColorBarWidget
@@ -40,6 +39,7 @@ import numpy
class TestColorScale(TestCaseQt):
"""Test that interaction with the colorScale is correct"""
+
def setUp(self):
super(TestColorScale, self).setUp()
self.colorScaleWidget = _ColorScale(colormap=None, parent=None)
@@ -59,37 +59,32 @@ class TestColorScale(TestCaseQt):
self.assertIsNone(colormap)
def testRelativePositionLinear(self):
- self.colorMapLin1 = Colormap(name='gray',
- normalization=Colormap.LINEAR,
- vmin=0.0,
- vmax=1.0)
+ self.colorMapLin1 = Colormap(
+ name="gray", normalization=Colormap.LINEAR, vmin=0.0, vmax=1.0
+ )
self.colorScaleWidget.setColormap(self.colorMapLin1)
self.assertTrue(
- self.colorScaleWidget.getValueFromRelativePosition(0.25) == 0.25)
- self.assertTrue(
- self.colorScaleWidget.getValueFromRelativePosition(0.5) == 0.5)
- self.assertTrue(
- self.colorScaleWidget.getValueFromRelativePosition(1.0) == 1.0)
-
- self.colorMapLin2 = Colormap(name='viridis',
- normalization=Colormap.LINEAR,
- vmin=-10,
- vmax=0)
+ self.colorScaleWidget.getValueFromRelativePosition(0.25) == 0.25
+ )
+ self.assertTrue(self.colorScaleWidget.getValueFromRelativePosition(0.5) == 0.5)
+ self.assertTrue(self.colorScaleWidget.getValueFromRelativePosition(1.0) == 1.0)
+
+ self.colorMapLin2 = Colormap(
+ name="viridis", normalization=Colormap.LINEAR, vmin=-10, vmax=0
+ )
self.colorScaleWidget.setColormap(self.colorMapLin2)
self.assertTrue(
- self.colorScaleWidget.getValueFromRelativePosition(0.25) == -7.5)
- self.assertTrue(
- self.colorScaleWidget.getValueFromRelativePosition(0.5) == -5.0)
- self.assertTrue(
- self.colorScaleWidget.getValueFromRelativePosition(1.0) == 0.0)
+ self.colorScaleWidget.getValueFromRelativePosition(0.25) == -7.5
+ )
+ self.assertTrue(self.colorScaleWidget.getValueFromRelativePosition(0.5) == -5.0)
+ self.assertTrue(self.colorScaleWidget.getValueFromRelativePosition(1.0) == 0.0)
def testRelativePositionLog(self):
- self.colorMapLog1 = Colormap(name='temperature',
- normalization=Colormap.LOGARITHM,
- vmin=1.0,
- vmax=100.0)
+ self.colorMapLog1 = Colormap(
+ name="temperature", normalization=Colormap.LOGARITHM, vmin=1.0, vmax=100.0
+ )
self.colorScaleWidget.setColormap(self.colorMapLog1)
@@ -130,14 +125,13 @@ class TestNoAutoscale(TestCaseQt):
super(TestNoAutoscale, self).tearDown()
def testLogNormNoAutoscale(self):
- colormapLog = Colormap(name='gray',
- normalization=Colormap.LOGARITHM,
- vmin=1.0,
- vmax=100.0)
+ colormapLog = Colormap(
+ name="gray", normalization=Colormap.LOGARITHM, vmin=1.0, vmax=100.0
+ )
data = numpy.linspace(10, 1e10, 9).reshape(3, 3)
- self.plot.addImage(data=data, colormap=colormapLog, legend='toto')
- self.plot.setActiveImage('toto')
+ self.plot.addImage(data=data, colormap=colormapLog, legend="toto")
+ self.plot.setActiveImage("toto")
# test Ticks
self.tickBar.setTicksNumber(10)
@@ -155,14 +149,13 @@ class TestNoAutoscale(TestCaseQt):
self.assertTrue(val == 1.0)
def testLinearNormNoAutoscale(self):
- colormapLog = Colormap(name='gray',
- normalization=Colormap.LINEAR,
- vmin=-4,
- vmax=5)
+ colormapLog = Colormap(
+ name="gray", normalization=Colormap.LINEAR, vmin=-4, vmax=5
+ )
data = numpy.linspace(1, 9, 9).reshape(3, 3)
- self.plot.addImage(data=data, colormap=colormapLog, legend='toto')
- self.plot.setActiveImage('toto')
+ self.plot.addImage(data=data, colormap=colormapLog, legend="toto")
+ self.plot.setActiveImage("toto")
# test Ticks
self.tickBar.setTicksNumber(10)
@@ -209,15 +202,14 @@ class TestColorBarWidget(TestCaseQt):
Note : colorbar is modified by the Plot directly not ColorBarWidget
"""
- colormapLog = Colormap(name='gray',
- normalization=Colormap.LOGARITHM,
- vmin=None,
- vmax=None)
+ colormapLog = Colormap(
+ name="gray", normalization=Colormap.LOGARITHM, vmin=None, vmax=None
+ )
data = numpy.array([-5, -4, 0, 2, 3, 5, 10, 20, 30])
data = data.reshape(3, 3)
- self.plot.addImage(data=data, colormap=colormapLog, legend='toto')
- self.plot.setActiveImage('toto')
+ self.plot.addImage(data=data, colormap=colormapLog, legend="toto")
+ self.plot.setActiveImage("toto")
# default behavior when with log and negative values: should set vmin
# to 1 and vmax to 10
@@ -226,52 +218,43 @@ class TestColorBarWidget(TestCaseQt):
# if data is positive
data[data < 1] = data.max()
- self.plot.addImage(data=data,
- colormap=colormapLog,
- legend='toto',
- replace=True)
- self.plot.setActiveImage('toto')
+ self.plot.addImage(data=data, colormap=colormapLog, legend="toto", replace=True)
+ self.plot.setActiveImage("toto")
self.assertTrue(self.colorBar.getColorScaleBar().minVal == data.min())
self.assertTrue(self.colorBar.getColorScaleBar().maxVal == data.max())
def testPlotAssocation(self):
"""Make sure the ColorBarWidget is properly connected with the plot"""
- colormap = Colormap(name='gray',
- normalization=Colormap.LINEAR,
- vmin=None,
- vmax=None)
+ colormap = Colormap(
+ name="gray", normalization=Colormap.LINEAR, vmin=None, vmax=None
+ )
# make sure that default settings are the same (but a copy of the
self.colorBar.setPlot(self.plot)
- self.assertTrue(
- self.colorBar.getColormap() is self.plot.getDefaultColormap())
+ self.assertTrue(self.colorBar.getColormap() is self.plot.getDefaultColormap())
data = numpy.linspace(0, 10, 100).reshape(10, 10)
- self.plot.addImage(data=data, colormap=colormap, legend='toto')
- self.plot.setActiveImage('toto')
+ self.plot.addImage(data=data, colormap=colormap, legend="toto")
+ self.plot.setActiveImage("toto")
# make sure the modification of the colormap has been done
- self.assertFalse(
- self.colorBar.getColormap() is self.plot.getDefaultColormap())
- self.assertTrue(
- self.colorBar.getColormap() is colormap)
+ self.assertFalse(self.colorBar.getColormap() is self.plot.getDefaultColormap())
+ self.assertTrue(self.colorBar.getColormap() is colormap)
# test that colorbar is updated when default plot colormap changes
self.plot.clear()
- plotColormap = Colormap(name='gray',
- normalization=Colormap.LOGARITHM,
- vmin=None,
- vmax=None)
+ plotColormap = Colormap(
+ name="gray", normalization=Colormap.LOGARITHM, vmin=None, vmax=None
+ )
self.plot.setDefaultColormap(plotColormap)
self.assertTrue(self.colorBar.getColormap() is plotColormap)
def testColormapWithoutRange(self):
"""Test with a colormap with vmin==vmax"""
- colormap = Colormap(name='gray',
- normalization=Colormap.LINEAR,
- vmin=1.0,
- vmax=1.0)
+ colormap = Colormap(
+ name="gray", normalization=Colormap.LINEAR, vmin=1.0, vmax=1.0
+ )
self.colorBar.setColormap(colormap)
@@ -300,40 +283,35 @@ class TestColorBarUpdate(TestCaseQt):
super(TestColorBarUpdate, self).tearDown()
def testUpdateColorMap(self):
- colormap = Colormap(name='gray',
- normalization='linear',
- vmin=0,
- vmax=1)
+ colormap = Colormap(name="gray", normalization="linear", vmin=0, vmax=1)
# check inital state
- self.plot.addImage(data=self.data, colormap=colormap, legend='toto')
- self.plot.setActiveImage('toto')
+ self.plot.addImage(data=self.data, colormap=colormap, legend="toto")
+ self.plot.setActiveImage("toto")
self.assertTrue(self.colorBar.getColorScaleBar().minVal == 0)
self.assertTrue(self.colorBar.getColorScaleBar().maxVal == 1)
- self.assertTrue(
- self.colorBar.getColorScaleBar().getTickBar()._vmin == 0)
- self.assertTrue(
- self.colorBar.getColorScaleBar().getTickBar()._vmax == 1)
+ self.assertTrue(self.colorBar.getColorScaleBar().getTickBar()._vmin == 0)
+ self.assertTrue(self.colorBar.getColorScaleBar().getTickBar()._vmax == 1)
self.assertIsInstance(
self.colorBar.getColorScaleBar().getTickBar()._normalizer,
- LinearNormalization)
+ LinearNormalization,
+ )
# update colormap
colormap.setVMin(0.5)
self.assertTrue(self.colorBar.getColorScaleBar().minVal == 0.5)
- self.assertTrue(
- self.colorBar.getColorScaleBar().getTickBar()._vmin == 0.5)
+ self.assertTrue(self.colorBar.getColorScaleBar().getTickBar()._vmin == 0.5)
colormap.setVMax(0.8)
self.assertTrue(self.colorBar.getColorScaleBar().maxVal == 0.8)
- self.assertTrue(
- self.colorBar.getColorScaleBar().getTickBar()._vmax == 0.8)
+ self.assertTrue(self.colorBar.getColorScaleBar().getTickBar()._vmax == 0.8)
- colormap.setNormalization('log')
+ colormap.setNormalization("log")
self.assertIsInstance(
self.colorBar.getColorScaleBar().getTickBar()._normalizer,
- LogarithmicNormalization)
+ LogarithmicNormalization,
+ )
# TODO : should also check that if the colormap is changing then values (especially in log scale)
# should be coherent if in autoscale
diff --git a/src/silx/gui/plot/test/testCompareImages.py b/src/silx/gui/plot/test/testCompareImages.py
index 9b5065d..4bc52b4 100644
--- a/src/silx/gui/plot/test/testCompareImages.py
+++ b/src/silx/gui/plot/test/testCompareImages.py
@@ -27,79 +27,210 @@ __authors__ = ["H. Payno"]
__license__ = "MIT"
__date__ = "23/07/2018"
-import unittest
+import pytest
import numpy
import weakref
-from silx.gui.utils.testutils import TestCaseQt
+from silx.gui import qt
from silx.gui.plot.CompareImages import CompareImages
-class TestCompareImages(TestCaseQt):
- """Test that CompareImages widget is working in some cases"""
-
- def setUp(self):
- super(TestCompareImages, self).setUp()
- self.widget = CompareImages()
-
- def tearDown(self):
- ref = weakref.ref(self.widget)
- self.widget = None
- self.qWaitForDestroy(ref)
- super(TestCompareImages, self).tearDown()
-
- def testIntensityImage(self):
- image1 = numpy.random.rand(10, 10)
- image2 = numpy.random.rand(10, 10)
- self.widget.setData(image1, image2)
-
- def testRgbImage(self):
- image1 = numpy.random.randint(0, 255, size=(10, 10, 3))
- image2 = numpy.random.randint(0, 255, size=(10, 10, 3))
- self.widget.setData(image1, image2)
-
- def testRgbaImage(self):
- image1 = numpy.random.randint(0, 255, size=(10, 10, 4))
- image2 = numpy.random.randint(0, 255, size=(10, 10, 4))
- self.widget.setData(image1, image2)
-
- def testVizualisations(self):
- image1 = numpy.random.rand(10, 10)
- image2 = numpy.random.rand(10, 10)
- self.widget.setData(image1, image2)
- for mode in CompareImages.VisualizationMode:
- self.widget.setVisualizationMode(mode)
-
- def testAlignemnt(self):
- image1 = numpy.random.rand(10, 10)
- image2 = numpy.random.rand(5, 5)
- self.widget.setData(image1, image2)
- for mode in CompareImages.AlignmentMode:
- self.widget.setAlignmentMode(mode)
-
- def testGetPixel(self):
- image1 = numpy.random.rand(11, 11)
- image2 = numpy.random.rand(5, 5)
- image1[5, 5] = 111.111
- image2[2, 2] = 222.222
- self.widget.setData(image1, image2)
- expectedValue = {}
- expectedValue[CompareImages.AlignmentMode.CENTER] = 222.222
- expectedValue[CompareImages.AlignmentMode.STRETCH] = 222.222
- expectedValue[CompareImages.AlignmentMode.ORIGIN] = None
- for mode in expectedValue.keys():
- self.widget.setAlignmentMode(mode)
- data = self.widget.getRawPixelData(11 / 2.0, 11 / 2.0)
- data1, data2 = data
- self.assertEqual(data1, 111.111)
- self.assertEqual(data2, expectedValue[mode])
-
- def testImageEmpty(self):
- self.widget.setData(image1=None, image2=None)
- self.assertTrue(self.widget.getRawPixelData(11 / 2.0, 11 / 2.0) == (None, None))
-
- def testSetImageSeparately(self):
- self.widget.setImage1(numpy.random.rand(10, 10))
- self.widget.setImage2(numpy.random.rand(10, 10))
- for mode in CompareImages.VisualizationMode:
- self.widget.setVisualizationMode(mode)
+@pytest.fixture
+def compareImages(qapp, qapp_utils):
+ widget = CompareImages()
+ widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ yield widget
+ widget.close()
+ ref = weakref.ref(widget)
+ widget = None
+ qapp_utils.qWaitForDestroy(ref)
+
+
+def testIntensityImage(compareImages):
+ image1 = numpy.random.rand(10, 10)
+ image2 = numpy.random.rand(10, 10)
+ compareImages.setData(image1, image2)
+
+
+def testRgbImage(compareImages):
+ image1 = numpy.random.randint(0, 255, size=(10, 10, 3))
+ image2 = numpy.random.randint(0, 255, size=(10, 10, 3))
+ compareImages.setData(image1, image2)
+
+
+def testRgbaImage(compareImages):
+ image1 = numpy.random.randint(0, 255, size=(10, 10, 4))
+ image2 = numpy.random.randint(0, 255, size=(10, 10, 4))
+ compareImages.setData(image1, image2)
+
+
+def testAlignemnt(compareImages):
+ image1 = numpy.random.rand(10, 10)
+ image2 = numpy.random.rand(5, 5)
+ compareImages.setData(image1, image2)
+ for mode in CompareImages.AlignmentMode:
+ compareImages.setAlignmentMode(mode)
+
+
+def testGetPixel(compareImages):
+ image1 = numpy.random.rand(11, 11)
+ image2 = numpy.random.rand(5, 5)
+ image1[5, 5] = 111.111
+ image2[2, 2] = 222.222
+ compareImages.setData(image1, image2)
+ expectedValue = {}
+ expectedValue[CompareImages.AlignmentMode.CENTER] = 222.222
+ expectedValue[CompareImages.AlignmentMode.STRETCH] = 222.222
+ expectedValue[CompareImages.AlignmentMode.ORIGIN] = None
+ for mode in expectedValue.keys():
+ compareImages.setAlignmentMode(mode)
+ data = compareImages.getRawPixelData(11 / 2.0, 11 / 2.0)
+ data1, data2 = data
+ assert data1 == 111.111
+ assert data2 == expectedValue[mode]
+
+
+def testImageEmpty(compareImages):
+ compareImages.setData(image1=None, image2=None)
+
+
+def testSetImageSeparately(compareImages):
+ compareImages.setImage1(numpy.random.rand(10, 10))
+ compareImages.setImage2(numpy.random.rand(10, 10))
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ (CompareImages.VisualizationMode.COMPOSITE_A_MINUS_B,),
+ (CompareImages.VisualizationMode.COMPOSITE_RED_BLUE_GRAY,),
+ (CompareImages.VisualizationMode.HORIZONTAL_LINE,),
+ (CompareImages.VisualizationMode.VERTICAL_LINE,),
+ (CompareImages.VisualizationMode.ONLY_A,),
+ (CompareImages.VisualizationMode.ONLY_B,),
+ ],
+)
+def testVisualizationMode(compareImages, data):
+ (visualizationMode,) = data
+ compareImages.setImage1(numpy.random.rand(10, 10))
+ compareImages.setImage2(numpy.random.rand(10, 10))
+ compareImages.setVisualizationMode(visualizationMode)
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ (CompareImages.VisualizationMode.COMPOSITE_A_MINUS_B,),
+ (CompareImages.VisualizationMode.COMPOSITE_RED_BLUE_GRAY,),
+ (CompareImages.VisualizationMode.HORIZONTAL_LINE,),
+ (CompareImages.VisualizationMode.VERTICAL_LINE,),
+ (CompareImages.VisualizationMode.ONLY_A,),
+ (CompareImages.VisualizationMode.ONLY_B,),
+ ],
+)
+def testVisualizationModeWithoutImage(compareImages, data):
+ (visualizationMode,) = data
+ compareImages.setImage1(None)
+ compareImages.setImage2(None)
+ compareImages.setVisualizationMode(visualizationMode)
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ (CompareImages.VisualizationMode.COMPOSITE_A_MINUS_B,),
+ (CompareImages.VisualizationMode.COMPOSITE_RED_BLUE_GRAY,),
+ (CompareImages.VisualizationMode.HORIZONTAL_LINE,),
+ (CompareImages.VisualizationMode.VERTICAL_LINE,),
+ (CompareImages.VisualizationMode.ONLY_A,),
+ (CompareImages.VisualizationMode.ONLY_B,),
+ ],
+)
+def testVisualizationModeWithOnlyImage1(compareImages, data):
+ (visualizationMode,) = data
+ compareImages.setImage1(numpy.random.rand(10, 10))
+ compareImages.setImage2(None)
+ compareImages.setVisualizationMode(visualizationMode)
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ (CompareImages.VisualizationMode.COMPOSITE_A_MINUS_B,),
+ (CompareImages.VisualizationMode.COMPOSITE_RED_BLUE_GRAY,),
+ (CompareImages.VisualizationMode.HORIZONTAL_LINE,),
+ (CompareImages.VisualizationMode.VERTICAL_LINE,),
+ (CompareImages.VisualizationMode.ONLY_A,),
+ (CompareImages.VisualizationMode.ONLY_B,),
+ ],
+)
+def testVisualizationModeWithOnlyImage2(compareImages, data):
+ (visualizationMode,) = data
+ compareImages.setImage1(None)
+ compareImages.setImage2(numpy.random.rand(10, 10))
+ compareImages.setVisualizationMode(visualizationMode)
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ (CompareImages.VisualizationMode.COMPOSITE_A_MINUS_B,),
+ (CompareImages.VisualizationMode.COMPOSITE_RED_BLUE_GRAY,),
+ (CompareImages.VisualizationMode.HORIZONTAL_LINE,),
+ (CompareImages.VisualizationMode.VERTICAL_LINE,),
+ (CompareImages.VisualizationMode.ONLY_A,),
+ (CompareImages.VisualizationMode.ONLY_B,),
+ ],
+)
+def testVisualizationModeWithRGBImage(compareImages, data):
+ (visualizationMode,) = data
+ image1 = numpy.random.rand(10, 10)
+ image2 = numpy.random.randint(0, 255, size=(10, 10, 3))
+ compareImages.setData(image1, image2)
+ compareImages.setVisualizationMode(visualizationMode)
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ (CompareImages.AlignmentMode.STRETCH,),
+ (CompareImages.AlignmentMode.AUTO,),
+ (CompareImages.AlignmentMode.CENTER,),
+ (CompareImages.AlignmentMode.ORIGIN,),
+ ],
+)
+def testAlignemntModeWithoutImages(compareImages, data):
+ (alignmentMode,) = data
+ compareImages.setAlignmentMode(alignmentMode)
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ (CompareImages.AlignmentMode.STRETCH,),
+ (CompareImages.AlignmentMode.AUTO,),
+ (CompareImages.AlignmentMode.CENTER,),
+ (CompareImages.AlignmentMode.ORIGIN,),
+ ],
+)
+def testAlignemntModeWithSingleImage(compareImages, data):
+ (alignmentMode,) = data
+ compareImages.setImage1(numpy.arange(9).reshape(3, 3))
+ compareImages.setAlignmentMode(alignmentMode)
+
+
+def testTooltip(compareImages):
+ compareImages.setImage1(numpy.arange(9).reshape(3, 3))
+ compareImages.setImage2(numpy.arange(9).reshape(3, 3))
+ compareImages.getRawPixelData(1.5, 1.5)
+
+
+def testTooltipWithoutImage(compareImages):
+ compareImages.setImage1(numpy.arange(9).reshape(3, 3))
+ compareImages.setImage2(numpy.arange(9).reshape(3, 3))
+ compareImages.getRawPixelData(1.5, 1.5)
+
+
+def testTooltipWithSingleImage(compareImages):
+ compareImages.setImage1(numpy.arange(9).reshape(3, 3))
+ compareImages.getRawPixelData(1.5, 1.5)
diff --git a/src/silx/gui/plot/test/testComplexImageView.py b/src/silx/gui/plot/test/testComplexImageView.py
index c26df25..f8b331b 100644
--- a/src/silx/gui/plot/test/testComplexImageView.py
+++ b/src/silx/gui/plot/test/testComplexImageView.py
@@ -28,7 +28,6 @@ __license__ = "MIT"
__date__ = "17/01/2018"
-import unittest
import logging
import numpy
@@ -57,7 +56,7 @@ class TestComplexImageView(PlotWidgetTestCase, ParametricTestCase):
# Test colormap API
colormap = self.plot.getColormap().copy()
- colormap.setName('magma')
+ colormap.setName("magma")
self.plot.setColormap(colormap)
self.qWait(100)
diff --git a/src/silx/gui/plot/test/testCurvesROIWidget.py b/src/silx/gui/plot/test/testCurvesROIWidget.py
index 32ac057..05acd36 100644
--- a/src/silx/gui/plot/test/testCurvesROIWidget.py
+++ b/src/silx/gui/plot/test/testCurvesROIWidget.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -30,8 +30,6 @@ __date__ = "16/11/2017"
import logging
import os.path
-import pytest
-from collections import OrderedDict
import numpy
from silx.gui import qt
@@ -40,9 +38,7 @@ from silx.gui.plot import Plot1D
from silx.test.utils import temp_dir
from silx.gui.utils.testutils import TestCaseQt, SignalListener
from silx.gui.plot import PlotWindow, CurvesROIWidget
-from silx.gui.plot.CurvesROIWidget import ROITable
from silx.gui.utils.testutils import getQToolButtonFromAction
-from silx.gui.plot.PlotInteraction import ItemsInteraction
_logger = logging.getLogger(__name__)
@@ -74,10 +70,12 @@ class TestCurvesROIWidget(TestCaseQt):
def testDummyAPI(self):
"""Simple test of the getRois and setRois API"""
- roi_neg = CurvesROIWidget.ROI(name='negative', fromdata=-20,
- todata=-10, type_='X')
- roi_pos = CurvesROIWidget.ROI(name='positive', fromdata=10,
- todata=20, type_='X')
+ roi_neg = CurvesROIWidget.ROI(
+ name="negative", fromdata=-20, todata=-10, type_="X"
+ )
+ roi_pos = CurvesROIWidget.ROI(
+ name="positive", fromdata=10, todata=20, type_="X"
+ )
self.widget.roiWidget.setRois((roi_pos, roi_neg))
@@ -87,9 +85,11 @@ class TestCurvesROIWidget(TestCaseQt):
def testWithCurves(self):
"""Plot with curves: test all ROI widget buttons"""
for offset in range(2):
- self.plot.addCurve(numpy.arange(1000),
- offset + numpy.random.random(1000),
- legend=str(offset))
+ self.plot.addCurve(
+ numpy.arange(1000),
+ offset + numpy.random.random(1000),
+ legend=str(offset),
+ )
# Add two ROI
self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton)
@@ -105,7 +105,7 @@ class TestCurvesROIWidget(TestCaseQt):
self.qWait(200)
with temp_dir() as tmpDir:
- self.tmpFile = os.path.join(tmpDir, 'test.ini')
+ self.tmpFile = os.path.join(tmpDir, "test.ini")
# Save ROIs
self.widget.roiWidget.save(self.tmpFile)
@@ -113,13 +113,12 @@ class TestCurvesROIWidget(TestCaseQt):
self.assertEqual(len(self.widget.getRois()), 2)
# Reset ROIs
- self.mouseClick(self.widget.roiWidget.resetButton,
- qt.Qt.LeftButton)
+ self.mouseClick(self.widget.roiWidget.resetButton, qt.Qt.LeftButton)
self.qWait(200)
rois = self.widget.getRois()
self.assertEqual(len(rois), 1)
roiID = list(rois.keys())[0]
- self.assertEqual(rois[roiID].getName(), 'ICR')
+ self.assertEqual(rois[roiID].getName(), "ICR")
# Load ROIs
self.widget.roiWidget.load(self.tmpFile)
@@ -135,18 +134,20 @@ class TestCurvesROIWidget(TestCaseQt):
self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton)
for roiID in self.widget.roiWidget.roiTable._markersHandler._roiMarkerHandlers:
- handler = self.widget.roiWidget.roiTable._markersHandler._roiMarkerHandlers[roiID]
- assert handler.getMarker('min')
- xleftMarker = handler.getMarker('min').getXPosition()
- xMiddleMarker = handler.getMarker('middle').getXPosition()
- xRightMarker = handler.getMarker('max').getXPosition()
- thValue = xleftMarker + (xRightMarker - xleftMarker) / 2.
+ handler = self.widget.roiWidget.roiTable._markersHandler._roiMarkerHandlers[
+ roiID
+ ]
+ assert handler.getMarker("min")
+ xleftMarker = handler.getMarker("min").getXPosition()
+ xMiddleMarker = handler.getMarker("middle").getXPosition()
+ xRightMarker = handler.getMarker("max").getXPosition()
+ thValue = xleftMarker + (xRightMarker - xleftMarker) / 2.0
self.assertAlmostEqual(xMiddleMarker, thValue)
def testAreaCalculation(self):
"""Test result of area calculation"""
- x = numpy.arange(100.)
- y = numpy.arange(100.)
+ x = numpy.arange(100.0)
+ y = numpy.arange(100.0)
# Add two curves
self.plot.addCurve(x, y, legend="positive")
@@ -156,30 +157,30 @@ class TestCurvesROIWidget(TestCaseQt):
self.plot.setActiveCurve("positive")
# Add two ROIs
- roi_neg = CurvesROIWidget.ROI(name='negative', fromdata=-20,
- todata=-10, type_='X')
- roi_pos = CurvesROIWidget.ROI(name='positive', fromdata=10,
- todata=20, type_='X')
+ roi_neg = CurvesROIWidget.ROI(
+ name="negative", fromdata=-20, todata=-10, type_="X"
+ )
+ roi_pos = CurvesROIWidget.ROI(
+ name="positive", fromdata=10, todata=20, type_="X"
+ )
self.widget.roiWidget.setRois((roi_pos, roi_neg))
- posCurve = self.plot.getCurve('positive')
- negCurve = self.plot.getCurve('negative')
+ posCurve = self.plot.getCurve("positive")
+ negCurve = self.plot.getCurve("negative")
- self.assertEqual(roi_pos.computeRawAndNetArea(posCurve),
- (numpy.trapz(y=[10, 20], x=[10, 20]),
- 0.0))
- self.assertEqual(roi_pos.computeRawAndNetArea(negCurve),
- (0.0, 0.0))
- self.assertEqual(roi_neg.computeRawAndNetArea(posCurve),
- ((0.0), 0.0))
- self.assertEqual(roi_neg.computeRawAndNetArea(negCurve),
- ((-150.0), 0.0))
+ self.assertEqual(
+ roi_pos.computeRawAndNetArea(posCurve),
+ (numpy.trapz(y=[10, 20], x=[10, 20]), 0.0),
+ )
+ self.assertEqual(roi_pos.computeRawAndNetArea(negCurve), (0.0, 0.0))
+ self.assertEqual(roi_neg.computeRawAndNetArea(posCurve), ((0.0), 0.0))
+ self.assertEqual(roi_neg.computeRawAndNetArea(negCurve), ((-150.0), 0.0))
def testCountsCalculation(self):
"""Test result of count calculation"""
- x = numpy.arange(100.)
- y = numpy.arange(100.)
+ x = numpy.arange(100.0)
+ y = numpy.arange(100.0)
# Add two curves
self.plot.addCurve(x, y, legend="positive")
@@ -189,36 +190,38 @@ class TestCurvesROIWidget(TestCaseQt):
self.plot.setActiveCurve("positive")
# Add two ROIs
- roi_neg = CurvesROIWidget.ROI(name='negative', fromdata=-20,
- todata=-10, type_='X')
- roi_pos = CurvesROIWidget.ROI(name='positive', fromdata=10,
- todata=20, type_='X')
+ roi_neg = CurvesROIWidget.ROI(
+ name="negative", fromdata=-20, todata=-10, type_="X"
+ )
+ roi_pos = CurvesROIWidget.ROI(
+ name="positive", fromdata=10, todata=20, type_="X"
+ )
self.widget.roiWidget.setRois((roi_pos, roi_neg))
- posCurve = self.plot.getCurve('positive')
- negCurve = self.plot.getCurve('negative')
+ posCurve = self.plot.getCurve("positive")
+ negCurve = self.plot.getCurve("negative")
- self.assertEqual(roi_pos.computeRawAndNetCounts(posCurve),
- (y[10:21].sum(), 0.0))
- self.assertEqual(roi_pos.computeRawAndNetCounts(negCurve),
- (0.0, 0.0))
- self.assertEqual(roi_neg.computeRawAndNetCounts(posCurve),
- ((0.0), 0.0))
- self.assertEqual(roi_neg.computeRawAndNetCounts(negCurve),
- (y[10:21].sum(), 0.0))
+ self.assertEqual(
+ roi_pos.computeRawAndNetCounts(posCurve), (y[10:21].sum(), 0.0)
+ )
+ self.assertEqual(roi_pos.computeRawAndNetCounts(negCurve), (0.0, 0.0))
+ self.assertEqual(roi_neg.computeRawAndNetCounts(posCurve), ((0.0), 0.0))
+ self.assertEqual(
+ roi_neg.computeRawAndNetCounts(negCurve), (y[10:21].sum(), 0.0)
+ )
def testDeferedInit(self):
"""Test behavior of the deferedInit"""
- x = numpy.arange(100.)
- y = numpy.arange(100.)
+ x = numpy.arange(100.0)
+ y = numpy.arange(100.0)
self.plot.addCurve(x=x, y=y, legend="name", replace="True")
- roisDefs = OrderedDict([
- ["range1",
- OrderedDict([["from", 20], ["to", 200], ["type", "energy"]])],
- ["range2",
- OrderedDict([["from", 300], ["to", 500], ["type", "energy"]])]
- ])
+ roisDefs = dict(
+ [
+ ["range1", dict([["from", 20], ["to", 200], ["type", "energy"]])],
+ ["range2", dict([["from", 300], ["to", 500], ["type", "energy"]])],
+ ]
+ )
roiWidget = self.plot.getCurvesRoiDockWidget().roiWidget
self.plot.getCurvesRoiDockWidget().setRois(roisDefs)
@@ -228,34 +231,41 @@ class TestCurvesROIWidget(TestCaseQt):
def testDictCompatibility(self):
"""Test that ROI api is valid with dict and not information is lost"""
- roiDict = {'from': 20, 'to': 200, 'type': 'energy', 'comment': 'no',
- 'name': 'myROI', 'calibration': [1, 2, 3]}
+ roiDict = {
+ "from": 20,
+ "to": 200,
+ "type": "energy",
+ "comment": "no",
+ "name": "myROI",
+ "calibration": [1, 2, 3],
+ }
roi = CurvesROIWidget.ROI._fromDict(roiDict)
self.assertEqual(roi.toDict(), roiDict)
def testShowAllROI(self):
"""Test the show allROI action"""
- x = numpy.arange(100.)
- y = numpy.arange(100.)
+ x = numpy.arange(100.0)
+ y = numpy.arange(100.0)
self.plot.addCurve(x=x, y=y, legend="name", replace="True")
roisDefsDict = {
- "range1": {"from": 20, "to": 200,"type": "energy"},
- "range2": {"from": 300, "to": 500, "type": "energy"}
+ "range1": {"from": 20, "to": 200, "type": "energy"},
+ "range2": {"from": 300, "to": 500, "type": "energy"},
}
roisDefsObj = (
- CurvesROIWidget.ROI(name='range3', fromdata=20, todata=200,
- type_='energy'),
- CurvesROIWidget.ROI(name='range4', fromdata=300, todata=500,
- type_='energy')
+ CurvesROIWidget.ROI(name="range3", fromdata=20, todata=200, type_="energy"),
+ CurvesROIWidget.ROI(
+ name="range4", fromdata=300, todata=500, type_="energy"
+ ),
)
self.widget.roiWidget.showAllMarkers(True)
roiWidget = self.plot.getCurvesRoiDockWidget().roiWidget
roiWidget.setRois(roisDefsDict)
- markers = [item for item in self.plot.getItems()
- if isinstance(item, items.MarkerBase)]
- self.assertEqual(len(markers), 2*3)
+ markers = [
+ item for item in self.plot.getItems() if isinstance(item, items.MarkerBase)
+ ]
+ self.assertEqual(len(markers), 2 * 3)
markersHandler = self.widget.roiWidget.roiTable._markersHandler
roiWidget.showAllMarkers(True)
@@ -268,9 +278,10 @@ class TestCurvesROIWidget(TestCaseQt):
roiWidget.setRois(roisDefsObj)
self.qapp.processEvents()
- markers = [item for item in self.plot.getItems()
- if isinstance(item, items.MarkerBase)]
- self.assertEqual(len(markers), 2*3)
+ markers = [
+ item for item in self.plot.getItems() if isinstance(item, items.MarkerBase)
+ ]
+ self.assertEqual(len(markers), 2 * 3)
markersHandler = self.widget.roiWidget.roiTable._markersHandler
roiWidget.showAllMarkers(True)
@@ -282,52 +293,51 @@ class TestCurvesROIWidget(TestCaseQt):
self.assertEqual(len(ICRROI), 1)
def testRoiEdition(self):
- """Make sure if the ROI object is edited the ROITable will be updated
- """
- roi = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5)
- self.widget.roiWidget.setRois((roi, ))
+ """Make sure if the ROI object is edited the ROITable will be updated"""
+ roi = CurvesROIWidget.ROI(name="linear", fromdata=0, todata=5)
+ self.widget.roiWidget.setRois((roi,))
x = (0, 1, 1, 2, 2, 3)
y = (1, 1, 2, 2, 1, 1)
- self.plot.addCurve(x=x, y=y, legend='linearCurve')
- self.plot.setActiveCurve(legend='linearCurve')
+ self.plot.addCurve(x=x, y=y, legend="linearCurve")
+ self.plot.setActiveCurve(legend="linearCurve")
self.widget.calculateROIs()
roiTable = self.widget.roiWidget.roiTable
indexesColumns = CurvesROIWidget.ROITable.COLUMNS_INDEX
- itemRawCounts = roiTable.item(0, indexesColumns['Raw Counts'])
- itemNetCounts = roiTable.item(0, indexesColumns['Net Counts'])
+ itemRawCounts = roiTable.item(0, indexesColumns["Raw Counts"])
+ itemNetCounts = roiTable.item(0, indexesColumns["Net Counts"])
- self.assertTrue(itemRawCounts.text() == '8.0')
- self.assertTrue(itemNetCounts.text() == '2.0')
+ self.assertTrue(itemRawCounts.text() == "8.0")
+ self.assertTrue(itemNetCounts.text() == "2.0")
- itemRawArea = roiTable.item(0, indexesColumns['Raw Area'])
- itemNetArea = roiTable.item(0, indexesColumns['Net Area'])
+ itemRawArea = roiTable.item(0, indexesColumns["Raw Area"])
+ itemNetArea = roiTable.item(0, indexesColumns["Net Area"])
- self.assertTrue(itemRawArea.text() == '4.0')
- self.assertTrue(itemNetArea.text() == '1.0')
+ self.assertTrue(itemRawArea.text() == "4.0")
+ self.assertTrue(itemNetArea.text() == "1.0")
roi.setTo(2)
- itemRawArea = roiTable.item(0, indexesColumns['Raw Area'])
- self.assertTrue(itemRawArea.text() == '3.0')
+ itemRawArea = roiTable.item(0, indexesColumns["Raw Area"])
+ self.assertTrue(itemRawArea.text() == "3.0")
roi.setFrom(1)
- itemRawArea = roiTable.item(0, indexesColumns['Raw Area'])
- self.assertTrue(itemRawArea.text() == '2.0')
+ itemRawArea = roiTable.item(0, indexesColumns["Raw Area"])
+ self.assertTrue(itemRawArea.text() == "2.0")
def testRemoveActiveROI(self):
"""Test widget behavior when removing the active ROI"""
- roi = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5)
+ roi = CurvesROIWidget.ROI(name="linear", fromdata=0, todata=5)
self.widget.roiWidget.setRois((roi,))
self.widget.roiWidget.roiTable.setActiveRoi(None)
self.assertEqual(len(self.widget.roiWidget.roiTable.selectedItems()), 0)
self.widget.roiWidget.setRois((roi,))
- self.plot.setActiveCurve(legend='linearCurve')
+ self.plot.setActiveCurve(legend="linearCurve")
self.widget.calculateROIs()
def testEmitCurrentROI(self):
"""Test behavior of the CurvesROIWidget.sigROISignal"""
- roi = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5)
+ roi = CurvesROIWidget.ROI(name="linear", fromdata=0, todata=5)
self.widget.roiWidget.setRois((roi,))
signalListener = SignalListener()
self.widget.roiWidget.sigROISignal.connect(signalListener.partial())
@@ -352,7 +362,7 @@ class TestRoiWidgetSignals(TestCaseQt):
self.plot = Plot1D()
x = range(20)
y = range(20)
- self.plot.addCurve(x, y, legend='curve0')
+ self.plot.addCurve(x, y, legend="curve0")
self.listener = SignalListener()
self.curves_roi_widget = self.plot.getCurvesRoiWidget()
self.curves_roi_widget.sigROISignal.connect(self.listener)
@@ -383,33 +393,33 @@ class TestRoiWidgetSignals(TestCaseQt):
"""Test SigROISignal when adding and removing ROIS"""
self.listener.clear()
- roi1 = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5)
+ roi1 = CurvesROIWidget.ROI(name="linear", fromdata=0, todata=5)
self.curves_roi_widget.roiTable.addRoi(roi1)
self.assertEqual(self.listener.callCount(), 1)
- self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear')
+ self.assertTrue(self.listener.arguments()[0][0]["current"] == "linear")
self.listener.clear()
- roi2 = CurvesROIWidget.ROI(name='linear2', fromdata=0, todata=5)
+ roi2 = CurvesROIWidget.ROI(name="linear2", fromdata=0, todata=5)
self.curves_roi_widget.roiTable.addRoi(roi2)
self.assertEqual(self.listener.callCount(), 1)
- self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear2')
+ self.assertTrue(self.listener.arguments()[0][0]["current"] == "linear2")
self.listener.clear()
self.curves_roi_widget.roiTable.removeROI(roi2)
self.assertEqual(self.listener.callCount(), 1)
self.assertTrue(self.curves_roi_widget.roiTable.activeRoi == roi1)
- self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear')
+ self.assertTrue(self.listener.arguments()[0][0]["current"] == "linear")
self.listener.clear()
self.curves_roi_widget.roiTable.deleteActiveRoi()
self.assertEqual(self.listener.callCount(), 1)
self.assertTrue(self.curves_roi_widget.roiTable.activeRoi is None)
- self.assertTrue(self.listener.arguments()[0][0]['current'] is None)
+ self.assertTrue(self.listener.arguments()[0][0]["current"] is None)
self.listener.clear()
self.curves_roi_widget.roiTable.addRoi(roi1)
self.assertEqual(self.listener.callCount(), 1)
- self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear')
+ self.assertTrue(self.listener.arguments()[0][0]["current"] == "linear")
self.assertTrue(self.curves_roi_widget.roiTable.activeRoi == roi1)
self.listener.clear()
self.qapp.processEvents()
@@ -417,13 +427,13 @@ class TestRoiWidgetSignals(TestCaseQt):
self.curves_roi_widget.roiTable.removeROI(roi1)
self.qapp.processEvents()
self.assertEqual(self.listener.callCount(), 1)
- self.assertTrue(self.listener.arguments()[0][0]['current'] == 'ICR')
+ self.assertTrue(self.listener.arguments()[0][0]["current"] == "ICR")
self.listener.clear()
def testSigROISignalModifyROI(self):
"""Test SigROISignal when modifying it"""
self.curves_roi_widget.roiTable.setMiddleROIMarkerFlag(True)
- roi1 = CurvesROIWidget.ROI(name='linear', fromdata=2, todata=5)
+ roi1 = CurvesROIWidget.ROI(name="linear", fromdata=2, todata=5)
self.curves_roi_widget.roiTable.addRoi(roi1)
self.curves_roi_widget.roiTable.setActiveRoi(roi1)
@@ -435,10 +445,10 @@ class TestRoiWidgetSignals(TestCaseQt):
roi1.setTo(2.56)
self.assertEqual(self.listener.callCount(), 1)
self.listener.clear()
- roi1.setName('linear2')
+ roi1.setName("linear2")
self.assertEqual(self.listener.callCount(), 1)
self.listener.clear()
- roi1.setType('new type')
+ roi1.setType("new type")
self.assertEqual(self.listener.callCount(), 1)
widget = self.plot.getWidgetHandle()
@@ -447,18 +457,24 @@ class TestRoiWidgetSignals(TestCaseQt):
self.qapp.processEvents()
# modify roi limits (from the gui)
- roi_marker_handler = self.curves_roi_widget.roiTable._markersHandler.getMarkerHandler(roi1.getID())
- for marker_type in ('min', 'max', 'middle'):
+ roi_marker_handler = (
+ self.curves_roi_widget.roiTable._markersHandler.getMarkerHandler(
+ roi1.getID()
+ )
+ )
+ for marker_type in ("min", "max", "middle"):
with self.subTest(marker_type=marker_type):
self.listener.clear()
marker = roi_marker_handler.getMarker(marker_type)
- x_pix, y_pix = self.plot.dataToPixel(marker.getXPosition(), marker.getYPosition())
+ x_pix, y_pix = self.plot.dataToPixel(
+ marker.getXPosition(), marker.getYPosition()
+ )
self.mouseMove(widget, pos=(x_pix, y_pix))
self.qWait(100)
self.mousePress(widget, qt.Qt.LeftButton, pos=(x_pix, y_pix))
- self.mouseMove(widget, pos=(x_pix+20, y_pix))
+ self.mouseMove(widget, pos=(x_pix + 20, y_pix))
self.qWait(100)
- self.mouseRelease(widget, qt.Qt.LeftButton, pos=(x_pix+20, y_pix))
+ self.mouseRelease(widget, qt.Qt.LeftButton, pos=(x_pix + 20, y_pix))
self.qWait(100)
self.mouseMove(widget, pos=(x_pix, y_pix))
self.qapp.processEvents()
@@ -466,8 +482,8 @@ class TestRoiWidgetSignals(TestCaseQt):
def testSetActiveCurve(self):
"""Test sigRoiSignal when set an active curve"""
- roi1 = CurvesROIWidget.ROI(name='linear', fromdata=2, todata=5)
+ roi1 = CurvesROIWidget.ROI(name="linear", fromdata=2, todata=5)
self.curves_roi_widget.roiTable.setActiveRoi(roi1)
self.listener.clear()
- self.plot.setActiveCurve('curve0')
+ self.plot.setActiveCurve("curve0")
self.assertEqual(self.listener.callCount(), 0)
diff --git a/src/silx/gui/plot/test/testImageStack.py b/src/silx/gui/plot/test/testImageStack.py
index 702f0fe..482cdfd 100644
--- a/src/silx/gui/plot/test/testImageStack.py
+++ b/src/silx/gui/plot/test/testImageStack.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2020 European Synchrotron Radiation Facility
+# Copyright (c) 2020-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -28,7 +28,6 @@ __license__ = "MIT"
__date__ = "15/01/2020"
-import unittest
import tempfile
import numpy
import h5py
@@ -38,7 +37,6 @@ from silx.gui.utils.testutils import TestCaseQt
from silx.io.url import DataUrl
from silx.gui.plot.ImageStack import ImageStack
from silx.gui.utils.testutils import SignalListener
-from collections import OrderedDict
import os
import time
import shutil
@@ -49,21 +47,21 @@ class TestImageStack(TestCaseQt):
def setUp(self):
TestCaseQt.setUp(self)
- self.urls = OrderedDict()
+ self.urls = {}
self._raw_data = {}
self._folder = tempfile.mkdtemp()
self._n_urls = 10
- file_name = os.path.join(self._folder, 'test_inage_stack_file.h5')
- with h5py.File(file_name, 'w') as h5f:
+ file_name = os.path.join(self._folder, "test_inage_stack_file.h5")
+ with h5py.File(file_name, "w") as h5f:
for i in range(self._n_urls):
width = numpy.random.randint(10, 40)
height = numpy.random.randint(10, 40)
raw_data = numpy.random.random((width, height))
self._raw_data[i] = raw_data
h5f[str(i)] = raw_data
- self.urls[i] = DataUrl(file_path=file_name,
- data_path=str(i),
- scheme='silx')
+ self.urls[i] = DataUrl(
+ file_path=file_name, data_path=str(i), scheme="silx"
+ )
self.widget = ImageStack()
self.urlLoadedListener = SignalListener()
@@ -79,8 +77,7 @@ class TestImageStack(TestCaseQt):
TestCaseQt.setUp(self)
def testControls(self):
- """Test that selection using the url table and the slider are working
- """
+ """Test that selection using the url table and the slider are working"""
self.widget.show()
self.assertEqual(self.widget.getCurrentUrl(), None)
self.assertEqual(self.widget.getCurrentUrlIndex(), None)
@@ -95,13 +92,15 @@ class TestImageStack(TestCaseQt):
self.assertEqual(self.urlLoadedListener.callCount(), self._n_urls)
numpy.testing.assert_array_equal(
self.widget.getPlotWidget().getActiveImage(just_legend=False).getData(),
- self._raw_data[0])
+ self._raw_data[0],
+ )
self.assertEqual(self.widget._slider.value(), 0)
self.widget._urlsTable.setUrl(self.urls[4])
numpy.testing.assert_array_equal(
self.widget.getPlotWidget().getActiveImage(just_legend=False).getData(),
- self._raw_data[4])
+ self._raw_data[4],
+ )
self.assertEqual(self.widget._slider.value(), 4)
self.assertEqual(self.widget.getCurrentUrl(), self.urls[4])
self.assertEqual(self.widget.getCurrentUrlIndex(), 4)
@@ -109,9 +108,11 @@ class TestImageStack(TestCaseQt):
self.widget._slider.setUrlIndex(6)
numpy.testing.assert_array_equal(
self.widget.getPlotWidget().getActiveImage(just_legend=False).getData(),
- self._raw_data[6])
- self.assertEqual(self.widget._urlsTable.currentItem().text(),
- self.urls[6].path())
+ self._raw_data[6],
+ )
+ self.assertEqual(
+ self.widget._urlsTable.currentItem().text(), self.urls[6].path()
+ )
def testCurrentUrlSignals(self):
"""Test emission of 'currentUrlChangedListener'"""
@@ -151,26 +152,72 @@ class TestImageStack(TestCaseQt):
self.assertEqual(urls_values[0], self.urls[0])
self.assertEqual(urls_values[7], self.urls[7])
- self.assertEqual(self.widget._getNextUrl(urls_values[2]).path(),
- urls_values[3].path())
+ self.assertEqual(
+ self.widget._getNextUrl(urls_values[2]).path(), urls_values[3].path()
+ )
self.assertEqual(self.widget._getPreviousUrl(urls_values[0]), None)
- self.assertEqual(self.widget._getPreviousUrl(urls_values[6]).path(),
- urls_values[5].path())
-
- self.assertEqual(self.widget._getNNextUrls(2, urls_values[0]),
- urls_values[1:3])
- self.assertEqual(self.widget._getNNextUrls(5, urls_values[7]),
- urls_values[8:])
- self.assertEqual(self.widget._getNPreviousUrls(3, urls_values[2]),
- urls_values[:2])
- self.assertEqual(self.widget._getNPreviousUrls(5, urls_values[8]),
- urls_values[3:8])
+ self.assertEqual(
+ self.widget._getPreviousUrl(urls_values[6]).path(), urls_values[5].path()
+ )
+
+ self.assertEqual(self.widget._getNNextUrls(2, urls_values[0]), urls_values[1:3])
+ self.assertEqual(self.widget._getNNextUrls(5, urls_values[7]), urls_values[8:])
+ self.assertEqual(
+ self.widget._getNPreviousUrls(3, urls_values[2]), urls_values[:2]
+ )
+ self.assertEqual(
+ self.widget._getNPreviousUrls(5, urls_values[8]), urls_values[3:8]
+ )
+
+ def testRemoveUrlFromList(self):
+ """
+ Test behavior when some item (url) are removed from the list
+ """
+ self.widget.setUrlsEditable(True)
+ self.widget.show()
+ self.widget.setUrls(list(self.urls.values()))
+ self.assertEqual(len(self.widget.getUrls()), len(self.urls))
+
+ # wait for image to be loaded
+ self._waitUntilUrlLoaded()
+ ll_slider = self.widget._slider._slider
+ assert ll_slider.maximum() - ll_slider.minimum() + 1 == len(self.urls)
+
+ # remove some urls from the list (~ simulating behavior with a right click)
+ urlsTable = self.widget._urlsTable._urlsTable
+ urlsTable.clearSelection()
+ urlsTable.item(1).setSelected(True)
+ urlsTable.item(2).setSelected(True)
+ urlsTable._removeSelectedItems()
+ self.qapp.processEvents()
+
+ # make sure slider has been updated
+ assert ll_slider.maximum() - ll_slider.minimum() + 1 == len(self.urls) - 2
+ # as the ImageStack widget
+ assert len(self.widget._urls) == len(self.urls) - 2
+ removed_urls = list(self.urls.values())[1:3]
+
+ existing_urls_as_str = [url.path() for url in self.widget._urls.values()]
+ for removed_url in removed_urls:
+ assert type(removed_url) == type(tuple(self.widget._urls.values())[0])
+ assert removed_url.path() not in existing_urls_as_str
+ # make sure we have some data plot
+ self.widget.getPlotWidget().getActiveImage() is not None
+
+ # test removing remaining urls
+ urlsTable.selectAll()
+ urlsTable._removeSelectedItems()
+ self.qapp.processEvents()
+ assert len(self.widget._urls) == 0
+ assert ll_slider.maximum() - ll_slider.minimum() == 0
+ # make sure if all urls are removed nothing is plot anymore
+ self.widget.getPlotWidget().getActiveImage() is None
def _waitUntilUrlLoaded(self, timeout=2.0):
"""Wait until all image urls are loaded"""
loop_duration = 0.2
remaining_duration = timeout
- while(len(self.widget._loadingThreads) > 0 and remaining_duration > 0):
+ while len(self.widget._loadingThreads) > 0 and remaining_duration > 0:
remaining_duration -= loop_duration
time.sleep(loop_duration)
self.qapp.processEvents()
@@ -179,7 +226,9 @@ class TestImageStack(TestCaseQt):
remaining_urls = []
for thread_ in self.widget._loadingThreads:
remaining_urls.append(thread_.url.path())
- mess = 'All images are not loaded after the time out. ' \
- 'Remaining urls are: ' + str(remaining_urls)
+ mess = (
+ "All images are not loaded after the time out. "
+ "Remaining urls are: " + str(remaining_urls)
+ )
raise TimeoutError(mess)
return True
diff --git a/src/silx/gui/plot/test/testImageView.py b/src/silx/gui/plot/test/testImageView.py
index 9fb6a5d..df19ab7 100644
--- a/src/silx/gui/plot/test/testImageView.py
+++ b/src/silx/gui/plot/test/testImageView.py
@@ -92,31 +92,33 @@ class TestImageView(TestCaseQt):
self.plot.setImage(image)
# Colormap as dict
- self.plot.setColormap({'name': 'viridis',
- 'normalization': 'log',
- 'autoscale': False,
- 'vmin': 0,
- 'vmax': 1})
+ self.plot.setColormap(
+ {
+ "name": "viridis",
+ "normalization": "log",
+ "autoscale": False,
+ "vmin": 0,
+ "vmax": 1,
+ }
+ )
colormap = self.plot.getColormap()
- self.assertEqual(colormap.getName(), 'viridis')
- self.assertEqual(colormap.getNormalization(), 'log')
+ self.assertEqual(colormap.getName(), "viridis")
+ self.assertEqual(colormap.getNormalization(), "log")
self.assertEqual(colormap.getVMin(), 0)
self.assertEqual(colormap.getVMax(), 1)
# Colormap as keyword arguments
- self.plot.setColormap(colormap='magma',
- normalization='linear',
- autoscale=True,
- vmin=1,
- vmax=2)
- self.assertEqual(colormap.getName(), 'magma')
- self.assertEqual(colormap.getNormalization(), 'linear')
+ self.plot.setColormap(
+ colormap="magma", normalization="linear", autoscale=True, vmin=1, vmax=2
+ )
+ self.assertEqual(colormap.getName(), "magma")
+ self.assertEqual(colormap.getNormalization(), "linear")
self.assertEqual(colormap.getVMin(), None)
self.assertEqual(colormap.getVMax(), None)
# Update colormap with keyword argument
- self.plot.setColormap(normalization='log')
- self.assertEqual(colormap.getNormalization(), 'log')
+ self.plot.setColormap(normalization="log")
+ self.assertEqual(colormap.getNormalization(), "log")
# Colormap as Colormap object
cmap = Colormap()
@@ -130,7 +132,7 @@ class TestImageView(TestCaseQt):
ImageView.ProfileWindowBehavior.POPUP,
)
- self.plot.setProfileWindowBehavior('embedded')
+ self.plot.setProfileWindowBehavior("embedded")
self.assertIs(
self.plot.getProfileWindowBehavior(),
ImageView.ProfileWindowBehavior.EMBEDDED,
@@ -139,9 +141,7 @@ class TestImageView(TestCaseQt):
image = numpy.arange(100).reshape(10, 10)
self.plot.setImage(image)
- self.plot.setProfileWindowBehavior(
- ImageView.ProfileWindowBehavior.POPUP
- )
+ self.plot.setProfileWindowBehavior(ImageView.ProfileWindowBehavior.POPUP)
self.assertIs(
self.plot.getProfileWindowBehavior(),
ImageView.ProfileWindowBehavior.POPUP,
@@ -170,7 +170,9 @@ class TestImageView(TestCaseQt):
image = numpy.arange(100).reshape(10, 10)
self.plot.setImage(image, reset=True)
self.qWait(100)
- self.plot.getAggregationModeAction().setAggregationMode(items.ImageDataAggregated.Aggregation.MAX)
+ self.plot.getAggregationModeAction().setAggregationMode(
+ items.ImageDataAggregated.Aggregation.MAX
+ )
self.qWait(100)
def testImageAggregationModeBackToNormalMode(self):
@@ -178,9 +180,13 @@ class TestImageView(TestCaseQt):
image = numpy.arange(100).reshape(10, 10)
self.plot.setImage(image, reset=True)
self.qWait(100)
- self.plot.getAggregationModeAction().setAggregationMode(items.ImageDataAggregated.Aggregation.MAX)
+ self.plot.getAggregationModeAction().setAggregationMode(
+ items.ImageDataAggregated.Aggregation.MAX
+ )
self.qWait(100)
- self.plot.getAggregationModeAction().setAggregationMode(items.ImageDataAggregated.Aggregation.NONE)
+ self.plot.getAggregationModeAction().setAggregationMode(
+ items.ImageDataAggregated.Aggregation.NONE
+ )
self.qWait(100)
def testRGBAInAggregationMode(self):
@@ -189,5 +195,7 @@ class TestImageView(TestCaseQt):
self.plot.setImage(image, reset=True)
self.qWait(100)
- self.plot.getAggregationModeAction().setAggregationMode(items.ImageDataAggregated.Aggregation.MAX)
+ self.plot.getAggregationModeAction().setAggregationMode(
+ items.ImageDataAggregated.Aggregation.MAX
+ )
self.qWait(100)
diff --git a/src/silx/gui/plot/test/testInteraction.py b/src/silx/gui/plot/test/testInteraction.py
index 459b132..b031454 100644
--- a/src/silx/gui/plot/test/testInteraction.py
+++ b/src/silx/gui/plot/test/testInteraction.py
@@ -40,38 +40,40 @@ class TestInteraction(unittest.TestCase):
class TestClickOrDrag(Interaction.ClickOrDrag):
def click(self, x, y, btn):
- events.append(('click', x, y, btn))
+ events.append(("click", x, y, btn))
def beginDrag(self, x, y, btn):
- events.append(('beginDrag', x, y, btn))
+ events.append(("beginDrag", x, y, btn))
def drag(self, x, y, btn):
- events.append(('drag', x, y, btn))
+ events.append(("drag", x, y, btn))
def endDrag(self, start, end, btn):
- events.append(('endDrag', start, end, btn))
+ events.append(("endDrag", start, end, btn))
clickOrDrag = TestClickOrDrag()
# click
- clickOrDrag.handleEvent('press', 10, 10, Interaction.LEFT_BTN)
+ clickOrDrag.handleEvent("press", 10, 10, Interaction.LEFT_BTN)
self.assertEqual(len(events), 0)
- clickOrDrag.handleEvent('release', 10, 10, Interaction.LEFT_BTN)
+ clickOrDrag.handleEvent("release", 10, 10, Interaction.LEFT_BTN)
self.assertEqual(len(events), 1)
- self.assertEqual(events[0], ('click', 10, 10, Interaction.LEFT_BTN))
+ self.assertEqual(events[0], ("click", 10, 10, Interaction.LEFT_BTN))
# drag
events = []
- clickOrDrag.handleEvent('press', 10, 10, Interaction.LEFT_BTN)
+ clickOrDrag.handleEvent("press", 10, 10, Interaction.LEFT_BTN)
self.assertEqual(len(events), 0)
- clickOrDrag.handleEvent('move', 15, 10)
+ clickOrDrag.handleEvent("move", 15, 10)
self.assertEqual(len(events), 2) # Received beginDrag and drag
- self.assertEqual(events[0], ('beginDrag', 10, 10, Interaction.LEFT_BTN))
- self.assertEqual(events[1], ('drag', 15, 10, Interaction.LEFT_BTN))
- clickOrDrag.handleEvent('move', 20, 10)
+ self.assertEqual(events[0], ("beginDrag", 10, 10, Interaction.LEFT_BTN))
+ self.assertEqual(events[1], ("drag", 15, 10, Interaction.LEFT_BTN))
+ clickOrDrag.handleEvent("move", 20, 10)
self.assertEqual(len(events), 3)
- self.assertEqual(events[-1], ('drag', 20, 10, Interaction.LEFT_BTN))
- clickOrDrag.handleEvent('release', 20, 10, Interaction.LEFT_BTN)
+ self.assertEqual(events[-1], ("drag", 20, 10, Interaction.LEFT_BTN))
+ clickOrDrag.handleEvent("release", 20, 10, Interaction.LEFT_BTN)
self.assertEqual(len(events), 4)
- self.assertEqual(events[-1], ('endDrag', (10, 10), (20, 10), Interaction.LEFT_BTN))
+ self.assertEqual(
+ events[-1], ("endDrag", (10, 10), (20, 10), Interaction.LEFT_BTN)
+ )
diff --git a/src/silx/gui/plot/test/testItem.py b/src/silx/gui/plot/test/testItem.py
index 7b4f636..8a6db40 100644
--- a/src/silx/gui/plot/test/testItem.py
+++ b/src/silx/gui/plot/test/testItem.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -28,11 +28,11 @@ __license__ = "MIT"
__date__ = "01/09/2017"
-import unittest
-
import numpy
+import pytest
from silx.gui.utils.testutils import SignalListener
+from silx.gui.plot.items.roi import RegionOfInterest
from silx.gui.plot.items import ItemChangedType
from silx.gui.plot import items
from .utils import PlotWidgetTestCase
@@ -43,8 +43,8 @@ class TestSigItemChangedSignal(PlotWidgetTestCase):
def testCurveChanged(self):
"""Test sigItemChanged for curve"""
- self.plot.addCurve(numpy.arange(10), numpy.arange(10), legend='test')
- curve = self.plot.getCurve('test')
+ self.plot.addCurve(numpy.arange(10), numpy.arange(10), legend="test")
+ curve = self.plot.getCurve("test")
listener = SignalListener()
curve.sigItemChanged.connect(listener)
@@ -58,8 +58,8 @@ class TestSigItemChangedSignal(PlotWidgetTestCase):
curve.setData(numpy.arange(100), numpy.arange(100))
# SymbolMixIn
- curve.setSymbol('Circle')
- curve.setSymbol('d')
+ curve.setSymbol("Circle")
+ curve.setSymbol("d")
curve.setSymbolSize(20)
# AlphaMixIn
@@ -67,49 +67,51 @@ class TestSigItemChangedSignal(PlotWidgetTestCase):
# Test for signals in Curve class
# ColorMixIn
- curve.setColor('yellow')
+ curve.setColor("yellow")
# YAxisMixIn
- curve.setYAxis('right')
+ curve.setYAxis("right")
# FillMixIn
curve.setFill(True)
# LineMixIn
- curve.setLineStyle(':')
- curve.setLineStyle(':') # Not sending event
+ curve.setLineStyle(":")
+ curve.setLineStyle(":") # Not sending event
curve.setLineWidth(2)
- self.assertEqual(listener.arguments(argumentIndex=0),
- [ItemChangedType.VISIBLE,
- ItemChangedType.VISIBLE,
- ItemChangedType.ZVALUE,
- ItemChangedType.DATA,
- ItemChangedType.SYMBOL,
- ItemChangedType.SYMBOL,
- ItemChangedType.SYMBOL_SIZE,
- ItemChangedType.ALPHA,
- ItemChangedType.COLOR,
- ItemChangedType.YAXIS,
- ItemChangedType.FILL,
- ItemChangedType.LINE_STYLE,
- ItemChangedType.LINE_WIDTH])
+ self.assertEqual(
+ listener.arguments(argumentIndex=0),
+ [
+ ItemChangedType.VISIBLE,
+ ItemChangedType.VISIBLE,
+ ItemChangedType.ZVALUE,
+ ItemChangedType.DATA,
+ ItemChangedType.SYMBOL,
+ ItemChangedType.SYMBOL,
+ ItemChangedType.SYMBOL_SIZE,
+ ItemChangedType.ALPHA,
+ ItemChangedType.COLOR,
+ ItemChangedType.YAXIS,
+ ItemChangedType.FILL,
+ ItemChangedType.LINE_STYLE,
+ ItemChangedType.LINE_WIDTH,
+ ],
+ )
def testHistogramChanged(self):
"""Test sigItemChanged for Histogram"""
- self.plot.addHistogram(
- numpy.arange(10), edges=numpy.arange(11), legend='test')
- histogram = self.plot.getHistogram('test')
+ self.plot.addHistogram(numpy.arange(10), edges=numpy.arange(11), legend="test")
+ histogram = self.plot.getHistogram("test")
listener = SignalListener()
histogram.sigItemChanged.connect(listener)
# Test signals in Histogram class
histogram.setData(numpy.zeros(10), numpy.arange(11))
- self.assertEqual(listener.arguments(argumentIndex=0),
- [ItemChangedType.DATA])
+ self.assertEqual(listener.arguments(argumentIndex=0), [ItemChangedType.DATA])
def testImageDataChanged(self):
"""Test sigItemChanged for ImageData"""
- self.plot.addImage(numpy.arange(100).reshape(10, 10), legend='test')
- image = self.plot.getImage('test')
+ self.plot.addImage(numpy.arange(100).reshape(10, 10), legend="test")
+ image = self.plot.getImage("test")
listener = SignalListener()
image.sigItemChanged.connect(listener)
@@ -117,7 +119,7 @@ class TestSigItemChangedSignal(PlotWidgetTestCase):
# ColormapMixIn
colormap = self.plot.getDefaultColormap().copy()
image.setColormap(colormap)
- image.getColormap().setName('viridis')
+ image.getColormap().setName("viridis")
# Test of signals in ImageBase class
image.setOrigin(10)
@@ -126,18 +128,22 @@ class TestSigItemChangedSignal(PlotWidgetTestCase):
# Test of signals in ImageData class
image.setData(numpy.ones((10, 10)))
- self.assertEqual(listener.arguments(argumentIndex=0),
- [ItemChangedType.COLORMAP,
- ItemChangedType.COLORMAP,
- ItemChangedType.POSITION,
- ItemChangedType.SCALE,
- ItemChangedType.COLORMAP,
- ItemChangedType.DATA])
+ self.assertEqual(
+ listener.arguments(argumentIndex=0),
+ [
+ ItemChangedType.COLORMAP,
+ ItemChangedType.COLORMAP,
+ ItemChangedType.POSITION,
+ ItemChangedType.SCALE,
+ ItemChangedType.COLORMAP,
+ ItemChangedType.DATA,
+ ],
+ )
def testImageRgbaChanged(self):
"""Test sigItemChanged for ImageRgba"""
- self.plot.addImage(numpy.ones((10, 10, 3)), legend='rgb')
- image = self.plot.getImage('rgb')
+ self.plot.addImage(numpy.ones((10, 10, 3)), legend="rgb")
+ image = self.plot.getImage("rgb")
listener = SignalListener()
image.sigItemChanged.connect(listener)
@@ -145,13 +151,12 @@ class TestSigItemChangedSignal(PlotWidgetTestCase):
# Test of signals in ImageRgba class
image.setData(numpy.zeros((10, 10, 3)))
- self.assertEqual(listener.arguments(argumentIndex=0),
- [ItemChangedType.DATA])
+ self.assertEqual(listener.arguments(argumentIndex=0), [ItemChangedType.DATA])
def testMarkerChanged(self):
"""Test sigItemChanged for markers"""
- self.plot.addMarker(10, 20, legend='test')
- marker = self.plot._getMarker('test')
+ self.plot.addMarker(10, 20, legend="test")
+ marker = self.plot._getMarker("test")
listener = SignalListener()
marker.sigItemChanged.connect(listener)
@@ -159,42 +164,45 @@ class TestSigItemChangedSignal(PlotWidgetTestCase):
# Test signals in _BaseMarker
marker.setPosition(10, 10)
marker.setPosition(10, 10) # Not sending event
- marker.setText('toto')
- self.assertEqual(listener.arguments(argumentIndex=0),
- [ItemChangedType.POSITION,
- ItemChangedType.TEXT])
+ marker.setText("toto")
+ self.assertEqual(
+ listener.arguments(argumentIndex=0),
+ [ItemChangedType.POSITION, ItemChangedType.TEXT],
+ )
# XMarker
- self.plot.addXMarker(10, legend='x')
- marker = self.plot._getMarker('x')
+ self.plot.addXMarker(10, legend="x")
+ marker = self.plot._getMarker("x")
listener = SignalListener()
marker.sigItemChanged.connect(listener)
marker.setPosition(20, 20)
- self.assertEqual(listener.arguments(argumentIndex=0),
- [ItemChangedType.POSITION])
+ self.assertEqual(
+ listener.arguments(argumentIndex=0), [ItemChangedType.POSITION]
+ )
# YMarker
- self.plot.addYMarker(10, legend='x')
- marker = self.plot._getMarker('x')
+ self.plot.addYMarker(10, legend="x")
+ marker = self.plot._getMarker("x")
listener = SignalListener()
marker.sigItemChanged.connect(listener)
marker.setPosition(20, 20)
- self.assertEqual(listener.arguments(argumentIndex=0),
- [ItemChangedType.POSITION])
+ self.assertEqual(
+ listener.arguments(argumentIndex=0), [ItemChangedType.POSITION]
+ )
def testScatterChanged(self):
"""Test sigItemChanged for scatter"""
data = numpy.arange(10)
- self.plot.addScatter(data, data, data, legend='test')
- scatter = self.plot.getScatter('test')
+ self.plot.addScatter(data, data, data, legend="test")
+ scatter = self.plot.getScatter("test")
listener = SignalListener()
scatter.sigItemChanged.connect(listener)
# ColormapMixIn
- scatter.getColormap().setName('viridis')
+ scatter.getColormap().setName("viridis")
# Test of signals in Scatter class
scatter.setData((0, 1, 2), (1, 0, 2), (0, 1, 2))
@@ -202,44 +210,48 @@ class TestSigItemChangedSignal(PlotWidgetTestCase):
# Visualization mode changed
scatter.setVisualization(scatter.Visualization.SOLID)
- self.assertEqual(listener.arguments(),
- [(ItemChangedType.COLORMAP,),
- (ItemChangedType.DATA,),
- (ItemChangedType.COLORMAP,),
- (ItemChangedType.VISUALIZATION_MODE,)])
+ self.assertEqual(
+ listener.arguments(),
+ [
+ (ItemChangedType.COLORMAP,),
+ (ItemChangedType.DATA,),
+ (ItemChangedType.COLORMAP,),
+ (ItemChangedType.VISUALIZATION_MODE,),
+ ],
+ )
def testShapeChanged(self):
"""Test sigItemChanged for shape"""
- data = numpy.array((1., 10.))
- self.plot.addShape(data, data, legend='test', shape='rectangle')
- shape = self.plot._getItem(kind='item', legend='test')
+ data = numpy.array((1.0, 10.0))
+ self.plot.addShape(data, data, legend="test", shape="rectangle")
+ shape = self.plot._getItem(kind="item", legend="test")
listener = SignalListener()
shape.sigItemChanged.connect(listener)
shape.setOverlay(True)
- shape.setPoints(((2., 2.), (3., 3.)))
+ shape.setPoints(((2.0, 2.0), (3.0, 3.0)))
- self.assertEqual(listener.arguments(),
- [(ItemChangedType.OVERLAY,),
- (ItemChangedType.DATA,)])
+ self.assertEqual(
+ listener.arguments(), [(ItemChangedType.OVERLAY,), (ItemChangedType.DATA,)]
+ )
class TestSymbol(PlotWidgetTestCase):
- """Test item's symbol """
+ """Test item's symbol"""
def test(self):
"""Test sigItemChanged for curve"""
- self.plot.addCurve(numpy.arange(10), numpy.arange(10), legend='test')
- curve = self.plot.getCurve('test')
+ self.plot.addCurve(numpy.arange(10), numpy.arange(10), legend="test")
+ curve = self.plot.getCurve("test")
# SymbolMixIn
- curve.setSymbol('o')
+ curve.setSymbol("o")
name = curve.getSymbolName()
- self.assertEqual('Circle', name)
+ self.assertEqual("Circle", name)
- name = curve.getSymbolName('d')
- self.assertEqual('Diamond', name)
+ name = curve.getSymbolName("d")
+ self.assertEqual("Diamond", name)
class TestVisibleExtent(PlotWidgetTestCase):
@@ -253,7 +265,7 @@ class TestVisibleExtent(PlotWidgetTestCase):
curve.setData((1, 2, 3), (0, 1, 2))
histogram = items.Histogram()
- histogram.setData((0, 1, 2), (1, 5/3, 7/3, 3))
+ histogram.setData((0, 1, 2), (1, 5 / 3, 7 / 3, 3))
image = items.ImageData()
image.setOrigin((1, 0))
@@ -271,10 +283,10 @@ class TestVisibleExtent(PlotWidgetTestCase):
xaxis.setLimits(0, 100)
yaxis.setLimits(0, 100)
self.plot.addItem(item)
- self.assertEqual(item.getVisibleBounds(), (1., 3., 0., 2.))
+ self.assertEqual(item.getVisibleBounds(), (1.0, 3.0, 0.0, 2.0))
xaxis.setLimits(0.5, 2.5)
- self.assertEqual(item.getVisibleBounds(), (1, 2.5, 0., 2.))
+ self.assertEqual(item.getVisibleBounds(), (1, 2.5, 0.0, 2.0))
yaxis.setLimits(0.5, 1.5)
self.assertEqual(item.getVisibleBounds(), (1, 2.5, 0.5, 1.5))
@@ -349,11 +361,205 @@ class TestImageDataAggregated(PlotWidgetTestCase):
# Zoom-out
for i in range(4):
xmin, xmax = self.plot.getXAxis().getLimits()
- ymin, ymax = self.plot.getYAxis().getLimits()
+ ymin, ymax = self.plot.getYAxis().getLimits()
self.plot.setLimits(
- xmin - (xmax - xmin)/2,
- xmax + (xmax - xmin)/2,
- ymin - (ymax - ymin)/2,
- ymax + (ymax - ymin)/2,
+ xmin - (xmax - xmin) / 2,
+ xmax + (xmax - xmin) / 2,
+ ymin - (ymax - ymin) / 2,
+ ymax + (ymax - ymin) / 2,
)
self.qapp.processEvents()
+
+
+def testRegionOfInterestText():
+ roi = RegionOfInterest()
+
+ listener = SignalListener()
+ roi.sigItemChanged.connect(listener)
+
+ assert roi.getName() == roi.getText()
+
+ roi.setText("some text")
+ assert listener.arguments(argumentIndex=0) == [ItemChangedType.TEXT]
+ listener.clear()
+ assert roi.getText() == "some text"
+
+ roi.setName("new_name")
+ assert listener.arguments(argumentIndex=0) == [ItemChangedType.NAME]
+ listener.clear()
+ assert roi.getText() == "some text"
+
+ roi.setText(None)
+ assert listener.arguments(argumentIndex=0) == [ItemChangedType.TEXT]
+ listener.clear()
+ assert roi.getText() == "new_name"
+
+ roi.setName("even_newer_name")
+ assert listener.arguments(argumentIndex=0) == [
+ ItemChangedType.NAME,
+ ItemChangedType.TEXT,
+ ]
+ assert roi.getText() == "even_newer_name"
+
+
+def testPlotAddItemsWithoutLegend(plotWidget):
+ curve1 = items.Curve()
+ curve1.setData([0, 10], [0, 20])
+ plotWidget.addItem(curve1)
+
+ curve2 = items.Curve()
+ curve2.setData([0, -10], [0, -20])
+ plotWidget.addItem(curve2)
+
+ assert plotWidget.getItems() == (curve1, curve2)
+
+ datarange = plotWidget.getDataRange()
+ assert datarange.x == (-10, 10)
+ assert datarange.y == (-20, 20)
+
+ plotWidget.resetZoom()
+ assert plotWidget.getXAxis().getLimits() == (-10, 10)
+ assert plotWidget.getYAxis().getLimits() == (-20, 20)
+
+
+def testPlotWidgetAddCurve(plotWidget):
+ curve = plotWidget.addCurve(x=(0, 1), y=(1, 0), legend="test", symbol="s")
+ assert isinstance(curve, items.Curve)
+ assert numpy.array_equal(curve.getXData(copy=False), (0, 1))
+ assert numpy.array_equal(curve.getYData(copy=False), (1, 0))
+ assert curve.getName() == "test"
+ assert curve.getSymbol() == "s"
+
+ curveUpdated = plotWidget.addCurve(
+ x=(0, 1, 2), y=(1, 0, 1), legend="test", symbol="o"
+ )
+ assert curveUpdated is curve
+ assert numpy.array_equal(curveUpdated.getXData(copy=False), (0, 1, 2))
+ assert numpy.array_equal(curveUpdated.getYData(copy=False), (1, 0, 1))
+ assert curveUpdated.getName() == "test"
+ assert curveUpdated.getSymbol() == "o"
+
+
+def testPlotWidgetAddImage(plotWidget):
+ image = plotWidget.addImage(((0, 1), (2, 3)), legend="test")
+ assert isinstance(image, items.ImageData)
+ assert numpy.array_equal(image.getData(copy=False), ((0, 1), (2, 3)))
+ assert image.getName() == "test"
+
+ imageUpdated = plotWidget.addImage([(0, 1)], legend="test")
+ assert imageUpdated is image
+ assert numpy.array_equal(image.getData(copy=False), [(0, 1)])
+ assert image.getName() == "test"
+
+ # Update with a 1pixel RGB image
+ imageRgb = plotWidget.addImage([[(0.0, 0.0, 1.0)]], legend="test")
+ assert isinstance(imageRgb, items.ImageRgba)
+ assert numpy.array_equal(imageRgb.getData(copy=False), [[(0.0, 0.0, 1.0)]])
+ assert imageRgb.getName() == "test"
+
+ # Update with a 1pixel RGB image
+ imageRgbUpdated = plotWidget.addImage([[(1.0, 0.0, 0.0)]], legend="test")
+ assert imageRgbUpdated is imageRgb
+ assert numpy.array_equal(imageRgbUpdated.getData(copy=False), [[(1.0, 0.0, 0.0)]])
+ assert imageRgbUpdated.getName() == "test"
+
+
+def testPlotWidgetAddScatter(plotWidget):
+ scatter = plotWidget.addScatter(
+ x=(0, 1), y=(0, 1), value=(0, 1), legend="test", symbol="s"
+ )
+ assert isinstance(scatter, items.Scatter)
+ assert numpy.array_equal(scatter.getXData(copy=False), (0, 1))
+ assert numpy.array_equal(scatter.getYData(copy=False), (0, 1))
+ assert numpy.array_equal(scatter.getValueData(copy=False), (0, 1))
+ assert scatter.getName() == "test"
+ assert scatter.getSymbol() == "s"
+
+
+def testPlotWidgetAddHistogram(plotWidget):
+ histogram = plotWidget.addHistogram(
+ histogram=[1], edges=(0, 1), legend="test", fill=True
+ )
+ assert isinstance(histogram, items.Histogram)
+ assert numpy.array_equal(histogram.getBinEdgesData(copy=False), (0, 1))
+ assert numpy.array_equal(histogram.getValueData(copy=False), [1])
+ assert histogram.getName() == "test"
+ assert histogram.isFill()
+
+
+def testPlotWidgetAddMarker(plotWidget):
+ marker = plotWidget.addMarker(x=0, y=1, legend="test")
+ assert isinstance(marker, items.Marker)
+ assert marker.getPosition() == (0, 1)
+ assert marker.getName() == "test"
+ assert plotWidget.getItems() == (marker,)
+
+ xmarker = plotWidget.addXMarker(1, legend="test")
+ assert isinstance(xmarker, items.XMarker)
+ assert xmarker.getPosition() == (1, None)
+ assert xmarker.getName() == "test"
+ assert plotWidget.getItems() == (xmarker,)
+
+ ymarker = plotWidget.addYMarker(2, legend="test")
+ assert isinstance(ymarker, items.YMarker)
+ assert ymarker.getPosition() == (None, 2)
+ assert ymarker.getName() == "test"
+ assert plotWidget.getItems() == (ymarker,)
+
+
+def testPlotWidgetAddShape(plotWidget):
+ shape = plotWidget.addShape(
+ xdata=(0, 1), ydata=(0, 1), legend="test", shape="polygon"
+ )
+ assert isinstance(shape, items.Shape)
+ assert numpy.array_equal(shape.getPoints(copy=False), ((0, 0), (1, 1)))
+ assert shape.getName() == "test"
+ assert shape.getType() == "polygon"
+
+
+@pytest.mark.parametrize(
+ "linestyle",
+ (
+ "",
+ "-",
+ "--",
+ "-.",
+ ":",
+ (0.0, None),
+ (0.5, ()),
+ (0.0, (5.0, 5.0)),
+ (4.0, (8.0, 4.0, 4.0, 4.0)),
+ ),
+)
+@pytest.mark.parametrize("plotWidget", ("mpl", "gl"), indirect=True)
+def testLineStyle(qapp_utils, plotWidget, linestyle):
+ """Test different line styles for LineMixIn items"""
+ plotWidget.setGraphTitle(f"Line style: {linestyle}")
+
+ curve = plotWidget.addCurve((0, 1), (0, 1), linestyle=linestyle)
+ assert curve.getLineStyle() == linestyle
+
+ histogram = plotWidget.addHistogram((0.25, 0.75, 0.25), (0.0, 0.33, 0.66, 1.0))
+ histogram.setLineStyle(linestyle)
+ assert histogram.getLineStyle() == linestyle
+
+ polylines = plotWidget.addShape(
+ (0, 1), (1, 0), shape="polylines", linestyle=linestyle
+ )
+ assert polylines.getLineStyle() == linestyle
+
+ rectangle = plotWidget.addShape(
+ (0.4, 0.6), (0.4, 0.6), shape="rectangle", linestyle=linestyle
+ )
+ assert rectangle.getLineStyle() == linestyle
+
+ xmarker = plotWidget.addXMarker(0.5)
+ xmarker.setLineStyle(linestyle)
+ assert xmarker.getLineStyle() == linestyle
+
+ ymarker = plotWidget.addYMarker(0.5)
+ ymarker.setLineStyle(linestyle)
+ assert ymarker.getLineStyle() == linestyle
+
+ plotWidget.replot()
+ qapp_utils.qWait(100)
diff --git a/src/silx/gui/plot/test/testLegendSelector.py b/src/silx/gui/plot/test/testLegendSelector.py
index 3a596ac..a1f000a 100644
--- a/src/silx/gui/plot/test/testLegendSelector.py
+++ b/src/silx/gui/plot/test/testLegendSelector.py
@@ -29,7 +29,6 @@ __date__ = "15/05/2017"
import logging
-import unittest
from silx.gui import qt
from silx.gui.utils.testutils import TestCaseQt
@@ -44,6 +43,7 @@ class TestLegendSelector(TestCaseQt):
def testLegendSelector(self):
"""Test copied from __main__ of LegendSelector in PyMca"""
+
class Notifier(qt.QObject):
def __init__(self):
qt.QObject.__init__(self)
@@ -51,22 +51,31 @@ class TestLegendSelector(TestCaseQt):
def signalReceived(self, **kw):
obj = self.sender()
- _logger.info('NOTIFIER -- signal received\n\tsender: %s',
- str(obj))
+ _logger.info("NOTIFIER -- signal received\n\tsender: %s", str(obj))
notifier = Notifier()
- legends = ['Legend0',
- 'Legend1',
- 'Long Legend 2',
- 'Foo Legend 3',
- 'Even Longer Legend 4',
- 'Short Leg 5',
- 'Dot symbol 6',
- 'Comma symbol 7']
- colors = [qt.Qt.darkRed, qt.Qt.green, qt.Qt.yellow, qt.Qt.darkCyan,
- qt.Qt.blue, qt.Qt.darkBlue, qt.Qt.red, qt.Qt.darkYellow]
- symbols = ['o', 't', '+', 'x', 's', 'd', '.', ',']
+ legends = [
+ "Legend0",
+ "Legend1",
+ "Long Legend 2",
+ "Foo Legend 3",
+ "Even Longer Legend 4",
+ "Short Leg 5",
+ "Dot symbol 6",
+ "Comma symbol 7",
+ ]
+ colors = [
+ qt.Qt.darkRed,
+ qt.Qt.green,
+ qt.Qt.yellow,
+ qt.Qt.darkCyan,
+ qt.Qt.blue,
+ qt.Qt.darkBlue,
+ qt.Qt.red,
+ qt.Qt.darkYellow,
+ ]
+ symbols = ["o", "t", "+", "x", "s", "d", ".", ","]
win = LegendSelector.LegendListView()
# win = LegendListContextMenu()
@@ -77,9 +86,9 @@ class TestLegendSelector(TestCaseQt):
for _idx, (l, c, s) in enumerate(zip(legends, colors, symbols)):
ddict = {
- 'color': qt.QColor(c),
- 'linewidth': 4,
- 'symbol': s,
+ "color": qt.QColor(c),
+ "linewidth": 4,
+ "symbol": s,
}
legend = l
llist.append((legend, ddict))
@@ -116,14 +125,15 @@ class TestRenameCurveDialog(TestCaseQt):
def testDialog(self):
"""Create dialog, change name and press OK"""
self.dialog = LegendSelector.RenameCurveDialog(
- None, 'curve1', ['curve1', 'curve2', 'curve3'])
+ None, "curve1", ["curve1", "curve2", "curve3"]
+ )
self.dialog.open()
self.qWaitForWindowExposed(self.dialog)
- self.keyClicks(self.dialog.lineEdit, 'changed')
+ self.keyClicks(self.dialog.lineEdit, "changed")
self.mouseClick(self.dialog.okButton, qt.Qt.LeftButton)
self.qapp.processEvents()
ret = self.dialog.result()
self.assertEqual(ret, qt.QDialog.Accepted)
newName = self.dialog.getText()
- self.assertEqual(newName, 'curve1changed')
+ self.assertEqual(newName, "curve1changed")
del self.dialog
diff --git a/src/silx/gui/plot/test/testMaskToolsWidget.py b/src/silx/gui/plot/test/testMaskToolsWidget.py
index 5f36ec2..1428687 100644
--- a/src/silx/gui/plot/test/testMaskToolsWidget.py
+++ b/src/silx/gui/plot/test/testMaskToolsWidget.py
@@ -30,7 +30,6 @@ __date__ = "17/01/2018"
import logging
import os.path
-import unittest
import numpy
@@ -41,8 +40,6 @@ from silx.gui.utils.testutils import getQToolButtonFromAction
from silx.gui.plot import PlotWindow, MaskToolsWidget
from .utils import PlotWidgetTestCase
-import fabio
-
_logger = logging.getLogger(__name__)
@@ -55,7 +52,7 @@ class TestMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
def setUp(self):
super(TestMaskToolsWidget, self).setUp()
- self.widget = MaskToolsWidget.MaskToolsDockWidget(plot=self.plot, name='TEST')
+ self.widget = MaskToolsWidget.MaskToolsDockWidget(plot=self.plot, name="TEST")
self.plot.addDockWidget(qt.Qt.BottomDockWidgetArea, self.widget)
self.maskWidget = self.widget.widget()
@@ -66,10 +63,10 @@ class TestMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
def testEmptyPlot(self):
"""Empty plot, display MaskToolsDockWidget, toggle multiple masks"""
- self.maskWidget.setMultipleMasks('single')
+ self.maskWidget.setMultipleMasks("single")
self.qapp.processEvents()
- self.maskWidget.setMultipleMasks('exclusive')
+ self.maskWidget.setMultipleMasks("exclusive")
self.qapp.processEvents()
def _drag(self):
@@ -99,12 +96,14 @@ class TestMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
x, y = plot.width() // 2, plot.height() // 2
offset = min(plot.width(), plot.height()) // 10
- star = [(x, y + offset),
- (x - offset, y - offset),
- (x + offset, y),
- (x - offset, y),
- (x + offset, y - offset),
- (x, y + offset)] # Close polygon
+ star = [
+ (x, y + offset),
+ (x - offset, y - offset),
+ (x + offset, y),
+ (x - offset, y),
+ (x + offset, y - offset),
+ (x, y + offset),
+ ] # Close polygon
self.mouseMove(plot, pos=(0, 0))
for pos in star:
@@ -121,28 +120,33 @@ class TestMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
x, y = plot.width() // 2, plot.height() // 2
offset = min(plot.width(), plot.height()) // 10
- star = [(x, y + offset),
- (x - offset, y - offset),
- (x + offset, y),
- (x - offset, y),
- (x + offset, y - offset)]
+ star = [
+ (x, y + offset),
+ (x - offset, y - offset),
+ (x + offset, y),
+ (x - offset, y),
+ (x + offset, y - offset),
+ ]
self.mouseMove(plot, pos=(0, 0))
for start, end in zip(star[:-1], star[1:]):
- self.mouseMove(plot, pos=start)
- self.mousePress(plot, qt.Qt.LeftButton, pos=start)
- self.qapp.processEvents()
- self.mouseMove(plot, pos=end)
- self.qapp.processEvents()
- self.mouseRelease(plot, qt.Qt.LeftButton, pos=end)
- self.qapp.processEvents()
+ self.mouseMove(plot, pos=start)
+ self.mousePress(plot, qt.Qt.LeftButton, pos=start)
+ self.qapp.processEvents()
+ self.mouseMove(plot, pos=end)
+ self.qapp.processEvents()
+ self.mouseRelease(plot, qt.Qt.LeftButton, pos=end)
+ self.qapp.processEvents()
def _isMaskItemSync(self):
"""Check if masks from item and tools are sync or not"""
if self.maskWidget.isItemMaskUpdated():
- return numpy.all(numpy.equal(
- self.maskWidget.getSelectionMask(),
- self.plot.getActiveImage().getMaskData(copy=False)))
+ return numpy.all(
+ numpy.equal(
+ self.maskWidget.getSelectionMask(),
+ self.plot.getActiveImage().getMaskData(copy=False),
+ )
+ )
else:
return True
@@ -150,30 +154,36 @@ class TestMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
"""Plot with an image: test MaskToolsWidget interactions"""
# Add and remove a image (this should enable/disable GUI + change mask)
- self.plot.addImage(numpy.random.random(1024**2).reshape(1024, 1024),
- legend='test')
+ self.plot.addImage(
+ numpy.random.random(1024**2).reshape(1024, 1024), legend="test"
+ )
self.qapp.processEvents()
- self.plot.remove('test', kind='image')
+ self.plot.remove("test", kind="image")
self.qapp.processEvents()
- tests = [((0, 0), (1, 1)),
- ((1000, 1000), (1, 1)),
- ((0, 0), (-1, -1)),
- ((1000, 1000), (-1, -1))]
+ tests = [
+ ((0, 0), (1, 1)),
+ ((1000, 1000), (1, 1)),
+ ((0, 0), (-1, -1)),
+ ((1000, 1000), (-1, -1)),
+ ]
for itemMaskUpdated in (False, True):
for origin, scale in tests:
with self.subTest(origin=origin, scale=scale):
self.maskWidget.setItemMaskUpdated(itemMaskUpdated)
- self.plot.addImage(numpy.arange(1024**2).reshape(1024, 1024),
- legend='test',
- origin=origin,
- scale=scale)
+ self.plot.addImage(
+ numpy.arange(1024**2).reshape(1024, 1024),
+ legend="test",
+ origin=origin,
+ scale=scale,
+ )
self.qapp.processEvents()
self.assertEqual(
- self.maskWidget.isItemMaskUpdated(), itemMaskUpdated)
+ self.maskWidget.isItemMaskUpdated(), itemMaskUpdated
+ )
# Test draw rectangle #
toolButton = getQToolButtonFromAction(self.maskWidget.rectAction)
@@ -185,7 +195,8 @@ class TestMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
self.qapp.processEvents()
self._drag()
self.assertFalse(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))
+ )
self.assertTrue(self._isMaskItemSync())
# unmask same region
@@ -193,7 +204,8 @@ class TestMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
self.qapp.processEvents()
self._drag()
self.assertTrue(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))
+ )
self.assertTrue(self._isMaskItemSync())
# Test draw polygon #
@@ -206,7 +218,8 @@ class TestMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
self.qapp.processEvents()
self._drawPolygon()
self.assertFalse(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))
+ )
self.assertTrue(self._isMaskItemSync())
# unmask same region
@@ -214,7 +227,8 @@ class TestMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
self.qapp.processEvents()
self._drawPolygon()
self.assertTrue(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))
+ )
self.assertTrue(self._isMaskItemSync())
# Test draw pencil #
@@ -230,7 +244,8 @@ class TestMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
self.qapp.processEvents()
self._drawPencil()
self.assertFalse(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))
+ )
self.assertTrue(self._isMaskItemSync())
# unmask same region
@@ -238,7 +253,8 @@ class TestMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
self.qapp.processEvents()
self._drawPencil()
self.assertTrue(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))
+ )
self.assertTrue(self._isMaskItemSync())
# Test no draw tool #
@@ -250,8 +266,7 @@ class TestMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
def __loadSave(self, file_format):
"""Plot with an image: test MaskToolsWidget operations"""
- self.plot.addImage(numpy.arange(1024**2).reshape(1024, 1024),
- legend='test')
+ self.plot.addImage(numpy.arange(1024**2).reshape(1024, 1024), legend="test")
self.qapp.processEvents()
# Draw a polygon mask
@@ -264,16 +279,18 @@ class TestMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
self.assertFalse(numpy.all(numpy.equal(ref_mask, 0)))
with temp_dir() as tmp:
- mask_filename = os.path.join(tmp, 'mask.' + file_format)
+ mask_filename = os.path.join(tmp, "mask." + file_format)
self.maskWidget.save(mask_filename, file_format)
self.maskWidget.resetSelectionMask()
self.assertTrue(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))
+ )
self.maskWidget.load(mask_filename)
- self.assertTrue(numpy.all(numpy.equal(
- self.maskWidget.getSelectionMask(), ref_mask)))
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), ref_mask))
+ )
def testLoadSaveNpy(self):
self.__loadSave("npy")
@@ -282,8 +299,7 @@ class TestMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
self.__loadSave("msk")
def testSigMaskChangedEmitted(self):
- self.plot.addImage(numpy.arange(512**2).reshape(512, 512),
- legend='test')
+ self.plot.addImage(numpy.arange(512**2).reshape(512, 512), legend="test")
self.plot.resetZoom()
self.qapp.processEvents()
diff --git a/src/silx/gui/plot/test/testPixelIntensityHistoAction.py b/src/silx/gui/plot/test/testPixelIntensityHistoAction.py
index 43d7588..7fd87e8 100644
--- a/src/silx/gui/plot/test/testPixelIntensityHistoAction.py
+++ b/src/silx/gui/plot/test/testPixelIntensityHistoAction.py
@@ -29,7 +29,6 @@ __date__ = "02/03/2018"
import numpy
-import unittest
from silx.utils.testutils import ParametricTestCase
from silx.gui.utils.testutils import TestCaseQt, getQToolButtonFromAction
@@ -53,7 +52,7 @@ class TestPixelIntensitiesHisto(TestCaseQt, ParametricTestCase):
def testShowAndHide(self):
"""Simple test that the plot is showing and hiding when activating the
action"""
- self.plotImage.addImage(self.image, origin=(0, 0), legend='sino')
+ self.plotImage.addImage(self.image, origin=(0, 0), legend="sino")
self.plotImage.show()
histoAction = self.plotImage.getIntensityHistogramAction()
@@ -67,7 +66,7 @@ class TestPixelIntensitiesHisto(TestCaseQt, ParametricTestCase):
self.assertTrue(histoAction.getHistogramWidget().isVisible())
# test the pixel intensity diagram is hiding
- self.qapp.setActiveWindow(self.plotImage)
+ self.plotImage.activateWindow()
self.qapp.processEvents()
self.mouseMove(button)
self.mouseClick(button, qt.Qt.LeftButton)
@@ -76,19 +75,25 @@ class TestPixelIntensitiesHisto(TestCaseQt, ParametricTestCase):
def testImageFormatInput(self):
"""Test multiple type as image input"""
- typesToTest = [numpy.uint8, numpy.int8, numpy.int16, numpy.int32,
- numpy.float32, numpy.float64]
- self.plotImage.addImage(self.image, origin=(0, 0), legend='sino')
+ typesToTest = [
+ numpy.uint8,
+ numpy.int8,
+ numpy.int16,
+ numpy.int32,
+ numpy.float32,
+ numpy.float64,
+ ]
+ self.plotImage.addImage(self.image, origin=(0, 0), legend="sino")
self.plotImage.show()
- button = getQToolButtonFromAction(
- self.plotImage.getIntensityHistogramAction())
+ button = getQToolButtonFromAction(self.plotImage.getIntensityHistogramAction())
self.mouseMove(button)
self.mouseClick(button, qt.Qt.LeftButton)
self.qapp.processEvents()
for typeToTest in typesToTest:
with self.subTest(typeToTest=typeToTest):
- self.plotImage.addImage(self.image.astype(typeToTest),
- origin=(0, 0), legend='sino')
+ self.plotImage.addImage(
+ self.image.astype(typeToTest), origin=(0, 0), legend="sino"
+ )
def testScatter(self):
"""Test that an histogram from a scatter is displayed"""
@@ -136,7 +141,7 @@ class TestPixelIntensitiesHisto(TestCaseQt, ParametricTestCase):
data1 = items[0].getValueData(copy=False)
# Set another item to the plot
- self.plotImage.addImage(self.image, origin=(0, 0), legend='sino')
+ self.plotImage.addImage(self.image, origin=(0, 0), legend="sino")
self.qapp.processEvents()
data2 = items[0].getValueData(copy=False)
diff --git a/src/silx/gui/plot/test/testPlotActions.py b/src/silx/gui/plot/test/testPlotActions.py
index 4006ab9..9f56aad 100644
--- a/src/silx/gui/plot/test/testPlotActions.py
+++ b/src/silx/gui/plot/test/testPlotActions.py
@@ -40,17 +40,13 @@ import numpy
@pytest.fixture
def colormap1():
- colormap = Colormap(name='gray',
- vmin=10.0, vmax=20.0,
- normalization='linear')
+ colormap = Colormap(name="gray", vmin=10.0, vmax=20.0, normalization="linear")
yield colormap
@pytest.fixture
def colormap2():
- colormap = Colormap(name='red',
- vmin=10.0, vmax=20.0,
- normalization='linear')
+ colormap = Colormap(name="red", vmin=10.0, vmax=20.0, normalization="linear")
yield colormap
@@ -70,25 +66,25 @@ def test_action_active_colormap(qapp_utils, plot, colormap1, colormap2):
defaultColormap = plot.getDefaultColormap()
assert colormapDialog.getColormap() is defaultColormap
- plot.addImage(data=numpy.random.rand(10, 10), legend='img1',
- origin=(0, 0),
- colormap=colormap1)
- plot.setActiveImage('img1')
+ plot.addImage(
+ data=numpy.random.rand(10, 10), legend="img1", origin=(0, 0), colormap=colormap1
+ )
+ plot.setActiveImage("img1")
assert colormapDialog.getColormap() is colormap1
- plot.addImage(data=numpy.random.rand(10, 10), legend='img2',
- origin=(0, 0), colormap=colormap2)
- plot.addImage(data=numpy.random.rand(10, 10), legend='img3',
- origin=(0, 0))
+ plot.addImage(
+ data=numpy.random.rand(10, 10), legend="img2", origin=(0, 0), colormap=colormap2
+ )
+ plot.addImage(data=numpy.random.rand(10, 10), legend="img3", origin=(0, 0))
- plot.setActiveImage('img3')
+ plot.setActiveImage("img3")
assert colormapDialog.getColormap() is defaultColormap
plot.getActiveImage().setColormap(colormap2)
assert colormapDialog.getColormap() is colormap2
- plot.remove('img2')
- plot.remove('img3')
- plot.remove('img1')
+ plot.remove("img2")
+ plot.remove("img3")
+ plot.remove("img1")
assert colormapDialog.getColormap() is defaultColormap
@@ -100,10 +96,11 @@ def test_action_show_hide_colormap_dialog(qapp_utils, plot, colormap1):
assert not plot.getColormapAction().isChecked()
plot.getColormapAction()._actionTriggered(checked=True)
assert plot.getColormapAction().isChecked()
- plot.addImage(data=numpy.random.rand(10, 10), legend='img1',
- origin=(0, 0), colormap=colormap1)
- colormap1.setName('red')
+ plot.addImage(
+ data=numpy.random.rand(10, 10), legend="img1", origin=(0, 0), colormap=colormap1
+ )
+ colormap1.setName("red")
plot.getColormapAction()._actionTriggered()
- colormap1.setName('blue')
+ colormap1.setName("blue")
colormapDialog.close()
assert not plot.getColormapAction().isChecked()
diff --git a/src/silx/gui/plot/test/testPlotInteraction.py b/src/silx/gui/plot/test/testPlotInteraction.py
index 17aad97..a97a694 100644
--- a/src/silx/gui/plot/test/testPlotInteraction.py
+++ b/src/silx/gui/plot/test/testPlotInteraction.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2016=2017 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -27,9 +27,10 @@ __authors__ = ["T. Vincent"]
__license__ = "MIT"
__date__ = "01/09/2017"
+import pytest
-import unittest
from silx.gui import qt
+from silx.gui.plot import PlotWidget
from .utils import PlotWidgetTestCase
@@ -78,82 +79,154 @@ class TestSelectPolygon(PlotWidgetTestCase):
def test(self):
"""Test draw polygons + events"""
- self.plot.sigInteractiveModeChanged.connect(
- self._interactionModeChanged)
+ self.plot.sigInteractiveModeChanged.connect(self._interactionModeChanged)
- self.plot.setInteractiveMode(
- 'draw', shape='polygon', label='test', source=self)
+ self.plot.setInteractiveMode("draw", shape="polygon", label="test", source=self)
interaction = self.plot.getInteractiveMode()
- self.assertEqual(interaction['mode'], 'draw')
- self.assertEqual(interaction['shape'], 'polygon')
+ self.assertEqual(interaction["mode"], "draw")
+ self.assertEqual(interaction["shape"], "polygon")
- self.plot.sigInteractiveModeChanged.disconnect(
- self._interactionModeChanged)
+ self.plot.sigInteractiveModeChanged.disconnect(self._interactionModeChanged)
plot = self.plot.getWidgetHandle()
xCenter, yCenter = plot.width() // 2, plot.height() // 2
offset = min(plot.width(), plot.height()) // 10
# Star polygon
- star = [(xCenter, yCenter + offset),
- (xCenter - offset, yCenter - offset),
- (xCenter + offset, yCenter),
- (xCenter - offset, yCenter),
- (xCenter + offset, yCenter - offset),
- (xCenter, yCenter + offset)] # Close polygon
+ star = [
+ (xCenter, yCenter + offset),
+ (xCenter - offset, yCenter - offset),
+ (xCenter + offset, yCenter),
+ (xCenter - offset, yCenter),
+ (xCenter + offset, yCenter - offset),
+ (xCenter, yCenter + offset),
+ ] # Close polygon
# Draw while dumping signals
events = self._draw(star)
# Test last event
- drawEvents = [event for event in events
- if event['event'].startswith('drawing')]
- self.assertEqual(drawEvents[-1]['event'], 'drawingFinished')
- self.assertEqual(len(drawEvents[-1]['points']), 6)
+ drawEvents = [event for event in events if event["event"].startswith("drawing")]
+ self.assertEqual(drawEvents[-1]["event"], "drawingFinished")
+ self.assertEqual(len(drawEvents[-1]["points"]), 6)
# Large square
- largeSquare = [(xCenter - offset, yCenter - offset),
- (xCenter + offset, yCenter - offset),
- (xCenter + offset, yCenter + offset),
- (xCenter - offset, yCenter + offset),
- (xCenter - offset, yCenter - offset)] # Close polygon
+ largeSquare = [
+ (xCenter - offset, yCenter - offset),
+ (xCenter + offset, yCenter - offset),
+ (xCenter + offset, yCenter + offset),
+ (xCenter - offset, yCenter + offset),
+ (xCenter - offset, yCenter - offset),
+ ] # Close polygon
# Draw while dumping signals
events = self._draw(largeSquare)
# Test last event
- drawEvents = [event for event in events
- if event['event'].startswith('drawing')]
- self.assertEqual(drawEvents[-1]['event'], 'drawingFinished')
- self.assertEqual(len(drawEvents[-1]['points']), 5)
+ drawEvents = [event for event in events if event["event"].startswith("drawing")]
+ self.assertEqual(drawEvents[-1]["event"], "drawingFinished")
+ self.assertEqual(len(drawEvents[-1]["points"]), 5)
# Rectangle too thin along X: Some points are ignored
- thinRectX = [(xCenter, yCenter - offset),
- (xCenter, yCenter + offset),
- (xCenter + 1, yCenter + offset),
- (xCenter + 1, yCenter - offset)] # Close polygon
+ thinRectX = [
+ (xCenter, yCenter - offset),
+ (xCenter, yCenter + offset),
+ (xCenter + 1, yCenter + offset),
+ (xCenter + 1, yCenter - offset),
+ ] # Close polygon
# Draw while dumping signals
events = self._draw(thinRectX)
# Test last event
- drawEvents = [event for event in events
- if event['event'].startswith('drawing')]
- self.assertEqual(drawEvents[-1]['event'], 'drawingFinished')
- self.assertEqual(len(drawEvents[-1]['points']), 3)
+ drawEvents = [event for event in events if event["event"].startswith("drawing")]
+ self.assertEqual(drawEvents[-1]["event"], "drawingFinished")
+ self.assertEqual(len(drawEvents[-1]["points"]), 3)
# Rectangle too thin along Y: Some points are ignored
- thinRectY = [(xCenter - offset, yCenter),
- (xCenter + offset, yCenter),
- (xCenter + offset, yCenter + 1),
- (xCenter - offset, yCenter + 1)] # Close polygon
+ thinRectY = [
+ (xCenter - offset, yCenter),
+ (xCenter + offset, yCenter),
+ (xCenter + offset, yCenter + 1),
+ (xCenter - offset, yCenter + 1),
+ ] # Close polygon
# Draw while dumping signals
events = self._draw(thinRectY)
# Test last event
- drawEvents = [event for event in events
- if event['event'].startswith('drawing')]
- self.assertEqual(drawEvents[-1]['event'], 'drawingFinished')
- self.assertEqual(len(drawEvents[-1]['points']), 3)
+ drawEvents = [event for event in events if event["event"].startswith("drawing")]
+ self.assertEqual(drawEvents[-1]["event"], "drawingFinished")
+ self.assertEqual(len(drawEvents[-1]["points"]), 3)
+
+
+@pytest.mark.parametrize("scale", ["linear", "log"])
+@pytest.mark.parametrize("xaxis", [True, False])
+@pytest.mark.parametrize("yaxis", [True, False])
+@pytest.mark.parametrize("y2axis", [True, False])
+def testZoomEnabledAxes(qapp, qWidgetFactory, scale, xaxis, yaxis, y2axis):
+ """Test PlotInteraction.setZoomEnabledAxes effect on zoom interaction"""
+ plotWidget = qWidgetFactory(PlotWidget)
+ plotWidget.getXAxis().setScale(scale)
+ plotWidget.getYAxis("left").setScale(scale)
+ plotWidget.getYAxis("right").setScale(scale)
+ qapp.processEvents()
+
+ xLimits = plotWidget.getXAxis().getLimits()
+ yLimits = plotWidget.getYAxis("left").getLimits()
+ y2Limits = plotWidget.getYAxis("right").getLimits()
+
+ interaction = plotWidget.interaction()
+
+ assert interaction.getZoomEnabledAxes() == (True, True, True)
+
+ enabledAxes = xaxis, yaxis, y2axis
+ interaction.setZoomEnabledAxes(*enabledAxes)
+ assert interaction.getZoomEnabledAxes() == enabledAxes
+
+ cx, cy = plotWidget.width() // 2, plotWidget.height() // 2
+ plotWidget.onMouseWheel(cx, cy, 10)
+ qapp.processEvents()
+
+ xZoomed = plotWidget.getXAxis().getLimits() != xLimits
+ yZoomed = plotWidget.getYAxis("left").getLimits() != yLimits
+ y2Zoomed = plotWidget.getYAxis("right").getLimits() != y2Limits
+
+ assert xZoomed == enabledAxes[0]
+ assert yZoomed == enabledAxes[1]
+ assert y2Zoomed == enabledAxes[2]
+
+
+@pytest.mark.parametrize("scale", ["linear", "log"])
+@pytest.mark.parametrize("zoomOnWheel", [True, False])
+def testZoomOnWheelEnabled(qapp, qWidgetFactory, zoomOnWheel, scale):
+ """Test PlotInteraction.setZoomOnWheelEnabled"""
+ plotWidget = qWidgetFactory(PlotWidget)
+ plotWidget.getXAxis().setScale(scale)
+ plotWidget.getYAxis("left").setScale(scale)
+ plotWidget.getYAxis("right").setScale(scale)
+ qapp.processEvents()
+
+ xLimits = plotWidget.getXAxis().getLimits()
+ yLimits = plotWidget.getYAxis("left").getLimits()
+ y2Limits = plotWidget.getYAxis("right").getLimits()
+
+ interaction = plotWidget.interaction()
+
+ assert interaction.isZoomOnWheelEnabled()
+
+ interaction.setZoomOnWheelEnabled(zoomOnWheel)
+ assert interaction.isZoomOnWheelEnabled() == zoomOnWheel
+
+ cx, cy = plotWidget.width() // 2, plotWidget.height() // 2
+ plotWidget.onMouseWheel(cx, cy, 10)
+ qapp.processEvents()
+
+ xZoomed = plotWidget.getXAxis().getLimits() != xLimits
+ yZoomed = plotWidget.getYAxis("left").getLimits() != yLimits
+ y2Zoomed = plotWidget.getYAxis("right").getLimits() != y2Limits
+
+ assert xZoomed == zoomOnWheel
+ assert yZoomed == zoomOnWheel
+ assert y2Zoomed == zoomOnWheel
diff --git a/src/silx/gui/plot/test/testPlotWidget.py b/src/silx/gui/plot/test/testPlotWidget.py
index 19a34a9..842e880 100755
--- a/src/silx/gui/plot/test/testPlotWidget.py
+++ b/src/silx/gui/plot/test/testPlotWidget.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2016-2022 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -29,7 +29,6 @@ __date__ = "03/01/2019"
import unittest
-import logging
import numpy
import pytest
@@ -39,7 +38,6 @@ from silx.gui.utils.testutils import TestCaseQt
from silx.gui import qt
from silx.gui.plot import PlotWidget
-from silx.gui.plot.items.curve import CurveStyle
from silx.gui.plot.items import BoundingRect, XAxisExtent, YAxisExtent, Axis
from silx.gui.colors import Colormap
@@ -49,16 +47,12 @@ from .utils import PlotWidgetTestCase
SIZE = 1024
"""Size of the test image"""
-DATA_2D = numpy.arange(SIZE ** 2).reshape(SIZE, SIZE)
+DATA_2D = numpy.arange(SIZE**2).reshape(SIZE, SIZE)
"""Image data set"""
-logger = logging.getLogger(__name__)
-
-
class TestSpecialBackend(PlotWidgetTestCase, ParametricTestCase):
-
- def __init__(self, methodName='runTest', backend=None):
+ def __init__(self, methodName="runTest", backend=None):
TestCaseQt.__init__(self, methodName=methodName)
self.__backend = backend
@@ -79,7 +73,7 @@ class TestPlotWidget(PlotWidgetTestCase, ParametricTestCase):
def testSetTitleLabels(self):
"""Set title and axes labels"""
- title, xlabel, ylabel = 'the title', 'x label', 'y label'
+ title, xlabel, ylabel = "the title", "x label", "y label"
self.plot.setGraphTitle(title)
self.plot.getXAxis().setLabel(xlabel)
self.plot.getYAxis().setLabel(ylabel)
@@ -89,10 +83,7 @@ class TestPlotWidget(PlotWidgetTestCase, ParametricTestCase):
self.assertEqual(self.plot.getXAxis().getLabel(), xlabel)
self.assertEqual(self.plot.getYAxis().getLabel(), ylabel)
- def _checkLimits(self,
- expectedXLim=None,
- expectedYLim=None,
- expectedRatio=None):
+ def _checkLimits(self, expectedXLim=None, expectedYLim=None, expectedRatio=None):
"""Assert that limits are as expected"""
xlim = self.plot.getXAxis().getLimits()
ylim = self.plot.getYAxis().getLimits()
@@ -105,8 +96,7 @@ class TestPlotWidget(PlotWidgetTestCase, ParametricTestCase):
self.assertEqual(expectedYLim, ylim)
if expectedRatio is not None:
- self.assertTrue(
- numpy.allclose(expectedRatio, ratio, atol=0.01))
+ self.assertTrue(numpy.allclose(expectedRatio, ratio, atol=0.01))
def testChangeLimitsWithAspectRatio(self):
self.plot.setKeepDataAspectRatio()
@@ -115,15 +105,15 @@ class TestPlotWidget(PlotWidgetTestCase, ParametricTestCase):
ylim = self.plot.getYAxis().getLimits()
defaultRatio = abs(xlim[1] - xlim[0]) / abs(ylim[1] - ylim[0])
- self.plot.getXAxis().setLimits(1., 10.)
- self._checkLimits(expectedXLim=(1., 10.), expectedRatio=defaultRatio)
+ self.plot.getXAxis().setLimits(1.0, 10.0)
+ self._checkLimits(expectedXLim=(1.0, 10.0), expectedRatio=defaultRatio)
self.qapp.processEvents()
- self._checkLimits(expectedXLim=(1., 10.), expectedRatio=defaultRatio)
+ self._checkLimits(expectedXLim=(1.0, 10.0), expectedRatio=defaultRatio)
- self.plot.getYAxis().setLimits(1., 10.)
- self._checkLimits(expectedYLim=(1., 10.), expectedRatio=defaultRatio)
+ self.plot.getYAxis().setLimits(1.0, 10.0)
+ self._checkLimits(expectedYLim=(1.0, 10.0), expectedRatio=defaultRatio)
self.qapp.processEvents()
- self._checkLimits(expectedYLim=(1., 10.), expectedRatio=defaultRatio)
+ self._checkLimits(expectedYLim=(1.0, 10.0), expectedRatio=defaultRatio)
def testResizeWidget(self):
"""Test resizing the widget and receiving limitsChanged events"""
@@ -135,8 +125,8 @@ class TestPlotWidget(PlotWidgetTestCase, ParametricTestCase):
ylim = self.plot.getYAxis().getLimits()
listener = SignalListener()
- self.plot.getXAxis().sigLimitsChanged.connect(listener.partial('x'))
- self.plot.getYAxis().sigLimitsChanged.connect(listener.partial('y'))
+ self.plot.getXAxis().sigLimitsChanged.connect(listener.partial("x"))
+ self.plot.getYAxis().sigLimitsChanged.connect(listener.partial("y"))
# Resize without aspect ratio
self.plot.resize(200, 300)
@@ -159,17 +149,17 @@ class TestPlotWidget(PlotWidgetTestCase, ParametricTestCase):
def testAddRemoveItemSignals(self):
"""Test sigItemAdded and sigItemAboutToBeRemoved"""
listener = SignalListener()
- self.plot.sigItemAdded.connect(listener.partial('add'))
- self.plot.sigItemAboutToBeRemoved.connect(listener.partial('remove'))
+ self.plot.sigItemAdded.connect(listener.partial("add"))
+ self.plot.sigItemAboutToBeRemoved.connect(listener.partial("remove"))
- self.plot.addCurve((1, 2, 3), (3, 2, 1), legend='curve')
+ self.plot.addCurve((1, 2, 3), (3, 2, 1), legend="curve")
self.assertEqual(listener.callCount(), 1)
- curve = self.plot.getCurve('curve')
- self.plot.remove('curve')
+ curve = self.plot.getCurve("curve")
+ self.plot.remove("curve")
self.assertEqual(listener.callCount(), 2)
- self.assertEqual(listener.arguments(callIndex=0), ('add', curve))
- self.assertEqual(listener.arguments(callIndex=1), ('remove', curve))
+ self.assertEqual(listener.arguments(callIndex=0), ("add", curve))
+ self.assertEqual(listener.arguments(callIndex=1), ("remove", curve))
def testGetItems(self):
"""Test getItems method"""
@@ -183,7 +173,7 @@ class TestPlotWidget(PlotWidgetTestCase, ParametricTestCase):
self.plot.addMarker(*marker_pos)
marker_x = 6
self.plot.addXMarker(marker_x)
- self.plot.addShape((0, 5), (2, 10), shape='rectangle')
+ self.plot.addShape((0, 5), (2, 10), shape="rectangle")
items = self.plot.getItems()
self.assertEqual(len(items), 6)
@@ -192,7 +182,7 @@ class TestPlotWidget(PlotWidgetTestCase, ParametricTestCase):
self.assertTrue(numpy.all(numpy.equal(items[2].getXData(), scatter_x)))
self.assertTrue(numpy.all(numpy.equal(items[3].getPosition(), marker_pos)))
self.assertTrue(numpy.all(numpy.equal(items[4].getPosition()[0], marker_x)))
- self.assertEqual(items[5].getType(), 'rectangle')
+ self.assertEqual(items[5].getType(), "rectangle")
def testRemoveDiscardItem(self):
"""Test removeItem and discardItem"""
@@ -232,7 +222,7 @@ class TestPlotWidget(PlotWidgetTestCase, ParametricTestCase):
self.qapp.processEvents()
# Back to default
- self.plot.setBackgroundColor('white')
+ self.plot.setBackgroundColor("white")
self.plot.setDataBackgroundColor(None)
color = self.plot.getBackgroundColor()
self.assertTrue(color.isValid())
@@ -248,116 +238,132 @@ class TestPlotImage(PlotWidgetTestCase, ParametricTestCase):
def setUp(self):
super(TestPlotImage, self).setUp()
- self.plot.getYAxis().setLabel('Rows')
- self.plot.getXAxis().setLabel('Columns')
+ self.plot.getYAxis().setLabel("Rows")
+ self.plot.getXAxis().setLabel("Columns")
def testPlotColormapTemperature(self):
- self.plot.setGraphTitle('Temp. Linear')
+ self.plot.setGraphTitle("Temp. Linear")
- colormap = Colormap(name='temperature',
- normalization='linear',
- vmin=None,
- vmax=None)
+ colormap = Colormap(
+ name="temperature", normalization="linear", vmin=None, vmax=None
+ )
self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap)
def testPlotColormapGray(self):
self.plot.setKeepDataAspectRatio(False)
- self.plot.setGraphTitle('Gray Linear')
+ self.plot.setGraphTitle("Gray Linear")
- colormap = Colormap(name='gray',
- normalization='linear',
- vmin=None,
- vmax=None)
+ colormap = Colormap(name="gray", normalization="linear", vmin=None, vmax=None)
self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap)
def testPlotColormapTemperatureLog(self):
- self.plot.setGraphTitle('Temp. Log')
+ self.plot.setGraphTitle("Temp. Log")
- colormap = Colormap(name='temperature',
- normalization=Colormap.LOGARITHM,
- vmin=None,
- vmax=None)
+ colormap = Colormap(
+ name="temperature", normalization=Colormap.LOGARITHM, vmin=None, vmax=None
+ )
self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap)
def testPlotRgbRgba(self):
self.plot.setKeepDataAspectRatio(False)
- self.plot.setGraphTitle('RGB + RGBA')
+ self.plot.setGraphTitle("RGB + RGBA")
rgb = numpy.array(
- (((0, 0, 0), (128, 0, 0), (255, 0, 0)),
- ((0, 128, 0), (0, 128, 128), (0, 128, 255))),
- dtype=numpy.uint8)
+ (
+ ((0, 0, 0), (128, 0, 0), (255, 0, 0)),
+ ((0, 128, 0), (0, 128, 128), (0, 128, 255)),
+ ),
+ dtype=numpy.uint8,
+ )
- self.plot.addImage(rgb, legend="rgb_uint8",
- origin=(0, 0), scale=(1, 1),
- resetzoom=False)
+ self.plot.addImage(
+ rgb, legend="rgb_uint8", origin=(0, 0), scale=(1, 1), resetzoom=False
+ )
rgb = numpy.array(
- (((0, 0, 0), (32768, 0, 0), (65535, 0, 0)),
- ((0, 32768, 0), (0, 32768, 32768), (0, 32768, 65535))),
- dtype=numpy.uint16)
+ (
+ ((0, 0, 0), (32768, 0, 0), (65535, 0, 0)),
+ ((0, 32768, 0), (0, 32768, 32768), (0, 32768, 65535)),
+ ),
+ dtype=numpy.uint16,
+ )
- self.plot.addImage(rgb, legend="rgb_uint16",
- origin=(3, 2), scale=(2, 2),
- resetzoom=False)
+ self.plot.addImage(
+ rgb, legend="rgb_uint16", origin=(3, 2), scale=(2, 2), resetzoom=False
+ )
rgba = numpy.array(
- (((0, 0, 0, .5), (.5, 0, 0, 1), (1, 0, 0, .5)),
- ((0, .5, 0, 1), (0, .5, .5, 1), (0, 1, 1, .5))),
- dtype=numpy.float32)
+ (
+ ((0, 0, 0, 0.5), (0.5, 0, 0, 1), (1, 0, 0, 0.5)),
+ ((0, 0.5, 0, 1), (0, 0.5, 0.5, 1), (0, 1, 1, 0.5)),
+ ),
+ dtype=numpy.float32,
+ )
- self.plot.addImage(rgba, legend="rgba_float32",
- origin=(9, 6), scale=(1, 1),
- resetzoom=False)
+ self.plot.addImage(
+ rgba, legend="rgba_float32", origin=(9, 6), scale=(1, 1), resetzoom=False
+ )
self.plot.resetZoom()
def testPlotColormapCustom(self):
self.plot.setKeepDataAspectRatio(False)
- self.plot.setGraphTitle('Custom colormap')
-
- colormap = Colormap(name=None,
- normalization=Colormap.LINEAR,
- vmin=None,
- vmax=None,
- colors=((0., 0., 0.), (1., 0., 0.),
- (0., 1., 0.), (0., 0., 1.)))
- self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap,
- resetzoom=False)
-
- colormap = Colormap(name=None,
- normalization=Colormap.LINEAR,
- vmin=None,
- vmax=None,
- colors=numpy.array(
- ((0, 0, 0, 0), (0, 0, 0, 128),
- (128, 128, 128, 128), (255, 255, 255, 255)),
- dtype=numpy.uint8))
- self.plot.addImage(DATA_2D, legend="image 2", colormap=colormap,
- origin=(DATA_2D.shape[0], 0),
- resetzoom=False)
+ self.plot.setGraphTitle("Custom colormap")
+
+ colormap = Colormap(
+ name=None,
+ normalization=Colormap.LINEAR,
+ vmin=None,
+ vmax=None,
+ colors=((0.0, 0.0, 0.0), (1.0, 0.0, 0.0), (0.0, 1.0, 0.0), (0.0, 0.0, 1.0)),
+ )
+ self.plot.addImage(
+ DATA_2D, legend="image 1", colormap=colormap, resetzoom=False
+ )
+
+ colormap = Colormap(
+ name=None,
+ normalization=Colormap.LINEAR,
+ vmin=None,
+ vmax=None,
+ colors=numpy.array(
+ (
+ (0, 0, 0, 0),
+ (0, 0, 0, 128),
+ (128, 128, 128, 128),
+ (255, 255, 255, 255),
+ ),
+ dtype=numpy.uint8,
+ ),
+ )
+ self.plot.addImage(
+ DATA_2D,
+ legend="image 2",
+ colormap=colormap,
+ origin=(DATA_2D.shape[0], 0),
+ resetzoom=False,
+ )
self.plot.resetZoom()
def testPlotColormapNaNColor(self):
self.plot.setKeepDataAspectRatio(False)
- self.plot.setGraphTitle('Colormap with NaN color')
+ self.plot.setGraphTitle("Colormap with NaN color")
colormap = Colormap()
- colormap.setNaNColor('red')
+ colormap.setNaNColor("red")
self.assertEqual(colormap.getNaNColor(), qt.QColor(255, 0, 0))
data = DATA_2D.astype(numpy.float32)
- data[len(data)//2:] = numpy.nan
- self.plot.addImage(data, legend="image 1", colormap=colormap,
- resetzoom=False)
+ data[len(data) // 2 :] = numpy.nan
+ self.plot.addImage(data, legend="image 1", colormap=colormap, resetzoom=False)
self.plot.resetZoom()
- colormap.setNaNColor((0., 1., 0., 1.))
+ colormap.setNaNColor((0.0, 1.0, 0.0, 1.0))
self.assertEqual(colormap.getNaNColor(), qt.QColor(0, 255, 0))
self.qapp.processEvents()
def testImageOriginScale(self):
"""Test of image with different origin and scale"""
- self.plot.setGraphTitle('origin and scale')
+ self.plot.setGraphTitle("origin and scale")
tests = [ # (origin, scale)
((10, 20), (1, 1)),
@@ -367,7 +373,7 @@ class TestPlotImage(PlotWidgetTestCase, ParametricTestCase):
(100, 2),
(-100, (1, 1)),
((10, 20), 2),
- ]
+ ]
for origin, scale in tests:
with self.subTest(origin=origin, scale=scale):
@@ -408,31 +414,30 @@ class TestPlotImage(PlotWidgetTestCase, ParametricTestCase):
def testPlotColormapDictAPI(self):
"""Test that the addImage API using a colormap dictionary is still
working"""
- self.plot.setGraphTitle('Temp. Log')
+ self.plot.setGraphTitle("Temp. Log")
colormap = {
- 'name': 'temperature',
- 'normalization': 'log',
- 'vmin': None,
- 'vmax': None
+ "name": "temperature",
+ "normalization": "log",
+ "vmin": None,
+ "vmax": None,
}
self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap)
def testPlotComplexImage(self):
"""Test that a complex image is displayed as its absolute value."""
data = numpy.linspace(1, 1j, 100).reshape(10, 10)
- self.plot.addImage(data, legend='complex')
+ self.plot.addImage(data, legend="complex")
image = self.plot.getActiveImage()
retrievedData = image.getData(copy=False)
- self.assertTrue(
- numpy.all(numpy.equal(retrievedData, numpy.absolute(data))))
+ self.assertTrue(numpy.all(numpy.equal(retrievedData, numpy.absolute(data))))
def testPlotBooleanImage(self):
"""Test that a boolean image is displayed and converted to int8."""
data = numpy.zeros((10, 10), dtype=bool)
data[::2, ::2] = True
- self.plot.addImage(data, legend='boolean')
+ self.plot.addImage(data, legend="boolean")
image = self.plot.getActiveImage()
retrievedData = image.getData(copy=False)
@@ -443,7 +448,7 @@ class TestPlotImage(PlotWidgetTestCase, ParametricTestCase):
"""Test with an alpha image layer"""
data = numpy.random.random((10, 10))
alpha = numpy.linspace(0, 1, 100).reshape(10, 10)
- self.plot.addImage(data, legend='image')
+ self.plot.addImage(data, legend="image")
image = self.plot.getActiveImage()
image.setData(data, alpha=alpha)
self.qapp.processEvents()
@@ -461,19 +466,19 @@ class TestPlotCurve(PlotWidgetTestCase):
def setUp(self):
super(TestPlotCurve, self).setUp()
- self.plot.setGraphTitle('Curve')
- self.plot.getYAxis().setLabel('Rows')
- self.plot.getXAxis().setLabel('Columns')
+ self.plot.setGraphTitle("Curve")
+ self.plot.getYAxis().setLabel("Rows")
+ self.plot.getXAxis().setLabel("Columns")
self.plot.setActiveCurveHandling(False)
def testPlotCurveInfinite(self):
"""Test plot curves with not finite data"""
tests = {
- 'y all not finite': ([0, 1, 2], [numpy.inf, numpy.nan, -numpy.inf]),
- 'x all not finite': ([numpy.inf, numpy.nan, -numpy.inf], [0, 1, 2]),
- 'x some inf': ([0, numpy.inf, 2], [0, 1, 2]),
- 'y some inf': ([0, 1, 2], [0, numpy.inf, 2])
+ "y all not finite": ([0, 1, 2], [numpy.inf, numpy.nan, -numpy.inf]),
+ "x all not finite": ([numpy.inf, numpy.nan, -numpy.inf], [0, 1, 2]),
+ "x some inf": ([0, numpy.inf, 2], [0, 1, 2]),
+ "y some inf": ([0, 1, 2], [0, numpy.inf, 2]),
}
for name, args in tests.items():
with self.subTest(name):
@@ -483,65 +488,111 @@ class TestPlotCurve(PlotWidgetTestCase):
self.plot.clear()
def testPlotCurveColorFloat(self):
- color = numpy.array(numpy.random.random(3 * 1000),
- dtype=numpy.float32).reshape(1000, 3)
-
- self.plot.addCurve(self.xData, self.yData,
- legend="curve 1",
- replace=False, resetzoom=False,
- color=color,
- linestyle="", symbol="s")
- self.plot.addCurve(self.xData2, self.yData2,
- legend="curve 2",
- replace=False, resetzoom=False,
- color='green', linestyle="-", symbol='o')
+ color = numpy.array(numpy.random.random(3 * 1000), dtype=numpy.float32).reshape(
+ 1000, 3
+ )
+
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve 1",
+ replace=False,
+ resetzoom=False,
+ color=color,
+ linestyle="",
+ symbol="s",
+ )
+ self.plot.addCurve(
+ self.xData2,
+ self.yData2,
+ legend="curve 2",
+ replace=False,
+ resetzoom=False,
+ color="green",
+ linestyle="-",
+ symbol="o",
+ )
self.plot.resetZoom()
def testPlotCurveColorByte(self):
- color = numpy.array(255 * numpy.random.random(3 * 1000),
- dtype=numpy.uint8).reshape(1000, 3)
-
- self.plot.addCurve(self.xData, self.yData,
- legend="curve 1",
- replace=False, resetzoom=False,
- color=color,
- linestyle="", symbol="s")
- self.plot.addCurve(self.xData2, self.yData2,
- legend="curve 2",
- replace=False, resetzoom=False,
- color='green', linestyle="-", symbol='o')
+ color = numpy.array(
+ 255 * numpy.random.random(3 * 1000), dtype=numpy.uint8
+ ).reshape(1000, 3)
+
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve 1",
+ replace=False,
+ resetzoom=False,
+ color=color,
+ linestyle="",
+ symbol="s",
+ )
+ self.plot.addCurve(
+ self.xData2,
+ self.yData2,
+ legend="curve 2",
+ replace=False,
+ resetzoom=False,
+ color="green",
+ linestyle="-",
+ symbol="o",
+ )
self.plot.resetZoom()
def testPlotCurveColors(self):
- color = numpy.array(numpy.random.random(3 * 1000),
- dtype=numpy.float32).reshape(1000, 3)
-
- self.plot.addCurve(self.xData, self.yData,
- legend="curve 2",
- replace=False, resetzoom=False,
- color=color, linestyle="-", symbol='o')
+ color = numpy.array(numpy.random.random(3 * 1000), dtype=numpy.float32).reshape(
+ 1000, 3
+ )
+
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve 2",
+ replace=False,
+ resetzoom=False,
+ color=color,
+ linestyle="-",
+ symbol="o",
+ )
self.plot.resetZoom()
# Test updating color array
# From array to array
newColors = numpy.ones((len(self.xData), 3), dtype=numpy.float32)
- self.plot.addCurve(self.xData, self.yData,
- legend="curve 2",
- replace=False, resetzoom=False,
- color=newColors, symbol='o')
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve 2",
+ replace=False,
+ resetzoom=False,
+ color=newColors,
+ symbol="o",
+ )
# Array to single color
- self.plot.addCurve(self.xData, self.yData,
- legend="curve 2",
- replace=False, resetzoom=False,
- color='green', symbol='o')
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve 2",
+ replace=False,
+ resetzoom=False,
+ color="green",
+ symbol="o",
+ )
# single color to array
- self.plot.addCurve(self.xData, self.yData,
- legend="curve 2",
- replace=False, resetzoom=False,
- color=color, symbol='o')
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve 2",
+ replace=False,
+ resetzoom=False,
+ color=color,
+ symbol="o",
+ )
def testPlotBaselineNumpyArray(self):
"""simple test of the API with baseline as a numpy array"""
@@ -550,8 +601,9 @@ class TestPlotCurve(PlotWidgetTestCase):
y = numpy.arange(-4, 6, step=0.1) + my_sin
baseline = y - 1.0
- self.plot.addCurve(x=x, y=y, color='grey', legend='curve1', fill=True,
- baseline=baseline)
+ self.plot.addCurve(
+ x=x, y=y, color="grey", legend="curve1", fill=True, baseline=baseline
+ )
def testPlotBaselineScalar(self):
"""simple test of the API with baseline as an int"""
@@ -559,8 +611,9 @@ class TestPlotCurve(PlotWidgetTestCase):
my_sin = numpy.sin(x)
y = numpy.arange(-4, 6, step=0.1) + my_sin
- self.plot.addCurve(x=x, y=y, color='grey', legend='curve1', fill=True,
- baseline=0)
+ self.plot.addCurve(
+ x=x, y=y, color="grey", legend="curve1", fill=True, baseline=0
+ )
def testPlotBaselineList(self):
"""simple test of the API with baseline as an int"""
@@ -568,35 +621,70 @@ class TestPlotCurve(PlotWidgetTestCase):
my_sin = numpy.sin(x)
y = numpy.arange(-4, 6, step=0.1) + my_sin
- self.plot.addCurve(x=x, y=y, color='grey', legend='curve1', fill=True,
- baseline=list(range(0, 100, 1)))
+ self.plot.addCurve(
+ x=x,
+ y=y,
+ color="grey",
+ legend="curve1",
+ fill=True,
+ baseline=list(range(0, 100, 1)),
+ )
def testPlotCurveComplexData(self):
"""Test curve with complex data"""
- data = numpy.arange(100.) + 1j
+ data = numpy.arange(100.0) + 1j
self.plot.addCurve(x=data, y=data, xerror=data, yerror=data)
+ def testPlotCurveGapColor(self):
+ """Test dashed curve with gap color"""
+ data = numpy.arange(100)
+ self.plot.addCurve(
+ x=data, y=data, legend="curve1", linestyle="--", color="blue"
+ )
+ curve = self.plot.getCurve("curve1")
+ assert curve.getLineGapColor() is None
+ curve.setLineGapColor("red")
+ assert curve.getLineGapColor() == (1.0, 0.0, 0.0, 1.0)
+
class TestPlotHistogram(PlotWidgetTestCase):
"""Basic tests for add Histogram"""
+
def setUp(self):
super(TestPlotHistogram, self).setUp()
self.edges = numpy.arange(0, 10, step=1)
self.histogram = numpy.random.random(len(self.edges))
def testPlot(self):
- self.plot.addHistogram(histogram=self.histogram,
- edges=self.edges,
- legend='histogram1')
+ self.plot.addHistogram(
+ histogram=self.histogram, edges=self.edges, legend="histogram1"
+ )
def testPlotBaseline(self):
- self.plot.addHistogram(histogram=self.histogram,
- edges=self.edges,
- legend='histogram1',
- color='blue',
- baseline=-2,
- z=2,
- fill=True)
+ self.plot.addHistogram(
+ histogram=self.histogram,
+ edges=self.edges,
+ legend="histogram1",
+ color="blue",
+ baseline=-2,
+ z=2,
+ fill=True,
+ )
+
+ def testPlotGapColor(self):
+ """Test dashed histogram with gap color"""
+ data = numpy.arange(100)
+ self.plot.addHistogram(
+ histogram=self.histogram,
+ edges=self.edges,
+ legend="histogram1",
+ color="blue",
+ )
+ histogram = self.plot.getItems()[0]
+ assert histogram.getLineGapColor() is None
+ histogram.setLineGapColor("red")
+ assert histogram.getLineGapColor() == (1.0, 0.0, 0.0, 1.0)
+ histogram.setLineStyle(":")
class TestPlotScatter(PlotWidgetTestCase, ParametricTestCase):
@@ -611,9 +699,8 @@ class TestPlotScatter(PlotWidgetTestCase, ParametricTestCase):
def testScatterComplexData(self):
"""Test scatter item with complex data"""
- data = numpy.arange(100.) + 1j
- self.plot.addScatter(
- x=data, y=data, value=data, xerror=data, yerror=data)
+ data = numpy.arange(100.0) + 1j
+ self.plot.addScatter(x=data, y=data, value=data, xerror=data, yerror=data)
self.plot.resetZoom()
def testScatterVisualization(self):
@@ -623,16 +710,18 @@ class TestPlotScatter(PlotWidgetTestCase, ParametricTestCase):
scatter = self.plot.getItems()[0]
- for visualization in ('solid',
- 'points',
- 'regular_grid',
- 'irregular_grid',
- 'binned_statistic',
- scatter.Visualization.SOLID,
- scatter.Visualization.POINTS,
- scatter.Visualization.REGULAR_GRID,
- scatter.Visualization.IRREGULAR_GRID,
- scatter.Visualization.BINNED_STATISTIC):
+ for visualization in (
+ "solid",
+ "points",
+ "regular_grid",
+ "irregular_grid",
+ "binned_statistic",
+ scatter.Visualization.SOLID,
+ scatter.Visualization.POINTS,
+ scatter.Visualization.REGULAR_GRID,
+ scatter.Visualization.IRREGULAR_GRID,
+ scatter.Visualization.BINNED_STATISTIC,
+ ):
with self.subTest(visualization=visualization):
scatter.setVisualization(visualization)
self.qapp.processEvents()
@@ -640,28 +729,30 @@ class TestPlotScatter(PlotWidgetTestCase, ParametricTestCase):
def testGridVisualization(self):
"""Test regular and irregular grid mode with different points"""
points = { # name: (x, y, order)
- 'single point': ((1.,), (1.,), 'row'),
- 'horizontal line': ((0, 1, 2), (0, 0, 0), 'row'),
- 'horizontal line backward': ((2, 1, 0), (0, 0, 0), 'row'),
- 'vertical line': ((0, 0, 0), (0, 1, 2), 'row'),
- 'vertical line backward': ((0, 0, 0), (2, 1, 0), 'row'),
- 'grid fast x, +x +y': ((0, 1, 2, 0, 1, 2), (0, 0, 0, 1, 1, 1), 'row'),
- 'grid fast x, +x -y': ((0, 1, 2, 0, 1, 2), (1, 1, 1, 0, 0, 0), 'row'),
- 'grid fast x, -x -y': ((2, 1, 0, 2, 1, 0), (1, 1, 1, 0, 0, 0), 'row'),
- 'grid fast x, -x +y': ((2, 1, 0, 2, 1, 0), (0, 0, 0, 1, 1, 1), 'row'),
- 'grid fast y, +x +y': ((0, 0, 0, 1, 1, 1), (0, 1, 2, 0, 1, 2), 'column'),
- 'grid fast y, +x -y': ((0, 0, 0, 1, 1, 1), (2, 1, 0, 2, 1, 0), 'column'),
- 'grid fast y, -x -y': ((1, 1, 1, 0, 0, 0), (2, 1, 0, 2, 1, 0), 'column'),
- 'grid fast y, -x +y': ((1, 1, 1, 0, 0, 0), (0, 1, 2, 0, 1, 2), 'column'),
- }
+ "single point": ((1.0,), (1.0,), "row"),
+ "horizontal line": ((0, 1, 2), (0, 0, 0), "row"),
+ "horizontal line backward": ((2, 1, 0), (0, 0, 0), "row"),
+ "vertical line": ((0, 0, 0), (0, 1, 2), "row"),
+ "vertical line backward": ((0, 0, 0), (2, 1, 0), "row"),
+ "grid fast x, +x +y": ((0, 1, 2, 0, 1, 2), (0, 0, 0, 1, 1, 1), "row"),
+ "grid fast x, +x -y": ((0, 1, 2, 0, 1, 2), (1, 1, 1, 0, 0, 0), "row"),
+ "grid fast x, -x -y": ((2, 1, 0, 2, 1, 0), (1, 1, 1, 0, 0, 0), "row"),
+ "grid fast x, -x +y": ((2, 1, 0, 2, 1, 0), (0, 0, 0, 1, 1, 1), "row"),
+ "grid fast y, +x +y": ((0, 0, 0, 1, 1, 1), (0, 1, 2, 0, 1, 2), "column"),
+ "grid fast y, +x -y": ((0, 0, 0, 1, 1, 1), (2, 1, 0, 2, 1, 0), "column"),
+ "grid fast y, -x -y": ((1, 1, 1, 0, 0, 0), (2, 1, 0, 2, 1, 0), "column"),
+ "grid fast y, -x +y": ((1, 1, 1, 0, 0, 0), (0, 1, 2, 0, 1, 2), "column"),
+ }
self.plot.addScatter((), (), ())
scatter = self.plot.getItems()[0]
self.qapp.processEvents()
- for visualization in (scatter.Visualization.REGULAR_GRID,
- scatter.Visualization.IRREGULAR_GRID):
+ for visualization in (
+ scatter.Visualization.REGULAR_GRID,
+ scatter.Visualization.IRREGULAR_GRID,
+ ):
scatter.setVisualization(visualization)
self.assertIs(scatter.getVisualization(), visualization)
@@ -673,16 +764,19 @@ class TestPlotScatter(PlotWidgetTestCase, ParametricTestCase):
self.qapp.processEvents()
order = scatter.getCurrentVisualizationParameter(
- scatter.VisualizationParameter.GRID_MAJOR_ORDER)
+ scatter.VisualizationParameter.GRID_MAJOR_ORDER
+ )
self.assertEqual(ref_order, order)
ref_bounds = (x[0], y[0]), (x[-1], y[-1])
bounds = scatter.getCurrentVisualizationParameter(
- scatter.VisualizationParameter.GRID_BOUNDS)
+ scatter.VisualizationParameter.GRID_BOUNDS
+ )
self.assertEqual(ref_bounds, bounds)
shape = scatter.getCurrentVisualizationParameter(
- scatter.VisualizationParameter.GRID_SHAPE)
+ scatter.VisualizationParameter.GRID_SHAPE
+ )
self.plot.getXAxis().setLimits(numpy.min(x) - 1, numpy.max(x) + 1)
self.plot.getYAxis().setLimits(numpy.min(y) - 1, numpy.max(y) + 1)
@@ -700,12 +794,15 @@ class TestPlotScatter(PlotWidgetTestCase, ParametricTestCase):
self.plot.addScatter((), (), ())
scatter = self.plot.getItems()[0]
scatter.setVisualization(scatter.Visualization.BINNED_STATISTIC)
- self.assertIs(scatter.getVisualization(),
- scatter.Visualization.BINNED_STATISTIC)
+ self.assertIs(
+ scatter.getVisualization(), scatter.Visualization.BINNED_STATISTIC
+ )
self.assertEqual(
scatter.getVisualizationParameter(
- scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION),
- 'mean')
+ scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION
+ ),
+ "mean",
+ )
self.qapp.processEvents()
@@ -716,15 +813,17 @@ class TestPlotScatter(PlotWidgetTestCase, ParametricTestCase):
scatter.setData(*numpy.random.random(3000).reshape(3, -1))
self.qapp.processEvents()
- for reduction in ('count', 'sum', 'mean'):
+ for reduction in ("count", "sum", "mean"):
with self.subTest(reduction=reduction):
scatter.setVisualizationParameter(
- scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION,
- reduction)
+ scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION, reduction
+ )
self.assertEqual(
scatter.getVisualizationParameter(
- scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION),
- reduction)
+ scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION
+ ),
+ reduction,
+ )
self.qapp.processEvents()
@@ -734,23 +833,23 @@ class TestPlotMarker(PlotWidgetTestCase):
def setUp(self):
super(TestPlotMarker, self).setUp()
- self.plot.getYAxis().setLabel('Rows')
- self.plot.getXAxis().setLabel('Columns')
+ self.plot.getYAxis().setLabel("Rows")
+ self.plot.getXAxis().setLabel("Columns")
self.plot.getXAxis().setAutoScale(False)
self.plot.getYAxis().setAutoScale(False)
self.plot.setKeepDataAspectRatio(False)
- self.plot.setLimits(0., 100., -100., 100.)
+ self.plot.setLimits(0.0, 100.0, -100.0, 100.0)
def testPlotMarkerX(self):
- self.plot.setGraphTitle('Markers X')
+ self.plot.setGraphTitle("Markers X")
markers = [
- (10., 'blue', False, False),
- (20., 'red', False, False),
- (40., 'green', True, False),
- (60., 'gray', True, True),
- (80., 'black', False, True),
+ (10.0, "blue", False, False),
+ (20.0, "red", False, False),
+ (40.0, "green", True, False),
+ (60.0, "gray", True, True),
+ (80.0, "black", False, True),
]
for x, color, select, drag in markers:
@@ -763,14 +862,14 @@ class TestPlotMarker(PlotWidgetTestCase):
self.plot.resetZoom()
def testPlotMarkerY(self):
- self.plot.setGraphTitle('Markers Y')
+ self.plot.setGraphTitle("Markers Y")
markers = [
- (-50., 'blue', False, False),
- (-30., 'red', False, False),
- (0., 'green', True, False),
- (10., 'gray', True, True),
- (80., 'black', False, True),
+ (-50.0, "blue", False, False),
+ (-30.0, "red", False, False),
+ (0.0, "green", True, False),
+ (10.0, "gray", True, True),
+ (80.0, "black", False, True),
]
for y, color, select, drag in markers:
@@ -783,14 +882,14 @@ class TestPlotMarker(PlotWidgetTestCase):
self.plot.resetZoom()
def testPlotMarkerPt(self):
- self.plot.setGraphTitle('Markers Pt')
+ self.plot.setGraphTitle("Markers Pt")
markers = [
- (10., -50., 'blue', False, False),
- (40., -30., 'red', False, False),
- (50., 0., 'green', True, False),
- (50., 20., 'gray', True, True),
- (70., 50., 'black', False, True),
+ (10.0, -50.0, "blue", False, False),
+ (40.0, -30.0, "red", False, False),
+ (50.0, 0.0, "green", True, False),
+ (50.0, 20.0, "gray", True, True),
+ (70.0, 50.0, "black", False, True),
]
for x, y, color, select, drag in markers:
name = "{0},{1}".format(x, y)
@@ -803,52 +902,45 @@ class TestPlotMarker(PlotWidgetTestCase):
self.plot.resetZoom()
def testPlotMarkerWithoutLegend(self):
- self.plot.setGraphTitle('Markers without legend')
+ self.plot.setGraphTitle("Markers without legend")
self.plot.getYAxis().setInverted(True)
# Markers without legend
self.plot.addMarker(10, 10)
self.plot.addMarker(10, 20)
- self.plot.addMarker(40, 50, text='test', symbol=None)
- self.plot.addMarker(40, 50, text='test', symbol='+')
+ self.plot.addMarker(40, 50, text="test", symbol=None)
+ self.plot.addMarker(40, 50, text="test", symbol="+")
self.plot.addXMarker(25)
self.plot.addXMarker(35)
- self.plot.addXMarker(45, text='test')
+ self.plot.addXMarker(45, text="test")
self.plot.addYMarker(55)
self.plot.addYMarker(65)
- self.plot.addYMarker(75, text='test')
+ self.plot.addYMarker(75, text="test")
self.plot.resetZoom()
def testPlotMarkerYAxis(self):
# Check only the API
- legend = self.plot.addMarker(10, 10)
- item = self.plot._getMarker(legend)
+ item = self.plot.addMarker(10, 10)
self.assertEqual(item.getYAxis(), "left")
- legend = self.plot.addMarker(10, 10, yaxis="right")
- item = self.plot._getMarker(legend)
+ item = self.plot.addMarker(10, 10, yaxis="right")
self.assertEqual(item.getYAxis(), "right")
- legend = self.plot.addMarker(10, 10, yaxis="left")
- item = self.plot._getMarker(legend)
+ item = self.plot.addMarker(10, 10, yaxis="left")
self.assertEqual(item.getYAxis(), "left")
- legend = self.plot.addXMarker(10, yaxis="right")
- item = self.plot._getMarker(legend)
+ item = self.plot.addXMarker(10, yaxis="right")
self.assertEqual(item.getYAxis(), "right")
- legend = self.plot.addXMarker(10, yaxis="left")
- item = self.plot._getMarker(legend)
+ item = self.plot.addXMarker(10, yaxis="left")
self.assertEqual(item.getYAxis(), "left")
- legend = self.plot.addYMarker(10, yaxis="right")
- item = self.plot._getMarker(legend)
+ item = self.plot.addYMarker(10, yaxis="right")
self.assertEqual(item.getYAxis(), "right")
- legend = self.plot.addYMarker(10, yaxis="left")
- item = self.plot._getMarker(legend)
+ item = self.plot.addYMarker(10, yaxis="left")
self.assertEqual(item.getYAxis(), "left")
self.plot.resetZoom()
@@ -856,39 +948,72 @@ class TestPlotMarker(PlotWidgetTestCase):
# TestPlotItem ################################################################
+
class TestPlotItem(PlotWidgetTestCase):
"""Basic tests for addItem."""
# Polygon coordinates and color
POLYGONS = [ # legend, x coords, y coords, color
- ('triangle', numpy.array((10, 30, 50)),
- numpy.array((55, 70, 55)), 'red'),
- ('square', numpy.array((10, 10, 50, 50)),
- numpy.array((10, 50, 50, 10)), 'green'),
- ('star', numpy.array((60, 70, 80, 60, 80)),
- numpy.array((25, 50, 25, 40, 40)), 'blue'),
- ('2 triangles-simple',
- numpy.array((90., 95., 100., numpy.nan, 90., 95., 100.)),
- numpy.array((25., 5., 25., numpy.nan, 30., 50., 30.)),
- 'pink'),
- ('2 triangles-extra NaN',
- numpy.array((numpy.nan, 90., 95., 100., numpy.nan, 0., 90., 95., 100., numpy.nan)),
- numpy.array((0., 55., 70., 55., numpy.nan, numpy.nan, 75., 90., 75., numpy.nan)),
- 'black'),
+ ("triangle", numpy.array((10, 30, 50)), numpy.array((55, 70, 55)), "red"),
+ (
+ "square",
+ numpy.array((10, 10, 50, 50)),
+ numpy.array((10, 50, 50, 10)),
+ "green",
+ ),
+ (
+ "star",
+ numpy.array((60, 70, 80, 60, 80)),
+ numpy.array((25, 50, 25, 40, 40)),
+ "blue",
+ ),
+ (
+ "2 triangles-simple",
+ numpy.array((90.0, 95.0, 100.0, numpy.nan, 90.0, 95.0, 100.0)),
+ numpy.array((25.0, 5.0, 25.0, numpy.nan, 30.0, 50.0, 30.0)),
+ "pink",
+ ),
+ (
+ "2 triangles-extra NaN",
+ numpy.array(
+ (
+ numpy.nan,
+ 90.0,
+ 95.0,
+ 100.0,
+ numpy.nan,
+ 0.0,
+ 90.0,
+ 95.0,
+ 100.0,
+ numpy.nan,
+ )
+ ),
+ numpy.array(
+ (
+ 0.0,
+ 55.0,
+ 70.0,
+ 55.0,
+ numpy.nan,
+ numpy.nan,
+ 75.0,
+ 90.0,
+ 75.0,
+ numpy.nan,
+ )
+ ),
+ "black",
+ ),
]
# Rectangle coordinantes and color
RECTANGLES = [ # legend, x coords, y coords, color
- ('square 1', numpy.array((1., 10.)),
- numpy.array((1., 10.)), 'red'),
- ('square 2', numpy.array((10., 20.)),
- numpy.array((10., 20.)), 'green'),
- ('square 3', numpy.array((20., 30.)),
- numpy.array((20., 30.)), 'blue'),
- ('rect 1', numpy.array((1., 30.)),
- numpy.array((35., 40.)), 'black'),
- ('line h', numpy.array((1., 30.)),
- numpy.array((45., 45.)), 'darkRed'),
+ ("square 1", numpy.array((1.0, 10.0)), numpy.array((1.0, 10.0)), "red"),
+ ("square 2", numpy.array((10.0, 20.0)), numpy.array((10.0, 20.0)), "green"),
+ ("square 3", numpy.array((20.0, 30.0)), numpy.array((20.0, 30.0)), "blue"),
+ ("rect 1", numpy.array((1.0, 30.0)), numpy.array((35.0, 40.0)), "black"),
+ ("line h", numpy.array((1.0, 30.0)), numpy.array((45.0, 45.0)), "darkRed"),
]
SCALES = Axis.LINEAR, Axis.LOGARITHMIC
@@ -896,12 +1021,12 @@ class TestPlotItem(PlotWidgetTestCase):
def setUp(self):
super(TestPlotItem, self).setUp()
- self.plot.getYAxis().setLabel('Rows')
- self.plot.getXAxis().setLabel('Columns')
+ self.plot.getYAxis().setLabel("Rows")
+ self.plot.getXAxis().setLabel("Columns")
self.plot.getXAxis().setAutoScale(False)
self.plot.getYAxis().setAutoScale(False)
self.plot.setKeepDataAspectRatio(False)
- self.plot.setLimits(0., 100., -100., 100.)
+ self.plot.setLimits(0.0, 100.0, -100.0, 100.0)
def testPlotItemPolygonFill(self):
for scale in self.SCALES:
@@ -909,12 +1034,19 @@ class TestPlotItem(PlotWidgetTestCase):
self.plot.clear()
self.plot.getXAxis().setScale(scale)
self.plot.getYAxis().setScale(scale)
- self.plot.setGraphTitle('Item Fill %s' % scale)
+ self.plot.setGraphTitle("Item Fill %s" % scale)
for legend, xList, yList, color in self.POLYGONS:
- self.plot.addShape(xList, yList, legend=legend,
- replace=False, linestyle='--',
- shape="polygon", fill=True, color=color)
+ self.plot.addShape(
+ xList,
+ yList,
+ legend=legend,
+ replace=False,
+ linestyle="--",
+ shape="polygon",
+ fill=True,
+ color=color,
+ )
self.plot.resetZoom()
def testPlotItemPolygonNoFill(self):
@@ -923,12 +1055,19 @@ class TestPlotItem(PlotWidgetTestCase):
self.plot.clear()
self.plot.getXAxis().setScale(scale)
self.plot.getYAxis().setScale(scale)
- self.plot.setGraphTitle('Item No Fill %s' % scale)
+ self.plot.setGraphTitle("Item No Fill %s" % scale)
for legend, xList, yList, color in self.POLYGONS:
- self.plot.addShape(xList, yList, legend=legend,
- replace=False, linestyle='--',
- shape="polygon", fill=False, color=color)
+ self.plot.addShape(
+ xList,
+ yList,
+ legend=legend,
+ replace=False,
+ linestyle="--",
+ shape="polygon",
+ fill=False,
+ color=color,
+ )
self.plot.resetZoom()
def testPlotItemRectangleFill(self):
@@ -937,12 +1076,18 @@ class TestPlotItem(PlotWidgetTestCase):
self.plot.clear()
self.plot.getXAxis().setScale(scale)
self.plot.getYAxis().setScale(scale)
- self.plot.setGraphTitle('Rectangle Fill %s' % scale)
+ self.plot.setGraphTitle("Rectangle Fill %s" % scale)
for legend, xList, yList, color in self.RECTANGLES:
- self.plot.addShape(xList, yList, legend=legend,
- replace=False,
- shape="rectangle", fill=True, color=color)
+ self.plot.addShape(
+ xList,
+ yList,
+ legend=legend,
+ replace=False,
+ shape="rectangle",
+ fill=True,
+ color=color,
+ )
self.plot.resetZoom()
def testPlotItemRectangleNoFill(self):
@@ -951,230 +1096,44 @@ class TestPlotItem(PlotWidgetTestCase):
self.plot.clear()
self.plot.getXAxis().setScale(scale)
self.plot.getYAxis().setScale(scale)
- self.plot.setGraphTitle('Rectangle No Fill %s' % scale)
+ self.plot.setGraphTitle("Rectangle No Fill %s" % scale)
for legend, xList, yList, color in self.RECTANGLES:
- self.plot.addShape(xList, yList, legend=legend,
- replace=False,
- shape="rectangle", fill=False, color=color)
+ self.plot.addShape(
+ xList,
+ yList,
+ legend=legend,
+ replace=False,
+ shape="rectangle",
+ fill=False,
+ color=color,
+ )
self.plot.resetZoom()
-class TestPlotActiveCurveImage(PlotWidgetTestCase):
- """Basic tests for active curve and image handling"""
- xData = numpy.arange(1000)
- yData = -500 + 100 * numpy.sin(xData)
- xData2 = xData + 1000
- yData2 = xData - 1000 + 200 * numpy.random.random(1000)
-
- def tearDown(self):
- self.plot.setActiveCurveHandling(False)
- super(TestPlotActiveCurveImage, self).tearDown()
-
- def testActiveCurveAndLabels(self):
- # Active curve handling off, no label change
- self.plot.setActiveCurveHandling(False)
- self.plot.getXAxis().setLabel('XLabel')
- self.plot.getYAxis().setLabel('YLabel')
- self.plot.addCurve((1, 2), (1, 2))
- self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel')
- self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel')
-
- self.plot.addCurve((1, 2), (2, 3), xlabel='x1', ylabel='y1')
- self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel')
- self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel')
-
- self.plot.clear()
- self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel')
- self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel')
-
- # Active curve handling on, label changes
- self.plot.setActiveCurveHandling(True)
- self.plot.getXAxis().setLabel('XLabel')
- self.plot.getYAxis().setLabel('YLabel')
-
- # labels changed as active curve
- self.plot.addCurve((1, 2), (1, 2), legend='1',
- xlabel='x1', ylabel='y1')
- self.plot.setActiveCurve('1')
- self.assertEqual(self.plot.getXAxis().getLabel(), 'x1')
- self.assertEqual(self.plot.getYAxis().getLabel(), 'y1')
-
- # labels not changed as not active curve
- self.plot.addCurve((1, 2), (2, 3), legend='2')
- self.assertEqual(self.plot.getXAxis().getLabel(), 'x1')
- self.assertEqual(self.plot.getYAxis().getLabel(), 'y1')
-
- # labels changed
- self.plot.setActiveCurve('2')
- self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel')
- self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel')
-
- self.plot.setActiveCurve('1')
- self.assertEqual(self.plot.getXAxis().getLabel(), 'x1')
- self.assertEqual(self.plot.getYAxis().getLabel(), 'y1')
-
- self.plot.clear()
- self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel')
- self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel')
-
- def testPlotActiveCurveSelectionMode(self):
- self.plot.clear()
- self.plot.setActiveCurveHandling(True)
- legend = "curve 1"
- self.plot.addCurve(self.xData, self.yData,
- legend=legend,
- color="green")
-
- # active curve should be None
- self.assertEqual(self.plot.getActiveCurve(just_legend=True), None)
-
- # active curve should be None when None is set as active curve
- self.plot.setActiveCurve(legend)
- current = self.plot.getActiveCurve(just_legend=True)
- self.assertEqual(current, legend)
- self.plot.setActiveCurve(None)
- current = self.plot.getActiveCurve(just_legend=True)
- self.assertEqual(current, None)
-
- # testing it automatically toggles if there is only one
- self.plot.setActiveCurveSelectionMode("legacy")
- current = self.plot.getActiveCurve(just_legend=True)
- self.assertEqual(current, legend)
-
- # active curve should not change when None set as active curve
- self.assertEqual(self.plot.getActiveCurveSelectionMode(), "legacy")
- self.plot.setActiveCurve(None)
- current = self.plot.getActiveCurve(just_legend=True)
- self.assertEqual(current, legend)
-
- # situation where no curve is active
- self.plot.clear()
- self.plot.setActiveCurveHandling(True)
- self.assertEqual(self.plot.getActiveCurveSelectionMode(), "atmostone")
- self.plot.addCurve(self.xData, self.yData,
- legend=legend,
- color="green")
- self.assertEqual(self.plot.getActiveCurve(just_legend=True), None)
- self.plot.addCurve(self.xData2, self.yData2,
- legend="curve 2",
- color="red")
- self.assertEqual(self.plot.getActiveCurve(just_legend=True), None)
- self.plot.setActiveCurveSelectionMode("legacy")
- self.assertEqual(self.plot.getActiveCurve(just_legend=True), None)
-
- # the first curve added should be active
- self.plot.clear()
- self.plot.addCurve(self.xData, self.yData,
- legend=legend,
- color="green")
- self.assertEqual(self.plot.getActiveCurve(just_legend=True), legend)
- self.plot.addCurve(self.xData2, self.yData2,
- legend="curve 2",
- color="red")
- self.assertEqual(self.plot.getActiveCurve(just_legend=True), legend)
-
- def testActiveCurveStyle(self):
- """Test change of active curve style"""
- self.plot.setActiveCurveHandling(True)
- self.plot.setActiveCurveStyle(color='black')
- style = self.plot.getActiveCurveStyle()
- self.assertEqual(style.getColor(), (0., 0., 0., 1.))
- self.assertIsNone(style.getLineStyle())
- self.assertIsNone(style.getLineWidth())
- self.assertIsNone(style.getSymbol())
- self.assertIsNone(style.getSymbolSize())
-
- self.plot.addCurve(x=self.xData, y=self.yData, legend="curve1")
- curve = self.plot.getCurve("curve1")
- curve.setColor('blue')
- curve.setLineStyle('-')
- curve.setLineWidth(1)
- curve.setSymbol('o')
- curve.setSymbolSize(5)
-
- # Check default current style
- defaultStyle = curve.getCurrentStyle()
- self.assertEqual(defaultStyle, CurveStyle(color='blue',
- linestyle='-',
- linewidth=1,
- symbol='o',
- symbolsize=5))
-
- # Activate curve with highlight color=black
- self.plot.setActiveCurve("curve1")
- style = curve.getCurrentStyle()
- self.assertEqual(style.getColor(), (0., 0., 0., 1.))
- self.assertEqual(style.getLineStyle(), '-')
- self.assertEqual(style.getLineWidth(), 1)
- self.assertEqual(style.getSymbol(), 'o')
- self.assertEqual(style.getSymbolSize(), 5)
-
- # Change highlight to linewidth=2
- self.plot.setActiveCurveStyle(linewidth=2)
- style = curve.getCurrentStyle()
- self.assertEqual(style.getColor(), (0., 0., 1., 1.))
- self.assertEqual(style.getLineStyle(), '-')
- self.assertEqual(style.getLineWidth(), 2)
- self.assertEqual(style.getSymbol(), 'o')
- self.assertEqual(style.getSymbolSize(), 5)
-
- self.plot.setActiveCurve(None)
- self.assertEqual(curve.getCurrentStyle(), defaultStyle)
-
- def testActiveImageAndLabels(self):
- # Active image handling always on, no API for toggling it
- self.plot.getXAxis().setLabel('XLabel')
- self.plot.getYAxis().setLabel('YLabel')
-
- # labels changed as active curve
- self.plot.addImage(numpy.arange(100).reshape(10, 10),
- legend='1', xlabel='x1', ylabel='y1')
- self.assertEqual(self.plot.getXAxis().getLabel(), 'x1')
- self.assertEqual(self.plot.getYAxis().getLabel(), 'y1')
-
- # labels not changed as not active curve
- self.plot.addImage(numpy.arange(100).reshape(10, 10),
- legend='2')
- self.assertEqual(self.plot.getXAxis().getLabel(), 'x1')
- self.assertEqual(self.plot.getYAxis().getLabel(), 'y1')
-
- # labels changed
- self.plot.setActiveImage('2')
- self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel')
- self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel')
-
- self.plot.setActiveImage('1')
- self.assertEqual(self.plot.getXAxis().getLabel(), 'x1')
- self.assertEqual(self.plot.getYAxis().getLabel(), 'y1')
-
- self.plot.clear()
- self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel')
- self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel')
-
-
##############################################################################
# Log
##############################################################################
+
class TestPlotEmptyLog(PlotWidgetTestCase):
"""Basic tests for log plot"""
+
def testEmptyPlotTitleLabelsLog(self):
- self.plot.setGraphTitle('Empty Log Log')
- self.plot.getXAxis().setLabel('X')
- self.plot.getYAxis().setLabel('Y')
+ self.plot.setGraphTitle("Empty Log Log")
+ self.plot.getXAxis().setLabel("X")
+ self.plot.getYAxis().setLabel("Y")
self.plot.getXAxis()._setLogarithmic(True)
self.plot.getYAxis()._setLogarithmic(True)
self.plot.resetZoom()
class TestPlotAxes(TestCaseQt, ParametricTestCase):
-
# Test data
xData = numpy.arange(1, 10)
- yData = xData ** 2
+ yData = xData**2
- def __init__(self, methodName='runTest', backend=None):
+ def __init__(self, methodName="runTest", backend=None):
unittest.TestCase.__init__(self, methodName)
self.__backend = backend
@@ -1234,7 +1193,7 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase):
with self.subTest():
if setter is not None:
if not isinstance(value, tuple):
- value = (value, )
+ value = (value,)
setter(*value)
if getter is not None:
self.assertEqual(getter(), expected)
@@ -1324,22 +1283,34 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase):
self.assertEqual(self.plot.isYAxisInverted(), False)
def testLogXWithData(self):
- self.plot.setGraphTitle('Curve X: Log Y: Linear')
- self.plot.addCurve(self.xData, self.yData,
- legend="curve",
- replace=False, resetzoom=True,
- color='green', linestyle="-", symbol='o')
+ self.plot.setGraphTitle("Curve X: Log Y: Linear")
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve",
+ replace=False,
+ resetzoom=True,
+ color="green",
+ linestyle="-",
+ symbol="o",
+ )
axis = self.plot.getXAxis()
axis.setScale(axis.LOGARITHMIC)
self.assertEqual(axis.getScale(), axis.LOGARITHMIC)
def testLogYWithData(self):
- self.plot.setGraphTitle('Curve X: Linear Y: Log')
- self.plot.addCurve(self.xData, self.yData,
- legend="curve",
- replace=False, resetzoom=True,
- color='green', linestyle="-", symbol='o')
+ self.plot.setGraphTitle("Curve X: Linear Y: Log")
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve",
+ replace=False,
+ resetzoom=True,
+ color="green",
+ linestyle="-",
+ symbol="o",
+ )
axis = self.plot.getYAxis()
axis.setScale(axis.LOGARITHMIC)
@@ -1348,11 +1319,17 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase):
self.assertEqual(axis.getScale(), axis.LOGARITHMIC)
def testLogYRightWithData(self):
- self.plot.setGraphTitle('Curve X: Linear Y: Log')
- self.plot.addCurve(self.xData, self.yData,
- legend="curve",
- replace=False, resetzoom=True,
- color='green', linestyle="-", symbol='o')
+ self.plot.setGraphTitle("Curve X: Linear Y: Log")
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve",
+ replace=False,
+ resetzoom=True,
+ color="green",
+ linestyle="-",
+ symbol="o",
+ )
axis = self.plot.getYAxis(axis="right")
axis.setScale(axis.LOGARITHMIC)
@@ -1361,36 +1338,58 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase):
self.assertEqual(axis.getScale(), axis.LOGARITHMIC)
def testLimitsChanged_setLimits(self):
- self.plot.addCurve(self.xData, self.yData,
- legend="curve",
- replace=False, resetzoom=False,
- color='green', linestyle="-", symbol='o')
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve",
+ replace=False,
+ resetzoom=False,
+ color="green",
+ linestyle="-",
+ symbol="o",
+ )
listener = SignalListener()
self.plot.getXAxis().sigLimitsChanged.connect(listener.partial(axis="x"))
self.plot.getYAxis().sigLimitsChanged.connect(listener.partial(axis="y"))
- self.plot.getYAxis(axis="right").sigLimitsChanged.connect(listener.partial(axis="y2"))
+ self.plot.getYAxis(axis="right").sigLimitsChanged.connect(
+ listener.partial(axis="y2")
+ )
self.plot.setLimits(0, 1, 0, 1, 0, 1)
# at least one event per axis
self.assertEqual(len(set(listener.karguments(argumentName="axis"))), 3)
def testLimitsChanged_resetZoom(self):
- self.plot.addCurve(self.xData, self.yData,
- legend="curve",
- replace=False, resetzoom=False,
- color='green', linestyle="-", symbol='o')
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve",
+ replace=False,
+ resetzoom=False,
+ color="green",
+ linestyle="-",
+ symbol="o",
+ )
listener = SignalListener()
self.plot.getXAxis().sigLimitsChanged.connect(listener.partial(axis="x"))
self.plot.getYAxis().sigLimitsChanged.connect(listener.partial(axis="y"))
- self.plot.getYAxis(axis="right").sigLimitsChanged.connect(listener.partial(axis="y2"))
+ self.plot.getYAxis(axis="right").sigLimitsChanged.connect(
+ listener.partial(axis="y2")
+ )
self.plot.resetZoom()
# at least one event per axis
self.assertEqual(len(set(listener.karguments(argumentName="axis"))), 3)
def testLimitsChanged_setXLimit(self):
- self.plot.addCurve(self.xData, self.yData,
- legend="curve",
- replace=False, resetzoom=False,
- color='green', linestyle="-", symbol='o')
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve",
+ replace=False,
+ resetzoom=False,
+ color="green",
+ linestyle="-",
+ symbol="o",
+ )
listener = SignalListener()
axis = self.plot.getXAxis()
axis.sigLimitsChanged.connect(listener)
@@ -1400,10 +1399,16 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase):
self.assertEqual(axis.getLimits(), (20.0, 30.0))
def testLimitsChanged_setYLimit(self):
- self.plot.addCurve(self.xData, self.yData,
- legend="curve",
- replace=False, resetzoom=False,
- color='green', linestyle="-", symbol='o')
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve",
+ replace=False,
+ resetzoom=False,
+ color="green",
+ linestyle="-",
+ symbol="o",
+ )
listener = SignalListener()
axis = self.plot.getYAxis()
axis.sigLimitsChanged.connect(listener)
@@ -1413,10 +1418,16 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase):
self.assertEqual(axis.getLimits(), (20.0, 30.0))
def testLimitsChanged_setYRightLimit(self):
- self.plot.addCurve(self.xData, self.yData,
- legend="curve",
- replace=False, resetzoom=False,
- color='green', linestyle="-", symbol='o')
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve",
+ replace=False,
+ resetzoom=False,
+ color="green",
+ linestyle="-",
+ symbol="o",
+ )
listener = SignalListener()
axis = self.plot.getYAxis(axis="right")
axis.sigLimitsChanged.connect(listener)
@@ -1481,9 +1492,9 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase):
self.qWaitForWindowExposed(self.plot)
margins = self.plot.getAxesMargins()
- self.assertEqual(margins, (.15, .1, .1, .15))
+ self.assertEqual(margins, (0.15, 0.1, 0.1, 0.15))
- for margins in ((0., 0., 0., 0.), (.15, .1, .1, .15)):
+ for margins in ((0.0, 0.0, 0.0, 0.0), (0.15, 0.1, 0.1, 0.15)):
with self.subTest(margins=margins):
self.plot.setAxesMargins(*margins)
self.qapp.processEvents()
@@ -1538,18 +1549,21 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase):
def testAxisExtent(self):
"""Test XAxisExtent and yAxisExtent"""
- for cls, axis in ((XAxisExtent, self.plot.getXAxis()),
- (YAxisExtent, self.plot.getYAxis())):
- for range_, logRange in (((2, 3), (2, 3)),
- ((-2, -1), (1, 100)),
- ((-1, 3), (3. * 0.9, 3. * 1.1))):
+ for cls, axis in (
+ (XAxisExtent, self.plot.getXAxis()),
+ (YAxisExtent, self.plot.getYAxis()),
+ ):
+ for range_, logRange in (
+ ((2, 3), (2, 3)),
+ ((-2, -1), (1, 100)),
+ ((-1, 3), (3.0 * 0.9, 3.0 * 1.1)),
+ ):
extent = cls()
extent.setRange(*range_)
self.plot.addItem(extent)
for isLog, plotRange in ((False, range_), (True, logRange)):
- with self.subTest(
- cls=cls.__name__, range=range_, isLog=isLog):
+ with self.subTest(cls=cls.__name__, range=range_, isLog=isLog):
axis._setLogarithmic(isLog)
self.plot.resetZoom()
self.qapp.processEvents()
@@ -1564,9 +1578,7 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase):
for scale in ("linear", "log"):
xaxis.setScale(scale)
yaxis.setScale(scale)
- for limits in ((1e300, 1e308),
- (-1e308, 1e308),
- (1e-300, 2e-300)):
+ for limits in ((1e300, 1e308), (-1e308, 1e308), (1e-300, 2e-300)):
with self.subTest(scale=scale, limits=limits):
xaxis.setLimits(*limits)
self.qapp.processEvents()
@@ -1581,44 +1593,62 @@ class TestPlotCurveLog(PlotWidgetTestCase, ParametricTestCase):
# Test data
xData = numpy.arange(1000) + 1
- yData = xData ** 2
+ yData = xData**2
def _setLabels(self):
- self.plot.getXAxis().setLabel('X')
- self.plot.getYAxis().setLabel('X * X')
+ self.plot.getXAxis().setLabel("X")
+ self.plot.getYAxis().setLabel("X * X")
def testPlotCurveLogX(self):
self._setLabels()
self.plot.getXAxis()._setLogarithmic(True)
- self.plot.setGraphTitle('Curve X: Log Y: Linear')
-
- self.plot.addCurve(self.xData, self.yData,
- legend="curve",
- replace=False, resetzoom=True,
- color='green', linestyle="-", symbol='o')
+ self.plot.setGraphTitle("Curve X: Log Y: Linear")
+
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve",
+ replace=False,
+ resetzoom=True,
+ color="green",
+ linestyle="-",
+ symbol="o",
+ )
def testPlotCurveLogY(self):
self._setLabels()
self.plot.getYAxis()._setLogarithmic(True)
- self.plot.setGraphTitle('Curve X: Linear Y: Log')
+ self.plot.setGraphTitle("Curve X: Linear Y: Log")
- self.plot.addCurve(self.xData, self.yData,
- legend="curve",
- replace=False, resetzoom=True,
- color='green', linestyle="-", symbol='o')
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve",
+ replace=False,
+ resetzoom=True,
+ color="green",
+ linestyle="-",
+ symbol="o",
+ )
def testPlotCurveLogXY(self):
self._setLabels()
self.plot.getXAxis()._setLogarithmic(True)
self.plot.getYAxis()._setLogarithmic(True)
- self.plot.setGraphTitle('Curve X: Log Y: Log')
+ self.plot.setGraphTitle("Curve X: Log Y: Log")
- self.plot.addCurve(self.xData, self.yData,
- legend="curve",
- replace=False, resetzoom=True,
- color='green', linestyle="-", symbol='o')
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend="curve",
+ replace=False,
+ resetzoom=True,
+ color="green",
+ linestyle="-",
+ symbol="o",
+ )
def testPlotCurveErrorLogXY(self):
self.plot.getXAxis()._setLogarithmic(True)
@@ -1629,24 +1659,31 @@ class TestPlotCurveLog(PlotWidgetTestCase, ParametricTestCase):
errors[::2] = self.xData[::2] + 1
tests = [ # name, xerror, yerror
- ('xerror=3', 3, None),
- ('xerror=N array', errors, None),
- ('xerror=Nx1 array', errors.reshape(len(errors), 1), None),
- ('xerror=2xN array', numpy.array((errors, errors)), None),
- ('yerror=6', None, 6),
- ('yerror=N array', None, errors ** 2),
- ('yerror=Nx1 array', None, (errors ** 2).reshape(len(errors), 1)),
- ('yerror=2xN array', None, numpy.array((errors, errors)) ** 2),
+ ("xerror=3", 3, None),
+ ("xerror=N array", errors, None),
+ ("xerror=Nx1 array", errors.reshape(len(errors), 1), None),
+ ("xerror=2xN array", numpy.array((errors, errors)), None),
+ ("yerror=6", None, 6),
+ ("yerror=N array", None, errors**2),
+ ("yerror=Nx1 array", None, (errors**2).reshape(len(errors), 1)),
+ ("yerror=2xN array", None, numpy.array((errors, errors)) ** 2),
]
for name, xError, yError in tests:
with self.subTest(name):
self.plot.setGraphTitle(name)
- self.plot.addCurve(self.xData, self.yData,
- legend=name,
- xerror=xError, yerror=yError,
- replace=False, resetzoom=True,
- color='green', linestyle="-", symbol='o')
+ self.plot.addCurve(
+ self.xData,
+ self.yData,
+ legend=name,
+ xerror=xError,
+ yerror=yError,
+ replace=False,
+ resetzoom=True,
+ color="green",
+ linestyle="-",
+ symbol="o",
+ )
self.qapp.processEvents()
@@ -1678,12 +1715,12 @@ class TestPlotCurveLog(PlotWidgetTestCase, ParametricTestCase):
"""Add a curve with negative data and toggle log axis"""
arange = numpy.arange(1000) + 1
tests = [ # name, xData, yData
- ('x>0, some negative y', arange, arange - 500),
- ('x>0, y<0', arange, -arange),
- ('some negative x, y>0', arange - 500, arange),
- ('x<0, y>0', -arange, arange),
- ('some negative x and y', arange - 500, arange - 500),
- ('x<0, y<0', -arange, -arange),
+ ("x>0, some negative y", arange, arange - 500),
+ ("x>0, y<0", arange, -arange),
+ ("some negative x, y>0", arange - 500, arange),
+ ("x<0, y>0", -arange, arange),
+ ("some negative x and y", arange - 500, arange - 500),
+ ("x<0, y<0", -arange, -arange),
]
for name, xData, yData in tests:
@@ -1705,54 +1742,65 @@ class TestPlotCurveLog(PlotWidgetTestCase, ParametricTestCase):
yLim = self.plot.getYAxis().getLimits()
positives = xData > 0
if numpy.any(positives):
- self.assertTrue(numpy.allclose(
- xLim, (min(xData[positives]), max(xData[positives]))))
- self.assertEqual(
- yLim, (min(yData[positives]), max(yData[positives])))
+ self.assertTrue(
+ numpy.allclose(
+ xLim, (min(xData[positives]), max(xData[positives]))
+ )
+ )
else: # No positive x in the curve
- self.assertEqual(xLim, (1., 100.))
- self.assertEqual(yLim, (1., 100.))
+ self.assertEqual(xLim, (1.0, 100.0))
+ self.assertEqual(yLim, (min(yData), max(yData)))
# x axis and y axis log
+ previousXLim = self.plot.getXAxis().getLimits()
+ previousYLim = self.plot.getYAxis().getLimits()
self.plot.getYAxis()._setLogarithmic(True)
self.qapp.processEvents()
xLim = self.plot.getXAxis().getLimits()
yLim = self.plot.getYAxis().getLimits()
+
+ self.assertEqual(xLim, previousXLim)
positives = numpy.logical_and(xData > 0, yData > 0)
- if numpy.any(positives):
- self.assertTrue(numpy.allclose(
- xLim, (min(xData[positives]), max(xData[positives]))))
- self.assertTrue(numpy.allclose(
- yLim, (min(yData[positives]), max(yData[positives]))))
+ if previousYLim[0] > 0:
+ self.assertEqual(yLim, previousYLim)
+ elif numpy.any(positives):
+ expectedLimits = min(yData[positives]), max(yData[positives])
+ self.assertTrue(
+ numpy.allclose(yLim, expectedLimits),
+ f"{yLim} != {expectedLimits}",
+ )
else: # No positive x and y in the curve
- self.assertEqual(xLim, (1., 100.))
- self.assertEqual(yLim, (1., 100.))
+ self.assertEqual(yLim, (1.0, 100.0))
# y axis log
+ previousXLim = self.plot.getXAxis().getLimits()
self.plot.getXAxis()._setLogarithmic(False)
self.qapp.processEvents()
xLim = self.plot.getXAxis().getLimits()
yLim = self.plot.getYAxis().getLimits()
+ self.assertEqual(xLim, previousXLim)
positives = yData > 0
if numpy.any(positives):
- self.assertEqual(
- xLim, (min(xData[positives]), max(xData[positives])))
- self.assertTrue(numpy.allclose(
- yLim, (min(yData[positives]), max(yData[positives]))))
+ self.assertTrue(
+ numpy.allclose(
+ yLim, (min(yData[positives]), max(yData[positives]))
+ )
+ )
else: # No positive y in the curve
- self.assertEqual(xLim, (1., 100.))
- self.assertEqual(yLim, (1., 100.))
+ self.assertEqual(yLim, (1.0, 100.0))
# no log axis
+ previousXLim = self.plot.getXAxis().getLimits()
+ previousYLim = self.plot.getYAxis().getLimits()
self.plot.getYAxis()._setLogarithmic(False)
self.qapp.processEvents()
xLim = self.plot.getXAxis().getLimits()
- self.assertEqual(xLim, (min(xData), max(xData)))
+ self.assertEqual(xLim, previousXLim)
yLim = self.plot.getYAxis().getLimits()
- self.assertEqual(yLim, (min(yData), max(yData)))
+ self.assertEqual(yLim, previousYLim)
self.plot.clear()
self.plot.resetZoom()
@@ -1765,71 +1813,83 @@ class TestPlotImageLog(PlotWidgetTestCase):
def setUp(self):
super(TestPlotImageLog, self).setUp()
- self.plot.getXAxis().setLabel('Columns')
- self.plot.getYAxis().setLabel('Rows')
+ self.plot.getXAxis().setLabel("Columns")
+ self.plot.getYAxis().setLabel("Rows")
def testPlotColormapGrayLogX(self):
self.plot.getXAxis()._setLogarithmic(True)
- self.plot.setGraphTitle('CMap X: Log Y: Linear')
-
- colormap = Colormap(name='gray',
- normalization='linear',
- vmin=None,
- vmax=None)
- self.plot.addImage(DATA_2D, legend="image 1",
- origin=(1., 1.), scale=(1., 1.),
- resetzoom=False, colormap=colormap)
+ self.plot.setGraphTitle("CMap X: Log Y: Linear")
+
+ colormap = Colormap(name="gray", normalization="linear", vmin=None, vmax=None)
+ self.plot.addImage(
+ DATA_2D,
+ legend="image 1",
+ origin=(1.0, 1.0),
+ scale=(1.0, 1.0),
+ resetzoom=False,
+ colormap=colormap,
+ )
self.plot.resetZoom()
def testPlotColormapGrayLogY(self):
self.plot.getYAxis()._setLogarithmic(True)
- self.plot.setGraphTitle('CMap X: Linear Y: Log')
-
- colormap = Colormap(name='gray',
- normalization='linear',
- vmin=None,
- vmax=None)
- self.plot.addImage(DATA_2D, legend="image 1",
- origin=(1., 1.), scale=(1., 1.),
- resetzoom=False, colormap=colormap)
+ self.plot.setGraphTitle("CMap X: Linear Y: Log")
+
+ colormap = Colormap(name="gray", normalization="linear", vmin=None, vmax=None)
+ self.plot.addImage(
+ DATA_2D,
+ legend="image 1",
+ origin=(1.0, 1.0),
+ scale=(1.0, 1.0),
+ resetzoom=False,
+ colormap=colormap,
+ )
self.plot.resetZoom()
def testPlotColormapGrayLogXY(self):
self.plot.getXAxis()._setLogarithmic(True)
self.plot.getYAxis()._setLogarithmic(True)
- self.plot.setGraphTitle('CMap X: Log Y: Log')
-
- colormap = Colormap(name='gray',
- normalization='linear',
- vmin=None,
- vmax=None)
- self.plot.addImage(DATA_2D, legend="image 1",
- origin=(1., 1.), scale=(1., 1.),
- resetzoom=False, colormap=colormap)
+ self.plot.setGraphTitle("CMap X: Log Y: Log")
+
+ colormap = Colormap(name="gray", normalization="linear", vmin=None, vmax=None)
+ self.plot.addImage(
+ DATA_2D,
+ legend="image 1",
+ origin=(1.0, 1.0),
+ scale=(1.0, 1.0),
+ resetzoom=False,
+ colormap=colormap,
+ )
self.plot.resetZoom()
def testPlotRgbRgbaLogXY(self):
self.plot.getXAxis()._setLogarithmic(True)
self.plot.getYAxis()._setLogarithmic(True)
- self.plot.setGraphTitle('RGB + RGBA X: Log Y: Log')
+ self.plot.setGraphTitle("RGB + RGBA X: Log Y: Log")
rgb = numpy.array(
- (((0, 0, 0), (128, 0, 0), (255, 0, 0)),
- ((0, 128, 0), (0, 128, 128), (0, 128, 256))),
- dtype=numpy.uint8)
+ (
+ ((0, 0, 0), (128, 0, 0), (255, 0, 0)),
+ ((0, 128, 0), (0, 128, 128), (0, 128, 255)),
+ ),
+ dtype=numpy.uint8,
+ )
- self.plot.addImage(rgb, legend="rgb",
- origin=(1, 1), scale=(10, 10),
- resetzoom=False)
+ self.plot.addImage(
+ rgb, legend="rgb", origin=(1, 1), scale=(10, 10), resetzoom=False
+ )
rgba = numpy.array(
- (((0, 0, 0, .5), (.5, 0, 0, 1), (1, 0, 0, .5)),
- ((0, .5, 0, 1), (0, .5, .5, 1), (0, 1, 1, .5))),
- dtype=numpy.float32)
-
- self.plot.addImage(rgba, legend="rgba",
- origin=(5., 5.), scale=(10., 10.),
- resetzoom=False)
+ (
+ ((0, 0, 0, 0.5), (0.5, 0, 0, 1), (1, 0, 0, 0.5)),
+ ((0, 0.5, 0, 1), (0, 0.5, 0.5, 1), (0, 1, 1, 0.5)),
+ ),
+ dtype=numpy.float32,
+ )
+
+ self.plot.addImage(
+ rgba, legend="rgba", origin=(5.0, 5.0), scale=(10.0, 10.0), resetzoom=False
+ )
self.plot.resetZoom()
@@ -1838,27 +1898,27 @@ class TestPlotMarkerLog(PlotWidgetTestCase):
# Test marker parameters
markers = [ # x, y, color, selectable, draggable
- (10., 10., 'blue', False, False),
- (20., 20., 'red', False, False),
- (40., 100., 'green', True, False),
- (40., 500., 'gray', True, True),
- (60., 800., 'black', False, True),
+ (10.0, 10.0, "blue", False, False),
+ (20.0, 20.0, "red", False, False),
+ (40.0, 100.0, "green", True, False),
+ (40.0, 500.0, "gray", True, True),
+ (60.0, 800.0, "black", False, True),
]
def setUp(self):
super(TestPlotMarkerLog, self).setUp()
- self.plot.getYAxis().setLabel('Rows')
- self.plot.getXAxis().setLabel('Columns')
+ self.plot.getYAxis().setLabel("Rows")
+ self.plot.getXAxis().setLabel("Columns")
self.plot.getXAxis().setAutoScale(False)
self.plot.getYAxis().setAutoScale(False)
self.plot.setKeepDataAspectRatio(False)
- self.plot.setLimits(1., 100., 1., 1000.)
+ self.plot.setLimits(1.0, 100.0, 1.0, 1000.0)
self.plot.getXAxis()._setLogarithmic(True)
self.plot.getYAxis()._setLogarithmic(True)
def testPlotMarkerXLog(self):
- self.plot.setGraphTitle('Markers X, Log axes')
+ self.plot.setGraphTitle("Markers X, Log axes")
for x, _, color, select, drag in self.markers:
name = str(x)
@@ -1870,7 +1930,7 @@ class TestPlotMarkerLog(PlotWidgetTestCase):
self.plot.resetZoom()
def testPlotMarkerYLog(self):
- self.plot.setGraphTitle('Markers Y, Log axes')
+ self.plot.setGraphTitle("Markers Y, Log axes")
for _, y, color, select, drag in self.markers:
name = str(y)
@@ -1882,7 +1942,7 @@ class TestPlotMarkerLog(PlotWidgetTestCase):
self.plot.resetZoom()
def testPlotMarkerPtLog(self):
- self.plot.setGraphTitle('Markers Pt, Log axes')
+ self.plot.setGraphTitle("Markers Pt, Log axes")
for x, y, color, select, drag in self.markers:
name = "{0},{1}".format(x, y)
@@ -1901,9 +1961,9 @@ class TestPlotWidgetSwitchBackend(PlotWidgetTestCase):
@pytest.mark.usefixtures("test_options")
def testSwitchBackend(self):
"""Test switching a plot with a few items"""
- backends = {'none': 'BackendBase', 'mpl': 'BackendMatplotlibQt'}
+ backends = {"none": "BackendBase", "mpl": "BackendMatplotlibQt"}
if self.test_options.WITH_GL_TEST:
- backends['gl'] = 'BackendOpenGL'
+ backends["gl"] = "BackendOpenGL"
self.plot.addImage(numpy.arange(100).reshape(10, 10))
self.plot.addCurve((-3, -2, -1), (1, 2, 3))
@@ -1925,208 +1985,65 @@ class TestPlotWidgetSwitchBackend(PlotWidgetTestCase):
self.assertEqual(self.plot.getItems(), items)
-class TestPlotWidgetSelection(PlotWidgetTestCase):
- """Test PlotWidget.selection and active items handling"""
-
- def _checkSelection(self, selection, current=None, selected=()):
- """Check current item and selected items."""
- self.assertIs(selection.getCurrentItem(), current)
- self.assertEqual(selection.getSelectedItems(), selected)
-
- def testSyncWithActiveItems(self):
- """Test update of PlotWidgetSelection according to active items"""
- listener = SignalListener()
-
- selection = self.plot.selection()
- selection.sigCurrentItemChanged.connect(listener)
- self._checkSelection(selection)
-
- # Active item is current
- self.plot.addImage(((0, 1), (2, 3)), legend='image')
- image = self.plot.getActiveImage()
- self.assertEqual(listener.callCount(), 1)
- self._checkSelection(selection, image, (image,))
-
- # No active = no current
- self.plot.setActiveImage(None)
- self.assertEqual(listener.callCount(), 2)
- self._checkSelection(selection)
-
- # Active item is current
- self.plot.setActiveImage('image')
- self.assertEqual(listener.callCount(), 3)
- self._checkSelection(selection, image, (image,))
-
- # Mosted recently "actived" item is current
- self.plot.addScatter((3, 2, 1), (0, 1, 2), (0, 1, 2), legend='scatter')
- scatter = self.plot.getActiveScatter()
- self.assertEqual(listener.callCount(), 4)
- self._checkSelection(selection, scatter, (scatter, image))
-
- # Previously mosted recently "actived" item is current
- self.plot.setActiveScatter(None)
- self.assertEqual(listener.callCount(), 5)
- self._checkSelection(selection, image, (image,))
-
- # Mosted recently "actived" item is current
- self.plot.setActiveScatter('scatter')
- self.assertEqual(listener.callCount(), 6)
- self._checkSelection(selection, scatter, (scatter, image))
-
- # No active = no current
- self.plot.setActiveImage(None)
- self.plot.setActiveScatter(None)
- self.assertEqual(listener.callCount(), 7)
- self._checkSelection(selection)
-
- # Mosted recently "actived" item is current
- self.plot.setActiveScatter('scatter')
- self.assertEqual(listener.callCount(), 8)
- self.plot.setActiveImage('image')
- self.assertEqual(listener.callCount(), 9)
- self._checkSelection(selection, image, (image, scatter))
-
- # Add a curve which is not active by default
- self.plot.addCurve((0, 1, 2), (0, 1, 2), legend='curve')
- curve = self.plot.getCurve('curve')
- self.assertEqual(listener.callCount(), 9)
- self._checkSelection(selection, image, (image, scatter))
-
- # Mosted recently "actived" item is current
- self.plot.setActiveCurve('curve')
- self.assertEqual(listener.callCount(), 10)
- self._checkSelection(selection, curve, (curve, image, scatter))
-
- # Add a curve which is not active by default
- self.plot.addCurve((0, 1, 2), (0, 1, 2), legend='curve2')
- curve2 = self.plot.getCurve('curve2')
- self.assertEqual(listener.callCount(), 10)
- self._checkSelection(selection, curve, (curve, image, scatter))
-
- # Mosted recently "actived" item is current, previous curve is removed
- self.plot.setActiveCurve('curve2')
- self.assertEqual(listener.callCount(), 11)
- self._checkSelection(selection, curve2, (curve2, image, scatter))
-
- # No items = no current
- self.plot.clear()
- self.assertEqual(listener.callCount(), 12)
- self._checkSelection(selection)
-
- def testPlotWidgetWithItems(self):
- """Test init of selection on a plot with items"""
- self.plot.addImage(((0, 1), (2, 3)), legend='image')
- self.plot.addScatter((3, 2, 1), (0, 1, 2), (0, 1, 2), legend='scatter')
- self.plot.addCurve((0, 1, 2), (0, 1, 2), legend='curve')
- self.plot.setActiveCurve('curve')
-
- selection = self.plot.selection()
- self.assertIsNotNone(selection.getCurrentItem())
- selected = selection.getSelectedItems()
- self.assertEqual(len(selected), 3)
- self.assertIn(self.plot.getActiveCurve(), selected)
- self.assertIn(self.plot.getActiveImage(), selected)
- self.assertIn(self.plot.getActiveScatter(), selected)
-
- def testSetCurrentItem(self):
- """Test setCurrentItem"""
- # Add items to the plot
- self.plot.addImage(((0, 1), (2, 3)), legend='image')
- image = self.plot.getActiveImage()
- self.plot.addScatter((3, 2, 1), (0, 1, 2), (0, 1, 2), legend='scatter')
- scatter = self.plot.getActiveScatter()
- self.plot.addCurve((0, 1, 2), (0, 1, 2), legend='curve')
- self.plot.setActiveCurve('curve')
- curve = self.plot.getActiveCurve()
-
- selection = self.plot.selection()
- self.assertIsNotNone(selection.getCurrentItem())
- self.assertEqual(len(selection.getSelectedItems()), 3)
-
- # Set current to None reset all active items
- selection.setCurrentItem(None)
- self._checkSelection(selection)
- self.assertIsNone(self.plot.getActiveCurve())
- self.assertIsNone(self.plot.getActiveImage())
- self.assertIsNone(self.plot.getActiveScatter())
-
- # Set current to an item makes it active
- selection.setCurrentItem(image)
- self._checkSelection(selection, image, (image,))
- self.assertIsNone(self.plot.getActiveCurve())
- self.assertIs(self.plot.getActiveImage(), image)
- self.assertIsNone(self.plot.getActiveScatter())
-
- # Set current to an item makes it active and keeps other active
- selection.setCurrentItem(curve)
- self._checkSelection(selection, curve, (curve, image))
- self.assertIs(self.plot.getActiveCurve(), curve)
- self.assertIs(self.plot.getActiveImage(), image)
- self.assertIsNone(self.plot.getActiveScatter())
-
- # Set current to an item makes it active and keeps other active
- selection.setCurrentItem(scatter)
- self._checkSelection(selection, scatter, (scatter, curve, image))
- self.assertIs(self.plot.getActiveCurve(), curve)
- self.assertIs(self.plot.getActiveImage(), image)
- self.assertIs(self.plot.getActiveScatter(), scatter)
-
-
@pytest.mark.usefixtures("use_opengl")
class TestPlotWidget_Gl(TestPlotWidget):
- backend="gl"
+ backend = "gl"
+
@pytest.mark.usefixtures("use_opengl")
class TestPlotImage_Gl(TestPlotImage):
- backend="gl"
+ backend = "gl"
+
@pytest.mark.usefixtures("use_opengl")
class TestPlotCurve_Gl(TestPlotCurve):
- backend="gl"
+ backend = "gl"
+
@pytest.mark.usefixtures("use_opengl")
class TestPlotHistogram_Gl(TestPlotHistogram):
- backend="gl"
+ backend = "gl"
+
@pytest.mark.usefixtures("use_opengl")
class TestPlotScatter_Gl(TestPlotScatter):
- backend="gl"
+ backend = "gl"
+
@pytest.mark.usefixtures("use_opengl")
class TestPlotMarker_Gl(TestPlotMarker):
- backend="gl"
+ backend = "gl"
+
@pytest.mark.usefixtures("use_opengl")
class TestPlotItem_Gl(TestPlotItem):
- backend="gl"
+ backend = "gl"
+
@pytest.mark.usefixtures("use_opengl")
class TestPlotAxes_Gl(TestPlotAxes):
- backend="gl"
+ backend = "gl"
-@pytest.mark.usefixtures("use_opengl")
-class TestPlotActiveCurveImage_Gl(TestPlotActiveCurveImage):
- backend="gl"
@pytest.mark.usefixtures("use_opengl")
class TestPlotEmptyLog_Gl(TestPlotEmptyLog):
- backend="gl"
+ backend = "gl"
+
@pytest.mark.usefixtures("use_opengl")
class TestPlotCurveLog_Gl(TestPlotCurveLog):
- backend="gl"
+ backend = "gl"
+
@pytest.mark.usefixtures("use_opengl")
class TestPlotImageLog_Gl(TestPlotImageLog):
- backend="gl"
+ backend = "gl"
+
@pytest.mark.usefixtures("use_opengl")
class TestPlotMarkerLog_Gl(TestPlotMarkerLog):
- backend="gl"
+ backend = "gl"
-@pytest.mark.usefixtures("use_opengl")
-class TestPlotWidgetSelection_Gl(TestPlotWidgetSelection):
- backend="gl"
class TestSpecial_ExplicitMplBackend(TestSpecialBackend):
- backend="mpl"
+ backend = "mpl"
diff --git a/src/silx/gui/plot/test/testPlotWidgetActiveItem.py b/src/silx/gui/plot/test/testPlotWidgetActiveItem.py
new file mode 100755
index 0000000..99285a8
--- /dev/null
+++ b/src/silx/gui/plot/test/testPlotWidgetActiveItem.py
@@ -0,0 +1,416 @@
+# /*##########################################################################
+#
+# Copyright (c) 2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test PlotWidget active item"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "11/12/2023"
+
+
+import numpy
+import pytest
+
+from silx.gui.utils.testutils import SignalListener
+from silx.gui.plot.items.curve import CurveStyle
+
+
+@pytest.mark.parametrize("plotWidget", ("mpl", "gl"), indirect=True)
+def testActiveCurveAndLabels(plotWidget):
+ # Active curve handling off, no label change
+ plotWidget.setActiveCurveHandling(False)
+ plotWidget.getXAxis().setLabel("XLabel")
+ plotWidget.getYAxis().setLabel("YLabel")
+ plotWidget.addCurve((1, 2), (1, 2))
+ assert plotWidget.getXAxis().getLabel() == "XLabel"
+ assert plotWidget.getYAxis().getLabel() == "YLabel"
+
+ plotWidget.addCurve((1, 2), (2, 3), xlabel="x1", ylabel="y1")
+ assert plotWidget.getXAxis().getLabel() == "XLabel"
+ assert plotWidget.getYAxis().getLabel() == "YLabel"
+
+ plotWidget.clear()
+ assert plotWidget.getXAxis().getLabel() == "XLabel"
+ assert plotWidget.getYAxis().getLabel() == "YLabel"
+
+ # Active curve handling on, label changes
+ plotWidget.setActiveCurveHandling(True)
+ plotWidget.getXAxis().setLabel("XLabel")
+ plotWidget.getYAxis().setLabel("YLabel")
+
+ # labels changed as active curve
+ plotWidget.addCurve((1, 2), (1, 2), legend="1", xlabel="x1", ylabel="y1")
+ plotWidget.setActiveCurve("1")
+ assert plotWidget.getXAxis().getLabel() == "x1"
+ assert plotWidget.getYAxis().getLabel() == "y1"
+
+ # labels not changed as not active curve
+ plotWidget.addCurve((1, 2), (2, 3), legend="2")
+ assert plotWidget.getXAxis().getLabel() == "x1"
+ assert plotWidget.getYAxis().getLabel() == "y1"
+
+ # labels changed
+ plotWidget.setActiveCurve("2")
+ assert plotWidget.getXAxis().getLabel() == "XLabel"
+ assert plotWidget.getYAxis().getLabel() == "YLabel"
+
+ plotWidget.setActiveCurve("1")
+ assert plotWidget.getXAxis().getLabel() == "x1"
+ assert plotWidget.getYAxis().getLabel() == "y1"
+
+ plotWidget.clear()
+ assert plotWidget.getXAxis().getLabel() == "XLabel"
+ assert plotWidget.getYAxis().getLabel() == "YLabel"
+
+ plotWidget.setActiveCurveHandling(False)
+
+
+@pytest.mark.parametrize("plotWidget", ("mpl", "gl"), indirect=True)
+def testPlotActiveCurveSelectionMode(plotWidget):
+ xData = numpy.arange(1000)
+ yData = -500 + 100 * numpy.sin(xData)
+ xData2 = xData + 1000
+ yData2 = xData - 1000 + 200 * numpy.random.random(1000)
+
+ plotWidget.clear()
+ plotWidget.setActiveCurveHandling(True)
+ legend = "curve 1"
+ plotWidget.addCurve(xData, yData, legend=legend, color="green")
+
+ # active curve should be None
+ assert plotWidget.getActiveCurve(just_legend=True) is None
+
+ # active curve should be None when None is set as active curve
+ plotWidget.setActiveCurve(legend)
+ current = plotWidget.getActiveCurve(just_legend=True)
+ assert current == legend
+ plotWidget.setActiveCurve(None)
+ current = plotWidget.getActiveCurve(just_legend=True)
+ assert current is None
+
+ # testing it automatically toggles if there is only one
+ plotWidget.setActiveCurveSelectionMode("legacy")
+ current = plotWidget.getActiveCurve(just_legend=True)
+ assert current == legend
+
+ # active curve should not change when None set as active curve
+ assert plotWidget.getActiveCurveSelectionMode() == "legacy"
+ plotWidget.setActiveCurve(None)
+ current = plotWidget.getActiveCurve(just_legend=True)
+ assert current == legend
+
+ # situation where no curve is active
+ plotWidget.clear()
+ plotWidget.setActiveCurveHandling(True)
+ assert plotWidget.getActiveCurveSelectionMode() == "atmostone"
+ plotWidget.addCurve(xData, yData, legend=legend, color="green")
+ assert plotWidget.getActiveCurve(just_legend=True) is None
+ plotWidget.addCurve(xData2, yData2, legend="curve 2", color="red")
+ assert plotWidget.getActiveCurve(just_legend=True) is None
+ plotWidget.setActiveCurveSelectionMode("legacy")
+ assert plotWidget.getActiveCurve(just_legend=True) is None
+
+ # the first curve added should be active
+ plotWidget.clear()
+ plotWidget.addCurve(xData, yData, legend=legend, color="green")
+ assert plotWidget.getActiveCurve(just_legend=True) == legend
+ plotWidget.addCurve(xData2, yData2, legend="curve 2", color="red")
+ assert plotWidget.getActiveCurve(just_legend=True) == legend
+
+ plotWidget.setActiveCurveHandling(False)
+
+
+@pytest.mark.parametrize("plotWidget", ("mpl", "gl"), indirect=True)
+def testActiveCurveStyle(plotWidget):
+ """Test change of active curve style"""
+ plotWidget.setActiveCurveHandling(True)
+ plotWidget.setActiveCurveStyle(color="black")
+ style = plotWidget.getActiveCurveStyle()
+ assert style.getColor() == (0.0, 0.0, 0.0, 1.0)
+ assert style.getLineStyle() is None
+ assert style.getLineWidth() is None
+ assert style.getSymbol() is None
+ assert style.getSymbolSize() is None
+
+ xData = numpy.arange(1000)
+ yData = -500 + 100 * numpy.sin(xData)
+ plotWidget.addCurve(x=xData, y=yData, legend="curve1")
+ curve = plotWidget.getCurve("curve1")
+ curve.setColor("blue")
+ curve.setLineStyle("-")
+ curve.setLineWidth(1)
+ curve.setSymbol("o")
+ curve.setSymbolSize(5)
+
+ # Check default current style
+ defaultStyle = curve.getCurrentStyle()
+ assert defaultStyle == CurveStyle(
+ color="blue", linestyle="-", linewidth=1, symbol="o", symbolsize=5
+ )
+
+ # Activate curve with highlight color=black
+ plotWidget.setActiveCurve("curve1")
+ style = curve.getCurrentStyle()
+ assert style.getColor() == (0.0, 0.0, 0.0, 1.0)
+ assert style.getLineStyle() == "-"
+ assert style.getLineWidth() == 1
+ assert style.getSymbol() == "o"
+ assert style.getSymbolSize() == 5
+
+ # Change highlight to linewidth=2
+ plotWidget.setActiveCurveStyle(linewidth=2)
+ style = curve.getCurrentStyle()
+ assert style.getColor() == (0.0, 0.0, 1.0, 1.0)
+ assert style.getLineStyle() == "-"
+ assert style.getLineWidth() == 2
+ assert style.getSymbol() == "o"
+ assert style.getSymbolSize() == 5
+
+ plotWidget.setActiveCurve(None)
+ assert curve.getCurrentStyle() == defaultStyle
+
+ plotWidget.setActiveCurveHandling(False)
+
+
+@pytest.mark.parametrize("plotWidget", ("mpl", "gl"), indirect=True)
+def testActiveImageAndLabels(plotWidget):
+ # Active image handling always on, no API for toggling it
+ plotWidget.getXAxis().setLabel("XLabel")
+ plotWidget.getYAxis().setLabel("YLabel")
+
+ # labels changed as active curve
+ plotWidget.addImage(
+ numpy.arange(100).reshape(10, 10), legend="1", xlabel="x1", ylabel="y1"
+ )
+ assert plotWidget.getXAxis().getLabel() == "x1"
+ assert plotWidget.getYAxis().getLabel() == "y1"
+
+ # labels not changed as not active curve
+ plotWidget.addImage(numpy.arange(100).reshape(10, 10), legend="2")
+ assert plotWidget.getXAxis().getLabel() == "x1"
+ assert plotWidget.getYAxis().getLabel() == "y1"
+
+ # labels changed
+ plotWidget.setActiveImage("2")
+ assert plotWidget.getXAxis().getLabel() == "XLabel"
+ assert plotWidget.getYAxis().getLabel() == "YLabel"
+
+ plotWidget.setActiveImage("1")
+ assert plotWidget.getXAxis().getLabel() == "x1"
+ assert plotWidget.getYAxis().getLabel() == "y1"
+
+ plotWidget.clear()
+ assert plotWidget.getXAxis().getLabel() == "XLabel"
+ assert plotWidget.getYAxis().getLabel() == "YLabel"
+
+ plotWidget.setActiveCurveHandling(False)
+
+
+def _checkSelection(selection, current=None, selected=()):
+ """Check current item and selected items."""
+ assert selection.getCurrentItem() is current
+ assert selection.getSelectedItems() == selected
+
+
+@pytest.mark.parametrize("plotWidget", ("mpl", "gl"), indirect=True)
+def testSelectionSyncWithActiveItems(plotWidget):
+ """Test update of PlotWidgetSelection according to active items"""
+ listener = SignalListener()
+
+ selection = plotWidget.selection()
+ selection.sigCurrentItemChanged.connect(listener)
+ _checkSelection(selection)
+
+ # Active item is current
+ plotWidget.addImage(((0, 1), (2, 3)), legend="image")
+ image = plotWidget.getActiveImage()
+ assert listener.callCount() == 1
+ _checkSelection(selection, image, (image,))
+
+ # No active = no current
+ plotWidget.setActiveImage(None)
+ assert listener.callCount() == 2
+ _checkSelection(selection)
+
+ # Active item is current
+ plotWidget.setActiveImage("image")
+ assert listener.callCount() == 3
+ _checkSelection(selection, image, (image,))
+
+ # Mosted recently "actived" item is current
+ plotWidget.addScatter((3, 2, 1), (0, 1, 2), (0, 1, 2), legend="scatter")
+ scatter = plotWidget.getActiveScatter()
+ assert listener.callCount() == 4
+ _checkSelection(selection, scatter, (scatter, image))
+
+ # Previously mosted recently "actived" item is current
+ plotWidget.setActiveScatter(None)
+ assert listener.callCount() == 5
+ _checkSelection(selection, image, (image,))
+
+ # Mosted recently "actived" item is current
+ plotWidget.setActiveScatter("scatter")
+ assert listener.callCount() == 6
+ _checkSelection(selection, scatter, (scatter, image))
+
+ # No active = no current
+ plotWidget.setActiveImage(None)
+ plotWidget.setActiveScatter(None)
+ assert listener.callCount() == 7
+ _checkSelection(selection)
+
+ # Mosted recently "actived" item is current
+ plotWidget.setActiveScatter("scatter")
+ assert listener.callCount() == 8
+ plotWidget.setActiveImage("image")
+ assert listener.callCount() == 9
+ _checkSelection(selection, image, (image, scatter))
+
+ # Add a curve which is not active by default
+ plotWidget.addCurve((0, 1, 2), (0, 1, 2), legend="curve")
+ curve = plotWidget.getCurve("curve")
+ assert listener.callCount() == 9
+ _checkSelection(selection, image, (image, scatter))
+
+ # Mosted recently "actived" item is current
+ plotWidget.setActiveCurve("curve")
+ assert listener.callCount() == 10
+ _checkSelection(selection, curve, (curve, image, scatter))
+
+ # Add a curve which is not active by default
+ plotWidget.addCurve((0, 1, 2), (0, 1, 2), legend="curve2")
+ curve2 = plotWidget.getCurve("curve2")
+ assert listener.callCount() == 10
+ _checkSelection(selection, curve, (curve, image, scatter))
+
+ # Mosted recently "actived" item is current, previous curve is removed
+ plotWidget.setActiveCurve("curve2")
+ assert listener.callCount() == 11
+ _checkSelection(selection, curve2, (curve2, image, scatter))
+
+ # No items = no current
+ plotWidget.clear()
+ assert listener.callCount() == 12
+ _checkSelection(selection)
+
+
+@pytest.mark.parametrize("plotWidget", ("mpl", "gl"), indirect=True)
+def testSelectionWithItems(plotWidget):
+ """Test init of selection on a plot with items"""
+ plotWidget.addImage(((0, 1), (2, 3)), legend="image")
+ plotWidget.addScatter((3, 2, 1), (0, 1, 2), (0, 1, 2), legend="scatter")
+ plotWidget.addCurve((0, 1, 2), (0, 1, 2), legend="curve")
+ plotWidget.setActiveCurve("curve")
+
+ selection = plotWidget.selection()
+ assert selection.getCurrentItem() is not None
+ selected = selection.getSelectedItems()
+ assert len(selected) == 3
+ assert plotWidget.getActiveCurve() in selected
+ assert plotWidget.getActiveImage() in selected
+ assert plotWidget.getActiveScatter() in selected
+
+
+@pytest.mark.parametrize("plotWidget", ("mpl", "gl"), indirect=True)
+def testSelectionSetCurrentItem(plotWidget):
+ """Test setCurrentItem"""
+ # Add items to the plot
+ plotWidget.addImage(((0, 1), (2, 3)), legend="image")
+ image = plotWidget.getActiveImage()
+ plotWidget.addScatter((3, 2, 1), (0, 1, 2), (0, 1, 2), legend="scatter")
+ scatter = plotWidget.getActiveScatter()
+ plotWidget.addCurve((0, 1, 2), (0, 1, 2), legend="curve")
+ plotWidget.setActiveCurve("curve")
+ curve = plotWidget.getActiveCurve()
+
+ selection = plotWidget.selection()
+ assert selection.getCurrentItem() is not None
+ assert len(selection.getSelectedItems()) == 3
+
+ # Set current to None reset all active items
+ selection.setCurrentItem(None)
+ _checkSelection(selection)
+ assert plotWidget.getActiveCurve() is None
+ assert plotWidget.getActiveImage() is None
+ assert plotWidget.getActiveScatter() is None
+
+ # Set current to an item makes it active
+ selection.setCurrentItem(image)
+ _checkSelection(selection, image, (image,))
+ assert plotWidget.getActiveCurve() is None
+ assert plotWidget.getActiveImage() is image
+ assert plotWidget.getActiveScatter() is None
+
+ # Set current to an item makes it active and keeps other active
+ selection.setCurrentItem(curve)
+ _checkSelection(selection, curve, (curve, image))
+ assert plotWidget.getActiveCurve() is curve
+ assert plotWidget.getActiveImage() is image
+ assert plotWidget.getActiveScatter() is None
+
+ # Set current to an item makes it active and keeps other active
+ selection.setCurrentItem(scatter)
+ _checkSelection(selection, scatter, (scatter, curve, image))
+ assert plotWidget.getActiveCurve() is curve
+ assert plotWidget.getActiveImage() is image
+ assert plotWidget.getActiveScatter() is scatter
+
+
+def testSetActiveCurveWithInstance(plotWidget):
+ """Test setting the active curve with a curve item instance"""
+ plotWidget.addCurve((0, 1), (0, 1), legend="curve0")
+ plotWidget.addCurve((0, 1), (1, 0), legend="curve1")
+ curve0, curve1 = plotWidget.getItems()
+
+ plotWidget.setActiveCurve(curve0)
+ assert plotWidget.getActiveCurve() is curve0
+
+ plotWidget.setActiveCurve(curve1)
+ assert plotWidget.getActiveCurve() is curve1
+
+ plotWidget.setActiveCurve(None)
+ assert plotWidget.getActiveCurve() is None
+
+
+def testSetActiveImageWithInstance(plotWidget):
+ """Test setting the active image with an image item instance"""
+ plotWidget.addImage(((0, 1), (2, 3)), legend="image")
+ image = plotWidget.getItems()[0]
+
+ plotWidget.setActiveImage(None)
+ assert plotWidget.getActiveImage() is None
+
+ plotWidget.setActiveImage(image)
+ assert plotWidget.getActiveImage() is image
+
+
+def testSetActiveScatterWithInstance(plotWidget):
+ """Test setting the active scatter with a scatter item instance"""
+ plotWidget.addScatter((0, 1), (0, 1), (0, 1), legend="scatter")
+ scatter = plotWidget.getItems()[0]
+
+ plotWidget.setActiveScatter(None)
+ assert plotWidget.getActiveScatter() is None
+
+ plotWidget.setActiveScatter(scatter)
+ assert plotWidget.getActiveScatter() is scatter
diff --git a/src/silx/gui/plot/test/testPlotWidgetDataMargins.py b/src/silx/gui/plot/test/testPlotWidgetDataMargins.py
new file mode 100644
index 0000000..4eb5134
--- /dev/null
+++ b/src/silx/gui/plot/test/testPlotWidgetDataMargins.py
@@ -0,0 +1,135 @@
+# /*##########################################################################
+#
+# Copyright (c) 2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test PlotWidget features related to data margins"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "11/05/2023"
+
+import numpy
+import pytest
+
+
+def testDefaultDataMargins(plotWidget):
+ """Test default PlotWidget data margins: No margins"""
+ assert plotWidget.getDataMargins() == (0, 0, 0, 0)
+
+
+def testResetZoomDataMarginsLinearAxes(qapp, plotWidget):
+ """Test PlotWidget.setDataMargins effect on resetZoom with linear axis scales"""
+
+ margins = 0.1, 0.2, 0.3, 0.4
+ plotWidget.setDataMargins(*margins)
+
+ plotWidget.resetZoom()
+ qapp.processEvents()
+
+ retrievedMargins = plotWidget.getDataMargins()
+ assert retrievedMargins == margins
+
+ dataRange = 100 - 1
+ expectedXLimits = 1 - 0.1 * dataRange, 100 + 0.2 * dataRange
+ expectedYLimits = 1 - 0.3 * dataRange, 100 + 0.4 * dataRange
+
+ assert plotWidget.getXAxis().getLimits() == expectedXLimits
+ assert plotWidget.getYAxis().getLimits() == expectedYLimits
+ assert plotWidget.getYAxis(axis="right").getLimits() == expectedYLimits
+
+
+def testResetZoomDataMarginsLogAxes(qapp, plotWidget):
+ """Test PlotWidget.setDataMargins effect on resetZoom with log axis scales"""
+ plotWidget.getXAxis().setScale("log")
+ plotWidget.getYAxis().setScale("log")
+
+ dataMargins = 0.1, 0.2, 0.3, 0.4
+ plotWidget.setDataMargins(*dataMargins)
+
+ plotWidget.resetZoom()
+ qapp.processEvents()
+
+ retrievedMargins = plotWidget.getDataMargins()
+ assert retrievedMargins == dataMargins
+
+ logMin, logMax = numpy.log10(1), numpy.log10(100)
+ logRange = logMax - logMin
+ expectedXLimits = pow(10.0, logMin - 0.1 * logRange), pow(
+ 10.0, logMax + 0.2 * logRange
+ )
+ expectedYLimits = pow(10.0, logMin - 0.3 * logRange), pow(
+ 10.0, logMax + 0.4 * logRange
+ )
+
+ assert plotWidget.getXAxis().getLimits() == expectedXLimits
+ assert plotWidget.getYAxis().getLimits() == expectedYLimits
+ assert plotWidget.getYAxis(axis="right").getLimits() == expectedYLimits
+
+
+@pytest.mark.parametrize("margins", [False, True, (0, 0, 0, 0)])
+def testSetLimitsNoDataMargins(plotWidget, margins):
+ """Test PlotWidget.setLimits without data margins"""
+ xlimits = 1, 2
+ ylimits = 3, 4
+ y2limits = 5, 6
+ plotWidget.setLimits(*xlimits, *ylimits, *y2limits, margins=margins)
+
+ assert plotWidget.getXAxis().getLimits() == xlimits
+ assert plotWidget.getYAxis().getLimits() == ylimits
+ assert plotWidget.getYAxis(axis="right").getLimits() == y2limits
+
+
+@pytest.mark.parametrize(
+ "margins,expectedLimits",
+ [
+ # margins=False: use limits as is
+ (
+ False,
+ (1, 2, 3, 4, 5, 6),
+ ),
+ # margins=True: apply data margins
+ (
+ True,
+ (1 - 0.1, 2 + 0.2, 3 - 0.3, 4 + 0.4, 5 - 0.3, 6 + 0.4),
+ ),
+ # margins=tuple: apply provided margins
+ (
+ (0.4, 0.3, 0.2, 0.1),
+ (1 - 0.4, 2 + 0.3, 3 - 0.2, 4 + 0.1, 5 - 0.2, 6 + 0.1),
+ ),
+ ],
+)
+def testSetLimitsWithDataMargins(qapp, plotWidget, margins, expectedLimits):
+ """Test PlotWidget.setLimits with data margins"""
+ dataMargins = 0.1, 0.2, 0.3, 0.4
+ limits = 1, 2, 3, 4, 5, 6
+
+ plotWidget.setDataMargins(*dataMargins)
+ plotWidget.setLimits(*limits, margins=margins)
+ qapp.processEvents()
+
+ retrievedLimits = (
+ *plotWidget.getXAxis().getLimits(),
+ *plotWidget.getYAxis().getLimits(),
+ *plotWidget.getYAxis(axis="right").getLimits(),
+ )
+ assert retrievedLimits == expectedLimits
diff --git a/src/silx/gui/plot/test/testPlotWidgetNoBackend.py b/src/silx/gui/plot/test/testPlotWidgetNoBackend.py
index 787d5a8..d9d5706 100644
--- a/src/silx/gui/plot/test/testPlotWidgetNoBackend.py
+++ b/src/silx/gui/plot/test/testPlotWidgetNoBackend.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -34,6 +34,8 @@ from silx.utils.testutils import ParametricTestCase
import numpy
+import silx
+from silx.gui.colors import rgba
from silx.gui.plot.PlotWidget import PlotWidget
from silx.gui.plot.items.histogram import _getHistogramCurve, _computeEdges
@@ -44,9 +46,9 @@ class TestPlot(unittest.TestCase):
def testPlotTitleLabels(self):
"""Create a Plot and set the labels"""
- plot = PlotWidget(backend='none')
+ plot = PlotWidget(backend="none")
- title, xlabel, ylabel = 'the title', 'x label', 'y label'
+ title, xlabel, ylabel = "the title", "x label", "y label"
plot.setGraphTitle(title)
plot.getXAxis().setLabel(xlabel)
plot.getYAxis().setLabel(ylabel)
@@ -58,26 +60,29 @@ class TestPlot(unittest.TestCase):
def testAddNoRemove(self):
"""add objects to the Plot"""
- plot = PlotWidget(backend='none')
+ plot = PlotWidget(backend="none")
plot.addCurve(x=(1, 2, 3), y=(3, 2, 1))
- plot.addImage(numpy.arange(100.).reshape(10, -1))
- plot.addShape(numpy.array((1., 10.)),
- numpy.array((10., 10.)),
- shape="rectangle")
- plot.addXMarker(10.)
+ plot.addImage(numpy.arange(100.0).reshape(10, -1))
+ plot.addShape(
+ numpy.array((1.0, 10.0)), numpy.array((10.0, 10.0)), shape="rectangle"
+ )
+ plot.addXMarker(10.0)
class TestPlotRanges(ParametricTestCase):
"""Basic tests of Plot data ranges without backend"""
- _getValidValues = {True: lambda ar: ar > 0,
- False: lambda ar: numpy.ones(shape=ar.shape,
- dtype=bool)}
+ _getValidValues = {
+ True: lambda ar: ar > 0,
+ False: lambda ar: numpy.ones(shape=ar.shape, dtype=bool),
+ }
@staticmethod
def _getRanges(arrays, are_logs):
- gen = (TestPlotRanges._getValidValues[is_log](ar)
- for (ar, is_log) in zip(arrays, are_logs))
+ gen = (
+ TestPlotRanges._getValidValues[is_log](ar)
+ for (ar, is_log) in zip(arrays, are_logs)
+ )
indices = numpy.where(reduce(numpy.logical_and, gen))[0]
if len(indices) > 0:
ranges = [(ar[indices[0]], ar[indices[-1]]) for ar in arrays]
@@ -96,13 +101,15 @@ class TestPlotRanges(ParametricTestCase):
def testDataRangeNoPlot(self):
"""empty plot data range"""
- plot = PlotWidget(backend='none')
+ plot = PlotWidget(backend="none")
- for logX, logY in ((False, False),
- (True, False),
- (True, True),
- (False, True),
- (False, False)):
+ for logX, logY in (
+ (False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False),
+ ):
with self.subTest(logX=logX, logY=logY):
plot.getXAxis()._setLogarithmic(logX)
plot.getYAxis()._setLogarithmic(logY)
@@ -114,27 +121,25 @@ class TestPlotRanges(ParametricTestCase):
def testDataRangeLeft(self):
"""left axis range"""
- plot = PlotWidget(backend='none')
+ plot = PlotWidget(backend="none")
xData = numpy.arange(10) - 4.9 # range : -4.9 , 4.1
yData = numpy.arange(10) - 6.9 # range : -6.9 , 2.1
- plot.addCurve(x=xData,
- y=yData,
- legend='plot_0',
- yaxis='left')
+ plot.addCurve(x=xData, y=yData, legend="plot_0", yaxis="left")
- for logX, logY in ((False, False),
- (True, False),
- (True, True),
- (False, True),
- (False, False)):
+ for logX, logY in (
+ (False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False),
+ ):
with self.subTest(logX=logX, logY=logY):
plot.getXAxis()._setLogarithmic(logX)
plot.getYAxis()._setLogarithmic(logY)
dataRange = plot.getDataRange()
- xRange, yRange = self._getRanges([xData, yData],
- [logX, logY])
+ xRange, yRange = self._getRanges([xData, yData], [logX, logY])
self.assertSequenceEqual(dataRange.x, xRange)
self.assertSequenceEqual(dataRange.y, yRange)
self.assertIsNone(dataRange.yright)
@@ -142,25 +147,23 @@ class TestPlotRanges(ParametricTestCase):
def testDataRangeRight(self):
"""right axis range"""
- plot = PlotWidget(backend='none')
+ plot = PlotWidget(backend="none")
xData = numpy.arange(10) - 4.9 # range : -4.9 , 4.1
yData = numpy.arange(10) - 6.9 # range : -6.9 , 2.1
- plot.addCurve(x=xData,
- y=yData,
- legend='plot_0',
- yaxis='right')
-
- for logX, logY in ((False, False),
- (True, False),
- (True, True),
- (False, True),
- (False, False)):
+ plot.addCurve(x=xData, y=yData, legend="plot_0", yaxis="right")
+
+ for logX, logY in (
+ (False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False),
+ ):
with self.subTest(logX=logX, logY=logY):
plot.getXAxis()._setLogarithmic(logX)
plot.getYAxis()._setLogarithmic(logY)
dataRange = plot.getDataRange()
- xRange, yRange = self._getRanges([xData, yData],
- [logX, logY])
+ xRange, yRange = self._getRanges([xData, yData], [logX, logY])
self.assertSequenceEqual(dataRange.x, xRange)
self.assertIsNone(dataRange.y)
self.assertSequenceEqual(dataRange.yright, yRange)
@@ -169,69 +172,70 @@ class TestPlotRanges(ParametricTestCase):
"""image data range"""
origin = (-10, 25)
- scale = (3., 8.)
- image = numpy.arange(100.).reshape(20, 5)
-
- plot = PlotWidget(backend='none')
- plot.addImage(image,
- origin=origin, scale=scale)
-
- xRange = numpy.array([0., image.shape[1] * scale[0]]) + origin[0]
- yRange = numpy.array([0., image.shape[0] * scale[1]]) + origin[1]
-
- ranges = {(False, False): (xRange, yRange),
- (True, False): (None, None),
- (True, True): (None, None),
- (False, True): (None, None)}
-
- for logX, logY in ((False, False),
- (True, False),
- (True, True),
- (False, True),
- (False, False)):
+ scale = (3.0, 8.0)
+ image = numpy.arange(100.0).reshape(20, 5)
+
+ plot = PlotWidget(backend="none")
+ plot.addImage(image, origin=origin, scale=scale)
+
+ xRange = numpy.array([0.0, image.shape[1] * scale[0]]) + origin[0]
+ yRange = numpy.array([0.0, image.shape[0] * scale[1]]) + origin[1]
+
+ ranges = {
+ (False, False): (xRange, yRange),
+ (True, False): (None, None),
+ (True, True): (None, None),
+ (False, True): (None, None),
+ }
+
+ for logX, logY in (
+ (False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False),
+ ):
with self.subTest(logX=logX, logY=logY):
plot.getXAxis()._setLogarithmic(logX)
plot.getYAxis()._setLogarithmic(logY)
dataRange = plot.getDataRange()
xRange, yRange = ranges[logX, logY]
- self.assertTrue(numpy.array_equal(dataRange.x, xRange),
- msg='{0} != {1}'.format(dataRange.x, xRange))
- self.assertTrue(numpy.array_equal(dataRange.y, yRange),
- msg='{0} != {1}'.format(dataRange.y, yRange))
+ self.assertTrue(
+ numpy.array_equal(dataRange.x, xRange),
+ msg="{0} != {1}".format(dataRange.x, xRange),
+ )
+ self.assertTrue(
+ numpy.array_equal(dataRange.y, yRange),
+ msg="{0} != {1}".format(dataRange.y, yRange),
+ )
self.assertIsNone(dataRange.yright)
def testDataRangeLeftRight(self):
"""right+left axis range"""
- plot = PlotWidget(backend='none')
+ plot = PlotWidget(backend="none")
xData_l = numpy.arange(10) - 0.9 # range : -0.9 , 8.1
yData_l = numpy.arange(10) - 1.9 # range : -1.9 , 7.1
- plot.addCurve(x=xData_l,
- y=yData_l,
- legend='plot_l',
- yaxis='left')
+ plot.addCurve(x=xData_l, y=yData_l, legend="plot_l", yaxis="left")
xData_r = numpy.arange(10) - 4.9 # range : -4.9 , 4.1
yData_r = numpy.arange(10) - 6.9 # range : -6.9 , 2.1
- plot.addCurve(x=xData_r,
- y=yData_r,
- legend='plot_r',
- yaxis='right')
-
- for logX, logY in ((False, False),
- (True, False),
- (True, True),
- (False, True),
- (False, False)):
+ plot.addCurve(x=xData_r, y=yData_r, legend="plot_r", yaxis="right")
+
+ for logX, logY in (
+ (False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False),
+ ):
with self.subTest(logX=logX, logY=logY):
plot.getXAxis()._setLogarithmic(logX)
plot.getYAxis()._setLogarithmic(logY)
dataRange = plot.getDataRange()
- xRangeL, yRangeL = self._getRanges([xData_l, yData_l],
- [logX, logY])
- xRangeR, yRangeR = self._getRanges([xData_r, yData_r],
- [logX, logY])
+ xRangeL, yRangeL = self._getRanges([xData_l, yData_l], [logX, logY])
+ xRangeR, yRangeR = self._getRanges([xData_r, yData_r], [logX, logY])
xRangeLR = self._getRangesMinmax([xRangeL, xRangeR])
self.assertSequenceEqual(dataRange.x, xRangeLR)
self.assertSequenceEqual(dataRange.y, yRangeL)
@@ -244,51 +248,42 @@ class TestPlotRanges(ParametricTestCase):
# image sets x min and y max
# plot_left sets y min
# plot_right sets x max (and yright)
- plot = PlotWidget(backend='none')
+ plot = PlotWidget(backend="none")
origin = (-10, 5)
- scale = (3., 8.)
- image = numpy.arange(100.).reshape(20, 5)
+ scale = (3.0, 8.0)
+ image = numpy.arange(100.0).reshape(20, 5)
- plot.addImage(image,
- origin=origin, scale=scale, legend='image')
+ plot.addImage(image, origin=origin, scale=scale, legend="image")
xData_l = numpy.arange(10) - 0.9 # range : -0.9 , 8.1
yData_l = numpy.arange(10) - 1.9 # range : -1.9 , 7.1
- plot.addCurve(x=xData_l,
- y=yData_l,
- legend='plot_l',
- yaxis='left')
+ plot.addCurve(x=xData_l, y=yData_l, legend="plot_l", yaxis="left")
xData_r = numpy.arange(10) + 4.1 # range : 4.1 , 13.1
yData_r = numpy.arange(10) - 0.9 # range : -0.9 , 8.1
- plot.addCurve(x=xData_r,
- y=yData_r,
- legend='plot_r',
- yaxis='right')
-
- imgXRange = numpy.array([0., image.shape[1] * scale[0]]) + origin[0]
- imgYRange = numpy.array([0., image.shape[0] * scale[1]]) + origin[1]
-
- for logX, logY in ((False, False),
- (True, False),
- (True, True),
- (False, True),
- (False, False)):
+ plot.addCurve(x=xData_r, y=yData_r, legend="plot_r", yaxis="right")
+
+ imgXRange = numpy.array([0.0, image.shape[1] * scale[0]]) + origin[0]
+ imgYRange = numpy.array([0.0, image.shape[0] * scale[1]]) + origin[1]
+
+ for logX, logY in (
+ (False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False),
+ ):
with self.subTest(logX=logX, logY=logY):
plot.getXAxis()._setLogarithmic(logX)
plot.getYAxis()._setLogarithmic(logY)
dataRange = plot.getDataRange()
- xRangeL, yRangeL = self._getRanges([xData_l, yData_l],
- [logX, logY])
- xRangeR, yRangeR = self._getRanges([xData_r, yData_r],
- [logX, logY])
+ xRangeL, yRangeL = self._getRanges([xData_l, yData_l], [logX, logY])
+ xRangeR, yRangeR = self._getRanges([xData_r, yData_r], [logX, logY])
if logX or logY:
xRangeLR = self._getRangesMinmax([xRangeL, xRangeR])
else:
- xRangeLR = self._getRangesMinmax([xRangeL,
- xRangeR,
- imgXRange])
+ xRangeLR = self._getRangesMinmax([xRangeL, xRangeR, imgXRange])
yRangeL = self._getRangesMinmax([yRangeL, imgYRange])
self.assertSequenceEqual(dataRange.x, xRangeLR)
self.assertSequenceEqual(dataRange.y, yRangeL)
@@ -298,83 +293,97 @@ class TestPlotRanges(ParametricTestCase):
"""image data range, negative scale"""
origin = (-10, 25)
- scale = (-3., 8.)
- image = numpy.arange(100.).reshape(20, 5)
+ scale = (-3.0, 8.0)
+ image = numpy.arange(100.0).reshape(20, 5)
- plot = PlotWidget(backend='none')
- plot.addImage(image,
- origin=origin, scale=scale)
+ plot = PlotWidget(backend="none")
+ plot.addImage(image, origin=origin, scale=scale)
- xRange = numpy.array([0., image.shape[1] * scale[0]]) + origin[0]
+ xRange = numpy.array([0.0, image.shape[1] * scale[0]]) + origin[0]
xRange.sort() # negative scale!
- yRange = numpy.array([0., image.shape[0] * scale[1]]) + origin[1]
-
- ranges = {(False, False): (xRange, yRange),
- (True, False): (None, None),
- (True, True): (None, None),
- (False, True): (None, None)}
-
- for logX, logY in ((False, False),
- (True, False),
- (True, True),
- (False, True),
- (False, False)):
+ yRange = numpy.array([0.0, image.shape[0] * scale[1]]) + origin[1]
+
+ ranges = {
+ (False, False): (xRange, yRange),
+ (True, False): (None, None),
+ (True, True): (None, None),
+ (False, True): (None, None),
+ }
+
+ for logX, logY in (
+ (False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False),
+ ):
with self.subTest(logX=logX, logY=logY):
plot.getXAxis()._setLogarithmic(logX)
plot.getYAxis()._setLogarithmic(logY)
dataRange = plot.getDataRange()
xRange, yRange = ranges[logX, logY]
- self.assertTrue(numpy.array_equal(dataRange.x, xRange),
- msg='{0} != {1}'.format(dataRange.x, xRange))
- self.assertTrue(numpy.array_equal(dataRange.y, yRange),
- msg='{0} != {1}'.format(dataRange.y, yRange))
+ self.assertTrue(
+ numpy.array_equal(dataRange.x, xRange),
+ msg="{0} != {1}".format(dataRange.x, xRange),
+ )
+ self.assertTrue(
+ numpy.array_equal(dataRange.y, yRange),
+ msg="{0} != {1}".format(dataRange.y, yRange),
+ )
self.assertIsNone(dataRange.yright)
def testDataRangeImageNegativeScaleY(self):
"""image data range, negative scale"""
origin = (-10, 25)
- scale = (3., -8.)
- image = numpy.arange(100.).reshape(20, 5)
+ scale = (3.0, -8.0)
+ image = numpy.arange(100.0).reshape(20, 5)
- plot = PlotWidget(backend='none')
- plot.addImage(image,
- origin=origin, scale=scale)
+ plot = PlotWidget(backend="none")
+ plot.addImage(image, origin=origin, scale=scale)
- xRange = numpy.array([0., image.shape[1] * scale[0]]) + origin[0]
- yRange = numpy.array([0., image.shape[0] * scale[1]]) + origin[1]
+ xRange = numpy.array([0.0, image.shape[1] * scale[0]]) + origin[0]
+ yRange = numpy.array([0.0, image.shape[0] * scale[1]]) + origin[1]
yRange.sort() # negative scale!
- ranges = {(False, False): (xRange, yRange),
- (True, False): (None, None),
- (True, True): (None, None),
- (False, True): (None, None)}
-
- for logX, logY in ((False, False),
- (True, False),
- (True, True),
- (False, True),
- (False, False)):
+ ranges = {
+ (False, False): (xRange, yRange),
+ (True, False): (None, None),
+ (True, True): (None, None),
+ (False, True): (None, None),
+ }
+
+ for logX, logY in (
+ (False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False),
+ ):
with self.subTest(logX=logX, logY=logY):
plot.getXAxis()._setLogarithmic(logX)
plot.getYAxis()._setLogarithmic(logY)
dataRange = plot.getDataRange()
xRange, yRange = ranges[logX, logY]
- self.assertTrue(numpy.array_equal(dataRange.x, xRange),
- msg='{0} != {1}'.format(dataRange.x, xRange))
- self.assertTrue(numpy.array_equal(dataRange.y, yRange),
- msg='{0} != {1}'.format(dataRange.y, yRange))
+ self.assertTrue(
+ numpy.array_equal(dataRange.x, xRange),
+ msg="{0} != {1}".format(dataRange.x, xRange),
+ )
+ self.assertTrue(
+ numpy.array_equal(dataRange.y, yRange),
+ msg="{0} != {1}".format(dataRange.y, yRange),
+ )
self.assertIsNone(dataRange.yright)
def testDataRangeHiddenCurve(self):
"""curves with a hidden curve"""
- plot = PlotWidget(backend='none')
- plot.addCurve((0, 1), (0, 1), legend='shown')
- plot.addCurve((0, 1, 2), (5, 5, 5), legend='hidden')
+ plot = PlotWidget(backend="none")
+ plot.addCurve((0, 1), (0, 1), legend="shown")
+ plot.addCurve((0, 1, 2), (5, 5, 5), legend="hidden")
range1 = plot.getDataRange()
self.assertEqual(range1.x, (0, 2))
self.assertEqual(range1.y, (0, 5))
- plot.hideCurve('hidden')
+ plot.hideCurve("hidden")
range2 = plot.getDataRange()
self.assertEqual(range2.x, (0, 1))
self.assertEqual(range2.y, (0, 1))
@@ -386,108 +395,108 @@ class TestPlotGetCurveImage(unittest.TestCase):
def testGetCurve(self):
"""PlotWidget.getCurve and Plot.getActiveCurve tests"""
- plot = PlotWidget(backend='none')
+ plot = PlotWidget(backend="none")
# No curve
curve = plot.getCurve()
self.assertIsNone(curve) # No curve
plot.setActiveCurveHandling(True)
- plot.addCurve(x=(0, 1), y=(0, 1), legend='curve 0')
- plot.addCurve(x=(0, 1), y=(0, 1), legend='curve 1')
- plot.addCurve(x=(0, 1), y=(0, 1), legend='curve 2')
- plot.setActiveCurve('curve 0')
+ plot.addCurve(x=(0, 1), y=(0, 1), legend="curve 0")
+ plot.addCurve(x=(0, 1), y=(0, 1), legend="curve 1")
+ plot.addCurve(x=(0, 1), y=(0, 1), legend="curve 2")
+ plot.setActiveCurve("curve 0")
# Active curve
active = plot.getActiveCurve()
- self.assertEqual(active.getName(), 'curve 0')
+ self.assertEqual(active.getName(), "curve 0")
curve = plot.getCurve()
- self.assertEqual(curve.getName(), 'curve 0')
+ self.assertEqual(curve.getName(), "curve 0")
# No active curve and curves
plot.setActiveCurveHandling(False)
active = plot.getActiveCurve()
self.assertIsNone(active) # No active curve
curve = plot.getCurve()
- self.assertEqual(curve.getName(), 'curve 2') # Last added curve
+ self.assertEqual(curve.getName(), "curve 2") # Last added curve
# Last curve hidden
- plot.hideCurve('curve 2', True)
+ plot.hideCurve("curve 2", True)
curve = plot.getCurve()
- self.assertEqual(curve.getName(), 'curve 1') # Last added curve
+ self.assertEqual(curve.getName(), "curve 1") # Last added curve
# All curves hidden
- plot.hideCurve('curve 1', True)
- plot.hideCurve('curve 0', True)
+ plot.hideCurve("curve 1", True)
+ plot.hideCurve("curve 0", True)
curve = plot.getCurve()
self.assertIsNone(curve)
def testGetCurveOldApi(self):
"""old API PlotWidget.getCurve and Plot.getActiveCurve tests"""
- plot = PlotWidget(backend='none')
+ plot = PlotWidget(backend="none")
# No curve
curve = plot.getCurve()
self.assertIsNone(curve) # No curve
plot.setActiveCurveHandling(True)
- x = numpy.arange(10.).astype(numpy.float32)
+ x = numpy.arange(10.0).astype(numpy.float32)
y = x * x
- plot.addCurve(x=x, y=y, legend='curve 0', info=["whatever"])
- plot.addCurve(x=x, y=2*x, legend='curve 1', info="anything")
- plot.setActiveCurve('curve 0')
+ plot.addCurve(x=x, y=y, legend="curve 0", info=["whatever"])
+ plot.addCurve(x=x, y=2 * x, legend="curve 1", info="anything")
+ plot.setActiveCurve("curve 0")
# Active curve (4 elements)
xOut, yOut, legend, info = plot.getActiveCurve()[:4]
- self.assertEqual(legend, 'curve 0')
- self.assertTrue(numpy.allclose(xOut, x), 'curve 0 wrong x data')
- self.assertTrue(numpy.allclose(yOut, y), 'curve 0 wrong y data')
+ self.assertEqual(legend, "curve 0")
+ self.assertTrue(numpy.allclose(xOut, x), "curve 0 wrong x data")
+ self.assertTrue(numpy.allclose(yOut, y), "curve 0 wrong y data")
# Active curve (5 elements)
xOut, yOut, legend, info, params = plot.getCurve("curve 1")
- self.assertEqual(legend, 'curve 1')
- self.assertEqual(info, 'anything')
- self.assertTrue(numpy.allclose(xOut, x), 'curve 1 wrong x data')
- self.assertTrue(numpy.allclose(yOut, 2 * x), 'curve 1 wrong y data')
+ self.assertEqual(legend, "curve 1")
+ self.assertEqual(info, "anything")
+ self.assertTrue(numpy.allclose(xOut, x), "curve 1 wrong x data")
+ self.assertTrue(numpy.allclose(yOut, 2 * x), "curve 1 wrong y data")
def testGetImage(self):
"""PlotWidget.getImage and PlotWidget.getActiveImage tests"""
- plot = PlotWidget(backend='none')
+ plot = PlotWidget(backend="none")
# No image
image = plot.getImage()
self.assertIsNone(image)
- plot.addImage(((0, 1), (2, 3)), legend='image 0')
- plot.addImage(((0, 1), (2, 3)), legend='image 1')
+ plot.addImage(((0, 1), (2, 3)), legend="image 0")
+ plot.addImage(((0, 1), (2, 3)), legend="image 1")
# Active image
active = plot.getActiveImage()
- self.assertEqual(active.getName(), 'image 0')
+ self.assertEqual(active.getName(), "image 0")
image = plot.getImage()
- self.assertEqual(image.getName(), 'image 0')
+ self.assertEqual(image.getName(), "image 0")
# No active image
- plot.addImage(((0, 1), (2, 3)), legend='image 2')
+ plot.addImage(((0, 1), (2, 3)), legend="image 2")
plot.setActiveImage(None)
active = plot.getActiveImage()
self.assertIsNone(active)
image = plot.getImage()
- self.assertEqual(image.getName(), 'image 2')
+ self.assertEqual(image.getName(), "image 2")
# Active image
- plot.setActiveImage('image 1')
+ plot.setActiveImage("image 1")
active = plot.getActiveImage()
- self.assertEqual(active.getName(), 'image 1')
+ self.assertEqual(active.getName(), "image 1")
image = plot.getImage()
- self.assertEqual(image.getName(), 'image 1')
+ self.assertEqual(image.getName(), "image 1")
def testGetImageOldApi(self):
"""PlotWidget.getImage and PlotWidget.getActiveImage old API tests"""
- plot = PlotWidget(backend='none')
+ plot = PlotWidget(backend="none")
# No image
image = plot.getImage()
@@ -496,18 +505,18 @@ class TestPlotGetCurveImage(unittest.TestCase):
image = numpy.arange(10).astype(numpy.float32)
image.shape = 5, 2
- plot.addImage(image, legend='image 0', info=["Hi!"])
+ plot.addImage(image, legend="image 0", info=["Hi!"])
# Active image
data, legend, info, something, params = plot.getActiveImage()
- self.assertEqual(legend, 'image 0')
+ self.assertEqual(legend, "image 0")
self.assertEqual(info, ["Hi!"])
self.assertTrue(numpy.allclose(data, image), "image 0 data not correct")
def testGetAllImages(self):
"""PlotWidget.getAllImages test"""
- plot = PlotWidget(backend='none')
+ plot = PlotWidget(backend="none")
# No image
images = plot.getAllImages()
@@ -515,35 +524,34 @@ class TestPlotGetCurveImage(unittest.TestCase):
# 2 images
data = numpy.arange(100).reshape(10, 10)
- plot.addImage(data, legend='1')
- plot.addImage(data, origin=(10, 10), legend='2')
+ plot.addImage(data, legend="1")
+ plot.addImage(data, origin=(10, 10), legend="2")
images = plot.getAllImages(just_legend=True)
- self.assertEqual(list(images), ['1', '2'])
+ self.assertEqual(list(images), ["1", "2"])
images = plot.getAllImages(just_legend=False)
self.assertEqual(len(images), 2)
- self.assertEqual(images[0].getName(), '1')
- self.assertEqual(images[1].getName(), '2')
+ self.assertEqual(images[0].getName(), "1")
+ self.assertEqual(images[1].getName(), "2")
class TestPlotAddScatter(unittest.TestCase):
"""Test of plot addScatter"""
def testAddGetScatter(self):
-
- plot = PlotWidget(backend='none')
+ plot = PlotWidget(backend="none")
# No curve
scatter = plot._getItem(kind="scatter")
self.assertIsNone(scatter) # No curve
- plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 0')
- plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 1')
- plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 2')
- plot._setActiveItem('scatter', 'scatter 0')
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend="scatter 0")
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend="scatter 1")
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend="scatter 2")
+ plot.setActiveScatter("scatter 0")
# Active scatter
- active = plot._getActiveItem(kind='scatter')
- self.assertEqual(active.getName(), 'scatter 0')
+ active = plot.getActiveScatter()
+ self.assertEqual(active.getName(), "scatter 0")
# check default values
self.assertAlmostEqual(active.getSymbolSize(), active._DEFAULT_SYMBOL_SIZE)
@@ -561,26 +569,26 @@ class TestPlotAddScatter(unittest.TestCase):
self.assertEqual(s0.getSymbol(), "d")
self.assertAlmostEqual(s0.getAlpha(), 0.777)
- scatter1 = plot._getItem(kind='scatter', legend='scatter 1')
- self.assertEqual(scatter1.getName(), 'scatter 1')
+ scatter1 = plot._getItem(kind="scatter", legend="scatter 1")
+ self.assertEqual(scatter1.getName(), "scatter 1")
def testGetAllScatters(self):
"""PlotWidget.getAllImages test"""
- plot = PlotWidget(backend='none')
+ plot = PlotWidget(backend="none")
items = plot.getItems()
self.assertEqual(len(items), 0)
- plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 0')
- plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 1')
- plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 2')
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend="scatter 0")
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend="scatter 1")
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend="scatter 2")
items = plot.getItems()
self.assertEqual(len(items), 3)
- self.assertEqual(items[0].getName(), 'scatter 0')
- self.assertEqual(items[1].getName(), 'scatter 1')
- self.assertEqual(items[2].getName(), 'scatter 2')
+ self.assertEqual(items[0].getName(), "scatter 0")
+ self.assertEqual(items[1].getName(), "scatter 1")
+ self.assertEqual(items[2].getName(), "scatter 2")
class TestPlotHistogram(unittest.TestCase):
@@ -593,13 +601,13 @@ class TestPlotHistogram(unittest.TestCase):
edgesCenter = numpy.array([-0.5, 0.5, 1.5, 2.5])
# testing x values for right
- edges = _computeEdges(x, 'right')
+ edges = _computeEdges(x, "right")
numpy.testing.assert_array_equal(edges, edgesRight)
- edges = _computeEdges(x, 'center')
+ edges = _computeEdges(x, "center")
numpy.testing.assert_array_equal(edges, edgesCenter)
- edges = _computeEdges(x, 'left')
+ edges = _computeEdges(x, "left")
numpy.testing.assert_array_equal(edges, edgesLeft)
def testHistogramCurve(self):
@@ -607,11 +615,71 @@ class TestPlotHistogram(unittest.TestCase):
edges = numpy.array([0, 1, 2, 3])
xHisto, yHisto = _getHistogramCurve(y, edges)
- numpy.testing.assert_array_equal(
- yHisto, numpy.array([3, 3, 2, 2, 5, 5]))
+ numpy.testing.assert_array_equal(yHisto, numpy.array([3, 3, 2, 2, 5, 5]))
y = numpy.array([-3, 2, 5, 0])
edges = numpy.array([-2, -1, 0, 1, 2])
xHisto, yHisto = _getHistogramCurve(y, edges)
numpy.testing.assert_array_equal(
- yHisto, numpy.array([-3, -3, 2, 2, 5, 5, 0, 0]))
+ yHisto, numpy.array([-3, -3, 2, 2, 5, 5, 0, 0])
+ )
+
+
+def testSetDefaultColors(qWidgetFactory):
+ """Basic test of PlotWidget.get|setDefaultColors"""
+ plot = qWidgetFactory(PlotWidget)
+
+ # By default using config
+ assert numpy.array_equal(
+ plot.getDefaultColors(), silx.config.DEFAULT_PLOT_CURVE_COLORS
+ )
+
+ # Use own colors
+ colors = "red", "green", "blue"
+ plot.setDefaultColors(colors)
+ assert plot.getDefaultColors() == colors
+
+ # Reset to default
+ plot.setDefaultColors(None)
+ assert numpy.array_equal(
+ plot.getDefaultColors(), silx.config.DEFAULT_PLOT_CURVE_COLORS
+ )
+
+
+def testSetDefaultColorsAddCurve(qWidgetFactory):
+ """Test that PlotWidget.setDefaultColors reset color index"""
+ plot = qWidgetFactory(PlotWidget)
+
+ plot.addCurve((0, 1), (0, 0), legend="curve0")
+ plot.addCurve((0, 1), (1, 1), legend="curve1")
+ plot.addCurve((0, 1), (2, 2), legend="curve2")
+
+ colors = "#123456", "#abcdef"
+ plot.setDefaultColors(colors)
+ assert plot.getDefaultColors() == colors
+
+ # Check that the color index is reset
+ curve = plot.addCurve((1, 2), (0, 1), legend="newcurve")
+ assert curve.getColor() == rgba(colors[0])
+
+
+def testDefaultColorsUpdateConfig(qWidgetFactory):
+ """Test that color index is reset if needed when default colors config is updated"""
+ plot = qWidgetFactory(PlotWidget)
+
+ plot.addCurve((0, 1), (0, 0), legend="curve0")
+ plot.addCurve((0, 1), (1, 1), legend="curve1")
+ plot.addCurve((0, 1), (2, 2), legend="curve2")
+
+ previous_colors = silx.config.DEFAULT_PLOT_CURVE_COLORS
+ try:
+ colors = "#123456", "#abcdef"
+ silx.config.DEFAULT_PLOT_CURVE_COLORS = colors
+ assert plot.getDefaultColors() == colors
+
+ # Check that the color index is reset
+ curve = plot.addCurve((1, 2), (0, 1), legend="newcurve")
+ assert curve.getColor() == rgba(colors[0])
+
+ finally:
+ silx.config.DEFAULT_PLOT_CURVE_COLORS = previous_colors
diff --git a/src/silx/gui/plot/test/testPlotWindow.py b/src/silx/gui/plot/test/testPlotWindow.py
index 8e3f1df..8f17bf1 100644
--- a/src/silx/gui/plot/test/testPlotWindow.py
+++ b/src/silx/gui/plot/test/testPlotWindow.py
@@ -28,7 +28,6 @@ __license__ = "MIT"
__date__ = "27/06/2017"
-import unittest
import numpy
import pytest
@@ -72,12 +71,14 @@ class TestPlotWindow(TestCaseQt):
toolButton = getQToolButtonFromAction(action)
self.assertIsNot(toolButton, None)
self.mouseClick(toolButton, qt.Qt.LeftButton)
- self.assertNotEqual(getter(), initialState,
- msg='"%s" state not changed' % action.text())
+ self.assertNotEqual(
+ getter(), initialState, msg='"%s" state not changed' % action.text()
+ )
self.mouseClick(toolButton, qt.Qt.LeftButton)
- self.assertEqual(getter(), initialState,
- msg='"%s" state not changed' % action.text())
+ self.assertEqual(
+ getter(), initialState, msg='"%s" state not changed' % action.text()
+ )
# Trigger a zoom reset
self.mouseMove(self.plot)
@@ -88,8 +89,8 @@ class TestPlotWindow(TestCaseQt):
def testDockWidgets(self):
"""Test add/remove dock widgets"""
- dock1 = qt.QDockWidget('Test 1')
- dock1.setWidget(qt.QLabel('Test 1'))
+ dock1 = qt.QDockWidget("Test 1")
+ dock1.setWidget(qt.QLabel("Test 1"))
self.plot.addTabbedDockWidget(dock1)
self.qapp.processEvents()
@@ -97,17 +98,17 @@ class TestPlotWindow(TestCaseQt):
self.plot.removeDockWidget(dock1)
self.qapp.processEvents()
- dock2 = qt.QDockWidget('Test 2')
- dock2.setWidget(qt.QLabel('Test 2'))
+ dock2 = qt.QDockWidget("Test 2")
+ dock2.setWidget(qt.QLabel("Test 2"))
self.plot.addTabbedDockWidget(dock2)
self.qapp.processEvents()
- if qt.BINDING != 'PySide2':
- # Weird bug with PySide2 later upon gc.collect() when getting the layout
- self.assertNotEqual(self.plot.layout().indexOf(dock2),
- -1,
- "dock2 not properly displayed")
+ self.assertNotEqual(
+ self.plot.layout().indexOf(dock2),
+ -1,
+ "dock2 not properly displayed",
+ )
def testToolAspectRatio(self):
self.plot.toolBar()
@@ -128,12 +129,14 @@ class TestPlotWindow(TestCaseQt):
old = Colormap._computeAutoscaleRange
self._count = 0
+
def _computeAutoscaleRange(colormap, data):
self._count = self._count + 1
return 10, 20
+
Colormap._computeAutoscaleRange = _computeAutoscaleRange
try:
- colormap = Colormap(name='red')
+ colormap = Colormap(name="red")
self.plot.setVisible(True)
# Add an image
@@ -163,11 +166,10 @@ class TestPlotWindow(TestCaseQt):
ylimits = self.plot.getYAxis().getLimits()
isKeepAspectRatio = self.plot.isKeepDataAspectRatio()
- for backend in ('gl', 'mpl'):
+ for backend in ("gl", "mpl"):
with self.subTest():
self.plot.setBackend(backend)
self.plot.replot()
self.assertEqual(self.plot.getXAxis().getLimits(), xlimits)
self.assertEqual(self.plot.getYAxis().getLimits(), ylimits)
- self.assertEqual(
- self.plot.isKeepDataAspectRatio(), isKeepAspectRatio)
+ self.assertEqual(self.plot.isKeepDataAspectRatio(), isKeepAspectRatio)
diff --git a/src/silx/gui/plot/test/testRoiStatsWidget.py b/src/silx/gui/plot/test/testRoiStatsWidget.py
index 2c1c6b3..759ebe2 100644
--- a/src/silx/gui/plot/test/testRoiStatsWidget.py
+++ b/src/silx/gui/plot/test/testRoiStatsWidget.py
@@ -32,47 +32,49 @@ from silx.gui.plot.ROIStatsWidget import ROIStatsWidget
from silx.gui.plot.CurvesROIWidget import ROI
from silx.gui.plot.items.roi import RectangleROI, PolygonROI
from silx.gui.plot.StatsWidget import UpdateMode
-import unittest
import numpy
-
class _TestRoiStatsBase(TestCaseQt):
"""Base class for several unittest relative to ROIStatsWidget"""
+
def setUp(self):
TestCaseQt.setUp(self)
# define plot
self.plot = PlotWindow()
- self.plot.addImage(numpy.arange(10000).reshape(100, 100),
- legend='img1')
- self.img_item = self.plot.getImage('img1')
- self.plot.addCurve(x=numpy.linspace(0, 10, 56), y=numpy.arange(56),
- legend='curve1')
- self.curve_item = self.plot.getCurve('curve1')
- self.plot.addHistogram(edges=numpy.linspace(0, 10, 56),
- histogram=numpy.arange(56), legend='histo1')
- self.histogram_item = self.plot.getHistogram(legend='histo1')
- self.plot.addScatter(x=numpy.linspace(0, 10, 56),
- y=numpy.linspace(0, 10, 56),
- value=numpy.arange(56),
- legend='scatter1')
- self.scatter_item = self.plot.getScatter(legend='scatter1')
+ self.plot.addImage(numpy.arange(10000).reshape(100, 100), legend="img1")
+ self.img_item = self.plot.getImage("img1")
+ self.plot.addCurve(
+ x=numpy.linspace(0, 10, 56), y=numpy.arange(56), legend="curve1"
+ )
+ self.curve_item = self.plot.getCurve("curve1")
+ self.plot.addHistogram(
+ edges=numpy.linspace(0, 10, 56), histogram=numpy.arange(56), legend="histo1"
+ )
+ self.histogram_item = self.plot.getHistogram(legend="histo1")
+ self.plot.addScatter(
+ x=numpy.linspace(0, 10, 56),
+ y=numpy.linspace(0, 10, 56),
+ value=numpy.arange(56),
+ legend="scatter1",
+ )
+ self.scatter_item = self.plot.getScatter(legend="scatter1")
# stats widget
self.statsWidget = ROIStatsWidget(plot=self.plot)
# define stats
stats = [
- ('sum', numpy.sum),
- ('mean', numpy.mean),
+ ("sum", numpy.sum),
+ ("mean", numpy.mean),
]
self.statsWidget.setStats(stats=stats)
# define rois
- self.roi1D = ROI(name='range1', fromdata=0, todata=4, type_='energy')
+ self.roi1D = ROI(name="range1", fromdata=0, todata=4, type_="energy")
self.rectangle_roi = RectangleROI()
self.rectangle_roi.setGeometry(origin=(0, 0), size=(20, 20))
- self.rectangle_roi.setName('Initial ROI')
+ self.rectangle_roi.setName("Initial ROI")
self.polygon_roi = PolygonROI()
points = numpy.array([[0, 5], [5, 0], [10, 5], [5, 10]])
self.polygon_roi.setPoints(points)
@@ -95,182 +97,164 @@ class TestRoiStatsCouple(_TestRoiStatsBase):
"""
Test different possible couple (roi, plotItem).
Check that:
-
+
* computation is correct if couple is valid
* raise an error if couple is invalid
"""
+
def testROICurve(self):
"""
- Test that the couple (ROI, curveItem) can be used for stats
+ Test that the couple (ROI, curveItem) can be used for stats
"""
- item = self.statsWidget.addItem(roi=self.roi1D,
- plotItem=self.curve_item)
+ item = self.statsWidget.addItem(roi=self.roi1D, plotItem=self.curve_item)
assert item is not None
tableItems = self.statsTable()._itemToTableItems(item)
- self.assertEqual(tableItems['sum'].text(), '253')
- self.assertEqual(tableItems['mean'].text(), '11.0')
+ self.assertEqual(tableItems["sum"].text(), "253")
+ self.assertEqual(tableItems["mean"].text(), "11.0")
def testRectangleImage(self):
"""
- Test that the couple (RectangleROI, imageItem) can be used for stats
+ Test that the couple (RectangleROI, imageItem) can be used for stats
"""
- item = self.statsWidget.addItem(roi=self.rectangle_roi,
- plotItem=self.img_item)
+ item = self.statsWidget.addItem(roi=self.rectangle_roi, plotItem=self.img_item)
assert item is not None
- self.plot.addImage(numpy.ones(10000).reshape(100, 100),
- legend='img1')
+ self.plot.addImage(numpy.ones(10000).reshape(100, 100), legend="img1")
self.qapp.processEvents()
tableItems = self.statsTable()._itemToTableItems(item)
- self.assertEqual(tableItems['sum'].text(), str(float(21*21)))
- self.assertEqual(tableItems['mean'].text(), '1.0')
+ self.assertEqual(tableItems["sum"].text(), str(float(21 * 21)))
+ self.assertEqual(tableItems["mean"].text(), "1.0")
def testPolygonImage(self):
"""
- Test that the couple (PolygonROI, imageItem) can be used for stats
+ Test that the couple (PolygonROI, imageItem) can be used for stats
"""
- item = self.statsWidget.addItem(roi=self.polygon_roi,
- plotItem=self.img_item)
+ item = self.statsWidget.addItem(roi=self.polygon_roi, plotItem=self.img_item)
assert item is not None
tableItems = self.statsTable()._itemToTableItems(item)
- self.assertEqual(tableItems['sum'].text(), '22750')
- self.assertEqual(tableItems['mean'].text(), '455.0')
+ self.assertEqual(tableItems["sum"].text(), "22750")
+ self.assertEqual(tableItems["mean"].text(), "455.0")
def testROIImage(self):
"""
- Test that the couple (ROI, imageItem) is raising an error
+ Test that the couple (ROI, imageItem) is raising an error
"""
with self.assertRaises(TypeError):
- self.statsWidget.addItem(roi=self.roi1D,
- plotItem=self.img_item)
+ self.statsWidget.addItem(roi=self.roi1D, plotItem=self.img_item)
def testRectangleCurve(self):
"""
- Test that the couple (rectangleROI, curveItem) is raising an error
+ Test that the couple (rectangleROI, curveItem) is raising an error
"""
with self.assertRaises(TypeError):
- self.statsWidget.addItem(roi=self.rectangle_roi,
- plotItem=self.curve_item)
+ self.statsWidget.addItem(roi=self.rectangle_roi, plotItem=self.curve_item)
def testROIHistogram(self):
"""
- Test that the couple (PolygonROI, imageItem) can be used for stats
+ Test that the couple (PolygonROI, imageItem) can be used for stats
"""
- item = self.statsWidget.addItem(roi=self.roi1D,
- plotItem=self.histogram_item)
+ item = self.statsWidget.addItem(roi=self.roi1D, plotItem=self.histogram_item)
assert item is not None
tableItems = self.statsTable()._itemToTableItems(item)
- self.assertEqual(tableItems['sum'].text(), '253')
- self.assertEqual(tableItems['mean'].text(), '11.0')
+ self.assertEqual(tableItems["sum"].text(), "253")
+ self.assertEqual(tableItems["mean"].text(), "11.0")
def testROIScatter(self):
"""
- Test that the couple (PolygonROI, imageItem) can be used for stats
+ Test that the couple (PolygonROI, imageItem) can be used for stats
"""
- item = self.statsWidget.addItem(roi=self.roi1D,
- plotItem=self.scatter_item)
+ item = self.statsWidget.addItem(roi=self.roi1D, plotItem=self.scatter_item)
assert item is not None
tableItems = self.statsTable()._itemToTableItems(item)
- self.assertEqual(tableItems['sum'].text(), '253')
- self.assertEqual(tableItems['mean'].text(), '11.0')
+ self.assertEqual(tableItems["sum"].text(), "253")
+ self.assertEqual(tableItems["mean"].text(), "11.0")
class TestRoiStatsAddRemoveItem(_TestRoiStatsBase):
"""Test adding and removing (roi, plotItem) items"""
+
def testAddRemoveItems(self):
- item1 = self.statsWidget.addItem(roi=self.roi1D,
- plotItem=self.scatter_item)
+ item1 = self.statsWidget.addItem(roi=self.roi1D, plotItem=self.scatter_item)
self.assertTrue(item1 is not None)
self.assertEqual(self.statsTable().rowCount(), 1)
- item2 = self.statsWidget.addItem(roi=self.roi1D,
- plotItem=self.histogram_item)
+ item2 = self.statsWidget.addItem(roi=self.roi1D, plotItem=self.histogram_item)
self.assertTrue(item2 is not None)
self.assertEqual(self.statsTable().rowCount(), 2)
# try to add twice the same item
- item3 = self.statsWidget.addItem(roi=self.roi1D,
- plotItem=self.histogram_item)
+ item3 = self.statsWidget.addItem(roi=self.roi1D, plotItem=self.histogram_item)
self.assertTrue(item3 is None)
self.assertEqual(self.statsTable().rowCount(), 2)
- item4 = self.statsWidget.addItem(roi=self.roi1D,
- plotItem=self.curve_item)
+ item4 = self.statsWidget.addItem(roi=self.roi1D, plotItem=self.curve_item)
self.assertTrue(item4 is not None)
self.assertEqual(self.statsTable().rowCount(), 3)
- self.statsWidget.removeItem(plotItem=item4._plot_item,
- roi=item4._roi)
+ self.statsWidget.removeItem(plotItem=item4._plot_item, roi=item4._roi)
self.assertEqual(self.statsTable().rowCount(), 2)
# try to remove twice the same item
- self.statsWidget.removeItem(plotItem=item4._plot_item,
- roi=item4._roi)
+ self.statsWidget.removeItem(plotItem=item4._plot_item, roi=item4._roi)
self.assertEqual(self.statsTable().rowCount(), 2)
- self.statsWidget.removeItem(plotItem=item2._plot_item,
- roi=item2._roi)
- self.statsWidget.removeItem(plotItem=item1._plot_item,
- roi=item1._roi)
+ self.statsWidget.removeItem(plotItem=item2._plot_item, roi=item2._roi)
+ self.statsWidget.removeItem(plotItem=item1._plot_item, roi=item1._roi)
self.assertEqual(self.statsTable().rowCount(), 0)
class TestRoiStatsRoiUpdate(_TestRoiStatsBase):
"""Test that the stats will be updated if the roi is updated"""
+
def testChangeRoi(self):
- item = self.statsWidget.addItem(roi=self.rectangle_roi,
- plotItem=self.img_item)
+ item = self.statsWidget.addItem(roi=self.rectangle_roi, plotItem=self.img_item)
assert item is not None
tableItems = self.statsTable()._itemToTableItems(item)
- self.assertEqual(tableItems['sum'].text(), '445410')
- self.assertEqual(tableItems['mean'].text(), '1010.0')
+ self.assertEqual(tableItems["sum"].text(), "445410")
+ self.assertEqual(tableItems["mean"].text(), "1010.0")
# update roi
self.rectangle_roi.setOrigin(position=(10, 10))
- self.assertNotEqual(tableItems['sum'].text(), '445410')
- self.assertNotEqual(tableItems['mean'].text(), '1010.0')
+ self.assertNotEqual(tableItems["sum"].text(), "445410")
+ self.assertNotEqual(tableItems["mean"].text(), "1010.0")
def testUpdateModeScenario(self):
"""Test update according to a simple scenario"""
self.statsWidget._setUpdateMode(UpdateMode.AUTO)
- item = self.statsWidget.addItem(roi=self.rectangle_roi,
- plotItem=self.img_item)
+ item = self.statsWidget.addItem(roi=self.rectangle_roi, plotItem=self.img_item)
assert item is not None
tableItems = self.statsTable()._itemToTableItems(item)
- self.assertEqual(tableItems['sum'].text(), '445410')
- self.assertEqual(tableItems['mean'].text(), '1010.0')
+ self.assertEqual(tableItems["sum"].text(), "445410")
+ self.assertEqual(tableItems["mean"].text(), "1010.0")
self.statsWidget._setUpdateMode(UpdateMode.MANUAL)
self.rectangle_roi.setOrigin(position=(10, 10))
self.qapp.processEvents()
- self.assertNotEqual(tableItems['sum'].text(), '445410')
- self.assertNotEqual(tableItems['mean'].text(), '1010.0')
+ self.assertNotEqual(tableItems["sum"].text(), "445410")
+ self.assertNotEqual(tableItems["mean"].text(), "1010.0")
self.statsWidget._updateAllStats(is_request=True)
- self.assertNotEqual(tableItems['sum'].text(), '445410')
- self.assertNotEqual(tableItems['mean'].text(), '1010.0')
+ self.assertNotEqual(tableItems["sum"].text(), "445410")
+ self.assertNotEqual(tableItems["mean"].text(), "1010.0")
class TestRoiStatsPlotItemUpdate(_TestRoiStatsBase):
"""Test that the stats will be updated if the plot item is updated"""
+
def testChangeImage(self):
self.statsWidget._setUpdateMode(UpdateMode.AUTO)
- item = self.statsWidget.addItem(roi=self.rectangle_roi,
- plotItem=self.img_item)
+ item = self.statsWidget.addItem(roi=self.rectangle_roi, plotItem=self.img_item)
assert item is not None
tableItems = self.statsTable()._itemToTableItems(item)
- self.assertEqual(tableItems['mean'].text(), '1010.0')
+ self.assertEqual(tableItems["mean"].text(), "1010.0")
# update plot
- self.plot.addImage(numpy.arange(100, 10100).reshape(100, 100),
- legend='img1')
- self.assertNotEqual(tableItems['mean'].text(), '1059.5')
+ self.plot.addImage(numpy.arange(100, 10100).reshape(100, 100), legend="img1")
+ self.assertNotEqual(tableItems["mean"].text(), "1059.5")
def testUpdateModeScenario(self):
"""Test update according to a simple scenario"""
self.statsWidget._setUpdateMode(UpdateMode.MANUAL)
- item = self.statsWidget.addItem(roi=self.rectangle_roi,
- plotItem=self.img_item)
+ item = self.statsWidget.addItem(roi=self.rectangle_roi, plotItem=self.img_item)
assert item is not None
tableItems = self.statsTable()._itemToTableItems(item)
- self.assertEqual(tableItems['mean'].text(), '1010.0')
- self.plot.addImage(numpy.arange(100, 10100).reshape(100, 100),
- legend='img1')
- self.assertEqual(tableItems['mean'].text(), '1010.0')
+ self.assertEqual(tableItems["mean"].text(), "1010.0")
+ self.plot.addImage(numpy.arange(100, 10100).reshape(100, 100), legend="img1")
+ self.assertEqual(tableItems["mean"].text(), "1010.0")
self.statsWidget._updateAllStats(is_request=True)
- self.assertEqual(tableItems['mean'].text(), '1110.0')
+ self.assertEqual(tableItems["mean"].text(), "1110.0")
diff --git a/src/silx/gui/plot/test/testSaveAction.py b/src/silx/gui/plot/test/testSaveAction.py
index d5a06c6..f8ac7ee 100644
--- a/src/silx/gui/plot/test/testSaveAction.py
+++ b/src/silx/gui/plot/test/testSaveAction.py
@@ -39,9 +39,8 @@ from silx.gui.plot.actions.io import SaveAction
class TestSaveActionSaveCurvesAsSpec(unittest.TestCase):
-
def setUp(self):
- self.plot = PlotWidget(backend='none')
+ self.plot = PlotWidget(backend="none")
self.saveAction = SaveAction(plot=self.plot)
self.tempdir = tempfile.mkdtemp()
@@ -56,17 +55,16 @@ class TestSaveActionSaveCurvesAsSpec(unittest.TestCase):
self.plot.setGraphXLabel("graph x label")
self.plot.setGraphYLabel("graph y label")
- self.plot.addCurve([0, 1], [1, 2], "curve with labels",
- xlabel="curve0 X", ylabel="curve0 Y")
- self.plot.addCurve([-1, 3], [-6, 2], "curve with X label",
- xlabel="curve1 X")
- self.plot.addCurve([-2, 0], [8, 12], "curve with Y label",
- ylabel="curve2 Y")
+ self.plot.addCurve(
+ [0, 1], [1, 2], "curve with labels", xlabel="curve0 X", ylabel="curve0 Y"
+ )
+ self.plot.addCurve([-1, 3], [-6, 2], "curve with X label", xlabel="curve1 X")
+ self.plot.addCurve([-2, 0], [8, 12], "curve with Y label", ylabel="curve2 Y")
self.plot.addCurve([3, 1], [7, 6], "curve with no labels")
- self.saveAction._saveCurves(self.plot,
- self.out_fname,
- SaveAction.DEFAULT_ALL_CURVES_FILTERS[0]) # "All curves as SpecFile (*.dat)"
+ self.saveAction._saveCurves(
+ self.plot, self.out_fname, SaveAction.DEFAULT_ALL_CURVES_FILTERS[0]
+ ) # "All curves as SpecFile (*.dat)"
with open(self.out_fname, "rb") as f:
file_content = f.read()
@@ -99,33 +97,35 @@ class TestSaveActionExtension(PlotWidgetTestCase):
saveAction = SaveAction(plot=self.plot, parent=self.plot)
# Add a new file filter
- nameFilter = 'Dummy file (*.dummy)'
- saveAction.setFileFilter('all', nameFilter, self._dummySaveFunction)
- self.assertTrue(nameFilter in saveAction.getFileFilters('all'))
- self.assertEqual(saveAction.getFileFilters('all')[nameFilter],
- self._dummySaveFunction)
+ nameFilter = "Dummy file (*.dummy)"
+ saveAction.setFileFilter("all", nameFilter, self._dummySaveFunction)
+ self.assertTrue(nameFilter in saveAction.getFileFilters("all"))
+ self.assertEqual(
+ saveAction.getFileFilters("all")[nameFilter], self._dummySaveFunction
+ )
# Add a new file filter at a particular position
- nameFilter = 'Dummy file2 (*.dummy)'
- saveAction.setFileFilter('all', nameFilter,
- self._dummySaveFunction, index=3)
- self.assertTrue(nameFilter in saveAction.getFileFilters('all'))
- filters = saveAction.getFileFilters('all')
+ nameFilter = "Dummy file2 (*.dummy)"
+ saveAction.setFileFilter("all", nameFilter, self._dummySaveFunction, index=3)
+ self.assertTrue(nameFilter in saveAction.getFileFilters("all"))
+ filters = saveAction.getFileFilters("all")
self.assertEqual(filters[nameFilter], self._dummySaveFunction)
- self.assertEqual(list(filters.keys()).index(nameFilter),3)
+ self.assertEqual(list(filters.keys()).index(nameFilter), 3)
# Update an existing file filter
nameFilter = SaveAction.IMAGE_FILTER_EDF
- saveAction.setFileFilter('image', nameFilter, self._dummySaveFunction)
- self.assertEqual(saveAction.getFileFilters('image')[nameFilter],
- self._dummySaveFunction)
+ saveAction.setFileFilter("image", nameFilter, self._dummySaveFunction)
+ self.assertEqual(
+ saveAction.getFileFilters("image")[nameFilter], self._dummySaveFunction
+ )
# Change the position of an existing file filter
- nameFilter = 'Dummy file2 (*.dummy)'
- oldIndex = list(saveAction.getFileFilters('all')).index(nameFilter)
+ nameFilter = "Dummy file2 (*.dummy)"
+ oldIndex = list(saveAction.getFileFilters("all")).index(nameFilter)
newIndex = oldIndex - 1
- saveAction.setFileFilter('all', nameFilter,
- self._dummySaveFunction, index=newIndex)
- filters = saveAction.getFileFilters('all')
+ saveAction.setFileFilter(
+ "all", nameFilter, self._dummySaveFunction, index=newIndex
+ )
+ filters = saveAction.getFileFilters("all")
self.assertEqual(filters[nameFilter], self._dummySaveFunction)
self.assertEqual(list(filters.keys()).index(nameFilter), newIndex)
diff --git a/src/silx/gui/plot/test/testScatterMaskToolsWidget.py b/src/silx/gui/plot/test/testScatterMaskToolsWidget.py
index 68375b0..5dc14e1 100644
--- a/src/silx/gui/plot/test/testScatterMaskToolsWidget.py
+++ b/src/silx/gui/plot/test/testScatterMaskToolsWidget.py
@@ -30,7 +30,6 @@ __date__ = "17/01/2018"
import logging
import os.path
-import unittest
import numpy
@@ -41,8 +40,6 @@ from silx.gui.utils.testutils import getQToolButtonFromAction
from silx.gui.plot import PlotWindow, ScatterMaskToolsWidget
from .utils import PlotWidgetTestCase
-import fabio
-
_logger = logging.getLogger(__name__)
@@ -56,7 +53,8 @@ class TestScatterMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
def setUp(self):
super(TestScatterMaskToolsWidget, self).setUp()
self.widget = ScatterMaskToolsWidget.ScatterMaskToolsDockWidget(
- plot=self.plot, name='TEST')
+ plot=self.plot, name="TEST"
+ )
self.plot.addDockWidget(qt.Qt.BottomDockWidgetArea, self.widget)
self.maskWidget = self.widget.widget()
@@ -68,10 +66,10 @@ class TestScatterMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
def testEmptyPlot(self):
"""Empty plot, display MaskToolsDockWidget, toggle multiple masks"""
- self.maskWidget.setMultipleMasks('single')
+ self.maskWidget.setMultipleMasks("single")
self.qapp.processEvents()
- self.maskWidget.setMultipleMasks('exclusive')
+ self.maskWidget.setMultipleMasks("exclusive")
self.qapp.processEvents()
def _drag(self):
@@ -102,12 +100,14 @@ class TestScatterMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
x, y = plot.width() // 2, plot.height() // 2
offset = min(plot.width(), plot.height()) // 10
- star = [(x, y + offset),
- (x - offset, y - offset),
- (x + offset, y),
- (x - offset, y),
- (x + offset, y - offset),
- (x, y + offset)] # Close polygon
+ star = [
+ (x, y + offset),
+ (x - offset, y - offset),
+ (x + offset, y),
+ (x - offset, y),
+ (x + offset, y - offset),
+ (x, y + offset),
+ ] # Close polygon
self.mouseMove(plot, pos=[0, 0])
for pos in star:
@@ -124,41 +124,44 @@ class TestScatterMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
x, y = plot.width() // 2, plot.height() // 2
offset = min(plot.width(), plot.height()) // 10
- star = [(x, y + offset),
- (x - offset, y - offset),
- (x + offset, y),
- (x - offset, y),
- (x + offset, y - offset)]
+ star = [
+ (x, y + offset),
+ (x - offset, y - offset),
+ (x + offset, y),
+ (x - offset, y),
+ (x + offset, y - offset),
+ ]
self.mouseMove(plot, pos=[0, 0])
self.mouseMove(plot, pos=star[0])
self.mousePress(plot, qt.Qt.LeftButton, pos=star[0])
for pos in star[1:]:
self.mouseMove(plot, pos=pos)
- self.mouseRelease(
- plot, qt.Qt.LeftButton, pos=star[-1])
+ self.mouseRelease(plot, qt.Qt.LeftButton, pos=star[-1])
def testWithAScatter(self):
"""Plot with a Scatter: test MaskToolsWidget interactions"""
# Add and remove a scatter (this should enable/disable GUI + change mask)
self.plot.addScatter(
- x=numpy.arange(256),
- y=numpy.arange(256),
- value=numpy.random.random(256),
- legend='test')
- self.plot._setActiveItem(kind="scatter", legend="test")
+ x=numpy.arange(256),
+ y=numpy.arange(256),
+ value=numpy.random.random(256),
+ legend="test",
+ )
+ self.plot.setActiveScatter("test")
self.qapp.processEvents()
- self.plot.remove('test', kind='scatter')
+ self.plot.remove("test", kind="scatter")
self.qapp.processEvents()
self.plot.addScatter(
- x=numpy.arange(1000),
- y=1000 * (numpy.arange(1000) % 20),
- value=numpy.random.random(1000),
- legend='test')
- self.plot._setActiveItem(kind="scatter", legend="test")
+ x=numpy.arange(1000),
+ y=1000 * (numpy.arange(1000) % 20),
+ value=numpy.random.random(1000),
+ legend="test",
+ )
+ self.plot.setActiveScatter("test")
self.plot.resetZoom()
self.qapp.processEvents()
@@ -172,15 +175,13 @@ class TestScatterMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
self.qapp.processEvents()
self._drag()
- self.assertFalse(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ self.assertFalse(numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
# unmask same region
self.maskWidget.maskStateGroup.button(0).click()
self.qapp.processEvents()
self._drag()
- self.assertTrue(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ self.assertTrue(numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
# Test draw polygon #
toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction)
@@ -191,15 +192,13 @@ class TestScatterMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
self.maskWidget.maskStateGroup.button(1).click()
self.qapp.processEvents()
self._drawPolygon()
- self.assertFalse(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ self.assertFalse(numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
# unmask same region
self.maskWidget.maskStateGroup.button(0).click()
self.qapp.processEvents()
self._drawPolygon()
- self.assertTrue(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ self.assertTrue(numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
# Test draw pencil #
toolButton = getQToolButtonFromAction(self.maskWidget.pencilAction)
@@ -213,15 +212,13 @@ class TestScatterMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
self.maskWidget.maskStateGroup.button(1).click()
self.qapp.processEvents()
self._drawPencil()
- self.assertFalse(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ self.assertFalse(numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
# unmask same region
self.maskWidget.maskStateGroup.button(0).click()
self.qapp.processEvents()
self._drawPencil()
- self.assertTrue(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ self.assertTrue(numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
# Test no draw tool #
toolButton = getQToolButtonFromAction(self.maskWidget.browseAction)
@@ -232,11 +229,12 @@ class TestScatterMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
def __loadSave(self, file_format):
self.plot.addScatter(
- x=numpy.arange(256),
- y=25 * (numpy.arange(256) % 10),
- value=numpy.random.random(256),
- legend='test')
- self.plot._setActiveItem(kind="scatter", legend="test")
+ x=numpy.arange(256),
+ y=25 * (numpy.arange(256) % 10),
+ value=numpy.random.random(256),
+ legend="test",
+ )
+ self.plot.setActiveScatter("test")
self.plot.resetZoom()
self.qapp.processEvents()
@@ -250,16 +248,18 @@ class TestScatterMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
self.assertFalse(numpy.all(numpy.equal(ref_mask, 0)))
with temp_dir() as tmp:
- mask_filename = os.path.join(tmp, 'mask.' + file_format)
+ mask_filename = os.path.join(tmp, "mask." + file_format)
self.maskWidget.save(mask_filename, file_format)
self.maskWidget.resetSelectionMask()
self.assertTrue(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0))
+ )
self.maskWidget.load(mask_filename)
- self.assertTrue(numpy.all(numpy.equal(
- self.maskWidget.getSelectionMask(), ref_mask)))
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), ref_mask))
+ )
def testLoadSaveNpy(self):
self.__loadSave("npy")
@@ -270,22 +270,24 @@ class TestScatterMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
def testSigMaskChangedEmitted(self):
self.qapp.processEvents()
self.plot.addScatter(
- x=numpy.arange(1000),
- y=1000 * (numpy.arange(1000) % 20),
- value=numpy.ones((1000,)),
- legend='test')
- self.plot._setActiveItem(kind="scatter", legend="test")
+ x=numpy.arange(1000),
+ y=1000 * (numpy.arange(1000) % 20),
+ value=numpy.ones((1000,)),
+ legend="test",
+ )
+ self.plot.setActiveScatter("test")
self.plot.resetZoom()
self.qapp.processEvents()
- self.plot.remove('test', kind='scatter')
+ self.plot.remove("test", kind="scatter")
self.qapp.processEvents()
self.plot.addScatter(
- x=numpy.arange(1000),
- y=1000 * (numpy.arange(1000) % 20),
- value=numpy.random.random(1000),
- legend='test')
+ x=numpy.arange(1000),
+ y=1000 * (numpy.arange(1000) % 20),
+ value=numpy.random.random(1000),
+ legend="test",
+ )
l = []
diff --git a/src/silx/gui/plot/test/testScatterView.py b/src/silx/gui/plot/test/testScatterView.py
index 692612d..d6853b1 100644
--- a/src/silx/gui/plot/test/testScatterView.py
+++ b/src/silx/gui/plot/test/testScatterView.py
@@ -28,8 +28,6 @@ __license__ = "MIT"
__date__ = "06/03/2018"
-import unittest
-
import numpy
from silx.gui.plot.items import Axis, Scatter
@@ -83,7 +81,7 @@ class TestScatterView(PlotWidgetTestCase):
scale = self.plot.getYAxis().getScale()
self.assertEqual(scale, Axis.LINEAR)
- title = 'Test ScatterView'
+ title = "Test ScatterView"
self.plot.setGraphTitle(title)
self.assertEqual(self.plot.getGraphTitle(), title)
@@ -107,13 +105,15 @@ class TestScatterView(PlotWidgetTestCase):
_pts = 100
_levels = 100
_fwhm = 50
- x = numpy.random.rand(_pts)*_levels
- y = numpy.random.rand(_pts)*_levels
- value = numpy.random.rand(_pts)*_levels
- x0 = x[int(_pts/2)]
- y0 = x[int(_pts/2)]
- #2D Gaussian kernel
- alpha = numpy.exp(-4*numpy.log(2) * ((x-x0)**2 + (y-y0)**2) / _fwhm**2)
+ x = numpy.random.rand(_pts) * _levels
+ y = numpy.random.rand(_pts) * _levels
+ value = numpy.random.rand(_pts) * _levels
+ x0 = x[int(_pts / 2)]
+ y0 = x[int(_pts / 2)]
+ # 2D Gaussian kernel
+ alpha = numpy.exp(
+ -4 * numpy.log(2) * ((x - x0) ** 2 + (y - y0) ** 2) / _fwhm**2
+ )
self.plot.setData(x, y, value, alpha=alpha)
self.qapp.processEvents()
diff --git a/src/silx/gui/plot/test/testStackView.py b/src/silx/gui/plot/test/testStackView.py
index aba8678..5e0ead5 100644
--- a/src/silx/gui/plot/test/testStackView.py
+++ b/src/silx/gui/plot/test/testStackView.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -28,7 +28,6 @@ __license__ = "MIT"
__date__ = "20/03/2017"
-import unittest
import numpy
from silx.gui.utils.testutils import TestCaseQt, SignalListener
@@ -49,8 +48,10 @@ class TestStackView(TestCaseQt):
self.stackview.show()
self.qWaitForWindowExposed(self.stackview)
self.mystack = numpy.fromfunction(
- lambda i, j, k: numpy.sin(i/15.) + numpy.cos(j/4.) + 2 * numpy.sin(k/6.),
- (10, 20, 30)
+ lambda i, j, k: numpy.sin(i / 15.0)
+ + numpy.cos(j / 4.0)
+ + 2 * numpy.sin(k / 6.0),
+ (10, 20, 30),
)
def tearDown(self):
@@ -74,13 +75,11 @@ class TestStackView(TestCaseQt):
def testSetStack(self):
self.stackview.setStack(self.mystack)
- self.stackview.setColormap("viridis", autoscale=True)
+ self.stackview.setColormap("viridis")
my_trans_stack, params = self.stackview.getStack()
self.assertEqual(my_trans_stack.shape, self.mystack.shape)
- self.assertTrue(numpy.array_equal(self.mystack,
- my_trans_stack))
- self.assertEqual(params["colormap"]["name"],
- "viridis")
+ self.assertTrue(numpy.array_equal(self.mystack, my_trans_stack))
+ self.assertEqual(params["colormap"]["name"], "viridis")
def testSetStackPerspective(self):
self.stackview.setStack(self.mystack, perspective=1)
@@ -88,10 +87,15 @@ class TestStackView(TestCaseQt):
my_trans_stack, params = self.stackview.getCurrentView()
# get stack returns the transposed data, depending on the perspective
- self.assertEqual(my_trans_stack.shape,
- (self.mystack.shape[1], self.mystack.shape[0], self.mystack.shape[2]))
- self.assertTrue(numpy.array_equal(numpy.transpose(self.mystack, axes=(1, 0, 2)),
- my_trans_stack))
+ self.assertEqual(
+ my_trans_stack.shape,
+ (self.mystack.shape[1], self.mystack.shape[0], self.mystack.shape[2]),
+ )
+ self.assertTrue(
+ numpy.array_equal(
+ numpy.transpose(self.mystack, axes=(1, 0, 2)), my_trans_stack
+ )
+ )
def testSetStackListOfImages(self):
loi = [self.mystack[i] for i in range(self.mystack.shape[0])]
@@ -100,10 +104,8 @@ class TestStackView(TestCaseQt):
my_orig_stack, params = self.stackview.getStack(returnNumpyArray=True)
my_trans_stack, params = self.stackview.getStack(returnNumpyArray=True)
self.assertEqual(my_trans_stack.shape, self.mystack.shape)
- self.assertTrue(numpy.array_equal(self.mystack,
- my_trans_stack))
- self.assertTrue(numpy.array_equal(self.mystack,
- my_orig_stack))
+ self.assertTrue(numpy.array_equal(self.mystack, my_trans_stack))
+ self.assertTrue(numpy.array_equal(self.mystack, my_orig_stack))
self.assertIsInstance(my_trans_stack, numpy.ndarray)
self.stackview.setStack(loi, perspective=2)
@@ -113,88 +115,100 @@ class TestStackView(TestCaseQt):
self.assertIs(my_orig_stack, loi)
# getCurrentView(copy=False) returns a ListOfImages whose .images
# attr is the original data
- self.assertEqual(my_trans_stack.shape,
- (self.mystack.shape[2], self.mystack.shape[0], self.mystack.shape[1]))
- self.assertTrue(numpy.array_equal(numpy.array(my_trans_stack),
- numpy.transpose(self.mystack, axes=(2, 0, 1))))
- self.assertIsInstance(my_trans_stack,
- ListOfImages) # returnNumpyArray=False by default in getStack
+ self.assertEqual(
+ my_trans_stack.shape,
+ (self.mystack.shape[2], self.mystack.shape[0], self.mystack.shape[1]),
+ )
+ self.assertTrue(
+ numpy.array_equal(
+ numpy.array(my_trans_stack),
+ numpy.transpose(self.mystack, axes=(2, 0, 1)),
+ )
+ )
+ self.assertIsInstance(
+ my_trans_stack, ListOfImages
+ ) # returnNumpyArray=False by default in getStack
self.assertIs(my_trans_stack.images, loi)
def testPerspective(self):
self.stackview.setStack(numpy.arange(24).reshape((2, 3, 4)))
- self.assertEqual(self.stackview._perspective, 0,
- "Default perspective is not 0 (dim1-dim2).")
+ self.assertEqual(
+ self.stackview._perspective, 0, "Default perspective is not 0 (dim1-dim2)."
+ )
self.stackview._StackView__planeSelection.setPerspective(1)
- self.assertEqual(self.stackview._perspective, 1,
- "Plane selection combobox not updating perspective")
+ self.assertEqual(
+ self.stackview._perspective,
+ 1,
+ "Plane selection combobox not updating perspective",
+ )
self.stackview.setStack(numpy.arange(6).reshape((1, 2, 3)))
- self.assertEqual(self.stackview._perspective, 1,
- "Perspective not preserved when calling setStack "
- "without specifying the perspective parameter.")
+ self.assertEqual(
+ self.stackview._perspective,
+ 1,
+ "Perspective not preserved when calling setStack "
+ "without specifying the perspective parameter.",
+ )
self.stackview.setStack(numpy.arange(24).reshape((2, 3, 4)), perspective=2)
- self.assertEqual(self.stackview._perspective, 2,
- "Perspective not set in setStack(..., perspective=2).")
+ self.assertEqual(
+ self.stackview._perspective,
+ 2,
+ "Perspective not set in setStack(..., perspective=2).",
+ )
def testDefaultTitle(self):
"""Test that the plot title contains the proper Z information"""
- self.stackview.setStack(numpy.arange(24).reshape((4, 3, 2)),
- calibrations=[(0, 1), (-10, 10), (3.14, 3.14)])
- self.assertEqual(self.stackview._plot.getGraphTitle(),
- "Image z=0")
+ self.stackview.setStack(
+ numpy.arange(24).reshape((4, 3, 2)),
+ calibrations=[(0, 1), (-10, 10), (3.14, 3.14)],
+ )
+ self.assertEqual(self.stackview._plot.getGraphTitle(), "Image z=0")
self.stackview.setFrameNumber(2)
- self.assertEqual(self.stackview._plot.getGraphTitle(),
- "Image z=2")
+ self.assertEqual(self.stackview._plot.getGraphTitle(), "Image z=2")
self.stackview._StackView__planeSelection.setPerspective(1)
self.stackview.setFrameNumber(0)
- self.assertEqual(self.stackview._plot.getGraphTitle(),
- "Image z=-10")
+ self.assertEqual(self.stackview._plot.getGraphTitle(), "Image z=-10")
self.stackview.setFrameNumber(2)
- self.assertEqual(self.stackview._plot.getGraphTitle(),
- "Image z=10")
+ self.assertEqual(self.stackview._plot.getGraphTitle(), "Image z=10")
self.stackview._StackView__planeSelection.setPerspective(2)
self.stackview.setFrameNumber(0)
- self.assertEqual(self.stackview._plot.getGraphTitle(),
- "Image z=3.14")
+ self.assertEqual(self.stackview._plot.getGraphTitle(), "Image z=3.14")
self.stackview.setFrameNumber(1)
- self.assertEqual(self.stackview._plot.getGraphTitle(),
- "Image z=6.28")
+ self.assertEqual(self.stackview._plot.getGraphTitle(), "Image z=6.28")
def testCustomTitle(self):
"""Test setting the plot title with a user defined callback"""
- self.stackview.setStack(numpy.arange(24).reshape((4, 3, 2)),
- calibrations=[(0, 1), (-10, 10), (3.14, 3.14)])
+ self.stackview.setStack(
+ numpy.arange(24).reshape((4, 3, 2)),
+ calibrations=[(0, 1), (-10, 10), (3.14, 3.14)],
+ )
def title_callback(frame_idx):
return "Cubed index title %d" % (frame_idx**3)
self.stackview.setTitleCallback(title_callback)
- self.assertEqual(self.stackview._plot.getGraphTitle(),
- "Cubed index title 0")
+ self.assertEqual(self.stackview._plot.getGraphTitle(), "Cubed index title 0")
self.stackview.setFrameNumber(2)
- self.assertEqual(self.stackview._plot.getGraphTitle(),
- "Cubed index title 8")
+ self.assertEqual(self.stackview._plot.getGraphTitle(), "Cubed index title 8")
# perspective should not matter, only frame index
self.stackview._StackView__planeSelection.setPerspective(1)
self.stackview.setFrameNumber(0)
- self.assertEqual(self.stackview._plot.getGraphTitle(),
- "Cubed index title 0")
+ self.assertEqual(self.stackview._plot.getGraphTitle(), "Cubed index title 0")
self.stackview.setFrameNumber(2)
- self.assertEqual(self.stackview._plot.getGraphTitle(),
- "Cubed index title 8")
+ self.assertEqual(self.stackview._plot.getGraphTitle(), "Cubed index title 8")
with self.assertRaises(TypeError):
# setTitleCallback should not accept non-callable objects like strings
self.stackview.setTitleCallback(
- "Là, vous faites sirop de vingt-et-un et vous dites : "
- "beau sirop, mi-sirop, siroté, gagne-sirop, sirop-grelot,"
- " passe-montagne, sirop au bon goût.")
+ "Là, vous faites sirop de vingt-et-un et vous dites : "
+ "beau sirop, mi-sirop, siroté, gagne-sirop, sirop-grelot,"
+ " passe-montagne, sirop au bon goût."
+ )
def testStackFrameNumber(self):
self.stackview.setStack(self.mystack)
@@ -217,8 +231,10 @@ class TestStackViewMainWindow(TestCaseQt):
self.stackview.show()
self.qWaitForWindowExposed(self.stackview)
self.mystack = numpy.fromfunction(
- lambda i, j, k: numpy.sin(i/15.) + numpy.cos(j/4.) + 2 * numpy.sin(k/6.),
- (10, 20, 30)
+ lambda i, j, k: numpy.sin(i / 15.0)
+ + numpy.cos(j / 4.0)
+ + 2 * numpy.sin(k / 6.0),
+ (10, 20, 30),
)
def tearDown(self):
@@ -229,19 +245,22 @@ class TestStackViewMainWindow(TestCaseQt):
def testSetStack(self):
self.stackview.setStack(self.mystack)
- self.stackview.setColormap("viridis", autoscale=True)
+ self.stackview.setColormap("viridis")
my_trans_stack, params = self.stackview.getStack()
self.assertEqual(my_trans_stack.shape, self.mystack.shape)
- self.assertTrue(numpy.array_equal(self.mystack,
- my_trans_stack))
- self.assertEqual(params["colormap"]["name"],
- "viridis")
+ self.assertTrue(numpy.array_equal(self.mystack, my_trans_stack))
+ self.assertEqual(params["colormap"]["name"], "viridis")
def testSetStackPerspective(self):
self.stackview.setStack(self.mystack, perspective=1)
my_trans_stack, params = self.stackview.getCurrentView()
# get stack returns the transposed data, depending on the perspective
- self.assertEqual(my_trans_stack.shape,
- (self.mystack.shape[1], self.mystack.shape[0], self.mystack.shape[2]))
- self.assertTrue(numpy.array_equal(numpy.transpose(self.mystack, axes=(1, 0, 2)),
- my_trans_stack))
+ self.assertEqual(
+ my_trans_stack.shape,
+ (self.mystack.shape[1], self.mystack.shape[0], self.mystack.shape[2]),
+ )
+ self.assertTrue(
+ numpy.array_equal(
+ numpy.transpose(self.mystack, axes=(1, 0, 2)), my_trans_stack
+ )
+ )
diff --git a/src/silx/gui/plot/test/testStats.py b/src/silx/gui/plot/test/testStats.py
index c5d5181..2a2793e 100644
--- a/src/silx/gui/plot/test/testStats.py
+++ b/src/silx/gui/plot/test/testStats.py
@@ -34,13 +34,11 @@ from silx.gui.plot import StatsWidget
from silx.gui.plot.stats import statshandler
from silx.gui.utils.testutils import TestCaseQt, SignalListener
from silx.gui.plot import Plot1D, Plot2D
-from silx.gui.plot3d.SceneWidget import SceneWidget
from silx.gui.plot.items.roi import RectangleROI, PolygonROI
-from silx.gui.plot.tools.roi import RegionOfInterestManager
+from silx.gui.plot.tools.roi import RegionOfInterestManager
from silx.gui.plot.stats.stats import Stats
from silx.gui.plot.CurvesROIWidget import ROI
from silx.utils.testutils import ParametricTestCase
-import unittest
import logging
import numpy
@@ -49,6 +47,7 @@ _logger = logging.getLogger(__name__)
class TestStatsBase(object):
"""Base class for stats TestCase"""
+
def setUp(self):
self.createCurveContext()
self.createImageContext()
@@ -69,51 +68,52 @@ class TestStatsBase(object):
self.plot1d = Plot1D()
x = range(20)
y = range(20)
- self.plot1d.addCurve(x, y, legend='curve0')
+ self.plot1d.addCurve(x, y, legend="curve0")
self.curveContext = stats._CurveContext(
- item=self.plot1d.getCurve('curve0'),
+ item=self.plot1d.getCurve("curve0"),
plot=self.plot1d,
onlimits=False,
- roi=None)
+ roi=None,
+ )
def createScatterContext(self):
self.scatterPlot = Plot2D()
- lgd = 'scatter plot'
+ lgd = "scatter plot"
self.xScatterData = numpy.array([0, 2, 3, 20, 50, 60, 36])
self.yScatterData = numpy.array([2, 3, 4, 26, 69, 6, 18])
self.valuesScatterData = numpy.array([5, 6, 7, 10, 90, 20, 5])
- self.scatterPlot.addScatter(self.xScatterData, self.yScatterData,
- self.valuesScatterData, legend=lgd)
+ self.scatterPlot.addScatter(
+ self.xScatterData, self.yScatterData, self.valuesScatterData, legend=lgd
+ )
self.scatterContext = stats._ScatterContext(
item=self.scatterPlot.getScatter(lgd),
plot=self.scatterPlot,
onlimits=False,
- roi=None
+ roi=None,
)
def createImageContext(self):
self.plot2d = Plot2D()
- self._imgLgd = 'test image'
- self.imageData = numpy.arange(32*128).reshape(32, 128)
- self.plot2d.addImage(data=self.imageData,
- legend=self._imgLgd, replace=False)
+ self._imgLgd = "test image"
+ self.imageData = numpy.arange(32 * 128).reshape(32, 128)
+ self.plot2d.addImage(data=self.imageData, legend=self._imgLgd, replace=False)
self.imageContext = stats._ImageContext(
item=self.plot2d.getImage(self._imgLgd),
plot=self.plot2d,
onlimits=False,
- roi=None
+ roi=None,
)
def getBasicStats(self):
return {
- 'min': stats.StatMin(),
- 'minCoords': stats.StatCoordMin(),
- 'max': stats.StatMax(),
- 'maxCoords': stats.StatCoordMax(),
- 'std': stats.Stat(name='std', fct=numpy.std),
- 'mean': stats.Stat(name='mean', fct=numpy.mean),
- 'com': stats.StatCOM()
+ "min": stats.StatMin(),
+ "minCoords": stats.StatCoordMin(),
+ "max": stats.StatMax(),
+ "maxCoords": stats.StatCoordMax(),
+ "std": stats.Stat(name="std", fct=numpy.std),
+ "mean": stats.Stat(name="mean", fct=numpy.mean),
+ "com": stats.StatCOM(),
}
@@ -121,6 +121,7 @@ class TestStats(TestStatsBase, TestCaseQt):
"""
Test :class:`BaseClass` class and inheriting classes
"""
+
def setUp(self):
TestCaseQt.setUp(self)
TestStatsBase.setUp(self)
@@ -133,41 +134,50 @@ class TestStats(TestStatsBase, TestCaseQt):
"""Test result for simple stats on a curve"""
_stats = self.getBasicStats()
xData = yData = numpy.array(range(20))
- self.assertEqual(_stats['min'].calculate(self.curveContext), 0)
- self.assertEqual(_stats['max'].calculate(self.curveContext), 19)
- self.assertEqual(_stats['minCoords'].calculate(self.curveContext), (0,))
- self.assertEqual(_stats['maxCoords'].calculate(self.curveContext), (19,))
- self.assertEqual(_stats['std'].calculate(self.curveContext), numpy.std(yData))
- self.assertEqual(_stats['mean'].calculate(self.curveContext), numpy.mean(yData))
+ self.assertEqual(_stats["min"].calculate(self.curveContext), 0)
+ self.assertEqual(_stats["max"].calculate(self.curveContext), 19)
+ self.assertEqual(_stats["minCoords"].calculate(self.curveContext), (0,))
+ self.assertEqual(_stats["maxCoords"].calculate(self.curveContext), (19,))
+ self.assertEqual(_stats["std"].calculate(self.curveContext), numpy.std(yData))
+ self.assertEqual(_stats["mean"].calculate(self.curveContext), numpy.mean(yData))
com = numpy.sum(xData * yData) / numpy.sum(yData)
- self.assertEqual(_stats['com'].calculate(self.curveContext), com)
+ self.assertEqual(_stats["com"].calculate(self.curveContext), com)
def testBasicStatsImage(self):
"""Test result for simple stats on an image"""
_stats = self.getBasicStats()
- self.assertEqual(_stats['min'].calculate(self.imageContext), 0)
- self.assertEqual(_stats['max'].calculate(self.imageContext), 128 * 32 - 1)
- self.assertEqual(_stats['minCoords'].calculate(self.imageContext), (0, 0))
- self.assertEqual(_stats['maxCoords'].calculate(self.imageContext), (127, 31))
- self.assertEqual(_stats['std'].calculate(self.imageContext), numpy.std(self.imageData))
- self.assertEqual(_stats['mean'].calculate(self.imageContext), numpy.mean(self.imageData))
+ self.assertEqual(_stats["min"].calculate(self.imageContext), 0)
+ self.assertEqual(_stats["max"].calculate(self.imageContext), 128 * 32 - 1)
+ self.assertEqual(_stats["minCoords"].calculate(self.imageContext), (0, 0))
+ self.assertEqual(_stats["maxCoords"].calculate(self.imageContext), (127, 31))
+ self.assertEqual(
+ _stats["std"].calculate(self.imageContext), numpy.std(self.imageData)
+ )
+ self.assertEqual(
+ _stats["mean"].calculate(self.imageContext), numpy.mean(self.imageData)
+ )
yData = numpy.sum(self.imageData.astype(numpy.float64), axis=1)
xData = numpy.sum(self.imageData.astype(numpy.float64), axis=0)
dataXRange = range(self.imageData.shape[1])
dataYRange = range(self.imageData.shape[0])
- ycom = numpy.sum(yData*dataYRange) / numpy.sum(yData)
- xcom = numpy.sum(xData*dataXRange) / numpy.sum(xData)
+ ycom = numpy.sum(yData * dataYRange) / numpy.sum(yData)
+ xcom = numpy.sum(xData * dataXRange) / numpy.sum(xData)
- self.assertEqual(_stats['com'].calculate(self.imageContext), (xcom, ycom))
+ self.assertEqual(_stats["com"].calculate(self.imageContext), (xcom, ycom))
def testStatsImageAdv(self):
"""Test that scale and origin are taking into account for images"""
image2Data = numpy.arange(32 * 128).reshape(32, 128)
- self.plot2d.addImage(data=image2Data, legend=self._imgLgd,
- replace=True, origin=(100, 10), scale=(2, 0.5))
+ self.plot2d.addImage(
+ data=image2Data,
+ legend=self._imgLgd,
+ replace=True,
+ origin=(100, 10),
+ scale=(2, 0.5),
+ )
image2Context = stats._ImageContext(
item=self.plot2d.getImage(self._imgLgd),
plot=self.plot2d,
@@ -175,18 +185,19 @@ class TestStats(TestStatsBase, TestCaseQt):
roi=None,
)
_stats = self.getBasicStats()
- self.assertEqual(_stats['min'].calculate(image2Context), 0)
+ self.assertEqual(_stats["min"].calculate(image2Context), 0)
+ self.assertEqual(_stats["max"].calculate(image2Context), 128 * 32 - 1)
+ self.assertEqual(_stats["minCoords"].calculate(image2Context), (100, 10))
self.assertEqual(
- _stats['max'].calculate(image2Context), 128 * 32 - 1)
+ _stats["maxCoords"].calculate(image2Context),
+ (127 * 2.0 + 100, 31 * 0.5 + 10),
+ )
self.assertEqual(
- _stats['minCoords'].calculate(image2Context), (100, 10))
+ _stats["std"].calculate(image2Context), numpy.std(self.imageData)
+ )
self.assertEqual(
- _stats['maxCoords'].calculate(image2Context), (127*2. + 100,
- 31 * 0.5 + 10))
- self.assertEqual(_stats['std'].calculate(image2Context),
- numpy.std(self.imageData))
- self.assertEqual(_stats['mean'].calculate(image2Context),
- numpy.mean(self.imageData))
+ _stats["mean"].calculate(image2Context), numpy.mean(self.imageData)
+ )
yData = numpy.sum(self.imageData, axis=1)
xData = numpy.sum(self.imageData, axis=0)
@@ -196,30 +207,36 @@ class TestStats(TestStatsBase, TestCaseQt):
ycom = numpy.sum(yData * dataYRange) / numpy.sum(yData)
ycom = (ycom * 0.5) + 10
xcom = numpy.sum(xData * dataXRange) / numpy.sum(xData)
- xcom = (xcom * 2.) + 100
- self.assertTrue(numpy.allclose(
- _stats['com'].calculate(image2Context), (xcom, ycom)))
+ xcom = (xcom * 2.0) + 100
+ self.assertTrue(
+ numpy.allclose(_stats["com"].calculate(image2Context), (xcom, ycom))
+ )
def testBasicStatsScatter(self):
"""Test result for simple stats on a scatter"""
_stats = self.getBasicStats()
- self.assertEqual(_stats['min'].calculate(self.scatterContext), 5)
- self.assertEqual(_stats['max'].calculate(self.scatterContext), 90)
- self.assertEqual(_stats['minCoords'].calculate(self.scatterContext), (0, 2))
- self.assertEqual(_stats['maxCoords'].calculate(self.scatterContext), (50, 69))
- self.assertEqual(_stats['std'].calculate(self.scatterContext), numpy.std(self.valuesScatterData))
- self.assertEqual(_stats['mean'].calculate(self.scatterContext), numpy.mean(self.valuesScatterData))
+ self.assertEqual(_stats["min"].calculate(self.scatterContext), 5)
+ self.assertEqual(_stats["max"].calculate(self.scatterContext), 90)
+ self.assertEqual(_stats["minCoords"].calculate(self.scatterContext), (0, 2))
+ self.assertEqual(_stats["maxCoords"].calculate(self.scatterContext), (50, 69))
+ self.assertEqual(
+ _stats["std"].calculate(self.scatterContext),
+ numpy.std(self.valuesScatterData),
+ )
+ self.assertEqual(
+ _stats["mean"].calculate(self.scatterContext),
+ numpy.mean(self.valuesScatterData),
+ )
data = self.valuesScatterData.astype(numpy.float64)
comx = numpy.sum(self.xScatterData * data) / numpy.sum(data)
comy = numpy.sum(self.yScatterData * data) / numpy.sum(data)
- self.assertEqual(_stats['com'].calculate(self.scatterContext),
- (comx, comy))
+ self.assertEqual(_stats["com"].calculate(self.scatterContext), (comx, comy))
def testKindNotManagedByStat(self):
"""Make sure an exception is raised if we try to execute calculate
of the base class"""
- b = stats.StatBase(name='toto', compatibleKinds='curve')
+ b = stats.StatBase(name="toto", compatibleKinds="curve")
with self.assertRaises(NotImplementedError):
b.calculate(self.imageContext)
@@ -228,7 +245,7 @@ class TestStats(TestStatsBase, TestCaseQt):
Make sure an error is raised if we try to calculate a statistic with
a context not managed
"""
- myStat = stats.Stat(name='toto', fct=numpy.std, kinds=('curve'))
+ myStat = stats.Stat(name="toto", fct=numpy.std, kinds=("curve"))
myStat.calculate(self.curveContext)
with self.assertRaises(ValueError):
myStat.calculate(self.scatterContext)
@@ -240,43 +257,48 @@ class TestStats(TestStatsBase, TestCaseQt):
self.plot1d.getXAxis().setLimitsConstraints(minPos=2, maxPos=5)
curveContextOnLimits = stats._CurveContext(
- item=self.plot1d.getCurve('curve0'),
+ item=self.plot1d.getCurve("curve0"),
plot=self.plot1d,
onlimits=True,
- roi=None)
+ roi=None,
+ )
self.assertEqual(stat.calculate(curveContextOnLimits), 2)
self.plot2d.getXAxis().setLimitsConstraints(minPos=32)
imageContextOnLimits = stats._ImageContext(
- item=self.plot2d.getImage('test image'),
+ item=self.plot2d.getImage("test image"),
plot=self.plot2d,
onlimits=True,
- roi=None)
+ roi=None,
+ )
self.assertEqual(stat.calculate(imageContextOnLimits), 32)
self.scatterPlot.getXAxis().setLimitsConstraints(minPos=40)
scatterContextOnLimits = stats._ScatterContext(
- item=self.scatterPlot.getScatter('scatter plot'),
+ item=self.scatterPlot.getScatter("scatter plot"),
plot=self.scatterPlot,
onlimits=True,
- roi=None)
+ roi=None,
+ )
self.assertEqual(stat.calculate(scatterContextOnLimits), 20)
class TestStatsFormatter(TestCaseQt):
"""Simple test to check usage of the :class:`StatsFormatter`"""
+
def setUp(self):
TestCaseQt.setUp(self)
self.plot1d = Plot1D()
x = range(20)
y = range(20)
- self.plot1d.addCurve(x, y, legend='curve0')
+ self.plot1d.addCurve(x, y, legend="curve0")
self.curveContext = stats._CurveContext(
- item=self.plot1d.getCurve('curve0'),
+ item=self.plot1d.getCurve("curve0"),
plot=self.plot1d,
onlimits=False,
- roi=None)
+ roi=None,
+ )
self.stat = stats.StatMin()
@@ -291,27 +313,30 @@ class TestStatsFormatter(TestCaseQt):
simple cast to str"""
emptyFormatter = statshandler.StatFormatter()
self.assertEqual(
- emptyFormatter.format(self.stat.calculate(self.curveContext)), '0.000')
+ emptyFormatter.format(self.stat.calculate(self.curveContext)), "0.000"
+ )
def testSettedFormatter(self):
"""Make sure a formatter with no formatter definition will return a
simple cast to str"""
- formatter= statshandler.StatFormatter(formatter='{0:.3f}')
+ formatter = statshandler.StatFormatter(formatter="{0:.3f}")
self.assertEqual(
- formatter.format(self.stat.calculate(self.curveContext)), '0.000')
+ formatter.format(self.stat.calculate(self.curveContext)), "0.000"
+ )
class TestStatsHandler(TestCaseQt):
- """Make sure the StatHandler is correctly making the link between
+ """Make sure the StatHandler is correctly making the link between
:class:`StatBase` and :class:`StatFormatter` and checking the API is valid
"""
+
def setUp(self):
TestCaseQt.setUp(self)
self.plot1d = Plot1D()
x = range(20)
y = range(20)
- self.plot1d.addCurve(x, y, legend='curve0')
- self.curveItem = self.plot1d.getCurve('curve0')
+ self.plot1d.addCurve(x, y, legend="curve0")
+ self.curveItem = self.plot1d.getCurve("curve0")
self.stat = stats.StatMin()
@@ -324,91 +349,94 @@ class TestStatsHandler(TestCaseQt):
def testConstructor(self):
"""Make sure the constructor can deal will all possible arguments:
-
+
* tuple of :class:`StatBase` derivated classes
* tuple of tuples (:class:`StatBase`, :class:`StatFormatter`)
* tuple of tuples (str, pointer to function, kind)
"""
- handler0 = statshandler.StatsHandler(
- (stats.StatMin(), stats.StatMax())
- )
+ handler0 = statshandler.StatsHandler((stats.StatMin(), stats.StatMax()))
- res = handler0.calculate(item=self.curveItem, plot=self.plot1d,
- onlimits=False)
- self.assertTrue('min' in res)
- self.assertEqual(res['min'], '0')
- self.assertTrue('max' in res)
- self.assertEqual(res['max'], '19')
+ res = handler0.calculate(item=self.curveItem, plot=self.plot1d, onlimits=False)
+ self.assertTrue("min" in res)
+ self.assertEqual(res["min"], "0")
+ self.assertTrue("max" in res)
+ self.assertEqual(res["max"], "19")
handler1 = statshandler.StatsHandler(
(
(stats.StatMin(), statshandler.StatFormatter(formatter=None)),
- (stats.StatMax(), statshandler.StatFormatter())
+ (stats.StatMax(), statshandler.StatFormatter()),
)
)
- res = handler1.calculate(item=self.curveItem, plot=self.plot1d,
- onlimits=False)
- self.assertTrue('min' in res)
- self.assertEqual(res['min'], '0')
- self.assertTrue('max' in res)
- self.assertEqual(res['max'], '19.000')
+ res = handler1.calculate(item=self.curveItem, plot=self.plot1d, onlimits=False)
+ self.assertTrue("min" in res)
+ self.assertEqual(res["min"], "0")
+ self.assertTrue("max" in res)
+ self.assertEqual(res["max"], "19.000")
handler2 = statshandler.StatsHandler(
+ ((stats.StatMin(), None), (stats.StatMax(), statshandler.StatFormatter()))
+ )
+
+ res = handler2.calculate(item=self.curveItem, plot=self.plot1d, onlimits=False)
+ self.assertTrue("min" in res)
+ self.assertEqual(res["min"], "0")
+ self.assertTrue("max" in res)
+ self.assertEqual(res["max"], "19.000")
+
+ handler3 = statshandler.StatsHandler(
(
- (stats.StatMin(), None),
- (stats.StatMax(), statshandler.StatFormatter())
- ))
-
- res = handler2.calculate(item=self.curveItem, plot=self.plot1d,
- onlimits=False)
- self.assertTrue('min' in res)
- self.assertEqual(res['min'], '0')
- self.assertTrue('max' in res)
- self.assertEqual(res['max'], '19.000')
-
- handler3 = statshandler.StatsHandler((
- (('amin', numpy.argmin), statshandler.StatFormatter()),
- ('amax', numpy.argmax)
- ))
-
- res = handler3.calculate(item=self.curveItem, plot=self.plot1d,
- onlimits=False)
- self.assertTrue('amin' in res)
- self.assertEqual(res['amin'], '0.000')
- self.assertTrue('amax' in res)
- self.assertEqual(res['amax'], '19')
+ (("amin", numpy.argmin), statshandler.StatFormatter()),
+ ("amax", numpy.argmax),
+ )
+ )
+
+ res = handler3.calculate(item=self.curveItem, plot=self.plot1d, onlimits=False)
+ self.assertTrue("amin" in res)
+ self.assertEqual(res["amin"], "0.000")
+ self.assertTrue("amax" in res)
+ self.assertEqual(res["amax"], "19")
with self.assertRaises(ValueError):
- statshandler.StatsHandler(('name'))
+ statshandler.StatsHandler(("name"))
class TestStatsWidgetWithCurves(TestCaseQt, ParametricTestCase):
"""Basic test for StatsWidget with curves"""
+
def setUp(self):
TestCaseQt.setUp(self)
self.plot = Plot1D()
self.plot.show()
x = range(20)
y = range(20)
- self.plot.addCurve(x, y, legend='curve0')
+ self.plot.addCurve(x, y, legend="curve0")
y = range(12, 32)
- self.plot.addCurve(x, y, legend='curve1')
+ self.plot.addCurve(x, y, legend="curve1")
y = range(-2, 18)
- self.plot.addCurve(x, y, legend='curve2')
+ self.plot.addCurve(x, y, legend="curve2")
self.widget = StatsWidget.StatsWidget(plot=self.plot)
self.statsTable = self.widget._statsTable
- mystats = statshandler.StatsHandler((
- stats.StatMin(),
- (stats.StatCoordMin(), statshandler.StatFormatter(None, qt.QTableWidgetItem)),
- stats.StatMax(),
- (stats.StatCoordMax(), statshandler.StatFormatter(None, qt.QTableWidgetItem)),
- stats.StatDelta(),
- ('std', numpy.std),
- ('mean', numpy.mean),
- stats.StatCOM()
- ))
+ mystats = statshandler.StatsHandler(
+ (
+ stats.StatMin(),
+ (
+ stats.StatCoordMin(),
+ statshandler.StatFormatter(None, qt.QTableWidgetItem),
+ ),
+ stats.StatMax(),
+ (
+ stats.StatCoordMax(),
+ statshandler.StatFormatter(None, qt.QTableWidgetItem),
+ ),
+ stats.StatDelta(),
+ ("std", numpy.std),
+ ("mean", numpy.mean),
+ stats.StatCOM(),
+ )
+ )
self.statsTable.setStats(mystats)
@@ -456,42 +484,44 @@ class TestStatsWidgetWithCurves(TestCaseQt, ParametricTestCase):
def testRemoveCurve(self):
"""Make sure the Curves stats take into account the curve removal from
plot"""
- self.plot.removeCurve('curve2')
+ self.plot.removeCurve("curve2")
self.assertEqual(self.statsTable.rowCount(), 2)
for iRow in range(2):
- self.assertTrue(self.statsTable.item(iRow, 0).text() in ('curve0', 'curve1'))
+ self.assertTrue(
+ self.statsTable.item(iRow, 0).text() in ("curve0", "curve1")
+ )
- self.plot.removeCurve('curve0')
+ self.plot.removeCurve("curve0")
self.assertEqual(self.statsTable.rowCount(), 1)
- self.plot.removeCurve('curve1')
+ self.plot.removeCurve("curve1")
self.assertEqual(self.statsTable.rowCount(), 0)
def testAddCurve(self):
"""Make sure the Curves stats take into account the add curve action"""
- self.plot.addCurve(legend='curve3', x=range(10), y=range(10))
+ self.plot.addCurve(legend="curve3", x=range(10), y=range(10))
self.assertEqual(self.statsTable.rowCount(), 4)
def testUpdateCurveFromAddCurve(self):
"""Make sure the stats of the cuve will be removed after updating a
curve"""
- self.plot.addCurve(legend='curve0', x=range(10), y=range(10))
+ self.plot.addCurve(legend="curve0", x=range(10), y=range(10))
self.qapp.processEvents()
self.assertEqual(self.statsTable.rowCount(), 3)
- curve = self.plot._getItem(kind='curve', legend='curve0')
+ curve = self.plot._getItem(kind="curve", legend="curve0")
tableItems = self.statsTable._itemToTableItems(curve)
- self.assertEqual(tableItems['max'].text(), '9')
+ self.assertEqual(tableItems["max"].text(), "9")
def testUpdateCurveFromCurveObj(self):
- self.plot.getCurve('curve0').setData(x=range(4), y=range(4))
+ self.plot.getCurve("curve0").setData(x=range(4), y=range(4))
self.qapp.processEvents()
self.assertEqual(self.statsTable.rowCount(), 3)
- curve = self.plot._getItem(kind='curve', legend='curve0')
+ curve = self.plot._getItem(kind="curve", legend="curve0")
tableItems = self.statsTable._itemToTableItems(curve)
- self.assertEqual(tableItems['max'].text(), '3')
+ self.assertEqual(tableItems["max"].text(), "3")
def testSetAnotherPlot(self):
plot2 = Plot1D()
- plot2.addCurve(x=range(26), y=range(26), legend='new curve')
+ plot2.addCurve(x=range(26), y=range(26), legend="new curve")
self.statsTable.setPlot(plot2)
self.assertEqual(self.statsTable.rowCount(), 1)
self.qapp.processEvents()
@@ -501,50 +531,62 @@ class TestStatsWidgetWithCurves(TestCaseQt, ParametricTestCase):
def testUpdateMode(self):
"""Make sure the update modes are well take into account"""
- self.plot.setActiveCurve('curve0')
+ self.plot.setActiveCurve("curve0")
for display_only_active in (True, False):
with self.subTest(display_only_active=display_only_active):
self.widget.setDisplayOnlyActiveItem(display_only_active)
- self.plot.getCurve('curve0').setData(x=range(4), y=range(4))
+ self.plot.getCurve("curve0").setData(x=range(4), y=range(4))
self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO)
update_stats_action = self.widget._options.getUpdateStatsAction()
# test from api
- self.assertEqual(self.widget.getUpdateMode(), StatsWidget.UpdateMode.AUTO)
+ self.assertEqual(
+ self.widget.getUpdateMode(), StatsWidget.UpdateMode.AUTO
+ )
self.widget.show()
# check stats change in auto mode
- self.plot.getCurve('curve0').setData(x=range(4), y=range(-1, 3))
+ self.plot.getCurve("curve0").setData(x=range(4), y=range(-1, 3))
self.qapp.processEvents()
- tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0'))
- curve0_min = tableItems['min'].text()
- self.assertTrue(float(curve0_min) == -1.)
+ tableItems = self.statsTable._itemToTableItems(
+ self.plot.getCurve("curve0")
+ )
+ curve0_min = tableItems["min"].text()
+ self.assertTrue(float(curve0_min) == -1.0)
- self.plot.getCurve('curve0').setData(x=range(4), y=range(1, 5))
+ self.plot.getCurve("curve0").setData(x=range(4), y=range(1, 5))
self.qapp.processEvents()
- tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0'))
- curve0_min = tableItems['min'].text()
- self.assertTrue(float(curve0_min) == 1.)
+ tableItems = self.statsTable._itemToTableItems(
+ self.plot.getCurve("curve0")
+ )
+ curve0_min = tableItems["min"].text()
+ self.assertTrue(float(curve0_min) == 1.0)
# check stats change in manual mode only if requested
self.widget.setUpdateMode(StatsWidget.UpdateMode.MANUAL)
- self.assertEqual(self.widget.getUpdateMode(), StatsWidget.UpdateMode.MANUAL)
+ self.assertEqual(
+ self.widget.getUpdateMode(), StatsWidget.UpdateMode.MANUAL
+ )
- self.plot.getCurve('curve0').setData(x=range(4), y=range(2, 6))
+ self.plot.getCurve("curve0").setData(x=range(4), y=range(2, 6))
self.qapp.processEvents()
- tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0'))
- curve0_min = tableItems['min'].text()
- self.assertTrue(float(curve0_min) == 1.)
+ tableItems = self.statsTable._itemToTableItems(
+ self.plot.getCurve("curve0")
+ )
+ curve0_min = tableItems["min"].text()
+ self.assertTrue(float(curve0_min) == 1.0)
update_stats_action.trigger()
- tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0'))
- curve0_min = tableItems['min'].text()
- self.assertTrue(float(curve0_min) == 2.)
+ tableItems = self.statsTable._itemToTableItems(
+ self.plot.getCurve("curve0")
+ )
+ curve0_min = tableItems["min"].text()
+ self.assertTrue(float(curve0_min) == 2.0)
def testItemHidden(self):
"""Test if an item is hide, then the associated stats item is also
hide"""
- curve0 = self.plot.getCurve('curve0')
- curve1 = self.plot.getCurve('curve1')
- curve2 = self.plot.getCurve('curve2')
+ curve0 = self.plot.getCurve("curve0")
+ curve1 = self.plot.getCurve("curve1")
+ curve2 = self.plot.getCurve("curve2")
self.plot.show()
self.widget.show()
@@ -563,8 +605,8 @@ class TestStatsWidgetWithCurves(TestCaseQt, ParametricTestCase):
self.qapp.processEvents()
self.assertTrue(self.statsTable.isRowHidden(1))
tableItems = self.statsTable._itemToTableItems(curve2)
- curve2_min = tableItems['min'].text()
- self.assertTrue(float(curve2_min) == -2.)
+ curve2_min = tableItems["min"].text()
+ self.assertTrue(float(curve2_min) == -2.0)
curve0.setVisible(False)
curve1.setVisible(False)
@@ -578,27 +620,38 @@ class TestStatsWidgetWithCurves(TestCaseQt, ParametricTestCase):
class TestStatsWidgetWithImages(TestCaseQt):
"""Basic test for StatsWidget with images"""
- IMAGE_LEGEND = 'test image'
+ IMAGE_LEGEND = "test image"
def setUp(self):
TestCaseQt.setUp(self)
self.plot = Plot2D()
- self.plot.addImage(data=numpy.arange(128*128).reshape(128, 128),
- legend=self.IMAGE_LEGEND, replace=False)
+ self.plot.addImage(
+ data=numpy.arange(128 * 128).reshape(128, 128),
+ legend=self.IMAGE_LEGEND,
+ replace=False,
+ )
self.widget = StatsWidget.StatsTable(plot=self.plot)
- mystats = statshandler.StatsHandler((
- (stats.StatMin(), statshandler.StatFormatter()),
- (stats.StatCoordMin(), statshandler.StatFormatter(None, qt.QTableWidgetItem)),
- (stats.StatMax(), statshandler.StatFormatter()),
- (stats.StatCoordMax(), statshandler.StatFormatter(None, qt.QTableWidgetItem)),
- (stats.StatDelta(), statshandler.StatFormatter()),
- ('std', numpy.std),
- ('mean', numpy.mean),
- (stats.StatCOM(), statshandler.StatFormatter(None))
- ))
+ mystats = statshandler.StatsHandler(
+ (
+ (stats.StatMin(), statshandler.StatFormatter()),
+ (
+ stats.StatCoordMin(),
+ statshandler.StatFormatter(None, qt.QTableWidgetItem),
+ ),
+ (stats.StatMax(), statshandler.StatFormatter()),
+ (
+ stats.StatCoordMax(),
+ statshandler.StatFormatter(None, qt.QTableWidgetItem),
+ ),
+ (stats.StatDelta(), statshandler.StatFormatter()),
+ ("std", numpy.std),
+ ("mean", numpy.mean),
+ (stats.StatCOM(), statshandler.StatFormatter(None)),
+ )
+ )
self.widget.setStats(mystats)
@@ -613,17 +666,16 @@ class TestStatsWidgetWithImages(TestCaseQt):
TestCaseQt.tearDown(self)
def test(self):
- image = self.plot._getItem(
- kind='image', legend=self.IMAGE_LEGEND)
+ image = self.plot._getItem(kind="image", legend=self.IMAGE_LEGEND)
tableItems = self.widget._itemToTableItems(image)
- maxText = '{0:.3f}'.format((128 * 128) - 1)
- self.assertEqual(tableItems['legend'].text(), self.IMAGE_LEGEND)
- self.assertEqual(tableItems['min'].text(), '0.000')
- self.assertEqual(tableItems['max'].text(), maxText)
- self.assertEqual(tableItems['delta'].text(), maxText)
- self.assertEqual(tableItems['coords min'].text(), '0.0, 0.0')
- self.assertEqual(tableItems['coords max'].text(), '127.0, 127.0')
+ maxText = "{0:.3f}".format((128 * 128) - 1)
+ self.assertEqual(tableItems["legend"].text(), self.IMAGE_LEGEND)
+ self.assertEqual(tableItems["min"].text(), "0.000")
+ self.assertEqual(tableItems["max"].text(), maxText)
+ self.assertEqual(tableItems["delta"].text(), maxText)
+ self.assertEqual(tableItems["coords min"].text(), "0.0, 0.0")
+ self.assertEqual(tableItems["coords max"].text(), "127.0, 127.0")
def testItemHidden(self):
"""Test if an item is hide, then the associated stats item is also
@@ -638,28 +690,37 @@ class TestStatsWidgetWithImages(TestCaseQt):
class TestStatsWidgetWithScatters(TestCaseQt):
-
- SCATTER_LEGEND = 'scatter plot'
+ SCATTER_LEGEND = "scatter plot"
def setUp(self):
TestCaseQt.setUp(self)
self.scatterPlot = Plot2D()
- self.scatterPlot.addScatter([0, 1, 2, 20, 50, 60],
- [2, 3, 4, 26, 69, 6],
- [5, 6, 7, 10, 90, 20],
- legend=self.SCATTER_LEGEND)
+ self.scatterPlot.addScatter(
+ [0, 1, 2, 20, 50, 60],
+ [2, 3, 4, 26, 69, 6],
+ [5, 6, 7, 10, 90, 20],
+ legend=self.SCATTER_LEGEND,
+ )
self.widget = StatsWidget.StatsTable(plot=self.scatterPlot)
- mystats = statshandler.StatsHandler((
- stats.StatMin(),
- (stats.StatCoordMin(), statshandler.StatFormatter(None, qt.QTableWidgetItem)),
- stats.StatMax(),
- (stats.StatCoordMax(), statshandler.StatFormatter(None, qt.QTableWidgetItem)),
- stats.StatDelta(),
- ('std', numpy.std),
- ('mean', numpy.mean),
- stats.StatCOM()
- ))
+ mystats = statshandler.StatsHandler(
+ (
+ stats.StatMin(),
+ (
+ stats.StatCoordMin(),
+ statshandler.StatFormatter(None, qt.QTableWidgetItem),
+ ),
+ stats.StatMax(),
+ (
+ stats.StatCoordMax(),
+ statshandler.StatFormatter(None, qt.QTableWidgetItem),
+ ),
+ stats.StatDelta(),
+ ("std", numpy.std),
+ ("mean", numpy.mean),
+ stats.StatCOM(),
+ )
+ )
self.widget.setStats(mystats)
@@ -674,15 +735,14 @@ class TestStatsWidgetWithScatters(TestCaseQt):
TestCaseQt.tearDown(self)
def testStats(self):
- scatter = self.scatterPlot._getItem(
- kind='scatter', legend=self.SCATTER_LEGEND)
+ scatter = self.scatterPlot._getItem(kind="scatter", legend=self.SCATTER_LEGEND)
tableItems = self.widget._itemToTableItems(scatter)
- self.assertEqual(tableItems['legend'].text(), self.SCATTER_LEGEND)
- self.assertEqual(tableItems['min'].text(), '5')
- self.assertEqual(tableItems['coords min'].text(), '0, 2')
- self.assertEqual(tableItems['max'].text(), '90')
- self.assertEqual(tableItems['coords max'].text(), '50, 69')
- self.assertEqual(tableItems['delta'].text(), '85')
+ self.assertEqual(tableItems["legend"].text(), self.SCATTER_LEGEND)
+ self.assertEqual(tableItems["min"].text(), "5")
+ self.assertEqual(tableItems["coords min"].text(), "0, 2")
+ self.assertEqual(tableItems["max"].text(), "90")
+ self.assertEqual(tableItems["coords max"].text(), "50, 69")
+ self.assertEqual(tableItems["delta"].text(), "85")
class TestEmptyStatsWidget(TestCaseQt):
@@ -694,25 +754,26 @@ class TestEmptyStatsWidget(TestCaseQt):
class TestLineWidget(TestCaseQt):
"""Some test for the StatsLineWidget."""
+
def setUp(self):
TestCaseQt.setUp(self)
- mystats = statshandler.StatsHandler((
- (stats.StatMin(), statshandler.StatFormatter()),
- ))
+ mystats = statshandler.StatsHandler(
+ ((stats.StatMin(), statshandler.StatFormatter()),)
+ )
self.plot = Plot1D()
self.plot.show()
self.x = range(20)
self.y0 = range(20)
- self.curve0 = self.plot.addCurve(self.x, self.y0, legend='curve0')
+ self.plot.addCurve(self.x, self.y0, legend="curve0")
self.y1 = range(12, 32)
- self.plot.addCurve(self.x, self.y1, legend='curve1')
+ self.plot.addCurve(self.x, self.y1, legend="curve1")
self.y2 = range(-2, 18)
- self.plot.addCurve(self.x, self.y2, legend='curve2')
- self.widget = StatsWidget.BasicGridStatsWidget(plot=self.plot,
- kind='curve',
- stats=mystats)
+ self.plot.addCurve(self.x, self.y2, legend="curve2")
+ self.widget = StatsWidget.BasicGridStatsWidget(
+ plot=self.plot, kind="curve", stats=mystats
+ )
def tearDown(self):
Stats._getContext.cache_clear()
@@ -730,27 +791,37 @@ class TestLineWidget(TestCaseQt):
def testProcessing(self):
self.widget._lineStatsWidget.setStatsOnVisibleData(False)
self.qapp.processEvents()
- self.plot.setActiveCurve(legend='curve0')
- self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '0.000')
- self.plot.setActiveCurve(legend='curve1')
- self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '12.000')
+ self.plot.setActiveCurve(legend="curve0")
+ self.assertTrue(
+ self.widget._lineStatsWidget._statQlineEdit["min"].text() == "0.000"
+ )
+ self.plot.setActiveCurve(legend="curve1")
+ self.assertTrue(
+ self.widget._lineStatsWidget._statQlineEdit["min"].text() == "12.000"
+ )
self.plot.getXAxis().setLimitsConstraints(minPos=2, maxPos=5)
self.widget.setStatsOnVisibleData(True)
self.qapp.processEvents()
- self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '14.000')
+ self.assertTrue(
+ self.widget._lineStatsWidget._statQlineEdit["min"].text() == "14.000"
+ )
self.plot.setActiveCurve(None)
self.assertIsNone(self.plot.getActiveCurve())
self.widget.setStatsOnVisibleData(False)
self.qapp.processEvents()
- self.assertFalse(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '14.000')
- self.widget.setKind('image')
- self.plot.addImage(numpy.arange(100*100).reshape(100, 100) + 0.312)
+ self.assertFalse(
+ self.widget._lineStatsWidget._statQlineEdit["min"].text() == "14.000"
+ )
+ self.widget.setKind("image")
+ self.plot.addImage(numpy.arange(100 * 100).reshape(100, 100) + 0.312)
self.qapp.processEvents()
- self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '0.312')
+ self.assertTrue(
+ self.widget._lineStatsWidget._statQlineEdit["min"].text() == "0.312"
+ )
def testUpdateMode(self):
"""Make sure the update modes are well take into account"""
- self.plot.setActiveCurve(self.curve0)
+ self.plot.setActiveCurve("curve0")
_autoRB = self.widget._options._autoRB
_manualRB = self.widget._options._manualRB
# test from api
@@ -759,10 +830,10 @@ class TestLineWidget(TestCaseQt):
self.assertFalse(_manualRB.isChecked())
# check stats change in auto mode
- curve0_min = self.widget._lineStatsWidget._statQlineEdit['min'].text()
+ curve0_min = self.widget._lineStatsWidget._statQlineEdit["min"].text()
new_y = numpy.array(self.y0) - 2.56
- self.plot.addCurve(x=self.x, y=new_y, legend=self.curve0)
- curve0_min2 = self.widget._lineStatsWidget._statQlineEdit['min'].text()
+ self.plot.addCurve(x=self.x, y=new_y, legend="curve0")
+ curve0_min2 = self.widget._lineStatsWidget._statQlineEdit["min"].text()
self.assertTrue(curve0_min != curve0_min2)
# check stats change in manual mode only if requested
@@ -771,11 +842,11 @@ class TestLineWidget(TestCaseQt):
self.assertTrue(_manualRB.isChecked())
new_y = numpy.array(self.y0) - 1.2
- self.plot.addCurve(x=self.x, y=new_y, legend=self.curve0)
- curve0_min3 = self.widget._lineStatsWidget._statQlineEdit['min'].text()
+ self.plot.addCurve(x=self.x, y=new_y, legend="curve0")
+ curve0_min3 = self.widget._lineStatsWidget._statQlineEdit["min"].text()
self.assertTrue(curve0_min3 == curve0_min2)
self.widget._options._updateRequested()
- curve0_min3 = self.widget._lineStatsWidget._statQlineEdit['min'].text()
+ curve0_min3 = self.widget._lineStatsWidget._statQlineEdit["min"].text()
self.assertTrue(curve0_min3 != curve0_min2)
# test from gui
@@ -791,6 +862,7 @@ class TestLineWidget(TestCaseQt):
class TestUpdateModeWidget(TestCaseQt):
"""Test UpdateModeWidget"""
+
def setUp(self):
TestCaseQt.setUp(self)
self.widget = StatsWidget.UpdateModeWidget(parent=None)
@@ -832,6 +904,7 @@ class TestStatsROI(TestStatsBase, TestCaseQt):
"""
Test stats based on ROI
"""
+
def setUp(self):
TestCaseQt.setUp(self)
self.createRois()
@@ -855,7 +928,7 @@ class TestStatsROI(TestStatsBase, TestCaseQt):
TestCaseQt.tearDown(self)
def createRois(self):
- self._1Droi = ROI(name='my1DRoi', fromdata=2.0, todata=5.0)
+ self._1Droi = ROI(name="my1DRoi", fromdata=2.0, todata=5.0)
self._2Droi_rect = RectangleROI()
self._2Droi_rect.setGeometry(size=(10, 10), origin=(10, 0))
self._2Droi_poly = PolygonROI()
@@ -865,30 +938,32 @@ class TestStatsROI(TestStatsBase, TestCaseQt):
def createCurveContext(self):
TestStatsBase.createCurveContext(self)
self.curveContext = stats._CurveContext(
- item=self.plot1d.getCurve('curve0'),
+ item=self.plot1d.getCurve("curve0"),
plot=self.plot1d,
onlimits=False,
- roi=self._1Droi)
+ roi=self._1Droi,
+ )
def createHistogramContext(self):
self.plotHisto = Plot1D()
x = range(20)
y = range(20)
- self.plotHisto.addHistogram(x, y, legend='histo0')
+ self.plotHisto.addHistogram(x, y, legend="histo0")
self.histoContext = stats._HistogramContext(
- item=self.plotHisto.getHistogram('histo0'),
+ item=self.plotHisto.getHistogram("histo0"),
plot=self.plotHisto,
onlimits=False,
- roi=self._1Droi)
+ roi=self._1Droi,
+ )
def createScatterContext(self):
TestStatsBase.createScatterContext(self)
self.scatterContext = stats._ScatterContext(
- item=self.scatterPlot.getScatter('scatter plot'),
+ item=self.scatterPlot.getScatter("scatter plot"),
plot=self.scatterPlot,
onlimits=False,
- roi=self._1Droi
+ roi=self._1Droi,
)
def createImageContext(self):
@@ -898,56 +973,68 @@ class TestStatsROI(TestStatsBase, TestCaseQt):
item=self.plot2d.getImage(self._imgLgd),
plot=self.plot2d,
onlimits=False,
- roi=self._2Droi_rect
+ roi=self._2Droi_rect,
)
self.imageContext_2 = stats._ImageContext(
item=self.plot2d.getImage(self._imgLgd),
plot=self.plot2d,
onlimits=False,
- roi=self._2Droi_poly
+ roi=self._2Droi_poly,
)
def testErrors(self):
# test if onlimits is True and give also a roi
with self.assertRaises(ValueError):
- stats._CurveContext(item=self.plot1d.getCurve('curve0'),
- plot=self.plot1d,
- onlimits=True,
- roi=self._1Droi)
+ stats._CurveContext(
+ item=self.plot1d.getCurve("curve0"),
+ plot=self.plot1d,
+ onlimits=True,
+ roi=self._1Droi,
+ )
# test if is a curve context and give an invalid 2D roi
with self.assertRaises(TypeError):
- stats._CurveContext(item=self.plot1d.getCurve('curve0'),
- plot=self.plot1d,
- onlimits=False,
- roi=self._2Droi_rect)
+ stats._CurveContext(
+ item=self.plot1d.getCurve("curve0"),
+ plot=self.plot1d,
+ onlimits=False,
+ roi=self._2Droi_rect,
+ )
def testBasicStatsCurve(self):
"""Test result for simple stats on a curve"""
_stats = self.getBasicStats()
xData = yData = numpy.array(range(0, 10))
- self.assertEqual(_stats['min'].calculate(self.curveContext), 2)
- self.assertEqual(_stats['max'].calculate(self.curveContext), 5)
- self.assertEqual(_stats['minCoords'].calculate(self.curveContext), (2,))
- self.assertEqual(_stats['maxCoords'].calculate(self.curveContext), (5,))
- self.assertEqual(_stats['std'].calculate(self.curveContext), numpy.std(yData[2:6]))
- self.assertEqual(_stats['mean'].calculate(self.curveContext), numpy.mean(yData[2:6]))
+ self.assertEqual(_stats["min"].calculate(self.curveContext), 2)
+ self.assertEqual(_stats["max"].calculate(self.curveContext), 5)
+ self.assertEqual(_stats["minCoords"].calculate(self.curveContext), (2,))
+ self.assertEqual(_stats["maxCoords"].calculate(self.curveContext), (5,))
+ self.assertEqual(
+ _stats["std"].calculate(self.curveContext), numpy.std(yData[2:6])
+ )
+ self.assertEqual(
+ _stats["mean"].calculate(self.curveContext), numpy.mean(yData[2:6])
+ )
com = numpy.sum(xData[2:6] * yData[2:6]) / numpy.sum(yData[2:6])
- self.assertEqual(_stats['com'].calculate(self.curveContext), com)
+ self.assertEqual(_stats["com"].calculate(self.curveContext), com)
def testBasicStatsImageRectRoi(self):
"""Test result for simple stats on an image"""
self.assertEqual(self.imageContext.values.compressed().size, 121)
_stats = self.getBasicStats()
- self.assertEqual(_stats['min'].calculate(self.imageContext), 10)
- self.assertEqual(_stats['max'].calculate(self.imageContext), 1300)
- self.assertEqual(_stats['minCoords'].calculate(self.imageContext), (10, 0))
- self.assertEqual(_stats['maxCoords'].calculate(self.imageContext), (20.0, 10.0))
- self.assertAlmostEqual(_stats['std'].calculate(self.imageContext),
- numpy.std(self.imageData[0:11, 10:21]))
- self.assertAlmostEqual(_stats['mean'].calculate(self.imageContext),
- numpy.mean(self.imageData[0:11, 10:21]))
+ self.assertEqual(_stats["min"].calculate(self.imageContext), 10)
+ self.assertEqual(_stats["max"].calculate(self.imageContext), 1300)
+ self.assertEqual(_stats["minCoords"].calculate(self.imageContext), (10, 0))
+ self.assertEqual(_stats["maxCoords"].calculate(self.imageContext), (20.0, 10.0))
+ self.assertAlmostEqual(
+ _stats["std"].calculate(self.imageContext),
+ numpy.std(self.imageData[0:11, 10:21]),
+ )
+ self.assertAlmostEqual(
+ _stats["mean"].calculate(self.imageContext),
+ numpy.mean(self.imageData[0:11, 10:21]),
+ )
compressed_values = self.imageContext.values.compressed()
compressed_values = compressed_values.reshape(11, 11)
@@ -957,41 +1044,47 @@ class TestStatsROI(TestStatsBase, TestCaseQt):
dataYRange = range(11)
dataXRange = range(10, 21)
- ycom = numpy.sum(yData*dataYRange) / numpy.sum(yData)
- xcom = numpy.sum(xData*dataXRange) / numpy.sum(xData)
- self.assertEqual(_stats['com'].calculate(self.imageContext), (xcom, ycom))
+ ycom = numpy.sum(yData * dataYRange) / numpy.sum(yData)
+ xcom = numpy.sum(xData * dataXRange) / numpy.sum(xData)
+ self.assertEqual(_stats["com"].calculate(self.imageContext), (xcom, ycom))
def testBasicStatsImagePolyRoi(self):
"""Test a simple rectangle ROI"""
_stats = self.getBasicStats()
- self.assertEqual(_stats['min'].calculate(self.imageContext_2), 0)
- self.assertEqual(_stats['max'].calculate(self.imageContext_2), 2432)
- self.assertEqual(_stats['minCoords'].calculate(self.imageContext_2), (0.0, 0.0))
+ self.assertEqual(_stats["min"].calculate(self.imageContext_2), 0)
+ self.assertEqual(_stats["max"].calculate(self.imageContext_2), 2432)
+ self.assertEqual(_stats["minCoords"].calculate(self.imageContext_2), (0.0, 0.0))
# not 0.0, 19.0 because not fully in. Should all pixel have a weight,
# on to manage them in stats. For now 0 if the center is not in, else 1
- self.assertEqual(_stats['maxCoords'].calculate(self.imageContext_2), (0.0, 19.0))
+ self.assertEqual(
+ _stats["maxCoords"].calculate(self.imageContext_2), (0.0, 19.0)
+ )
def testBasicStatsScatter(self):
self.assertEqual(self.scatterContext.values.compressed().size, 2)
_stats = self.getBasicStats()
- self.assertEqual(_stats['min'].calculate(self.scatterContext), 6)
- self.assertEqual(_stats['max'].calculate(self.scatterContext), 7)
- self.assertEqual(_stats['minCoords'].calculate(self.scatterContext), (2, 3))
- self.assertEqual(_stats['maxCoords'].calculate(self.scatterContext), (3, 4))
- self.assertEqual(_stats['std'].calculate(self.scatterContext), numpy.std([6, 7]))
- self.assertEqual(_stats['mean'].calculate(self.scatterContext), numpy.mean([6, 7]))
+ self.assertEqual(_stats["min"].calculate(self.scatterContext), 6)
+ self.assertEqual(_stats["max"].calculate(self.scatterContext), 7)
+ self.assertEqual(_stats["minCoords"].calculate(self.scatterContext), (2, 3))
+ self.assertEqual(_stats["maxCoords"].calculate(self.scatterContext), (3, 4))
+ self.assertEqual(
+ _stats["std"].calculate(self.scatterContext), numpy.std([6, 7])
+ )
+ self.assertEqual(
+ _stats["mean"].calculate(self.scatterContext), numpy.mean([6, 7])
+ )
def testBasicHistogram(self):
_stats = self.getBasicStats()
xData = yData = numpy.array(range(2, 6))
- self.assertEqual(_stats['min'].calculate(self.histoContext), 2)
- self.assertEqual(_stats['max'].calculate(self.histoContext), 5)
- self.assertEqual(_stats['minCoords'].calculate(self.histoContext), (2,))
- self.assertEqual(_stats['maxCoords'].calculate(self.histoContext), (5,))
- self.assertEqual(_stats['std'].calculate(self.histoContext), numpy.std(yData))
- self.assertEqual(_stats['mean'].calculate(self.histoContext), numpy.mean(yData))
+ self.assertEqual(_stats["min"].calculate(self.histoContext), 2)
+ self.assertEqual(_stats["max"].calculate(self.histoContext), 5)
+ self.assertEqual(_stats["minCoords"].calculate(self.histoContext), (2,))
+ self.assertEqual(_stats["maxCoords"].calculate(self.histoContext), (5,))
+ self.assertEqual(_stats["std"].calculate(self.histoContext), numpy.std(yData))
+ self.assertEqual(_stats["mean"].calculate(self.histoContext), numpy.mean(yData))
com = numpy.sum(xData * yData) / numpy.sum(yData)
- self.assertEqual(_stats['com'].calculate(self.histoContext), com)
+ self.assertEqual(_stats["com"].calculate(self.histoContext), com)
class TestAdvancedROIImageContext(TestCaseQt):
@@ -1016,31 +1109,35 @@ class TestAdvancedROIImageContext(TestCaseQt):
roi_origins = [(0, 0), (2, 10), (14, 20)]
img_origins = [(0, 0), (14, 20), (2, 10)]
img_scales = [1.0, 0.5, 2.0]
- _stats = {'sum': stats.Stat(name='sum', fct=numpy.sum), }
+ _stats = {
+ "sum": stats.Stat(name="sum", fct=numpy.sum),
+ }
for roi_origin in roi_origins:
for img_origin in img_origins:
for img_scale in img_scales:
- with self.subTest(roi_origin=roi_origin,
- img_origin=img_origin,
- img_scale=img_scale):
- self.plot.addImage(self.data, legend='img',
- origin=img_origin,
- scale=img_scale)
+ with self.subTest(
+ roi_origin=roi_origin,
+ img_origin=img_origin,
+ img_scale=img_scale,
+ ):
+ self.plot.addImage(
+ self.data, legend="img", origin=img_origin, scale=img_scale
+ )
roi = RectangleROI()
roi.setGeometry(origin=roi_origin, size=(20, 20))
context = stats._ImageContext(
- item=self.plot.getImage('img'),
+ item=self.plot.getImage("img"),
plot=self.plot,
onlimits=False,
- roi=roi)
+ roi=roi,
+ )
x_start = int((roi_origin[0] - img_origin[0]) / img_scale)
x_end = int(x_start + (20 / img_scale)) + 1
- y_start = int((roi_origin[1] - img_origin[1])/ img_scale)
+ y_start = int((roi_origin[1] - img_origin[1]) / img_scale)
y_end = int(y_start + (20 / img_scale)) + 1
x_start = max(x_start, 0)
x_end = min(max(x_end, 0), self.data_dims[1])
y_start = max(y_start, 0)
y_end = min(max(y_end, 0), self.data_dims[0])
th_sum = numpy.sum(self.data[y_start:y_end, x_start:x_end])
- self.assertAlmostEqual(_stats['sum'].calculate(context),
- th_sum)
+ self.assertAlmostEqual(_stats["sum"].calculate(context), th_sum)
diff --git a/src/silx/gui/plot/test/testUtilsAxis.py b/src/silx/gui/plot/test/testUtilsAxis.py
index 879ec73..d749845 100644
--- a/src/silx/gui/plot/test/testUtilsAxis.py
+++ b/src/silx/gui/plot/test/testUtilsAxis.py
@@ -28,7 +28,6 @@ __license__ = "MIT"
__date__ = "20/11/2018"
-import unittest
from silx.gui.plot import PlotWidget
from silx.gui.utils.testutils import TestCaseQt
from silx.gui.plot.utils.axis import SyncAxes
@@ -51,7 +50,9 @@ class TestAxisSync(TestCaseQt):
def testMoveFirstAxis(self):
"""Test synchronization after construction"""
- _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+ _sync = SyncAxes(
+ [self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]
+ )
self.plot1.getXAxis().setLimits(10, 500)
self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
@@ -60,7 +61,9 @@ class TestAxisSync(TestCaseQt):
def testMoveSecondAxis(self):
"""Test synchronization after construction"""
- _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+ _sync = SyncAxes(
+ [self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]
+ )
self.plot2.getXAxis().setLimits(10, 500)
self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
@@ -69,7 +72,9 @@ class TestAxisSync(TestCaseQt):
def testMoveTwoAxes(self):
"""Test synchronization after construction"""
- _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+ _sync = SyncAxes(
+ [self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]
+ )
self.plot1.getXAxis().setLimits(1, 50)
self.plot2.getXAxis().setLimits(10, 500)
@@ -79,7 +84,9 @@ class TestAxisSync(TestCaseQt):
def testDestruction(self):
"""Test synchronization when sync object is destroyed"""
- sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+ sync = SyncAxes(
+ [self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]
+ )
del sync
self.plot1.getXAxis().setLimits(10, 500)
@@ -89,10 +96,13 @@ class TestAxisSync(TestCaseQt):
def testAxisDestruction(self):
"""Test synchronization when an axis disappear"""
- _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+ _sync = SyncAxes(
+ [self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]
+ )
# Destroy the plot is possible
import weakref
+
plot = weakref.ref(self.plot2)
self.plot2 = None
result = self.qWaitForDestroy(plot)
@@ -105,7 +115,9 @@ class TestAxisSync(TestCaseQt):
def testStop(self):
"""Test synchronization after calling stop"""
- sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+ sync = SyncAxes(
+ [self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]
+ )
sync.stop()
self.plot1.getXAxis().setLimits(10, 500)
@@ -115,7 +127,9 @@ class TestAxisSync(TestCaseQt):
def testStopMovingStart(self):
"""Test synchronization after calling stop, moving an axis, then start again"""
- sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+ sync = SyncAxes(
+ [self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]
+ )
sync.stop()
self.plot1.getXAxis().setLimits(10, 500)
self.plot2.getXAxis().setLimits(1, 50)
@@ -129,26 +143,40 @@ class TestAxisSync(TestCaseQt):
def testDoubleStop(self):
"""Test double stop"""
- sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+ sync = SyncAxes(
+ [self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]
+ )
sync.stop()
self.assertRaises(RuntimeError, sync.stop)
def testDoubleStart(self):
"""Test double stop"""
- sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+ sync = SyncAxes(
+ [self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]
+ )
self.assertRaises(RuntimeError, sync.start)
def testScale(self):
"""Test scale change"""
- _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+ _sync = SyncAxes(
+ [self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]
+ )
self.plot1.getXAxis().setScale(self.plot1.getXAxis().LOGARITHMIC)
- self.assertEqual(self.plot1.getXAxis().getScale(), self.plot1.getXAxis().LOGARITHMIC)
- self.assertEqual(self.plot2.getXAxis().getScale(), self.plot1.getXAxis().LOGARITHMIC)
- self.assertEqual(self.plot3.getXAxis().getScale(), self.plot1.getXAxis().LOGARITHMIC)
+ self.assertEqual(
+ self.plot1.getXAxis().getScale(), self.plot1.getXAxis().LOGARITHMIC
+ )
+ self.assertEqual(
+ self.plot2.getXAxis().getScale(), self.plot1.getXAxis().LOGARITHMIC
+ )
+ self.assertEqual(
+ self.plot3.getXAxis().getScale(), self.plot1.getXAxis().LOGARITHMIC
+ )
def testDirection(self):
"""Test direction change"""
- _sync = SyncAxes([self.plot1.getYAxis(), self.plot2.getYAxis(), self.plot3.getYAxis()])
+ _sync = SyncAxes(
+ [self.plot1.getYAxis(), self.plot2.getYAxis(), self.plot3.getYAxis()]
+ )
self.plot1.getYAxis().setInverted(True)
self.assertEqual(self.plot1.getYAxis().isInverted(), True)
self.assertEqual(self.plot2.getYAxis().isInverted(), True)
@@ -160,8 +188,11 @@ class TestAxisSync(TestCaseQt):
self.plot1.getXAxis().setLimits(0, 200)
self.plot2.getXAxis().setLimits(0, 20)
self.plot3.getXAxis().setLimits(0, 2)
- _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()],
- syncLimits=False, syncCenter=True)
+ _sync = SyncAxes(
+ [self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()],
+ syncLimits=False,
+ syncCenter=True,
+ )
self.assertEqual(self.plot1.getXAxis().getLimits(), (0, 200))
self.assertEqual(self.plot2.getXAxis().getLimits(), (100 - 10, 100 + 10))
@@ -173,8 +204,12 @@ class TestAxisSync(TestCaseQt):
self.plot1.getXAxis().setLimits(0, 200)
self.plot2.getXAxis().setLimits(0, 20)
self.plot3.getXAxis().setLimits(0, 2)
- _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()],
- syncLimits=False, syncCenter=True, syncZoom=True)
+ _sync = SyncAxes(
+ [self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()],
+ syncLimits=False,
+ syncCenter=True,
+ syncZoom=True,
+ )
# Supposing all the plots use the same size
self.assertEqual(self.plot1.getXAxis().getLimits(), (0, 200))
@@ -193,7 +228,9 @@ class TestAxisSync(TestCaseQt):
def testRemoveAxis(self):
"""Test synchronization after construction"""
- sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+ sync = SyncAxes(
+ [self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()]
+ )
sync.removeAxis(self.plot3.getXAxis())
self.plot1.getXAxis().setLimits(10, 500)
diff --git a/src/silx/gui/plot/test/utils.py b/src/silx/gui/plot/test/utils.py
index faa40bb..d48a467 100644
--- a/src/silx/gui/plot/test/utils.py
+++ b/src/silx/gui/plot/test/utils.py
@@ -30,7 +30,6 @@ __date__ = "26/01/2018"
import logging
import pytest
-import unittest
from silx.gui.utils.testutils import TestCaseQt
@@ -47,6 +46,7 @@ class PlotWidgetTestCase(TestCaseQt):
plot attribute is the PlotWidget created for the test.
"""
+
__screenshot_already_taken = False
backend = None
diff --git a/src/silx/gui/plot/tools/CurveLegendsWidget.py b/src/silx/gui/plot/tools/CurveLegendsWidget.py
index c9b0101..0ebea0d 100644
--- a/src/silx/gui/plot/tools/CurveLegendsWidget.py
+++ b/src/silx/gui/plot/tools/CurveLegendsWidget.py
@@ -74,11 +74,10 @@ class _LegendWidget(qt.QWidget):
return icon.getCurve()
def _update(self):
- """Update widget according to current curve state.
- """
+ """Update widget according to current curve state."""
curve = self.getCurve()
if curve is None:
- _logger.error('Curve no more exists')
+ _logger.error("Curve no more exists")
self.setVisible(False)
return
@@ -95,9 +94,11 @@ class _LegendWidget(qt.QWidget):
:param event: Kind of change
"""
- if event in (items.ItemChangedType.VISIBLE,
- items.ItemChangedType.HIGHLIGHTED,
- items.ItemChangedType.HIGHLIGHTED_STYLE):
+ if event in (
+ items.ItemChangedType.VISIBLE,
+ items.ItemChangedType.HIGHLIGHTED,
+ items.ItemChangedType.HIGHLIGHTED_STYLE,
+ ):
self._update()
@@ -142,7 +143,7 @@ class CurveLegendsWidget(qt.QWidget):
"""
previousPlot = self.getPlotWidget()
if previousPlot is not None:
- previousPlot.sigItemAdded.disconnect( self._itemAdded)
+ previousPlot.sigItemAdded.disconnect(self._itemAdded)
previousPlot.sigItemAboutToBeRemoved.disconnect(self._itemRemoved)
for legend in list(self._legends.keys()):
self._removeLegend(legend)
@@ -168,7 +169,7 @@ class CurveLegendsWidget(qt.QWidget):
elif len(args) == 2:
point = qt.QPoint(*args)
else:
- raise ValueError('Unsupported arguments')
+ raise ValueError("Unsupported arguments")
assert isinstance(point, qt.QPoint)
widget = self.childAt(point)
@@ -202,7 +203,7 @@ class CurveLegendsWidget(qt.QWidget):
curve = plot.getCurve(legend)
if curve is None:
- _logger.error('Curve not found: %s' % legend)
+ _logger.error("Curve not found: %s" % legend)
return
widget = _LegendWidget(parent=self, curve=curve)
@@ -216,7 +217,7 @@ class CurveLegendsWidget(qt.QWidget):
"""
widget = self._legends.pop(legend, None)
if widget is None:
- _logger.warning('Unknown legend: %s' % legend)
+ _logger.warning("Unknown legend: %s" % legend)
else:
self.layout().removeWidget(widget)
widget.setParent(None)
diff --git a/src/silx/gui/plot/tools/LimitsToolBar.py b/src/silx/gui/plot/tools/LimitsToolBar.py
index d7f4bf5..5ed09f7 100644
--- a/src/silx/gui/plot/tools/LimitsToolBar.py
+++ b/src/silx/gui/plot/tools/LimitsToolBar.py
@@ -56,7 +56,7 @@ class LimitsToolBar(qt.QToolBar):
:param str title: See :class:`QToolBar`.
"""
- def __init__(self, parent=None, plot=None, title='Limits'):
+ def __init__(self, parent=None, plot=None, title="Limits"):
super(LimitsToolBar, self).__init__(title, parent)
assert plot is not None
self._plot = plot
@@ -74,32 +74,28 @@ class LimitsToolBar(qt.QToolBar):
xMin, xMax = self.plot.getXAxis().getLimits()
yMin, yMax = self.plot.getYAxis().getLimits()
- self.addWidget(qt.QLabel('Limits: '))
- self.addWidget(qt.QLabel(' X: '))
+ self.addWidget(qt.QLabel("Limits: "))
+ self.addWidget(qt.QLabel(" X: "))
self._xMinFloatEdit = FloatEdit(self, xMin)
- self._xMinFloatEdit.editingFinished[()].connect(
- self._xFloatEditChanged)
+ self._xMinFloatEdit.editingFinished[()].connect(self._xFloatEditChanged)
self.addWidget(self._xMinFloatEdit)
self._xMaxFloatEdit = FloatEdit(self, xMax)
- self._xMaxFloatEdit.editingFinished[()].connect(
- self._xFloatEditChanged)
+ self._xMaxFloatEdit.editingFinished[()].connect(self._xFloatEditChanged)
self.addWidget(self._xMaxFloatEdit)
- self.addWidget(qt.QLabel(' Y: '))
+ self.addWidget(qt.QLabel(" Y: "))
self._yMinFloatEdit = FloatEdit(self, yMin)
- self._yMinFloatEdit.editingFinished[()].connect(
- self._yFloatEditChanged)
+ self._yMinFloatEdit.editingFinished[()].connect(self._yFloatEditChanged)
self.addWidget(self._yMinFloatEdit)
self._yMaxFloatEdit = FloatEdit(self, yMax)
- self._yMaxFloatEdit.editingFinished[()].connect(
- self._yFloatEditChanged)
+ self._yMaxFloatEdit.editingFinished[()].connect(self._yFloatEditChanged)
self.addWidget(self._yMaxFloatEdit)
def _plotWidgetSlot(self, event):
"""Listen to :class:`PlotWidget` events."""
- if event['event'] not in ('limitsChanged',):
+ if event["event"] not in ("limitsChanged",):
return
xMin, xMax = self.plot.getXAxis().getLimits()
diff --git a/src/silx/gui/plot/tools/PlotToolButton.py b/src/silx/gui/plot/tools/PlotToolButton.py
new file mode 100644
index 0000000..3a14f77
--- /dev/null
+++ b/src/silx/gui/plot/tools/PlotToolButton.py
@@ -0,0 +1,92 @@
+# /*##########################################################################
+#
+# Copyright (c) 2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides an abstract PlotToolButton that can be use to create
+plot tools for a toolbar.
+"""
+
+from __future__ import annotations
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "20/12/2023"
+
+
+import logging
+import weakref
+
+from silx.gui import qt
+
+
+_logger = logging.getLogger(__name__)
+
+
+class PlotToolButton(qt.QToolButton):
+ """A QToolButton connected to a :class:`~silx.gui.plot.PlotWidget`."""
+
+ def __init__(self, parent: qt.QWidget | None = None, plot=None):
+ super(PlotToolButton, self).__init__(parent)
+ self._plotRef = None
+ if plot is not None:
+ self.setPlot(plot)
+
+ def plot(self):
+ """
+ Returns the plot connected to the widget.
+ """
+ return None if self._plotRef is None else self._plotRef()
+
+ def setPlot(self, plot):
+ """
+ Set the plot connected to the widget
+
+ :param plot: :class:`.PlotWidget` instance on which to operate.
+ """
+ previousPlot = self.plot()
+
+ if previousPlot is plot:
+ return
+ if previousPlot is not None:
+ self._disconnectPlot(previousPlot)
+
+ if plot is None:
+ self._plotRef = None
+ else:
+ self._plotRef = weakref.ref(plot)
+ self._connectPlot(plot)
+
+ def _connectPlot(self, plot):
+ """
+ Called when the plot is connected to the widget
+
+ :param plot: :class:`.PlotWidget` instance
+ """
+ pass
+
+ def _disconnectPlot(self, plot):
+ """
+ Called when the plot is disconnected from the widget
+
+ :param plot: :class:`.PlotWidget` instance
+ """
+ pass
diff --git a/src/silx/gui/plot/tools/PositionInfo.py b/src/silx/gui/plot/tools/PositionInfo.py
index cb16b80..e3b8425 100644
--- a/src/silx/gui/plot/tools/PositionInfo.py
+++ b/src/silx/gui/plot/tools/PositionInfo.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -38,7 +38,6 @@ import weakref
import numpy
-from ....utils.deprecation import deprecated
from ... import qt
from .. import items
from ...widgets.ElidedLabel import ElidedLabel
@@ -56,12 +55,13 @@ class _PositionInfoLabel(ElidedLabel):
def sizeHint(self):
hint = super().sizeHint()
- width = self.fontMetrics().boundingRect('##############').width()
+ width = self.fontMetrics().boundingRect("##############").width()
return qt.QSize(max(hint.width(), width), hint.height())
# PositionInfo ################################################################
+
class PositionInfo(qt.QWidget):
"""QWidget displaying coords converted from data coords of the mouse.
@@ -115,7 +115,7 @@ class PositionInfo(qt.QWidget):
super(PositionInfo, self).__init__(parent)
if converters is None:
- converters = (('X', lambda x, y: x), ('Y', lambda x, y: y))
+ converters = (("X", lambda x, y: x), ("Y", lambda x, y: y))
self._fields = [] # To store (QLineEdit, name, function (x, y)->v)
@@ -126,10 +126,10 @@ class PositionInfo(qt.QWidget):
# Create all QLabel and store them with the corresponding converter
for name, func in converters:
- layout.addWidget(qt.QLabel('<b>' + name + ':</b>'))
+ layout.addWidget(qt.QLabel("<b>" + name + ":</b>"))
contentWidget = _PositionInfoLabel(self)
- contentWidget.setText('------')
+ contentWidget.setText("------")
layout.addWidget(contentWidget)
self._fields.append((contentWidget, name, func))
@@ -146,11 +146,6 @@ class PositionInfo(qt.QWidget):
"""
return self._plotRef()
- @property
- @deprecated(replacement='getPlotWidget', since_version='0.8.0')
- def plot(self):
- return self.getPlotWidget()
-
def getConverters(self):
"""Return the list of converters as 2-tuple (name, function)."""
return [(name, func) for _label, name, func in self._fields]
@@ -160,17 +155,18 @@ class PositionInfo(qt.QWidget):
:param dict event: Plot event
"""
- if event['event'] == 'mouseMoved':
- x, y = event['x'], event['y']
- xPixel, yPixel = event['xpixel'], event['ypixel']
+ if event["event"] == "mouseMoved":
+ x, y = event["x"], event["y"]
+ xPixel, yPixel = event["xpixel"], event["ypixel"]
self._updateStatusBar(x, y, xPixel, yPixel)
def updateInfo(self):
"""Update displayed information"""
plot = self.getPlotWidget()
if plot is None:
- _logger.error("Trying to update PositionInfo "
- "while PlotWidget no longer exists")
+ _logger.error(
+ "Trying to update PositionInfo " "while PlotWidget no longer exists"
+ )
return
widget = plot.getWidgetHandle()
@@ -193,15 +189,15 @@ class PositionInfo(qt.QWidget):
if plot is None:
return
- styleSheet = "color: rgb(0, 0, 0);" # Default style
+ styleSheet = "" # Default style
xData, yData = x, y
snappingMode = self.getSnappingMode()
# Snapping when crosshair either not requested or active
- if (snappingMode & (self.SNAPPING_CURVE | self.SNAPPING_SCATTER) and
- (not (snappingMode & self.SNAPPING_CROSSHAIR) or
- plot.getGraphCursor())):
+ if snappingMode & (self.SNAPPING_CURVE | self.SNAPPING_SCATTER) and (
+ not (snappingMode & self.SNAPPING_CROSSHAIR) or plot.getGraphCursor()
+ ):
styleSheet = "color: rgb(255, 0, 0);" # Style far from item
if snappingMode & self.SNAPPING_ACTIVE_ONLY:
@@ -213,7 +209,7 @@ class PositionInfo(qt.QWidget):
selectedItems.append(activeCurve)
if snappingMode & self.SNAPPING_SCATTER:
- activeScatter = plot._getActiveItem(kind='scatter')
+ activeScatter = plot.getActiveScatter()
if activeScatter:
selectedItems.append(activeScatter)
@@ -224,8 +220,11 @@ class PositionInfo(qt.QWidget):
kinds.append(items.Histogram)
if snappingMode & self.SNAPPING_SCATTER:
kinds.append(items.Scatter)
- selectedItems = [item for item in plot.getItems()
- if isinstance(item, tuple(kinds)) and item.isVisible()]
+ selectedItems = [
+ item
+ for item in plot.getItems()
+ if isinstance(item, tuple(kinds)) and item.isVisible()
+ ]
# Compute distance threshold
window = plot.window()
@@ -236,12 +235,12 @@ class PositionInfo(qt.QWidget):
ratio = qt.QGuiApplication.primaryScreen().devicePixelRatio()
# Baseline squared distance threshold
- sqDistInPixels = (self.SNAP_THRESHOLD_DIST * ratio)**2
+ sqDistInPixels = (self.SNAP_THRESHOLD_DIST * ratio) ** 2
for item in selectedItems:
- if (snappingMode & self.SNAPPING_SYMBOLS_ONLY and (
- not isinstance(item, items.SymbolMixIn) or
- not item.getSymbol())):
+ if snappingMode & self.SNAPPING_SYMBOLS_ONLY and (
+ not isinstance(item, items.SymbolMixIn) or not item.getSymbol()
+ ):
# Only handled if item symbols are visible
continue
@@ -256,7 +255,7 @@ class PositionInfo(qt.QWidget):
yData = item.getValueData(copy=False)[index]
# Update label style sheet
- styleSheet = "color: rgb(0, 0, 0);"
+ styleSheet = ""
break
else: # Curve, Scatter
@@ -270,14 +269,16 @@ class PositionInfo(qt.QWidget):
if isinstance(item, items.YAxisMixIn):
axis = item.getYAxis()
else:
- axis = 'left'
+ axis = "left"
xArray = item.getXData(copy=False)[indices]
yArray = item.getYData(copy=False)[indices]
pixelPositions = plot.dataToPixel(xArray, yArray, axis=axis)
if pixelPositions is None:
continue
- sqDistances = (pixelPositions[0] - xPixel)**2 + (pixelPositions[1] - yPixel)**2
+ sqDistances = (pixelPositions[0] - xPixel) ** 2 + (
+ pixelPositions[1] - yPixel
+ ) ** 2
if not numpy.any(numpy.isfinite(sqDistances)):
continue
closestIndex = numpy.nanargmin(sqDistances)
@@ -285,7 +286,7 @@ class PositionInfo(qt.QWidget):
if closestSqDistInPixels <= sqDistInPixels:
# Update label style sheet
- styleSheet = "color: rgb(0, 0, 0);"
+ styleSheet = ""
# if close enough, snap to data point coord
xData, yData = xArray[closestIndex], yArray[closestIndex]
@@ -299,10 +300,11 @@ class PositionInfo(qt.QWidget):
text = self.valueToString(value)
label.setText(text)
except:
- label.setText('Error')
+ label.setText("Error")
_logger.error(
"Error while converting coordinates (%f, %f)"
- "with converter '%s'" % (xPixel, yPixel, name))
+ "with converter '%s'" % (xPixel, yPixel, name)
+ )
_logger.error(traceback.format_exc())
def valueToString(self, value):
@@ -311,7 +313,7 @@ class PositionInfo(qt.QWidget):
return ", ".join(value)
elif isinstance(value, numbers.Real):
# Use this for floats and int
- return '%.7g' % value
+ return "%.7g" % value
else:
# Fallback for other types
return str(value)
@@ -353,21 +355,3 @@ class PositionInfo(qt.QWidget):
:rtype: int
"""
return self._snappingMode
-
- _SNAPPING_LEGACY = (SNAPPING_CROSSHAIR |
- SNAPPING_ACTIVE_ONLY |
- SNAPPING_SYMBOLS_ONLY |
- SNAPPING_CURVE |
- SNAPPING_SCATTER)
- """Legacy snapping mode"""
-
- @property
- @deprecated(replacement="getSnappingMode", since_version="0.8")
- def autoSnapToActiveCurve(self):
- return self.getSnappingMode() == self._SNAPPING_LEGACY
-
- @autoSnapToActiveCurve.setter
- @deprecated(replacement="setSnappingMode", since_version="0.8")
- def autoSnapToActiveCurve(self, flag):
- self.setSnappingMode(
- self._SNAPPING_LEGACY if flag else self.SNAPPING_DISABLED)
diff --git a/src/silx/gui/plot/tools/RadarView.py b/src/silx/gui/plot/tools/RadarView.py
index 886f37e..8ddb98b 100644
--- a/src/silx/gui/plot/tools/RadarView.py
+++ b/src/silx/gui/plot/tools/RadarView.py
@@ -41,9 +41,9 @@ _logger = logging.getLogger(__name__)
class _DraggableRectItem(qt.QGraphicsRectItem):
"""RectItem which signals its change through visibleRectDragged."""
+
def __init__(self, *args, **kwargs):
- super(_DraggableRectItem, self).__init__(
- *args, **kwargs)
+ super(_DraggableRectItem, self).__init__(*args, **kwargs)
self._previousCursor = None
self.setFlag(qt.QGraphicsItem.ItemIsMovable)
@@ -81,8 +81,7 @@ class _DraggableRectItem(qt.QGraphicsRectItem):
def itemChange(self, change, value):
"""Callback called before applying changes to the item."""
- if (change == qt.QGraphicsItem.ItemPositionChange and
- not self._ignoreChange):
+ if change == qt.QGraphicsItem.ItemPositionChange and not self._ignoreChange:
# Makes sure that the visible area is in the data
# or that data is in the visible area if area is too wide
x, y = value.x(), value.y()
@@ -118,12 +117,12 @@ class _DraggableRectItem(qt.QGraphicsRectItem):
value.x() + self.rect().left(),
value.y() + self.rect().top(),
self.rect().width(),
- self.rect().height())
+ self.rect().height(),
+ )
return value
- return super(_DraggableRectItem, self).itemChange(
- change, value)
+ return super(_DraggableRectItem, self).itemChange(change, value)
def hoverEnterEvent(self, event):
"""Called when the mouse enters the rectangle area"""
@@ -160,37 +159,37 @@ class RadarView(qt.QGraphicsView):
It provides: left, top, width, height in data coordinates.
"""
- _DATA_PEN = qt.QPen(qt.QColor('white'))
- _DATA_BRUSH = qt.QBrush(qt.QColor('light gray'))
- _ACTIVEDATA_PEN = qt.QPen(qt.QColor('black'))
- _ACTIVEDATA_BRUSH = qt.QBrush(qt.QColor('transparent'))
+ _DATA_PEN = qt.QPen(qt.QColor("white"))
+ _DATA_BRUSH = qt.QBrush(qt.QColor("light gray"))
+ _ACTIVEDATA_PEN = qt.QPen(qt.QColor("black"))
+ _ACTIVEDATA_BRUSH = qt.QBrush(qt.QColor("transparent"))
_ACTIVEDATA_PEN.setWidth(2)
_ACTIVEDATA_PEN.setCosmetic(True)
- _VISIBLE_PEN = qt.QPen(qt.QColor('blue'))
+ _VISIBLE_PEN = qt.QPen(qt.QColor("blue"))
_VISIBLE_PEN.setWidth(2)
_VISIBLE_PEN.setCosmetic(True)
_VISIBLE_BRUSH = qt.QBrush(qt.QColor(0, 0, 0, 0))
- _TOOLTIP = 'Radar View:\nRed contour: Visible area\nGray area: The image'
+ _TOOLTIP = "Radar View:\nRed contour: Visible area\nGray area: The image"
_PIXMAP_SIZE = 256
def __init__(self, parent=None):
self.__plotRef = None
self._scene = qt.QGraphicsScene()
- self._dataRect = self._scene.addRect(0, 0, 1, 1,
- self._DATA_PEN,
- self._DATA_BRUSH)
- self._imageRect = self._scene.addRect(0, 0, 1, 1,
- self._ACTIVEDATA_PEN,
- self._ACTIVEDATA_BRUSH)
+ self._dataRect = self._scene.addRect(
+ 0, 0, 1, 1, self._DATA_PEN, self._DATA_BRUSH
+ )
+ self._imageRect = self._scene.addRect(
+ 0, 0, 1, 1, self._ACTIVEDATA_PEN, self._ACTIVEDATA_BRUSH
+ )
self._imageRect.setVisible(False)
- self._scatterRect = self._scene.addRect(0, 0, 1, 1,
- self._ACTIVEDATA_PEN,
- self._ACTIVEDATA_BRUSH)
+ self._scatterRect = self._scene.addRect(
+ 0, 0, 1, 1, self._ACTIVEDATA_PEN, self._ACTIVEDATA_BRUSH
+ )
self._scatterRect.setVisible(False)
- self._curveRect = self._scene.addRect(0, 0, 1, 1,
- self._ACTIVEDATA_PEN,
- self._ACTIVEDATA_BRUSH)
+ self._curveRect = self._scene.addRect(
+ 0, 0, 1, 1, self._ACTIVEDATA_PEN, self._ACTIVEDATA_BRUSH
+ )
self._curveRect.setVisible(False)
self._visibleRect = _DraggableRectItem(0, 0, 1, 1)
@@ -202,7 +201,7 @@ class RadarView(qt.QGraphicsView):
self.setHorizontalScrollBarPolicy(qt.Qt.ScrollBarAlwaysOff)
self.setVerticalScrollBarPolicy(qt.Qt.ScrollBarAlwaysOff)
self.setFocusPolicy(qt.Qt.NoFocus)
- self.setStyleSheet('border: 0px')
+ self.setStyleSheet("border: 0px")
self.setToolTip(self._TOOLTIP)
self.__reentrant = LockReentrant()
@@ -311,7 +310,7 @@ class RadarView(qt.QGraphicsView):
# As opposed to Plot. So invert RadarView when Plot is NOT inverted.
self.resetTransform()
if not inverted:
- self.scale(1., -1.)
+ self.scale(1.0, -1.0)
self.update()
def _viewRectDragged(self, left, top, width, height):
diff --git a/src/silx/gui/plot/tools/RulerToolButton.py b/src/silx/gui/plot/tools/RulerToolButton.py
new file mode 100644
index 0000000..55cc02f
--- /dev/null
+++ b/src/silx/gui/plot/tools/RulerToolButton.py
@@ -0,0 +1,183 @@
+# /*##########################################################################
+#
+# Copyright (c) 20023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+PlotToolButton to measure a distance in a plot
+"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "30/10/2023"
+
+
+import logging
+import numpy
+import weakref
+import typing
+
+from silx.gui import icons
+
+from .PlotToolButton import PlotToolButton
+
+from silx.gui.plot.tools.roi import RegionOfInterestManager
+from silx.gui.plot.items.roi import LineROI
+from silx.gui.plot import items
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _RulerROI(LineROI):
+ def __init__(self, parent=None):
+ super().__init__(parent)
+ self._formatFunction: typing.Optional[
+ typing.Callable[
+ [numpy.ndarray, numpy.ndarray], str
+ ]
+ ] = None
+ self.setColor("#001122") # Only there to trig updateStyle
+
+ def registerFormatFunction(
+ self,
+ fct: typing.Callable[
+ [numpy.ndarray, numpy.ndarray], str
+ ],
+ ):
+ """Register a function for the formatting of the label"""
+ self._formatFunction = fct
+
+ def _updatedStyle(self, event, style: items.CurveStyle):
+ style = items.CurveStyle(
+ color="red",
+ gapcolor="white",
+ linestyle=(0, (5, 5)),
+ linewidth=style.getLineWidth())
+ LineROI._updatedStyle(self, event, style)
+ self._handleLabel.setColor("black")
+ self._handleLabel.setBackgroundColor("#FFFFFF60")
+ self._handleLabel.setZValue(1000)
+
+ def setEndPoints(self, startPoint: numpy.ndarray, endPoint: numpy.ndarray):
+ super().setEndPoints(startPoint=startPoint, endPoint=endPoint)
+ if self._formatFunction is not None:
+ ruler_text = self._formatFunction(
+ startPoint=startPoint, endPoint=endPoint
+ )
+ self._updateText(ruler_text)
+
+
+class RulerToolButton(PlotToolButton):
+ """
+ Button to active measurement between two point of the plot
+
+ An instance of `RulerToolButton` can be added to a plot toolbar like:
+ .. code-block:: python
+
+ plot = Plot2D()
+
+ rulerButton = RulerToolButton(parent=plot, plot=plot)
+ plot.toolBar().addWidget(rulerButton)
+ """
+
+ def __init__(
+ self,
+ parent=None,
+ plot=None,
+ ):
+ super().__init__(parent=parent, plot=plot)
+ self.setCheckable(True)
+ self._roiManager = None
+ self.__lastRoiCreated = None
+ self.setIcon(icons.getQIcon("ruler"))
+ self.toggled.connect(self._callback)
+ self._connectPlot(plot)
+
+ def setPlot(self, plot):
+ return super().setPlot(plot)
+
+ @property
+ def _lastRoiCreated(self):
+ if self.__lastRoiCreated is None:
+ return None
+ return self.__lastRoiCreated()
+
+ def _callback(self, *args, **kwargs):
+ if not self._roiManager:
+ return
+ if self._lastRoiCreated is not None:
+ self._lastRoiCreated.setVisible(self.isChecked())
+ if self.isChecked():
+ self._roiManager.start(_RulerROI, self)
+ self.__interactiveModeStarted(self._roiManager)
+ else:
+ source = self._roiManager.getInteractionSource()
+ if source is self:
+ self._roiManager.stop()
+
+ def __interactiveModeStarted(self, roiManager):
+ roiManager.sigInteractiveModeFinished.connect(self.__interactiveModeFinished)
+
+ def __interactiveModeFinished(self):
+ roiManager = self._roiManager
+ if roiManager is not None:
+ roiManager.sigInteractiveModeFinished.disconnect(
+ self.__interactiveModeFinished
+ )
+ self.setChecked(False)
+
+ def _connectPlot(self, plot):
+ """
+ Called when the plot is connected to the widget
+
+ :param plot: :class:`.PlotWidget` instance
+ """
+ if plot is None:
+ return
+ self._roiManager = RegionOfInterestManager(plot)
+ self._roiManager.sigRoiAdded.connect(self._registerCurrentROI)
+
+ def _disconnectPlot(self, plot):
+ if plot and self._lastRoiCreated is not None:
+ self._roiManager.removeRoi(self._lastRoiCreated)
+ self.__lastRoiCreated = None
+ return super()._disconnectPlot(plot)
+
+ def _registerCurrentROI(self, currentRoi):
+ if self._lastRoiCreated is None:
+ self.__lastRoiCreated = weakref.ref(currentRoi)
+ self._lastRoiCreated.registerFormatFunction(self.buildDistanceText)
+ elif currentRoi is not self._lastRoiCreated and self._roiManager is not None:
+ self._roiManager.removeRoi(self._lastRoiCreated)
+ currentRoi.registerFormatFunction(self.buildDistanceText)
+ self.__lastRoiCreated = weakref.ref(currentRoi)
+
+ def buildDistanceText(self, startPoint: numpy.ndarray, endPoint: numpy.ndarray) -> str:
+ """
+ Define the text to be displayed by the ruler.
+
+ It can be redefine to modify precision or handle other parameters
+ (handling pixel size to display metric distance, display distance
+ on each distance - for non-square pixels...)
+ """
+ distance = numpy.linalg.norm(endPoint - startPoint)
+ return f"{distance: .1f}px"
diff --git a/src/silx/utils/html.py b/src/silx/gui/plot/tools/compare/__init__.py
index 654c780..7f23852 100644
--- a/src/silx/utils/html.py
+++ b/src/silx/gui/plot/tools/compare/__init__.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2018 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -21,16 +21,9 @@
# THE SOFTWARE.
#
# ###########################################################################*/
+"""This module provides tools related to the compare image plot.
+"""
__authors__ = ["V. Valls"]
__license__ = "MIT"
-__date__ = "19/09/2016"
-
-from .deprecation import deprecated_warning
-
-deprecated_warning(type_='module',
- name=__name__,
- replacement='html',
- since_version='0.15.0')
-
-from html import escape # noqa
+__date__ = "09/06/2023"
diff --git a/src/silx/gui/plot/tools/compare/core.py b/src/silx/gui/plot/tools/compare/core.py
new file mode 100644
index 0000000..90dbb79
--- /dev/null
+++ b/src/silx/gui/plot/tools/compare/core.py
@@ -0,0 +1,198 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides main objects shared by the compare image plot.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "09/06/2023"
+
+
+import numpy
+import enum
+import contextlib
+from typing import NamedTuple
+
+from silx.gui.plot.items.image import ImageBase
+from silx.gui.plot.items.core import ItemChangedType, ColormapMixIn
+
+from silx.opencl import ocl
+
+if ocl is not None:
+ try:
+ from silx.opencl import sift
+ except ImportError:
+ # sift module is not available (e.g., in official Debian packages)
+ sift = None
+else: # No OpenCL device or no pyopencl
+ sift = None
+
+
+@enum.unique
+class VisualizationMode(enum.Enum):
+ """Enum for each visualization mode available."""
+
+ ONLY_A = "a"
+ ONLY_B = "b"
+ VERTICAL_LINE = "vline"
+ HORIZONTAL_LINE = "hline"
+ COMPOSITE_RED_BLUE_GRAY = "rbgchannel"
+ COMPOSITE_RED_BLUE_GRAY_NEG = "rbgnegchannel"
+ COMPOSITE_A_MINUS_B = "aminusb"
+
+
+@enum.unique
+class AlignmentMode(enum.Enum):
+ """Enum for each alignment mode available."""
+
+ ORIGIN = "origin"
+ CENTER = "center"
+ STRETCH = "stretch"
+ AUTO = "auto"
+
+
+class AffineTransformation(NamedTuple):
+ """Description of a 2D affine transformation: translation, scale and
+ rotation.
+ """
+
+ tx: float
+ ty: float
+ sx: float
+ sy: float
+ rot: float
+
+
+class _CompareImageItem(ImageBase, ColormapMixIn):
+ """Description of a virtual item of images to compare, in order to share
+ the data through the silx components.
+ """
+
+ def __init__(self):
+ ImageBase.__init__(self)
+ ColormapMixIn.__init__(self)
+ self.__image1 = None
+ self.__image2 = None
+ self.__vizualisationMode = VisualizationMode.ONLY_A
+
+ def getImageData1(self):
+ return self.__image1
+
+ def getImageData2(self):
+ return self.__image2
+
+ def setImageData1(self, image1):
+ if self.__image1 is image1:
+ return
+ self.__image1 = image1
+ self._updated(ItemChangedType.DATA)
+
+ def setImageData2(self, image2):
+ if self.__image2 is image2:
+ return
+ self.__image2 = image2
+ self._updated(ItemChangedType.DATA)
+
+ def getVizualisationMode(self) -> VisualizationMode:
+ return self.__vizualisationMode
+
+ @contextlib.contextmanager
+ def _updateColormapRange(self, previousMode, mode):
+ """COMPOSITE_A_MINUS_B don't have the same data range than others.
+
+ If the colormap is using a fixed range, it is updated in order to set
+ a similar range with the new data.
+ """
+ normalize_colormap = (
+ previousMode == VisualizationMode.COMPOSITE_A_MINUS_B
+ or mode == VisualizationMode.COMPOSITE_A_MINUS_B
+ )
+ if normalize_colormap:
+ data = self._getConcatenatedData(copy=False)
+ if data is None or data.size == 0:
+ normalize_colormap = False
+ else:
+ std1 = numpy.nanstd(data)
+ mean1 = numpy.nanmean(data)
+ yield
+
+ def transfer(v, std1, mean1, std2, mean2):
+ """Transfer a value from a data range to another using statistics"""
+ if v is None:
+ return None
+ rv = (v - mean1) / std1
+ return rv * std2 + mean2
+
+ if normalize_colormap:
+ data = self._getConcatenatedData(copy=False)
+ if data is not None and data.size != 0:
+ std2 = numpy.nanstd(data)
+ mean2 = numpy.nanmean(data)
+ c = self.getColormap()
+ if c is not None:
+ vmin, vmax = c.getVRange()
+ vmin = transfer(vmin, std1, mean1, std2, mean2)
+ vmax = transfer(vmax, std1, mean1, std2, mean2)
+ c.setVRange(vmin, vmax)
+
+ def setVizualisationMode(self, mode: VisualizationMode):
+ if self.__vizualisationMode == mode:
+ return None
+ with self._updateColormapRange(self.__vizualisationMode, mode):
+ self.__vizualisationMode = mode
+ self._updated(ItemChangedType.DATA)
+
+ def _getConcatenatedData(self, copy=True):
+ if self.__image1 is None and self.__image2 is None:
+ return None
+ if self.__image1 is None:
+ return numpy.array(self.__image2, copy=copy)
+ if self.__image2 is None:
+ return numpy.array(self.__image1, copy=copy)
+
+ if self.__vizualisationMode == VisualizationMode.COMPOSITE_A_MINUS_B:
+ # In this case the histogram have to be special
+ if self.__image1.shape == self.__image2.shape:
+ return self.__image1.astype(numpy.float32) - self.__image2.astype(
+ numpy.float32
+ )
+ else:
+ d1 = self.__image1[numpy.isfinite(self.__image1)]
+ d2 = self.__image2[numpy.isfinite(self.__image2)]
+ return numpy.concatenate((d1, d2))
+
+ def _updated(self, event=None, checkVisibility=True):
+ # Synchronizes colormapped data if changed
+ if event in (ItemChangedType.DATA, ItemChangedType.MASK):
+ data = self._getConcatenatedData(copy=False)
+ return self._setColormappedData(data, copy=False)
+ super()._updated(event=event, checkVisibility=checkVisibility)
+
+ def getColormappedData(self, copy=True):
+ """
+ Reimplementation of the `ColormapMixIn.getColormappedData` method.
+
+ This is used to provide a consistent auto scale on the compared images.
+ """
+ return self._getConcatenatedData(copy=copy)
diff --git a/src/silx/gui/plot/tools/compare/profile.py b/src/silx/gui/plot/tools/compare/profile.py
new file mode 100644
index 0000000..afe0eba
--- /dev/null
+++ b/src/silx/gui/plot/tools/compare/profile.py
@@ -0,0 +1,173 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This provides profile ROIs.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "09/06/2023"
+
+
+import numpy
+
+from silx.gui.plot.tools.profile import rois
+from silx.gui.plot.tools.profile import core
+from .core import _CompareImageItem
+
+
+COLOR_A = "C0"
+COLOR_B = "C8"
+
+
+class ProfileImageLineROI(rois.ProfileImageLineROI):
+ """ROI for a compare image profile between 2 points.
+
+ The X profile of this ROI is the projection into one of the x/y axes,
+ using its scale and its orientation.
+ """
+
+ def computeProfile(self, item):
+ if not isinstance(item, _CompareImageItem):
+ raise TypeError("Unexpected class %s" % type(item))
+
+ origin = item.getOrigin()
+ scale = item.getScale()
+ method = self.getProfileMethod()
+ lineWidth = self.getProfileLineWidth()
+ roiInfo = self._getRoiInfo()
+
+ def createProfile2(currentData):
+ coords, profile, _area, profileName, xLabel = core.createProfile(
+ roiInfo=roiInfo,
+ currentData=currentData,
+ origin=origin,
+ scale=scale,
+ lineWidth=lineWidth,
+ method=method,
+ )
+ return coords, profile, profileName, xLabel
+
+ currentData1 = item.getImageData1()
+ currentData2 = item.getImageData2()
+
+ yLabel = "%s" % str(method).capitalize()
+ coords, profile1, title, xLabel = createProfile2(currentData1)
+ title = title + "; width = %d" % lineWidth
+ _coords, profile2, _title, _xLabel = createProfile2(currentData2)
+
+ profile1.shape = -1
+ profile2.shape = -1
+
+ title = title.format(xlabel="width", ylabel="height")
+ xLabel = xLabel.format(xlabel="width", ylabel="height")
+ yLabel = yLabel.format(xlabel="width", ylabel="height")
+
+ data = core.CurvesProfileData(
+ coords=coords,
+ profiles=[
+ core.CurveProfileDesc(profile1, color=COLOR_A, name="profileA"),
+ core.CurveProfileDesc(profile2, color=COLOR_B, name="profileB"),
+ ],
+ title=title,
+ xLabel=xLabel,
+ yLabel=yLabel,
+ )
+ return data
+
+
+class ProfileImageDirectedLineROI(rois.ProfileImageDirectedLineROI):
+ """ROI for a compare image profile between 2 points.
+
+ The X profile of the line is displayed projected into the line itself,
+ using its scale and its orientation. It's the distance from the origin.
+ """
+
+ def computeProfile(self, item):
+ if not isinstance(item, _CompareImageItem):
+ raise TypeError("Unexpected class %s" % type(item))
+
+ from silx.image.bilinear import BilinearImage
+
+ origin = item.getOrigin()
+ scale = item.getScale()
+ method = self.getProfileMethod()
+ lineWidth = self.getProfileLineWidth()
+
+ roiInfo = self._getRoiInfo()
+ roiStart, roiEnd, _lineProjectionMode = roiInfo
+
+ startPt = (
+ (roiStart[1] - origin[1]) / scale[1],
+ (roiStart[0] - origin[0]) / scale[0],
+ )
+ endPt = ((roiEnd[1] - origin[1]) / scale[1], (roiEnd[0] - origin[0]) / scale[0])
+
+ if numpy.array_equal(startPt, endPt):
+ return None
+
+ def computeProfile(data):
+ bilinear = BilinearImage(data)
+ profile = bilinear.profile_line(
+ (startPt[0] - 0.5, startPt[1] - 0.5),
+ (endPt[0] - 0.5, endPt[1] - 0.5),
+ lineWidth,
+ method=method,
+ )
+ return profile
+
+ currentData1 = item.getImageData1()
+ currentData2 = item.getImageData2()
+ profile1 = computeProfile(currentData1)
+ profile2 = computeProfile(currentData2)
+
+ # Compute the line size
+ lineSize = numpy.sqrt(
+ (roiEnd[1] - roiStart[1]) ** 2 + (roiEnd[0] - roiStart[0]) ** 2
+ )
+ coords = numpy.linspace(
+ 0, lineSize, len(profile1), endpoint=True, dtype=numpy.float32
+ )
+
+ title = rois._lineProfileTitle(*roiStart, *roiEnd)
+ title = title + "; width = %d" % lineWidth
+ xLabel = "√({xlabel}²+{ylabel}²)"
+ yLabel = str(method).capitalize()
+
+ # Use the axis names from the original plot
+ profileManager = self.getProfileManager()
+ plot = profileManager.getPlotWidget()
+ xLabel = rois._relabelAxes(plot, xLabel)
+ title = rois._relabelAxes(plot, title)
+
+ data = core.CurvesProfileData(
+ coords=coords,
+ profiles=[
+ core.CurveProfileDesc(profile1, color=COLOR_A, name="profileA"),
+ core.CurveProfileDesc(profile2, color=COLOR_B, name="profileB"),
+ ],
+ title=title,
+ xLabel=xLabel,
+ yLabel=yLabel,
+ )
+ return data
diff --git a/src/silx/gui/plot/tools/compare/statusbar.py b/src/silx/gui/plot/tools/compare/statusbar.py
new file mode 100644
index 0000000..5e43a37
--- /dev/null
+++ b/src/silx/gui/plot/tools/compare/statusbar.py
@@ -0,0 +1,218 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides tool bar helper.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "09/06/2023"
+
+
+import logging
+import weakref
+import numpy
+
+from silx.gui import qt
+
+
+_logger = logging.getLogger(__name__)
+
+
+class CompareImagesStatusBar(qt.QStatusBar):
+ """StatusBar containing specific information contained in a
+ :class:`CompareImages` widget
+
+ Use :meth:`setCompareWidget` to connect this toolbar to a specific
+ :class:`CompareImages` widget.
+
+ :param Union[qt.QWidget,None] parent: Parent of this widget.
+ """
+
+ def __init__(self, parent=None):
+ qt.QStatusBar.__init__(self, parent)
+ self.setSizeGripEnabled(False)
+ self.layout().setSpacing(0)
+ self.__compareWidget = None
+ self._label1 = qt.QLabel(self)
+ self._label1.setFrameShape(qt.QFrame.WinPanel)
+ self._label1.setFrameShadow(qt.QFrame.Sunken)
+ self._label2 = qt.QLabel(self)
+ self._label2.setFrameShape(qt.QFrame.WinPanel)
+ self._label2.setFrameShadow(qt.QFrame.Sunken)
+ self._transform = qt.QLabel(self)
+ self._transform.setFrameShape(qt.QFrame.WinPanel)
+ self._transform.setFrameShadow(qt.QFrame.Sunken)
+ self.addWidget(self._label1)
+ self.addWidget(self._label2)
+ self.addWidget(self._transform)
+ self._pos = None
+ self._updateStatusBar()
+
+ def setCompareWidget(self, widget):
+ """
+ Connect this tool bar to a specific :class:`CompareImages` widget.
+
+ :param Union[None,CompareImages] widget: The widget to connect with.
+ """
+ compareWidget = self.getCompareWidget()
+ if compareWidget is not None:
+ compareWidget.getPlot().sigPlotSignal.disconnect(self.__plotSignalReceived)
+ compareWidget.sigConfigurationChanged.disconnect(self.__dataChanged)
+ compareWidget = widget
+ if compareWidget is None:
+ self.__compareWidget = None
+ else:
+ self.__compareWidget = weakref.ref(compareWidget)
+ if compareWidget is not None:
+ compareWidget.getPlot().sigPlotSignal.connect(self.__plotSignalReceived)
+ compareWidget.sigConfigurationChanged.connect(self.__dataChanged)
+
+ def getCompareWidget(self):
+ """Returns the connected widget.
+
+ :rtype: CompareImages
+ """
+ if self.__compareWidget is None:
+ return None
+ else:
+ return self.__compareWidget()
+
+ def __plotSignalReceived(self, event):
+ """Called when old style signals at emmited from the plot."""
+ if event["event"] == "mouseMoved":
+ x, y = event["x"], event["y"]
+ self.__mouseMoved(x, y)
+
+ def __mouseMoved(self, x, y):
+ """Called when mouse move over the plot."""
+ self._pos = x, y
+ self._updateStatusBar()
+
+ def __dataChanged(self):
+ """Called when internal data from the connected widget changes."""
+ self._updateStatusBar()
+
+ def _formatData(self, data):
+ """Format pixel of an image.
+
+ It supports intensity, RGB, and RGBA.
+
+ :param Union[int,float,numpy.ndarray,str]: Value of a pixel
+ :rtype: str
+ """
+ if data is None:
+ return "No data"
+ if isinstance(data, (int, numpy.integer)):
+ return "%d" % data
+ if isinstance(data, (float, numpy.floating)):
+ return "%f" % data
+ if isinstance(data, numpy.ndarray):
+ # RGBA value
+ if data.shape == (3,):
+ return "R:%d G:%d B:%d" % (data[0], data[1], data[2])
+ elif data.shape == (4,):
+ return "R:%d G:%d B:%d A:%d" % (data[0], data[1], data[2], data[3])
+ _logger.debug("Unsupported data format %s. Cast it to string.", type(data))
+ return str(data)
+
+ def _updateStatusBar(self):
+ """Update the content of the status bar"""
+ widget = self.getCompareWidget()
+ if widget is None:
+ self._label1.setText("ImageA: NA")
+ self._label2.setText("ImageB: NA")
+ self._transform.setVisible(False)
+ else:
+ transform = widget.getTransformation()
+ self._transform.setVisible(transform is not None)
+ if transform is not None:
+ has_notable_translation = not numpy.isclose(
+ transform.tx, 0.0, atol=0.01
+ ) or not numpy.isclose(transform.ty, 0.0, atol=0.01)
+ has_notable_scale = not numpy.isclose(
+ transform.sx, 1.0, atol=0.01
+ ) or not numpy.isclose(transform.sy, 1.0, atol=0.01)
+ has_notable_rotation = not numpy.isclose(transform.rot, 0.0, atol=0.01)
+
+ strings = []
+ if has_notable_translation:
+ strings.append("Translation")
+ if has_notable_scale:
+ strings.append("Scale")
+ if has_notable_rotation:
+ strings.append("Rotation")
+ if strings == []:
+ has_translation = not numpy.isclose(
+ transform.tx, 0.0
+ ) or not numpy.isclose(transform.ty, 0.0)
+ has_scale = not numpy.isclose(
+ transform.sx, 1.0
+ ) or not numpy.isclose(transform.sy, 1.0)
+ has_rotation = not numpy.isclose(transform.rot, 0.0)
+ if has_translation or has_scale or has_rotation:
+ text = "No big changes"
+ else:
+ text = "No changes"
+ else:
+ text = "+".join(strings)
+ self._transform.setText("Align: " + text)
+
+ strings = []
+ if not numpy.isclose(transform.ty, 0.0):
+ strings.append("Translation x: %0.3fpx" % transform.tx)
+ if not numpy.isclose(transform.ty, 0.0):
+ strings.append("Translation y: %0.3fpx" % transform.ty)
+ if not numpy.isclose(transform.sx, 1.0):
+ strings.append("Scale x: %0.3f" % transform.sx)
+ if not numpy.isclose(transform.sy, 1.0):
+ strings.append("Scale y: %0.3f" % transform.sy)
+ if not numpy.isclose(transform.rot, 0.0):
+ strings.append(
+ "Rotation: %0.3fdeg" % (transform.rot * 180 / numpy.pi)
+ )
+ if strings == []:
+ text = "No transformation"
+ else:
+ text = "\n".join(strings)
+ self._transform.setToolTip(text)
+
+ if self._pos is None:
+ self._label1.setText("ImageA: NA")
+ self._label2.setText("ImageB: NA")
+ else:
+ data1, data2 = widget.getRawPixelData(self._pos[0], self._pos[1])
+ if isinstance(data1, str):
+ self._label1.setToolTip(data1)
+ text1 = "NA"
+ else:
+ self._label1.setToolTip("")
+ text1 = self._formatData(data1)
+ if isinstance(data2, str):
+ self._label2.setToolTip(data2)
+ text2 = "NA"
+ else:
+ self._label2.setToolTip("")
+ text2 = self._formatData(data2)
+ self._label1.setText("ImageA: %s" % text1)
+ self._label2.setText("ImageB: %s" % text2)
diff --git a/src/silx/gui/plot/tools/compare/toolbar.py b/src/silx/gui/plot/tools/compare/toolbar.py
new file mode 100644
index 0000000..a7f56ec
--- /dev/null
+++ b/src/silx/gui/plot/tools/compare/toolbar.py
@@ -0,0 +1,390 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides tool bar helper.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+
+import logging
+import weakref
+from typing import List, Optional
+
+from silx.gui import qt
+from silx.gui import icons
+from .core import AlignmentMode
+from .core import VisualizationMode
+from .core import sift
+
+
+_logger = logging.getLogger(__name__)
+
+
+class AlignmentModeToolButton(qt.QToolButton):
+ """ToolButton to select a AlignmentMode"""
+
+ sigSelected = qt.Signal(AlignmentMode)
+
+ def __init__(self, parent=None):
+ super(AlignmentModeToolButton, self).__init__(parent=parent)
+
+ menu = qt.QMenu(self)
+ self.setMenu(menu)
+
+ self.__group = qt.QActionGroup(self)
+ self.__group.setExclusive(True)
+ self.__group.triggered.connect(self.__selectionChanged)
+
+ icon = icons.getQIcon("compare-align-origin")
+ action = qt.QAction(icon, "Align images on their upper-left pixel", self)
+ action.setProperty("enum", AlignmentMode.ORIGIN)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ self.__originAlignAction = action
+ menu.addAction(action)
+ self.__group.addAction(action)
+
+ icon = icons.getQIcon("compare-align-center")
+ action = qt.QAction(icon, "Center images", self)
+ action.setProperty("enum", AlignmentMode.CENTER)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ self.__centerAlignAction = action
+ menu.addAction(action)
+ self.__group.addAction(action)
+
+ icon = icons.getQIcon("compare-align-stretch")
+ action = qt.QAction(icon, "Stretch the second image on the first one", self)
+ action.setProperty("enum", AlignmentMode.STRETCH)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ self.__stretchAlignAction = action
+ menu.addAction(action)
+ self.__group.addAction(action)
+
+ icon = icons.getQIcon("compare-align-auto")
+ action = qt.QAction(icon, "Auto-alignment of the second image", self)
+ action.setProperty("enum", AlignmentMode.AUTO)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ self.__autoAlignAction = action
+ menu.addAction(action)
+ if sift is None:
+ action.setEnabled(False)
+ action.setToolTip("Sift module is not available")
+ self.__group.addAction(action)
+
+ def getActionFromMode(self, mode: AlignmentMode) -> Optional[qt.QAction]:
+ """Returns an action from it's mode"""
+ for action in self.__group.actions():
+ actionMode = action.property("enum")
+ if mode == actionMode:
+ return action
+ return None
+
+ def setVisibleModes(self, modes: List[AlignmentMode]):
+ """Make visible only a set of modes.
+
+ The order does not matter.
+ """
+ modes = set(modes)
+ for action in self.__group.actions():
+ mode = action.property("enum")
+ action.setVisible(mode in modes)
+
+ def __selectionChanged(self, selectedAction: qt.QAction):
+ """Called when user requesting changes of the alignment mode."""
+ self.__updateMenu()
+ mode = self.getSelected()
+ self.sigSelected.emit(mode)
+
+ def __updateMenu(self):
+ """Update the state of the action containing alignment menu."""
+ selectedAction = self.__group.checkedAction()
+ if selectedAction is not None:
+ self.setText(selectedAction.text())
+ self.setIcon(selectedAction.icon())
+ self.setToolTip(selectedAction.toolTip())
+ else:
+ self.setText("")
+ self.setIcon(qt.QIcon())
+ self.setToolTip("")
+
+ def getSelected(self) -> AlignmentMode:
+ action = self.__group.checkedAction()
+ if action is None:
+ return None
+ return action.property("enum")
+
+ def setSelected(self, mode: AlignmentMode):
+ action = self.getActionFromMode(mode)
+ old = self.__group.blockSignals(True)
+ if action is not None:
+ # Check this action
+ action.setChecked(True)
+ else:
+ action = self.__group.checkedAction()
+ if action is not None:
+ # Uncheck this action
+ action.setChecked(False)
+ self.__updateMenu()
+ self.__group.blockSignals(old)
+
+
+class VisualizationModeToolButton(qt.QToolButton):
+ """ToolButton to select a VisualisationMode"""
+
+ sigSelected = qt.Signal(VisualizationMode)
+
+ def __init__(self, parent=None):
+ super(VisualizationModeToolButton, self).__init__(parent=parent)
+
+ menu = qt.QMenu(self)
+ self.setMenu(menu)
+
+ self.__group = qt.QActionGroup(self)
+ self.__group.setExclusive(True)
+ self.__group.triggered.connect(self.__selectionChanged)
+
+ icon = icons.getQIcon("compare-mode-a")
+ action = qt.QAction(icon, "Display the first image only", self)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_A))
+ action.setProperty("enum", VisualizationMode.ONLY_A)
+ menu.addAction(action)
+ self.__aModeAction = action
+ self.__group.addAction(action)
+
+ icon = icons.getQIcon("compare-mode-b")
+ action = qt.QAction(icon, "Display the second image only", self)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_B))
+ action.setProperty("enum", VisualizationMode.ONLY_B)
+ menu.addAction(action)
+ self.__bModeAction = action
+ self.__group.addAction(action)
+
+ icon = icons.getQIcon("compare-mode-vline")
+ action = qt.QAction(icon, "Vertical compare mode", self)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_V))
+ action.setProperty("enum", VisualizationMode.VERTICAL_LINE)
+ menu.addAction(action)
+ self.__vlineModeAction = action
+ self.__group.addAction(action)
+
+ icon = icons.getQIcon("compare-mode-hline")
+ action = qt.QAction(icon, "Horizontal compare mode", self)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_H))
+ action.setProperty("enum", VisualizationMode.HORIZONTAL_LINE)
+ menu.addAction(action)
+ self.__hlineModeAction = action
+ self.__group.addAction(action)
+
+ icon = icons.getQIcon("compare-mode-rb-channel")
+ action = qt.QAction(icon, "Blue/red compare mode (additive mode)", self)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_C))
+ action.setProperty("enum", VisualizationMode.COMPOSITE_RED_BLUE_GRAY)
+ menu.addAction(action)
+ self.__brChannelModeAction = action
+ self.__group.addAction(action)
+
+ icon = icons.getQIcon("compare-mode-rbneg-channel")
+ action = qt.QAction(icon, "Yellow/cyan compare mode (subtractive mode)", self)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_Y))
+ action.setProperty("enum", VisualizationMode.COMPOSITE_RED_BLUE_GRAY_NEG)
+ menu.addAction(action)
+ self.__ycChannelModeAction = action
+ self.__group.addAction(action)
+
+ icon = icons.getQIcon("compare-mode-a-minus-b")
+ action = qt.QAction(icon, "Raw A minus B compare mode", self)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_W))
+ action.setProperty("enum", VisualizationMode.COMPOSITE_A_MINUS_B)
+ menu.addAction(action)
+ self.__ycChannelModeAction = action
+ self.__group.addAction(action)
+
+ def getActionFromMode(self, mode: VisualizationMode) -> Optional[qt.QAction]:
+ """Returns an action from it's mode"""
+ for action in self.__group.actions():
+ actionMode = action.property("enum")
+ if mode == actionMode:
+ return action
+ return None
+
+ def setVisibleModes(self, modes: List[VisualizationMode]):
+ """Make visible only a set of modes.
+
+ The order does not matter.
+ """
+ modes = set(modes)
+ for action in self.__group.actions():
+ mode = action.property("enum")
+ action.setVisible(mode in modes)
+
+ def __selectionChanged(self, selectedAction: qt.QAction):
+ """Called when user requesting changes of the visualization mode."""
+ self.__updateMenu()
+ mode = self.getSelected()
+ self.sigSelected.emit(mode)
+
+ def __updateMenu(self):
+ """Update the state of the action containing visualization menu."""
+ selectedAction = self.__group.checkedAction()
+ if selectedAction is not None:
+ self.setText(selectedAction.text())
+ self.setIcon(selectedAction.icon())
+ self.setToolTip(selectedAction.toolTip())
+ else:
+ self.setText("")
+ self.setIcon(qt.QIcon())
+ self.setToolTip("")
+
+ def getSelected(self) -> VisualizationMode:
+ action = self.__group.checkedAction()
+ if action is None:
+ return None
+ return action.property("enum")
+
+ def setSelected(self, mode: VisualizationMode):
+ action = self.getActionFromMode(mode)
+ old = self.__group.blockSignals(True)
+ if action is not None:
+ # Check this action
+ action.setChecked(True)
+ else:
+ action = self.__group.checkedAction()
+ if action is not None:
+ # Uncheck this action
+ action.setChecked(False)
+ self.__updateMenu()
+ self.__group.blockSignals(old)
+
+
+class CompareImagesToolBar(qt.QToolBar):
+ """ToolBar containing specific tools to custom the configuration of a
+ :class:`CompareImages` widget
+
+ Use :meth:`setCompareWidget` to connect this toolbar to a specific
+ :class:`CompareImages` widget.
+
+ :param Union[qt.QWidget,None] parent: Parent of this widget.
+ """
+
+ def __init__(self, parent=None):
+ qt.QToolBar.__init__(self, parent)
+ self.setWindowTitle("Compare images")
+
+ self.__compareWidget = None
+
+ self.__visualizationToolButton = VisualizationModeToolButton(self)
+ self.__visualizationToolButton.setPopupMode(qt.QToolButton.InstantPopup)
+ self.__visualizationToolButton.sigSelected.connect(self.__visualizationChanged)
+ self.addWidget(self.__visualizationToolButton)
+
+ self.__alignmentToolButton = AlignmentModeToolButton(self)
+ self.__alignmentToolButton.setPopupMode(qt.QToolButton.InstantPopup)
+ self.__alignmentToolButton.sigSelected.connect(self.__alignmentChanged)
+ self.addWidget(self.__alignmentToolButton)
+
+ icon = icons.getQIcon("compare-keypoints")
+ action = qt.QAction(icon, "Display/hide alignment keypoints", self)
+ action.setCheckable(True)
+ action.triggered.connect(self.__keypointVisibilityChanged)
+ self.addAction(action)
+ self.__displayKeypoints = action
+
+ def __visualizationChanged(self, mode: VisualizationMode):
+ widget = self.getCompareWidget()
+ if widget is not None:
+ widget.setVisualizationMode(mode)
+
+ def __alignmentChanged(self, mode: AlignmentMode):
+ widget = self.getCompareWidget()
+ if widget is not None:
+ widget.setAlignmentMode(mode)
+
+ def setCompareWidget(self, widget):
+ """
+ Connect this tool bar to a specific :class:`CompareImages` widget.
+
+ :param Union[None,CompareImages] widget: The widget to connect with.
+ """
+ compareWidget = self.getCompareWidget()
+ if compareWidget is not None:
+ compareWidget.sigConfigurationChanged.disconnect(
+ self.__updateSelectedActions
+ )
+ compareWidget = widget
+ self.setEnabled(compareWidget is not None)
+ if compareWidget is None:
+ self.__compareWidget = None
+ else:
+ self.__compareWidget = weakref.ref(compareWidget)
+ if compareWidget is not None:
+ widget.sigConfigurationChanged.connect(self.__updateSelectedActions)
+ self.__updateSelectedActions()
+
+ def getCompareWidget(self):
+ """Returns the connected widget.
+
+ :rtype: CompareImages
+ """
+ if self.__compareWidget is None:
+ return None
+ else:
+ return self.__compareWidget()
+
+ def __updateSelectedActions(self):
+ """
+ Update the state of this tool bar according to the state of the
+ connected :class:`CompareImages` widget.
+ """
+ widget = self.getCompareWidget()
+ if widget is None:
+ return
+ self.__visualizationToolButton.setSelected(widget.getVisualizationMode())
+ self.__alignmentToolButton.setSelected(widget.getAlignmentMode())
+ self.__displayKeypoints.setChecked(widget.getKeypointsVisible())
+
+ def __keypointVisibilityChanged(self):
+ """Called when action managing keypoints visibility changes"""
+ widget = self.getCompareWidget()
+ if widget is not None:
+ keypointsVisible = self.__displayKeypoints.isChecked()
+ widget.setKeypointsVisible(keypointsVisible)
diff --git a/src/silx/gui/plot/tools/menus.py b/src/silx/gui/plot/tools/menus.py
new file mode 100644
index 0000000..c748b6e
--- /dev/null
+++ b/src/silx/gui/plot/tools/menus.py
@@ -0,0 +1,93 @@
+# /*##########################################################################
+#
+# Copyright (c) 2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module provides :class:`PlotWidget`-related QMenu.
+
+The following QMenu is available:
+
+- :class:`ZoomEnabledAxesMenu`
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "12/06/2023"
+
+
+import weakref
+from typing import Optional
+
+from silx.gui import qt
+
+from ..PlotWidget import PlotWidget
+
+
+class ZoomEnabledAxesMenu(qt.QMenu):
+ """Menu to toggle axes for zoom interaction"""
+
+ def __init__(self, plot: PlotWidget, parent: Optional[qt.QWidget] = None):
+ super().__init__(parent)
+ self.setTitle("Zoom axes")
+
+ assert isinstance(plot, PlotWidget)
+ self.__plotRef = weakref.ref(plot)
+
+ self.addSection("Enabled axes")
+ self.__xAxisAction = qt.QAction("X axis", parent=self)
+ self.__yAxisAction = qt.QAction("Y left axis", parent=self)
+ self.__y2AxisAction = qt.QAction("Y right axis", parent=self)
+
+ for action in (self.__xAxisAction, self.__yAxisAction, self.__y2AxisAction):
+ action.setCheckable(True)
+ action.setChecked(True)
+ action.triggered.connect(self._axesActionTriggered)
+ self.addAction(action)
+
+ # Listen to interaction configuration change
+ plot.interaction().sigChanged.connect(self._interactionChanged)
+ # Init the state
+ self._interactionChanged()
+
+ def getPlotWidget(self) -> Optional[PlotWidget]:
+ return self.__plotRef()
+
+ def _axesActionTriggered(self, checked=False):
+ plot = self.getPlotWidget()
+ if plot is None:
+ return
+
+ plot.interaction().setZoomEnabledAxes(
+ self.__xAxisAction.isChecked(),
+ self.__yAxisAction.isChecked(),
+ self.__y2AxisAction.isChecked(),
+ )
+
+ def _interactionChanged(self):
+ plot = self.getPlotWidget()
+ if plot is None:
+ return
+
+ enabledAxes = plot.interaction().getZoomEnabledAxes()
+ self.__xAxisAction.setChecked(enabledAxes.xaxis)
+ self.__yAxisAction.setChecked(enabledAxes.yaxis)
+ self.__y2AxisAction.setChecked(enabledAxes.y2axis)
diff --git a/src/silx/gui/plot/tools/profile/ScatterProfileToolBar.py b/src/silx/gui/plot/tools/profile/ScatterProfileToolBar.py
index 09f90b7..271adb8 100644
--- a/src/silx/gui/plot/tools/profile/ScatterProfileToolBar.py
+++ b/src/silx/gui/plot/tools/profile/ScatterProfileToolBar.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2018-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -29,7 +29,6 @@ __license__ = "MIT"
__date__ = "28/06/2018"
-from silx.utils import deprecation
from . import toolbar
@@ -38,16 +37,8 @@ class ScatterProfileToolBar(toolbar.ProfileToolBar):
:param parent: See :class:`QToolBar`.
:param plot: :class:`~silx.gui.plot.PlotWidget` on which to operate.
- :param str title: See :class:`QToolBar`.
"""
- def __init__(self, parent=None, plot=None, title=None):
+ def __init__(self, parent=None, plot=None):
super(ScatterProfileToolBar, self).__init__(parent, plot)
- if title is not None:
- deprecation.deprecated_warning("Attribute",
- name="title",
- reason="removed",
- since_version="0.13.0",
- only_once=True,
- skip_backtrace_count=1)
self.setScheme("scatter")
diff --git a/src/silx/gui/plot/tools/profile/core.py b/src/silx/gui/plot/tools/profile/core.py
index 5d4a674..194f459 100644
--- a/src/silx/gui/plot/tools/profile/core.py
+++ b/src/silx/gui/plot/tools/profile/core.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2018-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -24,49 +24,63 @@
"""This module define core objects for profile tools.
"""
+from __future__ import annotations
+
__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel", "H. Payno", "V. Valls"]
__license__ = "MIT"
__date__ = "17/04/2020"
-import collections
+import typing
import numpy
import weakref
from silx.image.bilinear import BilinearImage
from silx.gui import qt
+from silx.gui import colors
+import silx.gui.plot.items
+
+
+class CurveProfileData(typing.NamedTuple):
+ coords: numpy.ndarray
+ profile: numpy.ndarray
+ title: str
+ xLabel: str
+ yLabel: str
+
+
+class RgbaProfileData(typing.NamedTuple):
+ coords: numpy.ndarray
+ profile: numpy.ndarray
+ profile_r: numpy.ndarray
+ profile_g: numpy.ndarray
+ profile_b: numpy.ndarray
+ profile_a: numpy.ndarray
+ title: str
+ xLabel: str
+ yLabel: str
+
+
+class ImageProfileData(typing.NamedTuple):
+ coords: numpy.ndarray
+ profile: numpy.ndarray
+ title: str
+ xLabel: str
+ yLabel: str
+ colormap: colors.Colormap
-CurveProfileData = collections.namedtuple(
- 'CurveProfileData', [
- "coords",
- "profile",
- "title",
- "xLabel",
- "yLabel",
- ])
-
-RgbaProfileData = collections.namedtuple(
- 'RgbaProfileData', [
- "coords",
- "profile",
- "profile_r",
- "profile_g",
- "profile_b",
- "profile_a",
- "title",
- "xLabel",
- "yLabel",
- ])
-
-ImageProfileData = collections.namedtuple(
- 'ImageProfileData', [
- 'coords',
- 'profile',
- 'title',
- 'xLabel',
- 'yLabel',
- 'colormap',
- ])
+class CurveProfileDesc(typing.NamedTuple):
+ profile: numpy.ndarray
+ name: typing.Optional[str] = None
+ color: typing.Optional[str] = None
+
+
+class CurvesProfileData(typing.NamedTuple):
+ coords: numpy.ndarray
+ profiles: typing.List[CurveProfileDesc]
+ title: str
+ xLabel: str
+ yLabel: str
class ProfileRoiMixIn:
@@ -107,7 +121,7 @@ class ProfileRoiMixIn:
def _setPlotItem(self, plotItem):
"""Specify the plot item to use with this profile
- :param `~silx.gui.plot.items.item.Item` plotItem: A plot item
+ :param `~silx.gui.plot.items.Item` plotItem: A plot item
"""
previousPlotItem = self.getPlotItem()
if previousPlotItem is plotItem:
@@ -118,7 +132,7 @@ class ProfileRoiMixIn:
def getPlotItem(self):
"""Returns the plot item used by this profile
- :rtype: `~silx.gui.plot.items.item.Item`
+ :rtype: `~silx.gui.plot.items.Item`
"""
if self.__plotItem is None:
return None
@@ -171,15 +185,18 @@ class ProfileRoiMixIn:
except ValueError:
pass
- def computeProfile(self, item):
+ def computeProfile(
+ self, item: silx.gui.plot.items.Item
+ ) -> typing.Union[
+ CurveProfileData, ImageProfileData, RgbaProfileData, CurvesProfileData
+ ]:
"""
Compute the profile which will be displayed.
This method is not called from the main Qt thread, but from a thread
pool.
- :param ~silx.gui.plot.items.Item item: A plot item
- :rtype: Union[CurveProfileData,ImageProfileData]
+ :param item: A plot item
"""
raise NotImplementedError()
@@ -201,7 +218,7 @@ def _alignedFullProfile(data, origin, scale, position, roiWidth, axis, method):
"""
assert axis in (0, 1)
assert len(data.shape) == 3
- assert method in ('mean', 'sum', 'none')
+ assert method in ("mean", "sum", "none")
# Convert from plot to image coords
imgPos = int((position - origin[1 - axis]) / scale[1 - axis])
@@ -215,31 +232,35 @@ def _alignedFullProfile(data, origin, scale, position, roiWidth, axis, method):
roiWidth = min(height, roiWidth) # Clip roi width to image size
# Get [start, end[ coords of the roi in the data
- start = int(int(imgPos) + 0.5 - roiWidth / 2.)
+ start = int(int(imgPos) + 0.5 - roiWidth / 2.0)
start = min(max(0, start), height - roiWidth)
end = start + roiWidth
- if method == 'none':
+ if method == "none":
profile = None
else:
if start < height and end > 0:
- if method == 'mean':
+ if method == "mean":
fct = numpy.mean
- elif method == 'sum':
+ elif method == "sum":
fct = numpy.sum
else:
- raise ValueError('method not managed')
- profile = fct(data[:, max(0, start):min(end, height), :], axis=1).astype(numpy.float32)
+ raise ValueError("method not managed")
+ profile = fct(data[:, max(0, start) : min(end, height), :], axis=1).astype(
+ numpy.float32
+ )
else:
profile = numpy.zeros((nimages, width), dtype=numpy.float32)
# Compute effective ROI in plot coords
- profileBounds = numpy.array(
- (0, width, width, 0),
- dtype=numpy.float32) * scale[axis] + origin[axis]
- roiBounds = numpy.array(
- (start, start, end, end),
- dtype=numpy.float32) * scale[1 - axis] + origin[1 - axis]
+ profileBounds = (
+ numpy.array((0, width, width, 0), dtype=numpy.float32) * scale[axis]
+ + origin[axis]
+ )
+ roiBounds = (
+ numpy.array((start, start, end, end), dtype=numpy.float32) * scale[1 - axis]
+ + origin[1 - axis]
+ )
if axis == 0: # Horizontal profile
area = profileBounds, roiBounds
@@ -272,7 +293,7 @@ def _alignedPartialProfile(data, rowRange, colRange, axis, method):
assert len(data.shape) == 3
assert rowRange[0] < rowRange[1]
assert colRange[0] < colRange[1]
- assert method in ('mean', 'sum')
+ assert method in ("mean", "sum")
nimages, height, width = data.shape
@@ -287,22 +308,23 @@ def _alignedPartialProfile(data, rowRange, colRange, axis, method):
colStart = min(max(0, colRange[0]), width)
colEnd = min(max(0, colRange[1]), width)
- if method == 'mean':
+ if method == "mean":
_fct = numpy.mean
- elif method == 'sum':
+ elif method == "sum":
_fct = numpy.sum
else:
- raise ValueError('method not managed')
+ raise ValueError("method not managed")
- imgProfile = _fct(data[:, rowStart:rowEnd, colStart:colEnd], axis=axis + 1,
- dtype=numpy.float32)
+ imgProfile = _fct(
+ data[:, rowStart:rowEnd, colStart:colEnd], axis=axis + 1, dtype=numpy.float32
+ )
# Profile including out of bound area
profile = numpy.zeros((nimages, profileLength), dtype=numpy.float32)
# Place imgProfile in full profile
- offset = - min(0, profileRange[0])
- profile[:, offset:offset + imgProfile.shape[1]] = imgProfile
+ offset = -min(0, profileRange[0])
+ profile[:, offset : offset + imgProfile.shape[1]] = imgProfile
return profile
@@ -346,14 +368,12 @@ def createProfile(roiInfo, currentData, origin, scale, lineWidth, method):
roiWidth = max(1, lineWidth)
roiStart, roiEnd, lineProjectionMode = roiInfo
- if lineProjectionMode == 'X': # Horizontal profile on the whole image
- profile, area = _alignedFullProfile(currentData3D,
- origin, scale,
- roiStart[1], roiWidth,
- axis=0,
- method=method)
+ if lineProjectionMode == "X": # Horizontal profile on the whole image
+ profile, area = _alignedFullProfile(
+ currentData3D, origin, scale, roiStart[1], roiWidth, axis=0, method=method
+ )
- if method == 'none':
+ if method == "none":
coords = None
else:
coords = numpy.arange(len(profile[0]), dtype=numpy.float32)
@@ -361,19 +381,17 @@ def createProfile(roiInfo, currentData, origin, scale, lineWidth, method):
yMin, yMax = min(area[1]), max(area[1]) - 1
if roiWidth <= 1:
- profileName = '{ylabel} = %g' % yMin
+ profileName = "{ylabel} = %g" % yMin
else:
- profileName = '{ylabel} = [%g, %g]' % (yMin, yMax)
- xLabel = '{xlabel}'
+ profileName = "{ylabel} = [%g, %g]" % (yMin, yMax)
+ xLabel = "{xlabel}"
- elif lineProjectionMode == 'Y': # Vertical profile on the whole image
- profile, area = _alignedFullProfile(currentData3D,
- origin, scale,
- roiStart[0], roiWidth,
- axis=1,
- method=method)
+ elif lineProjectionMode == "Y": # Vertical profile on the whole image
+ profile, area = _alignedFullProfile(
+ currentData3D, origin, scale, roiStart[0], roiWidth, axis=1, method=method
+ )
- if method == 'none':
+ if method == "none":
coords = None
else:
coords = numpy.arange(len(profile[0]), dtype=numpy.float32)
@@ -381,21 +399,20 @@ def createProfile(roiInfo, currentData, origin, scale, lineWidth, method):
xMin, xMax = min(area[0]), max(area[0]) - 1
if roiWidth <= 1:
- profileName = '{xlabel} = %g' % xMin
+ profileName = "{xlabel} = %g" % xMin
else:
- profileName = '{xlabel} = [%g, %g]' % (xMin, xMax)
- xLabel = '{ylabel}'
+ profileName = "{xlabel} = [%g, %g]" % (xMin, xMax)
+ xLabel = "{ylabel}"
else: # Free line profile
-
# Convert start and end points in image coords as (row, col)
- startPt = ((roiStart[1] - origin[1]) / scale[1],
- (roiStart[0] - origin[0]) / scale[0])
- endPt = ((roiEnd[1] - origin[1]) / scale[1],
- (roiEnd[0] - origin[0]) / scale[0])
+ startPt = (
+ (roiStart[1] - origin[1]) / scale[1],
+ (roiStart[0] - origin[0]) / scale[0],
+ )
+ endPt = ((roiEnd[1] - origin[1]) / scale[1], (roiEnd[0] - origin[0]) / scale[0])
- if (int(startPt[0]) == int(endPt[0]) or
- int(startPt[1]) == int(endPt[1])):
+ if int(startPt[0]) == int(endPt[0]) or int(startPt[1]) == int(endPt[1]):
# Profile is aligned with one of the axes
# Convert to int
@@ -407,62 +424,75 @@ def createProfile(roiInfo, currentData, origin, scale, lineWidth, method):
startPt, endPt = endPt, startPt
if startPt[0] == endPt[0]: # Row aligned
- rowRange = (int(startPt[0] + 0.5 - 0.5 * roiWidth),
- int(startPt[0] + 0.5 + 0.5 * roiWidth))
+ rowRange = (
+ int(startPt[0] + 0.5 - 0.5 * roiWidth),
+ int(startPt[0] + 0.5 + 0.5 * roiWidth),
+ )
colRange = startPt[1], endPt[1] + 1
- if method == 'none':
+ if method == "none":
profile = None
else:
- profile = _alignedPartialProfile(currentData3D,
- rowRange, colRange,
- axis=0,
- method=method)
+ profile = _alignedPartialProfile(
+ currentData3D, rowRange, colRange, axis=0, method=method
+ )
else: # Column aligned
rowRange = startPt[0], endPt[0] + 1
- colRange = (int(startPt[1] + 0.5 - 0.5 * roiWidth),
- int(startPt[1] + 0.5 + 0.5 * roiWidth))
- if method == 'none':
+ colRange = (
+ int(startPt[1] + 0.5 - 0.5 * roiWidth),
+ int(startPt[1] + 0.5 + 0.5 * roiWidth),
+ )
+ if method == "none":
profile = None
else:
- profile = _alignedPartialProfile(currentData3D,
- rowRange, colRange,
- axis=1,
- method=method)
+ profile = _alignedPartialProfile(
+ currentData3D, rowRange, colRange, axis=1, method=method
+ )
# Convert ranges to plot coords to draw ROI area
area = (
numpy.array(
(colRange[0], colRange[1], colRange[1], colRange[0]),
- dtype=numpy.float32) * scale[0] + origin[0],
+ dtype=numpy.float32,
+ )
+ * scale[0]
+ + origin[0],
numpy.array(
(rowRange[0], rowRange[0], rowRange[1], rowRange[1]),
- dtype=numpy.float32) * scale[1] + origin[1])
+ dtype=numpy.float32,
+ )
+ * scale[1]
+ + origin[1],
+ )
else: # General case: use bilinear interpolation
-
# Ensure startPt <= endPt
- if (startPt[1] > endPt[1] or (
- startPt[1] == endPt[1] and startPt[0] > endPt[0])):
+ if startPt[1] > endPt[1] or (
+ startPt[1] == endPt[1] and startPt[0] > endPt[0]
+ ):
startPt, endPt = endPt, startPt
- if method == 'none':
+ if method == "none":
profile = None
else:
profile = []
for slice_idx in range(currentData3D.shape[0]):
bilinear = BilinearImage(currentData3D[slice_idx, :, :])
- profile.append(bilinear.profile_line(
- (startPt[0] - 0.5, startPt[1] - 0.5),
- (endPt[0] - 0.5, endPt[1] - 0.5),
- roiWidth,
- method=method))
+ profile.append(
+ bilinear.profile_line(
+ (startPt[0] - 0.5, startPt[1] - 0.5),
+ (endPt[0] - 0.5, endPt[1] - 0.5),
+ roiWidth,
+ method=method,
+ )
+ )
profile = numpy.array(profile)
# Extend ROI with half a pixel on each end, and
# Convert back to plot coords (x, y)
- length = numpy.sqrt((endPt[0] - startPt[0]) ** 2 +
- (endPt[1] - startPt[1]) ** 2)
+ length = numpy.sqrt(
+ (endPt[0] - startPt[0]) ** 2 + (endPt[1] - startPt[1]) ** 2
+ )
dRow = (endPt[0] - startPt[0]) / length
dCol = (endPt[1] - startPt[1]) / length
@@ -474,16 +504,29 @@ def createProfile(roiInfo, currentData, origin, scale, lineWidth, method):
dRow, dCol = dCol, -dRow
area = (
- numpy.array((roiStartPt[1] - 0.5 * roiWidth * dCol,
- roiStartPt[1] + 0.5 * roiWidth * dCol,
- roiEndPt[1] + 0.5 * roiWidth * dCol,
- roiEndPt[1] - 0.5 * roiWidth * dCol),
- dtype=numpy.float32) * scale[0] + origin[0],
- numpy.array((roiStartPt[0] - 0.5 * roiWidth * dRow,
- roiStartPt[0] + 0.5 * roiWidth * dRow,
- roiEndPt[0] + 0.5 * roiWidth * dRow,
- roiEndPt[0] - 0.5 * roiWidth * dRow),
- dtype=numpy.float32) * scale[1] + origin[1])
+ numpy.array(
+ (
+ roiStartPt[1] - 0.5 * roiWidth * dCol,
+ roiStartPt[1] + 0.5 * roiWidth * dCol,
+ roiEndPt[1] + 0.5 * roiWidth * dCol,
+ roiEndPt[1] - 0.5 * roiWidth * dCol,
+ ),
+ dtype=numpy.float32,
+ )
+ * scale[0]
+ + origin[0],
+ numpy.array(
+ (
+ roiStartPt[0] - 0.5 * roiWidth * dRow,
+ roiStartPt[0] + 0.5 * roiWidth * dRow,
+ roiEndPt[0] + 0.5 * roiWidth * dRow,
+ roiEndPt[0] - 0.5 * roiWidth * dRow,
+ ),
+ dtype=numpy.float32,
+ )
+ * scale[1]
+ + origin[1],
+ )
# Convert start and end points back to plot coords
y0 = startPt[0] * scale[1] + origin[1]
@@ -492,33 +535,33 @@ def createProfile(roiInfo, currentData, origin, scale, lineWidth, method):
x1 = endPt[1] * scale[0] + origin[0]
if startPt[1] == endPt[1]:
- profileName = '{xlabel} = %g; {ylabel} = [%g, %g]' % (x0, y0, y1)
- if method == 'none':
+ profileName = "{xlabel} = %g; {ylabel} = [%g, %g]" % (x0, y0, y1)
+ if method == "none":
coords = None
else:
coords = numpy.arange(len(profile[0]), dtype=numpy.float32)
coords = coords * scale[1] + y0
- xLabel = '{ylabel}'
+ xLabel = "{ylabel}"
elif startPt[0] == endPt[0]:
- profileName = '{ylabel} = %g; {xlabel} = [%g, %g]' % (y0, x0, x1)
- if method == 'none':
+ profileName = "{ylabel} = %g; {xlabel} = [%g, %g]" % (y0, x0, x1)
+ if method == "none":
coords = None
else:
coords = numpy.arange(len(profile[0]), dtype=numpy.float32)
coords = coords * scale[0] + x0
- xLabel = '{xlabel}'
+ xLabel = "{xlabel}"
else:
m = (y1 - y0) / (x1 - x0)
b = y0 - m * x0
- profileName = '{ylabel} = %g * {xlabel} %+g' % (m, b)
- if method == 'none':
+ profileName = "{ylabel} = %g * {xlabel} %+g" % (m, b)
+ if method == "none":
coords = None
else:
- coords = numpy.linspace(x0, x1, len(profile[0]),
- endpoint=True,
- dtype=numpy.float32)
- xLabel = '{xlabel}'
+ coords = numpy.linspace(
+ x0, x1, len(profile[0]), endpoint=True, dtype=numpy.float32
+ )
+ xLabel = "{xlabel}"
return coords, profile, area, profileName, xLabel
diff --git a/src/silx/gui/plot/tools/profile/editors.py b/src/silx/gui/plot/tools/profile/editors.py
index 1d6f198..d53f775 100644
--- a/src/silx/gui/plot/tools/profile/editors.py
+++ b/src/silx/gui/plot/tools/profile/editors.py
@@ -43,7 +43,6 @@ _logger = logging.getLogger(__name__)
class _NoProfileRoiEditor(qt.QWidget):
-
sigDataCommited = qt.Signal()
def setEditorData(self, roi):
@@ -54,7 +53,6 @@ class _NoProfileRoiEditor(qt.QWidget):
class _DefaultImageProfileRoiEditor(qt.QWidget):
-
sigDataCommited = qt.Signal()
def __init__(self, parent=None):
@@ -72,7 +70,7 @@ class _DefaultImageProfileRoiEditor(qt.QWidget):
self._methodsButton = ProfileOptionToolButton(parent=self, plot=None)
self._methodsButton.sigMethodChanged.connect(self._widgetChanged)
- label = qt.QLabel('W:')
+ label = qt.QLabel("W:")
label.setToolTip("Line width in pixels")
layout.addWidget(label)
layout.addWidget(self._lineWidth)
@@ -99,7 +97,6 @@ class _DefaultImageProfileRoiEditor(qt.QWidget):
class _DefaultImageStackProfileRoiEditor(_DefaultImageProfileRoiEditor):
-
def _initLayout(self, layout):
super(_DefaultImageStackProfileRoiEditor, self)._initLayout(layout)
self._profileDim = ProfileToolButton(parent=self, plot=None)
@@ -121,7 +118,6 @@ class _DefaultImageStackProfileRoiEditor(_DefaultImageProfileRoiEditor):
class _DefaultScatterProfileRoiEditor(qt.QWidget):
-
sigDataCommited = qt.Signal()
def __init__(self, parent=None):
@@ -134,7 +130,7 @@ class _DefaultScatterProfileRoiEditor(qt.QWidget):
layout = qt.QHBoxLayout(self)
layout.setContentsMargins(0, 0, 0, 0)
- label = qt.QLabel('Samples:')
+ label = qt.QLabel("Samples:")
label.setToolTip("Number of sample points of the profile")
layout.addWidget(label)
layout.addWidget(self._nPoints)
@@ -160,6 +156,7 @@ class ProfileRoiEditorAction(qt.QWidgetAction):
:param qt.QWidget parent: Parent widget
"""
+
def __init__(self, parent=None):
super(ProfileRoiEditorAction, self).__init__(parent)
self.__roiManager = None
@@ -237,8 +234,7 @@ class ProfileRoiEditorAction(qt.QWidgetAction):
return self.__roi
def __roiPropertyChanged(self):
- """Handle changes on the property defining the ROI.
- """
+ """Handle changes on the property defining the ROI."""
self._updateWidgetValues()
def __setEditor(self, widget, editor):
@@ -265,16 +261,20 @@ class ProfileRoiEditorAction(qt.QWidgetAction):
"""Returns the editor class to use according to the ROI."""
if roi is None:
editorClass = _NoProfileRoiEditor
- elif isinstance(roi, (rois._DefaultImageStackProfileRoiMixIn,
- rois.ProfileImageStackCrossROI)):
+ elif isinstance(
+ roi,
+ (rois._DefaultImageStackProfileRoiMixIn, rois.ProfileImageStackCrossROI),
+ ):
# Must be done before the default image ROI
# Cause ImageStack ROIs inherit from Image ROIs
editorClass = _DefaultImageStackProfileRoiEditor
- elif isinstance(roi, (rois._DefaultImageProfileRoiMixIn,
- rois.ProfileImageCrossROI)):
+ elif isinstance(
+ roi, (rois._DefaultImageProfileRoiMixIn, rois.ProfileImageCrossROI)
+ ):
editorClass = _DefaultImageProfileRoiEditor
- elif isinstance(roi, (rois._DefaultScatterProfileRoiMixIn,
- rois.ProfileScatterCrossROI)):
+ elif isinstance(
+ roi, (rois._DefaultScatterProfileRoiMixIn, rois.ProfileScatterCrossROI)
+ ):
editorClass = _DefaultScatterProfileRoiEditor
else:
# Unsupported
diff --git a/src/silx/gui/plot/tools/profile/manager.py b/src/silx/gui/plot/tools/profile/manager.py
index 58c1c86..6f4ba35 100644
--- a/src/silx/gui/plot/tools/profile/manager.py
+++ b/src/silx/gui/plot/tools/profile/manager.py
@@ -64,12 +64,12 @@ class _RunnableComputeProfile(qt.QRunnable):
class _Signals(qt.QObject):
"""Signal holder"""
+
resultReady = qt.Signal(object, object)
runnerFinished = qt.Signal(object)
def __init__(self, threadPool, item, roi):
- """Constructor
- """
+ """Constructor"""
super(_RunnableComputeProfile, self).__init__()
self._signals = self._Signals()
self._signals.moveToThread(threadPool.thread())
@@ -114,8 +114,7 @@ class _RunnableComputeProfile(qt.QRunnable):
return self._signals.runnerFinished
def run(self):
- """Process the profile computation.
- """
+ """Process the profile computation."""
if not self._cancelled:
try:
profileData = self._roi.computeProfile(self._item)
@@ -141,7 +140,7 @@ class ProfileWindow(qt.QMainWindow):
def __init__(self, parent=None, backend=None):
qt.QMainWindow.__init__(self, parent=parent, flags=qt.Qt.Dialog)
- self.setWindowTitle('Profile window')
+ self.setWindowTitle("Profile window")
self._plot1D = None
self._plot2D = None
self._backend = backend
@@ -175,10 +174,11 @@ class ProfileWindow(qt.QMainWindow):
"""
# import here to avoid circular import
from ...PlotWindow import Plot1D
+
plot = Plot1D(parent=parent, backend=backend)
plot.setDataMargins(yMinMargin=0.1, yMaxMargin=0.1)
- plot.setGraphYLabel('Profile')
- plot.setGraphXLabel('')
+ plot.setGraphYLabel("Profile")
+ plot.setGraphXLabel("")
positionInfo = plot.getPositionInfoWidget()
positionInfo.setSnappingMode(positionInfo.SNAPPING_CURVE)
return plot
@@ -194,6 +194,7 @@ class ProfileWindow(qt.QMainWindow):
"""
# import here to avoid circular import
from ...PlotWindow import Plot2D
+
return Plot2D(parent=parent, backend=backend)
def getPlot1D(self, init=True):
@@ -241,12 +242,12 @@ class ProfileWindow(qt.QMainWindow):
return
self.__color = colors.rgba(roi.getColor())
- def _setImageProfile(self, data):
+ def _setImageProfile(self, data: core.ImageProfileData):
"""
Setup the window to display a new profile data which is represented
by an image.
- :param core.ImageProfileData data: Computed data profile
+ :param data: Computed data profile
"""
plot = self.getPlot2D()
@@ -254,25 +255,26 @@ class ProfileWindow(qt.QMainWindow):
plot.setGraphTitle(data.title)
plot.getXAxis().setLabel(data.xLabel)
-
coords = data.coords
colormap = data.colormap
profileScale = (coords[-1] - coords[0]) / data.profile.shape[1], 1
- plot.addImage(data.profile,
- legend="profile",
- colormap=colormap,
- origin=(coords[0], 0),
- scale=profileScale)
+ plot.addImage(
+ data.profile,
+ legend="profile",
+ colormap=colormap,
+ origin=(coords[0], 0),
+ scale=profileScale,
+ )
plot.getYAxis().setLabel("Frame index (depth)")
self._showPlot2D()
- def _setCurveProfile(self, data):
+ def _setCurveProfile(self, data: core.CurveProfileData):
"""
Setup the window to display a new profile data which is represented
by a curve.
- :param core.CurveProfileData data: Computed data profile
+ :param data: Computed data profile
"""
plot = self.getPlot1D()
@@ -281,19 +283,16 @@ class ProfileWindow(qt.QMainWindow):
plot.getXAxis().setLabel(data.xLabel)
plot.getYAxis().setLabel(data.yLabel)
- plot.addCurve(data.coords,
- data.profile,
- legend="level",
- color=self.__color)
+ plot.addCurve(data.coords, data.profile, legend="level", color=self.__color)
self._showPlot1D()
- def _setRgbaProfile(self, data):
+ def _setRgbaProfile(self, data: core.RgbaProfileData):
"""
Setup the window to display a new profile data which is represented
by a curve.
- :param core.RgbaProfileData data: Computed data profile
+ :param data: Computed data profile
"""
plot = self.getPlot1D()
@@ -304,17 +303,33 @@ class ProfileWindow(qt.QMainWindow):
self._showPlot1D()
- plot.addCurve(data.coords, data.profile,
- legend="level", color="black")
- plot.addCurve(data.coords, data.profile_r,
- legend="red", color="red")
- plot.addCurve(data.coords, data.profile_g,
- legend="green", color="green")
- plot.addCurve(data.coords, data.profile_b,
- legend="blue", color="blue")
+ plot.addCurve(data.coords, data.profile, legend="level", color="black")
+ plot.addCurve(data.coords, data.profile_r, legend="red", color="red")
+ plot.addCurve(data.coords, data.profile_g, legend="green", color="green")
+ plot.addCurve(data.coords, data.profile_b, legend="blue", color="blue")
if data.profile_a is not None:
plot.addCurve(data.coords, data.profile_a, legend="alpha", color="gray")
+ def _setCurvesProfile(self, data: core.CurvesProfileData):
+ """
+ Setup the window to display a new profile data which is represented
+ by multiple curves.
+
+ :param data: Computed data profile
+ """
+ plot = self.getPlot1D()
+
+ plot.clear()
+ plot.setGraphTitle(data.title)
+ plot.getXAxis().setLabel(data.xLabel)
+ plot.getYAxis().setLabel(data.yLabel)
+
+ self._showPlot1D()
+
+ for i, desc in enumerate(data.profiles):
+ name = desc.name if desc.name is not None else f"profile{i}"
+ plot.addCurve(data.coords, desc.profile, legend=name, color=desc.color)
+
def clear(self):
"""Clear the window profile"""
plot = self.getPlot1D(init=False)
@@ -346,6 +361,8 @@ class ProfileWindow(qt.QMainWindow):
self._setRgbaProfile(data)
elif isinstance(data, core.CurveProfileData):
self._setCurveProfile(data)
+ elif isinstance(data, core.CurvesProfileData):
+ self._setCurvesProfile(data)
else:
raise TypeError("Unsupported type %s" % type(data))
@@ -359,10 +376,10 @@ class _ClearAction(qt.QAction):
def __init__(self, parent, profileManager):
super(_ClearAction, self).__init__(parent)
self.__profileManager = weakref.ref(profileManager)
- icon = icons.getQIcon('profile-clear')
+ icon = icons.getQIcon("profile-clear")
self.setIcon(icon)
- self.setText('Clear profile')
- self.setToolTip('Clear the profiles')
+ self.setText("Clear profile")
+ self.setToolTip("Clear the profiles")
self.setCheckable(False)
self.setEnabled(False)
self.triggered.connect(profileManager.clearProfile)
@@ -420,37 +437,47 @@ class _StoreLastParamBehavior(qt.QObject):
if previousRoi is roi:
return
if previousRoi is not None:
- previousRoi.sigProfilePropertyChanged.disconnect(self._profilePropertyChanged)
+ previousRoi.sigProfilePropertyChanged.disconnect(
+ self._profilePropertyChanged
+ )
self.__profileRoi = None if roi is None else weakref.ref(roi)
if roi is not None:
roi.sigProfilePropertyChanged.connect(self._profilePropertyChanged)
def _profilePropertyChanged(self):
- """Handle changes on the properties defining the profile ROI.
- """
+ """Handle changes on the properties defining the profile ROI."""
if self.__filter.locked():
return
roi = self.sender()
self.storeProperties(roi)
def storeProperties(self, roi):
- if isinstance(roi, (rois._DefaultImageStackProfileRoiMixIn,
- rois.ProfileImageStackCrossROI)):
+ if isinstance(
+ roi,
+ (rois._DefaultImageStackProfileRoiMixIn, rois.ProfileImageStackCrossROI),
+ ):
self.__properties["method"] = roi.getProfileMethod()
self.__properties["line-width"] = roi.getProfileLineWidth()
self.__properties["type"] = roi.getProfileType()
- elif isinstance(roi, (rois._DefaultImageProfileRoiMixIn,
- rois.ProfileImageCrossROI)):
+ elif isinstance(
+ roi, (rois._DefaultImageProfileRoiMixIn, rois.ProfileImageCrossROI)
+ ):
self.__properties["method"] = roi.getProfileMethod()
self.__properties["line-width"] = roi.getProfileLineWidth()
- elif isinstance(roi, (rois._DefaultScatterProfileRoiMixIn,
- rois.ProfileScatterCrossROI)):
+ elif isinstance(
+ roi, (rois._DefaultScatterProfileRoiMixIn, rois.ProfileScatterCrossROI)
+ ):
self.__properties["npoints"] = roi.getNPoints()
def restoreProperties(self, roi):
with self.__filter:
- if isinstance(roi, (rois._DefaultImageStackProfileRoiMixIn,
- rois.ProfileImageStackCrossROI)):
+ if isinstance(
+ roi,
+ (
+ rois._DefaultImageStackProfileRoiMixIn,
+ rois.ProfileImageStackCrossROI,
+ ),
+ ):
value = self.__properties.get("method", None)
if value is not None:
roi.setProfileMethod(value)
@@ -460,16 +487,18 @@ class _StoreLastParamBehavior(qt.QObject):
value = self.__properties.get("type", None)
if value is not None:
roi.setProfileType(value)
- elif isinstance(roi, (rois._DefaultImageProfileRoiMixIn,
- rois.ProfileImageCrossROI)):
+ elif isinstance(
+ roi, (rois._DefaultImageProfileRoiMixIn, rois.ProfileImageCrossROI)
+ ):
value = self.__properties.get("method", None)
if value is not None:
roi.setProfileMethod(value)
value = self.__properties.get("line-width", None)
if value is not None:
roi.setProfileLineWidth(value)
- elif isinstance(roi, (rois._DefaultScatterProfileRoiMixIn,
- rois.ProfileScatterCrossROI)):
+ elif isinstance(
+ roi, (rois._DefaultScatterProfileRoiMixIn, rois.ProfileScatterCrossROI)
+ ):
value = self.__properties.get("npoints", None)
if value is not None:
roi.setNPoints(value)
@@ -482,12 +511,12 @@ class ProfileManager(qt.QObject):
:param plot: :class:`~silx.gui.plot.tools.roi.RegionOfInterestManager`
on which to operate.
"""
+
def __init__(self, parent=None, plot=None, roiManager=None):
super(ProfileManager, self).__init__(parent)
assert isinstance(plot, PlotWidget)
- self._plotRef = weakref.ref(
- plot, WeakMethodProxy(self.__plotDestroyed))
+ self._plotRef = weakref.ref(plot, WeakMethodProxy(self.__plotDestroyed))
# Set-up interaction manager
if roiManager is None:
@@ -590,14 +619,16 @@ class ProfileManager(qt.QObject):
if hasattr(profileRoiClass, "ICON"):
action.setIcon(icons.getQIcon(profileRoiClass.ICON))
if hasattr(profileRoiClass, "NAME"):
+
def articulify(word):
"""Add an an/a article in the front of the word"""
- first = word[1] if word[0] == 'h' else word[0]
+ first = word[1] if word[0] == "h" else word[0]
if first in "aeiou":
return "an " + word
return "a " + word
- action.setText('Define %s' % articulify(profileRoiClass.NAME))
- action.setToolTip('Enables %s selection mode' % profileRoiClass.NAME)
+
+ action.setText("Define %s" % articulify(profileRoiClass.NAME))
+ action.setToolTip("Enables %s selection mode" % profileRoiClass.NAME)
action.setSingleShot(True)
return action
@@ -623,7 +654,7 @@ class ProfileManager(qt.QObject):
rois.ProfileImageLineROI,
rois.ProfileImageDirectedLineROI,
rois.ProfileImageCrossROI,
- ]
+ ]
return [self.createProfileAction(pc, parent=parent) for pc in profileClasses]
def createScatterActions(self, parent):
@@ -638,7 +669,7 @@ class ProfileManager(qt.QObject):
rois.ProfileScatterVerticalLineROI,
rois.ProfileScatterLineROI,
rois.ProfileScatterCrossROI,
- ]
+ ]
return [self.createProfileAction(pc, parent=parent) for pc in profileClasses]
def createScatterSliceActions(self, parent):
@@ -655,7 +686,7 @@ class ProfileManager(qt.QObject):
rois.ProfileScatterHorizontalSliceROI,
rois.ProfileScatterVerticalSliceROI,
rois.ProfileScatterCrossSliceROI,
- ]
+ ]
return [self.createProfileAction(pc, parent=parent) for pc in profileClasses]
def createImageStackActions(self, parent):
@@ -673,7 +704,7 @@ class ProfileManager(qt.QObject):
rois.ProfileImageStackVerticalLineROI,
rois.ProfileImageStackLineROI,
rois.ProfileImageStackCrossROI,
- ]
+ ]
return [self.createProfileAction(pc, parent=parent) for pc in profileClasses]
def createEditorAction(self, parent):
@@ -705,8 +736,7 @@ class ProfileManager(qt.QObject):
self.setPlotItem(item)
def setProfileWindowClass(self, profileWindowClass):
- """Set the class which will be instantiated to display profile result.
- """
+ """Set the class which will be instantiated to display profile result."""
self._profileWindowClass = profileWindowClass
def setActiveItemTracking(self, tracking):
@@ -798,7 +828,7 @@ class ProfileManager(qt.QObject):
roiManager.removeRoi(roi)
if not roiManager.isDrawing():
- # Clean the selected mode
+ # Clean the selected mode
roiManager.stop()
def hasPendingOperations(self):
@@ -809,8 +839,7 @@ class ProfileManager(qt.QObject):
return len(self.__reentrantResults) > 0 or len(self._pendingRunners) > 0
def requestUpdateAllProfile(self):
- """Request to update the profile of all the managed ROIs.
- """
+ """Request to update the profile of all the managed ROIs."""
for roi in self._rois:
self.requestUpdateProfile(roi)
@@ -868,7 +897,7 @@ class ProfileManager(qt.QObject):
if roi in self.__reentrantResults:
# Store the data to process it in the main loop
# And not a sub loop created by initProfileWindow
- # This also remove the duplicated requested
+ # This also remove the duplicated requested
self.__reentrantResults[roi] = profileData
return
@@ -918,7 +947,7 @@ class ProfileManager(qt.QObject):
:param ~silx.gui.plot.items.item.Item item: AN item
:rtype: qt.QColor
"""
- color = 'pink'
+ color = "pink"
if isinstance(item, items.ColormapMixIn):
colormap = item.getColormap()
name = colormap.getName()
@@ -948,12 +977,13 @@ class ProfileManager(qt.QObject):
roi.setColor(color)
def __itemChanged(self, changeType):
- """Handle item changes.
- """
- if changeType in (items.ItemChangedType.DATA,
- items.ItemChangedType.MASK,
- items.ItemChangedType.POSITION,
- items.ItemChangedType.SCALE):
+ """Handle item changes."""
+ if changeType in (
+ items.ItemChangedType.DATA,
+ items.ItemChangedType.MASK,
+ items.ItemChangedType.POSITION,
+ items.ItemChangedType.SCALE,
+ ):
self.requestUpdateAllProfile()
elif changeType == (items.ItemChangedType.COLORMAP):
self._updateRoiColors()
@@ -1040,7 +1070,7 @@ class ProfileManager(qt.QObject):
window = self.getPlotWidget().window()
winGeom = window.frameGeometry()
- if qt.BINDING in ("PySide2", "PyQt5"):
+ if qt.BINDING == "PyQt5":
qapp = qt.QApplication.instance()
desktop = qapp.desktop()
screenGeom = desktop.availableGeometry(window)
@@ -1070,7 +1100,6 @@ class ProfileManager(qt.QObject):
left = screenGeom.width() - profileGeom.width()
profileWindow.move(left, top)
-
def clearProfileWindow(self, profileWindow):
"""Called when a profile window is not anymore needed.
diff --git a/src/silx/gui/plot/tools/profile/rois.py b/src/silx/gui/plot/tools/profile/rois.py
index 042aff1..23f086a 100644
--- a/src/silx/gui/plot/tools/profile/rois.py
+++ b/src/silx/gui/plot/tools/profile/rois.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2018-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -75,13 +75,13 @@ def _lineProfileTitle(x0, y0, x1, y1):
:rtype: str
"""
if x0 == x1:
- title = '{xlabel} = %g; {ylabel} = [%g, %g]' % (x0, y0, y1)
+ title = "{xlabel} = %g; {ylabel} = [%g, %g]" % (x0, y0, y1)
elif y0 == y1:
- title = '{ylabel} = %g; {xlabel} = [%g, %g]' % (y0, x0, x1)
+ title = "{ylabel} = %g; {xlabel} = [%g, %g]" % (y0, x0, x1)
else:
m = (y1 - y0) / (x1 - x0)
b = y0 - m * x0
- title = '{ylabel} = %g * {xlabel} %+g' % (m, b)
+ title = "{ylabel} = %g * {xlabel} %+g" % (m, b)
return title
@@ -147,7 +147,8 @@ class _ImageProfileArea(items.Shape):
origin=origin,
scale=scale,
lineWidth=roi.getProfileLineWidth(),
- method="none")
+ method="none",
+ )
return area
@@ -214,8 +215,7 @@ class _SliceProfileArea(items.Shape):
class _DefaultImageProfileRoiMixIn(core.ProfileRoiMixIn):
- """Provide common behavior for silx default image profile ROI.
- """
+ """Provide common behavior for silx default image profile ROI."""
ITEM_KIND = items.ImageBase
@@ -265,21 +265,21 @@ class _DefaultImageProfileRoiMixIn(core.ProfileRoiMixIn):
def _getRoiInfo(self):
"""Wrapper to allow to reuse the previous Profile code.
-
+
It would be good to remove it at one point.
"""
if isinstance(self, roi_items.HorizontalLineROI):
- lineProjectionMode = 'X'
+ lineProjectionMode = "X"
y = self.getPosition()
roiStart = (0, y)
roiEnd = (1, y)
elif isinstance(self, roi_items.VerticalLineROI):
- lineProjectionMode = 'Y'
+ lineProjectionMode = "Y"
x = self.getPosition()
roiStart = (x, 0)
roiEnd = (x, 1)
elif isinstance(self, roi_items.LineROI):
- lineProjectionMode = 'D'
+ lineProjectionMode = "D"
roiStart, roiEnd = self.getEndPoints()
else:
assert False
@@ -294,15 +294,17 @@ class _DefaultImageProfileRoiMixIn(core.ProfileRoiMixIn):
scale = item.getScale()
method = self.getProfileMethod()
lineWidth = self.getProfileLineWidth()
+ roiInfo = self._getRoiInfo()
def createProfile2(currentData):
coords, profile, _area, profileName, xLabel = core.createProfile(
- roiInfo=self._getRoiInfo(),
+ roiInfo=roiInfo,
currentData=currentData,
origin=origin,
scale=scale,
lineWidth=lineWidth,
- method=method)
+ method=method,
+ )
return coords, profile, profileName, xLabel
currentData = item.getValueData(copy=False)
@@ -348,61 +350,61 @@ class _DefaultImageProfileRoiMixIn(core.ProfileRoiMixIn):
return data
-class ProfileImageHorizontalLineROI(roi_items.HorizontalLineROI,
- _DefaultImageProfileRoiMixIn):
+class ProfileImageHorizontalLineROI(
+ roi_items.HorizontalLineROI, _DefaultImageProfileRoiMixIn
+):
"""ROI for an horizontal profile at a location of an image"""
- ICON = 'shape-horizontal'
- NAME = 'horizontal line profile'
+ ICON = "shape-horizontal"
+ NAME = "horizontal line profile"
def __init__(self, parent=None):
roi_items.HorizontalLineROI.__init__(self, parent=parent)
_DefaultImageProfileRoiMixIn.__init__(self, parent=parent)
-class ProfileImageVerticalLineROI(roi_items.VerticalLineROI,
- _DefaultImageProfileRoiMixIn):
+class ProfileImageVerticalLineROI(
+ roi_items.VerticalLineROI, _DefaultImageProfileRoiMixIn
+):
"""ROI for a vertical profile at a location of an image"""
- ICON = 'shape-vertical'
- NAME = 'vertical line profile'
+ ICON = "shape-vertical"
+ NAME = "vertical line profile"
def __init__(self, parent=None):
roi_items.VerticalLineROI.__init__(self, parent=parent)
_DefaultImageProfileRoiMixIn.__init__(self, parent=parent)
-class ProfileImageLineROI(roi_items.LineROI,
- _DefaultImageProfileRoiMixIn):
+class ProfileImageLineROI(roi_items.LineROI, _DefaultImageProfileRoiMixIn):
"""ROI for an image profile between 2 points.
The X profile of this ROI is the projecting into one of the x/y axes,
using its scale and its orientation.
"""
- ICON = 'shape-diagonal'
- NAME = 'line profile'
+ ICON = "shape-diagonal"
+ NAME = "line profile"
def __init__(self, parent=None):
roi_items.LineROI.__init__(self, parent=parent)
_DefaultImageProfileRoiMixIn.__init__(self, parent=parent)
-class ProfileImageDirectedLineROI(roi_items.LineROI,
- _DefaultImageProfileRoiMixIn):
+class ProfileImageDirectedLineROI(roi_items.LineROI, _DefaultImageProfileRoiMixIn):
"""ROI for an image profile between 2 points.
The X profile of the line is displayed projected into the line itself,
using its scale and its orientation. It's the distance from the origin.
"""
- ICON = 'shape-diagonal-directed'
- NAME = 'directed line profile'
+ ICON = "shape-diagonal-directed"
+ NAME = "directed line profile"
def __init__(self, parent=None):
roi_items.LineROI.__init__(self, parent=parent)
_DefaultImageProfileRoiMixIn.__init__(self, parent=parent)
- self._handleStart.setSymbol('o')
+ self._handleStart.setSymbol("o")
def computeProfile(self, item):
if not isinstance(item, items.ImageBase):
@@ -419,10 +421,11 @@ class ProfileImageDirectedLineROI(roi_items.LineROI,
roiInfo = self._getRoiInfo()
roiStart, roiEnd, _lineProjectionMode = roiInfo
- startPt = ((roiStart[1] - origin[1]) / scale[1],
- (roiStart[0] - origin[0]) / scale[0])
- endPt = ((roiEnd[1] - origin[1]) / scale[1],
- (roiEnd[0] - origin[0]) / scale[0])
+ startPt = (
+ (roiStart[1] - origin[1]) / scale[1],
+ (roiStart[0] - origin[0]) / scale[0],
+ )
+ endPt = ((roiEnd[1] - origin[1]) / scale[1], (roiEnd[0] - origin[0]) / scale[0])
if numpy.array_equal(startPt, endPt):
return None
@@ -432,14 +435,16 @@ class ProfileImageDirectedLineROI(roi_items.LineROI,
(startPt[0] - 0.5, startPt[1] - 0.5),
(endPt[0] - 0.5, endPt[1] - 0.5),
lineWidth,
- method=method)
+ method=method,
+ )
# Compute the line size
- lineSize = numpy.sqrt((roiEnd[1] - roiStart[1]) ** 2 +
- (roiEnd[0] - roiStart[0]) ** 2)
- coords = numpy.linspace(0, lineSize, len(profile),
- endpoint=True,
- dtype=numpy.float32)
+ lineSize = numpy.sqrt(
+ (roiEnd[1] - roiStart[1]) ** 2 + (roiEnd[0] - roiStart[0]) ** 2
+ )
+ coords = numpy.linspace(
+ 0, lineSize, len(profile), endpoint=True, dtype=numpy.float32
+ )
title = _lineProfileTitle(*roiStart, *roiEnd)
title = title + "; width = %d" % lineWidth
@@ -530,8 +535,7 @@ class _ProfileCrossROI(roi_items.HandleBasedROI, core.ProfileRoiMixIn):
def __updateLineProperty(self, event=None, checkVisibility=True):
if event == items.ItemChangedType.NAME:
self.__handleLabel.setText(self.getName())
- elif event in [items.ItemChangedType.COLOR,
- items.ItemChangedType.VISIBLE]:
+ elif event in [items.ItemChangedType.COLOR, items.ItemChangedType.VISIBLE]:
lines = []
if self.__vline:
lines.append(self.__vline)
@@ -658,8 +662,8 @@ class ProfileImageCrossROI(_ProfileCrossROI):
It is managed using 2 sub ROIs for vertical and horizontal.
"""
- ICON = 'shape-cross'
- NAME = 'cross profile'
+ ICON = "shape-cross"
+ NAME = "cross profile"
ITEM_KIND = items.ImageBase
def _createLines(self, parent):
@@ -692,8 +696,7 @@ class ProfileImageCrossROI(_ProfileCrossROI):
class _DefaultScatterProfileRoiMixIn(core.ProfileRoiMixIn):
- """Provide common behavior for silx default scatter profile ROI.
- """
+ """Provide common behavior for silx default scatter profile ROI."""
ITEM_KIND = items.Scatter
@@ -736,7 +739,7 @@ class _DefaultScatterProfileRoiMixIn(core.ProfileRoiMixIn):
:param float y1: Profile end point Y coord
:return: (points, values) profile data or None
"""
- future = scatter._getInterpolator()
+ future = scatter._getInterpolatorFuture()
try:
interpolator = future.result()
except CancelledError:
@@ -745,15 +748,14 @@ class _DefaultScatterProfileRoiMixIn(core.ProfileRoiMixIn):
return None # Cannot init an interpolator
nPoints = self.getNPoints()
- points = numpy.transpose((
- numpy.linspace(x0, x1, nPoints, endpoint=True),
- numpy.linspace(y0, y1, nPoints, endpoint=True)))
-
- values = interpolator(points)
+ x = numpy.linspace(x0, x1, nPoints, endpoint=True)
+ y = numpy.linspace(y0, y1, nPoints, endpoint=True)
+ values = interpolator(x, y)
if not numpy.any(numpy.isfinite(values)):
return None # Profile outside convex hull
+ points = numpy.transpose((x, y))
return points, values
def computeProfile(self, item):
@@ -778,7 +780,7 @@ class _DefaultScatterProfileRoiMixIn(core.ProfileRoiMixIn):
x0 = x1 = self.getPosition()
y0, y1 = plot.getYAxis().getLimits()
else:
- raise RuntimeError('Unsupported ROI for profile: {}'.format(self.__class__))
+ raise RuntimeError("Unsupported ROI for profile: {}".format(self.__class__))
if x1 < x0 or (x1 == x0 and y1 < y0):
# Invert points
@@ -792,13 +794,14 @@ class _DefaultScatterProfileRoiMixIn(core.ProfileRoiMixIn):
points = profile[0]
values = profile[1]
- if (numpy.abs(points[-1, 0] - points[0, 0]) >
- numpy.abs(points[-1, 1] - points[0, 1])):
+ if numpy.abs(points[-1, 0] - points[0, 0]) > numpy.abs(
+ points[-1, 1] - points[0, 1]
+ ):
xProfile = points[:, 0]
- xLabel = '{xlabel}'
+ xLabel = "{xlabel}"
else:
xProfile = points[:, 1]
- xLabel = '{ylabel}'
+ xLabel = "{ylabel}"
# Use the axis names from the original
profileManager = self.getProfileManager()
@@ -811,41 +814,42 @@ class _DefaultScatterProfileRoiMixIn(core.ProfileRoiMixIn):
profile=values,
title=title,
xLabel=xLabel,
- yLabel='Profile',
+ yLabel="Profile",
)
return data
-class ProfileScatterHorizontalLineROI(roi_items.HorizontalLineROI,
- _DefaultScatterProfileRoiMixIn):
+class ProfileScatterHorizontalLineROI(
+ roi_items.HorizontalLineROI, _DefaultScatterProfileRoiMixIn
+):
"""ROI for an horizontal profile at a location of a scatter"""
- ICON = 'shape-horizontal'
- NAME = 'horizontal line profile'
+ ICON = "shape-horizontal"
+ NAME = "horizontal line profile"
def __init__(self, parent=None):
roi_items.HorizontalLineROI.__init__(self, parent=parent)
_DefaultScatterProfileRoiMixIn.__init__(self, parent=parent)
-class ProfileScatterVerticalLineROI(roi_items.VerticalLineROI,
- _DefaultScatterProfileRoiMixIn):
+class ProfileScatterVerticalLineROI(
+ roi_items.VerticalLineROI, _DefaultScatterProfileRoiMixIn
+):
"""ROI for an horizontal profile at a location of a scatter"""
- ICON = 'shape-vertical'
- NAME = 'vertical line profile'
+ ICON = "shape-vertical"
+ NAME = "vertical line profile"
def __init__(self, parent=None):
roi_items.VerticalLineROI.__init__(self, parent=parent)
_DefaultScatterProfileRoiMixIn.__init__(self, parent=parent)
-class ProfileScatterLineROI(roi_items.LineROI,
- _DefaultScatterProfileRoiMixIn):
+class ProfileScatterLineROI(roi_items.LineROI, _DefaultScatterProfileRoiMixIn):
"""ROI for an horizontal profile at a location of a scatter"""
- ICON = 'shape-diagonal'
- NAME = 'line profile'
+ ICON = "shape-diagonal"
+ NAME = "line profile"
def __init__(self, parent=None):
roi_items.LineROI.__init__(self, parent=parent)
@@ -853,11 +857,10 @@ class ProfileScatterLineROI(roi_items.LineROI,
class ProfileScatterCrossROI(_ProfileCrossROI):
- """ROI to manage a cross of profiles for scatters.
- """
+ """ROI to manage a cross of profiles for scatters."""
- ICON = 'shape-cross'
- NAME = 'cross profile'
+ ICON = "shape-cross"
+ NAME = "cross profile"
ITEM_KIND = items.Scatter
def _createLines(self, parent):
@@ -909,7 +912,9 @@ class _DefaultScatterProfileSliceRoiMixIn(core.ProfileRoiMixIn):
def _getSlice(self, item):
position = self.getPosition()
- bounds = item.getCurrentVisualizationParameter(items.Scatter.VisualizationParameter.GRID_BOUNDS)
+ bounds = item.getCurrentVisualizationParameter(
+ items.Scatter.VisualizationParameter.GRID_BOUNDS
+ )
if isinstance(self, roi_items.HorizontalLineROI):
axis = 1
elif isinstance(self, roi_items.VerticalLineROI):
@@ -920,21 +925,25 @@ class _DefaultScatterProfileSliceRoiMixIn(core.ProfileRoiMixIn):
# ROI outside of the scatter bound
return None
- major_order = item.getCurrentVisualizationParameter(items.Scatter.VisualizationParameter.GRID_MAJOR_ORDER)
- assert major_order == 'row'
- max_grid_yy, max_grid_xx = item.getCurrentVisualizationParameter(items.Scatter.VisualizationParameter.GRID_SHAPE)
+ major_order = item.getCurrentVisualizationParameter(
+ items.Scatter.VisualizationParameter.GRID_MAJOR_ORDER
+ )
+ assert major_order == "row"
+ max_grid_yy, max_grid_xx = item.getCurrentVisualizationParameter(
+ items.Scatter.VisualizationParameter.GRID_SHAPE
+ )
xx, yy, _values, _xx_error, _yy_error = item.getData(copy=False)
if isinstance(self, roi_items.HorizontalLineROI):
axis = yy
max_grid_first = max_grid_yy
max_grid_second = max_grid_xx
- major_axis = major_order == 'column'
+ major_axis = major_order == "column"
elif isinstance(self, roi_items.VerticalLineROI):
axis = xx
max_grid_first = max_grid_xx
max_grid_second = max_grid_yy
- major_axis = major_order == 'row'
+ major_axis = major_order == "row"
else:
assert False
@@ -946,7 +955,7 @@ class _DefaultScatterProfileSliceRoiMixIn(core.ProfileRoiMixIn):
# slice in the middle of the scatter
actual_size_grid_second = len(axis) // max_grid_first
start = actual_size_grid_second // 2 * max_grid_first
- vslice = axis[start:start + max_grid_first]
+ vslice = axis[start : start + max_grid_first]
if len(vslice) == 0:
return None
index = argnearest(vslice, position)
@@ -954,7 +963,7 @@ class _DefaultScatterProfileSliceRoiMixIn(core.ProfileRoiMixIn):
else:
# slice in the middle of the scatter
actual_size_grid_second = len(axis) // max_grid_first
- vslice = axis[actual_size_grid_second // 2::max_grid_second]
+ vslice = axis[actual_size_grid_second // 2 :: max_grid_second]
if len(vslice) == 0:
return None
index = argnearest(vslice, position)
@@ -999,28 +1008,30 @@ class _DefaultScatterProfileSliceRoiMixIn(core.ProfileRoiMixIn):
return data
-class ProfileScatterHorizontalSliceROI(roi_items.HorizontalLineROI,
- _DefaultScatterProfileSliceRoiMixIn):
+class ProfileScatterHorizontalSliceROI(
+ roi_items.HorizontalLineROI, _DefaultScatterProfileSliceRoiMixIn
+):
"""ROI for an horizontal profile at a location of a scatter
using data slicing.
"""
- ICON = 'slice-horizontal'
- NAME = 'horizontal data slice profile'
+ ICON = "slice-horizontal"
+ NAME = "horizontal data slice profile"
def __init__(self, parent=None):
roi_items.HorizontalLineROI.__init__(self, parent=parent)
_DefaultScatterProfileSliceRoiMixIn.__init__(self, parent=parent)
-class ProfileScatterVerticalSliceROI(roi_items.VerticalLineROI,
- _DefaultScatterProfileSliceRoiMixIn):
+class ProfileScatterVerticalSliceROI(
+ roi_items.VerticalLineROI, _DefaultScatterProfileSliceRoiMixIn
+):
"""ROI for a vertical profile at a location of a scatter
using data slicing.
"""
- ICON = 'slice-vertical'
- NAME = 'vertical data slice profile'
+ ICON = "slice-vertical"
+ NAME = "vertical data slice profile"
def __init__(self, parent=None):
roi_items.VerticalLineROI.__init__(self, parent=parent)
@@ -1028,11 +1039,10 @@ class ProfileScatterVerticalSliceROI(roi_items.VerticalLineROI,
class ProfileScatterCrossSliceROI(_ProfileCrossROI):
- """ROI to manage a cross of slicing profiles on scatters.
- """
+ """ROI to manage a cross of slicing profiles on scatters."""
- ICON = 'slice-cross'
- NAME = 'cross data slice profile'
+ ICON = "slice-cross"
+ NAME = "cross data slice profile"
ITEM_KIND = items.Scatter
def _createLines(self, parent):
@@ -1042,7 +1052,6 @@ class ProfileScatterCrossSliceROI(_ProfileCrossROI):
class _DefaultImageStackProfileRoiMixIn(_DefaultImageProfileRoiMixIn):
-
ITEM_KIND = items.ImageStack
def __init__(self, parent=None):
@@ -1073,22 +1082,24 @@ class _DefaultImageStackProfileRoiMixIn(_DefaultImageProfileRoiMixIn):
assert kind == "2D"
+ currentData = numpy.array(item.getStackData(copy=False))
+ origin = item.getOrigin()
+ scale = item.getScale()
+ colormap = item.getColormap()
+ method = self.getProfileMethod()
+ roiInfo = self._getRoiInfo()
+
def createProfile2(currentData):
coords, profile, _area, profileName, xLabel = core.createProfile(
- roiInfo=self._getRoiInfo(),
+ roiInfo=roiInfo,
currentData=currentData,
origin=origin,
scale=scale,
lineWidth=self.getProfileLineWidth(),
- method=method)
+ method=method,
+ )
return coords, profile, profileName, xLabel
- currentData = numpy.array(item.getStackData(copy=False))
- origin = item.getOrigin()
- scale = item.getScale()
- colormap = item.getColormap()
- method = self.getProfileMethod()
-
coords, profile, profileName, xLabel = createProfile2(currentData)
profileManager = self.getProfileManager()
@@ -1105,36 +1116,37 @@ class _DefaultImageStackProfileRoiMixIn(_DefaultImageProfileRoiMixIn):
return data
-class ProfileImageStackHorizontalLineROI(roi_items.HorizontalLineROI,
- _DefaultImageStackProfileRoiMixIn):
+class ProfileImageStackHorizontalLineROI(
+ roi_items.HorizontalLineROI, _DefaultImageStackProfileRoiMixIn
+):
"""ROI for an horizontal profile at a location of a stack of images"""
- ICON = 'shape-horizontal'
- NAME = 'horizontal line profile'
+ ICON = "shape-horizontal"
+ NAME = "horizontal line profile"
def __init__(self, parent=None):
roi_items.HorizontalLineROI.__init__(self, parent=parent)
_DefaultImageStackProfileRoiMixIn.__init__(self, parent=parent)
-class ProfileImageStackVerticalLineROI(roi_items.VerticalLineROI,
- _DefaultImageStackProfileRoiMixIn):
+class ProfileImageStackVerticalLineROI(
+ roi_items.VerticalLineROI, _DefaultImageStackProfileRoiMixIn
+):
"""ROI for an vertical profile at a location of a stack of images"""
- ICON = 'shape-vertical'
- NAME = 'vertical line profile'
+ ICON = "shape-vertical"
+ NAME = "vertical line profile"
def __init__(self, parent=None):
roi_items.VerticalLineROI.__init__(self, parent=parent)
_DefaultImageStackProfileRoiMixIn.__init__(self, parent=parent)
-class ProfileImageStackLineROI(roi_items.LineROI,
- _DefaultImageStackProfileRoiMixIn):
+class ProfileImageStackLineROI(roi_items.LineROI, _DefaultImageStackProfileRoiMixIn):
"""ROI for an vertical profile at a location of a stack of images"""
- ICON = 'shape-diagonal'
- NAME = 'line profile'
+ ICON = "shape-diagonal"
+ NAME = "line profile"
def __init__(self, parent=None):
roi_items.LineROI.__init__(self, parent=parent)
@@ -1144,8 +1156,8 @@ class ProfileImageStackLineROI(roi_items.LineROI,
class ProfileImageStackCrossROI(ProfileImageCrossROI):
"""ROI for an vertical profile at a location of a stack of images"""
- ICON = 'shape-cross'
- NAME = 'cross profile'
+ ICON = "shape-cross"
+ NAME = "cross profile"
ITEM_KIND = items.ImageStack
def _createLines(self, parent):
diff --git a/src/silx/gui/plot/tools/profile/toolbar.py b/src/silx/gui/plot/tools/profile/toolbar.py
index 12a734a..d073717 100644
--- a/src/silx/gui/plot/tools/profile/toolbar.py
+++ b/src/silx/gui/plot/tools/profile/toolbar.py
@@ -44,10 +44,11 @@ _logger = logging.getLogger(__name__)
class ProfileToolBar(qt.QToolBar):
"""Tool bar to provide profile for a plot.
-
+
It is an helper class. For a dedicated application it would be better to
use an own tool bar in order in order have more flexibility.
"""
+
def __init__(self, parent=None, plot=None):
super(ProfileToolBar, self).__init__(parent=parent)
self.__scheme = None
diff --git a/src/silx/gui/plot/tools/roi.py b/src/silx/gui/plot/tools/roi.py
index 1da692c..21b9409 100644
--- a/src/silx/gui/plot/tools/roi.py
+++ b/src/silx/gui/plot/tools/roi.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2018-2022 European Synchrotron Radiation Facility
+# Copyright (c) 2018-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -34,6 +34,7 @@ import logging
import time
import weakref
import functools
+from typing import Optional
import numpy
@@ -42,6 +43,8 @@ from ...utils import blockSignals
from ...utils import LockReentrant
from .. import PlotWidget
from ..items import roi as roi_items
+from ..items import ItemChangedType
+from ..items.roi import RegionOfInterest
from ...colors import rgba
@@ -87,7 +90,7 @@ class CreateRoiModeAction(qt.QAction):
iconName = "add-shape-unknown"
if name is None:
name = roiClass.__name__
- text = 'Add %s' % name
+ text = "Add %s" % name
self.setIcon(icons.getQIcon(iconName))
self.setText(text)
self.setCheckable(True)
@@ -144,7 +147,9 @@ class CreateRoiModeAction(qt.QAction):
if roiManager is not None:
roiManager.sigInteractiveRoiCreated.disconnect(self.initRoi)
roiManager.sigInteractiveRoiFinalized.disconnect(self.__finalizeRoi)
- roiManager.sigInteractiveModeFinished.disconnect(self.__interactiveModeFinished)
+ roiManager.sigInteractiveModeFinished.disconnect(
+ self.__interactiveModeFinished
+ )
self.setChecked(False)
def initRoi(self, roi):
@@ -391,7 +396,8 @@ class RegionOfInterestManager(qt.QObject):
self._roiClass = None
self._source = None
- self._color = rgba('red')
+ self._lastHoveredMarkerLabel = None
+ self._color = rgba("red")
self._label = "__RegionOfInterestManager__%d" % id(self)
@@ -404,8 +410,7 @@ class RegionOfInterestManager(qt.QObject):
parent.sigPlotSignal.connect(self._plotSignals)
- parent.sigInteractiveModeChanged.connect(
- self._plotInteractiveModeChanged)
+ parent.sigInteractiveModeChanged.connect(self._plotInteractiveModeChanged)
parent.sigItemRemoved.connect(self._itemRemoved)
@@ -432,7 +437,7 @@ class RegionOfInterestManager(qt.QObject):
:raise ValueError: If kind is not supported
"""
if not issubclass(roiClass, roi_items.RegionOfInterest):
- raise ValueError('Unsupported ROI class %s' % roiClass)
+ raise ValueError("Unsupported ROI class %s" % roiClass)
action = self._modeActions.get(roiClass, None)
if action is None: # Lazy-loading
@@ -476,19 +481,21 @@ class RegionOfInterestManager(qt.QObject):
return # Should not happen
kind = roiClass.getFirstInteractionShape()
- if kind == 'point':
- if event['event'] == 'mouseClicked' and event['button'] == 'left':
- points = numpy.array([(event['x'], event['y'])],
- dtype=numpy.float64)
+ if kind == "point":
+ if event["event"] == "mouseClicked" and event["button"] == "left":
+ points = numpy.array([(event["x"], event["y"])], dtype=numpy.float64)
# Not an interactive creation
roi = self._createInteractiveRoi(roiClass, points=points)
roi.creationFinalized()
self.sigInteractiveRoiFinalized.emit(roi)
else: # other shapes
- if (event['event'] in ('drawingProgress', 'drawingFinished') and
- event['parameters']['label'] == self._label):
- points = numpy.array((event['xdata'], event['ydata']),
- dtype=numpy.float64).T
+ if (
+ event["event"] in ("drawingProgress", "drawingFinished")
+ and event["parameters"]["label"] == self._label
+ ):
+ points = numpy.array(
+ (event["xdata"], event["ydata"]), dtype=numpy.float64
+ ).T
if self._drawnROI is None: # Create new ROI
# NOTE: Set something before createRoi, so isDrawing is True
@@ -497,8 +504,8 @@ class RegionOfInterestManager(qt.QObject):
else:
self._drawnROI.setFirstShapePoints(points)
- if event['event'] == 'drawingFinished':
- if kind == 'polygon' and len(points) > 1:
+ if event["event"] == "drawingFinished":
+ if kind == "polygon" and len(points) > 1:
self._drawnROI.setFirstShapePoints(points[:-1])
roi = self._drawnROI
self._drawnROI = None # Stop drawing
@@ -521,7 +528,7 @@ class RegionOfInterestManager(qt.QObject):
return roi
return None
- def setCurrentRoi(self, roi):
+ def setCurrentRoi(self, roi: Optional[RegionOfInterest]):
"""Set the currently selected ROI, and emit a signal.
:param Union[RegionOfInterest,None] roi: The ROI to select
@@ -545,11 +552,8 @@ class RegionOfInterestManager(qt.QObject):
self._currentRoi.setHighlighted(True)
self.sigCurrentRoiChanged.emit(roi)
- def getCurrentRoi(self):
- """Returns the currently selected ROI, else None.
-
- :rtype: Union[RegionOfInterest,None]
- """
+ def getCurrentRoi(self) -> Optional[RegionOfInterest]:
+ """Returns the currently selected ROI, else None."""
return self._currentRoi
def _plotSignals(self, event):
@@ -568,6 +572,8 @@ class RegionOfInterestManager(qt.QObject):
plot = self.parent()
marker = plot._getMarkerAt(event["xpixel"], event["ypixel"])
roi = self.__getRoiFromMarker(marker)
+ elif event["event"] == "hover":
+ self._lastHoveredMarkerLabel = event["label"]
else:
return
@@ -585,7 +591,7 @@ class RegionOfInterestManager(qt.QObject):
else:
self.setCurrentRoi(None)
- def __updateMode(self, roi):
+ def __updateMode(self, roi: RegionOfInterest):
if isinstance(roi, roi_items.InteractionModeMixIn):
available = roi.availableInteractionModes()
mode = roi.getInteractionMode()
@@ -593,46 +599,50 @@ class RegionOfInterestManager(qt.QObject):
mode = available[(imode + 1) % len(available)]
roi.setInteractionMode(mode)
- def _feedContextMenu(self, menu):
+ def _feedContextMenu(self, menu: qt.QMenu):
"""Called when the default plot context menu is about to be displayed"""
roi = self.getCurrentRoi()
if roi is not None:
if roi.isEditable():
- # Filter by data position
- # FIXME: It would be better to use GUI coords for it
- plot = self.parent()
- pos = plot.getWidgetHandle().mapFromGlobal(qt.QCursor.pos())
- data = plot.pixelToData(pos.x(), pos.y())
- if roi.contains(data):
- if isinstance(roi, roi_items.InteractionModeMixIn):
- self._contextMenuForInteractionMode(menu, roi)
-
- removeAction = qt.QAction(menu)
- removeAction.setText("Remove %s" % roi.getName())
- callback = functools.partial(self.removeRoi, roi)
- removeAction.triggered.connect(callback)
- menu.addAction(removeAction)
-
- def _contextMenuForInteractionMode(self, menu, roi):
- availableModes = roi.availableInteractionModes()
- currentMode = roi.getInteractionMode()
- submenu = qt.QMenu(menu)
- modeGroup = qt.QActionGroup(menu)
- modeGroup.setExclusive(True)
- for mode in availableModes:
- action = qt.QAction(menu)
- action.setText(mode.label)
- action.setToolTip(mode.description)
- action.setCheckable(True)
- if mode is currentMode:
- action.setChecked(True)
- else:
- callback = functools.partial(roi.setInteractionMode, mode)
- action.triggered.connect(callback)
- modeGroup.addAction(action)
- submenu.addAction(action)
- submenu.setTitle("%s interaction mode" % roi.getName())
- menu.addMenu(submenu)
+ if self._isMouseHoverRoi(roi):
+ roiMenu = self._createMenuForRoi(menu, roi)
+ menu.addMenu(roiMenu)
+
+ def _isMouseHoverRoi(self, roi: RegionOfInterest) -> bool:
+ """Check that the mouse hovers this roi"""
+ plot = self.parent()
+
+ if self._lastHoveredMarkerLabel is not None:
+ marker = plot._getMarker(self._lastHoveredMarkerLabel)
+ if marker is not None:
+ r = self.__getRoiFromMarker(marker)
+ if roi is r:
+ return True
+
+ # Filter by data position
+ # FIXME: It would be better to use GUI coords for it
+ pos = plot.getWidgetHandle().mapFromGlobal(qt.QCursor.pos())
+ data = plot.pixelToData(pos.x(), pos.y())
+ return roi.contains(data)
+
+ def _createMenuForRoi(self, parent: qt.QWidget, roi: RegionOfInterest) -> qt.QMenu:
+ """Create a QMenu for the given RegionOfInterest"""
+ roiMenu = qt.QMenu(parent)
+ roiMenu.setTitle(roi.getName())
+
+ if isinstance(roi, roi_items.InteractionModeMixIn):
+ interactionMenu = roi.createMenuForInteractionMode(roiMenu)
+ roiMenu.addMenu(interactionMenu)
+
+ removeAction = qt.QAction(roiMenu)
+ removeAction.setText("Remove")
+ callback = functools.partial(self.removeRoi, roi)
+ removeAction.triggered.connect(callback)
+ roiMenu.addAction(removeAction)
+
+ roi.populateContextMenu(roiMenu)
+
+ return roiMenu
# RegionOfInterest API
@@ -654,8 +664,7 @@ class RegionOfInterestManager(qt.QObject):
"""
if self.getRois(): # Something to reset
for roi in self._rois:
- roi.sigRegionChanged.disconnect(
- self._regionOfInterestChanged)
+ roi.sigRegionChanged.disconnect(self._regionOfInterestChanged)
roi.setParent(None)
self._rois = []
self._roisUpdated()
@@ -715,8 +724,7 @@ class RegionOfInterestManager(qt.QObject):
"""
plot = self.parent()
if plot is None:
- raise RuntimeError(
- 'Cannot add ROI: PlotWidget no more available')
+ raise RuntimeError("Cannot add ROI: PlotWidget no more available")
roi.setParent(self)
@@ -739,11 +747,12 @@ class RegionOfInterestManager(qt.QObject):
:param roi_items.RegionOfInterest roi: The ROI to remove
:raise ValueError: When ROI does not belong to this object
"""
- if not (isinstance(roi, roi_items.RegionOfInterest) and
- roi.parent() is self and
- roi in self._rois):
- raise ValueError(
- 'RegionOfInterest does not belong to this instance')
+ if not (
+ isinstance(roi, roi_items.RegionOfInterest)
+ and roi.parent() is self
+ and roi in self._rois
+ ):
+ raise ValueError("RegionOfInterest does not belong to this instance")
roi.sigAboutToBeRemoved.emit()
self.sigRoiAboutToBeRemoved.emit(roi)
@@ -834,7 +843,7 @@ class RegionOfInterestManager(qt.QObject):
self.stop()
if not issubclass(roiClass, roi_items.RegionOfInterest):
- raise ValueError('Unsupported ROI class %s' % roiClass)
+ raise ValueError("Unsupported ROI class %s" % roiClass)
plot = self.parent()
if plot is None:
@@ -859,18 +868,20 @@ class RegionOfInterestManager(qt.QObject):
plot = self.parent()
firstInteractionShapeKind = roiClass.getFirstInteractionShape()
- if firstInteractionShapeKind == 'point':
- plot.setInteractiveMode(mode='select', source=self)
+ if firstInteractionShapeKind == "point":
+ plot.setInteractiveMode(mode="select", source=self)
else:
if roiClass.showFirstInteractionShape():
color = rgba(self.getColor())
else:
color = None
- plot.setInteractiveMode(mode='select-draw',
- source=self,
- shape=firstInteractionShapeKind,
- color=color,
- label=self._label)
+ plot.setInteractiveMode(
+ mode="draw",
+ source=self,
+ shape=firstInteractionShapeKind,
+ color=color,
+ label=self._label,
+ )
def __roiInteractiveModeEnded(self):
"""Handle end of ROI draw interactive mode"""
@@ -964,7 +975,7 @@ class InteractiveRegionOfInterestManager(RegionOfInterestManager):
super(InteractiveRegionOfInterestManager, self).__init__(parent)
self._maxROI = None
self.__timeoutEndTime = None
- self.__message = ''
+ self.__message = ""
self.__validationMode = self.ValidationMode.ENTER
self.__execClass = None
@@ -991,11 +1002,10 @@ class InteractiveRegionOfInterestManager(RegionOfInterestManager):
if max_ is not None:
max_ = int(max_)
if max_ <= 0:
- raise ValueError('Max limit must be strictly positive')
+ raise ValueError("Max limit must be strictly positive")
if len(self.getRois()) > max_:
- raise ValueError(
- 'Cannot set max limit: Already too many ROIs')
+ raise ValueError("Cannot set max limit: Already too many ROIs")
self._maxROI = max_
@@ -1013,19 +1023,19 @@ class InteractiveRegionOfInterestManager(RegionOfInterestManager):
class ValidationMode(enum.Enum):
"""Mode of validation to leave blocking :meth:`exec`"""
- AUTO = 'auto'
+ AUTO = "auto"
"""Automatically ends the interactive mode once
the user terminates the last ROI shape."""
- ENTER = 'enter'
+ ENTER = "enter"
"""Ends the interactive mode when the *Enter* key is pressed."""
- AUTO_ENTER = 'auto_enter'
+ AUTO_ENTER = "auto_enter"
"""Ends the interactive mode when reaching max ROIs or
when the *Enter* key is pressed.
"""
- NONE = 'none'
+ NONE = "none"
"""Do not provide the user a way to end the interactive mode.
The end of :meth:`exec` is done through :meth:`quit` or timeout.
@@ -1051,9 +1061,10 @@ class InteractiveRegionOfInterestManager(RegionOfInterestManager):
self.__validationMode = mode
if self.isExec():
- if (self.isMaxRois() and self.getValidationMode() in
- (self.ValidationMode.AUTO,
- self.ValidationMode.AUTO_ENTER)):
+ if self.isMaxRois() and self.getValidationMode() in (
+ self.ValidationMode.AUTO,
+ self.ValidationMode.AUTO_ENTER,
+ ):
self.quit()
self.__updateMessage()
@@ -1064,17 +1075,20 @@ class InteractiveRegionOfInterestManager(RegionOfInterestManager):
if event.type() == qt.QEvent.KeyPress:
key = event.key()
- if (key in (qt.Qt.Key_Return, qt.Qt.Key_Enter) and
- self.getValidationMode() in (
- self.ValidationMode.ENTER,
- self.ValidationMode.AUTO_ENTER)):
+ if key in (
+ qt.Qt.Key_Return,
+ qt.Qt.Key_Enter,
+ ) and self.getValidationMode() in (
+ self.ValidationMode.ENTER,
+ self.ValidationMode.AUTO_ENTER,
+ ):
# Stop on return key pressed
self.quit()
return True # Stop further handling of this keys
- if (key in (qt.Qt.Key_Delete, qt.Qt.Key_Backspace) or (
- key == qt.Qt.Key_Z and
- event.modifiers() & qt.Qt.ControlModifier)):
+ if key in (qt.Qt.Key_Delete, qt.Qt.Key_Backspace) or (
+ key == qt.Qt.Key_Z and event.modifiers() & qt.Qt.ControlModifier
+ ):
rois = self.getRois()
if rois: # Something to undo
self.removeRoi(rois[-1])
@@ -1096,8 +1110,7 @@ class InteractiveRegionOfInterestManager(RegionOfInterestManager):
return self.__message
else:
remaining = self.__timeoutEndTime - time.time()
- return self.__message + (' - %d seconds remaining' %
- max(1, int(remaining)))
+ return self.__message + (" - %d seconds remaining" % max(1, int(remaining)))
# Listen to ROI updates
@@ -1110,9 +1123,10 @@ class InteractiveRegionOfInterestManager(RegionOfInterestManager):
self.removeRoi(self.getRois()[-2])
self.__updateMessage()
- if (self.isMaxRois() and
- self.getValidationMode() in (self.ValidationMode.AUTO,
- self.ValidationMode.AUTO_ENTER)):
+ if self.isMaxRois() and self.getValidationMode() in (
+ self.ValidationMode.AUTO,
+ self.ValidationMode.AUTO_ENTER,
+ ):
self.quit()
def __aboutToBeRemoved(self, *args, **kwargs):
@@ -1131,10 +1145,10 @@ class InteractiveRegionOfInterestManager(RegionOfInterestManager):
def __updateMessage(self, nbrois=None):
"""Update message"""
if not self.isExec():
- message = 'Done'
+ message = "Done"
elif not self.isStarted():
- message = 'Use %s ROI edition mode' % self.__execClass
+ message = "Use %s ROI edition mode" % self.__execClass
else:
if nbrois is None:
@@ -1144,16 +1158,18 @@ class InteractiveRegionOfInterestManager(RegionOfInterestManager):
max_ = self.getMaxRois()
if max_ is None:
- message = 'Select %ss (%d selected)' % (name, nbrois)
+ message = "Select %ss (%d selected)" % (name, nbrois)
elif max_ <= 1:
- message = 'Select a %s' % name
+ message = "Select a %s" % name
else:
- message = 'Select %d/%d %ss' % (nbrois, max_, name)
+ message = "Select %d/%d %ss" % (nbrois, max_, name)
- if (self.getValidationMode() == self.ValidationMode.ENTER and
- self.isMaxRois()):
- message += ' - Press Enter to confirm'
+ if (
+ self.getValidationMode() == self.ValidationMode.ENTER
+ and self.isMaxRois()
+ ):
+ message += " - Press Enter to confirm"
if message != self.__message:
self.__message = message
@@ -1164,9 +1180,11 @@ class InteractiveRegionOfInterestManager(RegionOfInterestManager):
def __timeoutUpdate(self):
"""Handle update of timeout"""
- if (self.__timeoutEndTime is not None and
- (self.__timeoutEndTime - time.time()) > 0):
- self.sigMessageChanged.emit(self.getMessage())
+ if (
+ self.__timeoutEndTime is not None
+ and (self.__timeoutEndTime - time.time()) > 0
+ ):
+ self.sigMessageChanged.emit(self.getMessage())
else: # Stop interactive mode and message timer
timer = self.sender()
if timer is not None:
@@ -1234,7 +1252,7 @@ class _DeleteRegionOfInterestToolButton(qt.QToolButton):
def __init__(self, parent, roi):
super(_DeleteRegionOfInterestToolButton, self).__init__(parent)
- self.setIcon(icons.getQIcon('remove'))
+ self.setIcon(icons.getQIcon("remove"))
self.setToolTip("Remove this ROI")
self.__roiRef = roi if roi is None else weakref.ref(roi)
self.clicked.connect(self.__clicked)
@@ -1252,11 +1270,20 @@ class _DeleteRegionOfInterestToolButton(qt.QToolButton):
class RegionOfInterestTableWidget(qt.QTableWidget):
"""Widget displaying the ROIs of a :class:`RegionOfInterestManager`"""
+ # Columns indices of the different displayed information
+ (
+ _LABEL_VISIBLE_COL,
+ _EDITABLE_COL,
+ _KIND_COL,
+ _COORDINATES_COL,
+ _DELETE_COL,
+ ) = range(5)
+
def __init__(self, parent=None):
super(RegionOfInterestTableWidget, self).__init__(parent)
self._roiManagerRef = None
- headers = ['Label', 'Edit', 'Kind', 'Coordinates', '']
+ headers = ["Label", "Edit", "Kind", "Coordinates", ""]
self.setColumnCount(len(headers))
self.setHorizontalHeaderLabels(headers)
@@ -1278,21 +1305,17 @@ class RegionOfInterestTableWidget(qt.QTableWidget):
self.itemChanged.connect(self.__itemChanged)
def __itemChanged(self, item):
- """Handle item updates"""
+ """Handle QTableWidget item updates"""
column = item.column()
- index = item.data(qt.Qt.UserRole)
-
- if index is not None:
- manager = self.getRegionOfInterestManager()
- roi = manager.getRois()[index]
- else:
+ roi = item.data(qt.Qt.UserRole)
+ if roi is None:
return
if column == 0:
# First collect information from item, then update ROI
- # Otherwise, this causes issues issues
+ # Otherwise, this causes issues
checked = item.checkState() == qt.Qt.Checked
- text= item.text()
+ text = item.text()
roi.setVisible(checked)
roi.setName(text)
elif column == 1:
@@ -1300,7 +1323,7 @@ class RegionOfInterestTableWidget(qt.QTableWidget):
elif column in (2, 3, 4):
pass # TODO
else:
- logger.error('Unhandled column %d', column)
+ logger.error("Unhandled column %d", column)
def setRegionOfInterestManager(self, manager):
"""Set the :class:`RegionOfInterestManager` object to sync with
@@ -1312,7 +1335,13 @@ class RegionOfInterestTableWidget(qt.QTableWidget):
previousManager = self.getRegionOfInterestManager()
if previousManager is not None:
- previousManager.sigRoiChanged.disconnect(self._sync)
+ previousManager.sigRoiAdded.disconnect(self.__roiAdded)
+ previousManager.sigRoiAboutToBeRemoved.disconnect(
+ self.__roiAboutToBeRemoved
+ )
+ for roi in previousManager.getRois():
+ self.__disconnectRoi(roi)
+
self.setRowCount(0)
self._roiManagerRef = weakref.ref(manager)
@@ -1320,7 +1349,10 @@ class RegionOfInterestTableWidget(qt.QTableWidget):
self._sync()
if manager is not None:
- manager.sigRoiChanged.connect(self._sync)
+ for roi in manager.getRois():
+ self.__connectRoi(roi)
+ manager.sigRoiAdded.connect(self.__roiAdded)
+ manager.sigRoiAboutToBeRemoved.connect(self.__roiAboutToBeRemoved)
def _getReadableRoiDescription(self, roi):
"""Returns modelisation of a ROI as a readable sequence of values.
@@ -1345,6 +1377,75 @@ class RegionOfInterestTableWidget(qt.QTableWidget):
logger.debug("Backtrace", exc_info=True)
return text
+ def __connectRoi(self, roi: RegionOfInterest):
+ """Start listening ROI signals"""
+ roi.sigItemChanged.connect(self.__roiItemChanged)
+ roi.sigRegionChanged.connect(self.__roiRegionChanged)
+
+ def __disconnectRoi(self, roi: RegionOfInterest):
+ """Stop listening ROI signals"""
+ roi.sigItemChanged.disconnect(self.__roiItemChanged)
+ roi.sigRegionChanged.disconnect(self.__roiRegionChanged)
+
+ def __getRoiRow(self, roi: RegionOfInterest) -> int:
+ """Returns row index of given region of interest
+
+ :raises ValueError: If region of interest is not in the list
+ """
+ manager = self.getRegionOfInterestManager()
+ if manager is None:
+ return
+ return manager.getRois().index(roi)
+
+ def __roiAdded(self, roi: RegionOfInterest):
+ """Handle new ROI added to the manager"""
+ self.__connectRoi(roi)
+ self._sync()
+
+ def __roiAboutToBeRemoved(self, roi: RegionOfInterest):
+ """Handle removing a ROI from the manager"""
+ self.__disconnectRoi(roi)
+ self.removeRow(self.__getRoiRow(roi))
+
+ def __roiItemChanged(self, event: ItemChangedType):
+ """Handle ROI sigItemChanged events"""
+ roi = self.sender()
+ if roi is None:
+ return
+
+ try:
+ row = self.__getRoiRow(roi)
+ except ValueError:
+ return
+
+ if event == ItemChangedType.VISIBLE:
+ item = self.item(row, self._LABEL_VISIBLE_COL)
+ item.setCheckState(qt.Qt.Checked if roi.isVisible() else qt.Qt.Unchecked)
+ return
+
+ if event == ItemChangedType.NAME:
+ item = self.item(row, self._LABEL_VISIBLE_COL)
+ item.setText(roi.getName())
+ return
+
+ if event == ItemChangedType.EDITABLE:
+ item = self.item(row, self._EDITABLE_COL)
+ item.setCheckState(qt.Qt.Checked if roi.isEditable() else qt.Qt.Unchecked)
+ return
+
+ def __roiRegionChanged(self):
+ """Handle change of ROI coordinates"""
+ roi = self.sender()
+ if roi is None:
+ return
+
+ item = self.item(self.__getRoiRow(roi), self._COORDINATES_COL)
+ if item is None:
+ return
+
+ text = self._getReadableRoiDescription(roi)
+ item.setText(text)
+
def _sync(self):
"""Update widget content according to ROI manger"""
manager = self.getRegionOfInterestManager()
@@ -1360,21 +1461,19 @@ class RegionOfInterestTableWidget(qt.QTableWidget):
baseFlags = qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled
# Label and visible
- label = roi.getName()
- item = qt.QTableWidgetItem(label)
+ item = qt.QTableWidgetItem()
item.setFlags(baseFlags | qt.Qt.ItemIsEditable | qt.Qt.ItemIsUserCheckable)
- item.setData(qt.Qt.UserRole, index)
- item.setCheckState(
- qt.Qt.Checked if roi.isVisible() else qt.Qt.Unchecked)
- self.setItem(index, 0, item)
+ item.setData(qt.Qt.UserRole, roi)
+ item.setText(roi.getName())
+ item.setCheckState(qt.Qt.Checked if roi.isVisible() else qt.Qt.Unchecked)
+ self.setItem(index, self._LABEL_VISIBLE_COL, item)
# Editable
item = qt.QTableWidgetItem()
item.setFlags(baseFlags | qt.Qt.ItemIsUserCheckable)
- item.setData(qt.Qt.UserRole, index)
- item.setCheckState(
- qt.Qt.Checked if roi.isEditable() else qt.Qt.Unchecked)
- self.setItem(index, 1, item)
+ item.setData(qt.Qt.UserRole, roi)
+ item.setCheckState(qt.Qt.Checked if roi.isEditable() else qt.Qt.Unchecked)
+ self.setItem(index, self._EDITABLE_COL, item)
item.setTextAlignment(qt.Qt.AlignCenter)
item.setText(None)
@@ -1385,19 +1484,18 @@ class RegionOfInterestTableWidget(qt.QTableWidget):
label = roi.__class__.__name__
item = qt.QTableWidgetItem(label.capitalize())
item.setFlags(baseFlags)
- self.setItem(index, 2, item)
+ self.setItem(index, self._KIND_COL, item)
+ # Coordinates
item = qt.QTableWidgetItem()
item.setFlags(baseFlags)
-
- # Coordinates
text = self._getReadableRoiDescription(roi)
item.setText(text)
- self.setItem(index, 3, item)
+ self.setItem(index, self._COORDINATES_COL, item)
# Delete
- delBtn = _DeleteRegionOfInterestToolButton(None, roi)
widget = qt.QWidget(self)
+ delBtn = _DeleteRegionOfInterestToolButton(widget, roi)
layout = qt.QHBoxLayout()
layout.setContentsMargins(2, 2, 2, 2)
layout.setSpacing(0)
@@ -1405,7 +1503,7 @@ class RegionOfInterestTableWidget(qt.QTableWidget):
layout.addStretch(1)
layout.addWidget(delBtn)
layout.addStretch(1)
- self.setCellWidget(index, 4, widget)
+ self.setCellWidget(index, self._DELETE_COL, widget)
def getRegionOfInterestManager(self):
"""Returns the :class:`RegionOfInterestManager` this widget supervise.
diff --git a/src/silx/gui/plot/tools/test/testCurveLegendsWidget.py b/src/silx/gui/plot/tools/test/testCurveLegendsWidget.py
index 657d328..9f1a184 100644
--- a/src/silx/gui/plot/tools/test/testCurveLegendsWidget.py
+++ b/src/silx/gui/plot/tools/test/testCurveLegendsWidget.py
@@ -26,8 +26,6 @@ __license__ = "MIT"
__date__ = "02/08/2018"
-import unittest
-
from silx.gui import qt
from silx.utils.testutils import ParametricTestCase
from silx.gui.utils.testutils import TestCaseQt
@@ -46,7 +44,7 @@ class TestCurveLegendsWidget(TestCaseQt, ParametricTestCase):
self.legends.setPlotWidget(self.plot)
dock = qt.QDockWidget()
- dock.setWindowTitle('Curve Legends')
+ dock.setWindowTitle("Curve Legends")
dock.setWidget(self.legends)
self.plot.addTabbedDockWidget(dock)
@@ -68,9 +66,9 @@ class TestCurveLegendsWidget(TestCaseQt, ParametricTestCase):
def testAddRemoveCurves(self):
"""Test CurveLegendsWidget while adding/removing curves"""
- self.plot.addCurve((0, 1), (1, 2), legend='a')
+ self.plot.addCurve((0, 1), (1, 2), legend="a")
self._assertNbLegends(1)
- self.plot.addCurve((0, 1), (2, 3), legend='b')
+ self.plot.addCurve((0, 1), (2, 3), legend="b")
self._assertNbLegends(2)
# Detached/attach
@@ -84,28 +82,35 @@ class TestCurveLegendsWidget(TestCaseQt, ParametricTestCase):
self._assertNbLegends(0)
def testUpdateCurves(self):
- """Test CurveLegendsWidget while updating curves """
- self.plot.addCurve((0, 1), (1, 2), legend='a')
+ """Test CurveLegendsWidget while updating curves"""
+ self.plot.addCurve((0, 1), (1, 2), legend="a")
self._assertNbLegends(1)
- self.plot.addCurve((0, 1), (2, 3), legend='b')
+ self.plot.addCurve((0, 1), (2, 3), legend="b")
self._assertNbLegends(2)
# Activate curve
- self.plot.setActiveCurve('a')
+ self.plot.setActiveCurve("a")
self.qapp.processEvents()
- self.plot.setActiveCurve('b')
+ self.plot.setActiveCurve("b")
self.qapp.processEvents()
# Change curve style
- curve = self.plot.getCurve('a')
+ curve = self.plot.getCurve("a")
curve.setLineWidth(2)
- for linestyle in (':', '', '--', '-'):
+ for linestyle in (
+ ":",
+ "",
+ "--",
+ "-",
+ (0.0, (5.0, 5.0)),
+ (5.0, (10.0, 2.0, 2.0, 5.0)),
+ ):
with self.subTest(linestyle=linestyle):
curve.setLineStyle(linestyle)
self.qapp.processEvents()
self.qWait(1000)
- for symbol in ('o', 'd', '', 's'):
+ for symbol in ("o", "d", "", "s"):
with self.subTest(symbol=symbol):
curve.setSymbol(symbol)
self.qapp.processEvents()
diff --git a/src/silx/gui/plot/tools/test/testProfile.py b/src/silx/gui/plot/tools/test/testProfile.py
index ad40e67..61b95a6 100644
--- a/src/silx/gui/plot/tools/test/testProfile.py
+++ b/src/silx/gui/plot/tools/test/testProfile.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2018-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -26,14 +26,11 @@ __license__ = "MIT"
__date__ = "28/06/2018"
-import unittest
import contextlib
import numpy
import logging
from silx.gui import qt
-from silx.utils import deprecation
-from silx.utils import testutils
from silx.gui.utils.testutils import TestCaseQt
from silx.utils.testutils import ParametricTestCase
@@ -49,7 +46,6 @@ _logger = logging.getLogger(__name__)
class TestRois(TestCaseQt):
-
def test_init(self):
"""Check that the constructor is not called twice"""
roi = rois.ProfileImageVerticalLineROI()
@@ -59,7 +55,6 @@ class TestRois(TestCaseQt):
class TestInteractions(TestCaseQt):
-
@contextlib.contextmanager
def defaultPlot(self):
try:
@@ -168,7 +163,7 @@ class TestInteractions(TestCaseQt):
self.assertEqual(len(profileRois), 3)
else:
self.assertEqual(len(profileRois), 1)
- # The first one should be the expected one
+ # The first one should be the expected one
roi = profileRois[0]
# Test that something was displayed
@@ -227,14 +222,14 @@ class TestInteractions(TestCaseQt):
if isinstance(editor, editors._NoProfileRoiEditor):
pass
elif isinstance(editor, editors._DefaultImageStackProfileRoiEditor):
- # GUI to ROI
+ # GUI to ROI
editor._lineWidth.setValue(2)
self.assertEqual(roi.getProfileLineWidth(), 2)
editor._methodsButton.setMethod("sum")
self.assertEqual(roi.getProfileMethod(), "sum")
editor._profileDim.setDimension(1)
self.assertEqual(roi.getProfileType(), "1D")
- # ROI to GUI
+ # ROI to GUI
roi.setProfileLineWidth(3)
self.assertEqual(editor._lineWidth.value(), 3)
roi.setProfileMethod("mean")
@@ -242,21 +237,21 @@ class TestInteractions(TestCaseQt):
roi.setProfileType("2D")
self.assertEqual(editor._profileDim.getDimension(), 2)
elif isinstance(editor, editors._DefaultImageProfileRoiEditor):
- # GUI to ROI
+ # GUI to ROI
editor._lineWidth.setValue(2)
self.assertEqual(roi.getProfileLineWidth(), 2)
editor._methodsButton.setMethod("sum")
self.assertEqual(roi.getProfileMethod(), "sum")
- # ROI to GUI
+ # ROI to GUI
roi.setProfileLineWidth(3)
self.assertEqual(editor._lineWidth.value(), 3)
roi.setProfileMethod("mean")
self.assertEqual(editor._methodsButton.getMethod(), "mean")
elif isinstance(editor, editors._DefaultScatterProfileRoiEditor):
- # GUI to ROI
+ # GUI to ROI
editor._nPoints.setValue(100)
self.assertEqual(roi.getNPoints(), 100)
- # ROI to GUI
+ # ROI to GUI
roi.setNPoints(200)
self.assertEqual(editor._nPoints.value(), 200)
else:
@@ -268,17 +263,32 @@ class TestInteractions(TestCaseQt):
(rois.ProfileImageVerticalLineROI, editors._DefaultImageProfileRoiEditor),
(rois.ProfileImageLineROI, editors._DefaultImageProfileRoiEditor),
(rois.ProfileImageCrossROI, editors._DefaultImageProfileRoiEditor),
- (rois.ProfileScatterHorizontalLineROI, editors._DefaultScatterProfileRoiEditor),
- (rois.ProfileScatterVerticalLineROI, editors._DefaultScatterProfileRoiEditor),
+ (
+ rois.ProfileScatterHorizontalLineROI,
+ editors._DefaultScatterProfileRoiEditor,
+ ),
+ (
+ rois.ProfileScatterVerticalLineROI,
+ editors._DefaultScatterProfileRoiEditor,
+ ),
(rois.ProfileScatterLineROI, editors._DefaultScatterProfileRoiEditor),
(rois.ProfileScatterCrossROI, editors._DefaultScatterProfileRoiEditor),
(rois.ProfileScatterHorizontalSliceROI, editors._NoProfileRoiEditor),
(rois.ProfileScatterVerticalSliceROI, editors._NoProfileRoiEditor),
(rois.ProfileScatterCrossSliceROI, editors._NoProfileRoiEditor),
- (rois.ProfileImageStackHorizontalLineROI, editors._DefaultImageStackProfileRoiEditor),
- (rois.ProfileImageStackVerticalLineROI, editors._DefaultImageStackProfileRoiEditor),
+ (
+ rois.ProfileImageStackHorizontalLineROI,
+ editors._DefaultImageStackProfileRoiEditor,
+ ),
+ (
+ rois.ProfileImageStackVerticalLineROI,
+ editors._DefaultImageStackProfileRoiEditor,
+ ),
(rois.ProfileImageStackLineROI, editors._DefaultImageStackProfileRoiEditor),
- (rois.ProfileImageStackCrossROI, editors._DefaultImageStackProfileRoiEditor),
+ (
+ rois.ProfileImageStackCrossROI,
+ editors._DefaultImageStackProfileRoiEditor,
+ ),
]
with self.defaultPlot() as plot:
profileManager = manager.ProfileManager(plot, plot)
@@ -288,7 +298,7 @@ class TestInteractions(TestCaseQt):
roi = roiClass()
roi._setProfileManager(profileManager)
try:
- # Force widget creation
+ # Force widget creation
menu = qt.QMenu(plot)
menu.addAction(editorAction)
widgets = editorAction.createdWidgets()
@@ -319,10 +329,8 @@ class TestProfileToolBar(TestCaseQt, ParametricTestCase):
self.mouseMove(self.plot) # Move to center
self.qapp.processEvents()
- deprecation.FORCE = True
def tearDown(self):
- deprecation.FORCE = False
self.qapp.processEvents()
profileManager = self.toolBar.getProfileManager()
profileManager.clearProfile()
@@ -338,7 +346,7 @@ class TestProfileToolBar(TestCaseQt, ParametricTestCase):
"""Test horizontal and vertical profile, without and with image"""
# Use Plot backend widget to submit mouse events
widget = self.plot.getWidgetHandle()
- for method in ('sum', 'mean'):
+ for method in ("sum", "mean"):
with self.subTest(method=method):
# 2 positions to use for mouse events
pos1 = widget.width() * 0.4, widget.height() * 0.4
@@ -353,8 +361,7 @@ class TestProfileToolBar(TestCaseQt, ParametricTestCase):
self.mouseClick(widget, qt.Qt.LeftButton, pos=pos1)
# with image
- self.plot.addImage(
- numpy.arange(100 * 100).reshape(100, -1))
+ self.plot.addImage(numpy.arange(100 * 100).reshape(100, -1))
self.mousePress(widget, qt.Qt.LeftButton, pos=pos1)
self.mouseMove(widget, pos=pos2)
self.mouseRelease(widget, qt.Qt.LeftButton, pos=pos2)
@@ -368,16 +375,14 @@ class TestProfileToolBar(TestCaseQt, ParametricTestCase):
if not manager.hasPendingOperations():
break
- @testutils.validate_logging(deprecation.depreclog.name, warning=4)
def testDiagonalProfile(self):
"""Test diagonal profile, without and with image"""
# Use Plot backend widget to submit mouse events
widget = self.plot.getWidgetHandle()
- self.plot.addImage(
- numpy.arange(100 * 100).reshape(100, -1))
+ self.plot.addImage(numpy.arange(100 * 100).reshape(100, -1))
- for method in ('sum', 'mean'):
+ for method in ("sum", "mean"):
with self.subTest(method=method):
# 2 positions to use for mouse events
pos1 = widget.width() * 0.4, widget.height() * 0.4
@@ -414,10 +419,12 @@ class TestProfileToolBar(TestCaseQt, ParametricTestCase):
if not manager.hasPendingOperations():
break
- curveItem = self.toolBar.getProfilePlot().getAllCurves()[0]
- if method == 'sum':
+ curveItem = (
+ roi.getProfileWindow().getCurrentPlotWidget().getAllCurves()[0]
+ )
+ if method == "sum":
self.assertTrue(curveItem.getData()[1].max() > 10000)
- elif method == 'mean':
+ elif method == "mean":
self.assertTrue(curveItem.getData()[1].max() < 10000)
# Remove the ROI so the profile window is also removed
@@ -426,77 +433,26 @@ class TestProfileToolBar(TestCaseQt, ParametricTestCase):
self.qWait(100)
-class TestDeprecatedProfileToolBar(TestCaseQt):
- """Tests old features of the ProfileToolBar widget."""
-
- def setUp(self):
- self.plot = None
- super(TestDeprecatedProfileToolBar, self).setUp()
-
- def tearDown(self):
- if self.plot is not None:
- self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.plot.close()
- self.plot = None
- self.qWait()
-
- super(TestDeprecatedProfileToolBar, self).tearDown()
-
- @testutils.validate_logging(deprecation.depreclog.name, warning=2)
- def testCustomProfileWindow(self):
- from silx.gui.plot import ProfileMainWindow
-
- self.plot = PlotWindow()
- profileWindow = ProfileMainWindow.ProfileMainWindow(self.plot)
- toolBar = Profile.ProfileToolBar(parent=self.plot,
- plot=self.plot,
- profileWindow=profileWindow)
-
- self.plot.show()
- self.qWaitForWindowExposed(self.plot)
- profileWindow.show()
- self.qWaitForWindowExposed(profileWindow)
- self.qapp.processEvents()
-
- self.plot.addImage(numpy.arange(10 * 10).reshape(10, -1))
- profile = rois.ProfileImageHorizontalLineROI()
- profile.setPosition(5)
- toolBar.getProfileManager().getRoiManager().addRoi(profile)
- toolBar.getProfileManager().getRoiManager().setCurrentRoi(profile)
-
- for _ in range(20):
- self.qWait(200)
- if not toolBar.getProfileManager().hasPendingOperations():
- break
-
- # There is a displayed profile
- self.assertIsNotNone(profileWindow.getProfile())
- self.assertIs(toolBar.getProfileMainWindow(), profileWindow)
-
- # There is nothing anymore but the window is still there
- toolBar.getProfileManager().clearProfile()
- self.qapp.processEvents()
- self.assertIsNone(profileWindow.getProfile())
-
-
class TestProfile3DToolBar(TestCaseQt):
- """Tests for Profile3DToolBar widget.
- """
+ """Tests for Profile3DToolBar widget."""
+
def setUp(self):
super(TestProfile3DToolBar, self).setUp()
self.plot = StackView()
self.plot.show()
self.qWaitForWindowExposed(self.plot)
- self.plot.setStack(numpy.array([
- [[0, 1, 2], [3, 4, 5]],
- [[6, 7, 8], [9, 10, 11]],
- [[12, 13, 14], [15, 16, 17]]
- ]))
- deprecation.FORCE = True
+ self.plot.setStack(
+ numpy.array(
+ [
+ [[0, 1, 2], [3, 4, 5]],
+ [[6, 7, 8], [9, 10, 11]],
+ [[12, 13, 14], [15, 16, 17]],
+ ]
+ )
+ )
def tearDown(self):
- deprecation.FORCE = False
profileManager = self.plot.getProfileToolbar().getProfileManager()
profileManager.clearProfile()
profileManager = None
@@ -506,7 +462,6 @@ class TestProfile3DToolBar(TestCaseQt):
super(TestProfile3DToolBar, self).tearDown()
- @testutils.validate_logging(deprecation.depreclog.name, warning=2)
def testMethodProfile2D(self):
"""Test that the profile can have a different method if we want to
compute then in 1D or in 2D"""
@@ -530,15 +485,13 @@ class TestProfile3DToolBar(TestCaseQt):
break
# check 2D 'mean' profile
- profilePlot = toolBar.getProfilePlot()
+ profilePlot = roi.getProfileWindow().getCurrentPlotWidget()
data = profilePlot.getAllImages()[0].getData()
expected = numpy.array([[1, 4], [7, 10], [13, 16]])
numpy.testing.assert_almost_equal(data, expected)
- @testutils.validate_logging(deprecation.depreclog.name, warning=2)
def testMethodSumLine(self):
- """Simple interaction test to make sure the sum is correctly computed
- """
+ """Simple interaction test to make sure the sum is correctly computed"""
toolBar = self.plot.getProfileToolbar()
toolBar.lineAction.trigger()
@@ -563,14 +516,13 @@ class TestProfile3DToolBar(TestCaseQt):
break
# check 2D 'sum' profile
- profilePlot = toolBar.getProfilePlot()
+ profilePlot = roi.getProfileWindow().getCurrentPlotWidget()
data = profilePlot.getAllImages()[0].getData()
expected = numpy.array([[3, 12], [21, 30], [39, 48]])
numpy.testing.assert_almost_equal(data, expected)
class TestGetProfilePlot(TestCaseQt):
-
def setUp(self):
self.plot = None
super(TestGetProfilePlot, self).setUp()
@@ -618,8 +570,7 @@ class TestGetProfilePlot(TestCaseQt):
self.plot.show()
self.qWaitForWindowExposed(self.plot)
- self.plot.setStack(numpy.array([[[0, 1], [2, 3]],
- [[4, 5], [6, 7]]]))
+ self.plot.setStack(numpy.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]))
toolBar = self.plot.getProfileToolbar()
diff --git a/src/silx/gui/plot/tools/test/testROI.py b/src/silx/gui/plot/tools/test/testRoiCore.py
index 6ce1553..e7f6d8a 100644
--- a/src/silx/gui/plot/tools/test/testROI.py
+++ b/src/silx/gui/plot/tools/test/testRoiCore.py
@@ -26,7 +26,6 @@ __license__ = "MIT"
__date__ = "28/06/2018"
-import unittest
import numpy.testing
from silx.gui import qt
@@ -37,243 +36,6 @@ import silx.gui.plot.items.roi as roi_items
from silx.gui.plot.tools import roi
-class TestRoiItems(TestCaseQt):
-
- def testLine_geometry(self):
- item = roi_items.LineROI()
- startPoint = numpy.array([1, 2])
- endPoint = numpy.array([3, 4])
- item.setEndPoints(startPoint, endPoint)
- numpy.testing.assert_allclose(item.getEndPoints()[0], startPoint)
- numpy.testing.assert_allclose(item.getEndPoints()[1], endPoint)
-
- def testHLine_geometry(self):
- item = roi_items.HorizontalLineROI()
- item.setPosition(15)
- self.assertEqual(item.getPosition(), 15)
-
- def testVLine_geometry(self):
- item = roi_items.VerticalLineROI()
- item.setPosition(15)
- self.assertEqual(item.getPosition(), 15)
-
- def testPoint_geometry(self):
- point = numpy.array([1, 2])
- item = roi_items.PointROI()
- item.setPosition(point)
- numpy.testing.assert_allclose(item.getPosition(), point)
-
- def testRectangle_originGeometry(self):
- origin = numpy.array([0, 0])
- size = numpy.array([10, 20])
- center = numpy.array([5, 10])
- item = roi_items.RectangleROI()
- item.setGeometry(origin=origin, size=size)
- numpy.testing.assert_allclose(item.getOrigin(), origin)
- numpy.testing.assert_allclose(item.getSize(), size)
- numpy.testing.assert_allclose(item.getCenter(), center)
-
- def testRectangle_centerGeometry(self):
- origin = numpy.array([0, 0])
- size = numpy.array([10, 20])
- center = numpy.array([5, 10])
- item = roi_items.RectangleROI()
- item.setGeometry(center=center, size=size)
- numpy.testing.assert_allclose(item.getOrigin(), origin)
- numpy.testing.assert_allclose(item.getSize(), size)
- numpy.testing.assert_allclose(item.getCenter(), center)
-
- def testRectangle_setCenterGeometry(self):
- origin = numpy.array([0, 0])
- size = numpy.array([10, 20])
- item = roi_items.RectangleROI()
- item.setGeometry(origin=origin, size=size)
- newCenter = numpy.array([0, 0])
- item.setCenter(newCenter)
- expectedOrigin = numpy.array([-5, -10])
- numpy.testing.assert_allclose(item.getOrigin(), expectedOrigin)
- numpy.testing.assert_allclose(item.getCenter(), newCenter)
- numpy.testing.assert_allclose(item.getSize(), size)
-
- def testRectangle_setOriginGeometry(self):
- origin = numpy.array([0, 0])
- size = numpy.array([10, 20])
- item = roi_items.RectangleROI()
- item.setGeometry(origin=origin, size=size)
- newOrigin = numpy.array([10, 10])
- item.setOrigin(newOrigin)
- expectedCenter = numpy.array([15, 20])
- numpy.testing.assert_allclose(item.getOrigin(), newOrigin)
- numpy.testing.assert_allclose(item.getCenter(), expectedCenter)
- numpy.testing.assert_allclose(item.getSize(), size)
-
- def testCircle_geometry(self):
- center = numpy.array([0, 0])
- radius = 10.
- item = roi_items.CircleROI()
- item.setGeometry(center=center, radius=radius)
- numpy.testing.assert_allclose(item.getCenter(), center)
- numpy.testing.assert_allclose(item.getRadius(), radius)
-
- def testCircle_setCenter(self):
- center = numpy.array([0, 0])
- radius = 10.
- item = roi_items.CircleROI()
- item.setGeometry(center=center, radius=radius)
- newCenter = numpy.array([-10, 0])
- item.setCenter(newCenter)
- numpy.testing.assert_allclose(item.getCenter(), newCenter)
- numpy.testing.assert_allclose(item.getRadius(), radius)
-
- def testCircle_setRadius(self):
- center = numpy.array([0, 0])
- radius = 10.
- item = roi_items.CircleROI()
- item.setGeometry(center=center, radius=radius)
- newRadius = 5.1
- item.setRadius(newRadius)
- numpy.testing.assert_allclose(item.getCenter(), center)
- numpy.testing.assert_allclose(item.getRadius(), newRadius)
-
- def testCircle_contains(self):
- center = numpy.array([2, -1])
- radius = 1.
- item = roi_items.CircleROI()
- item.setGeometry(center=center, radius=radius)
- self.assertTrue(item.contains([1, -1]))
- self.assertFalse(item.contains([0, 0]))
- self.assertTrue(item.contains([2, 0]))
- self.assertFalse(item.contains([3.01, -1]))
-
- def testEllipse_contains(self):
- center = numpy.array([-2, 0])
- item = roi_items.EllipseROI()
- item.setCenter(center)
- item.setOrientation(numpy.pi / 4.0)
- item.setMajorRadius(2)
- item.setMinorRadius(1)
- print(item.getMinorRadius(), item.getMajorRadius())
- self.assertFalse(item.contains([0, 0]))
- self.assertTrue(item.contains([-1, 1]))
- self.assertTrue(item.contains([-3, 0]))
- self.assertTrue(item.contains([-2, 0]))
- self.assertTrue(item.contains([-2, 1]))
- self.assertFalse(item.contains([-4, 1]))
-
- def testRectangle_isIn(self):
- origin = numpy.array([0, 0])
- size = numpy.array([10, 20])
- item = roi_items.RectangleROI()
- item.setGeometry(origin=origin, size=size)
- self.assertTrue(item.contains(position=(0, 0)))
- self.assertTrue(item.contains(position=(2, 14)))
- self.assertFalse(item.contains(position=(14, 12)))
-
- def testPolygon_emptyGeometry(self):
- points = numpy.empty((0, 2))
- item = roi_items.PolygonROI()
- item.setPoints(points)
- numpy.testing.assert_allclose(item.getPoints(), points)
-
- def testPolygon_geometry(self):
- points = numpy.array([[10, 10], [12, 10], [50, 1]])
- item = roi_items.PolygonROI()
- item.setPoints(points)
- numpy.testing.assert_allclose(item.getPoints(), points)
-
- def testPolygon_isIn(self):
- points = numpy.array([[0, 0], [0, 10], [5, 10]])
- item = roi_items.PolygonROI()
- item.setPoints(points)
- self.assertTrue(item.contains((0, 0)))
- self.assertFalse(item.contains((6, 2)))
- self.assertFalse(item.contains((-2, 5)))
- self.assertFalse(item.contains((2, -1)))
- self.assertFalse(item.contains((8, 1)))
- self.assertTrue(item.contains((1, 8)))
-
- def testArc_getToSetGeometry(self):
- """Test that we can use getGeometry as input to setGeometry"""
- item = roi_items.ArcROI()
- item.setFirstShapePoints(numpy.array([[5, 10], [50, 100]]))
- item.setGeometry(*item.getGeometry())
-
- def testArc_degenerated_point(self):
- item = roi_items.ArcROI()
- center = numpy.array([10, 20])
- innerRadius, outerRadius, startAngle, endAngle = 0, 0, 0, 0
- item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
-
- def testArc_degenerated_line(self):
- item = roi_items.ArcROI()
- center = numpy.array([10, 20])
- innerRadius, outerRadius, startAngle, endAngle = 0, 100, numpy.pi, numpy.pi
- item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
-
- def testArc_special_circle(self):
- item = roi_items.ArcROI()
- center = numpy.array([10, 20])
- innerRadius, outerRadius, startAngle, endAngle = 0, 100, numpy.pi, 3 * numpy.pi
- item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
- numpy.testing.assert_allclose(item.getCenter(), center)
- self.assertAlmostEqual(item.getInnerRadius(), innerRadius)
- self.assertAlmostEqual(item.getOuterRadius(), outerRadius)
- self.assertAlmostEqual(item.getStartAngle(), item.getEndAngle() - numpy.pi * 2.0)
- self.assertTrue(item.isClosed())
-
- def testArc_special_donut(self):
- item = roi_items.ArcROI()
- center = numpy.array([10, 20])
- innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi, 3 * numpy.pi
- item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
- numpy.testing.assert_allclose(item.getCenter(), center)
- self.assertAlmostEqual(item.getInnerRadius(), innerRadius)
- self.assertAlmostEqual(item.getOuterRadius(), outerRadius)
- self.assertAlmostEqual(item.getStartAngle(), item.getEndAngle() - numpy.pi * 2.0)
- self.assertTrue(item.isClosed())
-
- def testArc_clockwiseGeometry(self):
- """Test that we can use getGeometry as input to setGeometry"""
- item = roi_items.ArcROI()
- center = numpy.array([10, 20])
- innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi * 0.5, numpy.pi
- item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
- numpy.testing.assert_allclose(item.getCenter(), center)
- self.assertAlmostEqual(item.getInnerRadius(), innerRadius)
- self.assertAlmostEqual(item.getOuterRadius(), outerRadius)
- self.assertAlmostEqual(item.getStartAngle(), startAngle)
- self.assertAlmostEqual(item.getEndAngle(), endAngle)
- self.assertAlmostEqual(item.isClosed(), False)
-
- def testArc_anticlockwiseGeometry(self):
- """Test that we can use getGeometry as input to setGeometry"""
- item = roi_items.ArcROI()
- center = numpy.array([10, 20])
- innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi * 0.5, -numpy.pi * 0.5
- item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
- numpy.testing.assert_allclose(item.getCenter(), center)
- self.assertAlmostEqual(item.getInnerRadius(), innerRadius)
- self.assertAlmostEqual(item.getOuterRadius(), outerRadius)
- self.assertAlmostEqual(item.getStartAngle(), startAngle)
- self.assertAlmostEqual(item.getEndAngle(), endAngle)
- self.assertAlmostEqual(item.isClosed(), False)
-
- def testHRange_geometry(self):
- item = roi_items.HorizontalRangeROI()
- vmin = 1
- vmax = 3
- item.setRange(vmin, vmax)
- self.assertAlmostEqual(item.getMin(), vmin)
- self.assertAlmostEqual(item.getMax(), vmax)
- self.assertAlmostEqual(item.getCenter(), 2)
-
- def testBand_getToSetGeometry(self):
- """Test that we can use getGeometry as input to setGeometry"""
- item = roi_items.BandROI()
- item.setFirstShapePoints(numpy.array([[5, 10], [50, 100]]))
- item.setGeometry(*item.getGeometry())
-
-
class TestRegionOfInterestManager(TestCaseQt, ParametricTestCase):
"""Tests for RegionOfInterestManager class"""
@@ -300,25 +62,44 @@ class TestRegionOfInterestManager(TestCaseQt, ParametricTestCase):
def test(self):
"""Test ROI of different shapes"""
tests = ( # shape, points=[list of (x, y), list of (x, y)]
- (roi_items.PointROI, numpy.array(([(10., 15.)], [(20., 25.)]))),
- (roi_items.RectangleROI,
- numpy.array((((1., 10.), (11., 20.)),
- ((2., 3.), (12., 13.))))),
- (roi_items.PolygonROI,
- numpy.array((((0., 1.), (0., 10.), (10., 0.)),
- ((5., 6.), (5., 16.), (15., 6.))))),
- (roi_items.LineROI,
- numpy.array((((10., 20.), (10., 30.)),
- ((30., 40.), (30., 50.))))),
- (roi_items.HorizontalLineROI,
- numpy.array((((10., 20.), (10., 30.)),
- ((30., 40.), (30., 50.))))),
- (roi_items.VerticalLineROI,
- numpy.array((((10., 20.), (10., 30.)),
- ((30., 40.), (30., 50.))))),
- (roi_items.HorizontalLineROI,
- numpy.array((((10., 20.), (10., 30.)),
- ((30., 40.), (30., 50.))))),
+ (roi_items.PointROI, numpy.array(([(10.0, 15.0)], [(20.0, 25.0)]))),
+ (
+ roi_items.RectangleROI,
+ numpy.array((((1.0, 10.0), (11.0, 20.0)), ((2.0, 3.0), (12.0, 13.0)))),
+ ),
+ (
+ roi_items.PolygonROI,
+ numpy.array(
+ (
+ ((0.0, 1.0), (0.0, 10.0), (10.0, 0.0)),
+ ((5.0, 6.0), (5.0, 16.0), (15.0, 6.0)),
+ )
+ ),
+ ),
+ (
+ roi_items.LineROI,
+ numpy.array(
+ (((10.0, 20.0), (10.0, 30.0)), ((30.0, 40.0), (30.0, 50.0)))
+ ),
+ ),
+ (
+ roi_items.HorizontalLineROI,
+ numpy.array(
+ (((10.0, 20.0), (10.0, 30.0)), ((30.0, 40.0), (30.0, 50.0)))
+ ),
+ ),
+ (
+ roi_items.VerticalLineROI,
+ numpy.array(
+ (((10.0, 20.0), (10.0, 30.0)), ((30.0, 40.0), (30.0, 50.0)))
+ ),
+ ),
+ (
+ roi_items.HorizontalLineROI,
+ numpy.array(
+ (((10.0, 20.0), (10.0, 30.0)), ((30.0, 40.0), (30.0, 50.0)))
+ ),
+ ),
)
for roiClass, points in tests:
@@ -453,7 +234,12 @@ class TestRegionOfInterestManager(TestCaseQt, ParametricTestCase):
# Arc
item = roi_items.ArcROI()
center = numpy.array([10, 20])
- innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi * 0.5, numpy.pi
+ innerRadius, outerRadius, startAngle, endAngle = (
+ 1,
+ 100,
+ numpy.pi * 0.5,
+ numpy.pi,
+ )
item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
rois.append(item)
# Horizontal Range
@@ -493,12 +279,20 @@ class TestRegionOfInterestManager(TestCaseQt, ParametricTestCase):
manager.removeRoi(item1)
self.assertIs(manager.getCurrentRoi(), None)
+ def testInitROIWithParent(self):
+ manager = roi.RegionOfInterestManager(self.plot)
+ item = roi_items.PointROI(manager)
+ manager.addRoi(item)
+ self.qapp.processEvents()
+ manager.removeRoi(item)
+ self.qapp.processEvents()
+
def testMaxROI(self):
"""Test Max ROI"""
- origin1 = numpy.array([1., 10.])
- size1 = numpy.array([10., 10.])
- origin2 = numpy.array([2., 3.])
- size2 = numpy.array([10., 10.])
+ origin1 = numpy.array([1.0, 10.0])
+ size1 = numpy.array([10.0, 10.0])
+ origin2 = numpy.array([2.0, 3.0])
+ size2 = numpy.array([10.0, 10.0])
manager = roi.InteractiveRegionOfInterestManager(self.plot)
self.roiTableWidget.setRegionOfInterestManager(manager)
@@ -587,16 +381,17 @@ class TestRegionOfInterestManager(TestCaseQt, ParametricTestCase):
mx, my = self.plot.dataToPixel(*center)
self.mouseMove(widget, pos=(mx, my))
self.mousePress(widget, qt.Qt.LeftButton, pos=(mx, my))
- self.mouseMove(widget, pos=(mx, my+25))
- self.mouseMove(widget, pos=(mx, my+50))
- self.mouseRelease(widget, qt.Qt.LeftButton, pos=(mx, my+50))
+ self.mouseMove(widget, pos=(mx, my + 25))
+ self.mouseMove(widget, pos=(mx, my + 50))
+ self.mouseRelease(widget, qt.Qt.LeftButton, pos=(mx, my + 50))
result = numpy.array(item.getEndPoints())
# x location is still the same
numpy.testing.assert_allclose(points[:, 0], result[:, 0], atol=0.5)
# size is still the same
- numpy.testing.assert_allclose(points[1] - points[0],
- result[1] - result[0], atol=0.5)
+ numpy.testing.assert_allclose(
+ points[1] - points[0], result[1] - result[0], atol=0.5
+ )
# But Y is not the same
self.assertNotEqual(points[0, 1], result[0, 1])
self.assertNotEqual(points[1, 1], result[1, 1])
@@ -667,7 +462,7 @@ class TestRegionOfInterestManager(TestCaseQt, ParametricTestCase):
self.assertIs(item.getInteractionMode(), roi_items.ArcROI.ThreePointMode)
self.qWait(500)
- # Click on the center
+ # Click on the center
widget = self.plot.getWidgetHandle()
mx, my = self.plot.dataToPixel(*center)
@@ -710,7 +505,7 @@ class TestRegionOfInterestManager(TestCaseQt, ParametricTestCase):
assert item.getInteractionMode() is roi_items.BandROI.BoundedMode
self.qWait(500)
- # Click on the center
+ # Click on the center
widget = self.plot.getWidgetHandle()
mx, my = self.plot.dataToPixel(xcenter, ycenter)
diff --git a/src/silx/gui/plot/tools/test/testRoiItems.py b/src/silx/gui/plot/tools/test/testRoiItems.py
new file mode 100644
index 0000000..9bd9690
--- /dev/null
+++ b/src/silx/gui/plot/tools/test/testRoiItems.py
@@ -0,0 +1,313 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+
+import pytest
+import numpy.testing
+
+import silx.gui.plot.items.roi as roi_items
+
+
+def testLine_geometry(qapp):
+ item = roi_items.LineROI()
+ startPoint = numpy.array([1, 2])
+ endPoint = numpy.array([3, 4])
+ item.setEndPoints(startPoint, endPoint)
+ numpy.testing.assert_allclose(item.getEndPoints()[0], startPoint)
+ numpy.testing.assert_allclose(item.getEndPoints()[1], endPoint)
+
+
+def testHLine_geometry(qapp):
+ item = roi_items.HorizontalLineROI()
+ item.setPosition(15)
+ assert item.getPosition() == 15
+
+
+def testVLine_geometry(qapp):
+ item = roi_items.VerticalLineROI()
+ item.setPosition(15)
+ assert item.getPosition() == 15
+
+
+def testPoint_geometry(qapp):
+ point = numpy.array([1, 2])
+ item = roi_items.PointROI()
+ item.setPosition(point)
+ numpy.testing.assert_allclose(item.getPosition(), point)
+
+
+def testRectangle_originGeometry(qapp):
+ origin = numpy.array([0, 0])
+ size = numpy.array([10, 20])
+ center = numpy.array([5, 10])
+ item = roi_items.RectangleROI()
+ item.setGeometry(origin=origin, size=size)
+ numpy.testing.assert_allclose(item.getOrigin(), origin)
+ numpy.testing.assert_allclose(item.getSize(), size)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+
+
+def testRectangle_centerGeometry(qapp):
+ origin = numpy.array([0, 0])
+ size = numpy.array([10, 20])
+ center = numpy.array([5, 10])
+ item = roi_items.RectangleROI()
+ item.setGeometry(center=center, size=size)
+ numpy.testing.assert_allclose(item.getOrigin(), origin)
+ numpy.testing.assert_allclose(item.getSize(), size)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+
+
+def testRectangle_setCenterGeometry(qapp):
+ origin = numpy.array([0, 0])
+ size = numpy.array([10, 20])
+ item = roi_items.RectangleROI()
+ item.setGeometry(origin=origin, size=size)
+ newCenter = numpy.array([0, 0])
+ item.setCenter(newCenter)
+ expectedOrigin = numpy.array([-5, -10])
+ numpy.testing.assert_allclose(item.getOrigin(), expectedOrigin)
+ numpy.testing.assert_allclose(item.getCenter(), newCenter)
+ numpy.testing.assert_allclose(item.getSize(), size)
+
+
+def testRectangle_setOriginGeometry(qapp):
+ origin = numpy.array([0, 0])
+ size = numpy.array([10, 20])
+ item = roi_items.RectangleROI()
+ item.setGeometry(origin=origin, size=size)
+ newOrigin = numpy.array([10, 10])
+ item.setOrigin(newOrigin)
+ expectedCenter = numpy.array([15, 20])
+ numpy.testing.assert_allclose(item.getOrigin(), newOrigin)
+ numpy.testing.assert_allclose(item.getCenter(), expectedCenter)
+ numpy.testing.assert_allclose(item.getSize(), size)
+
+
+def testCircle_geometry(qapp):
+ center = numpy.array([0, 0])
+ radius = 10.0
+ item = roi_items.CircleROI()
+ item.setGeometry(center=center, radius=radius)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+ numpy.testing.assert_allclose(item.getRadius(), radius)
+
+
+def testCircle_setCenter(qapp):
+ center = numpy.array([0, 0])
+ radius = 10.0
+ item = roi_items.CircleROI()
+ item.setGeometry(center=center, radius=radius)
+ newCenter = numpy.array([-10, 0])
+ item.setCenter(newCenter)
+ numpy.testing.assert_allclose(item.getCenter(), newCenter)
+ numpy.testing.assert_allclose(item.getRadius(), radius)
+
+
+def testCircle_setRadius(qapp):
+ center = numpy.array([0, 0])
+ radius = 10.0
+ item = roi_items.CircleROI()
+ item.setGeometry(center=center, radius=radius)
+ newRadius = 5.1
+ item.setRadius(newRadius)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+ numpy.testing.assert_allclose(item.getRadius(), newRadius)
+
+
+def testCircle_contains(qapp):
+ center = numpy.array([2, -1])
+ radius = 1.0
+ item = roi_items.CircleROI()
+ item.setGeometry(center=center, radius=radius)
+ assert item.contains([1, -1])
+ assert not item.contains([0, 0])
+ assert item.contains([2, 0])
+ assert not item.contains([3.01, -1])
+
+
+def testEllipse_contains(qapp):
+ center = numpy.array([-2, 0])
+ item = roi_items.EllipseROI()
+ item.setCenter(center)
+ item.setOrientation(numpy.pi / 4.0)
+ item.setMajorRadius(2)
+ item.setMinorRadius(1)
+ print(item.getMinorRadius(), item.getMajorRadius())
+ assert not item.contains([0, 0])
+ assert item.contains([-1, 1])
+ assert item.contains([-3, 0])
+ assert item.contains([-2, 0])
+ assert item.contains([-2, 1])
+ assert not item.contains([-4, 1])
+
+
+def testRectangle_isIn(qapp):
+ origin = numpy.array([0, 0])
+ size = numpy.array([10, 20])
+ item = roi_items.RectangleROI()
+ item.setGeometry(origin=origin, size=size)
+ assert item.contains(position=(0, 0))
+ assert item.contains(position=(2, 14))
+ assert not item.contains(position=(14, 12))
+
+
+def testPolygon_emptyGeometry(qapp):
+ points = numpy.empty((0, 2))
+ item = roi_items.PolygonROI()
+ item.setPoints(points)
+ numpy.testing.assert_allclose(item.getPoints(), points)
+
+
+def testPolygon_geometry(qapp):
+ points = numpy.array([[10, 10], [12, 10], [50, 1]])
+ item = roi_items.PolygonROI()
+ item.setPoints(points)
+ numpy.testing.assert_allclose(item.getPoints(), points)
+
+
+def testPolygon_isIn(qapp):
+ points = numpy.array([[0, 0], [0, 10], [5, 10]])
+ item = roi_items.PolygonROI()
+ item.setPoints(points)
+ assert item.contains((0, 0))
+ assert not item.contains((6, 2))
+ assert not item.contains((-2, 5))
+ assert not item.contains((2, -1))
+ assert not item.contains((8, 1))
+ assert item.contains((1, 8))
+
+
+def testArc_getToSetGeometry(qapp):
+ """Test that we can use getGeometry as input to setGeometry"""
+ item = roi_items.ArcROI()
+ item.setFirstShapePoints(numpy.array([[5, 10], [50, 100]]))
+ item.setGeometry(*item.getGeometry())
+
+
+def testArc_degenerated_point(qapp):
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 0, 0, 0, 0
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+
+
+def testArc_degenerated_line(qapp):
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 0, 100, numpy.pi, numpy.pi
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+
+
+def testArc_special_circle(qapp):
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 0, 100, numpy.pi, 3 * numpy.pi
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+ assert item.getInnerRadius() == pytest.approx(innerRadius)
+ assert item.getOuterRadius() == pytest.approx(outerRadius)
+ assert item.getStartAngle() == pytest.approx(item.getEndAngle() - numpy.pi * 2.0)
+ assert item.isClosed()
+
+
+def testArc_special_donut(qapp):
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi, 3 * numpy.pi
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+ assert item.getInnerRadius() == pytest.approx(innerRadius)
+ assert item.getOuterRadius() == pytest.approx(outerRadius)
+ assert item.getStartAngle() == pytest.approx(item.getEndAngle() - numpy.pi * 2.0)
+ assert item.isClosed()
+
+
+def testArc_clockwiseGeometry(qapp):
+ """Test that we can use getGeometry as input to setGeometry"""
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi * 0.5, numpy.pi
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+ assert item.getInnerRadius() == pytest.approx(innerRadius)
+ assert item.getOuterRadius() == pytest.approx(outerRadius)
+ assert item.getStartAngle() == pytest.approx(startAngle)
+ assert item.getEndAngle() == pytest.approx(endAngle)
+ assert not item.isClosed()
+
+
+def testArc_anticlockwiseGeometry(qapp):
+ """Test that we can use getGeometry as input to setGeometry"""
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = (
+ 1,
+ 100,
+ numpy.pi * 0.5,
+ -numpy.pi * 0.5,
+ )
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+ assert item.getInnerRadius() == pytest.approx(innerRadius)
+ assert item.getOuterRadius() == pytest.approx(outerRadius)
+ assert item.getStartAngle() == pytest.approx(startAngle)
+ assert item.getEndAngle() == pytest.approx(endAngle)
+ assert not item.isClosed()
+
+
+def testArc_position(qapp):
+ """Test validity of getPosition"""
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi * 0.5, numpy.pi
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ assert item.getPosition(roi_items.ArcROI.Role.START) == pytest.approx((10.0, 70.5))
+ assert item.getPosition(roi_items.ArcROI.Role.STOP) == pytest.approx((-40.5, 20.0))
+ assert item.getPosition(roi_items.ArcROI.Role.MIDDLE) == pytest.approx(
+ (-25.71, 55.71), abs=0.1
+ )
+ assert item.getPosition(roi_items.ArcROI.Role.CENTER) == pytest.approx(
+ (10.0, 20), abs=0.1
+ )
+
+
+def testHRange_geometry(qapp):
+ item = roi_items.HorizontalRangeROI()
+ vmin = 1
+ vmax = 3
+ item.setRange(vmin, vmax)
+ assert item.getMin() == pytest.approx(vmin)
+ assert item.getMax() == pytest.approx(vmax)
+ assert item.getCenter() == pytest.approx(2)
+
+
+def testBand_getToSetGeometry(qapp):
+ """Test that we can use getGeometry as input to setGeometry"""
+ item = roi_items.BandROI()
+ item.setFirstShapePoints(numpy.array([[5, 10], [50, 100]]))
+ item.setGeometry(*item.getGeometry())
diff --git a/src/silx/gui/plot/tools/test/testScatterProfileToolBar.py b/src/silx/gui/plot/tools/test/testScatterProfileToolBar.py
index 9b9caa1..29c9ad0 100644
--- a/src/silx/gui/plot/tools/test/testScatterProfileToolBar.py
+++ b/src/silx/gui/plot/tools/test/testScatterProfileToolBar.py
@@ -26,7 +26,6 @@ __license__ = "MIT"
__date__ = "28/06/2018"
-import unittest
import numpy
from silx.gui import qt
@@ -67,8 +66,9 @@ class TestScatterProfileToolBar(TestCaseQt, ParametricTestCase):
# Add a scatter plot
self.plot.addScatter(
- x=(0., 1., 1., 0.), y=(0., 0., 1., 1.), value=(0., 1., 2., 3.))
- self.plot.resetZoom(dataMargins=(.1, .1, .1, .1))
+ x=(0.0, 1.0, 1.0, 0.0), y=(0.0, 0.0, 1.0, 1.0), value=(0.0, 1.0, 2.0, 3.0)
+ )
+ self.plot.resetZoom(dataMargins=(0.1, 0.1, 0.1, 0.1))
self.qapp.processEvents()
# Set a ROI profile
@@ -107,8 +107,9 @@ class TestScatterProfileToolBar(TestCaseQt, ParametricTestCase):
# Add a scatter plot
self.plot.addScatter(
- x=(0., 1., 1., 0.), y=(0., 0., 1., 1.), value=(0., 1., 2., 3.))
- self.plot.resetZoom(dataMargins=(.1, .1, .1, .1))
+ x=(0.0, 1.0, 1.0, 0.0), y=(0.0, 0.0, 1.0, 1.0), value=(0.0, 1.0, 2.0, 3.0)
+ )
+ self.plot.resetZoom(dataMargins=(0.1, 0.1, 0.1, 0.1))
self.qapp.processEvents()
# Set a ROI profile
@@ -160,13 +161,14 @@ class TestScatterProfileToolBar(TestCaseQt, ParametricTestCase):
# Add a scatter plot
self.plot.addScatter(
- x=(0., 1., 1., 0.), y=(0., 0., 1., 1.), value=(0., 1., 2., 3.))
- self.plot.resetZoom(dataMargins=(.1, .1, .1, .1))
+ x=(0.0, 1.0, 1.0, 0.0), y=(0.0, 0.0, 1.0, 1.0), value=(0.0, 1.0, 2.0, 3.0)
+ )
+ self.plot.resetZoom(dataMargins=(0.1, 0.1, 0.1, 0.1))
self.qapp.processEvents()
# Set a ROI profile
roi = rois.ProfileScatterLineROI()
- roi.setEndPoints(numpy.array([0., 0.]), numpy.array([1., 1.]))
+ roi.setEndPoints(numpy.array([0.0, 0.0]), numpy.array([1.0, 1.0]))
roi.setNPoints(8)
roiManager.addRoi(roi)
diff --git a/src/silx/gui/plot/tools/test/testTools.py b/src/silx/gui/plot/tools/test/testTools.py
index 507b922..1212ead 100644
--- a/src/silx/gui/plot/tools/test/testTools.py
+++ b/src/silx/gui/plot/tools/test/testTools.py
@@ -29,11 +29,9 @@ __date__ = "02/03/2018"
import functools
-import unittest
import numpy
from silx.utils.testutils import LoggingValidator
-from silx.gui.utils.testutils import qWaitForWindowExposedAndActivate
from silx.gui import qt
from silx.gui.plot import PlotWindow
from silx.gui.plot import tools
@@ -82,28 +80,28 @@ class TestPositionInfo(PlotWidgetTestCase):
def testDefaultConverters(self):
"""Test PositionInfo with default converters"""
positionWidget = tools.PositionInfo(plot=self.plot)
- self._test(positionWidget, ('X', 'Y'))
+ self._test(positionWidget, ("X", "Y"))
def testCustomConverters(self):
"""Test PositionInfo with custom converters"""
converters = [
- ('Coords', lambda x, y: (int(x), int(y))),
- ('Radius', lambda x, y: numpy.sqrt(x * x + y * y)),
- ('Angle', lambda x, y: numpy.degrees(numpy.arctan2(y, x)))
+ ("Coords", lambda x, y: (int(x), int(y))),
+ ("Radius", lambda x, y: numpy.sqrt(x * x + y * y)),
+ ("Angle", lambda x, y: numpy.degrees(numpy.arctan2(y, x))),
]
- positionWidget = tools.PositionInfo(plot=self.plot,
- converters=converters)
- self._test(positionWidget, ('Coords', 'Radius', 'Angle'))
+ positionWidget = tools.PositionInfo(plot=self.plot, converters=converters)
+ self._test(positionWidget, ("Coords", "Radius", "Angle"))
def testFailingConverters(self):
"""Test PositionInfo with failing custom converters"""
+
def raiseException(x, y):
raise RuntimeError()
positionWidget = tools.PositionInfo(
- plot=self.plot,
- converters=[('Exception', raiseException)])
- self._test(positionWidget, ['Exception'], error=2)
+ plot=self.plot, converters=[("Exception", raiseException)]
+ )
+ self._test(positionWidget, ["Exception"], error=2)
def testUpdate(self):
"""Test :meth:`PositionInfo.updateInfo`"""
@@ -115,7 +113,8 @@ class TestPositionInfo(PlotWidgetTestCase):
positionWidget = tools.PositionInfo(
plot=self.plot,
- converters=[('Call count', functools.partial(update, calls))])
+ converters=[("Call count", functools.partial(update, calls))],
+ )
positionWidget.updateInfo()
self.assertEqual(len(calls), 1)
@@ -125,10 +124,12 @@ class TestPlotToolsToolbars(PlotWidgetTestCase):
"""Tests toolbars from silx.gui.plot.tools"""
def test(self):
- """"Add all toolbars"""
- for tbClass in (tools.InteractiveModeToolBar,
- tools.ImageToolBar,
- tools.CurveToolBar,
- tools.OutputToolBar):
+ """ "Add all toolbars"""
+ for tbClass in (
+ tools.InteractiveModeToolBar,
+ tools.ImageToolBar,
+ tools.CurveToolBar,
+ tools.OutputToolBar,
+ ):
tb = tbClass(parent=self.plot, plot=self.plot)
self.plot.addToolBar(tb)
diff --git a/src/silx/gui/plot/tools/toolbars.py b/src/silx/gui/plot/tools/toolbars.py
index bb89942..7f38f1c 100644
--- a/src/silx/gui/plot/tools/toolbars.py
+++ b/src/silx/gui/plot/tools/toolbars.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2018-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -33,7 +33,6 @@ from ... import qt
from .. import actions
from ..PlotWidget import PlotWidget
from .. import PlotToolButtons
-from ....utils.deprecation import deprecated
class InteractiveModeToolBar(qt.QToolBar):
@@ -44,17 +43,15 @@ class InteractiveModeToolBar(qt.QToolBar):
:param str title: Title of the toolbar.
"""
- def __init__(self, parent=None, plot=None, title='Plot Interaction'):
+ def __init__(self, parent=None, plot=None, title="Plot Interaction"):
super(InteractiveModeToolBar, self).__init__(title, parent)
assert isinstance(plot, PlotWidget)
- self._zoomModeAction = actions.mode.ZoomModeAction(
- parent=self, plot=plot)
+ self._zoomModeAction = actions.mode.ZoomModeAction(parent=self, plot=plot)
self.addAction(self._zoomModeAction)
- self._panModeAction = actions.mode.PanModeAction(
- parent=self, plot=plot)
+ self._panModeAction = actions.mode.PanModeAction(parent=self, plot=plot)
self.addAction(self._panModeAction)
def getZoomModeAction(self):
@@ -80,7 +77,7 @@ class OutputToolBar(qt.QToolBar):
:param str title: Title of the toolbar.
"""
- def __init__(self, parent=None, plot=None, title='Plot Output'):
+ def __init__(self, parent=None, plot=None, title="Plot Output"):
super(OutputToolBar, self).__init__(title, parent)
assert isinstance(plot, PlotWidget)
@@ -124,25 +121,25 @@ class ImageToolBar(qt.QToolBar):
:param str title: Title of the toolbar.
"""
- def __init__(self, parent=None, plot=None, title='Image'):
+ def __init__(self, parent=None, plot=None, title="Image"):
super(ImageToolBar, self).__init__(title, parent)
assert isinstance(plot, PlotWidget)
- self._resetZoomAction = actions.control.ResetZoomAction(
- parent=self, plot=plot)
+ self._resetZoomAction = actions.control.ResetZoomAction(parent=self, plot=plot)
self.addAction(self._resetZoomAction)
- self._colormapAction = actions.control.ColormapAction(
- parent=self, plot=plot)
+ self._colormapAction = actions.control.ColormapAction(parent=self, plot=plot)
self.addAction(self._colormapAction)
self._keepDataAspectRatioButton = PlotToolButtons.AspectToolButton(
- parent=self, plot=plot)
+ parent=self, plot=plot
+ )
self.addWidget(self._keepDataAspectRatioButton)
self._yAxisInvertedButton = PlotToolButtons.YAxisOriginToolButton(
- parent=self, plot=plot)
+ parent=self, plot=plot
+ )
self.addWidget(self._yAxisInvertedButton)
def getResetZoomAction(self):
@@ -182,37 +179,40 @@ class CurveToolBar(qt.QToolBar):
:param str title: Title of the toolbar.
"""
- def __init__(self, parent=None, plot=None, title='Image'):
+ def __init__(self, parent=None, plot=None, title="Image"):
super(CurveToolBar, self).__init__(title, parent)
assert isinstance(plot, PlotWidget)
- self._resetZoomAction = actions.control.ResetZoomAction(
- parent=self, plot=plot)
+ self._resetZoomAction = actions.control.ResetZoomAction(parent=self, plot=plot)
self.addAction(self._resetZoomAction)
self._xAxisAutoScaleAction = actions.control.XAxisAutoScaleAction(
- parent=self, plot=plot)
+ parent=self, plot=plot
+ )
self.addAction(self._xAxisAutoScaleAction)
self._yAxisAutoScaleAction = actions.control.YAxisAutoScaleAction(
- parent=self, plot=plot)
+ parent=self, plot=plot
+ )
self.addAction(self._yAxisAutoScaleAction)
self._xAxisLogarithmicAction = actions.control.XAxisLogarithmicAction(
- parent=self, plot=plot)
+ parent=self, plot=plot
+ )
self.addAction(self._xAxisLogarithmicAction)
self._yAxisLogarithmicAction = actions.control.YAxisLogarithmicAction(
- parent=self, plot=plot)
+ parent=self, plot=plot
+ )
self.addAction(self._yAxisLogarithmicAction)
- self._gridAction = actions.control.GridAction(
- parent=self, plot=plot)
+ self._gridAction = actions.control.GridAction(parent=self, plot=plot)
self.addAction(self._gridAction)
self._curveStyleAction = actions.control.CurveStyleAction(
- parent=self, plot=plot)
+ parent=self, plot=plot
+ )
self.addAction(self._curveStyleAction)
def getResetZoomAction(self):
@@ -273,37 +273,38 @@ class ScatterToolBar(qt.QToolBar):
:param str title: Title of the toolbar.
"""
- def __init__(self, parent=None, plot=None, title='Scatter Tools'):
+ def __init__(self, parent=None, plot=None, title="Scatter Tools"):
super(ScatterToolBar, self).__init__(title, parent)
assert isinstance(plot, PlotWidget)
- self._resetZoomAction = actions.control.ResetZoomAction(
- parent=self, plot=plot)
+ self._resetZoomAction = actions.control.ResetZoomAction(parent=self, plot=plot)
self.addAction(self._resetZoomAction)
self._xAxisLogarithmicAction = actions.control.XAxisLogarithmicAction(
- parent=self, plot=plot)
+ parent=self, plot=plot
+ )
self.addAction(self._xAxisLogarithmicAction)
self._yAxisLogarithmicAction = actions.control.YAxisLogarithmicAction(
- parent=self, plot=plot)
+ parent=self, plot=plot
+ )
self.addAction(self._yAxisLogarithmicAction)
self._keepDataAspectRatioButton = PlotToolButtons.AspectToolButton(
- parent=self, plot=plot)
+ parent=self, plot=plot
+ )
self.addWidget(self._keepDataAspectRatioButton)
- self._gridAction = actions.control.GridAction(
- parent=self, plot=plot)
+ self._gridAction = actions.control.GridAction(parent=self, plot=plot)
self.addAction(self._gridAction)
- self._colormapAction = actions.control.ColormapAction(
- parent=self, plot=plot)
+ self._colormapAction = actions.control.ColormapAction(parent=self, plot=plot)
self.addAction(self._colormapAction)
- self._visualizationToolButton = \
- PlotToolButtons.ScatterVisualizationToolButton(parent=self, plot=plot)
+ self._visualizationToolButton = PlotToolButtons.ScatterVisualizationToolButton(
+ parent=self, plot=plot
+ )
self.addWidget(self._visualizationToolButton)
def getResetZoomAction(self):
@@ -354,8 +355,3 @@ class ScatterToolBar(qt.QToolBar):
:rtype: ScatterVisualizationToolButton
"""
return self._visualizationToolButton
-
- @deprecated(replacement='getScatterVisualizationToolButton',
- since_version='0.11.0')
- def getSymbolToolButton(self):
- return self.getScatterVisualizationToolButton()
diff --git a/src/silx/gui/plot/utils/axis.py b/src/silx/gui/plot/utils/axis.py
index 419a71c..4c6bcef 100644
--- a/src/silx/gui/plot/utils/axis.py
+++ b/src/silx/gui/plot/utils/axis.py
@@ -56,14 +56,16 @@ class SyncAxes(object):
.. versionadded:: 0.6
"""
- def __init__(self, axes,
- syncLimits=True,
- syncScale=True,
- syncDirection=True,
- syncCenter=False,
- syncZoom=False,
- filterHiddenPlots=False
- ):
+ def __init__(
+ self,
+ axes,
+ syncLimits=True,
+ syncScale=True,
+ syncDirection=True,
+ syncCenter=False,
+ syncZoom=False,
+ filterHiddenPlots=False,
+ ):
"""
Constructor
@@ -79,12 +81,13 @@ class SyncAxes(object):
"""
object.__init__(self)
- def implies(x, y): return bool(y ** x)
+ def implies(x, y):
+ return bool(y**x)
- assert(implies(syncZoom, not syncLimits))
- assert(implies(syncCenter, not syncLimits))
- assert(implies(syncLimits, not syncCenter))
- assert(implies(syncLimits, not syncZoom))
+ assert implies(syncZoom, not syncLimits)
+ assert implies(syncCenter, not syncLimits)
+ assert implies(syncLimits, not syncCenter)
+ assert implies(syncLimits, not syncZoom)
self.__filterHiddenPlots = filterHiddenPlots
self.__locked = False
@@ -313,7 +316,7 @@ class SyncAxes(object):
elif isinstance(axis, YAxis):
return bounds[3]
else:
- assert(False)
+ assert False
def __getLimitsFromCenter(self, axis, pos, pixelSize=None):
"""Returns the limits to apply to this axis to move the `pos` into the
diff --git a/src/silx/gui/plot/utils/intersections.py b/src/silx/gui/plot/utils/intersections.py
index 4f6ed23..faf6641 100644
--- a/src/silx/gui/plot/utils/intersections.py
+++ b/src/silx/gui/plot/utils/intersections.py
@@ -24,7 +24,9 @@
"""This module contains utils class for axes management.
"""
-__authors__ = ["H. Payno", ]
+__authors__ = [
+ "H. Payno",
+]
__license__ = "MIT"
__date__ = "18/05/2020"
@@ -59,11 +61,11 @@ def lines_intersection(line1_pt1, line1_pt2, line2_pt1, line2_pt2):
return None
return (
(num / denom.astype(float)) * dir_line2[0] + line2_pt1[0],
- (num / denom.astype(float)) * dir_line2[1] + line2_pt1[1])
+ (num / denom.astype(float)) * dir_line2[1] + line2_pt1[1],
+ )
-def segments_intersection(seg1_start_pt, seg1_end_pt, seg2_start_pt,
- seg2_end_pt):
+def segments_intersection(seg1_start_pt, seg1_end_pt, seg2_start_pt, seg2_end_pt):
"""
Compute intersection between two segments
@@ -74,10 +76,12 @@ def segments_intersection(seg1_start_pt, seg1_end_pt, seg2_start_pt,
:return: numpy.array if an intersection exists, else None
:rtype: Union[None,numpy.array]
"""
- intersection = lines_intersection(line1_pt1=seg1_start_pt,
- line1_pt2=seg1_end_pt,
- line2_pt1=seg2_start_pt,
- line2_pt2=seg2_end_pt)
+ intersection = lines_intersection(
+ line1_pt1=seg1_start_pt,
+ line1_pt2=seg1_end_pt,
+ line2_pt1=seg2_start_pt,
+ line2_pt2=seg2_end_pt,
+ )
if intersection is not None:
max_x_seg1 = max(seg1_start_pt[0], seg1_end_pt[0])
max_x_seg2 = max(seg2_start_pt[0], seg2_end_pt[0])
@@ -93,8 +97,10 @@ def segments_intersection(seg1_start_pt, seg1_end_pt, seg2_start_pt,
max_tmp_x = min(max_x_seg1, max_x_seg2)
min_tmp_y = max(min_y_seg1, min_y_seg2)
max_tmp_y = min(max_y_seg1, max_y_seg2)
- if (min_tmp_x <= intersection[0] <= max_tmp_x and
- min_tmp_y <= intersection[1] <= max_tmp_y):
+ if (
+ min_tmp_x <= intersection[0] <= max_tmp_x
+ and min_tmp_y <= intersection[1] <= max_tmp_y
+ ):
return intersection
else:
return None
diff --git a/src/silx/gui/plot3d/ParamTreeView.py b/src/silx/gui/plot3d/ParamTreeView.py
index b648251..34ed1aa 100644
--- a/src/silx/gui/plot3d/ParamTreeView.py
+++ b/src/silx/gui/plot3d/ParamTreeView.py
@@ -31,12 +31,14 @@ This module contains:
:class:`FloatEditor`, :class:`Vector3DEditor`,
:class:`Vector4DEditor`, :class:`IntSliderEditor`, :class:`BooleanEditor`
"""
+from __future__ import annotations
__authors__ = ["T. Vincent"]
__license__ = "MIT"
__date__ = "05/12/2017"
+from collections.abc import Sequence
import numbers
import sys
@@ -49,25 +51,19 @@ class FloatEditor(_FloatEdit):
"""Editor widget for float.
:param parent: The widget's parent
- :param float value: The initial editor value
+ :param value: The initial editor value
"""
- valueChanged = qt.Signal(float)
- """Signal emitted when the float value has changed"""
-
- def __init__(self, parent=None, value=None):
+ def __init__(self, parent: qt.QWidget | None = None, value: float | None = None):
super(FloatEditor, self).__init__(parent, value)
self.setAlignment(qt.Qt.AlignLeft)
- self.editingFinished.connect(self._emit)
-
- def _emit(self):
- self.valueChanged.emit(self.value)
- value = qt.Property(float,
- fget=_FloatEdit.value,
- fset=_FloatEdit.setValue,
- user=True,
- notify=valueChanged)
+ valueProperty = qt.Property(
+ float,
+ fget=_FloatEdit.value,
+ fset=_FloatEdit.setValue,
+ user=True,
+ )
"""Qt user property of the float value this widget edits"""
@@ -78,59 +74,49 @@ class Vector3DEditor(qt.QWidget):
:param flags: The widgets's flags
"""
- valueChanged = qt.Signal(qt.QVector3D)
- """Signal emitted when the QVector3D value has changed"""
-
- def __init__(self, parent=None, flags=qt.Qt.Widget):
+ def __init__(
+ self,
+ parent: qt.QWidget | None = None,
+ flags: qt.Qt.WindowType = qt.Qt.Widget,
+ ):
super(Vector3DEditor, self).__init__(parent, flags)
layout = qt.QHBoxLayout(self)
# layout.setSpacing(0)
layout.setContentsMargins(0, 0, 0, 0)
- self.setLayout(layout)
- self._xEdit = _FloatEdit(parent=self, value=0.)
+
+ self._xEdit = _FloatEdit(parent=self, value=0.0)
self._xEdit.setAlignment(qt.Qt.AlignLeft)
- # self._xEdit.editingFinished.connect(self._emit)
- self._yEdit = _FloatEdit(parent=self, value=0.)
+ self._yEdit = _FloatEdit(parent=self, value=0.0)
self._yEdit.setAlignment(qt.Qt.AlignLeft)
- # self._yEdit.editingFinished.connect(self._emit)
- self._zEdit = _FloatEdit(parent=self, value=0.)
+ self._zEdit = _FloatEdit(parent=self, value=0.0)
self._zEdit.setAlignment(qt.Qt.AlignLeft)
- # self._zEdit.editingFinished.connect(self._emit)
- layout.addWidget(qt.QLabel('x:'))
+
+ layout.addWidget(qt.QLabel("x:"))
layout.addWidget(self._xEdit)
- layout.addWidget(qt.QLabel('y:'))
+ layout.addWidget(qt.QLabel("y:"))
layout.addWidget(self._yEdit)
- layout.addWidget(qt.QLabel('z:'))
+ layout.addWidget(qt.QLabel("z:"))
layout.addWidget(self._zEdit)
layout.addStretch(1)
- def _emit(self):
- vector = self.value
- self.valueChanged.emit(vector)
-
- def getValue(self):
- """Returns the QVector3D value of this widget
-
- :rtype: QVector3D
- """
+ def getValue(self) -> qt.QVector3D:
+ """Returns the QVector3D value of this widget"""
return qt.QVector3D(
- self._xEdit.value(), self._yEdit.value(), self._zEdit.value())
-
- def setValue(self, value):
- """Set the QVector3D value
+ self._xEdit.value(), self._yEdit.value(), self._zEdit.value()
+ )
- :param QVector3D value: The new value
- """
+ def setValue(self, value: qt.QVector3D):
+ """Set the QVector3D value"""
self._xEdit.setValue(value.x())
self._yEdit.setValue(value.y())
self._zEdit.setValue(value.z())
- self.valueChanged.emit(value)
- value = qt.Property(qt.QVector3D,
- fget=getValue,
- fset=setValue,
- user=True,
- notify=valueChanged)
+ value = qt.Property(
+ qt.QVector3D,
+ fget=getValue,
+ fset=setValue,
+ user=True,
+ )
"""Qt user property of the QVector3D value this widget edits"""
@@ -141,65 +127,57 @@ class Vector4DEditor(qt.QWidget):
:param flags: The widgets's flags
"""
- valueChanged = qt.Signal(qt.QVector4D)
- """Signal emitted when the QVector4D value has changed"""
-
- def __init__(self, parent=None, flags=qt.Qt.Widget):
+ def __init__(
+ self,
+ parent: qt.QWidget | None = None,
+ flags: qt.Qt.WindowType = qt.Qt.Widget,
+ ):
super(Vector4DEditor, self).__init__(parent, flags)
layout = qt.QHBoxLayout(self)
# layout.setSpacing(0)
layout.setContentsMargins(0, 0, 0, 0)
- self.setLayout(layout)
- self._xEdit = _FloatEdit(parent=self, value=0.)
+
+ self._xEdit = _FloatEdit(parent=self, value=0.0)
self._xEdit.setAlignment(qt.Qt.AlignLeft)
- # self._xEdit.editingFinished.connect(self._emit)
- self._yEdit = _FloatEdit(parent=self, value=0.)
+ self._yEdit = _FloatEdit(parent=self, value=0.0)
self._yEdit.setAlignment(qt.Qt.AlignLeft)
- # self._yEdit.editingFinished.connect(self._emit)
- self._zEdit = _FloatEdit(parent=self, value=0.)
+ self._zEdit = _FloatEdit(parent=self, value=0.0)
self._zEdit.setAlignment(qt.Qt.AlignLeft)
- # self._zEdit.editingFinished.connect(self._emit)
- self._wEdit = _FloatEdit(parent=self, value=0.)
+ self._wEdit = _FloatEdit(parent=self, value=0.0)
self._wEdit.setAlignment(qt.Qt.AlignLeft)
- # self._wEdit.editingFinished.connect(self._emit)
- layout.addWidget(qt.QLabel('x:'))
+
+ layout.addWidget(qt.QLabel("x:"))
layout.addWidget(self._xEdit)
- layout.addWidget(qt.QLabel('y:'))
+ layout.addWidget(qt.QLabel("y:"))
layout.addWidget(self._yEdit)
- layout.addWidget(qt.QLabel('z:'))
+ layout.addWidget(qt.QLabel("z:"))
layout.addWidget(self._zEdit)
- layout.addWidget(qt.QLabel('w:'))
+ layout.addWidget(qt.QLabel("w:"))
layout.addWidget(self._wEdit)
layout.addStretch(1)
- def _emit(self):
- vector = self.value
- self.valueChanged.emit(vector)
-
- def getValue(self):
- """Returns the QVector4D value of this widget
-
- :rtype: QVector4D
- """
- return qt.QVector4D(self._xEdit.value(), self._yEdit.value(),
- self._zEdit.value(), self._wEdit.value())
-
- def setValue(self, value):
- """Set the QVector4D value
-
- :param QVector4D value: The new value
- """
+ def getValue(self) -> qt.QVector4D:
+ """Returns the QVector4D value of this widget"""
+ return qt.QVector4D(
+ self._xEdit.value(),
+ self._yEdit.value(),
+ self._zEdit.value(),
+ self._wEdit.value(),
+ )
+
+ def setValue(self, value: qt.QVector4D):
+ """Set the QVector4D value"""
self._xEdit.setValue(value.x())
self._yEdit.setValue(value.y())
self._zEdit.setValue(value.z())
self._wEdit.setValue(value.w())
- self.valueChanged.emit(value)
- value = qt.Property(qt.QVector4D,
- fget=getValue,
- fset=setValue,
- user=True,
- notify=valueChanged)
+ value = qt.Property(
+ qt.QVector4D,
+ fget=getValue,
+ fset=setValue,
+ user=True,
+ )
"""Qt user property of the QVector4D value this widget edits"""
@@ -211,7 +189,7 @@ class IntSliderEditor(qt.QSlider):
:param parent: The widget's parent
"""
- def __init__(self, parent=None):
+ def __init__(self, parent: qt.QWidget | None = None):
super(IntSliderEditor, self).__init__(parent)
self.setOrientation(qt.Qt.Horizontal)
self.setSingleStep(1)
@@ -222,14 +200,39 @@ class IntSliderEditor(qt.QSlider):
class BooleanEditor(qt.QCheckBox):
"""Checkbox editor for bool.
- This is a QCheckBox with white background.
+ Wrap a QCheckBox to define a different user property with `clicked` signal.
:param parent: The widget's parent
"""
- def __init__(self, parent=None):
+ valueChanged = qt.Signal(bool)
+ """Signal emitted when value is changed by the user"""
+
+ def __init__(self, parent: qt.QWidget | None = None):
super(BooleanEditor, self).__init__(parent)
- self.setStyleSheet("background: white;")
+ self.setBackgroundRole(qt.QPalette.Base)
+ self.setAutoFillBackground(True)
+
+ layout = qt.QHBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ self.__checkbox = qt.QCheckBox(self)
+ self.__checkbox.clicked.connect(self.valueChanged)
+ layout.addWidget(self.__checkbox)
+
+ def getValue(self) -> bool:
+ return self.__checkbox.isChecked()
+
+ def setValue(self, value: bool):
+ self.__checkbox.setChecked(value)
+
+ value = qt.Property(
+ bool,
+ fget=getValue,
+ fset=setValue,
+ user=True,
+ notify=valueChanged,
+ )
+ """Qt user property of the bool value this widget edits"""
class ParameterTreeDelegate(qt.QStyledItemDelegate):
@@ -248,77 +251,60 @@ class ParameterTreeDelegate(qt.QStyledItemDelegate):
}
"""Specific editors for different type of data"""
- def __init__(self, parent=None):
+ def __init__(self, parent: qt.QWidget | None = None):
super(ParameterTreeDelegate, self).__init__(parent)
- def paint(self, painter, option, index):
+ def paint(
+ self,
+ painter: qt.QPainter,
+ option: qt.QStyleOptionViewItem,
+ index: qt.QModelIndex,
+ ):
"""See :meth:`QStyledItemDelegate.paint`"""
data = index.data(qt.Qt.DisplayRole)
- if isinstance(data, (qt.QVector3D, qt.QVector4D)):
- if isinstance(data, qt.QVector3D):
- text = '(x: %g; y: %g; z: %g)' % (data.x(), data.y(), data.z())
- elif isinstance(data, qt.QVector4D):
- text = '(%g; %g; %g; %g)' % (data.x(), data.y(), data.z(), data.w())
- else:
- text = ''
-
- painter.save()
- painter.setRenderHint(qt.QPainter.Antialiasing, True)
-
- # Select palette color group
- colorGroup = qt.QPalette.Inactive
- if option.state & qt.QStyle.State_Active:
- colorGroup = qt.QPalette.Active
- if not option.state & qt.QStyle.State_Enabled:
- colorGroup = qt.QPalette.Disabled
-
- # Draw background if selected
- if option.state & qt.QStyle.State_Selected:
- brush = option.palette.brush(colorGroup,
- qt.QPalette.Highlight)
- painter.fillRect(option.rect, brush)
-
- # Draw text
- if option.state & qt.QStyle.State_Selected:
- colorRole = qt.QPalette.HighlightedText
- else:
- colorRole = qt.QPalette.WindowText
- color = option.palette.color(colorGroup, colorRole)
- painter.setPen(qt.QPen(color))
- painter.drawText(option.rect, qt.Qt.AlignLeft, text)
-
- painter.restore()
-
- # The following commented code does the same as QPainter based code
- # but it does not work with PySide
- # self.initStyleOption(option, index)
- # option.text = text
- # widget = option.widget
- # style = qt.QApplication.style() if not widget else widget.style()
- # style.drawControl(qt.QStyle.CE_ItemViewItem, option, painter, widget)
+ if not isinstance(data, (qt.QVector3D, qt.QVector4D)):
+ super(ParameterTreeDelegate, self).paint(painter, option, index)
+ return
+ if isinstance(data, qt.QVector3D):
+ text = "(x: %g; y: %g; z: %g)" % (data.x(), data.y(), data.z())
+ elif isinstance(data, qt.QVector4D):
+ text = "(%g; %g; %g; %g)" % (data.x(), data.y(), data.z(), data.w())
else:
- super(ParameterTreeDelegate, self).paint(painter, option, index)
+ text = ""
+
+ self.initStyleOption(option, index)
+ option.text = text
+ widget = option.widget
+ style = qt.QApplication.style() if not widget else widget.style()
+ style.drawControl(qt.QStyle.CE_ItemViewItem, option, painter, widget)
def _commit(self, *args):
"""Commit data to the model from editors"""
sender = self.sender()
self.commitData.emit(sender)
- def editorEvent(self, event, model, option, index):
+ def editorEvent(
+ self,
+ event: qt.QEvent,
+ model: qt.QAbstractItemModel,
+ option: qt.QStyleOptionViewItem,
+ index: qt.QModelIndex,
+ ):
"""See :meth:`QStyledItemDelegate.editorEvent`"""
- if (event.type() == qt.QEvent.MouseButtonPress and
- isinstance(index.data(qt.Qt.EditRole), qt.QColor)):
+ if event.type() == qt.QEvent.MouseButtonPress and isinstance(
+ index.data(qt.Qt.EditRole), qt.QColor
+ ):
initialColor = index.data(qt.Qt.EditRole)
- def callback(color):
+ def callback(color: qt.QColor):
theModel = index.model()
theModel.setData(index, color, qt.Qt.EditRole)
dialog = qt.QColorDialog(self.parent())
# dialog.setOption(qt.QColorDialog.ShowAlphaChannel, True)
- if sys.platform == 'darwin':
+ if sys.platform == "darwin":
# Use of native color dialog on macos might cause problems
dialog.setOption(qt.QColorDialog.DontUseNativeDialog, True)
dialog.setCurrentColor(initialColor)
@@ -330,9 +316,15 @@ class ParameterTreeDelegate(qt.QStyledItemDelegate):
return True
else:
return super(ParameterTreeDelegate, self).editorEvent(
- event, model, option, index)
-
- def createEditor(self, parent, option, index):
+ event, model, option, index
+ )
+
+ def createEditor(
+ self,
+ parent: qt.QWidget,
+ option: qt.QStyleOptionViewItem,
+ index: qt.QModelIndex,
+ ):
"""See :meth:`QStyledItemDelegate.createEditor`"""
data = index.data(qt.Qt.EditRole)
editorHint = index.data(qt.Qt.UserRole)
@@ -372,14 +364,8 @@ class ParameterTreeDelegate(qt.QStyledItemDelegate):
if userProperty.isValid() and userProperty.hasNotifySignal():
notifySignal = userProperty.notifySignal()
signature = notifySignal.methodSignature()
- if qt.BINDING == 'PySide2':
- signature = signature.data()
- else:
- signature = bytes(signature)
-
- if hasattr(signature, 'decode'): # For PySide with python3
- signature = signature.decode('ascii')
- signalName = signature.split('(')[0]
+ signature = bytes(signature).decode("ascii")
+ signalName = signature.split("(")[0]
signal = getattr(editor, signalName)
signal.connect(self._commit)
@@ -387,12 +373,18 @@ class ParameterTreeDelegate(qt.QStyledItemDelegate):
else: # Default handling for default types
return super(ParameterTreeDelegate, self).createEditor(
- parent, option, index)
+ parent, option, index
+ )
editor.setAutoFillBackground(True)
return editor
- def setModelData(self, editor, model, index):
+ def setModelData(
+ self,
+ editor: qt.QWidget,
+ model: qt.QAbstractItemModel,
+ index: qt.QModelIndex,
+ ):
"""See :meth:`QStyledItemDelegate.setModelData`"""
if isinstance(editor, tuple(self.EDITORS.values())):
# Special handling of Python classes
@@ -420,7 +412,7 @@ class ParamTreeView(qt.QTreeView):
:param parent: The widget's parent.
"""
- def __init__(self, parent=None):
+ def __init__(self, parent: qt.QWidget | None = None):
super(ParamTreeView, self).__init__(parent)
header = self.header()
@@ -435,65 +427,67 @@ class ParamTreeView(qt.QTreeView):
self.expanded.connect(self._expanded)
- self.setEditTriggers(qt.QAbstractItemView.CurrentChanged |
- qt.QAbstractItemView.DoubleClicked)
+ self.setEditTriggers(
+ qt.QAbstractItemView.CurrentChanged | qt.QAbstractItemView.DoubleClicked
+ )
self.__persistentEditors = set()
- def _openEditorForIndex(self, index):
+ def _openEditorForIndex(self, index: qt.QModelIndex):
"""Check if it has to open a persistent editor for a specific cell.
- :param QModelIndex index: The cell index
+ :param index: The cell index
"""
if index.flags() & qt.Qt.ItemIsEditable:
data = index.data(qt.Qt.EditRole)
editorHint = index.data(qt.Qt.UserRole)
- if (isinstance(data, bool) or
- callable(editorHint) or
- (isinstance(data, numbers.Number) and editorHint)):
+ if (
+ isinstance(data, bool)
+ or callable(editorHint)
+ or (isinstance(data, numbers.Number) and editorHint)
+ ):
self.openPersistentEditor(index)
self.__persistentEditors.add(index)
- def _openEditors(self, parent=qt.QModelIndex()):
+ def _openEditors(self, parent: qt.QModelIndex = qt.QModelIndex()):
"""Open persistent editors in a subtree starting at parent.
- :param QModelIndex parent: The root of the subtree to process.
+ :param parent: The root of the subtree to process.
"""
model = self.model()
if model is not None:
for index in visitQAbstractItemModel(model, parent):
self._openEditorForIndex(index)
- def setModel(self, model):
- """Set the model this TreeView is displaying
-
- :param QAbstractItemModel model:
- """
+ def setModel(self, model: qt.QAbstractItemModel):
+ """Set the model this TreeView is displaying"""
super(ParamTreeView, self).setModel(model)
self._openEditors()
- def rowsInserted(self, parent, start, end):
+ def rowsInserted(self, parent: qt.QModelIndex, start: int, end: int):
"""See :meth:`QTreeView.rowsInserted`"""
super(ParamTreeView, self).rowsInserted(parent, start, end)
model = self.model()
if model is not None:
- for row in range(start, end+1):
+ for row in range(start, end + 1):
self._openEditorForIndex(model.index(row, 1, parent))
self._openEditors(model.index(row, 0, parent))
- def _expanded(self, index):
+ def _expanded(self, index: qt.QModelIndex):
"""Handle QTreeView expanded signal"""
name = index.data(qt.Qt.DisplayRole)
- if name == 'Transform':
+ if name == "Transform":
rotateIndex = self.model().index(1, 0, index)
self.setExpanded(rotateIndex, True)
- def dataChanged(self, topLeft, bottomRight, roles=()):
+ def dataChanged(
+ self,
+ topLeft: qt.QModelIndex,
+ bottomRight: qt.QModelIndex,
+ roles: Sequence[int] = (),
+ ):
"""Handle model dataChanged signal eventually closing editors"""
- if roles: # Qt 5
- super(ParamTreeView, self).dataChanged(topLeft, bottomRight, roles)
- else: # Qt4 compatibility
- super(ParamTreeView, self).dataChanged(topLeft, bottomRight)
+ super(ParamTreeView, self).dataChanged(topLeft, bottomRight, roles)
if not roles or qt.Qt.UserRole in roles: # Check editorHint update
for row in range(topLeft.row(), bottomRight.row() + 1):
for column in range(topLeft.column(), bottomRight.column() + 1):
@@ -503,15 +497,15 @@ class ParamTreeView(qt.QTreeView):
self.closePersistentEditor(index)
self._openEditorForIndex(index)
- def _isPersistentEditorOpen(self, index):
- """Returns True if a persistent editor is opened for index
-
- :param QModelIndex index:
- :rtype: bool
- """
+ def _isPersistentEditorOpen(self, index: qt.QModelIndex) -> bool:
+ """Returns True if a persistent editor is opened for index"""
return index in self.__persistentEditors
- def selectionCommand(self, index, event=None):
+ def selectionCommand(
+ self,
+ index: qt.QModelIndex,
+ event: qt.QEvent | None = None,
+ ) -> qt.QItemSelectionModel.SelectionFlag:
"""Filter out selection of not selectable items"""
if index.flags() & qt.Qt.ItemIsSelectable:
return super(ParamTreeView, self).selectionCommand(index, event)
diff --git a/src/silx/gui/plot3d/Plot3DWidget.py b/src/silx/gui/plot3d/Plot3DWidget.py
index 09e06a2..9a88fe3 100644
--- a/src/silx/gui/plot3d/Plot3DWidget.py
+++ b/src/silx/gui/plot3d/Plot3DWidget.py
@@ -66,10 +66,9 @@ class _OverviewViewport(scene.Viewport):
# Add a point to draw the background (in a group with depth mask)
backgroundPoint = primitives.ColorPoints(
- x=0., y=0., z=0.,
- color=(1., 1., 1., 0.5),
- size=self._SIZE)
- backgroundPoint.marker = 'o'
+ x=0.0, y=0.0, z=0.0, color=(1.0, 1.0, 1.0, 0.5), size=self._SIZE
+ )
+ backgroundPoint.marker = "o"
noDepthGroup = primitives.GroupNoDepth(mask=True, notest=True)
noDepthGroup.children.append(backgroundPoint)
self.scene.children.append(noDepthGroup)
@@ -86,11 +85,12 @@ class _OverviewViewport(scene.Viewport):
Sync the overview camera to point in the same direction
but from a sphere centered on origin.
"""
- position = -12. * source.extrinsic.direction
+ position = -12.0 * source.extrinsic.direction
self.camera.extrinsic.position = position
self.camera.extrinsic.setOrientation(
- source.extrinsic.direction, source.extrinsic.up)
+ source.extrinsic.direction, source.extrinsic.up
+ )
class Plot3DWidget(glu.OpenGLWidget):
@@ -116,10 +116,10 @@ class Plot3DWidget(glu.OpenGLWidget):
class FogMode(_Enum):
"""Different mode to render the scene with fog"""
- NONE = 'none'
+ NONE = "none"
"""No fog effect"""
- LINEAR = 'linear'
+ LINEAR = "linear"
"""Linear fog through the whole scene"""
def __init__(self, parent=None, f=qt.Qt.Widget):
@@ -131,7 +131,8 @@ class Plot3DWidget(glu.OpenGLWidget):
depthBufferSize=0,
stencilBufferSize=0,
version=(2, 1),
- f=f)
+ f=f,
+ )
self.setAutoFillBackground(False)
self.setMouseTracking(True)
@@ -145,22 +146,24 @@ class Plot3DWidget(glu.OpenGLWidget):
# Main viewport
self.viewport = scene.Viewport()
- self._sceneScale = transform.Scale(1., 1., 1.)
- self.viewport.scene.transforms = [self._sceneScale,
- transform.Translate(0., 0., 0.)]
+ self._sceneScale = transform.Scale(1.0, 1.0, 1.0)
+ self.viewport.scene.transforms = [
+ self._sceneScale,
+ transform.Translate(0.0, 0.0, 0.0),
+ ]
# Overview area
self.overview = _OverviewViewport(self.viewport.camera)
- self.setBackgroundColor((0.2, 0.2, 0.2, 1.))
+ self.setBackgroundColor((0.2, 0.2, 0.2, 1.0))
# Window describing on screen area to render
- self._window = scene.Window(mode='framebuffer')
+ self._window = scene.Window(mode="framebuffer")
self._window.viewports = [self.viewport, self.overview]
self._window.addListener(self._redraw)
self.eventHandler = None
- self.setInteractiveMode('rotate')
+ self.setInteractiveMode("rotate")
def __clickHandler(self, *args):
"""Handle interaction state machine click"""
@@ -180,31 +183,35 @@ class Plot3DWidget(glu.OpenGLWidget):
if mode is None:
self.eventHandler = None
- elif mode == 'rotate':
+ elif mode == "rotate":
self.eventHandler = interaction.RotateCameraControl(
self.viewport,
orbitAroundCenter=False,
- mode='position',
+ mode="position",
scaleTransform=self._sceneScale,
- selectCB=self.__clickHandler)
+ selectCB=self.__clickHandler,
+ )
- elif mode == 'pan':
+ elif mode == "pan":
self.eventHandler = interaction.PanCameraControl(
self.viewport,
orbitAroundCenter=False,
- mode='position',
+ mode="position",
scaleTransform=self._sceneScale,
- selectCB=self.__clickHandler)
+ selectCB=self.__clickHandler,
+ )
elif isinstance(mode, interaction.StateMachine):
self.eventHandler = mode
else:
- raise ValueError('Unsupported interactive mode %s', str(mode))
+ raise ValueError("Unsupported interactive mode %s", str(mode))
- if (self.eventHandler is not None and
- qt.QApplication.keyboardModifiers() & qt.Qt.ControlModifier):
- self.eventHandler.handleEvent('keyPress', qt.Qt.Key_Control)
+ if (
+ self.eventHandler is not None
+ and qt.QApplication.keyboardModifiers() & qt.Qt.ControlModifier
+ ):
+ self.eventHandler.handleEvent("keyPress", qt.Qt.Key_Control)
self.sigInteractiveModeChanged.emit()
@@ -216,9 +223,9 @@ class Plot3DWidget(glu.OpenGLWidget):
if self.eventHandler is None:
return None
if isinstance(self.eventHandler, interaction.RotateCameraControl):
- return 'rotate'
+ return "rotate"
elif isinstance(self.eventHandler, interaction.PanCameraControl):
- return 'pan'
+ return "pan"
else:
return None
@@ -227,13 +234,12 @@ class Plot3DWidget(glu.OpenGLWidget):
:param str projection: In 'perspective', 'orthographic'.
"""
- if projection == 'orthographic':
+ if projection == "orthographic":
projection = transform.Orthographic(size=self.viewport.size)
- elif projection == 'perspective':
- projection = transform.Perspective(fovy=30.,
- size=self.viewport.size)
+ elif projection == "perspective":
+ projection = transform.Perspective(fovy=30.0, size=self.viewport.size)
else:
- raise RuntimeError('Unsupported projection: %s' % projection)
+ raise RuntimeError("Unsupported projection: %s" % projection)
self.viewport.camera.intrinsic = projection
self.viewport.resetCamera()
@@ -245,11 +251,11 @@ class Plot3DWidget(glu.OpenGLWidget):
"""
projection = self.viewport.camera.intrinsic
if isinstance(projection, transform.Orthographic):
- return 'orthographic'
+ return "orthographic"
elif isinstance(projection, transform.Perspective):
- return 'perspective'
+ return "perspective"
else:
- raise RuntimeError('Unknown projection in use')
+ raise RuntimeError("Unknown projection in use")
def setBackgroundColor(self, color):
"""Set the background color of the OpenGL view.
@@ -261,7 +267,7 @@ class Plot3DWidget(glu.OpenGLWidget):
color = rgba(color)
if color != self.viewport.background:
self.viewport.background = color
- self.sigStyleChanged.emit('backgroundColor')
+ self.sigStyleChanged.emit("backgroundColor")
def getBackgroundColor(self):
"""Returns the RGBA background color (QColor)."""
@@ -276,7 +282,7 @@ class Plot3DWidget(glu.OpenGLWidget):
mode = self.FogMode.from_value(mode)
if mode != self.getFogMode():
self.viewport.fog.isOn = mode is self.FogMode.LINEAR
- self.sigStyleChanged.emit('fogMode')
+ self.sigStyleChanged.emit("fogMode")
def getFogMode(self):
"""Returns the kind of fog in use
@@ -307,13 +313,13 @@ class Plot3DWidget(glu.OpenGLWidget):
self._window.viewports = [self.viewport, self.overview]
else:
self._window.viewports = [self.viewport]
- self.sigStyleChanged.emit('orientationIndicatorVisible')
+ self.sigStyleChanged.emit("orientationIndicatorVisible")
def centerScene(self):
"""Position the center of the scene at the center of rotation."""
self.viewport.resetCamera()
- def resetZoom(self, face='front'):
+ def resetZoom(self, face="front"):
"""Reset the camera position to a default.
:param str face: The direction the camera is looking at:
@@ -344,7 +350,9 @@ class Plot3DWidget(glu.OpenGLWidget):
if self.viewport.dirty:
self.viewport.adjustCameraDepthExtent()
- self._window.render(self.context(), self.getDevicePixelRatio())
+ self._window.render(
+ self.context(), self.getDotsPerInch(), self.getDevicePixelRatio()
+ )
if self._firstRender: # TODO remove this ugly hack
self._firstRender = False
@@ -366,7 +374,7 @@ class Plot3DWidget(glu.OpenGLWidget):
:rtype: QImage
"""
if not self.isValid():
- _logger.error('OpenGL 2.1 not available, cannot save OpenGL image')
+ _logger.error("OpenGL 2.1 not available, cannot save OpenGL image")
height, width = self._window.shape
image = numpy.zeros((height, width, 3), dtype=numpy.uint8)
@@ -380,22 +388,22 @@ class Plot3DWidget(glu.OpenGLWidget):
x, y = qt.getMouseEventPosition(event)
xpixel = x * self.getDevicePixelRatio()
ypixel = y * self.getDevicePixelRatio()
- angle = event.angleDelta().y() / 8.
+ angle = event.angleDelta().y() / 8.0
event.accept()
if self.eventHandler is not None and angle != 0 and self.isValid():
self.makeCurrent()
- self.eventHandler.handleEvent('wheel', xpixel, ypixel, angle)
+ self.eventHandler.handleEvent("wheel", xpixel, ypixel, angle)
def keyPressEvent(self, event):
keyCode = event.key()
# No need to accept QKeyEvent
converter = {
- qt.Qt.Key_Left: 'left',
- qt.Qt.Key_Right: 'right',
- qt.Qt.Key_Up: 'up',
- qt.Qt.Key_Down: 'down'
+ qt.Qt.Key_Left: "left",
+ qt.Qt.Key_Right: "right",
+ qt.Qt.Key_Up: "up",
+ qt.Qt.Key_Down: "down",
}
direction = converter.get(keyCode, None)
if direction is not None:
@@ -407,10 +415,12 @@ class Plot3DWidget(glu.OpenGLWidget):
self.viewport.orbitCamera(direction)
else:
- if (keyCode == qt.Qt.Key_Control and
- self.eventHandler is not None and
- self.isValid()):
- self.eventHandler.handleEvent('keyPress', keyCode)
+ if (
+ keyCode == qt.Qt.Key_Control
+ and self.eventHandler is not None
+ and self.isValid()
+ ):
+ self.eventHandler.handleEvent("keyPress", keyCode)
# Key not handled, call base class implementation
super(Plot3DWidget, self).keyPressEvent(event)
@@ -418,17 +428,19 @@ class Plot3DWidget(glu.OpenGLWidget):
def keyReleaseEvent(self, event):
"""Catch Ctrl key release"""
keyCode = event.key()
- if (keyCode == qt.Qt.Key_Control and
- self.eventHandler is not None and
- self.isValid()):
- self.eventHandler.handleEvent('keyRelease', keyCode)
+ if (
+ keyCode == qt.Qt.Key_Control
+ and self.eventHandler is not None
+ and self.isValid()
+ ):
+ self.eventHandler.handleEvent("keyRelease", keyCode)
super(Plot3DWidget, self).keyReleaseEvent(event)
# Mouse events #
_MOUSE_BTNS = {
- qt.Qt.LeftButton: 'left',
- qt.Qt.RightButton: 'right',
- qt.Qt.MiddleButton: 'middle',
+ qt.Qt.LeftButton: "left",
+ qt.Qt.RightButton: "right",
+ qt.Qt.MiddleButton: "middle",
}
def mousePressEvent(self, event):
@@ -440,7 +452,7 @@ class Plot3DWidget(glu.OpenGLWidget):
if self.eventHandler is not None and self.isValid():
self.makeCurrent()
- self.eventHandler.handleEvent('press', xpixel, ypixel, btn)
+ self.eventHandler.handleEvent("press", xpixel, ypixel, btn)
def mouseMoveEvent(self, event):
x, y = qt.getMouseEventPosition(event)
@@ -450,7 +462,7 @@ class Plot3DWidget(glu.OpenGLWidget):
if self.eventHandler is not None and self.isValid():
self.makeCurrent()
- self.eventHandler.handleEvent('move', xpixel, ypixel)
+ self.eventHandler.handleEvent("move", xpixel, ypixel)
def mouseReleaseEvent(self, event):
x, y = qt.getMouseEventPosition(event)
@@ -461,4 +473,4 @@ class Plot3DWidget(glu.OpenGLWidget):
if self.eventHandler is not None and self.isValid():
self.makeCurrent()
- self.eventHandler.handleEvent('release', xpixel, ypixel, btn)
+ self.eventHandler.handleEvent("release", xpixel, ypixel, btn)
diff --git a/src/silx/gui/plot3d/SFViewParamTree.py b/src/silx/gui/plot3d/SFViewParamTree.py
index cc78cec..6eea5ae 100644
--- a/src/silx/gui/plot3d/SFViewParamTree.py
+++ b/src/silx/gui/plot3d/SFViewParamTree.py
@@ -48,7 +48,7 @@ _logger = logging.getLogger(__name__)
class ModelColumns(object):
NameColumn, ValueColumn, ColumnMax = range(3)
- ColumnNames = ['Name', 'Value']
+ ColumnNames = ["Name", "Value"]
class SubjectItem(qt.QStandardItem):
@@ -86,7 +86,6 @@ class SubjectItem(qt.QStandardItem):
"""
def __init__(self, subject, *args):
-
super(SubjectItem, self).__init__(*args)
self.setEditable(self.editable)
@@ -119,8 +118,7 @@ class SubjectItem(qt.QStandardItem):
@subject.setter
def subject(self, subject):
if self.__subject is not None:
- raise ValueError('Subject already set '
- ' (subject change not supported).')
+ raise ValueError("Subject already set " " (subject change not supported).")
if subject is None:
self.__subject = None
else:
@@ -136,9 +134,8 @@ class SubjectItem(qt.QStandardItem):
def gen_slot(_sigIdx):
def slotfn(*args, **kwargs):
- self._subjectChanged(signalIdx=_sigIdx,
- args=args,
- kwargs=kwargs)
+ self._subjectChanged(signalIdx=_sigIdx, args=args, kwargs=kwargs)
+
return slotfn
if self.__subject is not None:
@@ -293,8 +290,10 @@ class SubjectItem(qt.QStandardItem):
# View settings ###############################################################
+
class ColorItem(SubjectItem):
"""color item."""
+
editable = True
persistent = True
@@ -303,8 +302,7 @@ class ColorItem(SubjectItem):
editor.color = self.getColor()
# Wrapping call in lambda is a workaround for PySide with Python 3
- editor.sigColorChanged.connect(
- lambda color: self._editorSlot(color))
+ editor.sigColorChanged.connect(lambda color: self._editorSlot(color))
return editor
def _editorSlot(self, color):
@@ -323,7 +321,7 @@ class ColorItem(SubjectItem):
class BackgroundColorItem(ColorItem):
- itemName = 'Background'
+ itemName = "Background"
def setColor(self, color):
self.subject.setBackgroundColor(color)
@@ -333,7 +331,7 @@ class BackgroundColorItem(ColorItem):
class ForegroundColorItem(ColorItem):
- itemName = 'Foreground'
+ itemName = "Foreground"
def setColor(self, color):
self.subject.setForegroundColor(color)
@@ -343,7 +341,7 @@ class ForegroundColorItem(ColorItem):
class HighlightColorItem(ColorItem):
- itemName = 'Highlight'
+ itemName = "Highlight"
def setColor(self, color):
self.subject.setHighlightColor(color)
@@ -354,6 +352,7 @@ class HighlightColorItem(ColorItem):
class _LightDirectionAngleBaseItem(SubjectItem):
"""Base class for directional light angle item."""
+
editable = True
persistent = True
@@ -380,8 +379,7 @@ class _LightDirectionAngleBaseItem(SubjectItem):
editor.setValue(int(self._pullData()))
# Wrapping call in lambda is a workaround for PySide with Python 3
- editor.valueChanged.connect(
- lambda value: self._pushData(value))
+ editor.valueChanged.connect(lambda value: self._pushData(value))
return editor
@@ -402,10 +400,10 @@ class LightAzimuthAngleItem(_LightDirectionAngleBaseItem):
return self.subject.sigAzimuthAngleChanged
def _pullData(self):
- return self.subject.getAzimuthAngle()
+ return self.subject.getAzimuthAngle()
def _pushData(self, value, role=qt.Qt.UserRole):
- self.subject.setAzimuthAngle(value)
+ self.subject.setAzimuthAngle(value)
class LightAltitudeAngleItem(_LightDirectionAngleBaseItem):
@@ -415,15 +413,14 @@ class LightAltitudeAngleItem(_LightDirectionAngleBaseItem):
return self.subject.sigAltitudeAngleChanged
def _pullData(self):
- return self.subject.getAltitudeAngle()
+ return self.subject.getAltitudeAngle()
def _pushData(self, value, role=qt.Qt.UserRole):
- self.subject.setAltitudeAngle(value)
+ self.subject.setAltitudeAngle(value)
class _DirectionalLightProxy(qt.QObject):
- """Proxy to handle directional light with angles rather than vector.
- """
+ """Proxy to handle directional light with angles rather than vector."""
sigAzimuthAngleChanged = qt.Signal()
"""Signal sent when the azimuth angle has changed."""
@@ -435,8 +432,8 @@ class _DirectionalLightProxy(qt.QObject):
super(_DirectionalLightProxy, self).__init__()
self._light = light
light.addListener(self._directionUpdated)
- self._azimuth = 0.
- self._altitude = 0.
+ self._azimuth = 0.0
+ self._altitude = 0.0
def getAzimuthAngle(self):
"""Returns the signed angle in the horizontal plane.
@@ -482,14 +479,16 @@ class _DirectionalLightProxy(qt.QObject):
"""Handle light direction update in the scene"""
# Invert direction to manipulate the 'source' pointing to
# the center of the viewport
- x, y, z = - self._light.direction
+ x, y, z = -self._light.direction
# Horizontal plane is plane xz
azimuth = numpy.degrees(numpy.arctan2(x, z))
- altitude = numpy.degrees(numpy.pi/2. - numpy.arccos(y))
+ altitude = numpy.degrees(numpy.pi / 2.0 - numpy.arccos(y))
- if (abs(azimuth - self.getAzimuthAngle()) > 0.01 and
- abs(abs(altitude) - 90.) >= 0.001): # Do not update when at zenith
+ if (
+ abs(azimuth - self.getAzimuthAngle()) > 0.01
+ and abs(abs(altitude) - 90.0) >= 0.001
+ ): # Do not update when at zenith
self.setAzimuthAngle(azimuth)
if abs(altitude - self.getAltitudeAngle()) > 0.01:
@@ -498,10 +497,10 @@ class _DirectionalLightProxy(qt.QObject):
def _updateLight(self):
"""Update light direction in the scene"""
azimuth = numpy.radians(self._azimuth)
- delta = numpy.pi/2. - numpy.radians(self._altitude)
- z = - numpy.sin(delta) * numpy.cos(azimuth)
- x = - numpy.sin(delta) * numpy.sin(azimuth)
- y = - numpy.cos(delta)
+ delta = numpy.pi / 2.0 - numpy.radians(self._altitude)
+ z = -numpy.sin(delta) * numpy.cos(azimuth)
+ x = -numpy.sin(delta) * numpy.sin(azimuth)
+ y = -numpy.cos(delta)
self._light.direction = x, y, z
@@ -510,20 +509,18 @@ class DirectionalLightGroup(SubjectItem):
Root Item for the directional light
"""
- def __init__(self,subject, *args):
- self._light = _DirectionalLightProxy(
- subject.getPlot3DWidget().viewport.light)
+ def __init__(self, subject, *args):
+ self._light = _DirectionalLightProxy(subject.getPlot3DWidget().viewport.light)
super(DirectionalLightGroup, self).__init__(subject, *args)
def _init(self):
-
- nameItem = qt.QStandardItem('Azimuth')
+ nameItem = qt.QStandardItem("Azimuth")
nameItem.setEditable(False)
valueItem = LightAzimuthAngleItem(self._light)
self.appendRow([nameItem, valueItem])
- nameItem = qt.QStandardItem('Altitude')
+ nameItem = qt.QStandardItem("Altitude")
nameItem.setEditable(False)
valueItem = LightAltitudeAngleItem(self._light)
self.appendRow([nameItem, valueItem])
@@ -534,7 +531,8 @@ class BoundingBoxItem(SubjectItem):
Item is checkable.
"""
- itemName = 'Bounding Box'
+
+ itemName = "Bounding Box"
def _init(self):
visible = self.subject.isBoundingBoxVisible()
@@ -542,7 +540,7 @@ class BoundingBoxItem(SubjectItem):
self.setCheckState(qt.Qt.Checked if visible else qt.Qt.Unchecked)
def leftClicked(self):
- checked = (self.checkState() == qt.Qt.Checked)
+ checked = self.checkState() == qt.Qt.Checked
if checked != self.subject.isBoundingBoxVisible():
self.subject.setBoundingBoxVisible(checked)
@@ -552,7 +550,8 @@ class OrientationIndicatorItem(SubjectItem):
Item is checkable.
"""
- itemName = 'Axes indicator'
+
+ itemName = "Axes indicator"
def _init(self):
plot3d = self.subject.getPlot3DWidget()
@@ -562,7 +561,7 @@ class OrientationIndicatorItem(SubjectItem):
def leftClicked(self):
plot3d = self.subject.getPlot3DWidget()
- checked = (self.checkState() == qt.Qt.Checked)
+ checked = self.checkState() == qt.Qt.Checked
if checked != plot3d.isOrientationIndicatorVisible():
plot3d.setOrientationIndicatorVisible(checked)
@@ -571,28 +570,30 @@ class ViewSettingsItem(qt.QStandardItem):
"""Viewport settings"""
def __init__(self, subject, *args):
-
super(ViewSettingsItem, self).__init__(*args)
self.setEditable(False)
- classes = (BackgroundColorItem,
- ForegroundColorItem,
- HighlightColorItem,
- BoundingBoxItem,
- OrientationIndicatorItem)
+ classes = (
+ BackgroundColorItem,
+ ForegroundColorItem,
+ HighlightColorItem,
+ BoundingBoxItem,
+ OrientationIndicatorItem,
+ )
for cls in classes:
titleItem = qt.QStandardItem(cls.itemName)
titleItem.setEditable(False)
self.appendRow([titleItem, cls(subject)])
- nameItem = DirectionalLightGroup(subject, 'Light Direction')
+ nameItem = DirectionalLightGroup(subject, "Light Direction")
valueItem = qt.QStandardItem()
self.appendRow([nameItem, valueItem])
# Data information ############################################################
+
class DataChangedItem(SubjectItem):
"""
Base class for items listening to ScalarFieldView.sigDataChanged
@@ -609,42 +610,41 @@ class DataChangedItem(SubjectItem):
class DataTypeItem(DataChangedItem):
- itemName = 'dtype'
+ itemName = "dtype"
def _pullData(self):
data = self.subject.getData(copy=False)
- return ((data is not None) and str(data.dtype)) or 'N/A'
+ return ((data is not None) and str(data.dtype)) or "N/A"
class DataShapeItem(DataChangedItem):
- itemName = 'size'
+ itemName = "size"
def _pullData(self):
data = self.subject.getData(copy=False)
if data is None:
- return 'N/A'
+ return "N/A"
else:
return str(list(reversed(data.shape)))
class OffsetItem(DataChangedItem):
- itemName = 'offset'
+ itemName = "offset"
def _pullData(self):
offset = self.subject.getTranslation()
- return ((offset is not None) and str(offset)) or 'N/A'
+ return ((offset is not None) and str(offset)) or "N/A"
class ScaleItem(DataChangedItem):
- itemName = 'scale'
+ itemName = "scale"
def _pullData(self):
scale = self.subject.getScale()
- return ((scale is not None) and str(scale)) or 'N/A'
+ return ((scale is not None) and str(scale)) or "N/A"
class MatrixItem(DataChangedItem):
-
def __init__(self, subject, row, *args):
self.__row = row
super(MatrixItem, self).__init__(subject, *args)
@@ -655,9 +655,7 @@ class MatrixItem(DataChangedItem):
class DataSetItem(qt.QStandardItem):
-
def __init__(self, subject, *args):
-
super(DataSetItem, self).__init__(*args)
self.setEditable(False)
@@ -668,7 +666,7 @@ class DataSetItem(qt.QStandardItem):
titleItem.setEditable(False)
self.appendRow([titleItem, klass(subject)])
- matrixItem = qt.QStandardItem('matrix')
+ matrixItem = qt.QStandardItem("matrix")
matrixItem.setEditable(False)
valueItem = qt.QStandardItem()
self.appendRow([matrixItem, valueItem])
@@ -686,6 +684,7 @@ class DataSetItem(qt.QStandardItem):
# Isosurface ##################################################################
+
class IsoSurfaceRootItem(SubjectItem):
"""
Root (i.e : column index 0) Isosurface item.
@@ -697,8 +696,7 @@ class IsoSurfaceRootItem(SubjectItem):
def getSignals(self):
subject = self.subject
- return [subject.sigColorChanged,
- subject.sigVisibilityChanged]
+ return [subject.sigColorChanged, subject.sigVisibilityChanged]
def _subjectChanged(self, signalIdx=None, args=None, kwargs=None):
if signalIdx == 0:
@@ -717,17 +715,18 @@ class IsoSurfaceRootItem(SubjectItem):
self.setData(color, qt.Qt.DecorationRole)
self.setCheckState((visible and qt.Qt.Checked) or qt.Qt.Unchecked)
- nameItem = qt.QStandardItem('Level')
- sliderItem = IsoSurfaceLevelSlider(self.subject,
- self._isoLevelSliderNormalization)
+ nameItem = qt.QStandardItem("Level")
+ sliderItem = IsoSurfaceLevelSlider(
+ self.subject, self._isoLevelSliderNormalization
+ )
self.appendRow([nameItem, sliderItem])
- nameItem = qt.QStandardItem('Color')
+ nameItem = qt.QStandardItem("Color")
nameItem.setEditable(False)
valueItem = IsoSurfaceColorItem(self.subject)
self.appendRow([nameItem, valueItem])
- nameItem = qt.QStandardItem('Opacity')
+ nameItem = qt.QStandardItem("Opacity")
nameItem.setTextAlignment(qt.Qt.AlignLeft | qt.Qt.AlignTop)
nameItem.setEditable(False)
valueItem = IsoSurfaceAlphaItem(self.subject)
@@ -741,10 +740,12 @@ class IsoSurfaceRootItem(SubjectItem):
def queryRemove(self, view=None):
buttons = qt.QMessageBox.Ok | qt.QMessageBox.Cancel
- ans = qt.QMessageBox.question(view,
- 'Remove isosurface',
- 'Remove the selected iso-surface?',
- buttons=buttons)
+ ans = qt.QMessageBox.question(
+ view,
+ "Remove isosurface",
+ "Remove the selected iso-surface?",
+ buttons=buttons,
+ )
if ans == qt.QMessageBox.Ok:
sfview = self.subject.parent()
if sfview:
@@ -753,7 +754,7 @@ class IsoSurfaceRootItem(SubjectItem):
return False
def leftClicked(self):
- checked = (self.checkState() == qt.Qt.Checked)
+ checked = self.checkState() == qt.Qt.Checked
visible = self.subject.isVisible()
if checked != visible:
self.subject.setVisible(checked)
@@ -763,12 +764,12 @@ class IsoSurfaceLevelItem(SubjectItem):
"""
Base class for the isosurface level items.
"""
+
editable = True
def getSignals(self):
subject = self.subject
- return [subject.sigLevelChanged,
- subject.sigVisibilityChanged]
+ return [subject.sigLevelChanged, subject.sigVisibilityChanged]
def getEditor(self, parent, option, index):
return FloatEdit(parent)
@@ -796,15 +797,14 @@ class _IsoLevelSlider(qt.QSlider):
super(_IsoLevelSlider, self).__init__(parent=parent)
self.subject = subject
- if normalization == 'arcsinh':
+ if normalization == "arcsinh":
self.__norm = numpy.arcsinh
self.__invNorm = numpy.sinh
- elif normalization == 'linear':
+ elif normalization == "linear":
self.__norm = lambda x: x
self.__invNorm = lambda x: x
else:
- raise ValueError(
- "Unsupported normalization %s", normalization)
+ raise ValueError("Unsupported normalization %s", normalization)
self.sliderReleased.connect(self.__sliderReleased)
@@ -845,6 +845,7 @@ class IsoSurfaceLevelSlider(IsoSurfaceLevelItem):
"""
Isosurface level item with a slider editor.
"""
+
nTicks = 1000
persistent = True
@@ -874,6 +875,7 @@ class IsoSurfaceColorItem(SubjectItem):
"""
Isosurface color item.
"""
+
editable = True
persistent = True
@@ -886,8 +888,7 @@ class IsoSurfaceColorItem(SubjectItem):
color.setAlpha(255)
editor.color = color
# Wrapping call in lambda is a workaround for PySide with Python 3
- editor.sigColorChanged.connect(
- lambda color: self.__editorChanged(color))
+ editor.sigColorChanged.connect(lambda color: self.__editorChanged(color))
return editor
def __editorChanged(self, color):
@@ -903,6 +904,7 @@ class QColorEditor(qt.QWidget):
"""
QColor editor.
"""
+
sigColorChanged = qt.Signal(object)
color = property(lambda self: qt.QColor(self.__color))
@@ -938,7 +940,7 @@ class QColorEditor(qt.QWidget):
def __showColorDialog(self):
dialog = qt.QColorDialog(parent=self)
- if sys.platform == 'darwin':
+ if sys.platform == "darwin":
# Use of native color dialog on macos might cause problems
dialog.setOption(qt.QColorDialog.DontUseNativeDialog, True)
@@ -964,6 +966,7 @@ class IsoSurfaceAlphaItem(SubjectItem):
"""
Isosurface alpha item.
"""
+
editable = True
persistent = True
@@ -983,8 +986,7 @@ class IsoSurfaceAlphaItem(SubjectItem):
editor.setValue(color.alpha())
# Wrapping call in lambda is a workaround for PySide with Python 3
- editor.valueChanged.connect(
- lambda value: self.__editorChanged(value))
+ editor.valueChanged.connect(lambda value: self.__editorChanged(value))
return editor
@@ -1010,9 +1012,9 @@ class IsoSurfaceAlphaLegendItem(SubjectItem):
layout = qt.QHBoxLayout()
layout.setContentsMargins(0, 0, 0, 0)
layout.setSpacing(0)
- layout.addWidget(qt.QLabel('0'))
+ layout.addWidget(qt.QLabel("0"))
layout.addStretch(1)
- layout.addWidget(qt.QLabel('1'))
+ layout.addWidget(qt.QLabel("1"))
editor = qt.QWidget(parent)
editor.setLayout(layout)
@@ -1033,7 +1035,6 @@ class IsoSurfaceCount(SubjectItem):
class IsoSurfaceAddRemoveWidget(qt.QWidget):
-
sigViewTask = qt.Signal(str)
"""Signal for the tree view to perform some task"""
@@ -1045,13 +1046,13 @@ class IsoSurfaceAddRemoveWidget(qt.QWidget):
layout.setSpacing(0)
addBtn = qt.QToolButton(self)
- addBtn.setText('+')
+ addBtn.setText("+")
addBtn.setToolButtonStyle(qt.Qt.ToolButtonTextOnly)
layout.addWidget(addBtn)
addBtn.clicked.connect(self.__addClicked)
removeBtn = qt.QToolButton(self)
- removeBtn.setText('-')
+ removeBtn.setText("-")
removeBtn.setToolButtonStyle(qt.Qt.ToolButtonTextOnly)
layout.addWidget(removeBtn)
removeBtn.clicked.connect(self.__removeClicked)
@@ -1066,17 +1067,17 @@ class IsoSurfaceAddRemoveWidget(qt.QWidget):
if dataRange is None:
dataRange = [0, 1]
- sfview.addIsosurface(
- numpy.mean((dataRange[0], dataRange[-1])), '#0000FF')
+ sfview.addIsosurface(numpy.mean((dataRange[0], dataRange[-1])), "#0000FF")
def __removeClicked(self):
- self.sigViewTask.emit('remove_iso')
+ self.sigViewTask.emit("remove_iso")
class IsoSurfaceAddRemoveItem(SubjectItem):
"""
Item displaying a simple QToolButton allowing to add an isosurface.
"""
+
persistent = True
def getEditor(self, parent, option, index):
@@ -1101,30 +1102,30 @@ class IsoSurfaceGroup(SubjectItem):
if len(args) >= 1:
isosurface = args[0]
if not isinstance(isosurface, Isosurface):
- raise ValueError('Expected an isosurface instance.')
+ raise ValueError("Expected an isosurface instance.")
self.__addIsosurface(isosurface)
else:
- raise ValueError('Expected an isosurface instance.')
+ raise ValueError("Expected an isosurface instance.")
elif signalIdx == 1:
if len(args) >= 1:
isosurface = args[0]
if not isinstance(isosurface, Isosurface):
- raise ValueError('Expected an isosurface instance.')
+ raise ValueError("Expected an isosurface instance.")
self.__removeIsosurface(isosurface)
else:
- raise ValueError('Expected an isosurface instance.')
+ raise ValueError("Expected an isosurface instance.")
def __addIsosurface(self, isosurface):
valueItem = IsoSurfaceRootItem(
- subject=isosurface,
- normalization=self._isoLevelSliderNormalization)
+ subject=isosurface, normalization=self._isoLevelSliderNormalization
+ )
nameItem = IsoSurfaceLevelItem(subject=isosurface)
self.insertRow(max(0, self.rowCount() - 1), [valueItem, nameItem])
def __removeIsosurface(self, isosurface):
for row in range(self.rowCount()):
child = self.child(row)
- subject = getattr(child, 'subject', None)
+ subject = getattr(child, "subject", None)
if subject == isosurface:
self.takeRow(row)
break
@@ -1143,6 +1144,7 @@ class IsoSurfaceGroup(SubjectItem):
# Cutting Plane ###############################################################
+
class ColormapBase(SubjectItem):
"""
Mixin class for colormap items.
@@ -1157,6 +1159,7 @@ class PlaneMinRangeItem(ColormapBase):
colormap minVal item.
Editor is a QLineEdit with a QDoubleValidator
"""
+
editable = True
def _pullData(self):
@@ -1197,6 +1200,7 @@ class PlaneMaxRangeItem(ColormapBase):
colormap maxVal item.
Editor is a QLineEdit with a QDoubleValidator
"""
+
editable = True
def _pullData(self):
@@ -1233,27 +1237,39 @@ class PlaneOrientationItem(SubjectItem):
Plane orientation item.
Editor is a QComboBox.
"""
+
editable = True
_PLANE_ACTIONS = (
- ('3d-plane-normal-x', 'Plane 0',
- 'Set plane perpendicular to red axis', (1., 0., 0.)),
- ('3d-plane-normal-y', 'Plane 1',
- 'Set plane perpendicular to green axis', (0., 1., 0.)),
- ('3d-plane-normal-z', 'Plane 2',
- 'Set plane perpendicular to blue axis', (0., 0., 1.)),
+ (
+ "3d-plane-normal-x",
+ "Plane 0",
+ "Set plane perpendicular to red axis",
+ (1.0, 0.0, 0.0),
+ ),
+ (
+ "3d-plane-normal-y",
+ "Plane 1",
+ "Set plane perpendicular to green axis",
+ (0.0, 1.0, 0.0),
+ ),
+ (
+ "3d-plane-normal-z",
+ "Plane 2",
+ "Set plane perpendicular to blue axis",
+ (0.0, 0.0, 1.0),
+ ),
)
def getSignals(self):
return [self.subject.getCutPlanes()[0].sigPlaneChanged]
def _pullData(self):
- currentNormal = self.subject.getCutPlanes()[0].getNormal(
- coordinates='scene')
+ currentNormal = self.subject.getCutPlanes()[0].getNormal(coordinates="scene")
for _, text, _, normal in self._PLANE_ACTIONS:
if numpy.allclose(normal, currentNormal):
return text
- return ''
+ return ""
def getEditor(self, parent, option, index):
editor = qt.QComboBox(parent)
@@ -1262,13 +1278,14 @@ class PlaneOrientationItem(SubjectItem):
# Wrapping call in lambda is a workaround for PySide with Python 3
editor.currentIndexChanged[int].connect(
- lambda index: self.__editorChanged(index))
+ lambda index: self.__editorChanged(index)
+ )
return editor
def __editorChanged(self, index):
normal = self._PLANE_ACTIONS[index][3]
plane = self.subject.getCutPlanes()[0]
- plane.setNormal(normal, coordinates='scene')
+ plane.setNormal(normal, coordinates="scene")
plane.moveToCenter()
def setEditorData(self, editor):
@@ -1295,7 +1312,8 @@ class PlaneInterpolationItem(SubjectItem):
interpolation = self.subject.getCutPlanes()[0].getInterpolation()
self.setCheckable(True)
self.setCheckState(
- qt.Qt.Checked if interpolation == 'linear' else qt.Qt.Unchecked)
+ qt.Qt.Checked if interpolation == "linear" else qt.Qt.Unchecked
+ )
self.setData(self._pullData(), role=qt.Qt.DisplayRole, pushData=False)
def getSignals(self):
@@ -1303,7 +1321,7 @@ class PlaneInterpolationItem(SubjectItem):
def leftClicked(self):
checked = self.checkState() == qt.Qt.Checked
- self._setInterpolation('linear' if checked else 'nearest')
+ self._setInterpolation("linear" if checked else "nearest")
def _pullData(self):
interpolation = self.subject.getCutPlanes()[0].getInterpolation()
@@ -1323,8 +1341,7 @@ class PlaneDisplayBelowMinItem(SubjectItem):
def _init(self):
display = self.subject.getCutPlanes()[0].getDisplayValuesBelowMin()
self.setCheckable(True)
- self.setCheckState(
- qt.Qt.Checked if display else qt.Qt.Unchecked)
+ self.setCheckState(qt.Qt.Checked if display else qt.Qt.Unchecked)
self.setData(self._pullData(), role=qt.Qt.DisplayRole, pushData=False)
def getSignals(self):
@@ -1348,12 +1365,21 @@ class PlaneColormapItem(ColormapBase):
colormap name item.
Editor is a QComboBox
"""
+
editable = True
- listValues = ['gray', 'reversed gray',
- 'temperature', 'red',
- 'green', 'blue',
- 'viridis', 'magma', 'inferno', 'plasma']
+ listValues = [
+ "gray",
+ "reversed gray",
+ "temperature",
+ "red",
+ "green",
+ "blue",
+ "viridis",
+ "magma",
+ "inferno",
+ "plasma",
+ ]
def getEditor(self, parent, option, index):
editor = qt.QComboBox(parent)
@@ -1361,7 +1387,8 @@ class PlaneColormapItem(ColormapBase):
# Wrapping call in lambda is a workaround for PySide with Python 3
editor.currentIndexChanged[int].connect(
- lambda index: self.__editorChanged(index))
+ lambda index: self.__editorChanged(index)
+ )
return editor
@@ -1375,7 +1402,7 @@ class PlaneColormapItem(ColormapBase):
try:
index = self.listValues.index(colormapName)
except ValueError:
- _logger.error('Unsupported colormap: %s', colormapName)
+ _logger.error("Unsupported colormap: %s", colormapName)
else:
editor.setCurrentIndex(index)
return True
@@ -1397,12 +1424,13 @@ class PlaneAutoScaleItem(ColormapBase):
def _init(self):
colorMap = self.subject.getCutPlanes()[0].getColormap()
self.setCheckable(True)
- self.setCheckState((colorMap.isAutoscale() and qt.Qt.Checked)
- or qt.Qt.Unchecked)
+ self.setCheckState(
+ (colorMap.isAutoscale() and qt.Qt.Checked) or qt.Qt.Unchecked
+ )
self.setData(self._pullData(), role=qt.Qt.DisplayRole, pushData=False)
def leftClicked(self):
- checked = (self.checkState() == qt.Qt.Checked)
+ checked = self.checkState() == qt.Qt.Checked
self._setAutoScale(checked)
def _setAutoScale(self, auto):
@@ -1424,9 +1452,9 @@ class PlaneAutoScaleItem(ColormapBase):
auto = self.subject.getCutPlanes()[0].getColormap().isAutoscale()
self._setAutoScale(auto)
if auto:
- data = 'Auto'
+ data = "Auto"
else:
- data = 'User'
+ data = "User"
return data
@@ -1435,6 +1463,7 @@ class NormalizationNode(ColormapBase):
colormap normalization item.
Item is a QComboBox.
"""
+
editable = True
listValues = list(Colormap.NORMALIZATIONS)
@@ -1444,17 +1473,20 @@ class NormalizationNode(ColormapBase):
# Wrapping call in lambda is a workaround for PySide with Python 3
editor.currentIndexChanged[int].connect(
- lambda index: self.__editorChanged(index))
+ lambda index: self.__editorChanged(index)
+ )
return editor
def __editorChanged(self, index):
colorMap = self.subject.getCutPlanes()[0].getColormap()
normalization = self.listValues[index]
- self.subject.getCutPlanes()[0].setColormap(name=colorMap.getName(),
- norm=normalization,
- vmin=colorMap.getVMin(),
- vmax=colorMap.getVMax())
+ self.subject.getCutPlanes()[0].setColormap(
+ name=colorMap.getName(),
+ norm=normalization,
+ vmin=colorMap.getVMin(),
+ vmax=colorMap.getVMax(),
+ )
def setEditorData(self, editor):
normalization = self.subject.getCutPlanes()[0].getColormap().getNormalization()
@@ -1474,48 +1506,49 @@ class PlaneGroup(SubjectItem):
"""
Root Item for the plane items.
"""
+
def _init(self):
valueItem = qt.QStandardItem()
valueItem.setEditable(False)
- nameItem = PlaneVisibleItem(self.subject, 'Visible')
+ nameItem = PlaneVisibleItem(self.subject, "Visible")
self.appendRow([nameItem, valueItem])
- nameItem = qt.QStandardItem('Colormap')
+ nameItem = qt.QStandardItem("Colormap")
nameItem.setEditable(False)
valueItem = PlaneColormapItem(self.subject)
self.appendRow([nameItem, valueItem])
- nameItem = qt.QStandardItem('Normalization')
+ nameItem = qt.QStandardItem("Normalization")
nameItem.setEditable(False)
valueItem = NormalizationNode(self.subject)
self.appendRow([nameItem, valueItem])
- nameItem = qt.QStandardItem('Orientation')
+ nameItem = qt.QStandardItem("Orientation")
nameItem.setEditable(False)
valueItem = PlaneOrientationItem(self.subject)
self.appendRow([nameItem, valueItem])
- nameItem = qt.QStandardItem('Interpolation')
+ nameItem = qt.QStandardItem("Interpolation")
nameItem.setEditable(False)
valueItem = PlaneInterpolationItem(self.subject)
self.appendRow([nameItem, valueItem])
- nameItem = qt.QStandardItem('Autoscale')
+ nameItem = qt.QStandardItem("Autoscale")
nameItem.setEditable(False)
valueItem = PlaneAutoScaleItem(self.subject)
self.appendRow([nameItem, valueItem])
- nameItem = qt.QStandardItem('Min')
+ nameItem = qt.QStandardItem("Min")
nameItem.setEditable(False)
valueItem = PlaneMinRangeItem(self.subject)
self.appendRow([nameItem, valueItem])
- nameItem = qt.QStandardItem('Max')
+ nameItem = qt.QStandardItem("Max")
nameItem.setEditable(False)
valueItem = PlaneMaxRangeItem(self.subject)
self.appendRow([nameItem, valueItem])
- nameItem = qt.QStandardItem('Values<=Min')
+ nameItem = qt.QStandardItem("Values<=Min")
nameItem.setEditable(False)
valueItem = PlaneDisplayBelowMinItem(self.subject)
self.appendRow([nameItem, valueItem])
@@ -1526,15 +1559,15 @@ class PlaneVisibleItem(SubjectItem):
Plane visibility item.
Item is checkable.
"""
+
def _init(self):
plane = self.subject.getCutPlanes()[0]
self.setCheckable(True)
- self.setCheckState((plane.isVisible() and qt.Qt.Checked)
- or qt.Qt.Unchecked)
+ self.setCheckState((plane.isVisible() and qt.Qt.Checked) or qt.Qt.Unchecked)
def leftClicked(self):
plane = self.subject.getCutPlanes()[0]
- checked = (self.checkState() == qt.Qt.Checked)
+ checked = self.checkState() == qt.Qt.Checked
if checked != plane.isVisible():
plane.setVisible(checked)
if plane.isVisible():
@@ -1543,6 +1576,7 @@ class PlaneVisibleItem(SubjectItem):
# Tree ########################################################################
+
class ItemDelegate(qt.QStyledItemDelegate):
"""
Delegate for the QTreeView filled with SubjectItems.
@@ -1560,13 +1594,11 @@ class ItemDelegate(qt.QStyledItemDelegate):
editor = item.getEditor(parent, option, index)
if editor:
editor.setAutoFillBackground(True)
- if hasattr(editor, 'sigViewTask'):
+ if hasattr(editor, "sigViewTask"):
editor.sigViewTask.connect(self.__viewTask)
return editor
- editor = super(ItemDelegate, self).createEditor(parent,
- option,
- index)
+ editor = super(ItemDelegate, self).createEditor(parent, option, index)
return editor
def updateEditorGeometry(self, editor, option, index):
@@ -1597,7 +1629,7 @@ class TreeView(qt.QTreeView):
def __init__(self, parent=None):
super(TreeView, self).__init__(parent)
self.__openedIndex = None
- self._isoLevelSliderNormalization = 'linear'
+ self._isoLevelSliderNormalization = "linear"
self.setIconSize(qt.QSize(16, 16))
@@ -1620,26 +1652,30 @@ class TreeView(qt.QTreeView):
"""
model = qt.QStandardItemModel()
model.setColumnCount(ModelColumns.ColumnMax)
- model.setHorizontalHeaderLabels(['Name', 'Value'])
+ model.setHorizontalHeaderLabels(["Name", "Value"])
item = qt.QStandardItem()
item.setEditable(False)
- model.appendRow([ViewSettingsItem(sfView, 'Style'), item])
+ model.appendRow([ViewSettingsItem(sfView, "Style"), item])
item = qt.QStandardItem()
item.setEditable(False)
- model.appendRow([DataSetItem(sfView, 'Data'), item])
+ model.appendRow([DataSetItem(sfView, "Data"), item])
item = IsoSurfaceCount(sfView)
item.setEditable(False)
- model.appendRow([IsoSurfaceGroup(sfView,
- self._isoLevelSliderNormalization,
- 'Isosurfaces'),
- item])
+ model.appendRow(
+ [
+ IsoSurfaceGroup(
+ sfView, self._isoLevelSliderNormalization, "Isosurfaces"
+ ),
+ item,
+ ]
+ )
item = qt.QStandardItem()
item.setEditable(False)
- model.appendRow([PlaneGroup(sfView, 'Cutting Plane'), item])
+ model.appendRow([PlaneGroup(sfView, "Cutting Plane"), item])
self.setModel(model)
@@ -1685,21 +1721,24 @@ class TreeView(qt.QTreeView):
meth = self.closePersistentEditor
curParent = parent
- children = [model.index(row, 0, curParent)
- for row in range(model.rowCount(curParent))]
+ children = [
+ model.index(row, 0, curParent) for row in range(model.rowCount(curParent))
+ ]
columnCount = model.columnCount()
while len(children) > 0:
curParent = children.pop(-1)
- children.extend([model.index(row, 0, curParent)
- for row in range(model.rowCount(curParent))])
+ children.extend(
+ [
+ model.index(row, 0, curParent)
+ for row in range(model.rowCount(curParent))
+ ]
+ )
for colIdx in range(columnCount):
- sibling = model.sibling(curParent.row(),
- colIdx,
- curParent)
+ sibling = model.sibling(curParent.row(), colIdx, curParent)
item = model.itemFromIndex(sibling)
if isinstance(item, SubjectItem) and item.persistent:
meth(sibling)
@@ -1781,9 +1820,8 @@ class TreeView(qt.QTreeView):
parentItem.removeRow(iso.row())
else:
qt.QMessageBox.information(
- self,
- 'Remove isosurface',
- 'Select an iso-surface to remove it')
+ self, "Remove isosurface", "Select an iso-surface to remove it"
+ )
def __clicked(self, index):
"""
@@ -1797,7 +1835,7 @@ class TreeView(qt.QTreeView):
item.leftClicked()
def __delegateEvent(self, task):
- if task == 'remove_iso':
+ if task == "remove_iso":
self.__removeIsosurfaces()
def setIsoLevelSliderNormalization(self, normalization):
@@ -1807,5 +1845,5 @@ class TreeView(qt.QTreeView):
:param str normalization: Either 'linear' or 'arcsinh'
"""
- assert normalization in ('linear', 'arcsinh')
+ assert normalization in ("linear", "arcsinh")
self._isoLevelSliderNormalization = normalization
diff --git a/src/silx/gui/plot3d/ScalarFieldView.py b/src/silx/gui/plot3d/ScalarFieldView.py
index 0633221..e1d34fd 100644
--- a/src/silx/gui/plot3d/ScalarFieldView.py
+++ b/src/silx/gui/plot3d/ScalarFieldView.py
@@ -76,9 +76,9 @@ class Isosurface(qt.QObject):
def __init__(self, parent):
super(Isosurface, self).__init__(parent=parent)
- self._level = float('nan')
+ self._level = float("nan")
self._autoLevelFunction = None
- self._color = rgba('#FFD700FF')
+ self._color = rgba("#FFD700FF")
self._data = None
self._group = scene.Group()
@@ -91,7 +91,7 @@ class Isosurface(qt.QObject):
if data is None:
self._data = None
else:
- self._data = numpy.array(data, copy=copy, order='C')
+ self._data = numpy.array(data, copy=copy, order="C")
self._update()
@@ -167,7 +167,7 @@ class Isosurface(qt.QObject):
if color != self._color:
self._color = color
if len(self._group.children) != 0:
- self._group.children[0].setAttribute('color', self._color)
+ self._group.children[0].setAttribute("color", self._color)
self.sigColorChanged.emit()
def _update(self):
@@ -176,7 +176,7 @@ class Isosurface(qt.QObject):
if self._data is None:
if self.isAutoLevel():
- self._level = float('nan')
+ self._level = float("nan")
else:
if self.isAutoLevel():
@@ -191,12 +191,12 @@ class Isosurface(qt.QObject):
"Error while executing iso level function %s.%s",
module,
name,
- exc_info=True)
- level = float('nan')
+ exc_info=True,
+ )
+ level = float("nan")
else:
- _logger.info(
- 'Computed iso-level in %f s.', time.time() - st)
+ _logger.info("Computed iso-level in %f s.", time.time() - st)
if level != self._level:
self._level = level
@@ -206,19 +206,19 @@ class Isosurface(qt.QObject):
return
st = time.time()
- vertices, normals, indices = MarchingCubes(
- self._data,
- isolevel=self._level)
- _logger.info('Computed iso-surface in %f s.', time.time() - st)
+ vertices, normals, indices = MarchingCubes(self._data, isolevel=self._level)
+ _logger.info("Computed iso-surface in %f s.", time.time() - st)
if len(vertices) == 0:
return
else:
- mesh = primitives.Mesh3D(vertices,
- colors=self._color,
- normals=normals,
- mode='triangles',
- indices=indices)
+ mesh = primitives.Mesh3D(
+ vertices,
+ colors=self._color,
+ normals=normals,
+ mode="triangles",
+ indices=indices,
+ )
self._group.children = [mesh]
@@ -233,9 +233,9 @@ class SelectedRegion(object):
:param scale: Scale from array to data coordinates (sx, sy, sz)
"""
- def __init__(self, arrayRange, dataBBox,
- translation=(0., 0., 0.),
- scale=(1., 1., 1.)):
+ def __init__(
+ self, arrayRange, dataBBox, translation=(0.0, 0.0, 0.0), scale=(1.0, 1.0, 1.0)
+ ):
self._arrayRange = numpy.array(arrayRange, copy=True, dtype=numpy.int64)
assert self._arrayRange.shape == (3, 2)
assert numpy.all(self._arrayRange[:, 1] >= self._arrayRange[:, 0])
@@ -261,9 +261,11 @@ class SelectedRegion(object):
:return: A numpy array with (zslice, yslice, zslice)
:rtype: numpy.ndarray
"""
- return (slice(*self._arrayRange[0]),
- slice(*self._arrayRange[1]),
- slice(*self._arrayRange[2]))
+ return (
+ slice(*self._arrayRange[0]),
+ slice(*self._arrayRange[1]),
+ slice(*self._arrayRange[2]),
+ )
def getDataRange(self):
"""Range in the data coordinates of the selection: 3x2 array of float
@@ -348,12 +350,13 @@ class CutPlane(qt.QObject):
# Plane with texture on the data bounding box
self._dataPlane = cutplane.CutPlane(normal=(0, 1, 0))
self._dataPlane.strokeVisible = False
- self._dataPlane.alpha = 1.
+ self._dataPlane.alpha = 1.0
self._dataPlane.visible = self._visible
self._dataPlane.plane.addListener(self._planePositionChanged)
self._colormap = Colormap(
- name='gray', normalization='linear', vmin=None, vmax=None)
+ name="gray", normalization="linear", vmin=None, vmax=None
+ )
self.getColormap().sigChanged.connect(self._colormapChanged)
self._updateSceneColormap()
@@ -369,8 +372,8 @@ class CutPlane(qt.QObject):
bounds = self._planeStroke.parent.bounds(dataBounds=True)
if bounds is not None:
self._planeStroke.plane.point = numpy.clip(
- self._planeStroke.plane.point,
- a_min=bounds[0], a_max=bounds[1])
+ self._planeStroke.plane.point, a_min=bounds[0], a_max=bounds[1]
+ )
@staticmethod
def _syncPlanes(master, slave):
@@ -379,14 +382,12 @@ class CutPlane(qt.QObject):
:param PlaneInGroup master: Reference PlaneInGroup
:param PlaneInGroup slave: PlaneInGroup to align
"""
- masterToSlave = transform.StaticTransformList([
- slave.objectToSceneTransform.inverse(),
- master.objectToSceneTransform])
-
- point = masterToSlave.transformPoint(
- master.plane.point)
- normal = masterToSlave.transformNormal(
- master.plane.normal)
+ masterToSlave = transform.StaticTransformList(
+ [slave.objectToSceneTransform.inverse(), master.objectToSceneTransform]
+ )
+
+ point = masterToSlave.transformPoint(master.plane.point)
+ normal = masterToSlave.transformNormal(master.plane.normal)
slave.plane.setPlane(point, normal)
def _sfViewDataChanged(self):
@@ -407,8 +408,7 @@ class CutPlane(qt.QObject):
def _sfViewTransformChanged(self):
"""Handle transform changed in the ScalarFieldView"""
self._keepPlaneInBBox()
- self._syncPlanes(master=self._planeStroke,
- slave=self._dataPlane)
+ self._syncPlanes(master=self._planeStroke, slave=self._dataPlane)
self.sigPlaneChanged.emit()
def _planeChanged(self, source, *args, **kwargs):
@@ -423,14 +423,11 @@ class CutPlane(qt.QObject):
if self.__syncPlane:
self.__syncPlane = False
if source is self._planeStroke.plane:
- self._syncPlanes(master=self._planeStroke,
- slave=self._dataPlane)
+ self._syncPlanes(master=self._planeStroke, slave=self._dataPlane)
elif source is self._dataPlane.plane:
- self._syncPlanes(master=self._dataPlane,
- slave=self._planeStroke)
+ self._syncPlanes(master=self._dataPlane, slave=self._planeStroke)
else:
- _logger.error('Received an unknown object %s',
- str(source))
+ _logger.error("Received an unknown object %s", str(source))
if self._planeStroke.visible or self._dataPlane.visible:
self.sigPlaneChanged.emit()
@@ -447,7 +444,7 @@ class CutPlane(qt.QObject):
"""Returns whether the cut plane is defined or not (bool)"""
return self._planeStroke.isValid
- def _plane(self, coordinates='array'):
+ def _plane(self, coordinates="array"):
"""Returns the scene plane to set.
:param str coordinates: The coordinate system to use:
@@ -455,15 +452,14 @@ class CutPlane(qt.QObject):
:rtype: Plane
:raise ValueError: If coordinates is not correct
"""
- if coordinates == 'scene':
+ if coordinates == "scene":
return self._planeStroke.plane
- elif coordinates == 'array':
+ elif coordinates == "array":
return self._dataPlane.plane
else:
- raise ValueError(
- 'Unsupported coordinates: %s' % str(coordinates))
+ raise ValueError("Unsupported coordinates: %s" % str(coordinates))
- def getNormal(self, coordinates='array'):
+ def getNormal(self, coordinates="array"):
"""Returns the normal of the plane (as a unit vector)
:param str coordinates: The coordinate system to use:
@@ -474,7 +470,7 @@ class CutPlane(qt.QObject):
"""
return self._plane(coordinates).normal
- def setNormal(self, normal, coordinates='array'):
+ def setNormal(self, normal, coordinates="array"):
"""Set the normal of the plane.
:param normal: 3-tuple of float: nx, ny, nz
@@ -484,7 +480,7 @@ class CutPlane(qt.QObject):
"""
self._plane(coordinates).normal = normal
- def getPoint(self, coordinates='array'):
+ def getPoint(self, coordinates="array"):
"""Returns a point on the plane.
:param str coordinates: The coordinate system to use:
@@ -495,7 +491,7 @@ class CutPlane(qt.QObject):
"""
return self._plane(coordinates).point
- def setPoint(self, point, constraint=True, coordinates='array'):
+ def setPoint(self, point, constraint=True, coordinates="array"):
"""Set a point contained in the plane.
Warning: The plane might not intersect the bounding box of the data.
@@ -511,7 +507,7 @@ class CutPlane(qt.QObject):
if constraint:
self._keepPlaneInBBox()
- def getParameters(self, coordinates='array'):
+ def getParameters(self, coordinates="array"):
"""Returns the plane equation parameters: a*x + b*y + c*z + d = 0
:param str coordinates: The coordinate system to use:
@@ -522,7 +518,7 @@ class CutPlane(qt.QObject):
"""
return self._plane(coordinates).parameters
- def setParameters(self, parameters, constraint=True, coordinates='array'):
+ def setParameters(self, parameters, constraint=True, coordinates="array"):
"""Set the plane equation parameters: a*x + b*y + c*z + d = 0
Warning: The plane might not intersect the bounding box of the data.
@@ -644,11 +640,7 @@ class CutPlane(qt.QObject):
"""
return self._colormap
- def setColormap(self,
- name='gray',
- norm=None,
- vmin=None,
- vmax=None):
+ def setColormap(self, name="gray", norm=None, vmin=None, vmax=None):
"""Set the colormap to use.
By either providing a :class:`Colormap` object or
@@ -662,8 +654,9 @@ class CutPlane(qt.QObject):
:param float vmin: The minimum value of the range or None for autoscale
:param float vmax: The maximum value of the range or None for autoscale
"""
- _logger.debug('setColormap %s %s (%s, %s)',
- name, str(norm), str(vmin), str(vmax))
+ _logger.debug(
+ "setColormap %s %s (%s, %s)", name, str(norm), str(vmin), str(vmax)
+ )
self._colormap.sigChanged.disconnect(self._colormapChanged)
@@ -672,9 +665,10 @@ class CutPlane(qt.QObject):
self._colormap = name
else:
if norm is None:
- norm = 'linear'
+ norm = "linear"
self._colormap = Colormap(
- name=name, normalization=norm, vmin=vmin, vmax=vmax)
+ name=name, normalization=norm, vmin=vmin, vmax=vmax
+ )
self._colormap.sigChanged.connect(self._colormapChanged)
self._colormapChanged()
@@ -718,12 +712,12 @@ class _CutPlaneImage(object):
self._isValid = False
self._data = numpy.zeros((0, 0), dtype=numpy.float32)
self._index = 0
- self._xLabel = ''
- self._yLabel = ''
- self._normalLabel = ''
- self._scale = float('nan'), float('nan')
- self._translation = float('nan'), float('nan')
- self._position = float('nan')
+ self._xLabel = ""
+ self._yLabel = ""
+ self._normalLabel = ""
+ self._scale = float("nan"), float("nan")
+ self._translation = float("nan"), float("nan")
+ self._position = float("nan")
sfView = cutPlane.parent()
if not sfView or not cutPlane.isValid():
@@ -735,10 +729,10 @@ class _CutPlaneImage(object):
_logger.info("No data available")
return
- normal = cutPlane.getNormal(coordinates='array')
- point = cutPlane.getPoint(coordinates='array')
+ normal = cutPlane.getNormal(coordinates="array")
+ point = cutPlane.getPoint(coordinates="array")
- if numpy.linalg.norm(numpy.cross(normal, (1., 0., 0.))) < 0.0017:
+ if numpy.linalg.norm(numpy.cross(normal, (1.0, 0.0, 0.0))) < 0.0017:
if not 0 <= point[0] <= data.shape[2]:
_logger.info("Plane outside dataset")
return
@@ -746,7 +740,7 @@ class _CutPlaneImage(object):
slice_ = data[:, :, index]
xAxisIndex, yAxisIndex, normalAxisIndex = 1, 2, 0 # y, z, x
- elif numpy.linalg.norm(numpy.cross(normal, (0., 1., 0.))) < 0.0017:
+ elif numpy.linalg.norm(numpy.cross(normal, (0.0, 1.0, 0.0))) < 0.0017:
if not 0 <= point[1] <= data.shape[1]:
_logger.info("Plane outside dataset")
return
@@ -754,7 +748,7 @@ class _CutPlaneImage(object):
slice_ = numpy.transpose(data[:, index, :])
xAxisIndex, yAxisIndex, normalAxisIndex = 2, 0, 1 # z, x, y
- elif numpy.linalg.norm(numpy.cross(normal, (0., 0., 1.))) < 0.0017:
+ elif numpy.linalg.norm(numpy.cross(normal, (0.0, 0.0, 1.0))) < 0.0017:
if not 0 <= point[2] <= data.shape[0]:
_logger.info("Plane outside dataset")
return
@@ -762,8 +756,9 @@ class _CutPlaneImage(object):
slice_ = data[index, :, :]
xAxisIndex, yAxisIndex, normalAxisIndex = 0, 1, 2 # x, y, z
else:
- _logger.warning('Unsupported normal: (%f, %f, %f)',
- normal[0], normal[1], normal[2])
+ _logger.warning(
+ "Unsupported normal: (%f, %f, %f)", normal[0], normal[1], normal[2]
+ )
return
# Store cut plane image info
@@ -774,8 +769,11 @@ class _CutPlaneImage(object):
# Only store extra information when no transform matrix is set
# Otherwise this information can be meaningless
- if numpy.all(numpy.equal(sfView.getTransformMatrix(),
- numpy.identity(3, dtype=numpy.float32))):
+ if numpy.all(
+ numpy.equal(
+ sfView.getTransformMatrix(), numpy.identity(3, dtype=numpy.float32)
+ )
+ ):
labels = sfView.getAxesLabels()
self._xLabel = labels[xAxisIndex]
self._yLabel = labels[yAxisIndex]
@@ -787,8 +785,9 @@ class _CutPlaneImage(object):
translation = sfView.getTranslation()
self._translation = translation[xAxisIndex], translation[yAxisIndex]
- self._position = float(index * scale[normalAxisIndex] +
- translation[normalAxisIndex])
+ self._position = float(
+ index * scale[normalAxisIndex] + translation[normalAxisIndex]
+ )
def isValid(self):
"""Returns True if the cut plane image is defined (bool)"""
@@ -860,7 +859,8 @@ class ScalarFieldView(Plot3DWindow):
def __init__(self, parent=None):
super(ScalarFieldView, self).__init__(parent)
self._colormap = Colormap(
- name='gray', normalization='linear', vmin=None, vmax=None)
+ name="gray", normalization="linear", vmin=None, vmax=None
+ )
self._selectedRange = None
# Store iso-surfaces
@@ -869,35 +869,37 @@ class ScalarFieldView(Plot3DWindow):
# Transformations
self._dataScale = transform.Scale()
self._dataTranslate = transform.Translate()
- self._dataTransform = transform.Matrix() # default to identity
+ self._dataTransform = transform.Matrix() # default to identity
- self._foregroundColor = 1., 1., 1., 1.
- self._highlightColor = 0.7, 0.7, 0., 1.
+ self._foregroundColor = 1.0, 1.0, 1.0, 1.0
+ self._highlightColor = 0.7, 0.7, 0.0, 1.0
self._data = None
self._dataRange = None
self._group = primitives.BoundedGroup()
self._group.transforms = [
- self._dataTranslate, self._dataTransform, self._dataScale]
+ self._dataTranslate,
+ self._dataTransform,
+ self._dataScale,
+ ]
self._bbox = axes.LabelledAxes()
self._bbox.children = [self._group]
- self._outerScale = transform.Scale(1., 1., 1.)
+ self._outerScale = transform.Scale(1.0, 1.0, 1.0)
self._bbox.transforms = [self._outerScale]
self.getPlot3DWidget().viewport.scene.children.append(self._bbox)
self._selectionBox = primitives.Box()
self._selectionBox.strokeSmooth = False
- self._selectionBox.strokeWidth = 1.
+ self._selectionBox.strokeWidth = 1.0
# self._selectionBox.fillColor = 1., 1., 1., 0.3
# self._selectionBox.fillCulling = 'back'
self._selectionBox.visible = False
self._group.children.append(self._selectionBox)
self._cutPlane = CutPlane(sfView=self)
- self._cutPlane.sigVisibilityChanged.connect(
- self._planeVisibilityChanged)
+ self._cutPlane.sigVisibilityChanged.connect(self._planeVisibilityChanged)
planeStroke, dataPlane = self._cutPlane._get3DPrimitives()
self._bbox.children.append(planeStroke)
self._group.children.append(dataPlane)
@@ -905,13 +907,16 @@ class ScalarFieldView(Plot3DWindow):
self._isogroup = primitives.GroupDepthOffset()
self._isogroup.transforms = [
# Convert from z, y, x from marching cubes to x, y, z
- transform.Matrix((
- (0., 0., 1., 0.),
- (0., 1., 0., 0.),
- (1., 0., 0., 0.),
- (0., 0., 0., 1.))),
+ transform.Matrix(
+ (
+ (0.0, 0.0, 1.0, 0.0),
+ (0.0, 1.0, 0.0, 0.0),
+ (1.0, 0.0, 0.0, 0.0),
+ (0.0, 0.0, 0.0, 1.0),
+ )
+ ),
# Offset to match cutting plane coords
- transform.Translate(0.5, 0.5, 0.5)
+ transform.Translate(0.5, 0.5, 0.5),
]
self._group.children.append(self._isogroup)
@@ -931,7 +936,7 @@ class ScalarFieldView(Plot3DWindow):
stream = qt.QDataStream(ioDevice)
- stream.writeString('<ScalarFieldView>')
+ stream.writeString("<ScalarFieldView>")
isoSurfaces = self.getIsosurfaces()
@@ -940,7 +945,7 @@ class ScalarFieldView(Plot3DWindow):
# TODO : delegate the serialization to the serialized items
# isosurfaces
if nIsoSurfaces:
- tagIn = '<IsoSurfaces nIso={0}>'.format(nIsoSurfaces)
+ tagIn = "<IsoSurfaces nIso={0}>".format(nIsoSurfaces)
stream.writeString(tagIn)
for surface in isoSurfaces:
@@ -951,16 +956,16 @@ class ScalarFieldView(Plot3DWindow):
stream.writeDouble(level)
stream.writeBool(visible)
- stream.writeString('</IsoSurfaces>')
+ stream.writeString("</IsoSurfaces>")
- stream.writeString('<Style>')
+ stream.writeString("<Style>")
background = self.getBackgroundColor()
foreground = self.getForegroundColor()
highlight = self.getHighlightColor()
stream << background << foreground << highlight
- stream.writeString('</Style>')
+ stream.writeString("</Style>")
- stream.writeString('</ScalarFieldView>')
+ stream.writeString("</ScalarFieldView>")
def loadConfig(self, ioDevice):
"""
@@ -972,14 +977,13 @@ class ScalarFieldView(Plot3DWindow):
tagStack = deque()
- tagInRegex = re.compile('<(?P<itemId>[^ /]*) *'
- '(?P<args>.*)>')
+ tagInRegex = re.compile("<(?P<itemId>[^ /]*) *" "(?P<args>.*)>")
- tagOutRegex = re.compile('</(?P<itemId>[^ ]*)>')
+ tagOutRegex = re.compile("</(?P<itemId>[^ ]*)>")
- tagRootInRegex = re.compile('<ScalarFieldView>')
+ tagRootInRegex = re.compile("<ScalarFieldView>")
- isoSurfaceArgsRegex = re.compile('nIso=(?P<nIso>[0-9]*)')
+ isoSurfaceArgsRegex = re.compile("nIso=(?P<nIso>[0-9]*)")
stream = qt.QDataStream(ioDevice)
@@ -988,26 +992,27 @@ class ScalarFieldView(Plot3DWindow):
if tagMatch is None:
# TODO : explicit error
- raise ValueError('Unknown data.')
+ raise ValueError("Unknown data.")
- itemId = 'ScalarFieldView'
+ itemId = "ScalarFieldView"
tagStack.append(itemId)
while True:
-
tag = stream.readString()
tagMatch = tagOutRegex.match(tag)
if tagMatch:
- closeId = tagMatch.groupdict()['itemId']
+ closeId = tagMatch.groupdict()["itemId"]
if closeId != itemId:
# TODO : explicit error
- raise ValueError('Unexpected closing tag {0} '
- '(expected {1})'
- ''.format(closeId, itemId))
+ raise ValueError(
+ "Unexpected closing tag {0} "
+ "(expected {1})"
+ "".format(closeId, itemId)
+ )
- if itemId == 'ScalarFieldView':
+ if itemId == "ScalarFieldView":
# reached end
break
else:
@@ -1019,23 +1024,24 @@ class ScalarFieldView(Plot3DWindow):
if tagMatch is None:
# TODO : explicit error
- raise ValueError('Unknown data.')
+ raise ValueError("Unknown data.")
tagStack.append(itemId)
matchDict = tagMatch.groupdict()
- itemId = matchDict['itemId']
+ itemId = matchDict["itemId"]
# TODO : delegate the deserialization to the serialized items
- if itemId == 'IsoSurfaces':
- argsMatch = isoSurfaceArgsRegex.match(matchDict['args'])
+ if itemId == "IsoSurfaces":
+ argsMatch = isoSurfaceArgsRegex.match(matchDict["args"])
if not argsMatch:
# TODO : explicit error
- raise ValueError('Failed to parse args "{0}".'
- ''.format(matchDict['args']))
+ raise ValueError(
+ 'Failed to parse args "{0}".' "".format(matchDict["args"])
+ )
argsDict = argsMatch.groupdict()
- nIso = int(argsDict['nIso'])
+ nIso = int(argsDict["nIso"])
if nIso:
for surface in self.getIsosurfaces():
self.removeIsosurface(surface)
@@ -1046,7 +1052,7 @@ class ScalarFieldView(Plot3DWindow):
visible = stream.readBool()
surface = self.addIsosurface(level, color=color)
surface.setVisible(visible)
- elif itemId == 'Style':
+ elif itemId == "Style":
background = qt.QColor()
foreground = qt.QColor()
highlight = qt.QColor()
@@ -1055,22 +1061,23 @@ class ScalarFieldView(Plot3DWindow):
self.setForegroundColor(foreground)
self.setHighlightColor(highlight)
else:
- raise ValueError('Unknown entry tag {0}.'
- ''.format(itemId))
+ raise ValueError("Unknown entry tag {0}." "".format(itemId))
def _initPanPlaneAction(self):
"""Creates and init the pan plane action"""
self._panPlaneAction = qt.QAction(self)
- self._panPlaneAction.setIcon(icons.getQIcon('3d-plane-pan'))
- self._panPlaneAction.setText('Pan plane')
+ self._panPlaneAction.setIcon(icons.getQIcon("3d-plane-pan"))
+ self._panPlaneAction.setText("Pan plane")
self._panPlaneAction.setCheckable(True)
self._panPlaneAction.setToolTip(
- 'Pan the cutting plane. Press <b>Ctrl</b> to rotate the scene.')
+ "Pan the cutting plane. Press <b>Ctrl</b> to rotate the scene."
+ )
self._panPlaneAction.setEnabled(False)
self._panPlaneAction.triggered[bool].connect(self._planeActionTriggered)
self.getPlot3DWidget().sigInteractiveModeChanged.connect(
- self._interactiveModeChanged)
+ self._interactiveModeChanged
+ )
toolbar = self.findChild(InteractiveModeToolBar)
if toolbar is not None:
@@ -1078,10 +1085,10 @@ class ScalarFieldView(Plot3DWindow):
def _planeActionTriggered(self, checked=False):
self._panPlaneAction.setChecked(True)
- self.setInteractiveMode('plane')
+ self.setInteractiveMode("plane")
def _interactiveModeChanged(self):
- self._panPlaneAction.setChecked(self.getInteractiveMode() == 'plane')
+ self._panPlaneAction.setChecked(self.getInteractiveMode() == "plane")
self._updateColors()
def _planeVisibilityChanged(self, visible):
@@ -1089,9 +1096,9 @@ class ScalarFieldView(Plot3DWindow):
if visible != self._panPlaneAction.isEnabled():
self._panPlaneAction.setEnabled(visible)
if visible:
- self.setInteractiveMode('plane')
+ self.setInteractiveMode("plane")
elif self._panPlaneAction.isChecked():
- self.setInteractiveMode('rotate')
+ self.setInteractiveMode("rotate")
def setInteractiveMode(self, mode):
"""Choose the current interaction.
@@ -1102,23 +1109,24 @@ class ScalarFieldView(Plot3DWindow):
return
sceneScale = self.getPlot3DWidget().viewport.scene.transforms[0]
- if mode == 'plane':
+ if mode == "plane":
mode = interaction.PanPlaneZoomOnWheelControl(
self.getPlot3DWidget().viewport,
self._cutPlane._get3DPrimitives()[0],
- mode='position',
+ mode="position",
orbitAroundCenter=False,
- scaleTransform=sceneScale)
+ scaleTransform=sceneScale,
+ )
self.getPlot3DWidget().setInteractiveMode(mode)
self._updateColors()
def getInteractiveMode(self):
- """Returns the current interaction mode, see :meth:`setInteractiveMode`
- """
- if isinstance(self.getPlot3DWidget().eventHandler,
- interaction.PanPlaneZoomOnWheelControl):
- return 'plane'
+ """Returns the current interaction mode, see :meth:`setInteractiveMode`"""
+ if isinstance(
+ self.getPlot3DWidget().eventHandler, interaction.PanPlaneZoomOnWheelControl
+ ):
+ return "plane"
else:
return self.getPlot3DWidget().getInteractiveMode()
@@ -1143,7 +1151,7 @@ class ScalarFieldView(Plot3DWindow):
self.centerScene()
else:
- data = numpy.array(data, copy=copy, dtype=numpy.float32, order='C')
+ data = numpy.array(data, copy=copy, dtype=numpy.float32, order="C")
assert data.ndim == 3
assert min(data.shape) >= 2
@@ -1160,7 +1168,7 @@ class ScalarFieldView(Plot3DWindow):
if dataRange is not None:
min_positive = dataRange.min_positive
if min_positive is None:
- min_positive = float('nan')
+ min_positive = float("nan")
dataRange = dataRange.minimum, min_positive, dataRange.maximum
self._dataRange = dataRange
@@ -1203,7 +1211,7 @@ class ScalarFieldView(Plot3DWindow):
# Transformations
- def setOuterScale(self, sx=1., sy=1., sz=1.):
+ def setOuterScale(self, sx=1.0, sy=1.0, sz=1.0):
"""Set the scale to apply to the whole scene including the axes.
This is useful when axis lengths in data space are really different.
@@ -1222,7 +1230,7 @@ class ScalarFieldView(Plot3DWindow):
"""
return self._outerScale.scale
- def setScale(self, sx=1., sy=1., sz=1.):
+ def setScale(self, sx=1.0, sy=1.0, sz=1.0):
"""Set the scale of the 3D scalar field (i.e., size of a voxel).
:param float sx: Scale factor along the X axis
@@ -1236,11 +1244,10 @@ class ScalarFieldView(Plot3DWindow):
self.centerScene() # Reset viewpoint
def getScale(self):
- """Returns the scales provided by :meth:`setScale` as a numpy.ndarray.
- """
+ """Returns the scales provided by :meth:`setScale` as a numpy.ndarray."""
return self._dataScale.scale
- def setTranslation(self, x=0., y=0., z=0.):
+ def setTranslation(self, x=0.0, y=0.0, z=0.0):
"""Set the translation of the origin of the data array in data coordinates.
:param float x: Offset of the data origin on the X axis
@@ -1254,8 +1261,7 @@ class ScalarFieldView(Plot3DWindow):
self.centerScene() # Reset viewpoint
def getTranslation(self):
- """Returns the offset set by :meth:`setTranslation` as a numpy.ndarray.
- """
+ """Returns the offset set by :meth:`setTranslation` as a numpy.ndarray."""
return self._dataTranslate.translation
def setTransformMatrix(self, matrix3x3):
@@ -1346,9 +1352,7 @@ class ScalarFieldView(Plot3DWindow):
:return: object describing the labels
"""
- return self._Labels((self._bbox.xlabel,
- self._bbox.ylabel,
- self._bbox.zlabel))
+ return self._Labels((self._bbox.xlabel, self._bbox.ylabel, self._bbox.zlabel))
# Colors
@@ -1356,7 +1360,7 @@ class ScalarFieldView(Plot3DWindow):
"""Update item depending on foreground/highlight color"""
self._bbox.tickColor = self._foregroundColor
self._selectionBox.strokeColor = self._foregroundColor
- if self.getInteractiveMode() == 'plane':
+ if self.getInteractiveMode() == "plane":
self._cutPlane.setStrokeColor(self._highlightColor)
self._bbox.color = self._foregroundColor
else:
@@ -1435,18 +1439,17 @@ class ScalarFieldView(Plot3DWindow):
elif None in (xrange_, yrange, zrange):
# One of the range is None and no data available
- raise RuntimeError(
- 'Data is not set, cannot get default range from it.')
+ raise RuntimeError("Data is not set, cannot get default range from it.")
# Clip selected region to data shape and make sure min <= max
- selectedRange = numpy.array((
- (max(0, min(*zrange)),
- min(self._data.shape[0], max(*zrange))),
- (max(0, min(*yrange)),
- min(self._data.shape[1], max(*yrange))),
- (max(0, min(*xrange_)),
- min(self._data.shape[2], max(*xrange_))),
- ), dtype=numpy.int64)
+ selectedRange = numpy.array(
+ (
+ (max(0, min(*zrange)), min(self._data.shape[0], max(*zrange))),
+ (max(0, min(*yrange)), min(self._data.shape[1], max(*yrange))),
+ (max(0, min(*xrange_)), min(self._data.shape[2], max(*xrange_))),
+ ),
+ dtype=numpy.int64,
+ )
# numpy.equal supports None
if not numpy.all(numpy.equal(selectedRange, self._selectedRange)):
@@ -1460,7 +1463,8 @@ class ScalarFieldView(Plot3DWindow):
scales = self._selectedRange[:, 1] - self._selectedRange[:, 0]
self._selectionBox.size = scales[::-1]
self._selectionBox.transforms = [
- transform.Translate(*self._selectedRange[::-1, 0])]
+ transform.Translate(*self._selectedRange[::-1, 0])
+ ]
self.sigSelectedRegionChanged.emit(self.getSelectedRegion())
@@ -1470,10 +1474,14 @@ class ScalarFieldView(Plot3DWindow):
return None
else:
dataBBox = self._group.transforms.transformBounds(
- self._selectedRange[::-1].T).T
- return SelectedRegion(self._selectedRange, dataBBox,
- translation=self.getTranslation(),
- scale=self.getScale())
+ self._selectedRange[::-1].T
+ ).T
+ return SelectedRegion(
+ self._selectedRange,
+ dataBBox,
+ translation=self.getTranslation(),
+ scale=self.getScale(),
+ )
# Handle iso-surfaces
@@ -1528,8 +1536,8 @@ class ScalarFieldView(Plot3DWindow):
:param isosurface: The isosurface object to remove"""
if isosurface not in self.getIsosurfaces():
_logger.warning(
- "Try to remove isosurface that is not in the list: %s",
- str(isosurface))
+ "Try to remove isosurface that is not in the list: %s", str(isosurface)
+ )
else:
isosurface.sigLevelChanged.disconnect(self._updateIsosurfaces)
self._isosurfaces.remove(isosurface)
@@ -1544,6 +1552,5 @@ class ScalarFieldView(Plot3DWindow):
def _updateIsosurfaces(self, level=None):
"""Handle updates of iso-surfaces level and add/remove"""
# Sorting using minus, this supposes data 'object' to be max values
- sortedIso = sorted(self.getIsosurfaces(),
- key=lambda iso: - iso.getLevel())
+ sortedIso = sorted(self.getIsosurfaces(), key=lambda iso: -iso.getLevel())
self._isogroup.children = [iso._get3DPrimitive() for iso in sortedIso]
diff --git a/src/silx/gui/plot3d/SceneWidget.py b/src/silx/gui/plot3d/SceneWidget.py
index 910820c..d4d21cb 100644
--- a/src/silx/gui/plot3d/SceneWidget.py
+++ b/src/silx/gui/plot3d/SceneWidget.py
@@ -42,7 +42,7 @@ from .scene import interaction
from ._model import SceneModel, visitQAbstractItemModel
from ._model.items import Item3DRow
-__all__ = ['items', 'SceneWidget']
+__all__ = ["items", "SceneWidget"]
class _SceneSelectionHighlightManager(object):
@@ -88,8 +88,7 @@ class _SceneSelectionHighlightManager(object):
else: # disabled
self.__unselectItem(current)
- selection.sigCurrentChanged.disconnect(
- self.__currentChanged)
+ selection.sigCurrentChanged.disconnect(self.__currentChanged)
def getSceneWidget(self):
"""Returns the SceneWidget this class controls highlight for.
@@ -101,7 +100,7 @@ class _SceneSelectionHighlightManager(object):
def __selectItem(self, current):
"""Highlight given item.
- :param ~silx.gui.plot3d.items.Item3D current: New current or None
+ :param ~silx.gui.plot3d.items.Item3D current: New current or None
"""
if current is None:
return
@@ -131,8 +130,9 @@ class _SceneSelectionHighlightManager(object):
# Restore bbox visibility and color
current.sigItemChanged.disconnect(self.__selectedChanged)
- if (self._previousBBoxState is not None and
- isinstance(current, items.DataItem3D)):
+ if self._previousBBoxState is not None and isinstance(
+ current, items.DataItem3D
+ ):
current.setBoundingBoxVisible(self._previousBBoxState)
current._setForegroundColor(sceneWidget.getForegroundColor())
@@ -160,10 +160,10 @@ class _SceneSelectionHighlightManager(object):
class HighlightMode(enum.Enum):
""":class:`SceneSelection` highlight modes"""
- NONE = 'noHighlight'
+ NONE = "noHighlight"
"""Do not highlight selected item"""
- BOUNDING_BOX = 'boundingBox'
+ BOUNDING_BOX = "boundingBox"
"""Highlight selected item bounding box"""
@@ -244,12 +244,10 @@ class SceneSelection(qt.QObject):
item.sigItemChanged.connect(self.__currentChanged)
self.__current = weakref.ref(item)
else:
- raise ValueError(
- 'Item is not in this SceneWidget: %s' % str(item))
+ raise ValueError("Item is not in this SceneWidget: %s" % str(item))
else:
- raise ValueError(
- 'Not an Item3D: %s' % str(item))
+ raise ValueError("Not an Item3D: %s" % str(item))
current = self.getCurrentItem()
self.sigCurrentChanged.emit(current, previous)
@@ -282,24 +280,29 @@ class SceneSelection(qt.QObject):
:raise ValueError: If the selection model does not correspond
to the same :class:`SceneWidget`
"""
- if (not isinstance(selectionModel, qt.QItemSelectionModel) or
- not isinstance(selectionModel.model(), SceneModel) or
- selectionModel.model().sceneWidget() is not self.parent()):
- raise ValueError("Expecting a QItemSelectionModel "
- "attached to the same SceneWidget")
+ if (
+ not isinstance(selectionModel, qt.QItemSelectionModel)
+ or not isinstance(selectionModel.model(), SceneModel)
+ or selectionModel.model().sceneWidget() is not self.parent()
+ ):
+ raise ValueError(
+ "Expecting a QItemSelectionModel " "attached to the same SceneWidget"
+ )
# Disconnect from previous selection model
previousSelectionModel = self._getSyncSelectionModel()
if previousSelectionModel is not None:
previousSelectionModel.selectionChanged.disconnect(
- self.__selectionModelSelectionChanged)
+ self.__selectionModelSelectionChanged
+ )
self.__selectionModel = selectionModel
if selectionModel is not None:
# Connect to new selection model
selectionModel.selectionChanged.connect(
- self.__selectionModelSelectionChanged)
+ self.__selectionModelSelectionChanged
+ )
self.__updateSelectionModel()
def __selectionModelSelectionChanged(self, selected, deselected):
@@ -341,15 +344,19 @@ class SceneSelection(qt.QObject):
model = selectionModel.model()
for index in visitQAbstractItemModel(model):
itemRow = index.internalPointer()
- if (isinstance(itemRow, Item3DRow) and
- itemRow.item() is currentItem and
- index.flags() & qt.Qt.ItemIsSelectable):
+ if (
+ isinstance(itemRow, Item3DRow)
+ and itemRow.item() is currentItem
+ and index.flags() & qt.Qt.ItemIsSelectable
+ ):
# This is the item we are looking for: select it in the model
self.__syncInProgress = True
selectionModel.select(
- index, qt.QItemSelectionModel.Clear |
- qt.QItemSelectionModel.Select |
- qt.QItemSelectionModel.Current)
+ index,
+ qt.QItemSelectionModel.Clear
+ | qt.QItemSelectionModel.Select
+ | qt.QItemSelectionModel.Current,
+ )
self.__syncInProgress = False
break
@@ -363,15 +370,14 @@ class SceneWidget(Plot3DWidget):
self._selection = None # Store lazy-loaded SceneSelection
self._items = []
- self._textColor = 1., 1., 1., 1.
- self._foregroundColor = 1., 1., 1., 1.
- self._highlightColor = 0.7, 0.7, 0., 1.
+ self._textColor = 1.0, 1.0, 1.0, 1.0
+ self._foregroundColor = 1.0, 1.0, 1.0, 1.0
+ self._highlightColor = 0.7, 0.7, 0.0, 1.0
self._sceneGroup = RootGroupWithAxesItem(parent=self)
- self._sceneGroup.setLabel('Data')
+ self._sceneGroup.setLabel("Data")
- self.viewport.scene.children.append(
- self._sceneGroup._getScenePrimitive())
+ self.viewport.scene.children.append(self._sceneGroup._getScenePrimitive())
def model(self):
"""Returns the model corresponding the scene of this widget
@@ -419,20 +425,21 @@ class SceneWidget(Plot3DWidget):
devicePixelRatio = self.getDevicePixelRatio()
for result in self.getSceneGroup().pickItems(
- x * devicePixelRatio, y * devicePixelRatio, condition):
+ x * devicePixelRatio, y * devicePixelRatio, condition
+ ):
yield result
# Interactive modes
def _handleSelectionChanged(self, current, previous):
"""Handle change of selection to update interactive mode"""
- if self.getInteractiveMode() == 'panSelectedPlane':
+ if self.getInteractiveMode() == "panSelectedPlane":
if isinstance(current, items.PlaneMixIn):
# Update pan plane to use new selected plane
- self.setInteractiveMode('panSelectedPlane')
+ self.setInteractiveMode("panSelectedPlane")
else: # Switch to rotate scene if new selection is not a plane
- self.setInteractiveMode('rotate')
+ self.setInteractiveMode("rotate")
def setInteractiveMode(self, mode):
"""Set the interactive mode.
@@ -443,26 +450,25 @@ class SceneWidget(Plot3DWidget):
:param str mode:
The interactive mode: 'rotate', 'pan', 'panSelectedPlane' or None
"""
- if self.getInteractiveMode() == 'panSelectedPlane':
- self.selection().sigCurrentChanged.disconnect(
- self._handleSelectionChanged)
+ if self.getInteractiveMode() == "panSelectedPlane":
+ self.selection().sigCurrentChanged.disconnect(self._handleSelectionChanged)
- if mode == 'panSelectedPlane':
+ if mode == "panSelectedPlane":
selected = self.selection().getCurrentItem()
if isinstance(selected, items.PlaneMixIn):
mode = interaction.PanPlaneZoomOnWheelControl(
self.viewport,
selected._getPlane(),
- mode='position',
+ mode="position",
orbitAroundCenter=False,
- scaleTransform=self._sceneScale)
+ scaleTransform=self._sceneScale,
+ )
- self.selection().sigCurrentChanged.connect(
- self._handleSelectionChanged)
+ self.selection().sigCurrentChanged.connect(self._handleSelectionChanged)
else: # No selected plane, fallback to rotate scene
- mode = 'rotate'
+ mode = "rotate"
super(SceneWidget, self).setInteractiveMode(mode)
@@ -472,7 +478,7 @@ class SceneWidget(Plot3DWidget):
:rtype: str
"""
if isinstance(self.eventHandler, interaction.PanPlaneZoomOnWheelControl):
- return 'panSelectedPlane'
+ return "panSelectedPlane"
else:
return super(SceneWidget, self).getInteractiveMode()
@@ -631,7 +637,7 @@ class SceneWidget(Plot3DWidget):
bbox = self._sceneGroup._getScenePrimitive()
bbox.tickColor = color
- self.sigStyleChanged.emit('textColor')
+ self.sigStyleChanged.emit("textColor")
def getForegroundColor(self):
"""Return color used for bounding box
@@ -657,7 +663,7 @@ class SceneWidget(Plot3DWidget):
if item is not selected:
item._setForegroundColor(color)
- self.sigStyleChanged.emit('foregroundColor')
+ self.sigStyleChanged.emit("foregroundColor")
def getHighlightColor(self):
"""Return color used for highlighted item bounding box
@@ -681,4 +687,4 @@ class SceneWidget(Plot3DWidget):
if selected is not None:
selected._setForegroundColor(color)
- self.sigStyleChanged.emit('highlightColor')
+ self.sigStyleChanged.emit("highlightColor")
diff --git a/src/silx/gui/plot3d/SceneWindow.py b/src/silx/gui/plot3d/SceneWindow.py
index d88cfa9..98c93fd 100644
--- a/src/silx/gui/plot3d/SceneWindow.py
+++ b/src/silx/gui/plot3d/SceneWindow.py
@@ -44,7 +44,7 @@ from .ParamTreeView import ParamTreeView
from . import items # noqa
-__all__ = ['items', 'SceneWidget', 'SceneWindow']
+__all__ = ["items", "SceneWidget", "SceneWindow"]
class _PanPlaneAction(InteractiveModeAction):
@@ -54,27 +54,24 @@ class _PanPlaneAction(InteractiveModeAction):
:param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
Plot3DWidget the action is associated with
"""
+
def __init__(self, parent, plot3d=None):
- super(_PanPlaneAction, self).__init__(
- parent, 'panSelectedPlane', plot3d)
- self.setIcon(icons.getQIcon('3d-plane-pan'))
- self.setText('Pan plane')
+ super(_PanPlaneAction, self).__init__(parent, "panSelectedPlane", plot3d)
+ self.setIcon(icons.getQIcon("3d-plane-pan"))
+ self.setText("Pan plane")
self.setCheckable(True)
- self.setToolTip(
- 'Pan selected plane. Press <b>Ctrl</b> to rotate the scene.')
+ self.setToolTip("Pan selected plane. Press <b>Ctrl</b> to rotate the scene.")
def _planeChanged(self, event):
"""Handle plane updates"""
- if event in (items.ItemChangedType.VISIBLE,
- items.ItemChangedType.POSITION):
+ if event in (items.ItemChangedType.VISIBLE, items.ItemChangedType.POSITION):
plane = self.sender()
- isPlaneInteractive = \
- plane._getPlane().plane.isPlane and plane.isVisible()
+ isPlaneInteractive = plane._getPlane().plane.isPlane and plane.isVisible()
if isPlaneInteractive != self.isEnabled():
self.setEnabled(isPlaneInteractive)
- mode = 'panSelectedPlane' if isPlaneInteractive else 'rotate'
+ mode = "panSelectedPlane" if isPlaneInteractive else "rotate"
self.getPlot3DWidget().setInteractiveMode(mode)
def _selectionChanged(self, current, previous):
@@ -85,24 +82,21 @@ class _PanPlaneAction(InteractiveModeAction):
if isinstance(current, items.PlaneMixIn):
current.sigItemChanged.connect(self._planeChanged)
self.setEnabled(True)
- self.getPlot3DWidget().setInteractiveMode('panSelectedPlane')
+ self.getPlot3DWidget().setInteractiveMode("panSelectedPlane")
else:
self.setEnabled(False)
def setPlot3DWidget(self, widget):
previous = self.getPlot3DWidget()
if isinstance(previous, SceneWidget):
- previous.selection().sigCurrentChanged.disconnect(
- self._selectionChanged)
- self._selectionChanged(
- None, previous.selection().getCurrentItem())
+ previous.selection().sigCurrentChanged.disconnect(self._selectionChanged)
+ self._selectionChanged(None, previous.selection().getCurrentItem())
super(_PanPlaneAction, self).setPlot3DWidget(widget)
if isinstance(widget, SceneWidget):
self._selectionChanged(widget.selection().getCurrentItem(), None)
- widget.selection().sigCurrentChanged.connect(
- self._selectionChanged)
+ widget.selection().sigCurrentChanged.connect(self._selectionChanged)
class SceneWindow(qt.QMainWindow):
@@ -128,16 +122,17 @@ class SceneWindow(qt.QMainWindow):
self._interactiveModeToolBar = InteractiveModeToolBar(parent=self)
panPlaneAction = _PanPlaneAction(self, plot3d=self._sceneWidget)
- self._interactiveModeToolBar.addAction(
- self._positionInfo.toggleAction())
+ self._interactiveModeToolBar.addAction(self._positionInfo.toggleAction())
self._interactiveModeToolBar.addAction(panPlaneAction)
self._viewpointToolBar = ViewpointToolBar(parent=self)
self._outputToolBar = OutputToolBar(parent=self)
- for toolbar in (self._interactiveModeToolBar,
- self._viewpointToolBar,
- self._outputToolBar):
+ for toolbar in (
+ self._interactiveModeToolBar,
+ self._viewpointToolBar,
+ self._outputToolBar,
+ ):
toolbar.setPlot3DWidget(self._sceneWidget)
self.addToolBar(toolbar)
self.addActions(toolbar.actions())
@@ -146,20 +141,18 @@ class SceneWindow(qt.QMainWindow):
self._paramTreeView.setModel(self._sceneWidget.model())
selectionModel = self._paramTreeView.selectionModel()
- self._sceneWidget.selection()._setSyncSelectionModel(
- selectionModel)
+ self._sceneWidget.selection()._setSyncSelectionModel(selectionModel)
paramDock = qt.QDockWidget()
- paramDock.setWindowTitle('Object parameters')
+ paramDock.setWindowTitle("Object parameters")
paramDock.setWidget(self._paramTreeView)
self.addDockWidget(qt.Qt.RightDockWidgetArea, paramDock)
self._sceneGroupResetWidget = GroupPropertiesWidget()
- self._sceneGroupResetWidget.setGroup(
- self._sceneWidget.getSceneGroup())
+ self._sceneGroupResetWidget.setGroup(self._sceneWidget.getSceneGroup())
resetDock = qt.QDockWidget()
- resetDock.setWindowTitle('Global parameters')
+ resetDock.setWindowTitle("Global parameters")
resetDock.setWidget(self._sceneGroupResetWidget)
self.addDockWidget(qt.Qt.RightDockWidgetArea, resetDock)
self.tabifyDockWidget(paramDock, resetDock)
diff --git a/src/silx/gui/plot3d/__init__.py b/src/silx/gui/plot3d/__init__.py
index e0cb688..470d37b 100644
--- a/src/silx/gui/plot3d/__init__.py
+++ b/src/silx/gui/plot3d/__init__.py
@@ -35,4 +35,4 @@ __date__ = "18/01/2017"
try:
import OpenGL as _OpenGL
except ImportError:
- raise ImportError('PyOpenGL is not installed')
+ raise ImportError("PyOpenGL is not installed")
diff --git a/src/silx/gui/plot3d/_model/core.py b/src/silx/gui/plot3d/_model/core.py
index 30d45ec..cb34ab9 100644
--- a/src/silx/gui/plot3d/_model/core.py
+++ b/src/silx/gui/plot3d/_model/core.py
@@ -242,7 +242,7 @@ class StaticRow(BaseRow):
:param children: Iterable of BaseRow to start with (not signaled)
"""
- def __init__(self, display=('', None), roles=None, children=()):
+ def __init__(self, display=("", None), roles=None, children=()):
super(StaticRow, self).__init__(children)
self._dataByRoles = {} if roles is None else roles
self._dataByRoles[qt.Qt.DisplayRole] = display
@@ -278,15 +278,16 @@ class ProxyRow(BaseRow):
:param editorHint: Data to provide as UserRole for editor selection/setup
"""
- def __init__(self,
- name='',
- fget=None,
- fset=None,
- notify=None,
- toModelData=None,
- fromModelData=None,
- editorHint=None):
-
+ def __init__(
+ self,
+ name="",
+ fget=None,
+ fset=None,
+ notify=None,
+ toModelData=None,
+ fromModelData=None,
+ editorHint=None,
+ ):
super(ProxyRow, self).__init__()
self.__name = name
self.__editorHint = editorHint
@@ -317,8 +318,9 @@ class ProxyRow(BaseRow):
elif column == 1:
if role == qt.Qt.UserRole: # EditorHint
return self.__editorHint
- elif role == qt.Qt.DisplayRole or (role == qt.Qt.EditRole and
- self._fset is not None):
+ elif role == qt.Qt.DisplayRole or (
+ role == qt.Qt.EditRole and self._fset is not None
+ ):
data = self._fget()
if self._toModelData is not None:
data = self._toModelData(data)
@@ -364,6 +366,6 @@ class AngleDegreeRow(ProxyRow):
def data(self, column, role):
if column == 1 and role == qt.Qt.DisplayRole:
- return u'%g°' % super(AngleDegreeRow, self).data(column, role)
+ return "%g°" % super(AngleDegreeRow, self).data(column, role)
else:
return super(AngleDegreeRow, self).data(column, role)
diff --git a/src/silx/gui/plot3d/_model/items.py b/src/silx/gui/plot3d/_model/items.py
index c6bf69a..8441be7 100644
--- a/src/silx/gui/plot3d/_model/items.py
+++ b/src/silx/gui/plot3d/_model/items.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -30,7 +30,6 @@ __license__ = "MIT"
__date__ = "24/04/2018"
-from collections import OrderedDict
import functools
import logging
import weakref
@@ -74,15 +73,17 @@ class ItemProxyRow(ProxyRow):
:param editorHint: Data to provide as UserRole for editor selection/setup
"""
- def __init__(self,
- item,
- name='',
- fget=None,
- fset=None,
- events=None,
- toModelData=None,
- fromModelData=None,
- editorHint=None):
+ def __init__(
+ self,
+ item,
+ name="",
+ fget=None,
+ fset=None,
+ events=None,
+ toModelData=None,
+ fromModelData=None,
+ editorHint=None,
+ ):
super(ItemProxyRow, self).__init__(
name=name,
fget=fget,
@@ -90,10 +91,10 @@ class ItemProxyRow(ProxyRow):
notify=None,
toModelData=toModelData,
fromModelData=fromModelData,
- editorHint=editorHint)
+ editorHint=editorHint,
+ )
- if isinstance(events, (items.ItemChangedType,
- items.Item3DChangedType)):
+ if isinstance(events, (items.ItemChangedType, items.Item3DChangedType)):
events = (events,)
self.__events = events
item.sigItemChanged.connect(self._itemChanged)
@@ -122,8 +123,7 @@ class ItemAngleDegreeRow(AngleDegreeRow, ItemProxyRow):
class _DirectionalLightProxy(qt.QObject):
- """Proxy to handle directional light with angles rather than vector.
- """
+ """Proxy to handle directional light with angles rather than vector."""
sigAzimuthAngleChanged = qt.Signal()
"""Signal sent when the azimuth angle has changed."""
@@ -184,11 +184,11 @@ class _DirectionalLightProxy(qt.QObject):
"""Handle light direction update in the scene"""
# Invert direction to manipulate the 'source' pointing to
# the center of the viewport
- x, y, z = - self._light.direction
+ x, y, z = -self._light.direction
# Horizontal plane is plane xz
azimuth = int(round(numpy.degrees(numpy.arctan2(x, z))))
- altitude = int(round(numpy.degrees(numpy.pi/2. - numpy.arccos(y))))
+ altitude = int(round(numpy.degrees(numpy.pi / 2.0 - numpy.arccos(y))))
if azimuth != self.getAzimuthAngle():
self.setAzimuthAngle(azimuth)
@@ -199,12 +199,12 @@ class _DirectionalLightProxy(qt.QObject):
def _updateLight(self):
"""Update light direction in the scene"""
azimuth = numpy.radians(self._azimuth)
- delta = numpy.pi/2. - numpy.radians(self._altitude)
- if delta == 0.: # Avoids zenith position
+ delta = numpy.pi / 2.0 - numpy.radians(self._altitude)
+ if delta == 0.0: # Avoids zenith position
delta = 0.0001
- z = - numpy.sin(delta) * numpy.cos(azimuth)
- x = - numpy.sin(delta) * numpy.sin(azimuth)
- y = - numpy.cos(delta)
+ z = -numpy.sin(delta) * numpy.cos(azimuth)
+ x = -numpy.sin(delta) * numpy.sin(azimuth)
+ y = -numpy.cos(delta)
self._light.direction = x, y, z
@@ -216,69 +216,87 @@ class Settings(StaticRow):
def __init__(self, sceneWidget):
background = ColorProxyRow(
- name='Background',
+ name="Background",
fget=sceneWidget.getBackgroundColor,
fset=sceneWidget.setBackgroundColor,
- notify=sceneWidget.sigStyleChanged)
+ notify=sceneWidget.sigStyleChanged,
+ )
foreground = ColorProxyRow(
- name='Foreground',
+ name="Foreground",
fget=sceneWidget.getForegroundColor,
fset=sceneWidget.setForegroundColor,
- notify=sceneWidget.sigStyleChanged)
+ notify=sceneWidget.sigStyleChanged,
+ )
text = ColorProxyRow(
- name='Text',
+ name="Text",
fget=sceneWidget.getTextColor,
fset=sceneWidget.setTextColor,
- notify=sceneWidget.sigStyleChanged)
+ notify=sceneWidget.sigStyleChanged,
+ )
highlight = ColorProxyRow(
- name='Highlight',
+ name="Highlight",
fget=sceneWidget.getHighlightColor,
fset=sceneWidget.setHighlightColor,
- notify=sceneWidget.sigStyleChanged)
+ notify=sceneWidget.sigStyleChanged,
+ )
axesIndicator = ProxyRow(
- name='Axes Indicator',
+ name="Axes Indicator",
fget=sceneWidget.isOrientationIndicatorVisible,
fset=sceneWidget.setOrientationIndicatorVisible,
- notify=sceneWidget.sigStyleChanged)
+ notify=sceneWidget.sigStyleChanged,
+ )
# Light direction
self._lightProxy = _DirectionalLightProxy(sceneWidget.viewport.light)
azimuthNode = ProxyRow(
- name='Azimuth',
+ name="Azimuth",
fget=self._lightProxy.getAzimuthAngle,
fset=self._lightProxy.setAzimuthAngle,
notify=self._lightProxy.sigAzimuthAngleChanged,
- editorHint=(-90, 90))
+ editorHint=(-90, 90),
+ )
altitudeNode = ProxyRow(
- name='Altitude',
+ name="Altitude",
fget=self._lightProxy.getAltitudeAngle,
fset=self._lightProxy.setAltitudeAngle,
notify=self._lightProxy.sigAltitudeAngleChanged,
- editorHint=(-90, 90))
+ editorHint=(-90, 90),
+ )
- lightDirection = StaticRow(('Light Direction', None),
- children=(azimuthNode, altitudeNode))
+ lightDirection = StaticRow(
+ ("Light Direction", None), children=(azimuthNode, altitudeNode)
+ )
# Fog
fog = ProxyRow(
- name='Fog',
+ name="Fog",
fget=sceneWidget.getFogMode,
fset=sceneWidget.setFogMode,
notify=sceneWidget.sigStyleChanged,
toModelData=lambda mode: mode is Plot3DWidget.FogMode.LINEAR,
- fromModelData=lambda mode: Plot3DWidget.FogMode.LINEAR if mode else Plot3DWidget.FogMode.NONE)
+ fromModelData=lambda mode: Plot3DWidget.FogMode.LINEAR
+ if mode
+ else Plot3DWidget.FogMode.NONE,
+ )
# Settings row
- children = (background, foreground, text, highlight,
- axesIndicator, lightDirection, fog)
- super(Settings, self).__init__(('Settings', None), children=children)
+ children = (
+ background,
+ foreground,
+ text,
+ highlight,
+ axesIndicator,
+ lightDirection,
+ fog,
+ )
+ super(Settings, self).__init__(("Settings", None), children=children)
class Item3DRow(BaseRow):
@@ -296,8 +314,8 @@ class Item3DRow(BaseRow):
super(Item3DRow, self).__init__()
self.setFlags(
- self.flags(0) | qt.Qt.ItemIsUserCheckable | qt.Qt.ItemIsSelectable,
- 0)
+ self.flags(0) | qt.Qt.ItemIsUserCheckable | qt.Qt.ItemIsSelectable, 0
+ )
self.setFlags(self.flags(1) | qt.Qt.ItemIsSelectable, 1)
self._item = weakref.ref(item)
@@ -325,12 +343,12 @@ class Item3DRow(BaseRow):
return qt.Qt.Unchecked
elif role == qt.Qt.DecorationRole:
- return icons.getQIcon('item-3dim')
+ return icons.getQIcon("item-3dim")
elif role == qt.Qt.DisplayRole:
if self.__name is None:
item = self.item()
- return '' if item is None else item.getLabel()
+ return "" if item is None else item.getLabel()
else:
return self.__name
@@ -340,7 +358,7 @@ class Item3DRow(BaseRow):
if column == 0 and role == qt.Qt.CheckStateRole:
item = self.item()
if item is not None:
- item.setVisible(value == qt.Qt.Checked)
+ item.setVisible(qt.Qt.CheckState(value) == qt.Qt.Checked)
return True
else:
return False
@@ -359,10 +377,11 @@ class DataItem3DBoundingBoxRow(ItemProxyRow):
def __init__(self, item):
super(DataItem3DBoundingBoxRow, self).__init__(
item=item,
- name='Bounding box',
+ name="Bounding box",
fget=item.isBoundingBoxVisible,
fset=item.setBoundingBoxVisible,
- events=items.Item3DChangedType.BOUNDING_BOX_VISIBLE)
+ events=items.Item3DChangedType.BOUNDING_BOX_VISIBLE,
+ )
class MatrixProxyRow(ItemProxyRow):
@@ -378,10 +397,11 @@ class MatrixProxyRow(ItemProxyRow):
super(MatrixProxyRow, self).__init__(
item=item,
- name='',
+ name="",
fget=self._getMatrixRow,
fset=self._setMatrixRow,
- events=items.Item3DChangedType.TRANSFORM)
+ events=items.Item3DChangedType.TRANSFORM,
+ )
def _getMatrixRow(self):
"""Returns the matrix row.
@@ -422,19 +442,20 @@ class DataItem3DTransformRow(StaticRow):
:param DataItem3D item: The item for which to display/control transform
"""
- _ROTATION_CENTER_OPTIONS = 'Origin', 'Lower', 'Center', 'Upper'
+ _ROTATION_CENTER_OPTIONS = "Origin", "Lower", "Center", "Upper"
def __init__(self, item):
- super(DataItem3DTransformRow, self).__init__(('Transform', None))
+ super(DataItem3DTransformRow, self).__init__(("Transform", None))
self._item = weakref.ref(item)
translation = ItemProxyRow(
item=item,
- name='Translation',
+ name="Translation",
fget=item.getTranslation,
fset=self._setTranslation,
events=items.Item3DChangedType.TRANSFORM,
- toModelData=lambda data: qt.QVector3D(*data))
+ toModelData=lambda data: qt.QVector3D(*data),
+ )
self.addRow(translation)
# Here to keep a reference
@@ -443,69 +464,80 @@ class DataItem3DTransformRow(StaticRow):
self._zSetCenter = functools.partial(self._setCenter, index=2)
rotateCenter = StaticRow(
- ('Center', None),
+ ("Center", None),
children=(
- ItemProxyRow(item=item,
- name='X axis',
- fget=item.getRotationCenter,
- fset=self._xSetCenter,
- events=items.Item3DChangedType.TRANSFORM,
- toModelData=functools.partial(
- self._centerToModelData, index=0),
- editorHint=self._ROTATION_CENTER_OPTIONS),
- ItemProxyRow(item=item,
- name='Y axis',
- fget=item.getRotationCenter,
- fset=self._ySetCenter,
- events=items.Item3DChangedType.TRANSFORM,
- toModelData=functools.partial(
- self._centerToModelData, index=1),
- editorHint=self._ROTATION_CENTER_OPTIONS),
- ItemProxyRow(item=item,
- name='Z axis',
- fget=item.getRotationCenter,
- fset=self._zSetCenter,
- events=items.Item3DChangedType.TRANSFORM,
- toModelData=functools.partial(
- self._centerToModelData, index=2),
- editorHint=self._ROTATION_CENTER_OPTIONS),
- ))
+ ItemProxyRow(
+ item=item,
+ name="X axis",
+ fget=item.getRotationCenter,
+ fset=self._xSetCenter,
+ events=items.Item3DChangedType.TRANSFORM,
+ toModelData=functools.partial(self._centerToModelData, index=0),
+ editorHint=self._ROTATION_CENTER_OPTIONS,
+ ),
+ ItemProxyRow(
+ item=item,
+ name="Y axis",
+ fget=item.getRotationCenter,
+ fset=self._ySetCenter,
+ events=items.Item3DChangedType.TRANSFORM,
+ toModelData=functools.partial(self._centerToModelData, index=1),
+ editorHint=self._ROTATION_CENTER_OPTIONS,
+ ),
+ ItemProxyRow(
+ item=item,
+ name="Z axis",
+ fget=item.getRotationCenter,
+ fset=self._zSetCenter,
+ events=items.Item3DChangedType.TRANSFORM,
+ toModelData=functools.partial(self._centerToModelData, index=2),
+ editorHint=self._ROTATION_CENTER_OPTIONS,
+ ),
+ ),
+ )
rotate = StaticRow(
- ('Rotation', None),
+ ("Rotation", None),
children=(
ItemAngleDegreeRow(
item=item,
- name='Angle',
+ name="Angle",
fget=item.getRotation,
fset=self._setAngle,
events=items.Item3DChangedType.TRANSFORM,
- toModelData=lambda data: data[0]),
+ toModelData=lambda data: data[0],
+ ),
ItemProxyRow(
item=item,
- name='Axis',
+ name="Axis",
fget=item.getRotation,
fset=self._setAxis,
events=items.Item3DChangedType.TRANSFORM,
- toModelData=lambda data: qt.QVector3D(*data[1])),
- rotateCenter
- ))
+ toModelData=lambda data: qt.QVector3D(*data[1]),
+ ),
+ rotateCenter,
+ ),
+ )
self.addRow(rotate)
scale = ItemProxyRow(
item=item,
- name='Scale',
+ name="Scale",
fget=item.getScale,
fset=self._setScale,
events=items.Item3DChangedType.TRANSFORM,
- toModelData=lambda data: qt.QVector3D(*data))
+ toModelData=lambda data: qt.QVector3D(*data),
+ )
self.addRow(scale)
matrix = StaticRow(
- ('Matrix', None),
- children=(MatrixProxyRow(item, 0),
- MatrixProxyRow(item, 1),
- MatrixProxyRow(item, 2)))
+ ("Matrix", None),
+ children=(
+ MatrixProxyRow(item, 0),
+ MatrixProxyRow(item, 1),
+ MatrixProxyRow(item, 2),
+ ),
+ )
self.addRow(matrix)
def item(self):
@@ -522,8 +554,8 @@ class DataItem3DTransformRow(StaticRow):
value = center[index]
if isinstance(value, str):
return value.title()
- elif value == 0.:
- return 'Origin'
+ elif value == 0.0:
+ return "Origin"
else:
return str(value)
@@ -535,8 +567,8 @@ class DataItem3DTransformRow(StaticRow):
"""
item = self.item()
if item is not None:
- if value == 'Origin':
- value = 0.
+ if value == "Origin":
+ value = 0.0
elif value not in self._ROTATION_CENTER_OPTIONS:
value = float(value)
else:
@@ -583,8 +615,8 @@ class DataItem3DTransformRow(StaticRow):
item = self.item()
if item is not None:
sx, sy, sz = scale.x(), scale.y(), scale.z()
- if sx == 0. or sy == 0. or sz == 0.:
- _logger.warning('Cannot set scale to 0: ignored')
+ if sx == 0.0 or sy == 0.0 or sz == 0.0:
+ _logger.warning("Cannot set scale to 0: ignored")
else:
item.setScale(scale.x(), scale.y(), scale.z())
@@ -650,13 +682,14 @@ class InterpolationRow(ItemProxyRow):
modes = [mode.title() for mode in item.INTERPOLATION_MODES]
super(InterpolationRow, self).__init__(
item=item,
- name='Interpolation',
+ name="Interpolation",
fget=item.getInterpolation,
fset=item.setInterpolation,
events=items.Item3DChangedType.INTERPOLATION,
toModelData=lambda mode: mode.title(),
fromModelData=lambda mode: mode.lower(),
- editorHint=modes)
+ editorHint=modes,
+ )
class _ColormapBaseProxyRow(ProxyRow):
@@ -736,15 +769,14 @@ class _ColormapBoundRow(_ColormapBaseProxyRow):
def __init__(self, item, name, index):
self._index = index
_ColormapBaseProxyRow.__init__(
- self,
- item,
- name=name,
- fget=self._getBound,
- fset=self._setBound)
+ self, item, name=name, fget=self._getBound, fset=self._setBound
+ )
- self.setToolTip('Colormap %s bound:\n'
- 'Check to set bound manually, '
- 'uncheck for autoscale' % name.lower())
+ self.setToolTip(
+ "Colormap %s bound:\n"
+ "Check to set bound manually, "
+ "uncheck for autoscale" % name.lower()
+ )
def _getRawBound(self):
"""Proxy to get raw colormap bound
@@ -770,7 +802,7 @@ class _ColormapBoundRow(_ColormapBaseProxyRow):
bound = self._getColormapRange()[self._index]
return bound
else:
- return 1. # Fallback
+ return 1.0 # Fallback
def _setBound(self, value):
"""Proxy to set colormap bound.
@@ -816,7 +848,11 @@ class _ColormapBoundRow(_ColormapBaseProxyRow):
def setData(self, column, value, role):
if column == 0 and role == qt.Qt.CheckStateRole:
if self._colormap is not None:
- bound = self._getBound() if value == qt.Qt.Checked else None
+ bound = (
+ self._getBound()
+ if qt.Qt.CheckState(value) == qt.Qt.Checked
+ else None
+ )
self._setBound(bound)
return True
else:
@@ -838,10 +874,13 @@ class _ColormapGammaRow(_ColormapBaseProxyRow):
item,
name="Gamma",
fget=self._getGammaNormalizationParameter,
- fset=self._setGammaNormalizationParameter)
+ fset=self._setGammaNormalizationParameter,
+ )
- self.setToolTip('Colormap gamma correction parameter:\n'
- 'Only meaningful for gamma normalization.')
+ self.setToolTip(
+ "Colormap gamma correction parameter:\n"
+ "Only meaningful for gamma normalization."
+ )
def _getGammaNormalizationParameter(self):
"""Proxy for :meth:`Colormap.getGammaNormalizationParameter`"""
@@ -860,11 +899,11 @@ class _ColormapGammaRow(_ColormapBaseProxyRow):
if self._colormap is not None:
return self._colormap.getNormalization()
else:
- return ''
+ return ""
def flags(self, column):
if column in (0, 1):
- if self._getNormalization() == 'gamma':
+ if self._getNormalization() == "gamma":
flags = qt.Qt.ItemIsEditable | qt.Qt.ItemIsEnabled
else:
flags = qt.Qt.NoItemFlags # Disabled if not gamma correction
@@ -881,10 +920,7 @@ class ColormapRow(_ColormapBaseProxyRow):
"""
def __init__(self, item):
- super(ColormapRow, self).__init__(
- item,
- name='Colormap',
- fget=self._get)
+ super(ColormapRow, self).__init__(item, name="Colormap", fget=self._get)
self._colormapImage = None
@@ -892,33 +928,42 @@ class ColormapRow(_ColormapBaseProxyRow):
for cmap in preferredColormaps():
self._colormapsMapping[cmap.title()] = cmap
- self.addRow(ProxyRow(
- name='Name',
- fget=self._getName,
- fset=self._setName,
- notify=self._sigColormapChanged,
- editorHint=list(self._colormapsMapping.keys())))
+ self.addRow(
+ ProxyRow(
+ name="Name",
+ fget=self._getName,
+ fset=self._setName,
+ notify=self._sigColormapChanged,
+ editorHint=list(self._colormapsMapping.keys()),
+ )
+ )
norms = [norm.title() for norm in self._colormap.NORMALIZATIONS]
- self.addRow(ProxyRow(
- name='Normalization',
- fget=self._getNormalization,
- fset=self._setNormalization,
- notify=self._sigColormapChanged,
- editorHint=norms))
+ self.addRow(
+ ProxyRow(
+ name="Normalization",
+ fget=self._getNormalization,
+ fset=self._setNormalization,
+ notify=self._sigColormapChanged,
+ editorHint=norms,
+ )
+ )
self.addRow(_ColormapGammaRow(item))
modes = [mode.title() for mode in self._colormap.AUTOSCALE_MODES]
- self.addRow(ProxyRow(
- name='Autoscale Mode',
- fget=self._getAutoscaleMode,
- fset=self._setAutoscaleMode,
- notify=self._sigColormapChanged,
- editorHint=modes))
-
- self.addRow(_ColormapBoundRow(item, name='Min.', index=0))
- self.addRow(_ColormapBoundRow(item, name='Max.', index=1))
+ self.addRow(
+ ProxyRow(
+ name="Autoscale Mode",
+ fget=self._getAutoscaleMode,
+ fset=self._setAutoscaleMode,
+ notify=self._sigColormapChanged,
+ editorHint=modes,
+ )
+ )
+
+ self.addRow(_ColormapBoundRow(item, name="Min.", index=0))
+ self.addRow(_ColormapBoundRow(item, name="Max.", index=1))
self._sigColormapChanged.connect(self._updateColormapImage)
@@ -942,7 +987,7 @@ class ColormapRow(_ColormapBaseProxyRow):
if self._colormap is not None and self._colormap.getName() is not None:
return self._colormap.getName().title()
else:
- return ''
+ return ""
def _setName(self, name):
"""Proxy for :meth:`Colormap.setName`"""
@@ -956,7 +1001,7 @@ class ColormapRow(_ColormapBaseProxyRow):
if self._colormap is not None:
return self._colormap.getNormalization().title()
else:
- return ''
+ return ""
def _setNormalization(self, normalization):
"""Proxy for :meth:`Colormap.setNormalization`"""
@@ -968,7 +1013,7 @@ class ColormapRow(_ColormapBaseProxyRow):
if self._colormap is not None:
return self._colormap.getAutoscaleMode().title()
else:
- return ''
+ return ""
def _setAutoscaleMode(self, mode):
"""Proxy for :meth:`Colormap.setAutoscaleMode`"""
@@ -1001,11 +1046,12 @@ class SymbolRow(ItemProxyRow):
names = [item.getSymbolName(s) for s in item.getSupportedSymbols()]
super(SymbolRow, self).__init__(
item=item,
- name='Marker',
+ name="Marker",
fget=item.getSymbolName,
fset=item.setSymbol,
events=items.ItemChangedType.SYMBOL,
- editorHint=names)
+ editorHint=names,
+ )
class SymbolSizeRow(ItemProxyRow):
@@ -1017,11 +1063,12 @@ class SymbolSizeRow(ItemProxyRow):
def __init__(self, item):
super(SymbolSizeRow, self).__init__(
item=item,
- name='Marker size',
+ name="Marker size",
fget=item.getSymbolSize,
fset=item.setSymbolSize,
events=items.ItemChangedType.SYMBOL_SIZE,
- editorHint=(1, 20)) # TODO link with OpenGL max point size
+ editorHint=(1, 20),
+ ) # TODO link with OpenGL max point size
class PlaneEquationRow(ItemProxyRow):
@@ -1033,12 +1080,13 @@ class PlaneEquationRow(ItemProxyRow):
def __init__(self, item):
super(PlaneEquationRow, self).__init__(
item=item,
- name='Equation',
+ name="Equation",
fget=item.getParameters,
fset=item.setParameters,
events=items.ItemChangedType.POSITION,
toModelData=lambda data: qt.QVector4D(*data),
- fromModelData=lambda data: (data.x(), data.y(), data.z(), data.w()))
+ fromModelData=lambda data: (data.x(), data.y(), data.z(), data.w()),
+ )
self._item = weakref.ref(item)
def data(self, column, role):
@@ -1046,8 +1094,12 @@ class PlaneEquationRow(ItemProxyRow):
item = self._item()
if item is not None:
params = item.getParameters()
- return ('%gx %+gy %+gz %+g = 0' %
- (params[0], params[1], params[2], params[3]))
+ return "%gx %+gy %+gz %+g = 0" % (
+ params[0],
+ params[1],
+ params[2],
+ params[3],
+ )
return super(PlaneEquationRow, self).data(column, role)
@@ -1057,26 +1109,33 @@ class PlaneRow(ItemProxyRow):
:param Item3D item: Scene item with plane equation property
"""
- _PLANES = OrderedDict((('Plane 0', (1., 0., 0.)),
- ('Plane 1', (0., 1., 0.)),
- ('Plane 2', (0., 0., 1.)),
- ('-', None)))
+ _PLANES = dict(
+ (
+ ("Plane 0", (1.0, 0.0, 0.0)),
+ ("Plane 1", (0.0, 1.0, 0.0)),
+ ("Plane 2", (0.0, 0.0, 1.0)),
+ ("-", None),
+ )
+ )
"""Mapping of plane names to normals"""
- _PLANE_ICONS = {'Plane 0': '3d-plane-normal-x',
- 'Plane 1': '3d-plane-normal-y',
- 'Plane 2': '3d-plane-normal-z',
- '-': '3d-plane'}
+ _PLANE_ICONS = {
+ "Plane 0": "3d-plane-normal-x",
+ "Plane 1": "3d-plane-normal-y",
+ "Plane 2": "3d-plane-normal-z",
+ "-": "3d-plane",
+ }
"""Mapping of plane names to normals"""
def __init__(self, item):
super(PlaneRow, self).__init__(
item=item,
- name='Plane',
+ name="Plane",
fget=self.__getPlaneName,
fset=self.__setPlaneName,
events=items.ItemChangedType.POSITION,
- editorHint=tuple(self._PLANES.keys()))
+ editorHint=tuple(self._PLANES.keys()),
+ )
self._item = weakref.ref(item)
self._lastName = None
@@ -1101,7 +1160,7 @@ class PlaneRow(ItemProxyRow):
for name, normal in self._PLANES.items():
if numpy.array_equal(planeNormal, normal):
return name
- return '-'
+ return "-"
def __setPlaneName(self, data):
"""Set plane normal according to given plane name
@@ -1129,18 +1188,20 @@ class ComplexModeRow(ItemProxyRow):
:param Item3D item: Scene item with symbol property
"""
- def __init__(self, item, name='Mode'):
- names = [m.value.replace('_', ' ').title()
- for m in item.supportedComplexModes()]
+ def __init__(self, item, name="Mode"):
+ names = [
+ m.value.replace("_", " ").title() for m in item.supportedComplexModes()
+ ]
super(ComplexModeRow, self).__init__(
item=item,
name=name,
fget=item.getComplexMode,
fset=item.setComplexMode,
events=items.ItemChangedType.COMPLEX_MODE,
- toModelData=lambda data: data.value.replace('_', ' ').title(),
- fromModelData=lambda data: data.lower().replace(' ', '_'),
- editorHint=names)
+ toModelData=lambda data: data.value.replace("_", " ").title(),
+ fromModelData=lambda data: data.lower().replace(" ", "_"),
+ editorHint=names,
+ )
class RemoveIsosurfaceRow(BaseRow):
@@ -1161,7 +1222,7 @@ class RemoveIsosurfaceRow(BaseRow):
layout.setSpacing(0)
removeBtn = qt.QToolButton()
- removeBtn.setText('Delete')
+ removeBtn.setText("Delete")
removeBtn.setToolButtonStyle(qt.Qt.ToolButtonTextOnly)
layout.addWidget(removeBtn)
removeBtn.clicked.connect(self._removeClicked)
@@ -1216,28 +1277,37 @@ class IsosurfaceRow(Item3DRow):
item.sigItemChanged.connect(self._levelChanged)
- self.addRow(ItemProxyRow(
- item=item,
- name='Level',
- fget=self._getValueForLevelSlider,
- fset=self._setLevelFromSliderValue,
- events=items.Item3DChangedType.ISO_LEVEL,
- editorHint=self._LEVEL_SLIDER_RANGE))
-
- self.addRow(ItemColorProxyRow(
- item=item,
- name='Color',
- fget=self._rgbColor,
- fset=self._setRgbColor,
- events=items.ItemChangedType.COLOR))
-
- self.addRow(ItemProxyRow(
- item=item,
- name='Opacity',
- fget=self._opacity,
- fset=self._setOpacity,
- events=items.ItemChangedType.COLOR,
- editorHint=(0, 255)))
+ self.addRow(
+ ItemProxyRow(
+ item=item,
+ name="Level",
+ fget=self._getValueForLevelSlider,
+ fset=self._setLevelFromSliderValue,
+ events=items.Item3DChangedType.ISO_LEVEL,
+ editorHint=self._LEVEL_SLIDER_RANGE,
+ )
+ )
+
+ self.addRow(
+ ItemColorProxyRow(
+ item=item,
+ name="Color",
+ fget=self._rgbColor,
+ fset=self._setRgbColor,
+ events=items.ItemChangedType.COLOR,
+ )
+ )
+
+ self.addRow(
+ ItemProxyRow(
+ item=item,
+ name="Opacity",
+ fget=self._opacity,
+ fset=self._setOpacity,
+ events=items.ItemChangedType.COLOR,
+ editorHint=(0, 255),
+ )
+ )
self.addRow(RemoveIsosurfaceRow(item))
@@ -1256,7 +1326,7 @@ class IsosurfaceRow(Item3DRow):
if dataMax != dataMin:
offset = (item.getLevel() - dataMin) / (dataMax - dataMin)
else:
- offset = 0.
+ offset = 0.0
sliderMin, sliderMax = self._LEVEL_SLIDER_RANGE
value = sliderMin + (sliderMax - sliderMin) * offset
@@ -1342,8 +1412,8 @@ class IsosurfaceRow(Item3DRow):
return self._rgbColor()
elif column == 1 and role in (qt.Qt.DisplayRole, qt.Qt.EditRole):
- item = self.item()
- return None if item is None else item.getLevel()
+ item = self.item()
+ return None if item is None else item.getLevel()
return super(IsosurfaceRow, self).data(column, role)
@@ -1363,9 +1433,11 @@ class ComplexIsosurfaceRow(IsosurfaceRow):
:param ComplexIsosurface item:
"""
- _EVENTS = (items.ItemChangedType.VISIBLE,
- items.ItemChangedType.COLOR,
- items.ItemChangedType.COMPLEX_MODE)
+ _EVENTS = (
+ items.ItemChangedType.VISIBLE,
+ items.ItemChangedType.COLOR,
+ items.ItemChangedType.COMPLEX_MODE,
+ )
"""Events for which to update the first column in the tree"""
def __init__(self, item):
@@ -1415,8 +1487,10 @@ class ComplexIsosurfaceRow(IsosurfaceRow):
def data(self, column, role):
if column == 0 and role == qt.Qt.DecorationRole:
item = self.item()
- if (item is not None and
- item.getComplexMode() != items.ComplexMixIn.ComplexMode.NONE):
+ if (
+ item is not None
+ and item.getComplexMode() != items.ComplexMixIn.ComplexMode.NONE
+ ):
return self._colormapRow.getColormapImage()
return super(ComplexIsosurfaceRow, self).data(column, role)
@@ -1441,7 +1515,7 @@ class AddIsosurfaceRow(BaseRow):
layout.setSpacing(0)
addBtn = qt.QToolButton()
- addBtn.setText('+')
+ addBtn.setText("+")
addBtn.setToolButtonStyle(qt.Qt.ToolButtonTextOnly)
layout.addWidget(addBtn)
addBtn.clicked.connect(self._addClicked)
@@ -1474,11 +1548,9 @@ class AddIsosurfaceRow(BaseRow):
if volume is not None:
dataRange = volume.getDataRange()
if dataRange is None:
- dataRange = 0., 1.
+ dataRange = 0.0, 1.0
- volume.addIsosurface(
- numpy.mean((dataRange[0], dataRange[-1])),
- '#0000FF')
+ volume.addIsosurface(numpy.mean((dataRange[0], dataRange[-1])), "#0000FF")
class VolumeIsoSurfacesRow(StaticRow):
@@ -1489,8 +1561,7 @@ class VolumeIsoSurfacesRow(StaticRow):
"""
def __init__(self, volume):
- super(VolumeIsoSurfacesRow, self).__init__(
- ('Isosurfaces', None))
+ super(VolumeIsoSurfacesRow, self).__init__(("Isosurfaces", None))
self._volume = weakref.ref(volume)
volume.sigIsosurfaceAdded.connect(self._isosurfaceAdded)
@@ -1551,7 +1622,7 @@ class Scatter2DPropertyMixInRow(object):
"""
def __init__(self, item, propertyName):
- assert propertyName in ('lineWidth', 'symbol', 'symbolSize')
+ assert propertyName in ("lineWidth", "symbol", "symbolSize")
self.__propertyName = propertyName
self.__isEnabled = item.isPropertyEnabled(propertyName)
@@ -1600,7 +1671,7 @@ class Scatter2DSymbolRow(Scatter2DPropertyMixInRow, SymbolRow):
def __init__(self, item):
SymbolRow.__init__(self, item)
- Scatter2DPropertyMixInRow.__init__(self, item, 'symbol')
+ Scatter2DPropertyMixInRow.__init__(self, item, "symbol")
class Scatter2DSymbolSizeRow(Scatter2DPropertyMixInRow, SymbolSizeRow):
@@ -1613,7 +1684,7 @@ class Scatter2DSymbolSizeRow(Scatter2DPropertyMixInRow, SymbolSizeRow):
def __init__(self, item):
SymbolSizeRow.__init__(self, item)
- Scatter2DPropertyMixInRow.__init__(self, item, 'symbolSize')
+ Scatter2DPropertyMixInRow.__init__(self, item, "symbolSize")
class Scatter2DLineWidth(Scatter2DPropertyMixInRow, ItemProxyRow):
@@ -1626,14 +1697,16 @@ class Scatter2DLineWidth(Scatter2DPropertyMixInRow, ItemProxyRow):
def __init__(self, item):
# TODO link editorHint with OpenGL max line width
- ItemProxyRow.__init__(self,
- item=item,
- name='Line width',
- fget=item.getLineWidth,
- fset=item.setLineWidth,
- events=items.ItemChangedType.LINE_WIDTH,
- editorHint=(1, 10))
- Scatter2DPropertyMixInRow.__init__(self, item, 'lineWidth')
+ ItemProxyRow.__init__(
+ self,
+ item=item,
+ name="Line width",
+ fget=item.getLineWidth,
+ fset=item.setLineWidth,
+ events=items.ItemChangedType.LINE_WIDTH,
+ editorHint=(1, 10),
+ )
+ Scatter2DPropertyMixInRow.__init__(self, item, "lineWidth")
def initScatter2DNode(node, item):
@@ -1642,22 +1715,28 @@ def initScatter2DNode(node, item):
:param Item3DRow node: The model node to setup
:param Scatter2D item: The Scatter2D the node is representing
"""
- node.addRow(ItemProxyRow(
- item=item,
- name='Mode',
- fget=item.getVisualization,
- fset=item.setVisualization,
- events=items.ItemChangedType.VISUALIZATION_MODE,
- editorHint=[m.value.title() for m in item.supportedVisualizations()],
- toModelData=lambda data: data.value.title(),
- fromModelData=lambda data: data.lower()))
-
- node.addRow(ItemProxyRow(
- item=item,
- name='Height map',
- fget=item.isHeightMap,
- fset=item.setHeightMap,
- events=items.Item3DChangedType.HEIGHT_MAP))
+ node.addRow(
+ ItemProxyRow(
+ item=item,
+ name="Mode",
+ fget=item.getVisualization,
+ fset=item.setVisualization,
+ events=items.ItemChangedType.VISUALIZATION_MODE,
+ editorHint=[m.value.title() for m in item.supportedVisualizations()],
+ toModelData=lambda data: data.value.title(),
+ fromModelData=lambda data: data.lower(),
+ )
+ )
+
+ node.addRow(
+ ItemProxyRow(
+ item=item,
+ name="Height map",
+ fget=item.isHeightMap,
+ fset=item.setHeightMap,
+ events=items.Item3DChangedType.HEIGHT_MAP,
+ )
+ )
node.addRow(ColormapRow(item))
@@ -1691,12 +1770,15 @@ def initVolumeCutPlaneNode(node, item):
node.addRow(ColormapRow(item))
- node.addRow(ItemProxyRow(
- item=item,
- name='Show <=Min',
- fget=item.getDisplayValuesBelowMin,
- fset=item.setDisplayValuesBelowMin,
- events=items.ItemChangedType.ALPHA))
+ node.addRow(
+ ItemProxyRow(
+ item=item,
+ name="Show <=Min",
+ fget=item.getDisplayValuesBelowMin,
+ fset=item.setDisplayValuesBelowMin,
+ events=items.ItemChangedType.ALPHA,
+ )
+ )
node.addRow(InterpolationRow(item))
diff --git a/src/silx/gui/plot3d/_model/model.py b/src/silx/gui/plot3d/_model/model.py
index 5276878..2c687f2 100644
--- a/src/silx/gui/plot3d/_model/model.py
+++ b/src/silx/gui/plot3d/_model/model.py
@@ -176,6 +176,6 @@ class SceneModel(qt.QAbstractItemModel):
def headerData(self, section, orientation, role=qt.Qt.DisplayRole):
"""See :meth:`QAbstractItemModel.headerData`"""
if orientation == qt.Qt.Horizontal and role == qt.Qt.DisplayRole:
- return 'Item' if section == 0 else 'Value'
+ return "Item" if section == 0 else "Value"
else:
return None
diff --git a/src/silx/gui/plot3d/actions/io.py b/src/silx/gui/plot3d/actions/io.py
index 47b0ce5..3c6212f 100644
--- a/src/silx/gui/plot3d/actions/io.py
+++ b/src/silx/gui/plot3d/actions/io.py
@@ -57,9 +57,9 @@ class CopyAction(Plot3DAction):
def __init__(self, parent, plot3d=None):
super(CopyAction, self).__init__(parent, plot3d)
- self.setIcon(getQIcon('edit-copy'))
- self.setText('Copy')
- self.setToolTip('Copy a snapshot of the 3D scene to the clipboard')
+ self.setIcon(getQIcon("edit-copy"))
+ self.setText("Copy")
+ self.setToolTip("Copy a snapshot of the 3D scene to the clipboard")
self.setCheckable(False)
self.setShortcut(qt.QKeySequence.Copy)
self.setShortcutContext(qt.Qt.WidgetShortcut)
@@ -68,7 +68,7 @@ class CopyAction(Plot3DAction):
def _triggered(self, checked=False):
plot3d = self.getPlot3DWidget()
if plot3d is None:
- _logger.error('Cannot copy widget, no associated Plot3DWidget')
+ _logger.error("Cannot copy widget, no associated Plot3DWidget")
else:
image = plot3d.grabGL()
qt.QApplication.clipboard().setImage(image)
@@ -85,9 +85,9 @@ class SaveAction(Plot3DAction):
def __init__(self, parent, plot3d=None):
super(SaveAction, self).__init__(parent, plot3d)
- self.setIcon(getQIcon('document-save'))
- self.setText('Save...')
- self.setToolTip('Save a snapshot of the 3D scene')
+ self.setIcon(getQIcon("document-save"))
+ self.setText("Save...")
+ self.setToolTip("Save a snapshot of the 3D scene")
self.setCheckable(False)
self.setShortcut(qt.QKeySequence.Save)
self.setShortcutContext(qt.Qt.WidgetShortcut)
@@ -96,13 +96,14 @@ class SaveAction(Plot3DAction):
def _triggered(self, checked=False):
plot3d = self.getPlot3DWidget()
if plot3d is None:
- _logger.error('Cannot save widget, no associated Plot3DWidget')
+ _logger.error("Cannot save widget, no associated Plot3DWidget")
else:
dialog = qt.QFileDialog(self.parent())
- dialog.setWindowTitle('Save snapshot as')
+ dialog.setWindowTitle("Save snapshot as")
dialog.setModal(True)
- dialog.setNameFilters(('Plot3D Snapshot PNG (*.png)',
- 'Plot3D Snapshot JPEG (*.jpg)'))
+ dialog.setNameFilters(
+ ("Plot3D Snapshot PNG (*.png)", "Plot3D Snapshot JPEG (*.jpg)")
+ )
dialog.setFileMode(qt.QFileDialog.AnyFile)
dialog.setAcceptMode(qt.QFileDialog.AcceptSave)
@@ -116,17 +117,18 @@ class SaveAction(Plot3DAction):
# Forces the filename extension to match the chosen filter
extension = nameFilter.split()[-1][2:-1]
- if (len(filename) <= len(extension) or
- filename[-len(extension):].lower() != extension.lower()):
+ if (
+ len(filename) <= len(extension)
+ or filename[-len(extension) :].lower() != extension.lower()
+ ):
filename += extension
image = plot3d.grabGL()
if not image.save(filename):
- _logger.error('Failed to save image as %s', filename)
+ _logger.error("Failed to save image as %s", filename)
qt.QMessageBox.critical(
- self.parent(),
- 'Save snapshot as',
- 'Failed to save snapshot')
+ self.parent(), "Save snapshot as", "Failed to save snapshot"
+ )
class PrintAction(Plot3DAction):
@@ -140,9 +142,9 @@ class PrintAction(Plot3DAction):
def __init__(self, parent, plot3d=None):
super(PrintAction, self).__init__(parent, plot3d)
- self.setIcon(getQIcon('document-print'))
- self.setText('Print...')
- self.setToolTip('Print a snapshot of the 3D scene')
+ self.setIcon(getQIcon("document-print"))
+ self.setText("Print...")
+ self.setToolTip("Print a snapshot of the 3D scene")
self.setCheckable(False)
self.setShortcut(qt.QKeySequence.Print)
self.setShortcutContext(qt.Qt.WidgetShortcut)
@@ -158,11 +160,11 @@ class PrintAction(Plot3DAction):
def _triggered(self, checked=False):
plot3d = self.getPlot3DWidget()
if plot3d is None:
- _logger.error('Cannot print widget, no associated Plot3DWidget')
+ _logger.error("Cannot print widget, no associated Plot3DWidget")
else:
printer = self.getPrinter()
dialog = qt.QPrintDialog(printer, plot3d)
- dialog.setWindowTitle('Print Plot3D snapshot')
+ dialog.setWindowTitle("Print Plot3D snapshot")
if not dialog.exec():
return
@@ -174,19 +176,15 @@ class PrintAction(Plot3DAction):
return
pageRect = printer.pageRect(qt.QPrinter.DevicePixel)
- if (pageRect.width() < image.width() or
- pageRect.height() < image.height()):
+ if pageRect.width() < image.width() or pageRect.height() < image.height():
# Downscale to page
xScale = pageRect.width() / image.width()
yScale = pageRect.height() / image.height()
scale = min(xScale, yScale)
else:
- scale = 1.
+ scale = 1.0
- rect = qt.QRectF(0,
- 0,
- scale * image.width(),
- scale * image.height())
+ rect = qt.QRectF(0, 0, scale * image.width(), scale * image.height())
painter.drawImage(rect, image)
painter.end()
@@ -201,15 +199,14 @@ class VideoAction(Plot3DAction):
Plot3DWidget the action is associated with
"""
- PNG_SERIE_FILTER = 'Serie of PNG files (*.png)'
- MNG_FILTER = 'Multiple-image Network Graphics file (*.mng)'
+ PNG_SERIE_FILTER = "Serie of PNG files (*.png)"
+ MNG_FILTER = "Multiple-image Network Graphics file (*.mng)"
def __init__(self, parent, plot3d=None):
super(VideoAction, self).__init__(parent, plot3d)
- self.setText('Record video..')
- self.setIcon(getQIcon('camera'))
- self.setToolTip(
- 'Record a video of a 360 degrees rotation of the 3D scene.')
+ self.setText("Record video..")
+ self.setIcon(getQIcon("camera"))
+ self.setToolTip("Record a video of a 360 degrees rotation of the 3D scene.")
self.setCheckable(False)
self.triggered[bool].connect(self._triggered)
@@ -217,15 +214,13 @@ class VideoAction(Plot3DAction):
"""Action triggered callback"""
plot3d = self.getPlot3DWidget()
if plot3d is None:
- _logger.warning(
- 'Ignoring action triggered without Plot3DWidget set')
+ _logger.warning("Ignoring action triggered without Plot3DWidget set")
return
dialog = qt.QFileDialog(parent=plot3d)
- dialog.setWindowTitle('Save video as...')
+ dialog.setWindowTitle("Save video as...")
dialog.setModal(True)
- dialog.setNameFilters([self.PNG_SERIE_FILTER,
- self.MNG_FILTER])
+ dialog.setNameFilters([self.PNG_SERIE_FILTER, self.MNG_FILTER])
dialog.setFileMode(qt.QFileDialog.AnyFile)
dialog.setAcceptMode(qt.QFileDialog.AcceptSave)
@@ -237,18 +232,20 @@ class VideoAction(Plot3DAction):
# Forces the filename extension to match the chosen filter
extension = nameFilter.split()[-1][2:-1]
- if (len(filename) <= len(extension) or
- filename[-len(extension):].lower() != extension.lower()):
+ if (
+ len(filename) <= len(extension)
+ or filename[-len(extension) :].lower() != extension.lower()
+ ):
filename += extension
- nbFrames = int(4. * 25) # 4 seconds, 25 fps
+ nbFrames = int(4.0 * 25) # 4 seconds, 25 fps
if nameFilter == self.PNG_SERIE_FILTER:
self._saveAsPNGSerie(filename, nbFrames)
elif nameFilter == self.MNG_FILTER:
self._saveAsMNG(filename, nbFrames)
else:
- _logger.error('Unsupported file filter: %s', nameFilter)
+ _logger.error("Unsupported file filter: %s", nameFilter)
def _saveAsPNGSerie(self, filename, nbFrames):
"""Save video as serie of PNG files.
@@ -263,10 +260,11 @@ class VideoAction(Plot3DAction):
# Define filename template
nbDigits = int(numpy.log10(nbFrames)) + 1
- indexFormat = '%%0%dd' % nbDigits
- extensionIndex = filename.rfind('.')
- filenameFormat = \
+ indexFormat = "%%0%dd" % nbDigits
+ extensionIndex = filename.rfind(".")
+ filenameFormat = (
filename[:extensionIndex] + indexFormat + filename[extensionIndex:]
+ )
try:
for index, image in enumerate(self._video360(nbFrames)):
@@ -285,7 +283,7 @@ class VideoAction(Plot3DAction):
frames = (convertQImageToArray(im) for im in self._video360(nbFrames))
try:
- with open(filename, 'wb') as file_:
+ with open(filename, "wb") as file_:
for chunk in mng.convert(frames, nb_images=nbFrames):
file_.write(chunk)
except GeneratorExit:
@@ -300,11 +298,11 @@ class VideoAction(Plot3DAction):
plot3d = self.getPlot3DWidget()
assert plot3d is not None
- angleStep = 360. / nbFrames
+ angleStep = 360.0 / nbFrames
# Create progress bar dialog
dialog = qt.QDialog(plot3d)
- dialog.setWindowTitle('Record Video')
+ dialog.setWindowTitle("Record Video")
layout = qt.QVBoxLayout(dialog)
progress = qt.QProgressBar()
progress.setRange(0, nbFrames)
@@ -323,7 +321,7 @@ class VideoAction(Plot3DAction):
progress.setValue(frame)
image = plot3d.grabGL()
yield image
- plot3d.viewport.orbitCamera('left', angleStep)
+ plot3d.viewport.orbitCamera("left", angleStep)
qapp.processEvents()
if not dialog.isVisible():
break # It as been rejected by the abort button
@@ -331,4 +329,4 @@ class VideoAction(Plot3DAction):
dialog.accept()
if dialog.result() == qt.QDialog.Rejected:
- raise GeneratorExit('Aborted')
+ raise GeneratorExit("Aborted")
diff --git a/src/silx/gui/plot3d/actions/mode.py b/src/silx/gui/plot3d/actions/mode.py
index 179fe05..99a83b4 100644
--- a/src/silx/gui/plot3d/actions/mode.py
+++ b/src/silx/gui/plot3d/actions/mode.py
@@ -63,8 +63,9 @@ class InteractiveModeAction(Plot3DAction):
plot3d = self.getPlot3DWidget()
if plot3d is None:
_logger.error(
- 'Cannot set %s interaction, no associated Plot3DWidget' %
- self._interaction)
+ "Cannot set %s interaction, no associated Plot3DWidget"
+ % self._interaction
+ )
else:
plot3d.setInteractiveMode(self._interaction)
self.setChecked(True)
@@ -74,8 +75,7 @@ class InteractiveModeAction(Plot3DAction):
# Disconnect from previous Plot3DWidget
plot3d = self.getPlot3DWidget()
if plot3d is not None:
- plot3d.sigInteractiveModeChanged.disconnect(
- self._interactiveModeChanged)
+ plot3d.sigInteractiveModeChanged.disconnect(self._interactiveModeChanged)
super(InteractiveModeAction, self).setPlot3DWidget(widget)
@@ -84,13 +84,12 @@ class InteractiveModeAction(Plot3DAction):
self.setChecked(False)
else:
self.setChecked(widget.getInteractiveMode() == self._interaction)
- widget.sigInteractiveModeChanged.connect(
- self._interactiveModeChanged)
+ widget.sigInteractiveModeChanged.connect(self._interactiveModeChanged)
def _interactiveModeChanged(self):
plot3d = self.getPlot3DWidget()
if plot3d is None:
- _logger.error('Received a signal while there is no widget')
+ _logger.error("Received a signal while there is no widget")
else:
self.setChecked(plot3d.getInteractiveMode() == self._interaction)
@@ -104,11 +103,11 @@ class RotateArcballAction(InteractiveModeAction):
"""
def __init__(self, parent, plot3d=None):
- super(RotateArcballAction, self).__init__(parent, 'rotate', plot3d)
+ super(RotateArcballAction, self).__init__(parent, "rotate", plot3d)
- self.setIcon(getQIcon('rotate-3d'))
- self.setText('Rotate')
- self.setToolTip('Rotate the view. Press <b>Ctrl</b> to pan.')
+ self.setIcon(getQIcon("rotate-3d"))
+ self.setText("Rotate")
+ self.setToolTip("Rotate the view. Press <b>Ctrl</b> to pan.")
class PanAction(InteractiveModeAction):
@@ -120,11 +119,11 @@ class PanAction(InteractiveModeAction):
"""
def __init__(self, parent, plot3d=None):
- super(PanAction, self).__init__(parent, 'pan', plot3d)
+ super(PanAction, self).__init__(parent, "pan", plot3d)
- self.setIcon(getQIcon('pan'))
- self.setText('Pan')
- self.setToolTip('Pan the view. Press <b>Ctrl</b> to rotate.')
+ self.setIcon(getQIcon("pan"))
+ self.setText("Pan")
+ self.setToolTip("Pan the view. Press <b>Ctrl</b> to rotate.")
class PickingModeAction(Plot3DAction):
@@ -145,9 +144,9 @@ class PickingModeAction(Plot3DAction):
def __init__(self, parent, plot3d=None):
super(PickingModeAction, self).__init__(parent, plot3d)
- self.setIcon(getQIcon('pointing-hand'))
- self.setText('Picking')
- self.setToolTip('Toggle picking with left button click')
+ self.setIcon(getQIcon("pointing-hand"))
+ self.setText("Picking")
+ self.setToolTip("Toggle picking with left button click")
self.setCheckable(True)
self.triggered[bool].connect(self._triggered)
diff --git a/src/silx/gui/plot3d/actions/viewpoint.py b/src/silx/gui/plot3d/actions/viewpoint.py
index c3d640e..57a7c7a 100644
--- a/src/silx/gui/plot3d/actions/viewpoint.py
+++ b/src/silx/gui/plot3d/actions/viewpoint.py
@@ -50,9 +50,10 @@ class _SetViewpointAction(Plot3DAction):
:param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
Plot3DWidget the action is associated with
"""
+
def __init__(self, parent, face, plot3d=None):
super(_SetViewpointAction, self).__init__(parent, plot3d)
- assert face in ('side', 'front', 'back', 'left', 'right', 'top', 'bottom')
+ assert face in ("side", "front", "back", "left", "right", "top", "bottom")
self._face = face
self.setIconVisibleInMenu(True)
@@ -62,8 +63,7 @@ class _SetViewpointAction(Plot3DAction):
def _triggered(self, checked=False):
plot3d = self.getPlot3DWidget()
if plot3d is None:
- _logger.error(
- 'Cannot start/stop rotation, no associated Plot3DWidget')
+ _logger.error("Cannot start/stop rotation, no associated Plot3DWidget")
else:
plot3d.viewport.camera.extrinsic.reset(face=self._face)
plot3d.centerScene()
@@ -76,12 +76,13 @@ class FrontViewpointAction(_SetViewpointAction):
:param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
Plot3DWidget the action is associated with
"""
+
def __init__(self, parent, plot3d=None):
- super(FrontViewpointAction, self).__init__(parent, 'front', plot3d)
+ super(FrontViewpointAction, self).__init__(parent, "front", plot3d)
- self.setIcon(getQIcon('cube-front'))
- self.setText('Front')
- self.setToolTip('View along the -Z axis')
+ self.setIcon(getQIcon("cube-front"))
+ self.setText("Front")
+ self.setToolTip("View along the -Z axis")
class BackViewpointAction(_SetViewpointAction):
@@ -91,12 +92,13 @@ class BackViewpointAction(_SetViewpointAction):
:param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
Plot3DWidget the action is associated with
"""
+
def __init__(self, parent, plot3d=None):
- super(BackViewpointAction, self).__init__(parent, 'back', plot3d)
+ super(BackViewpointAction, self).__init__(parent, "back", plot3d)
- self.setIcon(getQIcon('cube-back'))
- self.setText('Back')
- self.setToolTip('View along the +Z axis')
+ self.setIcon(getQIcon("cube-back"))
+ self.setText("Back")
+ self.setToolTip("View along the +Z axis")
class LeftViewpointAction(_SetViewpointAction):
@@ -106,12 +108,13 @@ class LeftViewpointAction(_SetViewpointAction):
:param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
Plot3DWidget the action is associated with
"""
+
def __init__(self, parent, plot3d=None):
- super(LeftViewpointAction, self).__init__(parent, 'left', plot3d)
+ super(LeftViewpointAction, self).__init__(parent, "left", plot3d)
- self.setIcon(getQIcon('cube-left'))
- self.setText('Left')
- self.setToolTip('View along the +X axis')
+ self.setIcon(getQIcon("cube-left"))
+ self.setText("Left")
+ self.setToolTip("View along the +X axis")
class RightViewpointAction(_SetViewpointAction):
@@ -121,12 +124,13 @@ class RightViewpointAction(_SetViewpointAction):
:param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
Plot3DWidget the action is associated with
"""
+
def __init__(self, parent, plot3d=None):
- super(RightViewpointAction, self).__init__(parent, 'right', plot3d)
+ super(RightViewpointAction, self).__init__(parent, "right", plot3d)
- self.setIcon(getQIcon('cube-right'))
- self.setText('Right')
- self.setToolTip('View along the -X axis')
+ self.setIcon(getQIcon("cube-right"))
+ self.setText("Right")
+ self.setToolTip("View along the -X axis")
class TopViewpointAction(_SetViewpointAction):
@@ -136,12 +140,13 @@ class TopViewpointAction(_SetViewpointAction):
:param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
Plot3DWidget the action is associated with
"""
+
def __init__(self, parent, plot3d=None):
- super(TopViewpointAction, self).__init__(parent, 'top', plot3d)
+ super(TopViewpointAction, self).__init__(parent, "top", plot3d)
- self.setIcon(getQIcon('cube-top'))
- self.setText('Top')
- self.setToolTip('View along the -Y axis')
+ self.setIcon(getQIcon("cube-top"))
+ self.setText("Top")
+ self.setToolTip("View along the -Y axis")
class BottomViewpointAction(_SetViewpointAction):
@@ -151,12 +156,13 @@ class BottomViewpointAction(_SetViewpointAction):
:param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
Plot3DWidget the action is associated with
"""
+
def __init__(self, parent, plot3d=None):
- super(BottomViewpointAction, self).__init__(parent, 'bottom', plot3d)
+ super(BottomViewpointAction, self).__init__(parent, "bottom", plot3d)
- self.setIcon(getQIcon('cube-bottom'))
- self.setText('Bottom')
- self.setToolTip('View along the +Y axis')
+ self.setIcon(getQIcon("cube-bottom"))
+ self.setText("Bottom")
+ self.setToolTip("View along the +Y axis")
class SideViewpointAction(_SetViewpointAction):
@@ -166,12 +172,13 @@ class SideViewpointAction(_SetViewpointAction):
:param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
Plot3DWidget the action is associated with
"""
+
def __init__(self, parent, plot3d=None):
- super(SideViewpointAction, self).__init__(parent, 'side', plot3d)
+ super(SideViewpointAction, self).__init__(parent, "side", plot3d)
- self.setIcon(getQIcon('cube'))
- self.setText('Side')
- self.setToolTip('Side view')
+ self.setIcon(getQIcon("cube"))
+ self.setText("Side")
+ self.setToolTip("Side view")
class RotateViewpoint(Plot3DAction):
@@ -185,7 +192,7 @@ class RotateViewpoint(Plot3DAction):
_TIMEOUT_MS = 50
"""Time interval between to frames (in milliseconds)"""
- _DEGREE_PER_SECONDS = 360. / 5.
+ _DEGREE_PER_SECONDS = 360.0 / 5.0
"""Rotation speed of the animation"""
def __init__(self, parent, plot3d=None):
@@ -197,18 +204,16 @@ class RotateViewpoint(Plot3DAction):
self._timer.setInterval(self._TIMEOUT_MS) # 20fps
self._timer.timeout.connect(self._rotate)
- self.setIcon(getQIcon('cube-rotate'))
- self.setText('Rotate scene')
- self.setToolTip('Rotate the 3D scene around the vertical axis')
+ self.setIcon(getQIcon("cube-rotate"))
+ self.setText("Rotate scene")
+ self.setToolTip("Rotate the 3D scene around the vertical axis")
self.setCheckable(True)
self.triggered[bool].connect(self._triggered)
-
def _triggered(self, checked=False):
plot3d = self.getPlot3DWidget()
if plot3d is None:
- _logger.error(
- 'Cannot start/stop rotation, no associated Plot3DWidget')
+ _logger.error("Cannot start/stop rotation, no associated Plot3DWidget")
elif checked:
self._previousTime = time.time()
self._timer.start()
@@ -219,10 +224,10 @@ class RotateViewpoint(Plot3DAction):
def _rotate(self):
"""Perform a step of the rotation"""
if self._previousTime is None:
- _logger.error('Previous time not set!')
- angleStep = 0.
+ _logger.error("Previous time not set!")
+ angleStep = 0.0
else:
angleStep = self._DEGREE_PER_SECONDS * (time.time() - self._previousTime)
- self.getPlot3DWidget().viewport.orbitCamera('left', angleStep)
+ self.getPlot3DWidget().viewport.orbitCamera("left", angleStep)
self._previousTime = time.time()
diff --git a/src/silx/gui/plot3d/conftest.py b/src/silx/gui/plot3d/conftest.py
index da02238..37c35d5 100644
--- a/src/silx/gui/plot3d/conftest.py
+++ b/src/silx/gui/plot3d/conftest.py
@@ -1,5 +1,6 @@
import pytest
+
@pytest.mark.usefixtures("use_opengl")
def setup_module(module):
pass
diff --git a/src/silx/gui/plot3d/items/__init__.py b/src/silx/gui/plot3d/items/__init__.py
index 3d22103..b091ffc 100644
--- a/src/silx/gui/plot3d/items/__init__.py
+++ b/src/silx/gui/plot3d/items/__init__.py
@@ -31,8 +31,13 @@ __date__ = "15/11/2017"
from .core import DataItem3D, Item3D, GroupItem, GroupWithAxesItem # noqa
from .core import ItemChangedType, Item3DChangedType # noqa
-from .mixins import (ColormapMixIn, ComplexMixIn, InterpolationMixIn, # noqa
- PlaneMixIn, SymbolMixIn) # noqa
+from .mixins import (
+ ColormapMixIn,
+ ComplexMixIn,
+ InterpolationMixIn, # noqa
+ PlaneMixIn,
+ SymbolMixIn,
+) # noqa
from .clipplane import ClipPlane # noqa
from .image import ImageData, ImageRgba, HeightMapData, HeightMapRGBA # noqa
from .mesh import Mesh, ColormapMesh, Box, Cylinder, Hexagon # noqa
diff --git a/src/silx/gui/plot3d/items/_pick.py b/src/silx/gui/plot3d/items/_pick.py
index 49e1a5b..aad5daf 100644
--- a/src/silx/gui/plot3d/items/_pick.py
+++ b/src/silx/gui/plot3d/items/_pick.py
@@ -53,7 +53,7 @@ class PickContext(object):
self._widgetPosition = x, y
assert isinstance(viewport, Viewport)
self._viewport = viewport
- self._ndcZRange = -1., 1.
+ self._ndcZRange = -1.0, 1.0
self._enabled = True
self._condition = condition
@@ -108,7 +108,7 @@ class PickContext(object):
"""
return self._enabled
- def setNDCZRange(self, near=-1., far=1.):
+ def setNDCZRange(self, near=-1.0, far=1.0):
"""Set near and far Z value in normalized device coordinates
This allows to clip the ray to a subset of the NDC range
@@ -142,36 +142,33 @@ class PickContext(object):
or None if picked point is outside viewport
:rtype: Union[None,numpy.ndarray]
"""
- assert frame in ('ndc', 'camera', 'scene') or isinstance(frame, Base)
+ assert frame in ("ndc", "camera", "scene") or isinstance(frame, Base)
positionNdc = self.getNDCPosition()
if positionNdc is None:
return None
near, far = self._ndcZRange
- rayNdc = numpy.array((positionNdc + (near, 1.),
- positionNdc + (far, 1.)),
- dtype=numpy.float64)
- if frame == 'ndc':
+ rayNdc = numpy.array(
+ (positionNdc + (near, 1.0), positionNdc + (far, 1.0)), dtype=numpy.float64
+ )
+ if frame == "ndc":
return rayNdc
viewport = self.getViewport()
rayCamera = viewport.camera.intrinsic.transformPoints(
- rayNdc,
- direct=False,
- perspectiveDivide=True)
- if frame == 'camera':
+ rayNdc, direct=False, perspectiveDivide=True
+ )
+ if frame == "camera":
return rayCamera
- rayScene = viewport.camera.extrinsic.transformPoints(
- rayCamera, direct=False)
- if frame == 'scene':
+ rayScene = viewport.camera.extrinsic.transformPoints(rayCamera, direct=False)
+ if frame == "scene":
return rayScene
# frame is a scene Base object
- rayObject = frame.objectToSceneTransform.transformPoints(
- rayScene, direct=False)
+ rayObject = frame.objectToSceneTransform.transformPoints(rayScene, direct=False)
return rayObject
@@ -193,8 +190,7 @@ class PickingResult(_PickingResult):
"""
super(PickingResult, self).__init__(item, indices)
- self._objectPositions = numpy.array(
- positions, copy=False, dtype=numpy.float64)
+ self._objectPositions = numpy.array(positions, copy=False, dtype=numpy.float64)
# Store matrices to generate positions on demand
primitive = item._getScenePrimitive()
@@ -219,7 +215,7 @@ class PickingResult(_PickingResult):
item = self.getItem()
if self._fetchdata is None:
- if hasattr(item, 'getData'):
+ if hasattr(item, "getData"):
data = item.getData(copy=False)
else:
return None
@@ -228,7 +224,7 @@ class PickingResult(_PickingResult):
return numpy.array(data[indices], copy=copy)
- def getPositions(self, frame='scene', copy=True):
+ def getPositions(self, frame="scene", copy=True):
"""Returns picking positions in item coordinates.
:param str frame: The frame in which the positions are returned
@@ -239,24 +235,26 @@ class PickingResult(_PickingResult):
:return: Nx3 array of (x, y, z) coordinates
:rtype: numpy.ndarray
"""
- if frame == 'ndc':
+ if frame == "ndc":
if self._ndcPositions is None: # Lazy-loading
self._ndcPositions = self._objectToNDCTransform.transformPoints(
- self._objectPositions, perspectiveDivide=True)
+ self._objectPositions, perspectiveDivide=True
+ )
positions = self._ndcPositions
- elif frame == 'scene':
+ elif frame == "scene":
if self._scenePositions is None: # Lazy-loading
self._scenePositions = self._objectToSceneTransform.transformPoints(
- self._objectPositions)
+ self._objectPositions
+ )
positions = self._scenePositions
- elif frame == 'object':
+ elif frame == "object":
positions = self._objectPositions
else:
- raise ValueError('Unsupported frame argument: %s' % str(frame))
+ raise ValueError("Unsupported frame argument: %s" % str(frame))
return numpy.array(positions, copy=copy)
diff --git a/src/silx/gui/plot3d/items/clipplane.py b/src/silx/gui/plot3d/items/clipplane.py
index 83a3c0e..283230b 100644
--- a/src/silx/gui/plot3d/items/clipplane.py
+++ b/src/silx/gui/plot3d/items/clipplane.py
@@ -47,7 +47,8 @@ class ClipPlane(Item3D, PlaneMixIn):
def __init__(self, parent=None):
plane = primitives.ClipPlane()
Item3D.__init__(self, parent=parent, primitive=plane)
- PlaneMixIn.__init__(self, plane=plane)
+ PlaneMixIn.__init__(self)
+ self._setPlane(plane)
def __pickPreProcessing(self, context):
"""Common processing for :meth:`_pickPostProcess` and :meth:`_pickFull`
@@ -73,12 +74,15 @@ class ClipPlane(Item3D, PlaneMixIn):
rayObject[0, :3],
rayObject[1, :3],
planeNorm=self.getNormal(),
- planePt=self.getPoint())
+ planePt=self.getPoint(),
+ )
# A single intersection inside bounding box
- picked = (len(points) == 1 and
- numpy.all(bounds[0] <= points[0]) and
- numpy.all(points[0] <= bounds[1]))
+ picked = (
+ len(points) == 1
+ and numpy.all(bounds[0] <= points[0])
+ and numpy.all(points[0] <= bounds[1])
+ )
return picked, points, rayObject
@@ -96,18 +100,20 @@ class ClipPlane(Item3D, PlaneMixIn):
if picked: # A single intersection inside bounding box
# Clip NDC z range for following brother items
ndcIntersect = plane.objectToNDCTransform.transformPoint(
- points[0], perspectiveDivide=True)
+ points[0], perspectiveDivide=True
+ )
ndcNormal = plane.objectToNDCTransform.transformNormal(
- self.getNormal())
+ self.getNormal()
+ )
if ndcNormal[2] < 0:
- context.setNDCZRange(-1., ndcIntersect[2])
+ context.setNDCZRange(-1.0, ndcIntersect[2])
else:
- context.setNDCZRange(ndcIntersect[2], 1.)
+ context.setNDCZRange(ndcIntersect[2], 1.0)
else:
# TODO check this might not be correct
- rayObject[:, 3] = 1. # Make sure 4h coordinate is one
- if numpy.sum(rayObject[0] * self.getParameters()) < 0.:
+ rayObject[:, 3] = 1.0 # Make sure 4h coordinate is one
+ if numpy.sum(rayObject[0] * self.getParameters()) < 0.0:
# Disable picking for remaining brothers
context.setEnabled(False)
diff --git a/src/silx/gui/plot3d/items/core.py b/src/silx/gui/plot3d/items/core.py
index 5fbe62c..4caf41d 100644
--- a/src/silx/gui/plot3d/items/core.py
+++ b/src/silx/gui/plot3d/items/core.py
@@ -44,25 +44,25 @@ from ._pick import PickContext
class Item3DChangedType(enum.Enum):
"""Type of modification provided by :attr:`Item3D.sigItemChanged` signal."""
- INTERPOLATION = 'interpolationChanged'
+ INTERPOLATION = "interpolationChanged"
"""Item3D image interpolation changed flag."""
- TRANSFORM = 'transformChanged'
+ TRANSFORM = "transformChanged"
"""Item3D transform changed flag."""
- HEIGHT_MAP = 'heightMapChanged'
+ HEIGHT_MAP = "heightMapChanged"
"""Item3D height map changed flag."""
- ISO_LEVEL = 'isoLevelChanged'
+ ISO_LEVEL = "isoLevelChanged"
"""Isosurface level changed flag."""
- LABEL = 'labelChanged'
+ LABEL = "labelChanged"
"""Item's label changed flag."""
- BOUNDING_BOX_VISIBLE = 'boundingBoxVisibleChanged'
+ BOUNDING_BOX_VISIBLE = "boundingBoxVisibleChanged"
"""Item's bounding box visibility changed"""
- ROOT_ITEM = 'rootItemChanged'
+ ROOT_ITEM = "rootItemChanged"
"""Item's root changed flag."""
@@ -85,7 +85,9 @@ class Item3D(qt.QObject):
"""
def __init__(self, parent, primitive=None):
- qt.QObject.__init__(self, parent)
+ qt.QObject.__init__(self)
+ if parent is not None:
+ self.setParent(parent)
if primitive is None:
primitive = scene.Group()
@@ -97,12 +99,9 @@ class Item3D(qt.QObject):
labelIndex = self._LABEL_INDICES[self.__class__]
self._label = str(self.__class__.__name__)
if labelIndex != 0:
- self._label += u' %d' % labelIndex
+ self._label += " %d" % labelIndex
self._LABEL_INDICES[self.__class__] += 1
- if isinstance(parent, Item3D):
- parent.sigItemChanged.connect(self.__parentItemChanged)
-
def setParent(self, parent):
"""Override set parent to handle root item change"""
previousParent = self.parent()
@@ -203,7 +202,7 @@ class Item3D(qt.QObject):
:param color: RGBA color
:type color: tuple of 4 float in [0., 1.]
"""
- if hasattr(super(Item3D, self), '_setForegroundColor'):
+ if hasattr(super(Item3D, self), "_setForegroundColor"):
super(Item3D, self)._setForegroundColor(color)
def __syncForegroundColor(self):
@@ -213,8 +212,7 @@ class Item3D(qt.QObject):
if root is not None:
widget = root.parent()
if isinstance(widget, qt.QWidget):
- self._setForegroundColor(
- widget.getForegroundColor().getRgbF())
+ self._setForegroundColor(widget.getForegroundColor().getRgbF())
# picking
@@ -225,10 +223,12 @@ class Item3D(qt.QObject):
:return: Data indices at picked position or None
:rtype: Union[None,PickingResult]
"""
- if (self.isVisible() and
- context.isEnabled() and
- context.isItemPickable(self) and
- self._pickFastCheck(context)):
+ if (
+ self.isVisible()
+ and context.isEnabled()
+ and context.isItemPickable(self)
+ and self._pickFastCheck(context)
+ ):
return self._pickFull(context)
return None
@@ -251,8 +251,10 @@ class Item3D(qt.QObject):
bounds = primitive.objectToNDCTransform.transformBounds(bounds)
- return (bounds[0, 0] <= positionNdc[0] <= bounds[1, 0] and
- bounds[0, 1] <= positionNdc[1] <= bounds[1, 1])
+ return (
+ bounds[0, 0] <= positionNdc[0] <= bounds[1, 0]
+ and bounds[0, 1] <= positionNdc[1] <= bounds[1, 1]
+ )
def _pickFull(self, context):
"""Perform precise picking in this item at given widget position.
@@ -295,17 +297,21 @@ class DataItem3D(Item3D):
# Group transforms to do to data before rotation
# This is useful to handle rotation center relative to bbox
self._transformObjectToRotate = transform.TransformList(
- [self._matrix, self._scale])
+ [self._matrix, self._scale]
+ )
self._transformObjectToRotate.addListener(self._updateRotationCenter)
- self._rotationCenter = 0., 0., 0.
+ self._rotationCenter = 0.0, 0.0, 0.0
- self.__transforms = transform.TransformList([
- self._translate,
- self._rotateForwardTranslation,
- self._rotate,
- self._rotateBackwardTranslation,
- self._transformObjectToRotate])
+ self.__transforms = transform.TransformList(
+ [
+ self._translate,
+ self._rotateForwardTranslation,
+ self._rotate,
+ self._rotateBackwardTranslation,
+ self._transformObjectToRotate,
+ ]
+ )
self._getScenePrimitive().transforms = self.__transforms
@@ -327,7 +333,7 @@ class DataItem3D(Item3D):
"""
return self.__transforms
- def setScale(self, sx=1., sy=1., sz=1.):
+ def setScale(self, sx=1.0, sy=1.0, sz=1.0):
"""Set the scale of the item in the scene.
:param float sx: Scale factor along the X axis
@@ -346,7 +352,7 @@ class DataItem3D(Item3D):
"""
return self._scale.scale
- def setTranslation(self, x=0., y=0., z=0.):
+ def setTranslation(self, x=0.0, y=0.0, z=0.0):
"""Set the translation of the origin of the item in the scene.
:param float x: Offset of the data origin on the X axis
@@ -365,7 +371,7 @@ class DataItem3D(Item3D):
"""
return self._translate.translation
- _ROTATION_CENTER_TAGS = 'lower', 'center', 'upper'
+ _ROTATION_CENTER_TAGS = "lower", "center", "upper"
def _updateRotationCenter(self, *args, **kwargs):
"""Update rotation center relative to bounding box"""
@@ -374,28 +380,31 @@ class DataItem3D(Item3D):
# Patch position relative to bounding box
if position in self._ROTATION_CENTER_TAGS:
bounds = self._getScenePrimitive().bounds(
- transformed=False, dataBounds=True)
+ transformed=False, dataBounds=True
+ )
bounds = self._transformObjectToRotate.transformBounds(bounds)
if bounds is None:
- position = 0.
- elif position == 'lower':
+ position = 0.0
+ elif position == "lower":
position = bounds[0, index]
- elif position == 'center':
+ elif position == "center":
position = 0.5 * (bounds[0, index] + bounds[1, index])
- elif position == 'upper':
+ elif position == "upper":
position = bounds[1, index]
center.append(position)
- if not numpy.all(numpy.equal(
- center, self._rotateForwardTranslation.translation)):
+ if not numpy.all(
+ numpy.equal(center, self._rotateForwardTranslation.translation)
+ ):
self._rotateForwardTranslation.translation = center
- self._rotateBackwardTranslation.translation = \
- - self._rotateForwardTranslation.translation
+ self._rotateBackwardTranslation.translation = (
+ -self._rotateForwardTranslation.translation
+ )
self._updated(Item3DChangedType.TRANSFORM)
- def setRotationCenter(self, x=0., y=0., z=0.):
+ def setRotationCenter(self, x=0.0, y=0.0, z=0.0):
"""Set the center of rotation of the item.
Position of the rotation center is either a float
@@ -430,7 +439,7 @@ class DataItem3D(Item3D):
"""
return self._rotationCenter
- def setRotation(self, angle=0., axis=(0., 0., 1.)):
+ def setRotation(self, angle=0.0, axis=(0.0, 0.0, 1.0)):
"""Set the rotation of the item in the scene
:param float angle: The rotation angle in degrees.
@@ -439,8 +448,9 @@ class DataItem3D(Item3D):
axis = numpy.array(axis, dtype=numpy.float32)
assert axis.ndim == 1
assert axis.size == 3
- if (self._rotate.angle != angle or
- not numpy.all(numpy.equal(axis, self._rotate.axis))):
+ if self._rotate.angle != angle or not numpy.all(
+ numpy.equal(axis, self._rotate.axis)
+ ):
self._rotate.setAngleAxis(angle, axis)
self._updated(Item3DChangedType.TRANSFORM)
@@ -522,7 +532,7 @@ class BaseNodeItem(DataItem3D):
:rtype: tuple
"""
- raise NotImplementedError('getItems must be implemented in subclass')
+ raise NotImplementedError("getItems must be implemented in subclass")
def visit(self, included=True):
"""Generator visiting the group content.
@@ -535,7 +545,7 @@ class BaseNodeItem(DataItem3D):
yield self
for child in self.getItems():
yield child
- if hasattr(child, 'visit'):
+ if hasattr(child, "visit"):
for item in child.visit(included=False):
yield item
@@ -554,8 +564,7 @@ class BaseNodeItem(DataItem3D):
"""
viewport = self._getScenePrimitive().viewport
if viewport is None:
- raise RuntimeError(
- 'Cannot perform picking: Item not attached to a widget')
+ raise RuntimeError("Cannot perform picking: Item not attached to a widget")
context = PickContext(x, y, viewport, condition)
for result in self._pickItems(context):
@@ -638,12 +647,10 @@ class _BaseGroupItem(BaseNodeItem):
item.setParent(self)
if index is None:
- self._getGroupPrimitive().children.append(
- item._getScenePrimitive())
+ self._getGroupPrimitive().children.append(item._getScenePrimitive())
self._items.append(item)
else:
- self._getGroupPrimitive().children.insert(
- index, item._getScenePrimitive())
+ self._getGroupPrimitive().children.insert(index, item._getScenePrimitive())
self._items.insert(index, item)
self.sigItemAdded.emit(item)
@@ -691,8 +698,9 @@ class GroupWithAxesItem(_BaseGroupItem):
:param parent: The View widget this item belongs to.
"""
- super(GroupWithAxesItem, self).__init__(parent=parent,
- group=axes.LabelledAxes())
+ super(GroupWithAxesItem, self).__init__(
+ parent=parent, group=axes.LabelledAxes()
+ )
# Axes labels
@@ -747,9 +755,9 @@ class GroupWithAxesItem(_BaseGroupItem):
:return: object describing the labels
"""
labelledAxes = self._getScenePrimitive()
- return self._Labels((labelledAxes.xlabel,
- labelledAxes.ylabel,
- labelledAxes.zlabel))
+ return self._Labels(
+ (labelledAxes.xlabel, labelledAxes.ylabel, labelledAxes.zlabel)
+ )
class RootGroupWithAxesItem(GroupWithAxesItem):
diff --git a/src/silx/gui/plot3d/items/image.py b/src/silx/gui/plot3d/items/image.py
index 669e97d..d4d31c6 100644
--- a/src/silx/gui/plot3d/items/image.py
+++ b/src/silx/gui/plot3d/items/image.py
@@ -66,11 +66,12 @@ class _Image(DataItem3D, InterpolationMixIn):
points = utils.segmentPlaneIntersect(
rayObject[0, :3],
rayObject[1, :3],
- planeNorm=numpy.array((0., 0., 1.), dtype=numpy.float64),
- planePt=numpy.array((0., 0., 0.), dtype=numpy.float64))
+ planeNorm=numpy.array((0.0, 0.0, 1.0), dtype=numpy.float64),
+ planePt=numpy.array((0.0, 0.0, 0.0), dtype=numpy.float64),
+ )
if len(points) == 1: # Single intersection
- if points[0][0] < 0. or points[0][1] < 0.:
+ if points[0][0] < 0.0 or points[0][1] < 0.0:
return None # Outside image
row, column = int(points[0][1]), int(points[0][0])
data = self.getData(copy=False)
@@ -78,8 +79,9 @@ class _Image(DataItem3D, InterpolationMixIn):
if row < height and column < width:
return PickingResult(
self,
- positions=[(points[0][0], points[0][1], 0.)],
- indices=([row], [column]))
+ positions=[(points[0][0], points[0][1], 0.0)],
+ indices=([row], [column]),
+ )
else:
return None # Outside image
else: # Either no intersection or segment and image are coplanar
@@ -183,7 +185,7 @@ class _HeightMap(DataItem3D):
DataItem3D.__init__(self, parent=parent)
self.__data = numpy.zeros((0, 0), dtype=numpy.float32)
- def _pickFull(self, context, threshold=0., sort='depth'):
+ def _pickFull(self, context, threshold=0.0, sort="depth"):
"""Perform picking in this item at given widget position.
:param PickContext context: Current picking context
@@ -197,9 +199,9 @@ class _HeightMap(DataItem3D):
:return: Object holding the results or None
:rtype: Union[None,PickingResult]
"""
- assert sort in ('index', 'depth')
+ assert sort in ("index", "depth")
- rayNdc = context.getPickingSegment(frame='ndc')
+ rayNdc = context.getPickingSegment(frame="ndc")
if rayNdc is None: # No picking outside viewport
return None
@@ -212,40 +214,46 @@ class _HeightMap(DataItem3D):
height, width = heightData.shape
z = numpy.ravel(heightData)
y, x = numpy.mgrid[0:height, 0:width]
- dataPoints = numpy.transpose((numpy.ravel(x),
- numpy.ravel(y),
- z,
- numpy.ones_like(z)))
+ dataPoints = numpy.transpose(
+ (numpy.ravel(x), numpy.ravel(y), z, numpy.ones_like(z))
+ )
primitive = self._getScenePrimitive()
pointsNdc = primitive.objectToNDCTransform.transformPoints(
- dataPoints, perspectiveDivide=True)
+ dataPoints, perspectiveDivide=True
+ )
# Perform picking
distancesNdc = numpy.abs(pointsNdc[:, :2] - rayNdc[0, :2])
# TODO issue with symbol size: using pixel instead of points
- threshold += 1. # symbol size
- thresholdNdc = 2. * threshold / numpy.array(primitive.viewport.size)
- picked = numpy.where(numpy.logical_and(
+ threshold += 1.0 # symbol size
+ thresholdNdc = 2.0 * threshold / numpy.array(primitive.viewport.size)
+ picked = numpy.where(
+ numpy.logical_and(
numpy.all(distancesNdc < thresholdNdc, axis=1),
- numpy.logical_and(rayNdc[0, 2] <= pointsNdc[:, 2],
- pointsNdc[:, 2] <= rayNdc[1, 2])))[0]
+ numpy.logical_and(
+ rayNdc[0, 2] <= pointsNdc[:, 2], pointsNdc[:, 2] <= rayNdc[1, 2]
+ ),
+ )
+ )[0]
- if sort == 'depth':
+ if sort == "depth":
# Sort picked points from front to back
picked = picked[numpy.argsort(pointsNdc[picked, 2])]
if picked.size > 0:
# Convert indices from 1D to 2D
- return PickingResult(self,
- positions=dataPoints[picked, :3],
- indices=(picked // width, picked % width),
- fetchdata=self.getData)
+ return PickingResult(
+ self,
+ positions=dataPoints[picked, :3],
+ indices=(picked // width, picked % width),
+ fetchdata=self.getData,
+ )
else:
return None
- def setData(self, data, copy: bool=True):
+ def setData(self, data, copy: bool = True):
"""Set the height field data.
:param data:
@@ -258,7 +266,7 @@ class _HeightMap(DataItem3D):
self.__data = data
self._updated(ItemChangedType.DATA)
- def getData(self, copy: bool=True) -> numpy.ndarray:
+ def getData(self, copy: bool = True) -> numpy.ndarray:
"""Get the height field 2D data.
:param bool copy:
@@ -306,23 +314,22 @@ class HeightMapData(_HeightMap, ColormapMixIn):
if data.shape != heightData.shape: # data and height size miss-match
# Colormapped data is interpolated (nearest-neighbour) to match the height field
- data = data[numpy.floor(y * data.shape[0] / height).astype(numpy.int32),
- numpy.floor(x * data.shape[1] / height).astype(numpy.int32)]
+ data = data[
+ numpy.floor(y * data.shape[0] / height).astype(numpy.int32),
+ numpy.floor(x * data.shape[1] / height).astype(numpy.int32),
+ ]
x = numpy.ravel(x)
y = numpy.ravel(y)
primitive = primitives.Points(
- x=x,
- y=y,
- z=numpy.ravel(heightData),
- value=numpy.ravel(data),
- size=1)
- primitive.marker = 's'
+ x=x, y=y, z=numpy.ravel(heightData), value=numpy.ravel(data), size=1
+ )
+ primitive.marker = "s"
ColormapMixIn._setSceneColormap(self, primitive.colormap)
self._getScenePrimitive().children = [primitive]
- def setColormappedData(self, data, copy: bool=True):
+ def setColormappedData(self, data, copy: bool = True):
"""Set the 2D data used to compute colors.
:param data: 2D array of data
@@ -335,7 +342,7 @@ class HeightMapData(_HeightMap, ColormapMixIn):
self.__data = data
self._updated(ItemChangedType.DATA)
- def getColormappedData(self, copy: bool=True) -> numpy.ndarray:
+ def getColormappedData(self, copy: bool = True) -> numpy.ndarray:
"""Returns the 2D data used to compute colors.
:param copy:
@@ -380,8 +387,10 @@ class HeightMapRGBA(_HeightMap):
if rgba.shape[:2] != heightData.shape: # image and height size miss-match
# RGBA data is interpolated (nearest-neighbour) to match the height field
- rgba = rgba[numpy.floor(y * rgba.shape[0] / height).astype(numpy.int32),
- numpy.floor(x * rgba.shape[1] / height).astype(numpy.int32)]
+ rgba = rgba[
+ numpy.floor(y * rgba.shape[0] / height).astype(numpy.int32),
+ numpy.floor(x * rgba.shape[1] / height).astype(numpy.int32),
+ ]
x = numpy.ravel(x)
y = numpy.ravel(y)
@@ -391,11 +400,12 @@ class HeightMapRGBA(_HeightMap):
y=y,
z=numpy.ravel(heightData),
color=rgba.reshape(-1, rgba.shape[-1]),
- size=1)
- primitive.marker = 's'
+ size=1,
+ )
+ primitive.marker = "s"
self._getScenePrimitive().children = [primitive]
- def setColorData(self, data, copy: bool=True):
+ def setColorData(self, data, copy: bool = True):
"""Set the RGB(A) image to use.
Supported array format: float32 in [0, 1], uint8.
@@ -413,7 +423,7 @@ class HeightMapRGBA(_HeightMap):
self.__rgba = data
self._updated(ItemChangedType.DATA)
- def getColorData(self, copy: bool=True) -> numpy.ndarray:
+ def getColorData(self, copy: bool = True) -> numpy.ndarray:
"""Get the RGB(A) image data.
:param copy: True (default) to get a copy,
diff --git a/src/silx/gui/plot3d/items/mesh.py b/src/silx/gui/plot3d/items/mesh.py
index dc1df3e..89056c3 100644
--- a/src/silx/gui/plot3d/items/mesh.py
+++ b/src/silx/gui/plot3d/items/mesh.py
@@ -82,7 +82,7 @@ class _MeshBase(DataItem3D):
if self._getMesh() is None:
return numpy.empty((0, 3), dtype=numpy.float32)
else:
- return self._getMesh().getAttribute('position', copy=copy)
+ return self._getMesh().getAttribute("position", copy=copy)
def getNormalData(self, copy=True):
"""Get the mesh vertex normals.
@@ -96,7 +96,7 @@ class _MeshBase(DataItem3D):
if self._getMesh() is None:
return None
else:
- return self._getMesh().getAttribute('normal', copy=copy)
+ return self._getMesh().getAttribute("normal", copy=copy)
def getIndices(self, copy=True):
"""Get the vertex indices.
@@ -143,21 +143,23 @@ class _MeshBase(DataItem3D):
positions = utils.unindexArrays(mode, vertexIndices, positions)[0]
triangles = positions.reshape(-1, 3, 3)
else:
- if mode == 'triangles':
+ if mode == "triangles":
triangles = positions.reshape(-1, 3, 3)
- elif mode == 'triangle_strip':
+ elif mode == "triangle_strip":
# Expand strip
- triangles = numpy.empty((len(positions) - 2, 3, 3),
- dtype=positions.dtype)
+ triangles = numpy.empty(
+ (len(positions) - 2, 3, 3), dtype=positions.dtype
+ )
triangles[:, 0] = positions[:-2]
triangles[:, 1] = positions[1:-1]
triangles[:, 2] = positions[2:]
- elif mode == 'fan':
+ elif mode == "fan":
# Expand fan
- triangles = numpy.empty((len(positions) - 2, 3, 3),
- dtype=positions.dtype)
+ triangles = numpy.empty(
+ (len(positions) - 2, 3, 3), dtype=positions.dtype
+ )
triangles[:, 0] = positions[0]
triangles[:, 1] = positions[1:-1]
triangles[:, 2] = positions[2:]
@@ -167,7 +169,8 @@ class _MeshBase(DataItem3D):
return None
trianglesIndices, t, barycentric = glu.segmentTrianglesIntersection(
- rayObject, triangles)
+ rayObject, triangles
+ )
if len(trianglesIndices) == 0:
return None
@@ -177,13 +180,13 @@ class _MeshBase(DataItem3D):
# Get vertex index from triangle index and closest point in triangle
closest = numpy.argmax(barycentric, axis=1)
- if mode == 'triangles':
+ if mode == "triangles":
indices = trianglesIndices * 3 + closest
- elif mode == 'triangle_strip':
+ elif mode == "triangle_strip":
indices = trianglesIndices + closest
- elif mode == 'fan':
+ elif mode == "fan":
indices = trianglesIndices + closest # For corners 1 and 2
indices[closest == 0] = 0 # For first corner (common)
@@ -191,10 +194,9 @@ class _MeshBase(DataItem3D):
# Convert from indices in expanded triangles to input vertices
indices = vertexIndices[indices]
- return PickingResult(self,
- positions=points,
- indices=indices,
- fetchdata=self.getPositionData)
+ return PickingResult(
+ self, positions=points, indices=indices, fetchdata=self.getPositionData
+ )
class Mesh(_MeshBase):
@@ -206,13 +208,9 @@ class Mesh(_MeshBase):
def __init__(self, parent=None):
_MeshBase.__init__(self, parent=parent)
- def setData(self,
- position,
- color,
- normal=None,
- mode='triangles',
- indices=None,
- copy=True):
+ def setData(
+ self, position, color, normal=None, mode="triangles", indices=None, copy=True
+ ):
"""Set mesh geometry data.
Supported drawing modes are: 'triangles', 'triangle_strip', 'fan'
@@ -227,12 +225,13 @@ class Mesh(_MeshBase):
:param bool copy: True (default) to copy the data,
False to use as is (do not modify!).
"""
- assert mode in ('triangles', 'triangle_strip', 'fan')
+ assert mode in ("triangles", "triangle_strip", "fan")
if position is None or len(position) == 0:
mesh = None
else:
mesh = primitives.Mesh3D(
- position, color, normal, mode=mode, indices=indices, copy=copy)
+ position, color, normal, mode=mode, indices=indices, copy=copy
+ )
self._setMesh(mesh)
def getData(self, copy=True):
@@ -244,10 +243,12 @@ class Mesh(_MeshBase):
:return: The positions, colors, normals and mode
:rtype: tuple of numpy.ndarray
"""
- return (self.getPositionData(copy=copy),
- self.getColorData(copy=copy),
- self.getNormalData(copy=copy),
- self.getDrawMode())
+ return (
+ self.getPositionData(copy=copy),
+ self.getColorData(copy=copy),
+ self.getNormalData(copy=copy),
+ self.getDrawMode(),
+ )
def getColorData(self, copy=True):
"""Get the mesh vertex colors.
@@ -261,7 +262,7 @@ class Mesh(_MeshBase):
if self._getMesh() is None:
return numpy.empty((0, 4), dtype=numpy.float32)
else:
- return self._getMesh().getAttribute('color', copy=copy)
+ return self._getMesh().getAttribute("color", copy=copy)
class ColormapMesh(_MeshBase, ColormapMixIn):
@@ -274,13 +275,9 @@ class ColormapMesh(_MeshBase, ColormapMixIn):
_MeshBase.__init__(self, parent=parent)
ColormapMixIn.__init__(self, function.Colormap())
- def setData(self,
- position,
- value,
- normal=None,
- mode='triangles',
- indices=None,
- copy=True):
+ def setData(
+ self, position, value, normal=None, mode="triangles", indices=None, copy=True
+ ):
"""Set mesh geometry data.
Supported drawing modes are: 'triangles', 'triangle_strip', 'fan'
@@ -295,18 +292,21 @@ class ColormapMesh(_MeshBase, ColormapMixIn):
:param bool copy: True (default) to copy the data,
False to use as is (do not modify!).
"""
- assert mode in ('triangles', 'triangle_strip', 'fan')
+ assert mode in ("triangles", "triangle_strip", "fan")
if position is None or len(position) == 0:
mesh = None
else:
mesh = primitives.ColormapMesh3D(
position=position,
- value=numpy.array(value, copy=False).reshape(-1, 1), # Make it a 2D array
+ value=numpy.array(value, copy=False).reshape(
+ -1, 1
+ ), # Make it a 2D array
colormap=self._getSceneColormap(),
normal=normal,
mode=mode,
indices=indices,
- copy=copy)
+ copy=copy,
+ )
self._setMesh(mesh)
self._setColormappedData(self.getValueData(copy=False), copy=False)
@@ -320,10 +320,12 @@ class ColormapMesh(_MeshBase, ColormapMixIn):
:return: The positions, values, normals and mode
:rtype: tuple of numpy.ndarray
"""
- return (self.getPositionData(copy=copy),
- self.getValueData(copy=copy),
- self.getNormalData(copy=copy),
- self.getDrawMode())
+ return (
+ self.getPositionData(copy=copy),
+ self.getValueData(copy=copy),
+ self.getNormalData(copy=copy),
+ self.getDrawMode(),
+ )
def getValueData(self, copy=True):
"""Get the mesh vertex values.
@@ -337,7 +339,7 @@ class ColormapMesh(_MeshBase, ColormapMixIn):
if self._getMesh() is None:
return numpy.empty((0,), dtype=numpy.float32)
else:
- return self._getMesh().getAttribute('value', copy=copy)
+ return self._getMesh().getAttribute("value", copy=copy)
class _CylindricalVolume(DataItem3D):
@@ -362,8 +364,7 @@ class _CylindricalVolume(DataItem3D):
"""
raise NotImplementedError("Must be implemented in subclass")
- def _setData(self, position, radius, height, angles, color, flatFaces,
- rotation):
+ def _setData(self, position, radius, height, angles, color, flatFaces, rotation):
"""Set volume geometry data.
:param numpy.ndarray position:
@@ -384,10 +385,8 @@ class _CylindricalVolume(DataItem3D):
else:
self._nbFaces = len(angles) - 1
- volume = numpy.empty(shape=(len(angles) - 1, 12, 3),
- dtype=numpy.float32)
- normal = numpy.empty(shape=(len(angles) - 1, 12, 3),
- dtype=numpy.float32)
+ volume = numpy.empty(shape=(len(angles) - 1, 12, 3), dtype=numpy.float32)
+ normal = numpy.empty(shape=(len(angles) - 1, 12, 3), dtype=numpy.float32)
for i in range(0, len(angles) - 1):
# c6
@@ -404,71 +403,103 @@ class _CylindricalVolume(DataItem3D):
# \ /
# \/
# c1
- c1 = numpy.array([0, 0, -height/2])
+ c1 = numpy.array([0, 0, -height / 2])
c1 = rotation.transformPoint(c1)
- c2 = numpy.array([radius * numpy.cos(angles[i]),
- radius * numpy.sin(angles[i]),
- -height/2])
+ c2 = numpy.array(
+ [
+ radius * numpy.cos(angles[i]),
+ radius * numpy.sin(angles[i]),
+ -height / 2,
+ ]
+ )
c2 = rotation.transformPoint(c2)
- c3 = numpy.array([radius * numpy.cos(angles[i+1]),
- radius * numpy.sin(angles[i+1]),
- -height/2])
+ c3 = numpy.array(
+ [
+ radius * numpy.cos(angles[i + 1]),
+ radius * numpy.sin(angles[i + 1]),
+ -height / 2,
+ ]
+ )
c3 = rotation.transformPoint(c3)
- c4 = numpy.array([radius * numpy.cos(angles[i]),
- radius * numpy.sin(angles[i]),
- height/2])
+ c4 = numpy.array(
+ [
+ radius * numpy.cos(angles[i]),
+ radius * numpy.sin(angles[i]),
+ height / 2,
+ ]
+ )
c4 = rotation.transformPoint(c4)
- c5 = numpy.array([radius * numpy.cos(angles[i+1]),
- radius * numpy.sin(angles[i+1]),
- height/2])
+ c5 = numpy.array(
+ [
+ radius * numpy.cos(angles[i + 1]),
+ radius * numpy.sin(angles[i + 1]),
+ height / 2,
+ ]
+ )
c5 = rotation.transformPoint(c5)
- c6 = numpy.array([0, 0, height/2])
+ c6 = numpy.array([0, 0, height / 2])
c6 = rotation.transformPoint(c6)
- volume[i] = numpy.array([c1, c3, c2,
- c2, c3, c4,
- c3, c5, c4,
- c4, c5, c6])
+ volume[i] = numpy.array(
+ [c1, c3, c2, c2, c3, c4, c3, c5, c4, c4, c5, c6]
+ )
if flatFaces:
- normal[i] = numpy.array([numpy.cross(c3-c1, c2-c1), # c1
- numpy.cross(c2-c3, c1-c3), # c3
- numpy.cross(c1-c2, c3-c2), # c2
- numpy.cross(c3-c2, c4-c2), # c2
- numpy.cross(c4-c3, c2-c3), # c3
- numpy.cross(c2-c4, c3-c4), # c4
- numpy.cross(c5-c3, c4-c3), # c3
- numpy.cross(c4-c5, c3-c5), # c5
- numpy.cross(c3-c4, c5-c4), # c4
- numpy.cross(c5-c4, c6-c4), # c4
- numpy.cross(c6-c5, c5-c5), # c5
- numpy.cross(c4-c6, c5-c6)]) # c6
+ normal[i] = numpy.array(
+ [
+ numpy.cross(c3 - c1, c2 - c1), # c1
+ numpy.cross(c2 - c3, c1 - c3), # c3
+ numpy.cross(c1 - c2, c3 - c2), # c2
+ numpy.cross(c3 - c2, c4 - c2), # c2
+ numpy.cross(c4 - c3, c2 - c3), # c3
+ numpy.cross(c2 - c4, c3 - c4), # c4
+ numpy.cross(c5 - c3, c4 - c3), # c3
+ numpy.cross(c4 - c5, c3 - c5), # c5
+ numpy.cross(c3 - c4, c5 - c4), # c4
+ numpy.cross(c5 - c4, c6 - c4), # c4
+ numpy.cross(c6 - c5, c5 - c5), # c5
+ numpy.cross(c4 - c6, c5 - c6),
+ ]
+ ) # c6
else:
- normal[i] = numpy.array([numpy.cross(c3-c1, c2-c1),
- numpy.cross(c2-c3, c1-c3),
- numpy.cross(c1-c2, c3-c2),
- c2-c1, c3-c1, c4-c6, # c2 c2 c4
- c3-c1, c5-c6, c4-c6, # c3 c5 c4
- numpy.cross(c5-c4, c6-c4),
- numpy.cross(c6-c5, c5-c5),
- numpy.cross(c4-c6, c5-c6)])
+ normal[i] = numpy.array(
+ [
+ numpy.cross(c3 - c1, c2 - c1),
+ numpy.cross(c2 - c3, c1 - c3),
+ numpy.cross(c1 - c2, c3 - c2),
+ c2 - c1,
+ c3 - c1,
+ c4 - c6, # c2 c2 c4
+ c3 - c1,
+ c5 - c6,
+ c4 - c6, # c3 c5 c4
+ numpy.cross(c5 - c4, c6 - c4),
+ numpy.cross(c6 - c5, c5 - c5),
+ numpy.cross(c4 - c6, c5 - c6),
+ ]
+ )
# Multiplication according to the number of positions
- vertices = numpy.tile(volume.reshape(-1, 3), (len(position), 1))\
- .reshape((-1, 3))
- normals = numpy.tile(normal.reshape(-1, 3), (len(position), 1))\
- .reshape((-1, 3))
+ vertices = numpy.tile(volume.reshape(-1, 3), (len(position), 1)).reshape(
+ (-1, 3)
+ )
+ normals = numpy.tile(normal.reshape(-1, 3), (len(position), 1)).reshape(
+ (-1, 3)
+ )
# Translations
- numpy.add(vertices, numpy.tile(position, (1, (len(angles)-1) * 12))
- .reshape((-1, 3)), out=vertices)
+ numpy.add(
+ vertices,
+ numpy.tile(position, (1, (len(angles) - 1) * 12)).reshape((-1, 3)),
+ out=vertices,
+ )
# Colors
if numpy.ndim(color) == 2:
- color = numpy.tile(color, (1, 12 * (len(angles) - 1)))\
- .reshape(-1, 3)
+ color = numpy.tile(color, (1, 12 * (len(angles) - 1))).reshape(-1, 3)
self._mesh = primitives.Mesh3D(
- vertices, color, normals, mode='triangles', copy=False)
+ vertices, color, normals, mode="triangles", copy=False
+ )
self._getScenePrimitive().children.append(self._mesh)
self._updated(ItemChangedType.DATA)
@@ -488,11 +519,10 @@ class _CylindricalVolume(DataItem3D):
return None
rayObject = rayObject[:, :3]
- positions = self._mesh.getAttribute('position', copy=False)
+ positions = self._mesh.getAttribute("position", copy=False)
triangles = positions.reshape(-1, 3, 3) # 'triangle' draw mode
- trianglesIndices, t = glu.segmentTrianglesIntersection(
- rayObject, triangles)[:2]
+ trianglesIndices, t = glu.segmentTrianglesIntersection(rayObject, triangles)[:2]
if len(trianglesIndices) == 0:
return None
@@ -511,10 +541,9 @@ class _CylindricalVolume(DataItem3D):
points = t.reshape(-1, 1) * (rayObject[1] - rayObject[0]) + rayObject[0]
- return PickingResult(self,
- positions=points,
- indices=indices,
- fetchdata=self.getPosition)
+ return PickingResult(
+ self, positions=points, indices=indices, fetchdata=self.getPosition
+ )
class Box(_CylindricalVolume):
@@ -533,8 +562,13 @@ class Box(_CylindricalVolume):
self.rotation = None
self.setData()
- def setData(self, size=(1, 1, 1), color=(1, 1, 1),
- position=(0, 0, 0), rotation=(0, (0, 0, 0))):
+ def setData(
+ self,
+ size=(1, 1, 1),
+ color=(1, 1, 1),
+ position=(0, 0, 0),
+ rotation=(0, (0, 0, 0)),
+ ):
"""
Set Box geometry data.
@@ -550,28 +584,28 @@ class Box(_CylindricalVolume):
self.position = numpy.atleast_2d(numpy.array(position, copy=True))
self.size = numpy.array(size, copy=True)
self.color = numpy.array(color, copy=True)
- self.rotation = Rotate(rotation[0],
- rotation[1][0], rotation[1][1], rotation[1][2])
+ self.rotation = Rotate(
+ rotation[0], rotation[1][0], rotation[1][1], rotation[1][2]
+ )
- assert (numpy.ndim(self.color) == 1 or
- len(self.color) == len(self.position))
+ assert numpy.ndim(self.color) == 1 or len(self.color) == len(self.position)
- diagonal = numpy.sqrt(self.size[0]**2 + self.size[1]**2)
+ diagonal = numpy.sqrt(self.size[0] ** 2 + self.size[1] ** 2)
alpha = 2 * numpy.arcsin(self.size[1] / diagonal)
beta = 2 * numpy.arcsin(self.size[0] / diagonal)
- angles = numpy.array([0,
- alpha,
- alpha + beta,
- alpha + beta + alpha,
- 2 * numpy.pi])
+ angles = numpy.array(
+ [0, alpha, alpha + beta, alpha + beta + alpha, 2 * numpy.pi]
+ )
numpy.subtract(angles, 0.5 * alpha, out=angles)
- self._setData(self.position,
- numpy.sqrt(self.size[0]**2 + self.size[1]**2)/2,
- self.size[2],
- angles,
- self.color,
- True,
- self.rotation)
+ self._setData(
+ self.position,
+ numpy.sqrt(self.size[0] ** 2 + self.size[1] ** 2) / 2,
+ self.size[2],
+ angles,
+ self.color,
+ True,
+ self.rotation,
+ )
def getPosition(self, copy=True):
"""Get box(es) position(s).
@@ -622,8 +656,15 @@ class Cylinder(_CylindricalVolume):
self.rotation = None
self.setData()
- def setData(self, radius=1, height=1, color=(1, 1, 1), nbFaces=20,
- position=(0, 0, 0), rotation=(0, (0, 0, 0))):
+ def setData(
+ self,
+ radius=1,
+ height=1,
+ color=(1, 1, 1),
+ nbFaces=20,
+ position=(0, 0, 0),
+ rotation=(0, (0, 0, 0)),
+ ):
"""
Set the cylinder geometry data
@@ -644,20 +685,22 @@ class Cylinder(_CylindricalVolume):
self.height = float(height)
self.color = numpy.array(color, copy=True)
self.nbFaces = int(nbFaces)
- self.rotation = Rotate(rotation[0],
- rotation[1][0], rotation[1][1], rotation[1][2])
-
- assert (numpy.ndim(self.color) == 1 or
- len(self.color) == len(self.position))
-
- angles = numpy.linspace(0, 2*numpy.pi, self.nbFaces + 1)
- self._setData(self.position,
- self.radius,
- self.height,
- angles,
- self.color,
- False,
- self.rotation)
+ self.rotation = Rotate(
+ rotation[0], rotation[1][0], rotation[1][1], rotation[1][2]
+ )
+
+ assert numpy.ndim(self.color) == 1 or len(self.color) == len(self.position)
+
+ angles = numpy.linspace(0, 2 * numpy.pi, self.nbFaces + 1)
+ self._setData(
+ self.position,
+ self.radius,
+ self.height,
+ angles,
+ self.color,
+ False,
+ self.rotation,
+ )
def getPosition(self, copy=True):
"""Get cylinder(s) position(s).
@@ -716,8 +759,14 @@ class Hexagon(_CylindricalVolume):
self.rotation = None
self.setData()
- def setData(self, radius=1, height=1, color=(1, 1, 1),
- position=(0, 0, 0), rotation=(0, (0, 0, 0))):
+ def setData(
+ self,
+ radius=1,
+ height=1,
+ color=(1, 1, 1),
+ position=(0, 0, 0),
+ rotation=(0, (0, 0, 0)),
+ ):
"""
Set the uniform hexagonal prism geometry data
@@ -735,20 +784,22 @@ class Hexagon(_CylindricalVolume):
self.radius = float(radius)
self.height = float(height)
self.color = numpy.array(color, copy=True)
- self.rotation = Rotate(rotation[0], rotation[1][0], rotation[1][1],
- rotation[1][2])
-
- assert (numpy.ndim(self.color) == 1 or
- len(self.color) == len(self.position))
-
- angles = numpy.linspace(0, 2*numpy.pi, 7)
- self._setData(self.position,
- self.radius,
- self.height,
- angles,
- self.color,
- True,
- self.rotation)
+ self.rotation = Rotate(
+ rotation[0], rotation[1][0], rotation[1][1], rotation[1][2]
+ )
+
+ assert numpy.ndim(self.color) == 1 or len(self.color) == len(self.position)
+
+ angles = numpy.linspace(0, 2 * numpy.pi, 7)
+ self._setData(
+ self.position,
+ self.radius,
+ self.height,
+ angles,
+ self.color,
+ True,
+ self.rotation,
+ )
def getPosition(self, copy=True):
"""Get hexagonal prim(s) position(s).
@@ -758,7 +809,7 @@ class Hexagon(_CylindricalVolume):
False to get internal representation (do not modify!).
:return: Position(s) of hexagonal prism(s) as a (N, 3) array.
:rtype: numpy.ndarray
- """
+ """
return numpy.array(self.position, copy=copy)
def getRadius(self):
diff --git a/src/silx/gui/plot3d/items/mixins.py b/src/silx/gui/plot3d/items/mixins.py
index 45b569d..c69c3ac 100644
--- a/src/silx/gui/plot3d/items/mixins.py
+++ b/src/silx/gui/plot3d/items/mixins.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -29,11 +29,8 @@ __license__ = "MIT"
__date__ = "24/04/2018"
-import collections
import numpy
-from silx.math.combo import min_max
-
from ...plot.items.core import ItemMixInBase
from ...plot.items.core import ColormapMixIn as _ColormapMixIn
from ...plot.items.core import SymbolMixIn as _SymbolMixIn
@@ -53,24 +50,21 @@ class InterpolationMixIn(ItemMixInBase):
This object MUST have an interpolation property that is updated.
"""
- NEAREST_INTERPOLATION = 'nearest'
+ NEAREST_INTERPOLATION = "nearest"
"""Nearest interpolation mode (see :meth:`setInterpolation`)"""
- LINEAR_INTERPOLATION = 'linear'
+ LINEAR_INTERPOLATION = "linear"
"""Linear interpolation mode (see :meth:`setInterpolation`)"""
INTERPOLATION_MODES = NEAREST_INTERPOLATION, LINEAR_INTERPOLATION
"""Supported interpolation modes for :meth:`setInterpolation`"""
- def __init__(self, mode=NEAREST_INTERPOLATION, primitive=None):
- self.__primitive = primitive
+ def __init__(self):
+ self.__primitive = None
+ self.__interpolationMode = self.NEAREST_INTERPOLATION
self._syncPrimitiveInterpolation()
- self.__interpolationMode = None
- self.setInterpolation(mode)
-
def _setPrimitive(self, primitive):
-
"""Set the scene object for which to sync interpolation"""
self.__primitive = primitive
self._syncPrimitiveInterpolation()
@@ -151,24 +145,28 @@ class ComplexMixIn(_ComplexMixIn):
_ComplexMixIn.ComplexMode.IMAGINARY,
_ComplexMixIn.ComplexMode.ABSOLUTE,
_ComplexMixIn.ComplexMode.PHASE,
- _ComplexMixIn.ComplexMode.SQUARE_AMPLITUDE)
+ _ComplexMixIn.ComplexMode.SQUARE_AMPLITUDE,
+ )
"""Overrides supported ComplexMode"""
class SymbolMixIn(_SymbolMixIn):
"""Mix-in class for symbol and symbolSize properties for Item3D"""
- _SUPPORTED_SYMBOLS = collections.OrderedDict((
- ('o', 'Circle'),
- ('d', 'Diamond'),
- ('s', 'Square'),
- ('+', 'Plus'),
- ('x', 'Cross'),
- ('*', 'Star'),
- ('|', 'Vertical Line'),
- ('_', 'Horizontal Line'),
- ('.', 'Point'),
- (',', 'Pixel')))
+ _SUPPORTED_SYMBOLS = dict(
+ (
+ ("o", "Circle"),
+ ("d", "Diamond"),
+ ("s", "Square"),
+ ("+", "Plus"),
+ ("x", "Cross"),
+ ("*", "Star"),
+ ("|", "Vertical Line"),
+ ("_", "Horizontal Line"),
+ (".", "Point"),
+ (",", "Pixel"),
+ )
+ )
def _getSceneSymbol(self):
"""Returns a symbol name and size suitable for scene primitives.
@@ -177,11 +175,11 @@ class SymbolMixIn(_SymbolMixIn):
"""
symbol = self.getSymbol()
size = self.getSymbolSize()
- if symbol == ',': # pixel
- return 's', 1.
- elif symbol == '.': # point
+ if symbol == ",": # pixel
+ return "s", 1.0
+ elif symbol == ".": # point
# Size as in plot OpenGL backend, mimic matplotlib
- return 'o', numpy.ceil(0.5 * size) + 1.
+ return "o", numpy.ceil(0.5 * size) + 1.0
else:
return symbol, size
@@ -189,18 +187,24 @@ class SymbolMixIn(_SymbolMixIn):
class PlaneMixIn(ItemMixInBase):
"""Mix-in class for plane items (based on PlaneInGroup primitive)"""
- def __init__(self, plane):
+ def __init__(self):
+ self.__plane = None
+ self._setPlane(primitives.PlaneInGroup())
+
+ def _setPlane(self, plane: primitives.PlaneInGroup):
+ """Set plane primitive"""
+ if self.__plane is not None:
+ self.__plane.removeListener(self._planeChanged)
+ self.__plane.plane.removeListener(self._planePositionChanged)
+
assert isinstance(plane, primitives.PlaneInGroup)
self.__plane = plane
- self.__plane.alpha = 1.
+ self.__plane.alpha = 1.0
self.__plane.addListener(self._planeChanged)
self.__plane.plane.addListener(self._planePositionChanged)
- def _getPlane(self):
- """Returns plane primitive
-
- :rtype: primitives.PlaneInGroup
- """
+ def _getPlane(self) -> primitives.PlaneInGroup:
+ """Returns plane primitive"""
return self.__plane
def _planeChanged(self, source, *args, **kwargs):
@@ -211,7 +215,9 @@ class PlaneMixIn(ItemMixInBase):
def _planePositionChanged(self, source, *args, **kwargs):
"""Handle update of cut plane position and normal"""
- if self.__plane.visible: # TODO send even if hidden? or send also when showing if moved while hidden
+ if (
+ self.__plane.visible
+ ): # TODO send even if hidden? or send also when showing if moved while hidden
self._updated(ItemChangedType.POSITION)
# Plane position
@@ -283,5 +289,5 @@ class PlaneMixIn(ItemMixInBase):
:param color: RGBA color as 4 floats in [0, 1]
"""
self.__plane.color = rgba(color)
- if hasattr(super(PlaneMixIn, self), '_setForegroundColor'):
+ if hasattr(super(PlaneMixIn, self), "_setForegroundColor"):
super(PlaneMixIn, self)._setForegroundColor(color)
diff --git a/src/silx/gui/plot3d/items/scatter.py b/src/silx/gui/plot3d/items/scatter.py
index c93db88..b8f2f39 100644
--- a/src/silx/gui/plot3d/items/scatter.py
+++ b/src/silx/gui/plot3d/items/scatter.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -28,16 +28,13 @@ __authors__ = ["T. Vincent"]
__license__ = "MIT"
__date__ = "15/11/2017"
-try:
- from collections import abc
-except ImportError: # Python2 support
- import collections as abc
+from collections import abc
import logging
+import sys
import numpy
+from matplotlib.tri import Triangulation
-from ....utils.deprecation import deprecated
from ... import _glutils as glu
-from ...plot._utils.delaunay import delaunay
from ..scene import function, primitives, utils
from ...plot.items import ScatterVisualizationMixIn
@@ -65,7 +62,8 @@ class Scatter3D(DataItem3D, ColormapMixIn, SymbolMixIn):
noData = numpy.zeros((0, 1), dtype=numpy.float32)
symbol, size = self._getSceneSymbol()
self._scatter = primitives.Points(
- x=noData, y=noData, z=noData, value=noData, size=size)
+ x=noData, y=noData, z=noData, value=noData, size=size
+ )
self._scatter.marker = symbol
self._getScenePrimitive().children.append(self._scatter)
@@ -77,7 +75,7 @@ class Scatter3D(DataItem3D, ColormapMixIn, SymbolMixIn):
if event in (ItemChangedType.SYMBOL, ItemChangedType.SYMBOL_SIZE):
symbol, size = self._getSceneSymbol()
self._scatter.marker = symbol
- self._scatter.setAttribute('size', size, copy=True)
+ self._scatter.setAttribute("size", size, copy=True)
super(Scatter3D, self)._updated(event)
@@ -92,10 +90,10 @@ class Scatter3D(DataItem3D, ColormapMixIn, SymbolMixIn):
True (default) to copy the data,
False to use provided data (do not modify!)
"""
- self._scatter.setAttribute('x', x, copy=copy)
- self._scatter.setAttribute('y', y, copy=copy)
- self._scatter.setAttribute('z', z, copy=copy)
- self._scatter.setAttribute('value', value, copy=copy)
+ self._scatter.setAttribute("x", x, copy=copy)
+ self._scatter.setAttribute("y", y, copy=copy)
+ self._scatter.setAttribute("z", z, copy=copy)
+ self._scatter.setAttribute("value", value, copy=copy)
self._setColormappedData(self.getValueData(copy=False), copy=False)
self._updated(ItemChangedType.DATA)
@@ -107,10 +105,12 @@ class Scatter3D(DataItem3D, ColormapMixIn, SymbolMixIn):
False to return internal data (do not modify!)
:return: (x, y, z, value)
"""
- return (self.getXData(copy),
- self.getYData(copy),
- self.getZData(copy),
- self.getValueData(copy))
+ return (
+ self.getXData(copy),
+ self.getYData(copy),
+ self.getZData(copy),
+ self.getValueData(copy),
+ )
def getXData(self, copy=True):
"""Returns X data coordinates.
@@ -120,7 +120,7 @@ class Scatter3D(DataItem3D, ColormapMixIn, SymbolMixIn):
:return: X coordinates
:rtype: numpy.ndarray
"""
- return self._scatter.getAttribute('x', copy=copy).reshape(-1)
+ return self._scatter.getAttribute("x", copy=copy).reshape(-1)
def getYData(self, copy=True):
"""Returns Y data coordinates.
@@ -130,7 +130,7 @@ class Scatter3D(DataItem3D, ColormapMixIn, SymbolMixIn):
:return: Y coordinates
:rtype: numpy.ndarray
"""
- return self._scatter.getAttribute('y', copy=copy).reshape(-1)
+ return self._scatter.getAttribute("y", copy=copy).reshape(-1)
def getZData(self, copy=True):
"""Returns Z data coordinates.
@@ -140,7 +140,7 @@ class Scatter3D(DataItem3D, ColormapMixIn, SymbolMixIn):
:return: Z coordinates
:rtype: numpy.ndarray
"""
- return self._scatter.getAttribute('z', copy=copy).reshape(-1)
+ return self._scatter.getAttribute("z", copy=copy).reshape(-1)
def getValueData(self, copy=True):
"""Returns data values.
@@ -150,14 +150,9 @@ class Scatter3D(DataItem3D, ColormapMixIn, SymbolMixIn):
:return: data values
:rtype: numpy.ndarray
"""
- return self._scatter.getAttribute('value', copy=copy).reshape(-1)
+ return self._scatter.getAttribute("value", copy=copy).reshape(-1)
- @deprecated(reason="Consistency with PlotWidget items",
- replacement="getValueData", since_version="0.10.0")
- def getValues(self, copy=True):
- return self.getValueData(copy)
-
- def _pickFull(self, context, threshold=0., sort='depth'):
+ def _pickFull(self, context, threshold=0.0, sort="depth"):
"""Perform picking in this item at given widget position.
:param PickContext context: Current picking context
@@ -171,9 +166,9 @@ class Scatter3D(DataItem3D, ColormapMixIn, SymbolMixIn):
:return: Object holding the results or None
:rtype: Union[None,PickingResult]
"""
- assert sort in ('index', 'depth')
+ assert sort in ("index", "depth")
- rayNdc = context.getPickingSegment(frame='ndc')
+ rayNdc = context.getPickingSegment(frame="ndc")
if rayNdc is None: # No picking outside viewport
return None
@@ -184,49 +179,57 @@ class Scatter3D(DataItem3D, ColormapMixIn, SymbolMixIn):
primitive = self._getScenePrimitive()
- dataPoints = numpy.transpose((xData,
- self.getYData(copy=False),
- self.getZData(copy=False),
- numpy.ones_like(xData)))
+ dataPoints = numpy.transpose(
+ (
+ xData,
+ self.getYData(copy=False),
+ self.getZData(copy=False),
+ numpy.ones_like(xData),
+ )
+ )
pointsNdc = primitive.objectToNDCTransform.transformPoints(
- dataPoints, perspectiveDivide=True)
+ dataPoints, perspectiveDivide=True
+ )
# Perform picking
distancesNdc = numpy.abs(pointsNdc[:, :2] - rayNdc[0, :2])
# TODO issue with symbol size: using pixel instead of points
threshold += self.getSymbolSize()
- thresholdNdc = 2. * threshold / numpy.array(primitive.viewport.size)
- picked = numpy.where(numpy.logical_and(
+ thresholdNdc = 2.0 * threshold / numpy.array(primitive.viewport.size)
+ picked = numpy.where(
+ numpy.logical_and(
numpy.all(distancesNdc < thresholdNdc, axis=1),
- numpy.logical_and(rayNdc[0, 2] <= pointsNdc[:, 2],
- pointsNdc[:, 2] <= rayNdc[1, 2])))[0]
+ numpy.logical_and(
+ rayNdc[0, 2] <= pointsNdc[:, 2], pointsNdc[:, 2] <= rayNdc[1, 2]
+ ),
+ )
+ )[0]
- if sort == 'depth':
+ if sort == "depth":
# Sort picked points from front to back
picked = picked[numpy.argsort(pointsNdc[picked, 2])]
if picked.size > 0:
- return PickingResult(self,
- positions=dataPoints[picked, :3],
- indices=picked,
- fetchdata=self.getValueData)
+ return PickingResult(
+ self,
+ positions=dataPoints[picked, :3],
+ indices=picked,
+ fetchdata=self.getValueData,
+ )
else:
return None
-class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn,
- ScatterVisualizationMixIn):
+class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn, ScatterVisualizationMixIn):
"""2D scatter data with settable visualization mode.
:param parent: The View widget this item belongs to.
"""
_VISUALIZATION_PROPERTIES = {
- ScatterVisualizationMixIn.Visualization.POINTS:
- ('symbol', 'symbolSize'),
- ScatterVisualizationMixIn.Visualization.LINES:
- ('lineWidth',),
+ ScatterVisualizationMixIn.Visualization.POINTS: ("symbol", "symbolSize"),
+ ScatterVisualizationMixIn.Visualization.LINES: ("lineWidth",),
ScatterVisualizationMixIn.Visualization.SOLID: (),
}
"""Dict {visualization mode: property names used in this mode}"""
@@ -241,7 +244,7 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn,
ScatterVisualizationMixIn.__init__(self)
self._heightMap = False
- self._lineWidth = 1.
+ self._lineWidth = 1.0
self._x = numpy.zeros((0,), dtype=numpy.float32)
self._y = numpy.zeros((0,), dtype=numpy.float32)
@@ -260,7 +263,7 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn,
for child in self._getScenePrimitive().children:
if isinstance(child, primitives.Points):
child.marker = symbol
- child.setAttribute('size', size, copy=True)
+ child.setAttribute("size", size, copy=True)
elif event is ItemChangedType.VISIBLE:
# TODO smart update?, need dirty flags
@@ -281,7 +284,7 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn,
By default, it is the current visualization mode.
:return:
"""
- assert name in ('lineWidth', 'symbol', 'symbolSize')
+ assert name in ("lineWidth", "symbol", "symbolSize")
if visualization is None:
visualization = self.getVisualization()
assert visualization in self.supportedVisualizations()
@@ -322,11 +325,11 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn,
:param float width: Width in pixels
"""
width = float(width)
- assert width >= 1.
+ assert width >= 1.0
if width != self._lineWidth:
self._lineWidth = width
for child in self._getScenePrimitive().children:
- if hasattr(child, 'lineWidth'):
+ if hasattr(child, "lineWidth"):
child.lineWidth = width
self._updated(ItemChangedType.LINE_WIDTH)
@@ -342,15 +345,14 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn,
True (default) to make a copy of the data,
False to avoid copy if possible (do not modify the arrays).
"""
- x = numpy.array(
- x, copy=copy, dtype=numpy.float32, order='C').reshape(-1)
- y = numpy.array(
- y, copy=copy, dtype=numpy.float32, order='C').reshape(-1)
+ x = numpy.array(x, copy=copy, dtype=numpy.float32, order="C").reshape(-1)
+ y = numpy.array(y, copy=copy, dtype=numpy.float32, order="C").reshape(-1)
assert len(x) == len(y)
if isinstance(value, abc.Iterable):
value = numpy.array(
- value, copy=copy, dtype=numpy.float32, order='C').reshape(-1)
+ value, copy=copy, dtype=numpy.float32, order="C"
+ ).reshape(-1)
assert len(value) == len(x)
else: # Single scalar
value = numpy.array((float(value),), dtype=numpy.float32)
@@ -376,9 +378,11 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn,
False to return internal data (do not modify!)
:return: (x, y, value)
"""
- return (self.getXData(copy=copy),
- self.getYData(copy=copy),
- self.getValueData(copy=copy))
+ return (
+ self.getXData(copy=copy),
+ self.getYData(copy=copy),
+ self.getValueData(copy=copy),
+ )
def getXData(self, copy=True):
"""Returns X data coordinates.
@@ -410,12 +414,7 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn,
"""
return numpy.array(self._value, copy=copy)
- @deprecated(reason="Consistency with PlotWidget items",
- replacement="getValueData", since_version="0.10.0")
- def getValues(self, copy=True):
- return self.getValueData(copy)
-
- def _pickPoints(self, context, points, threshold=1., sort='depth'):
+ def _pickPoints(self, context, points, threshold=1.0, sort="depth"):
"""Perform picking while in 'points' visualization mode
:param PickContext context: Current picking context
@@ -429,34 +428,41 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn,
:return: Object holding the results or None
:rtype: Union[None,PickingResult]
"""
- assert sort in ('index', 'depth')
+ assert sort in ("index", "depth")
- rayNdc = context.getPickingSegment(frame='ndc')
+ rayNdc = context.getPickingSegment(frame="ndc")
if rayNdc is None: # No picking outside viewport
return None
# Project data to NDC
primitive = self._getScenePrimitive()
pointsNdc = primitive.objectToNDCTransform.transformPoints(
- points, perspectiveDivide=True)
+ points, perspectiveDivide=True
+ )
# Perform picking
distancesNdc = numpy.abs(pointsNdc[:, :2] - rayNdc[0, :2])
thresholdNdc = threshold / numpy.array(primitive.viewport.size)
- picked = numpy.where(numpy.logical_and(
- numpy.all(distancesNdc < thresholdNdc, axis=1),
- numpy.logical_and(rayNdc[0, 2] <= pointsNdc[:, 2],
- pointsNdc[:, 2] <= rayNdc[1, 2])))[0]
+ picked = numpy.where(
+ numpy.logical_and(
+ numpy.all(distancesNdc < thresholdNdc, axis=1),
+ numpy.logical_and(
+ rayNdc[0, 2] <= pointsNdc[:, 2], pointsNdc[:, 2] <= rayNdc[1, 2]
+ ),
+ )
+ )[0]
- if sort == 'depth':
+ if sort == "depth":
# Sort picked points from front to back
picked = picked[numpy.argsort(pointsNdc[picked, 2])]
if picked.size > 0:
- return PickingResult(self,
- positions=points[picked, :3],
- indices=picked,
- fetchdata=self.getValueData)
+ return PickingResult(
+ self,
+ positions=points[picked, :3],
+ indices=picked,
+ fetchdata=self.getValueData,
+ )
else:
return None
@@ -477,7 +483,8 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn,
trianglesIndices = self._cachedTrianglesIndices.reshape(-1, 3)
triangles = points[trianglesIndices, :3]
selectedIndices, t, barycentric = glu.segmentTrianglesIntersection(
- rayObject, triangles)
+ rayObject, triangles
+ )
closest = numpy.argmax(barycentric, axis=1)
indices = trianglesIndices.reshape(-1, 3)[selectedIndices, closest]
@@ -488,10 +495,9 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn,
# Compute intersection points and get closest data point
positions = t.reshape(-1, 1) * (rayObject[1] - rayObject[0]) + rayObject[0]
- return PickingResult(self,
- positions=positions,
- indices=indices,
- fetchdata=self.getValueData)
+ return PickingResult(
+ self, positions=positions, indices=indices, fetchdata=self.getValueData
+ )
def _pickFull(self, context):
"""Perform picking in this item at given widget position.
@@ -509,22 +515,20 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn,
else:
zData = numpy.zeros_like(xData)
- points = numpy.transpose((xData,
- self.getYData(copy=False),
- zData,
- numpy.ones_like(xData)))
+ points = numpy.transpose(
+ (xData, self.getYData(copy=False), zData, numpy.ones_like(xData))
+ )
mode = self.getVisualization()
if mode is self.Visualization.POINTS:
# TODO issue with symbol size: using pixel instead of points
# Get "corrected" symbol size
_, threshold = self._getSceneSymbol()
- return self._pickPoints(
- context, points, threshold=max(3., threshold))
+ return self._pickPoints(context, points, threshold=max(3.0, threshold))
elif mode is self.Visualization.LINES:
# Picking only at point
- return self._pickPoints(context, points, threshold=5.)
+ return self._pickPoints(context, points, threshold=5.0)
else: # mode == 'solid'
return self._pickSolid(context, points)
@@ -543,36 +547,38 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn,
heightMap = self.isHeightMap()
if mode is self.Visualization.POINTS:
- z = value if heightMap else 0.
+ z = value if heightMap else 0.0
symbol, size = self._getSceneSymbol()
primitive = primitives.Points(
- x=x, y=y, z=z, value=value,
- size=size,
- colormap=self._getSceneColormap())
+ x=x, y=y, z=z, value=value, size=size, colormap=self._getSceneColormap()
+ )
primitive.marker = symbol
else:
# TODO run delaunay in a thread
# Compute lines/triangles indices if not cached
if self._cachedTrianglesIndices is None:
- triangulation = delaunay(x, y)
- if triangulation is None:
+ try:
+ triangulation = Triangulation(x, y)
+ except (RuntimeError, ValueError):
+ _logger.debug("Delaunay tesselation failed: %s", sys.exc_info()[1])
return None
self._cachedTrianglesIndices = numpy.ravel(
- triangulation.simplices.astype(numpy.uint32))
+ triangulation.triangles.astype(numpy.uint32)
+ )
- if (mode is self.Visualization.LINES and
- self._cachedLinesIndices is None):
+ if mode is self.Visualization.LINES and self._cachedLinesIndices is None:
# Compute line indices
self._cachedLinesIndices = utils.triangleToLineIndices(
- self._cachedTrianglesIndices, unicity=True)
+ self._cachedTrianglesIndices, unicity=True
+ )
if mode is self.Visualization.LINES:
indices = self._cachedLinesIndices
- renderMode = 'lines'
+ renderMode = "lines"
else:
indices = self._cachedTrianglesIndices
- renderMode = 'triangles'
+ renderMode = "triangles"
# TODO supports x, y instead of copy
if heightMap:
@@ -590,14 +596,15 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn,
if len(value) > 1:
value = value[indices]
triangleNormals = utils.trianglesNormal(coordinates)
- normal = numpy.empty((len(triangleNormals) * 3, 3),
- dtype=numpy.float32)
+ normal = numpy.empty(
+ (len(triangleNormals) * 3, 3), dtype=numpy.float32
+ )
normal[0::3, :] = triangleNormals
normal[1::3, :] = triangleNormals
normal[2::3, :] = triangleNormals
indices = None
else:
- normal = (0., 0., 1.)
+ normal = (0.0, 0.0, 1.0)
else:
normal = None
@@ -607,7 +614,8 @@ class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn,
normal=normal,
colormap=self._getSceneColormap(),
indices=indices,
- mode=renderMode)
+ mode=renderMode,
+ )
primitive.lineWidth = self.getLineWidth()
primitive.lineSmooth = False
diff --git a/src/silx/gui/plot3d/items/volume.py b/src/silx/gui/plot3d/items/volume.py
index b3007fa..7696794 100644
--- a/src/silx/gui/plot3d/items/volume.py
+++ b/src/silx/gui/plot3d/items/volume.py
@@ -58,12 +58,13 @@ class CutPlane(Item3D, ColormapMixIn, InterpolationMixIn, PlaneMixIn):
"""
def __init__(self, parent):
- plane = cutplane.CutPlane(normal=(0, 1, 0))
-
Item3D.__init__(self, parent=None)
ColormapMixIn.__init__(self)
InterpolationMixIn.__init__(self)
- PlaneMixIn.__init__(self, plane=plane)
+ PlaneMixIn.__init__(self)
+
+ plane = cutplane.CutPlane(normal=(0, 1, 0))
+ self._setPlane(plane)
self._dataRange = None
self._data = None
@@ -92,10 +93,13 @@ class CutPlane(Item3D, ColormapMixIn, InterpolationMixIn, PlaneMixIn):
self._dataRange = range_
if range_ is None:
range_ = None, None, None
- self._setColormappedData(self._data, copy=False,
- min_=range_[0],
- minPositive=range_[1],
- max_=range_[2])
+ self._setColormappedData(
+ self._data,
+ copy=False,
+ min_=range_[0],
+ minPositive=range_[1],
+ max_=range_[2],
+ )
self._updated(ItemChangedType.DATA)
@@ -184,10 +188,11 @@ class CutPlane(Item3D, ColormapMixIn, InterpolationMixIn, PlaneMixIn):
rayObject[0, :3],
rayObject[1, :3],
planeNorm=self.getNormal(),
- planePt=self.getPoint())
+ planePt=self.getPoint(),
+ )
if len(points) == 1: # Single intersection
- if numpy.any(points[0] < 0.):
+ if numpy.any(points[0] < 0.0):
return None # Outside volume
z, y, x = int(points[0][2]), int(points[0][1]), int(points[0][0])
@@ -197,9 +202,9 @@ class CutPlane(Item3D, ColormapMixIn, InterpolationMixIn, PlaneMixIn):
depth, height, width = data.shape
if z < depth and y < height and x < width:
- return PickingResult(self,
- positions=[points[0]],
- indices=([z], [y], [x]))
+ return PickingResult(
+ self, positions=[points[0]], indices=([z], [y], [x])
+ )
else:
return None # Outside image
else: # Either no intersection or segment and image are coplanar
@@ -215,9 +220,9 @@ class Isosurface(Item3D):
def __init__(self, parent):
Item3D.__init__(self, parent=None)
self._data = None
- self._level = float('nan')
+ self._level = float("nan")
self._autoLevelFunction = None
- self._color = rgba('#FFD700FF')
+ self._color = rgba("#FFD700FF")
self.setParent(parent)
def _syncDataWithParent(self):
@@ -310,7 +315,7 @@ class Isosurface(Item3D):
"""
primitive = self._getScenePrimitive()
if len(primitive.children) != 0:
- primitive.children[0].setAttribute('color', color)
+ primitive.children[0].setAttribute("color", color)
def setColor(self, color):
"""Set the color of the iso-surface
@@ -334,7 +339,7 @@ class Isosurface(Item3D):
if data is None:
if self.isAutoLevel():
- self._level = float('nan')
+ self._level = float("nan")
else:
if self.isAutoLevel():
@@ -349,12 +354,12 @@ class Isosurface(Item3D):
"Error while executing iso level function %s.%s",
module_,
name,
- exc_info=True)
- level = float('nan')
+ exc_info=True,
+ )
+ level = float("nan")
else:
- _logger.info(
- 'Computed iso-level in %f s.', time.time() - st)
+ _logger.info("Computed iso-level in %f s.", time.time() - st)
if level != self._level:
self._level = level
@@ -362,10 +367,8 @@ class Isosurface(Item3D):
if numpy.isfinite(self._level):
st = time.time()
- vertices, normals, indices = MarchingCubes(
- data,
- isolevel=self._level)
- _logger.info('Computed iso-surface in %f s.', time.time() - st)
+ vertices, normals, indices = MarchingCubes(data, isolevel=self._level)
+ _logger.info("Computed iso-surface in %f s.", time.time() - st)
if len(vertices) != 0:
return vertices, normals, indices
@@ -378,12 +381,14 @@ class Isosurface(Item3D):
vertices, normals, indices = self._computeIsosurface()
if vertices is not None:
- mesh = primitives.Mesh3D(vertices,
- colors=self._color,
- normals=normals,
- mode='triangles',
- indices=indices,
- copy=False)
+ mesh = primitives.Mesh3D(
+ vertices,
+ colors=self._color,
+ normals=normals,
+ mode="triangles",
+ indices=indices,
+ copy=False,
+ )
self._getScenePrimitive().children = [mesh]
def _pickFull(self, context):
@@ -399,8 +404,7 @@ class Isosurface(Item3D):
rayObject = rayObject[:, :3]
data = self.getData(copy=False)
- bins = utils.segmentVolumeIntersect(
- rayObject, numpy.array(data.shape) - 1)
+ bins = utils.segmentVolumeIntersect(rayObject, numpy.array(data.shape) - 1)
if bins is None:
return None
@@ -413,8 +417,10 @@ class Isosurface(Item3D):
# check bin candidates
level = self.getLevel()
- mask = numpy.logical_and(numpy.nanmin(binsData, axis=1) <= level,
- level <= numpy.nanmax(binsData, axis=1))
+ mask = numpy.logical_and(
+ numpy.nanmin(binsData, axis=1) <= level,
+ level <= numpy.nanmax(binsData, axis=1),
+ )
bins = bins[mask]
binsData = binsData[mask]
@@ -476,19 +482,23 @@ class ScalarField3D(BaseNodeItem):
self._isogroup = primitives.GroupDepthOffset()
self._isogroup.transforms = [
# Convert from z, y, x from marching cubes to x, y, z
- transform.Matrix((
- (0., 0., 1., 0.),
- (0., 1., 0., 0.),
- (1., 0., 0., 0.),
- (0., 0., 0., 1.))),
+ transform.Matrix(
+ (
+ (0.0, 0.0, 1.0, 0.0),
+ (0.0, 1.0, 0.0, 0.0),
+ (1.0, 0.0, 0.0, 0.0),
+ (0.0, 0.0, 0.0, 1.0),
+ )
+ ),
# Offset to match cutting plane coords
- transform.Translate(0.5, 0.5, 0.5)
+ transform.Translate(0.5, 0.5, 0.5),
]
self._getScenePrimitive().children = [
self._boundedGroup,
self._cutPlane._getScenePrimitive(),
- self._isogroup]
+ self._isogroup,
+ ]
@staticmethod
def _computeRangeFromData(data):
@@ -507,7 +517,7 @@ class ScalarField3D(BaseNodeItem):
if dataRange is not None:
min_positive = dataRange.min_positive
if min_positive is None:
- min_positive = float('nan')
+ min_positive = float("nan")
return dataRange.minimum, min_positive, dataRange.maximum
def setData(self, data, copy=True):
@@ -526,7 +536,7 @@ class ScalarField3D(BaseNodeItem):
self._boundedGroup.shape = None
else:
- data = numpy.array(data, copy=copy, dtype=numpy.float32, order='C')
+ data = numpy.array(data, copy=copy, dtype=numpy.float32, order="C")
assert data.ndim == 3
assert min(data.shape) >= 2
@@ -625,8 +635,8 @@ class ScalarField3D(BaseNodeItem):
"""
if isosurface not in self.getIsosurfaces():
_logger.warning(
- "Try to remove isosurface that is not in the list: %s",
- str(isosurface))
+ "Try to remove isosurface that is not in the list: %s", str(isosurface)
+ )
else:
isosurface.sigItemChanged.disconnect(self._isosurfaceItemChanged)
self._isosurfaces.remove(isosurface)
@@ -646,8 +656,9 @@ class ScalarField3D(BaseNodeItem):
def _updateIsosurfaces(self):
"""Handle updates of iso-surfaces level and add/remove"""
# Sorting using minus, this supposes data 'object' to be max values
- sortedIso = sorted(self.getIsosurfaces(),
- key=lambda isosurface: - isosurface.getLevel())
+ sortedIso = sorted(
+ self.getIsosurfaces(), key=lambda isosurface: -isosurface.getLevel()
+ )
self._isogroup.children = [iso._getScenePrimitive() for iso in sortedIso]
# BaseNodeItem
@@ -664,6 +675,7 @@ class ScalarField3D(BaseNodeItem):
# ComplexField3D #
##################
+
class ComplexCutPlane(CutPlane, ComplexMixIn):
"""Class representing a cutting plane in a :class:`ComplexField3D` item.
@@ -701,8 +713,9 @@ class ComplexIsosurface(Isosurface, ComplexMixIn, ColormapMixIn):
:param parent: The DataItem3D this iso-surface belongs to
"""
- _SUPPORTED_COMPLEX_MODES = \
- (ComplexMixIn.ComplexMode.NONE,) + ComplexMixIn._SUPPORTED_COMPLEX_MODES
+ _SUPPORTED_COMPLEX_MODES = (
+ ComplexMixIn.ComplexMode.NONE,
+ ) + ComplexMixIn._SUPPORTED_COMPLEX_MODES
"""Overrides supported ComplexMode"""
def __init__(self, parent):
@@ -717,8 +730,9 @@ class ComplexIsosurface(Isosurface, ComplexMixIn, ColormapMixIn):
:param List[float] color: RGBA channels in [0, 1]
"""
primitive = self._getScenePrimitive()
- if (len(primitive.children) != 0 and
- isinstance(primitive.children[0], primitives.ColormapMesh3D)):
+ if len(primitive.children) != 0 and isinstance(
+ primitive.children[0], primitives.ColormapMesh3D
+ ):
primitive.children[0].alpha = self._color[3]
else:
super(ComplexIsosurface, self)._updateColor(color)
@@ -729,15 +743,14 @@ class ComplexIsosurface(Isosurface, ComplexMixIn, ColormapMixIn):
if parent is None:
self._data = None
else:
- self._data = parent.getData(
- mode=parent.getComplexMode(), copy=False)
+ self._data = parent.getData(mode=parent.getComplexMode(), copy=False)
if parent is None or self.getComplexMode() == self.ComplexMode.NONE:
self._setColormappedData(None, copy=False)
else:
self._setColormappedData(
- parent.getData(mode=self.getComplexMode(), copy=False),
- copy=False)
+ parent.getData(mode=self.getComplexMode(), copy=False), copy=False
+ )
self._updateScenePrimitive()
@@ -755,8 +768,7 @@ class ComplexIsosurface(Isosurface, ComplexMixIn, ColormapMixIn):
if event == ItemChangedType.COMPLEX_MODE:
self._syncDataWithParent()
- elif event in (ItemChangedType.COLORMAP,
- Item3DChangedType.INTERPOLATION):
+ elif event in (ItemChangedType.COLORMAP, Item3DChangedType.INTERPOLATION):
self._updateScenePrimitive()
super(ComplexIsosurface, self)._updated(event)
@@ -772,7 +784,7 @@ class ComplexIsosurface(Isosurface, ComplexMixIn, ColormapMixIn):
if values is not None:
vertices, normals, indices = self._computeIsosurface()
if vertices is not None:
- values = interp3d(values, vertices, method='linear_omp')
+ values = interp3d(values, vertices, method="linear_omp")
# TODO reuse isosurface when only color changes...
mesh = primitives.ColormapMesh3D(
@@ -780,9 +792,10 @@ class ComplexIsosurface(Isosurface, ComplexMixIn, ColormapMixIn):
value=values.reshape(-1, 1),
colormap=self._getSceneColormap(),
normal=normals,
- mode='triangles',
+ mode="triangles",
indices=indices,
- copy=False)
+ copy=False,
+ )
mesh.alpha = self._color[3]
self._getScenePrimitive().children = [mesh]
@@ -826,7 +839,7 @@ class ComplexField3D(ScalarField3D, ComplexMixIn):
self._boundedGroup.shape = None
else:
- data = numpy.array(data, copy=copy, dtype=numpy.complex64, order='C')
+ data = numpy.array(data, copy=copy, dtype=numpy.complex64, order="C")
assert data.ndim == 3
assert min(data.shape) >= 2
diff --git a/src/silx/gui/plot3d/scene/axes.py b/src/silx/gui/plot3d/scene/axes.py
index 9f6ac6c..9102732 100644
--- a/src/silx/gui/plot3d/scene/axes.py
+++ b/src/silx/gui/plot3d/scene/axes.py
@@ -40,40 +40,37 @@ _logger = logging.getLogger(__name__)
class LabelledAxes(primitives.GroupBBox):
- """A group displaying a bounding box with axes labels around its children.
- """
+ """A group displaying a bounding box with axes labels around its children."""
def __init__(self):
super(LabelledAxes, self).__init__()
self._ticksForBounds = None
- self._font = text.Font()
+ self._font = text.Font(size=10)
self._boxVisibility = True
# TODO offset labels from anchor in pixels
self._xlabel = text.Text2D(font=self._font)
- self._xlabel.align = 'center'
- self._xlabel.transforms = [self._boxTransforms,
- transform.Translate(tx=0.5)]
+ self._xlabel.align = "center"
+ self._xlabel.transforms = [self._boxTransforms, transform.Translate(tx=0.5)]
self._children.insert(-1, self._xlabel)
self._ylabel = text.Text2D(font=self._font)
- self._ylabel.align = 'center'
- self._ylabel.transforms = [self._boxTransforms,
- transform.Translate(ty=0.5)]
+ self._ylabel.align = "center"
+ self._ylabel.transforms = [self._boxTransforms, transform.Translate(ty=0.5)]
self._children.insert(-1, self._ylabel)
self._zlabel = text.Text2D(font=self._font)
- self._zlabel.align = 'center'
- self._zlabel.transforms = [self._boxTransforms,
- transform.Translate(tz=0.5)]
+ self._zlabel.align = "center"
+ self._zlabel.transforms = [self._boxTransforms, transform.Translate(tz=0.5)]
self._children.insert(-1, self._zlabel)
# Init tick lines with dummy pos
self._tickLines = primitives.DashedLines(
- positions=((0., 0., 0.), (0., 0., 0.)))
+ positions=((0.0, 0.0, 0.0), (0.0, 0.0, 0.0))
+ )
self._tickLines.dash = 5, 10
self._tickLines.visible = False
self._children.insert(-1, self._tickLines)
@@ -82,7 +79,7 @@ class LabelledAxes(primitives.GroupBBox):
self._children.insert(-1, self._tickLabels)
# Sync color
- self.tickColor = 1., 1., 1., 1.
+ self.tickColor = 1.0, 1.0, 1.0, 1.0
def _updateBoxAndAxes(self):
"""Update bbox and axes position and size according to children.
@@ -93,7 +90,7 @@ class LabelledAxes(primitives.GroupBBox):
bounds = self._group.bounds(dataBounds=True)
if bounds is not None:
- tx, ty, tz = (bounds[1] - bounds[0]) / 2.
+ tx, ty, tz = (bounds[1] - bounds[0]) / 2.0
else:
tx, ty, tz = 0.5, 0.5, 0.5
@@ -116,7 +113,7 @@ class LabelledAxes(primitives.GroupBBox):
self._ylabel.foreground = color
self._zlabel.foreground = color
transparentColor = color[0], color[1], color[2], color[3] * 0.6
- self._tickLines.setAttribute('color', transparentColor)
+ self._tickLines.setAttribute("color", transparentColor)
for label in self._tickLabels.children:
label.foreground = color
@@ -185,8 +182,9 @@ class LabelledAxes(primitives.GroupBBox):
self._tickLines.visible = False
self._tickLabels.children = [] # Reset previous labels
- elif (self._ticksForBounds is None or
- not numpy.all(numpy.equal(bounds, self._ticksForBounds))):
+ elif self._ticksForBounds is None or not numpy.all(
+ numpy.equal(bounds, self._ticksForBounds)
+ ):
self._ticksForBounds = bounds
# Update ticks
@@ -198,21 +196,21 @@ class LabelledAxes(primitives.GroupBBox):
# Update tick lines
coords = numpy.empty(
- ((len(xticks) + len(yticks) + len(zticks)), 4, 3),
- dtype=numpy.float32)
+ ((len(xticks) + len(yticks) + len(zticks)), 4, 3), dtype=numpy.float32
+ )
coords[:, :, :] = bounds[0, :] # account for offset from origin
- xcoords = coords[:len(xticks)]
+ xcoords = coords[: len(xticks)]
xcoords[:, :, 0] = numpy.asarray(xticks)[:, numpy.newaxis]
xcoords[:, 1, 1] += ticklength[1] # X ticks on XY plane
xcoords[:, 3, 2] += ticklength[2] # X ticks on XZ plane
- ycoords = coords[len(xticks):len(xticks) + len(yticks)]
+ ycoords = coords[len(xticks) : len(xticks) + len(yticks)]
ycoords[:, :, 1] = numpy.asarray(yticks)[:, numpy.newaxis]
ycoords[:, 1, 0] += ticklength[0] # Y ticks on XY plane
ycoords[:, 3, 2] += ticklength[2] # Y ticks on YZ plane
- zcoords = coords[len(xticks) + len(yticks):]
+ zcoords = coords[len(xticks) + len(yticks) :]
zcoords[:, :, 2] = numpy.asarray(zticks)[:, numpy.newaxis]
zcoords[:, 1, 0] += ticklength[0] # Z ticks on XZ plane
zcoords[:, 3, 1] += ticklength[1] # Z ticks on YZ plane
@@ -222,30 +220,36 @@ class LabelledAxes(primitives.GroupBBox):
# Update labels
color = self.tickColor
- offsets = bounds[0] - ticklength / 20.
+ offsets = bounds[0] - ticklength / 20.0
labels = []
for tick, label in zip(xticks, xlabels):
text2d = text.Text2D(text=label, font=self.font)
- text2d.align = 'center'
+ text2d.align = "center"
+ text2d.valign = "center"
text2d.foreground = color
- text2d.transforms = [transform.Translate(
- tx=tick, ty=offsets[1], tz=offsets[2])]
+ text2d.transforms = [
+ transform.Translate(tx=tick, ty=offsets[1], tz=offsets[2])
+ ]
labels.append(text2d)
for tick, label in zip(yticks, ylabels):
text2d = text.Text2D(text=label, font=self.font)
- text2d.align = 'center'
+ text2d.align = "center"
+ text2d.valign = "center"
text2d.foreground = color
- text2d.transforms = [transform.Translate(
- tx=offsets[0], ty=tick, tz=offsets[2])]
+ text2d.transforms = [
+ transform.Translate(tx=offsets[0], ty=tick, tz=offsets[2])
+ ]
labels.append(text2d)
for tick, label in zip(zticks, zlabels):
text2d = text.Text2D(text=label, font=self.font)
- text2d.align = 'center'
+ text2d.align = "center"
+ text2d.valign = "center"
text2d.foreground = color
- text2d.transforms = [transform.Translate(
- tx=offsets[0], ty=offsets[1], tz=tick)]
+ text2d.transforms = [
+ transform.Translate(tx=offsets[0], ty=offsets[1], tz=tick)
+ ]
labels.append(text2d)
self._tickLabels.children = labels # Reset previous labels
diff --git a/src/silx/gui/plot3d/scene/camera.py b/src/silx/gui/plot3d/scene/camera.py
index a6bc642..5248c39 100644
--- a/src/silx/gui/plot3d/scene/camera.py
+++ b/src/silx/gui/plot3d/scene/camera.py
@@ -35,6 +35,7 @@ from . import transform
# CameraExtrinsic #############################################################
+
class CameraExtrinsic(transform.Transform):
"""Transform matrix to handle camera position and orientation.
@@ -46,21 +47,19 @@ class CameraExtrinsic(transform.Transform):
:type up: numpy.ndarray-like of 3 float32.
"""
- def __init__(self, position=(0., 0., 0.),
- direction=(0., 0., -1.),
- up=(0., 1., 0.)):
-
+ def __init__(
+ self, position=(0.0, 0.0, 0.0), direction=(0.0, 0.0, -1.0), up=(0.0, 1.0, 0.0)
+ ):
super(CameraExtrinsic, self).__init__()
self._position = None
self.position = position # set _position
- self._side = 1., 0., 0.
- self._up = 0., 1., 0.
- self._direction = 0., 0., -1.
+ self._side = 1.0, 0.0, 0.0
+ self._up = 0.0, 1.0, 0.0
+ self._direction = 0.0, 0.0, -1.0
self.setOrientation(direction=direction, up=up) # set _direction, _up
def _makeMatrix(self):
- return transform.mat4LookAtDir(self._position,
- self._direction, self._up)
+ return transform.mat4LookAtDir(self._position, self._direction, self._up)
def copy(self):
"""Return an independent copy"""
@@ -93,8 +92,8 @@ class CameraExtrinsic(transform.Transform):
# Update side and up to make sure they are perpendicular and normalized
side = numpy.cross(direction, up)
sidenormal = numpy.linalg.norm(side)
- if sidenormal == 0.:
- raise RuntimeError('direction and up vectors are parallel.')
+ if sidenormal == 0.0:
+ raise RuntimeError("direction and up vectors are parallel.")
# Alternative: when one of the input parameter is None, it is
# possible to guess correct vectors using previous direction and up
side /= sidenormal
@@ -128,8 +127,7 @@ class CameraExtrinsic(transform.Transform):
@property
def up(self):
- """Vector pointing upward in the image plane (ndarray of 3 float32).
- """
+ """Vector pointing upward in the image plane (ndarray of 3 float32)."""
return self._up.copy()
@up.setter
@@ -143,7 +141,7 @@ class CameraExtrinsic(transform.Transform):
ndarray of 3 float32"""
return self._side.copy()
- def move(self, direction, step=1.):
+ def move(self, direction, step=1.0):
"""Move the camera relative to the image plane.
:param str direction: Direction relative to image plane.
@@ -152,35 +150,35 @@ class CameraExtrinsic(transform.Transform):
:param float step: The step of the pan to perform in the coordinate
in which the camera position is defined.
"""
- if direction in ('up', 'down'):
- vector = self.up * (1. if direction == 'up' else -1.)
- elif direction in ('left', 'right'):
- vector = self.side * (1. if direction == 'right' else -1.)
- elif direction in ('forward', 'backward'):
- vector = self.direction * (1. if direction == 'forward' else -1.)
+ if direction in ("up", "down"):
+ vector = self.up * (1.0 if direction == "up" else -1.0)
+ elif direction in ("left", "right"):
+ vector = self.side * (1.0 if direction == "right" else -1.0)
+ elif direction in ("forward", "backward"):
+ vector = self.direction * (1.0 if direction == "forward" else -1.0)
else:
- raise ValueError('Unsupported direction: %s' % direction)
+ raise ValueError("Unsupported direction: %s" % direction)
self.position += step * vector
- def rotate(self, direction, angle=1.):
+ def rotate(self, direction, angle=1.0):
"""First-person rotation of the camera towards the direction.
:param str direction: Direction of movement relative to image plane.
In: 'up', 'down', 'left', 'right'.
:param float angle: The angle in degrees of the rotation.
"""
- if direction in ('up', 'down'):
- axis = self.side * (1. if direction == 'up' else -1.)
- elif direction in ('left', 'right'):
- axis = self.up * (1. if direction == 'left' else -1.)
+ if direction in ("up", "down"):
+ axis = self.side * (1.0 if direction == "up" else -1.0)
+ elif direction in ("left", "right"):
+ axis = self.up * (1.0 if direction == "left" else -1.0)
else:
- raise ValueError('Unsupported direction: %s' % direction)
+ raise ValueError("Unsupported direction: %s" % direction)
matrix = transform.mat4RotateFromAngleAxis(numpy.radians(angle), *axis)
newdir = numpy.dot(matrix[:3, :3], self.direction)
- if direction in ('up', 'down'):
+ if direction in ("up", "down"):
# Rotate up to avoid up and new direction to be (almost) co-linear
newup = numpy.dot(matrix[:3, :3], self.up)
self.setOrientation(newdir, newup)
@@ -188,7 +186,7 @@ class CameraExtrinsic(transform.Transform):
# No need to rotate up here as it is the rotation axis
self.direction = newdir
- def orbit(self, direction, center=(0., 0., 0.), angle=1.):
+ def orbit(self, direction, center=(0.0, 0.0, 0.0), angle=1.0):
"""Rotate the camera around a point.
:param str direction: Direction of movement relative to image plane.
@@ -197,33 +195,32 @@ class CameraExtrinsic(transform.Transform):
:type center: numpy.ndarray-like of 3 float32.
:param float angle: he angle in degrees of the rotation.
"""
- if direction in ('up', 'down'):
- axis = self.side * (1. if direction == 'down' else -1.)
- elif direction in ('left', 'right'):
- axis = self.up * (1. if direction == 'right' else -1.)
+ if direction in ("up", "down"):
+ axis = self.side * (1.0 if direction == "down" else -1.0)
+ elif direction in ("left", "right"):
+ axis = self.up * (1.0 if direction == "right" else -1.0)
else:
- raise ValueError('Unsupported direction: %s' % direction)
+ raise ValueError("Unsupported direction: %s" % direction)
# Rotate viewing direction
- rotmatrix = transform.mat4RotateFromAngleAxis(
- numpy.radians(angle), *axis)
+ rotmatrix = transform.mat4RotateFromAngleAxis(numpy.radians(angle), *axis)
self.direction = numpy.dot(rotmatrix[:3, :3], self.direction)
# Rotate position around center
center = numpy.array(center, copy=False, dtype=numpy.float32)
matrix = numpy.dot(transform.mat4Translate(*center), rotmatrix)
matrix = numpy.dot(matrix, transform.mat4Translate(*(-center)))
- position = numpy.append(self.position, 1.)
+ position = numpy.append(self.position, 1.0)
self.position = numpy.dot(matrix, position)[:3]
_RESET_CAMERA_ORIENTATIONS = {
- 'side': ((-1., -1., -1.), (0., 1., 0.)),
- 'front': ((0., 0., -1.), (0., 1., 0.)),
- 'back': ((0., 0., 1.), (0., 1., 0.)),
- 'top': ((0., -1., 0.), (0., 0., -1.)),
- 'bottom': ((0., 1., 0.), (0., 0., 1.)),
- 'right': ((-1., 0., 0.), (0., 1., 0.)),
- 'left': ((1., 0., 0.), (0., 1., 0.))
+ "side": ((-1.0, -1.0, -1.0), (0.0, 1.0, 0.0)),
+ "front": ((0.0, 0.0, -1.0), (0.0, 1.0, 0.0)),
+ "back": ((0.0, 0.0, 1.0), (0.0, 1.0, 0.0)),
+ "top": ((0.0, -1.0, 0.0), (0.0, 0.0, -1.0)),
+ "bottom": ((0.0, 1.0, 0.0), (0.0, 0.0, 1.0)),
+ "right": ((-1.0, 0.0, 0.0), (0.0, 1.0, 0.0)),
+ "left": ((1.0, 0.0, 0.0), (0.0, 1.0, 0.0)),
}
def reset(self, face=None):
@@ -233,12 +230,12 @@ class CameraExtrinsic(transform.Transform):
side, front, back, top, bottom, right, left.
"""
if face not in self._RESET_CAMERA_ORIENTATIONS:
- raise ValueError('Unsupported face: %s' % face)
+ raise ValueError("Unsupported face: %s" % face)
distance = numpy.linalg.norm(self.position)
direction, up = self._RESET_CAMERA_ORIENTATIONS[face]
self.setOrientation(direction, up)
- self.position = - self.direction * distance
+ self.position = -self.direction * distance
class Camera(transform.Transform):
@@ -260,9 +257,16 @@ class Camera(transform.Transform):
:type up: numpy.ndarray-like of 3 float32.
"""
- def __init__(self, fovy=30., near=0.1, far=1., size=(1., 1.),
- position=(0., 0., 0.),
- direction=(0., 0., -1.), up=(0., 1., 0.)):
+ def __init__(
+ self,
+ fovy=30.0,
+ near=0.1,
+ far=1.0,
+ size=(1.0, 1.0),
+ position=(0.0, 0.0, 0.0),
+ direction=(0.0, 0.0, -1.0),
+ up=(0.0, 1.0, 0.0),
+ ):
super(Camera, self).__init__()
self._intrinsic = transform.Perspective(fovy, near, far, size)
self._intrinsic.addListener(self._transformChanged)
@@ -289,8 +293,8 @@ class Camera(transform.Transform):
center = 0.5 * (bounds[0] + bounds[1])
radius = numpy.linalg.norm(0.5 * (bounds[1] - bounds[0]))
- if radius == 0.: # bounds are all collapsed
- radius = 1.
+ if radius == 0.0: # bounds are all collapsed
+ radius = 1.0
if isinstance(self.intrinsic, transform.Perspective):
# Get the viewpoint distance from the bounds center
@@ -302,8 +306,7 @@ class Camera(transform.Transform):
offset = radius / numpy.sin(0.5 * minfov)
# Update camera
- self.extrinsic.position = \
- center - offset * self.extrinsic.direction
+ self.extrinsic.position = center - offset * self.extrinsic.direction
self.intrinsic.setDepthExtent(offset - radius, offset + radius)
elif isinstance(self.intrinsic, transform.Orthographic):
@@ -312,14 +315,14 @@ class Camera(transform.Transform):
left=center[0] - radius,
right=center[0] + radius,
bottom=center[1] - radius,
- top=center[1] + radius)
+ top=center[1] + radius,
+ )
# Update camera
self.extrinsic.position = 0, 0, 0
- self.intrinsic.setDepthExtent(center[2] - radius,
- center[2] + radius)
+ self.intrinsic.setDepthExtent(center[2] - radius, center[2] + radius)
else:
- raise RuntimeError('Unsupported camera: %s' % self.intrinsic)
+ raise RuntimeError("Unsupported camera: %s" % self.intrinsic)
@property
def intrinsic(self):
diff --git a/src/silx/gui/plot3d/scene/core.py b/src/silx/gui/plot3d/scene/core.py
index c32a2c1..8773301 100644
--- a/src/silx/gui/plot3d/scene/core.py
+++ b/src/silx/gui/plot3d/scene/core.py
@@ -49,6 +49,7 @@ from .viewport import Viewport
# Nodes #######################################################################
+
class Base(event.Notifier):
"""A scene node with common features."""
@@ -64,10 +65,8 @@ class Base(event.Notifier):
# notifying properties
- visible = event.notifyProperty('_visible',
- doc="Visibility flag of the node")
- pickable = event.notifyProperty('_pickable',
- doc="True to make node pickable")
+ visible = event.notifyProperty("_visible", doc="Visibility flag of the node")
+ pickable = event.notifyProperty("_pickable", doc="True to make node pickable")
# Access to tree path
@@ -84,7 +83,7 @@ class Base(event.Notifier):
:param Base parent: The parent.
"""
if parent is not None and self._parentRef is not None:
- raise RuntimeError('Trying to add a node at two places.')
+ raise RuntimeError("Trying to add a node at two places.")
# Alternative: remove it from previous children list
self._parentRef = None if parent is None else weakref.ref(parent)
@@ -96,11 +95,11 @@ class Base(event.Notifier):
then the :class:`Viewport` is the first element of path.
"""
if self.parent is None:
- return self,
+ return (self,)
elif isinstance(self.parent, Viewport):
return self.parent, self
else:
- return self.parent.path + (self, )
+ return self.parent.path + (self,)
@property
def viewport(self):
@@ -154,7 +153,7 @@ class Base(event.Notifier):
# If it is a TransformList, do not create one to enable sharing.
self._transforms = iterable
else:
- assert hasattr(iterable, '__iter__')
+ assert hasattr(iterable, "__iter__")
self._transforms = transform.TransformList(iterable)
self._transforms.addListener(self._transformChanged)
@@ -163,8 +162,9 @@ class Base(event.Notifier):
# Bounds
- _CUBE_CORNERS = numpy.array(list(itertools.product((0., 1.), repeat=3)),
- dtype=numpy.float32)
+ _CUBE_CORNERS = numpy.array(
+ list(itertools.product((0.0, 1.0), repeat=3)), dtype=numpy.float32
+ )
"""Unit cube corners used to transform bounds"""
def _bounds(self, dataBounds=False):
@@ -256,7 +256,8 @@ class PrivateGroup(Base):
def _listWillChangeHook(self, methodName, *args, **kwargs):
super(PrivateGroup.ChildrenList, self)._listWillChangeHook(
- methodName, *args, **kwargs)
+ methodName, *args, **kwargs
+ )
for item in self:
item._setParent(None)
@@ -264,7 +265,8 @@ class PrivateGroup(Base):
for item in self:
item._setParent(self._parentRef())
super(PrivateGroup.ChildrenList, self)._listWasChangedHook(
- methodName, *args, **kwargs)
+ methodName, *args, **kwargs
+ )
def __init__(self, parent, children):
self._parentRef = weakref.ref(parent)
@@ -303,8 +305,7 @@ class PrivateGroup(Base):
bounds = []
for child in self._children:
if child.visible:
- childBounds = child.bounds(
- transformed=True, dataBounds=dataBounds)
+ childBounds = child.bounds(transformed=True, dataBounds=dataBounds)
if childBounds is not None:
bounds.append(childBounds)
@@ -312,9 +313,10 @@ class PrivateGroup(Base):
return None
else:
bounds = numpy.array(bounds, dtype=numpy.float32)
- return numpy.array((bounds[:, 0, :].min(axis=0),
- bounds[:, 1, :].max(axis=0)),
- dtype=numpy.float32)
+ return numpy.array(
+ (bounds[:, 0, :].min(axis=0), bounds[:, 1, :].max(axis=0)),
+ dtype=numpy.float32,
+ )
def prepareGL2(self, ctx):
pass
diff --git a/src/silx/gui/plot3d/scene/cutplane.py b/src/silx/gui/plot3d/scene/cutplane.py
index bfd578f..f3b7494 100644
--- a/src/silx/gui/plot3d/scene/cutplane.py
+++ b/src/silx/gui/plot3d/scene/cutplane.py
@@ -42,7 +42,8 @@ from . import transform, utils
class ColormapMesh3D(Geometry):
"""A 3D mesh with color from a 3D texture."""
- _shaders = ("""
+ _shaders = (
+ """
attribute vec3 position;
attribute vec3 normal;
@@ -67,7 +68,8 @@ class ColormapMesh3D(Geometry):
gl_Position = matrix * vec4(position, 1.0);
}
""",
- string.Template("""
+ string.Template(
+ """
varying vec4 vCameraPosition;
varying vec3 vPosition;
varying vec3 vNormal;
@@ -91,32 +93,41 @@ class ColormapMesh3D(Geometry):
$scenePostCall(vCameraPosition);
}
- """))
-
- def __init__(self, position, normal, data, copy=True,
- mode='triangles', indices=None, colormap=None):
+ """
+ ),
+ )
+
+ def __init__(
+ self,
+ position,
+ normal,
+ data,
+ copy=True,
+ mode="triangles",
+ indices=None,
+ colormap=None,
+ ):
assert mode in self._TRIANGLE_MODES
- data = numpy.array(data, copy=copy, order='C')
+ data = numpy.array(data, copy=copy, order="C")
assert data.ndim == 3
self._data = data
self._texture = None
self._update_texture = True
self._update_texture_filter = False
- self._alpha = 1.
+ self._alpha = 1.0
self._colormap = colormap or Colormap() # Default colormap
self._colormap.addListener(self._cmapChanged)
- self._interpolation = 'linear'
- super(ColormapMesh3D, self).__init__(mode,
- indices,
- position=position,
- normal=normal)
+ self._interpolation = "linear"
+ super(ColormapMesh3D, self).__init__(
+ mode, indices, position=position, normal=normal
+ )
self.isBackfaceVisible = True
- self.textureOffset = 0., 0., 0.
+ self.textureOffset = 0.0, 0.0, 0.0
"""Offset to add to texture coordinates"""
def setData(self, data, copy=True):
- data = numpy.array(data, copy=copy, order='C')
+ data = numpy.array(data, copy=copy, order="C")
assert data.ndim == 3
self._data = data
self._update_texture = True
@@ -131,7 +142,7 @@ class ColormapMesh3D(Geometry):
@interpolation.setter
def interpolation(self, interpolation):
- assert interpolation in ('linear', 'nearest')
+ assert interpolation in ("linear", "nearest")
self._interpolation = interpolation
self._update_texture_filter = True
self.notify()
@@ -159,21 +170,24 @@ class ColormapMesh3D(Geometry):
if self._texture is not None:
self._texture.discard()
- if self.interpolation == 'nearest':
+ if self.interpolation == "nearest":
filter_ = gl.GL_NEAREST
else:
filter_ = gl.GL_LINEAR
self._update_texture = False
self._update_texture_filter = False
self._texture = _glutils.Texture(
- gl.GL_R32F, self._data, gl.GL_RED,
+ gl.GL_R32F,
+ self._data,
+ gl.GL_RED,
minFilter=filter_,
magFilter=filter_,
- wrap=gl.GL_CLAMP_TO_EDGE)
+ wrap=gl.GL_CLAMP_TO_EDGE,
+ )
if self._update_texture_filter:
self._update_texture_filter = False
- if self.interpolation == 'nearest':
+ if self.interpolation == "nearest":
filter_ = gl.GL_NEAREST
else:
filter_ = gl.GL_LINEAR
@@ -190,8 +204,8 @@ class ColormapMesh3D(Geometry):
lightingFunction=ctx.viewport.light.fragmentDef,
lightingCall=ctx.viewport.light.fragmentCall,
colormapDecl=self.colormap.decl,
- colormapCall=self.colormap.call
- )
+ colormapCall=self.colormap.call,
+ )
program = ctx.glCtx.prog(self._shaders[0], fragment)
program.use()
@@ -202,18 +216,16 @@ class ColormapMesh3D(Geometry):
gl.glCullFace(gl.GL_BACK)
gl.glEnable(gl.GL_CULL_FACE)
- program.setUniformMatrix('matrix', ctx.objectToNDC.matrix)
- program.setUniformMatrix('transformMat',
- ctx.objectToCamera.matrix,
- safe=True)
- gl.glUniform1f(program.uniforms['alpha'], self._alpha)
+ program.setUniformMatrix("matrix", ctx.objectToNDC.matrix)
+ program.setUniformMatrix("transformMat", ctx.objectToCamera.matrix, safe=True)
+ gl.glUniform1f(program.uniforms["alpha"], self._alpha)
shape = self._data.shape
- scales = 1./shape[2], 1./shape[1], 1./shape[0]
- gl.glUniform3f(program.uniforms['dataScale'], *scales)
- gl.glUniform3f(program.uniforms['texCoordsOffset'], *self.textureOffset)
+ scales = 1.0 / shape[2], 1.0 / shape[1], 1.0 / shape[0]
+ gl.glUniform3f(program.uniforms["dataScale"], *scales)
+ gl.glUniform3f(program.uniforms["texCoordsOffset"], *self.textureOffset)
- gl.glUniform1i(program.uniforms['data'], self._texture.texUnit)
+ gl.glUniform1i(program.uniforms["data"], self._texture.texUnit)
ctx.setupProgram(program)
@@ -227,11 +239,11 @@ class ColormapMesh3D(Geometry):
class CutPlane(PlaneInGroup):
"""A cutting plane in a 3D texture"""
- def __init__(self, point=(0., 0., 0.), normal=(0., 0., 1.)):
+ def __init__(self, point=(0.0, 0.0, 0.0), normal=(0.0, 0.0, 1.0)):
self._data = None
self._mesh = None
- self._alpha = 1.
- self._interpolation = 'linear'
+ self._alpha = 1.0
+ self._interpolation = "linear"
self._colormap = Colormap()
super(CutPlane, self).__init__(point, normal)
@@ -243,7 +255,7 @@ class CutPlane(PlaneInGroup):
self._mesh = None
else:
- data = numpy.array(data, copy=copy, order='C')
+ data = numpy.array(data, copy=copy, order="C")
assert data.ndim == 3
self._data = data
if self._mesh is not None:
@@ -273,7 +285,7 @@ class CutPlane(PlaneInGroup):
@interpolation.setter
def interpolation(self, interpolation):
- assert interpolation in ('nearest', 'linear')
+ assert interpolation in ("nearest", "linear")
if interpolation != self.interpolation:
self._interpolation = interpolation
if self._mesh is not None:
@@ -282,45 +294,47 @@ class CutPlane(PlaneInGroup):
def prepareGL2(self, ctx):
if self.isValid:
-
contourVertices = self.contourVertices
if self._mesh is None and self._data is not None:
- self._mesh = ColormapMesh3D(contourVertices,
- normal=self.plane.normal,
- data=self._data,
- copy=False,
- mode='fan',
- colormap=self.colormap)
+ self._mesh = ColormapMesh3D(
+ contourVertices,
+ normal=self.plane.normal,
+ data=self._data,
+ copy=False,
+ mode="fan",
+ colormap=self.colormap,
+ )
self._mesh.alpha = self._alpha
self._mesh.interpolation = self.interpolation
self._children.insert(0, self._mesh)
if self._mesh is not None:
- if (contourVertices is None or
- len(contourVertices) == 0):
+ if contourVertices is None or len(contourVertices) == 0:
self._mesh.visible = False
else:
self._mesh.visible = True
- self._mesh.setAttribute('normal', self.plane.normal)
- self._mesh.setAttribute('position', contourVertices)
+ self._mesh.setAttribute("normal", self.plane.normal)
+ self._mesh.setAttribute("position", contourVertices)
needTextureOffset = False
- if self.interpolation == 'nearest':
+ if self.interpolation == "nearest":
# If cut plane is co-linear with array bin edges add texture offset
planePt = self.plane.point
- for index, normal in enumerate(((1., 0., 0.),
- (0., 1., 0.),
- (0., 0., 1.))):
- if (numpy.all(numpy.equal(self.plane.normal, normal)) and
- int(planePt[index]) == planePt[index]):
+ for index, normal in enumerate(
+ ((1.0, 0.0, 0.0), (0.0, 1.0, 0.0), (0.0, 0.0, 1.0))
+ ):
+ if (
+ numpy.all(numpy.equal(self.plane.normal, normal))
+ and int(planePt[index]) == planePt[index]
+ ):
needTextureOffset = True
break
if needTextureOffset:
self._mesh.textureOffset = self.plane.normal * 1e-6
else:
- self._mesh.textureOffset = 0., 0., 0.
+ self._mesh.textureOffset = 0.0, 0.0, 0.0
super(CutPlane, self).prepareGL2(ctx)
@@ -333,8 +347,8 @@ class CutPlane(PlaneInGroup):
vertices = self.contourVertices
if vertices is not None:
return numpy.array(
- (vertices.min(axis=0), vertices.max(axis=0)),
- dtype=numpy.float32)
+ (vertices.min(axis=0), vertices.max(axis=0)), dtype=numpy.float32
+ )
else:
return None # Plane in not slicing the data volume
else:
@@ -342,9 +356,9 @@ class CutPlane(PlaneInGroup):
return None
else:
depth, height, width = self._data.shape
- return numpy.array(((0., 0., 0.),
- (width, height, depth)),
- dtype=numpy.float32)
+ return numpy.array(
+ ((0.0, 0.0, 0.0), (width, height, depth)), dtype=numpy.float32
+ )
@property
def contourVertices(self):
@@ -364,7 +378,8 @@ class CutPlane(PlaneInGroup):
boxVertices = bounds[0] + boxVertices * (bounds[1] - bounds[0])
lineIndices = Box.getLineIndices(copy=False)
vertices = utils.boxPlaneIntersect(
- boxVertices, lineIndices, self.plane.normal, self.plane.point)
+ boxVertices, lineIndices, self.plane.normal, self.plane.point
+ )
self._cache = bounds, vertices if len(vertices) != 0 else None
@@ -382,6 +397,6 @@ class CutPlane(PlaneInGroup):
# If it is a TransformList, do not create one to enable sharing.
self._transforms = iterable
else:
- assert hasattr(iterable, '__iter__')
+ assert hasattr(iterable, "__iter__")
self._transforms = transform.TransformList(iterable)
self._transforms.addListener(self._transformChanged)
diff --git a/src/silx/gui/plot3d/scene/event.py b/src/silx/gui/plot3d/scene/event.py
index 637eddf..4c6dd47 100644
--- a/src/silx/gui/plot3d/scene/event.py
+++ b/src/silx/gui/plot3d/scene/event.py
@@ -37,6 +37,7 @@ _logger = logging.getLogger(__name__)
# Notifier ####################################################################
+
class Notifier(object):
"""Base class for object with notification mechanism."""
@@ -53,7 +54,7 @@ class Notifier(object):
if listener not in self._listeners:
self._listeners.append(listener)
else:
- _logger.warning('Ignoring addition of an already registered listener')
+ _logger.warning("Ignoring addition of an already registered listener")
def removeListener(self, listener):
"""Remove a previously registered listener.
@@ -63,7 +64,7 @@ class Notifier(object):
try:
self._listeners.remove(listener)
except ValueError:
- _logger.warning('Trying to remove a listener that is not registered')
+ _logger.warning("Trying to remove a listener that is not registered")
def notify(self, *args, **kwargs):
"""Notify all registered listeners with the given parameters.
@@ -89,19 +90,24 @@ def notifyProperty(attrName, copy=False, converter=None, doc=None):
:return: A property with getter and setter
"""
if copy:
+
def getter(self):
return getattr(self, attrName).copy()
+
else:
+
def getter(self):
return getattr(self, attrName)
if converter is None:
+
def setter(self, value):
if getattr(self, attrName) != value:
setattr(self, attrName, value)
self.notify()
else:
+
def setter(self, value):
value = converter(value)
if getattr(self, attrName) != value:
@@ -117,7 +123,7 @@ class HookList(list):
def __init__(self, iterable):
super(HookList, self).__init__(iterable)
- self._listWasChangedHook('__init__', iterable)
+ self._listWasChangedHook("__init__", iterable)
def _listWillChangeHook(self, methodName, *args, **kwargs):
"""To override. Called before modifying the list.
@@ -140,57 +146,56 @@ class HookList(list):
def _wrapper(self, methodName, *args, **kwargs):
"""Generic wrapper of list methods calling the hooks."""
self._listWillChangeHook(methodName, *args, **kwargs)
- result = getattr(super(HookList, self),
- methodName)(*args, **kwargs)
+ result = getattr(super(HookList, self), methodName)(*args, **kwargs)
self._listWasChangedHook(methodName, *args, **kwargs)
return result
# Add methods
def __iadd__(self, *args, **kwargs):
- return self._wrapper('__iadd__', *args, **kwargs)
+ return self._wrapper("__iadd__", *args, **kwargs)
def __imul__(self, *args, **kwargs):
- return self._wrapper('__imul__', *args, **kwargs)
+ return self._wrapper("__imul__", *args, **kwargs)
def append(self, *args, **kwargs):
- return self._wrapper('append', *args, **kwargs)
+ return self._wrapper("append", *args, **kwargs)
def extend(self, *args, **kwargs):
- return self._wrapper('extend', *args, **kwargs)
+ return self._wrapper("extend", *args, **kwargs)
def insert(self, *args, **kwargs):
- return self._wrapper('insert', *args, **kwargs)
+ return self._wrapper("insert", *args, **kwargs)
# Remove methods
def __delitem__(self, *args, **kwargs):
- return self._wrapper('__delitem__', *args, **kwargs)
+ return self._wrapper("__delitem__", *args, **kwargs)
def __delslice__(self, *args, **kwargs):
- return self._wrapper('__delslice__', *args, **kwargs)
+ return self._wrapper("__delslice__", *args, **kwargs)
def remove(self, *args, **kwargs):
- return self._wrapper('remove', *args, **kwargs)
+ return self._wrapper("remove", *args, **kwargs)
def pop(self, *args, **kwargs):
- return self._wrapper('pop', *args, **kwargs)
+ return self._wrapper("pop", *args, **kwargs)
# Set methods
def __setitem__(self, *args, **kwargs):
- return self._wrapper('__setitem__', *args, **kwargs)
+ return self._wrapper("__setitem__", *args, **kwargs)
def __setslice__(self, *args, **kwargs):
- return self._wrapper('__setslice__', *args, **kwargs)
+ return self._wrapper("__setslice__", *args, **kwargs)
# In place methods
def sort(self, *args, **kwargs):
- return self._wrapper('sort', *args, **kwargs)
+ return self._wrapper("sort", *args, **kwargs)
def reverse(self, *args, **kwargs):
- return self._wrapper('reverse', *args, **kwargs)
+ return self._wrapper("reverse", *args, **kwargs)
class NotifierList(HookList, Notifier):
diff --git a/src/silx/gui/plot3d/scene/function.py b/src/silx/gui/plot3d/scene/function.py
index 3d0a62f..cde7cad 100644
--- a/src/silx/gui/plot3d/scene/function.py
+++ b/src/silx/gui/plot3d/scene/function.py
@@ -44,8 +44,7 @@ _logger = logging.getLogger(__name__)
class ProgramFunction(object):
- """Class providing a function to add to a GLProgram shaders.
- """
+ """Class providing a function to add to a GLProgram shaders."""
def setupProgram(self, context, program):
"""Sets-up uniforms of a program using this shader function.
@@ -63,6 +62,7 @@ class Fog(event.Notifier, ProgramFunction):
The background of the viewport is used as fog color,
otherwise it defaults to white.
"""
+
# TODO: add more controls (set fog range), add more fog modes
_fragDecl = """
@@ -120,26 +120,29 @@ class Fog(event.Notifier, ProgramFunction):
"""
# Provide scene z extent in camera coords
bounds = viewport.camera.extrinsic.transformBounds(
- viewport.scene.bounds(transformed=True, dataBounds=True))
+ viewport.scene.bounds(transformed=True, dataBounds=True)
+ )
return bounds[:, 2]
def setupProgram(self, context, program):
if not self.isOn:
return
- far, near = context.cache(key='zExtentCamera',
- factory=self._zExtentCamera,
- viewport=context.viewport)
+ far, near = context.cache(
+ key="zExtentCamera", factory=self._zExtentCamera, viewport=context.viewport
+ )
extent = far - near
- gl.glUniform2f(program.uniforms['fogExtentInfo'],
- 0.9/extent if extent != 0. else 0.,
- near)
+ gl.glUniform2f(
+ program.uniforms["fogExtentInfo"],
+ 0.9 / extent if extent != 0.0 else 0.0,
+ near,
+ )
# Use background color as fog color
bgColor = context.viewport.background
if bgColor is None:
- bgColor = 1., 1., 1.
- gl.glUniform3f(program.uniforms['fogColor'], *bgColor[:3])
+ bgColor = 1.0, 1.0, 1.0
+ gl.glUniform3f(program.uniforms["fogColor"], *bgColor[:3])
class ClippingPlane(ProgramFunction):
@@ -183,7 +186,7 @@ class ClippingPlane(ProgramFunction):
void clipping(vec4 position) {}
"""
- def __init__(self, point=(0., 0., 0.), normal=(0., 0., 0.)):
+ def __init__(self, point=(0.0, 0.0, 0.0), normal=(0.0, 0.0, 0.0)):
self._plane = utils.Plane(point, normal)
@property
@@ -209,7 +212,7 @@ class ClippingPlane(ProgramFunction):
It MUST be in use and using this function.
"""
if self.plane.isPlane:
- gl.glUniform4f(program.uniforms['planeEq'], *self.plane.parameters)
+ gl.glUniform4f(program.uniforms["planeEq"], *self.plane.parameters)
class DirectionalLight(event.Notifier, ProgramFunction):
@@ -279,9 +282,14 @@ class DirectionalLight(event.Notifier, ProgramFunction):
}
"""
- def __init__(self, direction=None,
- ambient=(1., 1., 1.), diffuse=(0., 0., 0.),
- specular=(1., 1., 1.), shininess=0):
+ def __init__(
+ self,
+ direction=None,
+ ambient=(1.0, 1.0, 1.0),
+ diffuse=(0.0, 0.0, 0.0),
+ specular=(1.0, 1.0, 1.0),
+ shininess=0,
+ ):
super(DirectionalLight, self).__init__()
self._direction = None
self.direction = direction # Set _direction
@@ -291,10 +299,10 @@ class DirectionalLight(event.Notifier, ProgramFunction):
self._specular = specular
self._shininess = shininess
- ambient = event.notifyProperty('_ambient')
- diffuse = event.notifyProperty('_diffuse')
- specular = event.notifyProperty('_specular')
- shininess = event.notifyProperty('_shininess')
+ ambient = event.notifyProperty("_ambient")
+ diffuse = event.notifyProperty("_diffuse")
+ specular = event.notifyProperty("_specular")
+ shininess = event.notifyProperty("_shininess")
@property
def isOn(self):
@@ -359,28 +367,29 @@ class DirectionalLight(event.Notifier, ProgramFunction):
if self.isOn and self._direction is not None:
# Transform light direction from camera space to object coords
lightdir = context.objectToCamera.transformDir(
- self._direction, direct=False)
+ self._direction, direct=False
+ )
lightdir /= numpy.linalg.norm(lightdir)
- gl.glUniform3f(program.uniforms['dLight.lightDir'], *lightdir)
+ gl.glUniform3f(program.uniforms["dLight.lightDir"], *lightdir)
# Convert view position to object coords
viewpos = context.objectToCamera.transformPoint(
- numpy.array((0., 0., 0., 1.), dtype=numpy.float32),
+ numpy.array((0.0, 0.0, 0.0, 1.0), dtype=numpy.float32),
direct=False,
- perspectiveDivide=True)[:3]
- gl.glUniform3f(program.uniforms['dLight.viewPos'], *viewpos)
+ perspectiveDivide=True,
+ )[:3]
+ gl.glUniform3f(program.uniforms["dLight.viewPos"], *viewpos)
- gl.glUniform3f(program.uniforms['dLight.ambient'], *self.ambient)
- gl.glUniform3f(program.uniforms['dLight.diffuse'], *self.diffuse)
- gl.glUniform3f(program.uniforms['dLight.specular'], *self.specular)
- gl.glUniform1f(program.uniforms['dLight.shininess'],
- self.shininess)
+ gl.glUniform3f(program.uniforms["dLight.ambient"], *self.ambient)
+ gl.glUniform3f(program.uniforms["dLight.diffuse"], *self.diffuse)
+ gl.glUniform3f(program.uniforms["dLight.specular"], *self.specular)
+ gl.glUniform1f(program.uniforms["dLight.shininess"], self.shininess)
class Colormap(event.Notifier, ProgramFunction):
-
- _declTemplate = string.Template("""
+ _declTemplate = string.Template(
+ """
uniform sampler2D cmap_texture;
uniform int cmap_normalization;
uniform float cmap_parameter;
@@ -429,7 +438,8 @@ class Colormap(event.Notifier, ProgramFunction):
}
return color;
}
- """)
+ """
+ )
_discardCode = """
if (value == 0.) {
@@ -439,13 +449,13 @@ class Colormap(event.Notifier, ProgramFunction):
call = "colormap"
- NORMS = 'linear', 'log', 'sqrt', 'gamma', 'arcsinh'
+ NORMS = "linear", "log", "sqrt", "gamma", "arcsinh"
"""Tuple of supported normalizations."""
_COLORMAP_TEXTURE_UNIT = 1
"""Texture unit to use for storing the colormap"""
- def __init__(self, colormap=None, norm='linear', gamma=0., range_=(1., 10.)):
+ def __init__(self, colormap=None, norm="linear", gamma=0.0, range_=(1.0, 10.0)):
"""Shader function to apply a colormap to a value.
:param colormap: RGB(A) color look-up table (default: gray)
@@ -459,11 +469,11 @@ class Colormap(event.Notifier, ProgramFunction):
# Init privates to default
self._colormap = None
- self._norm = 'linear'
- self._gamma = -1.
- self._range = 1., 10.
+ self._norm = "linear"
+ self._gamma = -1.0
+ self._range = 1.0, 10.0
self._displayValuesBelowMin = True
- self._nancolor = numpy.array((1., 1., 1., 0.), dtype=numpy.float32)
+ self._nancolor = numpy.array((1.0, 1.0, 1.0, 0.0), dtype=numpy.float32)
self._texture = None
self._textureToDiscard = None
@@ -471,8 +481,7 @@ class Colormap(event.Notifier, ProgramFunction):
if colormap is None:
# default colormap
colormap = numpy.empty((256, 3), dtype=numpy.uint8)
- colormap[:] = numpy.arange(256,
- dtype=numpy.uint8)[:, numpy.newaxis]
+ colormap[:] = numpy.arange(256, dtype=numpy.uint8)[:, numpy.newaxis]
# Set to values through properties to perform asserts and updates
self.colormap = colormap
@@ -484,7 +493,8 @@ class Colormap(event.Notifier, ProgramFunction):
def decl(self):
"""Source code of the function declaration"""
return self._declTemplate.substitute(
- discard="" if self.displayValuesBelowMin else self._discardCode)
+ discard="" if self.displayValuesBelowMin else self._discardCode
+ )
@property
def colormap(self):
@@ -503,17 +513,21 @@ class Colormap(event.Notifier, ProgramFunction):
data = numpy.empty(
(16, self._colormap.shape[0], self._colormap.shape[1]),
- dtype=self._colormap.dtype)
+ dtype=self._colormap.dtype,
+ )
data[:] = self._colormap
format_ = gl.GL_RGBA if data.shape[-1] == 4 else gl.GL_RGB
self._texture = _glutils.Texture(
- format_, data, format_,
+ format_,
+ data,
+ format_,
texUnit=self._COLORMAP_TEXTURE_UNIT,
minFilter=gl.GL_NEAREST,
magFilter=gl.GL_NEAREST,
- wrap=gl.GL_CLAMP_TO_EDGE)
+ wrap=gl.GL_CLAMP_TO_EDGE,
+ )
self.notify()
@@ -524,7 +538,7 @@ class Colormap(event.Notifier, ProgramFunction):
@nancolor.setter
def nancolor(self, color):
- color = numpy.clip(numpy.array(color, dtype=numpy.float32), 0., 1.)
+ color = numpy.clip(numpy.array(color, dtype=numpy.float32), 0.0, 1.0)
assert color.ndim == 1
assert len(color) == 4
if not numpy.array_equal(self._nancolor, color):
@@ -545,7 +559,7 @@ class Colormap(event.Notifier, ProgramFunction):
if norm != self._norm:
assert norm in self.NORMS
self._norm = norm
- if norm in ('log', 'sqrt'):
+ if norm in ("log", "sqrt"):
self.range_ = self.range_ # To test for positive range_
self.notify()
@@ -557,7 +571,7 @@ class Colormap(event.Notifier, ProgramFunction):
@gamma.setter
def gamma(self, gamma):
if gamma != self._gamma:
- assert gamma >= 0.
+ assert gamma >= 0.0
self._gamma = gamma
self.notify()
@@ -577,15 +591,13 @@ class Colormap(event.Notifier, ProgramFunction):
assert len(range_) == 2
range_ = float(range_[0]), float(range_[1])
- if self.norm == 'log' and (range_[0] <= 0. or range_[1] <= 0.):
- _logger.warning(
- "Log normalization and negative range: updating range.")
+ if self.norm == "log" and (range_[0] <= 0.0 or range_[1] <= 0.0):
+ _logger.warning("Log normalization and negative range: updating range.")
minPos = numpy.finfo(numpy.float32).tiny
range_ = max(range_[0], minPos), max(range_[1], minPos)
- elif self.norm == 'sqrt' and (range_[0] < 0. or range_[1] < 0.):
- _logger.warning(
- "Sqrt normalization and negative range: updating range.")
- range_ = max(range_[0], 0.), max(range_[1], 0.)
+ elif self.norm == "sqrt" and (range_[0] < 0.0 or range_[1] < 0.0):
+ _logger.warning("Sqrt normalization and negative range: updating range.")
+ range_ = max(range_[0], 0.0), max(range_[1], 0.0)
if range_ != self._range:
self._range = range_
@@ -593,8 +605,7 @@ class Colormap(event.Notifier, ProgramFunction):
@property
def displayValuesBelowMin(self):
- """True to display values below colormap min, False to discard them.
- """
+ """True to display values below colormap min, False to discard them."""
return self._displayValuesBelowMin
@displayValuesBelowMin.setter
@@ -615,33 +626,34 @@ class Colormap(event.Notifier, ProgramFunction):
self._texture.bind()
- gl.glUniform1i(program.uniforms['cmap_texture'],
- self._texture.texUnit)
+ gl.glUniform1i(program.uniforms["cmap_texture"], self._texture.texUnit)
min_, max_ = self.range_
- param = 0.
- if self._norm == 'log':
+ param = 0.0
+ if self._norm == "log":
min_, max_ = numpy.log10(min_), numpy.log10(max_)
normID = 1
- elif self._norm == 'sqrt':
+ elif self._norm == "sqrt":
min_, max_ = numpy.sqrt(min_), numpy.sqrt(max_)
normID = 2
- elif self._norm == 'gamma':
+ elif self._norm == "gamma":
# Keep min_, max_ as is
param = self._gamma
normID = 3
- elif self._norm == 'arcsinh':
+ elif self._norm == "arcsinh":
min_, max_ = numpy.arcsinh(min_), numpy.arcsinh(max_)
normID = 4
else: # Linear
normID = 0
- gl.glUniform1i(program.uniforms['cmap_normalization'], normID)
- gl.glUniform1f(program.uniforms['cmap_parameter'], param)
- gl.glUniform1f(program.uniforms['cmap_min'], min_)
- gl.glUniform1f(program.uniforms['cmap_oneOverRange'],
- (1. / (max_ - min_)) if max_ != min_ else 0.)
- gl.glUniform4f(program.uniforms['nancolor'], *self._nancolor)
+ gl.glUniform1i(program.uniforms["cmap_normalization"], normID)
+ gl.glUniform1f(program.uniforms["cmap_parameter"], param)
+ gl.glUniform1f(program.uniforms["cmap_min"], min_)
+ gl.glUniform1f(
+ program.uniforms["cmap_oneOverRange"],
+ (1.0 / (max_ - min_)) if max_ != min_ else 0.0,
+ )
+ gl.glUniform4f(program.uniforms["nancolor"], *self._nancolor)
def prepareGL2(self, context):
if self._textureToDiscard is not None:
diff --git a/src/silx/gui/plot3d/scene/interaction.py b/src/silx/gui/plot3d/scene/interaction.py
index 91fab23..debf670 100644
--- a/src/silx/gui/plot3d/scene/interaction.py
+++ b/src/silx/gui/plot3d/scene/interaction.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2015-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2015-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -31,8 +31,12 @@ import logging
import numpy
from silx.gui import qt
-from silx.gui.plot.Interaction import \
- StateMachine, State, LEFT_BTN, RIGHT_BTN # , MIDDLE_BTN
+from silx.gui.plot.Interaction import (
+ StateMachine,
+ State,
+ LEFT_BTN,
+ RIGHT_BTN,
+) # , MIDDLE_BTN
from . import transform
@@ -41,35 +45,32 @@ _logger = logging.getLogger(__name__)
class ClickOrDrag(StateMachine):
- """Click or drag interaction for a given button.
+ """Click or drag interaction for a given button."""
- """
- #TODO: merge this class with silx.gui.plot.Interaction.ClickOrDrag
+ # TODO: merge this class with silx.gui.plot.Interaction.ClickOrDrag
- DRAG_THRESHOLD_SQUARE_DIST = 5 ** 2
+ DRAG_THRESHOLD_SQUARE_DIST = 5**2
class Idle(State):
def onPress(self, x, y, btn):
if btn == self.machine.button:
- self.goto('clickOrDrag', x, y)
+ self.goto("clickOrDrag", x, y)
return True
class ClickOrDrag(State):
def enterState(self, x, y):
self.initPos = x, y
- enter = enterState # silx v.0.3 support, remove when 0.4 out
-
def onMove(self, x, y):
dx = (x - self.initPos[0]) ** 2
dy = (y - self.initPos[1]) ** 2
- if (dx ** 2 + dy ** 2) >= self.machine.DRAG_THRESHOLD_SQUARE_DIST:
- self.goto('drag', self.initPos, (x, y))
+ if (dx**2 + dy**2) >= self.machine.DRAG_THRESHOLD_SQUARE_DIST:
+ self.goto("drag", self.initPos, (x, y))
def onRelease(self, x, y, btn):
if btn == self.machine.button:
self.machine.click(x, y)
- self.goto('idle')
+ self.goto("idle")
class Drag(State):
def enterState(self, initPos, curPos):
@@ -77,24 +78,22 @@ class ClickOrDrag(StateMachine):
self.machine.beginDrag(*initPos)
self.machine.drag(*curPos)
- enter = enterState # silx v.0.3 support, remove when 0.4 out
-
def onMove(self, x, y):
self.machine.drag(x, y)
def onRelease(self, x, y, btn):
if btn == self.machine.button:
self.machine.endDrag(self.initPos, (x, y))
- self.goto('idle')
+ self.goto("idle")
def __init__(self, button=LEFT_BTN):
self.button = button
states = {
- 'idle': ClickOrDrag.Idle,
- 'clickOrDrag': ClickOrDrag.ClickOrDrag,
- 'drag': ClickOrDrag.Drag
+ "idle": ClickOrDrag.Idle,
+ "clickOrDrag": ClickOrDrag.ClickOrDrag,
+ "drag": ClickOrDrag.Drag,
}
- super(ClickOrDrag, self).__init__(states, 'idle')
+ super(ClickOrDrag, self).__init__(states, "idle")
def click(self, x, y):
"""Called upon a left or right button click.
@@ -126,8 +125,9 @@ class ClickOrDrag(StateMachine):
class CameraSelectRotate(ClickOrDrag):
"""Camera rotation using an arcball-like interaction."""
- def __init__(self, viewport, orbitAroundCenter=True, button=RIGHT_BTN,
- selectCB=None):
+ def __init__(
+ self, viewport, orbitAroundCenter=True, button=RIGHT_BTN, selectCB=None
+ ):
self._viewport = viewport
self._orbitAroundCenter = orbitAroundCenter
self._selectCB = selectCB
@@ -144,7 +144,7 @@ class CameraSelectRotate(ClickOrDrag):
position = self._viewport._getXZYGL(x, y)
# This assume no object lie on the far plane
# Alternative, change the depth range so that far is < 1
- if ndcZ != 1. and position is not None:
+ if ndcZ != 1.0 and position is not None:
self._selectCB((x, y, ndcZ), position)
def beginDrag(self, x, y):
@@ -152,7 +152,7 @@ class CameraSelectRotate(ClickOrDrag):
if not self._orbitAroundCenter:
# Try to use picked object position as center of rotation
ndcZ = self._viewport._pickNdcZGL(x, y)
- if ndcZ != 1.:
+ if ndcZ != 1.0:
# Hit an object, use picked point as center
centerPos = self._viewport._getXZYGL(x, y) # Can return None
@@ -177,12 +177,11 @@ class CameraSelectRotate(ClickOrDrag):
position = self._startExtrinsic.position
else:
minsize = min(self._viewport.size)
- distance = numpy.sqrt(dx ** 2 + dy ** 2)
+ distance = numpy.sqrt(dx**2 + dy**2)
angle = distance / minsize * numpy.pi
# Take care of y inversion
- direction = dx * self._startExtrinsic.side - \
- dy * self._startExtrinsic.up
+ direction = dx * self._startExtrinsic.side - dy * self._startExtrinsic.up
direction /= numpy.linalg.norm(direction)
axis = numpy.cross(direction, self._startExtrinsic.direction)
axis /= numpy.linalg.norm(axis)
@@ -194,10 +193,9 @@ class CameraSelectRotate(ClickOrDrag):
up = rotation.transformDir(self._startExtrinsic.up)
# Rotate position around center
- trlist = transform.StaticTransformList((
- self._center,
- rotation,
- self._center.inverse()))
+ trlist = transform.StaticTransformList(
+ (self._center, rotation, self._center.inverse())
+ )
position = trlist.transformPoint(self._startExtrinsic.position)
camerapos = self._viewport.camera.extrinsic
@@ -223,7 +221,7 @@ class CameraSelectPan(ClickOrDrag):
position = self._viewport._getXZYGL(x, y)
# This assume no object lie on the far plane
# Alternative, change the depth range so that far is < 1
- if ndcZ != 1. and position is not None:
+ if ndcZ != 1.0 and position is not None:
self._selectCB((x, y, ndcZ), position)
def beginDrag(self, x, y):
@@ -231,8 +229,9 @@ class CameraSelectPan(ClickOrDrag):
ndcZ = self._viewport._pickNdcZGL(x, y)
# ndcZ is the panning plane
if ndc is not None and ndcZ is not None:
- self._lastPosNdc = numpy.array((ndc[0], ndc[1], ndcZ, 1.),
- dtype=numpy.float32)
+ self._lastPosNdc = numpy.array(
+ (ndc[0], ndc[1], ndcZ, 1.0), dtype=numpy.float32
+ )
else:
self._lastPosNdc = None
@@ -240,14 +239,17 @@ class CameraSelectPan(ClickOrDrag):
if self._lastPosNdc is not None:
ndc = self._viewport.windowToNdc(x, y)
if ndc is not None:
- ndcPos = numpy.array((ndc[0], ndc[1], self._lastPosNdc[2], 1.),
- dtype=numpy.float32)
+ ndcPos = numpy.array(
+ (ndc[0], ndc[1], self._lastPosNdc[2], 1.0), dtype=numpy.float32
+ )
# Convert last and current NDC positions to scene coords
scenePos = self._viewport.camera.transformPoint(
- ndcPos, direct=False, perspectiveDivide=True)
+ ndcPos, direct=False, perspectiveDivide=True
+ )
lastScenePos = self._viewport.camera.transformPoint(
- self._lastPosNdc, direct=False, perspectiveDivide=True)
+ self._lastPosNdc, direct=False, perspectiveDivide=True
+ )
# Get translation in scene coords
translation = scenePos[:3] - lastScenePos[:3]
@@ -264,21 +266,21 @@ class CameraWheel(object):
"""StateMachine like class, just handling wheel events."""
# TODO choose scale of motion? Translation or Scale?
- def __init__(self, viewport, mode='center', scaleTransform=None):
- assert mode in ('center', 'position', 'scale')
+ def __init__(self, viewport, mode="center", scaleTransform=None):
+ assert mode in ("center", "position", "scale")
self._viewport = viewport
- if mode == 'center':
+ if mode == "center":
self._zoomTo = self._zoomToCenter
- elif mode == 'position':
+ elif mode == "position":
self._zoomTo = self._zoomToPosition
- elif mode == 'scale':
+ elif mode == "scale":
self._zoomTo = self._zoomByScale
self._scale = scaleTransform
else:
- raise ValueError('Unsupported mode: %s' % mode)
+ raise ValueError("Unsupported mode: %s" % mode)
def handleEvent(self, eventName, *args, **kwargs):
- if eventName == 'wheel':
+ if eventName == "wheel":
return self._zoomTo(*args, **kwargs)
def _zoomToCenter(self, x, y, angleInDegrees):
@@ -286,7 +288,7 @@ class CameraWheel(object):
Only works with perspective camera.
"""
- direction = 'forward' if angleInDegrees > 0 else 'backward'
+ direction = "forward" if angleInDegrees > 0 else "backward"
self._viewport.camera.move(direction)
return True
@@ -297,20 +299,22 @@ class CameraWheel(object):
"""
ndc = self._viewport.windowToNdc(x, y)
if ndc is not None:
- near = numpy.array((ndc[0], ndc[1], -1., 1.), dtype=numpy.float32)
+ near = numpy.array((ndc[0], ndc[1], -1.0, 1.0), dtype=numpy.float32)
nearscene = self._viewport.camera.transformPoint(
- near, direct=False, perspectiveDivide=True)
+ near, direct=False, perspectiveDivide=True
+ )
- far = numpy.array((ndc[0], ndc[1], 1., 1.), dtype=numpy.float32)
+ far = numpy.array((ndc[0], ndc[1], 1.0, 1.0), dtype=numpy.float32)
farscene = self._viewport.camera.transformPoint(
- far, direct=False, perspectiveDivide=True)
+ far, direct=False, perspectiveDivide=True
+ )
dirscene = farscene[:3] - nearscene[:3]
dirscene /= numpy.linalg.norm(dirscene)
if angleInDegrees < 0:
- dirscene *= -1.
+ dirscene *= -1.0
# TODO which scale
self._viewport.camera.extrinsic.position += dirscene
@@ -327,43 +331,43 @@ class CameraWheel(object):
if ndc is not None:
ndcz = self._viewport._pickNdcZGL(x, y)
- position = numpy.array((ndc[0], ndc[1], ndcz),
- dtype=numpy.float32)
+ position = numpy.array((ndc[0], ndc[1], ndcz), dtype=numpy.float32)
positionscene = self._viewport.camera.transformPoint(
- position, direct=False, perspectiveDivide=True)
+ position, direct=False, perspectiveDivide=True
+ )
camtopos = extrinsic.position - positionscene
- step = 0.2 * (1. if angleInDegrees < 0 else -1.)
+ step = 0.2 * (1.0 if angleInDegrees < 0 else -1.0)
extrinsic.position += step * camtopos
elif isinstance(projection, transform.Orthographic):
# For orthographic projection, change projection borders
ndcx, ndcy = self._viewport.windowToNdc(x, y, checkInside=False)
- step = 0.2 * (1. if angleInDegrees < 0 else -1.)
+ step = 0.2 * (1.0 if angleInDegrees < 0 else -1.0)
- dx = (ndcx + 1) / 2.
+ dx = (ndcx + 1) / 2.0
stepwidth = step * (projection.right - projection.left)
left = projection.left - dx * stepwidth
- right = projection.right + (1. - dx) * stepwidth
+ right = projection.right + (1.0 - dx) * stepwidth
- dy = (ndcy + 1) / 2.
+ dy = (ndcy + 1) / 2.0
stepheight = step * (projection.top - projection.bottom)
bottom = projection.bottom - dy * stepheight
- top = projection.top + (1. - dy) * stepheight
+ top = projection.top + (1.0 - dy) * stepheight
projection.setClipping(left, right, bottom, top)
else:
- raise RuntimeError('Unsupported camera', projection)
+ raise RuntimeError("Unsupported camera", projection)
return True
def _zoomByScale(self, x, y, angleInDegrees):
"""Zoom by scaling scene (do not keep pixel under mouse invariant)."""
scalefactor = 1.1
- if angleInDegrees < 0.:
- scalefactor = 1. / scalefactor
+ if angleInDegrees < 0.0:
+ scalefactor = 1.0 / scalefactor
self._scale.scale = scalefactor * self._scale.scale
self._viewport.adjustCameraDepthExtent()
@@ -376,12 +380,13 @@ class FocusManager(StateMachine):
On press an event handler can acquire focus.
By default it looses focus when all buttons are released.
"""
+
class Idle(State):
def onPress(self, x, y, btn):
for eventHandler in self.machine.currentEventHandler:
- requestFocus = eventHandler.handleEvent('press', x, y, btn)
+ requestFocus = eventHandler.handleEvent("press", x, y, btn)
if requestFocus:
- self.goto('focus', eventHandler, btn)
+ self.goto("focus", eventHandler, btn)
break
def _processEvent(self, *args):
@@ -391,47 +396,42 @@ class FocusManager(StateMachine):
break
def onMove(self, x, y):
- self._processEvent('move', x, y)
+ self._processEvent("move", x, y)
def onRelease(self, x, y, btn):
- self._processEvent('release', x, y, btn)
+ self._processEvent("release", x, y, btn)
def onWheel(self, x, y, angle):
- self._processEvent('wheel', x, y, angle)
+ self._processEvent("wheel", x, y, angle)
class Focus(State):
def enterState(self, eventHandler, btn):
self.eventHandler = eventHandler
self.focusBtns = {btn} # Set
- enter = enterState # silx v.0.3 support, remove when 0.4 out
-
def onPress(self, x, y, btn):
self.focusBtns.add(btn)
- self.eventHandler.handleEvent('press', x, y, btn)
+ self.eventHandler.handleEvent("press", x, y, btn)
def onMove(self, x, y):
- self.eventHandler.handleEvent('move', x, y)
+ self.eventHandler.handleEvent("move", x, y)
def onRelease(self, x, y, btn):
self.focusBtns.discard(btn)
- requestfocus = self.eventHandler.handleEvent('release', x, y, btn)
+ requestfocus = self.eventHandler.handleEvent("release", x, y, btn)
if len(self.focusBtns) == 0 and not requestfocus:
- self.goto('idle')
+ self.goto("idle")
def onWheel(self, x, y, angleInDegrees):
- self.eventHandler.handleEvent('wheel', x, y, angleInDegrees)
+ self.eventHandler.handleEvent("wheel", x, y, angleInDegrees)
def __init__(self, eventHandlers=(), ctrlEventHandlers=None):
self.defaultEventHandlers = eventHandlers
self.ctrlEventHandlers = ctrlEventHandlers
self.currentEventHandler = self.defaultEventHandlers
- states = {
- 'idle': FocusManager.Idle,
- 'focus': FocusManager.Focus
- }
- super(FocusManager, self).__init__(states, 'idle')
+ states = {"idle": FocusManager.Idle, "focus": FocusManager.Focus}
+ super(FocusManager, self).__init__(states, "idle")
def onKeyPress(self, key):
if key == qt.Qt.Key_Control and self.ctrlEventHandlers is not None:
@@ -450,43 +450,65 @@ class RotateCameraControl(FocusManager):
"""Combine wheel and rotate state machine for left button
and pan when ctrl is pressed
"""
- def __init__(self, viewport,
- orbitAroundCenter=False,
- mode='center', scaleTransform=None,
- selectCB=None):
- handlers = (CameraWheel(viewport, mode, scaleTransform),
- CameraSelectRotate(
- viewport, orbitAroundCenter, LEFT_BTN, selectCB))
- ctrlHandlers = (CameraWheel(viewport, mode, scaleTransform),
- CameraSelectPan(viewport, LEFT_BTN, selectCB))
+
+ def __init__(
+ self,
+ viewport,
+ orbitAroundCenter=False,
+ mode="center",
+ scaleTransform=None,
+ selectCB=None,
+ ):
+ handlers = (
+ CameraWheel(viewport, mode, scaleTransform),
+ CameraSelectRotate(viewport, orbitAroundCenter, LEFT_BTN, selectCB),
+ )
+ ctrlHandlers = (
+ CameraWheel(viewport, mode, scaleTransform),
+ CameraSelectPan(viewport, LEFT_BTN, selectCB),
+ )
super(RotateCameraControl, self).__init__(handlers, ctrlHandlers)
class PanCameraControl(FocusManager):
"""Combine wheel, selectPan and rotate state machine for left button
and rotate when ctrl is pressed"""
- def __init__(self, viewport,
- orbitAroundCenter=False,
- mode='center', scaleTransform=None,
- selectCB=None):
- handlers = (CameraWheel(viewport, mode, scaleTransform),
- CameraSelectPan(viewport, LEFT_BTN, selectCB))
- ctrlHandlers = (CameraWheel(viewport, mode, scaleTransform),
- CameraSelectRotate(
- viewport, orbitAroundCenter, LEFT_BTN, selectCB))
+
+ def __init__(
+ self,
+ viewport,
+ orbitAroundCenter=False,
+ mode="center",
+ scaleTransform=None,
+ selectCB=None,
+ ):
+ handlers = (
+ CameraWheel(viewport, mode, scaleTransform),
+ CameraSelectPan(viewport, LEFT_BTN, selectCB),
+ )
+ ctrlHandlers = (
+ CameraWheel(viewport, mode, scaleTransform),
+ CameraSelectRotate(viewport, orbitAroundCenter, LEFT_BTN, selectCB),
+ )
super(PanCameraControl, self).__init__(handlers, ctrlHandlers)
class CameraControl(FocusManager):
"""Combine wheel, selectPan and rotate state machine."""
- def __init__(self, viewport,
- orbitAroundCenter=False,
- mode='center', scaleTransform=None,
- selectCB=None):
- handlers = (CameraWheel(viewport, mode, scaleTransform),
- CameraSelectPan(viewport, LEFT_BTN, selectCB),
- CameraSelectRotate(
- viewport, orbitAroundCenter, RIGHT_BTN, selectCB))
+
+ def __init__(
+ self,
+ viewport,
+ orbitAroundCenter=False,
+ mode="center",
+ scaleTransform=None,
+ selectCB=None,
+ ):
+ handlers = (
+ CameraWheel(viewport, mode, scaleTransform),
+ CameraSelectPan(viewport, LEFT_BTN, selectCB),
+ CameraSelectRotate(viewport, orbitAroundCenter, RIGHT_BTN, selectCB),
+ )
super(CameraControl, self).__init__(handlers)
@@ -532,14 +554,14 @@ class PlaneRotate(ClickOrDrag):
# Normalize x and y on a unit circle
spherecoords = (position - center) / float(radius)
- squarelength = numpy.sum(spherecoords ** 2)
+ squarelength = numpy.sum(spherecoords**2)
# Project on the unit sphere and compute z coordinates
if squarelength > 1.0: # Outside sphere: project
spherecoords /= numpy.sqrt(squarelength)
zsphere = 0.0
else: # In sphere: compute z
- zsphere = numpy.sqrt(1. - squarelength)
+ zsphere = numpy.sqrt(1.0 - squarelength)
spherecoords = numpy.append(spherecoords, zsphere)
return spherecoords
@@ -552,8 +574,7 @@ class PlaneRotate(ClickOrDrag):
# Store the plane normal
self._beginNormal = self._plane.plane.normal
- _logger.debug(
- 'Begin arcball, plane center %s', str(self._plane.center))
+ _logger.debug("Begin arcball, plane center %s", str(self._plane.center))
# Do the arcball on the screen
radius = min(self._viewport.size)
@@ -562,12 +583,15 @@ class PlaneRotate(ClickOrDrag):
else:
center = self._plane.objectToNDCTransform.transformPoint(
- self._plane.center, perspectiveDivide=True)
+ self._plane.center, perspectiveDivide=True
+ )
self._beginCenter = self._viewport.ndcToWindow(
- center[0], center[1], checkInside=False)
+ center[0], center[1], checkInside=False
+ )
self._startVector = self._sphereUnitVector(
- radius, self._beginCenter, (x, y))
+ radius, self._beginCenter, (x, y)
+ )
def drag(self, x, y):
if self._beginCenter is None:
@@ -575,24 +599,21 @@ class PlaneRotate(ClickOrDrag):
# Compute rotation: this is twice the rotation of the arcball
radius = min(self._viewport.size)
- currentvector = self._sphereUnitVector(
- radius, self._beginCenter, (x, y))
+ currentvector = self._sphereUnitVector(radius, self._beginCenter, (x, y))
crossprod = numpy.cross(self._startVector, currentvector)
dotprod = numpy.dot(self._startVector, currentvector)
quaternion = numpy.append(crossprod, dotprod)
# Rotation was computed with Y downward, but apply in NDC, invert Y
- quaternion[1] *= -1.
+ quaternion[1] *= -1.0
rotation = transform.Rotate()
rotation.quaternion = quaternion
# Convert to NDC, rotate, convert back to object
- normal = self._plane.objectToNDCTransform.transformNormal(
- self._beginNormal)
+ normal = self._plane.objectToNDCTransform.transformNormal(self._beginNormal)
normal = rotation.transformNormal(normal)
- normal = self._plane.objectToNDCTransform.transformNormal(
- normal, direct=False)
+ normal = self._plane.objectToNDCTransform.transformNormal(normal, direct=False)
self._plane.plane.normal = normal
def endDrag(self, x, y):
@@ -607,7 +628,7 @@ class PlanePan(ClickOrDrag):
self._viewport = viewport
self._beginPlanePoint = None
self._beginPos = None
- self._dragNdcZ = 0.
+ self._dragNdcZ = 0.0
super(PlanePan, self).__init__(button)
def click(self, x, y):
@@ -618,16 +639,17 @@ class PlanePan(ClickOrDrag):
ndcZ = self._viewport._pickNdcZGL(x, y)
# ndcZ is the panning plane
if ndc is not None and ndcZ is not None:
- ndcPos = numpy.array((ndc[0], ndc[1], ndcZ, 1.),
- dtype=numpy.float32)
+ ndcPos = numpy.array((ndc[0], ndc[1], ndcZ, 1.0), dtype=numpy.float32)
scenePos = self._viewport.camera.transformPoint(
- ndcPos, direct=False, perspectiveDivide=True)
+ ndcPos, direct=False, perspectiveDivide=True
+ )
self._beginPos = self._plane.objectToSceneTransform.transformPoint(
- scenePos, direct=False)
+ scenePos, direct=False
+ )
self._dragNdcZ = ndcZ
else:
self._beginPos = None
- self._dragNdcZ = 0.
+ self._dragNdcZ = 0.0
self._beginPlanePoint = self._plane.plane.point
@@ -635,14 +657,17 @@ class PlanePan(ClickOrDrag):
if self._beginPos is not None:
ndc = self._viewport.windowToNdc(x, y)
if ndc is not None:
- ndcPos = numpy.array((ndc[0], ndc[1], self._dragNdcZ, 1.),
- dtype=numpy.float32)
+ ndcPos = numpy.array(
+ (ndc[0], ndc[1], self._dragNdcZ, 1.0), dtype=numpy.float32
+ )
# Convert last and current NDC positions to scene coords
scenePos = self._viewport.camera.transformPoint(
- ndcPos, direct=False, perspectiveDivide=True)
+ ndcPos, direct=False, perspectiveDivide=True
+ )
curPos = self._plane.objectToSceneTransform.transformPoint(
- scenePos, direct=False)
+ scenePos, direct=False
+ )
# Get translation in scene coords
translation = curPos[:3] - self._beginPos[:3]
@@ -652,8 +677,7 @@ class PlanePan(ClickOrDrag):
# Keep plane point in bounds
bounds = self._plane.parent.bounds(dataBounds=True)
if bounds is not None:
- newPoint = numpy.clip(
- newPoint, a_min=bounds[0], a_max=bounds[1])
+ newPoint = numpy.clip(newPoint, a_min=bounds[0], a_max=bounds[1])
# Only update plane if it is in some bounds
self._plane.plane.point = newPoint
@@ -664,35 +688,45 @@ class PlanePan(ClickOrDrag):
class PlaneControl(FocusManager):
"""Combine wheel, selectPan and rotate state machine for plane control."""
- def __init__(self, viewport, plane,
- mode='center', scaleTransform=None):
- handlers = (CameraWheel(viewport, mode, scaleTransform),
- PlanePan(viewport, plane, LEFT_BTN),
- PlaneRotate(viewport, plane, RIGHT_BTN))
+
+ def __init__(self, viewport, plane, mode="center", scaleTransform=None):
+ handlers = (
+ CameraWheel(viewport, mode, scaleTransform),
+ PlanePan(viewport, plane, LEFT_BTN),
+ PlaneRotate(viewport, plane, RIGHT_BTN),
+ )
super(PlaneControl, self).__init__(handlers)
class PanPlaneRotateCameraControl(FocusManager):
"""Combine wheel, pan plane and camera rotate state machine."""
- def __init__(self, viewport, plane,
- mode='center', scaleTransform=None):
- handlers = (CameraWheel(viewport, mode, scaleTransform),
- PlanePan(viewport, plane, LEFT_BTN),
- CameraSelectRotate(viewport,
- orbitAroundCenter=False,
- button=RIGHT_BTN))
+
+ def __init__(self, viewport, plane, mode="center", scaleTransform=None):
+ handlers = (
+ CameraWheel(viewport, mode, scaleTransform),
+ PlanePan(viewport, plane, LEFT_BTN),
+ CameraSelectRotate(viewport, orbitAroundCenter=False, button=RIGHT_BTN),
+ )
super(PanPlaneRotateCameraControl, self).__init__(handlers)
class PanPlaneZoomOnWheelControl(FocusManager):
"""Combine zoom on wheel and pan plane state machines."""
- def __init__(self, viewport, plane,
- mode='center',
- orbitAroundCenter=False,
- scaleTransform=None):
- handlers = (CameraWheel(viewport, mode, scaleTransform),
- PlanePan(viewport, plane, LEFT_BTN))
- ctrlHandlers = (CameraWheel(viewport, mode, scaleTransform),
- CameraSelectRotate(
- viewport, orbitAroundCenter, LEFT_BTN))
+
+ def __init__(
+ self,
+ viewport,
+ plane,
+ mode="center",
+ orbitAroundCenter=False,
+ scaleTransform=None,
+ ):
+ handlers = (
+ CameraWheel(viewport, mode, scaleTransform),
+ PlanePan(viewport, plane, LEFT_BTN),
+ )
+ ctrlHandlers = (
+ CameraWheel(viewport, mode, scaleTransform),
+ CameraSelectRotate(viewport, orbitAroundCenter, LEFT_BTN),
+ )
super(PanPlaneZoomOnWheelControl, self).__init__(handlers, ctrlHandlers)
diff --git a/src/silx/gui/plot3d/scene/primitives.py b/src/silx/gui/plot3d/scene/primitives.py
index 6d3c4ff..93070c3 100644
--- a/src/silx/gui/plot3d/scene/primitives.py
+++ b/src/silx/gui/plot3d/scene/primitives.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2015-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2015-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -26,10 +26,7 @@ __authors__ = ["T. Vincent"]
__license__ = "MIT"
__date__ = "24/04/2018"
-try:
- from collections import abc
-except ImportError: # Python2 support
- import collections as abc
+from collections import abc
import ctypes
from functools import reduce
import logging
@@ -53,6 +50,7 @@ _logger = logging.getLogger(__name__)
# Geometry ####################################################################
+
class Geometry(core.Elem):
"""Set of vertices with normals and colors.
@@ -65,39 +63,36 @@ class Geometry(core.Elem):
"""
_ATTR_INFO = {
- 'position': {'dims': (1, 2), 'lastDim': (2, 3, 4)},
- 'normal': {'dims': (1, 2), 'lastDim': (3,)},
- 'color': {'dims': (1, 2), 'lastDim': (3, 4)},
+ "position": {"dims": (1, 2), "lastDim": (2, 3, 4)},
+ "normal": {"dims": (1, 2), "lastDim": (3,)},
+ "color": {"dims": (1, 2), "lastDim": (3, 4)},
}
_MODE_CHECKS = { # Min, Modulo
- 'lines': (2, 2), 'line_strip': (2, 0), 'loop': (2, 0),
- 'points': (1, 0),
- 'triangles': (3, 3), 'triangle_strip': (3, 0), 'fan': (3, 0)
+ "lines": (2, 2),
+ "line_strip": (2, 0),
+ "loop": (2, 0),
+ "points": (1, 0),
+ "triangles": (3, 3),
+ "triangle_strip": (3, 0),
+ "fan": (3, 0),
}
_MODES = {
- 'lines': gl.GL_LINES,
- 'line_strip': gl.GL_LINE_STRIP,
- 'loop': gl.GL_LINE_LOOP,
-
- 'points': gl.GL_POINTS,
-
- 'triangles': gl.GL_TRIANGLES,
- 'triangle_strip': gl.GL_TRIANGLE_STRIP,
- 'fan': gl.GL_TRIANGLE_FAN
+ "lines": gl.GL_LINES,
+ "line_strip": gl.GL_LINE_STRIP,
+ "loop": gl.GL_LINE_LOOP,
+ "points": gl.GL_POINTS,
+ "triangles": gl.GL_TRIANGLES,
+ "triangle_strip": gl.GL_TRIANGLE_STRIP,
+ "fan": gl.GL_TRIANGLE_FAN,
}
- _LINE_MODES = 'lines', 'line_strip', 'loop'
+ _LINE_MODES = "lines", "line_strip", "loop"
- _TRIANGLE_MODES = 'triangles', 'triangle_strip', 'fan'
+ _TRIANGLE_MODES = "triangles", "triangle_strip", "fan"
- def __init__(self,
- mode,
- indices=None,
- copy=True,
- attrib0='position',
- **attributes):
+ def __init__(self, mode, indices=None, copy=True, attrib0="position", **attributes):
super(Geometry, self).__init__()
self._attrib0 = str(attrib0)
@@ -146,26 +141,26 @@ class Geometry(core.Elem):
"""
# Convert single value (int, float, numpy types) to tuple
if not isinstance(array, abc.Iterable):
- array = (array, )
+ array = (array,)
# Makes sure it is an array
array = numpy.array(array, copy=False)
dtype = None
- if array.dtype.kind == 'f' and array.dtype.itemsize != 4:
+ if array.dtype.kind == "f" and array.dtype.itemsize != 4:
# Cast to float32
- _logger.info('Cast array to float32')
+ _logger.info("Cast array to float32")
dtype = numpy.float32
elif array.dtype.itemsize > 4:
# Cast (u)int64 to (u)int32
- if array.dtype.kind == 'i':
- _logger.info('Cast array to int32')
+ if array.dtype.kind == "i":
+ _logger.info("Cast array to int32")
dtype = numpy.int32
- elif array.dtype.kind == 'u':
- _logger.info('Cast array to uint32')
+ elif array.dtype.kind == "u":
+ _logger.info("Cast array to uint32")
dtype = numpy.uint32
- return numpy.array(array, dtype=dtype, order='C', copy=copy)
+ return numpy.array(array, dtype=dtype, order="C", copy=copy)
@property
def nbVertices(self):
@@ -200,17 +195,16 @@ class Geometry(core.Elem):
array = self._glReadyArray(array, copy=copy)
if name not in self._ATTR_INFO:
- _logger.debug('Not checking attribute %s dimensions', name)
+ _logger.debug("Not checking attribute %s dimensions", name)
else:
checks = self._ATTR_INFO[name]
- if (array.ndim == 1 and checks['lastDim'] == (1,) and
- len(array) > 1):
+ if array.ndim == 1 and checks["lastDim"] == (1,) and len(array) > 1:
array = array.reshape((len(array), 1))
# Checks
- assert array.ndim in checks['dims'], "Attr %s" % name
- assert array.shape[-1] in checks['lastDim'], "Attr %s" % name
+ assert array.ndim in checks["dims"], "Attr %s" % name
+ assert array.shape[-1] in checks["lastDim"], "Attr %s" % name
# Makes sure attrib0 is considered as an array of values
if name == self.attrib0 and array.ndim == 1:
@@ -277,7 +271,8 @@ class Geometry(core.Elem):
assert len(array) in (1, 2, 3, 4)
gl.glDisableVertexAttribArray(attribute)
_glVertexAttribFunc = getattr(
- _glutils.gl, 'glVertexAttrib{}f'.format(len(array)))
+ _glutils.gl, "glVertexAttrib{}f".format(len(array))
+ )
_glVertexAttribFunc(attribute, *array)
else:
# TODO As is this is a never event, remove?
@@ -288,7 +283,8 @@ class Geometry(core.Elem):
_glutils.numpyToGLType(array.dtype),
gl.GL_FALSE,
0,
- array)
+ array,
+ )
def setIndices(self, indices, copy=True):
"""Set the primitive indices to use.
@@ -297,13 +293,13 @@ class Geometry(core.Elem):
:param bool copy: True (default) to copy the data, False to use as is
"""
# Trigger garbage collection of previous indices VBO if any
- self._vbos.pop('__indices__', None)
+ self._vbos.pop("__indices__", None)
if indices is None:
self._indices = None
else:
indices = self._glReadyArray(indices, copy=copy).ravel()
- assert indices.dtype.name in ('uint8', 'uint16', 'uint32')
+ assert indices.dtype.name in ("uint8", "uint16", "uint32")
if _logger.getEffectiveLevel() <= logging.DEBUG:
# This might be a costy check
assert indices.max() < self.nbVertices
@@ -364,19 +360,22 @@ class Geometry(core.Elem):
min_ = numpy.nanmin(attribute, axis=0)
max_ = numpy.nanmax(attribute, axis=0)
else:
- min_, max_ = numpy.zeros((2, attribute.shape[1]), dtype=numpy.float32)
+ min_, max_ = numpy.zeros(
+ (2, attribute.shape[1]), dtype=numpy.float32
+ )
- toCopy = min(len(min_), 3-index)
+ toCopy = min(len(min_), 3 - index)
if toCopy != len(min_):
- _logger.error("Attribute defining bounds"
- " has too many dimensions")
+ _logger.error(
+ "Attribute defining bounds" " has too many dimensions"
+ )
- self.__bounds[0, index:index+toCopy] = min_[:toCopy]
- self.__bounds[1, index:index+toCopy] = max_[:toCopy]
+ self.__bounds[0, index : index + toCopy] = min_[:toCopy]
+ self.__bounds[1, index : index + toCopy] = max_[:toCopy]
index += toCopy
- self.__bounds[numpy.isnan(self.__bounds)] = 0. # Avoid NaNs
+ self.__bounds[numpy.isnan(self.__bounds)] = 0.0 # Avoid NaNs
return self.__bounds.copy()
@@ -389,11 +388,13 @@ class Geometry(core.Elem):
self._vbos[name] = ctx.glCtx.makeVboAttrib(array)
self._unsyncAttributes = []
- if self._indices is not None and '__indices__' not in self._vbos:
- vbo = ctx.glCtx.makeVbo(self._indices,
- usage=gl.GL_STATIC_DRAW,
- target=gl.GL_ELEMENT_ARRAY_BUFFER)
- self._vbos['__indices__'] = vbo
+ if self._indices is not None and "__indices__" not in self._vbos:
+ vbo = ctx.glCtx.makeVbo(
+ self._indices,
+ usage=gl.GL_STATIC_DRAW,
+ target=gl.GL_ELEMENT_ARRAY_BUFFER,
+ )
+ self._vbos["__indices__"] = vbo
def _draw(self, program=None, nbVertices=None):
"""Perform OpenGL draw calls.
@@ -413,18 +414,23 @@ class Geometry(core.Elem):
else:
if nbVertices is None:
nbVertices = self._indices.size
- with self._vbos['__indices__']:
- gl.glDrawElements(self._MODES[self._mode],
- nbVertices,
- _glutils.numpyToGLType(self._indices.dtype),
- ctypes.c_void_p(0))
+ with self._vbos["__indices__"]:
+ gl.glDrawElements(
+ self._MODES[self._mode],
+ nbVertices,
+ _glutils.numpyToGLType(self._indices.dtype),
+ ctypes.c_void_p(0),
+ )
# Lines #######################################################################
+
class Lines(Geometry):
"""A set of segments"""
- _shaders = ("""
+
+ _shaders = (
+ """
attribute vec3 position;
attribute vec3 normal;
attribute vec4 color;
@@ -446,7 +452,8 @@ class Lines(Geometry):
vColor = color;
}
""",
- string.Template("""
+ string.Template(
+ """
varying vec4 vCameraPosition;
varying vec3 vPosition;
varying vec3 vNormal;
@@ -461,33 +468,43 @@ class Lines(Geometry):
gl_FragColor = $lightingCall(vColor, vPosition, vNormal);
$scenePostCall(vCameraPosition);
}
- """))
-
- def __init__(self, positions, normals=None, colors=(1., 1., 1., 1.),
- indices=None, mode='lines', width=1.):
- if mode == 'strip':
- mode = 'line_strip'
+ """
+ ),
+ )
+
+ def __init__(
+ self,
+ positions,
+ normals=None,
+ colors=(1.0, 1.0, 1.0, 1.0),
+ indices=None,
+ mode="lines",
+ width=1.0,
+ ):
+ if mode == "strip":
+ mode = "line_strip"
assert mode in self._LINE_MODES
self._width = width
self._smooth = True
- super(Lines, self).__init__(mode, indices,
- position=positions,
- normal=normals,
- color=colors)
+ super(Lines, self).__init__(
+ mode, indices, position=positions, normal=normals, color=colors
+ )
- width = event.notifyProperty('_width', converter=float,
- doc="Width of the line in pixels.")
+ width = event.notifyProperty(
+ "_width", converter=float, doc="Width of the line in pixels."
+ )
smooth = event.notifyProperty(
- '_smooth',
+ "_smooth",
converter=bool,
- doc="Smooth line rendering enabled (bool, default: True)")
+ doc="Smooth line rendering enabled (bool, default: True)",
+ )
def renderGL2(self, ctx):
# Prepare program
- isnormals = 'normal' in self._attributes
+ isnormals = "normal" in self._attributes
if isnormals:
fraglightfunction = ctx.viewport.light.fragmentDef
else:
@@ -498,7 +515,8 @@ class Lines(Geometry):
scenePreCall=ctx.fragCallPre,
scenePostCall=ctx.fragCallPost,
lightingFunction=fraglightfunction,
- lightingCall=ctx.viewport.light.fragmentCall)
+ lightingCall=ctx.viewport.light.fragmentCall,
+ )
prog = ctx.glCtx.prog(self._shaders[0], fragment)
prog.use()
@@ -507,10 +525,8 @@ class Lines(Geometry):
gl.glLineWidth(self.width)
- prog.setUniformMatrix('matrix', ctx.objectToNDC.matrix)
- prog.setUniformMatrix('transformMat',
- ctx.objectToCamera.matrix,
- safe=True)
+ prog.setUniformMatrix("matrix", ctx.objectToNDC.matrix)
+ prog.setUniformMatrix("transformMat", ctx.objectToCamera.matrix, safe=True)
ctx.setupProgram(prog)
@@ -524,7 +540,8 @@ class DashedLines(Lines):
This MUST be defined as a set of lines (no strip or loop).
"""
- _shaders = ("""
+ _shaders = (
+ """
attribute vec3 position;
attribute vec3 origin;
attribute vec3 normal;
@@ -554,7 +571,8 @@ class DashedLines(Lines):
vOriginFragCoord = (ndcOrigin.xy + vec2(1.0, 1.0)) * 0.5 * viewportSize + vec2(0.5, 0.5);
}
""", # noqa
- string.Template("""
+ string.Template(
+ """
varying vec4 vCameraPosition;
varying vec3 vPosition;
varying vec3 vNormal;
@@ -579,16 +597,19 @@ class DashedLines(Lines):
$scenePostCall(vCameraPosition);
}
- """))
+ """
+ ),
+ )
- def __init__(self, positions, colors=(1., 1., 1., 1.),
- indices=None, width=1.):
+ def __init__(self, positions, colors=(1.0, 1.0, 1.0, 1.0), indices=None, width=1.0):
self._dash = 1, 0
- super(DashedLines, self).__init__(positions=positions,
- colors=colors,
- indices=indices,
- mode='lines',
- width=width)
+ super(DashedLines, self).__init__(
+ positions=positions,
+ colors=colors,
+ indices=indices,
+ mode="lines",
+ width=width,
+ )
@property
def dash(self):
@@ -609,7 +630,7 @@ class DashedLines(Lines):
:returns: Coordinates of lines
:rtype: numpy.ndarray of float32 of shape (N, 2, Ndim)
"""
- return self.getAttribute('position', copy=copy)
+ return self.getAttribute("position", copy=copy)
def setPositions(self, positions, copy=True):
"""Set line coordinates.
@@ -617,27 +638,27 @@ class DashedLines(Lines):
:param positions: Array of line coordinates
:param bool copy: True to copy input array, False to use as is
"""
- self.setAttribute('position', positions, copy=copy)
+ self.setAttribute("position", positions, copy=copy)
# Update line origins from given positions
- origins = numpy.array(positions, copy=True, order='C')
+ origins = numpy.array(positions, copy=True, order="C")
origins[1::2] = origins[::2]
- self.setAttribute('origin', origins, copy=False)
+ self.setAttribute("origin", origins, copy=False)
def renderGL2(self, context):
# Prepare program
- isnormals = 'normal' in self._attributes
+ isnormals = "normal" in self._attributes
if isnormals:
fraglightfunction = context.viewport.light.fragmentDef
else:
- fraglightfunction = \
- context.viewport.light.fragmentShaderFunctionNoop
+ fraglightfunction = context.viewport.light.fragmentShaderFunctionNoop
fragment = self._shaders[1].substitute(
sceneDecl=context.fragDecl,
scenePreCall=context.fragCallPre,
scenePostCall=context.fragCallPost,
lightingFunction=fraglightfunction,
- lightingCall=context.viewport.light.fragmentCall)
+ lightingCall=context.viewport.light.fragmentCall,
+ )
program = context.glCtx.prog(self._shaders[0], fragment)
program.use()
@@ -646,14 +667,13 @@ class DashedLines(Lines):
gl.glLineWidth(self.width)
- program.setUniformMatrix('matrix', context.objectToNDC.matrix)
- program.setUniformMatrix('transformMat',
- context.objectToCamera.matrix,
- safe=True)
+ program.setUniformMatrix("matrix", context.objectToNDC.matrix)
+ program.setUniformMatrix(
+ "transformMat", context.objectToCamera.matrix, safe=True
+ )
- gl.glUniform2f(
- program.uniforms['viewportSize'], *context.viewport.size)
- gl.glUniform2f(program.uniforms['dash'], *self.dash)
+ gl.glUniform2f(program.uniforms["viewportSize"], *context.viewport.size)
+ gl.glUniform2f(program.uniforms["dash"], *self.dash)
context.setupProgram(program)
@@ -663,42 +683,64 @@ class DashedLines(Lines):
class Box(core.PrivateGroup):
"""Rectangular box"""
- _lineIndices = numpy.array((
- (0, 1), (1, 2), (2, 3), (3, 0), # Lines with z=0
- (0, 4), (1, 5), (2, 6), (3, 7), # Lines from z=0 to z=1
- (4, 5), (5, 6), (6, 7), (7, 4)), # Lines with z=1
- dtype=numpy.uint8)
+ _lineIndices = numpy.array(
+ (
+ (0, 1),
+ (1, 2),
+ (2, 3),
+ (3, 0), # Lines with z=0
+ (0, 4),
+ (1, 5),
+ (2, 6),
+ (3, 7), # Lines from z=0 to z=1
+ (4, 5),
+ (5, 6),
+ (6, 7),
+ (7, 4),
+ ), # Lines with z=1
+ dtype=numpy.uint8,
+ )
_faceIndices = numpy.array(
- (0, 3, 1, 2, 5, 6, 4, 7, 7, 6, 6, 2, 7, 3, 4, 0, 5, 1),
- dtype=numpy.uint8)
-
- _vertices = numpy.array((
- # Corners with z=0
- (0., 0., 0.), (1., 0., 0.), (1., 1., 0.), (0., 1., 0.),
- # Corners with z=1
- (0., 0., 1.), (1., 0., 1.), (1., 1., 1.), (0., 1., 1.)),
- dtype=numpy.float32)
-
- def __init__(self, stroke=(1., 1., 1., 1.), fill=(1., 1., 1., 0.)):
+ (0, 3, 1, 2, 5, 6, 4, 7, 7, 6, 6, 2, 7, 3, 4, 0, 5, 1), dtype=numpy.uint8
+ )
+
+ _vertices = numpy.array(
+ (
+ # Corners with z=0
+ (0.0, 0.0, 0.0),
+ (1.0, 0.0, 0.0),
+ (1.0, 1.0, 0.0),
+ (0.0, 1.0, 0.0),
+ # Corners with z=1
+ (0.0, 0.0, 1.0),
+ (1.0, 0.0, 1.0),
+ (1.0, 1.0, 1.0),
+ (0.0, 1.0, 1.0),
+ ),
+ dtype=numpy.float32,
+ )
+
+ def __init__(self, stroke=(1.0, 1.0, 1.0, 1.0), fill=(1.0, 1.0, 1.0, 0.0)):
super(Box, self).__init__()
- self._fill = Mesh3D(self._vertices,
- colors=rgba(fill),
- mode='triangle_strip',
- indices=self._faceIndices)
- self._fill.visible = self.fillColor[-1] != 0.
+ self._fill = Mesh3D(
+ self._vertices,
+ colors=rgba(fill),
+ mode="triangle_strip",
+ indices=self._faceIndices,
+ )
+ self._fill.visible = self.fillColor[-1] != 0.0
- self._stroke = Lines(self._vertices,
- indices=self._lineIndices,
- colors=rgba(stroke),
- mode='lines')
- self._stroke.visible = self.strokeColor[-1] != 0.
- self.strokeWidth = 1.
+ self._stroke = Lines(
+ self._vertices, indices=self._lineIndices, colors=rgba(stroke), mode="lines"
+ )
+ self._stroke.visible = self.strokeColor[-1] != 0.0
+ self.strokeWidth = 1.0
self._children = [self._stroke, self._fill]
- self._size = 1., 1., 1.
+ self._size = 1.0, 1.0, 1.0
@classmethod
def getLineIndices(cls, copy=True):
@@ -732,11 +774,11 @@ class Box(core.PrivateGroup):
if size != self.size:
self._size = size
self._fill.setAttribute(
- 'position',
- self._vertices * numpy.array(size, dtype=numpy.float32))
+ "position", self._vertices * numpy.array(size, dtype=numpy.float32)
+ )
self._stroke.setAttribute(
- 'position',
- self._vertices * numpy.array(size, dtype=numpy.float32))
+ "position", self._vertices * numpy.array(size, dtype=numpy.float32)
+ )
self.notify()
@property
@@ -766,29 +808,29 @@ class Box(core.PrivateGroup):
@property
def strokeColor(self):
"""RGBA color of the box lines (4-tuple of float in [0, 1])"""
- return tuple(self._stroke.getAttribute('color', copy=False))
+ return tuple(self._stroke.getAttribute("color", copy=False))
@strokeColor.setter
def strokeColor(self, color):
color = rgba(color)
if color != self.strokeColor:
- self._stroke.setAttribute('color', color)
+ self._stroke.setAttribute("color", color)
# Fully transparent = hidden
- self._stroke.visible = color[-1] != 0.
+ self._stroke.visible = color[-1] != 0.0
self.notify()
@property
def fillColor(self):
"""RGBA color of the box faces (4-tuple of float in [0, 1])"""
- return tuple(self._fill.getAttribute('color', copy=False))
+ return tuple(self._fill.getAttribute("color", copy=False))
@fillColor.setter
def fillColor(self, color):
color = rgba(color)
if color != self.fillColor:
- self._fill.setAttribute('color', color)
+ self._fill.setAttribute("color", color)
# Fully transparent = hidden
- self._fill.visible = color[-1] != 0.
+ self._fill.visible = color[-1] != 0.0
self.notify()
@property
@@ -802,21 +844,34 @@ class Box(core.PrivateGroup):
class Axes(Lines):
"""3D RGB orthogonal axes"""
- _vertices = numpy.array(((0., 0., 0.), (1., 0., 0.),
- (0., 0., 0.), (0., 1., 0.),
- (0., 0., 0.), (0., 0., 1.)),
- dtype=numpy.float32)
- _colors = numpy.array(((255, 0, 0, 255), (255, 0, 0, 255),
- (0, 255, 0, 255), (0, 255, 0, 255),
- (0, 0, 255, 255), (0, 0, 255, 255)),
- dtype=numpy.uint8)
+ _vertices = numpy.array(
+ (
+ (0.0, 0.0, 0.0),
+ (1.0, 0.0, 0.0),
+ (0.0, 0.0, 0.0),
+ (0.0, 1.0, 0.0),
+ (0.0, 0.0, 0.0),
+ (0.0, 0.0, 1.0),
+ ),
+ dtype=numpy.float32,
+ )
+
+ _colors = numpy.array(
+ (
+ (255, 0, 0, 255),
+ (255, 0, 0, 255),
+ (0, 255, 0, 255),
+ (0, 255, 0, 255),
+ (0, 0, 255, 255),
+ (0, 0, 255, 255),
+ ),
+ dtype=numpy.uint8,
+ )
def __init__(self):
- super(Axes, self).__init__(self._vertices,
- colors=self._colors,
- width=3.)
- self._size = 1., 1., 1.
+ super(Axes, self).__init__(self._vertices, colors=self._colors, width=3.0)
+ self._size = 1.0, 1.0, 1.0
@property
def size(self):
@@ -830,8 +885,8 @@ class Axes(Lines):
if size != self.size:
self._size = size
self.setAttribute(
- 'position',
- self._vertices * numpy.array(size, dtype=numpy.float32))
+ "position", self._vertices * numpy.array(size, dtype=numpy.float32)
+ )
self.notify()
@@ -841,39 +896,67 @@ class BoxWithAxes(Lines):
:param color: RGBA color of the box
"""
- _vertices = numpy.array((
- # Axes corners
- (0., 0., 0.), (1., 0., 0.),
- (0., 0., 0.), (0., 1., 0.),
- (0., 0., 0.), (0., 0., 1.),
- # Box corners with z=0
- (1., 0., 0.), (1., 1., 0.), (0., 1., 0.),
- # Box corners with z=1
- (0., 0., 1.), (1., 0., 1.), (1., 1., 1.), (0., 1., 1.)),
- dtype=numpy.float32)
-
- _axesColors = numpy.array(((1., 0., 0., 1.), (1., 0., 0., 1.),
- (0., 1., 0., 1.), (0., 1., 0., 1.),
- (0., 0., 1., 1.), (0., 0., 1., 1.)),
- dtype=numpy.float32)
-
- _lineIndices = numpy.array((
- (0, 1), (2, 3), (4, 5), # Axes lines
- (6, 7), (7, 8), # Box lines with z=0
- (6, 10), (7, 11), (8, 12), # Box lines from z=0 to z=1
- (9, 10), (10, 11), (11, 12), (12, 9)), # Box lines with z=1
- dtype=numpy.uint8)
-
- def __init__(self, color=(1., 1., 1., 1.)):
- self._color = (1., 1., 1., 1.)
+ _vertices = numpy.array(
+ (
+ # Axes corners
+ (0.0, 0.0, 0.0),
+ (1.0, 0.0, 0.0),
+ (0.0, 0.0, 0.0),
+ (0.0, 1.0, 0.0),
+ (0.0, 0.0, 0.0),
+ (0.0, 0.0, 1.0),
+ # Box corners with z=0
+ (1.0, 0.0, 0.0),
+ (1.0, 1.0, 0.0),
+ (0.0, 1.0, 0.0),
+ # Box corners with z=1
+ (0.0, 0.0, 1.0),
+ (1.0, 0.0, 1.0),
+ (1.0, 1.0, 1.0),
+ (0.0, 1.0, 1.0),
+ ),
+ dtype=numpy.float32,
+ )
+
+ _axesColors = numpy.array(
+ (
+ (1.0, 0.0, 0.0, 1.0),
+ (1.0, 0.0, 0.0, 1.0),
+ (0.0, 1.0, 0.0, 1.0),
+ (0.0, 1.0, 0.0, 1.0),
+ (0.0, 0.0, 1.0, 1.0),
+ (0.0, 0.0, 1.0, 1.0),
+ ),
+ dtype=numpy.float32,
+ )
+
+ _lineIndices = numpy.array(
+ (
+ (0, 1),
+ (2, 3),
+ (4, 5), # Axes lines
+ (6, 7),
+ (7, 8), # Box lines with z=0
+ (6, 10),
+ (7, 11),
+ (8, 12), # Box lines from z=0 to z=1
+ (9, 10),
+ (10, 11),
+ (11, 12),
+ (12, 9),
+ ), # Box lines with z=1
+ dtype=numpy.uint8,
+ )
+
+ def __init__(self, color=(1.0, 1.0, 1.0, 1.0)):
+ self._color = (1.0, 1.0, 1.0, 1.0)
colors = numpy.ones((len(self._vertices), 4), dtype=numpy.float32)
- colors[:len(self._axesColors), :] = self._axesColors
+ colors[: len(self._axesColors), :] = self._axesColors
- super(BoxWithAxes, self).__init__(self._vertices,
- indices=self._lineIndices,
- colors=colors,
- width=2.)
- self._size = 1., 1., 1.
+ super(BoxWithAxes, self).__init__(
+ self._vertices, indices=self._lineIndices, colors=colors, width=2.0
+ )
+ self._size = 1.0, 1.0, 1.0
self.color = color
@property
@@ -887,9 +970,9 @@ class BoxWithAxes(Lines):
if color != self._color:
self._color = color
colors = numpy.empty((len(self._vertices), 4), dtype=numpy.float32)
- colors[:len(self._axesColors), :] = self._axesColors
- colors[len(self._axesColors):, :] = self._color
- self.setAttribute('color', colors) # Do the notification
+ colors[: len(self._axesColors), :] = self._axesColors
+ colors[len(self._axesColors) :, :] = self._color
+ self.setAttribute("color", colors) # Do the notification
@property
def size(self):
@@ -903,8 +986,8 @@ class BoxWithAxes(Lines):
if size != self.size:
self._size = size
self.setAttribute(
- 'position',
- self._vertices * numpy.array(size, dtype=numpy.float32))
+ "position", self._vertices * numpy.array(size, dtype=numpy.float32)
+ )
self.notify()
@@ -916,29 +999,29 @@ class PlaneInGroup(core.PrivateGroup):
Cannot set the transform attribute of this primitive.
This primitive never has any bounds.
"""
+
# TODO inherit from Lines directly?, make sure the plane remains visible?
- def __init__(self, point=(0., 0., 0.), normal=(0., 0., 1.)):
+ def __init__(self, point=(0.0, 0.0, 0.0), normal=(0.0, 0.0, 1.0)):
super(PlaneInGroup, self).__init__()
self._cache = None, None # Store bounds, vertices
self._outline = None
self._color = None
- self.color = 1., 1., 1., 1. # Set _color
- self._width = 2.
+ self.color = 1.0, 1.0, 1.0, 1.0 # Set _color
+ self._width = 2.0
self._strokeVisible = True
self._plane = utils.Plane(point, normal)
self._plane.addListener(self._planeChanged)
def moveToCenter(self):
- """Place the plane at the center of the data, not changing orientation.
- """
+ """Place the plane at the center of the data, not changing orientation."""
if self.parent is not None:
bounds = self.parent.bounds(dataBounds=True)
if bounds is not None:
- center = (bounds[0] + bounds[1]) / 2.
- _logger.debug('Moving plane to center: %s', str(center))
+ center = (bounds[0] + bounds[1]) / 2.0
+ _logger.debug("Moving plane to center: %s", str(center))
self.plane.point = center
@property
@@ -950,7 +1033,7 @@ class PlaneInGroup(core.PrivateGroup):
def color(self, color):
self._color = numpy.array(color, copy=True, dtype=numpy.float32)
if self._outline is not None:
- self._outline.setAttribute('color', self._color)
+ self._outline.setAttribute("color", self._color)
self.notify() # This is OK as Lines are rebuild for each rendering
@property
@@ -1019,7 +1102,8 @@ class PlaneInGroup(core.PrivateGroup):
boxVertices = bounds[0] + boxVertices * (bounds[1] - bounds[0])
lineIndices = Box.getLineIndices(copy=False)
vertices = utils.boxPlaneIntersect(
- boxVertices, lineIndices, self.plane.normal, self.plane.point)
+ boxVertices, lineIndices, self.plane.normal, self.plane.point
+ )
self._cache = bounds, vertices if len(vertices) != 0 else None
@@ -1041,15 +1125,15 @@ class PlaneInGroup(core.PrivateGroup):
def prepareGL2(self, ctx):
if self.isValid:
if self._outline is None: # Init outline
- self._outline = Lines(self.contourVertices,
- mode='loop',
- colors=self.color)
+ self._outline = Lines(
+ self.contourVertices, mode="loop", colors=self.color
+ )
self._outline.width = self._width
self._outline.visible = self._strokeVisible
self._children.append(self._outline)
# Update vertices, TODO only when necessary
- self._outline.setAttribute('position', self.contourVertices)
+ self._outline.setAttribute("position", self.contourVertices)
super(PlaneInGroup, self).prepareGL2(ctx)
@@ -1094,28 +1178,36 @@ class BoundedGroup(core.Group):
def _bounds(self, dataBounds=False):
if dataBounds and self.size is not None:
- return numpy.array(((0., 0., 0.), self.size),
- dtype=numpy.float32)
+ return numpy.array(((0.0, 0.0, 0.0), self.size), dtype=numpy.float32)
else:
return super(BoundedGroup, self)._bounds(dataBounds)
# Points ######################################################################
+
class _Points(Geometry):
"""Base class to render a set of points."""
- DIAMOND = 'd'
- CIRCLE = 'o'
- SQUARE = 's'
- PLUS = '+'
- X_MARKER = 'x'
- ASTERISK = '*'
- H_LINE = '_'
- V_LINE = '|'
-
- SUPPORTED_MARKERS = (DIAMOND, CIRCLE, SQUARE, PLUS,
- X_MARKER, ASTERISK, H_LINE, V_LINE)
+ DIAMOND = "d"
+ CIRCLE = "o"
+ SQUARE = "s"
+ PLUS = "+"
+ X_MARKER = "x"
+ ASTERISK = "*"
+ H_LINE = "_"
+ V_LINE = "|"
+
+ SUPPORTED_MARKERS = (
+ DIAMOND,
+ CIRCLE,
+ SQUARE,
+ PLUS,
+ X_MARKER,
+ ASTERISK,
+ H_LINE,
+ V_LINE,
+ )
"""List of supported markers:
- 'd' diamond
@@ -1204,10 +1296,12 @@ class _Points(Geometry):
return 0.0;
}
}
- """
+ """,
}
- _shaders = (string.Template("""
+ _shaders = (
+ string.Template(
+ """
#version 120
attribute float x;
@@ -1234,8 +1328,10 @@ class _Points(Geometry):
gl_PointSize = size;
vSize = size;
}
- """),
- string.Template("""
+ """
+ ),
+ string.Template(
+ """
#version 120
varying vec4 vCameraPosition;
@@ -1260,25 +1356,23 @@ class _Points(Geometry):
$scenePostCall(vCameraPosition);
}
- """))
+ """
+ ),
+ )
_ATTR_INFO = {
- 'x': {'dims': (1, 2), 'lastDim': (1,)},
- 'y': {'dims': (1, 2), 'lastDim': (1,)},
- 'z': {'dims': (1, 2), 'lastDim': (1,)},
- 'size': {'dims': (1, 2), 'lastDim': (1,)},
+ "x": {"dims": (1, 2), "lastDim": (1,)},
+ "y": {"dims": (1, 2), "lastDim": (1,)},
+ "z": {"dims": (1, 2), "lastDim": (1,)},
+ "size": {"dims": (1, 2), "lastDim": (1,)},
}
- def __init__(self, x, y, z, value, size=1., indices=None):
- super(_Points, self).__init__('points', indices,
- x=x,
- y=y,
- z=z,
- value=value,
- size=size,
- attrib0='x')
- self.boundsAttributeNames = 'x', 'y', 'z'
- self._marker = 'o'
+ def __init__(self, x, y, z, value, size=1.0, indices=None):
+ super(_Points, self).__init__(
+ "points", indices, x=x, y=y, z=z, value=value, size=size, attrib0="x"
+ )
+ self.boundsAttributeNames = "x", "y", "z"
+ self._marker = "o"
@property
def marker(self):
@@ -1297,20 +1391,16 @@ class _Points(Geometry):
self.notify()
def _shaderValueDefinition(self):
- """Type definition, fragment shader declaration, fragment shader call
- """
- raise NotImplementedError(
- "This method must be implemented in subclass")
+ """Type definition, fragment shader declaration, fragment shader call"""
+ raise NotImplementedError("This method must be implemented in subclass")
def _renderGL2PreDrawHook(self, ctx, program):
"""Override in subclass to run code before calling gl draw"""
pass
def renderGL2(self, ctx):
- valueType, valueToColorDecl, valueToColorCall = \
- self._shaderValueDefinition()
- vertexShader = self._shaders[0].substitute(
- valueType=valueType)
+ valueType, valueToColorDecl, valueToColorCall = self._shaderValueDefinition()
+ vertexShader = self._shaders[0].substitute(valueType=valueType)
fragmentShader = self._shaders[1].substitute(
sceneDecl=ctx.fragDecl,
scenePreCall=ctx.fragCallPre,
@@ -1318,19 +1408,17 @@ class _Points(Geometry):
valueType=valueType,
valueToColorDecl=valueToColorDecl,
valueToColorCall=valueToColorCall,
- alphaSymbolDecl=self._MARKER_FUNCTIONS[self.marker])
- program = ctx.glCtx.prog(vertexShader, fragmentShader,
- attrib0=self.attrib0)
+ alphaSymbolDecl=self._MARKER_FUNCTIONS[self.marker],
+ )
+ program = ctx.glCtx.prog(vertexShader, fragmentShader, attrib0=self.attrib0)
program.use()
gl.glEnable(gl.GL_VERTEX_PROGRAM_POINT_SIZE) # OpenGL 2
gl.glEnable(gl.GL_POINT_SPRITE) # OpenGL 2
# gl.glEnable(gl.GL_PROGRAM_POINT_SIZE)
- program.setUniformMatrix('matrix', ctx.objectToNDC.matrix)
- program.setUniformMatrix('transformMat',
- ctx.objectToCamera.matrix,
- safe=True)
+ program.setUniformMatrix("matrix", ctx.objectToNDC.matrix)
+ program.setUniformMatrix("transformMat", ctx.objectToCamera.matrix, safe=True)
ctx.setupProgram(program)
@@ -1343,16 +1431,12 @@ class Points(_Points):
"""A set of data points with an associated value and size."""
_ATTR_INFO = _Points._ATTR_INFO.copy()
- _ATTR_INFO.update({'value': {'dims': (1, 2), 'lastDim': (1,)}})
+ _ATTR_INFO.update({"value": {"dims": (1, 2), "lastDim": (1,)}})
- def __init__(self, x, y, z, value=0., size=1.,
- indices=None, colormap=None):
- super(Points, self).__init__(x=x,
- y=y,
- z=z,
- indices=indices,
- size=size,
- value=value)
+ def __init__(self, x, y, z, value=0.0, size=1.0, indices=None, colormap=None):
+ super(Points, self).__init__(
+ x=x, y=y, z=z, indices=indices, size=size, value=value
+ )
self._colormap = colormap or Colormap() # Default colormap
self._colormap.addListener(self._cmapChanged)
@@ -1367,9 +1451,8 @@ class Points(_Points):
self.notify(*args, **kwargs)
def _shaderValueDefinition(self):
- """Type definition, fragment shader declaration, fragment shader call
- """
- return 'float', self.colormap.decl, self.colormap.call
+ """Type definition, fragment shader declaration, fragment shader call"""
+ return "float", self.colormap.decl, self.colormap.call
def _renderGL2PreDrawHook(self, ctx, program):
"""Set-up colormap before calling gl draw"""
@@ -1380,21 +1463,16 @@ class ColorPoints(_Points):
"""A set of points with an associated color and size."""
_ATTR_INFO = _Points._ATTR_INFO.copy()
- _ATTR_INFO.update({'value': {'dims': (1, 2), 'lastDim': (3, 4)}})
+ _ATTR_INFO.update({"value": {"dims": (1, 2), "lastDim": (3, 4)}})
- def __init__(self, x, y, z, color=(1., 1., 1., 1.), size=1.,
- indices=None):
- super(ColorPoints, self).__init__(x=x,
- y=y,
- z=z,
- indices=indices,
- size=size,
- value=color)
+ def __init__(self, x, y, z, color=(1.0, 1.0, 1.0, 1.0), size=1.0, indices=None):
+ super(ColorPoints, self).__init__(
+ x=x, y=y, z=z, indices=indices, size=size, value=color
+ )
def _shaderValueDefinition(self):
- """Type definition, fragment shader declaration, fragment shader call
- """
- return 'vec4', '', ''
+ """Type definition, fragment shader declaration, fragment shader call"""
+ return "vec4", "", ""
def setColor(self, color, copy=True):
"""Set colors
@@ -1404,7 +1482,7 @@ class ColorPoints(_Points):
:param bool copy: True to copy colors (default),
False to use provided array (Do not modify!)
"""
- self.setAttribute('value', color, copy=copy)
+ self.setAttribute("value", color, copy=copy)
def getColor(self, copy=True):
"""Returns the color or array of colors of the points.
@@ -1414,13 +1492,14 @@ class ColorPoints(_Points):
:return: Color or array of colors
:rtype: numpy.ndarray
"""
- return self.getAttribute('value', copy=copy)
+ return self.getAttribute("value", copy=copy)
class GridPoints(Geometry):
# GLSL 1.30 !
"""Data points on a regular grid with an associated value and size."""
- _shaders = ("""
+ _shaders = (
+ """
#version 130
in float value;
@@ -1478,7 +1557,8 @@ class GridPoints(Geometry):
gl_PointSize = size;
}
""",
- string.Template("""
+ string.Template(
+ """
#version 130
in vec4 vCameraPosition;
@@ -1495,18 +1575,27 @@ class GridPoints(Geometry):
$scenePostCall(vCameraPosition);
}
- """))
+ """
+ ),
+ )
_ATTR_INFO = {
- 'value': {'dims': (1, 2), 'lastDim': (1,)},
- 'size': {'dims': (1, 2), 'lastDim': (1,)}
+ "value": {"dims": (1, 2), "lastDim": (1,)},
+ "size": {"dims": (1, 2), "lastDim": (1,)},
}
# TODO Add colormap, shape?
# TODO could also use a texture to store values
- def __init__(self, values=0., shape=None, sizes=1., indices=None,
- minValue=None, maxValue=None):
+ def __init__(
+ self,
+ values=0.0,
+ shape=None,
+ sizes=1.0,
+ indices=None,
+ minValue=None,
+ maxValue=None,
+ ):
if isinstance(values, abc.Iterable):
values = numpy.array(values, copy=False)
@@ -1522,16 +1611,14 @@ class GridPoints(Geometry):
assert len(self._shape) in (1, 2, 3)
- super(GridPoints, self).__init__('points', indices,
- value=values,
- size=sizes)
+ super(GridPoints, self).__init__("points", indices, value=values, size=sizes)
- data = self.getAttribute('value', copy=False)
+ data = self.getAttribute("value", copy=False)
self._minValue = data.min() if minValue is None else minValue
self._maxValue = data.max() if maxValue is None else maxValue
- minValue = event.notifyProperty('_minValue')
- maxValue = event.notifyProperty('_maxValue')
+ minValue = event.notifyProperty("_minValue")
+ maxValue = event.notifyProperty("_maxValue")
def _bounds(self, dataBounds=False):
# Get bounds from values shape
@@ -1544,7 +1631,8 @@ class GridPoints(Geometry):
fragment = self._shaders[1].substitute(
sceneDecl=ctx.fragDecl,
scenePreCall=ctx.fragCallPre,
- scenePostCall=ctx.fragCallPost)
+ scenePostCall=ctx.fragCallPost,
+ )
prog = ctx.glCtx.prog(self._shaders[0], fragment)
prog.use()
@@ -1552,25 +1640,26 @@ class GridPoints(Geometry):
gl.glEnable(gl.GL_POINT_SPRITE) # OpenGL 2
# gl.glEnable(gl.GL_PROGRAM_POINT_SIZE)
- prog.setUniformMatrix('matrix', ctx.objectToNDC.matrix)
- prog.setUniformMatrix('transformMat',
- ctx.objectToCamera.matrix,
- safe=True)
+ prog.setUniformMatrix("matrix", ctx.objectToNDC.matrix)
+ prog.setUniformMatrix("transformMat", ctx.objectToCamera.matrix, safe=True)
ctx.setupProgram(prog)
- gl.glUniform3i(prog.uniforms['gridDims'],
- self._shape[2] if len(self._shape) == 3 else 1,
- self._shape[1] if len(self._shape) >= 2 else 1,
- self._shape[0])
+ gl.glUniform3i(
+ prog.uniforms["gridDims"],
+ self._shape[2] if len(self._shape) == 3 else 1,
+ self._shape[1] if len(self._shape) >= 2 else 1,
+ self._shape[0],
+ )
- gl.glUniform2f(prog.uniforms['valRange'], self.minValue, self.maxValue)
+ gl.glUniform2f(prog.uniforms["valRange"], self.minValue, self.maxValue)
self._draw(prog, nbVertices=reduce(lambda a, b: a * b, self._shape))
# Spheres #####################################################################
+
class Spheres(Geometry):
"""A set of spheres.
@@ -1581,6 +1670,7 @@ class Spheres(Geometry):
- Do not render distorion by perspective projection.
- If the sphere center is clipped, the whole sphere is not displayed.
"""
+
# TODO check those links
# Accounting for perspective projection
# http://iquilezles.org/www/articles/sphereproj/sphereproj.htm
@@ -1593,7 +1683,8 @@ class Spheres(Geometry):
# TODO some issues with small scaling and regular grid or due to sampling
- _shaders = ("""
+ _shaders = (
+ """
#version 120
attribute vec3 position;
@@ -1632,7 +1723,8 @@ class Spheres(Geometry):
vViewDepth = vCameraPosition.z;
}
""",
- string.Template("""
+ string.Template(
+ """
# version 120
uniform mat4 projMat;
@@ -1672,20 +1764,21 @@ class Spheres(Geometry):
$scenePostCall(vCameraPosition);
}
- """))
+ """
+ ),
+ )
_ATTR_INFO = {
- 'position': {'dims': (2, ), 'lastDim': (2, 3, 4)},
- 'radius': {'dims': (1, 2), 'lastDim': (1, )},
- 'color': {'dims': (1, 2), 'lastDim': (3, 4)},
+ "position": {"dims": (2,), "lastDim": (2, 3, 4)},
+ "radius": {"dims": (1, 2), "lastDim": (1,)},
+ "color": {"dims": (1, 2), "lastDim": (3, 4)},
}
- def __init__(self, positions, radius=1., colors=(1., 1., 1., 1.)):
+ def __init__(self, positions, radius=1.0, colors=(1.0, 1.0, 1.0, 1.0)):
self.__bounds = None
- super(Spheres, self).__init__('points', None,
- position=positions,
- radius=radius,
- color=colors)
+ super(Spheres, self).__init__(
+ "points", None, position=positions, radius=radius, color=colors
+ )
def renderGL2(self, ctx):
fragment = self._shaders[1].substitute(
@@ -1693,7 +1786,8 @@ class Spheres(Geometry):
scenePreCall=ctx.fragCallPre,
scenePostCall=ctx.fragCallPost,
lightingFunction=ctx.viewport.light.fragmentDef,
- lightingCall=ctx.viewport.light.fragmentCall)
+ lightingCall=ctx.viewport.light.fragmentCall,
+ )
prog = ctx.glCtx.prog(self._shaders[0], fragment)
prog.use()
@@ -1703,14 +1797,12 @@ class Spheres(Geometry):
gl.glEnable(gl.GL_POINT_SPRITE) # OpenGL 2
# gl.glEnable(gl.GL_PROGRAM_POINT_SIZE)
- prog.setUniformMatrix('projMat', ctx.projection.matrix)
- prog.setUniformMatrix('transformMat',
- ctx.objectToCamera.matrix,
- safe=True)
+ prog.setUniformMatrix("projMat", ctx.projection.matrix)
+ prog.setUniformMatrix("transformMat", ctx.objectToCamera.matrix, safe=True)
ctx.setupProgram(prog)
- gl.glUniform2f(prog.uniforms['screenSize'], *ctx.viewport.size)
+ gl.glUniform2f(prog.uniforms["screenSize"], *ctx.viewport.size)
self._draw(prog)
@@ -1718,21 +1810,25 @@ class Spheres(Geometry):
if self.__bounds is None:
self.__bounds = numpy.zeros((2, 3), dtype=numpy.float32)
# Support vertex with to 2 to 4 coordinates
- positions = self._attributes['position']
- radius = self._attributes['radius']
- self.__bounds[0, :positions.shape[1]] = \
- (positions - radius).min(axis=0)[:3]
- self.__bounds[1, :positions.shape[1]] = \
- (positions + radius).max(axis=0)[:3]
+ positions = self._attributes["position"]
+ radius = self._attributes["radius"]
+ self.__bounds[0, : positions.shape[1]] = (positions - radius).min(axis=0)[
+ :3
+ ]
+ self.__bounds[1, : positions.shape[1]] = (positions + radius).max(axis=0)[
+ :3
+ ]
return self.__bounds.copy()
# Meshes ######################################################################
+
class Mesh3D(Geometry):
"""A conventional 3D mesh"""
- _shaders = ("""
+ _shaders = (
+ """
attribute vec3 position;
attribute vec3 normal;
attribute vec4 color;
@@ -1756,7 +1852,8 @@ class Mesh3D(Geometry):
gl_Position = matrix * vec4(position, 1.0);
}
""",
- string.Template("""
+ string.Template(
+ """
varying vec4 vCameraPosition;
varying vec3 vPosition;
varying vec3 vNormal;
@@ -1773,21 +1870,17 @@ class Mesh3D(Geometry):
$scenePostCall(vCameraPosition);
}
- """))
-
- def __init__(self,
- positions,
- colors,
- normals=None,
- mode='triangles',
- indices=None,
- copy=True):
+ """
+ ),
+ )
+
+ def __init__(
+ self, positions, colors, normals=None, mode="triangles", indices=None, copy=True
+ ):
assert mode in self._TRIANGLE_MODES
- super(Mesh3D, self).__init__(mode, indices,
- position=positions,
- normal=normals,
- color=colors,
- copy=copy)
+ super(Mesh3D, self).__init__(
+ mode, indices, position=positions, normal=normals, color=colors, copy=copy
+ )
self._culling = None
@@ -1801,13 +1894,13 @@ class Mesh3D(Geometry):
@culling.setter
def culling(self, culling):
- assert culling in ('back', 'front', None)
+ assert culling in ("back", "front", None)
if culling != self._culling:
self._culling = culling
self.notify()
def renderGL2(self, ctx):
- isnormals = 'normal' in self._attributes
+ isnormals = "normal" in self._attributes
if isnormals:
fragLightFunction = ctx.viewport.light.fragmentDef
else:
@@ -1818,7 +1911,8 @@ class Mesh3D(Geometry):
scenePreCall=ctx.fragCallPre,
scenePostCall=ctx.fragCallPost,
lightingFunction=fragLightFunction,
- lightingCall=ctx.viewport.light.fragmentCall)
+ lightingCall=ctx.viewport.light.fragmentCall,
+ )
prog = ctx.glCtx.prog(self._shaders[0], fragment)
prog.use()
@@ -1826,14 +1920,12 @@ class Mesh3D(Geometry):
ctx.viewport.light.setupProgram(ctx, prog)
if self.culling is not None:
- cullFace = gl.GL_FRONT if self.culling == 'front' else gl.GL_BACK
+ cullFace = gl.GL_FRONT if self.culling == "front" else gl.GL_BACK
gl.glCullFace(cullFace)
gl.glEnable(gl.GL_CULL_FACE)
- prog.setUniformMatrix('matrix', ctx.objectToNDC.matrix)
- prog.setUniformMatrix('transformMat',
- ctx.objectToCamera.matrix,
- safe=True)
+ prog.setUniformMatrix("matrix", ctx.objectToNDC.matrix)
+ prog.setUniformMatrix("transformMat", ctx.objectToCamera.matrix, safe=True)
ctx.setupProgram(prog)
@@ -1846,7 +1938,8 @@ class Mesh3D(Geometry):
class ColormapMesh3D(Geometry):
"""A 3D mesh with color computed from a colormap"""
- _shaders = ("""
+ _shaders = (
+ """
attribute vec3 position;
attribute vec3 normal;
attribute float value;
@@ -1870,7 +1963,8 @@ class ColormapMesh3D(Geometry):
gl_Position = matrix * vec4(position, 1.0);
}
""",
- string.Template("""
+ string.Template(
+ """
uniform float alpha;
varying vec4 vCameraPosition;
@@ -1892,21 +1986,23 @@ class ColormapMesh3D(Geometry):
$scenePostCall(vCameraPosition);
}
- """))
-
- def __init__(self,
- position,
- value,
- colormap=None,
- normal=None,
- mode='triangles',
- indices=None,
- copy=True):
- super(ColormapMesh3D, self).__init__(mode, indices,
- position=position,
- normal=normal,
- value=value,
- copy=copy)
+ """
+ ),
+ )
+
+ def __init__(
+ self,
+ position,
+ value,
+ colormap=None,
+ normal=None,
+ mode="triangles",
+ indices=None,
+ copy=True,
+ ):
+ super(ColormapMesh3D, self).__init__(
+ mode, indices, position=position, normal=normal, value=value, copy=copy
+ )
self._alpha = 1.0
self._lineWidth = 1.0
@@ -1915,17 +2011,19 @@ class ColormapMesh3D(Geometry):
self._colormap = colormap or Colormap() # Default colormap
self._colormap.addListener(self._cmapChanged)
- lineWidth = event.notifyProperty('_lineWidth', converter=float,
- doc="Width of the line in pixels.")
+ lineWidth = event.notifyProperty(
+ "_lineWidth", converter=float, doc="Width of the line in pixels."
+ )
lineSmooth = event.notifyProperty(
- '_lineSmooth',
+ "_lineSmooth",
converter=bool,
- doc="Smooth line rendering enabled (bool, default: True)")
+ doc="Smooth line rendering enabled (bool, default: True)",
+ )
alpha = event.notifyProperty(
- '_alpha', converter=float,
- doc="Transparency of the mesh, float in [0, 1]")
+ "_alpha", converter=float, doc="Transparency of the mesh, float in [0, 1]"
+ )
@property
def culling(self):
@@ -1937,7 +2035,7 @@ class ColormapMesh3D(Geometry):
@culling.setter
def culling(self, culling):
- assert culling in ('back', 'front', None)
+ assert culling in ("back", "front", None)
if culling != self._culling:
self._culling = culling
self.notify()
@@ -1952,7 +2050,7 @@ class ColormapMesh3D(Geometry):
self.notify(*args, **kwargs)
def renderGL2(self, ctx):
- if 'normal' in self._attributes:
+ if "normal" in self._attributes:
self._renderGL2(ctx)
else: # Disable lighting
with self.viewport.light.turnOff():
@@ -1966,7 +2064,8 @@ class ColormapMesh3D(Geometry):
lightingFunction=ctx.viewport.light.fragmentDef,
lightingCall=ctx.viewport.light.fragmentCall,
colormapDecl=self.colormap.decl,
- colormapCall=self.colormap.call)
+ colormapCall=self.colormap.call,
+ )
program = ctx.glCtx.prog(self._shaders[0], fragment)
program.use()
@@ -1975,15 +2074,13 @@ class ColormapMesh3D(Geometry):
self.colormap.setupProgram(ctx, program)
if self.culling is not None:
- cullFace = gl.GL_FRONT if self.culling == 'front' else gl.GL_BACK
+ cullFace = gl.GL_FRONT if self.culling == "front" else gl.GL_BACK
gl.glCullFace(cullFace)
gl.glEnable(gl.GL_CULL_FACE)
- program.setUniformMatrix('matrix', ctx.objectToNDC.matrix)
- program.setUniformMatrix('transformMat',
- ctx.objectToCamera.matrix,
- safe=True)
- gl.glUniform1f(program.uniforms['alpha'], self._alpha)
+ program.setUniformMatrix("matrix", ctx.objectToNDC.matrix)
+ program.setUniformMatrix("transformMat", ctx.objectToCamera.matrix, safe=True)
+ gl.glUniform1f(program.uniforms["alpha"], self._alpha)
if self.drawMode in self._LINE_MODES:
gl.glLineWidth(self.lineWidth)
@@ -1998,10 +2095,12 @@ class ColormapMesh3D(Geometry):
# ImageData ##################################################################
+
class _Image(Geometry):
"""Base class for ImageData and ImageRgba"""
- _shaders = ("""
+ _shaders = (
+ """
attribute vec2 position;
uniform mat4 matrix;
@@ -2022,7 +2121,8 @@ class _Image(Geometry):
gl_Position = matrix * positionVec4;
}
""",
- string.Template("""
+ string.Template(
+ """
varying vec4 vCameraPosition;
varying vec3 vPosition;
varying vec2 vTexCoords;
@@ -2048,22 +2148,24 @@ class _Image(Geometry):
$scenePostCall(vCameraPosition);
}
- """))
+ """
+ ),
+ )
- _UNIT_SQUARE = numpy.array(((0., 0.), (1., 0.), (0., 1.), (1., 1.)),
- dtype=numpy.float32)
+ _UNIT_SQUARE = numpy.array(
+ ((0.0, 0.0), (1.0, 0.0), (0.0, 1.0), (1.0, 1.0)), dtype=numpy.float32
+ )
def __init__(self, data, copy=True):
- super(_Image, self).__init__(mode='triangle_strip',
- position=self._UNIT_SQUARE)
+ super(_Image, self).__init__(mode="triangle_strip", position=self._UNIT_SQUARE)
self._texture = None
self._update_texture = True
self._update_texture_filter = False
self._data = None
self.setData(data, copy)
- self._alpha = 1.
- self._interpolation = 'linear'
+ self._alpha = 1.0
+ self._interpolation = "linear"
self.isBackfaceVisible = True
@@ -2077,7 +2179,9 @@ class _Image(Geometry):
self._update_texture = True
# By updating the position rather than always using a unit square
# we benefit from Geometry bounds handling
- self.setAttribute('position', self._UNIT_SQUARE * (self._data.shape[1], self._data.shape[0]))
+ self.setAttribute(
+ "position", self._UNIT_SQUARE * (self._data.shape[1], self._data.shape[0])
+ )
self.notify()
def getData(self, copy=True):
@@ -2090,7 +2194,7 @@ class _Image(Geometry):
@interpolation.setter
def interpolation(self, interpolation):
- assert interpolation in ('linear', 'nearest')
+ assert interpolation in ("linear", "nearest")
self._interpolation = interpolation
self._update_texture_filter = True
self.notify()
@@ -2110,15 +2214,14 @@ class _Image(Geometry):
:return: 2-tuple of gl flags (internalFormat, format)
"""
- raise NotImplementedError(
- "This method must be implemented in a subclass")
+ raise NotImplementedError("This method must be implemented in a subclass")
def prepareGL2(self, ctx):
if self._texture is None or self._update_texture:
if self._texture is not None:
self._texture.discard()
- if self.interpolation == 'nearest':
+ if self.interpolation == "nearest":
filter_ = gl.GL_NEAREST
else:
filter_ = gl.GL_LINEAR
@@ -2134,11 +2237,12 @@ class _Image(Geometry):
format_,
minFilter=filter_,
magFilter=filter_,
- wrap=gl.GL_CLAMP_TO_EDGE)
+ wrap=gl.GL_CLAMP_TO_EDGE,
+ )
if self._update_texture_filter and self._texture is not None:
self._update_texture_filter = False
- if self.interpolation == 'nearest':
+ if self.interpolation == "nearest":
filter_ = gl.GL_NEAREST
else:
filter_ = gl.GL_LINEAR
@@ -2160,8 +2264,7 @@ class _Image(Geometry):
def _shaderImageColorDecl(self):
"""Returns fragment shader imageColor function declaration"""
- raise NotImplementedError(
- "This method must be implemented in a subclass")
+ raise NotImplementedError("This method must be implemented in a subclass")
def _renderGL2(self, ctx):
fragment = self._shaders[1].substitute(
@@ -2170,8 +2273,8 @@ class _Image(Geometry):
scenePostCall=ctx.fragCallPost,
lightingFunction=ctx.viewport.light.fragmentDef,
lightingCall=ctx.viewport.light.fragmentCall,
- imageDecl=self._shaderImageColorDecl()
- )
+ imageDecl=self._shaderImageColorDecl(),
+ )
program = ctx.glCtx.prog(self._shaders[0], fragment)
program.use()
@@ -2181,16 +2284,14 @@ class _Image(Geometry):
gl.glCullFace(gl.GL_BACK)
gl.glEnable(gl.GL_CULL_FACE)
- program.setUniformMatrix('matrix', ctx.objectToNDC.matrix)
- program.setUniformMatrix('transformMat',
- ctx.objectToCamera.matrix,
- safe=True)
- gl.glUniform1f(program.uniforms['alpha'], self._alpha)
+ program.setUniformMatrix("matrix", ctx.objectToNDC.matrix)
+ program.setUniformMatrix("transformMat", ctx.objectToCamera.matrix, safe=True)
+ gl.glUniform1f(program.uniforms["alpha"], self._alpha)
shape = self._data.shape
- gl.glUniform2f(program.uniforms['dataScale'], 1./shape[1], 1./shape[0])
+ gl.glUniform2f(program.uniforms["dataScale"], 1.0 / shape[1], 1.0 / shape[0])
- gl.glUniform1i(program.uniforms['data'], self._texture.texUnit)
+ gl.glUniform1i(program.uniforms["data"], self._texture.texUnit)
ctx.setupProgram(program)
@@ -2207,7 +2308,8 @@ class _Image(Geometry):
class ImageData(_Image):
"""Display a 2x2 data array with a texture."""
- _imageDecl = string.Template("""
+ _imageDecl = string.Template(
+ """
$colormapDecl
vec4 imageColor(sampler2D data, vec2 texCoords) {
@@ -2215,7 +2317,8 @@ class ImageData(_Image):
vec4 color = $colormapCall(value);
return color;
}
- """)
+ """
+ )
def __init__(self, data, copy=True, colormap=None):
super(ImageData, self).__init__(data, copy=copy)
@@ -2224,7 +2327,7 @@ class ImageData(_Image):
self._colormap.addListener(self._cmapChanged)
def setData(self, data, copy=True):
- data = numpy.array(data, copy=copy, order='C', dtype=numpy.float32)
+ data = numpy.array(data, copy=copy, order="C", dtype=numpy.float32)
# TODO support (u)int8|16
assert data.ndim == 2
@@ -2247,12 +2350,13 @@ class ImageData(_Image):
def _shaderImageColorDecl(self):
return self._imageDecl.substitute(
- colormapDecl=self.colormap.decl,
- colormapCall=self.colormap.call)
+ colormapDecl=self.colormap.decl, colormapCall=self.colormap.call
+ )
# ImageRgba ##################################################################
+
class ImageRgba(_Image):
"""Display a 2x2 RGBA image with a texture.
@@ -2270,10 +2374,10 @@ class ImageRgba(_Image):
super(ImageRgba, self).__init__(data, copy=copy)
def setData(self, data, copy=True):
- data = numpy.array(data, copy=copy, order='C')
+ data = numpy.array(data, copy=copy, order="C")
assert data.ndim == 3
assert data.shape[2] in (3, 4)
- if data.dtype.kind == 'f':
+ if data.dtype.kind == "f":
if data.dtype != numpy.dtype(numpy.float32):
_logger.warning("Converting image data to float32")
data = numpy.array(data, dtype=numpy.float32, copy=False)
@@ -2295,6 +2399,7 @@ class ImageRgba(_Image):
# TODO lighting, clipping as groups?
# group composition?
+
class GroupDepthOffset(core.Group):
"""A group using 2-pass rendering and glDepthRange to avoid Z-fighting"""
@@ -2306,7 +2411,7 @@ class GroupDepthOffset(core.Group):
def prepareGL2(self, ctx):
if self._epsilon is None:
depthbits = gl.glGetInteger(gl.GL_DEPTH_BITS)
- self._epsilon = 1. / (1 << (depthbits - 1))
+ self._epsilon = 1.0 / (1 << (depthbits - 1))
def renderGL2(self, ctx):
if self.isDepthRangeOn:
@@ -2319,38 +2424,34 @@ class GroupDepthOffset(core.Group):
with gl.enabled(gl.GL_CULL_FACE):
gl.glCullFace(gl.GL_BACK)
for child in self.children:
- gl.glColorMask(
- gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE)
+ gl.glColorMask(gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE)
gl.glDepthMask(gl.GL_TRUE)
- gl.glDepthRange(self._epsilon, 1.)
+ gl.glDepthRange(self._epsilon, 1.0)
child.render(ctx)
- gl.glColorMask(
- gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE)
+ gl.glColorMask(gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE)
gl.glDepthMask(gl.GL_FALSE)
- gl.glDepthRange(0., 1. - self._epsilon)
+ gl.glDepthRange(0.0, 1.0 - self._epsilon)
child.render(ctx)
gl.glCullFace(gl.GL_FRONT)
for child in reversed(self.children):
- gl.glColorMask(
- gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE)
+ gl.glColorMask(gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE)
gl.glDepthMask(gl.GL_TRUE)
- gl.glDepthRange(self._epsilon, 1.)
+ gl.glDepthRange(self._epsilon, 1.0)
child.render(ctx)
- gl.glColorMask(
- gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE)
+ gl.glColorMask(gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE)
gl.glDepthMask(gl.GL_FALSE)
- gl.glDepthRange(0., 1. - self._epsilon)
+ gl.glDepthRange(0.0, 1.0 - self._epsilon)
child.render(ctx)
gl.glDepthMask(gl.GL_TRUE)
- gl.glDepthRange(0., 1.)
+ gl.glDepthRange(0.0, 1.0)
# gl.glDepthFunc(gl.GL_LEQUAL)
# TODO use epsilon for all rendering?
# TODO issue with picking in depth buffer!
@@ -2382,7 +2483,7 @@ class GroupNoDepth(core.Group):
class GroupBBox(core.PrivateGroup):
"""A group displaying a bounding box around the children."""
- def __init__(self, children=(), color=(1., 1., 1., 1.)):
+ def __init__(self, children=(), color=(1.0, 1.0, 1.0, 1.0)):
super(GroupBBox, self).__init__()
self._group = core.Group(children)
@@ -2394,7 +2495,7 @@ class GroupBBox(core.PrivateGroup):
self._boxWithAxes.smooth = False
self._boxWithAxes.transforms = self._boxTransforms
- self._box = Box(stroke=color, fill=(1., 1., 1., 0.))
+ self._box = Box(stroke=color, fill=(1.0, 1.0, 1.0, 0.0))
self._box.strokeSmooth = False
self._box.transforms = self._boxTransforms
self._box.visible = False
@@ -2404,7 +2505,7 @@ class GroupBBox(core.PrivateGroup):
self._axes.transforms = self._boxTransforms
self._axes.visible = False
- self.strokeWidth = 2.
+ self.strokeWidth = 2.0
self._children = [self._boxWithAxes, self._box, self._axes, self._group]
@@ -2415,7 +2516,7 @@ class GroupBBox(core.PrivateGroup):
origin = bounds[0]
size = bounds[1] - bounds[0]
else:
- origin, size = (0., 0., 0.), (1., 1., 1.)
+ origin, size = (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)
self._boxTransforms[0].translation = origin
@@ -2484,8 +2585,9 @@ class GroupBBox(core.PrivateGroup):
@axesVisible.setter
def axesVisible(self, visible):
- self._updateBoxAndAxesVisibility(axesVisible=bool(visible),
- boxVisible=self.boxVisible)
+ self._updateBoxAndAxesVisibility(
+ axesVisible=bool(visible), boxVisible=self.boxVisible
+ )
@property
def boxVisible(self):
@@ -2494,12 +2596,14 @@ class GroupBBox(core.PrivateGroup):
@boxVisible.setter
def boxVisible(self, visible):
- self._updateBoxAndAxesVisibility(axesVisible=self.axesVisible,
- boxVisible=bool(visible))
+ self._updateBoxAndAxesVisibility(
+ axesVisible=self.axesVisible, boxVisible=bool(visible)
+ )
# Clipping Plane ##############################################################
+
class ClipPlane(PlaneInGroup):
"""A clipping plane attached to a box"""
@@ -2510,8 +2614,9 @@ class ClipPlane(PlaneInGroup):
# Set-up clipping plane for following brothers
# No need of perspective divide, no projection
- point = ctx.objectToCamera.transformPoint(self.plane.point,
- perspectiveDivide=False)
+ point = ctx.objectToCamera.transformPoint(
+ self.plane.point, perspectiveDivide=False
+ )
normal = ctx.objectToCamera.transformNormal(self.plane.normal)
ctx.setClipPlane(point, normal)
diff --git a/src/silx/gui/plot3d/scene/test/test_transform.py b/src/silx/gui/plot3d/scene/test/test_transform.py
index 2998c65..cba384d 100644
--- a/src/silx/gui/plot3d/scene/test/test_transform.py
+++ b/src/silx/gui/plot3d/scene/test/test_transform.py
@@ -34,7 +34,6 @@ from silx.gui.plot3d.scene import transform
class TestTransformList(unittest.TestCase):
-
def assertSameArrays(self, a, b):
return self.assertTrue(numpy.allclose(a, b, atol=1e-06))
@@ -45,25 +44,36 @@ class TestTransformList(unittest.TestCase):
self.assertSameArrays(refmatrix, transforms.matrix)
# Append translate
- transforms.append(transform.Translate(1., 1., 1.))
- refmatrix = numpy.array(((1., 0., 0., 1.),
- (0., 1., 0., 1.),
- (0., 0., 1., 1.),
- (0., 0., 0., 1.)), dtype=numpy.float32)
+ transforms.append(transform.Translate(1.0, 1.0, 1.0))
+ refmatrix = numpy.array(
+ (
+ (1.0, 0.0, 0.0, 1.0),
+ (0.0, 1.0, 0.0, 1.0),
+ (0.0, 0.0, 1.0, 1.0),
+ (0.0, 0.0, 0.0, 1.0),
+ ),
+ dtype=numpy.float32,
+ )
self.assertSameArrays(refmatrix, transforms.matrix)
# Extend scale
- transforms.extend([transform.Scale(0.1, 2., 1.)])
- refmatrix = numpy.dot(refmatrix,
- numpy.array(((0.1, 0., 0., 0.),
- (0., 2., 0., 0.),
- (0., 0., 1., 0.),
- (0., 0., 0., 1.)),
- dtype=numpy.float32))
+ transforms.extend([transform.Scale(0.1, 2.0, 1.0)])
+ refmatrix = numpy.dot(
+ refmatrix,
+ numpy.array(
+ (
+ (0.1, 0.0, 0.0, 0.0),
+ (0.0, 2.0, 0.0, 0.0),
+ (0.0, 0.0, 1.0, 0.0),
+ (0.0, 0.0, 0.0, 1.0),
+ ),
+ dtype=numpy.float32,
+ ),
+ )
self.assertSameArrays(refmatrix, transforms.matrix)
# Insert rotate
- transforms.insert(0, transform.Rotate(360.))
+ transforms.insert(0, transform.Rotate(360.0))
self.assertSameArrays(refmatrix, transforms.matrix)
# Update translate and check for listener called
@@ -71,6 +81,7 @@ class TestTransformList(unittest.TestCase):
def listener(source):
self._callCount += 1
+
transforms.addListener(listener)
transforms[1].tx += 1
diff --git a/src/silx/gui/plot3d/scene/test/test_utils.py b/src/silx/gui/plot3d/scene/test/test_utils.py
index a9ba6bc..81f99d6 100644
--- a/src/silx/gui/plot3d/scene/test/test_utils.py
+++ b/src/silx/gui/plot3d/scene/test/test_utils.py
@@ -27,7 +27,6 @@ __license__ = "MIT"
__date__ = "17/01/2018"
-import unittest
from silx.utils.testutils import ParametricTestCase
import numpy
@@ -37,34 +36,35 @@ from silx.gui.plot3d.scene import utils
# angleBetweenVectors #########################################################
-class TestAngleBetweenVectors(ParametricTestCase):
+class TestAngleBetweenVectors(ParametricTestCase):
TESTS = { # name: (refvector, vectors, norm, refangles)
- 'single vector':
- ((1., 0., 0.), (1., 0., 0.), (0., 0., 1.), 0.),
- 'single vector, no norm':
- ((1., 0., 0.), (1., 0., 0.), None, 0.),
-
- 'with orthogonal norm':
- ((1., 0., 0.),
- ((1., 0., 0.), (0., 1., 0.), (-1., 0., 0.), (0., -1., 0.)),
- (0., 0., 1.),
- (0., 90., 180., 270.)),
-
- 'with coplanar norm': # = similar to no norm
- ((1., 0., 0.),
- ((1., 0., 0.), (0., 1., 0.), (-1., 0., 0.), (0., -1., 0.)),
- (1., 0., 0.),
- (0., 90., 180., 90.)),
-
- 'without norm':
- ((1., 0., 0.),
- ((1., 0., 0.), (0., 1., 0.), (-1., 0., 0.), (0., -1., 0.)),
- None,
- (0., 90., 180., 90.)),
-
- 'not unit vectors':
- ((2., 2., 0.), ((1., 1., 0.), (1., -1., 0.)), None, (0., 90.)),
+ "single vector": ((1.0, 0.0, 0.0), (1.0, 0.0, 0.0), (0.0, 0.0, 1.0), 0.0),
+ "single vector, no norm": ((1.0, 0.0, 0.0), (1.0, 0.0, 0.0), None, 0.0),
+ "with orthogonal norm": (
+ (1.0, 0.0, 0.0),
+ ((1.0, 0.0, 0.0), (0.0, 1.0, 0.0), (-1.0, 0.0, 0.0), (0.0, -1.0, 0.0)),
+ (0.0, 0.0, 1.0),
+ (0.0, 90.0, 180.0, 270.0),
+ ),
+ "with coplanar norm": ( # = similar to no norm
+ (1.0, 0.0, 0.0),
+ ((1.0, 0.0, 0.0), (0.0, 1.0, 0.0), (-1.0, 0.0, 0.0), (0.0, -1.0, 0.0)),
+ (1.0, 0.0, 0.0),
+ (0.0, 90.0, 180.0, 90.0),
+ ),
+ "without norm": (
+ (1.0, 0.0, 0.0),
+ ((1.0, 0.0, 0.0), (0.0, 1.0, 0.0), (-1.0, 0.0, 0.0), (0.0, -1.0, 0.0)),
+ None,
+ (0.0, 90.0, 180.0, 90.0),
+ ),
+ "not unit vectors": (
+ (2.0, 2.0, 0.0),
+ ((1.0, 1.0, 0.0), (1.0, -1.0, 0.0)),
+ None,
+ (0.0, 90.0),
+ ),
}
def testAngleBetweenVectorsFunction(self):
@@ -78,15 +78,14 @@ class TestAngleBetweenVectors(ParametricTestCase):
if norm is not None:
norm = numpy.array(norm)
- testangles = utils.angleBetweenVectors(
- refvector, vectors, norm)
+ testangles = utils.angleBetweenVectors(refvector, vectors, norm)
- self.assertTrue(
- numpy.allclose(testangles, refangles, atol=1e-5))
+ self.assertTrue(numpy.allclose(testangles, refangles, atol=1e-5))
# Plane #######################################################################
+
class AssertNotificationContext(object):
"""Context that checks if an event.Notifier is sending events."""
@@ -118,9 +117,9 @@ class TestPlaneParameters(ParametricTestCase):
"""Test Plane.parameters read/write and notifications."""
PARAMETERS = {
- 'unit normal': (1., 0., 0., 1.),
- 'not unit normal': (1., 1., 0., 1.),
- 'd = 0': (1., 0., 0., 0.)
+ "unit normal": (1.0, 0.0, 0.0, 1.0),
+ "not unit normal": (1.0, 1.0, 0.0, 1.0),
+ "d = 0": (1.0, 0.0, 0.0, 0.0),
}
def testParameters(self):
@@ -136,12 +135,9 @@ class TestPlaneParameters(ParametricTestCase):
normparams = parameters / numpy.linalg.norm(parameters[:3])
self.assertTrue(numpy.allclose(plane.parameters, normparams))
- ZEROS_PARAMETERS = (
- (0., 0., 0., 0.),
- (0., 0., 0., 1.)
- )
+ ZEROS_PARAMETERS = ((0.0, 0.0, 0.0, 0.0), (0.0, 0.0, 0.0, 1.0))
- ZEROS = 0., 0., 0., 0.
+ ZEROS = 0.0, 0.0, 0.0, 0.0
def testParametersNoPlane(self):
"""Test Plane.parameters with ||normal|| == 0 ."""
@@ -152,24 +148,25 @@ class TestPlaneParameters(ParametricTestCase):
with self.subTest(parameters=parameters):
with AssertNotificationContext(plane, count=0):
plane.parameters = parameters
- self.assertTrue(
- numpy.allclose(plane.parameters, self.ZEROS, 0., 0.))
+ self.assertTrue(numpy.allclose(plane.parameters, self.ZEROS, 0.0, 0.0))
# unindexArrays ###############################################################
+
class TestUnindexArrays(ParametricTestCase):
"""Test unindexArrays function."""
def testBasicModes(self):
"""Test for modes: points, lines and triangles"""
indices = numpy.array((1, 2, 0))
- arrays = (numpy.array((0., 1., 2.)),
- numpy.array(((0, 0), (1, 1), (2, 2))))
- refresults = (numpy.array((1., 2., 0.)),
- numpy.array(((1, 1), (2, 2), (0, 0))))
+ arrays = (numpy.array((0.0, 1.0, 2.0)), numpy.array(((0, 0), (1, 1), (2, 2))))
+ refresults = (
+ numpy.array((1.0, 2.0, 0.0)),
+ numpy.array(((1, 1), (2, 2), (0, 0))),
+ )
- for mode in ('points', 'lines', 'triangles'):
+ for mode in ("points", "lines", "triangles"):
with self.subTest(mode=mode):
testresults = utils.unindexArrays(mode, indices, *arrays)
for ref, test in zip(refresults, testresults):
@@ -178,15 +175,16 @@ class TestUnindexArrays(ParametricTestCase):
def testPackedLines(self):
"""Test for modes: line_strip, loop"""
indices = numpy.array((1, 2, 0))
- arrays = (numpy.array((0., 1., 2.)),
- numpy.array(((0, 0), (1, 1), (2, 2))))
+ arrays = (numpy.array((0.0, 1.0, 2.0)), numpy.array(((0, 0), (1, 1), (2, 2))))
results = {
- 'line_strip': (
- numpy.array((1., 2., 2., 0.)),
- numpy.array(((1, 1), (2, 2), (2, 2), (0, 0)))),
- 'loop': (
- numpy.array((1., 2., 2., 0., 0., 1.)),
- numpy.array(((1, 1), (2, 2), (2, 2), (0, 0), (0, 0), (1, 1)))),
+ "line_strip": (
+ numpy.array((1.0, 2.0, 2.0, 0.0)),
+ numpy.array(((1, 1), (2, 2), (2, 2), (0, 0))),
+ ),
+ "loop": (
+ numpy.array((1.0, 2.0, 2.0, 0.0, 0.0, 1.0)),
+ numpy.array(((1, 1), (2, 2), (2, 2), (0, 0), (0, 0), (1, 1))),
+ ),
}
for mode, refresults in results.items():
@@ -198,15 +196,19 @@ class TestUnindexArrays(ParametricTestCase):
def testPackedTriangles(self):
"""Test for modes: triangle_strip, fan"""
indices = numpy.array((1, 2, 0, 3))
- arrays = (numpy.array((0., 1., 2., 3.)),
- numpy.array(((0, 0), (1, 1), (2, 2), (3, 3))))
+ arrays = (
+ numpy.array((0.0, 1.0, 2.0, 3.0)),
+ numpy.array(((0, 0), (1, 1), (2, 2), (3, 3))),
+ )
results = {
- 'triangle_strip': (
- numpy.array((1., 2., 0., 2., 0., 3.)),
- numpy.array(((1, 1), (2, 2), (0, 0), (2, 2), (0, 0), (3, 3)))),
- 'fan': (
- numpy.array((1., 2., 0., 1., 0., 3.)),
- numpy.array(((1, 1), (2, 2), (0, 0), (1, 1), (0, 0), (3, 3)))),
+ "triangle_strip": (
+ numpy.array((1.0, 2.0, 0.0, 2.0, 0.0, 3.0)),
+ numpy.array(((1, 1), (2, 2), (0, 0), (2, 2), (0, 0), (3, 3))),
+ ),
+ "fan": (
+ numpy.array((1.0, 2.0, 0.0, 1.0, 0.0, 3.0)),
+ numpy.array(((1, 1), (2, 2), (0, 0), (1, 1), (0, 0), (3, 3))),
+ ),
}
for mode, refresults in results.items():
@@ -221,35 +223,49 @@ class TestUnindexArrays(ParametricTestCase):
# negative indices
with self.assertRaises(AssertionError):
- utils.unindexArrays('points', (-1, 0), *arrays)
+ utils.unindexArrays("points", (-1, 0), *arrays)
# Too high indices
with self.assertRaises(AssertionError):
- utils.unindexArrays('points', (0, 10), *arrays)
+ utils.unindexArrays("points", (0, 10), *arrays)
# triangleNormals #############################################################
+
class TestTriangleNormals(ParametricTestCase):
"""Test triangleNormals function."""
def test(self):
"""Test for modes: points, lines and triangles"""
positions = numpy.array(
- ((0., 0., 0.), (1., 0., 0.), (0., 1., 0.), # normal = Z
- (1., 1., 1.), (1., 2., 3.), (4., 5., 6.), # Random triangle
- # Degenerated triangles:
- (0., 0., 0.), (1., 0., 0.), (2., 0., 0.), # Colinear points
- (1., 1., 1.), (1., 1., 1.), (1., 1., 1.), # All same point
- ),
- dtype='float32')
+ (
+ (0.0, 0.0, 0.0),
+ (1.0, 0.0, 0.0),
+ (0.0, 1.0, 0.0), # normal = Z
+ (1.0, 1.0, 1.0),
+ (1.0, 2.0, 3.0),
+ (4.0, 5.0, 6.0), # Random triangle
+ # Degenerated triangles:
+ (0.0, 0.0, 0.0),
+ (1.0, 0.0, 0.0),
+ (2.0, 0.0, 0.0), # Colinear points
+ (1.0, 1.0, 1.0),
+ (1.0, 1.0, 1.0),
+ (1.0, 1.0, 1.0), # All same point
+ ),
+ dtype="float32",
+ )
normals = numpy.array(
- ((0., 0., 1.),
- (-0.40824829, 0.81649658, -0.40824829),
- (0., 0., 0.),
- (0., 0., 0.)),
- dtype='float32')
+ (
+ (0.0, 0.0, 1.0),
+ (-0.40824829, 0.81649658, -0.40824829),
+ (0.0, 0.0, 0.0),
+ (0.0, 0.0, 0.0),
+ ),
+ dtype="float32",
+ )
testnormals = utils.trianglesNormal(positions)
self.assertTrue(numpy.allclose(testnormals, normals))
diff --git a/src/silx/gui/plot3d/scene/text.py b/src/silx/gui/plot3d/scene/text.py
index 3c4e692..79cdb13 100644
--- a/src/silx/gui/plot3d/scene/text.py
+++ b/src/silx/gui/plot3d/scene/text.py
@@ -33,7 +33,7 @@ import numpy
from silx.gui.colors import rgba
-from ... import _glutils
+from ... import _glutils, qt
from ..._glutils import gl
from ..._glutils import font as _font
@@ -62,24 +62,18 @@ class Font(event.Notifier):
super(Font, self).__init__()
name = event.notifyProperty(
- '_name',
- doc="""Name of the font (str)""",
- converter=str)
+ "_name", doc="""Name of the font (str)""", converter=str
+ )
size = event.notifyProperty(
- '_size',
- doc="""Font size in points (int)""",
- converter=int)
+ "_size", doc="""Font size in points (int)""", converter=int
+ )
- weight = event.notifyProperty(
- '_weight',
- doc="""Font size in points (int)""",
- converter=int)
+ weight = event.notifyProperty("_weight", doc="""Font weight (int)""", converter=int)
italic = event.notifyProperty(
- '_italic',
- doc="""True for italic (bool)""",
- converter=bool)
+ "_italic", doc="""True for italic (bool)""", converter=bool
+ )
class Text2D(primitives.Geometry):
@@ -90,14 +84,14 @@ class Text2D(primitives.Geometry):
"""
# Text anchor values
- CENTER = 'center'
+ CENTER = "center"
- LEFT = 'left'
- RIGHT = 'right'
+ LEFT = "left"
+ RIGHT = "right"
- TOP = 'top'
- BASELINE = 'baseline'
- BOTTOM = 'bottom'
+ TOP = "top"
+ BASELINE = "baseline"
+ BOTTOM = "bottom"
_ALIGN = LEFT, CENTER, RIGHT
_VALIGN = TOP, BASELINE, CENTER, BOTTOM
@@ -106,30 +100,31 @@ class Text2D(primitives.Geometry):
"""Internal cache storing already rasterized text"""
# TODO limit cache size and discard least recent used
- def __init__(self, text='', font=None):
+ def __init__(self, text="", font=None):
self._dirtyTexture = True
self._dirtyAlign = True
self._baselineOffset = 0
self._text = text
self._font = font if font is not None else Font()
- self._foreground = 1., 1., 1., 1.
- self._background = 0., 0., 0., 0.
+ self._foreground = 1.0, 1.0, 1.0, 1.0
+ self._background = 0.0, 0.0, 0.0, 0.0
self._overlay = False
- self._align = 'left'
- self._valign = 'baseline'
- self._devicePixelRatio = 1.0 # Store it to check for changes
+ self._align = "left"
+ self._valign = "baseline"
+ self._dotsPerInch = 96.0 # Store it to check for changes
self._texture = None
self._textureDirty = True
super(Text2D, self).__init__(
- 'triangle_strip',
+ "triangle_strip",
copy=False,
# Keep an array for position as it is bound to attr 0 and MUST
# be active and an array at least on Mac OS X
position=numpy.zeros((4, 3), dtype=numpy.float32),
- vertexID=numpy.arange(4., dtype=numpy.float32).reshape(4, 1),
- offsetInViewportCoords=(0., 0.))
+ vertexID=numpy.arange(4.0, dtype=numpy.float32).reshape(4, 1),
+ offsetInViewportCoords=(0.0, 0.0),
+ )
@property
def text(self):
@@ -162,18 +157,22 @@ class Text2D(primitives.Geometry):
self.notify()
foreground = event.notifyProperty(
- '_foreground', doc="""RGBA color of the text: 4 float in [0, 1]""",
- converter=rgba)
+ "_foreground",
+ doc="""RGBA color of the text: 4 float in [0, 1]""",
+ converter=rgba,
+ )
background = event.notifyProperty(
- '_background',
+ "_background",
doc="RGBA background color of the text field: 4 float in [0, 1]",
- converter=rgba)
+ converter=rgba,
+ )
overlay = event.notifyProperty(
- '_overlay',
+ "_overlay",
doc="True to always display text on top of the scene (default: False)",
- converter=bool)
+ converter=bool,
+ )
def _setAlign(self, align):
assert align in self._ALIGN
@@ -186,7 +185,8 @@ class Text2D(primitives.Geometry):
_setAlign,
doc="""Horizontal anchor position of the text field (str).
- Either 'left' (default), 'center' or 'right'.""")
+ Either 'left' (default), 'center' or 'right'.""",
+ )
def _setVAlign(self, valign):
assert valign in self._VALIGN
@@ -199,37 +199,45 @@ class Text2D(primitives.Geometry):
_setVAlign,
doc="""Vertical anchor position of the text field (str).
- Either 'top', 'baseline' (default), 'center' or 'bottom'""")
+ Either 'top', 'baseline' (default), 'center' or 'bottom'""",
+ )
- def _raster(self, devicePixelRatio):
+ def _raster(self, dotsPerInch: float):
"""Raster current primitive to a bitmap
- :param float devicePixelRatio:
- The ratio between device and device-independent pixels
+ :param dotsPerInch: Screen resolution in pixels per inch
:return: Corresponding image in grayscale and baseline offset from top
:rtype: (HxW numpy.ndarray of uint8, int)
"""
- params = (self.text,
- self.font.name,
- self.font.size,
- self.font.weight,
- self.font.italic,
- devicePixelRatio)
-
- if params not in self._rasterTextCache: # Add to cache
- self._rasterTextCache[params] = _font.rasterText(*params)
-
- array, offset = self._rasterTextCache[params]
+ key = (
+ self.text,
+ self.font.name,
+ self.font.size,
+ self.font.weight,
+ self.font.italic,
+ dotsPerInch,
+ )
+
+ if key not in self._rasterTextCache: # Add to cache
+ font = qt.QFont(
+ self.font.name,
+ self.font.size,
+ self.font.weight,
+ self.font.italic,
+ )
+ self._rasterTextCache[key] = _font.rasterText(self.text, font, dotsPerInch)
+
+ array, offset = self._rasterTextCache[key]
return array.copy(), offset
def _bounds(self, dataBounds=False):
return None
def prepareGL2(self, context):
- # Check if devicePixelRatio has changed since last rendering
- devicePixelRatio = context.glCtx.devicePixelRatio
- if self._devicePixelRatio != devicePixelRatio:
- self._devicePixelRatio = devicePixelRatio
+ # Check if dotsPerInch has changed since last rendering
+ dotsPerInch = context.glCtx.dotsPerInch
+ if self._dotsPerInch != dotsPerInch:
+ self._dotsPerInch = dotsPerInch
self._dirtyTexture = True
if self._dirtyTexture:
@@ -241,13 +249,15 @@ class Text2D(primitives.Geometry):
self._baselineOffset = 0
if self.text:
- image, self._baselineOffset = self._raster(
- self._devicePixelRatio)
+ image, self._baselineOffset = self._raster(dotsPerInch)
self._texture = _glutils.Texture(
- gl.GL_R8, image, gl.GL_RED,
+ gl.GL_R8,
+ image,
+ gl.GL_RED,
minFilter=gl.GL_NEAREST,
magFilter=gl.GL_NEAREST,
- wrap=gl.GL_CLAMP_TO_EDGE)
+ wrap=gl.GL_CLAMP_TO_EDGE,
+ )
self._texture.prepare()
self._dirtyAlign = True # To force update of offset
@@ -257,32 +267,33 @@ class Text2D(primitives.Geometry):
if self._texture is not None:
height, width = self._texture.shape
- if self._align == 'left':
- ox = 0.
- elif self._align == 'center':
- ox = - width // 2
- elif self._align == 'right':
- ox = - width
+ if self._align == "left":
+ ox = 0.0
+ elif self._align == "center":
+ ox = -width // 2
+ elif self._align == "right":
+ ox = -width
else:
_logger.error("Unsupported align: %s", self._align)
- ox = 0.
+ ox = 0.0
- if self._valign == 'top':
- oy = 0.
- elif self._valign == 'baseline':
+ if self._valign == "top":
+ oy = 0.0
+ elif self._valign == "baseline":
oy = self._baselineOffset
- elif self._valign == 'center':
+ elif self._valign == "center":
oy = height // 2
- elif self._valign == 'bottom':
+ elif self._valign == "bottom":
oy = height
else:
_logger.error("Unsupported valign: %s", self._valign)
- oy = 0.
+ oy = 0.0
offsets = (ox, oy) + numpy.array(
- ((0., 0.), (width, 0.), (0., -height), (width, -height)),
- dtype=numpy.float32)
- self.setAttribute('offsetInViewportCoords', offsets)
+ ((0.0, 0.0), (width, 0.0), (0.0, -height), (width, -height)),
+ dtype=numpy.float32,
+ )
+ self.setAttribute("offsetInViewportCoords", offsets)
super(Text2D, self).prepareGL2(context)
@@ -293,14 +304,12 @@ class Text2D(primitives.Geometry):
program = context.glCtx.prog(*self._shaders)
program.use()
- program.setUniformMatrix('matrix', context.objectToNDC.matrix)
- gl.glUniform2f(
- program.uniforms['viewportSize'], *context.viewport.size)
- gl.glUniform4f(program.uniforms['foreground'], *self.foreground)
- gl.glUniform4f(program.uniforms['background'], *self.background)
- gl.glUniform1i(program.uniforms['texture'], self._texture.texUnit)
- gl.glUniform1i(program.uniforms['isOverlay'],
- 1 if self._overlay else 0)
+ program.setUniformMatrix("matrix", context.objectToNDC.matrix)
+ gl.glUniform2f(program.uniforms["viewportSize"], *context.viewport.size)
+ gl.glUniform4f(program.uniforms["foreground"], *self.foreground)
+ gl.glUniform4f(program.uniforms["background"], *self.background)
+ gl.glUniform1i(program.uniforms["texture"], self._texture.texUnit)
+ gl.glUniform1i(program.uniforms["isOverlay"], 1 if self._overlay else 0)
self._texture.bind()
@@ -351,7 +360,6 @@ class Text2D(primitives.Geometry):
vertexID < 1.5 ? 0.0 : 1.0);
}
""", # noqa
-
"""
varying vec2 texCoords;
@@ -373,12 +381,12 @@ class Text2D(primitives.Geometry):
}
}
}
- """)
+ """,
+ )
class LabelledAxes(primitives.GroupBBox):
- """A group displaying a bounding box with axes labels around its children.
- """
+ """A group displaying a bounding box with axes labels around its children."""
def __init__(self):
super(LabelledAxes, self).__init__()
@@ -389,26 +397,23 @@ class LabelledAxes(primitives.GroupBBox):
# TODO offset labels from anchor in pixels
self._xlabel = Text2D(font=self._font)
- self._xlabel.align = 'center'
- self._xlabel.transforms = [self._boxTransforms,
- transform.Translate(tx=0.5)]
+ self._xlabel.align = "center"
+ self._xlabel.transforms = [self._boxTransforms, transform.Translate(tx=0.5)]
self._children.append(self._xlabel)
self._ylabel = Text2D(font=self._font)
- self._ylabel.align = 'center'
- self._ylabel.transforms = [self._boxTransforms,
- transform.Translate(ty=0.5)]
+ self._ylabel.align = "center"
+ self._ylabel.transforms = [self._boxTransforms, transform.Translate(ty=0.5)]
self._children.append(self._ylabel)
self._zlabel = Text2D(font=self._font)
- self._zlabel.align = 'center'
- self._zlabel.transforms = [self._boxTransforms,
- transform.Translate(tz=0.5)]
+ self._zlabel.align = "center"
+ self._zlabel.transforms = [self._boxTransforms, transform.Translate(tz=0.5)]
self._children.append(self._zlabel)
self._tickLines = primitives.Lines( # Init tick lines with dummy pos
- positions=((0., 0., 0.), (0., 0., 0.)),
- mode='lines')
+ positions=((0.0, 0.0, 0.0), (0.0, 0.0, 0.0)), mode="lines"
+ )
self._tickLines.visible = False
self._children.append(self._tickLines)
@@ -465,13 +470,14 @@ class LabelledAxes(primitives.GroupBBox):
self._tickLines.visible = False
self._tickLabels.children = [] # Reset previous labels
- elif (self._ticksForBounds is None or
- not numpy.all(numpy.equal(bounds, self._ticksForBounds))):
+ elif self._ticksForBounds is None or not numpy.all(
+ numpy.equal(bounds, self._ticksForBounds)
+ ):
self._ticksForBounds = bounds
# Update ticks
# TODO make ticks having a constant length on the screen
- ticklength = numpy.abs(bounds[1] - bounds[0]) / 20.
+ ticklength = numpy.abs(bounds[1] - bounds[0]) / 20.0
xticks, xlabels = ticklayout.ticks(*bounds[:, 0])
yticks, ylabels = ticklayout.ticks(*bounds[:, 1])
@@ -479,26 +485,26 @@ class LabelledAxes(primitives.GroupBBox):
# Update tick lines
coords = numpy.empty(
- ((len(xticks) + len(yticks) + len(zticks)), 4, 3),
- dtype=numpy.float32)
+ ((len(xticks) + len(yticks) + len(zticks)), 4, 3), dtype=numpy.float32
+ )
coords[:, :, :] = bounds[0, :] # account for offset from origin
- xcoords = coords[:len(xticks)]
+ xcoords = coords[: len(xticks)]
xcoords[:, :, 0] = numpy.asarray(xticks)[:, numpy.newaxis]
xcoords[:, 1, 1] += ticklength[1] # X ticks on XY plane
xcoords[:, 3, 2] += ticklength[2] # X ticks on XZ plane
- ycoords = coords[len(xticks):len(xticks) + len(yticks)]
+ ycoords = coords[len(xticks) : len(xticks) + len(yticks)]
ycoords[:, :, 1] = numpy.asarray(yticks)[:, numpy.newaxis]
ycoords[:, 1, 0] += ticklength[0] # Y ticks on XY plane
ycoords[:, 3, 2] += ticklength[2] # Y ticks on YZ plane
- zcoords = coords[len(xticks) + len(yticks):]
+ zcoords = coords[len(xticks) + len(yticks) :]
zcoords[:, :, 2] = numpy.asarray(zticks)[:, numpy.newaxis]
zcoords[:, 1, 0] += ticklength[0] # Z ticks on XZ plane
zcoords[:, 3, 1] += ticklength[1] # Z ticks on YZ plane
- self._tickLines.setAttribute('position', coords.reshape(-1, 3))
+ self._tickLines.setAttribute("position", coords.reshape(-1, 3))
self._tickLines.visible = True
# Update labels
@@ -506,23 +512,26 @@ class LabelledAxes(primitives.GroupBBox):
labels = []
for tick, label in zip(xticks, xlabels):
text = Text2D(text=label, font=self.font)
- text.align = 'center'
- text.transforms = [transform.Translate(
- tx=tick, ty=offsets[1], tz=offsets[2])]
+ text.align = "center"
+ text.transforms = [
+ transform.Translate(tx=tick, ty=offsets[1], tz=offsets[2])
+ ]
labels.append(text)
for tick, label in zip(yticks, ylabels):
text = Text2D(text=label, font=self.font)
- text.align = 'center'
- text.transforms = [transform.Translate(
- tx=offsets[0], ty=tick, tz=offsets[2])]
+ text.align = "center"
+ text.transforms = [
+ transform.Translate(tx=offsets[0], ty=tick, tz=offsets[2])
+ ]
labels.append(text)
for tick, label in zip(zticks, zlabels):
text = Text2D(text=label, font=self.font)
- text.align = 'center'
- text.transforms = [transform.Translate(
- tx=offsets[0], ty=offsets[1], tz=tick)]
+ text.align = "center"
+ text.transforms = [
+ transform.Translate(tx=offsets[0], ty=offsets[1], tz=tick)
+ ]
labels.append(text)
self._tickLabels.children = labels # Reset previous labels
diff --git a/src/silx/gui/plot3d/scene/transform.py b/src/silx/gui/plot3d/scene/transform.py
index 5c2cbb3..20e2453 100644
--- a/src/silx/gui/plot3d/scene/transform.py
+++ b/src/silx/gui/plot3d/scene/transform.py
@@ -38,6 +38,7 @@ from . import event
# Projections
+
def mat4LookAtDir(position, direction, up):
"""Creates matrix to look in direction from position.
@@ -54,24 +55,22 @@ def mat4LookAtDir(position, direction, up):
direction = numpy.array(direction, copy=True, dtype=numpy.float32)
dirnorm = numpy.linalg.norm(direction)
- assert dirnorm != 0.
+ assert dirnorm != 0.0
direction /= dirnorm
- side = numpy.cross(direction,
- numpy.array(up, copy=False, dtype=numpy.float32))
+ side = numpy.cross(direction, numpy.array(up, copy=False, dtype=numpy.float32))
sidenorm = numpy.linalg.norm(side)
- assert sidenorm != 0.
+ assert sidenorm != 0.0
up = numpy.cross(side / sidenorm, direction)
upnorm = numpy.linalg.norm(up)
- assert upnorm != 0.
+ assert upnorm != 0.0
up /= upnorm
matrix = numpy.identity(4, dtype=numpy.float32)
matrix[0, :3] = side
matrix[1, :3] = up
matrix[2, :3] = -direction
- return numpy.dot(matrix,
- mat4Translate(-position[0], -position[1], -position[2]))
+ return numpy.dot(matrix, mat4Translate(-position[0], -position[1], -position[2]))
def mat4LookAt(position, center, up):
@@ -97,11 +96,15 @@ def mat4Frustum(left, right, bottom, top, near, far):
See glFrustum.
"""
- return numpy.array((
- (2.*near / (right-left), 0., (right+left) / (right-left), 0.),
- (0., 2.*near / (top-bottom), (top+bottom) / (top-bottom), 0.),
- (0., 0., -(far+near) / (far-near), -2.*far*near / (far-near)),
- (0., 0., -1., 0.)), dtype=numpy.float32)
+ return numpy.array(
+ (
+ (2.0 * near / (right - left), 0.0, (right + left) / (right - left), 0.0),
+ (0.0, 2.0 * near / (top - bottom), (top + bottom) / (top - bottom), 0.0),
+ (0.0, 0.0, -(far + near) / (far - near), -2.0 * far * near / (far - near)),
+ (0.0, 0.0, -1.0, 0.0),
+ ),
+ dtype=numpy.float32,
+ )
def mat4Perspective(fovy, width, height, near, far):
@@ -120,15 +123,19 @@ def mat4Perspective(fovy, width, height, near, far):
assert fovy != 0
assert height != 0
assert width != 0
- assert near > 0.
+ assert near > 0.0
assert far > near
aspectratio = width / height
- f = 1. / numpy.tan(numpy.radians(fovy) / 2.)
- return numpy.array((
- (f / aspectratio, 0., 0., 0.),
- (0., f, 0., 0.),
- (0., 0., (far + near) / (near - far), 2. * far * near / (near - far)),
- (0., 0., -1., 0.)), dtype=numpy.float32)
+ f = 1.0 / numpy.tan(numpy.radians(fovy) / 2.0)
+ return numpy.array(
+ (
+ (f / aspectratio, 0.0, 0.0, 0.0),
+ (0.0, f, 0.0, 0.0),
+ (0.0, 0.0, (far + near) / (near - far), 2.0 * far * near / (near - far)),
+ (0.0, 0.0, -1.0, 0.0),
+ ),
+ dtype=numpy.float32,
+ )
def mat4Orthographic(left, right, bottom, top, near, far):
@@ -136,34 +143,47 @@ def mat4Orthographic(left, right, bottom, top, near, far):
See glOrtho.
"""
- return numpy.array((
- (2. / (right - left), 0., 0., - (right + left) / (right - left)),
- (0., 2. / (top - bottom), 0., - (top + bottom) / (top - bottom)),
- (0., 0., -2. / (far - near), - (far + near) / (far - near)),
- (0., 0., 0., 1.)), dtype=numpy.float32)
+ return numpy.array(
+ (
+ (2.0 / (right - left), 0.0, 0.0, -(right + left) / (right - left)),
+ (0.0, 2.0 / (top - bottom), 0.0, -(top + bottom) / (top - bottom)),
+ (0.0, 0.0, -2.0 / (far - near), -(far + near) / (far - near)),
+ (0.0, 0.0, 0.0, 1.0),
+ ),
+ dtype=numpy.float32,
+ )
# Affine
+
def mat4Translate(tx, ty, tz):
"""4x4 translation matrix."""
- return numpy.array((
- (1., 0., 0., tx),
- (0., 1., 0., ty),
- (0., 0., 1., tz),
- (0., 0., 0., 1.)), dtype=numpy.float32)
+ return numpy.array(
+ (
+ (1.0, 0.0, 0.0, tx),
+ (0.0, 1.0, 0.0, ty),
+ (0.0, 0.0, 1.0, tz),
+ (0.0, 0.0, 0.0, 1.0),
+ ),
+ dtype=numpy.float32,
+ )
def mat4Scale(sx, sy, sz):
"""4x4 scale matrix."""
- return numpy.array((
- (sx, 0., 0., 0.),
- (0., sy, 0., 0.),
- (0., 0., sz, 0.),
- (0., 0., 0., 1.)), dtype=numpy.float32)
-
-
-def mat4RotateFromAngleAxis(angle, x=0., y=0., z=1.):
+ return numpy.array(
+ (
+ (sx, 0.0, 0.0, 0.0),
+ (0.0, sy, 0.0, 0.0),
+ (0.0, 0.0, sz, 0.0),
+ (0.0, 0.0, 0.0, 1.0),
+ ),
+ dtype=numpy.float32,
+ )
+
+
+def mat4RotateFromAngleAxis(angle, x=0.0, y=0.0, z=1.0):
"""4x4 rotation matrix from angle and axis.
:param float angle: The rotation angle in radians.
@@ -173,11 +193,30 @@ def mat4RotateFromAngleAxis(angle, x=0., y=0., z=1.):
"""
ca = numpy.cos(angle)
sa = numpy.sin(angle)
- return numpy.array((
- ((1.-ca) * x*x + ca, (1.-ca) * x*y - sa*z, (1.-ca) * x*z + sa*y, 0.),
- ((1.-ca) * x*y + sa*z, (1.-ca) * y*y + ca, (1.-ca) * y*z - sa*x, 0.),
- ((1.-ca) * x*z - sa*y, (1.-ca) * y*z + sa*x, (1.-ca) * z*z + ca, 0.),
- (0., 0., 0., 1.)), dtype=numpy.float32)
+ return numpy.array(
+ (
+ (
+ (1.0 - ca) * x * x + ca,
+ (1.0 - ca) * x * y - sa * z,
+ (1.0 - ca) * x * z + sa * y,
+ 0.0,
+ ),
+ (
+ (1.0 - ca) * x * y + sa * z,
+ (1.0 - ca) * y * y + ca,
+ (1.0 - ca) * y * z - sa * x,
+ 0.0,
+ ),
+ (
+ (1.0 - ca) * x * z - sa * y,
+ (1.0 - ca) * y * z + sa * x,
+ (1.0 - ca) * z * z + ca,
+ 0.0,
+ ),
+ (0.0, 0.0, 0.0, 1.0),
+ ),
+ dtype=numpy.float32,
+ )
def mat4RotateFromQuaternion(quaternion):
@@ -189,14 +228,33 @@ def mat4RotateFromQuaternion(quaternion):
quaternion /= numpy.linalg.norm(quaternion)
qx, qy, qz, qw = quaternion
- return numpy.array((
- (1. - 2.*(qy**2 + qz**2), 2.*(qx*qy - qw*qz), 2.*(qx*qz + qw*qy), 0.),
- (2.*(qx*qy + qw*qz), 1. - 2.*(qx**2 + qz**2), 2.*(qy*qz - qw*qx), 0.),
- (2.*(qx*qz - qw*qy), 2.*(qy*qz + qw*qx), 1. - 2.*(qx**2 + qy**2), 0.),
- (0., 0., 0., 1.)), dtype=numpy.float32)
-
-
-def mat4Shear(axis, sx=0., sy=0., sz=0.):
+ return numpy.array(
+ (
+ (
+ 1.0 - 2.0 * (qy**2 + qz**2),
+ 2.0 * (qx * qy - qw * qz),
+ 2.0 * (qx * qz + qw * qy),
+ 0.0,
+ ),
+ (
+ 2.0 * (qx * qy + qw * qz),
+ 1.0 - 2.0 * (qx**2 + qz**2),
+ 2.0 * (qy * qz - qw * qx),
+ 0.0,
+ ),
+ (
+ 2.0 * (qx * qz - qw * qy),
+ 2.0 * (qy * qz + qw * qx),
+ 1.0 - 2.0 * (qx**2 + qy**2),
+ 0.0,
+ ),
+ (0.0, 0.0, 0.0, 1.0),
+ ),
+ dtype=numpy.float32,
+ )
+
+
+def mat4Shear(axis, sx=0.0, sy=0.0, sz=0.0):
"""4x4 shear matrix: Skew two axes relative to a third fixed one.
shearFactor = tan(shearAngle)
@@ -207,22 +265,22 @@ def mat4Shear(axis, sx=0., sy=0., sz=0.):
:param float sy: The shear factor for the Y axis relative to axis.
:param float sz: The shear factor for the Z axis relative to axis.
"""
- assert axis in ('x', 'y', 'z')
+ assert axis in ("x", "y", "z")
matrix = numpy.identity(4, dtype=numpy.float32)
# Make the shear column
- index = 'xyz'.find(axis)
- shearcolumn = numpy.array((sx, sy, sz, 0.), dtype=numpy.float32)
- shearcolumn[index] = 1.
+ index = "xyz".find(axis)
+ shearcolumn = numpy.array((sx, sy, sz, 0.0), dtype=numpy.float32)
+ shearcolumn[index] = 1.0
matrix[:, index] = shearcolumn
return matrix
# Transforms ##################################################################
-class Transform(event.Notifier):
+class Transform(event.Notifier):
def __init__(self, static=False):
"""Base class for (row-major) 4x4 matrix transforms.
@@ -236,8 +294,7 @@ class Transform(event.Notifier):
self.addListener(self._changed) # Listening self for changes
def __repr__(self):
- return '%s(%s)' % (self.__class__.__init__,
- repr(self.getMatrix(copy=False)))
+ return "%s(%s)" % (self.__class__.__init__, repr(self.getMatrix(copy=False)))
def inverse(self):
"""Return the Transform of the inverse.
@@ -290,8 +347,8 @@ class Transform(event.Notifier):
return self._inverse
inverseMatrix = property(
- getInverseMatrix,
- doc="The 4x4 matrix of the inverse of this transform.")
+ getInverseMatrix, doc="The 4x4 matrix of the inverse of this transform."
+ )
# Listener
@@ -328,14 +385,13 @@ class Transform(event.Notifier):
if dimension == 3: # Add 4th coordinate
points = numpy.append(
- points,
- numpy.ones((1, points.shape[1]), dtype=points.dtype),
- axis=0)
+ points, numpy.ones((1, points.shape[1]), dtype=points.dtype), axis=0
+ )
result = numpy.transpose(numpy.dot(matrix, points))
if perspectiveDivide:
- mask = result[:, 3] != 0.
+ mask = result[:, 3] != 0.0
result[mask] /= result[mask, 3][:, numpy.newaxis]
return result[:, :3] if dimension == 3 else result
@@ -364,9 +420,9 @@ class Transform(event.Notifier):
matrix = self.getMatrix(copy=False)
else:
matrix = self.getInverseMatrix(copy=False)
- result = numpy.dot(matrix, self._prepareVector(point, 1.))
+ result = numpy.dot(matrix, self._prepareVector(point, 1.0))
- if perspectiveDivide and result[3] != 0.:
+ if perspectiveDivide and result[3] != 0.0:
result /= result[3]
if len(point) == 3:
@@ -404,8 +460,9 @@ class Transform(event.Notifier):
matrix = self.getMatrix(copy=False).T
return numpy.dot(matrix[:3, :3], normal[:3])
- _CUBE_CORNERS = numpy.array(list(itertools.product((0., 1.), repeat=3)),
- dtype=numpy.float32)
+ _CUBE_CORNERS = numpy.array(
+ list(itertools.product((0.0, 1.0), repeat=3)), dtype=numpy.float32
+ )
"""Unit cube corners used by :meth:`transformBounds`"""
def transformBounds(self, bounds, direct=True):
@@ -419,8 +476,7 @@ class Transform(event.Notifier):
:rtype: 2x3 numpy.ndarray of float32
"""
corners = numpy.ones((8, 4), dtype=numpy.float32)
- corners[:, :3] = bounds[0] + \
- self._CUBE_CORNERS * (bounds[1] - bounds[0])
+ corners[:, :3] = bounds[0] + self._CUBE_CORNERS * (bounds[1] - bounds[0])
if direct:
matrix = self.getMatrix(copy=False)
@@ -502,8 +558,8 @@ class StaticTransformList(Transform):
# Affine ######################################################################
-class Matrix(Transform):
+class Matrix(Transform):
def __init__(self, matrix=None):
"""4x4 Matrix.
@@ -528,16 +584,17 @@ class Matrix(Transform):
self.notify()
# Redefined here to add a setter
- matrix = property(Transform.getMatrix, setMatrix,
- doc="The 4x4 matrix of this transform.")
+ matrix = property(
+ Transform.getMatrix, setMatrix, doc="The 4x4 matrix of this transform."
+ )
class Translate(Transform):
"""4x4 translation matrix."""
- def __init__(self, tx=0., ty=0., tz=0.):
+ def __init__(self, tx=0.0, ty=0.0, tz=0.0):
super(Translate, self).__init__()
- self._tx, self._ty, self._tz = 0., 0., 0.
+ self._tx, self._ty, self._tz = 0.0, 0.0, 0.0
self.setTranslate(tx, ty, tz)
def _makeMatrix(self):
@@ -592,16 +649,16 @@ class Translate(Transform):
class Scale(Transform):
"""4x4 scale matrix."""
- def __init__(self, sx=1., sy=1., sz=1.):
+ def __init__(self, sx=1.0, sy=1.0, sz=1.0):
super(Scale, self).__init__()
- self._sx, self._sy, self._sz = 0., 0., 0.
+ self._sx, self._sy, self._sz = 0.0, 0.0, 0.0
self.setScale(sx, sy, sz)
def _makeMatrix(self):
return mat4Scale(self.sx, self.sy, self.sz)
def _makeInverse(self):
- return mat4Scale(1. / self.sx, 1. / self.sy, 1. / self.sz)
+ return mat4Scale(1.0 / self.sx, 1.0 / self.sy, 1.0 / self.sz)
@property
def sx(self):
@@ -638,20 +695,19 @@ class Scale(Transform):
def setScale(self, sx=None, sy=None, sz=None):
if sx is not None:
- assert sx != 0.
+ assert sx != 0.0
self._sx = sx
if sy is not None:
- assert sy != 0.
+ assert sy != 0.0
self._sy = sy
if sz is not None:
- assert sz != 0.
+ assert sz != 0.0
self._sz = sz
self.notify()
class Rotate(Transform):
-
- def __init__(self, angle=0., ax=0., ay=0., az=1.):
+ def __init__(self, angle=0.0, ax=0.0, ay=0.0, az=1.0):
"""4x4 rotation matrix.
:param float angle: The rotation angle in degrees.
@@ -660,7 +716,7 @@ class Rotate(Transform):
:param float az: The z coordinate of the rotation axis.
"""
super(Rotate, self).__init__()
- self._angle = 0.
+ self._angle = 0.0
self._axis = None
self.setAngleAxis(angle, (ax, ay, az))
@@ -695,9 +751,9 @@ class Rotate(Transform):
axis = numpy.array(axis, copy=True, dtype=numpy.float32)
assert axis.size == 3
norm = numpy.linalg.norm(axis)
- if norm == 0.: # No axis, set rotation angle to 0.
- self._angle = 0.
- self._axis = numpy.array((0., 0., 1.), dtype=numpy.float32)
+ if norm == 0.0: # No axis, set rotation angle to 0.
+ self._angle = 0.0
+ self._axis = numpy.array((0.0, 0.0, 1.0), dtype=numpy.float32)
else:
self._axis = axis / norm
@@ -710,8 +766,8 @@ class Rotate(Transform):
Where: ||(x, y, z)|| = sin(angle/2), w = cos(angle/2).
"""
- if numpy.linalg.norm(self._axis) == 0.:
- return numpy.array((0., 0., 0., 1.), dtype=numpy.float32)
+ if numpy.linalg.norm(self._axis) == 0.0:
+ return numpy.array((0.0, 0.0, 0.0, 1.0), dtype=numpy.float32)
else:
quaternion = numpy.empty((4,), dtype=numpy.float32)
@@ -731,7 +787,7 @@ class Rotate(Transform):
# Get angle
sinhalfangle = numpy.linalg.norm(quaternion[0:3])
coshalfangle = quaternion[3]
- angle = 2. * numpy.arctan2(sinhalfangle, coshalfangle)
+ angle = 2.0 * numpy.arctan2(sinhalfangle, coshalfangle)
# Axis will be normalized in setAngleAxis
self.setAngleAxis(numpy.degrees(angle), quaternion[0:3])
@@ -741,14 +797,16 @@ class Rotate(Transform):
return mat4RotateFromAngleAxis(angle, *self.axis)
def _makeInverse(self):
- return numpy.array(self.getMatrix(copy=False).transpose(),
- copy=True, order='C',
- dtype=numpy.float32)
+ return numpy.array(
+ self.getMatrix(copy=False).transpose(),
+ copy=True,
+ order="C",
+ dtype=numpy.float32,
+ )
class Shear(Transform):
-
- def __init__(self, axis, sx=0., sy=0., sz=0.):
+ def __init__(self, axis, sx=0.0, sy=0.0, sz=0.0):
"""4x4 shear/skew matrix of 2 axes relative to the third one.
:param str axis: The axis to keep fixed, in 'x', 'y', 'z'
@@ -756,7 +814,7 @@ class Shear(Transform):
:param float sy: The shear factor for the y axis.
:param float sz: The shear factor for the z axis.
"""
- assert axis in ('x', 'y', 'z')
+ assert axis in ("x", "y", "z")
super(Shear, self).__init__()
self._axis = axis
self._factors = sx, sy, sz
@@ -781,6 +839,7 @@ class Shear(Transform):
# Projection ##################################################################
+
class _Projection(Transform):
"""Base class for projection matrix.
@@ -795,12 +854,12 @@ class _Projection(Transform):
:type size: 2-tuple of float
"""
- def __init__(self, near, far, checkDepthExtent=False, size=(1., 1.)):
+ def __init__(self, near, far, checkDepthExtent=False, size=(1.0, 1.0)):
super(_Projection, self).__init__()
self._checkDepthExtent = checkDepthExtent
self._depthExtent = 1, 10
self.setDepthExtent(near, far) # set _depthExtent
- self._size = 1., 1.
+ self._size = 1.0, 1.0
self.size = size # set _size
def setDepthExtent(self, near=None, far=None):
@@ -813,7 +872,7 @@ class _Projection(Transform):
far = float(far) if far is not None else self._depthExtent[1]
if self._checkDepthExtent:
- assert near > 0.
+ assert near > 0.0
assert far > near
self._depthExtent = near, far
@@ -874,18 +933,27 @@ class Orthographic(_Projection):
True (default) to keep aspect ratio, False otherwise.
"""
- def __init__(self, left=0., right=1., bottom=1., top=0., near=-1., far=1.,
- size=(1., 1.), keepaspect=True):
+ def __init__(
+ self,
+ left=0.0,
+ right=1.0,
+ bottom=1.0,
+ top=0.0,
+ near=-1.0,
+ far=1.0,
+ size=(1.0, 1.0),
+ keepaspect=True,
+ ):
self._left, self._right = left, right
self._bottom, self._top = bottom, top
self._keepaspect = bool(keepaspect)
- super(Orthographic, self).__init__(near, far, checkDepthExtent=False,
- size=size)
+ super(Orthographic, self).__init__(near, far, checkDepthExtent=False, size=size)
# _update called when setting size
def _makeMatrix(self):
return mat4Orthographic(
- self.left, self.right, self.bottom, self.top, self.near, self.far)
+ self.left, self.right, self.bottom, self.top, self.near, self.far
+ )
def _update(self, left, right, bottom, top):
if self.keepaspect:
@@ -895,14 +963,12 @@ class Orthographic(_Projection):
orthoaspect = abs(left - right) / abs(bottom - top)
if orthoaspect >= aspect: # Keep width, enlarge height
- newheight = \
- numpy.sign(top - bottom) * abs(left - right) / aspect
+ newheight = numpy.sign(top - bottom) * abs(left - right) / aspect
bottom = 0.5 * (bottom + top) - 0.5 * newheight
top = bottom + newheight
else: # Keep height, enlarge width
- newwidth = \
- numpy.sign(right - left) * abs(bottom - top) * aspect
+ newwidth = numpy.sign(right - left) * abs(bottom - top) * aspect
left = 0.5 * (left + right) - 0.5 * newwidth
right = left + newwidth
@@ -929,17 +995,15 @@ class Orthographic(_Projection):
self._update(left, right, bottom, top)
self.notify()
- left = property(lambda self: self._left,
- doc="Coord of the left clipping plane.")
+ left = property(lambda self: self._left, doc="Coord of the left clipping plane.")
- right = property(lambda self: self._right,
- doc="Coord of the right clipping plane.")
+ right = property(lambda self: self._right, doc="Coord of the right clipping plane.")
- bottom = property(lambda self: self._bottom,
- doc="Coord of the bottom clipping plane.")
+ bottom = property(
+ lambda self: self._bottom, doc="Coord of the bottom clipping plane."
+ )
- top = property(lambda self: self._top,
- doc="Coord of the top clipping plane.")
+ top = property(lambda self: self._top, doc="Coord of the top clipping plane.")
@property
def size(self):
@@ -982,13 +1046,12 @@ class Ortho2DWidget(_Projection):
:type size: 2-tuple of float
"""
- def __init__(self, near=-1., far=1., size=(1., 1.)):
-
+ def __init__(self, near=-1.0, far=1.0, size=(1.0, 1.0)):
super(Ortho2DWidget, self).__init__(near, far, size)
def _makeMatrix(self):
width, height = self.size
- return mat4Orthographic(0., width, height, 0., self.near, self.far)
+ return mat4Orthographic(0.0, width, height, 0.0, self.near, self.far)
class Perspective(_Projection):
@@ -1002,10 +1065,9 @@ class Perspective(_Projection):
:type size: 2-tuple of float
"""
- def __init__(self, fovy=90., near=0.1, far=1., size=(1., 1.)):
-
+ def __init__(self, fovy=90.0, near=0.1, far=1.0, size=(1.0, 1.0)):
super(Perspective, self).__init__(near, far, checkDepthExtent=True)
- self._fovy = 90.
+ self._fovy = 90.0
self.fovy = fovy # Set _fovy
self.size = size # Set _ size
diff --git a/src/silx/gui/plot3d/scene/utils.py b/src/silx/gui/plot3d/scene/utils.py
index 48fc2f5..c856f15 100644
--- a/src/silx/gui/plot3d/scene/utils.py
+++ b/src/silx/gui/plot3d/scene/utils.py
@@ -42,6 +42,7 @@ _logger = logging.getLogger(__name__)
# numpy #######################################################################
+
def _uniqueAlongLastAxis(a):
"""Numpy unique on the last axis of a 2D array
@@ -57,12 +58,12 @@ def _uniqueAlongLastAxis(a):
assert len(a.shape) == 2
# Construct a type over last array dimension to run unique on a 1D array
- if a.dtype.char in numpy.typecodes['AllInteger']:
+ if a.dtype.char in numpy.typecodes["AllInteger"]:
# Bit-wise comparison of the 2 indices of a line at once
# Expect a C contiguous array of shape N, 2
uniquedt = numpy.dtype((numpy.void, a.itemsize * a.shape[-1]))
- elif a.dtype.char in numpy.typecodes['Float']:
- uniquedt = [('f{i}'.format(i=i), a.dtype) for i in range(a.shape[-1])]
+ elif a.dtype.char in numpy.typecodes["Float"]:
+ uniquedt = [("f{i}".format(i=i), a.dtype) for i in range(a.shape[-1])]
else:
raise TypeError("Unsupported type {dtype}".format(dtype=a.dtype))
@@ -72,6 +73,7 @@ def _uniqueAlongLastAxis(a):
# conversions #################################################################
+
def triangleToLineIndices(triangleIndices, unicity=False):
"""Generates lines indices from triangle indices.
@@ -88,8 +90,7 @@ def triangleToLineIndices(triangleIndices, unicity=False):
triangleIndices = triangleIndices.reshape(-1, 3)
# Pack line indices by triangle and by edge
- lineindices = numpy.empty((len(triangleIndices), 3, 2),
- dtype=triangleIndices.dtype)
+ lineindices = numpy.empty((len(triangleIndices), 3, 2), dtype=triangleIndices.dtype)
lineindices[:, 0] = triangleIndices[:, :2] # edge = t0, t1
lineindices[:, 1] = triangleIndices[:, 1:] # edge =t1, t2
lineindices[:, 2] = triangleIndices[:, ::2] # edge = t0, t2
@@ -103,7 +104,7 @@ def triangleToLineIndices(triangleIndices, unicity=False):
return lineindices
-def verticesNormalsToLines(vertices, normals, scale=1.):
+def verticesNormalsToLines(vertices, normals, scale=1.0):
"""Return vertices of lines representing normals at given positions.
:param vertices: Positions of the points.
@@ -137,13 +138,19 @@ def unindexArrays(mode, indices, *arrays):
"""
indices = numpy.array(indices, copy=False)
- assert mode in ('points',
- 'lines', 'line_strip', 'loop',
- 'triangles', 'triangle_strip', 'fan')
-
- if mode in ('lines', 'line_strip', 'loop'):
+ assert mode in (
+ "points",
+ "lines",
+ "line_strip",
+ "loop",
+ "triangles",
+ "triangle_strip",
+ "fan",
+ )
+
+ if mode in ("lines", "line_strip", "loop"):
assert len(indices) >= 2
- elif mode in ('triangles', 'triangle_strip', 'fan'):
+ elif mode in ("triangles", "triangle_strip", "fan"):
assert len(indices) >= 3
assert indices.min() >= 0
@@ -151,27 +158,27 @@ def unindexArrays(mode, indices, *arrays):
for data in arrays:
assert len(data) >= max_index
- if mode == 'line_strip':
+ if mode == "line_strip":
unpacked = numpy.empty((2 * (len(indices) - 1),), dtype=indices.dtype)
unpacked[0::2] = indices[:-1]
unpacked[1::2] = indices[1:]
indices = unpacked
- elif mode == 'loop':
+ elif mode == "loop":
unpacked = numpy.empty((2 * len(indices),), dtype=indices.dtype)
unpacked[0::2] = indices
unpacked[1:-1:2] = indices[1:]
unpacked[-1] = indices[0]
indices = unpacked
- elif mode == 'triangle_strip':
+ elif mode == "triangle_strip":
unpacked = numpy.empty((3 * (len(indices) - 2),), dtype=indices.dtype)
unpacked[0::3] = indices[:-2]
unpacked[1::3] = indices[1:-1]
unpacked[2::3] = indices[2:]
indices = unpacked
- elif mode == 'fan':
+ elif mode == "fan":
unpacked = numpy.empty((3 * (len(indices) - 2),), dtype=indices.dtype)
unpacked[0::3] = indices[0]
unpacked[1::3] = indices[1:-1]
@@ -220,8 +227,9 @@ def trianglesNormal(positions):
positions = numpy.array(positions, copy=False).reshape(-1, 3, 3)
- normals = numpy.cross(positions[:, 1] - positions[:, 0],
- positions[:, 2] - positions[:, 0])
+ normals = numpy.cross(
+ positions[:, 1] - positions[:, 0], positions[:, 2] - positions[:, 0]
+ )
# Normalize normals
norms = numpy.linalg.norm(normals, axis=1)
@@ -232,6 +240,7 @@ def trianglesNormal(positions):
# grid ########################################################################
+
def gridVertices(dim0Array, dim1Array, dtype):
"""Generate an array of 2D positions from 2 arrays of 1D coordinates.
@@ -308,29 +317,28 @@ def linesGridIndices(dim0, dim1):
nbsegmentalongdim1 = 2 * (dim1 - 1)
nbsegmentalongdim0 = 2 * (dim0 - 1)
- indices = numpy.empty(nbsegmentalongdim1 * dim0 +
- nbsegmentalongdim0 * dim1,
- dtype=numpy.uint32)
+ indices = numpy.empty(
+ nbsegmentalongdim1 * dim0 + nbsegmentalongdim0 * dim1, dtype=numpy.uint32
+ )
# Line indices over dim0
- onedim1line = (numpy.arange(nbsegmentalongdim1,
- dtype=numpy.uint32) + 1) // 2
- indices[:dim0 * nbsegmentalongdim1] = \
- (dim1 * numpy.arange(dim0, dtype=numpy.uint32)[:, None] +
- onedim1line[None, :]).ravel()
+ onedim1line = (numpy.arange(nbsegmentalongdim1, dtype=numpy.uint32) + 1) // 2
+ indices[: dim0 * nbsegmentalongdim1] = (
+ dim1 * numpy.arange(dim0, dtype=numpy.uint32)[:, None] + onedim1line[None, :]
+ ).ravel()
# Line indices over dim1
- onedim0line = (numpy.arange(nbsegmentalongdim0,
- dtype=numpy.uint32) + 1) // 2
- indices[dim0 * nbsegmentalongdim1:] = \
- (numpy.arange(dim1, dtype=numpy.uint32)[:, None] +
- dim1 * onedim0line[None, :]).ravel()
+ onedim0line = (numpy.arange(nbsegmentalongdim0, dtype=numpy.uint32) + 1) // 2
+ indices[dim0 * nbsegmentalongdim1 :] = (
+ numpy.arange(dim1, dtype=numpy.uint32)[:, None] + dim1 * onedim0line[None, :]
+ ).ravel()
return indices
# intersection ################################################################
+
def angleBetweenVectors(refVector, vectors, norm=None):
"""Return the angle between 2 vectors.
@@ -357,10 +365,10 @@ def angleBetweenVectors(refVector, vectors, norm=None):
vectors = numpy.array([v / numpy.linalg.norm(v) for v in vectors])
dots = numpy.sum(refVector * vectors, axis=-1)
- angles = numpy.arccos(numpy.clip(dots, -1., 1.))
+ angles = numpy.arccos(numpy.clip(dots, -1.0, 1.0))
if norm is not None:
- signs = numpy.sum(norm * numpy.cross(refVector, vectors), axis=-1) < 0.
- angles[signs] = numpy.pi * 2. - angles[signs]
+ signs = numpy.sum(norm * numpy.cross(refVector, vectors), axis=-1) < 0.0
+ angles[signs] = numpy.pi * 2.0 - angles[signs]
return angles[0] if singlevector else angles
@@ -391,8 +399,8 @@ def segmentPlaneIntersect(s0, s1, planeNorm, planePt):
else: # No intersection
return []
- alpha = - numpy.dot(planeNorm, s0 - planePt) / dotnormseg
- if 0. <= alpha <= 1.: # Intersection with segment
+ alpha = -numpy.dot(planeNorm, s0 - planePt) / dotnormseg
+ if 0.0 <= alpha <= 1.0: # Intersection with segment
return [s0 + alpha * segdir]
else: # intersection outside segment
return []
@@ -459,8 +467,9 @@ def clipSegmentToBounds(segment, bounds):
points.shape = -1, 3 # Set back to 2D array
# Find intersection points that are included in the volume
- mask = numpy.logical_and(numpy.all(bounds[0] <= points, axis=1),
- numpy.all(points <= bounds[1], axis=1))
+ mask = numpy.logical_and(
+ numpy.all(bounds[0] <= points, axis=1), numpy.all(points <= bounds[1], axis=1)
+ )
intersections = numpy.unique(offsets[mask])
if len(intersections) != 2:
return None
@@ -519,12 +528,12 @@ def segmentVolumeIntersect(segment, nbins):
# Get corresponding line parameters
t = []
if numpy.all(0 <= p0) and numpy.all(p0 <= nbins):
- t.append([0.]) # p0 within volume, add it
+ t.append([0.0]) # p0 within volume, add it
t += [(edgesByDim[i] - p0[i]) / delta[i] for i in range(dim) if delta[i] != 0]
if numpy.all(0 <= p1) and numpy.all(p1 <= nbins):
- t.append([1.]) # p1 within volume, add it
+ t.append([1.0]) # p1 within volume, add it
t = numpy.concatenate(t)
- t.sort(kind='mergesort')
+ t.sort(kind="mergesort")
# Remove duplicates
unique = numpy.ones((len(t),), dtype=bool)
@@ -536,13 +545,14 @@ def segmentVolumeIntersect(segment, nbins):
# bin edges/line intersection points
points = t.reshape(-1, 1) * delta + p0
- centers = (points[:-1] + points[1:]) / 2.
+ centers = (points[:-1] + points[1:]) / 2.0
bins = numpy.floor(centers).astype(numpy.int64)
return bins
# Plane #######################################################################
+
class Plane(event.Notifier):
"""Object handling a plane and notifying plane changes.
@@ -552,7 +562,7 @@ class Plane(event.Notifier):
:type normal: 3-tuple of float.
"""
- def __init__(self, point=(0., 0., 0.), normal=(0., 0., 1.)):
+ def __init__(self, point=(0.0, 0.0, 0.0), normal=(0.0, 0.0, 1.0)):
super(Plane, self).__init__()
assert len(point) == 3
@@ -583,7 +593,7 @@ class Plane(event.Notifier):
normal = numpy.array(normal, copy=True, dtype=numpy.float32)
norm = numpy.linalg.norm(normal)
- if norm != 0.:
+ if norm != 0.0:
normal /= norm
if not numpy.all(numpy.equal(self._normal, normal)):
@@ -591,8 +601,11 @@ class Plane(event.Notifier):
planechanged = True
if planechanged:
- _logger.debug('Plane updated:\n\tpoint: %s\n\tnormal: %s',
- str(self._point), str(self._normal))
+ _logger.debug(
+ "Plane updated:\n\tpoint: %s\n\tnormal: %s",
+ str(self._point),
+ str(self._normal),
+ )
self.notify()
@property
@@ -616,8 +629,7 @@ class Plane(event.Notifier):
@property
def parameters(self):
"""Plane equation parameters: a*x + b*y + c*z + d = 0."""
- return numpy.append(self._normal,
- - numpy.dot(self._point, self._normal))
+ return numpy.append(self._normal, -numpy.dot(self._point, self._normal))
@parameters.setter
def parameters(self, parameters):
@@ -630,13 +642,13 @@ class Plane(event.Notifier):
parameters /= norm
normal = parameters[:3]
- point = - parameters[3] * normal
+ point = -parameters[3] * normal
self.setPlane(point, normal)
@property
def isPlane(self):
"""True if a plane is defined (i.e., ||normal|| != 0)."""
- return numpy.any(self.normal != 0.)
+ return numpy.any(self.normal != 0.0)
def move(self, step):
"""Move the plane of step along the normal."""
diff --git a/src/silx/gui/plot3d/scene/viewport.py b/src/silx/gui/plot3d/scene/viewport.py
index bff77e2..c39d3ef 100644
--- a/src/silx/gui/plot3d/scene/viewport.py
+++ b/src/silx/gui/plot3d/scene/viewport.py
@@ -59,17 +59,19 @@ class RenderContext(object):
:param Context glContext: The operating system OpenGL context in use.
"""
- _FRAGMENT_SHADER_SRC = string.Template("""
+ _FRAGMENT_SHADER_SRC = string.Template(
+ """
void scene_post(vec4 cameraPosition) {
gl_FragColor = $fogCall(gl_FragColor, cameraPosition);
}
- """)
+ """
+ )
def __init__(self, viewport, glContext):
self._viewport = viewport
self._glContext = glContext
self._transformStack = [viewport.camera.extrinsic]
- self._clipPlane = ClippingPlane(normal=(0., 0., 0.))
+ self._clipPlane = ClippingPlane(normal=(0.0, 0.0, 0.0))
# cache
self.__cache = {}
@@ -118,8 +120,7 @@ class RenderContext(object):
Do not modify.
"""
- return transform.StaticTransformList(
- (self.projection, self.objectToCamera))
+ return transform.StaticTransformList((self.projection, self.objectToCamera))
def pushTransform(self, transform_, multiply=True):
"""Push a :class:`Transform` on the transform stack.
@@ -132,7 +133,8 @@ class RenderContext(object):
if multiply:
assert len(self._transformStack) >= 1
transform_ = transform.StaticTransformList(
- (self._transformStack[-1], transform_))
+ (self._transformStack[-1], transform_)
+ )
self._transformStack.append(transform_)
@@ -149,7 +151,7 @@ class RenderContext(object):
"""The current clipping plane (ClippingPlane)"""
return self._clipPlane
- def setClipPlane(self, point=(0., 0., 0.), normal=(0., 0., 0.)):
+ def setClipPlane(self, point=(0.0, 0.0, 0.0), normal=(0.0, 0.0, 0.0)):
"""Set the clipping plane to use
For now only handles a single clipping plane.
@@ -173,11 +175,15 @@ class RenderContext(object):
@property
def fragDecl(self):
"""Fragment shader declaration for scene shader functions"""
- return '\n'.join((
- self.clipper.fragDecl,
- self.viewport.fog.fragDecl,
- self._FRAGMENT_SHADER_SRC.substitute(
- fogCall=self.viewport.fog.fragCall)))
+ return "\n".join(
+ (
+ self.clipper.fragDecl,
+ self.viewport.fog.fragDecl,
+ self._FRAGMENT_SHADER_SRC.substitute(
+ fogCall=self.viewport.fog.fragCall
+ ),
+ )
+ )
@property
def fragCallPre(self):
@@ -204,6 +210,7 @@ class Viewport(event.Notifier):
def __init__(self, framebuffer=0):
from . import Group # Here to avoid cyclic import
+
super(Viewport, self).__init__()
self._dirty = True
self._origin = 0, 0
@@ -212,15 +219,16 @@ class Viewport(event.Notifier):
self.scene = Group() # The stuff to render, add overlaid scenes?
self.scene._setParent(self)
self.scene.addListener(self._changed)
- self._background = 0., 0., 0., 1.
- self._camera = camera.Camera(fovy=30., near=1., far=100.,
- position=(0., 0., 12.))
+ self._background = 0.0, 0.0, 0.0, 1.0
+ self._camera = camera.Camera(
+ fovy=30.0, near=1.0, far=100.0, position=(0.0, 0.0, 12.0)
+ )
self._camera.addListener(self._changed)
self._transforms = transform.TransformList([self._camera])
- self._light = DirectionalLight(direction=(0., 0., -1.),
- ambient=(0.3, 0.3, 0.3),
- diffuse=(0.7, 0.7, 0.7))
+ self._light = DirectionalLight(
+ direction=(0.0, 0.0, -1.0), ambient=(0.3, 0.3, 0.3), diffuse=(0.7, 0.7, 0.7)
+ )
self._light.addListener(self._changed)
self._fog = Fog()
self._fog.isOn = False
@@ -352,7 +360,7 @@ class Viewport(event.Notifier):
gl.glEnable(gl.GL_DEPTH_TEST)
gl.glDepthFunc(gl.GL_LEQUAL)
- gl.glDepthRange(0., 1.)
+ gl.glDepthRange(0.0, 1.0)
# gl.glEnable(gl.GL_POLYGON_OFFSET_FILL)
# gl.glPolygonOffset(1., 1.)
@@ -361,15 +369,16 @@ class Viewport(event.Notifier):
gl.glEnable(gl.GL_LINE_SMOOTH)
if self.background is None:
- gl.glClear(gl.GL_STENCIL_BUFFER_BIT |
- gl.GL_DEPTH_BUFFER_BIT)
+ gl.glClear(gl.GL_STENCIL_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT)
else:
gl.glClearColor(*self.background)
# Prepare OpenGL
- gl.glClear(gl.GL_COLOR_BUFFER_BIT |
- gl.GL_STENCIL_BUFFER_BIT |
- gl.GL_DEPTH_BUFFER_BIT)
+ gl.glClear(
+ gl.GL_COLOR_BUFFER_BIT
+ | gl.GL_STENCIL_BUFFER_BIT
+ | gl.GL_DEPTH_BUFFER_BIT
+ )
ctx = RenderContext(self, glContext)
self.scene.render(ctx)
@@ -384,15 +393,16 @@ class Viewport(event.Notifier):
"""
bounds = self.scene.bounds(transformed=True)
if bounds is None:
- bounds = numpy.array(((0., 0., 0.), (1., 1., 1.)),
- dtype=numpy.float32)
+ bounds = numpy.array(
+ ((0.0, 0.0, 0.0), (1.0, 1.0, 1.0)), dtype=numpy.float32
+ )
bounds = self.camera.extrinsic.transformBounds(bounds)
if isinstance(self.camera.intrinsic, transform.Perspective):
# This needs to be reworked
- zbounds = - bounds[:, 2]
+ zbounds = -bounds[:, 2]
zextent = max(numpy.fabs(zbounds[0] - zbounds[1]), 0.0001)
- near = max(zextent / 1000., 0.95 * zbounds[1])
+ near = max(zextent / 1000.0, 0.95 * zbounds[1])
far = max(near + 0.1, 1.05 * zbounds[0])
self.camera.intrinsic.setDepthExtent(near, far)
@@ -401,7 +411,7 @@ class Viewport(event.Notifier):
border = max(abs(bounds[:, 2]))
self.camera.intrinsic.setDepthExtent(-border, border)
else:
- raise RuntimeError('Unsupported camera', self.camera.intrinsic)
+ raise RuntimeError("Unsupported camera", self.camera.intrinsic)
def resetCamera(self):
"""Change camera to have the whole scene in the viewing frustum.
@@ -411,11 +421,12 @@ class Viewport(event.Notifier):
"""
bounds = self.scene.bounds(transformed=True)
if bounds is None:
- bounds = numpy.array(((0., 0., 0.), (1., 1., 1.)),
- dtype=numpy.float32)
+ bounds = numpy.array(
+ ((0.0, 0.0, 0.0), (1.0, 1.0, 1.0)), dtype=numpy.float32
+ )
self.camera.resetCamera(bounds)
- def orbitCamera(self, direction, angle=1.):
+ def orbitCamera(self, direction, angle=1.0):
"""Rotate the camera around center of the scene.
:param str direction: Direction of movement relative to image plane.
@@ -424,8 +435,9 @@ class Viewport(event.Notifier):
"""
bounds = self.scene.bounds(transformed=True)
if bounds is None:
- bounds = numpy.array(((0., 0., 0.), (1., 1., 1.)),
- dtype=numpy.float32)
+ bounds = numpy.array(
+ ((0.0, 0.0, 0.0), (1.0, 1.0, 1.0)), dtype=numpy.float32
+ )
center = 0.5 * (bounds[0] + bounds[1])
self.camera.orbit(direction, center, angle)
@@ -439,35 +451,36 @@ class Viewport(event.Notifier):
"""
bounds = self.scene.bounds(transformed=True)
if bounds is None:
- bounds = numpy.array(((0., 0., 0.), (1., 1., 1.)),
- dtype=numpy.float32)
+ bounds = numpy.array(
+ ((0.0, 0.0, 0.0), (1.0, 1.0, 1.0)), dtype=numpy.float32
+ )
bounds = self.camera.extrinsic.transformBounds(bounds)
center = 0.5 * (bounds[0] + bounds[1])
- ndcCenter = self.camera.intrinsic.transformPoint(
- center, perspectiveDivide=True)
+ ndcCenter = self.camera.intrinsic.transformPoint(center, perspectiveDivide=True)
- step *= 2. # NDC has size 2
+ step *= 2.0 # NDC has size 2
- if direction == 'up':
+ if direction == "up":
ndcCenter[1] -= step
- elif direction == 'down':
+ elif direction == "down":
ndcCenter[1] += step
- elif direction == 'right':
+ elif direction == "right":
ndcCenter[0] -= step
- elif direction == 'left':
+ elif direction == "left":
ndcCenter[0] += step
- elif direction == 'forward':
+ elif direction == "forward":
ndcCenter[2] += step
- elif direction == 'backward':
+ elif direction == "backward":
ndcCenter[2] -= step
else:
- raise ValueError('Unsupported direction: %s' % direction)
+ raise ValueError("Unsupported direction: %s" % direction)
newCenter = self.camera.intrinsic.transformPoint(
- ndcCenter, direct=False, perspectiveDivide=True)
+ ndcCenter, direct=False, perspectiveDivide=True
+ )
self.camera.move(direction, numpy.linalg.norm(newCenter - center))
@@ -495,11 +508,11 @@ class Viewport(event.Notifier):
x, y = winX - ox, winY - oy
- if checkInside and (x < 0. or x > width or y < 0. or y > height):
+ if checkInside and (x < 0.0 or x > width or y < 0.0 or y > height):
return None # Out of viewport
- ndcx = 2. * x / float(width) - 1.
- ndcy = 1. - 2. * y / float(height)
+ ndcx = 2.0 * x / float(width) - 1.0
+ ndcy = 1.0 - 2.0 * y / float(height)
return ndcx, ndcy
def ndcToWindow(self, ndcX, ndcY, checkInside=True):
@@ -512,15 +525,14 @@ class Viewport(event.Notifier):
:return: (x, y) window coordinates or None.
Origin top-left, x to the right, y goes downward.
"""
- if (checkInside and
- (ndcX < -1. or ndcX > 1. or ndcY < -1. or ndcY > 1.)):
+ if checkInside and (ndcX < -1.0 or ndcX > 1.0 or ndcY < -1.0 or ndcY > 1.0):
return None # Outside viewport
ox, oy = self._origin
width, height = self.size
- winx = ox + width * 0.5 * (ndcX + 1.)
- winy = oy + height * 0.5 * (1. - ndcY)
+ winx = ox + width * 0.5 * (ndcX + 1.0)
+ winy = oy + height * 0.5 * (1.0 - ndcY)
return winx, winy
def _pickNdcZGL(self, x, y, offset=0):
@@ -550,20 +562,19 @@ class Viewport(event.Notifier):
if offset == 0: # Fast path
# glReadPixels is not GL|ES friendly
- depth = gl.glReadPixels(
- x, y, 1, 1, gl.GL_DEPTH_COMPONENT, gl.GL_FLOAT)[0]
+ depthPatch = gl.glReadPixels(x, y, 1, 1, gl.GL_DEPTH_COMPONENT, gl.GL_FLOAT)
+ depth = numpy.ravel(depthPatch)[0]
else:
offset = abs(int(offset))
- size = 2*offset + 1
+ size = 2 * offset + 1
depthPatch = gl.glReadPixels(
- x - offset, y - offset,
- size, size,
- gl.GL_DEPTH_COMPONENT, gl.GL_FLOAT)
+ x - offset, y - offset, size, size, gl.GL_DEPTH_COMPONENT, gl.GL_FLOAT
+ )
depthPatch = depthPatch.ravel() # Work in 1D
# TODO cache sortedIndices to avoid computing it each time
# Compute distance of each pixels to the center of the patch
- offsetToCenter = numpy.arange(- offset, offset + 1, dtype=numpy.float32) ** 2
+ offsetToCenter = numpy.arange(-offset, offset + 1, dtype=numpy.float32) ** 2
sqDistToCenter = numpy.add.outer(offsetToCenter, offsetToCenter)
# Use distance to center to sort values from the patch
@@ -571,26 +582,26 @@ class Viewport(event.Notifier):
sortedValues = depthPatch[sortedIndices]
# Take first depth that is not 1 in the sorted values
- hits = sortedValues[sortedValues != 1.]
- depth = 1. if len(hits) == 0 else hits[0]
+ hits = sortedValues[sortedValues != 1.0]
+ depth = 1.0 if len(hits) == 0 else hits[0]
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
# Z in NDC in [-1., 1.]
- return float(depth) * 2. - 1.
+ return float(depth) * 2.0 - 1.0
def _getXZYGL(self, x, y):
ndc = self.windowToNdc(x, y)
if ndc is None:
return None # Outside viewport
ndcz = self._pickNdcZGL(x, y)
- ndcpos = numpy.array((ndc[0], ndc[1], ndcz, 1.), dtype=numpy.float32)
+ ndcpos = numpy.array((ndc[0], ndc[1], ndcz, 1.0), dtype=numpy.float32)
camerapos = self.camera.intrinsic.transformPoint(
- ndcpos, direct=False, perspectiveDivide=True)
+ ndcpos, direct=False, perspectiveDivide=True
+ )
- scenepos = self.camera.extrinsic.transformPoint(camerapos,
- direct=False)
+ scenepos = self.camera.extrinsic.transformPoint(camerapos, direct=False)
return scenepos[:3]
def pick(self, x, y):
diff --git a/src/silx/gui/plot3d/scene/window.py b/src/silx/gui/plot3d/scene/window.py
index c8f4cee..2a6d93b 100644
--- a/src/silx/gui/plot3d/scene/window.py
+++ b/src/silx/gui/plot3d/scene/window.py
@@ -58,6 +58,7 @@ class Context(object):
self._context = glContextHandle
self._isCurrent = False
self._devicePixelRatio = 1.0
+ self._dotsPerInch = 96.0
@property
def isCurrent(self):
@@ -75,6 +76,16 @@ class Context(object):
self._isCurrent = bool(isCurrent)
@property
+ def dotsPerInch(self) -> float:
+ """Number of physical dots per inch on the screen"""
+ return self._dotsPerInch
+
+ @dotsPerInch.setter
+ def dotsPerInch(self, dpi: float):
+ assert dpi > 0.0
+ self._dotsPerInch = float(dpi)
+
+ @property
def devicePixelRatio(self):
"""Ratio between device and device independent pixels (float)
@@ -112,6 +123,7 @@ class ContextGL2(Context):
:param glContextHandle: System specific OpenGL context handle.
"""
+
def __init__(self, glContextHandle):
super(ContextGL2, self).__init__(glContextHandle)
@@ -121,7 +133,7 @@ class ContextGL2(Context):
# programs
- def prog(self, vertexShaderSrc, fragmentShaderSrc, attrib0='position'):
+ def prog(self, vertexShaderSrc, fragmentShaderSrc, attrib0="position"):
"""Cache program within context.
WARNING: No clean-up.
@@ -138,14 +150,14 @@ class ContextGL2(Context):
program = self._programs.get(key, None)
if program is None:
program = _glutils.Program(
- vertexShaderSrc, fragmentShaderSrc, attrib0=attrib0)
+ vertexShaderSrc, fragmentShaderSrc, attrib0=attrib0
+ )
self._programs[key] = program
return program
# VBOs
- def makeVbo(self, data=None, sizeInBytes=None,
- usage=None, target=None):
+ def makeVbo(self, data=None, sizeInBytes=None, usage=None, target=None):
"""Create a VBO in this context with the data.
Current limitations:
@@ -193,7 +205,8 @@ class ContextGL2(Context):
size=data.shape[0],
dimension=dimension,
offset=0,
- stride=0)
+ stride=0,
+ )
def _deadVbo(self, vboRef):
"""Callback handling dead VBOAttribs."""
@@ -228,13 +241,18 @@ class Window(event.Notifier):
update the texture only when needed.
"""
- _position = numpy.array(((-1., -1., 0., 0.),
- (1., -1., 1., 0.),
- (-1., 1., 0., 1.),
- (1., 1., 1., 1.)),
- dtype=numpy.float32)
-
- _shaders = ("""
+ _position = numpy.array(
+ (
+ (-1.0, -1.0, 0.0, 0.0),
+ (1.0, -1.0, 1.0, 0.0),
+ (-1.0, 1.0, 0.0, 1.0),
+ (1.0, 1.0, 1.0, 1.0),
+ ),
+ dtype=numpy.float32,
+ )
+
+ _shaders = (
+ """
attribute vec4 position;
varying vec2 textureCoord;
@@ -243,7 +261,7 @@ class Window(event.Notifier):
textureCoord = position.zw;
}
""",
- """
+ """
uniform sampler2D texture;
varying vec2 textureCoord;
@@ -251,9 +269,10 @@ class Window(event.Notifier):
gl_FragColor = texture2D(texture, textureCoord);
gl_FragColor.a = 1.0;
}
- """)
+ """,
+ )
- def __init__(self, mode='framebuffer'):
+ def __init__(self, mode="framebuffer"):
super(Window, self).__init__()
self._dirty = True
self._size = 0, 0
@@ -263,8 +282,8 @@ class Window(event.Notifier):
self._framebufferid = 0
self._framebuffers = {} # Cache of framebuffers
- assert mode in ('direct', 'framebuffer')
- self._isframebuffer = mode == 'framebuffer'
+ assert mode in ("direct", "framebuffer")
+ self._isframebuffer = mode == "framebuffer"
@property
def dirty(self):
@@ -316,8 +335,9 @@ class Window(event.Notifier):
self._dirty = True
self.notify(*args, **kwargs)
- framebufferid = property(lambda self: self._framebufferid,
- doc="Framebuffer ID used to perform rendering")
+ framebufferid = property(
+ lambda self: self._framebufferid, doc="Framebuffer ID used to perform rendering"
+ )
def grab(self, glcontext):
"""Returns the raster of the scene as an RGB numpy array
@@ -332,21 +352,21 @@ class Window(event.Notifier):
previousFramebuffer = gl.glGetInteger(gl.GL_FRAMEBUFFER_BINDING)
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.framebufferid)
gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1)
- gl.glReadPixels(
- 0, 0, width, height, gl.GL_RGB, gl.GL_UNSIGNED_BYTE, image)
+ gl.glReadPixels(0, 0, width, height, gl.GL_RGB, gl.GL_UNSIGNED_BYTE, image)
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, previousFramebuffer)
# glReadPixels gives bottom to top,
# while images are stored as top to bottom
image = numpy.flipud(image)
- return numpy.array(image, copy=False, order='C')
+ return numpy.array(image, copy=False, order="C")
- def render(self, glcontext, devicePixelRatio):
+ def render(self, glcontext, dotsPerInch: float, devicePixelRatio: float):
"""Perform the rendering of attached viewports
:param glcontext: System identifier of the OpenGL context
- :param float devicePixelRatio:
+ :param dotsPerInch: Screen physical resolution in pixels per inch
+ :param devicePixelRatio:
Ratio between device and device-independent pixels
"""
if self.size == (0, 0):
@@ -356,6 +376,7 @@ class Window(event.Notifier):
self._contexts[glcontext] = ContextGL2(glcontext) # New context
with self._contexts[glcontext] as context:
+ context.dotsPerInch = dotsPerInch
context.devicePixelRatio = devicePixelRatio
if self._isframebuffer:
self._renderWithOffscreenFramebuffer(context)
@@ -384,18 +405,22 @@ class Window(event.Notifier):
if self.dirty or context not in self._framebuffers:
# Need to redraw framebuffer content
- if (context not in self._framebuffers or
- self._framebuffers[context].shape != self.shape):
+ if (
+ context not in self._framebuffers
+ or self._framebuffers[context].shape != self.shape
+ ):
# Need to rebuild framebuffer
if context in self._framebuffers:
self._framebuffers[context].discard()
- fbo = _glutils.FramebufferTexture(gl.GL_RGBA,
- shape=self.shape,
- minFilter=gl.GL_NEAREST,
- magFilter=gl.GL_NEAREST,
- wrap=gl.GL_CLAMP_TO_EDGE)
+ fbo = _glutils.FramebufferTexture(
+ gl.GL_RGBA,
+ shape=self.shape,
+ minFilter=gl.GL_NEAREST,
+ magFilter=gl.GL_NEAREST,
+ wrap=gl.GL_CLAMP_TO_EDGE,
+ )
self._framebuffers[context] = fbo
self._framebufferid = fbo.name
@@ -415,16 +440,18 @@ class Window(event.Notifier):
gl.glDisable(gl.GL_DEPTH_TEST)
gl.glDisable(gl.GL_SCISSOR_TEST)
# gl.glScissor(0, 0, width, height)
- gl.glClearColor(0., 0., 0., 0.)
+ gl.glClearColor(0.0, 0.0, 0.0, 0.0)
gl.glClear(gl.GL_COLOR_BUFFER_BIT)
- gl.glUniform1i(program.uniforms['texture'], fbo.texture.texUnit)
- gl.glEnableVertexAttribArray(program.attributes['position'])
- gl.glVertexAttribPointer(program.attributes['position'],
- 4,
- gl.GL_FLOAT,
- gl.GL_FALSE,
- 0,
- self._position)
+ gl.glUniform1i(program.uniforms["texture"], fbo.texture.texUnit)
+ gl.glEnableVertexAttribArray(program.attributes["position"])
+ gl.glVertexAttribPointer(
+ program.attributes["position"],
+ 4,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0,
+ self._position,
+ )
fbo.texture.bind()
gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(self._position))
gl.glBindTexture(gl.GL_TEXTURE_2D, 0)
diff --git a/src/silx/gui/plot3d/test/testGL.py b/src/silx/gui/plot3d/test/testGL.py
index d1d53ef..a2627eb 100644
--- a/src/silx/gui/plot3d/test/testGL.py
+++ b/src/silx/gui/plot3d/test/testGL.py
@@ -53,14 +53,18 @@ class TestOpenGL(TestCaseQt):
"""Perform the rendering and logging"""
if not self._dump:
self._dump = True
- _logger.info('OpenGL info:')
- _logger.info('\tQt OpenGL context version: %d.%d', *self.getOpenGLVersion())
- _logger.info('\tGL_VERSION: %s' % gl.glGetString(gl.GL_VERSION))
- _logger.info('\tGL_SHADING_LANGUAGE_VERSION: %s' %
- gl.glGetString(gl.GL_SHADING_LANGUAGE_VERSION))
- _logger.debug('\tGL_EXTENSIONS: %s' % gl.glGetString(gl.GL_EXTENSIONS))
+ _logger.info("OpenGL info:")
+ _logger.info(
+ "\tQt OpenGL context version: %d.%d", *self.getOpenGLVersion()
+ )
+ _logger.info("\tGL_VERSION: %s" % gl.glGetString(gl.GL_VERSION))
+ _logger.info(
+ "\tGL_SHADING_LANGUAGE_VERSION: %s"
+ % gl.glGetString(gl.GL_SHADING_LANGUAGE_VERSION)
+ )
+ _logger.debug("\tGL_EXTENSIONS: %s" % gl.glGetString(gl.GL_EXTENSIONS))
- gl.glClearColor(1., 1., 1., 1.)
+ gl.glClearColor(1.0, 1.0, 1.0, 1.0)
gl.glClear(gl.GL_COLOR_BUFFER_BIT)
def testOpenGL(self):
diff --git a/src/silx/gui/plot3d/test/testScalarFieldView.py b/src/silx/gui/plot3d/test/testScalarFieldView.py
index 1e06e3f..f81b985 100644
--- a/src/silx/gui/plot3d/test/testScalarFieldView.py
+++ b/src/silx/gui/plot3d/test/testScalarFieldView.py
@@ -83,8 +83,8 @@ class TestScalarFieldView(TestCaseQt, ParametricTestCase):
data = self._buildData(size=32)
self.widget.setData(data)
- self.widget.addIsosurface(0.5, (1., 0., 0., 0.5))
- self.widget.addIsosurface(0.7, qt.QColor('green'))
+ self.widget.addIsosurface(0.5, (1.0, 0.0, 0.0, 0.5))
+ self.widget.addIsosurface(0.7, qt.QColor("green"))
self.qapp.processEvents()
def testNotFinite(self):
@@ -94,9 +94,9 @@ class TestScalarFieldView(TestCaseQt, ParametricTestCase):
data = self._buildData(size=32)
data[8, :, :] = numpy.nan
data[16, :, :] = numpy.inf
- data[24, :, :] = - numpy.inf
+ data[24, :, :] = -numpy.inf
- self.widget.addIsosurface(0.5, 'red')
+ self.widget.addIsosurface(0.5, "red")
self.widget.setData(data, copy=True)
self.qapp.processEvents()
self.widget.setData(None)
@@ -114,13 +114,13 @@ class TestScalarFieldView(TestCaseQt, ParametricTestCase):
data = self._buildData(size=32)
self.widget.setData(data)
- self.widget.addIsosurface(0.5, (1., 0., 0., 0.5))
- self.widget.addIsosurface(0.7, qt.QColor('green'))
+ self.widget.addIsosurface(0.5, (1.0, 0.0, 0.0, 0.5))
+ self.widget.addIsosurface(0.7, qt.QColor("green"))
self.qapp.processEvents()
# Add a second TreeView
paramTreeWidget = TreeView(self.widget)
- paramTreeWidget.setIsoLevelSliderNormalization('arcsinh')
+ paramTreeWidget.setIsoLevelSliderNormalization("arcsinh")
paramTreeWidget.setSfView(self.widget)
dock = qt.QDockWidget()
diff --git a/src/silx/gui/plot3d/test/testSceneWidget.py b/src/silx/gui/plot3d/test/testSceneWidget.py
index e7f3b3f..cb3767c 100644
--- a/src/silx/gui/plot3d/test/testSceneWidget.py
+++ b/src/silx/gui/plot3d/test/testSceneWidget.py
@@ -62,7 +62,7 @@ class TestSceneWidget(TestCaseQt, ParametricTestCase):
scatter.setTranslation(10, 10)
scatter.setScale(10, 10, 10)
- self.widget.resetZoom('front')
+ self.widget.resetZoom("front")
self.qapp.processEvents()
self.widget.setFogMode(self.widget.FogMode.LINEAR)
diff --git a/src/silx/gui/plot3d/test/testSceneWidgetPicking.py b/src/silx/gui/plot3d/test/testSceneWidgetPicking.py
index c0ad3b0..1c32899 100644
--- a/src/silx/gui/plot3d/test/testSceneWidgetPicking.py
+++ b/src/silx/gui/plot3d/test/testSceneWidgetPicking.py
@@ -67,15 +67,14 @@ class TestSceneWidgetPicking(TestCaseQt, ParametricTestCase):
imageData.setData(numpy.arange(100).reshape(10, 10))
imageRgba = items.ImageRgba()
- imageRgba.setData(
- numpy.arange(300, dtype=numpy.uint8).reshape(10, 10, 3))
+ imageRgba.setData(numpy.arange(300, dtype=numpy.uint8).reshape(10, 10, 3))
for item in (imageData, imageRgba):
with self.subTest(item=item.__class__.__name__):
# Add item
self.widget.clearItems()
self.widget.addItem(item)
- self.widget.resetZoom('front')
+ self.widget.resetZoom("front")
self.qapp.processEvents()
# Picking on data (at widget center)
@@ -83,12 +82,12 @@ class TestSceneWidgetPicking(TestCaseQt, ParametricTestCase):
self.assertEqual(len(picking), 1)
self.assertIs(picking[0].getItem(), item)
- self.assertEqual(picking[0].getPositions('ndc').shape, (1, 3))
+ self.assertEqual(picking[0].getPositions("ndc").shape, (1, 3))
data = picking[0].getData()
self.assertEqual(len(data), 1)
- self.assertTrue(numpy.array_equal(
- data,
- item.getData()[picking[0].getIndices()]))
+ self.assertTrue(
+ numpy.array_equal(data, item.getData()[picking[0].getIndices()])
+ )
# Picking outside data
picking = list(self.widget.pickItems(1, 1))
@@ -109,7 +108,7 @@ class TestSceneWidgetPicking(TestCaseQt, ParametricTestCase):
# Add item
self.widget.clearItems()
self.widget.addItem(item)
- self.widget.resetZoom('front')
+ self.widget.resetZoom("front")
self.qapp.processEvents()
# Picking on data (at widget center)
@@ -117,12 +116,14 @@ class TestSceneWidgetPicking(TestCaseQt, ParametricTestCase):
self.assertEqual(len(picking), 1)
self.assertIs(picking[0].getItem(), item)
- nbPos = len(picking[0].getPositions('ndc'))
+ nbPos = len(picking[0].getPositions("ndc"))
data = picking[0].getData()
self.assertEqual(nbPos, len(data))
- self.assertTrue(numpy.array_equal(
- data,
- item.getValueData()[picking[0].getIndices()]))
+ self.assertTrue(
+ numpy.array_equal(
+ data, item.getValueData()[picking[0].getIndices()]
+ )
+ )
# Picking outside data
picking = list(self.widget.pickItems(1, 1))
@@ -137,7 +138,7 @@ class TestSceneWidgetPicking(TestCaseQt, ParametricTestCase):
if dtype == numpy.complex64:
volume.setComplexMode(volume.ComplexMode.REAL)
refData = numpy.real(refData)
- self.widget.resetZoom('front')
+ self.widget.resetZoom("front")
cutplane = volume.getCutPlanes()[0]
if dtype == numpy.complex64:
@@ -159,13 +160,12 @@ class TestSceneWidgetPicking(TestCaseQt, ParametricTestCase):
data = picking[0].getData()
self.assertEqual(len(data), 1)
self.assertEqual(picking[0].getPositions().shape, (1, 3))
- self.assertTrue(numpy.array_equal(
- data,
- refData[picking[0].getIndices()]))
+ self.assertTrue(
+ numpy.array_equal(data, refData[picking[0].getIndices()])
+ )
# Picking on data with an isosurface
- isosurface = volume.addIsosurface(
- level=500, color=(1., 0., 0., .5))
+ isosurface = volume.addIsosurface(level=500, color=(1.0, 0.0, 0.0, 0.5))
picking = list(self.widget.pickItems(*self._widgetCenter()))
self.assertEqual(len(picking), 2)
self.assertIs(picking[0].getItem(), cutplane)
@@ -173,9 +173,9 @@ class TestSceneWidgetPicking(TestCaseQt, ParametricTestCase):
self.assertEqual(picking[1].getPositions().shape, (1, 3))
data = picking[1].getData()
self.assertEqual(len(data), 1)
- self.assertTrue(numpy.array_equal(
- data,
- refData[picking[1].getIndices()]))
+ self.assertTrue(
+ numpy.array_equal(data, refData[picking[1].getIndices()])
+ )
# Picking outside data
picking = list(self.widget.pickItems(1, 1))
@@ -188,27 +188,29 @@ class TestSceneWidgetPicking(TestCaseQt, ParametricTestCase):
triangles = items.Mesh()
triangles.setData(
- position=((0, 0, 0), (1, 0, 0), (1, 1, 0),
- (0, 0, 0), (1, 1, 0), (0, 1, 0)),
+ position=((0, 0, 0), (1, 0, 0), (1, 1, 0), (0, 0, 0), (1, 1, 0), (0, 1, 0)),
color=(1, 0, 0, 1),
- mode='triangles')
+ mode="triangles",
+ )
triangleStrip = items.Mesh()
triangleStrip.setData(
position=(((1, 0, 0), (0, 0, 0), (1, 1, 0), (0, 1, 0))),
color=(0, 1, 0, 1),
- mode='triangle_strip')
+ mode="triangle_strip",
+ )
triangleFan = items.Mesh()
triangleFan.setData(
position=((0, 0, 0), (1, 0, 0), (1, 1, 0), (0, 1, 0)),
color=(0, 0, 1, 1),
- mode='fan')
+ mode="fan",
+ )
for item in (triangles, triangleStrip, triangleFan):
with self.subTest(mode=item.getDrawMode()):
# Add item
self.widget.clearItems()
self.widget.addItem(item)
- self.widget.resetZoom('front')
+ self.widget.resetZoom("front")
self.qapp.processEvents()
# Picking on data (at widget center)
@@ -219,9 +221,11 @@ class TestSceneWidgetPicking(TestCaseQt, ParametricTestCase):
nbPos = len(picking[0].getPositions())
data = picking[0].getData()
self.assertEqual(nbPos, len(data))
- self.assertTrue(numpy.array_equal(
- data,
- item.getPositionData()[picking[0].getIndices()]))
+ self.assertTrue(
+ numpy.array_equal(
+ data, item.getPositionData()[picking[0].getIndices()]
+ )
+ )
# Picking outside data
picking = list(self.widget.pickItems(1, 1))
@@ -235,29 +239,35 @@ class TestSceneWidgetPicking(TestCaseQt, ParametricTestCase):
position=((0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)),
color=(1, 0, 0, 1),
indices=numpy.array( # dummy triangles and square
- (0, 0, 1, 0, 1, 2, 1, 2, 3), dtype=numpy.uint8),
- mode='triangles')
+ (0, 0, 1, 0, 1, 2, 1, 2, 3), dtype=numpy.uint8
+ ),
+ mode="triangles",
+ )
triangleStrip = items.Mesh()
triangleStrip.setData(
position=((0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)),
color=(0, 1, 0, 1),
indices=numpy.array( # dummy triangles and square
- (1, 0, 0, 1, 2, 3), dtype=numpy.uint8),
- mode='triangle_strip')
+ (1, 0, 0, 1, 2, 3), dtype=numpy.uint8
+ ),
+ mode="triangle_strip",
+ )
triangleFan = items.Mesh()
triangleFan.setData(
position=((0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)),
color=(0, 0, 1, 1),
indices=numpy.array( # dummy triangle, square, dummy
- (1, 1, 0, 2, 3, 3), dtype=numpy.uint8),
- mode='fan')
+ (1, 1, 0, 2, 3, 3), dtype=numpy.uint8
+ ),
+ mode="fan",
+ )
for item in (triangles, triangleStrip, triangleFan):
with self.subTest(mode=item.getDrawMode()):
# Add item
self.widget.clearItems()
self.widget.addItem(item)
- self.widget.resetZoom('front')
+ self.widget.resetZoom("front")
self.qapp.processEvents()
# Picking on data (at widget center)
@@ -268,9 +278,11 @@ class TestSceneWidgetPicking(TestCaseQt, ParametricTestCase):
nbPos = len(picking[0].getPositions())
data = picking[0].getData()
self.assertEqual(nbPos, len(data))
- self.assertTrue(numpy.array_equal(
- data,
- item.getPositionData()[picking[0].getIndices()]))
+ self.assertTrue(
+ numpy.array_equal(
+ data, item.getPositionData()[picking[0].getIndices()]
+ )
+ )
# Picking outside data
picking = list(self.widget.pickItems(1, 1))
@@ -279,7 +291,7 @@ class TestSceneWidgetPicking(TestCaseQt, ParametricTestCase):
def testPickCylindricalMesh(self):
"""Test picking of Box, Cylinder and Hexagon items"""
- positions = numpy.array(((0., 0., 0.), (1., 1., 0.), (2., 2., 0.)))
+ positions = numpy.array(((0.0, 0.0, 0.0), (1.0, 1.0, 0.0), (2.0, 2.0, 0.0)))
box = items.Box()
box.setData(position=positions)
cylinder = items.Cylinder()
@@ -292,7 +304,7 @@ class TestSceneWidgetPicking(TestCaseQt, ParametricTestCase):
# Add item
self.widget.clearItems()
self.widget.addItem(item)
- self.widget.resetZoom('front')
+ self.widget.resetZoom("front")
self.qapp.processEvents()
# Picking on data (at widget center)
@@ -305,9 +317,9 @@ class TestSceneWidgetPicking(TestCaseQt, ParametricTestCase):
print(item.__class__.__name__, [positions[1]], data)
self.assertTrue(numpy.all(numpy.equal(positions[1], data)))
self.assertEqual(nbPos, len(data))
- self.assertTrue(numpy.array_equal(
- data,
- item.getPosition()[picking[0].getIndices()]))
+ self.assertTrue(
+ numpy.array_equal(data, item.getPosition()[picking[0].getIndices()])
+ )
# Picking outside data
picking = list(self.widget.pickItems(1, 1))
diff --git a/src/silx/gui/plot3d/test/testSceneWindow.py b/src/silx/gui/plot3d/test/testSceneWindow.py
index 09e097c..f2dc486 100644
--- a/src/silx/gui/plot3d/test/testSceneWindow.py
+++ b/src/silx/gui/plot3d/test/testSceneWindow.py
@@ -62,23 +62,25 @@ class TestSceneWindow(TestCaseQt, ParametricTestCase):
items = []
# RGB image
- image = sceneWidget.addImage(numpy.random.random(
- 10*10*3).astype(numpy.float32).reshape(10, 10, 3))
- image.setLabel('RGB image')
+ image = sceneWidget.addImage(
+ numpy.random.random(10 * 10 * 3).astype(numpy.float32).reshape(10, 10, 3)
+ )
+ image.setLabel("RGB image")
items.append(image)
self.assertEqual(sceneWidget.getItems(), tuple(items))
# Data image
image = sceneWidget.addImage(
- numpy.arange(100, dtype=numpy.float32).reshape(10, 10))
- image.setTranslation(10.)
+ numpy.arange(100, dtype=numpy.float32).reshape(10, 10)
+ )
+ image.setTranslation(10.0)
items.append(image)
self.assertEqual(sceneWidget.getItems(), tuple(items))
# 2D scatter
scatter = sceneWidget.add2DScatter(
- *numpy.random.random(3000).astype(numpy.float32).reshape(3, -1),
- index=0)
+ *numpy.random.random(3000).astype(numpy.float32).reshape(3, -1), index=0
+ )
scatter.setTranslation(0, 10)
scatter.setScale(10, 10, 10)
items.insert(0, scatter)
@@ -86,7 +88,8 @@ class TestSceneWindow(TestCaseQt, ParametricTestCase):
# 3D scatter
scatter = sceneWidget.add3DScatter(
- *numpy.random.random(4000).astype(numpy.float32).reshape(4, -1))
+ *numpy.random.random(4000).astype(numpy.float32).reshape(4, -1)
+ )
scatter.setTranslation(10, 10)
scatter.setScale(10, 10, 10)
items.append(scatter)
@@ -94,44 +97,48 @@ class TestSceneWindow(TestCaseQt, ParametricTestCase):
# 3D array of float
volume = sceneWidget.addVolume(
- numpy.arange(10**3, dtype=numpy.float32).reshape(10, 10, 10))
+ numpy.arange(10**3, dtype=numpy.float32).reshape(10, 10, 10)
+ )
volume.setTranslation(0, 0, 10)
volume.setRotation(45, (0, 0, 1))
- volume.addIsosurface(500, 'red')
- volume.getCutPlanes()[0].getColormap().setName('viridis')
+ volume.addIsosurface(500, "red")
+ volume.getCutPlanes()[0].getColormap().setName("viridis")
items.append(volume)
self.assertEqual(sceneWidget.getItems(), tuple(items))
# 3D array of complex
volume = sceneWidget.addVolume(
- numpy.arange(10**3).reshape(10, 10, 10).astype(numpy.complex64))
+ numpy.arange(10**3).reshape(10, 10, 10).astype(numpy.complex64)
+ )
volume.setTranslation(10, 0, 10)
volume.setRotation(45, (0, 0, 1))
volume.setComplexMode(volume.ComplexMode.REAL)
- volume.addIsosurface(500, (1., 0., 0., .5))
+ volume.addIsosurface(500, (1.0, 0.0, 0.0, 0.5))
items.append(volume)
self.assertEqual(sceneWidget.getItems(), tuple(items))
- sceneWidget.resetZoom('front')
+ sceneWidget.resetZoom("front")
self.qapp.processEvents()
def testHeightMap(self):
"""Test height map items"""
sceneWidget = self.window.getSceneWidget()
- height = numpy.arange(10000).reshape(100, 100) /100.
+ height = numpy.arange(10000).reshape(100, 100) / 100.0
for shape in ((100, 100), (4, 5), (150, 20), (110, 110)):
with self.subTest(shape=shape):
items = []
# Colormapped data height map
- data = numpy.arange(numpy.prod(shape)).astype(numpy.float32).reshape(shape)
+ data = (
+ numpy.arange(numpy.prod(shape)).astype(numpy.float32).reshape(shape)
+ )
heightmap = HeightMapData()
heightmap.setData(height)
heightmap.setColormappedData(data)
- heightmap.getColormap().setName('viridis')
+ heightmap.getColormap().setName("viridis")
items.append(heightmap)
sceneWidget.addItem(heightmap)
@@ -142,12 +149,12 @@ class TestSceneWindow(TestCaseQt, ParametricTestCase):
heightmap = HeightMapRGBA()
heightmap.setData(height)
heightmap.setColorData(colors)
- heightmap.setTranslation(100., 0., 0.)
+ heightmap.setTranslation(100.0, 0.0, 0.0)
items.append(heightmap)
sceneWidget.addItem(heightmap)
self.assertEqual(sceneWidget.getItems(), tuple(items))
- sceneWidget.resetZoom('front')
+ sceneWidget.resetZoom("front")
self.qapp.processEvents()
sceneWidget.clearItems()
@@ -203,17 +210,18 @@ class TestSceneWindow(TestCaseQt, ParametricTestCase):
def testInteractiveMode(self):
"""Test changing interactive mode"""
sceneWidget = self.window.getSceneWidget()
- center = numpy.array((sceneWidget.width() //2, sceneWidget.height() // 2))
+ center = numpy.array((sceneWidget.width() // 2, sceneWidget.height() // 2))
self.mouseMove(sceneWidget, pos=center)
self.mouseClick(sceneWidget, qt.Qt.LeftButton, pos=center)
volume = sceneWidget.addVolume(
- numpy.arange(10**3).astype(numpy.float32).reshape(10, 10, 10))
- sceneWidget.selection().setCurrentItem( volume.getCutPlanes()[0])
- sceneWidget.resetZoom('side')
+ numpy.arange(10**3).astype(numpy.float32).reshape(10, 10, 10)
+ )
+ sceneWidget.selection().setCurrentItem(volume.getCutPlanes()[0])
+ sceneWidget.resetZoom("side")
- for mode in (None, 'rotate', 'pan', 'panSelectedPlane'):
+ for mode in (None, "rotate", "pan", "panSelectedPlane"):
with self.subTest(mode=mode):
sceneWidget.setInteractiveMode(mode)
self.qapp.processEvents()
@@ -221,14 +229,14 @@ class TestSceneWindow(TestCaseQt, ParametricTestCase):
self.mouseMove(sceneWidget, pos=center)
self.mousePress(sceneWidget, qt.Qt.LeftButton, pos=center)
- self.mouseMove(sceneWidget, pos=center-10)
- self.mouseMove(sceneWidget, pos=center-20)
- self.mouseRelease(sceneWidget, qt.Qt.LeftButton, pos=center-20)
+ self.mouseMove(sceneWidget, pos=center - 10)
+ self.mouseMove(sceneWidget, pos=center - 20)
+ self.mouseRelease(sceneWidget, qt.Qt.LeftButton, pos=center - 20)
self.keyPress(sceneWidget, qt.Qt.Key_Control)
self.mouseMove(sceneWidget, pos=center)
self.mousePress(sceneWidget, qt.Qt.LeftButton, pos=center)
- self.mouseMove(sceneWidget, pos=center-10)
- self.mouseMove(sceneWidget, pos=center-20)
- self.mouseRelease(sceneWidget, qt.Qt.LeftButton, pos=center-20)
+ self.mouseMove(sceneWidget, pos=center - 10)
+ self.mouseMove(sceneWidget, pos=center - 20)
+ self.mouseRelease(sceneWidget, qt.Qt.LeftButton, pos=center - 20)
self.keyRelease(sceneWidget, qt.Qt.Key_Control)
diff --git a/src/silx/gui/plot3d/test/testStatsWidget.py b/src/silx/gui/plot3d/test/testStatsWidget.py
index e1411bf..71dcbd9 100644
--- a/src/silx/gui/plot3d/test/testStatsWidget.py
+++ b/src/silx/gui/plot3d/test/testStatsWidget.py
@@ -72,18 +72,19 @@ class TestSceneWidget(TestCaseQt, ParametricTestCase):
# Data image
image = self.sceneWidget.addImage(numpy.arange(100).reshape(10, 10))
- image.setLabel('Image')
+ image.setLabel("Image")
# RGB image
imageRGB = self.sceneWidget.addImage(
- numpy.arange(300, dtype=numpy.uint8).reshape(10, 10, 3))
- imageRGB.setLabel('RGB Image')
+ numpy.arange(300, dtype=numpy.uint8).reshape(10, 10, 3)
+ )
+ imageRGB.setLabel("RGB Image")
# 2D scatter
data = numpy.arange(100)
scatter2D = self.sceneWidget.add2DScatter(x=data, y=data, value=data)
- scatter2D.setLabel('2D Scatter')
+ scatter2D.setLabel("2D Scatter")
# 3D scatter
scatter3D = self.sceneWidget.add3DScatter(x=data, y=data, z=data, value=data)
- scatter3D.setLabel('3D Scatter')
+ scatter3D.setLabel("3D Scatter")
# Add a group
group = items.GroupItem()
self.sceneWidget.addItem(group)
@@ -91,7 +92,7 @@ class TestSceneWidget(TestCaseQt, ParametricTestCase):
data = numpy.arange(64**3).reshape(64, 64, 64)
scalarField = items.ScalarField3D()
scalarField.setData(data, copy=False)
- scalarField.setLabel('3D Scalar field')
+ scalarField.setLabel("3D Scalar field")
group.addItem(scalarField)
statsTable = self.statsWidget._getStatsTable()
@@ -104,7 +105,7 @@ class TestSceneWidget(TestCaseQt, ParametricTestCase):
self.assertEqual(statsTable.rowCount(), 0)
for item in (image, scatter2D, scatter3D, scalarField):
- with self.subTest('selection only', item=item.getLabel()):
+ with self.subTest("selection only", item=item.getLabel()):
self.sceneWidget.selection().setCurrentItem(item)
self.assertEqual(statsTable.rowCount(), 1)
self._checkItem(item)
@@ -114,7 +115,7 @@ class TestSceneWidget(TestCaseQt, ParametricTestCase):
self.assertEqual(statsTable.rowCount(), 4)
for item in (image, scatter2D, scatter3D, scalarField):
- with self.subTest('all items', item=item.getLabel()):
+ with self.subTest("all items", item=item.getLabel()):
self._checkItem(item)
def _checkItem(self, item):
@@ -130,9 +131,9 @@ class TestSceneWidget(TestCaseQt, ParametricTestCase):
statsTable = self.statsWidget._getStatsTable()
tableItems = statsTable._itemToTableItems(item)
self.assertTrue(len(tableItems) > 0)
- self.assertEqual(tableItems['legend'].text(), item.getLabel())
- self.assertEqual(float(tableItems['min'].text()), numpy.min(data))
- self.assertEqual(float(tableItems['max'].text()), numpy.max(data))
+ self.assertEqual(tableItems["legend"].text(), item.getLabel())
+ self.assertEqual(float(tableItems["min"].text()), numpy.min(data))
+ self.assertEqual(float(tableItems["max"].text()), numpy.max(data))
# TODO
@@ -192,10 +193,19 @@ class TestScalarFieldView(TestCaseQt):
self.assertEqual(statsTable.rowCount(), 1)
for column in range(statsTable.columnCount()):
- self.assertEqual(float(self._getTextFor(0, 'min')), numpy.min(data))
- self.assertEqual(float(self._getTextFor(0, 'max')), numpy.max(data))
+ self.assertEqual(float(self._getTextFor(0, "min")), numpy.min(data))
+ self.assertEqual(float(self._getTextFor(0, "max")), numpy.max(data))
sum_ = numpy.sum(data)
- comz = numpy.sum(numpy.arange(data.shape[0]) * numpy.sum(data, axis=(1, 2))) / sum_
- comy = numpy.sum(numpy.arange(data.shape[1]) * numpy.sum(data, axis=(0, 2))) / sum_
- comx = numpy.sum(numpy.arange(data.shape[2]) * numpy.sum(data, axis=(0, 1))) / sum_
- self.assertEqual(self._getTextFor(0, 'COM'), str((comx, comy, comz)))
+ comz = (
+ numpy.sum(numpy.arange(data.shape[0]) * numpy.sum(data, axis=(1, 2)))
+ / sum_
+ )
+ comy = (
+ numpy.sum(numpy.arange(data.shape[1]) * numpy.sum(data, axis=(0, 2)))
+ / sum_
+ )
+ comx = (
+ numpy.sum(numpy.arange(data.shape[2]) * numpy.sum(data, axis=(0, 1)))
+ / sum_
+ )
+ self.assertEqual(self._getTextFor(0, "COM"), str((comx, comy, comz)))
diff --git a/src/silx/gui/plot3d/tools/GroupPropertiesWidget.py b/src/silx/gui/plot3d/tools/GroupPropertiesWidget.py
index 922df3a..11f45cc 100644
--- a/src/silx/gui/plot3d/tools/GroupPropertiesWidget.py
+++ b/src/silx/gui/plot3d/tools/GroupPropertiesWidget.py
@@ -56,16 +56,16 @@ class GroupPropertiesWidget(qt.QWidget):
self.setLayout(layout)
# Colormap
- colormapButton = qt.QPushButton('Set...')
+ colormapButton = qt.QPushButton("Set...")
colormapButton.setToolTip("Set colormap for all items")
colormapButton.clicked.connect(self._colormapButtonClicked)
- layout.addRow('Colormap', colormapButton)
+ layout.addRow("Colormap", colormapButton)
self._markerComboBox = qt.QComboBox(self)
self._markerComboBox.addItems(SymbolMixIn.getSupportedSymbolNames())
# Marker
- markerButton = qt.QPushButton('Set')
+ markerButton = qt.QPushButton("Set")
markerButton.setToolTip("Set marker for all items")
markerButton.clicked.connect(self._markerButtonClicked)
@@ -74,7 +74,7 @@ class GroupPropertiesWidget(qt.QWidget):
markerLayout.addWidget(self._markerComboBox, 1)
markerLayout.addWidget(markerButton, 0)
- layout.addRow('Marker', markerLayout)
+ layout.addRow("Marker", markerLayout)
# Marker size
self._markerSizeSlider = qt.QSlider()
@@ -83,18 +83,18 @@ class GroupPropertiesWidget(qt.QWidget):
self._markerSizeSlider.setRange(1, self.MAX_MARKER_SIZE)
self._markerSizeSlider.setValue(1)
- markerSizeButton = qt.QPushButton('Set')
+ markerSizeButton = qt.QPushButton("Set")
markerSizeButton.setToolTip("Set marker size for all items")
markerSizeButton.clicked.connect(self._markerSizeButtonClicked)
markerSizeLayout = qt.QHBoxLayout()
markerSizeLayout.setContentsMargins(0, 0, 0, 0)
- markerSizeLayout.addWidget(qt.QLabel('1'))
+ markerSizeLayout.addWidget(qt.QLabel("1"))
markerSizeLayout.addWidget(self._markerSizeSlider, 1)
markerSizeLayout.addWidget(qt.QLabel(str(self.MAX_MARKER_SIZE)))
markerSizeLayout.addWidget(markerSizeButton, 0)
- layout.addRow('Marker Size', markerSizeLayout)
+ layout.addRow("Marker Size", markerSizeLayout)
# Line width
self._lineWidthSlider = qt.QSlider()
@@ -103,18 +103,18 @@ class GroupPropertiesWidget(qt.QWidget):
self._lineWidthSlider.setRange(1, self.MAX_LINE_WIDTH)
self._lineWidthSlider.setValue(1)
- lineWidthButton = qt.QPushButton('Set')
+ lineWidthButton = qt.QPushButton("Set")
lineWidthButton.setToolTip("Set line width for all items")
lineWidthButton.clicked.connect(self._lineWidthButtonClicked)
lineWidthLayout = qt.QHBoxLayout()
lineWidthLayout.setContentsMargins(0, 0, 0, 0)
- lineWidthLayout.addWidget(qt.QLabel('1'))
+ lineWidthLayout.addWidget(qt.QLabel("1"))
lineWidthLayout.addWidget(self._lineWidthSlider, 1)
lineWidthLayout.addWidget(qt.QLabel(str(self.MAX_LINE_WIDTH)))
lineWidthLayout.addWidget(lineWidthButton, 0)
- layout.addRow('Line Width', lineWidthLayout)
+ layout.addRow("Line Width", lineWidthLayout)
self._colormapDialog = None # To store dialog
self._colormap = Colormap()
@@ -159,7 +159,8 @@ class GroupPropertiesWidget(qt.QWidget):
itemCmap.setColormapLUT(colormap.getColormapLUT())
itemCmap.setNormalization(colormap.getNormalization())
itemCmap.setGammaNormalizationParameter(
- colormap.getGammaNormalizationParameter())
+ colormap.getGammaNormalizationParameter()
+ )
itemCmap.setVRange(colormap.getVMin(), colormap.getVMax())
else:
# Reset colormap
@@ -195,5 +196,5 @@ class GroupPropertiesWidget(qt.QWidget):
lineWidth = self._lineWidthSlider.value()
for item in group.visit():
- if hasattr(item, 'setLineWidth'):
+ if hasattr(item, "setLineWidth"):
item.setLineWidth(lineWidth)
diff --git a/src/silx/gui/plot3d/tools/PositionInfoWidget.py b/src/silx/gui/plot3d/tools/PositionInfoWidget.py
index 1998533..bffe952 100644
--- a/src/silx/gui/plot3d/tools/PositionInfoWidget.py
+++ b/src/silx/gui/plot3d/tools/PositionInfoWidget.py
@@ -55,18 +55,19 @@ class PositionInfoWidget(qt.QWidget):
self.setToolTip("Double-click on a data point to show its value")
layout = qt.QBoxLayout(qt.QBoxLayout.LeftToRight, self)
- self._xLabel = self._addInfoField('X')
- self._yLabel = self._addInfoField('Y')
- self._zLabel = self._addInfoField('Z')
- self._dataLabel = self._addInfoField('Data')
- self._itemLabel = self._addInfoField('Item')
+ self._xLabel = self._addInfoField("X")
+ self._yLabel = self._addInfoField("Y")
+ self._zLabel = self._addInfoField("Z")
+ self._dataLabel = self._addInfoField("Data")
+ self._itemLabel = self._addInfoField("Item")
layout.addStretch(1)
self._action = actions.mode.PickingModeAction(parent=self)
- self._action.setText('Selection')
+ self._action.setText("Selection")
self._action.setToolTip(
- 'Toggle selection information update with left button click')
+ "Toggle selection information update with left button click"
+ )
self._action.sigSceneClicked.connect(self.pick)
self._action.changed.connect(self.__actionChanged)
self._action.setChecked(False) # Disabled by default
@@ -94,14 +95,14 @@ class PositionInfoWidget(qt.QWidget):
subLayout = qt.QHBoxLayout()
subLayout.setContentsMargins(0, 0, 0, 0)
- subLayout.addWidget(qt.QLabel(label + ':'))
+ subLayout.addWidget(qt.QLabel(label + ":"))
- widget = qt.QLabel('-')
+ widget = qt.QLabel("-")
widget.setAlignment(qt.Qt.AlignLeft | qt.Qt.AlignVCenter)
widget.setTextInteractionFlags(qt.Qt.TextSelectableByMouse)
metrics = widget.fontMetrics()
- if qt.BINDING in ('PySide2', 'PyQt5'):
+ if qt.BINDING == "PyQt5":
width = metrics.width("#######")
else: # Qt6
width = metrics.horizontalAdvance("#######")
@@ -139,22 +140,29 @@ class PositionInfoWidget(qt.QWidget):
def clear(self):
"""Clean-up displayed values"""
- for widget in (self._xLabel, self._yLabel, self._zLabel,
- self._dataLabel, self._itemLabel):
- widget.setText('-')
-
- _SUPPORTED_ITEMS = (items.Scatter3D,
- items.Scatter2D,
- items.ImageData,
- items.ImageRgba,
- items.HeightMapData,
- items.HeightMapRGBA,
- items.Mesh,
- items.Box,
- items.Cylinder,
- items.Hexagon,
- volume.CutPlane,
- volume.Isosurface)
+ for widget in (
+ self._xLabel,
+ self._yLabel,
+ self._zLabel,
+ self._dataLabel,
+ self._itemLabel,
+ ):
+ widget.setText("-")
+
+ _SUPPORTED_ITEMS = (
+ items.Scatter3D,
+ items.Scatter2D,
+ items.ImageData,
+ items.ImageRgba,
+ items.HeightMapData,
+ items.HeightMapRGBA,
+ items.Mesh,
+ items.Box,
+ items.Cylinder,
+ items.Hexagon,
+ volume.CutPlane,
+ volume.Isosurface,
+ )
"""Type of items that are picked"""
def _isSupportedItem(self, item):
@@ -177,15 +185,14 @@ class PositionInfoWidget(qt.QWidget):
sceneWidget = self.getSceneWidget()
if sceneWidget is None: # No associated widget
- _logger.info('Picking without associated SceneWidget')
+ _logger.info("Picking without associated SceneWidget")
return
# Find closest (and latest in the tree) supported item
- closestNdcZ = float('inf')
+ closestNdcZ = float("inf")
picking = None
- for result in sceneWidget.pickItems(x, y,
- condition=self._isSupportedItem):
- ndcZ = result.getPositions('ndc', copy=False)[0, 2]
+ for result in sceneWidget.pickItems(x, y, condition=self._isSupportedItem):
+ ndcZ = result.getPositions("ndc", copy=False)[0, 2]
if ndcZ <= closestNdcZ:
closestNdcZ = ndcZ
picking = result
@@ -195,7 +202,7 @@ class PositionInfoWidget(qt.QWidget):
item = picking.getItem()
self._itemLabel.setText(item.getLabel())
- positions = picking.getPositions('scene', copy=False)
+ positions = picking.getPositions("scene", copy=False)
x, y, z = positions[0]
self._xLabel.setText("%g" % x)
self._yLabel.setText("%g" % y)
@@ -204,8 +211,8 @@ class PositionInfoWidget(qt.QWidget):
data = picking.getData(copy=False)
if data is not None:
data = data[0]
- if hasattr(data, '__len__'):
- text = ' '.join(["%.3g"] * len(data)) % tuple(data)
+ if hasattr(data, "__len__"):
+ text = " ".join(["%.3g"] * len(data)) % tuple(data)
else:
text = "%g" % data
self._dataLabel.setText(text)
@@ -214,7 +221,7 @@ class PositionInfoWidget(qt.QWidget):
"""Update information according to cursor position"""
widget = self.getSceneWidget()
if widget is None:
- _logger.info('Update without associated SceneWidget')
+ _logger.info("Update without associated SceneWidget")
self.clear()
return
diff --git a/src/silx/gui/plot3d/tools/ViewpointTools.py b/src/silx/gui/plot3d/tools/ViewpointTools.py
index ab26c96..3554972 100644
--- a/src/silx/gui/plot3d/tools/ViewpointTools.py
+++ b/src/silx/gui/plot3d/tools/ViewpointTools.py
@@ -57,8 +57,8 @@ class ViewpointToolButton(qt.QToolButton):
self.setMenu(menu)
self.setPopupMode(qt.QToolButton.InstantPopup)
- self.setIcon(getQIcon('cube'))
- self.setToolTip('Reset the viewpoint to a defined position')
+ self.setIcon(getQIcon("cube"))
+ self.setToolTip("Reset the viewpoint to a defined position")
def setPlot3DWidget(self, widget):
"""Set the Plot3DWidget this toolbar is associated with
diff --git a/src/silx/gui/plot3d/tools/test/testPositionInfoWidget.py b/src/silx/gui/plot3d/tools/test/testPositionInfoWidget.py
index e988817..ae95fca 100644
--- a/src/silx/gui/plot3d/tools/test/testPositionInfoWidget.py
+++ b/src/silx/gui/plot3d/tools/test/testPositionInfoWidget.py
@@ -27,8 +27,6 @@ __license__ = "MIT"
__date__ = "03/10/2018"
-import unittest
-
import numpy
from silx.gui.utils.testutils import TestCaseQt
@@ -68,12 +66,11 @@ class TestPositionInfoWidget(TestCaseQt):
def test(self):
"""Test PositionInfoWidget"""
- self.assertIs(self.positionInfoWidget.getSceneWidget(),
- self.sceneWidget)
+ self.assertIs(self.positionInfoWidget.getSceneWidget(), self.sceneWidget)
data = numpy.arange(100)
self.sceneWidget.add2DScatter(x=data, y=data, value=data)
- self.sceneWidget.resetZoom('front')
+ self.sceneWidget.resetZoom("front")
# Double click at the center
self.mouseDClick(self.sceneWidget, button=qt.Qt.LeftButton)
diff --git a/src/silx/gui/plot3d/tools/toolbars.py b/src/silx/gui/plot3d/tools/toolbars.py
index c89f6c6..152e548 100644
--- a/src/silx/gui/plot3d/tools/toolbars.py
+++ b/src/silx/gui/plot3d/tools/toolbars.py
@@ -58,7 +58,7 @@ class Plot3DWidgetToolBar(qt.QToolBar):
:param str title: Title of the toolbar.
"""
- def __init__(self, parent=None, title=''):
+ def __init__(self, parent=None, title=""):
super(Plot3DWidgetToolBar, self).__init__(title, parent)
self._plot3DRef = None
@@ -97,7 +97,7 @@ class InteractiveModeToolBar(Plot3DWidgetToolBar):
:param str title: Title of the toolbar.
"""
- def __init__(self, parent=None, title='Plot3D Interaction'):
+ def __init__(self, parent=None, title="Plot3D Interaction"):
super(InteractiveModeToolBar, self).__init__(parent, title)
self._rotateAction = actions.mode.RotateArcballAction(parent=self)
@@ -128,7 +128,7 @@ class OutputToolBar(Plot3DWidgetToolBar):
:param str title: Title of the toolbar.
"""
- def __init__(self, parent=None, title='Plot3D Output'):
+ def __init__(self, parent=None, title="Plot3D Output"):
super(OutputToolBar, self).__init__(parent, title)
self._copyAction = actions.io.CopyAction(parent=self)
@@ -179,7 +179,7 @@ class ViewpointToolBar(Plot3DWidgetToolBar):
:param str title: Title of the toolbar
"""
- def __init__(self, parent=None, title='Viewpoint control'):
+ def __init__(self, parent=None, title="Viewpoint control"):
super(ViewpointToolBar, self).__init__(parent, title)
self._viewpointToolButton = ViewpointToolButton(parent=self)
diff --git a/src/silx/gui/plot3d/utils/mng.py b/src/silx/gui/plot3d/utils/mng.py
index 52f619f..3c63266 100644
--- a/src/silx/gui/plot3d/utils/mng.py
+++ b/src/silx/gui/plot3d/utils/mng.py
@@ -47,10 +47,10 @@ def _png_chunk(name, data):
:param str name: Chunk type
:param byte data: Chunk payload
"""
- length = struct.pack('>I', len(data))
- name = [char.encode('ascii') for char in name]
- chunk = struct.pack('cccc', *name) + data
- crc = struct.pack('>I', zlib.crc32(chunk) & 0xffffffff)
+ length = struct.pack(">I", len(data))
+ name = [char.encode("ascii") for char in name]
+ chunk = struct.pack("cccc", *name) + data
+ crc = struct.pack(">I", zlib.crc32(chunk) & 0xFFFFFFFF)
return length + chunk + crc
@@ -76,43 +76,46 @@ def convert(images, nb_images=0, fps=25):
height, width = image.shape[:2]
# MNG signature
- yield b'\x8aMNG\r\n\x1a\n'
+ yield b"\x8aMNG\r\n\x1a\n"
# MHDR chunk: File header
- yield _png_chunk('MHDR', struct.pack(
- ">IIIIIII",
- width,
- height,
- fps, # ticks
- nb_images + 1, # layer count
- nb_images, # frame count
- nb_images, # play time
- 1)) # profile: MNG-VLC no alpha: only least significant bit 1
+ yield _png_chunk(
+ "MHDR",
+ struct.pack(
+ ">IIIIIII",
+ width,
+ height,
+ fps, # ticks
+ nb_images + 1, # layer count
+ nb_images, # frame count
+ nb_images, # play time
+ 1,
+ ),
+ ) # profile: MNG-VLC no alpha: only least significant bit 1
assert image.shape == (height, width, 3)
- assert image.dtype == numpy.dtype('uint8')
+ assert image.dtype == numpy.dtype("uint8")
# IHDR chunk: Image header
depth = 8 # 8 bit per channel
color_type = 2 # 'truecolor' = RGB
interlace = 0 # No
- yield _png_chunk('IHDR', struct.pack(">IIBBBBB",
- width,
- height,
- depth,
- color_type,
- 0, 0, interlace))
+ yield _png_chunk(
+ "IHDR",
+ struct.pack(">IIBBBBB", width, height, depth, color_type, 0, 0, interlace),
+ )
# Add filter 'None' before each scanline
- prepared_data = b'\x00' + b'\x00'.join(
- line.tobytes() for line in image) # TODO optimize that
+ prepared_data = b"\x00" + b"\x00".join(
+ line.tobytes() for line in image
+ ) # TODO optimize that
compressed_data = zlib.compress(prepared_data, 8)
# IDAT chunk: Payload
- yield _png_chunk('IDAT', compressed_data)
+ yield _png_chunk("IDAT", compressed_data)
# IEND chunk: Image footer
- yield _png_chunk('IEND', b'')
+ yield _png_chunk("IEND", b"")
# MEND chunk: footer
- yield _png_chunk('MEND', b'')
+ yield _png_chunk("MEND", b"")
diff --git a/src/silx/gui/qt/__init__.py b/src/silx/gui/qt/__init__.py
index bc75041..675a178 100644
--- a/src/silx/gui/qt/__init__.py
+++ b/src/silx/gui/qt/__init__.py
@@ -25,11 +25,12 @@
- `PyQt5 <http://pyqt.sourceforge.net/Docs/PyQt5/>`_
- `PySide6 <https://pypi.org/project/PySide6/>`_
-- `PySide2 <https://pypi.org/project/PySide2/>`_
- `PyQt6 <https://pypi.org/project/PyQt6/>`_
-If a Qt binding is already loaded, it will use it, otherwise the different
-Qt bindings are tried in this order: PyQt5, PySide6, PySide2, PyQt6.
+If a Qt binding is already loaded, it will be used.
+If the `QT_API` environment variable is set to one of the supported Qt bindings
+(case insensitive), this binding is loaded if available, otherwise the
+different Qt bindings are tried in this order: PyQt5, PySide6, PyQt6.
The name of the loaded Qt binding is stored in the BINDING variable.
@@ -48,7 +49,8 @@ see `qtpy <https://pypi.org/project/QtPy/>`_.
"""
from ._qt import * # noqa
-if BINDING in ('PySide2', 'PySide6'):
+
+if BINDING == "PySide6":
# Import loadUi wrapper
- from ._pyside_dynamic import loadUi # noqa
+ from ._pyside_dynamic import loadUi # noqa
from ._utils import * # noqa
diff --git a/src/silx/gui/qt/_pyqt6.py b/src/silx/gui/qt/_pyqt6.py
index 15b49bb..4f28d40 100644
--- a/src/silx/gui/qt/_pyqt6.py
+++ b/src/silx/gui/qt/_pyqt6.py
@@ -32,7 +32,6 @@ import enum
import logging
import PyQt6.sip
-from PyQt6.QtCore import Qt
_logger = logging.getLogger(__name__)
@@ -46,19 +45,35 @@ def patch_enums(*modules):
for module in modules:
for clsName in dir(module):
cls = getattr(module, clsName, None)
- if isinstance(cls, PyQt6.sip.wrappertype) and clsName.startswith('Q'):
- for qenumName in dir(cls):
- if qenumName[0].isupper():
- qenum = getattr(cls, qenumName, None)
- if isinstance(qenum, enum.EnumMeta):
- if qenum is getattr(cls.__mro__[1], qenumName, None):
- continue # Only handle it once
- for item in qenum:
- # Special cases to avoid overrides and mimic PySide6
- if clsName == 'QColorSpace' and qenumName in (
- 'Primaries', 'TransferFunction'):
- break
- if qenumName in ('DeviceType', 'PointerType'):
- break
+ if not isinstance(cls, PyQt6.sip.wrappertype) or not clsName.startswith(
+ "Q"
+ ):
+ continue
- setattr(cls, item.name, item)
+ for qenumName in dir(cls):
+ if not qenumName[0].isupper():
+ continue
+ # Special cases to avoid overrides and mimic PySide6
+ if clsName == "QColorSpace" and qenumName in (
+ "Primaries",
+ "TransferFunction",
+ ):
+ continue
+ if qenumName in ("DeviceType", "PointerType"):
+ continue
+
+ qenum = getattr(cls, qenumName)
+ if not isinstance(qenum, enum.EnumMeta):
+ continue
+
+ if any(
+ map(
+ lambda ancestor: isinstance(ancestor, PyQt6.sip.wrappertype)
+ and qenum is getattr(ancestor, qenumName, None),
+ cls.__mro__[1:],
+ )
+ ):
+ continue # Only handle it once in case of inheritance
+
+ for name, value in qenum.__members__.items():
+ setattr(cls, name, value)
diff --git a/src/silx/gui/qt/_pyside_dynamic.py b/src/silx/gui/qt/_pyside_dynamic.py
index 80520ac..4c1ceba 100644
--- a/src/silx/gui/qt/_pyside_dynamic.py
+++ b/src/silx/gui/qt/_pyside_dynamic.py
@@ -1,26 +1,58 @@
-
-# Taken from: https://gist.github.com/cpbotha/1b42a20c8f3eb9bb7cb8
-# Plus: https://github.com/spyder-ide/qtpy/commit/001a862c401d757feb63025f88dbb4601d353c84
-
+# Adapted from https://github.com/spyder-ide/qtpy/blob/296dee3da8aba381b3cf17da34a6d17626e50357/qtpy/uic.py
+# In PySide, loadUi does not exist, so we define it using QUiLoader, and
+# then make sure we expose that function. This is adapted from qt-helpers
+# which was released under a 3-clause BSD license:
+# qt-helpers - a common front-end to various Qt modules
+#
+# Copyright (c) 2015, Chris Beaumont and Thomas Robitaille
+#
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of the Glue project nor the names of its contributors
+# may be used to endorse or promote products derived from this software
+# without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
+# IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
+# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+# Which itself was based on the solution at
+#
+# https://gist.github.com/cpbotha/1b42a20c8f3eb9bb7cb8
+#
+# which was released under the MIT license:
+#
# Copyright (c) 2011 Sebastian Wiesner <lunaryorn@gmail.com>
# Modifications by Charl Botha <cpbotha@vxlabs.com>
-# * customWidgets support (registerCustomWidget() causes segfault in
-# pyside 1.1.2 on Ubuntu 12.04 x86_64)
-# * workingDirectory support in loadUi
-
-# found this here:
-# https://github.com/lunaryorn/snippets/blob/master/qt4/designer/pyside_dynamic.py
-
+#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
-
+#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
-
+#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
@@ -29,40 +61,68 @@
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
-"""
- How to load a user interface dynamically with PySide.
-
- .. moduleauthor:: Sebastian Wiesner <lunaryorn@gmail.com>
-"""
+"""How to load a user interface dynamically with PySide6"""
import logging
from ._qt import BINDING
-if BINDING == 'PySide2':
- from PySide2.QtCore import QMetaObject, Property, Qt
- from PySide2.QtWidgets import QFrame
- from PySide2.QtUiTools import QUiLoader
-elif BINDING == 'PySide6':
- from PySide6.QtCore import QMetaObject, Property, Qt
- from PySide6.QtWidgets import QFrame
- from PySide6.QtUiTools import QUiLoader
-else:
+
+if BINDING != "PySide6":
raise RuntimeError("Unsupported Qt binding: %s", BINDING)
+from PySide6.QtCore import QMetaObject, Property, Qt
+from PySide6.QtWidgets import QFrame
+from PySide6.QtUiTools import QUiLoader
+
_logger = logging.getLogger(__name__)
+# Specific custom widgets
+
+
+class _Line(QFrame):
+ """Widget to use as 'Line' Qt designer"""
+
+ def __init__(self, parent=None):
+ super(_Line, self).__init__(parent)
+ self.setFrameShape(QFrame.HLine)
+ self.setFrameShadow(QFrame.Sunken)
+
+ def getOrientation(self):
+ shape = self.frameShape()
+ if shape == QFrame.HLine:
+ return Qt.Horizontal
+ elif shape == QFrame.VLine:
+ return Qt.Vertical
+ else:
+ raise RuntimeError("Wrong shape: %d", shape)
+
+ def setOrientation(self, orientation):
+ if orientation == Qt.Horizontal:
+ self.setFrameShape(QFrame.HLine)
+ elif orientation == Qt.Vertical:
+ self.setFrameShape(QFrame.VLine)
+ else:
+ raise ValueError("Unsupported orientation %s" % str(orientation))
+
+ orientation = Property("Qt::Orientation", getOrientation, setOrientation)
+
+
+CUSTOM_WIDGETS = {"Line": _Line}
+"""Default custom widgets for `loadUi`"""
+
+
class UiLoader(QUiLoader):
"""
- Subclass :class:`~PySide.QtUiTools.QUiLoader` to create the user interface
- in a base instance.
+ Subclass of :class:`~PySide.QtUiTools.QUiLoader` to create the user
+ interface in a base instance.
Unlike :class:`~PySide.QtUiTools.QUiLoader` itself this class does not
create a new instance of the top-level widget, but creates the user
- interface in an existing instance of the top-level class.
+ interface in an existing instance of the top-level class if needed.
- This mimics the behaviour of :func:`PyQt*.uic.loadUi`.
+ This mimics the behaviour of :func:`PyQt4.uic.loadUi`.
"""
def __init__(self, baseinstance, customWidgets=None):
@@ -74,129 +134,91 @@ class UiLoader(QUiLoader):
subclass thereof.
``customWidgets`` is a dictionary mapping from class name to class
- object for widgets that you've promoted in the Qt Designer
- interface. Usually, this should be done by calling
- registerCustomWidget on the QUiLoader, but
- with PySide 1.1.2 on Ubuntu 12.04 x86_64 this causes a segfault.
+ object for custom widgets. Usually, this should be done by calling
+ registerCustomWidget on the QUiLoader, but with PySide 1.1.2 on
+ Ubuntu 12.04 x86_64 this causes a segfault.
``parent`` is the parent object of this loader.
"""
QUiLoader.__init__(self, baseinstance)
+
self.baseinstance = baseinstance
- self.customWidgets = {}
- self.uifile = None
- self.customWidgets.update(customWidgets)
- def createWidget(self, class_name, parent=None, name=''):
+ if customWidgets is None:
+ self.customWidgets = {}
+ else:
+ self.customWidgets = customWidgets
+
+ def createWidget(self, class_name, parent=None, name=""):
"""
Function that is called for each widget defined in ui file,
overridden here to populate baseinstance instead.
"""
if parent is None and self.baseinstance:
- # supposed to create the top-level widget, return the base instance
- # instead
+ # supposed to create the top-level widget, return the base
+ # instance instead
return self.baseinstance
else:
- if class_name in self.availableWidgets():
+ # For some reason, Line is not in the list of available
+ # widgets, but works fine, so we have to special case it here.
+ if class_name in self.availableWidgets() or class_name == "Line":
# create a new widget for child widgets
widget = QUiLoader.createWidget(self, class_name, parent, name)
else:
- # if not in the list of availableWidgets,
- # must be a custom widget
- # this will raise KeyError if the user has not supplied the
- # relevant class_name in the dictionary, or TypeError, if
- # customWidgets is None
- if class_name not in self.customWidgets:
- raise Exception('No custom widget ' + class_name +
- ' found in customWidgets param of' +
- 'UiFile %s.' % self.uifile)
+ # If not in the list of availableWidgets, must be a custom
+ # widget. This will raise KeyError if the user has not
+ # supplied the relevant class_name in the dictionary or if
+ # customWidgets is empty.
try:
widget = self.customWidgets[class_name](parent)
- except Exception:
- _logger.error("Fail to instanciate widget %s from file %s", class_name, self.uifile)
- raise
+ except KeyError as error:
+ raise Exception(
+ f"No custom widget {class_name} " "found in customWidgets"
+ ) from error
if self.baseinstance:
# set an attribute for the new child widget on the base
- # instance, just like PyQt*.uic.loadUi does.
+ # instance, just like PyQt4.uic.loadUi does.
setattr(self.baseinstance, name, widget)
- # this outputs the various widget names, e.g.
- # sampleGraphicsView, dockWidget, samplesTableView etc.
- # print(name)
-
return widget
- def _parse_custom_widgets(self, ui_file):
- """
- This function is used to parse a ui file and look for the <customwidgets>
- section, then automatically load all the custom widget classes.
- """
- import importlib
- from xml.etree.ElementTree import ElementTree
-
- # Parse the UI file
- etree = ElementTree()
- ui = etree.parse(ui_file)
-
- # Get the customwidgets section
- custom_widgets = ui.find('customwidgets')
-
- if custom_widgets is None:
- return
-
- custom_widget_classes = {}
-
- for custom_widget in custom_widgets.getchildren():
- cw_class = custom_widget.find('class').text
- cw_header = custom_widget.find('header').text
-
- module = importlib.import_module(cw_header)
-
- custom_widget_classes[cw_class] = getattr(module, cw_class)
+def _get_custom_widgets(ui_file):
+ """
+ This function is used to parse a ui file and look for the <customwidgets>
+ section, then automatically load all the custom widget classes.
+ """
- self.customWidgets.update(custom_widget_classes)
+ import sys
+ import importlib
+ from xml.etree.ElementTree import ElementTree
- def load(self, uifile):
- self._parse_custom_widgets(uifile)
- self.uifile = uifile
- return QUiLoader.load(self, uifile)
+ # Parse the UI file
+ etree = ElementTree()
+ ui = etree.parse(ui_file)
+ # Get the customwidgets section
+ custom_widgets = ui.find("customwidgets")
-class _Line(QFrame):
- """Widget to use as 'Line' Qt designer"""
- def __init__(self, parent=None):
- super(_Line, self).__init__(parent)
- self.setFrameShape(QFrame.HLine)
- self.setFrameShadow(QFrame.Sunken)
+ if custom_widgets is None:
+ return {}
- def getOrientation(self):
- shape = self.frameShape()
- if shape == QFrame.HLine:
- return Qt.Horizontal
- elif shape == QFrame.VLine:
- return Qt.Vertical
- else:
- raise RuntimeError("Wrong shape: %d", shape)
+ custom_widget_classes = {}
- def setOrientation(self, orientation):
- if orientation == Qt.Horizontal:
- self.setFrameShape(QFrame.HLine)
- elif orientation == Qt.Vertical:
- self.setFrameShape(QFrame.VLine)
- else:
- raise ValueError("Unsupported orientation %s" % str(orientation))
+ for custom_widget in list(custom_widgets):
+ cw_class = custom_widget.find("class").text
+ cw_header = custom_widget.find("header").text
- orientation = Property("Qt::Orientation", getOrientation, setOrientation)
+ module = importlib.import_module(cw_header)
+ custom_widget_classes[cw_class] = getattr(module, cw_class)
-CUSTOM_WIDGETS = {"Line": _Line}
-"""Default custom widgets for `loadUi`"""
+ return custom_widget_classes
def loadUi(uifile, baseinstance=None, package=None, resource_suffix=None):
@@ -205,30 +227,36 @@ def loadUi(uifile, baseinstance=None, package=None, resource_suffix=None):
``uifile`` is a string containing a file name of the UI file to load.
- If ``baseinstance`` is ``None``, the a new instance of the top-level widget
- will be created. Otherwise, the user interface is created within the given
- ``baseinstance``. In this case ``baseinstance`` must be an instance of the
- top-level widget class in the UI file to load, or a subclass thereof. In
- other words, if you've created a ``QMainWindow`` interface in the designer,
- ``baseinstance`` must be a ``QMainWindow`` or a subclass thereof, too. You
- cannot load a ``QMainWindow`` UI file with a plain
- :class:`~PySide.QtGui.QWidget` as ``baseinstance``.
+ If ``baseinstance`` is ``None``, the a new instance of the top-level
+ widget will be created. Otherwise, the user interface is created within
+ the given ``baseinstance``. In this case ``baseinstance`` must be an
+ instance of the top-level widget class in the UI file to load, or a
+ subclass thereof. In other words, if you've created a ``QMainWindow``
+ interface in the designer, ``baseinstance`` must be a ``QMainWindow``
+ or a subclass thereof, too. You cannot load a ``QMainWindow`` UI file
+ with a plain :class:`~PySide.QtGui.QWidget` as ``baseinstance``.
- :method:`~PySide.QtCore.QMetaObject.connectSlotsByName()` is called on the
- created user interface, so you can implemented your slots according to its
- conventions in your widget class.
+ :method:`~PySide.QtCore.QMetaObject.connectSlotsByName()` is called on
+ the created user interface, so you can implemented your slots according
+ to its conventions in your widget class.
Return ``baseinstance``, if ``baseinstance`` is not ``None``. Otherwise
return the newly created instance of the user interface.
"""
if package is not None:
- _logger.warning(
- "loadUi package parameter not implemented with PySide")
+ _logger.warning("loadUi package parameter not implemented with PySide")
if resource_suffix is not None:
- _logger.warning(
- "loadUi resource_suffix parameter not implemented with PySide")
+ _logger.warning("loadUi resource_suffix parameter not implemented with PySide")
+
+ # We parse the UI file and import any required custom widgets
+ customWidgets = _get_custom_widgets(uifile)
+
+ # Add CUSTOM_WIDGETS
+ for name, klass in CUSTOM_WIDGETS.items():
+ customWidgets.setdefault(name, klass)
+
+ loader = UiLoader(baseinstance, customWidgets)
- loader = UiLoader(baseinstance, customWidgets=CUSTOM_WIDGETS)
widget = loader.load(uifile)
QMetaObject.connectSlotsByName(widget)
return widget
diff --git a/src/silx/gui/qt/_qt.py b/src/silx/gui/qt/_qt.py
index b92fce2..e069f4b 100644
--- a/src/silx/gui/qt/_qt.py
+++ b/src/silx/gui/qt/_qt.py
@@ -28,20 +28,23 @@ __license__ = "MIT"
__date__ = "12/01/2022"
+import importlib
import logging
+import os
import sys
import traceback
+from packaging.version import Version
from silx.utils import deprecation
_logger = logging.getLogger(__name__)
BINDING = None
-"""The name of the Qt binding in use: PyQt5, PySide2, PySide6, PyQt6."""
+"""The name of the Qt binding in use: PyQt5, PySide6, PyQt6."""
QtBinding = None # noqa
-"""The Qt binding module in use: PyQt5, PySide2, PySide6, PyQt6."""
+"""The Qt binding module in use: PyQt5, PySide6, PyQt6."""
HAS_SVG = False
"""True if Qt provides support for Scalable Vector Graphics (QtSVG)."""
@@ -49,49 +52,68 @@ HAS_SVG = False
HAS_OPENGL = False
"""True if Qt provides support for OpenGL (QtOpenGL)."""
-# First check for an already loaded wrapper
-for _binding in ('PySide2', 'PyQt5', 'PySide6', 'PyQt6'):
- if _binding + '.QtCore' in sys.modules:
- BINDING = _binding
- break
-else: # Then try Qt bindings
- try:
- import PyQt5.QtCore # noqa
- except ImportError:
- if 'PyQt5' in sys.modules:
- del sys.modules["PyQt5"]
- try:
- import PySide6.QtCore # noqa
- except ImportError:
- if 'PySide6' in sys.modules:
- del sys.modules["PySide6"]
+
+def _select_binding() -> str:
+ """Select and load a Qt binding
+
+ Qt binding is selected according to:
+ - Already loaded binding
+ - QT_API environment variable
+ - Bindings order of priority
+
+ :raises ImportError:
+ :returns: Loaded binding
+ """
+ bindings = "PyQt5", "PySide6", "PyQt6"
+
+ envvar = os.environ.get("QT_API", "").lower()
+
+ # First check for an already loaded binding
+ for binding in bindings:
+ if f"{binding}.QtCore" in sys.modules:
+ if envvar and envvar != binding.lower():
+ _logger.warning(
+ f"Cannot satisfy QT_API={envvar} environment variable, {binding} is already loaded"
+ )
+ return binding
+
+ # Check if QT_API can be satisfied
+ if envvar:
+ selection = [b for b in bindings if envvar == b.lower()]
+ if not selection:
+ _logger.warning(f"Environment variable QT_API={envvar} is not supported")
+ else:
+ binding = selection[0]
try:
- import PySide2.QtCore # noqa
+ importlib.import_module(f"{binding}.QtCore")
except ImportError:
- if 'PySide2' in sys.modules:
- del sys.modules["PySide2"]
- try:
- import PyQt6.QtCore # noqa
- except ImportError:
- if 'PyQt6' in sys.modules:
- del sys.modules["PyQt6"]
-
- raise ImportError(
- 'No Qt wrapper found. Install PyQt5, PySide2, PySide6, PyQt6.')
- else:
- BINDING = 'PyQt6'
+ _logger.warning(
+ f"Cannot import {binding} specified by QT_API environment variable"
+ )
else:
- BINDING = 'PySide2'
+ return binding
+
+ # Try to load binding
+ for binding in bindings:
+ try:
+ importlib.import_module(f"{binding}.QtCore")
+ except ImportError:
+ if binding in sys.modules:
+ del sys.modules[binding]
else:
- BINDING = 'PySide6'
- else:
- BINDING = 'PyQt5'
+ return binding
+
+ raise ImportError("No Qt wrapper found. Install PyQt5, PySide6, PyQt6.")
+
+
+BINDING = _select_binding()
-if BINDING == 'PyQt5':
- _logger.debug('Using PyQt5 bindings')
+if BINDING == "PyQt5":
+ _logger.debug("Using PyQt5 bindings")
from PyQt5 import QtCore
- if sys.version_info >= (3, 10) and QtCore.PYQT_VERSION < 0x50e02:
+
+ if sys.version_info >= (3, 10) and QtCore.PYQT_VERSION < 0x50E02:
raise RuntimeError(
"PyQt5 v%s is not supported, please upgrade it." % QtCore.PYQT_VERSION_STR
)
@@ -129,75 +151,22 @@ if BINDING == 'PyQt5':
# Disable PyQt5's cooperative multi-inheritance since other bindings do not provide it.
# See https://www.riverbankcomputing.com/static/Docs/PyQt5/multiinheritance.html?highlight=inheritance
- class _Foo(object): pass
- class QObject(QObject, _Foo): pass
-
-
-elif BINDING == 'PySide2':
- deprecation.deprecated_warning(
- type_="Qt Binding",
- name="PySide2",
- replacement="PySide6",
- since_version="1.1",
- )
-
- import PySide2 as QtBinding # noqa
+ class _Foo(object):
+ pass
- from PySide2.QtCore import * # noqa
- from PySide2.QtGui import * # noqa
- from PySide2.QtWidgets import * # noqa
- from PySide2.QtPrintSupport import * # noqa
+ class QObject(QObject, _Foo):
+ pass
- try:
- from PySide2.QtOpenGL import * # noqa
- except ImportError:
- _logger.info("PySide2.QtOpenGL not available")
- HAS_OPENGL = False
- else:
- HAS_OPENGL = True
-
- try:
- from PySide2.QtSvg import * # noqa
- except ImportError:
- _logger.info("PySide2.QtSvg not available")
- HAS_SVG = False
- else:
- HAS_SVG = True
-
- pyqtSignal = Signal
-
- # Qt6 compatibility:
- # with PySide2 `exec` method has a special behavior
- class _ExecMixIn:
- """Mix-in class providind `exec` compatibility"""
- def exec(self, *args, **kwargs):
- return super().exec_(*args, **kwargs)
-
- # QtWidgets
- QApplication.exec = QApplication.exec_
- class QColorDialog(_ExecMixIn, QColorDialog): pass
- class QDialog(_ExecMixIn, QDialog): pass
- class QErrorMessage(_ExecMixIn, QErrorMessage): pass
- class QFileDialog(_ExecMixIn, QFileDialog): pass
- class QFontDialog(_ExecMixIn, QFontDialog): pass
- class QInputDialog(_ExecMixIn, QInputDialog): pass
- class QMenu(_ExecMixIn, QMenu): pass
- class QMessageBox(_ExecMixIn, QMessageBox): pass
- class QProgressDialog(_ExecMixIn, QProgressDialog): pass
- #QtCore
- class QCoreApplication(_ExecMixIn, QCoreApplication): pass
- class QEventLoop(_ExecMixIn, QEventLoop): pass
- if hasattr(QTextStreamManipulator, "exec_"):
- # exec_ only wrapped in PySide2 and NOT in PyQt5
- class QTextStreamManipulator(_ExecMixIn, QTextStreamManipulator): pass
- class QThread(_ExecMixIn, QThread): pass
-
-
-elif BINDING == 'PySide6':
- _logger.debug('Using PySide6 bindings')
+elif BINDING == "PySide6":
+ _logger.debug("Using PySide6 bindings")
import PySide6 as QtBinding # noqa
+ if Version(QtBinding.__version__) < Version("6.4"):
+ raise RuntimeError(
+ f"PySide6 v{QtBinding.__version__} is not supported, please upgrade it."
+ )
+
from PySide6.QtCore import * # noqa
from PySide6.QtGui import * # noqa
from PySide6.QtWidgets import * # noqa
@@ -223,13 +192,14 @@ elif BINDING == 'PySide6':
pyqtSignal = Signal
-elif BINDING == 'PyQt6':
- _logger.debug('Using PyQt6 bindings')
+elif BINDING == "PyQt6":
+ _logger.debug("Using PyQt6 bindings")
# Monkey-patch module to expose enum values for compatibility
# All Qt modules loaded here should be patched.
from . import _pyqt6
from PyQt6 import QtCore
+
if QtCore.PYQT_VERSION < int("0x60300", 16):
raise RuntimeError(
"PyQt6 v%s is not supported, please upgrade it." % QtCore.PYQT_VERSION_STR
@@ -237,8 +207,10 @@ elif BINDING == 'PyQt6':
from PyQt6 import QtGui, QtWidgets, QtPrintSupport, QtOpenGL, QtSvg
from PyQt6 import QtTest as _QtTest
+
_pyqt6.patch_enums(
- QtCore, QtGui, QtWidgets, QtPrintSupport, QtOpenGL, QtSvg, _QtTest)
+ QtCore, QtGui, QtWidgets, QtPrintSupport, QtOpenGL, QtSvg, _QtTest
+ )
import PyQt6 as QtBinding # noqa
@@ -274,11 +246,14 @@ elif BINDING == 'PyQt6':
# Disable PyQt6 cooperative multi-inheritance since other bindings do not provide it.
# See https://www.riverbankcomputing.com/static/Docs/PyQt6/multiinheritance.html?highlight=inheritance
- class _Foo(object): pass
- class QObject(QObject, _Foo): pass
+ class _Foo(object):
+ pass
+
+ class QObject(QObject, _Foo):
+ pass
else:
- raise ImportError('No Qt wrapper found. Install PyQt5, PySide2, PySide6 or PyQt6')
+ raise ImportError("No Qt wrapper found. Install PyQt5, PySide6 or PyQt6")
# provide a exception handler but not implement it by default
@@ -295,11 +270,11 @@ def exceptionHandler(type_, value, trace):
sys.excepthook = qt.exceptionHandler
"""
- _logger.error("%s %s %s", type_, value, ''.join(traceback.format_tb(trace)))
+ _logger.error("%s %s %s", type_, value, "".join(traceback.format_tb(trace)))
msg = QMessageBox()
msg.setWindowTitle("Unhandled exception")
msg.setIcon(QMessageBox.Critical)
msg.setInformativeText("%s %s\nPlease report details" % (type_, value))
- msg.setDetailedText(("%s " % value) + ''.join(traceback.format_tb(trace)))
+ msg.setDetailedText(("%s " % value) + "".join(traceback.format_tb(trace)))
msg.raise_()
msg.exec()
diff --git a/src/silx/gui/qt/_utils.py b/src/silx/gui/qt/_utils.py
index fb2b8ce..1015c29 100644
--- a/src/silx/gui/qt/_utils.py
+++ b/src/silx/gui/qt/_utils.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -38,7 +38,7 @@ def getMouseEventPosition(event):
:param QMouseEvent event:
:returns: (x, y) as a tuple of float
"""
- if _qt.BINDING in ("PyQt5", "PySide2"):
+ if _qt.BINDING == "PyQt5":
return float(event.x()), float(event.y())
# Qt6
position = event.position()
@@ -48,13 +48,8 @@ def getMouseEventPosition(event):
def supportedImageFormats():
"""Return a set of string of file format extensions supported by the
Qt runtime."""
- if _qt.BINDING == 'PySide2':
- def convert(data):
- return str(data.data(), 'ascii')
- else:
- convert = lambda data: str(data, 'ascii')
formats = _qt.QImageReader.supportedImageFormats()
- return set([convert(data) for data in formats])
+ return set([str(data, "ascii") for data in formats])
__globalThreadPoolInstance = None
@@ -62,7 +57,7 @@ __globalThreadPoolInstance = None
def silxGlobalThreadPool():
- """"Manage an own QThreadPool to avoid issue on Qt5 Windows with the
+ """Manage an own QThreadPool to avoid issue on Qt5 Windows with the
default Qt global thread pool.
A thread pool is create in lazy loading. With a maximum of 4 threads.
@@ -71,7 +66,7 @@ def silxGlobalThreadPool():
:rtype: qt.QThreadPool
"""
global __globalThreadPoolInstance
- if __globalThreadPoolInstance is None:
+ if __globalThreadPoolInstance is None:
tp = _qt.QThreadPool()
# Setting maxThreadCount fixes a segfault with PyQt 5.9.1 on Windows
maxThreadCount = min(4, tp.maxThreadCount())
diff --git a/src/silx/gui/qt/inspect.py b/src/silx/gui/qt/inspect.py
index c7fe32a..990b5fa 100644
--- a/src/silx/gui/qt/inspect.py
+++ b/src/silx/gui/qt/inspect.py
@@ -36,7 +36,7 @@ __date__ = "08/10/2018"
from . import _qt as qt
-if qt.BINDING == 'PyQt5':
+if qt.BINDING == "PyQt5":
try:
from PyQt5.sip import isdeleted as _isdeleted # noqa
from PyQt5.sip import ispycreated as createdByPython # noqa
@@ -46,7 +46,6 @@ if qt.BINDING == 'PyQt5':
from sip import ispycreated as createdByPython # noqa
from sip import ispyowned as ownedByPython # noqa
-
def isValid(obj):
"""Returns True if underlying C++ object is valid.
@@ -55,20 +54,10 @@ if qt.BINDING == 'PyQt5':
"""
return not _isdeleted(obj)
-elif qt.BINDING == 'PySide2':
- try:
- from PySide2.shiboken2 import isValid # noqa
- from PySide2.shiboken2 import createdByPython # noqa
- from PySide2.shiboken2 import ownedByPython # noqa
- except ImportError:
- from shiboken2 import isValid # noqa
- from shiboken2 import createdByPython # noqa
- from shiboken2 import ownedByPython # noqa
-
-elif qt.BINDING == 'PySide6':
+elif qt.BINDING == "PySide6":
from shiboken6 import isValid, createdByPython, ownedByPython # noqa
-elif qt.BINDING == 'PyQt6':
+elif qt.BINDING == "PyQt6":
from PyQt6.sip import isdeleted as _isdeleted # noqa
from PyQt6.sip import ispycreated as createdByPython # noqa
from PyQt6.sip import ispyowned as ownedByPython # noqa
@@ -81,9 +70,7 @@ elif qt.BINDING == 'PyQt6':
"""
return not _isdeleted(obj)
-
-
else:
raise ImportError("Unsupported Qt binding %s" % qt.BINDING)
-__all__ = ['isValid', 'createdByPython', 'ownedByPython']
+__all__ = ["isValid", "createdByPython", "ownedByPython"]
diff --git a/src/silx/gui/test/test_colors.py b/src/silx/gui/test/test_colors.py
index b0e6139..8c252a7 100755
--- a/src/silx/gui/test/test_colors.py
+++ b/src/silx/gui/test/test_colors.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2015-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2015-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -30,6 +30,9 @@ __date__ = "09/11/2018"
import unittest
import numpy
+import pytest
+
+import silx
from silx.utils.testutils import ParametricTestCase
from silx.gui import qt
from silx.gui import colors
@@ -38,38 +41,39 @@ from silx.gui.plot import items
from silx.utils.exceptions import NotEditableError
-class TestColor(ParametricTestCase):
- """Basic tests of rgba function"""
-
- TEST_COLORS = { # name: (colors, expected values)
- 'blue': ('blue', (0., 0., 1., 1.)),
- '#010203': ('#010203', (1. / 255., 2. / 255., 3. / 255., 1.)),
- '#01020304': ('#01020304', (1. / 255., 2. / 255., 3. / 255., 4. / 255.)),
- '3 x uint8': (numpy.array((1, 255, 0), dtype=numpy.uint8),
- (1 / 255., 1., 0., 1.)),
- '4 x uint8': (numpy.array((1, 255, 0, 1), dtype=numpy.uint8),
- (1 / 255., 1., 0., 1 / 255.)),
- '3 x float overflow': ((3., 0.5, 1.), (1., 0.5, 1., 1.)),
- }
-
- def testRGBA(self):
- """"Test rgba function with accepted values"""
- for name, test in self.TEST_COLORS.items():
- color, expected = test
- with self.subTest(msg=name):
- result = colors.rgba(color)
- self.assertEqual(result, expected)
-
- def testQColor(self):
- """"Test getQColor function with accepted values"""
- for name, test in self.TEST_COLORS.items():
- color, expected = test
- with self.subTest(msg=name):
- result = colors.asQColor(color)
- self.assertAlmostEqual(result.redF(), expected[0], places=4)
- self.assertAlmostEqual(result.greenF(), expected[1], places=4)
- self.assertAlmostEqual(result.blueF(), expected[2], places=4)
- self.assertAlmostEqual(result.alphaF(), expected[3], places=4)
+RGBA_TEST_CASES = (
+ # name
+ ("blue", (0.0, 0.0, 1.0, 1.0)),
+ # code
+ ("#010203", (1.0 / 255.0, 2.0 / 255.0, 3.0 / 255.0, 1.0)),
+ ("#01020304", (1.0 / 255.0, 2.0 / 255.0, 3.0 / 255.0, 4.0 / 255.0)),
+ # index name
+ ("C0", colors.rgba(silx.config.DEFAULT_PLOT_CURVE_COLORS[0])),
+ ("C2", colors.rgba(silx.config.DEFAULT_PLOT_CURVE_COLORS[2])),
+ # 3 uint
+ (numpy.array((1, 255, 0), dtype=numpy.uint8), (1 / 255.0, 1.0, 0.0, 1.0)),
+ # 4 uint
+ (numpy.array((1, 255, 0, 1), dtype=numpy.uint8), (1 / 255.0, 1.0, 0.0, 1 / 255.0)),
+ # float with overflow
+ ((3.0, 0.5, 1.0), (1.0, 0.5, 1.0, 1.0)),
+)
+
+
+@pytest.mark.parametrize("input, expected", RGBA_TEST_CASES)
+def testRgba(input, expected):
+ """Test rgba function with accepted values"""
+ result = colors.rgba(input)
+ assert result == expected
+
+
+@pytest.mark.parametrize("input, expected", RGBA_TEST_CASES)
+def testAsQColor(input, expected):
+ """Test asQColor function with accepted values"""
+ result = colors.asQColor(input)
+ assert result.redF() == pytest.approx(expected[0], abs=1e-5)
+ assert result.greenF() == pytest.approx(expected[1], abs=1e-5)
+ assert result.blueF() == pytest.approx(expected[2], abs=1e-5)
+ assert result.alphaF() == pytest.approx(expected[3], abs=1e-5)
class TestApplyColormapToData(ParametricTestCase):
@@ -77,24 +81,23 @@ class TestApplyColormapToData(ParametricTestCase):
def testApplyColormapToData(self):
"""Simple test of applyColormapToData function"""
- colormap = Colormap(name='gray', normalization='linear',
- vmin=0, vmax=255)
+ colormap = Colormap(name="gray", normalization="linear", vmin=0, vmax=255)
size = 10
- expected = numpy.empty((size, 4), dtype='uint8')
- expected[:, 0] = numpy.arange(size, dtype='uint8')
+ expected = numpy.empty((size, 4), dtype="uint8")
+ expected[:, 0] = numpy.arange(size, dtype="uint8")
expected[:, 1] = expected[:, 0]
expected[:, 2] = expected[:, 0]
expected[:, 3] = 255
- for dtype in ('uint8', 'int32', 'float32', 'float64'):
+ for dtype in ("uint8", "int32", "float32", "float64"):
with self.subTest(dtype=dtype):
array = numpy.arange(size, dtype=dtype)
result = colormap.applyToData(data=array)
self.assertTrue(numpy.all(numpy.equal(result, expected)))
def testAutoscaleFromDataReference(self):
- colormap = Colormap(name='gray', normalization='linear')
+ colormap = Colormap(name="gray", normalization="linear")
data = numpy.array([50])
reference = numpy.array([0, 100])
value = colormap.applyToData(data, reference)
@@ -102,7 +105,7 @@ class TestApplyColormapToData(ParametricTestCase):
self.assertEqual(value[0, 0], 128)
def testAutoscaleFromItemReference(self):
- colormap = Colormap(name='gray', normalization='linear')
+ colormap = Colormap(name="gray", normalization="linear")
data = numpy.array([50])
image = items.ImageData()
image.setData(numpy.array([[0, 100]]))
@@ -112,11 +115,11 @@ class TestApplyColormapToData(ParametricTestCase):
def testNaNColor(self):
"""Test Colormap.applyToData with NaN values"""
- colormap = Colormap(name='gray', normalization='linear')
- colormap.setNaNColor('red')
+ colormap = Colormap(name="gray", normalization="linear")
+ colormap.setNaNColor("red")
self.assertEqual(colormap.getNaNColor(), qt.QColor(255, 0, 0))
- data = numpy.array([50., numpy.nan])
+ data = numpy.array([50.0, numpy.nan])
image = items.ImageData()
image.setData(numpy.array([[0, 100]]))
value = colormap.applyToData(data, reference=image)
@@ -126,8 +129,7 @@ class TestApplyColormapToData(ParametricTestCase):
class TestDictAPI(unittest.TestCase):
- """Make sure the old dictionary API is working
- """
+ """Make sure the old dictionary API is working"""
def setUp(self):
self.vmin = -1.0
@@ -135,75 +137,79 @@ class TestDictAPI(unittest.TestCase):
def testGetItem(self):
"""test the item getter API ([xxx])"""
- colormap = Colormap(name='viridis',
- normalization=Colormap.LINEAR,
- vmin=self.vmin,
- vmax=self.vmax)
- self.assertTrue(colormap['name'] == 'viridis')
- self.assertTrue(colormap['normalization'] == Colormap.LINEAR)
- self.assertTrue(colormap['vmin'] == self.vmin)
- self.assertTrue(colormap['vmax'] == self.vmax)
+ colormap = Colormap(
+ name="viridis",
+ normalization=Colormap.LINEAR,
+ vmin=self.vmin,
+ vmax=self.vmax,
+ )
+ self.assertTrue(colormap["name"] == "viridis")
+ self.assertTrue(colormap["normalization"] == Colormap.LINEAR)
+ self.assertTrue(colormap["vmin"] == self.vmin)
+ self.assertTrue(colormap["vmax"] == self.vmax)
with self.assertRaises(KeyError):
- colormap['toto']
+ colormap["toto"]
def testGetDict(self):
"""Test the getDict function API"""
- clmObject = Colormap(name='viridis',
- normalization=Colormap.LINEAR,
- vmin=self.vmin,
- vmax=self.vmax)
+ clmObject = Colormap(
+ name="viridis",
+ normalization=Colormap.LINEAR,
+ vmin=self.vmin,
+ vmax=self.vmax,
+ )
clmDict = clmObject._toDict()
- self.assertTrue(clmDict['name'] == 'viridis')
- self.assertTrue(clmDict['autoscale'] is False)
- self.assertTrue(clmDict['vmin'] == self.vmin)
- self.assertTrue(clmDict['vmax'] == self.vmax)
- self.assertTrue(clmDict['normalization'] == Colormap.LINEAR)
+ self.assertTrue(clmDict["name"] == "viridis")
+ self.assertTrue(clmDict["autoscale"] is False)
+ self.assertTrue(clmDict["vmin"] == self.vmin)
+ self.assertTrue(clmDict["vmax"] == self.vmax)
+ self.assertTrue(clmDict["normalization"] == Colormap.LINEAR)
clmObject.setVRange(None, None)
- self.assertTrue(clmObject._toDict()['autoscale'] is True)
+ self.assertTrue(clmObject._toDict()["autoscale"] is True)
def testSetValidDict(self):
"""Test that if a colormap is created from a dict then it is correctly
created and the values are copied (so if some values from the dict
is changing, this won't affect the Colormap object"""
clm_dict = {
- 'name': 'temperature',
- 'vmin': 1.0,
- 'vmax': 2.0,
- 'normalization': 'linear',
- 'colors': None,
- 'autoscale': False
+ "name": "temperature",
+ "vmin": 1.0,
+ "vmax": 2.0,
+ "normalization": "linear",
+ "colors": None,
+ "autoscale": False,
}
# Test that the colormap is correctly created
colormapObject = Colormap._fromDict(clm_dict)
- self.assertTrue(colormapObject.getName() == clm_dict['name'])
- self.assertTrue(colormapObject.getColormapLUT() == clm_dict['colors'])
- self.assertTrue(colormapObject.getVMin() == clm_dict['vmin'])
- self.assertTrue(colormapObject.getVMax() == clm_dict['vmax'])
- self.assertTrue(colormapObject.isAutoscale() == clm_dict['autoscale'])
+ self.assertTrue(colormapObject.getName() == clm_dict["name"])
+ self.assertTrue(colormapObject.getColormapLUT() == clm_dict["colors"])
+ self.assertTrue(colormapObject.getVMin() == clm_dict["vmin"])
+ self.assertTrue(colormapObject.getVMax() == clm_dict["vmax"])
+ self.assertTrue(colormapObject.isAutoscale() == clm_dict["autoscale"])
# Check that the colormap has copied the values
- clm_dict['vmin'] = None
- clm_dict['vmax'] = None
- clm_dict['colors'] = [1.0, 2.0]
- clm_dict['autoscale'] = True
- clm_dict['normalization'] = Colormap.LOGARITHM
- clm_dict['name'] = 'viridis'
-
- self.assertFalse(colormapObject.getName() == clm_dict['name'])
- self.assertFalse(colormapObject.getColormapLUT() == clm_dict['colors'])
- self.assertFalse(colormapObject.getVMin() == clm_dict['vmin'])
- self.assertFalse(colormapObject.getVMax() == clm_dict['vmax'])
- self.assertFalse(colormapObject.isAutoscale() == clm_dict['autoscale'])
+ clm_dict["vmin"] = None
+ clm_dict["vmax"] = None
+ clm_dict["colors"] = [1.0, 2.0]
+ clm_dict["autoscale"] = True
+ clm_dict["normalization"] = Colormap.LOGARITHM
+ clm_dict["name"] = "viridis"
+
+ self.assertFalse(colormapObject.getName() == clm_dict["name"])
+ self.assertFalse(colormapObject.getColormapLUT() == clm_dict["colors"])
+ self.assertFalse(colormapObject.getVMin() == clm_dict["vmin"])
+ self.assertFalse(colormapObject.getVMax() == clm_dict["vmax"])
+ self.assertFalse(colormapObject.isAutoscale() == clm_dict["autoscale"])
def testMissingKeysFromDict(self):
"""Make sure we can create a Colormap object from a dictionary even if
there is missing keys except if those keys are 'colors' or 'name'
"""
- colormap = Colormap._fromDict({'name': 'blue'})
+ colormap = Colormap._fromDict({"name": "blue"})
self.assertTrue(colormap.getVMin() is None)
- colormap = Colormap._fromDict({'colors': numpy.zeros((5, 3))})
+ colormap = Colormap._fromDict({"colors": numpy.zeros((5, 3))})
self.assertTrue(colormap.getName() is None)
with self.assertRaises(ValueError):
@@ -214,12 +220,12 @@ class TestDictAPI(unittest.TestCase):
knowed
"""
clm_dict = {
- 'name': 'temperature',
- 'vmin': 1.0,
- 'vmax': 2.0,
- 'normalization': 'toto',
- 'colors': None,
- 'autoscale': False
+ "name": "temperature",
+ "vmin": 1.0,
+ "vmax": 2.0,
+ "normalization": "toto",
+ "colors": None,
+ "autoscale": False,
}
with self.assertRaises(ValueError):
Colormap._fromDict(clm_dict)
@@ -227,26 +233,26 @@ class TestDictAPI(unittest.TestCase):
def testNumericalColors(self):
"""Make sure the old API using colors=int was supported"""
clm_dict = {
- 'name': 'temperature',
- 'vmin': 1.0,
- 'vmax': 2.0,
- 'colors': 256,
- 'autoscale': False
+ "name": "temperature",
+ "vmin": 1.0,
+ "vmax": 2.0,
+ "colors": 256,
+ "autoscale": False,
}
Colormap._fromDict(clm_dict)
class TestObjectAPI(ParametricTestCase):
"""Test the new Object API of the colormap"""
+
def testVMinVMax(self):
"""Test getter and setter associated to vmin and vmax values"""
vmin = 1.0
vmax = 2.0
- colormapObject = Colormap(name='viridis',
- vmin=vmin,
- vmax=vmax,
- normalization=Colormap.LINEAR)
+ colormapObject = Colormap(
+ name="viridis", vmin=vmin, vmax=vmax, normalization=Colormap.LINEAR
+ )
with self.assertRaises(ValueError):
colormapObject.setVMin(3)
@@ -265,15 +271,14 @@ class TestObjectAPI(ParametricTestCase):
self.assertTrue(colormapObject.isAutoscale() is True)
def testCopy(self):
- """Make sure the copy function is correctly processing
- """
- colormapObject = Colormap(name=None,
- colors=numpy.array([[1., 0., 0.],
- [0., 1., 0.],
- [0., 0., 1.]]),
- vmin=None,
- vmax=None,
- normalization=Colormap.LOGARITHM)
+ """Make sure the copy function is correctly processing"""
+ colormapObject = Colormap(
+ name=None,
+ colors=numpy.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]),
+ vmin=None,
+ vmax=None,
+ normalization=Colormap.LOGARITHM,
+ )
colormapObject2 = colormapObject.copy()
self.assertTrue(colormapObject == colormapObject2)
@@ -290,23 +295,11 @@ class TestObjectAPI(ParametricTestCase):
applying
"""
# test linear scale
- data = numpy.array([-1, 1, 2, 3, float('nan')])
- cl1 = Colormap(name='gray',
- normalization=Colormap.LINEAR,
- vmin=0,
- vmax=2)
- cl2 = Colormap(name='gray',
- normalization=Colormap.LINEAR,
- vmin=None,
- vmax=2)
- cl3 = Colormap(name='gray',
- normalization=Colormap.LINEAR,
- vmin=0,
- vmax=None)
- cl4 = Colormap(name='gray',
- normalization=Colormap.LINEAR,
- vmin=None,
- vmax=None)
+ data = numpy.array([-1, 1, 2, 3, float("nan")])
+ cl1 = Colormap(name="gray", normalization=Colormap.LINEAR, vmin=0, vmax=2)
+ cl2 = Colormap(name="gray", normalization=Colormap.LINEAR, vmin=None, vmax=2)
+ cl3 = Colormap(name="gray", normalization=Colormap.LINEAR, vmin=0, vmax=None)
+ cl4 = Colormap(name="gray", normalization=Colormap.LINEAR, vmin=None, vmax=None)
self.assertTrue(cl1.getColormapRange(data) == (0, 2))
self.assertTrue(cl2.getColormapRange(data) == (-1, 2))
@@ -315,30 +308,23 @@ class TestObjectAPI(ParametricTestCase):
# test linear with annoying cases
self.assertEqual(cl3.getColormapRange((-1, -2)), (0, 0))
- self.assertEqual(cl4.getColormapRange(()), (0., 1.))
- self.assertEqual(cl4.getColormapRange(
- (float('nan'), float('inf'), 1., -float('inf'), 2)), (1., 2.))
- self.assertEqual(cl4.getColormapRange(
- (float('nan'), float('inf'))), (0., 1.))
+ self.assertEqual(cl4.getColormapRange(()), (0.0, 1.0))
+ self.assertEqual(
+ cl4.getColormapRange((float("nan"), float("inf"), 1.0, -float("inf"), 2)),
+ (1.0, 2.0),
+ )
+ self.assertEqual(cl4.getColormapRange((float("nan"), float("inf"))), (0.0, 1.0))
# test log scale
- data = numpy.array([float('nan'), -1, 1, 10, 100, 1000])
- cl1 = Colormap(name='gray',
- normalization=Colormap.LOGARITHM,
- vmin=1,
- vmax=100)
- cl2 = Colormap(name='gray',
- normalization=Colormap.LOGARITHM,
- vmin=None,
- vmax=100)
- cl3 = Colormap(name='gray',
- normalization=Colormap.LOGARITHM,
- vmin=1,
- vmax=None)
- cl4 = Colormap(name='gray',
- normalization=Colormap.LOGARITHM,
- vmin=None,
- vmax=None)
+ data = numpy.array([float("nan"), -1, 1, 10, 100, 1000])
+ cl1 = Colormap(name="gray", normalization=Colormap.LOGARITHM, vmin=1, vmax=100)
+ cl2 = Colormap(
+ name="gray", normalization=Colormap.LOGARITHM, vmin=None, vmax=100
+ )
+ cl3 = Colormap(name="gray", normalization=Colormap.LOGARITHM, vmin=1, vmax=None)
+ cl4 = Colormap(
+ name="gray", normalization=Colormap.LOGARITHM, vmin=None, vmax=None
+ )
self.assertTrue(cl1.getColormapRange(data) == (1, 100))
self.assertTrue(cl2.getColormapRange(data) == (1, 100))
@@ -347,12 +333,15 @@ class TestObjectAPI(ParametricTestCase):
# test log with annoying cases
self.assertEqual(cl3.getColormapRange((0.1, 0.2)), (1, 1))
- self.assertEqual(cl4.getColormapRange((-2., -1.)), (1., 1.))
- self.assertEqual(cl4.getColormapRange(()), (1., 10.))
- self.assertEqual(cl4.getColormapRange(
- (float('nan'), float('inf'), 1., -float('inf'), 2)), (1., 2.))
- self.assertEqual(cl4.getColormapRange(
- (float('nan'), float('inf'))), (1., 10.))
+ self.assertEqual(cl4.getColormapRange((-2.0, -1.0)), (1.0, 1.0))
+ self.assertEqual(cl4.getColormapRange(()), (1.0, 10.0))
+ self.assertEqual(
+ cl4.getColormapRange((float("nan"), float("inf"), 1.0, -float("inf"), 2)),
+ (1.0, 2.0),
+ )
+ self.assertEqual(
+ cl4.getColormapRange((float("nan"), float("inf"))), (1.0, 10.0)
+ )
def testApplyToData(self):
"""Test applyToData on different datasets"""
@@ -362,11 +351,10 @@ class TestObjectAPI(ParametricTestCase):
numpy.array((-numpy.inf, numpy.inf, 1.0, 2.0)), # Some infinite
]
- for normalization in ('linear', 'log'):
- colormap = Colormap(name='gray',
- normalization=normalization,
- vmin=None,
- vmax=None)
+ for normalization in ("linear", "log"):
+ colormap = Colormap(
+ name="gray", normalization=normalization, vmin=None, vmax=None
+ )
for data in datasets:
with self.subTest(data=data):
@@ -378,14 +366,13 @@ class TestObjectAPI(ParametricTestCase):
def testGetNColors(self):
"""Test getNColors method"""
# specific LUT
- colormap = Colormap(name=None,
- colors=((0., 0., 0.), (1., 1., 1.)),
- vmin=1000,
- vmax=2000)
+ colormap = Colormap(
+ name=None, colors=((0.0, 0.0, 0.0), (1.0, 1.0, 1.0)), vmin=1000, vmax=2000
+ )
colors = colormap.getNColors()
- self.assertTrue(numpy.all(numpy.equal(
- colors,
- ((0, 0, 0, 255), (255, 255, 255, 255)))))
+ self.assertTrue(
+ numpy.all(numpy.equal(colors, ((0, 0, 0, 255), (255, 255, 255, 255))))
+ )
def testEditableMode(self):
"""Make sure the colormap will raise NotEditableError when try to
@@ -393,17 +380,17 @@ class TestObjectAPI(ParametricTestCase):
colormap = Colormap()
colormap.setEditable(False)
with self.assertRaises(NotEditableError):
- colormap.setVRange(0., 1.)
+ colormap.setVRange(0.0, 1.0)
with self.assertRaises(NotEditableError):
- colormap.setVMin(1.)
+ colormap.setVMin(1.0)
with self.assertRaises(NotEditableError):
- colormap.setVMax(1.)
+ colormap.setVMax(1.0)
with self.assertRaises(NotEditableError):
colormap.setNormalization(Colormap.LOGARITHM)
with self.assertRaises(NotEditableError):
- colormap.setName('magma')
+ colormap.setName("magma")
with self.assertRaises(NotEditableError):
- colormap.setColormapLUT([[0., 0., 0.], [1., 1., 1.]])
+ colormap.setColormapLUT([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]])
with self.assertRaises(NotEditableError):
colormap._setFromDict(colormap._toDict())
state = colormap.saveState()
@@ -430,7 +417,9 @@ class TestObjectAPI(ParametricTestCase):
def testSet(self):
colormap = Colormap()
- other = Colormap(name="viridis", vmin=1, vmax=2, normalization=Colormap.LOGARITHM)
+ other = Colormap(
+ name="viridis", vmin=1, vmax=2, normalization=Colormap.LOGARITHM
+ )
self.assertNotEqual(colormap, other)
colormap.setFromColormap(other)
self.assertIsNot(colormap, other)
@@ -443,13 +432,10 @@ class TestObjectAPI(ParametricTestCase):
self.assertEqual(colormap.getAutoscaleMode(), Colormap.MINMAX)
def testStoreRestore(self):
- colormaps = [
- Colormap(name="viridis"),
- Colormap(normalization=Colormap.SQRT)
- ]
+ colormaps = [Colormap(name="viridis"), Colormap(normalization=Colormap.SQRT)]
cmap = Colormap(normalization=Colormap.GAMMA)
cmap.setGammaNormalizationParameter(1.2)
- cmap.setNaNColor('red')
+ cmap.setNaNColor("red")
colormaps.append(cmap)
for expected in colormaps:
with self.subTest(colormap=expected):
@@ -459,29 +445,37 @@ class TestObjectAPI(ParametricTestCase):
self.assertEqual(expected, result)
def testStorageV1(self):
- state = b'\x00\x00\x00\x10\x00C\x00o\x00l\x00o\x00r\x00m\x00a\x00p\x00\x00'\
- b'\x00\x01\x00\x00\x00\x0E\x00v\x00i\x00r\x00i\x00d\x00i\x00s\x00'\
- b'\x00\x00\x00\x06\x00?\xF0\x00\x00\x00\x00\x00\x00\x00\x00\x00'\
- b'\x00\x06\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x06\x00'\
- b'l\x00o\x00g'
+ state = (
+ b"\x00\x00\x00\x10\x00C\x00o\x00l\x00o\x00r\x00m\x00a\x00p\x00\x00"
+ b"\x00\x01\x00\x00\x00\x0E\x00v\x00i\x00r\x00i\x00d\x00i\x00s\x00"
+ b"\x00\x00\x00\x06\x00?\xF0\x00\x00\x00\x00\x00\x00\x00\x00\x00"
+ b"\x00\x06\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x06\x00"
+ b"l\x00o\x00g"
+ )
state = qt.QByteArray(state)
colormap = Colormap()
colormap.restoreState(state)
- expected = Colormap(name="viridis", vmin=1, vmax=2, normalization=Colormap.LOGARITHM)
+ expected = Colormap(
+ name="viridis", vmin=1, vmax=2, normalization=Colormap.LOGARITHM
+ )
self.assertEqual(colormap, expected)
def testStorageV2(self):
- state = b'\x00\x00\x00\x10\x00C\x00o\x00l\x00o\x00r\x00m\x00a\x00p\x00'\
- b'\x00\x00\x02\x00\x00\x00\x0e\x00v\x00i\x00r\x00i\x00d\x00i\x00'\
- b's\x00\x00\x00\x00\x06\x00?\xf0\x00\x00\x00\x00\x00\x00\x00\x00'\
- b'\x00\x00\x06\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x06'\
- b'\x00l\x00o\x00g\x00\x00\x00\x0c\x00m\x00i\x00n\x00m\x00a\x00x'
+ state = (
+ b"\x00\x00\x00\x10\x00C\x00o\x00l\x00o\x00r\x00m\x00a\x00p\x00"
+ b"\x00\x00\x02\x00\x00\x00\x0e\x00v\x00i\x00r\x00i\x00d\x00i\x00"
+ b"s\x00\x00\x00\x00\x06\x00?\xf0\x00\x00\x00\x00\x00\x00\x00\x00"
+ b"\x00\x00\x06\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x06"
+ b"\x00l\x00o\x00g\x00\x00\x00\x0c\x00m\x00i\x00n\x00m\x00a\x00x"
+ )
state = qt.QByteArray(state)
colormap = Colormap()
colormap.restoreState(state)
- expected = Colormap(name="viridis", vmin=1, vmax=2, normalization=Colormap.LOGARITHM)
+ expected = Colormap(
+ name="viridis", vmin=1, vmax=2, normalization=Colormap.LOGARITHM
+ )
expected.setGammaNormalizationParameter(1.5)
self.assertEqual(colormap, expected)
@@ -498,7 +492,7 @@ class TestPreferredColormaps(unittest.TestCase):
colors.setPreferredColormaps(self._colormaps)
def test(self):
- colormaps = 'viridis', 'magma'
+ colormaps = "viridis", "magma"
colors.setPreferredColormaps(colormaps)
self.assertEqual(colors.preferredColormaps(), colormaps)
@@ -507,10 +501,10 @@ class TestPreferredColormaps(unittest.TestCase):
colors.setPreferredColormaps(())
with self.assertRaises(ValueError):
- colors.setPreferredColormaps(('This is not a colormap',))
+ colors.setPreferredColormaps(("This is not a colormap",))
- colormaps = 'red', 'green'
- colors.setPreferredColormaps(('This is not a colormap',) + colormaps)
+ colormaps = "red", "green"
+ colors.setPreferredColormaps(("This is not a colormap",) + colormaps)
self.assertEqual(colors.preferredColormaps(), colormaps)
@@ -522,7 +516,7 @@ class TestRegisteredLut(unittest.TestCase):
lut = numpy.arange(8 * 3)
lut.shape = -1, 3
lut = lut / (8.0 * 3)
- colors.registerLUT("test_8", colors=lut, cursor_color='blue')
+ colors.registerLUT("test_8", colors=lut, cursor_color="blue")
def testColormap(self):
colormap = Colormap("test_8")
@@ -530,7 +524,7 @@ class TestRegisteredLut(unittest.TestCase):
def testCursor(self):
color = colors.cursorColorForColormap("test_8")
- self.assertEqual(color, 'blue')
+ self.assertEqual(color, "blue")
def testLut(self):
colormap = Colormap("test_8")
@@ -554,7 +548,10 @@ class TestRegisteredLut(unittest.TestCase):
self.assertEqual(lut[0, 0], 255)
def testFloatRGBA(self):
- lut = numpy.array([[1.0, 0, 0, 128 / 256.0], [0.5, 0, 0, 1.0], [0.0, 0, 0, 1.0]], dtype="float")
+ lut = numpy.array(
+ [[1.0, 0, 0, 128 / 256.0], [0.5, 0, 0, 1.0], [0.0, 0, 0, 1.0]],
+ dtype="float",
+ )
colors.registerLUT("test_type", lut)
colormap = colors.Colormap(name="test_type")
lut = colormap.getNColors(3)
@@ -564,28 +561,75 @@ class TestRegisteredLut(unittest.TestCase):
class TestAutoscaleRange(ParametricTestCase):
-
def testAutoscaleRange(self):
nan = numpy.nan
- data_std_inside = numpy.array([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2])
- data_std_inside_nan = numpy.array([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, numpy.nan])
+ data_std_inside = numpy.array(
+ [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2]
+ )
+ data_std_inside_nan = numpy.array(
+ [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, numpy.nan]
+ )
data = [
# Positive values
(Colormap.LINEAR, Colormap.MINMAX, numpy.array([10, 20, 50]), (10, 50)),
- (Colormap.LOGARITHM, Colormap.MINMAX, numpy.array([10, 50, 100]), (10, 100)),
- (Colormap.LINEAR, Colormap.STDDEV3, data_std_inside, (0.026671473215424735, 1.9733285267845753)),
- (Colormap.LOGARITHM, Colormap.STDDEV3, data_std_inside, (1, 1.6733506885453602)),
+ (
+ Colormap.LOGARITHM,
+ Colormap.MINMAX,
+ numpy.array([10, 50, 100]),
+ (10, 100),
+ ),
+ (
+ Colormap.LINEAR,
+ Colormap.STDDEV3,
+ data_std_inside,
+ (0.026671473215424735, 1.9733285267845753),
+ ),
+ (
+ Colormap.LOGARITHM,
+ Colormap.STDDEV3,
+ data_std_inside,
+ (1, 1.6733506885453602),
+ ),
(Colormap.LINEAR, Colormap.STDDEV3, numpy.array([10, 100]), (10, 100)),
(Colormap.LOGARITHM, Colormap.STDDEV3, numpy.array([10, 100]), (10, 100)),
-
# With nan
- (Colormap.LINEAR, Colormap.MINMAX, numpy.array([10, 20, 50, nan]), (10, 50)),
- (Colormap.LOGARITHM, Colormap.MINMAX, numpy.array([10, 50, 100, nan]), (10, 100)),
- (Colormap.LINEAR, Colormap.STDDEV3, data_std_inside_nan, (0.026671473215424735, 1.9733285267845753)),
- (Colormap.LOGARITHM, Colormap.STDDEV3, data_std_inside_nan, (1, 1.6733506885453602)),
+ (
+ Colormap.LINEAR,
+ Colormap.MINMAX,
+ numpy.array([10, 20, 50, nan]),
+ (10, 50),
+ ),
+ (
+ Colormap.LOGARITHM,
+ Colormap.MINMAX,
+ numpy.array([10, 50, 100, nan]),
+ (10, 100),
+ ),
+ (
+ Colormap.LINEAR,
+ Colormap.STDDEV3,
+ data_std_inside_nan,
+ (0.026671473215424735, 1.9733285267845753),
+ ),
+ (
+ Colormap.LOGARITHM,
+ Colormap.STDDEV3,
+ data_std_inside_nan,
+ (1, 1.6733506885453602),
+ ),
# With negative
- (Colormap.LOGARITHM, Colormap.MINMAX, numpy.array([10, 50, 100, -50]), (10, 100)),
- (Colormap.LOGARITHM, Colormap.STDDEV3, numpy.array([10, 100, -10]), (10, 100)),
+ (
+ Colormap.LOGARITHM,
+ Colormap.MINMAX,
+ numpy.array([10, 50, 100, -50]),
+ (10, 100),
+ ),
+ (
+ Colormap.LOGARITHM,
+ Colormap.STDDEV3,
+ numpy.array([10, 100, -10]),
+ (10, 100),
+ ),
]
for norm, mode, array, expectedRange in data:
with self.subTest(norm=norm, mode=mode, array=array):
diff --git a/src/silx/gui/test/test_console.py b/src/silx/gui/test/test_console.py
index f636287..4a25fe3 100644
--- a/src/silx/gui/test/test_console.py
+++ b/src/silx/gui/test/test_console.py
@@ -51,8 +51,8 @@ def console(qapp_utils):
pytest.skip("IPythonDockWidget is not available")
console = IPythonDockWidget(
- available_vars={"a": _a, "f": _f},
- custom_banner="Welcome!\n")
+ available_vars={"a": _a, "f": _f}, custom_banner="Welcome!\n"
+ )
console.show()
qapp_utils.qWaitForWindowExposed(console)
yield console
@@ -67,6 +67,6 @@ def testShow(console):
def testInteract(console, qapp_utils):
qapp_utils.mouseClick(console, qt.Qt.LeftButton)
- qapp_utils.keyClicks(console, 'import silx')
+ qapp_utils.keyClicks(console, "import silx")
qapp_utils.keyClick(console, qt.Qt.Key_Enter)
qapp_utils.qapp.processEvents()
diff --git a/src/silx/gui/test/test_icons.py b/src/silx/gui/test/test_icons.py
index 59c7e00..6797398 100644
--- a/src/silx/gui/test/test_icons.py
+++ b/src/silx/gui/test/test_icons.py
@@ -51,8 +51,12 @@ class TestIcons(TestCaseQt):
os.mkdir(os.path.join(cls.tmpDirectory, "gui"))
destination = os.path.join(cls.tmpDirectory, "gui", "icons")
os.mkdir(destination)
- shutil.copy(silx.resources.resource_filename("gui/icons/zoom-in.png"), destination)
- shutil.copy(silx.resources.resource_filename("gui/icons/zoom-out.svg"), destination)
+ shutil.copy(
+ silx.resources.resource_filename("gui/icons/zoom-in.png"), destination
+ )
+ shutil.copy(
+ silx.resources.resource_filename("gui/icons/zoom-out.svg"), destination
+ )
@classmethod
def tearDownClass(cls):
@@ -62,7 +66,9 @@ class TestIcons(TestCaseQt):
def setUp(self):
# Store the original configuration
self._oldResources = dict(silx.resources._RESOURCE_DIRECTORIES)
- silx.resources.register_resource_directory("test", "foo.bar", forced_path=self.tmpDirectory)
+ silx.resources.register_resource_directory(
+ "test", "foo.bar", forced_path=self.tmpDirectory
+ )
unittest.TestCase.setUp(self)
def tearDown(self):
diff --git a/src/silx/gui/test/test_qt.py b/src/silx/gui/test/test_qt.py
index 692d7f7..17bdc72 100644
--- a/src/silx/gui/test/test_qt.py
+++ b/src/silx/gui/test/test_qt.py
@@ -36,6 +36,7 @@ from silx.test.utils import temp_dir
from silx.gui.utils.testutils import TestCaseQt
from silx.gui import qt
+
try:
from silx.gui.qt import inspect as qt_inspect
except ImportError:
@@ -146,7 +147,7 @@ class TestLoadUi(TestCaseQt):
uifile = os.path.join(tmp, "test.ui")
# write file
- with open(uifile, mode='w') as f:
+ with open(uifile, mode="w") as f:
f.write(self.TEST_UI)
class TestMainWindow(qt.QMainWindow):
@@ -185,12 +186,11 @@ class TestQtInspect(unittest.TestCase):
self.assertFalse(qt_inspect.isValid(obj))
-@pytest.mark.skipif(qt.BINDING not in ("PyQt5", "PySide2"),
- reason="PyQt5/PySide2 only test")
+@pytest.mark.skipif(qt.BINDING != "PyQt5", reason="PyQt5 only test")
def test_exec_():
"""Test the exec_ is still useable with Qt5 bindings"""
klasses = [
- #QtWidgets
+ # QtWidgets
qt.QApplication,
qt.QColorDialog,
qt.QDialog,
@@ -201,11 +201,15 @@ def test_exec_():
qt.QMenu,
qt.QMessageBox,
qt.QProgressDialog,
- #QtCore
+ # QtCore
qt.QCoreApplication,
qt.QEventLoop,
qt.QThread,
]
for klass in klasses:
- assert hasattr(klass, "exec") and callable(klass.exec), "%s.exec missing" % klass.__name__
- assert hasattr(klass, "exec_") and callable(klass.exec_), "%s.exec_ missing" % klass.__name__
+ assert hasattr(klass, "exec") and callable(klass.exec), (
+ "%s.exec missing" % klass.__name__
+ )
+ assert hasattr(klass, "exec_") and callable(klass.exec_), (
+ "%s.exec_ missing" % klass.__name__
+ )
diff --git a/src/silx/gui/test/utils.py b/src/silx/gui/test/utils.py
deleted file mode 100644
index 1cfee67..0000000
--- a/src/silx/gui/test/utils.py
+++ /dev/null
@@ -1,40 +0,0 @@
-# /*##########################################################################
-#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Color conversion function, color dictionary and colormap tools."""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "05/10/2018"
-
-import silx.utils.deprecation
-
-silx.utils.deprecation.deprecated_warning("Module",
- name="silx.gui.test.utils",
- reason="moved",
- replacement="silx.gui.utils.testutils",
- since_version="0.9.0",
- only_once=True,
- skip_backtrace_count=1)
-
-from ..utils.testutils import * # noqa
diff --git a/src/silx/gui/utils/__init__.py b/src/silx/gui/utils/__init__.py
index 4fae646..248aa16 100755
--- a/src/silx/gui/utils/__init__.py
+++ b/src/silx/gui/utils/__init__.py
@@ -47,9 +47,9 @@ def blockSignals(*objs):
obj.blockSignals(previous)
-class LockReentrant():
- """Context manager to lock a code block and check the state.
- """
+class LockReentrant:
+ """Context manager to lock a code block and check the state."""
+
def __init__(self):
self.__locked = False
@@ -72,4 +72,5 @@ def getQEventName(eventType):
:returns: str
"""
from . import qtutils
+
return qtutils.getQEventName(eventType)
diff --git a/src/silx/gui/utils/glutils/__init__.py b/src/silx/gui/utils/glutils/__init__.py
index 2651402..8e34605 100644
--- a/src/silx/gui/utils/glutils/__init__.py
+++ b/src/silx/gui/utils/glutils/__init__.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2020-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2020-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -23,6 +23,8 @@
# ###########################################################################*/
"""This module provides the :func:`isOpenGLAvailable` utility function.
"""
+from __future__ import annotations
+
import os
import sys
@@ -37,9 +39,9 @@ class _isOpenGLAvailableResult:
an `error` string attribute storting the possible error message.
"""
- def __init__(self, status=True, error=''):
- self.__status = bool(status)
+ def __init__(self, error: str = "", status: bool = False):
self.__error = str(error)
+ self.__status = bool(status)
status = property(lambda self: self.__status, doc="True if OpenGL is working")
error = property(lambda self: self.__error, doc="Error message")
@@ -48,101 +50,119 @@ class _isOpenGLAvailableResult:
return self.status
def __repr__(self):
- return '<_isOpenGLAvailableResult: %s, "%s">' % (self.status, self.error)
+ return f'<_isOpenGLAvailableResult: {self.status}, "{self.error}">'
-def _runtimeOpenGLCheck(version, shareOpenGLContexts):
+def _runtimeOpenGLCheck(
+ version: tuple[int, int],
+ shareOpenGLContexts: bool,
+) -> _isOpenGLAvailableResult:
"""Run OpenGL check in a subprocess.
This is done by starting a subprocess that displays a Qt OpenGL widget.
- :param List[int] version:
+ :param version:
The minimal required OpenGL version as a 2-tuple (major, minor).
- :param bool shareOpenGLContexts:
+ :param shareOpenGLContexts:
True to test the `QApplication` with `AA_ShareOpenGLContexts`.
- :return: An error string that is empty if no error occured
- :rtype: str
+ :return: Result status and error message
"""
major, minor = str(version[0]), str(version[1])
env = os.environ.copy()
- env['PYTHONPATH'] = os.pathsep.join(
- [os.path.abspath(p) for p in sys.path])
+ env["PYTHONPATH"] = os.pathsep.join([os.path.abspath(p) for p in sys.path])
+
+ cmd = [sys.executable, "-s", "-S", __file__, major, minor]
+ if shareOpenGLContexts:
+ cmd.append("--shareOpenGLContexts")
try:
- cmd = [sys.executable, '-s', '-S', __file__, major, minor]
- if shareOpenGLContexts:
- cmd.append("--shareOpenGLContexts")
- error = subprocess.check_output(cmd, env=env, timeout=2)
+ output = subprocess.check_output(cmd, env=env, timeout=2)
except subprocess.TimeoutExpired:
- status = False
error = "Qt OpenGL widget hang"
- if sys.platform.startswith('linux'):
- error += ':\nIf connected remotely, GLX forwarding might be disabled.'
+ if sys.platform.startswith("linux"):
+ error += ":\nIf connected remotely, GLX forwarding might be disabled."
+ return _isOpenGLAvailableResult(error)
except subprocess.CalledProcessError as e:
- status = False
- error = "Qt OpenGL widget error: retcode=%d, error=%s" % (e.returncode, e.output)
- else:
- status = True
- error = error.decode()
- return _isOpenGLAvailableResult(status, error)
+ return _isOpenGLAvailableResult(
+ f"Qt OpenGL widget error: retcode={e.returncode}, error={e.output}"
+ )
+
+ return _isOpenGLAvailableResult(output.decode(), status=True)
_runtimeCheckCache = {} # Cache runtime check results: {version: result}
-def isOpenGLAvailable(version=(2, 1), runtimeCheck=True, shareOpenGLContexts=False):
+def isOpenGLAvailable(
+ version: tuple[int, int] = (2, 1),
+ runtimeCheck: bool = True,
+ shareOpenGLContexts: bool = False,
+) -> _isOpenGLAvailableResult:
"""Check if OpenGL is available through Qt and actually working.
After some basic tests, this is done by starting a subprocess that
displays a Qt OpenGL widget.
- :param List[int] version:
+ :param version:
The minimal required OpenGL version as a 2-tuple (major, minor).
Default: (2, 1)
- :param bool shareOpenGLContexts:
+ :param runtimeCheck:
+ True (default) to run the test creating a Qt OpenGL widget in a subprocess,
+ False to avoid this check.
+ :param shareOpenGLContexts:
True to test the `QApplication` with `AA_ShareOpenGLContexts`.
This only can be checked with `runtimeCheck` enabled.
Default is false.
- :param bool runtimeCheck:
- True (default) to run the test creating a Qt OpenGL widget in a subprocess,
- False to avoid this check.
:return: A result object that evaluates to True if successful and
which has a `status` boolean attribute (True if successful) and
an `error` string attribute that is not empty if `status` is False.
"""
- error = ''
-
- if sys.platform.startswith('linux') and not os.environ.get('DISPLAY', ''):
+ if sys.platform.startswith("linux") and not os.environ.get("DISPLAY", ""):
# On Linux and no DISPLAY available (e.g., ssh without -X)
- error = 'DISPLAY environment variable not set'
-
- else:
- # Check pyopengl availability
- try:
- import silx.gui._glutils.gl # noqa
- except ImportError:
- error = "Cannot import OpenGL wrapper: pyopengl is not installed"
- else:
- # Pre checks for Qt < 5.4
- if not hasattr(qt, 'QOpenGLWidget'):
- if not qt.HAS_OPENGL:
- error = '%s.QtOpenGL not available' % qt.BINDING
-
- elif qt.BINDING in ('PySide2', 'PyQt5') and qt.QApplication.instance() and not qt.QGLFormat.hasOpenGL():
- # qt.QGLFormat.hasOpenGL MUST be called with a QApplication created
- # so this is only checked if the QApplication is already created
- error = 'Qt reports OpenGL not available'
-
- result = _isOpenGLAvailableResult(error == '', error)
+ return _isOpenGLAvailableResult("DISPLAY environment variable not set")
+
+ # Check pyopengl availability
+ try:
+ from silx.gui._glutils import gl
+ except ImportError:
+ return _isOpenGLAvailableResult(
+ "Cannot import OpenGL wrapper: pyopengl is not installed"
+ )
+
+ # Pre checks for Qt < 5.4
+ if not hasattr(qt, "QOpenGLWidget"):
+ if not qt.HAS_OPENGL:
+ return _isOpenGLAvailableResult(f"{qt.BINDING}.QtOpenGL not available")
+
+ if (
+ qt.BINDING == "PyQt5"
+ and qt.QApplication.instance()
+ and not qt.QGLFormat.hasOpenGL()
+ ):
+ # qt.QGLFormat.hasOpenGL MUST be called with a QApplication created
+ # so this is only checked if the QApplication is already created
+ return _isOpenGLAvailableResult("Qt reports OpenGL not available")
+
+ # Check compatibility between Qt platform and pyopengl selected platform
+ qt_qpa_platform = qt.QGuiApplication.platformName()
+ pyopengl_platform = gl.getPlatform()
+ if (qt_qpa_platform == "wayland" and pyopengl_platform != "EGLPlatform") or (
+ qt_qpa_platform == "xcb" and pyopengl_platform != "GLXPlatform"
+ ):
+ return _isOpenGLAvailableResult(
+ f"Qt platform '{qt_qpa_platform}' is not compatible with PyOpenGL platform '{pyopengl_platform}'"
+ )
keyCache = version, shareOpenGLContexts
- if result: # No error so far, runtime check
- if keyCache in _runtimeCheckCache: # Use cache
- result = _runtimeCheckCache[keyCache]
- elif runtimeCheck: # Run test in subprocess
- result = _runtimeOpenGLCheck(version, shareOpenGLContexts)
- _runtimeCheckCache[keyCache] = result
+ if keyCache in _runtimeCheckCache: # Use cache
+ return _runtimeCheckCache[keyCache]
+
+ if not runtimeCheck:
+ return _isOpenGLAvailableResult(status=True)
+ # Run test in subprocess
+ result = _runtimeOpenGLCheck(version, shareOpenGLContexts)
+ _runtimeCheckCache[keyCache] = result
return result
@@ -154,15 +174,16 @@ if __name__ == "__main__":
class _TestOpenGLWidget(OpenGLWidget):
"""Widget checking that OpenGL is indeed available
- :param List[int] version: (major, minor) minimum OpenGL version
+ :param version: (major, minor) minimum OpenGL version
"""
- def __init__(self, version):
+ def __init__(self, version: tuple[int, int]):
super(_TestOpenGLWidget, self).__init__(
alphaBufferSize=0,
depthBufferSize=0,
stencilBufferSize=0,
- version=version)
+ version=version,
+ )
def paintEvent(self, event):
super(_TestOpenGLWidget, self).paintEvent(event)
@@ -176,25 +197,25 @@ if __name__ == "__main__":
qt.QTimer.singleShot(100, app.quit)
def paintGL(self):
- gl.glClearColor(1., 0., 0., 0.)
+ gl.glClearColor(1.0, 0.0, 0.0, 0.0)
gl.glClear(gl.GL_COLOR_BUFFER_BIT)
-
parser = argparse.ArgumentParser()
- parser.add_argument('major')
- parser.add_argument('minor')
- parser.add_argument('--shareOpenGLContexts', action="store_true")
+ parser.add_argument("major")
+ parser.add_argument("minor")
+ parser.add_argument("--shareOpenGLContexts", action="store_true")
args = parser.parse_args(args=sys.argv[1:])
if args.shareOpenGLContexts:
qt.QCoreApplication.setAttribute(qt.Qt.AA_ShareOpenGLContexts)
app = qt.QApplication([])
- window = qt.QMainWindow(flags=
- qt.Qt.Popup |
- qt.Qt.FramelessWindowHint |
- qt.Qt.NoDropShadowWindowHint |
- qt.Qt.WindowStaysOnTopHint)
+ window = qt.QMainWindow(
+ flags=qt.Qt.Popup
+ | qt.Qt.FramelessWindowHint
+ | qt.Qt.NoDropShadowWindowHint
+ | qt.Qt.WindowStaysOnTopHint
+ )
window.setAttribute(qt.Qt.WA_ShowWithoutActivating)
window.move(0, 0)
window.resize(3, 3)
diff --git a/src/silx/gui/utils/image.py b/src/silx/gui/utils/image.py
index 1757e3e..b9ab7c3 100644
--- a/src/silx/gui/utils/image.py
+++ b/src/silx/gui/utils/image.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2017-2022 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -39,7 +39,7 @@ from numpy.lib.stride_tricks import as_strided as _as_strided
from .. import qt
-def convertArrayToQImage(array):
+def convertArrayToQImage(array: numpy.ndarray) -> qt.QImage:
"""Convert an array-like image to a QImage.
The created QImage is using a copy of the array data.
@@ -50,73 +50,81 @@ def convertArrayToQImage(array):
Channels are expected to be either RGB or RGBA.
:type array: numpy.ndarray of uint8
:return: Corresponding Qt image with RGB888 or ARGB32 format.
- :rtype: QImage
"""
- array = numpy.array(array, copy=False, order='C', dtype=numpy.uint8)
+ array = numpy.array(array, copy=False, order="C", dtype=numpy.uint8)
if array.ndim != 3 or array.shape[2] not in (3, 4):
- raise ValueError(
- 'Image must be a 3D array with 3 or 4 channels per pixel')
+ raise ValueError("Image must be a 3D array with 3 or 4 channels per pixel")
if array.shape[2] == 4:
format_ = qt.QImage.Format_ARGB32
# RGBA -> ARGB + take care of endianness
- if sys.byteorder == 'little': # RGBA -> BGRA
+ if sys.byteorder == "little": # RGBA -> BGRA
array = array[:, :, (2, 1, 0, 3)]
else: # big endian: RGBA -> ARGB
array = array[:, :, (3, 0, 1, 2)]
- array = numpy.array(array, order='C') # Make a contiguous array
+ array = numpy.array(array, order="C") # Make a contiguous array
else: # array.shape[2] == 3
format_ = qt.QImage.Format_RGB888
height, width, depth = array.shape
qimage = qt.QImage(
- array.data,
- width,
- height,
- array.strides[0], # bytesPerLine
- format_)
+ array.data, width, height, array.strides[0], format_ # bytesPerLine
+ )
return qimage.copy() # Making a copy of the image and its data
-def convertQImageToArray(image):
+def convertQImageToArray(image: qt.QImage) -> numpy.ndarray:
"""Convert a QImage to a numpy array.
- If QImage format is not Format_RGB888, Format_RGBA8888 or Format_ARGB32,
- it is first converted to one of this format depending on
- the presence of an alpha channel.
+ If QImage format is not one of:
+
+ - Format_Grayscale8
+ - Format_RGB888
+ - Format_RGBA8888
+ - Format_ARGB32,
+
+ it is first converted to one of this format.
The created numpy array is using a copy of the QImage data.
:param QImage image: The QImage to convert.
- :return: The image array of RGB or RGBA channels of shape
- (height, width, channels (3 or 4))
- :rtype: numpy.ndarray of uint8
+ :return: Image array of uint8 of shape:
+
+ - (height, width) for grayscale images
+ - (height, width, channels (3 or 4)) for RGB and RGBA images
"""
- rgba8888 = getattr(qt.QImage, 'Format_RGBA8888', None) # Only in Qt5
+ supportedFormats = (
+ qt.QImage.Format_Grayscale8,
+ qt.QImage.Format_ARGB32,
+ qt.QImage.Format_RGB888,
+ qt.QImage.Format_RGBA8888,
+ )
# Convert to supported format if needed
- if image.format() not in (qt.QImage.Format_ARGB32,
- qt.QImage.Format_RGB888,
- rgba8888):
+ if image.format() not in supportedFormats:
if image.hasAlphaChannel():
- image = image.convertToFormat(
- rgba8888 if rgba8888 is not None else qt.QImage.Format_ARGB32)
+ image = image.convertToFormat(qt.QImage.Format_RGBA8888)
else:
image = image.convertToFormat(qt.QImage.Format_RGB888)
format_ = image.format()
- channels = 3 if format_ == qt.QImage.Format_RGB888 else 4
+ if format_ == qt.QImage.Format_Grayscale8:
+ channels = 1
+ elif format_ == qt.QImage.Format_RGB888:
+ channels = 3
+ else:
+ channels = 4
ptr = image.bits()
- if qt.BINDING == 'PyQt5':
+ if qt.BINDING == "PyQt5":
ptr.setsize(image.byteCount())
- elif qt.BINDING == 'PyQt6':
+ elif qt.BINDING == "PyQt6":
ptr.setsize(image.sizeInBytes())
- elif qt.BINDING in ('PySide2', 'PySide6'):
+ elif qt.BINDING == "PySide6":
ptr = ptr.tobytes()
else:
raise RuntimeError("Unsupported Qt binding: %s" % qt.BINDING)
@@ -125,17 +133,21 @@ def convertQImageToArray(image):
view = _as_strided(
numpy.frombuffer(ptr, dtype=numpy.uint8),
shape=(image.height(), image.width(), channels),
- strides=(image.bytesPerLine(), channels, 1))
+ strides=(image.bytesPerLine(), channels, 1),
+ )
if format_ == qt.QImage.Format_ARGB32:
# Convert from ARGB to RGBA
# Not a byte-ordered format: do care about endianness
- if sys.byteorder == 'little': # BGRA -> RGBA
+ if sys.byteorder == "little": # BGRA -> RGBA
view = view[:, :, (2, 1, 0, 3)]
else: # big endian: ARGB -> RGBA
view = view[:, :, (1, 2, 3, 0)]
+ if channels == 1: # Remove channel dimension
+ view = view[:, :, 0]
+
# Format_RGB888 and Format_RGBA8888 do not need reshuffling channels:
# They are byte-ordered and already in the right order
- return numpy.array(view, copy=True, order='C')
+ return numpy.array(view, copy=True, order="C")
diff --git a/src/silx/gui/utils/matplotlib.py b/src/silx/gui/utils/matplotlib.py
index 277a303..c51ccd2 100644
--- a/src/silx/gui/utils/matplotlib.py
+++ b/src/silx/gui/utils/matplotlib.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2016-2022 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2024 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -29,6 +29,8 @@ It MUST be imported prior to any other import of matplotlib.
It provides the matplotlib :class:`FigureCanvasQTAgg` class corresponding
to the used backend.
"""
+from __future__ import annotations
+
__authors__ = ["T. Vincent"]
__license__ = "MIT"
@@ -36,116 +38,148 @@ __date__ = "02/05/2018"
import io
-from pkg_resources import parse_version
import matplotlib
import numpy
from .. import qt
+# This must be performed before any import from matplotlib
+if qt.BINDING in ("PySide6", "PyQt6", "PyQt5"):
+ matplotlib.use("Qt5Agg", force=False)
+ from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg # noqa
+
+else:
+ raise ImportError("Unsupported Qt binding: %s" % qt.BINDING)
+
+
+from matplotlib.font_manager import FontProperties
+from matplotlib.mathtext import MathTextParser
+from matplotlib.ticker import ScalarFormatter as _ScalarFormatter
+from matplotlib import figure, font_manager
+from packaging.version import Version
+
+_MATPLOTLIB_VERSION = Version(matplotlib.__version__)
+
+
+class DefaultTickFormatter(_ScalarFormatter):
+ """Tick label formatter"""
+
+ def __init__(self):
+ super().__init__(useOffset=True, useMathText=True)
+ self.set_scientific(True)
+ self.create_dummy_axis()
+
+ if _MATPLOTLIB_VERSION < Version("3.1.0"):
+
+ def format_ticks(self, values):
+ self.set_locs(values)
+ return [self(value, i) for i, value in enumerate(values)]
+
+
+_FONT_STYLES = {
+ qt.QFont.StyleNormal: "normal",
+ qt.QFont.StyleItalic: "italic",
+ qt.QFont.StyleOblique: "oblique",
+}
+
+
+def qFontToFontProperties(font: qt.QFont):
+ """Convert a QFont to a matplotlib FontProperties"""
+ weightFactor = 10 if qt.BINDING == "PyQt5" else 1
+ families = [font.family(), font.defaultFamily()]
+ if _MATPLOTLIB_VERSION >= Version("3.6.0"):
+ # Prevent 'Font family not found' warnings
+ availableNames = font_manager.get_font_names()
+ families = [f for f in families if f in availableNames]
+ families.append(font_manager.fontManager.defaultFamily["ttf"])
+
+ if "Sans" in font.family():
+ families.insert(0, "sans-serif")
+
+ return FontProperties(
+ family=families,
+ style=_FONT_STYLES[font.style()],
+ weight=weightFactor * font.weight(),
+ size=font.pointSizeF(),
+ )
+
-def rasterMathText(text, font, size=-1, weight=-1, italic=False, devicePixelRatio=1.0):
+def rasterMathText(
+ text: str,
+ font: qt.QFont,
+ dotsPerInch: float = 96.0,
+) -> tuple[numpy.ndarray, float]:
"""Raster text using matplotlib supporting latex-like math syntax.
It supports multiple lines.
- :param str text: The text to raster
- :param font: Font name or QFont to use
- :type font: str or :class:`QFont`
- :param int size:
- Font size in points
- Used only if font is given as name.
- :param int weight:
- Font weight in [0, 99], see QFont.Weight.
- Used only if font is given as name.
- :param bool italic:
- True for italic font (default: False).
- Used only if font is given as name.
- :param float devicePixelRatio:
- The current ratio between device and device-independent pixel
- (default: 1.0)
+ :param text: The text to raster
+ :param font: Font to use
+ :param dotsPerInch: The DPI resolution of the created image
:return: Corresponding image in gray scale and baseline offset from top
- :rtype: (HxW numpy.ndarray of uint8, int)
"""
# Implementation adapted from:
# https://github.com/matplotlib/matplotlib/blob/d624571a19aec7c7d4a24123643288fc27db17e7/lib/matplotlib/mathtext.py#L264
- # Lazy import to avoid imports before setting matplotlib's rcParams
- from matplotlib.font_manager import FontProperties
- from matplotlib.mathtext import MathTextParser
- from matplotlib import figure
-
- dpi = 96 # default
- qapp = qt.QApplication.instance()
- if qapp:
- screen = qapp.primaryScreen()
- if screen:
- dpi = screen.logicalDotsPerInchY()
-
- # Make sure dpi is even, it causes issues with array reshape otherwise
- dpi = ((dpi * devicePixelRatio) // 2) * 2
-
stripped_text = text.strip("\n")
+ font_prop = qFontToFontProperties(font)
parser = MathTextParser("path")
- width, height, depth, _, _ = parser.parse(stripped_text, dpi=dpi)
- width *= 2
- height *= 2 * (stripped_text.count("\n") + 1)
-
- if not isinstance(font, qt.QFont):
- font = qt.QFont(font, size, weight, italic)
- prop = FontProperties(
- family=font.family(),
- style="italic" if font.italic() else "normal",
- weight=10 * font.weight(),
- size=font.pointSize(),
+ lines_info = [
+ parser.parse(line, prop=font_prop, dpi=dotsPerInch)
+ for line in stripped_text.split("\n")
+ ]
+ max_line_width = max(info[0] for info in lines_info)
+ # Use lp string as minimum height/ascent
+ ref_info = parser.parse("lp", prop=font_prop, dpi=dotsPerInch)
+ line_height = max(
+ ref_info[1],
+ *(info[1] for info in lines_info),
)
+ first_line_ascent = max(
+ ref_info[1] - ref_info[2], lines_info[0][1] - lines_info[0][2]
+ )
+
+ linespacing = 1.2
- fig = figure.Figure(figsize=(width / dpi, height / dpi))
- fig.text(0, depth / height, stripped_text, fontproperties=prop)
+ figure_height = numpy.ceil(line_height * len(lines_info) * linespacing) + 2
+ fig = figure.Figure(
+ figsize=(
+ (max_line_width + 1) / dotsPerInch,
+ figure_height / dotsPerInch,
+ )
+ )
+ fig.set_dpi(dotsPerInch)
+ text = fig.text(
+ 0,
+ 1,
+ stripped_text,
+ fontproperties=font_prop,
+ verticalalignment="top",
+ )
+ text.set_linespacing(linespacing)
with io.BytesIO() as buffer:
- fig.savefig(buffer, dpi=dpi, format="raw")
+ fig.savefig(buffer, dpi=dotsPerInch, format="raw")
+ canvas_width, canvas_height = fig.get_window_extent().max
buffer.seek(0)
image = numpy.frombuffer(buffer.read(), dtype=numpy.uint8).reshape(
- int(height), int(width), 4
+ int(canvas_height), int(canvas_width), 4
)
# RGB to inverted R channel
array = 255 - image[:, :, 0]
- # Remove leading and trailing empty columns/rows but one on each side
+ # Remove leading/trailing empty columns and trailing rows but one on each side
filled_rows = numpy.nonzero(numpy.sum(array, axis=1))[0]
filled_columns = numpy.nonzero(numpy.sum(array, axis=0))[0]
if len(filled_rows) == 0 or len(filled_columns) == 0:
- return array, image.shape[0] - 1
-
- clipped_array = numpy.ascontiguousarray(
- array[
- max(0, filled_rows[0] - 1) : filled_rows[-1] + 2,
- max(0, filled_columns[0] - 1) : filled_columns[-1] + 2,
- ]
+ return array, first_line_ascent
+ return (
+ numpy.ascontiguousarray(
+ array[
+ 0 : filled_rows[-1] + 2,
+ max(0, filled_columns[0] - 1) : filled_columns[-1] + 2,
+ ]
+ ),
+ first_line_ascent,
)
-
- return clipped_array, image.shape[0] - 1 # baseline not available
-
-
-def _matplotlib_use(backend, force):
- """Wrapper of `matplotlib.use` to set-up backend.
-
- It adds extra initialization for PySide2 with matplotlib < 2.2.
- """
- # This is kept for compatibility with matplotlib < 2.2
- if (
- parse_version(matplotlib.__version__) < parse_version("2.2")
- and qt.BINDING == "PySide2"
- ):
- matplotlib.rcParams["backend.qt5"] = "PySide2"
-
- matplotlib.use(backend, force=force)
-
-
-if qt.BINDING in ("PySide6", "PyQt6", "PyQt5", "PySide2"):
- _matplotlib_use("Qt5Agg", force=False)
- from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg # noqa
-
-else:
- raise ImportError("Unsupported Qt binding: %s" % qt.BINDING)
diff --git a/src/silx/gui/utils/projecturl.py b/src/silx/gui/utils/projecturl.py
index 116017e..125e8e7 100644
--- a/src/silx/gui/utils/projecturl.py
+++ b/src/silx/gui/utils/projecturl.py
@@ -67,7 +67,8 @@ def getDocumentationUrl(subpath):
"minor": version.MINOR,
"micro": version.MICRO,
"relev": version.RELEV,
- "subpath": subpath}
+ "subpath": subpath,
+ }
template = BASE_DOC_URL
if template is None:
template = _DEFAULT_BASE_DOC_URL
diff --git a/src/silx/gui/utils/signal.py b/src/silx/gui/utils/signal.py
index cd376a9..00a4d9b 100644
--- a/src/silx/gui/utils/signal.py
+++ b/src/silx/gui/utils/signal.py
@@ -30,8 +30,8 @@ import weakref
from time import time
from silx.gui.utils import concurrent
-__all__ = ['SignalProxy']
-__authors__ = ['L. Campagnola', 'M. Liberty']
+__all__ = ["SignalProxy"]
+__authors__ = ["L. Campagnola", "M. Liberty"]
__license__ = "MIT"
@@ -91,7 +91,9 @@ class SignalProxy(qt.QObject):
leakTime = max(0, (lastFlush + (1.0 / self.rateLimit)) - now)
concurrent.submitToQtMainThread(self.timer.stop)
- concurrent.submitToQtMainThread(self.timer.start, (min(leakTime, self.delay) * 1000) + 1)
+ concurrent.submitToQtMainThread(
+ self.timer.start, (min(leakTime, self.delay) * 1000) + 1
+ )
# self.timer.stop()
# self.timer.start((min(leakTime, self.delay) * 1000) + 1)
@@ -119,22 +121,19 @@ class SignalProxy(qt.QObject):
pass
-if __name__ == '__main__':
+if __name__ == "__main__":
app = qt.QApplication([])
win = qt.QMainWindow()
spin = qt.QSpinBox()
win.setCentralWidget(spin)
win.show()
-
def fn(*args):
print("Raw signal:", args)
-
def fn2(*args):
print("Delayed signal:", args)
-
spin.valueChanged.connect(fn)
# proxy = proxyConnect(spin, QtCore.SIGNAL('valueChanged(int)'), fn)
proxy = SignalProxy(spin.valueChanged, delay=0.5, slot=fn2)
diff --git a/src/silx/gui/utils/test/test.py b/src/silx/gui/utils/test/test.py
index 42bf5a2..59c031e 100644
--- a/src/silx/gui/utils/test/test.py
+++ b/src/silx/gui/utils/test/test.py
@@ -28,7 +28,6 @@ __license__ = "MIT"
__date__ = "01/08/2019"
-import unittest
from silx.gui import qt
from silx.gui.utils.testutils import TestCaseQt, SignalListener
diff --git a/src/silx/gui/utils/test/test_async.py b/src/silx/gui/utils/test/test_async.py
index 1fd8509..ef61df2 100644
--- a/src/silx/gui/utils/test/test_async.py
+++ b/src/silx/gui/utils/test/test_async.py
@@ -29,8 +29,6 @@ __date__ = "09/03/2018"
import threading
-import unittest
-
from concurrent.futures import wait
from silx.gui import qt
@@ -51,7 +49,7 @@ class TestSubmitToQtThread(TestCaseQt):
return value1, value2
def _taskWithException(self, *args, **kwargs):
- raise RuntimeError('task exception')
+ raise RuntimeError("task exception")
def testFromMainThread(self):
"""Call submitToQtMainThread from the main thread"""
@@ -97,10 +95,11 @@ class TestSubmitToQtThread(TestCaseQt):
if not thread.is_alive():
break
else:
- self.fail(('Thread task still running'))
+ self.fail(("Thread task still running"))
def testFromQtThread(self):
"""Call submitToQtMainThread from a Qt thread pool"""
+
class Runner(qt.QRunnable):
def __init__(self, fn):
super(Runner, self).__init__()
@@ -121,4 +120,4 @@ class TestSubmitToQtThread(TestCaseQt):
if done:
break
else:
- self.fail('Thread pool task still running')
+ self.fail("Thread pool task still running")
diff --git a/src/silx/gui/utils/test/test_glutils.py b/src/silx/gui/utils/test/test_glutils.py
index 4921f16..fb19e36 100644
--- a/src/silx/gui/utils/test/test_glutils.py
+++ b/src/silx/gui/utils/test/test_glutils.py
@@ -37,7 +37,9 @@ from silx.gui.utils.glutils import isOpenGLAvailable
_logger = logging.getLogger(__name__)
-@pytest.mark.parametrize("params", (((2, 1), False), ((2, 1), False), ((1000, 1), False), ((2, 1), True)))
+@pytest.mark.parametrize(
+ "params", (((2, 1), False), ((2, 1), False), ((1000, 1), False), ((2, 1), True))
+)
def testOpenGLAvailable(params):
version, shareOpenGLContexts = params
result = isOpenGLAvailable(version=version, shareOpenGLContexts=shareOpenGLContexts)
diff --git a/src/silx/gui/utils/test/test_image.py b/src/silx/gui/utils/test/test_image.py
index 07bc396..9ae1b80 100644
--- a/src/silx/gui/utils/test/test_image.py
+++ b/src/silx/gui/utils/test/test_image.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -28,51 +28,58 @@ __license__ = "MIT"
__date__ = "16/01/2017"
import numpy
-import unittest
+import pytest
from silx.gui import qt
-from silx.utils.testutils import ParametricTestCase
-from silx.gui.utils.testutils import TestCaseQt
from silx.gui.utils.image import convertArrayToQImage, convertQImageToArray
-class TestQImageConversion(TestCaseQt, ParametricTestCase):
- """Tests conversion of QImage to/from numpy array."""
+@pytest.mark.parametrize(
+ "format_, channels",
+ [
+ (qt.QImage.Format_RGB888, 3), # Native support
+ (qt.QImage.Format_ARGB32, 4), # Native support
+ ],
+)
+def testConvertArrayToQImage(format_, channels):
+ """Test conversion of numpy array to QImage"""
+ image = numpy.arange(3 * 3 * channels, dtype=numpy.uint8).reshape(3, 3, channels)
+ qimage = convertArrayToQImage(image)
- def testConvertArrayToQImage(self):
- """Test conversion of numpy array to QImage"""
- for format_, channels in [('Format_RGB888', 3),
- ('Format_ARGB32', 4)]:
- with self.subTest(format_):
- image = numpy.arange(
- 3*3*channels, dtype=numpy.uint8).reshape(3, 3, channels)
- qimage = convertArrayToQImage(image)
+ assert (qimage.height(), qimage.width()) == image.shape[:2]
+ assert qimage.format() == format_
- self.assertEqual(qimage.height(), image.shape[0])
- self.assertEqual(qimage.width(), image.shape[1])
- self.assertEqual(qimage.format(), getattr(qt.QImage, format_))
+ for row in range(3):
+ for col in range(3):
+ # Qrgb has no alpha channel, not compared
+ # Qt uses x,y while array is row,col...
+ assert qt.QColor(qimage.pixel(col, row)) == qt.QColor(*image[row, col, :3])
- for row in range(3):
- for col in range(3):
- # Qrgb has no alpha channel, not compared
- # Qt uses x,y while array is row,col...
- self.assertEqual(qt.QColor(qimage.pixel(col, row)),
- qt.QColor(*image[row, col, :3]))
+@pytest.mark.parametrize(
+ "format_, channels",
+ [
+ (qt.QImage.Format_RGB888, 3), # Native support
+ (qt.QImage.Format_ARGB32, 4), # Native support
+ (qt.QImage.Format_RGB32, 3), # Conversion to RGB
+ ],
+)
+def testConvertQImageToArray(format_, channels):
+ """Test conversion of QImage to numpy array"""
+ color = numpy.arange(channels) # RGB(A) values
+ qimage = qt.QImage(3, 3, format_)
+ qimage.fill(qt.QColor(*color))
+ image = convertQImageToArray(qimage)
- def testConvertQImageToArray(self):
- """Test conversion of QImage to numpy array"""
- for format_, channels in [
- ('Format_RGB888', 3), # Native support
- ('Format_ARGB32', 4), # Native support
- ('Format_RGB32', 3)]: # Conversion to RGB
- with self.subTest(format_):
- color = numpy.arange(channels) # RGB(A) values
- qimage = qt.QImage(3, 3, getattr(qt.QImage, format_))
- qimage.fill(qt.QColor(*color))
- image = convertQImageToArray(qimage)
+ assert (qimage.height(), qimage.width(), len(color)) == image.shape
+ assert numpy.all(numpy.equal(image, color))
- self.assertEqual(qimage.height(), image.shape[0])
- self.assertEqual(qimage.width(), image.shape[1])
- self.assertEqual(image.shape[2], len(color))
- self.assertTrue(numpy.all(numpy.equal(image, color)))
+
+def testConvertQImageToArrayGrayscale():
+ """Test conversion of grayscale QImage to numpy array"""
+ qimage = qt.QImage(3, 3, qt.QImage.Format_Grayscale8)
+ qimage.fill(1)
+ image = convertQImageToArray(qimage)
+
+ assert (qimage.height(), qimage.width()) == image.shape
+ assert numpy.all(numpy.equal(image, 1))
diff --git a/src/silx/gui/utils/test/test_qtutils.py b/src/silx/gui/utils/test/test_qtutils.py
index c5ff2d2..23e6cdf 100755
--- a/src/silx/gui/utils/test/test_qtutils.py
+++ b/src/silx/gui/utils/test/test_qtutils.py
@@ -28,7 +28,6 @@ __license__ = "MIT"
__date__ = "01/08/2019"
-import unittest
from silx.gui import qt
from silx.gui import utils
from silx.gui.utils.testutils import TestCaseQt
diff --git a/src/silx/gui/utils/test/test_testutils.py b/src/silx/gui/utils/test/test_testutils.py
index e8a0123..2277cb3 100644
--- a/src/silx/gui/utils/test/test_testutils.py
+++ b/src/silx/gui/utils/test/test_testutils.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -28,16 +28,13 @@ __license__ = "MIT"
__date__ = "16/01/2017"
import unittest
-import sys
-from silx.gui import qt
from ..testutils import TestCaseQt
class TestOutcome(unittest.TestCase):
"""Tests conversion of QImage to/from numpy array."""
- @unittest.skipIf(sys.version_info.major <= 2, 'Python3 only')
def testNoneOutcome(self):
test = TestCaseQt()
test._currentTestSucceeded()
diff --git a/src/silx/gui/utils/testutils.py b/src/silx/gui/utils/testutils.py
index 1ec9b0b..76d0b9b 100644
--- a/src/silx/gui/utils/testutils.py
+++ b/src/silx/gui/utils/testutils.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2016-2022 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -25,7 +25,7 @@
__authors__ = ["T. Vincent"]
__license__ = "MIT"
-__date__ = "22/07/2022"
+__date__ = "22/11/2023"
import gc
@@ -42,16 +42,14 @@ from silx.gui import qt
from silx.gui.qt import inspect as _inspect
-if qt.BINDING == 'PySide2':
- from PySide2.QtTest import QTest
-elif qt.BINDING == 'PyQt5':
+if qt.BINDING == "PyQt5":
from PyQt5.QtTest import QTest
-elif qt.BINDING == 'PySide6':
+elif qt.BINDING == "PySide6":
from PySide6.QtTest import QTest
-elif qt.BINDING == 'PyQt6':
+elif qt.BINDING == "PyQt6":
from PyQt6.QtTest import QTest
else:
- raise ImportError('Unsupported Qt bindings')
+ raise ImportError("Unsupported Qt bindings")
def qWaitForWindowExposedAndActivate(window, timeout=None):
@@ -85,7 +83,7 @@ class TestCaseQt(unittest.TestCase):
To allow some widgets to remain alive at the end of a test, set the
allowedLeakingWidgets attribute to the number of widgets that can remain
alive at the end of the test.
- With PySide2, this test is not run for now as it seems PySide2
+ With PySide, this test is not run for now as it seems PySide
is leaking widgets internally.
All keyboard and mouse event simulation methods call qWait(20) after
@@ -111,8 +109,9 @@ class TestCaseQt(unittest.TestCase):
@classmethod
def exceptionHandler(cls, exceptionClass, exception, stack):
import traceback
- message = (''.join(traceback.format_tb(stack)))
- template = 'Traceback (most recent call last):\n{2}{0}: {1}'
+
+ message = "".join(traceback.format_tb(stack))
+ template = "Traceback (most recent call last):\n{2}{0}: {1}"
message = template.format(exceptionClass.__name__, exception, message)
cls._exceptions.append(message)
@@ -133,41 +132,46 @@ class TestCaseQt(unittest.TestCase):
def setUp(self):
"""Get the list of existing widgets."""
self.allowedLeakingWidgets = 0
- if qt.BINDING in ('PySide2', 'PySide6'):
+ if qt.BINDING == "PySide6":
self.__previousWidgets = None
else:
self.__previousWidgets = self.qapp.allWidgets()
self.__class__._exceptions = []
def _currentTestSucceeded(self):
- if hasattr(self, '_outcome'):
- if hasattr(self, '_feedErrorsToResult'):
- # For Python 3.4 -3.10
- result = self.defaultTestResult() # these 2 methods have no side effects
- if hasattr(self._outcome, 'errors'):
- self._feedErrorsToResult(result, self._outcome.errors)
- else:
- # Python 3.11+
- result = self._outcome.result
+ if hasattr(self, "_feedErrorsToResult"):
+ # Python 3.4 - 3.10 (These two methods have no side effects)
+ result = self.defaultTestResult()
+ if hasattr(self._outcome, "errors"):
+ self._feedErrorsToResult(result, self._outcome.errors)
+ elif hasattr(self._outcome, "result"):
+ # Python 3.11+
+ result = self._outcome.result
+
+ if self._outcome is None:
+ return True
+ elif hasattr(self._outcome, "success"):
+ # using pytest
+ return self._outcome.success
else:
- # For Python < 3.4
- result = getattr(self, '_outcomeForDoCleanups', self._resultForDoCleanups)
-
- skipped = self.id() in [case.id() for case, _ in result.skipped]
- error = self.id() in [case.id() for case, _ in result.errors]
- failure = self.id() in [case.id() for case, _ in result.failures]
- return not error and not failure and not skipped
+ # using unittest
+ return all(test != self for test, text in result.errors + result.failures)
def _checkForUnreleasedWidgets(self):
"""Test fixture checking that no more widgets exists."""
if self.__previousWidgets is None:
- return # Do not test for leaking widgets with PySide2
+ return # Do not test for leaking widgets with PySide
gc.collect()
- widgets = [widget for widget in self.qapp.allWidgets()
- if (widget not in self.__previousWidgets and
- _inspect.createdByPython(widget))]
+ widgets = [
+ widget
+ for widget in self.qapp.allWidgets()
+ if (
+ widget not in self.__previousWidgets
+ and _inspect.createdByPython(widget)
+ )
+ ]
self.__previousWidgets = None
allowedLeakingWidgets = self.allowedLeakingWidgets
@@ -175,12 +179,11 @@ class TestCaseQt(unittest.TestCase):
if widgets and len(widgets) <= allowedLeakingWidgets:
_logger.info(
- '%s: %d remaining widgets after test' % (self.id(),
- len(widgets)))
+ "%s: %d remaining widgets after test" % (self.id(), len(widgets))
+ )
if len(widgets) > allowedLeakingWidgets:
- raise RuntimeError(
- "Test ended with widgets alive: %s" % str(widgets))
+ raise RuntimeError("Test ended with widgets alive: %s" % str(widgets))
def tearDown(self):
self.qapp.processEvents()
@@ -208,8 +211,9 @@ class TestCaseQt(unittest.TestCase):
Click = QTest.Click
"""Key click action code"""
- QTest = property(lambda self: QTest,
- doc="""The Qt QTest class from the used Qt binding.""")
+ QTest = property(
+ lambda self: QTest, doc="""The Qt QTest class from the used Qt binding."""
+ )
def keyClick(self, widget, key, modifier=qt.Qt.NoModifier, delay=-1):
"""Simulate clicking a key.
@@ -227,8 +231,7 @@ class TestCaseQt(unittest.TestCase):
QTest.keyClicks(widget, sequence, modifier, delay)
self.qWait(20)
- def keyEvent(self, action, widget, key,
- modifier=qt.Qt.NoModifier, delay=-1):
+ def keyEvent(self, action, widget, key, modifier=qt.Qt.NoModifier, delay=-1):
"""Sends a Qt key event.
See QTest.keyEvent for details.
@@ -321,14 +324,13 @@ class TestCaseQt(unittest.TestCase):
if ms is None:
ms = cls.DEFAULT_TIMEOUT_WAIT
- if qt.BINDING in ('PySide2', 'PySide6'):
- # PySide2 has no qWait, provide a replacement
+ if qt.BINDING == "PySide6":
+ # PySide has no qWait, provide a replacement
timeout = int(ms)
endTimeMS = int(time.time() * 1000) + timeout
qapp = qt.QApplication.instance()
while timeout > 0:
- qapp.processEvents(qt.QEventLoop.AllEvents,
- timeout)
+ qapp.processEvents(qt.QEventLoop.AllEvents, timeout)
timeout = endTimeMS - int(time.time() * 1000)
else:
QTest.qWait(int(ms) + cls.TIMEOUT_WAIT)
@@ -420,8 +422,7 @@ class TestCaseQt(unittest.TestCase):
class SignalListener(object):
- """Util to listen a Qt event and store parameters
- """
+ """Util to listen a Qt event and store parameters"""
def __init__(self):
self.__calls = []
@@ -503,7 +504,7 @@ def getQToolButtonFromAction(action):
def findChildren(parent, kind, name=None):
- if qt.BINDING in ("PySide2", "PySide6") and name is not None:
+ if qt.BINDING == "PySide6" and name is not None:
result = []
for obj in parent.findChildren(kind):
if obj.objectName() == name:
diff --git a/src/silx/gui/widgets/ElidedLabel.py b/src/silx/gui/widgets/ElidedLabel.py
index 3760ec0..ae45931 100644
--- a/src/silx/gui/widgets/ElidedLabel.py
+++ b/src/silx/gui/widgets/ElidedLabel.py
@@ -62,7 +62,7 @@ class ElidedLabel(qt.QLabel):
def __updateMinimumSize(self):
metrics = self.fontMetrics()
- if qt.BINDING in ('PySide2', 'PyQt5'):
+ if qt.BINDING == "PyQt5":
width = metrics.width("...")
else: # Qt6
width = metrics.horizontalAdvance("...")
@@ -93,7 +93,7 @@ class ElidedLabel(qt.QLabel):
"""
return self.__text
- @deprecated(replacement='text', since_version='1.1.0')
+ @deprecated(replacement="text", since_version="1.1.0")
def getText(self):
return self.text()
@@ -109,7 +109,7 @@ class ElidedLabel(qt.QLabel):
"""
return self.__toolTip
- @deprecated(replacement='toolTip', since_version='1.1.0')
+ @deprecated(replacement="toolTip", since_version="1.1.0")
def getToolTip(self):
return self.toolTip()
diff --git a/src/silx/gui/widgets/FloatEdit.py b/src/silx/gui/widgets/FloatEdit.py
index 61f518f..f9d7331 100644
--- a/src/silx/gui/widgets/FloatEdit.py
+++ b/src/silx/gui/widgets/FloatEdit.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -23,6 +23,8 @@
# ###########################################################################*/
"""Module contains a float editor
"""
+from __future__ import annotations
+
__authors__ = ["V.A. Sole", "T. Vincent"]
__license__ = "MIT"
@@ -34,18 +36,33 @@ from .. import qt
class FloatEdit(qt.QLineEdit):
"""Field to edit a float value.
- :param parent: See :class:`QLineEdit`
- :param float value: The value to set the QLineEdit to.
+ The value can be modified with :meth:`value` and :meth:`setValue`.
+
+ The property :meth:`widgetResizable` allow to change the default
+ behaviour in order to automatically resize the widget to the displayed value.
+ Use :meth:`setMinimumWidth` to enforce the minimum width.
+
+ :param parent: Parent of the widget
+ :param value: The value to set the QLineEdit to.
"""
- def __init__(self, parent=None, value=None):
+
+ _QLineEditPrivateHorizontalMargin = 2
+ """Constant from Qt source code"""
+
+ def __init__(self, parent: qt.QWidget | None = None, value: float | None = None):
qt.QLineEdit.__init__(self, parent)
validator = qt.QDoubleValidator(self)
+ self.__widgetResizable: bool = False
+ self.__minimumWidth = 30
+ """Store the minimum width requested by the user, the real one is
+ dynamic"""
self.setValidator(validator)
self.setAlignment(qt.Qt.AlignRight)
+ self.textChanged.connect(self.__textChanged)
if value is not None:
self.setValue(value)
- def value(self):
+ def value(self) -> float:
"""Return the QLineEdit current value as a float."""
text = self.text()
value, validated = self.validator().locale().toDouble(text)
@@ -53,16 +70,85 @@ class FloatEdit(qt.QLineEdit):
self.setValue(value)
return value
- def setValue(self, value):
+ def setValue(self, value: float):
"""Set the current value of the LineEdit
- :param float value: The value to set the QLineEdit to.
+ :param value: The value to set the QLineEdit to.
"""
locale = self.validator().locale()
if qt.BINDING == "PySide6":
# Fix for PySide6 not selecting the right method
- text = locale.toString(float(value), 'g')
+ text = locale.toString(float(value), "g")
else:
text = locale.toString(float(value))
self.setText(text)
+ if self.__widgetResizable:
+ self.__forceMinimumWidthFromContent()
+
+ def __textChanged(self, text: str):
+ if self.__widgetResizable:
+ self.__forceMinimumWidthFromContent()
+
+ def widgetResizable(self) -> bool:
+ """
+ Returns whether or not the widget auto resizes itself based on it's content
+ """
+ return self.__widgetResizable
+
+ def setWidgetResizable(self, resizable: bool):
+ """
+ If true, the widget will automatically resize itself to its displayed content.
+
+ This avoids to have to scroll to see the widget's content, and allow to take
+ advantage of extra space.
+ """
+ if self.__widgetResizable == resizable:
+ return
+ self.__widgetResizable = resizable
+ self.updateGeometry()
+ if resizable:
+ self.__forceMinimumWidthFromContent()
+ else:
+ qt.QLineEdit.setMinimumWidth(self, self.__minimumWidth)
+
+ def __minimumWidthFromContent(self) -> int:
+ """Minimum size for the widget to properly read the actual number"""
+ text = self.text()
+ font = self.font()
+ metrics = qt.QFontMetrics(font)
+ margins = self.textMargins()
+ width = (
+ metrics.horizontalAdvance(text)
+ + self._QLineEditPrivateHorizontalMargin * 2
+ + margins.left()
+ + margins.right()
+ )
+ width = max(self.__minimumWidth, width)
+ opt = qt.QStyleOptionFrame()
+ self.initStyleOption(opt)
+ s = self.style().sizeFromContents(
+ qt.QStyle.CT_LineEdit, opt, qt.QSize(width, self.height())
+ )
+ return s.width()
+
+ def sizeHint(self) -> qt.QSize:
+ sizeHint = qt.QLineEdit.sizeHint(self)
+ if not self.__widgetResizable:
+ return sizeHint
+ width = self.__minimumWidthFromContent()
+ return qt.QSize(width, sizeHint.height())
+
+ def __forceMinimumWidthFromContent(self):
+ width = self.__minimumWidthFromContent()
+ qt.QLineEdit.setMinimumWidth(self, width)
+ self.updateGeometry()
+
+ def setMinimumWidth(self, width: int):
+ self.__minimumWidth = width
+ qt.QLineEdit.setMinimumWidth(self, width)
+ self.updateGeometry()
+
+ def minimumWidth(self) -> int:
+ """Returns the user defined minimum width."""
+ return self.__minimumWidth
diff --git a/src/silx/gui/widgets/FlowLayout.py b/src/silx/gui/widgets/FlowLayout.py
index 917aa09..691cb06 100644
--- a/src/silx/gui/widgets/FlowLayout.py
+++ b/src/silx/gui/widgets/FlowLayout.py
@@ -105,13 +105,13 @@ class FlowLayout(qt.QLayout):
spaceX = widget.style().layoutSpacing(
qt.QSizePolicy.PushButton,
qt.QSizePolicy.PushButton,
- qt.Qt.Horizontal)
+ qt.Qt.Horizontal,
+ )
spaceY = self.verticalSpacing()
if spaceY == -1:
spaceY = widget.style().layoutSpacing(
- qt.QSizePolicy.PushButton,
- qt.QSizePolicy.PushButton,
- qt.Qt.Vertical)
+ qt.QSizePolicy.PushButton, qt.QSizePolicy.PushButton, qt.Qt.Vertical
+ )
nextX = x + item.sizeHint().width() + spaceX
if (nextX - spaceX) > effectiveRect.right() and lineHeight > 0:
diff --git a/src/silx/gui/widgets/FormGridLayout.py b/src/silx/gui/widgets/FormGridLayout.py
index 6068d30..a1a26b2 100644
--- a/src/silx/gui/widgets/FormGridLayout.py
+++ b/src/silx/gui/widgets/FormGridLayout.py
@@ -39,6 +39,7 @@ class FormGridLayout(qt.QGridLayout):
This allow a bit more flexibility, like allow vertical expanding
of the rows.
"""
+
def __init__(self, parent):
super(FormGridLayout, self).__init__(parent)
self.__cursor = 0
@@ -51,7 +52,11 @@ class FormGridLayout(qt.QGridLayout):
something = qt.QLabel(something)
self.addWidget(something, row, column, rowSpan, columnSpan)
- def addRow(self, label: typing.Union[str, qt.QWidget, qt.QLayout], field: typing.Union[None, qt.QWidget, qt.QLayout] = None):
+ def addRow(
+ self,
+ label: typing.Union[str, qt.QWidget, qt.QLayout],
+ field: typing.Union[None, qt.QWidget, qt.QLayout] = None,
+ ):
"""
Adds a new row to the bottom of this form layout.
diff --git a/src/silx/gui/widgets/FrameBrowser.py b/src/silx/gui/widgets/FrameBrowser.py
index 17a9148..c03b2a8 100644
--- a/src/silx/gui/widgets/FrameBrowser.py
+++ b/src/silx/gui/widgets/FrameBrowser.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -32,7 +32,6 @@
"""
from silx.gui import qt
from silx.gui import icons
-from silx.utils import deprecation
__authors__ = ["V.A. Sole", "P. Knobel"]
__license__ = "MIT"
@@ -94,7 +93,9 @@ class FrameBrowser(qt.QWidget):
else:
first, last = 0, n
- self._lineEdit.setFixedWidth(self._lineEdit.fontMetrics().boundingRect('%05d' % last).width())
+ self._lineEdit.setFixedWidth(
+ self._lineEdit.fontMetrics().boundingRect("%05d" % last).width()
+ )
validator = qt.QIntValidator(first, last, self._lineEdit)
self._lineEdit.setValidator(validator)
self._lineEdit.setText("%d" % first)
@@ -152,7 +153,7 @@ class FrameBrowser(qt.QWidget):
"event": "indexChanged",
"old": self._index,
"new": new_value,
- "id": id(self)
+ "id": id(self),
}
self._index = new_value
self.sigIndexChanged.emit(ddict)
@@ -182,11 +183,6 @@ class FrameBrowser(qt.QWidget):
# Update limits
self._label.setText(" limits: %d, %d " % (bottom, top))
- @deprecation.deprecated(replacement="FrameBrowser.setRange",
- since_version="0.8")
- def setLimits(self, first, last):
- return self.setRange(first, last)
-
def setNFrames(self, nframes):
"""Set minimum=0 and maximum=nframes-1 frame numbers.
@@ -199,11 +195,6 @@ class FrameBrowser(qt.QWidget):
# display 1-based index in label
self._label.setText(" of %d " % top)
- @deprecation.deprecated(replacement="FrameBrowser.getValue",
- since_version="0.8")
- def getCurrentIndex(self):
- return self._index
-
def getValue(self):
"""Return current frame index"""
return self._index
@@ -243,6 +234,7 @@ class HorizontalSliderWithBrowser(qt.QAbstractSlider):
:param QWidget parent: Optional parent widget
"""
+
def __init__(self, parent=None):
qt.QAbstractSlider.__init__(self, parent)
self.setOrientation(qt.Qt.Horizontal)
@@ -302,14 +294,13 @@ class HorizontalSliderWithBrowser(qt.QAbstractSlider):
self._browser.setRange(first, last)
def _sliderSlot(self, value):
- """Emit selected value when slider is activated
- """
+ """Emit selected value when slider is activated"""
self._browser.setValue(value)
self.valueChanged.emit(value)
def _browserSlot(self, ddict):
"""Emit selected value when browser state is changed"""
- self._slider.setValue(ddict['new'])
+ self._slider.setValue(ddict["new"])
def setValue(self, value):
"""Set value
diff --git a/src/silx/gui/widgets/LegendIconWidget.py b/src/silx/gui/widgets/LegendIconWidget.py
index d0d2f5c..ae86c35 100755
--- a/src/silx/gui/widgets/LegendIconWidget.py
+++ b/src/silx/gui/widgets/LegendIconWidget.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -44,27 +44,27 @@ _logger = logging.getLogger(__name__)
# Courtesy of the pyqtgraph project
_Symbols = None
-""""Cache supported symbols as Qt paths"""
+"""Cache supported symbols as Qt paths"""
-_NoSymbols = (None, 'None', 'none', '', ' ')
+_NoSymbols = (None, "None", "none", "", " ")
"""List of values resulting in no symbol being displayed for a curve"""
_LineStyles = {
None: qt.Qt.NoPen,
- 'None': qt.Qt.NoPen,
- 'none': qt.Qt.NoPen,
- '': qt.Qt.NoPen,
- ' ': qt.Qt.NoPen,
- '-': qt.Qt.SolidLine,
- '--': qt.Qt.DashLine,
- ':': qt.Qt.DotLine,
- '-.': qt.Qt.DashDotLine
+ "None": qt.Qt.NoPen,
+ "none": qt.Qt.NoPen,
+ "": qt.Qt.NoPen,
+ " ": qt.Qt.NoPen,
+ "-": qt.Qt.SolidLine,
+ "--": qt.Qt.DashLine,
+ ":": qt.Qt.DotLine,
+ "-.": qt.Qt.DashDotLine,
}
"""Conversion from matplotlib-like linestyle to Qt"""
-_NoLineStyle = (None, 'None', 'none', '', ' ')
+_NoLineStyle = (None, "None", "none", "", " ")
"""List of style values resulting in no line being displayed for a curve"""
@@ -82,22 +82,45 @@ def _initSymbols():
if _Symbols is not None:
return
- symbols = dict([(name, qt.QPainterPath())
- for name in ['o', 's', 't', 'd', '+', 'x', '.', ',']])
- symbols['o'].addEllipse(qt.QRectF(.1, .1, .8, .8))
- symbols['.'].addEllipse(qt.QRectF(.3, .3, .4, .4))
- symbols[','].addEllipse(qt.QRectF(.4, .4, .2, .2))
- symbols['s'].addRect(qt.QRectF(.1, .1, .8, .8))
+ symbols = dict(
+ [(name, qt.QPainterPath()) for name in ["o", "s", "t", "d", "+", "x", ".", ","]]
+ )
+ symbols["o"].addEllipse(qt.QRectF(0.1, 0.1, 0.8, 0.8))
+ symbols["."].addEllipse(qt.QRectF(0.3, 0.3, 0.4, 0.4))
+ symbols[","].addEllipse(qt.QRectF(0.4, 0.4, 0.2, 0.2))
+ symbols["s"].addRect(qt.QRectF(0.1, 0.1, 0.8, 0.8))
coords = {
- 't': [(0.5, 0.), (.1, .8), (.9, .8)],
- 'd': [(0.1, 0.5), (0.5, 0.), (0.9, 0.5), (0.5, 1.)],
- '+': [(0.0, 0.40), (0.40, 0.40), (0.40, 0.), (0.60, 0.),
- (0.60, 0.40), (1., 0.40), (1., 0.60), (0.60, 0.60),
- (0.60, 1.), (0.40, 1.), (0.40, 0.60), (0., 0.60)],
- 'x': [(0.0, 0.40), (0.40, 0.40), (0.40, 0.), (0.60, 0.),
- (0.60, 0.40), (1., 0.40), (1., 0.60), (0.60, 0.60),
- (0.60, 1.), (0.40, 1.), (0.40, 0.60), (0., 0.60)]
+ "t": [(0.5, 0.0), (0.1, 0.8), (0.9, 0.8)],
+ "d": [(0.1, 0.5), (0.5, 0.0), (0.9, 0.5), (0.5, 1.0)],
+ "+": [
+ (0.0, 0.40),
+ (0.40, 0.40),
+ (0.40, 0.0),
+ (0.60, 0.0),
+ (0.60, 0.40),
+ (1.0, 0.40),
+ (1.0, 0.60),
+ (0.60, 0.60),
+ (0.60, 1.0),
+ (0.40, 1.0),
+ (0.40, 0.60),
+ (0.0, 0.60),
+ ],
+ "x": [
+ (0.0, 0.40),
+ (0.40, 0.40),
+ (0.40, 0.0),
+ (0.60, 0.0),
+ (0.60, 0.40),
+ (1.0, 0.40),
+ (1.0, 0.60),
+ (0.60, 0.60),
+ (0.60, 1.0),
+ (0.40, 1.0),
+ (0.40, 0.60),
+ (0.0, 0.60),
+ ],
}
for s, c in coords.items():
symbols[s].moveTo(*c[0])
@@ -106,9 +129,9 @@ def _initSymbols():
symbols[s].closeSubpath()
tr = qt.QTransform()
tr.rotate(45)
- symbols['x'].translate(qt.QPointF(-0.5, -0.5))
- symbols['x'] = tr.map(symbols['x'])
- symbols['x'].translate(qt.QPointF(0.5, 0.5))
+ symbols["x"].translate(qt.QPointF(-0.5, -0.5))
+ symbols["x"] = tr.map(symbols["x"])
+ symbols["x"].translate(qt.QPointF(0.5, 0.5))
_Symbols = symbols
@@ -130,10 +153,11 @@ class LegendIconWidget(qt.QWidget):
# Line attributes
self.lineStyle = qt.Qt.NoPen
- self.lineWidth = 1.
+ self.__dashPattern = []
+ self.lineWidth = 1.0
self.lineColor = qt.Qt.green
- self.symbol = ''
+ self.symbol = ""
# Symbol attributes
self.symbolStyle = qt.Qt.SolidPattern
self.symbolColor = qt.Qt.green
@@ -147,8 +171,7 @@ class LegendIconWidget(qt.QWidget):
# Control widget size: sizeHint "is the only acceptable
# alternative, so the widget can never grow or shrink"
# (c.f. Qt Doc, enum QSizePolicy::Policy)
- self.setSizePolicy(qt.QSizePolicy.Fixed,
- qt.QSizePolicy.Fixed)
+ self.setSizePolicy(qt.QSizePolicy.Fixed, qt.QSizePolicy.Fixed)
def sizeHint(self):
return qt.QSize(50, 15)
@@ -190,12 +213,21 @@ class LegendIconWidget(qt.QWidget):
- '--': dashed
- ':': dotted
- '-.': dash and dot
+ - (offset, (dash pattern))
- :param str style: The linestyle to use
+ :param style: The linestyle to use
"""
+ print("setLineStyle", style)
if style not in _LineStyles:
- raise ValueError('Unknown style: %s', style)
- self.lineStyle = _LineStyles[style]
+ self.lineStyle = qt.Qt.SolidLine
+ dashPattern = style[1]
+ if dashPattern is None or dashPattern == ():
+ self.__dashPattern = None
+ else:
+ self.__dashPattern = style[1]
+ else:
+ self.lineStyle = _LineStyles[style]
+ self.__dashPattern = None
self.update()
def _toLut(self, colormap):
@@ -308,7 +340,7 @@ class LegendIconWidget(qt.QWidget):
# current -> width = 2.5, height = 1.0
scale = float(self.height())
ratio = float(self.width()) / scale
- symbolOffset = qt.QPointF(.5 * (ratio - 1.), 0.)
+ symbolOffset = qt.QPointF(0.5 * (ratio - 1.0), 0.0)
# Determine and scale offset
offset = qt.QPointF(float(rect.left()) / scale, float(rect.top()) / scale)
@@ -316,8 +348,7 @@ class LegendIconWidget(qt.QWidget):
if self.isEnabled():
overrideColor = None
else:
- overrideColor = palette.color(qt.QPalette.Disabled,
- qt.QPalette.WindowText)
+ overrideColor = palette.color(qt.QPalette.Disabled, qt.QPalette.WindowText)
# Draw BG rectangle (for debugging)
# bottomRight = qt.QPointF(
@@ -349,21 +380,23 @@ class LegendIconWidget(qt.QWidget):
llist = []
if self.showLine:
linePath = qt.QPainterPath()
- linePath.moveTo(0., 0.5)
+ linePath.moveTo(0.0, 0.5)
linePath.lineTo(ratio, 0.5)
# linePath.lineTo(2.5, 0.5)
lineBrush = qt.QBrush(
- self.lineColor if overrideColor is None else overrideColor)
+ self.lineColor if overrideColor is None else overrideColor
+ )
linePen = qt.QPen(
lineBrush,
(self.lineWidth / self.height()),
self.lineStyle,
- qt.Qt.FlatCap
+ qt.Qt.FlatCap,
)
+ if self.__dashPattern is not None:
+ linePen.setDashPattern(self.__dashPattern)
llist.append((linePath, linePen, lineBrush))
- isValidSymbol = (len(self.symbol) and
- self.symbol not in _NoSymbols)
+ isValidSymbol = len(self.symbol) and self.symbol not in _NoSymbols
if self.showSymbol and isValidSymbol:
if self.symbolColormap is None:
# PITFALL ahead: Let this be a warning to others
@@ -373,15 +406,14 @@ class LegendIconWidget(qt.QWidget):
symbolPath.translate(symbolOffset)
symbolBrush = qt.QBrush(
self.symbolColor if overrideColor is None else overrideColor,
- self.symbolStyle)
+ self.symbolStyle,
+ )
symbolPen = qt.QPen(
self.symbolOutlineBrush, # Brush
- 1. / self.height(), # Width
- qt.Qt.SolidLine # Style
+ 1.0 / self.height(), # Width
+ qt.Qt.SolidLine, # Style
)
- llist.append((symbolPath,
- symbolPen,
- symbolBrush))
+ llist.append((symbolPath, symbolPen, symbolBrush))
else:
nbSymbols = int(ratio + 2)
for i in range(nbSymbols):
@@ -390,21 +422,21 @@ class LegendIconWidget(qt.QWidget):
else:
image = self.getGrayedColormapImage(self.symbolColormap)
pos = int((_COLORMAP_PIXMAP_SIZE / nbSymbols) * i)
- pos = numpy.clip(pos, 0, _COLORMAP_PIXMAP_SIZE-1)
+ pos = numpy.clip(pos, 0, _COLORMAP_PIXMAP_SIZE - 1)
color = image.pixelColor(pos, 0)
- delta = qt.QPointF(ratio * ((i - (nbSymbols-1)/2) / nbSymbols), 0)
+ delta = qt.QPointF(
+ ratio * ((i - (nbSymbols - 1) / 2) / nbSymbols), 0
+ )
symbolPath = qt.QPainterPath(_Symbols[self.symbol])
symbolPath.translate(symbolOffset + delta)
symbolBrush = qt.QBrush(color, self.symbolStyle)
symbolPen = qt.QPen(
self.symbolOutlineBrush, # Brush
- 1. / self.height(), # Width
- qt.Qt.SolidLine # Style
+ 1.0 / self.height(), # Width
+ qt.Qt.SolidLine, # Style
)
- llist.append((symbolPath,
- symbolPen,
- symbolBrush))
+ llist.append((symbolPath, symbolPen, symbolBrush))
# Draw
for path, pen, brush in llist:
diff --git a/src/silx/gui/widgets/MedianFilterDialog.py b/src/silx/gui/widgets/MedianFilterDialog.py
index 982736c..5fe134f 100644
--- a/src/silx/gui/widgets/MedianFilterDialog.py
+++ b/src/silx/gui/widgets/MedianFilterDialog.py
@@ -42,8 +42,10 @@ from silx.gui import qt
_logger = logging.getLogger(__name__)
+
class MedianFilterDialog(qt.QDialog):
"""QDialog window featuring a :class:`BackgroundWidget`"""
+
sigFilterOptChanged = qt.Signal(int, bool)
def __init__(self, parent=None):
@@ -54,11 +56,11 @@ class MedianFilterDialog(qt.QDialog):
self.setLayout(self.mainLayout)
# filter width GUI
- self.mainLayout.addWidget(qt.QLabel('filter width:', parent = self))
+ self.mainLayout.addWidget(qt.QLabel("filter width:", parent=self))
self._filterWidth = qt.QSpinBox(parent=self)
self._filterWidth.setMinimum(1)
self._filterWidth.setValue(1)
- self._filterWidth.setSingleStep(2);
+ self._filterWidth.setSingleStep(2)
widthTooltip = """radius width of the pixel including in the filter
for each pixel"""
self._filterWidth.setToolTip(widthTooltip)
@@ -66,14 +68,16 @@ class MedianFilterDialog(qt.QDialog):
self.mainLayout.addWidget(self._filterWidth)
# filter option GUI
- self._filterOption = qt.QCheckBox('conditional', parent=self)
+ self._filterOption = qt.QCheckBox("conditional", parent=self)
conditionalTooltip = """if check, implement a conditional filter"""
self._filterOption.stateChanged.connect(self._filterOptionChanged)
self.mainLayout.addWidget(self._filterOption)
def _filterOptionChanged(self):
"""Call back used when the filter values are changed"""
- if self._filterWidth.value()%2 == 0:
- _logger.warning('median filter only accept odd values')
+ if self._filterWidth.value() % 2 == 0:
+ _logger.warning("median filter only accept odd values")
else:
- self.sigFilterOptChanged.emit(self._filterWidth.value(), self._filterOption.isChecked()) \ No newline at end of file
+ self.sigFilterOptChanged.emit(
+ self._filterWidth.value(), self._filterOption.isChecked()
+ )
diff --git a/src/silx/gui/widgets/PeriodicTable.py b/src/silx/gui/widgets/PeriodicTable.py
index 1fc3bab..2923cc6 100644
--- a/src/silx/gui/widgets/PeriodicTable.py
+++ b/src/silx/gui/widgets/PeriodicTable.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -136,122 +136,123 @@ __authors__ = ["E. Papillon", "V.A. Sole", "P. Knobel"]
__license__ = "MIT"
__date__ = "26/01/2017"
-from collections import OrderedDict
import logging
from silx.gui import qt
_logger = logging.getLogger(__name__)
# Symbol Atomic Number col row name mass subcategory
-_elements = [("H", 1, 1, 1, "hydrogen", 1.00800, "diatomic nonmetal"),
- ("He", 2, 18, 1, "helium", 4.0030, "noble gas"),
- ("Li", 3, 1, 2, "lithium", 6.94000, "alkali metal"),
- ("Be", 4, 2, 2, "beryllium", 9.01200, "alkaline earth metal"),
- ("B", 5, 13, 2, "boron", 10.8110, "metalloid"),
- ("C", 6, 14, 2, "carbon", 12.0100, "polyatomic nonmetal"),
- ("N", 7, 15, 2, "nitrogen", 14.0080, "diatomic nonmetal"),
- ("O", 8, 16, 2, "oxygen", 16.0000, "diatomic nonmetal"),
- ("F", 9, 17, 2, "fluorine", 19.0000, "diatomic nonmetal"),
- ("Ne", 10, 18, 2, "neon", 20.1830, "noble gas"),
- ("Na", 11, 1, 3, "sodium", 22.9970, "alkali metal"),
- ("Mg", 12, 2, 3, "magnesium", 24.3200, "alkaline earth metal"),
- ("Al", 13, 13, 3, "aluminium", 26.9700, "post transition metal"),
- ("Si", 14, 14, 3, "silicon", 28.0860, "metalloid"),
- ("P", 15, 15, 3, "phosphorus", 30.9750, "polyatomic nonmetal"),
- ("S", 16, 16, 3, "sulphur", 32.0660, "polyatomic nonmetal"),
- ("Cl", 17, 17, 3, "chlorine", 35.4570, "diatomic nonmetal"),
- ("Ar", 18, 18, 3, "argon", 39.9440, "noble gas"),
- ("K", 19, 1, 4, "potassium", 39.1020, "alkali metal"),
- ("Ca", 20, 2, 4, "calcium", 40.0800, "alkaline earth metal"),
- ("Sc", 21, 3, 4, "scandium", 44.9600, "transition metal"),
- ("Ti", 22, 4, 4, "titanium", 47.9000, "transition metal"),
- ("V", 23, 5, 4, "vanadium", 50.9420, "transition metal"),
- ("Cr", 24, 6, 4, "chromium", 51.9960, "transition metal"),
- ("Mn", 25, 7, 4, "manganese", 54.9400, "transition metal"),
- ("Fe", 26, 8, 4, "iron", 55.8500, "transition metal"),
- ("Co", 27, 9, 4, "cobalt", 58.9330, "transition metal"),
- ("Ni", 28, 10, 4, "nickel", 58.6900, "transition metal"),
- ("Cu", 29, 11, 4, "copper", 63.5400, "transition metal"),
- ("Zn", 30, 12, 4, "zinc", 65.3800, "transition metal"),
- ("Ga", 31, 13, 4, "gallium", 69.7200, "post transition metal"),
- ("Ge", 32, 14, 4, "germanium", 72.5900, "metalloid"),
- ("As", 33, 15, 4, "arsenic", 74.9200, "metalloid"),
- ("Se", 34, 16, 4, "selenium", 78.9600, "polyatomic nonmetal"),
- ("Br", 35, 17, 4, "bromine", 79.9200, "diatomic nonmetal"),
- ("Kr", 36, 18, 4, "krypton", 83.8000, "noble gas"),
- ("Rb", 37, 1, 5, "rubidium", 85.4800, "alkali metal"),
- ("Sr", 38, 2, 5, "strontium", 87.6200, "alkaline earth metal"),
- ("Y", 39, 3, 5, "yttrium", 88.9050, "transition metal"),
- ("Zr", 40, 4, 5, "zirconium", 91.2200, "transition metal"),
- ("Nb", 41, 5, 5, "niobium", 92.9060, "transition metal"),
- ("Mo", 42, 6, 5, "molybdenum", 95.9500, "transition metal"),
- ("Tc", 43, 7, 5, "technetium", 99.0000, "transition metal"),
- ("Ru", 44, 8, 5, "ruthenium", 101.0700, "transition metal"),
- ("Rh", 45, 9, 5, "rhodium", 102.9100, "transition metal"),
- ("Pd", 46, 10, 5, "palladium", 106.400, "transition metal"),
- ("Ag", 47, 11, 5, "silver", 107.880, "transition metal"),
- ("Cd", 48, 12, 5, "cadmium", 112.410, "transition metal"),
- ("In", 49, 13, 5, "indium", 114.820, "post transition metal"),
- ("Sn", 50, 14, 5, "tin", 118.690, "post transition metal"),
- ("Sb", 51, 15, 5, "antimony", 121.760, "metalloid"),
- ("Te", 52, 16, 5, "tellurium", 127.600, "metalloid"),
- ("I", 53, 17, 5, "iodine", 126.910, "diatomic nonmetal"),
- ("Xe", 54, 18, 5, "xenon", 131.300, "noble gas"),
- ("Cs", 55, 1, 6, "caesium", 132.910, "alkali metal"),
- ("Ba", 56, 2, 6, "barium", 137.360, "alkaline earth metal"),
- ("La", 57, 3, 6, "lanthanum", 138.920, "lanthanide"),
- ("Ce", 58, 4, 9, "cerium", 140.130, "lanthanide"),
- ("Pr", 59, 5, 9, "praseodymium", 140.920, "lanthanide"),
- ("Nd", 60, 6, 9, "neodymium", 144.270, "lanthanide"),
- ("Pm", 61, 7, 9, "promethium", 147.000, "lanthanide"),
- ("Sm", 62, 8, 9, "samarium", 150.350, "lanthanide"),
- ("Eu", 63, 9, 9, "europium", 152.000, "lanthanide"),
- ("Gd", 64, 10, 9, "gadolinium", 157.260, "lanthanide"),
- ("Tb", 65, 11, 9, "terbium", 158.930, "lanthanide"),
- ("Dy", 66, 12, 9, "dysprosium", 162.510, "lanthanide"),
- ("Ho", 67, 13, 9, "holmium", 164.940, "lanthanide"),
- ("Er", 68, 14, 9, "erbium", 167.270, "lanthanide"),
- ("Tm", 69, 15, 9, "thulium", 168.940, "lanthanide"),
- ("Yb", 70, 16, 9, "ytterbium", 173.040, "lanthanide"),
- ("Lu", 71, 17, 9, "lutetium", 174.990, "lanthanide"),
- ("Hf", 72, 4, 6, "hafnium", 178.500, "transition metal"),
- ("Ta", 73, 5, 6, "tantalum", 180.950, "transition metal"),
- ("W", 74, 6, 6, "tungsten", 183.920, "transition metal"),
- ("Re", 75, 7, 6, "rhenium", 186.200, "transition metal"),
- ("Os", 76, 8, 6, "osmium", 190.200, "transition metal"),
- ("Ir", 77, 9, 6, "iridium", 192.200, "transition metal"),
- ("Pt", 78, 10, 6, "platinum", 195.090, "transition metal"),
- ("Au", 79, 11, 6, "gold", 197.200, "transition metal"),
- ("Hg", 80, 12, 6, "mercury", 200.610, "transition metal"),
- ("Tl", 81, 13, 6, "thallium", 204.390, "post transition metal"),
- ("Pb", 82, 14, 6, "lead", 207.210, "post transition metal"),
- ("Bi", 83, 15, 6, "bismuth", 209.000, "post transition metal"),
- ("Po", 84, 16, 6, "polonium", 209.000, "post transition metal"),
- ("At", 85, 17, 6, "astatine", 210.000, "metalloid"),
- ("Rn", 86, 18, 6, "radon", 222.000, "noble gas"),
- ("Fr", 87, 1, 7, "francium", 223.000, "alkali metal"),
- ("Ra", 88, 2, 7, "radium", 226.000, "alkaline earth metal"),
- ("Ac", 89, 3, 7, "actinium", 227.000, "actinide"),
- ("Th", 90, 4, 10, "thorium", 232.000, "actinide"),
- ("Pa", 91, 5, 10, "proactinium", 231.03588, "actinide"),
- ("U", 92, 6, 10, "uranium", 238.070, "actinide"),
- ("Np", 93, 7, 10, "neptunium", 237.000, "actinide"),
- ("Pu", 94, 8, 10, "plutonium", 239.100, "actinide"),
- ("Am", 95, 9, 10, "americium", 243, "actinide"),
- ("Cm", 96, 10, 10, "curium", 247, "actinide"),
- ("Bk", 97, 11, 10, "berkelium", 247, "actinide"),
- ("Cf", 98, 12, 10, "californium", 251, "actinide"),
- ("Es", 99, 13, 10, "einsteinium", 252, "actinide"),
- ("Fm", 100, 14, 10, "fermium", 257, "actinide"),
- ("Md", 101, 15, 10, "mendelevium", 258, "actinide"),
- ("No", 102, 16, 10, "nobelium", 259, "actinide"),
- ("Lr", 103, 17, 10, "lawrencium", 262, "actinide"),
- ("Rf", 104, 4, 7, "rutherfordium", 261, "transition metal"),
- ("Db", 105, 5, 7, "dubnium", 262, "transition metal"),
- ("Sg", 106, 6, 7, "seaborgium", 266, "transition metal"),
- ("Bh", 107, 7, 7, "bohrium", 264, "transition metal"),
- ("Hs", 108, 8, 7, "hassium", 269, "transition metal"),
- ("Mt", 109, 9, 7, "meitnerium", 268)]
+_elements = [
+ ("H", 1, 1, 1, "hydrogen", 1.00800, "diatomic nonmetal"),
+ ("He", 2, 18, 1, "helium", 4.0030, "noble gas"),
+ ("Li", 3, 1, 2, "lithium", 6.94000, "alkali metal"),
+ ("Be", 4, 2, 2, "beryllium", 9.01200, "alkaline earth metal"),
+ ("B", 5, 13, 2, "boron", 10.8110, "metalloid"),
+ ("C", 6, 14, 2, "carbon", 12.0100, "polyatomic nonmetal"),
+ ("N", 7, 15, 2, "nitrogen", 14.0080, "diatomic nonmetal"),
+ ("O", 8, 16, 2, "oxygen", 16.0000, "diatomic nonmetal"),
+ ("F", 9, 17, 2, "fluorine", 19.0000, "diatomic nonmetal"),
+ ("Ne", 10, 18, 2, "neon", 20.1830, "noble gas"),
+ ("Na", 11, 1, 3, "sodium", 22.9970, "alkali metal"),
+ ("Mg", 12, 2, 3, "magnesium", 24.3200, "alkaline earth metal"),
+ ("Al", 13, 13, 3, "aluminium", 26.9700, "post transition metal"),
+ ("Si", 14, 14, 3, "silicon", 28.0860, "metalloid"),
+ ("P", 15, 15, 3, "phosphorus", 30.9750, "polyatomic nonmetal"),
+ ("S", 16, 16, 3, "sulphur", 32.0660, "polyatomic nonmetal"),
+ ("Cl", 17, 17, 3, "chlorine", 35.4570, "diatomic nonmetal"),
+ ("Ar", 18, 18, 3, "argon", 39.9440, "noble gas"),
+ ("K", 19, 1, 4, "potassium", 39.1020, "alkali metal"),
+ ("Ca", 20, 2, 4, "calcium", 40.0800, "alkaline earth metal"),
+ ("Sc", 21, 3, 4, "scandium", 44.9600, "transition metal"),
+ ("Ti", 22, 4, 4, "titanium", 47.9000, "transition metal"),
+ ("V", 23, 5, 4, "vanadium", 50.9420, "transition metal"),
+ ("Cr", 24, 6, 4, "chromium", 51.9960, "transition metal"),
+ ("Mn", 25, 7, 4, "manganese", 54.9400, "transition metal"),
+ ("Fe", 26, 8, 4, "iron", 55.8500, "transition metal"),
+ ("Co", 27, 9, 4, "cobalt", 58.9330, "transition metal"),
+ ("Ni", 28, 10, 4, "nickel", 58.6900, "transition metal"),
+ ("Cu", 29, 11, 4, "copper", 63.5400, "transition metal"),
+ ("Zn", 30, 12, 4, "zinc", 65.3800, "transition metal"),
+ ("Ga", 31, 13, 4, "gallium", 69.7200, "post transition metal"),
+ ("Ge", 32, 14, 4, "germanium", 72.5900, "metalloid"),
+ ("As", 33, 15, 4, "arsenic", 74.9200, "metalloid"),
+ ("Se", 34, 16, 4, "selenium", 78.9600, "polyatomic nonmetal"),
+ ("Br", 35, 17, 4, "bromine", 79.9200, "diatomic nonmetal"),
+ ("Kr", 36, 18, 4, "krypton", 83.8000, "noble gas"),
+ ("Rb", 37, 1, 5, "rubidium", 85.4800, "alkali metal"),
+ ("Sr", 38, 2, 5, "strontium", 87.6200, "alkaline earth metal"),
+ ("Y", 39, 3, 5, "yttrium", 88.9050, "transition metal"),
+ ("Zr", 40, 4, 5, "zirconium", 91.2200, "transition metal"),
+ ("Nb", 41, 5, 5, "niobium", 92.9060, "transition metal"),
+ ("Mo", 42, 6, 5, "molybdenum", 95.9500, "transition metal"),
+ ("Tc", 43, 7, 5, "technetium", 99.0000, "transition metal"),
+ ("Ru", 44, 8, 5, "ruthenium", 101.0700, "transition metal"),
+ ("Rh", 45, 9, 5, "rhodium", 102.9100, "transition metal"),
+ ("Pd", 46, 10, 5, "palladium", 106.400, "transition metal"),
+ ("Ag", 47, 11, 5, "silver", 107.880, "transition metal"),
+ ("Cd", 48, 12, 5, "cadmium", 112.410, "transition metal"),
+ ("In", 49, 13, 5, "indium", 114.820, "post transition metal"),
+ ("Sn", 50, 14, 5, "tin", 118.690, "post transition metal"),
+ ("Sb", 51, 15, 5, "antimony", 121.760, "metalloid"),
+ ("Te", 52, 16, 5, "tellurium", 127.600, "metalloid"),
+ ("I", 53, 17, 5, "iodine", 126.910, "diatomic nonmetal"),
+ ("Xe", 54, 18, 5, "xenon", 131.300, "noble gas"),
+ ("Cs", 55, 1, 6, "caesium", 132.910, "alkali metal"),
+ ("Ba", 56, 2, 6, "barium", 137.360, "alkaline earth metal"),
+ ("La", 57, 3, 6, "lanthanum", 138.920, "lanthanide"),
+ ("Ce", 58, 4, 9, "cerium", 140.130, "lanthanide"),
+ ("Pr", 59, 5, 9, "praseodymium", 140.920, "lanthanide"),
+ ("Nd", 60, 6, 9, "neodymium", 144.270, "lanthanide"),
+ ("Pm", 61, 7, 9, "promethium", 147.000, "lanthanide"),
+ ("Sm", 62, 8, 9, "samarium", 150.350, "lanthanide"),
+ ("Eu", 63, 9, 9, "europium", 152.000, "lanthanide"),
+ ("Gd", 64, 10, 9, "gadolinium", 157.260, "lanthanide"),
+ ("Tb", 65, 11, 9, "terbium", 158.930, "lanthanide"),
+ ("Dy", 66, 12, 9, "dysprosium", 162.510, "lanthanide"),
+ ("Ho", 67, 13, 9, "holmium", 164.940, "lanthanide"),
+ ("Er", 68, 14, 9, "erbium", 167.270, "lanthanide"),
+ ("Tm", 69, 15, 9, "thulium", 168.940, "lanthanide"),
+ ("Yb", 70, 16, 9, "ytterbium", 173.040, "lanthanide"),
+ ("Lu", 71, 17, 9, "lutetium", 174.990, "lanthanide"),
+ ("Hf", 72, 4, 6, "hafnium", 178.500, "transition metal"),
+ ("Ta", 73, 5, 6, "tantalum", 180.950, "transition metal"),
+ ("W", 74, 6, 6, "tungsten", 183.920, "transition metal"),
+ ("Re", 75, 7, 6, "rhenium", 186.200, "transition metal"),
+ ("Os", 76, 8, 6, "osmium", 190.200, "transition metal"),
+ ("Ir", 77, 9, 6, "iridium", 192.200, "transition metal"),
+ ("Pt", 78, 10, 6, "platinum", 195.090, "transition metal"),
+ ("Au", 79, 11, 6, "gold", 197.200, "transition metal"),
+ ("Hg", 80, 12, 6, "mercury", 200.610, "transition metal"),
+ ("Tl", 81, 13, 6, "thallium", 204.390, "post transition metal"),
+ ("Pb", 82, 14, 6, "lead", 207.210, "post transition metal"),
+ ("Bi", 83, 15, 6, "bismuth", 209.000, "post transition metal"),
+ ("Po", 84, 16, 6, "polonium", 209.000, "post transition metal"),
+ ("At", 85, 17, 6, "astatine", 210.000, "metalloid"),
+ ("Rn", 86, 18, 6, "radon", 222.000, "noble gas"),
+ ("Fr", 87, 1, 7, "francium", 223.000, "alkali metal"),
+ ("Ra", 88, 2, 7, "radium", 226.000, "alkaline earth metal"),
+ ("Ac", 89, 3, 7, "actinium", 227.000, "actinide"),
+ ("Th", 90, 4, 10, "thorium", 232.000, "actinide"),
+ ("Pa", 91, 5, 10, "proactinium", 231.03588, "actinide"),
+ ("U", 92, 6, 10, "uranium", 238.070, "actinide"),
+ ("Np", 93, 7, 10, "neptunium", 237.000, "actinide"),
+ ("Pu", 94, 8, 10, "plutonium", 239.100, "actinide"),
+ ("Am", 95, 9, 10, "americium", 243, "actinide"),
+ ("Cm", 96, 10, 10, "curium", 247, "actinide"),
+ ("Bk", 97, 11, 10, "berkelium", 247, "actinide"),
+ ("Cf", 98, 12, 10, "californium", 251, "actinide"),
+ ("Es", 99, 13, 10, "einsteinium", 252, "actinide"),
+ ("Fm", 100, 14, 10, "fermium", 257, "actinide"),
+ ("Md", 101, 15, 10, "mendelevium", 258, "actinide"),
+ ("No", 102, 16, 10, "nobelium", 259, "actinide"),
+ ("Lr", 103, 17, 10, "lawrencium", 262, "actinide"),
+ ("Rf", 104, 4, 7, "rutherfordium", 261, "transition metal"),
+ ("Db", 105, 5, 7, "dubnium", 262, "transition metal"),
+ ("Sg", 106, 6, 7, "seaborgium", 266, "transition metal"),
+ ("Bh", 107, 7, 7, "bohrium", 264, "transition metal"),
+ ("Hs", 108, 8, 7, "hassium", 269, "transition metal"),
+ ("Mt", 109, 9, 7, "meitnerium", 268),
+]
class PeriodicTableItem(object):
@@ -279,8 +280,8 @@ class PeriodicTableItem(object):
:param str subcategory: Subcategory, based on physical properties
(e.g. "alkali metal", "noble gas"...)
"""
- def __init__(self, symbol, Z, col, row, name, mass,
- subcategory=""):
+
+ def __init__(self, symbol, Z, col, row, name, mass, subcategory=""):
self.symbol = symbol
"""Atomic symbol (e.g. H, He, Li...)"""
self.Z = Z
@@ -302,10 +303,7 @@ class PeriodicTableItem(object):
if idx == 6:
_logger.warning("density not implemented in silx, returning 0.")
- ret = [self.symbol, self.Z,
- self.col, self.row,
- self.name, self.mass,
- 0.]
+ ret = [self.symbol, self.Z, self.col, self.row, self.name, self.mass, 0.0]
return ret[idx]
def __len__(self):
@@ -320,6 +318,7 @@ class ColoredPeriodicTableItem(PeriodicTableItem):
:param str bgcolor: Custom background color for element in
periodic table, as a RGB string *#RRGGBB*"""
+
COLORS = {
"diatomic nonmetal": "#7FFF00", # chartreuse
"noble gas": "#00FFFF", # cyan
@@ -331,14 +330,12 @@ class ColoredPeriodicTableItem(PeriodicTableItem):
"post transition metal": "#D3D3D3", # light gray
"lanthanide": "#FFB6C1", # light pink
"actinide": "#F08080", # Light Coral
- "": "#FFFFFF" # white
+ "": "#FFFFFF", # white
}
"""Dictionary defining RGB colors for each subcategory."""
- def __init__(self, symbol, Z, col, row, name, mass,
- subcategory="", bgcolor=None):
- PeriodicTableItem.__init__(self, symbol, Z, col, row, name, mass,
- subcategory)
+ def __init__(self, symbol, Z, col, row, name, mass, subcategory="", bgcolor=None):
+ PeriodicTableItem.__init__(self, symbol, Z, col, row, name, mass, subcategory)
self.bgcolor = self.COLORS.get(subcategory, "#FFFFFF")
"""Background color of element in the periodic table,
@@ -356,8 +353,8 @@ _defaultTableItems = [ColoredPeriodicTableItem(*info) for info in _elements]
class _ElementButton(qt.QPushButton):
- """Atomic element button, used as a cell in the periodic table
- """
+ """Atomic element button, used as a cell in the periodic table"""
+
sigElementEnter = qt.pyqtSignal(object)
"""Signal emitted as the cursor enters the widget"""
sigElementLeave = qt.pyqtSignal(object)
@@ -380,8 +377,9 @@ class _ElementButton(qt.QPushButton):
self.setFlat(1)
self.setCheckable(0)
- self.setSizePolicy(qt.QSizePolicy(qt.QSizePolicy.Expanding,
- qt.QSizePolicy.Expanding))
+ self.setSizePolicy(
+ qt.QSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Expanding)
+ )
self.selected = False
self.current = False
@@ -454,18 +452,19 @@ class _ElementButton(qt.QPushButton):
self.brush = qt.QBrush(self.bgcolor)
else:
self.brush = qt.QBrush()
- palette.setBrush(self.backgroundRole(),
- self.brush)
+ palette.setBrush(self.backgroundRole(), self.brush)
self.setPalette(palette)
self.update()
def paintEvent(self, pEvent):
# get button geometry
widgGeom = self.rect()
- paintGeom = qt.QRect(widgGeom.left() + 1,
- widgGeom.top() + 1,
- widgGeom.width() - 2,
- widgGeom.height() - 2)
+ paintGeom = qt.QRect(
+ widgGeom.left() + 1,
+ widgGeom.top() + 1,
+ widgGeom.width() - 2,
+ widgGeom.height() - 2,
+ )
# paint background color
painter = qt.QPainter(self)
@@ -521,6 +520,7 @@ class PeriodicTable(qt.QWidget):
pt.sigElementClicked.connect(my_slot)
"""
+
sigElementClicked = qt.pyqtSignal(object)
"""When any element is clicked in the table, the widget emits
this signal and sends a :class:`PeriodicTableItem` object.
@@ -551,8 +551,9 @@ class PeriodicTable(qt.QWidget):
selection is only possible with method :meth:`setSelection`.
"""
- def __init__(self, parent=None, name="PeriodicTable", elements=None,
- selectable=False):
+ def __init__(
+ self, parent=None, name="PeriodicTable", elements=None, selectable=False
+ ):
self.selectable = selectable
qt.QWidget.__init__(self, parent)
self.setWindowTitle(name)
@@ -576,7 +577,7 @@ class PeriodicTable(qt.QWidget):
self._eltCurrent = None
"""Current :class:`_ElementButton` (last clicked)"""
- self._eltButtons = OrderedDict()
+ self._eltButtons = {}
"""Dictionary of all :class:`_ElementButton`. Keys are the symbols
("H", "He", "Li"...)"""
@@ -617,7 +618,7 @@ class PeriodicTable(qt.QWidget):
def _elementClicked(self, item):
"""Emit :attr:`sigElementClicked`,
toggle selected state of element
-
+
:param PeriodicTableItem item: Element clicked
"""
if self._eltCurrent is not None:
@@ -652,7 +653,7 @@ class PeriodicTable(qt.QWidget):
if isinstance(symbols[0], PeriodicTableItem):
symbols = [elmt.symbol for elmt in symbols]
- for (e, b) in self._eltButtons.items():
+ for e, b in self._eltButtons.items():
b.setSelected(e in symbols)
self.sigSelectionChanged.emit(self.getSelection())
@@ -696,6 +697,7 @@ class PeriodicCombo(qt.QComboBox):
a predefined list with minimal information (symbol, atomic number,
name, mass).
"""
+
sigSelectionChanged = qt.pyqtSignal(object)
"""Signal emitted when the selection changes. Send
:class:`PeriodicTableItem` object representing selected
@@ -752,6 +754,7 @@ class PeriodicList(qt.QTreeWidget):
:param single: *True* for single element selection with mouse click,
*False* for multiple element selection mode.
"""
+
sigSelectionChanged = qt.pyqtSignal(object)
"""When any element is selected/unselected in the widget, it emits
this signal and sends a list of currently selected
@@ -774,8 +777,11 @@ class PeriodicList(qt.QTreeWidget):
self.setRootIsDecorated(0)
self.itemClicked.connect(self.__selectionChanged)
- self.setSelectionMode(qt.QAbstractItemView.SingleSelection if single
- else qt.QAbstractItemView.ExtendedSelection)
+ self.setSelectionMode(
+ qt.QAbstractItemView.SingleSelection
+ if single
+ else qt.QAbstractItemView.ExtendedSelection
+ )
self.__fill_widget(elements)
self.resizeColumnToContents(0)
self.resizeColumnToContents(1)
@@ -783,7 +789,7 @@ class PeriodicList(qt.QTreeWidget):
self.resizeColumnToContents(2)
def __fill_widget(self, elements):
- """Fill tree widget with elements """
+ """Fill tree widget with elements"""
if elements is None:
elements = _defaultTableItems
@@ -813,8 +819,11 @@ class PeriodicList(qt.QTreeWidget):
:return: Selected elements
:rtype: List[PeriodicTableItem]"""
- return [_defaultTableItems[idx] for idx in range(len(self.tree_items))
- if self.tree_items[idx].isSelected()]
+ return [
+ _defaultTableItems[idx]
+ for idx in range(len(self.tree_items))
+ if self.tree_items[idx].isSelected()
+ ]
# setSelection is a bad name (name of a QTreeWidget method)
def setSelectedElements(self, symbolList):
@@ -827,4 +836,6 @@ class PeriodicList(qt.QTreeWidget):
if isinstance(symbolList[0], PeriodicTableItem):
symbolList = [elmt.symbol for elmt in symbolList]
for idx in range(len(self.tree_items)):
- self.tree_items[idx].setSelected(_defaultTableItems[idx].symbol in symbolList)
+ self.tree_items[idx].setSelected(
+ _defaultTableItems[idx].symbol in symbolList
+ )
diff --git a/src/silx/gui/widgets/PrintGeometryDialog.py b/src/silx/gui/widgets/PrintGeometryDialog.py
index db905fb..652e1bc 100644
--- a/src/silx/gui/widgets/PrintGeometryDialog.py
+++ b/src/silx/gui/widgets/PrintGeometryDialog.py
@@ -34,6 +34,7 @@ class PrintGeometryWidget(qt.QWidget):
Use methods :meth:`setPrintGeometry` and :meth:`getPrintGeometry`
to interact with the widget.
"""
+
def __init__(self, parent=None):
super(PrintGeometryWidget, self).__init__(parent)
self.mainLayout = qt.QGridLayout(self)
@@ -107,21 +108,21 @@ class PrintGeometryWidget(qt.QWidget):
print geometry dictionary."""
ddict = {}
if self._inchButton.isChecked():
- ddict['units'] = "inches"
+ ddict["units"] = "inches"
elif self._cmButton.isChecked():
- ddict['units'] = "centimeters"
+ ddict["units"] = "centimeters"
else:
- ddict['units'] = "page"
+ ddict["units"] = "page"
- ddict['xOffset'] = self._xOffset.value()
- ddict['yOffset'] = self._yOffset.value()
- ddict['width'] = self._width.value()
- ddict['height'] = self._height.value()
+ ddict["xOffset"] = self._xOffset.value()
+ ddict["yOffset"] = self._yOffset.value()
+ ddict["width"] = self._width.value()
+ ddict["height"] = self._height.value()
if self._aspect.isChecked():
- ddict['keepAspectRatio'] = True
+ ddict["keepAspectRatio"] = True
else:
- ddict['keepAspectRatio'] = False
+ ddict["keepAspectRatio"] = False
return ddict
def setPrintGeometry(self, geometry=None):
@@ -144,22 +145,28 @@ class PrintGeometryWidget(qt.QWidget):
if geometry is None:
geometry = {}
oldDict = self.getPrintGeometry()
- for key in ["units", "xOffset", "yOffset",
- "width", "height", "keepAspectRatio"]:
+ for key in [
+ "units",
+ "xOffset",
+ "yOffset",
+ "width",
+ "height",
+ "keepAspectRatio",
+ ]:
geometry[key] = geometry.get(key, oldDict[key])
- if geometry['units'].lower().startswith("inc"):
+ if geometry["units"].lower().startswith("inc"):
self._inchButton.setChecked(True)
- elif geometry['units'].lower().startswith("c"):
+ elif geometry["units"].lower().startswith("c"):
self._cmButton.setChecked(True)
else:
self._pageButton.setChecked(True)
- self._xOffset.setText("%s" % float(geometry['xOffset']))
- self._yOffset.setText("%s" % float(geometry['yOffset']))
- self._width.setText("%s" % float(geometry['width']))
- self._height.setText("%s" % float(geometry['height']))
- if geometry['keepAspectRatio']:
+ self._xOffset.setText("%s" % float(geometry["xOffset"]))
+ self._yOffset.setText("%s" % float(geometry["yOffset"]))
+ self._width.setText("%s" % float(geometry["width"]))
+ self._height.setText("%s" % float(geometry["height"]))
+ if geometry["keepAspectRatio"]:
self._aspect.setChecked(True)
else:
self._aspect.setChecked(False)
diff --git a/src/silx/gui/widgets/PrintPreview.py b/src/silx/gui/widgets/PrintPreview.py
index dd6af1f..285f12c 100644
--- a/src/silx/gui/widgets/PrintPreview.py
+++ b/src/silx/gui/widgets/PrintPreview.py
@@ -42,10 +42,9 @@ _logger = logging.getLogger(__name__)
class PrintPreviewDialog(qt.QDialog):
- """Print preview dialog widget.
- """
- def __init__(self, parent=None, printer=None):
+ """Print preview dialog widget."""
+ def __init__(self, parent=None, printer=None):
qt.QDialog.__init__(self, parent)
self.setWindowTitle("Print Preview")
self.setModal(False)
@@ -108,8 +107,7 @@ class PrintPreviewDialog(qt.QDialog):
cancelBut.setToolTip("Remove all items")
cancelBut.clicked.connect(self._clearAll)
- removeBut = qt.QPushButton("Remove",
- toolBar)
+ removeBut = qt.QPushButton("Remove", toolBar)
removeBut.setToolTip("Remove selected item (use left click to select)")
removeBut.clicked.connect(self._remove)
@@ -160,18 +158,17 @@ class PrintPreviewDialog(qt.QDialog):
self.targetLabel.setText("Undefined printer")
return
if self.printer.outputFileName():
- self.targetLabel.setText("File:" +
- self.printer.outputFileName())
+ self.targetLabel.setText("File:" + self.printer.outputFileName())
else:
- self.targetLabel.setText("Printer:" +
- self.printer.printerName())
+ self.targetLabel.setText("Printer:" + self.printer.printerName())
def _updatePrinter(self):
"""Resize :attr:`page`, :attr:`scene` and :attr:`view` to :attr:`printer`
width and height."""
printer = self.printer
- assert printer is not None, \
- "_updatePrinter should not be called unless a printer is defined"
+ assert (
+ printer is not None
+ ), "_updatePrinter should not be called unless a printer is defined"
if self.scene is None:
self.scene = qt.QGraphicsScene()
self.scene.setBackgroundBrush(qt.QColor(qt.Qt.lightGray))
@@ -204,9 +201,12 @@ class PrintPreviewDialog(qt.QDialog):
:param str comment: Comment displayed below the image
:param commentPosition: "CENTER" or "LEFT"
"""
- self.addPixmap(qt.QPixmap.fromImage(image),
- title=title, comment=comment,
- commentPosition=commentPosition)
+ self.addPixmap(
+ qt.QPixmap.fromImage(image),
+ title=title,
+ comment=comment,
+ commentPosition=commentPosition,
+ )
def addPixmap(self, pixmap, title=None, comment=None, commentPosition=None):
"""Add a pixmap to the print preview scene
@@ -223,14 +223,13 @@ class PrintPreviewDialog(qt.QDialog):
_logger.error("printer is not set, cannot add pixmap to page")
return
if title is None:
- title = ' ' * 88
+ title = " " * 88
if comment is None:
- comment = ' ' * 88
+ comment = " " * 88
if commentPosition is None:
commentPosition = "CENTER"
rectItem = qt.QGraphicsRectItem(self.page)
- rectItem.setRect(qt.QRectF(1, 1,
- pixmap.width(), pixmap.height()))
+ rectItem.setRect(qt.QRectF(1, 1, pixmap.width(), pixmap.height()))
pen = rectItem.pen()
color = qt.QColor(qt.Qt.red)
@@ -269,9 +268,15 @@ class PrintPreviewDialog(qt.QDialog):
rectItem.moveBy(20, 40)
- def addSvgItem(self, item, title=None,
- comment=None, commentPosition=None,
- viewBox=None, keepRatio=True):
+ def addSvgItem(
+ self,
+ item,
+ title=None,
+ comment=None,
+ commentPosition=None,
+ viewBox=None,
+ keepRatio=True,
+ ):
"""Add a SVG item to the scene.
:param QSvgRenderer item: SVG item to be added to the scene.
@@ -296,9 +301,9 @@ class PrintPreviewDialog(qt.QDialog):
return
if title is None:
- title = 50 * ' '
+ title = 50 * " "
if comment is None:
- comment = 80 * ' '
+ comment = 80 * " "
if commentPosition is None:
commentPosition = "CENTER"
@@ -319,8 +324,9 @@ class PrintPreviewDialog(qt.QDialog):
svgItem.setFlag(qt.QGraphicsItem.ItemIsMovable, True)
svgItem.setFlag(qt.QGraphicsItem.ItemIsFocusable, False)
- rectItemResizeRect = _GraphicsResizeRectItem(svgItem, self.scene,
- keepratio=keepRatio)
+ rectItemResizeRect = _GraphicsResizeRectItem(
+ svgItem, self.scene, keepratio=keepRatio
+ )
rectItemResizeRect.setZValue(2)
self._svgItems.append(item)
@@ -357,9 +363,13 @@ class PrintPreviewDialog(qt.QDialog):
if alignment == qt.Qt.AlignLeft:
deltax = 0
else:
- deltax = (svgItem.boundingRect().width() - commentItem.boundingRect().width()) / 2.
- commentItem.moveBy(svgItem.boundingRect().x() + deltax,
- svgItem.boundingRect().y() + svgItem.boundingRect().height())
+ deltax = (
+ svgItem.boundingRect().width() - commentItem.boundingRect().width()
+ ) / 2.0
+ commentItem.moveBy(
+ svgItem.boundingRect().x() + deltax,
+ svgItem.boundingRect().y() + svgItem.boundingRect().height(),
+ )
# Title
textItem = qt.QGraphicsTextItem(title, svgItem)
@@ -368,9 +378,12 @@ class PrintPreviewDialog(qt.QDialog):
textItem.setFlag(qt.QGraphicsItem.ItemIsMovable, True)
title_offset = 0.5 * textItem.boundingRect().width()
- textItem.moveBy(svgItem.boundingRect().x() +
- 0.5 * svgItem.boundingRect().width() - title_offset * scale,
- svgItem.boundingRect().y())
+ textItem.moveBy(
+ svgItem.boundingRect().x()
+ + 0.5 * svgItem.boundingRect().width()
+ - title_offset * scale,
+ svgItem.boundingRect().y(),
+ )
textItem.setScale(scale)
def setup(self):
@@ -387,7 +400,9 @@ class PrintPreviewDialog(qt.QDialog):
if self.printer.width() <= 0 or self.printer.height() <= 0:
self.message = qt.QMessageBox(self)
self.message.setIcon(qt.QMessageBox.Critical)
- self.message.setText("Unknown library error \non printer initialization")
+ self.message.setText(
+ "Unknown library error \non printer initialization"
+ )
self.message.setWindowTitle("Library Error")
self.message.setModal(0)
self.printer = None
@@ -412,8 +427,9 @@ class PrintPreviewDialog(qt.QDialog):
self.setup()
if self.printer is None:
self.hide()
- _logger.warning("Printer setup failed or was cancelled, " +
- "but printer is required.")
+ _logger.warning(
+ "Printer setup failed or was cancelled, " + "but printer is required."
+ )
return self.printer is not None
def setOutputFileName(self, name):
@@ -461,19 +477,27 @@ class PrintPreviewDialog(qt.QDialog):
_logger.error("Cannot initialize printer")
return
try:
- self.scene.render(painter, qt.QRectF(0, 0, printer.width(), printer.height()),
- qt.QRectF(self.page.rect().x(), self.page.rect().y(),
- self.page.rect().width(), self.page.rect().height()),
- qt.Qt.KeepAspectRatio)
+ self.scene.render(
+ painter,
+ qt.QRectF(0, 0, printer.width(), printer.height()),
+ qt.QRectF(
+ self.page.rect().x(),
+ self.page.rect().y(),
+ self.page.rect().width(),
+ self.page.rect().height(),
+ ),
+ qt.Qt.KeepAspectRatio,
+ )
painter.end()
self.hide()
self.accept()
self._toBeCleared = True
- except: # FIXME
+ except: # FIXME
painter.end()
- qt.QMessageBox.critical(self, "ERROR",
- 'Printing problem:\n %s' % sys.exc_info()[1])
- _logger.error('printing problem:\n %s' % sys.exc_info()[1])
+ qt.QMessageBox.critical(
+ self, "ERROR", "Printing problem:\n %s" % sys.exc_info()[1]
+ )
+ _logger.error("printing problem:\n %s" % sys.exc_info()[1])
return
def _zoomPlus(self):
@@ -501,8 +525,7 @@ class PrintPreviewDialog(qt.QDialog):
self._toBeCleared = False
def _remove(self):
- """Remove selected item in :attr:`scene`.
- """
+ """Remove selected item in :attr:`scene`."""
itemlist = self.scene.items()
# this loop is not efficient if there are many items ...
@@ -518,6 +541,7 @@ class SingletonPrintPreviewDialog(PrintPreviewDialog):
a single print preview dialog. This enables sending
multiple images to a single page to be printed.
"""
+
_instance = None
def __new__(self, *var, **kw):
@@ -530,6 +554,7 @@ class _GraphicsSvgRectItem(qt.QGraphicsRectItem):
""":class:`qt.QGraphicsRectItem` with an attached
:class:`qt.QSvgRenderer`, and with a painter redefined to render
the SVG item."""
+
def setSvgRenderer(self, renderer):
"""
@@ -543,6 +568,7 @@ class _GraphicsSvgRectItem(qt.QGraphicsRectItem):
class _GraphicsResizeRectItem(qt.QGraphicsRectItem):
"""Resizable QGraphicsRectItem."""
+
def __init__(self, parent=None, scene=None, keepratio=True):
qt.QGraphicsRectItem.__init__(self, parent)
rect = parent.boundingRect()
@@ -561,7 +587,7 @@ class _GraphicsResizeRectItem(qt.QGraphicsRectItem):
pen.setStyle(qt.Qt.NoPen)
self.setPen(pen)
self.setBrush(color)
- self.setFlag(self.ItemIsMovable, True)
+ self.setFlag(qt.QGraphicsItem.ItemIsMovable, True)
self.show()
def hoverEnterEvent(self, event):
@@ -602,10 +628,7 @@ class _GraphicsResizeRectItem(qt.QGraphicsRectItem):
self._h = rect.height()
self._ratio = self._w / self._h
self._newRect = qt.QGraphicsRectItem(parent)
- self._newRect.setRect(qt.QRectF(self._x,
- self._y,
- self._w,
- self._h))
+ self._newRect.setRect(qt.QRectF(self._x, self._y, self._w, self._h))
qt.QGraphicsRectItem.mousePressEvent(self, event)
def mouseMoveEvent(self, event):
@@ -616,20 +639,27 @@ class _GraphicsResizeRectItem(qt.QGraphicsRectItem):
r1 = (self._w + deltax) / self._w
r2 = (self._h + deltay) / self._h
if r1 < r2:
- self._newRect.setRect(qt.QRectF(self._x,
- self._y,
- self._w + deltax,
- (self._w + deltax) / self._ratio))
+ self._newRect.setRect(
+ qt.QRectF(
+ self._x,
+ self._y,
+ self._w + deltax,
+ (self._w + deltax) / self._ratio,
+ )
+ )
else:
- self._newRect.setRect(qt.QRectF(self._x,
- self._y,
- (self._h + deltay) * self._ratio,
- self._h + deltay))
+ self._newRect.setRect(
+ qt.QRectF(
+ self._x,
+ self._y,
+ (self._h + deltay) * self._ratio,
+ self._h + deltay,
+ )
+ )
else:
- self._newRect.setRect(qt.QRectF(self._x,
- self._y,
- self._w + deltax,
- self._h + deltay))
+ self._newRect.setRect(
+ qt.QRectF(self._x, self._y, self._w + deltax, self._h + deltay)
+ )
qt.QGraphicsRectItem.mouseMoveEvent(self, event)
def mouseReleaseEvent(self, event):
@@ -649,8 +679,7 @@ class _GraphicsResizeRectItem(qt.QGraphicsRectItem):
# apply the scale to the previous transformation matrix
previousTransform = parent.transform()
- parent.setTransform(
- previousTransform.scale(scalex, scaley))
+ parent.setTransform(previousTransform.scale(scalex, scaley))
self.scene().removeItem(self._newRect)
self._newRect = None
@@ -658,8 +687,7 @@ class _GraphicsResizeRectItem(qt.QGraphicsRectItem):
def main():
- """
- """
+ """ """
if len(sys.argv) < 2:
print("give an image file as parameter please.")
sys.exit(1)
@@ -678,19 +706,20 @@ def main():
if filename[-3:] == "svg":
item = qt.QSvgRenderer(filename, w.page)
- w.addSvgItem(item, title=filename,
- comment=comment, commentPosition="CENTER")
+ w.addSvgItem(item, title=filename, comment=comment, commentPosition="CENTER")
else:
- w.addPixmap(qt.QPixmap.fromImage(qt.QImage(filename)),
- title=filename,
- comment=comment,
- commentPosition="CENTER")
+ w.addPixmap(
+ qt.QPixmap.fromImage(qt.QImage(filename)),
+ title=filename,
+ comment=comment,
+ commentPosition="CENTER",
+ )
w.addImage(qt.QImage(filename), comment=comment, commentPosition="LEFT")
sys.exit(w.exec())
-if __name__ == '__main__':
+if __name__ == "__main__":
a = qt.QApplication(sys.argv)
main()
a.exec()
diff --git a/src/silx/gui/widgets/RangeSlider.py b/src/silx/gui/widgets/RangeSlider.py
index 4db0470..c96ae14 100644
--- a/src/silx/gui/widgets/RangeSlider.py
+++ b/src/silx/gui/widgets/RangeSlider.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2015-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2015-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -29,7 +29,7 @@
__authors__ = ["D. Naudet", "T. Vincent"]
__license__ = "MIT"
-__date__ = "26/11/2018"
+__date__ = "14/12/2023"
import numpy as numpy
@@ -91,10 +91,10 @@ class RangeSlider(qt.QWidget):
def __init__(self, parent=None):
self.__pixmap = None
self.__positionCount = None
- self.__firstValue = 0.
- self.__secondValue = 1.
- self.__minValue = 0.
- self.__maxValue = 1.
+ self.__firstValue = 0.0
+ self.__secondValue = 1.0
+ self.__minValue = 0.0
+ self.__maxValue = 1.0
self.__hoverRect = qt.QRect()
self.__hoverControl = None
@@ -102,8 +102,8 @@ class RangeSlider(qt.QWidget):
self.__moving = None
self.__icons = {
- 'first': icons.getQIcon('previous'),
- 'second': icons.getQIcon('next')
+ "first": icons.getQIcon("previous"),
+ "second": icons.getQIcon("next"),
}
# call the super constructor AFTER defining all members that
@@ -121,8 +121,17 @@ class RangeSlider(qt.QWidget):
def event(self, event):
t = event.type()
- if t == qt.QEvent.HoverEnter or t == qt.QEvent.HoverLeave or t == qt.QEvent.HoverMove:
- return self.__updateHoverControl(event.pos())
+ if (
+ t == qt.QEvent.HoverEnter
+ or t == qt.QEvent.HoverLeave
+ or t == qt.QEvent.HoverMove
+ ):
+ if qt.BINDING in ("PyQt5",):
+ # qt-5
+ return self.__updateHoverControl(event.pos())
+ else:
+ # qt-6
+ return self.__updateHoverControl(event.position().toPoint())
else:
return super(RangeSlider, self).event(event)
@@ -256,8 +265,7 @@ class RangeSlider(qt.QWidget):
:param int first:
:param int second:
"""
- self.setValues(self.__positionToValue(first),
- self.__positionToValue(second))
+ self.setValues(self.__positionToValue(first), self.__positionToValue(second))
# Value (float) API
@@ -500,15 +508,25 @@ class RangeSlider(qt.QWidget):
self.setGroovePixmap(qpixmap)
# Handle interaction
+ def _mouseEventPosition(self, event):
+ if qt.BINDING in ("PyQt5",):
+ # qt-5 returns QPoint
+ position = event.pos()
+ else:
+ # qt-6 returns QPointF
+ # convert it to QPoint
+ position = event.position().toPoint()
+ return position
def mousePressEvent(self, event):
super(RangeSlider, self).mousePressEvent(event)
if event.buttons() == qt.Qt.LeftButton:
picked = None
- for name in ('first', 'second'):
+ for name in ("first", "second"):
area = self.__sliderRect(name)
- if area.contains(event.pos()):
+ position = self._mouseEventPosition(event)
+ if area.contains(position):
picked = name
break
@@ -520,12 +538,13 @@ class RangeSlider(qt.QWidget):
super(RangeSlider, self).mouseMoveEvent(event)
if self.__moving is not None:
+ event_pos = self._mouseEventPosition(event)
delta = self._SLIDER_WIDTH // 2
- if self.__moving == 'first':
- position = self.__xPixelToPosition(event.pos().x() + delta)
+ if self.__moving == "first":
+ position = self.__xPixelToPosition(event_pos.x() + delta)
self.setFirstPosition(position)
else:
- position = self.__xPixelToPosition(event.pos().x() - delta)
+ position = self.__xPixelToPosition(event_pos.x() - delta)
self.setSecondPosition(position)
def mouseReleaseEvent(self, event):
@@ -545,13 +564,13 @@ class RangeSlider(qt.QWidget):
key = event.key()
if event.modifiers() == qt.Qt.NoModifier and self.__focus is not None:
if key in (qt.Qt.Key_Left, qt.Qt.Key_Down):
- if self.__focus == 'first':
+ if self.__focus == "first":
self.setFirstPosition(self.getFirstPosition() - 1)
else:
self.setSecondPosition(self.getSecondPosition() - 1)
return # accept event
elif key in (qt.Qt.Key_Right, qt.Qt.Key_Up):
- if self.__focus == 'first':
+ if self.__focus == "first":
self.setFirstPosition(self.getFirstPosition() + 1)
else:
self.setSecondPosition(self.getSecondPosition() + 1)
@@ -565,8 +584,10 @@ class RangeSlider(qt.QWidget):
super(RangeSlider, self).resizeEvent(event)
# If no step, signal position update when width change
- if (self.getPositionCount() is None and
- event.size().width() != event.oldSize().width()):
+ if (
+ self.getPositionCount() is None
+ and event.size().width() != event.oldSize().width()
+ ):
self.sigPositionChanged.emit(*self.getPositions())
# Handle repaint
@@ -589,15 +610,15 @@ class RangeSlider(qt.QWidget):
:rtype: QRect
:raise ValueError: If wrong name
"""
- assert name in ('first', 'second')
- if name == 'first':
- offset = - self._SLIDER_WIDTH
+ assert name in ("first", "second")
+ if name == "first":
+ offset = -self._SLIDER_WIDTH
position = self.getFirstPosition()
- elif name == 'second':
+ elif name == "second":
offset = 0
position = self.getSecondPosition()
else:
- raise ValueError('Unknown name')
+ raise ValueError("Unknown name")
sliderArea = self.__sliderAreaRect()
@@ -605,26 +626,20 @@ class RangeSlider(qt.QWidget):
xOffset = int((sliderArea.width() - 1) * position / maxPos)
xPos = sliderArea.left() + xOffset + offset
- return qt.QRect(xPos,
- sliderArea.top(),
- self._SLIDER_WIDTH,
- sliderArea.height())
+ return qt.QRect(xPos, sliderArea.top(), self._SLIDER_WIDTH, sliderArea.height())
def __drawArea(self):
- return self.rect().adjusted(self._SLIDER_WIDTH, 0,
- -self._SLIDER_WIDTH, 0)
+ return self.rect().adjusted(self._SLIDER_WIDTH, 0, -self._SLIDER_WIDTH, 0)
def __sliderAreaRect(self):
- return self.__drawArea().adjusted(self._SLIDER_WIDTH // 2,
- 0,
- -self._SLIDER_WIDTH // 2 + 1,
- 0)
+ return self.__drawArea().adjusted(
+ self._SLIDER_WIDTH // 2, 0, -self._SLIDER_WIDTH // 2 + 1, 0
+ )
def __pixMapRect(self):
- return self.__sliderAreaRect().adjusted(0,
- self._PIXMAP_VOFFSET,
- -1,
- -self._PIXMAP_VOFFSET)
+ return self.__sliderAreaRect().adjusted(
+ 0, self._PIXMAP_VOFFSET, -1, -self._PIXMAP_VOFFSET
+ )
def paintEvent(self, event):
painter = qt.QPainter(self)
@@ -638,12 +653,10 @@ class RangeSlider(qt.QWidget):
option = qt.QStyleOptionProgressBar()
option.initFrom(self)
option.rect = area
- option.state = (qt.QStyle.State_Enabled if self.isEnabled()
- else qt.QStyle.State_None)
- style.drawControl(qt.QStyle.CE_ProgressBarGroove,
- option,
- painter,
- self)
+ option.state = (
+ qt.QStyle.State_Enabled if self.isEnabled() else qt.QStyle.State_None
+ )
+ style.drawControl(qt.QStyle.CE_ProgressBarGroove, option, painter, self)
painter.save()
pen = painter.pen()
@@ -654,13 +667,13 @@ class RangeSlider(qt.QWidget):
painter.restore()
if self.isEnabled():
- rect = area.adjusted(self._SLIDER_WIDTH // 2,
- self._PIXMAP_VOFFSET,
- -self._SLIDER_WIDTH // 2,
- -self._PIXMAP_VOFFSET + 1)
- painter.drawPixmap(rect,
- self.__pixmap,
- self.__pixmap.rect())
+ rect = area.adjusted(
+ self._SLIDER_WIDTH // 2,
+ self._PIXMAP_VOFFSET,
+ -self._SLIDER_WIDTH // 2,
+ -self._PIXMAP_VOFFSET + 1,
+ )
+ painter.drawPixmap(rect, self.__pixmap, self.__pixmap.rect())
else:
option = StyleOptionRangeSlider()
option.initFrom(self)
@@ -671,8 +684,9 @@ class RangeSlider(qt.QWidget):
option.handlerRect2 = self.__sliderRect("second")
option.minimum = self.__minValue
option.maximum = self.__maxValue
- option.state = (qt.QStyle.State_Enabled if self.isEnabled()
- else qt.QStyle.State_None)
+ option.state = (
+ qt.QStyle.State_Enabled if self.isEnabled() else qt.QStyle.State_None
+ )
if self.__hoverControl == "groove":
option.state |= qt.QStyle.State_MouseOver
elif option.state & qt.QStyle.State_MouseOver:
@@ -682,7 +696,7 @@ class RangeSlider(qt.QWidget):
# Avoid glitch when moving handles
hoverControl = self.__moving or self.__hoverControl
- for name in ('first', 'second'):
+ for name in ("first", "second"):
rect = self.__sliderRect(name)
option = qt.QStyleOptionButton()
option.initFrom(self)
@@ -697,8 +711,7 @@ class RangeSlider(qt.QWidget):
elif option.state & qt.QStyle.State_HasFocus:
option.state ^= qt.QStyle.State_HasFocus
option.rect = rect
- style.drawControl(
- qt.QStyle.CE_PushButton, option, painter, self)
+ style.drawControl(qt.QStyle.CE_PushButton, option, painter, self)
def sizeHint(self):
return qt.QSize(200, self.minimumHeight())
@@ -731,17 +744,24 @@ class RangeSlider(qt.QWidget):
buttonColor = option.palette.button().color()
val = qt.qGray(buttonColor.rgb())
buttonColor = buttonColor.lighter(100 + max(1, (180 - val) // 6))
- buttonColor.setHsv(buttonColor.hue(), (buttonColor.saturation() * 3) // 4, buttonColor.value())
+ buttonColor.setHsv(
+ buttonColor.hue(), (buttonColor.saturation() * 3) // 4, buttonColor.value()
+ )
grooveColor = qt.QColor()
- grooveColor.setHsv(buttonColor.hue(),
- min(255, (int)(buttonColor.saturation())),
- min(255, (int)(buttonColor.value() * 0.9)))
+ grooveColor.setHsv(
+ buttonColor.hue(),
+ min(255, (int)(buttonColor.saturation())),
+ min(255, (int)(buttonColor.value() * 0.9)),
+ )
selectedInnerContrastLine = qt.QColor(255, 255, 255, 30)
outline = option.palette.color(qt.QPalette.Window).darker(140)
- if (option.state & qt.QStyle.State_HasFocus and option.state & qt.QStyle.State_KeyboardFocusChange):
+ if (
+ option.state & qt.QStyle.State_HasFocus
+ and option.state & qt.QStyle.State_KeyboardFocusChange
+ ):
outline = highlight.darker(125)
if outline.value() > 160:
outline.setHsl(highlight.hue(), highlight.saturation(), 160)
@@ -760,7 +780,9 @@ class RangeSlider(qt.QWidget):
# Draw slider background for the value
gradient = qt.QLinearGradient()
gradient.setStart(selectedRangeRect.center().x(), selectedRangeRect.top())
- gradient.setFinalStop(selectedRangeRect.center().x(), selectedRangeRect.bottom())
+ gradient.setFinalStop(
+ selectedRangeRect.center().x(), selectedRangeRect.bottom()
+ )
painter.setRenderHint(qt.QPainter.Antialiasing, True)
painter.setPen(qt.QPen(selectedOutline))
gradient.setColorAt(0, activeHighlight)
diff --git a/src/silx/gui/widgets/StackedProgressBar.py b/src/silx/gui/widgets/StackedProgressBar.py
new file mode 100644
index 0000000..87a5896
--- /dev/null
+++ b/src/silx/gui/widgets/StackedProgressBar.py
@@ -0,0 +1,314 @@
+# /*##########################################################################
+#
+# Copyright (c) 2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+from __future__ import annotations
+
+from typing import NamedTuple, Any, ValuesView
+from silx.gui import qt
+
+
+class ProgressItem(NamedTuple):
+ """Item storing the state of a stacked progress item"""
+
+ value: int
+ """Progression of the item"""
+
+ visible: bool
+ """Is the item displayed"""
+
+ color: qt.QColor
+ """Color of the progress"""
+
+ striped: bool
+ """If true, apply a stripe color to the gradiant"""
+
+ animated: bool
+ """If true, the stripe is animated"""
+
+ toolTip: str
+ """Tool tip of this item"""
+
+ userData: Any
+ """Any user data"""
+
+
+class _UndefinedType:
+ pass
+
+
+_Undefined = _UndefinedType()
+
+
+class StackedProgressBar(qt.QProgressBar):
+ """
+ Multiple stacked progress bar in single component
+ """
+
+ def __init__(self, parent: qt.Qwidget | None = None):
+ super().__init__(parent=parent)
+ self.__stacks: dict[str, ProgressItem] = {}
+ self._animated: int = 0
+ self._timer = qt.QTimer(self)
+ self._timer.setInterval(80)
+ self._timer.timeout.connect(self._tick)
+ self._spacing: int = 0
+ self._spacingCollapsible: bool = True
+
+ def _tick(self):
+ self._animated += 2
+ self.update()
+
+ def setSpacing(self, spacing: int):
+ """Spacing between items, in pixels"""
+ if self._spacing == spacing:
+ return
+ self._spacing = spacing
+ self.update()
+
+ def spacing(self) -> int:
+ return self._spacing
+
+ def setSpacingCollapsible(self, collapse: bool):
+ """
+ Set whether consecutive spacing should be collapsed.
+
+ It can be useful to disable that to ensure pixel perfect
+ rendering is some use cases.
+
+
+ By default, this property is true.
+ """
+ if self._spacingCollapsible == collapse:
+ return
+ self._spacingCollapsible = collapse
+ self.update()
+
+ def spacingCollapsible(self) -> bool:
+ return self._spacingCollapsible
+
+ def clear(self):
+ """Remove every stacked items from the widget"""
+ if len(self.__stacks) == 0:
+ return
+ self.__stacks.clear()
+ self.update()
+
+ def setProgressItem(
+ self,
+ name: str,
+ value: int | None | _UndefinedType = _Undefined,
+ visible: bool | _UndefinedType = _Undefined,
+ color: qt.QColor | None | _UndefinedType = _Undefined,
+ striped: bool | _UndefinedType = _Undefined,
+ animated: bool | _UndefinedType = _Undefined,
+ toolTip: str | None | _UndefinedType = _Undefined,
+ userData: Any = _Undefined,
+ ):
+ """Add or update a stacked items by its name"""
+
+ previousItem = self.__stacks.get(name)
+
+ if previousItem is not None:
+ if value is _Undefined:
+ value = previousItem.value
+ if visible is _Undefined:
+ visible = previousItem.visible
+ if striped is _Undefined:
+ striped = previousItem.striped
+ if color is _Undefined:
+ color = previousItem.color
+ if toolTip is _Undefined:
+ toolTip = previousItem.toolTip
+ if animated is _Undefined:
+ animated = previousItem.animated
+ if userData is _Undefined:
+ userData = previousItem.userData
+ else:
+ if value is _Undefined:
+ value = 0
+ if visible is _Undefined:
+ visible = True
+ if striped is _Undefined:
+ striped = False
+ if color is _Undefined:
+ color = qt.QColor()
+ if toolTip is _Undefined:
+ toolTip = ""
+ if animated is _Undefined:
+ animated = False
+ if userData is _Undefined:
+ userData = None
+
+ newItem = ProgressItem(
+ value=value,
+ visible=visible,
+ color=color,
+ striped=striped,
+ animated=animated,
+ toolTip=toolTip,
+ userData=userData,
+ )
+ if previousItem == newItem:
+ return
+ self.__stacks[name] = newItem
+ animated = any([s.animated for s in self.__stacks.values()])
+ self._setAnimated(animated)
+ self.update()
+
+ def _setAnimated(self, animated: bool):
+ if animated == self._timer.isActive():
+ return
+ if animated:
+ self._timer.start()
+ else:
+ self._timer.stop()
+
+ def removeProgressItem(self, name: str):
+ """Remove a stacked item by its name"""
+ s = self.__stacks.pop(name, None)
+ if s is None:
+ return
+ self.update()
+
+ def _brushFromProgressItem(self, item: ProgressItem) -> qt.QPalette | None:
+ if item.color is None:
+ return None
+
+ palette = qt.QPalette()
+ color = qt.QColor(item.color)
+
+ if item.striped:
+ if item.animated:
+ delta = self._animated
+ else:
+ delta = 0
+ color2 = color.lighter(120)
+ shadowGradient = qt.QLinearGradient()
+ shadowGradient.setSpread(qt.QGradient.RepeatSpread)
+ shadowGradient.setStart(-delta, 0)
+ shadowGradient.setFinalStop(8 - delta, -8)
+ shadowGradient.setColorAt(0.0, color)
+ shadowGradient.setColorAt(0.5, color)
+ shadowGradient.setColorAt(0.50001, color2)
+ shadowGradient.setColorAt(1.0, color2)
+ brush = qt.QBrush(shadowGradient)
+ palette.setBrush(qt.QPalette.Highlight, brush)
+ palette.setBrush(qt.QPalette.Window, color2)
+ else:
+ palette.setColor(qt.QPalette.Highlight, color)
+
+ return palette
+
+ def paintEvent(self, event):
+ painter = qt.QStylePainter(self)
+ opt = qt.QStyleOptionProgressBar()
+ self.initStyleOption(opt)
+ painter.drawControl(qt.QStyle.CE_ProgressBarGroove, opt)
+ self._drawProgressItems(painter, self.__stacks.values())
+
+ def _drawProgressItems(self, painter: qt.QPainter, items: ValuesView[ProgressItem]):
+ opt = qt.QStyleOptionProgressBar()
+ self.initStyleOption(opt)
+
+ visibleItems = [i for i in items if i.value and i.visible]
+ xpos: int = 0
+ w = opt.rect.width()
+ if self._spacingCollapsible:
+ cumspacing = max(0, len(visibleItems) - 1) * self._spacing
+ w -= cumspacing
+ vw = opt.maximum - opt.minimum
+ opt.minimum = 0
+ opt.maximum = w
+
+ for item in visibleItems:
+ xwidth = int(item.value * w / vw)
+ opt.progress = xwidth * 2
+ palette = self._brushFromProgressItem(item)
+ if palette is not None:
+ opt.palette = palette
+ self._drawProgressItem(painter, opt, xpos, xwidth)
+ xpos += xwidth + self._spacing
+
+ def _drawProgressItem(
+ self,
+ painter: qt.QPainter,
+ option: qt.QStyleOptionProgressBar,
+ xpos: int,
+ xwidth: int,
+ ):
+ if xwidth == 0:
+ return
+ rect: qt.QRect = option.rect
+ style = self.style()
+
+ if option.minimum == 0 and option.maximum == 0:
+ return
+ x0 = rect.x() + 3
+ y0 = rect.y()
+
+ h = rect.height()
+ w = rect.width()
+ xmaxwith = min(x0 + xpos + xwidth, w - 1) - x0 - xpos
+ if xmaxwith < 0:
+ return
+ rect = qt.QRect(x0 + xpos, y0, xmaxwith, h)
+ opt = qt.QStyleOptionProgressBar()
+ opt.state = qt.QStyle.State_None
+ margin = 1
+ opt.rect = rect.marginsAdded(qt.QMargins(margin, margin, margin, margin))
+ opt.palette = option.palette
+ style.drawPrimitive(qt.QStyle.PE_IndicatorProgressChunk, opt, painter, self)
+
+ def getProgressItemByPosition(self, pos: qt.QPoint) -> ProgressItem | None:
+ """Returns the stacked item at a position of the component."""
+ minimum = self.minimum()
+ maximum = self.maximum()
+ vRange = maximum - minimum
+ w = self.width()
+ v = pos.x() * vRange / w
+ current = 0
+ for item in self.__stacks.values():
+ if not item.visible:
+ continue
+ current += item.value
+ if v < current:
+ return item
+ return None
+
+ def tooltipFromProgressItem(self, item: ProgressItem) -> str | None:
+ """Returns the tooltip to display over an item.
+
+ It is triggered when the tooltip have to be displayed.
+ """
+ return item.toolTip
+
+ def event(self, event: qt.QEvent):
+ if event.type() == qt.QEvent.ToolTip:
+ item = self.getProgressItemByPosition(event.pos())
+ if item is not None:
+ toolTip = self.tooltipFromProgressItem(item)
+ if toolTip:
+ qt.QToolTip.showText(event.globalPos(), toolTip, self)
+ return True
+ return super().event(event)
diff --git a/src/silx/gui/widgets/TableWidget.py b/src/silx/gui/widgets/TableWidget.py
index 9bada5e..7f6c1eb 100644
--- a/src/silx/gui/widgets/TableWidget.py
+++ b/src/silx/gui/widgets/TableWidget.py
@@ -82,10 +82,12 @@ class CopySelectedCellsAction(qt.QAction):
:param table: :class:`QTableView` to which this action belongs.
"""
+
def __init__(self, table):
if not isinstance(table, qt.QTableView):
- raise ValueError('CopySelectedCellsAction must be initialised ' +
- 'with a QTableWidget.')
+ raise ValueError(
+ "CopySelectedCellsAction must be initialised " + "with a QTableWidget."
+ )
super(CopySelectedCellsAction, self).__init__(table)
self.setText("Copy selection")
self.setToolTip("Copy selected cells into the clipboard.")
@@ -125,11 +127,11 @@ class CopySelectedCellsAction(qt.QAction):
data_model.setData(index, "")
copied_text += col_separator
# remove the right-most tabulation
- copied_text = copied_text[:-len(col_separator)]
+ copied_text = copied_text[: -len(col_separator)]
# add a newline
copied_text += row_separator
# remove final newline
- copied_text = copied_text[:-len(row_separator)]
+ copied_text = copied_text[: -len(row_separator)]
# put this text into clipboard
qapp = qt.QApplication.instance()
@@ -146,10 +148,12 @@ class CopyAllCellsAction(qt.QAction):
:param table: :class:`QTableView` to which this action belongs.
"""
+
def __init__(self, table):
if not isinstance(table, qt.QTableView):
- raise ValueError('CopyAllCellsAction must be initialised ' +
- 'with a QTableWidget.')
+ raise ValueError(
+ "CopyAllCellsAction must be initialised " + "with a QTableWidget."
+ )
super(CopyAllCellsAction, self).__init__(table)
self.setText("Copy all")
self.setToolTip("Copy all cells into the clipboard.")
@@ -175,11 +179,11 @@ class CopyAllCellsAction(qt.QAction):
data_model.setData(index, "")
copied_text += col_separator
# remove the right-most tabulation
- copied_text = copied_text[:-len(col_separator)]
+ copied_text = copied_text[: -len(col_separator)]
# add a newline
copied_text += row_separator
# remove final newline
- copied_text = copied_text[:-len(row_separator)]
+ copied_text = copied_text[: -len(row_separator)]
# put this text into clipboard
qapp = qt.QApplication.instance()
@@ -206,6 +210,7 @@ class CutSelectedCellsAction(CopySelectedCellsAction):
corresponding cell in the origin table.
:param table: :class:`QTableView` to which this action belongs."""
+
def __init__(self, table):
super(CutSelectedCellsAction, self).__init__(table)
self.setText("Cut selection")
@@ -228,6 +233,7 @@ class CutAllCellsAction(CopyAllCellsAction):
newline characters.
:param table: :class:`QTableView` to which this action belongs."""
+
def __init__(self, table):
super(CutAllCellsAction, self).__init__(table)
self.setText("Cut all")
@@ -266,17 +272,21 @@ class PasteCellsAction(qt.QAction):
:param table: :class:`QTableView` to which this action belongs.
"""
+
def __init__(self, table):
if not isinstance(table, qt.QTableView):
- raise ValueError('PasteCellsAction must be initialised ' +
- 'with a QTableWidget.')
+ raise ValueError(
+ "PasteCellsAction must be initialised " + "with a QTableWidget."
+ )
super(PasteCellsAction, self).__init__(table)
self.table = table
self.setText("Paste")
self.setShortcut(qt.QKeySequence.Paste)
self.setShortcutContext(qt.Qt.WidgetShortcut)
- self.setToolTip("Paste data. The selected cell is the top-left" +
- "corner of the paste area.")
+ self.setToolTip(
+ "Paste data. The selected cell is the top-left"
+ + "corner of the paste area."
+ )
self.triggered.connect(self.pasteCellFromClipboard)
def pasteCellFromClipboard(self):
@@ -309,8 +319,10 @@ class PasteCellsAction(qt.QAction):
target_row = selected_row + row_offset
target_col = selected_col + col_offset
- if target_row >= data_model.rowCount() or\
- target_col >= data_model.columnCount():
+ if (
+ target_row >= data_model.rowCount()
+ or target_col >= data_model.columnCount()
+ ):
out_of_range_cells += 1
continue
@@ -348,10 +360,12 @@ class CopySingleCellAction(qt.QAction):
:param table: :class:`QTableView` to which this action belongs.
"""
+
def __init__(self, table):
if not isinstance(table, qt.QTableView):
- raise ValueError('CopySingleCellAction must be initialised ' +
- 'with a QTableWidget.')
+ raise ValueError(
+ "CopySingleCellAction must be initialised " + "with a QTableWidget."
+ )
super(CopySingleCellAction, self).__init__(table)
self.setText("Copy cell")
self.setToolTip("Copy cell content into the clipboard.")
@@ -359,8 +373,7 @@ class CopySingleCellAction(qt.QAction):
self.table = table
def copyCellToClipboard(self):
- """
- """
+ """ """
cell_text = self.table._text_last_cell_clicked
if cell_text is None:
return
@@ -392,6 +405,7 @@ class TableWidget(qt.QTableWidget):
:param bool cut: Enable cut action
:param bool paste: Enable paste action
"""
+
def __init__(self, parent=None, cut=False, paste=False):
super(TableWidget, self).__init__(parent)
self._text_last_cell_clicked = None
@@ -457,8 +471,10 @@ class TableWidget(qt.QTableWidget):
self.cutSelectedCellsAction.setEnabled(False)
if self.copySingleCellAction is None:
self.copySingleCellAction = CopySingleCellAction(self)
- self.insertAction(self.copySelectedCellsAction, # before first action
- self.copySingleCellAction)
+ self.insertAction(
+ self.copySelectedCellsAction, # before first action
+ self.copySingleCellAction,
+ )
self.copySingleCellAction.setVisible(True)
self.copySingleCellAction.setEnabled(True)
else:
@@ -498,6 +514,7 @@ class TableView(qt.QTableView):
:param bool cut: Enable cut action
:param bool paste: Enable paste action
"""
+
def __init__(self, parent=None, cut=False, paste=False):
super(TableView, self).__init__(parent)
self._text_last_cell_clicked = None
@@ -514,7 +531,7 @@ class TableView(qt.QTableView):
def mousePressEvent(self, event):
qindex = self.indexAt(event.pos())
- if self.copyAllCellsAction is not None: # model was set
+ if self.copyAllCellsAction is not None: # model was set
self._text_last_cell_clicked = self.model().data(qindex)
super(TableView, self).mousePressEvent(event)
@@ -567,8 +584,7 @@ class TableView(qt.QTableView):
# compare action type and parent widget with those of existing actions
for existing_action in self.actions():
if type(action) == type(existing_action):
- if hasattr(action, "table") and\
- action.table is existing_action.table:
+ if hasattr(action, "table") and action.table is existing_action.table:
return None
super(TableView, self).addAction(action)
@@ -587,8 +603,10 @@ class TableView(qt.QTableView):
self.cutSelectedCellsAction.setEnabled(False)
if self.copySingleCellAction is None:
self.copySingleCellAction = CopySingleCellAction(self)
- self.insertAction(self.copySelectedCellsAction, # before first action
- self.copySingleCellAction)
+ self.insertAction(
+ self.copySelectedCellsAction, # before first action
+ self.copySingleCellAction,
+ )
self.copySingleCellAction.setVisible(True)
self.copySingleCellAction.setEnabled(True)
else:
diff --git a/src/silx/gui/widgets/ThreadPoolPushButton.py b/src/silx/gui/widgets/ThreadPoolPushButton.py
index 8a1d428..12eb95b 100644
--- a/src/silx/gui/widgets/ThreadPoolPushButton.py
+++ b/src/silx/gui/widgets/ThreadPoolPushButton.py
@@ -57,7 +57,9 @@ class _Wrapper(qt.QRunnable):
except Exception as e:
module = self.__callable.__module__
name = self.__callable.__name__
- _logger.error("Error while executing callable %s.%s.", module, name, exc_info=True)
+ _logger.error(
+ "Error while executing callable %s.%s.", module, name, exc_info=True
+ )
holder.failed.emit(e)
finally:
holder.finished.emit()
diff --git a/src/silx/gui/widgets/UrlList.py b/src/silx/gui/widgets/UrlList.py
new file mode 100644
index 0000000..3800d10
--- /dev/null
+++ b/src/silx/gui/widgets/UrlList.py
@@ -0,0 +1,139 @@
+# /*##########################################################################
+#
+# Copyright (c) 2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+from __future__ import annotations
+
+import typing
+import logging
+from collections.abc import Iterable
+from silx.io.url import DataUrl
+from silx.gui import qt
+from silx.utils.deprecation import deprecated
+
+_logger = logging.getLogger(__name__)
+
+
+class UrlList(qt.QListWidget):
+ """List of URLs with user selection"""
+
+ sigCurrentUrlChanged = qt.Signal(str)
+ """Signal emitted when the active/current URL has changed.
+
+ This signal emits the empty string when there is no longer an active URL.
+ """
+
+ sigUrlRemoved = qt.Signal(str)
+ """Signal emit when an url is removed from the URL list.
+
+ Provides the url (DataUrl) as a string
+ """
+
+ def __init__(self, parent=None):
+ super().__init__(parent)
+ self._editable = False
+ # are we in 'editable' mode: for now if true then we can remove some items from the list
+
+ # menu to be triggered when in edition from right-click
+ self._menu = qt.QMenu()
+ self._removeAction = qt.QAction(text="Remove", parent=self)
+ self._removeAction.setShortcuts(
+ [
+ # qt.Qt.Key_Delete,
+ qt.QKeySequence.Delete,
+ ]
+ )
+ self._menu.addAction(self._removeAction)
+
+ # connect signal / Slot
+ self.currentItemChanged.connect(self._notifyCurrentUrlChanged)
+
+ def setEditable(self, editable: bool):
+ """Toggle whether the user can remove some URLs from the list"""
+ if editable != self._editable:
+ self._editable = editable
+ # discusable choice: should we change the selection mode ? No much meaning
+ # to be in ExtendedSelection if we are not in editable mode. But does it has more
+ # meaning to change the selection mode ?
+ if editable:
+ self._removeAction.triggered.connect(self._removeSelectedItems)
+ self.addAction(self._removeAction)
+ else:
+ self._removeAction.triggered.disconnect(self._removeSelectedItems)
+ self.removeAction(self._removeAction)
+
+ @deprecated(replacement="addUrls", since_version="2.0")
+ def setUrls(self, urls: Iterable[DataUrl]) -> None:
+ self.addUrls(urls)
+
+ def addUrls(self, urls: Iterable[DataUrl]) -> None:
+ """Append multiple DataUrl to the list"""
+ self.addItems([url.path() for url in urls])
+
+ def removeUrl(self, url: str):
+ """Remove given URL from the list"""
+ sel_items = self.findItems(url, qt.Qt.MatchExactly)
+ if len(sel_items) > 0:
+ assert len(sel_items) == 0, "at most one item expected"
+ self.removeItemWidget(sel_items[0])
+
+ def _notifyCurrentUrlChanged(self, current, previous):
+ if current is None:
+ self.sigCurrentUrlChanged.emit("")
+ else:
+ self.sigCurrentUrlChanged.emit(current.text())
+
+ def setUrl(self, url: typing.Optional[DataUrl]) -> None:
+ """Set the current URL.
+
+ :param url: The new selected URL. Use `None` to clear the selection.
+ """
+ if url is None:
+ self.clearSelection()
+ self.sigCurrentUrlChanged.emit("")
+ else:
+ assert isinstance(url, DataUrl)
+ sel_items = self.findItems(url.path(), qt.Qt.MatchExactly)
+ if sel_items is None:
+ _logger.warning(url.path(), " is not registered in the list.")
+ elif len(sel_items) > 0:
+ item = sel_items[0]
+ self.setCurrentItem(item)
+ self.sigCurrentUrlChanged.emit(item.text())
+
+ def _removeSelectedItems(self):
+ if not self._editable:
+ raise ValueError("UrlList is not set as 'editable'")
+ urls = []
+ for item in self.selectedItems():
+ url = item.text()
+ self.takeItem(self.row(item))
+ urls.append(url)
+ # as the connected slot of 'sigUrlRemoved' can modify the items, better handling all at the end
+ for url in urls:
+ self.sigUrlRemoved.emit(url)
+
+ def contextMenuEvent(self, event):
+ if self._editable:
+ globalPos = self.mapToGlobal(event.pos())
+ self._menu.exec_(globalPos)
diff --git a/src/silx/gui/widgets/UrlSelectionTable.py b/src/silx/gui/widgets/UrlSelectionTable.py
index bc75d32..051ff32 100644
--- a/src/silx/gui/widgets/UrlSelectionTable.py
+++ b/src/silx/gui/widgets/UrlSelectionTable.py
@@ -1,5 +1,5 @@
# /*##########################################################################
-# Copyright (C) 2017-2021 European Synchrotron Radiation Facility
+# Copyright (C) 2017-2023 European Synchrotron Radiation Facility
#
# This file is part of the PyMca X-ray Fluorescence Toolkit developed at
# the ESRF by the Software group.
@@ -29,25 +29,88 @@ __author__ = ["H. Payno"]
__license__ = "MIT"
__date__ = "19/03/2018"
-from silx.gui import qt
-from collections import OrderedDict
-from silx.gui.widgets.TableWidget import TableWidget
-from silx.io.url import DataUrl
+import os
import functools
import logging
-import os
+from silx.gui import qt
+from silx.gui import utils as qtutils
+from silx.gui.widgets.TableWidget import TableWidget
+from silx.io.url import DataUrl, slice_sequence_to_string
+from silx.utils.deprecation import deprecated, deprecated_warning
+from silx.gui import constants
logger = logging.getLogger(__name__)
+class _IntegratedRadioButton(qt.QWidget):
+ """RadioButton integrated in the QTableWidget as a centered widget"""
+
+ toggled = qt.Signal()
+
+ def __init__(self, parent=None):
+ qt.QWidget.__init__(self, parent=parent)
+ self.setContentsMargins(1, 1, 1, 1)
+ layout = qt.QHBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(1)
+
+ self._radio = qt.QRadioButton(parent=self)
+ self._radio.setObjectName("radio")
+ self._radio.setAutoExclusive(False)
+ self._radio.setMinimumSize(self._radio.minimumSizeHint())
+ self._radio.setMaximumSize(self._radio.minimumSizeHint())
+ self._radio.toggled.connect(self.toggled.emit)
+ layout.addWidget(self._radio)
+ self.setSizePolicy(qt.QSizePolicy.Fixed, qt.QSizePolicy.Fixed)
+
+ def setChecked(self, checked: bool):
+ self._radio.setChecked(checked)
+
+ def isChecked(self) -> bool:
+ return self._radio.isChecked()
+
+
+class _DataUrlItem(qt.QTableWidgetItem):
+ FILENAME = 0
+ DATAPATH = 1
+ SLICE = 2
+
+ def __init__(self, url, display: int):
+ qt.QTableWidgetItem.__init__(self)
+ self._url = url
+ self._display = display
+
+ if self._display == self.FILENAME:
+ text = os.path.basename(self._url.file_path())
+ elif self._display == self.DATAPATH:
+ text = self._url.data_path()
+ elif self._display == self.SLICE:
+ s = self._url.data_slice()
+ if s is not None:
+ text = slice_sequence_to_string(self._url.data_slice())
+ else:
+ text = ""
+ else:
+ raise RuntimeError(f"Unsupported display node: {self._display}")
+
+ toolTip = self._url.path()
+
+ self.setText(text)
+ self.setToolTip(toolTip)
+
+ def dataUrl(self):
+ return self._url
+
+
class UrlSelectionTable(TableWidget):
"""Table used to select the color channel to be displayed for each"""
- COLUMS_INDEX = OrderedDict([
- ('url', 0),
- ('img A', 1),
- ('img B', 2),
- ])
+ FILENAME_COLUMN = 0
+ DATAPATH_COLUMN = 1
+ SLICE_COLUMN = 2
+ IMG_A_COLUMN = 3
+ IMG_B_COLUMN = 4
+ NB_COLUMNS = 5
sigImageAChanged = qt.Signal(str)
"""Signal emitted when the image A change. Param is the image url path"""
@@ -62,12 +125,38 @@ class UrlSelectionTable(TableWidget):
def clear(self):
qt.QTableWidget.clear(self)
self.setRowCount(0)
- self.setColumnCount(len(self.COLUMS_INDEX))
- self.setHorizontalHeaderLabels(list(self.COLUMS_INDEX.keys()))
- self.verticalHeader().hide()
- self.horizontalHeader().setSectionResizeMode(0,
- qt.QHeaderView.Stretch)
+ self.setColumnCount(self.NB_COLUMNS)
+ self.setSelectionBehavior(qt.QAbstractItemView.SelectRows)
+ self.setSelectionMode(qt.QAbstractItemView.NoSelection)
+ item = qt.QTableWidgetItem()
+ item.setText("Filename")
+ item.setToolTip("Filename to the data")
+ self.setHorizontalHeaderItem(self.FILENAME_COLUMN, item)
+ item = qt.QTableWidgetItem()
+ item.setText("Datapath")
+ item.setToolTip("Data path to the dataset")
+ self.setHorizontalHeaderItem(self.DATAPATH_COLUMN, item)
+ item = qt.QTableWidgetItem()
+ item.setText("Slice")
+ item.setToolTip("Slice applied to the dataset")
+ self.setHorizontalHeaderItem(self.SLICE_COLUMN, item)
+ item = qt.QTableWidgetItem()
+ item.setText("A")
+ item.setToolTip("Selected image as A")
+ self.setHorizontalHeaderItem(self.IMG_A_COLUMN, item)
+ item = qt.QTableWidgetItem()
+ item.setText("B")
+ item.setToolTip("Selected image as B")
+ self.setHorizontalHeaderItem(self.IMG_B_COLUMN, item)
+
+ self.verticalHeader().hide()
+ setSectionResizeMode = self.horizontalHeader().setSectionResizeMode
+ setSectionResizeMode(self.FILENAME_COLUMN, qt.QHeaderView.ResizeToContents)
+ setSectionResizeMode(self.DATAPATH_COLUMN, qt.QHeaderView.Stretch)
+ setSectionResizeMode(self.SLICE_COLUMN, qt.QHeaderView.ResizeToContents)
+ setSectionResizeMode(self.IMG_A_COLUMN, qt.QHeaderView.ResizeToContents)
+ setSectionResizeMode(self.IMG_B_COLUMN, qt.QHeaderView.ResizeToContents)
self.setSortingEnabled(True)
self._checkBoxes = {}
@@ -79,11 +168,12 @@ class UrlSelectionTable(TableWidget):
for url in urls:
self.addUrl(url=url)
- def addUrl(self, url, **kwargs):
+ def addUrl(self, url: DataUrl, **kwargs):
"""
+ Append this DataUrl to the end of the list of URLs.
- :param url:
- :param args:
+ :param url:
+ :param args:
:return: index of the created items row
:rtype int
"""
@@ -91,79 +181,167 @@ class UrlSelectionTable(TableWidget):
row = self.rowCount()
self.setRowCount(row + 1)
- _item = qt.QTableWidgetItem()
- _item.setText(os.path.basename(url.path()))
- _item.setFlags(qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable)
- self.setItem(row, self.COLUMS_INDEX['url'], _item)
+ item = _DataUrlItem(url, _DataUrlItem.FILENAME)
+ item.setFlags(qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable)
+ self.setItem(row, self.FILENAME_COLUMN, item)
- widgetImgA = qt.QRadioButton(parent=self)
- widgetImgA.setAutoExclusive(False)
- self.setCellWidget(row, self.COLUMS_INDEX['img A'], widgetImgA)
- callbackImgA = functools.partial(self._activeImgAChanged, url.path())
+ item = _DataUrlItem(url, _DataUrlItem.DATAPATH)
+ item.setFlags(qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable)
+ self.setItem(row, self.DATAPATH_COLUMN, item)
+
+ item = _DataUrlItem(url, _DataUrlItem.SLICE)
+ item.setFlags(qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable)
+ self.setItem(row, self.SLICE_COLUMN, item)
+
+ widgetImgA = _IntegratedRadioButton(parent=self)
+ self.setCellWidget(row, self.IMG_A_COLUMN, widgetImgA)
+ callbackImgA = functools.partial(self._activeImgAChanged, row)
widgetImgA.toggled.connect(callbackImgA)
- widgetImgB = qt.QRadioButton(parent=self)
- widgetImgA.setAutoExclusive(False)
- self.setCellWidget(row, self.COLUMS_INDEX['img B'], widgetImgB)
- callbackImgB = functools.partial(self._activeImgBChanged, url.path())
+ widgetImgB = _IntegratedRadioButton(parent=self)
+ self.setCellWidget(row, self.IMG_B_COLUMN, widgetImgB)
+ callbackImgB = functools.partial(self._activeImgBChanged, row)
widgetImgB.toggled.connect(callbackImgB)
- self._checkBoxes[url.path()] = {'img A': widgetImgA,
- 'img B': widgetImgB}
+ self._checkBoxes[row] = {
+ self.IMG_A_COLUMN: widgetImgA,
+ self.IMG_B_COLUMN: widgetImgB,
+ }
self.resizeColumnsToContents()
return row
- def _activeImgAChanged(self, name):
- self._updatecheckBoxes('img A', name)
- self.sigImageAChanged.emit(name)
+ def _getItemFromUrlPath(self, urlPath: str) -> _DataUrlItem:
+ """Returns the Qt item storing this urlPath, else None"""
+ for r in range(self.rowCount()):
+ item = self.item(r, self.FILENAME_COLUMN)
+ url = item.dataUrl()
+ if url.path() == urlPath:
+ return item
+ return None
+
+ def setError(self, urlPath: str, message: str):
+ """Flag this urlPath with an error in the UI."""
+ item = self._getItemFromUrlPath(urlPath)
+ if item is None:
+ return
+ if message == "":
+ item.setIcon(qt.QIcon())
+ item.setToolTip("")
+ else:
+ style = qt.QApplication.style()
+ icon = style.standardIcon(qt.QStyle.SP_MessageBoxCritical)
+ item.setIcon(icon)
+ item.setToolTip(f"Error: {message}")
- def _activeImgBChanged(self, name):
- self._updatecheckBoxes('img B', name)
- self.sigImageBChanged.emit(name)
+ def _activeImgAChanged(self, row):
+ if self._checkBoxes[row][self.IMG_A_COLUMN].isChecked():
+ self._updateCheckBoxes(self.IMG_A_COLUMN, row)
+ url = self.item(row, self.FILENAME_COLUMN).dataUrl()
+ self.sigImageAChanged.emit(url.path())
+ else:
+ self.sigImageAChanged.emit(None)
- def _updatecheckBoxes(self, whichImg, name):
- assert name in self._checkBoxes
- assert whichImg in self._checkBoxes[name]
- if self._checkBoxes[name][whichImg].isChecked():
- for radioUrl in self._checkBoxes:
- if radioUrl != name:
- self._checkBoxes[radioUrl][whichImg].blockSignals(True)
- self._checkBoxes[radioUrl][whichImg].setChecked(False)
- self._checkBoxes[radioUrl][whichImg].blockSignals(False)
+ def _activeImgBChanged(self, row):
+ if self._checkBoxes[row][self.IMG_B_COLUMN].isChecked():
+ self._updateCheckBoxes(self.IMG_B_COLUMN, row)
+ url = self.item(row, self.FILENAME_COLUMN).dataUrl()
+ self.sigImageBChanged.emit(url.path())
+ else:
+ self.sigImageBChanged.emit(None)
+ def _updateCheckBoxes(self, column, row):
+ for r in range(self.rowCount()):
+ if r == row:
+ continue
+ c = self._checkBoxes[r][column]
+ with qtutils.blockSignals(c):
+ c.setChecked(False)
+
+ @deprecated(
+ replacement="getUrlSelection",
+ since_version="2.0",
+ reason="Conflict with Qt API",
+ )
def getSelection(self):
+ return self.getUrlSelection()
+
+ def setSelection(self, url_img_a, url_img_b):
+ if isinstance(url_img_a, qt.QRect):
+ return super().setSelection(url_img_a, url_img_b)
+ deprecated_warning(
+ "Function",
+ "setSelection",
+ replacement="setUrlSelection",
+ since_version="2.0",
+ reason="Conflict with Qt API",
+ )
+ return self.setUrlSelection(url_img_a, url_img_b)
+
+ def getUrlSelection(self):
"""
:return: url selected for img A and img B.
"""
imgA = imgB = None
- for radioUrl in self._checkBoxes:
- if self._checkBoxes[radioUrl]['img A'].isChecked():
- imgA = radioUrl
- if self._checkBoxes[radioUrl]['img B'].isChecked():
- imgB = radioUrl
+ for row in range(self.rowCount()):
+ url = self.item(row, self.FILENAME_COLUMN).dataUrl()
+ if self._checkBoxes[row][self.IMG_A_COLUMN].isChecked():
+ imgA = url
+ if self._checkBoxes[row][self.IMG_B_COLUMN].isChecked():
+ imgB = url
return imgA, imgB
- def setSelection(self, url_img_a, url_img_b):
+ def setUrlSelection(self, url_img_a, url_img_b):
"""
:param ddict: key: image url, values: list of active channels
"""
- for radioUrl in self._checkBoxes:
- for img in ('img A', 'img B'):
- self._checkBoxes[radioUrl][img].blockSignals(True)
- self._checkBoxes[radioUrl][img].setChecked(False)
- self._checkBoxes[radioUrl][img].blockSignals(False)
-
- self._checkBoxes[radioUrl][img].blockSignals(True)
- self._checkBoxes[url_img_a]['img A'].setChecked(True)
- self._checkBoxes[radioUrl][img].blockSignals(False)
-
- self._checkBoxes[radioUrl][img].blockSignals(True)
- self._checkBoxes[url_img_b]['img B'].setChecked(True)
- self._checkBoxes[radioUrl][img].blockSignals(False)
+ rowA = None
+ rowB = None
+ for row in range(self.rowCount()):
+ for img in (self.IMG_A_COLUMN, self.IMG_B_COLUMN):
+ c = self._checkBoxes[row][img]
+ with qtutils.blockSignals(c):
+ c.setChecked(False)
+ url = self.item(row, self.FILENAME_COLUMN).dataUrl()
+ if url.path() == url_img_a:
+ rowA = row
+ if url.path() == url_img_b:
+ rowB = row
+
+ if rowA is not None:
+ c = self._checkBoxes[rowA][self.IMG_A_COLUMN]
+ with qtutils.blockSignals(c):
+ c.setChecked(True)
+
+ if rowB is not None:
+ c = self._checkBoxes[rowB][self.IMG_B_COLUMN]
+ with qtutils.blockSignals(c):
+ c.setChecked(True)
+
self.sigImageAChanged.emit(url_img_a)
self.sigImageBChanged.emit(url_img_b)
def removeUrl(self, url):
raise NotImplementedError("")
+
+ def supportedDropActions(self):
+ """Inherited method to redefine supported drop actions."""
+ return qt.Qt.CopyAction | qt.Qt.MoveAction
+
+ def mimeTypes(self):
+ """Inherited method to redefine draggable mime types."""
+ return [constants.SILX_URI_MIMETYPE]
+
+ def dropMimeData(
+ self, row: int, column: int, mimedata: qt.QMimeType, action: qt.Qt.DropAction
+ ):
+ """Inherited method to handle a drop operation to this model."""
+ if action == qt.Qt.IgnoreAction:
+ return True
+ if mimedata.hasFormat(constants.SILX_URI_MIMETYPE):
+ urlText = str(mimedata.data(constants.SILX_URI_MIMETYPE), "utf-8")
+ url = DataUrl(urlText)
+ self.addUrl(url)
+ return True
+ return False
diff --git a/src/silx/gui/widgets/WaitingOverlay.py b/src/silx/gui/widgets/WaitingOverlay.py
new file mode 100644
index 0000000..f6872d6
--- /dev/null
+++ b/src/silx/gui/widgets/WaitingOverlay.py
@@ -0,0 +1,111 @@
+import weakref
+from typing import Optional
+from silx.gui.widgets.WaitingPushButton import WaitingPushButton
+from silx.gui import qt
+from silx.gui.qt import inspect as qt_inspect
+from silx.gui.plot import PlotWidget
+
+
+class WaitingOverlay(qt.QWidget):
+ """Widget overlaying another widget with a processing wheel icon.
+
+ :param parent: widget on top of which to display the "processing/waiting wheel"
+ """
+
+ def __init__(self, parent: qt.QWidget) -> None:
+ super().__init__(parent)
+ self.setContentsMargins(0, 0, 0, 0)
+
+ self._waitingButton = WaitingPushButton(self)
+ self._waitingButton.setDown(True)
+ self._waitingButton.setWaiting(True)
+ self._waitingButton.setStyleSheet(
+ "QPushButton { background-color: rgba(150, 150, 150, 40); border: 0px; border-radius: 10px; }"
+ )
+ self._registerParent(parent)
+
+ def text(self) -> str:
+ """Returns displayed text"""
+ return self._waitingButton.text()
+
+ def setText(self, text: str):
+ """Set displayed text"""
+ self._waitingButton.setText(text)
+ self._resize()
+
+ def _listenedWidget(self, parent: qt.QWidget) -> qt.QWidget:
+ """Returns widget to register event filter to according to parent"""
+ if isinstance(parent, PlotWidget):
+ return parent.getWidgetHandle()
+ return parent
+
+ def _backendChanged(self):
+ self._listenedWidget(self.parent()).installEventFilter(self)
+ self._resizeLater()
+
+ def _registerParent(self, parent: Optional[qt.QWidget]):
+ if parent is None:
+ return
+ self._listenedWidget(parent).installEventFilter(self)
+ if isinstance(parent, PlotWidget):
+ parent.sigBackendChanged.connect(self._backendChanged)
+ self._resize()
+
+ def _unregisterParent(self, parent: Optional[qt.QWidget]):
+ if parent is None:
+ return
+ if isinstance(parent, PlotWidget):
+ parent.sigBackendChanged.disconnect(self._backendChanged)
+ self._listenedWidget(parent).removeEventFilter(self)
+
+ def setParent(self, parent: qt.QWidget):
+ self._unregisterParent(self.parent())
+ super().setParent(parent)
+ self._registerParent(parent)
+
+ def showEvent(self, event: qt.QShowEvent):
+ super().showEvent(event)
+ self._waitingButton.setVisible(True)
+
+ def hideEvent(self, event: qt.QHideEvent):
+ super().hideEvent(event)
+ self._waitingButton.setVisible(False)
+
+ def _resize(self):
+ if not qt_inspect.isValid(self):
+ return # For _resizeLater in case the widget has been deleted
+
+ parent = self.parent()
+ if parent is None:
+ return
+
+ size = self._waitingButton.sizeHint()
+ if isinstance(parent, PlotWidget):
+ offset = parent.getWidgetHandle().mapTo(parent, qt.QPoint(0, 0))
+ left, top, width, height = parent.getPlotBoundsInPixels()
+ rect = qt.QRect(
+ qt.QPoint(
+ int(offset.x() + left + width / 2 - size.width() / 2),
+ int(offset.y() + top + height / 2 - size.height() / 2),
+ ),
+ size,
+ )
+ else:
+ position = parent.size()
+ position = (position - size) / 2
+ rect = qt.QRect(qt.QPoint(position.width(), position.height()), size)
+ self.setGeometry(rect)
+ self.raise_()
+
+ def _resizeLater(self):
+ qt.QTimer.singleShot(0, self._resize)
+
+ def eventFilter(self, watched: qt.QWidget, event: qt.QEvent):
+ if event.type() == qt.QEvent.Resize:
+ self._resize()
+ self._resizeLater() # Defer resize for the receiver to have handled it
+ return super().eventFilter(watched, event)
+
+ # expose Waiting push button API
+ def setIconSize(self, size):
+ self._waitingButton.setIconSize(size)
diff --git a/src/silx/gui/widgets/WaitingPushButton.py b/src/silx/gui/widgets/WaitingPushButton.py
index 8bd9ea0..ff31286 100644
--- a/src/silx/gui/widgets/WaitingPushButton.py
+++ b/src/silx/gui/widgets/WaitingPushButton.py
@@ -104,8 +104,10 @@ class WaitingPushButton(qt.QPushButton):
w += self.style().pixelMetric(qt.QStyle.PM_MenuButtonIndicator, opt, self)
contentSize = qt.QSize(w, h)
- sizeHint = self.style().sizeFromContents(qt.QStyle.CT_PushButton, opt, contentSize, self)
- if qt.BINDING in ('PySide2', 'PyQt5'): # Qt6: globalStrut not available
+ sizeHint = self.style().sizeFromContents(
+ qt.QStyle.CT_PushButton, opt, contentSize, self
+ )
+ if qt.BINDING == "PyQt5": # Qt6: globalStrut not available
sizeHint = sizeHint.expandedTo(qt.QApplication.globalStrut())
return sizeHint
@@ -126,7 +128,9 @@ class WaitingPushButton(qt.QPushButton):
"""
return self.__disabled_when_waiting
- disabledWhenWaiting = qt.Property(bool, isDisabledWhenWaiting, setDisabledWhenWaiting)
+ disabledWhenWaiting = qt.Property(
+ bool, isDisabledWhenWaiting, setDisabledWhenWaiting
+ )
"""Property to enable/disable the auto disabled state when the button is waiting."""
def __setWaitingIcon(self, icon):
diff --git a/src/silx/gui/widgets/test/test_boxlayoutdockwidget.py b/src/silx/gui/widgets/test/test_boxlayoutdockwidget.py
index dd0ddf4..45f0152 100644
--- a/src/silx/gui/widgets/test/test_boxlayoutdockwidget.py
+++ b/src/silx/gui/widgets/test/test_boxlayoutdockwidget.py
@@ -27,8 +27,6 @@ __authors__ = ["T. Vincent"]
__license__ = "MIT"
__date__ = "06/03/2018"
-import unittest
-
from silx.gui.widgets.BoxLayoutDockWidget import BoxLayoutDockWidget
from silx.gui import qt
from silx.gui.utils.testutils import TestCaseQt
@@ -53,8 +51,8 @@ class TestBoxLayoutDockWidget(TestCaseQt):
"""Test update of layout direction according to dock area"""
# Create a widget with a QBoxLayout
layout = qt.QBoxLayout(qt.QBoxLayout.LeftToRight)
- layout.addWidget(qt.QLabel('First'))
- layout.addWidget(qt.QLabel('Second'))
+ layout.addWidget(qt.QLabel("First"))
+ layout.addWidget(qt.QLabel("Second"))
widget = qt.QWidget()
widget.setLayout(layout)
diff --git a/src/silx/gui/widgets/test/test_elidedlabel.py b/src/silx/gui/widgets/test/test_elidedlabel.py
index d7e2cdc..fbf63f0 100644
--- a/src/silx/gui/widgets/test/test_elidedlabel.py
+++ b/src/silx/gui/widgets/test/test_elidedlabel.py
@@ -32,7 +32,6 @@ from silx.gui.utils import testutils
class TestElidedLabel(testutils.TestCaseQt):
-
def setUp(self):
self.label = ElidedLabel()
self.label.show()
diff --git a/src/silx/gui/widgets/test/test_floatedit.py b/src/silx/gui/widgets/test/test_floatedit.py
new file mode 100644
index 0000000..c5edded
--- /dev/null
+++ b/src/silx/gui/widgets/test/test_floatedit.py
@@ -0,0 +1,82 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for FloatEdit"""
+
+__license__ = "MIT"
+
+import pytest
+from silx.gui import qt
+from silx.gui.widgets.FloatEdit import FloatEdit
+
+
+@pytest.fixture
+def floatEdit(qWidgetFactory):
+ widget = qWidgetFactory(FloatEdit)
+ yield widget
+
+
+@pytest.fixture
+def floatEditHolder(qWidgetFactory, floatEdit):
+ widget = qWidgetFactory(qt.QWidget)
+ layout = qt.QHBoxLayout(widget)
+ layout.addStretch()
+ layout.addWidget(floatEdit)
+ yield widget
+
+
+def test_show(floatEdit):
+ pass
+
+
+def test_value(floatEdit):
+ floatEdit.setValue(1.5)
+ assert floatEdit.value() == 1.5
+
+
+def test_no_widgetresize(floatEditHolder, floatEdit):
+ floatEditHolder.resize(50, 50)
+ floatEdit.setValue(123)
+ a = floatEdit.width()
+ floatEdit.setValue(123456789123456789.123456789123456789)
+ b = floatEdit.width()
+ assert b == a
+
+
+def test_widgetresize(qapp_utils, floatEditHolder, floatEdit):
+ floatEditHolder.resize(50, 50)
+ floatEdit.setWidgetResizable(True)
+ # Initial
+ floatEdit.setValue(123)
+ qapp_utils.qWait()
+ a = floatEdit.width()
+ # Grow
+ floatEdit.setValue(123456789123456789.123456789123456789)
+ qapp_utils.qWait()
+ b = floatEdit.width()
+ # Shrink
+ floatEdit.setValue(123)
+ qapp_utils.qWait()
+ c = floatEdit.width()
+ assert b > a
+ assert a <= c < b
diff --git a/src/silx/gui/widgets/test/test_flowlayout.py b/src/silx/gui/widgets/test/test_flowlayout.py
index 07f6697..c39e2a5 100644
--- a/src/silx/gui/widgets/test/test_flowlayout.py
+++ b/src/silx/gui/widgets/test/test_flowlayout.py
@@ -27,8 +27,6 @@ __authors__ = ["T. Vincent"]
__license__ = "MIT"
__date__ = "02/08/2018"
-import unittest
-
from silx.gui.widgets.FlowLayout import FlowLayout
from silx.gui import qt
from silx.gui.utils.testutils import TestCaseQt
@@ -55,8 +53,8 @@ class TestFlowLayout(TestCaseQt):
layout = FlowLayout()
self.widget.setLayout(layout)
- layout.addWidget(qt.QLabel('first'))
- layout.addWidget(qt.QLabel('second'))
+ layout.addWidget(qt.QLabel("first"))
+ layout.addWidget(qt.QLabel("second"))
self.assertEqual(layout.count(), 2)
layout.setHorizontalSpacing(10)
diff --git a/src/silx/gui/widgets/test/test_framebrowser.py b/src/silx/gui/widgets/test/test_framebrowser.py
index 7fa621b..bb80a58 100644
--- a/src/silx/gui/widgets/test/test_framebrowser.py
+++ b/src/silx/gui/widgets/test/test_framebrowser.py
@@ -26,8 +26,6 @@ __license__ = "MIT"
__date__ = "23/03/2018"
-import unittest
-
from silx.gui.utils.testutils import TestCaseQt
from silx.gui.widgets.FrameBrowser import FrameBrowser
diff --git a/src/silx/gui/widgets/test/test_hierarchicaltableview.py b/src/silx/gui/widgets/test/test_hierarchicaltableview.py
index 8f6a2a0..5ef36a0 100644
--- a/src/silx/gui/widgets/test/test_hierarchicaltableview.py
+++ b/src/silx/gui/widgets/test/test_hierarchicaltableview.py
@@ -25,15 +25,12 @@ __authors__ = ["V. Valls"]
__license__ = "MIT"
__date__ = "07/04/2017"
-import unittest
-
from .. import HierarchicalTableView
from silx.gui.utils.testutils import TestCaseQt
from silx.gui import qt
class TableModel(HierarchicalTableView.HierarchicalTableModel):
-
def __init__(self, parent):
HierarchicalTableView.HierarchicalTableModel.__init__(self, parent)
self.__content = {}
diff --git a/src/silx/gui/widgets/test/test_legendiconwidget.py b/src/silx/gui/widgets/test/test_legendiconwidget.py
index cfebc62..d31de23 100644
--- a/src/silx/gui/widgets/test/test_legendiconwidget.py
+++ b/src/silx/gui/widgets/test/test_legendiconwidget.py
@@ -27,8 +27,6 @@ __authors__ = ["V. Valls"]
__license__ = "MIT"
__date__ = "23/10/2020"
-import unittest
-
from silx.gui import qt
from silx.gui.widgets.LegendIconWidget import LegendIconWidget
from silx.gui.utils.testutils import TestCaseQt
diff --git a/src/silx/gui/widgets/test/test_periodictable.py b/src/silx/gui/widgets/test/test_periodictable.py
index f687e36..a2efed1 100644
--- a/src/silx/gui/widgets/test/test_periodictable.py
+++ b/src/silx/gui/widgets/test/test_periodictable.py
@@ -25,8 +25,6 @@ __authors__ = ["P. Knobel"]
__license__ = "MIT"
__date__ = "05/12/2016"
-import unittest
-
from .. import PeriodicTable
from silx.gui.utils.testutils import TestCaseQt
from silx.gui import qt
@@ -49,9 +47,8 @@ class TestPeriodicTable(TestCaseQt):
def testCustomElements(self):
PTI = PeriodicTable.ColoredPeriodicTableItem
my_items = [
- PTI("Xx", 42, 43, 44, "xaxatorium", 1002.2,
- bgcolor="#FF0000"),
- PTI("Yy", 25, 22, 44, "yoyotrium", 8.8)
+ PTI("Xx", 42, 43, 44, "xaxatorium", 1002.2, bgcolor="#FF0000"),
+ PTI("Yy", 25, 22, 44, "yoyotrium", 8.8),
]
pt = PeriodicTable.PeriodicTable(elements=my_items)
@@ -63,8 +60,7 @@ class TestPeriodicTable(TestCaseQt):
self.assertEqual(selection[0].Z, 42)
self.assertEqual(selection[0].col, 43)
self.assertAlmostEqual(selection[0].mass, 1002.2)
- self.assertEqual(qt.QColor(selection[0].bgcolor),
- qt.QColor(qt.Qt.red))
+ self.assertEqual(qt.QColor(selection[0].bgcolor), qt.QColor(qt.Qt.red))
self.assertTrue(pt.isElementSelected("Xx"))
self.assertFalse(pt.isElementSelected("Yy"))
@@ -78,7 +74,7 @@ class TestPeriodicTable(TestCaseQt):
my_items = [
MyPTI("Xx", 42, 43, 44, "xaxatorium", 1002.2, "spam"),
- MyPTI("Yy", 25, 22, 44, "yoyotrium", 8.8, "eggs")
+ MyPTI("Yy", 25, 22, 44, "yoyotrium", 8.8, "eggs"),
]
pt = PeriodicTable.PeriodicTable(elements=my_items)
@@ -96,6 +92,7 @@ class TestPeriodicTable(TestCaseQt):
class TestPeriodicCombo(TestCaseQt):
"""Basic test for ArrayTableWidget with a numpy array"""
+
def setUp(self):
super(TestPeriodicCombo, self).setUp()
self.pc = PeriodicTable.PeriodicCombo()
@@ -112,8 +109,7 @@ class TestPeriodicCombo(TestCaseQt):
def testSelect(self):
self.pc.setSelection("Sb")
selection = self.pc.getSelection()
- self.assertIsInstance(selection,
- PeriodicTable.PeriodicTableItem)
+ self.assertIsInstance(selection, PeriodicTable.PeriodicTableItem)
self.assertEqual(selection.symbol, "Sb")
self.assertEqual(selection.Z, 51)
self.assertEqual(selection.name, "antimony")
@@ -121,6 +117,7 @@ class TestPeriodicCombo(TestCaseQt):
class TestPeriodicList(TestCaseQt):
"""Basic test for ArrayTableWidget with a numpy array"""
+
def setUp(self):
super(TestPeriodicList, self).setUp()
self.pl = PeriodicTable.PeriodicList()
@@ -138,8 +135,7 @@ class TestPeriodicList(TestCaseQt):
self.pl.setSelectedElements(["Li", "He", "Au"])
sel_elmts = self.pl.getSelection()
- self.assertEqual(len(sel_elmts), 3,
- "Wrong number of elements selected")
+ self.assertEqual(len(sel_elmts), 3, "Wrong number of elements selected")
for e in sel_elmts:
self.assertIsInstance(e, PeriodicTable.PeriodicTableItem)
self.assertIn(e.symbol, ["Li", "He", "Au"])
diff --git a/src/silx/gui/widgets/test/test_printpreview.py b/src/silx/gui/widgets/test/test_printpreview.py
index b703d63..e88853b 100644
--- a/src/silx/gui/widgets/test/test_printpreview.py
+++ b/src/silx/gui/widgets/test/test_printpreview.py
@@ -28,7 +28,6 @@ __license__ = "MIT"
__date__ = "19/07/2017"
-import unittest
from silx.gui.utils.testutils import TestCaseQt
from silx.gui.widgets.PrintPreview import PrintPreviewDialog
from silx.gui import qt
@@ -52,11 +51,17 @@ class TestPrintPreview(TestCaseQt):
def testAddSvg(self):
p = qt.QPrinter()
d = PrintPreviewDialog(printer=p)
- d.addSvgItem(qt.QSvgRenderer(resource_filename("gui/icons/clipboard.svg"), d.page))
+ d.addSvgItem(
+ qt.QSvgRenderer(resource_filename("gui/icons/clipboard.svg"), d.page)
+ )
self.qapp.processEvents()
def testAddPixmap(self):
p = qt.QPrinter()
d = PrintPreviewDialog(printer=p)
- d.addPixmap(qt.QPixmap.fromImage(qt.QImage(resource_filename("gui/icons/clipboard.png"))))
+ d.addPixmap(
+ qt.QPixmap.fromImage(
+ qt.QImage(resource_filename("gui/icons/clipboard.png"))
+ )
+ )
self.qapp.processEvents()
diff --git a/src/silx/gui/widgets/test/test_rangeslider.py b/src/silx/gui/widgets/test/test_rangeslider.py
index 6ed50af..a59315b 100644
--- a/src/silx/gui/widgets/test/test_rangeslider.py
+++ b/src/silx/gui/widgets/test/test_rangeslider.py
@@ -27,8 +27,6 @@ __authors__ = ["T. Vincent"]
__license__ = "MIT"
__date__ = "01/08/2018"
-import unittest
-
from silx.gui import qt, colors
from silx.gui.widgets.RangeSlider import RangeSlider
from silx.gui.utils.testutils import TestCaseQt
@@ -54,26 +52,26 @@ class TestRangeSlider(TestCaseQt, ParametricTestCase):
# Play with range
self.slider.setRange(1, 2)
- self.assertEqual(self.slider.getRange(), (1., 2.))
- self.assertEqual(self.slider.getValues(), (1., 1.))
+ self.assertEqual(self.slider.getRange(), (1.0, 2.0))
+ self.assertEqual(self.slider.getValues(), (1.0, 1.0))
self.slider.setMinimum(-1)
- self.assertEqual(self.slider.getRange(), (-1., 2.))
- self.assertEqual(self.slider.getValues(), (1., 1.))
+ self.assertEqual(self.slider.getRange(), (-1.0, 2.0))
+ self.assertEqual(self.slider.getValues(), (1.0, 1.0))
self.slider.setMaximum(0)
- self.assertEqual(self.slider.getRange(), (-1., 0.))
- self.assertEqual(self.slider.getValues(), (0., 0.))
+ self.assertEqual(self.slider.getRange(), (-1.0, 0.0))
+ self.assertEqual(self.slider.getValues(), (0.0, 0.0))
# Play with values
- self.slider.setFirstValue(-2.)
- self.assertEqual(self.slider.getValues(), (-1., 0.))
+ self.slider.setFirstValue(-2.0)
+ self.assertEqual(self.slider.getValues(), (-1.0, 0.0))
self.slider.setFirstValue(-0.5)
- self.assertEqual(self.slider.getValues(), (-0.5, 0.))
+ self.assertEqual(self.slider.getValues(), (-0.5, 0.0))
- self.slider.setSecondValue(2.)
- self.assertEqual(self.slider.getValues(), (-0.5, 0.))
+ self.slider.setSecondValue(2.0)
+ self.assertEqual(self.slider.getValues(), (-0.5, 0.0))
self.slider.setSecondValue(-0.1)
self.assertEqual(self.slider.getValues(), (-0.5, -0.1))
@@ -87,14 +85,14 @@ class TestRangeSlider(TestCaseQt, ParametricTestCase):
self.assertEqual(self.slider.getFirstPosition(), 3)
self.slider.setPositionCount(3) # Value is adjusted
- self.assertEqual(self.slider.getValues(), (0.5, 1.))
+ self.assertEqual(self.slider.getValues(), (0.5, 1.0))
self.assertEqual(self.slider.getPositions(), (1, 2))
def testGroove(self):
"""Test Groove pixmap"""
profile = list(range(100))
- for cmap in ('jet', colors.Colormap('viridis')):
+ for cmap in ("jet", colors.Colormap("viridis")):
with self.subTest(str(cmap)):
self.slider.setGroovePixmapFromProfile(profile, cmap)
pixmap = self.slider.getGroovePixmap()
diff --git a/src/silx/gui/widgets/test/test_stackedprogressbar.py b/src/silx/gui/widgets/test/test_stackedprogressbar.py
new file mode 100644
index 0000000..17267b9
--- /dev/null
+++ b/src/silx/gui/widgets/test/test_stackedprogressbar.py
@@ -0,0 +1,60 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for StackedProgressBar"""
+
+__license__ = "MIT"
+
+import pytest
+from silx.gui import qt
+from silx.gui.widgets.StackedProgressBar import StackedProgressBar
+
+
+@pytest.fixture
+def stackedProgressBar(qWidgetFactory):
+ yield qWidgetFactory(StackedProgressBar)
+
+
+def test_show(qapp_utils, stackedProgressBar: StackedProgressBar):
+ pass
+
+
+def test_value(qapp_utils, stackedProgressBar: StackedProgressBar):
+ stackedProgressBar.setRange(0, 100)
+ stackedProgressBar.setProgressItem("foo", value=0)
+ stackedProgressBar.setProgressItem("foo", value=50)
+ stackedProgressBar.setProgressItem("foo", value=100)
+
+
+def test_animation(qapp_utils, stackedProgressBar: StackedProgressBar):
+ stackedProgressBar.setRange(0, 100)
+ stackedProgressBar.setProgressItem("foo", value=0, striped=True, animated=True)
+ stackedProgressBar.setProgressItem("foo", value=50)
+ stackedProgressBar.setProgressItem("foo", value=100)
+
+
+def test_stack(qapp_utils, stackedProgressBar: StackedProgressBar):
+ stackedProgressBar.setRange(0, 100)
+ stackedProgressBar.setProgressItem("foo1", value=10, color=qt.QColor("#FF0000"))
+ stackedProgressBar.setProgressItem("foo2", value=50, color=qt.QColor("#00FF00"))
+ stackedProgressBar.setProgressItem("foo3", value=20, color=qt.QColor("#0000FF"))
diff --git a/src/silx/gui/widgets/test/test_tablewidget.py b/src/silx/gui/widgets/test/test_tablewidget.py
index 9b1e53f..d631e45 100644
--- a/src/silx/gui/widgets/test/test_tablewidget.py
+++ b/src/silx/gui/widgets/test/test_tablewidget.py
@@ -28,7 +28,6 @@ __license__ = "MIT"
__date__ = "05/12/2016"
-import unittest
from silx.gui.utils.testutils import TestCaseQt
from silx.gui.widgets.TableWidget import TableWidget
diff --git a/src/silx/gui/widgets/test/test_threadpoolpushbutton.py b/src/silx/gui/widgets/test/test_threadpoolpushbutton.py
index a3eca33..cc0b0c5 100644
--- a/src/silx/gui/widgets/test/test_threadpoolpushbutton.py
+++ b/src/silx/gui/widgets/test/test_threadpoolpushbutton.py
@@ -28,7 +28,6 @@ __license__ = "MIT"
__date__ = "17/01/2018"
-import unittest
import time
from silx.gui import qt
from silx.gui.utils.testutils import TestCaseQt
@@ -38,7 +37,6 @@ from silx.utils.testutils import LoggingValidator
class TestThreadPoolPushButton(TestCaseQt):
-
def setUp(self):
super(TestThreadPoolPushButton, self).setUp()
self._result = []
@@ -113,7 +111,7 @@ class TestThreadPoolPushButton(TestCaseQt):
button.succeeded.connect(listener.partial(test="Unexpected success"))
button.failed.connect(listener.partial(test="exception"))
button.finished.connect(listener.partial(test="f"))
- with LoggingValidator('silx.gui.widgets.ThreadPoolPushButton', error=1):
+ with LoggingValidator("silx.gui.widgets.ThreadPoolPushButton", error=1):
button.executeCallable()
self.qapp.processEvents()
time.sleep(0.1)
diff --git a/src/silx/gui/widgets/test/test_urlselectiontable.py b/src/silx/gui/widgets/test/test_urlselectiontable.py
new file mode 100644
index 0000000..dd75f08
--- /dev/null
+++ b/src/silx/gui/widgets/test/test_urlselectiontable.py
@@ -0,0 +1,72 @@
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for UrlSelectionTable"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "09/05/2023"
+
+import pytest
+import weakref
+from silx.gui.widgets.UrlSelectionTable import UrlSelectionTable
+from silx.gui import qt
+from silx.io.url import DataUrl
+
+
+@pytest.fixture
+def urlSelectionTable(qapp, qapp_utils):
+ widget = UrlSelectionTable()
+ widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ yield widget
+ widget.close()
+ ref = weakref.ref(widget)
+ widget = None
+ qapp_utils.qWaitForDestroy(ref)
+
+
+def test_show(qapp_utils, urlSelectionTable):
+ qapp_utils.qWaitForWindowExposed(urlSelectionTable)
+
+
+def test_add_urls(urlSelectionTable):
+ urlSelectionTable.addUrl(DataUrl("aaaa"))
+ urlSelectionTable.addUrl(DataUrl("bbbb"))
+ urlSelectionTable.addUrl(DataUrl("cccc"))
+ assert urlSelectionTable.rowCount() == 3
+
+
+def test_clear(urlSelectionTable):
+ urlSelectionTable.addUrl(DataUrl("aaaa"))
+ assert urlSelectionTable.rowCount() == 1
+ urlSelectionTable.clear()
+ assert urlSelectionTable.rowCount() == 0
+
+
+def test_set_remove_error(urlSelectionTable):
+ urlSelectionTable.addUrl(DataUrl("aaaa"))
+ item = urlSelectionTable._getItemFromUrlPath("aaaa")
+ urlSelectionTable.setError("aaaa", "Oh... no...")
+ assert not item.icon().isNull()
+ urlSelectionTable.setError("aaaa", "")
+ assert item.icon().isNull()
diff --git a/src/silx/gui/widgets/test/test_waitingoverlay.py b/src/silx/gui/widgets/test/test_waitingoverlay.py
new file mode 100644
index 0000000..713c4cb
--- /dev/null
+++ b/src/silx/gui/widgets/test/test_waitingoverlay.py
@@ -0,0 +1,31 @@
+import pytest
+from silx.gui import qt
+from silx.gui.widgets.WaitingOverlay import WaitingOverlay
+from silx.gui.plot import Plot2D
+from silx.gui.plot.PlotWidget import PlotWidget
+
+
+@pytest.mark.parametrize("widget_parent", (Plot2D, qt.QFrame))
+def test_show(qapp, qapp_utils, widget_parent):
+ """Simple test of the WaitingOverlay component"""
+ widget = widget_parent()
+ widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+
+ waitingOverlay = WaitingOverlay(widget)
+ waitingOverlay.setAttribute(qt.Qt.WA_DeleteOnClose)
+
+ widget.show()
+ qapp_utils.qWaitForWindowExposed(widget)
+ assert waitingOverlay._waitingButton.isWaiting()
+
+ waitingOverlay.setText("test")
+ qapp.processEvents()
+ assert waitingOverlay.text() == "test"
+ qapp_utils.qWait(1000)
+
+ waitingOverlay.hide()
+ qapp.processEvents()
+
+ widget.close()
+ waitingOverlay.close()
+ qapp.processEvents()
diff --git a/src/silx/image/_boundingbox.py b/src/silx/image/_boundingbox.py
index c016471..f114062 100644
--- a/src/silx/image/_boundingbox.py
+++ b/src/silx/image/_boundingbox.py
@@ -39,6 +39,7 @@ class _BoundingBox:
:param tuple bottom_left: (y, x) bottom left point
:param tuple top_right: (y, x) top right point
"""
+
def __init__(self, bottom_left, top_right):
self.bottom_left = bottom_left
self.top_right = top_right
@@ -59,9 +60,8 @@ class _BoundingBox:
if isinstance(item, _BoundingBox):
return self.contains(item.bottom_left) and self.contains(item.top_right)
else:
- return (
- (self.min_x <= item[1] <= self.max_x) and
- (self.min_y <= item[0] <= self.max_y)
+ return (self.min_x <= item[1] <= self.max_x) and (
+ self.min_y <= item[0] <= self.max_y
)
def collide(self, bb):
@@ -74,9 +74,8 @@ class _BoundingBox:
:rtype: bool
"""
assert isinstance(bb, _BoundingBox)
- return (
- (self.min_x < bb.max_x and self.max_x > bb.min_x) and
- (self.min_y < bb.max_y and self.max_y > bb.min_y)
+ return (self.min_x < bb.max_x and self.max_x > bb.min_x) and (
+ self.min_y < bb.max_y and self.max_y > bb.min_y
)
@staticmethod
diff --git a/src/silx/image/backprojection.py b/src/silx/image/backprojection.py
index b208d3e..350be34 100644
--- a/src/silx/image/backprojection.py
+++ b/src/silx/image/backprojection.py
@@ -21,4 +21,4 @@
#
# ############################################################################*/
-from silx.opencl.backprojection import *
+from silx.opencl.backprojection import * # noqa
diff --git a/src/silx/image/bilinear.pyx b/src/silx/image/bilinear.pyx
index 31ba354..cfa8675 100644
--- a/src/silx/image/bilinear.pyx
+++ b/src/silx/image/bilinear.pyx
@@ -9,7 +9,7 @@
# Project: silx (originally pyFAI)
# https://github.com/silx-kit/silx
#
-# Copyright (C) 2012-2020 European Synchrotron Radiation Facility, Grenoble, France
+# Copyright (C) 2012-2023 European Synchrotron Radiation Facility, Grenoble, France
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -32,7 +32,7 @@
"""Bilinear interpolator, peak finder, line-profile for images"""
__authors__ = ["J. Kieffer"]
__license__ = "MIT"
-__date__ = "26/11/2020"
+__date__ = "21/12/2023"
# C-level imports
from libc.stdint cimport uint8_t
@@ -67,9 +67,9 @@ cdef class BilinearImage:
# C-level declarations
cpdef Py_ssize_t coarse_local_maxi(self, Py_ssize_t)
- cdef Py_ssize_t c_local_maxi(self, Py_ssize_t) nogil
- cdef data_t c_funct(self, data_t, data_t) nogil
- cdef void _init_min_max(self) nogil
+ cdef Py_ssize_t c_local_maxi(self, Py_ssize_t) noexcept nogil
+ cdef data_t c_funct(self, data_t, data_t) noexcept nogil
+ cdef void _init_min_max(self) noexcept nogil
def __cinit__(self, data not None, mask=None):
"""Constructor
@@ -102,7 +102,7 @@ cdef class BilinearImage:
"""
return self.c_funct(coord[1], coord[0])
- cdef void _init_min_max(self) nogil:
+ cdef void _init_min_max(self) noexcept nogil:
"Calculate the min & max"
cdef:
Py_ssize_t i, j
@@ -118,7 +118,7 @@ cdef class BilinearImage:
self.maxi = maxi
self.mini = mini
- cdef data_t c_funct(self, data_t x, data_t y) nogil:
+ cdef data_t c_funct(self, data_t x, data_t y) noexcept nogil:
"""Function f(x, y) where f is a continuous function
made from the image.
@@ -305,7 +305,7 @@ cdef class BilinearImage:
"""
return self.c_local_maxi(x)
- cdef Py_ssize_t c_local_maxi(self, Py_ssize_t idx) nogil:
+ cdef Py_ssize_t c_local_maxi(self, Py_ssize_t idx) noexcept nogil:
"""Return the nearest local maximum without sub-pixel refinement
:param idx: start index (=row*width+column)
diff --git a/src/silx/image/marchingsquares/__init__.py b/src/silx/image/marchingsquares/__init__.py
index 1c6f15e..a310e70 100644
--- a/src/silx/image/marchingsquares/__init__.py
+++ b/src/silx/image/marchingsquares/__init__.py
@@ -50,9 +50,12 @@ def _factory(engine, image, mask):
return MarchingSquaresMergeImpl(image, mask)
elif engine == "skimage":
from _skimage import MarchingSquaresSciKitImage
+
return MarchingSquaresSciKitImage(image, mask)
else:
- raise ValueError("Engine '%s' is not supported ('merge' or 'skimage' expected).")
+ raise ValueError(
+ "Engine '%s' is not supported ('merge' or 'skimage' expected)."
+ )
def find_pixels(image, level, mask=None):
@@ -79,9 +82,9 @@ def find_pixels(image, level, mask=None):
:returns: An array of coordinates in y/x
:rtype: numpy.ndarray
"""
- assert(image is not None)
+ assert image is not None
if mask is not None:
- assert(image.shape == mask.shape)
+ assert image.shape == mask.shape
engine = "merge"
impl = _factory(engine, image, mask)
return impl.find_pixels(level)
@@ -108,9 +111,9 @@ def find_contours(image, level, mask=None):
:returns: A list of array containing y-x coordinates of points
:rtype: List[numpy.ndarray]
"""
- assert(image is not None)
+ assert image is not None
if mask is not None:
- assert(image.shape == mask.shape)
+ assert image.shape == mask.shape
engine = "merge"
impl = _factory(engine, image, mask)
return impl.find_contours(level)
diff --git a/src/silx/image/marchingsquares/_mergeimpl.pyx b/src/silx/image/marchingsquares/_mergeimpl.pyx
index ce4786f..84e53bb 100644
--- a/src/silx/image/marchingsquares/_mergeimpl.pyx
+++ b/src/silx/image/marchingsquares/_mergeimpl.pyx
@@ -1,5 +1,11 @@
+#cython: embedsignature=True, language_level=3
+## This is for optimisation
+#cython: boundscheck=False, wraparound=False, cdivision=True, initializedcheck=False,
+## This is for developping:
+##cython: profile=True, warn.undeclared=True, warn.unused=True, warn.unused_result=False, warn.unused_arg=True
+
# /*##########################################################################
-# Copyright (C) 2018-2020 European Synchrotron Radiation Facility
+# Copyright (C) 2018-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -26,7 +32,7 @@ Marching squares implementation based on a merge of segements and polygons.
__authors__ = ["Almar Klein", "Jerome Kieffer", "Valentin Valls"]
__license__ = "MIT"
-__date__ = "23/04/2018"
+__date__ = "21/12/2023"
import numpy
cimport numpy as cnumpy
@@ -78,7 +84,7 @@ cdef cppclass PolygonDescription:
point_index_t end
clist[point_t] points
- PolygonDescription() nogil:
+ PolygonDescription() noexcept nogil:
pass
"""Description of a tile context.
@@ -101,7 +107,7 @@ cdef cppclass TileContext:
clist[coord_t] final_pixels
cset[coord_t] pixels
- TileContext() nogil:
+ TileContext() noexcept nogil:
pass
@@ -133,7 +139,7 @@ cdef class _MarchingSquaresAlgorithm(object):
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
- cdef void marching_squares(self, cnumpy.float64_t level) nogil:
+ cdef void marching_squares(self, cnumpy.float64_t level) noexcept nogil:
"""
Main method to execute the marching squares.
@@ -188,7 +194,7 @@ cdef class _MarchingSquaresAlgorithm(object):
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
- cdef void reduction_2d(self, int dim_x, int dim_y, TileContext **contexts) nogil:
+ cdef void reduction_2d(self, int dim_x, int dim_y, TileContext **contexts) noexcept nogil:
"""
Reduce the problem merging first neighbours together in a recursive
process. Optimized with OpenMP.
@@ -237,7 +243,7 @@ cdef class _MarchingSquaresAlgorithm(object):
cdef inline void merge_array_contexts(self,
TileContext **contexts,
int index1,
- int index2) nogil:
+ int index2) noexcept nogil:
"""
Merge contexts from `index2` to `index1` and delete the one from index2.
If the one from index1 was NULL, the one from index2 is moved to index1
@@ -265,7 +271,7 @@ cdef class _MarchingSquaresAlgorithm(object):
@cython.cdivision(True)
cdef void sequencial_reduction(self,
int nb_contexts,
- TileContext **contexts) nogil:
+ TileContext **contexts) noexcept nogil:
"""
Reduce the problem sequencially without taking care of the topology
@@ -286,7 +292,7 @@ cdef class _MarchingSquaresAlgorithm(object):
@cython.cdivision(True)
cdef void marching_squares_mp(self,
TileContext *context,
- cnumpy.float64_t level) nogil:
+ cnumpy.float64_t level) noexcept nogil:
"""
Main entry of the marching squares algorithm for each threads.
@@ -362,7 +368,7 @@ cdef class _MarchingSquaresAlgorithm(object):
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
- cdef void after_marching_squares(self, TileContext *context) nogil:
+ cdef void after_marching_squares(self, TileContext *context) noexcept nogil:
"""
Called by each threads after execution of the marching squares
algorithm. Called before merging together the contextes.
@@ -379,7 +385,7 @@ cdef class _MarchingSquaresAlgorithm(object):
int x,
int y,
int pattern,
- cnumpy.float64_t level) nogil:
+ cnumpy.float64_t level) noexcept nogil:
"""
Called by the marching squares algorithm each time a pattern is found.
@@ -396,7 +402,7 @@ cdef class _MarchingSquaresAlgorithm(object):
@cython.cdivision(True)
cdef void merge_context(self,
TileContext *context,
- TileContext *other) nogil:
+ TileContext *other) noexcept nogil:
"""
Merge into a context another context.
@@ -413,7 +419,7 @@ cdef class _MarchingSquaresAlgorithm(object):
cnumpy.float64_t level,
int* dim_x,
int* dim_y,
- int* nb_valid_contexts) nogil:
+ int* nb_valid_contexts) noexcept nogil:
"""
Create and initialize a 2d-array of contexts.
@@ -473,7 +479,7 @@ cdef class _MarchingSquaresAlgorithm(object):
int x,
int y,
int dim_x,
- int dim_y) nogil:
+ int dim_y) noexcept nogil:
"""
Allocate and initialize a context.
@@ -507,7 +513,7 @@ cdef class _MarchingSquaresAlgorithm(object):
cnumpy.uint32_t y,
cnumpy.uint8_t edge,
cnumpy.float64_t level,
- point_t *result_point) nogil:
+ point_t *result_point) noexcept nogil:
"""
Compute the location of a point of the polygons according to the level
and the neighbours.
@@ -551,7 +557,7 @@ cdef class _MarchingSquaresAlgorithm(object):
cnumpy.uint32_t y,
cnumpy.uint8_t edge,
cnumpy.float64_t level,
- coord_t *result_coord) nogil:
+ coord_t *result_coord) noexcept nogil:
"""
Compute the location of pixel which contains the point of the polygons
according to the level and the neighbours.
@@ -594,7 +600,7 @@ cdef class _MarchingSquaresAlgorithm(object):
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
- cdef point_index_t create_point_index(self, int yx, cnumpy.uint8_t edge) nogil:
+ cdef point_index_t create_point_index(self, int yx, cnumpy.uint8_t edge) noexcept nogil:
"""
Create a unique identifier for a point of a polygon based on the
pattern location and the edge.
@@ -633,7 +639,7 @@ cdef class _MarchingSquaresContours(_MarchingSquaresAlgorithm):
int x,
int y,
int pattern,
- cnumpy.float64_t level) nogil:
+ cnumpy.float64_t level) noexcept nogil:
cdef:
int segment
for segment in range(CELL_TO_EDGE[pattern][0]):
@@ -648,7 +654,7 @@ cdef class _MarchingSquaresContours(_MarchingSquaresAlgorithm):
int x, int y,
cnumpy.uint8_t begin_edge,
cnumpy.uint8_t end_edge,
- cnumpy.float64_t level) nogil:
+ cnumpy.float64_t level) noexcept nogil:
cdef:
int i, yx
point_t point
@@ -757,7 +763,7 @@ cdef class _MarchingSquaresContours(_MarchingSquaresAlgorithm):
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
- cdef void merge_context(self, TileContext *context, TileContext *other) nogil:
+ cdef void merge_context(self, TileContext *context, TileContext *other) noexcept nogil:
cdef:
map[point_index_t, PolygonDescription*].iterator it_begin
map[point_index_t, PolygonDescription*].iterator it_end
@@ -928,7 +934,7 @@ cdef class _MarchingSquaresPixels(_MarchingSquaresAlgorithm):
int x,
int y,
int pattern,
- cnumpy.float64_t level) nogil:
+ cnumpy.float64_t level) noexcept nogil:
cdef:
int segment
for segment in range(CELL_TO_EDGE[pattern][0]):
@@ -943,7 +949,7 @@ cdef class _MarchingSquaresPixels(_MarchingSquaresAlgorithm):
int x, int y,
cnumpy.uint8_t begin_edge,
cnumpy.uint8_t end_edge,
- cnumpy.float64_t level) nogil:
+ cnumpy.float64_t level) noexcept nogil:
cdef:
coord_t coord
self.compute_ipoint(x, y, begin_edge, level, &coord)
@@ -954,7 +960,7 @@ cdef class _MarchingSquaresPixels(_MarchingSquaresAlgorithm):
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
- cdef void after_marching_squares(self, TileContext *context) nogil:
+ cdef void after_marching_squares(self, TileContext *context) noexcept nogil:
cdef:
coord_t coord
cset[coord_t].iterator it_coord
@@ -976,7 +982,7 @@ cdef class _MarchingSquaresPixels(_MarchingSquaresAlgorithm):
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
- cdef void merge_context(self, TileContext *context, TileContext *other) nogil:
+ cdef void merge_context(self, TileContext *context, TileContext *other) noexcept nogil:
cdef:
cset[coord_t].iterator it_coord
@@ -1161,7 +1167,7 @@ cdef class MarchingSquaresMergeImpl(object):
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
- cdef void _compute_minmax_on_block(self, int block_x, int block_y, int block_index) nogil:
+ cdef void _compute_minmax_on_block(self, int block_x, int block_y, int block_index) noexcept nogil:
"""
Initialize the minmax cache.
@@ -1228,7 +1234,7 @@ cdef class MarchingSquaresMergeImpl(object):
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
- cdef void _create_minmax_cache(self) nogil:
+ cdef void _create_minmax_cache(self) noexcept nogil:
"""
Create and initialize minmax cache.
"""
diff --git a/src/silx/image/marchingsquares/_skimage.py b/src/silx/image/marchingsquares/_skimage.py
index 7fa97d5..2e136f7 100644
--- a/src/silx/image/marchingsquares/_skimage.py
+++ b/src/silx/image/marchingsquares/_skimage.py
@@ -80,7 +80,7 @@ class MarchingSquaresSciKitImage(object):
if len(polyline) == 0:
continue
integer_polyline = numpy.floor(polyline + delta)
- result[size:size + len(polyline)] = integer_polyline
+ result[size : size + len(polyline)] = integer_polyline
size += len(polyline)
if len(result) == 0:
diff --git a/src/silx/image/marchingsquares/test/test_funcapi.py b/src/silx/image/marchingsquares/test/test_funcapi.py
index e9d2d7d..0e5471c 100644
--- a/src/silx/image/marchingsquares/test/test_funcapi.py
+++ b/src/silx/image/marchingsquares/test/test_funcapi.py
@@ -32,7 +32,6 @@ import silx.image.marchingsquares
class MockMarchingSquares(object):
-
last = None
def __init__(self, image, mask=None):
diff --git a/src/silx/image/marchingsquares/test/test_mergeimpl.py b/src/silx/image/marchingsquares/test/test_mergeimpl.py
index bfa1263..db36b54 100644
--- a/src/silx/image/marchingsquares/test/test_mergeimpl.py
+++ b/src/silx/image/marchingsquares/test/test_mergeimpl.py
@@ -32,7 +32,6 @@ from .._mergeimpl import MarchingSquaresMergeImpl
class TestMergeImplApi(unittest.TestCase):
-
def test_image_not_an_array(self):
bad_image = 1
self.assertRaises(ValueError, MarchingSquaresMergeImpl, bad_image)
@@ -114,7 +113,6 @@ class TestMergeImplApi(unittest.TestCase):
class TestMergeImplContours(unittest.TestCase):
-
def test_merge_segments(self):
image = numpy.zeros((4, 4))
image[(2, 3), :] = 1
@@ -234,8 +232,8 @@ class TestMergeImplContours(unittest.TestCase):
def test_image(self):
# example from skimage
- x, y = numpy.ogrid[-numpy.pi:numpy.pi:100j, -numpy.pi:numpy.pi:100j]
- image = numpy.sin(numpy.exp((numpy.sin(x)**3 + numpy.cos(y)**2)))
+ x, y = numpy.ogrid[-numpy.pi : numpy.pi : 100j, -numpy.pi : numpy.pi : 100j]
+ image = numpy.sin(numpy.exp((numpy.sin(x) ** 3 + numpy.cos(y) ** 2)))
mask = None
ms = MarchingSquaresMergeImpl(image, mask)
polygons = ms.find_contours(0.5)
@@ -244,8 +242,8 @@ class TestMergeImplContours(unittest.TestCase):
def test_image_tiled(self):
# example from skimage
- x, y = numpy.ogrid[-numpy.pi:numpy.pi:100j, -numpy.pi:numpy.pi:100j]
- image = numpy.sin(numpy.exp((numpy.sin(x)**3 + numpy.cos(y)**2)))
+ x, y = numpy.ogrid[-numpy.pi : numpy.pi : 100j, -numpy.pi : numpy.pi : 100j]
+ image = numpy.sin(numpy.exp((numpy.sin(x) ** 3 + numpy.cos(y) ** 2)))
mask = None
ms = MarchingSquaresMergeImpl(image, mask, group_size=50)
polygons = ms.find_contours(0.5)
@@ -254,8 +252,8 @@ class TestMergeImplContours(unittest.TestCase):
def test_image_tiled_minmax(self):
# example from skimage
- x, y = numpy.ogrid[-numpy.pi:numpy.pi:100j, -numpy.pi:numpy.pi:100j]
- image = numpy.sin(numpy.exp((numpy.sin(x)**3 + numpy.cos(y)**2)))
+ x, y = numpy.ogrid[-numpy.pi : numpy.pi : 100j, -numpy.pi : numpy.pi : 100j]
+ image = numpy.sin(numpy.exp((numpy.sin(x) ** 3 + numpy.cos(y) ** 2)))
mask = None
ms = MarchingSquaresMergeImpl(image, mask, group_size=50, use_minmax_cache=True)
polygons = ms.find_contours(0.5)
diff --git a/src/silx/image/medianfilter.py b/src/silx/image/medianfilter.py
index 1938357..005058c 100644
--- a/src/silx/image/medianfilter.py
+++ b/src/silx/image/medianfilter.py
@@ -35,6 +35,7 @@ import logging
from silx.math import medianfilter as medianfilter_cpp
from silx.opencl import ocl as _ocl
+
if _ocl is not None:
from silx.opencl import medfilt as medfilt_opencl
else: # No OpenCL device or pyopencl not installed
@@ -44,15 +45,15 @@ else: # No OpenCL device or pyopencl not installed
_logger = logging.getLogger(__name__)
-MEDFILT_ENGINES = ['cpp', 'opencl']
+MEDFILT_ENGINES = ["cpp", "opencl"]
-def medfilt2d(image, kernel_size=3, engine='cpp'):
+def medfilt2d(image, kernel_size=3, engine="cpp"):
"""Apply a median filter on an image.
This median filter is using a 'nearest' padding for values
past the array edges. If you want more padding options or
- functionalities for the median filter (conditional filter
+ functionalities for the median filter (conditional filter
for example) please have a look at
:mod:`silx.math.medianfilter`.
@@ -73,41 +74,43 @@ def medfilt2d(image, kernel_size=3, engine='cpp'):
"""
if engine not in MEDFILT_ENGINES:
- err = 'silx doesn\'t have an implementation for the requested engine: '
- err += '%s' % engine
+ err = "silx doesn't have an implementation for the requested engine: "
+ err += "%s" % engine
raise ValueError(err)
if len(image.shape) != 2:
- raise ValueError('medfilt2d deals with arrays of dimension 2 only')
+ raise ValueError("medfilt2d deals with arrays of dimension 2 only")
- if engine == 'cpp':
- return medianfilter_cpp.medfilt(data=image,
- kernel_size=kernel_size,
- conditional=False)
- elif engine == 'opencl':
+ if engine == "cpp":
+ return medianfilter_cpp.medfilt(
+ data=image, kernel_size=kernel_size, conditional=False
+ )
+ elif engine == "opencl":
if medfilt_opencl is None:
- wrn = 'opencl median filter not available. '
- wrn += 'Launching cpp implementation.'
+ wrn = "opencl median filter not available. "
+ wrn += "Launching cpp implementation."
_logger.warning(wrn)
# instead call the cpp implementation
- return medianfilter_cpp.medfilt(data=image,
- kernel_size=kernel_size,
- conditional=False)
+ return medianfilter_cpp.medfilt(
+ data=image, kernel_size=kernel_size, conditional=False
+ )
else:
try:
- medianfilter = medfilt_opencl.MedianFilter2D(image.shape,
- devicetype="gpu")
+ medianfilter = medfilt_opencl.MedianFilter2D(
+ image.shape, devicetype="gpu"
+ )
res = medianfilter.medfilt2d(image, kernel_size)
- except(RuntimeError, MemoryError, ImportError):
- wrn = 'Exception occured in opencl median filter. '
- wrn += 'To get more information see debug log.'
- wrn += 'Launching cpp implementation.'
+ except (RuntimeError, MemoryError, ImportError):
+ wrn = "Exception occured in opencl median filter. "
+ wrn += "To get more information see debug log."
+ wrn += "Launching cpp implementation."
_logger.warning(wrn)
- _logger.debug("median filter - openCL implementation issue.",
- exc_info=True)
+ _logger.debug(
+ "median filter - openCL implementation issue.", exc_info=True
+ )
# instead call the cpp implementation
- res = medianfilter_cpp.medfilt(data=image,
- kernel_size=kernel_size,
- conditional=False)
+ res = medianfilter_cpp.medfilt(
+ data=image, kernel_size=kernel_size, conditional=False
+ )
return res
diff --git a/src/silx/image/phantomgenerator.py b/src/silx/image/phantomgenerator.py
index 1893368..118cb84 100644
--- a/src/silx/image/phantomgenerator.py
+++ b/src/silx/image/phantomgenerator.py
@@ -58,7 +58,7 @@ class PhantomGenerator(object):
_Ellipsoid(0.046, 0.046, 0.02, 0.0, -0.10, -0.25, 0.0, 0.01),
_Ellipsoid(0.046, 0.023, 0.02, -0.08, -0.605, -0.25, 0.0, 0.01),
_Ellipsoid(0.023, 0.023, 0.10, 0.0, -0.605, -0.25, 0.0, 0.01),
- _Ellipsoid(0.023, 0.046, 0.10, 0.06, -0.605, -0.25, 0.0, 0.01)
+ _Ellipsoid(0.023, 0.046, 0.10, 0.06, -0.605, -0.25, 0.0, 0.01),
]
@staticmethod
@@ -71,13 +71,15 @@ class PhantomGenerator(object):
produce every ellipsoid
:return numpy.ndarray: shepp logan phantom
"""
- assert(ellipsoidID is None or (ellipsoidID >= 0 and ellipsoidID < len(PhantomGenerator.SHEPP_LOGAN)))
+ assert ellipsoidID is None or (
+ ellipsoidID >= 0 and ellipsoidID < len(PhantomGenerator.SHEPP_LOGAN)
+ )
if ellipsoidID is None:
- area = PhantomGenerator._get2DPhantom(n,
- PhantomGenerator.SHEPP_LOGAN)
+ area = PhantomGenerator._get2DPhantom(n, PhantomGenerator.SHEPP_LOGAN)
else:
- area = PhantomGenerator._get2DPhantom(n,
- [PhantomGenerator.SHEPP_LOGAN[ellipsoidID]])
+ area = PhantomGenerator._get2DPhantom(
+ n, [PhantomGenerator.SHEPP_LOGAN[ellipsoidID]]
+ )
indices = numpy.abs(area) > 0
area[indices] = numpy.multiply(area[indices] + 0.1, 5)
@@ -86,11 +88,11 @@ class PhantomGenerator(object):
@staticmethod
def _get2DPhantom(n, phantomSpec):
area = numpy.ndarray(shape=(n, n))
- area.fill(0.)
+ area.fill(0.0)
count = 0
for ell in phantomSpec:
- count = count+1
+ count = count + 1
for x in range(n):
sumSquareXandY = PhantomGenerator._getSquareXandYsum(n, x, ell)
indices = sumSquareXandY <= 1
@@ -99,18 +101,18 @@ class PhantomGenerator(object):
@staticmethod
def _getSquareXandYsum(n, x, ell):
- supportX1 = numpy.ndarray(shape=(n, ))
- supportX2 = numpy.ndarray(shape=(n, ))
- support_consts = numpy.ndarray(shape=(n, ))
+ supportX1 = numpy.ndarray(shape=(n,))
+ supportX2 = numpy.ndarray(shape=(n,))
+ support_consts = numpy.ndarray(shape=(n,))
- xScaled = float(2*x-n)/float(n)
+ xScaled = float(2 * x - n) / float(n)
xCos = xScaled * ell.cosAlpha
xSin = -xScaled * ell.sinAlpha
supportX1.fill(xCos)
supportX2.fill(xSin)
supportY1 = numpy.arange(n)
- support_consts.fill(2.)
+ support_consts.fill(2.0)
supportY1 = numpy.multiply(support_consts, supportY1)
support_consts.fill(n)
supportY1 = numpy.subtract(supportY1, support_consts)
@@ -119,11 +121,9 @@ class PhantomGenerator(object):
supportY2 = numpy.array(supportY1)
support_consts.fill(ell.sinAlpha)
- supportY1 = numpy.add(supportX1,
- numpy.multiply(supportY1, support_consts))
+ supportY1 = numpy.add(supportX1, numpy.multiply(supportY1, support_consts))
support_consts.fill(ell.cosAlpha)
- supportY2 = numpy.add(supportX2,
- numpy.multiply(supportY2, support_consts))
+ supportY2 = numpy.add(supportX2, numpy.multiply(supportY2, support_consts))
support_consts.fill(ell.x0)
supportY1 = numpy.subtract(supportY1, support_consts)
@@ -131,19 +131,17 @@ class PhantomGenerator(object):
supportY2 = numpy.subtract(supportY2, support_consts)
support_consts.fill(ell.a)
- supportY1 = numpy.power((numpy.divide(supportY1, support_consts)),
- 2)
+ supportY1 = numpy.power((numpy.divide(supportY1, support_consts)), 2)
support_consts.fill(ell.b)
- supportY2 = numpy.power(numpy.divide(supportY2, support_consts),
- 2)
+ supportY2 = numpy.power(numpy.divide(supportY2, support_consts), 2)
return numpy.add(supportY1, supportY2)
@staticmethod
def _getSquareZ(n, ell):
supportZ1 = numpy.arange(n)
- support_consts = numpy.ndarray(shape=(n, ))
- support_consts.fill(2.)
+ support_consts = numpy.ndarray(shape=(n,))
+ support_consts.fill(2.0)
supportZ1 = numpy.multiply(support_consts, supportZ1)
support_consts.fill(n)
supportZ1 = numpy.subtract(supportZ1, support_consts)
@@ -154,6 +152,4 @@ class PhantomGenerator(object):
supportZ1 = numpy.subtract(supportZ1, ell.z0)
support_consts.fill(ell.c)
- return numpy.power(numpy.divide(supportZ1, support_consts),
- 2)
-
+ return numpy.power(numpy.divide(supportZ1, support_consts), 2)
diff --git a/src/silx/image/projection.py b/src/silx/image/projection.py
index 0b58323..251ac1f 100644
--- a/src/silx/image/projection.py
+++ b/src/silx/image/projection.py
@@ -21,4 +21,4 @@
#
# ############################################################################*/
-from silx.opencl.projection import *
+from silx.opencl.projection import * # noqa
diff --git a/src/silx/image/reconstruction.py b/src/silx/image/reconstruction.py
index 2ed95c0..8800962 100644
--- a/src/silx/image/reconstruction.py
+++ b/src/silx/image/reconstruction.py
@@ -21,4 +21,4 @@
#
# ############################################################################*/
-from silx.opencl.reconstruction import *
+from silx.opencl.reconstruction import * # noqa
diff --git a/src/silx/image/sift.py b/src/silx/image/sift.py
index 57599cd..e42e367 100644
--- a/src/silx/image/sift.py
+++ b/src/silx/image/sift.py
@@ -21,4 +21,4 @@
#
# ############################################################################*/
-from silx.opencl.sift import *
+from silx.opencl.sift import * # noqa
diff --git a/src/silx/image/test/test_bb.py b/src/silx/image/test/test_bb.py
index 19f5f39..f174b28 100644
--- a/src/silx/image/test/test_bb.py
+++ b/src/silx/image/test/test_bb.py
@@ -35,6 +35,7 @@ from silx.image._boundingbox import _BoundingBox
class TestBB(unittest.TestCase):
"""Some simple test on the bounding box class"""
+
def test_creation(self):
"""test some constructors"""
pts = numpy.array([(0, 0), (10, 20), (20, 0)])
@@ -60,14 +61,24 @@ class TestBB(unittest.TestCase):
def test_collide(self):
"""test the collide function"""
bb1 = _BoundingBox(bottom_left=(6, 2), top_right=(12, 6))
- self.assertTrue(bb1.collide(_BoundingBox(bottom_left=(6, 2), top_right=(12, 6))))
+ self.assertTrue(
+ bb1.collide(_BoundingBox(bottom_left=(6, 2), top_right=(12, 6)))
+ )
bb1 = _BoundingBox(bottom_left=(6, 2), top_right=(12, 6))
- self.assertFalse(bb1.collide(_BoundingBox(bottom_left=(12, 2), top_right=(12, 2))))
+ self.assertFalse(
+ bb1.collide(_BoundingBox(bottom_left=(12, 2), top_right=(12, 2)))
+ )
def test_isIn_bb(self):
"""test the isIn function with other bounding box"""
bb1 = _BoundingBox(bottom_left=(6, 2), top_right=(12, 6))
- self.assertTrue(bb1.contains(_BoundingBox(bottom_left=(6, 2), top_right=(12, 6))))
+ self.assertTrue(
+ bb1.contains(_BoundingBox(bottom_left=(6, 2), top_right=(12, 6)))
+ )
bb1 = _BoundingBox(bottom_left=(6, 2), top_right=(12, 6))
- self.assertTrue(bb1.contains(_BoundingBox(bottom_left=(12, 2), top_right=(12, 2))))
- self.assertFalse(_BoundingBox(bottom_left=(12, 2), top_right=(12, 2)).contains(bb1))
+ self.assertTrue(
+ bb1.contains(_BoundingBox(bottom_left=(12, 2), top_right=(12, 2)))
+ )
+ self.assertFalse(
+ _BoundingBox(bottom_left=(12, 2), top_right=(12, 2)).contains(bb1)
+ )
diff --git a/src/silx/image/test/test_bilinear.py b/src/silx/image/test/test_bilinear.py
index d0abe64..c9b866f 100644
--- a/src/silx/image/test/test_bilinear.py
+++ b/src/silx/image/test/test_bilinear.py
@@ -28,18 +28,20 @@ __date__ = "25/11/2020"
import unittest
import numpy
import logging
+
logger = logging.getLogger(__name__)
from ..bilinear import BilinearImage
class TestBilinear(unittest.TestCase):
"""basic maximum search test"""
+
N = 1000
def test_max_search_round(self):
"""test maximum search using random points: maximum is at the pixel center"""
- a = numpy.arange(100) - 40.
- b = numpy.arange(100) - 60.
+ a = numpy.arange(100) - 40.0
+ b = numpy.arange(100) - 60.0
ga = numpy.exp(-a * a / 4000)
gb = numpy.exp(-b * b / 6000)
gg = numpy.outer(ga, gb)
@@ -57,7 +59,7 @@ class TestBilinear(unittest.TestCase):
else:
logger.debug("Good guess maximum (%i,%i) -> (%.1f,%.1f)", i, j, k, l)
ok += 1
- logger.debug("Success rate: %.1f", 100. * ok / self.N)
+ logger.debug("Success rate: %.1f", 100.0 * ok / self.N)
self.assertEqual(ok, self.N, "Maximum is always found")
def test_max_search_half(self):
@@ -77,12 +79,12 @@ class TestBilinear(unittest.TestCase):
else:
logger.debug("Good guess maximum (%i,%i) -> (%.1f,%.1f)", i, j, k, l)
ok += 1
- logger.debug("Success rate: %.1f", 100. * ok / self.N)
+ logger.debug("Success rate: %.1f", 100.0 * ok / self.N)
self.assertEqual(ok, self.N, "Maximum is always found")
def test_map(self):
N = 6
- y, x = numpy.ogrid[:N,:N + 10]
+ y, x = numpy.ogrid[:N, : N + 10]
img = x + y
b = BilinearImage(img)
x2d = numpy.zeros_like(y) + x
@@ -90,15 +92,19 @@ class TestBilinear(unittest.TestCase):
res1 = b.map_coordinates((y2d, x2d))
self.assertEqual(abs(res1 - img).max(), 0, "images are the same (corners)")
- x2d = numpy.zeros_like(y) + (x[:,:-1] + 0.5)
- y2d = numpy.zeros_like(x[:,:-1]) + y
+ x2d = numpy.zeros_like(y) + (x[:, :-1] + 0.5)
+ y2d = numpy.zeros_like(x[:, :-1]) + y
res1 = b.map_coordinates((y2d, x2d))
- self.assertEqual(abs(res1 - img[:,:-1] - 0.5).max(), 0, "images are the same (middle)")
+ self.assertEqual(
+ abs(res1 - img[:, :-1] - 0.5).max(), 0, "images are the same (middle)"
+ )
- x2d = numpy.zeros_like(y[:-1,:]) + (x[:,:-1] + 0.5)
- y2d = numpy.zeros_like(x[:,:-1]) + (y[:-1,:] + 0.5)
+ x2d = numpy.zeros_like(y[:-1, :]) + (x[:, :-1] + 0.5)
+ y2d = numpy.zeros_like(x[:, :-1]) + (y[:-1, :] + 0.5)
res1 = b.map_coordinates((y2d, x2d))
- self.assertEqual(abs(res1 - img[:-1, 1:]).max(), 0, "images are the same (center)")
+ self.assertEqual(
+ abs(res1 - img[:-1, 1:]).max(), 0, "images are the same (center)"
+ )
def test_mask_grad(self):
N = 100
@@ -114,22 +120,30 @@ class TestBilinear(unittest.TestCase):
self.assertEqual(b.maxi, N * N - 1, "maxi is N²-1")
self.assertEqual(b.mini, 0, "mini is 0")
- y, x = numpy.ogrid[:N,:N]
+ y, x = numpy.ogrid[:N, :N]
x2d = numpy.zeros_like(y) + x
y2d = numpy.zeros_like(x) + y
res1 = b.map_coordinates((y2d, x2d))
- self.assertEqual(numpy.nanmax(abs(res1 - img)), 0, "images are the same (corners), or Nan ")
+ self.assertEqual(
+ numpy.nanmax(abs(res1 - img)), 0, "images are the same (corners), or Nan "
+ )
- x2d = numpy.zeros_like(y) + (x[:,:-1] + 0.5)
- y2d = numpy.zeros_like(x[:,:-1]) + y
+ x2d = numpy.zeros_like(y) + (x[:, :-1] + 0.5)
+ y2d = numpy.zeros_like(x[:, :-1]) + y
res1 = b.map_coordinates((y2d, x2d))
- self.assertLessEqual(numpy.max(abs(res1 - img[:, 1:] + 1 / 2.)), 0.5, "images are the same (middle) +/- 0.5")
-
- x2d = numpy.zeros_like(y[:-1]) + (x[:,:-1] + 0.5)
- y2d = numpy.zeros_like(x[:,:-1]) + (y[:-1] + 0.5)
+ self.assertLessEqual(
+ numpy.max(abs(res1 - img[:, 1:] + 1 / 2.0)),
+ 0.5,
+ "images are the same (middle) +/- 0.5",
+ )
+
+ x2d = numpy.zeros_like(y[:-1]) + (x[:, :-1] + 0.5)
+ y2d = numpy.zeros_like(x[:, :-1]) + (y[:-1] + 0.5)
res1 = b.map_coordinates((y2d, x2d))
- exp = 0.25 * (img[:-1,:-1] + img[:-1, 1:] + img[1:,:-1] + img[1:, 1:])
- self.assertLessEqual(abs(res1 - exp).max(), N / 4, "images are almost the same (center)")
+ exp = 0.25 * (img[:-1, :-1] + img[:-1, 1:] + img[1:, :-1] + img[1:, 1:])
+ self.assertLessEqual(
+ abs(res1 - exp).max(), N / 4, "images are almost the same (center)"
+ )
def test_profile_grad(self):
N = 100
@@ -138,7 +152,11 @@ class TestBilinear(unittest.TestCase):
res1 = b.profile_line((0, 0), (N - 1, N - 1))
l = numpy.ceil(numpy.sqrt(2) * N)
self.assertEqual(len(res1), l, "Profile has correct length")
- self.assertLess((res1[:-2] - res1[1:-1]).std(), 1e-3, "profile is linear (excluding last point)")
+ self.assertLess(
+ (res1[:-2] - res1[1:-1]).std(),
+ 1e-3,
+ "profile is linear (excluding last point)",
+ )
def test_profile_gaus(self):
N = 100
@@ -154,13 +172,15 @@ class TestBilinear(unittest.TestCase):
self.assertLess(abs(res_ver - g).max(), 1e-5, "correct vertical profile")
# Profile with linewidth=3
- expected_profile = img[:, N // 2 - 1:N // 2 + 2].mean(axis=1)
+ expected_profile = img[:, N // 2 - 1 : N // 2 + 2].mean(axis=1)
res_hor = b.profile_line((N // 2, 0), (N // 2, N - 1), linewidth=3)
res_ver = b.profile_line((0, N // 2), (N - 1, N // 2), linewidth=3)
self.assertEqual(len(res_hor), N, "Profile has correct length")
self.assertEqual(len(res_ver), N, "Profile has correct length")
- self.assertLess(abs(res_hor - expected_profile).max(), 1e-5,
- "correct horizontal profile")
- self.assertLess(abs(res_ver - expected_profile).max(), 1e-5,
- "correct vertical profile")
+ self.assertLess(
+ abs(res_hor - expected_profile).max(), 1e-5, "correct horizontal profile"
+ )
+ self.assertLess(
+ abs(res_ver - expected_profile).max(), 1e-5, "correct vertical profile"
+ )
diff --git a/src/silx/image/test/test_medianfilter.py b/src/silx/image/test/test_medianfilter.py
index 3215d69..d9e70ed 100644
--- a/src/silx/image/test/test_medianfilter.py
+++ b/src/silx/image/test/test_medianfilter.py
@@ -40,8 +40,7 @@ class TestMedianFilterEngines(unittest.TestCase):
"""Make sure we have access to all the different implementation of
median filter from image medfilt"""
-
- IMG = numpy.arange(10000.).reshape(100, 100)
+ IMG = numpy.arange(10000.0).reshape(100, 100)
KERNEL = (1, 1)
@@ -50,7 +49,8 @@ class TestMedianFilterEngines(unittest.TestCase):
res = medianfilter.medfilt2d(
image=TestMedianFilterEngines.IMG,
kernel_size=TestMedianFilterEngines.KERNEL,
- engine='cpp')
+ engine="cpp",
+ )
self.assertTrue(numpy.array_equal(res, TestMedianFilterEngines.IMG))
@unittest.skipUnless(ocl, "PyOpenCl is missing")
@@ -59,5 +59,6 @@ class TestMedianFilterEngines(unittest.TestCase):
res = medianfilter.medfilt2d(
image=TestMedianFilterEngines.IMG,
kernel_size=TestMedianFilterEngines.KERNEL,
- engine='opencl')
+ engine="opencl",
+ )
self.assertTrue(numpy.array_equal(res, TestMedianFilterEngines.IMG))
diff --git a/src/silx/image/test/test_shapes.py b/src/silx/image/test/test_shapes.py
index 1adc112..e936e64 100644
--- a/src/silx/image/test/test_shapes.py
+++ b/src/silx/image/test/test_shapes.py
@@ -47,52 +47,58 @@ class TestPolygonFill(ParametricTestCase):
mask_shape = 4, 4
tests = {
# test name: [(row min, row max), (col min, col max)]
- 'square in': [(1, 3), (1, 3)],
- 'square out': [(1, 3), (1, 10)],
- 'square around': [(-1, 5), (-1, 5)],
- }
+ "square in": [(1, 3), (1, 3)],
+ "square out": [(1, 3), (1, 10)],
+ "square around": [(-1, 5), (-1, 5)],
+ }
for test_name, (rows, cols) in tests.items():
- with self.subTest(msg=test_name, rows=rows, cols=cols,
- mask_shape=mask_shape):
+ with self.subTest(
+ msg=test_name, rows=rows, cols=cols, mask_shape=mask_shape
+ ):
ref_mask = numpy.zeros(mask_shape, dtype=numpy.uint8)
- ref_mask[max(0, rows[0]):rows[1],
- max(0, cols[0]):cols[1]] = True
-
- vertices = [(rows[0], cols[0]), (rows[1], cols[0]),
- (rows[1], cols[1]), (rows[0], cols[1])]
+ ref_mask[max(0, rows[0]) : rows[1], max(0, cols[0]) : cols[1]] = True
+
+ vertices = [
+ (rows[0], cols[0]),
+ (rows[1], cols[0]),
+ (rows[1], cols[1]),
+ (rows[0], cols[1]),
+ ]
mask = shapes.polygon_fill_mask(vertices, ref_mask.shape)
is_equal = numpy.all(numpy.equal(ref_mask, mask))
if not is_equal:
- _logger.debug('%s failed with mask != ref_mask:',
- test_name)
- _logger.debug('result:\n%s', str(mask))
- _logger.debug('ref:\n%s', str(ref_mask))
+ _logger.debug("%s failed with mask != ref_mask:", test_name)
+ _logger.debug("result:\n%s", str(mask))
+ _logger.debug("ref:\n%s", str(ref_mask))
self.assertTrue(is_equal)
def test_eight(self):
"""Tests with eight shape with different rotation and direction"""
- ref_mask = numpy.array((
- (1, 1, 1, 1, 1, 0),
- (0, 1, 1, 1, 0, 0),
- (0, 0, 1, 0, 0, 0),
- (0, 0, 1, 0, 0, 0),
- (0, 1, 1, 1, 0, 0),
- (0, 0, 0, 0, 0, 0)), dtype=numpy.uint8)
- ref_mask_rot = numpy.asarray(numpy.logical_not(ref_mask),
- dtype=numpy.uint8)
+ ref_mask = numpy.array(
+ (
+ (1, 1, 1, 1, 1, 0),
+ (0, 1, 1, 1, 0, 0),
+ (0, 0, 1, 0, 0, 0),
+ (0, 0, 1, 0, 0, 0),
+ (0, 1, 1, 1, 0, 0),
+ (0, 0, 0, 0, 0, 0),
+ ),
+ dtype=numpy.uint8,
+ )
+ ref_mask_rot = numpy.asarray(numpy.logical_not(ref_mask), dtype=numpy.uint8)
ref_mask_rot[:, -1] = 0
ref_mask_rot[-1, :] = 0
tests = {
- 'dir 1': ([(0, 0), (5, 5), (5, 0), (0, 5)], ref_mask),
- 'dir 1, rot 90': ([(5, 0), (0, 5), (5, 5), (0, 0)], ref_mask_rot),
- 'dir 1, rot 180': ([(5, 5), (0, 0), (0, 5), (5, 0)], ref_mask),
- 'dir 1, rot -90': ([(0, 5), (5, 0), (0, 0), (5, 5)], ref_mask_rot),
- 'dir 2': ([(0, 0), (0, 5), (5, 0), (5, 5)], ref_mask),
- 'dir 2, rot 90': ([(5, 0), (0, 0), (5, 5), (0, 5)], ref_mask_rot),
- 'dir 2, rot 180': ([(5, 5), (5, 0), (0, 5), (0, 0)], ref_mask),
- 'dir 2, rot -90': ([(0, 5), (5, 5), (0, 0), (5, 0)], ref_mask_rot),
+ "dir 1": ([(0, 0), (5, 5), (5, 0), (0, 5)], ref_mask),
+ "dir 1, rot 90": ([(5, 0), (0, 5), (5, 5), (0, 0)], ref_mask_rot),
+ "dir 1, rot 180": ([(5, 5), (0, 0), (0, 5), (5, 0)], ref_mask),
+ "dir 1, rot -90": ([(0, 5), (5, 0), (0, 0), (5, 5)], ref_mask_rot),
+ "dir 2": ([(0, 0), (0, 5), (5, 0), (5, 5)], ref_mask),
+ "dir 2, rot 90": ([(5, 0), (0, 0), (5, 5), (0, 5)], ref_mask_rot),
+ "dir 2, rot 180": ([(5, 5), (5, 0), (0, 5), (0, 0)], ref_mask),
+ "dir 2, rot -90": ([(0, 5), (5, 5), (0, 0), (5, 0)], ref_mask_rot),
}
for test_name, (vertices, ref_mask) in tests.items():
@@ -100,10 +106,9 @@ class TestPolygonFill(ParametricTestCase):
mask = shapes.polygon_fill_mask(vertices, ref_mask.shape)
is_equal = numpy.all(numpy.equal(ref_mask, mask))
if not is_equal:
- _logger.debug('%s failed with mask != ref_mask:',
- test_name)
- _logger.debug('result:\n%s', str(mask))
- _logger.debug('ref:\n%s', str(ref_mask))
+ _logger.debug("%s failed with mask != ref_mask:", test_name)
+ _logger.debug("result:\n%s", str(mask))
+ _logger.debug("ref:\n%s", str(ref_mask))
self.assertTrue(is_equal)
def test_shapes(self):
@@ -112,41 +117,50 @@ class TestPolygonFill(ParametricTestCase):
# name: (
# polygon corners as a list of (row, col),
# ref_mask)
- 'concave polygon': (
+ "concave polygon": (
[(1, 1), (4, 3), (1, 5), (2, 3)],
- numpy.array((
- (0, 0, 0, 0, 0, 0, 0, 0),
- (0, 0, 0, 0, 0, 0, 0, 0),
- (0, 0, 1, 1, 1, 0, 0, 0),
- (0, 0, 0, 1, 0, 0, 0, 0),
- (0, 0, 0, 0, 0, 0, 0, 0),
- (0, 0, 0, 0, 0, 0, 0, 0)), dtype=numpy.uint8)),
- 'concave polygon partly outside mask': (
+ numpy.array(
+ (
+ (0, 0, 0, 0, 0, 0, 0, 0),
+ (0, 0, 0, 0, 0, 0, 0, 0),
+ (0, 0, 1, 1, 1, 0, 0, 0),
+ (0, 0, 0, 1, 0, 0, 0, 0),
+ (0, 0, 0, 0, 0, 0, 0, 0),
+ (0, 0, 0, 0, 0, 0, 0, 0),
+ ),
+ dtype=numpy.uint8,
+ ),
+ ),
+ "concave polygon partly outside mask": (
[(-1, -1), (4, 3), (1, 5), (2, 3)],
- numpy.array((
- (1, 0, 0, 0, 0, 0),
- (0, 1, 0, 0, 0, 0),
- (0, 0, 1, 1, 1, 0),
- (0, 0, 0, 1, 0, 0),
- (0, 0, 0, 0, 0, 0),
- (0, 0, 0, 0, 0, 0),
- (0, 0, 0, 0, 0, 0),
- (0, 0, 0, 0, 0, 0)), dtype=numpy.uint8)),
- 'polygon surrounding mask': (
- [(-1, -1), (-1, 7), (7, 7), (7, -1), (0, -1),
- (8, -2), (8, 8), (-2, 8)],
- numpy.zeros((6, 6), dtype=numpy.uint8))
- }
+ numpy.array(
+ (
+ (1, 0, 0, 0, 0, 0),
+ (0, 1, 0, 0, 0, 0),
+ (0, 0, 1, 1, 1, 0),
+ (0, 0, 0, 1, 0, 0),
+ (0, 0, 0, 0, 0, 0),
+ (0, 0, 0, 0, 0, 0),
+ (0, 0, 0, 0, 0, 0),
+ (0, 0, 0, 0, 0, 0),
+ ),
+ dtype=numpy.uint8,
+ ),
+ ),
+ "polygon surrounding mask": (
+ [(-1, -1), (-1, 7), (7, 7), (7, -1), (0, -1), (8, -2), (8, 8), (-2, 8)],
+ numpy.zeros((6, 6), dtype=numpy.uint8),
+ ),
+ }
for test_name, (vertices, ref_mask) in tests.items():
with self.subTest(msg=test_name):
mask = shapes.polygon_fill_mask(vertices, ref_mask.shape)
is_equal = numpy.all(numpy.equal(ref_mask, mask))
if not is_equal:
- _logger.debug('%s failed with mask != ref_mask:',
- test_name)
- _logger.debug('result:\n%s', str(mask))
- _logger.debug('ref:\n%s', str(ref_mask))
+ _logger.debug("%s failed with mask != ref_mask:", test_name)
+ _logger.debug("result:\n%s", str(mask))
+ _logger.debug("ref:\n%s", str(ref_mask))
self.assertTrue(is_equal)
@@ -157,14 +171,14 @@ class TestDrawLine(ParametricTestCase):
"""Test drawing horizontal, vertical and diagonal lines"""
lines = { # test_name: (drow, dcol)
- 'Horizontal line, col0 < col1': (0, 10),
- 'Horizontal line, col0 > col1': (0, -10),
- 'Vertical line, row0 < row1': (10, 0),
- 'Vertical line, row0 > row1': (-10, 0),
- 'Diagonal col0 < col1 and row0 < row1': (10, 10),
- 'Diagonal col0 < col1 and row0 > row1': (-10, 10),
- 'Diagonal col0 > col1 and row0 < row1': (10, -10),
- 'Diagonal col0 > col1 and row0 > row1': (-10, -10),
+ "Horizontal line, col0 < col1": (0, 10),
+ "Horizontal line, col0 > col1": (0, -10),
+ "Vertical line, row0 < row1": (10, 0),
+ "Vertical line, row0 > row1": (-10, 0),
+ "Diagonal col0 < col1 and row0 < row1": (10, 10),
+ "Diagonal col0 < col1 and row0 > row1": (-10, 10),
+ "Diagonal col0 > col1 and row0 < row1": (10, -10),
+ "Diagonal col0 > col1 and row0 > row1": (-10, -10),
}
row0, col0 = 1, 2 # Start point
@@ -201,19 +215,18 @@ class TestDrawLine(ParametricTestCase):
row0, col0 = 1, 1
dy, dx = 3, 5
- ref_coords = numpy.array(
- [(0, 0), (1, 1), (1, 2), (2, 3), (2, 4), (3, 5)])
+ ref_coords = numpy.array([(0, 0), (1, 1), (1, 2), (2, 3), (2, 4), (3, 5)])
# Build lines for the 8 octants from this coordinantes
lines = { # name: (drow, dcol, ref_coords)
- '1st octant': (dy, dx, ref_coords),
- '2nd octant': (dx, dy, ref_coords[:, (1, 0)]), # invert x and y
- '3rd octant': (dx, -dy, ref_coords[:, (1, 0)] * (1, -1)),
- '4th octant': (dy, -dx, ref_coords * (1, -1)),
- '5th octant': (-dy, -dx, ref_coords * (-1, -1)),
- '6th octant': (-dx, -dy, ref_coords[:, (1, 0)] * (-1, -1)),
- '7th octant': (-dx, dy, ref_coords[:, (1, 0)] * (-1, 1)),
- '8th octant': (-dy, dx, ref_coords * (-1, 1))
+ "1st octant": (dy, dx, ref_coords),
+ "2nd octant": (dx, dy, ref_coords[:, (1, 0)]), # invert x and y
+ "3rd octant": (dx, -dy, ref_coords[:, (1, 0)] * (1, -1)),
+ "4th octant": (dy, -dx, ref_coords * (1, -1)),
+ "5th octant": (-dy, -dx, ref_coords * (-1, -1)),
+ "6th octant": (-dx, -dy, ref_coords[:, (1, 0)] * (-1, -1)),
+ "7th octant": (-dx, dy, ref_coords[:, (1, 0)] * (-1, 1)),
+ "8th octant": (-dy, dx, ref_coords * (-1, 1)),
}
# Test with different starting points with positive and negative coords
@@ -224,8 +237,7 @@ class TestDrawLine(ParametricTestCase):
# Transpose from ((row0, col0), ...) to (rows, cols)
ref_coords = numpy.transpose(ref_coords + (row0, col0))
- with self.subTest(msg=name,
- pt0=(row0, col0), pt1=(row1, col1)):
+ with self.subTest(msg=name, pt0=(row0, col0), pt1=(row1, col1)):
result = shapes.draw_line(row0, col0, row1, col1)
self.assertTrue(self.isEqual(name, result, ref_coords))
@@ -233,36 +245,64 @@ class TestDrawLine(ParametricTestCase):
"""Test of line width"""
lines = { # test_name: row0, col0, row1, col1, width, ref
- 'horizontal w=2':
- (0, 0, 0, 1, 2, ((0, 1, 0, 1),
- (0, 0, 1, 1))),
- 'horizontal w=3':
- (0, 0, 0, 1, 3, ((-1, 0, 1, -1, 0, 1),
- (0, 0, 0, 1, 1, 1))),
- 'vertical w=2':
- (0, 0, 1, 0, 2, ((0, 0, 1, 1),
- (0, 1, 0, 1))),
- 'vertical w=3':
- (0, 0, 1, 0, 3, ((0, 0, 0, 1, 1, 1),
- (-1, 0, 1, -1, 0, 1))),
- 'diagonal w=3':
- (0, 0, 1, 1, 3, ((-1, 0, 1, 0, 1, 2),
- (0, 0, 0, 1, 1, 1))),
- '1st octant w=3':
- (0, 0, 1, 2, 3,
- numpy.array(((-1, 0), (0, 0), (1, 0),
- (0, 1), (1, 1), (2, 1),
- (0, 2), (1, 2), (2, 2))).T),
- '2nd octant w=3':
- (0, 0, 2, 1, 3,
- numpy.array(((0, -1), (0, 0), (0, 1),
- (1, 0), (1, 1), (1, 2),
- (2, 0), (2, 1), (2, 2))).T),
+ "horizontal w=2": (0, 0, 0, 1, 2, ((0, 1, 0, 1), (0, 0, 1, 1))),
+ "horizontal w=3": (
+ 0,
+ 0,
+ 0,
+ 1,
+ 3,
+ ((-1, 0, 1, -1, 0, 1), (0, 0, 0, 1, 1, 1)),
+ ),
+ "vertical w=2": (0, 0, 1, 0, 2, ((0, 0, 1, 1), (0, 1, 0, 1))),
+ "vertical w=3": (0, 0, 1, 0, 3, ((0, 0, 0, 1, 1, 1), (-1, 0, 1, -1, 0, 1))),
+ "diagonal w=3": (0, 0, 1, 1, 3, ((-1, 0, 1, 0, 1, 2), (0, 0, 0, 1, 1, 1))),
+ "1st octant w=3": (
+ 0,
+ 0,
+ 1,
+ 2,
+ 3,
+ numpy.array(
+ (
+ (-1, 0),
+ (0, 0),
+ (1, 0),
+ (0, 1),
+ (1, 1),
+ (2, 1),
+ (0, 2),
+ (1, 2),
+ (2, 2),
+ )
+ ).T,
+ ),
+ "2nd octant w=3": (
+ 0,
+ 0,
+ 2,
+ 1,
+ 3,
+ numpy.array(
+ (
+ (0, -1),
+ (0, 0),
+ (0, 1),
+ (1, 0),
+ (1, 1),
+ (1, 2),
+ (2, 0),
+ (2, 1),
+ (2, 2),
+ )
+ ).T,
+ ),
}
for test_name, (row0, col0, row1, col1, width, ref) in lines.items():
- with self.subTest(msg=test_name,
- pt0=(row0, col0), pt1=(row1, col1), width=width):
+ with self.subTest(
+ msg=test_name, pt0=(row0, col0), pt1=(row1, col1), width=width
+ ):
result = shapes.draw_line(row0, col0, row1, col1, width)
self.assertTrue(self.isEqual(test_name, result, ref))
@@ -270,10 +310,9 @@ class TestDrawLine(ParametricTestCase):
"""Test equality of two numpy arrays and log them if different"""
is_equal = numpy.all(numpy.equal(result, ref))
if not is_equal:
- _logger.debug('%s failed with result != ref:',
- test_name)
- _logger.debug('result:\n%s', str(result))
- _logger.debug('ref:\n%s', str(ref))
+ _logger.debug("%s failed with result != ref:", test_name)
+ _logger.debug("result:\n%s", str(result))
+ _logger.debug("ref:\n%s", str(ref))
return is_equal
@@ -283,8 +322,9 @@ class TestCircleFill(ParametricTestCase):
def testCircle(self):
"""Test circle_fill with different input parameters"""
- square3x3 = numpy.array(((-1, -1, -1, 0, 0, 0, 1, 1, 1),
- (-1, 0, 1, -1, 0, 1, -1, 0, 1)))
+ square3x3 = numpy.array(
+ ((-1, -1, -1, 0, 0, 0, 1, 1, 1), (-1, 0, 1, -1, 0, 1, -1, 0, 1))
+ )
tests = [
# crow, ccol, radius, ref_coords = (ref_rows, ref_cols)
@@ -292,21 +332,97 @@ class TestCircleFill(ParametricTestCase):
(10, 15, 1, ((10,), (15,))),
(0, 0, 1.5, square3x3),
(5, 10, 2, (5 + square3x3[0], 10 + square3x3[1])),
- (10, 20, 3.5, (
- 10 + numpy.array((-3, -3, -3,
- -2, -2, -2, -2, -2,
- -1, -1, -1, -1, -1, -1, -1,
- 0, 0, 0, 0, 0, 0, 0,
- 1, 1, 1, 1, 1, 1, 1,
- 2, 2, 2, 2, 2,
- 3, 3, 3)),
- 20 + numpy.array((-1, 0, 1,
- -2, -1, 0, 1, 2,
- -3, -2, -1, 0, 1, 2, 3,
- -3, -2, -1, 0, 1, 2, 3,
- -3, -2, -1, 0, 1, 2, 3,
- -2, -1, 0, 1, 2,
- -1, 0, 1)))),
+ (
+ 10,
+ 20,
+ 3.5,
+ (
+ 10
+ + numpy.array(
+ (
+ -3,
+ -3,
+ -3,
+ -2,
+ -2,
+ -2,
+ -2,
+ -2,
+ -1,
+ -1,
+ -1,
+ -1,
+ -1,
+ -1,
+ -1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 3,
+ 3,
+ 3,
+ )
+ ),
+ 20
+ + numpy.array(
+ (
+ -1,
+ 0,
+ 1,
+ -2,
+ -1,
+ 0,
+ 1,
+ 2,
+ -3,
+ -2,
+ -1,
+ 0,
+ 1,
+ 2,
+ 3,
+ -3,
+ -2,
+ -1,
+ 0,
+ 1,
+ 2,
+ 3,
+ -3,
+ -2,
+ -1,
+ 0,
+ 1,
+ 2,
+ 3,
+ -2,
+ -1,
+ 0,
+ 1,
+ 2,
+ -1,
+ 0,
+ 1,
+ )
+ ),
+ ),
+ ),
]
for crow, ccol, radius, ref_coords in tests:
@@ -314,8 +430,8 @@ class TestCircleFill(ParametricTestCase):
coords = shapes.circle_fill(crow, ccol, radius)
is_equal = numpy.all(numpy.equal(coords, ref_coords))
if not is_equal:
- _logger.debug('result:\n%s', str(coords))
- _logger.debug('ref:\n%s', str(ref_coords))
+ _logger.debug("result:\n%s", str(coords))
+ _logger.debug("ref:\n%s", str(ref_coords))
self.assertTrue(is_equal)
diff --git a/src/silx/image/test/test_tomography.py b/src/silx/image/test/test_tomography.py
index e697cbc..73191d6 100644
--- a/src/silx/image/test/test_tomography.py
+++ b/src/silx/image/test/test_tomography.py
@@ -34,10 +34,9 @@ import numpy
from silx.test.utils import utilstest
from silx.image import tomography
-class TestTomography(unittest.TestCase):
- """
- """
+class TestTomography(unittest.TestCase):
+ """ """
def setUp(self):
self.sinoTrueData = numpy.load(utilstest.getfile("sino500.npz"))["data"]
@@ -47,7 +46,7 @@ class TestTomography(unittest.TestCase):
self.assertTrue(numpy.isclose(centerTD, 256, rtol=0.01))
def testCalcCenterCorr(self):
- centerTrueData = tomography.calc_center_corr(self.sinoTrueData,
- fullrot=False,
- props=1)
+ centerTrueData = tomography.calc_center_corr(
+ self.sinoTrueData, fullrot=False, props=1
+ )
self.assertTrue(numpy.isclose(centerTrueData, 256, rtol=0.01))
diff --git a/src/silx/image/tomography.py b/src/silx/image/tomography.py
index d64afde..826aff6 100644
--- a/src/silx/image/tomography.py
+++ b/src/silx/image/tomography.py
@@ -40,6 +40,7 @@ from silx.math.fit import leastsq
# -------------------- Filtering-related functions -----------------------------
# ------------------------------------------------------------------------------
+
def compute_ramlak_filter(dwidth_padded, dtype=np.float32):
"""
Compute the Ramachandran-Lakshminarayanan (Ram-Lak) filter, used in
@@ -50,11 +51,11 @@ def compute_ramlak_filter(dwidth_padded, dtype=np.float32):
"""
L = dwidth_padded
h = np.zeros(L, dtype=dtype)
- L2 = L//2+1
- h[0] = 1/4.
- j = np.linspace(1, L2, L2//2, False).astype(dtype) # np < 1.9.0
- h[1:L2:2] = -1./(pi**2 * j**2)
- h[L2:] = np.copy(h[1:L2-1][::-1])
+ L2 = L // 2 + 1
+ h[0] = 1 / 4.0
+ j = np.linspace(1, L2, L2 // 2, False).astype(dtype) # np < 1.9.0
+ h[1:L2:2] = -1.0 / (pi**2 * j**2)
+ h[L2:] = np.copy(h[1 : L2 - 1][::-1])
return h
@@ -66,14 +67,14 @@ def tukey(N, alpha=0.5):
:param float alpha:
"""
apod = np.zeros(N)
- x = np.arange(N)/(N-1)
+ x = np.arange(N) / (N - 1)
r = alpha
- M1 = (0 <= x) * (x < r/2)
- M2 = (r/2 <= x) * (x <= 1 - r/2)
- M3 = (1 - r/2 < x) * (x <= 1)
- apod[M1] = (1 + np.cos(2*pi/r * (x[M1] - r/2)))/2.
- apod[M2] = 1.
- apod[M3] = (1 + np.cos(2*pi/r * (x[M3] - 1 + r/2)))/2.
+ M1 = (0 <= x) * (x < r / 2)
+ M2 = (r / 2 <= x) * (x <= 1 - r / 2)
+ M3 = (1 - r / 2 < x) * (x <= 1)
+ apod[M1] = (1 + np.cos(2 * pi / r * (x[M1] - r / 2))) / 2.0
+ apod[M2] = 1.0
+ apod[M3] = (1 + np.cos(2 * pi / r * (x[M3] - 1 + r / 2))) / 2.0
return apod
@@ -83,11 +84,11 @@ def lanczos(N):
:param int N: window width
"""
- x = np.arange(N)/(N-1)
- return np.sin(pi*(2*x-1))/(pi*(2*x-1))
+ x = np.arange(N) / (N - 1)
+ return np.sin(pi * (2 * x - 1)) / (pi * (2 * x - 1))
-def compute_fourier_filter(dwidth_padded, filter_name, cutoff=1.):
+def compute_fourier_filter(dwidth_padded, filter_name, cutoff=1.0):
"""
Compute the filter used for FBP.
@@ -98,7 +99,7 @@ def compute_fourier_filter(dwidth_padded, filter_name, cutoff=1.):
:param cutoff: Cut-off frequency, if relevant.
"""
Nf = dwidth_padded
- #~ filt_f = np.abs(np.fft.fftfreq(Nf))
+ # ~ filt_f = np.abs(np.fft.fftfreq(Nf))
rl = compute_ramlak_filter(Nf, dtype=np.float64)
filt_f = np.fft.fft(rl)
@@ -110,20 +111,22 @@ def compute_fourier_filter(dwidth_padded, filter_name, cutoff=1.):
d = cutoff
apodization = {
# ~OK
- "shepp-logan": np.sin(w[1:Nf]/(2*d))/(w[1:Nf]/(2*d)),
+ "shepp-logan": np.sin(w[1:Nf] / (2 * d)) / (w[1:Nf] / (2 * d)),
# ~OK
- "cosine": np.cos(w[1:Nf]/(2*d)),
+ "cosine": np.cos(w[1:Nf] / (2 * d)),
# OK
- "hamming": 0.54*np.ones_like(filt_f)[1:Nf] + .46 * np.cos(w[1:Nf]/d),
+ "hamming": 0.54 * np.ones_like(filt_f)[1:Nf] + 0.46 * np.cos(w[1:Nf] / d),
# OK
- "hann": (np.ones_like(filt_f)[1:Nf] + np.cos(w[1:Nf]/d))/2.,
+ "hann": (np.ones_like(filt_f)[1:Nf] + np.cos(w[1:Nf] / d)) / 2.0,
# These one is not compatible with Astra - TODO investigate why
- "tukey": np.fft.fftshift(tukey(dwidth_padded, alpha=d/2.))[1:Nf],
+ "tukey": np.fft.fftshift(tukey(dwidth_padded, alpha=d / 2.0))[1:Nf],
"lanczos": np.fft.fftshift(lanczos(dwidth_padded))[1:Nf],
}
if filter_name not in apodization:
- raise ValueError("Unknown filter %s. Available filters are %s" %
- (filter_name, str(apodization.keys())))
+ raise ValueError(
+ "Unknown filter %s. Available filters are %s"
+ % (filter_name, str(apodization.keys()))
+ )
filt_f[1:Nf] *= apodization[filter_name]
return filt_f
@@ -142,11 +145,11 @@ def generate_powers():
# not multiple of 4 (Ram-Lak filter behaves strangely when
# dwidth_padded/2 is not even)
minval = 2 if prime == 2 else 0
- valuations.append(range(minval, maxpow[prime]+1))
+ valuations.append(range(minval, maxpow[prime] + 1))
powers = product(*valuations)
res = []
for pw in powers:
- res.append(np.prod(list(map(lambda x : x[0]**x[1], zip(primes, pw)))))
+ res.append(np.prod(list(map(lambda x: x[0] ** x[1], zip(primes, pw)))))
return np.unique(res)
@@ -158,7 +161,7 @@ def get_next_power(n, powers=None):
if powers is None:
powers = generate_powers()
idx = bisect(powers, n)
- if powers[idx-1] == n:
+ if powers[idx - 1] == n:
return n
return powers[idx]
@@ -168,7 +171,6 @@ def get_next_power(n, powers=None):
# ------------------------------------------------------------------------------
-
def calc_center_corr(sino, fullrot=False, props=1):
"""
Compute a guess of the Center of Rotation (CoR) of a given sinogram.
@@ -189,7 +191,7 @@ def calc_center_corr(sino, fullrot=False, props=1):
n_a, n_d = sino.shape
first = 0
- last = -1 if not(fullrot) else n_a // 2
+ last = -1 if not (fullrot) else n_a // 2
proj1 = sino[first, :]
proj2 = sino[last, :][::-1]
@@ -202,11 +204,11 @@ def calc_center_corr(sino, fullrot=False, props=1):
pos = np.argmax(corr)
if pos > n_d // 2:
pos -= n_d
- return (n_d + pos) / 2.
+ return (n_d + pos) / 2.0
else:
corr_argsorted = np.argsort(corr)[:props]
corr_argsorted[corr_argsorted > n_d // 2] -= n_d
- return (n_d + corr_argsorted) / 2.
+ return (n_d + corr_argsorted) / 2.0
def _sine_function(t, offset, amplitude, phase):
@@ -214,7 +216,7 @@ def _sine_function(t, offset, amplitude, phase):
Helper function for calc_center_centroid
"""
n_angles = t.shape[0]
- res = amplitude * np.sin(2 * pi * (1. / (2 * n_angles)) * t + phase)
+ res = amplitude * np.sin(2 * pi * (1.0 / (2 * n_angles)) * t + phase)
return offset + res
@@ -224,8 +226,8 @@ def _sine_function_derivative(t, params, eval_idx):
"""
offset, amplitude, phase = params
n_angles = t.shape[0]
- w = 2.0 * pi * (1. / (2.0 * n_angles)) * t + phase
- grad = (1.0, np.sin(w), amplitude*np.cos(w))
+ w = 2.0 * pi * (1.0 / (2.0 * n_angles)) * t + phase
+ grad = (1.0, np.sin(w), amplitude * np.cos(w))
return grad[eval_idx]
@@ -243,37 +245,38 @@ def calc_center_centroid(sino):
n_a, n_d = sino.shape
# Compute the vector of centroids of the sinogram
i = np.arange(n_d)
- centroids = np.sum(sino*i, axis=1)/np.sum(sino, axis=1)
+ centroids = np.sum(sino * i, axis=1) / np.sum(sino, axis=1)
# Fit with a sine function : phase, amplitude, offset
# Using non-linear Levenberg–Marquardt algorithm
angles = np.linspace(0, n_a, n_a, True)
# Initial parameter vector
cmax, cmin = centroids.max(), centroids.min()
- offs = (cmax + cmin) / 2.
- amp = (cmax - cmin) / 2.
+ offs = (cmax + cmin) / 2.0
+ amp = (cmax - cmin) / 2.0
phi = 1.1
p0 = (offs, amp, phi)
constraints = np.zeros((3, 3))
- popt, _ = leastsq(model=_sine_function,
- xdata=angles,
- ydata=centroids,
- p0=p0,
- sigma=None,
- constraints=constraints,
- model_deriv=None,
- epsfcn=None,
- deltachi=None,
- full_output=0,
- check_finite=True,
- left_derivative=False,
- max_iter=100)
+ popt, _ = leastsq(
+ model=_sine_function,
+ xdata=angles,
+ ydata=centroids,
+ p0=p0,
+ sigma=None,
+ constraints=constraints,
+ model_deriv=None,
+ epsfcn=None,
+ deltachi=None,
+ full_output=0,
+ check_finite=True,
+ left_derivative=False,
+ max_iter=100,
+ )
return popt[0]
-
# ------------------------------------------------------------------------------
# -------------------- Visualization-related functions -------------------------
# ------------------------------------------------------------------------------
@@ -292,9 +295,8 @@ def rescale_intensity(img, from_subimg=None, percentiles=None):
percentiles = [2, 98]
else:
assert type(percentiles) in (tuple, list)
- assert(len(percentiles) == 2)
+ assert len(percentiles) == 2
data = from_subimg if from_subimg is not None else img
imin, imax = np.percentile(data, percentiles)
res = np.clip(img, imin, imax)
return res
-
diff --git a/src/silx/image/utils.py b/src/silx/image/utils.py
index 2659112..5ee9b7b 100644
--- a/src/silx/image/utils.py
+++ b/src/silx/image/utils.py
@@ -24,6 +24,7 @@
import numpy as np
from math import ceil
+
def gaussian_kernel(sigma, cutoff=4, force_odd_size=False):
"""
Generates a Gaussian convolution kernel.
@@ -47,6 +48,6 @@ def gaussian_kernel(sigma, cutoff=4, force_odd_size=False):
if force_odd_size and size % 2 == 0:
size += 1
x = np.arange(size) - (size - 1.0) / 2.0
- g = np.exp(-(x / sigma) ** 2 / 2.0)
+ g = np.exp(-((x / sigma) ** 2) / 2.0)
g /= g.sum()
return g
diff --git a/src/silx/io/_sliceh5.py b/src/silx/io/_sliceh5.py
new file mode 100644
index 0000000..ba7c542
--- /dev/null
+++ b/src/silx/io/_sliceh5.py
@@ -0,0 +1,221 @@
+# /*##########################################################################
+# Copyright (C) 2022-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Provides a wrapper to expose a dataset slice as a `commonh5.Dataset`."""
+
+from __future__ import annotations
+
+from typing import Tuple, Union
+
+import h5py
+import numpy
+
+from . import commonh5
+from . import utils
+
+
+IndexType = Union[int, slice, type(Ellipsis)]
+IndicesType = Union[IndexType, Tuple[IndexType, ...]]
+NormalisedIndicesType = Tuple[Union[int, slice], ...]
+
+
+def _expand_indices(
+ ndim: int,
+ indices: IndicesType,
+) -> NormalisedIndicesType:
+ """Replace Ellipsis and complete indices to match ndim"""
+ if not isinstance(indices, tuple):
+ indices = (indices,)
+
+ nb_ellipsis = indices.count(Ellipsis)
+ if nb_ellipsis > 1:
+ raise IndexError("an index can only have a single ellipsis ('...')")
+ if nb_ellipsis == 1:
+ ellipsis_index = indices.index(Ellipsis)
+ return (
+ indices[:ellipsis_index]
+ + (slice(None),) * max(0, (ndim - len(indices) + 1))
+ + indices[ellipsis_index + 1 :]
+ )
+
+ if len(indices) > ndim:
+ raise IndexError(
+ f"too many indices ({len(indices)}) for the number of dimensions ({ndim})"
+ )
+ return indices + (slice(None),) * (ndim - len(indices))
+
+
+def _get_selection_shape(
+ shape: tuple[int, ...],
+ indices: NormalisedIndicesType,
+) -> tuple[int, ...]:
+ """Returns the shape of the selection of indices in a dataset of the given shape"""
+ assert len(shape) == len(indices)
+
+ selected_indices = (
+ index.indices(length)
+ for length, index in zip(shape, indices)
+ if isinstance(index, slice)
+ )
+ return tuple(
+ int(max(0, numpy.ceil((stop - start) / stride)))
+ for start, stop, stride in selected_indices
+ )
+
+
+def _combine_indices(
+ outer_shape: tuple[int, ...],
+ outer_indices: NormalisedIndicesType,
+ indices: IndicesType,
+) -> NormalisedIndicesType:
+ """Returns the combination of outer_indices and indices"""
+ inner_shape = _get_selection_shape(outer_shape, outer_indices)
+ inner_indices = _expand_indices(len(inner_shape), indices)
+ inner_iter = zip(range(len(inner_shape)), inner_shape, inner_indices)
+
+ combined_indices = []
+ for outer_length, outer_index in zip(outer_shape, outer_indices):
+ if isinstance(outer_index, int):
+ combined_indices.append(outer_index)
+ continue
+
+ outer_start, outer_stop, outer_stride = outer_index.indices(outer_length)
+ inner_axis, inner_length, inner_index = next(inner_iter)
+
+ if isinstance(inner_index, int):
+ if inner_index < -inner_length or inner_index >= inner_length:
+ raise IndexError(
+ f"index {inner_index} is out of bounds for axis {inner_axis} with size {inner_length}"
+ )
+ index = outer_start + outer_stride * inner_index
+ if inner_index < 0:
+ index += outer_stride * inner_length
+ combined_indices.append(index)
+ continue
+
+ inner_start, inner_stop, inner_stride = inner_index.indices(inner_length)
+ combined_indices.append(
+ slice(
+ outer_start + outer_stride * inner_start,
+ outer_start + outer_stride * inner_stop,
+ outer_stride * inner_stride,
+ )
+ )
+
+ return tuple(combined_indices)
+
+
+class DatasetSlice(commonh5.Dataset):
+ """Wrapper a dataset indexed selection as a commonh5.Dataset.
+ :param h5file: h5py-like file containing the dataset
+ :param dataset: h5py-like dataset from which to access a slice
+ :param indices: The indexing to select
+ :param attrs: dataset attributes
+ """
+
+ def __init__(
+ self,
+ dataset: Union[h5py.Dataset, commonh5.Dataset],
+ indices: IndicesType,
+ attrs: dict,
+ ):
+ if not utils.is_dataset(dataset):
+ raise ValueError(f"Unsupported dataset '{dataset}'")
+
+ self.__dataset = dataset
+ self.__file = dataset.file # Keep a ref on file to fix issue recovering it
+ self.__indices = indices
+ self.__expanded_indices = _expand_indices(len(self.__dataset.shape), indices)
+ self.__shape = _get_selection_shape(
+ self.__dataset.shape, self.__expanded_indices
+ )
+ super().__init__(
+ self.__dataset.name, data=None, parent=self.__file, attrs=attrs
+ )
+
+ def _get_data(self) -> Union[h5py.Dataset, commonh5.Dataset]:
+ # Give access to the underlying (h5py) dataset, not the selected data
+ # All commonh5.Dataset methods using _get_data must be overridden
+ return self.__dataset
+
+ @property
+ def dtype(self) -> numpy.dtype:
+ return self.__dataset.dtype
+
+ @property
+ def shape(self) -> tuple[int, ...]:
+ return self.__shape
+
+ @property
+ def size(self) -> int:
+ return numpy.prod(self.shape)
+
+ def __len__(self) -> int:
+ return self.shape[0]
+
+ def __getitem__(self, item):
+ if item is Ellipsis:
+ return numpy.array(self.__dataset[self.__expanded_indices], copy=False)
+ if item == ():
+ return self.__dataset[self.__expanded_indices]
+
+ if not self.__shape:
+ raise IndexError("invalid index to scalar variable.")
+
+ return self.__dataset[
+ _combine_indices(
+ self.__dataset.shape,
+ self.__expanded_indices,
+ item,
+ )
+ ]
+
+ @property
+ def value(self):
+ return self[()]
+
+ def __iter__(self):
+ return self[()].__iter__()
+
+ @property
+ def file(self) -> Union[h5py.File, commonh5.File]:
+ if isinstance(self.__file, h5py.File) and not self.__file.id:
+ return None
+ return self.__file
+
+ @property
+ def name(self) -> str:
+ return self.basename
+
+ @property
+ def indices(self) -> IndicesType:
+ return self.__indices
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()
+
+ def close(self):
+ """Close the file"""
+ self.__file.close()
diff --git a/src/silx/io/commonh5.py b/src/silx/io/commonh5.py
index 25744b4..8948e49 100644
--- a/src/silx/io/commonh5.py
+++ b/src/silx/io/commonh5.py
@@ -1,5 +1,5 @@
# /*##########################################################################
-# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
+# Copyright (C) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -24,11 +24,7 @@
This module contains generic objects, emulating *h5py* groups, datasets and
files. They are used in :mod:`spech5` and :mod:`fabioh5`.
"""
-import collections
-try:
- from collections import abc
-except ImportError: # Python2 support
- import collections as abc
+from collections import abc
import weakref
import h5py
@@ -181,8 +177,7 @@ class Node(object):
@property
def name(self):
- """Returns the HDF5 name of this node.
- """
+ """Returns the HDF5 name of this node."""
parent = self.parent
if parent is None:
return "/"
@@ -192,8 +187,7 @@ class Node(object):
@property
def basename(self):
- """Returns the HDF5 basename of this node.
- """
+ """Returns the HDF5 basename of this node."""
return self.__basename
def _is_editable(self):
@@ -323,13 +317,18 @@ class Dataset(Node):
elif item == tuple():
return self._get_data()
else:
- raise ValueError("Scalar can only be reached with an ellipsis or an empty tuple")
+ raise ValueError(
+ "Scalar can only be reached with an ellipsis or an empty tuple"
+ )
return self._get_data().__getitem__(item)
def __str__(self):
basename = self.name.split("/")[-1]
- return '<HDF5-like dataset "%s": shape %s, type "%s">' % \
- (basename, self.shape, self.dtype.str)
+ return '<HDF5-like dataset "%s": shape %s, type "%s">' % (
+ basename,
+ self.shape,
+ self.dtype.str,
+ )
def __getslice__(self, i, j):
"""Returns the slice of the data exposed by this dataset.
@@ -391,7 +390,7 @@ class Dataset(Node):
def __array__(self, dtype=None):
# Special case for (0,)*-shape datasets
- if numpy.product(self.shape) == 0:
+ if numpy.prod(self.shape) == 0:
return self[()]
else:
return numpy.array(self[...], dtype=self.dtype if dtype is None else dtype)
@@ -491,8 +490,7 @@ class Dataset(Node):
return self[()] >= other
def __getattr__(self, item):
- """Proxy to underlying numpy array methods.
- """
+ """Proxy to underlying numpy array methods."""
data = self._get_data()
if hasattr(data, item):
return getattr(data, item)
@@ -588,6 +586,7 @@ class SoftLink(Node):
In this implementation, the path to the target must be absolute.
"""
+
def __init__(self, name, path, parent=None):
assert str(path).startswith("/") # TODO: h5py also allows a relative path
@@ -615,7 +614,7 @@ class Group(Node):
def __init__(self, name, parent=None, attrs=None):
Node.__init__(self, name, parent, attrs=attrs)
- self.__items = collections.OrderedDict()
+ self.__items = {}
def _get_items(self):
"""Returns the child items as a name-node dictionary.
@@ -661,8 +660,9 @@ class Group(Node):
result = result.file.get(l_target)
if result is None:
raise KeyError(
- "Unable to open object (broken SoftLink %s -> %s)" %
- (l_name, l_target))
+ "Unable to open object (broken SoftLink %s -> %s)"
+ % (l_name, l_target)
+ )
if not item_name:
# trailing "/" in name (legal for accessing Groups only)
if isinstance(result, Group):
@@ -679,9 +679,13 @@ class Group(Node):
raise KeyError(msg % (link.name, link.path))
# Convert SoftLink into typed group/dataset
if isinstance(target, Group):
- result = _LinkToGroup(name=link.basename, target=target, parent=link.parent)
+ result = _LinkToGroup(
+ name=link.basename, target=target, parent=link.parent
+ )
elif isinstance(target, Dataset):
- result = _LinkToDataset(name=link.basename, target=target, parent=link.parent)
+ result = _LinkToDataset(
+ name=link.basename, target=target, parent=link.parent
+ )
else:
raise TypeError("Unexpected target type %s" % type(target))
@@ -875,11 +879,9 @@ class Group(Node):
call `func(name)` for links and recurse into target groups.
"""
origin_name = self.name
- return self._visit(func, origin_name, visit_links,
- visititems=True)
+ return self._visit(func, origin_name, visit_links, visititems=True)
- def _visit(self, func, origin_name,
- visit_links=False, visititems=False):
+ def _visit(self, func, origin_name, visit_links=False, visititems=False):
"""
:param origin_name: name of first group that initiated the recursion
@@ -889,7 +891,7 @@ class Group(Node):
for member in self.values():
ret = None
if not isinstance(member, SoftLink) or visit_links:
- relative_name = member.name[len(origin_name):]
+ relative_name = member.name[len(origin_name) :]
# remove leading slash and unnecessary trailing slash
relative_name = relative_name.strip("/")
if visititems:
@@ -918,13 +920,15 @@ class Group(Node):
name = name[1:]
return self.file.create_group(name)
- elements = name.split('/')
+ elements = name.split("/")
group = self
for basename in elements:
if basename in group:
group = group[basename]
if not isinstance(group, Group):
- raise RuntimeError("Unable to create group (group parent is missing")
+ raise RuntimeError(
+ "Unable to create group (group parent is missing"
+ )
else:
node = Group(basename)
group.add_node(node)
@@ -1031,7 +1035,7 @@ class File(Group):
self._file_name = name
if mode is None:
mode = "r"
- assert(mode in ["r", "w"])
+ assert mode in ["r", "w"]
self._mode = mode
@property
@@ -1054,7 +1058,6 @@ class File(Group):
self.close()
def close(self):
- """Close the object, and free up associated resources.
- """
+ """Close the object, and free up associated resources."""
# should be implemented in subclass
pass
diff --git a/src/silx/io/configdict.py b/src/silx/io/configdict.py
index c028211..e2a012e 100644
--- a/src/silx/io/configdict.py
+++ b/src/silx/io/configdict.py
@@ -1,5 +1,5 @@
# /*##########################################################################
-# Copyright (C) 2004-2018 European Synchrotron Radiation Facility
+# Copyright (C) 2004-2023 European Synchrotron Radiation Facility
#
# This file is part of the PyMca X-ray Fluorescence Toolkit developed at
# the ESRF by the Software group.
@@ -88,17 +88,9 @@ __author__ = ["E. Papillon", "V.A. Sole", "P. Knobel"]
__license__ = "MIT"
__date__ = "15/09/2016"
-from collections import OrderedDict
import numpy
import re
-import sys
-if sys.version_info < (3, ):
- import ConfigParser as configparser
-else:
- import configparser
-
-
-string_types = (basestring,) if sys.version_info[0] == 2 else (str,) # noqa
+import configparser
def _boolean(sstr):
@@ -112,9 +104,9 @@ def _boolean(sstr):
:raise: ``ValueError`` if ``sstr`` is not a valid string representation
of a boolean
"""
- if sstr.lower() in ['1', 'yes', 'true', 'on']:
+ if sstr.lower() in ["1", "yes", "true", "on"]:
return True
- if sstr.lower() in ['0', 'no', 'false', 'off']:
+ if sstr.lower() in ["0", "no", "false", "off"]:
return False
msg = "Cannot coerce string '%s' to a boolean value. " % sstr
msg += "Valid boolean strings: '1', 'yes', 'true', 'on', "
@@ -171,20 +163,19 @@ def _parse_container(sstr):
if not sstr:
raise ValueError
- if sstr.find(',') == -1:
+ if sstr.find(",") == -1:
# it is not a list
- if (sstr[0] == '[') and (sstr[-1] == ']'):
+ if (sstr[0] == "[") and (sstr[-1] == "]"):
# this looks like an array
try:
# try parsing as a 1D array
return numpy.array([float(x) for x in sstr[1:-1].split()])
except ValueError:
# try parsing as a 2D array
- if (sstr[2] == '[') and (sstr[-3] == ']'):
- nrows = len(sstr[3:-3].split('] ['))
- data = sstr[3:-3].replace('] [', ' ')
- data = numpy.array([float(x) for x in
- data.split()])
+ if (sstr[2] == "[") and (sstr[-3] == "]"):
+ nrows = len(sstr[3:-3].split("] ["))
+ data = sstr[3:-3].replace("] [", " ")
+ data = numpy.array([float(x) for x in data.split()])
data.shape = nrows, -1
return data
# not a list and not an array
@@ -215,21 +206,22 @@ def _parse_list_line(sstr):
# (_parse_simple_types recognizes ^@ as a comma)
sstr.replace(r"\,", "^@")
# it is a list
- if sstr.endswith(','):
- if ',' in sstr[:-1]:
- return [_parse_simple_types(sstr2.strip())
- for sstr2 in sstr[:-1].split(',')]
+ if sstr.endswith(","):
+ if "," in sstr[:-1]:
+ return [
+ _parse_simple_types(sstr2.strip()) for sstr2 in sstr[:-1].split(",")
+ ]
else:
return [_parse_simple_types(sstr[:-1].strip())]
else:
- return [_parse_simple_types(sstr2.strip())
- for sstr2 in sstr.split(',')]
+ return [_parse_simple_types(sstr2.strip()) for sstr2 in sstr.split(",")]
class OptionStr(str):
"""String class providing typecasting methods to parse values in a
:class:`ConfigDict` generated configuration file.
"""
+
def toint(self):
"""
:return: integer
@@ -288,7 +280,7 @@ class OptionStr(str):
return _parse_simple_types(self)
-class ConfigDict(OrderedDict):
+class ConfigDict(dict):
"""Store configuration parameters as an ordered dictionary.
Parameters can be grouped into sections, by storing them as
@@ -318,9 +310,10 @@ class ConfigDict(OrderedDict):
:param filelist: List of configuration files to be read and added into
dict after ``defaultdict`` and ``initdict``
"""
+
def __init__(self, defaultdict=None, initdict=None, filelist=None):
- self.default = defaultdict if defaultdict is not None else OrderedDict()
- OrderedDict.__init__(self, self.default)
+ self.default = defaultdict if defaultdict is not None else {}
+ super().__init__(self.default)
self.filelist = []
if initdict is not None:
@@ -329,19 +322,17 @@ class ConfigDict(OrderedDict):
self.read(filelist)
def reset(self):
- """ Revert to default values
- """
+ """Revert to default values"""
self.clear()
self.update(self.default)
def clear(self):
- """ Clear dictionnary
- """
- OrderedDict.clear(self)
+ """Clear dictionnary"""
+ super().clear()
self.filelist = []
def __tolist(self, mylist):
- """ If ``mylist` is not a list, encapsulate it in a list and return
+ """If ``mylist` is not a list, encapsulate it in a list and return
it.
:param mylist: List to encapsulate
@@ -411,10 +402,10 @@ class ConfigDict(OrderedDict):
for sect in readsect:
ddict = self
- for subsectw in sect.split('.'):
+ for subsectw in sect.split("."):
subsect = subsectw.replace("_|_", ".")
if not subsect in ddict:
- ddict[subsect] = OrderedDict()
+ ddict[subsect] = {}
ddict = ddict[subsect]
for opt in cfg.options(sect):
ddict[opt] = self.__parse_data(cfg.get(sect, opt))
@@ -431,9 +422,9 @@ class ConfigDict(OrderedDict):
return OptionStr(data).tobestguess()
def tostring(self):
- """Return INI file content generated by :meth:`write` as a string
- """
+ """Return INI file content generated by :meth:`write` as a string"""
import StringIO
+
tmp = StringIO.StringIO()
self.__write(tmp, self)
return tmp.getvalue()
@@ -469,15 +460,14 @@ class ConfigDict(OrderedDict):
the interpolation syntax
(https://docs.python.org/3/library/configparser.html#interpolation-of-values).
"""
- non_str = r'^([0-9]+|[0-9]*\.[0-9]*|none|false|true|on|off|yes|no)$'
+ non_str = r"^([0-9]+|[0-9]*\.[0-9]*|none|false|true|on|off|yes|no)$"
if re.match(non_str, sstr.lower()):
sstr = "\\" + sstr
# Escape commas
sstr = sstr.replace(",", r"\,")
- if sys.version_info >= (3, ):
- # Escape % characters except in "%%" and "%("
- sstr = re.sub(r'%([^%\(])', r'%%\1', sstr)
+ # Escape % characters except in "%%" and "%("
+ sstr = re.sub(r"%([^%\(])", r"%%\1", sstr)
return sstr
@@ -492,49 +482,52 @@ class ConfigDict(OrderedDict):
dictkey = []
for key in ddict.keys():
- if hasattr(ddict[key], 'keys'):
+ if hasattr(ddict[key], "keys"):
# subsections are added at the end of a section
dictkey.append(key)
elif isinstance(ddict[key], list):
- fp.write('%s = ' % key)
+ fp.write("%s = " % key)
llist = []
- sep = ', '
+ sep = ", "
for item in ddict[key]:
if isinstance(item, list):
if len(item) == 1:
- if isinstance(item[0], string_types):
+ if isinstance(item[0], str):
self._escape_str(item[0])
- llist.append('%s,' % self._escape_str(item[0]))
+ llist.append("%s," % self._escape_str(item[0]))
else:
- llist.append('%s,' % item[0])
+ llist.append("%s," % item[0])
else:
item2 = []
for val in item:
- if isinstance(val, string_types):
+ if isinstance(val, str):
val = self._escape_str(val)
item2.append(val)
- llist.append(', '.join([str(val) for val in item2]))
- sep = '\n\t'
- elif isinstance(item, string_types):
+ llist.append(", ".join([str(val) for val in item2]))
+ sep = "\n\t"
+ elif isinstance(item, str):
llist.append(self._escape_str(item))
else:
llist.append(str(item))
- fp.write('%s\n' % (sep.join(llist)))
- elif isinstance(ddict[key], string_types):
- fp.write('%s = %s\n' % (key, self._escape_str(ddict[key])))
+ fp.write("%s\n" % (sep.join(llist)))
+ elif isinstance(ddict[key], str):
+ fp.write("%s = %s\n" % (key, self._escape_str(ddict[key])))
else:
if isinstance(ddict[key], numpy.ndarray):
- fp.write('%s =' % key + ' [ ' +
- ' '.join([str(val) for val in ddict[key]]) +
- ' ]\n')
+ fp.write(
+ "%s =" % key
+ + " [ "
+ + " ".join([str(val) for val in ddict[key]])
+ + " ]\n"
+ )
else:
- fp.write('%s = %s\n' % (key, ddict[key]))
+ fp.write("%s = %s\n" % (key, ddict[key]))
for key in dictkey:
if secthead is None:
newsecthead = key.replace(".", "_|_")
else:
- newsecthead = '%s.%s' % (secthead, key.replace(".", "_|_"))
+ newsecthead = "%s.%s" % (secthead, key.replace(".", "_|_"))
- fp.write('\n[%s]\n' % newsecthead)
+ fp.write("\n[%s]\n" % newsecthead)
self.__write(fp, ddict[key], newsecthead)
diff --git a/src/silx/io/convert.py b/src/silx/io/convert.py
index a4c5dc3..6254b14 100644
--- a/src/silx/io/convert.py
+++ b/src/silx/io/convert.py
@@ -1,5 +1,5 @@
# /*##########################################################################
-# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
+# Copyright (C) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -28,26 +28,6 @@ Read the documentation of :mod:`silx.io.spech5`, :mod:`silx.io.fioh5` and :mod:`
information on the structure of the output HDF5 files.
Text strings are written to the HDF5 datasets as variable-length utf-8.
-
-.. warning::
-
- The output format for text strings changed in silx version 0.7.0.
- Prior to that, text was output as fixed-length ASCII.
-
- To be on the safe side, when reading back a HDF5 file written with an
- older version of silx, you can test for the presence of a *decode*
- attribute. To ensure that you always work with unicode text::
-
- >>> import h5py
- >>> h5f = h5py.File("my_scans.h5", "r")
- >>> title = h5f["/68.1/title"]
- >>> if hasattr(title, "decode"):
- ... title = title.decode()
-
-
-.. note:: This module has a dependency on the `h5py <http://www.h5py.org/>`_
- library, which is not a mandatory dependency for `silx`. You might need
- to install it if you don't already have it.
"""
__authors__ = ["P. Knobel"]
@@ -68,8 +48,7 @@ from . import fabioh5
_logger = logging.getLogger(__name__)
-def _create_link(h5f, link_name, target_name,
- link_type="soft", overwrite_data=False):
+def _create_link(h5f, link_name, target_name, link_type="soft", overwrite_data=False):
"""Create a link in a HDF5 file
If member with name ``link_name`` already exists, delete it first or
@@ -85,12 +64,12 @@ def _create_link(h5f, link_name, target_name,
if link_name not in h5f:
_logger.debug("Creating link " + link_name + " -> " + target_name)
elif overwrite_data:
- _logger.warning("Overwriting " + link_name + " with link to " +
- target_name)
+ _logger.warning("Overwriting " + link_name + " with link to " + target_name)
del h5f[link_name]
else:
- _logger.warning(link_name + " already exist. Cannot create link to " +
- target_name)
+ _logger.warning(
+ link_name + " already exist. Cannot create link to " + target_name
+ )
return None
if link_type == "hard":
@@ -108,9 +87,7 @@ def _attr_utf8(attr_value):
:return: Attr ready to be written by h5py as utf8
"""
if isinstance(attr_value, (bytes, str)):
- out_attr_value = numpy.array(
- attr_value,
- dtype=h5py.special_dtype(vlen=str))
+ out_attr_value = numpy.array(attr_value, dtype=h5py.special_dtype(vlen=str))
else:
out_attr_value = attr_value
@@ -118,14 +95,16 @@ def _attr_utf8(attr_value):
class Hdf5Writer(object):
- """Converter class to write the content of a data file to a HDF5 file.
- """
- def __init__(self,
- h5path='/',
- overwrite_data=False,
- link_type="soft",
- create_dataset_args=None,
- min_size=500):
+ """Converter class to write the content of a data file to a HDF5 file."""
+
+ def __init__(
+ self,
+ h5path="/",
+ overwrite_data=False,
+ link_type="soft",
+ create_dataset_args=None,
+ min_size=500,
+ ):
"""
:param h5path: Target path where the scan groups will be written
@@ -155,7 +134,7 @@ class Hdf5Writer(object):
self.min_size = min_size
- self.overwrite_data = overwrite_data # boolean
+ self.overwrite_data = overwrite_data # boolean
self.link_type = link_type
"""'soft' or 'hard' """
@@ -184,14 +163,17 @@ class Hdf5Writer(object):
root_grp = h5f[self.h5path]
for key in infile.attrs:
if self.overwrite_data or key not in root_grp.attrs:
- root_grp.attrs.create(key,
- _attr_utf8(infile.attrs[key]))
+ root_grp.attrs.create(key, _attr_utf8(infile.attrs[key]))
# Handle links at the end, when their targets are created
for link_name, target_name in self._links:
- _create_link(self._h5f, link_name, target_name,
- link_type=self.link_type,
- overwrite_data=self.overwrite_data)
+ _create_link(
+ self._h5f,
+ link_name,
+ target_name,
+ link_type=self.link_type,
+ overwrite_data=self.overwrite_data,
+ )
self._links = []
def append_member_to_h5(self, h5like_name, obj):
@@ -215,10 +197,12 @@ class Hdf5Writer(object):
if isinstance(obj, fabioh5.FrameData) and len(obj.shape) > 2:
# special case of multiframe data
# write frame by frame to save memory usage low
- ds = self._h5f.create_dataset(h5_name,
- shape=obj.shape,
- dtype=obj.dtype,
- **self.create_dataset_args)
+ ds = self._h5f.create_dataset(
+ h5_name,
+ shape=obj.shape,
+ dtype=obj.dtype,
+ **self.create_dataset_args,
+ )
for i, frame in enumerate(obj):
ds[i] = frame
else:
@@ -226,16 +210,16 @@ class Hdf5Writer(object):
if obj.size < self.min_size:
ds = self._h5f.create_dataset(h5_name, data=obj[()])
else:
- ds = self._h5f.create_dataset(h5_name, data=obj[()],
- **self.create_dataset_args)
+ ds = self._h5f.create_dataset(
+ h5_name, data=obj[()], **self.create_dataset_args
+ )
else:
ds = self._h5f[h5_name]
# add HDF5 attributes
for key in obj.attrs:
if self.overwrite_data or key not in ds.attrs:
- ds.attrs.create(key,
- _attr_utf8(obj.attrs[key]))
+ ds.attrs.create(key, _attr_utf8(obj.attrs[key]))
if not self.overwrite_data and member_initially_exists:
_logger.warning("Not overwriting existing dataset: " + h5_name)
@@ -250,15 +234,21 @@ class Hdf5Writer(object):
# add HDF5 attributes
for key in obj.attrs:
if self.overwrite_data or key not in grp.attrs:
- grp.attrs.create(key,
- _attr_utf8(obj.attrs[key]))
+ grp.attrs.create(key, _attr_utf8(obj.attrs[key]))
else:
_logger.warning("Unsuppored entity, ignoring: %s", h5_name)
-def write_to_h5(infile, h5file, h5path='/', mode="a",
- overwrite_data=False, link_type="soft",
- create_dataset_args=None, min_size=500):
+def write_to_h5(
+ infile,
+ h5file,
+ h5path="/",
+ mode="a",
+ overwrite_data=False,
+ link_type="soft",
+ create_dataset_args=None,
+ min_size=500,
+):
"""Write content of a h5py-like object into a HDF5 file.
Warning: External links in `infile` are ignored.
@@ -287,11 +277,13 @@ def write_to_h5(infile, h5file, h5path='/', mode="a",
The structure of the spec data in an HDF5 file is described in the
documentation of :mod:`silx.io.spech5`.
"""
- writer = Hdf5Writer(h5path=h5path,
- overwrite_data=overwrite_data,
- link_type=link_type,
- create_dataset_args=create_dataset_args,
- min_size=min_size)
+ writer = Hdf5Writer(
+ h5path=h5path,
+ overwrite_data=overwrite_data,
+ link_type=link_type,
+ create_dataset_args=create_dataset_args,
+ min_size=min_size,
+ )
# both infile and h5file can be either file handle or a file name: 4 cases
if not isinstance(h5file, h5py.File) and not is_group(infile):
@@ -328,7 +320,10 @@ def convert(infile, h5file, mode="w-", create_dataset_args=None):
compression parameters. Don't specify ``name`` and ``data``.
"""
if mode not in ["w", "w-"]:
- raise IOError("File mode must be 'w' or 'w-'. Use write_to_h5" +
- " to append data to an existing HDF5 file.")
- write_to_h5(infile, h5file, h5path='/', mode=mode,
- create_dataset_args=create_dataset_args)
+ raise IOError(
+ "File mode must be 'w' or 'w-'. Use write_to_h5"
+ + " to append data to an existing HDF5 file."
+ )
+ write_to_h5(
+ infile, h5file, h5path="/", mode=mode, create_dataset_args=create_dataset_args
+ )
diff --git a/src/silx/io/dictdump.py b/src/silx/io/dictdump.py
index 094a51f..7722842 100644
--- a/src/silx/io/dictdump.py
+++ b/src/silx/io/dictdump.py
@@ -1,5 +1,5 @@
# /*##########################################################################
-# Copyright (C) 2016-2022 European Synchrotron Radiation Facility
+# Copyright (C) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -24,13 +24,13 @@
by text strings to following file formats: `HDF5, INI, JSON`
"""
-from collections import OrderedDict
from collections.abc import Mapping
import json
import logging
import numpy
import os.path
import h5py
+
try:
from pint import Quantity as PintQuantity
except ImportError:
@@ -49,7 +49,6 @@ from .utils import is_file as is_h5_file_like
from .utils import open as h5open
from .utils import h5py_read_dataset
from .utils import H5pyAttributesReadWrapper
-from silx.utils.deprecation import deprecated_warning
__authors__ = ["P. Knobel"]
__license__ = "MIT"
@@ -95,6 +94,7 @@ class _SafeH5FileWrite:
function. The object is created in the initial call if a path is provided,
and it is closed only at the end when all the processing is finished.
"""
+
def __init__(self, h5file, mode="w"):
"""
:param h5file: HDF5 file path or :class:`h5py.File` instance
@@ -128,6 +128,7 @@ class _SafeH5FileRead:
that SPEC files and all formats supported by fabio can also be opened,
but in read-only mode.
"""
+
def __init__(self, h5file):
"""
@@ -177,9 +178,14 @@ def _normalize_h5_path(h5root, h5path):
return h5file, h5path
-def dicttoh5(treedict, h5file, h5path='/',
- mode="w", overwrite_data=None,
- create_dataset_args=None, update_mode=None):
+def dicttoh5(
+ treedict,
+ h5file,
+ h5path="/",
+ mode="w",
+ create_dataset_args=None,
+ update_mode=None,
+):
"""Write a nested dictionary to a HDF5 file, using keys as member names.
If a dictionary value is a sub-dictionary, a group is created. If it is
@@ -209,9 +215,6 @@ def dicttoh5(treedict, h5file, h5path='/',
``"w"`` (write, existing file is lost), ``"w-"`` (write, fail if
exists) or ``"a"`` (read/write if exists, create otherwise).
This parameter is ignored if ``h5file`` is a file handle.
- :param overwrite_data: Deprecated. ``True`` is approximately equivalent
- to ``update_mode="modify"`` and ``False`` is equivalent to
- ``update_mode="add"``.
:param create_dataset_args: Dictionary of args you want to pass to
``h5f.create_dataset``. This allows you to specify filters and
compression parameters. Don't specify ``name`` and ``data``.
@@ -253,32 +256,16 @@ def dicttoh5(treedict, h5file, h5path='/',
create_dataset_args=create_ds_args)
"""
- if overwrite_data is not None:
- reason = (
- "`overwrite_data=True` becomes `update_mode='modify'` and "
- "`overwrite_data=False` becomes `update_mode='add'`"
- )
- deprecated_warning(
- type_="argument",
- name="overwrite_data",
- reason=reason,
- replacement="update_mode",
- since_version="0.15",
- )
-
if update_mode is None:
- if overwrite_data:
- update_mode = "modify"
- else:
- update_mode = "add"
- else:
- if update_mode not in UPDATE_MODE_VALID_EXISTING_VALUES:
- raise ValueError((
+ update_mode = "add"
+
+ if update_mode not in UPDATE_MODE_VALID_EXISTING_VALUES:
+ raise ValueError(
+ (
"Argument 'update_mode' can only have values: {}"
"".format(UPDATE_MODE_VALID_EXISTING_VALUES)
- ))
- if overwrite_data is not None:
- logger.warning("The argument `overwrite_data` is ignored")
+ )
+ )
if not isinstance(treedict, Mapping):
raise TypeError("'treedict' must be a dictionary")
@@ -301,7 +288,9 @@ def dicttoh5(treedict, h5file, h5path='/',
del h5f[h5path]
h5f.create_group(h5path)
else:
- logger.info(f'Cannot overwrite {h5f.file.filename}::{h5f[h5path].name} with update_mode="{update_mode}"')
+ logger.info(
+ f'Cannot overwrite {h5f.file.filename}::{h5f[h5path].name} with update_mode="{update_mode}"'
+ )
return
else:
h5f.create_group(h5path)
@@ -322,9 +311,13 @@ def dicttoh5(treedict, h5file, h5path='/',
del h5f[h5name]
exists = False
if value:
- dicttoh5(value, h5f, h5name,
- update_mode=update_mode,
- create_dataset_args=create_dataset_args)
+ dicttoh5(
+ value,
+ h5f,
+ h5name,
+ update_mode=update_mode,
+ create_dataset_args=create_dataset_args,
+ )
elif not exists:
h5f.create_group(h5name)
elif is_link(value):
@@ -338,7 +331,9 @@ def dicttoh5(treedict, h5file, h5path='/',
else:
# HDF5 dataset
if exists and not change_allowed:
- logger.info(f'Cannot modify dataset {h5f.file.filename}::{h5f[h5name].name} with update_mode="{update_mode}"')
+ logger.info(
+ f'Cannot modify dataset {h5f.file.filename}::{h5f[h5name].name} with update_mode="{update_mode}"'
+ )
continue
data = _prepare_hdf5_write_value(value)
@@ -352,20 +347,28 @@ def dicttoh5(treedict, h5file, h5path='/',
# Delete the existing dataset
if update_mode != "replace":
if not is_dataset(h5f[h5name]):
- logger.info(f'Cannot overwrite {h5f.file.filename}::{h5f[h5name].name} with update_mode="{update_mode}"')
+ logger.info(
+ f'Cannot overwrite {h5f.file.filename}::{h5f[h5name].name} with update_mode="{update_mode}"'
+ )
continue
attrs_backup = dict(h5f[h5name].attrs)
del h5f[h5name]
# Create dataset
# can't apply filters on scalars (datasets with shape == ())
- if data.shape == () or create_dataset_args is None:
- h5f.create_dataset(h5name,
- data=data)
- else:
- h5f.create_dataset(h5name,
- data=data,
- **create_dataset_args)
+ try:
+ if data.shape == () or create_dataset_args is None:
+ h5f.create_dataset(h5name, data=data)
+ else:
+ h5f.create_dataset(h5name, data=data, **create_dataset_args)
+ except Exception as e:
+ if isinstance(data, numpy.ndarray):
+ dtype = f"numpy.ndarray-{data.dtype}"
+ else:
+ dtype = type(data)
+ raise ValueError(
+ f"Failed to create dataset '{h5name}' with data ({dtype}) = {data}"
+ ) from e
if attrs_backup:
h5f[h5name].attrs.update(attrs_backup)
@@ -391,20 +394,20 @@ def dicttoh5(treedict, h5file, h5path='/',
else:
# Add/modify HDF5 attribute
if exists and not change_allowed:
- logger.info(f'Cannot modify attribute {h5f.file.filename}::{h5f[h5name].name}@{attr_name} with update_mode="{update_mode}"')
+ logger.info(
+ f'Cannot modify attribute {h5f.file.filename}::{h5f[h5name].name}@{attr_name} with update_mode="{update_mode}"'
+ )
continue
data = _prepare_hdf5_write_value(value)
h5a[attr_name] = data
def _has_nx_class(treedict, key=""):
- return key + "@NX_class" in treedict or \
- (key, "NX_class") in treedict
+ return key + "@NX_class" in treedict or (key, "NX_class") in treedict
def _ensure_nx_class(treedict, parents=tuple()):
- """Each group needs an "NX_class" attribute.
- """
+ """Each group needs an "NX_class" attribute."""
if _has_nx_class(treedict):
return
nparents = len(parents)
@@ -416,13 +419,11 @@ def _ensure_nx_class(treedict, parents=tuple()):
treedict[("", "NX_class")] = "NXcollection"
-def nexus_to_h5_dict(
- treedict, parents=tuple(), add_nx_class=True, has_nx_class=False
-):
+def nexus_to_h5_dict(treedict, parents=tuple(), add_nx_class=True, has_nx_class=False):
"""The following conversions are applied:
* key with "{name}@{attr_name}" notation: key converted to 2-tuple
* key with ">{url}" notation: strip ">" and convert value to
- h5py.SoftLink or h5py.ExternalLink
+ h5py.SoftLink or h5py.ExternalLink
:param treedict: Nested dictionary/tree structure with strings as keys
and array-like objects as leafs. The ``"/"`` character can be used
@@ -469,9 +470,10 @@ def nexus_to_h5_dict(
key_has_nx_class = add_nx_class and _has_nx_class(treedict, key)
copy[key] = nexus_to_h5_dict(
value,
- parents=parents+(key,),
+ parents=parents + (key,),
add_nx_class=add_nx_class,
- has_nx_class=key_has_nx_class)
+ has_nx_class=key_has_nx_class,
+ )
elif PintQuantity is not None and isinstance(value, PintQuantity):
copy[key] = value.magnitude
@@ -534,23 +536,25 @@ def _handle_error(mode: str, exception, msg: str, *args) -> None:
:param str msg: Error message template
:param List[str] args: Arguments for error message template
"""
- if mode == 'ignore':
+ if mode == "ignore":
return # no-op
- elif mode == 'log':
+ elif mode == "log":
logger.error(msg, *args)
- elif mode == 'raise':
+ elif mode == "raise":
raise exception(msg % args)
else:
raise ValueError("Unsupported error handling: %s" % mode)
-def h5todict(h5file,
- path="/",
- exclude_names=None,
- asarray=True,
- dereference_links=True,
- include_attributes=False,
- errors='raise'):
+def h5todict(
+ h5file,
+ path="/",
+ exclude_names=None,
+ asarray=True,
+ dereference_links=True,
+ include_attributes=False,
+ errors="raise",
+):
"""Read a HDF5 file and return a nested dictionary with the complete file
structure and all data.
@@ -599,20 +603,18 @@ def h5todict(h5file,
with _SafeH5FileRead(h5file) as h5f:
ddict = {}
if path not in h5f:
- _handle_error(
- errors, KeyError, 'Path "%s" does not exist in file.', path)
+ _handle_error(errors, KeyError, 'Path "%s" does not exist in file.', path)
return ddict
try:
root = h5f[path]
except KeyError as e:
if not isinstance(h5f.get(path, getlink=True), h5py.HardLink):
- _handle_error(errors,
- KeyError,
- 'Cannot retrieve path "%s" (broken link)',
- path)
+ _handle_error(
+ errors, KeyError, 'Cannot retrieve path "%s" (broken link)', path
+ )
else:
- _handle_error(errors, KeyError, ', '.join(e.args))
+ _handle_error(errors, KeyError, ", ".join(e.args))
return ddict
# Read the attributes of the group
@@ -636,31 +638,35 @@ def h5todict(h5file,
h5obj = h5f[h5name]
except KeyError as e:
if not isinstance(h5f.get(h5name, getlink=True), h5py.HardLink):
- _handle_error(errors,
- KeyError,
- 'Cannot retrieve path "%s" (broken link)',
- h5name)
+ _handle_error(
+ errors,
+ KeyError,
+ 'Cannot retrieve path "%s" (broken link)',
+ h5name,
+ )
else:
- _handle_error(errors, KeyError, ', '.join(e.args))
+ _handle_error(errors, KeyError, ", ".join(e.args))
continue
if is_group(h5obj):
# Child is an HDF5 group
- ddict[key] = h5todict(h5f,
- h5name,
- exclude_names=exclude_names,
- asarray=asarray,
- dereference_links=dereference_links,
- include_attributes=include_attributes)
+ ddict[key] = h5todict(
+ h5f,
+ h5name,
+ exclude_names=exclude_names,
+ asarray=asarray,
+ dereference_links=dereference_links,
+ include_attributes=include_attributes,
+ errors=errors,
+ )
else:
# Child is an HDF5 dataset
try:
data = h5py_read_dataset(h5obj)
except OSError:
- _handle_error(errors,
- OSError,
- 'Cannot retrieve dataset "%s"',
- h5name)
+ _handle_error(
+ errors, OSError, 'Cannot retrieve dataset "%s"', h5name
+ )
else:
if asarray: # Convert HDF5 dataset to numpy array
data = numpy.array(data, copy=False)
@@ -728,9 +734,7 @@ def dicttonx(treedict, h5file, h5path="/", add_nx_class=None, **kw):
parents = tuple(p for p in h5path.split("/") if p)
if add_nx_class is None:
add_nx_class = kw.get("update_mode", None) in (None, "add")
- nxtreedict = nexus_to_h5_dict(
- treedict, parents=parents, add_nx_class=add_nx_class
- )
+ nxtreedict = nexus_to_h5_dict(treedict, parents=parents, add_nx_class=add_nx_class)
dicttoh5(nxtreedict, h5file, h5path=h5path, **kw)
@@ -806,7 +810,7 @@ def dump(ddict, ffile, mode="w", fmat=None):
"""
if fmat is None:
# If file-like object get its name, else use ffile as filename
- filename = getattr(ffile, 'name', ffile)
+ filename = getattr(ffile, "name", ffile)
fmat = os.path.splitext(filename)[1][1:] # Strip extension leading '.'
fmat = fmat.lower()
@@ -823,7 +827,7 @@ def dump(ddict, ffile, mode="w", fmat=None):
def load(ffile, fmat=None):
"""Load dictionary from a file
- When loading from a JSON or INI file, an OrderedDict is returned to
+ When loading from a JSON or INI file, the returned dict
preserve the values' insertion order.
:param ffile: File name or file-like object with a ``read`` method
@@ -831,7 +835,7 @@ def load(ffile, fmat=None):
When None (the default), it uses the filename extension as the format.
Loading from a HDF5 file requires `h5py <http://www.h5py.org/>`_ to be
installed.
- :return: Dictionary (ordered dictionary for JSON and INI)
+ :return: Dictionary
:raises IOError: if file format is not supported
"""
must_be_closed = False
@@ -849,7 +853,7 @@ def load(ffile, fmat=None):
fmat = fmat.lower()
if fmat == "json":
- return json.load(f, object_pairs_hook=OrderedDict)
+ return json.load(f)
if fmat in ["hdf5", "h5"]:
return h5todict(fname)
elif fmat in ["ini", "cfg"]:
diff --git a/src/silx/io/fabioh5.py b/src/silx/io/fabioh5.py
index c5ef964..89e838b 100755
--- a/src/silx/io/fabioh5.py
+++ b/src/silx/io/fabioh5.py
@@ -1,5 +1,5 @@
# /*##########################################################################
-# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
+# Copyright (C) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -31,7 +31,6 @@
"""
-import collections
import datetime
import logging
import numbers
@@ -77,19 +76,6 @@ def supported_extensions():
return _fabio_extensions
-class _FileSeries(fabio.file_series.file_series):
- """
- .. note:: Overwrite a function to fix an issue in fabio.
- """
- def jump(self, num):
- """
- Goto a position in sequence
- """
- assert num < len(self) and num >= 0, "num out of range"
- self._current = num
- return self[self._current]
-
-
class FrameData(commonh5.LazyLoadableDataset):
"""Expose a cube of image from a Fabio file using `FabioReader` as
cache."""
@@ -108,8 +94,7 @@ class FrameData(commonh5.LazyLoadableDataset):
return self.__fabio_reader.get_data()
def _update_cache(self):
- if isinstance(self.__fabio_reader.fabio_file(),
- fabio.file_series.file_series):
+ if isinstance(self.__fabio_reader.fabio_file(), fabio.file_series.file_series):
# Reading all the files is taking too much time
# Reach the information from the only first frame
first_image = self.__fabio_reader.fabio_file().first_image()
@@ -140,9 +125,9 @@ class FrameData(commonh5.LazyLoadableDataset):
def __getitem__(self, item):
# optimization for fetching a single frame if data not already loaded
if not self._is_initialized:
- if isinstance(item, int) and \
- isinstance(self.__fabio_reader.fabio_file(),
- fabio.file_series.file_series):
+ if isinstance(item, int) and isinstance(
+ self.__fabio_reader.fabio_file(), fabio.file_series.file_series
+ ):
if item < 0:
# negative indexing
item += len(self)
@@ -158,8 +143,7 @@ class RawHeaderData(commonh5.LazyLoadableDataset):
self.__fabio_reader = fabio_reader
def _create_data(self):
- """Initialize hold data by merging all headers of each frames.
- """
+ """Initialize hold data by merging all headers of each frames."""
headers = []
types = set([])
for fabio_frame in self.__fabio_reader.iter_frames():
@@ -199,8 +183,7 @@ class RawHeaderData(commonh5.LazyLoadableDataset):
class MetadataGroup(commonh5.LazyLoadableGroup):
- """Abstract class for groups containing a reference to a fabio image.
- """
+ """Abstract class for groups containing a reference to a fabio image."""
def __init__(self, name, metadata_reader, kind, parent=None, attrs=None):
commonh5.LazyLoadableGroup.__init__(self, name, parent, attrs)
@@ -220,8 +203,7 @@ class MetadataGroup(commonh5.LazyLoadableGroup):
class DetectorGroup(commonh5.LazyLoadableGroup):
- """Define the detector group (sub group of instrument) using Fabio data.
- """
+ """Define the detector group (sub group of instrument) using Fabio data."""
def __init__(self, name, fabio_reader, parent=None, attrs=None):
if attrs is None:
@@ -241,8 +223,7 @@ class DetectorGroup(commonh5.LazyLoadableGroup):
class ImageGroup(commonh5.LazyLoadableGroup):
- """Define the image group (sub group of measurement) using Fabio data.
- """
+ """Define the image group (sub group of measurement) using Fabio data."""
def __init__(self, name, fabio_reader, parent=None, attrs=None):
commonh5.LazyLoadableGroup.__init__(self, name, parent, attrs)
@@ -281,8 +262,7 @@ class NxDataPreviewGroup(commonh5.LazyLoadableGroup):
class SampleGroup(commonh5.LazyLoadableGroup):
- """Define the image group (sub group of measurement) using Fabio data.
- """
+ """Define the image group (sub group of measurement) using Fabio data."""
def __init__(self, name, fabio_reader, parent=None):
attrs = {"NXclass": "NXsample"}
@@ -309,8 +289,7 @@ class SampleGroup(commonh5.LazyLoadableGroup):
class MeasurementGroup(commonh5.LazyLoadableGroup):
- """Define the measurement group for fabio file.
- """
+ """Define the measurement group for fabio file."""
def __init__(self, name, fabio_reader, parent=None, attrs=None):
commonh5.LazyLoadableGroup.__init__(self, name, parent, attrs)
@@ -369,9 +348,13 @@ class FabioReader(object):
def __load(self, file_name=None, fabio_image=None, file_series=None):
if file_name is not None and fabio_image:
- raise TypeError("Parameters file_name and fabio_image are mutually exclusive.")
+ raise TypeError(
+ "Parameters file_name and fabio_image are mutually exclusive."
+ )
if file_name is not None and fabio_image:
- raise TypeError("Parameters fabio_image and file_series are mutually exclusive.")
+ raise TypeError(
+ "Parameters fabio_image and file_series are mutually exclusive."
+ )
self.__must_be_closed = False
@@ -382,14 +365,18 @@ class FabioReader(object):
if isinstance(fabio_image, fabio.fabioimage.FabioImage):
self.__fabio_file = fabio_image
else:
- raise TypeError("FabioImage expected but %s found.", fabio_image.__class__)
+ raise TypeError(
+ "FabioImage expected but %s found.", fabio_image.__class__
+ )
elif file_series is not None:
if isinstance(file_series, list):
- self.__fabio_file = _FileSeries(file_series)
+ self.__fabio_file = fabio.file_series.file_series(file_series)
elif isinstance(file_series, fabio.file_series.file_series):
self.__fabio_file = file_series
else:
- raise TypeError("file_series or list expected but %s found.", file_series.__class__)
+ raise TypeError(
+ "file_series or list expected but %s found.", file_series.__class__
+ )
def close(self):
"""Close the object, and free up associated resources.
@@ -401,10 +388,7 @@ class FabioReader(object):
may fail.
"""
if self.__must_be_closed:
- # Make sure the API of fabio provide it a 'close' method
- # TODO the test can be removed if fabio version >= 0.8
- if hasattr(self.__fabio_file, "close"):
- self.__fabio_file.close()
+ self.__fabio_file.close()
self.__fabio_file = None
def fabio_file(self):
@@ -428,7 +412,7 @@ class FabioReader(object):
for file_number in range(len(self.__fabio_file)):
with self.__fabio_file.jump_image(file_number) as fabio_image:
# return the first frame only
- assert(fabio_image.nframes == 1)
+ assert fabio_image.nframes == 1
yield fabio_image
elif isinstance(self.__fabio_file, fabio.fabioimage.FabioImage):
for frame_count in range(self.__fabio_file.nframes):
@@ -514,7 +498,9 @@ class FabioReader(object):
if not isinstance(value, numpy.ndarray):
if kind in [self.COUNTER, self.POSITIONER]:
# Force normalization for counters and positioners
- old = self._set_vector_normalization(at_least_32bits=True, signed_type=True)
+ old = self._set_vector_normalization(
+ at_least_32bits=True, signed_type=True
+ )
else:
old = None
value = self._convert_metadata_vector(value)
@@ -590,7 +576,7 @@ class FabioReader(object):
return previous
def _normalize_vector_type(self, dtype):
- """Normalize the """
+ """Normalize the"""
if self.__at_least_32bits:
if numpy.issubdtype(dtype, numpy.signedinteger):
dtype = numpy.result_type(dtype, numpy.uint32)
@@ -602,7 +588,7 @@ class FabioReader(object):
dtype = numpy.result_type(dtype, numpy.complex64)
if self.__signed_type:
if numpy.issubdtype(dtype, numpy.unsignedinteger):
- signed = numpy.dtype("%s%i" % ('i', dtype.itemsize))
+ signed = numpy.dtype("%s%i" % ("i", dtype.itemsize))
dtype = numpy.result_type(dtype, signed)
return dtype
@@ -652,7 +638,7 @@ class FabioReader(object):
if result_type.kind == "S":
none_value = b""
elif result_type.kind == "U":
- none_value = u""
+ none_value = ""
elif result_type.kind == "f":
none_value = numpy.float64("NaN")
elif result_type.kind == "i":
@@ -692,7 +678,7 @@ class FabioReader(object):
# convert to a numpy associative array
key_dtype = numpy.min_scalar_type(list(value.keys()))
value_dtype = numpy.min_scalar_type(list(value.values()))
- associative_type = [('key', key_dtype), ('value', value_dtype)]
+ associative_type = [("key", key_dtype), ("value", value_dtype)]
assert key_dtype.kind != "O" and value_dtype.kind != "O"
return numpy.array(list(value.items()), dtype=associative_type)
if isinstance(value, numbers.Number):
@@ -702,7 +688,7 @@ class FabioReader(object):
if isinstance(value, bytes):
try:
- value = value.decode('utf-8')
+ value = value.decode("utf-8")
except UnicodeDecodeError:
return numpy.void(value)
@@ -818,7 +804,7 @@ class EdfFabioReader(FabioReader):
pos_values = header.get(pos_values_key, "")
pos_values = pos_values.split()
- result = collections.OrderedDict()
+ result = {}
nbitems = max(len(mnemonic_values), len(pos_values))
for i in range(nbitems):
if i < len(mnemonic_values):
@@ -874,7 +860,9 @@ class EdfFabioReader(FabioReader):
if len(ub_data) > 9:
_logger.warning("UB_mne and UB_pos contains more than expected keys.")
if len(s_data) > 6:
- _logger.warning("sample_mne and sample_pos contains more than expected keys.")
+ _logger.warning(
+ "sample_mne and sample_pos contains more than expected keys."
+ )
data = numpy.array([s_data["U0"], s_data["U1"], s_data["U2"]], dtype=float)
unit_cell_abc = data
@@ -882,10 +870,16 @@ class EdfFabioReader(FabioReader):
data = numpy.array([s_data["U3"], s_data["U4"], s_data["U5"]], dtype=float)
unit_cell_alphabetagamma = data
- ub_matrix = numpy.array([[
- [ub_data["UB0"], ub_data["UB1"], ub_data["UB2"]],
- [ub_data["UB3"], ub_data["UB4"], ub_data["UB5"]],
- [ub_data["UB6"], ub_data["UB7"], ub_data["UB8"]]]], dtype=float)
+ ub_matrix = numpy.array(
+ [
+ [
+ [ub_data["UB0"], ub_data["UB1"], ub_data["UB2"]],
+ [ub_data["UB3"], ub_data["UB4"], ub_data["UB5"]],
+ [ub_data["UB6"], ub_data["UB7"], ub_data["UB8"]],
+ ]
+ ],
+ dtype=float,
+ )
self.__unit_cell_abc = unit_cell_abc
self.__unit_cell_alphabetagamma = unit_cell_alphabetagamma
@@ -939,8 +933,7 @@ class EdfFabioReader(FabioReader):
class File(commonh5.File):
- """Class which handle a fabio image as a mimick of a h5py.File.
- """
+ """Class which handle a fabio image as a mimick of a h5py.File."""
def __init__(self, file_name=None, fabio_image=None, file_series=None):
"""
@@ -953,15 +946,19 @@ class File(commonh5.File):
list of file name or a :class:`fabio.file_series.file_series`
instance
"""
- self.__fabio_reader = self.create_fabio_reader(file_name, fabio_image, file_series)
+ self.__fabio_reader = self.create_fabio_reader(
+ file_name, fabio_image, file_series
+ )
if fabio_image is not None:
file_name = fabio_image.filename
scan = self.create_scan_group(self.__fabio_reader)
- attrs = {"NX_class": "NXroot",
- "file_time": datetime.datetime.now().isoformat(),
- "creator": "silx %s" % silx_version,
- "default": scan.basename}
+ attrs = {
+ "NX_class": "NXroot",
+ "file_time": datetime.datetime.now().isoformat(),
+ "creator": "silx %s" % silx_version,
+ "default": scan.basename,
+ }
if file_name is not None:
attrs["file_name"] = file_name
commonh5.File.__init__(self, name=file_name, attrs=attrs)
@@ -981,9 +978,16 @@ class File(commonh5.File):
}
scan = commonh5.Group("scan_0", attrs=scan_attrs)
instrument = commonh5.Group("instrument", attrs={"NX_class": "NXinstrument"})
- measurement = MeasurementGroup("measurement", fabio_reader, attrs={"NX_class": "NXcollection"})
+ measurement = MeasurementGroup(
+ "measurement", fabio_reader, attrs={"NX_class": "NXcollection"}
+ )
file_ = commonh5.Group("file", attrs={"NX_class": "NXcollection"})
- positioners = MetadataGroup("positioners", fabio_reader, FabioReader.POSITIONER, attrs={"NX_class": "NXpositioner"})
+ positioners = MetadataGroup(
+ "positioners",
+ fabio_reader,
+ FabioReader.POSITIONER,
+ attrs={"NX_class": "NXpositioner"},
+ )
raw_header = RawHeaderData("scan_header", fabio_reader, self)
detector = DetectorGroup("detector_0", fabio_reader)
@@ -1031,7 +1035,7 @@ class File(commonh5.File):
elif first_image is not None:
use_edf_reader = isinstance(first_image, fabio.edfimage.EdfImage)
else:
- assert(False)
+ assert False
if use_edf_reader:
reader = EdfFabioReader(file_name, fabio_image, file_series)
diff --git a/src/silx/io/fioh5.py b/src/silx/io/fioh5.py
index 0a86bbf..a88d35b 100644
--- a/src/silx/io/fioh5.py
+++ b/src/silx/io/fioh5.py
@@ -154,15 +154,17 @@ logger1 = logging.getLogger(__name__)
if h5py.version.version_tuple[0] < 3:
text_dtype = h5py.special_dtype(vlen=str) # old API
else:
- text_dtype = 'O' # variable-length string (supported as of h5py > 3.0)
+ text_dtype = "O" # variable-length string (supported as of h5py > 3.0)
ABORTLINENO = 5
-dtypeConverter = {'STRING': text_dtype,
- 'DOUBLE': 'f8',
- 'FLOAT': 'f4',
- 'INTEGER': 'i8',
- 'BOOLEAN': '?'}
+dtypeConverter = {
+ "STRING": text_dtype,
+ "DOUBLE": "f8",
+ "FLOAT": "f4",
+ "INTEGER": "i8",
+ "BOOLEAN": "?",
+}
def is_fiofile(filename):
@@ -192,56 +194,51 @@ def is_fiofile(filename):
class FioFile(object):
- """This class opens a FIO file and reads the data.
-
- """
+ """This class opens a FIO file and reads the data."""
def __init__(self, filepath):
# parse filename
filename = os.path.basename(filepath)
- fnowithsuffix = filename.split('_')[-1]
+ fnowithsuffix = filename.split("_")[-1]
try:
- self.scanno = int(fnowithsuffix.split('.')[0])
+ self.scanno = int(fnowithsuffix.split(".")[0])
except Exception:
self.scanno = None
logger1.warning("Cannot parse scan number of file %s", filename)
- with open(filepath, 'r') as fiof:
-
+ with open(filepath, "r") as fiof:
prev = 0
line_counter = 0
- while(True):
+ while True:
line = fiof.readline()
- if line.startswith('!'): # skip comments
+ if line.startswith("!"): # skip comments
prev = fiof.tell()
line_counter = 0
continue
- if line.startswith('%c'): # comment section
+ if line.startswith("%c"): # comment section
line_counter = 0
- self.commentsection = ''
+ self.commentsection = ""
line = fiof.readline()
- while(not line.startswith('%')
- and not line.startswith('!')):
+ while not line.startswith("%") and not line.startswith("!"):
self.commentsection += line
prev = fiof.tell()
line = fiof.readline()
- if line.startswith('%p'): # parameter section
+ if line.startswith("%p"): # parameter section
line_counter = 0
- self.parameterssection = ''
+ self.parameterssection = ""
line = fiof.readline()
- while(not line.startswith('%')
- and not line.startswith('!')):
+ while not line.startswith("%") and not line.startswith("!"):
self.parameterssection += line
prev = fiof.tell()
line = fiof.readline()
- if line.startswith('%d'): # data type definitions
+ if line.startswith("%d"): # data type definitions
line_counter = 0
self.datacols = []
self.names = []
self.dtypes = []
line = fiof.readline()
- while(line.startswith(' Col')):
+ while line.startswith(" Col"):
splitline = line.split()
name = splitline[-2]
self.names.append(name)
@@ -255,13 +252,16 @@ class FioFile(object):
line_counter += 1
if line_counter > ABORTLINENO:
- raise IOError("Invalid fio file: Found no data "
- "after %s lines" % ABORTLINENO)
+ raise IOError(
+ "Invalid fio file: Found no data "
+ "after %s lines" % ABORTLINENO
+ )
- self.data = numpy.loadtxt(fiof,
- dtype={'names': tuple(self.names),
- 'formats': tuple(self.dtypes)},
- comments="!")
+ self.data = numpy.loadtxt(
+ fiof,
+ dtype={"names": tuple(self.names), "formats": tuple(self.dtypes)},
+ comments="!",
+ )
# ToDo: read only last line of file,
# which sometimes contains the end of acquisition timestamp.
@@ -271,7 +271,7 @@ class FioFile(object):
# parse parameter section:
try:
for line in self.parameterssection.splitlines():
- param, value = line.split(' = ')
+ param, value = line.split(" = ")
self.parameter[param] = value
except Exception:
logger1.warning("Cannot parse parameter section")
@@ -288,7 +288,7 @@ class FioFile(object):
raise Exception("acquisition str not found")
self.user = l2[:acqpos][4:].strip()
- self.start_time = l2[acqpos+len(acquiMarker):].strip()
+ self.start_time = l2[acqpos + len(acquiMarker) :].strip()
commentlines = commentlines[2:]
self.comments = "\n".join(commentlines[2:])
@@ -324,15 +324,13 @@ class FioH5NodeDataset(commonh5.Dataset):
data_kind = array.dtype.kind
if data_kind in ["S", "U"]:
- value = numpy.asarray(array,
- dtype=text_dtype)
+ value = numpy.asarray(array, dtype=text_dtype)
else:
value = array # numerical data is already the correct datatype
commonh5.Dataset.__init__(self, name, value, parent, attrs)
def __getattr__(self, item):
- """Proxy to underlying numpy array methods.
- """
+ """Proxy to underlying numpy array methods."""
if hasattr(self[()], item):
return getattr(self[()], item)
@@ -363,11 +361,12 @@ class FioH5(commonh5.File):
except Exception as e:
raise IOError("FIO file %s cannot be read.") from e
- attrs = {"NX_class": to_h5py_utf8("NXroot"),
- "file_time": to_h5py_utf8(
- datetime.datetime.now().isoformat()),
- "file_name": to_h5py_utf8(filename),
- "creator": to_h5py_utf8("silx fioh5 %s" % silx_version)}
+ attrs = {
+ "NX_class": to_h5py_utf8("NXroot"),
+ "file_time": to_h5py_utf8(datetime.datetime.now().isoformat()),
+ "file_name": to_h5py_utf8(filename),
+ "creator": to_h5py_utf8("silx fioh5 %s" % silx_version),
+ }
commonh5.File.__init__(self, filename, attrs=attrs)
if fiof.scanno is not None:
@@ -387,33 +386,40 @@ class FioScanGroup(commonh5.Group):
:param str scan_key: Scan key (e.g. "1.1")
:param scan: FioFile object
"""
- if hasattr(scan, 'user'):
+ if hasattr(scan, "user"):
userattr = to_h5py_utf8(scan.user)
else:
- userattr = to_h5py_utf8('')
- commonh5.Group.__init__(self, scan_key, parent=parent,
- attrs={"NX_class": to_h5py_utf8("NXentry"),
- "user": userattr})
+ userattr = to_h5py_utf8("")
+ commonh5.Group.__init__(
+ self,
+ scan_key,
+ parent=parent,
+ attrs={"NX_class": to_h5py_utf8("NXentry"), "user": userattr},
+ )
# 'title', 'start_time' and 'user' are defaults
# in Sardana created files:
- if hasattr(scan, 'title'):
+ if hasattr(scan, "title"):
title = scan.title
else:
title = scan_key # use scan number as default title
- self.add_node(FioH5NodeDataset(name="title",
- data=to_h5py_utf8(title),
- parent=self))
+ self.add_node(
+ FioH5NodeDataset(name="title", data=to_h5py_utf8(title), parent=self)
+ )
- if hasattr(scan, 'start_time'):
+ if hasattr(scan, "start_time"):
start_time = scan.start_time
- self.add_node(FioH5NodeDataset(name="start_time",
- data=to_h5py_utf8(start_time),
- parent=self))
-
- self.add_node(FioH5NodeDataset(name="comments",
- data=to_h5py_utf8(scan.comments),
- parent=self))
+ self.add_node(
+ FioH5NodeDataset(
+ name="start_time", data=to_h5py_utf8(start_time), parent=self
+ )
+ )
+
+ self.add_node(
+ FioH5NodeDataset(
+ name="comments", data=to_h5py_utf8(scan.comments), parent=self
+ )
+ )
self.add_node(FioInstrumentGroup(parent=self, scan=scan))
self.add_node(FioMeasurementGroup(parent=self, scan=scan))
@@ -426,14 +432,18 @@ class FioMeasurementGroup(commonh5.Group):
:param parent: parent Group
:param scan: FioFile object
"""
- commonh5.Group.__init__(self, name="measurement", parent=parent,
- attrs={"NX_class": to_h5py_utf8("NXcollection")})
+ commonh5.Group.__init__(
+ self,
+ name="measurement",
+ parent=parent,
+ attrs={"NX_class": to_h5py_utf8("NXcollection")},
+ )
for label in scan.names:
safe_label = label.replace("/", "%")
- self.add_node(FioH5NodeDataset(name=safe_label,
- data=scan.data[label],
- parent=self))
+ self.add_node(
+ FioH5NodeDataset(name=safe_label, data=scan.data[label], parent=self)
+ )
class FioInstrumentGroup(commonh5.Group):
@@ -443,14 +453,20 @@ class FioInstrumentGroup(commonh5.Group):
:param parent: parent Group
:param scan: FioFile object
"""
- commonh5.Group.__init__(self, name="instrument", parent=parent,
- attrs={"NX_class": to_h5py_utf8("NXinstrument")})
+ commonh5.Group.__init__(
+ self,
+ name="instrument",
+ parent=parent,
+ attrs={"NX_class": to_h5py_utf8("NXinstrument")},
+ )
self.add_node(FioParameterGroup(parent=self, scan=scan))
self.add_node(FioFileGroup(parent=self, scan=scan))
- self.add_node(FioH5NodeDataset(name="comment",
- data=to_h5py_utf8(scan.comments),
- parent=self))
+ self.add_node(
+ FioH5NodeDataset(
+ name="comment", data=to_h5py_utf8(scan.comments), parent=self
+ )
+ )
class FioFileGroup(commonh5.Group):
@@ -460,16 +476,24 @@ class FioFileGroup(commonh5.Group):
:param parent: parent Group
:param scan: FioFile object
"""
- commonh5.Group.__init__(self, name="fiofile", parent=parent,
- attrs={"NX_class": to_h5py_utf8("NXcollection")})
-
- self.add_node(FioH5NodeDataset(name="comments",
- data=to_h5py_utf8(scan.commentsection),
- parent=self))
-
- self.add_node(FioH5NodeDataset(name="parameter",
- data=to_h5py_utf8(scan.parameterssection),
- parent=self))
+ commonh5.Group.__init__(
+ self,
+ name="fiofile",
+ parent=parent,
+ attrs={"NX_class": to_h5py_utf8("NXcollection")},
+ )
+
+ self.add_node(
+ FioH5NodeDataset(
+ name="comments", data=to_h5py_utf8(scan.commentsection), parent=self
+ )
+ )
+
+ self.add_node(
+ FioH5NodeDataset(
+ name="parameter", data=to_h5py_utf8(scan.parameterssection), parent=self
+ )
+ )
class FioParameterGroup(commonh5.Group):
@@ -479,11 +503,19 @@ class FioParameterGroup(commonh5.Group):
:param parent: parent Group
:param scan: FioFile object
"""
- commonh5.Group.__init__(self, name="parameter", parent=parent,
- attrs={"NX_class": to_h5py_utf8("NXcollection")})
+ commonh5.Group.__init__(
+ self,
+ name="parameter",
+ parent=parent,
+ attrs={"NX_class": to_h5py_utf8("NXcollection")},
+ )
for label in scan.parameter:
safe_label = label.replace("/", "%")
- self.add_node(FioH5NodeDataset(name=safe_label,
- data=to_h5py_utf8(scan.parameter[label]),
- parent=self))
+ self.add_node(
+ FioH5NodeDataset(
+ name=safe_label,
+ data=to_h5py_utf8(scan.parameter[label]),
+ parent=self,
+ )
+ )
diff --git a/src/silx/io/h5link_utils.py b/src/silx/io/h5link_utils.py
new file mode 100644
index 0000000..39f9ae4
--- /dev/null
+++ b/src/silx/io/h5link_utils.py
@@ -0,0 +1,77 @@
+import os
+from typing import NamedTuple, Optional
+from .utils import is_dataset
+
+
+class ExternalDatasetInfo(NamedTuple):
+ type: str
+ nfiles: int
+ first_file_path: str
+ first_data_path: Optional[str] = None
+
+ @property
+ def first_source_url(self):
+ if self.first_data_path:
+ if self.first_data_path.startswith("/"):
+ return self.first_file_path + "::" + self.first_data_path
+ else:
+ return self.first_file_path + "::/" + self.first_data_path
+ return self.first_file_path
+
+
+def external_dataset_info(hdf5obj) -> Optional[ExternalDatasetInfo]:
+ """When the object is a virtual dataset or an external dataset,
+ return information on the external files. Return `None` otherwise.
+
+ Note that this has nothing to do with external HDF5 links."""
+ if not is_dataset(hdf5obj):
+ return
+ if hasattr(hdf5obj, "is_virtual") and hdf5obj.is_virtual:
+ sources = hdf5obj.virtual_sources()
+ if not sources:
+ return ExternalDatasetInfo(
+ type="Virtual",
+ nfiles=0,
+ first_file_path="",
+ )
+
+ first_source = sources[0]
+ first_file_path = first_source.file_name
+ if first_file_path == ".":
+ first_file_path = hdf5obj.file.filename
+ elif not os.path.isabs(first_file_path):
+ dirname = os.path.dirname(hdf5obj.file.filename)
+ first_file_path = os.path.normpath(
+ os.path.join(
+ dirname,
+ first_file_path,
+ )
+ )
+
+ return ExternalDatasetInfo(
+ type="Virtual",
+ nfiles=len(sources),
+ first_file_path=first_file_path,
+ first_data_path=first_source.dset_name,
+ )
+ if hasattr(hdf5obj, "external"):
+ sources = hdf5obj.external
+ if not sources:
+ return
+
+ first_source = sources[0]
+ first_file_path = first_source[0]
+ if not os.path.isabs(first_file_path):
+ dirname = os.path.dirname(hdf5obj.file.filename)
+ first_file_path = os.path.normpath(
+ os.path.join(
+ dirname,
+ first_file_path,
+ )
+ )
+
+ return ExternalDatasetInfo(
+ type="Raw",
+ nfiles=len(sources),
+ first_file_path=first_file_path,
+ )
diff --git a/src/silx/io/h5py_utils.py b/src/silx/io/h5py_utils.py
index 139bff3..478f72c 100644
--- a/src/silx/io/h5py_utils.py
+++ b/src/silx/io/h5py_utils.py
@@ -1,5 +1,5 @@
# /*##########################################################################
-# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
+# Copyright (C) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -27,7 +27,7 @@ parallel writing and reading.
__authors__ = ["W. de Nolf"]
__license__ = "MIT"
-__date__ = "27/01/2020"
+__date__ = "28/11/2023"
import os
@@ -47,8 +47,11 @@ IS_WINDOWS = sys.platform == "win32"
H5PY_HEX_VERSION = calc_hexversion(*h5py.version.version_tuple[:3])
HDF5_HEX_VERSION = calc_hexversion(*h5py.version.hdf5_version_tuple[:3])
-HDF5_SWMR_VERSION = calc_hexversion(*h5py.get_config().swmr_min_hdf5_version[:3])
-HAS_SWMR = HDF5_HEX_VERSION >= HDF5_SWMR_VERSION
+if h5py.version.version_tuple >= (3, 10):
+ HDF5_SWMR_VERSION = 1, 9, 178
+else:
+ HDF5_SWMR_VERSION = h5py.get_config().swmr_min_hdf5_version[:3]
+HAS_SWMR = HDF5_HEX_VERSION >= calc_hexversion(*HDF5_SWMR_VERSION)
HAS_TRACK_ORDER = H5PY_HEX_VERSION >= calc_hexversion(2, 9, 0)
@@ -117,8 +120,9 @@ def _is_h5py_exception(e):
:returns bool:
"""
for frame in traceback.walk_tb(e.__traceback__):
- if frame[0].f_locals.get("__package__", None) == "h5py":
- return True
+ for namespace in (frame[0].f_locals, frame[0].f_globals):
+ if namespace.get("__package__", None) == "h5py":
+ return True
return False
@@ -242,7 +246,11 @@ def _top_level_names(filename, include_only=group_has_end_time, **open_options):
top_level_names = retry()(_top_level_names)
-safe_top_level_names = retry_in_subprocess()(_top_level_names)
+if hasattr(sys, "frozen") and sys.frozen:
+ # multiprocessing not working on frozen binaries
+ safe_top_level_names = top_level_names
+else:
+ safe_top_level_names = retry_in_subprocess()(_top_level_names)
class Hdf5FileLockingManager:
diff --git a/src/silx/io/nxdata/__init__.py b/src/silx/io/nxdata/__init__.py
index 51efc68..23ac745 100644
--- a/src/silx/io/nxdata/__init__.py
+++ b/src/silx/io/nxdata/__init__.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2018 European Synchrotron Radiation Facility
+# Copyright (c) 2018-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -59,7 +59,14 @@ Functions
.. autofunction:: save_NXdata
"""
-from .parse import NXdata, get_default, is_valid_nxdata, InvalidNXdataError, \
- is_NXentry_with_default_NXdata, is_NXroot_with_default_NXdata, is_group_with_default_NXdata
-from ._utils import get_attr_as_unicode, get_attr_as_string, nxdata_logger
+from .parse import (
+ NXdata,
+ get_default,
+ is_valid_nxdata,
+ InvalidNXdataError,
+ is_NXentry_with_default_NXdata,
+ is_NXroot_with_default_NXdata,
+ is_group_with_default_NXdata,
+)
+from ._utils import get_attr_as_unicode, nxdata_logger
from .write import save_NXdata
diff --git a/src/silx/io/nxdata/_utils.py b/src/silx/io/nxdata/_utils.py
index 3aa3846..61bdf11 100644
--- a/src/silx/io/nxdata/_utils.py
+++ b/src/silx/io/nxdata/_utils.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -29,7 +29,6 @@ import logging
import numpy
from silx.io import is_dataset
-from silx.utils.deprecation import deprecated
__authors__ = ["P. Knobel"]
@@ -40,21 +39,18 @@ __date__ = "17/04/2018"
nxdata_logger = logging.getLogger("silx.io.nxdata")
-INTERPDIM = {"scalar": 0,
- "spectrum": 1,
- "image": 2,
- "rgba-image": 3, # "hsla-image": 3, "cmyk-image": 3, # TODO
- "vertex": 1} # 3D scatter: 1D signal + 3 axes (x, y, z) of same legth
+INTERPDIM = {
+ "scalar": 0,
+ "spectrum": 1,
+ "image": 2,
+ "rgba-image": 3, # "hsla-image": 3, "cmyk-image": 3, # TODO
+ "vertex": 1,
+} # 3D scatter: 1D signal + 3 axes (x, y, z) of same legth
"""Number of signal dimensions associated to each possible @interpretation
attribute.
"""
-@deprecated(since_version="0.8.0", replacement="get_attr_as_unicode")
-def get_attr_as_string(*args, **kwargs):
- return get_attr_as_unicode(*args, **kwargs)
-
-
def get_attr_as_unicode(item, attr_name, default=None):
"""Return item.attrs[attr_name] as unicode or as a
list of unicode.
@@ -107,13 +103,16 @@ def get_signal_name(group):
"""
signal_name = get_attr_as_unicode(group, "signal", default=None)
if signal_name is None:
- nxdata_logger.info("NXdata group %s does not define a signal attr. "
- "Testing legacy specification.", group.name)
+ nxdata_logger.info(
+ "NXdata group %s does not define a signal attr. "
+ "Testing legacy specification.",
+ group.name,
+ )
for key in group:
if "signal" in group[key].attrs:
signal_name = key
signal_attr = group[key].attrs["signal"]
- if signal_attr in [1, b"1", u"1"]:
+ if signal_attr in [1, b"1", "1"]:
# This is the main (default) signal
break
return signal_name
@@ -121,8 +120,9 @@ def get_signal_name(group):
def get_auxiliary_signals_names(group):
"""Return list of auxiliary signals names"""
- auxiliary_signals_names = get_attr_as_unicode(group, "auxiliary_signals",
- default=[])
+ auxiliary_signals_names = get_attr_as_unicode(
+ group, "auxiliary_signals", default=[]
+ )
if isinstance(auxiliary_signals_names, (str, bytes)):
auxiliary_signals_names = [auxiliary_signals_names]
return auxiliary_signals_names
@@ -133,11 +133,12 @@ def validate_auxiliary_signals(group, signal_name, auxiliary_signals_names):
issues = []
for asn in auxiliary_signals_names:
if asn not in group or not is_dataset(group[asn]):
- issues.append(
- "Cannot find auxiliary signal dataset '%s'" % asn)
+ issues.append("Cannot find auxiliary signal dataset '%s'" % asn)
elif group[signal_name].shape != group[asn].shape:
- issues.append("Auxiliary signal dataset '%s' does not" % asn +
- " have the same shape as the main signal.")
+ issues.append(
+ "Auxiliary signal dataset '%s' does not" % asn
+ + " have the same shape as the main signal."
+ )
return issues
@@ -147,9 +148,10 @@ def validate_number_of_axes(group, signal_name, num_axes):
if 1 < ndims < num_axes:
# ndim = 1 with several axes could be a scatter
issues.append(
- "More @axes defined than there are " +
- "signal dimensions: " +
- "%d axes, %d dimensions." % (num_axes, ndims))
+ "More @axes defined than there are "
+ + "signal dimensions: "
+ + "%d axes, %d dimensions." % (num_axes, ndims)
+ )
# case of less axes than dimensions: number of axes must match
# dimensionality defined by @interpretation
@@ -158,25 +160,30 @@ def validate_number_of_axes(group, signal_name, num_axes):
if interpretation is None:
interpretation = get_attr_as_unicode(group, "interpretation")
if interpretation is None:
- issues.append("No @interpretation and not enough" +
- " @axes defined.")
+ issues.append("No @interpretation and not enough" + " @axes defined.")
elif interpretation not in INTERPDIM:
- issues.append("Unrecognized @interpretation=" + interpretation +
- " for data with wrong number of defined @axes.")
+ issues.append(
+ "Unrecognized @interpretation="
+ + interpretation
+ + " for data with wrong number of defined @axes."
+ )
elif interpretation == "rgba-image":
if ndims != 3 or group[signal_name].shape[-1] not in [3, 4]:
issues.append(
- "Inconsistent RGBA Image. Expected 3 dimensions with " +
- "last one of length 3 or 4. Got ndim=%d " % ndims +
- "with last dimension of length %d." % group[signal_name].shape[-1])
+ "Inconsistent RGBA Image. Expected 3 dimensions with "
+ + "last one of length 3 or 4. Got ndim=%d " % ndims
+ + "with last dimension of length %d." % group[signal_name].shape[-1]
+ )
if num_axes != 2:
issues.append(
"Inconsistent number of axes for RGBA Image. Expected "
- "3, but got %d." % ndims)
+ "3, but got %d." % ndims
+ )
elif num_axes != INTERPDIM[interpretation]:
issues.append(
- "%d-D signal with @interpretation=%s " % (ndims, interpretation) +
- "must define %d or %d axes." % (ndims, INTERPDIM[interpretation]))
+ "%d-D signal with @interpretation=%s " % (ndims, interpretation)
+ + "must define %d or %d axes." % (ndims, INTERPDIM[interpretation])
+ )
return issues
diff --git a/src/silx/io/nxdata/parse.py b/src/silx/io/nxdata/parse.py
index 0c9d7e7..61e311e 100644
--- a/src/silx/io/nxdata/parse.py
+++ b/src/silx/io/nxdata/parse.py
@@ -47,9 +47,16 @@ import numpy
from silx.io.utils import is_group, is_file, is_dataset, h5py_read_dataset
-from ._utils import get_attr_as_unicode, INTERPDIM, nxdata_logger, \
- get_uncertainties_names, get_signal_name, \
- get_auxiliary_signals_names, validate_auxiliary_signals, validate_number_of_axes
+from ._utils import (
+ get_attr_as_unicode,
+ INTERPDIM,
+ nxdata_logger,
+ get_uncertainties_names,
+ get_signal_name,
+ get_auxiliary_signals_names,
+ validate_auxiliary_signals,
+ validate_number_of_axes,
+)
__authors__ = ["P. Knobel"]
@@ -80,55 +87,61 @@ class _SilxStyle(object):
try:
style = json.loads(stylestr)
except json.JSONDecodeError:
- nxdata_logger.error(
- "Ignoring SILX_style, cannot parse: %s", stylestr)
+ nxdata_logger.error("Ignoring SILX_style, cannot parse: %s", stylestr)
return
if not isinstance(style, dict):
- nxdata_logger.error(
- "Ignoring SILX_style, cannot parse: %s", stylestr)
+ nxdata_logger.error("Ignoring SILX_style, cannot parse: %s", stylestr)
- if 'axes_scale_types' in style:
- axes_scale_types = style['axes_scale_types']
+ if "axes_scale_types" in style:
+ axes_scale_types = style["axes_scale_types"]
if isinstance(axes_scale_types, str):
# Convert single argument to list
axes_scale_types = [axes_scale_types]
if not isinstance(axes_scale_types, list):
- nxdata_logger.error(
- "Ignoring SILX_style:axes_scale_types, not a list")
+ nxdata_logger.error("Ignoring SILX_style:axes_scale_types, not a list")
else:
for scale_type in axes_scale_types:
- if scale_type not in ('linear', 'log'):
+ if scale_type not in ("linear", "log"):
nxdata_logger.error(
- "Ignoring SILX_style:axes_scale_types, invalid value: %s", str(scale_type))
+ "Ignoring SILX_style:axes_scale_types, invalid value: %s",
+ str(scale_type),
+ )
break
else: # All values are valid
if len(axes_scale_types) > naxes:
nxdata_logger.error(
- "Clipping SILX_style:axes_scale_types, too many values")
+ "Clipping SILX_style:axes_scale_types, too many values"
+ )
axes_scale_types = axes_scale_types[:naxes]
elif len(axes_scale_types) < naxes:
# Extend axes_scale_types with None to match number of axes
- axes_scale_types = [None] * (naxes - len(axes_scale_types)) + axes_scale_types
+ axes_scale_types = [None] * (
+ naxes - len(axes_scale_types)
+ ) + axes_scale_types
self._axes_scale_types = tuple(axes_scale_types)
- if 'signal_scale_type' in style:
- scale_type = style['signal_scale_type']
- if scale_type not in ('linear', 'log'):
+ if "signal_scale_type" in style:
+ scale_type = style["signal_scale_type"]
+ if scale_type not in ("linear", "log"):
nxdata_logger.error(
- "Ignoring SILX_style:signal_scale_type, invalid value: %s", str(scale_type))
+ "Ignoring SILX_style:signal_scale_type, invalid value: %s",
+ str(scale_type),
+ )
else:
self._signal_scale_type = scale_type
axes_scale_types = property(
lambda self: self._axes_scale_types,
- doc="Tuple of NXdata axes scale types (None, 'linear' or 'log'). List[str]")
+ doc="Tuple of NXdata axes scale types (None, 'linear' or 'log'). List[str]",
+ )
signal_scale_type = property(
lambda self: self._signal_scale_type,
- doc="NXdata signal scale type (None, 'linear' or 'log'). str")
+ doc="NXdata signal scale type (None, 'linear' or 'log'). str",
+ )
class NXdata(object):
@@ -145,6 +158,7 @@ class NXdata(object):
where :meth:`silx.io.nxdata.is_valid_nxdata` has already been called
prior to instantiating this :class:`NXdata`.
"""
+
def __init__(self, group, validate=True):
super(NXdata, self).__init__()
self._plot_style = None
@@ -200,8 +214,7 @@ class NXdata(object):
self.signal_name = self.signal_dataset_name
# ndim will be available in very recent h5py versions only
- self.signal_ndim = getattr(self.signal, "ndim",
- len(self.signal.shape))
+ self.signal_ndim = getattr(self.signal, "ndim", len(self.signal.shape))
self.signal_is_0d = self.signal_ndim == 0
self.signal_is_1d = self.signal_ndim == 1
@@ -212,12 +225,16 @@ class NXdata(object):
# check if axis dataset defines @long_name
for _, dsname in enumerate(self.axes_dataset_names):
if dsname is not None and "long_name" in self.group[dsname].attrs:
- self.axes_names.append(get_attr_as_unicode(self.group[dsname], "long_name"))
+ self.axes_names.append(
+ get_attr_as_unicode(self.group[dsname], "long_name")
+ )
else:
self.axes_names.append(dsname)
# excludes scatters
- self.signal_is_1d = self.signal_is_1d and len(self.axes) <= 1 # excludes n-D scatters
+ self.signal_is_1d = (
+ self.signal_is_1d and len(self.axes) <= 1
+ ) # excludes n-D scatters
self._plot_style = _SilxStyle(self)
@@ -231,8 +248,10 @@ class NXdata(object):
signal_name = get_signal_name(self.group)
if signal_name is None:
- self.issues.append("No @signal attribute on the NXdata group, "
- "and no dataset with a @signal=1 attr found")
+ self.issues.append(
+ "No @signal attribute on the NXdata group, "
+ "and no dataset with a @signal=1 attr found"
+ )
# very difficult to do more consistency tests without signal
return
@@ -241,9 +260,9 @@ class NXdata(object):
return
auxiliary_signals_names = get_auxiliary_signals_names(self.group)
- self.issues += validate_auxiliary_signals(self.group,
- signal_name,
- auxiliary_signals_names)
+ self.issues += validate_auxiliary_signals(
+ self.group, signal_name, auxiliary_signals_names
+ )
axes_names = get_attr_as_unicode(self.group, "axes")
if axes_names is None:
@@ -258,8 +277,9 @@ class NXdata(object):
axes_names = [axes_names]
if axes_names:
- self.issues += validate_number_of_axes(self.group, signal_name,
- num_axes=len(axes_names))
+ self.issues += validate_number_of_axes(
+ self.group, signal_name, num_axes=len(axes_names)
+ )
# Test consistency of @uncertainties
uncertainties_names = get_uncertainties_names(self.group, signal_name)
@@ -268,11 +288,15 @@ class NXdata(object):
if len(uncertainties_names) < len(axes_names):
# ignore the field to avoid index error in the axes loop
uncertainties_names = None
- self.issues.append("@uncertainties does not define the same " +
- "number of fields than @axes. Field ignored")
+ self.issues.append(
+ "@uncertainties does not define the same "
+ + "number of fields than @axes. Field ignored"
+ )
else:
- self.issues.append("@uncertainties does not define the same " +
- "number of fields than @axes")
+ self.issues.append(
+ "@uncertainties does not define the same "
+ + "number of fields than @axes"
+ )
# Test individual axes
is_scatter = True # true if all axes have the same size as the signal
@@ -281,7 +305,6 @@ class NXdata(object):
signal_size *= dim
polynomial_axes_names = []
for i, axis_name in enumerate(axes_names):
-
if axis_name == ".":
continue
if axis_name not in self.group or not is_dataset(self.group[axis_name]):
@@ -299,37 +322,37 @@ class NXdata(object):
else:
# for a 1-d axis,
fg_idx = self.group[axis_name].attrs.get("first_good", 0)
- lg_idx = self.group[axis_name].attrs.get("last_good", len(self.group[axis_name]) - 1)
+ lg_idx = self.group[axis_name].attrs.get(
+ "last_good", len(self.group[axis_name]) - 1
+ )
axis_len = lg_idx + 1 - fg_idx
if axis_len != signal_size:
if axis_len not in self.group[signal_name].shape + (1, 2):
self.issues.append(
- "Axis %s number of elements does not " % axis_name +
- "correspond to the length of any signal dimension,"
- " it does not appear to be a constant or a linear calibration," +
- " and this does not seem to be a scatter plot.")
+ "Axis %s number of elements does not " % axis_name
+ + "correspond to the length of any signal dimension,"
+ " it does not appear to be a constant or a linear calibration,"
+ + " and this does not seem to be a scatter plot."
+ )
continue
elif axis_len in (1, 2):
polynomial_axes_names.append(axis_name)
is_scatter = False
- else:
- if not is_scatter:
- self.issues.append(
- "Axis %s number of elements is equal " % axis_name +
- "to the length of the signal, but this does not seem" +
- " to be a scatter (other axes have different sizes)")
- continue
# Test individual uncertainties
errors_name = axis_name + "_errors"
if errors_name not in self.group and uncertainties_names is not None:
errors_name = uncertainties_names[i]
- if errors_name in self.group and axis_name not in polynomial_axes_names:
+ if (
+ errors_name in self.group
+ and axis_name not in polynomial_axes_names
+ ):
if self.group[errors_name].shape != self.group[axis_name].shape:
self.issues.append(
- "Errors '%s' does not have the same " % errors_name +
- "dimensions as axis '%s'." % axis_name)
+ "Errors '%s' does not have the same " % errors_name
+ + "dimensions as axis '%s'." % axis_name
+ )
# test dimensions of errors associated with signal
@@ -345,8 +368,9 @@ class NXdata(object):
# In principle just the same size should be enough but
# NeXus documentation imposes to have the same shape
self.issues.append(
- "Dataset containing standard deviations must " +
- "have the same dimensions as the signal.")
+ "Dataset containing standard deviations must "
+ + "have the same dimensions as the signal."
+ )
@property
def signal_dataset_name(self):
@@ -358,7 +382,7 @@ class NXdata(object):
# find a dataset with @signal == 1
for dsname in self.group:
signal_attr = self.group[dsname].attrs.get("signal")
- if signal_attr in [1, b"1", u"1"]:
+ if signal_attr in [1, b"1", "1"]:
# This is the main (default) signal
signal_dataset_name = dsname
break
@@ -380,10 +404,13 @@ class NXdata(object):
raise InvalidNXdataError("Unable to parse invalid NXdata")
signal_dataset_name = get_attr_as_unicode(self.group, "signal")
if signal_dataset_name is not None:
- auxiliary_signals_names = get_attr_as_unicode(self.group, "auxiliary_signals")
+ auxiliary_signals_names = get_attr_as_unicode(
+ self.group, "auxiliary_signals"
+ )
if auxiliary_signals_names is not None:
- if not isinstance(auxiliary_signals_names,
- (tuple, list, numpy.ndarray)):
+ if not isinstance(
+ auxiliary_signals_names, (tuple, list, numpy.ndarray)
+ ):
# tolerate a single string, but coerce into a list
return [auxiliary_signals_names]
return list(auxiliary_signals_names)
@@ -398,16 +425,22 @@ class NXdata(object):
ds = self.group[dsname]
signal_attr = ds.attrs.get("signal")
if signal_attr is not None and not is_dataset(ds):
- nxdata_logger.warning("Item %s with @signal=%s is not a dataset (%s)",
- dsname, signal_attr, type(ds))
+ nxdata_logger.warning(
+ "Item %s with @signal=%s is not a dataset (%s)",
+ dsname,
+ signal_attr,
+ type(ds),
+ )
continue
if signal_attr is not None:
try:
signal_number = int(signal_attr)
except (ValueError, TypeError):
- nxdata_logger.warning("Could not parse attr @signal=%s on "
- "dataset %s as an int",
- signal_attr, dsname)
+ nxdata_logger.warning(
+ "Could not parse attr @signal=%s on " "dataset %s as an int",
+ signal_attr,
+ dsname,
+ )
continue
numbered_names.append((signal_number, dsname))
return [a[1] for a in sorted(numbered_names)]
@@ -465,17 +498,26 @@ class NXdata(object):
if not self.is_valid:
raise InvalidNXdataError("Unable to parse invalid NXdata")
- allowed_interpretations = [None, "scaler", "scalar", "spectrum", "image",
- "rgba-image", # "hsla-image", "cmyk-image"
- "vertex"]
+ allowed_interpretations = [
+ None,
+ "scaler",
+ "scalar",
+ "spectrum",
+ "image",
+ "rgba-image", # "hsla-image", "cmyk-image"
+ "vertex",
+ ]
interpretation = get_attr_as_unicode(self.signal, "interpretation")
if interpretation is None:
interpretation = get_attr_as_unicode(self.group, "interpretation")
if interpretation not in allowed_interpretations:
- nxdata_logger.warning("Interpretation %s is not valid." % interpretation +
- " Valid values: " + ", ".join(str(s) for s in allowed_interpretations))
+ nxdata_logger.warning(
+ "Interpretation %s is not valid." % interpretation
+ + " Valid values: "
+ + ", ".join(str(s) for s in allowed_interpretations)
+ )
return interpretation
@property
@@ -529,7 +571,7 @@ class NXdata(object):
continue
fg_idx = axis.attrs.get("first_good", 0)
lg_idx = axis.attrs.get("last_good", len(axis) - 1)
- axes[i] = axis[fg_idx:lg_idx + 1]
+ axes[i] = axis[fg_idx : lg_idx + 1]
self._axes = axes
return self._axes
@@ -548,7 +590,7 @@ class NXdata(object):
if not self.is_valid:
raise InvalidNXdataError("Unable to parse invalid NXdata")
- numbered_names = [] # used in case of @axis=0 (old spec)
+ numbered_names = [] # used in case of @axis=0 (old spec)
axes_dataset_names = get_attr_as_unicode(self.group, "axes")
if axes_dataset_names is None:
# try @axes on signal dataset (older NXdata specification)
@@ -567,8 +609,10 @@ class NXdata(object):
try:
axis_num = int(axis_attr)
except (ValueError, TypeError):
- nxdata_logger.warning("Could not interpret attr @axis as"
- "int on dataset %s", dsname)
+ nxdata_logger.warning(
+ "Could not interpret attr @axis as" "int on dataset %s",
+ dsname,
+ )
continue
numbered_names.append((axis_num, dsname))
@@ -634,8 +678,11 @@ class NXdata(object):
title = self.group.get("title")
data_dataset_names = [self.signal_name] + self.axes_dataset_names
- if (title is not None and is_dataset(title) and
- "title" not in data_dataset_names):
+ if (
+ title is not None
+ and is_dataset(title)
+ and "title" not in data_dataset_names
+ ):
return str(h5py_read_dataset(title))
title = self.group.attrs.get("title")
@@ -680,7 +727,7 @@ class NXdata(object):
errors_name = axis_name + "_errors"
if errors_name in self.group and is_dataset(self.group[errors_name]):
if fg_idx != 0 or lg_idx != (len_axis - 1):
- return self.group[errors_name][fg_idx:lg_idx + 1]
+ return self.group[errors_name][fg_idx : lg_idx + 1]
else:
return self.group[errors_name]
# case of uncertainties dataset name provided in @uncertainties
@@ -701,13 +748,17 @@ class NXdata(object):
if hasattr(axes_ds_names[0], "decode"):
axes_ds_names = [ax_name.decode("utf-8") for ax_name in axes_ds_names]
if axis_name not in axes_ds_names:
- raise KeyError("group attr @axes does not mention a dataset " +
- "named '%s'" % axis_name)
- errors = self.group[uncertainties_names[list(axes_ds_names).index(axis_name)]]
+ raise KeyError(
+ "group attr @axes does not mention a dataset "
+ + "named '%s'" % axis_name
+ )
+ errors = self.group[
+ uncertainties_names[list(axes_ds_names).index(axis_name)]
+ ]
if fg_idx == 0 and lg_idx == (len_axis - 1):
- return errors # dataset
+ return errors # dataset
else:
- return errors[fg_idx:lg_idx + 1] # numpy array
+ return errors[fg_idx : lg_idx + 1] # numpy array
return None
@property
@@ -798,7 +849,9 @@ class NXdata(object):
# the axis, if any, must be of the same length as the last dimension
# of the signal, or of length 2 (a + b *x scale)
if self.axes[-1] is not None and len(self.axes[-1]) not in [
- self.signal.shape[-1], 2]:
+ self.signal.shape[-1],
+ 2,
+ ]:
return False
if self.interpretation is None:
# We no longer test whether x values are monotonic
@@ -820,8 +873,7 @@ class NXdata(object):
return False
if self.signal_is_0d or self.signal_is_1d:
return False
- if not self.signal_is_2d and \
- self.interpretation not in ["image", "rgba-image"]:
+ if not self.signal_is_2d and self.interpretation not in ["image", "rgba-image"]:
return False
if self.signal_is_3d and self.interpretation == "rgba-image":
if self.signal.shape[-1] not in [3, 4]:
@@ -848,7 +900,12 @@ class NXdata(object):
raise InvalidNXdataError("Unable to parse invalid NXdata")
if self.signal_ndim < 3 or self.interpretation in [
- "scalar", "scaler", "spectrum", "image", "rgba-image"]:
+ "scalar",
+ "scaler",
+ "spectrum",
+ "image",
+ "rgba-image",
+ ]:
return False
stack_shape = self.signal.shape[-3:]
for i, axis in enumerate(self.axes[-3:]):
@@ -879,7 +936,7 @@ class NXdata(object):
return True
-def is_valid_nxdata(group): # noqa
+def is_valid_nxdata(group): # noqa
"""Check if a h5py group is a **valid** NX_data group.
:param group: h5py-like group
@@ -968,8 +1025,7 @@ def is_NXroot_with_default_NXdata(group, validate=True):
return False
default_nxentry_group = group.get(default_nxentry_name)
- return is_NXentry_with_default_NXdata(default_nxentry_group,
- validate=validate)
+ return is_NXentry_with_default_NXdata(default_nxentry_group, validate=validate)
def _get_default(
@@ -998,7 +1054,7 @@ def _get_default(
return None
-def get_default(group, validate: bool=True) -> Optional[NXdata]:
+def get_default(group, validate: bool = True) -> Optional[NXdata]:
"""Find the default :class:`NXdata` group in given group.
`@default` attributes are recursively followed until finding a group with
diff --git a/src/silx/io/nxdata/write.py b/src/silx/io/nxdata/write.py
index 5dfe1df..7f429e9 100644
--- a/src/silx/io/nxdata/write.py
+++ b/src/silx/io/nxdata/write.py
@@ -40,12 +40,21 @@ def _str_to_utf8(text):
return numpy.array(text, dtype=h5py.special_dtype(vlen=str))
-def save_NXdata(filename, signal, axes=None,
- signal_name="data", axes_names=None,
- signal_long_name=None, axes_long_names=None,
- signal_errors=None, axes_errors=None,
- title=None, interpretation=None,
- nxentry_name="entry", nxdata_name=None):
+def save_NXdata(
+ filename,
+ signal,
+ axes=None,
+ signal_name="data",
+ axes_names=None,
+ signal_long_name=None,
+ axes_long_names=None,
+ signal_errors=None,
+ axes_errors=None,
+ title=None,
+ interpretation=None,
+ nxentry_name="entry",
+ nxdata_name=None,
+):
"""Write data to an NXdata group.
.. note::
@@ -93,13 +102,15 @@ def save_NXdata(filename, signal, axes=None,
:return: True if save was successful, else False.
"""
if h5py is None:
- raise ImportError("h5py could not be imported, but is required by "
- "save_NXdata function")
+ raise ImportError(
+ "h5py could not be imported, but is required by " "save_NXdata function"
+ )
if axes_names is not None:
assert axes is not None, "Axes names defined, but missing axes arrays"
- assert len(axes) == len(axes_names), \
- "Mismatch between number of axes and axes_names"
+ assert len(axes) == len(
+ axes_names
+ ), "Mismatch between number of axes and axes_names"
if axes is not None and axes_names is None:
axes_names = []
@@ -131,7 +142,7 @@ def save_NXdata(filename, signal, axes=None,
# set this entry as default
h5f.attrs["default"] = _str_to_utf8(nxentry_name)
if "NX_class" not in entry.attrs:
- entry.attrs["NX_class"] = u"NXentry"
+ entry.attrs["NX_class"] = "NXentry"
else:
# write NXdata into the root of the file (invalid nexus!)
entry = h5f
@@ -139,21 +150,25 @@ def save_NXdata(filename, signal, axes=None,
# Create NXdata group
if nxdata_name is not None:
if nxdata_name in entry:
- _logger.error("Cannot assign an NXdata group to an existing"
- " group or dataset")
+ _logger.error(
+ "Cannot assign an NXdata group to an existing" " group or dataset"
+ )
return False
else:
# no name specified, take one that is available
nxdata_name = "data0"
i = 1
while nxdata_name in entry:
- _logger.info("%s item already exists in NXentry group," +
- " trying %s", nxdata_name, "data%d" % i)
+ _logger.info(
+ "%s item already exists in NXentry group," + " trying %s",
+ nxdata_name,
+ "data%d" % i,
+ )
nxdata_name = "data%d" % i
i += 1
data_group = entry.create_group(nxdata_name)
- data_group.attrs["NX_class"] = u"NXdata"
+ data_group.attrs["NX_class"] = "NXdata"
data_group.attrs["signal"] = _str_to_utf8(signal_name)
if axes:
data_group.attrs["axes"] = _str_to_utf8(axes_names)
@@ -163,8 +178,7 @@ def save_NXdata(filename, signal, axes=None,
# better way imho
data_group.attrs["title"] = _str_to_utf8(title)
- signal_dataset = data_group.create_dataset(signal_name,
- data=signal)
+ signal_dataset = data_group.create_dataset(signal_name, data=signal)
if signal_long_name:
signal_dataset.attrs["long_name"] = _str_to_utf8(signal_long_name)
if interpretation:
@@ -172,28 +186,28 @@ def save_NXdata(filename, signal, axes=None,
for i, axis_array in enumerate(axes):
if axis_array is None:
- assert axes_names[i] in [".", None], \
+ assert axes_names[i] in [".", None], (
"Axis name defined for dim %d but no axis array" % i
+ )
continue
- axis_dataset = data_group.create_dataset(axes_names[i],
- data=axis_array)
+ axis_dataset = data_group.create_dataset(axes_names[i], data=axis_array)
if axes_long_names is not None:
axis_dataset.attrs["long_name"] = _str_to_utf8(axes_long_names[i])
if signal_errors is not None:
- data_group.create_dataset("errors",
- data=signal_errors)
+ data_group.create_dataset("errors", data=signal_errors)
if axes_errors is not None:
- assert isinstance(axes_errors, (list, tuple)), \
- "axes_errors must be a list or a tuple of ndarray or None"
- assert len(axes_errors) == len(axes_names), \
- "Mismatch between number of axes_errors and axes_names"
+ assert isinstance(
+ axes_errors, (list, tuple)
+ ), "axes_errors must be a list or a tuple of ndarray or None"
+ assert len(axes_errors) == len(
+ axes_names
+ ), "Mismatch between number of axes_errors and axes_names"
for i, axis_errors in enumerate(axes_errors):
if axis_errors is not None:
dsname = axes_names[i] + "_errors"
- data_group.create_dataset(dsname,
- data=axis_errors)
+ data_group.create_dataset(dsname, data=axis_errors)
if "default" not in entry.attrs:
# set this NXdata as default
entry.attrs["default"] = nxdata_name
diff --git a/src/silx/io/octaveh5.py b/src/silx/io/octaveh5.py
index 67fb1e2..5f5d81d 100644
--- a/src/silx/io/octaveh5.py
+++ b/src/silx/io/octaveh5.py
@@ -48,6 +48,7 @@ Here is an example of a simple read and write :
"""
import logging
+
logger = logging.getLogger(__name__)
import numpy as np
import h5py
@@ -58,20 +59,19 @@ __date__ = "05/10/2016"
class Octaveh5(object):
- """This class allows communication between octave and python using hdf5 format.
- """
+ """This class allows communication between octave and python using hdf5 format."""
def __init__(self, octave_targetted_version=3.8):
"""Constructor
:param octave_targetted_version: the version of Octave for which we want to write this hdf5 file.
-
+
This is needed because for old Octave version we need to had a hack(adding one extra character)
"""
self.file = None
self.octave_targetted_version = octave_targetted_version
- def open(self, h5file, mode='r'):
+ def open(self, h5file, mode="r"):
"""Open the h5 file which has been write by octave
:param h5file: The path of the file to read
@@ -81,7 +81,7 @@ class Octaveh5(object):
self.file = h5py.File(h5file, mode)
return self
except IOError as e:
- if mode == 'a':
+ if mode == "a":
reason = "\n %s: Can t find or create " % h5file
else:
reason = "\n %s: File not found" % h5file
@@ -113,15 +113,17 @@ class Octaveh5(object):
for key, val in iter(dict(gr_level2).items()):
data_dict[str(key)] = list(val.items())[1][1][()]
- if list(val.items())[0][1][()] != np.string_('sq_string'):
+ if list(val.items())[0][1][()] != np.string_("sq_string"):
data_dict[str(key)] = float(data_dict[str(key)])
else:
- if list(val.items())[0][1][()] == np.string_('sq_string'):
+ if list(val.items())[0][1][()] == np.string_("sq_string"):
# in the case the string has been stored as an nd-array of char
if type(data_dict[str(key)]) is np.ndarray:
- data_dict[str(key)] = "".join(chr(item) for item in data_dict[str(key)])
+ data_dict[str(key)] = "".join(
+ chr(item) for item in data_dict[str(key)]
+ )
else:
- data_dict[str(key)] = data_dict[str(key)].decode('UTF-8')
+ data_dict[str(key)] = data_dict[str(key)].decode("UTF-8")
# In the case Octave have added an extra character at the end
if self.octave_targetted_version < 3.8:
@@ -141,30 +143,36 @@ class Octaveh5(object):
return
group_l1 = self.file.create_group(struct_name)
- group_l1.attrs['OCTAVE_GLOBAL'] = np.uint8(1)
- group_l1.attrs['OCTAVE_NEW_FORMAT'] = np.uint8(1)
- group_l1.create_dataset("type", data=np.string_('scalar struct'), dtype="|S14")
- group_l2 = group_l1.create_group('value')
+ group_l1.attrs["OCTAVE_GLOBAL"] = np.uint8(1)
+ group_l1.attrs["OCTAVE_NEW_FORMAT"] = np.uint8(1)
+ group_l1.create_dataset("type", data=np.string_("scalar struct"), dtype="|S14")
+ group_l2 = group_l1.create_group("value")
for ftparams in data_dict:
group_l3 = group_l2.create_group(ftparams)
- group_l3.attrs['OCTAVE_NEW_FORMAT'] = np.uint8(1)
+ group_l3.attrs["OCTAVE_NEW_FORMAT"] = np.uint8(1)
if type(data_dict[ftparams]) == str:
- group_l3.create_dataset("type", (), data=np.string_('sq_string'), dtype="|S10")
+ group_l3.create_dataset(
+ "type", (), data=np.string_("sq_string"), dtype="|S10"
+ )
if self.octave_targetted_version < 3.8:
- group_l3.create_dataset("value", data=np.string_(data_dict[ftparams] + '0'))
+ group_l3.create_dataset(
+ "value", data=np.string_(data_dict[ftparams] + "0")
+ )
else:
- group_l3.create_dataset("value", data=np.string_(data_dict[ftparams]))
+ group_l3.create_dataset(
+ "value", data=np.string_(data_dict[ftparams])
+ )
else:
- group_l3.create_dataset("type", (), data=np.string_('scalar'), dtype="|S7")
+ group_l3.create_dataset(
+ "type", (), data=np.string_("scalar"), dtype="|S7"
+ )
group_l3.create_dataset("value", data=data_dict[ftparams])
def close(self):
- """Close the file after calling read function
- """
+ """Close the file after calling read function"""
if self.file:
self.file.close()
def __del__(self):
- """Destructor
- """
+ """Destructor"""
self.close()
diff --git a/src/silx/io/rawh5.py b/src/silx/io/rawh5.py
index 31b554d..dc117c4 100644
--- a/src/silx/io/rawh5.py
+++ b/src/silx/io/rawh5.py
@@ -37,7 +37,6 @@ _logger = logging.getLogger(__name__)
class _FreeDataset(commonh5.Dataset):
-
def _check_data(self, data):
"""Release the constriants checked on types cause we can reach more
types than the one available on h5py, and it is not supposed to be
@@ -55,6 +54,7 @@ class NumpyFile(commonh5.File):
:param str name: Filename to load
"""
+
def __init__(self, name=None):
commonh5.File.__init__(self, name=name, mode="w")
np_file = numpy.load(name)
diff --git a/src/silx/io/specfile.pyx b/src/silx/io/specfile.pyx
index dde6d82..ca43419 100644
--- a/src/silx/io/specfile.pyx
+++ b/src/silx/io/specfile.pyx
@@ -1,5 +1,5 @@
# /*##########################################################################
-# Copyright (C) 2016-2018 European Synchrotron Radiation Facility
+# Copyright (C) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -110,7 +110,6 @@ import os.path
import logging
import numpy
import re
-import sys
_logger = logging.getLogger(__name__)
@@ -599,7 +598,7 @@ class Scan(object):
def _string_to_char_star(string_):
"""Convert a string to ASCII encoded bytes when using python3"""
- if sys.version_info[0] >= 3 and not isinstance(string_, bytes):
+ if not isinstance(string_, bytes):
return bytes(string_, "ascii")
return string_
diff --git a/src/silx/io/specfilewrapper.py b/src/silx/io/specfilewrapper.py
index d8ee90b..b257738 100644
--- a/src/silx/io/specfilewrapper.py
+++ b/src/silx/io/specfilewrapper.py
@@ -105,6 +105,7 @@ class Specfile(SpecFile):
- :meth:`epoch`
- :meth:`title`
"""
+
def __init__(self, filename):
SpecFile.__init__(self, filename)
@@ -167,8 +168,7 @@ class Specfile(SpecFile):
except (ValueError, IndexError):
# self.index can raise an index error
# int() can raise a value error
- raise KeyError(msg + "\nValid keys: '" +
- "', '".join(self.keys()) + "'")
+ raise KeyError(msg + "\nValid keys: '" + "', '".join(self.keys()) + "'")
except AttributeError:
# e.g. "AttrErr: 'float' object has no attribute 'split'"
raise TypeError(msg)
@@ -258,6 +258,7 @@ class scandata(Scan): # noqa
- :meth:`fileheader`
- :meth:`nbmca`
"""
+
def __init__(self, specfile, scan_index):
Scan.__init__(self, specfile, scan_index)
@@ -317,7 +318,7 @@ class scandata(Scan): # noqa
"""Return the date from the scan header line ``#D``"""
return self._specfile.date(self._index)
- def fileheader(self, key=''): # noqa
+ def fileheader(self, key=""): # noqa
"""Return a list of file header lines"""
# key is there for compatibility
return self.file_header
diff --git a/src/silx/io/spech5.py b/src/silx/io/spech5.py
index 05ce9f0..4f358e8 100644
--- a/src/silx/io/spech5.py
+++ b/src/silx/io/spech5.py
@@ -1,5 +1,5 @@
# /*##########################################################################
-# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
+# Copyright (C) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -155,32 +155,6 @@ You can test for existence of data or groups::
True
>>> "spam" in sfh5["1.1"]
False
-
-.. note::
-
- Text used to be stored with a dtype ``numpy.string_`` in silx versions
- prior to *0.7.0*. The type ``numpy.string_`` is a byte-string format.
- The consequence of this is that you had to decode strings before using
- them in **Python 3**::
-
- >>> from silx.io.spech5 import SpecH5
- >>> sfh5 = SpecH5("31oct98.dat")
- >>> sfh5["/68.1/title"]
- b'68 ascan tx3 -28.5 -24.5 20 0.5'
- >>> sfh5["/68.1/title"].decode()
- '68 ascan tx3 -28.5 -24.5 20 0.5'
-
- From silx version *0.7.0* onwards, text is now stored as unicode. This
- corresponds to the default text type in python 3, and to the *unicode*
- type in Python 2.
-
- To be on the safe side, you can test for the presence of a *decode*
- attribute, to ensure that you always work with unicode text::
-
- >>> title = sfh5["/68.1/title"]
- >>> if hasattr(title, "decode"):
- ... title = title.decode()
-
"""
import datetime
@@ -245,8 +219,9 @@ def _motor_in_scan(sf, scan_key, motor_name):
:raise: ``KeyError`` if scan_key not found in SpecFile
"""
if scan_key not in sf:
- raise KeyError("Scan key %s " % scan_key +
- "does not exist in SpecFile %s" % sf.filename)
+ raise KeyError(
+ "Scan key %s " % scan_key + "does not exist in SpecFile %s" % sf.filename
+ )
ret = motor_name in sf[scan_key].motor_names
if not ret and "%" in motor_name:
motor_name = motor_name.replace("%", "/")
@@ -263,8 +238,9 @@ def _column_label_in_scan(sf, scan_key, column_label):
:raise: ``KeyError`` if scan_key not found in SpecFile
"""
if scan_key not in sf:
- raise KeyError("Scan key %s " % scan_key +
- "does not exist in SpecFile %s" % sf.filename)
+ raise KeyError(
+ "Scan key %s " % scan_key + "does not exist in SpecFile %s" % sf.filename
+ )
ret = column_label in sf[scan_key].labels
if not ret and "%" in column_label:
column_label = column_label.replace("%", "/")
@@ -349,8 +325,9 @@ def _parse_ctime(ctime_lines, analyser_index=0):
else:
ctime_line = ctimes_lines_list[analyser_index]
if not len(ctime_line.split()) == 3:
- raise ValueError("Incorrect format for @CTIME header line " +
- '(expected "@CTIME %f %f %f").')
+ raise ValueError(
+ "Incorrect format for @CTIME header line " + '(expected "@CTIME %f %f %f").'
+ )
return list(map(float, ctime_line.split()))
@@ -380,36 +357,52 @@ def spec_date_to_iso8601(date, zone=None):
>>> spec_date_to_iso8601("Sat 2015/03/14 03:53:50")
'2015-03-14T03:53:50'
"""
- months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul',
- 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
- days = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
-
- days_rx = '(?P<day>' + '|'.join(days) + ')'
- months_rx = '(?P<month>' + '|'.join(months) + ')'
- year_rx = r'(?P<year>\d{4})'
- day_nb_rx = r'(?P<day_nb>[0-3 ]\d)'
- month_nb_rx = r'(?P<month_nb>[0-1]\d)'
- hh_rx = r'(?P<hh>[0-2]\d)'
- mm_rx = r'(?P<mm>[0-5]\d)'
- ss_rx = r'(?P<ss>[0-5]\d)'
- tz_rx = r'(?P<tz>[+-]\d\d:\d\d){0,1}'
+ months = [
+ "Jan",
+ "Feb",
+ "Mar",
+ "Apr",
+ "May",
+ "Jun",
+ "Jul",
+ "Aug",
+ "Sep",
+ "Oct",
+ "Nov",
+ "Dec",
+ ]
+ days = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
+
+ days_rx = "(?P<day>" + "|".join(days) + ")"
+ months_rx = "(?P<month>" + "|".join(months) + ")"
+ year_rx = r"(?P<year>\d{4})"
+ day_nb_rx = r"(?P<day_nb>[0-3 ]\d)"
+ month_nb_rx = r"(?P<month_nb>[0-1]\d)"
+ hh_rx = r"(?P<hh>[0-2]\d)"
+ mm_rx = r"(?P<mm>[0-5]\d)"
+ ss_rx = r"(?P<ss>[0-5]\d)"
+ tz_rx = r"(?P<tz>[+-]\d\d:\d\d){0,1}"
# date formats must have either month_nb (1..12) or month (Jan, Feb, ...)
- re_tpls = ['{days} {months} {day_nb} {hh}:{mm}:{ss}{tz} {year}',
- '{days} {year}/{month_nb}/{day_nb} {hh}:{mm}:{ss}{tz}']
+ re_tpls = [
+ "{days} {months} {day_nb} {hh}:{mm}:{ss}{tz} {year}",
+ "{days} {year}/{month_nb}/{day_nb} {hh}:{mm}:{ss}{tz}",
+ ]
grp_d = None
for rx in re_tpls:
- full_rx = rx.format(days=days_rx,
- months=months_rx,
- year=year_rx,
- day_nb=day_nb_rx,
- month_nb=month_nb_rx,
- hh=hh_rx,
- mm=mm_rx,
- ss=ss_rx,
- tz=tz_rx)
+ full_rx = rx.format(
+ days=days_rx,
+ months=months_rx,
+ year=year_rx,
+ day_nb=day_nb_rx,
+ month_nb=month_nb_rx,
+ hh=hh_rx,
+ mm=mm_rx,
+ ss=ss_rx,
+ tz=tz_rx,
+ )
m = re.match(full_rx, date)
if m:
@@ -417,30 +410,24 @@ def spec_date_to_iso8601(date, zone=None):
break
if not grp_d:
- raise ValueError('Date format not recognized : {0}'.format(date))
+ raise ValueError("Date format not recognized : {0}".format(date))
- year = grp_d['year']
+ year = grp_d["year"]
- month = grp_d.get('month_nb')
+ month = grp_d.get("month_nb")
if not month:
- month = '{0:02d}'.format(months.index(grp_d.get('month')) + 1)
+ month = "{0:02d}".format(months.index(grp_d.get("month")) + 1)
- day = grp_d['day_nb']
+ day = grp_d["day_nb"]
- tz = grp_d['tz']
+ tz = grp_d["tz"]
if not tz:
tz = zone
- time = '{0}:{1}:{2}'.format(grp_d['hh'],
- grp_d['mm'],
- grp_d['ss'])
+ time = "{0}:{1}:{2}".format(grp_d["hh"], grp_d["mm"], grp_d["ss"])
- full_date = '{0}-{1}-{2}T{3}{4}'.format(year,
- month,
- day,
- time,
- tz if tz else '')
+ full_date = "{0}-{1}-{2}T{3}{4}".format(year, month, day, time, tz if tz else "")
return full_date
@@ -483,6 +470,7 @@ class SpecH5Dataset(object):
Datasets must also inherit :class:`SpecH5NodeDataset` or
:class:`SpecH5LazyNodeDataset` which actually implement all the
API."""
+
pass
@@ -492,6 +480,7 @@ class SpecH5NodeDataset(commonh5.Dataset, SpecH5Dataset):
proxy behavior that allows to mimic the numpy array stored in this
class.
"""
+
def __init__(self, name, data, parent=None, attrs=None):
# get proper value types, to inherit from numpy
# attributes (dtype, shape, size)
@@ -509,8 +498,7 @@ class SpecH5NodeDataset(commonh5.Dataset, SpecH5Dataset):
data_kind = array.dtype.kind
if data_kind in ["S", "U"]:
- value = numpy.asarray(array,
- dtype=text_dtype)
+ value = numpy.asarray(array, dtype=text_dtype)
elif data_kind in ["f"]:
value = numpy.asarray(array, dtype=numpy.float32)
else:
@@ -518,8 +506,7 @@ class SpecH5NodeDataset(commonh5.Dataset, SpecH5Dataset):
commonh5.Dataset.__init__(self, name, value, parent, attrs)
def __getattr__(self, item):
- """Proxy to underlying numpy array methods.
- """
+ """Proxy to underlying numpy array methods."""
if hasattr(self[()], item):
return getattr(self[()], item)
@@ -535,9 +522,9 @@ class SpecH5LazyNodeDataset(commonh5.LazyLoadableDataset, SpecH5Dataset):
implemented to return the numpy data exposed by the dataset. This factory
method is only called once, when the data is needed.
"""
+
def __getattr__(self, item):
- """Proxy to underlying numpy array methods.
- """
+ """Proxy to underlying numpy array methods."""
if hasattr(self[()], item):
return getattr(self[()], item)
@@ -564,6 +551,7 @@ class SpecH5Group(object):
Groups must also inherit :class:`silx.io.commonh5.Group`, which
actually implements all the methods and attributes."""
+
pass
@@ -585,11 +573,12 @@ class SpecH5(commonh5.File, SpecH5Group):
self._sf = SpecFile(filename)
- attrs = {"NX_class": to_h5py_utf8("NXroot"),
- "file_time": to_h5py_utf8(
- datetime.datetime.now().isoformat()),
- "file_name": to_h5py_utf8(filename),
- "creator": to_h5py_utf8("silx spech5 %s" % silx_version)}
+ attrs = {
+ "NX_class": to_h5py_utf8("NXroot"),
+ "file_time": to_h5py_utf8(datetime.datetime.now().isoformat()),
+ "file_name": to_h5py_utf8(filename),
+ "creator": to_h5py_utf8("silx spech5 %s" % silx_version),
+ }
commonh5.File.__init__(self, filename, attrs=attrs)
for scan_key in self._sf.keys():
@@ -610,42 +599,51 @@ class ScanGroup(commonh5.Group, SpecH5Group):
:param str scan_key: Scan key (e.g. "1.1")
:param scan: specfile.Scan object
"""
- commonh5.Group.__init__(self, scan_key, parent=parent,
- attrs={"NX_class": to_h5py_utf8("NXentry")})
+ commonh5.Group.__init__(
+ self, scan_key, parent=parent, attrs={"NX_class": to_h5py_utf8("NXentry")}
+ )
# take title in #S after stripping away scan number and spaces
s_hdr_line = scan.scan_header_dict["S"]
title = s_hdr_line.lstrip("0123456789").lstrip()
- self.add_node(SpecH5NodeDataset(name="title",
- data=to_h5py_utf8(title),
- parent=self))
+ self.add_node(
+ SpecH5NodeDataset(name="title", data=to_h5py_utf8(title), parent=self)
+ )
if "D" in scan.scan_header_dict:
try:
start_time_str = spec_date_to_iso8601(scan.scan_header_dict["D"])
except (IndexError, ValueError):
- logger1.warning("Could not parse date format in scan %s header." +
- " Using original date not converted to ISO-8601",
- scan_key)
+ logger1.warning(
+ "Could not parse date format in scan %s header."
+ + " Using original date not converted to ISO-8601",
+ scan_key,
+ )
start_time_str = scan.scan_header_dict["D"]
elif "D" in scan.file_header_dict:
- logger1.warning("No #D line in scan %s header. " +
- "Using file header for start_time.",
- scan_key)
+ logger1.warning(
+ "No #D line in scan %s header. " + "Using file header for start_time.",
+ scan_key,
+ )
try:
start_time_str = spec_date_to_iso8601(scan.file_header_dict["D"])
except (IndexError, ValueError):
- logger1.warning("Could not parse date format in scan %s header. " +
- "Using original date not converted to ISO-8601",
- scan_key)
+ logger1.warning(
+ "Could not parse date format in scan %s header. "
+ + "Using original date not converted to ISO-8601",
+ scan_key,
+ )
start_time_str = scan.file_header_dict["D"]
else:
- logger1.warning("No #D line in %s header. Setting date to empty string.",
- scan_key)
+ logger1.warning(
+ "No #D line in %s header. Setting date to empty string.", scan_key
+ )
start_time_str = ""
- self.add_node(SpecH5NodeDataset(name="start_time",
- data=to_h5py_utf8(start_time_str),
- parent=self))
+ self.add_node(
+ SpecH5NodeDataset(
+ name="start_time", data=to_h5py_utf8(start_time_str), parent=self
+ )
+ )
self.add_node(InstrumentGroup(parent=self, scan=scan))
self.add_node(MeasurementGroup(parent=self, scan=scan))
@@ -660,42 +658,60 @@ class InstrumentGroup(commonh5.Group, SpecH5Group):
:param parent: parent Group
:param scan: specfile.Scan object
"""
- commonh5.Group.__init__(self, name="instrument", parent=parent,
- attrs={"NX_class": to_h5py_utf8("NXinstrument")})
+ commonh5.Group.__init__(
+ self,
+ name="instrument",
+ parent=parent,
+ attrs={"NX_class": to_h5py_utf8("NXinstrument")},
+ )
self.add_node(InstrumentSpecfileGroup(parent=self, scan=scan))
self.add_node(PositionersGroup(parent=self, scan=scan))
num_analysers = _get_number_of_mca_analysers(scan)
for anal_idx in range(num_analysers):
- self.add_node(InstrumentMcaGroup(parent=self,
- analyser_index=anal_idx,
- scan=scan))
+ self.add_node(
+ InstrumentMcaGroup(parent=self, analyser_index=anal_idx, scan=scan)
+ )
class InstrumentSpecfileGroup(commonh5.Group, SpecH5Group):
def __init__(self, parent, scan):
- commonh5.Group.__init__(self, name="specfile", parent=parent,
- attrs={"NX_class": to_h5py_utf8("NXcollection")})
- self.add_node(SpecH5NodeDataset(
+ commonh5.Group.__init__(
+ self,
+ name="specfile",
+ parent=parent,
+ attrs={"NX_class": to_h5py_utf8("NXcollection")},
+ )
+ self.add_node(
+ SpecH5NodeDataset(
name="file_header",
data=to_h5py_utf8(scan.file_header),
parent=self,
- attrs={}))
- self.add_node(SpecH5NodeDataset(
+ attrs={},
+ )
+ )
+ self.add_node(
+ SpecH5NodeDataset(
name="scan_header",
data=to_h5py_utf8(scan.scan_header),
parent=self,
- attrs={}))
+ attrs={},
+ )
+ )
class PositionersGroup(commonh5.Group, SpecH5Group):
def __init__(self, parent, scan):
- commonh5.Group.__init__(self, name="positioners", parent=parent,
- attrs={"NX_class": to_h5py_utf8("NXcollection")})
+ commonh5.Group.__init__(
+ self,
+ name="positioners",
+ parent=parent,
+ attrs={"NX_class": to_h5py_utf8("NXcollection")},
+ )
dataset_info = [] # Store list of positioner's (name, value)
- is_error = False # True if error encountered
+ is_error = False # True if error encountered
for motor_name in scan.motor_names:
safe_motor_name = motor_name.replace("/", "%")
@@ -709,31 +725,34 @@ class PositionersGroup(commonh5.Group, SpecH5Group):
motor_value = scan.motor_position_by_name(motor_name)
except SfErrColNotFound:
is_error = True
- motor_value = float('inf')
+ motor_value = float("inf")
dataset_info.append((safe_motor_name, motor_value))
if is_error: # Filter-out scalar values
logger1.warning("Mismatching number of elements in #P and #O: Ignoring")
dataset_info = [
- (name, value) for name, value in dataset_info
- if not isinstance(value, float)]
+ (name, value)
+ for name, value in dataset_info
+ if not isinstance(value, float)
+ ]
for name, value in dataset_info:
- self.add_node(SpecH5NodeDataset(
- name=name,
- data=value,
- parent=self))
+ self.add_node(SpecH5NodeDataset(name=name, data=value, parent=self))
class InstrumentMcaGroup(commonh5.Group, SpecH5Group):
def __init__(self, parent, analyser_index, scan):
name = "mca_%d" % analyser_index
- commonh5.Group.__init__(self, name=name, parent=parent,
- attrs={"NX_class": to_h5py_utf8("NXdetector")})
-
- mcaDataDataset = McaDataDataset(parent=self,
- analyser_index=analyser_index,
- scan=scan)
+ commonh5.Group.__init__(
+ self,
+ name=name,
+ parent=parent,
+ attrs={"NX_class": to_h5py_utf8("NXdetector")},
+ )
+
+ mcaDataDataset = McaDataDataset(
+ parent=self, analyser_index=analyser_index, scan=scan
+ )
self.add_node(mcaDataDataset)
spectrum_length = mcaDataDataset.shape[-1]
mcaDataDataset = None
@@ -746,7 +765,7 @@ class InstrumentMcaGroup(commonh5.Group, SpecH5Group):
calibration_dataset = scan.mca.calibration[analyser_index]
channels_dataset = scan.mca.channels[analyser_index]
- channels_length = len(channels_dataset)
+ channels_length = len(channels_dataset)
if (channels_length > 1) and (spectrum_length > 0):
logger1.info("Spectrum and channels length mismatch")
# this should always be the case
@@ -756,37 +775,48 @@ class InstrumentMcaGroup(commonh5.Group, SpecH5Group):
# only trust first channel and increment
channel0 = channels_dataset[0]
increment = channels_dataset[1] - channels_dataset[0]
- channels_dataset = numpy.linspace(channel0,
- channel0 + increment * spectrum_length,
- spectrum_length, endpoint=False)
-
- self.add_node(SpecH5NodeDataset(name="calibration",
- data=calibration_dataset,
- parent=self))
- self.add_node(SpecH5NodeDataset(name="channels",
- data=channels_dataset,
- parent=self))
+ channels_dataset = numpy.linspace(
+ channel0,
+ channel0 + increment * spectrum_length,
+ spectrum_length,
+ endpoint=False,
+ )
+
+ self.add_node(
+ SpecH5NodeDataset(name="calibration", data=calibration_dataset, parent=self)
+ )
+ self.add_node(
+ SpecH5NodeDataset(name="channels", data=channels_dataset, parent=self)
+ )
if "CTIME" in scan.mca_header_dict:
- ctime_line = scan.mca_header_dict['CTIME']
- preset_time, live_time, elapsed_time = _parse_ctime(ctime_line, analyser_index)
- self.add_node(SpecH5NodeDataset(name="preset_time",
- data=preset_time,
- parent=self))
- self.add_node(SpecH5NodeDataset(name="live_time",
- data=live_time,
- parent=self))
- self.add_node(SpecH5NodeDataset(name="elapsed_time",
- data=elapsed_time,
- parent=self))
+ ctime_line = scan.mca_header_dict["CTIME"]
+ preset_time, live_time, elapsed_time = _parse_ctime(
+ ctime_line, analyser_index
+ )
+ self.add_node(
+ SpecH5NodeDataset(name="preset_time", data=preset_time, parent=self)
+ )
+ self.add_node(
+ SpecH5NodeDataset(name="live_time", data=live_time, parent=self)
+ )
+ self.add_node(
+ SpecH5NodeDataset(name="elapsed_time", data=elapsed_time, parent=self)
+ )
class McaDataDataset(SpecH5LazyNodeDataset):
"""Lazy loadable dataset for MCA data"""
+
def __init__(self, parent, analyser_index, scan):
commonh5.LazyLoadableDataset.__init__(
- self, name="data", parent=parent,
- attrs={"interpretation": to_h5py_utf8("spectrum"),})
+ self,
+ name="data",
+ parent=parent,
+ attrs={
+ "interpretation": to_h5py_utf8("spectrum"),
+ },
+ )
self._scan = scan
self._analyser_index = analyser_index
self._shape = None
@@ -812,7 +842,7 @@ class McaDataDataset(SpecH5LazyNodeDataset):
def dtype(self):
# we initialize the data with numpy.empty() without specifying a dtype
# in _demultiplex_mca()
- return numpy.empty((1, )).dtype
+ return numpy.empty((1,)).dtype
def __len__(self):
return self.shape[0]
@@ -824,8 +854,7 @@ class McaDataDataset(SpecH5LazyNodeDataset):
if item < 0:
# negative indexing
item += len(self)
- return self._scan.mca[self._analyser_index +
- item * self._num_analysers]
+ return self._scan.mca[self._analyser_index + item * self._num_analysers]
# accessing a slice or element of a single spectrum [i, j:k]
try:
spectrum_idx, channel_idx_or_slice = item
@@ -848,13 +877,21 @@ class MeasurementGroup(commonh5.Group, SpecH5Group):
:param parent: parent Group
:param scan: specfile.Scan object
"""
- commonh5.Group.__init__(self, name="measurement", parent=parent,
- attrs={"NX_class": to_h5py_utf8("NXcollection"),})
+ commonh5.Group.__init__(
+ self,
+ name="measurement",
+ parent=parent,
+ attrs={
+ "NX_class": to_h5py_utf8("NXcollection"),
+ },
+ )
for label in scan.labels:
safe_label = label.replace("/", "%")
- self.add_node(SpecH5NodeDataset(name=safe_label,
- data=scan.data_column_by_name(label),
- parent=self))
+ self.add_node(
+ SpecH5NodeDataset(
+ name=safe_label, data=scan.data_column_by_name(label), parent=self
+ )
+ )
num_analysers = _get_number_of_mca_analysers(scan)
for anal_idx in range(num_analysers):
@@ -864,16 +901,13 @@ class MeasurementGroup(commonh5.Group, SpecH5Group):
class MeasurementMcaGroup(commonh5.Group, SpecH5Group):
def __init__(self, parent, analyser_index):
basename = "mca_%d" % analyser_index
- commonh5.Group.__init__(self, name=basename, parent=parent,
- attrs={})
+ commonh5.Group.__init__(self, name=basename, parent=parent, attrs={})
target_name = self.name.replace("measurement", "instrument")
- self.add_node(commonh5.SoftLink(name="data",
- path=target_name + "/data",
- parent=self))
- self.add_node(commonh5.SoftLink(name="info",
- path=target_name,
- parent=self))
+ self.add_node(
+ commonh5.SoftLink(name="data", path=target_name + "/data", parent=self)
+ )
+ self.add_node(commonh5.SoftLink(name="info", path=target_name, parent=self))
class SampleGroup(commonh5.Group, SpecH5Group):
@@ -883,24 +917,46 @@ class SampleGroup(commonh5.Group, SpecH5Group):
:param parent: parent Group
:param scan: specfile.Scan object
"""
- commonh5.Group.__init__(self, name="sample", parent=parent,
- attrs={"NX_class": to_h5py_utf8("NXsample"),})
+ commonh5.Group.__init__(
+ self,
+ name="sample",
+ parent=parent,
+ attrs={
+ "NX_class": to_h5py_utf8("NXsample"),
+ },
+ )
if _unit_cell_in_scan(scan):
- self.add_node(SpecH5NodeDataset(name="unit_cell",
- data=_parse_unit_cell(scan.scan_header_dict["G1"]),
- parent=self,
- attrs={"interpretation": to_h5py_utf8("scalar")}))
- self.add_node(SpecH5NodeDataset(name="unit_cell_abc",
- data=_parse_unit_cell(scan.scan_header_dict["G1"])[0, 0:3],
- parent=self,
- attrs={"interpretation": to_h5py_utf8("scalar")}))
- self.add_node(SpecH5NodeDataset(name="unit_cell_alphabetagamma",
- data=_parse_unit_cell(scan.scan_header_dict["G1"])[0, 3:6],
- parent=self,
- attrs={"interpretation": to_h5py_utf8("scalar")}))
+ self.add_node(
+ SpecH5NodeDataset(
+ name="unit_cell",
+ data=_parse_unit_cell(scan.scan_header_dict["G1"]),
+ parent=self,
+ attrs={"interpretation": to_h5py_utf8("scalar")},
+ )
+ )
+ self.add_node(
+ SpecH5NodeDataset(
+ name="unit_cell_abc",
+ data=_parse_unit_cell(scan.scan_header_dict["G1"])[0, 0:3],
+ parent=self,
+ attrs={"interpretation": to_h5py_utf8("scalar")},
+ )
+ )
+ self.add_node(
+ SpecH5NodeDataset(
+ name="unit_cell_alphabetagamma",
+ data=_parse_unit_cell(scan.scan_header_dict["G1"])[0, 3:6],
+ parent=self,
+ attrs={"interpretation": to_h5py_utf8("scalar")},
+ )
+ )
if _ub_matrix_in_scan(scan):
- self.add_node(SpecH5NodeDataset(name="ub_matrix",
- data=_parse_UB_matrix(scan.scan_header_dict["G3"]),
- parent=self,
- attrs={"interpretation": to_h5py_utf8("scalar")}))
+ self.add_node(
+ SpecH5NodeDataset(
+ name="ub_matrix",
+ data=_parse_UB_matrix(scan.scan_header_dict["G3"]),
+ parent=self,
+ attrs={"interpretation": to_h5py_utf8("scalar")},
+ )
+ )
diff --git a/src/silx/io/spectoh5.py b/src/silx/io/spectoh5.py
deleted file mode 100644
index 0f4f1c5..0000000
--- a/src/silx/io/spectoh5.py
+++ /dev/null
@@ -1,80 +0,0 @@
-# /*##########################################################################
-# Copyright (C) 2016-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Deprecated module. Use :mod:`convert` instead."""
-
-from .convert import Hdf5Writer
-from .convert import write_to_h5
-from .convert import convert as other_convert
-
-from silx.utils import deprecation
-
-deprecation.deprecated_warning(type_="Module",
- name="silx.io.spectoh5",
- since_version="0.6",
- replacement="silx.io.convert")
-
-
-class SpecToHdf5Writer(Hdf5Writer):
- def __init__(self, h5path='/', overwrite_data=False,
- link_type="hard", create_dataset_args=None):
- deprecation.deprecated_warning(
- type_="Class",
- name="SpecToHdf5Writer",
- since_version="0.6",
- replacement="silx.io.convert.Hdf5Writer")
- Hdf5Writer.__init__(self, h5path, overwrite_data,
- link_type, create_dataset_args)
-
- # methods whose signatures changed
- def write(self, sfh5, h5f):
- Hdf5Writer.write(self, infile=sfh5, h5f=h5f)
-
- def append_spec_member_to_h5(self, spec_h5_name, obj):
- Hdf5Writer.append_member_to_h5(self,
- h5like_name=spec_h5_name,
- obj=obj)
-
-
-@deprecation.deprecated(replacement="silx.io.convert.write_to_h5",
- since_version="0.6")
-def write_spec_to_h5(specfile, h5file, h5path='/',
- mode="a", overwrite_data=False,
- link_type="hard", create_dataset_args=None):
-
- write_to_h5(infile=specfile,
- h5file=h5file,
- h5path=h5path,
- mode=mode,
- overwrite_data=overwrite_data,
- link_type=link_type,
- create_dataset_args=create_dataset_args)
-
-
-@deprecation.deprecated(replacement="silx.io.convert.convert",
- since_version="0.6")
-def convert(specfile, h5file, mode="w-",
- create_dataset_args=None):
- other_convert(infile=specfile,
- h5file=h5file,
- mode=mode,
- create_dataset_args=create_dataset_args)
diff --git a/src/silx/io/test/test_commonh5.py b/src/silx/io/test/test_commonh5.py
index d554d27..1b0a3a6 100644
--- a/src/silx/io/test/test_commonh5.py
+++ b/src/silx/io/test/test_commonh5.py
@@ -46,6 +46,7 @@ except ImportError:
class _TestCommonFeatures(unittest.TestCase):
"""Test common features supported by h5py and our implementation."""
+
__test__ = False # ignore abstract class tests
@classmethod
@@ -108,7 +109,7 @@ class _TestCommonFeatures(unittest.TestCase):
self.assertTrue(isinstance(link, (h5py.SoftLink, commonh5.SoftLink)))
self.assertTrue(silx.io.utils.is_softlink(link))
self.assertEqual(classlink, h5py.SoftLink)
-
+
def test_external_link(self):
node = self.h5["link/external_link"]
self.assertEqual(node.name, "/target/dataset")
@@ -121,7 +122,9 @@ class _TestCommonFeatures(unittest.TestCase):
node = self.h5["link/external_link_to_link"]
self.assertEqual(node.name, "/target/link")
class_ = self.h5.get("link/external_link_to_link", getclass=True)
- classlink = self.h5.get("link/external_link_to_link", getlink=True, getclass=True)
+ classlink = self.h5.get(
+ "link/external_link_to_link", getlink=True, getclass=True
+ )
self.assertEqual(class_, h5py.Dataset)
self.assertEqual(classlink, h5py.ExternalLink)
@@ -156,6 +159,7 @@ class _TestCommonFeatures(unittest.TestCase):
class TestCommonFeatures_h5py(_TestCommonFeatures):
"""Check if h5py is compliant with what we expect."""
+
__test__ = True # because _TestCommonFeatures is ignored
@classmethod
@@ -171,7 +175,9 @@ class TestCommonFeatures_h5py(_TestCommonFeatures):
h5["group/dataset"] = 50
h5["link/soft_link"] = h5py.SoftLink("/group/dataset")
h5["link/external_link"] = h5py.ExternalLink("external.h5", "/target/dataset")
- h5["link/external_link_to_link"] = h5py.ExternalLink("external.h5", "/target/link")
+ h5["link/external_link_to_link"] = h5py.ExternalLink(
+ "external.h5", "/target/link"
+ )
return h5
@@ -184,6 +190,7 @@ class TestCommonFeatures_h5py(_TestCommonFeatures):
class TestCommonFeatures_commonH5(_TestCommonFeatures):
"""Check if commonh5 is compliant with h5py."""
+
__test__ = True # because _TestCommonFeatures is ignored
@classmethod
@@ -265,7 +272,7 @@ class TestSpecificCommonH5(unittest.TestCase):
def test_create_unicode_dataset(self):
f = commonh5.File(name="Foo", mode="w")
try:
- f.create_dataset("foo", data=numpy.array(u"aaaa"))
+ f.create_dataset("foo", data=numpy.array("aaaa"))
self.fail()
except TypeError:
pass
diff --git a/src/silx/io/test/test_dictdump.py b/src/silx/io/test/test_dictdump.py
index e31d7a8..2bd376e 100644
--- a/src/silx/io/test/test_dictdump.py
+++ b/src/silx/io/test/test_dictdump.py
@@ -1,5 +1,5 @@
# /*##########################################################################
-# Copyright (C) 2016-2022 European Synchrotron Radiation Facility
+# Copyright (C) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -27,15 +27,16 @@ __license__ = "MIT"
__date__ = "17/01/2018"
-from collections import defaultdict, OrderedDict
+from collections import defaultdict
from copy import deepcopy
-from io import BytesIO
import os
+import re
import tempfile
import unittest
import h5py
import numpy
+
try:
import pint
except ImportError:
@@ -53,13 +54,6 @@ from ..utils import is_link
from ..utils import h5py_read_dataset
-@pytest.fixture
-def tmp_h5py_file():
- with BytesIO() as buffer:
- with h5py.File(buffer, mode="w") as h5file:
- yield h5file
-
-
def tree():
"""Tree data structure as a recursive nested dictionary"""
return defaultdict(tree)
@@ -82,20 +76,17 @@ link_attrs["links"]["group"]["dataset"] = 10
link_attrs["links"]["group"]["relative_softlink"] = h5py.SoftLink("dataset")
link_attrs["links"]["relative_softlink"] = h5py.SoftLink("group/dataset")
link_attrs["links"]["absolute_softlink"] = h5py.SoftLink("/links/group/dataset")
-link_attrs["links"]["external_link"] = h5py.ExternalLink(ext_filename, "/ext_group/dataset")
+link_attrs["links"]["external_link"] = h5py.ExternalLink(
+ ext_filename, "/ext_group/dataset"
+)
class DictTestCase(unittest.TestCase):
-
def assertRecursiveEqual(self, expected, actual, nodes=tuple()):
err_msg = "\n\n Tree nodes: {}".format(nodes)
if isinstance(expected, dict):
self.assertTrue(isinstance(actual, dict), msg=err_msg)
- self.assertEqual(
- set(expected.keys()),
- set(actual.keys()),
- msg=err_msg
- )
+ self.assertEqual(set(expected.keys()), set(actual.keys()), msg=err_msg)
for k in actual:
self.assertRecursiveEqual(
expected[k],
@@ -112,7 +103,6 @@ class DictTestCase(unittest.TestCase):
class H5DictTestCase(DictTestCase):
-
def _dictRoundTripNormalize(self, treedict):
"""Convert the dictionary as expected from a round-trip
treedict -> dicttoh5 -> h5todict -> newtreedict
@@ -155,12 +145,16 @@ class TestDictToH5(H5DictTestCase):
os.rmdir(self.tempdir)
def testH5CityAttrs(self):
- filters = {'shuffle': True,
- 'fletcher32': True}
- dicttoh5(city_attrs, self.h5_fname, h5path='/city attributes',
- mode="w", create_dataset_args=filters)
+ filters = {"shuffle": True, "fletcher32": True}
+ dicttoh5(
+ city_attrs,
+ self.h5_fname,
+ h5path="/city attributes",
+ mode="w",
+ create_dataset_args=filters,
+ )
- h5f = h5py.File(self.h5_fname, mode='r')
+ h5f = h5py.File(self.h5_fname, mode="r")
self.assertIn("Tourcoing/area", h5f["/city attributes/Europe/France"])
ds = h5f["/city attributes/Europe/France/Grenoble/inhabitants"]
@@ -168,7 +162,7 @@ class TestDictToH5(H5DictTestCase):
# filters only apply to datasets that are not scalars (shape != () )
ds = h5f["/city attributes/Europe/France/Grenoble/coordinates"]
- #self.assertEqual(ds.compression, "gzip")
+ # self.assertEqual(ds.compression, "gzip")
self.assertTrue(ds.fletcher32)
self.assertTrue(ds.shuffle)
@@ -176,25 +170,11 @@ class TestDictToH5(H5DictTestCase):
ddict = load(self.h5_fname, fmat="hdf5")
self.assertAlmostEqual(
- min(ddict["city attributes"]["Europe"]["France"]["Grenoble"]["coordinates"]),
- 5.7196)
-
- def testH5OverwriteDeprecatedApi(self):
- dd = ConfigDict({'t': True})
-
- dicttoh5(h5file=self.h5_fname, treedict=dd, mode='a')
- dd = ConfigDict({'t': False})
- dicttoh5(h5file=self.h5_fname, treedict=dd, mode='a',
- overwrite_data=False)
-
- res = h5todict(self.h5_fname)
- assert(res['t'] == True)
-
- dicttoh5(h5file=self.h5_fname, treedict=dd, mode='a',
- overwrite_data=True)
-
- res = h5todict(self.h5_fname)
- assert(res['t'] == False)
+ min(
+ ddict["city attributes"]["Europe"]["France"]["Grenoble"]["coordinates"]
+ ),
+ 5.7196,
+ )
def testAttributes(self):
"""Any kind of attribute can be described"""
@@ -207,15 +187,15 @@ class TestDictToH5(H5DictTestCase):
}
with h5py.File(self.h5_fname, "w") as h5file:
dictdump.dicttoh5(ddict, h5file)
- self.assertEqual(h5file["group"].attrs['group_attr'], 10)
- self.assertEqual(h5file.attrs['root_attr'], 11)
- self.assertEqual(h5file["dataset"].attrs['dataset_attr'], 12)
- self.assertEqual(h5file["group"].attrs['group_attr2'], 13)
+ self.assertEqual(h5file["group"].attrs["group_attr"], 10)
+ self.assertEqual(h5file.attrs["root_attr"], 11)
+ self.assertEqual(h5file["dataset"].attrs["dataset_attr"], 12)
+ self.assertEqual(h5file["group"].attrs["group_attr2"], 13)
def testPathAttributes(self):
"""A group is requested at a path"""
ddict = {
- ("", "NX_class"): 'NXcollection',
+ ("", "NX_class"): "NXcollection",
}
with h5py.File(self.h5_fname, "w") as h5file:
# This should not warn
@@ -234,8 +214,8 @@ class TestDictToH5(H5DictTestCase):
with h5py.File(self.h5_fname, "w") as h5file:
dictdump.dicttoh5(ddict1, h5file, h5path="g1")
dictdump.dicttoh5(ddict2, h5file, h5path="g2")
- self.assertEqual(h5file["g1/d"].attrs['a'], "ox")
- self.assertEqual(h5file["g2/d"].attrs['a'], "ox")
+ self.assertEqual(h5file["g1/d"].attrs["a"], "ox")
+ self.assertEqual(h5file["g2/d"].attrs["a"], "ox")
def testAttributeValues(self):
"""Any NX data types can be used"""
@@ -269,7 +249,7 @@ class TestDictToH5(H5DictTestCase):
}
with h5py.File(self.h5_fname, "w") as h5file:
dictdump.dicttoh5(ddict, h5file)
- self.assertEqual(h5file["group"].attrs['attr'], 10)
+ self.assertEqual(h5file["group"].attrs["attr"], 10)
def testFlatDict(self):
"""Description of a tree with a single level of keys"""
@@ -281,8 +261,8 @@ class TestDictToH5(H5DictTestCase):
with h5py.File(self.h5_fname, "w") as h5file:
dictdump.dicttoh5(ddict, h5file)
self.assertEqual(h5file["group/group/dataset"][()], 10)
- self.assertEqual(h5file["group/group/dataset"].attrs['attr'], 11)
- self.assertEqual(h5file["group/group"].attrs['attr'], 12)
+ self.assertEqual(h5file["group/group/dataset"].attrs["attr"], 11)
+ self.assertEqual(h5file["group/group"].attrs["attr"], 12)
def testLinks(self):
with h5py.File(self.h5_ext_fname, "w") as h5file:
@@ -298,15 +278,14 @@ class TestDictToH5(H5DictTestCase):
def testDumpNumpyArray(self):
ddict = {
- 'darks': {
- '0': numpy.array([[0, 0, 0], [0, 0, 0]], dtype=numpy.uint16)
- }
+ "darks": {"0": numpy.array([[0, 0, 0], [0, 0, 0]], dtype=numpy.uint16)}
}
with h5py.File(self.h5_fname, "w") as h5file:
dictdump.dicttoh5(ddict, h5file)
with h5py.File(self.h5_fname, "r") as h5file:
- numpy.testing.assert_array_equal(h5py_read_dataset(h5file["darks"]["0"]),
- ddict['darks']['0'])
+ numpy.testing.assert_array_equal(
+ h5py_read_dataset(h5file["darks"]["0"]), ddict["darks"]["0"]
+ )
def testOverwrite(self):
# Tree structure that will be tested
@@ -323,17 +302,17 @@ class TestDictToH5(H5DictTestCase):
"subgroup1": group1.copy(),
"subgroup2": group1.copy(),
("subgroup1", "attr1"): "original1",
- ("subgroup2", "attr1"): "original1"
+ ("subgroup2", "attr1"): "original1",
}
group2.update(group1)
# initial HDF5 tree
otreedict = {
- ('', 'attr1'): "original1",
- ('', 'attr2'): "original2",
- 'group1': group1,
- 'group2': group2,
- ('group1', 'attr1'): "original1",
- ('group2', 'attr1'): "original1"
+ ("", "attr1"): "original1",
+ ("", "attr2"): "original2",
+ "group1": group1,
+ "group2": group2,
+ ("group1", "attr1"): "original1",
+ ("group2", "attr1"): "original1",
}
wtreedict = None # dumped dictionary
etreedict = None # expected HDF5 tree after dump
@@ -346,24 +325,16 @@ class TestDictToH5(H5DictTestCase):
)
def append_file(update_mode):
- dicttoh5(
- wtreedict,
- h5file=self.h5_fname,
- mode="a",
- update_mode=update_mode
- )
+ dicttoh5(wtreedict, h5file=self.h5_fname, mode="a", update_mode=update_mode)
def assert_file():
- rtreedict = h5todict(
- self.h5_fname,
- include_attributes=True,
- asarray=False
- )
+ rtreedict = h5todict(self.h5_fname, include_attributes=True, asarray=False)
netreedict = self.dictRoundTripNormalize(etreedict)
try:
self.assertRecursiveEqual(netreedict, rtreedict)
except AssertionError:
from pprint import pprint
+
print("\nDUMP:")
pprint(wtreedict)
print("\nEXPECTED:")
@@ -379,10 +350,7 @@ class TestDictToH5(H5DictTestCase):
# Test wrong arguments
with self.assertRaises(ValueError):
dicttoh5(
- otreedict,
- h5file=self.h5_fname,
- mode="w",
- update_mode="wrong-value"
+ otreedict, h5file=self.h5_fname, mode="w", update_mode="wrong-value"
)
# No writing
@@ -540,6 +508,13 @@ def test_dicttoh5_pint(tmp_h5py_file):
assert numpy.array_equal(result[key], value.magnitude)
+def test_dicttoh5_not_serializable(tmp_h5py_file):
+ treedict = {"group": {"dset": [{"a": 1}]}}
+ err_msg = "Failed to create dataset '/group/dset' with data (numpy.ndarray-object) = [{'a': 1}]"
+ with pytest.raises(ValueError, match=re.escape(err_msg)):
+ dicttoh5(treedict, tmp_h5py_file)
+
+
class TestH5ToDict(H5DictTestCase):
def setUp(self):
self.tempdir = tempfile.mkdtemp()
@@ -557,8 +532,11 @@ class TestH5ToDict(H5DictTestCase):
os.rmdir(self.tempdir)
def testExcludeNames(self):
- ddict = h5todict(self.h5_fname, path="/Europe/France",
- exclude_names=["ourcoing", "inhab", "toto"])
+ ddict = h5todict(
+ self.h5_fname,
+ path="/Europe/France",
+ exclude_names=["ourcoing", "inhab", "toto"],
+ )
self.assertNotIn("Tourcoing", ddict)
self.assertIn("Grenoble", ddict)
@@ -569,7 +547,9 @@ class TestH5ToDict(H5DictTestCase):
def testAsArrayTrue(self):
"""Test with asarray=True, the default"""
ddict = h5todict(self.h5_fname, path="/Europe/France/Grenoble")
- self.assertTrue(numpy.array_equal(ddict["inhabitants"], numpy.array(inhabitants)))
+ self.assertTrue(
+ numpy.array_equal(ddict["inhabitants"], numpy.array(inhabitants))
+ )
def testAsArrayFalse(self):
"""Test with asarray=False"""
@@ -591,14 +571,16 @@ class TestH5ToDict(H5DictTestCase):
self.assertTrue(is_link(ddict["group"]["relative_softlink"]))
def testStrings(self):
- ddict = {"dset_bytes": b"bytes",
- "dset_utf8": "utf8",
- "dset_2bytes": [b"bytes", b"bytes"],
- "dset_2utf8": ["utf8", "utf8"],
- ("", "attr_bytes"): b"bytes",
- ("", "attr_utf8"): "utf8",
- ("", "attr_2bytes"): [b"bytes", b"bytes"],
- ("", "attr_2utf8"): ["utf8", "utf8"]}
+ ddict = {
+ "dset_bytes": b"bytes",
+ "dset_utf8": "utf8",
+ "dset_2bytes": [b"bytes", b"bytes"],
+ "dset_2utf8": ["utf8", "utf8"],
+ ("", "attr_bytes"): b"bytes",
+ ("", "attr_utf8"): "utf8",
+ ("", "attr_2bytes"): [b"bytes", b"bytes"],
+ ("", "attr_2utf8"): ["utf8", "utf8"],
+ }
dicttoh5(ddict, self.h5_fname, mode="w")
adict = h5todict(self.h5_fname, include_attributes=True, asarray=False)
self.assertEqual(ddict["dset_bytes"], adict["dset_bytes"])
@@ -607,8 +589,12 @@ class TestH5ToDict(H5DictTestCase):
self.assertEqual(ddict[("", "attr_utf8")], adict[("", "attr_utf8")])
numpy.testing.assert_array_equal(ddict["dset_2bytes"], adict["dset_2bytes"])
numpy.testing.assert_array_equal(ddict["dset_2utf8"], adict["dset_2utf8"])
- numpy.testing.assert_array_equal(ddict[("", "attr_2bytes")], adict[("", "attr_2bytes")])
- numpy.testing.assert_array_equal(ddict[("", "attr_2utf8")], adict[("", "attr_2utf8")])
+ numpy.testing.assert_array_equal(
+ ddict[("", "attr_2bytes")], adict[("", "attr_2bytes")]
+ )
+ numpy.testing.assert_array_equal(
+ ddict[("", "attr_2utf8")], adict[("", "attr_2utf8")]
+ )
class TestDictToNx(H5DictTestCase):
@@ -635,10 +621,10 @@ class TestDictToNx(H5DictTestCase):
}
with h5py.File(self.h5_fname, "w") as h5file:
dictdump.dicttonx(ddict, h5file)
- self.assertEqual(h5file["group"].attrs['group_attr1'], 10)
- self.assertEqual(h5file.attrs['root_attr'], 11)
- self.assertEqual(h5file["dataset"].attrs['dataset_attr'], "12")
- self.assertEqual(h5file["group"].attrs['group_attr2'], 13)
+ self.assertEqual(h5file["group"].attrs["group_attr1"], 10)
+ self.assertEqual(h5file.attrs["root_attr"], 11)
+ self.assertEqual(h5file["dataset"].attrs["dataset_attr"], "12")
+ self.assertEqual(h5file["group"].attrs["group_attr2"], 13)
def testKeyOrder(self):
ddict1 = {
@@ -652,8 +638,8 @@ class TestDictToNx(H5DictTestCase):
with h5py.File(self.h5_fname, "w") as h5file:
dictdump.dicttonx(ddict1, h5file, h5path="g1")
dictdump.dicttonx(ddict2, h5file, h5path="g2")
- self.assertEqual(h5file["g1/d"].attrs['a'], "ox")
- self.assertEqual(h5file["g2/d"].attrs['a'], "ox")
+ self.assertEqual(h5file["g1/d"].attrs["a"], "ox")
+ self.assertEqual(h5file["g2/d"].attrs["a"], "ox")
def testAttributeValues(self):
"""Any NX data types can be used"""
@@ -689,16 +675,20 @@ class TestDictToNx(H5DictTestCase):
with h5py.File(self.h5_fname, "w") as h5file:
dictdump.dicttonx(ddict, h5file)
self.assertEqual(h5file["group/group/dataset"][()], 10)
- self.assertEqual(h5file["group/group/dataset"].attrs['attr'], 11)
- self.assertEqual(h5file["group/group"].attrs['attr'], 12)
+ self.assertEqual(h5file["group/group/dataset"].attrs["attr"], 11)
+ self.assertEqual(h5file["group/group"].attrs["attr"], 12)
def testLinks(self):
ddict = {"ext_group": {"dataset": 10}}
dictdump.dicttonx(ddict, self.h5_ext_fname)
- ddict = {"links": {"group": {"dataset": 10, ">relative_softlink": "dataset"},
- ">relative_softlink": "group/dataset",
- ">absolute_softlink": "/links/group/dataset",
- ">external_link": "nx_ext.h5::/ext_group/dataset"}}
+ ddict = {
+ "links": {
+ "group": {"dataset": 10, ">relative_softlink": "dataset"},
+ ">relative_softlink": "group/dataset",
+ ">absolute_softlink": "/links/group/dataset",
+ ">external_link": "nx_ext.h5::/ext_group/dataset",
+ }
+ }
dictdump.dicttonx(ddict, self.h5_fname)
with h5py.File(self.h5_fname, "r") as h5file:
self.assertEqual(h5file["links/group/dataset"][()], 10)
@@ -708,8 +698,14 @@ class TestDictToNx(H5DictTestCase):
self.assertEqual(h5file["links/external_link"][()], 10)
def testUpLinks(self):
- ddict = {"data": {"group": {"dataset": 10, ">relative_softlink": "dataset"}},
- "links": {"group": {"subgroup": {">relative_softlink": "../../../data/group/dataset"}}}}
+ ddict = {
+ "data": {"group": {"dataset": 10, ">relative_softlink": "dataset"}},
+ "links": {
+ "group": {
+ "subgroup": {">relative_softlink": "../../../data/group/dataset"}
+ }
+ },
+ }
dictdump.dicttonx(ddict, self.h5_fname)
with h5py.File(self.h5_fname, "r") as h5file:
self.assertEqual(h5file["/links/group/subgroup/relative_softlink"][()], 10)
@@ -766,7 +762,7 @@ class TestDictToNx(H5DictTestCase):
mode="a",
h5path=entry_name,
update_mode=update_mode,
- add_nx_class=add_nx_class
+ add_nx_class=add_nx_class,
)
def assert_file():
@@ -780,6 +776,7 @@ class TestDictToNx(H5DictTestCase):
self.assertRecursiveEqual(netreedict, rtreedict)
except AssertionError:
from pprint import pprint
+
print("\nDUMP:")
pprint(wtreedict)
print("\nEXPECTED:")
@@ -877,10 +874,14 @@ class TestNxToDict(H5DictTestCase):
"""Write links and dereference on read"""
ddict = {"ext_group": {"dataset": 10}}
dictdump.dicttonx(ddict, self.h5_ext_fname)
- ddict = {"links": {"group": {"dataset": 10, ">relative_softlink": "dataset"},
- ">relative_softlink": "group/dataset",
- ">absolute_softlink": "/links/group/dataset",
- ">external_link": "nx_ext.h5::/ext_group/dataset"}}
+ ddict = {
+ "links": {
+ "group": {"dataset": 10, ">relative_softlink": "dataset"},
+ ">relative_softlink": "group/dataset",
+ ">absolute_softlink": "/links/group/dataset",
+ ">external_link": "nx_ext.h5::/ext_group/dataset",
+ }
+ }
dictdump.dicttonx(ddict, self.h5_fname)
ddict = dictdump.h5todict(self.h5_fname, dereference_links=True)
@@ -893,48 +894,57 @@ class TestNxToDict(H5DictTestCase):
"""Write/read links"""
ddict = {"ext_group": {"dataset": 10}}
dictdump.dicttonx(ddict, self.h5_ext_fname)
- ddict = {"links": {"group": {"dataset": 10, ">relative_softlink": "dataset"},
- ">relative_softlink": "group/dataset",
- ">absolute_softlink": "/links/group/dataset",
- ">external_link": "nx_ext.h5::/ext_group/dataset"}}
+ ddict = {
+ "links": {
+ "group": {"dataset": 10, ">relative_softlink": "dataset"},
+ ">relative_softlink": "group/dataset",
+ ">absolute_softlink": "/links/group/dataset",
+ ">external_link": "nx_ext.h5::/ext_group/dataset",
+ }
+ }
dictdump.dicttonx(ddict, self.h5_fname)
ddict = dictdump.nxtodict(self.h5_fname, dereference_links=False)
self.assertTrue(ddict["links"][">absolute_softlink"], "dataset")
self.assertTrue(ddict["links"][">relative_softlink"], "group/dataset")
self.assertTrue(ddict["links"][">external_link"], "/links/group/dataset")
- self.assertTrue(ddict["links"]["group"][">relative_softlink"], "nx_ext.h5::/ext_group/datase")
+ self.assertTrue(
+ ddict["links"]["group"][">relative_softlink"],
+ "nx_ext.h5::/ext_group/datase",
+ )
def testNotExistingPath(self):
"""Test converting not existing path"""
- with h5py.File(self.h5_fname, 'a') as f:
- f['data'] = 1
+ with h5py.File(self.h5_fname, "a") as f:
+ f["data"] = 1
- ddict = h5todict(self.h5_fname, path="/I/am/not/a/path", errors='ignore')
+ ddict = h5todict(self.h5_fname, path="/I/am/not/a/path", errors="ignore")
self.assertFalse(ddict)
with LoggingValidator(dictdump_logger, error=1):
- ddict = h5todict(self.h5_fname, path="/I/am/not/a/path", errors='log')
+ ddict = h5todict(self.h5_fname, path="/I/am/not/a/path", errors="log")
self.assertFalse(ddict)
with self.assertRaises(KeyError):
- h5todict(self.h5_fname, path="/I/am/not/a/path", errors='raise')
+ h5todict(self.h5_fname, path="/I/am/not/a/path", errors="raise")
def testBrokenLinks(self):
"""Test with broken links"""
- with h5py.File(self.h5_fname, 'a') as f:
+ with h5py.File(self.h5_fname, "a") as f:
f["/Mars/BrokenSoftLink"] = h5py.SoftLink("/Idontexists")
- f["/Mars/BrokenExternalLink"] = h5py.ExternalLink("notexistingfile.h5", "/Idontexists")
+ f["/Mars/BrokenExternalLink"] = h5py.ExternalLink(
+ "notexistingfile.h5", "/Idontexists"
+ )
- ddict = h5todict(self.h5_fname, path="/Mars", errors='ignore')
+ ddict = h5todict(self.h5_fname, path="/Mars", errors="ignore")
self.assertFalse(ddict)
with LoggingValidator(dictdump_logger, error=2):
- ddict = h5todict(self.h5_fname, path="/Mars", errors='log')
+ ddict = h5todict(self.h5_fname, path="/Mars", errors="log")
self.assertFalse(ddict)
with self.assertRaises(KeyError):
- h5todict(self.h5_fname, path="/Mars", errors='raise')
+ h5todict(self.h5_fname, path="/Mars", errors="raise")
class TestDictToJson(DictTestCase):
@@ -968,86 +978,92 @@ class TestDictToIni(DictTestCase):
"""Ensure values and types of data is preserved when dictionary is
written to file and read back."""
testdict = {
- 'simple_types': {
- 'float': 1.0,
- 'int': 1,
- 'percent string': '5 % is too much',
- 'backslash string': 'i can use \\',
- 'empty_string': '',
- 'nonestring': 'None',
- 'nonetype': None,
- 'interpstring': 'interpolation: %(percent string)s',
+ "simple_types": {
+ "float": 1.0,
+ "int": 1,
+ "percent string": "5 % is too much",
+ "backslash string": "i can use \\",
+ "empty_string": "",
+ "nonestring": "None",
+ "nonetype": None,
+ "interpstring": "interpolation: %(percent string)s",
+ },
+ "containers": {
+ "list": [-1, "string", 3.0, False, None],
+ "array": numpy.array([1.0, 2.0, 3.0]),
+ "dict": {
+ "key1": "Hello World",
+ "key2": 2.0,
+ },
},
- 'containers': {
- 'list': [-1, 'string', 3.0, False, None],
- 'array': numpy.array([1.0, 2.0, 3.0]),
- 'dict': {
- 'key1': 'Hello World',
- 'key2': 2.0,
- }
- }
}
dump(testdict, self.ini_fname)
- #read the data back
+ # read the data back
readdict = load(self.ini_fname)
testdictkeys = list(testdict.keys())
readkeys = list(readdict.keys())
- self.assertTrue(len(readkeys) == len(testdictkeys),
- "Number of read keys not equal")
+ self.assertTrue(
+ len(readkeys) == len(testdictkeys), "Number of read keys not equal"
+ )
- self.assertEqual(readdict['simple_types']["interpstring"],
- "interpolation: 5 % is too much")
+ self.assertEqual(
+ readdict["simple_types"]["interpstring"], "interpolation: 5 % is too much"
+ )
- testdict['simple_types']["interpstring"] = "interpolation: 5 % is too much"
+ testdict["simple_types"]["interpstring"] = "interpolation: 5 % is too much"
for key in testdict["simple_types"]:
- original = testdict['simple_types'][key]
- read = readdict['simple_types'][key]
- self.assertEqual(read, original,
- "Read <%s> instead of <%s>" % (read, original))
+ original = testdict["simple_types"][key]
+ read = readdict["simple_types"][key]
+ self.assertEqual(
+ read, original, "Read <%s> instead of <%s>" % (read, original)
+ )
for key in testdict["containers"]:
original = testdict["containers"][key]
read = readdict["containers"][key]
- if key == 'array':
- self.assertEqual(read.all(), original.all(),
- "Read <%s> instead of <%s>" % (read, original))
+ if key == "array":
+ self.assertEqual(
+ read.all(),
+ original.all(),
+ "Read <%s> instead of <%s>" % (read, original),
+ )
else:
- self.assertEqual(read, original,
- "Read <%s> instead of <%s>" % (read, original))
+ self.assertEqual(
+ read, original, "Read <%s> instead of <%s>" % (read, original)
+ )
def testConfigDictOrder(self):
"""Ensure order is preserved when dictionary is
written to file and read back."""
- test_dict = {'banana': 3, 'apple': 4, 'pear': 1, 'orange': 2}
+ test_dict = {"banana": 3, "apple": 4, "pear": 1, "orange": 2}
# sort by key
- test_ordered_dict1 = OrderedDict(sorted(test_dict.items(),
- key=lambda t: t[0]))
+ test_ordered_dict1 = dict(sorted(test_dict.items(), key=lambda t: t[0]))
# sort by value
- test_ordered_dict2 = OrderedDict(sorted(test_dict.items(),
- key=lambda t: t[1]))
+ test_ordered_dict2 = dict(sorted(test_dict.items(), key=lambda t: t[1]))
# add the two ordered dict as sections of a third ordered dict
- test_ordered_dict3 = OrderedDict()
+ test_ordered_dict3 = {}
test_ordered_dict3["section1"] = test_ordered_dict1
test_ordered_dict3["section2"] = test_ordered_dict2
- # write to ini and read back as a ConfigDict (inherits OrderedDict)
- dump(test_ordered_dict3,
- self.ini_fname, fmat="ini")
+ # write to ini and read back as a ConfigDict
+ dump(test_ordered_dict3, self.ini_fname, fmat="ini")
read_instance = ConfigDict()
read_instance.read(self.ini_fname)
# loop through original and read-back dictionaries,
# test identical order for key/value pairs
- for orig_key, section in zip(test_ordered_dict3.keys(),
- read_instance.keys()):
+ for orig_key, section in zip(test_ordered_dict3.keys(), read_instance.keys()):
self.assertEqual(orig_key, section)
- for orig_key2, read_key in zip(test_ordered_dict3[section].keys(),
- read_instance[section].keys()):
+ for orig_key2, read_key in zip(
+ test_ordered_dict3[section].keys(), read_instance[section].keys()
+ ):
self.assertEqual(orig_key2, read_key)
- self.assertEqual(test_ordered_dict3[section][orig_key2],
- read_instance[section][read_key])
+ self.assertEqual(
+ test_ordered_dict3[section][orig_key2],
+ read_instance[section][read_key],
+ )
diff --git a/src/silx/io/test/test_fabioh5.py b/src/silx/io/test/test_fabioh5.py
index fdeb1c3..9c92f15 100755
--- a/src/silx/io/test/test_fabioh5.py
+++ b/src/silx/io/test/test_fabioh5.py
@@ -1,5 +1,5 @@
# /*##########################################################################
-# Copyright (C) 2016-2018 European Synchrotron Radiation Facility
+# Copyright (C) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -36,6 +36,7 @@ import shutil
_logger = logging.getLogger(__name__)
import fabio
+import fabio.file_series
import h5py
from .. import commonh5
@@ -43,9 +44,7 @@ from .. import fabioh5
class TestFabioH5(unittest.TestCase):
-
def setUp(self):
-
header = {
"integer": "-100",
"float": "1.0",
@@ -191,14 +190,16 @@ class TestFabioH5(unittest.TestCase):
self.assertEqual(dataset[0, 1], 2.0)
def test_metadata_list_looks_like_list(self):
- dataset = self.h5_image["/scan_0/instrument/detector_0/others/string_looks_like_list"]
+ dataset = self.h5_image[
+ "/scan_0/instrument/detector_0/others/string_looks_like_list"
+ ]
self.assertEqual(dataset.h5py_class, h5py.Dataset)
self.assertEqual(dataset[()], numpy.string_("2000 hi!"))
self.assertEqual(dataset.dtype.type, numpy.string_)
self.assertEqual(dataset.shape, (1,))
def test_float_32(self):
- float_list = [u'1.2', u'1.3', u'1.4']
+ float_list = ["1.2", "1.3", "1.4"]
data = numpy.array([[0, 0], [0, 0]], dtype=numpy.int8)
fabio_image = None
for float_item in float_list:
@@ -212,15 +213,22 @@ class TestFabioH5(unittest.TestCase):
# There is no equality between items
self.assertEqual(len(data), len(set(data)))
# At worst a float32
- self.assertIn(data.dtype.kind, ['d', 'f'])
+ self.assertIn(data.dtype.kind, ["d", "f"])
self.assertLessEqual(data.dtype.itemsize, 32 / 8)
def test_float_64(self):
float_list = [
- u'1469117129.082226',
- u'1469117136.684986', u'1469117144.312749', u'1469117151.892507',
- u'1469117159.474265', u'1469117167.100027', u'1469117174.815799',
- u'1469117182.437561', u'1469117190.094326', u'1469117197.721089']
+ "1469117129.082226",
+ "1469117136.684986",
+ "1469117144.312749",
+ "1469117151.892507",
+ "1469117159.474265",
+ "1469117167.100027",
+ "1469117174.815799",
+ "1469117182.437561",
+ "1469117190.094326",
+ "1469117197.721089",
+ ]
data = numpy.array([[0, 0], [0, 0]], dtype=numpy.int8)
fabio_image = None
for float_item in float_list:
@@ -234,12 +242,12 @@ class TestFabioH5(unittest.TestCase):
# There is no equality between items
self.assertEqual(len(data), len(set(data)))
# At least a float64
- self.assertIn(data.dtype.kind, ['d', 'f'])
+ self.assertIn(data.dtype.kind, ["d", "f"])
self.assertGreaterEqual(data.dtype.itemsize, 64 / 8)
def test_mixed_float_size__scalar(self):
# We expect to have a precision of 32 bits
- float_list = [u'1.2', u'1.3001']
+ float_list = ["1.2", "1.3001"]
expected_float_result = [1.2, 1.3001]
data = numpy.array([[0, 0], [0, 0]], dtype=numpy.int8)
fabio_image = None
@@ -252,14 +260,14 @@ class TestFabioH5(unittest.TestCase):
h5_image = fabioh5.File(fabio_image=fabio_image)
data = h5_image["/scan_0/instrument/detector_0/others/float_item"]
# At worst a float32
- self.assertIn(data.dtype.kind, ['d', 'f'])
+ self.assertIn(data.dtype.kind, ["d", "f"])
self.assertLessEqual(data.dtype.itemsize, 32 / 8)
for computed, expected in zip(data, expected_float_result):
numpy.testing.assert_almost_equal(computed, expected, 5)
def test_mixed_float_size__list(self):
# We expect to have a precision of 32 bits
- float_list = [u'1.2 1.3001']
+ float_list = ["1.2 1.3001"]
expected_float_result = numpy.array([[1.2, 1.3001]])
data = numpy.array([[0, 0], [0, 0]], dtype=numpy.int8)
fabio_image = None
@@ -272,14 +280,14 @@ class TestFabioH5(unittest.TestCase):
h5_image = fabioh5.File(fabio_image=fabio_image)
data = h5_image["/scan_0/instrument/detector_0/others/float_item"]
# At worst a float32
- self.assertIn(data.dtype.kind, ['d', 'f'])
+ self.assertIn(data.dtype.kind, ["d", "f"])
self.assertLessEqual(data.dtype.itemsize, 32 / 8)
for computed, expected in zip(data, expected_float_result):
numpy.testing.assert_almost_equal(computed, expected, 5)
def test_mixed_float_size__list_of_list(self):
# We expect to have a precision of 32 bits
- float_list = [u'1.2 1.3001', u'1.3001 1.3001']
+ float_list = ["1.2 1.3001", "1.3001 1.3001"]
expected_float_result = numpy.array([[1.2, 1.3001], [1.3001, 1.3001]])
data = numpy.array([[0, 0], [0, 0]], dtype=numpy.int8)
fabio_image = None
@@ -292,7 +300,7 @@ class TestFabioH5(unittest.TestCase):
h5_image = fabioh5.File(fabio_image=fabio_image)
data = h5_image["/scan_0/instrument/detector_0/others/float_item"]
# At worst a float32
- self.assertIn(data.dtype.kind, ['d', 'f'])
+ self.assertIn(data.dtype.kind, ["d", "f"])
self.assertLessEqual(data.dtype.itemsize, 32 / 8)
for computed, expected in zip(data, expected_float_result):
numpy.testing.assert_almost_equal(computed, expected, 5)
@@ -300,10 +308,12 @@ class TestFabioH5(unittest.TestCase):
def test_ub_matrix(self):
"""Data from mediapix.edf"""
header = {}
- header["UB_mne"] = 'UB0 UB1 UB2 UB3 UB4 UB5 UB6 UB7 UB8'
- header["UB_pos"] = '1.99593e-16 2.73682e-16 -1.54 -1.08894 1.08894 1.6083e-16 1.08894 1.08894 9.28619e-17'
- header["sample_mne"] = 'U0 U1 U2 U3 U4 U5'
- header["sample_pos"] = '4.08 4.08 4.08 90 90 90'
+ header["UB_mne"] = "UB0 UB1 UB2 UB3 UB4 UB5 UB6 UB7 UB8"
+ header[
+ "UB_pos"
+ ] = "1.99593e-16 2.73682e-16 -1.54 -1.08894 1.08894 1.6083e-16 1.08894 1.08894 9.28619e-17"
+ header["sample_mne"] = "U0 U1 U2 U3 U4 U5"
+ header["sample_pos"] = "4.08 4.08 4.08 90 90 90"
data = numpy.array([[0, 0], [0, 0]], dtype=numpy.int8)
fabio_image = fabio.edfimage.EdfImage(data=data, header=header)
h5_image = fabioh5.File(fabio_image=fabio_image)
@@ -311,27 +321,33 @@ class TestFabioH5(unittest.TestCase):
self.assertIsNotNone(sample)
self.assertEqual(sample.attrs["NXclass"], "NXsample")
- d = sample['unit_cell_abc']
+ d = sample["unit_cell_abc"]
expected = numpy.array([4.08, 4.08, 4.08])
self.assertIsNotNone(d)
- self.assertEqual(d.shape, (3, ))
- self.assertIn(d.dtype.kind, ['d', 'f'])
+ self.assertEqual(d.shape, (3,))
+ self.assertIn(d.dtype.kind, ["d", "f"])
numpy.testing.assert_array_almost_equal(d[...], expected)
- d = sample['unit_cell_alphabetagamma']
+ d = sample["unit_cell_alphabetagamma"]
expected = numpy.array([90.0, 90.0, 90.0])
self.assertIsNotNone(d)
- self.assertEqual(d.shape, (3, ))
- self.assertIn(d.dtype.kind, ['d', 'f'])
+ self.assertEqual(d.shape, (3,))
+ self.assertIn(d.dtype.kind, ["d", "f"])
numpy.testing.assert_array_almost_equal(d[...], expected)
- d = sample['ub_matrix']
- expected = numpy.array([[[1.99593e-16, 2.73682e-16, -1.54],
- [-1.08894, 1.08894, 1.6083e-16],
- [1.08894, 1.08894, 9.28619e-17]]])
+ d = sample["ub_matrix"]
+ expected = numpy.array(
+ [
+ [
+ [1.99593e-16, 2.73682e-16, -1.54],
+ [-1.08894, 1.08894, 1.6083e-16],
+ [1.08894, 1.08894, 9.28619e-17],
+ ]
+ ]
+ )
self.assertIsNotNone(d)
self.assertEqual(d.shape, (1, 3, 3))
- self.assertIn(d.dtype.kind, ['d', 'f'])
+ self.assertIn(d.dtype.kind, ["d", "f"])
numpy.testing.assert_array_almost_equal(d[...], expected)
def test_interpretation_mca_edf(self):
@@ -341,7 +357,8 @@ class TestFabioH5(unittest.TestCase):
"Title": "zapimage samy -4.975 -5.095 80 500 samz -4.091 -4.171 70 0",
"MCA a": -23.812,
"MCA b": 2.7107,
- "MCA c": 8.1164e-06}
+ "MCA c": 8.1164e-06,
+ }
data = numpy.array([[0, 0], [0, 0]], dtype=numpy.int8)
fabio_image = fabio.edfimage.EdfImage(data=data, header=header)
@@ -371,7 +388,9 @@ class TestFabioH5(unittest.TestCase):
detector2 = self.h5_image["/scan_0/measurement/image_0/info"]
self.assertIsNot(detector1, detector2)
self.assertEqual(list(detector1.items()), list(detector2.items()))
- self.assertEqual(self.h5_image.get(detector2.name, getlink=True).path, detector1.name)
+ self.assertEqual(
+ self.h5_image.get(detector2.name, getlink=True).path, detector1.name
+ )
def test_detector_data_link(self):
data1 = self.h5_image["/scan_0/instrument/detector_0/data"]
@@ -384,11 +403,11 @@ class TestFabioH5(unittest.TestCase):
"""Test that it does not fail"""
try:
header = {}
- header["foo"] = b'abc'
+ header["foo"] = b"abc"
data = numpy.array([[0, 0], [0, 0]], dtype=numpy.int8)
fabio_image = fabio.edfimage.edfimage(data=data, header=header)
header = {}
- header["foo"] = b'a\x90bc\xFE'
+ header["foo"] = b"a\x90bc\xFE"
fabio_image.append_frame(data=data, header=header)
except Exception as e:
_logger.error(e.args[0])
@@ -405,11 +424,11 @@ class TestFabioH5(unittest.TestCase):
"""Test that it does not fail"""
try:
header = {}
- header["foo"] = b'abc'
+ header["foo"] = b"abc"
data = numpy.array([[0, 0], [0, 0]], dtype=numpy.int8)
fabio_image = fabio.edfimage.edfimage(data=data, header=header)
header = {}
- header["foo"] = u'abc\u2764'
+ header["foo"] = "abc\u2764"
fabio_image.append_frame(data=data, header=header)
except Exception as e:
_logger.error(e.args[0])
@@ -424,13 +443,10 @@ class TestFabioH5(unittest.TestCase):
class TestFabioH5MultiFrames(unittest.TestCase):
-
@classmethod
def setUpClass(cls):
-
names = ["A", "B", "C", "D"]
- values = [["32000", "-10", "5.0", "1"],
- ["-32000", "-10", "5.0", "1"]]
+ values = [["32000", "-10", "5.0", "1"], ["-32000", "-10", "5.0", "1"]]
fabio_file = None
@@ -446,7 +462,7 @@ class TestFabioH5MultiFrames(unittest.TestCase):
"motor_mne": " ".join(names),
"motor_pos": " ".join(values[i % len(values)]),
"counter_mne": " ".join(names),
- "counter_pos": " ".join(values[i % len(values)])
+ "counter_pos": " ".join(values[i % len(values)]),
}
for iname, name in enumerate(names):
header[name] = values[i % len(values)][iname]
@@ -509,10 +525,8 @@ class TestFabioH5MultiFrames(unittest.TestCase):
class TestFabioH5WithEdf(unittest.TestCase):
-
@classmethod
def setUpClass(cls):
-
cls.tmp_directory = tempfile.mkdtemp()
cls.edf_filename = os.path.join(cls.tmp_directory, "test.edf")
@@ -550,15 +564,14 @@ class TestFabioH5WithEdf(unittest.TestCase):
class _TestableFrameData(fabioh5.FrameData):
"""Allow to test if the full data is reached."""
+
def _create_data(self):
raise RuntimeError("Not supposed to be called")
class TestFabioH5WithFileSeries(unittest.TestCase):
-
@classmethod
def setUpClass(cls):
-
cls.tmp_directory = tempfile.mkdtemp()
cls.edf_filenames = []
@@ -602,12 +615,12 @@ class TestFabioH5WithFileSeries(unittest.TestCase):
self._testH5Image(h5_image)
def testFileSeries(self):
- file_series = fabioh5._FileSeries(self.edf_filenames)
+ file_series = fabio.file_series.file_series(self.edf_filenames)
h5_image = fabioh5.File(file_series=file_series)
self._testH5Image(h5_image)
def testFrameDataCache(self):
- file_series = fabioh5._FileSeries(self.edf_filenames)
+ file_series = fabio.file_series.file_series(self.edf_filenames)
reader = fabioh5.FabioReader(file_series=file_series)
frameData = _TestableFrameData("foo", reader)
self.assertEqual(frameData.dtype.kind, "i")
diff --git a/src/silx/io/test/test_fioh5.py b/src/silx/io/test/test_fioh5.py
index 18200c9..fed22a2 100644
--- a/src/silx/io/test/test_fioh5.py
+++ b/src/silx/io/test/test_fioh5.py
@@ -23,19 +23,11 @@
"""Tests for fioh5"""
import numpy
import os
-import io
-import sys
import tempfile
import unittest
-import datetime
-import logging
-from silx.utils import testutils
+from ..fioh5 import FioH5, is_fiofile, logger1, dtypeConverter
-from .. import fioh5
-from ..fioh5 import (FioH5, FioH5NodeDataset, is_fiofile, logger1, dtypeConverter)
-
-import h5py
__authors__ = ["T. Fuchs"]
__license__ = "MIT"
@@ -80,15 +72,14 @@ ScanName = ascan
"""
-
class TestFioH5(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.temp_dir = tempfile.TemporaryDirectory()
- #fd, cls.fname = tempfile.mkstemp()
+ # fd, cls.fname = tempfile.mkstemp()
cls.fname_numbered = os.path.join(cls.temp_dir.name, "eh1scan_00005.fio")
-
- with open(cls.fname_numbered, 'w') as fiof:
+
+ with open(cls.fname_numbered, "w") as fiof:
fiof.write(fioftext)
@classmethod
@@ -98,10 +89,10 @@ class TestFioH5(unittest.TestCase):
def setUp(self):
self.fioh5 = FioH5(self.fname_numbered)
-
+
def tearDown(self):
self.fioh5.close()
-
+
def testScanNumber(self):
# scan number is parsed from the file name.
self.assertIn("/5.1", self.fioh5)
@@ -121,7 +112,7 @@ class TestFioH5(unittest.TestCase):
self.assertNotIn("/5.1/measurement/omega(encoder)/", self.fioh5)
# No gamma
self.assertNotIn("/5.1/measurement/gamma", self.fioh5)
-
+
def testContainsGroup(self):
self.assertIn("measurement", self.fioh5["/5.1/"])
self.assertIn("measurement", self.fioh5["/5.1"])
@@ -129,99 +120,101 @@ class TestFioH5(unittest.TestCase):
self.assertNotIn("5.2", self.fioh5["/"])
self.assertIn("measurement/filename", self.fioh5["/5.1"])
# illegal trailing "/" after dataset name
- self.assertNotIn("measurement/filename/",
- self.fioh5["/5.1"])
+ self.assertNotIn("measurement/filename/", self.fioh5["/5.1"])
# full path to element in group (OK)
- self.assertIn("/5.1/measurement/filename",
- self.fioh5["/5.1/measurement"])
-
+ self.assertIn("/5.1/measurement/filename", self.fioh5["/5.1/measurement"])
+
def testDataType(self):
meas = self.fioh5["/5.1/measurement/"]
- self.assertEqual(meas["omega(encoder)"].dtype, dtypeConverter['DOUBLE'])
- self.assertEqual(meas["channel"].dtype, dtypeConverter['INTEGER'])
- self.assertEqual(meas["filename"].dtype, dtypeConverter['STRING'])
- self.assertEqual(meas["time_s"].dtype, dtypeConverter['FLOAT'])
- self.assertEqual(meas["enable"].dtype, dtypeConverter['BOOLEAN'])
-
+ self.assertEqual(meas["omega(encoder)"].dtype, dtypeConverter["DOUBLE"])
+ self.assertEqual(meas["channel"].dtype, dtypeConverter["INTEGER"])
+ self.assertEqual(meas["filename"].dtype, dtypeConverter["STRING"])
+ self.assertEqual(meas["time_s"].dtype, dtypeConverter["FLOAT"])
+ self.assertEqual(meas["enable"].dtype, dtypeConverter["BOOLEAN"])
+
def testDataColumn(self):
- self.assertAlmostEqual(sum(self.fioh5["/5.1/measurement/omega(encoder)"]),
- 1802.23418821)
+ self.assertAlmostEqual(
+ sum(self.fioh5["/5.1/measurement/omega(encoder)"]), 1802.23418821
+ )
self.assertTrue(numpy.all(self.fioh5["/5.1/measurement/enable"]))
-
+
# --- comment section tests ---
-
+
def testComment(self):
# should hold the complete comment section
- self.assertEqual(self.fioh5["/5.1/instrument/fiofile/comments"],
-"""ascan omega 180.0 180.5 3:10/1 4
+ self.assertEqual(
+ self.fioh5["/5.1/instrument/fiofile/comments"],
+ """ascan omega 180.0 180.5 3:10/1 4
user username, acquisition started at Thu Dec 12 18:00:00 2021
sweep motor lag: 1.0e-03
channel 3: Detector
-""")
-
+""",
+ )
+
def testDate(self):
# there is no convention on how to format the time. So just check its existence.
- self.assertEqual(self.fioh5["/5.1/start_time"],
- u"Thu Dec 12 18:00:00 2021")
-
+ self.assertEqual(self.fioh5["/5.1/start_time"], "Thu Dec 12 18:00:00 2021")
+
def testTitle(self):
- self.assertEqual(self.fioh5["/5.1/title"],
- u"ascan omega 180.0 180.5 3:10/1 4")
-
-
+ self.assertEqual(self.fioh5["/5.1/title"], "ascan omega 180.0 180.5 3:10/1 4")
+
# --- parameter section tests ---
-
+
def testParameter(self):
# should hold the complete parameter section
- self.assertEqual(self.fioh5["/5.1/instrument/fiofile/parameter"],
-"""channel3_exposure = 1.000000e+00
+ self.assertEqual(
+ self.fioh5["/5.1/instrument/fiofile/parameter"],
+ """channel3_exposure = 1.000000e+00
ScanName = ascan
-""")
-
+""",
+ )
+
def testParsedParameter(self):
# no dtype is given, so everything is str.
- self.assertEqual(self.fioh5["/5.1/instrument/parameter/channel3_exposure"],
- u"1.000000e+00")
- self.assertEqual(self.fioh5["/5.1/instrument/parameter/ScanName"], u"ascan")
-
+ self.assertEqual(
+ self.fioh5["/5.1/instrument/parameter/channel3_exposure"], "1.000000e+00"
+ )
+ self.assertEqual(self.fioh5["/5.1/instrument/parameter/ScanName"], "ascan")
+
def testNotFioH5(self):
testfilename = os.path.join(self.temp_dir.name, "eh1scan_00010.fio")
- with open(testfilename, 'w') as fiof:
+ with open(testfilename, "w") as fiof:
fiof.write("!Not a fio file!")
self.assertRaises(IOError, FioH5, testfilename)
-
+
self.assertTrue(is_fiofile(self.fname_numbered))
self.assertFalse(is_fiofile(testfilename))
-
+
os.unlink(testfilename)
-
+
class TestUnnumberedFioH5(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.temp_dir = tempfile.TemporaryDirectory()
cls.fname_nosuffix = os.path.join(cls.temp_dir.name, "eh1scan_nosuffix.fio")
-
- with open(cls.fname_nosuffix, 'w') as fiof:
+
+ with open(cls.fname_nosuffix, "w") as fiof:
fiof.write(fioftext)
@classmethod
def tearDownClass(cls):
cls.temp_dir.cleanup()
del cls.temp_dir
-
+
def setUp(self):
self.fioh5 = FioH5(self.fname_nosuffix)
-
+
def testLogMissingScanno(self):
- with self.assertLogs(logger1,level='WARNING') as cm:
+ with self.assertLogs(logger1, level="WARNING") as cm:
fioh5 = FioH5(self.fname_nosuffix)
self.assertIn("Cannot parse scan number of file", cm.output[0])
-
+
def testFallbackName(self):
self.assertIn("/eh1scan_nosuffix", self.fioh5)
-
+
+
brokenHeaderText = """
!
! Comments
@@ -258,41 +251,46 @@ ScanName = ascan
180.448418821 3 00010 exposure 1576165750.20308
"""
+
class TestBrokenHeaderFioH5(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.temp_dir = tempfile.TemporaryDirectory()
cls.fname_numbered = os.path.join(cls.temp_dir.name, "eh1scan_00005.fio")
-
- with open(cls.fname_numbered, 'w') as fiof:
+
+ with open(cls.fname_numbered, "w") as fiof:
fiof.write(brokenHeaderText)
@classmethod
def tearDownClass(cls):
cls.temp_dir.cleanup()
del cls.temp_dir
-
+
def setUp(self):
self.fioh5 = FioH5(self.fname_numbered)
-
+
def testLogBrokenHeader(self):
- with self.assertLogs(logger1,level='WARNING') as cm:
+ with self.assertLogs(logger1, level="WARNING") as cm:
fioh5 = FioH5(self.fname_numbered)
self.assertIn("Cannot parse parameter section", cm.output[0])
self.assertIn("Cannot parse default comment section", cm.output[1])
-
+
def testComment(self):
# should hold the complete comment section
- self.assertEqual(self.fioh5["/5.1/instrument/fiofile/comments"],
-"""ascan omega 180.0 180.5 3:10/1 4
+ self.assertEqual(
+ self.fioh5["/5.1/instrument/fiofile/comments"],
+ """ascan omega 180.0 180.5 3:10/1 4
user username, acquisited at Thu Dec 12 100 2021
sweep motor lavgvf.0e-03
channel 3: Detector
-""")
+""",
+ )
def testParameter(self):
# should hold the complete parameter section
- self.assertEqual(self.fioh5["/5.1/instrument/fiofile/parameter"],
-"""channel3_exposu65 1.000000e+00
+ self.assertEqual(
+ self.fioh5["/5.1/instrument/fiofile/parameter"],
+ """channel3_exposu65 1.000000e+00
ScanName = ascan
-""")
+""",
+ )
diff --git a/src/silx/io/test/test_h5link_utils.py b/src/silx/io/test/test_h5link_utils.py
new file mode 100644
index 0000000..4140003
--- /dev/null
+++ b/src/silx/io/test/test_h5link_utils.py
@@ -0,0 +1,116 @@
+import os
+import pytest
+import h5py
+import numpy
+from silx.io import open
+from silx.io import h5link_utils
+
+
+@pytest.fixture(scope="module")
+def hdf5_with_external_data(tmpdir_factory):
+ tmpdir = tmpdir_factory.mktemp("hdf5_with_external_data")
+ master = str(tmpdir / "master.h5")
+ external_h5 = str(tmpdir / "external.h5")
+ external_raw = str(tmpdir / "external.raw")
+
+ data = numpy.array([100, 1000, 10000], numpy.uint16)
+ tshape = (1,) + data.shape
+
+ with h5py.File(master, "w") as fmaster:
+ dset = fmaster.create_dataset("data", data=data)
+
+ fmaster["int"] = h5py.SoftLink("data")
+
+ layout = h5py.VirtualLayout(shape=tshape, dtype=data.dtype)
+ layout[0] = h5py.VirtualSource(".", "data", shape=data.shape)
+ fmaster.create_virtual_dataset("vds0", layout)
+
+ with h5py.File(external_h5, "w") as f:
+ dset = f.create_dataset("data", data=data)
+ layout = h5py.VirtualLayout(shape=tshape, dtype=data.dtype)
+ layout[0] = h5py.VirtualSource(dset)
+ fmaster.create_virtual_dataset("vds1", layout)
+
+ layout = h5py.VirtualLayout(shape=tshape, dtype=data.dtype)
+ layout[0] = h5py.VirtualSource(
+ external_h5,
+ "data",
+ shape=data.shape,
+ )
+ fmaster.create_virtual_dataset("vds2", layout)
+ fmaster["ext1"] = h5py.ExternalLink(external_h5, "data")
+
+ layout = h5py.VirtualLayout(shape=tshape, dtype=data.dtype)
+ layout[0] = h5py.VirtualSource(
+ "external.h5",
+ "data",
+ shape=data.shape,
+ )
+ fmaster.create_virtual_dataset("vds3", layout)
+ fmaster["ext2"] = h5py.ExternalLink("external.h5", "data")
+
+ layout = h5py.VirtualLayout(shape=tshape, dtype=data.dtype)
+ layout[0] = h5py.VirtualSource(
+ "./external.h5",
+ "data",
+ shape=data.shape,
+ )
+ fmaster.create_virtual_dataset("vds4", layout)
+ fmaster["ext3"] = h5py.ExternalLink("./external.h5", "data")
+
+ data.tofile(external_raw)
+
+ external = [(external_raw, 0, 16 * 3)]
+ fmaster.create_dataset(
+ "raw1", external=external, shape=tshape, dtype=data.dtype
+ )
+
+ external = [("external.raw", 0, 16 * 3)]
+ fmaster.create_dataset(
+ "raw2", external=external, shape=tshape, dtype=data.dtype
+ )
+
+ external = [("./external.raw", 0, 16 * 3)]
+ fmaster.create_dataset(
+ "raw3", external=external, shape=tshape, dtype=data.dtype
+ )
+
+ # Validate links
+ expected = data.tolist()
+ cwd = os.getcwd()
+ with h5py.File(master, "r") as master:
+ for name in master:
+ if name in ("raw2", "raw3"):
+ os.chdir(str(tmpdir))
+ try:
+ data = master[name][()].flatten().tolist()
+ except Exception:
+ assert False, name
+ finally:
+ if name in ("raw2", "raw3"):
+ os.chdir(cwd)
+ assert data == expected, name
+
+ return tmpdir
+
+
+@pytest.mark.skipif("VirtualLayout" not in dir(h5py), reason="h5py is too old")
+def test_external_dataset_info(hdf5_with_external_data):
+ tmpdir = hdf5_with_external_data
+ master = str(tmpdir / "master.h5")
+ external_h5 = str(tmpdir / "external.h5")
+ external_raw = str(tmpdir / "external.raw")
+ with open(master) as f:
+ for name in f:
+ hdf5obj = f[name]
+ info = h5link_utils.external_dataset_info(hdf5obj)
+ if name in ("data", "int", "ext1", "ext2", "ext3"):
+ assert info is None, name
+ elif name == "vds0":
+ assert info.first_source_url == f"{master}::/data"
+ elif name in ("vds1", "vds2", "vds3", "vds4"):
+ assert info.first_source_url == f"{external_h5}::/data"
+ elif name in ("raw1", "raw2", "raw3"):
+ assert info.first_source_url == external_raw
+ else:
+ assert False, name
diff --git a/src/silx/io/test/test_nxdata.py b/src/silx/io/test/test_nxdata.py
index 52a2b8a..1c64a71 100644
--- a/src/silx/io/test/test_nxdata.py
+++ b/src/silx/io/test/test_nxdata.py
@@ -43,7 +43,9 @@ text_dtype = h5py.special_dtype(vlen=str)
class TestNXdata(unittest.TestCase):
def setUp(self):
- tmp = tempfile.NamedTemporaryFile(prefix="nxdata_examples_", suffix=".h5", delete=True)
+ tmp = tempfile.NamedTemporaryFile(
+ prefix="nxdata_examples_", suffix=".h5", delete=True
+ )
tmp.file.close()
self.h5fname = tmp.name
self.h5f = h5py.File(tmp.name, "w")
@@ -66,7 +68,9 @@ class TestNXdata(unittest.TestCase):
g0d1 = g0d.create_group("4D_scalars")
g0d1.attrs["NX_class"] = "NXdata"
g0d1.attrs["signal"] = "scalars"
- ds = g0d1.create_dataset("scalars", data=numpy.arange(2 * 2 * 3 * 10).reshape((2, 2, 3, 10)))
+ ds = g0d1.create_dataset(
+ "scalars", data=numpy.arange(2 * 2 * 3 * 10).reshape((2, 2, 3, 10))
+ )
ds.attrs["interpretation"] = "scalar"
# SPECTRA
@@ -75,11 +79,16 @@ class TestNXdata(unittest.TestCase):
g1d0 = g1d.create_group("1D_spectrum")
g1d0.attrs["NX_class"] = "NXdata"
g1d0.attrs["signal"] = "count"
- g1d0.attrs["auxiliary_signals"] = numpy.array(["count2", "count3"],
- dtype=text_dtype)
+ g1d0.attrs["auxiliary_signals"] = numpy.array(
+ ["count2", "count3"], dtype=text_dtype
+ )
g1d0.attrs["axes"] = "energy_calib"
- g1d0.attrs["uncertainties"] = numpy.array(["energy_errors", ],
- dtype=text_dtype)
+ g1d0.attrs["uncertainties"] = numpy.array(
+ [
+ "energy_errors",
+ ],
+ dtype=text_dtype,
+ )
g1d0.create_dataset("count", data=numpy.arange(10))
g1d0.create_dataset("count2", data=0.5 * numpy.arange(10))
d = g1d0.create_dataset("count3", data=0.4 * numpy.arange(10))
@@ -97,12 +106,20 @@ class TestNXdata(unittest.TestCase):
g1d2 = g1d.create_group("4D_spectra")
g1d2.attrs["NX_class"] = "NXdata"
g1d2.attrs["signal"] = "counts"
- g1d2.attrs["axes"] = numpy.array(["energy", ], dtype=text_dtype)
- ds = g1d2.create_dataset("counts", data=numpy.arange(2 * 2 * 3 * 10).reshape((2, 2, 3, 10)))
+ g1d2.attrs["axes"] = numpy.array(
+ [
+ "energy",
+ ],
+ dtype=text_dtype,
+ )
+ ds = g1d2.create_dataset(
+ "counts", data=numpy.arange(2 * 2 * 3 * 10).reshape((2, 2, 3, 10))
+ )
ds.attrs["interpretation"] = "spectrum"
ds = g1d2.create_dataset("errors", data=4.5 * numpy.random.rand(2, 2, 3, 10))
- ds = g1d2.create_dataset("energy", data=5 + 10 * numpy.arange(15),
- shuffle=True, compression="gzip")
+ ds = g1d2.create_dataset(
+ "energy", data=5 + 10 * numpy.arange(15), shuffle=True, compression="gzip"
+ )
ds.attrs["long_name"] = "Calibrated energy"
ds.attrs["first_good"] = 3
ds.attrs["last_good"] = 12
@@ -115,8 +132,9 @@ class TestNXdata(unittest.TestCase):
g2d0.attrs["NX_class"] = "NXdata"
g2d0.attrs["signal"] = "image"
g2d0.attrs["auxiliary_signals"] = "image2"
- g2d0.attrs["axes"] = numpy.array(["rows_calib", "columns_coordinates"],
- dtype=text_dtype)
+ g2d0.attrs["axes"] = numpy.array(
+ ["rows_calib", "columns_coordinates"], dtype=text_dtype
+ )
g2d0.create_dataset("image", data=numpy.arange(4 * 6).reshape((4, 6)))
g2d0.create_dataset("image2", data=numpy.arange(4 * 6).reshape((4, 6)))
ds = g2d0.create_dataset("rows_calib", data=(10, 5))
@@ -127,24 +145,34 @@ class TestNXdata(unittest.TestCase):
g2d1.attrs["NX_class"] = "NXdata"
g2d1.attrs["signal"] = "data"
g2d1.attrs["title"] = "Title as group attr"
- g2d1.attrs["axes"] = numpy.array(["rows_coordinates", "columns_coordinates"],
- dtype=text_dtype)
+ g2d1.attrs["axes"] = numpy.array(
+ ["rows_coordinates", "columns_coordinates"], dtype=text_dtype
+ )
g2d1.create_dataset("data", data=numpy.arange(64 * 128).reshape((64, 128)))
- g2d1.create_dataset("rows_coordinates", data=numpy.arange(64) + numpy.random.rand(64))
- g2d1.create_dataset("columns_coordinates", data=numpy.arange(128) + 2.5 * numpy.random.rand(128))
+ g2d1.create_dataset(
+ "rows_coordinates", data=numpy.arange(64) + numpy.random.rand(64)
+ )
+ g2d1.create_dataset(
+ "columns_coordinates", data=numpy.arange(128) + 2.5 * numpy.random.rand(128)
+ )
g2d2 = g2d.create_group("3D_images")
g2d2.attrs["NX_class"] = "NXdata"
g2d2.attrs["signal"] = "images"
- ds = g2d2.create_dataset("images", data=numpy.arange(2 * 4 * 6).reshape((2, 4, 6)))
+ ds = g2d2.create_dataset(
+ "images", data=numpy.arange(2 * 4 * 6).reshape((2, 4, 6))
+ )
ds.attrs["interpretation"] = "image"
g2d3 = g2d.create_group("5D_images")
g2d3.attrs["NX_class"] = "NXdata"
g2d3.attrs["signal"] = "images"
- g2d3.attrs["axes"] = numpy.array(["rows_coordinates", "columns_coordinates"],
- dtype=text_dtype)
- ds = g2d3.create_dataset("images", data=numpy.arange(2 * 2 * 2 * 4 * 6).reshape((2, 2, 2, 4, 6)))
+ g2d3.attrs["axes"] = numpy.array(
+ ["rows_coordinates", "columns_coordinates"], dtype=text_dtype
+ )
+ ds = g2d3.create_dataset(
+ "images", data=numpy.arange(2 * 2 * 2 * 4 * 6).reshape((2, 2, 2, 4, 6))
+ )
ds.attrs["interpretation"] = "image"
g2d3.create_dataset("rows_coordinates", data=5 + 10 * numpy.arange(4))
g2d3.create_dataset("columns_coordinates", data=0.5 + 0.02 * numpy.arange(6))
@@ -152,15 +180,18 @@ class TestNXdata(unittest.TestCase):
g2d4 = g2d.create_group("RGBA_image")
g2d4.attrs["NX_class"] = "NXdata"
g2d4.attrs["signal"] = "image"
- g2d4.attrs["axes"] = numpy.array(["rows_calib", "columns_coordinates"],
- dtype=text_dtype)
- rgba_image = numpy.linspace(0, 1, num=7*8*3).reshape((7, 8, 3))
- rgba_image[:, :, 1] = 1 - rgba_image[:, :, 1] # invert G channel to add some color
+ g2d4.attrs["axes"] = numpy.array(
+ ["rows_calib", "columns_coordinates"], dtype=text_dtype
+ )
+ rgba_image = numpy.linspace(0, 1, num=7 * 8 * 3).reshape((7, 8, 3))
+ rgba_image[:, :, 1] = (
+ 1 - rgba_image[:, :, 1]
+ ) # invert G channel to add some color
ds = g2d4.create_dataset("image", data=rgba_image)
ds.attrs["interpretation"] = "rgba-image"
ds = g2d4.create_dataset("rows_calib", data=(10, 5))
ds.attrs["long_name"] = "Calibrated Y"
- g2d4.create_dataset("columns_coordinates", data=0.5+0.02*numpy.arange(8))
+ g2d4.create_dataset("columns_coordinates", data=0.5 + 0.02 * numpy.arange(8))
# SCATTER
g = self.h5f.create_group("scatters")
@@ -168,7 +199,12 @@ class TestNXdata(unittest.TestCase):
gd0 = g.create_group("x_y_scatter")
gd0.attrs["NX_class"] = "NXdata"
gd0.attrs["signal"] = "y"
- gd0.attrs["axes"] = numpy.array(["x", ], dtype=text_dtype)
+ gd0.attrs["axes"] = numpy.array(
+ [
+ "x",
+ ],
+ dtype=text_dtype,
+ )
gd0.create_dataset("y", data=numpy.random.rand(128) - 0.5)
gd0.create_dataset("x", data=2 * numpy.random.rand(128))
gd0.create_dataset("x_errors", data=0.05 * numpy.random.rand(128))
@@ -191,8 +227,9 @@ class TestNXdata(unittest.TestCase):
for group in self.h5f:
for subgroup in self.h5f[group]:
self.assertTrue(
- nxdata.is_valid_nxdata(self.h5f[group][subgroup]),
- "%s/%s not found to be a valid NXdata group" % (group, subgroup))
+ nxdata.is_valid_nxdata(self.h5f[group][subgroup]),
+ "%s/%s not found to be a valid NXdata group" % (group, subgroup),
+ )
def testScalars(self):
nxd = nxdata.NXdata(self.h5f["scalars/0D_scalar"])
@@ -216,8 +253,9 @@ class TestNXdata(unittest.TestCase):
self.assertEqual(nxd.interpretation, "scalar")
nxd = nxdata.NXdata(self.h5f["scalars/4D_scalars"])
- self.assertFalse(nxd.signal_is_0d or nxd.signal_is_1d or
- nxd.signal_is_2d or nxd.signal_is_3d)
+ self.assertFalse(
+ nxd.signal_is_0d or nxd.signal_is_1d or nxd.signal_is_2d or nxd.signal_is_3d
+ )
self.assertEqual(nxd.signal[1, 0, 1, 4], 74)
self.assertEqual(nxd.axes_names, [None, None, None, None])
self.assertEqual(nxd.axes_dataset_names, [None, None, None, None])
@@ -230,8 +268,7 @@ class TestNXdata(unittest.TestCase):
nxd = nxdata.NXdata(self.h5f["spectra/1D_spectrum"])
self.assertTrue(nxd.signal_is_1d)
self.assertTrue(nxd.is_curve)
- self.assertTrue(numpy.array_equal(numpy.array(nxd.signal),
- numpy.arange(10)))
+ self.assertTrue(numpy.array_equal(numpy.array(nxd.signal), numpy.arange(10)))
self.assertEqual(nxd.axes_names, ["energy_calib"])
self.assertEqual(nxd.axes_dataset_names, ["energy_calib"])
self.assertEqual(nxd.axes[0][0], 10)
@@ -241,12 +278,11 @@ class TestNXdata(unittest.TestCase):
self.assertIsNone(nxd.interpretation)
self.assertEqual(nxd.title, "Title as dataset (like nexpy)")
- self.assertEqual(nxd.auxiliary_signals_dataset_names,
- ["count2", "count3"])
- self.assertEqual(nxd.auxiliary_signals_names,
- ["count2", "3rd counter"])
- self.assertAlmostEqual(nxd.auxiliary_signals[1][2],
- 0.8) # numpy.arange(10) * 0.4
+ self.assertEqual(nxd.auxiliary_signals_dataset_names, ["count2", "count3"])
+ self.assertEqual(nxd.auxiliary_signals_names, ["count2", "3rd counter"])
+ self.assertAlmostEqual(
+ nxd.auxiliary_signals[1][2], 0.8
+ ) # numpy.arange(10) * 0.4
nxd = nxdata.NXdata(self.h5f["spectra/2D_spectra"])
self.assertTrue(nxd.signal_is_2d)
@@ -259,34 +295,39 @@ class TestNXdata(unittest.TestCase):
self.assertEqual(nxd.interpretation, "spectrum")
nxd = nxdata.NXdata(self.h5f["spectra/4D_spectra"])
- self.assertFalse(nxd.signal_is_0d or nxd.signal_is_1d or
- nxd.signal_is_2d or nxd.signal_is_3d)
+ self.assertFalse(
+ nxd.signal_is_0d or nxd.signal_is_1d or nxd.signal_is_2d or nxd.signal_is_3d
+ )
self.assertTrue(nxd.is_curve)
- self.assertEqual(nxd.axes_names,
- [None, None, None, "Calibrated energy"])
- self.assertEqual(nxd.axes_dataset_names,
- [None, None, None, "energy"])
+ self.assertEqual(nxd.axes_names, [None, None, None, "Calibrated energy"])
+ self.assertEqual(nxd.axes_dataset_names, [None, None, None, "energy"])
self.assertEqual(nxd.axes[:3], [None, None, None])
- self.assertEqual(nxd.axes[3].shape, (10, )) # dataset shape (15, ) sliced [3:12]
+ self.assertEqual(nxd.axes[3].shape, (10,)) # dataset shape (15, ) sliced [3:12]
self.assertIsNotNone(nxd.errors)
self.assertEqual(nxd.errors.shape, (2, 2, 3, 10))
self.assertFalse(nxd.is_scatter or nxd.is_x_y_value_scatter)
self.assertEqual(nxd.interpretation, "spectrum")
- self.assertEqual(nxd.get_axis_errors("energy").shape,
- (10,))
+ self.assertEqual(nxd.get_axis_errors("energy").shape, (10,))
# test getting axis errors by long_name
- self.assertTrue(numpy.array_equal(nxd.get_axis_errors("Calibrated energy"),
- nxd.get_axis_errors("energy")))
- self.assertTrue(numpy.array_equal(nxd.get_axis_errors(b"Calibrated energy"),
- nxd.get_axis_errors("energy")))
+ self.assertTrue(
+ numpy.array_equal(
+ nxd.get_axis_errors("Calibrated energy"), nxd.get_axis_errors("energy")
+ )
+ )
+ self.assertTrue(
+ numpy.array_equal(
+ nxd.get_axis_errors(b"Calibrated energy"), nxd.get_axis_errors("energy")
+ )
+ )
def testImages(self):
nxd = nxdata.NXdata(self.h5f["images/2D_regular_image"])
self.assertTrue(nxd.signal_is_2d)
self.assertTrue(nxd.is_image)
self.assertEqual(nxd.axes_names, ["Calibrated Y", "columns_coordinates"])
- self.assertEqual(list(nxd.axes_dataset_names),
- ["rows_calib", "columns_coordinates"])
+ self.assertEqual(
+ list(nxd.axes_dataset_names), ["rows_calib", "columns_coordinates"]
+ )
self.assertIsNone(nxd.errors)
self.assertFalse(nxd.is_scatter or nxd.is_x_y_value_scatter)
self.assertIsNone(nxd.interpretation)
@@ -298,8 +339,9 @@ class TestNXdata(unittest.TestCase):
self.assertTrue(nxd.is_image)
self.assertEqual(nxd.axes_dataset_names, nxd.axes_names)
- self.assertEqual(list(nxd.axes_dataset_names),
- ["rows_coordinates", "columns_coordinates"])
+ self.assertEqual(
+ list(nxd.axes_dataset_names), ["rows_coordinates", "columns_coordinates"]
+ )
self.assertEqual(len(nxd.axes), 2)
self.assertIsNone(nxd.errors)
self.assertFalse(nxd.is_scatter or nxd.is_x_y_value_scatter)
@@ -308,12 +350,17 @@ class TestNXdata(unittest.TestCase):
nxd = nxdata.NXdata(self.h5f["images/5D_images"])
self.assertTrue(nxd.is_image)
- self.assertFalse(nxd.signal_is_0d or nxd.signal_is_1d or
- nxd.signal_is_2d or nxd.signal_is_3d)
- self.assertEqual(nxd.axes_names,
- [None, None, None, 'rows_coordinates', 'columns_coordinates'])
- self.assertEqual(nxd.axes_dataset_names,
- [None, None, None, 'rows_coordinates', 'columns_coordinates'])
+ self.assertFalse(
+ nxd.signal_is_0d or nxd.signal_is_1d or nxd.signal_is_2d or nxd.signal_is_3d
+ )
+ self.assertEqual(
+ nxd.axes_names,
+ [None, None, None, "rows_coordinates", "columns_coordinates"],
+ )
+ self.assertEqual(
+ nxd.axes_dataset_names,
+ [None, None, None, "rows_coordinates", "columns_coordinates"],
+ )
self.assertIsNone(nxd.errors)
self.assertFalse(nxd.is_scatter or nxd.is_x_y_value_scatter)
self.assertEqual(nxd.interpretation, "image")
@@ -322,35 +369,28 @@ class TestNXdata(unittest.TestCase):
self.assertTrue(nxd.is_image)
self.assertEqual(nxd.interpretation, "rgba-image")
self.assertTrue(nxd.signal_is_3d)
- self.assertEqual(nxd.axes_names, ["Calibrated Y",
- "columns_coordinates",
- None])
- self.assertEqual(list(nxd.axes_dataset_names),
- ["rows_calib", "columns_coordinates", None])
+ self.assertEqual(nxd.axes_names, ["Calibrated Y", "columns_coordinates", None])
+ self.assertEqual(
+ list(nxd.axes_dataset_names), ["rows_calib", "columns_coordinates", None]
+ )
def testScatters(self):
nxd = nxdata.NXdata(self.h5f["scatters/x_y_scatter"])
self.assertTrue(nxd.signal_is_1d)
self.assertEqual(nxd.axes_names, ["x"])
- self.assertEqual(nxd.axes_dataset_names,
- ["x"])
+ self.assertEqual(nxd.axes_dataset_names, ["x"])
self.assertIsNotNone(nxd.errors)
- self.assertEqual(nxd.get_axis_errors("x").shape,
- (128, ))
+ self.assertEqual(nxd.get_axis_errors("x").shape, (128,))
self.assertTrue(nxd.is_scatter)
self.assertFalse(nxd.is_x_y_value_scatter)
self.assertIsNone(nxd.interpretation)
nxd = nxdata.NXdata(self.h5f["scatters/x_y_value_scatter"])
self.assertFalse(nxd.signal_is_1d)
- self.assertTrue(nxd.axes_dataset_names,
- nxd.axes_names)
- self.assertEqual(nxd.axes_dataset_names,
- ["x", "y"])
- self.assertEqual(nxd.get_axis_errors("x").shape,
- (128, ))
- self.assertEqual(nxd.get_axis_errors("y").shape,
- (128, ))
+ self.assertTrue(nxd.axes_dataset_names, nxd.axes_names)
+ self.assertEqual(nxd.axes_dataset_names, ["x", "y"])
+ self.assertEqual(nxd.get_axis_errors("x").shape, (128,))
+ self.assertEqual(nxd.get_axis_errors("y").shape, (128,))
self.assertEqual(len(nxd.axes), 2)
self.assertIsNone(nxd.errors)
self.assertTrue(nxd.is_scatter)
@@ -360,8 +400,9 @@ class TestNXdata(unittest.TestCase):
class TestLegacyNXdata(unittest.TestCase):
def setUp(self):
- tmp = tempfile.NamedTemporaryFile(prefix="nxdata_legacy_examples_",
- suffix=".h5", delete=True)
+ tmp = tempfile.NamedTemporaryFile(
+ prefix="nxdata_legacy_examples_", suffix=".h5", delete=True
+ )
tmp.file.close()
self.h5fname = tmp.name
self.h5f = h5py.File(tmp.name, "w")
@@ -373,80 +414,61 @@ class TestLegacyNXdata(unittest.TestCase):
g = self.h5f.create_group("2D")
g.attrs["NX_class"] = "NXdata"
- ds0 = g.create_dataset("image0",
- data=numpy.arange(4 * 6).reshape((4, 6)))
+ ds0 = g.create_dataset("image0", data=numpy.arange(4 * 6).reshape((4, 6)))
ds0.attrs["signal"] = 1
ds0.attrs["long_name"] = "My first image"
- ds1 = g.create_dataset("image1",
- data=numpy.arange(4 * 6).reshape((4, 6)))
+ ds1 = g.create_dataset("image1", data=numpy.arange(4 * 6).reshape((4, 6)))
ds1.attrs["signal"] = "2"
ds1.attrs["long_name"] = "My 2nd image"
- ds2 = g.create_dataset("image2",
- data=numpy.arange(4 * 6).reshape((4, 6)))
+ ds2 = g.create_dataset("image2", data=numpy.arange(4 * 6).reshape((4, 6)))
ds2.attrs["signal"] = 3
nxd = nxdata.NXdata(self.h5f["2D"])
self.assertEqual(nxd.signal_dataset_name, "image0")
self.assertEqual(nxd.signal_name, "My first image")
- self.assertEqual(nxd.signal.shape,
- (4, 6))
+ self.assertEqual(nxd.signal.shape, (4, 6))
self.assertEqual(len(nxd.auxiliary_signals), 2)
- self.assertEqual(nxd.auxiliary_signals[1].shape,
- (4, 6))
+ self.assertEqual(nxd.auxiliary_signals[1].shape, (4, 6))
- self.assertEqual(nxd.auxiliary_signals_dataset_names,
- ["image1", "image2"])
- self.assertEqual(nxd.auxiliary_signals_names,
- ["My 2nd image", "image2"])
+ self.assertEqual(nxd.auxiliary_signals_dataset_names, ["image1", "image2"])
+ self.assertEqual(nxd.auxiliary_signals_names, ["My 2nd image", "image2"])
def testAxesOnSignalDataset(self):
g = self.h5f.create_group("2D")
g.attrs["NX_class"] = "NXdata"
- ds0 = g.create_dataset("image0",
- data=numpy.arange(4 * 6).reshape((4, 6)))
+ ds0 = g.create_dataset("image0", data=numpy.arange(4 * 6).reshape((4, 6)))
ds0.attrs["signal"] = 1
ds0.attrs["axes"] = "yaxis:xaxis"
- ds1 = g.create_dataset("yaxis",
- data=numpy.arange(4))
- ds2 = g.create_dataset("xaxis",
- data=numpy.arange(6))
+ ds1 = g.create_dataset("yaxis", data=numpy.arange(4))
+ ds2 = g.create_dataset("xaxis", data=numpy.arange(6))
nxd = nxdata.NXdata(self.h5f["2D"])
- self.assertEqual(nxd.axes_dataset_names,
- ["yaxis", "xaxis"])
- self.assertTrue(numpy.array_equal(nxd.axes[0],
- numpy.arange(4)))
- self.assertTrue(numpy.array_equal(nxd.axes[1],
- numpy.arange(6)))
+ self.assertEqual(nxd.axes_dataset_names, ["yaxis", "xaxis"])
+ self.assertTrue(numpy.array_equal(nxd.axes[0], numpy.arange(4)))
+ self.assertTrue(numpy.array_equal(nxd.axes[1], numpy.arange(6)))
def testAxesOnAxesDatasets(self):
g = self.h5f.create_group("2D")
g.attrs["NX_class"] = "NXdata"
- ds0 = g.create_dataset("image0",
- data=numpy.arange(4 * 6).reshape((4, 6)))
+ ds0 = g.create_dataset("image0", data=numpy.arange(4 * 6).reshape((4, 6)))
ds0.attrs["signal"] = 1
- ds1 = g.create_dataset("yaxis",
- data=numpy.arange(4))
+ ds1 = g.create_dataset("yaxis", data=numpy.arange(4))
ds1.attrs["axis"] = 0
- ds2 = g.create_dataset("xaxis",
- data=numpy.arange(6))
+ ds2 = g.create_dataset("xaxis", data=numpy.arange(6))
ds2.attrs["axis"] = "1"
nxd = nxdata.NXdata(self.h5f["2D"])
- self.assertEqual(nxd.axes_dataset_names,
- ["yaxis", "xaxis"])
- self.assertTrue(numpy.array_equal(nxd.axes[0],
- numpy.arange(4)))
- self.assertTrue(numpy.array_equal(nxd.axes[1],
- numpy.arange(6)))
+ self.assertEqual(nxd.axes_dataset_names, ["yaxis", "xaxis"])
+ self.assertTrue(numpy.array_equal(nxd.axes[0], numpy.arange(4)))
+ self.assertTrue(numpy.array_equal(nxd.axes[1], numpy.arange(6)))
def testAsciiUndefinedAxesAttrs(self):
"""Some files may not be using utf8 for str attrs"""
@@ -455,20 +477,16 @@ class TestLegacyNXdata(unittest.TestCase):
g.attrs["signal"] = b"image0"
g.attrs["axes"] = b"yaxis", b"."
- g.create_dataset("image0",
- data=numpy.arange(4 * 6).reshape((4, 6)))
- g.create_dataset("yaxis",
- data=numpy.arange(4))
+ g.create_dataset("image0", data=numpy.arange(4 * 6).reshape((4, 6)))
+ g.create_dataset("yaxis", data=numpy.arange(4))
nxd = nxdata.NXdata(self.h5f["bytes_attrs"])
- self.assertEqual(nxd.axes_dataset_names,
- ["yaxis", None])
+ self.assertEqual(nxd.axes_dataset_names, ["yaxis", None])
class TestSaveNXdata(unittest.TestCase):
def setUp(self):
- tmp = tempfile.NamedTemporaryFile(prefix="nxdata",
- suffix=".h5", delete=True)
+ tmp = tempfile.NamedTemporaryFile(prefix="nxdata", suffix=".h5", delete=True)
tmp.file.close()
self.h5fname = tmp.name
@@ -476,64 +494,60 @@ class TestSaveNXdata(unittest.TestCase):
sig = numpy.array([0, 1, 2])
a0 = numpy.array([2, 3, 4])
a1 = numpy.array([3, 4, 5])
- nxdata.save_NXdata(filename=self.h5fname,
- signal=sig,
- axes=[a0, a1],
- signal_name="sig",
- axes_names=["a0", "a1"],
- nxentry_name="a",
- nxdata_name="mydata")
+ nxdata.save_NXdata(
+ filename=self.h5fname,
+ signal=sig,
+ axes=[a0, a1],
+ signal_name="sig",
+ axes_names=["a0", "a1"],
+ nxentry_name="a",
+ nxdata_name="mydata",
+ )
h5f = h5py.File(self.h5fname, "r")
self.assertTrue(nxdata.is_valid_nxdata(h5f["a/mydata"]))
nxd = nxdata.NXdata(h5f["/a/mydata"])
- self.assertTrue(numpy.array_equal(nxd.signal,
- sig))
- self.assertTrue(numpy.array_equal(nxd.axes[0],
- a0))
+ self.assertTrue(numpy.array_equal(nxd.signal, sig))
+ self.assertTrue(numpy.array_equal(nxd.axes[0], a0))
h5f.close()
def testSimplestSave(self):
sig = numpy.array([0, 1, 2])
- nxdata.save_NXdata(filename=self.h5fname,
- signal=sig)
+ nxdata.save_NXdata(filename=self.h5fname, signal=sig)
h5f = h5py.File(self.h5fname, "r")
self.assertTrue(nxdata.is_valid_nxdata(h5f["/entry/data0"]))
nxd = nxdata.NXdata(h5f["/entry/data0"])
- self.assertTrue(numpy.array_equal(nxd.signal,
- sig))
+ self.assertTrue(numpy.array_equal(nxd.signal, sig))
h5f.close()
def testSaveDefaultAxesNames(self):
sig = numpy.array([0, 1, 2])
a0 = numpy.array([2, 3, 4])
a1 = numpy.array([3, 4, 5])
- nxdata.save_NXdata(filename=self.h5fname,
- signal=sig,
- axes=[a0, a1],
- signal_name="sig",
- axes_names=None,
- axes_long_names=["a", "b"],
- nxentry_name="a",
- nxdata_name="mydata")
+ nxdata.save_NXdata(
+ filename=self.h5fname,
+ signal=sig,
+ axes=[a0, a1],
+ signal_name="sig",
+ axes_names=None,
+ axes_long_names=["a", "b"],
+ nxentry_name="a",
+ nxdata_name="mydata",
+ )
h5f = h5py.File(self.h5fname, "r")
self.assertTrue(nxdata.is_valid_nxdata(h5f["a/mydata"]))
nxd = nxdata.NXdata(h5f["/a/mydata"])
- self.assertTrue(numpy.array_equal(nxd.signal,
- sig))
- self.assertTrue(numpy.array_equal(nxd.axes[0],
- a0))
- self.assertEqual(nxd.axes_dataset_names,
- [u"dim0", u"dim1"])
- self.assertEqual(nxd.axes_names,
- [u"a", u"b"])
+ self.assertTrue(numpy.array_equal(nxd.signal, sig))
+ self.assertTrue(numpy.array_equal(nxd.axes[0], a0))
+ self.assertEqual(nxd.axes_dataset_names, ["dim0", "dim1"])
+ self.assertEqual(nxd.axes_names, ["a", "b"])
h5f.close()
@@ -546,22 +560,22 @@ class TestSaveNXdata(unittest.TestCase):
sig = numpy.array([0, 1, 2])
a0 = numpy.array([2, 3, 4])
a1 = numpy.array([3, 4, 5])
- nxdata.save_NXdata(filename=self.h5fname,
- signal=sig,
- axes=[a0, a1],
- signal_name="sig",
- axes_names=["a0", "a1"],
- nxentry_name="myentry",
- nxdata_name="toto")
+ nxdata.save_NXdata(
+ filename=self.h5fname,
+ signal=sig,
+ axes=[a0, a1],
+ signal_name="sig",
+ axes_names=["a0", "a1"],
+ nxentry_name="myentry",
+ nxdata_name="toto",
+ )
h5f = h5py.File(self.h5fname, "r")
self.assertTrue(nxdata.is_valid_nxdata(h5f["myentry/toto"]))
nxd = nxdata.NXdata(h5f["myentry/toto"])
- self.assertTrue(numpy.array_equal(nxd.signal,
- sig))
- self.assertTrue(numpy.array_equal(nxd.axes[0],
- a0))
+ self.assertTrue(numpy.array_equal(nxd.signal, sig))
+ self.assertTrue(numpy.array_equal(nxd.axes[0], a0))
h5f.close()
@@ -585,9 +599,10 @@ class TestGetDefault:
"data": (1, 2, 3),
}
}
- }
+ },
},
- hdf5_file)
+ hdf5_file,
+ )
default = nxdata.get_default(hdf5_file)
assert isinstance(default, nxdata.NXdata)
assert default.group.name == "/nxentry/nxprocess/nxdata"
@@ -604,10 +619,11 @@ class TestGetDefault:
("", "signal"): "data",
"data": (1, 2, 3),
}
- }
- }
+ },
+ },
},
- hdf5_file)
+ hdf5_file,
+ )
default = nxdata.get_default(hdf5_file)
assert isinstance(default, nxdata.NXdata)
assert default.group.name == "/nxentry/nxprocess/nxdata"
@@ -622,10 +638,11 @@ class TestGetDefault:
("", "NX_class"): "NXdata",
("", "signal"): "data",
"data": (1, 2, 3),
- }
- }
+ },
+ },
},
- hdf5_file)
+ hdf5_file,
+ )
default = nxdata.get_default(hdf5_file)
assert isinstance(default, nxdata.NXdata)
assert default.group.name == "/nxentry/nxdata"
@@ -642,10 +659,11 @@ class TestGetDefault:
("", "signal"): "data",
"data": (1, 2, 3),
}
- }
- }
+ },
+ },
},
- hdf5_file)
+ hdf5_file,
+ )
default = nxdata.get_default(hdf5_file)
assert isinstance(default, nxdata.NXdata)
assert default.group.name == "/nxentry/nxprocess/nxdata"
@@ -662,11 +680,12 @@ class TestGetDefault:
("", "NX_class"): "NXdata",
("", "signal"): "data",
"data": (1, 2, 3),
- }
- }
- }
+ },
+ },
+ },
},
- hdf5_file)
+ hdf5_file,
+ )
default = nxdata.get_default(hdf5_file)
assert isinstance(default, nxdata.NXdata)
assert default.group.name == "/nxentry/nxprocess/nxdata"
@@ -683,11 +702,12 @@ class TestGetDefault:
("", "NX_class"): "NXdata",
("", "signal"): "data",
"data": (1, 2, 3),
- }
- }
- }
+ },
+ },
+ },
},
- hdf5_file)
+ hdf5_file,
+ )
default = nxdata.get_default(hdf5_file)
assert isinstance(default, nxdata.NXdata)
assert default.group.name == "/nxentry/nxprocess/nxdata"
@@ -699,8 +719,9 @@ class TestGetDefault:
("", "default"): "/nxentry",
"nxentry": {
("", "default"): "/nxentry",
- }
+ },
},
- hdf5_file)
+ hdf5_file,
+ )
default = nxdata.get_default(hdf5_file)
assert default is None
diff --git a/src/silx/io/test/test_octaveh5.py b/src/silx/io/test/test_octaveh5.py
index 19b8ad6..479ef85 100644
--- a/src/silx/io/test/test_octaveh5.py
+++ b/src/silx/io/test/test_octaveh5.py
@@ -42,41 +42,84 @@ except ImportError:
class TestOctaveH5(unittest.TestCase):
@staticmethod
def _get_struct_FT():
- return {
- 'NO_CHECK': 0.0, 'SHOWSLICE': 1.0, 'DOTOMO': 1.0, 'DATABASE': 0.0, 'ANGLE_OFFSET': 0.0,
- 'VOLSELECTION_REMEMBER': 0.0, 'NUM_PART': 4.0, 'VOLOUTFILE': 0.0, 'RINGSCORRECTION': 0.0,
- 'DO_TEST_SLICE': 1.0, 'ZEROOFFMASK': 1.0, 'VERSION': 'fastomo3 version 2.0',
- 'CORRECT_SPIKES_THRESHOLD': 0.040000000000000001, 'SHOWPROJ': 0.0, 'HALF_ACQ': 0.0,
- 'ANGLE_OFFSET_VALUE': 0.0, 'FIXEDSLICE': 'middle', 'VOLSELECT': 'total' }
+ return {
+ "NO_CHECK": 0.0,
+ "SHOWSLICE": 1.0,
+ "DOTOMO": 1.0,
+ "DATABASE": 0.0,
+ "ANGLE_OFFSET": 0.0,
+ "VOLSELECTION_REMEMBER": 0.0,
+ "NUM_PART": 4.0,
+ "VOLOUTFILE": 0.0,
+ "RINGSCORRECTION": 0.0,
+ "DO_TEST_SLICE": 1.0,
+ "ZEROOFFMASK": 1.0,
+ "VERSION": "fastomo3 version 2.0",
+ "CORRECT_SPIKES_THRESHOLD": 0.040000000000000001,
+ "SHOWPROJ": 0.0,
+ "HALF_ACQ": 0.0,
+ "ANGLE_OFFSET_VALUE": 0.0,
+ "FIXEDSLICE": "middle",
+ "VOLSELECT": "total",
+ }
+
@staticmethod
def _get_struct_PYHSTEXE():
return {
- 'EXE': 'PyHST2_2015d', 'VERBOSE': 0.0, 'OFFV': 'PyHST2_2015d', 'TOMO': 0.0,
- 'VERBOSE_FILE': 'pyhst_out.txt', 'DIR': '/usr/bin/', 'OFFN': 'pyhst2'}
+ "EXE": "PyHST2_2015d",
+ "VERBOSE": 0.0,
+ "OFFV": "PyHST2_2015d",
+ "TOMO": 0.0,
+ "VERBOSE_FILE": "pyhst_out.txt",
+ "DIR": "/usr/bin/",
+ "OFFN": "pyhst2",
+ }
@staticmethod
def _get_struct_FTAXIS():
return {
- 'POSITION_VALUE': 12345.0, 'COR_ERROR': 0.0, 'FILESDURINGSCAN': 0.0, 'PLOTFIGURE': 1.0,
- 'DIM1': 0.0, 'OVERSAMPLING': 5.0, 'TO_THE_CENTER': 1.0, 'POSITION': 'fixed',
- 'COR_POSITION': 0.0, 'HA': 0.0 }
-
+ "POSITION_VALUE": 12345.0,
+ "COR_ERROR": 0.0,
+ "FILESDURINGSCAN": 0.0,
+ "PLOTFIGURE": 1.0,
+ "DIM1": 0.0,
+ "OVERSAMPLING": 5.0,
+ "TO_THE_CENTER": 1.0,
+ "POSITION": "fixed",
+ "COR_POSITION": 0.0,
+ "HA": 0.0,
+ }
+
@staticmethod
def _get_struct_PAGANIN():
return {
- 'MKEEP_MASK': 0.0, 'UNSHARP_SIGMA': 0.80000000000000004, 'DILATE': 2.0, 'UNSHARP_COEFF': 3.0,
- 'MEDIANR': 4.0, 'DB': 500.0, 'MKEEP_ABS': 0.0, 'MODE': 0.0, 'THRESHOLD': 0.5,
- 'MKEEP_BONE': 0.0, 'DB2': 100.0, 'MKEEP_CORR': 0.0, 'MKEEP_SOFT': 0.0 }
+ "MKEEP_MASK": 0.0,
+ "UNSHARP_SIGMA": 0.80000000000000004,
+ "DILATE": 2.0,
+ "UNSHARP_COEFF": 3.0,
+ "MEDIANR": 4.0,
+ "DB": 500.0,
+ "MKEEP_ABS": 0.0,
+ "MODE": 0.0,
+ "THRESHOLD": 0.5,
+ "MKEEP_BONE": 0.0,
+ "DB2": 100.0,
+ "MKEEP_CORR": 0.0,
+ "MKEEP_SOFT": 0.0,
+ }
@staticmethod
def _get_struct_BEAMGEO():
- return {'DIST': 55.0, 'SY': 0.0, 'SX': 0.0, 'TYPE': 'p'}
-
+ return {"DIST": 55.0, "SY": 0.0, "SX": 0.0, "TYPE": "p"}
def setUp(self):
- self.tempdir = tempfile.mkdtemp()
- self.test_3_6_fname = os.path.join(self.tempdir, "silx_tmp_t00_octaveTest_3_6.h5")
- self.test_3_8_fname = os.path.join(self.tempdir, "silx_tmp_t00_octaveTest_3_8.h5")
+ self.tempdir = tempfile.mkdtemp()
+ self.test_3_6_fname = os.path.join(
+ self.tempdir, "silx_tmp_t00_octaveTest_3_6.h5"
+ )
+ self.test_3_8_fname = os.path.join(
+ self.tempdir, "silx_tmp_t00_octaveTest_3_8.h5"
+ )
def tearDown(self):
if os.path.isfile(self.test_3_6_fname):
@@ -88,68 +131,67 @@ class TestOctaveH5(unittest.TestCase):
"""
Simple test to write and reaf the structure compatible with the octave h5 using structure.
This test is for # test for octave version > 3.8
- """
+ """
writer = Octaveh5()
- writer.open(self.test_3_8_fname, 'a')
+ writer.open(self.test_3_8_fname, "a")
# step 1 writing the file
- writer.write('FT', self._get_struct_FT())
- writer.write('PYHSTEXE', self._get_struct_PYHSTEXE())
- writer.write('FTAXIS', self._get_struct_FTAXIS())
- writer.write('PAGANIN', self._get_struct_PAGANIN())
- writer.write('BEAMGEO', self._get_struct_BEAMGEO())
+ writer.write("FT", self._get_struct_FT())
+ writer.write("PYHSTEXE", self._get_struct_PYHSTEXE())
+ writer.write("FTAXIS", self._get_struct_FTAXIS())
+ writer.write("PAGANIN", self._get_struct_PAGANIN())
+ writer.write("BEAMGEO", self._get_struct_BEAMGEO())
writer.close()
# step 2 reading the file
reader = Octaveh5().open(self.test_3_8_fname)
# 2.1 check FT
- data_readed = reader.get('FT')
- self.assertEqual(data_readed, self._get_struct_FT() )
+ data_readed = reader.get("FT")
+ self.assertEqual(data_readed, self._get_struct_FT())
# 2.2 check PYHSTEXE
- data_readed = reader.get('PYHSTEXE')
- self.assertEqual(data_readed, self._get_struct_PYHSTEXE() )
+ data_readed = reader.get("PYHSTEXE")
+ self.assertEqual(data_readed, self._get_struct_PYHSTEXE())
# 2.3 check FTAXIS
- data_readed = reader.get('FTAXIS')
- self.assertEqual(data_readed, self._get_struct_FTAXIS() )
+ data_readed = reader.get("FTAXIS")
+ self.assertEqual(data_readed, self._get_struct_FTAXIS())
# 2.4 check PAGANIN
- data_readed = reader.get('PAGANIN')
- self.assertEqual(data_readed, self._get_struct_PAGANIN() )
+ data_readed = reader.get("PAGANIN")
+ self.assertEqual(data_readed, self._get_struct_PAGANIN())
# 2.5 check BEAMGEO
- data_readed = reader.get('BEAMGEO')
- self.assertEqual(data_readed, self._get_struct_BEAMGEO() )
+ data_readed = reader.get("BEAMGEO")
+ self.assertEqual(data_readed, self._get_struct_BEAMGEO())
reader.close()
def testWritedIsReadedOldOctaveVersion(self):
- """The same test as testWritedIsReaded but for octave version < 3.8
- """
+ """The same test as testWritedIsReaded but for octave version < 3.8"""
# test for octave version < 3.8
writer = Octaveh5(3.6)
- writer.open(self.test_3_6_fname, 'a')
+ writer.open(self.test_3_6_fname, "a")
# step 1 writing the file
- writer.write('FT', self._get_struct_FT())
- writer.write('PYHSTEXE', self._get_struct_PYHSTEXE())
- writer.write('FTAXIS', self._get_struct_FTAXIS())
- writer.write('PAGANIN', self._get_struct_PAGANIN())
- writer.write('BEAMGEO', self._get_struct_BEAMGEO())
+ writer.write("FT", self._get_struct_FT())
+ writer.write("PYHSTEXE", self._get_struct_PYHSTEXE())
+ writer.write("FTAXIS", self._get_struct_FTAXIS())
+ writer.write("PAGANIN", self._get_struct_PAGANIN())
+ writer.write("BEAMGEO", self._get_struct_BEAMGEO())
writer.close()
# step 2 reading the file
reader = Octaveh5(3.6).open(self.test_3_6_fname)
# 2.1 check FT
- data_readed = reader.get('FT')
- self.assertEqual(data_readed, self._get_struct_FT() )
+ data_readed = reader.get("FT")
+ self.assertEqual(data_readed, self._get_struct_FT())
# 2.2 check PYHSTEXE
- data_readed = reader.get('PYHSTEXE')
- self.assertEqual(data_readed, self._get_struct_PYHSTEXE() )
+ data_readed = reader.get("PYHSTEXE")
+ self.assertEqual(data_readed, self._get_struct_PYHSTEXE())
# 2.3 check FTAXIS
- data_readed = reader.get('FTAXIS')
- self.assertEqual(data_readed, self._get_struct_FTAXIS() )
+ data_readed = reader.get("FTAXIS")
+ self.assertEqual(data_readed, self._get_struct_FTAXIS())
# 2.4 check PAGANIN
- data_readed = reader.get('PAGANIN')
- self.assertEqual(data_readed, self._get_struct_PAGANIN() )
+ data_readed = reader.get("PAGANIN")
+ self.assertEqual(data_readed, self._get_struct_PAGANIN())
# 2.5 check BEAMGEO
- data_readed = reader.get('BEAMGEO')
- self.assertEqual(data_readed, self._get_struct_BEAMGEO() )
+ data_readed = reader.get("BEAMGEO")
+ self.assertEqual(data_readed, self._get_struct_BEAMGEO())
reader.close()
diff --git a/src/silx/io/test/test_rawh5.py b/src/silx/io/test/test_rawh5.py
index 947be0f..fb5caec 100644
--- a/src/silx/io/test/test_rawh5.py
+++ b/src/silx/io/test/test_rawh5.py
@@ -32,11 +32,10 @@ import unittest
import tempfile
import numpy
import shutil
-from ..import rawh5
+from .. import rawh5
class TestNumpyFile(unittest.TestCase):
-
@classmethod
def setUpClass(cls):
cls.tmpDirectory = tempfile.mkdtemp()
@@ -55,11 +54,11 @@ class TestNumpyFile(unittest.TestCase):
def testNumpyZFile(self):
filename = "%s/%s.npz" % (self.tmpDirectory, self.id())
- a = numpy.array(u"aaaaa")
+ a = numpy.array("aaaaa")
b = numpy.array([1, 2, 3, 4])
c = numpy.random.rand(5, 5)
d = numpy.array(b"aaaaa")
- e = numpy.array(u"i \u2661 my mother")
+ e = numpy.array("i \u2661 my mother")
numpy.savez(filename, a, b=b, c=c, d=d, e=e)
h5 = rawh5.NumpyFile(filename)
self.assertIn("arr_0", h5)
@@ -76,8 +75,8 @@ class TestNumpyFile(unittest.TestCase):
def testNumpyZFileContainingDirectories(self):
filename = "%s/%s.npz" % (self.tmpDirectory, self.id())
data = {}
- data['a/b/c'] = numpy.arange(10)
- data['a/b/e'] = numpy.arange(10)
+ data["a/b/c"] = numpy.arange(10)
+ data["a/b/e"] = numpy.arange(10)
numpy.savez(filename, **data)
h5 = rawh5.NumpyFile(filename)
self.assertIn("a/b/c", h5)
diff --git a/src/silx/io/test/test_sliceh5.py b/src/silx/io/test/test_sliceh5.py
new file mode 100644
index 0000000..8ccf14a
--- /dev/null
+++ b/src/silx/io/test/test_sliceh5.py
@@ -0,0 +1,104 @@
+# /*##########################################################################
+# Copyright (C) 2022-2023 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+import contextlib
+from io import BytesIO
+
+import h5py
+import numpy
+import pytest
+
+import silx.io
+from silx.io import commonh5
+from silx.io._sliceh5 import DatasetSlice, _combine_indices
+
+
+@contextlib.contextmanager
+def h5py_file(filename, mode):
+ with BytesIO() as buffer:
+ with h5py.File(buffer, mode) as h5file:
+ yield h5file
+
+
+@pytest.fixture(params=[commonh5.File, h5py_file])
+def temp_h5file(request):
+ temp_file_context = request.param
+ with temp_file_context("tempfile.h5", "w") as h5file:
+ yield h5file
+
+
+@pytest.mark.parametrize("indices", [1, slice(None), (1, slice(1, 4))])
+def test_datasetslice(temp_h5file, indices):
+ data = numpy.arange(50).reshape(10, 5)
+ ref_data = numpy.array(data[indices], copy=False)
+
+ h5dataset = temp_h5file.create_group("group").create_dataset("dataset", data=data)
+
+ with DatasetSlice(h5dataset, indices, attrs={}) as dset:
+ assert silx.io.is_dataset(dset)
+ assert dset.file == temp_h5file
+ assert dset.shape == ref_data.shape
+ assert dset.size == ref_data.size
+ assert dset.dtype == ref_data.dtype
+ assert len(dset) == len(ref_data)
+ assert numpy.array_equal(dset[()], ref_data)
+ assert dset.name == h5dataset.name
+
+
+def test_datasetslice_on_external_link(tmp_path):
+ data = numpy.arange(10).reshape(5, 2)
+
+ external_filename = str(tmp_path / "external.h5")
+ ext_dataset_name = "/external_data"
+ with h5py.File(external_filename, "w") as h5file:
+ h5file[ext_dataset_name] = data
+
+ with h5py.File(tmp_path / "test.h5", "w") as h5file:
+ h5file["group/data"] = h5py.ExternalLink(external_filename, ext_dataset_name)
+
+ with DatasetSlice(h5file["group/data"], slice(None), attrs={}) as dset:
+ assert dset.name == ext_dataset_name
+ assert numpy.array_equal(dset[()], data)
+
+
+@pytest.mark.parametrize(
+ "shape,outer_indices,indices",
+ [
+ ((2, 5, 10), (-1, slice(None), slice(None)), slice(None)),
+ ((2, 5, 10), (-1, slice(None), slice(None)), Ellipsis),
+ # negative strides
+ ((5, 10), (slice(1, 5, 2), slice(2, 8)), (slice(2, 3), slice(4, None, -2))),
+ (
+ (5, 10),
+ (slice(4, None, -1), slice(9, 3, -2)),
+ (slice(1, 3), slice(3, 0, -1)),
+ ),
+ ((5, 10), (slice(1, 8, 2), slice(None)), slice(2, 8)), # slice overflow
+ ],
+)
+def test_combine_indices(shape, outer_indices, indices):
+ data = numpy.arange(numpy.prod(shape)).reshape(shape)
+ ref_data = data[outer_indices][indices]
+
+ combined_indices = _combine_indices(shape, outer_indices, indices)
+
+ assert numpy.array_equal(data[combined_indices], ref_data)
diff --git a/src/silx/io/test/test_specfile.py b/src/silx/io/test/test_specfile.py
index 748e31c..1b84a65 100644
--- a/src/silx/io/test/test_specfile.py
+++ b/src/silx/io/test/test_specfile.py
@@ -1,5 +1,5 @@
# /*##########################################################################
-# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
+# Copyright (C) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -123,7 +123,7 @@ sftext = """#F /tmp/sf.dat
loc = locale.getlocale(locale.LC_NUMERIC)
try:
- locale.setlocale(locale.LC_NUMERIC, 'de_DE.utf8')
+ locale.setlocale(locale.LC_NUMERIC, "de_DE.utf8")
except locale.Error:
try_DE = False
else:
@@ -135,25 +135,16 @@ class TestSpecFile(unittest.TestCase):
@classmethod
def setUpClass(cls):
fd, cls.fname1 = tempfile.mkstemp(text=False)
- if sys.version_info < (3, ):
- os.write(fd, sftext)
- else:
- os.write(fd, bytes(sftext, 'ascii'))
+ os.write(fd, bytes(sftext, "ascii"))
os.close(fd)
fd2, cls.fname2 = tempfile.mkstemp(text=False)
- if sys.version_info < (3, ):
- os.write(fd2, sftext[370:923])
- else:
- os.write(fd2, bytes(sftext[370:923], 'ascii'))
+ os.write(fd2, bytes(sftext[370:923], "ascii"))
os.close(fd2)
fd3, cls.fname3 = tempfile.mkstemp(text=False)
txt = sftext[371:923]
- if sys.version_info < (3, ):
- os.write(fd3, txt)
- else:
- os.write(fd3, bytes(txt, 'ascii'))
+ os.write(fd3, bytes(txt, "ascii"))
os.close(fd3)
@classmethod
@@ -186,58 +177,42 @@ class TestSpecFile(unittest.TestCase):
SpecFile("doesnt_exist.dat")
# test filename types unicode and bytes
- if sys.version_info[0] < 3:
- try:
- SpecFile(self.fname1)
- except TypeError:
- self.fail("failed to handle filename as python2 str")
- try:
- SpecFile(unicode(self.fname1))
- except TypeError:
- self.fail("failed to handle filename as python2 unicode")
- else:
- try:
- SpecFile(self.fname1)
- except TypeError:
- self.fail("failed to handle filename as python3 str")
- try:
- SpecFile(bytes(self.fname1, 'utf-8'))
- except TypeError:
- self.fail("failed to handle filename as python3 bytes")
+ try:
+ SpecFile(self.fname1)
+ except TypeError:
+ self.fail("failed to handle filename as python3 str")
+ try:
+ SpecFile(bytes(self.fname1, "utf-8"))
+ except TypeError:
+ self.fail("failed to handle filename as python3 bytes")
def test_number_of_scans(self):
self.assertEqual(4, len(self.sf))
def test_list_of_scan_indices(self):
- self.assertEqual(self.sf.list(),
- [1, 25, 26, 1])
- self.assertEqual(self.sf.keys(),
- ["1.1", "25.1", "26.1", "1.2"])
+ self.assertEqual(self.sf.list(), [1, 25, 26, 1])
+ self.assertEqual(self.sf.keys(), ["1.1", "25.1", "26.1", "1.2"])
def test_index_number_order(self):
self.assertEqual(self.sf.index(1, 2), 3) # sf["1.2"]==sf[3]
- self.assertEqual(self.sf.number(1), 25) # sf[1]==sf["25"]
- self.assertEqual(self.sf.order(3), 2) # sf[3]==sf["1.2"]
+ self.assertEqual(self.sf.number(1), 25) # sf[1]==sf["25"]
+ self.assertEqual(self.sf.order(3), 2) # sf[3]==sf["1.2"]
with self.assertRaises(specfile.SfErrScanNotFound):
self.sf.index(3, 2)
with self.assertRaises(specfile.SfErrScanNotFound):
self.sf.index(99)
def assertRaisesRegex(self, *args, **kwargs):
- # Python 2 compatibility
- if sys.version_info.major >= 3:
- return super(TestSpecFile, self).assertRaisesRegex(*args, **kwargs)
- else:
- return self.assertRaisesRegexp(*args, **kwargs)
+ return super(TestSpecFile, self).assertRaisesRegex(*args, **kwargs)
def test_getitem(self):
self.assertIsInstance(self.sf[2], Scan)
self.assertIsInstance(self.sf["1.2"], Scan)
# int out of range
- with self.assertRaisesRegex(IndexError, 'Scan index must be in ran'):
+ with self.assertRaisesRegex(IndexError, "Scan index must be in ran"):
self.sf[107]
# float indexing not allowed
- with self.assertRaisesRegex(TypeError, 'The scan identification k'):
+ with self.assertRaisesRegex(TypeError, "The scan identification k"):
self.sf[1.2]
# non existant scan with "N.M" indexing
with self.assertRaises(KeyError):
@@ -247,8 +222,7 @@ class TestSpecFile(unittest.TestCase):
i = 0
for scan in self.sf:
if i == 1:
- self.assertEqual(scan.motor_positions,
- self.sf[1].motor_positions)
+ self.assertEqual(scan.motor_positions, self.sf[1].motor_positions)
i += 1
# number of returned scans
self.assertEqual(i, len(self.sf))
@@ -259,63 +233,64 @@ class TestSpecFile(unittest.TestCase):
self.assertEqual(self.scan25.index, 1)
def test_scan_headers(self):
- self.assertEqual(self.scan25.scan_header_dict['S'],
- "25 ascan c3th 1.33245 1.52245 40 0.15")
- self.assertEqual(self.scan1.header[17], '#G0 0')
+ self.assertEqual(
+ self.scan25.scan_header_dict["S"],
+ "25 ascan c3th 1.33245 1.52245 40 0.15",
+ )
+ self.assertEqual(self.scan1.header[17], "#G0 0")
self.assertEqual(len(self.scan1.header), 29)
# parsing headers with long keys
- self.assertEqual(self.scan1.scan_header_dict['UMI0'],
- 'Current AutoM Shutter')
+ self.assertEqual(
+ self.scan1.scan_header_dict["UMI0"], "Current AutoM Shutter"
+ )
# parsing empty headers
- self.assertEqual(self.scan1.scan_header_dict['Q'], '')
+ self.assertEqual(self.scan1.scan_header_dict["Q"], "")
# duplicate headers: concatenated (with newline)
- self.assertEqual(self.scan1_2.scan_header_dict["U"],
- "first duplicate line\nsecond duplicate line")
+ self.assertEqual(
+ self.scan1_2.scan_header_dict["U"],
+ "first duplicate line\nsecond duplicate line",
+ )
def test_file_headers(self):
- self.assertEqual(self.scan1.header[1],
- '#E 1455180875')
- self.assertEqual(self.scan1.file_header_dict['F'],
- '/tmp/sf.dat')
+ self.assertEqual(self.scan1.header[1], "#E 1455180875")
+ self.assertEqual(self.scan1.file_header_dict["F"], "/tmp/sf.dat")
def test_multiple_file_headers(self):
"""Scan 1.2 is after the second file header, with a different
Epoch"""
- self.assertEqual(self.scan1_2.header[1],
- '#E 1455180876')
+ self.assertEqual(self.scan1_2.header[1], "#E 1455180876")
def test_scan_labels(self):
- self.assertEqual(self.scan1.labels,
- ['first column', 'second column', '3rd_col'])
+ self.assertEqual(
+ self.scan1.labels, ["first column", "second column", "3rd_col"]
+ )
def test_data(self):
# data_line() and data_col() take 1-based indices as arg
- self.assertAlmostEqual(self.scan1.data_line(1)[2],
- 1.56)
+ self.assertAlmostEqual(self.scan1.data_line(1)[2], 1.56)
# tests for data transposition between original file and .data attr
- self.assertAlmostEqual(self.scan1.data[2, 0],
- 8)
+ self.assertAlmostEqual(self.scan1.data[2, 0], 8)
self.assertEqual(self.scan1.data.shape, (3, 4))
self.assertAlmostEqual(numpy.sum(self.scan1.data), 113.631)
def test_data_column_by_name(self):
- self.assertAlmostEqual(self.scan25.data_column_by_name("col2")[1],
- 1.2)
+ self.assertAlmostEqual(self.scan25.data_column_by_name("col2")[1], 1.2)
# Scan.data is transposed after readinq, so column is the first index
- self.assertAlmostEqual(numpy.sum(self.scan25.data_column_by_name("col2")),
- numpy.sum(self.scan25.data[2, :]))
+ self.assertAlmostEqual(
+ numpy.sum(self.scan25.data_column_by_name("col2")),
+ numpy.sum(self.scan25.data[2, :]),
+ )
with self.assertRaises(specfile.SfErrColNotFound):
self.scan25.data_column_by_name("ygfxgfyxg")
def test_motors(self):
self.assertEqual(len(self.scan1.motor_names), 6)
self.assertEqual(len(self.scan1.motor_positions), 6)
- self.assertAlmostEqual(sum(self.scan1.motor_positions),
- 223.385912)
- self.assertEqual(self.scan1.motor_names[1], 'MRTSlit UP')
+ self.assertAlmostEqual(sum(self.scan1.motor_positions), 223.385912)
+ self.assertEqual(self.scan1.motor_names[1], "MRTSlit UP")
self.assertAlmostEqual(
- self.scan25.motor_position_by_name('MRTSlit UP'),
- -1.66875)
+ self.scan25.motor_position_by_name("MRTSlit UP"), -1.66875
+ )
def test_absence_of_file_header(self):
"""We expect Scan.file_header to be an empty list in the absence
@@ -324,8 +299,7 @@ class TestSpecFile(unittest.TestCase):
self.assertEqual(len(self.scan1_no_fhdr.motor_names), 0)
# motor positions can still be read in the scan header
# even in the absence of motor names
- self.assertAlmostEqual(sum(self.scan1_no_fhdr.motor_positions),
- 223.385912)
+ self.assertAlmostEqual(sum(self.scan1_no_fhdr.motor_positions), 223.385912)
self.assertEqual(len(self.scan1_no_fhdr.header), 15)
self.assertEqual(len(self.scan1_no_fhdr.scan_header), 15)
self.assertEqual(len(self.scan1_no_fhdr.file_header), 0)
@@ -337,8 +311,9 @@ class TestSpecFile(unittest.TestCase):
self.assertEqual(len(self.scan1_no_fhdr_crash.motor_names), 0)
# motor positions can still be read in the scan header
# even in the absence of motor names
- self.assertAlmostEqual(sum(self.scan1_no_fhdr_crash.motor_positions),
- 223.385912)
+ self.assertAlmostEqual(
+ sum(self.scan1_no_fhdr_crash.motor_positions), 223.385912
+ )
self.assertEqual(len(self.scan1_no_fhdr_crash.scan_header), 15)
self.assertEqual(len(self.scan1_no_fhdr_crash.file_header), 0)
@@ -349,8 +324,9 @@ class TestSpecFile(unittest.TestCase):
self.assertEqual(sum(self.scan1_2.mca[2]), 21.7)
# Negative indexing
- self.assertEqual(sum(self.scan1_2.mca[len(self.scan1_2.mca) - 1]),
- sum(self.scan1_2.mca[-1]))
+ self.assertEqual(
+ sum(self.scan1_2.mca[len(self.scan1_2.mca) - 1]), sum(self.scan1_2.mca[-1])
+ )
# Test iterator
line_count, total_sum = (0, 0)
@@ -364,34 +340,26 @@ class TestSpecFile(unittest.TestCase):
self.assertEqual(self.scan1.mca_header_dict, {})
self.assertEqual(len(self.scan1_2.mca_header_dict), 4)
self.assertEqual(self.scan1_2.mca_header_dict["CALIB"], "1 2 3")
- self.assertEqual(self.scan1_2.mca.calibration,
- [[1., 2., 3.]])
+ self.assertEqual(self.scan1_2.mca.calibration, [[1.0, 2.0, 3.0]])
# default calib in the absence of #@CALIB
- self.assertEqual(self.scan25.mca.calibration,
- [[0., 1., 0.]])
- self.assertEqual(self.scan1_2.mca.channels,
- [[0, 1, 2]])
+ self.assertEqual(self.scan25.mca.calibration, [[0.0, 1.0, 0.0]])
+ self.assertEqual(self.scan1_2.mca.channels, [[0, 1, 2]])
# absence of #@CHANN and spectra
- self.assertEqual(self.scan25.mca.channels,
- [])
+ self.assertEqual(self.scan25.mca.channels, [])
@testutils.validate_logging(specfile._logger.name, warning=1)
def test_empty_scan(self):
"""Test reading a scan with no data points"""
- self.assertEqual(len(self.empty_scan.labels),
- 3)
+ self.assertEqual(len(self.empty_scan.labels), 3)
col1 = self.empty_scan.data_column_by_name("second column")
- self.assertEqual(col1.shape, (0, ))
+ self.assertEqual(col1.shape, (0,))
class TestSFLocale(unittest.TestCase):
@classmethod
def setUpClass(cls):
fd, cls.fname = tempfile.mkstemp(text=False)
- if sys.version_info < (3, ):
- os.write(fd, sftext)
- else:
- os.write(fd, bytes(sftext, 'ascii'))
+ os.write(fd, bytes(sftext, "ascii"))
os.close(fd)
@classmethod
@@ -401,19 +369,18 @@ class TestSFLocale(unittest.TestCase):
def crunch_data(self):
self.sf3 = SpecFile(self.fname)
- self.assertAlmostEqual(self.sf3[0].data_line(1)[2],
- 1.56)
+ self.assertAlmostEqual(self.sf3[0].data_line(1)[2], 1.56)
self.sf3.close()
@unittest.skipIf(not try_DE, "de_DE.utf8 locale not installed")
def test_locale_de_DE(self):
- locale.setlocale(locale.LC_NUMERIC, 'de_DE.utf8')
+ locale.setlocale(locale.LC_NUMERIC, "de_DE.utf8")
self.crunch_data()
def test_locale_user(self):
- locale.setlocale(locale.LC_NUMERIC, '') # use user's preferred locale
+ locale.setlocale(locale.LC_NUMERIC, "") # use user's preferred locale
self.crunch_data()
def test_locale_C(self):
- locale.setlocale(locale.LC_NUMERIC, 'C') # use default (C) locale
+ locale.setlocale(locale.LC_NUMERIC, "C") # use default (C) locale
self.crunch_data()
diff --git a/src/silx/io/test/test_specfilewrapper.py b/src/silx/io/test/test_specfilewrapper.py
index 7d2ce60..e830023 100644
--- a/src/silx/io/test/test_specfilewrapper.py
+++ b/src/silx/io/test/test_specfilewrapper.py
@@ -1,5 +1,5 @@
# /*##########################################################################
-# Copyright (C) 2016 European Synchrotron Radiation Facility
+# Copyright (C) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -26,7 +26,6 @@ __authors__ = ["P. Knobel"]
__license__ = "MIT"
__date__ = "15/05/2017"
-import locale
import logging
import numpy
import os
@@ -112,10 +111,7 @@ class TestSpecfilewrapper(unittest.TestCase):
@classmethod
def setUpClass(cls):
fd, cls.fname1 = tempfile.mkstemp(text=False)
- if sys.version_info < (3, ):
- os.write(fd, sftext)
- else:
- os.write(fd, bytes(sftext, 'ascii'))
+ os.write(fd, bytes(sftext, "ascii"))
os.close(fd)
@classmethod
@@ -135,60 +131,59 @@ class TestSpecfilewrapper(unittest.TestCase):
self.assertEqual(3, len(self.sf))
def test_list_of_scan_indices(self):
- self.assertEqual(self.sf.list(),
- '1,25,1')
- self.assertEqual(self.sf.keys(),
- ["1.1", "25.1", "1.2"])
+ self.assertEqual(self.sf.list(), "1,25,1")
+ self.assertEqual(self.sf.keys(), ["1.1", "25.1", "1.2"])
def test_scan_headers(self):
- self.assertEqual(self.scan25.header('S'),
- ["#S 25 ascan c3th 1.33245 1.52245 40 0.15"])
- self.assertEqual(self.scan1.header("G0"), ['#G0 0'])
+ self.assertEqual(
+ self.scan25.header("S"), ["#S 25 ascan c3th 1.33245 1.52245 40 0.15"]
+ )
+ self.assertEqual(self.scan1.header("G0"), ["#G0 0"])
# parsing headers with long keys
# parsing empty headers
- self.assertEqual(self.scan1.header('Q'), ['#Q '])
+ self.assertEqual(self.scan1.header("Q"), ["#Q "])
def test_file_headers(self):
- self.assertEqual(self.scan1.header("E"),
- ['#E 1455180875'])
- self.assertEqual(self.sf.title(),
- "imaging")
- self.assertEqual(self.sf.epoch(),
- 1455180875)
- self.assertEqual(self.sf.allmotors(),
- ["Pslit HGap", "MRTSlit UP", "MRTSlit DOWN",
- "Sslit1 VOff", "Sslit1 HOff", "Sslit1 VGap"])
+ self.assertEqual(self.scan1.header("E"), ["#E 1455180875"])
+ self.assertEqual(self.sf.title(), "imaging")
+ self.assertEqual(self.sf.epoch(), 1455180875)
+ self.assertEqual(
+ self.sf.allmotors(),
+ [
+ "Pslit HGap",
+ "MRTSlit UP",
+ "MRTSlit DOWN",
+ "Sslit1 VOff",
+ "Sslit1 HOff",
+ "Sslit1 VGap",
+ ],
+ )
def test_scan_labels(self):
- self.assertEqual(self.scan1.alllabels(),
- ['first column', 'second column', '3rd_col'])
+ self.assertEqual(
+ self.scan1.alllabels(), ["first column", "second column", "3rd_col"]
+ )
def test_data(self):
- self.assertAlmostEqual(self.scan1.dataline(3)[2],
- -3.14)
- self.assertAlmostEqual(self.scan1.datacol(1)[2],
- 3.14)
+ self.assertAlmostEqual(self.scan1.dataline(3)[2], -3.14)
+ self.assertAlmostEqual(self.scan1.datacol(1)[2], 3.14)
# tests for data transposition between original file and .data attr
- self.assertAlmostEqual(self.scan1.data()[2, 0],
- 8)
+ self.assertAlmostEqual(self.scan1.data()[2, 0], 8)
self.assertEqual(self.scan1.data().shape, (3, 4))
self.assertAlmostEqual(numpy.sum(self.scan1.data()), 113.631)
def test_date(self):
- self.assertEqual(self.scan1.date(),
- "Thu Feb 11 09:55:20 2016")
+ self.assertEqual(self.scan1.date(), "Thu Feb 11 09:55:20 2016")
def test_motors(self):
self.assertEqual(len(self.sf.allmotors()), 6)
self.assertEqual(len(self.scan1.allmotorpos()), 6)
- self.assertAlmostEqual(sum(self.scan1.allmotorpos()),
- 223.385912)
- self.assertEqual(self.sf.allmotors()[1], 'MRTSlit UP')
+ self.assertAlmostEqual(sum(self.scan1.allmotorpos()), 223.385912)
+ self.assertEqual(self.sf.allmotors()[1], "MRTSlit UP")
def test_mca(self):
self.assertEqual(self.scan1_2.mca(2)[2], 5)
self.assertEqual(sum(self.scan1_2.mca(3)), 21.7)
def test_mca_header(self):
- self.assertEqual(self.scan1_2.header("CALIB"),
- ["#@CALIB 1 2 3"])
+ self.assertEqual(self.scan1_2.header("CALIB"), ["#@CALIB 1 2 3"])
diff --git a/src/silx/io/test/test_spech5.py b/src/silx/io/test/test_spech5.py
index 456a538..93175f7 100644
--- a/src/silx/io/test/test_spech5.py
+++ b/src/silx/io/test/test_spech5.py
@@ -1,5 +1,5 @@
# /*##########################################################################
-# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
+# Copyright (C) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -24,7 +24,6 @@
import numpy
import os
import io
-import sys
import tempfile
import unittest
import datetime
@@ -33,7 +32,7 @@ from functools import partial
from silx.utils import testutils
from .. import spech5
-from ..spech5 import (SpecH5, SpecH5Dataset, spec_date_to_iso8601)
+from ..spech5 import SpecH5, SpecH5Dataset, spec_date_to_iso8601
from .. import specfile
import h5py
@@ -118,19 +117,22 @@ class TestSpecDate(unittest.TestCase):
"""
Test of the spec_date_to_iso8601 function.
"""
+
# TODO : time zone tests
# TODO : error cases
@classmethod
def setUpClass(cls):
import locale
+
# FYI : not threadsafe
cls.locale_saved = locale.setlocale(locale.LC_TIME)
- locale.setlocale(locale.LC_TIME, 'C')
+ locale.setlocale(locale.LC_TIME, "C")
@classmethod
def tearDownClass(cls):
import locale
+
# FYI : not threadsafe
locale.setlocale(locale.LC_TIME, cls.locale_saved)
@@ -145,75 +147,64 @@ class TestSpecDate(unittest.TestCase):
self.n_minutes = [0, 9, 42, 59]
self.n_hours = [0, 2, 17, 23]
- self.formats = ['%a %b %d %H:%M:%S %Y', '%a %Y/%m/%d %H:%M:%S']
-
- self.check_date_formats = partial(self.__check_date_formats,
- year=self.n_years[0],
- month=self.n_months[0],
- day=self.n_days[0],
- hour=self.n_hours[0],
- minute=self.n_minutes[0],
- second=self.n_seconds[0],
- msg=None)
-
- def __check_date_formats(self,
- year,
- month,
- day,
- hour,
- minute,
- second,
- msg=None):
+ self.formats = ["%a %b %d %H:%M:%S %Y", "%a %Y/%m/%d %H:%M:%S"]
+
+ self.check_date_formats = partial(
+ self.__check_date_formats,
+ year=self.n_years[0],
+ month=self.n_months[0],
+ day=self.n_days[0],
+ hour=self.n_hours[0],
+ minute=self.n_minutes[0],
+ second=self.n_seconds[0],
+ msg=None,
+ )
+
+ def __check_date_formats(self, year, month, day, hour, minute, second, msg=None):
dt = datetime.datetime(year, month, day, hour, minute, second)
expected_date = dt.isoformat()
for i_fmt, fmt in enumerate(self.formats):
spec_date = dt.strftime(fmt)
iso_date = spec_date_to_iso8601(spec_date)
- self.assertEqual(iso_date,
- expected_date,
- msg='Testing {0}. format={1}. '
- 'Expected "{2}", got "{3} ({4})" (dt={5}).'
- ''.format(msg,
- i_fmt,
- expected_date,
- iso_date,
- spec_date,
- dt))
+ self.assertEqual(
+ iso_date,
+ expected_date,
+ msg="Testing {0}. format={1}. "
+ 'Expected "{2}", got "{3} ({4})" (dt={5}).'
+ "".format(msg, i_fmt, expected_date, iso_date, spec_date, dt),
+ )
def testYearsNominal(self):
for year in self.n_years:
- self.check_date_formats(year=year, msg='year')
+ self.check_date_formats(year=year, msg="year")
def testMonthsNominal(self):
for month in self.n_months:
- self.check_date_formats(month=month, msg='month')
+ self.check_date_formats(month=month, msg="month")
def testDaysNominal(self):
for day in self.n_days:
- self.check_date_formats(day=day, msg='day')
+ self.check_date_formats(day=day, msg="day")
def testHoursNominal(self):
for hour in self.n_hours:
- self.check_date_formats(hour=hour, msg='hour')
+ self.check_date_formats(hour=hour, msg="hour")
def testMinutesNominal(self):
for minute in self.n_minutes:
- self.check_date_formats(minute=minute, msg='minute')
+ self.check_date_formats(minute=minute, msg="minute")
def testSecondsNominal(self):
for second in self.n_seconds:
- self.check_date_formats(second=second, msg='second')
+ self.check_date_formats(second=second, msg="second")
class TestSpecH5(unittest.TestCase):
@classmethod
def setUpClass(cls):
fd, cls.fname = tempfile.mkstemp()
- if sys.version_info < (3, ):
- os.write(fd, sftext)
- else:
- os.write(fd, bytes(sftext, 'ascii'))
+ os.write(fd, bytes(sftext, "ascii"))
os.close(fd)
@classmethod
@@ -251,32 +242,25 @@ class TestSpecH5(unittest.TestCase):
self.assertNotIn("25.2", self.sfh5["/"])
self.assertIn("instrument/positioners/Sslit1 HOff", self.sfh5["/1.1"])
# illegal trailing "/" after dataset name
- self.assertNotIn("instrument/positioners/Sslit1 HOff/",
- self.sfh5["/1.1"])
+ self.assertNotIn("instrument/positioners/Sslit1 HOff/", self.sfh5["/1.1"])
# full path to element in group (OK)
- self.assertIn("/1.1/instrument/positioners/Sslit1 HOff",
- self.sfh5["/1.1/instrument"])
+ self.assertIn(
+ "/1.1/instrument/positioners/Sslit1 HOff", self.sfh5["/1.1/instrument"]
+ )
def testDataColumn(self):
- self.assertAlmostEqual(sum(self.sfh5["/1.2/measurement/duo"]),
- 12.0)
+ self.assertAlmostEqual(sum(self.sfh5["/1.2/measurement/duo"]), 12.0)
self.assertAlmostEqual(
- sum(self.sfh5["1.1"]["measurement"]["MRTSlit UP"]),
- 87.891, places=4)
+ sum(self.sfh5["1.1"]["measurement"]["MRTSlit UP"]), 87.891, places=4
+ )
def testDate(self):
# start time is in Iso8601 format
- self.assertEqual(self.sfh5["/1.1/start_time"],
- u"2016-02-11T09:55:20")
- self.assertEqual(self.sfh5["25.1/start_time"],
- u"2015-03-14T03:53:50")
+ self.assertEqual(self.sfh5["/1.1/start_time"], "2016-02-11T09:55:20")
+ self.assertEqual(self.sfh5["25.1/start_time"], "2015-03-14T03:53:50")
def assertRaisesRegex(self, *args, **kwargs):
- # Python 2 compatibility
- if sys.version_info.major >= 3:
- return super(TestSpecH5, self).assertRaisesRegex(*args, **kwargs)
- else:
- return self.assertRaisesRegexp(*args, **kwargs)
+ return super(TestSpecH5, self).assertRaisesRegex(*args, **kwargs)
def testDatasetInstanceAttr(self):
"""The SpecH5Dataset objects must implement some dummy attributes
@@ -286,26 +270,24 @@ class TestSpecH5(unittest.TestCase):
# error message must be explicit
with self.assertRaisesRegex(
- AttributeError,
- "SpecH5Dataset has no attribute tOTo"):
+ AttributeError, "SpecH5Dataset has no attribute tOTo"
+ ):
dummy = self.sfh5["/1.1/start_time"].tOTo
def testGet(self):
"""Test :meth:`SpecH5Group.get`"""
# default value of param *default* is None
self.assertIsNone(self.sfh5.get("toto"))
- self.assertEqual(self.sfh5["25.1"].get("toto", default=-3),
- -3)
+ self.assertEqual(self.sfh5["25.1"].get("toto", default=-3), -3)
- self.assertEqual(self.sfh5.get("/1.1/start_time", default=-3),
- u"2016-02-11T09:55:20")
+ self.assertEqual(
+ self.sfh5.get("/1.1/start_time", default=-3), "2016-02-11T09:55:20"
+ )
def testGetClass(self):
"""Test :meth:`SpecH5Group.get`"""
- self.assertIs(self.sfh5["1.1"].get("start_time", getclass=True),
- h5py.Dataset)
- self.assertIs(self.sfh5["1.1"].get("instrument", getclass=True),
- h5py.Group)
+ self.assertIs(self.sfh5["1.1"].get("start_time", getclass=True), h5py.Dataset)
+ self.assertIs(self.sfh5["1.1"].get("instrument", getclass=True), h5py.Group)
# spech5 does not define external link, so there is no way
# a group can *get* a SpecH5 class
@@ -322,30 +304,37 @@ class TestSpecH5(unittest.TestCase):
def testGetItemGroup(self):
group = self.sfh5["25.1"]["instrument"]
- self.assertEqual(list(group["positioners"].keys()),
- ["Pslit HGap", "MRTSlit UP", "MRTSlit DOWN",
- "Sslit1 VOff", "Sslit1 HOff", "Sslit1 VGap"])
+ self.assertEqual(
+ list(group["positioners"].keys()),
+ [
+ "Pslit HGap",
+ "MRTSlit UP",
+ "MRTSlit DOWN",
+ "Sslit1 VOff",
+ "Sslit1 HOff",
+ "Sslit1 VGap",
+ ],
+ )
with self.assertRaises(KeyError):
group["Holy Grail"]
def testGetitemSpecH5(self):
- self.assertEqual(self.sfh5["/1.2/instrument/positioners"],
- self.sfh5["1.2"]["instrument"]["positioners"])
+ self.assertEqual(
+ self.sfh5["/1.2/instrument/positioners"],
+ self.sfh5["1.2"]["instrument"]["positioners"],
+ )
def testH5pyClass(self):
"""Test :attr:`h5py_class` returns the corresponding h5py class
(h5py.File, h5py.Group, h5py.Dataset)"""
a_file = self.sfh5
- self.assertIs(a_file.h5py_class,
- h5py.File)
+ self.assertIs(a_file.h5py_class, h5py.File)
a_group = self.sfh5["/1.2/measurement"]
- self.assertIs(a_group.h5py_class,
- h5py.Group)
+ self.assertIs(a_group.h5py_class, h5py.Group)
a_dataset = self.sfh5["/1.1/instrument/positioners/Sslit1 HOff"]
- self.assertIs(a_dataset.h5py_class,
- h5py.Dataset)
+ self.assertIs(a_dataset.h5py_class, h5py.Dataset)
def testHeader(self):
file_header = self.sfh5["/1.2/instrument/specfile/file_header"]
@@ -357,67 +346,79 @@ class TestSpecH5(unittest.TestCase):
self.assertEqual(len(scan_header), 9)
# line 4 of file header
- self.assertEqual(
- file_header[3],
- u"#C imaging User = opid17")
+ self.assertEqual(file_header[3], "#C imaging User = opid17")
# line 4 of scan header
scan_header = self.sfh5["25.1/instrument/specfile/scan_header"]
- self.assertEqual(
- scan_header[3],
- u"#P1 4.74255 6.197579 2.238283")
+ self.assertEqual(scan_header[3], "#P1 4.74255 6.197579 2.238283")
def testLinks(self):
- self.assertTrue(numpy.array_equal(
- self.sfh5["/1.2/measurement/mca_0/data"],
- self.sfh5["/1.2/instrument/mca_0/data"])
+ self.assertTrue(
+ numpy.array_equal(
+ self.sfh5["/1.2/measurement/mca_0/data"],
+ self.sfh5["/1.2/instrument/mca_0/data"],
+ )
+ )
+ self.assertTrue(
+ numpy.array_equal(
+ self.sfh5["/1.2/measurement/mca_0/info/data"],
+ self.sfh5["/1.2/instrument/mca_0/data"],
+ )
)
- self.assertTrue(numpy.array_equal(
- self.sfh5["/1.2/measurement/mca_0/info/data"],
- self.sfh5["/1.2/instrument/mca_0/data"])
+ self.assertTrue(
+ numpy.array_equal(
+ self.sfh5["/1.2/measurement/mca_0/info/channels"],
+ self.sfh5["/1.2/instrument/mca_0/channels"],
+ )
)
- self.assertTrue(numpy.array_equal(
- self.sfh5["/1.2/measurement/mca_0/info/channels"],
- self.sfh5["/1.2/instrument/mca_0/channels"])
+ self.assertEqual(
+ self.sfh5["/1.2/measurement/mca_0/info/"].keys(),
+ self.sfh5["/1.2/instrument/mca_0/"].keys(),
)
- self.assertEqual(self.sfh5["/1.2/measurement/mca_0/info/"].keys(),
- self.sfh5["/1.2/instrument/mca_0/"].keys())
- self.assertEqual(self.sfh5["/1.2/measurement/mca_0/info/preset_time"],
- self.sfh5["/1.2/instrument/mca_0/preset_time"])
- self.assertEqual(self.sfh5["/1.2/measurement/mca_0/info/live_time"],
- self.sfh5["/1.2/instrument/mca_0/live_time"])
- self.assertEqual(self.sfh5["/1.2/measurement/mca_0/info/elapsed_time"],
- self.sfh5["/1.2/instrument/mca_0/elapsed_time"])
+ self.assertEqual(
+ self.sfh5["/1.2/measurement/mca_0/info/preset_time"],
+ self.sfh5["/1.2/instrument/mca_0/preset_time"],
+ )
+ self.assertEqual(
+ self.sfh5["/1.2/measurement/mca_0/info/live_time"],
+ self.sfh5["/1.2/instrument/mca_0/live_time"],
+ )
+ self.assertEqual(
+ self.sfh5["/1.2/measurement/mca_0/info/elapsed_time"],
+ self.sfh5["/1.2/instrument/mca_0/elapsed_time"],
+ )
def testListScanIndices(self):
- self.assertEqual(list(self.sfh5.keys()),
- ["1.1", "25.1", "1.2", "1000.1", "1001.1"])
- self.assertEqual(self.sfh5["1.2"].attrs,
- {"NX_class": "NXentry", })
+ self.assertEqual(
+ list(self.sfh5.keys()), ["1.1", "25.1", "1.2", "1000.1", "1001.1"]
+ )
+ self.assertEqual(
+ self.sfh5["1.2"].attrs,
+ {
+ "NX_class": "NXentry",
+ },
+ )
def testMcaAbsent(self):
def access_absent_mca():
"""This must raise a KeyError, because scan 1.1 has no MCA"""
return self.sfh5["/1.1/measurement/mca_0/"]
+
self.assertRaises(KeyError, access_absent_mca)
def testMcaCalib(self):
mca0_calib = self.sfh5["/1.2/measurement/mca_0/info/calibration"]
mca1_calib = self.sfh5["/1.2/measurement/mca_1/info/calibration"]
- self.assertEqual(mca0_calib.tolist(),
- [1, 2, 3])
+ self.assertEqual(mca0_calib.tolist(), [1, 2, 3])
# calibration is unique in this scan and applies to all analysers
- self.assertEqual(mca0_calib.tolist(),
- mca1_calib.tolist())
+ self.assertEqual(mca0_calib.tolist(), mca1_calib.tolist())
def testMcaChannels(self):
mca0_chann = self.sfh5["/1.2/measurement/mca_0/info/channels"]
mca1_chann = self.sfh5["/1.2/measurement/mca_1/info/channels"]
- self.assertEqual(mca0_chann.tolist(),
- [0, 1, 2])
- self.assertEqual(mca0_chann.tolist(),
- mca1_chann.tolist())
+ self.assertEqual(mca0_chann.tolist(), [0, 1, 2])
+ self.assertEqual(mca0_chann.tolist(), mca1_chann.tolist())
def testMcaCtime(self):
"""Tests for #@CTIME mca header"""
@@ -428,31 +429,26 @@ class TestSpecH5(unittest.TestCase):
mca0_preset_time = self.sfh5["/1.2/instrument/mca_0/preset_time"]
mca1_preset_time = self.sfh5["/1.2/instrument/mca_1/preset_time"]
- self.assertLess(mca0_preset_time - 123.4,
- 10**-5)
+ self.assertLess(mca0_preset_time - 123.4, 10**-5)
# ctime is unique in a this scan and applies to all analysers
- self.assertEqual(mca0_preset_time,
- mca1_preset_time)
+ self.assertEqual(mca0_preset_time, mca1_preset_time)
mca0_live_time = self.sfh5["/1.2/instrument/mca_0/live_time"]
mca1_live_time = self.sfh5["/1.2/instrument/mca_1/live_time"]
- self.assertLess(mca0_live_time - 234.5,
- 10**-5)
- self.assertEqual(mca0_live_time,
- mca1_live_time)
+ self.assertLess(mca0_live_time - 234.5, 10**-5)
+ self.assertEqual(mca0_live_time, mca1_live_time)
mca0_elapsed_time = self.sfh5["/1.2/instrument/mca_0/elapsed_time"]
mca1_elapsed_time = self.sfh5["/1.2/instrument/mca_1/elapsed_time"]
- self.assertLess(mca0_elapsed_time - 345.6,
- 10**-5)
- self.assertEqual(mca0_elapsed_time,
- mca1_elapsed_time)
+ self.assertLess(mca0_elapsed_time - 345.6, 10**-5)
+ self.assertEqual(mca0_elapsed_time, mca1_elapsed_time)
def testMcaData(self):
# sum 1st MCA in scan 1.2 over rows
mca_0_data = self.sfh5["/1.2/measurement/mca_0/data"]
- for summed_row, expected in zip(mca_0_data.sum(axis=1).tolist(),
- [3.0, 12.1, 21.7]):
+ for summed_row, expected in zip(
+ mca_0_data.sum(axis=1).tolist(), [3.0, 12.1, 21.7]
+ ):
self.assertAlmostEqual(summed_row, expected, places=4)
# sum 3rd MCA in scan 1.2 along both axis
@@ -464,11 +460,11 @@ class TestSpecH5(unittest.TestCase):
def testMotorPosition(self):
positioners_group = self.sfh5["/1.1/instrument/positioners"]
# MRTSlit DOWN position is defined in #P0 san header line
- self.assertAlmostEqual(float(positioners_group["MRTSlit DOWN"]),
- 0.87125)
+ self.assertAlmostEqual(float(positioners_group["MRTSlit DOWN"]), 0.87125)
# MRTSlit UP position is defined in first data column
- for a, b in zip(positioners_group["MRTSlit UP"].tolist(),
- [-1.23, 8.478100E+01, 3.14, 1.2]):
+ for a, b in zip(
+ positioners_group["MRTSlit UP"].tolist(), [-1.23, 8.478100e01, 3.14, 1.2]
+ ):
self.assertAlmostEqual(float(a), b, places=4)
def testNumberMcaAnalysers(self):
@@ -476,41 +472,38 @@ class TestSpecH5(unittest.TestCase):
self.assertEqual(len(self.sfh5["1.2"]["measurement"]), 5)
def testTitle(self):
- self.assertEqual(self.sfh5["/25.1/title"],
- u"ascan c3th 1.33245 1.52245 40 0.15")
+ self.assertEqual(
+ self.sfh5["/25.1/title"], "ascan c3th 1.33245 1.52245 40 0.15"
+ )
def testValues(self):
group = self.sfh5["/25.1"]
self.assertTrue(hasattr(group, "values"))
self.assertTrue(callable(group.values))
- self.assertIn(self.sfh5["/25.1/title"],
- self.sfh5["/25.1"].values())
+ self.assertIn(self.sfh5["/25.1/title"], self.sfh5["/25.1"].values())
# visit and visititems ignore links
def testVisit(self):
name_list = []
self.sfh5.visit(name_list.append)
- self.assertIn('1.2/instrument/positioners/Pslit HGap', name_list)
+ self.assertIn("1.2/instrument/positioners/Pslit HGap", name_list)
self.assertIn("1.2/instrument/specfile/scan_header", name_list)
self.assertEqual(len(name_list), 117)
# test also visit of a subgroup, with various group name formats
name_list_leading_and_trailing_slash = []
- self.sfh5['/1.2/instrument/'].visit(name_list_leading_and_trailing_slash.append)
+ self.sfh5["/1.2/instrument/"].visit(name_list_leading_and_trailing_slash.append)
name_list_leading_slash = []
- self.sfh5['/1.2/instrument'].visit(name_list_leading_slash.append)
+ self.sfh5["/1.2/instrument"].visit(name_list_leading_slash.append)
name_list_trailing_slash = []
- self.sfh5['1.2/instrument/'].visit(name_list_trailing_slash.append)
+ self.sfh5["1.2/instrument/"].visit(name_list_trailing_slash.append)
name_list_no_slash = []
- self.sfh5['1.2/instrument'].visit(name_list_no_slash.append)
+ self.sfh5["1.2/instrument"].visit(name_list_no_slash.append)
# no differences expected in the output names
- self.assertEqual(name_list_leading_and_trailing_slash,
- name_list_leading_slash)
- self.assertEqual(name_list_leading_slash,
- name_list_trailing_slash)
- self.assertEqual(name_list_leading_slash,
- name_list_no_slash)
+ self.assertEqual(name_list_leading_and_trailing_slash, name_list_leading_slash)
+ self.assertEqual(name_list_leading_slash, name_list_trailing_slash)
+ self.assertEqual(name_list_leading_slash, name_list_no_slash)
self.assertIn("positioners/Pslit HGap", name_list_no_slash)
self.assertIn("positioners", name_list_no_slash)
@@ -519,32 +512,35 @@ class TestSpecH5(unittest.TestCase):
def func_generator(l):
"""return a function appending names to list l"""
+
def func(name, obj):
if isinstance(obj, SpecH5Dataset):
l.append(name)
+
return func
self.sfh5.visititems(func_generator(dataset_name_list))
- self.assertIn('1.2/instrument/positioners/Pslit HGap', dataset_name_list)
+ self.assertIn("1.2/instrument/positioners/Pslit HGap", dataset_name_list)
self.assertEqual(len(dataset_name_list), 85)
# test also visit of a subgroup, with various group name formats
name_list_leading_and_trailing_slash = []
- self.sfh5['/1.2/instrument/'].visititems(func_generator(name_list_leading_and_trailing_slash))
+ self.sfh5["/1.2/instrument/"].visititems(
+ func_generator(name_list_leading_and_trailing_slash)
+ )
name_list_leading_slash = []
- self.sfh5['/1.2/instrument'].visititems(func_generator(name_list_leading_slash))
+ self.sfh5["/1.2/instrument"].visititems(func_generator(name_list_leading_slash))
name_list_trailing_slash = []
- self.sfh5['1.2/instrument/'].visititems(func_generator(name_list_trailing_slash))
+ self.sfh5["1.2/instrument/"].visititems(
+ func_generator(name_list_trailing_slash)
+ )
name_list_no_slash = []
- self.sfh5['1.2/instrument'].visititems(func_generator(name_list_no_slash))
+ self.sfh5["1.2/instrument"].visititems(func_generator(name_list_no_slash))
# no differences expected in the output names
- self.assertEqual(name_list_leading_and_trailing_slash,
- name_list_leading_slash)
- self.assertEqual(name_list_leading_slash,
- name_list_trailing_slash)
- self.assertEqual(name_list_leading_slash,
- name_list_no_slash)
+ self.assertEqual(name_list_leading_and_trailing_slash, name_list_leading_slash)
+ self.assertEqual(name_list_leading_slash, name_list_trailing_slash)
+ self.assertEqual(name_list_leading_slash, name_list_no_slash)
self.assertIn("positioners/Pslit HGap", name_list_no_slash)
def testNotSpecH5(self):
@@ -609,10 +605,7 @@ class TestSpecH5MultiMca(unittest.TestCase):
@classmethod
def setUpClass(cls):
fd, cls.fname = tempfile.mkstemp(text=False)
- if sys.version_info < (3, ):
- os.write(fd, sftext_multi_mca_headers)
- else:
- os.write(fd, bytes(sftext_multi_mca_headers, 'ascii'))
+ os.write(fd, bytes(sftext_multi_mca_headers, "ascii"))
os.close(fd)
@classmethod
@@ -628,43 +621,32 @@ class TestSpecH5MultiMca(unittest.TestCase):
def testMcaCalib(self):
mca0_calib = self.sfh5["/1.1/measurement/mca_0/info/calibration"]
mca1_calib = self.sfh5["/1.1/measurement/mca_1/info/calibration"]
- self.assertEqual(mca0_calib.tolist(),
- [1, 2, 3])
- self.assertAlmostEqual(sum(mca1_calib.tolist()),
- sum([5.5, 6.6, 7.7]),
- places=5)
+ self.assertEqual(mca0_calib.tolist(), [1, 2, 3])
+ self.assertAlmostEqual(sum(mca1_calib.tolist()), sum([5.5, 6.6, 7.7]), places=5)
def testMcaChannels(self):
mca0_chann = self.sfh5["/1.1/measurement/mca_0/info/channels"]
mca1_chann = self.sfh5["/1.1/measurement/mca_1/info/channels"]
- self.assertEqual(mca0_chann.tolist(),
- [0., 1., 2.])
+ self.assertEqual(mca0_chann.tolist(), [0.0, 1.0, 2.0])
# @CHANN is unique in this scan and applies to all analysers
- self.assertEqual(mca1_chann.tolist(),
- [1., 2., 3.])
+ self.assertEqual(mca1_chann.tolist(), [1.0, 2.0, 3.0])
def testMcaCtime(self):
"""Tests for #@CTIME mca header"""
mca0_preset_time = self.sfh5["/1.1/instrument/mca_0/preset_time"]
mca1_preset_time = self.sfh5["/1.1/instrument/mca_1/preset_time"]
- self.assertLess(mca0_preset_time - 123.4,
- 10**-5)
- self.assertLess(mca1_preset_time - 10,
- 10**-5)
+ self.assertLess(mca0_preset_time - 123.4, 10**-5)
+ self.assertLess(mca1_preset_time - 10, 10**-5)
mca0_live_time = self.sfh5["/1.1/instrument/mca_0/live_time"]
mca1_live_time = self.sfh5["/1.1/instrument/mca_1/live_time"]
- self.assertLess(mca0_live_time - 234.5,
- 10**-5)
- self.assertLess(mca1_live_time - 11,
- 10**-5)
+ self.assertLess(mca0_live_time - 234.5, 10**-5)
+ self.assertLess(mca1_live_time - 11, 10**-5)
mca0_elapsed_time = self.sfh5["/1.1/instrument/mca_0/elapsed_time"]
mca1_elapsed_time = self.sfh5["/1.1/instrument/mca_1/elapsed_time"]
- self.assertLess(mca0_elapsed_time - 345.6,
- 10**-5)
- self.assertLess(mca1_elapsed_time - 12,
- 10**-5)
+ self.assertLess(mca0_elapsed_time - 345.6, 10**-5)
+ self.assertLess(mca1_elapsed_time - 12, 10**-5)
sftext_no_cols = r"""#F C:/DATA\test.mca
@@ -736,13 +718,11 @@ sftext_no_cols = r"""#F C:/DATA\test.mca
class TestSpecH5NoDataCols(unittest.TestCase):
"""Test reading SPEC files with only MCA data"""
+
@classmethod
def setUpClass(cls):
fd, cls.fname = tempfile.mkstemp()
- if sys.version_info < (3, ):
- os.write(fd, sftext_no_cols)
- else:
- os.write(fd, bytes(sftext_no_cols, 'ascii'))
+ os.write(fd, bytes(sftext_no_cols, "ascii"))
os.close(fd)
@classmethod
@@ -757,33 +737,23 @@ class TestSpecH5NoDataCols(unittest.TestCase):
def testScan1(self):
# 1.1: single analyser, single spectrum, 151 channels
- self.assertIn("mca_0",
- self.sfh5["1.1/instrument/"])
- self.assertEqual(self.sfh5["1.1/instrument/mca_0/data"].shape,
- (1, 151))
- self.assertNotIn("mca_1",
- self.sfh5["1.1/instrument/"])
+ self.assertIn("mca_0", self.sfh5["1.1/instrument/"])
+ self.assertEqual(self.sfh5["1.1/instrument/mca_0/data"].shape, (1, 151))
+ self.assertNotIn("mca_1", self.sfh5["1.1/instrument/"])
def testScan2(self):
# 2.1: single analyser, 9 spectra, 3 channels
- self.assertIn("mca_0",
- self.sfh5["2.1/instrument/"])
- self.assertEqual(self.sfh5["2.1/instrument/mca_0/data"].shape,
- (9, 3))
- self.assertNotIn("mca_1",
- self.sfh5["2.1/instrument/"])
+ self.assertIn("mca_0", self.sfh5["2.1/instrument/"])
+ self.assertEqual(self.sfh5["2.1/instrument/mca_0/data"].shape, (9, 3))
+ self.assertNotIn("mca_1", self.sfh5["2.1/instrument/"])
def testScan3(self):
# 3.1: 3 analysers, 3 spectra/analyser, 3 channels
for i in range(3):
- self.assertIn("mca_%d" % i,
- self.sfh5["3.1/instrument/"])
- self.assertEqual(
- self.sfh5["3.1/instrument/mca_%d/data" % i].shape,
- (3, 3))
+ self.assertIn("mca_%d" % i, self.sfh5["3.1/instrument/"])
+ self.assertEqual(self.sfh5["3.1/instrument/mca_%d/data" % i].shape, (3, 3))
- self.assertNotIn("mca_3",
- self.sfh5["3.1/instrument/"])
+ self.assertNotIn("mca_3", self.sfh5["3.1/instrument/"])
sf_text_slash = r"""#F /data/id09/archive/logspecfiles/laue/2016/scan_231_laue_16-11-29.dat
@@ -807,13 +777,11 @@ class TestSpecH5SlashInLabels(unittest.TestCase):
The / character must be substituted with a %
"""
+
@classmethod
def setUpClass(cls):
fd, cls.fname = tempfile.mkstemp()
- if sys.version_info < (3, ):
- os.write(fd, sf_text_slash)
- else:
- os.write(fd, bytes(sf_text_slash, 'ascii'))
+ os.write(fd, bytes(sf_text_slash, "ascii"))
os.close(fd)
@classmethod
@@ -829,66 +797,73 @@ class TestSpecH5SlashInLabels(unittest.TestCase):
def testLabels(self):
"""Ensure `/` is substituted with `%` and
ensure legitimate `%` in names are still working"""
- self.assertEqual(list(self.sfh5["1.1/measurement/"].keys()),
- ["GONY%mm", "PD3%A"])
+ self.assertEqual(
+ list(self.sfh5["1.1/measurement/"].keys()), ["GONY%mm", "PD3%A"]
+ )
# substituted "%"
- self.assertIn("GONY%mm",
- self.sfh5["1.1/measurement/"])
- self.assertNotIn("GONY/mm",
- self.sfh5["1.1/measurement/"])
- self.assertAlmostEqual(self.sfh5["1.1/measurement/GONY%mm"][0],
- -2.015, places=4)
+ self.assertIn("GONY%mm", self.sfh5["1.1/measurement/"])
+ self.assertNotIn("GONY/mm", self.sfh5["1.1/measurement/"])
+ self.assertAlmostEqual(
+ self.sfh5["1.1/measurement/GONY%mm"][0], -2.015, places=4
+ )
# legitimate "%"
- self.assertIn("PD3%A",
- self.sfh5["1.1/measurement/"])
+ self.assertIn("PD3%A", self.sfh5["1.1/measurement/"])
def testMotors(self):
"""Ensure `/` is substituted with `%` and
ensure legitimate `%` in names are still working"""
- self.assertEqual(list(self.sfh5["1.1/instrument/positioners"].keys()),
- ["Pslit%HGap", "MRTSlit%UP"])
+ self.assertEqual(
+ list(self.sfh5["1.1/instrument/positioners"].keys()),
+ ["Pslit%HGap", "MRTSlit%UP"],
+ )
# substituted "%"
- self.assertIn("Pslit%HGap",
- self.sfh5["1.1/instrument/positioners"])
- self.assertNotIn("Pslit/HGap",
- self.sfh5["1.1/instrument/positioners"])
+ self.assertIn("Pslit%HGap", self.sfh5["1.1/instrument/positioners"])
+ self.assertNotIn("Pslit/HGap", self.sfh5["1.1/instrument/positioners"])
self.assertAlmostEqual(
- self.sfh5["1.1/instrument/positioners/Pslit%HGap"],
- 180.005, places=4)
+ self.sfh5["1.1/instrument/positioners/Pslit%HGap"], 180.005, places=4
+ )
# legitimate "%"
- self.assertIn("MRTSlit%UP",
- self.sfh5["1.1/instrument/positioners"])
+ self.assertIn("MRTSlit%UP", self.sfh5["1.1/instrument/positioners"])
def testUnitCellUBMatrix(tmp_path):
"""Test unit cell (#G1) and UB matrix (#G3)"""
file_path = tmp_path / "spec.dat"
- file_path.write_bytes(bytes("""
+ file_path.write_bytes(
+ bytes(
+ """
#S 1 OK
#G1 0 1 2 3 4 5
#G3 0 1 2 3 4 5 6 7 8
-""", encoding="ascii"))
+""",
+ encoding="ascii",
+ )
+ )
with SpecH5(str(file_path)) as spech5:
assert numpy.array_equal(
- spech5["/1.1/sample/ub_matrix"],
- numpy.arange(9).reshape(1, 3, 3))
- assert numpy.array_equal(
- spech5["/1.1/sample/unit_cell"], [[0, 1, 2, 3, 4, 5]])
- assert numpy.array_equal(
- spech5["/1.1/sample/unit_cell_abc"], [0, 1, 2])
+ spech5["/1.1/sample/ub_matrix"], numpy.arange(9).reshape(1, 3, 3)
+ )
+ assert numpy.array_equal(spech5["/1.1/sample/unit_cell"], [[0, 1, 2, 3, 4, 5]])
+ assert numpy.array_equal(spech5["/1.1/sample/unit_cell_abc"], [0, 1, 2])
assert numpy.array_equal(
- spech5["/1.1/sample/unit_cell_alphabetagamma"], [3, 4, 5])
+ spech5["/1.1/sample/unit_cell_alphabetagamma"], [3, 4, 5]
+ )
def testMalformedUnitCellUBMatrix(tmp_path):
"""Test malformed unit cell (#G1) and UB matrix (#G3): 1 value"""
file_path = tmp_path / "spec.dat"
- file_path.write_bytes(bytes("""
+ file_path.write_bytes(
+ bytes(
+ """
#S 1 all malformed=0
#G1 0
#G3 0
-""", encoding="ascii"))
+""",
+ encoding="ascii",
+ )
+ )
with SpecH5(str(file_path)) as spech5:
assert "sample" not in spech5["1.1"]
@@ -896,33 +871,42 @@ def testMalformedUnitCellUBMatrix(tmp_path):
def testMalformedUBMatrix(tmp_path):
"""Test malformed UB matrix (#G3): all zeros"""
file_path = tmp_path / "spec.dat"
- file_path.write_bytes(bytes("""
+ file_path.write_bytes(
+ bytes(
+ """
#S 1 G3 all 0
#G1 0 1 2 3 4 5
#G3 0 0 0 0 0 0 0 0 0
-""", encoding="ascii"))
+""",
+ encoding="ascii",
+ )
+ )
with SpecH5(str(file_path)) as spech5:
assert "ub_matrix" not in spech5["/1.1/sample"]
+ assert numpy.array_equal(spech5["/1.1/sample/unit_cell"], [[0, 1, 2, 3, 4, 5]])
+ assert numpy.array_equal(spech5["/1.1/sample/unit_cell_abc"], [0, 1, 2])
assert numpy.array_equal(
- spech5["/1.1/sample/unit_cell"], [[0, 1, 2, 3, 4, 5]])
- assert numpy.array_equal(
- spech5["/1.1/sample/unit_cell_abc"], [0, 1, 2])
- assert numpy.array_equal(
- spech5["/1.1/sample/unit_cell_alphabetagamma"], [3, 4, 5])
+ spech5["/1.1/sample/unit_cell_alphabetagamma"], [3, 4, 5]
+ )
def testMalformedUnitCell(tmp_path):
"""Test malformed unit cell (#G1): missing values"""
file_path = tmp_path / "spec.dat"
- file_path.write_bytes(bytes("""
+ file_path.write_bytes(
+ bytes(
+ """
#S 1 G1 malformed missing values
#G1 0 1 2
#G3 0 1 2 3 4 5 6 7 8
-""", encoding="ascii"))
+""",
+ encoding="ascii",
+ )
+ )
with SpecH5(str(file_path)) as spech5:
assert "unit_cell" not in spech5["/1.1/sample"]
assert "unit_cell_abc" not in spech5["/1.1/sample"]
assert "unit_cell_alphabetagamma" not in spech5["/1.1/sample"]
assert numpy.array_equal(
- spech5["/1.1/sample/ub_matrix"],
- numpy.arange(9).reshape(1, 3, 3))
+ spech5["/1.1/sample/ub_matrix"], numpy.arange(9).reshape(1, 3, 3)
+ )
diff --git a/src/silx/io/test/test_spectoh5.py b/src/silx/io/test/test_spectoh5.py
index 5465ece..a3426ea 100644
--- a/src/silx/io/test/test_spectoh5.py
+++ b/src/silx/io/test/test_spectoh5.py
@@ -24,7 +24,6 @@
from numpy import array_equal
import os
-import sys
import tempfile
import unittest
@@ -113,53 +112,55 @@ class TestConvertSpecHDF5(unittest.TestCase):
def testAppendToHDF5(self):
write_to_h5(self.sfh5, self.h5f, h5path="/foo/bar/spam")
self.assertTrue(
- array_equal(self.h5f["/1.2/measurement/mca_1/data"],
- self.h5f["/foo/bar/spam/1.2/measurement/mca_1/data"])
+ array_equal(
+ self.h5f["/1.2/measurement/mca_1/data"],
+ self.h5f["/foo/bar/spam/1.2/measurement/mca_1/data"],
+ )
)
def testWriteSpecH5Group(self):
"""Test passing a SpecH5Group as parameter, instead of a Spec filename
or a SpecH5."""
g = self.sfh5["1.1/instrument"]
- self.assertIsInstance(g, SpecH5Group) # let's be paranoid
+ self.assertIsInstance(g, SpecH5Group) # let's be paranoid
write_to_h5(g, self.h5f, h5path="my instruments")
- self.assertAlmostEqual(self.h5f["my instruments/positioners/Sslit1 HOff"][tuple()],
- 16.197579, places=4)
+ self.assertAlmostEqual(
+ self.h5f["my instruments/positioners/Sslit1 HOff"][tuple()],
+ 16.197579,
+ places=4,
+ )
def testTitle(self):
"""Test the value of a dataset"""
title12 = h5py_read_dataset(self.h5f["/1.2/title"])
- self.assertEqual(title12,
- u"aaaaaa")
+ self.assertEqual(title12, "aaaaaa")
def testAttrs(self):
# Test root group (file) attributes
- self.assertEqual(self.h5f.attrs["NX_class"],
- u"NXroot")
+ self.assertEqual(self.h5f.attrs["NX_class"], "NXroot")
# Test dataset attributes
ds = self.h5f["/1.2/instrument/mca_1/data"]
self.assertTrue("interpretation" in ds.attrs)
- self.assertEqual(list(ds.attrs.values()),
- [u"spectrum"])
+ self.assertEqual(list(ds.attrs.values()), ["spectrum"])
# Test group attributes
grp = self.h5f["1.1"]
- self.assertEqual(grp.attrs["NX_class"],
- u"NXentry")
- self.assertEqual(len(list(grp.attrs.keys())),
- 1)
+ self.assertEqual(grp.attrs["NX_class"], "NXentry")
+ self.assertEqual(len(list(grp.attrs.keys())), 1)
def testHdf5HasSameMembers(self):
spec_member_list = []
def append_spec_members(name):
spec_member_list.append(name)
+
self.sfh5.visit(append_spec_members)
hdf5_member_list = []
def append_hdf5_members(name):
hdf5_member_list.append(name)
+
self.h5f.visit(append_hdf5_members)
# 1. For some reason, h5py visit method doesn't include the leading
@@ -168,15 +169,18 @@ class TestConvertSpecHDF5(unittest.TestCase):
# have a leading "/"
spec_member_list = [m.lstrip("/") for m in spec_member_list]
- self.assertEqual(set(hdf5_member_list),
- set(spec_member_list))
+ self.assertEqual(set(hdf5_member_list), set(spec_member_list))
def testLinks(self):
self.assertTrue(
- array_equal(self.sfh5["/1.2/measurement/mca_0/data"],
- self.h5f["/1.2/measurement/mca_0/data"])
+ array_equal(
+ self.sfh5["/1.2/measurement/mca_0/data"],
+ self.h5f["/1.2/measurement/mca_0/data"],
+ )
)
self.assertTrue(
- array_equal(self.h5f["/1.2/instrument/mca_1/channels"],
- self.h5f["/1.2/measurement/mca_1/info/channels"])
+ array_equal(
+ self.h5f["/1.2/instrument/mca_1/channels"],
+ self.h5f["/1.2/measurement/mca_1/info/channels"],
+ )
)
diff --git a/src/silx/io/test/test_url.py b/src/silx/io/test/test_url.py
index 8cbfb34..61f9883 100644
--- a/src/silx/io/test/test_url.py
+++ b/src/silx/io/test/test_url.py
@@ -27,190 +27,274 @@ __license__ = "MIT"
__date__ = "29/01/2018"
-import unittest
+import pytest
from ..url import DataUrl
-class TestDataUrl(unittest.TestCase):
-
- def assertUrl(self, url, expected):
- self.assertEqual(url.is_valid(), expected[0])
- self.assertEqual(url.is_absolute(), expected[1])
- self.assertEqual(url.scheme(), expected[2])
- self.assertEqual(url.file_path(), expected[3])
- self.assertEqual(url.data_path(), expected[4])
- self.assertEqual(url.data_slice(), expected[5])
-
- def test_fabio_absolute(self):
- url = DataUrl("fabio:///data/image.edf?slice=2")
- expected = [True, True, "fabio", "/data/image.edf", None, (2, )]
- self.assertUrl(url, expected)
-
- def test_fabio_absolute_windows(self):
- url = DataUrl("fabio:///C:/data/image.edf?slice=2")
- expected = [True, True, "fabio", "C:/data/image.edf", None, (2, )]
- self.assertUrl(url, expected)
-
- def test_silx_absolute(self):
- url = DataUrl("silx:///data/image.h5?path=/data/dataset&slice=1,5")
- expected = [True, True, "silx", "/data/image.h5", "/data/dataset", (1, 5)]
- self.assertUrl(url, expected)
-
- def test_commandline_shell_separator(self):
- url = DataUrl("silx:///data/image.h5::path=/data/dataset&slice=1,5")
- expected = [True, True, "silx", "/data/image.h5", "/data/dataset", (1, 5)]
- self.assertUrl(url, expected)
-
- def test_silx_absolute2(self):
- url = DataUrl("silx:///data/image.edf?/scan_0/detector/data")
- expected = [True, True, "silx", "/data/image.edf", "/scan_0/detector/data", None]
- self.assertUrl(url, expected)
-
- def test_silx_absolute_windows(self):
- url = DataUrl("silx:///C:/data/image.h5?/scan_0/detector/data")
- expected = [True, True, "silx", "C:/data/image.h5", "/scan_0/detector/data", None]
- self.assertUrl(url, expected)
-
- def test_silx_relative(self):
- url = DataUrl("silx:./image.h5")
- expected = [True, False, "silx", "./image.h5", None, None]
- self.assertUrl(url, expected)
-
- def test_fabio_relative(self):
- url = DataUrl("fabio:./image.edf")
- expected = [True, False, "fabio", "./image.edf", None, None]
- self.assertUrl(url, expected)
-
- def test_silx_relative2(self):
- url = DataUrl("silx:image.h5")
- expected = [True, False, "silx", "image.h5", None, None]
- self.assertUrl(url, expected)
-
- def test_fabio_relative2(self):
- url = DataUrl("fabio:image.edf")
- expected = [True, False, "fabio", "image.edf", None, None]
- self.assertUrl(url, expected)
-
- def test_file_relative(self):
- url = DataUrl("image.edf")
- expected = [True, False, None, "image.edf", None, None]
- self.assertUrl(url, expected)
-
- def test_file_relative2(self):
- url = DataUrl("./foo/bar/image.edf")
- expected = [True, False, None, "./foo/bar/image.edf", None, None]
- self.assertUrl(url, expected)
-
- def test_file_relative3(self):
- url = DataUrl("foo/bar/image.edf")
- expected = [True, False, None, "foo/bar/image.edf", None, None]
- self.assertUrl(url, expected)
-
- def test_file_absolute(self):
- url = DataUrl("/data/image.edf")
- expected = [True, True, None, "/data/image.edf", None, None]
- self.assertUrl(url, expected)
-
- def test_file_absolute_windows(self):
- url = DataUrl("C:/data/image.edf")
- expected = [True, True, None, "C:/data/image.edf", None, None]
- self.assertUrl(url, expected)
-
- def test_absolute_with_path(self):
- url = DataUrl("/foo/foobar.h5?/foo/bar")
- expected = [True, True, None, "/foo/foobar.h5", "/foo/bar", None]
- self.assertUrl(url, expected)
-
- def test_windows_file_data_slice(self):
- url = DataUrl("C:/foo/foobar.h5?path=/foo/bar&slice=5,1")
- expected = [True, True, None, "C:/foo/foobar.h5", "/foo/bar", (5, 1)]
- self.assertUrl(url, expected)
-
- def test_scheme_file_data_slice(self):
- url = DataUrl("silx:/foo/foobar.h5?path=/foo/bar&slice=5,1")
- expected = [True, True, "silx", "/foo/foobar.h5", "/foo/bar", (5, 1)]
- self.assertUrl(url, expected)
-
- def test_scheme_windows_file_data_slice(self):
- url = DataUrl("silx:C:/foo/foobar.h5?path=/foo/bar&slice=5,1")
- expected = [True, True, "silx", "C:/foo/foobar.h5", "/foo/bar", (5, 1)]
- self.assertUrl(url, expected)
-
- def test_empty(self):
- url = DataUrl("")
- expected = [False, False, None, "", None, None]
- self.assertUrl(url, expected)
-
- def test_unknown_scheme(self):
- url = DataUrl("foo:/foo/foobar.h5?path=/foo/bar&slice=5,1")
- expected = [False, True, "foo", "/foo/foobar.h5", "/foo/bar", (5, 1)]
- self.assertUrl(url, expected)
-
- def test_slice(self):
- url = DataUrl("/a.h5?path=/b&slice=5,1")
- expected = [True, True, None, "/a.h5", "/b", (5, 1)]
- self.assertUrl(url, expected)
-
- def test_slice2(self):
- url = DataUrl("/a.h5?path=/b&slice=2:5")
- expected = [True, True, None, "/a.h5", "/b", (slice(2, 5),)]
- self.assertUrl(url, expected)
-
- def test_slice3(self):
- url = DataUrl("/a.h5?path=/b&slice=::2")
- expected = [True, True, None, "/a.h5", "/b", (slice(None, None, 2),)]
- self.assertUrl(url, expected)
-
- def test_slice_ellipsis(self):
- url = DataUrl("/a.h5?path=/b&slice=...")
- expected = [True, True, None, "/a.h5", "/b", (Ellipsis, )]
- self.assertUrl(url, expected)
-
- def test_slice_slicing(self):
- url = DataUrl("/a.h5?path=/b&slice=:")
- expected = [True, True, None, "/a.h5", "/b", (slice(None), )]
- self.assertUrl(url, expected)
-
- def test_slice_missing_element(self):
- url = DataUrl("/a.h5?path=/b&slice=5,,1")
- expected = [False, True, None, "/a.h5", "/b", None]
- self.assertUrl(url, expected)
-
- def test_slice_no_elements(self):
- url = DataUrl("/a.h5?path=/b&slice=")
- expected = [False, True, None, "/a.h5", "/b", None]
- self.assertUrl(url, expected)
-
- def test_create_relative_url(self):
- url = DataUrl(scheme="silx", file_path="./foo.h5", data_path="/", data_slice=(5, 1))
- self.assertFalse(url.is_absolute())
- url2 = DataUrl(url.path())
- self.assertEqual(url, url2)
-
- def test_create_absolute_url(self):
- url = DataUrl(scheme="silx", file_path="/foo.h5", data_path="/", data_slice=(5, 1))
- url2 = DataUrl(url.path())
- self.assertEqual(url, url2)
-
- def test_create_absolute_windows_url(self):
- url = DataUrl(scheme="silx", file_path="C:/foo.h5", data_path="/", data_slice=(5, 1))
- url2 = DataUrl(url.path())
- self.assertEqual(url, url2)
-
- def test_create_slice_url(self):
- url = DataUrl(scheme="silx", file_path="/foo.h5", data_path="/", data_slice=(5, 1, Ellipsis, slice(None)))
- url2 = DataUrl(url.path())
- self.assertEqual(url, url2)
-
- def test_wrong_url(self):
- url = DataUrl(scheme="silx", file_path="/foo.h5", data_slice=(5, 1))
- self.assertFalse(url.is_valid())
-
- def test_path_creation(self):
- """make sure the construction of path succeed and that we can
- recreate a DataUrl from a path"""
- for data_slice in (1, (1,)):
- with self.subTest(data_slice=data_slice):
- url = DataUrl(scheme="silx", file_path="/foo.h5", data_slice=data_slice)
- path = url.path()
- DataUrl(path=path)
+def assert_url(url, expected):
+ assert url.is_valid() == expected[0]
+ assert url.is_absolute() == expected[1]
+ assert url.scheme() == expected[2]
+ assert url.file_path() == expected[3]
+ assert url.data_path() == expected[4]
+ assert url.data_slice() == expected[5]
+
+
+def test_fabio_absolute():
+ url = DataUrl("fabio:///data/image.edf?slice=2")
+ expected = [True, True, "fabio", "/data/image.edf", None, (2,)]
+ assert_url(url, expected)
+
+
+def test_fabio_absolute_windows():
+ url = DataUrl("fabio:///C:/data/image.edf?slice=2")
+ expected = [True, True, "fabio", "C:/data/image.edf", None, (2,)]
+ assert_url(url, expected)
+
+
+def test_silx_absolute():
+ url = DataUrl("silx:///data/image.h5?path=/data/dataset&slice=1,5")
+ expected = [True, True, "silx", "/data/image.h5", "/data/dataset", (1, 5)]
+ assert_url(url, expected)
+
+
+def test_commandline_shell_separator():
+ url = DataUrl("silx:///data/image.h5::path=/data/dataset&slice=1,5")
+ expected = [True, True, "silx", "/data/image.h5", "/data/dataset", (1, 5)]
+ assert_url(url, expected)
+
+
+def test_silx_absolute2():
+ url = DataUrl("silx:///data/image.edf?/scan_0/detector/data")
+ expected = [True, True, "silx", "/data/image.edf", "/scan_0/detector/data", None]
+ assert_url(url, expected)
+
+
+def test_silx_absolute_windows():
+ url = DataUrl("silx:///C:/data/image.h5?/scan_0/detector/data")
+ expected = [True, True, "silx", "C:/data/image.h5", "/scan_0/detector/data", None]
+ assert_url(url, expected)
+
+
+def test_silx_relative():
+ url = DataUrl("silx:./image.h5")
+ expected = [True, False, "silx", "./image.h5", None, None]
+ assert_url(url, expected)
+
+
+def test_fabio_relative():
+ url = DataUrl("fabio:./image.edf")
+ expected = [True, False, "fabio", "./image.edf", None, None]
+ assert_url(url, expected)
+
+
+def test_silx_relative2():
+ url = DataUrl("silx:image.h5")
+ expected = [True, False, "silx", "image.h5", None, None]
+ assert_url(url, expected)
+
+
+def test_fabio_relative2():
+ url = DataUrl("fabio:image.edf")
+ expected = [True, False, "fabio", "image.edf", None, None]
+ assert_url(url, expected)
+
+
+def test_file_relative():
+ url = DataUrl("image.edf")
+ expected = [True, False, None, "image.edf", None, None]
+ assert_url(url, expected)
+
+
+def test_file_relative2():
+ url = DataUrl("./foo/bar/image.edf")
+ expected = [True, False, None, "./foo/bar/image.edf", None, None]
+ assert_url(url, expected)
+
+
+def test_file_relative3():
+ url = DataUrl("foo/bar/image.edf")
+ expected = [True, False, None, "foo/bar/image.edf", None, None]
+ assert_url(url, expected)
+
+
+def test_file_absolute():
+ url = DataUrl("/data/image.edf")
+ expected = [True, True, None, "/data/image.edf", None, None]
+ assert_url(url, expected)
+
+
+def test_file_absolute_windows():
+ url = DataUrl("C:/data/image.edf")
+ expected = [True, True, None, "C:/data/image.edf", None, None]
+ assert_url(url, expected)
+
+
+def test_absolute_with_path():
+ url = DataUrl("/foo/foobar.h5?/foo/bar")
+ expected = [True, True, None, "/foo/foobar.h5", "/foo/bar", None]
+ assert_url(url, expected)
+
+
+def test_windows_file_data_slice():
+ url = DataUrl("C:/foo/foobar.h5?path=/foo/bar&slice=5,1")
+ expected = [True, True, None, "C:/foo/foobar.h5", "/foo/bar", (5, 1)]
+ assert_url(url, expected)
+
+
+def test_scheme_file_data_slice():
+ url = DataUrl("silx:/foo/foobar.h5?path=/foo/bar&slice=5,1")
+ expected = [True, True, "silx", "/foo/foobar.h5", "/foo/bar", (5, 1)]
+ assert_url(url, expected)
+
+
+def test_scheme_windows_file_data_slice():
+ url = DataUrl("silx:C:/foo/foobar.h5?path=/foo/bar&slice=5,1")
+ expected = [True, True, "silx", "C:/foo/foobar.h5", "/foo/bar", (5, 1)]
+ assert_url(url, expected)
+
+
+def test_empty():
+ url = DataUrl("")
+ expected = [False, False, None, "", None, None]
+ assert_url(url, expected)
+
+
+def test_unknown_scheme():
+ url = DataUrl("foo:/foo/foobar.h5?path=/foo/bar&slice=5,1")
+ expected = [False, True, "foo", "/foo/foobar.h5", "/foo/bar", (5, 1)]
+ assert_url(url, expected)
+
+
+def test_slice():
+ url = DataUrl("/a.h5?path=/b&slice=5,1")
+ expected = [True, True, None, "/a.h5", "/b", (5, 1)]
+ assert_url(url, expected)
+
+
+def test_slice2():
+ url = DataUrl("/a.h5?path=/b&slice=2:5")
+ expected = [True, True, None, "/a.h5", "/b", (slice(2, 5),)]
+ assert_url(url, expected)
+
+
+def test_slice3():
+ url = DataUrl("/a.h5?path=/b&slice=::2")
+ expected = [True, True, None, "/a.h5", "/b", (slice(None, None, 2),)]
+ assert_url(url, expected)
+
+
+def test_slice_ellipsis():
+ url = DataUrl("/a.h5?path=/b&slice=...")
+ expected = [True, True, None, "/a.h5", "/b", (Ellipsis,)]
+ assert_url(url, expected)
+
+
+def test_slice_slicing():
+ url = DataUrl("/a.h5?path=/b&slice=:")
+ expected = [True, True, None, "/a.h5", "/b", (slice(None),)]
+ assert_url(url, expected)
+
+
+def test_slice_missing_element():
+ url = DataUrl("/a.h5?path=/b&slice=5,,1")
+ expected = [False, True, None, "/a.h5", "/b", None]
+ assert_url(url, expected)
+
+
+def test_slice_no_elements():
+ url = DataUrl("/a.h5?path=/b&slice=")
+ expected = [False, True, None, "/a.h5", "/b", None]
+ assert_url(url, expected)
+
+
+def test_create_relative_url():
+ url = DataUrl(scheme="silx", file_path="./foo.h5", data_path="/", data_slice=(5, 1))
+ assert not url.is_absolute()
+ url2 = DataUrl(url.path())
+ assert url == url2
+
+
+def test_create_absolute_url():
+ url = DataUrl(scheme="silx", file_path="/foo.h5", data_path="/", data_slice=(5, 1))
+ url2 = DataUrl(url.path())
+ assert url == url2
+
+
+def test_create_absolute_windows_url():
+ url = DataUrl(
+ scheme="silx", file_path="C:/foo.h5", data_path="/", data_slice=(5, 1)
+ )
+ url2 = DataUrl(url.path())
+ assert url == url2
+
+
+def test_create_slice_url():
+ url = DataUrl(
+ scheme="silx",
+ file_path="/foo.h5",
+ data_path="/",
+ data_slice=(5, 1, Ellipsis, slice(None)),
+ )
+ url2 = DataUrl(url.path())
+ assert url == url2
+
+
+def test_wrong_url():
+ url = DataUrl(scheme="silx", file_path="/foo.h5", data_slice=(5, 1))
+ assert not url.is_valid()
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ (1, "silx:///foo.h5?slice=1"),
+ ((1,), "silx:///foo.h5?slice=1"),
+ (slice(None), "silx:///foo.h5?slice=:"),
+ (slice(1, None), "silx:///foo.h5?slice=1:"),
+ (slice(None, -2), "silx:///foo.h5?slice=:-2"),
+ (slice(1, None, 3), "silx:///foo.h5?slice=1::3"),
+ (slice(None, 2, 3), "silx:///foo.h5?slice=:2:3"),
+ (slice(None, None, 3), "silx:///foo.h5?slice=::3"),
+ (slice(1, 2, 3), "silx:///foo.h5?slice=1:2:3"),
+ ((1, slice(1, 2)), "silx:///foo.h5?slice=1,1:2"),
+ ],
+)
+def test_path_creation(data):
+ """make sure the construction of path succeed and that we can
+ recreate a DataUrl from a path"""
+ data_slice, expected_path = data
+ url = DataUrl(scheme="silx", file_path="/foo.h5", data_slice=data_slice)
+ path = url.path()
+ DataUrl(path=path)
+ assert path == expected_path
+
+
+def test_file_path_none():
+ """
+ make sure a file path can be None
+ """
+ url = DataUrl(scheme="silx", file_path=None, data_path="/path/to/data")
+ assert url.file_path() is None
+ assert url.scheme() == "silx"
+ assert url.data_path() == "/path/to/data"
+
+
+def test_data_path_none():
+ """
+ make sure a data path can be None
+ """
+ url = DataUrl(scheme="silx", file_path="my_file.hdf5", data_path=None)
+ assert url.file_path() == "my_file.hdf5"
+ assert url.scheme() == "silx"
+ assert url.data_path() is None
+
+
+def test_scheme_none():
+ """
+ make sure a scheme can be None
+ """
+ url = DataUrl(scheme=None, file_path="my_file.hdf5", data_path="/path/to/data")
+ assert url.file_path() == "my_file.hdf5"
+ assert url.scheme() is None
+ assert url.data_path() == "/path/to/data"
diff --git a/src/silx/io/test/test_utils.py b/src/silx/io/test/test_utils.py
index b9fc3ab..a9c7f6a 100644
--- a/src/silx/io/test/test_utils.py
+++ b/src/silx/io/test/test_utils.py
@@ -57,7 +57,9 @@ expected_spec1 = r"""#F .*
3 6\.00
"""
-expected_spec2 = expected_spec1 + r"""
+expected_spec2 = (
+ expected_spec1
+ + r"""
#S 2 Ordinate2
#D .*
#N 2
@@ -66,6 +68,7 @@ expected_spec2 = expected_spec1 + r"""
2 8\.00
3 9\.00
"""
+)
expected_spec2reg = r"""#F .*
#D .*
@@ -79,7 +82,9 @@ expected_spec2reg = r"""#F .*
3 6\.00 9\.00
"""
-expected_spec2irr = expected_spec1 + r"""
+expected_spec2irr = (
+ expected_spec1
+ + r"""
#S 2 Ordinate2
#D .*
#N 2
@@ -87,6 +92,7 @@ expected_spec2irr = expected_spec1 + r"""
1 7\.00
2 8\.00
"""
+)
expected_csv = r"""Abscissa;Ordinate1;Ordinate2
1;4\.00;7\.00e\+00
@@ -102,8 +108,7 @@ expected_csv2 = r"""x;y0;y1
class TestSave(unittest.TestCase):
- """Test saving curves as SpecFile:
- """
+ """Test saving curves as SpecFile:"""
def setUp(self):
self.tempdir = tempfile.mkdtemp()
@@ -127,10 +132,17 @@ class TestSave(unittest.TestCase):
shutil.rmtree(self.tempdir)
def test_save_csv(self):
- utils.save1D(self.csv_fname, self.x, self.y,
- xlabel=self.xlab, ylabels=self.ylabs,
- filetype="csv", fmt=["%d", "%.2f", "%.2e"],
- csvdelim=";", autoheader=True)
+ utils.save1D(
+ self.csv_fname,
+ self.x,
+ self.y,
+ xlabel=self.xlab,
+ ylabels=self.ylabs,
+ filetype="csv",
+ fmt=["%d", "%.2f", "%.2e"],
+ csvdelim=";",
+ autoheader=True,
+ )
csvf = open(self.csv_fname)
actual_csv = csvf.read()
@@ -142,21 +154,28 @@ class TestSave(unittest.TestCase):
"""npy file is saved with numpy.save after building a numpy array
and converting it to a named record array"""
npyf = open(self.npy_fname, "wb")
- utils.save1D(npyf, self.x, self.y,
- xlabel=self.xlab, ylabels=self.ylabs)
+ utils.save1D(npyf, self.x, self.y, xlabel=self.xlab, ylabels=self.ylabs)
npyf.close()
npy_recarray = numpy.load(self.npy_fname)
self.assertEqual(npy_recarray.shape, (3,))
- self.assertTrue(numpy.array_equal(npy_recarray['Ordinate1'],
- numpy.array((4, 5, 6))))
+ self.assertTrue(
+ numpy.array_equal(npy_recarray["Ordinate1"], numpy.array((4, 5, 6)))
+ )
def test_savespec_filename(self):
"""Save SpecFile using savespec()"""
- utils.savespec(self.spec_fname, self.x, self.y[0], xlabel=self.xlab,
- ylabel=self.ylabs[0], fmt=["%d", "%.2f"],
- close_file=True, scan_number=1)
+ utils.savespec(
+ self.spec_fname,
+ self.x,
+ self.y[0],
+ xlabel=self.xlab,
+ ylabel=self.ylabs[0],
+ fmt=["%d", "%.2f"],
+ close_file=True,
+ scan_number=1,
+ )
specf = open(self.spec_fname)
actual_spec = specf.read()
@@ -167,15 +186,28 @@ class TestSave(unittest.TestCase):
"""Save SpecFile using savespec(), passing a file handle"""
# first savespec: open, write file header, save y[0] as scan 1,
# return file handle
- specf = utils.savespec(self.spec_fname, self.x, self.y[0],
- xlabel=self.xlab, ylabel=self.ylabs[0],
- fmt=["%d", "%.2f"], close_file=False)
+ specf = utils.savespec(
+ self.spec_fname,
+ self.x,
+ self.y[0],
+ xlabel=self.xlab,
+ ylabel=self.ylabs[0],
+ fmt=["%d", "%.2f"],
+ close_file=False,
+ )
# second savespec: save y[1] as scan 2, close file
- utils.savespec(specf, self.x, self.y[1], xlabel=self.xlab,
- ylabel=self.ylabs[1], fmt=["%d", "%.2f"],
- write_file_header=False, close_file=True,
- scan_number=2)
+ utils.savespec(
+ specf,
+ self.x,
+ self.y[1],
+ xlabel=self.xlab,
+ ylabel=self.ylabs[1],
+ fmt=["%d", "%.2f"],
+ write_file_header=False,
+ close_file=True,
+ scan_number=2,
+ )
specf = open(self.spec_fname)
actual_spec = specf.read()
@@ -184,8 +216,15 @@ class TestSave(unittest.TestCase):
def test_save_spec_reg(self):
"""Save SpecFile using save() on a regular pattern"""
- utils.save1D(self.spec_fname, self.x, self.y, xlabel=self.xlab,
- ylabels=self.ylabs, filetype="spec", fmt=["%d", "%.2f"])
+ utils.save1D(
+ self.spec_fname,
+ self.x,
+ self.y,
+ xlabel=self.xlab,
+ ylabels=self.ylabs,
+ filetype="spec",
+ fmt=["%d", "%.2f"],
+ )
specf = open(self.spec_fname)
actual_spec = specf.read()
@@ -197,8 +236,15 @@ class TestSave(unittest.TestCase):
"""Save SpecFile using save() on an irregular pattern"""
# invalid test case ?!
return
- utils.save1D(self.spec_fname, self.x, self.y_irr, xlabel=self.xlab,
- ylabels=self.ylabs, filetype="spec", fmt=["%d", "%.2f"])
+ utils.save1D(
+ self.spec_fname,
+ self.x,
+ self.y_irr,
+ xlabel=self.xlab,
+ ylabels=self.ylabs,
+ filetype="spec",
+ fmt=["%d", "%.2f"],
+ )
specf = open(self.spec_fname)
actual_spec = specf.read()
@@ -218,8 +264,9 @@ class TestSave(unittest.TestCase):
self.xlab = "Abscissa"
self.y = [[4, 5, 6], [7, 8, 9]]
self.ylabs = ["Ordinate1", "Ordinate2"]
- utils.save1D(self.csv_fname, self.x, self.y,
- autoheader=True, fmt=["%d", "%.2f", "%.2e"])
+ utils.save1D(
+ self.csv_fname, self.x, self.y, autoheader=True, fmt=["%d", "%.2f", "%.2e"]
+ )
csvf = open(self.csv_fname)
actual_csv = csvf.read()
@@ -237,11 +284,11 @@ def assert_match_any_string_in_list(test, pattern, list_of_strings):
class TestH5Ls(unittest.TestCase):
"""Test displaying the following HDF5 file structure:
- +foo
- +bar
- <HDF5 dataset "spam": shape (2, 2), type "<i8">
- <HDF5 dataset "tmp": shape (3,), type "<i8">
- <HDF5 dataset "data": shape (1,), type "<f8">
+ +foo
+ +bar
+ <HDF5 dataset "spam": shape (2, 2), type "<i8">
+ <HDF5 dataset "tmp": shape (3,), type "<i8">
+ <HDF5 dataset "data": shape (1,), type "<f8">
"""
@@ -249,8 +296,11 @@ class TestH5Ls(unittest.TestCase):
for string_ in list_of_strings:
if re.match(pattern, string_):
return None
- raise AssertionError("regex pattern %s does not match any" % pattern +
- " string in list " + str(list_of_strings))
+ raise AssertionError(
+ "regex pattern %s does not match any" % pattern
+ + " string in list "
+ + str(list_of_strings)
+ )
def testHdf5(self):
fd, self.h5_fname = tempfile.mkstemp(text=False)
@@ -269,11 +319,11 @@ class TestH5Ls(unittest.TestCase):
self.assertIn("+foo", lines)
self.assertIn("\t+bar", lines)
- match = r'\t\t<HDF5 dataset "tmp": shape \(3,\), type "<i[48]">'
+ match = r'\t\t<HDF5 dataset "tmp": shape \(3,\), type "[<>]i[48]">'
self.assertMatchAnyStringInList(match, lines)
- match = r'\t\t<HDF5 dataset "spam": shape \(2, 2\), type "<i[48]">'
+ match = r'\t\t<HDF5 dataset "spam": shape \(2, 2\), type "[<>]i[48]">'
self.assertMatchAnyStringInList(match, lines)
- match = r'\t<HDF5 dataset "data": shape \(1,\), type "<f[48]">'
+ match = r'\t<HDF5 dataset "data": shape \(1,\), type "[<>]f[48]">'
self.assertMatchAnyStringInList(match, lines)
os.unlink(self.h5_fname)
@@ -321,15 +371,22 @@ class TestOpen(unittest.TestCase):
@classmethod
def createResources(cls, directory):
-
cls.h5_filename = os.path.join(directory, "test.h5")
h5 = h5py.File(cls.h5_filename, mode="w")
h5["group/group/dataset"] = 50
h5.close()
cls.spec_filename = os.path.join(directory, "test.dat")
- utils.savespec(cls.spec_filename, [1], [1.1], xlabel="x", ylabel="y",
- fmt=["%d", "%.2f"], close_file=True, scan_number=1)
+ utils.savespec(
+ cls.spec_filename,
+ [1],
+ [1.1],
+ xlabel="x",
+ ylabel="y",
+ fmt=["%d", "%.2f"],
+ close_file=True,
+ scan_number=1,
+ )
cls.edf_filename = os.path.join(directory, "test.edf")
header = fabio.fabioimage.OrderedDict()
@@ -340,7 +397,7 @@ class TestOpen(unittest.TestCase):
cls.txt_filename = os.path.join(directory, "test.txt")
f = io.open(cls.txt_filename, "w+t")
- f.write(u"Kikoo")
+ f.write("Kikoo")
f.close()
cls.missing_filename = os.path.join(directory, "test.missing")
@@ -403,7 +460,9 @@ class TestOpen(unittest.TestCase):
self.assertRaises(IOError, utils.open, self.missing_filename)
def test_silx_scheme(self):
- url = silx.io.url.DataUrl(scheme="silx", file_path=self.h5_filename, data_path="/")
+ url = silx.io.url.DataUrl(
+ scheme="silx", file_path=self.h5_filename, data_path="/"
+ )
with utils.open(url.path()) as f:
self.assertIsNotNone(f)
self.assertTrue(silx.io.utils.is_file(f))
@@ -446,9 +505,7 @@ class TestNodes(unittest.TestCase):
os.unlink(name)
def test_h5py_like_file(self):
-
class Foo(object):
-
def __init__(self):
self.h5_class = utils.H5Type.FILE
@@ -458,9 +515,7 @@ class TestNodes(unittest.TestCase):
self.assertFalse(utils.is_dataset(obj))
def test_h5py_like_group(self):
-
class Foo(object):
-
def __init__(self):
self.h5_class = utils.H5Type.GROUP
@@ -470,9 +525,7 @@ class TestNodes(unittest.TestCase):
self.assertFalse(utils.is_dataset(obj))
def test_h5py_like_dataset(self):
-
class Foo(object):
-
def __init__(self):
self.h5_class = utils.H5Type.DATASET
@@ -482,9 +535,7 @@ class TestNodes(unittest.TestCase):
self.assertTrue(utils.is_dataset(obj))
def test_bad(self):
-
class Foo(object):
-
def __init__(self):
pass
@@ -494,9 +545,7 @@ class TestNodes(unittest.TestCase):
self.assertFalse(utils.is_dataset(obj))
def test_bad_api(self):
-
class Foo(object):
-
def __init__(self):
self.h5_class = int
@@ -516,7 +565,6 @@ class TestGetData(unittest.TestCase):
@classmethod
def createResources(cls, directory):
-
cls.h5_filename = os.path.join(directory, "test.h5")
h5 = h5py.File(cls.h5_filename, mode="w")
h5["group/group/scalar"] = 50
@@ -525,8 +573,16 @@ class TestGetData(unittest.TestCase):
h5.close()
cls.spec_filename = os.path.join(directory, "test.dat")
- utils.savespec(cls.spec_filename, [1], [1.1], xlabel="x", ylabel="y",
- fmt=["%d", "%.2f"], close_file=True, scan_number=1)
+ utils.savespec(
+ cls.spec_filename,
+ [1],
+ [1.1],
+ xlabel="x",
+ ylabel="y",
+ fmt=["%d", "%.2f"],
+ close_file=True,
+ scan_number=1,
+ )
cls.edf_filename = os.path.join(directory, "test.edf")
cls.edf_multiframe_filename = os.path.join(directory, "test_multi.edf")
@@ -540,7 +596,7 @@ class TestGetData(unittest.TestCase):
cls.txt_filename = os.path.join(directory, "test.txt")
f = io.open(cls.txt_filename, "w+t")
- f.write(u"Kikoo")
+ f.write("Kikoo")
f.close()
cls.missing_filename = os.path.join(directory, "test.missing")
@@ -614,109 +670,150 @@ class TestGetData(unittest.TestCase):
def _h5_py_version_older_than(version):
- v_majeur, v_mineur, v_micro = [int(i) for i in h5py.version.version.split('.')[:3]]
- r_majeur, r_mineur, r_micro = [int(i) for i in version.split('.')]
- return calc_hexversion(v_majeur, v_mineur, v_micro) >= calc_hexversion(r_majeur, r_mineur, r_micro)
+ v_majeur, v_mineur, v_micro = [int(i) for i in h5py.version.version.split(".")[:3]]
+ r_majeur, r_mineur, r_micro = [int(i) for i in version.split(".")]
+ return calc_hexversion(v_majeur, v_mineur, v_micro) >= calc_hexversion(
+ r_majeur, r_mineur, r_micro
+ )
-@unittest.skipUnless(_h5_py_version_older_than('2.9.0'), 'h5py version < 2.9.0')
+@unittest.skipUnless(_h5_py_version_older_than("2.9.0"), "h5py version < 2.9.0")
class TestRawFileToH5(unittest.TestCase):
"""Test conversion of .vol file to .h5 external dataset"""
def setUp(self):
self.tempdir = tempfile.mkdtemp()
- self._vol_file = os.path.join(self.tempdir, 'test_vol.vol')
- self._file_info = os.path.join(self.tempdir, 'test_vol.info.vol')
+ self._vol_file = os.path.join(self.tempdir, "test_vol.vol")
+ self._file_info = os.path.join(self.tempdir, "test_vol.info.vol")
self._dataset_shape = 100, 20, 5
- data = numpy.random.random(self._dataset_shape[0] *
- self._dataset_shape[1] *
- self._dataset_shape[2]).astype(dtype=numpy.float32).reshape(self._dataset_shape)
+ data = (
+ numpy.random.random(
+ self._dataset_shape[0] * self._dataset_shape[1] * self._dataset_shape[2]
+ )
+ .astype(dtype=numpy.float32)
+ .reshape(self._dataset_shape)
+ )
numpy.save(file=self._vol_file, arr=data)
# those are storing into .noz file
- assert os.path.exists(self._vol_file + '.npy')
- os.rename(self._vol_file + '.npy', self._vol_file)
- self.h5_file = os.path.join(self.tempdir, 'test_h5.h5')
- self.external_dataset_path = '/root/my_external_dataset'
- self._data_url = silx.io.url.DataUrl(file_path=self.h5_file,
- data_path=self.external_dataset_path)
- with open(self._file_info, 'w') as _fi:
- _fi.write('NUM_X = %s\n' % self._dataset_shape[2])
- _fi.write('NUM_Y = %s\n' % self._dataset_shape[1])
- _fi.write('NUM_Z = %s\n' % self._dataset_shape[0])
+ assert os.path.exists(self._vol_file + ".npy")
+ os.rename(self._vol_file + ".npy", self._vol_file)
+ self.h5_file = os.path.join(self.tempdir, "test_h5.h5")
+ self.external_dataset_path = "/root/my_external_dataset"
+ self._data_url = silx.io.url.DataUrl(
+ file_path=self.h5_file, data_path=self.external_dataset_path
+ )
+ with open(self._file_info, "w") as _fi:
+ _fi.write("NUM_X = %s\n" % self._dataset_shape[2])
+ _fi.write("NUM_Y = %s\n" % self._dataset_shape[1])
+ _fi.write("NUM_Z = %s\n" % self._dataset_shape[0])
def tearDown(self):
shutil.rmtree(self.tempdir)
def check_dataset(self, h5_file, data_path, shape):
"""Make sure the external dataset is valid"""
- with h5py.File(h5_file, 'r') as _file:
+ with h5py.File(h5_file, "r") as _file:
return data_path in _file and _file[data_path].shape == shape
def test_h5_file_not_existing(self):
"""Test that can create a file with external dataset from scratch"""
- utils.rawfile_to_h5_external_dataset(bin_file=self._vol_file,
- output_url=self._data_url,
- shape=(100, 20, 5),
- dtype=numpy.float32)
- self.assertTrue(self.check_dataset(h5_file=self.h5_file,
- data_path=self.external_dataset_path,
- shape=self._dataset_shape))
+ utils.rawfile_to_h5_external_dataset(
+ bin_file=self._vol_file,
+ output_url=self._data_url,
+ shape=(100, 20, 5),
+ dtype=numpy.float32,
+ )
+ self.assertTrue(
+ self.check_dataset(
+ h5_file=self.h5_file,
+ data_path=self.external_dataset_path,
+ shape=self._dataset_shape,
+ )
+ )
os.remove(self.h5_file)
- utils.vol_to_h5_external_dataset(vol_file=self._vol_file,
- output_url=self._data_url,
- info_file=self._file_info)
- self.assertTrue(self.check_dataset(h5_file=self.h5_file,
- data_path=self.external_dataset_path,
- shape=self._dataset_shape))
+ utils.vol_to_h5_external_dataset(
+ vol_file=self._vol_file,
+ output_url=self._data_url,
+ info_file=self._file_info,
+ )
+ self.assertTrue(
+ self.check_dataset(
+ h5_file=self.h5_file,
+ data_path=self.external_dataset_path,
+ shape=self._dataset_shape,
+ )
+ )
def test_h5_file_existing(self):
"""Test that can add the external dataset from an existing file"""
- with h5py.File(self.h5_file, 'w') as _file:
- _file['/root/dataset1'] = numpy.zeros((100, 100))
- _file['/root/group/dataset2'] = numpy.ones((100, 100))
- utils.rawfile_to_h5_external_dataset(bin_file=self._vol_file,
- output_url=self._data_url,
- shape=(100, 20, 5),
- dtype=numpy.float32)
- self.assertTrue(self.check_dataset(h5_file=self.h5_file,
- data_path=self.external_dataset_path,
- shape=self._dataset_shape))
+ with h5py.File(self.h5_file, "w") as _file:
+ _file["/root/dataset1"] = numpy.zeros((100, 100))
+ _file["/root/group/dataset2"] = numpy.ones((100, 100))
+ utils.rawfile_to_h5_external_dataset(
+ bin_file=self._vol_file,
+ output_url=self._data_url,
+ shape=(100, 20, 5),
+ dtype=numpy.float32,
+ )
+ self.assertTrue(
+ self.check_dataset(
+ h5_file=self.h5_file,
+ data_path=self.external_dataset_path,
+ shape=self._dataset_shape,
+ )
+ )
def test_vol_file_not_existing(self):
"""Make sure error is raised if .vol file does not exists"""
os.remove(self._vol_file)
- utils.rawfile_to_h5_external_dataset(bin_file=self._vol_file,
- output_url=self._data_url,
- shape=(100, 20, 5),
- dtype=numpy.float32)
-
- self.assertTrue(self.check_dataset(h5_file=self.h5_file,
- data_path=self.external_dataset_path,
- shape=self._dataset_shape))
+ utils.rawfile_to_h5_external_dataset(
+ bin_file=self._vol_file,
+ output_url=self._data_url,
+ shape=(100, 20, 5),
+ dtype=numpy.float32,
+ )
+
+ self.assertTrue(
+ self.check_dataset(
+ h5_file=self.h5_file,
+ data_path=self.external_dataset_path,
+ shape=self._dataset_shape,
+ )
+ )
def test_conflicts(self):
"""Test several conflict cases"""
# test if path already exists
- utils.rawfile_to_h5_external_dataset(bin_file=self._vol_file,
- output_url=self._data_url,
- shape=(100, 20, 5),
- dtype=numpy.float32)
+ utils.rawfile_to_h5_external_dataset(
+ bin_file=self._vol_file,
+ output_url=self._data_url,
+ shape=(100, 20, 5),
+ dtype=numpy.float32,
+ )
with self.assertRaises(ValueError):
- utils.rawfile_to_h5_external_dataset(bin_file=self._vol_file,
- output_url=self._data_url,
- shape=(100, 20, 5),
- overwrite=False,
- dtype=numpy.float32)
-
- utils.rawfile_to_h5_external_dataset(bin_file=self._vol_file,
- output_url=self._data_url,
- shape=(100, 20, 5),
- overwrite=True,
- dtype=numpy.float32)
-
- self.assertTrue(self.check_dataset(h5_file=self.h5_file,
- data_path=self.external_dataset_path,
- shape=self._dataset_shape))
+ utils.rawfile_to_h5_external_dataset(
+ bin_file=self._vol_file,
+ output_url=self._data_url,
+ shape=(100, 20, 5),
+ overwrite=False,
+ dtype=numpy.float32,
+ )
+
+ utils.rawfile_to_h5_external_dataset(
+ bin_file=self._vol_file,
+ output_url=self._data_url,
+ shape=(100, 20, 5),
+ overwrite=True,
+ dtype=numpy.float32,
+ )
+
+ self.assertTrue(
+ self.check_dataset(
+ h5_file=self.h5_file,
+ data_path=self.external_dataset_path,
+ shape=self._dataset_shape,
+ )
+ )
class TestH5Strings(unittest.TestCase):
@@ -725,86 +822,153 @@ class TestH5Strings(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.tempdir = tempfile.mkdtemp()
- cls.vlenstr = h5py.special_dtype(vlen=str)
- cls.vlenbytes = h5py.special_dtype(vlen=bytes)
- try:
- cls.unicode = unicode
- except NameError:
- cls.unicode = str
@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.tempdir)
def setUp(self):
- self.file = h5py.File(os.path.join(self.tempdir, 'file.h5'), mode="w")
+ self.file = h5py.File(os.path.join(self.tempdir, "file.h5"), mode="w")
def tearDown(self):
self.file.close()
@classmethod
- def _make_array(cls, value, n):
+ def _make_array(cls, value, n, vlen=True):
if isinstance(value, bytes):
- dtype = cls.vlenbytes
- elif isinstance(value, cls.unicode):
- dtype = cls.vlenstr
+ if vlen:
+ dtype = h5py.special_dtype(vlen=bytes)
+ else:
+ if hasattr(h5py, "string_dtype"):
+ dtype = h5py.string_dtype("ascii", len(value))
+ else:
+ dtype = f"|S{len(value)}"
+ elif isinstance(value, str):
+ if vlen:
+ dtype = h5py.special_dtype(vlen=str)
+ else:
+ value = value.encode("utf-8")
+ if hasattr(h5py, "string_dtype"):
+ dtype = h5py.string_dtype("utf-8", len(value))
+ else:
+ dtype = f"|S{len(value)}"
else:
- return numpy.array([value] * n)
+ dtype = None
return numpy.array([value] * n, dtype=dtype)
@classmethod
def _get_charset(cls, value):
if isinstance(value, bytes):
return h5py.h5t.CSET_ASCII
- elif isinstance(value, cls.unicode):
+ elif isinstance(value, str):
return h5py.h5t.CSET_UTF8
else:
return None
def _check_dataset(self, value, result=None):
- # Write+read scalar
- if result:
+ if result is not None:
decode_ascii = True
else:
decode_ascii = False
result = value
+
+ # Write+read scalar
charset = self._get_charset(value)
self.file["data"] = value
data = utils.h5py_read_dataset(self.file["data"], decode_ascii=decode_ascii)
- assert type(data) == type(result), data
+ assert isinstance(data, type(result)), data
assert data == result, data
- if charset:
+ if charset is not None:
assert self.file["data"].id.get_type().get_cset() == charset
# Write+read variable length
+ no_unicode_support = isinstance(value, str) and not hasattr(
+ h5py, "string_dtype"
+ )
+ if no_unicode_support:
+ decode_ascii = True
self.file["vlen_data"] = self._make_array(value, 2)
- data = utils.h5py_read_dataset(self.file["vlen_data"], decode_ascii=decode_ascii, index=0)
- assert type(data) == type(result), data
+ data = utils.h5py_read_dataset(
+ self.file["vlen_data"], decode_ascii=decode_ascii, index=0
+ )
+ assert isinstance(data, type(result)), data
assert data == result, data
- data = utils.h5py_read_dataset(self.file["vlen_data"], decode_ascii=decode_ascii)
+ data = utils.h5py_read_dataset(
+ self.file["vlen_data"], decode_ascii=decode_ascii
+ )
numpy.testing.assert_array_equal(data, [result] * 2)
- if charset:
+ if charset is not None:
assert self.file["vlen_data"].id.get_type().get_cset() == charset
+ self.file["vlen_empty_array"] = self._make_array(value, 0)
+ data = utils.h5py_read_dataset(
+ self.file["vlen_empty_array"], decode_ascii=decode_ascii
+ )
+ assert data.shape == (0,)
+
+ # Write+read fixed length
+ self.file["flen_data"] = self._make_array(value, 2, vlen=False)
+ data = utils.h5py_read_dataset(
+ self.file["flen_data"], decode_ascii=decode_ascii, index=0
+ )
+ assert isinstance(data, type(result)), data
+ assert data == result, data
+ data = utils.h5py_read_dataset(
+ self.file["flen_data"], decode_ascii=decode_ascii
+ )
+ numpy.testing.assert_array_equal(data, [result] * 2)
+ if charset is not None and not no_unicode_support:
+ assert self.file["flen_data"].id.get_type().get_cset() == charset
+
def _check_attribute(self, value, result=None):
- if result:
+ if result is not None:
decode_ascii = True
else:
decode_ascii = False
result = value
+
+ # Write+read scalar
self.file.attrs["data"] = value
- data = utils.h5py_read_attribute(self.file.attrs, "data", decode_ascii=decode_ascii)
- assert type(data) == type(result), data
+ data = utils.h5py_read_attribute(
+ self.file.attrs, "data", decode_ascii=decode_ascii
+ )
+ assert isinstance(data, type(result)), data
assert data == result, data
+ # Write+read variable length
+ no_unicode_support = isinstance(value, str) and not hasattr(
+ h5py, "string_dtype"
+ )
+ if no_unicode_support:
+ decode_ascii = True
self.file.attrs["vlen_data"] = self._make_array(value, 2)
- data = utils.h5py_read_attribute(self.file.attrs, "vlen_data", decode_ascii=decode_ascii)
- assert type(data[0]) == type(result), data[0]
+ data = utils.h5py_read_attribute(
+ self.file.attrs, "vlen_data", decode_ascii=decode_ascii
+ )
+ assert isinstance(data[0], type(result)), data[0]
+ assert data[0] == result, data[0]
+ numpy.testing.assert_array_equal(data, [result] * 2)
+
+ data = utils.h5py_read_attributes(self.file.attrs, decode_ascii=decode_ascii)[
+ "vlen_data"
+ ]
+ assert isinstance(data[0], type(result)), data[0]
assert data[0] == result, data[0]
numpy.testing.assert_array_equal(data, [result] * 2)
- data = utils.h5py_read_attributes(self.file.attrs, decode_ascii=decode_ascii)["vlen_data"]
- assert type(data[0]) == type(result), data[0]
+ # Write+read fixed length
+ self.file.attrs["flen_data"] = self._make_array(value, 2, vlen=False)
+ data = utils.h5py_read_attribute(
+ self.file.attrs, "flen_data", decode_ascii=decode_ascii
+ )
+ assert isinstance(data[0], type(result)), data[0]
+ assert data[0] == result, data[0]
+ numpy.testing.assert_array_equal(data, [result] * 2)
+
+ data = utils.h5py_read_attributes(self.file.attrs, decode_ascii=decode_ascii)[
+ "flen_data"
+ ]
+ assert isinstance(data[0], type(result)), data[0]
assert data[0] == result, data[0]
numpy.testing.assert_array_equal(data, [result] * 2)
@@ -881,7 +1045,9 @@ def test_visitall_hdf5(tmp_path):
with h5py.File(filepath, mode="w") as h5file:
h5file["group/dataset"] = 50
h5file["link/soft_link"] = h5py.SoftLink("/group/dataset")
- h5file["link/external_link"] = h5py.ExternalLink("external.h5", "/target/dataset")
+ h5file["link/external_link"] = h5py.ExternalLink(
+ "external.h5", "/target/dataset"
+ )
with h5py.File(filepath, mode="r") as h5file:
visited_items = {}
@@ -906,12 +1072,13 @@ def test_visitall_hdf5(tmp_path):
"/link/external_link": (h5py.ExternalLink, ("external.h5", "/target/dataset")),
}
+
def test_visitall_commonh5():
"""Visit commonh5 File object"""
fobj = commonh5.File("filename.file", mode="w")
group = fobj.create_group("group")
dataset = group.create_dataset("dataset", data=numpy.array(50))
- group["soft_link"] = dataset # Create softlink
+ group["soft_link"] = dataset # Create softlink
visited_items = dict(utils.visitall(fobj))
assert len(visited_items) == 3
@@ -934,7 +1101,12 @@ def test_match_hdf5(tmp_path):
result = list(utils.match(h5f, "/entry_*/*"))
- assert sorted(result) == ['entry_0000/data', 'entry_0000/group', 'entry_0001/data', 'entry_0001/group']
+ assert sorted(result) == [
+ "entry_0000/data",
+ "entry_0000/group",
+ "entry_0001/data",
+ "entry_0001/group",
+ ]
def test_match_commonh5():
@@ -949,4 +1121,21 @@ def test_match_commonh5():
result = list(utils.match(fobj, "/entry_*/*"))
- assert sorted(result) == ['entry_0000/data', 'entry_0000/group', 'entry_0001/data', 'entry_0001/group']
+ assert sorted(result) == [
+ "entry_0000/data",
+ "entry_0000/group",
+ "entry_0001/data",
+ "entry_0001/group",
+ ]
+
+
+def test_recursive_match_commonh5():
+ """Test match function with commonh5 objects"""
+ with commonh5.File("filename.file", mode="w") as fobj:
+ fobj["entry_0000/bar/data"] = 0
+ fobj["entry_0001/foo/data"] = 1
+ fobj["entry_0001/foo/data1"] = 2
+ fobj["entry_0003"] = 3
+
+ result = list(utils.match(fobj, "**/data"))
+ assert result == ["entry_0000/bar/data", "entry_0001/foo/data"]
diff --git a/src/silx/io/test/test_write_to_h5.py b/src/silx/io/test/test_write_to_h5.py
index fe855e1..b74bf0f 100644
--- a/src/silx/io/test/test_write_to_h5.py
+++ b/src/silx/io/test/test_write_to_h5.py
@@ -30,7 +30,6 @@ from silx.io import spech5
from silx.io.convert import write_to_h5
from silx.io.dictdump import h5todict
from silx.io import commonh5
-from silx.io.spech5 import SpecH5
def test_with_commonh5(tmp_path):
@@ -38,13 +37,13 @@ def test_with_commonh5(tmp_path):
fobj = commonh5.File("filename.txt", mode="w")
group = fobj.create_group("group")
dataset = group.create_dataset("dataset", data=numpy.array(50))
- group["soft_link"] = dataset # Create softlink
+ group["soft_link"] = dataset # Create softlink
output_filepath = tmp_path / "output.h5"
write_to_h5(fobj, str(output_filepath))
assert h5todict(str(output_filepath)) == {
- 'group': {'dataset': numpy.array(50), 'soft_link': numpy.array(50)},
+ "group": {"dataset": numpy.array(50), "soft_link": numpy.array(50)},
}
with h5py.File(output_filepath, mode="r") as h5file:
soft_link = h5file.get("/group/soft_link", getlink=True)
@@ -63,7 +62,7 @@ def test_with_hdf5(tmp_path):
output_filepath = tmp_path / "output.h5"
write_to_h5(str(filepath), str(output_filepath))
assert h5todict(str(output_filepath)) == {
- 'group': {'dataset': 50, 'soft_link': 50},
+ "group": {"dataset": 50, "soft_link": 50},
}
with h5py.File(output_filepath, mode="r") as h5file:
soft_link = h5file.get("group/soft_link", getlink=True)
@@ -76,13 +75,14 @@ def test_with_spech5(tmp_path):
filepath = tmp_path / "file.spec"
filepath.write_bytes(
bytes(
-"""#F /tmp/sf.dat
+ """#F /tmp/sf.dat
#S 1 cmd
#L a b
1 2
""",
- encoding='ascii')
+ encoding="ascii",
+ )
)
output_filepath = tmp_path / "output.h5"
@@ -98,20 +98,23 @@ def test_with_spech5(tmp_path):
else:
numpy.array_equal(item1, item2)
- assert_equal(h5todict(str(output_filepath)), {
- '1.1': {
- 'instrument': {
- 'positioners': {},
- 'specfile': {
- 'file_header': ['#F /tmp/sf.dat'],
- 'scan_header': ['#S 1 cmd', '#L a b'],
+ assert_equal(
+ h5todict(str(output_filepath)),
+ {
+ "1.1": {
+ "instrument": {
+ "positioners": {},
+ "specfile": {
+ "file_header": ["#F /tmp/sf.dat"],
+ "scan_header": ["#S 1 cmd", "#L a b"],
+ },
},
+ "measurement": {
+ "a": [1.0],
+ "b": [2.0],
+ },
+ "start_time": "",
+ "title": "cmd",
},
- 'measurement': {
- 'a': [1.],
- 'b': [2.],
- },
- 'start_time': '',
- 'title': 'cmd',
},
- })
+ )
diff --git a/src/silx/io/url.py b/src/silx/io/url.py
index 71b3103..a3e04e4 100644
--- a/src/silx/io/url.py
+++ b/src/silx/io/url.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -23,17 +23,52 @@
# ###########################################################################*/
"""URL module"""
+from __future__ import annotations
+
__authors__ = ["V. Valls"]
__license__ = "MIT"
__date__ = "29/01/2018"
import logging
from collections.abc import Iterable
+from typing import Union
import urllib.parse
+from pathlib import Path
_logger = logging.getLogger(__name__)
+SliceLike = Union[slice, int, type(Ellipsis)]
+
+
+def _slice_to_string(s: SliceLike) -> str:
+ """Convert a Python slice into a string"""
+ if s == Ellipsis:
+ return "..."
+ elif isinstance(s, slice):
+ result = ""
+ if s.start is None:
+ result += ":"
+ else:
+ result += f"{s.start}:"
+ if s.stop is not None:
+ result += f"{s.stop}"
+ if s.step is not None:
+ result += f":{s.step}"
+ return result
+ elif isinstance(s, int):
+ return str(s)
+ else:
+ raise TypeError("Unexpected slicing type. Found %s" % type(s))
+
+
+def slice_sequence_to_string(data_slice: Iterable[SliceLike] | SliceLike) -> str:
+ """Convert a Python slice sequence or a slice into a string"""
+ if isinstance(data_slice, Iterable):
+ return ",".join([_slice_to_string(s) for s in data_slice])
+ else:
+ return _slice_to_string(data_slice)
+
class DataUrl(object):
"""Non-mutable object to parse a string representing a resource data
@@ -55,7 +90,7 @@ class DataUrl(object):
>>> DataUrl("silx:///data/image.edf?path=/scan_0/detector/data")
>>> DataUrl("silx:///C:/data/image.edf?path=/scan_0/detector/data")
- >>> # `path=` can be omited if there is no other query keys
+ >>> # `path=` can be omitted if there are no other query keys
>>> DataUrl("silx:///data/image.h5?/data/dataset")
>>> # is the same as
>>> DataUrl("silx:///data/image.h5?path=/data/dataset")
@@ -72,31 +107,41 @@ class DataUrl(object):
>>> DataUrl("silx:image.h5")
>>> DataUrl("fabio:image.edf")
- >>> # Is also support parsing of file access for convenience
+ >>> # It also supports parsing of file access for convenience
>>> DataUrl("./foo/bar/image.edf")
>>> DataUrl("C:/data/")
- :param str path: Path representing a link to a data. If specified, other
- arguments are not used.
- :param str file_path: Link to the file containing the the data.
+ :param path: Path representing a link to a data. If specified, other
+ arguments must not be provided.
+ :param file_path: Link to the file containing the the data.
None if there is no data selection.
- :param str data_path: Data selection applyed to the data file selected.
+ :param data_path: Data selection applied to the data file selected.
None if there is no data selection.
- :param Tuple[int,slice,Ellipse] data_slice: Slicing applyed of the selected
- data. None if no slicing applyed.
- :param Union[str,None] scheme: Scheme of the URL. "silx", "fabio"
+ :param data_slice: Slicing applied of the selected
+ data. None if no slicing applied.
+ :param scheme: Scheme of the URL. "silx", "fabio"
is supported. Other strings can be provided, but :meth:`is_valid` will
be false.
"""
- def __init__(self, path=None, file_path=None, data_path=None, data_slice=None, scheme=None):
+
+ def __init__(
+ self,
+ path: str | Path | None = None,
+ file_path: str | Path | None = None,
+ data_path: str | None = None,
+ data_slice: tuple[SliceLike, ...] | None = None,
+ scheme: str | None = None,
+ ):
self.__is_valid = False
if path is not None:
- assert(file_path is None)
- assert(data_path is None)
- assert(data_slice is None)
- assert(scheme is None)
- self.__parse_from_path(path)
+ assert file_path is None
+ assert data_path is None
+ assert data_slice is None
+ assert scheme is None
+ self.__parse_from_path(str(path))
else:
+ if file_path is not None:
+ file_path = str(file_path)
self.__file_path = file_path
self.__data_path = data_path
self.__data_slice = data_slice
@@ -130,6 +175,7 @@ class DataUrl(object):
def __str__(self):
if self.is_valid() or self.__path is None:
+
def quote_string(string):
if isinstance(string, str):
return "'%s'" % string
@@ -137,11 +183,13 @@ class DataUrl(object):
return string
template = "DataUrl(valid=%s, scheme=%s, file_path=%s, data_path=%s, data_slice=%s)"
- return template % (self.__is_valid,
- quote_string(self.__scheme),
- quote_string(self.__file_path),
- quote_string(self.__data_path),
- self.__data_slice)
+ return template % (
+ self.__is_valid,
+ quote_string(self.__scheme),
+ quote_string(self.__file_path),
+ quote_string(self.__data_path),
+ self.__data_slice,
+ )
else:
template = "DataUrl(valid=%s, string=%s)"
return template % (self.__is_valid, self.__path)
@@ -159,32 +207,36 @@ class DataUrl(object):
elif self.__scheme == "silx":
# If there is a slice you must have a data path
# But you can have a data path without slice
- slice_implies_data = (self.__data_path is None and self.__data_slice is None) or self.__data_path is not None
+ slice_implies_data = (
+ self.__data_path is None and self.__data_slice is None
+ ) or self.__data_path is not None
self.__is_valid = slice_implies_data
else:
self.__is_valid = False
@staticmethod
- def _parse_slice(slice_string):
+ def _parse_slice(slice_string: str) -> tuple[SliceLike, ...]:
"""Parse a slicing sequence and return an associated tuple.
It supports a sequence of `...`, `:`, and integers separated by a coma.
-
- :rtype: tuple
"""
- def str_to_slice(string):
+
+ def string_to_slice(string: str) -> SliceLike:
+ """Convert a string to a Python slice"""
if string == "...":
return Ellipsis
- elif ':' in string:
+ elif ":" in string:
if string == ":":
return slice(None)
else:
+
def get_value(my_str):
- if my_str in ('', None):
+ if my_str in ("", None):
return None
else:
return int(my_str)
- sss = string.split(':')
+
+ sss = string.split(":")
start = get_value(sss[0])
stop = get_value(sss[1] if len(sss) > 1 else None)
step = get_value(sss[2] if len(sss) > 2 else None)
@@ -196,23 +248,23 @@ class DataUrl(object):
raise ValueError("An empty slice is not valid")
tokens = slice_string.split(",")
- data_slice = []
+ data_slice: list[SliceLike] = []
for t in tokens:
try:
- data_slice.append(str_to_slice(t))
+ data_slice.append(string_to_slice(t))
except ValueError:
raise ValueError("'%s' is not a valid slicing" % t)
return tuple(data_slice)
- def __parse_from_path(self, path):
+ def __parse_from_path(self, path: str):
"""Parse the path and initialize attributes.
- :param str path: Path representing the URL.
+ :param path: Path representing the URL.
"""
self.__path = path
# only replace if ? not here already. Otherwise can mess sith
# data_slice if == ::2 for example
- if '?' not in path:
+ if "?" not in path:
path = path.replace("::", "?", 1)
url = urllib.parse.urlparse(path)
@@ -228,7 +280,7 @@ class DataUrl(object):
file_path = url.path
# Check absolute windows path
- if len(file_path) > 2 and file_path[0] == '/':
+ if len(file_path) > 2 and file_path[0] == "/":
if file_path[1] == ":" or file_path[2] == ":":
file_path = file_path[1:]
@@ -252,7 +304,10 @@ class DataUrl(object):
if name in merged_query:
values = merged_query.pop(name)
if len(values) > 1:
- _logger.warning("More than one query key named '%s'. The last one is used.", name)
+ _logger.warning(
+ "More than one query key named '%s'. The last one is used.",
+ name,
+ )
value = values[-1]
else:
value = None
@@ -278,31 +333,15 @@ class DataUrl(object):
else:
self.__is_valid = False
- def is_valid(self):
- """Returns true if the URL is valid. Else attributes can be None.
-
- :rtype: bool
- """
+ def is_valid(self) -> bool:
+ """Returns true if the URL is valid. Else attributes can be None."""
return self.__is_valid
- def path(self):
- """Returns the string representing the URL.
-
- :rtype: str
- """
+ def path(self) -> str:
+ """Returns the string representing the URL."""
if self.__path is not None:
return self.__path
- def slice_to_string(data_slice):
- if data_slice == Ellipsis:
- return "..."
- elif data_slice == slice(None):
- return ":"
- elif isinstance(data_slice, int):
- return str(data_slice)
- else:
- raise TypeError("Unexpected slicing type. Found %s" % type(data_slice))
-
if self.__data_path is not None and self.__data_slice is None:
query = self.__data_path
else:
@@ -310,10 +349,7 @@ class DataUrl(object):
if self.__data_path is not None:
queries.append("path=" + self.__data_path)
if self.__data_slice is not None:
- if isinstance(self.__data_slice, Iterable):
- data_slice = ",".join([slice_to_string(s) for s in self.__data_slice])
- else:
- data_slice = slice_to_string(self.__data_slice)
+ data_slice = slice_sequence_to_string(self.__data_slice)
queries.append("slice=" + data_slice)
query = "&".join(queries)
@@ -335,11 +371,8 @@ class DataUrl(object):
return path
- def is_absolute(self):
- """Returns true if the file path is an absolute path.
-
- :rtype: bool
- """
+ def is_absolute(self) -> bool:
+ """Returns true if the file path is an absolute path."""
file_path = self.file_path()
if file_path is None:
return False
@@ -356,32 +389,21 @@ class DataUrl(object):
return True
return False
- def file_path(self):
- """Returns the path to the file containing the data.
-
- :rtype: str
- """
+ def file_path(self) -> str:
+ """Returns the path to the file containing the data."""
return self.__file_path
- def data_path(self):
- """Returns the path inside the file to the data.
-
- :rtype: str
- """
+ def data_path(self) -> str | None:
+ """Returns the path inside the file to the data."""
return self.__data_path
- def data_slice(self):
+ def data_slice(self) -> tuple[SliceLike, ...] | None:
"""Returns the slicing applied to the data.
It is a tuple containing numbers, slice or ellipses.
-
- :rtype: Tuple[int, slice, Ellipse]
"""
return self.__data_slice
- def scheme(self):
- """Returns the scheme. It can be None if no scheme is specified.
-
- :rtype: Union[str, None]
- """
+ def scheme(self) -> str | None:
+ """Returns the scheme. It can be None if no scheme is specified."""
return self.__scheme
diff --git a/src/silx/io/utils.py b/src/silx/io/utils.py
index 0588138..ae6a55b 100644
--- a/src/silx/io/utils.py
+++ b/src/silx/io/utils.py
@@ -1,5 +1,5 @@
# /*##########################################################################
-# Copyright (C) 2016-2022 European Synchrotron Radiation Facility
+# Copyright (C) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -32,14 +32,14 @@ import os.path
import sys
import time
import logging
-import collections
-from typing import Generator
+from typing import Generator, Union, Optional
import urllib.parse
import numpy
from silx.utils.proxy import Proxy
-import silx.io.url
+from .url import DataUrl
+from . import h5py_utils
from .._version import calc_hexversion
import h5py
@@ -59,6 +59,7 @@ NEXUS_HDF5_EXT = [".h5", ".nx5", ".nxs", ".hdf", ".hdf5", ".cxi"]
class H5Type(enum.Enum):
"""Identify a set of HDF5 concepts"""
+
DATASET = 1
GROUP = 2
FILE = 3
@@ -70,7 +71,6 @@ class H5Type(enum.Enum):
_CLASSES_TYPE = None
"""Store mapping between classes and types"""
-string_types = (basestring,) if sys.version_info[0] == 2 else (str,) # noqa
builtin_open = open
@@ -86,7 +86,7 @@ def supported_extensions(flat_formats=True):
extensions (an extension is a string like "\\*.ext").
:rtype: Dict[str, Set[str]]
"""
- formats = collections.OrderedDict()
+ formats = {}
formats["HDF5 files"] = set(["*.h5", "*.hdf", "*.hdf5"])
formats["NeXus files"] = set(["*.nx", "*.nxs", "*.h5", "*.hdf", "*.hdf5"])
formats["NeXus layout from spec files"] = set(["*.dat", "*.spec", "*.mca"])
@@ -96,7 +96,9 @@ def supported_extensions(flat_formats=True):
except ImportError:
fabioh5 = None
if fabioh5 is not None:
- formats["NeXus layout from fabio files"] = set(fabioh5.supported_extensions())
+ formats["NeXus layout from fabio files"] = set(
+ fabioh5.supported_extensions()
+ )
extensions = ["*.npz"]
if flat_formats:
@@ -108,9 +110,21 @@ def supported_extensions(flat_formats=True):
return formats
-def save1D(fname, x, y, xlabel=None, ylabels=None, filetype=None,
- fmt="%.7g", csvdelim=";", newline="\n", header="",
- footer="", comments="#", autoheader=False):
+def save1D(
+ fname,
+ x,
+ y,
+ xlabel=None,
+ ylabels=None,
+ filetype=None,
+ fmt="%.7g",
+ csvdelim=";",
+ newline="\n",
+ header="",
+ footer="",
+ comments="#",
+ autoheader=False,
+):
"""Saves any number of curves to various formats: `Specfile`, `CSV`,
`txt` or `npy`. All curves must have the same number of points and share
the same ``x`` values.
@@ -169,19 +183,17 @@ def save1D(fname, x, y, xlabel=None, ylabels=None, filetype=None,
available_formats = ["spec", "csv", "txt", "ndarray"]
if filetype is None:
- exttypes = {".dat": "spec",
- ".csv": "csv",
- ".txt": "txt",
- ".npy": "ndarray"}
- outfname = (fname if not hasattr(fname, "name") else
- fname.name)
+ exttypes = {".dat": "spec", ".csv": "csv", ".txt": "txt", ".npy": "ndarray"}
+ outfname = fname if not hasattr(fname, "name") else fname.name
fileext = os.path.splitext(outfname)[1]
if fileext in exttypes:
filetype = exttypes[fileext]
else:
- raise IOError("File type unspecified and could not be " +
- "inferred from file extension (not in " +
- "txt, dat, csv, npy)")
+ raise IOError(
+ "File type unspecified and could not be "
+ + "inferred from file extension (not in "
+ + "txt, dat, csv, npy)"
+ )
else:
filetype = filetype.lower()
@@ -199,8 +211,9 @@ def save1D(fname, x, y, xlabel=None, ylabels=None, filetype=None,
elif isinstance(ylabels, (list, tuple)):
# if ylabels is provided as a list, every element must
# be a string
- ylabels = [ylabel if isinstance(ylabel, string_types) else "y%d" % i
- for ylabel in ylabels]
+ ylabels = [
+ ylabel if isinstance(ylabel, str) else "y%d" % i for ylabel in ylabels
+ ]
if filetype.lower() == "spec":
# Check if we have regular data:
@@ -211,9 +224,18 @@ def save1D(fname, x, y, xlabel=None, ylabels=None, filetype=None,
if regular:
if isinstance(fmt, (list, tuple)) and len(fmt) < (len(ylabels) + 1):
fmt = fmt + [fmt[-1] * (1 + len(ylabels) - len(fmt))]
- specf = savespec(fname, x, y, xlabel, ylabels, fmt=fmt,
- scan_number=1, mode="w", write_file_header=True,
- close_file=False)
+ specf = savespec(
+ fname,
+ x,
+ y,
+ xlabel,
+ ylabels,
+ fmt=fmt,
+ scan_number=1,
+ mode="w",
+ write_file_header=True,
+ close_file=False,
+ )
else:
y_array = numpy.asarray(y)
# make sure y_array is a 2D array even for a single curve
@@ -223,14 +245,32 @@ def save1D(fname, x, y, xlabel=None, ylabels=None, filetype=None,
raise IndexError("y must be a 1D or 2D array")
# First curve
- specf = savespec(fname, x, y_array[0], xlabel, ylabels[0], fmt=fmt,
- scan_number=1, mode="w", write_file_header=True,
- close_file=False)
+ specf = savespec(
+ fname,
+ x,
+ y_array[0],
+ xlabel,
+ ylabels[0],
+ fmt=fmt,
+ scan_number=1,
+ mode="w",
+ write_file_header=True,
+ close_file=False,
+ )
# Other curves
for i in range(1, y_array.shape[0]):
- specf = savespec(specf, x, y_array[i], xlabel, ylabels[i],
- fmt=fmt, scan_number=i + 1, mode="w",
- write_file_header=False, close_file=False)
+ specf = savespec(
+ specf,
+ x,
+ y_array[i],
+ xlabel,
+ ylabels[i],
+ fmt=fmt,
+ scan_number=i + 1,
+ mode="w",
+ write_file_header=False,
+ close_file=False,
+ )
# close file if we created it
if not hasattr(fname, "write"):
@@ -261,9 +301,16 @@ def save1D(fname, x, y, xlabel=None, ylabels=None, filetype=None,
if filetype.lower() in ["csv", "txt"]:
X = X.transpose()
- savetxt(fname, X, fmt=fmt, delimiter=csvdelim,
- newline=newline, header=header, footer=footer,
- comments=comments)
+ savetxt(
+ fname,
+ X,
+ fmt=fmt,
+ delimiter=csvdelim,
+ newline=newline,
+ header=header,
+ footer=footer,
+ comments=comments,
+ )
elif filetype.lower() == "ndarray":
if xlabel is not None and ylabels is not None:
@@ -271,14 +318,21 @@ def save1D(fname, x, y, xlabel=None, ylabels=None, filetype=None,
# .transpose is needed here because recarray labels
# apply to columns
- X = numpy.core.records.fromrecords(X.transpose(),
- names=labels)
+ X = numpy.core.records.fromrecords(X.transpose(), names=labels)
numpy.save(fname, X)
# Replace with numpy.savetxt when dropping support of numpy < 1.7.0
-def savetxt(fname, X, fmt="%.7g", delimiter=";", newline="\n",
- header="", footer="", comments="#"):
+def savetxt(
+ fname,
+ X,
+ fmt="%.7g",
+ delimiter=";",
+ newline="\n",
+ header="",
+ footer="",
+ comments="#",
+):
"""``numpy.savetxt`` backport of header and footer arguments from
numpy=1.7.0.
@@ -286,31 +340,35 @@ def savetxt(fname, X, fmt="%.7g", delimiter=";", newline="\n",
http://docs.scipy.org/doc/numpy-1.10.0/reference/generated/numpy.savetxt.html
"""
if not hasattr(fname, "name"):
- ffile = builtin_open(fname, 'wb')
+ ffile = builtin_open(fname, "wb")
else:
ffile = fname
if header:
- if sys.version_info[0] >= 3:
- header = header.encode("utf-8")
- ffile.write(header)
+ ffile.write(header.encode("utf-8"))
numpy.savetxt(ffile, X, fmt, delimiter, newline)
if footer:
- footer = (comments + footer.replace(newline, newline + comments) +
- newline)
- if sys.version_info[0] >= 3:
- footer = footer.encode("utf-8")
- ffile.write(footer)
+ footer = comments + footer.replace(newline, newline + comments) + newline
+ ffile.write(footer.encode("utf-8"))
if not hasattr(fname, "name"):
ffile.close()
-def savespec(specfile, x, y, xlabel="X", ylabel="Y", fmt="%.7g",
- scan_number=1, mode="w", write_file_header=True,
- close_file=False):
+def savespec(
+ specfile,
+ x,
+ y,
+ xlabel="X",
+ ylabel="Y",
+ fmt="%.7g",
+ scan_number=1,
+ mode="w",
+ write_file_header=True,
+ close_file=False,
+):
"""Saves one curve to a SpecFile.
The curve is saved as a scan with two data columns. To save multiple
@@ -324,7 +382,7 @@ def savespec(specfile, x, y, xlabel="X", ylabel="Y", fmt="%.7g",
:param y: 1D-array (or list), or list of them of ordinates values.
All dataset must have the same length as x
:param xlabel: Abscissa label (default ``"X"``)
- :param ylabel: Ordinate label, may be a list of labels when multiple curves
+ :param ylabel: Ordinate label, may be a list of labels when multiple curves
are to be saved together.
:param fmt: Format string for data. You can specify a short format
string that defines a single format for both ``x`` and ``y`` values,
@@ -366,13 +424,15 @@ def savespec(specfile, x, y, xlabel="X", ylabel="Y", fmt="%.7g",
assert len(labels) == ncol
print(xlabel, ylabel, fmt, ncol, x_array, y_array)
- if isinstance(fmt, string_types) and fmt.count("%") == 1:
+ if isinstance(fmt, str) and fmt.count("%") == 1:
full_fmt_string = " ".join([fmt] * ncol)
elif isinstance(fmt, (list, tuple)) and len(fmt) == ncol:
full_fmt_string = " ".join(fmt)
else:
- raise ValueError("`fmt` must be a single format string or a list of " +
- "format strings with as many format as ncolumns")
+ raise ValueError(
+ "`fmt` must be a single format string or a list of "
+ + "format strings with as many format as ncolumns"
+ )
if not hasattr(specfile, "write"):
f = builtin_open(specfile, mode)
@@ -381,14 +441,16 @@ def savespec(specfile, x, y, xlabel="X", ylabel="Y", fmt="%.7g",
current_date = "#D %s" % (time.ctime(time.time()))
if write_file_header:
- lines = [ "#F %s" % f.name, current_date, ""]
+ lines = ["#F %s" % f.name, current_date, ""]
else:
lines = [""]
- lines += [ "#S %d %s" % (scan_number, labels[1]),
- current_date,
- "#N %d" % ncol,
- "#L " + " ".join(labels)]
+ lines += [
+ "#S %d %s" % (scan_number, labels[1]),
+ current_date,
+ "#N %d" % ncol,
+ "#L " + " ".join(labels),
+ ]
for i in data.T:
lines.append(full_fmt_string % tuple(i))
@@ -429,27 +491,27 @@ def h5ls(h5group, lvl=0):
.. note:: This function requires `h5py <http://www.h5py.org/>`_ to be
installed.
"""
- h5repr = ''
+ h5repr = ""
if is_group(h5group):
h5f = h5group
- elif isinstance(h5group, string_types):
+ elif isinstance(h5group, str):
h5f = open(h5group) # silx.io.open
else:
raise TypeError("h5group must be a hdf5-like group object or a file name.")
for key in h5f.keys():
# group
- if hasattr(h5f[key], 'keys'):
- h5repr += '\t' * lvl + '+' + key
- h5repr += '\n'
+ if hasattr(h5f[key], "keys"):
+ h5repr += "\t" * lvl + "+" + key
+ h5repr += "\n"
h5repr += h5ls(h5f[key], lvl + 1)
# dataset
else:
- h5repr += '\t' * lvl
+ h5repr += "\t" * lvl
h5repr += str(h5f[key])
- h5repr += '\n'
+ h5repr += "\n"
- if isinstance(h5group, string_types):
+ if isinstance(h5group, str):
h5f.close()
return h5repr
@@ -482,42 +544,49 @@ def _open_local_file(filename):
if extension in [".npz", ".npy"]:
try:
from . import rawh5
+
return rawh5.NumpyFile(filename)
except (IOError, ValueError) as e:
- debugging_info.append((sys.exc_info(),
- "File '%s' can't be read as a numpy file." % filename))
+ debugging_info.append(
+ (
+ sys.exc_info(),
+ "File '%s' can't be read as a numpy file." % filename,
+ )
+ )
if h5py.is_hdf5(filename):
- try:
- return h5py.File(filename, "r")
- except OSError:
- return h5py.File(filename, "r", libver='latest', swmr=True)
+ return h5py_utils.File(filename, "r")
try:
from . import fabioh5
+
return fabioh5.File(filename)
except ImportError:
debugging_info.append((sys.exc_info(), "fabioh5 can't be loaded."))
except Exception:
- debugging_info.append((sys.exc_info(),
- "File '%s' can't be read as fabio file." % filename))
+ debugging_info.append(
+ (sys.exc_info(), "File '%s' can't be read as fabio file." % filename)
+ )
try:
from . import spech5
+
return spech5.SpecH5(filename)
except ImportError:
- debugging_info.append((sys.exc_info(),
- "spech5 can't be loaded."))
+ debugging_info.append((sys.exc_info(), "spech5 can't be loaded."))
except IOError:
- debugging_info.append((sys.exc_info(),
- "File '%s' can't be read as spec file." % filename))
+ debugging_info.append(
+ (sys.exc_info(), "File '%s' can't be read as spec file." % filename)
+ )
try:
from . import fioh5
+
return fioh5.FioH5(filename)
except IOError:
- debugging_info.append((sys.exc_info(),
- "File '%s' can't be read as fio file." % filename))
+ debugging_info.append(
+ (sys.exc_info(), "File '%s' can't be read as fio file." % filename)
+ )
finally:
for exc_info, message in debugging_info:
@@ -594,7 +663,7 @@ def open(filename): # pylint:disable=redefined-builtin
:raises: IOError if the file can't be loaded or path can't be found
:rtype: h5py-like node
"""
- url = silx.io.url.DataUrl(filename)
+ url = DataUrl(filename)
if url.scheme() in [None, "file", "silx"]:
# That's a local file
@@ -612,13 +681,11 @@ def open(filename): # pylint:disable=redefined-builtin
endpoint = "%s://%s" % (uri.scheme, uri.netloc)
if path.startswith("/"):
path = path[1:]
- return h5pyd.File(path, 'r', endpoint=endpoint)
-
- if url.data_slice():
- raise IOError("URL '%s' containing slicing is not supported" % filename)
+ return h5pyd.File(path, "r", endpoint=endpoint)
- if url.data_path() in [None, "/", ""]:
- # The full file is requested
+ if url.data_path() in [None, "/", ""]: # The full file is requested
+ if url.data_slice():
+ raise IOError(f"URL '{filename}' containing slicing is not supported")
return h5_file
else:
# Only a children is requested
@@ -626,6 +693,17 @@ def open(filename): # pylint:disable=redefined-builtin
msg = "File '%s' does not contain path '%s'." % (filename, url.data_path())
raise IOError(msg)
node = h5_file[url.data_path()]
+
+ if url.data_slice() is not None:
+ from . import _sliceh5 # Lazy-import to avoid circular dependency
+
+ try:
+ return _sliceh5.DatasetSlice(node, url.data_slice(), attrs=node.attrs)
+ except ValueError:
+ raise IOError(
+ f"URL {filename} contains slicing, but it is not a dataset"
+ )
+
proxy = _MainNode(node, h5_file)
return proxy
@@ -642,7 +720,7 @@ def _get_classes_type():
if _CLASSES_TYPE is not None:
return _CLASSES_TYPE
- _CLASSES_TYPE = collections.OrderedDict()
+ _CLASSES_TYPE = {}
_CLASSES_TYPE[commonh5.Dataset] = H5Type.DATASET
_CLASSES_TYPE[commonh5.File] = H5Type.FILE
@@ -793,7 +871,7 @@ def is_link(obj):
return t in {H5Type.SOFT_LINK, H5Type.EXTERNAL_LINK}
-def _visitall(item, path=''):
+def _visitall(item, path=""):
"""Helper function for func:`visitall`.
:param item: Item to visit
@@ -807,7 +885,7 @@ def _visitall(item, path=''):
link = item.get(name, getlink=True)
else:
link = child_item
- child_path = '/'.join((path, name))
+ child_path = "/".join((path, name))
ret = link if link is not None and is_link(link) else child_item
yield child_path, ret
@@ -822,16 +900,33 @@ def visitall(item):
:param item: The item to visit.
"""
- yield from _visitall(item, '')
+ yield from _visitall(item, "")
+def iter_groups(group, _root=None):
+ """Pythonic implementation of h5py.Group visit()"""
+ for name in group.keys():
+ entity = group.get(name)
+ if is_group(entity):
+ yield name
+ for subgroup in iter_groups(entity, _root=name):
+ yield f"{name}/{subgroup}"
+
def match(group, path_pattern: str) -> Generator[str, None, None]:
"""Generator of paths inside given h5py-like `group` matching `path_pattern`"""
if not is_group(group):
raise ValueError(f"Not a h5py-like group: {group}")
- path_parts = path_pattern.strip("/").split("/", 1)
+ path_parts = path_pattern.replace("\\", "/").strip("/").split("/", 1)
+ if path_parts[0] == "**":
+ # recursive match
+ for subpath in iter_groups(group):
+ sub = group.get(subpath)
+ for groupname in match(sub, path_parts[1]):
+ yield f"{subpath}/{groupname}"
+ return
+
for matching_path in fnmatch.filter(group.keys(), path_parts[0]):
if len(path_parts) == 1: # No more sub-path, stop recursion
yield matching_path
@@ -843,7 +938,7 @@ def match(group, path_pattern: str) -> Generator[str, None, None]:
yield f"{matching_path}/{matching_subpath}"
-def get_data(url):
+def get_data(url: Union[str, DataUrl]):
"""Returns a numpy data from an URL.
Examples:
@@ -868,7 +963,7 @@ def get_data(url):
.. seealso:: :class:`silx.io.url.DataUrl`
- :param Union[str,silx.io.url.DataUrl]: A data URL
+ :param url: A data URL
:rtype: Union[numpy.ndarray, numpy.generic]
:raises ImportError: If the mandatory library to read the file is not
available.
@@ -877,8 +972,8 @@ def get_data(url):
:meth:`fabio.open` or :meth:`silx.io.open`. In this last case more
informations are displayed in debug mode.
"""
- if not isinstance(url, silx.io.url.DataUrl):
- url = silx.io.url.DataUrl(url)
+ if not isinstance(url, DataUrl):
+ url = DataUrl(url)
if not url.is_valid():
raise ValueError("URL '%s' is not valid" % url.path())
@@ -895,8 +990,10 @@ def get_data(url):
raise ValueError("Data path from URL '%s' not found" % url.path())
data = h5[data_path]
- if not silx.io.is_dataset(data):
- raise ValueError("Data path from URL '%s' is not a dataset" % url.path())
+ if not is_dataset(data):
+ raise ValueError(
+ "Data path from URL '%s' is not a dataset" % url.path()
+ )
if data_slice is not None:
data = h5py_read_dataset(data, index=data_slice)
@@ -906,24 +1003,36 @@ def get_data(url):
elif url.scheme() == "fabio":
import fabio
+
data_slice = url.data_slice()
if data_slice is None:
data_slice = (0,)
if data_slice is None or len(data_slice) != 1:
- raise ValueError("Fabio slice expect a single frame, but %s found" % data_slice)
+ raise ValueError(
+ "Fabio slice expect a single frame, but %s found" % data_slice
+ )
index = data_slice[0]
if not isinstance(index, int):
- raise ValueError("Fabio slice expect a single integer, but %s found" % data_slice)
+ raise ValueError(
+ "Fabio slice expect a single integer, but %s found" % data_slice
+ )
try:
fabio_file = fabio.open(url.file_path())
except Exception:
- logger.debug("Error while opening %s with fabio", url.file_path(), exc_info=True)
- raise IOError("Error while opening %s with fabio (use debug for more information)" % url.path())
+ logger.debug(
+ "Error while opening %s with fabio", url.file_path(), exc_info=True
+ )
+ raise IOError(
+ "Error while opening %s with fabio (use debug for more information)"
+ % url.path()
+ )
if fabio_file.nframes == 1:
if index != 0:
- raise ValueError("Only a single frame available. Slice %s out of range" % index)
+ raise ValueError(
+ "Only a single frame available. Slice %s out of range" % index
+ )
data = fabio_file.data
else:
data = fabio_file.getframe(index).data
@@ -931,14 +1040,31 @@ def get_data(url):
# There is no explicit close
fabio_file = None
+ elif url.scheme() is None:
+ for scheme in ("silx", "fabio"):
+ specificUrl = DataUrl(
+ file_path=url.file_path(),
+ data_slice=url.data_slice(),
+ data_path=url.data_path(),
+ scheme=scheme,
+ )
+ try:
+ data = get_data(specificUrl)
+ except Exception:
+ logger.debug(
+ "Error while trying to loading %s as %s", url, scheme, exc_info=True
+ )
+ else:
+ break
+ else:
+ raise ValueError(f"Data from '{url}' is not readable as silx nor fabio")
else:
raise ValueError("Scheme '%s' not supported" % url.scheme())
return data
-def rawfile_to_h5_external_dataset(bin_file, output_url, shape, dtype,
- overwrite=False):
+def rawfile_to_h5_external_dataset(bin_file, output_url, shape, dtype, overwrite=False):
"""
Create a HDF5 dataset at `output_url` pointing to the given vol_file.
@@ -950,29 +1076,34 @@ def rawfile_to_h5_external_dataset(bin_file, output_url, shape, dtype,
:param numpy.dtype dtype: Data type of the volume elements (default: float32)
:param bool overwrite: True to allow overwriting (default: False).
"""
- assert isinstance(output_url, silx.io.url.DataUrl)
+ assert isinstance(output_url, DataUrl)
assert isinstance(shape, (tuple, list))
- v_majeur, v_mineur, v_micro = [int(i) for i in h5py.version.version.split('.')[:3]]
- if calc_hexversion(v_majeur, v_mineur, v_micro)< calc_hexversion(2,9,0):
- raise Exception('h5py >= 2.9 should be installed to access the '
- 'external feature.')
+ v_majeur, v_mineur, v_micro = [int(i) for i in h5py.version.version.split(".")[:3]]
+ if calc_hexversion(v_majeur, v_mineur, v_micro) < calc_hexversion(2, 9, 0):
+ raise Exception(
+ "h5py >= 2.9 should be installed to access the " "external feature."
+ )
with h5py.File(output_url.file_path(), mode="a") as _h5_file:
if output_url.data_path() in _h5_file:
if overwrite is False:
- raise ValueError('data_path already exists')
+ raise ValueError("data_path already exists")
else:
- logger.warning('will overwrite path %s' % output_url.data_path())
+ logger.warning("will overwrite path %s" % output_url.data_path())
del _h5_file[output_url.data_path()]
external = [(bin_file, 0, h5py.h5f.UNLIMITED)]
- _h5_file.create_dataset(output_url.data_path(),
- shape,
- dtype=dtype,
- external=external)
-
-
-def vol_to_h5_external_dataset(vol_file, output_url, info_file=None,
- vol_dtype=numpy.float32, overwrite=False):
+ _h5_file.create_dataset(
+ output_url.data_path(), shape, dtype=dtype, external=external
+ )
+
+
+def vol_to_h5_external_dataset(
+ vol_file,
+ output_url: DataUrl,
+ info_file: Optional[str] = None,
+ vol_dtype=numpy.float32,
+ overwrite=False,
+):
"""
Create a HDF5 dataset at `output_url` pointing to the given vol_file.
@@ -980,8 +1111,8 @@ def vol_to_h5_external_dataset(vol_file, output_url, info_file=None,
vol-file then you should specify her location.
:param str vol_file: Path to the .vol file
- :param DataUrl output_url: HDF5 URL where to save the external dataset
- :param Union[str,None] info_file:
+ :param output_url: HDF5 URL where to save the external dataset
+ :param info_file:
.vol.info file name written by pyhst and containing the shape information
:param numpy.dtype vol_dtype: Data type of the volume elements (default: float32)
:param bool overwrite: True to allow overwriting (default: False).
@@ -989,10 +1120,12 @@ def vol_to_h5_external_dataset(vol_file, output_url, info_file=None,
"""
_info_file = info_file
if _info_file is None:
- _info_file = vol_file + '.info'
+ _info_file = vol_file + ".info"
if not os.path.exists(_info_file):
- logger.error('info_file not given and %s does not exists, please'
- 'specify .vol.info file' % _info_file)
+ logger.error(
+ "info_file not given and %s does not exists, please"
+ "specify .vol.info file" % _info_file
+ )
return
def info_file_to_dict():
@@ -1000,29 +1133,49 @@ def vol_to_h5_external_dataset(vol_file, output_url, info_file=None,
with builtin_open(info_file, "r") as _file:
lines = _file.readlines()
for line in lines:
- if not '=' in line:
+ if not "=" in line:
continue
- l = line.rstrip().replace(' ', '')
- l = l.split('#')[0]
- key, value = l.split('=')
+ l = line.rstrip().replace(" ", "")
+ l = l.split("#")[0]
+ key, value = l.split("=")
ddict[key.lower()] = value
return ddict
ddict = info_file_to_dict()
- if 'num_x' not in ddict or 'num_y' not in ddict or 'num_z' not in ddict:
- raise ValueError(
- 'Unable to retrieve volume shape from %s' % info_file)
+ if "num_x" not in ddict or "num_y" not in ddict or "num_z" not in ddict:
+ raise ValueError("Unable to retrieve volume shape from %s" % info_file)
- dimX = int(ddict['num_x'])
- dimY = int(ddict['num_y'])
- dimZ = int(ddict['num_z'])
+ dimX = int(ddict["num_x"])
+ dimY = int(ddict["num_y"])
+ dimZ = int(ddict["num_z"])
shape = (dimZ, dimY, dimX)
- return rawfile_to_h5_external_dataset(bin_file=vol_file,
- output_url=output_url,
- shape=shape,
- dtype=vol_dtype,
- overwrite=overwrite)
+ return rawfile_to_h5_external_dataset(
+ bin_file=vol_file,
+ output_url=output_url,
+ shape=shape,
+ dtype=vol_dtype,
+ overwrite=overwrite,
+ )
+
+
+def hdf5_to_python_type(value, decode_ascii, encoding):
+ """Convert HDF5 type to proper python type.
+
+ :param value:
+ :param bool decode_ascii:
+ :param encoding str:
+ """
+ if encoding == "ascii":
+ is_bytes = h5py_value_isinstance(value, bytes)
+ if is_bytes and decode_ascii:
+ return h5py_decode_value(value, encoding="utf-8")
+ if not is_bytes and not decode_ascii:
+ return h5py_encode_value(value, encoding="utf-8")
+ elif encoding == "utf-8":
+ if h5py_value_isinstance(value, bytes):
+ return h5py_decode_value(value, encoding="utf-8")
+ return value
def h5py_decode_value(value, encoding="utf-8", errors="surrogateescape"):
@@ -1034,8 +1187,8 @@ def h5py_decode_value(value, encoding="utf-8", errors="surrogateescape"):
"""
try:
if numpy.isscalar(value):
- return value.decode(encoding, errors=errors)
- str_item = [b.decode(encoding, errors=errors) for b in value.flat]
+ return _decode_string(value, encoding, errors)
+ str_item = [_decode_string(b, encoding, errors) for b in value.flat]
return numpy.array(str_item, dtype=object).reshape(value.shape)
except UnicodeDecodeError:
return value
@@ -1050,13 +1203,55 @@ def h5py_encode_value(value, encoding="utf-8", errors="surrogateescape"):
"""
try:
if numpy.isscalar(value):
- return value.encode(encoding, errors=errors)
- bytes_item = [s.encode(encoding, errors=errors) for s in value.flat]
+ return _encode_string(value, encoding, errors)
+ bytes_item = [_encode_string(s, encoding, errors=errors) for s in value.flat]
return numpy.array(bytes_item, dtype=object).reshape(value.shape)
except UnicodeEncodeError:
return value
+def h5py_value_isinstance(value, vtype):
+ """Keep string when value cannot be encoding
+
+ :param value: string or array of strings
+ :param vtype:
+ :return bool:
+ """
+ if numpy.isscalar(value):
+ try:
+ value = value.item()
+ except AttributeError:
+ pass
+ else:
+ try:
+ value = value[0]
+ except IndexError:
+ pass
+ return isinstance(value, vtype)
+
+
+def _decode_string(string, encoding, errors):
+ """
+ :param value: string
+ :param encoding str:
+ :param errors str:
+ """
+ if isinstance(string, str):
+ return string
+ return string.decode(encoding, errors=errors)
+
+
+def _encode_string(string, encoding, errors):
+ """
+ :param value: string
+ :param encoding str:
+ :param errors str:
+ """
+ if isinstance(string, bytes):
+ return string
+ return string.encode(encoding, errors=errors)
+
+
class H5pyDatasetReadWrapper:
"""Wrapper to handle H5T_STRING decoding on-the-fly when reading
a dataset. Uniform behaviour for h5py 2.x and h5py 3.x
@@ -1066,13 +1261,12 @@ class H5pyDatasetReadWrapper:
Therefore an H5T_STRING with ASCII encoding is not decoded by default.
"""
- H5PY_AUTODECODE_NONASCII = int(h5py.version.version.split(".")[0]) < 3
-
def __init__(self, dset, decode_ascii=False):
"""
:param h5py.Dataset dset:
:param bool decode_ascii:
"""
+ # Get the string encoding (if a string)
try:
string_info = h5py.h5t.check_string_dtype(dset.dtype)
except AttributeError:
@@ -1091,23 +1285,14 @@ class H5pyDatasetReadWrapper:
except AttributeError:
# Not an H5T_STRING
encoding = None
- if encoding == "ascii" and not decode_ascii:
- encoding = None
- if encoding != "ascii" and self.H5PY_AUTODECODE_NONASCII:
- # Decoding is already done by the h5py library
- encoding = None
- if encoding == "ascii":
- # ASCII can be decoded as UTF-8
- encoding = "utf-8"
+
self._encoding = encoding
+ self._decode_ascii = decode_ascii
self._dset = dset
def __getitem__(self, args):
value = self._dset[args]
- if self._encoding:
- return h5py_decode_value(value, encoding=self._encoding)
- else:
- return value
+ return hdf5_to_python_type(value, self._decode_ascii, self._encoding)
class H5pyAttributesReadWrapper:
@@ -1119,8 +1304,6 @@ class H5pyAttributesReadWrapper:
Therefore an H5T_STRING with ASCII encoding is not decoded by default.
"""
- H5PY_AUTODECODE = int(h5py.version.version.split(".")[0]) >= 3
-
def __init__(self, attrs, decode_ascii=False):
"""
:param h5py.Dataset dset:
@@ -1153,17 +1336,7 @@ class H5pyAttributesReadWrapper:
# Not an H5T_STRING
return value
- if self.H5PY_AUTODECODE:
- if encoding == "ascii" and not self._decode_ascii:
- # Undo decoding by the h5py library
- return h5py_encode_value(value, encoding="utf-8")
- else:
- if encoding == "ascii" and self._decode_ascii:
- # Decode ASCII as UTF-8 for consistency
- return h5py_decode_value(value, encoding="utf-8")
-
- # Decoding is already done by the h5py library
- return value
+ return hdf5_to_python_type(value, self._decode_ascii, encoding)
def items(self):
for k in self._attrs.keys():
diff --git a/src/silx/math/_colormap.pyx b/src/silx/math/_colormap.pyx
index e1409fa..a15b4ff 100644
--- a/src/silx/math/_colormap.pyx
+++ b/src/silx/math/_colormap.pyx
@@ -1,6 +1,11 @@
+#cython: embedsignature=True, language_level=3
+## This is for optimisation
+##cython: boundscheck=False, wraparound=False, cdivision=True, initializedcheck=False,
+## This is for developping:
+##cython: profile=True, warn.undeclared=True, warn.unused=True, warn.unused_result=False, warn.unused_arg=True
# /*##########################################################################
#
-# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2018-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -26,7 +31,7 @@
__authors__ = ["T. Vincent"]
__license__ = "MIT"
-__date__ = "16/05/2018"
+__date__ = "21/12/2023"
import os
@@ -34,6 +39,7 @@ cimport cython
from cython.parallel import prange
cimport numpy as cnumpy
from libc.math cimport frexp, sinh, sqrt
+from libc.math cimport pow as c_pow
from .math_compatibility cimport asinh, isnan, isfinite, lrint, INFINITY, NAN
import logging
@@ -100,7 +106,7 @@ ctypedef fused image_types:
# Normalization
-ctypedef double (*NormalizationFunction)(double) nogil
+# ctypedef double (*NormalizationFunction)(double) nogil
cdef class Normalization:
@@ -152,7 +158,7 @@ cdef class Normalization:
<double> data1d[index], vmin, vmax)
return numpy.array(result).reshape(data.shape)
- cdef double apply_double(self, double value, double vmin, double vmax) nogil:
+ cdef double apply_double(self, double value, double vmin, double vmax) noexcept nogil:
"""Apply normalization to a floating point value
Override in subclass
@@ -163,7 +169,7 @@ cdef class Normalization:
"""
return value
- cdef double revert_double(self, double value, double vmin, double vmax) nogil:
+ cdef double revert_double(self, double value, double vmin, double vmax) noexcept nogil:
"""Apply inverse of normalization to a floating point value
Override in subclass
@@ -178,10 +184,10 @@ cdef class Normalization:
cdef class LinearNormalization(Normalization):
"""Linear normalization"""
- cdef double apply_double(self, double value, double vmin, double vmax) nogil:
+ cdef double apply_double(self, double value, double vmin, double vmax) noexcept nogil:
return value
- cdef double revert_double(self, double value, double vmin, double vmax) nogil:
+ cdef double revert_double(self, double value, double vmin, double vmax) noexcept nogil:
return value
@@ -207,7 +213,7 @@ cdef class LogarithmicNormalization(Normalization):
@cython.boundscheck(False)
@cython.nonecheck(False)
@cython.cdivision(True)
- cdef double apply_double(self, double value, double vmin, double vmax) nogil:
+ cdef double apply_double(self, double value, double vmin, double vmax) noexcept nogil:
"""Return log10(value) fast approximation based on LUT"""
cdef double result = NAN # if value < 0.0 or value == NAN
cdef int exponent, index_lut
@@ -226,28 +232,28 @@ cdef class LogarithmicNormalization(Normalization):
self.lut[index_lut])
return result
- cdef double revert_double(self, double value, double vmin, double vmax) nogil:
- return 10**value
+ cdef double revert_double(self, double value, double vmin, double vmax) noexcept nogil:
+ return c_pow(10, value)
cdef class ArcsinhNormalization(Normalization):
"""Inverse hyperbolic sine normalization"""
- cdef double apply_double(self, double value, double vmin, double vmax) nogil:
+ cdef double apply_double(self, double value, double vmin, double vmax) noexcept nogil:
return asinh(value)
- cdef double revert_double(self, double value, double vmin, double vmax) nogil:
+ cdef double revert_double(self, double value, double vmin, double vmax) noexcept nogil:
return sinh(value)
cdef class SqrtNormalization(Normalization):
"""Square root normalization"""
- cdef double apply_double(self, double value, double vmin, double vmax) nogil:
+ cdef double apply_double(self, double value, double vmin, double vmax) noexcept nogil:
return sqrt(value)
- cdef double revert_double(self, double value, double vmin, double vmax) nogil:
- return value**2
+ cdef double revert_double(self, double value, double vmin, double vmax) noexcept nogil:
+ return value*value
cdef class PowerNormalization(Normalization):
@@ -268,7 +274,8 @@ cdef class PowerNormalization(Normalization):
# Needed for multiple inheritance to work
pass
- cdef double apply_double(self, double value, double vmin, double vmax) nogil:
+ @cython.cdivision(True)
+ cdef double apply_double(self, double value, double vmin, double vmax) noexcept nogil:
if vmin == vmax:
return 0.
elif value <= vmin:
@@ -276,15 +283,16 @@ cdef class PowerNormalization(Normalization):
elif value >= vmax:
return 1.
else:
- return ((value - vmin) / (vmax - vmin))**self.gamma
+ return c_pow(((value - vmin) / (vmax - vmin)), self.gamma)
- cdef double revert_double(self, double value, double vmin, double vmax) nogil:
+ @cython.cdivision(True)
+ cdef double revert_double(self, double value, double vmin, double vmax) noexcept nogil:
if value <= 0.:
return vmin
elif value >= 1.:
return vmax
else:
- return vmin + (vmax - vmin) * value**(1.0/self.gamma)
+ return vmin + (vmax - vmin) * c_pow(value, (1.0/self.gamma))
# Colormap
diff --git a/src/silx/math/calibration.py b/src/silx/math/calibration.py
index 79be585..7a86a9a 100644
--- a/src/silx/math/calibration.py
+++ b/src/silx/math/calibration.py
@@ -1,5 +1,5 @@
# /*##########################################################################
-# Copyright (C) 2018 European Synchrotron Radiation Facility
+# Copyright (C) 2018-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -31,13 +31,13 @@ Classes
- :class:`ArrayCalibration`
"""
+import functools
import numpy
class AbstractCalibration(object):
- """A calibration is a transformation to be applied to an axis (i.e. a 1D array).
+ """A calibration is a transformation to be applied to an axis (i.e. a 1D array)."""
- """
def __init__(self):
super(AbstractCalibration, self).__init__()
@@ -46,8 +46,9 @@ class AbstractCalibration(object):
:param x: Axis (1-D array), or value"""
raise NotImplementedError(
- "AbstractCalibration can not be used directly. " +
- "You must subclass it and implement __call__")
+ "AbstractCalibration can not be used directly. "
+ + "You must subclass it and implement __call__"
+ )
def is_affine(self):
"""Returns True for an affine calibration of the form
@@ -57,12 +58,13 @@ class AbstractCalibration(object):
def get_slope(self):
raise NotImplementedError(
- "get_slope is implemented only for affine calibrations")
+ "get_slope is implemented only for affine calibrations"
+ )
class NoCalibration(AbstractCalibration):
- """No calibration :math:`x \\mapsto x`
- """
+ """No calibration :math:`x \\mapsto x`"""
+
def __init__(self):
super(NoCalibration, self).__init__()
@@ -73,7 +75,7 @@ class NoCalibration(AbstractCalibration):
return True
def get_slope(self):
- return 1.
+ return 1.0
class LinearCalibration(AbstractCalibration):
@@ -83,6 +85,7 @@ class LinearCalibration(AbstractCalibration):
:param y_intercept: y-intercept
:param slope: Slope of the affine transformation
"""
+
def __init__(self, y_intercept, slope):
super(LinearCalibration, self).__init__()
self.constant = y_intercept
@@ -108,37 +111,44 @@ class ArrayCalibration(AbstractCalibration):
channels (:math:`0, 1, ..., n-1`).
:param x1: Calibration array"""
+
def __init__(self, x1):
super(ArrayCalibration, self).__init__()
if not isinstance(x1, (list, tuple)) and not hasattr(x1, "shape"):
raise TypeError(
- "The calibration array must be a sequence (list, dataset, array)")
+ "The calibration array must be a sequence (list, dataset, array)"
+ )
self.calibration_array = numpy.array(x1)
- self._is_affine = None
+ if self.calibration_array.ndim != 1:
+ raise ValueError(
+ f"1D array expected, got {self.calibration_array.ndim}D array"
+ )
+ if self.calibration_array.size == 0:
+ raise ValueError("Calibration array must not be empty")
def __call__(self, x):
# calibrate the entire axis
- if isinstance(x, (list, tuple, numpy.ndarray)) and \
- len(self.calibration_array) == len(x):
+ if isinstance(x, (list, tuple, numpy.ndarray)) and len(
+ self.calibration_array
+ ) == len(x):
return self.calibration_array
# calibrate one value, by index
if isinstance(x, int) and x < len(self.calibration_array):
return self.calibration_array[x]
- raise ValueError("ArrayCalibration must be applied to array of same size "
- "or to index.")
+ raise ValueError(
+ "ArrayCalibration must be applied to array of same size " "or to index."
+ )
+ @functools.lru_cache()
def is_affine(self):
"""If all values in the calibration array are regularly spaced,
return True."""
- if self._is_affine is None:
- delta_x = self.calibration_array[1:] - self.calibration_array[:-1]
- # use a less strict relative tolerance to account for rounding errors
- # e.g. when using float64 into float32 (see #1823)
- if not numpy.isclose(delta_x, delta_x[0], rtol=1e-4).all():
- self._is_affine = False
- else:
- self._is_affine = True
- return self._is_affine
+ if self.calibration_array.size < 2:
+ return False
+ delta = numpy.diff(self.calibration_array)
+ # use a less strict relative tolerance to account for rounding errors
+ # e.g. when using float64 into float32 (see #1823)
+ return numpy.allclose(delta, delta[0], rtol=1e-4)
def get_slope(self):
"""If the calibration array is regularly spaced, return the spacing."""
@@ -153,6 +163,7 @@ class FunctionCalibration(AbstractCalibration):
"""Calibration defined by a function *f*, such as :math:`x \\mapsto f(x)`*.
:param function: Calibration function"""
+
def __init__(self, function, is_affine=False):
super(FunctionCalibration, self).__init__()
if not hasattr(function, "__call__"):
diff --git a/src/silx/math/colormap.py b/src/silx/math/colormap.py
index 8c05b63..065e09c 100644
--- a/src/silx/math/colormap.py
+++ b/src/silx/math/colormap.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2018-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -29,6 +29,8 @@ __date__ = "25/08/2021"
import collections
+import numbers
+from typing import NamedTuple
import warnings
import numpy
@@ -41,23 +43,27 @@ from ._colormap import cmap # noqa
__all__ = ["apply_colormap", "cmap"]
-_LUT_DESCRIPTION = collections.namedtuple("_LUT_DESCRIPTION", ["source", "cursor_color"])
+_LUT_DESCRIPTION = collections.namedtuple(
+ "_LUT_DESCRIPTION", ["source", "cursor_color"]
+)
"""Description of a LUT for internal purpose."""
-_AVAILABLE_LUTS = collections.OrderedDict([
- ('gray', _LUT_DESCRIPTION('builtin', '#ff66ff')),
- ('reversed gray', _LUT_DESCRIPTION('builtin', '#ff66ff')),
- ('red', _LUT_DESCRIPTION('builtin', '#00ff00')),
- ('green', _LUT_DESCRIPTION('builtin', '#ff66ff')),
- ('blue', _LUT_DESCRIPTION('builtin', '#ffff00')),
- ('viridis', _LUT_DESCRIPTION('resource', '#ff66ff')),
- ('cividis', _LUT_DESCRIPTION('resource', '#ff66ff')),
- ('magma', _LUT_DESCRIPTION('resource', '#00ff00')),
- ('inferno', _LUT_DESCRIPTION('resource', '#00ff00')),
- ('plasma', _LUT_DESCRIPTION('resource', '#00ff00')),
- ('temperature', _LUT_DESCRIPTION('builtin', '#ff66ff')),
-])
+_AVAILABLE_LUTS = dict(
+ [
+ ("gray", _LUT_DESCRIPTION("builtin", "#ff66ff")),
+ ("reversed gray", _LUT_DESCRIPTION("builtin", "#ff66ff")),
+ ("red", _LUT_DESCRIPTION("builtin", "#00ff00")),
+ ("green", _LUT_DESCRIPTION("builtin", "#ff66ff")),
+ ("blue", _LUT_DESCRIPTION("builtin", "#ffff00")),
+ ("viridis", _LUT_DESCRIPTION("resource", "#ff66ff")),
+ ("cividis", _LUT_DESCRIPTION("resource", "#ff66ff")),
+ ("magma", _LUT_DESCRIPTION("resource", "#00ff00")),
+ ("inferno", _LUT_DESCRIPTION("resource", "#00ff00")),
+ ("plasma", _LUT_DESCRIPTION("resource", "#00ff00")),
+ ("temperature", _LUT_DESCRIPTION("builtin", "#ff66ff")),
+ ]
+)
"""Description for internal porpose of all the default LUT provided by the library."""
@@ -80,11 +86,11 @@ def array_to_rgba8888(colors):
if colors.dtype == numpy.uint8:
pass
- elif colors.dtype.kind == 'f':
+ elif colors.dtype.kind == "f":
# Each bin is [N, N+1[ except the last one: [255, 256]
- colors = numpy.clip(colors.astype(numpy.float64) * 256, 0., 255.)
+ colors = numpy.clip(colors.astype(numpy.float64) * 256, 0.0, 255.0)
colors = colors.astype(numpy.uint8)
- elif colors.dtype.kind in 'iu':
+ elif colors.dtype.kind in "iu":
colors = numpy.clip(colors, 0, 255)
colors = colors.astype(numpy.uint8)
@@ -112,17 +118,17 @@ def _create_colormap_lut(name):
lut = numpy.zeros((256, 4), dtype=numpy.uint8)
lut[:, 3] = 255
- if name == 'gray':
+ if name == "gray":
lut[:, :3] = numpy.arange(256, dtype=numpy.uint8).reshape(-1, 1)
- elif name == 'reversed gray':
+ elif name == "reversed gray":
lut[:, :3] = numpy.arange(255, -1, -1, dtype=numpy.uint8).reshape(-1, 1)
- elif name == 'red':
+ elif name == "red":
lut[:, 0] = numpy.arange(256, dtype=numpy.uint8)
- elif name == 'green':
+ elif name == "green":
lut[:, 1] = numpy.arange(256, dtype=numpy.uint8)
- elif name == 'blue':
+ elif name == "blue":
lut[:, 2] = numpy.arange(256, dtype=numpy.uint8)
- elif name == 'temperature':
+ elif name == "temperature":
# Red
lut[128:192, 0] = numpy.arange(2, 255, 4, dtype=numpy.uint8)
lut[192:, 0] = 255
@@ -145,12 +151,14 @@ def _create_colormap_lut(name):
return lut
else:
- raise RuntimeError("Internal LUT source '%s' unsupported" % description.source)
+ raise RuntimeError(
+ "Internal LUT source '%s' unsupported" % description.source
+ )
raise ValueError("Unknown colormap '%s'" % name)
-def register_colormap(name, lut, cursor_color='#000000'):
+def register_colormap(name, lut, cursor_color="#000000"):
"""Register a custom colormap LUT
It can override existing LUT names.
@@ -162,7 +170,7 @@ def register_colormap(name, lut, cursor_color='#000000'):
:param str cursor_color: Color used to display overlay over images using
colormap with this LUT.
"""
- description = _LUT_DESCRIPTION('user', cursor_color)
+ description = _LUT_DESCRIPTION("user", cursor_color)
colors = array_to_rgba8888(lut)
_AVAILABLE_LUTS[name] = description
@@ -187,7 +195,7 @@ def get_colormap_cursor_color(name):
color = description.cursor_color
if color is not None:
return color
- return 'black'
+ return "black"
def get_colormap_lut(name):
@@ -207,6 +215,7 @@ def get_colormap_lut(name):
# Normalizations
+
class _NormalizationMixIn:
"""Colormap normalization mix-in class"""
@@ -258,7 +267,7 @@ class _NormalizationMixIn:
vmax = min(dmax, stdmax)
else:
- raise ValueError('Unsupported mode: %s' % mode)
+ raise ValueError("Unsupported mode: %s" % mode)
# Check returned range and handle fallbacks
if vmin is None or not numpy.isfinite(vmin):
@@ -294,19 +303,21 @@ class _NormalizationMixIn:
:rtype: Tuple[float,float]
"""
# Use [0, 1] as data range for normalization not using range
- normdata = self.apply(data, 0., 1.)
- if normdata.dtype.kind == 'f': # Replaces inf by NaN
+ normdata = self.apply(data, 0.0, 1.0)
+ if normdata.dtype.kind == "f": # Replaces inf by NaN
normdata[numpy.isfinite(normdata) == False] = numpy.nan
if normdata.size == 0: # Fallback
return None, None
with warnings.catch_warnings():
- warnings.simplefilter('ignore', category=RuntimeWarning)
+ warnings.simplefilter("ignore", category=RuntimeWarning)
# Ignore nanmean "Mean of empty slice" warning and
# nanstd "Degrees of freedom <= 0 for slice" warning
mean, std = numpy.nanmean(normdata), numpy.nanstd(normdata)
- return self.revert(mean - 3 * std, 0., 1.), self.revert(mean + 3 * std, 0., 1.)
+ return self.revert(mean - 3 * std, 0.0, 1.0), self.revert(
+ mean + 3 * std, 0.0, 1.0
+ )
class _LinearNormalizationMixIn(_NormalizationMixIn):
@@ -321,13 +332,13 @@ class _LinearNormalizationMixIn(_NormalizationMixIn):
:returns: (vmin, vmax)
:rtype: Tuple[float,float]
"""
- if data.dtype.kind == 'f': # Replaces inf by NaN
+ if data.dtype.kind == "f": # Replaces inf by NaN
data = numpy.array(data, copy=True) # Work on a copy
data[numpy.isfinite(data) == False] = numpy.nan
if data.size == 0: # Fallback
return None, None
with warnings.catch_warnings():
- warnings.simplefilter('ignore', category=RuntimeWarning)
+ warnings.simplefilter("ignore", category=RuntimeWarning)
# Ignore nanmean "Mean of empty slice" warning and
# nanstd "Degrees of freedom <= 0 for slice" warning
mean, std = numpy.nanmean(data), numpy.nanstd(data)
@@ -336,6 +347,7 @@ class _LinearNormalizationMixIn(_NormalizationMixIn):
class LinearNormalization(_colormap.LinearNormalization, _LinearNormalizationMixIn):
"""Linear normalization"""
+
def __init__(self):
_colormap.LinearNormalization.__init__(self)
_LinearNormalizationMixIn.__init__(self)
@@ -351,7 +363,7 @@ class LogarithmicNormalization(_colormap.LogarithmicNormalization, _Normalizatio
_NormalizationMixIn.__init__(self)
def is_valid(self, value):
- return value > 0.
+ return value > 0.0
def autoscale_minmax(self, data):
result = _min_max(data, min_positive=True, finite=True)
@@ -368,7 +380,7 @@ class SqrtNormalization(_colormap.SqrtNormalization, _NormalizationMixIn):
_NormalizationMixIn.__init__(self)
def is_valid(self, value):
- return value >= 0.
+ return value >= 0.0
class GammaNormalization(_colormap.PowerNormalization, _LinearNormalizationMixIn):
@@ -378,6 +390,7 @@ class GammaNormalization(_colormap.PowerNormalization, _LinearNormalizationMixIn
:param gamma: Gamma correction factor
"""
+
def __init__(self, gamma):
_colormap.PowerNormalization.__init__(self, gamma)
_LinearNormalizationMixIn.__init__(self)
@@ -404,15 +417,37 @@ _BASIC_NORMALIZATIONS = {
"arcsinh": ArcsinhNormalization(),
}
+
+def _get_normalizer(norm, gamma):
+ """Returns corresponding Normalization instance"""
+ if norm == "gamma":
+ return GammaNormalization(gamma)
+ return _BASIC_NORMALIZATIONS[norm]
+
+
+def _get_range(normalizer, data, autoscale, vmin, vmax):
+ """Returns effective range"""
+ if vmin is None or vmax is None:
+ auto_vmin, auto_vmax = normalizer.autoscale(data, autoscale)
+ if vmin is None: # Set vmin respecting provided vmax
+ vmin = auto_vmin if vmax is None else min(auto_vmin, vmax)
+ if vmax is None:
+ vmax = max(auto_vmax, vmin) # Handle max_ <= 0 for log scale
+ return vmin, vmax
+
+
_DEFAULT_NAN_COLOR = 255, 255, 255, 0
-def apply_colormap(data,
- colormap: str,
- norm: str="linear",
- autoscale: str="minmax",
- vmin=None,
- vmax=None,
- gamma=1.0):
+
+def apply_colormap(
+ data,
+ colormap: str,
+ norm: str = "linear",
+ autoscale: str = "minmax",
+ vmin=None,
+ vmax=None,
+ gamma=1.0,
+):
"""Apply colormap to data with given normalization and autoscale.
:param numpy.ndarray data: Data on which to apply the colormap
@@ -426,19 +461,8 @@ def apply_colormap(data,
:returns: Array of colors
"""
colors = get_colormap_lut(colormap)
-
- if norm == "gamma":
- normalizer = GammaNormalization(gamma)
- else:
- normalizer = _BASIC_NORMALIZATIONS[norm]
-
- if vmin is None or vmax is None:
- auto_vmin, auto_vmax = normalizer.autoscale(data, autoscale)
- if vmin is None: # Set vmin respecting provided vmax
- vmin = auto_vmin if vmax is None else min(auto_vmin, vmax)
- if vmax is None:
- vmax = max(auto_vmax, vmin) # Handle max_ <= 0 for log scale
-
+ normalizer = _get_normalizer(norm, gamma)
+ vmin, vmax = _get_range(normalizer, data, autoscale, vmin, vmax)
return _colormap.cmap(
data,
colors,
@@ -447,3 +471,45 @@ def apply_colormap(data,
normalizer,
_DEFAULT_NAN_COLOR,
)
+
+
+_UINT8_LUT = numpy.arange(256, dtype=numpy.uint8).reshape(-1, 1)
+
+
+class NormalizeResult(NamedTuple):
+ data: numpy.ndarray
+ vmin: numbers.Number
+ vmax: numbers.Number
+
+
+def normalize(
+ data,
+ norm: str = "linear",
+ autoscale: str = "minmax",
+ vmin=None,
+ vmax=None,
+ gamma=1.0,
+):
+ """Normalize data to an array of uint8.
+
+ :param numpy.ndarray data: Data to normalize
+ :param str norm: Normalization to apply
+ :param str autoscale: Autoscale mode: "minmax" (default) or "stddev3"
+ :param vmin: Lower bound, None (default) to autoscale
+ :param vmax: Upper bound, None (default) to autoscale
+ :param float gamma:
+ Gamma correction parameter (used only for "gamma" normalization)
+ :returns: Array of normalized values, vmin, vmax
+ """
+ normalizer = _get_normalizer(norm, gamma)
+ vmin, vmax = _get_range(normalizer, data, autoscale, vmin, vmax)
+ norm_data = _colormap.cmap(
+ data,
+ _UINT8_LUT,
+ vmin,
+ vmax,
+ normalizer,
+ nan_color=_UINT8_LUT[0],
+ )
+ norm_data.shape = data.shape
+ return NormalizeResult(norm_data, vmin, vmax)
diff --git a/src/silx/math/fft/basefft.py b/src/silx/math/fft/basefft.py
index c608fde..6e9fac8 100644
--- a/src/silx/math/fft/basefft.py
+++ b/src/silx/math/fft/basefft.py
@@ -23,7 +23,7 @@
#
# ###########################################################################*/
import numpy as np
-from pkg_resources import parse_version
+from packaging.version import Version
def check_version(package, required_version):
@@ -37,8 +37,8 @@ def check_version(package, required_version):
ver = getattr(package, "version")
except Exception:
return False
- req_v = parse_version(required_version)
- ver_v = parse_version(ver)
+ req_v = Version(required_version)
+ ver_v = Version(ver)
return ver_v >= req_v
@@ -46,6 +46,7 @@ class BaseFFT(object):
"""
Base class for all FFT backends.
"""
+
def __init__(self, **kwargs):
self.__get_args(**kwargs)
@@ -82,25 +83,20 @@ class BaseFFT(object):
np.dtype("float32"): np.complex64,
np.dtype("float64"): np.complex128,
np.dtype("complex64"): np.complex64,
- np.dtype("complex128"): np.complex128
- }
- dp = {
- np.dtype("float32"): np.float64,
- np.dtype("complex64"): np.complex128
+ np.dtype("complex128"): np.complex128,
}
+ dp = {np.dtype("float32"): np.float64, np.dtype("complex64"): np.complex128}
self.dtype_in = np.dtype(self.dtype)
if self.dtype_in not in dtypes_mapping:
- raise ValueError("Invalid input data type: got %s" %
- self.dtype_in
- )
+ raise ValueError("Invalid input data type: got %s" % self.dtype_in)
self.dtype_out = dtypes_mapping[self.dtype_in]
def __calc_shape(self):
# TODO allow for C2C even for real input data (?)
if self.dtype_in in [np.float32, np.float64]:
- last_dim = self.shape[-1]//2 + 1
+ last_dim = self.shape[-1] // 2 + 1
# FFTW convention
- self.shape_out = self.shape[:-1] + (self.shape[-1]//2 + 1,)
+ self.shape_out = self.shape[:-1] + (self.shape[-1] // 2 + 1,)
else:
self.shape_out = self.shape
@@ -121,7 +117,7 @@ class BaseFFT(object):
raise ValueError("This should be implemented by back-end FFT")
def allocate_arrays(self):
- if not(self.data_allocated):
+ if not (self.data_allocated):
self.data_in = self._allocate(self.shape, self.dtype_in)
self.data_out = self._allocate(self.shape_out, self.dtype_out)
self.data_allocated = True
@@ -130,13 +126,22 @@ class BaseFFT(object):
if data is None:
return self.data_in
else:
- return self.set_data(self.data_in, data, self.shape, self.dtype_in, copy=copy, name="data_in")
+ return self.set_data(
+ self.data_in, data, self.shape, self.dtype_in, copy=copy, name="data_in"
+ )
def set_output_data(self, data, copy=True):
if data is None:
return self.data_out
else:
- return self.set_data(self.data_out, data, self.shape_out, self.dtype_out, copy=copy, name="data_out")
+ return self.set_data(
+ self.data_out,
+ data,
+ self.shape_out,
+ self.dtype_out,
+ copy=copy,
+ name="data_out",
+ )
def fft(self, array, **kwargs):
raise ValueError("This should be implemented by back-end FFT")
diff --git a/src/silx/math/fft/clfft.py b/src/silx/math/fft/clfft.py
index 2e41e47..488102a 100644
--- a/src/silx/math/fft/clfft.py
+++ b/src/silx/math/fft/clfft.py
@@ -25,12 +25,14 @@
import numpy as np
from .basefft import BaseFFT, check_version
+
try:
import pyopencl as cl
import pyopencl.array as parray
import gpyfft
from gpyfft.fft import FFT as cl_fft
from ...opencl.common import ocl
+
__have_clfft__ = True
except ImportError:
__have_clfft__ = False
@@ -58,6 +60,7 @@ class CLFFT(BaseFFT):
:param bool choose_best_device:
Whether to automatically choose the best available OpenCL device.
"""
+
def __init__(
self,
shape=None,
@@ -70,8 +73,11 @@ class CLFFT(BaseFFT):
fast_math=False,
choose_best_device=True,
):
- if not(__have_clfft__) or not(__have_clfft__):
- raise ImportError("Please install pyopencl and gpyfft >= %s to use the OpenCL back-end" % __required_gpyfft_version__)
+ if not (__have_clfft__) or not (__have_clfft__):
+ raise ImportError(
+ "Please install pyopencl and gpyfft >= %s to use the OpenCL back-end"
+ % __required_gpyfft_version__
+ )
super().__init__(
shape=shape,
@@ -116,18 +122,16 @@ class CLFFT(BaseFFT):
ary.fill(0)
return ary
-
def check_array(self, array, shape, dtype, copy=True):
if array.shape != shape:
- raise ValueError("Invalid data shape: expected %s, got %s" %
- (shape, array.shape)
+ raise ValueError(
+ "Invalid data shape: expected %s, got %s" % (shape, array.shape)
)
if array.dtype != dtype:
- raise ValueError("Invalid data type: expected %s, got %s" %
- (dtype, array.dtype)
+ raise ValueError(
+ "Invalid data type: expected %s, got %s" % (dtype, array.dtype)
)
-
def set_data(self, dst, src, shape, dtype, copy=True, name=None):
"""
dst is a device array owned by the current instance
@@ -140,10 +144,10 @@ class CLFFT(BaseFFT):
if name == "data_out":
# Makes little sense to provide output=numpy_array
return dst
- if not(src.flags["C_CONTIGUOUS"]):
+ if not (src.flags["C_CONTIGUOUS"]):
src = np.ascontiguousarray(src, dtype=dtype)
# working on underlying buffer is notably faster
- #~ dst[:] = src[:]
+ # ~ dst[:] = src[:]
evt = cl.enqueue_copy(self.queue, dst.data, src)
evt.wait()
elif isinstance(src, parray.Array):
@@ -153,22 +157,20 @@ class CLFFT(BaseFFT):
if name is None:
# This should not happen
raise ValueError("Please provide either copy=True or name != None")
- assert id(self.refs[name]) == id(dst) # DEBUG
+ assert id(self.refs[name]) == id(dst) # DEBUG
setattr(self, name, src)
return src
else:
raise ValueError(
- "Invalid array type %s, expected numpy.ndarray or pyopencl.array" %
- type(src)
+ "Invalid array type %s, expected numpy.ndarray or pyopencl.array"
+ % type(src)
)
return dst
-
def recover_array_references(self):
self.data_in = self.refs["data_in"]
self.data_out = self.refs["data_out"]
-
def init_context_queue(self):
if self.ctx is None:
if self.choose_best_device:
@@ -177,7 +179,6 @@ class CLFFT(BaseFFT):
self.ctx = cl.create_some_context()
self.queue = cl.CommandQueue(self.ctx)
-
def compute_forward_plan(self):
self.plan_forward = cl_fft(
self.ctx,
@@ -189,7 +190,6 @@ class CLFFT(BaseFFT):
real=self.real_transform,
)
-
def compute_inverse_plan(self):
self.plan_inverse = cl_fft(
self.ctx,
@@ -201,26 +201,22 @@ class CLFFT(BaseFFT):
real=self.real_transform,
)
-
def update_forward_plan_arrays(self):
self.plan_forward.data = self.data_in
self.plan_forward.result = self.data_out
-
def update_inverse_plan_arrays(self):
self.plan_inverse.data = self.data_out
self.plan_inverse.result = self.data_in
-
def copy_output_if_numpy(self, dst, src):
if isinstance(dst, parray.Array):
return
# working on underlying buffer is notably faster
- #~ dst[:] = src[:]
+ # ~ dst[:] = src[:]
evt = cl.enqueue_copy(self.queue, dst, src.data)
evt.wait()
-
def fft(self, array, output=None, do_async=False):
"""
Perform a (forward) Fast Fourier Transform.
@@ -236,8 +232,8 @@ class CLFFT(BaseFFT):
self.set_input_data(array, copy=False)
self.set_output_data(output, copy=False)
self.update_forward_plan_arrays()
- event, = self.plan_forward.enqueue()
- if not(do_async):
+ (event,) = self.plan_forward.enqueue()
+ if not (do_async):
event.wait()
if output is not None:
self.copy_output_if_numpy(output, self.data_out)
@@ -247,7 +243,6 @@ class CLFFT(BaseFFT):
self.recover_array_references()
return res
-
def ifft(self, array, output=None, do_async=False):
"""
Perform a (inverse) Fast Fourier Transform.
@@ -263,8 +258,8 @@ class CLFFT(BaseFFT):
self.set_output_data(array, copy=False)
self.set_input_data(output, copy=False)
self.update_inverse_plan_arrays()
- event, = self.plan_inverse.enqueue(forward=False)
- if not(do_async):
+ (event,) = self.plan_inverse.enqueue(forward=False)
+ if not (do_async):
event.wait()
if output is not None:
self.copy_output_if_numpy(output, self.data_in)
@@ -274,7 +269,6 @@ class CLFFT(BaseFFT):
self.recover_array_references()
return res
-
def __del__(self):
# It seems that gpyfft underlying clFFT destructors are not called.
# This results in the following warning:
@@ -282,4 +276,3 @@ class CLFFT(BaseFFT):
# Please consider explicitly calling clfftTeardown( )
del self.plan_forward
del self.plan_inverse
-
diff --git a/src/silx/math/fft/cufft.py b/src/silx/math/fft/cufft.py
index 4bc7806..c609439 100644
--- a/src/silx/math/fft/cufft.py
+++ b/src/silx/math/fft/cufft.py
@@ -25,11 +25,13 @@
import numpy as np
from .basefft import BaseFFT
+
try:
import pycuda.gpuarray as gpuarray
from skcuda.fft import Plan
from skcuda.fft import fft as cu_fft
from skcuda.fft import ifft as cu_ifft
+
__have_cufft__ = True
except ImportError:
__have_cufft__ = False
@@ -47,6 +49,7 @@ class CUFFT(BaseFFT):
Stream with which to associate the plan. If no stream is specified,
the default stream is used.
"""
+
def __init__(
self,
shape=None,
@@ -57,8 +60,10 @@ class CUFFT(BaseFFT):
normalize="rescale",
stream=None,
):
- if not(__have_cufft__) or not(__have_cufft__):
- raise ImportError("Please install pycuda and scikit-cuda to use the CUDA back-end")
+ if not (__have_cufft__) or not (__have_cufft__):
+ raise ImportError(
+ "Please install pycuda and scikit-cuda to use the CUDA back-end"
+ )
super().__init__(
shape=shape,
@@ -103,7 +108,9 @@ class CUFFT(BaseFFT):
3: [(1, 2), (2, 1), (1,), (2,)],
}
if self.axes not in supported_axes[data_ndims]:
- raise NotImplementedError("With the CUDA backend, batched transform is only supported along fastest dimensions")
+ raise NotImplementedError(
+ "With the CUDA backend, batched transform is only supported along fastest dimensions"
+ )
self.cufft_batch_size = self.shape[0]
self.cufft_shape = self.shape[1:]
if data_ndims == 3 and len(self.axes) == 1:
@@ -120,15 +127,17 @@ class CUFFT(BaseFFT):
raise NotImplementedError(
"Normalization mode 'ortho' is not implemented with CUDA backend yet."
)
- self.cufft_scale_inverse = (self.normalize == "rescale")
+ self.cufft_scale_inverse = self.normalize == "rescale"
def check_array(self, array, shape, dtype, copy=True):
if array.shape != shape:
- raise ValueError("Invalid data shape: expected %s, got %s" %
- (shape, array.shape))
+ raise ValueError(
+ "Invalid data shape: expected %s, got %s" % (shape, array.shape)
+ )
if array.dtype != dtype:
- raise ValueError("Invalid data type: expected %s, got %s" %
- (dtype, array.dtype))
+ raise ValueError(
+ "Invalid data type: expected %s, got %s" % (dtype, array.dtype)
+ )
def set_data(self, dst, src, shape, dtype, copy=True, name=None):
"""
@@ -142,7 +151,7 @@ class CUFFT(BaseFFT):
if name == "data_out":
# Makes little sense to provide output=numpy_array
return dst
- if not(src.flags["C_CONTIGUOUS"]):
+ if not (src.flags["C_CONTIGUOUS"]):
src = np.ascontiguousarray(src, dtype=dtype)
dst[:] = src[:]
elif isinstance(src, gpuarray.GPUArray):
@@ -157,8 +166,8 @@ class CUFFT(BaseFFT):
return src
else:
raise ValueError(
- "Invalid array type %s, expected numpy.ndarray or pycuda.gpuarray" %
- type(src)
+ "Invalid array type %s, expected numpy.ndarray or pycuda.gpuarray"
+ % type(src)
)
return dst
@@ -176,12 +185,12 @@ class CUFFT(BaseFFT):
# cufft extensible plan API is only supported after 0.5.1
# (commit 65288d28ca0b93e1234133f8d460dc6becb65121)
# but there is still no official 0.5.2
- #~ auto_allocate=True # cufft extensible plan API
+ # ~ auto_allocate=True # cufft extensible plan API
)
def compute_inverse_plan(self):
self.plan_inverse = Plan(
- self.cufft_shape, # not shape_out
+ self.cufft_shape, # not shape_out
self.dtype_out,
self.dtype,
batch=self.cufft_batch_size,
@@ -189,7 +198,7 @@ class CUFFT(BaseFFT):
# cufft extensible plan API is only supported after 0.5.1
# (commit 65288d28ca0b93e1234133f8d460dc6becb65121)
# but there is still no official 0.5.2
- #~ auto_allocate=True
+ # ~ auto_allocate=True
)
def copy_output_if_numpy(self, dst, src):
@@ -209,12 +218,7 @@ class CUFFT(BaseFFT):
data_in = self.set_input_data(array, copy=False)
data_out = self.set_output_data(output, copy=False)
- cu_fft(
- data_in,
- data_out,
- self.plan_forward,
- scale=False
- )
+ cu_fft(data_in, data_out, self.plan_forward, scale=False)
if output is not None:
self.copy_output_if_numpy(output, self.data_out)
diff --git a/src/silx/math/fft/fft.py b/src/silx/math/fft/fft.py
index 23de0cb..7daf17b 100644
--- a/src/silx/math/fft/fft.py
+++ b/src/silx/math/fft/fft.py
@@ -34,7 +34,7 @@ def FFT(
axes=None,
normalize="rescale",
backend="numpy",
- **kwargs
+ **kwargs,
):
"""
Initialize a FFT plan.
@@ -94,6 +94,6 @@ def FFT(
shape_out=shape_out,
axes=axes,
normalize=normalize,
- **kwargs
+ **kwargs,
)
return F
diff --git a/src/silx/math/fft/fftw.py b/src/silx/math/fft/fftw.py
index 797543b..69edbb6 100644
--- a/src/silx/math/fft/fftw.py
+++ b/src/silx/math/fft/fftw.py
@@ -259,7 +259,6 @@ class FFTW(BaseFFT):
return data_out
-
def get_wisdom_metadata():
"""
Get metadata on the current platform.
@@ -269,7 +268,7 @@ def get_wisdom_metadata():
"""
return {
# "venv"
- "executable": sys_executable,
+ "executable": sys_executable,
# encapsulates sys.platform, platform.machine(), platform.architecture(), platform.libc_ver(), ...
"hostname": gethostname(),
"available_threads": len(os.sched_getaffinity(0)),
@@ -293,7 +292,7 @@ def export_wisdom(fname, on_existing="overwrite"):
if on_existing == "raise":
raise ValueError("File already exists: %s" % fname)
if on_existing == "append":
- import_wisdom(fname, on_mismatch="ignore") # ?
+ import_wisdom(fname, on_mismatch="ignore") # ?
current_wisdom = pyfftw.export_wisdom()
res = get_wisdom_metadata()
for i, w in enumerate(current_wisdom):
@@ -320,8 +319,12 @@ def import_wisdom(fname, match=["hostname"], on_mismatch="warn"):
- "warn": print a warning, don't crash
- "ignore": do nothing
"""
+
def handle_mismatch(item, loaded_value, current_value):
- msg = "Platform configuration mismatch: %s: currently have '%s', loaded '%s'" % (item, current_value, loaded_value)
+ msg = (
+ "Platform configuration mismatch: %s: currently have '%s', loaded '%s'"
+ % (item, current_value, loaded_value)
+ )
if on_mismatch == "raise":
raise ValueError(msg)
if on_mismatch == "warn":
@@ -332,16 +335,27 @@ def import_wisdom(fname, match=["hostname"], on_mismatch="warn"):
for metadata_name in match:
if metadata_name not in wis_metadata:
raise ValueError(
- "Cannot match metadata '%s'. Available are: %s" % (match, str(wis_metadata.keys()))
+ "Cannot match metadata '%s'. Available are: %s"
+ % (match, str(wis_metadata.keys()))
)
if loaded_wisdom[metadata_name] != wis_metadata[metadata_name]:
- handle_mismatch(metadata_name, loaded_wisdom[metadata_name], wis_metadata[metadata_name])
+ handle_mismatch(
+ metadata_name, loaded_wisdom[metadata_name], wis_metadata[metadata_name]
+ )
return
- w = tuple(loaded_wisdom[k][()] for k in loaded_wisdom.keys() if k not in wis_metadata.keys())
+ w = tuple(
+ loaded_wisdom[k][()]
+ for k in loaded_wisdom.keys()
+ if k not in wis_metadata.keys()
+ )
pyfftw.import_wisdom(w)
-def get_wisdom_file(directory=None, name_template="fftw_wisdom_{whoami}_{hostname}.npz", create_dirs=True):
+def get_wisdom_file(
+ directory=None,
+ name_template="fftw_wisdom_{whoami}_{hostname}.npz",
+ create_dirs=True,
+):
"""
Get a file path for storing FFTW wisdom.
@@ -355,10 +369,7 @@ def get_wisdom_file(directory=None, name_template="fftw_wisdom_{whoami}_{hostnam
Whether to create (possibly nested) directories if needed.
"""
directory = directory or gettempdir()
- file_basename = name_template.format(
- whoami=os.getlogin(),
- hostname=gethostname()
- )
+ file_basename = name_template.format(whoami=os.getlogin(), hostname=gethostname())
out_file = os.path.join(directory, file_basename)
if create_dirs:
Path(os.path.dirname(out_file)).mkdir(parents=True, exist_ok=True)
diff --git a/src/silx/math/fft/npfft.py b/src/silx/math/fft/npfft.py
index fc7d1c9..3fe0754 100644
--- a/src/silx/math/fft/npfft.py
+++ b/src/silx/math/fft/npfft.py
@@ -24,7 +24,7 @@
# ###########################################################################*/
import numpy as np
import warnings
-from pkg_resources import parse_version
+from packaging.version import Version
from .basefft import BaseFFT
@@ -76,7 +76,7 @@ class NPFFT(BaseFFT):
self.numpy_args_ifft = {"norm": "ortho"}
elif self.normalize == "none": # no normalisation on both fft & ifft
- if parse_version(np.version.version) < parse_version("1.20"):
+ if Version(np.version.version) < Version("1.20"):
# "backward" & "forward" keywords were introduced in 1.20 and we support numpy >= 1.8
warnings.warn(
"Numpy version %s does not allow to non-normalization. Effective normalization will be 'rescale'"
diff --git a/src/silx/math/fft/test/test_fft.py b/src/silx/math/fft/test/test_fft.py
index b696317..abe7842 100644
--- a/src/silx/math/fft/test/test_fft.py
+++ b/src/silx/math/fft/test/test_fft.py
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# /*##########################################################################
#
-# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2018-2022 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -28,11 +28,15 @@ from os import path
import logging
import numpy as np
import unittest
-from pkg_resources import parse_version
+from packaging.version import Version
import pytest
from tempfile import TemporaryDirectory
+
try:
- from scipy.misc import ascent
+ try:
+ from scipy.misc import ascent
+ except:
+ from scipy.datasets import ascent
__have_scipy = True
except ImportError:
__have_scipy = False
@@ -40,13 +44,19 @@ from silx.utils.testutils import ParametricTestCase
from silx.math.fft.fft import FFT
from silx.math.fft.clfft import __have_clfft__
from silx.math.fft.cufft import __have_cufft__
-from silx.math.fft.fftw import __have_fftw__, import_wisdom, export_wisdom, get_wisdom_file
+from silx.math.fft.fftw import (
+ __have_fftw__,
+ import_wisdom,
+ export_wisdom,
+ get_wisdom_file,
+)
if __have_cufft__:
import atexit
import pycuda.driver as cuda
from pycuda.tools import clear_context_caches
+
def get_cuda_context(device_id=None, cleanup_at_exit=True):
"""
Create or get a CUDA context.
@@ -66,18 +76,22 @@ def get_cuda_context(device_id=None, cleanup_at_exit=True):
# Unlike Context.make_context(), the newly-created context is not made current.
context = cuda.Device(device_id).retain_primary_context()
context.push()
+
# Register a clean-up function at exit
def _finish_up(context):
if context is not None:
context.pop()
context = None
clear_context_caches()
+
if cleanup_at_exit:
atexit.register(_finish_up, context)
return context
+
logger = logging.getLogger(__name__)
+
class TransformInfos(object):
def __init__(self):
self.dimensions = [
@@ -96,8 +110,16 @@ class TransformInfos(object):
self.sizes = {
"1D": [(128,), (127,)],
"2D": [(128, 128), (128, 127), (127, 128), (127, 127)],
- "3D": [(64, 64, 64), (64, 64, 63), (64, 63, 64), (63, 64, 64),
- (64, 63, 63), (63, 64, 63), (63, 63, 64), (63, 63, 63)]
+ "3D": [
+ (64, 64, 64),
+ (64, 64, 63),
+ (64, 63, 64),
+ (63, 64, 64),
+ (64, 63, 63),
+ (63, 64, 63),
+ (63, 63, 64),
+ (63, 63, 63),
+ ],
}
self.axes = {
"1D": None,
@@ -144,8 +166,9 @@ class TestFFT(ParametricTestCase):
"""
return np.max(np.abs(arr1 - arr2))
- @unittest.skipIf(not __have_cufft__,
- "cuda back-end requires pycuda and scikit-cuda")
+ @unittest.skipIf(
+ not __have_cufft__, "cuda back-end requires pycuda and scikit-cuda"
+ )
def test_cuda(self):
get_cuda_context()
@@ -154,15 +177,14 @@ class TestFFT(ParametricTestCase):
self.__run_tests(backend="cuda")
- @unittest.skipIf(not __have_clfft__,
- "opencl back-end requires pyopencl and gpyfft")
+ @unittest.skipIf(not __have_clfft__, "opencl back-end requires pyopencl and gpyfft")
def test_opencl(self):
from silx.opencl.common import ocl
+
if ocl is not None:
self.__run_tests(backend="opencl", ctx=ocl.create_context())
- @unittest.skipIf(not __have_fftw__,
- "fftw back-end requires pyfftw")
+ @unittest.skipIf(not __have_fftw__, "fftw back-end requires pyfftw")
def test_fftw(self):
self.__run_tests(backend="fftw")
@@ -180,14 +202,20 @@ class TestFFT(ParametricTestCase):
def __test(self, backend, trdim, mode, size, **extra_args):
"""Compare given backend with numpy for given conditions"""
- logger.debug("backend: %s, trdim: %s, mode: %s, size: %s",
- backend, trdim, mode, str(size))
+ logger.debug(
+ "backend: %s, trdim: %s, mode: %s, size: %s",
+ backend,
+ trdim,
+ mode,
+ str(size),
+ )
if size == "3D" and self.test_options.TEST_LOW_MEM:
self.skipTest("low mem")
ndim = len(size)
input_data = self.test_data.data_refs[ndim].astype(
- self.transform_infos.modes[mode])
+ self.transform_infos.modes[mode]
+ )
tol = self.tol[np.dtype(input_data.dtype)]
if trdim == "3D":
tol *= 10 # Error is relatively high in high dimensions
@@ -202,34 +230,29 @@ class TestFFT(ParametricTestCase):
"backend": backend,
}
fft_args.update(extra_args)
- F = FFT(
- **fft_args
- )
+ F = FFT(**fft_args)
F_np = FFT(
- template=input_data,
- axes=self.transform_infos.axes[trdim],
- backend="numpy"
+ template=input_data, axes=self.transform_infos.axes[trdim], backend="numpy"
)
# Forward FFT
res = F.fft(input_data)
res_np = F_np.fft(input_data)
mae = self.calc_mae(res, res_np)
- all_close = np.allclose(res, res_np, atol=tol, rtol=tol),
+ all_close = (np.allclose(res, res_np, atol=tol, rtol=tol),)
self.assertTrue(
all_close,
- "FFT %s:%s, MAE(%s, numpy) = %f (tol = %.2e)" % (mode, trdim, backend, mae, tol)
+ "FFT %s:%s, MAE(%s, numpy) = %f (tol = %.2e)"
+ % (mode, trdim, backend, mae, tol),
)
# Inverse FFT
res2 = F.ifft(res)
mae = self.calc_mae(res2, input_data)
self.assertTrue(
- mae < tol,
- "IFFT %s:%s, MAE(%s, numpy) = %f" % (mode, trdim, backend, mae)
+ mae < tol, "IFFT %s:%s, MAE(%s, numpy) = %f" % (mode, trdim, backend, mae)
)
-
# Test normalizations. silx FFT has three normalization modes:
# - "rescale" (default). FFT is unscaled, IFFT is scaled by 1/N.
# This corresponds to numpy normalize=None i.e normalize="backward"
@@ -248,7 +271,7 @@ class TestFFT(ParametricTestCase):
},
"cuda": {
"supported_normalizations": ["rescale", "none"],
- }
+ },
}
@staticmethod
@@ -267,7 +290,7 @@ class TestFFT(ParametricTestCase):
elif silx_normalization_mode == "ortho":
return np.fft.irfftn(data, axes=axes, norm="ortho")
elif silx_normalization_mode == "none":
- res = np.fft.irfftn(data, axes=axes, norm=None)
+ res = np.fft.irfftn(data, axes=axes, norm=None)
# This assumes a FFT on all the axes, won't work on batched FFT
N = res.size
return res * N
@@ -279,8 +302,8 @@ class TestFFT(ParametricTestCase):
return self._test_norms_with_backend("fftw")
@unittest.skipIf(
- parse_version(np.version.version) <= parse_version("1.19.5"),
- "normalization does not work for numpy <= 1.19.5"
+ Version(np.version.version) <= Version("1.19.5"),
+ "normalization does not work for numpy <= 1.19.5",
)
def test_norms_numpy(self):
return self._test_norms_with_backend("numpy")
@@ -288,10 +311,13 @@ class TestFFT(ParametricTestCase):
@unittest.skipIf(not __have_clfft__, "opencl back-end requires pyopencl and gpyfft")
def test_norms_opencl(self):
from silx.opencl.common import ocl
+
if ocl is not None:
return self._test_norms_with_backend("opencl")
- @unittest.skipIf(not __have_cufft__, "cuda back-end requires pycuda and scikit-cuda")
+ @unittest.skipIf(
+ not __have_cufft__, "cuda back-end requires pycuda and scikit-cuda"
+ )
def test_norms_cuda(self):
get_cuda_context()
return self._test_norms_with_backend("cuda")
@@ -306,12 +332,16 @@ class TestFFT(ParametricTestCase):
fft = FFT(template=data, backend=backend_name, normalize=norm)
res = fft.fft(data)
ref = self._compute_numpy_normalized_fft(data, fft.axes, norm)
- assert np.allclose(res, ref, atol=tol, rtol=tol), "Something wrong with %s norm=%s" % (backend_name, norm)
+ assert np.allclose(
+ res, ref, atol=tol, rtol=tol
+ ), "Something wrong with %s norm=%s" % (backend_name, norm)
res2 = fft.ifft(res)
ref2 = self._compute_numpy_normalized_ifft(ref, fft.axes, norm)
# unscaled IFFT yields very large values. Use a relatively high "atol"
- assert np.allclose(res2, ref2, atol=res2.max()/1e6), "Something wrong with I%s norm=%s" % (backend_name, norm)
+ assert np.allclose(
+ res2, ref2, atol=res2.max() / 1e6
+ ), "Something wrong with I%s norm=%s" % (backend_name, norm)
@unittest.skipUnless(__have_scipy, "scipy is missing")
@@ -356,13 +386,12 @@ class TestNumpyFFT(ParametricTestCase):
logger.debug("trdim: %s, mode: %s, size: %s", trdim, mode, str(size))
ndim = len(size)
input_data = self.test_data.data_refs[ndim].astype(
- self.transform_infos.modes[mode])
+ self.transform_infos.modes[mode]
+ )
np_fft, np_ifft = self.transforms[trdim][np.isrealobj(input_data)]
F = FFT(
- template=input_data,
- axes=self.transform_infos.axes[trdim],
- backend="numpy"
+ template=input_data, axes=self.transform_infos.axes[trdim], backend="numpy"
)
# Test FFT
res = F.fft(input_data)
@@ -375,21 +404,20 @@ class TestNumpyFFT(ParametricTestCase):
self.assertTrue(np.allclose(res2, ref2))
-@pytest.mark.skipif(not(__have_fftw__), reason="Need fftw/pyfftw for this test")
+@pytest.mark.skipif(not (__have_fftw__), reason="Need fftw/pyfftw for this test")
def test_fftw_wisdom():
"""
Test FFTW wisdom import/export mechanism
"""
- assert path.isdir(path.dirname(get_wisdom_file())) # Default: tempdir.gettempdir()
+ assert path.isdir(path.dirname(get_wisdom_file())) # Default: tempdir.gettempdir()
with TemporaryDirectory(prefix="test_fftw_wisdom") as dname:
subdir = path.join(dname, "subdir")
get_wisdom_file(directory=subdir, create_dirs=False)
- assert not(path.isdir(subdir))
+ assert not (path.isdir(subdir))
fname = get_wisdom_file(directory=subdir, create_dirs=True)
assert path.isdir(subdir)
export_wisdom(fname)
assert path.isfile(fname)
import_wisdom(fname)
-
diff --git a/src/silx/math/fit/__init__.py b/src/silx/math/fit/__init__.py
index 7dd6d32..da1b03d 100644
--- a/src/silx/math/fit/__init__.py
+++ b/src/silx/math/fit/__init__.py
@@ -27,9 +27,7 @@ __date__ = "22/06/2016"
from .leastsq import leastsq, chisq_alpha_beta
-from .leastsq import \
- CFREE, CPOSITIVE, CQUOTED, CFIXED, \
- CFACTOR, CDELTA, CSUM
+from .leastsq import CFREE, CPOSITIVE, CQUOTED, CFIXED, CFACTOR, CDELTA, CSUM
from .functions import *
from .filters import *
diff --git a/src/silx/math/fit/bgtheories.py b/src/silx/math/fit/bgtheories.py
index d0f4987..e698927 100644
--- a/src/silx/math/fit/bgtheories.py
+++ b/src/silx/math/fit/bgtheories.py
@@ -1,6 +1,6 @@
-#/*##########################################################################
+# /*##########################################################################
#
-# Copyright (c) 2004-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -69,10 +69,8 @@ __authors__ = ["P. Knobel"]
__license__ = "MIT"
__date__ = "16/01/2017"
-from collections import OrderedDict
import numpy
-from silx.math.fit.filters import strip, snip1d,\
- savitsky_golay
+from silx.math.fit.filters import strip, snip1d, savitsky_golay
from silx.math.fit.fittheory import FitTheory
CONFIG = {
@@ -84,7 +82,7 @@ CONFIG = {
"StripIterations": 5000,
"StripThresholdFactor": 1.0,
"SnipWidth": 16,
- "EstimatePolyOnStrip": True
+ "EstimatePolyOnStrip": True,
}
# to avoid costly computations when parameters stay the same
@@ -115,9 +113,9 @@ def _convert_anchors_to_indices(x):
of indices is empty, return None.
"""
# convert anchor X abscissa to index
- if CONFIG['AnchorsFlag'] and CONFIG['AnchorsList'] is not None:
+ if CONFIG["AnchorsFlag"] and CONFIG["AnchorsList"] is not None:
anchors_indices = []
- for anchor_x in CONFIG['AnchorsList']:
+ for anchor_x in CONFIG["AnchorsList"]:
if anchor_x <= x[0]:
continue
# take the first index where x > anchor_x
@@ -152,12 +150,13 @@ def strip_bg(x, y0, width, niter):
global _BG_OLD_ANCHORS
global _BG_OLD_ANCHORS_FLAG
- parameters_changed =\
- _BG_STRIP_OLDPARS != [width, niter] or\
- _BG_SMOOTH_OLDWIDTH != CONFIG["SmoothingWidth"] or\
- _BG_SMOOTH_OLDFLAG != CONFIG["SmoothingFlag"] or\
- _BG_OLD_ANCHORS_FLAG != CONFIG["AnchorsFlag"] or\
- _BG_OLD_ANCHORS != CONFIG["AnchorsList"]
+ parameters_changed = (
+ _BG_STRIP_OLDPARS != [width, niter]
+ or _BG_SMOOTH_OLDWIDTH != CONFIG["SmoothingWidth"]
+ or _BG_SMOOTH_OLDFLAG != CONFIG["SmoothingFlag"]
+ or _BG_OLD_ANCHORS_FLAG != CONFIG["AnchorsFlag"]
+ or _BG_OLD_ANCHORS != CONFIG["AnchorsList"]
+ )
# same parameters
if not parameters_changed:
@@ -177,11 +176,13 @@ def strip_bg(x, y0, width, niter):
anchors_indices = _convert_anchors_to_indices(x)
- background = strip(y1,
- w=width,
- niterations=niter,
- factor=CONFIG["StripThresholdFactor"],
- anchors=anchors_indices)
+ background = strip(
+ y1,
+ w=width,
+ niterations=niter,
+ factor=CONFIG["StripThresholdFactor"],
+ anchors=anchors_indices,
+ )
_BG_STRIP_OLDBG = background
@@ -198,12 +199,13 @@ def snip_bg(x, y0, width):
global _BG_OLD_ANCHORS
global _BG_OLD_ANCHORS_FLAG
- parameters_changed =\
- _BG_SNIP_OLDWIDTH != width or\
- _BG_SMOOTH_OLDWIDTH != CONFIG["SmoothingWidth"] or\
- _BG_SMOOTH_OLDFLAG != CONFIG["SmoothingFlag"] or\
- _BG_OLD_ANCHORS_FLAG != CONFIG["AnchorsFlag"] or\
- _BG_OLD_ANCHORS != CONFIG["AnchorsList"]
+ parameters_changed = (
+ _BG_SNIP_OLDWIDTH != width
+ or _BG_SMOOTH_OLDWIDTH != CONFIG["SmoothingWidth"]
+ or _BG_SMOOTH_OLDFLAG != CONFIG["SmoothingFlag"]
+ or _BG_OLD_ANCHORS_FLAG != CONFIG["AnchorsFlag"]
+ or _BG_OLD_ANCHORS != CONFIG["AnchorsList"]
+ )
# same parameters
if not parameters_changed:
@@ -230,14 +232,13 @@ def snip_bg(x, y0, width):
previous_anchor = 0
for anchor_index in anchors_indices:
if (anchor_index > previous_anchor) and (anchor_index < len(y1)):
- background[previous_anchor:anchor_index] =\
- snip1d(y1[previous_anchor:anchor_index],
- width)
- previous_anchor = anchor_index
+ background[previous_anchor:anchor_index] = snip1d(
+ y1[previous_anchor:anchor_index], width
+ )
+ previous_anchor = anchor_index
if previous_anchor < len(y1):
- background[previous_anchor:] = snip1d(y1[previous_anchor:],
- width)
+ background[previous_anchor:] = snip1d(y1[previous_anchor:], width)
_BG_SNIP_OLDBG = background
@@ -250,9 +251,7 @@ def estimate_linear(x, y):
Strip peaks, then perform a linear regression.
"""
- bg = strip_bg(x, y,
- width=CONFIG["StripWidth"],
- niter=CONFIG["StripIterations"])
+ bg = strip_bg(x, y, width=CONFIG["StripWidth"], niter=CONFIG["StripIterations"])
n = float(len(bg))
Sy = numpy.sum(bg)
Sx = float(numpy.sum(x))
@@ -278,8 +277,7 @@ def estimate_strip(x, y):
Return parameters as defined in CONFIG dict,
set constraints to FIXED.
"""
- estimated_par = [CONFIG["StripWidth"],
- CONFIG["StripIterations"]]
+ estimated_par = [CONFIG["StripWidth"], CONFIG["StripIterations"]]
constraints = numpy.zeros((len(estimated_par), 3), numpy.float64)
# code = 3: FIXED
constraints[0][0] = 3
@@ -311,46 +309,37 @@ def poly(x, y, *pars):
def estimate_poly(x, y, deg=2):
- """Estimate polynomial coefficients.
-
- """
+ """Estimate polynomial coefficients."""
# extract bg signal with strip, to estimate polynomial on background
if CONFIG["EstimatePolyOnStrip"]:
- y = strip_bg(x, y,
- CONFIG["StripWidth"],
- CONFIG["StripIterations"])
+ y = strip_bg(x, y, CONFIG["StripWidth"], CONFIG["StripIterations"])
pcoeffs = numpy.polyfit(x, y, deg)
cons = numpy.zeros((deg + 1, 3), numpy.float64)
return pcoeffs, cons
def estimate_quadratic_poly(x, y):
- """Estimate quadratic polynomial coefficients.
- """
+ """Estimate quadratic polynomial coefficients."""
return estimate_poly(x, y, deg=2)
def estimate_cubic_poly(x, y):
- """Estimate cubic polynomial coefficients.
- """
+ """Estimate cubic polynomial coefficients."""
return estimate_poly(x, y, deg=3)
def estimate_quartic_poly(x, y):
- """Estimate degree 4 polynomial coefficients.
- """
+ """Estimate degree 4 polynomial coefficients."""
return estimate_poly(x, y, deg=4)
def estimate_quintic_poly(x, y):
- """Estimate degree 5 polynomial coefficients.
- """
+ """Estimate degree 5 polynomial coefficients."""
return estimate_poly(x, y, deg=5)
def configure(**kw):
- """Update the CONFIG dict
- """
+ """Update the CONFIG dict"""
# inspect **kw to find known keys, update them in CONFIG
for key in CONFIG:
if key in kw:
@@ -359,81 +348,112 @@ def configure(**kw):
return CONFIG
-THEORY = OrderedDict(
- (('No Background',
- FitTheory(
+THEORY = dict(
+ (
+ (
+ "No Background",
+ FitTheory(
description="No background function",
function=lambda x, y0: numpy.zeros_like(x),
parameters=[],
- is_background=True)),
- ('Constant',
- FitTheory(
- description='Constant background',
+ is_background=True,
+ ),
+ ),
+ (
+ "Constant",
+ FitTheory(
+ description="Constant background",
function=lambda x, y0, c: c * numpy.ones_like(x),
- parameters=['Constant', ],
+ parameters=[
+ "Constant",
+ ],
estimate=lambda x, y: ([min(y)], [[0, 0, 0]]),
- is_background=True)),
- ('Linear',
- FitTheory(
- description="Linear background, parameters 'Constant' and"
- " 'Slope'",
+ is_background=True,
+ ),
+ ),
+ (
+ "Linear",
+ FitTheory(
+ description="Linear background, parameters 'Constant' and" " 'Slope'",
function=lambda x, y0, a, b: a + b * x,
- parameters=['Constant', 'Slope'],
+ parameters=["Constant", "Slope"],
estimate=estimate_linear,
configure=configure,
- is_background=True)),
- ('Strip',
- FitTheory(
+ is_background=True,
+ ),
+ ),
+ (
+ "Strip",
+ FitTheory(
description="Compute background using a strip filter\n"
- "Parameters 'StripWidth', 'StripIterations'",
+ "Parameters 'StripWidth', 'StripIterations'",
function=strip_bg,
- parameters=['StripWidth', 'StripIterations'],
+ parameters=["StripWidth", "StripIterations"],
estimate=estimate_strip,
configure=configure,
- is_background=True)),
- ('Snip',
- FitTheory(
+ is_background=True,
+ ),
+ ),
+ (
+ "Snip",
+ FitTheory(
description="Compute background using a snip filter\n"
- "Parameter 'SnipWidth'",
+ "Parameter 'SnipWidth'",
function=snip_bg,
- parameters=['SnipWidth'],
+ parameters=["SnipWidth"],
estimate=estimate_snip,
configure=configure,
- is_background=True)),
- ('Degree 2 Polynomial',
- FitTheory(
+ is_background=True,
+ ),
+ ),
+ (
+ "Degree 2 Polynomial",
+ FitTheory(
description="Quadratic polynomial background, Parameters "
- "'a', 'b' and 'c'\ny = a*x^2 + b*x +c",
+ "'a', 'b' and 'c'\ny = a*x^2 + b*x +c",
function=poly,
- parameters=['a', 'b', 'c'],
+ parameters=["a", "b", "c"],
estimate=estimate_quadratic_poly,
configure=configure,
- is_background=True)),
- ('Degree 3 Polynomial',
- FitTheory(
+ is_background=True,
+ ),
+ ),
+ (
+ "Degree 3 Polynomial",
+ FitTheory(
description="Cubic polynomial background, Parameters "
- "'a', 'b', 'c' and 'd'\n"
- "y = a*x^3 + b*x^2 + c*x + d",
+ "'a', 'b', 'c' and 'd'\n"
+ "y = a*x^3 + b*x^2 + c*x + d",
function=poly,
- parameters=['a', 'b', 'c', 'd'],
+ parameters=["a", "b", "c", "d"],
estimate=estimate_cubic_poly,
configure=configure,
- is_background=True)),
- ('Degree 4 Polynomial',
- FitTheory(
+ is_background=True,
+ ),
+ ),
+ (
+ "Degree 4 Polynomial",
+ FitTheory(
description="Quartic polynomial background\n"
- "y = a*x^4 + b*x^3 + c*x^2 + d*x + e",
+ "y = a*x^4 + b*x^3 + c*x^2 + d*x + e",
function=poly,
- parameters=['a', 'b', 'c', 'd', 'e'],
+ parameters=["a", "b", "c", "d", "e"],
estimate=estimate_quartic_poly,
configure=configure,
- is_background=True)),
- ('Degree 5 Polynomial',
- FitTheory(
+ is_background=True,
+ ),
+ ),
+ (
+ "Degree 5 Polynomial",
+ FitTheory(
description="Quaintic polynomial background\n"
- "y = a*x^5 + b*x^4 + c*x^3 + d*x^2 + e*x + f",
+ "y = a*x^5 + b*x^4 + c*x^3 + d*x^2 + e*x + f",
function=poly,
- parameters=['a', 'b', 'c', 'd', 'e', 'f'],
+ parameters=["a", "b", "c", "d", "e", "f"],
estimate=estimate_quintic_poly,
configure=configure,
- is_background=True))))
+ is_background=True,
+ ),
+ ),
+ )
+)
diff --git a/src/silx/math/fit/fitmanager.py b/src/silx/math/fit/fitmanager.py
index cbb1e34..983cbf7 100644
--- a/src/silx/math/fit/fitmanager.py
+++ b/src/silx/math/fit/fitmanager.py
@@ -1,6 +1,6 @@
# /*#########################################################################
#
-# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -39,7 +39,6 @@ This module deals with:
- providing different background models
"""
-from collections import OrderedDict
import logging
import numpy
from numpy.linalg.linalg import LinAlgError
@@ -82,19 +81,19 @@ class FitManager(object):
uncertainties are set to 1.
:type weight_flag: boolean
"""
+
def __init__(self, x=None, y=None, sigmay=None, weight_flag=False):
- """
- """
+ """ """
self.fitconfig = {
- 'WeightFlag': weight_flag,
- 'fitbkg': 'No Background',
- 'fittheory': None,
+ "WeightFlag": weight_flag,
+ "fitbkg": "No Background",
+ "fittheory": None,
# Next few parameters are defined for compatibility with legacy theories
# which take the background as argument for their estimation function
- 'StripWidth': 2,
- 'StripIterations': 5000,
- 'StripThresholdFactor': 1.0,
- 'SmoothingFlag': False
+ "StripWidth": 2,
+ "StripIterations": 5000,
+ "StripThresholdFactor": 1.0,
+ "SmoothingFlag": False,
}
"""Dictionary of fit configuration parameters.
These parameters can be modified using the :meth:`configure` method.
@@ -109,7 +108,7 @@ class FitManager(object):
algorithm (:func:`silx.math.fit.peak_search`)
"""
- self.theories = OrderedDict()
+ self.theories = {}
"""Dictionary of fit theories, defining functions to be fitted
to individual peaks.
@@ -134,7 +133,7 @@ class FitManager(object):
"""Name of currently selected theory. This name matches a key in
:attr:`theories`."""
- self.bgtheories = OrderedDict()
+ self.bgtheories = {}
"""Dictionary of background theories.
See :attr:`theories` for documentation on theories.
@@ -143,7 +142,7 @@ class FitManager(object):
# Load default theories (constant, linear, strip)
self.loadbgtheories(bgtheories)
- self.selectedbg = 'No Background'
+ self.selectedbg = "No Background"
"""Name of currently selected background theory. This name must be
an existing key in :attr:`bgtheories`."""
@@ -220,10 +219,18 @@ class FitManager(object):
"""
self.bgtheories[bgname] = bgtheory
- def addtheory(self, name, theory=None,
- function=None, parameters=None,
- estimate=None, configure=None, derivative=None,
- description=None, pymca_legacy=False):
+ def addtheory(
+ self,
+ name,
+ theory=None,
+ function=None,
+ parameters=None,
+ estimate=None,
+ configure=None,
+ derivative=None,
+ description=None,
+ pymca_legacy=False,
+ ):
"""Add a new theory to dictionary :attr:`theories`.
You can pass a name and a :class:`FitTheory` object as arguments, or
@@ -267,17 +274,26 @@ class FitManager(object):
estimate=estimate,
configure=configure,
derivative=derivative,
- pymca_legacy=pymca_legacy
+ pymca_legacy=pymca_legacy,
)
else:
- raise TypeError("You must supply a FitTheory object or define " +
- "a fit function and its parameters.")
+ raise TypeError(
+ "You must supply a FitTheory object or define "
+ + "a fit function and its parameters."
+ )
- def addbgtheory(self, name, theory=None,
- function=None, parameters=None,
- estimate=None, configure=None,
- derivative=None, description=None):
+ def addbgtheory(
+ self,
+ name,
+ theory=None,
+ function=None,
+ parameters=None,
+ estimate=None,
+ configure=None,
+ derivative=None,
+ description=None,
+ ):
"""Add a new theory to dictionary :attr:`bgtheories`.
You can pass a name and a :class:`FitTheory` object as arguments, or
@@ -314,12 +330,14 @@ class FitManager(object):
estimate=estimate,
configure=configure,
derivative=derivative,
- is_background=True
+ is_background=True,
)
else:
- raise TypeError("You must supply a FitTheory object or define " +
- "a background function and its parameters.")
+ raise TypeError(
+ "You must supply a FitTheory object or define "
+ + "a background function and its parameters."
+ )
def configure(self, **kw):
"""Configure the current theory by filling or updating the
@@ -394,21 +412,22 @@ class FitManager(object):
update a widget displaying a status message.
:return: Estimated parameters
"""
- self.state = 'Estimate in progress'
+ self.state = "Estimate in progress"
self.chisq = None
if callback is not None:
- callback(data={'chisq': self.chisq,
- 'status': self.state})
-
- CONS = {0: 'FREE',
- 1: 'POSITIVE',
- 2: 'QUOTED',
- 3: 'FIXED',
- 4: 'FACTOR',
- 5: 'DELTA',
- 6: 'SUM',
- 7: 'IGNORE'}
+ callback(data={"chisq": self.chisq, "status": self.state})
+
+ CONS = {
+ 0: "FREE",
+ 1: "POSITIVE",
+ 2: "QUOTED",
+ 3: "FIXED",
+ 4: "FACTOR",
+ 5: "DELTA",
+ 6: "SUM",
+ 7: "IGNORE",
+ }
# Filter-out not finite data
xwork = self.xdata[self._finite_mask]
@@ -421,9 +440,9 @@ class FitManager(object):
try:
fun_params, fun_constraints = self.estimate_fun(xwork, ywork)
except LinAlgError:
- self.state = 'Estimate failed'
+ self.state = "Estimate failed"
if callback is not None:
- callback(data={'status': self.state})
+ callback(data={"status": self.state})
raise
# build the names
@@ -446,7 +465,7 @@ class FitManager(object):
xmin = min(xwork)
xmax = max(xwork)
nb_bg_params = len(bg_params)
- for (pindex, pname) in enumerate(self.parameter_names):
+ for pindex, pname in enumerate(self.parameter_names):
# First come background parameters
if pindex < nb_bg_params:
estimation_value = bg_params[pindex]
@@ -471,24 +490,27 @@ class FitManager(object):
cons1 += nb_bg_params
cons2 = fun_constraints[fun_param_index][2]
- self.fit_results.append({'name': pname,
- 'estimation': estimation_value,
- 'group': group_number,
- 'code': constraint_code,
- 'cons1': cons1,
- 'cons2': cons2,
- 'fitresult': 0.0,
- 'sigma': 0.0,
- 'xmin': xmin,
- 'xmax': xmax})
-
- self.state = 'Ready to Fit'
+ self.fit_results.append(
+ {
+ "name": pname,
+ "estimation": estimation_value,
+ "group": group_number,
+ "code": constraint_code,
+ "cons1": cons1,
+ "cons2": cons2,
+ "fitresult": 0.0,
+ "sigma": 0.0,
+ "xmin": xmin,
+ "xmax": xmax,
+ }
+ )
+
+ self.state = "Ready to Fit"
self.chisq = None
self.niter = 0
if callback is not None:
- callback(data={'chisq': self.chisq,
- 'status': self.state})
+ callback(data={"chisq": self.chisq, "status": self.state})
return numpy.append(bg_params, fun_params)
def fit(self):
@@ -522,11 +544,11 @@ class FitManager(object):
paramlist = self.fit_results
active_params = []
for param in paramlist:
- if param['code'] not in ['IGNORE', 7]:
+ if param["code"] not in ["IGNORE", 7]:
if not estimated:
- active_params.append(param['fitresult'])
+ active_params.append(param["fitresult"])
else:
- active_params.append(param['estimation'])
+ active_params.append(param["estimation"])
# Mask x with not finite (support nD x)
finite_mask = numpy.all(numpy.isfinite(x), axis=tuple(range(1, x.ndim)))
@@ -538,7 +560,8 @@ class FitManager(object):
# Create result with same number as elements as x, filling holes with NaNs
result = numpy.full((x.shape[0],), numpy.nan, dtype=numpy.float64)
result[finite_mask] = self.fitfunction(
- numpy.array(x[finite_mask], copy=True), *active_params)
+ numpy.array(x[finite_mask], copy=True), *active_params
+ )
return result
def get_estimation(self):
@@ -595,14 +618,16 @@ class FitManager(object):
:raise: ImportError if theories cannot be imported
"""
from types import ModuleType
+
if isinstance(theories, ModuleType):
theories_module = theories
else:
# if theories is not a module, it must be a string
- string_types = (basestring,) if sys.version_info[0] == 2 else (str,) # noqa
- if not isinstance(theories, string_types):
- raise ImportError("theory must be a python module, a module" +
- "name or a python filename")
+ if not isinstance(theories, str):
+ raise ImportError(
+ "theory must be a python module, a module"
+ + "name or a python filename"
+ )
# if theories is a filename
if os.path.isfile(theories):
sys.path.append(os.path.dirname(theories))
@@ -654,14 +679,16 @@ class FitManager(object):
:raise: ImportError if theories cannot be imported
"""
from types import ModuleType
+
if isinstance(theories, ModuleType):
theories_module = theories
else:
# if theories is not a module, it must be a string
- string_types = (basestring,) if sys.version_info[0] == 2 else (str,) # noqa
- if not isinstance(theories, string_types):
- raise ImportError("theory must be a python module, a module" +
- "name or a python filename")
+ if not isinstance(theories, str):
+ raise ImportError(
+ "theory must be a python module, a module"
+ + "name or a python filename"
+ )
# if theories is a filename
if os.path.isfile(theories):
sys.path.append(os.path.dirname(theories))
@@ -746,10 +773,14 @@ class FitManager(object):
# default weight
if sigmay is None:
self.sigmay0 = None
- self.sigmay = numpy.sqrt(self.ydata) if self.fitconfig["WeightFlag"] else None
+ self.sigmay = (
+ numpy.sqrt(self.ydata) if self.fitconfig["WeightFlag"] else None
+ )
else:
self.sigmay0 = numpy.array(sigmay)
- self.sigmay = numpy.array(sigmay) if self.fitconfig["WeightFlag"] else None
+ self.sigmay = (
+ numpy.array(sigmay) if self.fitconfig["WeightFlag"] else None
+ )
# take the data between limits, using boolean array indexing
if (xmin is not None or xmax is not None) and len(self.xdata):
@@ -761,8 +792,11 @@ class FitManager(object):
self.sigmay = self.sigmay[bool_array] if sigmay is not None else None
self._finite_mask = numpy.logical_and(
- numpy.all(numpy.isfinite(self.xdata), axis=tuple(range(1, self.xdata.ndim))),
- numpy.isfinite(self.ydata))
+ numpy.all(
+ numpy.isfinite(self.xdata), axis=tuple(range(1, self.xdata.ndim))
+ ),
+ numpy.isfinite(self.ydata),
+ )
def enableweight(self):
"""This method can be called to set :attr:`sigmay`. If :attr:`sigmay0` was filled with
@@ -770,7 +804,9 @@ class FitManager(object):
Else, use ``sqrt(self.ydata)``.
"""
if self.sigmay0 is None:
- self.sigmay = numpy.sqrt(self.ydata) if self.fitconfig["WeightFlag"] else None
+ self.sigmay = (
+ numpy.sqrt(self.ydata) if self.fitconfig["WeightFlag"] else None
+ )
else:
self.sigmay = self.sigmay0
@@ -822,19 +858,18 @@ class FitManager(object):
"""
# self.dataupdate()
- self.state = 'Fit in progress'
+ self.state = "Fit in progress"
self.chisq = None
if callback is not None:
- callback(data={'chisq': self.chisq,
- 'status': self.state})
+ callback(data={"chisq": self.chisq, "status": self.state})
param_val = []
param_constraints = []
# Initial values are set to the ones computed in estimate()
for param in self.fit_results:
- param_val.append(param['estimation'])
- param_constraints.append([param['code'], param['cons1'], param['cons2']])
+ param_val.append(param["estimation"])
+ param_constraints.append([param["code"], param["cons1"], param["cons2"]])
# Filter-out not finite data
ywork = self.ydata[self._finite_mask]
@@ -842,31 +877,34 @@ class FitManager(object):
try:
params, covariance_matrix, infodict = leastsq(
- self.fitfunction, # bg + actual model function
- xwork, ywork, param_val,
- sigma=self.sigmay,
- constraints=param_constraints,
- model_deriv=self.theories[self.selectedtheory].derivative,
- full_output=True, left_derivative=True)
+ self.fitfunction, # bg + actual model function
+ xwork,
+ ywork,
+ param_val,
+ sigma=self.sigmay,
+ constraints=param_constraints,
+ model_deriv=self.theories[self.selectedtheory].derivative,
+ full_output=True,
+ left_derivative=True,
+ )
except LinAlgError:
- self.state = 'Fit failed'
- callback(data={'status': self.state})
+ self.state = "Fit failed"
+ callback(data={"status": self.state})
raise
- sigmas = infodict['uncertainties']
+ sigmas = infodict["uncertainties"]
for i, param in enumerate(self.fit_results):
- if param['code'] != 'IGNORE':
- param['fitresult'] = params[i]
- param['sigma'] = sigmas[i]
+ if param["code"] != "IGNORE":
+ param["fitresult"] = params[i]
+ param["sigma"] = sigmas[i]
self.chisq = infodict["reduced_chisq"]
self.niter = infodict["niter"]
- self.state = 'Ready'
+ self.state = "Ready"
if callback is not None:
- callback(data={'chisq': self.chisq,
- 'status': self.state})
+ callback(data={"chisq": self.chisq, "status": self.state})
return params, sigmas, infodict
@@ -964,7 +1002,7 @@ class FitManager(object):
"""
estimatefunction = self.theories[self.selectedtheory].estimate
- if hasattr(estimatefunction, '__call__'):
+ if hasattr(estimatefunction, "__call__"):
if not self.theories[self.selectedtheory].pymca_legacy:
return estimatefunction(x, y)
else:
@@ -974,59 +1012,76 @@ class FitManager(object):
else:
if self.fitconfig["SmoothingFlag"]:
y = smooth1d(y)
- bg = strip(y,
- w=self.fitconfig["StripWidth"],
- niterations=self.fitconfig["StripIterations"],
- factor=self.fitconfig["StripThresholdFactor"])
+ bg = strip(
+ y,
+ w=self.fitconfig["StripWidth"],
+ niterations=self.fitconfig["StripIterations"],
+ factor=self.fitconfig["StripThresholdFactor"],
+ )
# fitconfig can be filled by user defined config function
- xscaling = self.fitconfig.get('Xscaling', 1.0)
- yscaling = self.fitconfig.get('Yscaling', 1.0)
+ xscaling = self.fitconfig.get("Xscaling", 1.0)
+ yscaling = self.fitconfig.get("Yscaling", 1.0)
return estimatefunction(x, y, bg, xscaling, yscaling)
else:
- raise TypeError("Estimation function in attribute " +
- "theories[%s]" % self.selectedtheory +
- " must be callable.")
+ raise TypeError(
+ "Estimation function in attribute "
+ + "theories[%s]" % self.selectedtheory
+ + " must be callable."
+ )
def _load_legacy_theories(self, theories_module):
"""Load theories from a custom module in the old PyMca format.
See PyMca5.PyMcaMath.fitting.SpecfitFunctions for an example.
"""
- mandatory_attributes = ["THEORY", "PARAMETERS",
- "FUNCTION", "ESTIMATE"]
+ mandatory_attributes = ["THEORY", "PARAMETERS", "FUNCTION", "ESTIMATE"]
err_msg = "Custom fit function file must define: "
err_msg += ", ".join(mandatory_attributes)
for attr in mandatory_attributes:
if not hasattr(theories_module, attr):
raise ImportError(err_msg)
- derivative = theories_module.DERIVATIVE if hasattr(theories_module, "DERIVATIVE") else None
- configure = theories_module.CONFIGURE if hasattr(theories_module, "CONFIGURE") else None
- estimate = theories_module.ESTIMATE if hasattr(theories_module, "ESTIMATE") else None
+ derivative = (
+ theories_module.DERIVATIVE
+ if hasattr(theories_module, "DERIVATIVE")
+ else None
+ )
+ configure = (
+ theories_module.CONFIGURE if hasattr(theories_module, "CONFIGURE") else None
+ )
+ estimate = (
+ theories_module.ESTIMATE if hasattr(theories_module, "ESTIMATE") else None
+ )
if isinstance(theories_module.THEORY, (list, tuple)):
# multiple fit functions
for i in range(len(theories_module.THEORY)):
deriv = derivative[i] if derivative is not None else None
config = configure[i] if configure is not None else None
estim = estimate[i] if estimate is not None else None
- self.addtheory(theories_module.THEORY[i],
- FitTheory(
- theories_module.FUNCTION[i],
- theories_module.PARAMETERS[i],
- estim,
- config,
- deriv,
- pymca_legacy=True))
+ self.addtheory(
+ theories_module.THEORY[i],
+ FitTheory(
+ theories_module.FUNCTION[i],
+ theories_module.PARAMETERS[i],
+ estim,
+ config,
+ deriv,
+ pymca_legacy=True,
+ ),
+ )
else:
# single fit function
- self.addtheory(theories_module.THEORY,
- FitTheory(
- theories_module.FUNCTION,
- theories_module.PARAMETERS,
- estimate,
- configure,
- derivative,
- pymca_legacy=True))
+ self.addtheory(
+ theories_module.THEORY,
+ FitTheory(
+ theories_module.FUNCTION,
+ theories_module.PARAMETERS,
+ estimate,
+ configure,
+ derivative,
+ pymca_legacy=True,
+ ),
+ )
def test():
@@ -1037,9 +1092,7 @@ def test():
# Create synthetic data with a sum of gaussian functions
x = numpy.arange(1000).astype(numpy.float64)
- p = [1000, 100., 250,
- 255, 690., 45,
- 1500, 800.5, 95]
+ p = [1000, 100.0, 250, 255, 690.0, 45, 1500, 800.5, 95]
y = 0.5 * x + 13 + sum_gauss(x, *p)
# Fitting
@@ -1048,9 +1101,9 @@ def test():
# overlapping peaks at x=690 and x=800.5
fit.setdata(x=x, y=y)
fit.loadtheories(fittheories)
- fit.settheory('Gaussians')
+ fit.settheory("Gaussians")
fit.loadbgtheories(bgtheories)
- fit.setbackground('Linear')
+ fit.setbackground("Linear")
fit.estimate()
fit.runfit()
@@ -1058,8 +1111,8 @@ def test():
print("Obtained parameters : ")
dummy_list = []
for param in fit.fit_results:
- print(param['name'], ' = ', param['fitresult'])
- dummy_list.append(param['fitresult'])
+ print(param["name"], " = ", param["fitresult"])
+ dummy_list.append(param["fitresult"])
print("chisq = ", fit.chisq)
# Plot
@@ -1071,6 +1124,7 @@ def test():
try:
from silx.gui import qt
from silx.gui.plot.PlotWindow import PlotWindow
+
app = qt.QApplication([])
pw = PlotWindow(control=True)
pw.addCurve(x, y, "Original")
diff --git a/src/silx/math/fit/fittheories.py b/src/silx/math/fit/fittheories.py
index 76f2478..f20cbf1 100644
--- a/src/silx/math/fit/fittheories.py
+++ b/src/silx/math/fit/fittheories.py
@@ -1,6 +1,6 @@
-#/*##########################################################################
+# /*##########################################################################
#
-# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -59,10 +59,7 @@ a dictionary :const:`THEORY`: with the following structure::
.. note::
- Consider using an OrderedDict instead of a regular dictionary, when
- defining your own theory dictionary, if the order matters to you.
- This will likely be the case if you intend to load a selection of
- functions in a GUI such as :class:`silx.gui.fit.FitManager`.
+ The order of the provided dictionary is taken into account.
Theory names can be customized (e.g. ``gauss, lorentz, splitgauss``…).
@@ -79,7 +76,6 @@ Module members:
---------------
"""
import numpy
-from collections import OrderedDict
import logging
from silx.math.fit import functions
@@ -96,58 +92,59 @@ __date__ = "15/05/2017"
DEFAULT_CONFIG = {
- 'NoConstraintsFlag': False,
- 'PositiveFwhmFlag': True,
- 'PositiveHeightAreaFlag': True,
- 'SameFwhmFlag': False,
- 'QuotedPositionFlag': False, # peak not outside data range
- 'QuotedEtaFlag': False, # force 0 < eta < 1
+ "NoConstraintsFlag": False,
+ "PositiveFwhmFlag": True,
+ "PositiveHeightAreaFlag": True,
+ "SameFwhmFlag": False,
+ "QuotedPositionFlag": False, # peak not outside data range
+ "QuotedEtaFlag": False, # force 0 < eta < 1
# Peak detection
- 'AutoScaling': False,
- 'Yscaling': 1.0,
- 'FwhmPoints': 8,
- 'AutoFwhm': True,
- 'Sensitivity': 2.5,
- 'ForcePeakPresence': True,
+ "AutoScaling": False,
+ "Yscaling": 1.0,
+ "FwhmPoints": 8,
+ "AutoFwhm": True,
+ "Sensitivity": 2.5,
+ "ForcePeakPresence": True,
# Hypermet
- 'HypermetTails': 15,
- 'QuotedFwhmFlag': 0,
- 'MaxFwhm2InputRatio': 1.5,
- 'MinFwhm2InputRatio': 0.4,
+ "HypermetTails": 15,
+ "QuotedFwhmFlag": 0,
+ "MaxFwhm2InputRatio": 1.5,
+ "MinFwhm2InputRatio": 0.4,
# short tail parameters
- 'MinGaussArea4ShortTail': 50000.,
- 'InitialShortTailAreaRatio': 0.050,
- 'MaxShortTailAreaRatio': 0.100,
- 'MinShortTailAreaRatio': 0.0010,
- 'InitialShortTailSlopeRatio': 0.70,
- 'MaxShortTailSlopeRatio': 2.00,
- 'MinShortTailSlopeRatio': 0.50,
+ "MinGaussArea4ShortTail": 50000.0,
+ "InitialShortTailAreaRatio": 0.050,
+ "MaxShortTailAreaRatio": 0.100,
+ "MinShortTailAreaRatio": 0.0010,
+ "InitialShortTailSlopeRatio": 0.70,
+ "MaxShortTailSlopeRatio": 2.00,
+ "MinShortTailSlopeRatio": 0.50,
# long tail parameters
- 'MinGaussArea4LongTail': 1000.0,
- 'InitialLongTailAreaRatio': 0.050,
- 'MaxLongTailAreaRatio': 0.300,
- 'MinLongTailAreaRatio': 0.010,
- 'InitialLongTailSlopeRatio': 20.0,
- 'MaxLongTailSlopeRatio': 50.0,
- 'MinLongTailSlopeRatio': 5.0,
+ "MinGaussArea4LongTail": 1000.0,
+ "InitialLongTailAreaRatio": 0.050,
+ "MaxLongTailAreaRatio": 0.300,
+ "MinLongTailAreaRatio": 0.010,
+ "InitialLongTailSlopeRatio": 20.0,
+ "MaxLongTailSlopeRatio": 50.0,
+ "MinLongTailSlopeRatio": 5.0,
# step tail
- 'MinGaussHeight4StepTail': 5000.,
- 'InitialStepTailHeightRatio': 0.002,
- 'MaxStepTailHeightRatio': 0.0100,
- 'MinStepTailHeightRatio': 0.0001,
+ "MinGaussHeight4StepTail": 5000.0,
+ "InitialStepTailHeightRatio": 0.002,
+ "MaxStepTailHeightRatio": 0.0100,
+ "MinStepTailHeightRatio": 0.0001,
# Hypermet constraints
# position in range [estimated position +- estimated fwhm/2]
- 'HypermetQuotedPositionFlag': True,
- 'DeltaPositionFwhmUnits': 0.5,
- 'SameSlopeRatioFlag': 1,
- 'SameAreaRatioFlag': 1,
+ "HypermetQuotedPositionFlag": True,
+ "DeltaPositionFwhmUnits": 0.5,
+ "SameSlopeRatioFlag": 1,
+ "SameAreaRatioFlag": 1,
# Strip bg removal
- 'StripBackgroundFlag': True,
- 'SmoothingFlag': True,
- 'SmoothingWidth': 5,
- 'StripWidth': 2,
- 'StripIterations': 5000,
- 'StripThresholdFactor': 1.0}
+ "StripBackgroundFlag": True,
+ "SmoothingFlag": True,
+ "SmoothingWidth": 5,
+ "StripWidth": 2,
+ "StripIterations": 5000,
+ "StripThresholdFactor": 1.0,
+}
"""This dictionary defines default configuration parameters that have effects
on fit functions and estimation functions, mainly on fit constraints.
This dictionary is accessible as attribute :attr:`FitTheories.config`,
@@ -168,6 +165,7 @@ CIGNORED = 7
class FitTheories(object):
"""Class wrapping functions from :class:`silx.math.fit.functions`
and providing estimate functions for all of these fit functions."""
+
def __init__(self, config=None):
if config is None:
self.config = DEFAULT_CONFIG
@@ -189,13 +187,18 @@ class FitTheories(object):
For example, 15 can be expressed as ``1111`` in base 2, so a flag of
15 means all terms are active.
"""
- g_term = self.config['HypermetTails'] & 1
- st_term = (self.config['HypermetTails'] >> 1) & 1
- lt_term = (self.config['HypermetTails'] >> 2) & 1
- step_term = (self.config['HypermetTails'] >> 3) & 1
- return functions.sum_ahypermet(x, *pars,
- gaussian_term=g_term, st_term=st_term,
- lt_term=lt_term, step_term=step_term)
+ g_term = self.config["HypermetTails"] & 1
+ st_term = (self.config["HypermetTails"] >> 1) & 1
+ lt_term = (self.config["HypermetTails"] >> 2) & 1
+ step_term = (self.config["HypermetTails"] >> 3) & 1
+ return functions.sum_ahypermet(
+ x,
+ *pars,
+ gaussian_term=g_term,
+ st_term=st_term,
+ lt_term=lt_term,
+ step_term=step_term,
+ )
def poly(self, x, *pars):
"""Order n polynomial.
@@ -208,51 +211,41 @@ class FitTheories(object):
@staticmethod
def estimate_poly(x, y, n=2):
- """Estimate polynomial coefficients for a degree n polynomial.
-
- """
+ """Estimate polynomial coefficients for a degree n polynomial."""
pcoeffs = numpy.polyfit(x, y, n)
constraints = numpy.zeros((n + 1, 3), numpy.float64)
return pcoeffs, constraints
def estimate_quadratic(self, x, y):
- """Estimate quadratic coefficients
-
- """
+ """Estimate quadratic coefficients"""
return self.estimate_poly(x, y, n=2)
def estimate_cubic(self, x, y):
- """Estimate coefficients for a degree 3 polynomial
-
- """
+ """Estimate coefficients for a degree 3 polynomial"""
return self.estimate_poly(x, y, n=3)
def estimate_quartic(self, x, y):
- """Estimate coefficients for a degree 4 polynomial
-
- """
+ """Estimate coefficients for a degree 4 polynomial"""
return self.estimate_poly(x, y, n=4)
def estimate_quintic(self, x, y):
- """Estimate coefficients for a degree 5 polynomial
-
- """
+ """Estimate coefficients for a degree 5 polynomial"""
return self.estimate_poly(x, y, n=5)
def strip_bg(self, y):
"""Return the strip background of y, using parameters from
:attr:`config` dictionary (*StripBackgroundFlag, StripWidth,
StripIterations, StripThresholdFactor*)"""
- remove_strip_bg = self.config.get('StripBackgroundFlag', False)
+ remove_strip_bg = self.config.get("StripBackgroundFlag", False)
if remove_strip_bg:
- if self.config['SmoothingFlag']:
- y = savitsky_golay(y, self.config['SmoothingWidth'])
- strip_width = self.config['StripWidth']
- strip_niterations = self.config['StripIterations']
- strip_thr_factor = self.config['StripThresholdFactor']
- return strip(y, w=strip_width,
- niterations=strip_niterations,
- factor=strip_thr_factor)
+ if self.config["SmoothingFlag"]:
+ y = savitsky_golay(y, self.config["SmoothingWidth"])
+ strip_width = self.config["StripWidth"]
+ strip_niterations = self.config["StripIterations"]
+ strip_thr_factor = self.config["StripThresholdFactor"]
+ return strip(
+ y, w=strip_width, niterations=strip_niterations, factor=strip_thr_factor
+ )
else:
return numpy.zeros_like(y)
@@ -268,7 +261,7 @@ class FitTheories(object):
yy = numpy.array(y, copy=False)
# smooth
- convolution_kernel = numpy.ones(shape=(3,)) / 3.
+ convolution_kernel = numpy.ones(shape=(3,)) / 3.0
ysmooth = numpy.convolve(y, convolution_kernel, mode="same")
# remove zeros
@@ -277,9 +270,9 @@ class FitTheories(object):
ysmooth = ysmooth[idx_array]
# compute scaling factor
- chisq = numpy.mean((yy - ysmooth)**2 / numpy.fabs(yy))
+ chisq = numpy.mean((yy - ysmooth) ** 2 / numpy.fabs(yy))
if chisq > 0:
- return 1. / chisq
+ return 1.0 / chisq
else:
return 1.0
@@ -299,16 +292,22 @@ class FitTheories(object):
# add padding
ysearch = numpy.ones((len(y) + 2 * fwhm,), numpy.float64)
ysearch[0:fwhm] = y[0]
- ysearch[-1:-fwhm - 1:-1] = y[len(y)-1]
- ysearch[fwhm:fwhm + len(y)] = y[:]
+ ysearch[-1 : -fwhm - 1 : -1] = y[len(y) - 1]
+ ysearch[fwhm : fwhm + len(y)] = y[:]
- scaling = self.guess_yscaling(y) if self.config["AutoScaling"] else self.config["Yscaling"]
+ scaling = (
+ self.guess_yscaling(y)
+ if self.config["AutoScaling"]
+ else self.config["Yscaling"]
+ )
if len(ysearch) > 1.5 * fwhm:
- peaks = peak_search(scaling * ysearch,
- fwhm=fwhm, sensitivity=sensitivity)
- return [peak_index - fwhm for peak_index in peaks
- if 0 <= peak_index - fwhm < len(y)]
+ peaks = peak_search(scaling * ysearch, fwhm=fwhm, sensitivity=sensitivity)
+ return [
+ peak_index - fwhm
+ for peak_index in peaks
+ if 0 <= peak_index - fwhm < len(y)
+ ]
else:
return []
@@ -331,32 +330,32 @@ class FitTheories(object):
bg = self.strip_bg(y)
- if self.config['AutoFwhm']:
+ if self.config["AutoFwhm"]:
search_fwhm = guess_fwhm(y)
else:
- search_fwhm = int(float(self.config['FwhmPoints']))
- search_sens = float(self.config['Sensitivity'])
+ search_fwhm = int(float(self.config["FwhmPoints"]))
+ search_sens = float(self.config["Sensitivity"])
if search_fwhm < 3:
_logger.warning("Setting peak fwhm to 3 (lower limit)")
search_fwhm = 3
- self.config['FwhmPoints'] = 3
+ self.config["FwhmPoints"] = 3
if search_sens < 1:
- _logger.warning("Setting peak search sensitivity to 1. " +
- "(lower limit to filter out noise peaks)")
+ _logger.warning(
+ "Setting peak search sensitivity to 1. "
+ + "(lower limit to filter out noise peaks)"
+ )
search_sens = 1
- self.config['Sensitivity'] = 1
+ self.config["Sensitivity"] = 1
npoints = len(y)
# Find indices of peaks in data array
- peaks = self.peak_search(y,
- fwhm=search_fwhm,
- sensitivity=search_sens)
+ peaks = self.peak_search(y, fwhm=search_fwhm, sensitivity=search_sens)
if not len(peaks):
- forcepeak = int(float(self.config.get('ForcePeakPresence', 0)))
+ forcepeak = int(float(self.config.get("ForcePeakPresence", 0)))
if forcepeak:
delta = y - bg
# get index of global maximum
@@ -371,13 +370,11 @@ class FitTheories(object):
peakpos = x[int(peaks[0])]
if abs(peakpos) < 1.0e-16:
peakpos = 0.0
- param = numpy.array(
- [y[int(peaks[0])] - bg[int(peaks[0])], peakpos, sig])
+ param = numpy.array([y[int(peaks[0])] - bg[int(peaks[0])], peakpos, sig])
height_largest_peak = param[0]
peak_index = 1
for i in peaks[1:]:
- param2 = numpy.array(
- [y[int(i)] - bg[int(i)], x[int(i)], sig])
+ param2 = numpy.array([y[int(i)] - bg[int(i)], x[int(i)], sig])
param = numpy.concatenate((param, param2))
if param2[0] > height_largest_peak:
height_largest_peak = param2[0]
@@ -391,60 +388,62 @@ class FitTheories(object):
cons = numpy.zeros((len(param), 3), numpy.float64)
# peak height must be positive
- cons[0:len(param):3, 0] = CPOSITIVE
+ cons[0 : len(param) : 3, 0] = CPOSITIVE
# force peaks to stay around their position
- cons[1:len(param):3, 0] = CQUOTED
+ cons[1 : len(param) : 3, 0] = CQUOTED
# set possible peak range to estimated peak +- guessed fwhm
if len(xw) > search_fwhm:
fwhmx = numpy.fabs(xw[int(search_fwhm)] - xw[0])
- cons[1:len(param):3, 1] = param[1:len(param):3] - 0.5 * fwhmx
- cons[1:len(param):3, 2] = param[1:len(param):3] + 0.5 * fwhmx
+ cons[1 : len(param) : 3, 1] = param[1 : len(param) : 3] - 0.5 * fwhmx
+ cons[1 : len(param) : 3, 2] = param[1 : len(param) : 3] + 0.5 * fwhmx
else:
- shape = [max(1, int(x)) for x in (param[1:len(param):3])]
- cons[1:len(param):3, 1] = min(xw) * numpy.ones(
- shape,
- numpy.float64)
- cons[1:len(param):3, 2] = max(xw) * numpy.ones(
- shape,
- numpy.float64)
+ shape = [max(1, int(x)) for x in (param[1 : len(param) : 3])]
+ cons[1 : len(param) : 3, 1] = min(xw) * numpy.ones(shape, numpy.float64)
+ cons[1 : len(param) : 3, 2] = max(xw) * numpy.ones(shape, numpy.float64)
# ensure fwhm is positive
- cons[2:len(param):3, 0] = CPOSITIVE
+ cons[2 : len(param) : 3, 0] = CPOSITIVE
# run a quick iterative fit (4 iterations) to improve
# estimations
- fittedpar, _, _ = leastsq(functions.sum_gauss, xw, yw, param,
- max_iter=4, constraints=cons.tolist(),
- full_output=True)
+ fittedpar, _, _ = leastsq(
+ functions.sum_gauss,
+ xw,
+ yw,
+ param,
+ max_iter=4,
+ constraints=cons.tolist(),
+ full_output=True,
+ )
# set final constraints based on config parameters
cons = numpy.zeros((len(fittedpar), 3), numpy.float64)
peak_index = 0
for i in range(len(peaks)):
# Setup height area constrains
- if not self.config['NoConstraintsFlag']:
- if self.config['PositiveHeightAreaFlag']:
+ if not self.config["NoConstraintsFlag"]:
+ if self.config["PositiveHeightAreaFlag"]:
cons[peak_index, 0] = CPOSITIVE
cons[peak_index, 1] = 0
cons[peak_index, 2] = 0
peak_index += 1
# Setup position constrains
- if not self.config['NoConstraintsFlag']:
- if self.config['QuotedPositionFlag']:
+ if not self.config["NoConstraintsFlag"]:
+ if self.config["QuotedPositionFlag"]:
cons[peak_index, 0] = CQUOTED
cons[peak_index, 1] = min(x)
cons[peak_index, 2] = max(x)
peak_index += 1
# Setup positive FWHM constrains
- if not self.config['NoConstraintsFlag']:
- if self.config['PositiveFwhmFlag']:
+ if not self.config["NoConstraintsFlag"]:
+ if self.config["PositiveFwhmFlag"]:
cons[peak_index, 0] = CPOSITIVE
cons[peak_index, 1] = 0
cons[peak_index, 2] = 0
- if self.config['SameFwhmFlag']:
+ if self.config["SameFwhmFlag"]:
if i != index_largest_peak:
cons[peak_index, 0] = CFACTOR
cons[peak_index, 1] = 3 * index_largest_peak + 2
@@ -475,8 +474,12 @@ class FitTheories(object):
height = fittedpar[3 * i]
fwhm = fittedpar[3 * i + 2]
# Replace height with area in fittedpar
- fittedpar[3 * i] = numpy.sqrt(2 * numpy.pi) * height * fwhm / (
- 2.0 * numpy.sqrt(2 * numpy.log(2)))
+ fittedpar[3 * i] = (
+ numpy.sqrt(2 * numpy.pi)
+ * height
+ * fwhm
+ / (2.0 * numpy.sqrt(2 * numpy.log(2)))
+ )
return fittedpar, cons
def estimate_alorentz(self, x, y):
@@ -501,7 +504,7 @@ class FitTheories(object):
height = fittedpar[3 * i]
fwhm = fittedpar[3 * i + 2]
# Replace height with area in fittedpar
- fittedpar[3 * i] = (height * fwhm * 0.5 * numpy.pi)
+ fittedpar[3 * i] = height * fwhm * 0.5 * numpy.pi
return fittedpar, cons
def estimate_splitgauss(self, x, y):
@@ -548,10 +551,12 @@ class FitTheories(object):
if cons[3 * i + 2, 0] == CFACTOR:
# convert indices of related parameters
# (this happens if SameFwhmFlag == True)
- estimated_constraints[4 * i + 2, 1] = \
+ estimated_constraints[4 * i + 2, 1] = (
int(cons[3 * i + 2, 1] / 3) * 4 + 2
- estimated_constraints[4 * i + 3, 1] = \
+ )
+ estimated_constraints[4 * i + 3, 1] = (
int(cons[3 * i + 2, 1] / 3) * 4 + 3
+ )
return estimated_parameters, estimated_constraints
def estimate_pvoigt(self, x, y):
@@ -580,8 +585,8 @@ class FitTheories(object):
newpar = []
newcons = numpy.zeros((4 * npeaks, 3), numpy.float64)
# find out related parameters proper index
- if not self.config['NoConstraintsFlag']:
- if self.config['SameFwhmFlag']:
+ if not self.config["NoConstraintsFlag"]:
+ if self.config["SameFwhmFlag"]:
j = 0
# get the index of the free FWHM
for i in range(npeaks):
@@ -611,7 +616,7 @@ class FitTheories(object):
newcons[4 * i + 3, 0] = CFREE
newcons[4 * i + 3, 1] = 0
newcons[4 * i + 3, 2] = 0
- if self.config['QuotedEtaFlag']:
+ if self.config["QuotedEtaFlag"]:
newcons[4 * i + 3, 0] = CQUOTED
newcons[4 * i + 3, 1] = 0.0
newcons[4 * i + 3, 2] = 1.0
@@ -641,8 +646,8 @@ class FitTheories(object):
newpar = []
newcons = numpy.zeros((5 * npeaks, 3), numpy.float64)
# find out related parameters proper index
- if not self.config['NoConstraintsFlag']:
- if self.config['SameFwhmFlag']:
+ if not self.config["NoConstraintsFlag"]:
+ if self.config["SameFwhmFlag"]:
j = 0
# get the index of the free FWHM
for i in range(npeaks):
@@ -692,12 +697,103 @@ class FitTheories(object):
newcons[5 * i + 4, 0] = CFREE
newcons[5 * i + 4, 1] = 0
newcons[5 * i + 4, 2] = 0
- if self.config['QuotedEtaFlag']:
+ if self.config["QuotedEtaFlag"]:
newcons[5 * i + 4, 0] = CQUOTED
newcons[5 * i + 4, 1] = 0.0
newcons[5 * i + 4, 2] = 1.0
return newpar, newcons
+ def estimate_splitpvoigt2(self, x, y):
+ """Estimation of *Height, Position, FWHM1, FWHM2, eta1, eta2* of peaks, for
+ asymmetric pseudo-Voigt curves.
+
+ This functions uses :meth:`estimate_height_position_fwhm`, then
+ adds an identical FWHM2 parameter and a constant estimation of
+ *eta1* and *eta2* (0.5) to the fit parameters for each peak, and the corresponding
+ constraints.
+
+ Constraint for the eta parameter can be set to QUOTED (0.--1.)
+ by setting :attr:`config`['QuotedEtaFlag'] to ``True``.
+ If this is not the case, the constraint code is set to FREE.
+
+ :param x: Array of abscissa values
+ :param y: Array of ordinate values (``y = f(x)``)
+ :return: Tuple of estimated fit parameters and fit constraints.
+ Parameters to be estimated for each peak are:
+ *Height, Position, FWHM1, FWHM2, eta1, eta2*.
+ """
+ fittedpar, cons = self.estimate_height_position_fwhm(x, y)
+ npeaks = len(fittedpar) // 3
+ newpar = []
+ newcons = numpy.zeros((5 * npeaks, 3), numpy.float64)
+ # find out related parameters proper index
+ if not self.config["NoConstraintsFlag"]:
+ if self.config["SameFwhmFlag"]:
+ j = 0
+ # get the index of the free FWHM
+ for i in range(npeaks):
+ if cons[3 * i + 2, 0] != 4:
+ j = i
+ for i in range(npeaks):
+ if i != j:
+ cons[3 * i + 2, 1] = 4 * j + 2
+ for i in range(npeaks):
+ # height
+ newpar.append(fittedpar[3 * i])
+ # position
+ newpar.append(fittedpar[3 * i + 1])
+ # fwhm1
+ newpar.append(fittedpar[3 * i + 2])
+ # fwhm2 estimate equal to fwhm1
+ newpar.append(fittedpar[3 * i + 2])
+ # eta1
+ newpar.append(0.5)
+ # eta2
+ newpar.append(0.5)
+ # constraint codes
+ # ----------------
+ # height
+ newcons[6 * i, 0] = cons[3 * i, 0]
+ # position
+ newcons[6 * i + 1, 0] = cons[3 * i + 1, 0]
+ # fwhm1
+ newcons[6 * i + 2, 0] = cons[3 * i + 2, 0]
+ # fwhm2
+ newcons[6 * i + 3, 0] = cons[3 * i + 2, 0]
+ # cons 1
+ # ------
+ newcons[6 * i, 1] = cons[3 * i, 1]
+ newcons[6 * i + 1, 1] = cons[3 * i + 1, 1]
+ newcons[6 * i + 2, 1] = cons[3 * i + 2, 1]
+ newcons[6 * i + 3, 1] = cons[3 * i + 2, 1]
+ # cons 2
+ # ------
+ newcons[6 * i, 2] = cons[3 * i, 2]
+ newcons[6 * i + 1, 2] = cons[3 * i + 1, 2]
+ newcons[6 * i + 2, 2] = cons[3 * i + 2, 2]
+ newcons[6 * i + 3, 2] = cons[3 * i + 2, 2]
+
+ if cons[3 * i + 2, 0] == CFACTOR:
+ # fwhm2 constraint depends on fwhm1
+ newcons[6 * i + 3, 1] = newcons[6 * i + 2, 1] + 1
+ # eta1 constraints
+ newcons[6 * i + 4, 0] = CFREE
+ newcons[6 * i + 4, 1] = 0
+ newcons[6 * i + 4, 2] = 0
+ if self.config["QuotedEtaFlag"]:
+ newcons[6 * i + 4, 0] = CQUOTED
+ newcons[6 * i + 4, 1] = 0.0
+ newcons[6 * i + 4, 2] = 1.0
+ # eta2 constraints
+ newcons[6 * i + 5, 0] = CFREE
+ newcons[6 * i + 5, 1] = 0
+ newcons[6 * i + 5, 2] = 0
+ if self.config["QuotedEtaFlag"]:
+ newcons[6 * i + 5, 0] = CQUOTED
+ newcons[6 * i + 5, 1] = 0.0
+ newcons[6 * i + 5, 2] = 1.0
+ return newpar, newcons
+
def estimate_apvoigt(self, x, y):
"""Estimation of *Area, Position, FWHM1, eta* of peaks, for
pseudo-Voigt curves.
@@ -718,9 +814,9 @@ class FitTheories(object):
for i in range(npeaks):
height = fittedpar[4 * i]
fwhm = fittedpar[4 * i + 2]
- fittedpar[4 * i] = 0.5 * (height * fwhm * 0.5 * numpy.pi) +\
- 0.5 * (height * fwhm / (2.0 * numpy.sqrt(2 * numpy.log(2)))
- ) * numpy.sqrt(2 * numpy.pi)
+ fittedpar[4 * i] = 0.5 * (height * fwhm * 0.5 * numpy.pi) + 0.5 * (
+ height * fwhm / (2.0 * numpy.sqrt(2 * numpy.log(2)))
+ ) * numpy.sqrt(2 * numpy.pi)
return fittedpar, cons
def estimate_ahypermet(self, x, y):
@@ -734,7 +830,7 @@ class FitTheories(object):
*area, position, fwhm, st_area_r, st_slope_r,
lt_area_r, lt_slope_r, step_height_r* .
"""
- yscaling = self.config.get('Yscaling', 1.0)
+ yscaling = self.config.get("Yscaling", 1.0)
if yscaling == 0:
yscaling = 1.0
fittedpar, cons = self.estimate_height_position_fwhm(x, y)
@@ -743,8 +839,8 @@ class FitTheories(object):
newcons = numpy.zeros((8 * npeaks, 3), numpy.float64)
main_peak = 0
# find out related parameters proper index
- if not self.config['NoConstraintsFlag']:
- if self.config['SameFwhmFlag']:
+ if not self.config["NoConstraintsFlag"]:
+ if self.config["SameFwhmFlag"]:
j = 0
# get the index of the free FWHM
for i in range(npeaks):
@@ -762,8 +858,9 @@ class FitTheories(object):
height = fittedpar[3 * i]
position = fittedpar[3 * i + 1]
fwhm = fittedpar[3 * i + 2]
- area = (height * fwhm / (2.0 * numpy.sqrt(2 * numpy.log(2)))
- ) * numpy.sqrt(2 * numpy.pi)
+ area = (height * fwhm / (2.0 * numpy.sqrt(2 * numpy.log(2)))) * numpy.sqrt(
+ 2 * numpy.pi
+ )
# the gaussian parameters
newpar.append(area)
newpar.append(position)
@@ -774,20 +871,20 @@ class FitTheories(object):
st_term = 1
lt_term = 1
step_term = 1
- if self.config['HypermetTails'] != 0:
- g_term = self.config['HypermetTails'] & 1
- st_term = (self.config['HypermetTails'] >> 1) & 1
- lt_term = (self.config['HypermetTails'] >> 2) & 1
- step_term = (self.config['HypermetTails'] >> 3) & 1
+ if self.config["HypermetTails"] != 0:
+ g_term = self.config["HypermetTails"] & 1
+ st_term = (self.config["HypermetTails"] >> 1) & 1
+ lt_term = (self.config["HypermetTails"] >> 2) & 1
+ step_term = (self.config["HypermetTails"] >> 3) & 1
if g_term == 0:
# fix the gaussian parameters
newcons[8 * i, 0] = CFIXED
newcons[8 * i + 1, 0] = CFIXED
newcons[8 * i + 2, 0] = CFIXED
# the short tail parameters
- if ((area * yscaling) <
- self.config['MinGaussArea4ShortTail']) | \
- (st_term == 0):
+ if ((area * yscaling) < self.config["MinGaussArea4ShortTail"]) | (
+ st_term == 0
+ ):
newpar.append(0.0)
newpar.append(0.0)
newcons[8 * i + 3, 0] = CFIXED
@@ -797,18 +894,18 @@ class FitTheories(object):
newcons[8 * i + 4, 1] = 0.0
newcons[8 * i + 4, 2] = 0.0
else:
- newpar.append(self.config['InitialShortTailAreaRatio'])
- newpar.append(self.config['InitialShortTailSlopeRatio'])
+ newpar.append(self.config["InitialShortTailAreaRatio"])
+ newpar.append(self.config["InitialShortTailSlopeRatio"])
newcons[8 * i + 3, 0] = CQUOTED
- newcons[8 * i + 3, 1] = self.config['MinShortTailAreaRatio']
- newcons[8 * i + 3, 2] = self.config['MaxShortTailAreaRatio']
+ newcons[8 * i + 3, 1] = self.config["MinShortTailAreaRatio"]
+ newcons[8 * i + 3, 2] = self.config["MaxShortTailAreaRatio"]
newcons[8 * i + 4, 0] = CQUOTED
- newcons[8 * i + 4, 1] = self.config['MinShortTailSlopeRatio']
- newcons[8 * i + 4, 2] = self.config['MaxShortTailSlopeRatio']
+ newcons[8 * i + 4, 1] = self.config["MinShortTailSlopeRatio"]
+ newcons[8 * i + 4, 2] = self.config["MaxShortTailSlopeRatio"]
# the long tail parameters
- if ((area * yscaling) <
- self.config['MinGaussArea4LongTail']) | \
- (lt_term == 0):
+ if ((area * yscaling) < self.config["MinGaussArea4LongTail"]) | (
+ lt_term == 0
+ ):
newpar.append(0.0)
newpar.append(0.0)
newcons[8 * i + 5, 0] = CFIXED
@@ -818,50 +915,50 @@ class FitTheories(object):
newcons[8 * i + 6, 1] = 0.0
newcons[8 * i + 6, 2] = 0.0
else:
- newpar.append(self.config['InitialLongTailAreaRatio'])
- newpar.append(self.config['InitialLongTailSlopeRatio'])
+ newpar.append(self.config["InitialLongTailAreaRatio"])
+ newpar.append(self.config["InitialLongTailSlopeRatio"])
newcons[8 * i + 5, 0] = CQUOTED
- newcons[8 * i + 5, 1] = self.config['MinLongTailAreaRatio']
- newcons[8 * i + 5, 2] = self.config['MaxLongTailAreaRatio']
+ newcons[8 * i + 5, 1] = self.config["MinLongTailAreaRatio"]
+ newcons[8 * i + 5, 2] = self.config["MaxLongTailAreaRatio"]
newcons[8 * i + 6, 0] = CQUOTED
- newcons[8 * i + 6, 1] = self.config['MinLongTailSlopeRatio']
- newcons[8 * i + 6, 2] = self.config['MaxLongTailSlopeRatio']
+ newcons[8 * i + 6, 1] = self.config["MinLongTailSlopeRatio"]
+ newcons[8 * i + 6, 2] = self.config["MaxLongTailSlopeRatio"]
# the step parameters
- if ((height * yscaling) <
- self.config['MinGaussHeight4StepTail']) | \
- (step_term == 0):
+ if ((height * yscaling) < self.config["MinGaussHeight4StepTail"]) | (
+ step_term == 0
+ ):
newpar.append(0.0)
newcons[8 * i + 7, 0] = CFIXED
newcons[8 * i + 7, 1] = 0.0
newcons[8 * i + 7, 2] = 0.0
else:
- newpar.append(self.config['InitialStepTailHeightRatio'])
+ newpar.append(self.config["InitialStepTailHeightRatio"])
newcons[8 * i + 7, 0] = CQUOTED
- newcons[8 * i + 7, 1] = self.config['MinStepTailHeightRatio']
- newcons[8 * i + 7, 2] = self.config['MaxStepTailHeightRatio']
+ newcons[8 * i + 7, 1] = self.config["MinStepTailHeightRatio"]
+ newcons[8 * i + 7, 2] = self.config["MaxStepTailHeightRatio"]
# if self.config['NoConstraintsFlag'] == 1:
# newcons=numpy.zeros((8*npeaks, 3),numpy.float64)
if npeaks > 0:
if g_term:
- if self.config['PositiveHeightAreaFlag']:
+ if self.config["PositiveHeightAreaFlag"]:
for i in range(npeaks):
newcons[8 * i, 0] = CPOSITIVE
- if self.config['PositiveFwhmFlag']:
+ if self.config["PositiveFwhmFlag"]:
for i in range(npeaks):
newcons[8 * i + 2, 0] = CPOSITIVE
- if self.config['SameFwhmFlag']:
+ if self.config["SameFwhmFlag"]:
for i in range(npeaks):
if i != main_peak:
newcons[8 * i + 2, 0] = CFACTOR
newcons[8 * i + 2, 1] = 8 * main_peak + 2
newcons[8 * i + 2, 2] = 1.0
- if self.config['HypermetQuotedPositionFlag']:
+ if self.config["HypermetQuotedPositionFlag"]:
for i in range(npeaks):
- delta = self.config['DeltaPositionFwhmUnits'] * fwhm
+ delta = self.config["DeltaPositionFwhmUnits"] * fwhm
newcons[8 * i + 1, 0] = CQUOTED
newcons[8 * i + 1, 1] = newpar[8 * i + 1] - delta
newcons[8 * i + 1, 2] = newpar[8 * i + 1] + delta
- if self.config['SameSlopeRatioFlag']:
+ if self.config["SameSlopeRatioFlag"]:
for i in range(npeaks):
if i != main_peak:
newcons[8 * i + 4, 0] = CFACTOR
@@ -870,7 +967,7 @@ class FitTheories(object):
newcons[8 * i + 6, 0] = CFACTOR
newcons[8 * i + 6, 1] = 8 * main_peak + 6
newcons[8 * i + 6, 2] = 1.0
- if self.config['SameAreaRatioFlag']:
+ if self.config["SameAreaRatioFlag"]:
for i in range(npeaks):
if i != main_peak:
newcons[8 * i + 3, 0] = CFACTOR
@@ -897,16 +994,15 @@ class FitTheories(object):
"""
crappyfilter = [-0.25, -0.75, 0.0, 0.75, 0.25]
cutoff = len(crappyfilter) // 2
- y_deriv = numpy.convolve(y,
- crappyfilter,
- mode="valid")
+ y_deriv = numpy.convolve(y, crappyfilter, mode="valid")
# make the derivative's peak have the same amplitude as the step
if max(y_deriv) > 0:
y_deriv = y_deriv * max(y) / max(y_deriv)
fittedpar, newcons = self.estimate_height_position_fwhm(
- x[cutoff:-cutoff], y_deriv)
+ x[cutoff:-cutoff], y_deriv
+ )
data_amplitude = max(y) - min(y)
@@ -914,38 +1010,44 @@ class FitTheories(object):
if len(fittedpar):
npeaks = len(fittedpar) // 3
largest_index = 0
- largest = [data_amplitude,
- fittedpar[3 * largest_index + 1],
- fittedpar[3 * largest_index + 2]]
+ largest = [
+ data_amplitude,
+ fittedpar[3 * largest_index + 1],
+ fittedpar[3 * largest_index + 2],
+ ]
for i in range(npeaks):
if fittedpar[3 * i] > largest[0]:
largest_index = i
- largest = [data_amplitude,
- fittedpar[3 * largest_index + 1],
- fittedpar[3 * largest_index + 2]]
+ largest = [
+ data_amplitude,
+ fittedpar[3 * largest_index + 1],
+ fittedpar[3 * largest_index + 2],
+ ]
else:
# no peak was found
- largest = [data_amplitude, # height
- x[len(x)//2], # center: middle of x range
- self.config["FwhmPoints"] * (x[1] - x[0])] # fwhm: default value
+ largest = [
+ data_amplitude, # height
+ x[len(x) // 2], # center: middle of x range
+ self.config["FwhmPoints"] * (x[1] - x[0]),
+ ] # fwhm: default value
# Setup constrains
newcons = numpy.zeros((3, 3), numpy.float64)
- if not self.config['NoConstraintsFlag']:
- # Setup height constrains
- if self.config['PositiveHeightAreaFlag']:
+ if not self.config["NoConstraintsFlag"]:
+ # Setup height constrains
+ if self.config["PositiveHeightAreaFlag"]:
newcons[0, 0] = CPOSITIVE
newcons[0, 1] = 0
newcons[0, 2] = 0
# Setup position constrains
- if self.config['QuotedPositionFlag']:
+ if self.config["QuotedPositionFlag"]:
newcons[1, 0] = CQUOTED
newcons[1, 1] = min(x)
newcons[1, 2] = max(x)
# Setup positive FWHM constrains
- if self.config['PositiveFwhmFlag']:
+ if self.config["PositiveFwhmFlag"]:
newcons[2, 0] = CPOSITIVE
newcons[2, 1] = 0
newcons[2, 2] = 0
@@ -984,27 +1086,27 @@ class FitTheories(object):
largest = [height, position, fwhm, beamfwhm]
cons = numpy.zeros((4, 3), numpy.float64)
# Setup constrains
- if not self.config['NoConstraintsFlag']:
+ if not self.config["NoConstraintsFlag"]:
# Setup height constrains
- if self.config['PositiveHeightAreaFlag']:
+ if self.config["PositiveHeightAreaFlag"]:
cons[0, 0] = CPOSITIVE
cons[0, 1] = 0
cons[0, 2] = 0
# Setup position constrains
- if self.config['QuotedPositionFlag']:
+ if self.config["QuotedPositionFlag"]:
cons[1, 0] = CQUOTED
cons[1, 1] = min(x)
cons[1, 2] = max(x)
# Setup positive FWHM constrains
- if self.config['PositiveFwhmFlag']:
+ if self.config["PositiveFwhmFlag"]:
cons[2, 0] = CPOSITIVE
cons[2, 1] = 0
cons[2, 2] = 0
# Setup positive FWHM constrains
- if self.config['PositiveFwhmFlag']:
+ if self.config["PositiveFwhmFlag"]:
cons[3, 0] = CPOSITIVE
cons[3, 1] = 0
cons[3, 2] = 0
@@ -1030,8 +1132,7 @@ class FitTheories(object):
if max(y_deriv) > 0:
y_deriv = y_deriv * max(y) / max(y_deriv)
- fittedpar, cons = self.estimate_height_position_fwhm(
- x[cutoff:-cutoff], y_deriv)
+ fittedpar, cons = self.estimate_height_position_fwhm(x[cutoff:-cutoff], y_deriv)
# for height, use the data amplitude after removing the background
data_amplitude = max(y) - min(y)
@@ -1040,38 +1141,44 @@ class FitTheories(object):
if len(fittedpar):
npeaks = len(fittedpar) // 3
largest_index = 0
- largest = [data_amplitude,
- fittedpar[3 * largest_index + 1],
- fittedpar[3 * largest_index + 2]]
+ largest = [
+ data_amplitude,
+ fittedpar[3 * largest_index + 1],
+ fittedpar[3 * largest_index + 2],
+ ]
for i in range(npeaks):
if fittedpar[3 * i] > largest[0]:
largest_index = i
- largest = [fittedpar[3 * largest_index],
- fittedpar[3 * largest_index + 1],
- fittedpar[3 * largest_index + 2]]
+ largest = [
+ fittedpar[3 * largest_index],
+ fittedpar[3 * largest_index + 1],
+ fittedpar[3 * largest_index + 2],
+ ]
else:
# no peak was found
- largest = [data_amplitude, # height
- x[len(x)//2], # center: middle of x range
- self.config["FwhmPoints"] * (x[1] - x[0])] # fwhm: default value
+ largest = [
+ data_amplitude, # height
+ x[len(x) // 2], # center: middle of x range
+ self.config["FwhmPoints"] * (x[1] - x[0]),
+ ] # fwhm: default value
newcons = numpy.zeros((3, 3), numpy.float64)
# Setup constrains
- if not self.config['NoConstraintsFlag']:
- # Setup height constraints
- if self.config['PositiveHeightAreaFlag']:
+ if not self.config["NoConstraintsFlag"]:
+ # Setup height constraints
+ if self.config["PositiveHeightAreaFlag"]:
newcons[0, 0] = CPOSITIVE
newcons[0, 1] = 0
newcons[0, 2] = 0
# Setup position constraints
- if self.config['QuotedPositionFlag']:
+ if self.config["QuotedPositionFlag"]:
newcons[1, 0] = CQUOTED
newcons[1, 1] = min(x)
newcons[1, 2] = max(x)
# Setup positive FWHM constraints
- if self.config['PositiveFwhmFlag']:
+ if self.config["PositiveFwhmFlag"]:
newcons[2, 0] = CPOSITIVE
newcons[2, 1] = 0
newcons[2, 2] = 0
@@ -1096,17 +1203,17 @@ class FitTheories(object):
:param y: Array of ordinate values (``y = f(x)``)
:return: Tuple of estimated fit parameters and fit constraints.
"""
- yscaling = self.config.get('Yscaling', 1.0)
+ yscaling = self.config.get("Yscaling", 1.0)
if yscaling == 0:
yscaling = 1.0
bg = self.strip_bg(y)
- if self.config['AutoFwhm']:
+ if self.config["AutoFwhm"]:
search_fwhm = guess_fwhm(y)
else:
- search_fwhm = int(float(self.config['FwhmPoints']))
- search_sens = float(self.config['Sensitivity'])
+ search_fwhm = int(float(self.config["FwhmPoints"]))
+ search_sens = float(self.config["Sensitivity"])
if search_fwhm < 3:
search_fwhm = 3
@@ -1115,8 +1222,7 @@ class FitTheories(object):
search_sens = 1
if len(y) > 1.5 * search_fwhm:
- peaks = peak_search(yscaling * y, fwhm=search_fwhm,
- sensitivity=search_sens)
+ peaks = peak_search(yscaling * y, fwhm=search_fwhm, sensitivity=search_sens)
else:
peaks = []
npeaks = len(peaks)
@@ -1136,7 +1242,7 @@ class FitTheories(object):
for i in range(npeaks):
height += y[int(peaks[i])] - bg[int(peaks[i])]
if i != npeaks - 1:
- delta += (x[int(peaks[i + 1])] - x[int(peaks[i])])
+ delta += x[int(peaks[i + 1])] - x[int(peaks[i])]
# delta between peaks
if npeaks > 1:
@@ -1160,8 +1266,8 @@ class FitTheories(object):
cons[1, 0] = CFREE
j = 2
# Setup height area constrains
- if not self.config['NoConstraintsFlag']:
- if self.config['PositiveHeightAreaFlag']:
+ if not self.config["NoConstraintsFlag"]:
+ if self.config["PositiveHeightAreaFlag"]:
# POSITIVE = 1
cons[j, 0] = CPOSITIVE
cons[j, 1] = 0
@@ -1169,8 +1275,8 @@ class FitTheories(object):
j += 1
# Setup position constrains
- if not self.config['NoConstraintsFlag']:
- if self.config['QuotedPositionFlag']:
+ if not self.config["NoConstraintsFlag"]:
+ if self.config["QuotedPositionFlag"]:
# QUOTED = 2
cons[j, 0] = CQUOTED
cons[j, 1] = min(x)
@@ -1178,8 +1284,8 @@ class FitTheories(object):
j += 1
# Setup positive FWHM constrains
- if not self.config['NoConstraintsFlag']:
- if self.config['PositiveFwhmFlag']:
+ if not self.config["NoConstraintsFlag"]:
+ if self.config["PositiveFwhmFlag"]:
# POSITIVE=1
cons[j, 0] = CPOSITIVE
cons[j, 1] = 0
@@ -1208,127 +1314,223 @@ class FitTheories(object):
self.config[key] = kw[key]
return self.config
+
fitfuns = FitTheories()
-THEORY = OrderedDict((
- ('Gaussians',
- FitTheory(description='Gaussian functions',
- function=functions.sum_gauss,
- parameters=('Height', 'Position', 'FWHM'),
- estimate=fitfuns.estimate_height_position_fwhm,
- configure=fitfuns.configure)),
- ('Lorentz',
- FitTheory(description='Lorentzian functions',
- function=functions.sum_lorentz,
- parameters=('Height', 'Position', 'FWHM'),
- estimate=fitfuns.estimate_height_position_fwhm,
- configure=fitfuns.configure)),
- ('Area Gaussians',
- FitTheory(description='Gaussian functions (area)',
- function=functions.sum_agauss,
- parameters=('Area', 'Position', 'FWHM'),
- estimate=fitfuns.estimate_agauss,
- configure=fitfuns.configure)),
- ('Area Lorentz',
- FitTheory(description='Lorentzian functions (area)',
- function=functions.sum_alorentz,
- parameters=('Area', 'Position', 'FWHM'),
- estimate=fitfuns.estimate_alorentz,
- configure=fitfuns.configure)),
- ('Pseudo-Voigt Line',
- FitTheory(description='Pseudo-Voigt functions',
- function=functions.sum_pvoigt,
- parameters=('Height', 'Position', 'FWHM', 'Eta'),
- estimate=fitfuns.estimate_pvoigt,
- configure=fitfuns.configure)),
- ('Area Pseudo-Voigt',
- FitTheory(description='Pseudo-Voigt functions (area)',
- function=functions.sum_apvoigt,
- parameters=('Area', 'Position', 'FWHM', 'Eta'),
- estimate=fitfuns.estimate_apvoigt,
- configure=fitfuns.configure)),
- ('Split Gaussian',
- FitTheory(description='Asymmetric gaussian functions',
- function=functions.sum_splitgauss,
- parameters=('Height', 'Position', 'LowFWHM',
- 'HighFWHM'),
- estimate=fitfuns.estimate_splitgauss,
- configure=fitfuns.configure)),
- ('Split Lorentz',
- FitTheory(description='Asymmetric lorentzian functions',
- function=functions.sum_splitlorentz,
- parameters=('Height', 'Position', 'LowFWHM', 'HighFWHM'),
- estimate=fitfuns.estimate_splitgauss,
- configure=fitfuns.configure)),
- ('Split Pseudo-Voigt',
- FitTheory(description='Asymmetric pseudo-Voigt functions',
- function=functions.sum_splitpvoigt,
- parameters=('Height', 'Position', 'LowFWHM',
- 'HighFWHM', 'Eta'),
- estimate=fitfuns.estimate_splitpvoigt,
- configure=fitfuns.configure)),
- ('Step Down',
- FitTheory(description='Step down function',
- function=functions.sum_stepdown,
- parameters=('Height', 'Position', 'FWHM'),
- estimate=fitfuns.estimate_stepdown,
- configure=fitfuns.configure)),
- ('Step Up',
- FitTheory(description='Step up function',
- function=functions.sum_stepup,
- parameters=('Height', 'Position', 'FWHM'),
- estimate=fitfuns.estimate_stepup,
- configure=fitfuns.configure)),
- ('Slit',
- FitTheory(description='Slit function',
- function=functions.sum_slit,
- parameters=('Height', 'Position', 'FWHM', 'BeamFWHM'),
- estimate=fitfuns.estimate_slit,
- configure=fitfuns.configure)),
- ('Atan',
- FitTheory(description='Arctan step up function',
- function=functions.atan_stepup,
- parameters=('Height', 'Position', 'Width'),
- estimate=fitfuns.estimate_stepup,
- configure=fitfuns.configure)),
- ('Hypermet',
- FitTheory(description='Hypermet functions',
- function=fitfuns.ahypermet, # customized version of functions.sum_ahypermet
- parameters=('G_Area', 'Position', 'FWHM', 'ST_Area',
- 'ST_Slope', 'LT_Area', 'LT_Slope', 'Step_H'),
- estimate=fitfuns.estimate_ahypermet,
- configure=fitfuns.configure)),
- # ('Periodic Gaussians',
- # FitTheory(description='Periodic gaussian functions',
- # function=functions.periodic_gauss,
- # parameters=('N', 'Delta', 'Height', 'Position', 'FWHM'),
- # estimate=fitfuns.estimate_periodic_gauss,
- # configure=fitfuns.configure))
- ('Degree 2 Polynomial',
- FitTheory(description='Degree 2 polynomial'
- '\ny = a*x^2 + b*x +c',
- function=fitfuns.poly,
- parameters=['a', 'b', 'c'],
- estimate=fitfuns.estimate_quadratic)),
- ('Degree 3 Polynomial',
- FitTheory(description='Degree 3 polynomial'
- '\ny = a*x^3 + b*x^2 + c*x + d',
- function=fitfuns.poly,
- parameters=['a', 'b', 'c', 'd'],
- estimate=fitfuns.estimate_cubic)),
- ('Degree 4 Polynomial',
- FitTheory(description='Degree 4 polynomial'
- '\ny = a*x^4 + b*x^3 + c*x^2 + d*x + e',
- function=fitfuns.poly,
- parameters=['a', 'b', 'c', 'd', 'e'],
- estimate=fitfuns.estimate_quartic)),
- ('Degree 5 Polynomial',
- FitTheory(description='Degree 5 polynomial'
- '\ny = a*x^5 + b*x^4 + c*x^3 + d*x^2 + e*x + f',
- function=fitfuns.poly,
- parameters=['a', 'b', 'c', 'd', 'e', 'f'],
- estimate=fitfuns.estimate_quintic)),
-))
+THEORY = dict(
+ (
+ (
+ "Gaussians",
+ FitTheory(
+ description="Gaussian functions",
+ function=functions.sum_gauss,
+ parameters=("Height", "Position", "FWHM"),
+ estimate=fitfuns.estimate_height_position_fwhm,
+ configure=fitfuns.configure,
+ ),
+ ),
+ (
+ "Lorentz",
+ FitTheory(
+ description="Lorentzian functions",
+ function=functions.sum_lorentz,
+ parameters=("Height", "Position", "FWHM"),
+ estimate=fitfuns.estimate_height_position_fwhm,
+ configure=fitfuns.configure,
+ ),
+ ),
+ (
+ "Area Gaussians",
+ FitTheory(
+ description="Gaussian functions (area)",
+ function=functions.sum_agauss,
+ parameters=("Area", "Position", "FWHM"),
+ estimate=fitfuns.estimate_agauss,
+ configure=fitfuns.configure,
+ ),
+ ),
+ (
+ "Area Lorentz",
+ FitTheory(
+ description="Lorentzian functions (area)",
+ function=functions.sum_alorentz,
+ parameters=("Area", "Position", "FWHM"),
+ estimate=fitfuns.estimate_alorentz,
+ configure=fitfuns.configure,
+ ),
+ ),
+ (
+ "Pseudo-Voigt Line",
+ FitTheory(
+ description="Pseudo-Voigt functions",
+ function=functions.sum_pvoigt,
+ parameters=("Height", "Position", "FWHM", "Eta"),
+ estimate=fitfuns.estimate_pvoigt,
+ configure=fitfuns.configure,
+ ),
+ ),
+ (
+ "Area Pseudo-Voigt",
+ FitTheory(
+ description="Pseudo-Voigt functions (area)",
+ function=functions.sum_apvoigt,
+ parameters=("Area", "Position", "FWHM", "Eta"),
+ estimate=fitfuns.estimate_apvoigt,
+ configure=fitfuns.configure,
+ ),
+ ),
+ (
+ "Split Gaussian",
+ FitTheory(
+ description="Asymmetric gaussian functions",
+ function=functions.sum_splitgauss,
+ parameters=("Height", "Position", "LowFWHM", "HighFWHM"),
+ estimate=fitfuns.estimate_splitgauss,
+ configure=fitfuns.configure,
+ ),
+ ),
+ (
+ "Split Lorentz",
+ FitTheory(
+ description="Asymmetric lorentzian functions",
+ function=functions.sum_splitlorentz,
+ parameters=("Height", "Position", "LowFWHM", "HighFWHM"),
+ estimate=fitfuns.estimate_splitgauss,
+ configure=fitfuns.configure,
+ ),
+ ),
+ (
+ "Split Pseudo-Voigt",
+ FitTheory(
+ description="Asymmetric pseudo-Voigt functions",
+ function=functions.sum_splitpvoigt,
+ parameters=("Height", "Position", "LowFWHM", "HighFWHM", "Eta"),
+ estimate=fitfuns.estimate_splitpvoigt,
+ configure=fitfuns.configure,
+ ),
+ ),
+ (
+ "Split Pseudo-Voigt 2",
+ FitTheory(
+ description="Asymmetric pseudo-Voigt functions",
+ function=functions.sum_splitpvoigt2,
+ parameters=(
+ "Height",
+ "Position",
+ "LowFWHM",
+ "HighFWHM",
+ "LowEta",
+ "HighEta",
+ ),
+ estimate=fitfuns.estimate_splitpvoigt2,
+ configure=fitfuns.configure,
+ ),
+ ),
+ (
+ "Step Down",
+ FitTheory(
+ description="Step down function",
+ function=functions.sum_stepdown,
+ parameters=("Height", "Position", "FWHM"),
+ estimate=fitfuns.estimate_stepdown,
+ configure=fitfuns.configure,
+ ),
+ ),
+ (
+ "Step Up",
+ FitTheory(
+ description="Step up function",
+ function=functions.sum_stepup,
+ parameters=("Height", "Position", "FWHM"),
+ estimate=fitfuns.estimate_stepup,
+ configure=fitfuns.configure,
+ ),
+ ),
+ (
+ "Slit",
+ FitTheory(
+ description="Slit function",
+ function=functions.sum_slit,
+ parameters=("Height", "Position", "FWHM", "BeamFWHM"),
+ estimate=fitfuns.estimate_slit,
+ configure=fitfuns.configure,
+ ),
+ ),
+ (
+ "Atan",
+ FitTheory(
+ description="Arctan step up function",
+ function=functions.atan_stepup,
+ parameters=("Height", "Position", "Width"),
+ estimate=fitfuns.estimate_stepup,
+ configure=fitfuns.configure,
+ ),
+ ),
+ (
+ "Hypermet",
+ FitTheory(
+ description="Hypermet functions",
+ function=fitfuns.ahypermet, # customized version of functions.sum_ahypermet
+ parameters=(
+ "G_Area",
+ "Position",
+ "FWHM",
+ "ST_Area",
+ "ST_Slope",
+ "LT_Area",
+ "LT_Slope",
+ "Step_H",
+ ),
+ estimate=fitfuns.estimate_ahypermet,
+ configure=fitfuns.configure,
+ ),
+ ),
+ # ('Periodic Gaussians',
+ # FitTheory(description='Periodic gaussian functions',
+ # function=functions.periodic_gauss,
+ # parameters=('N', 'Delta', 'Height', 'Position', 'FWHM'),
+ # estimate=fitfuns.estimate_periodic_gauss,
+ # configure=fitfuns.configure))
+ (
+ "Degree 2 Polynomial",
+ FitTheory(
+ description="Degree 2 polynomial" "\ny = a*x^2 + b*x +c",
+ function=fitfuns.poly,
+ parameters=["a", "b", "c"],
+ estimate=fitfuns.estimate_quadratic,
+ ),
+ ),
+ (
+ "Degree 3 Polynomial",
+ FitTheory(
+ description="Degree 3 polynomial" "\ny = a*x^3 + b*x^2 + c*x + d",
+ function=fitfuns.poly,
+ parameters=["a", "b", "c", "d"],
+ estimate=fitfuns.estimate_cubic,
+ ),
+ ),
+ (
+ "Degree 4 Polynomial",
+ FitTheory(
+ description="Degree 4 polynomial"
+ "\ny = a*x^4 + b*x^3 + c*x^2 + d*x + e",
+ function=fitfuns.poly,
+ parameters=["a", "b", "c", "d", "e"],
+ estimate=fitfuns.estimate_quartic,
+ ),
+ ),
+ (
+ "Degree 5 Polynomial",
+ FitTheory(
+ description="Degree 5 polynomial"
+ "\ny = a*x^5 + b*x^4 + c*x^3 + d*x^2 + e*x + f",
+ function=fitfuns.poly,
+ parameters=["a", "b", "c", "d", "e", "f"],
+ estimate=fitfuns.estimate_quintic,
+ ),
+ ),
+ )
+)
"""Dictionary of fit theories: fit functions and their associated estimation
function, parameters list, configuration function and description.
"""
@@ -1336,16 +1538,20 @@ function, parameters list, configuration function and description.
def test(a):
from silx.math.fit import fitmanager
+
x = numpy.arange(1000).astype(numpy.float64)
- p = [1500, 100., 50.0,
- 1500, 700., 50.0]
+ p = [1500, 100.0, 50.0, 1500, 700.0, 50.0]
y_synthetic = functions.sum_gauss(x, *p) + 1
fit = fitmanager.FitManager(x, y_synthetic)
- fit.addtheory('Gaussians', functions.sum_gauss, ['Height', 'Position', 'FWHM'],
- a.estimate_height_position_fwhm)
- fit.settheory('Gaussians')
- fit.setbackground('Linear')
+ fit.addtheory(
+ "Gaussians",
+ functions.sum_gauss,
+ ["Height", "Position", "FWHM"],
+ a.estimate_height_position_fwhm,
+ )
+ fit.settheory("Gaussians")
+ fit.setbackground("Linear")
fit.estimate()
fit.runfit()
@@ -1353,12 +1559,13 @@ def test(a):
y_fit = fit.gendata()
print("Fit parameter names: %s" % str(fit.get_names()))
- print("Theoretical parameters: %s" % str(numpy.append([1, 0], p)))
+ print("Theoretical parameters: %s" % str(numpy.append([1, 0], p)))
print("Fitted parameters: %s" % str(fit.get_fitted_parameters()))
try:
from silx.gui import qt
from silx.gui.plot import plot1D
+
app = qt.QApplication([])
# Offset of 1 to see the difference in log scale
diff --git a/src/silx/math/fit/fittheory.py b/src/silx/math/fit/fittheory.py
index ab3ae43..4d2b19b 100644
--- a/src/silx/math/fit/fittheory.py
+++ b/src/silx/math/fit/fittheory.py
@@ -1,4 +1,4 @@
-#/*##########################################################################
+# /*##########################################################################
#
# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
#
@@ -35,19 +35,28 @@ __date__ = "09/08/2016"
class FitTheory(object):
"""This class defines a fit theory, which consists of:
- - a model function, the actual function to be fitted
- - parameters names
- - an estimation function, that return the estimated initial parameters
- that serve as input for :func:`silx.math.fit.leastsq`
- - an optional configuration function, that can be used to modify
- configuration parameters to alter the behavior of the fit function
- and the estimation function
- - an optional derivative function, that replaces the default model
- derivative used in :func:`silx.math.fit.leastsq`
+ - a model function, the actual function to be fitted
+ - parameters names
+ - an estimation function, that return the estimated initial parameters
+ that serve as input for :func:`silx.math.fit.leastsq`
+ - an optional configuration function, that can be used to modify
+ configuration parameters to alter the behavior of the fit function
+ and the estimation function
+ - an optional derivative function, that replaces the default model
+ derivative used in :func:`silx.math.fit.leastsq`
"""
- def __init__(self, function, parameters,
- estimate=None, configure=None, derivative=None,
- description=None, pymca_legacy=False, is_background=False):
+
+ def __init__(
+ self,
+ function,
+ parameters,
+ estimate=None,
+ configure=None,
+ derivative=None,
+ description=None,
+ pymca_legacy=False,
+ is_background=False,
+ ):
"""
:param function function: Actual function. See documentation for
:attr:`function`.
@@ -155,6 +164,6 @@ class FitTheory(object):
"""Default estimate function. Return an array of *ones* as the
initial estimated parameters, and set all constraints to zero
(FREE)"""
- estimated_parameters = [1. for _ in self.parameters]
+ estimated_parameters = [1.0 for _ in self.parameters]
estimated_constraints = [[0, 0, 0] for _ in self.parameters]
return estimated_parameters, estimated_constraints
diff --git a/src/silx/math/fit/functions.pyx b/src/silx/math/fit/functions.pyx
index a69086c..e7102a5 100644
--- a/src/silx/math/fit/functions.pyx
+++ b/src/silx/math/fit/functions.pyx
@@ -33,6 +33,7 @@ List of fit functions:
- :func:`sum_apvoigt`
- :func:`sum_pvoigt`
- :func:`sum_splitpvoigt`
+ - :func:`sum_splitpvoigt2`
- :func:`sum_lorentz`
- :func:`sum_alorentz`
@@ -143,9 +144,7 @@ def sum_gauss(x, *params):
double[::1] params_c
double[::1] y_c
- if not len(params):
- raise IndexError("No gaussian parameters specified. " +
- "At least 3 parameters are required.")
+ _validate_parameters(params, 3)
# ensure float64 (double) type and 1D contiguous data layout in memory
x_c = numpy.array(x,
@@ -191,9 +190,7 @@ def sum_agauss(x, *params):
double[::1] params_c
double[::1] y_c
- if not len(params):
- raise IndexError("No gaussian parameters specified. " +
- "At least 3 parameters are required.")
+ _validate_parameters(params, 3)
x_c = numpy.array(x,
copy=False,
@@ -241,9 +238,7 @@ def sum_fastagauss(x, *params):
double[::1] params_c
double[::1] y_c
- if not len(params):
- raise IndexError("No gaussian parameters specified. " +
- "At least 3 parameters are required.")
+ _validate_parameters(params, 3)
x_c = numpy.array(x,
copy=False,
@@ -290,9 +285,7 @@ def sum_splitgauss(x, *params):
double[::1] params_c
double[::1] y_c
- if not len(params):
- raise IndexError("No gaussian parameters specified. " +
- "At least 4 parameters are required.")
+ _validate_parameters(params, 4)
x_c = numpy.array(x,
copy=False,
@@ -327,7 +320,7 @@ def sum_apvoigt(x, *params):
- *area* is the area underneath both G(x) and L(x)
- *centroid* is the peak x-coordinate for both functions
- *fwhm* is the full-width at half maximum of both functions
- - *eta* is the Lorentz factor: PV(x) = eta * L(x) + (1 - eta) * G(x)
+ - *eta* is the Lorentzian fraction: PV(x) = eta * L(x) + (1 - eta) * G(x)
:param x: Independent variable where the gaussians are calculated
:type x: numpy.ndarray
@@ -341,9 +334,8 @@ def sum_apvoigt(x, *params):
double[::1] params_c
double[::1] y_c
- if not len(params):
- raise IndexError("No parameters specified. " +
- "At least 4 parameters are required.")
+ _validate_parameters(params, 4)
+
x_c = numpy.array(x,
copy=False,
dtype=numpy.float64,
@@ -377,7 +369,7 @@ def sum_pvoigt(x, *params):
- *height* is the peak amplitude of G(x) and L(x)
- *centroid* is the peak x-coordinate for both functions
- *fwhm* is the full-width at half maximum of both functions
- - *eta* is the Lorentz factor: PV(x) = eta * L(x) + (1 - eta) * G(x)
+ - *eta* is the Lorentzian fraction: PV(x) = eta * L(x) + (1 - eta) * G(x)
:param x: Independent variable where the gaussians are calculated
:type x: numpy.ndarray
@@ -391,9 +383,7 @@ def sum_pvoigt(x, *params):
double[::1] params_c
double[::1] y_c
- if not len(params):
- raise IndexError("No parameters specified. " +
- "At least 4 parameters are required.")
+ _validate_parameters(params, 4)
x_c = numpy.array(x,
copy=False,
@@ -425,13 +415,13 @@ def sum_splitpvoigt(x, *params):
profile using a linear combination of a Gaussian curve ``G(x)`` and a
Lorentzian curve ``L(x)`` instead of their convolution.
- - *height* is the peak amplitudefor G(x) and L(x)
+ - *height* is the peak amplitude for G(x) and L(x)
- *centroid* is the peak x-coordinate for both functions
- *fwhm1* is the full-width at half maximum of both functions
when ``x < centroid``
- *fwhm2* is the full-width at half maximum of both functions
when ``x > centroid``
- - *eta* is the Lorentz factor: PV(x) = eta * L(x) + (1 - eta) * G(x)
+ - *eta* is the Lorentzian fraction: PV(x) = eta * L(x) + (1 - eta) * G(x)
:param x: Independent variable where the gaussians are calculated
:type x: numpy.ndarray
@@ -446,9 +436,7 @@ def sum_splitpvoigt(x, *params):
double[::1] params_c
double[::1] y_c
- if not len(params):
- raise IndexError("No parameters specified. " +
- "At least 5 parameters are required.")
+ _validate_parameters(params, 5)
x_c = numpy.array(x,
copy=False,
@@ -472,6 +460,60 @@ def sum_splitpvoigt(x, *params):
return numpy.asarray(y_c).reshape(x.shape)
+def sum_splitpvoigt2(x, *params):
+ """Return a sum of split pseudo-Voigt functions, defined by *(height,
+ centroid, fwhm1, fwhm2, eta1, eta2)*.
+
+ The pseudo-Voigt profile ``PV(x)`` is an approximation of the Voigt
+ profile using a linear combination of a Gaussian curve ``G(x)`` and a
+ Lorentzian curve ``L(x)`` instead of their convolution.
+
+ - *height* is the peak amplitude for G(x) and L(x)
+ - *centroid* is the peak x-coordinate for both functions
+ - *fwhm1* is the full-width at half maximum of both functions
+ when ``x < centroid``
+ - *fwhm2* is the full-width at half maximum of both functions
+ when ``x > centroid``
+ - *eta1* is the Lorentzian fraction when ``x < centroid``
+ - *eta2* is the Lorentzian fraction when ``x > centroid``
+
+ :param x: Independent variable where the gaussians are calculated
+ :type x: numpy.ndarray
+ :param params: Array of pseudo-Voigt parameters (length must be a multiple
+ of 6):
+ *(height1, centroid1, fwhm11, fwhm21, eta11, eta21,...)*
+ :return: Array of sum of split pseudo-Voigt functions at each ``x``
+ coordinate
+ """
+ cdef:
+ double[::1] x_c
+ double[::1] params_c
+ double[::1] y_c
+
+ _validate_parameters(params, 6)
+
+ x_c = numpy.array(x,
+ copy=False,
+ dtype=numpy.float64,
+ order='C').reshape(-1)
+ params_c = numpy.array(params,
+ copy=False,
+ dtype=numpy.float64,
+ order='C').reshape(-1)
+ y_c = numpy.empty(shape=(x.size,),
+ dtype=numpy.float64)
+
+ status = functions_wrapper.sum_splitpvoigt2(
+ &x_c[0], x.size,
+ &params_c[0], params_c.size,
+ &y_c[0])
+
+ if status:
+ raise IndexError("Wrong number of parameters for function")
+
+ return numpy.asarray(y_c).reshape(x.shape)
+
+
def sum_lorentz(x, *params):
"""Return a sum of Lorentz distributions, also known as Cauchy distribution,
defined by *(height, centroid, fwhm)*.
@@ -493,9 +535,7 @@ def sum_lorentz(x, *params):
double[::1] params_c
double[::1] y_c
- if not len(params):
- raise IndexError("No parameters specified. " +
- "At least 3 parameters are required.")
+ _validate_parameters(params, 3)
x_c = numpy.array(x,
copy=False,
@@ -540,9 +580,7 @@ def sum_alorentz(x, *params):
double[::1] params_c
double[::1] y_c
- if not len(params):
- raise IndexError("No parameters specified. " +
- "At least 3 parameters are required.")
+ _validate_parameters(params, 3)
x_c = numpy.array(x,
copy=False,
@@ -588,9 +626,7 @@ def sum_splitlorentz(x, *params):
double[::1] params_c
double[::1] y_c
- if not len(params):
- raise IndexError("No parameters specified. " +
- "At least 4 parameters are required.")
+ _validate_parameters(params, 4)
x_c = numpy.array(x,
copy=False,
@@ -636,9 +672,8 @@ def sum_stepdown(x, *params):
double[::1] params_c
double[::1] y_c
- if not len(params):
- raise IndexError("No parameters specified. " +
- "At least 3 parameters are required.")
+ _validate_parameters(params, 3)
+
x_c = numpy.array(x,
copy=False,
dtype=numpy.float64,
@@ -684,9 +719,7 @@ def sum_stepup(x, *params):
double[::1] params_c
double[::1] y_c
- if not len(params):
- raise IndexError("No parameters specified. " +
- "At least 3 parameters are required.")
+ _validate_parameters(params, 3)
x_c = numpy.array(x,
copy=False,
@@ -735,9 +768,7 @@ def sum_slit(x, *params):
double[::1] params_c
double[::1] y_c
- if not len(params):
- raise IndexError("No parameters specified. " +
- "At least 4 parameters are required.")
+ _validate_parameters(params, 4)
x_c = numpy.array(x,
copy=False,
@@ -797,9 +828,9 @@ def sum_ahypermet(x, *params,
*(area1, position1, fwhm1, st_area_r1, st_slope_r1, lt_area_r1,
lt_slope_r1, step_height_r1...)*
:param gaussian_term: If ``True``, enable gaussian term. Default ``True``
- :param st_term: If ``True``, enable gaussian term. Default ``True``
- :param lt_term: If ``True``, enable gaussian term. Default ``True``
- :param step_term: If ``True``, enable gaussian term. Default ``True``
+ :param st_term: If ``True``, enable short tail term. Default ``True``
+ :param lt_term: If ``True``, enable long tail term. Default ``True``
+ :param step_term: If ``True``, enable step term. Default ``True``
:return: Array of sum of hypermet functions at each ``x`` coordinate
"""
cdef:
@@ -807,9 +838,7 @@ def sum_ahypermet(x, *params,
double[::1] params_c
double[::1] y_c
- if not len(params):
- raise IndexError("No parameters specified. " +
- "At least 8 parameters are required.")
+ _validate_parameters(params, 8)
# Sum binary flags to activate various terms of the equation
tail_flags = 1 if gaussian_term else 0
@@ -883,9 +912,9 @@ def sum_fastahypermet(x, *params,
*(area1, position1, fwhm1, st_area_r1, st_slope_r1, lt_area_r1,
lt_slope_r1, step_height_r1...)*
:param gaussian_term: If ``True``, enable gaussian term. Default ``True``
- :param st_term: If ``True``, enable gaussian term. Default ``True``
- :param lt_term: If ``True``, enable gaussian term. Default ``True``
- :param step_term: If ``True``, enable gaussian term. Default ``True``
+ :param st_term: If ``True``, enable short tail term. Default ``True``
+ :param lt_term: If ``True``, enable long tail term. Default ``True``
+ :param step_term: If ``True``, enable step term. Default ``True``
:return: Array of sum of hypermet functions at each ``x`` coordinate
"""
cdef:
@@ -893,9 +922,7 @@ def sum_fastahypermet(x, *params,
double[::1] params_c
double[::1] y_c
- if not len(params):
- raise IndexError("No parameters specified. " +
- "At least 8 parameters are required.")
+ _validate_parameters(params, 8)
# Sum binary flags to activate various terms of the equation
tail_flags = 1 if gaussian_term else 0
@@ -955,7 +982,7 @@ def atan_stepup(x, a, b, c):
return a * (0.5 + (numpy.arctan((1.0 * x - b) / c) / numpy.pi))
-def periodic_gauss(x, *pars):
+def periodic_gauss(x, *params):
"""
Return a sum of gaussian functions defined by
*(npeaks, delta, height, centroid, fwhm)*,
@@ -968,17 +995,22 @@ def periodic_gauss(x, *pars):
- *fwhm* is the full-width at half maximum for all the gaussians
:param x: Independent variable where the function is calculated
- :param pars: *(npeaks, delta, height, centroid, fwhm)*
+ :param params: *(npeaks, delta, height, centroid, fwhm)*
:return: Sum of ``npeaks`` gaussians
"""
- if not len(pars):
- raise IndexError("No parameters specified. " +
- "At least 5 parameters are required.")
+ _validate_parameters(params, 5)
- newpars = numpy.zeros((pars[0], 3), numpy.float64)
- for i in range(int(pars[0])):
- newpars[i, 0] = pars[2]
- newpars[i, 1] = pars[3] + i * pars[1]
- newpars[:, 2] = pars[4]
+ newpars = numpy.zeros((params[0], 3), numpy.float64)
+ for i in range(int(params[0])):
+ newpars[i, 0] = params[2]
+ newpars[i, 1] = params[3] + i * params[1]
+ newpars[:, 2] = params[4]
return sum_gauss(x, newpars)
+
+
+def _validate_parameters(params, multiple):
+ if len(params) == 0:
+ raise IndexError("No parameters specified.")
+ if len(params) % multiple:
+ raise IndexError(f"The number of parameters should be a multiple of {multiple}.")
diff --git a/src/silx/math/fit/functions/include/functions.h b/src/silx/math/fit/functions/include/functions.h
index de4209b..cf084b2 100644
--- a/src/silx/math/fit/functions/include/functions.h
+++ b/src/silx/math/fit/functions/include/functions.h
@@ -53,6 +53,7 @@ int sum_splitgauss(double* x, int len_x, double* pgauss, int len_pgauss, double*
int sum_apvoigt(double* x, int len_x, double* pvoigt, int len_pvoigt, double* y);
int sum_pvoigt(double* x, int len_x, double* pvoigt, int len_pvoigt, double* y);
int sum_splitpvoigt(double* x, int len_x, double* pvoigt, int len_pvoigt, double* y);
+int sum_splitpvoigt2(double* x, int len_x, double* pvoigt, int len_pvoigt, double* y);
int sum_lorentz(double* x, int len_x, double* plorentz, int len_plorentz, double* y);
int sum_alorentz(double* x, int len_x, double* plorentz, int len_plorentz, double* y);
diff --git a/src/silx/math/fit/functions/src/funs.c b/src/silx/math/fit/functions/src/funs.c
index aae173f..4b41fce 100644
--- a/src/silx/math/fit/functions/src/funs.c
+++ b/src/silx/math/fit/functions/src/funs.c
@@ -434,7 +434,7 @@ int sum_splitgauss(double* x, int len_x, double* pgauss, int len_pgauss, double*
*area* is the area underneath both G(x) and L(x)
*centroid* is the peak x-coordinate for both functions
*fwhm* is the full-width at half maximum of both functions
- *eta* is the Lorentz factor: PV(x) = eta * L(x) + (1 - eta) * G(x)
+ *eta* is the Lorentzian fraction: PV(x) = eta * L(x) + (1 - eta) * G(x)
Parameters:
-----------
@@ -504,7 +504,7 @@ int sum_apvoigt(double* x, int len_x, double* pvoigt, int len_pvoigt, double* y)
*height* is the peak amplitude of G(x) and L(x)
*centroid* is the peak x-coordinate for both functions
*fwhm* is the full-width at half maximum of both functions
- *eta* is the Lorentz factor: PV(x) = eta * L(x) + (1 - eta) * G(x)
+ *eta* is the Lorentzian fraction: PV(x) = eta * L(x) + (1 - eta) * G(x)
Parameters:
-----------
@@ -573,7 +573,7 @@ int sum_pvoigt(double* x, int len_x, double* pvoigt, int len_pvoigt, double* y)
*centroid* is the peak x-coordinate for both functions
*fwhm1* is the full-width at half maximum of both functions for x < centroid
*fwhm2* is the full-width at half maximum of both functions for x > centroid
- *eta* is the Lorentz factor: PV(x) = eta * L(x) + (1 - eta) * G(x)
+ *eta* is the Lorentzian fraction: PV(x) = eta * L(x) + (1 - eta) * G(x)
Parameters:
-----------
@@ -650,6 +650,98 @@ int sum_splitpvoigt(double* x, int len_x, double* pvoigt, int len_pvoigt, double
return(0);
}
+/* sum_splitpvoigt2
+ Sum of split pseudo-Voigt functions, defined by
+ (height, centroid, fwhm1, fwhm2, eta1, eta2).
+
+ The pseudo-Voigt profile PV(x) is an approximation of the Voigt profile
+ using a linear combination of a Gaussian curve G(x) and a Lorentzian curve
+ L(x) instead of their convolution.
+
+ *height* is the peak amplitude of G(x) and L(x)
+ *centroid* is the peak x-coordinate for both functions
+ *fwhm1* is the full-width at half maximum of both functions for x < centroid
+ *fwhm2* is the full-width at half maximum of both functions for x > centroid
+ *eta1* is the Lorentzian fraction for x < centroid
+ *eta2* is the Lorentzian fraction for x > centroid
+
+ Parameters:
+ -----------
+
+ - x: Independant variable where the gaussians are calculated.
+ - len_x: Number of elements in the x array.
+ - pvoigt: Array of Voigt function parameters:
+ (height1, centroid1, fwhm11, fwhm21, eta11, eta21, ...)
+ - len_voigt: Number of elements in the pvoigt array. Must be
+ a multiple of 6.
+ - y: Output array. Must have memory allocated for the same number
+ of elements as x (len_x).
+
+*/
+int sum_splitpvoigt2(double* x, int len_x, double* pvoigt, int len_pvoigt, double* y)
+{
+ int i, j;
+ double dhelp, x_minus_centroid, inv_two_sqrt_two_log2, sigma1, sigma2;
+ double height, centroid, fwhm1, fwhm2, eta1, eta2;
+
+ if (test_params(len_pvoigt, 6, "sum_splitpvoigt2", "height, centroid, fwhm1, fwhm2, eta1, eta2")) {
+ return(1);
+ }
+
+ /* Initialize output array */
+ for (j=0; j<len_x; j++) {
+ y[j] = 0.;
+ }
+
+ inv_two_sqrt_two_log2 = 1.0 / (2.0 * sqrt(2.0 * LOG2));
+
+ for (i=0; i<len_pvoigt/6; i++) {
+ height = pvoigt[6*i];
+ centroid = pvoigt[6*i+1];
+ fwhm1 = pvoigt[6*i+2];
+ fwhm2 = pvoigt[6*i+3];
+ eta1 = pvoigt[6*i+4];
+ eta2 = pvoigt[6*i+5];
+
+ sigma1 = fwhm1 * inv_two_sqrt_two_log2;
+ sigma2 = fwhm2 * inv_two_sqrt_two_log2;
+
+ for (j=0; j<len_x; j++) {
+ x_minus_centroid = (x[j] - centroid);
+
+ /* Use fwhm2 and eta2 when x > centroid */
+ if (x_minus_centroid > 0) {
+ /* Lorentzian term */
+ dhelp = (2.0 * x_minus_centroid) / fwhm2;
+ dhelp = 1.0 + (dhelp * dhelp);
+ y[j] += eta2 * height / dhelp;
+
+ /* Gaussian term */
+ dhelp = x_minus_centroid / sigma2;
+ if (dhelp <= 35) {
+ dhelp = exp(-0.5 * dhelp * dhelp);
+ y[j] += (1 - eta2) * height * dhelp;
+ }
+ }
+ /* Use fwhm1 and eta1 when x < centroid */
+ else {
+ /* Lorentzian term */
+ dhelp = (2.0 * x_minus_centroid) / fwhm1;
+ dhelp = 1.0 + (dhelp * dhelp);
+ y[j] += eta1 * height / dhelp;
+
+ /* Gaussian term */
+ dhelp = x_minus_centroid / sigma1;
+ if (dhelp <= 35) {
+ dhelp = exp(-0.5 * dhelp * dhelp);
+ y[j] += (1 - eta1) * height * dhelp;
+ }
+ }
+ }
+ }
+ return(0);
+}
+
/* sum_lorentz
Sum of Lorentz functions, defined by (height, centroid, fwhm).
diff --git a/src/silx/math/fit/functions_wrapper.pxd b/src/silx/math/fit/functions_wrapper.pxd
index 38de94a..232a14b 100644
--- a/src/silx/math/fit/functions_wrapper.pxd
+++ b/src/silx/math/fit/functions_wrapper.pxd
@@ -102,6 +102,12 @@ cdef extern from "functions.h":
int len_pvoigt,
double* y)
+ int sum_splitpvoigt2(double* x,
+ int len_x,
+ double* pvoigt,
+ int len_pvoigt,
+ double* y)
+
int sum_lorentz(double* x,
int len_x,
double* plorentz,
diff --git a/src/silx/math/fit/leastsq.py b/src/silx/math/fit/leastsq.py
index e49977f..9a1e2ad 100644
--- a/src/silx/math/fit/leastsq.py
+++ b/src/silx/math/fit/leastsq.py
@@ -46,21 +46,31 @@ import copy
_logger = logging.getLogger(__name__)
# codes understood by the routine
-CFREE = 0
-CPOSITIVE = 1
-CQUOTED = 2
-CFIXED = 3
-CFACTOR = 4
-CDELTA = 5
-CSUM = 6
-CIGNORED = 7
-
-def leastsq(model, xdata, ydata, p0, sigma=None,
- constraints=None, model_deriv=None, epsfcn=None,
- deltachi=None, full_output=None,
- check_finite=True,
- left_derivative=False,
- max_iter=100):
+CFREE = 0
+CPOSITIVE = 1
+CQUOTED = 2
+CFIXED = 3
+CFACTOR = 4
+CDELTA = 5
+CSUM = 6
+CIGNORED = 7
+
+
+def leastsq(
+ model,
+ xdata,
+ ydata,
+ p0,
+ sigma=None,
+ constraints=None,
+ model_deriv=None,
+ epsfcn=None,
+ deltachi=None,
+ full_output=None,
+ check_finite=True,
+ left_derivative=False,
+ max_iter=100,
+):
"""
Use non-linear least squares Levenberg-Marquardt algorithm to fit a function, f, to
data with optional constraints on the fitted parameters.
@@ -272,7 +282,9 @@ def leastsq(model, xdata, ydata, p0, sigma=None,
filter_xdata = True
if filter_xdata:
if xdata.size != ydata.size:
- raise ValueError("xdata contains non-finite data that cannot be filtered")
+ raise ValueError(
+ "xdata contains non-finite data that cannot be filtered"
+ )
else:
# we leave the xdata as they where
old_shape = xdata.shape
@@ -324,25 +336,27 @@ def leastsq(model, xdata, ydata, p0, sigma=None,
elif txt in ["IGNORED", "IGNORE"]:
constraints[i][0] = CIGNORED
else:
- #I should raise an exception
+ # I should raise an exception
raise ValueError("Unknown constraint %s" % constraints[i][0])
if constraints[i][0] > 0:
constrained_fit = True
if constrained_fit:
if full_output is None:
- _logger.info("Recommended to set full_output to True when using constraints")
+ _logger.info(
+ "Recommended to set full_output to True when using constraints"
+ )
# Levenberg-Marquardt algorithm
fittedpar = parameters.__copy__()
flambda = 0.001
iiter = max_iter
- #niter = 0
- last_evaluation=None
+ # niter = 0
+ last_evaluation = None
x = xdata
y = ydata
chisq0 = -1
iteration_counter = 0
- while (iiter > 0):
+ while iiter > 0:
weight = weight0
"""
I cannot evaluate the initial chisq here because I do not know
@@ -357,60 +371,67 @@ def leastsq(model, xdata, ydata, p0, sigma=None,
"""
iteration_counter += 1
chisq0, alpha0, beta, internal_output = chisq_alpha_beta(
- model, fittedpar,
- x, y, weight, constraints=constraints,
- model_deriv=model_deriv,
- epsfcn=epsfcn,
- left_derivative=left_derivative,
- last_evaluation=last_evaluation,
- full_output=True)
+ model,
+ fittedpar,
+ x,
+ y,
+ weight,
+ constraints=constraints,
+ model_deriv=model_deriv,
+ epsfcn=epsfcn,
+ left_derivative=left_derivative,
+ last_evaluation=last_evaluation,
+ full_output=True,
+ )
n_free = internal_output["n_free"]
free_index = internal_output["free_index"]
noigno = internal_output["noigno"]
fitparam = internal_output["fitparam"]
function_calls = internal_output["function_calls"]
function_call_counter += function_calls
- #print("chisq0 = ", chisq0, n_free, fittedpar)
- #raise
+ # print("chisq0 = ", chisq0, n_free, fittedpar)
+ # raise
nr, nc = alpha0.shape
flag = 0
- #lastdeltachi = chisq0
+ # lastdeltachi = chisq0
while flag == 0:
alpha = alpha0 * (1.0 + flambda * numpy.identity(nr))
deltapar = numpy.dot(beta, inv(alpha))
if constraints is None:
- newpar = fitparam + deltapar [0]
+ newpar = fitparam + deltapar[0]
else:
newpar = parameters.__copy__()
pwork = numpy.zeros(deltapar.shape, numpy.float64)
for i in range(n_free):
if constraints is None:
- pwork [0] [i] = fitparam [i] + deltapar [0] [i]
- elif constraints [free_index[i]][0] == CFREE:
- pwork [0] [i] = fitparam [i] + deltapar [0] [i]
- elif constraints [free_index[i]][0] == CPOSITIVE:
- #abs method
- pwork [0] [i] = fitparam [i] + deltapar [0] [i]
- #square method
- #pwork [0] [i] = (numpy.sqrt(fitparam [i]) + deltapar [0] [i]) * \
+ pwork[0][i] = fitparam[i] + deltapar[0][i]
+ elif constraints[free_index[i]][0] == CFREE:
+ pwork[0][i] = fitparam[i] + deltapar[0][i]
+ elif constraints[free_index[i]][0] == CPOSITIVE:
+ # abs method
+ pwork[0][i] = fitparam[i] + deltapar[0][i]
+ # square method
+ # pwork [0] [i] = (numpy.sqrt(fitparam [i]) + deltapar [0] [i]) * \
# (numpy.sqrt(fitparam [i]) + deltapar [0] [i])
elif constraints[free_index[i]][0] == CQUOTED:
- pmax = max(constraints[free_index[i]][1],
- constraints[free_index[i]][2])
- pmin = min(constraints[free_index[i]][1],
- constraints[free_index[i]][2])
+ pmax = max(
+ constraints[free_index[i]][1], constraints[free_index[i]][2]
+ )
+ pmin = min(
+ constraints[free_index[i]][1], constraints[free_index[i]][2]
+ )
A = 0.5 * (pmax + pmin)
B = 0.5 * (pmax - pmin)
if B != 0:
- pwork [0] [i] = A + \
- B * numpy.sin(numpy.arcsin((fitparam[i] - A)/B)+ \
- deltapar [0] [i])
+ pwork[0][i] = A + B * numpy.sin(
+ numpy.arcsin((fitparam[i] - A) / B) + deltapar[0][i]
+ )
else:
txt = "Error processing constrained fit\n"
txt += "Parameter limits are %g and %g\n" % (pmin, pmax)
- txt += "A = %g B = %g" % (A, B)
+ txt += "A = %g B = %g" % (A, B)
raise ValueError("Invalid parameter limits")
- newpar[free_index[i]] = pwork [0] [i]
+ newpar[free_index[i]] = pwork[0][i]
newpar = numpy.array(_get_parameters(newpar, constraints))
workpar = numpy.take(newpar, noigno)
yfit = model(x, *workpar)
@@ -422,7 +443,7 @@ def leastsq(model, xdata, ydata, p0, sigma=None,
_logger.warning(msg)
yfit.shape = -1
function_call_counter += 1
- chisq = (weight * pow(y-yfit, 2)).sum()
+ chisq = (weight * pow(y - yfit, 2)).sum()
absdeltachi = chisq0 - chisq
if absdeltachi < 0:
flambda *= 10.0
@@ -440,7 +461,9 @@ def leastsq(model, xdata, ydata, p0, sigma=None,
iiter = 0
elif absdeltachi < numpy.sqrt(epsfcn):
iiter = 0
- _logger.info("Iteration finished due to too small absolute chi decrement")
+ _logger.info(
+ "Iteration finished due to too small absolute chi decrement"
+ )
chisq0 = chisq
flambda = flambda / 10.0
last_evaluation = yfit
@@ -462,13 +485,18 @@ def leastsq(model, xdata, ydata, p0, sigma=None,
new_constraints[idx][1] = 0
new_constraints[idx][2] = 0
chisq, alpha, beta, internal_output = chisq_alpha_beta(
- model, fittedpar,
- x, y, weight, constraints=new_constraints,
- model_deriv=model_deriv,
- epsfcn=epsfcn,
- left_derivative=left_derivative,
- last_evaluation=last_evaluation,
- full_output=True)
+ model,
+ fittedpar,
+ x,
+ y,
+ weight,
+ constraints=new_constraints,
+ model_deriv=model_deriv,
+ epsfcn=epsfcn,
+ left_derivative=left_derivative,
+ last_evaluation=last_evaluation,
+ full_output=True,
+ )
# obtained chisq should be identical to chisq0
try:
cov = inv(alpha)
@@ -478,7 +506,9 @@ def leastsq(model, xdata, ydata, p0, sigma=None,
if cov is not None:
for idx, value in enumerate(flag_special):
if value in [CFIXED, CIGNORED]:
- cov = numpy.insert(numpy.insert(cov, idx, 0, axis=1), idx, 0, axis=0)
+ cov = numpy.insert(
+ numpy.insert(cov, idx, 0, axis=1), idx, 0, axis=0
+ )
cov[idx, idx] = fittedpar[idx] * fittedpar[idx]
if not full_output:
@@ -488,18 +518,32 @@ def leastsq(model, xdata, ydata, p0, sigma=None,
sigmapar = _get_sigma_parameters(fittedpar, sigma0, constraints)
ddict = {}
ddict["chisq"] = chisq0
- ddict["reduced_chisq"] = chisq0 / (len(yfit)-n_free)
+ ddict["reduced_chisq"] = chisq0 / (len(yfit) - n_free)
ddict["covariance"] = cov0
ddict["uncertainties"] = sigmapar
ddict["fvec"] = last_evaluation
ddict["nfev"] = function_call_counter
ddict["niter"] = iteration_counter
- return fittedpar, cov, ddict #, chisq/(len(yfit)-len(sigma0)), sigmapar,niter,lastdeltachi
-
-def chisq_alpha_beta(model, parameters, x, y, weight, constraints=None,
- model_deriv=None, epsfcn=None, left_derivative=False,
- last_evaluation=None, full_output=False):
-
+ return (
+ fittedpar,
+ cov,
+ ddict,
+ ) # , chisq/(len(yfit)-len(sigma0)), sigmapar,niter,lastdeltachi
+
+
+def chisq_alpha_beta(
+ model,
+ parameters,
+ x,
+ y,
+ weight,
+ constraints=None,
+ model_deriv=None,
+ epsfcn=None,
+ left_derivative=False,
+ last_evaluation=None,
+ full_output=False,
+):
"""
Get chi square, the curvature matrix alpha and the matrix beta according to the input parameters.
If all the parameters are unconstrained, the covariance matrix is the inverse of the alpha matrix.
@@ -597,10 +641,10 @@ def chisq_alpha_beta(model, parameters, x, y, weight, constraints=None,
epsfcn = numpy.finfo(numpy.float64).eps
else:
epsfcn = max(epsfcn, numpy.finfo(numpy.float64).eps)
- #nr0, nc = data.shape
+ # nr0, nc = data.shape
n_param = len(parameters)
if constraints is None:
- derivfactor = numpy.ones((n_param, ))
+ derivfactor = numpy.ones((n_param,))
n_free = n_param
noigno = numpy.arange(n_param)
free_index = noigno * 1
@@ -615,30 +659,34 @@ def chisq_alpha_beta(model, parameters, x, y, weight, constraints=None,
if constraints[i][0] != CIGNORED:
noigno.append(i)
if constraints[i][0] == CFREE:
- fitparam.append(parameters [i])
+ fitparam.append(parameters[i])
derivfactor.append(1.0)
free_index.append(i)
n_free += 1
elif constraints[i][0] == CPOSITIVE:
fitparam.append(abs(parameters[i]))
derivfactor.append(1.0)
- #fitparam.append(numpy.sqrt(abs(parameters[i])))
- #derivfactor.append(2.0*numpy.sqrt(abs(parameters[i])))
+ # fitparam.append(numpy.sqrt(abs(parameters[i])))
+ # derivfactor.append(2.0*numpy.sqrt(abs(parameters[i])))
free_index.append(i)
n_free += 1
elif constraints[i][0] == CQUOTED:
pmax = max(constraints[i][1], constraints[i][2])
- pmin =min(constraints[i][1], constraints[i][2])
- if ((pmax-pmin) > 0) & \
- (parameters[i] <= pmax) & \
- (parameters[i] >= pmin):
+ pmin = min(constraints[i][1], constraints[i][2])
+ if (
+ ((pmax - pmin) > 0)
+ & (parameters[i] <= pmax)
+ & (parameters[i] >= pmin)
+ ):
A = 0.5 * (pmax + pmin)
B = 0.5 * (pmax - pmin)
fitparam.append(parameters[i])
- derivfactor.append(B*numpy.cos(numpy.arcsin((parameters[i] - A)/B)))
+ derivfactor.append(
+ B * numpy.cos(numpy.arcsin((parameters[i] - A) / B))
+ )
free_index.append(i)
n_free += 1
- elif (pmax-pmin) > 0:
+ elif (pmax - pmin) > 0:
print("WARNING: Quoted parameter outside boundaries")
print("Initial value = %f" % parameters[i])
print("Limits are %f and %f" % (pmin, pmax))
@@ -646,15 +694,15 @@ def chisq_alpha_beta(model, parameters, x, y, weight, constraints=None,
fitparam = numpy.array(fitparam, numpy.float64)
alpha = numpy.zeros((n_free, n_free), numpy.float64)
beta = numpy.zeros((1, n_free), numpy.float64)
- #delta = (fitparam + numpy.equal(fitparam, 0.0)) * 0.00001
+ # delta = (fitparam + numpy.equal(fitparam, 0.0)) * 0.00001
delta = (fitparam + numpy.equal(fitparam, 0.0)) * numpy.sqrt(epsfcn)
- nr = y.size
+ nr = y.size
##############
# Prior to each call to the function one has to re-calculate the
# parameters
pwork = parameters.__copy__()
for i in range(n_free):
- pwork [free_index[i]] = fitparam [i]
+ pwork[free_index[i]] = fitparam[i]
if n_free == 0:
raise ValueError("No free parameters to fit")
function_calls = 0
@@ -667,26 +715,26 @@ def chisq_alpha_beta(model, parameters, x, y, weight, constraints=None,
function_calls += 1
for i in range(n_free):
if model_deriv is None:
- #pwork = parameters.__copy__()
- pwork[free_index[i]] = fitparam [i] + delta [i]
+ # pwork = parameters.__copy__()
+ pwork[free_index[i]] = fitparam[i] + delta[i]
newpar = _get_parameters(pwork.tolist(), constraints)
newpar = numpy.take(newpar, noigno)
f1 = model(x, *newpar)
f1.shape = -1
function_calls += 1
if left_derivative:
- pwork[free_index[i]] = fitparam [i] - delta [i]
+ pwork[free_index[i]] = fitparam[i] - delta[i]
newpar = _get_parameters(pwork.tolist(), constraints)
- newpar=numpy.take(newpar, noigno)
+ newpar = numpy.take(newpar, noigno)
f2 = model(x, *newpar)
function_calls += 1
help0 = (f1 - f2) / (2.0 * delta[i])
else:
help0 = (f1 - f2) / (delta[i])
help0 = help0 * derivfactor[i]
- pwork[free_index[i]] = fitparam [i]
- #removed I resize outside the loop:
- #help0 = numpy.resize(help0, (1, nr))
+ pwork[free_index[i]] = fitparam[i]
+ # removed I resize outside the loop:
+ # help0 = numpy.resize(help0, (1, nr))
else:
help0 = model_deriv(x, pwork, free_index[i])
help0 = help0 * derivfactor[i]
@@ -696,7 +744,7 @@ def chisq_alpha_beta(model, parameters, x, y, weight, constraints=None,
else:
deriv = numpy.concatenate((deriv, help0), 0)
- #line added to resize outside the loop
+ # line added to resize outside the loop
deriv = numpy.resize(deriv, (n_free, nr))
if last_evaluation is None:
if constraints is None:
@@ -719,7 +767,7 @@ def chisq_alpha_beta(model, parameters, x, y, weight, constraints=None,
beta = help1
else:
beta = numpy.concatenate((beta, help1), 1)
- help1 = numpy.inner(deriv, weight*derivi)
+ help1 = numpy.inner(deriv, weight * derivi)
if i == 0:
alpha = help1
else:
@@ -752,13 +800,13 @@ def _get_parameters(parameters, constraints):
if constraints is None:
return parameters * 1
newparam = []
- #first I make the free parameters
- #because the quoted ones put troubles
+ # first I make the free parameters
+ # because the quoted ones put troubles
for i in range(len(constraints)):
if constraints[i][0] == CFREE:
newparam.append(parameters[i])
elif constraints[i][0] == CPOSITIVE:
- #newparam.append(parameters[i] * parameters[i])
+ # newparam.append(parameters[i] * parameters[i])
newparam.append(abs(parameters[i]))
elif constraints[i][0] == CQUOTED:
newparam.append(parameters[i])
@@ -779,7 +827,7 @@ def _get_parameters(parameters, constraints):
# using this module
newparam[i] = 0
elif constraints[i][0] == CSUM:
- newparam[i] = constraints[i][2]-newparam[int(constraints[i][1])]
+ newparam[i] = constraints[i][2] - newparam[int(constraints[i][1])]
return newparam
@@ -805,31 +853,31 @@ def _get_sigma_parameters(parameters, sigma0, constraints):
sigma_par = numpy.zeros(parameters.shape, numpy.float64)
for i in range(len(constraints)):
if constraints[i][0] == CFREE:
- sigma_par [i] = sigma0[n_free]
+ sigma_par[i] = sigma0[n_free]
n_free += 1
elif constraints[i][0] == CPOSITIVE:
- #sigma_par [i] = 2.0 * sigma0[n_free]
- sigma_par [i] = sigma0[n_free]
+ # sigma_par [i] = 2.0 * sigma0[n_free]
+ sigma_par[i] = sigma0[n_free]
n_free += 1
elif constraints[i][0] == CQUOTED:
- pmax = max(constraints [i][1], constraints [i][2])
- pmin = min(constraints [i][1], constraints [i][2])
+ pmax = max(constraints[i][1], constraints[i][2])
+ pmin = min(constraints[i][1], constraints[i][2])
# A = 0.5 * (pmax + pmin)
B = 0.5 * (pmax - pmin)
- if (B > 0) & (parameters [i] < pmax) & (parameters [i] > pmin):
- sigma_par [i] = abs(B * numpy.cos(parameters[i]) * sigma0[n_free])
+ if (B > 0) & (parameters[i] < pmax) & (parameters[i] > pmin):
+ sigma_par[i] = abs(B * numpy.cos(parameters[i]) * sigma0[n_free])
n_free += 1
else:
- sigma_par [i] = parameters[i]
+ sigma_par[i] = parameters[i]
elif abs(constraints[i][0]) == CFIXED:
sigma_par[i] = parameters[i]
for i in range(len(constraints)):
if constraints[i][0] == CFACTOR:
- sigma_par [i] = constraints[i][2]*sigma_par[int(constraints[i][1])]
+ sigma_par[i] = constraints[i][2] * sigma_par[int(constraints[i][1])]
elif constraints[i][0] == CDELTA:
- sigma_par [i] = sigma_par[int(constraints[i][1])]
+ sigma_par[i] = sigma_par[int(constraints[i][1])]
elif constraints[i][0] == CSUM:
- sigma_par [i] = sigma_par[int(constraints[i][1])]
+ sigma_par[i] = sigma_par[int(constraints[i][1])]
return sigma_par
@@ -852,24 +900,29 @@ def main(argv=None):
dummy = 2.3548200450309493 * (t - param[3]) / param[4]
return param[0] + param[1] * t + param[2] * myexp(-0.5 * dummy * dummy)
-
def myexp(x):
# put a (bad) filter to avoid over/underflows
# with no python looping
- return numpy.exp(x * numpy.less(abs(x), 250)) -\
- 1.0 * numpy.greater_equal(abs(x), 250)
+ return numpy.exp(x * numpy.less(abs(x), 250)) - 1.0 * numpy.greater_equal(
+ abs(x), 250
+ )
xx = numpy.arange(npoints, dtype=numpy.float64)
- yy = gauss(xx, *[10.5, 2, 1000.0, 20., 15])
+ yy = gauss(xx, *[10.5, 2, 1000.0, 20.0, 15])
sy = numpy.sqrt(abs(yy))
- parameters = [0.0, 1.0, 900.0, 25., 10]
+ parameters = [0.0, 1.0, 900.0, 25.0, 10]
stime = time.time()
- fittedpar, cov, ddict = leastsq(gauss, xx, yy, parameters,
- sigma=sy,
- left_derivative=False,
- full_output=True,
- check_finite=True)
+ fittedpar, cov, ddict = leastsq(
+ gauss,
+ xx,
+ yy,
+ parameters,
+ sigma=sy,
+ left_derivative=False,
+ full_output=True,
+ check_finite=True,
+ )
etime = time.time()
sigmapars = numpy.sqrt(numpy.diag(cov))
print("Took ", etime - stime, "seconds")
@@ -879,22 +932,20 @@ def main(argv=None):
print("Sigma pars = ", sigmapars)
try:
from scipy.optimize import curve_fit as cfit
+
SCIPY = True
except ImportError:
SCIPY = False
if SCIPY:
counter = 0
stime = time.time()
- scipy_fittedpar, scipy_cov = cfit(gauss,
- xx,
- yy,
- parameters,
- sigma=sy)
+ scipy_fittedpar, scipy_cov = cfit(gauss, xx, yy, parameters, sigma=sy)
etime = time.time()
print("Scipy Took ", etime - stime, "seconds")
print("Counter = ", counter)
print("scipy = ", scipy_fittedpar)
print("Sigma = ", numpy.sqrt(numpy.diag(scipy_cov)))
+
if __name__ == "__main__":
main()
diff --git a/src/silx/math/fit/test/test_bgtheories.py b/src/silx/math/fit/test/test_bgtheories.py
index 40f0831..8dd8d81 100644
--- a/src/silx/math/fit/test/test_bgtheories.py
+++ b/src/silx/math/fit/test/test_bgtheories.py
@@ -30,13 +30,13 @@ from silx.math.fit.functions import sum_gauss
class TestBgTheories(unittest.TestCase):
- """
- """
+ """ """
+
def setUp(self):
self.x = numpy.arange(100)
- self.y = 10 + 0.05 * self.x + sum_gauss(self.x, 10., 45., 15.)
+ self.y = 10 + 0.05 * self.x + sum_gauss(self.x, 10.0, 45.0, 15.0)
# add a very narrow high amplitude peak to test strip and snip
- self.y += sum_gauss(self.x, 100., 75., 2.)
+ self.y += sum_gauss(self.x, 100.0, 75.0, 2.0)
self.narrow_peak_index = list(self.x).index(75)
random.seed()
@@ -46,46 +46,47 @@ class TestBgTheories(unittest.TestCase):
def testTheoriesAttrs(self):
for theory_name in bgtheories.THEORY:
self.assertIsInstance(theory_name, str)
- self.assertTrue(hasattr(bgtheories.THEORY[theory_name],
- "function"))
- self.assertTrue(hasattr(bgtheories.THEORY[theory_name].function,
- "__call__"))
+ self.assertTrue(hasattr(bgtheories.THEORY[theory_name], "function"))
+ self.assertTrue(
+ hasattr(bgtheories.THEORY[theory_name].function, "__call__")
+ )
# Ensure legacy functions are not renamed accidentally
self.assertTrue(
- {"No Background", "Constant", "Linear", "Strip", "Snip"}.issubset(
- set(bgtheories.THEORY)))
+ {"No Background", "Constant", "Linear", "Strip", "Snip"}.issubset(
+ set(bgtheories.THEORY)
+ )
+ )
def testNoBg(self):
nobgfun = bgtheories.THEORY["No Background"].function
- self.assertTrue(numpy.array_equal(nobgfun(self.x, self.y),
- numpy.zeros_like(self.x)))
+ self.assertTrue(
+ numpy.array_equal(nobgfun(self.x, self.y), numpy.zeros_like(self.x))
+ )
# default estimate
- self.assertEqual(bgtheories.THEORY["No Background"].estimate(self.x, self.y),
- ([], []))
+ self.assertEqual(
+ bgtheories.THEORY["No Background"].estimate(self.x, self.y), ([], [])
+ )
def testConstant(self):
consfun = bgtheories.THEORY["Constant"].function
c = random.random() * 100
- self.assertTrue(numpy.array_equal(consfun(self.x, self.y, c),
- c * numpy.ones_like(self.x)))
+ self.assertTrue(
+ numpy.array_equal(consfun(self.x, self.y, c), c * numpy.ones_like(self.x))
+ )
# default estimate
esti_par, cons = bgtheories.THEORY["Constant"].estimate(self.x, self.y)
- self.assertEqual(cons,
- [[0, 0, 0]])
- self.assertAlmostEqual(esti_par,
- min(self.y))
+ self.assertEqual(cons, [[0, 0, 0]])
+ self.assertAlmostEqual(esti_par, min(self.y))
def testLinear(self):
linfun = bgtheories.THEORY["Linear"].function
a = random.random() * 100
b = random.random() * 100
- self.assertTrue(numpy.array_equal(linfun(self.x, self.y, a, b),
- a + b * self.x))
+ self.assertTrue(numpy.array_equal(linfun(self.x, self.y, a, b), a + b * self.x))
# default estimate
esti_par, cons = bgtheories.THEORY["Linear"].estimate(self.x, self.y)
- self.assertEqual(cons,
- [[0, 0, 0], [0, 0, 0]])
+ self.assertEqual(cons, [[0, 0, 0], [0, 0, 0]])
self.assertAlmostEqual(esti_par[0], 10, places=3)
self.assertAlmostEqual(esti_par[1], 0.05, places=3)
@@ -108,8 +109,7 @@ class TestBgTheories(unittest.TestCase):
bg = stripfun(self.x, self.y, width, niter)
# assert peak amplitude has been decreased
- self.assertLess(bg[self.narrow_peak_index],
- self.y[self.narrow_peak_index])
+ self.assertLess(bg[self.narrow_peak_index], self.y[self.narrow_peak_index])
# default estimate
for i in anchors_indices:
@@ -138,9 +138,11 @@ class TestBgTheories(unittest.TestCase):
bg = snipfun(self.x, self.y, width)
# assert peak amplitude has been decreased
- self.assertLess(bg[self.narrow_peak_index],
- self.y[self.narrow_peak_index],
- "Snip didn't decrease the peak amplitude.")
+ self.assertLess(
+ bg[self.narrow_peak_index],
+ self.y[self.narrow_peak_index],
+ "Snip didn't decrease the peak amplitude.",
+ )
# anchored data must remain fixed
for i in anchors_indices:
diff --git a/src/silx/math/fit/test/test_filters.py b/src/silx/math/fit/test/test_filters.py
index 5b8b070..645991e 100644
--- a/src/silx/math/fit/test/test_filters.py
+++ b/src/silx/math/fit/test/test_filters.py
@@ -35,35 +35,70 @@ class TestSmooth(unittest.TestCase):
noise and the result of smoothing that signal is less than 5%. We compare
the sum of all samples in each curve.
"""
+
def setUp(self):
x = numpy.arange(5000)
# (height1, center1, fwhm1, beamfwhm...)
- slit_params = (50, 500, 200, 100,
- 50, 600, 80, 30,
- 20, 2000, 150, 150,
- 50, 2250, 110, 100,
- 40, 3000, 50, 10,
- 23, 4980, 250, 20)
+ slit_params = (
+ 50,
+ 500,
+ 200,
+ 100,
+ 50,
+ 600,
+ 80,
+ 30,
+ 20,
+ 2000,
+ 150,
+ 150,
+ 50,
+ 2250,
+ 110,
+ 100,
+ 40,
+ 3000,
+ 50,
+ 10,
+ 23,
+ 4980,
+ 250,
+ 20,
+ )
self.y1 = functions.sum_slit(x, *slit_params)
# 5% noise
- self.y1 = add_relative_noise(self.y1, 5.)
+ self.y1 = add_relative_noise(self.y1, 5.0)
# (height1, center1, fwhm1...)
- step_params = (50, 500, 200,
- 50, 600, 80,
- 20, 2000, 150,
- 50, 2250, 110,
- 40, 3000, 50,
- 23, 4980, 250,)
+ step_params = (
+ 50,
+ 500,
+ 200,
+ 50,
+ 600,
+ 80,
+ 20,
+ 2000,
+ 150,
+ 50,
+ 2250,
+ 110,
+ 40,
+ 3000,
+ 50,
+ 23,
+ 4980,
+ 250,
+ )
self.y2 = functions.sum_stepup(x, *step_params)
# 5% noise
- self.y2 = add_relative_noise(self.y2, 5.)
+ self.y2 = add_relative_noise(self.y2, 5.0)
self.y3 = functions.sum_stepdown(x, *step_params)
# 5% noise
- self.y3 = add_relative_noise(self.y3, 5.)
+ self.y3 = add_relative_noise(self.y3, 5.0)
def tearDown(self):
pass
@@ -76,9 +111,12 @@ class TestSmooth(unittest.TestCase):
# we added +-5% of random noise. The difference must be much lower
# than 5%.
diff = abs(sum(smoothed_y) - sum(y)) / sum(y)
- self.assertLess(diff, 0.05,
- "Difference between data with 5%% noise and " +
- "smoothed data is > 5%% (%f %%)" % (diff * 100))
+ self.assertLess(
+ diff,
+ 0.05,
+ "Difference between data with 5%% noise and "
+ + "smoothed data is > 5%% (%f %%)" % (diff * 100),
+ )
# Try various smoothing levels
npts += 25
@@ -89,8 +127,9 @@ class TestSmooth(unittest.TestCase):
smoothed_y = filters.smooth1d(self.y1)
for i in range(1, len(self.y1) - 1):
- self.assertAlmostEqual(4 * smoothed_y[i],
- self.y1[i-1] + 2 * self.y1[i] + self.y1[i+1])
+ self.assertAlmostEqual(
+ 4 * smoothed_y[i], self.y1[i - 1] + 2 * self.y1[i] + self.y1[i + 1]
+ )
def testSmooth2d(self):
"""Test that a 2D smoothing is the same as two successive and
@@ -117,5 +156,4 @@ class TestSmooth(unittest.TestCase):
for i in range(0, y.shape[0]):
for j in range(0, y.shape[1]):
- self.assertAlmostEqual(smoothed_y[i, j],
- expected_smooth[i, j])
+ self.assertAlmostEqual(smoothed_y[i, j], expected_smooth[i, j])
diff --git a/src/silx/math/fit/test/test_fit.py b/src/silx/math/fit/test/test_fit.py
index 39a04f9..a25a94b 100644
--- a/src/silx/math/fit/test/test_fit.py
+++ b/src/silx/math/fit/test/test_fit.py
@@ -43,6 +43,7 @@ class Test_leastsq(unittest.TestCase):
def setUp(self):
try:
from silx.math.fit import leastsq
+
self.instance = leastsq
except ImportError:
self.instance = None
@@ -50,9 +51,10 @@ class Test_leastsq(unittest.TestCase):
def myexp(x):
# put a (bad) filter to avoid over/underflows
# with no python looping
- with numpy.errstate(invalid='ignore'):
- return numpy.exp(x*numpy.less(abs(x), 250)) - \
- 1.0 * numpy.greater_equal(abs(x), 250)
+ with numpy.errstate(invalid="ignore"):
+ return numpy.exp(
+ x * numpy.less(abs(x), 250)
+ ) - 1.0 * numpy.greater_equal(abs(x), 250)
self.my_exp = myexp
@@ -60,8 +62,8 @@ class Test_leastsq(unittest.TestCase):
params = numpy.array(params, copy=False, dtype=numpy.float64)
result = params[0] + params[1] * x
for i in range(2, len(params), 3):
- p = params[i:(i+3)]
- dummy = 2.3548200450309493*(x - p[1])/p[2]
+ p = params[i : (i + 3)]
+ dummy = 2.3548200450309493 * (x - p[1]) / p[2]
result += p[0] * self.my_exp(-0.5 * dummy * dummy)
return result
@@ -75,17 +77,17 @@ class Test_leastsq(unittest.TestCase):
gaussian_peak = (idx - 2) // 3
gaussian_parameter = (idx - 2) % 3
actual_idx = 2 + 3 * gaussian_peak
- p = params[actual_idx:(actual_idx+3)]
+ p = params[actual_idx : (actual_idx + 3)]
if gaussian_parameter == 0:
return self.gauss(x, *[0, 0, 1.0, p[1], p[2]])
if gaussian_parameter == 1:
tmp = self.gauss(x, *[0, 0, p[0], p[1], p[2]])
- tmp *= 2.3548200450309493*(x - p[1])/p[2]
- return tmp * 2.3548200450309493/p[2]
+ tmp *= 2.3548200450309493 * (x - p[1]) / p[2]
+ return tmp * 2.3548200450309493 / p[2]
if gaussian_parameter == 2:
tmp = self.gauss(x, *[0, 0, p[0], p[1], p[2]])
- tmp *= 2.3548200450309493*(x - p[1])/p[2]
- return tmp * 2.3548200450309493*(x - p[1])/(p[2]*p[2])
+ tmp *= 2.3548200450309493 * (x - p[1]) / p[2]
+ return tmp * 2.3548200450309493 * (x - p[1]) / (p[2] * p[2])
self.gauss_derivative = gauss_derivative
@@ -98,14 +100,15 @@ class Test_leastsq(unittest.TestCase):
self.model_derivative = None
def testImport(self):
- self.assertTrue(self.instance is not None,
- "Cannot import leastsq from silx.math.fit")
+ self.assertTrue(
+ self.instance is not None, "Cannot import leastsq from silx.math.fit"
+ )
def testUnconstrainedFitNoWeight(self):
- parameters_actual = [10.5, 2, 1000.0, 20., 15]
- x = numpy.arange(10000.)
+ parameters_actual = [10.5, 2, 1000.0, 20.0, 15]
+ x = numpy.arange(10000.0)
y = self.gauss(x, *parameters_actual)
- parameters_estimate = [0.0, 1.0, 900.0, 25., 10]
+ parameters_estimate = [0.0, 1.0, 900.0, 25.0, 10]
model_function = self.gauss
fittedpar, cov = self.instance(model_function, x, y, parameters_estimate)
@@ -113,32 +116,36 @@ class Test_leastsq(unittest.TestCase):
if not test_condition:
msg = "Unsuccessfull fit\n"
for i in range(len(fittedpar)):
- msg += "Expected %g obtained %g\n" % (parameters_actual[i],
- fittedpar[i])
+ msg += "Expected %g obtained %g\n" % (
+ parameters_actual[i],
+ fittedpar[i],
+ )
self.assertTrue(test_condition, msg)
def testUnconstrainedFitWeight(self):
- parameters_actual = [10.5,2,1000.0,20.,15]
- x = numpy.arange(10000.)
+ parameters_actual = [10.5, 2, 1000.0, 20.0, 15]
+ x = numpy.arange(10000.0)
y = self.gauss(x, *parameters_actual)
sigma = numpy.sqrt(y)
- parameters_estimate = [0.0, 1.0, 900.0, 25., 10]
+ parameters_estimate = [0.0, 1.0, 900.0, 25.0, 10]
model_function = self.gauss
- fittedpar, cov = self.instance(model_function, x, y,
- parameters_estimate,
- sigma=sigma)
+ fittedpar, cov = self.instance(
+ model_function, x, y, parameters_estimate, sigma=sigma
+ )
test_condition = numpy.allclose(parameters_actual, fittedpar)
if not test_condition:
msg = "Unsuccessfull fit\n"
for i in range(len(fittedpar)):
- msg += "Expected %g obtained %g\n" % (parameters_actual[i],
- fittedpar[i])
+ msg += "Expected %g obtained %g\n" % (
+ parameters_actual[i],
+ fittedpar[i],
+ )
self.assertTrue(test_condition, msg)
def testDerivativeFunction(self):
- parameters_actual = [10.5, 2, 10000.0, 20., 150, 5000, 900., 300]
- x = numpy.arange(10000.)
+ parameters_actual = [10.5, 2, 10000.0, 20.0, 150, 5000, 900.0, 300]
+ x = numpy.arange(10000.0)
y = self.gauss(x, *parameters_actual)
delta = numpy.sqrt(numpy.finfo(numpy.float64).eps)
for i in range(len(parameters_actual)):
@@ -155,44 +162,47 @@ class Test_leastsq(unittest.TestCase):
p[i] = parameters_actual[i] - delta_par
yMinus = self.gauss(x, *p)
numerical_derivative = (yPlus - yMinus) / (2 * delta_par)
- #numerical_derivative = (self.gauss(x, *p) - y) / delta_par
+ # numerical_derivative = (self.gauss(x, *p) - y) / delta_par
p[i] = parameters_actual[i]
derivative = self.gauss_derivative(x, p, i)
diff = numerical_derivative - derivative
- test_condition = numpy.allclose(numerical_derivative,
- derivative, atol=5.0e-6)
+ test_condition = numpy.allclose(
+ numerical_derivative, derivative, atol=5.0e-6
+ )
if not test_condition:
msg = "Error calculating derivative of parameter %d." % i
msg += "\n diff min = %g diff max = %g" % (diff.min(), diff.max())
self.assertTrue(test_condition, msg)
def testConstrainedFit(self):
- CFREE = 0
- CPOSITIVE = 1
- CQUOTED = 2
- CFIXED = 3
- CFACTOR = 4
- CDELTA = 5
- CSUM = 6
- parameters_actual = [10.5, 2, 10000.0, 20., 150, 5000, 900., 300]
- x = numpy.arange(10000.)
+ CFREE = 0
+ CPOSITIVE = 1
+ CQUOTED = 2
+ CFIXED = 3
+ CFACTOR = 4
+ CDELTA = 5
+ CSUM = 6
+ parameters_actual = [10.5, 2, 10000.0, 20.0, 150, 5000, 900.0, 300]
+ x = numpy.arange(10000.0)
y = self.gauss(x, *parameters_actual)
- parameters_estimate = [0.0, 1.0, 900.0, 25., 10, 400, 850, 200]
+ parameters_estimate = [0.0, 1.0, 900.0, 25.0, 10, 400, 850, 200]
model_function = self.gauss
model_deriv = self.gauss_derivative
constraints_all_free = [[0, 0, 0]] * len(parameters_actual)
constraints_all_positive = [[1, 0, 0]] * len(parameters_actual)
constraints_delta_position = [[0, 0, 0]] * len(parameters_actual)
constraints_delta_position[6] = [CDELTA, 3, 880]
- constraints_sum_position = constraints_all_positive * 1
+ constraints_sum_position = constraints_all_positive * 1
constraints_sum_position[6] = [CSUM, 3, 920]
constraints_factor = constraints_delta_position * 1
constraints_factor[2] = [CFACTOR, 5, 2]
- constraints_list = [None,
- constraints_all_free,
- constraints_all_positive,
- constraints_delta_position,
- constraints_sum_position]
+ constraints_list = [
+ None,
+ constraints_all_free,
+ constraints_all_positive,
+ constraints_delta_position,
+ constraints_sum_position,
+ ]
# for better code coverage, the warning recommending to set full_output
# to True when using constraints should be shown at least once
@@ -203,152 +213,176 @@ class Test_leastsq(unittest.TestCase):
elif index == 3:
full_output = 0
for model_deriv in [None, self.gauss_derivative]:
- for sigma in [None, numpy.sqrt(y)]:
- fittedpar, cov = self.instance(model_function, x, y,
- parameters_estimate,
- sigma=sigma,
- constraints=constraints,
- model_deriv=model_deriv,
- full_output=full_output)[:2]
+ for sigma in [None, numpy.sqrt(y)]:
+ fittedpar, cov = self.instance(
+ model_function,
+ x,
+ y,
+ parameters_estimate,
+ sigma=sigma,
+ constraints=constraints,
+ model_deriv=model_deriv,
+ full_output=full_output,
+ )[:2]
full_output = True
test_condition = numpy.allclose(parameters_actual, fittedpar)
if not test_condition:
msg = "Unsuccessfull fit\n"
for i in range(len(fittedpar)):
- msg += "Expected %g obtained %g\n" % (parameters_actual[i],
- fittedpar[i])
+ msg += "Expected %g obtained %g\n" % (
+ parameters_actual[i],
+ fittedpar[i],
+ )
self.assertTrue(test_condition, msg)
def testUnconstrainedFitAnalyticalDerivative(self):
- parameters_actual = [10.5, 2, 1000.0, 20., 15]
- x = numpy.arange(10000.)
+ parameters_actual = [10.5, 2, 1000.0, 20.0, 15]
+ x = numpy.arange(10000.0)
y = self.gauss(x, *parameters_actual)
sigma = numpy.sqrt(y)
- parameters_estimate = [0.0, 1.0, 900.0, 25., 10]
+ parameters_estimate = [0.0, 1.0, 900.0, 25.0, 10]
model_function = self.gauss
model_deriv = self.gauss_derivative
- fittedpar, cov = self.instance(model_function, x, y,
- parameters_estimate,
- sigma=sigma,
- model_deriv=model_deriv)
+ fittedpar, cov = self.instance(
+ model_function,
+ x,
+ y,
+ parameters_estimate,
+ sigma=sigma,
+ model_deriv=model_deriv,
+ )
test_condition = numpy.allclose(parameters_actual, fittedpar)
if not test_condition:
msg = "Unsuccessfull fit\n"
for i in range(len(fittedpar)):
- msg += "Expected %g obtained %g\n" % (parameters_actual[i],
- fittedpar[i])
+ msg += "Expected %g obtained %g\n" % (
+ parameters_actual[i],
+ fittedpar[i],
+ )
self.assertTrue(test_condition, msg)
@testutils.validate_logging(fitlogger.name, warning=2)
def testBadlyShapedData(self):
- parameters_actual = [10.5, 2, 1000.0, 20., 15]
- x = numpy.arange(10000.).reshape(1000, 10)
+ parameters_actual = [10.5, 2, 1000.0, 20.0, 15]
+ x = numpy.arange(10000.0).reshape(1000, 10)
y = self.gauss(x, *parameters_actual)
sigma = numpy.sqrt(y)
- parameters_estimate = [0.0, 1.0, 900.0, 25., 10]
+ parameters_estimate = [0.0, 1.0, 900.0, 25.0, 10]
model_function = self.gauss
for check_finite in [True, False]:
- fittedpar, cov = self.instance(model_function, x, y,
- parameters_estimate,
- sigma=sigma,
- check_finite=check_finite)
+ fittedpar, cov = self.instance(
+ model_function,
+ x,
+ y,
+ parameters_estimate,
+ sigma=sigma,
+ check_finite=check_finite,
+ )
test_condition = numpy.allclose(parameters_actual, fittedpar)
if not test_condition:
msg = "Unsuccessfull fit\n"
for i in range(len(fittedpar)):
- msg += "Expected %g obtained %g\n" % (parameters_actual[i],
- fittedpar[i])
+ msg += "Expected %g obtained %g\n" % (
+ parameters_actual[i],
+ fittedpar[i],
+ )
self.assertTrue(test_condition, msg)
@testutils.validate_logging(fitlogger.name, warning=3)
def testDataWithNaN(self):
- parameters_actual = [10.5, 2, 1000.0, 20., 15]
- x = numpy.arange(10000.).reshape(1000, 10)
+ parameters_actual = [10.5, 2, 1000.0, 20.0, 15]
+ x = numpy.arange(10000.0).reshape(1000, 10)
y = self.gauss(x, *parameters_actual)
sigma = numpy.sqrt(y)
- parameters_estimate = [0.0, 1.0, 900.0, 25., 10]
+ parameters_estimate = [0.0, 1.0, 900.0, 25.0, 10]
model_function = self.gauss
x[500] = numpy.inf
# check default behavior
try:
- self.instance(model_function, x, y,
- parameters_estimate,
- sigma=sigma)
+ self.instance(model_function, x, y, parameters_estimate, sigma=sigma)
except ValueError:
info = "%s" % sys.exc_info()[1]
self.assertTrue("array must not contain inf" in info)
# check requested behavior
try:
- self.instance(model_function, x, y,
- parameters_estimate,
- sigma=sigma,
- check_finite=True)
+ self.instance(
+ model_function,
+ x,
+ y,
+ parameters_estimate,
+ sigma=sigma,
+ check_finite=True,
+ )
except ValueError:
info = "%s" % sys.exc_info()[1]
self.assertTrue("array must not contain inf" in info)
- fittedpar, cov = self.instance(model_function, x, y,
- parameters_estimate,
- sigma=sigma,
- check_finite=False)
+ fittedpar, cov = self.instance(
+ model_function, x, y, parameters_estimate, sigma=sigma, check_finite=False
+ )
test_condition = numpy.allclose(parameters_actual, fittedpar)
if not test_condition:
msg = "Unsuccessfull fit\n"
for i in range(len(fittedpar)):
- msg += "Expected %g obtained %g\n" % (parameters_actual[i],
- fittedpar[i])
+ msg += "Expected %g obtained %g\n" % (
+ parameters_actual[i],
+ fittedpar[i],
+ )
self.assertTrue(test_condition, msg)
# testing now with ydata containing NaN
- x = numpy.arange(10000.).reshape(1000, 10)
+ x = numpy.arange(10000.0).reshape(1000, 10)
y[500] = numpy.nan
- fittedpar, cov = self.instance(model_function, x, y,
- parameters_estimate,
- sigma=sigma,
- check_finite=False)
+ fittedpar, cov = self.instance(
+ model_function, x, y, parameters_estimate, sigma=sigma, check_finite=False
+ )
test_condition = numpy.allclose(parameters_actual, fittedpar)
if not test_condition:
msg = "Unsuccessfull fit\n"
for i in range(len(fittedpar)):
- msg += "Expected %g obtained %g\n" % (parameters_actual[i],
- fittedpar[i])
+ msg += "Expected %g obtained %g\n" % (
+ parameters_actual[i],
+ fittedpar[i],
+ )
self.assertTrue(test_condition, msg)
# testing now with sigma containing NaN
sigma[300] = numpy.nan
- fittedpar, cov = self.instance(model_function, x, y,
- parameters_estimate,
- sigma=sigma,
- check_finite=False)
+ fittedpar, cov = self.instance(
+ model_function, x, y, parameters_estimate, sigma=sigma, check_finite=False
+ )
test_condition = numpy.allclose(parameters_actual, fittedpar)
if not test_condition:
msg = "Unsuccessfull fit\n"
for i in range(len(fittedpar)):
- msg += "Expected %g obtained %g\n" % (parameters_actual[i],
- fittedpar[i])
+ msg += "Expected %g obtained %g\n" % (
+ parameters_actual[i],
+ fittedpar[i],
+ )
self.assertTrue(test_condition, msg)
def testUncertainties(self):
"""Test for validity of uncertainties in returned full-output
dictionary. This is a non-regression test for pull request #197"""
- parameters_actual = [10.5, 2, 1000.0, 20., 15, 2001.0, 30.1, 16]
- x = numpy.arange(10000.)
+ parameters_actual = [10.5, 2, 1000.0, 20.0, 15, 2001.0, 30.1, 16]
+ x = numpy.arange(10000.0)
y = self.gauss(x, *parameters_actual)
- parameters_estimate = [0.0, 1.0, 900.0, 25., 10., 1500., 20., 2.0]
+ parameters_estimate = [0.0, 1.0, 900.0, 25.0, 10.0, 1500.0, 20.0, 2.0]
# test that uncertainties are not 0.
- fittedpar, cov, infodict = self.instance(self.gauss, x, y, parameters_estimate,
- full_output=True)
+ fittedpar, cov, infodict = self.instance(
+ self.gauss, x, y, parameters_estimate, full_output=True
+ )
uncertainties = infodict["uncertainties"]
self.assertEqual(len(uncertainties), len(parameters_actual))
self.assertEqual(len(uncertainties), len(fittedpar))
for uncertainty in uncertainties:
- self.assertNotAlmostEqual(uncertainty, 0.)
+ self.assertNotAlmostEqual(uncertainty, 0.0)
# set constraint FIXED for half the parameters.
# This should cause leastsq to return 100% uncertainty.
@@ -361,12 +395,16 @@ class Test_leastsq(unittest.TestCase):
constraints.append([CFIXED, 0, 0])
else:
constraints.append([CFREE, 0, 0])
- fittedpar, cov, infodict = self.instance(self.gauss, x, y, parameters_estimate,
- constraints=constraints,
- full_output=True)
+ fittedpar, cov, infodict = self.instance(
+ self.gauss,
+ x,
+ y,
+ parameters_estimate,
+ constraints=constraints,
+ full_output=True,
+ )
uncertainties = infodict["uncertainties"]
for i in range(len(parameters_estimate)):
if i % 2:
# test that all FIXED parameters have 100% uncertainty
- self.assertAlmostEqual(uncertainties[i],
- parameters_estimate[i])
+ self.assertAlmostEqual(uncertainties[i], parameters_estimate[i])
diff --git a/src/silx/math/fit/test/test_fitmanager.py b/src/silx/math/fit/test/test_fitmanager.py
index cc35ccf..5229df5 100644
--- a/src/silx/math/fit/test/test_fitmanager.py
+++ b/src/silx/math/fit/test/test_fitmanager.py
@@ -1,5 +1,5 @@
# /*##########################################################################
-# Copyright (C) 2016-2020 European Synchrotron Radiation Facility
+# Copyright (C) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -114,6 +114,7 @@ class TestFitmanager(ParametricTestCase):
"""
Unit tests of multi-peak functions.
"""
+
def setUp(self):
pass
@@ -126,9 +127,7 @@ class TestFitmanager(ParametricTestCase):
# Create synthetic data with a sum of gaussian functions
x = numpy.arange(1000).astype(numpy.float64)
- p = [1000, 100., 250,
- 255, 650., 45,
- 1500, 800.5, 95]
+ p = [1000, 100.0, 250, 255, 650.0, 45, 1500, 800.5, 95]
linear_bg = 2.65 * x + 13
y = linear_bg + sum_gauss(x, *p)
@@ -139,10 +138,10 @@ class TestFitmanager(ParametricTestCase):
x_with_nans[5::15] = numpy.nan
tests = {
- 'all finite': (x, y),
- 'y with NaNs': (x, y_with_nans),
- 'x with NaNs': (x_with_nans, y),
- }
+ "all finite": (x, y),
+ "y with NaNs": (x, y_with_nans),
+ "x with NaNs": (x_with_nans, y),
+ }
for name, (xdata, ydata) in tests.items():
with self.subTest(name=name):
@@ -151,8 +150,8 @@ class TestFitmanager(ParametricTestCase):
fit.setdata(x=xdata, y=ydata)
fit.loadtheories(fittheories)
# Use one of the default fit functions
- fit.settheory('Gaussians')
- fit.setbackground('Linear')
+ fit.settheory("Gaussians")
+ fit.setbackground("Linear")
fit.estimate()
fit.runfit()
@@ -167,19 +166,17 @@ class TestFitmanager(ParametricTestCase):
for i, param in enumerate(fit.fit_results[2:]):
param_number = i // 3 + 1
if i % 3 == 0:
- self.assertEqual(param["name"],
- "Height%d" % param_number)
+ self.assertEqual(param["name"], "Height%d" % param_number)
elif i % 3 == 1:
- self.assertEqual(param["name"],
- "Position%d" % param_number)
+ self.assertEqual(param["name"], "Position%d" % param_number)
elif i % 3 == 2:
- self.assertEqual(param["name"],
- "FWHM%d" % param_number)
+ self.assertEqual(param["name"], "FWHM%d" % param_number)
- self.assertAlmostEqual(param["fitresult"],
- p[i])
- self.assertAlmostEqual(_order_of_magnitude(param["estimation"]),
- _order_of_magnitude(p[i]))
+ self.assertAlmostEqual(param["fitresult"], p[i])
+ self.assertAlmostEqual(
+ _order_of_magnitude(param["estimation"]),
+ _order_of_magnitude(p[i]),
+ )
def testLoadCustomFitFunction(self):
"""Test FitManager using a custom fit function defined in an external
@@ -198,35 +195,29 @@ class TestFitmanager(ParametricTestCase):
# Create a temporary function definition file, and import it
with temp_dir() as tmpDir:
- tmpfile = os.path.join(tmpDir, 'customfun.py')
+ tmpfile = os.path.join(tmpDir, "customfun.py")
# custom_function_definition
fd = open(tmpfile, "w")
fd.write(custom_function_definition)
fd.close()
fit.loadtheories(tmpfile)
- tmpfile_pyc = os.path.join(tmpDir, 'customfun.pyc')
+ tmpfile_pyc = os.path.join(tmpDir, "customfun.pyc")
if os.path.exists(tmpfile_pyc):
os.unlink(tmpfile_pyc)
os.unlink(tmpfile)
- fit.settheory('my fit theory')
+ fit.settheory("my fit theory")
# Test configure
fit.configure(d=4.5)
fit.estimate()
fit.runfit()
- self.assertEqual(fit.fit_results[0]["name"],
- "A1")
- self.assertAlmostEqual(fit.fit_results[0]["fitresult"],
- 1.5)
- self.assertEqual(fit.fit_results[1]["name"],
- "B1")
- self.assertAlmostEqual(fit.fit_results[1]["fitresult"],
- 2.5)
- self.assertEqual(fit.fit_results[2]["name"],
- "C1")
- self.assertAlmostEqual(fit.fit_results[2]["fitresult"],
- 3.5)
+ self.assertEqual(fit.fit_results[0]["name"], "A1")
+ self.assertAlmostEqual(fit.fit_results[0]["fitresult"], 1.5)
+ self.assertEqual(fit.fit_results[1]["name"], "B1")
+ self.assertAlmostEqual(fit.fit_results[1]["fitresult"], 2.5)
+ self.assertEqual(fit.fit_results[2]["name"], "C1")
+ self.assertAlmostEqual(fit.fit_results[2]["fitresult"], 3.5)
def testLoadOldCustomFitFunction(self):
"""Test FitManager using a custom fit function defined in an external
@@ -245,34 +236,28 @@ class TestFitmanager(ParametricTestCase):
# Create a temporary function definition file, and import it
with temp_dir() as tmpDir:
- tmpfile = os.path.join(tmpDir, 'oldcustomfun.py')
+ tmpfile = os.path.join(tmpDir, "oldcustomfun.py")
# custom_function_definition
fd = open(tmpfile, "w")
fd.write(old_custom_function_definition)
fd.close()
fit.loadtheories(tmpfile)
- tmpfile_pyc = os.path.join(tmpDir, 'oldcustomfun.pyc')
+ tmpfile_pyc = os.path.join(tmpDir, "oldcustomfun.pyc")
if os.path.exists(tmpfile_pyc):
os.unlink(tmpfile_pyc)
os.unlink(tmpfile)
- fit.settheory('my fit theory')
+ fit.settheory("my fit theory")
fit.configure(d=4.5)
fit.estimate()
fit.runfit()
- self.assertEqual(fit.fit_results[0]["name"],
- "A1")
- self.assertAlmostEqual(fit.fit_results[0]["fitresult"],
- 1.5)
- self.assertEqual(fit.fit_results[1]["name"],
- "B1")
- self.assertAlmostEqual(fit.fit_results[1]["fitresult"],
- 2.5)
- self.assertEqual(fit.fit_results[2]["name"],
- "C1")
- self.assertAlmostEqual(fit.fit_results[2]["fitresult"],
- 3.5)
+ self.assertEqual(fit.fit_results[0]["name"], "A1")
+ self.assertAlmostEqual(fit.fit_results[0]["fitresult"], 1.5)
+ self.assertEqual(fit.fit_results[1]["name"], "B1")
+ self.assertAlmostEqual(fit.fit_results[1]["fitresult"], 2.5)
+ self.assertEqual(fit.fit_results[2]["name"], "C1")
+ self.assertAlmostEqual(fit.fit_results[2]["fitresult"], 3.5)
def testAddTheory(self, estimate=True):
"""Test FitManager using a custom fit function imported with
@@ -290,19 +275,19 @@ class TestFitmanager(ParametricTestCase):
fit.setdata(x=x, y=y)
# Define and add the fit theory
- CONFIG = {'d': 1.}
+ CONFIG = {"d": 1.0}
def myfun(x_, a_, b_, c_):
- """"Model function"""
- return (a_ * x_**2 + b_ * x_ + c_) / CONFIG['d']
+ """Model function"""
+ return (a_ * x_**2 + b_ * x_ + c_) / CONFIG["d"]
def myesti(x_, y_):
- """"Initial parameters for iterative fit:
+ """Initial parameters for iterative fit:
(a, b, c) = (1, 1, 1)
Constraints all set to 0 (FREE)"""
- return (1., 1., 1.), ((0, 0, 0), (0, 0, 0), (0, 0, 0))
+ return (1.0, 1.0, 1.0), ((0, 0, 0), (0, 0, 0), (0, 0, 0))
- def myconfig(d_=1., **kw):
+ def myconfig(d_=1.0, **kw):
"""This function can modify CONFIG"""
CONFIG["d"] = d_
return CONFIG
@@ -320,41 +305,41 @@ class TestFitmanager(ParametricTestCase):
return delta_fun / delta_par
- fit.addtheory("polynomial",
- FitTheory(function=myfun,
- parameters=["A", "B", "C"],
- estimate=myesti if estimate else None,
- configure=myconfig,
- derivative=myderiv))
-
- fit.settheory('polynomial')
+ fit.addtheory(
+ "polynomial",
+ FitTheory(
+ function=myfun,
+ parameters=["A", "B", "C"],
+ estimate=myesti if estimate else None,
+ configure=myconfig,
+ derivative=myderiv,
+ ),
+ )
+
+ fit.settheory("polynomial")
fit.configure(d_=4.5)
fit.estimate()
params1, sigmas, infodict = fit.runfit()
- self.assertEqual(fit.fit_results[0]["name"],
- "A1")
- self.assertAlmostEqual(fit.fit_results[0]["fitresult"],
- -3.14)
- self.assertEqual(fit.fit_results[1]["name"],
- "B1")
+ self.assertEqual(fit.fit_results[0]["name"], "A1")
+ self.assertAlmostEqual(fit.fit_results[0]["fitresult"], -3.14)
+ self.assertEqual(fit.fit_results[1]["name"], "B1")
# params1[1] is the same as fit.fit_results[1]["fitresult"]
- self.assertAlmostEqual(params1[1],
- 1234.5)
- self.assertEqual(fit.fit_results[2]["name"],
- "C1")
- self.assertAlmostEqual(params1[2],
- 10000)
+ self.assertAlmostEqual(params1[1], 1234.5)
+ self.assertEqual(fit.fit_results[2]["name"], "C1")
+ self.assertAlmostEqual(params1[2], 10000)
# change configuration scaling factor and check that the fit returns
# different values
- fit.configure(d_=5.)
+ fit.configure(d_=5.0)
fit.estimate()
params2, sigmas, infodict = fit.runfit()
for p1, p2 in zip(params1, params2):
- self.assertFalse(numpy.array_equal(p1, p2),
- "Fit parameters are equal even though the " +
- "configuration has been changed")
+ self.assertFalse(
+ numpy.array_equal(p1, p2),
+ "Fit parameters are equal even though the "
+ + "configuration has been changed",
+ )
def testNoEstimate(self):
"""Ensure that the in the absence of the estimation function,
@@ -365,8 +350,10 @@ class TestFitmanager(ParametricTestCase):
def testStep(self):
"""Test fit manager on a step function with a more complex estimate
function than the gaussian (convolution filter)"""
- for theory_name, theory_fun in (('Step Down', sum_stepdown),
- ('Step Up', sum_stepup)):
+ for theory_name, theory_fun in (
+ ("Step Down", sum_stepdown),
+ ("Step Up", sum_stepup),
+ ):
# Create synthetic data with a sum of gaussian functions
x = numpy.arange(1000).astype(numpy.float64)
@@ -381,7 +368,7 @@ class TestFitmanager(ParametricTestCase):
fit.setdata(x=x, y=y)
fit.loadtheories(fittheories)
fit.settheory(theory_name)
- fit.setbackground('Constant')
+ fit.setbackground("Constant")
fit.estimate()
@@ -391,8 +378,10 @@ class TestFitmanager(ParametricTestCase):
self.assertAlmostEqual(params[0], 13, places=5)
for i, param in enumerate(params[1:]):
self.assertAlmostEqual(param, p[i], places=5)
- self.assertAlmostEqual(_order_of_magnitude(fit.fit_results[i+1]["estimation"]),
- _order_of_magnitude(p[i]))
+ self.assertAlmostEqual(
+ _order_of_magnitude(fit.fit_results[i + 1]["estimation"]),
+ _order_of_magnitude(p[i]),
+ )
def quadratic(x, a, b, c):
@@ -405,6 +394,7 @@ def cubic(x, a, b, c, d):
class TestPolynomials(unittest.TestCase):
"""Test polynomial fit theories and fit background"""
+
def setUp(self):
self.x = numpy.arange(100).astype(numpy.float64)
@@ -424,8 +414,7 @@ class TestPolynomials(unittest.TestCase):
fit_params = fm.runfit()[0]
for p, pfit in zip(poly_params + gaussian_params, fit_params):
- self.assertAlmostEqual(p,
- pfit)
+ self.assertAlmostEqual(p, pfit)
def testCubicBg(self):
gaussian_params = [1000, 45, 8]
@@ -442,8 +431,7 @@ class TestPolynomials(unittest.TestCase):
fit_params = fm.runfit()[0]
for p, pfit in zip(poly_params + gaussian_params, fit_params):
- self.assertAlmostEqual(p,
- pfit)
+ self.assertAlmostEqual(p, pfit)
def testQuarticcBg(self):
gaussian_params = [10000, 69, 25]
@@ -460,9 +448,7 @@ class TestPolynomials(unittest.TestCase):
fit_params = fm.runfit()[0]
for p, pfit in zip(poly_params + gaussian_params, fit_params):
- self.assertAlmostEqual(p,
- pfit,
- places=5)
+ self.assertAlmostEqual(p, pfit, places=5)
def _testPoly(self, poly_params, theory, places=5):
p = numpy.poly1d(poly_params)
@@ -480,18 +466,13 @@ class TestPolynomials(unittest.TestCase):
self.assertAlmostEqual(p, pfit, places=places)
def testQuadratic(self):
- self._testPoly([0.05, -2, 3],
- "Degree 2 Polynomial")
+ self._testPoly([0.05, -2, 3], "Degree 2 Polynomial")
def testCubic(self):
- self._testPoly([0.0005, -0.05, 3, -4],
- "Degree 3 Polynomial")
+ self._testPoly([0.0005, -0.05, 3, -4], "Degree 3 Polynomial")
def testQuartic(self):
- self._testPoly([1, -2, 3, -4, -5],
- "Degree 4 Polynomial")
+ self._testPoly([1, -2, 3, -4, -5], "Degree 4 Polynomial")
def testQuintic(self):
- self._testPoly([1, -2, 3, -4, -5, 6],
- "Degree 5 Polynomial",
- places=4)
+ self._testPoly([1, -2, 3, -4, -5, 6], "Degree 5 Polynomial", places=4)
diff --git a/src/silx/math/fit/test/test_functions.py b/src/silx/math/fit/test/test_functions.py
index 71cce8b..525925c 100644
--- a/src/silx/math/fit/test/test_functions.py
+++ b/src/silx/math/fit/test/test_functions.py
@@ -34,36 +34,59 @@ __authors__ = ["P. Knobel"]
__license__ = "MIT"
__date__ = "21/07/2016"
+
class Test_functions(unittest.TestCase):
"""
Unit tests of multi-peak functions.
"""
+
def setUp(self):
self.x = numpy.arange(11)
# height, center, sigma1, sigma2
- (h, c, s1, s2) = (7., 5., 3., 2.1)
+ (h, c, s1, s2) = (7.0, 5.0, 3.0, 2.1)
self.g_params = {
"height": h,
"center": c,
- #"sigma": s,
+ # "sigma": s,
"fwhm1": 2 * math.sqrt(2 * math.log(2)) * s1,
"fwhm2": 2 * math.sqrt(2 * math.log(2)) * s2,
- "area1": h * s1 * math.sqrt(2 * math.pi)
+ "area1": h * s1 * math.sqrt(2 * math.pi),
}
# result of `7 * scipy.signal.gaussian(11, 3)`
self.scipy_gaussian = numpy.array(
- [1.74546546, 2.87778603, 4.24571462, 5.60516182, 6.62171628,
- 7., 6.62171628, 5.60516182, 4.24571462, 2.87778603,
- 1.74546546]
+ [
+ 1.74546546,
+ 2.87778603,
+ 4.24571462,
+ 5.60516182,
+ 6.62171628,
+ 7.0,
+ 6.62171628,
+ 5.60516182,
+ 4.24571462,
+ 2.87778603,
+ 1.74546546,
+ ]
)
# result of:
# numpy.concatenate((7 * scipy.signal.gaussian(11, 3)[0:5],
# 7 * scipy.signal.gaussian(11, 2.1)[5:11]))
self.scipy_asym_gaussian = numpy.array(
- [1.74546546, 2.87778603, 4.24571462, 5.60516182, 6.62171628,
- 7., 6.24968751, 4.44773692, 2.52313452, 1.14093853, 0.41124877]
+ [
+ 1.74546546,
+ 2.87778603,
+ 4.24571462,
+ 5.60516182,
+ 6.62171628,
+ 7.0,
+ 6.24968751,
+ 4.44773692,
+ 2.52313452,
+ 1.14093853,
+ 0.41124877,
+ ]
)
def tearDown(self):
@@ -71,41 +94,48 @@ class Test_functions(unittest.TestCase):
def testGauss(self):
"""Compare sum_gauss with scipy.signals.gaussian"""
- y = functions.sum_gauss(self.x,
- self.g_params["height"],
- self.g_params["center"],
- self.g_params["fwhm1"])
+ y = functions.sum_gauss(
+ self.x,
+ self.g_params["height"],
+ self.g_params["center"],
+ self.g_params["fwhm1"],
+ )
for i in range(11):
self.assertAlmostEqual(y[i], self.scipy_gaussian[i])
def testAGauss(self):
"""Compare sum_agauss with scipy.signals.gaussian"""
- y = functions.sum_agauss(self.x,
- self.g_params["area1"],
- self.g_params["center"],
- self.g_params["fwhm1"])
+ y = functions.sum_agauss(
+ self.x,
+ self.g_params["area1"],
+ self.g_params["center"],
+ self.g_params["fwhm1"],
+ )
for i in range(11):
self.assertAlmostEqual(y[i], self.scipy_gaussian[i])
def testFastAGauss(self):
"""Compare sum_fastagauss with scipy.signals.gaussian
Limit precision to 3 decimal places."""
- y = functions.sum_fastagauss(self.x,
- self.g_params["area1"],
- self.g_params["center"],
- self.g_params["fwhm1"])
+ y = functions.sum_fastagauss(
+ self.x,
+ self.g_params["area1"],
+ self.g_params["center"],
+ self.g_params["fwhm1"],
+ )
for i in range(11):
self.assertAlmostEqual(y[i], self.scipy_gaussian[i], 3)
-
def testSplitGauss(self):
"""Compare sum_splitgauss with scipy.signals.gaussian"""
- y = functions.sum_splitgauss(self.x,
- self.g_params["height"],
- self.g_params["center"],
- self.g_params["fwhm1"],
- self.g_params["fwhm2"])
+ y = functions.sum_splitgauss(
+ self.x,
+ self.g_params["height"],
+ self.g_params["center"],
+ self.g_params["fwhm1"],
+ self.g_params["fwhm2"],
+ )
for i in range(11):
self.assertAlmostEqual(y[i], self.scipy_asym_gaussian[i])
@@ -120,18 +150,14 @@ class Test_functions(unittest.TestCase):
x = [-5, -2, -1.5, -0.6, 0, 0.1, 2, 3]
erfx = functions.erf(x)
for i in range(len(x)):
- self.assertAlmostEqual(erfx[i],
- math.erf(x[i]),
- places=5)
+ self.assertAlmostEqual(erfx[i], math.erf(x[i]), places=5)
# ndarray
x = numpy.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
erfx = functions.erf(x)
for i in range(x.shape[0]):
for j in range(x.shape[1]):
- self.assertAlmostEqual(erfx[i, j],
- math.erf(x[i, j]),
- places=5)
+ self.assertAlmostEqual(erfx[i, j], math.erf(x[i, j]), places=5)
def testErfc(self):
"""Compare erf with math.erf"""
@@ -162,15 +188,14 @@ class Test_functions(unittest.TestCase):
for x, y in zip(x0, y0):
self.assertAlmostEqual(
- 11.1 * (0.5 + math.atan((x - 22.2) / 3.33) / math.pi),
- y
+ 11.1 * (0.5 + math.atan((x - 22.2) / 3.33) / math.pi), y
)
def testStepUp(self):
"""sanity check for step up:
- - derivative must be largest around the step center
- - max value must be close to height parameter
+ - derivative must be largest around the step center
+ - max value must be close to height parameter
"""
x0 = numpy.arange(1000)
@@ -187,14 +212,13 @@ class Test_functions(unittest.TestCase):
# Test center position within +- 1 sample of max derivative
index_max_deriv = numpy.argmax(deriv0)
- self.assertLess(abs(index_max_deriv - center),
- 1)
+ self.assertLess(abs(index_max_deriv - center), 1)
def testStepDown(self):
"""sanity check for step down:
- - absolute value of derivative must be largest around the step center
- - max value must be close to height parameter
+ - absolute value of derivative must be largest around the step center
+ - max value must be close to height parameter
"""
x0 = numpy.arange(1000)
@@ -207,18 +231,19 @@ class Test_functions(unittest.TestCase):
self.assertAlmostEqual(max(y0), height, places=1)
self.assertAlmostEqual(min(y0), 0, places=1)
- deriv0 = _numerical_derivative(functions.sum_stepdown, x0, [height, center, fwhm])
+ deriv0 = _numerical_derivative(
+ functions.sum_stepdown, x0, [height, center, fwhm]
+ )
# Test center position within +- 1 sample of max derivative
index_min_deriv = numpy.argmax(-deriv0)
- self.assertLess(abs(index_min_deriv - center),
- 1)
+ self.assertLess(abs(index_min_deriv - center), 1)
def testSlit(self):
"""sanity check for slit:
- - absolute value of derivative must be largest around the step center
- - max value must be close to height parameter
+ - absolute value of derivative must be largest around the step center
+ - max value must be close to height parameter
"""
x0 = numpy.arange(1000)
@@ -231,16 +256,16 @@ class Test_functions(unittest.TestCase):
self.assertAlmostEqual(max(y0), height, places=1)
self.assertAlmostEqual(min(y0), 0, places=1)
- deriv0 = _numerical_derivative(functions.sum_slit, x0, [height, center, fwhm, beamfwhm])
+ deriv0 = _numerical_derivative(
+ functions.sum_slit, x0, [height, center, fwhm, beamfwhm]
+ )
# Test step up center position (center - fwhm/2) within +- 1 sample of max derivative
index_max_deriv = numpy.argmax(deriv0)
- self.assertLess(abs(index_max_deriv - (center - fwhm/2)),
- 1)
+ self.assertLess(abs(index_max_deriv - (center - fwhm / 2)), 1)
# Test step down center position (center + fwhm/2) within +- 1 sample of min derivative
index_min_deriv = numpy.argmin(deriv0)
- self.assertLess(abs(index_min_deriv - (center + fwhm/2)),
- 1)
+ self.assertLess(abs(index_min_deriv - (center + fwhm / 2)), 1)
def _numerical_derivative(f, x, params=[], delta_factor=0.0001):
diff --git a/src/silx/math/fit/test/test_peaks.py b/src/silx/math/fit/test/test_peaks.py
index 23e4061..d6b9db5 100644
--- a/src/silx/math/fit/test/test_peaks.py
+++ b/src/silx/math/fit/test/test_peaks.py
@@ -24,108 +24,414 @@
Tests for peaks module
"""
-import unittest
import numpy
-import math
+import pytest
from silx.math.fit import functions
from silx.math.fit import peaks
-class Test_peak_search(unittest.TestCase):
- """
- Unit tests of peak_search on various types of multi-peak functions.
- """
- def setUp(self):
- self.x = numpy.arange(5000)
- # (height1, center1, fwhm1, ...)
- self.h_c_fwhm = (50, 500, 100,
- 50, 600, 80,
- 20, 2000, 100,
- 50, 2250, 110,
- 40, 3000, 99,
- 23, 4980, 80)
- # (height1, center1, fwhm1, eta1 ...)
- self.h_c_fwhm_eta = (50, 500, 100, 0.4,
- 50, 600, 80, 0.5,
- 20, 2000, 100, 0.6,
- 50, 2250, 110, 0.7,
- 40, 3000, 99, 0.8,
- 23, 4980, 80, 0.3,)
- # (height1, center1, fwhm11, fwhm21, ...)
- self.h_c_fwhm_fwhm = (50, 500, 100, 85,
- 50, 600, 80, 110,
- 20, 2000, 100, 100,
- 50, 2250, 110, 99,
- 40, 3000, 99, 110,
- 23, 4980, 80, 80,)
- # (height1, center1, fwhm11, fwhm21, eta1 ...)
- self.h_c_fwhm_fwhm_eta = (50, 500, 100, 85, 0.4,
- 50, 600, 80, 110, 0.5,
- 20, 2000, 100, 100, 0.6,
- 50, 2250, 110, 99, 0.7,
- 40, 3000, 99, 110, 0.8,
- 23, 4980, 80, 80, 0.3,)
- # (area1, center1, fwhm1, ...)
- self.a_c_fwhm = (2550, 500, 100,
- 2000, 600, 80,
- 500, 2000, 100,
- 4000, 2250, 110,
- 2300, 3000, 99,
- 3333, 4980, 80)
- # (area1, center1, fwhm1, eta1 ...)
- self.a_c_fwhm_eta = (500, 500, 100, 0.4,
- 500, 600, 80, 0.5,
- 200, 2000, 100, 0.6,
- 500, 2250, 110, 0.7,
- 400, 3000, 99, 0.8,
- 230, 4980, 80, 0.3,)
- # (area, position, fwhm, st_area_r, st_slope_r, lt_area_r, lt_slope_r, step_height_r)
- self.hypermet_params = (1000, 500, 200, 0.2, 100, 0.3, 100, 0.05,
- 1000, 1000, 200, 0.2, 100, 0.3, 100, 0.05,
- 1000, 2000, 200, 0.2, 100, 0.3, 100, 0.05,
- 1000, 2350, 200, 0.2, 100, 0.3, 100, 0.05,
- 1000, 3000, 200, 0.2, 100, 0.3, 100, 0.05,
- 1000, 4900, 200, 0.2, 100, 0.3, 100, 0.05,)
+_PEAK_PARAMETERS = {
+ "sum_gauss": (
+ 50,
+ 500,
+ 100,
+ 50,
+ 600,
+ 80,
+ 20,
+ 2000,
+ 100,
+ 50,
+ 2250,
+ 110,
+ 40,
+ 3000,
+ 99,
+ 23,
+ 4980,
+ 80,
+ ),
+ "sum_lorentz": (
+ 50,
+ 500,
+ 100,
+ 50,
+ 600,
+ 80,
+ 20,
+ 2000,
+ 100,
+ 50,
+ 2250,
+ 110,
+ 40,
+ 3000,
+ 99,
+ 23,
+ 4980,
+ 80,
+ ),
+ "sum_pvoigt": (
+ 50,
+ 500,
+ 100,
+ 0.4,
+ 50,
+ 600,
+ 80,
+ 0.5,
+ 20,
+ 2000,
+ 100,
+ 0.6,
+ 50,
+ 2250,
+ 110,
+ 0.7,
+ 40,
+ 3000,
+ 99,
+ 0.8,
+ 23,
+ 4980,
+ 80,
+ 0.3,
+ ),
+ "sum_splitgauss": (
+ 50,
+ 500,
+ 100,
+ 85,
+ 50,
+ 600,
+ 80,
+ 110,
+ 20,
+ 2000,
+ 100,
+ 100,
+ 50,
+ 2250,
+ 110,
+ 99,
+ 40,
+ 3000,
+ 99,
+ 110,
+ 23,
+ 4980,
+ 80,
+ 80,
+ ),
+ "sum_splitlorentz": (
+ 50,
+ 500,
+ 100,
+ 85,
+ 50,
+ 600,
+ 80,
+ 110,
+ 20,
+ 2000,
+ 100,
+ 100,
+ 50,
+ 2250,
+ 110,
+ 99,
+ 40,
+ 3000,
+ 99,
+ 110,
+ 23,
+ 4980,
+ 80,
+ 80,
+ ),
+ "sum_splitpvoigt": (
+ 50,
+ 500,
+ 100,
+ 85,
+ 0.4,
+ 50,
+ 600,
+ 80,
+ 110,
+ 0.5,
+ 20,
+ 2000,
+ 100,
+ 100,
+ 0.6,
+ 50,
+ 2250,
+ 110,
+ 99,
+ 0.7,
+ 40,
+ 3000,
+ 99,
+ 110,
+ 0.8,
+ 23,
+ 4980,
+ 80,
+ 80,
+ 0.3,
+ ),
+ "sum_splitpvoigt2": (
+ 50,
+ 500,
+ 100,
+ 85,
+ 0.4,
+ 0.7,
+ 50,
+ 600,
+ 80,
+ 110,
+ 0.5,
+ 0.3,
+ 20,
+ 2000,
+ 100,
+ 100,
+ 0.6,
+ 0.4,
+ 50,
+ 2250,
+ 110,
+ 99,
+ 0.7,
+ 1,
+ 40,
+ 3000,
+ 99,
+ 110,
+ 0.8,
+ 0,
+ 23,
+ 4980,
+ 80,
+ 80,
+ 0.3,
+ 0.5,
+ ),
+ "sum_agauss": (
+ 2550,
+ 500,
+ 100,
+ 2000,
+ 600,
+ 80,
+ 500,
+ 2000,
+ 100,
+ 4000,
+ 2250,
+ 110,
+ 2300,
+ 3000,
+ 99,
+ 3333,
+ 4980,
+ 80,
+ ),
+ "sum_fastagauss": (
+ 2550,
+ 500,
+ 100,
+ 2000,
+ 600,
+ 80,
+ 500,
+ 2000,
+ 100,
+ 4000,
+ 2250,
+ 110,
+ 2300,
+ 3000,
+ 99,
+ 3333,
+ 4980,
+ 80,
+ ),
+ "sum_alorentz": (
+ 2550,
+ 500,
+ 100,
+ 2000,
+ 600,
+ 80,
+ 500,
+ 2000,
+ 100,
+ 4000,
+ 2250,
+ 110,
+ 2300,
+ 3000,
+ 99,
+ 3333,
+ 4980,
+ 80,
+ ),
+ "sum_apvoigt": (
+ 500,
+ 500,
+ 100,
+ 0.4,
+ 500,
+ 600,
+ 80,
+ 0.5,
+ 200,
+ 2000,
+ 100,
+ 0.6,
+ 500,
+ 2250,
+ 110,
+ 0.7,
+ 400,
+ 3000,
+ 99,
+ 0.8,
+ 230,
+ 4980,
+ 80,
+ 0.3,
+ ),
+ "sum_ahypermet": (
+ 1000,
+ 500,
+ 200,
+ 0.2,
+ 100,
+ 0.3,
+ 100,
+ 0.05,
+ 1000,
+ 1000,
+ 200,
+ 0.2,
+ 100,
+ 0.3,
+ 100,
+ 0.05,
+ 1000,
+ 2000,
+ 200,
+ 0.2,
+ 100,
+ 0.3,
+ 100,
+ 0.05,
+ 1000,
+ 2350,
+ 200,
+ 0.2,
+ 100,
+ 0.3,
+ 100,
+ 0.05,
+ 1000,
+ 3000,
+ 200,
+ 0.2,
+ 100,
+ 0.3,
+ 100,
+ 0.05,
+ 1000,
+ 4900,
+ 200,
+ 0.2,
+ 100,
+ 0.3,
+ 100,
+ 0.05,
+ ),
+ "sum_fastahypermet": (
+ 1000,
+ 500,
+ 200,
+ 0.2,
+ 100,
+ 0.3,
+ 100,
+ 0.05,
+ 1000,
+ 1000,
+ 200,
+ 0.2,
+ 100,
+ 0.3,
+ 100,
+ 0.05,
+ 1000,
+ 2000,
+ 200,
+ 0.2,
+ 100,
+ 0.3,
+ 100,
+ 0.05,
+ 1000,
+ 2350,
+ 200,
+ 0.2,
+ 100,
+ 0.3,
+ 100,
+ 0.05,
+ 1000,
+ 3000,
+ 200,
+ 0.2,
+ 100,
+ 0.3,
+ 100,
+ 0.05,
+ 1000,
+ 4900,
+ 200,
+ 0.2,
+ 100,
+ 0.3,
+ 100,
+ 0.05,
+ ),
+}
- def tearDown(self):
- pass
+@pytest.mark.parametrize("peak_profile", list(_PEAK_PARAMETERS))
+def test_peak_functions(peak_profile):
+ x = numpy.arange(5000)
+ peak_params = _PEAK_PARAMETERS[peak_profile]
+ func = getattr(functions, peak_profile)
- def get_peaks(self, function, params):
- """
+ with pytest.raises(IndexError):
+ func(x)
+ with pytest.raises(IndexError):
+ func(x, *peak_params, 0)
- :param function: Multi-peak function
- :param params: Parameter for this function
- :return: list of (peak, relevance) tuples
- """
- y = function(self.x, *params)
- return peaks.peak_search(y=y, fwhm=100, relevance_info=True)
+ y = func(x, *peak_params)
+ assert x.shape == y.shape
- def testPeakSearch_various_functions(self):
- """Run peak search on a variety of synthetic functions, and
- check that result falls within +-25 samples of the actual peak
- (reasonable delta considering a fwhm of ~100 samples) and effects
- of overlapping peaks)."""
- f_p = ((functions.sum_gauss, self.h_c_fwhm ),
- (functions.sum_lorentz, self.h_c_fwhm),
- (functions.sum_pvoigt, self.h_c_fwhm_eta),
- (functions.sum_splitgauss, self.h_c_fwhm_fwhm),
- (functions.sum_splitlorentz, self.h_c_fwhm_fwhm),
- (functions.sum_splitpvoigt, self.h_c_fwhm_fwhm_eta),
- (functions.sum_agauss, self.a_c_fwhm),
- (functions.sum_fastagauss, self.a_c_fwhm),
- (functions.sum_alorentz, self.a_c_fwhm),
- (functions.sum_apvoigt, self.a_c_fwhm_eta),
- (functions.sum_ahypermet, self.hypermet_params),
- (functions.sum_fastahypermet, self.hypermet_params),)
- for function, params in f_p:
- peaks = self.get_peaks(function, params)
+@pytest.mark.parametrize("peak_profile", list(_PEAK_PARAMETERS))
+def test_peak_search(peak_profile):
+ x = numpy.arange(5000)
+ peak_params = _PEAK_PARAMETERS[peak_profile]
+ func = getattr(functions, peak_profile)
+ y = func(x, *peak_params)
+ estimated_peak_params = peaks.peak_search(y=y, fwhm=100, relevance_info=True)
- self.assertEqual(len(peaks), 6,
- "Wrong number of peaks detected")
-
- for i in range(6):
- theoretical_peak_index = params[i*(len(params)//6) + 1]
- found_peak_index = peaks[i][0]
- self.assertLess(abs(found_peak_index - theoretical_peak_index), 25)
+ assert len(estimated_peak_params) == 6, "Wrong number of peaks detected"
+ for i, (peak_position, *_) in enumerate(estimated_peak_params):
+ theoretical_peak_position = peak_params[i * (len(peak_params) // 6) + 1]
+ assert abs(peak_position - theoretical_peak_position) < 25
diff --git a/src/silx/math/histogram.py b/src/silx/math/histogram.py
index e00daa9..d22ab1f 100644
--- a/src/silx/math/histogram.py
+++ b/src/silx/math/histogram.py
@@ -152,15 +152,17 @@ class Histogramnd(object):
Computes the multidimensional histogram of some data.
"""
- def __init__(self,
- sample,
- histo_range,
- n_bins,
- weights=None,
- weight_min=None,
- weight_max=None,
- last_bin_closed=False,
- wh_dtype=None):
+ def __init__(
+ self,
+ sample,
+ histo_range,
+ n_bins,
+ weights=None,
+ weight_min=None,
+ weight_max=None,
+ last_bin_closed=False,
+ wh_dtype=None,
+ ):
"""
:param sample:
The data to be histogrammed.
@@ -240,14 +242,16 @@ class Histogramnd(object):
if sample is None:
self.__data = [None, None, None]
else:
- self.__data = _chistogramnd(sample,
- self.__histo_range,
- self.__n_bins,
- weights=weights,
- weight_min=weight_min,
- weight_max=weight_max,
- last_bin_closed=self.__last_bin_closed,
- wh_dtype=self.__wh_dtype)
+ self.__data = _chistogramnd(
+ sample,
+ self.__histo_range,
+ self.__n_bins,
+ weights=weights,
+ weight_min=weight_min,
+ weight_max=weight_max,
+ last_bin_closed=self.__last_bin_closed,
+ wh_dtype=self.__wh_dtype,
+ )
def __getitem__(self, key):
"""
@@ -263,11 +267,7 @@ class Histogramnd(object):
"""
return self.__data[key]
- def accumulate(self,
- sample,
- weights=None,
- weight_min=None,
- weight_max=None):
+ def accumulate(self, sample, weights=None, weight_min=None, weight_max=None):
"""
Computes the multidimensional histogram of some data and accumulates it
into the histogram held by this instance of Histogramnd.
@@ -315,16 +315,18 @@ class Histogramnd(object):
as *weights*.
:type weight_max: *optional*, scalar
"""
- result = _chistogramnd(sample,
- self.__histo_range,
- self.__n_bins,
- weights=weights,
- weight_min=weight_min,
- weight_max=weight_max,
- last_bin_closed=self.__last_bin_closed,
- histo=self.__data[0],
- weighted_histo=self.__data[1],
- wh_dtype=self.__wh_dtype)
+ result = _chistogramnd(
+ sample,
+ self.__histo_range,
+ self.__n_bins,
+ weights=weights,
+ weight_min=weight_min,
+ weight_max=weight_max,
+ last_bin_closed=self.__last_bin_closed,
+ histo=self.__data[0],
+ weighted_histo=self.__data[1],
+ wh_dtype=self.__wh_dtype,
+ )
if self.__data[0] is None:
self.__data = result
elif self.__data[1] is None and result[1] is not None:
@@ -357,12 +359,7 @@ class HistogramndLut(object):
share the same coordinates (*sample*) have to be mapped onto the same grid.
"""
- def __init__(self,
- sample,
- histo_range,
- n_bins,
- last_bin_closed=False,
- dtype=None):
+ def __init__(self, sample, histo_range, n_bins, last_bin_closed=False, dtype=None):
"""
:param sample:
The coordinates of the data to be histogrammed.
@@ -397,10 +394,9 @@ class HistogramndLut(object):
the LAST bin to be closed.
:type last_bin_closed: *optional*, :class:`python.boolean`
"""
- lut, histo, edges = _histo_get_lut(sample,
- histo_range,
- n_bins,
- last_bin_closed=last_bin_closed)
+ lut, histo, edges = _histo_get_lut(
+ sample, histo_range, n_bins, last_bin_closed=last_bin_closed
+ )
self.__n_bins = np.array(histo.shape)
self.__histo_range = histo_range
@@ -477,10 +473,7 @@ class HistogramndLut(object):
"""
return self.__last_bin_closed
- def accumulate(self,
- weights,
- weight_min=None,
- weight_max=None):
+ def accumulate(self, weights, weight_min=None, weight_max=None):
"""
Computes the multidimensional histogram of some data and adds it to
the current histogram stored by this instance. The results can be
@@ -513,14 +506,16 @@ class HistogramndLut(object):
if self.__dtype is None:
self.__dtype = weights.dtype
- histo, w_histo = _histo_from_lut(weights,
- self.__lut,
- histo=self.__histo,
- weighted_histo=self.__weighted_histo,
- shape=self.__shape,
- dtype=self.__dtype,
- weight_min=weight_min,
- weight_max=weight_max)
+ histo, w_histo = _histo_from_lut(
+ weights,
+ self.__lut,
+ histo=self.__histo,
+ weighted_histo=self.__weighted_histo,
+ shape=self.__shape,
+ dtype=self.__dtype,
+ weight_min=weight_min,
+ weight_max=weight_max,
+ )
if self.__histo is None:
self.__histo = histo
@@ -528,12 +523,9 @@ class HistogramndLut(object):
if self.__weighted_histo is None:
self.__weighted_histo = w_histo
- def apply_lut(self,
- weights,
- histo=None,
- weighted_histo=None,
- weight_min=None,
- weight_max=None):
+ def apply_lut(
+ self, weights, histo=None, weighted_histo=None, weight_min=None, weight_max=None
+ ):
"""
Computes the multidimensional histogram of some data and returns the
result (it is NOT added to the current histogram stored by this
@@ -577,16 +569,19 @@ class HistogramndLut(object):
as *weights*.
:type weight_max: *optional*, scalar
"""
- histo, w_histo = _histo_from_lut(weights,
- self.__lut,
- histo=histo,
- weighted_histo=weighted_histo,
- shape=self.__shape,
- dtype=self.__dtype,
- weight_min=weight_min,
- weight_max=weight_max)
+ histo, w_histo = _histo_from_lut(
+ weights,
+ self.__lut,
+ histo=histo,
+ weighted_histo=weighted_histo,
+ shape=self.__shape,
+ dtype=self.__dtype,
+ weight_min=weight_min,
+ weight_max=weight_max,
+ )
self.__dtype = w_histo.dtype
return histo, w_histo
-if __name__ == '__main__':
+
+if __name__ == "__main__":
pass
diff --git a/src/silx/math/histogramnd/include/histogramnd_c.h b/src/silx/math/histogramnd/include/histogramnd_c.h
index 25293b9..8d6365c 100644
--- a/src/silx/math/histogramnd/include/histogramnd_c.h
+++ b/src/silx/math/histogramnd/include/histogramnd_c.h
@@ -1,5 +1,5 @@
/*##########################################################################
-# Copyright (C) 2016 European Synchrotron Radiation Facility
+# Copyright (C) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -24,13 +24,7 @@
#ifndef HISTOGRAMND_C_H
#define HISTOGRAMND_C_H
-/* checking for MSVC version because VS 2008 doesnt fully support C99
- so inttypes.h and stdint.h are not provided with the compiler. */
-#if defined(_MSC_VER) && _MSC_VER < 1600
- #include "msvc/stdint.h"
-#else
- #include <inttypes.h>
-#endif
+#include <inttypes.h>
#include <stddef.h>
#include "templates.h"
diff --git a/src/silx/math/histogramnd/include/msvc/stdint.h b/src/silx/math/histogramnd/include/msvc/stdint.h
deleted file mode 100644
index e236bb0..0000000
--- a/src/silx/math/histogramnd/include/msvc/stdint.h
+++ /dev/null
@@ -1,247 +0,0 @@
-// ISO C9x compliant stdint.h for Microsoft Visual Studio
-// Based on ISO/IEC 9899:TC2 Committee draft (May 6, 2005) WG14/N1124
-//
-// Copyright (c) 2006-2008 Alexander Chemeris
-//
-// Redistribution and use in source and binary forms, with or without
-// modification, are permitted provided that the following conditions are met:
-//
-// 1. Redistributions of source code must retain the above copyright notice,
-// this list of conditions and the following disclaimer.
-//
-// 2. Redistributions in binary form must reproduce the above copyright
-// notice, this list of conditions and the following disclaimer in the
-// documentation and/or other materials provided with the distribution.
-//
-// 3. The name of the author may be used to endorse or promote products
-// derived from this software without specific prior written permission.
-//
-// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
-// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
-// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
-// EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
-// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
-// OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
-// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
-// OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
-// ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-//
-///////////////////////////////////////////////////////////////////////////////
-
-#ifndef _MSC_VER // [
-#error "Use this header only with Microsoft Visual C++ compilers!"
-#endif // _MSC_VER ]
-
-#ifndef _MSC_STDINT_H_ // [
-#define _MSC_STDINT_H_
-
-#if _MSC_VER > 1000
-#pragma once
-#endif
-
-#include <limits.h>
-
-// For Visual Studio 6 in C++ mode and for many Visual Studio versions when
-// compiling for ARM we should wrap <wchar.h> include with 'extern "C++" {}'
-// or compiler give many errors like this:
-// error C2733: second C linkage of overloaded function 'wmemchr' not allowed
-#ifdef __cplusplus
-extern "C" {
-#endif
-# include <wchar.h>
-#ifdef __cplusplus
-}
-#endif
-
-// Define _W64 macros to mark types changing their size, like intptr_t.
-#ifndef _W64
-# if !defined(__midl) && (defined(_X86_) || defined(_M_IX86)) && _MSC_VER >= 1300
-# define _W64 __w64
-# else
-# define _W64
-# endif
-#endif
-
-
-// 7.18.1 Integer types
-
-// 7.18.1.1 Exact-width integer types
-
-// Visual Studio 6 and Embedded Visual C++ 4 doesn't
-// realize that, e.g. char has the same size as __int8
-// so we give up on __intX for them.
-#if (_MSC_VER < 1300)
- typedef char int8_t;
- typedef short int16_t;
- typedef int int32_t;
- typedef unsigned char uint8_t;
- typedef unsigned short uint16_t;
- typedef unsigned int uint32_t;
-#else
- typedef __int8 int8_t;
- typedef __int16 int16_t;
- typedef __int32 int32_t;
- typedef unsigned __int8 uint8_t;
- typedef unsigned __int16 uint16_t;
- typedef unsigned __int32 uint32_t;
-#endif
-typedef __int64 int64_t;
-typedef unsigned __int64 uint64_t;
-
-
-// 7.18.1.2 Minimum-width integer types
-typedef int8_t int_least8_t;
-typedef int16_t int_least16_t;
-typedef int32_t int_least32_t;
-typedef int64_t int_least64_t;
-typedef uint8_t uint_least8_t;
-typedef uint16_t uint_least16_t;
-typedef uint32_t uint_least32_t;
-typedef uint64_t uint_least64_t;
-
-// 7.18.1.3 Fastest minimum-width integer types
-typedef int8_t int_fast8_t;
-typedef int16_t int_fast16_t;
-typedef int32_t int_fast32_t;
-typedef int64_t int_fast64_t;
-typedef uint8_t uint_fast8_t;
-typedef uint16_t uint_fast16_t;
-typedef uint32_t uint_fast32_t;
-typedef uint64_t uint_fast64_t;
-
-// 7.18.1.4 Integer types capable of holding object pointers
-#ifdef _WIN64 // [
- typedef __int64 intptr_t;
- typedef unsigned __int64 uintptr_t;
-#else // _WIN64 ][
- typedef _W64 int intptr_t;
- typedef _W64 unsigned int uintptr_t;
-#endif // _WIN64 ]
-
-// 7.18.1.5 Greatest-width integer types
-typedef int64_t intmax_t;
-typedef uint64_t uintmax_t;
-
-
-// 7.18.2 Limits of specified-width integer types
-
-#if !defined(__cplusplus) || defined(__STDC_LIMIT_MACROS) // [ See footnote 220 at page 257 and footnote 221 at page 259
-
-// 7.18.2.1 Limits of exact-width integer types
-#define INT8_MIN ((int8_t)_I8_MIN)
-#define INT8_MAX _I8_MAX
-#define INT16_MIN ((int16_t)_I16_MIN)
-#define INT16_MAX _I16_MAX
-#define INT32_MIN ((int32_t)_I32_MIN)
-#define INT32_MAX _I32_MAX
-#define INT64_MIN ((int64_t)_I64_MIN)
-#define INT64_MAX _I64_MAX
-#define UINT8_MAX _UI8_MAX
-#define UINT16_MAX _UI16_MAX
-#define UINT32_MAX _UI32_MAX
-#define UINT64_MAX _UI64_MAX
-
-// 7.18.2.2 Limits of minimum-width integer types
-#define INT_LEAST8_MIN INT8_MIN
-#define INT_LEAST8_MAX INT8_MAX
-#define INT_LEAST16_MIN INT16_MIN
-#define INT_LEAST16_MAX INT16_MAX
-#define INT_LEAST32_MIN INT32_MIN
-#define INT_LEAST32_MAX INT32_MAX
-#define INT_LEAST64_MIN INT64_MIN
-#define INT_LEAST64_MAX INT64_MAX
-#define UINT_LEAST8_MAX UINT8_MAX
-#define UINT_LEAST16_MAX UINT16_MAX
-#define UINT_LEAST32_MAX UINT32_MAX
-#define UINT_LEAST64_MAX UINT64_MAX
-
-// 7.18.2.3 Limits of fastest minimum-width integer types
-#define INT_FAST8_MIN INT8_MIN
-#define INT_FAST8_MAX INT8_MAX
-#define INT_FAST16_MIN INT16_MIN
-#define INT_FAST16_MAX INT16_MAX
-#define INT_FAST32_MIN INT32_MIN
-#define INT_FAST32_MAX INT32_MAX
-#define INT_FAST64_MIN INT64_MIN
-#define INT_FAST64_MAX INT64_MAX
-#define UINT_FAST8_MAX UINT8_MAX
-#define UINT_FAST16_MAX UINT16_MAX
-#define UINT_FAST32_MAX UINT32_MAX
-#define UINT_FAST64_MAX UINT64_MAX
-
-// 7.18.2.4 Limits of integer types capable of holding object pointers
-#ifdef _WIN64 // [
-# define INTPTR_MIN INT64_MIN
-# define INTPTR_MAX INT64_MAX
-# define UINTPTR_MAX UINT64_MAX
-#else // _WIN64 ][
-# define INTPTR_MIN INT32_MIN
-# define INTPTR_MAX INT32_MAX
-# define UINTPTR_MAX UINT32_MAX
-#endif // _WIN64 ]
-
-// 7.18.2.5 Limits of greatest-width integer types
-#define INTMAX_MIN INT64_MIN
-#define INTMAX_MAX INT64_MAX
-#define UINTMAX_MAX UINT64_MAX
-
-// 7.18.3 Limits of other integer types
-
-#ifdef _WIN64 // [
-# define PTRDIFF_MIN _I64_MIN
-# define PTRDIFF_MAX _I64_MAX
-#else // _WIN64 ][
-# define PTRDIFF_MIN _I32_MIN
-# define PTRDIFF_MAX _I32_MAX
-#endif // _WIN64 ]
-
-#define SIG_ATOMIC_MIN INT_MIN
-#define SIG_ATOMIC_MAX INT_MAX
-
-#ifndef SIZE_MAX // [
-# ifdef _WIN64 // [
-# define SIZE_MAX _UI64_MAX
-# else // _WIN64 ][
-# define SIZE_MAX _UI32_MAX
-# endif // _WIN64 ]
-#endif // SIZE_MAX ]
-
-// WCHAR_MIN and WCHAR_MAX are also defined in <wchar.h>
-#ifndef WCHAR_MIN // [
-# define WCHAR_MIN 0
-#endif // WCHAR_MIN ]
-#ifndef WCHAR_MAX // [
-# define WCHAR_MAX _UI16_MAX
-#endif // WCHAR_MAX ]
-
-#define WINT_MIN 0
-#define WINT_MAX _UI16_MAX
-
-#endif // __STDC_LIMIT_MACROS ]
-
-
-// 7.18.4 Limits of other integer types
-
-#if !defined(__cplusplus) || defined(__STDC_CONSTANT_MACROS) // [ See footnote 224 at page 260
-
-// 7.18.4.1 Macros for minimum-width integer constants
-
-#define INT8_C(val) val##i8
-#define INT16_C(val) val##i16
-#define INT32_C(val) val##i32
-#define INT64_C(val) val##i64
-
-#define UINT8_C(val) val##ui8
-#define UINT16_C(val) val##ui16
-#define UINT32_C(val) val##ui32
-#define UINT64_C(val) val##ui64
-
-// 7.18.4.2 Macros for greatest-width integer constants
-#define INTMAX_C INT64_C
-#define UINTMAX_C UINT64_C
-
-#endif // __STDC_CONSTANT_MACROS ]
-
-
-#endif // _MSC_STDINT_H_ ]
diff --git a/src/silx/math/medianfilter/__init__.py b/src/silx/math/medianfilter/__init__.py
index 5c199e3..7e0863d 100644
--- a/src/silx/math/medianfilter/__init__.py
+++ b/src/silx/math/medianfilter/__init__.py
@@ -26,4 +26,4 @@ __license__ = "MIT"
__date__ = "02/05/2017"
-from .medianfilter import (medfilt, medfilt1d, medfilt2d)
+from .medianfilter import medfilt, medfilt1d, medfilt2d
diff --git a/src/silx/math/medianfilter/test/benchmark.py b/src/silx/math/medianfilter/test/benchmark.py
index ebe4ac4..284e3bc 100644
--- a/src/silx/math/medianfilter/test/benchmark.py
+++ b/src/silx/math/medianfilter/test/benchmark.py
@@ -73,9 +73,7 @@ class BenchmarkMedianFilter(object):
medfilt2d_silx(self.img, width)
def execScipy():
- scipy.ndimage.median_filter(input=self.img,
- size=width,
- mode='nearest')
+ scipy.ndimage.median_filter(input=self.img, size=width, mode="nearest")
def execPymca():
medfilt2d_pymca(self.img, width)
@@ -85,18 +83,21 @@ class BenchmarkMedianFilter(object):
t = Timer(execSilx)
execTime["silx"] = t.timeit(BenchmarkMedianFilter.NB_ITER)
logger.info(
- 'exec time silx (kernel size = %s) is %s' % (width, execTime["silx"]))
+ "exec time silx (kernel size = %s) is %s" % (width, execTime["silx"])
+ )
if scipy is not None:
t = Timer(execScipy)
execTime["scipy"] = t.timeit(BenchmarkMedianFilter.NB_ITER)
logger.info(
- 'exec time scipy (kernel size = %s) is %s' % (width, execTime["scipy"]))
+ "exec time scipy (kernel size = %s) is %s" % (width, execTime["scipy"])
+ )
if pymca is not None:
t = Timer(execPymca)
execTime["pymca"] = t.timeit(BenchmarkMedianFilter.NB_ITER)
logger.info(
- 'exec time pymca (kernel size = %s) is %s' % (width, execTime["pymca"]))
+ "exec time pymca (kernel size = %s) is %s" % (width, execTime["pymca"])
+ )
return execTime
@@ -111,11 +112,11 @@ app = qt.QApplication([])
kernels = [3, 5, 7, 11, 15]
benchmark = BenchmarkMedianFilter(imageWidth=1000, kernels=kernels)
plot = Plot1D()
-plot.addCurve(x=kernels, y=benchmark.getExecTimeFor("silx"), legend='silx')
+plot.addCurve(x=kernels, y=benchmark.getExecTimeFor("silx"), legend="silx")
if scipy is not None:
- plot.addCurve(x=kernels, y=benchmark.getExecTimeFor("scipy"), legend='scipy')
+ plot.addCurve(x=kernels, y=benchmark.getExecTimeFor("scipy"), legend="scipy")
if pymca is not None:
- plot.addCurve(x=kernels, y=benchmark.getExecTimeFor("pymca"), legend='pymca')
+ plot.addCurve(x=kernels, y=benchmark.getExecTimeFor("pymca"), legend="pymca")
plot.show()
app.exec()
del app
diff --git a/src/silx/math/medianfilter/test/test_medianfilter.py b/src/silx/math/medianfilter/test/test_medianfilter.py
index 15ee92e..62b1338 100644
--- a/src/silx/math/medianfilter/test/test_medianfilter.py
+++ b/src/silx/math/medianfilter/test/test_medianfilter.py
@@ -1,5 +1,5 @@
# ##########################################################################
-# Copyright (C) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (C) 2017-2022 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -32,29 +32,35 @@ from silx.math.medianfilter import medfilt2d, medfilt1d
from silx.math.medianfilter.medianfilter import reflect, mirror
from silx.math.medianfilter.medianfilter import MODES as silx_mf_modes
from silx.utils.testutils import ParametricTestCase
+
try:
import scipy
- import scipy.misc
except:
scipy = None
else:
+ try:
+ from scipy.misc import ascent
+ except:
+ from scipy.datasets import ascent
import scipy.ndimage
import logging
+
_logger = logging.getLogger(__name__)
-RANDOM_FLOAT_MAT = numpy.array([
- [0.05564293, 0.62717157, 0.75002406, 0.40555336, 0.70278975],
- [0.76532598, 0.02839148, 0.05272484, 0.65166994, 0.42161216],
- [0.23067427, 0.74219128, 0.56049024, 0.44406320, 0.28773158],
- [0.81025249, 0.20303021, 0.68382382, 0.46372299, 0.81281709],
- [0.94691602, 0.07813661, 0.81651256, 0.84220106, 0.33623165]])
+RANDOM_FLOAT_MAT = numpy.array(
+ [
+ [0.05564293, 0.62717157, 0.75002406, 0.40555336, 0.70278975],
+ [0.76532598, 0.02839148, 0.05272484, 0.65166994, 0.42161216],
+ [0.23067427, 0.74219128, 0.56049024, 0.44406320, 0.28773158],
+ [0.81025249, 0.20303021, 0.68382382, 0.46372299, 0.81281709],
+ [0.94691602, 0.07813661, 0.81651256, 0.84220106, 0.33623165],
+ ]
+)
-RANDOM_INT_MAT = numpy.array([
- [0, 5, 2, 6, 1],
- [2, 3, 1, 7, 1],
- [9, 8, 6, 7, 8],
- [5, 6, 8, 2, 4]])
+RANDOM_INT_MAT = numpy.array(
+ [[0, 5, 2, 6, 1], [2, 3, 1, 7, 1], [9, 8, 6, 7, 8], [5, 6, 8, 2, 4]]
+)
class TestMedianFilterNearest(ParametricTestCase):
@@ -65,10 +71,9 @@ class TestMedianFilterNearest(ParametricTestCase):
dataIn = numpy.arange(100, dtype=numpy.int32)
dataIn = dataIn.reshape((10, 10))
- dataOut = medfilt2d(image=dataIn,
- kernel_size=(3, 3),
- conditional=False,
- mode='nearest')
+ dataOut = medfilt2d(
+ image=dataIn, kernel_size=(3, 3), conditional=False, mode="nearest"
+ )
self.assertTrue(dataOut[0, 0] == 1)
self.assertTrue(dataOut[9, 0] == 90)
self.assertTrue(dataOut[9, 9] == 98)
@@ -80,15 +85,11 @@ class TestMedianFilterNearest(ParametricTestCase):
def testFilter3_9(self):
"Test median filter on a 3x3 matrix with a 3x3 kernel."
- dataIn = numpy.array([0, -1, 1,
- 12, 6, -2,
- 100, 4, 12],
- dtype=numpy.int16)
+ dataIn = numpy.array([0, -1, 1, 12, 6, -2, 100, 4, 12], dtype=numpy.int16)
dataIn = dataIn.reshape((3, 3))
- dataOut = medfilt2d(image=dataIn,
- kernel_size=(3, 3),
- conditional=False,
- mode='nearest')
+ dataOut = medfilt2d(
+ image=dataIn, kernel_size=(3, 3), conditional=False, mode="nearest"
+ )
self.assertTrue(dataOut.shape == dataIn.shape)
self.assertTrue(dataOut[1, 1] == 4)
self.assertTrue(dataOut[0, 0] == 0)
@@ -96,24 +97,25 @@ class TestMedianFilterNearest(ParametricTestCase):
self.assertTrue(dataOut[1, 0] == 6)
def testFilterWidthOne(self):
- """Make sure a filter of one by one give the same result as the input
- """
+ """Make sure a filter of one by one give the same result as the input"""
dataIn = numpy.arange(100, dtype=numpy.int32)
dataIn = dataIn.reshape((10, 10))
- dataOut = medfilt2d(image=dataIn,
- kernel_size=(1, 1),
- conditional=False,
- mode='nearest')
+ dataOut = medfilt2d(
+ image=dataIn, kernel_size=(1, 1), conditional=False, mode="nearest"
+ )
self.assertTrue(numpy.array_equal(dataIn, dataOut))
def testFilter3_1d(self):
"""Test binding and result of the 1d filter"""
- self.assertTrue(numpy.array_equal(
- medfilt1d(RANDOM_INT_MAT[0], kernel_size=3, conditional=False,
- mode='nearest'),
- [0, 2, 5, 2, 1])
+ self.assertTrue(
+ numpy.array_equal(
+ medfilt1d(
+ RANDOM_INT_MAT[0], kernel_size=3, conditional=False, mode="nearest"
+ ),
+ [0, 2, 5, 2, 1],
+ )
)
def testFilter3Conditionnal(self):
@@ -123,10 +125,9 @@ class TestMedianFilterNearest(ParametricTestCase):
dataIn = numpy.arange(100, dtype=numpy.int32)
dataIn = dataIn.reshape((10, 10))
- dataOut = medfilt2d(image=dataIn,
- kernel_size=(3, 3),
- conditional=True,
- mode='nearest')
+ dataOut = medfilt2d(
+ image=dataIn, kernel_size=(3, 3), conditional=True, mode="nearest"
+ )
self.assertTrue(dataOut[0, 0] == 1)
self.assertTrue(dataOut[0, 1] == 1)
self.assertTrue(numpy.array_equal(dataOut[1:8, 1:8], dataIn[1:8, 1:8]))
@@ -136,10 +137,9 @@ class TestMedianFilterNearest(ParametricTestCase):
"""Simple test of a 3x3 median filter on a 1D array"""
dataIn = numpy.arange(100, dtype=numpy.int32)
- dataOut = medfilt2d(image=dataIn,
- kernel_size=(5),
- conditional=False,
- mode='nearest')
+ dataOut = medfilt2d(
+ image=dataIn, kernel_size=(5), conditional=False, mode="nearest"
+ )
self.assertTrue(dataOut[0] == 0)
self.assertTrue(dataOut[9] == 9)
@@ -148,22 +148,20 @@ class TestMedianFilterNearest(ParametricTestCase):
def testNaNs(self):
"""Test median filter on image with NaNs in nearest mode"""
# Data with a NaN in first corner
- nan_corner = numpy.arange(100.).reshape(10, 10)
+ nan_corner = numpy.arange(100.0).reshape(10, 10)
nan_corner[0, 0] = numpy.nan
- output = medfilt2d(
- nan_corner, kernel_size=3, conditional=False, mode='nearest')
+ output = medfilt2d(nan_corner, kernel_size=3, conditional=False, mode="nearest")
self.assertEqual(output[0, 0], 10)
self.assertEqual(output[0, 1], 2)
self.assertEqual(output[1, 0], 11)
self.assertEqual(output[1, 1], 12)
# Data with some NaNs
- some_nans = numpy.arange(100.).reshape(10, 10)
+ some_nans = numpy.arange(100.0).reshape(10, 10)
some_nans[0, 1] = numpy.nan
some_nans[1, 1] = numpy.nan
some_nans[1, 0] = numpy.nan
- output = medfilt2d(
- some_nans, kernel_size=3, conditional=False, mode='nearest')
+ output = medfilt2d(some_nans, kernel_size=3, conditional=False, mode="nearest")
self.assertEqual(output[0, 0], 0)
self.assertEqual(output[0, 1], 2)
self.assertEqual(output[1, 0], 20)
@@ -178,34 +176,36 @@ class TestMedianFilterReflect(ParametricTestCase):
img = numpy.arange(9, dtype=numpy.int32)
img = img.reshape(3, 3)
kernel = (3, 3)
- res = medfilt2d(image=img,
- kernel_size=kernel,
- conditional=False,
- mode='reflect')
- self.assertTrue(
- numpy.array_equal(res.ravel(), [1, 2, 2, 3, 4, 5, 6, 6, 7]))
+ res = medfilt2d(
+ image=img, kernel_size=kernel, conditional=False, mode="reflect"
+ )
+ self.assertTrue(numpy.array_equal(res.ravel(), [1, 2, 2, 3, 4, 5, 6, 6, 7]))
def testRandom10(self):
"""Test a (5, 3) window to a RANDOM_FLOAT_MAT"""
kernel = (5, 3)
- thRes = numpy.array([
- [0.23067427, 0.56049024, 0.56049024, 0.4440632, 0.42161216],
- [0.23067427, 0.62717157, 0.56049024, 0.56049024, 0.46372299],
- [0.62717157, 0.62717157, 0.56049024, 0.56049024, 0.4440632],
- [0.76532598, 0.68382382, 0.56049024, 0.56049024, 0.42161216],
- [0.81025249, 0.68382382, 0.56049024, 0.68382382, 0.46372299]])
+ thRes = numpy.array(
+ [
+ [0.23067427, 0.56049024, 0.56049024, 0.4440632, 0.42161216],
+ [0.23067427, 0.62717157, 0.56049024, 0.56049024, 0.46372299],
+ [0.62717157, 0.62717157, 0.56049024, 0.56049024, 0.4440632],
+ [0.76532598, 0.68382382, 0.56049024, 0.56049024, 0.42161216],
+ [0.81025249, 0.68382382, 0.56049024, 0.68382382, 0.46372299],
+ ]
+ )
- res = medfilt2d(image=RANDOM_FLOAT_MAT,
- kernel_size=kernel,
- conditional=False,
- mode='reflect')
+ res = medfilt2d(
+ image=RANDOM_FLOAT_MAT,
+ kernel_size=kernel,
+ conditional=False,
+ mode="reflect",
+ )
self.assertTrue(numpy.array_equal(thRes, res))
def testApplyReflect1D(self):
- """Test the reflect function used for the median filter in reflect mode
- """
+ """Test the reflect function used for the median filter in reflect mode"""
# test for inside values
self.assertTrue(reflect(2, 3) == 2)
# test for boundaries values
@@ -227,38 +227,38 @@ class TestMedianFilterReflect(ParametricTestCase):
option"""
kernel = (3, 1)
- thRes = numpy.array([
- [0.05564293, 0.62717157, 0.75002406, 0.40555336, 0.70278975],
- [0.23067427, 0.62717157, 0.56049024, 0.44406320, 0.42161216],
- [0.76532598, 0.20303021, 0.56049024, 0.46372299, 0.42161216],
- [0.81025249, 0.20303021, 0.68382382, 0.46372299, 0.33623165],
- [0.94691602, 0.07813661, 0.81651256, 0.84220106, 0.33623165]])
+ thRes = numpy.array(
+ [
+ [0.05564293, 0.62717157, 0.75002406, 0.40555336, 0.70278975],
+ [0.23067427, 0.62717157, 0.56049024, 0.44406320, 0.42161216],
+ [0.76532598, 0.20303021, 0.56049024, 0.46372299, 0.42161216],
+ [0.81025249, 0.20303021, 0.68382382, 0.46372299, 0.33623165],
+ [0.94691602, 0.07813661, 0.81651256, 0.84220106, 0.33623165],
+ ]
+ )
- res = medfilt2d(image=RANDOM_FLOAT_MAT,
- kernel_size=kernel,
- conditional=True,
- mode='reflect')
+ res = medfilt2d(
+ image=RANDOM_FLOAT_MAT, kernel_size=kernel, conditional=True, mode="reflect"
+ )
self.assertTrue(numpy.array_equal(thRes, res))
def testNaNs(self):
"""Test median filter on image with NaNs in reflect mode"""
# Data with a NaN in first corner
- nan_corner = numpy.arange(100.).reshape(10, 10)
+ nan_corner = numpy.arange(100.0).reshape(10, 10)
nan_corner[0, 0] = numpy.nan
- output = medfilt2d(
- nan_corner, kernel_size=3, conditional=False, mode='reflect')
+ output = medfilt2d(nan_corner, kernel_size=3, conditional=False, mode="reflect")
self.assertEqual(output[0, 0], 10)
self.assertEqual(output[0, 1], 2)
self.assertEqual(output[1, 0], 11)
self.assertEqual(output[1, 1], 12)
# Data with some NaNs
- some_nans = numpy.arange(100.).reshape(10, 10)
+ some_nans = numpy.arange(100.0).reshape(10, 10)
some_nans[0, 1] = numpy.nan
some_nans[1, 1] = numpy.nan
some_nans[1, 0] = numpy.nan
- output = medfilt2d(
- some_nans, kernel_size=3, conditional=False, mode='reflect')
+ output = medfilt2d(some_nans, kernel_size=3, conditional=False, mode="reflect")
self.assertEqual(output[0, 0], 0)
self.assertEqual(output[0, 1], 2)
self.assertEqual(output[1, 0], 20)
@@ -266,20 +266,21 @@ class TestMedianFilterReflect(ParametricTestCase):
def testFilter3_1d(self):
"""Test binding and result of the 1d filter"""
- self.assertTrue(numpy.array_equal(
- medfilt1d(RANDOM_INT_MAT[0], kernel_size=5, conditional=False,
- mode='reflect'),
- [2, 2, 2, 2, 2])
+ self.assertTrue(
+ numpy.array_equal(
+ medfilt1d(
+ RANDOM_INT_MAT[0], kernel_size=5, conditional=False, mode="reflect"
+ ),
+ [2, 2, 2, 2, 2],
+ )
)
class TestMedianFilterMirror(ParametricTestCase):
- """Unit test for the median filter in mirror mode
- """
+ """Unit test for the median filter in mirror mode"""
def testApplyMirror1D(self):
- """Test the reflect function used for the median filter in mirror mode
- """
+ """Test the reflect function used for the median filter in mirror mode"""
# test for inside values
self.assertTrue(mirror(2, 3) == 2)
# test for boundaries values
@@ -299,17 +300,19 @@ class TestMedianFilterMirror(ParametricTestCase):
"""Test a (5, 3) window to a random array"""
kernel = (3, 5)
- thRes = numpy.array([
- [0.05272484, 0.40555336, 0.42161216, 0.42161216, 0.42161216],
- [0.56049024, 0.56049024, 0.4440632, 0.4440632, 0.4440632],
- [0.56049024, 0.46372299, 0.46372299, 0.46372299, 0.46372299],
- [0.68382382, 0.56049024, 0.56049024, 0.46372299, 0.56049024],
- [0.68382382, 0.46372299, 0.68382382, 0.46372299, 0.68382382]])
+ thRes = numpy.array(
+ [
+ [0.05272484, 0.40555336, 0.42161216, 0.42161216, 0.42161216],
+ [0.56049024, 0.56049024, 0.4440632, 0.4440632, 0.4440632],
+ [0.56049024, 0.46372299, 0.46372299, 0.46372299, 0.46372299],
+ [0.68382382, 0.56049024, 0.56049024, 0.46372299, 0.56049024],
+ [0.68382382, 0.46372299, 0.68382382, 0.46372299, 0.68382382],
+ ]
+ )
- res = medfilt2d(image=RANDOM_FLOAT_MAT,
- kernel_size=kernel,
- conditional=False,
- mode='mirror')
+ res = medfilt2d(
+ image=RANDOM_FLOAT_MAT, kernel_size=kernel, conditional=False, mode="mirror"
+ )
self.assertTrue(numpy.array_equal(thRes, res))
@@ -318,39 +321,39 @@ class TestMedianFilterMirror(ParametricTestCase):
option"""
kernel = (1, 3)
- thRes = numpy.array([
- [0.62717157, 0.62717157, 0.62717157, 0.70278975, 0.40555336],
- [0.02839148, 0.05272484, 0.05272484, 0.42161216, 0.65166994],
- [0.74219128, 0.56049024, 0.56049024, 0.44406320, 0.44406320],
- [0.20303021, 0.68382382, 0.46372299, 0.68382382, 0.46372299],
- [0.07813661, 0.81651256, 0.81651256, 0.81651256, 0.84220106]])
+ thRes = numpy.array(
+ [
+ [0.62717157, 0.62717157, 0.62717157, 0.70278975, 0.40555336],
+ [0.02839148, 0.05272484, 0.05272484, 0.42161216, 0.65166994],
+ [0.74219128, 0.56049024, 0.56049024, 0.44406320, 0.44406320],
+ [0.20303021, 0.68382382, 0.46372299, 0.68382382, 0.46372299],
+ [0.07813661, 0.81651256, 0.81651256, 0.81651256, 0.84220106],
+ ]
+ )
- res = medfilt2d(image=RANDOM_FLOAT_MAT,
- kernel_size=kernel,
- conditional=True,
- mode='mirror')
+ res = medfilt2d(
+ image=RANDOM_FLOAT_MAT, kernel_size=kernel, conditional=True, mode="mirror"
+ )
self.assertTrue(numpy.array_equal(thRes, res))
def testNaNs(self):
"""Test median filter on image with NaNs in mirror mode"""
# Data with a NaN in first corner
- nan_corner = numpy.arange(100.).reshape(10, 10)
+ nan_corner = numpy.arange(100.0).reshape(10, 10)
nan_corner[0, 0] = numpy.nan
- output = medfilt2d(
- nan_corner, kernel_size=3, conditional=False, mode='mirror')
+ output = medfilt2d(nan_corner, kernel_size=3, conditional=False, mode="mirror")
self.assertEqual(output[0, 0], 11)
self.assertEqual(output[0, 1], 11)
self.assertEqual(output[1, 0], 11)
self.assertEqual(output[1, 1], 12)
# Data with some NaNs
- some_nans = numpy.arange(100.).reshape(10, 10)
+ some_nans = numpy.arange(100.0).reshape(10, 10)
some_nans[0, 1] = numpy.nan
some_nans[1, 1] = numpy.nan
some_nans[1, 0] = numpy.nan
- output = medfilt2d(
- some_nans, kernel_size=3, conditional=False, mode='mirror')
+ output = medfilt2d(some_nans, kernel_size=3, conditional=False, mode="mirror")
self.assertEqual(output[0, 0], 0)
self.assertEqual(output[0, 1], 12)
self.assertEqual(output[1, 0], 21)
@@ -358,32 +361,37 @@ class TestMedianFilterMirror(ParametricTestCase):
def testFilter3_1d(self):
"""Test binding and result of the 1d filter"""
- self.assertTrue(numpy.array_equal(
- medfilt1d(RANDOM_INT_MAT[0], kernel_size=5, conditional=False,
- mode='mirror'),
- [2, 5, 2, 5, 2])
+ self.assertTrue(
+ numpy.array_equal(
+ medfilt1d(
+ RANDOM_INT_MAT[0], kernel_size=5, conditional=False, mode="mirror"
+ ),
+ [2, 5, 2, 5, 2],
+ )
)
+
class TestMedianFilterShrink(ParametricTestCase):
- """Unit test for the median filter in mirror mode
- """
+ """Unit test for the median filter in mirror mode"""
def testRandom_3x3(self):
"""Test the median filter in shrink mode and with the conditionnal
option"""
kernel = (3, 3)
- thRes = numpy.array([
- [0.62717157, 0.62717157, 0.62717157, 0.65166994, 0.65166994],
- [0.62717157, 0.56049024, 0.56049024, 0.44406320, 0.44406320],
- [0.74219128, 0.56049024, 0.46372299, 0.46372299, 0.46372299],
- [0.74219128, 0.68382382, 0.56049024, 0.56049024, 0.46372299],
- [0.81025249, 0.81025249, 0.68382382, 0.81281709, 0.81281709]])
+ thRes = numpy.array(
+ [
+ [0.62717157, 0.62717157, 0.62717157, 0.65166994, 0.65166994],
+ [0.62717157, 0.56049024, 0.56049024, 0.44406320, 0.44406320],
+ [0.74219128, 0.56049024, 0.46372299, 0.46372299, 0.46372299],
+ [0.74219128, 0.68382382, 0.56049024, 0.56049024, 0.46372299],
+ [0.81025249, 0.81025249, 0.68382382, 0.81281709, 0.81281709],
+ ]
+ )
- res = medfilt2d(image=RANDOM_FLOAT_MAT,
- kernel_size=kernel,
- conditional=False,
- mode='shrink')
+ res = medfilt2d(
+ image=RANDOM_FLOAT_MAT, kernel_size=kernel, conditional=False, mode="shrink"
+ )
self.assertTrue(numpy.array_equal(thRes, res))
@@ -396,25 +404,21 @@ class TestMedianFilterShrink(ParametricTestCase):
kernel2 = (1, 11)
kernel3 = (1, 21)
- thRes = numpy.array([[2, 2, 2, 2, 2],
- [2, 2, 2, 2, 2],
- [8, 8, 8, 8, 8],
- [5, 5, 5, 5, 5]])
+ thRes = numpy.array(
+ [[2, 2, 2, 2, 2], [2, 2, 2, 2, 2], [8, 8, 8, 8, 8], [5, 5, 5, 5, 5]]
+ )
- resK1 = medfilt2d(image=RANDOM_INT_MAT,
- kernel_size=kernel1,
- conditional=False,
- mode='shrink')
+ resK1 = medfilt2d(
+ image=RANDOM_INT_MAT, kernel_size=kernel1, conditional=False, mode="shrink"
+ )
- resK2 = medfilt2d(image=RANDOM_INT_MAT,
- kernel_size=kernel2,
- conditional=False,
- mode='shrink')
+ resK2 = medfilt2d(
+ image=RANDOM_INT_MAT, kernel_size=kernel2, conditional=False, mode="shrink"
+ )
- resK3 = medfilt2d(image=RANDOM_INT_MAT,
- kernel_size=kernel3,
- conditional=False,
- mode='shrink')
+ resK3 = medfilt2d(
+ image=RANDOM_INT_MAT, kernel_size=kernel3, conditional=False, mode="shrink"
+ )
self.assertTrue(numpy.array_equal(resK1, thRes))
self.assertTrue(numpy.array_equal(resK2, resK1))
@@ -425,56 +429,53 @@ class TestMedianFilterShrink(ParametricTestCase):
option"""
kernel = (3, 3)
- thRes = numpy.array([
- [0.05564293, 0.62717157, 0.62717157, 0.40555336, 0.65166994],
- [0.62717157, 0.56049024, 0.05272484, 0.65166994, 0.42161216],
- [0.23067427, 0.74219128, 0.56049024, 0.44406320, 0.46372299],
- [0.81025249, 0.20303021, 0.68382382, 0.46372299, 0.81281709],
- [0.81025249, 0.81025249, 0.81651256, 0.81281709, 0.81281709]])
+ thRes = numpy.array(
+ [
+ [0.05564293, 0.62717157, 0.62717157, 0.40555336, 0.65166994],
+ [0.62717157, 0.56049024, 0.05272484, 0.65166994, 0.42161216],
+ [0.23067427, 0.74219128, 0.56049024, 0.44406320, 0.46372299],
+ [0.81025249, 0.20303021, 0.68382382, 0.46372299, 0.81281709],
+ [0.81025249, 0.81025249, 0.81651256, 0.81281709, 0.81281709],
+ ]
+ )
- res = medfilt2d(image=RANDOM_FLOAT_MAT,
- kernel_size=kernel,
- conditional=True,
- mode='shrink')
+ res = medfilt2d(
+ image=RANDOM_FLOAT_MAT, kernel_size=kernel, conditional=True, mode="shrink"
+ )
self.assertTrue(numpy.array_equal(res, thRes))
def testRandomInt(self):
- """Test 3x3 kernel on RANDOM_INT_MAT
- """
+ """Test 3x3 kernel on RANDOM_INT_MAT"""
kernel = (3, 3)
- thRes = numpy.array([[3, 2, 5, 2, 6],
- [5, 3, 6, 6, 7],
- [6, 6, 6, 6, 7],
- [8, 8, 7, 7, 7]])
+ thRes = numpy.array(
+ [[3, 2, 5, 2, 6], [5, 3, 6, 6, 7], [6, 6, 6, 6, 7], [8, 8, 7, 7, 7]]
+ )
- resK1 = medfilt2d(image=RANDOM_INT_MAT,
- kernel_size=kernel,
- conditional=False,
- mode='shrink')
+ resK1 = medfilt2d(
+ image=RANDOM_INT_MAT, kernel_size=kernel, conditional=False, mode="shrink"
+ )
self.assertTrue(numpy.array_equal(resK1, thRes))
def testNaNs(self):
"""Test median filter on image with NaNs in shrink mode"""
# Data with a NaN in first corner
- nan_corner = numpy.arange(100.).reshape(10, 10)
+ nan_corner = numpy.arange(100.0).reshape(10, 10)
nan_corner[0, 0] = numpy.nan
- output = medfilt2d(
- nan_corner, kernel_size=3, conditional=False, mode='shrink')
+ output = medfilt2d(nan_corner, kernel_size=3, conditional=False, mode="shrink")
self.assertEqual(output[0, 0], 10)
self.assertEqual(output[0, 1], 10)
self.assertEqual(output[1, 0], 11)
self.assertEqual(output[1, 1], 12)
# Data with some NaNs
- some_nans = numpy.arange(100.).reshape(10, 10)
+ some_nans = numpy.arange(100.0).reshape(10, 10)
some_nans[0, 1] = numpy.nan
some_nans[1, 1] = numpy.nan
some_nans[1, 0] = numpy.nan
- output = medfilt2d(
- some_nans, kernel_size=3, conditional=False, mode='shrink')
+ output = medfilt2d(some_nans, kernel_size=3, conditional=False, mode="shrink")
self.assertEqual(output[0, 0], 0)
self.assertEqual(output[0, 1], 2)
self.assertEqual(output[1, 0], 20)
@@ -482,40 +483,51 @@ class TestMedianFilterShrink(ParametricTestCase):
def testFilter3_1d(self):
"""Test binding and result of the 1d filter"""
- self.assertTrue(numpy.array_equal(
- medfilt1d(RANDOM_INT_MAT[0], kernel_size=3, conditional=False,
- mode='shrink'),
- [5, 2, 5, 2, 6])
+ self.assertTrue(
+ numpy.array_equal(
+ medfilt1d(
+ RANDOM_INT_MAT[0], kernel_size=3, conditional=False, mode="shrink"
+ ),
+ [5, 2, 5, 2, 6],
+ )
)
+
class TestMedianFilterConstant(ParametricTestCase):
- """Unit test for the median filter in constant mode
- """
+ """Unit test for the median filter in constant mode"""
def testRandom10(self):
"""Test a (5, 3) window to a random array"""
kernel = (3, 5)
- thRes = numpy.array([
- [0., 0.02839148, 0.05564293, 0.02839148, 0.],
- [0.05272484, 0.40555336, 0.4440632, 0.42161216, 0.28773158],
- [0.05272484, 0.44406320, 0.46372299, 0.42161216, 0.28773158],
- [0.20303021, 0.46372299, 0.56049024, 0.44406320, 0.33623165],
- [0., 0.07813661, 0.33623165, 0.07813661, 0.]])
+ thRes = numpy.array(
+ [
+ [0.0, 0.02839148, 0.05564293, 0.02839148, 0.0],
+ [0.05272484, 0.40555336, 0.4440632, 0.42161216, 0.28773158],
+ [0.05272484, 0.44406320, 0.46372299, 0.42161216, 0.28773158],
+ [0.20303021, 0.46372299, 0.56049024, 0.44406320, 0.33623165],
+ [0.0, 0.07813661, 0.33623165, 0.07813661, 0.0],
+ ]
+ )
- res = medfilt2d(image=RANDOM_FLOAT_MAT,
- kernel_size=kernel,
- conditional=False,
- mode='constant')
+ res = medfilt2d(
+ image=RANDOM_FLOAT_MAT,
+ kernel_size=kernel,
+ conditional=False,
+ mode="constant",
+ )
self.assertTrue(numpy.array_equal(thRes, res))
- RANDOM_FLOAT_MAT = numpy.array([
- [0.05564293, 0.62717157, 0.75002406, 0.40555336, 0.70278975],
- [0.76532598, 0.02839148, 0.05272484, 0.65166994, 0.42161216],
- [0.23067427, 0.74219128, 0.56049024, 0.44406320, 0.28773158],
- [0.81025249, 0.20303021, 0.68382382, 0.46372299, 0.81281709],
- [0.94691602, 0.07813661, 0.81651256, 0.84220106, 0.33623165]])
+ RANDOM_FLOAT_MAT = numpy.array(
+ [
+ [0.05564293, 0.62717157, 0.75002406, 0.40555336, 0.70278975],
+ [0.76532598, 0.02839148, 0.05272484, 0.65166994, 0.42161216],
+ [0.23067427, 0.74219128, 0.56049024, 0.44406320, 0.28773158],
+ [0.81025249, 0.20303021, 0.68382382, 0.46372299, 0.81281709],
+ [0.94691602, 0.07813661, 0.81651256, 0.84220106, 0.33623165],
+ ]
+ )
def testRandom10Conditionnal(self):
"""Test the median filter in reflect mode and with the conditionnal
@@ -524,45 +536,46 @@ class TestMedianFilterConstant(ParametricTestCase):
print(RANDOM_FLOAT_MAT)
- thRes = numpy.array([
- [0.05564293, 0.62717157, 0.62717157, 0.70278975, 0.40555336],
- [0.02839148, 0.05272484, 0.05272484, 0.42161216, 0.42161216],
- [0.23067427, 0.56049024, 0.56049024, 0.44406320, 0.28773158],
- [0.20303021, 0.68382382, 0.46372299, 0.68382382, 0.46372299],
- [0.07813661, 0.81651256, 0.81651256, 0.81651256, 0.33623165]])
+ thRes = numpy.array(
+ [
+ [0.05564293, 0.62717157, 0.62717157, 0.70278975, 0.40555336],
+ [0.02839148, 0.05272484, 0.05272484, 0.42161216, 0.42161216],
+ [0.23067427, 0.56049024, 0.56049024, 0.44406320, 0.28773158],
+ [0.20303021, 0.68382382, 0.46372299, 0.68382382, 0.46372299],
+ [0.07813661, 0.81651256, 0.81651256, 0.81651256, 0.33623165],
+ ]
+ )
- res = medfilt2d(image=RANDOM_FLOAT_MAT,
- kernel_size=kernel,
- conditional=True,
- mode='constant')
+ res = medfilt2d(
+ image=RANDOM_FLOAT_MAT,
+ kernel_size=kernel,
+ conditional=True,
+ mode="constant",
+ )
self.assertTrue(numpy.array_equal(thRes, res))
def testNaNs(self):
"""Test median filter on image with NaNs in constant mode"""
# Data with a NaN in first corner
- nan_corner = numpy.arange(100.).reshape(10, 10)
+ nan_corner = numpy.arange(100.0).reshape(10, 10)
nan_corner[0, 0] = numpy.nan
- output = medfilt2d(nan_corner,
- kernel_size=3,
- conditional=False,
- mode='constant',
- cval=0)
+ output = medfilt2d(
+ nan_corner, kernel_size=3, conditional=False, mode="constant", cval=0
+ )
self.assertEqual(output[0, 0], 0)
self.assertEqual(output[0, 1], 2)
self.assertEqual(output[1, 0], 10)
self.assertEqual(output[1, 1], 12)
# Data with some NaNs
- some_nans = numpy.arange(100.).reshape(10, 10)
+ some_nans = numpy.arange(100.0).reshape(10, 10)
some_nans[0, 1] = numpy.nan
some_nans[1, 1] = numpy.nan
some_nans[1, 0] = numpy.nan
- output = medfilt2d(some_nans,
- kernel_size=3,
- conditional=False,
- mode='constant',
- cval=0)
+ output = medfilt2d(
+ some_nans, kernel_size=3, conditional=False, mode="constant", cval=0
+ )
self.assertEqual(output[0, 0], 0)
self.assertEqual(output[0, 1], 0)
self.assertEqual(output[1, 0], 0)
@@ -570,12 +583,16 @@ class TestMedianFilterConstant(ParametricTestCase):
def testFilter3_1d(self):
"""Test binding and result of the 1d filter"""
- self.assertTrue(numpy.array_equal(
- medfilt1d(RANDOM_INT_MAT[0], kernel_size=5, conditional=False,
- mode='constant'),
- [0, 2, 2, 2, 1])
+ self.assertTrue(
+ numpy.array_equal(
+ medfilt1d(
+ RANDOM_INT_MAT[0], kernel_size=5, conditional=False, mode="constant"
+ ),
+ [0, 2, 2, 2, 1],
+ )
)
+
class TestGeneralExecution(ParametricTestCase):
"""Some general test on median filter application"""
@@ -584,15 +601,20 @@ class TestGeneralExecution(ParametricTestCase):
filter
"""
for mode in silx_mf_modes:
- for testType in [numpy.float32, numpy.float64, numpy.int16,
- numpy.uint16, numpy.int32, numpy.int64,
- numpy.uint64]:
+ for testType in [
+ numpy.float32,
+ numpy.float64,
+ numpy.int16,
+ numpy.uint16,
+ numpy.int32,
+ numpy.int64,
+ numpy.uint64,
+ ]:
with self.subTest(mode=mode, type=testType):
data = (numpy.random.rand(10, 10) * 65000).astype(testType)
- out = medfilt2d(image=data,
- kernel_size=(3, 3),
- conditional=False,
- mode=mode)
+ out = medfilt2d(
+ image=data, kernel_size=(3, 3), conditional=False, mode=mode
+ )
self.assertTrue(out.dtype.type is testType)
def testInputDataIsNotModify(self):
@@ -603,10 +625,9 @@ class TestGeneralExecution(ParametricTestCase):
for mode in silx_mf_modes:
with self.subTest(mode=mode):
- medfilt2d(image=dataIn,
- kernel_size=(3, 3),
- conditional=False,
- mode=mode)
+ medfilt2d(
+ image=dataIn, kernel_size=(3, 3), conditional=False, mode=mode
+ )
self.assertTrue(numpy.array_equal(dataIn, dataInCopy))
def testAllNaNs(self):
@@ -622,7 +643,8 @@ class TestGeneralExecution(ParametricTestCase):
kernel_size=3,
conditional=conditional,
mode=mode,
- cval=numpy.nan)
+ cval=numpy.nan,
+ )
self.assertTrue(numpy.all(numpy.isnan(output)))
def testConditionalWithNaNs(self):
@@ -635,29 +657,25 @@ class TestGeneralExecution(ParametricTestCase):
nan_mask[4, :] = True
nan_mask[6, 4] = True
image[nan_mask] = numpy.nan
- output = medfilt2d(
- image,
- kernel_size=3,
- conditional=True,
- mode=mode)
+ output = medfilt2d(image, kernel_size=3, conditional=True, mode=mode)
out_isnan = numpy.isnan(output)
self.assertTrue(numpy.all(out_isnan[nan_mask]))
- self.assertFalse(
- numpy.any(out_isnan[numpy.logical_not(nan_mask)]))
+ self.assertFalse(numpy.any(out_isnan[numpy.logical_not(nan_mask)]))
def _getScipyAndSilxCommonModes():
"""return the mode which are comparable between silx and scipy"""
modes = silx_mf_modes.copy()
- del modes['shrink']
+ del modes["shrink"]
return modes
@unittest.skipUnless(scipy is not None, "scipy not available")
class TestVsScipy(ParametricTestCase):
"""Compare scipy.ndimage.median_filter vs silx.math.medianfilter
- on comparable
+ on comparable
"""
+
def testWithArange(self):
"""Test vs scipy with different kernels on arange matrix"""
data = numpy.arange(10000, dtype=numpy.int32)
@@ -668,13 +686,12 @@ class TestVsScipy(ParametricTestCase):
for kernel in kernels:
for mode in modesToTest:
with self.subTest(kernel=kernel, mode=mode):
- resScipy = scipy.ndimage.median_filter(input=data,
- size=kernel,
- mode=mode)
- resSilx = medfilt2d(image=data,
- kernel_size=kernel,
- conditional=False,
- mode=mode)
+ resScipy = scipy.ndimage.median_filter(
+ input=data, size=kernel, mode=mode
+ )
+ resSilx = medfilt2d(
+ image=data, kernel_size=kernel, conditional=False, mode=mode
+ )
self.assertTrue(numpy.array_equal(resScipy, resSilx))
@@ -685,23 +702,22 @@ class TestVsScipy(ParametricTestCase):
for kernel in kernels:
for mode in modesToTest:
with self.subTest(kernel=kernel, mode=mode):
- resScipy = scipy.ndimage.median_filter(input=RANDOM_FLOAT_MAT,
- size=kernel,
- mode=mode)
+ resScipy = scipy.ndimage.median_filter(
+ input=RANDOM_FLOAT_MAT, size=kernel, mode=mode
+ )
- resSilx = medfilt2d(image=RANDOM_FLOAT_MAT,
- kernel_size=kernel,
- conditional=False,
- mode=mode)
+ resSilx = medfilt2d(
+ image=RANDOM_FLOAT_MAT,
+ kernel_size=kernel,
+ conditional=False,
+ mode=mode,
+ )
self.assertTrue(numpy.array_equal(resScipy, resSilx))
- def testAscentOrLena(self):
- """Test vs scipy with """
- if hasattr(scipy.misc, 'ascent'):
- img = scipy.misc.ascent()
- else:
- img = scipy.misc.lena()
+ def testAscent(self):
+ """Test vs scipy with"""
+ img = ascent()
kernels = [(3, 1), (3, 5), (5, 9), (9, 3)]
modesToTest = _getScipyAndSilxCommonModes()
@@ -709,13 +725,12 @@ class TestVsScipy(ParametricTestCase):
for kernel in kernels:
for mode in modesToTest:
with self.subTest(kernel=kernel, mode=mode):
- resScipy = scipy.ndimage.median_filter(input=img,
- size=kernel,
- mode=mode)
-
- resSilx = medfilt2d(image=img,
- kernel_size=kernel,
- conditional=False,
- mode=mode)
+ resScipy = scipy.ndimage.median_filter(
+ input=img, size=kernel, mode=mode
+ )
+
+ resSilx = medfilt2d(
+ image=img, kernel_size=kernel, conditional=False, mode=mode
+ )
self.assertTrue(numpy.array_equal(resScipy, resSilx))
diff --git a/src/silx/math/test/benchmark_combo.py b/src/silx/math/test/benchmark_combo.py
index 484bc93..e679a28 100644
--- a/src/silx/math/test/benchmark_combo.py
+++ b/src/silx/math/test/benchmark_combo.py
@@ -28,13 +28,10 @@ __date__ = "17/01/2018"
import logging
-import os.path
import time
-import unittest
import numpy
-from silx.test.utils import temp_dir
from silx.utils.testutils import ParametricTestCase
from silx.math import combo
@@ -46,40 +43,49 @@ _logger.setLevel(logging.DEBUG)
class TestBenchmarkMinMax(ParametricTestCase):
"""Benchmark of min max combo"""
- DTYPES = ('float32', 'float64',
- 'int8', 'int16', 'int32', 'int64',
- 'uint8', 'uint16', 'uint32', 'uint64')
-
- ARANGE = 'ascent', 'descent', 'random'
+ DTYPES = (
+ "float32",
+ "float64",
+ "int8",
+ "int16",
+ "int32",
+ "int64",
+ "uint8",
+ "uint16",
+ "uint32",
+ "uint64",
+ )
+
+ ARANGE = "ascent", "descent", "random"
EXPONENT = 3, 4, 5, 6, 7
def test_benchmark_min_max(self):
"""Benchmark min_max without min positive.
-
+
Compares with:
-
+
- numpy.nanmin, numpy.nanmax and
- numpy.argmin, numpy.argmax
It runs bench for different types, different data size and 3
data sets: increasing , decreasing and random data.
"""
- durations = {'min/max': [], 'argmin/max': [], 'combo': []}
+ durations = {"min/max": [], "argmin/max": [], "combo": []}
- _logger.info('Benchmark against argmin/argmax and nanmin/nanmax')
+ _logger.info("Benchmark against argmin/argmax and nanmin/nanmax")
for dtype in self.DTYPES:
for arange in self.ARANGE:
for exponent in self.EXPONENT:
size = 10**exponent
with self.subTest(dtype=dtype, size=size, arange=arange):
- if arange == 'ascent':
+ if arange == "ascent":
data = numpy.arange(0, size, 1, dtype=dtype)
- elif arange == 'descent':
+ elif arange == "descent":
data = numpy.arange(size, 0, -1, dtype=dtype)
else:
- if dtype in ('float32', 'float64'):
+ if dtype in ("float32", "float64"):
data = numpy.random.random(size)
else:
data = numpy.random.randint(10**6, size=size)
@@ -88,55 +94,58 @@ class TestBenchmarkMinMax(ParametricTestCase):
start = time.time()
ref_min = numpy.nanmin(data)
ref_max = numpy.nanmax(data)
- durations['min/max'].append(time.time() - start)
+ durations["min/max"].append(time.time() - start)
start = time.time()
ref_argmin = numpy.argmin(data)
ref_argmax = numpy.argmax(data)
- durations['argmin/max'].append(time.time() - start)
+ durations["argmin/max"].append(time.time() - start)
start = time.time()
result = combo.min_max(data, min_positive=False)
- durations['combo'].append(time.time() - start)
+ durations["combo"].append(time.time() - start)
_logger.info(
- '%s-%s-10**%d\tx%.2f argmin/max x%.2f min/max',
- dtype, arange, exponent,
- durations['argmin/max'][-1] / durations['combo'][-1],
- durations['min/max'][-1] / durations['combo'][-1])
+ "%s-%s-10**%d\tx%.2f argmin/max x%.2f min/max",
+ dtype,
+ arange,
+ exponent,
+ durations["argmin/max"][-1] / durations["combo"][-1],
+ durations["min/max"][-1] / durations["combo"][-1],
+ )
self.assertEqual(result.minimum, ref_min)
self.assertEqual(result.maximum, ref_max)
self.assertEqual(result.argmin, ref_argmin)
self.assertEqual(result.argmax, ref_argmax)
- self.show_results('min/max', durations, 'combo')
+ self.show_results("min/max", durations, "combo")
def test_benchmark_min_pos(self):
"""Benchmark min_max wit min positive.
-
+
Compares with:
-
+
- numpy.nanmin(data[data > 0]); numpy.nanmin(pos); numpy.nanmax(pos)
It runs bench for different types, different data size and 3
data sets: increasing , decreasing and random data.
"""
- durations = {'min/max': [], 'combo': []}
+ durations = {"min/max": [], "combo": []}
- _logger.info('Benchmark against min, max, positive min')
+ _logger.info("Benchmark against min, max, positive min")
for dtype in self.DTYPES:
for arange in self.ARANGE:
for exponent in self.EXPONENT:
size = 10**exponent
with self.subTest(dtype=dtype, size=size, arange=arange):
- if arange == 'ascent':
+ if arange == "ascent":
data = numpy.arange(0, size, 1, dtype=dtype)
- elif arange == 'descent':
+ elif arange == "descent":
data = numpy.arange(size, 0, -1, dtype=dtype)
else:
- if dtype in ('float32', 'float64'):
+ if dtype in ("float32", "float64"):
data = numpy.random.random(size)
else:
data = numpy.random.randint(10**6, size=size)
@@ -146,44 +155,47 @@ class TestBenchmarkMinMax(ParametricTestCase):
ref_min_positive = numpy.nanmin(data[data > 0])
ref_min = numpy.nanmin(data)
ref_max = numpy.nanmax(data)
- durations['min/max'].append(time.time() - start)
+ durations["min/max"].append(time.time() - start)
start = time.time()
result = combo.min_max(data, min_positive=True)
- durations['combo'].append(time.time() - start)
+ durations["combo"].append(time.time() - start)
_logger.info(
- '%s-%s-10**%d\tx%.2f min/minpos/max',
- dtype, arange, exponent,
- durations['min/max'][-1] / durations['combo'][-1])
+ "%s-%s-10**%d\tx%.2f min/minpos/max",
+ dtype,
+ arange,
+ exponent,
+ durations["min/max"][-1] / durations["combo"][-1],
+ )
self.assertEqual(result.min_positive, ref_min_positive)
self.assertEqual(result.minimum, ref_min)
self.assertEqual(result.maximum, ref_max)
- self.show_results('min/max/min positive', durations, 'combo')
+ self.show_results("min/max/min positive", durations, "combo")
def show_results(self, title, durations, ref_key):
try:
from matplotlib import pyplot
except ImportError:
- _logger.warning('matplotlib not available')
+ _logger.warning("matplotlib not available")
return
pyplot.title(title)
- pyplot.xlabel('-'.join(self.DTYPES))
- pyplot.ylabel('duration (sec)')
+ pyplot.xlabel("-".join(self.DTYPES))
+ pyplot.ylabel("duration (sec)")
for label, values in durations.items():
pyplot.semilogy(values, label=label)
pyplot.legend()
pyplot.show()
pyplot.title(title)
- pyplot.xlabel('-'.join(self.DTYPES))
- pyplot.ylabel('Duration ratio')
+ pyplot.xlabel("-".join(self.DTYPES))
+ pyplot.ylabel("Duration ratio")
ref = numpy.array(durations[ref_key])
for label, values in durations.items():
values = numpy.array(values)
- pyplot.plot(values/ref, label=label + ' / ' + ref_key)
+ pyplot.plot(values / ref, label=label + " / " + ref_key)
pyplot.legend()
pyplot.show()
diff --git a/src/silx/math/test/histo_benchmarks.py b/src/silx/math/test/histo_benchmarks.py
index 6cc5507..051ace2 100644
--- a/src/silx/math/test/histo_benchmarks.py
+++ b/src/silx/math/test/histo_benchmarks.py
@@ -36,29 +36,22 @@ def print_times(t0s, t1s, t2s, t3s):
np_times = t2s - t1s
np_w_times = t3s - t2s
- time_txt = 'min : {0: <7.3f}; max : {1: <7.3f}; avg : {2: <7.3f}'
-
- print('\tTimes :')
- print('\tC : ' + time_txt.format(c_times.min(),
- c_times.max(),
- c_times.mean()))
- print('\tNP : ' + time_txt.format(np_times.min(),
- np_times.max(),
- np_times.mean()))
- print('\tNP(W) : ' + time_txt.format(np_w_times.min(),
- np_w_times.max(),
- np_w_times.mean()))
-
-
-def commpare_results(txt,
- times,
- result_c,
- result_np,
- result_np_w,
- sample,
- weights,
- raise_ex=False):
-
+ time_txt = "min : {0: <7.3f}; max : {1: <7.3f}; avg : {2: <7.3f}"
+
+ print("\tTimes :")
+ print("\tC : " + time_txt.format(c_times.min(), c_times.max(), c_times.mean()))
+ print(
+ "\tNP : " + time_txt.format(np_times.min(), np_times.max(), np_times.mean())
+ )
+ print(
+ "\tNP(W) : "
+ + time_txt.format(np_w_times.min(), np_w_times.max(), np_w_times.mean())
+ )
+
+
+def commpare_results(
+ txt, times, result_c, result_np, result_np_w, sample, weights, raise_ex=False
+):
if result_np:
hits_cmp = np.array_equal(result_c[0], result_np[0])
else:
@@ -69,61 +62,64 @@ def commpare_results(txt,
else:
weights_cmp = None
- if((hits_cmp is not None and not hits_cmp) or
- (weights_cmp is not None and not weights_cmp)):
- err_txt = (txt + ' : results arent the same : '
- 'hits : {0}, '
- 'weights : {1}.'
- ''.format('OK' if hits_cmp else 'NOK',
- 'OK' if weights_cmp else 'NOK'))
- print('\t' + err_txt)
+ if (hits_cmp is not None and not hits_cmp) or (
+ weights_cmp is not None and not weights_cmp
+ ):
+ err_txt = (
+ txt + " : results arent the same : "
+ "hits : {0}, "
+ "weights : {1}."
+ "".format("OK" if hits_cmp else "NOK", "OK" if weights_cmp else "NOK")
+ )
+ print("\t" + err_txt)
if raise_ex:
raise ValueError(err_txt)
return False
- result_txt = ' : results OK. c : {0: <7.3f};'.format(times[0])
+ result_txt = " : results OK. c : {0: <7.3f};".format(times[0])
if result_np or result_np_w:
- result_txt += (' np : {0: <7.3f}; '
- 'np (weights) {1: <7.3f}.'
- ''.format(times[1], times[2]))
- print('\t' + txt + result_txt)
+ result_txt += (
+ " np : {0: <7.3f}; "
+ "np (weights) {1: <7.3f}."
+ "".format(times[1], times[2])
+ )
+ print("\t" + txt + result_txt)
return True
-def benchmark(n_loops,
- sample_shape,
- sample_rng,
- weights_rng,
- histo_range,
- n_bins,
- weight_min,
- weight_max,
- last_bin_closed,
- dtype=np.double,
- do_weights=True,
- do_numpy=True):
-
+def benchmark(
+ n_loops,
+ sample_shape,
+ sample_rng,
+ weights_rng,
+ histo_range,
+ n_bins,
+ weight_min,
+ weight_max,
+ last_bin_closed,
+ dtype=np.double,
+ do_weights=True,
+ do_numpy=True,
+):
int_min = 0
int_max = 100000
- sample = np.random.randint(int_min,
- high=int_max,
- size=sample_shape).astype(np.double)
- sample = (sample_rng[0] +
- (sample - int_min) *
- (sample_rng[1] - sample_rng[0]) /
- (int_max - int_min))
+ sample = np.random.randint(int_min, high=int_max, size=sample_shape).astype(
+ np.double
+ )
+ sample = sample_rng[0] + (sample - int_min) * (sample_rng[1] - sample_rng[0]) / (
+ int_max - int_min
+ )
sample = sample.astype(dtype)
if do_weights:
- weights = np.random.randint(int_min,
- high=int_max,
- size=(ssetup.pyample_shape[0],))
+ weights = np.random.randint(
+ int_min, high=int_max, size=(ssetup.pyample_shape[0],)
+ )
weights = weights.astype(np.double)
- weights = (weights_rng[0] +
- (weights - int_min) *
- (weights_rng[1] - weights_rng[0]) /
- (int_max - int_min))
+ weights = weights_rng[0] + (weights - int_min) * (
+ weights_rng[1] - weights_rng[0]
+ ) / (int_max - int_min)
else:
weights = None
@@ -134,23 +130,22 @@ def benchmark(n_loops,
for i in range(n_loops):
t0s.append(time.time())
- result_c = histogramnd(sample,
- histo_range,
- n_bins,
- weights=weights,
- weight_min=weight_min,
- weight_max=weight_max,
- last_bin_closed=last_bin_closed)
+ result_c = histogramnd(
+ sample,
+ histo_range,
+ n_bins,
+ weights=weights,
+ weight_min=weight_min,
+ weight_max=weight_max,
+ last_bin_closed=last_bin_closed,
+ )
t1s.append(time.time())
if do_numpy:
- result_np = np.histogramdd(sample,
- bins=n_bins,
- range=histo_range)
+ result_np = np.histogramdd(sample, bins=n_bins, range=histo_range)
t2s.append(time.time())
- result_np_w = np.histogramdd(sample,
- bins=n_bins,
- range=histo_range,
- weights=weights)
+ result_np_w = np.histogramdd(
+ sample, bins=n_bins, range=histo_range, weights=weights
+ )
t3s.append(time.time())
else:
result_np = None
@@ -158,24 +153,24 @@ def benchmark(n_loops,
t2s.append(0)
t3s.append(0)
- commpare_results('Run {0}'.format(i),
- [t1s[-1] - t0s[-1], t2s[-1] - t1s[-1], t3s[-1] - t2s[-1]],
- result_c,
- result_np,
- result_np_w,
- sample,
- weights)
+ commpare_results(
+ "Run {0}".format(i),
+ [t1s[-1] - t0s[-1], t2s[-1] - t1s[-1], t3s[-1] - t2s[-1]],
+ result_c,
+ result_np,
+ result_np_w,
+ sample,
+ weights,
+ )
print_times(np.array(t0s), np.array(t1s), np.array(t2s), np.array(t3s))
-def run_benchmark(dtype=np.double,
- do_weights=True,
- do_numpy=True):
+def run_benchmark(dtype=np.double, do_weights=True, do_numpy=True):
n_loops = 5
- weights_rng = [0., 100.]
- sample_rng = [0., 100.]
+ weights_rng = [0.0, 100.0]
+ sample_rng = [0.0, 100.0]
weight_min = None
weight_max = None
@@ -187,25 +182,27 @@ def run_benchmark(dtype=np.double,
# ====================================================
# ====================================================
- print('==========================')
- print(' 1D [{0}]'.format(dtype))
- print('==========================')
+ print("==========================")
+ print(" 1D [{0}]".format(dtype))
+ print("==========================")
sample_shape = (10**7,)
- histo_range = [[0., 100.]]
+ histo_range = [[0.0, 100.0]]
n_bins = 30
- benchmark(n_loops,
- sample_shape,
- sample_rng,
- weights_rng,
- histo_range,
- n_bins,
- weight_min,
- weight_max,
- last_bin_closed,
- dtype=dtype,
- do_weights=True,
- do_numpy=do_numpy)
+ benchmark(
+ n_loops,
+ sample_shape,
+ sample_rng,
+ weights_rng,
+ histo_range,
+ n_bins,
+ weight_min,
+ weight_max,
+ last_bin_closed,
+ dtype=dtype,
+ do_weights=True,
+ do_numpy=do_numpy,
+ )
# ====================================================
# ====================================================
@@ -213,25 +210,27 @@ def run_benchmark(dtype=np.double,
# ====================================================
# ====================================================
- print('==========================')
- print(' 2D [{0}]'.format(dtype))
- print('==========================')
+ print("==========================")
+ print(" 2D [{0}]".format(dtype))
+ print("==========================")
sample_shape = (10**7, 2)
- histo_range = [[0., 100.], [0., 100.]]
+ histo_range = [[0.0, 100.0], [0.0, 100.0]]
n_bins = 30
- benchmark(n_loops,
- sample_shape,
- sample_rng,
- weights_rng,
- histo_range,
- n_bins,
- weight_min,
- weight_max,
- last_bin_closed,
- dtype=dtype,
- do_weights=True,
- do_numpy=do_numpy)
+ benchmark(
+ n_loops,
+ sample_shape,
+ sample_rng,
+ weights_rng,
+ histo_range,
+ n_bins,
+ weight_min,
+ weight_max,
+ last_bin_closed,
+ dtype=dtype,
+ do_weights=True,
+ do_numpy=do_numpy,
+ )
# ====================================================
# ====================================================
@@ -239,30 +238,35 @@ def run_benchmark(dtype=np.double,
# ====================================================
# ====================================================
- print('==========================')
- print(' 3D [{0}]'.format(dtype))
- print('==========================')
+ print("==========================")
+ print(" 3D [{0}]".format(dtype))
+ print("==========================")
sample_shape = (10**7, 3)
- histo_range = np.array([[0., 100.], [0., 100.], [0., 100.]])
+ histo_range = np.array([[0.0, 100.0], [0.0, 100.0], [0.0, 100.0]])
n_bins = 30
- benchmark(n_loops,
- sample_shape,
- sample_rng,
- weights_rng,
- histo_range,
- n_bins,
- weight_min,
- weight_max,
- last_bin_closed,
- dtype=dtype,
- do_weights=True,
- do_numpy=do_numpy)
-
-if __name__ == '__main__':
- types = (np.double, np.int32, np.float32,)
+ benchmark(
+ n_loops,
+ sample_shape,
+ sample_rng,
+ weights_rng,
+ histo_range,
+ n_bins,
+ weight_min,
+ weight_max,
+ last_bin_closed,
+ dtype=dtype,
+ do_weights=True,
+ do_numpy=do_numpy,
+ )
+
+
+if __name__ == "__main__":
+ types = (
+ np.double,
+ np.int32,
+ np.float32,
+ )
for t in types:
- run_benchmark(t,
- do_weights=True,
- do_numpy=True)
+ run_benchmark(t, do_weights=True, do_numpy=True)
diff --git a/src/silx/math/test/test_HistogramndLut_nominal.py b/src/silx/math/test/test_HistogramndLut_nominal.py
index 907a592..fba0778 100644
--- a/src/silx/math/test/test_HistogramndLut_nominal.py
+++ b/src/silx/math/test/test_HistogramndLut_nominal.py
@@ -34,10 +34,12 @@ from silx.math import HistogramndLut
def _get_bin_edges(histo_range, n_bins, n_dims):
edges = []
for i_dim in range(n_dims):
- edges.append(histo_range[i_dim, 0] +
- np.arange(n_bins[i_dim] + 1) *
- (histo_range[i_dim, 1] - histo_range[i_dim, 0]) /
- n_bins[i_dim])
+ edges.append(
+ histo_range[i_dim, 0]
+ + np.arange(n_bins[i_dim] + 1)
+ * (histo_range[i_dim, 1] - histo_range[i_dim, 0])
+ / n_bins[i_dim]
+ )
return tuple(edges)
@@ -50,6 +52,7 @@ class _TestHistogramndLut_nominal(unittest.TestCase):
"""
Unit tests of the HistogramndLut class.
"""
+
__test__ = False # ignore abstract class
ndims = None
@@ -58,22 +61,16 @@ class _TestHistogramndLut_nominal(unittest.TestCase):
ndims = self.ndims
if ndims is None:
self.skipTest("Abstract class")
- self.tested_dim = ndims-1
+ self.tested_dim = ndims - 1
if ndims is None:
- raise ValueError('ndims class member not set.')
+ raise ValueError("ndims class member not set.")
- sample = np.array([5.5, -3.3,
- 0., -0.5,
- 3.3, 8.8,
- -7.7, 6.0,
- -4.0])
+ sample = np.array([5.5, -3.3, 0.0, -0.5, 3.3, 8.8, -7.7, 6.0, -4.0])
- weights = np.array([500.5, -300.3,
- 0.01, -0.5,
- 300.3, 800.8,
- -700.7, 600.6,
- -400.4])
+ weights = np.array(
+ [500.5, -300.3, 0.01, -0.5, 300.3, 800.8, -700.7, 600.6, -400.4]
+ )
n_elems = len(sample)
@@ -86,7 +83,7 @@ class _TestHistogramndLut_nominal(unittest.TestCase):
if ndims == 1:
self.sample = sample
else:
- self.sample[..., ndims-1] = sample
+ self.sample[..., ndims - 1] = sample
self.weights = weights
@@ -97,124 +94,106 @@ class _TestHistogramndLut_nominal(unittest.TestCase):
# bin [2, y] because of the bin ranges [-2, 2] and n_bins = 4
# for the first dimension)
self.other_axes_index = 2
- self.histo_range = np.repeat([[-2., 2.]], ndims, axis=0)
- self.histo_range[ndims-1] = [-4., 6.]
+ self.histo_range = np.repeat([[-2.0, 2.0]], ndims, axis=0)
+ self.histo_range[ndims - 1] = [-4.0, 6.0]
- self.n_bins = np.array([4]*ndims)
- self.n_bins[ndims-1] = 5
+ self.n_bins = np.array([4] * ndims)
+ self.n_bins[ndims - 1] = 5
if ndims == 1:
+
def fill_histo(h, v, dim, op=None):
if op:
h[:] = op(h[:], v)
else:
h[:] = v
+
self.fill_histo = fill_histo
else:
+
def fill_histo(h, v, dim, op=None):
- idx = [self.other_axes_index]*len(h.shape)
+ idx = [self.other_axes_index] * len(h.shape)
idx[dim] = slice(0, None)
idx = tuple(idx)
if op:
h[idx] = op(h[idx], v)
else:
h[idx] = v
+
self.fill_histo = fill_histo
def test_nominal_bin_edges(self):
-
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins)
+ instance = HistogramndLut(self.sample, self.histo_range, self.n_bins)
bin_edges = instance.bins_edges
- expected_edges = _get_bin_edges(self.histo_range,
- self.n_bins,
- self.ndims)
+ expected_edges = _get_bin_edges(self.histo_range, self.n_bins, self.ndims)
for i_edges, edges in enumerate(expected_edges):
- self.assertTrue(np.array_equal(bin_edges[i_edges],
- expected_edges[i_edges]),
- msg='Testing bin_edges for dim {0}'
- ''.format(i_edges+1))
+ self.assertTrue(
+ np.array_equal(bin_edges[i_edges], expected_edges[i_edges]),
+ msg="Testing bin_edges for dim {0}" "".format(i_edges + 1),
+ )
def test_nominal_histo_range(self):
-
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins)
+ instance = HistogramndLut(self.sample, self.histo_range, self.n_bins)
histo_range = instance.histo_range
self.assertTrue(np.array_equal(histo_range, self.histo_range))
def test_nominal_last_bin_closed(self):
-
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins)
+ instance = HistogramndLut(self.sample, self.histo_range, self.n_bins)
last_bin_closed = instance.last_bin_closed
self.assertEqual(last_bin_closed, False)
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins,
- last_bin_closed=True)
+ instance = HistogramndLut(
+ self.sample, self.histo_range, self.n_bins, last_bin_closed=True
+ )
last_bin_closed = instance.last_bin_closed
self.assertEqual(last_bin_closed, True)
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins,
- last_bin_closed=False)
+ instance = HistogramndLut(
+ self.sample, self.histo_range, self.n_bins, last_bin_closed=False
+ )
last_bin_closed = instance.last_bin_closed
self.assertEqual(last_bin_closed, False)
def test_nominal_n_bins_array(self):
-
test_n_bins = np.arange(self.ndims) + 10
- instance = HistogramndLut(self.sample,
- self.histo_range,
- test_n_bins)
+ instance = HistogramndLut(self.sample, self.histo_range, test_n_bins)
n_bins = instance.n_bins
self.assertTrue(np.array_equal(test_n_bins, n_bins))
def test_nominal_n_bins_scalar(self):
-
test_n_bins = 10
expected_n_bins = np.array([test_n_bins] * self.ndims)
- instance = HistogramndLut(self.sample,
- self.histo_range,
- test_n_bins)
+ instance = HistogramndLut(self.sample, self.histo_range, test_n_bins)
n_bins = instance.n_bins
self.assertTrue(np.array_equal(expected_n_bins, n_bins))
def test_nominal_histo_ref(self):
- """
- """
+ """ """
expected_h_tpl = np.array([2, 1, 1, 1, 1])
expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins)
+ instance = HistogramndLut(self.sample, self.histo_range, self.n_bins)
instance.accumulate(self.weights)
@@ -245,20 +224,17 @@ class _TestHistogramndLut_nominal(unittest.TestCase):
self.assertTrue(np.array_equal(w_histo_2, w_histo_ref))
def test_nominal_accumulate_once(self):
- """
- """
+ """ """
expected_h_tpl = np.array([2, 1, 1, 1, 1])
expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins)
+ instance = HistogramndLut(self.sample, self.histo_range, self.n_bins)
instance.accumulate(self.weights)
@@ -270,28 +246,24 @@ class _TestHistogramndLut_nominal(unittest.TestCase):
self.assertTrue(np.array_equal(histo, expected_h))
self.assertTrue(np.array_equal(w_histo, expected_c))
self.assertTrue(np.array_equal(instance.histo(), expected_h))
- self.assertTrue(np.array_equal(instance.weighted_histo(),
- expected_c))
+ self.assertTrue(np.array_equal(instance.weighted_histo(), expected_c))
def test_nominal_accumulate_twice(self):
- """
- """
+ """ """
expected_h_tpl = np.array([2, 1, 1, 1, 1])
expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
# calling accumulate twice
expected_h *= 2
expected_c *= 2
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins)
+ instance = HistogramndLut(self.sample, self.histo_range, self.n_bins)
instance.accumulate(self.weights)
@@ -305,24 +277,20 @@ class _TestHistogramndLut_nominal(unittest.TestCase):
self.assertTrue(np.array_equal(histo, expected_h))
self.assertTrue(np.array_equal(w_histo, expected_c))
self.assertTrue(np.array_equal(instance.histo(), expected_h))
- self.assertTrue(np.array_equal(instance.weighted_histo(),
- expected_c))
+ self.assertTrue(np.array_equal(instance.weighted_histo(), expected_c))
def test_nominal_apply_lut_once(self):
- """
- """
+ """ """
expected_h_tpl = np.array([2, 1, 1, 1, 1])
expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins)
+ instance = HistogramndLut(self.sample, self.histo_range, self.n_bins)
histo, w_histo = instance.apply_lut(self.weights)
@@ -334,29 +302,26 @@ class _TestHistogramndLut_nominal(unittest.TestCase):
self.assertEqual(instance.weighted_histo(), None)
def test_nominal_apply_lut_twice(self):
- """
- """
+ """ """
expected_h_tpl = np.array([2, 1, 1, 1, 1])
expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
# calling apply_lut twice
expected_h *= 2
expected_c *= 2
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins)
+ instance = HistogramndLut(self.sample, self.histo_range, self.n_bins)
histo, w_histo = instance.apply_lut(self.weights)
- histo_2, w_histo_2 = instance.apply_lut(self.weights,
- histo=histo,
- weighted_histo=w_histo)
+ histo_2, w_histo_2 = instance.apply_lut(
+ self.weights, histo=histo, weighted_histo=w_histo
+ )
self.assertEqual(id(histo), id(histo_2))
self.assertEqual(id(w_histo), id(w_histo_2))
@@ -368,21 +333,19 @@ class _TestHistogramndLut_nominal(unittest.TestCase):
self.assertEqual(instance.weighted_histo(), None)
def test_nominal_accumulate_last_bin_closed(self):
- """
- """
+ """ """
expected_h_tpl = np.array([2, 1, 1, 1, 2])
expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 1101.1])
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins,
- last_bin_closed=True)
+ instance = HistogramndLut(
+ self.sample, self.histo_range, self.n_bins, last_bin_closed=True
+ )
instance.accumulate(self.weights)
@@ -395,27 +358,22 @@ class _TestHistogramndLut_nominal(unittest.TestCase):
self.assertTrue(np.array_equal(w_histo, expected_c))
def test_nominal_accumulate_weight_min_max(self):
- """
- """
+ """ """
weight_min = -299.9
weight_max = 499.9
expected_h_tpl = np.array([0, 1, 1, 1, 0])
- expected_c_tpl = np.array([0., -0.5, 0.01, 300.3, 0.])
+ expected_c_tpl = np.array([0.0, -0.5, 0.01, 300.3, 0.0])
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins)
+ instance = HistogramndLut(self.sample, self.histo_range, self.n_bins)
- instance.accumulate(self.weights,
- weight_min=weight_min,
- weight_max=weight_max)
+ instance.accumulate(self.weights, weight_min=weight_min, weight_max=weight_max)
histo = instance.histo()
w_histo = instance.weighted_histo()
@@ -435,13 +393,12 @@ class _TestHistogramndLut_nominal(unittest.TestCase):
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins,
- dtype=np.int32)
+ instance = HistogramndLut(
+ self.sample, self.histo_range, self.n_bins, dtype=np.int32
+ )
instance.accumulate(self.weights)
@@ -458,18 +415,17 @@ class _TestHistogramndLut_nominal(unittest.TestCase):
int32 weights, float32 weighted_histogram
"""
expected_h_tpl = np.array([2, 1, 1, 1, 1])
- expected_c_tpl = np.array([-700., 0., 0., 300., 500.])
+ expected_c_tpl = np.array([-700.0, 0.0, 0.0, 300.0, 500.0])
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.float32)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins,
- dtype=np.float32)
+ instance = HistogramndLut(
+ self.sample, self.histo_range, self.n_bins, dtype=np.float32
+ )
instance.accumulate(self.weights.astype(np.int32))
@@ -491,12 +447,10 @@ class _TestHistogramndLut_nominal(unittest.TestCase):
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.int32)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins)
+ instance = HistogramndLut(self.sample, self.histo_range, self.n_bins)
instance.accumulate(self.weights.astype(np.int32))
@@ -518,12 +472,10 @@ class _TestHistogramndLut_nominal(unittest.TestCase):
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.int32)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins)
+ instance = HistogramndLut(self.sample, self.histo_range, self.n_bins)
instance.accumulate(self.weights.astype(np.int32))
instance.accumulate(self.weights)
@@ -546,13 +498,9 @@ class _TestHistogramndLut_nominal(unittest.TestCase):
type = self.sample.dtype.newbyteorder("L")
sampleL = self.sample.astype(type)
- histo_inst = HistogramndLut(sampleB,
- self.histo_range,
- self.n_bins)
+ histo_inst = HistogramndLut(sampleB, self.histo_range, self.n_bins)
- histo_inst = HistogramndLut(sampleL,
- self.histo_range,
- self.n_bins)
+ histo_inst = HistogramndLut(sampleL, self.histo_range, self.n_bins)
class TestHistogramndLut_nominal_1d(_TestHistogramndLut_nominal):
diff --git a/src/silx/math/test/test_calibration.py b/src/silx/math/test/test_calibration.py
index 1c961be..27c4c57 100644
--- a/src/silx/math/test/test_calibration.py
+++ b/src/silx/math/test/test_calibration.py
@@ -1,5 +1,5 @@
# /*##########################################################################
-# Copyright (C) 2018 European Synchrotron Radiation Facility
+# Copyright (C) 2018-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -31,8 +31,12 @@ import unittest
import numpy
-from silx.math.calibration import NoCalibration, LinearCalibration, \
- ArrayCalibration, FunctionCalibration
+from silx.math.calibration import (
+ NoCalibration,
+ LinearCalibration,
+ ArrayCalibration,
+ FunctionCalibration,
+)
X = numpy.array([3.14, 2.73, 1337])
@@ -46,11 +50,10 @@ class TestNoCalibration(unittest.TestCase):
self.assertTrue(self.calib.is_affine())
def testSlope(self):
- self.assertEqual(self.calib.get_slope(), 1.)
+ self.assertEqual(self.calib.get_slope(), 1.0)
def testYIntercept(self):
- self.assertEqual(self.calib(0.),
- 0.)
+ self.assertEqual(self.calib(0.0), 0.0)
def testCall(self):
self.assertTrue(numpy.array_equal(self.calib(X), X))
@@ -60,8 +63,7 @@ class TestLinearCalibration(unittest.TestCase):
def setUp(self):
self.y_intercept = 1.5
self.slope = 2.5
- self.calib = LinearCalibration(y_intercept=self.y_intercept,
- slope=self.slope)
+ self.calib = LinearCalibration(y_intercept=self.y_intercept, slope=self.slope)
def testIsAffine(self):
self.assertTrue(self.calib.is_affine())
@@ -70,17 +72,17 @@ class TestLinearCalibration(unittest.TestCase):
self.assertEqual(self.calib.get_slope(), self.slope)
def testYIntercept(self):
- self.assertEqual(self.calib(0.),
- self.y_intercept)
+ self.assertEqual(self.calib(0.0), self.y_intercept)
def testCall(self):
- self.assertTrue(numpy.array_equal(self.calib(X),
- self.y_intercept + self.slope * X))
+ self.assertTrue(
+ numpy.array_equal(self.calib(X), self.y_intercept + self.slope * X)
+ )
class TestArrayCalibration(unittest.TestCase):
def setUp(self):
- self.arr = numpy.array([45.2, 25.3, 666., -8.])
+ self.arr = numpy.array([45.2, 25.3, 666.0, -8.0])
self.calib = ArrayCalibration(self.arr)
self.affine_calib = ArrayCalibration([0.1, 0.2, 0.3])
@@ -91,12 +93,10 @@ class TestArrayCalibration(unittest.TestCase):
def testSlope(self):
with self.assertRaises(AttributeError):
self.calib.get_slope()
- self.assertEqual(self.affine_calib.get_slope(),
- 0.1)
+ self.assertEqual(self.affine_calib.get_slope(), 0.1)
def testYIntercept(self):
- self.assertEqual(self.calib(0),
- self.arr[0])
+ self.assertEqual(self.calib(0), self.arr[0])
def testCall(self):
with self.assertRaises(ValueError):
@@ -107,22 +107,27 @@ class TestArrayCalibration(unittest.TestCase):
# floats are not valid indices
self.calib(3.14)
- self.assertTrue(
- numpy.array_equal(self.calib([1, 2, 3, 4]),
- self.arr))
+ self.assertTrue(numpy.array_equal(self.calib([1, 2, 3, 4]), self.arr))
for idx, value in enumerate(self.arr):
self.assertEqual(self.calib(idx), value)
+ def testEmptyArray(self):
+ with self.assertRaises(ValueError):
+ ArrayCalibration(numpy.array([]))
+
+ def testOneElementArray(self):
+ calib = ArrayCalibration(numpy.array([1]))
+ self.assertFalse(calib.is_affine())
+
class TestFunctionCalibration(unittest.TestCase):
def setUp(self):
self.non_affine_fun = numpy.sin
self.non_affine_calib = FunctionCalibration(self.non_affine_fun)
- self.affine_fun = lambda x: 52. * x + 0.01
- self.affine_calib = FunctionCalibration(self.affine_fun,
- is_affine=True)
+ self.affine_fun = lambda x: 52.0 * x + 0.01
+ self.affine_calib = FunctionCalibration(self.affine_fun, is_affine=True)
def testIsAffine(self):
self.assertFalse(self.non_affine_calib.is_affine())
@@ -131,12 +136,9 @@ class TestFunctionCalibration(unittest.TestCase):
def testSlope(self):
with self.assertRaises(AttributeError):
self.non_affine_calib.get_slope()
- self.assertAlmostEqual(self.affine_calib.get_slope(),
- 52.)
+ self.assertAlmostEqual(self.affine_calib.get_slope(), 52.0)
def testCall(self):
for x in X:
- self.assertAlmostEqual(self.non_affine_calib(x),
- self.non_affine_fun(x))
- self.assertAlmostEqual(self.affine_calib(x),
- self.affine_fun(x))
+ self.assertAlmostEqual(self.non_affine_calib(x), self.non_affine_fun(x))
+ self.assertAlmostEqual(self.affine_calib(x), self.affine_fun(x))
diff --git a/src/silx/math/test/test_colormap.py b/src/silx/math/test/test_colormap.py
index 144ee5f..4d09f0d 100644
--- a/src/silx/math/test/test_colormap.py
+++ b/src/silx/math/test/test_colormap.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2018-2022 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -32,6 +32,7 @@ import logging
import sys
import numpy
+import pytest
from silx.utils.testutils import ParametricTestCase
from silx.math import colormap
@@ -45,20 +46,23 @@ class TestNormalization(ParametricTestCase):
def _testCodec(self, normalization, rtol=1e-5):
"""Test apply/revert for normalizations"""
- test_data = (numpy.arange(1, 10, dtype=numpy.int32),
- numpy.linspace(1., 100., 1000, dtype=numpy.float32),
- numpy.linspace(-1., 1., 100, dtype=numpy.float32),
- 1.,
- 1)
+ test_data = (
+ numpy.arange(1, 10, dtype=numpy.int32),
+ numpy.linspace(1.0, 100.0, 1000, dtype=numpy.float32),
+ numpy.linspace(-1.0, 1.0, 100, dtype=numpy.float32),
+ 1.0,
+ 1,
+ )
for index in range(len(test_data)):
with self.subTest(normalization=normalization, data_index=index):
data = test_data[index]
- normalized = normalization.apply(data, 1., 100.)
- result = normalization.revert(normalized, 1., 100.)
+ normalized = normalization.apply(data, 1.0, 100.0)
+ result = normalization.revert(normalized, 1.0, 100.0)
- self.assertTrue(numpy.array_equal(
- numpy.isnan(normalized), numpy.isnan(result)))
+ self.assertTrue(
+ numpy.array_equal(numpy.isnan(normalized), numpy.isnan(result))
+ )
if isinstance(data, numpy.ndarray):
notNaN = numpy.logical_not(numpy.isnan(result))
@@ -78,10 +82,10 @@ class TestNormalization(ParametricTestCase):
self._testCodec(normalization, rtol=1e-3)
# Specific extra tests
- self.assertTrue(numpy.isnan(normalization.apply(-1., 1., 100.)))
- self.assertTrue(numpy.isnan(normalization.apply(numpy.nan, 1., 100.)))
- self.assertEqual(normalization.apply(numpy.inf, 1., 100.), numpy.inf)
- self.assertEqual(normalization.apply(0, 1., 100.), - numpy.inf)
+ self.assertTrue(numpy.isnan(normalization.apply(-1.0, 1.0, 100.0)))
+ self.assertTrue(numpy.isnan(normalization.apply(numpy.nan, 1.0, 100.0)))
+ self.assertEqual(normalization.apply(numpy.inf, 1.0, 100.0), numpy.inf)
+ self.assertEqual(normalization.apply(0, 1.0, 100.0), -numpy.inf)
def testArcsinhNormalization(self):
"""Test for ArcsinhNormalization"""
@@ -93,24 +97,25 @@ class TestNormalization(ParametricTestCase):
self._testCodec(normalization)
# Specific extra tests
- self.assertTrue(numpy.isnan(normalization.apply(-1., 0., 100.)))
- self.assertTrue(numpy.isnan(normalization.apply(numpy.nan, 0., 100.)))
- self.assertEqual(normalization.apply(numpy.inf, 0., 100.), numpy.inf)
- self.assertEqual(normalization.apply(0, 0., 100.), 0.)
+ self.assertTrue(numpy.isnan(normalization.apply(-1.0, 0.0, 100.0)))
+ self.assertTrue(numpy.isnan(normalization.apply(numpy.nan, 0.0, 100.0)))
+ self.assertEqual(normalization.apply(numpy.inf, 0.0, 100.0), numpy.inf)
+ self.assertEqual(normalization.apply(0, 0.0, 100.0), 0.0)
class TestColormap(ParametricTestCase):
"""Test silx.math.colormap.cmap"""
NORMALIZATIONS = (
- 'linear',
- 'log',
- 'arcsinh',
- 'sqrt',
+ "linear",
+ "log",
+ "arcsinh",
+ "sqrt",
colormap.LinearNormalization(),
colormap.LogarithmicNormalization(),
- colormap.GammaNormalization(2.),
- colormap.GammaNormalization(0.5))
+ colormap.GammaNormalization(2.0),
+ colormap.GammaNormalization(0.5),
+ )
@staticmethod
def ref_colormap(data, colors, vmin, vmax, normalization, nan_color):
@@ -123,22 +128,25 @@ class TestColormap(ParametricTestCase):
:param str normalization: Normalization to use
:param Union[numpy.ndarray, None] nan_color: Color to use for NaN
"""
- norm_functions = {'linear': lambda v: v,
- 'log': numpy.log10,
- 'arcsinh': numpy.arcsinh,
- 'sqrt': numpy.sqrt}
+ norm_functions = {
+ "linear": lambda v: v,
+ "log": numpy.log10,
+ "arcsinh": numpy.arcsinh,
+ "sqrt": numpy.sqrt,
+ }
if isinstance(normalization, str):
norm_function = norm_functions[normalization]
else:
+
def norm_function(value):
return normalization.apply(value, vmin, vmax)
- with numpy.errstate(divide='ignore', invalid='ignore'):
+ with numpy.errstate(divide="ignore", invalid="ignore"):
# Ignore divide by zero and invalid value encountered in log10, sqrt
norm_data, vmin, vmax = map(norm_function, (data, vmin, vmax))
- if normalization == 'arcsinh' and sys.platform == 'win32':
+ if normalization == "arcsinh" and sys.platform == "win32":
# There is a difference of behavior of numpy.arcsinh
# between Windows and other OS for results of infinite values
# This makes Windows behaves as Linux and MacOS
@@ -149,10 +157,9 @@ class TestColormap(ParametricTestCase):
scale = nb_colors / (vmax - vmin)
# Substraction must be done in float to avoid overflow with uint
- indices = numpy.clip(scale * (norm_data - float(vmin)),
- 0, nb_colors - 1)
+ indices = numpy.clip(scale * (norm_data - float(vmin)), 0, nb_colors - 1)
indices[numpy.isnan(indices)] = nb_colors # Use an extra index for NaN
- indices = indices.astype('uint')
+ indices = indices.astype("uint")
# Add NaN color to array
if nan_color is None:
@@ -171,11 +178,11 @@ class TestColormap(ParametricTestCase):
:param str normalization: Normalization to use
:param Union[numpy.ndarray, None] nan_color: Color to use for NaN
"""
- image = colormap.cmap(
- data, colors, vmin, vmax, normalization, nan_color)
+ image = colormap.cmap(data, colors, vmin, vmax, normalization, nan_color)
ref_image = self.ref_colormap(
- data, colors, vmin, vmax, normalization, nan_color)
+ data, colors, vmin, vmax, normalization, nan_color
+ )
self.assertTrue(numpy.allclose(ref_image, image))
self.assertEqual(image.dtype, colors.dtype)
@@ -191,16 +198,20 @@ class TestColormap(ParametricTestCase):
colors[:, 3] = 255
# Generates (u)int and floats types
- dtypes = [e + k + i for e in '<>' for k in 'uif' for i in '1248'
- if k != 'f' or i != '1']
+ dtypes = [
+ e + k + i
+ for e in "<>"
+ for k in "uif"
+ for i in "1248"
+ if k != "f" or i != "1"
+ ]
dtypes.append(numpy.dtype(numpy.longdouble).name) # Add long double
for normalization in self.NORMALIZATIONS:
for dtype in dtypes:
with self.subTest(dtype=dtype, normalization=normalization):
- _logger.info('normalization: %s, dtype: %s',
- normalization, dtype)
- data = numpy.arange(-5, 15, dtype=dtype).reshape(4, 5)
+ _logger.info("normalization: %s, dtype: %s", normalization, dtype)
+ data = numpy.arange(-5, 15).astype(dtype).reshape(4, 5)
self._test(data, colors, 1, 10, normalization, None)
@@ -211,21 +222,20 @@ class TestColormap(ParametricTestCase):
colors[:, 3] = 255
test_data = { # message: data
- 'no finite values': (float('inf'), float('-inf'), float('nan')),
- 'only NaN': (float('nan'), float('nan'), float('nan')),
- 'mix finite/not finite': (float('inf'), float('-inf'), 1., float('nan')),
+ "no finite values": (float("inf"), float("-inf"), float("nan")),
+ "only NaN": (float("nan"), float("nan"), float("nan")),
+ "mix finite/not finite": (float("inf"), float("-inf"), 1.0, float("nan")),
}
for normalization in self.NORMALIZATIONS:
for msg, data in test_data.items():
with self.subTest(msg, normalization=normalization):
- _logger.info('normalization: %s, %s', normalization, msg)
+ _logger.info("normalization: %s, %s", normalization, msg)
data = numpy.array(data, dtype=numpy.float64)
self._test(data, colors, 1, 10, normalization, (0, 0, 0, 0))
def test_errors(self):
- """Test raising exception for bad vmin, vmax, normalization parameters
- """
+ """Test raising exception for bad vmin, vmax, normalization parameters"""
colors = numpy.zeros((256, 4), dtype=numpy.uint8)
colors[:, 0] = numpy.arange(len(colors))
colors[:, 3] = 255
@@ -233,18 +243,18 @@ class TestColormap(ParametricTestCase):
data = numpy.arange(10, dtype=numpy.float64)
test_params = [ # (vmin, vmax, normalization)
- (-1., 2., 'log'),
- (0., 1., 'log'),
- (1., 0., 'log'),
- (-1., 1., 'sqrt'),
- (1., -1., 'sqrt'),
+ (-1.0, 2.0, "log"),
+ (0.0, 1.0, "log"),
+ (1.0, 0.0, "log"),
+ (-1.0, 1.0, "sqrt"),
+ (1.0, -1.0, "sqrt"),
]
for vmin, vmax, normalization in test_params:
- with self.subTest(
- vmin=vmin, vmax=vmax, normalization=normalization):
- _logger.info('normalization: %s, range: [%f, %f]',
- normalization, vmin, vmax)
+ with self.subTest(vmin=vmin, vmax=vmax, normalization=normalization):
+ _logger.info(
+ "normalization: %s, range: [%f, %f]", normalization, vmin, vmax
+ )
with self.assertRaises(ValueError):
self._test(data, colors, vmin, vmax, normalization, None)
@@ -262,5 +272,35 @@ def test_apply_colormap():
autoscale="minmax",
vmin=None,
vmax=None,
- gamma=1.0)
+ gamma=1.0,
+ )
assert numpy.array_equal(colors, expected_colors)
+
+
+testdata_normalize = [
+ (numpy.arange(512), numpy.arange(512) // 2, 0, 511),
+ ((numpy.nan, numpy.inf, -numpy.inf), (0, 255, 0), 0, 1),
+ ((numpy.nan, numpy.inf, -numpy.inf, 1), (0, 255, 0, 0), 1, 1),
+]
+
+
+@pytest.mark.parametrize(
+ "data,expected_data,expected_vmin,expected_vmax",
+ testdata_normalize,
+)
+def test_normalize(data, expected_data, expected_vmin, expected_vmax):
+ """Basic test of silx.math.colormap.normalize"""
+ result = colormap.normalize(
+ numpy.asarray(data),
+ norm="linear",
+ autoscale="minmax",
+ vmin=None,
+ vmax=None,
+ gamma=1.0,
+ )
+ assert result.vmin == expected_vmin
+ assert result.vmax == expected_vmax
+ assert numpy.array_equal(
+ result.data,
+ numpy.asarray(expected_data, dtype=numpy.uint8),
+ )
diff --git a/src/silx/math/test/test_combo.py b/src/silx/math/test/test_combo.py
index eed0625..917be55 100644
--- a/src/silx/math/test/test_combo.py
+++ b/src/silx/math/test/test_combo.py
@@ -1,5 +1,5 @@
# /*##########################################################################
-# Copyright (C) 2016-2020 European Synchrotron Radiation Facility
+# Copyright (C) 2016-2022 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -27,8 +27,6 @@ __license__ = "MIT"
__date__ = "17/01/2018"
-import unittest
-
import numpy
from silx.utils.testutils import ParametricTestCase
@@ -39,11 +37,11 @@ from silx.math.combo import min_max
class TestMinMax(ParametricTestCase):
"""Tests of min max combo"""
- FLOATING_DTYPES = 'float32', 'float64'
+ FLOATING_DTYPES = "float32", "float64"
if hasattr(numpy, "float128"):
- FLOATING_DTYPES += ('float128',)
- SIGNED_INT_DTYPES = 'int8', 'int16', 'int32', 'int64'
- UNSIGNED_INT_DTYPES = 'uint8', 'uint16', 'uint32', 'uint64'
+ FLOATING_DTYPES += ("float128",)
+ SIGNED_INT_DTYPES = "int8", "int16", "int32", "int64"
+ UNSIGNED_INT_DTYPES = "uint8", "uint16", "uint32", "uint64"
DTYPES = FLOATING_DTYPES + SIGNED_INT_DTYPES + UNSIGNED_INT_DTYPES
def _numpy_min_max(self, data, min_positive=False, finite=False):
@@ -55,7 +53,7 @@ class TestMinMax(ParametricTestCase):
"""
data = numpy.array(data, copy=False)
if data.size == 0:
- raise ValueError('Zero-sized array')
+ raise ValueError("Zero-sized array")
minimum = None
argmin = None
@@ -84,7 +82,7 @@ class TestMinMax(ParametricTestCase):
argmax = numpy.where(data == maximum)[0][0]
if min_positive:
- with numpy.errstate(invalid='ignore'):
+ with numpy.errstate(invalid="ignore"):
# Ignore invalid value encountered in greater
pos_data = filtered_data[filtered_data > 0]
if pos_data.size > 0:
@@ -100,8 +98,9 @@ class TestMinMax(ParametricTestCase):
:param bool min_positive: True to test with positive min
:param bool finite: True to only test finite values
"""
- minimum, min_pos, maximum, argmin, argmin_pos, argmax = \
- self._numpy_min_max(data, min_positive, finite)
+ minimum, min_pos, maximum, argmin, argmin_pos, argmax = self._numpy_min_max(
+ data, min_positive, finite
+ )
result = min_max(data, min_positive, finite)
@@ -114,30 +113,28 @@ class TestMinMax(ParametricTestCase):
def assertSimilar(self, a, b):
"""Assert that a and b are both None or NaN or that a == b."""
- self.assertTrue((a is None and b is None) or
- (numpy.isnan(a) and numpy.isnan(b)) or
- a == b)
+ self.assertTrue(
+ (a is None and b is None) or (numpy.isnan(a) and numpy.isnan(b)) or a == b
+ )
def test_different_datasets(self):
"""Test min_max with different numpy.arange datasets."""
size = 1000
for dtype in self.DTYPES:
-
- tests = {
- '0 to N': (0, 1),
- 'N-1 to 0': (size - 1, -1)}
+ tests = {"0 to N": (0, 1), "N-1 to 0": (size - 1, -1)}
if dtype not in self.UNSIGNED_INT_DTYPES:
- tests['N/2 to -N/2'] = size // 2, -1
- tests['0 to -N'] = 0, -1
+ tests["N/2 to -N/2"] = size // 2, -1
+ tests["0 to -N"] = 0, -1
for name, (start, step) in tests.items():
for min_positive in (True, False):
- with self.subTest(dtype=dtype,
- min_positive=min_positive,
- data=name):
- data = numpy.arange(
- start, start + step * size, step, dtype=dtype)
+ with self.subTest(
+ dtype=dtype, min_positive=min_positive, data=name
+ ):
+ data = numpy.arange(start, start + step * size, step).astype(
+ dtype
+ )
self._test_min_max(data, min_positive)
@@ -147,18 +144,18 @@ class TestMinMax(ParametricTestCase):
with self.subTest(dtype=dtype):
with self.assertRaises(TypeError):
min_max(None)
-
+
data = numpy.array((), dtype=dtype)
with self.assertRaises(ValueError):
min_max(data)
NAN_TEST_DATA = [
- (float('nan'), float('nan')), # All NaNs
- (float('nan'), 1.0), # NaN first and positive
- (float('nan'), -1.0), # NaN first and negative
- (1.0, 2.0, float('nan')), # NaN last and positive
- (-1.0, -2.0, float('nan')), # NaN last and negative
- (1.0, float('nan'), -1.0), # Some NaN
+ (float("nan"), float("nan")), # All NaNs
+ (float("nan"), 1.0), # NaN first and positive
+ (float("nan"), -1.0), # NaN first and negative
+ (1.0, 2.0, float("nan")), # NaN last and positive
+ (-1.0, -2.0, float("nan")), # NaN last and negative
+ (1.0, float("nan"), -1.0), # Some NaN
]
def test_nandata(self):
@@ -170,12 +167,12 @@ class TestMinMax(ParametricTestCase):
self._test_min_max(data, min_positive=True)
INF_TEST_DATA = [
- [float('inf')] * 3, # All +inf
- [float('-inf')] * 3, # All -inf
- (float('inf'), float('-inf')), # + and - inf
- (float('inf'), float('-inf'), float('nan')), # +/-inf, nan last
- (float('nan'), float('-inf'), float('inf')), # +/-inf, nan first
- (float('inf'), float('nan'), float('-inf')), # +/-inf, nan center
+ [float("inf")] * 3, # All +inf
+ [float("-inf")] * 3, # All -inf
+ (float("inf"), float("-inf")), # + and - inf
+ (float("inf"), float("-inf"), float("nan")), # +/-inf, nan last
+ (float("nan"), float("-inf"), float("inf")), # +/-inf, nan first
+ (float("inf"), float("nan"), float("-inf")), # +/-inf, nan center
]
def test_infdata(self):
@@ -189,10 +186,10 @@ class TestMinMax(ParametricTestCase):
def test_finite(self):
"""Test min_max with finite=True"""
tests = [
- (-1., 2., 0.), # Basic test
- (float('nan'), float('inf'), float('-inf')), # NaN + Inf
- (float('nan'), float('inf'), -2, float('-inf')), # NaN + Inf + 1 value
- (float('inf'), -3, -2), # values + inf
+ (-1.0, 2.0, 0.0), # Basic test
+ (float("nan"), float("inf"), float("-inf")), # NaN + Inf
+ (float("nan"), float("inf"), -2, float("-inf")), # NaN + Inf + 1 value
+ (float("inf"), -3, -2), # values + inf
]
tests += self.INF_TEST_DATA
tests += self.NAN_TEST_DATA
diff --git a/src/silx/math/test/test_histogramnd_error.py b/src/silx/math/test/test_histogramnd_error.py
index d01cab9..c640b4a 100644
--- a/src/silx/math/test/test_histogramnd_error.py
+++ b/src/silx/math/test/test_histogramnd_error.py
@@ -1,5 +1,5 @@
# /*##########################################################################
-# Copyright (C) 2016 European Synchrotron Radiation Facility
+# Copyright (C) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -28,8 +28,6 @@ __date__ = "01/02/2016"
"""
Tests of the histogramnd function, error cases.
"""
-import sys
-import platform
import unittest
import numpy as np
@@ -47,61 +45,61 @@ class _Test_chistogramnd_errors(unittest.TestCase):
"""
Unit tests of the chistogramnd error cases.
"""
+
__test__ = False # ignore abstract class
def setUp(self):
self.skipTest("Abstract class")
def test_weights_shape(self):
- """
- """
+ """ """
for err_w_shape in self.err_weights_shapes:
- test_msg = ('Testing invalid weights shape : {0}'
- ''.format(err_w_shape))
+ test_msg = "Testing invalid weights shape : {0}" "".format(err_w_shape)
- err_weights = np.random.randint(0,
- high=10,
- size=err_w_shape)
+ err_weights = np.random.randint(0, high=10, size=err_w_shape)
err_weights = err_weights.astype(np.double)
ex_str = None
try:
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=err_weights)[0:2]
+ histo, cumul = histogramnd(
+ self.sample, self.histo_range, self.n_bins, weights=err_weights
+ )[0:2]
except ValueError as ex:
ex_str = str(ex)
self.assertIsNotNone(ex_str, msg=test_msg)
- self.assertEqual(ex_str,
- '<weights> must be an array whose length '
- 'is equal to the number of samples.')
+ self.assertEqual(
+ ex_str,
+ "<weights> must be an array whose length "
+ "is equal to the number of samples.",
+ )
def test_histo_range_shape(self):
- """
- """
+ """ """
n_dims = 1 if len(self.s_shape) == 1 else self.s_shape[1]
- expected_txt_tpl = ('<histo_range> error : expected {n_dims} sets '
- 'of lower and upper bin edges, '
- 'got the following instead : {histo_range}. '
- '(provided <sample> contains '
- '{n_dims}D values)')
+ expected_txt_tpl = (
+ "<histo_range> error : expected {n_dims} sets "
+ "of lower and upper bin edges, "
+ "got the following instead : {histo_range}. "
+ "(provided <sample> contains "
+ "{n_dims}D values)"
+ )
for err_histo_range in self.err_histo_range_shapes:
- test_msg = ('Testing invalid histo_range shape : {0}'
- ''.format(err_histo_range))
+ test_msg = "Testing invalid histo_range shape : {0}" "".format(
+ err_histo_range
+ )
- expected_txt = expected_txt_tpl.format(histo_range=err_histo_range,
- n_dims=n_dims)
+ expected_txt = expected_txt_tpl.format(
+ histo_range=err_histo_range, n_dims=n_dims
+ )
ex_str = None
try:
- histo, cumul = histogramnd(self.sample,
- err_histo_range,
- self.n_bins,
- weights=self.weights)[0:2]
+ histo, cumul = histogramnd(
+ self.sample, err_histo_range, self.n_bins, weights=self.weights
+ )[0:2]
except ValueError as ex:
ex_str = str(ex)
@@ -109,24 +107,23 @@ class _Test_chistogramnd_errors(unittest.TestCase):
self.assertEqual(ex_str, expected_txt, msg=test_msg)
def test_nbins_shape(self):
- """
- """
+ """ """
- expected_txt = ('n_bins must be either a scalar (same number '
- 'of bins for all dimensions) or '
- 'an array (number of bins for each '
- 'dimension).')
+ expected_txt = (
+ "n_bins must be either a scalar (same number "
+ "of bins for all dimensions) or "
+ "an array (number of bins for each "
+ "dimension)."
+ )
for err_n_bins in self.err_n_bins_shapes:
- test_msg = ('Testing invalid n_bins shape : {0}'
- ''.format(err_n_bins))
+ test_msg = "Testing invalid n_bins shape : {0}" "".format(err_n_bins)
ex_str = None
try:
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- err_n_bins,
- weights=self.weights)[0:2]
+ histo, cumul = histogramnd(
+ self.sample, self.histo_range, err_n_bins, weights=self.weights
+ )[0:2]
except ValueError as ex:
ex_str = str(ex)
@@ -134,20 +131,17 @@ class _Test_chistogramnd_errors(unittest.TestCase):
self.assertEqual(ex_str, expected_txt, msg=test_msg)
def test_nbins_values(self):
- """
- """
- expected_txt = ('<n_bins> : only positive values allowed.')
+ """ """
+ expected_txt = "<n_bins> : only positive values allowed."
for err_n_bins in self.err_n_bins_values:
- test_msg = ('Testing invalid n_bins value : {0}'
- ''.format(err_n_bins))
+ test_msg = "Testing invalid n_bins value : {0}" "".format(err_n_bins)
ex_str = None
try:
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- err_n_bins,
- weights=self.weights)[0:2]
+ histo, cumul = histogramnd(
+ self.sample, self.histo_range, err_n_bins, weights=self.weights
+ )[0:2]
except ValueError as ex:
ex_str = str(ex)
@@ -155,33 +149,28 @@ class _Test_chistogramnd_errors(unittest.TestCase):
self.assertEqual(ex_str, expected_txt, msg=test_msg)
def test_histo_shape(self):
- """
- """
+ """ """
for err_h_shape in self.err_histo_shapes:
+ test_msg = "Testing invalid histo shape : {0}" "".format(err_h_shape)
- # windows & python 2.7 : numpy shapes are long values
- if platform.system() == 'Windows':
- version = (sys.version_info.major, sys.version_info.minor)
- if version <= (2, 7):
- err_h_shape = tuple([long(val) for val in err_h_shape])
-
- test_msg = ('Testing invalid histo shape : {0}'
- ''.format(err_h_shape))
-
- expected_txt = ('Provided <histo> array doesn\'t have '
- 'a shape compatible with <n_bins> '
- ': should be {0} instead of {1}.'
- ''.format(self.h_shape, err_h_shape))
+ expected_txt = (
+ "Provided <histo> array doesn't have "
+ "a shape compatible with <n_bins> "
+ ": should be {0} instead of {1}."
+ "".format(self.h_shape, err_h_shape)
+ )
histo = np.zeros(shape=err_h_shape, dtype=np.uint32)
ex_str = None
try:
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- histo=histo)[0:2]
+ histo, cumul = histogramnd(
+ self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ histo=histo,
+ )[0:2]
except ValueError as ex:
ex_str = str(ex)
@@ -189,26 +178,28 @@ class _Test_chistogramnd_errors(unittest.TestCase):
self.assertEqual(ex_str, expected_txt, msg=test_msg)
def test_histo_dtype(self):
- """
- """
+ """ """
for err_h_dtype in self.err_histo_dtypes:
- test_msg = ('Testing invalid histo dtype : {0}'
- ''.format(err_h_dtype))
+ test_msg = "Testing invalid histo dtype : {0}" "".format(err_h_dtype)
histo = np.zeros(shape=self.h_shape, dtype=err_h_dtype)
- expected_txt = ('Provided <histo> array doesn\'t have '
- 'the expected type '
- ': should be {0} instead of {1}.'
- ''.format(np.uint32, histo.dtype))
+ expected_txt = (
+ "Provided <histo> array doesn't have "
+ "the expected type "
+ ": should be {0} instead of {1}."
+ "".format(np.uint32, histo.dtype)
+ )
ex_str = None
try:
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- histo=histo)[0:2]
+ histo, cumul = histogramnd(
+ self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ histo=histo,
+ )[0:2]
except ValueError as ex:
ex_str = str(ex)
@@ -216,34 +207,31 @@ class _Test_chistogramnd_errors(unittest.TestCase):
self.assertEqual(ex_str, expected_txt, msg=test_msg)
def test_weighted_histo_shape(self):
- """
- """
+ """ """
# using the same values as histo
for err_h_shape in self.err_histo_shapes:
+ test_msg = "Testing invalid weighted_histo shape : {0}" "".format(
+ err_h_shape
+ )
- # windows & python 2.7 : numpy shapes are long values
- if platform.system() == 'Windows':
- version = (sys.version_info.major, sys.version_info.minor)
- if version <= (2, 7):
- err_h_shape = tuple([long(val) for val in err_h_shape])
-
- test_msg = ('Testing invalid weighted_histo shape : {0}'
- ''.format(err_h_shape))
-
- expected_txt = ('Provided <weighted_histo> array doesn\'t have '
- 'a shape compatible with <n_bins> '
- ': should be {0} instead of {1}.'
- ''.format(self.h_shape, err_h_shape))
+ expected_txt = (
+ "Provided <weighted_histo> array doesn't have "
+ "a shape compatible with <n_bins> "
+ ": should be {0} instead of {1}."
+ "".format(self.h_shape, err_h_shape)
+ )
cumul = np.zeros(shape=err_h_shape, dtype=np.double)
ex_str = None
try:
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- weighted_histo=cumul)[0:2]
+ histo, cumul = histogramnd(
+ self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ weighted_histo=cumul,
+ )[0:2]
except ValueError as ex:
ex_str = str(ex)
@@ -251,51 +239,54 @@ class _Test_chistogramnd_errors(unittest.TestCase):
self.assertEqual(ex_str, expected_txt, msg=test_msg)
def test_cumul_dtype(self):
- """
- """
+ """ """
# using the same values as histo
for err_h_dtype in self.err_histo_dtypes:
- test_msg = ('Testing invalid weighted_histo dtype : {0}'
- ''.format(err_h_dtype))
+ test_msg = "Testing invalid weighted_histo dtype : {0}" "".format(
+ err_h_dtype
+ )
cumul = np.zeros(shape=self.h_shape, dtype=err_h_dtype)
- expected_txt = ('Provided <weighted_histo> array doesn\'t have '
- 'the expected type '
- ': should be {0} or {1} instead of {2}.'
- ''.format(np.float64, np.float32, cumul.dtype))
+ expected_txt = (
+ "Provided <weighted_histo> array doesn't have "
+ "the expected type "
+ ": should be {0} or {1} instead of {2}."
+ "".format(np.float64, np.float32, cumul.dtype)
+ )
ex_str = None
try:
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- weighted_histo=cumul)[0:2]
+ histo, cumul = histogramnd(
+ self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ weighted_histo=cumul,
+ )[0:2]
except ValueError as ex:
ex_str = str(ex)
self.assertIsNotNone(ex_str, msg=test_msg)
self.assertEqual(ex_str, expected_txt, msg=test_msg)
-
+
def test_wh_histo_dtype(self):
- """
- """
+ """ """
# using the same values as histo
for err_h_dtype in self.err_histo_dtypes:
- test_msg = ('Testing invalid wh_dtype dtype : {0}'
- ''.format(err_h_dtype))
+ test_msg = "Testing invalid wh_dtype dtype : {0}" "".format(err_h_dtype)
- expected_txt = ('<wh_dtype> type not supported : {0}.'
- ''.format(err_h_dtype))
+ expected_txt = "<wh_dtype> type not supported : {0}." "".format(err_h_dtype)
ex_str = None
try:
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- wh_dtype=err_h_dtype)[0:2]
+ histo, cumul = histogramnd(
+ self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ wh_dtype=err_h_dtype,
+ )[0:2]
except ValueError as ex:
ex_str = str(ex)
@@ -303,26 +294,22 @@ class _Test_chistogramnd_errors(unittest.TestCase):
self.assertEqual(ex_str, expected_txt, msg=test_msg)
def test_unmanaged_dtypes(self):
- """
- """
+ """ """
for err_unmanaged_dtype in self.err_unmanaged_dtypes:
- test_msg = ('Testing unmanaged dtypes : {0}'
- ''.format(err_unmanaged_dtype))
+ test_msg = "Testing unmanaged dtypes : {0}" "".format(err_unmanaged_dtype)
sample = self.sample.astype(err_unmanaged_dtype[0])
weights = self.weights.astype(err_unmanaged_dtype[1])
- expected_txt = ('Case not supported - sample:{0} '
- 'and weights:{1}.'
- ''.format(sample.dtype,
- weights.dtype))
+ expected_txt = (
+ "Case not supported - sample:{0} "
+ "and weights:{1}."
+ "".format(sample.dtype, weights.dtype)
+ )
ex_str = None
try:
- histogramnd(sample,
- self.histo_range,
- self.n_bins,
- weights=weights)
+ histogramnd(sample, self.histo_range, self.n_bins, weights=weights)
except TypeError as ex:
ex_str = str(ex)
@@ -330,23 +317,24 @@ class _Test_chistogramnd_errors(unittest.TestCase):
self.assertEqual(ex_str, expected_txt, msg=test_msg)
def test_uncontiguous_histo(self):
- """
- """
+ """ """
# non contiguous array
shape = np.array(self.n_bins, ndmin=1)
shape[0] *= 2
histo_tmp = np.zeros(shape)
histo = histo_tmp[::2, ...]
- expected_txt = ('<histo> must be a C_CONTIGUOUS numpy array.')
+ expected_txt = "<histo> must be a C_CONTIGUOUS numpy array."
ex_str = None
try:
- histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- histo=histo)
+ histogramnd(
+ self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ histo=histo,
+ )
except ValueError as ex:
ex_str = str(ex)
@@ -354,23 +342,24 @@ class _Test_chistogramnd_errors(unittest.TestCase):
self.assertEqual(ex_str, expected_txt)
def test_uncontiguous_weighted_histo(self):
- """
- """
+ """ """
# non contiguous array
shape = np.array(self.n_bins, ndmin=1)
shape[0] *= 2
cumul_tmp = np.zeros(shape)
cumul = cumul_tmp[::2, ...]
- expected_txt = ('<weighted_histo> must be a C_CONTIGUOUS numpy array.')
+ expected_txt = "<weighted_histo> must be a C_CONTIGUOUS numpy array."
ex_str = None
try:
- histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- weighted_histo=cumul)
+ histogramnd(
+ self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ weighted_histo=cumul,
+ )
except ValueError as ex:
ex_str = str(ex)
@@ -382,6 +371,7 @@ class Test_chistogramnd_1D_errors(_Test_chistogramnd_errors):
"""
Unit tests of the 1D histogramnd error cases.
"""
+
__test__ = True # because _Test_chistogramnd_errors is ignored
def setUp(self):
@@ -390,48 +380,43 @@ class Test_chistogramnd_1D_errors(_Test_chistogramnd_errors):
self.s_shape = (self.n_elements,)
self.w_shape = (self.n_elements,)
- self.histo_range = [0., 100.]
+ self.histo_range = [0.0, 100.0]
self.n_bins = 10
self.h_shape = (self.n_bins,)
- self.sample = np.random.randint(0,
- high=10,
- size=self.s_shape)
+ self.sample = np.random.randint(0, high=10, size=self.s_shape)
self.sample = self.sample.astype(np.double)
- self.weights = np.random.randint(0,
- high=10,
- size=self.w_shape)
+ self.weights = np.random.randint(0, high=10, size=self.w_shape)
self.weights = self.weights.astype(np.double)
- self.err_weights_shapes = ((self.n_elements+1,),
- (self.n_elements-1,),
- (self.n_elements-1, 3))
- self.err_histo_range_shapes = ([0.],
- [0., 1., 2.],
- [[0.], [1.]])
- self.err_n_bins_shapes = ([10, 2],
- [[10], [2]])
- self.err_n_bins_values = (0,
- [-10],
- None)
- self.err_histo_shapes = ((self.n_bins+1,),
- (self.n_bins-1,),
- (self.n_bins, self.n_bins))
+ self.err_weights_shapes = (
+ (self.n_elements + 1,),
+ (self.n_elements - 1,),
+ (self.n_elements - 1, 3),
+ )
+ self.err_histo_range_shapes = ([0.0], [0.0, 1.0, 2.0], [[0.0], [1.0]])
+ self.err_n_bins_shapes = ([10, 2], [[10], [2]])
+ self.err_n_bins_values = (0, [-10], None)
+ self.err_histo_shapes = (
+ (self.n_bins + 1,),
+ (self.n_bins - 1,),
+ (self.n_bins, self.n_bins),
+ )
# these are used for testing the histo parameter as well
# as the weighted_histo parameter.
- self.err_histo_dtypes = (np.uint16,
- np.float16)
+ self.err_histo_dtypes = (np.uint16, np.float16)
- self.err_unmanaged_dtypes = ((np.double, np.uint16),
- (np.uint16, np.double),
- (np.uint16, np.uint16))
+ self.err_unmanaged_dtypes = (
+ (np.double, np.uint16),
+ (np.uint16, np.double),
+ (np.uint16, np.uint16),
+ )
-class Test_chistogramnd_ND_range(unittest.TestCase):
- """
- """
+class Test_chistogramnd_ND_range(unittest.TestCase):
+ """ """
def test_invalid_histo_range(self):
data = np.random.random((60, 60))
@@ -440,21 +425,18 @@ class Test_chistogramnd_ND_range(unittest.TestCase):
with self.assertRaises(ValueError):
histo_range = data.min(), np.inf
- Histogramnd(sample=data.ravel(),
- histo_range=histo_range,
- n_bins=nbins)
+ Histogramnd(sample=data.ravel(), histo_range=histo_range, n_bins=nbins)
histo_range = data.min(), np.nan
- Histogramnd(sample=data.ravel(),
- histo_range=histo_range,
- n_bins=nbins)
+ Histogramnd(sample=data.ravel(), histo_range=histo_range, n_bins=nbins)
class Test_chistogramnd_ND_errors(_Test_chistogramnd_errors):
"""
Unit tests of the 3D histogramnd error cases.
"""
+
__test__ = True # because _Test_chistogramnd_errors is ignored
def setUp(self):
@@ -463,56 +445,43 @@ class Test_chistogramnd_ND_errors(_Test_chistogramnd_errors):
self.s_shape = (self.n_elements, 3)
self.w_shape = (self.n_elements,)
- self.histo_range = [[0., 100.], [0., 100.], [0., 100.]]
+ self.histo_range = [[0.0, 100.0], [0.0, 100.0], [0.0, 100.0]]
self.n_bins = (10, 20, 30)
self.h_shape = self.n_bins
- self.sample = np.random.randint(0,
- high=10,
- size=self.s_shape)
+ self.sample = np.random.randint(0, high=10, size=self.s_shape)
self.sample = self.sample.astype(np.double)
- self.weights = np.random.randint(0,
- high=10,
- size=self.w_shape)
+ self.weights = np.random.randint(0, high=10, size=self.w_shape)
self.weights = self.weights.astype(np.double)
- self.err_weights_shapes = ((self.n_elements+1,),
- (self.n_elements-1,),
- (self.n_elements-1, 3))
- self.err_histo_range_shapes = ([0.],
- [0., 1.],
- [[0., 10.], [0., 10.]],
- [0., 10., 0, 10., 0, 10.])
- self.err_n_bins_shapes = ([10, 2],
- [[10], [20], [30]])
- self.err_n_bins_values = (0,
- [-10],
- [10, 20, -4],
- None,
- [10, None, 30])
- self.err_histo_shapes = ((self.n_bins[0]+1,
- self.n_bins[1],
- self.n_bins[2]),
- (self.n_bins[0],
- self.n_bins[1],
- self.n_bins[2]-1),
- (self.n_bins[0],
- self.n_bins[1]),
- (self.n_bins[1],
- self.n_bins[0],
- self.n_bins[2]),
- (self.n_bins[0],
- self.n_bins[1],
- self.n_bins[2],
- 10)
- )
+ self.err_weights_shapes = (
+ (self.n_elements + 1,),
+ (self.n_elements - 1,),
+ (self.n_elements - 1, 3),
+ )
+ self.err_histo_range_shapes = (
+ [0.0],
+ [0.0, 1.0],
+ [[0.0, 10.0], [0.0, 10.0]],
+ [0.0, 10.0, 0, 10.0, 0, 10.0],
+ )
+ self.err_n_bins_shapes = ([10, 2], [[10], [20], [30]])
+ self.err_n_bins_values = (0, [-10], [10, 20, -4], None, [10, None, 30])
+ self.err_histo_shapes = (
+ (self.n_bins[0] + 1, self.n_bins[1], self.n_bins[2]),
+ (self.n_bins[0], self.n_bins[1], self.n_bins[2] - 1),
+ (self.n_bins[0], self.n_bins[1]),
+ (self.n_bins[1], self.n_bins[0], self.n_bins[2]),
+ (self.n_bins[0], self.n_bins[1], self.n_bins[2], 10),
+ )
# these are used for testing the histo parameter as well
# as the weighted_histo parameter.
- self.err_histo_dtypes = (np.uint16,
- np.float16)
+ self.err_histo_dtypes = (np.uint16, np.float16)
- self.err_unmanaged_dtypes = ((np.double, np.uint16),
- (np.uint16, np.double),
- (np.uint16, np.uint16))
+ self.err_unmanaged_dtypes = (
+ (np.double, np.uint16),
+ (np.uint16, np.double),
+ (np.uint16, np.uint16),
+ )
diff --git a/src/silx/math/test/test_histogramnd_nominal.py b/src/silx/math/test/test_histogramnd_nominal.py
index 9a8c3c3..235f138 100644
--- a/src/silx/math/test/test_histogramnd_nominal.py
+++ b/src/silx/math/test/test_histogramnd_nominal.py
@@ -25,7 +25,6 @@ Nominal tests of the histogramnd function.
"""
import unittest
-import pytest
import numpy as np
@@ -36,10 +35,12 @@ from silx.math import Histogramnd
def _get_bin_edges(histo_range, n_bins, n_dims):
edges = []
for i_dim in range(n_dims):
- edges.append(histo_range[i_dim, 0] +
- np.arange(n_bins[i_dim] + 1) *
- (histo_range[i_dim, 1] - histo_range[i_dim, 0]) /
- n_bins[i_dim])
+ edges.append(
+ histo_range[i_dim, 0]
+ + np.arange(n_bins[i_dim] + 1)
+ * (histo_range[i_dim, 1] - histo_range[i_dim, 0])
+ / n_bins[i_dim]
+ )
return tuple(edges)
@@ -52,6 +53,7 @@ class _Test_chistogramnd_nominal(unittest.TestCase):
"""
Unit tests of the histogramnd function.
"""
+
__test__ = False # ignore abstract classe
ndims = None
@@ -60,22 +62,16 @@ class _Test_chistogramnd_nominal(unittest.TestCase):
if type(self).__name__.startswith("_"):
self.skipTest("Abstract class")
ndims = self.ndims
- self.tested_dim = ndims-1
+ self.tested_dim = ndims - 1
if ndims is None:
- raise ValueError('ndims class member not set.')
+ raise ValueError("ndims class member not set.")
- sample = np.array([5.5, -3.3,
- 0., -0.5,
- 3.3, 8.8,
- -7.7, 6.0,
- -4.0])
+ sample = np.array([5.5, -3.3, 0.0, -0.5, 3.3, 8.8, -7.7, 6.0, -4.0])
- weights = np.array([500.5, -300.3,
- 0.01, -0.5,
- 300.3, 800.8,
- -700.7, 600.6,
- -400.4])
+ weights = np.array(
+ [500.5, -300.3, 0.01, -0.5, 300.3, 800.8, -700.7, 600.6, -400.4]
+ )
n_elems = len(sample)
@@ -88,7 +84,7 @@ class _Test_chistogramnd_nominal(unittest.TestCase):
if ndims == 1:
self.sample = sample
else:
- self.sample[..., ndims-1] = sample
+ self.sample[..., ndims - 1] = sample
self.weights = weights
@@ -99,50 +95,50 @@ class _Test_chistogramnd_nominal(unittest.TestCase):
# bin [2, y] because of the bin ranges [-2, 2] and n_bins = 4
# for the first dimension)
self.other_axes_index = 2
- self.histo_range = np.repeat([[-2., 2.]], ndims, axis=0)
- self.histo_range[ndims-1] = [-4., 6.]
+ self.histo_range = np.repeat([[-2.0, 2.0]], ndims, axis=0)
+ self.histo_range[ndims - 1] = [-4.0, 6.0]
- self.n_bins = np.array([4]*ndims)
- self.n_bins[ndims-1] = 5
+ self.n_bins = np.array([4] * ndims)
+ self.n_bins[ndims - 1] = 5
if ndims == 1:
+
def fill_histo(h, v, dim, op=None):
if op:
h[:] = op(h[:], v)
else:
h[:] = v
+
self.fill_histo = fill_histo
else:
+
def fill_histo(h, v, dim, op=None):
- idx = [self.other_axes_index]*len(h.shape)
+ idx = [self.other_axes_index] * len(h.shape)
idx[dim] = slice(0, None)
idx = tuple(idx)
if op:
h[idx] = op(h[idx], v)
else:
h[idx] = v
+
self.fill_histo = fill_histo
def test_nominal(self):
- """
- """
+ """ """
expected_h_tpl = np.array([2, 1, 1, 1, 1])
expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
- histo, cumul, bin_edges = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights)
+ histo, cumul, bin_edges = histogramnd(
+ self.sample, self.histo_range, self.n_bins, weights=self.weights
+ )
- expected_edges = _get_bin_edges(self.histo_range,
- self.n_bins,
- self.ndims)
+ expected_edges = _get_bin_edges(self.histo_range, self.n_bins, self.ndims)
self.assertEqual(cumul.dtype, np.float64)
self.assertEqual(histo.dtype, np.uint32)
@@ -150,44 +146,44 @@ class _Test_chistogramnd_nominal(unittest.TestCase):
self.assertTrue(np.array_equal(cumul, expected_c))
for i_edges, edges in enumerate(expected_edges):
- self.assertTrue(np.array_equal(bin_edges[i_edges],
- expected_edges[i_edges]),
- msg='Testing bin_edges for dim {0}'
- ''.format(i_edges+1))
+ self.assertTrue(
+ np.array_equal(bin_edges[i_edges], expected_edges[i_edges]),
+ msg="Testing bin_edges for dim {0}" "".format(i_edges + 1),
+ )
def test_nominal_wh_dtype(self):
- """
- """
+ """ """
expected_h_tpl = np.array([2, 1, 1, 1, 1])
expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.float32)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
- histo, cumul, bin_edges = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- wh_dtype=np.float32)
+ histo, cumul, bin_edges = histogramnd(
+ self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ wh_dtype=np.float32,
+ )
self.assertEqual(cumul.dtype, np.float32)
self.assertTrue(np.array_equal(histo, expected_h))
self.assertTrue(np.allclose(cumul, expected_c))
def test_nominal_uncontiguous_sample(self):
- """
- """
+ """ """
expected_h_tpl = np.array([2, 1, 1, 1, 1])
expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
shape = list(self.sample.shape)
shape[0] *= 2
@@ -195,13 +191,14 @@ class _Test_chistogramnd_nominal(unittest.TestCase):
uncontig_sample = sample[::2, ...]
uncontig_sample[:] = self.sample
- self.assertFalse(uncontig_sample.flags['C_CONTIGUOUS'],
- msg='Making sure the array is not contiguous.')
+ self.assertFalse(
+ uncontig_sample.flags["C_CONTIGUOUS"],
+ msg="Making sure the array is not contiguous.",
+ )
- histo, cumul, bin_edges = histogramnd(uncontig_sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights)
+ histo, cumul, bin_edges = histogramnd(
+ uncontig_sample, self.histo_range, self.n_bins, weights=self.weights
+ )
self.assertEqual(cumul.dtype, np.float64)
self.assertEqual(histo.dtype, np.uint32)
@@ -209,16 +206,15 @@ class _Test_chistogramnd_nominal(unittest.TestCase):
self.assertTrue(np.array_equal(cumul, expected_c))
def test_nominal_uncontiguous_weights(self):
- """
- """
+ """ """
expected_h_tpl = np.array([2, 1, 1, 1, 1])
expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
shape = list(self.weights.shape)
shape[0] *= 2
@@ -226,13 +222,14 @@ class _Test_chistogramnd_nominal(unittest.TestCase):
uncontig_weights = weights[::2, ...]
uncontig_weights[:] = self.weights
- self.assertFalse(uncontig_weights.flags['C_CONTIGUOUS'],
- msg='Making sure the array is not contiguous.')
+ self.assertFalse(
+ uncontig_weights.flags["C_CONTIGUOUS"],
+ msg="Making sure the array is not contiguous.",
+ )
- histo, cumul, bin_edges = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=uncontig_weights)
+ histo, cumul, bin_edges = histogramnd(
+ self.sample, self.histo_range, self.n_bins, weights=uncontig_weights
+ )
self.assertEqual(cumul.dtype, np.float64)
self.assertEqual(histo.dtype, np.uint32)
@@ -240,25 +237,22 @@ class _Test_chistogramnd_nominal(unittest.TestCase):
self.assertTrue(np.array_equal(cumul, expected_c))
def test_nominal_wo_weights(self):
- """
- """
+ """ """
expected_h_tpl = np.array([2, 1, 1, 1, 1])
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=None)[0:2]
+ histo, cumul = histogramnd(
+ self.sample, self.histo_range, self.n_bins, weights=None
+ )[0:2]
self.assertTrue(np.array_equal(histo, expected_h))
self.assertTrue(cumul is None)
def test_nominal_wo_weights_w_cumul(self):
- """
- """
+ """ """
expected_h_tpl = np.array([2, 1, 1, 1, 1])
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
@@ -267,23 +261,24 @@ class _Test_chistogramnd_nominal(unittest.TestCase):
# it is not cleared by histogramnd
cumul_in = np.ones(self.n_bins, dtype=np.double)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=None,
- weighted_histo=cumul_in)[0:2]
+ histo, cumul = histogramnd(
+ self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=None,
+ weighted_histo=cumul_in,
+ )[0:2]
self.assertTrue(np.array_equal(histo, expected_h))
self.assertTrue(cumul is None)
- self.assertTrue(np.array_equal(cumul_in,
- np.ones(shape=self.n_bins,
- dtype=np.double)))
+ self.assertTrue(
+ np.array_equal(cumul_in, np.ones(shape=self.n_bins, dtype=np.double))
+ )
def test_nominal_wo_weights_w_histo(self):
- """
- """
+ """ """
expected_h_tpl = np.array([2, 1, 1, 1, 1])
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
@@ -292,67 +287,66 @@ class _Test_chistogramnd_nominal(unittest.TestCase):
# it is not cleared by histogramnd
histo_in = np.ones(self.n_bins, dtype=np.uint32)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=None,
- histo=histo_in)[0:2]
+ histo, cumul = histogramnd(
+ self.sample, self.histo_range, self.n_bins, weights=None, histo=histo_in
+ )[0:2]
self.assertTrue(np.array_equal(histo, expected_h + 1))
self.assertTrue(cumul is None)
self.assertEqual(id(histo), id(histo_in))
def test_nominal_last_bin_closed(self):
- """
- """
+ """ """
expected_h_tpl = np.array([2, 1, 1, 1, 2])
expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 1101.1])
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- last_bin_closed=True)[0:2]
+ histo, cumul = histogramnd(
+ self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ last_bin_closed=True,
+ )[0:2]
self.assertTrue(np.array_equal(histo, expected_h))
self.assertTrue(np.array_equal(cumul, expected_c))
def test_int32_weights_double_weights_range(self):
- """
- """
+ """ """
weight_min = -299.9 # ===> will be cast to -299
weight_max = 499.9 # ===> will be cast to 499
expected_h_tpl = np.array([0, 1, 1, 1, 0])
- expected_c_tpl = np.array([0., 0., 0., 300., 0.])
+ expected_c_tpl = np.array([0.0, 0.0, 0.0, 300.0, 0.0])
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights.astype(np.int32),
- weight_min=weight_min,
- weight_max=weight_max)[0:2]
+ histo, cumul = histogramnd(
+ self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights.astype(np.int32),
+ weight_min=weight_min,
+ weight_max=weight_max,
+ )[0:2]
self.assertTrue(np.array_equal(histo, expected_h))
self.assertTrue(np.array_equal(cumul, expected_c))
def test_reuse_histo(self):
- """
- """
+ """ """
expected_h_tpl = np.array([2, 3, 2, 2, 2])
expected_c_tpl = np.array([0.0, -7007, -5.0, 0.1, 3003.0])
@@ -360,13 +354,12 @@ class _Test_chistogramnd_nominal(unittest.TestCase):
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights)[0:2]
+ histo, cumul = histogramnd(
+ self.sample, self.histo_range, self.n_bins, weights=self.weights
+ )[0:2]
sample_2 = self.sample[:]
if len(sample_2.shape) == 1:
@@ -376,19 +369,20 @@ class _Test_chistogramnd_nominal(unittest.TestCase):
sample_2[idx] += 2
- histo_2, cumul = histogramnd(sample_2, # <==== !!
- self.histo_range,
- self.n_bins,
- weights=10 * self.weights, # <==== !!
- histo=histo)[0:2]
+ histo_2, cumul = histogramnd(
+ sample_2, # <==== !!
+ self.histo_range,
+ self.n_bins,
+ weights=10 * self.weights, # <==== !!
+ histo=histo,
+ )[0:2]
self.assertTrue(np.array_equal(histo, expected_h))
self.assertTrue(np.array_equal(cumul, expected_c))
self.assertEqual(id(histo), id(histo_2))
def test_reuse_cumul(self):
- """
- """
+ """ """
expected_h_tpl = np.array([0, 2, 1, 1, 1])
expected_c_tpl = np.array([-700.7, -7007.5, -4.99, 300.4, 3503.5])
@@ -396,13 +390,12 @@ class _Test_chistogramnd_nominal(unittest.TestCase):
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights)[0:2]
+ histo, cumul = histogramnd(
+ self.sample, self.histo_range, self.n_bins, weights=self.weights
+ )[0:2]
sample_2 = self.sample[:]
if len(sample_2.shape) == 1:
@@ -412,11 +405,13 @@ class _Test_chistogramnd_nominal(unittest.TestCase):
sample_2[idx] += 2
- histo, cumul_2 = histogramnd(sample_2, # <==== !!
- self.histo_range,
- self.n_bins,
- weights=10 * self.weights, # <==== !!
- weighted_histo=cumul)[0:2]
+ histo, cumul_2 = histogramnd(
+ sample_2, # <==== !!
+ self.histo_range,
+ self.n_bins,
+ weights=10 * self.weights, # <==== !!
+ weighted_histo=cumul,
+ )[0:2]
self.assertEqual(cumul.dtype, np.float64)
self.assertTrue(np.array_equal(histo, expected_h))
@@ -424,23 +419,22 @@ class _Test_chistogramnd_nominal(unittest.TestCase):
self.assertEqual(id(cumul), id(cumul_2))
def test_reuse_cumul_float(self):
- """
- """
+ """ """
expected_h_tpl = np.array([0, 2, 1, 1, 1])
- expected_c_tpl = np.array([-700.7, -7007.5, -4.99, 300.4, 3503.5],
- dtype=np.float32)
+ expected_c_tpl = np.array(
+ [-700.7, -7007.5, -4.99, 300.4, 3503.5], dtype=np.float32
+ )
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights)[0:2]
+ histo, cumul = histogramnd(
+ self.sample, self.histo_range, self.n_bins, weights=self.weights
+ )[0:2]
# converting the cumul array to float
cumul = cumul.astype(np.float32)
@@ -453,21 +447,25 @@ class _Test_chistogramnd_nominal(unittest.TestCase):
sample_2[idx] += 2
- histo, cumul_2 = histogramnd(sample_2, # <==== !!
- self.histo_range,
- self.n_bins,
- weights=10 * self.weights, # <==== !!
- weighted_histo=cumul)[0:2]
+ histo, cumul_2 = histogramnd(
+ sample_2, # <==== !!
+ self.histo_range,
+ self.n_bins,
+ weights=10 * self.weights, # <==== !!
+ weighted_histo=cumul,
+ )[0:2]
self.assertEqual(cumul.dtype, np.float32)
self.assertTrue(np.array_equal(histo, expected_h))
self.assertEqual(id(cumul), id(cumul_2))
self.assertTrue(np.allclose(cumul, expected_c, rtol=10e-15))
+
class _Test_Histogramnd_nominal(unittest.TestCase):
"""
Unit tests of the Histogramnd class.
"""
+
__test__ = False # ignore abstract class
ndims = None
@@ -476,22 +474,16 @@ class _Test_Histogramnd_nominal(unittest.TestCase):
ndims = self.ndims
if ndims is None:
self.skipTest("Abstract class")
- self.tested_dim = ndims-1
+ self.tested_dim = ndims - 1
if ndims is None:
- raise ValueError('ndims class member not set.')
+ raise ValueError("ndims class member not set.")
- sample = np.array([5.5, -3.3,
- 0., -0.5,
- 3.3, 8.8,
- -7.7, 6.0,
- -4.0])
+ sample = np.array([5.5, -3.3, 0.0, -0.5, 3.3, 8.8, -7.7, 6.0, -4.0])
- weights = np.array([500.5, -300.3,
- 0.01, -0.5,
- 300.3, 800.8,
- -700.7, 600.6,
- -400.4])
+ weights = np.array(
+ [500.5, -300.3, 0.01, -0.5, 300.3, 800.8, -700.7, 600.6, -400.4]
+ )
n_elems = len(sample)
@@ -504,7 +496,7 @@ class _Test_Histogramnd_nominal(unittest.TestCase):
if ndims == 1:
self.sample = sample
else:
- self.sample[..., ndims-1] = sample
+ self.sample[..., ndims - 1] = sample
self.weights = weights
@@ -515,52 +507,52 @@ class _Test_Histogramnd_nominal(unittest.TestCase):
# bin [2, y] because of the bin ranges [-2, 2] and n_bins = 4
# for the first dimension)
self.other_axes_index = 2
- self.histo_range = np.repeat([[-2., 2.]], ndims, axis=0)
- self.histo_range[ndims-1] = [-4., 6.]
+ self.histo_range = np.repeat([[-2.0, 2.0]], ndims, axis=0)
+ self.histo_range[ndims - 1] = [-4.0, 6.0]
- self.n_bins = np.array([4]*ndims)
- self.n_bins[ndims-1] = 5
+ self.n_bins = np.array([4] * ndims)
+ self.n_bins[ndims - 1] = 5
if ndims == 1:
+
def fill_histo(h, v, dim, op=None):
if op:
h[:] = op(h[:], v)
else:
h[:] = v
+
self.fill_histo = fill_histo
else:
+
def fill_histo(h, v, dim, op=None):
- idx = [self.other_axes_index]*len(h.shape)
+ idx = [self.other_axes_index] * len(h.shape)
idx[dim] = slice(0, None)
idx = tuple(idx)
if op:
h[idx] = op(h[idx], v)
else:
h[idx] = v
+
self.fill_histo = fill_histo
def test_nominal(self):
- """
- """
+ """ """
expected_h_tpl = np.array([2, 1, 1, 1, 1])
expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
+
+ histo = Histogramnd(
+ self.sample, self.histo_range, self.n_bins, weights=self.weights
+ )
- histo = Histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights)
-
histo, cumul, bin_edges = histo
- expected_edges = _get_bin_edges(self.histo_range,
- self.n_bins,
- self.ndims)
+ expected_edges = _get_bin_edges(self.histo_range, self.n_bins, self.ndims)
self.assertEqual(cumul.dtype, np.float64)
self.assertEqual(histo.dtype, np.uint32)
@@ -568,44 +560,44 @@ class _Test_Histogramnd_nominal(unittest.TestCase):
self.assertTrue(np.array_equal(cumul, expected_c))
for i_edges, edges in enumerate(expected_edges):
- self.assertTrue(np.array_equal(bin_edges[i_edges],
- expected_edges[i_edges]),
- msg='Testing bin_edges for dim {0}'
- ''.format(i_edges+1))
+ self.assertTrue(
+ np.array_equal(bin_edges[i_edges], expected_edges[i_edges]),
+ msg="Testing bin_edges for dim {0}" "".format(i_edges + 1),
+ )
def test_nominal_wh_dtype(self):
- """
- """
+ """ """
expected_h_tpl = np.array([2, 1, 1, 1, 1])
expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.float32)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
- histo, cumul, bin_edges = Histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- wh_dtype=np.float32)
+ histo, cumul, bin_edges = Histogramnd(
+ self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ wh_dtype=np.float32,
+ )
self.assertEqual(cumul.dtype, np.float32)
self.assertTrue(np.array_equal(histo, expected_h))
self.assertTrue(np.allclose(cumul, expected_c))
def test_nominal_uncontiguous_sample(self):
- """
- """
+ """ """
expected_h_tpl = np.array([2, 1, 1, 1, 1])
expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
shape = list(self.sample.shape)
shape[0] *= 2
@@ -613,13 +605,14 @@ class _Test_Histogramnd_nominal(unittest.TestCase):
uncontig_sample = sample[::2, ...]
uncontig_sample[:] = self.sample
- self.assertFalse(uncontig_sample.flags['C_CONTIGUOUS'],
- msg='Making sure the array is not contiguous.')
+ self.assertFalse(
+ uncontig_sample.flags["C_CONTIGUOUS"],
+ msg="Making sure the array is not contiguous.",
+ )
- histo, cumul, bin_edges = Histogramnd(uncontig_sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights)
+ histo, cumul, bin_edges = Histogramnd(
+ uncontig_sample, self.histo_range, self.n_bins, weights=self.weights
+ )
self.assertEqual(cumul.dtype, np.float64)
self.assertEqual(histo.dtype, np.uint32)
@@ -627,16 +620,15 @@ class _Test_Histogramnd_nominal(unittest.TestCase):
self.assertTrue(np.array_equal(cumul, expected_c))
def test_nominal_uncontiguous_weights(self):
- """
- """
+ """ """
expected_h_tpl = np.array([2, 1, 1, 1, 1])
expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
shape = list(self.weights.shape)
shape[0] *= 2
@@ -644,13 +636,14 @@ class _Test_Histogramnd_nominal(unittest.TestCase):
uncontig_weights = weights[::2, ...]
uncontig_weights[:] = self.weights
- self.assertFalse(uncontig_weights.flags['C_CONTIGUOUS'],
- msg='Making sure the array is not contiguous.')
+ self.assertFalse(
+ uncontig_weights.flags["C_CONTIGUOUS"],
+ msg="Making sure the array is not contiguous.",
+ )
- histo, cumul, bin_edges = Histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=uncontig_weights)
+ histo, cumul, bin_edges = Histogramnd(
+ self.sample, self.histo_range, self.n_bins, weights=uncontig_weights
+ )
self.assertEqual(cumul.dtype, np.float64)
self.assertEqual(histo.dtype, np.uint32)
@@ -658,75 +651,72 @@ class _Test_Histogramnd_nominal(unittest.TestCase):
self.assertTrue(np.array_equal(cumul, expected_c))
def test_nominal_wo_weights(self):
- """
- """
+ """ """
expected_h_tpl = np.array([2, 1, 1, 1, 1])
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
- histo, cumul = Histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=None)[0:2]
+ histo, cumul = Histogramnd(
+ self.sample, self.histo_range, self.n_bins, weights=None
+ )[0:2]
self.assertTrue(np.array_equal(histo, expected_h))
self.assertTrue(cumul is None)
def test_nominal_last_bin_closed(self):
- """
- """
+ """ """
expected_h_tpl = np.array([2, 1, 1, 1, 2])
expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 1101.1])
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
- histo, cumul = Histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- last_bin_closed=True)[0:2]
+ histo, cumul = Histogramnd(
+ self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ last_bin_closed=True,
+ )[0:2]
self.assertTrue(np.array_equal(histo, expected_h))
self.assertTrue(np.array_equal(cumul, expected_c))
def test_int32_weights_double_weights_range(self):
- """
- """
+ """ """
weight_min = -299.9 # ===> will be cast to -299
weight_max = 499.9 # ===> will be cast to 499
expected_h_tpl = np.array([0, 1, 1, 1, 0])
- expected_c_tpl = np.array([0., 0., 0., 300., 0.])
+ expected_c_tpl = np.array([0.0, 0.0, 0.0, 300.0, 0.0])
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
- histo, cumul = Histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights.astype(np.int32),
- weight_min=weight_min,
- weight_max=weight_max)[0:2]
+ histo, cumul = Histogramnd(
+ self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights.astype(np.int32),
+ weight_min=weight_min,
+ weight_max=weight_max,
+ )[0:2]
self.assertTrue(np.array_equal(histo, expected_h))
self.assertTrue(np.array_equal(cumul, expected_c))
def test_nominal_no_sample(self):
- """
- """
+ """ """
- histo_inst = Histogramnd(None,
- self.histo_range,
- self.n_bins)
+ histo_inst = Histogramnd(None, self.histo_range, self.n_bins)
histo, weighted_histo, edges = histo_inst
@@ -738,31 +728,25 @@ class _Test_Histogramnd_nominal(unittest.TestCase):
self.assertIsNone(histo_inst.edges)
def test_empty_init_accumulate(self):
- """
- """
+ """ """
expected_h_tpl = np.array([2, 1, 1, 1, 1])
expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
- histo_inst = Histogramnd(None,
- self.histo_range,
- self.n_bins)
+ histo_inst = Histogramnd(None, self.histo_range, self.n_bins)
- histo_inst.accumulate(self.sample,
- weights=self.weights)
+ histo_inst.accumulate(self.sample, weights=self.weights)
histo = histo_inst.histo
cumul = histo_inst.weighted_histo
bin_edges = histo_inst.edges
- expected_edges = _get_bin_edges(self.histo_range,
- self.n_bins,
- self.ndims)
+ expected_edges = _get_bin_edges(self.histo_range, self.n_bins, self.ndims)
self.assertEqual(cumul.dtype, np.float64)
self.assertEqual(histo.dtype, np.uint32)
@@ -770,14 +754,13 @@ class _Test_Histogramnd_nominal(unittest.TestCase):
self.assertTrue(np.array_equal(cumul, expected_c))
for i_edges, edges in enumerate(expected_edges):
- self.assertTrue(np.array_equal(bin_edges[i_edges],
- expected_edges[i_edges]),
- msg='Testing bin_edges for dim {0}'
- ''.format(i_edges+1))
+ self.assertTrue(
+ np.array_equal(bin_edges[i_edges], expected_edges[i_edges]),
+ msg="Testing bin_edges for dim {0}" "".format(i_edges + 1),
+ )
def test_accumulate(self):
- """
- """
+ """ """
expected_h_tpl = np.array([2, 3, 2, 2, 2])
expected_c_tpl = np.array([-700.7, -7007.5, -4.99, 300.4, 3503.5])
@@ -785,13 +768,12 @@ class _Test_Histogramnd_nominal(unittest.TestCase):
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
- histo_inst = Histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights)
+ histo_inst = Histogramnd(
+ self.sample, self.histo_range, self.n_bins, weights=self.weights
+ )
sample_2 = self.sample[:]
if len(sample_2.shape) == 1:
@@ -801,8 +783,9 @@ class _Test_Histogramnd_nominal(unittest.TestCase):
sample_2[idx] += 2
- histo_inst.accumulate(sample_2, # <==== !!
- weights=10 * self.weights) # <==== !!
+ histo_inst.accumulate(
+ sample_2, weights=10 * self.weights # <==== !!
+ ) # <==== !!
histo = histo_inst.histo
cumul = histo_inst.weighted_histo
@@ -813,8 +796,7 @@ class _Test_Histogramnd_nominal(unittest.TestCase):
self.assertTrue(np.allclose(cumul, expected_c, rtol=10e-15))
def test_accumulate_no_weights(self):
- """
- """
+ """ """
expected_h_tpl = np.array([2, 3, 2, 2, 2])
expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
@@ -822,13 +804,12 @@ class _Test_Histogramnd_nominal(unittest.TestCase):
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
- histo_inst = Histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights)
+ histo_inst = Histogramnd(
+ self.sample, self.histo_range, self.n_bins, weights=self.weights
+ )
sample_2 = self.sample[:]
if len(sample_2.shape) == 1:
@@ -849,8 +830,7 @@ class _Test_Histogramnd_nominal(unittest.TestCase):
self.assertTrue(np.allclose(cumul, expected_c, rtol=10e-15))
def test_accumulate_no_weights_at_init(self):
- """
- """
+ """ """
expected_h_tpl = np.array([2, 3, 2, 2, 2])
expected_c_tpl = np.array([0.0, -700.7, -0.5, 0.01, 300.3])
@@ -858,13 +838,12 @@ class _Test_Histogramnd_nominal(unittest.TestCase):
expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims - 1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims - 1)
- histo_inst = Histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=None) # <==== !!
+ histo_inst = Histogramnd(
+ self.sample, self.histo_range, self.n_bins, weights=None
+ ) # <==== !!
cumul = histo_inst.weighted_histo
self.assertIsNone(cumul)
@@ -877,8 +856,7 @@ class _Test_Histogramnd_nominal(unittest.TestCase):
sample_2[idx] += 2
- histo_inst.accumulate(sample_2,
- weights=self.weights) # <==== !!
+ histo_inst.accumulate(sample_2, weights=self.weights) # <==== !!
histo = histo_inst.histo
cumul = histo_inst.weighted_histo
@@ -895,15 +873,13 @@ class _Test_Histogramnd_nominal(unittest.TestCase):
type = self.sample.dtype.newbyteorder("L")
sampleL = self.sample.astype(type)
- histo_inst = Histogramnd(sampleB,
- self.histo_range,
- self.n_bins,
- weights=self.weights)
+ histo_inst = Histogramnd(
+ sampleB, self.histo_range, self.n_bins, weights=self.weights
+ )
- histo_inst = Histogramnd(sampleL,
- self.histo_range,
- self.n_bins,
- weights=self.weights)
+ histo_inst = Histogramnd(
+ sampleL, self.histo_range, self.n_bins, weights=self.weights
+ )
class Test_chistogram_nominal_1d(_Test_chistogramnd_nominal):
diff --git a/src/silx/math/test/test_histogramnd_vs_np.py b/src/silx/math/test/test_histogramnd_vs_np.py
index d1fb8be..23167f6 100644
--- a/src/silx/math/test/test_histogramnd_vs_np.py
+++ b/src/silx/math/test/test_histogramnd_vs_np.py
@@ -25,7 +25,6 @@ Tests for the histogramnd function.
Results are compared to numpy's histogramdd.
"""
-import os
import unittest
import operator
import pytest
@@ -38,8 +37,7 @@ from silx.math.chistogramnd import chistogramnd as histogramnd
# ==============================================================
# ==============================================================
-_RTOL_DICT = {np.float64: 10**-13,
- np.float32: 10**-5}
+_RTOL_DICT = {np.float64: 10**-13, np.float32: 10**-5}
# ==============================================================
# ==============================================================
@@ -51,16 +49,12 @@ def _add_values_to_array_if_missing(array, values, n_values):
if len(array.shape) == 1:
if not max_in_col:
- rnd_idx = np.random.randint(0,
- high=len(array)-1,
- size=(n_values,))
+ rnd_idx = np.random.randint(0, high=len(array) - 1, size=(n_values,))
array[rnd_idx] = values
else:
for i in range(len(max_in_col)):
if not max_in_col[i]:
- rnd_idx = np.random.randint(0,
- high=len(array)-1,
- size=(n_values,))
+ rnd_idx = np.random.randint(0, high=len(array) - 1, size=(n_values,))
array[rnd_idx, i] = values[i]
@@ -71,13 +65,10 @@ def _get_values_index(array, values, op=operator.lt):
return np.where(idx)[0]
-def _get_in_range_indices(array,
- minvalues,
- maxvalues,
- minop=operator.ge,
- maxop=operator.lt):
- idx = np.logical_and(minop(array, minvalues),
- maxop(array, maxvalues))
+def _get_in_range_indices(
+ array, minvalues, maxvalues, minop=operator.ge, maxop=operator.lt
+):
+ idx = np.logical_and(minop(array, minvalues), maxop(array, maxvalues))
if array.ndim > 1:
idx = np.all(idx, axis=1)
return np.where(idx)[0]
@@ -87,6 +78,7 @@ class _TestHistogramnd(unittest.TestCase):
"""
Unit tests of the histogramnd function.
"""
+
__test__ = False # ignore abstract class
sample_rng = None
@@ -103,7 +95,6 @@ class _TestHistogramnd(unittest.TestCase):
dtype_weights = None
def generate_data(self):
-
self.longMessage = True
int_min = 0
@@ -113,31 +104,33 @@ class _TestHistogramnd(unittest.TestCase):
if self.n_dims == 1:
shape = (n_elements,)
else:
- shape = (n_elements, self.n_dims,)
+ shape = (
+ n_elements,
+ self.n_dims,
+ )
self.rng_state = np.random.get_state()
- self.state_msg = ('Current RNG state :\n'
- '{0}'.format(self.rng_state))
+ self.state_msg = "Current RNG state :\n" "{0}".format(self.rng_state)
- sample = np.random.randint(int_min,
- high=int_max,
- size=shape)
+ sample = np.random.randint(int_min, high=int_max, size=shape)
sample = sample.astype(self.dtype_sample)
- sample = (self.sample_rng[0] +
- (sample-int_min) *
- (self.sample_rng[1]-self.sample_rng[0]) /
- (int_max-int_min)).astype(self.dtype_sample)
-
- weights = np.random.randint(int_min,
- high=int_max,
- size=(n_elements,))
+ sample = (
+ self.sample_rng[0]
+ + (sample - int_min)
+ * (self.sample_rng[1] - self.sample_rng[0])
+ / (int_max - int_min)
+ ).astype(self.dtype_sample)
+
+ weights = np.random.randint(int_min, high=int_max, size=(n_elements,))
weights = weights.astype(self.dtype_weights)
- weights = (self.weights_rng[0] +
- (weights-int_min) *
- (self.weights_rng[1]-self.weights_rng[0]) /
- (int_max-int_min)).astype(self.dtype_weights)
+ weights = (
+ self.weights_rng[0]
+ + (weights - int_min)
+ * (self.weights_rng[1] - self.weights_rng[0])
+ / (int_max - int_min)
+ ).astype(self.dtype_weights)
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!
@@ -149,21 +142,15 @@ class _TestHistogramnd(unittest.TestCase):
# adding some values that are equal to the max
# in order to test the opened/closed last bin
bins_max = [b[1] for b in self.histo_range]
- _add_values_to_array_if_missing(sample,
- bins_max,
- 100)
+ _add_values_to_array_if_missing(sample, bins_max, 100)
# adding some values that are equal to the min weight value
# in order to test the filters
- _add_values_to_array_if_missing(weights,
- self.weights_rng[0],
- 100)
+ _add_values_to_array_if_missing(weights, self.weights_rng[0], 100)
# adding some values that are equal to the max weight value
# in order to test the filters
- _add_values_to_array_if_missing(weights,
- self.weights_rng[1],
- 100)
+ _add_values_to_array_if_missing(weights, self.weights_rng[1], 100)
return sample, weights
@@ -179,150 +166,146 @@ class _TestHistogramnd(unittest.TestCase):
return np.allclose(ar_a, ar_b, self.rtol)
def test_bin_ranges(self):
- """
-
- """
- result_c = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- last_bin_closed=True)
+ """ """
+ result_c = histogramnd(
+ self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ last_bin_closed=True,
+ )
- result_np = np.histogramdd(self.sample,
- bins=self.n_bins,
- range=self.histo_range)
+ result_np = np.histogramdd(
+ self.sample, bins=self.n_bins, range=self.histo_range
+ )
for i_edges, edges in enumerate(result_c[2]):
# allclose for now until I can try with the latest version (TBD)
# of numpy
- self.assertTrue(np.allclose(edges,
- result_np[1][i_edges]),
- msg='{0}. Testing bin_edges for dim {1}.'
- ''.format(self.state_msg, i_edges+1))
+ self.assertTrue(
+ np.allclose(edges, result_np[1][i_edges]),
+ msg="{0}. Testing bin_edges for dim {1}."
+ "".format(self.state_msg, i_edges + 1),
+ )
def test_last_bin_closed(self):
- """
-
- """
- result_c = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- last_bin_closed=True)
+ """ """
+ result_c = histogramnd(
+ self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ last_bin_closed=True,
+ )
- result_np = np.histogramdd(self.sample,
- bins=self.n_bins,
- range=self.histo_range)
+ result_np = np.histogramdd(
+ self.sample, bins=self.n_bins, range=self.histo_range
+ )
- result_np_w = np.histogramdd(self.sample,
- bins=self.n_bins,
- range=self.histo_range,
- weights=self.weights)
+ result_np_w = np.histogramdd(
+ self.sample, bins=self.n_bins, range=self.histo_range, weights=self.weights
+ )
# comparing "hits"
- hits_cmp = np.array_equal(result_c[0],
- result_np[0])
+ hits_cmp = np.array_equal(result_c[0], result_np[0])
# comparing weights
- weights_cmp = np.array_equal(result_c[1],
- result_np_w[0])
+ weights_cmp = np.array_equal(result_c[1], result_np_w[0])
self.assertTrue(hits_cmp, msg=self.state_msg)
self.assertTrue(weights_cmp, msg=self.state_msg)
bins_min = [rng[0] for rng in self.histo_range]
bins_max = [rng[1] for rng in self.histo_range]
- inrange_idx = _get_in_range_indices(self.sample,
- bins_min,
- bins_max,
- minop=operator.ge,
- maxop=operator.le)
+ inrange_idx = _get_in_range_indices(
+ self.sample, bins_min, bins_max, minop=operator.ge, maxop=operator.le
+ )
- self.assertEqual(result_c[0].sum(), inrange_idx.shape[0],
- msg=self.state_msg)
+ self.assertEqual(result_c[0].sum(), inrange_idx.shape[0], msg=self.state_msg)
# we have to sum the weights using the same precision as the
# histogramnd function
weights_sum = self.weights[inrange_idx].astype(result_c[1].dtype).sum()
- self.assertTrue(self.array_compare(result_c[1].sum(), weights_sum),
- msg=self.state_msg)
+ self.assertTrue(
+ self.array_compare(result_c[1].sum(), weights_sum), msg=self.state_msg
+ )
def test_last_bin_open(self):
- """
-
- """
- result_c = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- last_bin_closed=False)
+ """ """
+ result_c = histogramnd(
+ self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ last_bin_closed=False,
+ )
bins_max = [rng[1] for rng in self.histo_range]
filtered_idx = _get_values_index(self.sample, bins_max)
- result_np = np.histogramdd(self.sample[filtered_idx],
- bins=self.n_bins,
- range=self.histo_range)
+ result_np = np.histogramdd(
+ self.sample[filtered_idx], bins=self.n_bins, range=self.histo_range
+ )
- result_np_w = np.histogramdd(self.sample[filtered_idx],
- bins=self.n_bins,
- range=self.histo_range,
- weights=self.weights[filtered_idx])
+ result_np_w = np.histogramdd(
+ self.sample[filtered_idx],
+ bins=self.n_bins,
+ range=self.histo_range,
+ weights=self.weights[filtered_idx],
+ )
# comparing "hits"
hits_cmp = np.array_equal(result_c[0], result_np[0])
# comparing weights
- weights_cmp = np.array_equal(result_c[1],
- result_np_w[0])
+ weights_cmp = np.array_equal(result_c[1], result_np_w[0])
self.assertTrue(hits_cmp, msg=self.state_msg)
self.assertTrue(weights_cmp, msg=self.state_msg)
bins_min = [rng[0] for rng in self.histo_range]
bins_max = [rng[1] for rng in self.histo_range]
- inrange_idx = _get_in_range_indices(self.sample,
- bins_min,
- bins_max,
- minop=operator.ge,
- maxop=operator.lt)
-
- self.assertEqual(result_c[0].sum(), len(inrange_idx),
- msg=self.state_msg)
+ inrange_idx = _get_in_range_indices(
+ self.sample, bins_min, bins_max, minop=operator.ge, maxop=operator.lt
+ )
+
+ self.assertEqual(result_c[0].sum(), len(inrange_idx), msg=self.state_msg)
# we have to sum the weights using the same precision as the
# histogramnd function
weights_sum = self.weights[inrange_idx].astype(result_c[1].dtype).sum()
- self.assertTrue(self.array_compare(result_c[1].sum(), weights_sum),
- msg=self.state_msg)
+ self.assertTrue(
+ self.array_compare(result_c[1].sum(), weights_sum), msg=self.state_msg
+ )
def test_filter_min(self):
- """
-
- """
- result_c = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- last_bin_closed=True,
- weight_min=self.filter_min)
+ """ """
+ result_c = histogramnd(
+ self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ last_bin_closed=True,
+ weight_min=self.filter_min,
+ )
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!
filter_min = self.dtype_weights(self.filter_min)
- weight_idx = _get_values_index(self.weights,
- filter_min, # <------ !!!
- operator.ge)
+ weight_idx = _get_values_index(
+ self.weights, filter_min, operator.ge # <------ !!!
+ )
- result_np = np.histogramdd(self.sample[weight_idx],
- bins=self.n_bins,
- range=self.histo_range)
+ result_np = np.histogramdd(
+ self.sample[weight_idx], bins=self.n_bins, range=self.histo_range
+ )
- result_np_w = np.histogramdd(self.sample[weight_idx],
- bins=self.n_bins,
- range=self.histo_range,
- weights=self.weights[weight_idx])
+ result_np_w = np.histogramdd(
+ self.sample[weight_idx],
+ bins=self.n_bins,
+ range=self.histo_range,
+ weights=self.weights[weight_idx],
+ )
# comparing "hits"
- hits_cmp = np.array_equal(result_c[0],
- result_np[0])
+ hits_cmp = np.array_equal(result_c[0], result_np[0])
# comparing weights
weights_cmp = np.array_equal(result_c[1], result_np_w[0])
@@ -331,53 +314,56 @@ class _TestHistogramnd(unittest.TestCase):
bins_min = [rng[0] for rng in self.histo_range]
bins_max = [rng[1] for rng in self.histo_range]
- inrange_idx = _get_in_range_indices(self.sample[weight_idx],
- bins_min,
- bins_max,
- minop=operator.ge,
- maxop=operator.le)
+ inrange_idx = _get_in_range_indices(
+ self.sample[weight_idx],
+ bins_min,
+ bins_max,
+ minop=operator.ge,
+ maxop=operator.le,
+ )
inrange_idx = weight_idx[inrange_idx]
- self.assertEqual(result_c[0].sum(), len(inrange_idx),
- msg=self.state_msg)
+ self.assertEqual(result_c[0].sum(), len(inrange_idx), msg=self.state_msg)
# we have to sum the weights using the same precision as the
# histogramnd function
weights_sum = self.weights[inrange_idx].astype(result_c[1].dtype).sum()
- self.assertTrue(self.array_compare(result_c[1].sum(), weights_sum),
- msg=self.state_msg)
+ self.assertTrue(
+ self.array_compare(result_c[1].sum(), weights_sum), msg=self.state_msg
+ )
def test_filter_max(self):
- """
-
- """
- result_c = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- last_bin_closed=True,
- weight_max=self.filter_max)
+ """ """
+ result_c = histogramnd(
+ self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ last_bin_closed=True,
+ weight_max=self.filter_max,
+ )
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!
filter_max = self.dtype_weights(self.filter_max)
- weight_idx = _get_values_index(self.weights,
- filter_max, # <------ !!!
- operator.le)
+ weight_idx = _get_values_index(
+ self.weights, filter_max, operator.le # <------ !!!
+ )
- result_np = np.histogramdd(self.sample[weight_idx],
- bins=self.n_bins,
- range=self.histo_range)
+ result_np = np.histogramdd(
+ self.sample[weight_idx], bins=self.n_bins, range=self.histo_range
+ )
- result_np_w = np.histogramdd(self.sample[weight_idx],
- bins=self.n_bins,
- range=self.histo_range,
- weights=self.weights[weight_idx])
+ result_np_w = np.histogramdd(
+ self.sample[weight_idx],
+ bins=self.n_bins,
+ range=self.histo_range,
+ weights=self.weights[weight_idx],
+ )
# comparing "hits"
- hits_cmp = np.array_equal(result_c[0],
- result_np[0])
+ hits_cmp = np.array_equal(result_c[0], result_np[0])
# comparing weights
weights_cmp = np.array_equal(result_c[1], result_np_w[0])
@@ -386,57 +372,62 @@ class _TestHistogramnd(unittest.TestCase):
bins_min = [rng[0] for rng in self.histo_range]
bins_max = [rng[1] for rng in self.histo_range]
- inrange_idx = _get_in_range_indices(self.sample[weight_idx],
- bins_min,
- bins_max,
- minop=operator.ge,
- maxop=operator.le)
+ inrange_idx = _get_in_range_indices(
+ self.sample[weight_idx],
+ bins_min,
+ bins_max,
+ minop=operator.ge,
+ maxop=operator.le,
+ )
inrange_idx = weight_idx[inrange_idx]
- self.assertEqual(result_c[0].sum(), len(inrange_idx),
- msg=self.state_msg)
+ self.assertEqual(result_c[0].sum(), len(inrange_idx), msg=self.state_msg)
# we have to sum the weights using the same precision as the
# histogramnd function
weights_sum = self.weights[inrange_idx].astype(result_c[1].dtype).sum()
- self.assertTrue(self.array_compare(result_c[1].sum(), weights_sum),
- msg=self.state_msg)
+ self.assertTrue(
+ self.array_compare(result_c[1].sum(), weights_sum), msg=self.state_msg
+ )
def test_filter_minmax(self):
- """
-
- """
- result_c = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- last_bin_closed=True,
- weight_min=self.filter_min,
- weight_max=self.filter_max)
+ """ """
+ result_c = histogramnd(
+ self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ last_bin_closed=True,
+ weight_min=self.filter_min,
+ weight_max=self.filter_max,
+ )
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!
filter_min = self.dtype_weights(self.filter_min)
filter_max = self.dtype_weights(self.filter_max)
- weight_idx = _get_in_range_indices(self.weights,
- filter_min, # <------ !!!
- filter_max, # <------ !!!
- minop=operator.ge,
- maxop=operator.le)
+ weight_idx = _get_in_range_indices(
+ self.weights,
+ filter_min, # <------ !!!
+ filter_max, # <------ !!!
+ minop=operator.ge,
+ maxop=operator.le,
+ )
- result_np = np.histogramdd(self.sample[weight_idx],
- bins=self.n_bins,
- range=self.histo_range)
+ result_np = np.histogramdd(
+ self.sample[weight_idx], bins=self.n_bins, range=self.histo_range
+ )
- result_np_w = np.histogramdd(self.sample[weight_idx],
- bins=self.n_bins,
- range=self.histo_range,
- weights=self.weights[weight_idx])
+ result_np_w = np.histogramdd(
+ self.sample[weight_idx],
+ bins=self.n_bins,
+ range=self.histo_range,
+ weights=self.weights[weight_idx],
+ )
# comparing "hits"
- hits_cmp = np.array_equal(result_c[0],
- result_np[0])
+ hits_cmp = np.array_equal(result_c[0], result_np[0])
# comparing weights
weights_cmp = np.array_equal(result_c[1], result_np_w[0])
@@ -445,122 +436,113 @@ class _TestHistogramnd(unittest.TestCase):
bins_min = [rng[0] for rng in self.histo_range]
bins_max = [rng[1] for rng in self.histo_range]
- inrange_idx = _get_in_range_indices(self.sample[weight_idx],
- bins_min,
- bins_max,
- minop=operator.ge,
- maxop=operator.le)
+ inrange_idx = _get_in_range_indices(
+ self.sample[weight_idx],
+ bins_min,
+ bins_max,
+ minop=operator.ge,
+ maxop=operator.le,
+ )
inrange_idx = weight_idx[inrange_idx]
- self.assertEqual(result_c[0].sum(), len(inrange_idx),
- msg=self.state_msg)
+ self.assertEqual(result_c[0].sum(), len(inrange_idx), msg=self.state_msg)
# we have to sum the weights using the same precision as the
# histogramnd function
weights_sum = self.weights[inrange_idx].astype(result_c[1].dtype).sum()
- self.assertTrue(self.array_compare(result_c[1].sum(), weights_sum),
- msg=self.state_msg)
+ self.assertTrue(
+ self.array_compare(result_c[1].sum(), weights_sum), msg=self.state_msg
+ )
def test_reuse_histo(self):
- """
-
- """
- result_c_1 = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- last_bin_closed=True)
+ """ """
+ result_c_1 = histogramnd(
+ self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ last_bin_closed=True,
+ )
- result_np_1 = np.histogramdd(self.sample,
- bins=self.n_bins,
- range=self.histo_range)
+ result_np_1 = np.histogramdd(
+ self.sample, bins=self.n_bins, range=self.histo_range
+ )
- np.histogramdd(self.sample,
- bins=self.n_bins,
- range=self.histo_range,
- weights=self.weights)
+ np.histogramdd(
+ self.sample, bins=self.n_bins, range=self.histo_range, weights=self.weights
+ )
sample_2, weights_2 = self.generate_data()
- result_c_2 = histogramnd(sample_2,
- self.histo_range,
- self.n_bins,
- weights=weights_2,
- last_bin_closed=True,
- histo=result_c_1[0])
+ result_c_2 = histogramnd(
+ sample_2,
+ self.histo_range,
+ self.n_bins,
+ weights=weights_2,
+ last_bin_closed=True,
+ histo=result_c_1[0],
+ )
- result_np_2 = np.histogramdd(sample_2,
- bins=self.n_bins,
- range=self.histo_range)
+ result_np_2 = np.histogramdd(sample_2, bins=self.n_bins, range=self.histo_range)
- result_np_w_2 = np.histogramdd(sample_2,
- bins=self.n_bins,
- range=self.histo_range,
- weights=weights_2)
+ result_np_w_2 = np.histogramdd(
+ sample_2, bins=self.n_bins, range=self.histo_range, weights=weights_2
+ )
# comparing "hits"
- hits_cmp = np.array_equal(result_c_2[0],
- result_np_1[0] +
- result_np_2[0])
+ hits_cmp = np.array_equal(result_c_2[0], result_np_1[0] + result_np_2[0])
# comparing weights
- weights_cmp = np.array_equal(result_c_2[1],
- result_np_w_2[0])
+ weights_cmp = np.array_equal(result_c_2[1], result_np_w_2[0])
self.assertTrue(hits_cmp, msg=self.state_msg)
self.assertTrue(weights_cmp, msg=self.state_msg)
def test_reuse_cumul(self):
- """
-
- """
- result_c = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- last_bin_closed=True)
+ """ """
+ result_c = histogramnd(
+ self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ last_bin_closed=True,
+ )
- np.histogramdd(self.sample,
- bins=self.n_bins,
- range=self.histo_range)
+ np.histogramdd(self.sample, bins=self.n_bins, range=self.histo_range)
- result_np_w = np.histogramdd(self.sample,
- bins=self.n_bins,
- range=self.histo_range,
- weights=self.weights)
+ result_np_w = np.histogramdd(
+ self.sample, bins=self.n_bins, range=self.histo_range, weights=self.weights
+ )
sample_2, weights_2 = self.generate_data()
- result_c_2 = histogramnd(sample_2,
- self.histo_range,
- self.n_bins,
- weights=weights_2,
- last_bin_closed=True,
- weighted_histo=result_c[1])
+ result_c_2 = histogramnd(
+ sample_2,
+ self.histo_range,
+ self.n_bins,
+ weights=weights_2,
+ last_bin_closed=True,
+ weighted_histo=result_c[1],
+ )
- result_np_2 = np.histogramdd(sample_2,
- bins=self.n_bins,
- range=self.histo_range)
+ result_np_2 = np.histogramdd(sample_2, bins=self.n_bins, range=self.histo_range)
- result_np_w_2 = np.histogramdd(sample_2,
- bins=self.n_bins,
- range=self.histo_range,
- weights=weights_2)
+ result_np_w_2 = np.histogramdd(
+ sample_2, bins=self.n_bins, range=self.histo_range, weights=weights_2
+ )
# comparing "hits"
- hits_cmp = np.array_equal(result_c_2[0],
- result_np_2[0])
+ hits_cmp = np.array_equal(result_c_2[0], result_np_2[0])
# comparing weights
self.assertTrue(hits_cmp, msg=self.state_msg)
- self.assertTrue(self.array_compare(result_c_2[1],
- result_np_w[0] + result_np_w_2[0]),
- msg=self.state_msg)
+ self.assertTrue(
+ self.array_compare(result_c_2[1], result_np_w[0] + result_np_w_2[0]),
+ msg=self.state_msg,
+ )
def test_reuse_cumul_float(self):
- """
-
- """
+ """ """
n_bins = np.array(self.n_bins, ndmin=1)
if len(self.sample.shape) == 2:
if len(n_bins) == self.sample.shape[1]:
@@ -572,43 +554,46 @@ class _TestHistogramnd(unittest.TestCase):
shp = (self.n_bins,)
cumul = np.zeros(shp, dtype=np.float32)
- result_c_1 = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- last_bin_closed=True,
- weighted_histo=cumul)
+ result_c_1 = histogramnd(
+ self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ last_bin_closed=True,
+ weighted_histo=cumul,
+ )
- result_np_1 = np.histogramdd(self.sample,
- bins=self.n_bins,
- range=self.histo_range)
+ result_np_1 = np.histogramdd(
+ self.sample, bins=self.n_bins, range=self.histo_range
+ )
- result_np_w_1 = np.histogramdd(self.sample,
- bins=self.n_bins,
- range=self.histo_range,
- weights=self.weights)
+ result_np_w_1 = np.histogramdd(
+ self.sample, bins=self.n_bins, range=self.histo_range, weights=self.weights
+ )
# comparing "hits"
- hits_cmp = np.array_equal(result_c_1[0],
- result_np_1[0])
+ hits_cmp = np.array_equal(result_c_1[0], result_np_1[0])
self.assertTrue(hits_cmp, msg=self.state_msg)
self.assertEqual(result_c_1[1].dtype, np.float32, msg=self.state_msg)
bins_min = [rng[0] for rng in self.histo_range]
bins_max = [rng[1] for rng in self.histo_range]
- inrange_idx = _get_in_range_indices(self.sample,
- bins_min,
- bins_max,
- minop=operator.ge,
- maxop=operator.le)
- weights_sum = \
- self.weights[inrange_idx].astype(np.float32).sum(dtype=np.float64)
- self.assertTrue(np.allclose(result_c_1[1].sum(dtype=np.float64),
- weights_sum), msg=self.state_msg)
- self.assertTrue(np.allclose(result_c_1[1].sum(dtype=np.float64),
- result_np_w_1[0].sum(dtype=np.float64)),
- msg=self.state_msg)
+ inrange_idx = _get_in_range_indices(
+ self.sample, bins_min, bins_max, minop=operator.ge, maxop=operator.le
+ )
+ weights_sum = self.weights[inrange_idx].astype(np.float32).sum(dtype=np.float64)
+ self.assertTrue(
+ np.allclose(result_c_1[1].sum(dtype=np.float64), weights_sum),
+ msg=self.state_msg,
+ )
+ self.assertTrue(
+ np.allclose(
+ result_c_1[1].sum(dtype=np.float64),
+ result_np_w_1[0].sum(dtype=np.float64),
+ ),
+ msg=self.state_msg,
+ )
@pytest.mark.usefixtures("use_large_memory")
def test_histo_big_array(self):
@@ -622,35 +607,23 @@ class _TestHistogramnd(unittest.TestCase):
n_repeat = (2**31 + 10) // self.sample.size
sample = np.repeat(self.sample, n_repeat)
n_bins = int(1e6)
- result_c = histogramnd(
- sample,
- self.histo_range,
- n_bins,
- last_bin_closed=True
- )
- result_np = np.histogramdd(
- sample,
- n_bins,
- range=self.histo_range
- )
+ result_c = histogramnd(sample, self.histo_range, n_bins, last_bin_closed=True)
+ result_np = np.histogramdd(sample, n_bins, range=self.histo_range)
for i_edges, edges in enumerate(result_c[2]):
self.assertTrue(
- np.allclose(
- edges,
- result_np[1][i_edges]
- ),
- msg='{0}. Testing bin_edges for dim {1}.'''.format(self.state_msg, i_edges+1)
+ np.allclose(edges, result_np[1][i_edges]),
+ msg="{0}. Testing bin_edges for dim {1}."
+ "".format(self.state_msg, i_edges + 1),
)
-
-
class _TestHistogramnd_1d(_TestHistogramnd):
"""
Unit tests of the 1D histogramnd function.
"""
- sample_rng = [-55., 100.]
- weights_rng = [-70., 150.]
+
+ sample_rng = [-55.0, 100.0]
+ weights_rng = [-70.0, 150.0]
n_dims = 1
filter_min = -15.6
filter_max = 85.7
@@ -665,13 +638,14 @@ class _TestHistogramnd_2d(_TestHistogramnd):
"""
Unit tests of the 1D histogramnd function.
"""
+
sample_rng = [-50.2, 100.99]
- weights_rng = [70., 150.]
+ weights_rng = [70.0, 150.0]
n_dims = 2
filter_min = 81.7
filter_max = 135.3
- histo_range = [[10., 90.], [20., 70.]]
+ histo_range = [[10.0, 90.0], [20.0, 70.0]]
n_bins = 30
dtype = None
@@ -681,13 +655,14 @@ class _TestHistogramnd_3d(_TestHistogramnd):
"""
Unit tests of the 1D histogramnd function.
"""
+
sample_rng = [10.2, 200.9]
- weights_rng = [0., 100.]
+ weights_rng = [0.0, 100.0]
n_dims = 3
filter_min = 31.5
filter_max = 83.7
- histo_range = [[30.8, 150.2], [20.1, 90.9], [10.1, 195.]]
+ histo_range = [[30.8, 150.2], [20.1, 90.9], [10.1, 195.0]]
n_bins = 30
dtype = None
diff --git a/src/silx/math/test/test_interpolate.py b/src/silx/math/test/test_interpolate.py
index cff8bd9..a2b5455 100644
--- a/src/silx/math/test/test_interpolate.py
+++ b/src/silx/math/test/test_interpolate.py
@@ -31,6 +31,7 @@ __date__ = "11/07/2019"
import unittest
import numpy
+
try:
from scipy.interpolate import interpn
except ImportError:
@@ -55,7 +56,8 @@ class TestInterp3d(ParametricTestCase):
[numpy.arange(dim, dtype=data.dtype) for dim in data.shape],
data,
points,
- method='linear')
+ method="linear",
+ )
def test_random_data(self):
"""Test interp3d with random data"""
@@ -63,14 +65,14 @@ class TestInterp3d(ParametricTestCase):
npoints = 10
ref_data = numpy.random.random((size, size, size))
- ref_points = numpy.random.random(npoints*3).reshape(npoints, 3) * (size -1)
+ ref_points = numpy.random.random(npoints * 3).reshape(npoints, 3) * (size - 1)
for dtype in (numpy.float32, numpy.float64):
data = ref_data.astype(dtype)
points = ref_points.astype(dtype)
ref_result = self.ref_interp3d(data, points)
- for method in (u'linear', u'linear_omp'):
+ for method in ("linear", "linear_omp"):
with self.subTest(method=method):
result = interpolate.interp3d(data, points, method=method)
self.assertTrue(numpy.allclose(ref_result, result))
@@ -80,29 +82,27 @@ class TestInterp3d(ParametricTestCase):
data = numpy.ones((3, 3, 3), dtype=numpy.float64)
data[0, 0, 0] = numpy.nan
data[2, 2, 2] = numpy.inf
- points = numpy.array([(0.5, 0.5, 0.5),
- (1.5, 1.5, 1.5)])
+ points = numpy.array([(0.5, 0.5, 0.5), (1.5, 1.5, 1.5)])
- for method in (u'linear', u'linear_omp'):
+ for method in ("linear", "linear_omp"):
with self.subTest(method=method):
- result = interpolate.interp3d(
- data, points, method=method)
+ result = interpolate.interp3d(data, points, method=method)
self.assertTrue(numpy.isnan(result[0]))
self.assertTrue(result[1] == numpy.inf)
def test_points_outside(self):
"""Test interp3d with points outside the volume"""
data = numpy.ones((4, 4, 4), dtype=numpy.float64)
- points = numpy.array([(-0.1, -0.1, -0.1),
- (3.1, 3.1, 3.1),
- (-0.1, 1., 1.),
- (1., 1., 3.1)])
+ points = numpy.array(
+ [(-0.1, -0.1, -0.1), (3.1, 3.1, 3.1), (-0.1, 1.0, 1.0), (1.0, 1.0, 3.1)]
+ )
- for method in (u'linear', u'linear_omp'):
- for fill_value in (numpy.nan, 0., -1.):
+ for method in ("linear", "linear_omp"):
+ for fill_value in (numpy.nan, 0.0, -1.0):
with self.subTest(method=method):
result = interpolate.interp3d(
- data, points, method=method, fill_value=fill_value)
+ data, points, method=method, fill_value=fill_value
+ )
if numpy.isnan(fill_value):
self.assertTrue(numpy.all(numpy.isnan(result)))
else:
@@ -111,14 +111,13 @@ class TestInterp3d(ParametricTestCase):
def test_integer_points(self):
"""Test interp3d with integer points coord"""
data = numpy.arange(4**3, dtype=numpy.float64).reshape(4, 4, 4)
- points = numpy.array([(0., 0., 0.),
- (0., 0., 1.),
- (2., 3., 0.),
- (3., 3., 3.)])
+ points = numpy.array(
+ [(0.0, 0.0, 0.0), (0.0, 0.0, 1.0), (2.0, 3.0, 0.0), (3.0, 3.0, 3.0)]
+ )
ref_result = data[tuple(points.T.astype(numpy.int32))]
- for method in (u'linear', u'linear_omp'):
+ for method in ("linear", "linear_omp"):
with self.subTest(method=method):
result = interpolate.interp3d(data, points, method=method)
self.assertTrue(numpy.allclose(ref_result, result))
diff --git a/src/silx/math/test/test_marchingcubes.py b/src/silx/math/test/test_marchingcubes.py
index 7c60414..7ac171e 100644
--- a/src/silx/math/test/test_marchingcubes.py
+++ b/src/silx/math/test/test_marchingcubes.py
@@ -26,8 +26,6 @@ __authors__ = ["T. Vincent"]
__license__ = "MIT"
__date__ = "17/01/2018"
-import unittest
-
import numpy
from silx.utils.testutils import ParametricTestCase
@@ -38,8 +36,7 @@ from silx.math import marchingcubes
class TestMarchingCubes(ParametricTestCase):
"""Tests of marching cubes"""
- def assertAllClose(self, array1, array2, msg=None,
- rtol=1e-05, atol=1e-08):
+ def assertAllClose(self, array1, array2, msg=None, rtol=1e-05, atol=1e-08):
"""Assert that the 2 numpy.ndarrays are almost equal.
:param str msg: Message to provide when assert fails
@@ -55,9 +52,9 @@ class TestMarchingCubes(ParametricTestCase):
# No isosurface
cube_zero = numpy.zeros((2, 2, 2), dtype=numpy.float32)
- result = marchingcubes.MarchingCubes(cube_zero, 1.)
+ result = marchingcubes.MarchingCubes(cube_zero, 1.0)
self.assertEqual(result.shape, cube_zero.shape)
- self.assertEqual(result.isolevel, 1.)
+ self.assertEqual(result.isolevel, 1.0)
self.assertEqual(result.invert_normals, True)
vertices, normals, indices = result
@@ -84,43 +81,45 @@ class TestMarchingCubes(ParametricTestCase):
# isosurface perpendicular to dim 0 (Z)
cube = numpy.array(
- (((0., 0.), (0., 0.)),
- ((1., 1.), (1., 1.))), dtype=numpy.float32)
+ (((0.0, 0.0), (0.0, 0.0)), ((1.0, 1.0), (1.0, 1.0))), dtype=numpy.float32
+ )
level = 0.5
vertices, normals, indices = marchingcubes.MarchingCubes(
- cube, level, invert_normals=False)
+ cube, level, invert_normals=False
+ )
self.assertAllClose(vertices[:, 0], level)
- self.assertAllClose(normals, (1., 0., 0.))
+ self.assertAllClose(normals, (1.0, 0.0, 0.0))
self.assertEqual(len(indices), 2)
# isosurface perpendicular to dim 1 (Y)
cube = numpy.array(
- (((0., 0.), (1., 1.)),
- ((0., 0.), (1., 1.))), dtype=numpy.float32)
+ (((0.0, 0.0), (1.0, 1.0)), ((0.0, 0.0), (1.0, 1.0))), dtype=numpy.float32
+ )
level = 0.2
vertices, normals, indices = marchingcubes.MarchingCubes(cube, level)
self.assertAllClose(vertices[:, 1], level)
- self.assertAllClose(normals, (0., -1., 0.))
+ self.assertAllClose(normals, (0.0, -1.0, 0.0))
self.assertEqual(len(indices), 2)
# isosurface perpendicular to dim 2 (X)
cube = numpy.array(
- (((0., 1.), (0., 1.)),
- ((0., 1.), (0., 1.))), dtype=numpy.float32)
+ (((0.0, 1.0), (0.0, 1.0)), ((0.0, 1.0), (0.0, 1.0))), dtype=numpy.float32
+ )
level = 0.9
vertices, normals, indices = marchingcubes.MarchingCubes(
- cube, level, invert_normals=False)
+ cube, level, invert_normals=False
+ )
self.assertAllClose(vertices[:, 2], level)
- self.assertAllClose(normals, (0., 0., 1.))
+ self.assertAllClose(normals, (0.0, 0.0, 1.0))
self.assertEqual(len(indices), 2)
# isosurface normal in dim1, dim 0 (Y, Z) plane
cube = numpy.array(
- (((0., 0.), (0., 0.)),
- ((0., 0.), (1., 1.))), dtype=numpy.float32)
+ (((0.0, 0.0), (0.0, 0.0)), ((0.0, 0.0), (1.0, 1.0))), dtype=numpy.float32
+ )
level = 0.5
vertices, normals, indices = marchingcubes.MarchingCubes(cube, level)
- self.assertAllClose(normals[:, 2], 0.)
+ self.assertAllClose(normals[:, 2], 0.0)
self.assertEqual(len(indices), 2)
def test_sampling(self):
@@ -146,26 +145,26 @@ class TestMarchingCubes(ParametricTestCase):
with self.subTest(sampling=sampling):
sampling = numpy.array(sampling)
- data = 1e6 * numpy.ones(
- sampling * size, dtype=numpy.float32)
+ data = 1e6 * numpy.ones(sampling * size, dtype=numpy.float32)
# Copy ref chessboard in data according to sampling
- data[::sampling[0], ::sampling[1], ::sampling[2]] = chessboard
+ data[:: sampling[0], :: sampling[1], :: sampling[2]] = chessboard
- result = marchingcubes.MarchingCubes(data, isolevel,
- sampling=sampling)
+ result = marchingcubes.MarchingCubes(data, isolevel, sampling=sampling)
# Compare vertices normalized with shape
self.assertAllClose(
ref_result.get_vertices() / ref_result.shape,
result.get_vertices() / result.shape,
- atol=0., rtol=0.)
+ atol=0.0,
+ rtol=0.0,
+ )
# Compare normals
# This comparison only works for normals aligned with axes
# otherwise non uniform sampling would make different normals
- self.assertAllClose(ref_result.get_normals(),
- result.get_normals(),
- atol=0., rtol=0.)
+ self.assertAllClose(
+ ref_result.get_normals(), result.get_normals(), atol=0.0, rtol=0.0
+ )
- self.assertAllClose(ref_result.get_indices(),
- result.get_indices(),
- atol=0., rtol=0.)
+ self.assertAllClose(
+ ref_result.get_indices(), result.get_indices(), atol=0.0, rtol=0.0
+ )
diff --git a/src/silx/opencl/atomic.py b/src/silx/opencl/atomic.py
new file mode 100644
index 0000000..16d3eff
--- /dev/null
+++ b/src/silx/opencl/atomic.py
@@ -0,0 +1,93 @@
+#
+# Project: S I L X project
+# https://github.com/silx-kit/silx
+#
+# Copyright (C) 2012-2023 European Synchrotron Radiation Facility, Grenoble, France
+#
+# Principal author: Jérôme Kieffer (Jerome.Kieffer@ESRF.eu)
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+# .
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+# .
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+
+"""
+Utilities around atomic operation in OpenCL
+"""
+
+__author__ = "Jerome Kieffer"
+__license__ = "MIT"
+__date__ = "14/06/2023"
+__copyright__ = "2023-2023, ESRF, Grenoble"
+__contact__ = "jerome.kieffer@esrf.fr"
+
+import numpy
+import pyopencl
+import pyopencl.array as cla
+
+
+def check_atomic32(device):
+ try:
+ ctx = pyopencl.Context(devices=[device])
+ except:
+ return False, f"Unable to create context on {device}"
+ else:
+ queue = pyopencl.CommandQueue(ctx)
+ src = """
+kernel void check_atomic32(global int* ary){
+int res = atom_inc(ary);
+}
+"""
+ try:
+ prg = pyopencl.Program(ctx, src).build()
+ except Exception as err:
+ return False, f"{type(err)}: {err}"
+ a = numpy.zeros(1, numpy.int32)
+ d = cla.to_device(queue, a)
+ prg.check_atomic32(queue, (1024,), (32,), d.data).wait()
+ value = d.get()[0]
+ return value == 1024, f"Got the proper value 1024=={value}"
+
+
+def check_atomic64(device):
+ try:
+ ctx = pyopencl.Context(devices=[device])
+ except:
+ return False, f"Unable to create context on {device}"
+ else:
+ queue = pyopencl.CommandQueue(ctx)
+ if (
+ device.platform.name == "Portable Computing Language"
+ and "GPU" in pyopencl.device_type.to_string(device.type).upper()
+ ):
+ # this configuration is known to seg-fault
+ return False, "PoCL + GPU do not support atomic64"
+ src = """
+#pragma OPENCL EXTENSION cl_khr_fp64: enable
+#pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable
+kernel void check_atomic64(global long* ary){
+long res = atom_inc(ary);
+}
+"""
+ try:
+ prg = pyopencl.Program(ctx, src).build()
+ except Exception as err:
+ return False, f"{type(err)}: {err}"
+ a = numpy.zeros(1, numpy.int64)
+ d = cla.to_device(queue, a)
+ prg.check_atomic64(queue, (1024,), (32,), d.data).wait()
+ value = d.get()[0]
+ return value == 1024, f"Got the proper value 1024=={value}"
diff --git a/src/silx/opencl/backprojection.py b/src/silx/opencl/backprojection.py
index 9f747c1..5af2bc5 100644
--- a/src/silx/opencl/backprojection.py
+++ b/src/silx/opencl/backprojection.py
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# /*##########################################################################
#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -34,8 +34,6 @@ import numpy as np
from .common import pyopencl
from .processing import EventDescription, OpenclProcessing, BufferDescription
from .sinofilter import SinoFilter
-from .sinofilter import fourier_filter as fourier_filter_
-from ..utils.deprecation import deprecated
if pyopencl:
mf = pyopencl.mem_flags
@@ -61,12 +59,23 @@ def _idivup(a, b):
class Backprojection(OpenclProcessing):
"""A class for performing the backprojection using OpenCL"""
+
kernel_files = ["backproj.cl", "array_utils.cl"]
- def __init__(self, sino_shape, slice_shape=None, axis_position=None,
- angles=None, filter_name=None, ctx=None, devicetype="all",
- platformid=None, deviceid=None, profile=False,
- extra_options=None):
+ def __init__(
+ self,
+ sino_shape,
+ slice_shape=None,
+ axis_position=None,
+ angles=None,
+ filter_name=None,
+ ctx=None,
+ devicetype="all",
+ platformid=None,
+ deviceid=None,
+ profile=False,
+ extra_options=None,
+ ):
"""Constructor of the OpenCL (filtered) backprojection
:param sino_shape: shape of the sinogram. The sinogram is in the format
@@ -98,19 +107,26 @@ class Backprojection(OpenclProcessing):
# assuming no discrete GPU
# raise NotImplementedError("Backprojection is not implemented on CPU for OS X yet")
- OpenclProcessing.__init__(self, ctx=ctx, devicetype=devicetype,
- platformid=platformid, deviceid=deviceid,
- profile=profile)
+ OpenclProcessing.__init__(
+ self,
+ ctx=ctx,
+ devicetype=devicetype,
+ platformid=platformid,
+ deviceid=deviceid,
+ profile=profile,
+ )
- self._init_geometry(sino_shape, slice_shape, angles, axis_position,
- extra_options)
+ self._init_geometry(
+ sino_shape, slice_shape, angles, axis_position, extra_options
+ )
self._allocate_memory()
self._compute_angles()
self._init_kernels()
self._init_filter(filter_name)
- def _init_geometry(self, sino_shape, slice_shape, angles, axis_position,
- extra_options):
+ def _init_geometry(
+ self, sino_shape, slice_shape, angles, axis_position, extra_options
+ ):
"""Geometry Initialization
:param sino_shape: shape of the sinogram. The sinogram is in the format
@@ -134,12 +150,12 @@ class Backprojection(OpenclProcessing):
self.slice_shape = slice_shape
self.dimrec_shape = (
_idivup(self.slice_shape[0], 32) * 32,
- _idivup(self.slice_shape[1], 32) * 32
+ _idivup(self.slice_shape[1], 32) * 32,
)
if axis_position:
self.axis_pos = np.float32(axis_position)
else:
- self.axis_pos = np.float32((sino_shape[1] - 1.) / 2)
+ self.axis_pos = np.float32((sino_shape[1] - 1.0) / 2)
self.axis_array = None # TODO: add axis correction front-end
self._init_extra_options(extra_options)
@@ -149,11 +165,11 @@ class Backprojection(OpenclProcessing):
:param dict extra_options: Advanced extra options
"""
self.extra_options = {
- "cutoff": 1.,
+ "cutoff": 1.0,
"use_numpy_fft": False,
# It is axis_pos - (num_bins-1)/2 in PyHST
- "gpu_offset_x": 0., #self.axis_pos - (self.num_bins - 1) / 2.,
- "gpu_offset_y": 0., #self.axis_pos - (self.num_bins - 1) / 2.
+ "gpu_offset_x": 0.0, # self.axis_pos - (self.num_bins - 1) / 2.,
+ "gpu_offset_y": 0.0, # self.axis_pos - (self.num_bins - 1) / 2.
}
if extra_options is not None:
self.extra_options.update(extra_options)
@@ -166,7 +182,9 @@ class Backprojection(OpenclProcessing):
# Device memory
self.buffers = [
BufferDescription("_d_slice", self.dimrec_shape, np.float32, mf.READ_WRITE),
- BufferDescription("d_sino", self.shape, np.float32, mf.READ_WRITE), # before transferring to texture (if available)
+ BufferDescription(
+ "d_sino", self.shape, np.float32, mf.READ_WRITE
+ ), # before transferring to texture (if available)
BufferDescription("d_cos", (self.num_projs,), np.float32, mf.READ_ONLY),
BufferDescription("d_sin", (self.num_projs,), np.float32, mf.READ_ONLY),
BufferDescription("d_axes", (self.num_projs,), np.float32, mf.READ_ONLY),
@@ -191,27 +209,29 @@ class Backprojection(OpenclProcessing):
if self.axis_array:
self.cl_mem["d_axes"][:] = self.axis_array.astype(np.float32)[:]
else:
- self.cl_mem["d_axes"][:] = np.ones(self.num_projs, dtype="f") * self.axis_pos
+ self.cl_mem["d_axes"][:] = (
+ np.ones(self.num_projs, dtype="f") * self.axis_pos
+ )
def _init_kernels(self):
compile_options = None
- if not(self._use_textures):
+ if not (self._use_textures):
compile_options = "-DDONT_USE_TEXTURES"
OpenclProcessing.compile_kernels(
- self,
- self.kernel_files,
- compile_options=compile_options
+ self, self.kernel_files, compile_options=compile_options
)
# check that workgroup can actually be (16, 16)
- self.compiletime_workgroup_size = self.kernels.max_workgroup_size("backproj_cpu_kernel")
+ self.compiletime_workgroup_size = self.kernels.max_workgroup_size(
+ "backproj_cpu_kernel"
+ )
# Workgroup and ndrange sizes are always the same
self.wg = (16, 16)
self.ndrange = (
_idivup(int(self.dimrec_shape[1]), 32) * self.wg[0],
- _idivup(int(self.dimrec_shape[0]), 32) * self.wg[1]
+ _idivup(int(self.dimrec_shape[0]), 32) * self.wg[1],
)
# Prepare arguments for the kernel call
- if not(self._use_textures):
+ if not (self._use_textures):
d_sino_ref = self.d_sino.data
else:
d_sino_ref = self.d_sino_tex
@@ -226,7 +246,7 @@ class Backprojection(OpenclProcessing):
self.cl_mem["_d_slice"].data,
# d_sino (__read_only image2d_t or float*)
d_sino_ref,
- # gpu_offset_x (float32) 
+ # gpu_offset_x (float32)
np.float32(self.extra_options["gpu_offset_x"]),
# gpu_offset_y (float32)
np.float32(self.extra_options["gpu_offset_y"]),
@@ -237,7 +257,7 @@ class Backprojection(OpenclProcessing):
# d_axis (__global float32*)
self.cl_mem["d_axes"].data,
# shared mem (__local float32*)
- self._get_local_mem()
+ self._get_local_mem(),
)
def _allocate_textures(self):
@@ -273,7 +293,7 @@ class Backprojection(OpenclProcessing):
np.int32(self.dimrec_shape[1]),
np.int32((0, 0)),
np.int32((0, 0)),
- slice_shape_ocl
+ slice_shape_ocl,
)
return self.kernels.cpy2d(self.queue, ndrange, wg, *kernel_args)
@@ -281,47 +301,39 @@ class Backprojection(OpenclProcessing):
if isinstance(sino, parray.Array):
return self._transfer_device_to_texture(sino)
sino2 = sino
- if not(sino.flags["C_CONTIGUOUS"] and sino.dtype == np.float32):
+ if not (sino.flags["C_CONTIGUOUS"] and sino.dtype == np.float32):
sino2 = np.ascontiguousarray(sino, dtype=np.float32)
- if not(self._use_textures):
- ev = pyopencl.enqueue_copy(
- self.queue,
- self.d_sino.data,
- sino2
- )
+ if not (self._use_textures):
+ ev = pyopencl.enqueue_copy(self.queue, self.d_sino.data, sino2)
what = "transfer filtered sino H->D buffer"
ev.wait()
else:
ev = pyopencl.enqueue_copy(
- self.queue,
- self.d_sino_tex,
- sino2,
- origin=(0, 0),
- region=self.shape[::-1]
- )
+ self.queue,
+ self.d_sino_tex,
+ sino2,
+ origin=(0, 0),
+ region=self.shape[::-1],
+ )
what = "transfer filtered sino H->D texture"
return EventDescription(what, ev)
def _transfer_device_to_texture(self, d_sino):
- if not(self._use_textures):
+ if not (self._use_textures):
if id(self.d_sino) == id(d_sino):
return
- ev = pyopencl.enqueue_copy(
- self.queue,
- self.d_sino.data,
- d_sino
- )
+ ev = pyopencl.enqueue_copy(self.queue, self.d_sino.data, d_sino)
what = "transfer filtered sino D->D buffer"
ev.wait()
else:
ev = pyopencl.enqueue_copy(
- self.queue,
- self.d_sino_tex,
- d_sino.data,
- offset=0,
- origin=(0, 0),
- region=self.shape[::-1]
- )
+ self.queue,
+ self.d_sino_tex,
+ d_sino.data,
+ offset=0,
+ origin=(0, 0),
+ region=self.shape[::-1],
+ )
what = "transfer filtered sino D->D texture"
return EventDescription(what, ev)
@@ -337,20 +349,17 @@ class Backprojection(OpenclProcessing):
with self.sem:
events.append(self._transfer_to_texture(sino))
# Call the backprojection kernel
- if not(self._use_textures):
+ if not (self._use_textures):
kernel_to_call = self.kernels.backproj_cpu_kernel
else:
kernel_to_call = self.kernels.backproj_kernel
kernel_to_call(
- self.queue,
- self.ndrange,
- self.wg,
- *self._backproj_kernel_args
+ self.queue, self.ndrange, self.wg, *self._backproj_kernel_args
)
# Return
if output is None:
res = self.cl_mem["_d_slice"].get()
- res = res[:self.slice_shape[0], :self.slice_shape[1]]
+ res = res[: self.slice_shape[0], : self.slice_shape[1]]
else:
res = output
self._cpy2d_to_slice(output)
@@ -377,18 +386,3 @@ class Backprojection(OpenclProcessing):
return res
__call__ = filtered_backprojection
-
-
- # -------------------
- # - Compatibility -
- # -------------------
-
- @deprecated(replacement="Backprojection.sino_filter", since_version="0.10")
- def filter_projections(self, sino, rescale=True):
- self.sino_filter(sino, output=self.d_sino)
-
-
-
-def fourier_filter(sino, filter_=None, fft_size=None):
- return fourier_filter_(sino, filter_=filter_, fft_size=fft_size)
-
diff --git a/src/silx/opencl/codec/bitshuffle_lz4.py b/src/silx/opencl/codec/bitshuffle_lz4.py
new file mode 100644
index 0000000..b0992b9
--- /dev/null
+++ b/src/silx/opencl/codec/bitshuffle_lz4.py
@@ -0,0 +1,214 @@
+#!/usr/bin/env python
+#
+# Project: Sift implementation in Python + OpenCL
+# https://github.com/silx-kit/silx
+#
+# Copyright (C) 2022-2023 European Synchrotron Radiation Facility, Grenoble, France
+#
+# Permission is hereby granted, free of charge, to any person
+# obtaining a copy of this software and associated documentation
+# files (the "Software"), to deal in the Software without
+# restriction, including without limitation the rights to use,
+# copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the
+# Software is furnished to do so, subject to the following
+# conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
+# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
+# OTHER DEALINGS IN THE SOFTWARE.
+
+"""
+This module provides a class for CBF byte offset compression/decompression.
+"""
+
+__authors__ = ["Jérôme Kieffer"]
+__contact__ = "jerome.kieffer@esrf.eu"
+__license__ = "MIT"
+__copyright__ = "European Synchrotron Radiation Facility, Grenoble, France"
+__date__ = "09/11/2022"
+__status__ = "production"
+
+
+import os
+import struct
+import numpy
+from ..common import ocl, pyopencl, kernel_workgroup_size
+from ..processing import BufferDescription, EventDescription, OpenclProcessing
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class BitshuffleLz4(OpenclProcessing):
+ """Perform the bitshuffle-lz4 decompression on the GPU
+ See :class:`OpenclProcessing` for optional arguments description.
+ :param int cmp_size:
+ Size of the raw stream for decompression.
+ It can be (slightly) larger than the array.
+ :param int dec_size:
+ Size of the decompression output array
+ (mandatory for decompression)
+ :param dtype: dtype of decompressed data
+ """
+
+ LZ4_BLOCK_SIZE = 8192
+
+ def __init__(
+ self,
+ cmp_size,
+ dec_size,
+ dtype,
+ ctx=None,
+ devicetype="all",
+ platformid=None,
+ deviceid=None,
+ block_size=None,
+ profile=False,
+ ):
+ """Constructor of the class:
+
+ :param cmp_size: size of the compressed data buffer (in bytes)
+ :param dec_size: size of the compressed data buffer (in words)
+ :param dtype: data type of one work in decompressed array
+
+ For the other, see the doc of OpenclProcessing
+ """
+ OpenclProcessing.__init__(
+ self,
+ ctx=ctx,
+ devicetype=devicetype,
+ platformid=platformid,
+ deviceid=deviceid,
+ block_size=block_size,
+ profile=profile,
+ )
+ if self.block_size is None:
+ try:
+ self.block_size = self.ctx.devices[0].preferred_work_group_size_multiple
+ except:
+ self.block_size = self.device.max_work_group_size
+
+ self.cmp_size = numpy.uint64(cmp_size)
+ self.dec_size = numpy.uint64(dec_size)
+ self.dec_dtype = numpy.dtype(dtype)
+ self.num_blocks = numpy.uint32(
+ (self.dec_dtype.itemsize * self.dec_size + self.LZ4_BLOCK_SIZE - 1)
+ // self.LZ4_BLOCK_SIZE
+ )
+
+ buffers = [
+ BufferDescription("nb_blocks", 1, numpy.uint32, None),
+ BufferDescription("block_position", self.num_blocks, numpy.uint64, None),
+ BufferDescription("cmp", self.cmp_size, numpy.uint8, None),
+ BufferDescription("dec", self.dec_size, self.dec_dtype, None),
+ ]
+
+ self.allocate_buffers(buffers, use_array=True)
+
+ self.compile_kernels([os.path.join("codec", "bitshuffle_lz4")])
+ self.block_size = min(
+ self.block_size,
+ kernel_workgroup_size(self.program, "bslz4_decompress_block"),
+ )
+
+ def decompress(self, raw, out=None, wg=None, nbytes=None):
+ """This function actually performs the decompression by calling the kernels
+ :param numpy.ndarray raw: The compressed data as a 1D numpy array of char or string
+ :param pyopencl.array out: pyopencl array in which to place the result.
+ :param wg: tuneable parameter with the workgroup size.
+ :param int nbytes: (Optional) Number of bytes occupied by the chunk in raw.
+ :return: The decompressed image as an pyopencl array.
+ :rtype: pyopencl.array
+ """
+
+ events = []
+ with self.sem:
+ if nbytes is not None:
+ assert nbytes <= raw.size
+ len_raw = numpy.uint64(nbytes)
+ elif isinstance(raw, pyopencl.Buffer):
+ len_raw = numpy.uint64(raw.size)
+ else:
+ len_raw = numpy.uint64(len(raw))
+
+ if isinstance(raw, pyopencl.array.Array):
+ cmp_buffer = raw.data
+ num_blocks = self.num_blocks
+ elif isinstance(raw, pyopencl.Buffer):
+ cmp_buffer = raw
+ num_blocks = self.num_blocks
+ else:
+ if len_raw > self.cmp_size:
+ self.cmp_size = len_raw
+ logger.info("increase cmp buffer size to %s", self.cmp_size)
+ self.cl_mem["cmp"] = pyopencl.array.empty(
+ self.queue, self.cmp_size, dtype=numpy.uint8
+ )
+ evt = pyopencl.enqueue_copy(
+ self.queue, self.cl_mem["cmp"].data, raw, is_blocking=False
+ )
+ events.append(EventDescription("copy raw H -> D", evt))
+ cmp_buffer = self.cl_mem["cmp"].data
+
+ dest_size = struct.unpack(">Q", raw[:8])
+ self_dest_nbyte = self.dec_size * self.dec_dtype.itemsize
+ if dest_size < self_dest_nbyte:
+ num_blocks = numpy.uint32(
+ (dest_size + self.LZ4_BLOCK_SIZE - 1) // self.LZ4_BLOCK_SIZE
+ )
+ elif dest_size > self_dest_nbyte:
+ num_blocks = numpy.uint32(
+ (dest_size + self.LZ4_BLOCK_SIZE - 1) // self.LZ4_BLOCK_SIZE
+ )
+ self.cl_mem["dec"] = pyopencl.array.empty(
+ self.queue, dest_size, self.dec_dtype
+ )
+ self.dec_size = dest_size // self.dec_dtype.itemsize
+ else:
+ num_blocks = self.num_blocks
+
+ wg = int(wg or self.block_size)
+
+ evt = self.program.lz4_unblock(
+ self.queue,
+ (1,),
+ (1,),
+ cmp_buffer,
+ len_raw,
+ self.cl_mem["block_position"].data,
+ num_blocks,
+ self.cl_mem["nb_blocks"].data,
+ )
+ events.append(EventDescription("LZ4 unblock", evt))
+
+ if out is None:
+ out = self.cl_mem["dec"]
+ else:
+ assert out.dtype == self.dec_dtype
+ assert out.size == self.dec_size
+
+ evt = self.program.bslz4_decompress_block(
+ self.queue,
+ (self.num_blocks * wg,),
+ (wg,),
+ cmp_buffer,
+ out.data,
+ self.cl_mem["block_position"].data,
+ self.cl_mem["nb_blocks"].data,
+ numpy.uint8(self.dec_dtype.itemsize),
+ )
+ events.append(EventDescription("LZ4 decompress", evt))
+ self.profile_multi(events)
+ return out
+
+ __call__ = decompress
diff --git a/src/silx/opencl/codec/byte_offset.py b/src/silx/opencl/codec/byte_offset.py
index e497a73..e3df9b2 100644
--- a/src/silx/opencl/codec/byte_offset.py
+++ b/src/silx/opencl/codec/byte_offset.py
@@ -3,7 +3,7 @@
# Project: Sift implementation in Python + OpenCL
# https://github.com/silx-kit/silx
#
-# Copyright (C) 2013-2020 European Synchrotron Radiation Facility, Grenoble, France
+# Copyright (C) 2013-2023 European Synchrotron Radiation Facility, Grenoble, France
#
# Permission is hereby granted, free of charge, to any person
# obtaining a copy of this software and associated documentation
@@ -45,10 +45,12 @@ from ..common import ocl, pyopencl
from ..processing import BufferDescription, EventDescription, OpenclProcessing
import logging
+
logger = logging.getLogger(__name__)
if pyopencl:
import pyopencl.version
+
if pyopencl.version.VERSION < (2016, 0):
from pyopencl.scan import GenericScanKernel, GenericDebugScanKernel
else:
@@ -61,23 +63,36 @@ else:
class ByteOffset(OpenclProcessing):
"""Perform the byte offset compression/decompression on the GPU
- See :class:`OpenclProcessing` for optional arguments description.
-
- :param int raw_size:
- Size of the raw stream for decompression.
- It can be (slightly) larger than the array.
- :param int dec_size:
- Size of the decompression output array
- (mandatory for decompression)
- """
-
- def __init__(self, raw_size=None, dec_size=None,
- ctx=None, devicetype="all",
- platformid=None, deviceid=None,
- block_size=None, profile=False):
- OpenclProcessing.__init__(self, ctx=ctx, devicetype=devicetype,
- platformid=platformid, deviceid=deviceid,
- block_size=block_size, profile=profile)
+ See :class:`OpenclProcessing` for optional arguments description.
+
+ :param int raw_size:
+ Size of the raw stream for decompression.
+ It can be (slightly) larger than the array.
+ :param int dec_size:
+ Size of the decompression output array
+ (mandatory for decompression)
+ """
+
+ def __init__(
+ self,
+ raw_size=None,
+ dec_size=None,
+ ctx=None,
+ devicetype="all",
+ platformid=None,
+ deviceid=None,
+ block_size=None,
+ profile=False,
+ ):
+ OpenclProcessing.__init__(
+ self,
+ ctx=ctx,
+ devicetype=devicetype,
+ platformid=platformid,
+ deviceid=deviceid,
+ block_size=block_size,
+ profile=profile,
+ )
if self.block_size is None:
self.block_size = self.device.max_work_group_size
wg = self.block_size
@@ -94,7 +109,9 @@ class ByteOffset(OpenclProcessing):
BufferDescription("raw", self.padded_raw_size, numpy.int8, None),
BufferDescription("mask", self.padded_raw_size, numpy.int32, None),
BufferDescription("values", self.padded_raw_size, numpy.int32, None),
- BufferDescription("exceptions", self.padded_raw_size, numpy.int32, None)
+ BufferDescription(
+ "exceptions", self.padded_raw_size, numpy.int32, None
+ ),
]
if dec_size is None:
@@ -103,18 +120,17 @@ class ByteOffset(OpenclProcessing):
self.dec_size = numpy.int32(dec_size)
buffers += [
BufferDescription("data_float", self.dec_size, numpy.float32, None),
- BufferDescription("data_int", self.dec_size, numpy.int32, None)
+ BufferDescription("data_int", self.dec_size, numpy.int32, None),
]
self.allocate_buffers(buffers, use_array=True)
self.compile_kernels([os.path.join("codec", "byte_offset")])
self.kernels.__setattr__("scan", self._init_double_scan())
- self.kernels.__setattr__("compression_scan",
- self._init_compression_scan())
+ self.kernels.__setattr__("compression_scan", self._init_compression_scan())
def _init_double_scan(self):
- """"generates a double scan on indexes and values in one operation"""
+ """generates a double scan on indexes and values in one operation"""
arguments = "__global int *value", "__global int *index"
int2 = pyopencl.tools.get_or_register_dtype("int2")
input_expr = "index[i]>0 ? (int2)(0, 0) : (int2)(value[i], 1)"
@@ -123,21 +139,25 @@ class ByteOffset(OpenclProcessing):
output_statement = "value[i] = item.s0; index[i+1] = item.s1;"
if self.block_size > 256:
- knl = GenericScanKernel(self.ctx,
- dtype=int2,
- arguments=arguments,
- input_expr=input_expr,
- scan_expr=scan_expr,
- neutral=neutral,
- output_statement=output_statement)
+ knl = GenericScanKernel(
+ self.ctx,
+ dtype=int2,
+ arguments=arguments,
+ input_expr=input_expr,
+ scan_expr=scan_expr,
+ neutral=neutral,
+ output_statement=output_statement,
+ )
else: # MacOS on CPU
- knl = GenericDebugScanKernel(self.ctx,
- dtype=int2,
- arguments=arguments,
- input_expr=input_expr,
- scan_expr=scan_expr,
- neutral=neutral,
- output_statement=output_statement)
+ knl = GenericDebugScanKernel(
+ self.ctx,
+ dtype=int2,
+ arguments=arguments,
+ input_expr=input_expr,
+ scan_expr=scan_expr,
+ neutral=neutral,
+ output_statement=output_statement,
+ )
return knl
def decode(self, raw, as_float=False, out=None):
@@ -150,8 +170,9 @@ class ByteOffset(OpenclProcessing):
:return: The decompressed image as an pyopencl array.
:rtype: pyopencl.array
"""
- assert self.dec_size is not None, \
- "dec_size is a mandatory ByteOffset init argument for decompression"
+ assert (
+ self.dec_size is not None
+ ), "dec_size is a mandatory ByteOffset init argument for decompression"
events = []
with self.sem:
@@ -162,67 +183,96 @@ class ByteOffset(OpenclProcessing):
self.padded_raw_size = (self.raw_size + wg - 1) & ~(wg - 1)
logger.info("increase raw buffer size to %s", self.padded_raw_size)
buffers = {
- "raw": pyopencl.array.empty(self.queue, self.padded_raw_size, dtype=numpy.int8),
- "mask": pyopencl.array.empty(self.queue, self.padded_raw_size, dtype=numpy.int32),
- "exceptions": pyopencl.array.empty(self.queue, self.padded_raw_size, dtype=numpy.int32),
- "values": pyopencl.array.empty(self.queue, self.padded_raw_size, dtype=numpy.int32),
- }
+ "raw": pyopencl.array.empty(
+ self.queue, self.padded_raw_size, dtype=numpy.int8
+ ),
+ "mask": pyopencl.array.empty(
+ self.queue, self.padded_raw_size, dtype=numpy.int32
+ ),
+ "exceptions": pyopencl.array.empty(
+ self.queue, self.padded_raw_size, dtype=numpy.int32
+ ),
+ "values": pyopencl.array.empty(
+ self.queue, self.padded_raw_size, dtype=numpy.int32
+ ),
+ }
self.cl_mem.update(buffers)
else:
wg = self.block_size
- evt = pyopencl.enqueue_copy(self.queue, self.cl_mem["raw"].data,
- raw,
- is_blocking=False)
+ evt = pyopencl.enqueue_copy(
+ self.queue, self.cl_mem["raw"].data, raw, is_blocking=False
+ )
events.append(EventDescription("copy raw H -> D", evt))
- evt = self.kernels.fill_int_mem(self.queue, (self.padded_raw_size,), (wg,),
- self.cl_mem["mask"].data,
- numpy.int32(self.padded_raw_size),
- numpy.int32(0),
- numpy.int32(0))
+ evt = self.kernels.fill_int_mem(
+ self.queue,
+ (self.padded_raw_size,),
+ (wg,),
+ self.cl_mem["mask"].data,
+ numpy.int32(self.padded_raw_size),
+ numpy.int32(0),
+ numpy.int32(0),
+ )
events.append(EventDescription("memset mask", evt))
- evt = self.kernels.fill_int_mem(self.queue, (1,), (1,),
- self.cl_mem["counter"].data,
- numpy.int32(1),
- numpy.int32(0),
- numpy.int32(0))
+ evt = self.kernels.fill_int_mem(
+ self.queue,
+ (1,),
+ (1,),
+ self.cl_mem["counter"].data,
+ numpy.int32(1),
+ numpy.int32(0),
+ numpy.int32(0),
+ )
events.append(EventDescription("memset counter", evt))
- evt = self.kernels.mark_exceptions(self.queue, (self.padded_raw_size,), (wg,),
- self.cl_mem["raw"].data,
- len_raw,
- numpy.int32(self.raw_size),
- self.cl_mem["mask"].data,
- self.cl_mem["values"].data,
- self.cl_mem["counter"].data,
- self.cl_mem["exceptions"].data)
+ evt = self.kernels.mark_exceptions(
+ self.queue,
+ (self.padded_raw_size,),
+ (wg,),
+ self.cl_mem["raw"].data,
+ len_raw,
+ numpy.int32(self.raw_size),
+ self.cl_mem["mask"].data,
+ self.cl_mem["values"].data,
+ self.cl_mem["counter"].data,
+ self.cl_mem["exceptions"].data,
+ )
events.append(EventDescription("mark exceptions", evt))
nb_exceptions = numpy.empty(1, dtype=numpy.int32)
- evt = pyopencl.enqueue_copy(self.queue, nb_exceptions, self.cl_mem["counter"].data,
- is_blocking=False)
+ evt = pyopencl.enqueue_copy(
+ self.queue,
+ nb_exceptions,
+ self.cl_mem["counter"].data,
+ is_blocking=False,
+ )
events.append(EventDescription("copy counter D -> H", evt))
evt.wait()
nbexc = int(nb_exceptions[0])
if nbexc == 0:
logger.info("nbexc %i", nbexc)
else:
- evt = self.kernels.treat_exceptions(self.queue, (nbexc,), (1,),
- self.cl_mem["raw"].data,
- len_raw,
- self.cl_mem["mask"].data,
- self.cl_mem["exceptions"].data,
- self.cl_mem["values"].data
- )
+ evt = self.kernels.treat_exceptions(
+ self.queue,
+ (nbexc,),
+ (1,),
+ self.cl_mem["raw"].data,
+ len_raw,
+ self.cl_mem["mask"].data,
+ self.cl_mem["exceptions"].data,
+ self.cl_mem["values"].data,
+ )
events.append(EventDescription("treat_exceptions", evt))
- #self.cl_mem["copy_values"] = self.cl_mem["values"].copy()
- #self.cl_mem["copy_mask"] = self.cl_mem["mask"].copy()
- evt = self.kernels.scan(self.cl_mem["values"],
- self.cl_mem["mask"],
- queue=self.queue,
- size=int(len_raw),
- wait_for=(evt,))
+ # self.cl_mem["copy_values"] = self.cl_mem["values"].copy()
+ # self.cl_mem["copy_mask"] = self.cl_mem["mask"].copy()
+ evt = self.kernels.scan(
+ self.cl_mem["values"],
+ self.cl_mem["mask"],
+ queue=self.queue,
+ size=int(len_raw),
+ wait_for=(evt,),
+ )
events.append(EventDescription("double scan", evt))
- #evt.wait()
+ # evt.wait()
if out is not None:
if out.dtype == numpy.float32:
copy_results = self.kernels.copy_result_float
@@ -235,15 +285,18 @@ class ByteOffset(OpenclProcessing):
else:
out = self.cl_mem["data_int"]
copy_results = self.kernels.copy_result_int
- evt = copy_results(self.queue, (self.padded_raw_size,), (wg,),
- self.cl_mem["values"].data,
- self.cl_mem["mask"].data,
- len_raw,
- self.dec_size,
- out.data
- )
+ evt = copy_results(
+ self.queue,
+ (self.padded_raw_size,),
+ (wg,),
+ self.cl_mem["values"].data,
+ self.cl_mem["mask"].data,
+ len_raw,
+ self.dec_size,
+ out.data,
+ )
events.append(EventDescription("copy_results", evt))
- #evt.wait()
+ # evt.wait()
if self.profile:
self.events += events
return out
@@ -291,7 +344,9 @@ class ByteOffset(OpenclProcessing):
}
}
"""
- arguments = "__global const int *data, __global char *compressed, __global int *size"
+ arguments = (
+ "__global const int *data, __global char *compressed, __global int *size"
+ )
input_expr = "compressed_size((i == 0) ? data[0] : (data[i] - data[i - 1]))"
scan_expr = "a+b"
neutral = "0"
@@ -303,23 +358,27 @@ class ByteOffset(OpenclProcessing):
"""
if self.block_size >= 64:
- knl = GenericScanKernel(self.ctx,
- dtype=numpy.int32,
- preamble=preamble,
- arguments=arguments,
- input_expr=input_expr,
- scan_expr=scan_expr,
- neutral=neutral,
- output_statement=output_statement)
+ knl = GenericScanKernel(
+ self.ctx,
+ dtype=numpy.int32,
+ preamble=preamble,
+ arguments=arguments,
+ input_expr=input_expr,
+ scan_expr=scan_expr,
+ neutral=neutral,
+ output_statement=output_statement,
+ )
else: # MacOS on CPU
- knl = GenericDebugScanKernel(self.ctx,
- dtype=numpy.int32,
- preamble=preamble,
- arguments=arguments,
- input_expr=input_expr,
- scan_expr=scan_expr,
- neutral=neutral,
- output_statement=output_statement)
+ knl = GenericDebugScanKernel(
+ self.ctx,
+ dtype=numpy.int32,
+ preamble=preamble,
+ arguments=arguments,
+ input_expr=input_expr,
+ scan_expr=scan_expr,
+ neutral=neutral,
+ output_statement=output_statement,
+ )
return knl
def encode(self, data, out=None):
@@ -348,28 +407,39 @@ class ByteOffset(OpenclProcessing):
data = numpy.ascontiguousarray(data, dtype=numpy.int32).ravel()
# Make sure data array exists and is large enough
- if ("data_input" not in self.cl_mem or
- self.cl_mem["data_input"].size < data.size):
+ if (
+ "data_input" not in self.cl_mem
+ or self.cl_mem["data_input"].size < data.size
+ ):
logger.info("increase data input buffer size to %s", data.size)
- self.cl_mem.update({
- "data_input": pyopencl.array.empty(self.queue,
- data.size,
- dtype=numpy.int32)})
+ self.cl_mem.update(
+ {
+ "data_input": pyopencl.array.empty(
+ self.queue, data.size, dtype=numpy.int32
+ )
+ }
+ )
d_data = self.cl_mem["data_input"]
evt = pyopencl.enqueue_copy(
- self.queue, d_data.data, data, is_blocking=False)
+ self.queue, d_data.data, data, is_blocking=False
+ )
events.append(EventDescription("copy data H -> D", evt))
# Make sure compressed array exists and is large enough
compressed_size = d_data.size * 7
- if ("compressed" not in self.cl_mem or
- self.cl_mem["compressed"].size < compressed_size):
+ if (
+ "compressed" not in self.cl_mem
+ or self.cl_mem["compressed"].size < compressed_size
+ ):
logger.info("increase compressed buffer size to %s", compressed_size)
- self.cl_mem.update({
- "compressed": pyopencl.array.empty(self.queue,
- compressed_size,
- dtype=numpy.int8)})
+ self.cl_mem.update(
+ {
+ "compressed": pyopencl.array.empty(
+ self.queue, compressed_size, dtype=numpy.int8
+ )
+ }
+ )
d_compressed = self.cl_mem["compressed"]
d_size = self.cl_mem["counter"] # Shared with decompression
@@ -384,13 +454,15 @@ class ByteOffset(OpenclProcessing):
shape=(byte_count,),
dtype=numpy.int8,
allocator=functools.partial(
- d_compressed.base_data.get_sub_region,
- d_compressed.offset))
+ d_compressed.base_data.get_sub_region, d_compressed.offset
+ ),
+ )
elif out.size < byte_count:
raise ValueError(
"Provided output buffer is not large enough: "
- "requires %d bytes, got %d" % (byte_count, out.size))
+ "requires %d bytes, got %d" % (byte_count, out.size)
+ )
else: # out.size >= byte_count
# Create an array with a sub-region of out and this class queue
@@ -398,13 +470,15 @@ class ByteOffset(OpenclProcessing):
self.queue,
shape=(byte_count,),
dtype=numpy.int8,
- allocator=functools.partial(out.base_data.get_sub_region,
- out.offset))
+ allocator=functools.partial(
+ out.base_data.get_sub_region, out.offset
+ ),
+ )
- evt = pyopencl.enqueue_copy(self.queue, out.data, d_compressed.data,
- byte_count=byte_count)
- events.append(
- EventDescription("copy D -> D: internal -> out", evt))
+ evt = pyopencl.enqueue_copy(
+ self.queue, out.data, d_compressed.data, byte_count=byte_count
+ )
+ events.append(EventDescription("copy D -> D: internal -> out", evt))
if self.profile:
self.events += events
diff --git a/src/silx/opencl/codec/test/test_bitshuffle_lz4.py b/src/silx/opencl/codec/test/test_bitshuffle_lz4.py
new file mode 100644
index 0000000..6c5891e
--- /dev/null
+++ b/src/silx/opencl/codec/test/test_bitshuffle_lz4.py
@@ -0,0 +1,126 @@
+#!/usr/bin/env python
+#
+# Project: Bitshuffle-LZ4 decompression in OpenCL
+# https://github.com/silx-kit/silx
+#
+# Copyright (C) 2022-2023 European Synchrotron Radiation Facility,
+# Grenoble, France
+# Permission is hereby granted, free of charge, to any person
+# obtaining a copy of this software and associated documentation
+# files (the "Software"), to deal in the Software without
+# restriction, including without limitation the rights to use,
+# copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the
+# Software is furnished to do so, subject to the following
+# conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
+# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
+# OTHER DEALINGS IN THE SOFTWARE.
+
+"""
+Test suite for byte-offset decompression
+"""
+
+__authors__ = ["Jérôme Kieffer"]
+__contact__ = "jerome.kieffer@esrf.eu"
+__license__ = "MIT"
+__copyright__ = "2022 European Synchrotron Radiation Facility, Grenoble, France"
+__date__ = "07/11/2022"
+
+import struct
+import numpy
+import pytest
+
+try:
+ import bitshuffle
+except ImportError:
+ bitshuffle = None
+from silx.opencl.common import ocl, pyopencl
+from silx.opencl.codec.bitshuffle_lz4 import BitshuffleLz4
+
+
+TESTCASES = ( # dtype, shape
+ ("uint64", (103, 503)),
+ ("int64", (101, 509)),
+ ("uint32", (229, 659)),
+ ("int32", (233, 653)),
+ ("uint16", (743, 647)),
+ ("int16", (751, 643)),
+ ("uint8", (157, 1373)),
+ ("int8", (163, 1367)),
+)
+
+
+@pytest.mark.skipif(
+ not ocl or not pyopencl or bitshuffle is None,
+ reason="PyOpenCl or bitshuffle is missing",
+)
+class TestBitshuffleLz4:
+ """Test pyopencl bishuffle+LZ4 decompression"""
+
+ @staticmethod
+ def _create_test_data(shape, lam=100, dtype="uint32"):
+ """Create test (image, compressed stream) pair.
+
+ :param shape: Shape of test image
+ :param lam: Expectation of interval argument for numpy.random.poisson
+ :return: (reference image array, compressed stream)
+ """
+ ref = numpy.random.poisson(lam, size=shape).astype(dtype)
+ raw = (
+ struct.pack(">Q", ref.nbytes)
+ + b"\x00" * 4
+ + bitshuffle.compress_lz4(ref).tobytes()
+ )
+ return ref, raw
+
+ @pytest.mark.parametrize("dtype,shape", TESTCASES)
+ def test_decompress(self, dtype, shape):
+ """
+ Tests the byte offset decompression on GPU with various configuration
+ """
+ ref, raw = self._create_test_data(shape=shape, dtype=dtype)
+ bs = BitshuffleLz4(len(raw), numpy.prod(shape), dtype=dtype)
+ res = bs.decompress(raw).get()
+ assert numpy.array_equal(res, ref.ravel()), "Checks decompression works"
+
+ @pytest.mark.parametrize("dtype,shape", TESTCASES)
+ def test_decompress_from_buffer(self, dtype, shape):
+ """Test reading compressed data from pyopencl Buffer"""
+ ref, raw = self._create_test_data(shape=shape, dtype=dtype)
+
+ bs = BitshuffleLz4(0, numpy.prod(shape), dtype=dtype)
+
+ buffer = pyopencl.Buffer(
+ bs.ctx,
+ flags=pyopencl.mem_flags.COPY_HOST_PTR | pyopencl.mem_flags.READ_ONLY,
+ hostbuf=raw,
+ )
+
+ res = bs.decompress(buffer).get()
+ assert numpy.array_equal(res, ref.ravel()), "Checks decompression works"
+
+ @pytest.mark.parametrize("dtype,shape", TESTCASES)
+ def test_decompress_from_array(self, dtype, shape):
+ """Test reading compressed data from pyopencl Array"""
+ ref, raw = self._create_test_data(shape=shape, dtype=dtype)
+
+ bs = BitshuffleLz4(0, numpy.prod(shape), dtype=dtype)
+
+ array = pyopencl.array.to_device(
+ bs.queue,
+ numpy.frombuffer(raw, dtype=numpy.uint8),
+ array_queue=bs.queue,
+ )
+
+ res = bs.decompress(array).get()
+ assert numpy.array_equal(res, ref.ravel()), "Checks decompression works"
diff --git a/src/silx/opencl/codec/test/test_byte_offset.py b/src/silx/opencl/codec/test/test_byte_offset.py
index 9ed53bc..0e58076 100644
--- a/src/silx/opencl/codec/test/test_byte_offset.py
+++ b/src/silx/opencl/codec/test/test_byte_offset.py
@@ -44,13 +44,12 @@ from silx.opencl.common import ocl, pyopencl
from silx.opencl.codec import byte_offset
import fabio
import unittest
+
logger = logging.getLogger(__name__)
-@unittest.skipUnless(ocl and pyopencl,
- "PyOpenCl is missing")
+@unittest.skipUnless(ocl and pyopencl, "PyOpenCl is missing")
class TestByteOffset(unittest.TestCase):
-
@staticmethod
def _create_test_data(shape, nexcept, lam=200):
"""Create test (image, compressed stream) pair.
@@ -84,7 +83,9 @@ class TestByteOffset(unittest.TestCase):
except (RuntimeError, pyopencl.RuntimeError) as err:
logger.warning(err)
if sys.platform == "darwin":
- raise unittest.SkipTest("Byte-offset decompression is known to be buggy on MacOS-CPU")
+ raise unittest.SkipTest(
+ "Byte-offset decompression is known to be buggy on MacOS-CPU"
+ )
else:
raise err
print(bo.block_size)
@@ -97,9 +98,11 @@ class TestByteOffset(unittest.TestCase):
delta_cy = abs(ref.ravel() - res_cy).max()
delta_cl = abs(ref.ravel() - res_cl.get()).max()
- logger.debug("Global execution time: fabio %.3fms, OpenCL: %.3fms.",
- 1000.0 * (t1 - t0),
- 1000.0 * (t2 - t1))
+ logger.debug(
+ "Global execution time: fabio %.3fms, OpenCL: %.3fms.",
+ 1000.0 * (t1 - t0),
+ 1000.0 * (t2 - t1),
+ )
bo.log_profile()
# print(ref)
# print(res_cl.get())
@@ -108,8 +111,8 @@ class TestByteOffset(unittest.TestCase):
def test_many_decompress(self, ntest=10):
"""
- tests the byte offset decompression on GPU, many images to ensure there
- is not leaking in memory
+ tests the byte offset decompression on GPU, many images to ensure there
+ is not leaking in memory
"""
shape = (991, 997)
size = numpy.prod(shape)
@@ -120,7 +123,9 @@ class TestByteOffset(unittest.TestCase):
except (RuntimeError, pyopencl.RuntimeError) as err:
logger.warning(err)
if sys.platform == "darwin":
- raise unittest.SkipTest("Byte-offset decompression is known to be buggy on MacOS-CPU")
+ raise unittest.SkipTest(
+ "Byte-offset decompression is known to be buggy on MacOS-CPU"
+ )
else:
raise err
t0 = time.time()
@@ -132,9 +137,11 @@ class TestByteOffset(unittest.TestCase):
delta_cl = abs(ref.ravel() - res_cl.get()).max()
self.assertEqual(delta_cy, 0, "Checks fabio works")
self.assertEqual(delta_cl, 0, "Checks opencl works")
- logger.debug("Global execution time: fabio %.3fms, OpenCL: %.3fms.",
- 1000.0 * (t1 - t0),
- 1000.0 * (t2 - t1))
+ logger.debug(
+ "Global execution time: fabio %.3fms, OpenCL: %.3fms.",
+ 1000.0 * (t1 - t0),
+ 1000.0 * (t2 - t1),
+ )
for i in range(ntest):
ref, raw = self._create_test_data(shape=shape, nexcept=2729, lam=200)
@@ -149,9 +156,11 @@ class TestByteOffset(unittest.TestCase):
self.assertEqual(delta_cy, 0, "Checks fabio works #%i" % i)
self.assertEqual(delta_cl, 0, "Checks opencl works #%i" % i)
- logger.debug("Global execution time: fabio %.3fms, OpenCL: %.3fms.",
- 1000.0 * (t1 - t0),
- 1000.0 * (t2 - t1))
+ logger.debug(
+ "Global execution time: fabio %.3fms, OpenCL: %.3fms.",
+ 1000.0 * (t1 - t0),
+ 1000.0 * (t2 - t1),
+ )
bo.log_profile(stats=True)
def test_encode(self):
@@ -171,8 +180,7 @@ class TestByteOffset(unittest.TestCase):
compressed_stream = compressed_array.get().tobytes()
self.assertEqual(raw, compressed_stream)
- logger.debug("Global execution time: OpenCL: %.3fms.",
- 1000.0 * (t1 - t0))
+ logger.debug("Global execution time: OpenCL: %.3fms.", 1000.0 * (t1 - t0))
bo.log_profile()
def test_encode_to_array(self):
@@ -222,14 +230,15 @@ class TestByteOffset(unittest.TestCase):
self.assertEqual(raw, compressed_stream)
- logger.debug("Global execution time: fabio %.3fms, OpenCL: %.3fms.",
- 1000.0 * (t1 - t0),
- 1000.0 * (t2 - t1))
+ logger.debug(
+ "Global execution time: fabio %.3fms, OpenCL: %.3fms.",
+ 1000.0 * (t1 - t0),
+ 1000.0 * (t2 - t1),
+ )
bo.log_profile()
def test_encode_to_bytes_from_array(self):
- """Test byte offset compression to bytes from a pyopencl array.
- """
+ """Test byte offset compression to bytes from a pyopencl array."""
ref, raw = self._create_test_data(shape=(2713, 2719), nexcept=2729)
try:
@@ -238,8 +247,7 @@ class TestByteOffset(unittest.TestCase):
logger.warning(err)
raise err
- d_ref = pyopencl.array.to_device(
- bo.queue, ref.astype(numpy.int32).ravel())
+ d_ref = pyopencl.array.to_device(bo.queue, ref.astype(numpy.int32).ravel())
t0 = time.time()
res_fabio = fabio.compression.compByteOffset(ref)
@@ -249,9 +257,11 @@ class TestByteOffset(unittest.TestCase):
self.assertEqual(raw, compressed_stream)
- logger.debug("Global execution time: fabio %.3fms, OpenCL: %.3fms.",
- 1000.0 * (t1 - t0),
- 1000.0 * (t2 - t1))
+ logger.debug(
+ "Global execution time: fabio %.3fms, OpenCL: %.3fms.",
+ 1000.0 * (t1 - t0),
+ 1000.0 * (t2 - t1),
+ )
bo.log_profile()
def test_many_encode(self, ntest=10):
@@ -275,9 +285,11 @@ class TestByteOffset(unittest.TestCase):
bo_durations.append(1000.0 * (t2 - t1))
self.assertEqual(raw, compressed_stream)
- logger.debug("Global execution time: fabio %.3fms, OpenCL: %.3fms.",
- 1000.0 * (t1 - t0),
- 1000.0 * (t2 - t1))
+ logger.debug(
+ "Global execution time: fabio %.3fms, OpenCL: %.3fms.",
+ 1000.0 * (t1 - t0),
+ 1000.0 * (t2 - t1),
+ )
for i in range(ntest):
ref, raw = self._create_test_data(shape=shape, nexcept=2729, lam=200)
@@ -290,11 +302,15 @@ class TestByteOffset(unittest.TestCase):
bo_durations.append(1000.0 * (t2 - t1))
self.assertEqual(raw, compressed_stream)
- logger.debug("Global execution time: fabio %.3fms, OpenCL: %.3fms.",
- 1000.0 * (t1 - t0),
- 1000.0 * (t2 - t1))
-
- logger.debug("OpenCL execution time: Mean: %fms, Min: %fms, Max: %fms",
- numpy.mean(bo_durations),
- numpy.min(bo_durations),
- numpy.max(bo_durations))
+ logger.debug(
+ "Global execution time: fabio %.3fms, OpenCL: %.3fms.",
+ 1000.0 * (t1 - t0),
+ 1000.0 * (t2 - t1),
+ )
+
+ logger.debug(
+ "OpenCL execution time: Mean: %fms, Min: %fms, Max: %fms",
+ numpy.mean(bo_durations),
+ numpy.min(bo_durations),
+ numpy.max(bo_durations),
+ )
diff --git a/src/silx/opencl/common.py b/src/silx/opencl/common.py
index cf51406..30c9ef7 100644
--- a/src/silx/opencl/common.py
+++ b/src/silx/opencl/common.py
@@ -3,7 +3,7 @@
# Project: S I L X project
# https://github.com/silx-kit/silx
#
-# Copyright (C) 2012-2021 European Synchrotron Radiation Facility, Grenoble, France
+# Copyright (C) 2012-2023 European Synchrotron Radiation Facility, Grenoble, France
#
# Principal author: Jérôme Kieffer (Jerome.Kieffer@ESRF.eu)
#
@@ -33,10 +33,17 @@ __author__ = "Jerome Kieffer"
__contact__ = "Jerome.Kieffer@ESRF.eu"
__license__ = "MIT"
__copyright__ = "2012-2017 European Synchrotron Radiation Facility, Grenoble, France"
-__date__ = "29/09/2021"
+__date__ = "09/09/2023"
__status__ = "stable"
-__all__ = ["ocl", "pyopencl", "mf", "release_cl_buffers", "allocate_cl_buffers",
- "measure_workgroup_size", "kernel_workgroup_size"]
+__all__ = [
+ "ocl",
+ "pyopencl",
+ "mf",
+ "release_cl_buffers",
+ "allocate_cl_buffers",
+ "measure_workgroup_size",
+ "kernel_workgroup_size",
+]
import os
import logging
@@ -48,58 +55,69 @@ from .utils import get_opencl_code
logger = logging.getLogger(__name__)
if os.environ.get("SILX_OPENCL") in ["0", "False"]:
- logger.info("Use of OpenCL has been disabled from environment variable: SILX_OPENCL=0")
+ logger.info(
+ "Use of OpenCL has been disabled from environment variable: SILX_OPENCL=0"
+ )
pyopencl = None
else:
try:
import pyopencl
except ImportError:
- logger.warning("Unable to import pyOpenCl. Please install it from: https://pypi.org/project/pyopencl")
+ logger.warning(
+ "Unable to import pyOpenCl. Please install it from: https://pypi.org/project/pyopencl"
+ )
pyopencl = None
else:
try:
pyopencl.get_platforms()
except pyopencl.LogicError:
- logger.warning("The module pyOpenCL has been imported but can't be used here")
+ logger.warning(
+ "The module pyOpenCL has been imported but can't be used here"
+ )
pyopencl = None
- else:
- import pyopencl.array as array
- mf = pyopencl.mem_flags
-if pyopencl is None:
+if pyopencl is not None:
+ import pyopencl.array as array
+ mf = pyopencl.mem_flags
+ from .atomic import check_atomic32, check_atomic64
+else:
# Define default mem flags
class mf(object):
WRITE_ONLY = 1
READ_ONLY = 1
READ_WRITE = 1
-FLOP_PER_CORE = {"GPU": 64, # GPU, Fermi at least perform 64 flops per cycle/multicore, G80 were at 24 or 48 ...
- "CPU": 4, # CPU, at least intel's have 4 operation per cycle
- "ACC": 8} # ACC: the Xeon-phi (MIC) appears to be able to process 8 Flops per hyperthreaded-core
+
+FLOP_PER_CORE = {
+ "GPU": 64, # GPU, Fermi at least perform 64 flops per cycle/multicore, G80 were at 24 or 48 ...
+ "CPU": 4, # CPU, at least intel's have 4 operation per cycle
+ "ACC": 8,
+} # ACC: the Xeon-phi (MIC) appears to be able to process 8 Flops per hyperthreaded-core
# Sources : https://en.wikipedia.org/wiki/CUDA
-NVIDIA_FLOP_PER_CORE = {(1, 0): 24, # Guessed !
- (1, 1): 24, # Measured on G98 [Quadro NVS 295]
- (1, 2): 24, # Guessed !
- (1, 3): 24, # measured on a GT285 (GT200)
- (2, 0): 64, # Measured on a 580 (GF110)
- (2, 1): 96, # Measured on Quadro2000 GF106GL
- (3, 0): 384, # Guessed!
- (3, 5): 384, # Measured on K20
- (3, 7): 384, # K80: Guessed!
- (5, 0): 256, # Maxwell 4 warps/SM 2 flops/ CU
- (5, 2): 256, # Titan-X
- (5, 3): 256, # TX1
- (6, 0): 128, # GP100
- (6, 1): 128, # GP104
- (6, 2): 128, # ?
- (7, 0): 128, # Volta # measured on Telsa V100
- (7, 2): 128, # Volta ?
- (7, 5): 128, # Turing # measured on RTX 6000
- (8, 0): 128, # Ampere # measured on Tesla A100
- (8, 6): 256, # Ampere # measured on RTX A5000
- }
+NVIDIA_FLOP_PER_CORE = {
+ (1, 0): 24, # Guessed !
+ (1, 1): 24, # Measured on G98 [Quadro NVS 295]
+ (1, 2): 24, # Guessed !
+ (1, 3): 24, # measured on a GT285 (GT200)
+ (2, 0): 64, # Measured on a 580 (GF110)
+ (2, 1): 96, # Measured on Quadro2000 GF106GL
+ (3, 0): 384, # Guessed!
+ (3, 5): 384, # Measured on K20
+ (3, 7): 384, # K80: Guessed!
+ (5, 0): 256, # Maxwell 4 warps/SM 2 flops/ CU
+ (5, 2): 256, # Titan-X
+ (5, 3): 256, # TX1
+ (6, 0): 128, # GP100
+ (6, 1): 128, # GP104
+ (6, 2): 128, # ?
+ (7, 0): 128, # Volta # measured on Telsa V100
+ (7, 2): 128, # Volta ?
+ (7, 5): 128, # Turing # measured on RTX 6000
+ (8, 0): 128, # Ampere # measured on Tesla A100
+ (8, 6): 256, # Ampere # measured on RTX A5000
+}
AMD_FLOP_PER_CORE = 160 # Measured on a M7820 10 core, 700MHz 1120GFlops
@@ -109,9 +127,24 @@ class Device(object):
Simple class that contains the structure of an OpenCL device
"""
- def __init__(self, name="None", dtype=None, version=None, driver_version=None,
- extensions="", memory=None, available=None,
- cores=None, frequency=None, flop_core=None, idx=0, workgroup=1):
+ def __init__(
+ self,
+ name="None",
+ dtype=None,
+ version=None,
+ driver_version=None,
+ extensions="",
+ memory=None,
+ available=None,
+ cores=None,
+ frequency=None,
+ flop_core=None,
+ idx=0,
+ workgroup=1,
+ atomic32=None,
+ atomic64=None,
+ platform=None,
+ ):
"""
Simple container with some important data for the OpenCL device description.
@@ -127,6 +160,7 @@ class Device(object):
:param flop_core: Flopating Point operation per core per cycle
:param idx: index of the device within the platform
:param workgroup: max workgroup size
+ :param platform: the platform to which this device is attached
"""
self.name = name.strip()
self.type = dtype
@@ -139,12 +173,15 @@ class Device(object):
self.frequency = frequency
self.id = idx
self.max_work_group_size = workgroup
+ self.atomic32 = atomic32
+ self.atomic64 = atomic64
if not flop_core:
flop_core = FLOP_PER_CORE.get(dtype, 1)
if cores and frequency:
self.flops = cores * frequency * flop_core
else:
self.flops = flop_core
+ self.platform = platform
def __repr__(self):
return "%s" % self.name
@@ -155,19 +192,20 @@ class Device(object):
:return: string
"""
- lst = ["Name\t\t:\t%s" % self.name,
- "Type\t\t:\t%s" % self.type,
- "Memory\t\t:\t%.3f MB" % (self.memory / 2.0 ** 20),
- "Cores\t\t:\t%s CU" % self.cores,
- "Frequency\t:\t%s MHz" % self.frequency,
- "Speed\t\t:\t%.3f GFLOPS" % (self.flops / 1000.),
- "Version\t\t:\t%s" % self.version,
- "Available\t:\t%s" % self.available]
+ lst = [
+ "Name\t\t:\t%s" % self.name,
+ "Type\t\t:\t%s" % self.type,
+ "Memory\t\t:\t%.3f MB" % (self.memory / 2.0**20),
+ "Cores\t\t:\t%s CU" % self.cores,
+ "Frequency\t:\t%s MHz" % self.frequency,
+ "Speed\t\t:\t%.3f GFLOPS" % (self.flops / 1000.0),
+ "Version\t\t:\t%s" % self.version,
+ "Available\t:\t%s" % self.available,
+ ]
return os.linesep.join(lst)
def set_unavailable(self):
- """Use this method to flag a faulty device
- """
+ """Use this method to flag a faulty device"""
self.available = False
@@ -176,7 +214,9 @@ class Platform(object):
Simple class that contains the structure of an OpenCL platform
"""
- def __init__(self, name="None", vendor="None", version=None, extensions=None, idx=0):
+ def __init__(
+ self, name="None", vendor="None", version=None, extensions=None, idx=0
+ ):
"""
Class containing all descriptions of a platform and all devices description within that platform.
@@ -202,6 +242,7 @@ class Platform(object):
:param device: Device instance
"""
+ device.platform = self
self.devices.append(device)
def get_device(self, key):
@@ -242,19 +283,25 @@ def _measure_workgroup_size(device_or_context, fast=False):
platformid = pyopencl.get_platforms().index(platform)
deviceid = platform.get_devices().index(device_or_context)
ocl.platforms[platformid].devices[deviceid].set_unavailable()
- raise RuntimeError("Unable to create context on %s/%s: %s" % (platform, device_or_context, error))
+ raise RuntimeError(
+ "Unable to create context on %s/%s: %s"
+ % (platform, device_or_context, error)
+ )
else:
device = device_or_context
elif isinstance(device_or_context, pyopencl.Context):
ctx = device_or_context
device = device_or_context.devices[0]
elif isinstance(device_or_context, (tuple, list)) and len(device_or_context) == 2:
- ctx = ocl.create_context(platformid=device_or_context[0],
- deviceid=device_or_context[1])
+ ctx = ocl.create_context(
+ platformid=device_or_context[0], deviceid=device_or_context[1]
+ )
device = ctx.devices[0]
else:
- raise RuntimeError("""given parameter device_or_context is not an
- instanciation of a device or a context""")
+ raise RuntimeError(
+ """given parameter device_or_context is not an
+ instanciation of a device or a context"""
+ )
shape = device.max_work_group_size
# get the context
@@ -269,7 +316,9 @@ def _measure_workgroup_size(device_or_context, fast=False):
program = pyopencl.Program(ctx, get_opencl_code("addition")).build()
if fast:
- max_valid_wg = program.addition.get_work_group_info(pyopencl.kernel_work_group_info.WORK_GROUP_SIZE, device)
+ max_valid_wg = program.addition.get_work_group_info(
+ pyopencl.kernel_work_group_info.WORK_GROUP_SIZE, device
+ )
else:
maxi = int(round(numpy.log2(shape)))
for i in range(maxi + 1):
@@ -277,11 +326,19 @@ def _measure_workgroup_size(device_or_context, fast=False):
wg = 1 << i
try:
evt = program.addition(
- queue, (shape,), (wg,),
- d_data.data, d_data_1.data, d_res.data, numpy.int32(shape))
+ queue,
+ (shape,),
+ (wg,),
+ d_data.data,
+ d_data_1.data,
+ d_res.data,
+ numpy.int32(shape),
+ )
evt.wait()
except Exception as error:
- logger.info("%s on device %s for WG=%s/%s", error, device.name, wg, shape)
+ logger.info(
+ "%s on device %s for WG=%s/%s", error, device.name, wg, shape
+ )
program = queue = d_res = d_data_1 = d_data = None
break
else:
@@ -291,7 +348,9 @@ def _measure_workgroup_size(device_or_context, fast=False):
if wg > max_valid_wg:
max_valid_wg = wg
else:
- logger.warning("ArithmeticError on %s for WG=%s/%s", wg, device.name, shape)
+ logger.warning(
+ "ArithmeticError on %s for WG=%s/%s", wg, device.name, shape
+ )
return max_valid_wg
@@ -314,15 +373,25 @@ class OpenCL(object):
if pyopencl:
platform = device = pypl = devtype = extensions = pydev = None
for idx, platform in enumerate(pyopencl.get_platforms()):
- pypl = Platform(platform.name, platform.vendor, platform.version, platform.extensions, idx)
+ pypl = Platform(
+ platform.name,
+ platform.vendor,
+ platform.version,
+ platform.extensions,
+ idx,
+ )
for idd, device in enumerate(platform.get_devices()):
####################################################
# Nvidia does not report int64 atomics (we are using) ...
# this is a hack around as any nvidia GPU with double-precision supports int64 atomics
####################################################
extensions = device.extensions
- if (pypl.vendor == "NVIDIA Corporation") and ('cl_khr_fp64' in extensions):
- extensions += ' cl_khr_int64_base_atomics cl_khr_int64_extended_atomics'
+ if (pypl.vendor == "NVIDIA Corporation") and (
+ "cl_khr_fp64" in extensions
+ ):
+ extensions += (
+ " cl_khr_int64_base_atomics cl_khr_int64_extended_atomics"
+ )
try:
devtype = pyopencl.device_type.to_string(device.type).upper()
except ValueError:
@@ -337,14 +406,23 @@ class OpenCL(object):
devtype = "CPU"
else:
devtype = devtype[:3]
- if _is_nvidia_gpu(device.vendor, devtype) and ("compute_capability_major_nv" in dir(device)):
+ if _is_nvidia_gpu(device.vendor, devtype) and (
+ "compute_capability_major_nv" in dir(device)
+ ):
try:
- comput_cap = device.compute_capability_major_nv, device.compute_capability_minor_nv
+ comput_cap = (
+ device.compute_capability_major_nv,
+ device.compute_capability_minor_nv,
+ )
except pyopencl.LogicError:
flop_core = FLOP_PER_CORE["GPU"]
else:
- flop_core = NVIDIA_FLOP_PER_CORE.get(comput_cap, FLOP_PER_CORE["GPU"])
- elif (pypl.vendor == "Advanced Micro Devices, Inc.") and (devtype == "GPU"):
+ flop_core = NVIDIA_FLOP_PER_CORE.get(
+ comput_cap, FLOP_PER_CORE["GPU"]
+ )
+ elif (pypl.vendor == "Advanced Micro Devices, Inc.") and (
+ devtype == "GPU"
+ ):
flop_core = AMD_FLOP_PER_CORE
elif devtype == "CPU":
flop_core = FLOP_PER_CORE.get(devtype, 1)
@@ -352,14 +430,29 @@ class OpenCL(object):
flop_core = 1
workgroup = device.max_work_group_size
if (devtype == "CPU") and (pypl.vendor == "Apple"):
- logger.info("For Apple's OpenCL on CPU: Measuring actual valid max_work_goup_size.")
+ logger.info(
+ "For Apple's OpenCL on CPU: Measuring actual valid max_work_goup_size."
+ )
workgroup = _measure_workgroup_size(device, fast=True)
if (devtype == "GPU") and os.environ.get("GPU") == "False":
# Environment variable to disable GPU devices
continue
- pydev = Device(device.name, devtype, device.version, device.driver_version, extensions,
- device.global_mem_size, bool(device.available), device.max_compute_units,
- device.max_clock_frequency, flop_core, idd, workgroup)
+ pydev = Device(
+ device.name,
+ devtype,
+ device.version,
+ device.driver_version,
+ extensions,
+ device.global_mem_size,
+ bool(device.available),
+ device.max_compute_units,
+ device.max_clock_frequency,
+ flop_core,
+ idd,
+ workgroup,
+ check_atomic32(device)[0],
+ check_atomic64(device)[0],
+ )
pypl.add_device(pydev)
nb_devices += 1
platforms.append(pypl)
@@ -368,9 +461,11 @@ class OpenCL(object):
def __repr__(self):
out = ["OpenCL devices:"]
for platformid, platform in enumerate(self.platforms):
- deviceids = ["(%s,%s) %s" % (platformid, deviceid, dev.name)
- for deviceid, dev in enumerate(platform.devices)]
- out.append("[%s] %s: " % (platformid, platform.name) + ", ".join(deviceids))
+ deviceids = [
+ f"({platformid},{deviceid}) {dev.name}"
+ for deviceid, dev in enumerate(platform.devices)
+ ]
+ out.append(f"[{platformid}] {platform.name}: " + ", ".join(deviceids))
return os.linesep.join(out)
def get_platform(self, key):
@@ -392,7 +487,9 @@ class OpenCL(object):
out = self.platforms[platid]
return out
- def select_device(self, dtype="ALL", memory=None, extensions=None, best=True, **kwargs):
+ def select_device(
+ self, dtype="ALL", memory=None, extensions=None, best=True, **kwargs
+ ):
"""
Select a device based on few parameters (at the end, keep the one with most memory)
@@ -436,8 +533,15 @@ class OpenCL(object):
# Nothing found
return None
- def create_context(self, devicetype="ALL", useFp64=False, platformid=None,
- deviceid=None, cached=True, memory=None, extensions=None):
+ def create_context(
+ self,
+ devicetype="ALL",
+ platformid=None,
+ deviceid=None,
+ cached=True,
+ memory=None,
+ extensions=None,
+ ):
"""
Choose a device and initiate a context.
@@ -447,7 +551,6 @@ class OpenCL(object):
E.g.: If Nvidia driver is installed, GPU will succeed but CPU will fail.
The AMD SDK kit is required for CPU via OpenCL.
:param devicetype: string in ["cpu","gpu", "all", "acc"]
- :param useFp64: boolean specifying if double precision will be used: deprecated use extensions=["cl_khr_fp64"]
:param platformid: integer
:param deviceid: integer
:param cached: True if we want to cache the context
@@ -457,37 +560,60 @@ class OpenCL(object):
"""
if extensions is None:
extensions = []
- if useFp64:
- logger.warning("Deprecation: please select your device using the extension name!, i.e. extensions=['cl_khr_fp64']")
- extensions.append('cl_khr_fp64')
+ ctx = None
if (platformid is not None) and (deviceid is not None):
platformid = int(platformid)
deviceid = int(deviceid)
elif "PYOPENCL_CTX" in os.environ:
- pyopencl_ctx = [int(i) if i.isdigit() else 0 for i in os.environ["PYOPENCL_CTX"].split(":")]
- pyopencl_ctx += [0] * (2 - len(pyopencl_ctx)) # pad with 0
- platformid, deviceid = pyopencl_ctx
+ ctx = pyopencl.create_some_context()
+ # try:
+ device = ctx.devices[0]
+ platforms = [
+ i for i, p in enumerate(ocl.platforms) if device.platform.name == p.name
+ ]
+ if platforms:
+ platformid = platforms[0]
+ devices = [
+ i
+ for i, d in enumerate(ocl.platforms[platformid].devices)
+ if device.name == d.name
+ ]
+ if devices:
+ deviceid = devices[0]
+ if cached:
+ self.context_cache[(platformid, deviceid)] = ctx
else:
ids = ocl.select_device(type=devicetype, extensions=extensions)
if ids:
platformid, deviceid = ids
- ctx = None
- if (platformid is not None) and (deviceid is not None):
+
+ if (ctx is None) and (platformid is not None) and (deviceid is not None):
if (platformid, deviceid) in self.context_cache:
ctx = self.context_cache[(platformid, deviceid)]
else:
try:
- ctx = pyopencl.Context(devices=[pyopencl.get_platforms()[platformid].get_devices()[deviceid]])
+ ctx = pyopencl.Context(
+ devices=[
+ pyopencl.get_platforms()[platformid].get_devices()[deviceid]
+ ]
+ )
except pyopencl._cl.LogicError as error:
self.platforms[platformid].devices[deviceid].set_unavailable()
- logger.warning("Unable to create context on %s/%s: %s", platformid, deviceid, error)
+ logger.warning(
+ "Unable to create context on %s/%s: %s",
+ platformid,
+ deviceid,
+ error,
+ )
ctx = None
else:
if cached:
self.context_cache[(platformid, deviceid)] = ctx
if ctx is None:
- logger.warning("Last chance to get an OpenCL device ... probably not the one requested")
+ logger.warning(
+ "Last chance to get an OpenCL device ... probably not the one requested"
+ )
ctx = pyopencl.create_some_context(interactive=False)
return ctx
@@ -557,8 +683,9 @@ def allocate_cl_buffers(buffers, device=None, context=None):
for _, _, dtype, size in buffers:
ualloc += numpy.dtype(dtype).itemsize * size
memory = device.memory
- logger.info("%.3fMB are needed on device which has %.3fMB",
- ualloc / 1.0e6, memory / 1.0e6)
+ logger.info(
+ "%.3fMB are needed on device which has %.3fMB", ualloc / 1.0e6, memory / 1.0e6
+ )
if ualloc >= memory:
memError = "Fatal error in allocate_buffers."
memError += "Not enough device memory for buffers"
@@ -568,8 +695,9 @@ def allocate_cl_buffers(buffers, device=None, context=None):
# do the allocation
try:
for name, flag, dtype, size in buffers:
- mem[name] = pyopencl.Buffer(context, flag,
- numpy.dtype(dtype).itemsize * size)
+ mem[name] = pyopencl.Buffer(
+ context, flag, numpy.dtype(dtype).itemsize * size
+ )
except pyopencl.MemoryError as error:
release_cl_buffers(mem)
raise MemoryError(error)
@@ -586,16 +714,15 @@ def allocate_texture(ctx, shape, hostbuf=None, support_1D=False):
do not support 1D images, so 1D images are handled as 2D with one row
:param support_1D: force the image to be 1D if the shape has only one dim
"""
- if len(shape) == 1 and not(support_1D):
+ if len(shape) == 1 and not (support_1D):
shape = (1,) + shape
return pyopencl.Image(
ctx,
pyopencl.mem_flags.READ_ONLY | pyopencl.mem_flags.USE_HOST_PTR,
pyopencl.ImageFormat(
- pyopencl.channel_order.INTENSITY,
- pyopencl.channel_type.FLOAT
+ pyopencl.channel_order.INTENSITY, pyopencl.channel_type.FLOAT
),
- hostbuf=numpy.zeros(shape[::-1], dtype=numpy.float32)
+ hostbuf=numpy.zeros(shape[::-1], dtype=numpy.float32),
)
@@ -617,7 +744,7 @@ def check_textures_availability(ctx):
# There is no way to detect this until a kernel is compiled
try:
cc = ctx.devices[0].compute_capability_major_nv
- textures_available &= (cc >= 3)
+ textures_available &= cc >= 3
except (pyopencl.LogicError, AttributeError): # probably not a Nvidia GPU
pass
#
@@ -657,7 +784,7 @@ def query_kernel_info(program, kernel, what="WORK_GROUP_SIZE"):
:param kernel: kernel or name of the kernel
:param what: what is the query about ?
:return: int or 3-int for the workgroup size.
-
+
Possible information available are:
* 'COMPILE_WORK_GROUP_SIZE': Returns the work-group size specified inside the kernel (__attribute__((reqd_work_gr oup_size(X, Y, Z))))
* 'GLOBAL_WORK_SIZE': maximum global size that can be used to execute a kernel #OCL2.1!
@@ -665,15 +792,17 @@ def query_kernel_info(program, kernel, what="WORK_GROUP_SIZE"):
* 'PREFERRED_WORK_GROUP_SIZE_MULTIPLE': preferred multiple of workgroup size for launch. This is a performance hint.
* 'PRIVATE_MEM_SIZE' Returns the minimum amount of private memory, in bytes, used by each workitem in the kernel
* 'WORK_GROUP_SIZE': maximum work-group size that can be used to execute a kernel on a specific device given by device
-
+
Further information on:
https://www.khronos.org/registry/OpenCL/sdk/1.1/docs/man/xhtml/clGetKernelWorkGroupInfo.html
-
+
"""
assert isinstance(program, pyopencl.Program)
if not isinstance(kernel, pyopencl.Kernel):
kernel_name = kernel
- assert kernel in (k.function_name for k in program.all_kernels()), "the kernel exists"
+ assert kernel in (
+ k.function_name for k in program.all_kernels()
+ ), "the kernel exists"
kernel = program.__getattr__(kernel_name)
device = program.devices[0]
diff --git a/src/silx/opencl/conftest.py b/src/silx/opencl/conftest.py
index 1fdc516..f6cf5de 100644
--- a/src/silx/opencl/conftest.py
+++ b/src/silx/opencl/conftest.py
@@ -1,5 +1,6 @@
import pytest
+
@pytest.mark.usefixtures("use_opencl")
def setup_module(module):
pass
diff --git a/src/silx/opencl/convolution.py b/src/silx/opencl/convolution.py
index 481e8fb..99ecd02 100644
--- a/src/silx/opencl/convolution.py
+++ b/src/silx/opencl/convolution.py
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# /*##########################################################################
#
-# Copyright (c) 2019 European Synchrotron Radiation Facility
+# Copyright (c) 2019-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -29,20 +29,30 @@ __license__ = "MIT"
__date__ = "01/08/2019"
import numpy as np
-from copy import copy # python2
from .common import pyopencl as cl
import pyopencl.array as parray
from .processing import OpenclProcessing, EventDescription
from .utils import ConvolutionInfos
+
class Convolution(OpenclProcessing):
"""
A class for performing convolution on CPU/GPU with OpenCL.
"""
- def __init__(self, shape, kernel, axes=None, mode=None, ctx=None,
- devicetype="all", platformid=None, deviceid=None,
- profile=False, extra_options=None):
+ def __init__(
+ self,
+ shape,
+ kernel,
+ axes=None,
+ mode=None,
+ ctx=None,
+ devicetype="all",
+ platformid=None,
+ deviceid=None,
+ profile=False,
+ extra_options=None,
+ ):
"""Constructor of OpenCL Convolution.
:param shape: shape of the array.
@@ -70,9 +80,14 @@ class Convolution(OpenclProcessing):
"allocate_tmp_array": True,
"dont_use_textures": False,
"""
- OpenclProcessing.__init__(self, ctx=ctx, devicetype=devicetype,
- platformid=platformid, deviceid=deviceid,
- profile=profile)
+ OpenclProcessing.__init__(
+ self,
+ ctx=ctx,
+ devicetype=devicetype,
+ platformid=platformid,
+ deviceid=deviceid,
+ profile=profile,
+ )
self._configure_extra_options(extra_options)
self._determine_use_case(shape, kernel, axes)
@@ -88,7 +103,7 @@ class Convolution(OpenclProcessing):
}
extra_opts = extra_options or {}
self.extra_options.update(extra_opts)
- self.use_textures = not(self.extra_options["dont_use_textures"])
+ self.use_textures = not (self.extra_options["dont_use_textures"])
self.use_textures &= self.check_textures_availability()
def _get_dimensions(self, shape, kernel):
@@ -133,8 +148,7 @@ class Convolution(OpenclProcessing):
if axes in convol_infos.allowed_axes[uc_name]:
self.use_case_name = uc_name
self.use_case_desc = uc_params["name"]
- #~ self.use_case_kernels = uc_params["kernels"].copy()
- self.use_case_kernels = copy(uc_params["kernels"]) # TODO use the above line once we get rid of python2
+ self.use_case_kernels = uc_params["kernels"].copy()
if self.use_case_name is None:
raise ValueError(
"Cannot find a use case for data ndim = %d, kernel ndim = %d and axes=%s"
@@ -143,8 +157,7 @@ class Convolution(OpenclProcessing):
# TODO implement this use case
if self.use_case_name == "batched_separable_2D_1D_3D":
raise NotImplementedError(
- "The use case %s is not implemented"
- % self.use_case_name
+ "The use case %s is not implemented" % self.use_case_name
)
#
self.axes = axes
@@ -168,7 +181,7 @@ class Convolution(OpenclProcessing):
"allocate_tmp_array": "data_tmp",
}
# Nonseparable transforms do not need tmp array
- if not(self.separable):
+ if not (self.separable):
self.extra_options["allocate_tmp_array"] = False
# Allocate arrays
for option_name, array_name in option_array_names.items():
@@ -182,7 +195,7 @@ class Convolution(OpenclProcessing):
if isinstance(self.kernel, np.ndarray):
self.d_kernel = parray.to_device(self.queue, self.kernel)
else:
- if not(isinstance(self.kernel, parray.Array)):
+ if not (isinstance(self.kernel, parray.Array)):
raise ValueError("kernel must be either numpy array or pyopencl array")
self.d_kernel = self.kernel
self._old_input_ref = None
@@ -207,7 +220,7 @@ class Convolution(OpenclProcessing):
% (self.mode, str(mp.keys()))
)
# TODO
- if not(self.use_textures) and self.mode.lower() == "constant":
+ if not (self.use_textures) and self.mode.lower() == "constant":
raise NotImplementedError(
"mode='constant' is not implemented without textures yet"
)
@@ -228,28 +241,30 @@ class Convolution(OpenclProcessing):
compile_options = [str("-DUSED_CONV_MODE=%d" % self._c_conv_mode)]
if self.use_textures:
kernel_files = ["convolution_textures.cl"]
- compile_options.extend([
- str("-DIMAGE_DIMS=%d" % self.data_ndim),
- str("-DFILTER_DIMS=%d" % self.kernel_ndim),
- ])
+ compile_options.extend(
+ [
+ str("-DIMAGE_DIMS=%d" % self.data_ndim),
+ str("-DFILTER_DIMS=%d" % self.kernel_ndim),
+ ]
+ )
d_kernel_ref = self.d_kernel_tex
else:
kernel_files = ["convolution.cl"]
d_kernel_ref = self.d_kernel.data
- self.compile_kernels(
- kernel_files=kernel_files,
- compile_options=compile_options
- )
+ self.compile_kernels(kernel_files=kernel_files, compile_options=compile_options)
self.ndrange = self.shape[::-1]
self.wg = None
kernel_args = [
self.queue,
- self.ndrange, self.wg,
+ self.ndrange,
+ self.wg,
None,
None,
d_kernel_ref,
np.int32(self.kernel.shape[0]),
- self.Nx, self.Ny, self.Nz
+ self.Nx,
+ self.Ny,
+ self.Nz,
]
if self.kernel_ndim == 2:
kernel_args.insert(6, np.int32(self.kernel.shape[1]))
@@ -263,10 +278,7 @@ class Convolution(OpenclProcessing):
if self.separable:
if self.data_tmp is not None:
self.swap_pattern = {
- 2: [
- ("data_in", "data_tmp"),
- ("data_tmp", "data_out")
- ],
+ 2: [("data_in", "data_tmp"), ("data_tmp", "data_out")],
3: [
("data_in", "data_out"),
("data_out", "data_tmp"),
@@ -322,14 +334,14 @@ class Convolution(OpenclProcessing):
else:
raise ValueError("Please provide either arr= or shape=")
if ndim < dim_min or ndim > dim_max:
- raise ValueError("%s dimensions should be between %d and %d"
- % (name, dim_min, dim_max)
+ raise ValueError(
+ "%s dimensions should be between %d and %d" % (name, dim_min, dim_max)
)
return ndim
def _check_array(self, arr):
# TODO allow cl.Buffer
- if not(isinstance(arr, parray.Array) or isinstance(arr, np.ndarray)):
+ if not (isinstance(arr, parray.Array) or isinstance(arr, np.ndarray)):
raise TypeError("Expected either pyopencl.array.Array or numpy.ndarray")
# TODO composition with ImageProcessing/cast
if arr.dtype != np.float32:
@@ -351,14 +363,12 @@ class Convolution(OpenclProcessing):
self.data_in = array
data_in_ref = self.data_in
if output is not None:
- if not(isinstance(output, np.ndarray)):
+ if not (isinstance(output, np.ndarray)):
self._old_output_ref = self.data_out
self.data_out = output
# Update OpenCL kernel arguments with new array references
self.kernel_args = self._configure_kernel_args(
- self.kernel_args,
- data_in_ref,
- self.data_out
+ self.kernel_args, data_in_ref, self.data_out
)
def _separable_convolution(self):
@@ -372,9 +382,7 @@ class Convolution(OpenclProcessing):
# Batched: one kernel call in total
opencl_kernel = self.kernels.get_kernel(self.use_case_kernels[axis])
opencl_kernel_args = self._configure_kernel_args(
- self.kernel_args,
- input_ref,
- output_ref
+ self.kernel_args, input_ref, output_ref
)
ev = opencl_kernel(*opencl_kernel_args)
if self.profile:
@@ -395,9 +403,7 @@ class Convolution(OpenclProcessing):
self.data_out = self._old_output_ref
self._old_output_ref = None
self.kernel_args = self._configure_kernel_args(
- self.kernel_args,
- self.data_in,
- self.data_out
+ self.kernel_args, self.data_in, self.data_out
)
def _get_output(self, output):
@@ -433,7 +439,4 @@ class Convolution(OpenclProcessing):
res = self._get_output(output)
return res
-
__call__ = convolve
-
-
diff --git a/src/silx/opencl/image.py b/src/silx/opencl/image.py
index 6a4a854..ec30e66 100644
--- a/src/silx/opencl/image.py
+++ b/src/silx/opencl/image.py
@@ -2,7 +2,7 @@
# Project: silx
# https://github.com/silx-kit/silx
#
-# Copyright (C) 2012-2017 European Synchrotron Radiation Facility, Grenoble, France
+# Copyright (C) 2012-2023 European Synchrotron Radiation Facility, Grenoble, France
#
# Principal author: Jérôme Kieffer (Jerome.Kieffer@ESRF.eu)
#
@@ -37,7 +37,6 @@ __contact__ = "jerome.kieffer@esrf.fr"
import os
import logging
import numpy
-from collections import OrderedDict
from math import floor, ceil, sqrt, log
from .common import pyopencl, kernel_workgroup_size
@@ -49,20 +48,30 @@ logger = logging.getLogger(__name__)
class ImageProcessing(OpenclProcessing):
-
kernel_files = ["cast", "map", "max_min", "histogram"]
- converter = {numpy.dtype(numpy.uint8): "u8_to_float",
- numpy.dtype(numpy.int8): "s8_to_float",
- numpy.dtype(numpy.uint16): "u16_to_float",
- numpy.dtype(numpy.int16): "s16_to_float",
- numpy.dtype(numpy.uint32): "u32_to_float",
- numpy.dtype(numpy.int32): "s32_to_float",
- }
-
- def __init__(self, shape=None, ncolors=1, template=None,
- ctx=None, devicetype="all", platformid=None, deviceid=None,
- block_size=None, memory=None, profile=False):
+ converter = {
+ numpy.dtype(numpy.uint8): "u8_to_float",
+ numpy.dtype(numpy.int8): "s8_to_float",
+ numpy.dtype(numpy.uint16): "u16_to_float",
+ numpy.dtype(numpy.int16): "s16_to_float",
+ numpy.dtype(numpy.uint32): "u32_to_float",
+ numpy.dtype(numpy.int32): "s32_to_float",
+ }
+
+ def __init__(
+ self,
+ shape=None,
+ ncolors=1,
+ template=None,
+ ctx=None,
+ devicetype="all",
+ platformid=None,
+ deviceid=None,
+ block_size=None,
+ memory=None,
+ profile=False,
+ ):
"""Constructor of the ImageProcessing class
:param ctx: actual working context, left to None for automatic
@@ -76,9 +85,16 @@ class ImageProcessing(OpenclProcessing):
:param profile: switch on profiling to be able to profile at the kernel
level, store profiling elements (makes code slightly slower)
"""
- OpenclProcessing.__init__(self, ctx=ctx, devicetype=devicetype,
- platformid=platformid, deviceid=deviceid,
- block_size=block_size, memory=memory, profile=profile)
+ OpenclProcessing.__init__(
+ self,
+ ctx=ctx,
+ devicetype=devicetype,
+ platformid=platformid,
+ deviceid=deviceid,
+ block_size=block_size,
+ memory=memory,
+ profile=profile,
+ )
if template is not None:
shape = template.shape
if len(shape) > 2:
@@ -91,36 +107,51 @@ class ImageProcessing(OpenclProcessing):
self.ncolors = ncolors
self.shape = shape
assert shape is not None
- self.buffer_shape = self.shape if self.ncolors == 1 else self.shape + (self.ncolors,)
+ self.buffer_shape = (
+ self.shape if self.ncolors == 1 else self.shape + (self.ncolors,)
+ )
kernel_files = [os.path.join("image", i) for i in self.kernel_files]
- self.compile_kernels(kernel_files,
- compile_options="-DNB_COLOR=%i" % self.ncolors)
+ self.compile_kernels(
+ kernel_files, compile_options="-DNB_COLOR=%i" % self.ncolors
+ )
if self.ncolors == 1:
img_shape = self.shape
else:
img_shape = self.shape + (self.ncolors,)
- buffers = [BufferDescription("image0_d", img_shape, numpy.float32, None),
- BufferDescription("image1_d", img_shape, numpy.float32, None),
- BufferDescription("image2_d", img_shape, numpy.float32, None),
- BufferDescription("max_min_d", 2, numpy.float32, None),
- BufferDescription("cnt_d", 1, numpy.int32, None), ]
+ buffers = [
+ BufferDescription("image0_d", img_shape, numpy.float32, None),
+ BufferDescription("image1_d", img_shape, numpy.float32, None),
+ BufferDescription("image2_d", img_shape, numpy.float32, None),
+ BufferDescription("max_min_d", 2, numpy.float32, None),
+ BufferDescription("cnt_d", 1, numpy.int32, None),
+ ]
# Temporary buffer for max-min reduction
- self.wg_red = kernel_workgroup_size(self.program, self.kernels.max_min_reduction_stage1)
+ self.wg_red = kernel_workgroup_size(
+ self.program, self.kernels.max_min_reduction_stage1
+ )
if self.wg_red > 1:
- self.wg_red = min(self.wg_red,
- numpy.int32(1 << int(floor(log(sqrt(numpy.prod(self.shape)), 2)))))
- tmp = BufferDescription("tmp_max_min_d", 2 * self.wg_red, numpy.float32, None)
+ self.wg_red = min(
+ self.wg_red,
+ numpy.int32(1 << int(floor(log(sqrt(numpy.prod(self.shape)), 2)))),
+ )
+ tmp = BufferDescription(
+ "tmp_max_min_d", 2 * self.wg_red, numpy.float32, None
+ )
buffers.append(tmp)
self.allocate_buffers(buffers, use_array=True)
self.cl_mem["cnt_d"].fill(0)
def __repr__(self):
- return "ImageProcessing for shape=%s, %i colors initalized on %s" % \
- (self.shape, self.ncolors, self.ctx.devices[0].name)
-
- def _get_in_out_buffers(self, img=None, copy=True, out=None,
- out_dtype=None, out_size=None):
+ return "ImageProcessing for shape=%s, %i colors initalized on %s" % (
+ self.shape,
+ self.ncolors,
+ self.ctx.devices[0].name,
+ )
+
+ def _get_in_out_buffers(
+ self, img=None, copy=True, out=None, out_dtype=None, out_size=None
+ ):
"""Internal method used to select the proper buffers before processing.
:param img: expects a numpy array or a pyopencl.array of dim 2 or 3
@@ -129,7 +160,7 @@ class ImageProcessing(OpenclProcessing):
:param out_dtype: enforce the type of the output buffer (optional)
:param out_size: enforce the size of the output buffer (optional)
:return: input_buffer, output_buffer
-
+
Nota: this is not locked.
"""
events = []
@@ -148,7 +179,9 @@ class ImageProcessing(OpenclProcessing):
if out_dtype != numpy.float32 and out_size:
name = "%s_%s_d" % (numpy.dtype(out_dtype), out_size)
if name not in self.cl_mem:
- output_array = self.cl_mem[name] = pyopencl.array.empty(self.queue, (out_size,), out_dtype)
+ output_array = self.cl_mem[name] = pyopencl.array.empty(
+ self.queue, (out_size,), out_dtype
+ )
else:
output_array = self.cl_mem[name]
else:
@@ -158,7 +191,9 @@ class ImageProcessing(OpenclProcessing):
input_array = self.cl_mem["image1_d"]
if isinstance(img, pyopencl.array.Array):
if copy:
- evt = pyopencl.enqueue_copy(self.queue, self.cl_mem["image1_d"].data, img.data)
+ evt = pyopencl.enqueue_copy(
+ self.queue, self.cl_mem["image1_d"].data, img.data
+ )
input_array = self.cl_mem["image1_d"]
events.append(EventDescription("copy D->D", evt))
else:
@@ -169,11 +204,19 @@ class ImageProcessing(OpenclProcessing):
# assume this is numpy
if img.dtype.itemsize > 4:
logger.warning("Casting to float32 on CPU")
- evt = pyopencl.enqueue_copy(self.queue, self.cl_mem["image1_d"].data, numpy.ascontiguousarray(img, numpy.float32))
+ evt = pyopencl.enqueue_copy(
+ self.queue,
+ self.cl_mem["image1_d"].data,
+ numpy.ascontiguousarray(img, numpy.float32),
+ )
input_array = self.cl_mem["image1_d"]
events.append(EventDescription("cast+copy H->D", evt))
else:
- evt = pyopencl.enqueue_copy(self.queue, self.cl_mem["image1_d"].data, numpy.ascontiguousarray(img))
+ evt = pyopencl.enqueue_copy(
+ self.queue,
+ self.cl_mem["image1_d"].data,
+ numpy.ascontiguousarray(img),
+ )
input_array = self.cl_mem["image1_d"]
events.append(EventDescription("copy H->D", evt))
if self.profile:
@@ -181,8 +224,8 @@ class ImageProcessing(OpenclProcessing):
return input_array, output_array
def to_float(self, img, copy=True, out=None):
- """ Takes any array and convert it to a float array for ease of processing.
-
+ """Takes any array and convert it to a float array for ease of processing.
+
:param img: expects a numpy array or a pyopencl.array of dim 2 or 3
:param copy: set to False to directly re-use a pyopencl array
:param out: provide an output buffer to store the result
@@ -194,16 +237,23 @@ class ImageProcessing(OpenclProcessing):
input_array, output_array = self._get_in_out_buffers(img, copy, out)
if (img.dtype.itemsize > 4) or (img.dtype == numpy.float32):
# copy device -> device, already there as float32
- ev = pyopencl.enqueue_copy(self.queue, output_array.data, input_array.data)
+ ev = pyopencl.enqueue_copy(
+ self.queue, output_array.data, input_array.data
+ )
events.append(EventDescription("copy D->D", ev))
else:
# Cast to float:
name = self.converter[img.dtype]
kernel = self.kernels.get_kernel(name)
- ev = kernel(self.queue, (self.shape[1], self.shape[0]), None,
- input_array.data, output_array.data,
- numpy.int32(self.shape[1]), numpy.int32(self.shape[0])
- )
+ ev = kernel(
+ self.queue,
+ (self.shape[1], self.shape[0]),
+ None,
+ input_array.data,
+ output_array.data,
+ numpy.int32(self.shape[1]),
+ numpy.int32(self.shape[0]),
+ )
events.append(EventDescription("cast %s" % name, ev))
if self.profile:
@@ -218,14 +268,14 @@ class ImageProcessing(OpenclProcessing):
def normalize(self, img, mini=0.0, maxi=1.0, copy=True, out=None):
"""Scale the intensity of the image so that the minimum is 0 and the
maximum is 1.0 (or any value suggested).
-
+
:param img: numpy array or pyopencl array of dim 2 or 3 and of type float
:param mini: Expected minimum value
:param maxi: expected maxiumum value
:param copy: set to False to use directly the input buffer
:param out: provides an output buffer. prevents a copy D->H
-
- This uses a min/max reduction in two stages plus a map operation
+
+ This uses a min/max reduction in two stages plus a map operation
"""
assert img.shape == self.buffer_shape
events = []
@@ -235,34 +285,55 @@ class ImageProcessing(OpenclProcessing):
if self.wg_red == 1:
# Probably on MacOS CPU WG==1 --> serial code.
kernel = self.kernels.get_kernel("max_min_serial")
- evt = kernel(self.queue, (1,), (1,),
- input_array.data,
- size,
- self.cl_mem["max_min_d"].data)
+ evt = kernel(
+ self.queue,
+ (1,),
+ (1,),
+ input_array.data,
+ size,
+ self.cl_mem["max_min_d"].data,
+ )
ed = EventDescription("max_min_serial", evt)
events.append(ed)
else:
stage1 = self.kernels.max_min_reduction_stage1
stage2 = self.kernels.max_min_reduction_stage2
local_mem = pyopencl.LocalMemory(int(self.wg_red * 8))
- k1 = stage1(self.queue, (int(self.wg_red ** 2),), (int(self.wg_red),),
- input_array.data,
- self.cl_mem["tmp_max_min_d"].data,
- size,
- local_mem)
- k2 = stage2(self.queue, (int(self.wg_red),), (int(self.wg_red),),
- self.cl_mem["tmp_max_min_d"].data,
- self.cl_mem["max_min_d"].data,
- local_mem)
-
- events += [EventDescription("max_min_stage1", k1),
- EventDescription("max_min_stage2", k2)]
-
- evt = self.kernels.normalize_image(self.queue, (self.shape[1], self.shape[0]), None,
- input_array.data, output_array.data,
- numpy.int32(self.shape[1]), numpy.int32(self.shape[0]),
- self.cl_mem["max_min_d"].data,
- numpy.float32(mini), numpy.float32(maxi))
+ k1 = stage1(
+ self.queue,
+ (int(self.wg_red**2),),
+ (int(self.wg_red),),
+ input_array.data,
+ self.cl_mem["tmp_max_min_d"].data,
+ size,
+ local_mem,
+ )
+ k2 = stage2(
+ self.queue,
+ (int(self.wg_red),),
+ (int(self.wg_red),),
+ self.cl_mem["tmp_max_min_d"].data,
+ self.cl_mem["max_min_d"].data,
+ local_mem,
+ )
+
+ events += [
+ EventDescription("max_min_stage1", k1),
+ EventDescription("max_min_stage2", k2),
+ ]
+
+ evt = self.kernels.normalize_image(
+ self.queue,
+ (self.shape[1], self.shape[0]),
+ None,
+ input_array.data,
+ output_array.data,
+ numpy.int32(self.shape[1]),
+ numpy.int32(self.shape[0]),
+ self.cl_mem["max_min_d"].data,
+ numpy.float32(mini),
+ numpy.float32(maxi),
+ )
events.append(EventDescription("normalize", evt))
if self.profile:
self.events += events
@@ -274,32 +345,32 @@ class ImageProcessing(OpenclProcessing):
output_array.finish()
return output_array
- def histogram(self, img=None, nbins=255, range=None,
- log_scale=False, copy=True, out=None):
+ def histogram(
+ self, img=None, nbins=255, range=None, log_scale=False, copy=True, out=None
+ ):
"""Compute the histogram of a set of data.
-
+
:param img: input image. If None, use the one already on the device
:param nbins: number of bins
- :param range: the lower and upper range of the bins. If not provided,
- range is simply ``(a.min(), a.max())``. Values outside the
- range are ignored. The first element of the range must be
+ :param range: the lower and upper range of the bins. If not provided,
+ range is simply ``(a.min(), a.max())``. Values outside the
+ range are ignored. The first element of the range must be
less than or equal to the second.
- :param log_scale: perform the binning in lograrithmic scale.
+ :param log_scale: perform the binning in lograrithmic scale.
Open to extension
:param copy: unset to directly use the input buffer without copy
- :param out: use a provided array for offering the result
+ :param out: use a provided array for offering the result
:return: histogram (size=nbins), edges (size=nbins+1)
- API similar to numpy
+ API similar to numpy
"""
assert img.shape == self.buffer_shape
input_array = self.to_float(img, copy=copy, out=self.cl_mem["image0_d"])
events = []
with self.sem:
- input_array, output_array = self._get_in_out_buffers(input_array, copy=False,
- out=out,
- out_dtype=numpy.int32,
- out_size=nbins)
+ input_array, output_array = self._get_in_out_buffers(
+ input_array, copy=False, out=out, out_dtype=numpy.int32, out_size=nbins
+ )
if range is None:
# measure actually the bounds
@@ -308,27 +379,43 @@ class ImageProcessing(OpenclProcessing):
# Probably on MacOS CPU WG==1 --> serial code.
kernel = self.kernels.get_kernel("max_min_serial")
- evt = kernel(self.queue, (1,), (1,),
- input_array.data,
- size,
- self.cl_mem["max_min_d"].data)
+ evt = kernel(
+ self.queue,
+ (1,),
+ (1,),
+ input_array.data,
+ size,
+ self.cl_mem["max_min_d"].data,
+ )
events.append(EventDescription("max_min_serial", evt))
else:
stage1 = self.kernels.max_min_reduction_stage1
stage2 = self.kernels.max_min_reduction_stage2
- local_mem = pyopencl.LocalMemory(int(self.wg_red * 2 * numpy.dtype("float32").itemsize))
- k1 = stage1(self.queue, (int(self.wg_red ** 2),), (int(self.wg_red),),
- input_array.data,
- self.cl_mem["tmp_max_min_d"].data,
- size,
- local_mem)
- k2 = stage2(self.queue, (int(self.wg_red),), (int(self.wg_red),),
- self.cl_mem["tmp_max_min_d"].data,
- self.cl_mem["max_min_d"].data,
- local_mem)
-
- events += [EventDescription("max_min_stage1", k1),
- EventDescription("max_min_stage2", k2)]
+ local_mem = pyopencl.LocalMemory(
+ int(self.wg_red * 2 * numpy.dtype("float32").itemsize)
+ )
+ k1 = stage1(
+ self.queue,
+ (int(self.wg_red**2),),
+ (int(self.wg_red),),
+ input_array.data,
+ self.cl_mem["tmp_max_min_d"].data,
+ size,
+ local_mem,
+ )
+ k2 = stage2(
+ self.queue,
+ (int(self.wg_red),),
+ (int(self.wg_red),),
+ self.cl_mem["tmp_max_min_d"].data,
+ self.cl_mem["max_min_d"].data,
+ local_mem,
+ )
+
+ events += [
+ EventDescription("max_min_stage1", k1),
+ EventDescription("max_min_stage2", k2),
+ ]
maxi, mini = self.cl_mem["max_min_d"].get()
else:
mini = numpy.float32(min(range))
@@ -338,13 +425,17 @@ class ImageProcessing(OpenclProcessing):
tmp_size = nb_engines * nbins
name = "tmp_int32_%s_d" % (tmp_size)
if name not in self.cl_mem:
- tmp_array = self.cl_mem[name] = pyopencl.array.empty(self.queue, (tmp_size,), numpy.int32)
+ tmp_array = self.cl_mem[name] = pyopencl.array.empty(
+ self.queue, (tmp_size,), numpy.int32
+ )
else:
tmp_array = self.cl_mem[name]
edge_name = "tmp_float32_%s_d" % (nbins + 1)
if edge_name not in self.cl_mem:
- edges_array = self.cl_mem[edge_name] = pyopencl.array.empty(self.queue, (nbins + 1,), numpy.float32)
+ edges_array = self.cl_mem[edge_name] = pyopencl.array.empty(
+ self.queue, (nbins + 1,), numpy.float32
+ )
else:
edges_array = self.cl_mem[edge_name]
@@ -356,21 +447,27 @@ class ImageProcessing(OpenclProcessing):
else:
map_operation = numpy.int32(0)
kernel = self.kernels.get_kernel("histogram")
- wg = min(device.max_work_group_size,
- 1 << (int(ceil(log(nbins, 2)))),
- self.kernels.max_workgroup_size(kernel))
- evt = kernel(self.queue, (wg * nb_engines,), (wg,),
- input_array.data,
- numpy.int32(input_array.size),
- mini,
- maxi,
- map_operation,
- output_array.data,
- edges_array.data,
- numpy.int32(nbins),
- tmp_array.data,
- self.cl_mem["cnt_d"].data,
- shared)
+ wg = min(
+ device.max_work_group_size,
+ 1 << (int(ceil(log(nbins, 2)))),
+ self.kernels.max_workgroup_size(kernel),
+ )
+ evt = kernel(
+ self.queue,
+ (wg * nb_engines,),
+ (wg,),
+ input_array.data,
+ numpy.int32(input_array.size),
+ mini,
+ maxi,
+ map_operation,
+ output_array.data,
+ edges_array.data,
+ numpy.int32(nbins),
+ tmp_array.data,
+ self.cl_mem["cnt_d"].data,
+ shared,
+ )
events.append(EventDescription("histogram", evt))
if self.profile:
diff --git a/src/silx/opencl/linalg.py b/src/silx/opencl/linalg.py
index 77d826b..573ebce 100644
--- a/src/silx/opencl/linalg.py
+++ b/src/silx/opencl/linalg.py
@@ -34,14 +34,23 @@ from .common import pyopencl
from .processing import EventDescription, OpenclProcessing
import pyopencl.array as parray
+
cl = pyopencl
class LinAlg(OpenclProcessing):
-
kernel_files = ["linalg.cl"]
- def __init__(self, shape, do_checks=False, ctx=None, devicetype="all", platformid=None, deviceid=None, profile=False):
+ def __init__(
+ self,
+ shape,
+ do_checks=False,
+ ctx=None,
+ devicetype="all",
+ platformid=None,
+ deviceid=None,
+ profile=False,
+ ):
"""
Create a "Linear Algebra" plan for a given image shape.
@@ -56,32 +65,34 @@ class LinAlg(OpenclProcessing):
store profiling elements (makes code slightly slower)
"""
- OpenclProcessing.__init__(self, ctx=ctx, devicetype=devicetype,
- platformid=platformid, deviceid=deviceid,
- profile=profile)
+ OpenclProcessing.__init__(
+ self,
+ ctx=ctx,
+ devicetype=devicetype,
+ platformid=platformid,
+ deviceid=deviceid,
+ profile=profile,
+ )
self.d_gradient = parray.empty(self.queue, shape, np.complex64)
self.d_gradient.fill(np.complex64(0.0))
self.d_image = parray.empty(self.queue, shape, np.float32)
self.d_image.fill(np.float32(0.0))
- self.add_to_cl_mem({
- "d_gradient": self.d_gradient,
- "d_image": self.d_image
- })
+ self.add_to_cl_mem({"d_gradient": self.d_gradient, "d_image": self.d_image})
self.wg2D = None
self.shape = shape
- self.ndrange2D = (
- int(self.shape[1]),
- int(self.shape[0])
- )
+ self.ndrange2D = (int(self.shape[1]), int(self.shape[0]))
self.do_checks = bool(do_checks)
OpenclProcessing.compile_kernels(self, self.kernel_files)
@staticmethod
def check_array(array, dtype, shape, arg_name):
if array.shape != shape or array.dtype != dtype:
- raise ValueError("%s should be a %s array of type %s" %(arg_name, str(shape), str(dtype)))
+ raise ValueError(
+ "%s should be a %s array of type %s"
+ % (arg_name, str(shape), str(dtype))
+ )
def get_data_references(self, src, dst, default_src_ref, default_dst_ref):
"""
@@ -97,7 +108,9 @@ class LinAlg(OpenclProcessing):
elif isinstance(dst, cl.Buffer):
dst_ref = dst
else:
- raise ValueError("dst should be either pyopencl.array.Array or pyopencl.Buffer")
+ raise ValueError(
+ "dst should be either pyopencl.array.Array or pyopencl.Buffer"
+ )
else:
dst_ref = default_dst_ref
@@ -127,21 +140,15 @@ class LinAlg(OpenclProcessing):
self.check_array(image, np.float32, self.shape, "image")
if dst is not None:
self.check_array(dst, np.complex64, self.shape, "dst")
- img_ref, grad_ref = self.get_data_references(image, dst, self.d_image.data, self.d_gradient.data)
+ img_ref, grad_ref = self.get_data_references(
+ image, dst, self.d_image.data, self.d_gradient.data
+ )
# Prepare the kernel call
- kernel_args = [
- img_ref,
- grad_ref,
- n_x,
- n_y
- ]
+ kernel_args = [img_ref, grad_ref, n_x, n_y]
# Call the gradient kernel
evt = self.kernels.kern_gradient2D(
- self.queue,
- self.ndrange2D,
- self.wg2D,
- *kernel_args
+ self.queue, self.ndrange2D, self.wg2D, *kernel_args
)
self.events.append(EventDescription("gradient2D", evt))
# TODO: should the wait be done in any case ?
@@ -184,21 +191,15 @@ class LinAlg(OpenclProcessing):
self.check_array(gradient, np.complex64, self.shape, "gradient")
if dst is not None:
self.check_array(dst, np.float32, self.shape, "dst")
- grad_ref, img_ref = self.get_data_references(gradient, dst, self.d_gradient.data, self.d_image.data)
+ grad_ref, img_ref = self.get_data_references(
+ gradient, dst, self.d_gradient.data, self.d_image.data
+ )
# Prepare the kernel call
- kernel_args = [
- grad_ref,
- img_ref,
- n_x,
- n_y
- ]
+ kernel_args = [grad_ref, img_ref, n_x, n_y]
# Call the gradient kernel
evt = self.kernels.kern_divergence2D(
- self.queue,
- self.ndrange2D,
- self.wg2D,
- *kernel_args
+ self.queue, self.ndrange2D, self.wg2D, *kernel_args
)
self.events.append(EventDescription("divergence2D", evt))
# TODO: should the wait be done in any case ?
diff --git a/src/silx/opencl/medfilt.py b/src/silx/opencl/medfilt.py
index ae63eb2..a18c5a4 100644
--- a/src/silx/opencl/medfilt.py
+++ b/src/silx/opencl/medfilt.py
@@ -2,7 +2,7 @@
# Project: Azimuthal integration
# https://github.com/silx-kit/pyFAI
#
-# Copyright (C) 2012-2017 European Synchrotron Radiation Facility, Grenoble, France
+# Copyright (C) 2012-2023 European Synchrotron Radiation Facility, Grenoble, France
#
# Principal author: Jérôme Kieffer (Jerome.Kieffer@ESRF.eu)
#
@@ -39,7 +39,6 @@ __contact__ = "jerome.kieffer@esrf.fr"
import logging
import numpy
-from collections import OrderedDict
from .common import pyopencl, kernel_workgroup_size
from .processing import EventDescription, OpenclProcessing, BufferDescription
@@ -53,23 +52,33 @@ logger = logging.getLogger(__name__)
class MedianFilter2D(OpenclProcessing):
"""A class for doing median filtering using OpenCL"""
+
buffers = [
- BufferDescription("result", 1, numpy.float32, mf.WRITE_ONLY),
- BufferDescription("image_raw", 1, numpy.float32, mf.READ_ONLY),
- BufferDescription("image", 1, numpy.float32, mf.READ_WRITE),
- ]
+ BufferDescription("result", 1, numpy.float32, mf.WRITE_ONLY),
+ BufferDescription("image_raw", 1, numpy.float32, mf.READ_ONLY),
+ BufferDescription("image", 1, numpy.float32, mf.READ_WRITE),
+ ]
kernel_files = ["preprocess.cl", "bitonic.cl", "medfilt.cl"]
- mapping = {numpy.int8: "s8_to_float",
- numpy.uint8: "u8_to_float",
- numpy.int16: "s16_to_float",
- numpy.uint16: "u16_to_float",
- numpy.uint32: "u32_to_float",
- numpy.int32: "s32_to_float"}
-
- def __init__(self, shape, kernel_size=(3, 3),
- ctx=None, devicetype="all", platformid=None, deviceid=None,
- block_size=None, profile=False
- ):
+ mapping = {
+ numpy.int8: "s8_to_float",
+ numpy.uint8: "u8_to_float",
+ numpy.int16: "s16_to_float",
+ numpy.uint16: "u16_to_float",
+ numpy.uint32: "u32_to_float",
+ numpy.int32: "s32_to_float",
+ }
+
+ def __init__(
+ self,
+ shape,
+ kernel_size=(3, 3),
+ ctx=None,
+ devicetype="all",
+ platformid=None,
+ deviceid=None,
+ block_size=None,
+ profile=False,
+ ):
"""Constructor of the OpenCL 2D median filtering class
:param shape: shape of the images to treat
@@ -83,34 +92,56 @@ class MedianFilter2D(OpenclProcessing):
:param profile: switch on profiling to be able to profile at the kernel level,
store profiling elements (makes code slightly slower)
"""
- OpenclProcessing.__init__(self, ctx=ctx, devicetype=devicetype,
- platformid=platformid, deviceid=deviceid,
- block_size=block_size, profile=profile)
+ OpenclProcessing.__init__(
+ self,
+ ctx=ctx,
+ devicetype=devicetype,
+ platformid=platformid,
+ deviceid=deviceid,
+ block_size=block_size,
+ profile=profile,
+ )
self.shape = shape
self.size = self.shape[0] * self.shape[1]
self.kernel_size = self.calc_kernel_size(kernel_size)
self.workgroup_size = (self.calc_wg(self.kernel_size), 1) # 3D kernel
- self.buffers = [BufferDescription(i.name, i.size * self.size, i.dtype, i.flags)
- for i in self.__class__.buffers]
+ self.buffers = [
+ BufferDescription(i.name, i.size * self.size, i.dtype, i.flags)
+ for i in self.__class__.buffers
+ ]
self.allocate_buffers()
self.local_mem = self._get_local_mem(self.workgroup_size[0])
- OpenclProcessing.compile_kernels(self, self.kernel_files, "-D NIMAGE=%i" % self.size)
+ OpenclProcessing.compile_kernels(
+ self, self.kernel_files, "-D NIMAGE=%i" % self.size
+ )
self.set_kernel_arguments()
def set_kernel_arguments(self):
- """Parametrize all kernel arguments
- """
+ """Parametrize all kernel arguments"""
for val in self.mapping.values():
- self.cl_kernel_args[val] = OrderedDict(((i, self.cl_mem[i]) for i in ("image_raw", "image")))
- self.cl_kernel_args["medfilt2d"] = OrderedDict((("image", self.cl_mem["image"]),
- ("result", self.cl_mem["result"]),
- ("local", self.local_mem),
- ("khs1", numpy.int32(self.kernel_size[0] // 2)), # Kernel half-size along dim1 (lines)
- ("khs2", numpy.int32(self.kernel_size[1] // 2)), # Kernel half-size along dim2 (columns)
- ("height", numpy.int32(self.shape[0])), # Image size along dim1 (lines)
- ("width", numpy.int32(self.shape[1]))))
-# ('debug', self.cl_mem["debug"]))) # Image size along dim2 (columns))
+ self.cl_kernel_args[val] = dict(
+ ((i, self.cl_mem[i]) for i in ("image_raw", "image"))
+ )
+ self.cl_kernel_args["medfilt2d"] = dict(
+ (
+ ("image", self.cl_mem["image"]),
+ ("result", self.cl_mem["result"]),
+ ("local", self.local_mem),
+ (
+ "khs1",
+ numpy.int32(self.kernel_size[0] // 2),
+ ), # Kernel half-size along dim1 (lines)
+ (
+ "khs2",
+ numpy.int32(self.kernel_size[1] // 2),
+ ), # Kernel half-size along dim2 (columns)
+ ("height", numpy.int32(self.shape[0])), # Image size along dim1 (lines)
+ ("width", numpy.int32(self.shape[1])),
+ )
+ )
+
+ # ('debug', self.cl_mem["debug"]))) # Image size along dim2 (columns))
def _get_local_mem(self, wg):
return pyopencl.LocalMemory(wg * 32) # 4byte per float, 8 element per thread
@@ -125,13 +156,26 @@ class MedianFilter2D(OpenclProcessing):
dest_type = numpy.dtype([i.dtype for i in self.buffers if i.name == dest][0])
events = []
if (data.dtype == dest_type) or (data.dtype.itemsize > dest_type.itemsize):
- copy_image = pyopencl.enqueue_copy(self.queue, self.cl_mem[dest], numpy.ascontiguousarray(data, dest_type))
+ copy_image = pyopencl.enqueue_copy(
+ self.queue, self.cl_mem[dest], numpy.ascontiguousarray(data, dest_type)
+ )
events.append(EventDescription("copy H->D %s" % dest, copy_image))
else:
- copy_image = pyopencl.enqueue_copy(self.queue, self.cl_mem["image_raw"], numpy.ascontiguousarray(data))
+ copy_image = pyopencl.enqueue_copy(
+ self.queue, self.cl_mem["image_raw"], numpy.ascontiguousarray(data)
+ )
kernel = getattr(self.program, self.mapping[data.dtype.type])
- cast_to_float = kernel(self.queue, (self.size,), None, self.cl_mem["image_raw"], self.cl_mem[dest])
- events += [EventDescription("copy H->D %s" % dest, copy_image), EventDescription("cast to float", cast_to_float)]
+ cast_to_float = kernel(
+ self.queue,
+ (self.size,),
+ None,
+ self.cl_mem["image_raw"],
+ self.cl_mem[dest],
+ )
+ events += [
+ EventDescription("copy H->D %s" % dest, copy_image),
+ EventDescription("cast to float", cast_to_float),
+ ]
if self.profile:
self.events += events
@@ -180,7 +224,9 @@ class MedianFilter2D(OpenclProcessing):
amws = kernel_workgroup_size(self.program, "medfilt2d")
logger.warning("max actual workgroup size: %s, expected: %s", amws, wg)
if wg > amws:
- raise RuntimeError("Workgroup size is too big for medfilt2d: %s>%s" % (wg, amws))
+ raise RuntimeError(
+ "Workgroup size is too big for medfilt2d: %s>%s" % (wg, amws)
+ )
localmem = self._get_local_mem(wg)
@@ -197,11 +243,11 @@ class MedianFilter2D(OpenclProcessing):
kwargs["khs2"] = kernel_half_size[1]
kwargs["height"] = numpy.int32(image.shape[0])
kwargs["width"] = numpy.int32(image.shape[1])
-# for k, v in kwargs.items():
-# print("%s: %s (%s)" % (k, v, type(v)))
- mf2d = self.kernels.medfilt2d(self.queue,
- (wg, image.shape[1]),
- (wg, 1), *list(kwargs.values()))
+ # for k, v in kwargs.items():
+ # print("%s: %s (%s)" % (k, v, type(v)))
+ mf2d = self.kernels.medfilt2d(
+ self.queue, (wg, image.shape[1]), (wg, 1), *list(kwargs.values())
+ )
events.append(EventDescription("median filter 2d", mf2d))
result = numpy.empty(image.shape, numpy.float32)
@@ -211,12 +257,12 @@ class MedianFilter2D(OpenclProcessing):
if self.profile:
self.events += events
return result
+
__call__ = medfilt2d
@staticmethod
def calc_kernel_size(kernel_size):
- """format the kernel size to be a 2-length numpy array of int32
- """
+ """format the kernel size to be a 2-length numpy array of int32"""
kernel_size = numpy.asarray(kernel_size, dtype=numpy.int32)
if kernel_size.shape == ():
kernel_size = numpy.repeat(kernel_size.item(), 2).astype(numpy.int32)
@@ -249,7 +295,7 @@ class _MedFilt2d(object):
* The filling mode in scipy.signal.medfilt2d is zero-padding
* This implementation is equivalent to:
- scipy.ndimage.filters.median_filter(ary, kernel_size, mode="nearest")
+ scipy.ndimage.median_filter(ary, kernel_size, mode="nearest")
"""
image = numpy.atleast_2d(ary)
@@ -263,4 +309,5 @@ class _MedFilt2d(object):
cls.median_filter = MedianFilter2D(new_shape, kernel_size, ctx=ctx)
return cls.median_filter.medfilt2d(image, kernel_size=kernel_size)
+
medfilt2d = _MedFilt2d.medfilt2d
diff --git a/src/silx/opencl/processing.py b/src/silx/opencl/processing.py
index c223354..6db21d0 100644
--- a/src/silx/opencl/processing.py
+++ b/src/silx/opencl/processing.py
@@ -3,7 +3,7 @@
# Project: S I L X project
# https://github.com/silx-kit/silx
#
-# Copyright (C) 2012-2018 European Synchrotron Radiation Facility, Grenoble, France
+# Copyright (C) 2012-2023 European Synchrotron Radiation Facility, Grenoble, France
#
# Principal author: Jérôme Kieffer (Jerome.Kieffer@ESRF.eu)
#
@@ -37,22 +37,31 @@ __author__ = "Jerome Kieffer"
__contact__ = "Jerome.Kieffer@ESRF.eu"
__license__ = "MIT"
__copyright__ = "European Synchrotron Radiation Facility, Grenoble, France"
-__date__ = "06/10/2022"
+__date__ = "09/11/2022"
__status__ = "stable"
import sys
import os
import logging
import gc
-from collections import namedtuple, OrderedDict
+from collections import namedtuple
import numpy
import threading
-from .common import ocl, pyopencl, release_cl_buffers, query_kernel_info, allocate_texture, check_textures_availability
+from .common import (
+ ocl,
+ pyopencl,
+ release_cl_buffers,
+ query_kernel_info,
+ allocate_texture,
+ check_textures_availability,
+)
from .utils import concatenate_cl_kernel
import platform
BufferDescription = namedtuple("BufferDescription", ["name", "size", "dtype", "flags"])
-EventDescription = namedtuple("EventDescription", ["name", "event"]) # Deprecated, please use ProfileDescription
+EventDescription = namedtuple(
+ "EventDescription", ["name", "event"]
+) # Deprecated, please use ProfileDescription
ProfileDescription = namedtuple("ProfileDescription", ["name", "start", "stop"])
logger = logging.getLogger(__name__)
@@ -72,8 +81,9 @@ class KernelContainer(object):
def get_kernels(self):
"return the dictionary with all kernels"
- return dict(item for item in self.__dict__.items()
- if not item[0].startswith("_"))
+ return dict(
+ item for item in self.__dict__.items() if not item[0].startswith("_")
+ )
def get_kernel(self, name):
"get a kernel from its name"
@@ -96,7 +106,9 @@ class KernelContainer(object):
else:
kernel = self.get_kernel(kernel_name)
- return query_kernel_info(self._program, kernel, "PREFERRED_WORK_GROUP_SIZE_MULTIPLE")
+ return query_kernel_info(
+ self._program, kernel, "PREFERRED_WORK_GROUP_SIZE_MULTIPLE"
+ )
class OpenclProcessing(object):
@@ -108,14 +120,24 @@ class OpenclProcessing(object):
* Functions to compile kernels, cache them and clean them
* helper functions to clone the object
"""
+
# Example of how to create an output buffer of 10 floats
- buffers = [BufferDescription("output", 10, numpy.float32, None),
- ]
+ buffers = [
+ BufferDescription("output", 10, numpy.float32, None),
+ ]
# list of kernel source files to be concatenated before compilation of the program
kernel_files = []
- def __init__(self, ctx=None, devicetype="all", platformid=None, deviceid=None,
- block_size=None, memory=None, profile=False):
+ def __init__(
+ self,
+ ctx=None,
+ devicetype="all",
+ platformid=None,
+ deviceid=None,
+ block_size=None,
+ memory=None,
+ profile=False,
+ ):
"""Constructor of the abstract OpenCL processing class
:param ctx: actual working context, left to None for automatic
@@ -140,9 +162,12 @@ class OpenclProcessing(object):
if ctx:
self.ctx = ctx
else:
- self.ctx = ocl.create_context(devicetype=devicetype,
- platformid=platformid, deviceid=deviceid,
- memory=memory)
+ self.ctx = ocl.create_context(
+ devicetype=devicetype,
+ platformid=platformid,
+ deviceid=deviceid,
+ memory=memory,
+ )
device_name = self.ctx.devices[0].name.strip()
platform_name = self.ctx.devices[0].platform.name.strip()
platform = ocl.get_platform(platform_name)
@@ -158,8 +183,7 @@ class OpenclProcessing(object):
return check_textures_availability(self.ctx)
def __del__(self):
- """Destructor: release all buffers and programs
- """
+ """Destructor: release all buffers and programs"""
try:
self.reset_log()
self.free_kernels()
@@ -201,19 +225,27 @@ class OpenclProcessing(object):
ualloc = 0
for buf in buffers:
ualloc += numpy.dtype(buf.dtype).itemsize * numpy.prod(buf.size)
- logger.info("%.3fMB are needed on device: %s, which has %.3fMB",
- ualloc / 1.0e6, self.device, self.device.memory / 1.0e6)
+ logger.info(
+ "%.3fMB are needed on device: %s, which has %.3fMB",
+ ualloc / 1.0e6,
+ self.device,
+ self.device.memory / 1.0e6,
+ )
if ualloc >= self.device.memory:
- raise MemoryError("Fatal error in allocate_buffers. Not enough "
- " device memory for buffers (%lu requested, %lu available)"
- % (ualloc, self.device.memory))
+ raise MemoryError(
+ "Fatal error in allocate_buffers. Not enough "
+ " device memory for buffers (%lu requested, %lu available)"
+ % (ualloc, self.device.memory)
+ )
# do the allocation
try:
if use_array:
for buf in buffers:
- mem[buf.name] = pyopencl.array.empty(self.queue, buf.size, buf.dtype)
+ mem[buf.name] = pyopencl.array.empty(
+ self.queue, buf.size, buf.dtype
+ )
else:
for buf in buffers:
size = numpy.dtype(buf.dtype).itemsize * numpy.prod(buf.size)
@@ -241,8 +273,7 @@ class OpenclProcessing(object):
return self.kernels.max_workgroup_size(kernel_name)
def free_buffers(self):
- """free all device.memory allocated on the device
- """
+ """free all device.memory allocated on the device"""
with self.sem:
for key, buf in list(self.cl_mem.items()):
if buf is not None:
@@ -272,21 +303,22 @@ class OpenclProcessing(object):
compile_options = compile_options or self.get_compiler_options()
logger.info("Compiling file %s with options %s", kernel_files, compile_options)
try:
- self.program = pyopencl.Program(self.ctx, kernel_src).build(options=compile_options)
+ self.program = pyopencl.Program(self.ctx, kernel_src).build(
+ options=compile_options
+ )
except (pyopencl.MemoryError, pyopencl.LogicError) as error:
raise MemoryError(error)
else:
self.kernels = KernelContainer(self.program)
def free_kernels(self):
- """Free all kernels
- """
+ """Free all kernels"""
for kernel in self.cl_kernel_args:
self.cl_kernel_args[kernel] = []
self.kernels = None
self.program = None
-# Methods about Profiling
+ # Methods about Profiling
def set_profiling(self, value=True):
"""Switch On/Off the profiling flag of the command queue to allow debugging
@@ -301,10 +333,16 @@ class OpenclProcessing(object):
if self.queue is not None:
self.queue.finish()
if self.profile:
- self.queue = pyopencl.CommandQueue(self.ctx,
- properties=pyopencl.command_queue_properties.PROFILING_ENABLE)
+ self.queue = pyopencl.CommandQueue(
+ self.ctx,
+ properties=pyopencl.command_queue_properties.PROFILING_ENABLE,
+ )
else:
self.queue = pyopencl.CommandQueue(self.ctx)
+ # Update all memory-objects with the new queue:
+ for obj, cl_obj in list(self.cl_mem.items()):
+ if isinstance(cl_obj, pyopencl.array.Array):
+ self.cl_mem[obj] = cl_obj.with_queue(self.queue)
def profile_add(self, event, desc):
"""
@@ -332,7 +370,11 @@ class OpenclProcessing(object):
if isinstance(event_desc, ProfileDescription):
self.events.append(event_desc)
else:
- if isinstance(event_desc, EventDescription) or "__len__" in dir(e) and len(e) == 2:
+ if (
+ isinstance(event_desc, EventDescription)
+ or "__len__" in dir(event_desc)
+ and len(event_desc) == 2
+ ):
desc, event = event_desc
else:
desc = "?"
@@ -349,16 +391,20 @@ class OpenclProcessing(object):
def log_profile(self, stats=False):
"""If we are in profiling mode, prints out all timing for every single OpenCL call
-
+
:param stats: if True, prints the statistics on each kernel instead of all execution timings
:return: list of lines to print
"""
total_time = 0.0
out = [""]
if stats:
- stats = OrderedDict()
- out.append(f"OpenCL kernel profiling statistics in milliseconds for: {self.__class__.__name__}")
- out.append(f"{'Kernel name':>50} (count): min median max mean std")
+ stats = {}
+ out.append(
+ f"OpenCL kernel profiling statistics in milliseconds for: {self.__class__.__name__}"
+ )
+ out.append(
+ f"{'Kernel name':>50} (count): min median max mean std"
+ )
else:
stats = None
out.append(f"Profiling info for OpenCL: {self.__class__.__name__}")
@@ -369,7 +415,11 @@ class OpenclProcessing(object):
name = e[0]
t0 = e[1]
t1 = e[2]
- elif isinstance(e, EventDescription) or "__len__" in dir(e) and len(e) == 2:
+ elif (
+ isinstance(e, EventDescription)
+ or "__len__" in dir(e)
+ and len(e) == 2
+ ):
name = e[0]
pr = e[1].profile
t0 = pr.start
@@ -391,9 +441,13 @@ class OpenclProcessing(object):
if stats is not None:
for k, v in stats.items():
n = numpy.array(v)
- out.append(f"{k:>50} ({len(v):5}): {n.min():8.3f} {numpy.median(n):8.3f} {n.max():8.3f} {n.mean():8.3f} {n.std():8.3f}")
+ out.append(
+ f"{k:>50} ({len(v):5}): {n.min():8.3f} {numpy.median(n):8.3f} {n.max():8.3f} {n.mean():8.3f} {n.std():8.3f}"
+ )
out.append("_" * 80)
- out.append(f"{'Total OpenCL execution time':>50} : {total_time:.3f}ms")
+ out.append(
+ f"{'Total OpenCL execution time':>50} : {total_time:.3f}ms"
+ )
logger.info(os.linesep.join(out))
return out
@@ -405,7 +459,7 @@ class OpenclProcessing(object):
with self.sem:
self.events = []
-# Methods about textures
+ # Methods about textures
def allocate_texture(self, shape, hostbuf=None, support_1D=False):
return allocate_texture(self.ctx, shape, hostbuf=hostbuf, support_1D=support_1D)
@@ -424,8 +478,8 @@ class OpenclProcessing(object):
# force 2D with one row in this case
# ~ ndim = 2
shp = (1,) + shp
- copy_kwargs = {"origin":(0,) * ndim, "region": shp[::-1]}
- if not(isinstance(arr, numpy.ndarray)): # assuming pyopencl.array.Array
+ copy_kwargs = {"origin": (0,) * ndim, "region": shp[::-1]}
+ if not (isinstance(arr, numpy.ndarray)): # assuming pyopencl.array.Array
# D->D copy
copy_args[2] = arr.data
copy_kwargs["offset"] = 0
@@ -436,9 +490,11 @@ class OpenclProcessing(object):
def x87_volatile_option(self):
# this is running 32 bits OpenCL woth POCL
if self._X87_VOLATILE is None:
- if (platform.machine() in ("i386", "i686", "x86_64", "AMD64") and
- (tuple.__itemsize__ == 4) and
- self.ctx.devices[0].platform.name == 'Portable Computing Language'):
+ if (
+ platform.machine() in ("i386", "i686", "x86_64", "AMD64")
+ and (tuple.__itemsize__ == 4)
+ and self.ctx.devices[0].platform.name == "Portable Computing Language"
+ ):
self._X87_VOLATILE = "-DX87_VOLATILE=volatile"
else:
self._X87_VOLATILE = ""
@@ -455,6 +511,7 @@ class OpenclProcessing(object):
option_list.append(self.x87_volatile_option)
return " ".join(i for i in option_list if i)
+
# This should be implemented by concrete class
# def __copy__(self):
# """Shallow copy of the object
diff --git a/src/silx/opencl/projection.py b/src/silx/opencl/projection.py
index a02e28b..cf4b625 100644
--- a/src/silx/opencl/projection.py
+++ b/src/silx/opencl/projection.py
@@ -48,14 +48,23 @@ class Projection(OpenclProcessing):
A class for performing a tomographic projection (Radon Transform) using
OpenCL
"""
+
kernel_files = ["proj.cl", "array_utils.cl"]
logger.warning("Forward Projecter is untested and unsuported for now")
- def __init__(self, slice_shape, angles, axis_position=None,
- detector_width=None, normalize=False, ctx=None,
- devicetype="all", platformid=None, deviceid=None,
- profile=False
- ):
+ def __init__(
+ self,
+ slice_shape,
+ angles,
+ axis_position=None,
+ detector_width=None,
+ normalize=False,
+ ctx=None,
+ devicetype="all",
+ platformid=None,
+ deviceid=None,
+ profile=False,
+ ):
"""Constructor of the OpenCL projector.
:param slice_shape: shape of the slice: (num_rows, num_columns).
@@ -84,9 +93,14 @@ class Projection(OpenclProcessing):
# if sys.platform.startswith('darwin'): # assuming no discrete GPU
# raise NotImplementedError("Backprojection is not implemented on CPU for OS X yet")
- OpenclProcessing.__init__(self, ctx=ctx, devicetype=devicetype,
- platformid=platformid, deviceid=deviceid,
- profile=profile)
+ OpenclProcessing.__init__(
+ self,
+ ctx=ctx,
+ devicetype=devicetype,
+ platformid=platformid,
+ deviceid=deviceid,
+ profile=profile,
+ )
self.shape = slice_shape
self.axis_pos = axis_position
self.angles = angles
@@ -95,24 +109,27 @@ class Projection(OpenclProcessing):
# Default values
if self.axis_pos is None:
- self.axis_pos = (self.shape[1] - 1) / 2.
+ self.axis_pos = (self.shape[1] - 1) / 2.0
if self.dwidth is None:
self.dwidth = self.shape[1]
- if not(np.iterable(self.angles)):
+ if not (np.iterable(self.angles)):
if self.angles is None:
self.nprojs = self.shape[0]
else:
self.nprojs = self.angles
- self.angles = np.linspace(start=0,
- stop=np.pi,
- num=self.nprojs,
- endpoint=False).astype(dtype=np.float32)
+ self.angles = np.linspace(
+ start=0, stop=np.pi, num=self.nprojs, endpoint=False
+ ).astype(dtype=np.float32)
else:
self.nprojs = len(self.angles)
- self.offset_x = -np.float32((self.shape[1] - 1) / 2. - self.axis_pos) # TODO: custom
- self.offset_y = -np.float32((self.shape[0] - 1) / 2. - self.axis_pos) # TODO: custom
+ self.offset_x = -np.float32(
+ (self.shape[1] - 1) / 2.0 - self.axis_pos
+ ) # TODO: custom
+ self.offset_y = -np.float32(
+ (self.shape[0] - 1) / 2.0 - self.axis_pos
+ ) # TODO: custom
# Reset axis_pos once offset are computed
- self.axis_pos0 = np.float64((self.shape[1] - 1) / 2.)
+ self.axis_pos0 = np.float64((self.shape[1] - 1) / 2.0)
# Workgroup, ndrange and shared size
self.dimgrid_x = _idivup(self.dwidth, 16)
@@ -123,118 +140,122 @@ class Projection(OpenclProcessing):
self.wg = (16, 16)
self.ndrange = (
int(self.dimgrid_x) * self.wg[0], # int(): pyopencl <= 2015.1
- int(self.dimgrid_y) * self.wg[1] # int(): pyopencl <= 2015.1
+ int(self.dimgrid_y) * self.wg[1], # int(): pyopencl <= 2015.1
)
self._use_textures = self.check_textures_availability()
# Allocate memory
self.buffers = [
- BufferDescription("_d_sino", self._dimrecx * self._dimrecy, np.float32, mf.READ_WRITE),
+ BufferDescription(
+ "_d_sino", self._dimrecx * self._dimrecy, np.float32, mf.READ_WRITE
+ ),
BufferDescription("d_angles", self._dimrecy, np.float32, mf.READ_ONLY),
BufferDescription("d_beginPos", self._dimrecy * 2, np.int32, mf.READ_ONLY),
- BufferDescription("d_strideJoseph", self._dimrecy * 2, np.int32, mf.READ_ONLY),
- BufferDescription("d_strideLine", self._dimrecy * 2, np.int32, mf.READ_ONLY),
+ BufferDescription(
+ "d_strideJoseph", self._dimrecy * 2, np.int32, mf.READ_ONLY
+ ),
+ BufferDescription(
+ "d_strideLine", self._dimrecy * 2, np.int32, mf.READ_ONLY
+ ),
]
d_axis_corrections = parray.empty(self.queue, self.nprojs, np.float32)
d_axis_corrections.fill(np.float32(0.0))
- self.add_to_cl_mem(
- {
- "d_axis_corrections": d_axis_corrections
- }
+ self.add_to_cl_mem({"d_axis_corrections": d_axis_corrections})
+ self._tmp_extended_img = np.zeros(
+ (self.shape[0] + 2, self.shape[1] + 2), dtype=np.float32
)
- self._tmp_extended_img = np.zeros((self.shape[0] + 2, self.shape[1] + 2),
- dtype=np.float32)
- if not(self._use_textures):
+ if not (self._use_textures):
self.allocate_slice()
else:
self.allocate_textures()
self.allocate_buffers()
- self._ex_sino = np.zeros((self._dimrecy, self._dimrecx),
- dtype=np.float32)
- if not(self._use_textures):
- self.cl_mem["d_slice"].fill(0.)
+ self._ex_sino = np.zeros((self._dimrecy, self._dimrecx), dtype=np.float32)
+ if not (self._use_textures):
+ self.cl_mem["d_slice"].fill(0.0)
# enqueue_fill_buffer has issues if opencl 1.2 is not present
# ~ pyopencl.enqueue_fill_buffer(
- # ~ self.queue,
- # ~ self.cl_mem["d_slice"],
- # ~ np.float32(0),
- # ~ 0,
- # ~ self._tmp_extended_img.size * _sizeof(np.float32)
+ # ~ self.queue,
+ # ~ self.cl_mem["d_slice"],
+ # ~ np.float32(0),
+ # ~ 0,
+ # ~ self._tmp_extended_img.size * _sizeof(np.float32)
# ~ )
# Precomputations
self.compute_angles()
self.proj_precomputations()
- self.cl_mem["d_axis_corrections"].fill(0.)
+ self.cl_mem["d_axis_corrections"].fill(0.0)
# enqueue_fill_buffer has issues if opencl 1.2 is not present
# ~ pyopencl.enqueue_fill_buffer(
- # ~ self.queue,
- # ~ self.cl_mem["d_axis_corrections"],
- # ~ np.float32(0),
- # ~ 0,
- # ~ self.nprojs*_sizeof(np.float32)
- # ~ )
+ # ~ self.queue,
+ # ~ self.cl_mem["d_axis_corrections"],
+ # ~ np.float32(0),
+ # ~ 0,
+ # ~ self.nprojs*_sizeof(np.float32)
+ # ~ )
# Shorthands
self._d_sino = self.cl_mem["_d_sino"]
compile_options = None
- if not(self._use_textures):
+ if not (self._use_textures):
compile_options = "-DDONT_USE_TEXTURES"
OpenclProcessing.compile_kernels(
- self,
- self.kernel_files,
- compile_options=compile_options
+ self, self.kernel_files, compile_options=compile_options
)
# check that workgroup can actually be (16, 16)
- self.compiletime_workgroup_size = self.kernels.max_workgroup_size("forward_kernel_cpu")
+ self.compiletime_workgroup_size = self.kernels.max_workgroup_size(
+ "forward_kernel_cpu"
+ )
def compute_angles(self):
angles2 = np.zeros(self._dimrecy, dtype=np.float32) # dimrecy != num_projs
- angles2[:self.nprojs] = np.copy(self.angles)
- angles2[self.nprojs:] = angles2[self.nprojs - 1]
+ angles2[: self.nprojs] = np.copy(self.angles)
+ angles2[self.nprojs :] = angles2[self.nprojs - 1]
self.angles2 = angles2
pyopencl.enqueue_copy(self.queue, self.cl_mem["d_angles"], angles2)
def allocate_slice(self):
- ary = parray.empty(self.queue, (self.shape[1] + 2, self.shape[1] + 2), np.float32)
+ ary = parray.empty(
+ self.queue, (self.shape[1] + 2, self.shape[1] + 2), np.float32
+ )
ary.fill(0)
self.add_to_cl_mem({"d_slice": ary})
def allocate_textures(self):
self.d_image_tex = pyopencl.Image(
- self.ctx,
- mf.READ_ONLY | mf.USE_HOST_PTR,
- pyopencl.ImageFormat(
- pyopencl.channel_order.INTENSITY,
- pyopencl.channel_type.FLOAT
- ), hostbuf=np.ascontiguousarray(self._tmp_extended_img.T),
- )
+ self.ctx,
+ mf.READ_ONLY | mf.USE_HOST_PTR,
+ pyopencl.ImageFormat(
+ pyopencl.channel_order.INTENSITY, pyopencl.channel_type.FLOAT
+ ),
+ hostbuf=np.ascontiguousarray(self._tmp_extended_img.T),
+ )
def transfer_to_texture(self, image):
image2 = image
- if not(image.flags["C_CONTIGUOUS"] and image.dtype == np.float32):
+ if not (image.flags["C_CONTIGUOUS"] and image.dtype == np.float32):
image2 = np.ascontiguousarray(image)
- if not(self._use_textures):
+ if not (self._use_textures):
# TODO: create NoneEvent
return self.transfer_to_slice(image2)
# ~ return pyopencl.enqueue_copy(
- # ~ self.queue,
- # ~ self.cl_mem["d_slice"].data,
- # ~ image2,
- # ~ origin=(1, 1),
- # ~ region=image.shape[::-1]
- # ~ )
+ # ~ self.queue,
+ # ~ self.cl_mem["d_slice"].data,
+ # ~ image2,
+ # ~ origin=(1, 1),
+ # ~ region=image.shape[::-1]
+ # ~ )
else:
return pyopencl.enqueue_copy(
- self.queue,
- self.d_image_tex,
- image2,
- origin=(1, 1),
- region=image.shape[::-1]
- )
+ self.queue,
+ self.d_image_tex,
+ image2,
+ origin=(1, 1),
+ region=image.shape[::-1],
+ )
def transfer_device_to_texture(self, d_image):
- if not(self._use_textures):
+ if not (self._use_textures):
# TODO this copy should not be necessary
return self.cpy2d_to_slice(d_image)
else:
@@ -244,7 +265,10 @@ class Projection(OpenclProcessing):
d_image,
offset=0,
origin=(1, 1),
- region=(int(self.shape[1]), int(self.shape[0])) # self.shape[::-1] # pyopencl <= 2015.2
+ region=(
+ int(self.shape[1]),
+ int(self.shape[0]),
+ ), # self.shape[::-1] # pyopencl <= 2015.2
)
def transfer_to_slice(self, image):
@@ -323,7 +347,7 @@ class Projection(OpenclProcessing):
np.int32(self._dimrecx),
np.int32((0, 0)),
np.int32((0, 0)),
- sino_shape_ocl
+ sino_shape_ocl,
)
return self.kernels.cpy2d(self.queue, ndrange, wg, *kernel_args)
@@ -331,7 +355,10 @@ class Projection(OpenclProcessing):
"""
copy a Nx * Ny slice to self.d_slice which is (Nx+2)*(Ny+2)
"""
- ndrange = (int(self.shape[1]), int(self.shape[0])) # self.shape[::-1] # pyopencl < 2015.2
+ ndrange = (
+ int(self.shape[1]),
+ int(self.shape[0]),
+ ) # self.shape[::-1] # pyopencl < 2015.2
wg = None
slice_shape_ocl = np.int32(ndrange)
kernel_args = (
@@ -341,7 +368,7 @@ class Projection(OpenclProcessing):
np.int32(self.shape[1]),
np.int32((1, 1)),
np.int32((0, 0)),
- slice_shape_ocl
+ slice_shape_ocl,
)
return self.kernels.cpy2d(self.queue, ndrange, wg, *kernel_args)
@@ -364,7 +391,7 @@ class Projection(OpenclProcessing):
self.transfer_to_slice(image)
slice_ref = self.cl_mem["d_slice"].data
else:
- if not(self._use_textures):
+ if not (self._use_textures):
slice_ref = self.cl_mem["d_slice"].data
else:
slice_ref = self.d_image_tex
@@ -386,23 +413,17 @@ class Projection(OpenclProcessing):
self.offset_x,
self.offset_y,
np.int32(1), # josephnoclip, 1 by default
- np.int32(self.normalize)
+ np.int32(self.normalize),
)
# Call the kernel
- if not(self._use_textures):
+ if not (self._use_textures):
event_pj = self.kernels.forward_kernel_cpu(
- self.queue,
- self.ndrange,
- self.wg,
- *kernel_args
+ self.queue, self.ndrange, self.wg, *kernel_args
)
else:
event_pj = self.kernels.forward_kernel(
- self.queue,
- self.ndrange,
- self.wg,
- *kernel_args
+ self.queue, self.ndrange, self.wg, *kernel_args
)
events.append(EventDescription("projection", event_pj))
if dst is None:
@@ -410,7 +431,7 @@ class Projection(OpenclProcessing):
ev = pyopencl.enqueue_copy(self.queue, self._ex_sino, self._d_sino)
events.append(EventDescription("copy D->H result", ev))
ev.wait()
- res = np.copy(self._ex_sino[:self.nprojs, :self.dwidth])
+ res = np.copy(self._ex_sino[: self.nprojs, : self.dwidth])
else:
ev = self.cpy2d_to_sino(dst)
events.append(EventDescription("copy D->D result", ev))
diff --git a/src/silx/opencl/reconstruction.py b/src/silx/opencl/reconstruction.py
index c85fd42..c80a0ef 100644
--- a/src/silx/opencl/reconstruction.py
+++ b/src/silx/opencl/reconstruction.py
@@ -39,6 +39,7 @@ from .linalg import LinAlg
import pyopencl.array as parray
from pyopencl.elementwise import ElementwiseKernel
+
logger = logging.getLogger(__name__)
cl = pyopencl
@@ -65,13 +66,26 @@ class ReconstructionAlgorithm(OpenclProcessing):
store profiling elements (makes code slightly slower)
"""
- def __init__(self, sino_shape, slice_shape=None, axis_position=None, angles=None,
- ctx=None, devicetype="all", platformid=None, deviceid=None,
- profile=False
- ):
- OpenclProcessing.__init__(self, ctx=ctx, devicetype=devicetype,
- platformid=platformid, deviceid=deviceid,
- profile=profile)
+ def __init__(
+ self,
+ sino_shape,
+ slice_shape=None,
+ axis_position=None,
+ angles=None,
+ ctx=None,
+ devicetype="all",
+ platformid=None,
+ deviceid=None,
+ profile=False,
+ ):
+ OpenclProcessing.__init__(
+ self,
+ ctx=ctx,
+ devicetype=devicetype,
+ platformid=platformid,
+ deviceid=deviceid,
+ profile=profile,
+ )
# Create a backprojector
self.backprojector = Backprojection(
@@ -80,7 +94,7 @@ class ReconstructionAlgorithm(OpenclProcessing):
axis_position=axis_position,
angles=angles,
ctx=self.ctx,
- profile=profile
+ profile=profile,
)
# Create a projector
self.projector = Projection(
@@ -90,7 +104,7 @@ class ReconstructionAlgorithm(OpenclProcessing):
detector_width=self.backprojector.num_bins,
normalize=False,
ctx=self.ctx,
- profile=profile
+ profile=profile,
)
self.sino_shape = sino_shape
self.is_cpu = self.backprojector.is_cpu
@@ -99,32 +113,34 @@ class ReconstructionAlgorithm(OpenclProcessing):
self.d_data.fill(0.0)
self.d_sino = parray.empty_like(self.d_data)
self.d_sino.fill(0.0)
- self.d_x = parray.empty(self.queue,
- self.backprojector.slice_shape,
- dtype=np.float32)
+ self.d_x = parray.empty(
+ self.queue, self.backprojector.slice_shape, dtype=np.float32
+ )
self.d_x.fill(0.0)
self.d_x_old = parray.empty_like(self.d_x)
self.d_x_old.fill(0.0)
- self.add_to_cl_mem({
- "d_data": self.d_data,
- "d_sino": self.d_sino,
- "d_x": self.d_x,
- "d_x_old": self.d_x_old,
- })
+ self.add_to_cl_mem(
+ {
+ "d_data": self.d_data,
+ "d_sino": self.d_sino,
+ "d_x": self.d_x,
+ "d_x_old": self.d_x_old,
+ }
+ )
def proj(self, d_slice, d_sino):
"""
Project d_slice to d_sino
"""
- self.projector.transfer_device_to_texture(d_slice.data) #.wait()
+ self.projector.transfer_device_to_texture(d_slice.data) # .wait()
self.projector.projection(dst=d_sino)
def backproj(self, d_sino, d_slice):
"""
Backproject d_sino to d_slice
"""
- self.backprojector.transfer_device_to_texture(d_sino.data) #.wait()
+ self.backprojector.transfer_device_to_texture(d_sino.data) # .wait()
self.backprojector.backprojection(dst=d_slice)
@@ -153,15 +169,30 @@ class SIRT(ReconstructionAlgorithm):
implementation.
"""
- def __init__(self, sino_shape, slice_shape=None, axis_position=None, angles=None,
- ctx=None, devicetype="all", platformid=None, deviceid=None,
- profile=False
- ):
-
- ReconstructionAlgorithm.__init__(self, sino_shape, slice_shape=slice_shape,
- axis_position=axis_position, angles=angles,
- ctx=ctx, devicetype=devicetype, platformid=platformid,
- deviceid=deviceid, profile=profile)
+ def __init__(
+ self,
+ sino_shape,
+ slice_shape=None,
+ axis_position=None,
+ angles=None,
+ ctx=None,
+ devicetype="all",
+ platformid=None,
+ deviceid=None,
+ profile=False,
+ ):
+ ReconstructionAlgorithm.__init__(
+ self,
+ sino_shape,
+ slice_shape=slice_shape,
+ axis_position=axis_position,
+ angles=angles,
+ ctx=ctx,
+ devicetype=devicetype,
+ platformid=platformid,
+ deviceid=deviceid,
+ profile=profile,
+ )
self.compute_preconditioners()
def compute_preconditioners(self):
@@ -178,26 +209,31 @@ class SIRT(ReconstructionAlgorithm):
# r_{i,i} = 1/(sum_j a_{i,j})
slice_ones = np.ones(self.backprojector.slice_shape, dtype=np.float32)
- R = 1./self.projector.projection(slice_ones) # could be all done on GPU, but I want extra checks
- R[np.logical_not(np.isfinite(R))] = 1. # In the case where the rotation axis is excentred
+ R = 1.0 / self.projector.projection(
+ slice_ones
+ ) # could be all done on GPU, but I want extra checks
+ R[
+ np.logical_not(np.isfinite(R))
+ ] = 1.0 # In the case where the rotation axis is excentred
self.d_R = parray.to_device(self.queue, R)
# c_{j,j} = 1/(sum_i a_{i,j})
sino_ones = np.ones(self.sino_shape, dtype=np.float32)
- C = 1./self.backprojector.backprojection(sino_ones)
- C[np.logical_not(np.isfinite(C))] = 1. # In the case where the rotation axis is excentred
+ C = 1.0 / self.backprojector.backprojection(sino_ones)
+ C[
+ np.logical_not(np.isfinite(C))
+ ] = 1.0 # In the case where the rotation axis is excentred
self.d_C = parray.to_device(self.queue, C)
- self.add_to_cl_mem({
- "d_R": self.d_R,
- "d_C": self.d_C
- })
+ self.add_to_cl_mem({"d_R": self.d_R, "d_C": self.d_C})
# TODO: compute and possibly return the residual
def run(self, data, n_it):
"""
Run n_it iterations of the SIRT algorithm.
"""
- cl.enqueue_copy(self.queue, self.d_data.data, np.ascontiguousarray(data.astype(np.float32)))
+ cl.enqueue_copy(
+ self.queue, self.d_data.data, np.ascontiguousarray(data.astype(np.float32))
+ )
d_x_old = self.d_x_old
d_x = self.d_x
@@ -254,26 +290,44 @@ class TV(ReconstructionAlgorithm):
the AMD opencl implementation.
"""
- def __init__(self, sino_shape, slice_shape=None, axis_position=None, angles=None,
- ctx=None, devicetype="all", platformid=None, deviceid=None,
- profile=False
- ):
- ReconstructionAlgorithm.__init__(self, sino_shape, slice_shape=slice_shape,
- axis_position=axis_position, angles=angles,
- ctx=ctx, devicetype=devicetype, platformid=platformid,
- deviceid=deviceid, profile=profile)
+ def __init__(
+ self,
+ sino_shape,
+ slice_shape=None,
+ axis_position=None,
+ angles=None,
+ ctx=None,
+ devicetype="all",
+ platformid=None,
+ deviceid=None,
+ profile=False,
+ ):
+ ReconstructionAlgorithm.__init__(
+ self,
+ sino_shape,
+ slice_shape=slice_shape,
+ axis_position=axis_position,
+ angles=angles,
+ ctx=ctx,
+ devicetype=devicetype,
+ platformid=platformid,
+ deviceid=deviceid,
+ profile=profile,
+ )
self.compute_preconditioners()
# Create a LinAlg instance
self.linalg = LinAlg(self.backprojector.slice_shape, ctx=self.ctx)
# Positivity constraint
- self.elwise_clamp = ElementwiseKernel(self.ctx, "float *a", "a[i] = max(a[i], 0.0f);")
+ self.elwise_clamp = ElementwiseKernel(
+ self.ctx, "float *a", "a[i] = max(a[i], 0.0f);"
+ )
# Projection onto the L-infinity ball of radius Lambda
self.elwise_proj_linf = ElementwiseKernel(
self.ctx,
"float2* a, float Lambda",
"a[i].x = copysign(min(fabs(a[i].x), Lambda), a[i].x); a[i].y = copysign(min(fabs(a[i].y), Lambda), a[i].y);",
- "elwise_proj_linf"
+ "elwise_proj_linf",
)
# Additional arrays
self.linalg.gradient(self.d_x)
@@ -284,11 +338,13 @@ class TV(ReconstructionAlgorithm):
self.d_p.fill(0)
self.d_q.fill(0)
self.d_tmp.fill(0)
- self.add_to_cl_mem({
- "d_p": self.d_p,
- "d_q": self.d_q,
- "d_tmp": self.d_tmp,
- })
+ self.add_to_cl_mem(
+ {
+ "d_p": self.d_p,
+ "d_q": self.d_q,
+ "d_tmp": self.d_tmp,
+ }
+ )
self.theta = 1.0
@@ -308,30 +364,36 @@ class TV(ReconstructionAlgorithm):
# Compute the diagonal preconditioner "Sigma"
slice_ones = np.ones(self.backprojector.slice_shape, dtype=np.float32)
- Sigma_k = 1./self.projector.projection(slice_ones)
- Sigma_k[np.logical_not(np.isfinite(Sigma_k))] = 1.
+ Sigma_k = 1.0 / self.projector.projection(slice_ones)
+ Sigma_k[np.logical_not(np.isfinite(Sigma_k))] = 1.0
self.d_Sigma_k = parray.to_device(self.queue, Sigma_k)
self.d_Sigma_kp1 = self.d_Sigma_k + 1 # TODO: memory vs computation
- self.Sigma_grad = 1/2.0 # For discrete gradient, sum|D_i,j| = 2 along lines or cols
+ self.Sigma_grad = (
+ 1 / 2.0
+ ) # For discrete gradient, sum|D_i,j| = 2 along lines or cols
# Compute the diagonal preconditioner "Tau"
sino_ones = np.ones(self.sino_shape, dtype=np.float32)
C = self.backprojector.backprojection(sino_ones)
- Tau = 1./(C + 2.)
+ Tau = 1.0 / (C + 2.0)
self.d_Tau = parray.to_device(self.queue, Tau)
- self.add_to_cl_mem({
- "d_Sigma_k": self.d_Sigma_k,
- "d_Sigma_kp1": self.d_Sigma_kp1,
- "d_Tau": self.d_Tau
- })
+ self.add_to_cl_mem(
+ {
+ "d_Sigma_k": self.d_Sigma_k,
+ "d_Sigma_kp1": self.d_Sigma_kp1,
+ "d_Tau": self.d_Tau,
+ }
+ )
def run(self, data, n_it, Lambda, pos_constraint=False):
"""
Run n_it iterations of the TV-regularized reconstruction,
with the regularization parameter Lambda.
"""
- cl.enqueue_copy(self.queue, self.d_data.data, np.ascontiguousarray(data.astype(np.float32)))
+ cl.enqueue_copy(
+ self.queue, self.d_data.data, np.ascontiguousarray(data.astype(np.float32))
+ )
d_x = self.d_x
d_x_old = self.d_x_old
@@ -348,7 +410,7 @@ class TV(ReconstructionAlgorithm):
for k in range(0, n_it):
# Update primal variables
d_x_old[:] = d_x[:]
- #~ x = x + Tau*div(p) - Tau*Kadj(q)
+ # ~ x = x + Tau*div(p) - Tau*Kadj(q)
self.backproj(d_q, d_tmp)
self.linalg.divergence(d_p)
# TODO: this in less than three ops (one kernel ?)
@@ -360,20 +422,20 @@ class TV(ReconstructionAlgorithm):
self.elwise_clamp(d_x)
# Update dual variables
- #~ p = proj_linf(p + Sigma_grad*gradient(x + theta*(x - x_old)), Lambda)
+ # ~ p = proj_linf(p + Sigma_grad*gradient(x + theta*(x - x_old)), Lambda)
d_tmp[:] = d_x[:]
# FIXME: mul_add is out of place, put an equivalent thing in linalg...
- #~ d_tmp.mul_add(1 + theta, d_x_old, -theta)
- d_tmp *= 1+self.theta
- d_tmp -= self.theta*d_x_old
+ # ~ d_tmp.mul_add(1 + theta, d_x_old, -theta)
+ d_tmp *= 1 + self.theta
+ d_tmp -= self.theta * d_x_old
self.linalg.gradient(d_tmp)
# TODO: out of place mul_add
- #~ d_p.mul_add(1, L.cl_mem["d_gradient"], Sigma_grad)
+ # ~ d_p.mul_add(1, L.cl_mem["d_gradient"], Sigma_grad)
self.linalg.cl_mem["d_gradient"] *= self.Sigma_grad
d_p += self.linalg.cl_mem["d_gradient"]
self.elwise_proj_linf(d_p, Lambda)
- #~ q = (q + Sigma_k*K(x + theta*(x - x_old)) - Sigma_k*data)/(1.0 + Sigma_k)
+ # ~ q = (q + Sigma_k*K(x + theta*(x - x_old)) - Sigma_k*data)/(1.0 + Sigma_k)
self.proj(d_tmp, d_sino)
# TODO: this in less instructions
d_sino -= self.d_data
diff --git a/src/silx/opencl/sinofilter.py b/src/silx/opencl/sinofilter.py
index 890267e..fc447de 100644
--- a/src/silx/opencl/sinofilter.py
+++ b/src/silx/opencl/sinofilter.py
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# /*##########################################################################
#
-# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -38,8 +38,6 @@ from .processing import OpenclProcessing
from ..math.fft.clfft import CLFFT, __have_clfft__
from ..math.fft.npfft import NPFFT
from ..image.tomography import generate_powers, get_next_power, compute_fourier_filter
-from ..utils.deprecation import deprecated
-
class SinoFilter(OpenclProcessing):
@@ -50,12 +48,21 @@ class SinoFilter(OpenclProcessing):
- In 2D: (n_a, d_x): n_a filterings (1D FFT of size d_x)
- In 3D: (n_z, n_a, d_x): n_z*n_a filterings (1D FFT of size d_x)
"""
+
kernel_files = ["array_utils.cl"]
powers = generate_powers()
- def __init__(self, sino_shape, filter_name=None, ctx=None,
- devicetype="all", platformid=None, deviceid=None,
- profile=False, extra_options=None):
+ def __init__(
+ self,
+ sino_shape,
+ filter_name=None,
+ ctx=None,
+ devicetype="all",
+ platformid=None,
+ deviceid=None,
+ profile=False,
+ extra_options=None,
+ ):
"""Constructor of OpenCL FFT-Convolve.
:param sino_shape: shape of the sinogram.
@@ -72,9 +79,14 @@ class SinoFilter(OpenclProcessing):
:param dict extra_options: Advanced extra options.
Current options are: cutoff, use_numpy_fft
"""
- OpenclProcessing.__init__(self, ctx=ctx, devicetype=devicetype,
- platformid=platformid, deviceid=deviceid,
- profile=profile)
+ OpenclProcessing.__init__(
+ self,
+ ctx=ctx,
+ devicetype=devicetype,
+ platformid=platformid,
+ deviceid=deviceid,
+ profile=profile,
+ )
self._init_extra_options(extra_options)
self._calculate_shapes(sino_shape)
@@ -92,8 +104,9 @@ class SinoFilter(OpenclProcessing):
if self.ndim == 2:
n_angles, dwidth = sino_shape
else:
- raise ValueError("Invalid sinogram number of dimensions: "
- "expected 2 dimensions")
+ raise ValueError(
+ "Invalid sinogram number of dimensions: " "expected 2 dimensions"
+ )
self.sino_shape = sino_shape
self.n_angles = n_angles
self.dwidth = dwidth
@@ -110,14 +123,14 @@ class SinoFilter(OpenclProcessing):
Current options are: cutoff,
"""
self.extra_options = {
- "cutoff": 1.,
+ "cutoff": 1.0,
"use_numpy_fft": False,
}
if extra_options is not None:
self.extra_options.update(extra_options)
def _init_fft(self):
- if __have_clfft__ and not(self.extra_options["use_numpy_fft"]):
+ if __have_clfft__ and not (self.extra_options["use_numpy_fft"]):
self.fft_backend = "opencl"
self.fft = CLFFT(
self.sino_padded_shape,
@@ -127,17 +140,22 @@ class SinoFilter(OpenclProcessing):
)
else:
self.fft_backend = "numpy"
- print("The gpyfft module was not found. The Fourier transforms "
- "will be done on CPU. For more performances, it is advised "
- "to install gpyfft.""")
+ print(
+ "The gpyfft module was not found. The Fourier transforms "
+ "will be done on CPU. For more performances, it is advised "
+ "to install gpyfft."
+ ""
+ )
self.fft = NPFFT(
template=np.zeros(self.sino_padded_shape, "f"),
axes=(-1,),
)
def _allocate_memory(self):
- self.d_filter_f = parray.zeros(self.queue, (self.sino_f_shape[-1],), np.complex64)
- self.is_cpu = (self.device.type == "CPU")
+ self.d_filter_f = parray.zeros(
+ self.queue, (self.sino_f_shape[-1],), np.complex64
+ )
+ self.is_cpu = self.device.type == "CPU"
# These are already allocated by FFT() if using the opencl backend
if self.fft_backend == "opencl":
self.d_sino_padded = self.fft.data_in
@@ -160,7 +178,9 @@ class SinoFilter(OpenclProcessing):
self.dwidth_padded,
self.filter_name,
cutoff=self.extra_options["cutoff"],
- )[:self.dwidth_padded // 2 + 1] # R2C
+ )[
+ : self.dwidth_padded // 2 + 1
+ ] # R2C
self.set_filter(filter_f, normalize=True)
def set_filter(self, h_filt, normalize=True):
@@ -181,7 +201,7 @@ class SinoFilter(OpenclProcessing):
"""
% (self.sino_f_shape[-1], h_filt.size)
)
- if not(np.iscomplexobj(h_filt)):
+ if not (np.iscomplexobj(h_filt)):
print("Warning: expected a complex Fourier filter")
self.filter_f = h_filt
if normalize:
@@ -192,24 +212,27 @@ class SinoFilter(OpenclProcessing):
def _init_kernels(self):
OpenclProcessing.compile_kernels(self, self.kernel_files)
h, w = self.d_sino_f.shape
- self.mult_kern_args = (self.queue, (int(w), (int(h))), None,
- self.d_sino_f.data,
- self.d_filter_f.data,
- np.int32(w),
- np.int32(h))
+ self.mult_kern_args = (
+ self.queue,
+ (int(w), (int(h))),
+ None,
+ self.d_sino_f.data,
+ self.d_filter_f.data,
+ np.int32(w),
+ np.int32(h),
+ )
def check_array(self, arr):
if arr.dtype != np.float32:
raise ValueError("Expected data type = numpy.float32")
if arr.shape != self.sino_shape:
- raise ValueError("Expected sinogram shape %s, got %s" %
- (self.sino_shape, arr.shape))
- if not(isinstance(arr, np.ndarray) or isinstance(arr, parray.Array)):
- raise ValueError("Expected either numpy.ndarray or "
- "pyopencl.array.Array")
-
- def copy2d(self, dst, src, transfer_shape, dst_offset=(0, 0),
- src_offset=(0, 0)):
+ raise ValueError(
+ "Expected sinogram shape %s, got %s" % (self.sino_shape, arr.shape)
+ )
+ if not (isinstance(arr, np.ndarray) or isinstance(arr, parray.Array)):
+ raise ValueError("Expected either numpy.ndarray or " "pyopencl.array.Array")
+
+ def copy2d(self, dst, src, transfer_shape, dst_offset=(0, 0), src_offset=(0, 0)):
"""
:param dst:
@@ -219,18 +242,23 @@ class SinoFilter(OpenclProcessing):
:param src_offset:
"""
shape = tuple(int(i) for i in transfer_shape[::-1])
- ev = self.kernels.cpy2d(self.queue, shape, None,
- dst.data,
- src.data,
- np.int32(dst.shape[1]),
- np.int32(src.shape[1]),
- np.int32(dst_offset),
- np.int32(src_offset),
- np.int32(transfer_shape[::-1]))
+ ev = self.kernels.cpy2d(
+ self.queue,
+ shape,
+ None,
+ dst.data,
+ src.data,
+ np.int32(dst.shape[1]),
+ np.int32(src.shape[1]),
+ np.int32(dst_offset),
+ np.int32(src_offset),
+ np.int32(transfer_shape[::-1]),
+ )
ev.wait()
- def copy2d_host(self, dst, src, transfer_shape, dst_offset=(0, 0),
- src_offset=(0, 0)):
+ def copy2d_host(
+ self, dst, src, transfer_shape, dst_offset=(0, 0), src_offset=(0, 0)
+ ):
"""
:param dst:
@@ -242,7 +270,9 @@ class SinoFilter(OpenclProcessing):
s = transfer_shape
do = dst_offset
so = src_offset
- dst[do[0]:do[0] + s[0], do[1]:do[1] + s[1]] = src[so[0]:so[0] + s[0], so[1]:so[1] + s[1]]
+ dst[do[0] : do[0] + s[0], do[1] : do[1] + s[1]] = src[
+ so[0] : so[0] + s[0], so[1] : so[1] + s[1]
+ ]
def _prepare_input_sino(self, sino):
"""
@@ -269,7 +299,7 @@ class SinoFilter(OpenclProcessing):
self.d_sino_padded.finish() # should not be required here
else:
# Numpy backend: FFT/mult/IFFT are done on host.
- if not(isinstance(sino, np.ndarray)):
+ if not (isinstance(sino, np.ndarray)):
# Numpy backend + pyopencl input: need to copy D->H
self.tmp_sino_host[:] = sino[:]
h_sino_ref = self.tmp_sino_host
@@ -293,9 +323,11 @@ class SinoFilter(OpenclProcessing):
# As pyopencl does not support rectangular copies, we first have
# to call a kernel doing rectangular copy D->D, then do a copy
# D->H.
- self.copy2d(dst=self.tmp_sino_device,
- src=self.d_sino_padded,
- transfer_shape=self.sino_shape)
+ self.copy2d(
+ dst=self.tmp_sino_device,
+ src=self.d_sino_padded,
+ transfer_shape=self.sino_shape,
+ )
if self.is_cpu:
self.tmp_sino_device.finish() # should not be required here
res[:] = self.tmp_sino_device.get()[:]
@@ -306,11 +338,13 @@ class SinoFilter(OpenclProcessing):
if self.is_cpu:
res.finish() # should not be required here
else:
- if not(isinstance(res, np.ndarray)):
+ if not (isinstance(res, np.ndarray)):
# Numpy backend + pyopencl output: rect copy H->H + copy H->D
- self.copy2d_host(dst=self.tmp_sino_host,
- src=self.d_sino_padded,
- transfer_shape=self.sino_shape)
+ self.copy2d_host(
+ dst=self.tmp_sino_host,
+ src=self.d_sino_padded,
+ transfer_shape=self.sino_shape,
+ )
res[:] = self.tmp_sino_host[:]
else:
# Numpy backend + numpy output: rect copy H->H
@@ -331,9 +365,7 @@ class SinoFilter(OpenclProcessing):
def _multiply_fourier(self):
if self.fft_backend == "opencl":
# Everything is on device. Call the multiplication kernel.
- ev = self.kernels.inplace_complex_mul_2Dby1D(
- *self.mult_kern_args
- )
+ ev = self.kernels.inplace_complex_mul_2Dby1D(*self.mult_kern_args)
ev.wait()
if self.is_cpu:
self.d_sino_f.finish() # should not be required here
@@ -376,57 +408,3 @@ class SinoFilter(OpenclProcessing):
# ~ return output
__call__ = filter_sino
-
-
-
-
-# -------------------
-# - Compatibility -
-# -------------------
-
-
-def nextpow2(N):
- p = 1
- while p < N:
- p *= 2
- return p
-
-
-@deprecated(replacement="Backprojection.sino_filter", since_version="0.10")
-def fourier_filter(sino, filter_=None, fft_size=None):
- """Simple np based implementation of fourier space filter.
- This function is deprecated, please use silx.opencl.sinofilter.SinoFilter.
-
- :param sino: of shape shape = (num_projs, num_bins)
- :param filter: filter function to apply in fourier space
- :fft_size: size on which perform the fft. May be larger than the sino array
- :return: filtered sinogram
- """
- assert sino.ndim == 2
- num_projs, num_bins = sino.shape
- if fft_size is None:
- fft_size = nextpow2(num_bins * 2 - 1)
- else:
- assert fft_size >= num_bins
- if fft_size == num_bins:
- sino_zeropadded = sino.astype(np.float32)
- else:
- sino_zeropadded = np.zeros((num_projs, fft_size),
- dtype=np.complex64)
- sino_zeropadded[:, :num_bins] = sino.astype(np.float32)
-
- if filter_ is None:
- h = np.zeros(fft_size, dtype=np.float32)
- L2 = fft_size // 2 + 1
- h[0] = 1 / 4.
- j = np.linspace(1, L2, L2 // 2, False)
- h[1:L2:2] = -1. / (np.pi ** 2 * j ** 2)
- h[L2:] = np.copy(h[1:L2 - 1][::-1])
- filter_ = np.fft.fft(h).astype(np.complex64)
-
- # Linear convolution
- sino_f = np.fft.fft(sino, fft_size)
- sino_f = sino_f * filter_
- sino_filtered = np.fft.ifft(sino_f)[:, :num_bins].real
-
- return np.ascontiguousarray(sino_filtered.real, dtype=np.float32)
diff --git a/src/silx/opencl/sparse.py b/src/silx/opencl/sparse.py
index 709e3c7..9baa3a0 100644
--- a/src/silx/opencl/sparse.py
+++ b/src/silx/opencl/sparse.py
@@ -35,11 +35,13 @@ from pyopencl.scan import GenericScanKernel
from pyopencl.tools import dtype_to_ctype
from .common import pyopencl as cl
from .processing import OpenclProcessing, EventDescription, BufferDescription
+
mf = cl.mem_flags
CSRData = namedtuple("CSRData", ["data", "indices", "indptr"])
+
def tuple_to_csrdata(arrs):
"""
Converts a 3-tuple to a CSRData namedtuple.
@@ -49,13 +51,23 @@ def tuple_to_csrdata(arrs):
return CSRData(data=arrs[0], indices=arrs[1], indptr=arrs[2])
-
class CSR(OpenclProcessing):
kernel_files = ["sparse.cl"]
- def __init__(self, shape, dtype="f", max_nnz=None, idx_dtype=numpy.int32,
- ctx=None, devicetype="all", platformid=None, deviceid=None,
- block_size=None, memory=None, profile=False):
+ def __init__(
+ self,
+ shape,
+ dtype="f",
+ max_nnz=None,
+ idx_dtype=numpy.int32,
+ ctx=None,
+ devicetype="all",
+ platformid=None,
+ deviceid=None,
+ block_size=None,
+ memory=None,
+ profile=False,
+ ):
"""
Compute Compressed Sparse Row format of an image (2D matrix).
It is designed to be compatible with scipy.sparse.csr_matrix.
@@ -77,10 +89,16 @@ class CSR(OpenclProcessing):
for information on the other parameters.
"""
- OpenclProcessing.__init__(self, ctx=ctx, devicetype=devicetype,
- platformid=platformid, deviceid=deviceid,
- block_size=block_size, memory=memory,
- profile=profile)
+ OpenclProcessing.__init__(
+ self,
+ ctx=ctx,
+ devicetype=devicetype,
+ platformid=platformid,
+ deviceid=deviceid,
+ block_size=block_size,
+ memory=memory,
+ profile=profile,
+ )
self._set_parameters(shape, dtype, max_nnz, idx_dtype)
self._allocate_memory()
self._setup_kernels()
@@ -93,23 +111,23 @@ class CSR(OpenclProcessing):
self.shape = shape
self.size = numpy.prod(shape)
self._set_idx_dtype(idx_dtype)
- assert len(shape) == 2 #
+ assert len(shape) == 2 #
if max_nnz is None:
- self.max_nnz = numpy.prod(shape) # worst case
+ self.max_nnz = numpy.prod(shape) # worst case
else:
self.max_nnz = int(max_nnz)
self._set_dtype(dtype)
-
def _set_idx_dtype(self, idx_dtype):
idx_dtype = numpy.dtype(idx_dtype)
if idx_dtype.kind not in ["i", "u"]:
raise ValueError("Not an integer type: %s" % idx_dtype)
# scan value type must have size divisible by 4 bytes
if idx_dtype.itemsize % 4 != 0:
- raise ValueError("Due to an internal pyopencl limitation, idx_dtype type must have size divisible by 4 bytes")
- self.indice_dtype = idx_dtype #
-
+ raise ValueError(
+ "Due to an internal pyopencl limitation, idx_dtype type must have size divisible by 4 bytes"
+ )
+ self.indice_dtype = idx_dtype #
def _set_dtype(self, dtype):
self.dtype = numpy.dtype(dtype)
@@ -119,42 +137,44 @@ class CSR(OpenclProcessing):
self._c_zero_str = "0.0f"
elif self.dtype == numpy.dtype(numpy.float64):
self._c_zero_str = "0.0"
- else: # assuming integer
+ else: # assuming integer
self._c_zero_str = "0"
self.c_dtype = dtype_to_ctype(self.dtype)
self.idx_c_dtype = dtype_to_ctype(self.indice_dtype)
-
def _allocate_memory(self):
- self.is_cpu = (self.device.type == "CPU") # move to OpenclProcessing ?
+ self.is_cpu = self.device.type == "CPU" # move to OpenclProcessing ?
self.buffers = [
BufferDescription("array", (self.size,), self.dtype, mf.READ_ONLY),
BufferDescription("data", (self.max_nnz,), self.dtype, mf.READ_WRITE),
- BufferDescription("indices", (self.max_nnz,), self.indice_dtype, mf.READ_WRITE),
- BufferDescription("indptr", (self.shape[0]+1,), self.indice_dtype, mf.READ_WRITE),
+ BufferDescription(
+ "indices", (self.max_nnz,), self.indice_dtype, mf.READ_WRITE
+ ),
+ BufferDescription(
+ "indptr", (self.shape[0] + 1,), self.indice_dtype, mf.READ_WRITE
+ ),
]
self.allocate_buffers(use_array=True)
for arr_name in ["array", "data", "indices", "indptr"]:
setattr(self, arr_name, self.cl_mem[arr_name])
- self.cl_mem[arr_name].fill(0) # allocate_buffers() uses empty()
+ self.cl_mem[arr_name].fill(0) # allocate_buffers() uses empty()
self._old_array = self.array
self._old_data = self.data
self._old_indices = self.indices
self._old_indptr = self.indptr
-
def _setup_kernels(self):
self._setup_compaction_kernel()
self._setup_decompaction_kernel()
-
def _setup_compaction_kernel(self):
kernel_signature = str(
"__global %s *data, \
__global %s *data_compacted, \
__global %s *indices, \
__global %s* indptr \
- """ % (self.c_dtype, self.c_dtype, self.idx_c_dtype, self.idx_c_dtype)
+ "
+ "" % (self.c_dtype, self.c_dtype, self.idx_c_dtype, self.idx_c_dtype)
)
if self.dtype.kind == "f":
map_nonzero_expr = "(fabs(data[i]) > %s) ? 1 : 0" % self._c_zero_str
@@ -164,10 +184,12 @@ class CSR(OpenclProcessing):
raise ValueError("Unknown data type")
self.scan_kernel = GenericScanKernel(
- self.ctx, self.indice_dtype,
+ self.ctx,
+ self.indice_dtype,
arguments=kernel_signature,
input_expr=map_nonzero_expr,
- scan_expr="a+b", neutral="0",
+ scan_expr="a+b",
+ neutral="0",
output_statement="""
// item is the running sum of input_expr(i), i.e the cumsum of "nonzero"
if (prev_item != item) {
@@ -183,7 +205,6 @@ class CSR(OpenclProcessing):
preamble="#define GET_INDEX(i) (i % IMAGE_WIDTH)",
)
-
def _setup_decompaction_kernel(self):
OpenclProcessing.compile_kernels(
self,
@@ -192,18 +213,17 @@ class CSR(OpenclProcessing):
"-DIMAGE_WIDTH=%d" % self.shape[1],
"-DDTYPE=%s" % self.c_dtype,
"-DIDX_DTYPE=%s" % self.idx_c_dtype,
- ]
+ ],
)
device = self.ctx.devices[0]
wg_x = min(
device.max_work_group_size,
32,
- self.kernels.max_workgroup_size("densify_csr")
+ self.kernels.max_workgroup_size("densify_csr"),
)
self._decomp_wg = (wg_x, 1)
self._decomp_grid = (self._decomp_wg[0], self.shape[0])
-
# --------------------------------------------------------------------------
# -------------------------- Array utils -----------------------------------
# --------------------------------------------------------------------------
@@ -219,7 +239,6 @@ class CSR(OpenclProcessing):
assert arr.size == self.size
assert arr.dtype == self.dtype
-
# TODO handle pyopencl Buffer
def check_sparse_arrays(self, csr_data):
"""
@@ -234,12 +253,11 @@ class CSR(OpenclProcessing):
assert arr.ndim == 1
assert csr_data.data.size <= self.max_nnz
assert csr_data.indices.size <= self.max_nnz
- assert csr_data.indptr.size == self.shape[0]+1
+ assert csr_data.indptr.size == self.shape[0] + 1
assert csr_data.data.dtype == self.dtype
assert csr_data.indices.dtype == self.indice_dtype
assert csr_data.indptr.dtype == self.indice_dtype
-
def set_array(self, arr):
"""
Set the provided array as the current context 2D matrix.
@@ -259,23 +277,25 @@ class CSR(OpenclProcessing):
else:
raise ValueError("Expected pyopencl array or numpy array")
-
def set_sparse_arrays(self, csr_data):
if csr_data is None:
return
self.check_sparse_arrays(csr_data)
- for name, arr in {"data": csr_data.data, "indices": csr_data.indices, "indptr": csr_data.indptr}.items():
+ for name, arr in {
+ "data": csr_data.data,
+ "indices": csr_data.indices,
+ "indptr": csr_data.indptr,
+ }.items():
# The current array is a device array. Don't copy, use it directly
if isinstance(arr, parray.Array):
setattr(self, "_old_" + name, getattr(self, name))
setattr(self, name, arr)
# The current array is a numpy.ndarray: copy H2D
elif isinstance(arr, numpy.ndarray):
- getattr(self, name)[:arr.size] = arr[:]
+ getattr(self, name)[: arr.size] = arr[:]
else:
raise ValueError("Unsupported array type: %s" % type(arr))
-
def _recover_arrays_references(self):
"""
Recover the previous arrays references, and return the references of the
@@ -290,7 +310,6 @@ class CSR(OpenclProcessing):
setattr(self, name, getattr(self, "_old_" + name))
return array, (data, indices, indptr)
-
def get_sparse_arrays(self, output):
"""
Get the 2D dense array of the current context.
@@ -311,7 +330,6 @@ class CSR(OpenclProcessing):
res = output
return res
-
def get_array(self, output):
if output is None:
res = self.array.get().reshape(self.shape)
@@ -341,7 +359,7 @@ class CSR(OpenclProcessing):
self.indices,
self.indptr,
)
- #~ evt.wait()
+ # ~ evt.wait()
self.profile_add(evt, "sparsification kernel")
res = self.get_sparse_arrays(output)
self._recover_arrays_references()
@@ -352,9 +370,7 @@ class CSR(OpenclProcessing):
# --------------------------------------------------------------------------
def densify(self, data, indices, indptr, output=None):
- self.set_sparse_arrays(
- CSRData(data=data, indices=indices, indptr=indptr)
- )
+ self.set_sparse_arrays(CSRData(data=data, indices=indices, indptr=indptr))
self.set_array(output)
evt = self.kernels.densify_csr(
self.queue,
@@ -366,9 +382,8 @@ class CSR(OpenclProcessing):
self.array.data,
numpy.int32(self.shape[0]),
)
- #~ evt.wait()
+ # ~ evt.wait()
self.profile_add(evt, "desparsification kernel")
res = self.get_array(output)
self._recover_arrays_references()
return res
-
diff --git a/src/silx/opencl/statistics.py b/src/silx/opencl/statistics.py
index 9197dd1..26d23e6 100644
--- a/src/silx/opencl/statistics.py
+++ b/src/silx/opencl/statistics.py
@@ -2,7 +2,7 @@
# Project: SILX
# https://github.com/silx-kit/silx
#
-# Copyright (C) 2012-2019 European Synchrotron Radiation Facility, Grenoble, France
+# Copyright (C) 2012-2023 European Synchrotron Radiation Facility, Grenoble, France
#
# Principal author: Jérôme Kieffer (Jerome.Kieffer@ESRF.eu)
#
@@ -36,7 +36,7 @@ __contact__ = "jerome.kieffer@esrf.fr"
import logging
import numpy
-from collections import OrderedDict, namedtuple
+from collections import namedtuple
from math import sqrt
from .common import pyopencl
@@ -46,6 +46,7 @@ from .utils import concatenate_cl_kernel
if pyopencl:
mf = pyopencl.mem_flags
from pyopencl.reduction import ReductionKernel
+
try:
from pyopencl import cltypes
except ImportError:
@@ -58,8 +59,9 @@ else:
raise ImportError("pyopencl is not installed")
logger = logging.getLogger(__name__)
-StatResults = namedtuple("StatResults", ["min", "max", "cnt", "sum", "mean",
- "var", "std"])
+StatResults = namedtuple(
+ "StatResults", ["min", "max", "cnt", "sum", "mean", "var", "std"]
+)
zero8 = "(float8)(FLT_MAX, -FLT_MAX, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f)"
# min max cnt cnt_e sum sum_e var var_e
@@ -81,33 +83,52 @@ class Statistics(OpenclProcessing):
Switch on profiling to be able to profile at the kernel level,
store profiling elements (makes code slightly slower)
"""
+
buffers = [
BufferDescription("raw", 1, numpy.float32, mf.READ_ONLY),
BufferDescription("converted", 1, numpy.float32, mf.READ_WRITE),
]
kernel_files = ["preprocess.cl"]
- mapping = {numpy.int8: "s8_to_float",
- numpy.uint8: "u8_to_float",
- numpy.int16: "s16_to_float",
- numpy.uint16: "u16_to_float",
- numpy.uint32: "u32_to_float",
- numpy.int32: "s32_to_float"}
-
- def __init__(self, size=None, dtype=None, template=None,
- ctx=None, devicetype="all", platformid=None, deviceid=None,
- block_size=None, profile=False
- ):
- OpenclProcessing.__init__(self, ctx=ctx, devicetype=devicetype,
- platformid=platformid, deviceid=deviceid,
- block_size=block_size, profile=profile)
+ mapping = {
+ numpy.int8: "s8_to_float",
+ numpy.uint8: "u8_to_float",
+ numpy.int16: "s16_to_float",
+ numpy.uint16: "u16_to_float",
+ numpy.uint32: "u32_to_float",
+ numpy.int32: "s32_to_float",
+ }
+
+ def __init__(
+ self,
+ size=None,
+ dtype=None,
+ template=None,
+ ctx=None,
+ devicetype="all",
+ platformid=None,
+ deviceid=None,
+ block_size=None,
+ profile=False,
+ ):
+ OpenclProcessing.__init__(
+ self,
+ ctx=ctx,
+ devicetype=devicetype,
+ platformid=platformid,
+ deviceid=deviceid,
+ block_size=block_size,
+ profile=profile,
+ )
self.size = size
self.dtype = dtype
if template is not None:
self.size = template.size
self.dtype = template.dtype
- self.buffers = [BufferDescription(i.name, i.size * self.size, i.dtype, i.flags)
- for i in self.__class__.buffers]
+ self.buffers = [
+ BufferDescription(i.name, i.size * self.size, i.dtype, i.flags)
+ for i in self.__class__.buffers
+ ]
self.allocate_buffers(use_array=True)
self.compile_kernels()
@@ -116,43 +137,54 @@ class Statistics(OpenclProcessing):
def set_kernel_arguments(self):
"""Parametrize all kernel arguments"""
for val in self.mapping.values():
- self.cl_kernel_args[val] = OrderedDict(((i, self.cl_mem[i]) for i in ("raw", "converted")))
+ self.cl_kernel_args[val] = dict(
+ ((i, self.cl_mem[i]) for i in ("raw", "converted"))
+ )
def compile_kernels(self):
"""Compile the kernel"""
- OpenclProcessing.compile_kernels(self,
- self.kernel_files,
- "-D NIMAGE=%i" % self.size)
+ OpenclProcessing.compile_kernels(
+ self, self.kernel_files, "-D NIMAGE=%i" % self.size
+ )
compiler_options = self.get_compiler_options(x87_volatile=True)
src = concatenate_cl_kernel(("doubleword.cl", "statistics.cl"))
- self.reduction_comp = ReductionKernel(self.ctx,
- dtype_out=float8,
- neutral=zero8,
- map_expr="map_statistics(data, i)",
- reduce_expr="reduce_statistics(a,b)",
- arguments="__global float *data",
- preamble=src,
- options=compiler_options)
- self.reduction_simple = ReductionKernel(self.ctx,
- dtype_out=float8,
- neutral=zero8,
- map_expr="map_statistics(data, i)",
- reduce_expr="reduce_statistics_simple(a,b)",
- arguments="__global float *data",
- preamble=src,
- options=compiler_options)
+ self.reduction_comp = ReductionKernel(
+ self.ctx,
+ dtype_out=float8,
+ neutral=zero8,
+ map_expr="map_statistics(data, i)",
+ reduce_expr="reduce_statistics(a,b)",
+ arguments="__global float *data",
+ preamble=src,
+ options=compiler_options,
+ )
+ self.reduction_simple = ReductionKernel(
+ self.ctx,
+ dtype_out=float8,
+ neutral=zero8,
+ map_expr="map_statistics(data, i)",
+ reduce_expr="reduce_statistics_simple(a,b)",
+ arguments="__global float *data",
+ preamble=src,
+ options=compiler_options,
+ )
if "cl_khr_fp64" in self.device.extensions:
- self.reduction_double = ReductionKernel(self.ctx,
- dtype_out=float8,
- neutral=zero8,
- map_expr="map_statistics(data, i)",
- reduce_expr="reduce_statistics_double(a,b)",
- arguments="__global float *data",
- preamble=src,
- options=compiler_options)
+ self.reduction_double = ReductionKernel(
+ self.ctx,
+ dtype_out=float8,
+ neutral=zero8,
+ map_expr="map_statistics(data, i)",
+ reduce_expr="reduce_statistics_double(a,b)",
+ arguments="__global float *data",
+ preamble=src,
+ options=compiler_options,
+ )
else:
- logger.info("Device %s does not support double-precision arithmetics, fall-back on compensated one", self.device)
+ logger.info(
+ "Device %s does not support double-precision arithmetics, fall-back on compensated one",
+ self.device,
+ )
self.reduction_double = self.reduction_comp
def send_buffer(self, data, dest):
@@ -167,23 +199,27 @@ class Statistics(OpenclProcessing):
dest_type = numpy.dtype([i.dtype for i in self.buffers if i.name == dest][0])
events = []
if (data.dtype == dest_type) or (data.dtype.itemsize > dest_type.itemsize):
- copy_image = pyopencl.enqueue_copy(self.queue,
- self.cl_mem[dest].data,
- numpy.ascontiguousarray(data, dest_type))
+ copy_image = pyopencl.enqueue_copy(
+ self.queue,
+ self.cl_mem[dest].data,
+ numpy.ascontiguousarray(data, dest_type),
+ )
events.append(EventDescription("copy H->D %s" % dest, copy_image))
else:
- copy_image = pyopencl.enqueue_copy(self.queue,
- self.cl_mem["raw"].data,
- numpy.ascontiguousarray(data))
+ copy_image = pyopencl.enqueue_copy(
+ self.queue, self.cl_mem["raw"].data, numpy.ascontiguousarray(data)
+ )
kernel = getattr(self.program, self.mapping[data.dtype.type])
- cast_to_float = kernel(self.queue,
- (self.size,),
- None,
- self.cl_mem["raw"].data,
- self.cl_mem[dest].data)
+ cast_to_float = kernel(
+ self.queue,
+ (self.size,),
+ None,
+ self.cl_mem["raw"].data,
+ self.cl_mem[dest].data,
+ )
events += [
EventDescription("copy H->D raw", copy_image),
- EventDescription(f"cast to float {dest}", cast_to_float)
+ EventDescription(f"cast to float {dest}", cast_to_float),
]
if self.profile:
self.events += events
@@ -193,7 +229,7 @@ class Statistics(OpenclProcessing):
"""Actually calculate the statics on the data
:param numpy.ndarray data: numpy array with the image
- :param comp: use Kahan compensated arithmetics for the calculation
+ :param comp: use Kahan compensated arithmetics for the calculation
:return: Statistics named tuple
:rtype: StatResults
"""
@@ -216,9 +252,11 @@ class Statistics(OpenclProcessing):
reduction = self.reduction_double
else:
reduction = self.reduction_comp
- res_d, evt = reduction(self.cl_mem["converted"][:self.size],
- queue=self.queue,
- return_event=True)
+ res_d, evt = reduction(
+ self.cl_mem["converted"][: self.size],
+ queue=self.queue,
+ return_event=True,
+ )
events.append(EventDescription(f"statistical reduction {comp}", evt))
if self.profile:
self.events += events
@@ -229,13 +267,7 @@ class Statistics(OpenclProcessing):
sum_ = 1.0 * res_h["s4"] + res_h["s5"]
m2 = 1.0 * res_h["s6"] + res_h["s7"]
var = m2 / (count - 1.0)
- res = StatResults(min_,
- max_,
- count,
- sum_,
- sum_ / count,
- var,
- sqrt(var))
+ res = StatResults(min_, max_, count, sum_, sum_ / count, var, sqrt(var))
return res
__call__ = process
diff --git a/src/silx/opencl/test/test_addition.py b/src/silx/opencl/test/test_addition.py
index d6cf1ac..98beab4 100644
--- a/src/silx/opencl/test/test_addition.py
+++ b/src/silx/opencl/test/test_addition.py
@@ -40,16 +40,17 @@ import pytest
import unittest
from ..common import ocl, _measure_workgroup_size, query_kernel_info
+
if ocl:
import pyopencl
import pyopencl.array
from ..utils import get_opencl_code
+
logger = logging.getLogger(__name__)
@unittest.skipUnless(ocl, "PyOpenCl is missing")
class TestAddition(unittest.TestCase):
-
@classmethod
def setUpClass(cls):
super(TestAddition, cls).setUpClass()
@@ -58,8 +59,9 @@ class TestAddition(unittest.TestCase):
if logger.getEffectiveLevel() <= logging.INFO:
cls.PROFILE = True
cls.queue = pyopencl.CommandQueue(
- cls.ctx,
- properties=pyopencl.command_queue_properties.PROFILING_ENABLE)
+ cls.ctx,
+ properties=pyopencl.command_queue_properties.PROFILING_ENABLE,
+ )
else:
cls.PROFILE = False
cls.queue = pyopencl.CommandQueue(cls.ctx)
@@ -68,7 +70,10 @@ class TestAddition(unittest.TestCase):
@classmethod
def tearDownClass(cls):
super(TestAddition, cls).tearDownClass()
- print("Maximum valid workgroup size %s on device %s" % (cls.max_valid_wg, cls.ctx.devices[0]))
+ print(
+ "Maximum valid workgroup size %s on device %s"
+ % (cls.max_valid_wg, cls.ctx.devices[0])
+ )
cls.ctx = None
cls.queue = None
@@ -95,11 +100,20 @@ class TestAddition(unittest.TestCase):
d_array_result = pyopencl.array.empty_like(self.d_array_img)
wg = 1 << i
try:
- evt = self.program.addition(self.queue, (self.shape,), (wg,),
- self.d_array_img.data, self.d_array_5.data, d_array_result.data, numpy.int32(self.shape))
+ evt = self.program.addition(
+ self.queue,
+ (self.shape,),
+ (wg,),
+ self.d_array_img.data,
+ self.d_array_5.data,
+ d_array_result.data,
+ numpy.int32(self.shape),
+ )
evt.wait()
except Exception as error:
- max_valid_wg = self.program.addition.get_work_group_info(pyopencl.kernel_work_group_info.WORK_GROUP_SIZE, self.ctx.devices[0])
+ max_valid_wg = self.program.addition.get_work_group_info(
+ pyopencl.kernel_work_group_info.WORK_GROUP_SIZE, self.ctx.devices[0]
+ )
msg = "Error %s on WG=%s: %s" % (error, wg, max_valid_wg)
self.assertLess(max_valid_wg, wg, msg)
break
@@ -117,23 +131,39 @@ class TestAddition(unittest.TestCase):
for platform in ocl.platforms:
for did, device in enumerate(platform.devices):
meas = _measure_workgroup_size((platform.id, device.id))
- self.assertEqual(meas, device.max_work_group_size,
- "Workgroup size for %s/%s: %s == %s" % (platform, device, meas, device.max_work_group_size))
+ self.assertEqual(
+ meas,
+ device.max_work_group_size,
+ "Workgroup size for %s/%s: %s == %s"
+ % (platform, device, meas, device.max_work_group_size),
+ )
def test_query(self):
"""
tests that all devices are working properly ... lengthy and error prone
"""
- for what in ("COMPILE_WORK_GROUP_SIZE",
- "LOCAL_MEM_SIZE",
- "PREFERRED_WORK_GROUP_SIZE_MULTIPLE",
- "PRIVATE_MEM_SIZE",
- "WORK_GROUP_SIZE"):
- logger.info("%s: %s", what, query_kernel_info(program=self.program, kernel="addition", what=what))
-
- # Not all ICD work properly ....
- #self.assertEqual(3, len(query_kernel_info(program=self.program, kernel="addition", what="COMPILE_WORK_GROUP_SIZE")), "3D kernel")
-
- min_wg = query_kernel_info(program=self.program, kernel="addition", what="PREFERRED_WORK_GROUP_SIZE_MULTIPLE")
- max_wg = query_kernel_info(program=self.program, kernel="addition", what="WORK_GROUP_SIZE")
+ for what in (
+ "COMPILE_WORK_GROUP_SIZE",
+ "LOCAL_MEM_SIZE",
+ "PREFERRED_WORK_GROUP_SIZE_MULTIPLE",
+ "PRIVATE_MEM_SIZE",
+ "WORK_GROUP_SIZE",
+ ):
+ logger.info(
+ "%s: %s",
+ what,
+ query_kernel_info(program=self.program, kernel="addition", what=what),
+ )
+
+ # Not all ICD work properly ....
+ # self.assertEqual(3, len(query_kernel_info(program=self.program, kernel="addition", what="COMPILE_WORK_GROUP_SIZE")), "3D kernel")
+
+ min_wg = query_kernel_info(
+ program=self.program,
+ kernel="addition",
+ what="PREFERRED_WORK_GROUP_SIZE_MULTIPLE",
+ )
+ max_wg = query_kernel_info(
+ program=self.program, kernel="addition", what="WORK_GROUP_SIZE"
+ )
self.assertEqual(max_wg % min_wg, 0, msg="max_wg is a multiple of min_wg")
diff --git a/src/silx/opencl/test/test_array_utils.py b/src/silx/opencl/test/test_array_utils.py
index 125d323..98f8bf3 100644
--- a/src/silx/opencl/test/test_array_utils.py
+++ b/src/silx/opencl/test/test_array_utils.py
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# /*##########################################################################
#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2022 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -30,34 +30,26 @@ __copyright__ = "2013-2017 European Synchrotron Radiation Facility, Grenoble, Fr
__date__ = "14/06/2017"
-import time
import logging
import numpy as np
import unittest
+
try:
import mako
except ImportError:
mako = None
from ..common import ocl
+
if ocl:
import pyopencl as cl
import pyopencl.array as parray
- from .. import linalg
from ..utils import get_opencl_code
-from silx.test.utils import utilstest
logger = logging.getLogger(__name__)
-try:
- from scipy.ndimage.filters import laplace
- _has_scipy = True
-except ImportError:
- _has_scipy = False
-
@unittest.skipUnless(ocl and mako, "PyOpenCl is missing")
class TestCpy2d(unittest.TestCase):
-
def setUp(self):
if ocl is None:
return
@@ -65,8 +57,8 @@ class TestCpy2d(unittest.TestCase):
if logger.getEffectiveLevel() <= logging.INFO:
self.PROFILE = True
self.queue = cl.CommandQueue(
- self.ctx,
- properties=cl.command_queue_properties.PROFILING_ENABLE)
+ self.ctx, properties=cl.command_queue_properties.PROFILING_ENABLE
+ )
else:
self.PROFILE = False
self.queue = cl.CommandQueue(self.ctx)
@@ -93,8 +85,12 @@ class TestCpy2d(unittest.TestCase):
self.offset1 = (offset1_y, offset1_x)
self.offset2 = (offset2_y, offset2_x)
# Compute the size of the rectangle to transfer
- size_y = np.random.randint(2, high=min(self.shape1[0], self.shape2[0]) - max(offset1_y, offset2_y) + 1)
- size_x = np.random.randint(2, high=min(self.shape1[1], self.shape2[1]) - max(offset1_x, offset2_x) + 1)
+ size_y = np.random.randint(
+ 2, high=min(self.shape1[0], self.shape2[0]) - max(offset1_y, offset2_y) + 1
+ )
+ size_x = np.random.randint(
+ 2, high=min(self.shape1[1], self.shape2[1]) - max(offset1_x, offset2_x) + 1
+ )
self.transfer_shape = (size_y, size_x)
def tearDown(self):
@@ -110,7 +106,9 @@ class TestCpy2d(unittest.TestCase):
def compare(self, result, reference):
errmax = np.max(np.abs(result - reference))
logger.info("Max error = %e" % (errmax))
- self.assertTrue(errmax == 0, str("Max error is too high"))#. PRNG state was %s" % str(self.prng_state)))
+ self.assertTrue(
+ errmax == 0, str("Max error is too high")
+ ) # . PRNG state was %s" % str(self.prng_state)))
@unittest.skipUnless(ocl and mako, "pyopencl is missing")
def test_cpy2d(self):
@@ -121,18 +119,26 @@ class TestCpy2d(unittest.TestCase):
o1 = self.offset1
o2 = self.offset2
T = self.transfer_shape
- logger.info("""Testing D->D rectangular copy with (N1_y, N1_x) = %s,
+ logger.info(
+ """Testing D->D rectangular copy with (N1_y, N1_x) = %s,
(N2_y, N2_x) = %s:
- array2[%d:%d, %d:%d] = array1[%d:%d, %d:%d]""" %
- (
- str(self.shape1), str(self.shape2),
- o2[0], o2[0] + T[0],
- o2[1], o2[1] + T[1],
- o1[0], o1[0] + T[0],
- o1[1], o1[1] + T[1]
- )
- )
- self.array2[o2[0]:o2[0] + T[0], o2[1]:o2[1] + T[1]] = self.array1[o1[0]:o1[0] + T[0], o1[1]:o1[1] + T[1]]
+ array2[%d:%d, %d:%d] = array1[%d:%d, %d:%d]"""
+ % (
+ str(self.shape1),
+ str(self.shape2),
+ o2[0],
+ o2[0] + T[0],
+ o2[1],
+ o2[1] + T[1],
+ o1[0],
+ o1[0] + T[0],
+ o1[1],
+ o1[1] + T[1],
+ )
+ )
+ self.array2[o2[0] : o2[0] + T[0], o2[1] : o2[1] + T[1]] = self.array1[
+ o1[0] : o1[0] + T[0], o1[1] : o1[1] + T[1]
+ ]
kernel_args = (
self.d_array2.data,
self.d_array1.data,
@@ -140,7 +146,7 @@ class TestCpy2d(unittest.TestCase):
np.int32(self.shape1[1]),
np.int32(self.offset2[::-1]),
np.int32(self.offset1[::-1]),
- np.int32(self.transfer_shape[::-1])
+ np.int32(self.transfer_shape[::-1]),
)
wg = None
ndrange = self.transfer_shape[::-1]
diff --git a/src/silx/opencl/test/test_backprojection.py b/src/silx/opencl/test/test_backprojection.py
index 501cf2f..b08c972 100644
--- a/src/silx/opencl/test/test_backprojection.py
+++ b/src/silx/opencl/test/test_backprojection.py
@@ -35,11 +35,13 @@ import logging
import numpy as np
import unittest
from math import pi
+
try:
import mako
except ImportError:
mako = None
from ..common import ocl
+
if ocl:
from .. import backprojection
from ...image.tomography import compute_fourier_filter
@@ -56,7 +58,7 @@ def generate_coords(img_shp, center=None):
l_r, l_c = float(img_shp[0]), float(img_shp[1])
R, C = np.mgrid[:l_r, :l_c]
if center is None:
- center0, center1 = l_r / 2., l_c / 2.
+ center0, center1 = l_r / 2.0, l_c / 2.0
else:
center0, center1 = center
R = R + 0.5 - center0
@@ -72,7 +74,7 @@ def clip_circle(img, center=None, radius=None):
M = R * R + C * C
res = np.zeros_like(img)
if radius is None:
- radius = img.shape[0] / 2. - 1
+ radius = img.shape[0] / 2.0 - 1
mask = M < radius * radius
res[mask] = img[mask]
return res
@@ -80,20 +82,21 @@ def clip_circle(img, center=None, radius=None):
@unittest.skipUnless(ocl and mako, "PyOpenCl is missing")
class TestFBP(unittest.TestCase):
-
def setUp(self):
if ocl is None:
return
self.getfiles()
self.fbp = backprojection.Backprojection(self.sino.shape, profile=True)
if self.fbp.compiletime_workgroup_size < 16 * 16:
- self.skipTest("Current implementation of OpenCL backprojection is "
- "not supported on this platform yet")
+ self.skipTest(
+ "Current implementation of OpenCL backprojection is "
+ "not supported on this platform yet"
+ )
# Astra does not use the same backprojector implementation.
# Therefore, we cannot expect results to be the "same" (up to float32
# numerical error)
self.tol = 5e-2
- if not(self.fbp._use_textures) or self.fbp.device.type == "CPU":
+ if not (self.fbp._use_textures) or self.fbp.device.type == "CPU":
# Precision is less when using CPU
# (either CPU textures or "manual" linear interpolation)
self.tol *= 2
@@ -130,8 +133,12 @@ class TestFBP(unittest.TestCase):
ref_clipped = clip_circle(self.reference_rec)
delta = abs(res_clipped - ref_clipped)
bad = delta > 1
- logger.debug("Absolute difference: %s with %s outlier pixels out of %s"
- "", delta.max(), bad.sum(), np.prod(bad.shape))
+ logger.debug(
+ "Absolute difference: %s with %s outlier pixels out of %s" "",
+ delta.max(),
+ bad.sum(),
+ np.prod(bad.shape),
+ )
return delta.max()
@unittest.skipUnless(ocl and mako, "pyopencl is missing")
@@ -157,17 +164,14 @@ class TestFBP(unittest.TestCase):
for i in range(10):
res = self.fbp.filtered_backprojection(self.sino)
errmax = np.max(np.abs(res - res0))
- self.assertTrue(errmax < 1.e-6, "Max error is too high")
+ self.assertTrue(errmax < 1.0e-6, "Max error is too high")
@unittest.skipUnless(ocl and mako, "pyopencl is missing")
def test_fbp_filters(self):
"""
Test the different available filters of silx FBP.
"""
- avail_filters = [
- "ramlak", "shepp-logan", "cosine", "hamming",
- "hann"
- ]
+ avail_filters = ["ramlak", "shepp-logan", "cosine", "hamming", "hann"]
# Create a Dirac delta function at a single angle view.
# As the filters are radially invarant:
# - backprojection yields an image where each line is a Dirac.
@@ -176,7 +180,7 @@ class TestFBP(unittest.TestCase):
# test will also ensure that backprojection behaves well.
dirac = np.zeros_like(self.sino)
na, dw = dirac.shape
- dirac[0, dw//2] = na / pi * 2
+ dirac[0, dw // 2] = na / pi * 2
for filter_name in avail_filters:
B = backprojection.Backprojection(dirac.shape, filter_name=filter_name)
@@ -184,17 +188,15 @@ class TestFBP(unittest.TestCase):
# Check that radial invariance is kept
std0 = np.max(np.abs(np.std(r, axis=0)))
self.assertTrue(
- std0 < 5.e-6,
- "Something wrong with FBP(filter=%s)" % filter_name
+ std0 < 5.0e-6, "Something wrong with FBP(filter=%s)" % filter_name
)
# Check that the filter is retrieved
- r_f = np.fft.fft(np.fft.fftshift(r[0])).real / 2. # filter factor
+ r_f = np.fft.fft(np.fft.fftshift(r[0])).real / 2.0 # filter factor
ref_filter_f = compute_fourier_filter(dw, filter_name)
errmax = np.max(np.abs(r_f - ref_filter_f))
logger.info("FBP filter %s: max error=%e" % (filter_name, errmax))
self.assertTrue(
- errmax < 1.e-3,
- "Something wrong with FBP(filter=%s)" % filter_name
+ errmax < 1.0e-3, "Something wrong with FBP(filter=%s)" % filter_name
)
@unittest.skipUnless(ocl and mako, "pyopencl is missing")
@@ -202,13 +204,14 @@ class TestFBP(unittest.TestCase):
# Generate a 513-sinogram.
# The padded width will be nextpow(513*2).
# silx [0.10, 0.10.1] will give 1029, which makes R2C transform fail.
- sino = np.pad(self.sino, ((0, 0), (1, 0)), mode='edge')
- B = backprojection.Backprojection(sino.shape, axis_position=self.fbp.axis_pos+1)
+ sino = np.pad(self.sino, ((0, 0), (1, 0)), mode="edge")
+ B = backprojection.Backprojection(
+ sino.shape, axis_position=self.fbp.axis_pos + 1
+ )
res = B(sino)
# Compare with self.reference_rec. Tolerance is high as backprojector
# is not fully shift-invariant.
errmax = np.max(np.abs(clip_circle(res[1:, 1:] - self.reference_rec)))
self.assertLess(
- errmax, 1.e-1,
- "Something wrong with FBP on odd-sized sinogram"
+ errmax, 1.0e-1, "Something wrong with FBP on odd-sized sinogram"
)
diff --git a/src/silx/opencl/test/test_convolution.py b/src/silx/opencl/test/test_convolution.py
index e38a36a..86716f4 100644
--- a/src/silx/opencl/test/test_convolution.py
+++ b/src/silx/opencl/test/test_convolution.py
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# /*##########################################################################
#
-# Copyright (c) 2019 European Synchrotron Radiation Facility
+# Copyright (c) 2019-2022 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -41,7 +41,11 @@ from silx.image.utils import gaussian_kernel
try:
from scipy.ndimage import convolve, convolve1d
- from scipy.misc import ascent
+
+ try:
+ from scipy.misc import ascent
+ except:
+ from scipy.datasets import ascent
scipy_convolve = convolve
scipy_convolve1d = convolve1d
@@ -58,7 +62,6 @@ logger = logging.getLogger(__name__)
class ConvolutionData:
-
def __init__(self, param):
self.param = param
self.mode = param["boundary_handling"]
@@ -204,7 +207,7 @@ def convolution_data_params():
)
params = []
for boundary_handling, use_texture, input_dev, output_dev in param_vals:
- param={
+ param = {
"boundary_handling": boundary_handling,
"input_on_device": input_dev,
"output_on_device": output_dev,
@@ -236,26 +239,31 @@ def convolution_data(request):
def test_1D(convolution_data):
convolution_data.template_test("test_1D")
+
@pytest.mark.skipif(ocl is None, reason="OpenCL is missing")
@pytest.mark.skipif(scipy_convolve is None, reason="scipy is missing")
def test_separable_2D(convolution_data):
convolution_data.template_test("test_separable_2D")
+
@pytest.mark.skipif(ocl is None, reason="OpenCL is missing")
@pytest.mark.skipif(scipy_convolve is None, reason="scipy is missing")
def test_separable_3D(convolution_data):
convolution_data.template_test("test_separable_3D")
+
@pytest.mark.skipif(ocl is None, reason="OpenCL is missing")
@pytest.mark.skipif(scipy_convolve is None, reason="scipy is missing")
def test_nonseparable_2D(convolution_data):
convolution_data.template_test("test_nonseparable_2D")
+
@pytest.mark.skipif(ocl is None, reason="OpenCL is missing")
@pytest.mark.skipif(scipy_convolve is None, reason="scipy is missing")
def test_nonseparable_3D(convolution_data):
convolution_data.template_test("test_nonseparable_3D")
+
@pytest.mark.skipif(ocl is None, reason="OpenCL is missing")
@pytest.mark.skipif(scipy_convolve is None, reason="scipy is missing")
def test_batched_2D(convolution_data):
diff --git a/src/silx/opencl/test/test_doubleword.py b/src/silx/opencl/test/test_doubleword.py
index 8ab594d..493d8c8 100644
--- a/src/silx/opencl/test/test_doubleword.py
+++ b/src/silx/opencl/test/test_doubleword.py
@@ -46,6 +46,7 @@ except ImportError as error:
pyopencl = None
from .. import ocl
+
if ocl is not None:
from ..utils import read_cl_file
from .. import pyopencl
@@ -65,14 +66,21 @@ class TestDoubleWord(unittest.TestCase):
@classmethod
def setUpClass(cls):
if pyopencl is None or ocl is None:
- raise unittest.SkipTest("OpenCL module (pyopencl) is not present or no device available")
+ raise unittest.SkipTest(
+ "OpenCL module (pyopencl) is not present or no device available"
+ )
cls.ctx = ocl.create_context(devicetype="GPU")
- cls.queue = pyopencl.CommandQueue(cls.ctx, properties=pyopencl.command_queue_properties.PROFILING_ENABLE)
+ cls.queue = pyopencl.CommandQueue(
+ cls.ctx, properties=pyopencl.command_queue_properties.PROFILING_ENABLE
+ )
# this is running 32 bits OpenCL woth POCL
- if (platform.machine() in ("i386", "i686", "x86_64") and (tuple.__itemsize__ == 4) and
- cls.ctx.devices[0].platform.name == 'Portable Computing Language'):
+ if (
+ platform.machine() in ("i386", "i686", "x86_64")
+ and (tuple.__itemsize__ == 4)
+ and cls.ctx.devices[0].platform.name == "Portable Computing Language"
+ ):
cls.args = "-DX87_VOLATILE=volatile"
else:
cls.args = ""
@@ -94,38 +102,68 @@ class TestDoubleWord(unittest.TestCase):
cls.doubleword = None
def test_fast_sum2(self):
- test_kernel = ElementwiseKernel(self.ctx,
- "float *a, float *b, float *res_h, float *res_l",
- "float2 tmp = fast_fp_plus_fp(a[i], b[i]); res_h[i] = tmp.s0; res_l[i] = tmp.s1",
- preamble=self.doubleword)
+ test_kernel = ElementwiseKernel(
+ self.ctx,
+ "float *a, float *b, float *res_h, float *res_l",
+ "float2 tmp = fast_fp_plus_fp(a[i], b[i]); res_h[i] = tmp.s0; res_l[i] = tmp.s1",
+ preamble=self.doubleword,
+ )
a_g = pyopencl.array.to_device(self.queue, self.ah)
b_g = pyopencl.array.to_device(self.queue, self.bl)
res_l = pyopencl.array.empty_like(a_g)
res_h = pyopencl.array.empty_like(a_g)
test_kernel(a_g, b_g, res_h, res_l)
self.assertEqual(abs(self.ah + self.bl - res_h.get()).max(), 0, "Major matches")
- self.assertGreater(abs(self.ah.astype(numpy.float64) + self.bl - res_h.get()).max(), 0, "Exact mismatches")
- self.assertEqual(abs(self.ah.astype(numpy.float64) + self.bl - (res_h.get().astype(numpy.float64) + res_l.get())).max(), 0, "Exact matches")
+ self.assertGreater(
+ abs(self.ah.astype(numpy.float64) + self.bl - res_h.get()).max(),
+ 0,
+ "Exact mismatches",
+ )
+ self.assertEqual(
+ abs(
+ self.ah.astype(numpy.float64)
+ + self.bl
+ - (res_h.get().astype(numpy.float64) + res_l.get())
+ ).max(),
+ 0,
+ "Exact matches",
+ )
def test_sum2(self):
- test_kernel = ElementwiseKernel(self.ctx,
- "float *a, float *b, float *res_h, float *res_l",
- "float2 tmp = fp_plus_fp(a[i],b[i]); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
- preamble=self.doubleword)
+ test_kernel = ElementwiseKernel(
+ self.ctx,
+ "float *a, float *b, float *res_h, float *res_l",
+ "float2 tmp = fp_plus_fp(a[i],b[i]); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
+ preamble=self.doubleword,
+ )
a_g = pyopencl.array.to_device(self.queue, self.ah)
b_g = pyopencl.array.to_device(self.queue, self.bh)
res_l = pyopencl.array.empty_like(a_g)
res_h = pyopencl.array.empty_like(a_g)
test_kernel(a_g, b_g, res_h, res_l)
self.assertEqual(abs(self.ah + self.bh - res_h.get()).max(), 0, "Major matches")
- self.assertGreater(abs(self.ah.astype(numpy.float64) + self.bh - res_h.get()).max(), 0, "Exact mismatches")
- self.assertEqual(abs(self.ah.astype(numpy.float64) + self.bh - (res_h.get().astype(numpy.float64) + res_l.get())).max(), 0, "Exact matches")
+ self.assertGreater(
+ abs(self.ah.astype(numpy.float64) + self.bh - res_h.get()).max(),
+ 0,
+ "Exact mismatches",
+ )
+ self.assertEqual(
+ abs(
+ self.ah.astype(numpy.float64)
+ + self.bh
+ - (res_h.get().astype(numpy.float64) + res_l.get())
+ ).max(),
+ 0,
+ "Exact matches",
+ )
def test_prod2(self):
- test_kernel = ElementwiseKernel(self.ctx,
- "float *a, float *b, float *res_h, float *res_l",
- "float2 tmp = fp_times_fp(a[i],b[i]); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
- preamble=self.doubleword)
+ test_kernel = ElementwiseKernel(
+ self.ctx,
+ "float *a, float *b, float *res_h, float *res_l",
+ "float2 tmp = fp_times_fp(a[i],b[i]); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
+ preamble=self.doubleword,
+ )
a_g = pyopencl.array.to_device(self.queue, self.ah)
b_g = pyopencl.array.to_device(self.queue, self.bh)
res_l = pyopencl.array.empty_like(a_g)
@@ -134,14 +172,22 @@ class TestDoubleWord(unittest.TestCase):
res_m = res_h.get()
res = res_h.get().astype(numpy.float64) + res_l.get()
self.assertEqual(abs(self.ah * self.bh - res_m).max(), 0, "Major matches")
- self.assertGreater(abs(self.ah.astype(numpy.float64) * self.bh - res_m).max(), 0, "Exact mismatches")
- self.assertEqual(abs(self.ah.astype(numpy.float64) * self.bh - res).max(), 0, "Exact matches")
+ self.assertGreater(
+ abs(self.ah.astype(numpy.float64) * self.bh - res_m).max(),
+ 0,
+ "Exact mismatches",
+ )
+ self.assertEqual(
+ abs(self.ah.astype(numpy.float64) * self.bh - res).max(), 0, "Exact matches"
+ )
def test_dw_plus_fp(self):
- test_kernel = ElementwiseKernel(self.ctx,
- "float *ah, float *al, float *b, float *res_h, float *res_l",
- "float2 tmp = dw_plus_fp((float2)(ah[i], al[i]),b[i]); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
- preamble=self.doubleword)
+ test_kernel = ElementwiseKernel(
+ self.ctx,
+ "float *ah, float *al, float *b, float *res_h, float *res_l",
+ "float2 tmp = dw_plus_fp((float2)(ah[i], al[i]),b[i]); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
+ preamble=self.doubleword,
+ )
ah_g = pyopencl.array.to_device(self.queue, self.ah)
al_g = pyopencl.array.to_device(self.queue, self.al)
b_g = pyopencl.array.to_device(self.queue, self.bh)
@@ -151,14 +197,22 @@ class TestDoubleWord(unittest.TestCase):
res_m = res_h.get()
res = res_h.get().astype(numpy.float64) + res_l.get()
self.assertLess(abs(self.a + self.bh - res_m).max(), EPS32, "Major matches")
- self.assertGreater(abs(self.a + self.bh - res_m).max(), EPS64, "Exact mismatches")
- self.assertLess(abs(self.ah.astype(numpy.float64) + self.al + self.bh - res).max(), 2 * EPS32 ** 2, "Exact matches")
+ self.assertGreater(
+ abs(self.a + self.bh - res_m).max(), EPS64, "Exact mismatches"
+ )
+ self.assertLess(
+ abs(self.ah.astype(numpy.float64) + self.al + self.bh - res).max(),
+ 2 * EPS32**2,
+ "Exact matches",
+ )
def test_dw_plus_dw(self):
- test_kernel = ElementwiseKernel(self.ctx,
- "float *ah, float *al, float *bh, float *bl, float *res_h, float *res_l",
- "float2 tmp = dw_plus_dw((float2)(ah[i], al[i]),(float2)(bh[i], bl[i])); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
- preamble=self.doubleword)
+ test_kernel = ElementwiseKernel(
+ self.ctx,
+ "float *ah, float *al, float *bh, float *bl, float *res_h, float *res_l",
+ "float2 tmp = dw_plus_dw((float2)(ah[i], al[i]),(float2)(bh[i], bl[i])); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
+ preamble=self.doubleword,
+ )
ah_g = pyopencl.array.to_device(self.queue, self.ah)
al_g = pyopencl.array.to_device(self.queue, self.al)
bh_g = pyopencl.array.to_device(self.queue, self.bh)
@@ -169,14 +223,20 @@ class TestDoubleWord(unittest.TestCase):
res_m = res_h.get()
res = res_h.get().astype(numpy.float64) + res_l.get()
self.assertLess(abs(self.a + self.b - res_m).max(), EPS32, "Major matches")
- self.assertGreater(abs(self.a + self.b - res_m).max(), EPS64, "Exact mismatches")
- self.assertLess(abs(self.a + self.b - res).max(), 3 * EPS32 ** 2, "Exact matches")
+ self.assertGreater(
+ abs(self.a + self.b - res_m).max(), EPS64, "Exact mismatches"
+ )
+ self.assertLess(
+ abs(self.a + self.b - res).max(), 3 * EPS32**2, "Exact matches"
+ )
def test_dw_times_fp(self):
- test_kernel = ElementwiseKernel(self.ctx,
- "float *ah, float *al, float *b, float *res_h, float *res_l",
- "float2 tmp = dw_times_fp((float2)(ah[i], al[i]),b[i]); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
- preamble=self.doubleword)
+ test_kernel = ElementwiseKernel(
+ self.ctx,
+ "float *ah, float *al, float *b, float *res_h, float *res_l",
+ "float2 tmp = dw_times_fp((float2)(ah[i], al[i]),b[i]); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
+ preamble=self.doubleword,
+ )
ah_g = pyopencl.array.to_device(self.queue, self.ah)
al_g = pyopencl.array.to_device(self.queue, self.al)
b_g = pyopencl.array.to_device(self.queue, self.bh)
@@ -186,14 +246,20 @@ class TestDoubleWord(unittest.TestCase):
res_m = res_h.get()
res = res_h.get().astype(numpy.float64) + res_l.get()
self.assertLess(abs(self.a * self.bh - res_m).max(), EPS32, "Major matches")
- self.assertGreater(abs(self.a * self.bh - res_m).max(), EPS64, "Exact mismatches")
- self.assertLess(abs(self.a * self.bh - res).max(), 2 * EPS32 ** 2, "Exact matches")
+ self.assertGreater(
+ abs(self.a * self.bh - res_m).max(), EPS64, "Exact mismatches"
+ )
+ self.assertLess(
+ abs(self.a * self.bh - res).max(), 2 * EPS32**2, "Exact matches"
+ )
def test_dw_times_dw(self):
- test_kernel = ElementwiseKernel(self.ctx,
- "float *ah, float *al, float *bh, float *bl, float *res_h, float *res_l",
- "float2 tmp = dw_times_dw((float2)(ah[i], al[i]),(float2)(bh[i], bl[i])); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
- preamble=self.doubleword)
+ test_kernel = ElementwiseKernel(
+ self.ctx,
+ "float *ah, float *al, float *bh, float *bl, float *res_h, float *res_l",
+ "float2 tmp = dw_times_dw((float2)(ah[i], al[i]),(float2)(bh[i], bl[i])); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
+ preamble=self.doubleword,
+ )
ah_g = pyopencl.array.to_device(self.queue, self.ah)
al_g = pyopencl.array.to_device(self.queue, self.al)
bh_g = pyopencl.array.to_device(self.queue, self.bh)
@@ -204,14 +270,20 @@ class TestDoubleWord(unittest.TestCase):
res_m = res_h.get()
res = res_h.get().astype(numpy.float64) + res_l.get()
self.assertLess(abs(self.a * self.b - res_m).max(), EPS32, "Major matches")
- self.assertGreater(abs(self.a * self.b - res_m).max(), EPS64, "Exact mismatches")
- self.assertLess(abs(self.a * self.b - res).max(), 5 * EPS32 ** 2, "Exact matches")
+ self.assertGreater(
+ abs(self.a * self.b - res_m).max(), EPS64, "Exact mismatches"
+ )
+ self.assertLess(
+ abs(self.a * self.b - res).max(), 5 * EPS32**2, "Exact matches"
+ )
def test_dw_div_fp(self):
- test_kernel = ElementwiseKernel(self.ctx,
- "float *ah, float *al, float *b, float *res_h, float *res_l",
- "float2 tmp = dw_div_fp((float2)(ah[i], al[i]),b[i]); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
- preamble=self.doubleword)
+ test_kernel = ElementwiseKernel(
+ self.ctx,
+ "float *ah, float *al, float *b, float *res_h, float *res_l",
+ "float2 tmp = dw_div_fp((float2)(ah[i], al[i]),b[i]); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
+ preamble=self.doubleword,
+ )
ah_g = pyopencl.array.to_device(self.queue, self.ah)
al_g = pyopencl.array.to_device(self.queue, self.al)
b_g = pyopencl.array.to_device(self.queue, self.bh)
@@ -221,14 +293,20 @@ class TestDoubleWord(unittest.TestCase):
res_m = res_h.get()
res = res_h.get().astype(numpy.float64) + res_l.get()
self.assertLess(abs(self.a / self.bh - res_m).max(), EPS32, "Major matches")
- self.assertGreater(abs(self.a / self.bh - res_m).max(), EPS64, "Exact mismatches")
- self.assertLess(abs(self.a / self.bh - res).max(), 3 * EPS32 ** 2, "Exact matches")
+ self.assertGreater(
+ abs(self.a / self.bh - res_m).max(), EPS64, "Exact mismatches"
+ )
+ self.assertLess(
+ abs(self.a / self.bh - res).max(), 3 * EPS32**2, "Exact matches"
+ )
def test_dw_div_dw(self):
- test_kernel = ElementwiseKernel(self.ctx,
- "float *ah, float *al, float *bh, float *bl, float *res_h, float *res_l",
- "float2 tmp = dw_div_dw((float2)(ah[i], al[i]),(float2)(bh[i], bl[i])); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
- preamble=self.doubleword)
+ test_kernel = ElementwiseKernel(
+ self.ctx,
+ "float *ah, float *al, float *bh, float *bl, float *res_h, float *res_l",
+ "float2 tmp = dw_div_dw((float2)(ah[i], al[i]),(float2)(bh[i], bl[i])); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
+ preamble=self.doubleword,
+ )
ah_g = pyopencl.array.to_device(self.queue, self.ah)
al_g = pyopencl.array.to_device(self.queue, self.al)
bh_g = pyopencl.array.to_device(self.queue, self.bh)
@@ -239,5 +317,9 @@ class TestDoubleWord(unittest.TestCase):
res_m = res_h.get()
res = res_h.get().astype(numpy.float64) + res_l.get()
self.assertLess(abs(self.a / self.b - res_m).max(), EPS32, "Major matches")
- self.assertGreater(abs(self.a / self.b - res_m).max(), EPS64, "Exact mismatches")
- self.assertLess(abs(self.a / self.b - res).max(), 6 * EPS32 ** 2, "Exact matches")
+ self.assertGreater(
+ abs(self.a / self.b - res_m).max(), EPS64, "Exact mismatches"
+ )
+ self.assertLess(
+ abs(self.a / self.b - res).max(), 6 * EPS32**2, "Exact matches"
+ )
diff --git a/src/silx/opencl/test/test_image.py b/src/silx/opencl/test/test_image.py
index 4ea8960..691ea82 100644
--- a/src/silx/opencl/test/test_image.py
+++ b/src/silx/opencl/test/test_image.py
@@ -39,11 +39,13 @@ import numpy
import unittest
from ..common import ocl, _measure_workgroup_size
+
if ocl:
import pyopencl
import pyopencl.array
from ...test.utils import utilstest
from ..image import ImageProcessing
+
logger = logging.getLogger(__name__)
try:
from PIL import Image
@@ -53,7 +55,6 @@ except ImportError:
@unittest.skipUnless(ocl and Image, "PyOpenCl/Image is missing")
class TestImage(unittest.TestCase):
-
@classmethod
def setUpClass(cls):
super(TestImage, cls).setUpClass()
@@ -99,7 +100,9 @@ class TestImage(unittest.TestCase):
tmp = pyopencl.array.empty(self.ip.ctx, self.data.shape, "float32")
res = self.ip.to_float(self.data, out=tmp)
res2 = self.ip.normalize(tmp, -100, 100, copy=False)
- norm = (self.data.astype(numpy.float32) - self.data.min()) / (self.data.max() - self.data.min())
+ norm = (self.data.astype(numpy.float32) - self.data.min()) / (
+ self.data.max() - self.data.min()
+ )
ref2 = 200 * norm - 100
self.assertLess(abs(res2 - ref2).max(), 3e-5, "content")
@@ -108,15 +111,21 @@ class TestImage(unittest.TestCase):
"""
Test on a greyscaled image ... of Lena :)
"""
- lena_bw = (0.2126 * self.data[:, :, 0] +
- 0.7152 * self.data[:, :, 1] +
- 0.0722 * self.data[:, :, 2]).astype("int32")
+ lena_bw = (
+ 0.2126 * self.data[:, :, 0]
+ + 0.7152 * self.data[:, :, 1]
+ + 0.0722 * self.data[:, :, 2]
+ ).astype("int32")
ref = numpy.histogram(lena_bw, 255)
ip = ImageProcessing(ctx=self.ctx, template=lena_bw, profile=True)
res = ip.histogram(lena_bw, 255)
ip.log_profile()
- delta = (ref[0] - res[0])
- deltap = (ref[1] - res[1])
+ delta = ref[0] - res[0]
+ deltap = ref[1] - res[1]
self.assertEqual(delta.sum(), 0, "errors are self-compensated")
self.assertLessEqual(abs(delta).max(), 1, "errors are small")
- self.assertLessEqual(abs(deltap).max(), 3e-5, "errors on position are small: %s" % (abs(deltap).max()))
+ self.assertLessEqual(
+ abs(deltap).max(),
+ 3e-5,
+ "errors on position are small: %s" % (abs(deltap).max()),
+ )
diff --git a/src/silx/opencl/test/test_kahan.py b/src/silx/opencl/test/test_kahan.py
index 62ed047..069d7de 100644
--- a/src/silx/opencl/test/test_kahan.py
+++ b/src/silx/opencl/test/test_kahan.py
@@ -47,6 +47,7 @@ except ImportError as error:
pyopencl = None
from .. import ocl
+
if ocl is not None:
from ..utils import read_cl_file
from .. import pyopencl
@@ -61,14 +62,21 @@ class TestKahan(unittest.TestCase):
@classmethod
def setUpClass(cls):
if pyopencl is None or ocl is None:
- raise unittest.SkipTest("OpenCL module (pyopencl) is not present or no device available")
+ raise unittest.SkipTest(
+ "OpenCL module (pyopencl) is not present or no device available"
+ )
cls.ctx = ocl.create_context(devicetype="GPU")
- cls.queue = pyopencl.CommandQueue(cls.ctx, properties=pyopencl.command_queue_properties.PROFILING_ENABLE)
+ cls.queue = pyopencl.CommandQueue(
+ cls.ctx, properties=pyopencl.command_queue_properties.PROFILING_ENABLE
+ )
# this is running 32 bits OpenCL woth POCL
- if (platform.machine() in ("i386", "i686", "x86_64") and (tuple.__itemsize__ == 4) and
- cls.ctx.devices[0].platform.name == 'Portable Computing Language'):
+ if (
+ platform.machine() in ("i386", "i686", "x86_64")
+ and (tuple.__itemsize__ == 4)
+ and cls.ctx.devices[0].platform.name == "Portable Computing Language"
+ ):
cls.args = "-DX87_VOLATILE=volatile"
else:
cls.args = ""
@@ -80,7 +88,7 @@ class TestKahan(unittest.TestCase):
@staticmethod
def dummy_sum(ary, dtype=None):
- "perform the actual sum in a dummy way "
+ "perform the actual sum in a dummy way"
if dtype is None:
dtype = ary.dtype.type
sum_ = dtype(0)
@@ -95,8 +103,10 @@ class TestKahan(unittest.TestCase):
ref64 = numpy.sum(data, dtype=numpy.float64)
ref32 = self.dummy_sum(data)
- if (ref64 == ref32):
- logger.warning("Kahan: invalid tests as float32 provides the same result as float64")
+ if ref64 == ref32:
+ logger.warning(
+ "Kahan: invalid tests as float32 provides the same result as float64"
+ )
# Dummy kernel to evaluate
src = """
kernel void summation(global float* data,
@@ -112,11 +122,15 @@ class TestKahan(unittest.TestCase):
result[1] = acc.s1;
}
"""
- prg = pyopencl.Program(self.ctx, read_cl_file("kahan.cl") + src).build(self.args)
+ prg = pyopencl.Program(self.ctx, read_cl_file("kahan.cl") + src).build(
+ self.args
+ )
ones_d = pyopencl.array.to_device(self.queue, data)
res_d = pyopencl.array.empty(self.queue, 2, numpy.float32)
res_d.fill(0)
- evt = prg.summation(self.queue, (1,), (1,), ones_d.data, numpy.int32(N), res_d.data)
+ evt = prg.summation(
+ self.queue, (1,), (1,), ones_d.data, numpy.int32(N), res_d.data
+ )
evt.wait()
res = res_d.get().sum(dtype=numpy.float64)
self.assertEqual(ref64, res, "test_kahan")
@@ -128,8 +142,10 @@ class TestKahan(unittest.TestCase):
ref64 = numpy.dot(data.astype(numpy.float64), data.astype(numpy.float64))
ref32 = numpy.dot(data, data)
- if (ref64 == ref32):
- logger.warning("dot16: invalid tests as float32 provides the same result as float64")
+ if ref64 == ref32:
+ logger.warning(
+ "dot16: invalid tests as float32 provides the same result as float64"
+ )
# Dummy kernel to evaluate
src = """
kernel void test_dot16(global float* data,
@@ -195,11 +211,15 @@ class TestKahan(unittest.TestCase):
"""
- prg = pyopencl.Program(self.ctx, read_cl_file("kahan.cl") + src).build(self.args)
+ prg = pyopencl.Program(self.ctx, read_cl_file("kahan.cl") + src).build(
+ self.args
+ )
ones_d = pyopencl.array.to_device(self.queue, data)
res_d = pyopencl.array.empty(self.queue, 2, numpy.float32)
res_d.fill(0)
- evt = prg.test_dot16(self.queue, (1,), (1,), ones_d.data, numpy.int32(N), res_d.data)
+ evt = prg.test_dot16(
+ self.queue, (1,), (1,), ones_d.data, numpy.int32(N), res_d.data
+ )
evt.wait()
res = res_d.get().sum(dtype="float64")
self.assertEqual(ref64, res, "test_dot16")
@@ -209,9 +229,13 @@ class TestKahan(unittest.TestCase):
data1 = data[1::2]
ref64 = numpy.dot(data0.astype(numpy.float64), data1.astype(numpy.float64))
ref32 = numpy.dot(data0, data1)
- if (ref64 == ref32):
- logger.warning("dot8: invalid tests as float32 provides the same result as float64")
- evt = prg.test_dot8(self.queue, (1,), (1,), ones_d.data, numpy.int32(N), res_d.data)
+ if ref64 == ref32:
+ logger.warning(
+ "dot8: invalid tests as float32 provides the same result as float64"
+ )
+ evt = prg.test_dot8(
+ self.queue, (1,), (1,), ones_d.data, numpy.int32(N), res_d.data
+ )
evt.wait()
res = res_d.get().sum(dtype="float64")
self.assertEqual(ref64, res, "test_dot8")
@@ -221,9 +245,13 @@ class TestKahan(unittest.TestCase):
data1 = data[3::4]
ref64 = numpy.dot(data0.astype(numpy.float64), data1.astype(numpy.float64))
ref32 = numpy.dot(data0, data1)
- if (ref64 == ref32):
- logger.warning("dot4: invalid tests as float32 provides the same result as float64")
- evt = prg.test_dot4(self.queue, (1,), (1,), ones_d.data, numpy.int32(N), res_d.data)
+ if ref64 == ref32:
+ logger.warning(
+ "dot4: invalid tests as float32 provides the same result as float64"
+ )
+ evt = prg.test_dot4(
+ self.queue, (1,), (1,), ones_d.data, numpy.int32(N), res_d.data
+ )
evt.wait()
res = res_d.get().sum(dtype="float64")
self.assertEqual(ref64, res, "test_dot4")
@@ -233,9 +261,13 @@ class TestKahan(unittest.TestCase):
data1 = numpy.array([data[3], data[11], data[15]])
ref64 = numpy.dot(data0.astype(numpy.float64), data1.astype(numpy.float64))
ref32 = numpy.dot(data0, data1)
- if (ref64 == ref32):
- logger.warning("dot3: invalid tests as float32 provides the same result as float64")
- evt = prg.test_dot3(self.queue, (1,), (1,), ones_d.data, numpy.int32(N), res_d.data)
+ if ref64 == ref32:
+ logger.warning(
+ "dot3: invalid tests as float32 provides the same result as float64"
+ )
+ evt = prg.test_dot3(
+ self.queue, (1,), (1,), ones_d.data, numpy.int32(N), res_d.data
+ )
evt.wait()
res = res_d.get().sum(dtype="float64")
self.assertEqual(ref64, res, "test_dot3")
@@ -245,9 +277,13 @@ class TestKahan(unittest.TestCase):
data1 = numpy.array([data[1], data[15]])
ref64 = numpy.dot(data0.astype(numpy.float64), data1.astype(numpy.float64))
ref32 = numpy.dot(data0, data1)
- if (ref64 == ref32):
- logger.warning("dot2: invalid tests as float32 provides the same result as float64")
- evt = prg.test_dot2(self.queue, (1,), (1,), ones_d.data, numpy.int32(N), res_d.data)
+ if ref64 == ref32:
+ logger.warning(
+ "dot2: invalid tests as float32 provides the same result as float64"
+ )
+ evt = prg.test_dot2(
+ self.queue, (1,), (1,), ones_d.data, numpy.int32(N), res_d.data
+ )
evt.wait()
res = res_d.get().sum(dtype="float64")
self.assertEqual(ref64, res, "test_dot2")
diff --git a/src/silx/opencl/test/test_linalg.py b/src/silx/opencl/test/test_linalg.py
index da99480..0b0a443 100644
--- a/src/silx/opencl/test/test_linalg.py
+++ b/src/silx/opencl/test/test_linalg.py
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# /*##########################################################################
#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2022 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -30,15 +30,16 @@ __copyright__ = "2013-2017 European Synchrotron Radiation Facility, Grenoble, Fr
__date__ = "01/08/2019"
-import time
import logging
import numpy as np
import unittest
+
try:
import mako
except ImportError:
mako = None
from ..common import ocl
+
if ocl:
import pyopencl as cl
import pyopencl.array as parray
@@ -47,7 +48,8 @@ from silx.test.utils import utilstest
logger = logging.getLogger(__name__)
try:
- from scipy.ndimage.filters import laplace
+ from scipy.ndimage import laplace
+
_has_scipy = True
except ImportError:
_has_scipy = False
@@ -55,13 +57,18 @@ except ImportError:
# TODO move this function in math or image ?
def gradient(img):
- '''
+ """
Compute the gradient of an image as a numpy array
Code from https://github.com/emmanuelle/tomo-tv/
- '''
- shape = [img.ndim, ] + list(img.shape)
+ """
+ shape = [
+ img.ndim,
+ ] + list(img.shape)
gradient = np.zeros(shape, dtype=img.dtype)
- slice_all = [0, slice(None, -1),]
+ slice_all = [
+ 0,
+ slice(None, -1),
+ ]
for d in range(img.ndim):
gradient[tuple(slice_all)] = np.diff(img, axis=d)
slice_all[0] = d + 1
@@ -71,10 +78,10 @@ def gradient(img):
# TODO move this function in math or image ?
def divergence(grad):
- '''
+ """
Compute the divergence of a gradient
Code from https://github.com/emmanuelle/tomo-tv/
- '''
+ """
res = np.zeros(grad.shape[1:])
for d in range(grad.shape[0]):
this_grad = np.rollaxis(grad[d], d)
@@ -87,7 +94,6 @@ def divergence(grad):
@unittest.skipUnless(ocl and mako, "PyOpenCl is missing")
class TestLinAlg(unittest.TestCase):
-
def setUp(self):
if ocl is None:
return
@@ -106,7 +112,9 @@ class TestLinAlg(unittest.TestCase):
self.div_ref = divergence(self.grad_ref)
self.image2 = np.zeros_like(self.image)
# Device images
- self.gradient_parray = parray.empty(self.la.queue, self.image.shape, np.complex64)
+ self.gradient_parray = parray.empty(
+ self.la.queue, self.image.shape, np.complex64
+ )
self.gradient_parray.fill(0)
# we should be using cl.Buffer(self.la.ctx, cl.mem_flags.READ_WRITE, size=self.image.nbytes*2),
# but platforms not suporting openCL 1.2 have a problem with enqueue_fill_buffer,
@@ -153,46 +161,78 @@ class TestLinAlg(unittest.TestCase):
arrays = {
"numpy.ndarray": self.image,
"buffer": self.image_buffer,
- "parray": self.image_parray
+ "parray": self.image_parray,
}
for desc, image in arrays.items():
# Test with dst on host (numpy.ndarray)
res = self.la.gradient(image, return_to_host=True)
- self.compare(res, self.grad_ref, 1e-6, str("gradient[src=%s, dst=numpy.ndarray]" % desc))
+ self.compare(
+ res,
+ self.grad_ref,
+ 1e-6,
+ str("gradient[src=%s, dst=numpy.ndarray]" % desc),
+ )
# Test with dst on device (pyopencl.Buffer)
self.la.gradient(image, dst=self.gradient_buffer)
cl.enqueue_copy(self.la.queue, self.grad, self.gradient_buffer)
self.grad2[0] = self.grad.real
self.grad2[1] = self.grad.imag
- self.compare(self.grad2, self.grad_ref, 1e-6, str("gradient[src=%s, dst=buffer]" % desc))
+ self.compare(
+ self.grad2,
+ self.grad_ref,
+ 1e-6,
+ str("gradient[src=%s, dst=buffer]" % desc),
+ )
# Test with dst on device (pyopencl.Array)
self.la.gradient(image, dst=self.gradient_parray)
self.grad = self.gradient_parray.get()
self.grad2[0] = self.grad.real
self.grad2[1] = self.grad.imag
- self.compare(self.grad2, self.grad_ref, 1e-6, str("gradient[src=%s, dst=parray]" % desc))
+ self.compare(
+ self.grad2,
+ self.grad_ref,
+ 1e-6,
+ str("gradient[src=%s, dst=parray]" % desc),
+ )
@unittest.skipUnless(ocl and mako, "pyopencl is missing")
def test_divergence(self):
arrays = {
"numpy.ndarray": self.grad_ref,
"buffer": self.grad_ref_buffer,
- "parray": self.grad_ref_parray
+ "parray": self.grad_ref_parray,
}
for desc, grad in arrays.items():
# Test with dst on host (numpy.ndarray)
res = self.la.divergence(grad, return_to_host=True)
- self.compare(res, self.div_ref, 1e-6, str("divergence[src=%s, dst=numpy.ndarray]" % desc))
+ self.compare(
+ res,
+ self.div_ref,
+ 1e-6,
+ str("divergence[src=%s, dst=numpy.ndarray]" % desc),
+ )
# Test with dst on device (pyopencl.Buffer)
self.la.divergence(grad, dst=self.image_buffer)
cl.enqueue_copy(self.la.queue, self.image2, self.image_buffer)
- self.compare(self.image2, self.div_ref, 1e-6, str("divergence[src=%s, dst=buffer]" % desc))
+ self.compare(
+ self.image2,
+ self.div_ref,
+ 1e-6,
+ str("divergence[src=%s, dst=buffer]" % desc),
+ )
# Test with dst on device (pyopencl.Array)
self.la.divergence(grad, dst=self.image_parray)
self.image2 = self.image_parray.get()
- self.compare(self.image2, self.div_ref, 1e-6, str("divergence[src=%s, dst=parray]" % desc))
-
- @unittest.skipUnless(ocl and mako and _has_scipy, "pyopencl and/or scipy is missing")
+ self.compare(
+ self.image2,
+ self.div_ref,
+ 1e-6,
+ str("divergence[src=%s, dst=parray]" % desc),
+ )
+
+ @unittest.skipUnless(
+ ocl and mako and _has_scipy, "pyopencl and/or scipy is missing"
+ )
def test_laplacian(self):
laplacian_ref = laplace(self.image)
# Laplacian = div(grad)
diff --git a/src/silx/opencl/test/test_medfilt.py b/src/silx/opencl/test/test_medfilt.py
index e657d0d..2ef4490 100644
--- a/src/silx/opencl/test/test_medfilt.py
+++ b/src/silx/opencl/test/test_medfilt.py
@@ -31,8 +31,8 @@ Simple test of the median filter
__authors__ = ["Jérôme Kieffer"]
__contact__ = "jerome.kieffer@esrf.eu"
__license__ = "MIT"
-__copyright__ = "2013-2017 European Synchrotron Radiation Facility, Grenoble, France"
-__date__ = "05/07/2018"
+__copyright__ = "2013-2022 European Synchrotron Radiation Facility, Grenoble, France"
+__date__ = "09/05/2023"
import sys
@@ -41,11 +41,13 @@ import logging
import numpy
import unittest
from collections import namedtuple
+
try:
import mako
except ImportError:
mako = None
from ..common import ocl
+
if ocl:
import pyopencl
import pyopencl.array
@@ -58,20 +60,26 @@ Result = namedtuple("Result", ["size", "error", "sp_time", "oc_time"])
try:
from scipy.misc import ascent
except:
- def ascent():
- """Dummy image from random data"""
- return numpy.random.random((512, 512))
+ try:
+ from scipy.datasets import ascent
+ except:
+
+ def ascent():
+ """Dummy image from random data"""
+ return numpy.random.random((512, 512))
+
+
try:
- from scipy.ndimage import filters
- median_filter = filters.median_filter
+ from scipy.ndimage import median_filter
+
HAS_SCIPY = True
except:
HAS_SCIPY = False
from silx.math import medfilt2d as median_filter
+
@unittest.skipUnless(ocl and mako, "PyOpenCl is missing")
class TestMedianFilter(unittest.TestCase):
-
def setUp(self):
if ocl is None:
return
@@ -108,8 +116,15 @@ class TestMedianFilter(unittest.TestCase):
if r is None:
logger.info("test_medfilt: size: %s: skipped")
else:
- logger.info("test_medfilt: size: %s error %s, t_ref: %.3fs, t_ocl: %.3fs" % r)
- self.assertEqual(r.error, 0, 'Results are correct')
+ logger.info(
+ "test_medfilt: size: %s error %s, t_ref: %.3fs, t_ocl: %.3fs" % r
+ )
+ if (
+ self.medianfilter.device.platform.name.lower()
+ != "portable computing language"
+ ):
+ # Known broken
+ self.assertEqual(r.error, 0, "Results are correct")
def benchmark(self, limit=36):
"Run some benchmarking"
diff --git a/src/silx/opencl/test/test_projection.py b/src/silx/opencl/test/test_projection.py
index d093e4b..550a2f6 100644
--- a/src/silx/opencl/test/test_projection.py
+++ b/src/silx/opencl/test/test_projection.py
@@ -34,11 +34,13 @@ import time
import logging
import numpy as np
import unittest
+
try:
import mako
except ImportError:
mako = None
from ..common import ocl
+
if ocl:
from .. import projection
from silx.test.utils import utilstest
@@ -48,17 +50,18 @@ logger = logging.getLogger(__name__)
@unittest.skipUnless(ocl and mako, "PyOpenCl is missing")
class TestProj(unittest.TestCase):
-
def setUp(self):
if ocl is None:
return
# ~ if sys.platform.startswith('darwin'):
- # ~ self.skipTest("Projection is not implemented on CPU for OS X yet")
+ # ~ self.skipTest("Projection is not implemented on CPU for OS X yet")
self.getfiles()
n_angles = self.sino.shape[0]
self.proj = projection.Projection(self.phantom.shape, n_angles)
if self.proj.compiletime_workgroup_size < 16 * 16:
- self.skipTest("Current implementation of OpenCL projection is not supported on this platform yet")
+ self.skipTest(
+ "Current implementation of OpenCL projection is not supported on this platform yet"
+ )
def tearDown(self):
self.phantom = None
@@ -108,11 +111,11 @@ class TestProj(unittest.TestCase):
msg = str("Max error = %e" % err)
logger.info(msg)
# Interpolation differs at some lines, giving relative error of 10/50000
- self.assertTrue(err < 20., "Max error is too high")
+ self.assertTrue(err < 20.0, "Max error is too high")
# Test multiple reconstructions
# -----------------------------
res0 = np.copy(res)
for i in range(10):
res = self.proj.projection(self.phantom)
errmax = np.max(np.abs(res - res0))
- self.assertTrue(errmax < 1.e-6, "Max error is too high")
+ self.assertTrue(errmax < 1.0e-6, "Max error is too high")
diff --git a/src/silx/opencl/test/test_sparse.py b/src/silx/opencl/test/test_sparse.py
index 62a1399..db58220 100644
--- a/src/silx/opencl/test/test_sparse.py
+++ b/src/silx/opencl/test/test_sparse.py
@@ -29,6 +29,7 @@ import unittest
import logging
from itertools import product
from ..common import ocl
+
if ocl:
import pyopencl.array as parray
from silx.opencl.sparse import CSR
@@ -39,13 +40,14 @@ except ImportError:
logger = logging.getLogger(__name__)
-
def generate_sparse_random_data(
shape=(1000,),
- data_min=0, data_max=100,
+ data_min=0,
+ data_max=100,
density=0.1,
use_only_integers=True,
- dtype="f"):
+ dtype="f",
+):
"""
Generate random sparse data where.
@@ -75,7 +77,6 @@ def generate_sparse_random_data(
return (d * mask).astype(dtype)
-
@unittest.skipUnless(ocl and sp, "PyOpenCl/scipy is missing")
class TestCSR(unittest.TestCase):
"""Test CSR format"""
@@ -87,20 +88,19 @@ class TestCSR(unittest.TestCase):
dtypes = [np.float32, np.int32, np.uint16]
self._test_configs = list(product(input_on_device, output_on_device, dtypes))
-
def compute_ref_sparsification(self, array):
ref_sparse = sp.csr_matrix(array)
return ref_sparse
-
def test_sparsification(self):
for input_on_device, output_on_device, dtype in self._test_configs:
self._test_sparsification(input_on_device, output_on_device, dtype)
-
def _test_sparsification(self, input_on_device, output_on_device, dtype):
current_config = "input on device: %s, output on device: %s, dtype: %s" % (
- str(input_on_device), str(output_on_device), str(dtype)
+ str(input_on_device),
+ str(output_on_device),
+ str(dtype),
)
logger.debug("CSR: %s" % current_config)
# Generate data and reference CSR
@@ -132,29 +132,27 @@ class TestCSR(unittest.TestCase):
nnz = ref_sparse.nnz
self.assertTrue(
np.allclose(data[:nnz], ref_sparse.data),
- "something wrong with sparsified data (%s)"
- % current_config
+ "something wrong with sparsified data (%s)" % current_config,
)
self.assertTrue(
np.allclose(indices[:nnz], ref_sparse.indices),
- "something wrong with sparsified indices (%s)"
- % current_config
+ "something wrong with sparsified indices (%s)" % current_config,
)
self.assertTrue(
np.allclose(indptr, ref_sparse.indptr),
"something wrong with sparsified indices pointers (indptr) (%s)"
- % current_config
+ % current_config,
)
-
def test_desparsification(self):
for input_on_device, output_on_device, dtype in self._test_configs:
self._test_desparsification(input_on_device, output_on_device, dtype)
-
def _test_desparsification(self, input_on_device, output_on_device, dtype):
current_config = "input on device: %s, output on device: %s, dtype: %s" % (
- str(input_on_device), str(output_on_device), str(dtype)
+ str(input_on_device),
+ str(output_on_device),
+ str(dtype),
)
logger.debug("CSR: %s" % current_config)
# Generate data and reference CSR
@@ -182,6 +180,5 @@ class TestCSR(unittest.TestCase):
# Compare
self.assertTrue(
np.allclose(arr.reshape(array.shape), array),
- "something wrong with densified data (%s)"
- % current_config
+ "something wrong with densified data (%s)" % current_config,
)
diff --git a/src/silx/opencl/test/test_stats.py b/src/silx/opencl/test/test_stats.py
index f8ab1a7..7637211 100644
--- a/src/silx/opencl/test/test_stats.py
+++ b/src/silx/opencl/test/test_stats.py
@@ -39,17 +39,18 @@ import numpy
import unittest
from ..common import ocl
+
if ocl:
import pyopencl
import pyopencl.array
from ..statistics import StatResults, Statistics
from ..utils import get_opencl_code
+
logger = logging.getLogger(__name__)
@unittest.skipUnless(ocl, "PyOpenCl is missing")
class TestStatistics(unittest.TestCase):
-
@classmethod
def setUpClass(cls):
cls.size = 1 << 20 # 1 million elements
@@ -57,9 +58,15 @@ class TestStatistics(unittest.TestCase):
fdata = cls.data.astype("float64")
t0 = time.perf_counter()
std = fdata.std()
- cls.ref = StatResults(fdata.min(), fdata.max(), float(fdata.size),
- fdata.sum(), fdata.mean(), std ** 2,
- std)
+ cls.ref = StatResults(
+ fdata.min(),
+ fdata.max(),
+ float(fdata.size),
+ fdata.sum(),
+ fdata.mean(),
+ std**2,
+ std,
+ )
t1 = time.perf_counter()
cls.ref_time = t1 - t0
@@ -70,11 +77,12 @@ class TestStatistics(unittest.TestCase):
@classmethod
def validate(cls, res):
return (
- (res.min == cls.ref.min) and
- (res.max == cls.ref.max) and
- (res.cnt == cls.ref.cnt) and
- abs(res.mean - cls.ref.mean) < 0.01 and
- abs(res.std - cls.ref.std) < 0.1)
+ (res.min == cls.ref.min)
+ and (res.max == cls.ref.max)
+ and (res.cnt == cls.ref.cnt)
+ and abs(res.mean - cls.ref.mean) < 0.01
+ and abs(res.std - cls.ref.std) < 0.1
+ )
def test_measurement(self):
"""
@@ -95,11 +103,26 @@ class TestStatistics(unittest.TestCase):
t0 = time.perf_counter()
res = s(self.data, comp=comp)
t1 = time.perf_counter()
- logger.info("Runtime on %s/%s : %.3fms x%.1f", platform, device, 1000 * (t1 - t0), self.ref_time / (t1 - t0))
+ logger.info(
+ "Runtime on %s/%s : %.3fms x%.1f",
+ platform,
+ device,
+ 1000 * (t1 - t0),
+ self.ref_time / (t1 - t0),
+ )
if failed_init or not self.validate(res):
- logger.error("failed_init %s; Computation modes %s", failed_init, comp)
- logger.error("Failed on platform %s device %s", platform, device)
+ logger.error(
+ "failed_init %s; Computation modes %s",
+ failed_init,
+ comp,
+ )
+ logger.error(
+ "Failed on platform %s device %s", platform, device
+ )
logger.error("Reference results: %s", self.ref)
logger.error("Faulty results: %s", res)
- self.assertTrue(False, f"Stat calculation failed on {platform},{device} in mode {comp}")
+ self.assertTrue(
+ False,
+ f"Stat calculation failed on {platform},{device} in mode {comp}",
+ )
diff --git a/src/silx/opencl/utils.py b/src/silx/opencl/utils.py
index cc9f62d..c332402 100644
--- a/src/silx/opencl/utils.py
+++ b/src/silx/opencl/utils.py
@@ -43,9 +43,13 @@ def calc_size(shape, blocksize):
Calculate the optimal size for a kernel according to the workgroup size
"""
if "__len__" in dir(blocksize):
- return tuple((int(i) + int(j) - 1) & ~(int(j) - 1) for i, j in zip(shape, blocksize))
+ return tuple(
+ (int(i) + int(j) - 1) & ~(int(j) - 1) for i, j in zip(shape, blocksize)
+ )
else:
- return tuple((int(i) + int(blocksize) - 1) & ~(int(blocksize) - 1) for i in shape)
+ return tuple(
+ (int(i) + int(blocksize) - 1) & ~(int(blocksize) - 1) for i in shape
+ )
def nextpower(n):
@@ -88,8 +92,7 @@ def get_cl_file(resource):
"""
if not resource.endswith(".cl"):
resource += ".cl"
- return resources._resource_filename(resource,
- default_directory="opencl")
+ return resources._resource_filename(resource, default_directory="opencl")
def read_cl_file(filename):
@@ -118,8 +121,6 @@ def concatenate_cl_kernel(filenames):
return os.linesep.join(read_cl_file(fn) for fn in filenames)
-
-
class ConvolutionInfos(object):
allowed_axes = {
"1D": [None],
@@ -132,10 +133,10 @@ class ConvolutionInfos(object):
(2, 0, 1),
(2, 1, 0),
(1, 0, 2),
- (0, 2, 1)
+ (0, 2, 1),
],
"batched_1D_3D": [(0,), (1,), (2,)],
- "batched_separable_2D_1D_3D": [(0,), (1,), (2,)], # unsupported (?)
+ "batched_separable_2D_1D_3D": [(0,), (1,), (2,)], # unsupported (?)
"2D": [None],
"batched_2D_3D": [(0,), (1,), (2,)],
"separable_3D_2D_3D": [
@@ -202,10 +203,3 @@ class ConvolutionInfos(object):
},
},
}
-
-
-
-
-
-
-
diff --git a/src/silx/resources/__init__.py b/src/silx/resources/__init__.py
index b53f15b..4946600 100644
--- a/src/silx/resources/__init__.py
+++ b/src/silx/resources/__init__.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -27,7 +27,7 @@ All access to data and documentation files MUST be made through the functions
of this modules to ensure access across different distribution schemes:
- Installing from source or from wheel
-- Installing package as a zip (through the use of pkg_resources)
+- Installing package as a zip
- Linux packaging willing to install data files (and doc files) in
alternative folders. In this case, this file must be patched.
- Frozen fat binary application using silx (frozen with cx_Freeze or py2app).
@@ -52,28 +52,28 @@ of this modules to ensure access across different distribution schemes:
options={'py2app': {'packages': ['silx']}}
)
"""
+from __future__ import annotations
__authors__ = ["V.A. Sole", "Thomas Vincent", "J. Kieffer"]
__license__ = "MIT"
__date__ = "08/03/2019"
+import atexit
+import contextlib
+import functools
+import importlib
+import importlib.resources
+import logging
import os
import sys
-import logging
-import importlib
-
+from typing import NamedTuple, Optional
-logger = logging.getLogger(__name__)
+if sys.version_info < (3, 9):
+ import pkg_resources
-# pkg_resources is useful when this package is stored in a zip
-# When pkg_resources is not available, the resources dir defaults to the
-# directory containing this module.
-try:
- import pkg_resources
-except ImportError:
- pkg_resources = None
+logger = logging.getLogger(__name__)
# For packaging purpose, patch this variable to use an alternative directory
@@ -87,66 +87,56 @@ _RESOURCES_DIR = None
# cx_Freeze frozen support
# See http://cx-freeze.readthedocs.io/en/latest/faq.html#using-data-files
-if getattr(sys, 'frozen', False):
+if getattr(sys, "frozen", False):
# Running in a frozen application:
# We expect resources to be located either in a silx/resources/ dir
# relative to the executable or within this package.
- _dir = os.path.join(os.path.dirname(sys.executable), 'silx', 'resources')
+ _dir = os.path.join(os.path.dirname(sys.executable), "silx", "resources")
if os.path.isdir(_dir):
_RESOURCES_DIR = _dir
-class _ResourceDirectory(object):
+class _ResourceDirectory(NamedTuple):
"""Store a source of resources"""
- def __init__(self, package_name, package_path=None, forced_path=None):
- if forced_path is None:
- if package_path is None:
- if pkg_resources is None:
- # In this case we have to compute the package path
- # Else it will not be used
- module = importlib.import_module(package_name)
- package_path = os.path.abspath(os.path.dirname(module.__file__))
- self.package_name = package_name
- self.package_path = package_path
- self.forced_path = forced_path
+ package_name: str
+ forced_path: Optional[str] = None
-_SILX_DIRECTORY = _ResourceDirectory(
- package_name=__name__,
- package_path=os.path.abspath(os.path.dirname(__file__)),
- forced_path=_RESOURCES_DIR)
+_SILX_DIRECTORY = _ResourceDirectory(package_name=__name__, forced_path=_RESOURCES_DIR)
_RESOURCE_DIRECTORIES = {}
_RESOURCE_DIRECTORIES["silx"] = _SILX_DIRECTORY
-def register_resource_directory(name, package_name, forced_path=None):
+def register_resource_directory(
+ name: str, package_name: str, forced_path: Optional[str] = None
+):
"""Register another resource directory to the available list.
By default only the directory "silx" is available.
.. versionadded:: 0.6
- :param str name: Name of the resource directory. It is used on the resource
+ :param name: Name of the resource directory. It is used on the resource
name to specify the resource directory to use. The resource
"silx:foo.png" will use the "silx" resource directory.
- :param str package_name: Python name of the package containing resources.
+ :param package_name: Python name of the package containing resources.
For example "silx.resources".
- :param str forced_path: Path containing the resources. If specified
- `pkg_resources` nor `package_name` will be used
+ :param forced_path: Path containing the resources. If specified
+ neither `importlib` nor `package_name` will be used
For example "silx.resources".
:raises ValueError: If the resource directory name already exists.
"""
if name in _RESOURCE_DIRECTORIES:
raise ValueError("Resource directory name %s already exists" % name)
resource_directory = _ResourceDirectory(
- package_name=package_name,
- forced_path=forced_path)
+ package_name=package_name, forced_path=forced_path
+ )
_RESOURCE_DIRECTORIES[name] = resource_directory
-def list_dir(resource):
+def list_dir(resource: str) -> list[str]:
"""List the content of a resource directory.
Result are not prefixed by the resource name.
@@ -155,9 +145,8 @@ def list_dir(resource):
example "silx:foo.png" identify the resource "foo.png" from the resource
directory "silx". See also :func:`register_resource_directory`.
- :param str resource: Name of the resource directory to list
+ :param resource: Name of the resource directory to list
:return: list of name contained in the directory
- :rtype: List
"""
resource_directory, resource_name = _get_package_and_resource(resource)
@@ -165,50 +154,49 @@ def list_dir(resource):
# if set, use this directory
path = resource_filename(resource)
return os.listdir(path)
- elif pkg_resources is None:
- # Fallback if pkg_resources is not available
- path = resource_filename(resource)
- return os.listdir(path)
- else:
- # Preferred way to get resources as it supports zipfile package
- package_name = resource_directory.package_name
- return pkg_resources.resource_listdir(package_name, resource_name)
+ if sys.version_info < (3, 9):
+ return pkg_resources.resource_listdir(
+ resource_directory.package_name, resource_name
+ )
-def is_dir(resource):
+ path = importlib.resources.files(resource_directory.package_name) / resource_name
+ return [entry.name for entry in path.iterdir()]
+
+
+def is_dir(resource: str) -> bool:
"""True is the resource is a resource directory.
The resource name can be prefixed by the name of a resource directory. For
example "silx:foo.png" identify the resource "foo.png" from the resource
directory "silx". See also :func:`register_resource_directory`.
- :param str resource: Name of the resource
- :rtype: bool
+ :param resource: Name of the resource
"""
path = resource_filename(resource)
return os.path.isdir(path)
-def exists(resource):
+def exists(resource: str) -> bool:
"""True is the resource exists.
- :param str resource: Name of the resource
- :rtype: bool
+ :param resource: Name of the resource
"""
path = resource_filename(resource)
return os.path.exists(path)
-def _get_package_and_resource(resource, default_directory=None):
+def _get_package_and_resource(
+ resource: str, default_directory: Optional[str] = None
+) -> tuple[_ResourceDirectory, str]:
"""
Return the resource directory class and a cleaned resource name without
prefix.
- :param str: resource: Name of the resource with resource prefix.
- :param str default_directory: If the resource is not prefixed, the resource
+ :param resource: Name of the resource with resource prefix.
+ :param default_directory: If the resource is not prefixed, the resource
will be searched on this default directory of the silx resource
directory.
- :rtype: tuple(_ResourceDirectory, str)
:raises ValueError: If the resource name uses an unregistred resource
directory name
"""
@@ -217,14 +205,14 @@ def _get_package_and_resource(resource, default_directory=None):
else:
prefix = "silx"
if default_directory is not None:
- resource = os.path.join(default_directory, resource)
+ resource = f"{default_directory}/{resource}"
if prefix not in _RESOURCE_DIRECTORIES:
raise ValueError("Resource '%s' uses an unregistred prefix", resource)
resource_directory = _RESOURCE_DIRECTORIES[prefix]
return resource_directory, resource
-def resource_filename(resource):
+def resource_filename(resource: str) -> str:
"""Return filename corresponding to resource.
The existence of the resource is not checked.
@@ -233,18 +221,41 @@ def resource_filename(resource):
example "silx:foo.png" identify the resource "foo.png" from the resource
directory "silx". See also :func:`register_resource_directory`.
- :param str resource: Resource path relative to resource directory
- using '/' path separator. It can be either a file or
- a directory.
+ :param resource: Resource path relative to resource directory
+ using '/' path separator. It can be either a file or
+ a directory.
:raises ValueError: If the resource name uses an unregistred resource
directory name
:return: Absolute resource path in the file system
- :rtype: str
"""
return _resource_filename(resource, default_directory=None)
-def _resource_filename(resource, default_directory=None):
+# Manage resource files life-cycle
+_file_manager = contextlib.ExitStack()
+atexit.register(_file_manager.close)
+
+
+@functools.lru_cache(maxsize=None)
+def _get_resource_filename(package: str, resource: str) -> str:
+ """Returns path to requested resource in package
+
+ :param package: Name of the package in which to look for the resource
+ :param resource: Resource path relative to package using '/' path separator
+ :return: Abolute resource path in the file system
+ """
+ if sys.version_info < (3, 9):
+ return pkg_resources.resource_filename(package, resource)
+
+ # Caching prevents extracting the resource twice
+ file_context = importlib.resources.as_file(
+ importlib.resources.files(package) / resource
+ )
+ path = _file_manager.enter_context(file_context)
+ return str(path.absolute())
+
+
+def _resource_filename(resource: str, default_directory: Optional[str] = None) -> str:
"""Return filename corresponding to resource.
The existence of the resource is not checked.
@@ -253,32 +264,25 @@ def _resource_filename(resource, default_directory=None):
example "silx:foo.png" identify the resource "foo.png" from the resource
directory "silx". See also :func:`register_resource_directory`.
- :param str resource: Resource path relative to resource directory
- using '/' path separator. It can be either a file or
- a directory.
- :param str default_directory: If the resource is not prefixed, the resource
+ :param resource: Resource path relative to resource directory
+ using '/' path separator. It can be either a file or
+ a directory.
+ :param default_directory: If the resource is not prefixed, the resource
will be searched on this default directory of the silx resource
directory. It should only be used internally by silx.
:return: Absolute resource path in the file system
- :rtype: str
"""
- resource_directory, resource_name = _get_package_and_resource(resource,
- default_directory=default_directory)
+ resource_directory, resource_name = _get_package_and_resource(
+ resource, default_directory=default_directory
+ )
if resource_directory.forced_path is not None:
# if set, use this directory
base_dir = resource_directory.forced_path
- resource_path = os.path.join(base_dir, *resource_name.split('/'))
+ resource_path = os.path.join(base_dir, *resource_name.split("/"))
return resource_path
- elif pkg_resources is None:
- # Fallback if pkg_resources is not available
- base_dir = resource_directory.package_path
- resource_path = os.path.join(base_dir, *resource_name.split('/'))
- return resource_path
- else:
- # Preferred way to get resources as it supports zipfile package
- package_name = resource_directory.package_name
- return pkg_resources.resource_filename(package_name, resource_name)
+
+ return _get_resource_filename(resource_directory.package_name, resource_name)
# Expose ExternalResources for compatibility (since silx 0.11)
diff --git a/src/silx/resources/gui/icons/ruler.png b/src/silx/resources/gui/icons/ruler.png
new file mode 100644
index 0000000..0ff603f
--- /dev/null
+++ b/src/silx/resources/gui/icons/ruler.png
Binary files differ
diff --git a/src/silx/resources/gui/icons/ruler.svg b/src/silx/resources/gui/icons/ruler.svg
new file mode 100644
index 0000000..268b1db
--- /dev/null
+++ b/src/silx/resources/gui/icons/ruler.svg
@@ -0,0 +1,216 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<svg
+ width="32"
+ height="32"
+ viewBox="0 0 8.4666657 8.4666657"
+ version="1.1"
+ id="svg5"
+ inkscape:version="1.2.2 (b0a8486541, 2022-12-01)"
+ sodipodi:docname="ruler.svg"
+ inkscape:export-filename="ruler.png"
+ inkscape:export-xdpi="100"
+ inkscape:export-ydpi="100"
+ xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
+ xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
+ xmlns="http://www.w3.org/2000/svg"
+ xmlns:svg="http://www.w3.org/2000/svg"
+ xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
+ xmlns:cc="http://creativecommons.org/ns#"
+ xmlns:dc="http://purl.org/dc/elements/1.1/">
+ <metadata
+ id="metadata35">
+ <rdf:RDF>
+ <cc:Work
+ rdf:about="">
+ <dc:format>image/svg+xml</dc:format>
+ <dc:type
+ rdf:resource="http://purl.org/dc/dcmitype/StillImage" />
+ </cc:Work>
+ </rdf:RDF>
+ </metadata>
+ <sodipodi:namedview
+ id="namedview7"
+ pagecolor="#ffffff"
+ bordercolor="#000000"
+ borderopacity="0.25"
+ inkscape:showpageshadow="2"
+ inkscape:pageopacity="0.0"
+ inkscape:pagecheckerboard="0"
+ inkscape:deskcolor="#d1d1d1"
+ inkscape:document-units="mm"
+ showgrid="false"
+ inkscape:zoom="13.455443"
+ inkscape:cx="-18.988598"
+ inkscape:cy="0.2229581"
+ inkscape:window-width="1920"
+ inkscape:window-height="1163"
+ inkscape:window-x="1920"
+ inkscape:window-y="0"
+ inkscape:window-maximized="1"
+ inkscape:current-layer="g1102"
+ inkscape:document-rotation="0"
+ showguides="true">
+ <inkscape:grid
+ type="xygrid"
+ id="grid3452"
+ originx="0"
+ originy="0" />
+ </sodipodi:namedview>
+ <defs
+ id="defs2" />
+ <g
+ inkscape:label="Layer 1"
+ inkscape:groupmode="layer"
+ id="layer1">
+ <g
+ id="g1102"
+ transform="translate(-0.36925443,-7.7531893)">
+ <g
+ id="path1743">
+ <path
+ style="color:#000000;fill:#ffffff;stroke-width:0.517192;stroke-miterlimit:3.5;-inkscape-stroke:none"
+ d="M 6.5495575,11.909879 5.7344716,11.369562"
+ id="path418" />
+ <path
+ style="color:#000000;fill:#000000;stroke-miterlimit:3.5;-inkscape-stroke:none"
+ d="M 5.8769531,11.154297 5.5917969,11.585938 6.40625,12.125 6.6933594,11.695313 Z"
+ id="path420" />
+ <g
+ id="g408">
+ <g
+ id="path410">
+ <path
+ style="color:#000000;fill:#ffffff;fill-rule:evenodd;stroke-width:0.0456346pt;-inkscape-stroke:none"
+ d="m 5.5442899,11.243491 c 0.069591,-0.10498 0.211272,-0.133702 0.3162524,-0.06411 0.1049803,0.06959 0.133702,0.211272 0.064111,0.316253 -0.069591,0.10498 -0.2135542,0.132189 -0.3162524,0.06411 -0.1049803,-0.06959 -0.133702,-0.211272 -0.064111,-0.316253 z"
+ id="path414" />
+ <path
+ style="color:#000000;fill:#000000;fill-rule:evenodd;-inkscape-stroke:none"
+ d="m 5.8769531,11.154297 c -0.1187003,-0.07869 -0.278736,-0.04643 -0.3574219,0.07227 -0.078686,0.1187 -0.046435,0.28069 0.072266,0.359375 0.116723,0.07738 0.2784923,0.04485 0.3574218,-0.07422 0.078686,-0.118701 0.046435,-0.278737 -0.072266,-0.357422 z m -0.033203,0.05078 c 0.09126,0.06049 0.1151836,0.182177 0.054687,0.273438 -0.060252,0.09089 -0.1847643,0.115422 -0.2734375,0.05664 -0.09126,-0.0605 -0.1151836,-0.184129 -0.054687,-0.27539 0.060496,-0.09126 0.1821772,-0.115185 0.2734375,-0.05469 z"
+ id="path416" />
+ </g>
+ </g>
+ </g>
+ <g
+ id="rect234"
+ transform="matrix(0.77265229,0.63482945,-0.6944385,0.71955206,0,0)">
+ <path
+ style="color:#000000;fill:#ffffff;stroke-width:0.399005;stroke-linecap:round;stroke-linejoin:round;stroke-miterlimit:3.5;-inkscape-stroke:none"
+ d="m 10.480519,2.0891316 h 2.354846 c 0.12095,0 0.218321,0.097371 0.218321,0.2183209 v 8.0683205 c 0,0.120949 -0.09737,0.218321 -0.218321,0.218321 h -2.354846 c -0.120949,0 -0.218321,-0.09737 -0.218321,-0.218321 V 2.3074525 c 0,-0.1209498 0.09737,-0.2183209 0.218321,-0.2183209 z"
+ id="path330" />
+ <path
+ style="color:#000000;fill:#000000;stroke-linecap:round;stroke-linejoin:round;stroke-miterlimit:3.5;-inkscape-stroke:none"
+ d="m 10.480469,1.8886719 c -0.228023,0 -0.417969,0.1899469 -0.417969,0.4179687 V 10.375 c 0,0.228024 0.189947,0.417969 0.417969,0.417969 h 2.355469 c 0.228024,0 0.417968,-0.189947 0.417968,-0.417969 V 2.3066406 c 0,-0.2280229 -0.189946,-0.4179687 -0.417968,-0.4179687 z m 0,0.4003906 h 2.355469 c 0.01388,0 0.01758,0.0037 0.01758,0.017578 V 10.375 c 0,0.01388 -0.0037,0.01953 -0.01758,0.01953 h -2.355469 c -0.01388,0 -0.01953,-0.0057 -0.01953,-0.01953 V 2.3066406 c 0,-0.013877 0.0057,-0.017578 0.01953,-0.017578 z"
+ id="path332" />
+ </g>
+ <g
+ id="path1743-9-3">
+ <path
+ style="color:#000000;fill:#000000;stroke-miterlimit:3.5;-inkscape-stroke:none;paint-order:stroke markers fill"
+ d="m 5.9492187,10.972656 -0.3945312,0.439453 0.796875,0.716797 0.3945312,-0.4375 z"
+ id="path404" />
+ <g
+ id="g394">
+ <g
+ id="path396">
+ <path
+ style="color:#000000;fill:#ffffff;fill-rule:evenodd;stroke-width:0.0520108pt;-inkscape-stroke:none"
+ d="m 5.5585792,11.01876 c 0.095985,-0.10674 0.260515,-0.11547 0.3672555,-0.01949 0.1067406,0.09599 0.1154702,0.260515 0.019486,0.367256 -0.095985,0.10674 -0.2628354,0.113383 -0.3672555,0.01949 C 5.4713246,11.290036 5.462595,11.125501 5.5585792,11.01876 Z"
+ id="path400" />
+ <path
+ style="color:#000000;fill:#000000;fill-rule:evenodd;-inkscape-stroke:none"
+ d="m 5.9492187,10.972656 c -0.1206907,-0.108524 -0.3074864,-0.09725 -0.4160156,0.02344 -0.1085286,0.120691 -0.099207,0.307491 0.021484,0.416015 0.1186803,0.106716 0.3071504,0.09958 0.4160156,-0.02148 0.108528,-0.120691 0.099205,-0.309435 -0.021484,-0.417969 z m -0.046875,0.05274 c 0.092792,0.08345 0.1010184,0.225568 0.017578,0.318359 -0.083105,0.09242 -0.2281997,0.0967 -0.3183594,0.01563 -0.092789,-0.08343 -0.1010177,-0.225569 -0.017578,-0.318359 0.083441,-0.09279 0.2255693,-0.09906 0.3183593,-0.01563 z"
+ id="path402" />
+ </g>
+ </g>
+ </g>
+ <g
+ id="path1741-9">
+ <path
+ style="color:#000000;fill:#000000;stroke-miterlimit:3.5;-inkscape-stroke:none"
+ d="M 6.4042969,9.6074219 6.0117187,10.052734 7.3417969,11.224609 7.734375,10.779297 Z"
+ id="path390" />
+ <g
+ id="g380">
+ <g
+ id="path382">
+ <path
+ style="color:#000000;fill:#000000;fill-rule:evenodd;stroke-width:0.0524193pt;-inkscape-stroke:none"
+ d="m 6.0116332,9.6575594 c 0.095671,-0.1085287 0.261399,-0.1189636 0.3699276,-0.023292 0.1085287,0.095671 0.1189636,0.261399 0.023292,0.3699276 -0.095671,0.108528 -0.2637583,0.116884 -0.3699276,0.02329 C 5.9263967,9.9318156 5.9159618,9.766088 6.0116332,9.6575594 Z"
+ id="path386" />
+ <path
+ style="color:#000000;fill:#000000;fill-rule:evenodd;-inkscape-stroke:none"
+ d="m 6.4042969,9.6074219 c -0.1227123,-0.108175 -0.3097944,-0.095369 -0.4179688,0.027344 -0.1081747,0.1227123 -0.097322,0.3097959 0.025391,0.4179684 0.1206682,0.106376 0.3114126,0.0977 0.4199219,-0.02539 0.108175,-0.1227125 0.095369,-0.3117477 -0.027344,-0.4199221 z m -0.044922,0.052734 c 0.094345,0.083167 0.1026993,0.2259678 0.019531,0.3203125 C 6.2960737,10.074433 6.150264,10.082765 6.0585937,10.001953 5.9642495,9.9187874 5.9539415,9.7759853 6.0371094,9.6816406 6.1202768,9.5872959 6.2650303,9.5769882 6.359375,9.6601563 Z"
+ id="path388" />
+ </g>
+ </g>
+ </g>
+ <g
+ id="path1741-9-7">
+ <path
+ style="color:#000000;fill:#000000;stroke-miterlimit:3.5;-inkscape-stroke:none"
+ d="m 4.2851562,11.6875 -0.3925781,0.445313 1.3300781,1.171875 0.3925782,-0.445313 z"
+ id="path376" />
+ <g
+ id="g366">
+ <g
+ id="path368">
+ <path
+ style="color:#000000;fill:#000000;fill-rule:evenodd;stroke-width:0.0524193pt;-inkscape-stroke:none"
+ d="m 3.8930299,11.736561 c 0.095671,-0.108528 0.261399,-0.118963 0.3699276,-0.02329 0.1085286,0.09567 0.1189634,0.261399 0.023292,0.369928 -0.095671,0.108528 -0.2637583,0.116883 -0.3699276,0.02329 -0.1085286,-0.09567 -0.1189634,-0.261399 -0.023292,-0.369928 z"
+ id="path372" />
+ <path
+ style="color:#000000;fill:#000000;fill-rule:evenodd;-inkscape-stroke:none"
+ d="m 4.2851562,11.6875 c -0.1227122,-0.108177 -0.3097943,-0.09732 -0.4179687,0.02539 -0.1081748,0.122712 -0.097322,0.311748 0.025391,0.419922 0.1206681,0.106374 0.3114125,0.09575 0.4199219,-0.02734 0.1081748,-0.122713 0.095369,-0.309795 -0.027344,-0.417969 z m -0.044922,0.05273 c 0.094344,0.08317 0.1026991,0.225968 0.019531,0.320313 -0.082832,0.09396 -0.2286422,0.100343 -0.3203125,0.01953 -0.094344,-0.08317 -0.1026991,-0.225967 -0.019531,-0.320312 0.083167,-0.09434 0.2259677,-0.102701 0.3203125,-0.01953 z"
+ id="path374" />
+ </g>
+ </g>
+ </g>
+ <g
+ id="path1741-9-3">
+ <path
+ style="color:#000000;fill:#000000;stroke-miterlimit:3.5;-inkscape-stroke:none"
+ d="m 2.4824219,13.669922 -0.3945313,0.445312 1.3300782,1.171875 0.3925781,-0.445312 z"
+ id="path362" />
+ <g
+ id="g352">
+ <g
+ id="path354">
+ <path
+ style="color:#000000;fill:#000000;fill-rule:evenodd;stroke-width:0.0524193pt;-inkscape-stroke:none"
+ d="m 2.0886536,13.720045 c 0.095671,-0.108528 0.2613989,-0.118963 0.3699275,-0.02329 0.1085287,0.09567 0.1189637,0.261399 0.023292,0.369928 -0.095671,0.108528 -0.2637582,0.116883 -0.3699275,0.02329 -0.1085287,-0.09567 -0.1189637,-0.261399 -0.023292,-0.369928 z"
+ id="path358" />
+ <path
+ style="color:#000000;fill:#000000;fill-rule:evenodd;-inkscape-stroke:none"
+ d="M 2.4824219,13.669922 C 2.3597097,13.561745 2.1706744,13.574554 2.0625,13.697266 c -0.1081751,0.122712 -0.097322,0.309795 0.025391,0.417968 0.1206681,0.106375 0.3114126,0.0977 0.4199219,-0.02539 0.1081751,-0.122713 0.097322,-0.311749 -0.025391,-0.419922 z m -0.046875,0.05273 c 0.094344,0.08317 0.1026993,0.225968 0.019531,0.320313 -0.082832,0.09396 -0.2286422,0.102296 -0.3203125,0.02148 -0.094344,-0.08317 -0.1026993,-0.22792 -0.019531,-0.322265 0.083167,-0.09435 0.2259677,-0.102701 0.3203125,-0.01953 z"
+ id="path360" />
+ </g>
+ </g>
+ </g>
+ <g
+ id="path1743-9-3-6">
+ <path
+ style="color:#000000;fill:#ffffff;stroke-width:0.589456;stroke-miterlimit:3.5;-inkscape-stroke:none"
+ d="M 4.4307455,14.144747 3.6331377,13.427511"
+ id="path346" />
+ <path
+ style="color:#000000;fill:#000000;stroke-miterlimit:3.5;-inkscape-stroke:none"
+ d="m 3.8300781,13.208984 -0.3945312,0.4375 0.7988281,0.716797 0.3925781,-0.4375 z"
+ id="path348" />
+ <g
+ id="g336">
+ <g
+ id="path338">
+ <path
+ style="color:#000000;fill:#ffffff;fill-rule:evenodd;stroke-width:0.0520108pt;-inkscape-stroke:none"
+ d="m 3.4397673,13.253626 c 0.095985,-0.106741 0.2605151,-0.11547 0.3672556,-0.01949 0.1067404,0.09598 0.1154698,0.260515 0.019485,0.367255 -0.095985,0.106741 -0.2628356,0.113383 -0.3672556,0.01949 -0.1067404,-0.09598 -0.1154698,-0.260515 -0.019485,-0.367255 z"
+ id="path342" />
+ <path
+ style="color:#000000;fill:#000000;fill-rule:evenodd;-inkscape-stroke:none"
+ d="m 3.8300781,13.208984 c -0.1206906,-0.108523 -0.3074864,-0.09921 -0.4160156,0.02149 -0.1085295,0.12069 -0.099207,0.307491 0.021484,0.416015 0.1186801,0.106716 0.3071504,0.09958 0.4160156,-0.02148 0.1085295,-0.12069 0.099207,-0.307491 -0.021484,-0.416016 z m -0.046875,0.05078 c 0.092789,0.08343 0.1010181,0.22557 0.017578,0.318359 -0.083105,0.09242 -0.2281998,0.09865 -0.3183594,0.01758 -0.092789,-0.08343 -0.1010181,-0.22557 -0.017578,-0.318359 0.083441,-0.09279 0.2255692,-0.101014 0.3183593,-0.01758 z"
+ id="path344" />
+ </g>
+ </g>
+ </g>
+ </g>
+ </g>
+</svg>
diff --git a/src/silx/resources/opencl/codec/bitshuffle_lz4.cl b/src/silx/resources/opencl/codec/bitshuffle_lz4.cl
new file mode 100644
index 0000000..71f617a
--- /dev/null
+++ b/src/silx/resources/opencl/codec/bitshuffle_lz4.cl
@@ -0,0 +1,625 @@
+/*
+ * Project: SILX: Bitshuffle LZ4 decompressor
+ *
+ * Copyright (C) 2022 European Synchrotron Radiation Facility
+ * Grenoble, France
+ *
+ * Principal authors: J. Kieffer (kieffer@esrf.fr)
+ *
+ * Permission is hereby granted, free of charge, to any person
+ * obtaining a copy of this software and associated documentation
+ * files (the "Software"), to deal in the Software without
+ * restriction, including without limitation the rights to use,
+ * copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the
+ * Software is furnished to do so, subject to the following
+ * conditions:
+ *
+ * The above copyright notice and this permission notice shall be
+ * included in all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+ * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
+ * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+ * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+ * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+ * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
+ * OTHER DEALINGS IN THE SOFTWARE.
+ */
+
+/* To decompress bitshuffle-LZ4 data in parallel on GPU one needs to:
+ * - Find all begining of blocks, this is performed by the ... kernel.
+ * - Decompress each block by one workgroup.
+ * - bitshuffle the data from one workgroup
+ */
+
+#ifndef LZ4_BLOCK_SIZE
+# define LZ4_BLOCK_SIZE 8192
+#endif
+#define LZ4_BLOCK_EXTRA 400
+#ifdef __ENDIAN_LITTLE__
+#define SWAP_BE 1
+#define SWAP_LE 0
+#else
+#define SWAP_BE 0
+#define SWAP_LE 1
+#endif
+
+
+#define int8_t char
+#define uint8_t uchar
+#define int16_t short
+#define uint16_t ushort
+#define int32_t int
+#define uint32_t uint
+#define int64_t long
+#define uint64_t ulong
+
+#define position_t uint
+#define token_t uchar2
+
+//Some function used as part of bitshuffle:
+
+inline token_t decode_token(uint8_t value){
+ return (token_t)(value >> 4, // literals
+ value & 0x0f); // matches
+}
+
+inline bool has_liter_over(token_t token)
+{
+ return token.s0 >= 15;
+}
+
+inline bool has_match_over(token_t token)
+{
+ return token.s1 >= 15;
+}
+
+//parse overflow, return the number of overflow and the new position
+inline uint2 read_overflow(local uint8_t* buffer,
+ position_t buffer_size,
+ position_t idx){
+ position_t num = 0;
+ uint8_t next = 0xff;
+ while (next == 0xff && idx < buffer_size){
+ next = buffer[idx];
+ idx += 1;
+ num += next;
+ }
+ return (uint2)(num, idx);
+}
+
+inline void copy_no_overlap(local uint8_t* dest,
+ const position_t dest_position,
+ local uint8_t* source,
+ const position_t src_position,
+ const position_t length){
+ for (position_t i=get_local_id(0); i<length; i+=get_local_size(0)) {
+ dest[dest_position+i] = source[src_position+i];
+ }
+}
+
+inline void copy_repeat(local uint8_t* dest,
+ const position_t dest_position,
+ local uint8_t* source,
+ const position_t src_position,
+ const position_t dist,
+ const position_t length){
+
+ // if there is overlap, it means we repeat, so we just
+ // need to organize our copy around that
+ for (position_t i=get_local_id(0); i<length; i+=get_local_size(0)) {
+ dest[dest_position+i] = source[src_position + i%dist];
+ }
+}
+
+inline void copy_collab(local uint8_t* dest,
+ const position_t dest_position,
+ local uint8_t* source,
+ const position_t src_position,
+ const position_t dist,
+ const position_t length){
+ //Generic copy function
+ if (dist < length) {
+ copy_repeat(dest, dest_position, source, src_position, dist, length);
+ }
+ else {
+ copy_no_overlap(dest, dest_position, source, src_position, length);
+ }
+}
+
+// Function to read larger integers at various position. Endianness is addressed as well with the swap flag
+uint64_t load64_at(global uint8_t *src,
+ const uint64_t position,
+ const bool swap){
+ uchar8 vector;
+ if (swap){
+ vector = (uchar8)(src[position+7],src[position+6],
+ src[position+5],src[position+4],
+ src[position+3],src[position+2],
+ src[position+1],src[position+0]);
+ }
+ else{
+ vector = (uchar8)(src[position+0],src[position+1],
+ src[position+2],src[position+3],
+ src[position+4],src[position+5],
+ src[position+6],src[position+7]);
+ }
+ return as_ulong(vector);
+}
+
+uint32_t load32_at(global uint8_t *src,
+ const uint64_t position,
+ const bool swap){
+ uchar4 vector;
+ if (swap){
+ vector = (uchar4)(
+ src[position+3],src[position+2],
+ src[position+1],src[position+0]);
+ }
+ else{
+ vector = (uchar4)(src[position+0],src[position+1],
+ src[position+2],src[position+3]);
+ }
+ return as_uint(vector);
+}
+
+uint16_t load16_at(local uint8_t *src,
+ const uint64_t position,
+ const bool swap){
+ uchar2 vector;
+ if (swap){
+ vector = (uchar2)(src[position+1],src[position+0]);
+ }
+ else{
+ vector = (uchar2)(src[position+0],src[position+1]);
+ }
+ return as_ushort(vector);
+}
+
+//Calculate the begining and the end of the block corresponding to the block=gid
+inline void _lz4_unblock(global uint8_t *src,
+ const uint64_t size,
+ local uint64_t *block_position){
+ uint32_t gid = get_group_id(0);
+ uint32_t lid = get_local_id(0);
+ if (lid == 0){
+ uint64_t block_start=16;
+ uint32_t block_size = load32_at(src, 12, SWAP_BE);
+ uint64_t block_end = block_start + block_size;
+
+ for (uint32_t block_idx=0; block_idx<gid; block_idx++){
+ // printf("gid %u idx %u %lu-%lu\n",gid, block_idx,block_start,block_end);
+ block_start = block_end + 4;
+ if (block_start>=size){
+ printf("Read beyond end of source buffer at gid %u %lu>%lu\n",gid, block_start, size);
+ block_start = 0;
+ block_end = 0;
+ break;
+ }
+ block_size = load32_at(src, block_end, SWAP_BE);
+ block_end = block_start + block_size;
+ }
+ block_position[0] = block_start;
+ block_position[1] = block_end;
+// if (gid>get_num_groups(0)-10) printf("Success finish unblock gid %u block: %lu - %lu\n",gid,block_start,block_end);
+ }
+ barrier(CLK_LOCAL_MEM_FENCE);
+}
+
+
+//Decompress one block in shared memory
+inline uint32_t lz4_decompress_local_block( local uint8_t* local_cmp,
+ local uint8_t* local_dec,
+ const uint32_t cmp_buffer_size,
+ const uint32_t dec_buffer_size){
+
+ uint32_t gid = get_group_id(0); // One block is decompressed by one workgroup
+ uint32_t lid = get_local_id(0); // This is the thread position in the group...
+ uint32_t wg = get_local_size(0); // workgroup size
+
+ position_t dec_idx = 0;
+ position_t cmp_idx = 0;
+ while (cmp_idx < cmp_buffer_size) {
+ // read header byte
+ token_t tok = decode_token(local_cmp[cmp_idx]);
+ // if (lid==0) printf("gid %u at idx %u/%u. Token is litterials: %u; matches: %u\n", gid, cmp_idx, cmp_buffer_size,tok.s0, tok.s1);
+
+ cmp_idx+=1;
+
+ // read the length of the literals
+ position_t num_literals = tok.s0;
+ if (has_liter_over(tok)) {
+ uint2 tmp = read_overflow(local_cmp,
+ cmp_buffer_size,
+ cmp_idx);
+ num_literals += tmp.s0;
+ cmp_idx = tmp.s1;
+ }
+ const position_t start_literal = cmp_idx;
+
+ // copy the literals to the dst stream in parallel
+ // if (lid==0) printf("gid %u: copy literals from %u to %u <%u (len %u)\n", gid, cmp_idx,num_literals+cmp_idx,cmp_buffer_size,num_literals);
+ copy_no_overlap(local_dec, dec_idx, local_cmp, cmp_idx, num_literals);
+ cmp_idx += num_literals;
+ dec_idx += num_literals;
+
+ // Note that the last sequence stops right after literals field.
+ // There are specific parsing rules to respect to be compatible with the
+ // reference decoder : 1) The last 5 bytes are always literals 2) The last
+ // match cannot start within the last 12 bytes Consequently, a file with
+ // less then 13 bytes can only be represented as literals These rules are in
+ // place to benefit speed and ensure buffer limits are never crossed.
+ if (cmp_idx < cmp_buffer_size) {
+
+ // read the offset
+ uint16_t offset = load16_at(local_cmp, cmp_idx, SWAP_LE);
+ // if (lid==0) printf("gid %u: offset is %u at %u\n",gid, offset, cmp_idx);
+ if (offset == 0) {
+ //corruped block
+ if (lid == 0)
+ printf("Corrupted block #%u\n", gid);
+ return 0;
+ }
+
+ cmp_idx += 2;
+
+ // read the match length
+ position_t match = 4 + tok.s1;
+ if (has_match_over(tok)) {
+ uint2 tmp = read_overflow(local_cmp,
+ cmp_buffer_size,
+ cmp_idx);
+ match += tmp.s0;
+ cmp_idx = tmp.s1;
+ }
+
+ //syncronize threads before reading shared memory
+ barrier(CLK_LOCAL_MEM_FENCE);
+
+ // copy match
+ copy_collab(local_dec, dec_idx, local_dec, dec_idx - offset, offset, match);
+ dec_idx += match;
+ }
+ }
+ //syncronize threads before reading shared memory
+ barrier(CLK_LOCAL_MEM_FENCE);
+ return dec_idx;
+}
+
+//Perform the bifshuffling on 8-bits objects
+inline void bitunshuffle8( local uint8_t* inp,
+ local uint8_t* out,
+ const uint32_t buffer_size){ //8k ... or less.
+// uint32_t gid = get_group_id(0);
+ uint32_t lid = get_local_id(0);
+ uint32_t wg = get_local_size(0);
+ uint32_t u8_buffer_size = buffer_size; // /1 -> 8k
+
+ // One thread deals with one or several output data
+ for (uint32_t dpos=lid; dpos<u8_buffer_size; dpos+=wg){
+ uint8_t res = 0;
+ // read bits at several places...
+ for (uint32_t bit=0; bit<8; bit++){
+ uint32_t read_bit = bit*u8_buffer_size + dpos;
+ uint32_t u8_word_pos = read_bit>>3; // /8
+ uint32_t u8_bit_pos = read_bit&7; // %8
+ // if (lid==0) printf("dpos %u bit %u read at %u,%u\n",dpos,bit,u8_word_pos,u8_bit_pos);
+ res |= ((inp[u8_word_pos]>>u8_bit_pos) & 1)<<bit;
+ }
+ // if (lid==0) printf("dpos %u res %u\n",dpos,res);
+ out[dpos] = res;
+ }
+}
+
+
+//Perform the bifshuffling on 16-bits objects
+inline void bitunshuffle16( local uint8_t* inp,
+ local uint8_t* out,
+ const uint32_t buffer_size){ //8k ... or less.
+// uint32_t gid = get_group_id(0);
+ uint32_t lid = get_local_id(0);
+ uint32_t wg = get_local_size(0);
+ uint32_t u16_buffer_size = buffer_size>>1; // /2 -> 4k
+
+ // One thread deals with one or several output data
+ for (uint32_t dpos=lid; dpos<u16_buffer_size; dpos+=wg){
+ uint16_t res = 0;
+ // read bits at several places...
+ for (uint32_t bit=0; bit<16; bit++){
+ uint32_t read_bit = bit*u16_buffer_size + dpos;
+ uint32_t u8_word_pos = read_bit>>3; // /8
+ uint32_t u8_bit_pos = read_bit&7; // %8
+ // if (lid==0) printf("dpos %u bit %u read at %u,%u\n",dpos,bit,u8_word_pos,u8_bit_pos);
+ res |= ((inp[u8_word_pos]>>u8_bit_pos) & 1)<<bit;
+ }
+ // if (lid==0) printf("dpos %u res %u\n",dpos,res);
+ uchar2 tmp = as_uchar2(res);
+ out[2*dpos] = tmp.s0;
+ out[2*dpos+1] = tmp.s1;
+ }
+}
+
+
+//Perform the bifshuffling on 32-bits objects
+inline void bitunshuffle32( local uint8_t* inp,
+ local uint8_t* out,
+ const uint32_t buffer_size){ //8k ... or less.
+// uint32_t gid = get_group_id(0);
+ uint32_t lid = get_local_id(0);
+ uint32_t wg = get_local_size(0);
+ uint32_t u32_buffer_size = buffer_size>>2; // /4 -> 2k
+
+ // One thread deals with one or several output data
+ for (uint32_t dpos=lid; dpos<u32_buffer_size; dpos+=wg){
+ uint32_t res = 0;
+ // read bits at several places...
+ for (uint32_t bit=0; bit<32; bit++){
+ uint32_t read_bit = bit*u32_buffer_size + dpos;
+ uint32_t u8_word_pos = read_bit>>3; // /8
+ uint32_t u8_bit_pos = read_bit&7; // %8
+ // if (lid==0) printf("dpos %u bit %u read at %u,%u\n",dpos,bit,u8_word_pos,u8_bit_pos);
+ res |= ((inp[u8_word_pos]>>u8_bit_pos) & 1)<<bit;
+ }
+ // if (lid==0) printf("dpos %u res %u\n",dpos,res);
+ uchar4 tmp = as_uchar4(res);
+ out[4*dpos] = tmp.s0;
+ out[4*dpos+1] = tmp.s1;
+ out[4*dpos+2] = tmp.s2;
+ out[4*dpos+3] = tmp.s3;
+ }
+}
+
+//Perform the bifshuffling on 32-bits objects
+inline void bitunshuffle64( local uint8_t* inp,
+ local uint8_t* out,
+ const uint32_t buffer_size){ //8k ... or less.
+// uint32_t gid = get_group_id(0);
+ uint32_t lid = get_local_id(0);
+ uint32_t wg = get_local_size(0);
+ uint32_t u64_buffer_size = buffer_size>>3; // /8 -> 1k
+
+ // One thread deals with one or several output data
+ for (uint32_t dpos=lid; dpos<u64_buffer_size; dpos+=wg){
+ uint64_t res = 0;
+ // read bits at several places...
+ for (uint32_t bit=0; bit<64; bit++){
+ uint32_t read_bit = bit*u64_buffer_size + dpos;
+ uint32_t u8_word_pos = read_bit>>3; // /8
+ uint32_t u8_bit_pos = read_bit&7; // %8
+ // if (lid==0) printf("dpos %u bit %u read at %u,%u\n",dpos,bit,u8_word_pos,u8_bit_pos);
+ res |= ((inp[u8_word_pos]>>u8_bit_pos) & 1)<<bit;
+ }
+ // if (lid==0) printf("dpos %u res %u\n",dpos,res);
+ uchar8 tmp = as_uchar8(res);
+ out[8*dpos] = tmp.s0;
+ out[8*dpos+1] = tmp.s1;
+ out[8*dpos+2] = tmp.s2;
+ out[8*dpos+3] = tmp.s3;
+ out[8*dpos+4] = tmp.s4;
+ out[8*dpos+5] = tmp.s5;
+ out[8*dpos+6] = tmp.s6;
+ out[8*dpos+7] = tmp.s7;
+ }
+}
+
+
+/* Preprocessing kernel which performs:
+- Memset arrays
+- read block position stored in block_position array
+
+Param:
+- src: input buffer in global memory
+- size: input buffer size
+- block_position: output buffer in local memory containing the index of the begining of each block
+- max_blocks: allocated memory for block_position array (output)
+- nb_blocks: output buffer with the actual number of blocks in src (output).
+
+Return: Nothing, this is a kernel
+
+Hint on workgroup size: little kernel ... wg=1, 1 wg is enough.
+*/
+
+kernel void lz4_unblock(global uint8_t *src,
+ const uint64_t size,
+ global uint64_t *block_start,
+ const uint32_t max_blocks,
+ global uint32_t *nb_blocks){
+
+ uint64_t total_nbytes = load64_at(src,0,SWAP_BE);
+ uint32_t block_nbytes = load32_at(src,8,SWAP_BE);
+
+ uint32_t block_idx = 0;
+ uint64_t pos = 12;
+ uint32_t block_size;
+
+ while ((pos+4<size) && (block_idx<max_blocks)){
+ block_size = load32_at(src, pos, SWAP_BE);
+ block_start[block_idx] = pos + 4;
+ block_idx +=1;
+ pos += 4 + block_size;
+ }
+ nb_blocks[0] = block_idx;
+}
+
+// decompress a frame blockwise.
+// Needs the block position to be known in advance (block_start) calculated from lz4_unblock.
+// one workgroup treats on block.
+
+kernel void bslz4_decompress_block( global uint8_t* comp_src,
+ global uint8_t* dec_dest,
+ global uint64_t* block_start,
+ global uint32_t *nb_blocks,
+ const uint8_t item_size){
+
+ uint32_t gid = get_group_id(0); // One block is decompressed by one workgroup
+ uint32_t lid = get_local_id(0); // This is the thread position in the group...
+ uint32_t wg = get_local_size(0); // workgroup size
+
+ //guard if the number of wg scheduled is too large
+ if (gid >=nb_blocks[0]) return;
+
+ // No need to guard, the number of blocks can be calculated in advance.
+ uint64_t start_read = block_start[gid];
+ if (start_read<12) return;
+
+ local uint8_t local_cmp[LZ4_BLOCK_SIZE+LZ4_BLOCK_EXTRA];
+ local uint8_t local_dec[LZ4_BLOCK_SIZE];
+
+ uint32_t cmp_buffer_size = load32_at(comp_src, start_read-4, SWAP_BE);
+ uint64_t end_read = start_read + cmp_buffer_size;
+ // Copy locally the compressed buffer and memset the destination buffer
+ for (uint32_t i=lid; i<cmp_buffer_size; i+=wg){
+ uint64_t read_pos = start_read + i;
+ if (read_pos<end_read)
+ local_cmp[i] = comp_src[read_pos];
+ else
+ local_cmp[i] = 0;
+ }
+ for (uint32_t i=lid+cmp_buffer_size; i<LZ4_BLOCK_SIZE+LZ4_BLOCK_EXTRA; i+=wg){
+ local_cmp[i] = 0;
+ }
+ for (uint32_t i=lid; i<LZ4_BLOCK_SIZE; i+=wg){
+ local_dec[i] = 0;
+ }
+ barrier(CLK_LOCAL_MEM_FENCE);
+
+ //All the work is performed here:
+ uint32_t dec_size = lz4_decompress_local_block( local_cmp, local_dec, cmp_buffer_size, LZ4_BLOCK_SIZE);
+
+ barrier(CLK_LOCAL_MEM_FENCE);
+ local uint8_t* local_buffer;
+
+ //Perform bit-unshuffle
+ if (item_size == 1){
+// if ((gid==0) && (lid==0)) printf("bitunshuffle8");
+ bitunshuffle8(local_dec, local_cmp, dec_size);
+ local_buffer=local_cmp;
+ }
+ else if (item_size == 2){
+// if ((gid==0) && (lid==0)) printf("bitunshuffle16");
+ bitunshuffle16(local_dec, local_cmp, dec_size);
+ local_buffer=local_cmp;
+ }
+ else if (item_size == 4){
+// if ((gid==0) && (lid==0)) printf("bitunshuffle32");
+ bitunshuffle32(local_dec, local_cmp, dec_size);
+ local_buffer=local_cmp;
+ }
+ else if (item_size == 8){
+// if ((gid==0) && (lid==0)) printf("bitunshuffle64");
+ bitunshuffle64(local_dec, local_cmp, dec_size);
+ local_buffer=local_cmp;
+ }
+ else {
+ local_buffer = local_dec;
+ }
+
+
+ //Finally copy the destination data from local to global memory:
+ uint64_t start_write = LZ4_BLOCK_SIZE*gid;
+ barrier(CLK_LOCAL_MEM_FENCE);
+ for (uint32_t i=lid; i<dec_size; i+=wg){
+ dec_dest[start_write + i] = local_buffer[i];
+ }
+
+ if (gid+1==get_num_groups(0)){
+ uint64_t total_nbytes = load64_at(comp_src,0,SWAP_BE);
+ uint64_t end_write = dec_size + start_write;
+ int32_t remaining = total_nbytes - end_write;
+// if (lid==0) printf("gid %u is last block has %u elements. Writing ends at %u/%lu, copy remaining %d\n",gid, dec_size, end_write, total_nbytes, remaining);
+ if ((remaining>0) && (remaining<item_size*8)){
+ for (uint32_t i=lid; i<remaining; i++){
+ dec_dest[end_write + i] = comp_src[end_read+i];
+ }
+ }
+ }
+
+}
+
+// decompress a frame blockwise.
+// block-start are searched by one thread from each workgroup ... not very efficient
+// one workgroup treats on block.
+
+kernel void bslz4_decompress_frame(
+ global uint8_t* comp_src,
+ const uint64_t src_size,
+ global uint8_t* dec_dest,
+ const uint8_t item_size){
+
+ uint32_t gid = get_group_id(0); // One block is decompressed by one workgroup
+ uint32_t lid = get_local_id(0); // This is the thread position in the group...
+ uint32_t wg = get_local_size(0); // workgroup size
+
+ local uint8_t local_cmp[LZ4_BLOCK_SIZE+LZ4_BLOCK_EXTRA];
+ local uint8_t local_dec[LZ4_BLOCK_SIZE];
+ local uint64_t block[2]; // will contain begining and end of the current block
+
+ uint64_t start_read, end_read;
+ uint32_t cmp_buffer_size;
+ _lz4_unblock(comp_src, src_size, block);
+ start_read = block[0];
+ end_read = block[1];
+ cmp_buffer_size = end_read - start_read;
+ if (cmp_buffer_size == 0){
+ if (lid == 0) printf("gid=%u: Empty buffer\n", gid);
+ return;
+ }
+
+ // Copy locally the compressed buffer and memset the destination buffer
+ for (uint32_t i=lid; i<cmp_buffer_size; i+=wg){
+ uint64_t read_pos = start_read + i;
+ if (read_pos<end_read)
+ local_cmp[i] = comp_src[read_pos];
+ else
+ local_cmp[i] = 0;
+ }
+ for (uint32_t i=lid+cmp_buffer_size; i<LZ4_BLOCK_SIZE+LZ4_BLOCK_EXTRA; i+=wg){
+ local_cmp[i] = 0;
+ }
+ for (uint32_t i=lid; i<LZ4_BLOCK_SIZE; i+=wg){
+ local_dec[i] = 0;
+ }
+ barrier(CLK_LOCAL_MEM_FENCE);
+
+ //All the work is performed here:
+ uint32_t dec_size;
+ dec_size = lz4_decompress_local_block( local_cmp, local_dec, cmp_buffer_size, LZ4_BLOCK_SIZE);
+
+ barrier(CLK_LOCAL_MEM_FENCE);
+ local uint8_t* local_buffer;
+
+ //Perform bit-unshuffle
+ if (item_size == 1){
+// if ((gid==0) && (lid==0)) printf("bitunshuffle8");
+ bitunshuffle8(local_dec, local_cmp, dec_size);
+ local_buffer=local_cmp;
+ }
+ else if (item_size == 2){
+// if ((gid==0) && (lid==0)) printf("bitunshuffle16");
+ bitunshuffle16(local_dec, local_cmp, dec_size);
+ local_buffer=local_cmp;
+ }
+ else if (item_size == 4){
+// if ((gid==0) && (lid==0)) printf("bitunshuffle32");
+ bitunshuffle32(local_dec, local_cmp, dec_size);
+ local_buffer=local_cmp;
+ }
+ else if (item_size == 8){
+// if ((gid==0) && (lid==0)) printf("bitunshuffle64");
+ bitunshuffle64(local_dec, local_cmp, dec_size);
+ local_buffer=local_cmp;
+ }
+ else {
+ local_buffer = local_dec;
+ }
+
+ //Finally copy the destination data from local to global memory:
+ uint64_t start_write = LZ4_BLOCK_SIZE*gid;
+ barrier(CLK_LOCAL_MEM_FENCE);
+ for (uint32_t i=lid; i<dec_size; i+=wg){
+ dec_dest[start_write + i] = local_buffer[i];
+ }
+
+}
diff --git a/src/silx/resources/opencl/doubleword.cl b/src/silx/resources/opencl/doubleword.cl
index a0ebfda..02a8aba 100644
--- a/src/silx/resources/opencl/doubleword.cl
+++ b/src/silx/resources/opencl/doubleword.cl
@@ -29,6 +29,7 @@
*
* We use the trick to declare some variable "volatile" to enforce the actual
* precision reduction of those variables.
+ * This has to be used in combination with #pragma clang fp contract(on)
*/
#ifndef X87_VOLATILE
@@ -37,6 +38,7 @@
//Algorithm 1, p23, theorem 1.1.12. Requires e_x > e_y, valid if |x| > |y|
inline fp2 fast_fp_plus_fp(fp x, fp y){
+ #pragma clang fp contract(on)
X87_VOLATILE fp s = x + y;
X87_VOLATILE fp z = s - x;
fp e = y - z;
@@ -45,6 +47,7 @@ inline fp2 fast_fp_plus_fp(fp x, fp y){
//Algorithm 2, p24, same as fast_fp_plus_fp without the condition on e_x and e_y
inline fp2 fp_plus_fp(fp x, fp y){
+ #pragma clang fp contract(on)
X87_VOLATILE fp s = x + y;
X87_VOLATILE fp xp = s - y;
X87_VOLATILE fp yp = s - xp;
@@ -62,6 +65,7 @@ inline fp2 fp_times_fp(fp x, fp y){
//Algorithm 7, p38: Addition of a FP to a DW. 10flop bounds:2u²+5u³
inline fp2 dw_plus_fp(fp2 x, fp y){
+ #pragma clang fp contract(on)
fp2 s = fp_plus_fp(x.s0, y);
X87_VOLATILE fp v = x.s1 + s.s1;
return fast_fp_plus_fp(s.s0, v);
@@ -83,6 +87,7 @@ inline fp2 dw_times_fp(fp2 x, fp y){
//Algorithm 14, p52: Multiplication DW*DW, 8 flops bounds:6u²
inline fp2 dw_times_dw(fp2 x, fp2 y){
+ #pragma clang fp contract(on)
fp2 c = fp_times_fp(x.s0, y.s0);
X87_VOLATILE fp l = fma(x.s1, y.s0, x.s0 * y.s1);
return fast_fp_plus_fp(c.s0, c.s1 + l);
@@ -90,6 +95,7 @@ inline fp2 dw_times_dw(fp2 x, fp2 y){
//Algorithm 17, p55: Division DW / FP, 10flops bounds: 3.5u²
inline fp2 dw_div_fp(fp2 x, fp y){
+ #pragma clang fp contract(on)
X87_VOLATILE fp th = x.s0 / y;
fp2 pi = fp_times_fp(th, y);
fp2 d = x - pi;
@@ -100,6 +106,7 @@ inline fp2 dw_div_fp(fp2 x, fp y){
//Derived from algorithm 20, p64: Inversion 1/ DW, 22 flops
inline fp2 inv_dw(fp2 y){
+ #pragma clang fp contract(on)
X87_VOLATILE fp th = one/y.s0;
X87_VOLATILE fp rh = fma(-y.s0, th, one);
X87_VOLATILE fp rl = -y.s1 * th;
diff --git a/src/silx/sx/__init__.py b/src/silx/sx/__init__.py
index 8922989..8a19d61 100644
--- a/src/silx/sx/__init__.py
+++ b/src/silx/sx/__init__.py
@@ -52,11 +52,11 @@ _logger = _logging.getLogger(__name__)
# Init logging when used from the console
-if hasattr(_sys, 'ps1'):
+if hasattr(_sys, "ps1"):
_logging.basicConfig()
# Probe DISPLAY available on linux
-_NO_DISPLAY = _sys.platform.startswith('linux') and not _os.environ.get('DISPLAY')
+_NO_DISPLAY = _sys.platform.startswith("linux") and not _os.environ.get("DISPLAY")
# Probe ipython
try:
@@ -66,10 +66,10 @@ except (NameError, ImportError):
# Probe ipython/jupyter notebook
if _get_ipython is not None and _get_ipython() is not None:
-
# Notebook detection probably fragile
- _IS_NOTEBOOK = ('parent_appname' in _get_ipython().config['IPKernelApp'] or
- hasattr(_get_ipython(), 'kernel'))
+ _IS_NOTEBOOK = "parent_appname" in _get_ipython().config["IPKernelApp"] or hasattr(
+ _get_ipython(), "kernel"
+ )
else:
_IS_NOTEBOOK = False
@@ -81,30 +81,39 @@ _qapp = None
def enable_gui():
"""Populate silx.sx module with silx.gui features and initialise Qt"""
if _NO_DISPLAY: # Missing DISPLAY under linux
- _logger.warning(
- 'Not loading silx.gui features: No DISPLAY available')
+ _logger.warning("Not loading silx.gui features: No DISPLAY available")
return
global qt, _qapp
if _get_ipython is not None and _get_ipython() is not None:
- _get_ipython().enable_pylab(gui='qt', import_all=False)
+ _get_ipython().enable_pylab(gui="qt", import_all=False)
from silx.gui import qt
+
# Create QApplication and keep reference only if needed
if not qt.QApplication.instance():
_qapp = qt.QApplication([])
- if hasattr(_sys, 'ps1'): # If from console, change windows icon
+ if hasattr(_sys, "ps1"): # If from console, change windows icon
# Change windows default icon
from silx.gui import icons
+
app = qt.QApplication.instance()
- app.setWindowIcon(icons.getQIcon('silx'))
+ app.setWindowIcon(icons.getQIcon("silx"))
global ImageView, PlotWidget, PlotWindow, Plot1D
global Plot2D, StackView, ScatterView, TickMode
- from silx.gui.plot import (ImageView, PlotWidget, PlotWindow, Plot1D,
- Plot2D, StackView, ScatterView, TickMode) # noqa
+ from silx.gui.plot import (
+ ImageView,
+ PlotWidget,
+ PlotWindow,
+ Plot1D,
+ Plot2D,
+ StackView,
+ ScatterView,
+ TickMode,
+ ) # noqa
global plot, imshow, scatter, ginput
from ._plot import plot, imshow, scatter, ginput # noqa
@@ -113,7 +122,8 @@ def enable_gui():
import OpenGL
except ImportError:
_logger.warning(
- 'Not loading silx.gui.plot3d features: PyOpenGL is not installed')
+ "Not loading silx.gui.plot3d features: PyOpenGL is not installed"
+ )
else:
global contour3d, points3d
from ._plot3d import contour3d, points3d # noqa
@@ -121,8 +131,7 @@ def enable_gui():
# Load Qt and widgets only if running from console and display available
if _IS_NOTEBOOK:
- _logger.warning(
- 'Not loading silx.gui features: Running from the notebook')
+ _logger.warning("Not loading silx.gui features: Running from the notebook")
else:
enable_gui()
@@ -131,6 +140,7 @@ else:
if _get_ipython is not None and _get_ipython() is not None:
if not _NO_DISPLAY: # Not loading pylab without display
from IPython.core.pylabtools import import_pylab as _import_pylab
+
_import_pylab(_get_ipython().user_ns, import_all=False)
diff --git a/src/silx/sx/_plot.py b/src/silx/sx/_plot.py
index 155adba..22e1a2f 100644
--- a/src/silx/sx/_plot.py
+++ b/src/silx/sx/_plot.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -29,11 +29,7 @@ __license__ = "MIT"
__date__ = "06/11/2018"
-import collections
-try:
- from collections import abc
-except ImportError: # Python2 support
- import collections as abc
+from collections import abc
import logging
import weakref
@@ -118,17 +114,17 @@ def plot(*args, **kwargs):
:rtype: silx.gui.plot.Plot1D
"""
plt = Plot1D()
- if 'title' in kwargs:
- plt.setGraphTitle(kwargs['title'])
- if 'xlabel' in kwargs:
- plt.getXAxis().setLabel(kwargs['xlabel'])
- if 'ylabel' in kwargs:
- plt.getYAxis().setLabel(kwargs['ylabel'])
-
- color = kwargs.get('color')
- linestyle = kwargs.get('linestyle')
- linewidth = kwargs.get('linewidth')
- marker = kwargs.get('marker')
+ if "title" in kwargs:
+ plt.setGraphTitle(kwargs["title"])
+ if "xlabel" in kwargs:
+ plt.getXAxis().setLabel(kwargs["xlabel"])
+ if "ylabel" in kwargs:
+ plt.getYAxis().setLabel(kwargs["ylabel"])
+
+ color = kwargs.get("color")
+ linestyle = kwargs.get("linestyle")
+ linewidth = kwargs.get("linewidth")
+ marker = kwargs.get("marker")
# Parse args and store curves as (x, y, style string)
args = list(args)
@@ -172,43 +168,54 @@ def plot(*args, **kwargs):
for c in possible_colors[1:]:
if len(c) > len(curve_color):
curve_color = c
- style = style[len(curve_color):]
+ style = style[len(curve_color) :]
if style:
# Run twice to handle inversion symbol/linestyle
for _i in range(2):
# Handle linestyle
- for line in (' ', '--', '-', '-.', ':'):
+ for line in (" ", "--", "-", "-.", ":"):
if style.endswith(line):
curve_linestyle = line
- style = style[:-len(line)]
+ style = style[: -len(line)]
break
# Handle symbol
- for curve_marker in ('o', '.', ',', '+', 'x', 'd', 's'):
+ for curve_marker in ("o", ".", ",", "+", "x", "d", "s"):
if style.endswith(curve_marker):
curve_symbol = style[-1]
style = style[:-1]
break
# As in matplotlib, marker, linestyle and color override other style
- plt.addCurve(x, y,
- legend=('curve_%d' % index),
- symbol=marker or curve_symbol,
- linestyle=linestyle or curve_linestyle,
- linewidth=linewidth,
- color=color or curve_color)
+ plt.addCurve(
+ x,
+ y,
+ legend=("curve_%d" % index),
+ symbol=marker or curve_symbol,
+ linestyle=linestyle or curve_linestyle,
+ linewidth=linewidth,
+ color=color or curve_color,
+ )
plt.show()
_plots.insert(0, plt)
return plt
-def imshow(data=None, cmap=None, norm=colors.Colormap.LINEAR,
- vmin=None, vmax=None,
- aspect=False,
- origin='upper', scale=(1., 1.),
- title='', xlabel='X', ylabel='Y'):
+def imshow(
+ data=None,
+ cmap=None,
+ norm=colors.Colormap.LINEAR,
+ vmin=None,
+ vmax=None,
+ aspect=False,
+ origin="upper",
+ scale=(1.0, 1.0),
+ title="",
+ xlabel="X",
+ ylabel="Y",
+):
"""
Plot an image in a :class:`~silx.gui.plot.PlotWindow.Plot2D` widget.
@@ -273,18 +280,17 @@ def imshow(data=None, cmap=None, norm=colors.Colormap.LINEAR,
colormap.setVMax(vmax)
# Handle aspect
- if aspect in (None, False, 'auto', 'normal'):
+ if aspect in (None, False, "auto", "normal"):
plt.setKeepDataAspectRatio(False)
- elif aspect in (True, 'equal') or aspect == 1:
+ elif aspect in (True, "equal") or aspect == 1:
plt.setKeepDataAspectRatio(True)
else:
- _logger.warning(
- 'imshow: Unhandled aspect argument: %s', str(aspect))
+ _logger.warning("imshow: Unhandled aspect argument: %s", str(aspect))
# Handle matplotlib-like origin
- if origin in ('upper', 'lower'):
- plt.setYAxisInverted(origin == 'upper')
- origin = 0., 0. # Set origin to the definition of silx
+ if origin in ("upper", "lower"):
+ plt.setYAxisInverted(origin == "upper")
+ origin = 0.0, 0.0 # Set origin to the definition of silx
if data is not None:
data = numpy.array(data, copy=True)
@@ -300,10 +306,17 @@ def imshow(data=None, cmap=None, norm=colors.Colormap.LINEAR,
return plt
-def scatter(x=None, y=None, value=None, size=None,
- marker=None,
- cmap=None, norm=colors.Colormap.LINEAR,
- vmin=None, vmax=None):
+def scatter(
+ x=None,
+ y=None,
+ value=None,
+ size=None,
+ marker=None,
+ cmap=None,
+ norm=colors.Colormap.LINEAR,
+ vmin=None,
+ vmax=None,
+):
"""
Plot scattered data in a :class:`~silx.gui.plot.ScatterView` widget.
@@ -451,7 +464,7 @@ class _GInputHandler(roi.InteractiveRegionOfInterestManager):
super(_GInputHandler, self).__init__(plot)
self._timeout = timeout
- self.__selections = collections.OrderedDict()
+ self.__selections = {}
window = plot.window() # Retrieve window containing PlotWidget
statusBar = window.statusBar()
@@ -481,7 +494,9 @@ class _GInputHandler(roi.InteractiveRegionOfInterestManager):
window.addToolBar(toolbar)
toolbar.addAction(self.getInteractionModeAction(roi_items.PointROI))
- super(_GInputHandler, self).exec(roiClass=roi_items.PointROI, timeout=self._timeout)
+ super(_GInputHandler, self).exec(
+ roiClass=roi_items.PointROI, timeout=self._timeout
+ )
if isinstance(toolbar, InteractiveModeToolBar):
toolbar.removeAction(self.getInteractionModeAction(roi_items.PointROI))
@@ -507,18 +522,19 @@ class _GInputHandler(roi.InteractiveRegionOfInterestManager):
raise RuntimeError("Unexpected item")
x, y = roi.getPosition()
- xPixel, yPixel = plot.dataToPixel(x, y, axis='left', check=False)
+ xPixel, yPixel = plot.dataToPixel(x, y, axis="left", check=False)
# Pick item at selected position
pickingResult = plot._pickTopMost(
- xPixel, yPixel,
- lambda item: isinstance(item, (items.ImageBase, items.Curve)))
+ xPixel,
+ yPixel,
+ lambda item: isinstance(item, (items.ImageBase, items.Curve)),
+ )
if pickingResult is None:
- result = _GInputResult((x, y),
- item=None,
- indices=numpy.array((), dtype=int),
- data=None)
+ result = _GInputResult(
+ (x, y), item=None, indices=numpy.array((), dtype=int), data=None
+ )
else:
item = pickingResult.getItem()
indices = pickingResult.getIndices(copy=True)
@@ -526,18 +542,19 @@ class _GInputHandler(roi.InteractiveRegionOfInterestManager):
if isinstance(item, items.Curve):
xData = item.getXData(copy=False)[indices]
yData = item.getYData(copy=False)[indices]
- result = _GInputResult((x, y),
- item=item,
- indices=indices,
- data=numpy.array((xData, yData)).T)
+ result = _GInputResult(
+ (x, y),
+ item=item,
+ indices=indices,
+ data=numpy.array((xData, yData)).T,
+ )
elif isinstance(item, items.ImageBase):
row, column = indices[0][0], indices[1][0]
data = item.getData(copy=False)[row, column]
- result = _GInputResult((x, y),
- item=item,
- indices=(row, column),
- data=data)
+ result = _GInputResult(
+ (x, y), item=item, indices=(row, column), data=data
+ )
self.__selections[roi] = result
@@ -548,7 +565,7 @@ class _GInputHandler(roi.InteractiveRegionOfInterestManager):
"""
if isinstance(roi, roi_items.PointROI):
# Only handle points
- roi.setName('%d' % len(self.__selections))
+ roi.setName("%d" % len(self.__selections))
self.__updateSelection(roi)
roi.sigRegionChanged.connect(self.__regionChanged)
@@ -610,14 +627,14 @@ def ginput(n=1, timeout=30, plot=None):
plot.show()
if plot is None:
- _logger.warning('No plot available to perform ginput, create one')
+ _logger.warning("No plot available to perform ginput, create one")
plot = Plot1D()
plot.show()
_plots.insert(0, plot)
plot.raise_() # So window becomes the top level one
- _logger.info('Performing ginput with plot widget %s', str(plot))
+ _logger.info("Performing ginput with plot widget %s", str(plot))
handler = _GInputHandler(plot, n, timeout)
points = handler.exec()
diff --git a/src/silx/sx/_plot3d.py b/src/silx/sx/_plot3d.py
index c6833ac..1dc9ea5 100644
--- a/src/silx/sx/_plot3d.py
+++ b/src/silx/sx/_plot3d.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2018-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -29,10 +29,7 @@ __license__ = "MIT"
__date__ = "24/04/2018"
-try:
- from collections import abc
-except ImportError: # Python2 support
- import collections as abc
+from collections import abc
import logging
import numpy
@@ -47,14 +44,16 @@ from ..gui.colors import rgba
_logger = logging.getLogger(__name__)
-def contour3d(scalars,
- contours=1,
- copy=True,
- color=None,
- colormap='viridis',
- vmin=None,
- vmax=None,
- opacity=1.):
+def contour3d(
+ scalars,
+ contours=1,
+ copy=True,
+ color=None,
+ colormap="viridis",
+ vmin=None,
+ vmax=None,
+ opacity=1.0,
+):
"""
Plot isosurfaces of a 3D scalar field in a :class:`~silx.gui.plot3d.ScalarFieldView.ScalarFieldView` widget.
@@ -135,7 +134,7 @@ def contour3d(scalars,
# Prepare and apply opacity
assert isinstance(opacity, float)
- opacity = min(max(0., opacity), 1.) # Clip opacity
+ opacity = min(max(0.0, opacity), 1.0) # Clip opacity
colors[:, -1] = (colors[:, -1] * opacity).astype(numpy.uint8)
# Prepare widget
@@ -151,7 +150,7 @@ def contour3d(scalars,
# Add the parameter tree to the main window in a dock widget
dock = qt.QDockWidget(scalarField)
- dock.setWindowTitle('Parameters')
+ dock.setWindowTitle("Parameters")
dock.setWidget(treeView)
scalarField.addDockWidget(qt.Qt.RightDockWidgetArea, dock)
@@ -164,22 +163,26 @@ def contour3d(scalars,
_POINTS3D_MODE_CONVERSION = {
- '2dcircle': 'o',
- '2dcross': 'x',
- '2ddash': '_',
- '2ddiamond': 'd',
- '2dsquare': 's',
- 'point': ','
+ "2dcircle": "o",
+ "2dcross": "x",
+ "2ddash": "_",
+ "2ddiamond": "d",
+ "2dsquare": "s",
+ "point": ",",
}
-def points3d(x, y, z=None,
- values=0.,
- copy=True,
- colormap='viridis',
- vmin=None,
- vmax=None,
- mode=None):
+def points3d(
+ x,
+ y,
+ z=None,
+ values=0.0,
+ copy=True,
+ colormap="viridis",
+ vmin=None,
+ vmax=None,
+ mode=None,
+):
"""
Plot a 3D scatter plot in a :class:`~silx.gui.plot3d.SceneWindow.SceneWindow` widget.
diff --git a/src/silx/test/__init__.py b/src/silx/test/__init__.py
index 31a892a..0f3d5de 100644
--- a/src/silx/test/__init__.py
+++ b/src/silx/test/__init__.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2015-2022 European Synchrotron Radiation Facility
+# Copyright (c) 2015-2024 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -25,28 +25,39 @@
"""
import logging
+import subprocess
+import sys
try:
import pytest
except ImportError:
logging.getLogger(__name__).error(
- "pytest is required to run the tests, please install it.")
+ "pytest is required to run the tests, please install it."
+ )
raise
-def run_tests(module: str='silx', verbosity: int=0, args=()):
- """Run tests
+
+def run_tests(module: str = "silx", verbosity: int = 0, args=()):
+ """Run tests in a subprocess
:param module: Name of the silx module to test (default: 'silx')
:param verbosity: Requested level of verbosity
:param args: List of extra arguments to pass to `pytest`
"""
- return pytest.main([
- '--pyargs',
- module,
- '--verbosity',
- str(verbosity),
- '-o python_files=["test/test*.py","test/Test*.py"]',
- '-o python_classes=["Test"]',
- '-o python_functions=["test"]',
- ] + list(args))
+ return subprocess.run(
+ [
+ sys.executable,
+ "-m",
+ "pytest",
+ "--pyargs",
+ module,
+ "--verbosity",
+ str(verbosity),
+ '-o python_files=["test/test*.py","test/Test*.py"]',
+ '-o python_classes=["Test"]',
+ '-o python_functions=["test"]',
+ ]
+ + list(args),
+ check=False,
+ ).returncode
diff --git a/src/silx/test/test_resources.py b/src/silx/test/test_resources.py
index 3344da0..52c0df7 100644
--- a/src/silx/test/test_resources.py
+++ b/src/silx/test/test_resources.py
@@ -37,7 +37,6 @@ import silx.resources
class TestResources(unittest.TestCase):
-
@classmethod
def setUpClass(cls):
super(TestResources, cls).setUpClass()
@@ -70,74 +69,78 @@ class TestResources(unittest.TestCase):
def test_resource_dir(self):
"""Get a resource directory"""
- icons_dirname = silx.resources.resource_filename('gui/icons/')
+ icons_dirname = silx.resources.resource_filename("gui/icons/")
self.assertTrue(os.path.isdir(icons_dirname))
def test_resource_file(self):
"""Get a resource file name"""
- filename = silx.resources.resource_filename('gui/icons/colormap.png')
+ filename = silx.resources.resource_filename("gui/icons/colormap.png")
self.assertTrue(os.path.isfile(filename))
def test_resource_nonexistent(self):
"""Get a non existent resource"""
- filename = silx.resources.resource_filename('non_existent_file.txt')
+ filename = silx.resources.resource_filename("non_existent_file.txt")
self.assertFalse(os.path.exists(filename))
def test_isdir(self):
- self.assertTrue(silx.resources.is_dir('gui/icons'))
+ self.assertTrue(silx.resources.is_dir("gui/icons"))
def test_not_isdir(self):
- self.assertFalse(silx.resources.is_dir('gui/icons/colormap.png'))
+ self.assertFalse(silx.resources.is_dir("gui/icons/colormap.png"))
def test_list_dir(self):
- result = silx.resources.list_dir('gui/icons')
+ result = silx.resources.list_dir("gui/icons")
self.assertTrue(len(result) > 10)
# With prefixed resources
def test_resource_dir_with_prefix(self):
"""Get a resource directory"""
- icons_dirname = silx.resources.resource_filename('silx:gui/icons/')
+ icons_dirname = silx.resources.resource_filename("silx:gui/icons/")
self.assertTrue(os.path.isdir(icons_dirname))
def test_resource_file_with_prefix(self):
"""Get a resource file name"""
- filename = silx.resources.resource_filename('silx:gui/icons/colormap.png')
+ filename = silx.resources.resource_filename("silx:gui/icons/colormap.png")
self.assertTrue(os.path.isfile(filename))
def test_resource_nonexistent_with_prefix(self):
"""Get a non existent resource"""
- filename = silx.resources.resource_filename('silx:non_existent_file.txt')
+ filename = silx.resources.resource_filename("silx:non_existent_file.txt")
self.assertFalse(os.path.exists(filename))
def test_isdir_with_prefix(self):
- self.assertTrue(silx.resources.is_dir('silx:gui/icons'))
+ self.assertTrue(silx.resources.is_dir("silx:gui/icons"))
def test_not_isdir_with_prefix(self):
- self.assertFalse(silx.resources.is_dir('silx:gui/icons/colormap.png'))
+ self.assertFalse(silx.resources.is_dir("silx:gui/icons/colormap.png"))
def test_list_dir_with_prefix(self):
- result = silx.resources.list_dir('silx:gui/icons')
+ result = silx.resources.list_dir("silx:gui/icons")
self.assertTrue(len(result) > 10)
# Test new repository
def test_repository_not_exists(self):
"""The resource from 'test' is available"""
- self.assertRaises(ValueError, silx.resources.resource_filename, 'test:foo.png')
+ self.assertRaises(ValueError, silx.resources.resource_filename, "test:foo.png")
def test_adding_test_directory(self):
"""The resource from 'test' is available"""
- silx.resources.register_resource_directory("test", "silx.test.resources", forced_path=self.tmpDirectory)
- path = silx.resources.resource_filename('test:gui/icons/foo.png')
+ silx.resources.register_resource_directory(
+ "test", "silx.test.resources", forced_path=self.tmpDirectory
+ )
+ path = silx.resources.resource_filename("test:gui/icons/foo.png")
self.assertTrue(os.path.exists(path))
def test_adding_test_directory_no_override(self):
"""The resource from 'silx' is still available"""
- silx.resources.register_resource_directory("test", "silx.test.resources", forced_path=self.tmpDirectory)
- filename1 = silx.resources.resource_filename('gui/icons/close.png')
- filename2 = silx.resources.resource_filename('silx:gui/icons/close.png')
- filename3 = silx.resources.resource_filename('test:gui/icons/close.png')
+ silx.resources.register_resource_directory(
+ "test", "silx.test.resources", forced_path=self.tmpDirectory
+ )
+ filename1 = silx.resources.resource_filename("gui/icons/close.png")
+ filename2 = silx.resources.resource_filename("silx:gui/icons/close.png")
+ filename3 = silx.resources.resource_filename("test:gui/icons/close.png")
self.assertTrue(os.path.isfile(filename1))
self.assertTrue(os.path.isfile(filename2))
self.assertTrue(os.path.isfile(filename3))
@@ -147,31 +150,17 @@ class TestResources(unittest.TestCase):
def test_adding_test_directory_non_existing(self):
"""A resource while not exists in test is not available anyway it exists
in silx"""
- silx.resources.register_resource_directory("test", "silx.test.resources", forced_path=self.tmpDirectory)
+ silx.resources.register_resource_directory(
+ "test", "silx.test.resources", forced_path=self.tmpDirectory
+ )
resource_name = "gui/icons/colormap.png"
- path = silx.resources.resource_filename('test:' + resource_name)
- path2 = silx.resources.resource_filename('silx:' + resource_name)
+ path = silx.resources.resource_filename("test:" + resource_name)
+ path2 = silx.resources.resource_filename("silx:" + resource_name)
self.assertFalse(os.path.exists(path))
self.assertTrue(os.path.exists(path2))
-class TestResourcesWithoutPkgResources(TestResources):
-
- @classmethod
- def setUpClass(cls):
- super(TestResourcesWithoutPkgResources, cls).setUpClass()
- cls._old = silx.resources.pkg_resources
- silx.resources.pkg_resources = None
-
- @classmethod
- def tearDownClass(cls):
- silx.resources.pkg_resources = cls._old
- del cls._old
- super(TestResourcesWithoutPkgResources, cls).tearDownClass()
-
-
class TestResourcesWithCustomDirectory(TestResources):
-
@classmethod
def setUpClass(cls):
super(TestResourcesWithCustomDirectory, cls).setUpClass()
diff --git a/src/silx/test/test_sx.py b/src/silx/test/test_sx.py
index 1107964..1b8449b 100644
--- a/src/silx/test/test_sx.py
+++ b/src/silx/test/test_sx.py
@@ -38,6 +38,7 @@ from silx.gui.colors import Colormap
def sx(qapp):
"""Lazy loading to avoid it to create QApplication before qapp fixture"""
from silx import sx
+
if sx._IS_NOTEBOOK:
pytest.skip("notebook context")
if sx._NO_DISPLAY:
@@ -55,44 +56,43 @@ def test_plot(sx, qapp_utils):
qapp_utils.exposeAndClose(plt)
# y
- plt = sx.plot(y, title='y')
+ plt = sx.plot(y, title="y")
qapp_utils.exposeAndClose(plt)
# y, style
- plt = sx.plot(y, 'blued ', title='y, "blued "')
+ plt = sx.plot(y, "blued ", title='y, "blued "')
qapp_utils.exposeAndClose(plt)
# x, y
- plt = sx.plot(x, y, title='x, y')
+ plt = sx.plot(x, y, title="x, y")
qapp_utils.exposeAndClose(plt)
# x, y, style
- plt = sx.plot(x, y, 'ro-', xlabel='x', title='x, y, "ro-"')
+ plt = sx.plot(x, y, "ro-", xlabel="x", title='x, y, "ro-"')
qapp_utils.exposeAndClose(plt)
# x, y, style, y
- plt = sx.plot(x, y, 'ro-', y ** 2, xlabel='x', ylabel='y',
- title='x, y, "ro-", y ** 2')
+ plt = sx.plot(
+ x, y, "ro-", y**2, xlabel="x", ylabel="y", title='x, y, "ro-", y ** 2'
+ )
qapp_utils.exposeAndClose(plt)
# x, y, style, y, style
- plt = sx.plot(x, y, 'ro-', y ** 2, 'b--',
- title='x, y, "ro-", y ** 2, "b--"')
+ plt = sx.plot(x, y, "ro-", y**2, "b--", title='x, y, "ro-", y ** 2, "b--"')
qapp_utils.exposeAndClose(plt)
# x, y, style, x, y, style
- plt = sx.plot(x, y, 'ro-', x, y ** 2, 'b--',
- title='x, y, "ro-", x, y ** 2, "b--"')
+ plt = sx.plot(x, y, "ro-", x, y**2, "b--", title='x, y, "ro-", x, y ** 2, "b--"')
qapp_utils.exposeAndClose(plt)
# x, y, x, y
- plt = sx.plot(x, y, x, y ** 2, title='x, y, x, y ** 2')
+ plt = sx.plot(x, y, x, y**2, title="x, y, x, y ** 2")
qapp_utils.exposeAndClose(plt)
def test_imshow(sx, qapp_utils):
"""Test imshow function"""
- img = numpy.arange(100.).reshape(10, 10) + 1
+ img = numpy.arange(100.0).reshape(10, 10) + 1
# Nothing
plt = sx.imshow()
@@ -103,34 +103,33 @@ def test_imshow(sx, qapp_utils):
qapp_utils.exposeAndClose(plt)
# image, named cmap
- plt = sx.imshow(img, cmap='jet', title='jet cmap')
+ plt = sx.imshow(img, cmap="jet", title="jet cmap")
qapp_utils.exposeAndClose(plt)
# image, custom colormap
- plt = sx.imshow(img, cmap=Colormap(), title='custom colormap')
+ plt = sx.imshow(img, cmap=Colormap(), title="custom colormap")
qapp_utils.exposeAndClose(plt)
# image, log cmap
- plt = sx.imshow(img, norm='log', title='log cmap')
+ plt = sx.imshow(img, norm="log", title="log cmap")
qapp_utils.exposeAndClose(plt)
# image, fixed range
- plt = sx.imshow(img, vmin=10, vmax=20,
- title='[10,20] cmap')
+ plt = sx.imshow(img, vmin=10, vmax=20, title="[10,20] cmap")
qapp_utils.exposeAndClose(plt)
# image, keep ratio
- plt = sx.imshow(img, aspect=True,
- title='keep ratio')
+ plt = sx.imshow(img, aspect=True, title="keep ratio")
qapp_utils.exposeAndClose(plt)
# image, change origin and scale
- plt = sx.imshow(img, origin=(10, 10), scale=(2, 2),
- title='origin=(10, 10), scale=(2, 2)')
+ plt = sx.imshow(
+ img, origin=(10, 10), scale=(2, 2), title="origin=(10, 10), scale=(2, 2)"
+ )
qapp_utils.exposeAndClose(plt)
# image, origin='lower'
- plt = sx.imshow(img, origin='upper', title='origin="lower"')
+ plt = sx.imshow(img, origin="upper", title='origin="lower"')
qapp_utils.exposeAndClose(plt)
@@ -149,7 +148,7 @@ def test_scatter(sx, qapp_utils):
qapp_utils.exposeAndClose(plt)
# single value
- plt = sx.scatter(x, y, 10.)
+ plt = sx.scatter(x, y, 10.0)
qapp_utils.exposeAndClose(plt)
# set size
@@ -157,7 +156,7 @@ def test_scatter(sx, qapp_utils):
qapp_utils.exposeAndClose(plt)
# set colormap
- plt = sx.scatter(x, y, values, cmap='jet')
+ plt = sx.scatter(x, y, values, cmap="jet")
qapp_utils.exposeAndClose(plt)
# set colormap range
@@ -165,7 +164,7 @@ def test_scatter(sx, qapp_utils):
qapp_utils.exposeAndClose(plt)
# set colormap normalisation
- plt = sx.scatter(x, y, values, norm='log')
+ plt = sx.scatter(x, y, values, norm="log")
qapp_utils.exposeAndClose(plt)
@@ -207,9 +206,8 @@ def test_contour3d(sx, qapp_utils):
pytest.skip("OpenGL context is not valid")
# N contours + color
- colors = ['red', 'green', 'blue']
- window = sx.contour3d(data, copy=False, contours=len(colors),
- color=colors)
+ colors = ["red", "green", "blue"]
+ window = sx.contour3d(data, copy=False, contours=len(colors), color=colors)
isosurfaces = window.getIsosurfaces()
assert len(isosurfaces) == len(colors)
@@ -218,23 +216,23 @@ def test_contour3d(sx, qapp_utils):
# by isolevel, single color
contours = 0.2, 0.5
- window = sx.contour3d(data, copy=False, contours=contours,
- color='yellow')
+ window = sx.contour3d(data, copy=False, contours=contours, color="yellow")
isosurfaces = window.getIsosurfaces()
assert len(isosurfaces) == len(contours)
for iso, level in zip(isosurfaces, contours):
assert iso.getLevel() == level
- assert rgba(iso.getColor()) == rgba('yellow')
+ assert rgba(iso.getColor()) == rgba("yellow")
# Single isolevel, colormap
- window = sx.contour3d(data, copy=False, contours=0.5,
- colormap='gray', vmin=0.6, opacity=0.4)
+ window = sx.contour3d(
+ data, copy=False, contours=0.5, colormap="gray", vmin=0.6, opacity=0.4
+ )
isosurfaces = window.getIsosurfaces()
assert len(isosurfaces) == 1
assert isosurfaces[0].getLevel() == 0.5
- assert rgba(isosurfaces[0].getColor()) == (0., 0., 0., 0.4)
+ assert rgba(isosurfaces[0].getColor()) == (0.0, 0.0, 0.0, 0.4)
@pytest.mark.usefixtures("use_opengl")
@@ -253,12 +251,14 @@ def test_points3d(sx, qapp_utils):
pytest.skip("OpenGL context is not valid")
# 3D positions, values
- window = sx.points3d(x, y, z, values, mode='2dsquare',
- colormap='magma', vmin=0.4, vmax=0.5)
+ window = sx.points3d(
+ x, y, z, values, mode="2dsquare", colormap="magma", vmin=0.4, vmax=0.5
+ )
# 2D positions, no value
window = sx.points3d(x, y)
# 2D positions, values
- window = sx.points3d(x, y, values=values, mode=',',
- colormap='magma', vmin=0.4, vmax=0.5)
+ window = sx.points3d(
+ x, y, values=values, mode=",", colormap="magma", vmin=0.4, vmax=0.5
+ )
diff --git a/src/silx/test/utils.py b/src/silx/test/utils.py
index 5178e4b..72afdf1 100644
--- a/src/silx/test/utils.py
+++ b/src/silx/test/utils.py
@@ -41,15 +41,16 @@ import tempfile
from ..resources import ExternalResources
-utilstest = ExternalResources(project="silx",
- url_base="http://www.silx.org/pub/silx/",
- env_key="SILX_DATA",
- timeout=60)
+utilstest = ExternalResources(
+ project="silx",
+ url_base="http://www.silx.org/pub/silx/",
+ env_key="SILX_DATA",
+ timeout=60,
+)
"This is the instance to be used. Singleton-like feature provided by module"
class _TestOptions(object):
-
def __init__(self):
self.WITH_QT_TEST = True
"""Qt tests are included"""
@@ -82,32 +83,32 @@ class _TestOptions(object):
if parsed_options is not None and not parsed_options.gui:
self.WITH_QT_TEST = False
self.WITH_QT_TEST_REASON = "Skipped by command line"
- elif os.environ.get('WITH_QT_TEST', 'True') == 'False':
+ elif os.environ.get("WITH_QT_TEST", "True") == "False":
self.WITH_QT_TEST = False
self.WITH_QT_TEST_REASON = "Skipped by WITH_QT_TEST env var"
- elif sys.platform.startswith('linux') and not os.environ.get('DISPLAY', ''):
+ elif sys.platform.startswith("linux") and not os.environ.get("DISPLAY", ""):
self.WITH_QT_TEST = False
self.WITH_QT_TEST_REASON = "DISPLAY env variable not set"
if parsed_options is not None and not parsed_options.opencl:
self.WITH_OPENCL_TEST_REASON = "Skipped by command line"
self.WITH_OPENCL_TEST = False
- elif os.environ.get('SILX_OPENCL', 'True') == 'False':
+ elif os.environ.get("SILX_OPENCL", "True") == "False":
self.WITH_OPENCL_TEST_REASON = "Skipped by SILX_OPENCL env var"
self.WITH_OPENCL_TEST = False
if not self.WITH_OPENCL_TEST:
# That's an easy way to skip OpenCL tests
# It disable the use of OpenCL on the full silx project
- os.environ['SILX_OPENCL'] = "False"
+ os.environ["SILX_OPENCL"] = "False"
if parsed_options is not None and not parsed_options.opengl:
self.WITH_GL_TEST = False
self.WITH_GL_TEST_REASON = "Skipped by command line"
- elif os.environ.get('WITH_GL_TEST', 'True') == 'False':
+ elif os.environ.get("WITH_GL_TEST", "True") == "False":
self.WITH_GL_TEST = False
self.WITH_GL_TEST_REASON = "Skipped by WITH_GL_TEST env var"
- elif sys.platform.startswith('linux') and not os.environ.get('DISPLAY', ''):
+ elif sys.platform.startswith("linux") and not os.environ.get("DISPLAY", ""):
self.WITH_GL_TEST = False
self.WITH_GL_TEST_REASON = "DISPLAY env variable not set"
else:
@@ -120,7 +121,7 @@ class _TestOptions(object):
if parsed_options is not None and parsed_options.low_mem:
self.TEST_LOW_MEM = True
self.TEST_LOW_MEM_REASON = "Skipped by command line"
- elif os.environ.get('SILX_TEST_LOW_MEM', 'True') == 'False':
+ elif os.environ.get("SILX_TEST_LOW_MEM", "True") == "False":
self.TEST_LOW_MEM = True
self.TEST_LOW_MEM_REASON = "Skipped by SILX_TEST_LOW_MEM env var"
@@ -137,6 +138,7 @@ class _TestOptions(object):
# Temporary directory context #################################################
+
@contextlib.contextmanager
def temp_dir():
"""with context providing a temporary directory.
@@ -153,7 +155,7 @@ def temp_dir():
# Synthetic data and random noise #############################################
-def add_gaussian_noise(y, stdev=1., mean=0.):
+def add_gaussian_noise(y, stdev=1.0, mean=0.0):
"""Add random gaussian noise to synthetic data.
:param ndarray y: Array of synthetic data
@@ -178,7 +180,7 @@ def add_poisson_noise(y):
return yn
-def add_relative_noise(y, max_noise=5.):
+def add_relative_noise(y, max_noise=5.0):
"""Add relative random noise to synthetic data. The maximum noise level
is given in percents.
@@ -194,4 +196,4 @@ def add_relative_noise(y, max_noise=5.):
"""
noise = max_noise * (2 * numpy.random.random(size=y.size) - 1)
noise.shape = y.shape
- return y * (1. + noise / 100.)
+ return y * (1.0 + noise / 100.0)
diff --git a/src/silx/third_party/EdfFile.py b/src/silx/third_party/EdfFile.py
index 0606d1c..a9e2e1b 100644
--- a/src/silx/third_party/EdfFile.py
+++ b/src/silx/third_party/EdfFile.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2004-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2023 European Synchrotron Radiation Facility
#
# This file is part of the PyMca X-ray Fluorescence Toolkit developed at
# the ESRF by the Software group.
@@ -95,13 +95,16 @@ DEBUG = 0
import sys
import numpy
import os.path
+
try:
import gzip
+
GZIP = True
except:
GZIP = False
try:
import bz2
+
BZ2 = True
except:
BZ2 = False
@@ -110,10 +113,19 @@ MARCCD_SUPPORT = False
PILATUS_CBF_SUPPORT = False
CAN_USE_FASTEDF = False
-# Using local TiffIO
-from . import TiffIO
+from fabio import TiffIO
+
TIFF_SUPPORT = True
+from silx.utils.deprecation import deprecated_warning
+
+deprecated_warning(
+ "Module",
+ "silx.third_party.EdfFile",
+ since_version="2.0.0",
+ replacement="fabio.open and fabio.edfimage.EdfImage",
+)
+
# Constants
HEADER_BLOCK_SIZE = 1024
@@ -128,7 +140,8 @@ STATIC_HEADER_ELEMENTS = (
"Offset_1",
"Offset_2",
"Offset_3",
- "Size")
+ "Size",
+)
STATIC_HEADER_ELEMENTS_CAPS = (
"HEADERID",
@@ -141,7 +154,8 @@ STATIC_HEADER_ELEMENTS_CAPS = (
"OFFSET_1",
"OFFSET_2",
"OFFSET_3",
- "SIZE")
+ "SIZE",
+)
LOWER_CASE = 0
UPPER_CASE = 1
@@ -151,11 +165,10 @@ VALUES = 2
class Image(object):
- """
- """
+ """ """
+
def __init__(self):
- """ Constructor
- """
+ """Constructor"""
self.Header = {}
self.StaticHeader = {}
self.HeaderPosition = 0
@@ -169,10 +182,10 @@ class Image(object):
class EdfFile(object):
- """
- """
+ """ """
+
def __init__(self, FileName, access=None, fastedf=None):
- """ Constructor
+ """Constructor
:param FileName: Name of the file (either existing or to be created)
:type FileName: string
@@ -199,8 +212,7 @@ class EdfFile(object):
else:
self.SysByteOrder = "LowByteFirst"
- if hasattr(FileName, "seek") and\
- hasattr(FileName, "read"):
+ if hasattr(FileName, "seek") and hasattr(FileName, "read"):
# this looks like a file descriptor ...
self.__ownedOpen = False
self.File = FileName
@@ -208,13 +220,13 @@ class EdfFile(object):
self.FileName = self.File.name
except AttributeError:
self.FileName = self.File.filename
- elif FileName.lower().endswith('.gz'):
+ elif FileName.lower().endswith(".gz"):
if GZIP:
self.__ownedOpen = False
self.File = gzip.GzipFile(FileName)
else:
raise IOError("No gzip module support in this system")
- elif FileName.lower().endswith('.bz2'):
+ elif FileName.lower().endswith(".bz2"):
if BZ2:
self.__ownedOpen = False
self.File = bz2.BZ2File(FileName)
@@ -228,8 +240,8 @@ class EdfFile(object):
if access[0].upper() == "R":
if not os.path.isfile(self.FileName):
raise IOError("File %s not found" % FileName)
- if 'b' not in access:
- access += 'b'
+ if "b" not in access:
+ access += "b"
if 1:
if not os.path.isfile(self.FileName):
# write access
@@ -239,13 +251,13 @@ class EdfFile(object):
self.File = open(self.FileName, access)
self.File.seek(0, 0)
return
- if 'b' not in access:
- access += 'b'
+ if "b" not in access:
+ access += "b"
self.File = open(self.FileName, access)
return
else:
if access is None:
- if (os.access(self.FileName, os.W_OK)):
+ if os.access(self.FileName, os.W_OK):
access = "r+b"
else:
access = "rb"
@@ -253,15 +265,17 @@ class EdfFile(object):
self.File.seek(0, 0)
twoChars = self.File.read(2)
tiff = False
- if sys.version < '3.0':
+ if sys.version < "3.0":
if twoChars in ["II", "MM"]:
tiff = True
elif twoChars in [eval('b"II"'), eval('b"MM"')]:
- tiff = True
+ tiff = True
if tiff:
fileExtension = os.path.splitext(self.FileName)[-1]
- if fileExtension.lower() in [".tif", ".tiff"] or\
- sys.version > '2.9':
+ if (
+ fileExtension.lower() in [".tif", ".tiff"]
+ or sys.version > "2.9"
+ ):
if not TIFF_SUPPORT:
raise IOError("TIFF support not implemented")
else:
@@ -274,15 +288,15 @@ class EdfFile(object):
else:
self.MARCCD = True
basename = os.path.basename(FileName).upper()
- if basename.endswith('.CBF'):
+ if basename.endswith(".CBF"):
if not PILATUS_CBF_SUPPORT:
raise IOError("CBF support not implemented")
if twoChars[0] != "{":
self.PILATUS_CBF = True
- elif basename.endswith('.SPE'):
+ elif basename.endswith(".SPE"):
if twoChars[0] != "$":
self.SPE = True
- elif basename.endswith('EDF.GZ') or basename.endswith('CCD.GZ'):
+ elif basename.endswith("EDF.GZ") or basename.endswith("CCD.GZ"):
self.GZIP = True
else:
try:
@@ -312,13 +326,13 @@ class EdfFile(object):
Index = 0
line = self.File.readline()
selectedLines = [""]
- if sys.version > '2.6':
+ if sys.version > "2.6":
selectedLines.append(eval('b""'))
parsingHeader = False
while line not in selectedLines:
# decode to make sure I have character string
# str to make sure python 2.x sees it as string and not unicode
- if sys.version < '3.0':
+ if sys.version < "3.0":
if type(line) != type(str("")):
line = "%s" % line
else:
@@ -326,10 +340,10 @@ class EdfFile(object):
line = str(line.decode())
except UnicodeDecodeError:
try:
- line = str(line.decode('utf-8'))
+ line = str(line.decode("utf-8"))
except UnicodeDecodeError:
try:
- line = str(line.decode('latin-1'))
+ line = str(line.decode("latin-1"))
except UnicodeDecodeError:
line = "%s" % line
if (line.count("{\n") >= 1) or (line.count("{\r\n") >= 1):
@@ -352,14 +366,18 @@ class EdfFile(object):
self.Images[Index].StaticHeader[typeItem] = valueItem
else:
self.Images[Index].Header[typeItem] = valueItem
- if ((line.count("}\n") >= 1) or (line.count("}\r") >= 1)) and (parsingHeader):
+ if ((line.count("}\n") >= 1) or (line.count("}\r") >= 1)) and (
+ parsingHeader
+ ):
parsingHeader = False
# for i in STATIC_HEADER_ELEMENTS_CAPS:
# if self.Images[Index].StaticHeader[i]=="":
# raise "Bad File Format"
self.Images[Index].DataPosition = self.File.tell()
# self.File.seek(int(self.Images[Index].StaticHeader["Size"]), 1)
- StaticPar = SetDictCase(self.Images[Index].StaticHeader, UPPER_CASE, KEYS)
+ StaticPar = SetDictCase(
+ self.Images[Index].StaticHeader, UPPER_CASE, KEYS
+ )
if "SIZE" in StaticPar.keys():
self.Images[Index].Size = int(StaticPar["SIZE"])
if self.Images[Index].Size <= 0:
@@ -403,60 +421,63 @@ class EdfFile(object):
header_keys = []
header = {}
try:
- """ read an adsc header """
+ """read an adsc header"""
line = infile.readline()
bytesread = len(line)
- while '}' not in line:
- if '=' in line:
- (key, val) = line.split('=')
+ while "}" not in line:
+ if "=" in line:
+ (key, val) = line.split("=")
header_keys.append(key.strip())
- header[key.strip()] = val.strip(' ;\n')
+ header[key.strip()] = val.strip(" ;\n")
line = infile.readline()
bytesread = bytesread + len(line)
except:
raise Exception("Error processing adsc header")
# banned by bzip/gzip???
try:
- infile.seek(int(header['HEADER_BYTES']), 0)
+ infile.seek(int(header["HEADER_BYTES"]), 0)
except TypeError:
# Gzipped does not allow a seek and read header is not
# promising to stop in the right place
infile.close()
infile = self._open(fname, "rb")
- infile.read(int(header['HEADER_BYTES']))
+ infile.read(int(header["HEADER_BYTES"]))
binary = infile.read()
infile.close()
# now read the data into the array
- self.Images[Index].Dim1 = int(header['SIZE1'])
- self.Images[Index].Dim2 = int(header['SIZE2'])
+ self.Images[Index].Dim1 = int(header["SIZE1"])
+ self.Images[Index].Dim2 = int(header["SIZE2"])
self.Images[Index].NumDim = 2
- self.Images[Index].DataType = 'UnsignedShort'
+ self.Images[Index].DataType = "UnsignedShort"
try:
self.__data = numpy.reshape(
numpy.copy(numpy.frombuffer(binary, numpy.uint16)),
- (self.Images[Index].Dim2, self.Images[Index].Dim1))
+ (self.Images[Index].Dim2, self.Images[Index].Dim1),
+ )
except ValueError:
- msg = 'Size spec in ADSC-header does not match size of image data field'
+ msg = "Size spec in ADSC-header does not match size of image data field"
raise IOError(msg)
- if 'little' in header['BYTE_ORDER']:
- self.Images[Index].ByteOrder = 'LowByteFirst'
+ if "little" in header["BYTE_ORDER"]:
+ self.Images[Index].ByteOrder = "LowByteFirst"
else:
- self.Images[Index].ByteOrder = 'HighByteFirst'
+ self.Images[Index].ByteOrder = "HighByteFirst"
if self.SysByteOrder.upper() != self.Images[Index].ByteOrder.upper():
self.__data = self.__data.byteswap()
self.Images[Index].ByteOrder = self.SysByteOrder
- self.Images[Index].StaticHeader['Dim_1'] = self.Images[Index].Dim1
- self.Images[Index].StaticHeader['Dim_2'] = self.Images[Index].Dim2
- self.Images[Index].StaticHeader['Offset_1'] = 0
- self.Images[Index].StaticHeader['Offset_2'] = 0
- self.Images[Index].StaticHeader['DataType'] = self.Images[Index].DataType
+ self.Images[Index].StaticHeader["Dim_1"] = self.Images[Index].Dim1
+ self.Images[Index].StaticHeader["Dim_2"] = self.Images[Index].Dim2
+ self.Images[Index].StaticHeader["Offset_1"] = 0
+ self.Images[Index].StaticHeader["Offset_2"] = 0
+ self.Images[Index].StaticHeader["DataType"] = self.Images[Index].DataType
self.__makeSureFileIsClosed()
def _wrapTIFF(self):
- self._wrappedInstance = TiffIO.TiffIO(self.File, cache_length=0, mono_output=True)
+ self._wrappedInstance = TiffIO.TiffIO(
+ self.File, cache_length=0, mono_output=True
+ )
self.NumImages = self._wrappedInstance.getNumberOfImages()
if self.NumImages < 1:
return
@@ -471,17 +492,17 @@ class EdfFile(object):
for Index in range(self.NumImages):
info = self._wrappedInstance.getInfo(Index)
self.Images.append(Image())
- self.Images[Index].Dim1 = info['nRows']
- self.Images[Index].Dim2 = info['nColumns']
+ self.Images[Index].Dim1 = info["nRows"]
+ self.Images[Index].Dim2 = info["nColumns"]
self.Images[Index].NumDim = 2
if data is None:
data = self._wrappedInstance.getData(0)
self.Images[Index].DataType = self.__GetDefaultEdfType__(data.dtype)
- self.Images[Index].StaticHeader['Dim_1'] = self.Images[Index].Dim1
- self.Images[Index].StaticHeader['Dim_2'] = self.Images[Index].Dim2
- self.Images[Index].StaticHeader['Offset_1'] = 0
- self.Images[Index].StaticHeader['Offset_2'] = 0
- self.Images[Index].StaticHeader['DataType'] = self.Images[Index].DataType
+ self.Images[Index].StaticHeader["Dim_1"] = self.Images[Index].Dim1
+ self.Images[Index].StaticHeader["Dim_2"] = self.Images[Index].Dim2
+ self.Images[Index].StaticHeader["Offset_1"] = 0
+ self.Images[Index].StaticHeader["Offset_2"] = 0
+ self.Images[Index].StaticHeader["DataType"] = self.Images[Index].DataType
self.Images[Index].Header.update(info)
def _wrapMarCCD(self):
@@ -491,7 +512,7 @@ class EdfFile(object):
raise NotImplementedError("Look at the module EdfFile from PyMca")
def _wrapSPE(self):
- if 0 and sys.version < '3.0':
+ if 0 and sys.version < "3.0":
self.File.seek(42)
xdim = numpy.int64(numpy.fromfile(self.File, numpy.int16, 1)[0])
self.File.seek(656)
@@ -500,12 +521,15 @@ class EdfFile(object):
self.__data = numpy.fromfile(self.File, numpy.uint16, int(xdim * ydim))
else:
import struct
+
self.File.seek(0)
a = self.File.read()
- xdim = numpy.int64(struct.unpack('<h', a[42:44])[0])
- ydim = numpy.int64(struct.unpack('<h', a[656:658])[0])
- fmt = '<%dH' % int(xdim * ydim)
- self.__data = numpy.array(struct.unpack(fmt, a[4100:int(4100 + int(2 * xdim * ydim))])).astype(numpy.uint16)
+ xdim = numpy.int64(struct.unpack("<h", a[42:44])[0])
+ ydim = numpy.int64(struct.unpack("<h", a[656:658])[0])
+ fmt = "<%dH" % int(xdim * ydim)
+ self.__data = numpy.array(
+ struct.unpack(fmt, a[4100 : int(4100 + int(2 * xdim * ydim))])
+ ).astype(numpy.uint16)
self.__data.shape = ydim, xdim
Index = 0
self.Images.append(Image())
@@ -513,19 +537,18 @@ class EdfFile(object):
self.Images[Index].Dim1 = ydim
self.Images[Index].Dim2 = xdim
self.Images[Index].NumDim = 2
- self.Images[Index].DataType = 'UnsignedShort'
- self.Images[Index].ByteOrder = 'LowByteFirst'
+ self.Images[Index].DataType = "UnsignedShort"
+ self.Images[Index].ByteOrder = "LowByteFirst"
if self.SysByteOrder.upper() != self.Images[Index].ByteOrder.upper():
self.__data = self.__data.byteswap()
- self.Images[Index].StaticHeader['Dim_1'] = self.Images[Index].Dim1
- self.Images[Index].StaticHeader['Dim_2'] = self.Images[Index].Dim2
- self.Images[Index].StaticHeader['Offset_1'] = 0
- self.Images[Index].StaticHeader['Offset_2'] = 0
- self.Images[Index].StaticHeader['DataType'] = self.Images[Index].DataType
+ self.Images[Index].StaticHeader["Dim_1"] = self.Images[Index].Dim1
+ self.Images[Index].StaticHeader["Dim_2"] = self.Images[Index].Dim2
+ self.Images[Index].StaticHeader["Offset_1"] = 0
+ self.Images[Index].StaticHeader["Offset_2"] = 0
+ self.Images[Index].StaticHeader["DataType"] = self.Images[Index].DataType
def GetNumImages(self):
- """ Returns number of images of the object (and associated file)
- """
+ """Returns number of images of the object (and associated file)"""
return self.NumImages
def GetData(self, *var, **kw):
@@ -536,34 +559,34 @@ class EdfFile(object):
self.__makeSureFileIsClosed()
def _GetData(self, Index, DataType="", Pos=None, Size=None):
- """ Returns numpy array with image data
- Index: The zero-based index of the image in the file
- DataType: The edf type of the array to be returnd
- If ommited, it is used the default one for the type
- indicated in the image header
- Attention to the absence of UnsignedShort,
- UnsignedInteger and UnsignedLong types in
- Numpy Python
- Default relation between Edf types and NumPy's typecodes:
- SignedByte int8 b
- UnsignedByte uint8 B
- SignedShort int16 h
- UnsignedShort uint16 H
- SignedInteger int32 i
- UnsignedInteger uint32 I
- SignedLong int32 i
- UnsignedLong uint32 I
- Signed64 int64 (l in 64bit, q in 32 bit)
- Unsigned64 uint64 (L in 64bit, Q in 32 bit)
- FloatValue float32 f
- DoubleValue float64 d
- Pos: Tuple (x) or (x,y) or (x,y,z) that indicates the begining
- of data to be read. If ommited, set to the origin (0),
- (0,0) or (0,0,0)
- Size: Tuple, size of the data to be returned as x) or (x,y) or
- (x,y,z) if ommited, is the distance from Pos to the end.
-
- If Pos and Size not mentioned, returns the whole data.
+ """Returns numpy array with image data
+ Index: The zero-based index of the image in the file
+ DataType: The edf type of the array to be returnd
+ If ommited, it is used the default one for the type
+ indicated in the image header
+ Attention to the absence of UnsignedShort,
+ UnsignedInteger and UnsignedLong types in
+ Numpy Python
+ Default relation between Edf types and NumPy's typecodes:
+ SignedByte int8 b
+ UnsignedByte uint8 B
+ SignedShort int16 h
+ UnsignedShort uint16 H
+ SignedInteger int32 i
+ UnsignedInteger uint32 I
+ SignedLong int32 i
+ UnsignedLong uint32 I
+ Signed64 int64 (l in 64bit, q in 32 bit)
+ Unsigned64 uint64 (L in 64bit, Q in 32 bit)
+ FloatValue float32 f
+ DoubleValue float64 d
+ Pos: Tuple (x) or (x,y) or (x,y,z) that indicates the begining
+ of data to be read. If ommited, set to the origin (0),
+ (0,0) or (0,0,0)
+ Size: Tuple, size of the data to be returned as x) or (x,y) or
+ (x,y,z) if ommited, is the distance from Pos to the end.
+
+ If Pos and Size not mentioned, returns the whole data.
"""
fastedf = self.fastedf
if Index < 0 or Index >= self.NumImages:
@@ -578,7 +601,9 @@ class EdfFile(object):
return data
else:
self.File.seek(self.Images[Index].DataPosition, 0)
- datatype = self.__GetDefaultNumpyType__(self.Images[Index].DataType, index=Index)
+ datatype = self.__GetDefaultNumpyType__(
+ self.Images[Index].DataType, index=Index
+ )
try:
datasize = self.__GetSizeNumpyType__(datatype)
except TypeError:
@@ -587,12 +612,23 @@ class EdfFile(object):
if self.Images[Index].NumDim == 3:
image = self.Images[Index]
sizeToRead = image.Dim1 * image.Dim2 * image.Dim3 * datasize
- Data = numpy.copy(numpy.frombuffer(self.File.read(sizeToRead), datatype))
- Data = numpy.reshape(Data, (self.Images[Index].Dim3, self.Images[Index].Dim2, self.Images[Index].Dim1))
+ Data = numpy.copy(
+ numpy.frombuffer(self.File.read(sizeToRead), datatype)
+ )
+ Data = numpy.reshape(
+ Data,
+ (
+ self.Images[Index].Dim3,
+ self.Images[Index].Dim2,
+ self.Images[Index].Dim1,
+ ),
+ )
elif self.Images[Index].NumDim == 2:
image = self.Images[Index]
sizeToRead = image.Dim1 * image.Dim2 * datasize
- Data = numpy.copy(numpy.frombuffer(self.File.read(sizeToRead), datatype))
+ Data = numpy.copy(
+ numpy.frombuffer(self.File.read(sizeToRead), datatype)
+ )
# print "datatype = ",datatype
# print "Data.type = ", Data.dtype.char
# print "self.Images[Index].DataType ", self.Images[Index].DataType
@@ -600,22 +636,27 @@ class EdfFile(object):
# print "datasize = ",datasize
# print "sizeToRead ",sizeToRead
# print "lenData = ", len(Data)
- Data = numpy.reshape(Data, (self.Images[Index].Dim2, self.Images[Index].Dim1))
+ Data = numpy.reshape(
+ Data, (self.Images[Index].Dim2, self.Images[Index].Dim1)
+ )
elif self.Images[Index].NumDim == 1:
sizeToRead = self.Images[Index].Dim1 * datasize
- Data = numpy.copy(numpy.frombuffer(self.File.read(sizeToRead), datatype))
+ Data = numpy.copy(
+ numpy.frombuffer(self.File.read(sizeToRead), datatype)
+ )
elif self.ADSC or self.MARCCD or self.PILATUS_CBF or self.SPE:
- return self.__data[Pos[1]:(Pos[1] + Size[1]),
- Pos[0]:(Pos[0] + Size[0])]
+ return self.__data[Pos[1] : (Pos[1] + Size[1]), Pos[0] : (Pos[0] + Size[0])]
elif self.TIFF:
data = self._wrappedInstance.getData(Index)
- return data[Pos[1]:(Pos[1] + Size[1]), Pos[0]:(Pos[0] + Size[0])]
+ return data[Pos[1] : (Pos[1] + Size[1]), Pos[0] : (Pos[0] + Size[0])]
elif fastedf and CAN_USE_FASTEDF:
raise NotImplementedError("Look at the module EdfFile from PyMCA")
else:
if fastedf:
print("It could not use fast routines")
- type_ = self.__GetDefaultNumpyType__(self.Images[Index].DataType, index=Index)
+ type_ = self.__GetDefaultNumpyType__(
+ self.Images[Index].DataType, index=Index
+ )
size_pixel = self.__GetSizeNumpyType__(type_)
Data = numpy.array([], type_)
if self.Images[Index].NumDim == 1:
@@ -627,8 +668,12 @@ class EdfFile(object):
Size = list(Size)
if Size[0] == 0:
Size[0] = sizex - Pos[0]
- self.File.seek((Pos[0] * size_pixel) + self.Images[Index].DataPosition, 0)
- Data = numpy.copy(numpy.frombuffer(self.File.read(Size[0] * size_pixel), type_))
+ self.File.seek(
+ (Pos[0] * size_pixel) + self.Images[Index].DataPosition, 0
+ )
+ Data = numpy.copy(
+ numpy.frombuffer(self.File.read(Size[0] * size_pixel), type_)
+ )
elif self.Images[Index].NumDim == 2:
if Pos is None:
Pos = (0, 0)
@@ -645,8 +690,14 @@ class EdfFile(object):
Data = numpy.zeros((Size[1], Size[0]), type_)
dataindex = 0
for y in range(Pos[1], Pos[1] + Size[1]):
- self.File.seek((((y * sizex) + Pos[0]) * size_pixel) + self.Images[Index].DataPosition, 0)
- line = numpy.copy(numpy.frombuffer(self.File.read(Size[0] * size_pixel), type_))
+ self.File.seek(
+ (((y * sizex) + Pos[0]) * size_pixel)
+ + self.Images[Index].DataPosition,
+ 0,
+ )
+ line = numpy.copy(
+ numpy.frombuffer(self.File.read(Size[0] * size_pixel), type_)
+ )
Data[dataindex, :] = line[:]
# Data=numpy.concatenate((Data,line))
dataindex += 1
@@ -659,7 +710,11 @@ class EdfFile(object):
if Size is None:
Size = (0, 0, 0)
Size = list(Size)
- sizex, sizey, sizez = self.Images[Index].Dim1, self.Images[Index].Dim2, self.Images[Index].Dim3
+ sizex, sizey, sizez = (
+ self.Images[Index].Dim1,
+ self.Images[Index].Dim2,
+ self.Images[Index].Dim3,
+ )
if Size[0] == 0:
Size[0] = sizex - Pos[0]
if Size[1] == 0:
@@ -668,8 +723,16 @@ class EdfFile(object):
Size[2] = sizez - Pos[2]
for z in range(Pos[2], Pos[2] + Size[2]):
for y in range(Pos[1], Pos[1] + Size[1]):
- self.File.seek(((((z * sizey + y) * sizex) + Pos[0]) * size_pixel) + self.Images[Index].DataPosition, 0)
- line = numpy.copy(numpy.frombuffer(self.File.read(Size[0] * size_pixel), type_))
+ self.File.seek(
+ ((((z * sizey + y) * sizex) + Pos[0]) * size_pixel)
+ + self.Images[Index].DataPosition,
+ 0,
+ )
+ line = numpy.copy(
+ numpy.frombuffer(
+ self.File.read(Size[0] * size_pixel), type_
+ )
+ )
Data = numpy.concatenate((Data, line))
Data = numpy.reshape(Data, (Size[2], Size[1], Size[0]))
@@ -680,16 +743,18 @@ class EdfFile(object):
return Data
def GetPixel(self, Index, Position):
- """ Returns double value of the pixel, regardless the format of the array
- Index: The zero-based index of the image in the file
- Position: Tuple with the coordinete (x), (x,y) or (x,y,z)
+ """Returns double value of the pixel, regardless the format of the array
+ Index: The zero-based index of the image in the file
+ Position: Tuple with the coordinete (x), (x,y) or (x,y,z)
"""
if Index < 0 or Index >= self.NumImages:
raise ValueError("EdfFile: Index out of limit")
if len(Position) != self.Images[Index].NumDim:
raise ValueError("EdfFile: coordinate with wrong dimension ")
- size_pixel = self.__GetSizeNumpyType__(self.__GetDefaultNumpyType__(self.Images[Index].DataType, index=Index))
+ size_pixel = self.__GetSizeNumpyType__(
+ self.__GetDefaultNumpyType__(self.Images[Index].DataType, index=Index)
+ )
offset = Position[0] * size_pixel
if self.Images[Index].NumDim > 1:
size_row = size_pixel * self.Images[Index].Dim1
@@ -698,20 +763,23 @@ class EdfFile(object):
size_img = size_row * self.Images[Index].Dim2
offset = offset + (Position[2] * size_img)
self.File.seek(self.Images[Index].DataPosition + offset, 0)
- Data = numpy.copy(numpy.frombuffer(self.File.read(size_pixel),
- self.__GetDefaultNumpyType__(self.Images[Index].DataType,
- index=Index)))
+ Data = numpy.copy(
+ numpy.frombuffer(
+ self.File.read(size_pixel),
+ self.__GetDefaultNumpyType__(self.Images[Index].DataType, index=Index),
+ )
+ )
if self.SysByteOrder.upper() != self.Images[Index].ByteOrder.upper():
Data = Data.byteswap()
Data = self.__SetDataType__(Data, "DoubleValue")
return Data[0]
def GetHeader(self, Index):
- """ Returns dictionary with image header fields.
- Does not include the basic fields (static) defined by data shape,
- type and file position. These are get with GetStaticHeader
- method.
- Index: The zero-based index of the image in the file
+ """Returns dictionary with image header fields.
+ Does not include the basic fields (static) defined by data shape,
+ type and file position. These are get with GetStaticHeader
+ method.
+ Index: The zero-based index of the image in the file
"""
if Index < 0 or Index >= self.NumImages:
raise ValueError("Index out of limit")
@@ -722,10 +790,10 @@ class EdfFile(object):
return ret
def GetStaticHeader(self, Index):
- """ Returns dictionary with static parameters
- Data format and file position dependent information
- (dim1,dim2,size,datatype,byteorder,headerId,Image)
- Index: The zero-based index of the image in the file
+ """Returns dictionary with static parameters
+ Data format and file position dependent information
+ (dim1,dim2,size,datatype,byteorder,headerId,Image)
+ Index: The zero-based index of the image in the file
"""
if Index < 0 or Index >= self.NumImages:
raise ValueError("Index out of limit")
@@ -743,37 +811,37 @@ class EdfFile(object):
self.__makeSureFileIsClosed()
def _WriteImage(self, Header, Data, Append=1, DataType="", ByteOrder=""):
- """ Writes image to the file.
- Header: Dictionary containing the non-static header
- information (static information is generated
- according to position of image and data format
- Append: If equals to 0, overwrites the file. Otherwise, appends
- to the end of the file
- DataType: The data type to be saved to the file:
- SignedByte
- UnsignedByte
- SignedShort
- UnsignedShort
- SignedInteger
- UnsignedInteger
- SignedLong
- UnsignedLong
- FloatValue
- DoubleValue
- Default: according to Data array typecode:
- 1: SignedByte
- b: UnsignedByte
- s: SignedShort
- w: UnsignedShort
- i: SignedInteger
- l: SignedLong
- u: UnsignedLong
- f: FloatValue
- d: DoubleValue
- ByteOrder: Byte order of the data in file:
- HighByteFirst
- LowByteFirst
- Default: system's byte order
+ """Writes image to the file.
+ Header: Dictionary containing the non-static header
+ information (static information is generated
+ according to position of image and data format
+ Append: If equals to 0, overwrites the file. Otherwise, appends
+ to the end of the file
+ DataType: The data type to be saved to the file:
+ SignedByte
+ UnsignedByte
+ SignedShort
+ UnsignedShort
+ SignedInteger
+ UnsignedInteger
+ SignedLong
+ UnsignedLong
+ FloatValue
+ DoubleValue
+ Default: according to Data array typecode:
+ 1: SignedByte
+ b: UnsignedByte
+ s: SignedShort
+ w: UnsignedShort
+ i: SignedInteger
+ l: SignedLong
+ u: UnsignedLong
+ f: FloatValue
+ d: DoubleValue
+ ByteOrder: Byte order of the data in file:
+ HighByteFirst
+ LowByteFirst
+ Default: system's byte order
"""
if Append == 0:
self.File.truncate(0)
@@ -804,7 +872,9 @@ class EdfFile(object):
self.Images[Index].StaticHeader["Dim_1"] = "%d" % self.Images[Index].Dim1
self.Images[Index].StaticHeader["Dim_2"] = "%d" % self.Images[Index].Dim2
self.Images[Index].StaticHeader["Dim_3"] = "%d" % self.Images[Index].Dim3
- self.Images[Index].Size = Data.shape[0] * Data.shape[1] * Data.shape[2] * scalarSize
+ self.Images[Index].Size = (
+ Data.shape[0] * Data.shape[1] * Data.shape[2] * scalarSize
+ )
self.Images[Index].NumDim = 3
elif len(Data.shape) > 3:
raise TypeError("EdfFile: Data dimension not suported")
@@ -822,7 +892,9 @@ class EdfFile(object):
self.Images[Index].StaticHeader["Size"] = "%d" % self.Images[Index].Size
self.Images[Index].StaticHeader["Image"] = Index + 1
- self.Images[Index].StaticHeader["HeaderID"] = "EH:%06d:000000:000000" % self.Images[Index].StaticHeader["Image"]
+ self.Images[Index].StaticHeader["HeaderID"] = (
+ "EH:%06d:000000:000000" % self.Images[Index].StaticHeader["Image"]
+ )
self.Images[Index].StaticHeader["ByteOrder"] = self.Images[Index].ByteOrder
self.Images[Index].StaticHeader["DataType"] = self.Images[Index].DataType
@@ -831,11 +903,15 @@ class EdfFile(object):
StrHeader = "{\n"
for i in STATIC_HEADER_ELEMENTS:
if i in self.Images[Index].StaticHeader.keys():
- StrHeader = StrHeader + ("%s = %s ;\n" % (i, self.Images[Index].StaticHeader[i]))
+ StrHeader = StrHeader + (
+ "%s = %s ;\n" % (i, self.Images[Index].StaticHeader[i])
+ )
for i in Header.keys():
StrHeader = StrHeader + ("%s = %s ;\n" % (i, Header[i]))
self.Images[Index].Header[i] = Header[i]
- newsize = (((len(StrHeader) + 1) // HEADER_BLOCK_SIZE) + 1) * HEADER_BLOCK_SIZE - 2
+ newsize = (
+ ((len(StrHeader) + 1) // HEADER_BLOCK_SIZE) + 1
+ ) * HEADER_BLOCK_SIZE - 2
newsize = int(newsize)
StrHeader = StrHeader.ljust(newsize)
StrHeader = StrHeader + "}\n"
@@ -890,13 +966,11 @@ class EdfFile(object):
return
def __GetDefaultNumpyType__(self, EdfType, index=None):
- """ Internal method: returns NumPy type according to Edf type
- """
+ """Internal method: returns NumPy type according to Edf type"""
return self.GetDefaultNumpyType(EdfType, index)
def __GetDefaultEdfType__(self, NumpyType):
- """ Internal method: returns Edf type according Numpy type
- """
+ """Internal method: returns Edf type according Numpy type"""
if NumpyType in ["b", numpy.int8]:
return "SignedByte"
elif NumpyType in ["B", numpy.uint8]:
@@ -910,12 +984,12 @@ class EdfFile(object):
elif NumpyType in ["I", numpy.uint32]:
return "UnsignedInteger"
elif NumpyType == "l":
- if sys.platform == 'linux2':
+ if sys.platform == "linux2":
return "Signed64"
else:
return "SignedLong"
elif NumpyType == "L":
- if sys.platform == 'linux2':
+ if sys.platform == "linux2":
return "Unsigned64"
else:
return "UnsignedLong"
@@ -931,8 +1005,7 @@ class EdfFile(object):
raise TypeError("unknown NumpyType %s" % NumpyType)
def __GetSizeNumpyType__(self, NumpyType):
- """ Internal method: returns size of NumPy's Array Types
- """
+ """Internal method: returns size of NumPy's Array Types"""
if NumpyType in ["b", numpy.int8]:
return 1
elif NumpyType in ["B", numpy.uint8]:
@@ -946,15 +1019,15 @@ class EdfFile(object):
elif NumpyType in ["I", numpy.uint32]:
return 4
elif NumpyType == "l":
- if sys.platform == 'linux2':
- return 8 # 64 bit
+ if sys.platform == "linux2":
+ return 8 # 64 bit
else:
- return 4 # 32 bit
+ return 4 # 32 bit
elif NumpyType == "L":
- if sys.platform == 'linux2':
- return 8 # 64 bit
+ if sys.platform == "linux2":
+ return 8 # 64 bit
else:
- return 4 # 32 bit
+ return 4 # 32 bit
elif NumpyType in ["f", numpy.float32]:
return 4
elif NumpyType in ["d", numpy.float64]:
@@ -971,8 +1044,7 @@ class EdfFile(object):
raise TypeError("unknown NumpyType %s" % NumpyType)
def __SetDataType__(self, Array, DataType):
- """ Internal method: array type convertion
- """
+ """Internal method: array type convertion"""
# AVOID problems not using FromEdfType= Array.dtype.char
FromEdfType = Array.dtype
ToEdfType = self.__GetDefaultNumpyType__(DataType)
@@ -988,14 +1060,13 @@ class EdfFile(object):
pass
def GetDefaultNumpyType(self, EdfType, index=None):
- """ Returns NumPy type according Edf type
- """
+ """Returns NumPy type according Edf type"""
if index is None:
return GetDefaultNumpyType(EdfType)
EdfType = EdfType.upper()
- if EdfType in ['SIGNED64']:
+ if EdfType in ["SIGNED64"]:
return numpy.int64
- if EdfType in ['UNSIGNED64']:
+ if EdfType in ["UNSIGNED64"]:
return numpy.uint64
if EdfType in ["SIGNEDLONG", "UNSIGNEDLONG"]:
dim1 = 1
@@ -1027,11 +1098,10 @@ class EdfFile(object):
def GetDefaultNumpyType(EdfType):
- """ Returns NumPy type according Edf type
- """
+ """Returns NumPy type according Edf type"""
EdfType = EdfType.upper()
if EdfType == "SIGNEDBYTE":
- return numpy.int8 # "b"
+ return numpy.int8 # "b"
elif EdfType == "UNSIGNEDBYTE":
return numpy.uint8 # "B"
elif EdfType == "SIGNEDSHORT":
@@ -1061,10 +1131,10 @@ def GetDefaultNumpyType(EdfType):
def SetDictCase(Dict, Case, Flag):
- """ Returns dictionary with keys and/or values converted into upper or lowercase
- Dict: input dictionary
- Case: LOWER_CASE, UPPER_CASE
- Flag: KEYS, VALUES or KEYS | VALUES
+ """Returns dictionary with keys and/or values converted into upper or lowercase
+ Dict: input dictionary
+ Case: LOWER_CASE, UPPER_CASE
+ Flag: KEYS, VALUES or KEYS | VALUES
"""
newdict = {}
for i in Dict.keys():
@@ -1086,9 +1156,9 @@ def SetDictCase(Dict, Case, Flag):
def GetRegion(Arr, Pos, Size):
"""Returns array with refion of Arr.
- Arr must be 1d, 2d or 3d
- Pos and Size are tuples in the format (x) or (x,y) or (x,y,z)
- Both parameters must have the same size as the dimention of Arr
+ Arr must be 1d, 2d or 3d
+ Pos and Size are tuples in the format (x) or (x,y) or (x,y,z)
+ Both parameters must have the same size as the dimention of Arr
"""
Dim = len(Arr.shape)
if len(Pos) != Dim:
@@ -1096,12 +1166,12 @@ def GetRegion(Arr, Pos, Size):
if len(Size) != Dim:
return None
- if (Dim == 1):
+ if Dim == 1:
SizeX = Size[0]
if SizeX == 0:
SizeX = Arr.shape[0] - Pos[0]
ArrRet = numpy.take(Arr, range(Pos[0], Pos[0] + SizeX))
- elif (Dim == 2):
+ elif Dim == 2:
SizeX = Size[0]
SizeY = Size[1]
if SizeX == 0:
@@ -1110,7 +1180,7 @@ def GetRegion(Arr, Pos, Size):
SizeY = Arr.shape[0] - Pos[1]
ArrRet = numpy.take(Arr, range(Pos[1], Pos[1] + SizeY))
ArrRet = numpy.take(ArrRet, range(Pos[0], Pos[0] + SizeX), 1)
- elif (Dim == 3):
+ elif Dim == 3:
SizeX = Size[0]
SizeY = Size[1]
SizeZ = Size[2]
@@ -1154,11 +1224,18 @@ if __name__ == "__main__":
x = numpy.arange(100)
x.shape = 5, 20
- for item in ["SignedByte", "UnsignedByte",
- "SignedShort", "UnsignedShort",
- "SignedLong", "UnsignedLong",
- "Signed64", "Unsigned64",
- "FloatValue", "DoubleValue"]:
+ for item in [
+ "SignedByte",
+ "UnsignedByte",
+ "SignedShort",
+ "UnsignedShort",
+ "SignedLong",
+ "UnsignedLong",
+ "Signed64",
+ "Unsigned64",
+ "FloatValue",
+ "DoubleValue",
+ ]:
fname = item + ".edf"
if os.path.exists(fname):
os.remove(fname)
@@ -1200,7 +1277,7 @@ if __name__ == "__main__":
exe.WriteImage({}, la, 0, "")
# Appends short array with new header items
- exe.WriteImage({'Name': 'Alexandre', 'Date': '16/07/2001'}, sa)
+ exe.WriteImage({"Name": "Alexandre", "Date": "16/07/2001"}, sa)
# Appends short array, in Edf type unsigned
exe.WriteImage({}, sa, DataType="UnsignedShort")
diff --git a/src/silx/third_party/TiffIO.py b/src/silx/third_party/TiffIO.py
index 7526a75..b9dd829 100644
--- a/src/silx/third_party/TiffIO.py
+++ b/src/silx/third_party/TiffIO.py
@@ -1,1268 +1,10 @@
-# /*##########################################################################
-#
-# The PyMca X-Ray Fluorescence Toolkit
-#
-# Copyright (c) 2004-2020 European Synchrotron Radiation Facility
-#
-# This file is part of the PyMca X-ray Fluorescence Toolkit developed at
-# the ESRF by the Software group.
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-__author__ = "V.A. Sole - ESRF Data Analysis"
-__contact__ = "sole@esrf.fr"
-__license__ = "MIT"
-__copyright__ = "European Synchrotron Radiation Facility, Grenoble, France"
+from silx.utils.deprecation import deprecated_warning
-import sys
-import os
-import struct
-import numpy
-
-DEBUG = 0
-ALLOW_MULTIPLE_STRIPS = False
-
-TAG_ID = { 256:"NumberOfColumns", # S or L ImageWidth
- 257:"NumberOfRows", # S or L ImageHeight
- 258:"BitsPerSample", # S Number of bits per component
- 259:"Compression", # SHORT (1 - NoCompression, ...
- 262:"PhotometricInterpretation", # SHORT (0 - WhiteIsZero, 1 -BlackIsZero, 2 - RGB, 3 - Palette color
- 270:"ImageDescription", # ASCII
- 273:"StripOffsets", # S or L, for each strip, the byte offset of the strip
- 277:"SamplesPerPixel", # SHORT (>=3) only for RGB images
- 278:"RowsPerStrip", # S or L, number of rows in each back may be not for the last
- 279:"StripByteCounts", # S or L, The number of bytes in the strip AFTER any compression
- 305:"Software", # ASCII
- 306:"Date", # ASCII
- 320:"Colormap", # Colormap of Palette-color Images
- 339:"SampleFormat", # SHORT Interpretation of data in each pixel
- }
-
-#TILES ARE TO BE SUPPORTED TOO ...
-TAG_NUMBER_OF_COLUMNS = 256
-TAG_NUMBER_OF_ROWS = 257
-TAG_BITS_PER_SAMPLE = 258
-TAG_PHOTOMETRIC_INTERPRETATION = 262
-TAG_COMPRESSION = 259
-TAG_IMAGE_DESCRIPTION = 270
-TAG_STRIP_OFFSETS = 273
-TAG_SAMPLES_PER_PIXEL = 277
-TAG_ROWS_PER_STRIP = 278
-TAG_STRIP_BYTE_COUNTS = 279
-TAG_SOFTWARE = 305
-TAG_DATE = 306
-TAG_COLORMAP = 320
-TAG_SAMPLE_FORMAT = 339
-
-FIELD_TYPE = {1:('BYTE', "B"),
- 2:('ASCII', "s"), #string ending with binary zero
- 3:('SHORT', "H"),
- 4:('LONG', "I"),
- 5:('RATIONAL',"II"),
- 6:('SBYTE', "b"),
- 7:('UNDEFINED',"B"),
- 8:('SSHORT', "h"),
- 9:('SLONG', "i"),
- 10:('SRATIONAL',"ii"),
- 11:('FLOAT', "f"),
- 12:('DOUBLE', "d")}
-
-FIELD_TYPE_OUT = { 'B': 1,
- 's': 2,
- 'H': 3,
- 'I': 4,
- 'II': 5,
- 'b': 6,
- 'h': 8,
- 'i': 9,
- 'ii': 10,
- 'f': 11,
- 'd': 12}
-
-#sample formats (http://www.awaresystems.be/imaging/tiff/tiffflags/sampleformat.html)
-SAMPLE_FORMAT_UINT = 1
-SAMPLE_FORMAT_INT = 2
-SAMPLE_FORMAT_FLOAT = 3 #floating point
-SAMPLE_FORMAT_VOID = 4 #undefined data, usually assumed UINT
-SAMPLE_FORMAT_COMPLEXINT = 5
-SAMPLE_FORMAT_COMPLEXIEEEFP = 6
-
-
-
-class TiffIO(object):
- def __init__(self, filename, mode=None, cache_length=20, mono_output=False):
- if mode is None:
- mode = 'rb'
- if 'b' not in mode:
- mode = mode + 'b'
- if 'a' in mode.lower():
- raise IOError("Mode %s makes no sense on TIFF files. Consider 'rb+'" % mode)
- if ('w' in mode):
- if '+' not in mode:
- mode += '+'
-
- if hasattr(filename, "seek") and\
- hasattr(filename, "read"):
- fd = filename
- self._access = None
- else:
- #the b is needed for windows and python 3
- fd = open(filename, mode)
- self._access = mode
-
- self._initInternalVariables(fd)
- self._maxImageCacheLength = cache_length
- self._forceMonoOutput = mono_output
-
- def _initInternalVariables(self, fd=None):
- if fd is None:
- fd = self.fd
- else:
- self.fd = fd
- # read the order
- fd.seek(0)
- order = fd.read(2).decode()
- if len(order):
- if order == "II":
- #intel, little endian
- fileOrder = "little"
- self._structChar = '<'
- elif order == "MM":
- #motorola, high endian
- fileOrder = "big"
- self._structChar = '>'
- else:
- raise IOError("File is not a Mar CCD file, nor a TIFF file")
- a = fd.read(2)
- fortyTwo = struct.unpack(self._structChar+"H",a)[0]
- if fortyTwo != 42:
- raise IOError("Invalid TIFF version %d" % fortyTwo)
- else:
- if DEBUG:
- print("VALID TIFF VERSION")
- if sys.byteorder != fileOrder:
- swap = True
- else:
- swap = False
- else:
- if sys.byteorder == "little":
- self._structChar = '<'
- else:
- self._structChar = '>'
- swap = False
- self._swap = swap
- self._IFD = []
- self._imageDataCacheIndex = []
- self._imageDataCache = []
- self._imageInfoCacheIndex = []
- self._imageInfoCache = []
- self.getImageFileDirectories(fd)
-
- def __makeSureFileIsOpen(self):
- if not self.fd.closed:
- return
- if DEBUG:
- print("Reopening closed file")
- fileName = self.fd.name
- if self._access is None:
- #we do not own the file
- #open in read mode
- newFile = open(fileName,'rb')
- else:
- newFile = open(fileName, self._access)
- self.fd = newFile
-
- def __makeSureFileIsClosed(self):
- if self._access is None:
- #we do not own the file
- if DEBUG:
- print("Not closing not owned file")
- return
-
- if not self.fd.closed:
- self.fd.close()
-
- def close(self):
- return self.__makeSureFileIsClosed()
-
- def getNumberOfImages(self):
- #update for the case someone has done anything?
- self._updateIFD()
- return len(self._IFD)
-
- def _updateIFD(self):
- self.__makeSureFileIsOpen()
- self.getImageFileDirectories()
- self.__makeSureFileIsClosed()
-
- def getImageFileDirectories(self, fd=None):
- if fd is None:
- fd = self.fd
- else:
- self.fd = fd
- st = self._structChar
- fd.seek(4)
- self._IFD = []
- nImages = 0
- fmt = st + 'I'
- inStr = fd.read(struct.calcsize(fmt))
- if not len(inStr):
- offsetToIFD = 0
- else:
- offsetToIFD = struct.unpack(fmt, inStr)[0]
- if DEBUG:
- print("Offset to first IFD = %d" % offsetToIFD)
- while offsetToIFD != 0:
- self._IFD.append(offsetToIFD)
- nImages += 1
- fd.seek(offsetToIFD)
- fmt = st + 'H'
- numberOfDirectoryEntries = struct.unpack(fmt,fd.read(struct.calcsize(fmt)))[0]
- if DEBUG:
- print("Number of directory entries = %d" % numberOfDirectoryEntries)
-
- fmt = st + 'I'
- fd.seek(offsetToIFD + 2 + 12 * numberOfDirectoryEntries)
- offsetToIFD = struct.unpack(fmt,fd.read(struct.calcsize(fmt)))[0]
- if DEBUG:
- print("Next Offset to IFD = %d" % offsetToIFD)
- #offsetToIFD = 0
- if DEBUG:
- print("Number of images found = %d" % nImages)
- return nImages
-
- def _parseImageFileDirectory(self, nImage):
- offsetToIFD = self._IFD[nImage]
- st = self._structChar
- fd = self.fd
- fd.seek(offsetToIFD)
- fmt = st + 'H'
- numberOfDirectoryEntries = struct.unpack(fmt,fd.read(struct.calcsize(fmt)))[0]
- if DEBUG:
- print("Number of directory entries = %d" % numberOfDirectoryEntries)
-
- fmt = st + 'HHI4s'
- tagIDList = []
- fieldTypeList = []
- nValuesList = []
- valueOffsetList = []
- for i in range(numberOfDirectoryEntries):
- tagID, fieldType, nValues, valueOffset = struct.unpack(fmt, fd.read(12))
- tagIDList.append(tagID)
- fieldTypeList.append(fieldType)
- nValuesList.append(nValues)
- if nValues == 1:
- ftype, vfmt = FIELD_TYPE[fieldType]
- if ftype not in ['ASCII', 'RATIONAL', 'SRATIONAL']:
- vfmt = st + vfmt
- actualValue = struct.unpack(vfmt, valueOffset[0: struct.calcsize(vfmt)])[0]
- valueOffsetList.append(actualValue)
- else:
- valueOffsetList.append(valueOffset)
- elif (nValues < 5) and (fieldType == 2):
- ftype, vfmt = FIELD_TYPE[fieldType]
- vfmt = st + "%d%s" % (nValues,vfmt)
- actualValue = struct.unpack(vfmt, valueOffset[0: struct.calcsize(vfmt)])[0]
- valueOffsetList.append(actualValue)
- else:
- valueOffsetList.append(valueOffset)
- if DEBUG:
- if tagID in TAG_ID:
- print("tagID = %s" % TAG_ID[tagID])
- else:
- print("tagID = %d" % tagID)
- print("fieldType = %s" % FIELD_TYPE[fieldType][0])
- print("nValues = %d" % nValues)
- #if nValues == 1:
- # print("valueOffset = %s" % valueOffset)
- return tagIDList, fieldTypeList, nValuesList, valueOffsetList
-
-
-
- def _readIFDEntry(self, tag, tagIDList, fieldTypeList, nValuesList, valueOffsetList):
- fd = self.fd
- st = self._structChar
- idx = tagIDList.index(tag)
- nValues = nValuesList[idx]
- output = []
- ftype, vfmt = FIELD_TYPE[fieldTypeList[idx]]
- vfmt = st + "%d%s" % (nValues, vfmt)
- requestedBytes = struct.calcsize(vfmt)
- if nValues == 1:
- output.append(valueOffsetList[idx])
- elif requestedBytes < 5:
- output.append(valueOffsetList[idx])
- else:
- fd.seek(struct.unpack(st+"I", valueOffsetList[idx])[0])
- output = struct.unpack(vfmt, fd.read(requestedBytes))
- return output
-
- def getData(self, nImage, **kw):
- if nImage >= len(self._IFD):
- #update prior to raise an index error error
- self._updateIFD()
- return self._readImage(nImage, **kw)
-
- def getImage(self, nImage):
- return self.getData(nImage)
-
- def getInfo(self, nImage, **kw):
- if nImage >= len(self._IFD):
- #update prior to raise an index error error
- self._updateIFD()
- # current = self._IFD[nImage]
- return self._readInfo(nImage)
-
- def _readInfo(self, nImage, close=True):
- if nImage in self._imageInfoCacheIndex:
- if DEBUG:
- print("Reading info from cache")
- return self._imageInfoCache[self._imageInfoCacheIndex.index(nImage)]
-
- #read the header
- self.__makeSureFileIsOpen()
- tagIDList, fieldTypeList, nValuesList, valueOffsetList = self._parseImageFileDirectory(nImage)
-
- #rows and columns
- nColumns = valueOffsetList[tagIDList.index(TAG_NUMBER_OF_COLUMNS)]
- nRows = valueOffsetList[tagIDList.index(TAG_NUMBER_OF_ROWS)]
-
- #bits per sample
- idx = tagIDList.index(TAG_BITS_PER_SAMPLE)
- nBits = valueOffsetList[idx]
- if nValuesList[idx] != 1:
- #this happens with RGB and friends, nBits is not a single value
- nBits = self._readIFDEntry(TAG_BITS_PER_SAMPLE,
- tagIDList, fieldTypeList, nValuesList, valueOffsetList)
-
-
- if TAG_COLORMAP in tagIDList:
- idx = tagIDList.index(TAG_COLORMAP)
- tmpColormap = self._readIFDEntry(TAG_COLORMAP,
- tagIDList, fieldTypeList, nValuesList, valueOffsetList)
- if max(tmpColormap) > 255:
- tmpColormap = numpy.array(tmpColormap, dtype=numpy.uint16)
- tmpColormap = (tmpColormap/256.).astype(numpy.uint8)
- else:
- tmpColormap = numpy.array(tmpColormap, dtype=numpy.uint8)
- tmpColormap.shape = 3, -1
- colormap = numpy.zeros((tmpColormap.shape[-1], 3), tmpColormap.dtype)
- colormap[:,:] = tmpColormap.T
- tmpColormap = None
- else:
- colormap = None
-
- #sample format
- if TAG_SAMPLE_FORMAT in tagIDList:
- sampleFormat = valueOffsetList[tagIDList.index(TAG_SAMPLE_FORMAT)]
- else:
- #set to unknown
- sampleFormat = SAMPLE_FORMAT_VOID
-
- # compression
- compression = False
- compression_type = 1
- if TAG_COMPRESSION in tagIDList:
- compression_type = valueOffsetList[tagIDList.index(TAG_COMPRESSION)]
- if compression_type == 1:
- compression = False
- else:
- compression = True
-
- #photometric interpretation
- interpretation = 1
- if TAG_PHOTOMETRIC_INTERPRETATION in tagIDList:
- interpretation = valueOffsetList[tagIDList.index(TAG_PHOTOMETRIC_INTERPRETATION)]
- else:
- print("WARNING: Non standard TIFF. Photometric interpretation TAG missing")
- helpString = ""
- if sys.version > '2.6':
- helpString = eval('b""')
-
- if TAG_IMAGE_DESCRIPTION in tagIDList:
- imageDescription = self._readIFDEntry(TAG_IMAGE_DESCRIPTION,
- tagIDList, fieldTypeList, nValuesList, valueOffsetList)
- if type(imageDescription) in [type([1]), type((1,))]:
- imageDescription =helpString.join(imageDescription)
- else:
- imageDescription = "%d/%d" % (nImage+1, len(self._IFD))
-
- if sys.version < '3.0':
- defaultSoftware = "Unknown Software"
- else:
- defaultSoftware = bytes("Unknown Software",
- encoding='utf-8')
- if TAG_SOFTWARE in tagIDList:
- software = self._readIFDEntry(TAG_SOFTWARE,
- tagIDList, fieldTypeList, nValuesList, valueOffsetList)
- if type(software) in [type([1]), type((1,))]:
- software =helpString.join(software)
- else:
- software = defaultSoftware
-
- if software == defaultSoftware:
- try:
- if sys.version < '3.0':
- if imageDescription.upper().startswith("IMAGEJ"):
- software = imageDescription.split("=")[0]
- else:
- tmpString = imageDescription.decode()
- if tmpString.upper().startswith("IMAGEJ"):
- software = bytes(tmpString.split("=")[0],
- encoding='utf-8')
- except:
- pass
-
- if TAG_DATE in tagIDList:
- date = self._readIFDEntry(TAG_DATE,
- tagIDList, fieldTypeList, nValuesList, valueOffsetList)
- if type(date) in [type([1]), type((1,))]:
- date =helpString.join(date)
- else:
- date = "Unknown Date"
-
- stripOffsets = self._readIFDEntry(TAG_STRIP_OFFSETS,
- tagIDList, fieldTypeList, nValuesList, valueOffsetList)
- if TAG_ROWS_PER_STRIP in tagIDList:
- rowsPerStrip = self._readIFDEntry(TAG_ROWS_PER_STRIP,
- tagIDList, fieldTypeList, nValuesList, valueOffsetList)[0]
- else:
- rowsPerStrip = nRows
- print("WARNING: Non standard TIFF. Rows per strip TAG missing")
-
- if TAG_STRIP_BYTE_COUNTS in tagIDList:
- stripByteCounts = self._readIFDEntry(TAG_STRIP_BYTE_COUNTS,
- tagIDList, fieldTypeList, nValuesList, valueOffsetList)
- else:
- print("WARNING: Non standard TIFF. Strip byte counts TAG missing")
- if hasattr(nBits, 'index'):
- expectedSum = 0
- for n in nBits:
- expectedSum += int(nRows * nColumns * n / 8)
- else:
- expectedSum = int(nRows * nColumns * nBits / 8)
- stripByteCounts = [expectedSum]
-
- if close:
- self.__makeSureFileIsClosed()
-
- if self._forceMonoOutput and (interpretation > 1):
- #color image but asked monochrome output
- nBits = 32
- colormap = None
- sampleFormat = SAMPLE_FORMAT_FLOAT
- interpretation = 1
- #we cannot rely on any cache in this case
- useInfoCache = False
- if DEBUG:
- print("FORCED MONO")
- else:
- useInfoCache = True
-
- info = {}
- info["nRows"] = nRows
- info["nColumns"] = nColumns
- info["nBits"] = nBits
- info["compression"] = compression
- info["compression_type"] = compression_type
- info["imageDescription"] = imageDescription
- info["stripOffsets"] = stripOffsets #This contains the file offsets to the data positions
- info["rowsPerStrip"] = rowsPerStrip
- info["stripByteCounts"] = stripByteCounts #bytes in strip since I do not support compression
- info["software"] = software
- info["date"] = date
- info["colormap"] = colormap
- info["sampleFormat"] = sampleFormat
- info["photometricInterpretation"] = interpretation
- infoDict = {}
- if sys.version < '3.0':
- testString = 'PyMca'
- else:
- testString = eval('b"PyMca"')
- if software.startswith(testString):
- #str to make sure python 2.x sees it as string and not unicode
- if sys.version < '3.0':
- descriptionString = imageDescription
- else:
- descriptionString = str(imageDescription.decode())
- #interpret the image description in terms of supplied
- #information at writing time
- items = descriptionString.split('=')
- for i in range(int(len(items)/2)):
- key = "%s" % items[i*2]
- #get rid of the \n at the end of the value
- value = "%s" % items[i*2+1][:-1]
- infoDict[key] = value
- info['info'] = infoDict
-
- if (self._maxImageCacheLength > 0) and useInfoCache:
- self._imageInfoCacheIndex.insert(0,nImage)
- self._imageInfoCache.insert(0, info)
- if len(self._imageInfoCacheIndex) > self._maxImageCacheLength:
- self._imageInfoCacheIndex = self._imageInfoCacheIndex[:self._maxImageCacheLength]
- self._imageInfoCache = self._imageInfoCache[:self._maxImageCacheLength]
- return info
-
- def _readImage(self, nImage, **kw):
- if DEBUG:
- print("Reading image %d" % nImage)
- if 'close' in kw:
- close = kw['close']
- else:
- close = True
- rowMin = kw.get('rowMin', None)
- rowMax = kw.get('rowMax', None)
- if nImage in self._imageDataCacheIndex:
- if DEBUG:
- print("Reading image data from cache")
- return self._imageDataCache[self._imageDataCacheIndex.index(nImage)]
-
- self.__makeSureFileIsOpen()
- if self._forceMonoOutput:
- oldMono = True
- else:
- oldMono = False
- try:
- self._forceMonoOutput = False
- info = self._readInfo(nImage, close=False)
- self._forceMonoOutput = oldMono
- except:
- self._forceMonoOutput = oldMono
- raise
- compression = info['compression']
- compression_type = info['compression_type']
- if compression:
- if compression_type != 32773:
- raise IOError("Compressed TIFF images not supported except packbits")
- else:
- #PackBits compression
- if DEBUG:
- print("Using PackBits compression")
-
- interpretation = info["photometricInterpretation"]
- if interpretation == 2:
- #RGB
- pass
- #raise IOError("RGB Image. Only grayscale images supported")
- elif interpretation == 3:
- #Palette Color Image
- pass
- #raise IOError("Palette-color Image. Only grayscale images supported")
- elif interpretation > 2:
- #Palette Color Image
- raise IOError("Only grayscale images supported")
-
- nRows = info["nRows"]
- nColumns = info["nColumns"]
- nBits = info["nBits"]
- colormap = info["colormap"]
- sampleFormat = info["sampleFormat"]
-
- if rowMin is None:
- rowMin = 0
-
- if rowMax is None:
- rowMax = nRows - 1
-
- if rowMin < 0:
- rowMin = nRows - rowMin
-
- if rowMax < 0:
- rowMax = nRows - rowMax
-
- if rowMax < rowMin:
- txt = "Max Row smaller than Min Row. Reverse selection not supported"
- raise NotImplementedError(txt)
-
- if rowMin >= nRows:
- raise IndexError("Image only has %d rows" % nRows)
-
- if rowMax >= nRows:
- raise IndexError("Image only has %d rows" % nRows)
-
- if sampleFormat == SAMPLE_FORMAT_FLOAT:
- if nBits == 32:
- dtype = numpy.float32
- elif nBits == 64:
- dtype = numpy.float64
- else:
- raise ValueError("Unsupported number of bits for a float: %d" % nBits)
- elif sampleFormat in [SAMPLE_FORMAT_UINT, SAMPLE_FORMAT_VOID]:
- if nBits in [8, (8, 8, 8), [8, 8, 8]]:
- dtype = numpy.uint8
- elif nBits in [16, (16, 16, 16), [16, 16, 16]]:
- dtype = numpy.uint16
- elif nBits in [32, (32, 32, 32), [32, 32, 32]]:
- dtype = numpy.uint32
- elif nBits in [64, (64, 64, 64), [64, 64, 64]]:
- dtype = numpy.uint64
- else:
- raise ValueError("Unsupported number of bits for unsigned int: %s" % (nBits,))
- elif sampleFormat == SAMPLE_FORMAT_INT:
- if nBits in [8, (8, 8, 8), [8, 8, 8]]:
- dtype = numpy.int8
- elif nBits in [16, (16, 16, 16), [16, 16, 16]]:
- dtype = numpy.int16
- elif nBits in [32, (32, 32, 32), [32, 32, 32]]:
- dtype = numpy.int32
- elif nBits in [64, (64, 64, 64), [64, 64, 64]]:
- dtype = numpy.int64
- else:
- raise ValueError("Unsupported number of bits for signed int: %s" % (nBits,))
- else:
- raise ValueError("Unsupported combination. Bits = %s Format = %d" % (nBits, sampleFormat))
- if hasattr(nBits, 'index'):
- image = numpy.zeros((nRows, nColumns, len(nBits)), dtype=dtype)
- elif colormap is not None:
- #should I use colormap dtype?
- image = numpy.zeros((nRows, nColumns, 3), dtype=dtype)
- else:
- image = numpy.zeros((nRows, nColumns), dtype=dtype)
-
- fd = self.fd
- st = self._structChar
- stripOffsets = info["stripOffsets"] #This contains the file offsets to the data positions
- rowsPerStrip = info["rowsPerStrip"]
- stripByteCounts = info["stripByteCounts"] #bytes in strip since I do not support compression
-
- rowStart = 0
- if len(stripOffsets) == 1:
- bytesPerRow = int(stripByteCounts[0]/rowsPerStrip)
- if nRows == rowsPerStrip:
- actualBytesPerRow = int(image.nbytes/nRows)
- if actualBytesPerRow != bytesPerRow:
- print("Warning: Bogus StripByteCounts information")
- bytesPerRow = actualBytesPerRow
- fd.seek(stripOffsets[0] + rowMin * bytesPerRow)
- nBytes = (rowMax-rowMin+1) * bytesPerRow
- if self._swap:
- readout = numpy.copy(numpy.frombuffer(fd.read(nBytes), dtype)).byteswap()
- else:
- readout = numpy.copy(numpy.frombuffer(fd.read(nBytes), dtype))
- if hasattr(nBits, 'index'):
- readout.shape = -1, nColumns, len(nBits)
- elif info['colormap'] is not None:
- readout = colormap[readout]
- else:
- readout.shape = -1, nColumns
- image[rowMin:rowMax+1, :] = readout
- else:
- for i in range(len(stripOffsets)):
- #the amount of rows
- nRowsToRead = rowsPerStrip
- rowEnd = int(min(rowStart+nRowsToRead, nRows))
- if rowEnd < rowMin:
- rowStart += nRowsToRead
- continue
- if (rowStart > rowMax):
- break
- #we are in position
- fd.seek(stripOffsets[i])
- #the amount of bytes to read
- nBytes = stripByteCounts[i]
- if compression_type == 32773:
- try:
- bufferBytes = bytes()
- except:
- #python 2.5 ...
- bufferBytes = ""
- #packBits
- readBytes = 0
- #intermediate buffer
- tmpBuffer = fd.read(nBytes)
- while readBytes < nBytes:
- n = struct.unpack('b', tmpBuffer[readBytes:(readBytes+1)])[0]
- readBytes += 1
- if n >= 0:
- #should I prevent reading more than the
- #length of the chain? Let's python raise
- #the exception...
- bufferBytes += tmpBuffer[readBytes:\
- readBytes+(n+1)]
- readBytes += (n+1)
- elif n > -128:
- bufferBytes += (-n+1) * tmpBuffer[readBytes:(readBytes+1)]
- readBytes += 1
- else:
- #if read -128 ignore the byte
- continue
- if self._swap:
- readout = numpy.copy(numpy.frombuffer(bufferBytes, dtype)).byteswap()
- else:
- readout = numpy.copy(numpy.frombuffer(bufferBytes, dtype))
- if hasattr(nBits, 'index'):
- readout.shape = -1, nColumns, len(nBits)
- elif info['colormap'] is not None:
- readout = colormap[readout]
- readout.shape = -1, nColumns, 3
- else:
- readout.shape = -1, nColumns
- image[rowStart:rowEnd, :] = readout
- else:
- if 1:
- #use numpy
- if self._swap:
- readout = numpy.copy(numpy.frombuffer(fd.read(nBytes), dtype)).byteswap()
- else:
- readout = numpy.copy(numpy.frombuffer(fd.read(nBytes), dtype))
- if hasattr(nBits, 'index'):
- readout.shape = -1, nColumns, len(nBits)
- elif colormap is not None:
- readout = colormap[readout]
- readout.shape = -1, nColumns, 3
- else:
- readout.shape = -1, nColumns
- image[rowStart:rowEnd, :] = readout
- else:
- #using struct
- readout = numpy.array(struct.unpack(st+"%df" % int(nBytes/4), fd.read(nBytes)),
- dtype=dtype)
- if hasattr(nBits, 'index'):
- readout.shape = -1, nColumns, len(nBits)
- elif colormap is not None:
- readout = colormap[readout]
- readout.shape = -1, nColumns, 3
- else:
- readout.shape = -1, nColumns
- image[rowStart:rowEnd, :] = readout
- rowStart += nRowsToRead
- if close:
- self.__makeSureFileIsClosed()
-
- if len(image.shape) == 3:
- #color image
- if self._forceMonoOutput:
- #color image, convert to monochrome
- image = (image[:,:,0] * 0.114 +\
- image[:,:,1] * 0.587 +\
- image[:,:,2] * 0.299).astype(numpy.float32)
-
- if (rowMin == 0) and (rowMax == (nRows-1)):
- self._imageDataCacheIndex.insert(0,nImage)
- self._imageDataCache.insert(0, image)
- if len(self._imageDataCacheIndex) > self._maxImageCacheLength:
- self._imageDataCacheIndex = self._imageDataCacheIndex[:self._maxImageCacheLength]
- self._imageDataCache = self._imageDataCache[:self._maxImageCacheLength]
-
- return image
-
- def writeImage(self, image0, info=None, software=None, date=None):
- if software is None:
- software = 'PyMca.TiffIO'
- #if date is None:
- # date = time.ctime()
-
- self.__makeSureFileIsOpen()
- fd = self.fd
- #prior to do anything, perform some tests
- if not len(image0.shape):
- raise ValueError("Empty image")
- if len(image0.shape) == 1:
- #get a different view
- image = image0[:]
- image.shape = 1, -1
- else:
- image = image0
-
- if image.dtype == numpy.float64:
- image = image.astype(numpy.float32)
- fd.seek(0)
- mode = fd.mode
- name = fd.name
- if 'w' in mode:
- #we have to overwrite the file
- self.__makeSureFileIsClosed()
- fd = None
- if os.path.exists(name):
- os.remove(name)
- fd = open(name, mode='wb+')
- self._initEmptyFile(fd)
- self.fd = fd
-
- #read the file size
- self.__makeSureFileIsOpen()
- fd = self.fd
- fd.seek(0, os.SEEK_END)
- endOfFile = fd.tell()
- if fd.tell() == 0:
- self._initEmptyFile(fd)
- fd.seek(0, os.SEEK_END)
- endOfFile = fd.tell()
-
- #init internal variables
- self._initInternalVariables(fd)
- st = self._structChar
-
- #get the image file directories
- nImages = self.getImageFileDirectories()
- if DEBUG:
- print("File contains %d images" % nImages)
- if nImages == 0:
- fd.seek(4)
- fmt = st + 'I'
- fd.write(struct.pack(fmt, endOfFile))
- else:
- fd.seek(self._IFD[-1])
- fmt = st + 'H'
- numberOfDirectoryEntries = struct.unpack(fmt,fd.read(struct.calcsize(fmt)))[0]
- fmt = st + 'I'
- pos = self._IFD[-1] + 2 + 12 * numberOfDirectoryEntries
- fd.seek(pos)
- fmt = st + 'I'
- fd.write(struct.pack(fmt, endOfFile))
- fd.flush()
-
- #and we can write at the end of the file, find out the file length
- fd.seek(0, os.SEEK_END)
-
- #get the description information from the input information
- if info is None:
- description = info
- else:
- description = "%s" % ""
- for key in info.keys():
- description += "%s=%s\n" % (key, info[key])
-
- #get the image file directory
- outputIFD = self._getOutputIFD(image, description=description,
- software=software,
- date=date)
-
- #write the new IFD
- fd.write(outputIFD)
-
- #write the image
- if self._swap:
- fd.write(image.byteswap().tobytes())
- else:
- fd.write(image.tobytes())
-
- fd.flush()
- self.fd=fd
- self.__makeSureFileIsClosed()
-
- def _initEmptyFile(self, fd=None):
- if fd is None:
- fd = self.fd
- if sys.byteorder == "little":
- order = "II"
- #intel, little endian
- fileOrder = "little"
- self._structChar = '<'
- else:
- order = "MM"
- #motorola, high endian
- fileOrder = "big"
- self._structChar = '>'
- st = self._structChar
- if fileOrder == sys.byteorder:
- self._swap = False
- else:
- self._swap = True
- fd.seek(0)
- if sys.version < '3.0':
- fd.write(struct.pack(st+'2s', order))
- fd.write(struct.pack(st+'H', 42))
- fd.write(struct.pack(st+'I', 0))
- else:
- fd.write(struct.pack(st+'2s', bytes(order,'utf-8')))
- fd.write(struct.pack(st+'H', 42))
- fd.write(struct.pack(st+'I', 0))
- fd.flush()
-
- def _getOutputIFD(self, image, description=None, software=None, date=None):
- #the tags have to be in order
- #the very minimum is
- #256:"NumberOfColumns", # S or L ImageWidth
- #257:"NumberOfRows", # S or L ImageHeight
- #258:"BitsPerSample", # S Number of bits per component
- #259:"Compression", # SHORT (1 - NoCompression, ...
- #262:"PhotometricInterpretation", # SHORT (0 - WhiteIsZero, 1 -BlackIsZero, 2 - RGB, 3 - Palette color
- #270:"ImageDescription", # ASCII
- #273:"StripOffsets", # S or L, for each strip, the byte offset of the strip
- #277:"SamplesPerPixel", # SHORT (>=3) only for RGB images
- #278:"RowsPerStrip", # S or L, number of rows in each back may be not for the last
- #279:"StripByteCounts", # S or L, The number of bytes in the strip AFTER any compression
- #305:"Software", # ASCII
- #306:"Date", # ASCII
- #339:"SampleFormat", # SHORT Interpretation of data in each pixel
-
- nDirectoryEntries = 9
- imageDescription = None
- if description is not None:
- descriptionLength = len(description)
- while descriptionLength < 4:
- description = description + " "
- descriptionLength = len(description)
- if sys.version >= '3.0':
- description = bytes(description, 'utf-8')
- elif type(description) != type(""):
- try:
- description = description.decode('utf-8')
- except UnicodeDecodeError:
- try:
- description = description.decode('latin-1')
- except UnicodeDecodeError:
- description = "%s" % description
- if sys.version > '2.6':
- description=description.encode('utf-8', errors="ignore")
- description = "%s" % description
- descriptionLength = len(description)
- imageDescription = struct.pack("%ds" % descriptionLength, description)
- nDirectoryEntries += 1
-
- #software
- if software is not None:
- softwareLength = len(software)
- while softwareLength < 4:
- software = software + " "
- softwareLength = len(software)
- if sys.version >= '3.0':
- software = bytes(software, 'utf-8')
- softwarePackedString = struct.pack("%ds" % softwareLength, software)
- nDirectoryEntries += 1
- else:
- softwareLength = 0
-
- if date is not None:
- dateLength = len(date)
- if sys.version >= '3.0':
- date = bytes(date, 'utf-8')
- datePackedString = struct.pack("%ds" % dateLength, date)
- dateLength = len(datePackedString)
- nDirectoryEntries += 1
- else:
- dateLength = 0
-
- if len(image.shape) == 2:
- nRows, nColumns = image.shape
- nChannels = 1
- elif len(image.shape) == 3:
- nRows, nColumns, nChannels = image.shape
- else:
- raise RuntimeError("Image does not have the right shape")
- dtype = image.dtype
- bitsPerSample = int(dtype.str[-1]) * 8
-
- #only uncompressed data
- compression = 1
-
- #interpretation, black is zero
- if nChannels == 1:
- interpretation = 1
- bitsPerSampleLength = 0
- elif nChannels == 3:
- interpretation = 2
- bitsPerSampleLength = 3 * 2 # To store 3 shorts
- nDirectoryEntries += 1 # For SamplesPerPixel
- else:
- raise RuntimeError(
- "Image with %d color channel(s) not supported" % nChannels)
-
- #image description
- if imageDescription is not None:
- descriptionLength = len(imageDescription)
- else:
- descriptionLength = 0
-
- #strip offsets
- #we are putting them after the directory and the directory is
- #at the end of the file
- self.fd.seek(0, os.SEEK_END)
- endOfFile = self.fd.tell()
- if endOfFile == 0:
- #empty file
- endOfFile = 8
-
- #rows per strip
- if ALLOW_MULTIPLE_STRIPS:
- #try to segment the image in several pieces
- if not (nRows % 4):
- rowsPerStrip = int(nRows/4)
- elif not (nRows % 10):
- rowsPerStrip = int(nRows/10)
- elif not (nRows % 8):
- rowsPerStrip = int(nRows/8)
- elif not (nRows % 4):
- rowsPerStrip = int(nRows/4)
- elif not (nRows % 2):
- rowsPerStrip = int(nRows/2)
- else:
- rowsPerStrip = nRows
- else:
- rowsPerStrip = nRows
-
- #stripByteCounts
- stripByteCounts = int(nColumns * rowsPerStrip *
- bitsPerSample * nChannels / 8)
-
- if descriptionLength > 4:
- stripOffsets0 = endOfFile + dateLength + descriptionLength +\
- 2 + 12 * nDirectoryEntries + 4
- else:
- stripOffsets0 = endOfFile + dateLength + \
- 2 + 12 * nDirectoryEntries + 4
-
- if softwareLength > 4:
- stripOffsets0 += softwareLength
-
- stripOffsets0 += bitsPerSampleLength
-
- stripOffsets = [stripOffsets0]
- stripOffsetsLength = 0
- stripOffsetsString = None
-
- st = self._structChar
-
- if rowsPerStrip != nRows:
- nStripOffsets = int(nRows/rowsPerStrip)
- fmt = st + 'I'
- stripOffsetsLength = struct.calcsize(fmt) * nStripOffsets
- stripOffsets0 += stripOffsetsLength
- #the length for the stripByteCounts will be the same
- stripOffsets0 += stripOffsetsLength
- stripOffsets = []
- for i in range(nStripOffsets):
- value = stripOffsets0 + i * stripByteCounts
- stripOffsets.append(value)
- if i == 0:
- stripOffsetsString = struct.pack(fmt, value)
- stripByteCountsString = struct.pack(fmt, stripByteCounts)
- else:
- stripOffsetsString += struct.pack(fmt, value)
- stripByteCountsString += struct.pack(fmt, stripByteCounts)
-
- if DEBUG:
- print("IMAGE WILL START AT %d" % stripOffsets[0])
-
- #sample format
- if dtype in [numpy.float32, numpy.float64] or\
- dtype.str[-2] == 'f':
- sampleFormat = SAMPLE_FORMAT_FLOAT
- elif dtype in [numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64]:
- sampleFormat = SAMPLE_FORMAT_UINT
- elif dtype in [numpy.int8, numpy.int16, numpy.int32, numpy.int64]:
- sampleFormat = SAMPLE_FORMAT_INT
- else:
- raise ValueError("Unsupported data type %s" % dtype)
-
- info = {}
- info["nColumns"] = nColumns
- info["nRows"] = nRows
- info["nBits"] = bitsPerSample
- info["compression"] = compression
- info["photometricInterpretation"] = interpretation
- info["stripOffsets"] = stripOffsets
- if interpretation == 2:
- info["samplesPerPixel"] = 3 # No support for extra samples
- info["rowsPerStrip"] = rowsPerStrip
- info["stripByteCounts"] = stripByteCounts
- info["date"] = date
- info["sampleFormat"] = sampleFormat
-
- outputIFD = ""
- if sys.version > '2.6':
- outputIFD = eval('b""')
-
- fmt = st + "H"
- outputIFD += struct.pack(fmt, nDirectoryEntries)
-
- fmt = st + "HHII"
- outputIFD += struct.pack(fmt, TAG_NUMBER_OF_COLUMNS,
- FIELD_TYPE_OUT['I'],
- 1,
- info["nColumns"])
- outputIFD += struct.pack(fmt, TAG_NUMBER_OF_ROWS,
- FIELD_TYPE_OUT['I'],
- 1,
- info["nRows"])
-
- if info["photometricInterpretation"] == 1:
- fmt = st + 'HHIHH'
- outputIFD += struct.pack(fmt, TAG_BITS_PER_SAMPLE,
- FIELD_TYPE_OUT['H'],
- 1,
- info["nBits"], 0)
- elif info["photometricInterpretation"] == 2:
- fmt = st + 'HHII'
- outputIFD += struct.pack(fmt, TAG_BITS_PER_SAMPLE,
- FIELD_TYPE_OUT['H'],
- 3,
- info["stripOffsets"][0] - \
- 2 * stripOffsetsLength - \
- descriptionLength - \
- dateLength - \
- softwareLength - \
- bitsPerSampleLength)
- else:
- raise RuntimeError("Unsupported photometric interpretation")
-
- fmt = st + 'HHIHH'
- outputIFD += struct.pack(fmt, TAG_COMPRESSION,
- FIELD_TYPE_OUT['H'],
- 1,
- info["compression"],0)
- fmt = st + 'HHIHH'
- outputIFD += struct.pack(fmt, TAG_PHOTOMETRIC_INTERPRETATION,
- FIELD_TYPE_OUT['H'],
- 1,
- info["photometricInterpretation"],0)
-
- if imageDescription is not None:
- descriptionLength = len(imageDescription)
- if descriptionLength > 4:
- fmt = st + 'HHII'
- outputIFD += struct.pack(fmt, TAG_IMAGE_DESCRIPTION,
- FIELD_TYPE_OUT['s'],
- descriptionLength,
- info["stripOffsets"][0]-\
- 2*stripOffsetsLength-\
- descriptionLength)
- else:
- #it has to have length 4
- fmt = st + 'HHI%ds' % descriptionLength
- outputIFD += struct.pack(fmt, TAG_IMAGE_DESCRIPTION,
- FIELD_TYPE_OUT['s'],
- descriptionLength,
- description)
-
- if len(stripOffsets) == 1:
- fmt = st + 'HHII'
- outputIFD += struct.pack(fmt, TAG_STRIP_OFFSETS,
- FIELD_TYPE_OUT['I'],
- 1,
- info["stripOffsets"][0])
- else:
- fmt = st + 'HHII'
- outputIFD += struct.pack(fmt, TAG_STRIP_OFFSETS,
- FIELD_TYPE_OUT['I'],
- len(stripOffsets),
- info["stripOffsets"][0]-2*stripOffsetsLength)
-
- if info["photometricInterpretation"] == 2:
- fmt = st + 'HHIHH'
- outputIFD += struct.pack(fmt, TAG_SAMPLES_PER_PIXEL,
- FIELD_TYPE_OUT['H'],
- 1,
- info["samplesPerPixel"], 0)
-
- fmt = st + 'HHII'
- outputIFD += struct.pack(fmt, TAG_ROWS_PER_STRIP,
- FIELD_TYPE_OUT['I'],
- 1,
- info["rowsPerStrip"])
-
- if len(stripOffsets) == 1:
- fmt = st + 'HHII'
- outputIFD += struct.pack(fmt, TAG_STRIP_BYTE_COUNTS,
- FIELD_TYPE_OUT['I'],
- 1,
- info["stripByteCounts"])
- else:
- fmt = st + 'HHII'
- outputIFD += struct.pack(fmt, TAG_STRIP_BYTE_COUNTS,
- FIELD_TYPE_OUT['I'],
- len(stripOffsets),
- info["stripOffsets"][0]-stripOffsetsLength)
-
- if software is not None:
- if softwareLength > 4:
- fmt = st + 'HHII'
- outputIFD += struct.pack(fmt, TAG_SOFTWARE,
- FIELD_TYPE_OUT['s'],
- softwareLength,
- info["stripOffsets"][0]-\
- 2*stripOffsetsLength-\
- descriptionLength-softwareLength-dateLength)
- else:
- #it has to have length 4
- fmt = st + 'HHI%ds' % softwareLength
- outputIFD += struct.pack(fmt, TAG_SOFTWARE,
- FIELD_TYPE_OUT['s'],
- softwareLength,
- softwarePackedString)
-
- if date is not None:
- fmt = st + 'HHII'
- outputIFD += struct.pack(fmt, TAG_DATE,
- FIELD_TYPE_OUT['s'],
- dateLength,
- info["stripOffsets"][0]-\
- 2*stripOffsetsLength-\
- descriptionLength-dateLength)
-
- fmt = st + 'HHIHH'
- outputIFD += struct.pack(fmt, TAG_SAMPLE_FORMAT,
- FIELD_TYPE_OUT['H'],
- 1,
- info["sampleFormat"],0)
- fmt = st + 'I'
- outputIFD += struct.pack(fmt, 0)
-
- if info["photometricInterpretation"] == 2:
- outputIFD += struct.pack('HHH', info["nBits"],
- info["nBits"], info["nBits"])
-
- if softwareLength > 4:
- outputIFD += softwarePackedString
-
- if date is not None:
- outputIFD += datePackedString
-
- if imageDescription is not None:
- if descriptionLength > 4:
- outputIFD += imageDescription
-
- if stripOffsetsString is not None:
- outputIFD += stripOffsetsString
- outputIFD += stripByteCountsString
-
- return outputIFD
-
-
-if __name__ == "__main__":
- filename = sys.argv[1]
- dtype = numpy.uint16
- if not os.path.exists(filename):
- print("Testing file creation")
- tif = TiffIO(filename, mode = 'wb+')
- data = numpy.arange(10000).astype(dtype)
- data.shape = 100, 100
- tif.writeImage(data, info={'Title':'1st'})
- tif = None
- if os.path.exists(filename):
- print("Testing image appending")
- tif = TiffIO(filename, mode = 'rb+')
- tif.writeImage((data*2).astype(dtype), info={'Title':'2nd'})
- tif = None
- tif = TiffIO(filename)
- print("Number of images = %d" % tif.getNumberOfImages())
- for i in range(tif.getNumberOfImages()):
- info = tif.getInfo(i)
- for key in info:
- if key not in ["colormap"]:
- print("%s = %s" % (key, info[key]))
- elif info['colormap'] is not None:
- print("RED %s = %s" % (key, info[key][0:10, 0]))
- print("GREEN %s = %s" % (key, info[key][0:10, 1]))
- print("BLUE %s = %s" % (key, info[key][0:10, 2]))
- data = tif.getImage(i)[0, 0:10]
- print("data [0, 0:10] = ", data)
+deprecated_warning(
+ "Module",
+ "silx.third_party.TiffIO",
+ since_version="2.0.0",
+ replacement="fabio.TiffIO",
+)
+from fabio.TiffIO import *
diff --git a/src/silx/third_party/__init__.py b/src/silx/third_party/__init__.py
index 529ae3f..388430b 100644
--- a/src/silx/third_party/__init__.py
+++ b/src/silx/third_party/__init__.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2015-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2015-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -30,3 +30,12 @@ This is meant for internal use.
__authors__ = ["Jérôme Kieffer"]
__license__ = "MIT"
__date__ = "09/10/2015"
+
+from silx.utils.deprecation import deprecated_warning
+
+deprecated_warning(
+ "Module",
+ "silx.third_party",
+ since_version="2.0.0",
+ replacement="fabio",
+)
diff --git a/src/silx/utils/ExternalResources.py b/src/silx/utils/ExternalResources.py
index 429314e..8172b66 100644
--- a/src/silx/utils/ExternalResources.py
+++ b/src/silx/utils/ExternalResources.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -29,16 +29,20 @@ __license__ = "MIT"
__date__ = "21/12/2021"
-import os
-import threading
+import hashlib
import json
import logging
+import os
+import sys
+import tarfile
+import threading
import tempfile
import unittest
import urllib.request
import urllib.error
-import hashlib
-from collections import OrderedDict
+import zipfile
+
+
logger = logging.getLogger(__name__)
@@ -48,10 +52,7 @@ class ExternalResources(object):
"""
- def __init__(self, project,
- url_base,
- env_key=None,
- timeout=60):
+ def __init__(self, project, url_base, env_key=None, timeout=60):
"""Constructor of the class
:param str project: name of the project, like "silx"
@@ -86,6 +87,7 @@ class ExternalResources(object):
if data_home is None:
try:
import getpass
+
name = getpass.getuser()
except Exception:
if "getlogin" in dir(os):
@@ -110,7 +112,7 @@ class ExternalResources(object):
return
h = self.hash()
if filename is not None:
- fullfilename = os.path.join(self.data_home, filename)
+ fullfilename = os.path.join(self.data_home, filename)
if os.path.exists(fullfilename):
with open(fullfilename, "rb") as fd:
data = fd.read()
@@ -130,8 +132,8 @@ class ExternalResources(object):
jdata = json.load(f)
if isinstance(jdata, dict):
self.all_data = jdata
- else:
- #recalculate the hash only if the data was stored as a list
+ else:
+ # recalculate the hash only if the data was stored as a list
self.all_data = {k: self.get_hash(k) for k in jdata}
self.save_json()
self._initialized = True
@@ -154,14 +156,17 @@ class ExternalResources(object):
fullfilename = os.path.abspath(os.path.join(self.data_home, filename))
if not os.path.isfile(fullfilename):
- logger.debug("Trying to download image %s, timeout set to %ss",
- filename, self.timeout)
+ logger.debug(
+ "Trying to download image %s, timeout set to %ss",
+ filename,
+ self.timeout,
+ )
dictProxies = {}
if "http_proxy" in os.environ:
- dictProxies['http'] = os.environ["http_proxy"]
- dictProxies['https'] = os.environ["http_proxy"]
+ dictProxies["http"] = os.environ["http_proxy"]
+ dictProxies["https"] = os.environ["http_proxy"]
if "https_proxy" in os.environ:
- dictProxies['https'] = os.environ["https_proxy"]
+ dictProxies["https"] = os.environ["https_proxy"]
if dictProxies:
proxy_handler = urllib.request.ProxyHandler(dictProxies)
opener = urllib.request.build_opener(proxy_handler).open
@@ -170,8 +175,9 @@ class ExternalResources(object):
logger.debug("wget %s/%s", self.url_base, filename)
try:
- data = opener("%s/%s" % (self.url_base, filename),
- data=None, timeout=self.timeout).read()
+ data = opener(
+ "%s/%s" % (self.url_base, filename), data=None, timeout=self.timeout
+ ).read()
logger.info("Image %s successfully downloaded.", filename)
except urllib.error.URLError:
raise unittest.SkipTest("network unreachable.")
@@ -184,8 +190,11 @@ class ExternalResources(object):
with open(fullfilename, mode="wb") as outfile:
outfile.write(data)
except IOError:
- raise IOError("unable to write downloaded \
- data to disk at %s" % self.data_home)
+ raise IOError(
+ "unable to write downloaded \
+ data to disk at %s"
+ % self.data_home
+ )
if not os.path.isfile(fullfilename):
raise RuntimeError(
@@ -193,7 +202,9 @@ class ExternalResources(object):
If you are behind a firewall, please set both environment variable http_proxy and https_proxy.
This even works under windows !
Otherwise please try to download the images manually from
- %s/%s""" % (filename, self.url_base, filename))
+ %s/%s"""
+ % (filename, self.url_base, filename)
+ )
else:
self.all_data[filename] = self.get_hash(data=data)
self.save_json()
@@ -207,13 +218,13 @@ class ExternalResources(object):
self.all_data.pop(filename)
os.unlink(fullfilename)
return self.getfile(filename)
-
+
return fullfilename
def save_json(self):
image_list = list(self.all_data.keys())
image_list.sort()
- dico = OrderedDict([(i, self.all_data[i]) for i in image_list])
+ dico = dict([(i, self.all_data[i]) for i in image_list])
try:
with open(self.testdata, "w") as fp:
json.dump(dico, fp, indent=4)
@@ -229,26 +240,25 @@ class ExternalResources(object):
:return: list of files with their full path.
"""
lodn = dirname.lower()
- if (lodn.endswith("tar") or lodn.endswith("tgz") or
- lodn.endswith("tbz2") or lodn.endswith("tar.gz") or
- lodn.endswith("tar.bz2")):
- import tarfile
- engine = tarfile.TarFile.open
- elif lodn.endswith("zip"):
- import zipfile
- engine = zipfile.ZipFile
- else:
- raise RuntimeError("Unsupported archive format. Only tar and zip "
- "are currently supported")
full_path = self.getfile(dirname)
- with engine(full_path, mode="r") as fd:
- output = os.path.join(self.data_home, dirname + "__content")
- fd.extractall(output)
- if lodn.endswith("zip"):
- result = [os.path.join(output, i) for i in fd.namelist()]
- else:
- result = [os.path.join(output, i) for i in fd.getnames()]
- return result
+ output = os.path.join(self.data_home, dirname + "__content")
+
+ if lodn.endswith(("tar", "tgz", "tbz2", "tar.gz", "tar.bz2")):
+ with tarfile.TarFile.open(full_path, mode="r") as fd:
+ # Avoid unsafe filter deprecation warning during transistion of mode change
+ if (3, 12) <= sys.version_info < (3, 14):
+ fd.extraction_filter = tarfile.data_filter
+ fd.extractall(output)
+ return [os.path.join(output, i) for i in fd.getnames()]
+
+ if lodn.endswith("zip"):
+ with zipfile.ZipFile(full_path, mode="r") as fd:
+ fd.extractall(output)
+ return [os.path.join(output, i) for i in fd.namelist()]
+
+ raise RuntimeError(
+ "Unsupported archive format. Only tar and zip " "are currently supported"
+ )
def get_file_and_repack(self, filename):
"""
@@ -295,7 +305,9 @@ class ExternalResources(object):
"""Could not automatically download test images %s!
If you are behind a firewall, please set the environment variable http_proxy.
Otherwise please try to download the images manually from
- %s""" % (self.url_base, filename))
+ %s"""
+ % (self.url_base, filename)
+ )
try:
import bz2
@@ -318,8 +330,11 @@ class ExternalResources(object):
with open(fullimagename_raw, "wb") as fullimage:
fullimage.write(decompressed)
except IOError:
- raise IOError("unable to write decompressed \
- data to disk at %s" % self.data_home)
+ raise IOError(
+ "unable to write decompressed \
+ data to disk at %s"
+ % self.data_home
+ )
if not gz_file_exists:
if gzip is None:
@@ -327,8 +342,11 @@ class ExternalResources(object):
try:
gzip.open(fullimagename_gz, "wb").write(decompressed)
except IOError:
- raise IOError("unable to write gzipped \
- data to disk at %s" % self.data_home)
+ raise IOError(
+ "unable to write gzipped \
+ data to disk at %s"
+ % self.data_home
+ )
return fullimagename
diff --git a/src/silx/utils/array_like.py b/src/silx/utils/array_like.py
index d9c7b73..b9b976b 100644
--- a/src/silx/utils/array_like.py
+++ b/src/silx/utils/array_like.py
@@ -46,8 +46,6 @@ Functions:
"""
-import sys
-
import numpy
import numbers
@@ -211,10 +209,9 @@ class ListOfImages(object):
:param images: list of 2D numpy arrays, or :class:`ListOfImages` object
:param transposition: Tuple of dimension numbers in the wanted order
"""
- def __init__(self, images, transposition=None):
- """
- """
+ def __init__(self, images, transposition=None):
+ """ """
super(ListOfImages, self).__init__()
# if images is a ListOfImages instance, get the underlying data
@@ -223,19 +220,16 @@ class ListOfImages(object):
images = images.images
# test stack of images is as expected
- assert is_list_of_arrays(images), \
- "Image stack must be a list of arrays"
+ assert is_list_of_arrays(images), "Image stack must be a list of arrays"
image0_shape = images[0].shape
for image in images:
- assert image.ndim == 2, \
- "Images must be 2D numpy arrays"
- assert image.shape == image0_shape, \
- "All images must have the same shape"
+ assert image.ndim == 2, "Images must be 2D numpy arrays"
+ assert image.shape == image0_shape, "All images must have the same shape"
self.images = images
"""List of images"""
- self.shape = (len(images), ) + image0_shape
+ self.shape = (len(images),) + image0_shape
"""Tuple of array dimensions"""
self.dtype = get_concatenated_dtype(images)
"""Data-type of the global array"""
@@ -257,14 +251,14 @@ class ListOfImages(object):
if transposition is not None:
assert len(transposition) == self.ndim
- assert set(transposition) == set(list(range(self.ndim))), \
- "Transposition must be a sequence containing all dimensions"
+ assert set(transposition) == set(
+ list(range(self.ndim))
+ ), "Transposition must be a sequence containing all dimensions"
self.transposition = transposition
self.__sort_shape()
def __sort_shape(self):
- """Sort shape in the order defined in :attr:`transposition`
- """
+ """Sort shape in the order defined in :attr:`transposition`"""
new_shape = tuple(self.shape[dim] for dim in self.transposition)
self.shape = new_shape
@@ -277,8 +271,9 @@ class ListOfImages(object):
:return: Sorted tuple of indices, to access original data
"""
assert len(indices) == self.ndim
- sorted_indices = tuple(idx for (_, idx) in
- sorted(zip(self.transposition, indices)))
+ sorted_indices = tuple(
+ idx for (_, idx) in sorted(zip(self.transposition, indices))
+ )
return sorted_indices
def __array__(self, dtype=None):
@@ -286,8 +281,9 @@ class ListOfImages(object):
If a transposition has been done on this images, return
a transposed view of a numpy array."""
- return numpy.transpose(numpy.array(self.images, dtype=dtype),
- self.transposition)
+ return numpy.transpose(
+ numpy.array(self.images, dtype=dtype), self.transposition
+ )
def __len__(self):
return self.shape[0]
@@ -313,8 +309,7 @@ class ListOfImages(object):
elif list(self.transposition) != list(range(self.ndim)):
transposition = [self.transposition[i] for i in transposition]
- return ListOfImages(self.images,
- transposition)
+ return ListOfImages(self.images, transposition)
@property
def T(self):
@@ -346,8 +341,9 @@ class ListOfImages(object):
# n-dimensional slicing
if len(item) != self.ndim:
raise IndexError(
- "N-dim slicing requires a tuple of N indices/slices. " +
- "Needed dimensions: %d" % self.ndim)
+ "N-dim slicing requires a tuple of N indices/slices. "
+ + "Needed dimensions: %d" % self.ndim
+ )
# get list of indices sorted in the original images order
sorted_indices = self.__sort_indices(item)
@@ -379,8 +375,7 @@ class ListOfImages(object):
# single list elements selected
if isinstance(images_selection, numpy.ndarray):
- return numpy.transpose(images_selection[array_idx],
- axes=output_dimensions)
+ return numpy.transpose(images_selection[array_idx], axes=output_dimensions)
# muliple list elements selected
else:
# apply selection first
@@ -388,8 +383,7 @@ class ListOfImages(object):
for img in images_selection:
output_stack.append(img[array_idx])
# then cast into a numpy array, and transpose
- return numpy.transpose(numpy.array(output_stack),
- axes=output_dimensions)
+ return numpy.transpose(numpy.array(output_stack), axes=output_dimensions)
def min(self):
"""
@@ -428,10 +422,9 @@ class DatasetView(object):
:param transposition: List of dimensions sorted in the order of
transposition (relative to the original h5py dataset)
"""
- def __init__(self, dataset, transposition=None):
- """
- """
+ def __init__(self, dataset, transposition=None):
+ """ """
super(DatasetView, self).__init__()
self.dataset = dataset
"""original dataset"""
@@ -463,14 +456,14 @@ class DatasetView(object):
if transposition is not None:
assert len(transposition) == self.ndim
- assert set(transposition) == set(list(range(self.ndim))), \
- "Transposition must be a list containing all dimensions"
+ assert set(transposition) == set(
+ list(range(self.ndim))
+ ), "Transposition must be a list containing all dimensions"
self.transposition = transposition
self.__sort_shape()
def __sort_shape(self):
- """Sort shape in the order defined in :attr:`transposition`
- """
+ """Sort shape in the order defined in :attr:`transposition`"""
new_shape = tuple(self.shape[dim] for dim in self.transposition)
self.shape = new_shape
@@ -483,8 +476,9 @@ class DatasetView(object):
:return: Sorted tuple of indices, to access original data
"""
assert len(indices) == self.ndim
- sorted_indices = tuple(idx for (_, idx) in
- sorted(zip(self.transposition, indices)))
+ sorted_indices = tuple(
+ idx for (_, idx) in sorted(zip(self.transposition, indices))
+ )
return sorted_indices
def __getitem__(self, item):
@@ -516,8 +510,9 @@ class DatasetView(object):
# n-dimensional slicing
if len(item) != self.ndim:
raise IndexError(
- "N-dim slicing requires a tuple of N indices/slices. " +
- "Needed dimensions: %d" % self.ndim)
+ "N-dim slicing requires a tuple of N indices/slices. "
+ + "Needed dimensions: %d" % self.ndim
+ )
# get list of indices sorted in the original dataset order
sorted_indices = self.__sort_indices(item)
@@ -546,16 +541,16 @@ class DatasetView(object):
assert (len(output_dimensions) + len(frozen_dimensions)) == self.ndim
assert set(output_dimensions) == set(range(len(output_dimensions)))
- return numpy.transpose(output_data_not_transposed,
- axes=output_dimensions)
+ return numpy.transpose(output_data_not_transposed, axes=output_dimensions)
def __array__(self, dtype=None):
"""Cast the dataset into a numpy array, and return it.
If a transposition has been done on this dataset, return
a transposed view of a numpy array."""
- return numpy.transpose(numpy.array(self.dataset, dtype=dtype),
- self.transposition)
+ return numpy.transpose(
+ numpy.array(self.dataset, dtype=dtype), self.transposition
+ )
def __len__(self):
return self.shape[0]
@@ -580,8 +575,7 @@ class DatasetView(object):
elif list(self.transposition) != list(range(self.ndim)):
transposition = [self.transposition[i] for i in transposition]
- return DatasetView(self.dataset,
- transposition)
+ return DatasetView(self.dataset, transposition)
@property
def T(self):
diff --git a/src/silx/utils/debug.py b/src/silx/utils/debug.py
index ec361ac..8e9eba6 100644
--- a/src/silx/utils/debug.py
+++ b/src/silx/utils/debug.py
@@ -45,6 +45,7 @@ def log_method(func, class_name=None):
:param callable func: The function to patch
:param str class_name: In case a method, provide the class name
"""
+
def wrapper(*args, **kwargs):
global _indent
@@ -60,6 +61,7 @@ def log_method(func, class_name=None):
_indent -= 1
debug_logger.warning("%sreturn (%s)" % (indent, name))
return result
+
return wrapper
@@ -89,7 +91,12 @@ def log_all_methods(base_class):
:param class base_class: The class to patch
"""
- methodTypes = (types.MethodType, types.FunctionType, types.BuiltinFunctionType, types.BuiltinMethodType)
+ methodTypes = (
+ types.MethodType,
+ types.FunctionType,
+ types.BuiltinFunctionType,
+ types.BuiltinMethodType,
+ )
for name, func in inspect.getmembers(base_class):
if isinstance(func, methodTypes):
if func.__name__ not in ["__subclasshook__", "__new__"]:
diff --git a/src/silx/utils/deprecation.py b/src/silx/utils/deprecation.py
index 81d7ed1..4a69bc6 100644
--- a/src/silx/utils/deprecation.py
+++ b/src/silx/utils/deprecation.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -27,7 +27,6 @@ __authors__ = ["Jerome Kieffer", "H. Payno", "P. Knobel"]
__license__ = "MIT"
__date__ = "26/02/2018"
-import sys
import logging
import functools
import traceback
@@ -41,7 +40,15 @@ FORCE = False
It is needed for reproducible tests.
"""
-def deprecated(func=None, reason=None, replacement=None, since_version=None, only_once=True, skip_backtrace_count=1):
+
+def deprecated(
+ func=None,
+ reason=None,
+ replacement=None,
+ since_version=None,
+ only_once=True,
+ skip_backtrace_count=1,
+):
"""
Decorator that deprecates the use of a function
@@ -56,28 +63,37 @@ def deprecated(func=None, reason=None, replacement=None, since_version=None, onl
:param int skip_backtrace_count: Amount of last backtrace to ignore when
logging the backtrace
"""
+
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
- name = func.func_name if sys.version_info[0] < 3 else func.__name__
-
- deprecated_warning(type_='Function',
- name=name,
- reason=reason,
- replacement=replacement,
- since_version=since_version,
- only_once=only_once,
- skip_backtrace_count=skip_backtrace_count)
+ deprecated_warning(
+ type_="Function",
+ name=func.__name__,
+ reason=reason,
+ replacement=replacement,
+ since_version=since_version,
+ only_once=only_once,
+ skip_backtrace_count=skip_backtrace_count,
+ )
return func(*args, **kwargs)
+
return wrapper
+
if func is not None:
return decorator(func)
return decorator
-def deprecated_warning(type_, name, reason=None, replacement=None,
- since_version=None, only_once=True,
- skip_backtrace_count=0):
+def deprecated_warning(
+ type_,
+ name,
+ reason=None,
+ replacement=None,
+ since_version=None,
+ only_once=True,
+ skip_backtrace_count=0,
+):
"""
Function to log a deprecation warning
diff --git a/src/silx/utils/files.py b/src/silx/utils/files.py
index ab8d417..e240af1 100644
--- a/src/silx/utils/files.py
+++ b/src/silx/utils/files.py
@@ -31,6 +31,7 @@ __date__ = "19/09/2016"
import os.path
import glob
+
def expand_filenames(filenames):
"""
Takes a list of paths and expand it into a list of files.
diff --git a/src/silx/utils/launcher.py b/src/silx/utils/launcher.py
index 20752b3..ed94f5d 100644
--- a/src/silx/utils/launcher.py
+++ b/src/silx/utils/launcher.py
@@ -87,9 +87,9 @@ class LauncherCommand(object):
# reach the 'main' function
if not hasattr(module, "main"):
- raise TypeError("Module expect to have a 'main' function")
- else:
- main = getattr(module, "main")
+ raise TypeError(f"Module {module.__name__} must have a 'main' function")
+
+ main = getattr(module, "main")
return main
@contextlib.contextmanager
@@ -140,12 +140,9 @@ class Launcher(object):
and execute the commands.
"""
- def __init__(self,
- prog=None,
- usage=None,
- description=None,
- epilog=None,
- version=None):
+ def __init__(
+ self, prog=None, usage=None, description=None, epilog=None, version=None
+ ):
"""
:param str prog: Name of the program. If it is not defined it uses the
first argument of `sys.argv`
@@ -168,7 +165,8 @@ class Launcher(object):
help_command = LauncherCommand(
"help",
description="Show help of the following command",
- function=self.execute_help)
+ function=self.execute_help,
+ )
self.add_command(command=help_command)
def add_command(self, name=None, module_name=None, description=None, command=None):
@@ -183,17 +181,15 @@ class Launcher(object):
:param LauncherCommand command: A `LauncherCommand`
"""
if command is not None:
- assert(name is None and module_name is None and description is None)
+ assert name is None and module_name is None and description is None
else:
command = LauncherCommand(
- name=name,
- description=description,
- module_name=module_name)
+ name=name, description=description, module_name=module_name
+ )
self._commands[command.name] = command
def print_help(self):
- """Print the help to stdout.
- """
+ """Print the help to stdout."""
usage = self.usage
if usage is None:
usage = "usage: {0.prog} [--version|--help] <command> [<args>]"
@@ -226,10 +222,11 @@ class Launcher(object):
description = "Display help information about %s" % self.prog
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
- 'command',
+ "command",
default=None,
nargs=argparse.OPTIONAL,
- help='Command in which aving help')
+ help="Command in which aving help",
+ )
try:
options = parser.parse_args(argv[1:])
diff --git a/src/silx/utils/number.py b/src/silx/utils/number.py
index 72106e7..630c79c 100755
--- a/src/silx/utils/number.py
+++ b/src/silx/utils/number.py
@@ -52,7 +52,9 @@ if _biggest_float is None:
_float_types = (numpy.float64, numpy.float32, numpy.float16)
-_parse_numeric_value = re.compile(r"^\s*[-+]?0*(\d+?)?(?:\.(\d+))?(?:[eE]([-+]?\d+))?\s*$")
+_parse_numeric_value = re.compile(
+ r"^\s*[-+]?0*(\d+?)?(?:\.(\d+))?(?:[eE]([-+]?\d+))?\s*$"
+)
def is_longdouble_64bits():
@@ -129,14 +131,24 @@ def min_numerical_convertible_type(string, check_accuracy=True):
expected = number + decimal
# This format the number without python convertion
try:
- result = numpy.array2string(value, precision=len(number) + len(decimal), floatmode="fixed")
+ result = numpy.array2string(
+ value, precision=len(number) + len(decimal), floatmode="fixed"
+ )
except TypeError:
# numpy 1.8.2 do not have floatmode argument
- _logger.warning("Not able to check accuracy of the conversion of '%s' using %s", string, _biggest_float)
+ _logger.warning(
+ "Not able to check accuracy of the conversion of '%s' using %s",
+ string,
+ _biggest_float,
+ )
return numpy_type
result = result.replace(".", "").replace("-", "")
if not result.startswith(expected):
- _logger.warning("Not able to convert '%s' using %s without losing precision", string, _biggest_float)
+ _logger.warning(
+ "Not able to convert '%s' using %s without losing precision",
+ string,
+ _biggest_float,
+ )
return numpy_type
diff --git a/src/silx/utils/property.py b/src/silx/utils/property.py
index 029f28e..76d2cdf 100644
--- a/src/silx/utils/property.py
+++ b/src/silx/utils/property.py
@@ -45,5 +45,6 @@ class classproperty(property):
def VALUE(self):
return 10
"""
+
def __get__(self, cls, owner):
return classmethod(self.fget).__get__(None, owner)()
diff --git a/src/silx/utils/proxy.py b/src/silx/utils/proxy.py
index 7801b4b..f0da3c9 100644
--- a/src/silx/utils/proxy.py
+++ b/src/silx/utils/proxy.py
@@ -180,8 +180,7 @@ def _docstring(dest, origin):
try:
origin = getattr(origin, dest.__name__)
except AttributeError:
- raise ValueError(
- "origin class has no %s method" % dest.__name__)
+ raise ValueError("origin class has no %s method" % dest.__name__)
dest.__doc__ = origin.__doc__
return dest
diff --git a/src/silx/utils/retry.py b/src/silx/utils/retry.py
index 804bcb6..a365abf 100644
--- a/src/silx/utils/retry.py
+++ b/src/silx/utils/retry.py
@@ -1,5 +1,5 @@
# /*##########################################################################
-# Copyright (C) 2016-2017 European Synchrotron Radiation Facility
+# Copyright (C) 2016-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -27,14 +27,13 @@ no longer fail.
__authors__ = ["W. de Nolf"]
__license__ = "MIT"
-__date__ = "05/02/2020"
+__date__ = "28/11/2023"
import time
import inspect
from functools import wraps
from contextlib import contextmanager
-import multiprocessing
from queue import Empty
@@ -252,6 +251,8 @@ def retry_in_subprocess(
if retry_period is None:
retry_period = RETRY_PERIOD
+ import multiprocessing
+
def decorator(method):
@wraps(method)
def wrapper(*args, **kw):
diff --git a/src/silx/utils/test/test_array_like.py b/src/silx/utils/test/test_array_like.py
index 309b9ff..74e2604 100644
--- a/src/silx/utils/test/test_array_like.py
+++ b/src/silx/utils/test/test_array_like.py
@@ -34,12 +34,17 @@ import tempfile
import unittest
from ..array_like import DatasetView, ListOfImages
-from ..array_like import get_dtype, get_concatenated_dtype, get_shape,\
- is_array, is_nested_sequence, is_list_of_arrays
+from ..array_like import (
+ get_dtype,
+ get_concatenated_dtype,
+ get_shape,
+ is_array,
+ is_nested_sequence,
+ is_list_of_arrays,
+)
class TestTransposedDatasetView(unittest.TestCase):
-
def setUp(self):
# dataset attributes
self.ndim = 3
@@ -57,8 +62,14 @@ class TestTransposedDatasetView(unittest.TestCase):
self.h5f = h5py.File(self.h5_fname, "r")
- self.all_permutations = [(0, 1, 2), (0, 2, 1), (1, 0, 2), (1, 2, 0),
- (2, 0, 1), (2, 1, 0)]
+ self.all_permutations = [
+ (0, 1, 2),
+ (0, 2, 1),
+ (1, 0, 2),
+ (1, 2, 0),
+ (2, 0, 1),
+ (2, 1, 0),
+ ]
def tearDown(self):
self.h5f.close()
@@ -86,15 +97,12 @@ class TestTransposedDatasetView(unittest.TestCase):
# reversing the dimensions twice results in no change
rtrans = list(reversed(range(self.ndim)))
- self.assertTrue(numpy.array_equal(
- a,
- a.transpose(rtrans).transpose(rtrans)))
+ self.assertTrue(numpy.array_equal(a, a.transpose(rtrans).transpose(rtrans)))
for i in range(a.shape[0]):
for j in range(a.shape[1]):
for k in range(a.shape[2]):
- self.assertEqual(self.h5f["volume"][i, j, k],
- a[i, j, k])
+ self.assertEqual(self.h5f["volume"][i, j, k], a[i, j, k])
def _testTransposition(self, transposition):
"""test transposed dataset
@@ -102,56 +110,50 @@ class TestTransposedDatasetView(unittest.TestCase):
:param tuple transposition: List of dimensions (0... n-1) sorted
in the desired order
"""
- a = DatasetView(self.h5f["volume"],
- transposition=transposition)
+ a = DatasetView(self.h5f["volume"], transposition=transposition)
self._testSize(a)
# sort shape of transposed object, to hopefully find the original shape
- sorted_shape = tuple(dim_size for (_, dim_size) in
- sorted(zip(transposition, a.shape)))
+ sorted_shape = tuple(
+ dim_size for (_, dim_size) in sorted(zip(transposition, a.shape))
+ )
self.assertEqual(sorted_shape, self.original_shape)
a_as_array = numpy.array(self.h5f["volume"]).transpose(transposition)
# test the __array__ method
- self.assertTrue(numpy.array_equal(
- numpy.array(a),
- a_as_array))
+ self.assertTrue(numpy.array_equal(numpy.array(a), a_as_array))
# test slicing
- for selection in [(2, slice(None), slice(None)),
- (slice(None), 1, slice(0, 8)),
- (slice(0, 3), slice(None), 3),
- (1, 3, slice(None)),
- (slice(None), 2, 1),
- (4, slice(1, 9, 2), 2)]:
+ for selection in [
+ (2, slice(None), slice(None)),
+ (slice(None), 1, slice(0, 8)),
+ (slice(0, 3), slice(None), 3),
+ (1, 3, slice(None)),
+ (slice(None), 2, 1),
+ (4, slice(1, 9, 2), 2),
+ ]:
self.assertIsInstance(a[selection], numpy.ndarray)
- self.assertTrue(numpy.array_equal(
- a[selection],
- a_as_array[selection]))
+ self.assertTrue(numpy.array_equal(a[selection], a_as_array[selection]))
# test the DatasetView.__getitem__ for single values
# (step adjusted to test at least 3 indices in each dimension)
for i in range(0, a.shape[0], a.shape[0] // 3):
for j in range(0, a.shape[1], a.shape[1] // 3):
for k in range(0, a.shape[2], a.shape[2] // 3):
- sorted_indices = tuple(idx for (_, idx) in
- sorted(zip(transposition, [i, j, k])))
+ sorted_indices = tuple(
+ idx for (_, idx) in sorted(zip(transposition, [i, j, k]))
+ )
viewed_value = a[i, j, k]
corresponding_original_value = self.h5f["volume"][sorted_indices]
- self.assertEqual(viewed_value,
- corresponding_original_value)
+ self.assertEqual(viewed_value, corresponding_original_value)
# reversing the dimensions twice results in no change
rtrans = list(reversed(range(self.ndim)))
- self.assertTrue(numpy.array_equal(
- a,
- a.transpose(rtrans).transpose(rtrans)))
+ self.assertTrue(numpy.array_equal(a, a.transpose(rtrans).transpose(rtrans)))
# test .T property
- self.assertTrue(numpy.array_equal(
- a.T,
- a.transpose(rtrans)))
+ self.assertTrue(numpy.array_equal(a.T, a.transpose(rtrans)))
def testTransposition012(self):
"""transposition = (0, 1, 2)
@@ -184,22 +186,23 @@ class TestTransposedDatasetView(unittest.TestCase):
self._testDoubleTransposition(trans1, trans2)
def _testDoubleTransposition(self, transposition1, transposition2):
- a = DatasetView(self.h5f["volume"],
- transposition=transposition1).transpose(transposition2)
+ a = DatasetView(self.h5f["volume"], transposition=transposition1).transpose(
+ transposition2
+ )
b = self.volume.transpose(transposition1).transpose(transposition2)
- self.assertTrue(numpy.array_equal(a, b),
- "failed with double transposition %s %s" % (transposition1, transposition2))
+ self.assertTrue(
+ numpy.array_equal(a, b),
+ "failed with double transposition %s %s" % (transposition1, transposition2),
+ )
def test1DIndex(self):
a = DatasetView(self.h5f["volume"])
- self.assertTrue(numpy.array_equal(self.volume[1],
- a[1]))
+ self.assertTrue(numpy.array_equal(self.volume[1], a[1]))
b = DatasetView(self.h5f["volume"], transposition=(1, 0, 2))
- self.assertTrue(numpy.array_equal(self.volume[:, 1, :],
- b[1]))
+ self.assertTrue(numpy.array_equal(self.volume[:, 1, :], b[1]))
class TestTransposedListOfImages(unittest.TestCase):
@@ -215,13 +218,18 @@ class TestTransposedListOfImages(unittest.TestCase):
self.images = []
for i in range(self.original_shape[0]):
- self.images.append(
- volume[i])
+ self.images.append(volume[i])
self.images_as_3D_array = numpy.array(self.images)
- self.all_permutations = [(0, 1, 2), (0, 2, 1), (1, 0, 2), (1, 2, 0),
- (2, 0, 1), (2, 1, 0)]
+ self.all_permutations = [
+ (0, 1, 2),
+ (0, 2, 1),
+ (1, 0, 2),
+ (1, 2, 0),
+ (2, 0, 1),
+ (2, 1, 0),
+ ]
def tearDown(self):
pass
@@ -248,19 +256,14 @@ class TestTransposedListOfImages(unittest.TestCase):
for i in range(a.shape[0]):
for j in range(a.shape[1]):
for k in range(a.shape[2]):
- self.assertEqual(self.images[i][j, k],
- a[i, j, k])
+ self.assertEqual(self.images[i][j, k], a[i, j, k])
# reversing the dimensions twice results in no change
rtrans = list(reversed(range(self.ndim)))
- self.assertTrue(numpy.array_equal(
- a,
- a.transpose(rtrans).transpose(rtrans)))
+ self.assertTrue(numpy.array_equal(a, a.transpose(rtrans).transpose(rtrans)))
# test .T property
- self.assertTrue(numpy.array_equal(
- a.T,
- a.transpose(rtrans)))
+ self.assertTrue(numpy.array_equal(a.T, a.transpose(rtrans)))
def _testTransposition(self, transposition):
"""test transposed dataset
@@ -268,33 +271,31 @@ class TestTransposedListOfImages(unittest.TestCase):
:param tuple transposition: List of dimensions (0... n-1) sorted
in the desired order
"""
- a = ListOfImages(self.images,
- transposition=transposition)
+ a = ListOfImages(self.images, transposition=transposition)
self._testSize(a)
# sort shape of transposed object, to hopefully find the original shape
- sorted_shape = tuple(dim_size for (_, dim_size) in
- sorted(zip(transposition, a.shape)))
+ sorted_shape = tuple(
+ dim_size for (_, dim_size) in sorted(zip(transposition, a.shape))
+ )
self.assertEqual(sorted_shape, self.original_shape)
a_as_array = numpy.array(self.images).transpose(transposition)
# test the DatasetView.__array__ method
- self.assertTrue(numpy.array_equal(
- numpy.array(a),
- a_as_array))
+ self.assertTrue(numpy.array_equal(numpy.array(a), a_as_array))
# test slicing
- for selection in [(2, slice(None), slice(None)),
- (slice(None), 1, slice(0, 8)),
- (slice(0, 3), slice(None), 3),
- (1, 3, slice(None)),
- (slice(None), 2, 1),
- (4, slice(1, 9, 2), 2)]:
+ for selection in [
+ (2, slice(None), slice(None)),
+ (slice(None), 1, slice(0, 8)),
+ (slice(0, 3), slice(None), 3),
+ (1, 3, slice(None)),
+ (slice(None), 2, 1),
+ (4, slice(1, 9, 2), 2),
+ ]:
self.assertIsInstance(a[selection], numpy.ndarray)
- self.assertTrue(numpy.array_equal(
- a[selection],
- a_as_array[selection]))
+ self.assertTrue(numpy.array_equal(a[selection], a_as_array[selection]))
# test the DatasetView.__getitem__ for single values
# (step adjusted to test at least 3 indices in each dimension)
@@ -302,31 +303,32 @@ class TestTransposedListOfImages(unittest.TestCase):
for j in range(0, a.shape[1], a.shape[1] // 3):
for k in range(0, a.shape[2], a.shape[2] // 3):
viewed_value = a[i, j, k]
- sorted_indices = tuple(idx for (_, idx) in
- sorted(zip(transposition, [i, j, k])))
- corresponding_original_value = self.images[sorted_indices[0]][sorted_indices[1:]]
- self.assertEqual(viewed_value,
- corresponding_original_value)
+ sorted_indices = tuple(
+ idx for (_, idx) in sorted(zip(transposition, [i, j, k]))
+ )
+ corresponding_original_value = self.images[sorted_indices[0]][
+ sorted_indices[1:]
+ ]
+ self.assertEqual(viewed_value, corresponding_original_value)
# reversing the dimensions twice results in no change
rtrans = list(reversed(range(self.ndim)))
- self.assertTrue(numpy.array_equal(
- a,
- a.transpose(rtrans).transpose(rtrans)))
+ self.assertTrue(numpy.array_equal(a, a.transpose(rtrans).transpose(rtrans)))
# test .T property
- self.assertTrue(numpy.array_equal(
- a.T,
- a.transpose(rtrans)))
+ self.assertTrue(numpy.array_equal(a.T, a.transpose(rtrans)))
def _testDoubleTransposition(self, transposition1, transposition2):
- a = ListOfImages(self.images,
- transposition=transposition1).transpose(transposition2)
+ a = ListOfImages(self.images, transposition=transposition1).transpose(
+ transposition2
+ )
b = self.images_as_3D_array.transpose(transposition1).transpose(transposition2)
- self.assertTrue(numpy.array_equal(a, b),
- "failed with double transposition %s %s" % (transposition1, transposition2))
+ self.assertTrue(
+ numpy.array_equal(a, b),
+ "failed with double transposition %s %s" % (transposition1, transposition2),
+ )
def testTransposition012(self):
"""transposition = (0, 1, 2)
@@ -360,52 +362,46 @@ class TestTransposedListOfImages(unittest.TestCase):
def test1DIndex(self):
a = ListOfImages(self.images)
- self.assertTrue(numpy.array_equal(self.images[1],
- a[1]))
+ self.assertTrue(numpy.array_equal(self.images[1], a[1]))
b = ListOfImages(self.images, transposition=(1, 0, 2))
- self.assertTrue(numpy.array_equal(self.images_as_3D_array[:, 1, :],
- b[1]))
+ self.assertTrue(numpy.array_equal(self.images_as_3D_array[:, 1, :], b[1]))
class TestFunctions(unittest.TestCase):
"""Test functions to guess the dtype and shape of an array_like
object"""
+
def testListOfLists(self):
l = [[0, 1, 2], [2, 3, 4]]
- self.assertEqual(get_dtype(l),
- numpy.dtype(int))
- self.assertEqual(get_shape(l),
- (2, 3))
+ self.assertEqual(get_dtype(l), numpy.dtype(int))
+ self.assertEqual(get_shape(l), (2, 3))
self.assertTrue(is_nested_sequence(l))
self.assertFalse(is_array(l))
self.assertFalse(is_list_of_arrays(l))
- l = [[0., 1.], [2., 3.]]
- self.assertEqual(get_dtype(l),
- numpy.dtype(float))
- self.assertEqual(get_shape(l),
- (2, 2))
+ l = [[0.0, 1.0], [2.0, 3.0]]
+ self.assertEqual(get_dtype(l), numpy.dtype(float))
+ self.assertEqual(get_shape(l), (2, 2))
self.assertTrue(is_nested_sequence(l))
self.assertFalse(is_array(l))
self.assertFalse(is_list_of_arrays(l))
# concatenated dtype of int and float
- l = [numpy.array([[0, 1, 2], [2, 3, 4]]),
- numpy.array([[0., 1., 2.], [2., 3., 4.]])]
+ l = [
+ numpy.array([[0, 1, 2], [2, 3, 4]]),
+ numpy.array([[0.0, 1.0, 2.0], [2.0, 3.0, 4.0]]),
+ ]
- self.assertEqual(get_concatenated_dtype(l),
- numpy.array(l).dtype)
- self.assertEqual(get_shape(l),
- (2, 2, 3))
+ self.assertEqual(get_concatenated_dtype(l), numpy.array(l).dtype)
+ self.assertEqual(get_shape(l), (2, 2, 3))
self.assertFalse(is_nested_sequence(l))
self.assertFalse(is_array(l))
self.assertTrue(is_list_of_arrays(l))
def testNumpyArray(self):
a = numpy.array([[0, 1], [2, 3]])
- self.assertEqual(get_dtype(a),
- a.dtype)
+ self.assertEqual(get_dtype(a), a.dtype)
self.assertFalse(is_nested_sequence(a))
self.assertTrue(is_array(a))
self.assertFalse(is_list_of_arrays(a))
@@ -419,8 +415,7 @@ class TestFunctions(unittest.TestCase):
h5f["dataset"] = a
d = h5f["dataset"]
- self.assertEqual(get_dtype(d),
- numpy.dtype(int))
+ self.assertEqual(get_dtype(d), numpy.dtype(int))
self.assertFalse(is_nested_sequence(d))
self.assertTrue(is_array(d))
self.assertFalse(is_list_of_arrays(d))
diff --git a/src/silx/utils/test/test_debug.py b/src/silx/utils/test/test_debug.py
index 6b7b5d6..9895514 100644
--- a/src/silx/utils/test/test_debug.py
+++ b/src/silx/utils/test/test_debug.py
@@ -35,7 +35,6 @@ from silx.utils import testutils
@debug.log_all_methods
class _Foobar(object):
-
def a(self):
return None
diff --git a/src/silx/utils/test/test_deprecation.py b/src/silx/utils/test/test_deprecation.py
index 798221a..1115c5d 100644
--- a/src/silx/utils/test/test_deprecation.py
+++ b/src/silx/utils/test/test_deprecation.py
@@ -44,11 +44,15 @@ class TestDeprecation(unittest.TestCase):
def deprecatedWithParams(self):
pass
- @deprecation.deprecated(reason="r", replacement="r", since_version="v", only_once=True)
+ @deprecation.deprecated(
+ reason="r", replacement="r", since_version="v", only_once=True
+ )
def deprecatedOnlyOnce(self):
pass
- @deprecation.deprecated(reason="r", replacement="r", since_version="v", only_once=False)
+ @deprecation.deprecated(
+ reason="r", replacement="r", since_version="v", only_once=False
+ )
def deprecatedEveryTime(self):
pass
@@ -71,6 +75,7 @@ class TestDeprecation(unittest.TestCase):
def testLoggedSingleTime(self):
def log():
self.deprecatedOnlyOnce()
+
log()
log()
log()
diff --git a/src/silx/utils/test/test_enum.py b/src/silx/utils/test/test_enum.py
index df6b266..c652abf 100644
--- a/src/silx/utils/test/test_enum.py
+++ b/src/silx/utils/test/test_enum.py
@@ -1,6 +1,6 @@
# /*##########################################################################
#
-# Copyright (c) 2019 European Synchrotron Radiation Facility
+# Copyright (c) 2019-2023 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -28,55 +28,22 @@ __license__ = "MIT"
__date__ = "29/04/2019"
-import sys
-import unittest
-
-import enum
+import pytest
from silx.utils.enum import Enum
-class TestEnum(unittest.TestCase):
- """Tests for enum module."""
-
- def test(self):
- """Test with Enum"""
- class Success(Enum):
- A = 1
- B = 'B'
- self._check_enum_content(Success)
-
- @unittest.skipIf(sys.version_info.major <= 2, 'Python3 only')
- def test(self):
- """Test Enum with member redefinition"""
- with self.assertRaises(TypeError):
- class Failure(Enum):
- A = 1
- A = 'B'
-
- def test_unique(self):
- """Test with enum.unique"""
- with self.assertRaises(ValueError):
- @enum.unique
- class Failure(Enum):
- A = 1
- B = 1
-
- @enum.unique
- class Success(Enum):
- A = 1
- B = 'B'
- self._check_enum_content(Success)
+def test_enum_methods():
+ """Test Enum"""
- def _check_enum_content(self, enum_):
- """Check that the content of an enum is: <A: 1, B: 2>.
+ class Success(Enum):
+ A = 1
+ B = "B"
- :param Enum enum_:
- """
- self.assertEqual(enum_.members(), (enum_.A, enum_.B))
- self.assertEqual(enum_.names(), ('A', 'B'))
- self.assertEqual(enum_.values(), (1, 'B'))
+ assert Success.members() == (Success.A, Success.B)
+ assert Success.names() == ("A", "B")
+ assert Success.values() == (1, "B")
- self.assertEqual(enum_.from_value(1), enum_.A)
- self.assertEqual(enum_.from_value('B'), enum_.B)
- with self.assertRaises(ValueError):
- enum_.from_value(3)
+ assert Success.from_value(1) == Success.A
+ assert Success.from_value("B") == Success.B
+ with pytest.raises(ValueError):
+ Success.from_value(3)
diff --git a/src/silx/utils/test/test_external_resources.py b/src/silx/utils/test/test_external_resources.py
index 6279460..565554e 100644
--- a/src/silx/utils/test/test_external_resources.py
+++ b/src/silx/utils/test/test_external_resources.py
@@ -40,7 +40,7 @@ from silx.utils.ExternalResources import ExternalResources
def isSilxWebsiteAvailable():
try:
- urllib.request.urlopen('http://www.silx.org', timeout=1)
+ urllib.request.urlopen("http://www.silx.org", timeout=1)
return True
except urllib.error.URLError:
return False
@@ -58,7 +58,9 @@ class TestExternalResources(unittest.TestCase):
raise unittest.SkipTest("Network or silx website not available")
def setUp(self):
- self.resources = ExternalResources("toto%d" % os.getpid(), "http://www.silx.org/pub/silx/")
+ self.resources = ExternalResources(
+ "toto%d" % os.getpid(), "http://www.silx.org/pub/silx/"
+ )
def tearDown(self):
if self.resources.data_home:
diff --git a/src/silx/utils/test/test_launcher.py b/src/silx/utils/test/test_launcher.py
index 3261df5..bfb5041 100644
--- a/src/silx/utils/test/test_launcher.py
+++ b/src/silx/utils/test/test_launcher.py
@@ -34,8 +34,7 @@ from silx.utils.testutils import ParametricTestCase
from .. import launcher
-class CallbackMock():
-
+class CallbackMock:
def __init__(self, result=None):
self._execute_count = 0
self._execute_argv = None
@@ -83,7 +82,6 @@ class TestLauncherCommand(unittest.TestCase):
class TestModuleCommand(ParametricTestCase):
-
def setUp(self):
module_name = "silx.utils.test.test_launcher_command"
command = launcher.LauncherCommand("foo", module_name=module_name)
diff --git a/src/silx/utils/test/test_launcher_command.py b/src/silx/utils/test/test_launcher_command.py
index ff9f336..94bd09a 100644
--- a/src/silx/utils/test/test_launcher_command.py
+++ b/src/silx/utils/test/test_launcher_command.py
@@ -32,7 +32,6 @@ import sys
def main(argv):
-
if "--help" in argv:
# Common behaviour of ArgumentParser
sys.exit(0)
diff --git a/src/silx/utils/test/test_number.py b/src/silx/utils/test/test_number.py
index 8c6d1a2..c4c29a1 100644
--- a/src/silx/utils/test/test_number.py
+++ b/src/silx/utils/test/test_number.py
@@ -28,8 +28,7 @@ __date__ = "05/06/2018"
import logging
import numpy
-import unittest
-import pkg_resources
+from packaging.version import Version
from silx.utils import number
from silx.utils import testutils
@@ -37,7 +36,6 @@ _logger = logging.getLogger(__name__)
class TestConversionTypes(testutils.ParametricTestCase):
-
def testEmptyFail(self):
self.assertRaises(ValueError, number.min_numerical_convertible_type, "")
@@ -116,13 +114,13 @@ class TestConversionTypes(testutils.ParametricTestCase):
self.skipIfFloat80NotSupported()
dtype = number.min_numerical_convertible_type("1000000000.00001013")
- if pkg_resources.parse_version(numpy.version.version) <= pkg_resources.parse_version("1.10.4"):
+ if Version(numpy.version.version) <= Version("1.10.4"):
# numpy 1.8.2 -> Debian 8
# Checking a float128 precision with numpy 1.8.2 using abs(diff) is not working.
# It looks like the difference is done using float64 (diff == 0.0)
expected = (numpy.longdouble, numpy.float64)
else:
- expected = (numpy.longdouble, )
+ expected = (numpy.longdouble,)
self.assertIn(dtype, expected)
def testExponent32(self):
@@ -149,25 +147,27 @@ class TestConversionTypes(testutils.ParametricTestCase):
def testLosePrecisionUsingFloat80(self):
self.skipIfFloat80NotSupported()
- if pkg_resources.parse_version(numpy.version.version) <= pkg_resources.parse_version("1.10.4"):
+ if Version(numpy.version.version) <= Version("1.10.4"):
self.skipTest("numpy > 1.10.4 expected")
# value does not fit even in a 128 bits mantissa
value = "1.0340282366920938463463374607431768211456"
func = testutils.validate_logging(number._logger.name, warning=1)
func = func(number.min_numerical_convertible_type)
dtype = func(value)
- self.assertIn(dtype, (numpy.longdouble, ))
+ self.assertIn(dtype, (numpy.longdouble,))
def testMillisecondEpochTime(self):
- datetimes = ['1465803236.495412',
- '1465803236.999362',
- '1465803237.504311',
- '1465803238.009261',
- '1465803238.512211',
- '1465803239.016160',
- '1465803239.520110',
- '1465803240.026059',
- '1465803240.529009']
+ datetimes = [
+ "1465803236.495412",
+ "1465803236.999362",
+ "1465803237.504311",
+ "1465803238.009261",
+ "1465803238.512211",
+ "1465803239.016160",
+ "1465803239.520110",
+ "1465803240.026059",
+ "1465803240.529009",
+ ]
for datetime in datetimes:
with self.subTest(datetime=datetime):
dtype = number.min_numerical_convertible_type(datetime)
diff --git a/src/silx/utils/test/test_proxy.py b/src/silx/utils/test/test_proxy.py
index 7af4d1f..fae74a3 100644
--- a/src/silx/utils/test/test_proxy.py
+++ b/src/silx/utils/test/test_proxy.py
@@ -35,7 +35,6 @@ from silx.utils.proxy import Proxy, docstring
class Thing(object):
-
def __init__(self, value):
self.value = value
@@ -264,7 +263,6 @@ class TestInheritedProxy(unittest.TestCase):
class TestPickle(unittest.TestCase):
-
def test_dumps(self):
obj = Thing(10)
p = Proxy(obj)
@@ -295,8 +293,7 @@ class TestDocstring(unittest.TestCase):
def method(self):
pass
- self.assertEqual(Derived.method.__doc__,
- TestDocstring.Base.method.__doc__)
+ self.assertEqual(Derived.method.__doc__, TestDocstring.Base.method.__doc__)
def test_composition(self):
class Composed(object):
@@ -311,11 +308,9 @@ class TestDocstring(unittest.TestCase):
def renamed(self):
return self._base.method()
- self.assertEqual(Composed.method.__doc__,
- TestDocstring.Base.method.__doc__)
+ self.assertEqual(Composed.method.__doc__, TestDocstring.Base.method.__doc__)
- self.assertEqual(Composed.renamed.__doc__,
- TestDocstring.Base.method.__doc__)
+ self.assertEqual(Composed.renamed.__doc__, TestDocstring.Base.method.__doc__)
def test_function(self):
def f():
diff --git a/src/silx/utils/test/test_weakref.py b/src/silx/utils/test/test_weakref.py
index 4e3bf21..dd1d6f1 100644
--- a/src/silx/utils/test/test_weakref.py
+++ b/src/silx/utils/test/test_weakref.py
@@ -34,6 +34,7 @@ from .. import weakref
class Dummy(object):
"""Dummy class to use it as geanie pig"""
+
def inc(self, a):
return a + 1
@@ -74,6 +75,7 @@ class TestWeakMethod(unittest.TestCase):
def testDeadFunction(self):
def inc(a):
return a + 1
+
callable_ = weakref.WeakMethod(inc)
inc = None
self.assertIsNone(callable_())
@@ -93,6 +95,7 @@ class TestWeakMethod(unittest.TestCase):
def callback(ref):
self.__count += 1
self.assertIs(callable_, ref)
+
dummy = Dummy()
callable_ = weakref.WeakMethod(dummy.inc, callback)
dummy = None
@@ -104,6 +107,7 @@ class TestWeakMethod(unittest.TestCase):
def callback(ref):
self.__count += 1
self.assertIs(callable_, ref)
+
dummy = Dummy()
dummy.inc2 = lambda self, a: a + 1
callable_ = weakref.WeakMethod(dummy.inc2, callback)
@@ -116,6 +120,7 @@ class TestWeakMethod(unittest.TestCase):
def callback(ref):
self.__count += 1
self.assertIs(callable_, ref)
+
store = lambda a: a + 1 # noqa: E731
callable_ = weakref.WeakMethod(store, callback)
store = None
@@ -143,7 +148,6 @@ class TestWeakMethod(unittest.TestCase):
class TestWeakMethodProxy(unittest.TestCase):
-
def testMethod(self):
dummy = Dummy()
callable_ = weakref.WeakMethodProxy(dummy.inc)
diff --git a/src/silx/utils/testutils.py b/src/silx/utils/testutils.py
index b331829..84e8fb0 100755
--- a/src/silx/utils/testutils.py
+++ b/src/silx/utils/testutils.py
@@ -46,9 +46,12 @@ _logger = logging.getLogger(__name__)
if sys.hexversion >= 0x030400F0: # Python >= 3.4
+
class ParametricTestCase(unittest.TestCase):
pass
+
else:
+
class ParametricTestCase(unittest.TestCase):
"""TestCase with subTest support for Python < 3.4.
@@ -63,8 +66,8 @@ else:
def subTest(self, msg=None, **params):
"""Use as unittest.TestCase.subTest method in Python >= 3.4."""
# Format arguments as: '[msg] (key=value, ...)'
- param_str = ', '.join(['%s=%s' % (k, v) for k, v in params.items()])
- self._subtest_msg = '[%s] (%s)' % (msg or '', param_str)
+ param_str = ", ".join(["%s=%s" % (k, v) for k, v in params.items()])
+ self._subtest_msg = "[%s] (%s)" % (msg or "", param_str)
yield
self._subtest_msg = None
@@ -72,8 +75,9 @@ else:
short_desc = super(ParametricTestCase, self).shortDescription()
if self._subtest_msg is not None:
# Append subTest message to shortDescription
- short_desc = ' '.join(
- [msg for msg in (short_desc, self._subtest_msg) if msg])
+ short_desc = " ".join(
+ [msg for msg in (short_desc, self._subtest_msg) if msg]
+ )
return short_desc if short_desc else None
@@ -143,8 +147,16 @@ class LoggingValidator(logging.Handler):
:raises RuntimeError: If the message counts are the expected ones.
"""
- def __init__(self, logger=None, critical=None, error=None,
- warning=None, info=None, debug=None, notset=None):
+ def __init__(
+ self,
+ logger=None,
+ critical=None,
+ error=None,
+ warning=None,
+ info=None,
+ debug=None,
+ notset=None,
+ ):
if logger is None:
logger = logging.getLogger()
elif not isinstance(logger, logging.Logger):
@@ -159,10 +171,12 @@ class LoggingValidator(logging.Handler):
logging.WARNING: warning,
logging.INFO: info,
logging.DEBUG: debug,
- logging.NOTSET: notset
+ logging.NOTSET: notset,
}
- self._expected_count = sum([v for k, v in self.expected_count_by_level.items() if v is not None])
+ self._expected_count = sum(
+ [v for k, v in self.expected_count_by_level.items() if v is not None]
+ )
"""Amount of any logging expected"""
super(LoggingValidator, self).__init__()
@@ -189,15 +203,14 @@ class LoggingValidator(logging.Handler):
return len(self.records) >= self._expected_count
def get_count_by_level(self):
- """Returns the current message count by level.
- """
+ """Returns the current message count by level."""
count = {
logging.CRITICAL: 0,
logging.ERROR: 0,
logging.WARNING: 0,
logging.INFO: 0,
logging.DEBUG: 0,
- logging.NOTSET: 0
+ logging.NOTSET: 0,
}
for record in self.records:
level = record.levelno
@@ -230,18 +243,30 @@ class LoggingValidator(logging.Handler):
message += ", "
count = count_by_level[level]
expected_count = expected_count_by_level[level]
- message += "%d %s (got %d)" % (expected_count, logging.getLevelName(level), count)
+ message += "%d %s (got %d)" % (
+ expected_count,
+ logging.getLevelName(level),
+ count,
+ )
raise LoggingRuntimeError(
- 'Expected %s' % message, records=list(self.records))
+ "Expected %s" % message, records=list(self.records)
+ )
def emit(self, record):
"""Override :meth:`logging.Handler.emit`"""
self.records.append(record)
-def validate_logging(logger=None, critical=None, error=None,
- warning=None, info=None, debug=None, notset=None):
+def validate_logging(
+ logger=None,
+ critical=None,
+ error=None,
+ warning=None,
+ info=None,
+ debug=None,
+ notset=None,
+):
"""Decorator checking number of logging messages.
Propagation of logging messages is disabled by this decorator.
@@ -270,16 +295,20 @@ def validate_logging(logger=None, critical=None, error=None,
:param int notset: Expected number of NOTSET messages.
Default: Do not check.
"""
+
def decorator(func):
test_context = LoggingValidator(
- logger, critical, error, warning, info, debug, notset)
+ logger, critical, error, warning, info, debug, notset
+ )
@functools.wraps(func)
def wrapper(*args, **kwargs):
with test_context:
result = func(*args, **kwargs)
return result
+
return wrapper
+
return decorator
@@ -290,7 +319,8 @@ class TestLogging(LoggingValidator):
"Class",
"TestLogging",
since_version="1.0.0",
- replacement="LoggingValidator")
+ replacement="LoggingValidator",
+ )
super().__init__(*args, **kwargs)
@@ -325,6 +355,7 @@ class EnsureImportError(object):
if it is already imported. It only ensures that any attempt to import
it again will cause an ImportError to be raised.
"""
+
def __init__(self, name):
"""
diff --git a/src/silx/utils/weakref.py b/src/silx/utils/weakref.py
index 62a9232..96dbca4 100644
--- a/src/silx/utils/weakref.py
+++ b/src/silx/utils/weakref.py
@@ -164,6 +164,7 @@ class WeakMethodProxy(WeakMethod):
"""Wraps a callable object like a function or a bound method
with a weakref proxy.
"""
+
def __call__(self, *args, **kwargs):
"""Dereference the method and call it if the method is still alive.
Else raises an ReferenceError.
@@ -199,8 +200,7 @@ class WeakList(list):
self.__is_valid = False
def __create_ref(self, obj):
- """Create a weakref from an object. It uses the `ref` module function.
- """
+ """Create a weakref from an object. It uses the `ref` module function."""
return ref(obj, self.__invalidate)
def __clean(self):
@@ -255,7 +255,7 @@ class WeakList(list):
:param key: Index to delete
:type key: int or slice
- """
+ """
self.__clean()
del self.__list[key]