summaryrefslogtreecommitdiff
path: root/src/silx/gui
diff options
context:
space:
mode:
Diffstat (limited to 'src/silx/gui')
-rw-r--r--src/silx/gui/__init__.py49
-rw-r--r--src/silx/gui/_glutils/Context.py75
-rw-r--r--src/silx/gui/_glutils/FramebufferTexture.py168
-rw-r--r--src/silx/gui/_glutils/OpenGLWidget.py422
-rw-r--r--src/silx/gui/_glutils/Program.py202
-rw-r--r--src/silx/gui/_glutils/Texture.py352
-rw-r--r--src/silx/gui/_glutils/VertexBuffer.py266
-rw-r--r--src/silx/gui/_glutils/__init__.py43
-rw-r--r--src/silx/gui/_glutils/font.py156
-rw-r--r--src/silx/gui/_glutils/gl.py168
-rw-r--r--src/silx/gui/_glutils/utils.py123
-rwxr-xr-xsrc/silx/gui/colors.py1036
-rw-r--r--src/silx/gui/conftest.py5
-rw-r--r--src/silx/gui/console.py202
-rw-r--r--src/silx/gui/data/ArrayTableModel.py650
-rw-r--r--src/silx/gui/data/ArrayTableWidget.py492
-rw-r--r--src/silx/gui/data/DataViewer.py593
-rw-r--r--src/silx/gui/data/DataViewerFrame.py217
-rw-r--r--src/silx/gui/data/DataViewerSelector.py175
-rw-r--r--src/silx/gui/data/DataViews.py2059
-rw-r--r--src/silx/gui/data/Hdf5TableView.py634
-rw-r--r--src/silx/gui/data/HexaTableView.py272
-rw-r--r--src/silx/gui/data/NXdataWidgets.py1086
-rw-r--r--src/silx/gui/data/NumpyAxesSelector.py578
-rw-r--r--src/silx/gui/data/RecordTableView.py439
-rw-r--r--src/silx/gui/data/TextFormatter.py386
-rw-r--r--src/silx/gui/data/_RecordPlot.py92
-rw-r--r--src/silx/gui/data/_VolumeWindow.py148
-rw-r--r--src/silx/gui/data/__init__.py35
-rw-r--r--src/silx/gui/data/setup.py41
-rw-r--r--src/silx/gui/data/test/__init__.py24
-rw-r--r--src/silx/gui/data/test/test_arraywidget.py316
-rw-r--r--src/silx/gui/data/test/test_dataviewer.py304
-rw-r--r--src/silx/gui/data/test/test_numpyaxesselector.py150
-rw-r--r--src/silx/gui/data/test/test_textformatter.py199
-rw-r--r--src/silx/gui/dialog/AbstractDataFileDialog.py1731
-rw-r--r--src/silx/gui/dialog/ColormapDialog.py1775
-rw-r--r--src/silx/gui/dialog/DataFileDialog.py340
-rw-r--r--src/silx/gui/dialog/DatasetDialog.py122
-rw-r--r--src/silx/gui/dialog/FileTypeComboBox.py226
-rw-r--r--src/silx/gui/dialog/GroupDialog.py230
-rw-r--r--src/silx/gui/dialog/ImageFileDialog.py354
-rw-r--r--src/silx/gui/dialog/SafeFileIconProvider.py154
-rw-r--r--src/silx/gui/dialog/SafeFileSystemModel.py802
-rw-r--r--src/silx/gui/dialog/__init__.py29
-rw-r--r--src/silx/gui/dialog/setup.py40
-rw-r--r--src/silx/gui/dialog/test/__init__.py24
-rw-r--r--src/silx/gui/dialog/test/test_colormapdialog.py395
-rw-r--r--src/silx/gui/dialog/test/test_datafiledialog.py924
-rw-r--r--src/silx/gui/dialog/test/test_imagefiledialog.py772
-rw-r--r--src/silx/gui/dialog/utils.py99
-rw-r--r--src/silx/gui/fit/BackgroundWidget.py534
-rw-r--r--src/silx/gui/fit/FitConfig.py543
-rw-r--r--src/silx/gui/fit/FitWidget.py751
-rw-r--r--src/silx/gui/fit/FitWidgets.py555
-rw-r--r--src/silx/gui/fit/Parameters.py882
-rw-r--r--src/silx/gui/fit/__init__.py28
-rw-r--r--src/silx/gui/fit/setup.py43
-rw-r--r--src/silx/gui/fit/test/__init__.py24
-rw-r--r--src/silx/gui/fit/test/testBackgroundWidget.py72
-rw-r--r--src/silx/gui/fit/test/testFitConfig.py84
-rw-r--r--src/silx/gui/fit/test/testFitWidget.py124
-rw-r--r--src/silx/gui/hdf5/Hdf5Formatter.py240
-rw-r--r--src/silx/gui/hdf5/Hdf5HeaderView.py184
-rwxr-xr-xsrc/silx/gui/hdf5/Hdf5Item.py642
-rw-r--r--src/silx/gui/hdf5/Hdf5LoadingItem.py77
-rw-r--r--src/silx/gui/hdf5/Hdf5Node.py238
-rw-r--r--src/silx/gui/hdf5/Hdf5TreeModel.py742
-rw-r--r--src/silx/gui/hdf5/Hdf5TreeView.py269
-rw-r--r--src/silx/gui/hdf5/NexusSortFilterProxyModel.py224
-rw-r--r--src/silx/gui/hdf5/__init__.py44
-rw-r--r--src/silx/gui/hdf5/_utils.py461
-rw-r--r--src/silx/gui/hdf5/setup.py41
-rw-r--r--src/silx/gui/hdf5/test/__init__.py24
-rwxr-xr-xsrc/silx/gui/hdf5/test/test_hdf5.py1092
-rw-r--r--src/silx/gui/icons.py425
-rw-r--r--src/silx/gui/plot/AlphaSlider.py300
-rw-r--r--src/silx/gui/plot/ColorBar.py883
-rw-r--r--src/silx/gui/plot/Colormap.py42
-rw-r--r--src/silx/gui/plot/ColormapDialog.py43
-rw-r--r--src/silx/gui/plot/Colors.py90
-rw-r--r--src/silx/gui/plot/CompareImages.py1259
-rw-r--r--src/silx/gui/plot/ComplexImageView.py518
-rw-r--r--src/silx/gui/plot/CurvesROIWidget.py1581
-rw-r--r--src/silx/gui/plot/ImageStack.py640
-rw-r--r--src/silx/gui/plot/ImageView.py1057
-rw-r--r--src/silx/gui/plot/Interaction.py350
-rw-r--r--src/silx/gui/plot/ItemsSelectionDialog.py286
-rwxr-xr-xsrc/silx/gui/plot/LegendSelector.py1039
-rw-r--r--src/silx/gui/plot/LimitsHistory.py83
-rw-r--r--src/silx/gui/plot/MaskToolsWidget.py919
-rw-r--r--src/silx/gui/plot/PlotActions.py67
-rw-r--r--src/silx/gui/plot/PlotEvents.py166
-rw-r--r--src/silx/gui/plot/PlotInteraction.py1746
-rw-r--r--src/silx/gui/plot/PlotToolButtons.py592
-rw-r--r--src/silx/gui/plot/PlotTools.py43
-rwxr-xr-xsrc/silx/gui/plot/PlotWidget.py3628
-rw-r--r--src/silx/gui/plot/PlotWindow.py993
-rw-r--r--src/silx/gui/plot/PrintPreviewToolButton.py388
-rw-r--r--src/silx/gui/plot/Profile.py352
-rw-r--r--src/silx/gui/plot/ProfileMainWindow.py110
-rw-r--r--src/silx/gui/plot/ROIStatsWidget.py780
-rw-r--r--src/silx/gui/plot/ScatterMaskToolsWidget.py621
-rw-r--r--src/silx/gui/plot/ScatterView.py404
-rw-r--r--src/silx/gui/plot/StackView.py1254
-rw-r--r--src/silx/gui/plot/StatsWidget.py1658
-rw-r--r--src/silx/gui/plot/_BaseMaskToolsWidget.py1282
-rw-r--r--src/silx/gui/plot/__init__.py71
-rw-r--r--src/silx/gui/plot/_utils/__init__.py92
-rw-r--r--src/silx/gui/plot/_utils/delaunay.py62
-rw-r--r--src/silx/gui/plot/_utils/dtime_ticklayout.py442
-rw-r--r--src/silx/gui/plot/_utils/panzoom.py325
-rw-r--r--src/silx/gui/plot/_utils/setup.py42
-rw-r--r--src/silx/gui/plot/_utils/test/__init__.py24
-rw-r--r--src/silx/gui/plot/_utils/test/test_dtime_ticklayout.py79
-rw-r--r--src/silx/gui/plot/_utils/test/test_ticklayout.py81
-rw-r--r--src/silx/gui/plot/_utils/ticklayout.py267
-rw-r--r--src/silx/gui/plot/actions/PlotAction.py78
-rw-r--r--src/silx/gui/plot/actions/PlotToolAction.py150
-rw-r--r--src/silx/gui/plot/actions/__init__.py42
-rwxr-xr-xsrc/silx/gui/plot/actions/control.py694
-rw-r--r--src/silx/gui/plot/actions/fit.py485
-rw-r--r--src/silx/gui/plot/actions/histogram.py542
-rw-r--r--src/silx/gui/plot/actions/io.py819
-rw-r--r--src/silx/gui/plot/actions/medfilt.py147
-rw-r--r--src/silx/gui/plot/actions/mode.py104
-rwxr-xr-xsrc/silx/gui/plot/backends/BackendBase.py568
-rwxr-xr-xsrc/silx/gui/plot/backends/BackendMatplotlib.py1557
-rwxr-xr-xsrc/silx/gui/plot/backends/BackendOpenGL.py1420
-rw-r--r--src/silx/gui/plot/backends/__init__.py29
-rw-r--r--src/silx/gui/plot/backends/glutils/GLPlotCurve.py1380
-rw-r--r--src/silx/gui/plot/backends/glutils/GLPlotFrame.py1210
-rw-r--r--src/silx/gui/plot/backends/glutils/GLPlotImage.py756
-rw-r--r--src/silx/gui/plot/backends/glutils/GLPlotItem.py99
-rw-r--r--src/silx/gui/plot/backends/glutils/GLPlotTriangles.py197
-rw-r--r--src/silx/gui/plot/backends/glutils/GLSupport.py158
-rw-r--r--src/silx/gui/plot/backends/glutils/GLText.py287
-rw-r--r--src/silx/gui/plot/backends/glutils/GLTexture.py241
-rw-r--r--src/silx/gui/plot/backends/glutils/PlotImageFile.py153
-rw-r--r--src/silx/gui/plot/backends/glutils/__init__.py46
-rw-r--r--src/silx/gui/plot/items/__init__.py53
-rw-r--r--src/silx/gui/plot/items/_arc_roi.py878
-rw-r--r--src/silx/gui/plot/items/_pick.py72
-rw-r--r--src/silx/gui/plot/items/_roi_base.py835
-rw-r--r--src/silx/gui/plot/items/axis.py560
-rw-r--r--src/silx/gui/plot/items/complex.py386
-rw-r--r--src/silx/gui/plot/items/core.py1733
-rw-r--r--src/silx/gui/plot/items/curve.py325
-rw-r--r--src/silx/gui/plot/items/histogram.py389
-rw-r--r--src/silx/gui/plot/items/image.py641
-rw-r--r--src/silx/gui/plot/items/image_aggregated.py229
-rwxr-xr-xsrc/silx/gui/plot/items/marker.py281
-rw-r--r--src/silx/gui/plot/items/roi.py1519
-rw-r--r--src/silx/gui/plot/items/scatter.py1002
-rw-r--r--src/silx/gui/plot/items/shape.py287
-rw-r--r--src/silx/gui/plot/matplotlib/Colormap.py249
-rw-r--r--src/silx/gui/plot/matplotlib/__init__.py37
-rw-r--r--src/silx/gui/plot/setup.py54
-rw-r--r--src/silx/gui/plot/stats/__init__.py33
-rw-r--r--src/silx/gui/plot/stats/stats.py890
-rw-r--r--src/silx/gui/plot/stats/statshandler.py202
-rw-r--r--src/silx/gui/plot/test/__init__.py24
-rw-r--r--src/silx/gui/plot/test/testAlphaSlider.py204
-rw-r--r--src/silx/gui/plot/test/testColorBar.py340
-rw-r--r--src/silx/gui/plot/test/testCompareImages.py106
-rw-r--r--src/silx/gui/plot/test/testComplexImageView.py84
-rw-r--r--src/silx/gui/plot/test/testCurvesROIWidget.py465
-rw-r--r--src/silx/gui/plot/test/testImageStack.py186
-rw-r--r--src/silx/gui/plot/test/testImageView.py194
-rw-r--r--src/silx/gui/plot/test/testInteraction.py78
-rw-r--r--src/silx/gui/plot/test/testItem.py360
-rw-r--r--src/silx/gui/plot/test/testLegendSelector.py130
-rw-r--r--src/silx/gui/plot/test/testLimitConstraints.py114
-rw-r--r--src/silx/gui/plot/test/testMaskToolsWidget.py306
-rw-r--r--src/silx/gui/plot/test/testPixelIntensityHistoAction.py145
-rw-r--r--src/silx/gui/plot/test/testPlotActions.py110
-rw-r--r--src/silx/gui/plot/test/testPlotInteraction.py160
-rwxr-xr-xsrc/silx/gui/plot/test/testPlotWidget.py2113
-rw-r--r--src/silx/gui/plot/test/testPlotWidgetNoBackend.py618
-rw-r--r--src/silx/gui/plot/test/testPlotWindow.py174
-rw-r--r--src/silx/gui/plot/test/testRoiStatsWidget.py277
-rw-r--r--src/silx/gui/plot/test/testSaveAction.py132
-rw-r--r--src/silx/gui/plot/test/testScatterMaskToolsWidget.py306
-rw-r--r--src/silx/gui/plot/test/testScatterView.py123
-rw-r--r--src/silx/gui/plot/test/testStackView.py248
-rw-r--r--src/silx/gui/plot/test/testStats.py1047
-rw-r--r--src/silx/gui/plot/test/testUtilsAxis.py203
-rw-r--r--src/silx/gui/plot/test/utils.py93
-rw-r--r--src/silx/gui/plot/tools/CurveLegendsWidget.py247
-rw-r--r--src/silx/gui/plot/tools/LimitsToolBar.py131
-rw-r--r--src/silx/gui/plot/tools/PositionInfo.py373
-rw-r--r--src/silx/gui/plot/tools/RadarView.py361
-rw-r--r--src/silx/gui/plot/tools/__init__.py50
-rw-r--r--src/silx/gui/plot/tools/profile/ScatterProfileToolBar.py54
-rw-r--r--src/silx/gui/plot/tools/profile/__init__.py38
-rw-r--r--src/silx/gui/plot/tools/profile/core.py525
-rw-r--r--src/silx/gui/plot/tools/profile/editors.py307
-rw-r--r--src/silx/gui/plot/tools/profile/manager.py1079
-rw-r--r--src/silx/gui/plot/tools/profile/rois.py1156
-rw-r--r--src/silx/gui/plot/tools/profile/toolbar.py172
-rw-r--r--src/silx/gui/plot/tools/roi.py1417
-rw-r--r--src/silx/gui/plot/tools/test/__init__.py24
-rw-r--r--src/silx/gui/plot/tools/test/testCurveLegendsWidget.py113
-rw-r--r--src/silx/gui/plot/tools/test/testProfile.py654
-rw-r--r--src/silx/gui/plot/tools/test/testROI.py682
-rw-r--r--src/silx/gui/plot/tools/test/testScatterProfileToolBar.py184
-rw-r--r--src/silx/gui/plot/tools/test/testTools.py135
-rw-r--r--src/silx/gui/plot/tools/toolbars.py362
-rw-r--r--src/silx/gui/plot/utils/__init__.py30
-rw-r--r--src/silx/gui/plot/utils/axis.py398
-rw-r--r--src/silx/gui/plot/utils/intersections.py101
-rw-r--r--src/silx/gui/plot3d/ParamTreeView.py522
-rw-r--r--src/silx/gui/plot3d/Plot3DWidget.py463
-rw-r--r--src/silx/gui/plot3d/Plot3DWindow.py88
-rw-r--r--src/silx/gui/plot3d/SFViewParamTree.py1814
-rw-r--r--src/silx/gui/plot3d/ScalarFieldView.py1552
-rw-r--r--src/silx/gui/plot3d/SceneWidget.py687
-rw-r--r--src/silx/gui/plot3d/SceneWindow.py219
-rw-r--r--src/silx/gui/plot3d/__init__.py40
-rw-r--r--src/silx/gui/plot3d/_model/__init__.py35
-rw-r--r--src/silx/gui/plot3d/_model/core.py372
-rw-r--r--src/silx/gui/plot3d/_model/items.py1759
-rw-r--r--src/silx/gui/plot3d/_model/model.py184
-rw-r--r--src/silx/gui/plot3d/actions/Plot3DAction.py71
-rw-r--r--src/silx/gui/plot3d/actions/__init__.py34
-rw-r--r--src/silx/gui/plot3d/actions/io.py337
-rw-r--r--src/silx/gui/plot3d/actions/mode.py178
-rw-r--r--src/silx/gui/plot3d/actions/viewpoint.py231
-rw-r--r--src/silx/gui/plot3d/conftest.py5
-rw-r--r--src/silx/gui/plot3d/items/__init__.py43
-rw-r--r--src/silx/gui/plot3d/items/_pick.py265
-rw-r--r--src/silx/gui/plot3d/items/clipplane.py136
-rw-r--r--src/silx/gui/plot3d/items/core.py778
-rw-r--r--src/silx/gui/plot3d/items/image.py425
-rw-r--r--src/silx/gui/plot3d/items/mesh.py792
-rw-r--r--src/silx/gui/plot3d/items/mixins.py288
-rw-r--r--src/silx/gui/plot3d/items/scatter.py617
-rw-r--r--src/silx/gui/plot3d/items/volume.py886
-rw-r--r--src/silx/gui/plot3d/scene/__init__.py34
-rw-r--r--src/silx/gui/plot3d/scene/axes.py258
-rw-r--r--src/silx/gui/plot3d/scene/camera.py353
-rw-r--r--src/silx/gui/plot3d/scene/core.py343
-rw-r--r--src/silx/gui/plot3d/scene/cutplane.py390
-rw-r--r--src/silx/gui/plot3d/scene/event.py225
-rw-r--r--src/silx/gui/plot3d/scene/function.py654
-rw-r--r--src/silx/gui/plot3d/scene/interaction.py701
-rw-r--r--src/silx/gui/plot3d/scene/primitives.py2524
-rw-r--r--src/silx/gui/plot3d/scene/test/__init__.py24
-rw-r--r--src/silx/gui/plot3d/scene/test/test_transform.py80
-rw-r--r--src/silx/gui/plot3d/scene/test/test_utils.py258
-rw-r--r--src/silx/gui/plot3d/scene/text.py535
-rw-r--r--src/silx/gui/plot3d/scene/transform.py1027
-rw-r--r--src/silx/gui/plot3d/scene/utils.py662
-rw-r--r--src/silx/gui/plot3d/scene/viewport.py603
-rw-r--r--src/silx/gui/plot3d/scene/window.py433
-rw-r--r--src/silx/gui/plot3d/setup.py50
-rw-r--r--src/silx/gui/plot3d/test/__init__.py25
-rw-r--r--src/silx/gui/plot3d/test/testGL.py73
-rw-r--r--src/silx/gui/plot3d/test/testScalarFieldView.py128
-rw-r--r--src/silx/gui/plot3d/test/testSceneWidget.py72
-rw-r--r--src/silx/gui/plot3d/test/testSceneWidgetPicking.py314
-rw-r--r--src/silx/gui/plot3d/test/testSceneWindow.py233
-rw-r--r--src/silx/gui/plot3d/test/testStatsWidget.py201
-rw-r--r--src/silx/gui/plot3d/tools/GroupPropertiesWidget.py202
-rw-r--r--src/silx/gui/plot3d/tools/PositionInfoWidget.py225
-rw-r--r--src/silx/gui/plot3d/tools/ViewpointTools.py84
-rw-r--r--src/silx/gui/plot3d/tools/__init__.py34
-rw-r--r--src/silx/gui/plot3d/tools/test/__init__.py25
-rw-r--r--src/silx/gui/plot3d/tools/test/testPositionInfoWidget.py89
-rw-r--r--src/silx/gui/plot3d/tools/toolbars.py209
-rw-r--r--src/silx/gui/plot3d/utils/__init__.py28
-rw-r--r--src/silx/gui/plot3d/utils/mng.py121
-rw-r--r--src/silx/gui/printer.py62
-rw-r--r--src/silx/gui/qt/__init__.py54
-rw-r--r--src/silx/gui/qt/_pyside_dynamic.py235
-rw-r--r--src/silx/gui/qt/_qt.py232
-rw-r--r--src/silx/gui/qt/_utils.py68
-rw-r--r--src/silx/gui/qt/inspect.py75
-rw-r--r--src/silx/gui/setup.py55
-rw-r--r--src/silx/gui/test/__init__.py24
-rwxr-xr-xsrc/silx/gui/test/test_colors.py603
-rw-r--r--src/silx/gui/test/test_console.py75
-rw-r--r--src/silx/gui/test/test_icons.py144
-rw-r--r--src/silx/gui/test/test_qt.py212
-rw-r--r--src/silx/gui/test/utils.py43
-rwxr-xr-xsrc/silx/gui/utils/__init__.py76
-rw-r--r--src/silx/gui/utils/concurrent.py105
-rw-r--r--src/silx/gui/utils/glutils/__init__.py199
-rw-r--r--src/silx/gui/utils/image.py143
-rw-r--r--src/silx/gui/utils/matplotlib.py65
-rw-r--r--src/silx/gui/utils/projecturl.py77
-rwxr-xr-xsrc/silx/gui/utils/qtutils.py196
-rw-r--r--src/silx/gui/utils/signal.py141
-rwxr-xr-xsrc/silx/gui/utils/test/__init__.py25
-rw-r--r--src/silx/gui/utils/test/test.py63
-rw-r--r--src/silx/gui/utils/test/test_async.py127
-rw-r--r--src/silx/gui/utils/test/test_glutils.py55
-rw-r--r--src/silx/gui/utils/test/test_image.py79
-rwxr-xr-xsrc/silx/gui/utils/test/test_qtutils.py65
-rw-r--r--src/silx/gui/utils/test/test_testutils.py44
-rw-r--r--src/silx/gui/utils/testutils.py508
-rw-r--r--src/silx/gui/widgets/BoxLayoutDockWidget.py90
-rw-r--r--src/silx/gui/widgets/ColormapNameComboBox.py166
-rw-r--r--src/silx/gui/widgets/ElidedLabel.py140
-rw-r--r--src/silx/gui/widgets/FloatEdit.py71
-rw-r--r--src/silx/gui/widgets/FlowLayout.py177
-rw-r--r--src/silx/gui/widgets/FrameBrowser.py324
-rw-r--r--src/silx/gui/widgets/HierarchicalTableView.py172
-rwxr-xr-xsrc/silx/gui/widgets/LegendIconWidget.py514
-rw-r--r--src/silx/gui/widgets/MedianFilterDialog.py80
-rw-r--r--src/silx/gui/widgets/MultiModeAction.py83
-rw-r--r--src/silx/gui/widgets/PeriodicTable.py831
-rw-r--r--src/silx/gui/widgets/PrintGeometryDialog.py222
-rw-r--r--src/silx/gui/widgets/PrintPreview.py697
-rw-r--r--src/silx/gui/widgets/RangeSlider.py776
-rw-r--r--src/silx/gui/widgets/TableWidget.py626
-rw-r--r--src/silx/gui/widgets/ThreadPoolPushButton.py238
-rw-r--r--src/silx/gui/widgets/UrlSelectionTable.py169
-rw-r--r--src/silx/gui/widgets/WaitingPushButton.py241
-rw-r--r--src/silx/gui/widgets/__init__.py27
-rw-r--r--src/silx/gui/widgets/setup.py41
-rw-r--r--src/silx/gui/widgets/test/__init__.py24
-rw-r--r--src/silx/gui/widgets/test/test_boxlayoutdockwidget.py72
-rw-r--r--src/silx/gui/widgets/test/test_elidedlabel.py100
-rw-r--r--src/silx/gui/widgets/test/test_flowlayout.py66
-rw-r--r--src/silx/gui/widgets/test/test_framebrowser.py62
-rw-r--r--src/silx/gui/widgets/test/test_hierarchicaltableview.py103
-rw-r--r--src/silx/gui/widgets/test/test_legendiconwidget.py63
-rw-r--r--src/silx/gui/widgets/test/test_periodictable.py148
-rw-r--r--src/silx/gui/widgets/test/test_printpreview.py63
-rw-r--r--src/silx/gui/widgets/test/test_rangeslider.py103
-rw-r--r--src/silx/gui/widgets/test/test_tablewidget.py50
-rw-r--r--src/silx/gui/widgets/test/test_threadpoolpushbutton.py124
333 files changed, 126477 insertions, 0 deletions
diff --git a/src/silx/gui/__init__.py b/src/silx/gui/__init__.py
new file mode 100644
index 0000000..b796e20
--- /dev/null
+++ b/src/silx/gui/__init__.py
@@ -0,0 +1,49 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides a set of Qt widgets.
+
+It contains the following sub-packages and modules:
+
+- silx.gui.colors: Functions to handle colors and colormap
+- silx.gui.console: IPython console widget
+- silx.gui.data:
+ Widgets for displaying data arrays using table views and plot widgets
+- silx.gui.dialog: Specific dialog widgets
+- silx.gui.fit: Widgets for controlling curve fitting
+- silx.gui.hdf5: Widgets for displaying content relative to HDF5 format
+- silx.gui.icons: Functions to access embedded icons
+- silx.gui.plot: Widgets for 1D and 2D plotting and related tools
+- silx.gui.plot3d: Widgets for visualizing data in 3D based on OpenGL
+- silx.gui.printer: Shared printer used by the library
+- silx.gui.qt: Common wrapper over different Python Qt binding
+- silx.gui.utils: Miscellaneous helpers for Qt
+- silx.gui.widgets: Miscellaneous standalone widgets
+
+See silx documentation: http://www.silx.org/doc/silx/latest/
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "23/05/2016"
diff --git a/src/silx/gui/_glutils/Context.py b/src/silx/gui/_glutils/Context.py
new file mode 100644
index 0000000..c62dbb9
--- /dev/null
+++ b/src/silx/gui/_glutils/Context.py
@@ -0,0 +1,75 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Abstraction of OpenGL context.
+
+It defines a way to get current OpenGL context to support multiple
+OpenGL contexts.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/07/2016"
+
+import contextlib
+
+
+class _DEFAULT_CONTEXT(object):
+ """The default value for OpenGL context"""
+ pass
+
+_context = _DEFAULT_CONTEXT
+"""The current OpenGL context"""
+
+
+def getCurrent():
+ """Returns platform dependent object of current OpenGL context.
+
+ This is useful to associate OpenGL resources with the context they are
+ created in.
+
+ :return: Platform specific OpenGL context
+ """
+ return _context
+
+
+def setCurrent(context=_DEFAULT_CONTEXT):
+ """Set a platform dependent OpenGL context
+
+ :param context: Platform dependent GL context
+ """
+ global _context
+ _context = context
+
+
+@contextlib.contextmanager
+def current(context):
+ """Context manager setting the platform-dependent GL context
+
+ :param context: Platform dependent GL context
+ """
+ previous_context = getCurrent()
+ setCurrent(context)
+ yield
+ setCurrent(previous_context)
diff --git a/src/silx/gui/_glutils/FramebufferTexture.py b/src/silx/gui/_glutils/FramebufferTexture.py
new file mode 100644
index 0000000..d12a6e0
--- /dev/null
+++ b/src/silx/gui/_glutils/FramebufferTexture.py
@@ -0,0 +1,168 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Association of a texture and a framebuffer object for off-screen rendering.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/07/2016"
+
+
+import logging
+
+from . import gl
+from .Texture import Texture
+
+
+_logger = logging.getLogger(__name__)
+
+
+class FramebufferTexture(object):
+ """Framebuffer with a texture.
+
+ Aimed at off-screen rendering to texture.
+
+ :param internalFormat: OpenGL texture internal format
+ :param shape: Shape (height, width) of the framebuffer and texture
+ :type shape: 2-tuple of int
+ :param stencilFormat: Stencil renderbuffer format
+ :param depthFormat: Depth renderbuffer format
+ :param kwargs: Extra arguments for :class:`Texture` constructor
+ """
+
+ _PACKED_FORMAT = gl.GL_DEPTH24_STENCIL8, gl.GL_DEPTH_STENCIL
+
+ def __init__(self,
+ internalFormat,
+ shape,
+ stencilFormat=gl.GL_DEPTH24_STENCIL8,
+ depthFormat=gl.GL_DEPTH24_STENCIL8,
+ **kwargs):
+
+ self._texture = Texture(internalFormat, shape=shape, **kwargs)
+ self._texture.prepare()
+
+ self._previousFramebuffer = 0 # Used by with statement
+
+ self._name = gl.glGenFramebuffers(1)
+
+ with self: # Bind FBO
+ # Attachments
+ gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER,
+ gl.GL_COLOR_ATTACHMENT0,
+ gl.GL_TEXTURE_2D,
+ self._texture.name,
+ 0)
+
+ height, width = self._texture.shape
+
+ if stencilFormat is not None:
+ self._stencilId = gl.glGenRenderbuffers(1)
+ gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self._stencilId)
+ gl.glRenderbufferStorage(gl.GL_RENDERBUFFER,
+ stencilFormat,
+ width, height)
+ gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER,
+ gl.GL_STENCIL_ATTACHMENT,
+ gl.GL_RENDERBUFFER,
+ self._stencilId)
+ else:
+ self._stencilId = None
+
+ if depthFormat is not None:
+ if self._stencilId and depthFormat in self._PACKED_FORMAT:
+ self._depthId = self._stencilId
+ else:
+ self._depthId = gl.glGenRenderbuffers(1)
+ gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self._depthId)
+ gl.glRenderbufferStorage(gl.GL_RENDERBUFFER,
+ depthFormat,
+ width, height)
+ gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER,
+ gl.GL_DEPTH_ATTACHMENT,
+ gl.GL_RENDERBUFFER,
+ self._depthId)
+ else:
+ self._depthId = None
+
+ status = gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER)
+ if status != gl.GL_FRAMEBUFFER_COMPLETE:
+ _logger.error(
+ "OpenGL framebuffer initialization not complete, display may fail (error %d)",
+ status)
+
+ @property
+ def shape(self):
+ """Shape of the framebuffer (height, width)"""
+ return self._texture.shape
+
+ @property
+ def texture(self):
+ """The texture this framebuffer is rendering to.
+
+ The life-cycle of the texture is managed by this object"""
+ return self._texture
+
+ @property
+ def name(self):
+ """OpenGL name of the framebuffer"""
+ if self._name is not None:
+ return self._name
+ else:
+ raise RuntimeError("No OpenGL framebuffer resource, \
+ discard has already been called")
+
+ def bind(self):
+ """Bind this framebuffer for rendering"""
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.name)
+
+ # with statement
+
+ def __enter__(self):
+ self._previousFramebuffer = gl.glGetInteger(gl.GL_FRAMEBUFFER_BINDING)
+ self.bind()
+
+ def __exit__(self, exctype, excvalue, traceback):
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self._previousFramebuffer)
+ self._previousFramebuffer = None
+
+ def discard(self):
+ """Delete associated OpenGL resources including texture"""
+ if self._name is not None:
+ gl.glDeleteFramebuffers(self._name)
+ self._name = None
+
+ if self._stencilId is not None:
+ gl.glDeleteRenderbuffers(self._stencilId)
+ if self._stencilId == self._depthId:
+ self._depthId = None
+ self._stencilId = None
+ if self._depthId is not None:
+ gl.glDeleteRenderbuffers(self._depthId)
+ self._depthId = None
+
+ self._texture.discard() # Also discard the texture
+ else:
+ _logger.warning("Discard has already been called")
diff --git a/src/silx/gui/_glutils/OpenGLWidget.py b/src/silx/gui/_glutils/OpenGLWidget.py
new file mode 100644
index 0000000..2ca4649
--- /dev/null
+++ b/src/silx/gui/_glutils/OpenGLWidget.py
@@ -0,0 +1,422 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides a compatibility layer for OpenGL widget.
+
+It provides a compatibility layer for Qt OpenGL widget used in silx
+across Qt<=5.3 QtOpenGL.QGLWidget and QOpenGLWidget.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "22/11/2019"
+
+
+import logging
+import sys
+
+from .. import qt
+from ..utils.glutils import isOpenGLAvailable
+from .._glutils import gl
+
+
+_logger = logging.getLogger(__name__)
+
+
+if not hasattr(qt, 'QOpenGLWidget') and not hasattr(qt, 'QGLWidget'):
+ OpenGLWidget = None
+
+else:
+ if hasattr(qt, 'QOpenGLWidget'): # PyQt>=5.4
+ _logger.info('Using QOpenGLWidget')
+ _BaseOpenGLWidget = qt.QOpenGLWidget
+
+ else:
+ _logger.info('Using QGLWidget')
+ _BaseOpenGLWidget = qt.QGLWidget
+
+ class _OpenGLWidget(_BaseOpenGLWidget):
+ """Wrapper over QOpenGLWidget and QGLWidget"""
+
+ sigOpenGLContextError = qt.Signal(str)
+ """Signal emitted when an OpenGL context error is detected at runtime.
+
+ It provides the error reason as a str.
+ """
+
+ def __init__(self, parent,
+ alphaBufferSize=0,
+ depthBufferSize=24,
+ stencilBufferSize=8,
+ version=(2, 0),
+ f=qt.Qt.WindowFlags()):
+ # True if using QGLWidget, False if using QOpenGLWidget
+ self.__legacy = not hasattr(qt, 'QOpenGLWidget')
+
+ self.__devicePixelRatio = 1.0
+ self.__requestedOpenGLVersion = int(version[0]), int(version[1])
+ self.__isValid = False
+
+ if self.__legacy: # QGLWidget
+ format_ = qt.QGLFormat()
+ format_.setAlphaBufferSize(alphaBufferSize)
+ format_.setAlpha(alphaBufferSize != 0)
+ format_.setDepthBufferSize(depthBufferSize)
+ format_.setDepth(depthBufferSize != 0)
+ format_.setStencilBufferSize(stencilBufferSize)
+ format_.setStencil(stencilBufferSize != 0)
+ format_.setVersion(*self.__requestedOpenGLVersion)
+ format_.setDoubleBuffer(True)
+
+ super(_OpenGLWidget, self).__init__(format_, parent, None, f)
+
+ else: # QOpenGLWidget
+ super(_OpenGLWidget, self).__init__(parent, f)
+
+ format_ = qt.QSurfaceFormat()
+ format_.setAlphaBufferSize(alphaBufferSize)
+ format_.setDepthBufferSize(depthBufferSize)
+ format_.setStencilBufferSize(stencilBufferSize)
+ format_.setVersion(*self.__requestedOpenGLVersion)
+ format_.setSwapBehavior(qt.QSurfaceFormat.DoubleBuffer)
+ self.setFormat(format_)
+
+ # Enable receiving mouse move events when no buttons are pressed
+ self.setMouseTracking(True)
+
+ def getDevicePixelRatio(self):
+ """Returns the ratio device-independent / device pixel size
+
+ It should be either 1.0 or 2.0.
+
+ :return: Scale factor between screen and Qt units
+ :rtype: float
+ """
+ return self.__devicePixelRatio
+
+ def getRequestedOpenGLVersion(self):
+ """Returns the requested OpenGL version.
+
+ :return: (major, minor)
+ :rtype: 2-tuple of int"""
+ return self.__requestedOpenGLVersion
+
+ def getOpenGLVersion(self):
+ """Returns the available OpenGL version.
+
+ :return: (major, minor)
+ :rtype: 2-tuple of int"""
+ if self.__legacy: # QGLWidget
+ supportedVersion = 0, 0
+
+ # Go through all OpenGL version flags checking support
+ flags = self.format().openGLVersionFlags()
+ for version in ((1, 1), (1, 2), (1, 3), (1, 4), (1, 5),
+ (2, 0), (2, 1),
+ (3, 0), (3, 1), (3, 2), (3, 3),
+ (4, 0)):
+ versionFlag = getattr(qt.QGLFormat,
+ 'OpenGL_Version_%d_%d' % version)
+ if not versionFlag & flags:
+ break
+ supportedVersion = version
+ return supportedVersion
+
+ else: # QOpenGLWidget
+ return self.format().version()
+
+ # QOpenGLWidget methods
+
+ def isValid(self):
+ """Returns True if OpenGL is available.
+
+ This adds extra checks to Qt isValid method.
+
+ :rtype: bool
+ """
+ return self.__isValid and super(_OpenGLWidget, self).isValid()
+
+ def defaultFramebufferObject(self):
+ """Returns the framebuffer object handle.
+
+ See :meth:`QOpenGLWidget.defaultFramebufferObject`
+ """
+ if self.__legacy: # QGLWidget
+ return 0
+ else: # QOpenGLWidget
+ return super(_OpenGLWidget, self).defaultFramebufferObject()
+
+ # *GL overridden methods
+
+ def initializeGL(self):
+ parent = self.parent()
+ if parent is None:
+ _logger.error('_OpenGLWidget has no parent')
+ return
+
+ # Check OpenGL version
+ if self.getOpenGLVersion() >= self.getRequestedOpenGLVersion():
+ try:
+ gl.glGetError() # clear any previous error (if any)
+ version = gl.glGetString(gl.GL_VERSION)
+ except:
+ version = None
+
+ if version:
+ self.__isValid = True
+ else:
+ errMsg = 'OpenGL not available'
+ if sys.platform.startswith('linux'):
+ errMsg += ': If connected remotely, ' \
+ 'GLX forwarding might be disabled.'
+ _logger.error(errMsg)
+ self.sigOpenGLContextError.emit(errMsg)
+ self.__isValid = False
+
+ else:
+ errMsg = 'OpenGL %d.%d not available' % \
+ self.getRequestedOpenGLVersion()
+ _logger.error('OpenGL widget disabled: %s', errMsg)
+ self.sigOpenGLContextError.emit(errMsg)
+ self.__isValid = False
+
+ if self.isValid():
+ parent.initializeGL()
+
+ def paintGL(self):
+ parent = self.parent()
+ if parent is None:
+ _logger.error('_OpenGLWidget has no parent')
+ return
+
+ devicePixelRatio = self.window().windowHandle().devicePixelRatio()
+
+ if devicePixelRatio != self.getDevicePixelRatio():
+ # Update devicePixelRatio and call resizeOpenGL
+ # as resizeGL is not always called.
+ self.__devicePixelRatio = devicePixelRatio
+ self.makeCurrent()
+ parent.resizeGL(self.width(), self.height())
+
+ if self.isValid():
+ parent.paintGL()
+
+ def resizeGL(self, width, height):
+ parent = self.parent()
+ if parent is None:
+ _logger.error('_OpenGLWidget has no parent')
+ return
+
+ if self.isValid():
+ # Call parent resizeGL with device-independent pixel unit
+ # This works over both QGLWidget and QOpenGLWidget
+ parent.resizeGL(self.width(), self.height())
+
+
+class OpenGLWidget(qt.QWidget):
+ """OpenGL widget wrapper over QGLWidget and QOpenGLWidget
+
+ This wrapper API implements a subset of QOpenGLWidget API.
+ The constructor takes a different set of arguments.
+ Methods returning object like :meth:`context` returns either
+ QGL* or QOpenGL* objects.
+
+ :param parent: Parent widget see :class:`QWidget`
+ :param int alphaBufferSize:
+ Size in bits of the alpha channel (default: 0).
+ Set to 0 to disable alpha channel.
+ :param int depthBufferSize:
+ Size in bits of the depth buffer (default: 24).
+ Set to 0 to disable depth buffer.
+ :param int stencilBufferSize:
+ Size in bits of the stencil buffer (default: 8).
+ Set to 0 to disable stencil buffer
+ :param version: Requested OpenGL version (default: (2, 0)).
+ :type version: 2-tuple of int
+ :param f: see :class:`QWidget`
+ """
+
+ def __init__(self, parent=None,
+ alphaBufferSize=0,
+ depthBufferSize=24,
+ stencilBufferSize=8,
+ version=(2, 0),
+ f=qt.Qt.WindowFlags()):
+ super(OpenGLWidget, self).__init__(parent, f)
+
+ layout = qt.QHBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ self.setLayout(layout)
+
+ self.__context = None
+
+ _check = isOpenGLAvailable(version=version, runtimeCheck=False)
+ if _OpenGLWidget is None or not _check:
+ _logger.error('OpenGL-based widget disabled: %s', _check.error)
+ self.__openGLWidget = None
+ label = self._createErrorQLabel(_check.error)
+ self.layout().addWidget(label)
+
+ else:
+ self.__openGLWidget = _OpenGLWidget(
+ parent=self,
+ alphaBufferSize=alphaBufferSize,
+ depthBufferSize=depthBufferSize,
+ stencilBufferSize=stencilBufferSize,
+ version=version,
+ f=f)
+ # Async connection need, otherwise issue when hiding OpenGL
+ # widget while doing the rendering..
+ self.__openGLWidget.sigOpenGLContextError.connect(
+ self._handleOpenGLInitError, qt.Qt.QueuedConnection)
+ self.layout().addWidget(self.__openGLWidget)
+
+ @staticmethod
+ def _createErrorQLabel(error):
+ """Create QLabel displaying error message in place of OpenGL widget
+
+ :param str error: The error message to display"""
+ label = qt.QLabel()
+ label.setText('OpenGL-based widget disabled:\n%s' % error)
+ label.setAlignment(qt.Qt.AlignCenter)
+ label.setWordWrap(True)
+ return label
+
+ def _handleOpenGLInitError(self, error):
+ """Handle runtime errors in OpenGL widget"""
+ if self.__openGLWidget is not None:
+ self.__openGLWidget.setVisible(False)
+ self.__openGLWidget.setParent(None)
+ self.__openGLWidget = None
+
+ label = self._createErrorQLabel(error)
+ self.layout().addWidget(label)
+
+ # Additional API
+
+ def getDevicePixelRatio(self):
+ """Returns the ratio device-independent / device pixel size
+
+ It should be either 1.0 or 2.0.
+
+ :return: Scale factor between screen and Qt units
+ :rtype: float
+ """
+ if self.__openGLWidget is None:
+ return 1.
+ else:
+ return self.__openGLWidget.getDevicePixelRatio()
+
+ def getDotsPerInch(self):
+ """Returns current screen resolution as device pixels per inch.
+
+ :rtype: float
+ """
+ screen = self.window().windowHandle().screen()
+ if screen is not None:
+ # TODO check if this is correct on different OS/screen
+ # OK on macOS10.12/qt5.13.2
+ dpi = screen.physicalDotsPerInch() * self.getDevicePixelRatio()
+ else: # Fallback
+ dpi = 96. * self.getDevicePixelRatio()
+ return dpi
+
+ def getOpenGLVersion(self):
+ """Returns the available OpenGL version.
+
+ :return: (major, minor)
+ :rtype: 2-tuple of int"""
+ if self.__openGLWidget is None:
+ return 0, 0
+ else:
+ return self.__openGLWidget.getOpenGLVersion()
+
+ # QOpenGLWidget API
+
+ def isValid(self):
+ """Returns True if OpenGL with the requested version is available.
+
+ :rtype: bool
+ """
+ if self.__openGLWidget is None:
+ return False
+ else:
+ return self.__openGLWidget.isValid()
+
+ def context(self):
+ """Return Qt OpenGL context object or None.
+
+ See :meth:`QOpenGLWidget.context` and :meth:`QGLWidget.context`
+ """
+ if self.__openGLWidget is None:
+ return None
+ else:
+ # Keep a reference on QOpenGLContext to make
+ # else PyQt5 keeps creating a new one.
+ self.__context = self.__openGLWidget.context()
+ return self.__context
+
+ def defaultFramebufferObject(self):
+ """Returns the framebuffer object handle.
+
+ See :meth:`QOpenGLWidget.defaultFramebufferObject`
+ """
+ if self.__openGLWidget is None:
+ return 0
+ else:
+ return self.__openGLWidget.defaultFramebufferObject()
+
+ def makeCurrent(self):
+ """Make the underlying OpenGL widget's context current.
+
+ See :meth:`QOpenGLWidget.makeCurrent`
+ """
+ if self.__openGLWidget is not None:
+ self.__openGLWidget.makeCurrent()
+
+ def update(self):
+ """Async update of the OpenGL widget.
+
+ See :meth:`QOpenGLWidget.update`
+ """
+ if self.__openGLWidget is not None:
+ self.__openGLWidget.update()
+
+ # QOpenGLWidget API to override
+
+ def initializeGL(self):
+ """Override to implement OpenGL initialization."""
+ pass
+
+ def paintGL(self):
+ """Override to implement OpenGL rendering."""
+ pass
+
+ def resizeGL(self, width, height):
+ """Override to implement resize of OpenGL framebuffer.
+
+ :param int width: Width in device-independent pixels
+ :param int height: Height in device-independent pixels
+ """
+ pass
diff --git a/src/silx/gui/_glutils/Program.py b/src/silx/gui/_glutils/Program.py
new file mode 100644
index 0000000..87eec5f
--- /dev/null
+++ b/src/silx/gui/_glutils/Program.py
@@ -0,0 +1,202 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a class to handle shader program compilation."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/07/2016"
+
+
+import logging
+import weakref
+
+import numpy
+
+from . import Context, gl
+
+_logger = logging.getLogger(__name__)
+
+
+class Program(object):
+ """Wrap OpenGL shader program.
+
+ The program is compiled lazily (i.e., at first program :meth:`use`).
+ When the program is compiled, it stores attributes and uniforms locations.
+ So, attributes and uniforms must be used after :meth:`use`.
+
+ This object supports multiple OpenGL contexts.
+
+ :param str vertexShader: The source of the vertex shader.
+ :param str fragmentShader: The source of the fragment shader.
+ :param str attrib0:
+ Attribute's name to bind to position 0 (default: 'position').
+ On certain platform, this attribute MUST be active and with an
+ array attached to it in order for the rendering to occur....
+ """
+
+ def __init__(self, vertexShader, fragmentShader,
+ attrib0='position'):
+ self._vertexShader = vertexShader
+ self._fragmentShader = fragmentShader
+ self._attrib0 = attrib0
+ self._programs = weakref.WeakKeyDictionary()
+
+ @staticmethod
+ def _compileGL(vertexShader, fragmentShader, attrib0):
+ program = gl.glCreateProgram()
+
+ gl.glBindAttribLocation(program, 0, attrib0.encode('ascii'))
+
+ vertex = gl.glCreateShader(gl.GL_VERTEX_SHADER)
+ gl.glShaderSource(vertex, vertexShader)
+ gl.glCompileShader(vertex)
+ if gl.glGetShaderiv(vertex, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
+ raise RuntimeError(gl.glGetShaderInfoLog(vertex))
+ gl.glAttachShader(program, vertex)
+ gl.glDeleteShader(vertex)
+
+ fragment = gl.glCreateShader(gl.GL_FRAGMENT_SHADER)
+ gl.glShaderSource(fragment, fragmentShader)
+ gl.glCompileShader(fragment)
+ if gl.glGetShaderiv(fragment,
+ gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
+ raise RuntimeError(gl.glGetShaderInfoLog(fragment))
+ gl.glAttachShader(program, fragment)
+ gl.glDeleteShader(fragment)
+
+ gl.glLinkProgram(program)
+ if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE:
+ raise RuntimeError(gl.glGetProgramInfoLog(program))
+
+ attributes = {}
+ for index in range(gl.glGetProgramiv(program,
+ gl.GL_ACTIVE_ATTRIBUTES)):
+ name = gl.glGetActiveAttrib(program, index)[0]
+ namestr = name.decode('ascii')
+ attributes[namestr] = gl.glGetAttribLocation(program, name)
+
+ uniforms = {}
+ for index in range(gl.glGetProgramiv(program, gl.GL_ACTIVE_UNIFORMS)):
+ name = gl.glGetActiveUniform(program, index)[0]
+ namestr = name.decode('ascii')
+ uniforms[namestr] = gl.glGetUniformLocation(program, name)
+
+ return program, attributes, uniforms
+
+ def _getProgramInfo(self):
+ glcontext = Context.getCurrent()
+ if glcontext not in self._programs:
+ raise RuntimeError(
+ "Program was not compiled for current OpenGL context.")
+ return self._programs[glcontext]
+
+ @property
+ def attributes(self):
+ """Vertex attributes names and locations as a dict of {str: int}.
+
+ WARNING:
+ Read-only usage.
+ To use only with a valid OpenGL context and after :meth:`use`
+ has been called for this context.
+ """
+ return self._getProgramInfo()[1]
+
+ @property
+ def uniforms(self):
+ """Program uniforms names and locations as a dict of {str: int}.
+
+ WARNING:
+ Read-only usage.
+ To use only with a valid OpenGL context and after :meth:`use`
+ has been called for this context.
+ """
+ return self._getProgramInfo()[2]
+
+ @property
+ def program(self):
+ """OpenGL id of the program.
+
+ WARNING:
+ To use only with a valid OpenGL context and after :meth:`use`
+ has been called for this context.
+ """
+ return self._getProgramInfo()[0]
+
+ # def discard(self):
+ # pass # Not implemented yet
+
+ def use(self):
+ """Make use of the program, compiling it if necessary"""
+ glcontext = Context.getCurrent()
+
+ if glcontext not in self._programs:
+ self._programs[glcontext] = self._compileGL(
+ self._vertexShader,
+ self._fragmentShader,
+ self._attrib0)
+
+ if _logger.getEffectiveLevel() <= logging.DEBUG:
+ gl.glValidateProgram(self.program)
+ if gl.glGetProgramiv(
+ self.program, gl.GL_VALIDATE_STATUS) != gl.GL_TRUE:
+ _logger.debug('Cannot validate program: %s',
+ gl.glGetProgramInfoLog(self.program))
+
+ gl.glUseProgram(self.program)
+
+ def setUniformMatrix(self, name, value, transpose=True, safe=False):
+ """Wrap glUniformMatrix[2|3|4]fv
+
+ :param str name: The name of the uniform.
+ :param value: The 2D matrix (or the array of matrices, 3D).
+ Matrices are 2x2, 3x3 or 4x4.
+ :type value: numpy.ndarray with 2 or 3 dimensions of float32
+ :param bool transpose: Whether to transpose (True, default) or not.
+ :param bool safe: False: raise an error if no uniform with this name;
+ True: silently ignores it.
+
+ :raises KeyError: if no uniform corresponds to name.
+ """
+ assert value.dtype == numpy.float32
+
+ shape = value.shape
+ assert len(shape) in (2, 3)
+ assert shape[-1] in (2, 3, 4)
+ assert shape[-1] == shape[-2] # As in OpenGL|ES 2.0
+
+ location = self.uniforms.get(name)
+ if location is not None:
+ count = 1 if len(shape) == 2 else shape[0]
+ transpose = gl.GL_TRUE if transpose else gl.GL_FALSE
+
+ if shape[-1] == 2:
+ gl.glUniformMatrix2fv(location, count, transpose, value)
+ elif shape[-1] == 3:
+ gl.glUniformMatrix3fv(location, count, transpose, value)
+ elif shape[-1] == 4:
+ gl.glUniformMatrix4fv(location, count, transpose, value)
+
+ elif not safe:
+ raise KeyError('No uniform: %s' % name)
diff --git a/src/silx/gui/_glutils/Texture.py b/src/silx/gui/_glutils/Texture.py
new file mode 100644
index 0000000..c72135a
--- /dev/null
+++ b/src/silx/gui/_glutils/Texture.py
@@ -0,0 +1,352 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a class wrapping OpenGL 2D and 3D texture."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "04/10/2016"
+
+
+try:
+ from collections import abc
+except ImportError: # Python2 support
+ import collections as abc
+
+from ctypes import c_void_p
+import logging
+
+import numpy
+
+from . import gl, utils
+
+
+_logger = logging.getLogger(__name__)
+
+
+class Texture(object):
+ """Base class to wrap OpenGL 2D and 3D texture
+
+ :param internalFormat: OpenGL texture internal format
+ :param data: The data to copy to the texture or None for an empty texture
+ :type data: numpy.ndarray or None
+ :param format_: Input data format if different from internalFormat
+ :param shape: If data is None, shape of the texture
+ (height, width) or (depth, height, width)
+ :type shape: List[int]
+ :param int texUnit: The texture unit to use
+ :param minFilter: OpenGL texture minimization filter (default: GL_NEAREST)
+ :param magFilter: OpenGL texture magnification filter (default: GL_LINEAR)
+ :param wrap: Texture wrap mode for dimensions: (t, s) or (r, t, s)
+ If a single value is provided, it used for all dimensions.
+ :type wrap: OpenGL wrap mode or 2 or 3-tuple of wrap mode
+ """
+
+ def __init__(self, internalFormat, data=None, format_=None,
+ shape=None, texUnit=0,
+ minFilter=None, magFilter=None, wrap=None):
+
+ self._internalFormat = internalFormat
+ if format_ is None:
+ format_ = self.internalFormat
+
+ if data is None:
+ assert shape is not None
+ else:
+ assert shape is None
+ data = numpy.array(data, copy=False, order='C')
+ if format_ != gl.GL_RED:
+ shape = data.shape[:-1] # Last dimension is channels
+ else:
+ shape = data.shape
+
+ self._deferredUpdates = [(format_, data, None)]
+
+ assert len(shape) in (2, 3)
+ self._shape = tuple(shape)
+ self._ndim = len(shape)
+
+ self.texUnit = texUnit
+
+ self._texParameterUpdates = {} # Store texture params to update
+
+ self._minFilter = minFilter if minFilter is not None else gl.GL_NEAREST
+ self._texParameterUpdates[gl.GL_TEXTURE_MIN_FILTER] = self._minFilter
+
+ self._magFilter = magFilter if magFilter is not None else gl.GL_LINEAR
+ self._texParameterUpdates[gl.GL_TEXTURE_MAG_FILTER] = self._magFilter
+
+ self._name = None # Store texture ID
+
+ if wrap is not None:
+ if not isinstance(wrap, abc.Iterable):
+ wrap = [wrap] * self.ndim
+
+ assert len(wrap) == self.ndim
+
+ self._texParameterUpdates[gl.GL_TEXTURE_WRAP_S] = wrap[-1]
+ self._texParameterUpdates[gl.GL_TEXTURE_WRAP_T] = wrap[-2]
+ if self.ndim == 3:
+ self._texParameterUpdates[gl.GL_TEXTURE_WRAP_R] = wrap[0]
+
+ @property
+ def target(self):
+ """OpenGL target type of this texture"""
+ return gl.GL_TEXTURE_2D if self.ndim == 2 else gl.GL_TEXTURE_3D
+
+ @property
+ def ndim(self):
+ """The number of dimensions: 2 or 3"""
+ return self._ndim
+
+ @property
+ def internalFormat(self):
+ """Texture internal format"""
+ return self._internalFormat
+
+ @property
+ def shape(self):
+ """Shape of the texture: (height, width) or (depth, height, width)"""
+ return self._shape
+
+ @property
+ def name(self):
+ """OpenGL texture name.
+
+ It is None if not initialized or already discarded.
+ """
+ return self._name
+
+ @property
+ def minFilter(self):
+ """Minifying function parameter (GL_TEXTURE_MIN_FILTER)"""
+ return self._minFilter
+
+ @minFilter.setter
+ def minFilter(self, minFilter):
+ if minFilter != self.minFilter:
+ self._minFilter = minFilter
+ self._texParameterUpdates[gl.GL_TEXTURE_MIN_FILTER] = minFilter
+
+ @property
+ def magFilter(self):
+ """Magnification function parameter (GL_TEXTURE_MAG_FILTER)"""
+ return self._magFilter
+
+ @magFilter.setter
+ def magFilter(self, magFilter):
+ if magFilter != self.magFilter:
+ self._magFilter = magFilter
+ self._texParameterUpdates[gl.GL_TEXTURE_MAG_FILTER] = magFilter
+
+ def _isPrepareRequired(self) -> bool:
+ """Returns True if OpenGL texture needs to be updated.
+
+ :rtype: bool
+ """
+ return (self._name is None or
+ self._texParameterUpdates or
+ self._deferredUpdates)
+
+ def _prepareAndBind(self, texUnit=None):
+ """Synchronizes the OpenGL texture"""
+ if self._name is None:
+ self._name = gl.glGenTextures(1)
+
+ self._bind(texUnit)
+
+ # Synchronizes texture parameters
+ for pname, param in self._texParameterUpdates.items():
+ gl.glTexParameter(self.target, pname, param)
+ self._texParameterUpdates = {}
+
+ # Copy data to texture
+ for format_, data, offset in self._deferredUpdates:
+ gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
+
+ # This are the defaults, useless to set if not modified
+ # gl.glPixelStorei(gl.GL_UNPACK_ROW_LENGTH, 0)
+ # gl.glPixelStorei(gl.GL_UNPACK_SKIP_PIXELS, 0)
+ # gl.glPixelStorei(gl.GL_UNPACK_SKIP_ROWS, 0)
+ # gl.glPixelStorei(gl.GL_UNPACK_IMAGE_HEIGHT, 0)
+ # gl.glPixelStorei(gl.GL_UNPACK_SKIP_IMAGES, 0)
+
+ if data is None:
+ data = c_void_p(0)
+ type_ = gl.GL_UNSIGNED_BYTE
+ else:
+ type_ = utils.numpyToGLType(data.dtype)
+
+ if offset is None: # Initialize texture
+ if self.ndim == 2:
+ _logger.debug(
+ 'Creating 2D texture shape: (%d, %d),'
+ ' internal format: %s, format: %s, type: %s',
+ self.shape[0], self.shape[1],
+ str(self.internalFormat), str(format_), str(type_))
+
+ gl.glTexImage2D(
+ gl.GL_TEXTURE_2D,
+ 0,
+ self.internalFormat,
+ self.shape[1],
+ self.shape[0],
+ 0,
+ format_,
+ type_,
+ data)
+
+ else:
+ _logger.debug(
+ 'Creating 3D texture shape: (%d, %d, %d),'
+ ' internal format: %s, format: %s, type: %s',
+ self.shape[0], self.shape[1], self.shape[2],
+ str(self.internalFormat), str(format_), str(type_))
+
+ gl.glTexImage3D(
+ gl.GL_TEXTURE_3D,
+ 0,
+ self.internalFormat,
+ self.shape[2],
+ self.shape[1],
+ self.shape[0],
+ 0,
+ format_,
+ type_,
+ data)
+
+ else: # Update already existing texture
+ if self.ndim == 2:
+ gl.glTexSubImage2D(gl.GL_TEXTURE_2D,
+ 0,
+ offset[1],
+ offset[0],
+ data.shape[1],
+ data.shape[0],
+ format_,
+ type_,
+ data)
+
+ else:
+ gl.glTexSubImage3D(gl.GL_TEXTURE_3D,
+ 0,
+ offset[2],
+ offset[1],
+ offset[0],
+ data.shape[2],
+ data.shape[1],
+ data.shape[0],
+ format_,
+ type_,
+ data)
+
+ self._deferredUpdates = []
+
+ def _bind(self, texUnit=None):
+ """Bind the texture to a texture unit.
+
+ :param int texUnit: The texture unit to use
+ """
+ if texUnit is None:
+ texUnit = self.texUnit
+ gl.glActiveTexture(gl.GL_TEXTURE0 + texUnit)
+ gl.glBindTexture(self.target, self.name)
+
+ def _unbind(self, texUnit=None):
+ """Reset texture binding to a texture unit.
+
+ :param int texUnit: The texture unit to use
+ """
+ if texUnit is None:
+ texUnit = self.texUnit
+ gl.glActiveTexture(gl.GL_TEXTURE0 + texUnit)
+ gl.glBindTexture(self.target, 0)
+
+ def prepare(self):
+ """Synchronizes the OpenGL texture.
+
+ This method must be called with a current OpenGL context.
+ """
+ if self._isPrepareRequired():
+ self._prepareAndBind()
+ self._unbind()
+
+ def bind(self, texUnit=None):
+ """Bind the texture to a texture unit.
+
+ The OpenGL texture is updated if needed.
+
+ This method must be called with a current OpenGL context.
+
+ :param int texUnit: The texture unit to use
+ """
+ if self._isPrepareRequired():
+ self._prepareAndBind(texUnit)
+ else:
+ self._bind(texUnit)
+
+ def discard(self):
+ """Delete associated OpenGL texture.
+
+ This method must be called with a current OpenGL context.
+ """
+ if self._name is not None:
+ gl.glDeleteTextures(self._name)
+ self._name = None
+ else:
+ _logger.warning("Texture not initialized or already discarded")
+
+ # with statement
+
+ def __enter__(self):
+ self.bind()
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self._unbind()
+
+ def update(self, format_, data, offset=(0, 0, 0), copy=True):
+ """Update the content of the texture.
+
+ Texture is not resized, so data must fit into texture with the
+ given offset.
+
+ This update is performed lazily during next call to
+ :meth:`prepare` or :meth:`bind`.
+ Data MUST not be changed until then.
+
+ :param format_: The OpenGL format of the data
+ :param data: The data to use to update the texture
+ :param List[int] offset: Offset in the texture where to copy the data
+ :param bool copy:
+ True (default) to copy data, False to use as is (do not modify)
+ """
+ data = numpy.array(data, copy=copy, order='C')
+ offset = tuple(offset)
+
+ assert data.ndim == self.ndim
+ assert len(offset) >= self.ndim
+ for i in range(self.ndim):
+ assert offset[i] + data.shape[i] <= self.shape[i]
+
+ self._deferredUpdates.append((format_, data, offset))
diff --git a/src/silx/gui/_glutils/VertexBuffer.py b/src/silx/gui/_glutils/VertexBuffer.py
new file mode 100644
index 0000000..b74b748
--- /dev/null
+++ b/src/silx/gui/_glutils/VertexBuffer.py
@@ -0,0 +1,266 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a class managing an OpenGL vertex buffer."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "10/01/2017"
+
+
+import logging
+from ctypes import c_void_p
+import numpy
+
+from . import gl
+from .utils import numpyToGLType, sizeofGLType
+
+
+_logger = logging.getLogger(__name__)
+
+
+class VertexBuffer(object):
+ """Object handling an OpenGL vertex buffer object
+
+ :param data: Data used to fill the vertex buffer
+ :type data: numpy.ndarray or None
+ :param int size: Size in bytes of the buffer or None for data size
+ :param usage: OpenGL vertex buffer expected usage pattern:
+ GL_STREAM_DRAW, GL_STATIC_DRAW (default) or GL_DYNAMIC_DRAW
+ :param target: Target buffer:
+ GL_ARRAY_BUFFER (default) or GL_ELEMENT_ARRAY_BUFFER
+ """
+ # OpenGL|ES 2.0 subset:
+ _USAGES = gl.GL_STREAM_DRAW, gl.GL_STATIC_DRAW, gl.GL_DYNAMIC_DRAW
+ _TARGETS = gl.GL_ARRAY_BUFFER, gl.GL_ELEMENT_ARRAY_BUFFER
+
+ def __init__(self,
+ data=None,
+ size=None,
+ usage=None,
+ target=None):
+ if usage is None:
+ usage = gl.GL_STATIC_DRAW
+ assert usage in self._USAGES
+
+ if target is None:
+ target = gl.GL_ARRAY_BUFFER
+ assert target in self._TARGETS
+
+ self._target = target
+ self._usage = usage
+
+ self._name = gl.glGenBuffers(1)
+ self.bind()
+
+ if data is None:
+ assert size is not None
+ self._size = size
+ gl.glBufferData(self._target,
+ self._size,
+ c_void_p(0),
+ self._usage)
+ else:
+ data = numpy.array(data, copy=False, order='C')
+ if size is not None:
+ assert size <= data.nbytes
+
+ self._size = size or data.nbytes
+ gl.glBufferData(self._target,
+ self._size,
+ data,
+ self._usage)
+
+ gl.glBindBuffer(self._target, 0)
+
+ @property
+ def target(self):
+ """The target buffer of the vertex buffer"""
+ return self._target
+
+ @property
+ def usage(self):
+ """The expected usage of the vertex buffer"""
+ return self._usage
+
+ @property
+ def name(self):
+ """OpenGL Vertex Buffer object name (int)"""
+ if self._name is not None:
+ return self._name
+ else:
+ raise RuntimeError("No OpenGL buffer resource, \
+ discard has already been called")
+
+ @property
+ def size(self):
+ """Size in bytes of the Vertex Buffer Object (int)"""
+ if self._size is not None:
+ return self._size
+ else:
+ raise RuntimeError("No OpenGL buffer resource, \
+ discard has already been called")
+
+ def bind(self):
+ """Bind the vertex buffer"""
+ gl.glBindBuffer(self._target, self.name)
+
+ def update(self, data, offset=0, size=None):
+ """Update vertex buffer content.
+
+ :param numpy.ndarray data: The data to put in the vertex buffer
+ :param int offset: Offset in bytes in the buffer where to put the data
+ :param int size: If provided, size of data to copy
+ """
+ data = numpy.array(data, copy=False, order='C')
+ if size is None:
+ size = data.nbytes
+ assert offset + size <= self.size
+ with self:
+ gl.glBufferSubData(self._target, offset, size, data)
+
+ def discard(self):
+ """Delete the vertex buffer"""
+ if self._name is not None:
+ gl.glDeleteBuffers(self._name)
+ self._name = None
+ self._size = None
+ else:
+ _logger.warning("Discard has already been called")
+
+ # with statement
+
+ def __enter__(self):
+ self.bind()
+
+ def __exit__(self, exctype, excvalue, traceback):
+ gl.glBindBuffer(self._target, 0)
+
+
+class VertexBufferAttrib(object):
+ """Describes data stored in a vertex buffer
+
+ Convenient class to store info for glVertexAttribPointer calls
+
+ :param VertexBuffer vbo: The vertex buffer storing the data
+ :param int type_: The OpenGL type of the data
+ :param int size: The number of data elements stored in the VBO
+ :param int dimension: The number of `type_` element(s) in [1, 4]
+ :param int offset: Start offset of data in the vertex buffer
+ :param int stride: Data stride in the vertex buffer
+ """
+
+ _GL_TYPES = gl.GL_UNSIGNED_BYTE, gl.GL_FLOAT, gl.GL_INT
+
+ def __init__(self,
+ vbo,
+ type_,
+ size,
+ dimension=1,
+ offset=0,
+ stride=0,
+ normalization=False):
+ self.vbo = vbo
+ assert type_ in self._GL_TYPES
+ self.type_ = type_
+ self.size = size
+ assert 1 <= dimension <= 4
+ self.dimension = dimension
+ self.offset = offset
+ self.stride = stride
+ self.normalization = bool(normalization)
+
+ @property
+ def itemsize(self):
+ """Size in bytes of a vertex buffer element (int)"""
+ return self.dimension * sizeofGLType(self.type_)
+
+ itemSize = itemsize # Backward compatibility
+
+ def setVertexAttrib(self, attribute):
+ """Call glVertexAttribPointer with objects information"""
+ normalization = gl.GL_TRUE if self.normalization else gl.GL_FALSE
+ with self.vbo:
+ gl.glVertexAttribPointer(attribute,
+ self.dimension,
+ self.type_,
+ normalization,
+ self.stride,
+ c_void_p(self.offset))
+
+ def copy(self):
+ return VertexBufferAttrib(self.vbo,
+ self.type_,
+ self.size,
+ self.dimension,
+ self.offset,
+ self.stride,
+ self.normalization)
+
+
+def vertexBuffer(arrays, prefix=None, suffix=None, usage=None):
+ """Create a single vertex buffer from multiple 1D or 2D numpy arrays.
+
+ It is possible to reserve memory before and after each array in the VBO
+
+ :param arrays: Arrays of data to store
+ :type arrays: Iterable of numpy.ndarray
+ :param prefix: If given, number of elements to reserve before each array
+ :type prefix: Iterable of int or None
+ :param suffix: If given, number of elements to reserve after each array
+ :type suffix: Iterable of int or None
+ :param int usage: vertex buffer expected usage or None for default
+ :returns: List of VertexBufferAttrib objects sharing the same vertex buffer
+ """
+ info = []
+ vbosize = 0
+
+ if prefix is None:
+ prefix = (0,) * len(arrays)
+ if suffix is None:
+ suffix = (0,) * len(arrays)
+
+ for data, pre, post in zip(arrays, prefix, suffix):
+ data = numpy.array(data, copy=False, order='C')
+ shape = data.shape
+ assert len(shape) <= 2
+ type_ = numpyToGLType(data.dtype)
+ size = shape[0] + pre + post
+ dimension = 1 if len(shape) == 1 else shape[1]
+ sizeinbytes = size * dimension * sizeofGLType(type_)
+ sizeinbytes = 4 * ((sizeinbytes + 3) >> 2) # 4 bytes alignment
+ copyoffset = vbosize + pre * dimension * sizeofGLType(type_)
+ info.append((data, type_, size, dimension,
+ vbosize, sizeinbytes, copyoffset))
+ vbosize += sizeinbytes
+
+ vbo = VertexBuffer(size=vbosize, usage=usage)
+
+ result = []
+ for data, type_, size, dimension, offset, sizeinbytes, copyoffset in info:
+ copysize = data.shape[0] * dimension * sizeofGLType(type_)
+ vbo.update(data, offset=copyoffset, size=copysize)
+ result.append(
+ VertexBufferAttrib(vbo, type_, size, dimension, offset, 0))
+ return result
diff --git a/src/silx/gui/_glutils/__init__.py b/src/silx/gui/_glutils/__init__.py
new file mode 100644
index 0000000..e88affd
--- /dev/null
+++ b/src/silx/gui/_glutils/__init__.py
@@ -0,0 +1,43 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides utility functions to handle OpenGL resources.
+
+The :mod:`gl` module provides a wrapper to OpenGL based on PyOpenGL.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/07/2016"
+
+
+# OpenGL convenient functions
+from .OpenGLWidget import OpenGLWidget # noqa
+from . import Context # noqa
+from .FramebufferTexture import FramebufferTexture # noqa
+from .Program import Program # noqa
+from .Texture import Texture # noqa
+from .VertexBuffer import VertexBuffer, VertexBufferAttrib, vertexBuffer # noqa
+from .utils import sizeofGLType, isSupportedGLType, numpyToGLType # noqa
+from .utils import segmentTrianglesIntersection # noqa
diff --git a/src/silx/gui/_glutils/font.py b/src/silx/gui/_glutils/font.py
new file mode 100644
index 0000000..3ea474d
--- /dev/null
+++ b/src/silx/gui/_glutils/font.py
@@ -0,0 +1,156 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Text rasterisation feature leveraging Qt font and text layout support."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "13/10/2016"
+
+
+import logging
+import numpy
+
+from ..utils.image import convertQImageToArray
+from .. import qt
+
+_logger = logging.getLogger(__name__)
+
+
+def getDefaultFontFamily():
+ """Returns the default font family of the application"""
+ return qt.QApplication.instance().font().family()
+
+
+# Font weights
+ULTRA_LIGHT = 0
+"""Lightest characters: Minimum font weight"""
+
+LIGHT = 25
+"""Light characters"""
+
+NORMAL = 50
+"""Normal characters"""
+
+SEMI_BOLD = 63
+"""Between normal and bold characters"""
+
+BOLD = 74
+"""Thicker characters"""
+
+BLACK = 87
+"""Really thick characters"""
+
+ULTRA_BLACK = 99
+"""Thickest characters: Maximum font weight"""
+
+
+def rasterText(text, font,
+ size=-1,
+ weight=-1,
+ italic=False,
+ devicePixelRatio=1.0):
+ """Raster text using Qt.
+
+ It supports multiple lines.
+
+ :param str text: The text to raster
+ :param font: Font name or QFont to use
+ :type font: str or :class:`QFont`
+ :param int size:
+ Font size in points
+ Used only if font is given as name.
+ :param int weight:
+ Font weight in [0, 99], see QFont.Weight.
+ Used only if font is given as name.
+ :param bool italic:
+ True for italic font (default: False).
+ Used only if font is given as name.
+ :param float devicePixelRatio:
+ The current ratio between device and device-independent pixel
+ (default: 1.0)
+ :return: Corresponding image in gray scale and baseline offset from top
+ :rtype: (HxW numpy.ndarray of uint8, int)
+ """
+ if not text:
+ _logger.info("Trying to raster empty text, replaced by white space")
+ text = ' ' # Replace empty text by white space to produce an image
+
+ if not isinstance(font, qt.QFont):
+ font = qt.QFont(font, size, weight, italic)
+
+ # get text size
+ image = qt.QImage(1, 1, qt.QImage.Format_RGB888)
+ painter = qt.QPainter()
+ painter.begin(image)
+ painter.setPen(qt.Qt.white)
+ painter.setFont(font)
+ bounds = painter.boundingRect(
+ qt.QRect(0, 0, 4096, 4096), qt.Qt.TextExpandTabs, text)
+ painter.end()
+
+ metrics = qt.QFontMetrics(font)
+
+ # This does not provide the correct text bbox on macOS
+ # size = metrics.size(qt.Qt.TextExpandTabs, text)
+ # bounds = metrics.boundingRect(
+ # qt.QRect(0, 0, size.width(), size.height()),
+ # qt.Qt.TextExpandTabs,
+ # text)
+
+ # Add extra border and handle devicePixelRatio
+ width = bounds.width() * devicePixelRatio + 2
+ # align line size to 32 bits to ease conversion to numpy array
+ width = 4 * ((width + 3) // 4)
+ image = qt.QImage(int(width),
+ int(bounds.height() * devicePixelRatio + 2),
+ qt.QImage.Format_RGB888)
+ image.setDevicePixelRatio(devicePixelRatio)
+
+ # TODO if Qt5 use Format_Grayscale8 instead
+ image.fill(0)
+
+ # Raster text
+ painter = qt.QPainter()
+ painter.begin(image)
+ painter.setPen(qt.Qt.white)
+ painter.setFont(font)
+ painter.drawText(bounds, qt.Qt.TextExpandTabs, text)
+ painter.end()
+
+ array = convertQImageToArray(image)
+
+ # RGB to R
+ array = numpy.ascontiguousarray(array[:, :, 0])
+
+ # Remove leading and trailing empty columns but one on each side
+ column_cumsum = numpy.cumsum(numpy.sum(array, axis=0))
+ array = array[:, column_cumsum.argmin():column_cumsum.argmax() + 2]
+
+ # Remove leading and trailing empty rows but one on each side
+ row_cumsum = numpy.cumsum(numpy.sum(array, axis=1))
+ min_row = row_cumsum.argmin()
+ array = array[min_row:row_cumsum.argmax() + 2, :]
+
+ return array, metrics.ascent() - min_row
diff --git a/src/silx/gui/_glutils/gl.py b/src/silx/gui/_glutils/gl.py
new file mode 100644
index 0000000..608d9ce
--- /dev/null
+++ b/src/silx/gui/_glutils/gl.py
@@ -0,0 +1,168 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module loads PyOpenGL and provides a namespace for OpenGL."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/07/2016"
+
+
+from contextlib import contextmanager as _contextmanager
+from ctypes import c_uint
+import logging
+
+_logger = logging.getLogger(__name__)
+
+import OpenGL
+# Set the following to true for debugging
+if _logger.getEffectiveLevel() <= logging.DEBUG:
+ _logger.debug('Enabling PyOpenGL debug flags')
+ OpenGL.ERROR_LOGGING = True
+ OpenGL.ERROR_CHECKING = True
+ OpenGL.ERROR_ON_COPY = True
+else:
+ OpenGL.ERROR_LOGGING = False
+ OpenGL.ERROR_CHECKING = False
+ OpenGL.ERROR_ON_COPY = False
+
+import OpenGL.GL as _GL
+from OpenGL.GL import * # noqa
+
+# Extentions core in OpenGL 3
+from OpenGL.GL.ARB import framebuffer_object as _FBO
+from OpenGL.GL.ARB.framebuffer_object import * # noqa
+from OpenGL.GL.ARB.texture_rg import GL_R32F, GL_R16F # noqa
+from OpenGL.GL.ARB.texture_rg import GL_R16, GL_R8 # noqa
+
+# PyOpenGL 3.0.1 does not define it
+try:
+ GLchar
+except NameError:
+ from ctypes import c_char
+ GLchar = c_char
+
+
+def testGL():
+ """Test if required OpenGL version and extensions are available.
+
+ This MUST be run with an active OpenGL context.
+ """
+ version = glGetString(GL_VERSION).split()[0] # get version number
+ major, minor = int(version[0]), int(version[2])
+ if major < 2 or (major == 2 and minor < 1):
+ raise RuntimeError(
+ "Requires at least OpenGL version 2.1, running with %s" % version)
+
+ from OpenGL.GL.ARB.framebuffer_object import glInitFramebufferObjectARB
+ from OpenGL.GL.ARB.texture_rg import glInitTextureRgARB
+
+ if not glInitFramebufferObjectARB():
+ raise RuntimeError(
+ "OpenGL GL_ARB_framebuffer_object extension required !")
+
+ if not glInitTextureRgARB():
+ raise RuntimeError("OpenGL GL_ARB_texture_rg extension required !")
+
+
+# Additional setup
+if hasattr(glget, 'addGLGetConstant'):
+ glget.addGLGetConstant(GL_FRAMEBUFFER_BINDING, (1,))
+
+
+@_contextmanager
+def enabled(capacity, enable=True):
+ """Context manager enabling an OpenGL capacity.
+
+ This is not checking the current state of the capacity.
+
+ :param capacity: The OpenGL capacity enum to enable/disable
+ :param bool enable:
+ True (default) to enable during context, False to disable
+ """
+ if bool(enable) == glGetBoolean(capacity):
+ # Already in the right state: noop
+ yield
+ elif enable:
+ glEnable(capacity)
+ yield
+ glDisable(capacity)
+ else:
+ glDisable(capacity)
+ yield
+ glEnable(capacity)
+
+
+def disabled(capacity, disable=True):
+ """Context manager disabling an OpenGL capacity.
+
+ This is not checking the current state of the capacity.
+
+ :param capacity: The OpenGL capacity enum to disable/enable
+ :param bool disable:
+ True (default) to disable during context, False to enable
+ """
+ return enabled(capacity, not disable)
+
+
+# Additional OpenGL wrapping
+
+def glGetActiveAttrib(program, index):
+ """Wrap PyOpenGL glGetActiveAttrib"""
+ bufsize = glGetProgramiv(program, GL_ACTIVE_ATTRIBUTE_MAX_LENGTH)
+ length = GLsizei()
+ size = GLint()
+ type_ = GLenum()
+ name = (GLchar * bufsize)()
+
+ _GL.glGetActiveAttrib(program, index, bufsize, length, size, type_, name)
+ return name.value, size.value, type_.value
+
+
+def glDeleteRenderbuffers(buffers):
+ if not hasattr(buffers, '__len__'): # Support single int argument
+ buffers = [buffers]
+ length = len(buffers)
+ _FBO.glDeleteRenderbuffers(length, (c_uint * length)(*buffers))
+
+
+def glDeleteFramebuffers(buffers):
+ if not hasattr(buffers, '__len__'): # Support single int argument
+ buffers = [buffers]
+ length = len(buffers)
+ _FBO.glDeleteFramebuffers(length, (c_uint * length)(*buffers))
+
+
+def glDeleteBuffers(buffers):
+ if not hasattr(buffers, '__len__'): # Support single int argument
+ buffers = [buffers]
+ length = len(buffers)
+ _GL.glDeleteBuffers(length, (c_uint * length)(*buffers))
+
+
+def glDeleteTextures(textures):
+ if not hasattr(textures, '__len__'): # Support single int argument
+ textures = [textures]
+ length = len(textures)
+ _GL.glDeleteTextures((c_uint * length)(*textures))
diff --git a/src/silx/gui/_glutils/utils.py b/src/silx/gui/_glutils/utils.py
new file mode 100644
index 0000000..5886599
--- /dev/null
+++ b/src/silx/gui/_glutils/utils.py
@@ -0,0 +1,123 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides conversion functions between OpenGL and numpy types.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "10/01/2017"
+
+import numpy
+
+from OpenGL.constants import BYTE_SIZES as _BYTE_SIZES
+from OpenGL.constants import ARRAY_TO_GL_TYPE_MAPPING as _ARRAY_TO_GL_TYPE_MAPPING
+
+
+def sizeofGLType(type_):
+ """Returns the size in bytes of an element of type `type_`"""
+ return _BYTE_SIZES[type_]
+
+
+def isSupportedGLType(type_):
+ """Test if a numpy type or dtype can be converted to a GL type."""
+ return numpy.dtype(type_).char in _ARRAY_TO_GL_TYPE_MAPPING
+
+
+def numpyToGLType(type_):
+ """Returns the GL type corresponding the provided numpy type or dtype."""
+ return _ARRAY_TO_GL_TYPE_MAPPING[numpy.dtype(type_).char]
+
+
+def segmentTrianglesIntersection(segment, triangles):
+ """Check for segment/triangles intersection.
+
+ This is based on signed tetrahedron volume comparison.
+
+ See A. Kensler, A., Shirley, P.
+ Optimizing Ray-Triangle Intersection via Automated Search.
+ Symposium on Interactive Ray Tracing, vol. 0, p33-38 (2006)
+
+ :param numpy.ndarray segment:
+ Segment end points as a 2x3 array of coordinates
+ :param numpy.ndarray triangles:
+ Nx3x3 array of triangles
+ :return: (triangle indices, segment parameter, barycentric coord)
+ Indices of intersected triangles, "depth" along the segment
+ of the intersection point and barycentric coordinates of intersection
+ point in the triangle.
+ :rtype: List[numpy.ndarray]
+ """
+ # TODO triangles from vertices + indices
+ # TODO early rejection? e.g., check segment bbox vs triangle bbox
+ segment = numpy.asarray(segment)
+ assert segment.ndim == 2
+ assert segment.shape == (2, 3)
+
+ triangles = numpy.asarray(triangles)
+ assert triangles.ndim == 3
+ assert triangles.shape[1] == 3
+
+ # Test line/triangles intersection
+ d = segment[1] - segment[0]
+ t0s0 = segment[0] - triangles[:, 0, :]
+ edge01 = triangles[:, 1, :] - triangles[:, 0, :]
+ edge02 = triangles[:, 2, :] - triangles[:, 0, :]
+
+ dCrossEdge02 = numpy.cross(d, edge02)
+ t0s0CrossEdge01 = numpy.cross(t0s0, edge01)
+ volume = numpy.sum(dCrossEdge02 * edge01, axis=1)
+ del edge01
+ subVolumes = numpy.empty((len(triangles), 3), dtype=triangles.dtype)
+ subVolumes[:, 1] = numpy.sum(dCrossEdge02 * t0s0, axis=1)
+ del dCrossEdge02
+ subVolumes[:, 2] = numpy.sum(t0s0CrossEdge01 * d, axis=1)
+ subVolumes[:, 0] = volume - subVolumes[:, 1] - subVolumes[:, 2]
+ intersect = numpy.logical_or(
+ numpy.all(subVolumes >= 0., axis=1), # All positive
+ numpy.all(subVolumes <= 0., axis=1)) # All negative
+ intersect = numpy.where(intersect)[0] # Indices of intersected triangles
+
+ # Get barycentric coordinates
+ with numpy.errstate(invalid="ignore"):
+ barycentric = subVolumes[intersect] / volume[intersect].reshape(-1, 1)
+ del subVolumes
+
+ # Test segment/triangles intersection
+ volAlpha = numpy.sum(t0s0CrossEdge01[intersect] * edge02[intersect], axis=1)
+ with numpy.errstate(invalid="ignore"):
+ t = volAlpha / volume[intersect] # segment parameter of intersected triangles
+ del t0s0CrossEdge01
+ del edge02
+ del volAlpha
+ del volume
+
+ inSegmentMask = numpy.logical_and(t >= 0., t <= 1.)
+ intersect = intersect[inSegmentMask]
+ t = t[inSegmentMask]
+ barycentric = barycentric[inSegmentMask]
+
+ # Sort intersecting triangles by t
+ indices = numpy.argsort(t)
+ return intersect[indices], t[indices], barycentric[indices]
diff --git a/src/silx/gui/colors.py b/src/silx/gui/colors.py
new file mode 100755
index 0000000..12046cf
--- /dev/null
+++ b/src/silx/gui/colors.py
@@ -0,0 +1,1036 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides API to manage colors.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent", "H.Payno"]
+__license__ = "MIT"
+__date__ = "29/01/2019"
+
+import numpy
+import logging
+
+from silx.gui import qt
+from silx.gui.utils import blockSignals
+from silx.math import colormap as _colormap
+from silx.utils.exceptions import NotEditableError
+from silx.utils import deprecation
+
+
+_logger = logging.getLogger(__name__)
+
+try:
+ import silx.gui.utils.matplotlib # noqa Initalize matplotlib
+ from matplotlib import cm as _matplotlib_cm
+ from matplotlib.pyplot import colormaps as _matplotlib_colormaps
+except ImportError:
+ _logger.info("matplotlib not available, only embedded colormaps available")
+ _matplotlib_cm = None
+ _matplotlib_colormaps = None
+
+
+_COLORDICT = {}
+"""Dictionary of common colors."""
+
+_COLORDICT['b'] = _COLORDICT['blue'] = '#0000ff'
+_COLORDICT['r'] = _COLORDICT['red'] = '#ff0000'
+_COLORDICT['g'] = _COLORDICT['green'] = '#00ff00'
+_COLORDICT['k'] = _COLORDICT['black'] = '#000000'
+_COLORDICT['w'] = _COLORDICT['white'] = '#ffffff'
+_COLORDICT['pink'] = '#ff66ff'
+_COLORDICT['brown'] = '#a52a2a'
+_COLORDICT['orange'] = '#ff9900'
+_COLORDICT['violet'] = '#6600ff'
+_COLORDICT['gray'] = _COLORDICT['grey'] = '#a0a0a4'
+# _COLORDICT['darkGray'] = _COLORDICT['darkGrey'] = '#808080'
+# _COLORDICT['lightGray'] = _COLORDICT['lightGrey'] = '#c0c0c0'
+_COLORDICT['y'] = _COLORDICT['yellow'] = '#ffff00'
+_COLORDICT['m'] = _COLORDICT['magenta'] = '#ff00ff'
+_COLORDICT['c'] = _COLORDICT['cyan'] = '#00ffff'
+_COLORDICT['darkBlue'] = '#000080'
+_COLORDICT['darkRed'] = '#800000'
+_COLORDICT['darkGreen'] = '#008000'
+_COLORDICT['darkBrown'] = '#660000'
+_COLORDICT['darkCyan'] = '#008080'
+_COLORDICT['darkYellow'] = '#808000'
+_COLORDICT['darkMagenta'] = '#800080'
+_COLORDICT['transparent'] = '#00000000'
+
+
+# FIXME: It could be nice to expose a functional API instead of that attribute
+COLORDICT = _COLORDICT
+
+
+DEFAULT_MIN_LIN = 0
+"""Default min value if in linear normalization"""
+DEFAULT_MAX_LIN = 1
+"""Default max value if in linear normalization"""
+
+
+def rgba(color, colorDict=None):
+ """Convert color code '#RRGGBB' and '#RRGGBBAA' to a tuple (R, G, B, A)
+ of floats.
+
+ It also supports RGB(A) from uint8 in [0, 255], float in [0, 1], and
+ QColor as color argument.
+
+ :param str color: The color to convert
+ :param dict colorDict: A dictionary of color name conversion to color code
+ :returns: RGBA colors as floats in [0., 1.]
+ :rtype: tuple
+ """
+ if colorDict is None:
+ colorDict = _COLORDICT
+
+ if hasattr(color, 'getRgb'): # QColor support
+ color = color.getRgb()
+
+ values = numpy.asarray(color).ravel()
+
+ if values.dtype.kind in 'iuf': # integer or float
+ # Color is an array
+ assert len(values) in (3, 4)
+
+ # Convert from integers in [0, 255] to float in [0, 1]
+ if values.dtype.kind in 'iu':
+ values = values / 255.
+
+ # Clip to [0, 1]
+ values[values < 0.] = 0.
+ values[values > 1.] = 1.
+
+ if len(values) == 3:
+ return values[0], values[1], values[2], 1.
+ else:
+ return tuple(values)
+
+ # We assume color is a string
+ if not color.startswith('#'):
+ color = colorDict[color]
+
+ assert len(color) in (7, 9) and color[0] == '#'
+ r = int(color[1:3], 16) / 255.
+ g = int(color[3:5], 16) / 255.
+ b = int(color[5:7], 16) / 255.
+ a = int(color[7:9], 16) / 255. if len(color) == 9 else 1.
+ return r, g, b, a
+
+
+def greyed(color, colorDict=None):
+ """Convert color code '#RRGGBB' and '#RRGGBBAA' to a grey color
+ (R, G, B, A).
+
+ It also supports RGB(A) from uint8 in [0, 255], float in [0, 1], and
+ QColor as color argument.
+
+ :param str color: The color to convert
+ :param dict colorDict: A dictionary of color name conversion to color code
+ :returns: RGBA colors as floats in [0., 1.]
+ :rtype: tuple
+ """
+ r, g, b, a = rgba(color=color, colorDict=colorDict)
+ g = 0.21 * r + 0.72 * g + 0.07 * b
+ return g, g, g, a
+
+
+def asQColor(color):
+ """Convert color code '#RRGGBB' and '#RRGGBBAA' to a `qt.QColor`.
+
+ It also supports RGB(A) from uint8 in [0, 255], float in [0, 1], and
+ QColor as color argument.
+
+ :param str color: The color to convert
+ :rtype: qt.QColor
+ """
+ color = rgba(color)
+ return qt.QColor.fromRgbF(*color)
+
+
+def cursorColorForColormap(colormapName):
+ """Get a color suitable for overlay over a colormap.
+
+ :param str colormapName: The name of the colormap.
+ :return: Name of the color.
+ :rtype: str
+ """
+ return _colormap.get_colormap_cursor_color(colormapName)
+
+
+# Colormap loader
+
+def _registerColormapFromMatplotlib(name, cursor_color='black', preferred=False):
+ colormap = _matplotlib_cm.get_cmap(name)
+ lut = colormap(numpy.linspace(0, 1, colormap.N, endpoint=True))
+ colors = _colormap.array_to_rgba8888(lut)
+ registerLUT(name, colors, cursor_color, preferred)
+
+
+def _getColormap(name):
+ """Returns the color LUT corresponding to a colormap name
+ :param str name: Name of the colormap to load
+ :returns: Corresponding table of colors
+ :rtype: numpy.ndarray
+ :raise ValueError: If no colormap corresponds to name
+ """
+ name = str(name)
+ try:
+ return _colormap.get_colormap_lut(name)
+ except ValueError:
+ # Colormap is not available, try to load it from matplotlib
+ _registerColormapFromMatplotlib(name, 'black', False)
+ return _colormap.get_colormap_lut(name)
+
+
+class Colormap(qt.QObject):
+ """Description of a colormap
+
+ If no `name` nor `colors` are provided, a default gray LUT is used.
+
+ :param str name: Name of the colormap
+ :param tuple colors: optional, custom colormap.
+ Nx3 or Nx4 numpy array of RGB(A) colors,
+ either uint8 or float in [0, 1].
+ If 'name' is None, then this array is used as the colormap.
+ :param str normalization: Normalization: 'linear' (default) or 'log'
+ :param vmin: Lower bound of the colormap or None for autoscale (default)
+ :type vmin: Union[None, float]
+ :param vmax: Upper bounds of the colormap or None for autoscale (default)
+ :type vmax: Union[None, float]
+ """
+
+ LINEAR = 'linear'
+ """constant for linear normalization"""
+
+ LOGARITHM = 'log'
+ """constant for logarithmic normalization"""
+
+ SQRT = 'sqrt'
+ """constant for square root normalization"""
+
+ GAMMA = 'gamma'
+ """Constant for gamma correction normalization"""
+
+ ARCSINH = 'arcsinh'
+ """constant for inverse hyperbolic sine normalization"""
+
+ _BASIC_NORMALIZATIONS = {
+ LINEAR: _colormap.LinearNormalization(),
+ LOGARITHM: _colormap.LogarithmicNormalization(),
+ SQRT: _colormap.SqrtNormalization(),
+ ARCSINH: _colormap.ArcsinhNormalization(),
+ }
+ """Normalizations without parameters"""
+
+ NORMALIZATIONS = LINEAR, LOGARITHM, SQRT, GAMMA, ARCSINH
+ """Tuple of managed normalizations"""
+
+ MINMAX = 'minmax'
+ """constant for autoscale using min/max data range"""
+
+ STDDEV3 = 'stddev3'
+ """constant for autoscale using mean +/- 3*std(data)
+ with a clamp on min/max of the data"""
+
+ AUTOSCALE_MODES = (MINMAX, STDDEV3)
+ """Tuple of managed auto scale algorithms"""
+
+ sigChanged = qt.Signal()
+ """Signal emitted when the colormap has changed."""
+
+ _DEFAULT_NAN_COLOR = 255, 255, 255, 0
+
+ def __init__(self, name=None, colors=None, normalization=LINEAR, vmin=None, vmax=None, autoscaleMode=MINMAX):
+ qt.QObject.__init__(self)
+ self._editable = True
+ self.__gamma = 2.0
+ # Default NaN color: fully transparent white
+ self.__nanColor = numpy.array(self._DEFAULT_NAN_COLOR, dtype=numpy.uint8)
+
+ assert normalization in Colormap.NORMALIZATIONS
+ assert autoscaleMode in Colormap.AUTOSCALE_MODES
+
+ if normalization is Colormap.LOGARITHM:
+ if (vmin is not None and vmin < 0) or (vmax is not None and vmax < 0):
+ m = "Unsuported vmin (%s) and/or vmax (%s) given for a log scale."
+ m += ' Autoscale will be performed.'
+ m = m % (vmin, vmax)
+ _logger.warning(m)
+ vmin = None
+ vmax = None
+
+ self._name = None
+ self._colors = None
+
+ if colors is not None and name is not None:
+ deprecation.deprecated_warning("Argument",
+ name="silx.gui.plot.Colors",
+ reason="name and colors can't be used at the same time",
+ since_version="0.10.0",
+ skip_backtrace_count=1)
+
+ colors = None
+
+ if name is not None:
+ self.setName(name) # And resets colormap LUT
+ elif colors is not None:
+ self.setColormapLUT(colors)
+ else:
+ # Default colormap is grey
+ self.setName("gray")
+
+ self._normalization = str(normalization)
+ self._autoscaleMode = str(autoscaleMode)
+ self._vmin = float(vmin) if vmin is not None else None
+ self._vmax = float(vmax) if vmax is not None else None
+ self.__warnBadVmin = True
+ self.__warnBadVmax = True
+
+ def setFromColormap(self, other):
+ """Set this colormap using information from the `other` colormap.
+
+ :param ~silx.gui.colors.Colormap other: Colormap to use as reference.
+ """
+ if not self.isEditable():
+ raise NotEditableError('Colormap is not editable')
+ if self == other:
+ return
+ with blockSignals(self):
+ name = other.getName()
+ if name is not None:
+ self.setName(name)
+ else:
+ self.setColormapLUT(other.getColormapLUT())
+ self.setNaNColor(other.getNaNColor())
+ self.setNormalization(other.getNormalization())
+ self.setGammaNormalizationParameter(
+ other.getGammaNormalizationParameter())
+ self.setAutoscaleMode(other.getAutoscaleMode())
+ self.setVRange(*other.getVRange())
+ self.setEditable(other.isEditable())
+ self.sigChanged.emit()
+
+ def getNColors(self, nbColors=None):
+ """Returns N colors computed by sampling the colormap regularly.
+
+ :param nbColors:
+ The number of colors in the returned array or None for the default value.
+ The default value is the size of the colormap LUT.
+ :type nbColors: int or None
+ :return: 2D array of uint8 of shape (nbColors, 4)
+ :rtype: numpy.ndarray
+ """
+ # Handle default value for nbColors
+ if nbColors is None:
+ return numpy.array(self._colors, copy=True)
+ else:
+ nbColors = int(nbColors)
+ colormap = self.copy()
+ colormap.setNormalization(Colormap.LINEAR)
+ colormap.setVRange(vmin=0, vmax=nbColors - 1)
+ colors = colormap.applyToData(
+ numpy.arange(nbColors, dtype=numpy.int32))
+ return colors
+
+ def getName(self):
+ """Return the name of the colormap
+ :rtype: str
+ """
+ return self._name
+
+ def setName(self, name):
+ """Set the name of the colormap to use.
+
+ :param str name: The name of the colormap.
+ At least the following names are supported: 'gray',
+ 'reversed gray', 'temperature', 'red', 'green', 'blue', 'jet',
+ 'viridis', 'magma', 'inferno', 'plasma'.
+ """
+ name = str(name)
+ if self._name == name:
+ return
+ if self.isEditable() is False:
+ raise NotEditableError('Colormap is not editable')
+ if name not in self.getSupportedColormaps():
+ raise ValueError("Colormap name '%s' is not supported" % name)
+ self._name = name
+ self._colors = _getColormap(self._name)
+ self.sigChanged.emit()
+
+ def getColormapLUT(self, copy=True):
+ """Return the list of colors for the colormap or None if not set.
+
+ This returns None if the colormap was set with :meth:`setName`.
+ Use :meth:`getNColors` to get the colormap LUT for any colormap.
+
+ :param bool copy: If true a copy of the numpy array is provided
+ :return: the list of colors for the colormap or None if not set
+ :rtype: numpy.ndarray or None
+ """
+ if self._name is None:
+ return numpy.array(self._colors, copy=copy)
+ else:
+ return None
+
+ def setColormapLUT(self, colors):
+ """Set the colors of the colormap.
+
+ :param numpy.ndarray colors: the colors of the LUT.
+ If float, it is converted from [0, 1] to uint8 range.
+ Otherwise it is casted to uint8.
+
+ .. warning: this will set the value of name to None
+ """
+ if self.isEditable() is False:
+ raise NotEditableError('Colormap is not editable')
+ assert colors is not None
+
+ colors = numpy.array(colors, copy=False)
+ if colors.shape == ():
+ raise TypeError("An array is expected for 'colors' argument. '%s' was found." % type(colors))
+ assert len(colors) != 0
+ assert colors.ndim >= 2
+ colors.shape = -1, colors.shape[-1]
+ self._colors = _colormap.array_to_rgba8888(colors)
+ self._name = None
+ self.sigChanged.emit()
+
+ def getNaNColor(self):
+ """Returns the color to use for Not-A-Number floating point value.
+
+ :rtype: QColor
+ """
+ return qt.QColor(*self.__nanColor)
+
+ def setNaNColor(self, color):
+ """Set the color to use for Not-A-Number floating point value.
+
+ :param color: RGB(A) color to use for NaN values
+ :type color: QColor, str, tuple of uint8 or float in [0., 1.]
+ """
+ color = (numpy.array(rgba(color)) * 255).astype(numpy.uint8)
+ if not numpy.array_equal(self.__nanColor, color):
+ self.__nanColor = color
+ self.sigChanged.emit()
+
+ def getNormalization(self):
+ """Return the normalization of the colormap.
+
+ See :meth:`setNormalization` for returned values.
+
+ :return: the normalization of the colormap
+ :rtype: str
+ """
+ return self._normalization
+
+ def setNormalization(self, norm):
+ """Set the colormap normalization.
+
+ Accepted normalizations: 'log', 'linear', 'sqrt'
+
+ :param str norm: the norm to set
+ """
+ assert norm in self.NORMALIZATIONS
+ if self.isEditable() is False:
+ raise NotEditableError('Colormap is not editable')
+ norm = str(norm)
+ if norm != self._normalization:
+ self._normalization = norm
+ self.__warnBadVmin = True
+ self.__warnBadVmax = True
+ self.sigChanged.emit()
+
+ def setGammaNormalizationParameter(self, gamma: float) -> None:
+ """Set the gamma correction parameter.
+
+ Only used for gamma correction normalization.
+
+ :param float gamma:
+ :raise ValueError: If gamma is not valid
+ """
+ if gamma < 0. or not numpy.isfinite(gamma):
+ raise ValueError("Gamma value not supported")
+ if gamma != self.__gamma:
+ self.__gamma = gamma
+ self.sigChanged.emit()
+
+ def getGammaNormalizationParameter(self) -> float:
+ """Returns the gamma correction parameter value.
+
+ :rtype: float
+ """
+ return self.__gamma
+
+ def getAutoscaleMode(self):
+ """Return the autoscale mode of the colormap ('minmax' or 'stddev3')
+
+ :rtype: str
+ """
+ return self._autoscaleMode
+
+ def setAutoscaleMode(self, mode):
+ """Set the autoscale mode: either 'minmax' or 'stddev3'
+
+ :param str mode: the mode to set
+ """
+ if self.isEditable() is False:
+ raise NotEditableError('Colormap is not editable')
+ assert mode in self.AUTOSCALE_MODES
+ if mode != self._autoscaleMode:
+ self._autoscaleMode = mode
+ self.sigChanged.emit()
+
+ def isAutoscale(self):
+ """Return True if both min and max are in autoscale mode"""
+ return self._vmin is None and self._vmax is None
+
+ def getVMin(self):
+ """Return the lower bound of the colormap
+
+ :return: the lower bound of the colormap
+ :rtype: float or None
+ """
+ return self._vmin
+
+ def setVMin(self, vmin):
+ """Set the minimal value of the colormap
+
+ :param float vmin: Lower bound of the colormap or None for autoscale
+ (default)
+ value)
+ """
+ if self.isEditable() is False:
+ raise NotEditableError('Colormap is not editable')
+ if vmin is not None:
+ if self._vmax is not None and vmin > self._vmax:
+ err = "Can't set vmin because vmin >= vmax. " \
+ "vmin = %s, vmax = %s" % (vmin, self._vmax)
+ raise ValueError(err)
+
+ if vmin != self._vmin:
+ self._vmin = vmin
+ self.__warnBadVmin = True
+ self.sigChanged.emit()
+
+ def getVMax(self):
+ """Return the upper bounds of the colormap or None
+
+ :return: the upper bounds of the colormap or None
+ :rtype: float or None
+ """
+ return self._vmax
+
+ def setVMax(self, vmax):
+ """Set the maximal value of the colormap
+
+ :param float vmax: Upper bounds of the colormap or None for autoscale
+ (default)
+ """
+ if self.isEditable() is False:
+ raise NotEditableError('Colormap is not editable')
+ if vmax is not None:
+ if self._vmin is not None and vmax < self._vmin:
+ err = "Can't set vmax because vmax <= vmin. " \
+ "vmin = %s, vmax = %s" % (self._vmin, vmax)
+ raise ValueError(err)
+
+ if vmax != self._vmax:
+ self._vmax = vmax
+ self.__warnBadVmax = True
+ self.sigChanged.emit()
+
+ def isEditable(self):
+ """ Return if the colormap is editable or not
+
+ :return: editable state of the colormap
+ :rtype: bool
+ """
+ return self._editable
+
+ def setEditable(self, editable):
+ """
+ Set the editable state of the colormap
+
+ :param bool editable: is the colormap editable
+ """
+ assert type(editable) is bool
+ self._editable = editable
+ self.sigChanged.emit()
+
+ def _getNormalizer(self):
+ """Returns normalizer object"""
+ normalization = self.getNormalization()
+ if normalization == self.GAMMA:
+ return _colormap.GammaNormalization(self.getGammaNormalizationParameter())
+ else:
+ return self._BASIC_NORMALIZATIONS[normalization]
+
+ def _computeAutoscaleRange(self, data):
+ """Compute the data range which will be used in autoscale mode.
+
+ :param numpy.ndarray data: The data for which to compute the range
+ :return: (vmin, vmax) range
+ """
+ return self._getNormalizer().autoscale(
+ data, mode=self.getAutoscaleMode())
+
+ def getColormapRange(self, data=None):
+ """Return (vmin, vmax) the range of the colormap for the given data or item.
+
+ :param Union[numpy.ndarray,~silx.gui.plot.items.ColormapMixIn] data:
+ The data or item to use for autoscale bounds.
+ :return: (vmin, vmax) corresponding to the colormap applied to data if provided.
+ :rtype: tuple
+ """
+ vmin = self._vmin
+ vmax = self._vmax
+ assert vmin is None or vmax is None or vmin <= vmax # TODO handle this in setters
+
+ normalizer = self._getNormalizer()
+
+ # Handle invalid bounds as autoscale
+ if vmin is not None and not normalizer.is_valid(vmin):
+ if self.__warnBadVmin:
+ self.__warnBadVmin = False
+ _logger.info(
+ 'Invalid vmin, switching to autoscale for lower bound')
+ vmin = None
+ if vmax is not None and not normalizer.is_valid(vmax):
+ if self.__warnBadVmax:
+ self.__warnBadVmax = False
+ _logger.info(
+ 'Invalid vmax, switching to autoscale for upper bound')
+ vmax = None
+
+ if vmin is None or vmax is None: # Handle autoscale
+ from .plot.items.core import ColormapMixIn # avoid cyclic import
+ if isinstance(data, ColormapMixIn):
+ min_, max_ = data._getColormapAutoscaleRange(self)
+ # Make sure min_, max_ are not None
+ min_ = normalizer.DEFAULT_RANGE[0] if min_ is None else min_
+ max_ = normalizer.DEFAULT_RANGE[1] if max_ is None else max_
+ else:
+ min_, max_ = normalizer.autoscale(
+ data, mode=self.getAutoscaleMode())
+
+ if vmin is None: # Set vmin respecting provided vmax
+ vmin = min_ if vmax is None else min(min_, vmax)
+
+ if vmax is None:
+ vmax = max(max_, vmin) # Handle max_ <= 0 for log scale
+
+ return vmin, vmax
+
+ def getVRange(self):
+ """Get the bounds of the colormap
+
+ :rtype: Tuple(Union[float,None],Union[float,None])
+ :returns: A tuple of 2 values for min and max. Or None instead of float
+ for autoscale
+ """
+ return self.getVMin(), self.getVMax()
+
+ def setVRange(self, vmin, vmax):
+ """Set the bounds of the colormap
+
+ :param vmin: Lower bound of the colormap or None for autoscale
+ (default)
+ :param vmax: Upper bounds of the colormap or None for autoscale
+ (default)
+ """
+ if self.isEditable() is False:
+ raise NotEditableError('Colormap is not editable')
+ if vmin is not None and vmax is not None:
+ if vmin > vmax:
+ err = "Can't set vmin and vmax because vmin >= vmax " \
+ "vmin = %s, vmax = %s" % (vmin, vmax)
+ raise ValueError(err)
+
+ if self._vmin == vmin and self._vmax == vmax:
+ return
+
+ if vmin != self._vmin:
+ self.__warnBadVmin = True
+ self._vmin = vmin
+ if vmax != self._vmax:
+ self.__warnBadVmax = True
+ self._vmax = vmax
+ self.sigChanged.emit()
+
+ def __getitem__(self, item):
+ if item == 'autoscale':
+ return self.isAutoscale()
+ elif item == 'name':
+ return self.getName()
+ elif item == 'normalization':
+ return self.getNormalization()
+ elif item == 'vmin':
+ return self.getVMin()
+ elif item == 'vmax':
+ return self.getVMax()
+ elif item == 'colors':
+ return self.getColormapLUT()
+ elif item == 'autoscaleMode':
+ return self.getAutoscaleMode()
+ else:
+ raise KeyError(item)
+
+ def _toDict(self):
+ """Return the equivalent colormap as a dictionary
+ (old colormap representation)
+
+ :return: the representation of the Colormap as a dictionary
+ :rtype: dict
+ """
+ return {
+ 'name': self._name,
+ 'colors': self.getColormapLUT(),
+ 'vmin': self._vmin,
+ 'vmax': self._vmax,
+ 'autoscale': self.isAutoscale(),
+ 'normalization': self.getNormalization(),
+ 'autoscaleMode': self.getAutoscaleMode(),
+ }
+
+ def _setFromDict(self, dic):
+ """Set values to the colormap from a dictionary
+
+ :param dict dic: the colormap as a dictionary
+ """
+ if self.isEditable() is False:
+ raise NotEditableError('Colormap is not editable')
+ name = dic['name'] if 'name' in dic else None
+ colors = dic['colors'] if 'colors' in dic else None
+ if name is not None and colors is not None:
+ if isinstance(colors, int):
+ # Filter out argument which was supported but never used
+ _logger.info("Unused 'colors' from colormap dictionary filterer.")
+ colors = None
+ vmin = dic['vmin'] if 'vmin' in dic else None
+ vmax = dic['vmax'] if 'vmax' in dic else None
+ if 'normalization' in dic:
+ normalization = dic['normalization']
+ else:
+ warn = 'Normalization not given in the dictionary, '
+ warn += 'set by default to ' + Colormap.LINEAR
+ _logger.warning(warn)
+ normalization = Colormap.LINEAR
+
+ if name is None and colors is None:
+ err = 'The colormap should have a name defined or a tuple of colors'
+ raise ValueError(err)
+ if normalization not in Colormap.NORMALIZATIONS:
+ err = 'Given normalization is not recognized (%s)' % normalization
+ raise ValueError(err)
+
+ autoscaleMode = dic.get('autoscaleMode', Colormap.MINMAX)
+ if autoscaleMode not in Colormap.AUTOSCALE_MODES:
+ err = 'Given autoscale mode is not recognized (%s)' % autoscaleMode
+ raise ValueError(err)
+
+ # If autoscale, then set boundaries to None
+ if dic.get('autoscale', False):
+ vmin, vmax = None, None
+
+ if name is not None:
+ self.setName(name)
+ else:
+ self.setColormapLUT(colors)
+ self._vmin = vmin
+ self._vmax = vmax
+ self._autoscale = True if (vmin is None and vmax is None) else False
+ self._normalization = normalization
+ self._autoscaleMode = autoscaleMode
+
+ self.__warnBadVmin = True
+ self.__warnBadVmax = True
+ self.sigChanged.emit()
+
+ @staticmethod
+ def _fromDict(dic):
+ colormap = Colormap()
+ colormap._setFromDict(dic)
+ return colormap
+
+ def copy(self):
+ """Return a copy of the Colormap.
+
+ :rtype: silx.gui.colors.Colormap
+ """
+ colormap = Colormap(name=self._name,
+ colors=self.getColormapLUT(),
+ vmin=self._vmin,
+ vmax=self._vmax,
+ normalization=self.getNormalization(),
+ autoscaleMode=self.getAutoscaleMode())
+ colormap.setNaNColor(self.getNaNColor())
+ colormap.setGammaNormalizationParameter(
+ self.getGammaNormalizationParameter())
+ colormap.setEditable(self.isEditable())
+ return colormap
+
+ def applyToData(self, data, reference=None):
+ """Apply the colormap to the data
+
+ :param Union[numpy.ndarray,~silx.gui.plot.item.ColormapMixIn] data:
+ The data to convert or the item for which to apply the colormap.
+ :param Union[numpy.ndarray,~silx.gui.plot.item.ColormapMixIn,None] reference:
+ The data or item to use as reference to compute autoscale
+ """
+ if reference is None:
+ reference = data
+ vmin, vmax = self.getColormapRange(reference)
+
+ if hasattr(data, "getColormappedData"): # Use item's data
+ data = data.getColormappedData(copy=False)
+
+ return _colormap.cmap(
+ data,
+ self._colors,
+ vmin,
+ vmax,
+ self._getNormalizer(),
+ self.__nanColor)
+
+ @staticmethod
+ def getSupportedColormaps():
+ """Get the supported colormap names as a tuple of str.
+
+ The list should at least contain and start by:
+
+ ('gray', 'reversed gray', 'temperature', 'red', 'green', 'blue',
+ 'viridis', 'magma', 'inferno', 'plasma')
+
+ :rtype: tuple
+ """
+ registered_colormaps = _colormap.get_registered_colormaps()
+ colormaps = set(registered_colormaps)
+ if _matplotlib_colormaps is not None:
+ colormaps.update(_matplotlib_colormaps())
+
+ # Put registered_colormaps first
+ colormaps = tuple(cmap for cmap in sorted(colormaps)
+ if cmap not in registered_colormaps)
+ return registered_colormaps + colormaps
+
+ def __str__(self):
+ return str(self._toDict())
+
+ def __eq__(self, other):
+ """Compare colormap values and not pointers"""
+ if other is None:
+ return False
+ if not isinstance(other, Colormap):
+ return False
+ if self.getNormalization() != other.getNormalization():
+ return False
+ if self.getNormalization() == self.GAMMA:
+ delta = self.getGammaNormalizationParameter() - other.getGammaNormalizationParameter()
+ if abs(delta) > 0.001:
+ return False
+ return (self.getName() == other.getName() and
+ self.getAutoscaleMode() == other.getAutoscaleMode() and
+ self.getVMin() == other.getVMin() and
+ self.getVMax() == other.getVMax() and
+ numpy.array_equal(self.getColormapLUT(), other.getColormapLUT())
+ )
+
+ _SERIAL_VERSION = 3
+
+ def restoreState(self, byteArray):
+ """
+ Read the colormap state from a QByteArray.
+
+ :param qt.QByteArray byteArray: Stream containing the state
+ :return: True if the restoration sussseed
+ :rtype: bool
+ """
+ if self.isEditable() is False:
+ raise NotEditableError('Colormap is not editable')
+ stream = qt.QDataStream(byteArray, qt.QIODevice.ReadOnly)
+
+ className = stream.readQString()
+ if className != self.__class__.__name__:
+ _logger.warning("Classname mismatch. Found %s." % className)
+ return False
+
+ version = stream.readUInt32()
+ if version not in numpy.arange(1, self._SERIAL_VERSION+1):
+ _logger.warning("Serial version mismatch. Found %d." % version)
+ return False
+
+ name = stream.readQString()
+ isNull = stream.readBool()
+ if not isNull:
+ vmin = stream.readQVariant()
+ else:
+ vmin = None
+ isNull = stream.readBool()
+ if not isNull:
+ vmax = stream.readQVariant()
+ else:
+ vmax = None
+
+ normalization = stream.readQString()
+ if normalization == Colormap.GAMMA:
+ gamma = stream.readFloat()
+ else:
+ gamma = None
+
+ if version == 1:
+ autoscaleMode = Colormap.MINMAX
+ else:
+ autoscaleMode = stream.readQString()
+
+ if version <= 2:
+ nanColor = self._DEFAULT_NAN_COLOR
+ else:
+ nanColor = stream.readInt32(), stream.readInt32(), stream.readInt32(), stream.readInt32()
+
+ # emit change event only once
+ old = self.blockSignals(True)
+ try:
+ self.setName(name)
+ self.setNormalization(normalization)
+ self.setAutoscaleMode(autoscaleMode)
+ self.setVRange(vmin, vmax)
+ if gamma is not None:
+ self.setGammaNormalizationParameter(gamma)
+ self.setNaNColor(nanColor)
+ finally:
+ self.blockSignals(old)
+ self.sigChanged.emit()
+ return True
+
+ def saveState(self):
+ """
+ Save state of the colomap into a QDataStream.
+
+ :rtype: qt.QByteArray
+ """
+ data = qt.QByteArray()
+ stream = qt.QDataStream(data, qt.QIODevice.WriteOnly)
+
+ stream.writeQString(self.__class__.__name__)
+ stream.writeUInt32(self._SERIAL_VERSION)
+ stream.writeQString(self.getName())
+ stream.writeBool(self.getVMin() is None)
+ if self.getVMin() is not None:
+ stream.writeQVariant(self.getVMin())
+ stream.writeBool(self.getVMax() is None)
+ if self.getVMax() is not None:
+ stream.writeQVariant(self.getVMax())
+ stream.writeQString(self.getNormalization())
+ if self.getNormalization() == Colormap.GAMMA:
+ stream.writeFloat(self.getGammaNormalizationParameter())
+ stream.writeQString(self.getAutoscaleMode())
+ nanColor = self.getNaNColor()
+ stream.writeInt32(nanColor.red())
+ stream.writeInt32(nanColor.green())
+ stream.writeInt32(nanColor.blue())
+ stream.writeInt32(nanColor.alpha())
+
+ return data
+
+
+_PREFERRED_COLORMAPS = None
+"""
+Tuple of preferred colormap names accessed with :meth:`preferredColormaps`.
+"""
+
+_DEFAULT_PREFERRED_COLORMAPS = (
+ 'gray', 'reversed gray', 'red', 'green', 'blue',
+ 'viridis', 'cividis', 'magma', 'inferno', 'plasma',
+ 'temperature',
+ 'jet', 'hsv'
+)
+
+
+def preferredColormaps():
+ """Returns the name of the preferred colormaps.
+
+ This list is used by widgets allowing to change the colormap
+ like the :class:`ColormapDialog` as a subset of colormap choices.
+
+ :rtype: tuple of str
+ """
+ global _PREFERRED_COLORMAPS
+ if _PREFERRED_COLORMAPS is None:
+ # Initialize preferred colormaps
+ setPreferredColormaps(_DEFAULT_PREFERRED_COLORMAPS)
+ return tuple(_PREFERRED_COLORMAPS)
+
+
+def setPreferredColormaps(colormaps):
+ """Set the list of preferred colormap names.
+
+ Warning: If a colormap name is not available
+ it will be removed from the list.
+
+ :param colormaps: Not empty list of colormap names
+ :type colormaps: iterable of str
+ :raise ValueError: if the list of available preferred colormaps is empty.
+ """
+ supportedColormaps = Colormap.getSupportedColormaps()
+ colormaps = [cmap for cmap in colormaps if cmap in supportedColormaps]
+ if len(colormaps) == 0:
+ raise ValueError("Cannot set preferred colormaps to an empty list")
+
+ global _PREFERRED_COLORMAPS
+ _PREFERRED_COLORMAPS = colormaps
+
+
+def registerLUT(name, colors, cursor_color='black', preferred=True):
+ """Register a custom LUT to be used with `Colormap` objects.
+
+ It can override existing LUT names.
+
+ :param str name: Name of the LUT as defined to configure colormaps
+ :param numpy.ndarray colors: The custom LUT to register.
+ Nx3 or Nx4 numpy array of RGB(A) colors,
+ either uint8 or float in [0, 1].
+ :param bool preferred: If true, this LUT will be displayed as part of the
+ preferred colormaps in dialogs.
+ :param str cursor_color: Color used to display overlay over images using
+ colormap with this LUT.
+ """
+ _colormap.register_colormap(name, colors, cursor_color)
+
+ if preferred:
+ # Invalidate the preferred cache
+ global _PREFERRED_COLORMAPS
+ if _PREFERRED_COLORMAPS is not None:
+ if name not in _PREFERRED_COLORMAPS:
+ _PREFERRED_COLORMAPS.append(name)
+ else:
+ # The cache is not yet loaded, it's fine
+ pass
+
+
+# Load some colormaps from matplotlib by default
+if _matplotlib_cm is not None:
+ _registerColormapFromMatplotlib('jet', cursor_color='pink', preferred=True)
+ _registerColormapFromMatplotlib('hsv', cursor_color='black', preferred=True)
diff --git a/src/silx/gui/conftest.py b/src/silx/gui/conftest.py
new file mode 100644
index 0000000..74b5c19
--- /dev/null
+++ b/src/silx/gui/conftest.py
@@ -0,0 +1,5 @@
+import pytest
+
+@pytest.fixture(autouse=True)
+def auto_qapp(qapp):
+ pass
diff --git a/src/silx/gui/console.py b/src/silx/gui/console.py
new file mode 100644
index 0000000..953b6a1
--- /dev/null
+++ b/src/silx/gui/console.py
@@ -0,0 +1,202 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides an IPython console widget.
+
+You can push variables - any python object - to the
+console's interactive namespace. This provides users with an advanced way
+of interacting with your program. For instance, if your program has a
+:class:`PlotWidget` or a :class:`PlotWindow`, you can push a reference to
+these widgets to allow your users to add curves, save data to files… by using
+the widgets' methods from the console.
+
+.. note::
+
+ This module has a dependency on
+ `qtconsole <https://pypi.org/project/qtconsole/>`_.
+ An ``ImportError`` will be raised if it is
+ imported while the dependencies are not satisfied.
+
+Basic usage example::
+
+ from silx.gui import qt
+ from silx.gui.console import IPythonWidget
+
+ app = qt.QApplication([])
+
+ hello_button = qt.QPushButton("Hello World!", None)
+ hello_button.show()
+
+ console = IPythonWidget()
+ console.show()
+ console.pushVariables({"the_button": hello_button})
+
+ app.exec()
+
+This program will display a console widget and a push button in two separate
+windows. You will be able to interact with the button from the console,
+for example change its text::
+
+ >>> the_button.setText("Spam spam")
+
+An IPython interactive console is a powerful tool that enables you to work
+with data and plot it.
+See `this tutorial <https://plot.ly/python/ipython-notebook-tutorial/>`_
+for more information on some of the rich features of IPython.
+"""
+__authors__ = ["Tim Rae", "V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "24/05/2016"
+
+import logging
+
+from . import qt
+
+_logger = logging.getLogger(__name__)
+
+
+# This widget cannot be used inside an interactive IPython shell.
+# It would raise MultipleInstanceError("Multiple incompatible subclass
+# instances of InProcessInteractiveShell are being created").
+try:
+ __IPYTHON__
+except NameError:
+ pass # Not in IPython
+else:
+ msg = "Module " + __name__ + " cannot be used within an IPython shell"
+ raise ImportError(msg)
+
+try:
+ from qtconsole.rich_jupyter_widget import RichJupyterWidget as \
+ _RichJupyterWidget
+except ImportError:
+ try:
+ from qtconsole.rich_ipython_widget import RichJupyterWidget as \
+ _RichJupyterWidget
+ except ImportError:
+ from qtconsole.rich_ipython_widget import RichIPythonWidget as \
+ _RichJupyterWidget
+
+from qtconsole.inprocess import QtInProcessKernelManager
+
+try:
+ from ipykernel import version_info as _ipykernel_version_info
+except ImportError:
+ _ipykernel_version_info = None
+
+
+class IPythonWidget(_RichJupyterWidget):
+ """Live IPython console widget.
+
+ .. image:: img/IPythonWidget.png
+
+ :param custom_banner: Custom welcome message to be printed at the top of
+ the console.
+ """
+
+ def __init__(self, parent=None, custom_banner=None, *args, **kwargs):
+ if parent is not None:
+ kwargs["parent"] = parent
+ super(IPythonWidget, self).__init__(*args, **kwargs)
+ if custom_banner is not None:
+ self.banner = custom_banner
+ self.setWindowTitle(self.banner)
+ self.kernel_manager = kernel_manager = QtInProcessKernelManager()
+ kernel_manager.start_kernel()
+
+ # Monkey-patch to workaround issue:
+ # https://github.com/ipython/ipykernel/issues/370
+ if (_ipykernel_version_info is not None and
+ _ipykernel_version_info[0] > 4 and
+ _ipykernel_version_info[:3] <= (5, 1, 0)):
+ def _abort_queues(*args, **kwargs):
+ pass
+ kernel_manager.kernel._abort_queues = _abort_queues
+
+ self.kernel_client = kernel_client = self._kernel_manager.client()
+ kernel_client.start_channels()
+
+ def stop():
+ kernel_client.stop_channels()
+ kernel_manager.shutdown_kernel()
+ self.exit_requested.connect(stop)
+
+ def sizeHint(self):
+ """Return a reasonable default size for usage in :class:`PlotWindow`"""
+ return qt.QSize(500, 300)
+
+ def pushVariables(self, variable_dict):
+ """ Given a dictionary containing name / value pairs, push those
+ variables to the IPython console widget.
+
+ :param variable_dict: Dictionary of variables to be pushed to the
+ console's interactive namespace (```{variable_name: object, …}```)
+ """
+ self.kernel_manager.kernel.shell.push(variable_dict)
+
+
+class IPythonDockWidget(qt.QDockWidget):
+ """Dock Widget including a :class:`IPythonWidget` inside
+ a vertical layout.
+
+ .. image:: img/IPythonDockWidget.png
+
+ :param available_vars: Dictionary of variables to be pushed to the
+ console's interactive namespace: ``{"variable_name": object, …}``
+ :param custom_banner: Custom welcome message to be printed at the top of
+ the console
+ :param title: Dock widget title
+ :param parent: Parent :class:`qt.QMainWindow` containing this
+ :class:`qt.QDockWidget`
+ """
+ def __init__(self, parent=None, available_vars=None, custom_banner=None,
+ title="Console"):
+ super(IPythonDockWidget, self).__init__(title, parent)
+
+ self.ipyconsole = IPythonWidget(custom_banner=custom_banner)
+
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self.setWidget(self.ipyconsole)
+
+ if available_vars is not None:
+ self.ipyconsole.pushVariables(available_vars)
+
+ def showEvent(self, event):
+ """Make sure this widget is raised when it is shown
+ (when it is first created as a tab in PlotWindow or when it is shown
+ again after hiding).
+ """
+ self.raise_()
+
+
+def main():
+ """Run a Qt app with an IPython console"""
+ app = qt.QApplication([])
+ widget = IPythonDockWidget()
+ widget.show()
+ app.exec()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/src/silx/gui/data/ArrayTableModel.py b/src/silx/gui/data/ArrayTableModel.py
new file mode 100644
index 0000000..23b0bb2
--- /dev/null
+++ b/src/silx/gui/data/ArrayTableModel.py
@@ -0,0 +1,650 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module defines a data model for displaying and editing arrays of any
+number of dimensions in a table view.
+"""
+from __future__ import division
+import numpy
+import logging
+from silx.gui import qt
+from silx.gui.data.TextFormatter import TextFormatter
+
+__authors__ = ["V.A. Sole"]
+__license__ = "MIT"
+__date__ = "27/09/2017"
+
+
+_logger = logging.getLogger(__name__)
+
+
+def _is_array(data):
+ """Return True if object implements all necessary attributes to be used
+ as a numpy array.
+
+ :param object data: Array-like object (numpy array, h5py dataset...)
+ :return: boolean
+ """
+ # add more required attribute if necessary
+ for attr in ("shape", "dtype"):
+ if not hasattr(data, attr):
+ return False
+ return True
+
+
+class ArrayTableModel(qt.QAbstractTableModel):
+ """This data model provides access to 2D slices in a N-dimensional
+ array.
+
+ A slice for a 3-D array is characterized by a perspective (the number of
+ the axis orthogonal to the slice) and an index at which the slice
+ intersects the orthogonal axis.
+
+ In the n-D case, only slices parallel to the last two axes are handled. A
+ slice is therefore characterized by a list of indices locating the
+ slice on all the :math:`n - 2` orthogonal axes.
+
+ :param parent: Parent QObject
+ :param data: Numpy array, or object implementing a similar interface
+ (e.g. h5py dataset)
+ :param str fmt: Format string for representing numerical values.
+ Default is ``"%g"``.
+ :param sequence[int] perspective: See documentation
+ of :meth:`setPerspective`.
+ """
+
+ MAX_NUMBER_OF_SECTIONS = 10e6
+ """Maximum number of displayed rows and columns"""
+
+ def __init__(self, parent=None, data=None, perspective=None):
+ qt.QAbstractTableModel.__init__(self, parent)
+
+ self._array = None
+ """n-dimensional numpy array"""
+
+ self._bgcolors = None
+ """(n+1)-dimensional numpy array containing RGB(A) color data
+ for the background color
+ """
+
+ self._fgcolors = None
+ """(n+1)-dimensional numpy array containing RGB(A) color data
+ for the foreground color
+ """
+
+ self._formatter = None
+ """Formatter for text representation of data"""
+
+ formatter = TextFormatter(self)
+ formatter.setUseQuoteForText(False)
+ self.setFormatter(formatter)
+
+ self._index = None
+ """This attribute stores the slice index, as a list of indices
+ where the frame intersects orthogonal axis."""
+
+ self._perspective = None
+ """Sequence of dimensions orthogonal to the frame to be viewed.
+ For an array with ``n`` dimensions, this is a sequence of ``n-2``
+ integers. the first dimension is numbered ``0``.
+ By default, the data frames use the last two dimensions as their axes
+ and therefore the perspective is a sequence of the first ``n-2``
+ dimensions.
+ For example, for a 5-D array, the default perspective is ``(0, 1, 2)``
+ and the default frames axes are ``(3, 4)``."""
+
+ # set _data and _perspective
+ self.setArrayData(data, perspective=perspective)
+
+ def _getRowDim(self):
+ """The row axis is the first axis parallel to the frames
+ (lowest dimension number)
+
+ Return None for 0-D (scalar) or 1-D arrays
+ """
+ n_dimensions = len(self._array.shape)
+ if n_dimensions < 2:
+ # scalar or 1D array: no row index
+ return None
+ # take all dimensions and remove the orthogonal ones
+ frame_axes = set(range(0, n_dimensions)) - set(self._perspective)
+ # sanity check
+ assert len(frame_axes) == 2
+ return min(frame_axes)
+
+ def _getColumnDim(self):
+ """The column axis is the second (highest dimension) axis parallel
+ to the frames
+
+ Return None for 0-D (scalar)
+ """
+ n_dimensions = len(self._array.shape)
+ if n_dimensions < 1:
+ # scalar: no column index
+ return None
+ frame_axes = set(range(0, n_dimensions)) - set(self._perspective)
+ # sanity check
+ assert (len(frame_axes) == 2) if n_dimensions > 1 else (len(frame_axes) == 1)
+ return max(frame_axes)
+
+ def _getIndexTuple(self, table_row, table_col):
+ """Return the n-dimensional index of a value in the original array,
+ based on its row and column indices in the table view
+
+ :param table_row: Row index (0-based) of a table cell
+ :param table_col: Column index (0-based) of a table cell
+ :return: Tuple of indices of the element in the numpy array
+ """
+ row_dim = self._getRowDim()
+ col_dim = self._getColumnDim()
+
+ # get indices on all orthogonal axes
+ selection = list(self._index)
+ # insert indices on parallel axes
+ if row_dim is not None:
+ selection.insert(row_dim, table_row)
+ if col_dim is not None:
+ selection.insert(col_dim, table_col)
+ return tuple(selection)
+
+ # Methods to be implemented to subclass QAbstractTableModel
+ def rowCount(self, parent_idx=None):
+ """QAbstractTableModel method
+ Return number of rows to be displayed in table"""
+ row_dim = self._getRowDim()
+ if row_dim is None:
+ # 0-D and 1-D arrays
+ return 1
+ return min(self._array.shape[row_dim], self.MAX_NUMBER_OF_SECTIONS)
+
+ def columnCount(self, parent_idx=None):
+ """QAbstractTableModel method
+ Return number of columns to be displayed in table"""
+ col_dim = self._getColumnDim()
+ if col_dim is None:
+ # 0-D array
+ return 1
+ return min(self._array.shape[col_dim], self.MAX_NUMBER_OF_SECTIONS)
+
+ def __isClipped(self, orientation=qt.Qt.Vertical) -> bool:
+ """Returns whether or not array is clipped in a given orientation"""
+ if orientation == qt.Qt.Vertical:
+ dim = self._getRowDim()
+ else:
+ dim = self._getColumnDim()
+ return (dim is not None and
+ self._array.shape[dim] > self.MAX_NUMBER_OF_SECTIONS)
+
+ def __isClippedIndex(self, index) -> bool:
+ """Returns whether or not index's cell represents clipped data."""
+ if not index.isValid():
+ return False
+ if index.row() == self.MAX_NUMBER_OF_SECTIONS - 2:
+ return self.__isClipped(qt.Qt.Vertical)
+ if index.column() == self.MAX_NUMBER_OF_SECTIONS - 2:
+ return self.__isClipped(qt.Qt.Horizontal)
+ return False
+
+ def __clippedData(self, role=qt.Qt.DisplayRole):
+ """Return data for cells representing clipped data"""
+ if role == qt.Qt.DisplayRole:
+ return "..."
+ elif role == qt.Qt.ToolTipRole:
+ return "Dataset is too large: display is clipped"
+ else:
+ return None
+
+ def data(self, index, role=qt.Qt.DisplayRole):
+ """QAbstractTableModel method to access data values
+ in the format ready to be displayed"""
+ if index.isValid():
+ if self.__isClippedIndex(index): # Special displayed for clipped data
+ return self.__clippedData(role)
+
+ row, column = index.row(), index.column()
+
+ # When clipped, display last data of the array in last column of the table
+ if (self.__isClipped(qt.Qt.Vertical) and
+ row == self.MAX_NUMBER_OF_SECTIONS - 1):
+ row = self._array.shape[self._getRowDim()] - 1
+ if (self.__isClipped(qt.Qt.Horizontal) and
+ column == self.MAX_NUMBER_OF_SECTIONS - 1):
+ column = self._array.shape[self._getColumnDim()] - 1
+
+ selection = self._getIndexTuple(row, column)
+
+ if role == qt.Qt.DisplayRole:
+ return self._formatter.toString(self._array[selection], self._array.dtype)
+
+ if role == qt.Qt.BackgroundRole and self._bgcolors is not None:
+ r, g, b = self._bgcolors[selection][0:3]
+ if self._bgcolors.shape[-1] == 3:
+ return qt.QColor(r, g, b)
+ if self._bgcolors.shape[-1] == 4:
+ a = self._bgcolors[selection][3]
+ return qt.QColor(r, g, b, a)
+
+ if role == qt.Qt.ForegroundRole:
+ if self._fgcolors is not None:
+ r, g, b = self._fgcolors[selection][0:3]
+ if self._fgcolors.shape[-1] == 3:
+ return qt.QColor(r, g, b)
+ if self._fgcolors.shape[-1] == 4:
+ a = self._fgcolors[selection][3]
+ return qt.QColor(r, g, b, a)
+
+ # no fg color given, use black or white
+ # based on luminosity threshold
+ elif self._bgcolors is not None:
+ r, g, b = self._bgcolors[selection][0:3]
+ lum = 0.21 * r + 0.72 * g + 0.07 * b
+ if lum < 128:
+ return qt.QColor(qt.Qt.white)
+ else:
+ return qt.QColor(qt.Qt.black)
+
+ def headerData(self, section, orientation, role=qt.Qt.DisplayRole):
+ """QAbstractTableModel method
+ Return the 0-based row or column index, for display in the
+ horizontal and vertical headers"""
+ if self.__isClipped(orientation): # Header is clipped
+ if section == self.MAX_NUMBER_OF_SECTIONS - 2:
+ # Represent clipped data
+ return self.__clippedData(role)
+
+ elif section == self.MAX_NUMBER_OF_SECTIONS - 1:
+ # Display last index from data not table
+ if role == qt.Qt.DisplayRole:
+ if orientation == qt.Qt.Vertical:
+ dim = self._getRowDim()
+ else:
+ dim = self._getColumnDim()
+ return str(self._array.shape[dim] - 1)
+ else:
+ return None
+
+ if role == qt.Qt.DisplayRole:
+ return "%d" % section
+ return None
+
+ def flags(self, index):
+ """QAbstractTableModel method to inform the view whether data
+ is editable or not."""
+ if not self._editable or self.__isClippedIndex(index):
+ return qt.QAbstractTableModel.flags(self, index)
+ return qt.QAbstractTableModel.flags(self, index) | qt.Qt.ItemIsEditable
+
+ def setData(self, index, value, role=None):
+ """QAbstractTableModel method to handle editing data.
+ Cast the new value into the same format as the array before editing
+ the array value."""
+ if index.isValid() and role == qt.Qt.EditRole:
+ try:
+ # cast value to same type as array
+ v = numpy.array(value, dtype=self._array.dtype).item()
+ except ValueError:
+ return False
+
+ selection = self._getIndexTuple(index.row(),
+ index.column())
+ self._array[selection] = v
+ self.dataChanged.emit(index, index)
+ return True
+ else:
+ return False
+
+ # Public methods
+ def setArrayData(self, data, copy=True,
+ perspective=None, editable=False):
+ """Set the data array and the viewing perspective.
+
+ You can set ``copy=False`` if you need more performances, when dealing
+ with a large numpy array. In this case, a simple reference to the data
+ is used to access the data, rather than a copy of the array.
+
+ .. warning::
+
+ Any change to the data model will affect your original data
+ array, when using a reference rather than a copy..
+
+ :param data: n-dimensional numpy array, or any object that can be
+ converted to a numpy array using ``numpy.array(data)`` (e.g.
+ a nested sequence).
+ :param bool copy: If *True* (default), a copy of the array is stored
+ and the original array is not modified if the table is edited.
+ If *False*, then the behavior depends on the data type:
+ if possible (if the original array is a proper numpy array)
+ a reference to the original array is used.
+ :param perspective: See documentation of :meth:`setPerspective`.
+ If None, the default perspective is the list of the first ``n-2``
+ dimensions, to view frames parallel to the last two axes.
+ :param bool editable: Flag to enable editing data. Default *False*.
+ """
+ self.beginResetModel()
+
+ if data is None:
+ # empty array
+ self._array = numpy.array([])
+ elif copy:
+ # copy requested (default)
+ self._array = numpy.array(data, copy=True)
+ if hasattr(data, "dtype"):
+ # Avoid to lose the monkey-patched h5py dtype
+ self._array.dtype = data.dtype
+ elif not _is_array(data):
+ raise TypeError("data is not a proper array. Try setting" +
+ " copy=True to convert it into a numpy array" +
+ " (this will cause the data to be copied!)")
+ # # copy not requested, but necessary
+ # _logger.warning(
+ # "data is not an array-like object. " +
+ # "Data must be copied.")
+ # self._array = numpy.array(data, copy=True)
+ else:
+ # Copy explicitly disabled & data implements required attributes.
+ # We can use a reference.
+ self._array = data
+
+ # reset colors to None if new data shape is inconsistent
+ valid_color_shapes = (self._array.shape + (3,),
+ self._array.shape + (4,))
+ if self._bgcolors is not None:
+ if self._bgcolors.shape not in valid_color_shapes:
+ self._bgcolors = None
+ if self._fgcolors is not None:
+ if self._fgcolors.shape not in valid_color_shapes:
+ self._fgcolors = None
+
+ self.setEditable(editable)
+
+ self._index = [0 for _i in range((len(self._array.shape) - 2))]
+ self._perspective = tuple(perspective) if perspective is not None else\
+ tuple(range(0, len(self._array.shape) - 2))
+
+ self.endResetModel()
+
+ def setArrayColors(self, bgcolors=None, fgcolors=None):
+ """Set the colors for all table cells by passing an array
+ of RGB or RGBA values (integers between 0 and 255).
+
+ The shape of the colors array must be consistent with the data shape.
+
+ If the data array is n-dimensional, the colors array must be
+ (n+1)-dimensional, with the first n-dimensions identical to the data
+ array dimensions, and the last dimension length-3 (RGB) or
+ length-4 (RGBA).
+
+ :param bgcolors: RGB or RGBA colors array, defining the background color
+ for each cell in the table.
+ :param fgcolors: RGB or RGBA colors array, defining the foreground color
+ (text color) for each cell in the table.
+ """
+ # array must be RGB or RGBA
+ valid_shapes = (self._array.shape + (3,), self._array.shape + (4,))
+ errmsg = "Inconsistent shape for color array, should be %s or %s" % valid_shapes
+
+ if bgcolors is not None:
+ if not _is_array(bgcolors):
+ bgcolors = numpy.array(bgcolors)
+ assert bgcolors.shape in valid_shapes, errmsg
+
+ self._bgcolors = bgcolors
+
+ if fgcolors is not None:
+ if not _is_array(fgcolors):
+ fgcolors = numpy.array(fgcolors)
+ assert fgcolors.shape in valid_shapes, errmsg
+
+ self._fgcolors = fgcolors
+
+ def setEditable(self, editable):
+ """Set flags to make the data editable.
+
+ .. warning::
+
+ If the data is a reference to a h5py dataset open in read-only
+ mode, setting *editable=True* will fail and print a warning.
+
+ .. warning::
+
+ Making the data editable means that the underlying data structure
+ in this data model will be modified.
+ If the data is a reference to a public object (open with
+ ``copy=False``), this could have side effects. If it is a
+ reference to an HDF5 dataset, this means the file will be
+ modified.
+
+ :param bool editable: Flag to enable editing data.
+ :return: True if setting desired flag succeeded, False if it failed.
+ """
+ self._editable = editable
+ if hasattr(self._array, "file"):
+ if hasattr(self._array.file, "mode"):
+ if editable and self._array.file.mode == "r":
+ _logger.warning(
+ "Data is a HDF5 dataset open in read-only " +
+ "mode. Editing must be disabled.")
+ self._editable = False
+ return False
+ return True
+
+ def getData(self, copy=True):
+ """Return a copy of the data array, or a reference to it
+ if *copy=False* is passed as parameter.
+
+ In case the shape was modified, to convert 0-D or 1-D data
+ into 2-D data, the original shape is restored in the returned data.
+
+ :param bool copy: If *True* (default), return a copy of the data. If
+ *False*, return a reference.
+ :return: numpy array of data, or reference to original data object
+ if *copy=False*
+ """
+ data = self._array if not copy else numpy.array(self._array, copy=True)
+ return data
+
+ def setFrameIndex(self, index):
+ """Set the active slice index.
+
+ This method is only relevant to arrays with at least 3 dimensions.
+
+ :param index: Index of the active slice in the array.
+ In the general n-D case, this is a sequence of :math:`n - 2`
+ indices where the slice intersects the respective orthogonal axes.
+ :raise IndexError: If any index in the index sequence is out of bound
+ on its respective axis.
+ """
+ shape = self._array.shape
+ if len(shape) < 3:
+ # index is ignored
+ return
+
+ self.beginResetModel()
+
+ if len(shape) == 3:
+ len_ = shape[self._perspective[0]]
+ # accept integers as index in the case of 3-D arrays
+ if not hasattr(index, "__len__"):
+ self._index = [index]
+ else:
+ self._index = index
+ if not 0 <= self._index[0] < len_:
+ raise ValueError("Index must be a positive integer " +
+ "lower than %d" % len_)
+ else:
+ # general n-D case
+ for i_, idx in enumerate(index):
+ if not 0 <= idx < shape[self._perspective[i_]]:
+ raise IndexError("Invalid index %d " % idx +
+ "not in range 0-%d" % (shape[i_] - 1))
+ self._index = index
+
+ self.endResetModel()
+
+ def setFormatter(self, formatter):
+ """Set the formatter object to be used to display data from the model
+
+ :param TextFormatter formatter: Formatter to use
+ """
+ if formatter is self._formatter:
+ return
+
+ self.beginResetModel()
+
+ if self._formatter is not None:
+ self._formatter.formatChanged.disconnect(self.__formatChanged)
+
+ self._formatter = formatter
+ if self._formatter is not None:
+ self._formatter.formatChanged.connect(self.__formatChanged)
+
+ self.endResetModel()
+
+ def getFormatter(self):
+ """Returns the text formatter used.
+
+ :rtype: TextFormatter
+ """
+ return self._formatter
+
+ def __formatChanged(self):
+ """Called when the format changed.
+ """
+ self.reset()
+
+ def setPerspective(self, perspective):
+ """Set the perspective by defining a sequence listing all axes
+ orthogonal to the frame or 2-D slice to be visualized.
+
+ Alternatively, you can use :meth:`setFrameAxes` for the complementary
+ approach of specifying the two axes parallel to the frame.
+
+ In the 1-D or 2-D case, this parameter is irrelevant.
+
+ In the 3-D case, if the unit vectors describing
+ your axes are :math:`\vec{x}, \vec{y}, \vec{z}`, a perspective of 0
+ means you slices are parallel to :math:`\vec{y}\vec{z}`, 1 means they
+ are parallel to :math:`\vec{x}\vec{z}` and 2 means they
+ are parallel to :math:`\vec{x}\vec{y}`.
+
+ In the n-D case, this parameter is a sequence of :math:`n-2` axes
+ numbers.
+ For instance if you want to display 2-D frames whose axes are the
+ second and third dimensions of a 5-D array, set the perspective to
+ ``(0, 3, 4)``.
+
+ :param perspective: Sequence of dimensions/axes orthogonal to the
+ frames.
+ :raise: IndexError if any value in perspective is higher than the
+ number of dimensions minus one (first dimension is 0), or
+ if the number of values is different from the number of dimensions
+ minus two.
+ """
+ n_dimensions = len(self._array.shape)
+ if n_dimensions < 3:
+ _logger.warning(
+ "perspective is not relevant for 1D and 2D arrays")
+ return
+
+ if not hasattr(perspective, "__len__"):
+ # we can tolerate an integer for 3-D array
+ if n_dimensions == 3:
+ perspective = [perspective]
+ else:
+ raise ValueError("perspective must be a sequence of integers")
+
+ # ensure unicity of dimensions in perspective
+ perspective = tuple(set(perspective))
+
+ if len(perspective) != n_dimensions - 2 or\
+ min(perspective) < 0 or max(perspective) >= n_dimensions:
+ raise IndexError(
+ "Invalid perspective " + str(perspective) +
+ " for %d-D array " % n_dimensions +
+ "with shape " + str(self._array.shape))
+
+ self.beginResetModel()
+
+ self._perspective = perspective
+
+ # reset index
+ self._index = [0 for _i in range(n_dimensions - 2)]
+
+ self.endResetModel()
+
+ def setFrameAxes(self, row_axis, col_axis):
+ """Set the perspective by specifying the two axes parallel to the frame
+ to be visualised.
+
+ The complementary approach of defining the orthogonal axes can be used
+ with :meth:`setPerspective`.
+
+ :param int row_axis: Index (0-based) of the first dimension used as a frame
+ axis
+ :param int col_axis: Index (0-based) of the 2nd dimension used as a frame
+ axis
+ :raise: IndexError if axes are invalid
+ """
+ if row_axis > col_axis:
+ _logger.warning("The dimension of the row axis must be lower " +
+ "than the dimension of the column axis. Swapping.")
+ row_axis, col_axis = min(row_axis, col_axis), max(row_axis, col_axis)
+
+ n_dimensions = len(self._array.shape)
+ if n_dimensions < 3:
+ _logger.warning(
+ "Frame axes cannot be changed for 1D and 2D arrays")
+ return
+
+ perspective = tuple(set(range(0, n_dimensions)) - {row_axis, col_axis})
+
+ if len(perspective) != n_dimensions - 2 or\
+ min(perspective) < 0 or max(perspective) >= n_dimensions:
+ raise IndexError(
+ "Invalid perspective " + str(perspective) +
+ " for %d-D array " % n_dimensions +
+ "with shape " + str(self._array.shape))
+
+ self.beginResetModel()
+
+ self._perspective = perspective
+ # reset index
+ self._index = [0 for _i in range(n_dimensions - 2)]
+
+ self.endResetModel()
+
+
+if __name__ == "__main__":
+ app = qt.QApplication([])
+ w = qt.QTableView()
+ d = numpy.random.normal(0, 1, (5, 1000, 1000))
+ for i in range(5):
+ d[i, :, :] += i * 10
+ m = ArrayTableModel(data=d)
+ w.setModel(m)
+ m.setFrameIndex(3)
+ # m.setArrayData(numpy.ones((100,)))
+ w.show()
+ app.exec()
diff --git a/src/silx/gui/data/ArrayTableWidget.py b/src/silx/gui/data/ArrayTableWidget.py
new file mode 100644
index 0000000..baef5f4
--- /dev/null
+++ b/src/silx/gui/data/ArrayTableWidget.py
@@ -0,0 +1,492 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module defines a widget designed to display data arrays with any
+number of dimensions as 2D frames (images, slices) in a table view.
+The dimensions not displayed in the table can be browsed using improved
+sliders.
+
+The widget uses a TableView that relies on a custom abstract item
+model: :class:`silx.gui.data.ArrayTableModel`.
+"""
+from __future__ import division
+import sys
+
+from silx.gui import qt
+from silx.gui.widgets.TableWidget import TableView
+from .ArrayTableModel import ArrayTableModel
+from silx.gui.widgets.FrameBrowser import HorizontalSliderWithBrowser
+
+__authors__ = ["V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "24/01/2017"
+
+
+class AxesSelector(qt.QWidget):
+ """Widget with two combo-boxes to select two dimensions among
+ all possible dimensions of an n-dimensional array.
+
+ The first combobox contains values from :math:`0` to :math:`n-2`.
+
+ The choices in the 2nd CB depend on the value selected in the first one.
+ If the value selected in the first CB is :math:`m`, the second one lets you
+ select values from :math:`m+1` to :math:`n-1`.
+
+ The two axes can be used to select the row axis and the column axis t
+ display a slice of the array data in a table view.
+ """
+ sigDimensionsChanged = qt.Signal(int, int)
+ """Signal emitted whenever one of the comboboxes is changed.
+ The signal carries the two selected dimensions."""
+
+ def __init__(self, parent=None, n=None):
+ qt.QWidget.__init__(self, parent)
+ self.layout = qt.QHBoxLayout(self)
+ self.layout.setContentsMargins(0, 2, 0, 2)
+ self.layout.setSpacing(10)
+
+ self.rowsCB = qt.QComboBox(self)
+ self.columnsCB = qt.QComboBox(self)
+
+ self.layout.addWidget(qt.QLabel("Rows dimension", self))
+ self.layout.addWidget(self.rowsCB)
+ self.layout.addWidget(qt.QLabel(" ", self))
+ self.layout.addWidget(qt.QLabel("Columns dimension", self))
+ self.layout.addWidget(self.columnsCB)
+ self.layout.addStretch(1)
+
+ self._slotsAreConnected = False
+ if n is not None:
+ self.setNDimensions(n)
+
+ def setNDimensions(self, n):
+ """Initialize combo-boxes depending on number of dimensions of array.
+ Initially, the rows dimension is the second-to-last one, and the
+ columns dimension is the last one.
+
+ Link the CBs together. MAke them emit a signal when their value is
+ changed.
+
+ :param int n: Number of dimensions of array
+ """
+ # remember the number of dimensions and the rows dimension
+ self.n = n
+ self._rowsDim = n - 2
+
+ # ensure slots are disconnected before (re)initializing widget
+ if self._slotsAreConnected:
+ self.rowsCB.currentIndexChanged.disconnect(self._rowDimChanged)
+ self.columnsCB.currentIndexChanged.disconnect(self._colDimChanged)
+
+ self._clear()
+ self.rowsCB.addItems([str(i) for i in range(n - 1)])
+ self.rowsCB.setCurrentIndex(n - 2)
+ if n >= 1:
+ self.columnsCB.addItem(str(n - 1))
+ self.columnsCB.setCurrentIndex(0)
+
+ # reconnect slots
+ self.rowsCB.currentIndexChanged.connect(self._rowDimChanged)
+ self.columnsCB.currentIndexChanged.connect(self._colDimChanged)
+ self._slotsAreConnected = True
+
+ # emit new dimensions
+ if n > 2:
+ self.sigDimensionsChanged.emit(n - 2, n - 1)
+
+ def setDimensions(self, row_dim, col_dim):
+ """Set the rows and columns dimensions.
+
+ The rows dimension must be lower than the columns dimension.
+
+ :param int row_dim: Rows dimension
+ :param int col_dim: Columns dimension
+ """
+ if row_dim >= col_dim:
+ raise IndexError("Row dimension must be lower than column dimension")
+ if not (0 <= row_dim < self.n - 1):
+ raise IndexError("Row dimension must be between 0 and %d" % (self.n - 2))
+ if not (row_dim < col_dim <= self.n - 1):
+ raise IndexError("Col dimension must be between %d and %d" % (row_dim + 1, self.n - 1))
+
+ # set the rows dimension; this triggers an update of columnsCB
+ self.rowsCB.setCurrentIndex(row_dim)
+ # columnsCB first item is "row_dim + 1". So index of "col_dim" is
+ # col_dim - (row_dim + 1)
+ self.columnsCB.setCurrentIndex(col_dim - row_dim - 1)
+
+ def getDimensions(self):
+ """Return a 2-tuple of the rows dimension and the columns dimension.
+
+ :return: 2-tuple of axes numbers (row_dimension, col_dimension)
+ """
+ return self._getRowDim(), self._getColDim()
+
+ def _clear(self):
+ """Empty the combo-boxes"""
+ self.rowsCB.clear()
+ self.columnsCB.clear()
+
+ def _getRowDim(self):
+ """Get rows dimension, selected in :attr:`rowsCB`
+ """
+ # rows combobox contains elements "0", ..."n-2",
+ # so the selected dim is always equal to the index
+ return self.rowsCB.currentIndex()
+
+ def _getColDim(self):
+ """Get columns dimension, selected in :attr:`columnsCB`"""
+ # columns combobox contains elements "row_dim+1", "row_dim+2", ..., "n-1"
+ # so the selected dim is equal to row_dim + 1 + index
+ return self._rowsDim + 1 + self.columnsCB.currentIndex()
+
+ def _rowDimChanged(self):
+ """Update columns combobox when the rows dimension is changed.
+
+ Emit :attr:`sigDimensionsChanged`"""
+ old_col_dim = self._getColDim()
+ new_row_dim = self._getRowDim()
+
+ # clear cols CB
+ self.columnsCB.currentIndexChanged.disconnect(self._colDimChanged)
+ self.columnsCB.clear()
+ # refill cols CB
+ for i in range(new_row_dim + 1, self.n):
+ self.columnsCB.addItem(str(i))
+
+ # keep previous col dimension if possible
+ new_col_cb_idx = old_col_dim - (new_row_dim + 1)
+ if new_col_cb_idx < 0:
+ # if row_dim is now greater than the previous col_dim,
+ # we select a new col_dim = row_dim + 1 (first element in cols CB)
+ new_col_cb_idx = 0
+ self.columnsCB.setCurrentIndex(new_col_cb_idx)
+
+ # reconnect slot
+ self.columnsCB.currentIndexChanged.connect(self._colDimChanged)
+
+ self._rowsDim = new_row_dim
+
+ self.sigDimensionsChanged.emit(self._getRowDim(), self._getColDim())
+
+ def _colDimChanged(self):
+ """Emit :attr:`sigDimensionsChanged`"""
+ self.sigDimensionsChanged.emit(self._getRowDim(), self._getColDim())
+
+
+def _get_shape(array_like):
+ """Return shape of an array like object.
+
+ In case the object is a nested sequence (list of lists, tuples...),
+ the size of each dimension is assumed to be uniform, and is deduced from
+ the length of the first sequence.
+
+ :param array_like: Array like object: numpy array, hdf5 dataset,
+ multi-dimensional sequence
+ :return: Shape of array, as a tuple of integers
+ """
+ if hasattr(array_like, "shape"):
+ return array_like.shape
+
+ shape = []
+ subsequence = array_like
+ while hasattr(subsequence, "__len__"):
+ shape.append(len(subsequence))
+ subsequence = subsequence[0]
+
+ return tuple(shape)
+
+
+class ArrayTableWidget(qt.QWidget):
+ """This widget is designed to display data of 2D frames (images, slices)
+ in a table view. The widget can load any n-dimensional array, and display
+ any 2-D frame/slice in the array.
+
+ The index of the dimensions orthogonal to the displayed frame can be set
+ interactively using a browser widget (sliders, buttons and text entries).
+
+ To set the data, use :meth:`setArrayData`.
+ To select the perspective, use :meth:`setPerspective` or
+ use :meth:`setFrameAxes`.
+ To select the frame, use :meth:`setFrameIndex`.
+
+ .. image:: img/ArrayTableWidget.png
+ """
+ def __init__(self, parent=None):
+ """
+
+ :param parent: parent QWidget
+ :param labels: list of labels for each dimension of the array
+ """
+ qt.QWidget.__init__(self, parent)
+ self.mainLayout = qt.QVBoxLayout(self)
+ self.mainLayout.setContentsMargins(0, 0, 0, 0)
+ self.mainLayout.setSpacing(0)
+
+ self.browserContainer = qt.QWidget(self)
+ self.browserLayout = qt.QGridLayout(self.browserContainer)
+ self.browserLayout.setContentsMargins(0, 0, 0, 0)
+ self.browserLayout.setSpacing(0)
+
+ self._dimensionLabelsText = []
+ """List of text labels sorted in the increasing order of the dimension
+ they apply to."""
+ self._browserLabels = []
+ """List of QLabel widgets."""
+ self._browserWidgets = []
+ """List of HorizontalSliderWithBrowser widgets."""
+
+ self.axesSelector = AxesSelector(self)
+
+ self.view = TableView(self)
+
+ self.mainLayout.addWidget(self.browserContainer)
+ self.mainLayout.addWidget(self.axesSelector)
+ self.mainLayout.addWidget(self.view)
+
+ self.model = ArrayTableModel(self)
+ self.view.setModel(self.model)
+
+ def setArrayData(self, data, labels=None, copy=True, editable=False):
+ """Set the data array. Update frame browsers and labels.
+
+ :param data: Numpy array or similar object (e.g. nested sequence,
+ h5py dataset...)
+ :param labels: list of labels for each dimension of the array, or
+ boolean ``True`` to use default labels ("dimension 0",
+ "dimension 1", ...). `None` to disable labels (default).
+ :param bool copy: If *True*, store a copy of *data* in the model. If
+ *False*, store a reference to *data* if possible (only possible if
+ *data* is a proper numpy array or an object that implements the
+ same methods).
+ :param bool editable: Flag to enable editing data. Default is *False*
+ """
+ self._data_shape = _get_shape(data)
+
+ n_widgets = len(self._browserWidgets)
+ n_dimensions = len(self._data_shape)
+
+ # Reset text of labels
+ self._dimensionLabelsText = []
+ for i in range(n_dimensions):
+ if labels in [True, 1]:
+ label_text = "Dimension %d" % i
+ elif labels is None or i >= len(labels):
+ label_text = ""
+ else:
+ label_text = labels[i]
+ self._dimensionLabelsText.append(label_text)
+
+ # not enough widgets, create new ones (we need n_dim - 2)
+ for i in range(n_widgets, n_dimensions - 2):
+ browser = HorizontalSliderWithBrowser(self.browserContainer)
+ self.browserLayout.addWidget(browser, i, 1)
+ self._browserWidgets.append(browser)
+ browser.valueChanged.connect(self._browserSlot)
+ browser.setEnabled(False)
+ browser.hide()
+
+ label = qt.QLabel(self.browserContainer)
+ self._browserLabels.append(label)
+ self.browserLayout.addWidget(label, i, 0)
+ label.hide()
+
+ n_widgets = len(self._browserWidgets)
+ for i in range(n_widgets):
+ label = self._browserLabels[i]
+ browser = self._browserWidgets[i]
+
+ if (i + 2) < n_dimensions:
+ label.setText(self._dimensionLabelsText[i])
+ browser.setRange(0, self._data_shape[i] - 1)
+ browser.setEnabled(True)
+ browser.show()
+ if labels is not None:
+ label.show()
+ else:
+ label.hide()
+ else:
+ browser.setEnabled(False)
+ browser.hide()
+ label.hide()
+
+ # set model
+ self.model.setArrayData(data, copy=copy, editable=editable)
+ # some linux distributions need this call
+ self.view.setModel(self.model)
+ if editable:
+ self.view.enableCut()
+ self.view.enablePaste()
+
+ # initialize & connect axesSelector
+ self.axesSelector.setNDimensions(n_dimensions)
+ self.axesSelector.sigDimensionsChanged.connect(self.setFrameAxes)
+
+ def setArrayColors(self, bgcolors=None, fgcolors=None):
+ """Set the colors for all table cells by passing an array
+ of RGB or RGBA values (integers between 0 and 255).
+
+ The shape of the colors array must be consistent with the data shape.
+
+ If the data array is n-dimensional, the colors array must be
+ (n+1)-dimensional, with the first n-dimensions identical to the data
+ array dimensions, and the last dimension length-3 (RGB) or
+ length-4 (RGBA).
+
+ :param bgcolors: RGB or RGBA colors array, defining the background color
+ for each cell in the table.
+ :param fgcolors: RGB or RGBA colors array, defining the foreground color
+ (text color) for each cell in the table.
+ """
+ self.model.setArrayColors(bgcolors, fgcolors)
+
+ def displayAxesSelector(self, isVisible):
+ """Allow to display or hide the axes selector.
+
+ :param bool isVisible: True to display the axes selector.
+ """
+ self.axesSelector.setVisible(isVisible)
+
+ def setFrameIndex(self, index):
+ """Set the active slice/image index in the n-dimensional array.
+
+ A frame is a 2D array extracted from an array. This frame is
+ necessarily parallel to 2 axes, and orthogonal to all other axes.
+
+ The index of a frame is a sequence of indices along the orthogonal
+ axes, where the frame intersects the respective axis. The indices
+ are listed in the same order as the corresponding dimensions of the
+ data array.
+
+ For example, it the data array has 5 dimensions, and we are
+ considering frames whose parallel axes are the 2nd and 4th dimensions
+ of the array, the frame index will be a sequence of length 3
+ corresponding to the indices where the frame intersects the 1st, 3rd
+ and 5th axes.
+
+ :param index: Sequence of indices defining the active data slice in
+ a n-dimensional array. The sequence length is :math:`n-2`
+ :raise: IndexError if any index in the index sequence is out of bound
+ on its respective axis.
+ """
+ self.model.setFrameIndex(index)
+
+ def _resetBrowsers(self, perspective):
+ """Adjust limits for browsers based on the perspective and the
+ size of the corresponding dimensions. Reset the index to 0.
+ Update the dimension in the labels.
+
+ :param perspective: Sequence of axes/dimensions numbers (0-based)
+ defining the axes orthogonal to the frame.
+ """
+ # for 3D arrays we can accept an int rather than a 1-tuple
+ if not hasattr(perspective, "__len__"):
+ perspective = [perspective]
+
+ # perspective must be sorted
+ perspective = sorted(perspective)
+
+ n_dimensions = len(self._data_shape)
+ for i in range(n_dimensions - 2):
+ browser = self._browserWidgets[i]
+ label = self._browserLabels[i]
+ browser.setRange(0, self._data_shape[perspective[i]] - 1)
+ browser.setValue(0)
+ label.setText(self._dimensionLabelsText[perspective[i]])
+
+ def setPerspective(self, perspective):
+ """Set the *perspective* by specifying which axes are orthogonal
+ to the frame.
+
+ For the opposite approach (defining parallel axes), use
+ :meth:`setFrameAxes` instead.
+
+ :param perspective: Sequence of unique axes numbers (0-based) defining
+ the orthogonal axes. For a n-dimensional array, the sequence
+ length is :math:`n-2`. The order is of the sequence is not taken
+ into account (the dimensions are displayed in increasing order
+ in the widget).
+ """
+ self.model.setPerspective(perspective)
+ self._resetBrowsers(perspective)
+
+ def setFrameAxes(self, row_axis, col_axis):
+ """Set the *perspective* by specifying which axes are parallel
+ to the frame.
+
+ For the opposite approach (defining orthogonal axes), use
+ :meth:`setPerspective` instead.
+
+ :param int row_axis: Index (0-based) of the first dimension used as a frame
+ axis
+ :param int col_axis: Index (0-based) of the 2nd dimension used as a frame
+ axis
+ """
+ self.model.setFrameAxes(row_axis, col_axis)
+ n_dimensions = len(self._data_shape)
+ perspective = tuple(set(range(0, n_dimensions)) - {row_axis, col_axis})
+ self._resetBrowsers(perspective)
+
+ def _browserSlot(self, value):
+ index = []
+ for browser in self._browserWidgets:
+ if browser.isEnabled():
+ index.append(browser.value())
+ self.setFrameIndex(index)
+ self.view.reset()
+
+ def getData(self, copy=True):
+ """Return a copy of the data array, or a reference to it if
+ *copy=False* is passed as parameter.
+
+ :param bool copy: If *True* (default), return a copy of the data. If
+ *False*, return a reference.
+ :return: Numpy array of data, or reference to original data object
+ if *copy=False*
+ """
+ return self.model.getData(copy=copy)
+
+
+def main():
+ import numpy
+ a = qt.QApplication([])
+ d = numpy.random.normal(0, 1, (4, 5, 1000, 1000))
+ for j in range(4):
+ for i in range(5):
+ d[j, i, :, :] += i + 10 * j
+ w = ArrayTableWidget()
+ if "2" in sys.argv:
+ print("sending a single image")
+ w.setArrayData(d[0, 0])
+ elif "3" in sys.argv:
+ print("sending 5 images")
+ w.setArrayData(d[0])
+ else:
+ print("sending 4 * 5 images ")
+ w.setArrayData(d, labels=True)
+ w.show()
+ a.exec()
+
+if __name__ == "__main__":
+ main()
diff --git a/src/silx/gui/data/DataViewer.py b/src/silx/gui/data/DataViewer.py
new file mode 100644
index 0000000..2e51439
--- /dev/null
+++ b/src/silx/gui/data/DataViewer.py
@@ -0,0 +1,593 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module defines a widget designed to display data using the most adapted
+view from the ones provided by silx.
+"""
+from __future__ import division
+
+import logging
+import os.path
+import collections
+from silx.gui import qt
+from silx.gui.data import DataViews
+from silx.gui.data.DataViews import _normalizeData
+from silx.gui.utils import blockSignals
+from silx.gui.data.NumpyAxesSelector import NumpyAxesSelector
+
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "12/02/2019"
+
+
+_logger = logging.getLogger(__name__)
+
+
+DataSelection = collections.namedtuple("DataSelection",
+ ["filename", "datapath",
+ "slice", "permutation"])
+
+
+class DataViewer(qt.QFrame):
+ """Widget to display any kind of data
+
+ .. image:: img/DataViewer.png
+
+ The method :meth:`setData` allows to set any data to the widget. Mostly
+ `numpy.array` and `h5py.Dataset` are supported with adapted views. Other
+ data types are displayed using a text viewer.
+
+ A default view is automatically selected when a data is set. The method
+ :meth:`setDisplayMode` allows to change the view. To have a graphical tool
+ to select the view, prefer using the widget :class:`DataViewerFrame`.
+
+ The dimension of the input data and the expected dimension of the selected
+ view can differ. For example you can display an image (2D) from 4D
+ data. In this case a :class:`NumpyAxesSelector` is displayed to allow the
+ user to select the axis mapping and the slicing of other axes.
+
+ .. code-block:: python
+
+ import numpy
+ data = numpy.random.rand(500,500)
+ viewer = DataViewer()
+ viewer.setData(data)
+ viewer.setVisible(True)
+ """
+
+ displayedViewChanged = qt.Signal(object)
+ """Emitted when the displayed view changes"""
+
+ dataChanged = qt.Signal()
+ """Emitted when the data changes"""
+
+ currentAvailableViewsChanged = qt.Signal()
+ """Emitted when the current available views (which support the current
+ data) change"""
+
+ def __init__(self, parent=None):
+ """Constructor
+
+ :param QWidget parent: The parent of the widget
+ """
+ super(DataViewer, self).__init__(parent)
+
+ self.__stack = qt.QStackedWidget(self)
+ self.__numpySelection = NumpyAxesSelector(self)
+ self.__numpySelection.selectedAxisChanged.connect(self.__numpyAxisChanged)
+ self.__numpySelection.selectionChanged.connect(self.__numpySelectionChanged)
+ self.__numpySelection.customAxisChanged.connect(self.__numpyCustomAxisChanged)
+
+ self.setLayout(qt.QVBoxLayout(self))
+ self.layout().addWidget(self.__stack, 1)
+
+ group = qt.QGroupBox(self)
+ group.setLayout(qt.QVBoxLayout())
+ group.layout().addWidget(self.__numpySelection)
+ group.setTitle("Axis selection")
+ self.__axisSelection = group
+
+ self.layout().addWidget(self.__axisSelection)
+
+ self.__currentAvailableViews = []
+ self.__currentView = None
+ self.__data = None
+ self.__info = None
+ self.__useAxisSelection = False
+ self.__userSelectedView = None
+ self.__hooks = None
+
+ self.__views = []
+ self.__index = {}
+ """store stack index for each views"""
+
+ self._initializeViews()
+
+ def _initializeViews(self):
+ """Inisialize the available views"""
+ views = self.createDefaultViews(self.__stack)
+ self.__views = list(views)
+ self.setDisplayMode(DataViews.EMPTY_MODE)
+
+ def setGlobalHooks(self, hooks):
+ """Set a data view hooks for all the views
+
+ :param DataViewHooks context: The hooks to use
+ """
+ self.__hooks = hooks
+ for v in self.__views:
+ v.setHooks(hooks)
+
+ def createDefaultViews(self, parent=None):
+ """Create and returns available views which can be displayed by default
+ by the data viewer. It is called internally by the widget. It can be
+ overwriten to provide a different set of viewers.
+
+ :param QWidget parent: QWidget parent of the views
+ :rtype: List[silx.gui.data.DataViews.DataView]
+ """
+ viewClasses = [
+ DataViews._EmptyView,
+ DataViews._Hdf5View,
+ DataViews._NXdataView,
+ DataViews._Plot1dView,
+ DataViews._ImageView,
+ DataViews._Plot3dView,
+ DataViews._RawView,
+ DataViews._StackView,
+ DataViews._Plot2dRecordView,
+ ]
+ views = []
+ for viewClass in viewClasses:
+ try:
+ view = viewClass(parent)
+ views.append(view)
+ except Exception:
+ _logger.warning("%s instantiation failed. View is ignored" % viewClass.__name__)
+ _logger.debug("Backtrace", exc_info=True)
+
+ return views
+
+ def clear(self):
+ """Clear the widget"""
+ self.setData(None)
+
+ def normalizeData(self, data):
+ """Returns a normalized data if the embed a numpy or a dataset.
+ Else returns the data."""
+ return _normalizeData(data)
+
+ def __getStackIndex(self, view):
+ """Get the stack index containing the view.
+
+ :param silx.gui.data.DataViews.DataView view: The view
+ """
+ if view not in self.__index:
+ widget = view.getWidget()
+ index = self.__stack.addWidget(widget)
+ self.__index[view] = index
+ else:
+ index = self.__index[view]
+ return index
+
+ def __clearCurrentView(self):
+ """Clear the current selected view"""
+ view = self.__currentView
+ if view is not None:
+ view.clear()
+
+ def __numpyCustomAxisChanged(self, name, value):
+ view = self.__currentView
+ if view is not None:
+ view.setCustomAxisValue(name, value)
+
+ def __updateNumpySelectionAxis(self):
+ """
+ Update the numpy-selector according to the needed axis names
+ """
+ with blockSignals(self.__numpySelection):
+ previousPermutation = self.__numpySelection.permutation()
+ previousSelection = self.__numpySelection.selection()
+
+ self.__numpySelection.clear()
+
+ info = self._getInfo()
+ axisNames = self.__currentView.axesNames(self.__data, info)
+ if (info.isArray and info.size != 0 and
+ self.__data is not None and axisNames is not None):
+ self.__useAxisSelection = True
+ self.__numpySelection.setAxisNames(axisNames)
+ self.__numpySelection.setCustomAxis(
+ self.__currentView.customAxisNames())
+ data = self.normalizeData(self.__data)
+ self.__numpySelection.setData(data)
+
+ # Try to restore previous permutation and selection
+ try:
+ self.__numpySelection.setSelection(
+ previousSelection, previousPermutation)
+ except ValueError as e:
+ _logger.info("Not restoring selection because: %s", e)
+
+ if hasattr(data, "shape"):
+ isVisible = not (len(axisNames) == 1 and len(data.shape) == 1)
+ else:
+ isVisible = True
+ self.__axisSelection.setVisible(isVisible)
+ else:
+ self.__useAxisSelection = False
+ self.__axisSelection.setVisible(False)
+
+ def __updateDataInView(self):
+ """
+ Update the views using the current data
+ """
+ if self.__useAxisSelection:
+ self.__displayedData = self.__numpySelection.selectedData()
+
+ permutation = self.__numpySelection.permutation()
+ normal = tuple(range(len(permutation)))
+ if permutation == normal:
+ permutation = None
+ slicing = self.__numpySelection.selection()
+ normal = tuple([slice(None)] * len(slicing))
+ if slicing == normal:
+ slicing = None
+ else:
+ self.__displayedData = self.__data
+ permutation = None
+ slicing = None
+
+ try:
+ filename = os.path.abspath(self.__data.file.filename)
+ except:
+ filename = None
+
+ try:
+ datapath = self.__data.name
+ except:
+ datapath = None
+
+ # FIXME: maybe use DataUrl, with added support of permutation
+ self.__displayedSelection = DataSelection(filename, datapath, slicing, permutation)
+
+ # TODO: would be good to avoid that, it should be synchonous
+ qt.QTimer.singleShot(10, self.__setDataInView)
+
+ def __setDataInView(self):
+ self.__currentView.setData(self.__displayedData)
+ self.__currentView.setDataSelection(self.__displayedSelection)
+
+ def setDisplayedView(self, view):
+ """Set the displayed view.
+
+ Change the displayed view according to the view itself.
+
+ :param silx.gui.data.DataViews.DataView view: The DataView to use to display the data
+ """
+ self.__userSelectedView = view
+ self._setDisplayedView(view)
+
+ def _setDisplayedView(self, view):
+ """Internal set of the displayed view.
+
+ Change the displayed view according to the view itself.
+
+ :param silx.gui.data.DataViews.DataView view: The DataView to use to display the data
+ """
+ if self.__currentView is view:
+ return
+ self.__clearCurrentView()
+ self.__currentView = view
+ self.__updateNumpySelectionAxis()
+ self.__updateDataInView()
+ stackIndex = self.__getStackIndex(self.__currentView)
+ if self.__currentView is not None:
+ self.__currentView.select()
+ self.__stack.setCurrentIndex(stackIndex)
+ self.displayedViewChanged.emit(view)
+
+ def getViewFromModeId(self, modeId):
+ """Returns the first available view which have the requested modeId.
+ Return None if modeId does not correspond to an existing view.
+
+ :param int modeId: Requested mode id
+ :rtype: silx.gui.data.DataViews.DataView
+ """
+ for view in self.__views:
+ if view.modeId() == modeId:
+ return view
+ return None
+
+ def setDisplayMode(self, modeId):
+ """Set the displayed view using display mode.
+
+ Change the displayed view according to the requested mode.
+
+ :param int modeId: Display mode, one of
+
+ - `DataViews.EMPTY_MODE`: display nothing
+ - `DataViews.PLOT1D_MODE`: display the data as a curve
+ - `DataViews.IMAGE_MODE`: display the data as an image
+ - `DataViews.PLOT3D_MODE`: display the data as an isosurface
+ - `DataViews.RAW_MODE`: display the data as a table
+ - `DataViews.STACK_MODE`: display the data as a stack of images
+ - `DataViews.HDF5_MODE`: display the data as a table of HDF5 info
+ - `DataViews.NXDATA_MODE`: display the data as NXdata
+ """
+ try:
+ view = self.getViewFromModeId(modeId)
+ except KeyError:
+ raise ValueError("Display mode %s is unknown" % modeId)
+ self._setDisplayedView(view)
+
+ def displayedView(self):
+ """Returns the current displayed view.
+
+ :rtype: silx.gui.data.DataViews.DataView
+ """
+ return self.__currentView
+
+ def addView(self, view):
+ """Allow to add a view to the dataview.
+
+ If the current data support this view, it will be displayed.
+
+ :param DataView view: A dataview
+ """
+ if self.__hooks is not None:
+ view.setHooks(self.__hooks)
+ self.__views.append(view)
+ # TODO It can be skipped if the view do not support the data
+ self.__updateAvailableViews()
+
+ def removeView(self, view):
+ """Allow to remove a view which was available from the dataview.
+
+ If the view was displayed, the widget will be updated.
+
+ :param DataView view: A dataview
+ """
+ self.__views.remove(view)
+ self.__stack.removeWidget(view.getWidget())
+ # invalidate the full index. It will be updated as expected
+ self.__index = {}
+
+ if self.__userSelectedView is view:
+ self.__userSelectedView = None
+
+ if view is self.__currentView:
+ self.__updateView()
+ else:
+ # TODO It can be skipped if the view is not part of the
+ # available views
+ self.__updateAvailableViews()
+
+ def __updateAvailableViews(self):
+ """
+ Update available views from the current data.
+ """
+ data = self.__data
+ info = self._getInfo()
+ # sort available views according to priority
+ views = []
+ for v in self.__views:
+ views.extend(v.getMatchingViews(data, info))
+ views = [(v.getCachedDataPriority(data, info), v) for v in views]
+ views = filter(lambda t: t[0] > DataViews.DataView.UNSUPPORTED, views)
+ views = sorted(views, reverse=True)
+ views = [v[1] for v in views]
+
+ # store available views
+ self.__setCurrentAvailableViews(views)
+
+ def __updateView(self):
+ """Display the data using the widget which fit the best"""
+ data = self.__data
+
+ # update available views for this data
+ self.__updateAvailableViews()
+ available = self.__currentAvailableViews
+
+ # display the view with the most priority (the default view)
+ view = self.getDefaultViewFromAvailableViews(data, available)
+ self.__clearCurrentView()
+ try:
+ self._setDisplayedView(view)
+ except Exception as e:
+ # in case there is a problem to read the data, try to use a safe
+ # view
+ view = self.getSafeViewFromAvailableViews(data, available)
+ self._setDisplayedView(view)
+ raise e
+
+ def getSafeViewFromAvailableViews(self, data, available):
+ """Returns a view which is sure to display something without failing
+ on rendering.
+
+ :param object data: data which will be displayed
+ :param List[view] available: List of available views, from highest
+ priority to lowest.
+ :rtype: DataView
+ """
+ hdf5View = self.getViewFromModeId(DataViews.HDF5_MODE)
+ if hdf5View in available:
+ return hdf5View
+ return self.getViewFromModeId(DataViews.EMPTY_MODE)
+
+ def getDefaultViewFromAvailableViews(self, data, available):
+ """Returns the default view which will be used according to available
+ views.
+
+ :param object data: data which will be displayed
+ :param List[view] available: List of available views, from highest
+ priority to lowest.
+ :rtype: DataView
+ """
+ if len(available) > 0:
+ # returns the view with the highest priority
+ if self.__userSelectedView in available:
+ return self.__userSelectedView
+ self.__userSelectedView = None
+ view = available[0]
+ else:
+ # else returns the empty view
+ view = self.getViewFromModeId(DataViews.EMPTY_MODE)
+ return view
+
+ def __setCurrentAvailableViews(self, availableViews):
+ """Set the current available viewa
+
+ :param List[DataView] availableViews: Current available viewa
+ """
+ self.__currentAvailableViews = availableViews
+ self.currentAvailableViewsChanged.emit()
+
+ def currentAvailableViews(self):
+ """Returns the list of available views for the current data
+
+ :rtype: List[DataView]
+ """
+ return self.__currentAvailableViews
+
+ def getReachableViews(self):
+ """Returns the list of reachable views from the registred available
+ views.
+
+ :rtype: List[DataView]
+ """
+ views = []
+ for v in self.availableViews():
+ views.extend(v.getReachableViews())
+ return views
+
+ def availableViews(self):
+ """Returns the list of registered views
+
+ :rtype: List[DataView]
+ """
+ return self.__views
+
+ def setData(self, data):
+ """Set the data to view.
+
+ It mostly can be a h5py.Dataset or a numpy.ndarray. Other kind of
+ objects will be displayed as text rendering.
+
+ :param numpy.ndarray data: The data.
+ """
+ self.__data = data
+ self._invalidateInfo()
+ self.__displayedData = None
+ self.__displayedSelection = None
+ self.__updateView()
+ self.__updateNumpySelectionAxis()
+ self.__updateDataInView()
+ self.dataChanged.emit()
+
+ def __numpyAxisChanged(self):
+ """
+ Called when axis selection of the numpy-selector changed
+ """
+ self.__clearCurrentView()
+
+ def __numpySelectionChanged(self):
+ """
+ Called when data selection of the numpy-selector changed
+ """
+ self.__updateDataInView()
+
+ def data(self):
+ """Returns the data"""
+ return self.__data
+
+ def _invalidateInfo(self):
+ """Invalidate DataInfo cache."""
+ self.__info = None
+
+ def _getInfo(self):
+ """Returns the DataInfo of the current selected data.
+
+ This value is cached.
+
+ :rtype: DataInfo
+ """
+ if self.__info is None:
+ self.__info = DataViews.DataInfo(self.__data)
+ return self.__info
+
+ def displayMode(self):
+ """Returns the current display mode"""
+ return self.__currentView.modeId()
+
+ def replaceView(self, modeId, newView):
+ """Replace one of the builtin data views with a custom view.
+ Return True in case of success, False in case of failure.
+
+ .. note::
+
+ This method must be called just after instantiation, before
+ the viewer is used.
+
+ :param int modeId: Unique mode ID identifying the DataView to
+ be replaced. One of:
+
+ - `DataViews.EMPTY_MODE`
+ - `DataViews.PLOT1D_MODE`
+ - `DataViews.IMAGE_MODE`
+ - `DataViews.PLOT2D_MODE`
+ - `DataViews.COMPLEX_IMAGE_MODE`
+ - `DataViews.PLOT3D_MODE`
+ - `DataViews.RAW_MODE`
+ - `DataViews.STACK_MODE`
+ - `DataViews.HDF5_MODE`
+ - `DataViews.NXDATA_MODE`
+ - `DataViews.NXDATA_INVALID_MODE`
+ - `DataViews.NXDATA_SCALAR_MODE`
+ - `DataViews.NXDATA_CURVE_MODE`
+ - `DataViews.NXDATA_XYVSCATTER_MODE`
+ - `DataViews.NXDATA_IMAGE_MODE`
+ - `DataViews.NXDATA_STACK_MODE`
+
+ :param DataViews.DataView newView: New data view
+ :return: True if replacement was successful, else False
+ """
+ assert isinstance(newView, DataViews.DataView)
+ isReplaced = False
+ for idx, view in enumerate(self.__views):
+ if view.modeId() == modeId:
+ if self.__hooks is not None:
+ newView.setHooks(self.__hooks)
+ self.__views[idx] = newView
+ isReplaced = True
+ break
+ elif isinstance(view, DataViews.CompositeDataView):
+ isReplaced = view.replaceView(modeId, newView)
+ if isReplaced:
+ break
+
+ if isReplaced:
+ self.__updateAvailableViews()
+ return isReplaced
diff --git a/src/silx/gui/data/DataViewerFrame.py b/src/silx/gui/data/DataViewerFrame.py
new file mode 100644
index 0000000..9bfb95b
--- /dev/null
+++ b/src/silx/gui/data/DataViewerFrame.py
@@ -0,0 +1,217 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module contains a DataViewer with a view selector.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "12/02/2019"
+
+from silx.gui import qt
+from .DataViewer import DataViewer
+from .DataViewerSelector import DataViewerSelector
+
+
+class DataViewerFrame(qt.QWidget):
+ """
+ A :class:`DataViewer` with a view selector.
+
+ .. image:: img/DataViewerFrame.png
+
+ This widget provides the same API as :class:`DataViewer`. Therefore, for more
+ documentation, take a look at the documentation of the class
+ :class:`DataViewer`.
+
+ .. code-block:: python
+
+ import numpy
+ data = numpy.random.rand(500,500)
+ viewer = DataViewerFrame()
+ viewer.setData(data)
+ viewer.setVisible(True)
+
+ """
+
+ displayedViewChanged = qt.Signal(object)
+ """Emitted when the displayed view changes"""
+
+ dataChanged = qt.Signal()
+ """Emitted when the data changes"""
+
+ def __init__(self, parent=None):
+ """
+ Constructor
+
+ :param qt.QWidget parent:
+ """
+ super(DataViewerFrame, self).__init__(parent)
+
+ class _DataViewer(DataViewer):
+ """Overwrite methods to avoid to create views while the instance
+ is not created. `initializeViews` have to be called manually."""
+
+ def _initializeViews(self):
+ pass
+
+ def initializeViews(self):
+ """Avoid to create views while the instance is not created."""
+ super(_DataViewer, self)._initializeViews()
+
+ def _createDefaultViews(self, parent):
+ """Expose the original `createDefaultViews` function"""
+ return super(_DataViewer, self).createDefaultViews()
+
+ def createDefaultViews(self, parent=None):
+ """Allow the DataViewerFrame to override this function"""
+ return self.parent().createDefaultViews(parent)
+
+ self.__dataViewer = _DataViewer(self)
+ # initialize views when `self.__dataViewer` is set
+ self.__dataViewer.initializeViews()
+ self.__dataViewer.setFrameShape(qt.QFrame.StyledPanel)
+ self.__dataViewer.setFrameShadow(qt.QFrame.Sunken)
+ self.__dataViewerSelector = DataViewerSelector(self, self.__dataViewer)
+ self.__dataViewerSelector.setFlat(True)
+
+ layout = qt.QVBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(0)
+ layout.addWidget(self.__dataViewer, 1)
+ layout.addWidget(self.__dataViewerSelector)
+ self.setLayout(layout)
+
+ self.__dataViewer.dataChanged.connect(self.__dataChanged)
+ self.__dataViewer.displayedViewChanged.connect(self.__displayedViewChanged)
+
+ def __dataChanged(self):
+ """Called when the data is changed"""
+ self.dataChanged.emit()
+
+ def __displayedViewChanged(self, view):
+ """Called when the displayed view changes"""
+ self.displayedViewChanged.emit(view)
+
+ def setGlobalHooks(self, hooks):
+ """Set a data view hooks for all the views
+
+ :param DataViewHooks context: The hooks to use
+ """
+ self.__dataViewer.setGlobalHooks(hooks)
+
+ def getReachableViews(self):
+ return self.__dataViewer.getReachableViews()
+
+ def availableViews(self):
+ """Returns the list of registered views
+
+ :rtype: List[DataView]
+ """
+ return self.__dataViewer.availableViews()
+
+ def currentAvailableViews(self):
+ """Returns the list of available views for the current data
+
+ :rtype: List[DataView]
+ """
+ return self.__dataViewer.currentAvailableViews()
+
+ def createDefaultViews(self, parent=None):
+ """Create and returns available views which can be displayed by default
+ by the data viewer. It is called internally by the widget. It can be
+ overwriten to provide a different set of viewers.
+
+ :param QWidget parent: QWidget parent of the views
+ :rtype: List[silx.gui.data.DataViews.DataView]
+ """
+ return self.__dataViewer._createDefaultViews(parent)
+
+ def addView(self, view):
+ """Allow to add a view to the dataview.
+
+ If the current data support this view, it will be displayed.
+
+ :param DataView view: A dataview
+ """
+ return self.__dataViewer.addView(view)
+
+ def removeView(self, view):
+ """Allow to remove a view which was available from the dataview.
+
+ If the view was displayed, the widget will be updated.
+
+ :param DataView view: A dataview
+ """
+ return self.__dataViewer.removeView(view)
+
+ def setData(self, data):
+ """Set the data to view.
+
+ It mostly can be a h5py.Dataset or a numpy.ndarray. Other kind of
+ objects will be displayed as text rendering.
+
+ :param numpy.ndarray data: The data.
+ """
+ self.__dataViewer.setData(data)
+
+ def data(self):
+ """Returns the data"""
+ return self.__dataViewer.data()
+
+ def setDisplayedView(self, view):
+ self.__dataViewer.setDisplayedView(view)
+
+ def displayedView(self):
+ return self.__dataViewer.displayedView()
+
+ def displayMode(self):
+ return self.__dataViewer.displayMode()
+
+ def setDisplayMode(self, modeId):
+ """Set the displayed view using display mode.
+
+ Change the displayed view according to the requested mode.
+
+ :param int modeId: Display mode, one of
+
+ - `EMPTY_MODE`: display nothing
+ - `PLOT1D_MODE`: display the data as a curve
+ - `PLOT2D_MODE`: display the data as an image
+ - `TEXT_MODE`: display the data as a text
+ - `ARRAY_MODE`: display the data as a table
+ """
+ return self.__dataViewer.setDisplayMode(modeId)
+
+ def getViewFromModeId(self, modeId):
+ """See :meth:`DataViewer.getViewFromModeId`"""
+ return self.__dataViewer.getViewFromModeId(modeId)
+
+ def replaceView(self, modeId, newView):
+ """Replace one of the builtin data views with a custom view.
+ See :meth:`DataViewer.replaceView` for more documentation.
+
+ :param DataViews.DataView newView: New data view
+ :return: True if replacement was successful, else False
+ """
+ return self.__dataViewer.replaceView(modeId, newView)
diff --git a/src/silx/gui/data/DataViewerSelector.py b/src/silx/gui/data/DataViewerSelector.py
new file mode 100644
index 0000000..a1e9947
--- /dev/null
+++ b/src/silx/gui/data/DataViewerSelector.py
@@ -0,0 +1,175 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module defines a widget to be able to select the available view
+of the DataViewer.
+"""
+from __future__ import division
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "12/02/2019"
+
+import weakref
+import functools
+from silx.gui import qt
+import silx.utils.weakref
+
+
+class DataViewerSelector(qt.QWidget):
+ """Widget to be able to select a custom view from the DataViewer"""
+
+ def __init__(self, parent=None, dataViewer=None):
+ """Constructor
+
+ :param QWidget parent: The parent of the widget
+ :param DataViewer dataViewer: The connected `DataViewer`
+ """
+ super(DataViewerSelector, self).__init__(parent)
+
+ self.__group = None
+ self.__buttons = {}
+ self.__buttonLayout = None
+ self.__buttonDummy = None
+ self.__dataViewer = None
+
+ # Create the fixed layout
+ self.setLayout(qt.QHBoxLayout())
+ layout = self.layout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ self.__buttonLayout = qt.QHBoxLayout()
+ self.__buttonLayout.setContentsMargins(0, 0, 0, 0)
+ layout.addLayout(self.__buttonLayout)
+ layout.addStretch(1)
+
+ if dataViewer is not None:
+ self.setDataViewer(dataViewer)
+
+ def __updateButtons(self):
+ if self.__group is not None:
+ self.__group.deleteLater()
+
+ # Clean up
+ for _, b in self.__buttons.items():
+ b.deleteLater()
+ if self.__buttonDummy is not None:
+ self.__buttonDummy.deleteLater()
+ self.__buttonDummy = None
+ self.__buttons = {}
+ self.__buttonDummy = None
+
+ self.__group = qt.QButtonGroup(self)
+ if self.__dataViewer is None:
+ return
+
+ iconSize = qt.QSize(16, 16)
+
+ for view in self.__dataViewer.getReachableViews():
+ label = view.label()
+ icon = view.icon()
+ button = qt.QPushButton(label)
+ button.setIcon(icon)
+ button.setIconSize(iconSize)
+ button.setCheckable(True)
+ # the weak objects are needed to be able to destroy the widget safely
+ weakView = weakref.ref(view)
+ weakMethod = silx.utils.weakref.WeakMethodProxy(self.__setDisplayedView)
+ callback = functools.partial(weakMethod, weakView)
+ button.clicked.connect(callback)
+ self.__buttonLayout.addWidget(button)
+ self.__group.addButton(button)
+ self.__buttons[view] = button
+
+ button = qt.QPushButton("Dummy")
+ button.setCheckable(True)
+ button.setVisible(False)
+ self.__buttonLayout.addWidget(button)
+ self.__group.addButton(button)
+ self.__buttonDummy = button
+
+ self.__updateButtonsVisibility()
+ self.__displayedViewChanged(self.__dataViewer.displayedView())
+
+ def setDataViewer(self, dataViewer):
+ """Define the dataviewer connected to this status bar
+
+ :param DataViewer dataViewer: The connected `DataViewer`
+ """
+ if self.__dataViewer is dataViewer:
+ return
+ if self.__dataViewer is not None:
+ self.__dataViewer.dataChanged.disconnect(self.__updateButtonsVisibility)
+ self.__dataViewer.displayedViewChanged.disconnect(self.__displayedViewChanged)
+ self.__dataViewer = dataViewer
+ if self.__dataViewer is not None:
+ self.__dataViewer.dataChanged.connect(self.__updateButtonsVisibility)
+ self.__dataViewer.displayedViewChanged.connect(self.__displayedViewChanged)
+ self.__updateButtons()
+
+ def setFlat(self, isFlat):
+ """Set the flat state of all the buttons.
+
+ :param bool isFlat: True to display the buttons flatten.
+ """
+ for b in self.__buttons.values():
+ b.setFlat(isFlat)
+ self.__buttonDummy.setFlat(isFlat)
+
+ def __displayedViewChanged(self, view):
+ """Called on displayed view changes"""
+ selectedButton = self.__buttons.get(view, self.__buttonDummy)
+ selectedButton.setChecked(True)
+
+ def __setDisplayedView(self, refView, clickEvent=None):
+ """Display a data using the requested view
+
+ :param DataView view: Requested view
+ :param clickEvent: Event sent by the clicked event
+ """
+ if self.__dataViewer is None:
+ return
+ view = refView()
+ if view is None:
+ return
+ self.__dataViewer.setDisplayedView(view)
+
+ def __checkAvailableButtons(self):
+ views = set(self.__dataViewer.getReachableViews())
+ if views == set(self.__buttons.keys()):
+ return
+ # Recreate all the buttons
+ # TODO: We dont have to create everything again
+ # We expect the views stay quite stable
+ self.__updateButtons()
+
+ def __updateButtonsVisibility(self):
+ """Called on data changed"""
+ if self.__dataViewer is None:
+ for b in self.__buttons.values():
+ b.setVisible(False)
+ else:
+ self.__checkAvailableButtons()
+ availableViews = set(self.__dataViewer.currentAvailableViews())
+ for view, button in self.__buttons.items():
+ button.setVisible(view in availableViews)
diff --git a/src/silx/gui/data/DataViews.py b/src/silx/gui/data/DataViews.py
new file mode 100644
index 0000000..b18a813
--- /dev/null
+++ b/src/silx/gui/data/DataViews.py
@@ -0,0 +1,2059 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module defines a views used by :class:`silx.gui.data.DataViewer`.
+"""
+
+from collections import OrderedDict
+import logging
+import numbers
+import numpy
+import os
+
+import silx.io
+from silx.utils import deprecation
+from silx.gui import qt, icons
+from silx.gui.data.TextFormatter import TextFormatter
+from silx.io import nxdata
+from silx.gui.hdf5 import H5Node
+from silx.io.nxdata import get_attr_as_unicode
+from silx.gui.colors import Colormap
+from silx.gui.dialog.ColormapDialog import ColormapDialog
+
+__authors__ = ["V. Valls", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "19/02/2019"
+
+_logger = logging.getLogger(__name__)
+
+
+# DataViewer modes
+EMPTY_MODE = 0
+PLOT1D_MODE = 10
+RECORD_PLOT_MODE = 15
+IMAGE_MODE = 20
+PLOT2D_MODE = 21
+COMPLEX_IMAGE_MODE = 22
+PLOT3D_MODE = 30
+RAW_MODE = 40
+RAW_ARRAY_MODE = 41
+RAW_RECORD_MODE = 42
+RAW_SCALAR_MODE = 43
+RAW_HEXA_MODE = 44
+STACK_MODE = 50
+HDF5_MODE = 60
+NXDATA_MODE = 70
+NXDATA_INVALID_MODE = 71
+NXDATA_SCALAR_MODE = 72
+NXDATA_CURVE_MODE = 73
+NXDATA_XYVSCATTER_MODE = 74
+NXDATA_IMAGE_MODE = 75
+NXDATA_STACK_MODE = 76
+NXDATA_VOLUME_MODE = 77
+NXDATA_VOLUME_AS_STACK_MODE = 78
+
+
+def _normalizeData(data):
+ """Returns a normalized data.
+
+ If the data embed a numpy data or a dataset it is returned.
+ Else returns the input data."""
+ if isinstance(data, H5Node):
+ if data.is_broken:
+ return None
+ return data.h5py_object
+ return data
+
+
+def _normalizeComplex(data):
+ """Returns a normalized complex data.
+
+ If the data is a numpy data with complex, returns the
+ absolute value.
+ Else returns the input data."""
+ if hasattr(data, "dtype"):
+ isComplex = numpy.issubdtype(data.dtype, numpy.complexfloating)
+ else:
+ isComplex = isinstance(data, numbers.Complex)
+ if isComplex:
+ data = numpy.absolute(data)
+ return data
+
+
+class DataInfo(object):
+ """Store extracted information from a data"""
+
+ def __init__(self, data):
+ self.__priorities = {}
+ data = self.normalizeData(data)
+ self.isArray = False
+ self.interpretation = None
+ self.isNumeric = False
+ self.isVoid = False
+ self.isComplex = False
+ self.isBoolean = False
+ self.isRecord = False
+ self.hasNXdata = False
+ self.isInvalidNXdata = False
+ self.countNumericColumns = 0
+ self.shape = tuple()
+ self.dim = 0
+ self.size = 0
+
+ if data is None:
+ return
+
+ if silx.io.is_group(data):
+ nxd = nxdata.get_default(data)
+ nx_class = get_attr_as_unicode(data, "NX_class")
+ if nxd is not None:
+ self.hasNXdata = True
+ # can we plot it?
+ is_scalar = nxd.signal_is_0d or nxd.interpretation in ["scalar", "scaler"]
+ if not (is_scalar or nxd.is_curve or nxd.is_x_y_value_scatter or
+ nxd.is_image or nxd.is_stack):
+ # invalid: cannot be plotted by any widget
+ self.isInvalidNXdata = True
+ elif nx_class == "NXdata":
+ # group claiming to be NXdata could not be parsed
+ self.isInvalidNXdata = True
+ elif nx_class == "NXroot" or silx.io.is_file(data):
+ # root claiming to have a default entry
+ if "default" in data.attrs:
+ def_entry = data.attrs["default"]
+ if def_entry in data and "default" in data[def_entry].attrs:
+ # and entry claims to have default NXdata
+ self.isInvalidNXdata = True
+ elif "default" in data.attrs:
+ # group claiming to have a default NXdata could not be parsed
+ self.isInvalidNXdata = True
+
+ if isinstance(data, numpy.ndarray):
+ self.isArray = True
+ elif silx.io.is_dataset(data) and data.shape != tuple():
+ self.isArray = True
+ else:
+ self.isArray = False
+
+ if silx.io.is_dataset(data):
+ if "interpretation" in data.attrs:
+ self.interpretation = get_attr_as_unicode(data, "interpretation")
+ else:
+ self.interpretation = None
+ elif self.hasNXdata:
+ self.interpretation = nxd.interpretation
+ else:
+ self.interpretation = None
+
+ if hasattr(data, "dtype"):
+ if numpy.issubdtype(data.dtype, numpy.void):
+ # That's a real opaque type, else it is a structured type
+ self.isVoid = data.dtype.fields is None
+ self.isNumeric = numpy.issubdtype(data.dtype, numpy.number)
+ self.isRecord = data.dtype.fields is not None
+ self.isComplex = numpy.issubdtype(data.dtype, numpy.complexfloating)
+ self.isBoolean = numpy.issubdtype(data.dtype, numpy.bool_)
+ elif self.hasNXdata:
+ self.isNumeric = numpy.issubdtype(nxd.signal.dtype,
+ numpy.number)
+ self.isComplex = numpy.issubdtype(nxd.signal.dtype, numpy.complexfloating)
+ self.isBoolean = numpy.issubdtype(nxd.signal.dtype, numpy.bool_)
+ else:
+ self.isNumeric = isinstance(data, numbers.Number)
+ self.isComplex = isinstance(data, numbers.Complex)
+ self.isBoolean = isinstance(data, bool)
+ self.isRecord = False
+
+ if hasattr(data, "shape"):
+ self.shape = data.shape
+ elif self.hasNXdata:
+ self.shape = nxd.signal.shape
+ else:
+ self.shape = tuple()
+ if self.shape is not None:
+ self.dim = len(self.shape)
+
+ if hasattr(data, "shape") and data.shape is None:
+ # This test is expected to avoid to fall done on the h5py issue
+ # https://github.com/h5py/h5py/issues/1044
+ self.size = 0
+ elif hasattr(data, "size"):
+ self.size = int(data.size)
+ else:
+ self.size = 1
+
+ if hasattr(data, "dtype"):
+ if data.dtype.fields is not None:
+ for field in data.dtype.fields:
+ if numpy.issubdtype(data.dtype[field], numpy.number):
+ self.countNumericColumns += 1
+
+ def normalizeData(self, data):
+ """Returns a normalized data if the embed a numpy or a dataset.
+ Else returns the data."""
+ return _normalizeData(data)
+
+ def cachePriority(self, view, priority):
+ self.__priorities[view] = priority
+
+ def getPriority(self, view):
+ return self.__priorities[view]
+
+
+class DataViewHooks(object):
+ """A set of hooks defined to custom the behaviour of the data views."""
+
+ def getColormap(self, view):
+ """Returns a colormap for this view."""
+ return None
+
+ def getColormapDialog(self, view):
+ """Returns a color dialog for this view."""
+ return None
+
+ def viewWidgetCreated(self, view, plot):
+ """Called when the widget of the view was created"""
+ return
+
+class DataView(object):
+ """Holder for the data view."""
+
+ UNSUPPORTED = -1
+ """Priority returned when the requested data can't be displayed by the
+ view."""
+
+ TITLE_PATTERN = "{datapath}{slicing} {permuted}"
+ """Pattern used to format the title of the plot.
+
+ Supported fields: `{directory}`, `{filename}`, `{datapath}`, `{slicing}`, `{permuted}`.
+ """
+
+ def __init__(self, parent, modeId=None, icon=None, label=None):
+ """Constructor
+
+ :param qt.QWidget parent: Parent of the hold widget
+ """
+ self.__parent = parent
+ self.__widget = None
+ self.__modeId = modeId
+ if label is None:
+ label = self.__class__.__name__
+ self.__label = label
+ if icon is None:
+ icon = qt.QIcon()
+ self.__icon = icon
+ self.__hooks = None
+
+ def getHooks(self):
+ """Returns the data viewer hooks used by this view.
+
+ :rtype: DataViewHooks
+ """
+ return self.__hooks
+
+ def setHooks(self, hooks):
+ """Set the data view hooks to use with this view.
+
+ :param DataViewHooks hooks: The data view hooks to use
+ """
+ self.__hooks = hooks
+
+ def defaultColormap(self):
+ """Returns a default colormap.
+
+ :rtype: Colormap
+ """
+ colormap = None
+ if self.__hooks is not None:
+ colormap = self.__hooks.getColormap(self)
+ if colormap is None:
+ colormap = Colormap(name="viridis")
+ return colormap
+
+ def defaultColorDialog(self):
+ """Returns a default color dialog.
+
+ :rtype: ColormapDialog
+ """
+ dialog = None
+ if self.__hooks is not None:
+ dialog = self.__hooks.getColormapDialog(self)
+ if dialog is None:
+ dialog = ColormapDialog()
+ dialog.setModal(False)
+ return dialog
+
+ def icon(self):
+ """Returns the default icon"""
+ return self.__icon
+
+ def label(self):
+ """Returns the default label"""
+ return self.__label
+
+ def modeId(self):
+ """Returns the mode id"""
+ return self.__modeId
+
+ def normalizeData(self, data):
+ """Returns a normalized data if the embed a numpy or a dataset.
+ Else returns the data."""
+ return _normalizeData(data)
+
+ def customAxisNames(self):
+ """Returns names of axes which can be custom by the user and provided
+ to the view."""
+ return []
+
+ def setCustomAxisValue(self, name, value):
+ """
+ Set the value of a custom axis
+
+ :param str name: Name of the custom axis
+ :param int value: Value of the custom axis
+ """
+ pass
+
+ def isWidgetInitialized(self):
+ """Returns true if the widget is already initialized.
+ """
+ return self.__widget is not None
+
+ def select(self):
+ """Called when the view is selected to display the data.
+ """
+ return
+
+ def getWidget(self):
+ """Returns the widget hold in the view and displaying the data.
+
+ :returns: qt.QWidget
+ """
+ if self.__widget is None:
+ self.__widget = self.createWidget(self.__parent)
+ hooks = self.getHooks()
+ if hooks is not None:
+ hooks.viewWidgetCreated(self, self.__widget)
+ return self.__widget
+
+ def createWidget(self, parent):
+ """Create the the widget displaying the data
+
+ :param qt.QWidget parent: Parent of the widget
+ :returns: qt.QWidget
+ """
+ raise NotImplementedError()
+
+ def clear(self):
+ """Clear the data from the view"""
+ return None
+
+ def setData(self, data):
+ """Set the data displayed by the view
+
+ :param data: Data to display
+ :type data: numpy.ndarray or h5py.Dataset
+ """
+ return None
+
+ def __formatSlices(self, indices):
+ """Format an iterable of slice objects
+
+ :param indices: The slices to format
+ :type indices: Union[None,List[Union[slice,int]]]
+ :rtype: str
+ """
+ if indices is None:
+ return ''
+
+ def formatSlice(slice_):
+ start, stop, step = slice_.start, slice_.stop, slice_.step
+ string = ('' if start is None else str(start)) + ':'
+ if stop is not None:
+ string += str(stop)
+ if step not in (None, 1):
+ string += ':' + step
+ return string
+
+ return '[' + ', '.join(
+ formatSlice(index) if isinstance(index, slice) else str(index)
+ for index in indices) + ']'
+
+ def titleForSelection(self, selection):
+ """Build title from given selection information.
+
+ :param NamedTuple selection: Data selected
+ :rtype: str
+ """
+ if selection is None or selection.filename is None:
+ return None
+ else:
+ directory, filename = os.path.split(selection.filename)
+ try:
+ slicing = self.__formatSlices(selection.slice)
+ except Exception:
+ _logger.debug("Error while formatting slices", exc_info=True)
+ slicing = '[sliced]'
+
+ permuted = '(permuted)' if selection.permutation is not None else ''
+
+ try:
+ title = self.TITLE_PATTERN.format(
+ directory=directory,
+ filename=filename,
+ datapath=selection.datapath,
+ slicing=slicing,
+ permuted=permuted)
+ except Exception:
+ _logger.debug("Error while formatting title", exc_info=True)
+ title = selection.datapath + slicing
+
+ return title
+
+ def setDataSelection(self, selection):
+ """Set the data selection displayed by the view
+
+ If called, it have to be called directly after `setData`.
+
+ :param selection: Data selected
+ :type selection: NamedTuple
+ """
+ pass
+
+ def axesNames(self, data, info):
+ """Returns names of the expected axes of the view, according to the
+ input data. A none value will disable the default axes selectior.
+
+ :param data: Data to display
+ :type data: numpy.ndarray or h5py.Dataset
+ :param DataInfo info: Pre-computed information on the data
+ :rtype: list[str] or None
+ """
+ return []
+
+ def getReachableViews(self):
+ """Returns the views that can be returned by `getMatchingViews`.
+
+ :param object data: Any object to be displayed
+ :param DataInfo info: Information cached about this data
+ :rtype: List[DataView]
+ """
+ return [self]
+
+ def getMatchingViews(self, data, info):
+ """Returns the views according to data and info from the data.
+
+ :param object data: Any object to be displayed
+ :param DataInfo info: Information cached about this data
+ :rtype: List[DataView]
+ """
+ priority = self.getCachedDataPriority(data, info)
+ if priority == DataView.UNSUPPORTED:
+ return []
+ return [self]
+
+ def getCachedDataPriority(self, data, info):
+ try:
+ priority = info.getPriority(self)
+ except KeyError:
+ priority = self.getDataPriority(data, info)
+ info.cachePriority(self, priority)
+ return priority
+
+ def getDataPriority(self, data, info):
+ """
+ Returns the priority of using this view according to a data.
+
+ - `UNSUPPORTED` means this view can't display this data
+ - `1` means this view can display the data
+ - `100` means this view should be used for this data
+ - `1000` max value used by the views provided by silx
+ - ...
+
+ :param object data: The data to check
+ :param DataInfo info: Pre-computed information on the data
+ :rtype: int
+ """
+ return DataView.UNSUPPORTED
+
+ def __lt__(self, other):
+ return str(self) < str(other)
+
+
+class _CompositeDataView(DataView):
+ """Contains sub views"""
+
+ def getViews(self):
+ """Returns the direct sub views registered in this view.
+
+ :rtype: List[DataView]
+ """
+ raise NotImplementedError()
+
+ def getReachableViews(self):
+ """Returns all views that can be reachable at on point.
+
+ This method return any sub view provided (recursivly).
+
+ :rtype: List[DataView]
+ """
+ raise NotImplementedError()
+
+ def getMatchingViews(self, data, info):
+ """Returns sub views matching this data and info.
+
+ This method return any sub view provided (recursivly).
+
+ :param object data: Any object to be displayed
+ :param DataInfo info: Information cached about this data
+ :rtype: List[DataView]
+ """
+ raise NotImplementedError()
+
+ @deprecation.deprecated(replacement="getReachableViews", since_version="0.10")
+ def availableViews(self):
+ return self.getViews()
+
+ def isSupportedData(self, data, info):
+ """If true, the composite view allow sub views to access to this data.
+ Else this this data is considered as not supported by any of sub views
+ (incliding this composite view).
+
+ :param object data: Any object to be displayed
+ :param DataInfo info: Information cached about this data
+ :rtype: bool
+ """
+ return True
+
+
+class SelectOneDataView(_CompositeDataView):
+ """Data view which can display a data using different view according to
+ the kind of the data."""
+
+ def __init__(self, parent, modeId=None, icon=None, label=None):
+ """Constructor
+
+ :param qt.QWidget parent: Parent of the hold widget
+ """
+ super(SelectOneDataView, self).__init__(parent, modeId, icon, label)
+ self.__views = OrderedDict()
+ self.__currentView = None
+
+ def setHooks(self, hooks):
+ """Set the data context to use with this view.
+
+ :param DataViewHooks hooks: The data view hooks to use
+ """
+ super(SelectOneDataView, self).setHooks(hooks)
+ if hooks is not None:
+ for v in self.__views:
+ v.setHooks(hooks)
+
+ def addView(self, dataView):
+ """Add a new dataview to the available list."""
+ hooks = self.getHooks()
+ if hooks is not None:
+ dataView.setHooks(hooks)
+ self.__views[dataView] = None
+
+ def getReachableViews(self):
+ views = []
+ addSelf = False
+ for v in self.__views:
+ if isinstance(v, SelectManyDataView):
+ views.extend(v.getReachableViews())
+ else:
+ addSelf = True
+ if addSelf:
+ # Single views are hidden by this view
+ views.insert(0, self)
+ return views
+
+ def getMatchingViews(self, data, info):
+ if not self.isSupportedData(data, info):
+ return []
+ view = self.__getBestView(data, info)
+ if isinstance(view, SelectManyDataView):
+ return view.getMatchingViews(data, info)
+ else:
+ return [self]
+
+ def getViews(self):
+ """Returns the list of registered views
+
+ :rtype: List[DataView]
+ """
+ return list(self.__views.keys())
+
+ def __getBestView(self, data, info):
+ """Returns the best view according to priorities."""
+ if not self.isSupportedData(data, info):
+ return None
+ views = [(v.getCachedDataPriority(data, info), v) for v in self.__views.keys()]
+ views = filter(lambda t: t[0] > DataView.UNSUPPORTED, views)
+ views = sorted(views, key=lambda t: t[0], reverse=True)
+
+ if len(views) == 0:
+ return None
+ elif views[0][0] == DataView.UNSUPPORTED:
+ return None
+ else:
+ return views[0][1]
+
+ def customAxisNames(self):
+ if self.__currentView is None:
+ return
+ return self.__currentView.customAxisNames()
+
+ def setCustomAxisValue(self, name, value):
+ if self.__currentView is None:
+ return
+ self.__currentView.setCustomAxisValue(name, value)
+
+ def __updateDisplayedView(self):
+ widget = self.getWidget()
+ if self.__currentView is None:
+ return
+
+ # load the widget if it is not yet done
+ index = self.__views[self.__currentView]
+ if index is None:
+ w = self.__currentView.getWidget()
+ index = widget.addWidget(w)
+ self.__views[self.__currentView] = index
+ if widget.currentIndex() != index:
+ widget.setCurrentIndex(index)
+ self.__currentView.select()
+
+ def select(self):
+ self.__updateDisplayedView()
+ if self.__currentView is not None:
+ self.__currentView.select()
+
+ def createWidget(self, parent):
+ return qt.QStackedWidget()
+
+ def clear(self):
+ for v in self.__views.keys():
+ v.clear()
+
+ def setData(self, data):
+ if self.__currentView is None:
+ return
+ self.__updateDisplayedView()
+ self.__currentView.setData(data)
+
+ def setDataSelection(self, selection):
+ if self.__currentView is None:
+ return
+ self.__currentView.setDataSelection(selection)
+
+ def axesNames(self, data, info):
+ view = self.__getBestView(data, info)
+ self.__currentView = view
+ return view.axesNames(data, info)
+
+ def getDataPriority(self, data, info):
+ view = self.__getBestView(data, info)
+ self.__currentView = view
+ if view is None:
+ return DataView.UNSUPPORTED
+ else:
+ return view.getCachedDataPriority(data, info)
+
+ def replaceView(self, modeId, newView):
+ """Replace a data view with a custom view.
+ Return True in case of success, False in case of failure.
+
+ .. note::
+
+ This method must be called just after instantiation, before
+ the viewer is used.
+
+ :param int modeId: Unique mode ID identifying the DataView to
+ be replaced.
+ :param DataViews.DataView newView: New data view
+ :return: True if replacement was successful, else False
+ """
+ oldView = None
+ for view in self.__views:
+ if view.modeId() == modeId:
+ oldView = view
+ break
+ elif isinstance(view, _CompositeDataView):
+ # recurse
+ hooks = self.getHooks()
+ if hooks is not None:
+ newView.setHooks(hooks)
+ if view.replaceView(modeId, newView):
+ return True
+ if oldView is None:
+ return False
+
+ # replace oldView with new view in dict
+ self.__views = OrderedDict(
+ (newView, None) if view is oldView else (view, idx) for
+ view, idx in self.__views.items())
+ return True
+
+
+# NOTE: SelectOneDataView was introduced with silx 0.10
+CompositeDataView = SelectOneDataView
+
+
+class SelectManyDataView(_CompositeDataView):
+ """Data view which can select a set of sub views according to
+ the kind of the data.
+
+ This view itself is abstract and is not exposed.
+ """
+
+ def __init__(self, parent, views=None):
+ """Constructor
+
+ :param qt.QWidget parent: Parent of the hold widget
+ """
+ super(SelectManyDataView, self).__init__(parent, modeId=None, icon=None, label=None)
+ if views is None:
+ views = []
+ self.__views = views
+
+ def setHooks(self, hooks):
+ """Set the data context to use with this view.
+
+ :param DataViewHooks hooks: The data view hooks to use
+ """
+ super(SelectManyDataView, self).setHooks(hooks)
+ if hooks is not None:
+ for v in self.__views:
+ v.setHooks(hooks)
+
+ def addView(self, dataView):
+ """Add a new dataview to the available list."""
+ hooks = self.getHooks()
+ if hooks is not None:
+ dataView.setHooks(hooks)
+ self.__views.append(dataView)
+
+ def getViews(self):
+ """Returns the list of registered views
+
+ :rtype: List[DataView]
+ """
+ return list(self.__views)
+
+ def getReachableViews(self):
+ views = []
+ for v in self.__views:
+ views.extend(v.getReachableViews())
+ return views
+
+ def getMatchingViews(self, data, info):
+ """Returns the views according to data and info from the data.
+
+ :param object data: Any object to be displayed
+ :param DataInfo info: Information cached about this data
+ """
+ if not self.isSupportedData(data, info):
+ return []
+ views = [v for v in self.__views if v.getCachedDataPriority(data, info) != DataView.UNSUPPORTED]
+ return views
+
+ def customAxisNames(self):
+ raise RuntimeError("Abstract view")
+
+ def setCustomAxisValue(self, name, value):
+ raise RuntimeError("Abstract view")
+
+ def select(self):
+ raise RuntimeError("Abstract view")
+
+ def createWidget(self, parent):
+ raise RuntimeError("Abstract view")
+
+ def clear(self):
+ for v in self.__views:
+ v.clear()
+
+ def setData(self, data):
+ raise RuntimeError("Abstract view")
+
+ def axesNames(self, data, info):
+ raise RuntimeError("Abstract view")
+
+ def getDataPriority(self, data, info):
+ if not self.isSupportedData(data, info):
+ return DataView.UNSUPPORTED
+ priorities = [v.getCachedDataPriority(data, info) for v in self.__views]
+ priorities = [v for v in priorities if v != DataView.UNSUPPORTED]
+ priorities = sorted(priorities)
+ if len(priorities) == 0:
+ return DataView.UNSUPPORTED
+ return priorities[-1]
+
+ def replaceView(self, modeId, newView):
+ """Replace a data view with a custom view.
+ Return True in case of success, False in case of failure.
+
+ .. note::
+
+ This method must be called just after instantiation, before
+ the viewer is used.
+
+ :param int modeId: Unique mode ID identifying the DataView to
+ be replaced.
+ :param DataViews.DataView newView: New data view
+ :return: True if replacement was successful, else False
+ """
+ oldView = None
+ for iview, view in enumerate(self.__views):
+ if view.modeId() == modeId:
+ oldView = view
+ break
+ elif isinstance(view, CompositeDataView):
+ # recurse
+ hooks = self.getHooks()
+ if hooks is not None:
+ newView.setHooks(hooks)
+ if view.replaceView(modeId, newView):
+ return True
+
+ if oldView is None:
+ return False
+
+ # replace oldView with new view in dict
+ self.__views[iview] = newView
+ return True
+
+
+class _EmptyView(DataView):
+ """Dummy view to display nothing"""
+
+ def __init__(self, parent):
+ DataView.__init__(self, parent, modeId=EMPTY_MODE)
+
+ def axesNames(self, data, info):
+ return None
+
+ def createWidget(self, parent):
+ return qt.QLabel(parent)
+
+ def getDataPriority(self, data, info):
+ return DataView.UNSUPPORTED
+
+
+class _Plot1dView(DataView):
+ """View displaying data using a 1d plot"""
+
+ def __init__(self, parent):
+ super(_Plot1dView, self).__init__(
+ parent=parent,
+ modeId=PLOT1D_MODE,
+ label="Curve",
+ icon=icons.getQIcon("view-1d"))
+ self.__resetZoomNextTime = True
+
+ def createWidget(self, parent):
+ from silx.gui import plot
+ return plot.Plot1D(parent=parent)
+
+ def clear(self):
+ self.getWidget().clear()
+ self.__resetZoomNextTime = True
+
+ def normalizeData(self, data):
+ data = DataView.normalizeData(self, data)
+ data = _normalizeComplex(data)
+ return data
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ plotWidget = self.getWidget()
+ legend = "data"
+ plotWidget.addCurve(legend=legend,
+ x=range(len(data)),
+ y=data,
+ resetzoom=self.__resetZoomNextTime)
+ plotWidget.setActiveCurve(legend)
+ self.__resetZoomNextTime = True
+
+ def setDataSelection(self, selection):
+ self.getWidget().setGraphTitle(self.titleForSelection(selection))
+
+ def axesNames(self, data, info):
+ return ["y"]
+
+ def getDataPriority(self, data, info):
+ if info.size <= 0:
+ return DataView.UNSUPPORTED
+ if data is None or not info.isArray or not info.isNumeric:
+ return DataView.UNSUPPORTED
+ if info.dim < 1:
+ return DataView.UNSUPPORTED
+ if info.interpretation == "spectrum":
+ return 1000
+ if info.dim == 2 and info.shape[0] == 1:
+ return 210
+ if info.dim == 1:
+ return 100
+ else:
+ return 10
+
+
+class _Plot2dRecordView(DataView):
+ def __init__(self, parent):
+ super(_Plot2dRecordView, self).__init__(
+ parent=parent,
+ modeId=RECORD_PLOT_MODE,
+ label="Curve",
+ icon=icons.getQIcon("view-1d"))
+ self.__resetZoomNextTime = True
+ self._data = None
+ self._xAxisDropDown = None
+ self._yAxisDropDown = None
+ self.__fields = None
+
+ def createWidget(self, parent):
+ from ._RecordPlot import RecordPlot
+ return RecordPlot(parent=parent)
+
+ def clear(self):
+ self.getWidget().clear()
+ self.__resetZoomNextTime = True
+
+ def normalizeData(self, data):
+ data = DataView.normalizeData(self, data)
+ data = _normalizeComplex(data)
+ return data
+
+ def setData(self, data):
+ self._data = self.normalizeData(data)
+
+ all_fields = sorted(self._data.dtype.fields.items(), key=lambda e: e[1][1])
+ numeric_fields = [f[0] for f in all_fields if numpy.issubdtype(f[1][0], numpy.number)]
+ if numeric_fields == self.__fields: # Reuse previously selected fields
+ fieldNameX = self.getWidget().getXAxisFieldName()
+ fieldNameY = self.getWidget().getYAxisFieldName()
+ else:
+ self.__fields = numeric_fields
+
+ self.getWidget().setSelectableXAxisFieldNames(numeric_fields)
+ self.getWidget().setSelectableYAxisFieldNames(numeric_fields)
+ fieldNameX = None
+ fieldNameY = numeric_fields[0]
+
+ # If there is a field called time, use it for the x-axis by default
+ if "time" in numeric_fields:
+ fieldNameX = "time"
+ # Use the first field that is not "time" for the y-axis
+ if fieldNameY == "time" and len(numeric_fields) >= 2:
+ fieldNameY = numeric_fields[1]
+
+ self._plotData(fieldNameX, fieldNameY)
+
+ if not self._xAxisDropDown:
+ self._xAxisDropDown = self.getWidget().getAxesSelectionToolBar().getXAxisDropDown()
+ self._yAxisDropDown = self.getWidget().getAxesSelectionToolBar().getYAxisDropDown()
+ self._xAxisDropDown.activated.connect(self._onAxesSelectionChaned)
+ self._yAxisDropDown.activated.connect(self._onAxesSelectionChaned)
+
+ def setDataSelection(self, selection):
+ self.getWidget().setGraphTitle(self.titleForSelection(selection))
+
+ def _onAxesSelectionChaned(self):
+ fieldNameX = self._xAxisDropDown.currentData()
+ self._plotData(fieldNameX, self._yAxisDropDown.currentText())
+
+ def _plotData(self, fieldNameX, fieldNameY):
+ self.clear()
+ ydata = self._data[fieldNameY]
+ if fieldNameX is None:
+ xdata = numpy.arange(len(ydata))
+ else:
+ xdata = self._data[fieldNameX]
+ self.getWidget().addCurve(legend="data",
+ x=xdata,
+ y=ydata,
+ resetzoom=self.__resetZoomNextTime)
+ self.getWidget().setXAxisFieldName(fieldNameX)
+ self.getWidget().setYAxisFieldName(fieldNameY)
+ self.__resetZoomNextTime = True
+
+ def axesNames(self, data, info):
+ return ["data"]
+
+ def getDataPriority(self, data, info):
+ if info.size <= 0:
+ return DataView.UNSUPPORTED
+ if data is None or not info.isRecord:
+ return DataView.UNSUPPORTED
+ if info.dim < 1:
+ return DataView.UNSUPPORTED
+ if info.countNumericColumns < 2:
+ return DataView.UNSUPPORTED
+ if info.interpretation == "spectrum":
+ return 1000
+ if info.dim == 2 and info.shape[0] == 1:
+ return 210
+ if info.dim == 1:
+ return 40
+ else:
+ return 10
+
+
+class _Plot2dView(DataView):
+ """View displaying data using a 2d plot"""
+
+ def __init__(self, parent):
+ super(_Plot2dView, self).__init__(
+ parent=parent,
+ modeId=PLOT2D_MODE,
+ label="Image",
+ icon=icons.getQIcon("view-2d"))
+ self.__resetZoomNextTime = True
+
+ def createWidget(self, parent):
+ from silx.gui import plot
+ widget = plot.Plot2D(parent=parent)
+ widget.setDefaultColormap(self.defaultColormap())
+ widget.getColormapAction().setColorDialog(self.defaultColorDialog())
+ widget.getIntensityHistogramAction().setVisible(True)
+ widget.setKeepDataAspectRatio(True)
+ widget.getXAxis().setLabel('X')
+ widget.getYAxis().setLabel('Y')
+ maskToolsWidget = widget.getMaskToolsDockWidget().widget()
+ maskToolsWidget.setItemMaskUpdated(True)
+ return widget
+
+ def clear(self):
+ self.getWidget().clear()
+ self.__resetZoomNextTime = True
+
+ def normalizeData(self, data):
+ data = DataView.normalizeData(self, data)
+ data = _normalizeComplex(data)
+ return data
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ self.getWidget().addImage(legend="data",
+ data=data,
+ resetzoom=self.__resetZoomNextTime)
+ self.__resetZoomNextTime = False
+
+ def setDataSelection(self, selection):
+ self.getWidget().setGraphTitle(self.titleForSelection(selection))
+
+ def axesNames(self, data, info):
+ return ["y", "x"]
+
+ def getDataPriority(self, data, info):
+ if info.size <= 0:
+ return DataView.UNSUPPORTED
+ if (data is None or
+ not info.isArray or
+ not (info.isNumeric or info.isBoolean)):
+ return DataView.UNSUPPORTED
+ if info.dim < 2:
+ return DataView.UNSUPPORTED
+ if info.interpretation == "image":
+ return 1000
+ if info.dim == 2:
+ return 200
+ else:
+ return 190
+
+
+class _Plot3dView(DataView):
+ """View displaying data using a 3d plot"""
+
+ def __init__(self, parent):
+ super(_Plot3dView, self).__init__(
+ parent=parent,
+ modeId=PLOT3D_MODE,
+ label="Cube",
+ icon=icons.getQIcon("view-3d"))
+ try:
+ from ._VolumeWindow import VolumeWindow # noqa
+ except ImportError:
+ _logger.warning("3D visualization is not available")
+ _logger.debug("Backtrace", exc_info=True)
+ raise
+ self.__resetZoomNextTime = True
+
+ def createWidget(self, parent):
+ from ._VolumeWindow import VolumeWindow
+
+ plot = VolumeWindow(parent)
+ plot.setAxesLabels(*reversed(self.axesNames(None, None)))
+ return plot
+
+ def clear(self):
+ self.getWidget().clear()
+ self.__resetZoomNextTime = True
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ self.getWidget().setData(data)
+ self.__resetZoomNextTime = False
+
+ def axesNames(self, data, info):
+ return ["z", "y", "x"]
+
+ def getDataPriority(self, data, info):
+ if info.size <= 0:
+ return DataView.UNSUPPORTED
+ if data is None or not info.isArray or not info.isNumeric:
+ return DataView.UNSUPPORTED
+ if info.dim < 3:
+ return DataView.UNSUPPORTED
+ if min(data.shape) < 2:
+ return DataView.UNSUPPORTED
+ if info.dim == 3:
+ return 100
+ else:
+ return 10
+
+
+class _ComplexImageView(DataView):
+ """View displaying data using a ComplexImageView"""
+
+ def __init__(self, parent):
+ super(_ComplexImageView, self).__init__(
+ parent=parent,
+ modeId=COMPLEX_IMAGE_MODE,
+ label="Complex Image",
+ icon=icons.getQIcon("view-2d"))
+
+ def createWidget(self, parent):
+ from silx.gui.plot.ComplexImageView import ComplexImageView
+ widget = ComplexImageView(parent=parent)
+ widget.setColormap(self.defaultColormap(), mode=ComplexImageView.ComplexMode.ABSOLUTE)
+ widget.setColormap(self.defaultColormap(), mode=ComplexImageView.ComplexMode.SQUARE_AMPLITUDE)
+ widget.setColormap(self.defaultColormap(), mode=ComplexImageView.ComplexMode.REAL)
+ widget.setColormap(self.defaultColormap(), mode=ComplexImageView.ComplexMode.IMAGINARY)
+ widget.getPlot().getColormapAction().setColorDialog(self.defaultColorDialog())
+ widget.getPlot().getIntensityHistogramAction().setVisible(True)
+ widget.getPlot().setKeepDataAspectRatio(True)
+ widget.getXAxis().setLabel('X')
+ widget.getYAxis().setLabel('Y')
+ maskToolsWidget = widget.getPlot().getMaskToolsDockWidget().widget()
+ maskToolsWidget.setItemMaskUpdated(True)
+ return widget
+
+ def clear(self):
+ self.getWidget().setData(None)
+
+ def normalizeData(self, data):
+ data = DataView.normalizeData(self, data)
+ return data
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ self.getWidget().setData(data)
+
+ def setDataSelection(self, selection):
+ self.getWidget().getPlot().setGraphTitle(
+ self.titleForSelection(selection))
+
+ def axesNames(self, data, info):
+ return ["y", "x"]
+
+ def getDataPriority(self, data, info):
+ if info.size <= 0:
+ return DataView.UNSUPPORTED
+ if data is None or not info.isArray or not info.isComplex:
+ return DataView.UNSUPPORTED
+ if info.dim < 2:
+ return DataView.UNSUPPORTED
+ if info.interpretation == "image":
+ return 1000
+ if info.dim == 2:
+ return 200
+ else:
+ return 190
+
+
+class _ArrayView(DataView):
+ """View displaying data using a 2d table"""
+
+ def __init__(self, parent):
+ DataView.__init__(self, parent, modeId=RAW_ARRAY_MODE)
+
+ def createWidget(self, parent):
+ from silx.gui.data.ArrayTableWidget import ArrayTableWidget
+ widget = ArrayTableWidget(parent)
+ widget.displayAxesSelector(False)
+ return widget
+
+ def clear(self):
+ self.getWidget().setArrayData(numpy.array([[]]))
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ self.getWidget().setArrayData(data)
+
+ def axesNames(self, data, info):
+ return ["col", "row"]
+
+ def getDataPriority(self, data, info):
+ if info.size <= 0:
+ return DataView.UNSUPPORTED
+ if data is None or not info.isArray or info.isRecord:
+ return DataView.UNSUPPORTED
+ if info.dim < 2:
+ return DataView.UNSUPPORTED
+ if info.interpretation in ["scalar", "scaler"]:
+ return 1000
+ return 500
+
+
+class _StackView(DataView):
+ """View displaying data using a stack of images"""
+
+ def __init__(self, parent):
+ super(_StackView, self).__init__(
+ parent=parent,
+ modeId=STACK_MODE,
+ label="Image stack",
+ icon=icons.getQIcon("view-2d-stack"))
+ self.__resetZoomNextTime = True
+
+ def customAxisNames(self):
+ return ["depth"]
+
+ def setCustomAxisValue(self, name, value):
+ if name == "depth":
+ self.getWidget().setFrameNumber(value)
+ else:
+ raise Exception("Unsupported axis")
+
+ def createWidget(self, parent):
+ from silx.gui import plot
+ widget = plot.StackView(parent=parent)
+ widget.setColormap(self.defaultColormap())
+ widget.getPlotWidget().getColormapAction().setColorDialog(self.defaultColorDialog())
+ widget.setKeepDataAspectRatio(True)
+ widget.setLabels(self.axesNames(None, None))
+ # hide default option panel
+ widget.setOptionVisible(False)
+ maskToolWidget = widget.getPlotWidget().getMaskToolsDockWidget().widget()
+ maskToolWidget.setItemMaskUpdated(True)
+ return widget
+
+ def clear(self):
+ self.getWidget().clear()
+ self.__resetZoomNextTime = True
+
+ def normalizeData(self, data):
+ data = DataView.normalizeData(self, data)
+ data = _normalizeComplex(data)
+ return data
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ self.getWidget().setStack(stack=data, reset=self.__resetZoomNextTime)
+ # Override the colormap, while setStack overwrite it
+ self.getWidget().setColormap(self.defaultColormap())
+ self.__resetZoomNextTime = False
+
+ def setDataSelection(self, selection):
+ title = self.titleForSelection(selection)
+ self.getWidget().setTitleCallback(
+ lambda idx: "%s z=%d" % (title, idx))
+
+ def axesNames(self, data, info):
+ return ["depth", "y", "x"]
+
+ def getDataPriority(self, data, info):
+ if info.size <= 0:
+ return DataView.UNSUPPORTED
+ if data is None or not info.isArray or not info.isNumeric:
+ return DataView.UNSUPPORTED
+ if info.dim < 3:
+ return DataView.UNSUPPORTED
+ if info.interpretation == "image":
+ return 500
+ return 90
+
+
+class _ScalarView(DataView):
+ """View displaying data using text"""
+
+ def __init__(self, parent):
+ DataView.__init__(self, parent, modeId=RAW_SCALAR_MODE)
+
+ def createWidget(self, parent):
+ widget = qt.QTextEdit(parent)
+ widget.setTextInteractionFlags(qt.Qt.TextSelectableByMouse)
+ widget.setAlignment(qt.Qt.AlignLeft | qt.Qt.AlignTop)
+ self.__formatter = TextFormatter(parent)
+ return widget
+
+ def clear(self):
+ self.getWidget().setText("")
+
+ def setData(self, data):
+ d = self.normalizeData(data)
+ if silx.io.is_dataset(d):
+ d = d[()]
+ dtype = None
+ if data is not None:
+ if hasattr(data, "dtype"):
+ dtype = data.dtype
+ text = self.__formatter.toString(d, dtype)
+ self.getWidget().setText(text)
+
+ def axesNames(self, data, info):
+ return []
+
+ def getDataPriority(self, data, info):
+ if info.size <= 0:
+ return DataView.UNSUPPORTED
+ data = self.normalizeData(data)
+ if info.shape is None:
+ return DataView.UNSUPPORTED
+ if data is None:
+ return DataView.UNSUPPORTED
+ if silx.io.is_group(data):
+ return DataView.UNSUPPORTED
+ return 2
+
+
+class _RecordView(DataView):
+ """View displaying data using text"""
+
+ def __init__(self, parent):
+ DataView.__init__(self, parent, modeId=RAW_RECORD_MODE)
+
+ def createWidget(self, parent):
+ from .RecordTableView import RecordTableView
+ widget = RecordTableView(parent)
+ widget.setWordWrap(False)
+ return widget
+
+ def clear(self):
+ self.getWidget().setArrayData(None)
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ widget = self.getWidget()
+ widget.setArrayData(data)
+ if len(data) < 100:
+ widget.resizeRowsToContents()
+ widget.resizeColumnsToContents()
+
+ def axesNames(self, data, info):
+ return ["data"]
+
+ def getDataPriority(self, data, info):
+ if info.size <= 0:
+ return DataView.UNSUPPORTED
+ if info.isRecord:
+ return 40
+ if data is None or not info.isArray:
+ return DataView.UNSUPPORTED
+ if info.dim == 1:
+ if info.interpretation in ["scalar", "scaler"]:
+ return 1000
+ if info.shape[0] == 1:
+ return 510
+ return 500
+ elif info.isRecord:
+ return 40
+ return DataView.UNSUPPORTED
+
+
+class _HexaView(DataView):
+ """View displaying data using text"""
+
+ def __init__(self, parent):
+ DataView.__init__(self, parent, modeId=RAW_HEXA_MODE)
+
+ def createWidget(self, parent):
+ from .HexaTableView import HexaTableView
+ widget = HexaTableView(parent)
+ return widget
+
+ def clear(self):
+ self.getWidget().setArrayData(None)
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ widget = self.getWidget()
+ widget.setArrayData(data)
+
+ def axesNames(self, data, info):
+ return []
+
+ def getDataPriority(self, data, info):
+ if info.size <= 0:
+ return DataView.UNSUPPORTED
+ if info.isVoid:
+ return 2000
+ return DataView.UNSUPPORTED
+
+
+class _Hdf5View(DataView):
+ """View displaying data using text"""
+
+ def __init__(self, parent):
+ super(_Hdf5View, self).__init__(
+ parent=parent,
+ modeId=HDF5_MODE,
+ label="HDF5",
+ icon=icons.getQIcon("view-hdf5"))
+
+ def createWidget(self, parent):
+ from .Hdf5TableView import Hdf5TableView
+ widget = Hdf5TableView(parent)
+ return widget
+
+ def clear(self):
+ widget = self.getWidget()
+ widget.setData(None)
+
+ def setData(self, data):
+ widget = self.getWidget()
+ widget.setData(data)
+
+ def axesNames(self, data, info):
+ return None
+
+ def getDataPriority(self, data, info):
+ widget = self.getWidget()
+ if widget.isSupportedData(data):
+ return 1
+ else:
+ return DataView.UNSUPPORTED
+
+
+class _RawView(CompositeDataView):
+ """View displaying data as raw data.
+
+ This implementation use a 2d-array view, or a record array view, or a
+ raw text output.
+ """
+
+ def __init__(self, parent):
+ super(_RawView, self).__init__(
+ parent=parent,
+ modeId=RAW_MODE,
+ label="Raw",
+ icon=icons.getQIcon("view-raw"))
+ self.addView(_HexaView(parent))
+ self.addView(_ScalarView(parent))
+ self.addView(_ArrayView(parent))
+ self.addView(_RecordView(parent))
+
+
+class _ImageView(CompositeDataView):
+ """View displaying data as 2D image
+
+ It choose between Plot2D and ComplexImageView widgets
+ """
+
+ def __init__(self, parent):
+ super(_ImageView, self).__init__(
+ parent=parent,
+ modeId=IMAGE_MODE,
+ label="Image",
+ icon=icons.getQIcon("view-2d"))
+ self.addView(_ComplexImageView(parent))
+ self.addView(_Plot2dView(parent))
+
+
+class _InvalidNXdataView(DataView):
+ """DataView showing a simple label with an error message
+ to inform that a group with @NX_class=NXdata cannot be
+ interpreted by any NXDataview."""
+ def __init__(self, parent):
+ DataView.__init__(self, parent,
+ modeId=NXDATA_INVALID_MODE)
+ self._msg = ""
+
+ def createWidget(self, parent):
+ widget = qt.QLabel(parent)
+ widget.setWordWrap(True)
+ widget.setStyleSheet("QLabel { color : red; }")
+ return widget
+
+ def axesNames(self, data, info):
+ return []
+
+ def clear(self):
+ self.getWidget().setText("")
+
+ def setData(self, data):
+ self.getWidget().setText(self._msg)
+
+ def getDataPriority(self, data, info):
+ data = self.normalizeData(data)
+
+ if not info.isInvalidNXdata:
+ return DataView.UNSUPPORTED
+
+ if info.hasNXdata:
+ self._msg = "NXdata seems valid, but cannot be displayed "
+ self._msg += "by any existing plot widget."
+ else:
+ nx_class = get_attr_as_unicode(data, "NX_class")
+ if nx_class == "NXdata":
+ # invalid: could not even be parsed by NXdata
+ self._msg = "Group has @NX_class = NXdata, but could not be interpreted"
+ self._msg += " as valid NXdata."
+ elif nx_class == "NXroot" or silx.io.is_file(data):
+ default_entry = data[data.attrs["default"]]
+ default_nxdata_name = default_entry.attrs["default"]
+ self._msg = "NXroot group provides a @default attribute "
+ self._msg += "pointing to a NXentry which defines its own "
+ self._msg += "@default attribute, "
+ if default_nxdata_name not in default_entry:
+ self._msg += " but no corresponding NXdata group exists."
+ elif get_attr_as_unicode(default_entry[default_nxdata_name],
+ "NX_class") != "NXdata":
+ self._msg += " but the corresponding item is not a "
+ self._msg += "NXdata group."
+ else:
+ self._msg += " but the corresponding NXdata seems to be"
+ self._msg += " malformed."
+ else:
+ self._msg = "Group provides a @default attribute,"
+ default_nxdata_name = data.attrs["default"]
+ if default_nxdata_name not in data:
+ self._msg += " but no corresponding NXdata group exists."
+ elif get_attr_as_unicode(data[default_nxdata_name], "NX_class") != "NXdata":
+ self._msg += " but the corresponding item is not a "
+ self._msg += "NXdata group."
+ else:
+ self._msg += " but the corresponding NXdata seems to be"
+ self._msg += " malformed."
+ return 100
+
+
+class _NXdataBaseDataView(DataView):
+ """Base class for NXdata DataView"""
+
+ def __init__(self, *args, **kwargs):
+ DataView.__init__(self, *args, **kwargs)
+
+ def _updateColormap(self, nxdata):
+ """Update used colormap according to nxdata's SILX_style"""
+ cmap_norm = nxdata.plot_style.signal_scale_type
+ if cmap_norm is not None:
+ self.defaultColormap().setNormalization(
+ 'log' if cmap_norm == 'log' else 'linear')
+
+
+class _NXdataScalarView(_NXdataBaseDataView):
+ """DataView using a table view for displaying NXdata scalars:
+ 0-D signal or n-D signal with *@interpretation=scalar*"""
+ def __init__(self, parent):
+ _NXdataBaseDataView.__init__(
+ self, parent, modeId=NXDATA_SCALAR_MODE)
+
+ def createWidget(self, parent):
+ from silx.gui.data.ArrayTableWidget import ArrayTableWidget
+ widget = ArrayTableWidget(parent)
+ # widget.displayAxesSelector(False)
+ return widget
+
+ def axesNames(self, data, info):
+ return ["col", "row"]
+
+ def clear(self):
+ self.getWidget().setArrayData(numpy.array([[]]),
+ labels=True)
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ # data could be a NXdata or an NXentry
+ nxd = nxdata.get_default(data, validate=False)
+ signal = nxd.signal
+ self.getWidget().setArrayData(signal,
+ labels=True)
+
+ def getDataPriority(self, data, info):
+ data = self.normalizeData(data)
+
+ if info.hasNXdata and not info.isInvalidNXdata:
+ nxd = nxdata.get_default(data, validate=False)
+ if nxd.signal_is_0d or nxd.interpretation in ["scalar", "scaler"]:
+ return 100
+ return DataView.UNSUPPORTED
+
+
+class _NXdataCurveView(_NXdataBaseDataView):
+ """DataView using a Plot1D for displaying NXdata curves:
+ 1-D signal or n-D signal with *@interpretation=spectrum*.
+
+ It also handles basic scatter plots:
+ a 1-D signal with one axis whose values are not monotonically increasing.
+ """
+ def __init__(self, parent):
+ _NXdataBaseDataView.__init__(
+ self, parent, modeId=NXDATA_CURVE_MODE)
+
+ def createWidget(self, parent):
+ from silx.gui.data.NXdataWidgets import ArrayCurvePlot
+ widget = ArrayCurvePlot(parent)
+ return widget
+
+ def axesNames(self, data, info):
+ # disabled (used by default axis selector widget in Hdf5Viewer)
+ return None
+
+ def clear(self):
+ self.getWidget().clear()
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ nxd = nxdata.get_default(data, validate=False)
+ signals_names = [nxd.signal_name] + nxd.auxiliary_signals_names
+ if nxd.axes_dataset_names[-1] is not None:
+ x_errors = nxd.get_axis_errors(nxd.axes_dataset_names[-1])
+ else:
+ x_errors = None
+
+ # this fix is necessary until the next release of PyMca (5.2.3 or 5.3.0)
+ # see https://github.com/vasole/pymca/issues/144 and https://github.com/vasole/pymca/pull/145
+ if not hasattr(self.getWidget(), "setCurvesData") and \
+ hasattr(self.getWidget(), "setCurveData"):
+ _logger.warning("Using deprecated ArrayCurvePlot API, "
+ "without support of auxiliary signals")
+ self.getWidget().setCurveData(nxd.signal, nxd.axes[-1],
+ yerror=nxd.errors, xerror=x_errors,
+ ylabel=nxd.signal_name, xlabel=nxd.axes_names[-1],
+ title=nxd.title or nxd.signal_name)
+ return
+
+ self.getWidget().setCurvesData([nxd.signal] + nxd.auxiliary_signals, nxd.axes[-1],
+ yerror=nxd.errors, xerror=x_errors,
+ ylabels=signals_names, xlabel=nxd.axes_names[-1],
+ title=nxd.title or signals_names[0],
+ xscale=nxd.plot_style.axes_scale_types[-1],
+ yscale=nxd.plot_style.signal_scale_type)
+
+ def getDataPriority(self, data, info):
+ data = self.normalizeData(data)
+ if info.hasNXdata and not info.isInvalidNXdata:
+ if nxdata.get_default(data, validate=False).is_curve:
+ return 100
+ return DataView.UNSUPPORTED
+
+
+class _NXdataXYVScatterView(_NXdataBaseDataView):
+ """DataView using a Plot1D for displaying NXdata 3D scatters as
+ a scatter of coloured points (1-D signal with 2 axes)"""
+ def __init__(self, parent):
+ _NXdataBaseDataView.__init__(
+ self, parent, modeId=NXDATA_XYVSCATTER_MODE)
+
+ def createWidget(self, parent):
+ from silx.gui.data.NXdataWidgets import XYVScatterPlot
+ widget = XYVScatterPlot(parent)
+ widget.getScatterView().setColormap(self.defaultColormap())
+ widget.getScatterView().getScatterToolBar().getColormapAction().setColorDialog(
+ self.defaultColorDialog())
+ return widget
+
+ def axesNames(self, data, info):
+ # disabled (used by default axis selector widget in Hdf5Viewer)
+ return None
+
+ def clear(self):
+ self.getWidget().clear()
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ nxd = nxdata.get_default(data, validate=False)
+
+ x_axis, y_axis = nxd.axes[-2:]
+ if x_axis is None:
+ x_axis = numpy.arange(nxd.signal.size)
+ if y_axis is None:
+ y_axis = numpy.arange(nxd.signal.size)
+
+ x_label, y_label = nxd.axes_names[-2:]
+ if x_label is not None:
+ x_errors = nxd.get_axis_errors(x_label)
+ else:
+ x_errors = None
+
+ if y_label is not None:
+ y_errors = nxd.get_axis_errors(y_label)
+ else:
+ y_errors = None
+
+ self._updateColormap(nxd)
+
+ self.getWidget().setScattersData(y_axis, x_axis, values=[nxd.signal] + nxd.auxiliary_signals,
+ yerror=y_errors, xerror=x_errors,
+ ylabel=y_label, xlabel=x_label,
+ title=nxd.title,
+ scatter_titles=[nxd.signal_name] + nxd.auxiliary_signals_names,
+ xscale=nxd.plot_style.axes_scale_types[-2],
+ yscale=nxd.plot_style.axes_scale_types[-1])
+
+ def getDataPriority(self, data, info):
+ data = self.normalizeData(data)
+ if info.hasNXdata and not info.isInvalidNXdata:
+ if nxdata.get_default(data, validate=False).is_x_y_value_scatter:
+ # It have to be a little more than a NX curve priority
+ return 110
+
+ return DataView.UNSUPPORTED
+
+
+class _NXdataImageView(_NXdataBaseDataView):
+ """DataView using a Plot2D for displaying NXdata images:
+ 2-D signal or n-D signals with *@interpretation=image*."""
+ def __init__(self, parent):
+ _NXdataBaseDataView.__init__(
+ self, parent, modeId=NXDATA_IMAGE_MODE)
+
+ def createWidget(self, parent):
+ from silx.gui.data.NXdataWidgets import ArrayImagePlot
+ widget = ArrayImagePlot(parent)
+ widget.getPlot().setDefaultColormap(self.defaultColormap())
+ widget.getPlot().getColormapAction().setColorDialog(self.defaultColorDialog())
+ return widget
+
+ def axesNames(self, data, info):
+ # disabled (used by default axis selector widget in Hdf5Viewer)
+ return None
+
+ def clear(self):
+ self.getWidget().clear()
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ nxd = nxdata.get_default(data, validate=False)
+ isRgba = nxd.interpretation == "rgba-image"
+
+ self._updateColormap(nxd)
+
+ # last two axes are Y & X
+ img_slicing = slice(-2, None) if not isRgba else slice(-3, -1)
+ y_axis, x_axis = nxd.axes[img_slicing]
+ y_label, x_label = nxd.axes_names[img_slicing]
+ y_scale, x_scale = nxd.plot_style.axes_scale_types[img_slicing]
+
+ self.getWidget().setImageData(
+ [nxd.signal] + nxd.auxiliary_signals,
+ x_axis=x_axis, y_axis=y_axis,
+ signals_names=[nxd.signal_name] + nxd.auxiliary_signals_names,
+ xlabel=x_label, ylabel=y_label,
+ title=nxd.title, isRgba=isRgba,
+ xscale=x_scale, yscale=y_scale)
+
+ def getDataPriority(self, data, info):
+ data = self.normalizeData(data)
+
+ if info.hasNXdata and not info.isInvalidNXdata:
+ if nxdata.get_default(data, validate=False).is_image:
+ return 100
+
+ return DataView.UNSUPPORTED
+
+
+class _NXdataComplexImageView(_NXdataBaseDataView):
+ """DataView using a ComplexImageView for displaying NXdata complex images:
+ 2-D signal or n-D signals with *@interpretation=image*."""
+ def __init__(self, parent):
+ _NXdataBaseDataView.__init__(
+ self, parent, modeId=NXDATA_IMAGE_MODE)
+
+ def createWidget(self, parent):
+ from silx.gui.data.NXdataWidgets import ArrayComplexImagePlot
+ widget = ArrayComplexImagePlot(parent, colormap=self.defaultColormap())
+ widget.getPlot().getColormapAction().setColorDialog(self.defaultColorDialog())
+ return widget
+
+ def clear(self):
+ self.getWidget().clear()
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ nxd = nxdata.get_default(data, validate=False)
+
+ self._updateColormap(nxd)
+
+ # last two axes are Y & X
+ img_slicing = slice(-2, None)
+ y_axis, x_axis = nxd.axes[img_slicing]
+ y_label, x_label = nxd.axes_names[img_slicing]
+
+ self.getWidget().setImageData(
+ [nxd.signal] + nxd.auxiliary_signals,
+ x_axis=x_axis, y_axis=y_axis,
+ signals_names=[nxd.signal_name] + nxd.auxiliary_signals_names,
+ xlabel=x_label, ylabel=y_label,
+ title=nxd.title)
+
+ def axesNames(self, data, info):
+ # disabled (used by default axis selector widget in Hdf5Viewer)
+ return None
+
+ def getDataPriority(self, data, info):
+ data = self.normalizeData(data)
+
+ if info.hasNXdata and not info.isInvalidNXdata:
+ nxd = nxdata.get_default(data, validate=False)
+ if nxd.is_image and numpy.iscomplexobj(nxd.signal):
+ return 100
+
+ return DataView.UNSUPPORTED
+
+
+class _NXdataStackView(_NXdataBaseDataView):
+ def __init__(self, parent):
+ _NXdataBaseDataView.__init__(
+ self, parent, modeId=NXDATA_STACK_MODE)
+
+ def createWidget(self, parent):
+ from silx.gui.data.NXdataWidgets import ArrayStackPlot
+ widget = ArrayStackPlot(parent)
+ widget.getStackView().setColormap(self.defaultColormap())
+ widget.getStackView().getPlotWidget().getColormapAction().setColorDialog(self.defaultColorDialog())
+ return widget
+
+ def axesNames(self, data, info):
+ # disabled (used by default axis selector widget in Hdf5Viewer)
+ return None
+
+ def clear(self):
+ self.getWidget().clear()
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ nxd = nxdata.get_default(data, validate=False)
+ signal_name = nxd.signal_name
+ z_axis, y_axis, x_axis = nxd.axes[-3:]
+ z_label, y_label, x_label = nxd.axes_names[-3:]
+ title = nxd.title or signal_name
+
+ self._updateColormap(nxd)
+
+ widget = self.getWidget()
+ widget.setStackData(
+ nxd.signal, x_axis=x_axis, y_axis=y_axis, z_axis=z_axis,
+ signal_name=signal_name,
+ xlabel=x_label, ylabel=y_label, zlabel=z_label,
+ title=title)
+ # Override the colormap, while setStack overwrite it
+ widget.getStackView().setColormap(self.defaultColormap())
+
+ def getDataPriority(self, data, info):
+ data = self.normalizeData(data)
+ if info.hasNXdata and not info.isInvalidNXdata:
+ if nxdata.get_default(data, validate=False).is_stack:
+ return 100
+
+ return DataView.UNSUPPORTED
+
+
+class _NXdataVolumeView(_NXdataBaseDataView):
+ def __init__(self, parent):
+ _NXdataBaseDataView.__init__(
+ self, parent,
+ label="NXdata (3D)",
+ icon=icons.getQIcon("view-nexus"),
+ modeId=NXDATA_VOLUME_MODE)
+ try:
+ import silx.gui.plot3d # noqa
+ except ImportError:
+ _logger.warning("Plot3dView is not available")
+ _logger.debug("Backtrace", exc_info=True)
+ raise
+
+ def normalizeData(self, data):
+ data = super(_NXdataVolumeView, self).normalizeData(data)
+ data = _normalizeComplex(data)
+ return data
+
+ def createWidget(self, parent):
+ from silx.gui.data.NXdataWidgets import ArrayVolumePlot
+ widget = ArrayVolumePlot(parent)
+ return widget
+
+ def axesNames(self, data, info):
+ # disabled (used by default axis selector widget in Hdf5Viewer)
+ return None
+
+ def clear(self):
+ self.getWidget().clear()
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ nxd = nxdata.get_default(data, validate=False)
+ signal_name = nxd.signal_name
+ z_axis, y_axis, x_axis = nxd.axes[-3:]
+ z_label, y_label, x_label = nxd.axes_names[-3:]
+ title = nxd.title or signal_name
+
+ widget = self.getWidget()
+ widget.setData(
+ nxd.signal, x_axis=x_axis, y_axis=y_axis, z_axis=z_axis,
+ signal_name=signal_name,
+ xlabel=x_label, ylabel=y_label, zlabel=z_label,
+ title=title)
+
+ def getDataPriority(self, data, info):
+ data = self.normalizeData(data)
+ if info.hasNXdata and not info.isInvalidNXdata:
+ if nxdata.get_default(data, validate=False).is_volume:
+ return 150
+
+ return DataView.UNSUPPORTED
+
+
+class _NXdataVolumeAsStackView(_NXdataBaseDataView):
+ def __init__(self, parent):
+ _NXdataBaseDataView.__init__(
+ self, parent,
+ label="NXdata (2D)",
+ icon=icons.getQIcon("view-nexus"),
+ modeId=NXDATA_VOLUME_AS_STACK_MODE)
+
+ def createWidget(self, parent):
+ from silx.gui.data.NXdataWidgets import ArrayStackPlot
+ widget = ArrayStackPlot(parent)
+ widget.getStackView().setColormap(self.defaultColormap())
+ widget.getStackView().getPlotWidget().getColormapAction().setColorDialog(self.defaultColorDialog())
+ return widget
+
+ def axesNames(self, data, info):
+ # disabled (used by default axis selector widget in Hdf5Viewer)
+ return None
+
+ def clear(self):
+ self.getWidget().clear()
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ nxd = nxdata.get_default(data, validate=False)
+ signal_name = nxd.signal_name
+ z_axis, y_axis, x_axis = nxd.axes[-3:]
+ z_label, y_label, x_label = nxd.axes_names[-3:]
+ title = nxd.title or signal_name
+
+ self._updateColormap(nxd)
+
+ widget = self.getWidget()
+ widget.setStackData(
+ nxd.signal, x_axis=x_axis, y_axis=y_axis, z_axis=z_axis,
+ signal_name=signal_name,
+ xlabel=x_label, ylabel=y_label, zlabel=z_label,
+ title=title)
+ # Override the colormap, while setStack overwrite it
+ widget.getStackView().setColormap(self.defaultColormap())
+
+ def getDataPriority(self, data, info):
+ data = self.normalizeData(data)
+ if info.isComplex:
+ return DataView.UNSUPPORTED
+ if info.hasNXdata and not info.isInvalidNXdata:
+ if nxdata.get_default(data, validate=False).is_volume:
+ return 200
+
+ return DataView.UNSUPPORTED
+
+class _NXdataComplexVolumeAsStackView(_NXdataBaseDataView):
+ def __init__(self, parent):
+ _NXdataBaseDataView.__init__(
+ self, parent,
+ label="NXdata (2D)",
+ icon=icons.getQIcon("view-nexus"),
+ modeId=NXDATA_VOLUME_AS_STACK_MODE)
+ self._is_complex_data = False
+
+ def createWidget(self, parent):
+ from silx.gui.data.NXdataWidgets import ArrayComplexImagePlot
+ widget = ArrayComplexImagePlot(parent, colormap=self.defaultColormap())
+ widget.getPlot().getColormapAction().setColorDialog(self.defaultColorDialog())
+ return widget
+
+ def axesNames(self, data, info):
+ # disabled (used by default axis selector widget in Hdf5Viewer)
+ return None
+
+ def clear(self):
+ self.getWidget().clear()
+
+ def setData(self, data):
+ data = self.normalizeData(data)
+ nxd = nxdata.get_default(data, validate=False)
+ signal_name = nxd.signal_name
+ z_axis, y_axis, x_axis = nxd.axes[-3:]
+ z_label, y_label, x_label = nxd.axes_names[-3:]
+ title = nxd.title or signal_name
+
+ self._updateColormap(nxd)
+
+ self.getWidget().setImageData(
+ [nxd.signal] + nxd.auxiliary_signals,
+ x_axis=x_axis, y_axis=y_axis,
+ signals_names=[nxd.signal_name] + nxd.auxiliary_signals_names,
+ xlabel=x_label, ylabel=y_label, title=nxd.title)
+
+ def getDataPriority(self, data, info):
+ data = self.normalizeData(data)
+ if not info.isComplex:
+ return DataView.UNSUPPORTED
+ if info.hasNXdata and not info.isInvalidNXdata:
+ if nxdata.get_default(data, validate=False).is_volume:
+ return 200
+
+ return DataView.UNSUPPORTED
+
+
+class _NXdataView(CompositeDataView):
+ """Composite view displaying NXdata groups using the most adequate
+ widget depending on the dimensionality."""
+ def __init__(self, parent):
+ super(_NXdataView, self).__init__(
+ parent=parent,
+ label="NXdata",
+ modeId=NXDATA_MODE,
+ icon=icons.getQIcon("view-nexus"))
+
+ self.addView(_InvalidNXdataView(parent))
+ self.addView(_NXdataScalarView(parent))
+ self.addView(_NXdataCurveView(parent))
+ self.addView(_NXdataXYVScatterView(parent))
+ self.addView(_NXdataComplexImageView(parent))
+ self.addView(_NXdataImageView(parent))
+ self.addView(_NXdataStackView(parent))
+
+ # The 3D view can be displayed using 2 ways
+ nx3dViews = SelectManyDataView(parent)
+ nx3dViews.addView(_NXdataVolumeAsStackView(parent))
+ nx3dViews.addView(_NXdataComplexVolumeAsStackView(parent))
+ try:
+ nx3dViews.addView(_NXdataVolumeView(parent))
+ except Exception:
+ _logger.warning("NXdataVolumeView is not available")
+ _logger.debug("Backtrace", exc_info=True)
+ self.addView(nx3dViews)
diff --git a/src/silx/gui/data/Hdf5TableView.py b/src/silx/gui/data/Hdf5TableView.py
new file mode 100644
index 0000000..9d65a84
--- /dev/null
+++ b/src/silx/gui/data/Hdf5TableView.py
@@ -0,0 +1,634 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module define model and widget to display 1D slices from numpy
+array using compound data types or hdf5 databases.
+"""
+from __future__ import division
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "12/02/2019"
+
+import collections
+import functools
+import os.path
+import logging
+import h5py
+import numpy
+
+from silx.gui import qt
+import silx.io
+from .TextFormatter import TextFormatter
+import silx.gui.hdf5
+from silx.gui.widgets import HierarchicalTableView
+from ..hdf5.Hdf5Formatter import Hdf5Formatter
+from ..hdf5._utils import htmlFromDict
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _CellData(object):
+ """Store a table item
+ """
+ def __init__(self, value=None, isHeader=False, span=None, tooltip=None):
+ """
+ Constructor
+
+ :param str value: Label of this property
+ :param bool isHeader: True if the cell is an header
+ :param tuple span: Tuple of row, column span
+ """
+ self.__value = value
+ self.__isHeader = isHeader
+ self.__span = span
+ self.__tooltip = tooltip
+
+ def isHeader(self):
+ """Returns true if the property is a sub-header title.
+
+ :rtype: bool
+ """
+ return self.__isHeader
+
+ def value(self):
+ """Returns the value of the item.
+ """
+ return self.__value
+
+ def span(self):
+ """Returns the span size of the cell.
+
+ :rtype: tuple
+ """
+ return self.__span
+
+ def tooltip(self):
+ """Returns the tooltip of the item.
+
+ :rtype: tuple
+ """
+ return self.__tooltip
+
+ def invalidateValue(self):
+ self.__value = None
+
+ def invalidateToolTip(self):
+ self.__tooltip = None
+
+ def data(self, role):
+ return None
+
+
+class _TableData(object):
+ """Modelize a table with header, row and column span.
+
+ It is mostly defined as a row based table.
+ """
+
+ def __init__(self, columnCount):
+ """Constructor.
+
+ :param int columnCount: Define the number of column of the table
+ """
+ self.__colCount = columnCount
+ self.__data = []
+
+ def rowCount(self):
+ """Returns the number of rows.
+
+ :rtype: int
+ """
+ return len(self.__data)
+
+ def columnCount(self):
+ """Returns the number of columns.
+
+ :rtype: int
+ """
+ return self.__colCount
+
+ def clear(self):
+ """Remove all the cells of the table"""
+ self.__data = []
+
+ def cellAt(self, row, column):
+ """Returns the cell at the row column location. Else None if there is
+ nothing.
+
+ :rtype: _CellData
+ """
+ if row < 0:
+ return None
+ if column < 0:
+ return None
+ if row >= len(self.__data):
+ return None
+ cells = self.__data[row]
+ if column >= len(cells):
+ return None
+ return cells[column]
+
+ def addHeaderRow(self, headerLabel):
+ """Append the table with header on the full row.
+
+ :param str headerLabel: label of the header.
+ """
+ item = _CellData(value=headerLabel, isHeader=True, span=(1, self.__colCount))
+ self.__data.append([item])
+
+ def addHeaderValueRow(self, headerLabel, value, tooltip=None):
+ """Append the table with a row using the first column as an header and
+ other cells as a single cell for the value.
+
+ :param str headerLabel: label of the header.
+ :param object value: value to store.
+ """
+ header = _CellData(value=headerLabel, isHeader=True)
+ value = _CellData(value=value, span=(1, self.__colCount), tooltip=tooltip)
+ self.__data.append([header, value])
+
+ def addRow(self, *args):
+ """Append the table with a row using arguments for each cells
+
+ :param list[object] args: List of cell values for the row
+ """
+ row = []
+ for value in args:
+ if not isinstance(value, _CellData):
+ value = _CellData(value=value)
+ row.append(value)
+ self.__data.append(row)
+
+
+class _CellFilterAvailableData(_CellData):
+ """Cell rendering for availability of a filter"""
+
+ _states = {
+ True: ("Available", qt.QColor(0x000000), None, None),
+ False: ("Not available", qt.QColor(0xFFFFFF), qt.QColor(0xFF0000),
+ "You have to install this filter on your system to be able to read this dataset"),
+ "na": ("n.a.", qt.QColor(0x000000), None,
+ "This version of h5py/hdf5 is not able to display the information"),
+ }
+
+ def __init__(self, filterId):
+ if h5py.version.hdf5_version_tuple >= (1, 10, 2):
+ # Previous versions only returns True if the filter was first used
+ # to decode a dataset
+ self.__availability = h5py.h5z.filter_avail(filterId)
+ else:
+ self.__availability = "na"
+ _CellData.__init__(self)
+
+ def value(self):
+ state = self._states[self.__availability]
+ return state[0]
+
+ def tooltip(self):
+ state = self._states[self.__availability]
+ return state[3]
+
+ def data(self, role=qt.Qt.DisplayRole):
+ state = self._states[self.__availability]
+ if role == qt.Qt.ForegroundRole:
+ return state[1]
+ elif role == qt.Qt.BackgroundRole:
+ return state[2]
+ else:
+ return None
+
+
+class Hdf5TableModel(HierarchicalTableView.HierarchicalTableModel):
+ """This data model provides access to HDF5 node content (File, Group,
+ Dataset). Main info, like name, file, attributes... are displayed
+ """
+
+ def __init__(self, parent=None, data=None):
+ """
+ Constructor
+
+ :param qt.QObject parent: Parent object
+ :param object data: An h5py-like object (file, group or dataset)
+ """
+ super(Hdf5TableModel, self).__init__(parent)
+
+ self.__obj = None
+ self.__data = _TableData(columnCount=5)
+ self.__formatter = None
+ self.__hdf5Formatter = Hdf5Formatter(self)
+ formatter = TextFormatter(self)
+ self.setFormatter(formatter)
+ self.setObject(data)
+
+ def rowCount(self, parent_idx=None):
+ """Returns number of rows to be displayed in table"""
+ return self.__data.rowCount()
+
+ def columnCount(self, parent_idx=None):
+ """Returns number of columns to be displayed in table"""
+ return self.__data.columnCount()
+
+ def data(self, index, role=qt.Qt.DisplayRole):
+ """QAbstractTableModel method to access data values
+ in the format ready to be displayed"""
+ if not index.isValid():
+ return None
+
+ cell = self.__data.cellAt(index.row(), index.column())
+ if cell is None:
+ return None
+
+ if role == self.SpanRole:
+ return cell.span()
+ elif role == self.IsHeaderRole:
+ return cell.isHeader()
+ elif role in (qt.Qt.DisplayRole, qt.Qt.EditRole):
+ value = cell.value()
+ if callable(value):
+ try:
+ value = value(self.__obj)
+ except Exception:
+ cell.invalidateValue()
+ raise
+ return value
+ elif role == qt.Qt.ToolTipRole:
+ value = cell.tooltip()
+ if callable(value):
+ try:
+ value = value(self.__obj)
+ except Exception:
+ cell.invalidateToolTip()
+ raise
+ return value
+ else:
+ return cell.data(role)
+ return None
+
+ def isSupportedObject(self, h5pyObject):
+ """
+ Returns true if the provided object can be modelized using this model.
+ """
+ isSupported = False
+ isSupported = isSupported or silx.io.is_group(h5pyObject)
+ isSupported = isSupported or silx.io.is_dataset(h5pyObject)
+ isSupported = isSupported or isinstance(h5pyObject, silx.gui.hdf5.H5Node)
+ return isSupported
+
+ def setObject(self, h5pyObject):
+ """Set the h5py-like object exposed by the model
+
+ :param h5pyObject: A h5py-like object. It can be a `h5py.Dataset`,
+ a `h5py.File`, a `h5py.Group`. It also can be a,
+ `silx.gui.hdf5.H5Node` which is needed to display some local path
+ information.
+ """
+ self.beginResetModel()
+
+ if h5pyObject is None or self.isSupportedObject(h5pyObject):
+ self.__obj = h5pyObject
+ else:
+ _logger.warning("Object class %s unsupported. Object ignored.", type(h5pyObject))
+ self.__initProperties()
+
+ self.endResetModel()
+
+ def __formatHdf5Type(self, dataset):
+ """Format the HDF5 type"""
+ return self.__hdf5Formatter.humanReadableHdf5Type(dataset)
+
+ def __attributeTooltip(self, attribute):
+ attributeDict = collections.OrderedDict()
+ if hasattr(attribute, "shape"):
+ attributeDict["Shape"] = self.__hdf5Formatter.humanReadableShape(attribute)
+ attributeDict["Data type"] = self.__hdf5Formatter.humanReadableType(attribute, full=True)
+ html = htmlFromDict(attributeDict, title="HDF5 Attribute")
+ return html
+
+ def __formatDType(self, dataset):
+ """Format the numpy dtype"""
+ return self.__hdf5Formatter.humanReadableType(dataset, full=True)
+
+ def __formatShape(self, dataset):
+ """Format the shape"""
+ if dataset.shape is None or len(dataset.shape) <= 1:
+ return self.__hdf5Formatter.humanReadableShape(dataset)
+ size = dataset.size
+ shape = self.__hdf5Formatter.humanReadableShape(dataset)
+ return u"%s = %s" % (shape, size)
+
+ def __formatChunks(self, dataset):
+ """Format the shape"""
+ chunks = dataset.chunks
+ if chunks is None:
+ return ""
+ shape = " \u00D7 ".join([str(i) for i in chunks])
+ sizes = numpy.product(chunks)
+ text = "%s = %s" % (shape, sizes)
+ return text
+
+ def __initProperties(self):
+ """Initialize the list of available properties according to the defined
+ h5py-like object."""
+ self.__data.clear()
+ if self.__obj is None:
+ return
+
+ obj = self.__obj
+
+ hdf5obj = obj
+ if isinstance(obj, silx.gui.hdf5.H5Node):
+ hdf5obj = obj.h5py_object
+
+ if silx.io.is_file(hdf5obj):
+ objectType = "File"
+ elif silx.io.is_group(hdf5obj):
+ objectType = "Group"
+ elif silx.io.is_dataset(hdf5obj):
+ objectType = "Dataset"
+ else:
+ objectType = obj.__class__.__name__
+ self.__data.addHeaderRow(headerLabel="HDF5 %s" % objectType)
+
+ SEPARATOR = "::"
+
+ self.__data.addHeaderRow(headerLabel="Path info")
+ showPhysicalLocation = True
+ if isinstance(obj, silx.gui.hdf5.H5Node):
+ # helpful informations if the object come from an HDF5 tree
+ self.__data.addHeaderValueRow("Basename", lambda x: x.local_basename)
+ self.__data.addHeaderValueRow("Name", lambda x: x.local_name)
+ local = lambda x: x.local_filename + SEPARATOR + x.local_name
+ self.__data.addHeaderValueRow("Local", local)
+ else:
+ # it's a real H5py object
+ self.__data.addHeaderValueRow("Basename", lambda x: os.path.basename(x.name))
+ self.__data.addHeaderValueRow("Name", lambda x: x.name)
+ if obj.file is not None:
+ self.__data.addHeaderValueRow("File", lambda x: x.file.filename)
+ if hasattr(obj, "path"):
+ # That's a link
+ if hasattr(obj, "filename"):
+ # External link
+ link = lambda x: x.filename + SEPARATOR + x.path
+ else:
+ # Soft link
+ link = lambda x: x.path
+ self.__data.addHeaderValueRow("Link", link)
+ showPhysicalLocation = False
+
+ # External data (nothing to do with external links)
+ nExtSources = 0
+ firstExtSource = None
+ extType = None
+ if silx.io.is_dataset(hdf5obj):
+ if hasattr(hdf5obj, "is_virtual"):
+ if hdf5obj.is_virtual:
+ extSources = hdf5obj.virtual_sources()
+ if extSources:
+ firstExtSource = extSources[0].file_name + SEPARATOR + extSources[0].dset_name
+ extType = "Virtual"
+ nExtSources = len(extSources)
+ if hasattr(hdf5obj, "external"):
+ extSources = hdf5obj.external
+ if extSources:
+ firstExtSource = extSources[0][0]
+ extType = "Raw"
+ nExtSources = len(extSources)
+
+ if showPhysicalLocation:
+ def _physical_location(x):
+ if isinstance(obj, silx.gui.hdf5.H5Node):
+ return x.physical_filename + SEPARATOR + x.physical_name
+ elif silx.io.is_file(obj):
+ return x.filename + SEPARATOR + x.name
+ elif obj.file is not None:
+ return x.file.filename + SEPARATOR + x.name
+ else:
+ # Guess it is a virtual node
+ return "No physical location"
+
+ self.__data.addHeaderValueRow("Physical", _physical_location)
+
+ if extType:
+ def _first_source(x):
+ # Absolute path
+ if os.path.isabs(firstExtSource):
+ return firstExtSource
+
+ # Relative path with respect to the file directory
+ if isinstance(obj, silx.gui.hdf5.H5Node):
+ filename = x.physical_filename
+ elif silx.io.is_file(obj):
+ filename = x.filename
+ elif obj.file is not None:
+ filename = x.file.filename
+ else:
+ return firstExtSource
+
+ if firstExtSource[0] == ".":
+ firstExtSource.pop(0)
+ return os.path.join(os.path.dirname(filename), firstExtSource)
+
+ self.__data.addHeaderRow(headerLabel="External sources")
+ self.__data.addHeaderValueRow("Type", extType)
+ self.__data.addHeaderValueRow("Count", str(nExtSources))
+ self.__data.addHeaderValueRow("First", _first_source)
+
+ if hasattr(obj, "dtype"):
+
+ self.__data.addHeaderRow(headerLabel="Data info")
+
+ if hasattr(obj, "id") and hasattr(obj.id, "get_type"):
+ # display the HDF5 type
+ self.__data.addHeaderValueRow("HDF5 type", self.__formatHdf5Type)
+ self.__data.addHeaderValueRow("dtype", self.__formatDType)
+ if hasattr(obj, "shape"):
+ self.__data.addHeaderValueRow("shape", self.__formatShape)
+ if hasattr(obj, "chunks") and obj.chunks is not None:
+ self.__data.addHeaderValueRow("chunks", self.__formatChunks)
+
+ # relative to compression
+ # h5py expose compression, compression_opts but are not initialized
+ # for external plugins, then we use id
+ # h5py also expose fletcher32 and shuffle attributes, but it is also
+ # part of the filters
+ if hasattr(obj, "shape") and hasattr(obj, "id"):
+ if hasattr(obj.id, "get_create_plist"):
+ dcpl = obj.id.get_create_plist()
+ if dcpl.get_nfilters() > 0:
+ self.__data.addHeaderRow(headerLabel="Compression info")
+ pos = _CellData(value="Position", isHeader=True)
+ hdf5id = _CellData(value="HDF5 ID", isHeader=True)
+ name = _CellData(value="Name", isHeader=True)
+ options = _CellData(value="Options", isHeader=True)
+ availability = _CellData(value="", isHeader=True)
+ self.__data.addRow(pos, hdf5id, name, options, availability)
+ for index in range(dcpl.get_nfilters()):
+ filterId, name, options = self.__getFilterInfo(obj, index)
+ pos = _CellData(value=str(index))
+ hdf5id = _CellData(value=str(filterId))
+ name = _CellData(value=name)
+ options = _CellData(value=options)
+ availability = _CellFilterAvailableData(filterId=filterId)
+ self.__data.addRow(pos, hdf5id, name, options, availability)
+
+ if hasattr(obj, "attrs"):
+ if len(obj.attrs) > 0:
+ self.__data.addHeaderRow(headerLabel="Attributes")
+ for key in sorted(obj.attrs.keys()):
+ callback = lambda key, x: self.__formatter.toString(x.attrs[key])
+ callbackTooltip = lambda key, x: self.__attributeTooltip(x.attrs[key])
+ self.__data.addHeaderValueRow(headerLabel=key,
+ value=functools.partial(callback, key),
+ tooltip=functools.partial(callbackTooltip, key))
+
+ def __getFilterInfo(self, dataset, filterIndex):
+ """Get a tuple of readable info from dataset filters
+
+ :param h5py.Dataset dataset: A h5py dataset
+ :param int filterId:
+ """
+ try:
+ dcpl = dataset.id.get_create_plist()
+ info = dcpl.get_filter(filterIndex)
+ filterId, _flags, cdValues, name = info
+ name = self.__formatter.toString(name)
+ options = " ".join([self.__formatter.toString(i) for i in cdValues])
+ return (filterId, name, options)
+ except Exception:
+ _logger.debug("Backtrace", exc_info=True)
+ return (None, None, None)
+
+ def object(self):
+ """Returns the internal object modelized.
+
+ :rtype: An h5py-like object
+ """
+ return self.__obj
+
+ def setFormatter(self, formatter):
+ """Set the formatter object to be used to display data from the model
+
+ :param TextFormatter formatter: Formatter to use
+ """
+ if formatter is self.__formatter:
+ return
+
+ self.__hdf5Formatter.setTextFormatter(formatter)
+
+ self.beginResetModel()
+
+ if self.__formatter is not None:
+ self.__formatter.formatChanged.disconnect(self.__formatChanged)
+
+ self.__formatter = formatter
+ if self.__formatter is not None:
+ self.__formatter.formatChanged.connect(self.__formatChanged)
+
+ self.endResetModel()
+
+ def getFormatter(self):
+ """Returns the text formatter used.
+
+ :rtype: TextFormatter
+ """
+ return self.__formatter
+
+ def __formatChanged(self):
+ """Called when the format changed.
+ """
+ self.reset()
+
+
+class Hdf5TableItemDelegate(HierarchicalTableView.HierarchicalItemDelegate):
+ """Item delegate the :class:`Hdf5TableView` with read-only text editor"""
+
+ def createEditor(self, parent, option, index):
+ """See :meth:`QStyledItemDelegate.createEditor`"""
+ editor = super().createEditor(parent, option, index)
+ if isinstance(editor, qt.QLineEdit):
+ editor.setReadOnly(True)
+ editor.deselect()
+ editor.textChanged.connect(self.__textChanged, qt.Qt.QueuedConnection)
+ self.installEventFilter(editor)
+ return editor
+
+ def __textChanged(self, text):
+ sender = self.sender()
+ if sender is not None:
+ sender.deselect()
+
+ def eventFilter(self, watched, event):
+ eventType = event.type()
+ if eventType == qt.QEvent.FocusIn:
+ watched.selectAll()
+ qt.QTimer.singleShot(0, watched.selectAll)
+ elif eventType == qt.QEvent.FocusOut:
+ watched.deselect()
+ return super().eventFilter(watched, event)
+
+
+class Hdf5TableView(HierarchicalTableView.HierarchicalTableView):
+ """A widget to display metadata about a HDF5 node using a table."""
+
+ def __init__(self, parent=None):
+ super(Hdf5TableView, self).__init__(parent)
+ self.setModel(Hdf5TableModel(self))
+ self.setItemDelegate(Hdf5TableItemDelegate(self))
+ self.setSelectionMode(qt.QAbstractItemView.NoSelection)
+
+ def isSupportedData(self, data):
+ """
+ Returns true if the provided object can be modelized using this model.
+ """
+ return self.model().isSupportedObject(data)
+
+ def setData(self, data):
+ """Set the h5py-like object exposed by the model
+
+ :param data: A h5py-like object. It can be a `h5py.Dataset`,
+ a `h5py.File`, a `h5py.Group`. It also can be a,
+ `silx.gui.hdf5.H5Node` which is needed to display some local path
+ information.
+ """
+ model = self.model()
+
+ model.setObject(data)
+ header = self.horizontalHeader()
+ header.setSectionResizeMode(0, qt.QHeaderView.Fixed)
+ header.setSectionResizeMode(1, qt.QHeaderView.ResizeToContents)
+ header.setSectionResizeMode(2, qt.QHeaderView.Stretch)
+ header.setSectionResizeMode(3, qt.QHeaderView.ResizeToContents)
+ header.setSectionResizeMode(4, qt.QHeaderView.ResizeToContents)
+ header.setStretchLastSection(False)
+
+ for row in range(model.rowCount()):
+ for column in range(model.columnCount()):
+ index = model.index(row, column)
+ if (index.isValid() and index.data(
+ HierarchicalTableView.HierarchicalTableModel.IsHeaderRole) is False):
+ self.openPersistentEditor(index)
diff --git a/src/silx/gui/data/HexaTableView.py b/src/silx/gui/data/HexaTableView.py
new file mode 100644
index 0000000..9e00a7b
--- /dev/null
+++ b/src/silx/gui/data/HexaTableView.py
@@ -0,0 +1,272 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module defines model and widget to display raw data using an
+hexadecimal viewer.
+"""
+from __future__ import division
+
+import collections
+
+import numpy
+
+from silx.gui import qt
+import silx.io.utils
+from silx.gui.widgets.TableWidget import CopySelectedCellsAction
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "23/05/2018"
+
+
+class _VoidConnector(object):
+ """Byte connector to a numpy.void data.
+
+ It uses a cache of 32 x 1KB and a direct read access API from HDF5.
+ """
+
+ def __init__(self, data):
+ self.__cache = collections.OrderedDict()
+ self.__len = data.itemsize
+ self.__data = data
+
+ def __getBuffer(self, bufferId):
+ if bufferId not in self.__cache:
+ pos = bufferId << 10
+ data = self.__data
+ if hasattr(data, "tobytes"):
+ data = data.tobytes()[pos:pos + 1024]
+ else:
+ # Old fashion
+ data = data.data[pos:pos + 1024]
+
+ self.__cache[bufferId] = data
+ if len(self.__cache) > 32:
+ self.__cache.popitem()
+ else:
+ data = self.__cache[bufferId]
+ return data
+
+ def __getitem__(self, pos):
+ """Returns the value of the byte at the given position.
+
+ :param uint pos: Position of the byte
+ :rtype: int
+ """
+ bufferId = pos >> 10
+ bufferPos = pos & 0b1111111111
+ data = self.__getBuffer(bufferId)
+ return data[bufferPos]
+
+ def __len__(self):
+ """
+ Returns the number of available bytes.
+
+ :rtype: uint
+ """
+ return self.__len
+
+
+class HexaTableModel(qt.QAbstractTableModel):
+ """This data model provides access to a numpy void data.
+
+ Bytes are displayed one by one as a hexadecimal viewer.
+
+ The 16th first columns display bytes as hexadecimal, the last column
+ displays the same data as ASCII.
+
+ :param qt.QObject parent: Parent object
+ :param data: A numpy array or a h5py dataset
+ """
+ def __init__(self, parent=None, data=None):
+ qt.QAbstractTableModel.__init__(self, parent)
+
+ self.__data = None
+ self.__connector = None
+ self.setArrayData(data)
+
+ if hasattr(qt.QFontDatabase, "systemFont"): # Qt >= 5.2
+ self.__font = qt.QFontDatabase.systemFont(qt.QFontDatabase.FixedFont)
+ else: # Qt < 5.2
+ self.__font = qt.QFont("Monospace")
+ self.__font.setStyleHint(qt.QFont.TypeWriter)
+ self.__palette = qt.QPalette()
+
+ def rowCount(self, parent_idx=None):
+ """Returns number of rows to be displayed in table"""
+ if self.__connector is None:
+ return 0
+ return ((len(self.__connector) - 1) >> 4) + 1
+
+ def columnCount(self, parent_idx=None):
+ """Returns number of columns to be displayed in table"""
+ return 0x10 + 1
+
+ def data(self, index, role=qt.Qt.DisplayRole):
+ """QAbstractTableModel method to access data values
+ in the format ready to be displayed"""
+ if not index.isValid():
+ return None
+
+ if self.__connector is None:
+ return None
+
+ row = index.row()
+ column = index.column()
+
+ if role == qt.Qt.DisplayRole:
+ if column == 0x10:
+ start = (row << 4)
+ text = ""
+ for i in range(0x10):
+ pos = start + i
+ if pos >= len(self.__connector):
+ break
+ value = self.__connector[pos]
+ if value > 0x20 and value < 0x7F:
+ text += chr(value)
+ else:
+ text += "."
+ return text
+ else:
+ pos = (row << 4) + column
+ if pos < len(self.__connector):
+ value = self.__connector[pos]
+ return "%02X" % value
+ else:
+ return ""
+ elif role == qt.Qt.FontRole:
+ return self.__font
+
+ elif role == qt.Qt.BackgroundRole:
+ pos = (row << 4) + column
+ if column != 0x10 and pos >= len(self.__connector):
+ return self.__palette.color(qt.QPalette.Disabled, qt.QPalette.Window)
+ else:
+ return None
+
+ return None
+
+ def headerData(self, section, orientation, role=qt.Qt.DisplayRole):
+ """Returns the 0-based row or column index, for display in the
+ horizontal and vertical headers"""
+ if section == -1:
+ # PyQt4 send -1 when there is columns but no rows
+ return None
+
+ if role == qt.Qt.DisplayRole:
+ if orientation == qt.Qt.Vertical:
+ return "%02X" % (section << 4)
+ if orientation == qt.Qt.Horizontal:
+ if section == 0x10:
+ return "ASCII"
+ else:
+ return "%02X" % section
+ elif role == qt.Qt.FontRole:
+ return self.__font
+ elif role == qt.Qt.TextAlignmentRole:
+ if orientation == qt.Qt.Vertical:
+ return qt.Qt.AlignRight
+ if orientation == qt.Qt.Horizontal:
+ if section == 0x10:
+ return qt.Qt.AlignLeft
+ else:
+ return qt.Qt.AlignCenter
+ return None
+
+ def flags(self, index):
+ """QAbstractTableModel method to inform the view whether data
+ is editable or not.
+ """
+ row = index.row()
+ column = index.column()
+ pos = (row << 4) + column
+ if column != 0x10 and pos >= len(self.__connector):
+ return qt.Qt.NoItemFlags
+ return qt.QAbstractTableModel.flags(self, index)
+
+ def setArrayData(self, data):
+ """Set the data array.
+
+ :param data: A numpy object or a dataset.
+ """
+ self.beginResetModel()
+
+ self.__connector = None
+ self.__data = data
+ if self.__data is not None:
+ if silx.io.utils.is_dataset(self.__data):
+ data = data[()]
+ elif isinstance(self.__data, numpy.ndarray):
+ data = data[()]
+ self.__connector = _VoidConnector(data)
+
+ self.endResetModel()
+
+ def arrayData(self):
+ """Returns the internal data.
+
+ :rtype: numpy.ndarray of h5py.Dataset
+ """
+ return self.__data
+
+
+class HexaTableView(qt.QTableView):
+ """TableView using HexaTableModel as default model.
+
+ It customs the column size to provide a better layout.
+ """
+ def __init__(self, parent=None):
+ """
+ Constructor
+
+ :param qt.QWidget parent: parent QWidget
+ """
+ qt.QTableView.__init__(self, parent)
+
+ model = HexaTableModel(self)
+ self.setModel(model)
+ self._copyAction = CopySelectedCellsAction(self)
+ self.addAction(self._copyAction)
+
+ def copy(self):
+ self._copyAction.trigger()
+
+ def setArrayData(self, data):
+ """Set the data array.
+
+ :param data: A numpy object or a dataset.
+ """
+ self.model().setArrayData(data)
+ self.__fixHeader()
+
+ def __fixHeader(self):
+ """Update the view according to the state of the auto-resize"""
+ header = self.horizontalHeader()
+ header.setDefaultSectionSize(30)
+ header.setStretchLastSection(True)
+ for i in range(0x10):
+ header.setSectionResizeMode(i, qt.QHeaderView.Fixed)
+ header.setSectionResizeMode(0x10, qt.QHeaderView.Stretch)
diff --git a/src/silx/gui/data/NXdataWidgets.py b/src/silx/gui/data/NXdataWidgets.py
new file mode 100644
index 0000000..54ea287
--- /dev/null
+++ b/src/silx/gui/data/NXdataWidgets.py
@@ -0,0 +1,1086 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module defines widgets used by _NXdataView.
+"""
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "12/11/2018"
+
+import logging
+import numpy
+
+from silx.gui import qt
+from silx.gui.data.NumpyAxesSelector import NumpyAxesSelector
+from silx.gui.plot import Plot1D, Plot2D, StackView, ScatterView, items
+from silx.gui.plot.ComplexImageView import ComplexImageView
+from silx.gui.colors import Colormap
+from silx.gui.widgets.FrameBrowser import HorizontalSliderWithBrowser
+
+from silx.math.calibration import ArrayCalibration, NoCalibration, LinearCalibration
+
+
+_logger = logging.getLogger(__name__)
+
+
+class ArrayCurvePlot(qt.QWidget):
+ """
+ Widget for plotting a curve from a multi-dimensional signal array
+ and a 1D axis array.
+
+ The signal array can have an arbitrary number of dimensions, the only
+ limitation being that the last dimension must have the same length as
+ the axis array.
+
+ The widget provides sliders to select indices on the first (n - 1)
+ dimensions of the signal array, and buttons to add/replace selected
+ curves to the plot.
+
+ This widget also handles simple 2D or 3D scatter plots (third dimension
+ displayed as colour of points).
+ """
+ def __init__(self, parent=None):
+ """
+
+ :param parent: Parent QWidget
+ """
+ super(ArrayCurvePlot, self).__init__(parent)
+
+ self.__signals = None
+ self.__signals_names = None
+ self.__signal_errors = None
+ self.__axis = None
+ self.__axis_name = None
+ self.__x_axis_errors = None
+ self.__values = None
+
+ self._plot = Plot1D(self)
+
+ self._selector = NumpyAxesSelector(self)
+ self._selector.setNamedAxesSelectorVisibility(False)
+ self.__selector_is_connected = False
+
+ self._plot.sigActiveCurveChanged.connect(self._setYLabelFromActiveLegend)
+
+ layout = qt.QVBoxLayout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.addWidget(self._plot)
+ layout.addWidget(self._selector)
+
+ self.setLayout(layout)
+
+ def getPlot(self):
+ """Returns the plot used for the display
+
+ :rtype: Plot1D
+ """
+ return self._plot
+
+ def setCurvesData(self, ys, x=None,
+ yerror=None, xerror=None,
+ ylabels=None, xlabel=None, title=None,
+ xscale=None, yscale=None):
+ """
+
+ :param List[ndarray] ys: List of arrays to be represented by the y (vertical) axis.
+ It can be multiple n-D array whose last dimension must
+ have the same length as x (and values must be None)
+ :param ndarray x: 1-D dataset used as the curve's x values. If provided,
+ its lengths must be equal to the length of the last dimension of
+ ``y`` (and equal to the length of ``value``, for a scatter plot).
+ :param ndarray yerror: Single array of errors for y (same shape), or None.
+ There can only be one array, and it applies to the first/main y
+ (no y errors for auxiliary_signals curves).
+ :param ndarray xerror: 1-D dataset of errors for x, or None
+ :param str ylabels: Labels for each curve's Y axis
+ :param str xlabel: Label for X axis
+ :param str title: Graph title
+ :param str xscale: Scale of X axis in (None, 'linear', 'log')
+ :param str yscale: Scale of Y axis in (None, 'linear', 'log')
+ """
+ self.__signals = ys
+ self.__signals_names = ylabels or (["Y"] * len(ys))
+ self.__signal_errors = yerror
+ self.__axis = x
+ self.__axis_name = xlabel
+ self.__x_axis_errors = xerror
+
+ if self.__selector_is_connected:
+ self._selector.selectionChanged.disconnect(self._updateCurve)
+ self.__selector_is_connected = False
+ self._selector.setData(ys[0])
+ self._selector.setAxisNames(["Y"])
+
+ if len(ys[0].shape) < 2:
+ self._selector.hide()
+ else:
+ self._selector.show()
+
+ self._plot.setGraphTitle(title or "")
+ if xscale is not None:
+ self._plot.getXAxis().setScale(
+ 'log' if xscale == 'log' else 'linear')
+ if yscale is not None:
+ self._plot.getYAxis().setScale(
+ 'log' if yscale == 'log' else 'linear')
+ self._updateCurve()
+
+ if not self.__selector_is_connected:
+ self._selector.selectionChanged.connect(self._updateCurve)
+ self.__selector_is_connected = True
+
+ def _updateCurve(self):
+ selection = self._selector.selection()
+ ys = [sig[selection] for sig in self.__signals]
+ y0 = ys[0]
+ len_y = len(y0)
+ x = self.__axis
+ if x is None:
+ x = numpy.arange(len_y)
+ elif numpy.isscalar(x) or len(x) == 1:
+ # constant axis
+ x = x * numpy.ones_like(y0)
+ elif len(x) == 2 and len_y != 2:
+ # linear calibration a + b * x
+ x = x[0] + x[1] * numpy.arange(len_y)
+
+ # Only remove curves that will no longer belong to the plot
+ # So remaining curves keep their settings
+ for item in self._plot.getItems():
+ if (isinstance(item, items.Curve) and
+ item.getName() not in self.__signals_names):
+ self._plot.remove(item)
+
+ for i in range(len(self.__signals)):
+ legend = self.__signals_names[i]
+
+ # errors only supported for primary signal in NXdata
+ y_errors = None
+ if i == 0 and self.__signal_errors is not None:
+ y_errors = self.__signal_errors[self._selector.selection()]
+ self._plot.addCurve(x, ys[i], legend=legend,
+ xerror=self.__x_axis_errors,
+ yerror=y_errors)
+ if i == 0:
+ self._plot.setActiveCurve(legend)
+
+ self._plot.resetZoom()
+ self._plot.getXAxis().setLabel(self.__axis_name)
+ self._plot.getYAxis().setLabel(self.__signals_names[0])
+
+ def _setYLabelFromActiveLegend(self, previous_legend, new_legend):
+ for ylabel in self.__signals_names:
+ if new_legend is not None and new_legend == ylabel:
+ self._plot.getYAxis().setLabel(ylabel)
+ break
+
+ def clear(self):
+ old = self._selector.blockSignals(True)
+ self._selector.clear()
+ self._selector.blockSignals(old)
+ self._plot.clear()
+
+
+class XYVScatterPlot(qt.QWidget):
+ """
+ Widget for plotting one or more scatters
+ (with identical x, y coordinates).
+ """
+ def __init__(self, parent=None):
+ """
+
+ :param parent: Parent QWidget
+ """
+ super(XYVScatterPlot, self).__init__(parent)
+
+ self.__y_axis = None
+ """1D array"""
+ self.__y_axis_name = None
+ self.__values = None
+ """List of 1D arrays (for multiple scatters with identical
+ x, y coordinates)"""
+
+ self.__x_axis = None
+ self.__x_axis_name = None
+ self.__x_axis_errors = None
+ self.__y_axis = None
+ self.__y_axis_name = None
+ self.__y_axis_errors = None
+
+ self._plot = ScatterView(self)
+ self._plot.setColormap(Colormap(name="viridis",
+ vmin=None, vmax=None,
+ normalization=Colormap.LINEAR))
+
+ self._slider = HorizontalSliderWithBrowser(parent=self)
+ self._slider.setMinimum(0)
+ self._slider.setValue(0)
+ self._slider.valueChanged[int].connect(self._sliderIdxChanged)
+ self._slider.setToolTip("Select auxiliary signals")
+
+ layout = qt.QGridLayout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.addWidget(self._plot, 0, 0)
+ layout.addWidget(self._slider, 1, 0)
+
+ self.setLayout(layout)
+
+ def _sliderIdxChanged(self, value):
+ self._updateScatter()
+
+ def getScatterView(self):
+ """Returns the :class:`ScatterView` used for the display
+
+ :rtype: ScatterView
+ """
+ return self._plot
+
+ def getPlot(self):
+ """Returns the plot used for the display
+
+ :rtype: PlotWidget
+ """
+ return self._plot.getPlotWidget()
+
+ def setScattersData(self, y, x, values,
+ yerror=None, xerror=None,
+ ylabel=None, xlabel=None,
+ title="", scatter_titles=None,
+ xscale=None, yscale=None):
+ """
+
+ :param ndarray y: 1D array for y (vertical) coordinates.
+ :param ndarray x: 1D array for x coordinates.
+ :param List[ndarray] values: List of 1D arrays of values.
+ This will be used to compute the color map and assign colors
+ to the points. There should be as many arrays in the list as
+ scatters to be represented.
+ :param ndarray yerror: 1D array of errors for y (same shape), or None.
+ :param ndarray xerror: 1D array of errors for x, or None
+ :param str ylabel: Label for Y axis
+ :param str xlabel: Label for X axis
+ :param str title: Main graph title
+ :param List[str] scatter_titles: Subtitles (one per scatter)
+ :param str xscale: Scale of X axis in (None, 'linear', 'log')
+ :param str yscale: Scale of Y axis in (None, 'linear', 'log')
+ """
+ self.__y_axis = y
+ self.__x_axis = x
+ self.__x_axis_name = xlabel or "X"
+ self.__y_axis_name = ylabel or "Y"
+ self.__x_axis_errors = xerror
+ self.__y_axis_errors = yerror
+ self.__values = values
+
+ self.__graph_title = title or ""
+ self.__scatter_titles = scatter_titles
+
+ self._slider.valueChanged[int].disconnect(self._sliderIdxChanged)
+ self._slider.setMaximum(len(values) - 1)
+ if len(values) > 1:
+ self._slider.show()
+ else:
+ self._slider.hide()
+ self._slider.setValue(0)
+ self._slider.valueChanged[int].connect(self._sliderIdxChanged)
+
+ if xscale is not None:
+ self._plot.getXAxis().setScale(
+ 'log' if xscale == 'log' else 'linear')
+ if yscale is not None:
+ self._plot.getYAxis().setScale(
+ 'log' if yscale == 'log' else 'linear')
+
+ self._updateScatter()
+
+ def _updateScatter(self):
+ x = self.__x_axis
+ y = self.__y_axis
+
+ idx = self._slider.value()
+
+ if self.__graph_title:
+ title = self.__graph_title # main NXdata @title
+ if len(self.__scatter_titles) > 1:
+ # Append dataset name only when there is many datasets
+ title += '\n' + self.__scatter_titles[idx]
+ else:
+ title = self.__scatter_titles[idx] # scatter dataset name
+
+ self._plot.setGraphTitle(title)
+ self._plot.setData(x, y, self.__values[idx],
+ xerror=self.__x_axis_errors,
+ yerror=self.__y_axis_errors)
+ self._plot.resetZoom()
+ self._plot.getXAxis().setLabel(self.__x_axis_name)
+ self._plot.getYAxis().setLabel(self.__y_axis_name)
+
+ def clear(self):
+ self._plot.getPlotWidget().clear()
+
+
+class ArrayImagePlot(qt.QWidget):
+ """
+ Widget for plotting an image from a multi-dimensional signal array
+ and two 1D axes array.
+
+ The signal array can have an arbitrary number of dimensions, the only
+ limitation being that the last two dimensions must have the same length as
+ the axes arrays.
+
+ Sliders are provided to select indices on the first (n - 2) dimensions of
+ the signal array, and the plot is updated to show the image corresponding
+ to the selection.
+
+ If one or both of the axes does not have regularly spaced values, the
+ the image is plotted as a coloured scatter plot.
+ """
+ def __init__(self, parent=None):
+ """
+
+ :param parent: Parent QWidget
+ """
+ super(ArrayImagePlot, self).__init__(parent)
+
+ self.__signals = None
+ self.__signals_names = None
+ self.__x_axis = None
+ self.__x_axis_name = None
+ self.__y_axis = None
+ self.__y_axis_name = None
+
+ self._plot = Plot2D(self)
+ self._plot.setDefaultColormap(Colormap(name="viridis",
+ vmin=None, vmax=None,
+ normalization=Colormap.LINEAR))
+ self._plot.getIntensityHistogramAction().setVisible(True)
+ self._plot.setKeepDataAspectRatio(True)
+ maskToolWidget = self._plot.getMaskToolsDockWidget().widget()
+ maskToolWidget.setItemMaskUpdated(True)
+
+ # not closable
+ self._selector = NumpyAxesSelector(self)
+ self._selector.setNamedAxesSelectorVisibility(False)
+ self._selector.selectionChanged.connect(self._updateImage)
+
+ self._auxSigSlider = HorizontalSliderWithBrowser(parent=self)
+ self._auxSigSlider.setMinimum(0)
+ self._auxSigSlider.setValue(0)
+ self._auxSigSlider.valueChanged[int].connect(self._sliderIdxChanged)
+ self._auxSigSlider.setToolTip("Select auxiliary signals")
+
+ layout = qt.QVBoxLayout()
+ layout.addWidget(self._plot)
+ layout.addWidget(self._selector)
+ layout.addWidget(self._auxSigSlider)
+
+ self.setLayout(layout)
+
+ def _sliderIdxChanged(self, value):
+ self._updateImage()
+
+ def getPlot(self):
+ """Returns the plot used for the display
+
+ :rtype: Plot2D
+ """
+ return self._plot
+
+ def setImageData(self, signals,
+ x_axis=None, y_axis=None,
+ signals_names=None,
+ xlabel=None, ylabel=None,
+ title=None, isRgba=False,
+ xscale=None, yscale=None):
+ """
+
+ :param signals: list of n-D datasets, whose last 2 dimensions are used as the
+ image's values, or list of 3D datasets interpreted as RGBA image.
+ :param x_axis: 1-D dataset used as the image's x coordinates. If
+ provided, its lengths must be equal to the length of the last
+ dimension of ``signal``.
+ :param y_axis: 1-D dataset used as the image's y. If provided,
+ its lengths must be equal to the length of the 2nd to last
+ dimension of ``signal``.
+ :param signals_names: Names for each image, used as subtitle and legend.
+ :param xlabel: Label for X axis
+ :param ylabel: Label for Y axis
+ :param title: Graph title
+ :param isRgba: True if data is a 3D RGBA image
+ :param str xscale: Scale of X axis in (None, 'linear', 'log')
+ :param str yscale: Scale of Y axis in (None, 'linear', 'log')
+ """
+ self._selector.selectionChanged.disconnect(self._updateImage)
+ self._auxSigSlider.valueChanged.disconnect(self._sliderIdxChanged)
+
+ self.__signals = signals
+ self.__signals_names = signals_names
+ self.__x_axis = x_axis
+ self.__x_axis_name = xlabel
+ self.__y_axis = y_axis
+ self.__y_axis_name = ylabel
+ self.__title = title
+
+ self._selector.clear()
+ if not isRgba:
+ self._selector.setAxisNames(["Y", "X"])
+ img_ndim = 2
+ else:
+ self._selector.setAxisNames(["Y", "X", "RGB(A) channel"])
+ img_ndim = 3
+ self._selector.setData(signals[0])
+
+ if len(signals[0].shape) <= img_ndim:
+ self._selector.hide()
+ else:
+ self._selector.show()
+
+ self._auxSigSlider.setMaximum(len(signals) - 1)
+ if len(signals) > 1:
+ self._auxSigSlider.show()
+ else:
+ self._auxSigSlider.hide()
+ self._auxSigSlider.setValue(0)
+
+ self._axis_scales = xscale, yscale
+ self._updateImage()
+ self._plot.resetZoom()
+
+ self._selector.selectionChanged.connect(self._updateImage)
+ self._auxSigSlider.valueChanged.connect(self._sliderIdxChanged)
+
+ def _updateImage(self):
+ selection = self._selector.selection()
+ auxSigIdx = self._auxSigSlider.value()
+
+ legend = self.__signals_names[auxSigIdx]
+
+ images = [img[selection] for img in self.__signals]
+ image = images[auxSigIdx]
+
+ x_axis = self.__x_axis
+ y_axis = self.__y_axis
+
+ if x_axis is None and y_axis is None:
+ xcalib = NoCalibration()
+ ycalib = NoCalibration()
+ else:
+ if x_axis is None:
+ # no calibration
+ x_axis = numpy.arange(image.shape[1])
+ elif numpy.isscalar(x_axis) or len(x_axis) == 1:
+ # constant axis
+ x_axis = x_axis * numpy.ones((image.shape[1], ))
+ elif len(x_axis) == 2:
+ # linear calibration
+ x_axis = x_axis[0] * numpy.arange(image.shape[1]) + x_axis[1]
+
+ if y_axis is None:
+ y_axis = numpy.arange(image.shape[0])
+ elif numpy.isscalar(y_axis) or len(y_axis) == 1:
+ y_axis = y_axis * numpy.ones((image.shape[0], ))
+ elif len(y_axis) == 2:
+ y_axis = y_axis[0] * numpy.arange(image.shape[0]) + y_axis[1]
+
+ xcalib = ArrayCalibration(x_axis)
+ ycalib = ArrayCalibration(y_axis)
+
+ self._plot.remove(kind=("scatter", "image",))
+ if xcalib.is_affine() and ycalib.is_affine():
+ # regular image
+ xorigin, xscale = xcalib(0), xcalib.get_slope()
+ yorigin, yscale = ycalib(0), ycalib.get_slope()
+ origin = (xorigin, yorigin)
+ scale = (xscale, yscale)
+
+ self._plot.getXAxis().setScale('linear')
+ self._plot.getYAxis().setScale('linear')
+ self._plot.addImage(image, legend=legend,
+ origin=origin, scale=scale,
+ replace=True, resetzoom=False)
+ else:
+ xaxisscale, yaxisscale = self._axis_scales
+
+ if xaxisscale is not None:
+ self._plot.getXAxis().setScale(
+ 'log' if xaxisscale == 'log' else 'linear')
+ if yaxisscale is not None:
+ self._plot.getYAxis().setScale(
+ 'log' if yaxisscale == 'log' else 'linear')
+
+ scatterx, scattery = numpy.meshgrid(x_axis, y_axis)
+ # fixme: i don't think this can handle "irregular" RGBA images
+ self._plot.addScatter(numpy.ravel(scatterx),
+ numpy.ravel(scattery),
+ numpy.ravel(image),
+ legend=legend)
+
+ if self.__title:
+ title = self.__title
+ if len(self.__signals_names) > 1:
+ # Append dataset name only when there is many datasets
+ title += '\n' + self.__signals_names[auxSigIdx]
+ else:
+ title = self.__signals_names[auxSigIdx]
+ self._plot.setGraphTitle(title)
+ self._plot.getXAxis().setLabel(self.__x_axis_name)
+ self._plot.getYAxis().setLabel(self.__y_axis_name)
+
+ def clear(self):
+ old = self._selector.blockSignals(True)
+ self._selector.clear()
+ self._selector.blockSignals(old)
+ self._plot.clear()
+
+
+class ArrayComplexImagePlot(qt.QWidget):
+ """
+ Widget for plotting an image of complex from a multi-dimensional signal array
+ and two 1D axes array.
+
+ The signal array can have an arbitrary number of dimensions, the only
+ limitation being that the last two dimensions must have the same length as
+ the axes arrays.
+
+ Sliders are provided to select indices on the first (n - 2) dimensions of
+ the signal array, and the plot is updated to show the image corresponding
+ to the selection.
+
+ If one or both of the axes does not have regularly spaced values, the
+ the image is plotted as a coloured scatter plot.
+ """
+ def __init__(self, parent=None, colormap=None):
+ """
+
+ :param parent: Parent QWidget
+ """
+ super(ArrayComplexImagePlot, self).__init__(parent)
+
+ self.__signals = None
+ self.__signals_names = None
+ self.__x_axis = None
+ self.__x_axis_name = None
+ self.__y_axis = None
+ self.__y_axis_name = None
+
+ self._plot = ComplexImageView(self)
+ if colormap is not None:
+ for mode in (ComplexImageView.ComplexMode.ABSOLUTE,
+ ComplexImageView.ComplexMode.SQUARE_AMPLITUDE,
+ ComplexImageView.ComplexMode.REAL,
+ ComplexImageView.ComplexMode.IMAGINARY):
+ self._plot.setColormap(colormap, mode)
+
+ self._plot.getPlot().getIntensityHistogramAction().setVisible(True)
+ self._plot.setKeepDataAspectRatio(True)
+ maskToolWidget = self._plot.getPlot().getMaskToolsDockWidget().widget()
+ maskToolWidget.setItemMaskUpdated(True)
+
+ # not closable
+ self._selector = NumpyAxesSelector(self)
+ self._selector.setNamedAxesSelectorVisibility(False)
+ self._selector.selectionChanged.connect(self._updateImage)
+
+ self._auxSigSlider = HorizontalSliderWithBrowser(parent=self)
+ self._auxSigSlider.setMinimum(0)
+ self._auxSigSlider.setValue(0)
+ self._auxSigSlider.valueChanged[int].connect(self._sliderIdxChanged)
+ self._auxSigSlider.setToolTip("Select auxiliary signals")
+
+ layout = qt.QVBoxLayout()
+ layout.addWidget(self._plot)
+ layout.addWidget(self._selector)
+ layout.addWidget(self._auxSigSlider)
+
+ self.setLayout(layout)
+
+ def _sliderIdxChanged(self, value):
+ self._updateImage()
+
+ def getPlot(self):
+ """Returns the plot used for the display
+
+ :rtype: PlotWidget
+ """
+ return self._plot.getPlot()
+
+ def setImageData(self, signals,
+ x_axis=None, y_axis=None,
+ signals_names=None,
+ xlabel=None, ylabel=None,
+ title=None):
+ """
+
+ :param signals: list of n-D datasets, whose last 2 dimensions are used as the
+ image's values, or list of 3D datasets interpreted as RGBA image.
+ :param x_axis: 1-D dataset used as the image's x coordinates. If
+ provided, its lengths must be equal to the length of the last
+ dimension of ``signal``.
+ :param y_axis: 1-D dataset used as the image's y. If provided,
+ its lengths must be equal to the length of the 2nd to last
+ dimension of ``signal``.
+ :param signals_names: Names for each image, used as subtitle and legend.
+ :param xlabel: Label for X axis
+ :param ylabel: Label for Y axis
+ :param title: Graph title
+ """
+ self._selector.selectionChanged.disconnect(self._updateImage)
+ self._auxSigSlider.valueChanged.disconnect(self._sliderIdxChanged)
+
+ self.__signals = signals
+ self.__signals_names = signals_names
+ self.__x_axis = x_axis
+ self.__x_axis_name = xlabel
+ self.__y_axis = y_axis
+ self.__y_axis_name = ylabel
+ self.__title = title
+
+ self._selector.clear()
+ self._selector.setAxisNames(["Y", "X"])
+ self._selector.setData(signals[0])
+
+ if len(signals[0].shape) <= 2:
+ self._selector.hide()
+ else:
+ self._selector.show()
+
+ self._auxSigSlider.setMaximum(len(signals) - 1)
+ if len(signals) > 1:
+ self._auxSigSlider.show()
+ else:
+ self._auxSigSlider.hide()
+ self._auxSigSlider.setValue(0)
+
+ self._updateImage()
+ self._plot.getPlot().resetZoom()
+
+ self._selector.selectionChanged.connect(self._updateImage)
+ self._auxSigSlider.valueChanged.connect(self._sliderIdxChanged)
+
+ def _updateImage(self):
+ selection = self._selector.selection()
+ auxSigIdx = self._auxSigSlider.value()
+
+ images = [img[selection] for img in self.__signals]
+ image = images[auxSigIdx]
+
+ x_axis = self.__x_axis
+ y_axis = self.__y_axis
+
+ if x_axis is None and y_axis is None:
+ xcalib = NoCalibration()
+ ycalib = NoCalibration()
+ else:
+ if x_axis is None:
+ # no calibration
+ x_axis = numpy.arange(image.shape[1])
+ elif numpy.isscalar(x_axis) or len(x_axis) == 1:
+ # constant axis
+ x_axis = x_axis * numpy.ones((image.shape[1], ))
+ elif len(x_axis) == 2:
+ # linear calibration
+ x_axis = x_axis[0] * numpy.arange(image.shape[1]) + x_axis[1]
+
+ if y_axis is None:
+ y_axis = numpy.arange(image.shape[0])
+ elif numpy.isscalar(y_axis) or len(y_axis) == 1:
+ y_axis = y_axis * numpy.ones((image.shape[0], ))
+ elif len(y_axis) == 2:
+ y_axis = y_axis[0] * numpy.arange(image.shape[0]) + y_axis[1]
+
+ xcalib = ArrayCalibration(x_axis)
+ ycalib = ArrayCalibration(y_axis)
+
+ self._plot.setData(image)
+ if xcalib.is_affine():
+ xorigin, xscale = xcalib(0), xcalib.get_slope()
+ else:
+ _logger.warning("Unsupported complex image X axis calibration")
+ xorigin, xscale = 0., 1.
+
+ if ycalib.is_affine():
+ yorigin, yscale = ycalib(0), ycalib.get_slope()
+ else:
+ _logger.warning("Unsupported complex image Y axis calibration")
+ yorigin, yscale = 0., 1.
+
+ self._plot.setOrigin((xorigin, yorigin))
+ self._plot.setScale((xscale, yscale))
+
+ if self.__title:
+ title = self.__title
+ if len(self.__signals_names) > 1:
+ # Append dataset name only when there is many datasets
+ title += '\n' + self.__signals_names[auxSigIdx]
+ else:
+ title = self.__signals_names[auxSigIdx]
+ self._plot.setGraphTitle(title)
+ self._plot.getXAxis().setLabel(self.__x_axis_name)
+ self._plot.getYAxis().setLabel(self.__y_axis_name)
+
+ def clear(self):
+ old = self._selector.blockSignals(True)
+ self._selector.clear()
+ self._selector.blockSignals(old)
+ self._plot.setData(None)
+
+
+class ArrayStackPlot(qt.QWidget):
+ """
+ Widget for plotting a n-D array (n >= 3) as a stack of images.
+ Three axis arrays can be provided to calibrate the axes.
+
+ The signal array can have an arbitrary number of dimensions, the only
+ limitation being that the last 3 dimensions must have the same length as
+ the axes arrays.
+
+ Sliders are provided to select indices on the first (n - 3) dimensions of
+ the signal array, and the plot is updated to load the stack corresponding
+ to the selection.
+ """
+ def __init__(self, parent=None):
+ """
+
+ :param parent: Parent QWidget
+ """
+ super(ArrayStackPlot, self).__init__(parent)
+
+ self.__signal = None
+ self.__signal_name = None
+ # the Z, Y, X axes apply to the last three dimensions of the signal
+ # (in that order)
+ self.__z_axis = None
+ self.__z_axis_name = None
+ self.__y_axis = None
+ self.__y_axis_name = None
+ self.__x_axis = None
+ self.__x_axis_name = None
+
+ self._stack_view = StackView(self)
+ maskToolWidget = self._stack_view.getPlotWidget().getMaskToolsDockWidget().widget()
+ maskToolWidget.setItemMaskUpdated(True)
+
+ self._hline = qt.QFrame(self)
+ self._hline.setFrameStyle(qt.QFrame.HLine)
+ self._hline.setFrameShadow(qt.QFrame.Sunken)
+ self._legend = qt.QLabel(self)
+ self._selector = NumpyAxesSelector(self)
+ self._selector.setNamedAxesSelectorVisibility(False)
+ self.__selector_is_connected = False
+
+ layout = qt.QVBoxLayout()
+ layout.addWidget(self._stack_view)
+ layout.addWidget(self._hline)
+ layout.addWidget(self._legend)
+ layout.addWidget(self._selector)
+
+ self.setLayout(layout)
+
+ def getStackView(self):
+ """Returns the plot used for the display
+
+ :rtype: StackView
+ """
+ return self._stack_view
+
+ def setStackData(self, signal,
+ x_axis=None, y_axis=None, z_axis=None,
+ signal_name=None,
+ xlabel=None, ylabel=None, zlabel=None,
+ title=None):
+ """
+
+ :param signal: n-D dataset, whose last 3 dimensions are used as the
+ 3D stack values.
+ :param x_axis: 1-D dataset used as the image's x coordinates. If
+ provided, its lengths must be equal to the length of the last
+ dimension of ``signal``.
+ :param y_axis: 1-D dataset used as the image's y. If provided,
+ its lengths must be equal to the length of the 2nd to last
+ dimension of ``signal``.
+ :param z_axis: 1-D dataset used as the image's z. If provided,
+ its lengths must be equal to the length of the 3rd to last
+ dimension of ``signal``.
+ :param signal_name: Label used in the legend
+ :param xlabel: Label for X axis
+ :param ylabel: Label for Y axis
+ :param zlabel: Label for Z axis
+ :param title: Graph title
+ """
+ if self.__selector_is_connected:
+ self._selector.selectionChanged.disconnect(self._updateStack)
+ self.__selector_is_connected = False
+
+ self.__signal = signal
+ self.__signal_name = signal_name or ""
+ self.__x_axis = x_axis
+ self.__x_axis_name = xlabel
+ self.__y_axis = y_axis
+ self.__y_axis_name = ylabel
+ self.__z_axis = z_axis
+ self.__z_axis_name = zlabel
+
+ self._selector.setData(signal)
+ self._selector.setAxisNames(["Y", "X", "Z"])
+
+ self._stack_view.setGraphTitle(title or "")
+ # by default, the z axis is the image position (dimension not plotted)
+ self._stack_view.getPlotWidget().getXAxis().setLabel(self.__x_axis_name or "X")
+ self._stack_view.getPlotWidget().getYAxis().setLabel(self.__y_axis_name or "Y")
+
+ self._updateStack()
+
+ ndims = len(signal.shape)
+ self._stack_view.setFirstStackDimension(ndims - 3)
+
+ # the legend label shows the selection slice producing the volume
+ # (only interesting for ndim > 3)
+ if ndims > 3:
+ self._selector.setVisible(True)
+ self._legend.setVisible(True)
+ self._hline.setVisible(True)
+ else:
+ self._selector.setVisible(False)
+ self._legend.setVisible(False)
+ self._hline.setVisible(False)
+
+ if not self.__selector_is_connected:
+ self._selector.selectionChanged.connect(self._updateStack)
+ self.__selector_is_connected = True
+
+ @staticmethod
+ def _get_origin_scale(axis):
+ """Assuming axis is a regularly spaced 1D array,
+ return a tuple (origin, scale) where:
+ - origin = axis[0]
+ - scale = (axis[n-1] - axis[0]) / (n -1)
+ :param axis: 1D numpy array
+ :return: Tuple (axis[0], (axis[-1] - axis[0]) / (len(axis) - 1))
+ """
+ return axis[0], (axis[-1] - axis[0]) / (len(axis) - 1)
+
+ def _updateStack(self):
+ """Update displayed stack according to the current axes selector
+ data."""
+ stk = self._selector.selectedData()
+ x_axis = self.__x_axis
+ y_axis = self.__y_axis
+ z_axis = self.__z_axis
+
+ calibrations = []
+ for axis in [z_axis, y_axis, x_axis]:
+
+ if axis is None:
+ calibrations.append(NoCalibration())
+ elif len(axis) == 2:
+ calibrations.append(
+ LinearCalibration(y_intercept=axis[0],
+ slope=axis[1]))
+ else:
+ calibrations.append(ArrayCalibration(axis))
+
+ legend = self.__signal_name + "["
+ for sl in self._selector.selection():
+ if sl == slice(None):
+ legend += ":, "
+ else:
+ legend += str(sl) + ", "
+ legend = legend[:-2] + "]"
+ self._legend.setText("Displayed data: " + legend)
+
+ self._stack_view.setStack(stk, calibrations=calibrations)
+ self._stack_view.setLabels(
+ labels=[self.__z_axis_name,
+ self.__y_axis_name,
+ self.__x_axis_name])
+
+ def clear(self):
+ old = self._selector.blockSignals(True)
+ self._selector.clear()
+ self._selector.blockSignals(old)
+ self._stack_view.clear()
+
+
+class ArrayVolumePlot(qt.QWidget):
+ """
+ Widget for plotting a n-D array (n >= 3) as a 3D scalar field.
+ Three axis arrays can be provided to calibrate the axes.
+
+ The signal array can have an arbitrary number of dimensions, the only
+ limitation being that the last 3 dimensions must have the same length as
+ the axes arrays.
+
+ Sliders are provided to select indices on the first (n - 3) dimensions of
+ the signal array, and the plot is updated to load the stack corresponding
+ to the selection.
+ """
+ def __init__(self, parent=None):
+ """
+
+ :param parent: Parent QWidget
+ """
+ super(ArrayVolumePlot, self).__init__(parent)
+
+ self.__signal = None
+ self.__signal_name = None
+ # the Z, Y, X axes apply to the last three dimensions of the signal
+ # (in that order)
+ self.__z_axis = None
+ self.__z_axis_name = None
+ self.__y_axis = None
+ self.__y_axis_name = None
+ self.__x_axis = None
+ self.__x_axis_name = None
+
+ from ._VolumeWindow import VolumeWindow
+
+ self._view = VolumeWindow(self)
+
+ self._hline = qt.QFrame(self)
+ self._hline.setFrameStyle(qt.QFrame.HLine)
+ self._hline.setFrameShadow(qt.QFrame.Sunken)
+ self._legend = qt.QLabel(self)
+ self._selector = NumpyAxesSelector(self)
+ self._selector.setNamedAxesSelectorVisibility(False)
+ self.__selector_is_connected = False
+
+ layout = qt.QVBoxLayout()
+ layout.addWidget(self._view)
+ layout.addWidget(self._hline)
+ layout.addWidget(self._legend)
+ layout.addWidget(self._selector)
+
+ self.setLayout(layout)
+
+ def getVolumeView(self):
+ """Returns the plot used for the display
+
+ :rtype: SceneWindow
+ """
+ return self._view
+
+ def setData(self, signal,
+ x_axis=None, y_axis=None, z_axis=None,
+ signal_name=None,
+ xlabel=None, ylabel=None, zlabel=None,
+ title=None):
+ """
+
+ :param signal: n-D dataset, whose last 3 dimensions are used as the
+ 3D stack values.
+ :param x_axis: 1-D dataset used as the image's x coordinates. If
+ provided, its lengths must be equal to the length of the last
+ dimension of ``signal``.
+ :param y_axis: 1-D dataset used as the image's y. If provided,
+ its lengths must be equal to the length of the 2nd to last
+ dimension of ``signal``.
+ :param z_axis: 1-D dataset used as the image's z. If provided,
+ its lengths must be equal to the length of the 3rd to last
+ dimension of ``signal``.
+ :param signal_name: Label used in the legend
+ :param xlabel: Label for X axis
+ :param ylabel: Label for Y axis
+ :param zlabel: Label for Z axis
+ :param title: Graph title
+ """
+ if self.__selector_is_connected:
+ self._selector.selectionChanged.disconnect(self._updateVolume)
+ self.__selector_is_connected = False
+
+ self.__signal = signal
+ self.__signal_name = signal_name or ""
+ self.__x_axis = x_axis
+ self.__x_axis_name = xlabel
+ self.__y_axis = y_axis
+ self.__y_axis_name = ylabel
+ self.__z_axis = z_axis
+ self.__z_axis_name = zlabel
+
+ self._selector.setData(signal)
+ self._selector.setAxisNames(["Y", "X", "Z"])
+
+ self._updateVolume()
+
+ # the legend label shows the selection slice producing the volume
+ # (only interesting for ndim > 3)
+ if signal.ndim > 3:
+ self._selector.setVisible(True)
+ self._legend.setVisible(True)
+ self._hline.setVisible(True)
+ else:
+ self._selector.setVisible(False)
+ self._legend.setVisible(False)
+ self._hline.setVisible(False)
+
+ if not self.__selector_is_connected:
+ self._selector.selectionChanged.connect(self._updateVolume)
+ self.__selector_is_connected = True
+
+ def _updateVolume(self):
+ """Update displayed stack according to the current axes selector
+ data."""
+ x_axis = self.__x_axis
+ y_axis = self.__y_axis
+ z_axis = self.__z_axis
+
+ offset = []
+ scale = []
+ for axis in [x_axis, y_axis, z_axis]:
+ if axis is None:
+ calibration = NoCalibration()
+ elif len(axis) == 2:
+ calibration = LinearCalibration(
+ y_intercept=axis[0], slope=axis[1])
+ else:
+ calibration = ArrayCalibration(axis)
+ if not calibration.is_affine():
+ _logger.warning("Axis has not linear values, ignored")
+ offset.append(0.)
+ scale.append(1.)
+ else:
+ offset.append(calibration(0))
+ scale.append(calibration.get_slope())
+
+ legend = self.__signal_name + "["
+ for sl in self._selector.selection():
+ if sl == slice(None):
+ legend += ":, "
+ else:
+ legend += str(sl) + ", "
+ legend = legend[:-2] + "]"
+ self._legend.setText("Displayed data: " + legend)
+
+ # Update SceneWidget
+ data = self._selector.selectedData()
+
+ volumeView = self.getVolumeView()
+ volumeView.setData(data, offset=offset, scale=scale)
+ volumeView.setAxesLabels(
+ self.__x_axis_name, self.__y_axis_name, self.__z_axis_name)
+
+ def clear(self):
+ old = self._selector.blockSignals(True)
+ self._selector.clear()
+ self._selector.blockSignals(old)
+ self.getVolumeView().clear()
diff --git a/src/silx/gui/data/NumpyAxesSelector.py b/src/silx/gui/data/NumpyAxesSelector.py
new file mode 100644
index 0000000..e6da0d4
--- /dev/null
+++ b/src/silx/gui/data/NumpyAxesSelector.py
@@ -0,0 +1,578 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module defines a widget able to convert a numpy array from n-dimensions
+to a numpy array with less dimensions.
+"""
+from __future__ import division
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "29/01/2018"
+
+import logging
+import numpy
+import functools
+from silx.gui.widgets.FrameBrowser import HorizontalSliderWithBrowser
+from silx.gui import qt
+from silx.gui.utils import blockSignals
+import silx.utils.weakref
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _Axis(qt.QWidget):
+ """Widget displaying an axis.
+
+ It allows to display and scroll in the axis, and provide a widget to
+ map the axis with a named axis (the one from the view).
+ """
+
+ valueChanged = qt.Signal(int)
+ """Emitted when the location on the axis change."""
+
+ axisNameChanged = qt.Signal(object)
+ """Emitted when the user change the name of the axis."""
+
+ def __init__(self, parent=None):
+ """Constructor
+
+ :param parent: Parent of the widget
+ """
+ super(_Axis, self).__init__(parent)
+ self.__axisNumber = None
+ self.__customAxisNames = set([])
+ self.__label = qt.QLabel(self)
+ self.__axes = qt.QComboBox(self)
+ self.__axes.currentIndexChanged[int].connect(self.__axisMappingChanged)
+ self.__slider = HorizontalSliderWithBrowser(self)
+ self.__slider.valueChanged[int].connect(self.__sliderValueChanged)
+ layout = qt.QHBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.addWidget(self.__label)
+ layout.addWidget(self.__axes)
+ layout.addWidget(self.__slider, 10000)
+ layout.addStretch(1)
+ self.setLayout(layout)
+
+ def slider(self):
+ """Returns the slider used to display axes location.
+
+ :rtype: HorizontalSliderWithBrowser
+ """
+ return self.__slider
+
+ def setAxis(self, number, position, size):
+ """Set axis information.
+
+ :param int number: The number of the axis (from the original numpy
+ array)
+ :param int position: The current position in the axis (for a slicing)
+ :param int size: The size of this axis (0..n)
+ """
+ self.__label.setText("Dimension %s" % number)
+ self.__axisNumber = number
+ self.__slider.setMaximum(size - 1)
+
+ def axisNumber(self):
+ """Returns the axis number.
+
+ :rtype: int
+ """
+ return self.__axisNumber
+
+ def setAxisName(self, axisName):
+ """Set the current used axis name.
+
+ If this name is not available an exception is raised. An empty string
+ means that no name is selected.
+
+ :param str axisName: The new name of the axis
+ :raise ValueError: When the name is not available
+ """
+ if axisName == "" and self.__axes.count() == 0:
+ self.__axes.setCurrentIndex(-1)
+ self.__updateSliderVisibility()
+ return
+
+ for index in range(self.__axes.count()):
+ name = self.__axes.itemData(index)
+ if name == axisName:
+ self.__axes.setCurrentIndex(index)
+ self.__updateSliderVisibility()
+ return
+ raise ValueError("Axis name '%s' not found", axisName)
+
+ def axisName(self):
+ """Returns the selected axis name.
+
+ If no name is selected, an empty string is returned.
+
+ :rtype: str
+ """
+ index = self.__axes.currentIndex()
+ if index == -1:
+ return ""
+ return self.__axes.itemData(index)
+
+ def setAxisNames(self, axesNames):
+ """Set the available list of names for the axis.
+
+ :param List[str] axesNames: List of available names
+ """
+ self.__axes.clear()
+ with blockSignals(self.__axes):
+ self.__axes.addItem(" ", "")
+ for axis in axesNames:
+ self.__axes.addItem(axis, axis)
+
+ self.__updateSliderVisibility()
+
+ def setCustomAxis(self, axesNames):
+ """Set the available list of named axis which can be set to a value.
+
+ :param List[str] axesNames: List of customable axis names
+ """
+ self.__customAxisNames = set(axesNames)
+ self.__updateSliderVisibility()
+
+ def __axisMappingChanged(self, index):
+ """Called when the selected name change.
+
+ :param int index: Selected index
+ """
+ self.__updateSliderVisibility()
+ name = self.axisName()
+ self.axisNameChanged.emit(name)
+
+ def __updateSliderVisibility(self):
+ """Update the visibility of the slider according to axis names and
+ customable axis names."""
+ name = self.axisName()
+ isVisible = name == "" or name in self.__customAxisNames
+ self.__slider.setVisible(isVisible)
+
+ def value(self):
+ """Returns the currently selected position in the axis.
+
+ :rtype: int
+ """
+ return self.__slider.value()
+
+ def setValue(self, value):
+ """Set the currently selected position in the axis.
+
+ :param int value:
+ """
+ self.__slider.setValue(value)
+
+ def __sliderValueChanged(self, value):
+ """Called when the selected position in the axis change.
+
+ :param int value: Position of the axis
+ """
+ self.valueChanged.emit(value)
+
+ def setNamedAxisSelectorVisibility(self, visible):
+ """Hide or show the named axis combobox.
+
+ If both the selector and the slider are hidden, hide the entire widget.
+
+ :param visible: boolean
+ """
+ self.__axes.setVisible(visible)
+ name = self.axisName()
+ self.setVisible(visible or name == "")
+
+
+class NumpyAxesSelector(qt.QWidget):
+ """Widget to select a view from a numpy array.
+
+ .. image:: img/NumpyAxesSelector.png
+
+ The widget is set with an input data using :meth:`setData`, and a requested
+ output dimension using :meth:`setAxisNames`.
+
+ Widgets are provided to selected expected input axis, and a slice on the
+ non-selected axis.
+
+ The final selected array can be reached using the getter
+ :meth:`selectedData`, and the event `selectionChanged`.
+
+ If the input data is a HDF5 Dataset, the selected output data will be a
+ new numpy array.
+ """
+
+ dataChanged = qt.Signal()
+ """Emitted when the input data change"""
+
+ selectedAxisChanged = qt.Signal()
+ """Emitted when the selected axis change"""
+
+ selectionChanged = qt.Signal()
+ """Emitted when the selected data change"""
+
+ customAxisChanged = qt.Signal(str, int)
+ """Emitted when a custom axis change"""
+
+ def __init__(self, parent=None):
+ """Constructor
+
+ :param parent: Parent of the widget
+ """
+ super(NumpyAxesSelector, self).__init__(parent)
+
+ self.__data = None
+ self.__selectedData = None
+ self.__axis = []
+ self.__axisNames = []
+ self.__customAxisNames = set([])
+ self.__namedAxesVisibility = True
+ layout = qt.QVBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSizeConstraint(qt.QLayout.SetMinAndMaxSize)
+ self.setLayout(layout)
+
+ def clear(self):
+ """Clear the widget."""
+ self.setData(None)
+
+ def setAxisNames(self, axesNames):
+ """Set the axis names of the output selected data.
+
+ Axis names are defined from slower to faster axis.
+
+ The size of the list will constrain the dimension of the resulting
+ array.
+
+ :param List[str] axesNames: List of distinct strings identifying axis names
+ """
+ self.__axisNames = list(axesNames)
+ assert len(set(self.__axisNames)) == len(self.__axisNames),\
+ "Non-unique axes names: %s" % self.__axisNames
+
+ delta = len(self.__axis) - len(self.__axisNames)
+ if delta < 0:
+ delta = 0
+ for index, axis in enumerate(self.__axis):
+ with blockSignals(axis):
+ axis.setAxisNames(self.__axisNames)
+ if index >= delta and index - delta < len(self.__axisNames):
+ axis.setAxisName(self.__axisNames[index - delta])
+ else:
+ axis.setAxisName("")
+ self.__updateSelectedData()
+
+ def setCustomAxis(self, axesNames):
+ """Set the available list of named axis which can be set to a value.
+
+ :param List[str] axesNames: List of customable axis names
+ """
+ self.__customAxisNames = set(axesNames)
+ for axis in self.__axis:
+ axis.setCustomAxis(self.__customAxisNames)
+
+ def setData(self, data):
+ """Set the input data unsed by the widget.
+
+ :param numpy.ndarray data: The input data
+ """
+ if self.__data is not None:
+ # clean up
+ for widget in self.__axis:
+ self.layout().removeWidget(widget)
+ widget.deleteLater()
+ self.__axis = []
+
+ self.__data = data
+
+ if data is not None:
+ # create expected axes
+ dimensionNumber = len(data.shape)
+ delta = dimensionNumber - len(self.__axisNames)
+ for index in range(dimensionNumber):
+ axis = _Axis(self)
+ axis.setAxis(index, 0, data.shape[index])
+ axis.setAxisNames(self.__axisNames)
+ axis.setCustomAxis(self.__customAxisNames)
+ if index >= delta and index - delta < len(self.__axisNames):
+ axis.setAxisName(self.__axisNames[index - delta])
+ # this weak method was expected to be able to delete sub widget
+ callback = functools.partial(silx.utils.weakref.WeakMethodProxy(self.__axisValueChanged), axis)
+ axis.valueChanged.connect(callback)
+ # this weak method was expected to be able to delete sub widget
+ callback = functools.partial(silx.utils.weakref.WeakMethodProxy(self.__axisNameChanged), axis)
+ axis.axisNameChanged.connect(callback)
+ axis.setNamedAxisSelectorVisibility(self.__namedAxesVisibility)
+ self.layout().addWidget(axis)
+ self.__axis.append(axis)
+ self.__normalizeAxisGeometry()
+
+ self.dataChanged.emit()
+ self.__updateSelectedData()
+
+ def __normalizeAxisGeometry(self):
+ """Update axes geometry to align all axes components together."""
+ if len(self.__axis) <= 0:
+ return
+ lineEditWidth = max([a.slider().lineEdit().minimumSize().width() for a in self.__axis])
+ limitWidth = max([a.slider().limitWidget().minimumSizeHint().width() for a in self.__axis])
+ for a in self.__axis:
+ a.slider().lineEdit().setFixedWidth(lineEditWidth)
+ a.slider().limitWidget().setFixedWidth(limitWidth)
+
+ def __axisValueChanged(self, axis, value):
+ name = axis.axisName()
+ if name in self.__customAxisNames:
+ self.customAxisChanged.emit(name, value)
+ else:
+ self.__updateSelectedData()
+
+ def __axisNameChanged(self, axis, name):
+ """Called when an axis name change.
+
+ :param _Axis axis: The changed axis
+ :param str name: The new name of the axis
+ """
+ names = [x.axisName() for x in self.__axis]
+ missingName = set(self.__axisNames) - set(names) - set("")
+ if len(missingName) == 0:
+ missingName = None
+ elif len(missingName) == 1:
+ missingName = list(missingName)[0]
+ else:
+ raise Exception("Unexpected state")
+
+ axisChanged = True
+
+ if axis.axisName() == "":
+ # set the removed label to another widget if it is possible
+ availableWidget = None
+ for widget in self.__axis:
+ if widget is axis:
+ continue
+ if widget.axisName() == "":
+ availableWidget = widget
+ break
+ if availableWidget is None:
+ # If there is no other solution we set the name at the same place
+ axisChanged = False
+ availableWidget = axis
+ with blockSignals(availableWidget):
+ availableWidget.setAxisName(missingName)
+ else:
+ # there is a duplicated name somewhere
+ # we swap it with the missing name or with nothing
+ dupWidget = None
+ for widget in self.__axis:
+ if widget is axis:
+ continue
+ if widget.axisName() == axis.axisName():
+ dupWidget = widget
+ break
+ if missingName is None:
+ missingName = ""
+ with blockSignals(dupWidget):
+ dupWidget.setAxisName(missingName)
+
+ if self.__data is None:
+ return
+ if axisChanged:
+ self.selectedAxisChanged.emit()
+ self.__updateSelectedData()
+
+ def __updateSelectedData(self):
+ """Update the selected data according to the state of the widget.
+
+ It fires a `selectionChanged` event.
+ """
+ permutation = self.permutation()
+
+ if self.__data is None or permutation is None:
+ # No data or not all the expected axes are there
+ if self.__selectedData is not None:
+ self.__selectedData = None
+ self.selectionChanged.emit()
+ return
+
+ # get a view with few fixed dimensions
+ # with a h5py dataset, it create a copy
+ # TODO we can reuse the same memory in case of a copy
+ self.__selectedData = numpy.transpose(self.__data[self.selection()], permutation)
+ self.selectionChanged.emit()
+
+ def data(self):
+ """Returns the input data.
+
+ :rtype: Union[numpy.ndarray,None]
+ """
+ if self.__data is None:
+ return None
+ else:
+ return numpy.array(self.__data, copy=False)
+
+ def selectedData(self):
+ """Returns the output data.
+
+ This is equivalent to::
+
+ numpy.transpose(self.data()[self.selection()], self.permutation())
+
+ :rtype: Union[numpy.ndarray,None]
+ """
+ if self.__selectedData is None:
+ return None
+ else:
+ return numpy.array(self.__selectedData, copy=False)
+
+ def permutation(self):
+ """Returns the axes permutation to convert data subset to selected data.
+
+ If permutation cannot be computer, it returns None.
+
+ :rtype: Union[List[int],None]
+ """
+ if self.__data is None:
+ return None
+ else:
+ indices = []
+ for name in self.__axisNames:
+ index = 0
+ for axis in self.__axis:
+ if axis.axisName() == name:
+ indices.append(index)
+ break
+ if axis.axisName() != "":
+ index += 1
+ else:
+ _logger.warning("No axis corresponding to: %s", name)
+ return None
+ return tuple(indices)
+
+ def selection(self):
+ """Returns the selection tuple used to slice the data.
+
+ :rtype: tuple
+ """
+ if self.__data is None:
+ return tuple()
+ else:
+ return tuple([axis.value() if axis.axisName() == "" else slice(None)
+ for axis in self.__axis])
+
+ def setSelection(self, selection, permutation=None):
+ """Set the selection along each dimension.
+
+ tuple returned by :meth:`selection` can be provided as input,
+ provided that it is for the same the number of axes and
+ the same number of dimensions of the data.
+
+ :param List[Union[int,slice,None]] selection:
+ The selection tuple with as one element for each dimension of the data.
+ If an element is None, then the whole dimension is selected.
+ :param Union[List[int],None] permutation:
+ The data axes indices to transpose.
+ If not given, no permutation is applied
+ :raise ValueError:
+ When the selection does not match current data shape and number of axes.
+ """
+ data_shape = self.__data.shape if self.__data is not None else ()
+
+ # Check selection
+ if len(selection) != len(data_shape):
+ raise ValueError(
+ "Selection length (%d) and data ndim (%d) mismatch" %
+ (len(selection), len(data_shape)))
+
+ # Check selection type
+ selectedDataNDim = 0
+ for element, size in zip(selection, data_shape):
+ if isinstance(element, int):
+ if not 0 <= element < size:
+ raise ValueError(
+ "Selected index (%d) outside data dimension range [0-%d]" %
+ (element, size))
+ elif element is None or element == slice(None):
+ selectedDataNDim += 1
+ else:
+ raise ValueError("Unsupported element in selection: %s" % element)
+
+ ndim = len(self.__axisNames)
+ if selectedDataNDim != ndim:
+ raise ValueError(
+ "Selection dimensions (%d) and number of axes (%d) mismatch" %
+ (selectedDataNDim, ndim))
+
+ # check permutation
+ if permutation is None:
+ permutation = tuple(range(ndim))
+
+ if set(permutation) != set(range(ndim)):
+ raise ValueError(
+ "Error in provided permutation: "
+ "Wrong size, elements out of range or duplicates")
+
+ inversePermutation = numpy.argsort(permutation)
+
+ axisNameChanged = False
+ customValueChanged = []
+ with blockSignals(*self.__axis):
+ index = 0
+ for element, axis in zip(selection, self.__axis):
+ if isinstance(element, int):
+ name = ""
+ else:
+ name = self.__axisNames[inversePermutation[index]]
+ index += 1
+
+ if axis.axisName() != name:
+ axis.setAxisName(name)
+ axisNameChanged = True
+
+ for element, axis in zip(selection, self.__axis):
+ value = element if isinstance(element, int) else 0
+ if axis.value() != value:
+ axis.setValue(value)
+
+ name = axis.axisName()
+ if name in self.__customAxisNames:
+ customValueChanged.append((name, value))
+
+ # Send signals that where disabled
+ if axisNameChanged:
+ self.selectedAxisChanged.emit()
+ for name, value in customValueChanged:
+ self.customAxisChanged.emit(name, value)
+ self.__updateSelectedData()
+
+ def setNamedAxesSelectorVisibility(self, visible):
+ """Show or hide the combo-boxes allowing to map the plot axes
+ to the data dimension.
+
+ :param visible: Boolean
+ """
+ self.__namedAxesVisibility = visible
+ for axis in self.__axis:
+ axis.setNamedAxisSelectorVisibility(visible)
diff --git a/src/silx/gui/data/RecordTableView.py b/src/silx/gui/data/RecordTableView.py
new file mode 100644
index 0000000..ea73c62
--- /dev/null
+++ b/src/silx/gui/data/RecordTableView.py
@@ -0,0 +1,439 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module define model and widget to display 1D slices from numpy
+array using compound data types or hdf5 databases.
+"""
+from __future__ import division
+
+import itertools
+import numpy
+from silx.gui import qt
+import silx.io
+from .TextFormatter import TextFormatter
+from silx.gui.widgets.TableWidget import CopySelectedCellsAction
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "29/08/2018"
+
+
+class _MultiLineItem(qt.QItemDelegate):
+ """Draw a multiline text without hiding anything.
+
+ The paint method display a cell without any wrap. And an editor is
+ available to scroll into the selected cell.
+ """
+
+ def __init__(self, parent=None):
+ """
+ Constructor
+
+ :param qt.QWidget parent: Parent of the widget
+ """
+ qt.QItemDelegate.__init__(self, parent)
+ self.__textOptions = qt.QTextOption()
+ self.__textOptions.setFlags(qt.QTextOption.IncludeTrailingSpaces |
+ qt.QTextOption.ShowTabsAndSpaces)
+ self.__textOptions.setWrapMode(qt.QTextOption.NoWrap)
+ self.__textOptions.setAlignment(qt.Qt.AlignTop | qt.Qt.AlignLeft)
+
+ def paint(self, painter, option, index):
+ """
+ Write multiline text without using any wrap or any alignment according
+ to the cell size.
+
+ :param qt.QPainter painter: Painter context used to displayed the cell
+ :param qt.QStyleOptionViewItem option: Control how the editor is shown
+ :param qt.QIndex index: Index of the data to display
+ """
+ painter.save()
+
+ # set colors
+ painter.setPen(qt.QPen(qt.Qt.NoPen))
+ if option.state & qt.QStyle.State_Selected:
+ brush = option.palette.highlight()
+ painter.setBrush(brush)
+ else:
+ brush = index.data(qt.Qt.BackgroundRole)
+ if brush is None:
+ # default background color for a cell
+ brush = qt.Qt.white
+ painter.setBrush(brush)
+ painter.drawRect(option.rect)
+
+ if index.isValid():
+ if option.state & qt.QStyle.State_Selected:
+ brush = option.palette.highlightedText()
+ else:
+ brush = index.data(qt.Qt.ForegroundRole)
+ if brush is None:
+ brush = option.palette.text()
+ painter.setPen(qt.QPen(brush.color()))
+ text = index.data(qt.Qt.DisplayRole)
+ painter.drawText(qt.QRectF(option.rect), text, self.__textOptions)
+
+ painter.restore()
+
+ def createEditor(self, parent, option, index):
+ """
+ Returns the widget used to edit the item specified by index for editing.
+
+ We use it not to edit the content but to show the content with a
+ convenient scroll bar.
+
+ :param qt.QWidget parent: Parent of the widget
+ :param qt.QStyleOptionViewItem option: Control how the editor is shown
+ :param qt.QIndex index: Index of the data to display
+ """
+ if not index.isValid():
+ return super(_MultiLineItem, self).createEditor(parent, option, index)
+
+ editor = qt.QTextEdit(parent)
+ editor.setReadOnly(True)
+ return editor
+
+ def setEditorData(self, editor, index):
+ """
+ Read data from the model and feed the editor.
+
+ :param qt.QWidget editor: Editor widget
+ :param qt.QIndex index: Index of the data to display
+ """
+ text = index.model().data(index, qt.Qt.EditRole)
+ editor.setText(text)
+
+ def updateEditorGeometry(self, editor, option, index):
+ """
+ Update the geometry of the editor according to the changes of the view.
+
+ :param qt.QWidget editor: Editor widget
+ :param qt.QStyleOptionViewItem option: Control how the editor is shown
+ :param qt.QIndex index: Index of the data to display
+ """
+ editor.setGeometry(option.rect)
+
+
+class RecordTableModel(qt.QAbstractTableModel):
+ """This data model provides access to 1D slices from numpy array using
+ compound data types or hdf5 databases.
+
+ Each entries are displayed in a single row, and each columns contain a
+ specific field of the compound type.
+
+ It also allows to display 1D arrays of simple data types.
+ array.
+
+ :param qt.QObject parent: Parent object
+ :param numpy.ndarray data: A numpy array or a h5py dataset
+ """
+
+ MAX_NUMBER_OF_ROWS = 10e6
+ """Maximum number of display values of the dataset"""
+
+ def __init__(self, parent=None, data=None):
+ qt.QAbstractTableModel.__init__(self, parent)
+
+ self.__data = None
+ self.__is_array = False
+ self.__fields = None
+ self.__formatter = None
+ self.__editFormatter = None
+ self.setFormatter(TextFormatter(self))
+
+ # set _data
+ self.setArrayData(data)
+
+ # Methods to be implemented to subclass QAbstractTableModel
+ def rowCount(self, parent_idx=None):
+ """Returns number of rows to be displayed in table"""
+ if self.__data is None:
+ return 0
+ elif not self.__is_array:
+ return 1
+ else:
+ return min(len(self.__data), self.MAX_NUMBER_OF_ROWS)
+
+ def columnCount(self, parent_idx=None):
+ """Returns number of columns to be displayed in table"""
+ if self.__fields is None:
+ return 1
+ else:
+ return len(self.__fields)
+
+ def __clippedData(self, role=qt.Qt.DisplayRole):
+ """Return data for cells representing clipped data"""
+ if role == qt.Qt.DisplayRole:
+ return "..."
+ elif role == qt.Qt.ToolTipRole:
+ return "Dataset is too large: display is clipped"
+ else:
+ return None
+
+ def data(self, index, role=qt.Qt.DisplayRole):
+ """QAbstractTableModel method to access data values
+ in the format ready to be displayed"""
+ if not index.isValid():
+ return None
+
+ if self.__data is None:
+ return None
+
+ # Special display of one before last data for clipped table
+ if self.__isClipped() and index.row() == self.rowCount() - 2:
+ return self.__clippedData(role)
+
+ if self.__is_array:
+ row = index.row()
+ if row >= self.rowCount():
+ return None
+ elif self.__isClipped() and row == self.rowCount() - 1:
+ # Clipped array, display last value at the end
+ data = self.__data[-1]
+ else:
+ data = self.__data[row]
+ else:
+ if index.row() > 0:
+ return None
+ data = self.__data
+
+ if self.__fields is not None:
+ if index.column() >= len(self.__fields):
+ return None
+ key = self.__fields[index.column()][1]
+ data = data[key[0]]
+ if len(key) > 1:
+ data = data[key[1]]
+
+ # no dtype in case of 1D array of unicode objects (#2093)
+ dtype = getattr(data, "dtype", None)
+
+ if role == qt.Qt.DisplayRole:
+ return self.__formatter.toString(data, dtype=dtype)
+ elif role == qt.Qt.EditRole:
+ return self.__editFormatter.toString(data, dtype=dtype)
+ return None
+
+ def headerData(self, section, orientation, role=qt.Qt.DisplayRole):
+ """Returns the 0-based row or column index, for display in the
+ horizontal and vertical headers"""
+ if section == -1:
+ # PyQt4 send -1 when there is columns but no rows
+ return None
+
+ # Handle clipping of huge tables
+ if (self.__isClipped() and
+ orientation == qt.Qt.Vertical and
+ section == self.rowCount() - 2):
+ return self.__clippedData(role)
+
+ if role == qt.Qt.DisplayRole:
+ if orientation == qt.Qt.Vertical:
+ if not self.__is_array:
+ return "Scalar"
+ elif section == self.MAX_NUMBER_OF_ROWS - 1:
+ return str(len(self.__data) - 1)
+ else:
+ return str(section)
+ if orientation == qt.Qt.Horizontal:
+ if self.__fields is None:
+ if section == 0:
+ return "Data"
+ else:
+ return None
+ else:
+ if section < len(self.__fields):
+ return self.__fields[section][0]
+ else:
+ return None
+ return None
+
+ def flags(self, index):
+ """QAbstractTableModel method to inform the view whether data
+ is editable or not.
+ """
+ return qt.QAbstractTableModel.flags(self, index)
+
+ def __isClipped(self) -> bool:
+ """Returns whether the displayed array is clipped or not"""
+ return self.__data is not None and self.__is_array and len(self.__data) > self.MAX_NUMBER_OF_ROWS
+
+ def setArrayData(self, data):
+ """Set the data array and the viewing perspective.
+
+ You can set ``copy=False`` if you need more performances, when dealing
+ with a large numpy array. In this case, a simple reference to the data
+ is used to access the data, rather than a copy of the array.
+
+ .. warning::
+
+ Any change to the data model will affect your original data
+ array, when using a reference rather than a copy..
+
+ :param data: 1D numpy array, or any object that can be
+ converted to a numpy array using ``numpy.array(data)`` (e.g.
+ a nested sequence).
+ """
+ self.beginResetModel()
+
+ self.__data = data
+ if isinstance(data, numpy.ndarray):
+ self.__is_array = True
+ elif silx.io.is_dataset(data) and data.shape != tuple():
+ self.__is_array = True
+ else:
+ self.__is_array = False
+
+ self.__fields = []
+ if data is not None:
+ if data.dtype.fields is not None:
+ fields = sorted(data.dtype.fields.items(), key=lambda e: e[1][1])
+ for name, (dtype, _index) in fields:
+ if dtype.shape != tuple():
+ keys = itertools.product(*[range(x) for x in dtype.shape])
+ for key in keys:
+ label = "%s%s" % (name, list(key))
+ array_key = (name, key)
+ self.__fields.append((label, array_key))
+ else:
+ self.__fields.append((name, (name,)))
+ else:
+ self.__fields = None
+
+ self.endResetModel()
+
+ def arrayData(self):
+ """Returns the internal data.
+
+ :rtype: numpy.ndarray of h5py.Dataset
+ """
+ return self.__data
+
+ def setFormatter(self, formatter):
+ """Set the formatter object to be used to display data from the model
+
+ :param TextFormatter formatter: Formatter to use
+ """
+ if formatter is self.__formatter:
+ return
+
+ self.beginResetModel()
+
+ if self.__formatter is not None:
+ self.__formatter.formatChanged.disconnect(self.__formatChanged)
+
+ self.__formatter = formatter
+ self.__editFormatter = TextFormatter(formatter)
+ self.__editFormatter.setUseQuoteForText(False)
+
+ if self.__formatter is not None:
+ self.__formatter.formatChanged.connect(self.__formatChanged)
+
+ self.endResetModel()
+
+ def getFormatter(self):
+ """Returns the text formatter used.
+
+ :rtype: TextFormatter
+ """
+ return self.__formatter
+
+ def __formatChanged(self):
+ """Called when the format changed.
+ """
+ self.__editFormatter = TextFormatter(self, self.getFormatter())
+ self.__editFormatter.setUseQuoteForText(False)
+ self.reset()
+
+
+class _ShowEditorProxyModel(qt.QIdentityProxyModel):
+ """
+ Allow to custom the flag edit of the model
+ """
+
+ def __init__(self, parent=None):
+ """
+ Constructor
+
+ :param qt.QObject arent: parent object
+ """
+ super(_ShowEditorProxyModel, self).__init__(parent)
+ self.__forceEditable = False
+
+ def flags(self, index):
+ flag = qt.QIdentityProxyModel.flags(self, index)
+ if self.__forceEditable:
+ flag = flag | qt.Qt.ItemIsEditable
+ return flag
+
+ def forceCellEditor(self, show):
+ """
+ Enable the editable flag to allow to display cell editor.
+ """
+ if self.__forceEditable == show:
+ return
+ self.beginResetModel()
+ self.__forceEditable = show
+ self.endResetModel()
+
+
+class RecordTableView(qt.QTableView):
+ """TableView using DatabaseTableModel as default model.
+ """
+ def __init__(self, parent=None):
+ """
+ Constructor
+
+ :param qt.QWidget parent: parent QWidget
+ """
+ qt.QTableView.__init__(self, parent)
+
+ model = _ShowEditorProxyModel(self)
+ self._model = RecordTableModel()
+ model.setSourceModel(self._model)
+ self.setModel(model)
+
+ self.__multilineView = _MultiLineItem(self)
+ self.setEditTriggers(qt.QAbstractItemView.AllEditTriggers)
+ self._copyAction = CopySelectedCellsAction(self)
+ self.addAction(self._copyAction)
+
+ def copy(self):
+ self._copyAction.trigger()
+
+ def setArrayData(self, data):
+ model = self.model()
+ sourceModel = model.sourceModel()
+ sourceModel.setArrayData(data)
+
+ if data is not None:
+ if issubclass(data.dtype.type, (numpy.string_, numpy.unicode_)):
+ # TODO it would be nice to also fix fields
+ # but using it only for string array is already very useful
+ self.setItemDelegateForColumn(0, self.__multilineView)
+ model.forceCellEditor(True)
+ else:
+ self.setItemDelegateForColumn(0, None)
+ model.forceCellEditor(False)
diff --git a/src/silx/gui/data/TextFormatter.py b/src/silx/gui/data/TextFormatter.py
new file mode 100644
index 0000000..b6baca4
--- /dev/null
+++ b/src/silx/gui/data/TextFormatter.py
@@ -0,0 +1,386 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides a class sharred by widget from the
+data module to format data as text in the same way."""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "24/07/2018"
+
+import logging
+import numbers
+
+import numpy
+
+from silx.gui import qt
+
+import h5py
+
+
+_logger = logging.getLogger(__name__)
+
+
+class TextFormatter(qt.QObject):
+ """Formatter to convert data to string.
+
+ The method :meth:`toString` returns a formatted string from an input data
+ using parameters set to this object.
+
+ It support most python and numpy data, expecting dictionary. Unsupported
+ data are displayed using the string representation of the object (`str`).
+
+ It provides a set of parameters to custom the formatting of integer and
+ float values (:meth:`setIntegerFormat`, :meth:`setFloatFormat`).
+
+ It also allows to custom the use of quotes to display text data
+ (:meth:`setUseQuoteForText`), and custom unit used to display imaginary
+ numbers (:meth:`setImaginaryUnit`).
+
+ The object emit an event `formatChanged` every time a parametter is
+ changed.
+ """
+
+ formatChanged = qt.Signal()
+ """Emitted when properties of the formatter change."""
+
+ def __init__(self, parent=None, formatter=None):
+ """
+ Constructor
+
+ :param qt.QObject parent: Owner of the object
+ :param TextFormatter formatter: Instantiate this object from the
+ formatter
+ """
+ qt.QObject.__init__(self, parent)
+ if formatter is not None:
+ self.__integerFormat = formatter.integerFormat()
+ self.__floatFormat = formatter.floatFormat()
+ self.__useQuoteForText = formatter.useQuoteForText()
+ self.__imaginaryUnit = formatter.imaginaryUnit()
+ self.__enumFormat = formatter.enumFormat()
+ else:
+ self.__integerFormat = "%d"
+ self.__floatFormat = "%g"
+ self.__useQuoteForText = True
+ self.__imaginaryUnit = u"j"
+ self.__enumFormat = u"%(name)s(%(value)d)"
+
+ def integerFormat(self):
+ """Returns the format string controlling how the integer data
+ are formated by this object.
+
+ This is the C-style format string used by python when formatting
+ strings with the modulus operator.
+
+ :rtype: str
+ """
+ return self.__integerFormat
+
+ def setIntegerFormat(self, value):
+ """Set format string controlling how the integer data are
+ formated by this object.
+
+ :param str value: Format string (e.g. "%d", "%i", "%08i").
+ This is the C-style format string used by python when formatting
+ strings with the modulus operator.
+ """
+ if self.__integerFormat == value:
+ return
+ self.__integerFormat = value
+ self.formatChanged.emit()
+
+ def floatFormat(self):
+ """Returns the format string controlling how the floating-point data
+ are formated by this object.
+
+ This is the C-style format string used by python when formatting
+ strings with the modulus operator.
+
+ :rtype: str
+ """
+ return self.__floatFormat
+
+ def setFloatFormat(self, value):
+ """Set format string controlling how the floating-point data are
+ formated by this object.
+
+ :param str value: Format string (e.g. "%.3f", "%d", "%-10.2f",
+ "%10.3e").
+ This is the C-style format string used by python when formatting
+ strings with the modulus operator.
+ """
+ if self.__floatFormat == value:
+ return
+ self.__floatFormat = value
+ self.formatChanged.emit()
+
+ def useQuoteForText(self):
+ """Returns true if the string data are formatted using double quotes.
+
+ Else, no quotes are used.
+ """
+ return self.__integerFormat
+
+ def setUseQuoteForText(self, useQuote):
+ """Set the use of quotes to delimit string data.
+
+ :param bool useQuote: True to use quotes.
+ """
+ if self.__useQuoteForText == useQuote:
+ return
+ self.__useQuoteForText = useQuote
+ self.formatChanged.emit()
+
+ def imaginaryUnit(self):
+ """Returns the unit display for imaginary numbers.
+
+ :rtype: str
+ """
+ return self.__imaginaryUnit
+
+ def setImaginaryUnit(self, imaginaryUnit):
+ """Set the unit display for imaginary numbers.
+
+ :param str imaginaryUnit: Unit displayed after imaginary numbers
+ """
+ if self.__imaginaryUnit == imaginaryUnit:
+ return
+ self.__imaginaryUnit = imaginaryUnit
+ self.formatChanged.emit()
+
+ def setEnumFormat(self, value):
+ """Set format string controlling how the enum data are
+ formated by this object.
+
+ :param str value: Format string (e.g. "%(name)s(%(value)d)").
+ This is the C-style format string used by python when formatting
+ strings with the modulus operator.
+ """
+ if self.__enumFormat == value:
+ return
+ self.__enumFormat = value
+ self.formatChanged.emit()
+
+ def enumFormat(self):
+ """Returns the format string controlling how the enum data
+ are formated by this object.
+
+ This is the C-style format string used by python when formatting
+ strings with the modulus operator.
+
+ :rtype: str
+ """
+ return self.__enumFormat
+
+ def __formatText(self, text):
+ if self.__useQuoteForText:
+ text = "\"%s\"" % text.replace("\\", "\\\\").replace("\"", "\\\"")
+ return text
+
+ def __formatBinary(self, data):
+ if isinstance(data, numpy.void):
+ data = data.item()
+ if isinstance(data, numpy.ndarray):
+ # Before numpy 1.15.0 the item API was returning a numpy array
+ data = data.astype(numpy.uint8)
+ else:
+ # Now it is supposed to be a bytes type
+ pass
+ data = ["\\x%02X" % d for d in data]
+ if self.__useQuoteForText:
+ return "b\"%s\"" % "".join(data)
+ else:
+ return "".join(data)
+
+ def __formatSafeAscii(self, data):
+ data = [chr(d) if (d > 0x20 and d < 0x7F) else "\\x%02X" % d for d in data]
+ if self.__useQuoteForText:
+ data = [c if c != '"' else "\\" + c for c in data]
+ return "b\"%s\"" % "".join(data)
+ else:
+ return "".join(data)
+
+ def __formatCharString(self, data):
+ """Format text of char.
+
+ From the specifications we expect to have ASCII, but we also allow
+ CP1252 in some ceases as fallback.
+
+ If no encoding fits, it will display a readable ASCII chars, with
+ escaped chars (using the python syntax) for non decoded characters.
+
+ :param data: A binary string of char expected in ASCII
+ :rtype: str
+ """
+ try:
+ text = "%s" % data.decode("ascii")
+ return self.__formatText(text)
+ except UnicodeDecodeError:
+ # Here we can spam errors, this is definitly a badly
+ # generated file
+ _logger.error("Invalid ASCII string %s.", data)
+ if data == b"\xB0":
+ _logger.error("Fallback using cp1252 encoding")
+ return self.__formatText(u"\u00B0")
+ return self.__formatSafeAscii(data)
+
+ def __formatH5pyObject(self, data, dtype):
+ # That's an HDF5 object
+ ref = h5py.check_dtype(ref=dtype)
+ if ref is not None:
+ if bool(data):
+ return "REF"
+ else:
+ return "NULL_REF"
+ vlen = h5py.check_dtype(vlen=dtype)
+ if vlen is not None:
+ if vlen == str:
+ # HDF5 UTF8
+ # With h5py>=3 reading dataset returns bytes
+ if isinstance(data, (bytes, numpy.bytes_)):
+ try:
+ data = data.decode("utf-8")
+ except UnicodeDecodeError:
+ self.__formatSafeAscii(data)
+ return self.__formatText(data)
+ elif vlen == bytes:
+ # HDF5 ASCII
+ return self.__formatCharString(data)
+ elif isinstance(vlen, numpy.dtype):
+ return self.toString(data, vlen)
+ return None
+
+ def toString(self, data, dtype=None):
+ """Format a data into a string using formatter options
+
+ :param object data: Data to render
+ :param dtype: enforce a dtype (mostly used to remember the h5py dtype,
+ special h5py dtypes are not propagated from array to items)
+ :rtype: str
+ """
+ if isinstance(data, tuple):
+ text = [self.toString(d) for d in data]
+ return "(" + " ".join(text) + ")"
+ elif isinstance(data, list):
+ text = [self.toString(d) for d in data]
+ return "[" + " ".join(text) + "]"
+ elif isinstance(data, numpy.ndarray):
+ if dtype is None:
+ dtype = data.dtype
+ if data.shape == ():
+ # it is a scaler
+ return self.toString(data[()], dtype)
+ else:
+ text = [self.toString(d, dtype) for d in data]
+ return "[" + " ".join(text) + "]"
+ if dtype is not None and dtype.kind == 'O':
+ text = self.__formatH5pyObject(data, dtype)
+ if text is not None:
+ return text
+ elif isinstance(data, numpy.void):
+ if dtype is None:
+ dtype = data.dtype
+ if dtype.fields is not None:
+ text = []
+ for index, field in enumerate(dtype.fields.items()):
+ text.append(field[0] + ":" + self.toString(data[index], field[1][0]))
+ return "(" + " ".join(text) + ")"
+ return self.__formatBinary(data)
+ elif isinstance(data, (numpy.unicode_, str)):
+ return self.__formatText(data)
+ elif isinstance(data, (numpy.string_, bytes)):
+ if dtype is None and hasattr(data, "dtype"):
+ dtype = data.dtype
+ if dtype is not None:
+ # Maybe a sub item from HDF5
+ if dtype.kind == 'S':
+ return self.__formatCharString(data)
+ elif dtype.kind == 'O':
+ text = self.__formatH5pyObject(data, dtype)
+ if text is not None:
+ return text
+ try:
+ # Try ascii/utf-8
+ text = "%s" % data.decode("utf-8")
+ return self.__formatText(text)
+ except UnicodeDecodeError:
+ pass
+ return self.__formatBinary(data)
+ elif isinstance(data, str):
+ text = "%s" % data
+ return self.__formatText(text)
+ elif isinstance(data, (numpy.integer)):
+ if dtype is None:
+ dtype = data.dtype
+ enumType = h5py.check_dtype(enum=dtype)
+ if enumType is not None:
+ for key, value in enumType.items():
+ if value == data:
+ result = {}
+ result["name"] = key
+ result["value"] = data
+ return self.__enumFormat % result
+ return self.__integerFormat % data
+ elif isinstance(data, (numbers.Integral)):
+ return self.__integerFormat % data
+ elif isinstance(data, (numbers.Real, numpy.floating)):
+ # It have to be done before complex checking
+ return self.__floatFormat % data
+ elif isinstance(data, (numpy.complexfloating, numbers.Complex)):
+ text = ""
+ if data.real != 0:
+ text += self.__floatFormat % data.real
+ if data.real != 0 and data.imag != 0:
+ if data.imag < 0:
+ template = self.__floatFormat + " - " + self.__floatFormat + self.__imaginaryUnit
+ params = (data.real, -data.imag)
+ else:
+ template = self.__floatFormat + " + " + self.__floatFormat + self.__imaginaryUnit
+ params = (data.real, data.imag)
+ else:
+ if data.imag != 0:
+ template = self.__floatFormat + self.__imaginaryUnit
+ params = (data.imag)
+ else:
+ template = self.__floatFormat
+ params = (data.real)
+ return template % params
+ elif isinstance(data, h5py.h5r.Reference):
+ dtype = h5py.special_dtype(ref=h5py.Reference)
+ text = self.__formatH5pyObject(data, dtype)
+ return text
+ elif isinstance(data, h5py.h5r.RegionReference):
+ dtype = h5py.special_dtype(ref=h5py.RegionReference)
+ text = self.__formatH5pyObject(data, dtype)
+ return text
+ elif isinstance(data, numpy.object_) or dtype is not None:
+ if dtype is None:
+ dtype = data.dtype
+ text = self.__formatH5pyObject(data, dtype)
+ if text is not None:
+ return text
+ # That's a numpy object
+ return str(data)
+ return str(data)
diff --git a/src/silx/gui/data/_RecordPlot.py b/src/silx/gui/data/_RecordPlot.py
new file mode 100644
index 0000000..5be792f
--- /dev/null
+++ b/src/silx/gui/data/_RecordPlot.py
@@ -0,0 +1,92 @@
+from silx.gui.plot.PlotWindow import PlotWindow
+from silx.gui.plot.PlotWidget import PlotWidget
+from .. import qt
+
+
+class RecordPlot(PlotWindow):
+ def __init__(self, parent=None, backend=None):
+ super(RecordPlot, self).__init__(parent=parent, backend=backend,
+ resetzoom=True, autoScale=True,
+ logScale=True, grid=True,
+ curveStyle=True, colormap=False,
+ aspectRatio=False, yInverted=False,
+ copy=True, save=True, print_=True,
+ control=True, position=True,
+ roi=True, mask=False, fit=True)
+ if parent is None:
+ self.setWindowTitle('RecordPlot')
+ self._axesSelectionToolBar = AxesSelectionToolBar(parent=self, plot=self)
+ self.addToolBar(qt.Qt.BottomToolBarArea, self._axesSelectionToolBar)
+
+ def setXAxisFieldName(self, value):
+ """Set the current selected field for the X axis.
+
+ :param Union[str,None] value:
+ """
+ label = '' if value is None else value
+ index = self._axesSelectionToolBar.getXAxisDropDown().findData(value)
+
+ if index >= 0:
+ self.getXAxis().setLabel(label)
+ self._axesSelectionToolBar.getXAxisDropDown().setCurrentIndex(index)
+
+ def getXAxisFieldName(self):
+ """Returns currently selected field for the X axis or None.
+
+ rtype: Union[str,None]
+ """
+ return self._axesSelectionToolBar.getXAxisDropDown().currentData()
+
+ def setYAxisFieldName(self, value):
+ self.getYAxis().setLabel(value)
+ index = self._axesSelectionToolBar.getYAxisDropDown().findText(value)
+ if index >= 0:
+ self._axesSelectionToolBar.getYAxisDropDown().setCurrentIndex(index)
+
+ def getYAxisFieldName(self):
+ return self._axesSelectionToolBar.getYAxisDropDown().currentText()
+
+ def setSelectableXAxisFieldNames(self, fieldNames):
+ """Add list of field names to X axis
+
+ :param List[str] fieldNames:
+ """
+ comboBox = self._axesSelectionToolBar.getXAxisDropDown()
+ comboBox.clear()
+ comboBox.addItem('-', None)
+ comboBox.insertSeparator(1)
+ for name in fieldNames:
+ comboBox.addItem(name, name)
+
+ def setSelectableYAxisFieldNames(self, fieldNames):
+ self._axesSelectionToolBar.getYAxisDropDown().clear()
+ self._axesSelectionToolBar.getYAxisDropDown().addItems(fieldNames)
+
+ def getAxesSelectionToolBar(self):
+ return self._axesSelectionToolBar
+
+class AxesSelectionToolBar(qt.QToolBar):
+ def __init__(self, parent=None, plot=None, title='Plot Axes Selection'):
+ super(AxesSelectionToolBar, self).__init__(title, parent)
+
+ assert isinstance(plot, PlotWidget)
+
+ self.addWidget(qt.QLabel("Field selection: "))
+
+ self._labelXAxis = qt.QLabel(" X: ")
+ self.addWidget(self._labelXAxis)
+
+ self._selectXAxisDropDown = qt.QComboBox()
+ self.addWidget(self._selectXAxisDropDown)
+
+ self._labelYAxis = qt.QLabel(" Y: ")
+ self.addWidget(self._labelYAxis)
+
+ self._selectYAxisDropDown = qt.QComboBox()
+ self.addWidget(self._selectYAxisDropDown)
+
+ def getXAxisDropDown(self):
+ return self._selectXAxisDropDown
+
+ def getYAxisDropDown(self):
+ return self._selectYAxisDropDown \ No newline at end of file
diff --git a/src/silx/gui/data/_VolumeWindow.py b/src/silx/gui/data/_VolumeWindow.py
new file mode 100644
index 0000000..03b6876
--- /dev/null
+++ b/src/silx/gui/data/_VolumeWindow.py
@@ -0,0 +1,148 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a widget to visualize 3D arrays"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "22/03/2019"
+
+
+import numpy
+
+from .. import qt
+from ..plot3d.SceneWindow import SceneWindow
+from ..plot3d.items import ScalarField3D, ComplexField3D, ItemChangedType
+
+
+class VolumeWindow(SceneWindow):
+ """Extends SceneWindow with a convenient API for 3D array
+
+ :param QWidget: parent
+ """
+
+ def __init__(self, parent):
+ super(VolumeWindow, self).__init__(parent)
+ self.__firstData = True
+ # Hide global parameter dock
+ self.getGroupResetWidget().parent().setVisible(False)
+
+ def setAxesLabels(self, xlabel=None, ylabel=None, zlabel=None):
+ """Set the text labels of the axes.
+
+ :param Union[str,None] xlabel: Label of the X axis
+ :param Union[str,None] ylabel: Label of the Y axis
+ :param Union[str,None] zlabel: Label of the Z axis
+ """
+ sceneWidget = self.getSceneWidget()
+ sceneWidget.getSceneGroup().setAxesLabels(
+ 'X' if xlabel is None else xlabel,
+ 'Y' if ylabel is None else ylabel,
+ 'Z' if zlabel is None else zlabel)
+
+ def clear(self):
+ """Clear any currently displayed data"""
+ sceneWidget = self.getSceneWidget()
+ items = sceneWidget.getItems()
+ if (len(items) == 1 and
+ isinstance(items[0], (ScalarField3D, ComplexField3D))):
+ items[0].setData(None)
+ else: # Safety net
+ sceneWidget.clearItems()
+
+ @staticmethod
+ def __computeIsolevel(data):
+ """Returns a suitable isolevel value for data
+
+ :param numpy.ndarray data:
+ :rtype: float
+ """
+ data = data[numpy.isfinite(data)]
+ if len(data) == 0:
+ return 0
+ else:
+ return numpy.mean(data) + numpy.std(data)
+
+ def setData(self, data, offset=(0., 0., 0.), scale=(1., 1., 1.)):
+ """Set the 3D array data to display.
+
+ :param numpy.ndarray data: 3D array of float or complex
+ :param List[float] offset: (tx, ty, tz) coordinates of the origin
+ :param List[float] scale: (sx, sy, sz) scale for each dimension
+ """
+ sceneWidget = self.getSceneWidget()
+ dataMaxCoords = numpy.array(list(reversed(data.shape))) - 1
+
+ previousItems = sceneWidget.getItems()
+ if (len(previousItems) == 1 and
+ isinstance(previousItems[0], (ScalarField3D, ComplexField3D)) and
+ numpy.iscomplexobj(data) == isinstance(previousItems[0], ComplexField3D)):
+ # Reuse existing volume item
+ volume = sceneWidget.getItems()[0]
+ volume.setData(data, copy=False)
+ # Make sure the plane goes through the dataset
+ for plane in volume.getCutPlanes():
+ point = numpy.array(plane.getPoint())
+ if numpy.any(point < (0, 0, 0)) or numpy.any(point > dataMaxCoords):
+ plane.setPoint(dataMaxCoords // 2)
+ else:
+ # Add a new volume
+ sceneWidget.clearItems()
+ volume = sceneWidget.addVolume(data, copy=False)
+ volume.setLabel('Volume')
+ for plane in volume.getCutPlanes():
+ # Make plane going through the center of the data
+ plane.setPoint(dataMaxCoords // 2)
+ plane.setVisible(False)
+ plane.sigItemChanged.connect(self.__cutPlaneUpdated)
+ volume.addIsosurface(self.__computeIsolevel, '#FF0000FF')
+
+ # Expand the parameter tree
+ model = self.getParamTreeView().model()
+ index = qt.QModelIndex() # Invalid index for top level
+ while 1:
+ rowCount = model.rowCount(parent=index)
+ if rowCount == 0:
+ break
+ index = model.index(rowCount - 1, 0, parent=index)
+ self.getParamTreeView().setExpanded(index, True)
+ if not index.isValid():
+ break
+
+ volume.setTranslation(*offset)
+ volume.setScale(*scale)
+
+ if self.__firstData: # Only center for first dataset
+ self.__firstData = False
+ sceneWidget.centerScene()
+
+ def __cutPlaneUpdated(self, event):
+ """Handle the change of visibility of the cut plane
+
+ :param event: Kind of update
+ """
+ if event == ItemChangedType.VISIBLE:
+ plane = self.sender()
+ if plane.isVisible():
+ self.getSceneWidget().selection().setCurrentItem(plane)
diff --git a/src/silx/gui/data/__init__.py b/src/silx/gui/data/__init__.py
new file mode 100644
index 0000000..560062d
--- /dev/null
+++ b/src/silx/gui/data/__init__.py
@@ -0,0 +1,35 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides a set of Qt widgets for displaying data arrays using
+table views and plot widgets.
+
+.. note::
+
+ Widgets in this package may rely on additional dependencies that are
+ not mandatory for *silx*.
+ :class:`DataViewer.DataViewer` relies on :mod:`silx.gui.plot` which
+ depends on *matplotlib*. It also optionally depends on *PyOpenGL* for 3D
+ visualization.
+"""
diff --git a/src/silx/gui/data/setup.py b/src/silx/gui/data/setup.py
new file mode 100644
index 0000000..23ccbdd
--- /dev/null
+++ b/src/silx/gui/data/setup.py
@@ -0,0 +1,41 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "16/01/2017"
+
+
+from numpy.distutils.misc_util import Configuration
+
+
+def configuration(parent_package='', top_path=None):
+ config = Configuration('data', parent_package, top_path)
+ config.add_subpackage('test')
+ return config
+
+
+if __name__ == "__main__":
+ from numpy.distutils.core import setup
+ setup(configuration=configuration)
diff --git a/src/silx/gui/data/test/__init__.py b/src/silx/gui/data/test/__init__.py
new file mode 100644
index 0000000..7790ee5
--- /dev/null
+++ b/src/silx/gui/data/test/__init__.py
@@ -0,0 +1,24 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
diff --git a/src/silx/gui/data/test/test_arraywidget.py b/src/silx/gui/data/test/test_arraywidget.py
new file mode 100644
index 0000000..c84a34f
--- /dev/null
+++ b/src/silx/gui/data/test/test_arraywidget.py
@@ -0,0 +1,316 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+import os
+import tempfile
+import unittest
+
+import numpy
+
+from silx.gui import qt
+from silx.gui.data import ArrayTableWidget
+from silx.gui.data.ArrayTableModel import ArrayTableModel
+from silx.gui.utils.testutils import TestCaseQt
+
+import h5py
+
+
+class TestArrayWidget(TestCaseQt):
+ """Basic test for ArrayTableWidget with a numpy array"""
+ def setUp(self):
+ super(TestArrayWidget, self).setUp()
+ self.aw = ArrayTableWidget.ArrayTableWidget()
+
+ def tearDown(self):
+ del self.aw
+ super(TestArrayWidget, self).tearDown()
+
+ def testShow(self):
+ """test for errors"""
+ self.aw.show()
+ self.qWaitForWindowExposed(self.aw)
+
+ def testSetData0D(self):
+ a = 1
+ self.aw.setArrayData(a)
+ b = self.aw.getData(copy=True)
+
+ self.assertTrue(numpy.array_equal(a, b))
+
+ # scalar/0D data has no frame index
+ self.assertEqual(len(self.aw.model._index), 0)
+ # and no perspective
+ self.assertEqual(len(self.aw.model._perspective), 0)
+
+ def testSetData1D(self):
+ a = [1, 2]
+ self.aw.setArrayData(a)
+ b = self.aw.getData(copy=True)
+
+ self.assertTrue(numpy.array_equal(a, b))
+
+ # 1D data has no frame index
+ self.assertEqual(len(self.aw.model._index), 0)
+ # and no perspective
+ self.assertEqual(len(self.aw.model._perspective), 0)
+
+ def testSetData4D(self):
+ a = numpy.reshape(numpy.linspace(0.213, 1.234, 1250),
+ (5, 5, 5, 10))
+ self.aw.setArrayData(a)
+
+ # default perspective (0, 1)
+ self.assertEqual(list(self.aw.model._perspective),
+ [0, 1])
+ self.aw.setPerspective((1, 3))
+ self.assertEqual(list(self.aw.model._perspective),
+ [1, 3])
+
+ b = self.aw.getData(copy=True)
+ self.assertTrue(numpy.array_equal(a, b))
+
+ # 4D data has a 2-tuple as frame index
+ self.assertEqual(len(self.aw.model._index), 2)
+ # default index is (0, 0)
+ self.assertEqual(list(self.aw.model._index),
+ [0, 0])
+ self.aw.setFrameIndex((3, 1))
+
+ self.assertEqual(list(self.aw.model._index),
+ [3, 1])
+
+ def testColors(self):
+ a = numpy.arange(256, dtype=numpy.uint8)
+ self.aw.setArrayData(a)
+
+ bgcolor = numpy.empty(a.shape + (3,), dtype=numpy.uint8)
+ # Black & white palette
+ bgcolor[..., 0] = a
+ bgcolor[..., 1] = a
+ bgcolor[..., 2] = a
+
+ fgcolor = numpy.bitwise_xor(bgcolor, 255)
+
+ self.aw.setArrayColors(bgcolor, fgcolor)
+
+ # test colors are as expected in model
+ for i in range(256):
+ # all RGB channels for BG equal to data value
+ self.assertEqual(
+ self.aw.model.data(self.aw.model.index(0, i),
+ role=qt.Qt.BackgroundRole),
+ qt.QColor(i, i, i),
+ "Unexpected background color"
+ )
+
+ # all RGB channels for FG equal to XOR(data value, 255)
+ self.assertEqual(
+ self.aw.model.data(self.aw.model.index(0, i),
+ role=qt.Qt.ForegroundRole),
+ qt.QColor(i ^ 255, i ^ 255, i ^ 255),
+ "Unexpected text color"
+ )
+
+ # test colors are reset to None when a new data array is loaded
+ # with different shape
+ self.aw.setArrayData(numpy.arange(300))
+
+ for i in range(300):
+ # all RGB channels for BG equal to data value
+ self.assertIsNone(
+ self.aw.model.data(self.aw.model.index(0, i),
+ role=qt.Qt.BackgroundRole))
+
+ def testDefaultFlagNotEditable(self):
+ """editable should be False by default, in setArrayData"""
+ self.aw.setArrayData([[0]])
+ idx = self.aw.model.createIndex(0, 0)
+ # model is editable
+ self.assertFalse(
+ self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
+
+ def testFlagEditable(self):
+ self.aw.setArrayData([[0]], editable=True)
+ idx = self.aw.model.createIndex(0, 0)
+ # model is editable
+ self.assertTrue(
+ self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
+
+ def testFlagNotEditable(self):
+ self.aw.setArrayData([[0]], editable=False)
+ idx = self.aw.model.createIndex(0, 0)
+ # model is editable
+ self.assertFalse(
+ self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
+
+ def testReferenceReturned(self):
+ """when setting the data with copy=False and
+ retrieving it with getData(copy=False), we should recover
+ the same original object.
+ """
+ # n-D (n >=2)
+ a0 = numpy.reshape(numpy.linspace(0.213, 1.234, 1000),
+ (10, 10, 10))
+ self.aw.setArrayData(a0, copy=False)
+ a1 = self.aw.getData(copy=False)
+
+ self.assertIs(a0, a1)
+
+ # 1D
+ b0 = numpy.linspace(0.213, 1.234, 1000)
+ self.aw.setArrayData(b0, copy=False)
+ b1 = self.aw.getData(copy=False)
+ self.assertIs(b0, b1)
+
+ def testClipping(self):
+ """Test clipping of large arrays"""
+ self.aw.show()
+ self.qWaitForWindowExposed(self.aw)
+
+ data = numpy.arange(ArrayTableModel.MAX_NUMBER_OF_SECTIONS + 10)
+
+ for shape in [(1, -1), (-1, 1)]:
+ with self.subTest(shape=shape):
+ self.aw.setArrayData(data.reshape(shape), editable=True)
+ self.qapp.processEvents()
+
+
+class TestH5pyArrayWidget(TestCaseQt):
+ """Basic test for ArrayTableWidget with a dataset.
+
+ Test flags, for dataset open in read-only or read-write modes"""
+ def setUp(self):
+ super(TestH5pyArrayWidget, self).setUp()
+ self.aw = ArrayTableWidget.ArrayTableWidget()
+ self.data = numpy.reshape(numpy.linspace(0.213, 1.234, 1000),
+ (10, 10, 10))
+ # create an h5py file with a dataset
+ self.tempdir = tempfile.mkdtemp()
+ self.h5_fname = os.path.join(self.tempdir, "array.h5")
+ h5f = h5py.File(self.h5_fname, mode='w')
+ h5f["my_array"] = self.data
+ h5f["my_scalar"] = 3.14
+ h5f["my_1D_array"] = numpy.array(numpy.arange(1000))
+ h5f.close()
+
+ def tearDown(self):
+ del self.aw
+ os.unlink(self.h5_fname)
+ os.rmdir(self.tempdir)
+ super(TestH5pyArrayWidget, self).tearDown()
+
+ def testShow(self):
+ self.aw.show()
+ self.qWaitForWindowExposed(self.aw)
+
+ def testReadOnly(self):
+ """Open H5 dataset in read-only mode, ensure the model is not editable."""
+ h5f = h5py.File(self.h5_fname, "r")
+ a = h5f["my_array"]
+ # ArrayTableModel relies on following condition
+ self.assertTrue(a.file.mode == "r")
+
+ self.aw.setArrayData(a, copy=False, editable=True)
+
+ self.assertIsInstance(a, h5py.Dataset) # simple sanity check
+ # internal representation must be a reference to original data (copy=False)
+ self.assertIsInstance(self.aw.model._array, h5py.Dataset)
+ self.assertTrue(self.aw.model._array.file.mode == "r")
+
+ b = self.aw.getData()
+ self.assertTrue(numpy.array_equal(self.data, b))
+
+ # model must have detected read-only dataset and disabled editing
+ self.assertFalse(self.aw.model._editable)
+ idx = self.aw.model.createIndex(0, 0)
+ self.assertFalse(
+ self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
+
+ # force editing read-only datasets raises IOError
+ self.assertRaises(IOError, self.aw.model.setData,
+ idx, 123.4, role=qt.Qt.EditRole)
+ h5f.close()
+
+ def testReadWrite(self):
+ h5f = h5py.File(self.h5_fname, "r+")
+ a = h5f["my_array"]
+ self.assertTrue(a.file.mode == "r+")
+
+ self.aw.setArrayData(a, copy=False, editable=True)
+ b = self.aw.getData(copy=False)
+ self.assertTrue(numpy.array_equal(self.data, b))
+
+ idx = self.aw.model.createIndex(0, 0)
+ # model is editable
+ self.assertTrue(
+ self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
+ h5f.close()
+
+ def testSetData0D(self):
+ h5f = h5py.File(self.h5_fname, "r+")
+ a = h5f["my_scalar"]
+ self.aw.setArrayData(a)
+ b = self.aw.getData(copy=True)
+
+ self.assertTrue(numpy.array_equal(a, b))
+
+ h5f.close()
+
+ def testSetData1D(self):
+ h5f = h5py.File(self.h5_fname, "r+")
+ a = h5f["my_1D_array"]
+ self.aw.setArrayData(a)
+ b = self.aw.getData(copy=True)
+
+ self.assertTrue(numpy.array_equal(a, b))
+
+ h5f.close()
+
+ def testReferenceReturned(self):
+ """when setting the data with copy=False and
+ retrieving it with getData(copy=False), we should recover
+ the same original object.
+
+ This only works for array with at least 2D. For 1D and 0D
+ arrays, a view is created at some point, which in the case
+ of an hdf5 dataset creates a copy."""
+ h5f = h5py.File(self.h5_fname, "r+")
+
+ # n-D
+ a0 = h5f["my_array"]
+ self.aw.setArrayData(a0, copy=False)
+ a1 = self.aw.getData(copy=False)
+ self.assertIs(a0, a1)
+
+ # 1D
+ b0 = h5f["my_1D_array"]
+ self.aw.setArrayData(b0, copy=False)
+ b1 = self.aw.getData(copy=False)
+ self.assertIs(b0, b1)
+
+ h5f.close()
diff --git a/src/silx/gui/data/test/test_dataviewer.py b/src/silx/gui/data/test/test_dataviewer.py
new file mode 100644
index 0000000..30b76ce
--- /dev/null
+++ b/src/silx/gui/data/test/test_dataviewer.py
@@ -0,0 +1,304 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "19/02/2019"
+
+import os
+import tempfile
+import pytest
+from contextlib import contextmanager
+
+import numpy
+from ..DataViewer import DataViewer
+from ..DataViews import DataView
+from .. import DataViews
+
+from silx.gui import qt
+
+from silx.gui.data.DataViewerFrame import DataViewerFrame
+from silx.gui.utils.testutils import SignalListener
+from silx.gui.utils.testutils import TestCaseQt
+
+import h5py
+
+
+class _DataViewMock(DataView):
+ """Dummy view to display nothing"""
+
+ def __init__(self, parent):
+ DataView.__init__(self, parent)
+
+ def axesNames(self, data, info):
+ return []
+
+ def createWidget(self, parent):
+ return qt.QLabel(parent)
+
+ def getDataPriority(self, data, info):
+ return 0
+
+
+class _TestAbstractDataViewer(TestCaseQt):
+ __test__ = False # ignore abstract class
+
+ def create_widget(self):
+ # Avoid to raise an error when testing the full module
+ self.skipTest("Not implemented")
+
+ @contextmanager
+ def h5_temporary_file(self):
+ # create tmp file
+ fd, tmp_name = tempfile.mkstemp(suffix=".h5")
+ os.close(fd)
+ data = numpy.arange(3 * 3 * 3)
+ data.shape = 3, 3, 3
+ # create h5 data
+ h5file = h5py.File(tmp_name, "w")
+ h5file["data"] = data
+ yield h5file
+ # clean up
+ h5file.close()
+ os.unlink(tmp_name)
+
+ def test_text_data(self):
+ data_list = ["aaa", int, 8, self]
+ widget = self.create_widget()
+ for data in data_list:
+ widget.setData(data)
+ self.assertEqual(DataViews.RAW_MODE, widget.displayMode())
+
+ def test_plot_1d_data(self):
+ data = numpy.arange(3 ** 1)
+ data.shape = [3] * 1
+ widget = self.create_widget()
+ widget.setData(data)
+ availableModes = set([v.modeId() for v in widget.currentAvailableViews()])
+ self.assertEqual(DataViews.RAW_MODE, widget.displayMode())
+ self.assertIn(DataViews.PLOT1D_MODE, availableModes)
+
+ def test_image_data(self):
+ data = numpy.arange(3 ** 2)
+ data.shape = [3] * 2
+ widget = self.create_widget()
+ widget.setData(data)
+ availableModes = set([v.modeId() for v in widget.currentAvailableViews()])
+ self.assertEqual(DataViews.RAW_MODE, widget.displayMode())
+ self.assertIn(DataViews.IMAGE_MODE, availableModes)
+
+ def test_image_bool(self):
+ data = numpy.zeros((10, 10), dtype=bool)
+ data[::2, ::2] = True
+ widget = self.create_widget()
+ widget.setData(data)
+ availableModes = set([v.modeId() for v in widget.currentAvailableViews()])
+ self.assertEqual(DataViews.RAW_MODE, widget.displayMode())
+ self.assertIn(DataViews.IMAGE_MODE, availableModes)
+
+ def test_image_complex_data(self):
+ data = numpy.arange(3 ** 2, dtype=numpy.complex64)
+ data.shape = [3] * 2
+ widget = self.create_widget()
+ widget.setData(data)
+ availableModes = set([v.modeId() for v in widget.currentAvailableViews()])
+ self.assertEqual(DataViews.RAW_MODE, widget.displayMode())
+ self.assertIn(DataViews.IMAGE_MODE, availableModes)
+
+ def test_plot_3d_data(self):
+ data = numpy.arange(3 ** 3)
+ data.shape = [3] * 3
+ widget = self.create_widget()
+ widget.setData(data)
+ availableModes = set([v.modeId() for v in widget.currentAvailableViews()])
+ try:
+ import silx.gui.plot3d # noqa
+ self.assertIn(DataViews.PLOT3D_MODE, availableModes)
+ except ImportError:
+ self.assertIn(DataViews.STACK_MODE, availableModes)
+ self.assertEqual(DataViews.RAW_MODE, widget.displayMode())
+
+ def test_array_1d_data(self):
+ data = numpy.array(["aaa"] * (3 ** 1))
+ data.shape = [3] * 1
+ widget = self.create_widget()
+ widget.setData(data)
+ self.assertEqual(DataViews.RAW_MODE, widget.displayedView().modeId())
+
+ def test_array_2d_data(self):
+ data = numpy.array(["aaa"] * (3 ** 2))
+ data.shape = [3] * 2
+ widget = self.create_widget()
+ widget.setData(data)
+ self.assertEqual(DataViews.RAW_MODE, widget.displayedView().modeId())
+
+ def test_array_4d_data(self):
+ data = numpy.array(["aaa"] * (3 ** 4))
+ data.shape = [3] * 4
+ widget = self.create_widget()
+ widget.setData(data)
+ self.assertEqual(DataViews.RAW_MODE, widget.displayedView().modeId())
+
+ def test_record_4d_data(self):
+ data = numpy.zeros(3 ** 4, dtype='3int8, float32, (2,3)float64')
+ data.shape = [3] * 4
+ widget = self.create_widget()
+ widget.setData(data)
+ self.assertEqual(DataViews.RAW_MODE, widget.displayedView().modeId())
+
+ def test_3d_h5_dataset(self):
+ with self.h5_temporary_file() as h5file:
+ dataset = h5file["data"]
+ widget = self.create_widget()
+ widget.setData(dataset)
+
+ def test_data_event(self):
+ listener = SignalListener()
+ widget = self.create_widget()
+ widget.dataChanged.connect(listener)
+ widget.setData(10)
+ widget.setData(None)
+ self.assertEqual(listener.callCount(), 2)
+
+ def test_display_mode_event(self):
+ listener = SignalListener()
+ widget = self.create_widget()
+ widget.displayedViewChanged.connect(listener)
+ widget.setData(10)
+ widget.setData(None)
+ modes = [v.modeId() for v in listener.arguments(argumentIndex=0)]
+ self.assertEqual(modes, [DataViews.RAW_MODE, DataViews.EMPTY_MODE])
+ listener.clear()
+
+ def test_change_display_mode(self):
+ data = numpy.arange(10 ** 4)
+ data.shape = [10] * 4
+ widget = self.create_widget()
+ widget.setData(data)
+ widget.setDisplayMode(DataViews.PLOT1D_MODE)
+ self.assertEqual(widget.displayedView().modeId(), DataViews.PLOT1D_MODE)
+ widget.setDisplayMode(DataViews.IMAGE_MODE)
+ self.assertEqual(widget.displayedView().modeId(), DataViews.IMAGE_MODE)
+ widget.setDisplayMode(DataViews.RAW_MODE)
+ self.assertEqual(widget.displayedView().modeId(), DataViews.RAW_MODE)
+ widget.setDisplayMode(DataViews.EMPTY_MODE)
+ self.assertEqual(widget.displayedView().modeId(), DataViews.EMPTY_MODE)
+
+ def test_create_default_views(self):
+ widget = self.create_widget()
+ views = widget.createDefaultViews()
+ self.assertTrue(len(views) > 0)
+
+ def test_add_view(self):
+ widget = self.create_widget()
+ view = _DataViewMock(widget)
+ widget.addView(view)
+ self.assertTrue(view in widget.availableViews())
+ self.assertTrue(view in widget.currentAvailableViews())
+
+ def test_remove_view(self):
+ widget = self.create_widget()
+ widget.setData("foobar")
+ view = widget.currentAvailableViews()[0]
+ widget.removeView(view)
+ self.assertTrue(view not in widget.availableViews())
+ self.assertTrue(view not in widget.currentAvailableViews())
+
+ def test_replace_view(self):
+ widget = self.create_widget()
+ view = _DataViewMock(widget)
+ widget.replaceView(DataViews.RAW_MODE,
+ view)
+ self.assertIsNone(widget.getViewFromModeId(DataViews.RAW_MODE))
+ self.assertTrue(view in widget.availableViews())
+ self.assertTrue(view in widget.currentAvailableViews())
+
+ def test_replace_view_in_composite(self):
+ # replace a view that is a child of a composite view
+ widget = self.create_widget()
+ view = _DataViewMock(widget)
+ replaced = widget.replaceView(DataViews.NXDATA_INVALID_MODE,
+ view)
+ self.assertTrue(replaced)
+ nxdata_view = widget.getViewFromModeId(DataViews.NXDATA_MODE)
+ self.assertNotIn(DataViews.NXDATA_INVALID_MODE,
+ [v.modeId() for v in nxdata_view.getViews()])
+ self.assertTrue(view in nxdata_view.getViews())
+
+
+class TestDataViewer(_TestAbstractDataViewer):
+ __test__ = True # because _TestAbstractDataViewer is ignored
+ def create_widget(self):
+ return DataViewer()
+
+
+class TestDataViewerFrame(_TestAbstractDataViewer):
+ __test__ = True # because _TestAbstractDataViewer is ignored
+ def create_widget(self):
+ return DataViewerFrame()
+
+
+class TestDataView(TestCaseQt):
+
+ def createComplexData(self):
+ line = [1, 2j, 3 + 3j, 4]
+ image = [line, line, line, line]
+ cube = [image, image, image, image]
+ data = numpy.array(cube, dtype=numpy.complex64)
+ return data
+
+ def createDataViewWithData(self, dataViewClass, data):
+ viewer = dataViewClass(None)
+ widget = viewer.getWidget()
+ viewer.setData(data)
+ return widget
+
+ def testCurveWithComplex(self):
+ data = self.createComplexData()
+ dataViewClass = DataViews._Plot1dView
+ widget = self.createDataViewWithData(dataViewClass, data[0, 0])
+ self.qWaitForWindowExposed(widget)
+
+ def testImageWithComplex(self):
+ data = self.createComplexData()
+ dataViewClass = DataViews._Plot2dView
+ widget = self.createDataViewWithData(dataViewClass, data[0])
+ self.qWaitForWindowExposed(widget)
+
+ @pytest.mark.usefixtures("use_opengl")
+ def testCubeWithComplex(self):
+ try:
+ import silx.gui.plot3d # noqa
+ except ImportError:
+ self.skipTest("OpenGL not available")
+ data = self.createComplexData()
+ dataViewClass = DataViews._Plot3dView
+ widget = self.createDataViewWithData(dataViewClass, data)
+ self.qWaitForWindowExposed(widget)
+
+ def testImageStackWithComplex(self):
+ data = self.createComplexData()
+ dataViewClass = DataViews._StackView
+ widget = self.createDataViewWithData(dataViewClass, data)
+ self.qWaitForWindowExposed(widget)
diff --git a/src/silx/gui/data/test/test_numpyaxesselector.py b/src/silx/gui/data/test/test_numpyaxesselector.py
new file mode 100644
index 0000000..37b8d3e
--- /dev/null
+++ b/src/silx/gui/data/test/test_numpyaxesselector.py
@@ -0,0 +1,150 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "29/01/2018"
+
+import os
+import tempfile
+import unittest
+from contextlib import contextmanager
+
+import numpy
+
+from silx.gui.data.NumpyAxesSelector import NumpyAxesSelector
+from silx.gui.utils.testutils import SignalListener
+from silx.gui.utils.testutils import TestCaseQt
+
+import h5py
+
+
+class TestNumpyAxesSelector(TestCaseQt):
+
+ def test_creation(self):
+ data = numpy.arange(3 * 3 * 3)
+ data.shape = 3, 3, 3
+ widget = NumpyAxesSelector()
+ widget.setVisible(True)
+
+ def test_none(self):
+ data = numpy.arange(3 * 3 * 3)
+ widget = NumpyAxesSelector()
+ widget.setData(data)
+ widget.setData(None)
+ result = widget.selectedData()
+ self.assertIsNone(result)
+
+ def test_output_samedim(self):
+ data = numpy.arange(3 * 3 * 3)
+ data.shape = 3, 3, 3
+ expectedResult = data
+
+ widget = NumpyAxesSelector()
+ widget.setAxisNames(["x", "y", "z"])
+ widget.setData(data)
+ result = widget.selectedData()
+ self.assertTrue(numpy.array_equal(result, expectedResult))
+
+ def test_output_moredim(self):
+ data = numpy.arange(3 * 3 * 3 * 3)
+ data.shape = 3, 3, 3, 3
+ expectedResult = data
+
+ widget = NumpyAxesSelector()
+ widget.setAxisNames(["x", "y", "z", "boum"])
+ widget.setData(data[0])
+ result = widget.selectedData()
+ self.assertIsNone(result)
+ widget.setData(data)
+ result = widget.selectedData()
+ self.assertTrue(numpy.array_equal(result, expectedResult))
+
+ def test_output_lessdim(self):
+ data = numpy.arange(3 * 3 * 3)
+ data.shape = 3, 3, 3
+ expectedResult = data[0]
+
+ widget = NumpyAxesSelector()
+ widget.setAxisNames(["y", "x"])
+ widget.setData(data)
+ result = widget.selectedData()
+ self.assertTrue(numpy.array_equal(result, expectedResult))
+
+ def test_output_1dim(self):
+ data = numpy.arange(3 * 3 * 3)
+ data.shape = 3, 3, 3
+ expectedResult = data[0, 0, 0]
+
+ widget = NumpyAxesSelector()
+ widget.setData(data)
+ result = widget.selectedData()
+ self.assertTrue(numpy.array_equal(result, expectedResult))
+
+ @contextmanager
+ def h5_temporary_file(self):
+ # create tmp file
+ fd, tmp_name = tempfile.mkstemp(suffix=".h5")
+ os.close(fd)
+ data = numpy.arange(3 * 3 * 3)
+ data.shape = 3, 3, 3
+ # create h5 data
+ h5file = h5py.File(tmp_name, "w")
+ h5file["data"] = data
+ yield h5file
+ # clean up
+ h5file.close()
+ os.unlink(tmp_name)
+
+ def test_h5py_dataset(self):
+ with self.h5_temporary_file() as h5file:
+ dataset = h5file["data"]
+ expectedResult = dataset[0]
+
+ widget = NumpyAxesSelector()
+ widget.setData(dataset)
+ widget.setAxisNames(["y", "x"])
+ result = widget.selectedData()
+ self.assertTrue(numpy.array_equal(result, expectedResult))
+
+ def test_data_event(self):
+ data = numpy.arange(3 * 3 * 3)
+ widget = NumpyAxesSelector()
+ listener = SignalListener()
+ widget.dataChanged.connect(listener)
+ widget.setData(data)
+ widget.setData(None)
+ self.assertEqual(listener.callCount(), 2)
+
+ def test_selected_data_event(self):
+ data = numpy.arange(3 * 3 * 3)
+ data.shape = 3, 3, 3
+ widget = NumpyAxesSelector()
+ listener = SignalListener()
+ widget.selectionChanged.connect(listener)
+ widget.setData(data)
+ widget.setAxisNames(["x"])
+ widget.setData(None)
+ self.assertEqual(listener.callCount(), 3)
+ listener.clear()
diff --git a/src/silx/gui/data/test/test_textformatter.py b/src/silx/gui/data/test/test_textformatter.py
new file mode 100644
index 0000000..af41def
--- /dev/null
+++ b/src/silx/gui/data/test/test_textformatter.py
@@ -0,0 +1,199 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "12/12/2017"
+
+import unittest
+import shutil
+import tempfile
+
+import numpy
+
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.utils.testutils import SignalListener
+from ..TextFormatter import TextFormatter
+from silx.io.utils import h5py_read_dataset
+
+import h5py
+
+
+class TestTextFormatter(TestCaseQt):
+
+ def test_copy(self):
+ formatter = TextFormatter()
+ copy = TextFormatter(formatter=formatter)
+ self.assertIsNot(formatter, copy)
+ copy.setFloatFormat("%.3f")
+ self.assertEqual(formatter.integerFormat(), copy.integerFormat())
+ self.assertNotEqual(formatter.floatFormat(), copy.floatFormat())
+ self.assertEqual(formatter.useQuoteForText(), copy.useQuoteForText())
+ self.assertEqual(formatter.imaginaryUnit(), copy.imaginaryUnit())
+
+ def test_event(self):
+ listener = SignalListener()
+ formatter = TextFormatter()
+ formatter.formatChanged.connect(listener)
+ formatter.setFloatFormat("%.3f")
+ formatter.setIntegerFormat("%03i")
+ formatter.setUseQuoteForText(False)
+ formatter.setImaginaryUnit("z")
+ self.assertEqual(listener.callCount(), 4)
+
+ def test_int(self):
+ formatter = TextFormatter()
+ formatter.setIntegerFormat("%05i")
+ result = formatter.toString(512)
+ self.assertEqual(result, "00512")
+
+ def test_float(self):
+ formatter = TextFormatter()
+ formatter.setFloatFormat("%.3f")
+ result = formatter.toString(1.3)
+ self.assertEqual(result, "1.300")
+
+ def test_complex(self):
+ formatter = TextFormatter()
+ formatter.setFloatFormat("%.1f")
+ formatter.setImaginaryUnit("i")
+ result = formatter.toString(1.0 + 5j)
+ result = result.replace(" ", "")
+ self.assertEqual(result, "1.0+5.0i")
+
+ def test_string(self):
+ formatter = TextFormatter()
+ formatter.setIntegerFormat("%.1f")
+ formatter.setImaginaryUnit("z")
+ result = formatter.toString("toto")
+ self.assertEqual(result, '"toto"')
+
+ def test_numpy_void(self):
+ formatter = TextFormatter()
+ result = formatter.toString(numpy.void(b"\xFF"))
+ self.assertEqual(result, 'b"\\xFF"')
+
+ def test_char_cp1252(self):
+ # degree character in cp1252
+ formatter = TextFormatter()
+ result = formatter.toString(numpy.bytes_(b"\xB0"))
+ self.assertEqual(result, u'"\u00B0"')
+
+
+class TestTextFormatterWithH5py(TestCaseQt):
+
+ @classmethod
+ def setUpClass(cls):
+ super(TestTextFormatterWithH5py, cls).setUpClass()
+
+ cls.tmpDirectory = tempfile.mkdtemp()
+ cls.h5File = h5py.File("%s/formatter.h5" % cls.tmpDirectory, mode="w")
+ cls.formatter = TextFormatter()
+
+ @classmethod
+ def tearDownClass(cls):
+ super(TestTextFormatterWithH5py, cls).tearDownClass()
+ cls.h5File.close()
+ cls.h5File = None
+ shutil.rmtree(cls.tmpDirectory)
+
+ def create_dataset(self, data, dtype=None):
+ testName = "%s" % self.id()
+ dataset = self.h5File.create_dataset(testName, data=data, dtype=dtype)
+ return dataset
+
+ def read_dataset(self, d):
+ return self.formatter.toString(d[()], dtype=d.dtype)
+
+ def testAscii(self):
+ d = self.create_dataset(data=b"abc")
+ result = self.read_dataset(d)
+ self.assertEqual(result, '"abc"')
+
+ def testUnicode(self):
+ d = self.create_dataset(data=u"i\u2661cookies")
+ result = self.read_dataset(d)
+ self.assertEqual(len(result), 11)
+ self.assertEqual(result, u'"i\u2661cookies"')
+
+ def testBadAscii(self):
+ d = self.create_dataset(data=b"\xF0\x9F\x92\x94")
+ result = self.read_dataset(d)
+ self.assertEqual(result, 'b"\\xF0\\x9F\\x92\\x94"')
+
+ def testVoid(self):
+ d = self.create_dataset(data=numpy.void(b"abc\xF0"))
+ result = self.read_dataset(d)
+ self.assertEqual(result, 'b"\\x61\\x62\\x63\\xF0"')
+
+ def testEnum(self):
+ dtype = h5py.special_dtype(enum=('i', {"RED": 0, "GREEN": 1, "BLUE": 42}))
+ d = numpy.array(42, dtype=dtype)
+ d = self.create_dataset(data=d)
+ result = self.read_dataset(d)
+ self.assertEqual(result, 'BLUE(42)')
+
+ def testRef(self):
+ dtype = h5py.special_dtype(ref=h5py.Reference)
+ d = numpy.array(self.h5File.ref, dtype=dtype)
+ d = self.create_dataset(data=d)
+ result = self.read_dataset(d)
+ self.assertEqual(result, 'REF')
+
+ def testArrayAscii(self):
+ d = self.create_dataset(data=[b"abc"])
+ result = self.read_dataset(d)
+ self.assertEqual(result, '["abc"]')
+
+ def testArrayUnicode(self):
+ dtype = h5py.special_dtype(vlen=str)
+ d = numpy.array([u"i\u2661cookies"], dtype=dtype)
+ d = self.create_dataset(data=d)
+ result = self.read_dataset(d)
+ self.assertEqual(len(result), 13)
+ self.assertEqual(result, u'["i\u2661cookies"]')
+
+ def testArrayBadAscii(self):
+ d = self.create_dataset(data=[b"\xF0\x9F\x92\x94"])
+ result = self.read_dataset(d)
+ self.assertEqual(result, '[b"\\xF0\\x9F\\x92\\x94"]')
+
+ def testArrayVoid(self):
+ d = self.create_dataset(data=numpy.void([b"abc\xF0"]))
+ result = self.read_dataset(d)
+ self.assertEqual(result, '[b"\\x61\\x62\\x63\\xF0"]')
+
+ def testArrayEnum(self):
+ dtype = h5py.special_dtype(enum=('i', {"RED": 0, "GREEN": 1, "BLUE": 42}))
+ d = numpy.array([42, 1, 100], dtype=dtype)
+ d = self.create_dataset(data=d)
+ result = self.read_dataset(d)
+ self.assertEqual(result, '[BLUE(42) GREEN(1) 100]')
+
+ def testArrayRef(self):
+ dtype = h5py.special_dtype(ref=h5py.Reference)
+ d = numpy.array([self.h5File.ref, None], dtype=dtype)
+ d = self.create_dataset(data=d)
+ result = self.read_dataset(d)
+ self.assertEqual(result, '[REF NULL_REF]')
diff --git a/src/silx/gui/dialog/AbstractDataFileDialog.py b/src/silx/gui/dialog/AbstractDataFileDialog.py
new file mode 100644
index 0000000..5272f48
--- /dev/null
+++ b/src/silx/gui/dialog/AbstractDataFileDialog.py
@@ -0,0 +1,1731 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module contains an :class:`AbstractDataFileDialog`.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "05/03/2019"
+
+
+import sys
+import os
+import logging
+import functools
+from distutils.version import LooseVersion
+
+import numpy
+
+import silx.io.url
+from silx.gui import qt
+from silx.gui.hdf5.Hdf5TreeModel import Hdf5TreeModel
+from . import utils
+from .FileTypeComboBox import FileTypeComboBox
+
+import fabio
+
+
+_logger = logging.getLogger(__name__)
+
+
+DEFAULT_SIDEBAR_URL = True
+"""Set it to false to disable initilializing of the sidebar urls with the
+default Qt list. This could allow to disable a behaviour known to segfault on
+some version of PyQt."""
+
+
+class _IconProvider(object):
+
+ FileDialogToParentDir = qt.QStyle.SP_CustomBase + 1
+
+ FileDialogToParentFile = qt.QStyle.SP_CustomBase + 2
+
+ def __init__(self):
+ self.__iconFileDialogToParentDir = None
+ self.__iconFileDialogToParentFile = None
+
+ def _createIconToParent(self, standardPixmap):
+ """
+
+ FIXME: It have to be tested for some OS (arrow icon do not have always
+ the same direction)
+ """
+ style = qt.QApplication.style()
+ baseIcon = style.standardIcon(qt.QStyle.SP_FileDialogToParent)
+ backgroundIcon = style.standardIcon(standardPixmap)
+ icon = qt.QIcon()
+
+ sizes = baseIcon.availableSizes()
+ sizes = sorted(sizes, key=lambda s: s.height())
+ sizes = filter(lambda s: s.height() < 100, sizes)
+ sizes = list(sizes)
+ if len(sizes) > 0:
+ baseSize = sizes[-1]
+ else:
+ baseSize = baseIcon.availableSizes()[0]
+ size = qt.QSize(baseSize.width(), baseSize.height() * 3 // 2)
+
+ modes = [qt.QIcon.Normal, qt.QIcon.Disabled]
+ for mode in modes:
+ pixmap = qt.QPixmap(size)
+ pixmap.fill(qt.Qt.transparent)
+ painter = qt.QPainter(pixmap)
+ painter.drawPixmap(0, 0, backgroundIcon.pixmap(baseSize, mode=mode))
+ painter.drawPixmap(0, size.height() // 3, baseIcon.pixmap(baseSize, mode=mode))
+ painter.end()
+ icon.addPixmap(pixmap, mode=mode)
+
+ return icon
+
+ def getFileDialogToParentDir(self):
+ if self.__iconFileDialogToParentDir is None:
+ self.__iconFileDialogToParentDir = self._createIconToParent(qt.QStyle.SP_DirIcon)
+ return self.__iconFileDialogToParentDir
+
+ def getFileDialogToParentFile(self):
+ if self.__iconFileDialogToParentFile is None:
+ self.__iconFileDialogToParentFile = self._createIconToParent(qt.QStyle.SP_FileIcon)
+ return self.__iconFileDialogToParentFile
+
+ def icon(self, kind):
+ if kind == self.FileDialogToParentDir:
+ return self.getFileDialogToParentDir()
+ elif kind == self.FileDialogToParentFile:
+ return self.getFileDialogToParentFile()
+ else:
+ style = qt.QApplication.style()
+ icon = style.standardIcon(kind)
+ return icon
+
+
+class _SideBar(qt.QListView):
+ """Sidebar containing shortcuts for common directories"""
+
+ def __init__(self, parent=None):
+ super(_SideBar, self).__init__(parent)
+ self.__iconProvider = qt.QFileIconProvider()
+ self.setUniformItemSizes(True)
+ model = qt.QStandardItemModel(self)
+ self.setModel(model)
+ self._initModel()
+ self.setEditTriggers(qt.QAbstractItemView.NoEditTriggers)
+
+ def iconProvider(self):
+ return self.__iconProvider
+
+ def _initModel(self):
+ urls = self._getDefaultUrls()
+ self.setUrls(urls)
+
+ def _getDefaultUrls(self):
+ """Returns the default shortcuts.
+
+ It uses the default QFileDialog shortcuts if it is possible, else
+ provides a link to the computer's root and the user's home.
+
+ :rtype: List[str]
+ """
+ urls = []
+ version = LooseVersion(qt.qVersion())
+ feed_sidebar = True
+
+ if not DEFAULT_SIDEBAR_URL:
+ _logger.debug("Skip default sidebar URLs (from setted variable)")
+ feed_sidebar = False
+ elif version < LooseVersion("5.11.2") and qt.BINDING == "PyQt5" and sys.platform in ["linux", "linux2"]:
+ # Avoid segfault on PyQt5 + gtk
+ _logger.debug("Skip default sidebar URLs (avoid PyQt5 segfault)")
+ feed_sidebar = False
+
+ if feed_sidebar:
+ # Get default shortcut
+ # There is no other way
+ d = qt.QFileDialog(self)
+ # Needed to be able to reach the sidebar urls
+ d.setOption(qt.QFileDialog.DontUseNativeDialog, True)
+ urls = d.sidebarUrls()
+ d.deleteLater()
+ d = None
+
+ if len(urls) == 0:
+ urls.append(qt.QUrl("file://"))
+ urls.append(qt.QUrl.fromLocalFile(qt.QDir.homePath()))
+
+ return urls
+
+ def setSelectedPath(self, path):
+ selected = None
+ model = self.model()
+ for i in range(model.rowCount()):
+ index = model.index(i, 0)
+ url = model.data(index, qt.Qt.UserRole)
+ urlPath = url.toLocalFile()
+ if path == urlPath:
+ selected = index
+
+ selectionModel = self.selectionModel()
+ if selected is not None:
+ selectionModel.setCurrentIndex(selected, qt.QItemSelectionModel.ClearAndSelect)
+ else:
+ selectionModel.clear()
+
+ def setUrls(self, urls):
+ model = self.model()
+ model.clear()
+
+ names = {}
+ names[qt.QDir.rootPath()] = "Computer"
+ names[qt.QDir.homePath()] = "Home"
+
+ style = qt.QApplication.style()
+ iconProvider = self.iconProvider()
+ for url in urls:
+ path = url.toLocalFile()
+ if path == "":
+ if sys.platform != "win32":
+ url = qt.QUrl(qt.QDir.rootPath())
+ name = "Computer"
+ icon = style.standardIcon(qt.QStyle.SP_ComputerIcon)
+ else:
+ fileInfo = qt.QFileInfo(path)
+ name = names.get(path, fileInfo.fileName())
+ icon = iconProvider.icon(fileInfo)
+
+ if icon.isNull():
+ icon = style.standardIcon(qt.QStyle.SP_MessageBoxCritical)
+
+ item = qt.QStandardItem()
+ item.setText(name)
+ item.setIcon(icon)
+ item.setData(url, role=qt.Qt.UserRole)
+ model.appendRow(item)
+
+ def urls(self):
+ result = []
+ model = self.model()
+ for i in range(model.rowCount()):
+ index = model.index(i, 0)
+ url = model.data(index, qt.Qt.UserRole)
+ result.append(url)
+ return result
+
+ def sizeHint(self):
+ index = self.model().index(0, 0)
+ return self.sizeHintForIndex(index) + qt.QSize(2 * self.frameWidth(), 2 * self.frameWidth())
+
+
+class _Browser(qt.QStackedWidget):
+
+ activated = qt.Signal(qt.QModelIndex)
+ selected = qt.Signal(qt.QModelIndex)
+ rootIndexChanged = qt.Signal(qt.QModelIndex)
+
+ def __init__(self, parent, listView, detailView):
+ qt.QStackedWidget.__init__(self, parent)
+ self.__listView = listView
+ self.__detailView = detailView
+ self.insertWidget(0, self.__listView)
+ self.insertWidget(1, self.__detailView)
+
+ self.__listView.activated.connect(self.__emitActivated)
+ self.__detailView.activated.connect(self.__emitActivated)
+
+ def __emitActivated(self, index):
+ self.activated.emit(index)
+
+ def __emitSelected(self, selected, deselected):
+ index = self.selectedIndex()
+ if index is not None:
+ self.selected.emit(index)
+
+ def selectedIndex(self):
+ if self.currentIndex() == 0:
+ selectionModel = self.__listView.selectionModel()
+ else:
+ selectionModel = self.__detailView.selectionModel()
+
+ if selectionModel is None:
+ return None
+
+ indexes = selectionModel.selectedIndexes()
+ # Filter non-main columns
+ indexes = [i for i in indexes if i.column() == 0]
+ if len(indexes) == 1:
+ index = indexes[0]
+ return index
+ return None
+
+ def model(self):
+ """Returns the current model."""
+ if self.currentIndex() == 0:
+ return self.__listView.model()
+ else:
+ return self.__detailView.model()
+
+ def selectIndex(self, index):
+ if self.currentIndex() == 0:
+ selectionModel = self.__listView.selectionModel()
+ else:
+ selectionModel = self.__detailView.selectionModel()
+ if selectionModel is None:
+ return
+ selectionModel.setCurrentIndex(index, qt.QItemSelectionModel.ClearAndSelect)
+
+ def viewMode(self):
+ """Returns the current view mode.
+
+ :rtype: qt.QFileDialog.ViewMode
+ """
+ if self.currentIndex() == 0:
+ return qt.QFileDialog.List
+ elif self.currentIndex() == 1:
+ return qt.QFileDialog.Detail
+ else:
+ assert(False)
+
+ def setViewMode(self, mode):
+ """Set the current view mode.
+
+ :param qt.QFileDialog.ViewMode mode: The new view mode
+ """
+ if mode == qt.QFileDialog.Detail:
+ self.showDetails()
+ elif mode == qt.QFileDialog.List:
+ self.showList()
+ else:
+ assert(False)
+
+ def showList(self):
+ self.__listView.show()
+ self.__detailView.hide()
+ self.setCurrentIndex(0)
+
+ def showDetails(self):
+ self.__listView.hide()
+ self.__detailView.show()
+ self.setCurrentIndex(1)
+ self.__detailView.updateGeometry()
+
+ def clear(self):
+ self.__listView.setRootIndex(qt.QModelIndex())
+ self.__detailView.setRootIndex(qt.QModelIndex())
+ selectionModel = self.__listView.selectionModel()
+ if selectionModel is not None:
+ selectionModel.selectionChanged.disconnect()
+ selectionModel.clear()
+ selectionModel = self.__detailView.selectionModel()
+ if selectionModel is not None:
+ selectionModel.selectionChanged.disconnect()
+ selectionModel.clear()
+ self.__listView.setModel(None)
+ self.__detailView.setModel(None)
+
+ def setRootIndex(self, index, model=None):
+ """Sets the root item to the item at the given index.
+ """
+ rootIndex = self.__listView.rootIndex()
+ newModel = model or index.model()
+ assert(newModel is not None)
+
+ if rootIndex is None or rootIndex.model() is not newModel:
+ # update the model
+ selectionModel = self.__listView.selectionModel()
+ if selectionModel is not None:
+ selectionModel.selectionChanged.disconnect()
+ selectionModel.clear()
+ selectionModel = self.__detailView.selectionModel()
+ if selectionModel is not None:
+ selectionModel.selectionChanged.disconnect()
+ selectionModel.clear()
+ pIndex = qt.QPersistentModelIndex(index)
+ self.__listView.setModel(newModel)
+ # changing the model of the tree view change the index mapping
+ # that is why we are using a persistance model index
+ self.__detailView.setModel(newModel)
+ index = newModel.index(pIndex.row(), pIndex.column(), pIndex.parent())
+ selectionModel = self.__listView.selectionModel()
+ selectionModel.selectionChanged.connect(self.__emitSelected)
+ selectionModel = self.__detailView.selectionModel()
+ selectionModel.selectionChanged.connect(self.__emitSelected)
+
+ self.__listView.setRootIndex(index)
+ self.__detailView.setRootIndex(index)
+ self.rootIndexChanged.emit(index)
+
+ def rootIndex(self):
+ """Returns the model index of the model's root item. The root item is
+ the parent item to the view's toplevel items. The root can be invalid.
+ """
+ return self.__listView.rootIndex()
+
+ __serialVersion = 1
+ """Store the current version of the serialized data"""
+
+ def visualRect(self, index):
+ """Returns the rectangle on the viewport occupied by the item at index.
+
+ :param qt.QModelIndex index: An index
+ :rtype: QRect
+ """
+ if self.currentIndex() == 0:
+ return self.__listView.visualRect(index)
+ else:
+ return self.__detailView.visualRect(index)
+
+ def viewport(self):
+ """Returns the viewport widget.
+
+ :param qt.QModelIndex index: An index
+ :rtype: QRect
+ """
+ if self.currentIndex() == 0:
+ return self.__listView.viewport()
+ else:
+ return self.__detailView.viewport()
+
+ def restoreState(self, state):
+ """Restores the dialogs's layout, history and current directory to the
+ state specified.
+
+ :param qt.QByeArray state: Stream containing the new state
+ :rtype: bool
+ """
+ stream = qt.QDataStream(state, qt.QIODevice.ReadOnly)
+
+ nameId = stream.readQString()
+ if nameId != "Browser":
+ _logger.warning("Stored state contains an invalid name id. Browser restoration cancelled.")
+ return False
+
+ version = stream.readInt32()
+ if version != self.__serialVersion:
+ _logger.warning("Stored state contains an invalid version. Browser restoration cancelled.")
+ return False
+
+ headerData = stream.readQVariant()
+ self.__detailView.header().restoreState(headerData)
+
+ viewMode = stream.readInt32()
+ self.setViewMode(viewMode)
+ return True
+
+ def saveState(self):
+ """Saves the state of the dialog's layout.
+
+ :rtype: qt.QByteArray
+ """
+ data = qt.QByteArray()
+ stream = qt.QDataStream(data, qt.QIODevice.WriteOnly)
+
+ nameId = u"Browser"
+ stream.writeQString(nameId)
+ stream.writeInt32(self.__serialVersion)
+ stream.writeQVariant(self.__detailView.header().saveState())
+ stream.writeInt32(self.viewMode())
+
+ return data
+
+
+class _FabioData(object):
+
+ def __init__(self, fabioFile):
+ self.__fabioFile = fabioFile
+
+ @property
+ def dtype(self):
+ # Let say it is a valid type
+ return numpy.dtype("float")
+
+ @property
+ def shape(self):
+ if self.__fabioFile.nframes == 0:
+ return None
+ if self.__fabioFile.nframes == 1:
+ return [slice(None), slice(None)]
+ return [self.__fabioFile.nframes, slice(None), slice(None)]
+
+ def __getitem__(self, selector):
+ if self.__fabioFile.nframes == 1 and selector == tuple():
+ return self.__fabioFile.data
+ if isinstance(selector, tuple) and len(selector) == 1:
+ selector = selector[0]
+
+ if isinstance(selector, int):
+ if 0 <= selector < self.__fabioFile.nframes:
+ if self.__fabioFile.nframes == 1:
+ return self.__fabioFile.data
+ else:
+ frame = self.__fabioFile.getframe(selector)
+ return frame.data
+ else:
+ raise ValueError("Invalid selector %s" % selector)
+ else:
+ raise TypeError("Unsupported selector type %s" % type(selector))
+
+
+class _PathEdit(qt.QLineEdit):
+ pass
+
+
+class _CatchResizeEvent(qt.QObject):
+
+ resized = qt.Signal(qt.QResizeEvent)
+
+ def __init__(self, parent, target):
+ super(_CatchResizeEvent, self).__init__(parent)
+ self.__target = target
+ self.__target_oldResizeEvent = self.__target.resizeEvent
+ self.__target.resizeEvent = self.__resizeEvent
+
+ def __resizeEvent(self, event):
+ result = self.__target_oldResizeEvent(event)
+ self.resized.emit(event)
+ return result
+
+
+class AbstractDataFileDialog(qt.QDialog):
+ """The `AbstractFileDialog` provides a generic GUI to create a custom dialog
+ allowing to access to file resources like HDF5 files or HDF5 datasets.
+
+ .. image:: img/abstractdatafiledialog.png
+
+ The dialog contains:
+
+ - Shortcuts: It provides few links to have a fast access of browsing
+ locations.
+ - Browser: It provides a display to browse throw the file system and inside
+ HDF5 files or fabio files. A file format selector is provided.
+ - URL: Display the URL available to reach the data using
+ :meth:`silx.io.get_data`, :meth:`silx.io.open`.
+ - Data selector: A widget to apply a sub selection of the browsed dataset.
+ This widget can be provided, else nothing will be used.
+ - Data preview: A widget to preview the selected data, which is the result
+ of the filter from the data selector.
+ This widget can be provided, else nothing will be used.
+ - Preview's toolbar: Provides tools used to custom data preview or data
+ selector.
+ This widget can be provided, else nothing will be used.
+ - Buttons to validate the dialog
+ """
+
+ _defaultIconProvider = None
+ """Lazy loaded default icon provider"""
+
+ def __init__(self, parent=None):
+ super(AbstractDataFileDialog, self).__init__(parent)
+ self._init()
+
+ def _init(self):
+ self.setWindowTitle("Open")
+
+ self.__openedFiles = []
+ """Store the list of files opened by the model itself."""
+ # FIXME: It should be managed one by one by Hdf5Item itself
+
+ self.__directory = None
+ self.__directoryLoadedFilter = None
+ self.__errorWhileLoadingFile = None
+ self.__selectedFile = None
+ self.__selectedData = None
+ self.__currentHistory = []
+ """Store history of URLs, last index one is the latest one"""
+ self.__currentHistoryLocation = -1
+ """Store the location in the history. Bigger is older"""
+
+ self.__processing = 0
+ """Number of asynchronous processing tasks"""
+ self.__h5 = None
+ self.__fabio = None
+
+ # On Qt5 a safe icon provider is still needed to avoid freeze
+ _logger.debug("Uses default QFileSystemModel with a SafeFileIconProvider")
+ self.__fileModel = qt.QFileSystemModel(self)
+ from .SafeFileIconProvider import SafeFileIconProvider
+ iconProvider = SafeFileIconProvider()
+ self.__fileModel.setIconProvider(iconProvider)
+
+ # The common file dialog filter only on Mac OS X
+ self.__fileModel.setNameFilterDisables(sys.platform == "darwin")
+ self.__fileModel.setReadOnly(True)
+ self.__fileModel.directoryLoaded.connect(self.__directoryLoaded)
+
+ self.__dataModel = Hdf5TreeModel(self)
+
+ self.__createWidgets()
+ self.__initLayout()
+ self.__showAsListView()
+
+ path = os.getcwd()
+ self.__fileModel_setRootPath(path)
+
+ self.__clearData()
+ self.__updatePath()
+
+ # Update the file model filter
+ self.__fileTypeCombo.setCurrentIndex(0)
+ self.__filterSelected(0)
+
+ # It is not possible to override the QObject destructor nor
+ # to access to the content of the Python object with the `destroyed`
+ # signal cause the Python method was already removed with the QWidget,
+ # while the QObject still exists.
+ # We use a static method plus explicit references to objects to
+ # release. The callback do not use any ref to self.
+ onDestroy = functools.partial(self._closeFileList, self.__openedFiles)
+ self.destroyed.connect(onDestroy)
+
+ @staticmethod
+ def _closeFileList(fileList):
+ """Static method to close explicit references to internal objects."""
+ _logger.debug("Clear AbstractDataFileDialog")
+ for obj in fileList:
+ _logger.debug("Close file %s", obj.filename)
+ obj.close()
+ fileList[:] = []
+
+ def done(self, result):
+ self._clear()
+ super(AbstractDataFileDialog, self).done(result)
+
+ def _clear(self):
+ """Explicit method to clear data stored in the dialog.
+ After this call it is not anymore possible to use the widget.
+
+ This method is triggered by the destruction of the object and the
+ QDialog :meth:`done`. Then it can be triggered more than once.
+ """
+ _logger.debug("Clear dialog")
+ self.__errorWhileLoadingFile = None
+ self.__clearData()
+ if self.__fileModel is not None:
+ # Cache the directory before cleaning the model
+ self.__directory = self.directory()
+ self.__browser.clear()
+ self.__closeFile()
+ self.__fileModel = None
+ self.__dataModel = None
+
+ def hasPendingEvents(self):
+ """Returns true if the dialog have asynchronous tasks working on the
+ background."""
+ return self.__processing > 0
+
+ # User interface
+
+ def __createWidgets(self):
+ self.__sidebar = self._createSideBar()
+ if self.__sidebar is not None:
+ sideBarModel = self.__sidebar.selectionModel()
+ sideBarModel.selectionChanged.connect(self.__shortcutSelected)
+ self.__sidebar.setSelectionMode(qt.QAbstractItemView.SingleSelection)
+
+ listView = qt.QListView(self)
+ listView.setSelectionBehavior(qt.QAbstractItemView.SelectRows)
+ listView.setSelectionMode(qt.QAbstractItemView.SingleSelection)
+ listView.setResizeMode(qt.QListView.Adjust)
+ listView.setWrapping(True)
+ listView.setEditTriggers(qt.QAbstractItemView.NoEditTriggers)
+ listView.setContextMenuPolicy(qt.Qt.CustomContextMenu)
+ utils.patchToConsumeReturnKey(listView)
+
+ treeView = qt.QTreeView(self)
+ treeView.setSelectionBehavior(qt.QAbstractItemView.SelectRows)
+ treeView.setSelectionMode(qt.QAbstractItemView.SingleSelection)
+ treeView.setRootIsDecorated(False)
+ treeView.setItemsExpandable(False)
+ treeView.setSortingEnabled(True)
+ treeView.header().setSortIndicator(0, qt.Qt.AscendingOrder)
+ treeView.header().setStretchLastSection(False)
+ treeView.setTextElideMode(qt.Qt.ElideMiddle)
+ treeView.setEditTriggers(qt.QAbstractItemView.NoEditTriggers)
+ treeView.setContextMenuPolicy(qt.Qt.CustomContextMenu)
+ treeView.setDragDropMode(qt.QAbstractItemView.InternalMove)
+ utils.patchToConsumeReturnKey(treeView)
+
+ self.__browser = _Browser(self, listView, treeView)
+ self.__browser.activated.connect(self.__browsedItemActivated)
+ self.__browser.selected.connect(self.__browsedItemSelected)
+ self.__browser.rootIndexChanged.connect(self.__rootIndexChanged)
+ self.__browser.setObjectName("browser")
+
+ self.__previewWidget = self._createPreviewWidget(self)
+
+ self.__fileTypeCombo = FileTypeComboBox(self)
+ self.__fileTypeCombo.setObjectName("fileTypeCombo")
+ self.__fileTypeCombo.setDuplicatesEnabled(False)
+ self.__fileTypeCombo.setSizeAdjustPolicy(qt.QComboBox.AdjustToMinimumContentsLengthWithIcon)
+ self.__fileTypeCombo.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
+ self.__fileTypeCombo.activated[int].connect(self.__filterSelected)
+ self.__fileTypeCombo.setFabioUrlSupproted(self._isFabioFilesSupported())
+
+ self.__pathEdit = _PathEdit(self)
+ self.__pathEdit.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
+ self.__pathEdit.textChanged.connect(self.__textChanged)
+ self.__pathEdit.setObjectName("url")
+ utils.patchToConsumeReturnKey(self.__pathEdit)
+
+ self.__buttons = qt.QDialogButtonBox(self)
+ self.__buttons.setSizePolicy(qt.QSizePolicy.Fixed, qt.QSizePolicy.Fixed)
+ types = qt.QDialogButtonBox.Open | qt.QDialogButtonBox.Cancel
+ self.__buttons.setStandardButtons(types)
+ self.__buttons.button(qt.QDialogButtonBox.Cancel).setObjectName("cancel")
+ self.__buttons.button(qt.QDialogButtonBox.Open).setObjectName("open")
+
+ self.__buttons.accepted.connect(self.accept)
+ self.__buttons.rejected.connect(self.reject)
+
+ self.__browseToolBar = self._createBrowseToolBar()
+ self.__backwardAction.setEnabled(False)
+ self.__forwardAction.setEnabled(False)
+ self.__fileDirectoryAction.setEnabled(False)
+ self.__parentFileDirectoryAction.setEnabled(False)
+
+ self.__selectorWidget = self._createSelectorWidget(self)
+ if self.__selectorWidget is not None:
+ self.__selectorWidget.selectionChanged.connect(self.__selectorWidgetChanged)
+
+ self.__previewToolBar = self._createPreviewToolbar(self, self.__previewWidget, self.__selectorWidget)
+
+ self.__dataIcon = qt.QLabel(self)
+ self.__dataIcon.setSizePolicy(qt.QSizePolicy.Fixed, qt.QSizePolicy.Fixed)
+ self.__dataIcon.setScaledContents(True)
+ self.__dataIcon.setMargin(2)
+ self.__dataIcon.setAlignment(qt.Qt.AlignCenter)
+
+ self.__dataInfo = qt.QLabel(self)
+ self.__dataInfo.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
+
+ def _createSideBar(self):
+ sidebar = _SideBar(self)
+ sidebar.setObjectName("sidebar")
+ return sidebar
+
+ def iconProvider(self):
+ iconProvider = self.__class__._defaultIconProvider
+ if iconProvider is None:
+ iconProvider = _IconProvider()
+ self.__class__._defaultIconProvider = iconProvider
+ return iconProvider
+
+ def _createBrowseToolBar(self):
+ toolbar = qt.QToolBar(self)
+ toolbar.setIconSize(qt.QSize(16, 16))
+ iconProvider = self.iconProvider()
+
+ backward = qt.QAction(toolbar)
+ backward.setText("Back")
+ backward.setObjectName("backwardAction")
+ backward.setIcon(iconProvider.icon(qt.QStyle.SP_ArrowBack))
+ backward.triggered.connect(self.__navigateBackward)
+ self.__backwardAction = backward
+
+ forward = qt.QAction(toolbar)
+ forward.setText("Forward")
+ forward.setObjectName("forwardAction")
+ forward.setIcon(iconProvider.icon(qt.QStyle.SP_ArrowForward))
+ forward.triggered.connect(self.__navigateForward)
+ self.__forwardAction = forward
+
+ parentDirectory = qt.QAction(toolbar)
+ parentDirectory.setText("Go to parent")
+ parentDirectory.setObjectName("toParentAction")
+ parentDirectory.setIcon(iconProvider.icon(qt.QStyle.SP_FileDialogToParent))
+ parentDirectory.triggered.connect(self.__navigateToParent)
+ self.__toParentAction = parentDirectory
+
+ fileDirectory = qt.QAction(toolbar)
+ fileDirectory.setText("Root of the file")
+ fileDirectory.setObjectName("toRootFileAction")
+ fileDirectory.setIcon(iconProvider.icon(iconProvider.FileDialogToParentFile))
+ fileDirectory.triggered.connect(self.__navigateToParentFile)
+ self.__fileDirectoryAction = fileDirectory
+
+ parentFileDirectory = qt.QAction(toolbar)
+ parentFileDirectory.setText("Parent directory of the file")
+ parentFileDirectory.setObjectName("toDirectoryAction")
+ parentFileDirectory.setIcon(iconProvider.icon(iconProvider.FileDialogToParentDir))
+ parentFileDirectory.triggered.connect(self.__navigateToParentDir)
+ self.__parentFileDirectoryAction = parentFileDirectory
+
+ listView = qt.QAction(toolbar)
+ listView.setText("List view")
+ listView.setObjectName("listModeAction")
+ listView.setIcon(iconProvider.icon(qt.QStyle.SP_FileDialogListView))
+ listView.triggered.connect(self.__showAsListView)
+ listView.setCheckable(True)
+
+ detailView = qt.QAction(toolbar)
+ detailView.setText("Detail view")
+ detailView.setObjectName("detailModeAction")
+ detailView.setIcon(iconProvider.icon(qt.QStyle.SP_FileDialogDetailedView))
+ detailView.triggered.connect(self.__showAsDetailedView)
+ detailView.setCheckable(True)
+
+ self.__listViewAction = listView
+ self.__detailViewAction = detailView
+
+ toolbar.addAction(backward)
+ toolbar.addAction(forward)
+ toolbar.addSeparator()
+ toolbar.addAction(parentDirectory)
+ toolbar.addAction(fileDirectory)
+ toolbar.addAction(parentFileDirectory)
+ toolbar.addSeparator()
+ toolbar.addAction(listView)
+ toolbar.addAction(detailView)
+
+ toolbar.setStyleSheet("QToolBar { border: 0px }")
+
+ return toolbar
+
+ def __initLayout(self):
+ sideBarLayout = qt.QVBoxLayout()
+ sideBarLayout.setContentsMargins(0, 0, 0, 0)
+ dummyToolBar = qt.QWidget(self)
+ dummyToolBar.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
+ dummyCombo = qt.QWidget(self)
+ dummyCombo.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
+ sideBarLayout.addWidget(dummyToolBar)
+ if self.__sidebar is not None:
+ sideBarLayout.addWidget(self.__sidebar)
+ sideBarLayout.addWidget(dummyCombo)
+ sideBarWidget = qt.QWidget(self)
+ sideBarWidget.setLayout(sideBarLayout)
+
+ dummyCombo.setFixedHeight(self.__fileTypeCombo.height())
+ self.__resizeCombo = _CatchResizeEvent(self, self.__fileTypeCombo)
+ self.__resizeCombo.resized.connect(lambda e: dummyCombo.setFixedHeight(e.size().height()))
+
+ dummyToolBar.setFixedHeight(self.__browseToolBar.height())
+ self.__resizeToolbar = _CatchResizeEvent(self, self.__browseToolBar)
+ self.__resizeToolbar.resized.connect(lambda e: dummyToolBar.setFixedHeight(e.size().height()))
+
+ datasetSelection = qt.QWidget(self)
+ layoutLeft = qt.QVBoxLayout()
+ layoutLeft.setContentsMargins(0, 0, 0, 0)
+ layoutLeft.addWidget(self.__browseToolBar)
+ layoutLeft.addWidget(self.__browser)
+ layoutLeft.addWidget(self.__fileTypeCombo)
+ datasetSelection.setLayout(layoutLeft)
+ datasetSelection.setSizePolicy(qt.QSizePolicy.MinimumExpanding, qt.QSizePolicy.Expanding)
+
+ infoLayout = qt.QHBoxLayout()
+ infoLayout.setContentsMargins(0, 0, 0, 0)
+ infoLayout.addWidget(self.__dataIcon)
+ infoLayout.addWidget(self.__dataInfo)
+
+ dataFrame = qt.QFrame(self)
+ dataFrame.setFrameShape(qt.QFrame.StyledPanel)
+ layout = qt.QVBoxLayout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(0)
+ layout.addWidget(self.__previewWidget)
+ layout.addLayout(infoLayout)
+ dataFrame.setLayout(layout)
+
+ dataSelection = qt.QWidget(self)
+ dataLayout = qt.QVBoxLayout()
+ dataLayout.setContentsMargins(0, 0, 0, 0)
+ if self.__previewToolBar is not None:
+ dataLayout.addWidget(self.__previewToolBar)
+ else:
+ # Add dummy space
+ dummyToolbar2 = qt.QWidget(self)
+ dummyToolbar2.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
+ dummyToolbar2.setFixedHeight(self.__browseToolBar.height())
+ self.__resizeToolbar = _CatchResizeEvent(self, self.__browseToolBar)
+ self.__resizeToolbar.resized.connect(lambda e: dummyToolbar2.setFixedHeight(e.size().height()))
+ dataLayout.addWidget(dummyToolbar2)
+
+ dataLayout.addWidget(dataFrame)
+ if self.__selectorWidget is not None:
+ dataLayout.addWidget(self.__selectorWidget)
+ else:
+ # Add dummy space
+ dummyCombo2 = qt.QWidget(self)
+ dummyCombo2.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
+ dummyCombo2.setFixedHeight(self.__fileTypeCombo.height())
+ self.__resizeToolbar = _CatchResizeEvent(self, self.__fileTypeCombo)
+ self.__resizeToolbar.resized.connect(lambda e: dummyCombo2.setFixedHeight(e.size().height()))
+ dataLayout.addWidget(dummyCombo2)
+ dataSelection.setLayout(dataLayout)
+
+ self.__splitter = qt.QSplitter(self)
+ self.__splitter.setContentsMargins(0, 0, 0, 0)
+ self.__splitter.addWidget(sideBarWidget)
+ self.__splitter.addWidget(datasetSelection)
+ self.__splitter.addWidget(dataSelection)
+ self.__splitter.setStretchFactor(1, 10)
+
+ bottomLayout = qt.QHBoxLayout()
+ bottomLayout.setContentsMargins(0, 0, 0, 0)
+ bottomLayout.addWidget(self.__pathEdit)
+ bottomLayout.addWidget(self.__buttons)
+
+ layout = qt.QVBoxLayout(self)
+ layout.addWidget(self.__splitter)
+ layout.addLayout(bottomLayout)
+
+ self.setLayout(layout)
+ self.updateGeometry()
+
+ # Logic
+
+ def __navigateBackward(self):
+ """Navigate through the history one step backward."""
+ if len(self.__currentHistory) > 0 and self.__currentHistoryLocation > 0:
+ self.__currentHistoryLocation -= 1
+ url = self.__currentHistory[self.__currentHistoryLocation]
+ self.selectUrl(url)
+
+ def __navigateForward(self):
+ """Navigate through the history one step forward."""
+ if len(self.__currentHistory) > 0 and self.__currentHistoryLocation < len(self.__currentHistory) - 1:
+ self.__currentHistoryLocation += 1
+ url = self.__currentHistory[self.__currentHistoryLocation]
+ self.selectUrl(url)
+
+ def __navigateToParent(self):
+ index = self.__browser.rootIndex()
+ if index.model() is self.__fileModel:
+ # browse throw the file system
+ index = index.parent()
+ path = self.__fileModel.filePath(index)
+ self.__fileModel_setRootPath(path)
+ self.__browser.selectIndex(qt.QModelIndex())
+ self.__updatePath()
+ elif index.model() is self.__dataModel:
+ index = index.parent()
+ if index.isValid():
+ # browse throw the hdf5
+ self.__browser.setRootIndex(index)
+ self.__browser.selectIndex(qt.QModelIndex())
+ self.__updatePath()
+ else:
+ # go back to the file system
+ self.__navigateToParentDir()
+ else:
+ # Root of the file system (my computer)
+ pass
+
+ def __navigateToParentFile(self):
+ index = self.__browser.rootIndex()
+ if index.model() is self.__dataModel:
+ index = self.__dataModel.indexFromH5Object(self.__h5)
+ self.__browser.setRootIndex(index)
+ self.__browser.selectIndex(qt.QModelIndex())
+ self.__updatePath()
+
+ def __navigateToParentDir(self):
+ index = self.__browser.rootIndex()
+ if index.model() is self.__dataModel:
+ path = os.path.dirname(self.__h5.file.filename)
+ index = self.__fileModel.index(path)
+ self.__browser.setRootIndex(index)
+ self.__browser.selectIndex(qt.QModelIndex())
+ self.__closeFile()
+ self.__updatePath()
+
+ def viewMode(self):
+ """Returns the current view mode.
+
+ :rtype: qt.QFileDialog.ViewMode
+ """
+ return self.__browser.viewMode()
+
+ def setViewMode(self, mode):
+ """Set the current view mode.
+
+ :param qt.QFileDialog.ViewMode mode: The new view mode
+ """
+ if mode == qt.QFileDialog.Detail:
+ self.__browser.showDetails()
+ self.__listViewAction.setChecked(False)
+ self.__detailViewAction.setChecked(True)
+ elif mode == qt.QFileDialog.List:
+ self.__browser.showList()
+ self.__listViewAction.setChecked(True)
+ self.__detailViewAction.setChecked(False)
+ else:
+ assert(False)
+
+ def __showAsListView(self):
+ self.setViewMode(qt.QFileDialog.List)
+
+ def __showAsDetailedView(self):
+ self.setViewMode(qt.QFileDialog.Detail)
+
+ def __shortcutSelected(self):
+ self.__browser.selectIndex(qt.QModelIndex())
+ self.__clearData()
+ self.__updatePath()
+ selectionModel = self.__sidebar.selectionModel()
+ indexes = selectionModel.selectedIndexes()
+ if len(indexes) == 1:
+ index = indexes[0]
+ url = self.__sidebar.model().data(index, role=qt.Qt.UserRole)
+ path = url.toLocalFile()
+ self.__fileModel_setRootPath(path)
+
+ def __browsedItemActivated(self, index):
+ if not index.isValid():
+ return
+ if index.model() is self.__fileModel:
+ path = self.__fileModel.filePath(index)
+ if self.__fileModel.isDir(index):
+ self.__fileModel_setRootPath(path)
+ if os.path.isfile(path):
+ self.__fileActivated(index)
+ elif index.model() is self.__dataModel:
+ obj = self.__dataModel.data(index, role=Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ if silx.io.is_group(obj):
+ self.__browser.setRootIndex(index)
+ else:
+ assert(False)
+
+ def __browsedItemSelected(self, index):
+ self.__dataSelected(index)
+ self.__updatePath()
+
+ def __fileModel_setRootPath(self, path):
+ """Set the root path of the fileModel with a filter on the
+ directoryLoaded event.
+
+ Without this filter an extra event is received (at least with PyQt4)
+ when we use for the first time the sidebar.
+
+ :param str path: Path to load
+ """
+ assert(path is not None)
+ if path != "" and not os.path.exists(path):
+ return
+ if self.hasPendingEvents():
+ # Make sure the asynchronous fileModel setRootPath is finished
+ qt.QApplication.instance().processEvents()
+
+ if self.__directoryLoadedFilter is not None:
+ if utils.samefile(self.__directoryLoadedFilter, path):
+ return
+ self.__directoryLoadedFilter = path
+ self.__processing += 1
+ if self.__fileModel is None:
+ return
+ index = self.__fileModel.setRootPath(path)
+ if not index.isValid():
+ # There is a problem with this path
+ # No asynchronous process will be waked up
+ self.__processing -= 1
+ self.__browser.setRootIndex(index, model=self.__fileModel)
+ self.__clearData()
+ self.__updatePath()
+
+ def __directoryLoaded(self, path):
+ if self.__directoryLoadedFilter is not None:
+ if not utils.samefile(self.__directoryLoadedFilter, path):
+ # Filter event which should not arrive in PyQt4
+ # The first click on the sidebar sent 2 events
+ self.__processing -= 1
+ return
+ if self.__fileModel is None:
+ return
+ index = self.__fileModel.index(path)
+ self.__browser.setRootIndex(index, model=self.__fileModel)
+ self.__updatePath()
+ self.__processing -= 1
+
+ def __closeFile(self):
+ self.__openedFiles[:] = []
+ self.__fileDirectoryAction.setEnabled(False)
+ self.__parentFileDirectoryAction.setEnabled(False)
+ if self.__h5 is not None:
+ self.__dataModel.removeH5pyObject(self.__h5)
+ self.__h5.close()
+ self.__h5 = None
+ if self.__fabio is not None:
+ if hasattr(self.__fabio, "close"):
+ self.__fabio.close()
+ self.__fabio = None
+
+ def __openFabioFile(self, filename):
+ self.__closeFile()
+ try:
+ self.__fabio = fabio.open(filename)
+ self.__openedFiles.append(self.__fabio)
+ self.__selectedFile = filename
+ except Exception as e:
+ _logger.error("Error while loading file %s: %s", filename, e.args[0])
+ _logger.debug("Backtrace", exc_info=True)
+ self.__errorWhileLoadingFile = filename, e.args[0]
+ return False
+ else:
+ return True
+
+ def __openSilxFile(self, filename):
+ self.__closeFile()
+ try:
+ self.__h5 = silx.io.open(filename)
+ self.__openedFiles.append(self.__h5)
+ self.__selectedFile = filename
+ except IOError as e:
+ _logger.error("Error while loading file %s: %s", filename, e.args[0])
+ _logger.debug("Backtrace", exc_info=True)
+ self.__errorWhileLoadingFile = filename, e.args[0]
+ return False
+ else:
+ self.__fileDirectoryAction.setEnabled(True)
+ self.__parentFileDirectoryAction.setEnabled(True)
+ self.__dataModel.insertH5pyObject(self.__h5)
+ return True
+
+ def __isSilxHavePriority(self, filename):
+ """Silx have priority when there is a specific decoder
+ """
+ _, ext = os.path.splitext(filename)
+ ext = "*%s" % ext
+ formats = silx.io.supported_extensions(flat_formats=False)
+ for extensions in formats.values():
+ if ext in extensions:
+ return True
+ return False
+
+ def __openFile(self, filename):
+ codec = self.__fileTypeCombo.currentCodec()
+ openners = []
+ if codec.is_autodetect():
+ if self.__isSilxHavePriority(filename):
+ openners.append(self.__openSilxFile)
+ if self._isFabioFilesSupported():
+ openners.append(self.__openFabioFile)
+ else:
+ if self._isFabioFilesSupported():
+ openners.append(self.__openFabioFile)
+ openners.append(self.__openSilxFile)
+ elif codec.is_silx_codec():
+ openners.append(self.__openSilxFile)
+ elif self._isFabioFilesSupported() and codec.is_fabio_codec():
+ # It is requested to use fabio, anyway fabio is here or not
+ openners.append(self.__openFabioFile)
+
+ for openner in openners:
+ ref = openner(filename)
+ if ref is not None:
+ return True
+ return False
+
+ def __fileActivated(self, index):
+ self.__selectedFile = None
+ path = self.__fileModel.filePath(index)
+ if os.path.isfile(path):
+ loaded = self.__openFile(path)
+ if loaded:
+ if self.__h5 is not None:
+ index = self.__dataModel.indexFromH5Object(self.__h5)
+ self.__browser.setRootIndex(index)
+ elif self.__fabio is not None:
+ data = _FabioData(self.__fabio)
+ self.__setData(data)
+ self.__updatePath()
+ else:
+ self.__clearData()
+
+ def __dataSelected(self, index):
+ selectedData = None
+ if index is not None:
+ if index.model() is self.__dataModel:
+ obj = self.__dataModel.data(index, self.__dataModel.H5PY_OBJECT_ROLE)
+ if self._isDataSupportable(obj):
+ selectedData = obj
+ elif index.model() is self.__fileModel:
+ self.__closeFile()
+ if self._isFabioFilesSupported():
+ path = self.__fileModel.filePath(index)
+ if os.path.isfile(path):
+ codec = self.__fileTypeCombo.currentCodec()
+ is_fabio_decoder = codec.is_fabio_codec()
+ is_fabio_have_priority = not codec.is_silx_codec() and not self.__isSilxHavePriority(path)
+ if is_fabio_decoder or is_fabio_have_priority:
+ # Then it's flat frame container
+ self.__openFabioFile(path)
+ if self.__fabio is not None:
+ selectedData = _FabioData(self.__fabio)
+ else:
+ assert(False)
+
+ self.__setData(selectedData)
+
+ def __filterSelected(self, index):
+ filters = self.__fileTypeCombo.itemExtensions(index)
+ self.__fileModel.setNameFilters(list(filters))
+
+ def __setData(self, data):
+ self.__data = data
+
+ if data is not None and self._isDataSupportable(data):
+ if self.__selectorWidget is not None:
+ self.__selectorWidget.setData(data)
+ if not self.__selectorWidget.isUsed():
+ # Needed to fake the fact we have to reset the zoom in preview
+ self.__selectedData = None
+ self.__setSelectedData(data)
+ self.__selectorWidget.hide()
+ else:
+ self.__selectorWidget.setVisible(self.__selectorWidget.hasVisibleSelectors())
+ # Needed to fake the fact we have to reset the zoom in preview
+ self.__selectedData = None
+ self.__selectorWidget.selectionChanged.emit()
+ else:
+ # Needed to fake the fact we have to reset the zoom in preview
+ self.__selectedData = None
+ self.__setSelectedData(data)
+ else:
+ self.__clearData()
+ self.__updatePath()
+
+ def _isDataSupported(self, data):
+ """Check if the data can be returned by the dialog.
+
+ If true, this data can be returned by the dialog and the open button
+ while be enabled. If false the button will be disabled.
+
+ :rtype: bool
+ """
+ raise NotImplementedError()
+
+ def _isDataSupportable(self, data):
+ """Check if the selected data can be supported at one point.
+
+ If true, the data selector will be checked and it will update the data
+ preview. Else the selecting is disabled.
+
+ :rtype: bool
+ """
+ raise NotImplementedError()
+
+ def __clearData(self):
+ """Clear the data part of the GUI"""
+ if self.__previewWidget is not None:
+ self.__previewWidget.setData(None)
+ if self.__selectorWidget is not None:
+ self.__selectorWidget.setData(None)
+ self.__selectorWidget.hide()
+ self.__selectedData = None
+ self.__data = None
+ self.__updateDataInfo()
+ button = self.__buttons.button(qt.QDialogButtonBox.Open)
+ button.setEnabled(False)
+
+ def __selectorWidgetChanged(self):
+ data = self.__selectorWidget.getSelectedData(self.__data)
+ self.__setSelectedData(data)
+
+ def __setSelectedData(self, data):
+ """Set the data selected by the dialog.
+
+ If :meth:`_isDataSupported` returns false, this function will be
+ inhibited and no data will be selected.
+ """
+ if isinstance(data, _FabioData):
+ data = data[()]
+ if self.__previewWidget is not None:
+ fromDataSelector = self.__selectedData is not None
+ self.__previewWidget.setData(data, fromDataSelector=fromDataSelector)
+ if self._isDataSupported(data):
+ self.__selectedData = data
+ else:
+ self.__clearData()
+ return
+ self.__updateDataInfo()
+ self.__updatePath()
+ button = self.__buttons.button(qt.QDialogButtonBox.Open)
+ button.setEnabled(True)
+
+ def __updateDataInfo(self):
+ if self.__errorWhileLoadingFile is not None:
+ filename, message = self.__errorWhileLoadingFile
+ message = "<b>Error while loading file '%s'</b><hr/>%s" % (filename, message)
+ size = self.__dataInfo.height()
+ icon = self.style().standardIcon(qt.QStyle.SP_MessageBoxCritical)
+ pixmap = icon.pixmap(size, size)
+
+ self.__dataInfo.setText("Error while loading file")
+ self.__dataInfo.setToolTip(message)
+ self.__dataIcon.setToolTip(message)
+ self.__dataIcon.setVisible(True)
+ self.__dataIcon.setPixmap(pixmap)
+
+ self.__errorWhileLoadingFile = None
+ return
+
+ self.__dataIcon.setVisible(False)
+ self.__dataInfo.setToolTip("")
+ if self.__selectedData is None:
+ self.__dataInfo.setText("No data selected")
+ else:
+ text = self._displayedDataInfo(self.__data, self.__selectedData)
+ self.__dataInfo.setVisible(text is not None)
+ if text is not None:
+ self.__dataInfo.setText(text)
+
+ def _displayedDataInfo(self, dataBeforeSelection, dataAfterSelection):
+ """Returns the text displayed under the data preview.
+
+ This zone is used to display error in case or problem of data selection
+ or problems with IO.
+
+ :param numpy.ndarray dataAfterSelection: Data as it is after the
+ selection widget (basically the data from the preview widget)
+ :param numpy.ndarray dataAfterSelection: Data as it is before the
+ selection widget (basically the data from the browsing widget)
+ :rtype: bool
+ """
+ return None
+
+ def __createUrlFromIndex(self, index, useSelectorWidget=True):
+ if index.model() is self.__fileModel:
+ filename = self.__fileModel.filePath(index)
+ dataPath = None
+ elif index.model() is self.__dataModel:
+ obj = self.__dataModel.data(index, role=Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ filename = obj.file.filename
+ dataPath = obj.name
+ else:
+ # root of the computer
+ filename = ""
+ dataPath = None
+
+ if useSelectorWidget and self.__selectorWidget is not None and self.__selectorWidget.isUsed():
+ slicing = self.__selectorWidget.slicing()
+ if slicing == tuple():
+ slicing = None
+ else:
+ slicing = None
+
+ if self.__fabio is not None:
+ scheme = "fabio"
+ elif self.__h5 is not None:
+ scheme = "silx"
+ else:
+ if os.path.isfile(filename):
+ codec = self.__fileTypeCombo.currentCodec()
+ if codec.is_fabio_codec():
+ scheme = "fabio"
+ elif codec.is_silx_codec():
+ scheme = "silx"
+ else:
+ scheme = None
+ else:
+ scheme = None
+
+ url = silx.io.url.DataUrl(file_path=filename, data_path=dataPath, data_slice=slicing, scheme=scheme)
+ return url
+
+ def __updatePath(self):
+ index = self.__browser.selectedIndex()
+ if index is None:
+ index = self.__browser.rootIndex()
+ url = self.__createUrlFromIndex(index)
+ if url.path() != self.__pathEdit.text():
+ old = self.__pathEdit.blockSignals(True)
+ self.__pathEdit.setText(url.path())
+ self.__pathEdit.blockSignals(old)
+
+ def __rootIndexChanged(self, index):
+ url = self.__createUrlFromIndex(index, useSelectorWidget=False)
+
+ currentUrl = None
+ if 0 <= self.__currentHistoryLocation < len(self.__currentHistory):
+ currentUrl = self.__currentHistory[self.__currentHistoryLocation]
+
+ if currentUrl is None or currentUrl != url.path():
+ # clean up the forward history
+ self.__currentHistory = self.__currentHistory[0:self.__currentHistoryLocation + 1]
+ self.__currentHistory.append(url.path())
+ self.__currentHistoryLocation += 1
+
+ if index.model() != self.__dataModel:
+ if sys.platform == "win32":
+ # path == ""
+ isRoot = not index.isValid()
+ else:
+ # path in ["", "/"]
+ isRoot = not index.isValid() or not index.parent().isValid()
+ else:
+ isRoot = False
+
+ if index.isValid():
+ self.__dataSelected(index)
+ self.__toParentAction.setEnabled(not isRoot)
+ self.__updateActionHistory()
+ self.__updateSidebar()
+
+ def __updateSidebar(self):
+ """Called when the current directory location change"""
+ if self.__sidebar is None:
+ return
+ selectionModel = self.__sidebar.selectionModel()
+ selectionModel.selectionChanged.disconnect(self.__shortcutSelected)
+ index = self.__browser.rootIndex()
+ if index.model() == self.__fileModel:
+ path = self.__fileModel.filePath(index)
+ self.__sidebar.setSelectedPath(path)
+ elif index.model() is None:
+ path = ""
+ self.__sidebar.setSelectedPath(path)
+ else:
+ selectionModel.clear()
+ selectionModel.selectionChanged.connect(self.__shortcutSelected)
+
+ def __updateActionHistory(self):
+ self.__forwardAction.setEnabled(len(self.__currentHistory) - 1 > self.__currentHistoryLocation)
+ self.__backwardAction.setEnabled(self.__currentHistoryLocation > 0)
+
+ def __textChanged(self, text):
+ self.__pathChanged()
+
+ def _isFabioFilesSupported(self):
+ """Returns true fabio files can be loaded.
+ """
+ return True
+
+ def _isLoadableUrl(self, url):
+ """Returns true if the URL is loadable by this dialog.
+
+ :param DataUrl url: The requested URL
+ """
+ return True
+
+ def __pathChanged(self):
+ url = silx.io.url.DataUrl(path=self.__pathEdit.text())
+ if url.is_valid() or url.path() == "":
+ if url.path() in ["", "/"] or url.file_path() in ["", "/"]:
+ self.__fileModel_setRootPath(qt.QDir.rootPath())
+ elif os.path.exists(url.file_path()):
+ rootIndex = None
+ if os.path.isdir(url.file_path()):
+ self.__fileModel_setRootPath(url.file_path())
+ index = self.__fileModel.index(url.file_path())
+ elif os.path.isfile(url.file_path()):
+ if self._isLoadableUrl(url):
+ if url.scheme() == "silx":
+ loaded = self.__openSilxFile(url.file_path())
+ elif url.scheme() == "fabio" and self._isFabioFilesSupported():
+ loaded = self.__openFabioFile(url.file_path())
+ else:
+ loaded = self.__openFile(url.file_path())
+ else:
+ loaded = False
+ if loaded:
+ if self.__h5 is not None:
+ rootIndex = self.__dataModel.indexFromH5Object(self.__h5)
+ elif self.__fabio is not None:
+ index = self.__fileModel.index(url.file_path())
+ rootIndex = index
+ if rootIndex is None:
+ index = self.__fileModel.index(url.file_path())
+ index = index.parent()
+
+ if rootIndex is not None:
+ if rootIndex.model() == self.__dataModel:
+ if url.data_path() is not None:
+ dataPath = url.data_path()
+ if dataPath in self.__h5:
+ obj = self.__h5[dataPath]
+ else:
+ path = utils.findClosestSubPath(self.__h5, dataPath)
+ if path is None:
+ path = "/"
+ obj = self.__h5[path]
+
+ if silx.io.is_file(obj):
+ self.__browser.setRootIndex(rootIndex)
+ elif silx.io.is_group(obj):
+ index = self.__dataModel.indexFromH5Object(obj)
+ self.__browser.setRootIndex(index)
+ else:
+ index = self.__dataModel.indexFromH5Object(obj)
+ self.__browser.setRootIndex(index.parent())
+ self.__browser.selectIndex(index)
+ else:
+ self.__browser.setRootIndex(rootIndex)
+ self.__clearData()
+ elif rootIndex.model() == self.__fileModel:
+ # that's a fabio file
+ self.__browser.setRootIndex(rootIndex.parent())
+ self.__browser.selectIndex(rootIndex)
+ # data = _FabioData(self.__fabio)
+ # self.__setData(data)
+ else:
+ assert(False)
+ else:
+ self.__browser.setRootIndex(index, model=self.__fileModel)
+ self.__clearData()
+
+ if self.__selectorWidget is not None:
+ self.__selectorWidget.selectSlicing(url.data_slice())
+ else:
+ self.__errorWhileLoadingFile = (url.file_path(), "File not found")
+ self.__clearData()
+ else:
+ self.__errorWhileLoadingFile = (url.file_path(), "Path invalid")
+ self.__clearData()
+
+ def previewToolbar(self):
+ return self.__previewToolbar
+
+ def previewWidget(self):
+ return self.__previewWidget
+
+ def selectorWidget(self):
+ return self.__selectorWidget
+
+ def _createPreviewToolbar(self, parent, dataPreviewWidget, dataSelectorWidget):
+ return None
+
+ def _createPreviewWidget(self, parent):
+ return None
+
+ def _createSelectorWidget(self, parent):
+ return None
+
+ # Selected file
+
+ def setDirectory(self, path):
+ """Sets the data dialog's current directory."""
+ self.__fileModel_setRootPath(path)
+
+ def selectedFile(self):
+ """Returns the file path containing the selected data.
+
+ :rtype: str
+ """
+ return self.__selectedFile
+
+ def selectFile(self, filename):
+ """Sets the data dialog's current file."""
+ self.__directoryLoadedFilter = ""
+ old = self.__pathEdit.blockSignals(True)
+ try:
+ self.__pathEdit.setText(filename)
+ finally:
+ self.__pathEdit.blockSignals(old)
+ self.__pathChanged()
+
+ # Selected data
+
+ def selectUrl(self, url):
+ """Sets the data dialog's current data url.
+
+ :param Union[str,DataUrl] url: URL identifying a data (it can be a
+ `DataUrl` object)
+ """
+ if isinstance(url, silx.io.url.DataUrl):
+ url = url.path()
+ self.__directoryLoadedFilter = ""
+ old = self.__pathEdit.blockSignals(True)
+ try:
+ self.__pathEdit.setText(url)
+ finally:
+ self.__pathEdit.blockSignals(old)
+ self.__pathChanged()
+
+ def selectedUrl(self):
+ """Returns the URL from the file system to the data.
+
+ If the dialog is not validated, the path can be an intermediat
+ selected path, or an invalid path.
+
+ :rtype: str
+ """
+ return self.__pathEdit.text()
+
+ def selectedDataUrl(self):
+ """Returns the URL as a :class:`DataUrl` from the file system to the
+ data.
+
+ If the dialog is not validated, the path can be an intermediat
+ selected path, or an invalid path.
+
+ :rtype: DataUrl
+ """
+ url = self.selectedUrl()
+ return silx.io.url.DataUrl(url)
+
+ def directory(self):
+ """Returns the path from the current browsed directory.
+
+ :rtype: str
+ """
+ if self.__directory is not None:
+ # At post execution, returns the cache
+ return self.__directory
+
+ index = self.__browser.rootIndex()
+ if index.model() is self.__fileModel:
+ path = self.__fileModel.filePath(index)
+ return path
+ elif index.model() is self.__dataModel:
+ path = os.path.dirname(self.__h5.file.filename)
+ return path
+ else:
+ return ""
+
+ def _selectedData(self):
+ """Returns the internal selected data
+
+ :rtype: numpy.ndarray
+ """
+ return self.__selectedData
+
+ # Filters
+
+ def selectedNameFilter(self):
+ """Returns the filter that the user selected in the file dialog."""
+ return self.__fileTypeCombo.currentText()
+
+ # History
+
+ def history(self):
+ """Returns the browsing history of the filedialog as a list of paths.
+
+ :rtype: List<str>
+ """
+ if len(self.__currentHistory) <= 1:
+ return []
+ history = self.__currentHistory[0:self.__currentHistoryLocation]
+ return list(history)
+
+ def setHistory(self, history):
+ self.__currentHistory = []
+ self.__currentHistory.extend(history)
+ self.__currentHistoryLocation = len(self.__currentHistory) - 1
+ self.__updateActionHistory()
+
+ # Colormap
+
+ def colormap(self):
+ if self.__previewWidget is None:
+ return None
+ return self.__previewWidget.colormap()
+
+ def setColormap(self, colormap):
+ if self.__previewWidget is None:
+ raise RuntimeError("No preview widget defined")
+ self.__previewWidget.setColormap(colormap)
+
+ # Sidebar
+
+ def setSidebarUrls(self, urls):
+ """Sets the urls that are located in the sidebar."""
+ if self.__sidebar is None:
+ return
+ self.__sidebar.setUrls(urls)
+
+ def sidebarUrls(self):
+ """Returns a list of urls that are currently in the sidebar."""
+ if self.__sidebar is None:
+ return []
+ return self.__sidebar.urls()
+
+ # State
+
+ __serialVersion = 1
+ """Store the current version of the serialized data"""
+
+ @classmethod
+ def qualifiedName(cls):
+ return "%s.%s" % (cls.__module__, cls.__name__)
+
+ def restoreState(self, state):
+ """Restores the dialogs's layout, history and current directory to the
+ state specified.
+
+ :param qt.QByteArray state: Stream containing the new state
+ :rtype: bool
+ """
+ stream = qt.QDataStream(state, qt.QIODevice.ReadOnly)
+
+ qualifiedName = stream.readQString()
+ if qualifiedName != self.qualifiedName():
+ _logger.warning("Stored state contains an invalid qualified name. %s restoration cancelled.", self.__class__.__name__)
+ return False
+
+ version = stream.readInt32()
+ if version != self.__serialVersion:
+ _logger.warning("Stored state contains an invalid version. %s restoration cancelled.", self.__class__.__name__)
+ return False
+
+ result = True
+
+ splitterData = stream.readQVariant()
+ sidebarUrls = stream.readQStringList()
+ history = stream.readQStringList()
+ workingDirectory = stream.readQString()
+ browserData = stream.readQVariant()
+ viewMode = stream.readInt32()
+ colormapData = stream.readQVariant()
+
+ result &= self.__splitter.restoreState(splitterData)
+ sidebarUrls = [qt.QUrl(s) for s in sidebarUrls]
+ self.setSidebarUrls(list(sidebarUrls))
+ history = [s for s in history]
+ self.setHistory(list(history))
+ if workingDirectory is not None:
+ self.setDirectory(workingDirectory)
+ result &= self.__browser.restoreState(browserData)
+ self.setViewMode(viewMode)
+ colormap = self.colormap()
+ if colormap is not None:
+ result &= self.colormap().restoreState(colormapData)
+
+ return result
+
+ def saveState(self):
+ """Saves the state of the dialog's layout, history and current
+ directory.
+
+ :rtype: qt.QByteArray
+ """
+ data = qt.QByteArray()
+ stream = qt.QDataStream(data, qt.QIODevice.WriteOnly)
+
+ s = self.qualifiedName()
+ stream.writeQString(u"%s" % s)
+ stream.writeInt32(self.__serialVersion)
+ stream.writeQVariant(self.__splitter.saveState())
+ strings = [u"%s" % s.toString() for s in self.sidebarUrls()]
+ stream.writeQStringList(strings)
+ strings = [u"%s" % s for s in self.history()]
+ stream.writeQStringList(strings)
+ stream.writeQString(u"%s" % self.directory())
+ stream.writeQVariant(self.__browser.saveState())
+ stream.writeInt32(self.viewMode())
+ colormap = self.colormap()
+ if colormap is not None:
+ stream.writeQVariant(self.colormap().saveState())
+ else:
+ stream.writeQVariant(None)
+
+ return data
diff --git a/src/silx/gui/dialog/ColormapDialog.py b/src/silx/gui/dialog/ColormapDialog.py
new file mode 100644
index 0000000..2506e2a
--- /dev/null
+++ b/src/silx/gui/dialog/ColormapDialog.py
@@ -0,0 +1,1775 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""A QDialog widget to set-up the colormap.
+
+It uses a description of colormaps as dict compatible with :class:`Plot`.
+
+To run the following sample code, a QApplication must be initialized.
+
+Create the colormap dialog and set the colormap description and data range:
+
+>>> from silx.gui.dialog.ColormapDialog import ColormapDialog
+>>> from silx.gui.colors import Colormap
+
+>>> dialog = ColormapDialog()
+>>> colormap = Colormap(name='red', normalization='log',
+... vmin=1., vmax=2.)
+
+>>> dialog.setColormap(colormap)
+>>> colormap.setVRange(1., 100.) # This scale the width of the plot area
+>>> dialog.show()
+
+Get the colormap description (compatible with :class:`Plot`) from the dialog:
+
+>>> cmap = dialog.getColormap()
+>>> cmap.getName()
+'red'
+
+It is also possible to display an histogram of the image in the dialog.
+This updates the data range with the range of the bins.
+
+>>> import numpy
+>>> image = numpy.random.normal(size=512 * 512).reshape(512, -1)
+>>> hist, bin_edges = numpy.histogram(image, bins=10)
+>>> dialog.setHistogram(hist, bin_edges)
+
+The updates of the colormap description are also available through the signal:
+:attr:`ColormapDialog.sigColormapChanged`.
+""" # noqa
+
+__authors__ = ["V.A. Sole", "T. Vincent", "H. Payno"]
+__license__ = "MIT"
+__date__ = "08/12/2020"
+
+import enum
+import logging
+
+import numpy
+
+from .. import qt
+from .. import utils
+from ..colors import Colormap, cursorColorForColormap
+from ..plot import PlotWidget
+from ..plot.items.axis import Axis
+from ..plot.items import BoundingRect
+from silx.gui.widgets.FloatEdit import FloatEdit
+import weakref
+from silx.math.combo import min_max
+from silx.gui.plot import items
+from silx.gui import icons
+from silx.gui.qt import inspect as qtinspect
+from silx.gui.widgets.ColormapNameComboBox import ColormapNameComboBox
+from silx.gui.widgets.WaitingPushButton import WaitingPushButton
+from silx.math.histogram import Histogramnd
+from silx.utils import deprecation
+from silx.gui.plot.items.roi import RectangleROI
+from silx.gui.plot.tools.roi import RegionOfInterestManager
+
+_logger = logging.getLogger(__name__)
+
+_colormapIconPreview = {}
+
+
+class _DataRefHolder(items.Item, items.ColormapMixIn):
+ """Holder for a weakref of a numpy array.
+
+ It provides features from `ColormapMixIn`.
+ """
+
+ def __init__(self, dataRef):
+ items.Item.__init__(self)
+ items.ColormapMixIn.__init__(self)
+ self.__dataRef = dataRef
+ self._updated(items.ItemChangedType.DATA)
+
+ def getColormappedData(self, copy=True):
+ return self.__dataRef()
+
+
+class _BoundaryWidget(qt.QWidget):
+ """Widget to edit a boundary of the colormap (vmin or vmax)"""
+
+ sigAutoScaleChanged = qt.Signal(object)
+ """Signal emitted when the autoscale was changed
+
+ True is sent as an argument if autoscale is set to true.
+ """
+
+ sigValueChanged = qt.Signal(object)
+ """Signal emitted when value is changed
+
+ The new value is sent as an argument.
+ """
+
+ def __init__(self, parent=None, value=0.0):
+ qt.QWidget.__init__(self, parent=parent)
+ self.setLayout(qt.QHBoxLayout())
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self._numVal = FloatEdit(parent=self, value=value)
+ self.layout().addWidget(self._numVal)
+ self._autoCB = qt.QCheckBox('auto', parent=self)
+ self.layout().addWidget(self._autoCB)
+ self._autoCB.setChecked(False)
+ self._autoCB.setVisible(False)
+
+ self._autoCB.toggled.connect(self._autoToggled)
+ self._numVal.textEdited.connect(self.__textEdited)
+ self._numVal.editingFinished.connect(self.__editingFinished)
+ self.setFocusProxy(self._numVal)
+
+ self.__textWasEdited = False
+ """True if the text was edited, in order to send an event
+ at the end of the user interaction"""
+
+ self.__realValue = None
+ """Store the real value set by setValue, to avoid
+ rounding of the widget"""
+
+ def __textEdited(self):
+ self.__textWasEdited = True
+
+ def __editingFinished(self):
+ if self.__textWasEdited:
+ value = self._numVal.value()
+ self.__realValue = value
+ with utils.blockSignals(self._numVal):
+ # Fix the formatting
+ self._numVal.setValue(self.__realValue)
+ self.sigValueChanged.emit(value)
+ self.__textWasEdited = False
+
+ def isAutoChecked(self):
+ return self._autoCB.isChecked()
+
+ def getValue(self):
+ """Returns the stored range. If autoscale is
+ enabled, this returns None.
+ """
+ if self._autoCB.isChecked():
+ return None
+ if self.__realValue is not None:
+ return self.__realValue
+ return self._numVal.value()
+
+ def _autoToggled(self, enabled):
+ self._numVal.setEnabled(not enabled)
+ self._updateDisplayedText()
+ self.sigAutoScaleChanged.emit(enabled)
+
+ def _updateDisplayedText(self):
+ self.__textWasEdited = False
+ if self._autoCB.isChecked() and self.__realValue is not None:
+ with utils.blockSignals(self._numVal):
+ self._numVal.setValue(self.__realValue)
+
+ def setValue(self, value, isAuto=False):
+ """Set the value of the boundary.
+
+ :param float value: A finite value for the boundary
+ :param bool isAuto: If true, the finite value was automatically computed
+ from the data, else it is a fixed custom value.
+ """
+ assert value is not None
+ self._autoCB.setChecked(isAuto)
+ with utils.blockSignals(self._numVal):
+ if isAuto or self.__realValue != value:
+ if not self.__textWasEdited:
+ self._numVal.setValue(value)
+ self.__realValue = value
+ self._numVal.setEnabled(not isAuto)
+
+
+class _AutoscaleModeComboBox(qt.QComboBox):
+
+ DATA = {
+ Colormap.MINMAX: ("Min/max", "Use the data min/max"),
+ Colormap.STDDEV3: ("Mean ± 3 × stddev", "Use the data mean ± 3 × standard deviation"),
+ }
+
+ def __init__(self, parent: qt.QWidget):
+ super(_AutoscaleModeComboBox, self).__init__(parent=parent)
+ self.currentIndexChanged.connect(self.__updateTooltip)
+ self._init()
+
+ def _init(self):
+ for mode in Colormap.AUTOSCALE_MODES:
+ label, tooltip = self.DATA.get(mode, (mode, None))
+ self.addItem(label, mode)
+ if tooltip is not None:
+ self.setItemData(self.count() - 1, tooltip, qt.Qt.ToolTipRole)
+
+ def setCurrentIndex(self, index):
+ self.__updateTooltip(index)
+ super(_AutoscaleModeComboBox, self).setCurrentIndex(index)
+
+ def __updateTooltip(self, index):
+ if index > -1:
+ tooltip = self.itemData(index, qt.Qt.ToolTipRole)
+ else:
+ tooltip = ""
+ self.setToolTip(tooltip)
+
+ def currentMode(self):
+ index = self.currentIndex()
+ return self.itemData(index)
+
+ def setCurrentMode(self, mode):
+ for index in range(self.count()):
+ if mode == self.itemData(index):
+ self.setCurrentIndex(index)
+ return
+ if mode is None:
+ # If None was not a value
+ self.setCurrentIndex(-1)
+ return
+ self.addItem(mode, mode)
+ self.setCurrentIndex(self.count() - 1)
+
+
+class _AutoScaleButtons(qt.QWidget):
+
+ autoRangeChanged = qt.Signal(object)
+
+ def __init__(self, parent=None):
+ qt.QWidget.__init__(self, parent=parent)
+ layout = qt.QHBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(0)
+
+ self.setFocusPolicy(qt.Qt.NoFocus)
+
+ self._bothAuto = qt.QPushButton(self)
+ self._bothAuto.setText("Autoscale")
+ self._bothAuto.setToolTip("Enable/disable the autoscale for both min and max")
+ self._bothAuto.setCheckable(True)
+ self._bothAuto.toggled[bool].connect(self.__bothToggled)
+ self._bothAuto.setFocusPolicy(qt.Qt.TabFocus)
+
+ self._minAuto = qt.QCheckBox(self)
+ self._minAuto.setText("")
+ self._minAuto.setToolTip("Enable/disable the autoscale for min")
+ self._minAuto.toggled[bool].connect(self.__minToggled)
+ self._minAuto.setFocusPolicy(qt.Qt.TabFocus)
+
+ self._maxAuto = qt.QCheckBox(self)
+ self._maxAuto.setText("")
+ self._maxAuto.setToolTip("Enable/disable the autoscale for max")
+ self._maxAuto.toggled[bool].connect(self.__maxToggled)
+ self._maxAuto.setFocusPolicy(qt.Qt.TabFocus)
+
+ layout.addStretch(1)
+ layout.addWidget(self._minAuto)
+ layout.addSpacing(20)
+ layout.addWidget(self._bothAuto)
+ layout.addSpacing(20)
+ layout.addWidget(self._maxAuto)
+ layout.addStretch(1)
+
+ def __bothToggled(self, checked):
+ autoRange = checked, checked
+ self.setAutoRange(autoRange)
+ self.autoRangeChanged.emit(autoRange)
+
+ def __minToggled(self, checked):
+ autoRange = self.getAutoRange()
+ self.setAutoRange(autoRange)
+ self.autoRangeChanged.emit(autoRange)
+
+ def __maxToggled(self, checked):
+ autoRange = self.getAutoRange()
+ self.setAutoRange(autoRange)
+ self.autoRangeChanged.emit(autoRange)
+
+ def setAutoRangeFromColormap(self, colormap):
+ vRange = colormap.getVRange()
+ autoRange = vRange[0] is None, vRange[1] is None
+ self.setAutoRange(autoRange)
+
+ def setAutoRange(self, autoRange):
+ if autoRange[0] == autoRange[1]:
+ with utils.blockSignals(self._bothAuto):
+ self._bothAuto.setChecked(autoRange[0])
+ else:
+ with utils.blockSignals(self._bothAuto):
+ self._bothAuto.setChecked(False)
+ with utils.blockSignals(self._minAuto):
+ self._minAuto.setChecked(autoRange[0])
+ with utils.blockSignals(self._maxAuto):
+ self._maxAuto.setChecked(autoRange[1])
+
+ def getAutoRange(self):
+ return self._minAuto.isChecked(), self._maxAuto.isChecked()
+
+
+@enum.unique
+class _DataInPlotMode(enum.Enum):
+ """Enum for each mode of display of the data in the plot."""
+ RANGE = 'range'
+ HISTOGRAM = 'histogram'
+
+
+class _ColormapHistogram(qt.QWidget):
+ """Display the colormap and the data as a plot."""
+
+ sigRangeMoving = qt.Signal(object, object)
+ """Emitted when a mouse interaction moves the location
+ of the colormap range in the plot.
+
+ This signal contains 2 elements:
+
+ - vmin: A float value if this range was moved, else None
+ - vmax: A float value if this range was moved, else None
+ """
+
+ sigRangeMoved = qt.Signal(object, object)
+ """Emitted when a mouse interaction stop.
+
+ This signal contains 2 elements:
+
+ - vmin: A float value if this range was moved, else None
+ - vmax: A float value if this range was moved, else None
+ """
+
+ def __init__(self, parent):
+ qt.QWidget.__init__(self, parent=parent)
+ self._dataInPlotMode = _DataInPlotMode.RANGE
+ self._finiteRange = None, None
+ self._initPlot()
+
+ self._histogramData = {}
+ """Histogram displayed in the plot"""
+
+ self._dragging = False, False
+ """True, if the min or the max handle is dragging"""
+
+ self._dataRange = {}
+ """Histogram displayed in the plot"""
+
+ self._invalidated = False
+
+ def paintEvent(self, event):
+ if self._invalidated:
+ self._updateDataInPlot()
+ self._invalidated = False
+ self._updateMarkerPosition()
+ return super(_ColormapHistogram, self).paintEvent(event)
+
+ def getFiniteRange(self):
+ """Returns the colormap range as displayed in the plot."""
+ return self._finiteRange
+
+ def setFiniteRange(self, vRange):
+ """Set the colormap range to use in the plot.
+
+ Here there is no concept of auto. The values should
+ not be None, except if there is no range or marker
+ to display.
+ """
+ # Do not reset the limit for handle about to be dragged
+ if self._dragging[0]:
+ vRange = self._finiteRange[0], vRange[1]
+ if self._dragging[1]:
+ vRange = vRange[0], self._finiteRange[1]
+
+ if vRange == self._finiteRange:
+ return
+
+ self._finiteRange = vRange
+ self.update()
+
+ def getColormap(self):
+ return self.parent().getColormap()
+
+ def _getNormalizedHistogram(self):
+ """Return an histogram already normalized according to the colormap
+ normalization.
+
+ Returns a tuple edges, counts
+ """
+ norm = self._getNorm()
+ histogram = self._histogramData.get(norm, None)
+ if histogram is None:
+ histogram = self._computeNormalizedHistogram()
+ self._histogramData[norm] = histogram
+ return histogram
+
+ def _computeNormalizedHistogram(self):
+ colormap = self.getColormap()
+ if colormap is None:
+ norm = Colormap.LINEAR
+ else:
+ norm = colormap.getNormalization()
+
+ # Try to use the histogram defined in the dialog
+ histo = self.parent()._getHistogram()
+ if histo is not None:
+ counts, edges = histo
+ normalizer = Colormap(normalization=norm)._getNormalizer()
+ mask = normalizer.is_valid(edges[:-1]) # Check lower bin edges only
+ firstValid = numpy.argmax(mask) # edges increases monotonically
+ if firstValid == 0: # Mask is all False or all True
+ return (counts, edges) if mask[0] else (None, None)
+ else: # Clip to valid values
+ return counts[firstValid:], edges[firstValid:]
+
+ data = self.parent()._getArray()
+ if data is None:
+ return None, None
+ dataRange = self._getNormalizedDataRange()
+ if dataRange[0] is None or dataRange[1] is None:
+ return None, None
+ counts, edges = self.parent().computeHistogram(data, scale=norm, dataRange=dataRange)
+ return counts, edges
+
+ def _getNormalizedDataRange(self):
+ """Return a data range already normalized according to the colormap
+ normalization.
+
+ Returns a tuple with min and max
+ """
+ norm = self._getNorm()
+ dataRange = self._dataRange.get(norm, None)
+ if dataRange is None:
+ dataRange = self._computeNormalizedDataRange()
+ self._dataRange[norm] = dataRange
+ return dataRange
+
+ def _computeNormalizedDataRange(self):
+ colormap = self.getColormap()
+ if colormap is None:
+ norm = Colormap.LINEAR
+ else:
+ norm = colormap.getNormalization()
+
+ # Try to use the one defined in the dialog
+ dataRange = self.parent()._getDataRange()
+ if dataRange is not None:
+ if norm in (Colormap.LINEAR, Colormap.GAMMA, Colormap.ARCSINH):
+ return dataRange[0], dataRange[2]
+ elif norm == Colormap.LOGARITHM:
+ return dataRange[1], dataRange[2]
+ elif norm == Colormap.SQRT:
+ return dataRange[1], dataRange[2]
+ else:
+ _logger.error("Undefined %s normalization", norm)
+
+ # Try to use the histogram defined in the dialog
+ histo = self.parent()._getHistogram()
+ if histo is not None:
+ _histo, edges = histo
+ normalizer = Colormap(normalization=norm)._getNormalizer()
+ edges = edges[normalizer.is_valid(edges)]
+ if edges.size == 0:
+ return None, None
+ else:
+ dataRange = min_max(edges, finite=True)
+ return dataRange.minimum, dataRange.maximum
+
+ item = self.parent()._getItem()
+ if item is not None:
+ # Trick to reach data range using colormap cache
+ cm = Colormap()
+ cm.setVRange(None, None)
+ cm.setNormalization(norm)
+ dataRange = item._getColormapAutoscaleRange(cm)
+ return dataRange
+
+ # If there is no item, there is no data
+ return None, None
+
+ def _getDisplayableRange(self):
+ """Returns the selected min/max range to apply to the data,
+ according to the used scale.
+
+ One or both limits can be None in case it is not displayable in the
+ current axes scale.
+
+ :returns: Tuple{float, float}
+ """
+ scale = self._plot.getXAxis().getScale()
+
+ def isDisplayable(pos):
+ if pos is None:
+ return False
+ if scale == Axis.LOGARITHMIC:
+ return pos > 0.0
+ return True
+
+ posMin, posMax = self.getFiniteRange()
+ if not isDisplayable(posMin):
+ posMin = None
+ if not isDisplayable(posMax):
+ posMax = None
+
+ return posMin, posMax
+
+ def _initPlot(self):
+ """Init the plot to display the range and the values"""
+ self._plot = PlotWidget(self)
+ self._plot.setDataMargins(0.125, 0.125, 0.125, 0.125)
+ self._plot.getXAxis().setLabel("Data Values")
+ self._plot.getYAxis().setLabel("")
+ self._plot.setInteractiveMode('select', zoomOnWheel=False)
+ self._plot.setActiveCurveHandling(False)
+ self._plot.setMinimumSize(qt.QSize(250, 200))
+ self._plot.sigPlotSignal.connect(self._plotEventReceived)
+ palette = self.palette()
+ color = palette.color(qt.QPalette.Normal, qt.QPalette.Window)
+ self._plot.setBackgroundColor(color)
+ self._plot.setDataBackgroundColor("white")
+
+ lut = numpy.arange(256)
+ lut.shape = 1, -1
+ self._plot.addImage(lut, legend='lut')
+ self._lutItem = self._plot._getItem("image", "lut")
+ self._lutItem.setVisible(False)
+
+ self._plot.addScatter(x=[], y=[], value=[], legend='lut2')
+ self._lutItem2 = self._plot._getItem("scatter", "lut2")
+ self._lutItem2.setVisible(False)
+ self.__lutY = numpy.array([-0.05] * 256)
+ self.__lutV = numpy.arange(256)
+
+ self._bound = BoundingRect()
+ self._plot.addItem(self._bound)
+ self._bound.setVisible(True)
+
+ # Add plot for histogram
+ self._plotToolbar = qt.QToolBar(self)
+ self._plotToolbar.setFloatable(False)
+ self._plotToolbar.setMovable(False)
+ self._plotToolbar.setIconSize(qt.QSize(8, 8))
+ self._plotToolbar.setStyleSheet("QToolBar { border: 0px }")
+ self._plotToolbar.setOrientation(qt.Qt.Vertical)
+
+ group = qt.QActionGroup(self._plotToolbar)
+ group.setExclusive(True)
+
+ action = qt.QAction("Data range", self)
+ action.setToolTip("Display the data range within the colormap range. A fast data processing have to be done.")
+ action.setIcon(icons.getQIcon('colormap-range'))
+ action.setCheckable(True)
+ action.setData(_DataInPlotMode.RANGE)
+ action.setChecked(action.data() == self._dataInPlotMode)
+ self._plotToolbar.addAction(action)
+ group.addAction(action)
+ action = qt.QAction("Histogram", self)
+ action.setToolTip("Display the data histogram within the colormap range. A slow data processing have to be done. ")
+ action.setIcon(icons.getQIcon('colormap-histogram'))
+ action.setCheckable(True)
+ action.setData(_DataInPlotMode.HISTOGRAM)
+ action.setChecked(action.data() == self._dataInPlotMode)
+ self._plotToolbar.addAction(action)
+ group.addAction(action)
+ group.triggered.connect(self._displayDataInPlotModeChanged)
+
+ plotBoxLayout = qt.QHBoxLayout()
+ plotBoxLayout.setContentsMargins(0, 0, 0, 0)
+ plotBoxLayout.setSpacing(2)
+ plotBoxLayout.addWidget(self._plotToolbar)
+ plotBoxLayout.addWidget(self._plot)
+ plotBoxLayout.setSizeConstraint(qt.QLayout.SetMinimumSize)
+ self.setLayout(plotBoxLayout)
+
+ def _plotEventReceived(self, event):
+ """Handle events from the plot"""
+ kind = event['event']
+
+ if kind == 'markerMoving':
+ value = event['xdata']
+ if event['label'] == 'Min':
+ self._dragging = True, False
+ self._finiteRange = value, self._finiteRange[1]
+ self._last = value, None
+ self.sigRangeMoving.emit(*self._last)
+ elif event['label'] == 'Max':
+ self._dragging = False, True
+ self._finiteRange = self._finiteRange[0], value
+ self._last = None, value
+ self.sigRangeMoving.emit(*self._last)
+ self._updateLutItem(self._finiteRange)
+ elif kind == 'markerMoved':
+ self.sigRangeMoved.emit(*self._last)
+ self._plot.resetZoom()
+ self._dragging = False, False
+ else:
+ pass
+
+ def _updateMarkerPosition(self):
+ colormap = self.getColormap()
+ posMin, posMax = self._getDisplayableRange()
+
+ if colormap is None:
+ isDraggable = False
+ else:
+ isDraggable = colormap.isEditable()
+
+ with utils.blockSignals(self):
+ if posMin is not None and not self._dragging[0]:
+ self._plot.addXMarker(
+ posMin,
+ legend='Min',
+ text='Min',
+ draggable=isDraggable,
+ color="blue",
+ constraint=self._plotMinMarkerConstraint)
+ if posMax is not None and not self._dragging[1]:
+ self._plot.addXMarker(
+ posMax,
+ legend='Max',
+ text='Max',
+ draggable=isDraggable,
+ color="blue",
+ constraint=self._plotMaxMarkerConstraint)
+
+ self._updateLutItem((posMin, posMax))
+ self._plot.resetZoom()
+
+ def _updateLutItem(self, vRange):
+ colormap = self.getColormap()
+ if colormap is None:
+ return
+
+ if vRange is None:
+ posMin, posMax = self._getDisplayableRange()
+ else:
+ posMin, posMax = vRange
+ if posMin is None or posMax is None:
+ self._lutItem.setVisible(False)
+ pos = posMax if posMin is None else posMin
+ if pos is not None:
+ self._bound.setBounds((pos, pos, -0.1, 0))
+ else:
+ self._bound.setBounds((0, 0, -0.1, 0))
+ else:
+ norm = colormap.getNormalization()
+ normColormap = colormap.copy()
+ normColormap.setEditable(True)
+ normColormap.setVRange(0, 255)
+ normColormap.setNormalization(Colormap.LINEAR)
+ if norm == Colormap.LINEAR:
+ scale = (posMax - posMin) / 256
+ self._lutItem.setColormap(normColormap)
+ self._lutItem.setOrigin((posMin, -0.09))
+ self._lutItem.setScale((scale, 0.08))
+ self._lutItem.setVisible(True)
+ self._lutItem2.setVisible(False)
+ elif norm == Colormap.LOGARITHM:
+ self._lutItem2.setVisible(False)
+ self._lutItem2.setColormap(normColormap)
+ xx = numpy.geomspace(posMin, posMax, 256)
+ self._lutItem2.setData(x=xx,
+ y=self.__lutY,
+ value=self.__lutV,
+ copy=False)
+ self._lutItem2.setSymbol("|")
+ self._lutItem2.setVisible(True)
+ self._lutItem.setVisible(False)
+ else:
+ # Fallback: Display with linear axis and applied normalization
+ self._lutItem2.setVisible(False)
+ normColormap.setNormalization(norm)
+ self._lutItem2.setColormap(normColormap)
+ xx = numpy.linspace(posMin, posMax, 256, endpoint=True)
+ self._lutItem2.setData(
+ x=xx,
+ y=self.__lutY,
+ value=self.__lutV,
+ copy=False)
+ self._lutItem2.setSymbol("|")
+ self._lutItem2.setVisible(True)
+ self._lutItem.setVisible(False)
+
+ self._bound.setBounds((posMin, posMax, -0.1, 1))
+
+ def _plotMinMarkerConstraint(self, x, y):
+ """Constraint of the min marker"""
+ _vmin, vmax = self.getFiniteRange()
+ if vmax is None:
+ return x, y
+ return min(x, vmax), y
+
+ def _plotMaxMarkerConstraint(self, x, y):
+ """Constraint of the max marker"""
+ vmin, _vmax = self.getFiniteRange()
+ if vmin is None:
+ return x, y
+ return max(x, vmin), y
+
+ def _setDataInPlotMode(self, mode):
+ if self._dataInPlotMode == mode:
+ return
+ self._dataInPlotMode = mode
+ self._updateDataInPlot()
+
+ def _displayDataInPlotModeChanged(self, action):
+ mode = action.data()
+ self._setDataInPlotMode(mode)
+
+ def invalidateData(self):
+ self._histogramData = {}
+ self._dataRange = {}
+ self._invalidated = True
+ self.update()
+
+ def _updateDataInPlot(self):
+ mode = self._dataInPlotMode
+
+ norm = self._getNorm()
+ if norm == Colormap.LINEAR:
+ scale = Axis.LINEAR
+ elif norm == Colormap.LOGARITHM:
+ scale = Axis.LOGARITHMIC
+ else:
+ scale = Axis.LINEAR
+
+ axis = self._plot.getXAxis()
+ axis.setScale(scale)
+
+ if mode == _DataInPlotMode.RANGE:
+ dataRange = self._getNormalizedDataRange()
+ xmin, xmax = dataRange
+ if xmax is None or xmin is None:
+ self._plot.remove(legend='Data', kind='histogram')
+ else:
+ histogram = numpy.array([1])
+ bin_edges = numpy.array([xmin, xmax])
+ self._plot.addHistogram(histogram,
+ bin_edges,
+ legend="Data",
+ color='gray',
+ align='center',
+ fill=True,
+ z=1)
+
+ elif mode == _DataInPlotMode.HISTOGRAM:
+ histogram, bin_edges = self._getNormalizedHistogram()
+ if histogram is None or bin_edges is None:
+ self._plot.remove(legend='Data', kind='histogram')
+ else:
+ histogram = numpy.array(histogram, copy=True)
+ bin_edges = numpy.array(bin_edges, copy=True)
+ with numpy.errstate(invalid='ignore'):
+ norm_histogram = histogram / numpy.nanmax(histogram)
+ self._plot.addHistogram(norm_histogram,
+ bin_edges,
+ legend="Data",
+ color='gray',
+ align='center',
+ fill=True,
+ z=1)
+ else:
+ _logger.error("Mode unsupported")
+
+ def sizeHint(self):
+ return self.layout().minimumSize()
+
+ def updateLut(self):
+ self._updateLutItem(None)
+
+ def _getNorm(self):
+ colormap = self.getColormap()
+ if colormap is None:
+ return Axis.LINEAR
+ else:
+ norm = colormap.getNormalization()
+ return norm
+
+ def updateNormalization(self):
+ self._updateDataInPlot()
+ self.update()
+
+
+class ColormapDialog(qt.QDialog):
+ """A QDialog widget to set the colormap.
+
+ :param parent: See :class:`QDialog`
+ :param str title: The QDialog title
+ """
+
+ visibleChanged = qt.Signal(bool)
+ """This event is sent when the dialog visibility change"""
+
+ def __init__(self, parent=None, title="Colormap Dialog"):
+ qt.QDialog.__init__(self, parent)
+ self.setWindowTitle(title)
+
+ self.__aboutToDelete = False
+ self._colormap = None
+
+ self._data = None
+ """Weak ref to an external numpy array
+ """
+ self._itemHolder = None
+ """Hard ref to a private item (used as holder to the data)
+ This allow to reuse the item cache
+ """
+ self._item = None
+ """Weak ref to an external item"""
+
+ self._colormapChange = utils.LockReentrant()
+ """Used as a semaphore to avoid editing the colormap object when we are
+ only attempt to display it.
+ Used instead of n connect and disconnect of the sigChanged. The
+ disconnection to sigChanged was also limiting when this colormapdialog
+ is used in the colormapaction and associated to the activeImageChanged.
+ (because the activeImageChanged is send when the colormap changed and
+ the self.setcolormap is a callback)
+ """
+
+ self.__colormapInvalidated = False
+ self.__dataInvalidated = False
+
+ self._histogramData = None
+
+ self._dataRange = None
+ """If defined 3-tuple containing information from a data:
+ minimum, positive minimum, maximum"""
+
+ self._colormapStoredState = None
+
+ # Colormap row
+ self._comboBoxColormap = ColormapNameComboBox(parent=self)
+ self._comboBoxColormap.currentIndexChanged[int].connect(self._comboBoxColormapUpdated)
+
+ # Normalization row
+ self._comboBoxNormalization = qt.QComboBox(parent=self)
+ normalizations = [
+ ('Linear', Colormap.LINEAR),
+ ('Gamma correction', Colormap.GAMMA),
+ ('Arcsinh', Colormap.ARCSINH),
+ ('Logarithmic', Colormap.LOGARITHM),
+ ('Square root', Colormap.SQRT)]
+ for name, userData in normalizations:
+ try:
+ icon = icons.getQIcon("colormap-norm-%s" % userData)
+ except:
+ icon = qt.QIcon()
+ self._comboBoxNormalization.addItem(icon, name, userData)
+ self._comboBoxNormalization.currentIndexChanged[int].connect(
+ self._normalizationUpdated)
+
+ self._gammaSpinBox = qt.QDoubleSpinBox(parent=self)
+ self._gammaSpinBox.setEnabled(False)
+ self._gammaSpinBox.setRange(0., 1000.)
+ self._gammaSpinBox.setDecimals(4)
+ if hasattr(qt.QDoubleSpinBox, "setStepType"):
+ # Introduced in Qt 5.12
+ self._gammaSpinBox.setStepType(qt.QDoubleSpinBox.AdaptiveDecimalStepType)
+ else:
+ self._gammaSpinBox.setSingleStep(0.1)
+ self._gammaSpinBox.valueChanged.connect(self._gammaUpdated)
+ self._gammaSpinBox.setValue(2.)
+
+ autoScaleCombo = _AutoscaleModeComboBox(self)
+ autoScaleCombo.currentIndexChanged.connect(self._autoscaleModeUpdated)
+ self._autoScaleCombo = autoScaleCombo
+
+ # Min row
+ self._minValue = _BoundaryWidget(parent=self, value=1.0)
+ self._minValue.sigAutoScaleChanged.connect(self._minAutoscaleUpdated)
+ self._minValue.sigValueChanged.connect(self._minValueUpdated)
+
+ # Max row
+ self._maxValue = _BoundaryWidget(parent=self, value=10.0)
+ self._maxValue.sigAutoScaleChanged.connect(self._maxAutoscaleUpdated)
+ self._maxValue.sigValueChanged.connect(self._maxValueUpdated)
+
+ self._autoButtons = _AutoScaleButtons(self)
+ self._autoButtons.autoRangeChanged.connect(self._autoRangeButtonsUpdated)
+
+ rangeLayout = qt.QGridLayout()
+ miniFont = qt.QFont(self.font())
+ miniFont.setPixelSize(8)
+ labelMin = qt.QLabel("Min", self)
+ labelMin.setFont(miniFont)
+ labelMin.setAlignment(qt.Qt.AlignHCenter)
+ labelMax = qt.QLabel("Max", self)
+ labelMax.setAlignment(qt.Qt.AlignHCenter)
+ labelMax.setFont(miniFont)
+ rangeLayout.addWidget(labelMin, 0, 0)
+ rangeLayout.addWidget(labelMax, 0, 1)
+ rangeLayout.addWidget(self._minValue, 1, 0)
+ rangeLayout.addWidget(self._maxValue, 1, 1)
+ rangeLayout.addWidget(self._autoButtons, 2, 0, 1, -1, qt.Qt.AlignCenter)
+
+ self._histoWidget = _ColormapHistogram(self)
+ self._histoWidget.sigRangeMoving.connect(self._histogramRangeMoving)
+ self._histoWidget.sigRangeMoved.connect(self._histogramRangeMoved)
+
+ # Scale to buttons
+ self._visibleAreaButton = qt.QPushButton(self)
+ self._visibleAreaButton.setEnabled(False)
+ self._visibleAreaButton.setText("Visible Area")
+ self._visibleAreaButton.clicked.connect(
+ self._handleScaleToVisibleAreaClicked,
+ type=qt.Qt.QueuedConnection)
+
+ # Place-holder for selected area ROI manager
+ self._roiForColormapManager = None
+
+ self._selectedAreaButton = WaitingPushButton(self)
+ self._selectedAreaButton.setEnabled(False)
+ self._selectedAreaButton.setText("Selection")
+ self._selectedAreaButton.setIcon(icons.getQIcon("add-shape-rectangle"))
+ self._selectedAreaButton.setCheckable(True)
+ self._selectedAreaButton.setDisabledWhenWaiting(False)
+ self._selectedAreaButton.toggled.connect(
+ self._handleScaleToSelectionToggled,
+ type=qt.Qt.QueuedConnection)
+
+ # define modal buttons
+ types = qt.QDialogButtonBox.Ok | qt.QDialogButtonBox.Cancel
+ self._buttonsModal = qt.QDialogButtonBox(parent=self)
+ self._buttonsModal.setStandardButtons(types)
+ self._buttonsModal.accepted.connect(self.accept)
+ self._buttonsModal.rejected.connect(self.reject)
+
+ # define non modal buttons
+ types = qt.QDialogButtonBox.Close | qt.QDialogButtonBox.Reset
+ self._buttonsNonModal = qt.QDialogButtonBox(parent=self)
+ self._buttonsNonModal.setStandardButtons(types)
+ button = self._buttonsNonModal.button(qt.QDialogButtonBox.Close)
+ button.clicked.connect(self.accept)
+ button.setDefault(True)
+ button = self._buttonsNonModal.button(qt.QDialogButtonBox.Reset)
+ button.clicked.connect(self.resetColormap)
+
+ self._buttonsModal.setFocus(qt.Qt.OtherFocusReason)
+ self._buttonsNonModal.setFocus(qt.Qt.OtherFocusReason)
+
+ # Set the colormap to default values
+ self.setColormap(Colormap(name='gray', normalization='linear',
+ vmin=None, vmax=None))
+
+ self.setModal(self.isModal())
+
+ formLayout = qt.QFormLayout(self)
+ formLayout.setContentsMargins(10, 10, 10, 10)
+ formLayout.addRow('Colormap:', self._comboBoxColormap)
+ formLayout.addRow('Normalization:', self._comboBoxNormalization)
+ formLayout.addRow('Gamma:', self._gammaSpinBox)
+ formLayout.addRow(self._histoWidget)
+ formLayout.addRow(rangeLayout)
+ label = qt.QLabel('Mode:', self)
+ self._autoscaleModeLabel = label
+ label.setToolTip("Mode for autoscale. Algorithm used to find range in auto scale.")
+ formLayout.addItem(qt.QSpacerItem(1, 1, qt.QSizePolicy.Fixed, qt.QSizePolicy.Fixed))
+ formLayout.addRow(label, autoScaleCombo)
+
+ layout = qt.QHBoxLayout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.addWidget(self._visibleAreaButton)
+ layout.addWidget(self._selectedAreaButton)
+ self._scaleToAreaGroup = qt.QGroupBox('Scale to:', self)
+ self._scaleToAreaGroup.setLayout(layout)
+ self._scaleToAreaGroup.setVisible(False)
+ formLayout.addRow(self._scaleToAreaGroup)
+
+ formLayout.addRow(self._buttonsModal)
+ formLayout.addRow(self._buttonsNonModal)
+ formLayout.setSizeConstraint(qt.QLayout.SetMinimumSize)
+
+ self.setTabOrder(self._comboBoxColormap, self._comboBoxNormalization)
+ self.setTabOrder(self._comboBoxNormalization, self._gammaSpinBox)
+ self.setTabOrder(self._gammaSpinBox, self._minValue)
+ self.setTabOrder(self._minValue, self._maxValue)
+ self.setTabOrder(self._maxValue, self._autoButtons)
+ self.setTabOrder(self._autoButtons, self._autoScaleCombo)
+ self.setTabOrder(self._autoScaleCombo, self._visibleAreaButton)
+ self.setTabOrder(self._visibleAreaButton, self._selectedAreaButton)
+ self.setTabOrder(self._selectedAreaButton, self._buttonsModal)
+ self.setTabOrder(self._buttonsModal, self._buttonsNonModal)
+
+ self.setFixedSize(self.sizeHint())
+ self._applyColormap()
+
+ def _invalidateColormap(self):
+ if self.isVisible():
+ self._applyColormap()
+ else:
+ self.__colormapInvalidated = True
+
+ def _invalidateData(self):
+ if self.isVisible():
+ self._updateWidgetRange()
+ self._histoWidget.invalidateData()
+ else:
+ self.__dataInvalidated = True
+
+ def _validate(self):
+ if self.__colormapInvalidated:
+ self._applyColormap()
+ if self.__dataInvalidated:
+ self._histoWidget.invalidateData()
+ if self.__dataInvalidated or self.__colormapInvalidated:
+ self._updateWidgetRange()
+ self.__dataInvalidated = False
+ self.__colormapInvalidated = False
+
+ def showEvent(self, event):
+ self.visibleChanged.emit(True)
+ super(ColormapDialog, self).showEvent(event)
+ if self.isVisible():
+ self._validate()
+
+ def closeEvent(self, event):
+ if not self.isModal():
+ self.accept()
+ super(ColormapDialog, self).closeEvent(event)
+
+ def hideEvent(self, event):
+ self.visibleChanged.emit(False)
+ super(ColormapDialog, self).hideEvent(event)
+
+ def close(self):
+ self.accept()
+ qt.QDialog.close(self)
+
+ def setModal(self, modal):
+ assert type(modal) is bool
+ self._buttonsNonModal.setVisible(not modal)
+ self._buttonsModal.setVisible(modal)
+ qt.QDialog.setModal(self, modal)
+
+ def event(self, event):
+ if event.type() == qt.QEvent.DeferredDelete:
+ self.__aboutToDelete = True
+ return super(ColormapDialog, self).event(event)
+
+ def exec(self):
+ wasModal = self.isModal()
+ self.setModal(True)
+ result = super(ColormapDialog, self).exec()
+ if not self.__aboutToDelete:
+ self.setModal(wasModal)
+ return result
+
+ def exec_(self): # Qt5 compatibility wrapper
+ return self.exec()
+
+ def _getFiniteColormapRange(self):
+ """Return a colormap range where auto ranges are fixed
+ according to the available data.
+ """
+ colormap = self.getColormap()
+ if colormap is None:
+ return 1, 10
+
+ item = self._getItem()
+ if item is not None:
+ return colormap.getColormapRange(item)
+ # If there is not item, there is no data
+ return colormap.getColormapRange(None)
+
+ @staticmethod
+ def computeDataRange(data):
+ """Compute the data range as used by :meth:`setDataRange`.
+
+ :param data: The data to process
+ :rtype: List[Union[None,float]]
+ """
+ if data is None or len(data) == 0:
+ return None, None, None
+
+ dataRange = min_max(data, min_positive=True, finite=True)
+ if dataRange.minimum is None:
+ # Only non-finite data
+ dataRange = None
+
+ if dataRange is not None:
+ dataRange = dataRange.minimum, dataRange.min_positive, dataRange.maximum
+
+ if dataRange is None or len(dataRange) != 3:
+ qt.QMessageBox.warning(
+ None, "No Data",
+ "Image data does not contain any real value")
+ dataRange = 1., 1., 10.
+
+ return dataRange
+
+ @staticmethod
+ def computeHistogram(data, scale=Axis.LINEAR, dataRange=None):
+ """Compute the data histogram as used by :meth:`setHistogram`.
+
+ :param data: The data to process
+ :param dataRange: Optional range to compute the histogram, which is a
+ tuple of min, max
+ :rtype: Tuple(List(float),List(float)
+ """
+ # For compatibility
+ if scale == Axis.LOGARITHMIC:
+ scale = Colormap.LOGARITHM
+
+ if data is None:
+ return None, None
+
+ if len(data) == 0:
+ return None, None
+
+ if data.ndim == 3: # RGB(A) images
+ _logger.info('Converting current image from RGB(A) to grayscale\
+ in order to compute the intensity distribution')
+ data = (data[:,:, 0] * 0.299 +
+ data[:,:, 1] * 0.587 +
+ data[:,:, 2] * 0.114)
+
+ # bad hack: get 256 continuous bins in the case we have a B&W
+ normalizeData = True
+ if numpy.issubdtype(data.dtype, numpy.ubyte):
+ normalizeData = False
+ elif numpy.issubdtype(data.dtype, numpy.integer):
+ if dataRange is not None:
+ xmin, xmax = dataRange
+ if xmin is not None and xmax is not None:
+ normalizeData = (xmax - xmin) > 255
+
+ if normalizeData:
+ if scale == Colormap.LOGARITHM:
+ with numpy.errstate(divide='ignore', invalid='ignore'):
+ data = numpy.log10(data)
+
+ if dataRange is not None:
+ xmin, xmax = dataRange
+ if xmin is None:
+ return None, None
+ if normalizeData:
+ if scale == Colormap.LOGARITHM:
+ xmin, xmax = numpy.log10(xmin), numpy.log10(xmax)
+ else:
+ xmin, xmax = min_max(data, min_positive=False, finite=True)
+
+ if xmin is None:
+ return None, None
+
+ nbins = min(256, int(numpy.sqrt(data.size)))
+ data_range = xmin, xmax
+
+ # bad hack: get 256 bins in the case we have a B&W
+ if numpy.issubdtype(data.dtype, numpy.integer):
+ if nbins > xmax - xmin:
+ nbins = int(xmax - xmin)
+
+ nbins = max(2, nbins)
+ data = data.ravel().astype(numpy.float32)
+
+ histogram = Histogramnd(data, n_bins=nbins, histo_range=data_range)
+ bins = histogram.edges[0]
+ if normalizeData:
+ if scale == Colormap.LOGARITHM:
+ bins = 10 ** bins
+ return histogram.histo, bins
+
+ def _getItem(self):
+ if self._itemHolder is not None:
+ return self._itemHolder
+ if self._item is None:
+ return None
+ return self._item()
+
+ def setItem(self, item):
+ """Store the plot item.
+
+ According to the state of the dialog, the item will be used to display
+ the data range or the histogram of the data using :meth:`setDataRange`
+ and :meth:`setHistogram`
+ """
+ # While event from items are not supported, we can't ignore dup items
+ # old = self._getItem()
+ # if old is item:
+ # return
+ self._data = None
+ self._itemHolder = None
+ try:
+ if item is None:
+ self._item = None
+ else:
+ if not isinstance(item, items.ColormapMixIn):
+ self._item = None
+ raise ValueError("Item %s is not supported" % item)
+ self._item = weakref.ref(item, self._itemAboutToFinalize)
+ finally:
+ self._syncScaleToButtonsEnabled()
+ self._dataRange = None
+ self._histogramData = None
+ self._invalidateData()
+
+ def _getData(self):
+ if self._data is None:
+ return None
+ return self._data()
+
+ def setData(self, data):
+ """Store the data
+
+ According to the state of the dialog, the data will be used to display
+ the data range or the histogram of the data using :meth:`setDataRange`
+ and :meth:`setHistogram`
+ """
+ oldData = self._getData()
+ if oldData is data:
+ return
+
+ self._item = None
+ self._syncScaleToButtonsEnabled()
+ if data is None:
+ self._data = None
+ self._itemHolder = None
+ else:
+ self._data = weakref.ref(data, self._dataAboutToFinalize)
+ self._itemHolder = _DataRefHolder(self._data)
+
+ self._dataRange = None
+ self._histogramData = None
+
+ self._invalidateData()
+
+ def _getArray(self):
+ data = self._getData()
+ if data is not None:
+ return data
+ item = self._getItem()
+ if item is not None:
+ return item.getColormappedData(copy=False)
+ return None
+
+ def _colormapAboutToFinalize(self, weakrefColormap):
+ """Callback when the data weakref is about to be finalized."""
+ if self._colormap is weakrefColormap and qtinspect.isValid(self):
+ self.setColormap(None)
+
+ def _dataAboutToFinalize(self, weakrefData):
+ """Callback when the data weakref is about to be finalized."""
+ if self._data is weakrefData and qtinspect.isValid(self):
+ self.setData(None)
+
+ def _itemAboutToFinalize(self, weakref):
+ """Callback when the data weakref is about to be finalized."""
+ if self._item is weakref and qtinspect.isValid(self):
+ self.setItem(None)
+
+ @deprecation.deprecated(reason="It is private data", since_version="0.13")
+ def getHistogram(self):
+ histo = self._getHistogram()
+ if histo is None:
+ return None
+ counts, bin_edges = histo
+ return numpy.array(counts, copy=True), numpy.array(bin_edges, copy=True)
+
+ def _getHistogram(self):
+ """Returns the histogram defined by the dialog as metadata
+ to describe the data in order to speed up the dialog.
+
+ :return: (hist, bin_edges)
+ :rtype: 2-tuple of numpy arrays"""
+ return self._histogramData
+
+ def setHistogram(self, hist=None, bin_edges=None):
+ """Set the histogram to display.
+
+ This update the data range with the bounds of the bins.
+
+ :param hist: array-like of counts or None to hide histogram
+ :param bin_edges: array-like of bins edges or None to hide histogram
+ """
+ if hist is None or bin_edges is None:
+ self._histogramData = None
+ else:
+ self._histogramData = numpy.array(hist), numpy.array(bin_edges)
+
+ self._invalidateData()
+
+ def getColormap(self):
+ """Return the colormap description.
+
+ :rtype: ~silx.gui.colors.Colormap
+ """
+ if self._colormap is None:
+ return None
+ return self._colormap()
+
+ def resetColormap(self):
+ """
+ Reset the colormap state before modification.
+
+ ..note :: the colormap reference state is the state when set or the
+ state when validated
+ """
+ colormap = self.getColormap()
+ if colormap is not None and self._colormapStoredState is not None:
+ if colormap != self._colormapStoredState:
+ with self._colormapChange:
+ colormap.setFromColormap(self._colormapStoredState)
+ self._applyColormap()
+
+ def _getDataRange(self):
+ """Returns the data range defined by the dialog as metadata
+ to describe the data in order to speed up the dialog.
+
+ :return: (minimum, positiveMin, maximum)
+ :rtype: 3-tuple of floats or None"""
+ return self._dataRange
+
+ def setDataRange(self, minimum=None, positiveMin=None, maximum=None):
+ """Set the range of data to use for the range of the histogram area.
+
+ :param float minimum: The minimum of the data
+ :param float positiveMin: The positive minimum of the data
+ :param float maximum: The maximum of the data
+ """
+ self._dataRange = minimum, positiveMin, maximum
+ self._invalidateData()
+
+ def _setColormapRange(self, xmin, xmax):
+ """Set a new range to the held colormap and update the
+ widget."""
+ colormap = self.getColormap()
+ if colormap is not None:
+ with self._colormapChange:
+ colormap.setVRange(xmin, xmax)
+ self._updateWidgetRange()
+
+ def setColormapRangeFromDataBounds(self, bounds):
+ """Set the range of the colormap from current item and rect.
+
+ If there is no ColormapMixIn item attached to the ColormapDialog,
+ nothing is done.
+
+ :param Union[List[float],None] bounds:
+ (xmin, xmax, ymin, ymax) Rectangular region in data space
+ """
+ if bounds is None:
+ return None # no-op
+
+ colormap = self.getColormap()
+ if colormap is None:
+ return # no-op
+
+ item = self._getItem()
+ if not isinstance(item, items.ColormapMixIn):
+ return None # no-op
+
+ data = item.getColormappedData(copy=False)
+
+ xmin, xmax, ymin, ymax = bounds
+
+ if isinstance(item, items.ImageBase):
+ ox, oy = item.getOrigin()
+ sx, sy = item.getScale()
+
+ ystart = max(0, int((ymin - oy) / sy))
+ ystop = max(0, int(numpy.ceil((ymax - oy) / sy)))
+ xstart = max(0, int((xmin - ox) / sx))
+ xstop = max(0, int(numpy.ceil((xmax - ox) / sx)))
+
+ subset = data[ystart:ystop, xstart:xstop]
+
+ elif isinstance(item, items.Scatter):
+ x = item.getXData(copy=False)
+ y = item.getYData(copy=False)
+ subset = data[
+ numpy.logical_and(
+ numpy.logical_and(xmin <= x, x <= xmax),
+ numpy.logical_and(ymin <= y, y <= ymax))]
+
+ if subset.size == 0:
+ return # no-op
+
+ vmin, vmax = colormap._computeAutoscaleRange(subset)
+ self._setColormapRange(vmin, vmax)
+
+ def _updateWidgetRange(self):
+ """Update the colormap range displayed into the widget."""
+ xmin, xmax = self._getFiniteColormapRange()
+ colormap = self.getColormap()
+ if colormap is not None:
+ vRange = colormap.getVRange()
+ autoMin, autoMax = (r is None for r in vRange)
+ else:
+ autoMin, autoMax = False, False
+
+ with utils.blockSignals(self._minValue):
+ self._minValue.setValue(xmin, autoMin)
+ with utils.blockSignals(self._maxValue):
+ self._maxValue.setValue(xmax, autoMax)
+ with utils.blockSignals(self._histoWidget):
+ self._histoWidget.setFiniteRange((xmin, xmax))
+ with utils.blockSignals(self._autoButtons):
+ self._autoButtons.setAutoRange((autoMin, autoMax))
+ self._autoscaleModeLabel.setEnabled(autoMin or autoMax)
+
+ def accept(self):
+ self.storeCurrentState()
+ qt.QDialog.accept(self)
+
+ def storeCurrentState(self):
+ """
+ save the current value sof the colormap if the user want to undo is
+ modifications
+ """
+ colormap = self.getColormap()
+ if colormap is not None:
+ self._colormapStoredState = colormap.copy()
+ else:
+ self._colormapStoredState = None
+
+ def reject(self):
+ self.resetColormap()
+ qt.QDialog.reject(self)
+
+ def setColormap(self, colormap):
+ """Set the colormap description
+
+ :param ~silx.gui.colors.Colormap colormap: the colormap to edit
+ """
+ assert colormap is None or isinstance(colormap, Colormap)
+ if self._colormapChange.locked():
+ return
+
+ oldColormap = self.getColormap()
+ if oldColormap is colormap:
+ return
+ if oldColormap is not None:
+ oldColormap.sigChanged.disconnect(self._applyColormap)
+
+ if colormap is not None:
+ colormap.sigChanged.connect(self._applyColormap)
+ colormap = weakref.ref(colormap, self._colormapAboutToFinalize)
+
+ self._colormap = colormap
+ self.storeCurrentState()
+ self._invalidateColormap()
+
+ def _updateResetButton(self):
+ resetButton = self._buttonsNonModal.button(qt.QDialogButtonBox.Reset)
+ rStateEnabled = False
+ colormap = self.getColormap()
+ if colormap is not None and colormap.isEditable():
+ # can reset only in the case the colormap changed
+ rStateEnabled = colormap != self._colormapStoredState
+ resetButton.setEnabled(rStateEnabled)
+
+ def _applyColormap(self):
+ self._updateResetButton()
+ if self._colormapChange.locked():
+ return
+
+ self._syncScaleToButtonsEnabled()
+
+ colormap = self.getColormap()
+ if colormap is None:
+ self._comboBoxColormap.setEnabled(False)
+ self._comboBoxNormalization.setEnabled(False)
+ self._gammaSpinBox.setEnabled(False)
+ self._autoScaleCombo.setEnabled(False)
+ self._minValue.setEnabled(False)
+ self._maxValue.setEnabled(False)
+ self._autoButtons.setEnabled(False)
+ self._autoscaleModeLabel.setEnabled(False)
+ self._histoWidget.setVisible(False)
+ self._histoWidget.setFiniteRange((None, None))
+ else:
+ assert colormap.getNormalization() in Colormap.NORMALIZATIONS
+ with utils.blockSignals(self._comboBoxColormap):
+ self._comboBoxColormap.setCurrentLut(colormap)
+ self._comboBoxColormap.setEnabled(colormap.isEditable())
+ with utils.blockSignals(self._comboBoxNormalization):
+ index = self._comboBoxNormalization.findData(
+ colormap.getNormalization())
+ if index < 0:
+ _logger.error('Unsupported normalization: %s' %
+ colormap.getNormalization())
+ else:
+ self._comboBoxNormalization.setCurrentIndex(index)
+ self._comboBoxNormalization.setEnabled(colormap.isEditable())
+ with utils.blockSignals(self._gammaSpinBox):
+ self._gammaSpinBox.setValue(
+ colormap.getGammaNormalizationParameter())
+ self._gammaSpinBox.setEnabled(
+ colormap.getNormalization() == 'gamma' and
+ colormap.isEditable())
+ with utils.blockSignals(self._autoScaleCombo):
+ self._autoScaleCombo.setCurrentMode(colormap.getAutoscaleMode())
+ self._autoScaleCombo.setEnabled(colormap.isEditable())
+ with utils.blockSignals(self._autoButtons):
+ self._autoButtons.setEnabled(colormap.isEditable())
+ self._autoButtons.setAutoRangeFromColormap(colormap)
+
+ vmin, vmax = colormap.getVRange()
+ if vmin is None or vmax is None:
+ # Compute it only if needed
+ dataRange = self._getFiniteColormapRange()
+ else:
+ dataRange = vmin, vmax
+
+ with utils.blockSignals(self._minValue):
+ self._minValue.setValue(vmin or dataRange[0], isAuto=vmin is None)
+ self._minValue.setEnabled(colormap.isEditable())
+ with utils.blockSignals(self._maxValue):
+ self._maxValue.setValue(vmax or dataRange[1], isAuto=vmax is None)
+ self._maxValue.setEnabled(colormap.isEditable())
+ self._autoscaleModeLabel.setEnabled(vmin is None or vmax is None)
+
+ with utils.blockSignals(self._histoWidget):
+ self._histoWidget.setVisible(True)
+ self._histoWidget.setFiniteRange(dataRange)
+ self._histoWidget.updateNormalization()
+
+ def _comboBoxColormapUpdated(self):
+ """Callback executed when the combo box with the colormap LUT
+ is updated by user input.
+ """
+ colormap = self.getColormap()
+ if colormap is not None:
+ with self._colormapChange:
+ name = self._comboBoxColormap.getCurrentName()
+ if name is not None:
+ colormap.setName(name)
+ else:
+ lut = self._comboBoxColormap.getCurrentColors()
+ colormap.setColormapLUT(lut)
+ self._histoWidget.updateLut()
+
+ def _autoRangeButtonsUpdated(self, autoRange):
+ """Callback executed when the autoscale buttons widget
+ is updated by user input.
+ """
+ dataRange = self._getFiniteColormapRange()
+
+ # Final colormap range
+ vmin = (dataRange[0] if not autoRange[0] else None)
+ vmax = (dataRange[1] if not autoRange[1] else None)
+
+ with self._colormapChange:
+ colormap = self.getColormap()
+ colormap.setVRange(vmin, vmax)
+
+ with utils.blockSignals(self._minValue):
+ self._minValue.setValue(vmin or dataRange[0], isAuto=vmin is None)
+ with utils.blockSignals(self._maxValue):
+ self._maxValue.setValue(vmax or dataRange[1], isAuto=vmax is None)
+
+ self._updateWidgetRange()
+
+ def _normalizationUpdated(self, index):
+ """Callback executed when the normalization widget
+ is updated by user input.
+ """
+ colormap = self.getColormap()
+ if colormap is not None:
+ normalization = self._comboBoxNormalization.itemData(index)
+ self._gammaSpinBox.setEnabled(normalization == 'gamma')
+
+ with self._colormapChange:
+ colormap.setNormalization(normalization)
+ self._histoWidget.updateNormalization()
+
+ self._updateWidgetRange()
+
+ def _gammaUpdated(self, value):
+ """Callback used to update the gamma normalization parameter"""
+ colormap = self.getColormap()
+ if colormap is not None:
+ colormap.setGammaNormalizationParameter(value)
+
+ def _autoscaleModeUpdated(self):
+ """Callback executed when the autoscale mode widget
+ is updated by user input.
+ """
+ mode = self._autoScaleCombo.currentMode()
+
+ colormap = self.getColormap()
+ if colormap is not None:
+ with self._colormapChange:
+ colormap.setAutoscaleMode(mode)
+
+ self._updateWidgetRange()
+
+ def _minAutoscaleUpdated(self, autoEnabled):
+ """Callback executed when the min autoscale from
+ the lineedit is updated by user input"""
+ colormap = self.getColormap()
+ xmin, xmax = colormap.getVRange()
+ if autoEnabled:
+ xmin = None
+ else:
+ xmin, _xmax = self._getFiniteColormapRange()
+ self._setColormapRange(xmin, xmax)
+
+ def _maxAutoscaleUpdated(self, autoEnabled):
+ """Callback executed when the max autoscale from
+ the lineedit is updated by user input"""
+ colormap = self.getColormap()
+ xmin, xmax = colormap.getVRange()
+ if autoEnabled:
+ xmax = None
+ else:
+ _xmin, xmax = self._getFiniteColormapRange()
+ self._setColormapRange(xmin, xmax)
+
+ def _minValueUpdated(self, value):
+ """Callback executed when the lineedit min value is
+ updated by user input"""
+ xmin = value
+ xmax = self._maxValue.getValue()
+ if xmax is not None and xmin > xmax:
+ # FIXME: This should be done in the widget itself
+ xmin = xmax
+ with utils.blockSignals(self._minValue):
+ self._minValue.setValue(xmin)
+ self._setColormapRange(xmin, xmax)
+
+ def _maxValueUpdated(self, value):
+ """Callback executed when the lineedit max value is
+ updated by user input"""
+ xmin = self._minValue.getValue()
+ xmax = value
+ if xmin is not None and xmin > xmax:
+ # FIXME: This should be done in the widget itself
+ xmax = xmin
+ with utils.blockSignals(self._maxValue):
+ self._maxValue.setValue(xmax)
+ self._setColormapRange(xmin, xmax)
+
+ def _histogramRangeMoving(self, vmin, vmax):
+ """Callback executed when for colormap range displayed in
+ the histogram widget is moving.
+
+ :param vmin: Update of the minimum range, else None
+ :param vmax: Update of the maximum range, else None
+ """
+ colormap = self.getColormap()
+ if vmin is not None:
+ with self._colormapChange:
+ colormap.setVMin(vmin)
+ self._minValue.setValue(vmin)
+ if vmax is not None:
+ with self._colormapChange:
+ colormap.setVMax(vmax)
+ self._maxValue.setValue(vmax)
+
+ def _histogramRangeMoved(self, vmin, vmax):
+ """Callback executed when for colormap range displayed in
+ the histogram widget has finished to move
+ """
+ xmin = self._minValue.getValue()
+ xmax = self._maxValue.getValue()
+ if vmin is None:
+ vmin = xmin
+ if vmax is None:
+ vmax = xmax
+ self._setColormapRange(vmin, vmax)
+
+ def _syncScaleToButtonsEnabled(self):
+ """Set the state of scale to buttons according to current item and colormap"""
+ colormap = self.getColormap()
+ enabled = self._item is not None and colormap is not None and colormap.isEditable()
+ self._scaleToAreaGroup.setVisible(enabled)
+ self._visibleAreaButton.setEnabled(enabled)
+ if not enabled:
+ self._selectedAreaButton.setChecked(False)
+ self._selectedAreaButton.setEnabled(enabled)
+
+ def _handleScaleToVisibleAreaClicked(self):
+ """Set colormap range from current item's visible area"""
+ item = self._getItem()
+ if item is None:
+ return # no-op
+
+ bounds = item.getVisibleBounds()
+ if bounds is None:
+ return # no-op
+
+ self.setColormapRangeFromDataBounds(bounds)
+
+ def _handleScaleToSelectionToggled(self, checked=False):
+ """Handle toggle of scale to selected are button"""
+ # Reset any previous ROI manager
+ if self._roiForColormapManager is not None:
+ self._roiForColormapManager.clear()
+ self._roiForColormapManager.stop()
+ self._roiForColormapManager = None
+
+ if not checked: # Reset button status
+ self._selectedAreaButton.setWaiting(False)
+ self._selectedAreaButton.setText("Selection")
+ return
+
+ item = self._getItem()
+ if item is None:
+ self._selectedAreaButton.setChecked(False)
+ return # no-op
+
+ plotWidget = item.getPlot()
+ if plotWidget is None:
+ self._selectedAreaButton.setChecked(False)
+ return # no-op
+
+ self._selectedAreaButton.setWaiting(True)
+ self._selectedAreaButton.setText("Draw Area...")
+
+ self._roiForColormapManager = RegionOfInterestManager(parent=plotWidget)
+ cmap = self.getColormap()
+ self._roiForColormapManager.setColor(
+ 'black' if cmap is None else cursorColorForColormap(cmap.getName()))
+ self._roiForColormapManager.sigInteractiveModeFinished.connect(
+ self.__roiInteractiveModeFinished)
+ self._roiForColormapManager.sigInteractiveRoiFinalized.connect(self.__roiFinalized)
+ self._roiForColormapManager.start(RectangleROI)
+
+ def __roiInteractiveModeFinished(self):
+ self._selectedAreaButton.setChecked(False)
+
+ def __roiFinalized(self, roi):
+ self._selectedAreaButton.setChecked(False)
+ if roi is not None:
+ ox, oy = roi.getOrigin()
+ width, height = roi.getSize()
+ self.setColormapRangeFromDataBounds((ox, ox+width, oy, oy+height))
+
+ def keyPressEvent(self, event):
+ """Override key handling.
+
+ It disables leaving the dialog when editing a text field.
+
+ But several press of Return key can be use to validate and close the
+ dialog.
+ """
+ if event.key() in (qt.Qt.Key_Enter, qt.Qt.Key_Return):
+ # Bypass QDialog keyPressEvent
+ # To avoid leaving the dialog when pressing enter on a text field
+ if self._minValue.hasFocus():
+ nextFocus = self._maxValue
+ elif self._maxValue.hasFocus():
+ if self.isModal():
+ nextFocus = self._buttonsModal.button(qt.QDialogButtonBox.Apply)
+ else:
+ nextFocus = self._buttonsNonModal.button(qt.QDialogButtonBox.Close)
+ else:
+ nextFocus = None
+ if nextFocus is not None:
+ nextFocus.setFocus(qt.Qt.OtherFocusReason)
+ else:
+ super(ColormapDialog, self).keyPressEvent(event)
diff --git a/src/silx/gui/dialog/DataFileDialog.py b/src/silx/gui/dialog/DataFileDialog.py
new file mode 100644
index 0000000..0d0382d
--- /dev/null
+++ b/src/silx/gui/dialog/DataFileDialog.py
@@ -0,0 +1,340 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module contains an :class:`DataFileDialog`.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "14/02/2018"
+
+import enum
+import logging
+from silx.gui import qt
+from silx.gui.hdf5.Hdf5Formatter import Hdf5Formatter
+import silx.io
+from .AbstractDataFileDialog import AbstractDataFileDialog
+
+import fabio
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _DataPreview(qt.QWidget):
+ """Provide a preview of the selected image"""
+
+ def __init__(self, parent=None):
+ super(_DataPreview, self).__init__(parent)
+
+ self.__formatter = Hdf5Formatter(self)
+ self.__data = None
+ self.__info = qt.QTableView(self)
+ self.__model = qt.QStandardItemModel(self)
+ self.__info.setModel(self.__model)
+ self.__info.horizontalHeader().hide()
+ self.__info.horizontalHeader().setStretchLastSection(True)
+ layout = qt.QVBoxLayout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.addWidget(self.__info)
+ self.setLayout(layout)
+
+ def colormap(self):
+ return None
+
+ def setColormap(self, colormap):
+ # Ignored
+ pass
+
+ def sizeHint(self):
+ return qt.QSize(200, 200)
+
+ def setData(self, data, fromDataSelector=False):
+ self.__info.setEnabled(data is not None)
+ if data is None:
+ self.__model.clear()
+ else:
+ self.__model.clear()
+
+ if silx.io.is_dataset(data):
+ kind = "Dataset"
+ elif silx.io.is_group(data):
+ kind = "Group"
+ elif silx.io.is_file(data):
+ kind = "File"
+ else:
+ kind = "Unknown"
+
+ headers = []
+
+ basename = data.name.split("/")[-1]
+ if basename == "":
+ basename = "/"
+ headers.append("Basename")
+ self.__model.appendRow([qt.QStandardItem(basename)])
+ headers.append("Kind")
+ self.__model.appendRow([qt.QStandardItem(kind)])
+ if hasattr(data, "dtype"):
+ headers.append("Type")
+ text = self.__formatter.humanReadableType(data)
+ self.__model.appendRow([qt.QStandardItem(text)])
+ if hasattr(data, "shape"):
+ headers.append("Shape")
+ text = self.__formatter.humanReadableShape(data)
+ self.__model.appendRow([qt.QStandardItem(text)])
+ if hasattr(data, "attrs") and "NX_class" in data.attrs:
+ headers.append("NX_class")
+ value = data.attrs["NX_class"]
+ formatter = self.__formatter.textFormatter()
+ old = formatter.useQuoteForText()
+ formatter.setUseQuoteForText(False)
+ text = self.__formatter.textFormatter().toString(value)
+ formatter.setUseQuoteForText(old)
+ self.__model.appendRow([qt.QStandardItem(text)])
+ self.__model.setVerticalHeaderLabels(headers)
+ self.__data = data
+
+ def __imageItem(self):
+ image = self.__plot.getImage("data")
+ return image
+
+ def data(self):
+ if self.__data is not None:
+ if hasattr(self.__data, "name"):
+ # in case of HDF5
+ if self.__data.name is None:
+ # The dataset was closed
+ self.__data = None
+ return self.__data
+
+ def clear(self):
+ self.__data = None
+ self.__info.setText("")
+
+
+class DataFileDialog(AbstractDataFileDialog):
+ """The `DataFileDialog` class provides a dialog that allow users to select
+ any datasets or groups from an HDF5-like file.
+
+ The `DataFileDialog` class enables a user to traverse the file system in
+ order to select an HDF5-like file. Then to traverse the file to select an
+ HDF5 node.
+
+ .. image:: img/datafiledialog.png
+
+ The selected data is any kind of group or dataset. It can be restricted
+ to only existing datasets or only existing groups using
+ :meth:`setFilterMode`. A callback can be defining using
+ :meth:`setFilterCallback` to filter even more data which can be returned.
+
+ Filtering data which can be returned by a `DataFileDialog` can be done like
+ that:
+
+ .. code-block:: python
+
+ # Force to return only a dataset
+ dialog = DataFileDialog()
+ dialog.setFilterMode(DataFileDialog.FilterMode.ExistingDataset)
+
+ .. code-block:: python
+
+ def customFilter(obj):
+ if "NX_class" in obj.attrs:
+ return obj.attrs["NX_class"] in [b"NXentry", u"NXentry"]
+ return False
+
+ # Force to return an NX entry
+ dialog = DataFileDialog()
+ # 1st, filter out everything which is not a group
+ dialog.setFilterMode(DataFileDialog.FilterMode.ExistingGroup)
+ # 2nd, check what NX_class is an NXentry
+ dialog.setFilterCallback(customFilter)
+
+ Executing a `DataFileDialog` can be done like that:
+
+ .. code-block:: python
+
+ dialog = DataFileDialog()
+ result = dialog.exec()
+ if result:
+ print("Selection:")
+ print(dialog.selectedFile())
+ print(dialog.selectedUrl())
+ else:
+ print("Nothing selected")
+
+ If the selection is a dataset you can access to the data using
+ :meth:`selectedData`.
+
+ If the selection is a group or if you want to read the selected object on
+ your own you can use the `silx.io` API.
+
+ .. code-block:: python
+
+ url = dialog.selectedUrl()
+ with silx.io.open(url) as data:
+ pass
+
+ Or by loading the file first
+
+ .. code-block:: python
+
+ url = dialog.selectedDataUrl()
+ with silx.io.open(url.file_path()) as h5:
+ data = h5[url.data_path()]
+
+ Or by using `h5py` library
+
+ .. code-block:: python
+
+ url = dialog.selectedDataUrl()
+ with h5py.File(url.file_path(), mode="r") as h5:
+ data = h5[url.data_path()]
+ """
+
+ class FilterMode(enum.Enum):
+ """This enum is used to indicate what the user may select in the
+ dialog; i.e. what the dialog will return if the user clicks OK."""
+
+ AnyNode = 0
+ """Any existing node from an HDF5-like file."""
+ ExistingDataset = 1
+ """An existing HDF5-like dataset."""
+ ExistingGroup = 2
+ """An existing HDF5-like group. A file root is a group."""
+
+ def __init__(self, parent=None):
+ AbstractDataFileDialog.__init__(self, parent=parent)
+ self.__filter = DataFileDialog.FilterMode.AnyNode
+ self.__filterCallback = None
+
+ def selectedData(self):
+ """Returns the selected data by using the :meth:`silx.io.get_data`
+ API with the selected URL provided by the dialog.
+
+ If the URL identify a group of a file it will raise an exception. For
+ group or file you have to use on your own the API :meth:`silx.io.open`.
+
+ :rtype: numpy.ndarray
+ :raise ValueError: If the URL do not link to a dataset
+ """
+ url = self.selectedUrl()
+ return silx.io.get_data(url)
+
+ def _createPreviewWidget(self, parent):
+ previewWidget = _DataPreview(parent)
+ previewWidget.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Expanding)
+ return previewWidget
+
+ def _createSelectorWidget(self, parent):
+ # There is no selector
+ return None
+
+ def _createPreviewToolbar(self, parent, dataPreviewWidget, dataSelectorWidget):
+ # There is no toolbar
+ return None
+
+ def _isDataSupportable(self, data):
+ """Check if the selected data can be supported at one point.
+
+ If true, the data selector will be checked and it will update the data
+ preview. Else the selecting is disabled.
+
+ :rtype: bool
+ """
+ # Everything is supported
+ return True
+
+ def _isFabioFilesSupported(self):
+ # Everything is supported
+ return False
+
+ def _isDataSupported(self, data):
+ """Check if the data can be returned by the dialog.
+
+ If true, this data can be returned by the dialog and the open button
+ will be enabled. If false the button will be disabled.
+
+ :rtype: bool
+ """
+ if self.__filter == DataFileDialog.FilterMode.AnyNode:
+ accepted = True
+ elif self.__filter == DataFileDialog.FilterMode.ExistingDataset:
+ accepted = silx.io.is_dataset(data)
+ elif self.__filter == DataFileDialog.FilterMode.ExistingGroup:
+ accepted = silx.io.is_group(data)
+ else:
+ raise ValueError("Filter %s is not supported" % self.__filter)
+ if not accepted:
+ return False
+ if self.__filterCallback is not None:
+ try:
+ return self.__filterCallback(data)
+ except Exception:
+ _logger.error("Error while executing custom callback", exc_info=True)
+ return False
+ return True
+
+ def setFilterCallback(self, callback):
+ """Set the filter callback. This filter is applied only if the filter
+ mode (:meth:`filterMode`) first accepts the selected data.
+
+ It is not supposed to be set while the dialog is being used.
+
+ :param callable callback: Define a custom function returning a boolean
+ and taking as argument an h5-like node. If the function returns true
+ the dialog can return the associated URL.
+ """
+ self.__filterCallback = callback
+
+ def setFilterMode(self, mode):
+ """Set the filter mode.
+
+ It is not supposed to be set while the dialog is being used.
+
+ :param DataFileDialog.FilterMode mode: The new filter.
+ """
+ self.__filter = mode
+
+ def fileMode(self):
+ """Returns the filter mode.
+
+ :rtype: DataFileDialog.FilterMode
+ """
+ return self.__filter
+
+ def _displayedDataInfo(self, dataBeforeSelection, dataAfterSelection):
+ """Returns the text displayed under the data preview.
+
+ This zone is used to display error in case or problem of data selection
+ or problems with IO.
+
+ :param numpy.ndarray dataAfterSelection: Data as it is after the
+ selection widget (basically the data from the preview widget)
+ :param numpy.ndarray dataAfterSelection: Data as it is before the
+ selection widget (basically the data from the browsing widget)
+ :rtype: bool
+ """
+ return u""
diff --git a/src/silx/gui/dialog/DatasetDialog.py b/src/silx/gui/dialog/DatasetDialog.py
new file mode 100644
index 0000000..c5ee295
--- /dev/null
+++ b/src/silx/gui/dialog/DatasetDialog.py
@@ -0,0 +1,122 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a dialog widget to select a HDF5 dataset in a
+tree.
+
+.. autoclass:: DatasetDialog
+ :members: addFile, addGroup, getSelectedDataUrl, setMode
+
+"""
+from .GroupDialog import _Hdf5ItemSelectionDialog
+import silx.io
+from silx.io.url import DataUrl
+
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "05/09/2018"
+
+
+class DatasetDialog(_Hdf5ItemSelectionDialog):
+ """This :class:`QDialog` uses a :class:`silx.gui.hdf5.Hdf5TreeView` to
+ provide a HDF5 dataset selection dialog.
+
+ The information identifying the selected node is provided as a
+ :class:`silx.io.url.DataUrl`.
+
+ Example:
+
+ .. code-block:: python
+
+ dialog = DatasetDialog()
+ dialog.addFile(filepath1)
+ dialog.addFile(filepath2)
+
+ if dialog.exec():
+ print("File path: %s" % dialog.getSelectedDataUrl().file_path())
+ print("HDF5 dataset path : %s " % dialog.getSelectedDataUrl().data_path())
+ else:
+ print("Operation cancelled :(")
+
+ """
+ def __init__(self, parent=None):
+ _Hdf5ItemSelectionDialog.__init__(self, parent)
+
+ # customization for groups
+ self.setWindowTitle("HDF5 dataset selection")
+
+ self._header.setSections([self._model.NAME_COLUMN,
+ self._model.NODE_COLUMN,
+ self._model.LINK_COLUMN,
+ self._model.TYPE_COLUMN,
+ self._model.SHAPE_COLUMN])
+ self._selectDatasetStatusText = "Select a dataset or type a new dataset name"
+
+ def setMode(self, mode):
+ """Set dialog mode DatasetDialog.SaveMode or DatasetDialog.LoadMode
+
+ :param mode: DatasetDialog.SaveMode or DatasetDialog.LoadMode
+ """
+ _Hdf5ItemSelectionDialog.setMode(self, mode)
+ if mode == DatasetDialog.SaveMode:
+ self._selectDatasetStatusText = "Select a dataset or type a new dataset name"
+ elif mode == DatasetDialog.LoadMode:
+ self._selectDatasetStatusText = "Select a dataset"
+
+ def _onActivation(self, idx):
+ # double-click or enter press: filter for datasets
+ nodes = list(self._tree.selectedH5Nodes())
+ node = nodes[0]
+ if silx.io.is_dataset(node.h5py_object):
+ self.accept()
+
+ def _updateUrl(self):
+ # overloaded to filter for datasets
+ nodes = list(self._tree.selectedH5Nodes())
+ newDatasetName = self._lineEditNewItem.text()
+ isDatasetSelected = False
+ if nodes:
+ node = nodes[0]
+ if silx.io.is_dataset(node.h5py_object):
+ data_path = node.local_name
+ isDatasetSelected = True
+ elif silx.io.is_group(node.h5py_object):
+ data_path = node.local_name
+ if newDatasetName.lstrip("/"):
+ if not data_path.endswith("/"):
+ data_path += "/"
+ data_path += newDatasetName.lstrip("/")
+ isDatasetSelected = True
+
+ if isDatasetSelected:
+ self._selectedUrl = DataUrl(file_path=node.local_filename,
+ data_path=data_path)
+ self._okButton.setEnabled(True)
+ self._labelSelection.setText(
+ self._selectedUrl.path())
+ else:
+ self._selectedUrl = None
+ self._okButton.setEnabled(False)
+ self._labelSelection.setText(self._selectDatasetStatusText)
diff --git a/src/silx/gui/dialog/FileTypeComboBox.py b/src/silx/gui/dialog/FileTypeComboBox.py
new file mode 100644
index 0000000..92529bc
--- /dev/null
+++ b/src/silx/gui/dialog/FileTypeComboBox.py
@@ -0,0 +1,226 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module contains utilitaries used by other dialog modules.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "17/01/2019"
+
+import fabio
+import silx.io
+from silx.gui import qt
+
+
+class Codec(object):
+
+ def __init__(self, any_fabio=False, any_silx=False, fabio_codec=None, auto=False):
+ self.__any_fabio = any_fabio
+ self.__any_silx = any_silx
+ self.fabio_codec = fabio_codec
+ self.__auto = auto
+
+ def is_autodetect(self):
+ return self.__auto
+
+ def is_fabio_codec(self):
+ return self.__any_fabio or self.fabio_codec is not None
+
+ def is_silx_codec(self):
+ return self.__any_silx
+
+
+class FileTypeComboBox(qt.QComboBox):
+ """
+ A combobox providing all image file formats supported by fabio and silx.
+
+ It provides access for each fabio codecs individually.
+ """
+
+ EXTENSIONS_ROLE = qt.Qt.UserRole + 1
+
+ CODEC_ROLE = qt.Qt.UserRole + 2
+
+ INDENTATION = u"\u2022 "
+
+ def __init__(self, parent=None):
+ qt.QComboBox.__init__(self, parent)
+ self.__fabioUrlSupported = True
+ self.__initItems()
+
+ def setFabioUrlSupproted(self, isSupported):
+ if self.__fabioUrlSupported == isSupported:
+ return
+ self.__fabioUrlSupported = isSupported
+ self.__initItems()
+
+ def __initItems(self):
+ self.clear()
+ if self.__fabioUrlSupported:
+ self.__insertFabioFormats()
+ self.__insertSilxFormats()
+ self.__insertAllSupported()
+ self.__insertAnyFiles()
+
+ def __insertAnyFiles(self):
+ index = self.count()
+ self.addItem("All files (*)")
+ self.setItemData(index, ["*"], role=self.EXTENSIONS_ROLE)
+ self.setItemData(index, Codec(auto=True), role=self.CODEC_ROLE)
+
+ def __insertAllSupported(self):
+ allExtensions = set([])
+ for index in range(self.count()):
+ ext = self.itemExtensions(index)
+ allExtensions.update(ext)
+ allExtensions = allExtensions - set("*")
+ list(sorted(list(allExtensions)))
+ index = 0
+ self.insertItem(index, "All supported files")
+ self.setItemData(index, allExtensions, role=self.EXTENSIONS_ROLE)
+ self.setItemData(index, Codec(auto=True), role=self.CODEC_ROLE)
+
+ def __insertSilxFormats(self):
+ formats = silx.io.supported_extensions()
+
+ extensions = []
+ allExtensions = set([])
+
+ for description, ext in formats.items():
+ allExtensions.update(ext)
+ if ext == []:
+ ext = ["*"]
+ extensions.append((description, ext, "silx"))
+ extensions = list(sorted(extensions))
+
+ allExtensions = list(sorted(list(allExtensions)))
+ index = self.count()
+ self.addItem("All supported files, using Silx")
+ self.setItemData(index, allExtensions, role=self.EXTENSIONS_ROLE)
+ self.setItemData(index, Codec(any_silx=True), role=self.CODEC_ROLE)
+
+ for e in extensions:
+ index = self.count()
+ if len(e[1]) < 10:
+ self.addItem("%s%s (%s)" % (self.INDENTATION, e[0], " ".join(e[1])))
+ else:
+ self.addItem("%s%s" % (self.INDENTATION, e[0]))
+ codec = Codec(any_silx=True)
+ self.setItemData(index, e[1], role=self.EXTENSIONS_ROLE)
+ self.setItemData(index, codec, role=self.CODEC_ROLE)
+
+ def __insertFabioFormats(self):
+ formats = fabio.fabioformats.get_classes(reader=True)
+
+ from fabio import fabioutils
+ if hasattr(fabioutils, "COMPRESSED_EXTENSIONS"):
+ compressedExtensions = fabioutils.COMPRESSED_EXTENSIONS
+ else:
+ # Support for fabio < 0.9
+ compressedExtensions = set(["gz", "bz2"])
+
+ extensions = []
+ allExtensions = set([])
+
+ def extensionsIterator(reader):
+ for extension in reader.DEFAULT_EXTENSIONS:
+ yield "*.%s" % extension
+ for compressedExtension in compressedExtensions:
+ for extension in reader.DEFAULT_EXTENSIONS:
+ yield "*.%s.%s" % (extension, compressedExtension)
+
+ for reader in formats:
+ if not hasattr(reader, "DESCRIPTION"):
+ continue
+ if not hasattr(reader, "DEFAULT_EXTENSIONS"):
+ continue
+
+ displayext = reader.DEFAULT_EXTENSIONS
+ displayext = ["*.%s" % e for e in displayext]
+ ext = list(extensionsIterator(reader))
+ allExtensions.update(ext)
+ if ext == []:
+ ext = ["*"]
+ extensions.append((reader.DESCRIPTION, displayext, ext, reader.codec_name()))
+ extensions = list(sorted(extensions))
+
+ allExtensions = list(sorted(list(allExtensions)))
+ index = self.count()
+ self.addItem("All supported files, using Fabio")
+ self.setItemData(index, allExtensions, role=self.EXTENSIONS_ROLE)
+ self.setItemData(index, Codec(any_fabio=True), role=self.CODEC_ROLE)
+
+ for e in extensions:
+ description, displayExt, allExt, _codecName = e
+ index = self.count()
+ if len(e[1]) < 10:
+ self.addItem("%s%s (%s)" % (self.INDENTATION, description, " ".join(displayExt)))
+ else:
+ self.addItem("%s%s" % (self.INDENTATION, description))
+ codec = Codec(fabio_codec=_codecName)
+ self.setItemData(index, allExt, role=self.EXTENSIONS_ROLE)
+ self.setItemData(index, codec, role=self.CODEC_ROLE)
+
+ def itemExtensions(self, index):
+ """Returns the extensions associated to an index."""
+ result = self.itemData(index, self.EXTENSIONS_ROLE)
+ if result is None:
+ result = None
+ return result
+
+ def currentExtensions(self):
+ """Returns the current selected extensions."""
+ index = self.currentIndex()
+ return self.itemExtensions(index)
+
+ def indexFromCodec(self, codecName):
+ for i in range(self.count()):
+ codec = self.itemCodec(i)
+ if codecName == "auto":
+ if codec.is_autodetect():
+ return i
+ elif codecName == "silx":
+ if codec.is_silx_codec():
+ return i
+ elif codecName == "fabio":
+ if codec.is_fabio_codec() and codec.fabio_codec is None:
+ return i
+ elif codecName == codec.fabio_codec:
+ return i
+ return -1
+
+ def itemCodec(self, index):
+ """Returns the codec associated to an index."""
+ result = self.itemData(index, self.CODEC_ROLE)
+ if result is None:
+ result = None
+ return result
+
+ def currentCodec(self):
+ """Returns the current selected codec. None if nothing selected
+ or if the item is not a codec"""
+ index = self.currentIndex()
+ return self.itemCodec(index)
diff --git a/src/silx/gui/dialog/GroupDialog.py b/src/silx/gui/dialog/GroupDialog.py
new file mode 100644
index 0000000..e129a51
--- /dev/null
+++ b/src/silx/gui/dialog/GroupDialog.py
@@ -0,0 +1,230 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a dialog widget to select a HDF5 group in a
+tree.
+
+.. autoclass:: GroupDialog
+ :members: addFile, addGroup, getSelectedDataUrl, setMode
+
+"""
+from silx.gui import qt
+from silx.gui.hdf5.Hdf5TreeView import Hdf5TreeView
+import silx.io
+from silx.io.url import DataUrl
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "22/03/2018"
+
+
+class _Hdf5ItemSelectionDialog(qt.QDialog):
+ SaveMode = 1
+ """Mode used to set the HDF5 item selection dialog to *save* mode.
+ This adds a text field to type in a new item name."""
+
+ LoadMode = 2
+ """Mode used to set the HDF5 item selection dialog to *load* mode.
+ Only existing items of the HDF5 file can be selected in this mode."""
+
+ def __init__(self, parent=None):
+ qt.QDialog.__init__(self, parent)
+ self.setWindowTitle("HDF5 item selection")
+
+ self._tree = Hdf5TreeView(self)
+ self._tree.setSelectionMode(qt.QAbstractItemView.SingleSelection)
+ self._tree.activated.connect(self._onActivation)
+ self._tree.selectionModel().selectionChanged.connect(
+ self._onSelectionChange)
+
+ self._model = self._tree.findHdf5TreeModel()
+
+ self._header = self._tree.header()
+
+ self._newItemWidget = qt.QWidget(self)
+ newItemLayout = qt.QVBoxLayout(self._newItemWidget)
+ self._labelNewItem = qt.QLabel(self._newItemWidget)
+ self._labelNewItem.setText("Create new item in selected group (optional):")
+ self._lineEditNewItem = qt.QLineEdit(self._newItemWidget)
+ self._lineEditNewItem.setToolTip(
+ "Specify the name of a new item "
+ "to be created in the selected group.")
+ self._lineEditNewItem.textChanged.connect(
+ self._onNewItemNameChange)
+ newItemLayout.addWidget(self._labelNewItem)
+ newItemLayout.addWidget(self._lineEditNewItem)
+
+ _labelSelectionTitle = qt.QLabel(self)
+ _labelSelectionTitle.setText("Current selection")
+ self._labelSelection = qt.QLabel(self)
+ self._labelSelection.setStyleSheet("color: gray")
+ self._labelSelection.setWordWrap(True)
+ self._labelSelection.setText("Select an item")
+
+ buttonBox = qt.QDialogButtonBox()
+ self._okButton = buttonBox.addButton(qt.QDialogButtonBox.Ok)
+ self._okButton.setEnabled(False)
+ buttonBox.addButton(qt.QDialogButtonBox.Cancel)
+
+ buttonBox.accepted.connect(self.accept)
+ buttonBox.rejected.connect(self.reject)
+
+ vlayout = qt.QVBoxLayout(self)
+ vlayout.addWidget(self._tree)
+ vlayout.addWidget(self._newItemWidget)
+ vlayout.addWidget(_labelSelectionTitle)
+ vlayout.addWidget(self._labelSelection)
+ vlayout.addWidget(buttonBox)
+ self.setLayout(vlayout)
+
+ self.setMinimumWidth(400)
+
+ self._selectedUrl = None
+
+ def _onSelectionChange(self, old, new):
+ self._updateUrl()
+
+ def _onNewItemNameChange(self, text):
+ self._updateUrl()
+
+ def _onActivation(self, idx):
+ # double-click or enter press
+ self.accept()
+
+ def setMode(self, mode):
+ """Set dialog mode DatasetDialog.SaveMode or DatasetDialog.LoadMode
+
+ :param mode: DatasetDialog.SaveMode or DatasetDialog.LoadMode
+ """
+ if mode == self.LoadMode:
+ # hide "Create new item" field
+ self._lineEditNewItem.clear()
+ self._newItemWidget.hide()
+ elif mode == self.SaveMode:
+ self._newItemWidget.show()
+ else:
+ raise ValueError("Invalid DatasetDialog mode %s" % mode)
+
+ def addFile(self, path):
+ """Add a HDF5 file to the tree.
+ All groups it contains will be selectable in the dialog.
+
+ :param str path: File path
+ """
+ self._model.insertFile(path)
+
+ def addGroup(self, group):
+ """Add a HDF5 group to the tree. This group and all its subgroups
+ will be selectable in the dialog.
+
+ :param h5py.Group group: HDF5 group
+ """
+ self._model.insertH5pyObject(group)
+
+ def _updateUrl(self):
+ nodes = list(self._tree.selectedH5Nodes())
+ subgroupName = self._lineEditNewItem.text()
+ if nodes:
+ node = nodes[0]
+ data_path = node.local_name
+ if subgroupName.lstrip("/"):
+ if not data_path.endswith("/"):
+ data_path += "/"
+ data_path += subgroupName.lstrip("/")
+ self._selectedUrl = DataUrl(file_path=node.local_filename,
+ data_path=data_path)
+ self._okButton.setEnabled(True)
+ self._labelSelection.setText(
+ self._selectedUrl.path())
+
+ def getSelectedDataUrl(self):
+ """Return a :class:`DataUrl` with a file path and a data path.
+ Return None if the dialog was cancelled.
+
+ :return: :class:`silx.io.url.DataUrl` object pointing to the
+ selected HDF5 item.
+ """
+ return self._selectedUrl
+
+
+class GroupDialog(_Hdf5ItemSelectionDialog):
+ """This :class:`QDialog` uses a :class:`silx.gui.hdf5.Hdf5TreeView` to
+ provide a HDF5 group selection dialog.
+
+ The information identifying the selected node is provided as a
+ :class:`silx.io.url.DataUrl`.
+
+ Example:
+
+ .. code-block:: python
+
+ dialog = GroupDialog()
+ dialog.addFile(filepath1)
+ dialog.addFile(filepath2)
+
+ if dialog.exec():
+ print("File path: %s" % dialog.getSelectedDataUrl().file_path())
+ print("HDF5 group path : %s " % dialog.getSelectedDataUrl().data_path())
+ else:
+ print("Operation cancelled :(")
+
+ """
+ def __init__(self, parent=None):
+ _Hdf5ItemSelectionDialog.__init__(self, parent)
+
+ # customization for groups
+ self.setWindowTitle("HDF5 group selection")
+
+ self._header.setSections([self._model.NAME_COLUMN,
+ self._model.NODE_COLUMN,
+ self._model.LINK_COLUMN])
+
+ def _onActivation(self, idx):
+ # double-click or enter press: filter for groups
+ nodes = list(self._tree.selectedH5Nodes())
+ node = nodes[0]
+ if silx.io.is_group(node.h5py_object):
+ self.accept()
+
+ def _updateUrl(self):
+ # overloaded to filter for groups
+ nodes = list(self._tree.selectedH5Nodes())
+ subgroupName = self._lineEditNewItem.text()
+ if nodes:
+ node = nodes[0]
+ if silx.io.is_group(node.h5py_object):
+ data_path = node.local_name
+ if subgroupName.lstrip("/"):
+ if not data_path.endswith("/"):
+ data_path += "/"
+ data_path += subgroupName.lstrip("/")
+ self._selectedUrl = DataUrl(file_path=node.local_filename,
+ data_path=data_path)
+ self._okButton.setEnabled(True)
+ self._labelSelection.setText(
+ self._selectedUrl.path())
+ else:
+ self._selectedUrl = None
+ self._okButton.setEnabled(False)
+ self._labelSelection.setText("Select a group")
diff --git a/src/silx/gui/dialog/ImageFileDialog.py b/src/silx/gui/dialog/ImageFileDialog.py
new file mode 100644
index 0000000..83c6d95
--- /dev/null
+++ b/src/silx/gui/dialog/ImageFileDialog.py
@@ -0,0 +1,354 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module contains an :class:`ImageFileDialog`.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "05/03/2019"
+
+import logging
+from silx.gui.plot import actions
+from silx.gui import qt
+from silx.gui.plot.PlotWidget import PlotWidget
+from .AbstractDataFileDialog import AbstractDataFileDialog
+import silx.io
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _ImageSelection(qt.QWidget):
+ """Provide a widget allowing to select an image from an hypercube by
+ selecting a slice."""
+
+ selectionChanged = qt.Signal()
+ """Emitted when the selection change."""
+
+ def __init__(self, parent=None):
+ qt.QWidget.__init__(self, parent)
+ self.__shape = None
+ self.__axis = []
+ layout = qt.QVBoxLayout()
+ self.setLayout(layout)
+
+ def hasVisibleSelectors(self):
+ return self.__visibleSliders > 0
+
+ def isUsed(self):
+ if self.__shape is None:
+ return False
+ return len(self.__shape) > 2
+
+ def getSelectedData(self, data):
+ slicing = self.slicing()
+ image = data[slicing]
+ return image
+
+ def setData(self, data):
+ if data is None:
+ self.__visibleSliders = 0
+ return
+
+ shape = data.shape
+ if self.__shape is not None:
+ # clean up
+ for widget in self.__axis:
+ self.layout().removeWidget(widget)
+ widget.deleteLater()
+ self.__axis = []
+
+ self.__shape = shape
+ self.__visibleSliders = 0
+
+ if shape is not None:
+ # create expected axes
+ for index in range(len(shape) - 2):
+ axis = qt.QSlider(self)
+ axis.setMinimum(0)
+ axis.setMaximum(shape[index] - 1)
+ axis.setOrientation(qt.Qt.Horizontal)
+ if shape[index] == 1:
+ axis.setVisible(False)
+ else:
+ self.__visibleSliders += 1
+
+ axis.valueChanged.connect(self.__axisValueChanged)
+ self.layout().addWidget(axis)
+ self.__axis.append(axis)
+
+ self.selectionChanged.emit()
+
+ def __axisValueChanged(self):
+ self.selectionChanged.emit()
+
+ def slicing(self):
+ slicing = []
+ for axes in self.__axis:
+ slicing.append(axes.value())
+ return tuple(slicing)
+
+ def setSlicing(self, slicing):
+ for i, value in enumerate(slicing):
+ if i > len(self.__axis):
+ break
+ self.__axis[i].setValue(value)
+
+ def selectSlicing(self, slicing):
+ """Select a slicing.
+
+ The provided value could be unconsistent and therefore is not supposed
+ to be retrivable with a getter.
+
+ :param Union[None,Tuple[int]] slicing:
+ """
+ if slicing is None:
+ # Create a default slicing
+ needed = self.__visibleSliders
+ slicing = (0,) * needed
+ if len(slicing) < self.__visibleSliders:
+ slicing = slicing + (0,) * (self.__visibleSliders - len(slicing))
+ self.setSlicing(slicing)
+
+
+class _ImagePreview(qt.QWidget):
+ """Provide a preview of the selected image"""
+
+ def __init__(self, parent=None):
+ super(_ImagePreview, self).__init__(parent)
+
+ self.__data = None
+ self.__plot = PlotWidget(self)
+ self.__plot.setAxesDisplayed(False)
+ self.__plot.setKeepDataAspectRatio(True)
+ layout = qt.QVBoxLayout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.addWidget(self.__plot)
+ self.setLayout(layout)
+
+ def resizeEvent(self, event):
+ self.__updateConstraints()
+ return qt.QWidget.resizeEvent(self, event)
+
+ def sizeHint(self):
+ return qt.QSize(200, 200)
+
+ def plot(self):
+ return self.__plot
+
+ def setData(self, data, fromDataSelector=False):
+ if data is None:
+ self.clear()
+ return
+
+ resetzoom = not fromDataSelector
+ previousImage = self.data()
+ if previousImage is not None and data.shape != previousImage.shape:
+ resetzoom = True
+
+ self.__plot.addImage(legend="data", data=data, resetzoom=resetzoom)
+ self.__data = data
+ self.__updateConstraints()
+
+ def __updateConstraints(self):
+ """
+ Update the constraints depending on the size of the widget
+ """
+ image = self.data()
+ if image is None:
+ return
+ size = self.size()
+ if size.width() == 0 or size.height() == 0:
+ return
+
+ heightData, widthData = image.shape
+
+ widthContraint = heightData * size.width() / size.height()
+ if widthContraint > widthData:
+ heightContraint = heightData
+ else:
+ heightContraint = heightData * size.height() / size.width()
+ widthContraint = widthData
+
+ midWidth, midHeight = widthData * 0.5, heightData * 0.5
+ heightContraint, widthContraint = heightContraint * 0.5, widthContraint * 0.5
+
+ axis = self.__plot.getXAxis()
+ axis.setLimitsConstraints(midWidth - widthContraint, midWidth + widthContraint)
+ axis = self.__plot.getYAxis()
+ axis.setLimitsConstraints(midHeight - heightContraint, midHeight + heightContraint)
+
+ def __imageItem(self):
+ image = self.__plot.getImage("data")
+ return image
+
+ def data(self):
+ if self.__data is not None:
+ if hasattr(self.__data, "name"):
+ # in case of HDF5
+ if self.__data.name is None:
+ # The dataset was closed
+ self.__data = None
+ return self.__data
+
+ def colormap(self):
+ image = self.__imageItem()
+ if image is not None:
+ return image.getColormap()
+ return self.__plot.getDefaultColormap()
+
+ def setColormap(self, colormap):
+ self.__plot.setDefaultColormap(colormap)
+
+ def clear(self):
+ self.__data = None
+ image = self.__imageItem()
+ if image is not None:
+ self.__plot.removeImage(legend="data")
+
+
+class ImageFileDialog(AbstractDataFileDialog):
+ """The `ImageFileDialog` class provides a dialog that allow users to select
+ an image from a file.
+
+ The `ImageFileDialog` class enables a user to traverse the file system in
+ order to select one file. Then to traverse the file to select a frame or
+ a slice of a dataset.
+
+ .. image:: img/imagefiledialog_h5.png
+
+ It supports fast access to image files using `FabIO`. Which is not the case
+ of the default silx API. Image files still also can be available using the
+ NeXus layout, by editing the file type combo box.
+
+ .. image:: img/imagefiledialog_edf.png
+
+ The selected data is an numpy array with 2 dimension.
+
+ Using an `ImageFileDialog` can be done like that.
+
+ .. code-block:: python
+
+ dialog = ImageFileDialog()
+ result = dialog.exec()
+ if result:
+ print("Selection:")
+ print(dialog.selectedFile())
+ print(dialog.selectedUrl())
+ print(dialog.selectedImage())
+ else:
+ print("Nothing selected")
+ """
+
+ def selectedImage(self):
+ """Returns the selected image data as numpy
+
+ :rtype: numpy.ndarray
+ """
+ url = self.selectedUrl()
+ return silx.io.get_data(url)
+
+ def _createPreviewWidget(self, parent):
+ previewWidget = _ImagePreview(parent)
+ previewWidget.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Expanding)
+ return previewWidget
+
+ def _createSelectorWidget(self, parent):
+ return _ImageSelection(parent)
+
+ def _createPreviewToolbar(self, parent, dataPreviewWidget, dataSelectorWidget):
+ plot = dataPreviewWidget.plot()
+ toolbar = qt.QToolBar(parent)
+ toolbar.setIconSize(qt.QSize(16, 16))
+ toolbar.setStyleSheet("QToolBar { border: 0px }")
+ toolbar.addAction(actions.mode.ZoomModeAction(plot, parent))
+ toolbar.addAction(actions.mode.PanModeAction(plot, parent))
+ toolbar.addSeparator()
+ toolbar.addAction(actions.control.ResetZoomAction(plot, parent))
+ toolbar.addSeparator()
+ toolbar.addAction(actions.control.ColormapAction(plot, parent))
+ return toolbar
+
+ def _isDataSupportable(self, data):
+ """Check if the selected data can be supported at one point.
+
+ If true, the data selector will be checked and it will update the data
+ preview. Else the selecting is disabled.
+
+ :rtype: bool
+ """
+ if not hasattr(data, "dtype"):
+ # It is not an HDF5 dataset nor a fabio image wrapper
+ return False
+
+ if data is None or data.shape is None:
+ return False
+
+ if data.dtype.kind not in set(["f", "u", "i", "b"]):
+ return False
+
+ dim = len(data.shape)
+ return dim >= 2
+
+ def _isFabioFilesSupported(self):
+ return True
+
+ def _isDataSupported(self, data):
+ """Check if the data can be returned by the dialog.
+
+ If true, this data can be returned by the dialog and the open button
+ while be enabled. If false the button will be disabled.
+
+ :rtype: bool
+ """
+ dim = len(data.shape)
+ return dim == 2
+
+ def _displayedDataInfo(self, dataBeforeSelection, dataAfterSelection):
+ """Returns the text displayed under the data preview.
+
+ This zone is used to display error in case or problem of data selection
+ or problems with IO.
+
+ :param numpy.ndarray dataAfterSelection: Data as it is after the
+ selection widget (basically the data from the preview widget)
+ :param numpy.ndarray dataAfterSelection: Data as it is before the
+ selection widget (basically the data from the browsing widget)
+ :rtype: bool
+ """
+ destination = self.__formatShape(dataAfterSelection.shape)
+ source = self.__formatShape(dataBeforeSelection.shape)
+ return u"%s \u2192 %s" % (source, destination)
+
+ def __formatShape(self, shape):
+ result = []
+ for s in shape:
+ if isinstance(s, slice):
+ v = u"\u2026"
+ else:
+ v = str(s)
+ result.append(v)
+ return u" \u00D7 ".join(result)
diff --git a/src/silx/gui/dialog/SafeFileIconProvider.py b/src/silx/gui/dialog/SafeFileIconProvider.py
new file mode 100644
index 0000000..1e06b64
--- /dev/null
+++ b/src/silx/gui/dialog/SafeFileIconProvider.py
@@ -0,0 +1,154 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module contains :class:`SafeIconProvider`.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "31/10/2017"
+
+import sys
+import logging
+from silx.gui import qt
+
+
+_logger = logging.getLogger(__name__)
+
+
+class SafeFileIconProvider(qt.QFileIconProvider):
+ """
+ This class reimplement :class:`qt.QFileIconProvider` to avoid blocking
+ access to the file system.
+
+ It avoid to use `qt.QFileInfo.absoluteFilePath` or
+ `qt.QFileInfo.canonicalPath` to reach drive icons which are known to
+ freeze the file system using network drives.
+
+ Computer root, and drive root paths are filtered. Other paths are not
+ filtered while it is anyway needed to synchronoze a drive to accesss to it.
+ """
+
+ WIN32_DRIVE_UNKNOWN = 0
+ """The drive type cannot be determined."""
+ WIN32_DRIVE_NO_ROOT_DIR = 1
+ """The root path is invalid; for example, there is no volume mounted at the
+ specified path."""
+ WIN32_DRIVE_REMOVABLE = 2
+ """The drive has removable media; for example, a floppy drive, thumb drive,
+ or flash card reader."""
+ WIN32_DRIVE_FIXED = 3
+ """The drive has fixed media; for example, a hard disk drive or flash
+ drive."""
+ WIN32_DRIVE_REMOTE = 4
+ """The drive is a remote (network) drive."""
+ WIN32_DRIVE_CDROM = 5
+ """The drive is a CD-ROM drive."""
+ WIN32_DRIVE_RAMDISK = 6
+ """The drive is a RAM disk."""
+
+ def __init__(self):
+ qt.QFileIconProvider.__init__(self)
+ self.__filterDirAndFiles = False
+ if sys.platform == "win32":
+ self._windowsTypes = {}
+ item = "Drive", qt.QStyle.SP_DriveHDIcon
+ self._windowsTypes[self.WIN32_DRIVE_UNKNOWN] = item
+ item = "Invalid root", qt.QStyle.SP_DriveHDIcon
+ self._windowsTypes[self.WIN32_DRIVE_NO_ROOT_DIR] = item
+ item = "Removable", qt.QStyle.SP_DriveNetIcon
+ self._windowsTypes[self.WIN32_DRIVE_REMOVABLE] = item
+ item = "Drive", qt.QStyle.SP_DriveHDIcon
+ self._windowsTypes[self.WIN32_DRIVE_FIXED] = item
+ item = "Remote", qt.QStyle.SP_DriveNetIcon
+ self._windowsTypes[self.WIN32_DRIVE_REMOTE] = item
+ item = "CD-ROM", qt.QStyle.SP_DriveCDIcon
+ self._windowsTypes[self.WIN32_DRIVE_CDROM] = item
+ item = "RAM disk", qt.QStyle.SP_DriveHDIcon
+ self._windowsTypes[self.WIN32_DRIVE_RAMDISK] = item
+
+ def __windowsDriveTypeId(self, info):
+ try:
+ import ctypes
+ path = info.filePath()
+ dtype = ctypes.cdll.kernel32.GetDriveTypeW(path)
+ except Exception:
+ _logger.warning("Impossible to identify drive %s" % path)
+ _logger.debug("Backtrace", exc_info=True)
+ return self.WIN32_DRIVE_UNKNOWN
+ return dtype
+
+ def __windowsDriveIcon(self, info):
+ dtype = self.__windowsDriveTypeId(info)
+ default = self._windowsTypes[self.WIN32_DRIVE_UNKNOWN]
+ driveInfo = self._windowsTypes.get(dtype, default)
+ style = qt.QApplication.instance().style()
+ icon = style.standardIcon(driveInfo[1])
+ return icon
+
+ def __windowsDriveType(self, info):
+ dtype = self.__windowsDriveTypeId(info)
+ default = self._windowsTypes[self.WIN32_DRIVE_UNKNOWN]
+ driveInfo = self._windowsTypes.get(dtype, default)
+ return driveInfo[0]
+
+ def icon(self, info):
+ if isinstance(info, qt.QFileIconProvider.IconType):
+ # It's another C++ method signature:
+ # QIcon QFileIconProvider::icon(QFileIconProvider::IconType type)
+ return super(SafeFileIconProvider, self).icon(info)
+ style = qt.QApplication.instance().style()
+ path = info.filePath()
+ if path in ["", "/"]:
+ # That's the computer root on Windows or Linux
+ result = style.standardIcon(qt.QStyle.SP_ComputerIcon)
+ elif sys.platform == "win32" and path[-2] == ":":
+ # That's a drive on Windows
+ result = self.__windowsDriveIcon(info)
+ elif self.__filterDirAndFiles:
+ if info.isDir():
+ result = style.standardIcon(qt.QStyle.SP_DirIcon)
+ else:
+ result = style.standardIcon(qt.QStyle.SP_FileIcon)
+ else:
+ result = qt.QFileIconProvider.icon(self, info)
+ return result
+
+ def type(self, info):
+ path = info.filePath()
+ if path in ["", "/"]:
+ # That's the computer root on Windows or Linux
+ result = "Computer"
+ elif sys.platform == "win32" and path[-2] == ":":
+ # That's a drive on Windows
+ result = self.__windowsDriveType(info)
+ elif self.__filterDirAndFiles:
+ if info.isDir():
+ result = "Directory"
+ else:
+ result = info.suffix()
+ else:
+ result = qt.QFileIconProvider.type(self, info)
+ return result
diff --git a/src/silx/gui/dialog/SafeFileSystemModel.py b/src/silx/gui/dialog/SafeFileSystemModel.py
new file mode 100644
index 0000000..1ec7153
--- /dev/null
+++ b/src/silx/gui/dialog/SafeFileSystemModel.py
@@ -0,0 +1,802 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module contains an :class:`SafeFileSystemModel`.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "22/11/2017"
+
+import sys
+import os.path
+import logging
+import weakref
+
+from silx.gui import qt
+from .SafeFileIconProvider import SafeFileIconProvider
+
+_logger = logging.getLogger(__name__)
+
+
+class _Item(object):
+
+ def __init__(self, fileInfo):
+ self.__fileInfo = fileInfo
+ self.__parent = None
+ self.__children = None
+ self.__absolutePath = None
+
+ def isDrive(self):
+ if sys.platform == "win32":
+ return self.parent().parent() is None
+ else:
+ return False
+
+ def isRoot(self):
+ return self.parent() is None
+
+ def isFile(self):
+ """
+ Returns true if the path is a file.
+
+ It avoid to access to the `Qt.QFileInfo` in case the file is a drive.
+ """
+ if self.isDrive():
+ return False
+ return self.__fileInfo.isFile()
+
+ def isDir(self):
+ """
+ Returns true if the path is a directory.
+
+ The default `qt.QFileInfo.isDir` can freeze the file system with
+ network drives. This function avoid the freeze in case of browsing
+ the root.
+ """
+ if self.isDrive():
+ # A drive is a directory, we don't have to synchronize the
+ # drive to know that
+ return True
+ return self.__fileInfo.isDir()
+
+ def absoluteFilePath(self):
+ """
+ Returns an absolute path including the file name.
+
+ This function uses in most cases the default
+ `qt.QFileInfo.absoluteFilePath`. But it is known to freeze the file
+ system with network drives.
+
+ This function uses `qt.QFileInfo.filePath` in case of root drives, to
+ avoid this kind of issues. In case of drive, the result is the same,
+ while the file path is already absolute.
+
+ :rtype: str
+ """
+ if self.__absolutePath is None:
+ if self.isRoot():
+ path = ""
+ elif self.isDrive():
+ path = self.__fileInfo.filePath()
+ else:
+ path = os.path.join(self.parent().absoluteFilePath(), self.__fileInfo.fileName())
+ if path == "":
+ return "/"
+ self.__absolutePath = path
+ return self.__absolutePath
+
+ def child(self):
+ self.populate()
+ return self.__children
+
+ def childAt(self, position):
+ self.populate()
+ return self.__children[position]
+
+ def childCount(self):
+ self.populate()
+ return len(self.__children)
+
+ def indexOf(self, item):
+ self.populate()
+ return self.__children.index(item)
+
+ def parent(self):
+ parent = self.__parent
+ if parent is None:
+ return None
+ return parent()
+
+ def filePath(self):
+ return self.__fileInfo.filePath()
+
+ def fileName(self):
+ if self.isDrive():
+ name = self.absoluteFilePath()
+ if name[-1] == "/":
+ name = name[:-1]
+ return name
+ return os.path.basename(self.absoluteFilePath())
+
+ def fileInfo(self):
+ """
+ Returns the Qt file info.
+
+ :rtype: Qt.QFileInfo
+ """
+ return self.__fileInfo
+
+ def _setParent(self, parent):
+ self.__parent = weakref.ref(parent)
+
+ def findChildrenByPath(self, path):
+ if path == "":
+ return self
+ path = path.replace("\\", "/")
+ if path[-1] == "/":
+ path = path[:-1]
+ names = path.split("/")
+ caseSensitive = qt.QFSFileEngine(path).caseSensitive()
+ count = len(names)
+ cursor = self
+ for name in names:
+ for item in cursor.child():
+ if caseSensitive:
+ same = item.fileName() == name
+ else:
+ same = item.fileName().lower() == name.lower()
+ if same:
+ cursor = item
+ count -= 1
+ break
+ else:
+ return None
+ if count == 0:
+ break
+ else:
+ return None
+ return cursor
+
+ def populate(self):
+ if self.__children is not None:
+ return
+ self.__children = []
+ if self.isRoot():
+ items = qt.QDir.drives()
+ else:
+ directory = qt.QDir(self.absoluteFilePath())
+ filters = qt.QDir.AllEntries | qt.QDir.Hidden | qt.QDir.System
+ items = directory.entryInfoList(filters)
+ for fileInfo in items:
+ i = _Item(fileInfo)
+ self.__children.append(i)
+ i._setParent(self)
+
+
+class _RawFileSystemModel(qt.QAbstractItemModel):
+ """
+ This class implement a file system model and try to avoid freeze. On Qt4,
+ :class:`qt.QFileSystemModel` is known to freeze the file system when
+ network drives are available.
+
+ To avoid this behaviour, this class does not use
+ `qt.QFileInfo.absoluteFilePath` nor `qt.QFileInfo.canonicalPath` to reach
+ information on drives.
+
+ This model do not take care of sorting and filtering. This features are
+ managed by another model, by composition.
+
+ And because it is the end of life of Qt4, we do not implement asynchronous
+ loading of files as it is done by :class:`qt.QFileSystemModel`, nor some
+ useful features.
+ """
+
+ __directoryLoadedSync = qt.Signal(str)
+ """This signal is connected asynchronously to a slot. It allows to
+ emit directoryLoaded as an asynchronous signal."""
+
+ directoryLoaded = qt.Signal(str)
+ """This signal is emitted when the gatherer thread has finished to load the
+ path."""
+
+ rootPathChanged = qt.Signal(str)
+ """This signal is emitted whenever the root path has been changed to a
+ newPath."""
+
+ NAME_COLUMN = 0
+ SIZE_COLUMN = 1
+ TYPE_COLUMN = 2
+ LAST_MODIFIED_COLUMN = 3
+
+ def __init__(self, parent=None):
+ qt.QAbstractItemModel.__init__(self, parent)
+ self.__computer = _Item(qt.QFileInfo())
+ self.__header = "Name", "Size", "Type", "Last modification"
+ self.__currentPath = ""
+ self.__iconProvider = SafeFileIconProvider()
+ self.__directoryLoadedSync.connect(self.__emitDirectoryLoaded, qt.Qt.QueuedConnection)
+
+ def headerData(self, section, orientation, role=qt.Qt.DisplayRole):
+ if orientation == qt.Qt.Horizontal:
+ if role == qt.Qt.DisplayRole:
+ return self.__header[section]
+ if role == qt.Qt.TextAlignmentRole:
+ return qt.Qt.AlignRight if section == 1 else qt.Qt.AlignLeft
+ return None
+
+ def flags(self, index):
+ if not index.isValid():
+ return 0
+ return qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable
+
+ def columnCount(self, parent=qt.QModelIndex()):
+ return len(self.__header)
+
+ def rowCount(self, parent=qt.QModelIndex()):
+ item = self.__getItem(parent)
+ return item.childCount()
+
+ def data(self, index, role=qt.Qt.DisplayRole):
+ if not index.isValid():
+ return None
+
+ column = index.column()
+ if role in [qt.Qt.DisplayRole, qt.Qt.EditRole]:
+ if column == self.NAME_COLUMN:
+ return self.__displayName(index)
+ elif column == self.SIZE_COLUMN:
+ return self.size(index)
+ elif column == self.TYPE_COLUMN:
+ return self.type(index)
+ elif column == self.LAST_MODIFIED_COLUMN:
+ return self.lastModified(index)
+ else:
+ _logger.warning("data: invalid display value column %d", index.column())
+ elif role == qt.QFileSystemModel.FilePathRole:
+ return self.filePath(index)
+ elif role == qt.QFileSystemModel.FileNameRole:
+ return self.fileName(index)
+ elif role == qt.Qt.DecorationRole:
+ if column == self.NAME_COLUMN:
+ icon = self.fileIcon(index)
+ if icon is None or icon.isNull():
+ if self.isDir(index):
+ self.__iconProvider.icon(qt.QFileIconProvider.Folder)
+ else:
+ self.__iconProvider.icon(qt.QFileIconProvider.File)
+ return icon
+ elif role == qt.Qt.TextAlignmentRole:
+ if column == self.SIZE_COLUMN:
+ return qt.Qt.AlignRight
+ elif role == qt.QFileSystemModel.FilePermissions:
+ return self.permissions(index)
+
+ return None
+
+ def index(self, *args, **kwargs):
+ path_api = False
+ path_api |= len(args) >= 1 and isinstance(args[0], str)
+ path_api |= "path" in kwargs
+
+ if path_api:
+ return self.__indexFromPath(*args, **kwargs)
+ else:
+ return self.__index(*args, **kwargs)
+
+ def __index(self, row, column, parent=qt.QModelIndex()):
+ if parent.isValid() and parent.column() != 0:
+ return None
+
+ parentItem = self.__getItem(parent)
+ item = parentItem.childAt(row)
+ return self.createIndex(row, column, item)
+
+ def __indexFromPath(self, path, column=0):
+ """
+ Uses the index(str) C++ API
+
+ :rtype: qt.QModelIndex
+ """
+ if path == "":
+ return qt.QModelIndex()
+
+ item = self.__computer.findChildrenByPath(path)
+ if item is None:
+ return qt.QModelIndex()
+
+ return self.createIndex(item.parent().indexOf(item), column, item)
+
+ def parent(self, index):
+ if not index.isValid():
+ return qt.QModelIndex()
+
+ item = self.__getItem(index)
+ if index is None:
+ return qt.QModelIndex()
+
+ parent = item.parent()
+ if parent is None or parent is self.__computer:
+ return qt.QModelIndex()
+
+ return self.createIndex(parent.parent().indexOf(parent), 0, parent)
+
+ def __emitDirectoryLoaded(self, path):
+ self.directoryLoaded.emit(path)
+
+ def __emitRootPathChanged(self, path):
+ self.rootPathChanged.emit(path)
+
+ def __getItem(self, index):
+ if not index.isValid():
+ return self.__computer
+ item = index.internalPointer()
+ return item
+
+ def fileIcon(self, index):
+ item = self.__getItem(index)
+ if self.__iconProvider is not None:
+ fileInfo = item.fileInfo()
+ result = self.__iconProvider.icon(fileInfo)
+ else:
+ style = qt.QApplication.instance().style()
+ if item.isRoot():
+ result = style.standardIcon(qt.QStyle.SP_ComputerIcon)
+ elif item.isDrive():
+ result = style.standardIcon(qt.QStyle.SP_DriveHDIcon)
+ elif item.isDir():
+ result = style.standardIcon(qt.QStyle.SP_DirIcon)
+ else:
+ result = style.standardIcon(qt.QStyle.SP_FileIcon)
+ return result
+
+ def _item(self, index):
+ item = self.__getItem(index)
+ return item
+
+ def fileInfo(self, index):
+ item = self.__getItem(index)
+ result = item.fileInfo()
+ return result
+
+ def __fileIcon(self, index):
+ item = self.__getItem(index)
+ result = item.fileName()
+ return result
+
+ def __displayName(self, index):
+ item = self.__getItem(index)
+ result = item.fileName()
+ return result
+
+ def fileName(self, index):
+ item = self.__getItem(index)
+ result = item.fileName()
+ return result
+
+ def filePath(self, index):
+ item = self.__getItem(index)
+ result = item.fileInfo().filePath()
+ return result
+
+ def isDir(self, index):
+ item = self.__getItem(index)
+ result = item.isDir()
+ return result
+
+ def lastModified(self, index):
+ item = self.__getItem(index)
+ result = item.fileInfo().lastModified()
+ return result
+
+ def permissions(self, index):
+ item = self.__getItem(index)
+ result = item.fileInfo().permissions()
+ return result
+
+ def size(self, index):
+ item = self.__getItem(index)
+ result = item.fileInfo().size()
+ return result
+
+ def type(self, index):
+ item = self.__getItem(index)
+ if self.__iconProvider is not None:
+ fileInfo = item.fileInfo()
+ result = self.__iconProvider.type(fileInfo)
+ else:
+ if item.isRoot():
+ result = "Computer"
+ elif item.isDrive():
+ result = "Drive"
+ elif item.isDir():
+ result = "Directory"
+ else:
+ fileInfo = item.fileInfo()
+ result = fileInfo.suffix()
+ return result
+
+ # File manipulation
+
+ # bool remove(const QModelIndex & index) const
+ # bool rmdir(const QModelIndex & index) const
+ # QModelIndex mkdir(const QModelIndex & parent, const QString & name)
+
+ # Configuration
+
+ def rootDirectory(self):
+ return qt.QDir(self.rootPath())
+
+ def rootPath(self):
+ return self.__currentPath
+
+ def setRootPath(self, path):
+ if self.__currentPath == path:
+ return
+ self.__currentPath = path
+ item = self.__computer.findChildrenByPath(path)
+ self.__emitRootPathChanged(path)
+ if item is None or item.parent() is None:
+ return qt.QModelIndex()
+ index = self.createIndex(item.parent().indexOf(item), 0, item)
+ self.__directoryLoadedSync.emit(path)
+ return index
+
+ def iconProvider(self):
+ # FIXME: invalidate the model
+ return self.__iconProvider
+
+ def setIconProvider(self, provider):
+ # FIXME: invalidate the model
+ self.__iconProvider = provider
+
+ # bool resolveSymlinks() const
+ # void setResolveSymlinks(bool enable)
+
+ def setNameFilterDisables(self, enable):
+ return None
+
+ def nameFilterDisables(self):
+ return None
+
+ def myComputer(self, role=qt.Qt.DisplayRole):
+ return None
+
+ def setNameFilters(self, filters):
+ return
+
+ def nameFilters(self):
+ return None
+
+ def filter(self):
+ return self.__filters
+
+ def setFilter(self, filters):
+ return
+
+ def setReadOnly(self, enable):
+ assert(enable is True)
+
+ def isReadOnly(self):
+ return False
+
+
+class SafeFileSystemModel(qt.QSortFilterProxyModel):
+ """
+ This class implement a file system model and try to avoid freeze. On Qt4,
+ :class:`qt.QFileSystemModel` is known to freeze the file system when
+ network drives are available.
+
+ To avoid this behaviour, this class does not use
+ `qt.QFileInfo.absoluteFilePath` nor `qt.QFileInfo.canonicalPath` to reach
+ information on drives.
+
+ And because it is the end of life of Qt4, we do not implement asynchronous
+ loading of files as it is done by :class:`qt.QFileSystemModel`, nor some
+ useful features.
+ """
+
+ def __init__(self, parent=None):
+ qt.QSortFilterProxyModel.__init__(self, parent=parent)
+ self.__nameFilterDisables = sys.platform == "darwin"
+ self.__nameFilters = []
+ self.__filters = qt.QDir.AllEntries | qt.QDir.NoDotAndDotDot | qt.QDir.AllDirs
+ sourceModel = _RawFileSystemModel(self)
+ self.setSourceModel(sourceModel)
+
+ @property
+ def directoryLoaded(self):
+ return self.sourceModel().directoryLoaded
+
+ @property
+ def rootPathChanged(self):
+ return self.sourceModel().rootPathChanged
+
+ def index(self, *args, **kwargs):
+ path_api = False
+ path_api |= len(args) >= 1 and isinstance(args[0], str)
+ path_api |= "path" in kwargs
+
+ if path_api:
+ return self.__indexFromPath(*args, **kwargs)
+ else:
+ return self.__index(*args, **kwargs)
+
+ def __index(self, row, column, parent=qt.QModelIndex()):
+ return qt.QSortFilterProxyModel.index(self, row, column, parent)
+
+ def __indexFromPath(self, path, column=0):
+ """
+ Uses the index(str) C++ API
+
+ :rtype: qt.QModelIndex
+ """
+ if path == "":
+ return qt.QModelIndex()
+
+ index = self.sourceModel().index(path, column)
+ index = self.mapFromSource(index)
+ return index
+
+ def lessThan(self, leftSourceIndex, rightSourceIndex):
+ sourceModel = self.sourceModel()
+ sortColumn = self.sortColumn()
+ if sortColumn == _RawFileSystemModel.NAME_COLUMN:
+ leftItem = sourceModel._item(leftSourceIndex)
+ rightItem = sourceModel._item(rightSourceIndex)
+ if sys.platform != "darwin":
+ # Sort directories before files
+ leftIsDir = leftItem.isDir()
+ rightIsDir = rightItem.isDir()
+ if leftIsDir ^ rightIsDir:
+ return leftIsDir
+ return leftItem.fileName().lower() < rightItem.fileName().lower()
+ elif sortColumn == _RawFileSystemModel.SIZE_COLUMN:
+ left = sourceModel.fileInfo(leftSourceIndex)
+ right = sourceModel.fileInfo(rightSourceIndex)
+ return left.size() < right.size()
+ elif sortColumn == _RawFileSystemModel.TYPE_COLUMN:
+ left = sourceModel.type(leftSourceIndex)
+ right = sourceModel.type(rightSourceIndex)
+ return left < right
+ elif sortColumn == _RawFileSystemModel.LAST_MODIFIED_COLUMN:
+ left = sourceModel.fileInfo(leftSourceIndex)
+ right = sourceModel.fileInfo(rightSourceIndex)
+ return left.lastModified() < right.lastModified()
+ else:
+ _logger.warning("Unsupported sorted column %d", sortColumn)
+
+ return False
+
+ def __filtersAccepted(self, item, filters):
+ """
+ Check individual flag filters.
+ """
+ if not (filters & (qt.QDir.Dirs | qt.QDir.AllDirs)):
+ # Hide dirs
+ if item.isDir():
+ return False
+ if not (filters & qt.QDir.Files):
+ # Hide files
+ if item.isFile():
+ return False
+ if not (filters & qt.QDir.Drives):
+ # Hide drives
+ if item.isDrive():
+ return False
+
+ fileInfo = item.fileInfo()
+ if fileInfo is None:
+ return False
+
+ filterPermissions = (filters & qt.QDir.PermissionMask) != 0
+ if filterPermissions and (filters & (qt.QDir.Dirs | qt.QDir.Files)):
+ if (filters & qt.QDir.Readable):
+ # Hide unreadable
+ if not fileInfo.isReadable():
+ return False
+ if (filters & qt.QDir.Writable):
+ # Hide unwritable
+ if not fileInfo.isWritable():
+ return False
+ if (filters & qt.QDir.Executable):
+ # Hide unexecutable
+ if not fileInfo.isExecutable():
+ return False
+
+ if (filters & qt.QDir.NoSymLinks):
+ # Hide sym links
+ if fileInfo.isSymLink():
+ return False
+
+ if not (filters & qt.QDir.System):
+ # Hide system
+ if not item.isDir() and not item.isFile():
+ return False
+
+ fileName = item.fileName()
+ isDot = fileName == "."
+ isDotDot = fileName == ".."
+
+ if not (filters & qt.QDir.Hidden):
+ # Hide hidden
+ if not (isDot or isDotDot) and fileInfo.isHidden():
+ return False
+
+ if filters & (qt.QDir.NoDot | qt.QDir.NoDotDot | qt.QDir.NoDotAndDotDot):
+ # Hide parent/self references
+ if filters & qt.QDir.NoDot:
+ if isDot:
+ return False
+ if filters & qt.QDir.NoDotDot:
+ if isDotDot:
+ return False
+ if filters & qt.QDir.NoDotAndDotDot:
+ if isDot or isDotDot:
+ return False
+
+ return True
+
+ def filterAcceptsRow(self, sourceRow, sourceParent):
+ if not sourceParent.isValid():
+ return True
+
+ sourceModel = self.sourceModel()
+ index = sourceModel.index(sourceRow, 0, sourceParent)
+ if not index.isValid():
+ return True
+ item = sourceModel._item(index)
+
+ filters = self.__filters
+
+ if item.isDrive():
+ # Let say a user always have access to a drive
+ # It avoid to access to fileInfo then avoid to freeze the file
+ # system
+ return True
+
+ if not self.__filtersAccepted(item, filters):
+ return False
+
+ if self.__nameFilterDisables:
+ return True
+
+ if item.isDir() and (filters & qt.QDir.AllDirs):
+ # dont apply the filters to directory names
+ return True
+
+ return self.__nameFiltersAccepted(item)
+
+ def __nameFiltersAccepted(self, item):
+ if len(self.__nameFilters) == 0:
+ return True
+
+ fileName = item.fileName()
+ for reg in self.__nameFilters:
+ if reg.exactMatch(fileName):
+ return True
+ return False
+
+ def setNameFilterDisables(self, enable):
+ self.__nameFilterDisables = enable
+ self.invalidate()
+
+ def nameFilterDisables(self):
+ return self.__nameFilterDisables
+
+ def myComputer(self, role=qt.Qt.DisplayRole):
+ return self.sourceModel().myComputer(role)
+
+ def setNameFilters(self, filters):
+ self.__nameFilters = []
+ isCaseSensitive = self.__filters & qt.QDir.CaseSensitive
+ caseSensitive = qt.Qt.CaseSensitive if isCaseSensitive else qt.Qt.CaseInsensitive
+ for f in filters:
+ reg = qt.QRegExp(f, caseSensitive, qt.QRegExp.Wildcard)
+ self.__nameFilters.append(reg)
+ self.invalidate()
+
+ def nameFilters(self):
+ return [f.pattern() for f in self.__nameFilters]
+
+ def filter(self):
+ return self.__filters
+
+ def setFilter(self, filters):
+ self.__filters = filters
+ # In case of change of case sensitivity
+ self.setNameFilters(self.nameFilters())
+ self.invalidate()
+
+ def setReadOnly(self, enable):
+ assert(enable is True)
+
+ def isReadOnly(self):
+ return False
+
+ def rootPath(self):
+ return self.sourceModel().rootPath()
+
+ def setRootPath(self, path):
+ index = self.sourceModel().setRootPath(path)
+ index = self.mapFromSource(index)
+ return index
+
+ def flags(self, index):
+ sourceModel = self.sourceModel()
+ index = self.mapToSource(index)
+ filters = sourceModel.flags(index)
+
+ if self.__nameFilterDisables and not sourceModel.isDir(index):
+ item = sourceModel._item(index)
+ if not self.__nameFiltersAccepted(item):
+ filters &= ~qt.Qt.ItemIsEnabled
+
+ return filters
+
+ def fileIcon(self, index):
+ sourceModel = self.sourceModel()
+ index = self.mapToSource(index)
+ return sourceModel.fileIcon(index)
+
+ def fileInfo(self, index):
+ sourceModel = self.sourceModel()
+ index = self.mapToSource(index)
+ return sourceModel.fileInfo(index)
+
+ def fileName(self, index):
+ sourceModel = self.sourceModel()
+ index = self.mapToSource(index)
+ return sourceModel.fileName(index)
+
+ def filePath(self, index):
+ sourceModel = self.sourceModel()
+ index = self.mapToSource(index)
+ return sourceModel.filePath(index)
+
+ def isDir(self, index):
+ sourceModel = self.sourceModel()
+ index = self.mapToSource(index)
+ return sourceModel.isDir(index)
+
+ def lastModified(self, index):
+ sourceModel = self.sourceModel()
+ index = self.mapToSource(index)
+ return sourceModel.lastModified(index)
+
+ def permissions(self, index):
+ sourceModel = self.sourceModel()
+ index = self.mapToSource(index)
+ return sourceModel.permissions(index)
+
+ def size(self, index):
+ sourceModel = self.sourceModel()
+ index = self.mapToSource(index)
+ return sourceModel.size(index)
+
+ def type(self, index):
+ sourceModel = self.sourceModel()
+ index = self.mapToSource(index)
+ return sourceModel.type(index)
diff --git a/src/silx/gui/dialog/__init__.py b/src/silx/gui/dialog/__init__.py
new file mode 100644
index 0000000..77c5949
--- /dev/null
+++ b/src/silx/gui/dialog/__init__.py
@@ -0,0 +1,29 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Qt dialogs"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "11/10/2017"
diff --git a/src/silx/gui/dialog/setup.py b/src/silx/gui/dialog/setup.py
new file mode 100644
index 0000000..48ab8d8
--- /dev/null
+++ b/src/silx/gui/dialog/setup.py
@@ -0,0 +1,40 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "23/10/2017"
+
+from numpy.distutils.misc_util import Configuration
+
+
+def configuration(parent_package='', top_path=None):
+ config = Configuration('dialog', parent_package, top_path)
+ config.add_subpackage('test')
+ return config
+
+
+if __name__ == "__main__":
+ from numpy.distutils.core import setup
+ setup(configuration=configuration)
diff --git a/src/silx/gui/dialog/test/__init__.py b/src/silx/gui/dialog/test/__init__.py
new file mode 100644
index 0000000..71128fb
--- /dev/null
+++ b/src/silx/gui/dialog/test/__init__.py
@@ -0,0 +1,24 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
diff --git a/src/silx/gui/dialog/test/test_colormapdialog.py b/src/silx/gui/dialog/test/test_colormapdialog.py
new file mode 100644
index 0000000..16a5ab2
--- /dev/null
+++ b/src/silx/gui/dialog/test/test_colormapdialog.py
@@ -0,0 +1,395 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for ColormapDialog"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "09/11/2018"
+
+
+import pytest
+import weakref
+
+from silx.gui import qt
+from silx.gui.dialog import ColormapDialog
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.colors import Colormap, preferredColormaps
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.plot.items.image import ImageData
+
+import numpy
+
+
+@pytest.fixture
+def colormap():
+ colormap = Colormap(name='gray',
+ vmin=10.0, vmax=20.0,
+ normalization='linear')
+ yield colormap
+
+
+@pytest.fixture
+def colormapDialog(qapp, qapp_utils):
+ dialog = ColormapDialog.ColormapDialog()
+ dialog.setAttribute(qt.Qt.WA_DeleteOnClose)
+ yield weakref.proxy(dialog)
+ qapp.processEvents()
+ from silx.gui.qt import inspect
+ if inspect.isValid(dialog):
+ dialog.close()
+ qapp.processEvents()
+
+
+@pytest.fixture
+def colormap_class_attr(request, qapp_utils, colormap, colormapDialog):
+ """Provides few fixtures to a class as class attribute
+
+ Used as transition from TestCase to pytest
+ """
+ request.cls.qapp_utils = qapp_utils
+ request.cls.colormap = colormap
+ request.cls.colormapDiag = colormapDialog
+ yield
+ request.cls.qapp_utils = None
+ request.cls.colormap = None
+ request.cls.colormapDiag = None
+
+
+@pytest.mark.usefixtures("colormap_class_attr")
+class TestColormapDialog(TestCaseQt, ParametricTestCase):
+
+ def testGUIEdition(self):
+ """Make sure the colormap is correctly edited and also that the
+ modification are correctly updated if an other colormapdialog is
+ editing the same colormap"""
+ colormapDiag2 = ColormapDialog.ColormapDialog()
+ colormapDiag2.setColormap(self.colormap)
+ colormapDiag2.show()
+ self.colormapDiag.setColormap(self.colormap)
+ self.colormapDiag.show()
+ self.qapp.processEvents()
+
+ self.colormapDiag._comboBoxColormap._setCurrentName('red')
+ self.colormapDiag._comboBoxNormalization.setCurrentIndex(
+ self.colormapDiag._comboBoxNormalization.findData(Colormap.LOGARITHM))
+ self.assertTrue(self.colormap.getName() == 'red')
+ self.assertTrue(self.colormapDiag.getColormap().getName() == 'red')
+ self.assertTrue(self.colormap.getNormalization() == 'log')
+ self.assertTrue(self.colormap.getVMin() == 10)
+ self.assertTrue(self.colormap.getVMax() == 20)
+ # checked second colormap dialog
+ self.assertTrue(colormapDiag2._comboBoxColormap.getCurrentName() == 'red')
+ self.assertEqual(colormapDiag2._comboBoxNormalization.currentData(),
+ Colormap.LOGARITHM)
+ self.assertTrue(int(colormapDiag2._minValue.getValue()) == 10)
+ self.assertTrue(int(colormapDiag2._maxValue.getValue()) == 20)
+ colormapDiag2.close()
+
+ def testGUIModalOk(self):
+ """Make sure the colormap is modified if gone through accept"""
+ assert self.colormap.isAutoscale() is False
+ self.colormapDiag.setModal(True)
+ self.colormapDiag.show()
+ self.qapp.processEvents()
+ self.colormapDiag.setColormap(self.colormap)
+ self.assertTrue(self.colormap.getVMin() is not None)
+ self.colormapDiag._minValue.sigAutoScaleChanged.emit(True)
+ self.assertTrue(self.colormap.getVMin() is None)
+ self.colormapDiag._maxValue.sigAutoScaleChanged.emit(True)
+ self.mouseClick(
+ widget=self.colormapDiag._buttonsModal.button(qt.QDialogButtonBox.Ok),
+ button=qt.Qt.LeftButton
+ )
+ self.assertTrue(self.colormap.getVMin() is None)
+ self.assertTrue(self.colormap.getVMax() is None)
+ self.assertTrue(self.colormap.isAutoscale() is True)
+
+ def testGUIModalCancel(self):
+ """Make sure the colormap is not modified if gone through reject"""
+ assert self.colormap.isAutoscale() is False
+ self.colormapDiag.setModal(True)
+ self.colormapDiag.show()
+ self.qapp.processEvents()
+ self.colormapDiag.setColormap(self.colormap)
+ self.assertTrue(self.colormap.getVMin() is not None)
+ self.colormapDiag._minValue.sigAutoScaleChanged.emit(True)
+ self.assertTrue(self.colormap.getVMin() is None)
+ self.mouseClick(
+ widget=self.colormapDiag._buttonsModal.button(qt.QDialogButtonBox.Cancel),
+ button=qt.Qt.LeftButton
+ )
+ self.assertTrue(self.colormap.getVMin() is not None)
+
+ def testGUIModalClose(self):
+ assert self.colormap.isAutoscale() is False
+ self.colormapDiag.setModal(False)
+ self.colormapDiag.show()
+ self.qapp.processEvents()
+ self.colormapDiag.setColormap(self.colormap)
+ self.assertTrue(self.colormap.getVMin() is not None)
+ self.colormapDiag._minValue.sigAutoScaleChanged.emit(True)
+ self.assertTrue(self.colormap.getVMin() is None)
+ self.mouseClick(
+ widget=self.colormapDiag._buttonsNonModal.button(qt.QDialogButtonBox.Close),
+ button=qt.Qt.LeftButton
+ )
+ self.assertTrue(self.colormap.getVMin() is None)
+
+ def testGUIModalReset(self):
+ assert self.colormap.isAutoscale() is False
+ self.colormapDiag.setModal(False)
+ self.colormapDiag.show()
+ self.qapp.processEvents()
+ self.colormapDiag.setColormap(self.colormap)
+ self.assertTrue(self.colormap.getVMin() is not None)
+ self.colormapDiag._minValue.sigAutoScaleChanged.emit(True)
+ self.assertTrue(self.colormap.getVMin() is None)
+ self.mouseClick(
+ widget=self.colormapDiag._buttonsNonModal.button(qt.QDialogButtonBox.Reset),
+ button=qt.Qt.LeftButton
+ )
+ self.assertTrue(self.colormap.getVMin() is not None)
+ self.colormapDiag.close()
+
+ def testGUIClose(self):
+ """Make sure the colormap is modify if go through reject"""
+ assert self.colormap.isAutoscale() is False
+ self.colormapDiag.show()
+ self.qapp.processEvents()
+ self.colormapDiag.setColormap(self.colormap)
+ self.assertTrue(self.colormap.getVMin() is not None)
+ self.colormapDiag._minValue.sigAutoScaleChanged.emit(True)
+ self.assertTrue(self.colormap.getVMin() is None)
+ self.colormapDiag.close()
+ self.qapp.processEvents()
+ self.assertTrue(self.colormap.getVMin() is None)
+
+ def testSetColormapIsCorrect(self):
+ """Make sure the interface fir the colormap when set a new colormap"""
+ self.colormap.setName('red')
+ self.colormapDiag.show()
+ self.qapp.processEvents()
+ for norm in (Colormap.NORMALIZATIONS):
+ for autoscale in (True, False):
+ if autoscale is True:
+ self.colormap.setVRange(None, None)
+ else:
+ self.colormap.setVRange(11, 101)
+ self.colormap.setNormalization(norm)
+ with self.subTest(colormap=self.colormap):
+ self.colormapDiag.setColormap(self.colormap)
+ self.assertEqual(
+ self.colormapDiag._comboBoxNormalization.currentData(), norm)
+ self.assertTrue(
+ self.colormapDiag._comboBoxColormap.getCurrentName() == 'red')
+ self.assertTrue(
+ self.colormapDiag._minValue.isAutoChecked() == autoscale)
+ self.assertTrue(
+ self.colormapDiag._maxValue.isAutoChecked() == autoscale)
+ if autoscale is False:
+ self.assertTrue(self.colormapDiag._minValue.getValue() == 11)
+ self.assertTrue(self.colormapDiag._maxValue.getValue() == 101)
+ self.assertTrue(self.colormapDiag._minValue.isEnabled())
+ self.assertTrue(self.colormapDiag._maxValue.isEnabled())
+ else:
+ self.assertFalse(self.colormapDiag._minValue._numVal.isEnabled())
+ self.assertFalse(self.colormapDiag._maxValue._numVal.isEnabled())
+
+ def testColormapDel(self):
+ """Check behavior if the colormap has been deleted outside. For now
+ we make sure the colormap is still running and nothing more"""
+ colormap = Colormap(name='gray')
+ self.colormapDiag.setColormap(colormap)
+ self.colormapDiag.show()
+ self.qapp.processEvents()
+ colormap = None
+ self.assertTrue(self.colormapDiag.getColormap() is None)
+ self.colormapDiag._comboBoxColormap._setCurrentName('blue')
+
+ def testColormapEditedOutside(self):
+ """Make sure the GUI is still up to date if the colormap is modified
+ outside"""
+ self.colormapDiag.setColormap(self.colormap)
+ self.colormapDiag.show()
+ self.qapp.processEvents()
+
+ self.colormap.setName('red')
+ self.assertTrue(
+ self.colormapDiag._comboBoxColormap.getCurrentName() == 'red')
+ self.colormap.setNormalization(Colormap.LOGARITHM)
+ self.assertEqual(self.colormapDiag._comboBoxNormalization.currentData(),
+ Colormap.LOGARITHM)
+ self.colormap.setVRange(11, 201)
+ self.assertTrue(self.colormapDiag._minValue.getValue() == 11)
+ self.assertTrue(self.colormapDiag._maxValue.getValue() == 201)
+ self.assertTrue(self.colormapDiag._minValue._numVal.isEnabled())
+ self.assertTrue(self.colormapDiag._maxValue._numVal.isEnabled())
+ self.assertFalse(self.colormapDiag._minValue.isAutoChecked())
+ self.assertFalse(self.colormapDiag._maxValue.isAutoChecked())
+ self.colormap.setVRange(None, None)
+ self.assertFalse(self.colormapDiag._minValue._numVal.isEnabled())
+ self.assertFalse(self.colormapDiag._maxValue._numVal.isEnabled())
+ self.assertTrue(self.colormapDiag._minValue.isAutoChecked())
+ self.assertTrue(self.colormapDiag._maxValue.isAutoChecked())
+
+ def testSetColormapScenario(self):
+ """Test of a simple scenario of a colormap dialog editing several
+ colormap"""
+ colormap1 = Colormap(name='gray', vmin=10.0, vmax=20.0,
+ normalization='linear')
+ colormap2 = Colormap(name='red', vmin=10.0, vmax=20.0,
+ normalization='log')
+ colormap3 = Colormap(name='blue', vmin=None, vmax=None,
+ normalization='linear')
+ self.colormapDiag.setColormap(self.colormap)
+ self.colormapDiag.setColormap(colormap1)
+ del colormap1
+ self.colormapDiag.setColormap(colormap2)
+ del colormap2
+ self.colormapDiag.setColormap(colormap3)
+ del colormap3
+
+ def testNotPreferredColormap(self):
+ """Test that the colormapEditor is able to edit a colormap which is not
+ part of the 'prefered colormap'
+ """
+ def getFirstNotPreferredColormap():
+ cms = Colormap.getSupportedColormaps()
+ preferred = preferredColormaps()
+ for cm in cms:
+ if cm not in preferred:
+ return cm
+ return None
+
+ colormapName = getFirstNotPreferredColormap()
+ assert colormapName is not None
+ colormap = Colormap(name=colormapName)
+ self.colormapDiag.setColormap(colormap)
+ self.colormapDiag.show()
+ self.qapp.processEvents()
+ cb = self.colormapDiag._comboBoxColormap
+ self.assertTrue(cb.getCurrentName() == colormapName)
+ cb.setCurrentIndex(0)
+ index = cb.findLutName(colormapName)
+ assert index != 0 # if 0 then the rest of the test has no sense
+ cb.setCurrentIndex(index)
+ self.assertTrue(cb.getCurrentName() == colormapName)
+
+ def testColormapEditableMode(self):
+ """Test that the colormapDialog is correctly updated when changing the
+ colormap editable status"""
+ colormap = Colormap(normalization='linear', vmin=1.0, vmax=10.0)
+ self.colormapDiag.show()
+ self.qapp.processEvents()
+ self.colormapDiag.setColormap(colormap)
+ for editable in (True, False):
+ with self.subTest(editable=editable):
+ colormap.setEditable(editable)
+ self.assertTrue(
+ self.colormapDiag._comboBoxColormap.isEnabled() is editable)
+ self.assertTrue(
+ self.colormapDiag._minValue.isEnabled() is editable)
+ self.assertTrue(
+ self.colormapDiag._maxValue.isEnabled() is editable)
+ self.assertTrue(
+ self.colormapDiag._comboBoxNormalization.isEnabled() is editable)
+
+ # Make sure the reset button is also set to enable when edition mode is
+ # False
+ self.colormapDiag.setModal(False)
+ colormap.setEditable(True)
+ self.colormapDiag._comboBoxNormalization.setCurrentIndex(
+ self.colormapDiag._comboBoxNormalization.findData(Colormap.LOGARITHM))
+ resetButton = self.colormapDiag._buttonsNonModal.button(qt.QDialogButtonBox.Reset)
+ self.assertTrue(resetButton.isEnabled())
+ colormap.setEditable(False)
+ self.assertFalse(resetButton.isEnabled())
+
+ def testImageData(self):
+ data = numpy.random.rand(5, 5)
+ self.colormapDiag.setData(data)
+
+ def testEmptyData(self):
+ data = numpy.empty((10, 0))
+ self.colormapDiag.setData(data)
+
+ def testNoneData(self):
+ data = numpy.random.rand(5, 5)
+ self.colormapDiag.setData(data)
+ self.colormapDiag.setData(None)
+
+ def testImageItem(self):
+ """Check that an ImageData plot item can be used"""
+ dialog = self.colormapDiag
+ colormap = Colormap(name='gray', vmin=None, vmax=None)
+ data = numpy.arange(3**2).reshape(3, 3)
+ item = ImageData()
+ item.setData(data, copy=False)
+
+ dialog.setColormap(colormap)
+ dialog.show()
+ self.qapp.processEvents()
+ dialog.setItem(item)
+ vrange = dialog._getFiniteColormapRange()
+ self.assertEqual(vrange, (0, 8))
+
+ def testItemDel(self):
+ """Check that the plot items are not hard linked to the dialog"""
+ dialog = self.colormapDiag
+ colormap = Colormap(name='gray', vmin=None, vmax=None)
+ data = numpy.arange(3**2).reshape(3, 3)
+ item = ImageData()
+ item.setData(data, copy=False)
+
+ dialog.setColormap(colormap)
+ dialog.show()
+ self.qapp.processEvents()
+ dialog.setItem(item)
+ previousRange = dialog._getFiniteColormapRange()
+ del item
+ vrange = dialog._getFiniteColormapRange()
+ self.assertNotEqual(vrange, previousRange)
+
+ def testDataDel(self):
+ """Check that the data are not hard linked to the dialog"""
+ dialog = self.colormapDiag
+ colormap = Colormap(name='gray', vmin=None, vmax=None)
+ data = numpy.arange(5)
+
+ dialog.setColormap(colormap)
+ dialog.show()
+ self.qapp.processEvents()
+ dialog.setData(data)
+ previousRange = dialog._getFiniteColormapRange()
+ del data
+ vrange = dialog._getFiniteColormapRange()
+ self.assertNotEqual(vrange, previousRange)
+
+ def testDeleteWhileExec(self):
+ colormapDiag = self.colormapDiag
+ self.colormapDiag = None
+ qt.QTimer.singleShot(1000, colormapDiag.deleteLater)
+ result = colormapDiag.exec()
+ self.assertEqual(result, 0)
diff --git a/src/silx/gui/dialog/test/test_datafiledialog.py b/src/silx/gui/dialog/test/test_datafiledialog.py
new file mode 100644
index 0000000..8411c67
--- /dev/null
+++ b/src/silx/gui/dialog/test/test_datafiledialog.py
@@ -0,0 +1,924 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test for silx.gui.hdf5 module"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "08/03/2019"
+
+
+import unittest
+import tempfile
+import numpy
+import shutil
+import os
+import io
+import weakref
+import fabio
+import h5py
+import silx.io.url
+from silx.gui import qt
+from silx.gui.utils import testutils
+from ..DataFileDialog import DataFileDialog
+from silx.gui.hdf5 import Hdf5TreeModel
+
+_tmpDirectory = None
+
+
+def setUpModule():
+ global _tmpDirectory
+ _tmpDirectory = tempfile.mkdtemp(prefix=__name__)
+
+ data = numpy.arange(100 * 100)
+ data.shape = 100, 100
+
+ filename = _tmpDirectory + "/singleimage.edf"
+ image = fabio.edfimage.EdfImage(data=data)
+ image.write(filename)
+
+ filename = _tmpDirectory + "/data.h5"
+ f = h5py.File(filename, "w")
+ f["scalar"] = 10
+ f["image"] = data
+ f["cube"] = [data, data + 1, data + 2]
+ f["complex_image"] = data * 1j
+ f["group/image"] = data
+ f["nxdata/foo"] = 10
+ f["nxdata"].attrs["NX_class"] = u"NXdata"
+ f.close()
+
+ directory = os.path.join(_tmpDirectory, "data")
+ os.mkdir(directory)
+ filename = os.path.join(directory, "data.h5")
+ f = h5py.File(filename, "w")
+ f["scalar"] = 10
+ f["image"] = data
+ f["cube"] = [data, data + 1, data + 2]
+ f["complex_image"] = data * 1j
+ f["group/image"] = data
+ f["nxdata/foo"] = 10
+ f["nxdata"].attrs["NX_class"] = u"NXdata"
+ f.close()
+
+ filename = _tmpDirectory + "/badformat.h5"
+ with io.open(filename, "wb") as f:
+ f.write(b"{\nHello Nurse!")
+
+
+def tearDownModule():
+ global _tmpDirectory
+ shutil.rmtree(_tmpDirectory)
+ _tmpDirectory = None
+
+
+class _UtilsMixin(object):
+
+ def createDialog(self):
+ self._deleteDialog()
+ self._dialog = self._createDialog()
+ return self._dialog
+
+ def _createDialog(self):
+ return DataFileDialog()
+
+ def _deleteDialog(self):
+ if not hasattr(self, "_dialog"):
+ return
+ if self._dialog is not None:
+ ref = weakref.ref(self._dialog)
+ self._dialog = None
+ self.qWaitForDestroy(ref)
+
+ def qWaitForPendingActions(self, dialog):
+ for _ in range(20):
+ if not dialog.hasPendingEvents():
+ return
+ self.qWait(10)
+ raise RuntimeError("Still have pending actions")
+
+ def assertSamePath(self, path1, path2):
+ path1_ = os.path.normcase(path1)
+ path2_ = os.path.normcase(path2)
+ if path1_ != path2_:
+ # Use the unittest API to log and display error
+ self.assertEqual(path1, path2)
+
+ def assertNotSamePath(self, path1, path2):
+ path1_ = os.path.normcase(path1)
+ path2_ = os.path.normcase(path2)
+ if path1_ == path2_:
+ # Use the unittest API to log and display error
+ self.assertNotEqual(path1, path2)
+
+
+class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
+
+ def tearDown(self):
+ self._deleteDialog()
+ testutils.TestCaseQt.tearDown(self)
+
+ def testDisplayAndKeyEscape(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+
+ self.keyClick(dialog, qt.Qt.Key_Escape)
+ self.assertFalse(dialog.isVisible())
+ self.assertEqual(dialog.result(), qt.QDialog.Rejected)
+
+ def testDisplayAndClickCancel(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+
+ button = testutils.findChildren(dialog, qt.QPushButton, name="cancel")[0]
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.assertFalse(dialog.isVisible())
+ self.assertFalse(dialog.isVisible())
+ self.assertEqual(dialog.result(), qt.QDialog.Rejected)
+
+ def testDisplayAndClickLockedOpen(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+
+ button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
+ self.mouseClick(button, qt.Qt.LeftButton)
+ # open button locked, dialog is not closed
+ self.assertTrue(dialog.isVisible())
+ self.assertEqual(dialog.result(), qt.QDialog.Rejected)
+
+ def testSelectRoot_Activate(self):
+ dialog = self.createDialog()
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+ filename = _tmpDirectory + "/data.h5"
+ dialog.selectFile(os.path.dirname(filename))
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().index(filename)
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
+ self.assertTrue(button.isEnabled())
+ self.mouseClick(button, qt.Qt.LeftButton)
+ url = silx.io.url.DataUrl(dialog.selectedUrl())
+ self.assertTrue(url.data_path() is not None)
+ self.assertFalse(dialog.isVisible())
+ self.assertEqual(dialog.result(), qt.QDialog.Accepted)
+
+ def testSelectGroup_Activate(self):
+ dialog = self.createDialog()
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+ filename = _tmpDirectory + "/data.h5"
+ dialog.selectFile(os.path.dirname(filename))
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().index(filename)
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/group"])
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
+ self.assertTrue(button.isEnabled())
+ self.mouseClick(button, qt.Qt.LeftButton)
+ url = silx.io.url.DataUrl(dialog.selectedUrl())
+ self.assertEqual(url.data_path(), "/group")
+ self.assertFalse(dialog.isVisible())
+ self.assertEqual(dialog.result(), qt.QDialog.Accepted)
+
+ def testSelectDataset_Activate(self):
+ dialog = self.createDialog()
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+ filename = _tmpDirectory + "/data.h5"
+ dialog.selectFile(os.path.dirname(filename))
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().index(filename)
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/scalar"])
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
+ self.assertTrue(button.isEnabled())
+ self.mouseClick(button, qt.Qt.LeftButton)
+ url = silx.io.url.DataUrl(dialog.selectedUrl())
+ self.assertEqual(url.data_path(), "/scalar")
+ self.assertFalse(dialog.isVisible())
+ self.assertEqual(dialog.result(), qt.QDialog.Accepted)
+
+ def testClickOnBackToParentTool(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ url = testutils.findChildren(dialog, qt.QLineEdit, name="url")[0]
+ action = testutils.findChildren(dialog, qt.QAction, name="toParentAction")[0]
+ toParentButton = testutils.getQToolButtonFromAction(action)
+ filename = _tmpDirectory + "/data/data.h5"
+
+ # init state
+ path = silx.io.url.DataUrl(file_path=filename, data_path="/group/image").path()
+ dialog.selectUrl(path)
+ self.qWaitForPendingActions(dialog)
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/image").path()
+ self.assertSamePath(url.text(), path)
+ # test
+ self.mouseClick(toParentButton, qt.Qt.LeftButton)
+ self.qWaitForPendingActions(dialog)
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
+ self.assertSamePath(url.text(), path)
+
+ self.mouseClick(toParentButton, qt.Qt.LeftButton)
+ self.qWaitForPendingActions(dialog)
+ self.assertSamePath(url.text(), _tmpDirectory + "/data")
+
+ self.mouseClick(toParentButton, qt.Qt.LeftButton)
+ self.qWaitForPendingActions(dialog)
+ self.assertSamePath(url.text(), _tmpDirectory)
+
+ def testClickOnBackToRootTool(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ url = testutils.findChildren(dialog, qt.QLineEdit, name="url")[0]
+ action = testutils.findChildren(dialog, qt.QAction, name="toRootFileAction")[0]
+ button = testutils.getQToolButtonFromAction(action)
+ filename = _tmpDirectory + "/data.h5"
+
+ # init state
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/image").path()
+ dialog.selectUrl(path)
+ self.qWaitForPendingActions(dialog)
+ self.assertSamePath(url.text(), path)
+ self.assertTrue(button.isEnabled())
+ # test
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qWaitForPendingActions(dialog)
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
+ self.assertSamePath(url.text(), path)
+ # self.assertFalse(button.isEnabled())
+
+ def testClickOnBackToDirectoryTool(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ url = testutils.findChildren(dialog, qt.QLineEdit, name="url")[0]
+ action = testutils.findChildren(dialog, qt.QAction, name="toDirectoryAction")[0]
+ button = testutils.getQToolButtonFromAction(action)
+ filename = _tmpDirectory + "/data.h5"
+
+ # init state
+ path = silx.io.url.DataUrl(file_path=filename, data_path="/group/image").path()
+ dialog.selectUrl(path)
+ self.qWaitForPendingActions(dialog)
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/image").path()
+ self.assertSamePath(url.text(), path)
+ self.assertTrue(button.isEnabled())
+ # test
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qWaitForPendingActions(dialog)
+ self.assertSamePath(url.text(), _tmpDirectory)
+ self.assertFalse(button.isEnabled())
+
+ # FIXME: There is an unreleased qt.QWidget without nameObject
+ # No idea where it come from.
+ self.allowedLeakingWidgets = 1
+
+ def testClickOnHistoryTools(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ url = testutils.findChildren(dialog, qt.QLineEdit, name="url")[0]
+ forwardAction = testutils.findChildren(dialog, qt.QAction, name="forwardAction")[0]
+ backwardAction = testutils.findChildren(dialog, qt.QAction, name="backwardAction")[0]
+ filename = _tmpDirectory + "/data.h5"
+
+ dialog.setDirectory(_tmpDirectory)
+ self.qWaitForPendingActions(dialog)
+ # No way to use QTest.mouseDClick with QListView, QListWidget
+ # Then we feed the history using selectPath
+ dialog.selectUrl(filename)
+ self.qWaitForPendingActions(dialog)
+ path2 = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
+ dialog.selectUrl(path2)
+ self.qWaitForPendingActions(dialog)
+ path3 = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group").path()
+ dialog.selectUrl(path3)
+ self.qWaitForPendingActions(dialog)
+ self.assertFalse(forwardAction.isEnabled())
+ self.assertTrue(backwardAction.isEnabled())
+
+ button = testutils.getQToolButtonFromAction(backwardAction)
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qWaitForPendingActions(dialog)
+ self.assertTrue(forwardAction.isEnabled())
+ self.assertTrue(backwardAction.isEnabled())
+ self.assertSamePath(url.text(), path2)
+
+ button = testutils.getQToolButtonFromAction(forwardAction)
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qWaitForPendingActions(dialog)
+ self.assertFalse(forwardAction.isEnabled())
+ self.assertTrue(backwardAction.isEnabled())
+ self.assertSamePath(url.text(), path3)
+
+ def testSelectImageFromEdf(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ filename = _tmpDirectory + "/singleimage.edf"
+ url = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/scan_0/instrument/detector_0/data")
+ dialog.selectUrl(url.path())
+ self.assertEqual(dialog._selectedData().shape, (100, 100))
+ self.assertSamePath(dialog.selectedFile(), filename)
+ self.assertSamePath(dialog.selectedUrl(), url.path())
+
+ def testSelectImage(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ filename = _tmpDirectory + "/data.h5"
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/image").path()
+ dialog.selectUrl(path)
+ # test
+ self.assertEqual(dialog._selectedData().shape, (100, 100))
+ self.assertSamePath(dialog.selectedFile(), filename)
+ self.assertSamePath(dialog.selectedUrl(), path)
+
+ def testSelectScalar(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ filename = _tmpDirectory + "/data.h5"
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/scalar").path()
+ dialog.selectUrl(path)
+ # test
+ self.assertEqual(dialog._selectedData()[()], 10)
+ self.assertSamePath(dialog.selectedFile(), filename)
+ self.assertSamePath(dialog.selectedUrl(), path)
+
+ def testSelectGroup(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ filename = _tmpDirectory + "/data.h5"
+ uri = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group")
+ dialog.selectUrl(uri.path())
+ self.qWaitForPendingActions(dialog)
+ # test
+ self.assertTrue(silx.io.is_group(dialog._selectedData()))
+ self.assertSamePath(dialog.selectedFile(), filename)
+ uri = silx.io.url.DataUrl(dialog.selectedUrl())
+ self.assertSamePath(uri.data_path(), "/group")
+
+ def testSelectRoot(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ filename = _tmpDirectory + "/data.h5"
+ uri = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/")
+ dialog.selectUrl(uri.path())
+ self.qWaitForPendingActions(dialog)
+ # test
+ self.assertTrue(silx.io.is_file(dialog._selectedData()))
+ self.assertSamePath(dialog.selectedFile(), filename)
+ uri = silx.io.url.DataUrl(dialog.selectedUrl())
+ self.assertSamePath(uri.data_path(), "/")
+
+ def testSelectH5_Activate(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ dialog.selectUrl(_tmpDirectory)
+ self.qWaitForPendingActions(dialog)
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ filename = _tmpDirectory + "/data.h5"
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
+ index = browser.rootIndex().model().index(filename)
+ # click
+ browser.selectIndex(index)
+ # double click
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+ # test
+ self.assertSamePath(dialog.selectedUrl(), path)
+
+ def testSelectBadFileFormat_Activate(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ dialog.selectUrl(_tmpDirectory)
+ self.qWaitForPendingActions(dialog)
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ filename = _tmpDirectory + "/badformat.h5"
+ index = browser.model().index(filename)
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+ # test
+ self.assertSamePath(dialog.selectedUrl(), filename)
+
+ def _countSelectableItems(self, model, rootIndex):
+ selectable = 0
+ for i in range(model.rowCount(rootIndex)):
+ index = model.index(i, 0, rootIndex)
+ flags = model.flags(index)
+ isEnabled = (int(flags) & qt.Qt.ItemIsEnabled) != 0
+ if isEnabled:
+ selectable += 1
+ return selectable
+
+ def testFilterExtensions(self):
+ dialog = self.createDialog()
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ dialog.selectUrl(_tmpDirectory)
+ self.qWaitForPendingActions(dialog)
+ self.assertEqual(self._countSelectableItems(browser.model(), browser.rootIndex()), 4)
+
+
+class TestDataFileDialog_FilterDataset(testutils.TestCaseQt, _UtilsMixin):
+
+ def tearDown(self):
+ self._deleteDialog()
+ testutils.TestCaseQt.tearDown(self)
+
+ def _createDialog(self):
+ dialog = DataFileDialog()
+ dialog.setFilterMode(DataFileDialog.FilterMode.ExistingDataset)
+ return dialog
+
+ def testSelectGroup_Activate(self):
+ dialog = self.createDialog()
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+ filename = _tmpDirectory + "/data.h5"
+ dialog.selectFile(os.path.dirname(filename))
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().index(filename)
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/group"])
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
+ self.assertFalse(button.isEnabled())
+
+ def testSelectDataset_Activate(self):
+ dialog = self.createDialog()
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+ filename = _tmpDirectory + "/data.h5"
+ dialog.selectFile(os.path.dirname(filename))
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().index(filename)
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/scalar"])
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
+ self.assertTrue(button.isEnabled())
+ self.mouseClick(button, qt.Qt.LeftButton)
+ url = silx.io.url.DataUrl(dialog.selectedUrl())
+ self.assertEqual(url.data_path(), "/scalar")
+ self.assertFalse(dialog.isVisible())
+ self.assertEqual(dialog.result(), qt.QDialog.Accepted)
+
+ data = dialog.selectedData()
+ self.assertEqual(data, 10)
+
+
+class TestDataFileDialog_FilterGroup(testutils.TestCaseQt, _UtilsMixin):
+
+ def tearDown(self):
+ self._deleteDialog()
+ testutils.TestCaseQt.tearDown(self)
+
+ def _createDialog(self):
+ dialog = DataFileDialog()
+ dialog.setFilterMode(DataFileDialog.FilterMode.ExistingGroup)
+ return dialog
+
+ def testSelectGroup_Activate(self):
+ dialog = self.createDialog()
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+ filename = _tmpDirectory + "/data.h5"
+ dialog.selectFile(os.path.dirname(filename))
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().index(filename)
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/group"])
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
+ self.assertTrue(button.isEnabled())
+ self.mouseClick(button, qt.Qt.LeftButton)
+ url = silx.io.url.DataUrl(dialog.selectedUrl())
+ self.assertEqual(url.data_path(), "/group")
+ self.assertFalse(dialog.isVisible())
+ self.assertEqual(dialog.result(), qt.QDialog.Accepted)
+
+ self.assertRaises(Exception, dialog.selectedData)
+
+ def testSelectDataset_Activate(self):
+ dialog = self.createDialog()
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+ filename = _tmpDirectory + "/data.h5"
+ dialog.selectFile(os.path.dirname(filename))
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().index(filename)
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/scalar"])
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
+ self.assertFalse(button.isEnabled())
+
+
+class TestDataFileDialog_FilterNXdata(testutils.TestCaseQt, _UtilsMixin):
+
+ def tearDown(self):
+ self._deleteDialog()
+ testutils.TestCaseQt.tearDown(self)
+
+ def _createDialog(self):
+ def customFilter(obj):
+ if "NX_class" in obj.attrs:
+ return obj.attrs["NX_class"] == u"NXdata"
+ return False
+
+ dialog = DataFileDialog()
+ dialog.setFilterMode(DataFileDialog.FilterMode.ExistingGroup)
+ dialog.setFilterCallback(customFilter)
+ return dialog
+
+ def testSelectGroupRefused_Activate(self):
+ dialog = self.createDialog()
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+ filename = _tmpDirectory + "/data.h5"
+ dialog.selectFile(os.path.dirname(filename))
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().index(filename)
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/group"])
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
+ self.assertFalse(button.isEnabled())
+
+ self.assertRaises(Exception, dialog.selectedData)
+
+ def testSelectNXdataAccepted_Activate(self):
+ dialog = self.createDialog()
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+ filename = _tmpDirectory + "/data.h5"
+ dialog.selectFile(os.path.dirname(filename))
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().index(filename)
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/nxdata"])
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
+ self.assertTrue(button.isEnabled())
+ self.mouseClick(button, qt.Qt.LeftButton)
+ url = silx.io.url.DataUrl(dialog.selectedUrl())
+ self.assertEqual(url.data_path(), "/nxdata")
+ self.assertFalse(dialog.isVisible())
+ self.assertEqual(dialog.result(), qt.QDialog.Accepted)
+
+
+class TestDataFileDialogApi(testutils.TestCaseQt, _UtilsMixin):
+
+ def tearDown(self):
+ self._deleteDialog()
+ testutils.TestCaseQt.tearDown(self)
+
+ def _createDialog(self):
+ dialog = DataFileDialog()
+ return dialog
+
+ def testSaveRestoreState(self):
+ dialog = self.createDialog()
+ dialog.setDirectory(_tmpDirectory)
+ self.qWaitForPendingActions(dialog)
+ state = dialog.saveState()
+ dialog = None
+
+ dialog2 = self.createDialog()
+ result = dialog2.restoreState(state)
+ self.assertTrue(result)
+ dialog2 = None
+
+ def printState(self):
+ """
+ Print state of the ImageFileDialog.
+
+ Can be used to add or regenerate `STATE_VERSION1_QT4` or
+ `STATE_VERSION1_QT5`.
+
+ >>> ./run_tests.py -v silx.gui.dialog.test.test_datafiledialog.TestDataFileDialogApi.printState
+ """
+ dialog = self.createDialog()
+ dialog.setDirectory("")
+ dialog.setHistory([])
+ dialog.setSidebarUrls([])
+ state = dialog.saveState()
+ string = ""
+ strings = []
+ for i in range(state.size()):
+ d = state.data()[i]
+ if not isinstance(d, int):
+ d = ord(d)
+ if d > 0x20 and d < 0x7F:
+ string += chr(d)
+ else:
+ string += "\\x%02X" % d
+ if len(string) > 60:
+ strings.append(string)
+ string = ""
+ strings.append(string)
+ strings = ["b'%s'" % s for s in strings]
+ print()
+ print("\\\n".join(strings))
+
+ STATE_VERSION1_QT4 = b''\
+ b'\x00\x00\x00Z\x00s\x00i\x00l\x00x\x00.\x00g\x00u\x00i\x00.\x00'\
+ b'd\x00i\x00a\x00l\x00o\x00g\x00.\x00D\x00a\x00t\x00a\x00F\x00i'\
+ b'\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00.\x00D\x00a\x00t\x00'\
+ b'a\x00F\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00\x00\x00'\
+ b'\x01\x00\x00\x00\x0C\x00\x00\x00\x00"\x00\x00\x00\xFF\x00\x00'\
+ b'\x00\x00\x00\x00\x00\x03\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF'\
+ b'\xFF\xFF\x01\x00\x00\x00\x06\x01\x00\x00\x00\x01\x00\x00\x00\x00'\
+ b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0C\x00\x00\x00\x00'\
+ b'}\x00\x00\x00\x0E\x00B\x00r\x00o\x00w\x00s\x00e\x00r\x00\x00\x00'\
+ b'\x01\x00\x00\x00\x0C\x00\x00\x00\x00Z\x00\x00\x00\xFF\x00\x00'\
+ b'\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00'\
+ b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'\
+ b'\x00\x01\x90\x00\x00\x00\x04\x01\x01\x00\x00\x00\x00\x00\x00\x00'\
+ b'\x00\x00\x00\x00\x00\x00\x00d\xFF\xFF\xFF\xFF\x00\x00\x00\x81'\
+ b'\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x01\x90\x00\x00\x00\x04'\
+ b'\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x00'\
+ b'\x01\xFF\xFF\xFF\xFF'
+ """Serialized state on Qt4. Generated using :meth:`printState`"""
+
+ STATE_VERSION1_QT5 = b''\
+ b'\x00\x00\x00Z\x00s\x00i\x00l\x00x\x00.\x00g\x00u\x00i\x00.\x00'\
+ b'd\x00i\x00a\x00l\x00o\x00g\x00.\x00D\x00a\x00t\x00a\x00F\x00i'\
+ b'\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00.\x00D\x00a\x00t\x00'\
+ b'a\x00F\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00\x00\x00'\
+ b'\x01\x00\x00\x00\x0C\x00\x00\x00\x00#\x00\x00\x00\xFF\x00\x00'\
+ b'\x00\x01\x00\x00\x00\x03\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF'\
+ b'\xFF\xFF\x01\xFF\xFF\xFF\xFF\x01\x00\x00\x00\x01\x00\x00\x00\x00'\
+ b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0C\x00\x00\x00'\
+ b'\x00\xAA\x00\x00\x00\x0E\x00B\x00r\x00o\x00w\x00s\x00e\x00r\x00'\
+ b'\x00\x00\x01\x00\x00\x00\x0C\x00\x00\x00\x00\x87\x00\x00\x00\xFF'\
+ b'\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00'\
+ b'\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'\
+ b'\x00\x00\x00\x01\x90\x00\x00\x00\x04\x01\x01\x00\x00\x00\x00\x00'\
+ b'\x00\x00\x00\x00\x00\x00\x00\x00\x00d\xFF\xFF\xFF\xFF\x00\x00'\
+ b'\x00\x81\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00d\x00\x00'\
+ b'\x00\x01\x00\x00\x00\x00\x00\x00\x00d\x00\x00\x00\x01\x00\x00'\
+ b'\x00\x00\x00\x00\x00d\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00'\
+ b'\x00d\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x03\xE8\x00\xFF'\
+ b'\xFF\xFF\xFF\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x00\x01'
+ """Serialized state on Qt5. Generated using :meth:`printState`"""
+
+ def testAvoidRestoreRegression_Version1(self):
+ version = qt.qVersion().split(".")[0]
+ if version == "4":
+ state = self.STATE_VERSION1_QT4
+ elif version == "5":
+ state = self.STATE_VERSION1_QT5
+ else:
+ self.skipTest("Resource not available")
+
+ state = qt.QByteArray(state)
+ dialog = self.createDialog()
+ result = dialog.restoreState(state)
+ self.assertTrue(result)
+
+ def testRestoreRobusness(self):
+ """What's happen if you try to open a config file with a different
+ binding."""
+ state = qt.QByteArray(self.STATE_VERSION1_QT4)
+ dialog = self.createDialog()
+ dialog.restoreState(state)
+ state = qt.QByteArray(self.STATE_VERSION1_QT5)
+ dialog = None
+ dialog = self.createDialog()
+ dialog.restoreState(state)
+
+ def testRestoreNonExistingDirectory(self):
+ directory = os.path.join(_tmpDirectory, "dir")
+ os.mkdir(directory)
+ dialog = self.createDialog()
+ dialog.setDirectory(directory)
+ self.qWaitForPendingActions(dialog)
+ state = dialog.saveState()
+ os.rmdir(directory)
+ dialog = None
+
+ dialog2 = self.createDialog()
+ result = dialog2.restoreState(state)
+ self.assertTrue(result)
+ self.assertNotEqual(dialog2.directory(), directory)
+
+ def testHistory(self):
+ dialog = self.createDialog()
+ history = dialog.history()
+ dialog.setHistory([])
+ self.assertEqual(dialog.history(), [])
+ dialog.setHistory(history)
+ self.assertEqual(dialog.history(), history)
+
+ def testSidebarUrls(self):
+ dialog = self.createDialog()
+ urls = dialog.sidebarUrls()
+ dialog.setSidebarUrls([])
+ self.assertEqual(dialog.sidebarUrls(), [])
+ dialog.setSidebarUrls(urls)
+ self.assertEqual(dialog.sidebarUrls(), urls)
+
+ def testDirectory(self):
+ dialog = self.createDialog()
+ self.qWaitForPendingActions(dialog)
+ dialog.selectUrl(_tmpDirectory)
+ self.qWaitForPendingActions(dialog)
+ self.assertSamePath(dialog.directory(), _tmpDirectory)
+
+ def testBadFileFormat(self):
+ dialog = self.createDialog()
+ dialog.selectUrl(_tmpDirectory + "/badformat.h5")
+ self.qWaitForPendingActions(dialog)
+ self.assertIsNone(dialog._selectedData())
+
+ def testBadPath(self):
+ dialog = self.createDialog()
+ dialog.selectUrl("#$%/#$%")
+ self.qWaitForPendingActions(dialog)
+ self.assertIsNone(dialog._selectedData())
+
+ def testBadSubpath(self):
+ dialog = self.createDialog()
+ self.qWaitForPendingActions(dialog)
+
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+
+ filename = _tmpDirectory + "/data.h5"
+ url = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/foobar")
+ dialog.selectUrl(url.path())
+ self.qWaitForPendingActions(dialog)
+ self.assertIsNotNone(dialog._selectedData())
+
+ # an existing node is browsed, but the wrong path is selected
+ index = browser.rootIndex()
+ obj = index.model().data(index, role=Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ self.assertEqual(obj.name, "/group")
+ url = silx.io.url.DataUrl(dialog.selectedUrl())
+ self.assertEqual(url.data_path(), "/group")
+
+ def testUnsupportedSlicingPath(self):
+ dialog = self.createDialog()
+ self.qWaitForPendingActions(dialog)
+ dialog.selectUrl(_tmpDirectory + "/data.h5?path=/cube&slice=0")
+ self.qWaitForPendingActions(dialog)
+ data = dialog._selectedData()
+ if data is None:
+ # Maybe nothing is selected
+ self.assertTrue(True)
+ else:
+ # Maybe the cube is selected but not sliced
+ self.assertEqual(len(data.shape), 3)
diff --git a/src/silx/gui/dialog/test/test_imagefiledialog.py b/src/silx/gui/dialog/test/test_imagefiledialog.py
new file mode 100644
index 0000000..9e204b9
--- /dev/null
+++ b/src/silx/gui/dialog/test/test_imagefiledialog.py
@@ -0,0 +1,772 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test for silx.gui.hdf5 module"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "08/03/2019"
+
+
+import unittest
+import tempfile
+import numpy
+import shutil
+import os
+import io
+import weakref
+import fabio
+import h5py
+import silx.io.url
+from silx.gui import qt
+from silx.gui.utils import testutils
+from ..ImageFileDialog import ImageFileDialog
+from silx.gui.colors import Colormap
+from silx.gui.hdf5 import Hdf5TreeModel
+
+_tmpDirectory = None
+
+
+def setUpModule():
+ global _tmpDirectory
+ _tmpDirectory = tempfile.mkdtemp(prefix=__name__)
+
+ data = numpy.arange(100 * 100)
+ data.shape = 100, 100
+
+ filename = _tmpDirectory + "/singleimage.edf"
+ image = fabio.edfimage.EdfImage(data=data)
+ image.write(filename)
+
+ filename = _tmpDirectory + "/multiframe.edf"
+ image = fabio.edfimage.EdfImage(data=data)
+ image.append_frame(data=data + 1)
+ image.append_frame(data=data + 2)
+ image.write(filename)
+
+ filename = _tmpDirectory + "/singleimage.msk"
+ image = fabio.fit2dmaskimage.Fit2dMaskImage(data=data % 2 == 1)
+ image.write(filename)
+
+ filename = _tmpDirectory + "/data.h5"
+ with h5py.File(filename, "w") as f:
+ f["scalar"] = 10
+ f["image"] = data
+ f["cube"] = [data, data + 1, data + 2]
+ f["single_frame"] = [data + 5]
+ f["complex_image"] = data * 1j
+ f["group/image"] = data
+
+ directory = os.path.join(_tmpDirectory, "data")
+ os.mkdir(directory)
+ filename = os.path.join(directory, "data.h5")
+ with h5py.File(filename, "w") as f:
+ f["scalar"] = 10
+ f["image"] = data
+ f["cube"] = [data, data + 1, data + 2]
+ f["single_frame"] = [data + 5]
+ f["complex_image"] = data * 1j
+ f["group/image"] = data
+
+ filename = _tmpDirectory + "/badformat.edf"
+ with io.open(filename, "wb") as f:
+ f.write(b"{\nHello Nurse!")
+
+
+def tearDownModule():
+ global _tmpDirectory
+ shutil.rmtree(_tmpDirectory)
+ _tmpDirectory = None
+
+
+class _UtilsMixin(object):
+
+ def createDialog(self):
+ self._deleteDialog()
+ self._dialog = self._createDialog()
+ return self._dialog
+
+ def _createDialog(self):
+ return ImageFileDialog()
+
+ def _deleteDialog(self):
+ if not hasattr(self, "_dialog"):
+ return
+ if self._dialog is not None:
+ ref = weakref.ref(self._dialog)
+ self._dialog = None
+ self.qWaitForDestroy(ref)
+
+ def qWaitForPendingActions(self, dialog):
+ for _ in range(20):
+ if not dialog.hasPendingEvents():
+ return
+ self.qWait(10)
+ raise RuntimeError("Still have pending actions")
+
+ def assertSamePath(self, path1, path2):
+ path1_ = os.path.normcase(path1)
+ path2_ = os.path.normcase(path2)
+ if path1_ != path2_:
+ # Use the unittest API to log and display error
+ self.assertEqual(path1, path2)
+
+ def assertNotSamePath(self, path1, path2):
+ path1_ = os.path.normcase(path1)
+ path2_ = os.path.normcase(path2)
+ if path1_ == path2_:
+ # Use the unittest API to log and display error
+ self.assertNotEqual(path1, path2)
+
+
+class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
+
+ def tearDown(self):
+ self._deleteDialog()
+ testutils.TestCaseQt.tearDown(self)
+
+ def testDisplayAndKeyEscape(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+
+ self.keyClick(dialog, qt.Qt.Key_Escape)
+ self.assertFalse(dialog.isVisible())
+ self.assertEqual(dialog.result(), qt.QDialog.Rejected)
+
+ def testDisplayAndClickCancel(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+
+ button = testutils.findChildren(dialog, qt.QPushButton, name="cancel")[0]
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.assertFalse(dialog.isVisible())
+ self.assertFalse(dialog.isVisible())
+ self.assertEqual(dialog.result(), qt.QDialog.Rejected)
+
+ def testDisplayAndClickLockedOpen(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+
+ button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
+ self.mouseClick(button, qt.Qt.LeftButton)
+ # open button locked, dialog is not closed
+ self.assertTrue(dialog.isVisible())
+ self.assertEqual(dialog.result(), qt.QDialog.Rejected)
+
+ def testDisplayAndClickOpen(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+ filename = _tmpDirectory + "/singleimage.edf"
+ dialog.selectFile(filename)
+ self.qWaitForPendingActions(dialog)
+
+ button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
+ self.assertTrue(button.isEnabled())
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.assertFalse(dialog.isVisible())
+ self.assertEqual(dialog.result(), qt.QDialog.Accepted)
+
+ def testClickOnShortcut(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ sidebar = testutils.findChildren(dialog, qt.QListView, name="sidebar")[0]
+ url = testutils.findChildren(dialog, qt.QLineEdit, name="url")[0]
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ dialog.setDirectory(_tmpDirectory)
+ self.qWaitForPendingActions(dialog)
+
+ self.assertSamePath(url.text(), _tmpDirectory)
+
+ urls = sidebar.urls()
+ if len(urls) == 0:
+ self.skipTest("No sidebar path")
+ path = urls[0].path()
+ if path != "" and not os.path.exists(path):
+ self.skipTest("Sidebar path do not exists")
+
+ index = sidebar.model().index(0, 0)
+ # rect = sidebar.visualRect(index)
+ # self.mouseClick(sidebar, qt.Qt.LeftButton, pos=rect.center())
+ # Using mouse click is not working, let's use the selection API
+ sidebar.selectionModel().select(index, qt.QItemSelectionModel.ClearAndSelect)
+ self.qWaitForPendingActions(dialog)
+
+ index = browser.rootIndex()
+ if not index.isValid():
+ path = ""
+ else:
+ path = index.model().filePath(index)
+ self.assertNotSamePath(_tmpDirectory, path)
+ self.assertNotSamePath(url.text(), _tmpDirectory)
+
+ def testClickOnDetailView(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ action = testutils.findChildren(dialog, qt.QAction, name="detailModeAction")[0]
+ detailModeButton = testutils.getQToolButtonFromAction(action)
+ self.mouseClick(detailModeButton, qt.Qt.LeftButton)
+ self.assertEqual(dialog.viewMode(), qt.QFileDialog.Detail)
+
+ action = testutils.findChildren(dialog, qt.QAction, name="listModeAction")[0]
+ listModeButton = testutils.getQToolButtonFromAction(action)
+ self.mouseClick(listModeButton, qt.Qt.LeftButton)
+ self.assertEqual(dialog.viewMode(), qt.QFileDialog.List)
+
+ def testClickOnBackToParentTool(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ url = testutils.findChildren(dialog, qt.QLineEdit, name="url")[0]
+ action = testutils.findChildren(dialog, qt.QAction, name="toParentAction")[0]
+ toParentButton = testutils.getQToolButtonFromAction(action)
+ filename = _tmpDirectory + "/data/data.h5"
+
+ # init state
+ path = silx.io.url.DataUrl(file_path=filename, data_path="/group/image").path()
+ dialog.selectUrl(path)
+ self.qWaitForPendingActions(dialog)
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/image").path()
+ self.assertSamePath(url.text(), path)
+ # test
+ self.mouseClick(toParentButton, qt.Qt.LeftButton)
+ self.qWaitForPendingActions(dialog)
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
+ self.assertSamePath(url.text(), path)
+
+ self.mouseClick(toParentButton, qt.Qt.LeftButton)
+ self.qWaitForPendingActions(dialog)
+ self.assertSamePath(url.text(), _tmpDirectory + "/data")
+
+ self.mouseClick(toParentButton, qt.Qt.LeftButton)
+ self.qWaitForPendingActions(dialog)
+ self.assertSamePath(url.text(), _tmpDirectory)
+
+ def testClickOnBackToRootTool(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ url = testutils.findChildren(dialog, qt.QLineEdit, name="url")[0]
+ action = testutils.findChildren(dialog, qt.QAction, name="toRootFileAction")[0]
+ button = testutils.getQToolButtonFromAction(action)
+ filename = _tmpDirectory + "/data.h5"
+
+ # init state
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/image").path()
+ dialog.selectUrl(path)
+ self.qWaitForPendingActions(dialog)
+ self.assertSamePath(url.text(), path)
+ self.assertTrue(button.isEnabled())
+ # test
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qWaitForPendingActions(dialog)
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
+ self.assertSamePath(url.text(), path)
+ # self.assertFalse(button.isEnabled())
+
+ def testClickOnBackToDirectoryTool(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ url = testutils.findChildren(dialog, qt.QLineEdit, name="url")[0]
+ action = testutils.findChildren(dialog, qt.QAction, name="toDirectoryAction")[0]
+ button = testutils.getQToolButtonFromAction(action)
+ filename = _tmpDirectory + "/data.h5"
+
+ # init state
+ path = silx.io.url.DataUrl(file_path=filename, data_path="/group/image").path()
+ dialog.selectUrl(path)
+ self.qWaitForPendingActions(dialog)
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/image").path()
+ self.assertSamePath(url.text(), path)
+ self.assertTrue(button.isEnabled())
+ # test
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qWaitForPendingActions(dialog)
+ self.assertSamePath(url.text(), _tmpDirectory)
+ self.assertFalse(button.isEnabled())
+
+ # FIXME: There is an unreleased qt.QWidget without nameObject
+ # No idea where it come from.
+ self.allowedLeakingWidgets = 1
+
+ def testClickOnHistoryTools(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ url = testutils.findChildren(dialog, qt.QLineEdit, name="url")[0]
+ forwardAction = testutils.findChildren(dialog, qt.QAction, name="forwardAction")[0]
+ backwardAction = testutils.findChildren(dialog, qt.QAction, name="backwardAction")[0]
+ filename = _tmpDirectory + "/data.h5"
+
+ dialog.setDirectory(_tmpDirectory)
+ self.qWaitForPendingActions(dialog)
+ # No way to use QTest.mouseDClick with QListView, QListWidget
+ # Then we feed the history using selectPath
+ dialog.selectUrl(filename)
+ self.qWaitForPendingActions(dialog)
+ path2 = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
+ dialog.selectUrl(path2)
+ self.qWaitForPendingActions(dialog)
+ path3 = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group").path()
+ dialog.selectUrl(path3)
+ self.qWaitForPendingActions(dialog)
+ self.assertFalse(forwardAction.isEnabled())
+ self.assertTrue(backwardAction.isEnabled())
+
+ button = testutils.getQToolButtonFromAction(backwardAction)
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qWaitForPendingActions(dialog)
+ self.assertTrue(forwardAction.isEnabled())
+ self.assertTrue(backwardAction.isEnabled())
+ self.assertSamePath(url.text(), path2)
+
+ button = testutils.getQToolButtonFromAction(forwardAction)
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qWaitForPendingActions(dialog)
+ self.assertFalse(forwardAction.isEnabled())
+ self.assertTrue(backwardAction.isEnabled())
+ self.assertSamePath(url.text(), path3)
+
+ def testSelectImageFromEdf(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ filename = _tmpDirectory + "/singleimage.edf"
+ path = filename
+ dialog.selectUrl(path)
+ self.assertEqual(dialog.selectedImage().shape, (100, 100))
+ self.assertSamePath(dialog.selectedFile(), filename)
+ path = silx.io.url.DataUrl(scheme="fabio", file_path=filename).path()
+ self.assertSamePath(dialog.selectedUrl(), path)
+
+ def testSelectImageFromEdf_Activate(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ dialog.selectUrl(_tmpDirectory)
+ self.qWaitForPendingActions(dialog)
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ filename = _tmpDirectory + "/singleimage.edf"
+ path = silx.io.url.DataUrl(scheme="fabio", file_path=filename).path()
+ index = browser.rootIndex().model().index(filename)
+ # click
+ browser.selectIndex(index)
+ # double click
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+ # test
+ self.assertEqual(dialog.selectedImage().shape, (100, 100))
+ self.assertSamePath(dialog.selectedFile(), filename)
+ self.assertSamePath(dialog.selectedUrl(), path)
+
+ def testSelectFrameFromEdf(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ filename = _tmpDirectory + "/multiframe.edf"
+ path = silx.io.url.DataUrl(scheme="fabio", file_path=filename, data_slice=(1,)).path()
+ dialog.selectUrl(path)
+ # test
+ image = dialog.selectedImage()
+ self.assertEqual(image.shape, (100, 100))
+ self.assertEqual(image[0, 0], 1)
+ self.assertSamePath(dialog.selectedFile(), filename)
+ self.assertSamePath(dialog.selectedUrl(), path)
+
+ def testSelectImageFromMsk(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ filename = _tmpDirectory + "/singleimage.msk"
+ path = silx.io.url.DataUrl(scheme="fabio", file_path=filename).path()
+ dialog.selectUrl(path)
+ # test
+ self.assertEqual(dialog.selectedImage().shape, (100, 100))
+ self.assertSamePath(dialog.selectedFile(), filename)
+ self.assertSamePath(dialog.selectedUrl(), path)
+
+ def testSelectImageFromH5(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ filename = _tmpDirectory + "/data.h5"
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/image").path()
+ dialog.selectUrl(path)
+ # test
+ self.assertEqual(dialog.selectedImage().shape, (100, 100))
+ self.assertSamePath(dialog.selectedFile(), filename)
+ self.assertSamePath(dialog.selectedUrl(), path)
+
+ def testSelectH5_Activate(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ dialog.selectUrl(_tmpDirectory)
+ self.qWaitForPendingActions(dialog)
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ filename = _tmpDirectory + "/data.h5"
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
+ index = browser.rootIndex().model().index(filename)
+ # click
+ browser.selectIndex(index)
+ # double click
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+ # test
+ self.assertSamePath(dialog.selectedUrl(), path)
+
+ def testSelectFrameFromH5(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ filename = _tmpDirectory + "/data.h5"
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/cube", data_slice=(1, )).path()
+ dialog.selectUrl(path)
+ # test
+ self.assertEqual(dialog.selectedImage().shape, (100, 100))
+ self.assertEqual(dialog.selectedImage()[0, 0], 1)
+ self.assertSamePath(dialog.selectedFile(), filename)
+ self.assertSamePath(dialog.selectedUrl(), path)
+
+ def testSelectSingleFrameFromH5(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ filename = _tmpDirectory + "/data.h5"
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/single_frame", data_slice=(0, )).path()
+ dialog.selectUrl(path)
+ # test
+ self.assertEqual(dialog.selectedImage().shape, (100, 100))
+ self.assertEqual(dialog.selectedImage()[0, 0], 5)
+ self.assertSamePath(dialog.selectedFile(), filename)
+ self.assertSamePath(dialog.selectedUrl(), path)
+
+ def testSelectBadFileFormat_Activate(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ dialog.selectUrl(_tmpDirectory)
+ self.qWaitForPendingActions(dialog)
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ filename = _tmpDirectory + "/badformat.edf"
+ index = browser.model().index(filename)
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+ # test
+ self.assertSamePath(dialog.selectedUrl(), filename)
+
+ def _countSelectableItems(self, model, rootIndex):
+ selectable = 0
+ for i in range(model.rowCount(rootIndex)):
+ index = model.index(i, 0, rootIndex)
+ flags = model.flags(index)
+ isEnabled = (int(flags) & qt.Qt.ItemIsEnabled) != 0
+ if isEnabled:
+ selectable += 1
+ return selectable
+
+ def testFilterExtensions(self):
+ dialog = self.createDialog()
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ filters = testutils.findChildren(dialog, qt.QWidget, name="fileTypeCombo")[0]
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ dialog.selectUrl(_tmpDirectory)
+ self.qWaitForPendingActions(dialog)
+ self.assertEqual(self._countSelectableItems(browser.model(), browser.rootIndex()), 6)
+
+ codecName = fabio.edfimage.EdfImage.codec_name()
+ index = filters.indexFromCodec(codecName)
+ filters.setCurrentIndex(index)
+ filters.activated[int].emit(index)
+ self.qWait(50)
+ self.assertEqual(self._countSelectableItems(browser.model(), browser.rootIndex()), 4)
+
+ codecName = fabio.fit2dmaskimage.Fit2dMaskImage.codec_name()
+ index = filters.indexFromCodec(codecName)
+ filters.setCurrentIndex(index)
+ filters.activated[int].emit(index)
+ self.qWait(50)
+ self.assertEqual(self._countSelectableItems(browser.model(), browser.rootIndex()), 2)
+
+
+class TestImageFileDialogApi(testutils.TestCaseQt, _UtilsMixin):
+
+ def tearDown(self):
+ self._deleteDialog()
+ testutils.TestCaseQt.tearDown(self)
+
+ def testSaveRestoreState(self):
+ dialog = self.createDialog()
+ dialog.setDirectory(_tmpDirectory)
+ colormap = Colormap(normalization=Colormap.LOGARITHM)
+ dialog.setColormap(colormap)
+ self.qWaitForPendingActions(dialog)
+ state = dialog.saveState()
+ dialog = None
+
+ dialog2 = self.createDialog()
+ result = dialog2.restoreState(state)
+ self.qWaitForPendingActions(dialog2)
+ self.assertTrue(result)
+ self.assertEqual(dialog2.colormap().getNormalization(), "log")
+
+ def printState(self):
+ """
+ Print state of the ImageFileDialog.
+
+ Can be used to add or regenerate `STATE_VERSION1_QT4` or
+ `STATE_VERSION1_QT5`.
+
+ >>> ./run_tests.py -v silx.gui.dialog.test.test_imagefiledialog.TestImageFileDialogApi.printState
+ """
+ dialog = self.createDialog()
+ colormap = Colormap(normalization=Colormap.LOGARITHM)
+ dialog.setDirectory("")
+ dialog.setHistory([])
+ dialog.setColormap(colormap)
+ dialog.setSidebarUrls([])
+ state = dialog.saveState()
+ string = ""
+ strings = []
+ for i in range(state.size()):
+ d = state.data()[i]
+ if not isinstance(d, int):
+ d = ord(d)
+ if d > 0x20 and d < 0x7F:
+ string += chr(d)
+ else:
+ string += "\\x%02X" % d
+ if len(string) > 60:
+ strings.append(string)
+ string = ""
+ strings.append(string)
+ strings = ["b'%s'" % s for s in strings]
+ print()
+ print("\\\n".join(strings))
+
+ STATE_VERSION1_QT4 = b''\
+ b'\x00\x00\x00^\x00s\x00i\x00l\x00x\x00.\x00g\x00u\x00i\x00.\x00'\
+ b'd\x00i\x00a\x00l\x00o\x00g\x00.\x00I\x00m\x00a\x00g\x00e\x00F'\
+ b'\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00.\x00I\x00m\x00'\
+ b'a\x00g\x00e\x00F\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g'\
+ b'\x00\x00\x00\x01\x00\x00\x00\x0C\x00\x00\x00\x00"\x00\x00\x00'\
+ b'\xFF\x00\x00\x00\x00\x00\x00\x00\x03\xFF\xFF\xFF\xFF\xFF\xFF\xFF'\
+ b'\xFF\xFF\xFF\xFF\xFF\x01\x00\x00\x00\x06\x01\x00\x00\x00\x01\x00'\
+ b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0C\x00'\
+ b'\x00\x00\x00}\x00\x00\x00\x0E\x00B\x00r\x00o\x00w\x00s\x00e\x00'\
+ b'r\x00\x00\x00\x01\x00\x00\x00\x0C\x00\x00\x00\x00Z\x00\x00\x00'\
+ b'\xFF\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00'\
+ b'\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'\
+ b'\x00\x00\x00\x00\x01\x90\x00\x00\x00\x04\x01\x01\x00\x00\x00\x00'\
+ b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00d\xFF\xFF\xFF\xFF\x00'\
+ b'\x00\x00\x81\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x01\x90\x00'\
+ b'\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x01\x00'\
+ b'\x00\x00\x0C\x00\x00\x00\x000\x00\x00\x00\x10\x00C\x00o\x00l\x00'\
+ b'o\x00r\x00m\x00a\x00p\x00\x00\x00\x01\x00\x00\x00\x08\x00g\x00'\
+ b'r\x00a\x00y\x01\x01\x00\x00\x00\x06\x00l\x00o\x00g'
+ """Serialized state on Qt4. Generated using :meth:`printState`"""
+
+ STATE_VERSION1_QT5 = b''\
+ b'\x00\x00\x00^\x00s\x00i\x00l\x00x\x00.\x00g\x00u\x00i\x00.\x00'\
+ b'd\x00i\x00a\x00l\x00o\x00g\x00.\x00I\x00m\x00a\x00g\x00e\x00F'\
+ b'\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00.\x00I\x00m\x00'\
+ b'a\x00g\x00e\x00F\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g'\
+ b'\x00\x00\x00\x01\x00\x00\x00\x0C\x00\x00\x00\x00#\x00\x00\x00'\
+ b'\xFF\x00\x00\x00\x01\x00\x00\x00\x03\xFF\xFF\xFF\xFF\xFF\xFF\xFF'\
+ b'\xFF\xFF\xFF\xFF\xFF\x01\xFF\xFF\xFF\xFF\x01\x00\x00\x00\x01\x00'\
+ b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0C'\
+ b'\x00\x00\x00\x00\xAA\x00\x00\x00\x0E\x00B\x00r\x00o\x00w\x00s'\
+ b'\x00e\x00r\x00\x00\x00\x01\x00\x00\x00\x0C\x00\x00\x00\x00\x87'\
+ b'\x00\x00\x00\xFF\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00'\
+ b'\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'\
+ b'\x00\x00\x00\x00\x00\x00\x00\x01\x90\x00\x00\x00\x04\x01\x01\x00'\
+ b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00d\xFF\xFF'\
+ b'\xFF\xFF\x00\x00\x00\x81\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00'\
+ b'\x00d\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00d\x00\x00\x00'\
+ b'\x01\x00\x00\x00\x00\x00\x00\x00d\x00\x00\x00\x01\x00\x00\x00'\
+ b'\x00\x00\x00\x00d\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x03'\
+ b'\xE8\x00\xFF\xFF\xFF\xFF\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00'\
+ b'\x00\x0C\x00\x00\x00\x000\x00\x00\x00\x10\x00C\x00o\x00l\x00o'\
+ b'\x00r\x00m\x00a\x00p\x00\x00\x00\x01\x00\x00\x00\x08\x00g\x00'\
+ b'r\x00a\x00y\x01\x01\x00\x00\x00\x06\x00l\x00o\x00g'
+ """Serialized state on Qt5. Generated using :meth:`printState`"""
+
+ def testAvoidRestoreRegression_Version1(self):
+ version = qt.qVersion().split(".")[0]
+ if version == "4":
+ state = self.STATE_VERSION1_QT4
+ elif version == "5":
+ state = self.STATE_VERSION1_QT5
+ else:
+ self.skipTest("Resource not available")
+
+ state = qt.QByteArray(state)
+ dialog = self.createDialog()
+ result = dialog.restoreState(state)
+ self.assertTrue(result)
+ colormap = dialog.colormap()
+ self.assertEqual(colormap.getNormalization(), "log")
+
+ def testRestoreRobusness(self):
+ """What's happen if you try to open a config file with a different
+ binding."""
+ state = qt.QByteArray(self.STATE_VERSION1_QT4)
+ dialog = self.createDialog()
+ dialog.restoreState(state)
+ state = qt.QByteArray(self.STATE_VERSION1_QT5)
+ dialog = None
+ dialog = self.createDialog()
+ dialog.restoreState(state)
+
+ def testRestoreNonExistingDirectory(self):
+ directory = os.path.join(_tmpDirectory, "dir")
+ os.mkdir(directory)
+ dialog = self.createDialog()
+ dialog.setDirectory(directory)
+ self.qWaitForPendingActions(dialog)
+ state = dialog.saveState()
+ os.rmdir(directory)
+ dialog = None
+
+ dialog2 = self.createDialog()
+ result = dialog2.restoreState(state)
+ self.assertTrue(result)
+ self.assertNotEqual(dialog2.directory(), directory)
+
+ def testHistory(self):
+ dialog = self.createDialog()
+ history = dialog.history()
+ dialog.setHistory([])
+ self.assertEqual(dialog.history(), [])
+ dialog.setHistory(history)
+ self.assertEqual(dialog.history(), history)
+
+ def testSidebarUrls(self):
+ dialog = self.createDialog()
+ urls = dialog.sidebarUrls()
+ dialog.setSidebarUrls([])
+ self.assertEqual(dialog.sidebarUrls(), [])
+ dialog.setSidebarUrls(urls)
+ self.assertEqual(dialog.sidebarUrls(), urls)
+
+ def testColomap(self):
+ dialog = self.createDialog()
+ colormap = dialog.colormap()
+ self.assertEqual(colormap.getNormalization(), "linear")
+ colormap = Colormap(normalization=Colormap.LOGARITHM)
+ dialog.setColormap(colormap)
+ self.assertEqual(colormap.getNormalization(), "log")
+
+ def testDirectory(self):
+ dialog = self.createDialog()
+ self.qWaitForPendingActions(dialog)
+ dialog.selectUrl(_tmpDirectory)
+ self.qWaitForPendingActions(dialog)
+ self.assertSamePath(dialog.directory(), _tmpDirectory)
+
+ def testBadDataType(self):
+ dialog = self.createDialog()
+ dialog.selectUrl(_tmpDirectory + "/data.h5::/complex_image")
+ self.qWaitForPendingActions(dialog)
+ self.assertIsNone(dialog._selectedData())
+
+ def testBadDataShape(self):
+ dialog = self.createDialog()
+ dialog.selectUrl(_tmpDirectory + "/data.h5::/unknown")
+ self.qWaitForPendingActions(dialog)
+ self.assertIsNone(dialog._selectedData())
+
+ def testBadDataFormat(self):
+ dialog = self.createDialog()
+ dialog.selectUrl(_tmpDirectory + "/badformat.edf")
+ self.qWaitForPendingActions(dialog)
+ self.assertIsNone(dialog._selectedData())
+
+ def testBadPath(self):
+ dialog = self.createDialog()
+ dialog.selectUrl("#$%/#$%")
+ self.qWaitForPendingActions(dialog)
+ self.assertIsNone(dialog._selectedData())
+
+ def testBadSubpath(self):
+ dialog = self.createDialog()
+ self.qWaitForPendingActions(dialog)
+
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+
+ filename = _tmpDirectory + "/data.h5"
+ url = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/foobar")
+ dialog.selectUrl(url.path())
+ self.qWaitForPendingActions(dialog)
+ self.assertIsNone(dialog._selectedData())
+
+ # an existing node is browsed, but the wrong path is selected
+ index = browser.rootIndex()
+ obj = index.model().data(index, role=Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ self.assertEqual(obj.name, "/group")
+ url = silx.io.url.DataUrl(dialog.selectedUrl())
+ self.assertEqual(url.data_path(), "/group")
+
+ def testBadSlicingPath(self):
+ dialog = self.createDialog()
+ self.qWaitForPendingActions(dialog)
+ dialog.selectUrl(_tmpDirectory + "/data.h5::/cube[a;45,-90]")
+ self.qWaitForPendingActions(dialog)
+ self.assertIsNone(dialog._selectedData())
diff --git a/src/silx/gui/dialog/utils.py b/src/silx/gui/dialog/utils.py
new file mode 100644
index 0000000..4c48930
--- /dev/null
+++ b/src/silx/gui/dialog/utils.py
@@ -0,0 +1,99 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module contains utilitaries used by other dialog modules.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "25/10/2017"
+
+import os
+import types
+
+from silx.gui import qt
+
+
+def samefile(path1, path2):
+ """Portable :func:`os.path.samepath` function.
+
+ :param str path1: A path to a file
+ :param str path2: Another path to a file
+ :rtype: bool
+ """
+ if path1 == path2:
+ return True
+ if path1 == "":
+ return False
+ if path2 == "":
+ return False
+ return os.path.samefile(path1, path2)
+
+
+def findClosestSubPath(hdf5Object, path):
+ """Find the closest existing path from the hdf5Object using a subset of the
+ provided path.
+
+ Returns None if no path found. It is possible if the path is a relative
+ path.
+
+ :param h5py.Node hdf5Object: An HDF5 node
+ :param str path: A path
+ :rtype: str
+ """
+ if path in ["", "/"]:
+ return "/"
+ names = path.split("/")
+ if path[0] == "/":
+ names.pop(0)
+ for i in range(len(names)):
+ n = len(names) - i
+ path2 = "/".join(names[0:n])
+ if path2 == "":
+ return ""
+ if path2 in hdf5Object:
+ return path2
+
+ if path[0] == "/":
+ return "/"
+ return None
+
+
+def patchToConsumeReturnKey(widget):
+ """
+ Monkey-patch a widget to consume the return key instead of propagating it
+ to the dialog.
+ """
+ assert(not hasattr(widget, "_oldKeyPressEvent"))
+
+ def keyPressEvent(self, event):
+ k = event.key()
+ result = self._oldKeyPressEvent(event)
+ if k in [qt.Qt.Key_Return, qt.Qt.Key_Enter]:
+ event.accept()
+ return result
+
+ widget._oldKeyPressEvent = widget.keyPressEvent
+ widget.keyPressEvent = types.MethodType(keyPressEvent, widget)
diff --git a/src/silx/gui/fit/BackgroundWidget.py b/src/silx/gui/fit/BackgroundWidget.py
new file mode 100644
index 0000000..7703ee1
--- /dev/null
+++ b/src/silx/gui/fit/BackgroundWidget.py
@@ -0,0 +1,534 @@
+# coding: utf-8
+#/*##########################################################################
+# Copyright (C) 2004-2021 V.A. Sole, European Synchrotron Radiation Facility
+#
+# This file is part of the PyMca X-ray Fluorescence Toolkit developed at
+# the ESRF by the Software group.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# #########################################################################*/
+"""This module provides a background configuration widget
+:class:`BackgroundWidget` and a corresponding dialog window
+:class:`BackgroundDialog`.
+
+.. image:: img/BackgroundDialog.png
+ :height: 300px
+"""
+import sys
+import numpy
+from silx.gui import qt
+from silx.gui.plot import PlotWidget
+from silx.math.fit import filters
+
+__authors__ = ["V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "28/06/2017"
+
+
+class HorizontalSpacer(qt.QWidget):
+ def __init__(self, *args):
+ qt.QWidget.__init__(self, *args)
+ self.setSizePolicy(qt.QSizePolicy(qt.QSizePolicy.Expanding,
+ qt.QSizePolicy.Fixed))
+
+
+class BackgroundParamWidget(qt.QWidget):
+ """Background configuration composite widget.
+
+ Strip and snip filters parameters can be adjusted using input widgets.
+
+ Updating the widgets causes :attr:`sigBackgroundParamWidgetSignal` to
+ be emitted.
+ """
+ sigBackgroundParamWidgetSignal = qt.pyqtSignal(object)
+
+ def __init__(self, parent=None):
+ qt.QWidget.__init__(self, parent)
+
+ self.mainLayout = qt.QGridLayout(self)
+ self.mainLayout.setColumnStretch(1, 1)
+
+ # Algorithm choice ---------------------------------------------------
+ self.algorithmComboLabel = qt.QLabel(self)
+ self.algorithmComboLabel.setText("Background algorithm")
+ self.algorithmCombo = qt.QComboBox(self)
+ self.algorithmCombo.addItem("Strip")
+ self.algorithmCombo.addItem("Snip")
+ self.algorithmCombo.activated[int].connect(
+ self._algorithmComboActivated)
+
+ # Strip parameters ---------------------------------------------------
+ self.stripWidthLabel = qt.QLabel(self)
+ self.stripWidthLabel.setText("Strip Width")
+
+ self.stripWidthSpin = qt.QSpinBox(self)
+ self.stripWidthSpin.setMaximum(100)
+ self.stripWidthSpin.setMinimum(1)
+ self.stripWidthSpin.valueChanged[int].connect(self._emitSignal)
+
+ self.stripIterLabel = qt.QLabel(self)
+ self.stripIterLabel.setText("Strip Iterations")
+ self.stripIterValue = qt.QLineEdit(self)
+ validator = qt.QIntValidator(self.stripIterValue)
+ self.stripIterValue._v = validator
+ self.stripIterValue.setText("0")
+ self.stripIterValue.editingFinished[()].connect(self._emitSignal)
+ self.stripIterValue.setToolTip(
+ "Number of iterations for strip algorithm.\n" +
+ "If greater than 999, an 2nd pass of strip filter is " +
+ "applied to remove artifacts created by first pass.")
+
+ # Snip parameters ----------------------------------------------------
+ self.snipWidthLabel = qt.QLabel(self)
+ self.snipWidthLabel.setText("Snip Width")
+
+ self.snipWidthSpin = qt.QSpinBox(self)
+ self.snipWidthSpin.setMaximum(300)
+ self.snipWidthSpin.setMinimum(0)
+ self.snipWidthSpin.valueChanged[int].connect(self._emitSignal)
+
+
+ # Smoothing parameters -----------------------------------------------
+ self.smoothingFlagCheck = qt.QCheckBox(self)
+ self.smoothingFlagCheck.setText("Smoothing Width (Savitsky-Golay)")
+ self.smoothingFlagCheck.toggled.connect(self._smoothingToggled)
+
+ self.smoothingSpin = qt.QSpinBox(self)
+ self.smoothingSpin.setMinimum(3)
+ #self.smoothingSpin.setMaximum(40)
+ self.smoothingSpin.setSingleStep(2)
+ self.smoothingSpin.valueChanged[int].connect(self._emitSignal)
+
+ # Anchors ------------------------------------------------------------
+
+ self.anchorsGroup = qt.QWidget(self)
+ anchorsLayout = qt.QHBoxLayout(self.anchorsGroup)
+ anchorsLayout.setSpacing(2)
+ anchorsLayout.setContentsMargins(0, 0, 0, 0)
+
+ self.anchorsFlagCheck = qt.QCheckBox(self.anchorsGroup)
+ self.anchorsFlagCheck.setText("Use anchors")
+ self.anchorsFlagCheck.setToolTip(
+ "Define X coordinates of points that must remain fixed")
+ self.anchorsFlagCheck.stateChanged[int].connect(
+ self._anchorsToggled)
+ anchorsLayout.addWidget(self.anchorsFlagCheck)
+
+ maxnchannel = 16384 * 4 # Fixme ?
+ self.anchorsList = []
+ num_anchors = 4
+ for i in range(num_anchors):
+ anchorSpin = qt.QSpinBox(self.anchorsGroup)
+ anchorSpin.setMinimum(0)
+ anchorSpin.setMaximum(maxnchannel)
+ anchorSpin.valueChanged[int].connect(self._emitSignal)
+ anchorsLayout.addWidget(anchorSpin)
+ self.anchorsList.append(anchorSpin)
+
+ # Layout ------------------------------------------------------------
+ self.mainLayout.addWidget(self.algorithmComboLabel, 0, 0)
+ self.mainLayout.addWidget(self.algorithmCombo, 0, 2)
+ self.mainLayout.addWidget(self.stripWidthLabel, 1, 0)
+ self.mainLayout.addWidget(self.stripWidthSpin, 1, 2)
+ self.mainLayout.addWidget(self.stripIterLabel, 2, 0)
+ self.mainLayout.addWidget(self.stripIterValue, 2, 2)
+ self.mainLayout.addWidget(self.snipWidthLabel, 3, 0)
+ self.mainLayout.addWidget(self.snipWidthSpin, 3, 2)
+ self.mainLayout.addWidget(self.smoothingFlagCheck, 4, 0)
+ self.mainLayout.addWidget(self.smoothingSpin, 4, 2)
+ self.mainLayout.addWidget(self.anchorsGroup, 5, 0, 1, 4)
+
+ # Initialize interface -----------------------------------------------
+ self._setAlgorithm("strip")
+ self.smoothingFlagCheck.setChecked(False)
+ self._smoothingToggled(is_checked=False)
+ self.anchorsFlagCheck.setChecked(False)
+ self._anchorsToggled(is_checked=False)
+
+ def _algorithmComboActivated(self, algorithm_index):
+ self._setAlgorithm("strip" if algorithm_index == 0 else "snip")
+
+ def _setAlgorithm(self, algorithm):
+ """Enable/disable snip and snip input widgets, depending on the
+ chosen algorithm.
+ :param algorithm: "snip" or "strip"
+ """
+ if algorithm not in ["strip", "snip"]:
+ raise ValueError(
+ "Unknown background filter algorithm %s" % algorithm)
+
+ self.algorithm = algorithm
+ self.stripWidthSpin.setEnabled(algorithm == "strip")
+ self.stripIterValue.setEnabled(algorithm == "strip")
+ self.snipWidthSpin.setEnabled(algorithm == "snip")
+
+ def _smoothingToggled(self, is_checked):
+ """Enable/disable smoothing input widgets, emit dictionary"""
+ self.smoothingSpin.setEnabled(is_checked)
+ self._emitSignal()
+
+ def _anchorsToggled(self, is_checked):
+ """Enable/disable all spin widgets defining anchor X coordinates,
+ emit signal.
+ """
+ for anchor_spin in self.anchorsList:
+ anchor_spin.setEnabled(is_checked)
+ self._emitSignal()
+
+ def setParameters(self, ddict):
+ """Set values for all input widgets.
+
+ :param dict ddict: Input dictionary, must have the same
+ keys as the dictionary output by :meth:`getParameters`
+ """
+ if "algorithm" in ddict:
+ self._setAlgorithm(ddict["algorithm"])
+
+ if "SnipWidth" in ddict:
+ self.snipWidthSpin.setValue(int(ddict["SnipWidth"]))
+
+ if "StripWidth" in ddict:
+ self.stripWidthSpin.setValue(int(ddict["StripWidth"]))
+
+ if "StripIterations" in ddict:
+ self.stripIterValue.setText("%d" % int(ddict["StripIterations"]))
+
+ if "SmoothingFlag" in ddict:
+ self.smoothingFlagCheck.setChecked(bool(ddict["SmoothingFlag"]))
+
+ if "SmoothingWidth" in ddict:
+ self.smoothingSpin.setValue(int(ddict["SmoothingWidth"]))
+
+ if "AnchorsFlag" in ddict:
+ self.anchorsFlagCheck.setChecked(bool(ddict["AnchorsFlag"]))
+
+ if "AnchorsList" in ddict:
+ anchorslist = ddict["AnchorsList"]
+ if anchorslist in [None, 'None']:
+ anchorslist = []
+ for spin in self.anchorsList:
+ spin.setValue(0)
+
+ i = 0
+ for value in anchorslist:
+ self.anchorsList[i].setValue(int(value))
+ i += 1
+
+ def getParameters(self):
+ """Return dictionary of parameters defined in the GUI
+
+ The returned dictionary contains following values:
+
+ - *algorithm*: *"strip"* or *"snip"*
+ - *StripWidth*: width of strip iterator
+ - *StripIterations*: number of iterations
+ - *StripThreshold*: curvature parameter (currently fixed to 1.0)
+ - *SnipWidth*: width of snip algorithm
+ - *SmoothingFlag*: flag to enable/disable smoothing
+ - *SmoothingWidth*: width of Savitsky-Golay smoothing filter
+ - *AnchorsFlag*: flag to enable/disable anchors
+ - *AnchorsList*: list of anchors (X coordinates of fixed values)
+ """
+ stripitertext = self.stripIterValue.text()
+ stripiter = int(stripitertext) if len(stripitertext) else 0
+
+ return {"algorithm": self.algorithm,
+ "StripThreshold": 1.0,
+ "SnipWidth": self.snipWidthSpin.value(),
+ "StripIterations": stripiter,
+ "StripWidth": self.stripWidthSpin.value(),
+ "SmoothingFlag": self.smoothingFlagCheck.isChecked(),
+ "SmoothingWidth": self.smoothingSpin.value(),
+ "AnchorsFlag": self.anchorsFlagCheck.isChecked(),
+ "AnchorsList": [spin.value() for spin in self.anchorsList]}
+
+ def _emitSignal(self, dummy=None):
+ self.sigBackgroundParamWidgetSignal.emit(
+ {'event': 'ParametersChanged',
+ 'parameters': self.getParameters()})
+
+
+class BackgroundWidget(qt.QWidget):
+ """Background configuration widget, with a plot to preview the results.
+
+ Strip and snip filters parameters can be adjusted using input widgets,
+ and the computed backgrounds are plotted next to the original data to
+ show the result."""
+ def __init__(self, parent=None):
+ qt.QWidget.__init__(self, parent)
+ self.setWindowTitle("Strip and SNIP Configuration Window")
+ self.mainLayout = qt.QVBoxLayout(self)
+ self.mainLayout.setContentsMargins(0, 0, 0, 0)
+ self.mainLayout.setSpacing(2)
+ self.parametersWidget = BackgroundParamWidget(self)
+ self.graphWidget = PlotWidget(parent=self)
+ self.mainLayout.addWidget(self.parametersWidget)
+ self.mainLayout.addWidget(self.graphWidget)
+ self._x = None
+ self._y = None
+ self.parametersWidget.sigBackgroundParamWidgetSignal.connect(self._slot)
+
+ def getParameters(self):
+ """Return dictionary of parameters defined in the GUI
+
+ The returned dictionary contains following values:
+
+ - *algorithm*: *"strip"* or *"snip"*
+ - *StripWidth*: width of strip iterator
+ - *StripIterations*: number of iterations
+ - *StripThreshold*: strip curvature (currently fixed to 1.0)
+ - *SnipWidth*: width of snip algorithm
+ - *SmoothingFlag*: flag to enable/disable smoothing
+ - *SmoothingWidth*: width of Savitsky-Golay smoothing filter
+ - *AnchorsFlag*: flag to enable/disable anchors
+ - *AnchorsList*: list of anchors (X coordinates of fixed values)
+ """
+ return self.parametersWidget.getParameters()
+
+ def setParameters(self, ddict):
+ """Set values for all input widgets.
+
+ :param dict ddict: Input dictionary, must have the same
+ keys as the dictionary output by :meth:`getParameters`
+ """
+ return self.parametersWidget.setParameters(ddict)
+
+ def setData(self, x, y, xmin=None, xmax=None):
+ """Set data for the original curve, and _update strip and snip
+ curves accordingly.
+
+ :param x: Array or sequence of curve abscissa values
+ :param y: Array or sequence of curve ordinate values
+ :param xmin: Min value to be displayed on the X axis
+ :param xmax: Max value to be displayed on the X axis
+ """
+ self._x = x
+ self._y = y
+ self._xmin = xmin
+ self._xmax = xmax
+ self._update(resetzoom=True)
+
+ def _slot(self, ddict):
+ self._update()
+
+ def _update(self, resetzoom=False):
+ """Compute strip and snip backgrounds, update the curves
+ """
+ if self._y is None:
+ return
+
+ pars = self.getParameters()
+
+ # smoothed data
+ y = numpy.ravel(numpy.array(self._y)).astype(numpy.float64)
+ if pars["SmoothingFlag"]:
+ ysmooth = filters.savitsky_golay(y, pars['SmoothingWidth'])
+ f = [0.25, 0.5, 0.25]
+ ysmooth[1:-1] = numpy.convolve(ysmooth, f, mode=0)
+ ysmooth[0] = 0.5 * (ysmooth[0] + ysmooth[1])
+ ysmooth[-1] = 0.5 * (ysmooth[-1] + ysmooth[-2])
+ else:
+ ysmooth = y
+
+
+ # loop for anchors
+ x = self._x
+ niter = pars['StripIterations']
+ anchors_indices = []
+ if pars['AnchorsFlag'] and pars['AnchorsList'] is not None:
+ ravelled = x
+ for channel in pars['AnchorsList']:
+ if channel <= ravelled[0]:
+ continue
+ index = numpy.nonzero(ravelled >= channel)[0]
+ if len(index):
+ index = min(index)
+ if index > 0:
+ anchors_indices.append(index)
+
+ stripBackground = filters.strip(ysmooth,
+ w=pars['StripWidth'],
+ niterations=niter,
+ factor=pars['StripThreshold'],
+ anchors=anchors_indices)
+
+ if niter >= 1000:
+ # final smoothing
+ stripBackground = filters.strip(stripBackground,
+ w=1,
+ niterations=50*pars['StripWidth'],
+ factor=pars['StripThreshold'],
+ anchors=anchors_indices)
+
+ if len(anchors_indices) == 0:
+ anchors_indices = [0, len(ysmooth)-1]
+ anchors_indices.sort()
+ snipBackground = 0.0 * ysmooth
+ lastAnchor = 0
+ for anchor in anchors_indices:
+ if (anchor > lastAnchor) and (anchor < len(ysmooth)):
+ snipBackground[lastAnchor:anchor] =\
+ filters.snip1d(ysmooth[lastAnchor:anchor],
+ pars['SnipWidth'])
+ lastAnchor = anchor
+ if lastAnchor < len(ysmooth):
+ snipBackground[lastAnchor:] =\
+ filters.snip1d(ysmooth[lastAnchor:],
+ pars['SnipWidth'])
+
+ self.graphWidget.addCurve(x, y,
+ legend='Input Data',
+ replace=True,
+ resetzoom=resetzoom)
+ self.graphWidget.addCurve(x, stripBackground,
+ legend='Strip Background',
+ resetzoom=False)
+ self.graphWidget.addCurve(x, snipBackground,
+ legend='SNIP Background',
+ resetzoom=False)
+ if self._xmin is not None and self._xmax is not None:
+ self.graphWidget.getXAxis().setLimits(self._xmin, self._xmax)
+
+
+class BackgroundDialog(qt.QDialog):
+ """QDialog window featuring a :class:`BackgroundWidget`"""
+ def __init__(self, parent=None):
+ qt.QDialog.__init__(self, parent)
+ self.setWindowTitle("Strip and Snip Configuration Window")
+ self.mainLayout = qt.QVBoxLayout(self)
+ self.mainLayout.setContentsMargins(0, 0, 0, 0)
+ self.mainLayout.setSpacing(2)
+ self.parametersWidget = BackgroundWidget(self)
+ self.mainLayout.addWidget(self.parametersWidget)
+ hbox = qt.QWidget(self)
+ hboxLayout = qt.QHBoxLayout(hbox)
+ hboxLayout.setContentsMargins(0, 0, 0, 0)
+ hboxLayout.setSpacing(2)
+ self.okButton = qt.QPushButton(hbox)
+ self.okButton.setText("OK")
+ self.okButton.setAutoDefault(False)
+ self.dismissButton = qt.QPushButton(hbox)
+ self.dismissButton.setText("Cancel")
+ self.dismissButton.setAutoDefault(False)
+ hboxLayout.addWidget(HorizontalSpacer(hbox))
+ hboxLayout.addWidget(self.okButton)
+ hboxLayout.addWidget(self.dismissButton)
+ self.mainLayout.addWidget(hbox)
+ self.dismissButton.clicked.connect(self.reject)
+ self.okButton.clicked.connect(self.accept)
+
+ self.output = {}
+ """Configuration dictionary containing following fields:
+
+ - *SmoothingFlag*
+ - *SmoothingWidth*
+ - *StripWidth*
+ - *StripIterations*
+ - *StripThreshold*
+ - *SnipWidth*
+ - *AnchorsFlag*
+ - *AnchorsList*
+ """
+
+ # self.parametersWidget.parametersWidget.sigBackgroundParamWidgetSignal.connect(self.updateOutput)
+
+ # def updateOutput(self, ddict):
+ # self.output = ddict
+
+ def accept(self):
+ """Update :attr:`output`, then call :meth:`QDialog.accept`
+ """
+ self.output = self.getParameters()
+ super(BackgroundDialog, self).accept()
+
+ def sizeHint(self):
+ return qt.QSize(int(1.5*qt.QDialog.sizeHint(self).width()),
+ qt.QDialog.sizeHint(self).height())
+
+ def setData(self, x, y, xmin=None, xmax=None):
+ """See :meth:`BackgroundWidget.setData`"""
+ return self.parametersWidget.setData(x, y, xmin, xmax)
+
+ def getParameters(self):
+ """See :meth:`BackgroundWidget.getParameters`"""
+ return self.parametersWidget.getParameters()
+
+ def setParameters(self, ddict):
+ """See :meth:`BackgroundWidget.setPrintGeometry`"""
+ return self.parametersWidget.setParameters(ddict)
+
+ def setDefault(self, ddict):
+ """Alias for :meth:`setPrintGeometry`"""
+ return self.setParameters(ddict)
+
+
+def getBgDialog(parent=None, default=None, modal=True):
+ """Instantiate and return a bg configuration dialog, adapted
+ for configuring standard background theories from
+ :mod:`silx.math.fit.bgtheories`.
+
+ :return: Instance of :class:`BackgroundDialog`
+ """
+ bgd = BackgroundDialog(parent=parent)
+ # apply default to newly added pages
+ bgd.setParameters(default)
+
+ return bgd
+
+
+def main():
+ # synthetic data
+ from silx.math.fit.functions import sum_gauss
+
+ x = numpy.arange(5000)
+ # (height1, center1, fwhm1, ...) 5 peaks
+ params1 = (50, 500, 100,
+ 20, 2000, 200,
+ 50, 2250, 100,
+ 40, 3000, 75,
+ 23, 4000, 150)
+ y0 = sum_gauss(x, *params1)
+
+ # random values between [-1;1]
+ noise = 2 * numpy.random.random(5000) - 1
+ # make it +- 5%
+ noise *= 0.05
+
+ # 2 gaussians with very large fwhm, as background signal
+ actual_bg = sum_gauss(x, 15, 3500, 3000, 5, 1000, 1500)
+
+ # Add 5% random noise to gaussians and add background
+ y = y0 + numpy.average(y0) * noise + actual_bg
+
+ # Open widget
+ a = qt.QApplication(sys.argv)
+ a.lastWindowClosed.connect(a.quit)
+
+ def mySlot(ddict):
+ print(ddict)
+
+ w = BackgroundDialog()
+ w.parametersWidget.parametersWidget.sigBackgroundParamWidgetSignal.connect(mySlot)
+ w.setData(x, y)
+ w.exec()
+ #a.exec()
+
+if __name__ == "__main__":
+ main()
diff --git a/src/silx/gui/fit/FitConfig.py b/src/silx/gui/fit/FitConfig.py
new file mode 100644
index 0000000..48ebca2
--- /dev/null
+++ b/src/silx/gui/fit/FitConfig.py
@@ -0,0 +1,543 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2004-2021 V.A. Sole, European Synchrotron Radiation Facility
+#
+# This file is part of the PyMca X-ray Fluorescence Toolkit developed at
+# the ESRF by the Software group.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ######################################################################### */
+"""This module defines widgets used to build a fit configuration dialog.
+The resulting dialog widget outputs a dictionary of configuration parameters.
+"""
+from silx.gui import qt
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "30/11/2016"
+
+
+class TabsDialog(qt.QDialog):
+ """Dialog widget containing a QTabWidget :attr:`tabWidget`
+ and a buttons:
+
+ # - buttonHelp
+ - buttonDefaults
+ - buttonOk
+ - buttonCancel
+
+ This dialog defines a __len__ returning the number of tabs,
+ and an __iter__ method yielding the tab widgets.
+ """
+ def __init__(self, parent=None):
+ qt.QDialog.__init__(self, parent)
+ self.tabWidget = qt.QTabWidget(self)
+
+ layout = qt.QVBoxLayout(self)
+ layout.addWidget(self.tabWidget)
+
+ layout2 = qt.QHBoxLayout(None)
+
+ # self.buttonHelp = qt.QPushButton(self)
+ # self.buttonHelp.setText("Help")
+ # layout2.addWidget(self.buttonHelp)
+
+ self.buttonDefault = qt.QPushButton(self)
+ self.buttonDefault.setText("Undo changes")
+ layout2.addWidget(self.buttonDefault)
+
+ spacer = qt.QSpacerItem(20, 20,
+ qt.QSizePolicy.Expanding,
+ qt.QSizePolicy.Minimum)
+ layout2.addItem(spacer)
+
+ self.buttonOk = qt.QPushButton(self)
+ self.buttonOk.setText("OK")
+ layout2.addWidget(self.buttonOk)
+
+ self.buttonCancel = qt.QPushButton(self)
+ self.buttonCancel.setText("Cancel")
+ layout2.addWidget(self.buttonCancel)
+
+ layout.addLayout(layout2)
+
+ self.buttonOk.clicked.connect(self.accept)
+ self.buttonCancel.clicked.connect(self.reject)
+
+ def __len__(self):
+ """Return number of tabs"""
+ return self.tabWidget.count()
+
+ def __iter__(self):
+ """Return the next tab widget in :attr:`tabWidget` every
+ time this method is called.
+
+ :return: Tab widget
+ :rtype: QWidget
+ """
+ for widget_index in range(len(self)):
+ yield self.tabWidget.widget(widget_index)
+
+ def addTab(self, page, label):
+ """Add a new tab
+
+ :param page: Content of new page. Must be a widget with
+ a get() method returning a dictionary.
+ :param str label: Tab label
+ """
+ self.tabWidget.addTab(page, label)
+
+ def getTabLabels(self):
+ """
+ Return a list of all tab labels in :attr:`tabWidget`
+ """
+ return [self.tabWidget.tabText(i) for i in range(len(self))]
+
+
+class TabsDialogData(TabsDialog):
+ """This dialog adds a data attribute to :class:`TabsDialog`.
+
+ Data input in widgets, such as text entries or checkboxes, is stored in an
+ attribute :attr:`output` when the user clicks the OK button.
+
+ A default dictionary can be supplied when this dialog is initialized, to
+ be used as default data for :attr:`output`.
+ """
+ def __init__(self, parent=None, modal=True, default=None):
+ """
+
+ :param parent: Parent :class:`QWidget`
+ :param modal: If `True`, dialog is modal, meaning this dialog remains
+ in front of it's parent window and disables it until the user is
+ done interacting with the dialog
+ :param default: Default dictionary, used to initialize and reset
+ :attr:`output`.
+ """
+ TabsDialog.__init__(self, parent)
+ self.setModal(modal)
+ self.setWindowTitle("Fit configuration")
+
+ self.output = {}
+
+ self.default = {} if default is None else default
+
+ self.buttonDefault.clicked.connect(self._resetDefault)
+ # self.keyPressEvent(qt.Qt.Key_Enter).
+
+ def keyPressEvent(self, event):
+ """Redefining this method to ignore Enter key
+ (for some reason it activates buttonDefault callback which
+ resets all widgets)
+ """
+ if event.key() in [qt.Qt.Key_Enter, qt.Qt.Key_Return]:
+ return
+ TabsDialog.keyPressEvent(self, event)
+
+ def accept(self):
+ """When *OK* is clicked, update :attr:`output` with data from
+ various widgets
+ """
+ self.output.update(self.default)
+
+ # loop over all tab widgets (uses TabsDialog.__iter__)
+ for tabWidget in self:
+ self.output.update(tabWidget.get())
+
+ # avoid pathological None cases
+ for key in self.output.keys():
+ if self.output[key] is None:
+ if key in self.default:
+ self.output[key] = self.default[key]
+ super(TabsDialogData, self).accept()
+
+ def reject(self):
+ """When the *Cancel* button is clicked, reinitialize :attr:`output`
+ and quit
+ """
+ self.setDefault()
+ super(TabsDialogData, self).reject()
+
+ def _resetDefault(self, checked):
+ self.setDefault()
+
+ def setDefault(self, newdefault=None):
+ """Reinitialize :attr:`output` with :attr:`default` or with
+ new dictionary ``newdefault`` if provided.
+ Call :meth:`setDefault` for each tab widget, if available.
+ """
+ self.output = {}
+ if newdefault is None:
+ newdefault = self.default
+ else:
+ self.default = newdefault
+ self.output.update(newdefault)
+
+ for tabWidget in self:
+ if hasattr(tabWidget, "setDefault"):
+ tabWidget.setDefault(self.output)
+
+
+class ConstraintsPage(qt.QGroupBox):
+ """Checkable QGroupBox widget filled with QCheckBox widgets,
+ to configure the fit estimation for standard fit theories.
+ """
+ def __init__(self, parent=None, title="Set constraints"):
+ super(ConstraintsPage, self).__init__(parent)
+ self.setTitle(title)
+ self.setToolTip("Disable 'Set constraints' to remove all " +
+ "constraints on all fit parameters")
+ self.setCheckable(True)
+
+ layout = qt.QVBoxLayout(self)
+ self.setLayout(layout)
+
+ self.positiveHeightCB = qt.QCheckBox("Force positive height/area", self)
+ self.positiveHeightCB.setToolTip("Fit must find positive peaks")
+ layout.addWidget(self.positiveHeightCB)
+
+ self.positionInIntervalCB = qt.QCheckBox("Force position in interval", self)
+ self.positionInIntervalCB.setToolTip(
+ "Fit must position peak within X limits")
+ layout.addWidget(self.positionInIntervalCB)
+
+ self.positiveFwhmCB = qt.QCheckBox("Force positive FWHM", self)
+ self.positiveFwhmCB.setToolTip("Fit must find a positive FWHM")
+ layout.addWidget(self.positiveFwhmCB)
+
+ self.sameFwhmCB = qt.QCheckBox("Force same FWHM for all peaks", self)
+ self.sameFwhmCB.setToolTip("Fit must find same FWHM for all peaks")
+ layout.addWidget(self.sameFwhmCB)
+
+ self.quotedEtaCB = qt.QCheckBox("Force Eta between 0 and 1", self)
+ self.quotedEtaCB.setToolTip(
+ "Fit must find Eta between 0 and 1 for pseudo-Voigt function")
+ layout.addWidget(self.quotedEtaCB)
+
+ layout.addStretch()
+
+ self.setDefault()
+
+ def setDefault(self, default_dict=None):
+ """Set default state for all widgets.
+
+ :param default_dict: If a default config dictionary is provided as
+ a parameter, its values are used as default state."""
+ if default_dict is None:
+ default_dict = {}
+ # this one uses reverse logic: if checked, NoConstraintsFlag must be False
+ self.setChecked(
+ not default_dict.get('NoConstraintsFlag', False))
+ self.positiveHeightCB.setChecked(
+ default_dict.get('PositiveHeightAreaFlag', True))
+ self.positionInIntervalCB.setChecked(
+ default_dict.get('QuotedPositionFlag', False))
+ self.positiveFwhmCB.setChecked(
+ default_dict.get('PositiveFwhmFlag', True))
+ self.sameFwhmCB.setChecked(
+ default_dict.get('SameFwhmFlag', False))
+ self.quotedEtaCB.setChecked(
+ default_dict.get('QuotedEtaFlag', False))
+
+ def get(self):
+ """Return a dictionary of constraint flags, to be processed by the
+ :meth:`configure` method of the selected fit theory."""
+ ddict = {
+ 'NoConstraintsFlag': not self.isChecked(),
+ 'PositiveHeightAreaFlag': self.positiveHeightCB.isChecked(),
+ 'QuotedPositionFlag': self.positionInIntervalCB.isChecked(),
+ 'PositiveFwhmFlag': self.positiveFwhmCB.isChecked(),
+ 'SameFwhmFlag': self.sameFwhmCB.isChecked(),
+ 'QuotedEtaFlag': self.quotedEtaCB.isChecked(),
+ }
+ return ddict
+
+
+class SearchPage(qt.QWidget):
+ def __init__(self, parent=None):
+ super(SearchPage, self).__init__(parent)
+ layout = qt.QVBoxLayout(self)
+
+ self.manualFwhmGB = qt.QGroupBox("Define FWHM manually", self)
+ self.manualFwhmGB.setCheckable(True)
+ self.manualFwhmGB.setToolTip(
+ "If disabled, the FWHM parameter used for peak search is " +
+ "estimated based on the highest peak in the data")
+ layout.addWidget(self.manualFwhmGB)
+ # ------------ GroupBox fwhm--------------------------
+ layout2 = qt.QHBoxLayout(self.manualFwhmGB)
+ self.manualFwhmGB.setLayout(layout2)
+
+ label = qt.QLabel("Fwhm Points", self.manualFwhmGB)
+ layout2.addWidget(label)
+
+ self.fwhmPointsSpin = qt.QSpinBox(self.manualFwhmGB)
+ self.fwhmPointsSpin.setRange(0, 999999)
+ self.fwhmPointsSpin.setToolTip("Typical peak fwhm (number of data points)")
+ layout2.addWidget(self.fwhmPointsSpin)
+ # ----------------------------------------------------
+
+ self.manualScalingGB = qt.QGroupBox("Define scaling manually", self)
+ self.manualScalingGB.setCheckable(True)
+ self.manualScalingGB.setToolTip(
+ "If disabled, the Y scaling used for peak search is " +
+ "estimated automatically")
+ layout.addWidget(self.manualScalingGB)
+ # ------------ GroupBox scaling-----------------------
+ layout3 = qt.QHBoxLayout(self.manualScalingGB)
+ self.manualScalingGB.setLayout(layout3)
+
+ label = qt.QLabel("Y Scaling", self.manualScalingGB)
+ layout3.addWidget(label)
+
+ self.yScalingEntry = qt.QLineEdit(self.manualScalingGB)
+ self.yScalingEntry.setToolTip(
+ "Data values will be multiplied by this value prior to peak" +
+ " search")
+ self.yScalingEntry.setValidator(qt.QDoubleValidator(self))
+ layout3.addWidget(self.yScalingEntry)
+ # ----------------------------------------------------
+
+ # ------------------- grid layout --------------------
+ containerWidget = qt.QWidget(self)
+ layout4 = qt.QHBoxLayout(containerWidget)
+ containerWidget.setLayout(layout4)
+
+ label = qt.QLabel("Sensitivity", containerWidget)
+ layout4.addWidget(label)
+
+ self.sensitivityEntry = qt.QLineEdit(containerWidget)
+ self.sensitivityEntry.setToolTip(
+ "Peak search sensitivity threshold, expressed as a multiple " +
+ "of the standard deviation of the noise.\nMinimum value is 1 " +
+ "(to be detected, peak must be higher than the estimated noise)")
+ sensivalidator = qt.QDoubleValidator(self)
+ sensivalidator.setBottom(1.0)
+ self.sensitivityEntry.setValidator(sensivalidator)
+ layout4.addWidget(self.sensitivityEntry)
+ # ----------------------------------------------------
+ layout.addWidget(containerWidget)
+
+ self.forcePeakPresenceCB = qt.QCheckBox("Force peak presence", self)
+ self.forcePeakPresenceCB.setToolTip(
+ "If peak search algorithm is unsuccessful, place one peak " +
+ "at the maximum of the curve")
+ layout.addWidget(self.forcePeakPresenceCB)
+
+ layout.addStretch()
+
+ self.setDefault()
+
+ def setDefault(self, default_dict=None):
+ """Set default values for all widgets.
+
+ :param default_dict: If a default config dictionary is provided as
+ a parameter, its values are used as default values."""
+ if default_dict is None:
+ default_dict = {}
+ self.manualFwhmGB.setChecked(
+ not default_dict.get('AutoFwhm', True))
+ self.fwhmPointsSpin.setValue(
+ default_dict.get('FwhmPoints', 8))
+ self.sensitivityEntry.setText(
+ str(default_dict.get('Sensitivity', 1.0)))
+ self.manualScalingGB.setChecked(
+ not default_dict.get('AutoScaling', False))
+ self.yScalingEntry.setText(
+ str(default_dict.get('Yscaling', 1.0)))
+ self.forcePeakPresenceCB.setChecked(
+ default_dict.get('ForcePeakPresence', False))
+
+ def get(self):
+ """Return a dictionary of peak search parameters, to be processed by
+ the :meth:`configure` method of the selected fit theory."""
+ ddict = {
+ 'AutoFwhm': not self.manualFwhmGB.isChecked(),
+ 'FwhmPoints': self.fwhmPointsSpin.value(),
+ 'Sensitivity': safe_float(self.sensitivityEntry.text()),
+ 'AutoScaling': not self.manualScalingGB.isChecked(),
+ 'Yscaling': safe_float(self.yScalingEntry.text()),
+ 'ForcePeakPresence': self.forcePeakPresenceCB.isChecked()
+ }
+ return ddict
+
+
+class BackgroundPage(qt.QGroupBox):
+ """Background subtraction configuration, specific to fittheories
+ estimation functions."""
+ def __init__(self, parent=None,
+ title="Subtract strip background prior to estimation"):
+ super(BackgroundPage, self).__init__(parent)
+ self.setTitle(title)
+ self.setCheckable(True)
+ self.setToolTip(
+ "The strip algorithm strips away peaks to compute the " +
+ "background signal.\nAt each iteration, a sample is compared " +
+ "to the average of the two samples at a given distance in both" +
+ " directions,\n and if its value is higher than the average,"
+ "it is replaced by the average.")
+
+ layout = qt.QGridLayout(self)
+ self.setLayout(layout)
+
+ for i, label_text in enumerate(
+ ["Strip width (in samples)",
+ "Number of iterations",
+ "Strip threshold factor"]):
+ label = qt.QLabel(label_text)
+ layout.addWidget(label, i, 0)
+
+ self.stripWidthSpin = qt.QSpinBox(self)
+ self.stripWidthSpin.setToolTip(
+ "Width, in number of samples, of the strip operator")
+ self.stripWidthSpin.setRange(1, 999999)
+
+ layout.addWidget(self.stripWidthSpin, 0, 1)
+
+ self.numIterationsSpin = qt.QSpinBox(self)
+ self.numIterationsSpin.setToolTip(
+ "Number of iterations of the strip algorithm")
+ self.numIterationsSpin.setRange(1, 999999)
+ layout.addWidget(self.numIterationsSpin, 1, 1)
+
+ self.thresholdFactorEntry = qt.QLineEdit(self)
+ self.thresholdFactorEntry.setToolTip(
+ "Factor used by the strip algorithm to decide whether a sample" +
+ "value should be stripped.\nThe value must be higher than the " +
+ "average of the 2 samples at +- w times this factor.\n")
+ self.thresholdFactorEntry.setValidator(qt.QDoubleValidator(self))
+ layout.addWidget(self.thresholdFactorEntry, 2, 1)
+
+ self.smoothStripGB = qt.QGroupBox("Apply smoothing prior to strip", self)
+ self.smoothStripGB.setCheckable(True)
+ self.smoothStripGB.setToolTip(
+ "Apply a smoothing before subtracting strip background" +
+ " in fit and estimate processes")
+ smoothlayout = qt.QHBoxLayout(self.smoothStripGB)
+ label = qt.QLabel("Smoothing width (Savitsky-Golay)")
+ smoothlayout.addWidget(label)
+ self.smoothingWidthSpin = qt.QSpinBox(self)
+ self.smoothingWidthSpin.setToolTip(
+ "Width parameter for Savitsky-Golay smoothing (number of samples, must be odd)")
+ self.smoothingWidthSpin.setRange(3, 101)
+ self.smoothingWidthSpin.setSingleStep(2)
+ smoothlayout.addWidget(self.smoothingWidthSpin)
+
+ layout.addWidget(self.smoothStripGB, 3, 0, 1, 2)
+
+ layout.setRowStretch(4, 1)
+
+ self.setDefault()
+
+ def setDefault(self, default_dict=None):
+ """Set default values for all widgets.
+
+ :param default_dict: If a default config dictionary is provided as
+ a parameter, its values are used as default values."""
+ if default_dict is None:
+ default_dict = {}
+
+ self.setChecked(
+ default_dict.get('StripBackgroundFlag', True))
+
+ self.stripWidthSpin.setValue(
+ default_dict.get('StripWidth', 2))
+ self.numIterationsSpin.setValue(
+ default_dict.get('StripIterations', 5000))
+ self.thresholdFactorEntry.setText(
+ str(default_dict.get('StripThreshold', 1.0)))
+ self.smoothStripGB.setChecked(
+ default_dict.get('SmoothingFlag', False))
+ self.smoothingWidthSpin.setValue(
+ default_dict.get('SmoothingWidth', 3))
+
+ def get(self):
+ """Return a dictionary of background subtraction parameters, to be
+ processed by the :meth:`configure` method of the selected fit theory.
+ """
+ ddict = {
+ 'StripBackgroundFlag': self.isChecked(),
+ 'StripWidth': self.stripWidthSpin.value(),
+ 'StripIterations': self.numIterationsSpin.value(),
+ 'StripThreshold': safe_float(self.thresholdFactorEntry.text()),
+ 'SmoothingFlag': self.smoothStripGB.isChecked(),
+ 'SmoothingWidth': self.smoothingWidthSpin.value()
+ }
+ return ddict
+
+
+def safe_float(string_, default=1.0):
+ """Convert a string into a float.
+ If the conversion fails, return the default value.
+ """
+ try:
+ ret = float(string_)
+ except ValueError:
+ return default
+ else:
+ return ret
+
+
+def safe_int(string_, default=1):
+ """Convert a string into a integer.
+ If the conversion fails, return the default value.
+ """
+ try:
+ ret = int(float(string_))
+ except ValueError:
+ return default
+ else:
+ return ret
+
+
+def getFitConfigDialog(parent=None, default=None, modal=True):
+ """Instantiate and return a fit configuration dialog, adapted
+ for configuring standard fit theories from
+ :mod:`silx.math.fit.fittheories`.
+
+ :return: Instance of :class:`TabsDialogData` with 3 tabs:
+ :class:`ConstraintsPage`, :class:`SearchPage` and
+ :class:`BackgroundPage`
+ """
+ tdd = TabsDialogData(parent=parent, default=default)
+ tdd.addTab(ConstraintsPage(), label="Constraints")
+ tdd.addTab(SearchPage(), label="Peak search")
+ tdd.addTab(BackgroundPage(), label="Background")
+ # apply default to newly added pages
+ tdd.setDefault()
+
+ return tdd
+
+
+def main():
+ a = qt.QApplication([])
+
+ mw = qt.QMainWindow()
+ mw.show()
+
+ tdd = getFitConfigDialog(mw, default={"a": 1})
+ tdd.show()
+ tdd.exec()
+ print("TabsDialogData result: ", tdd.result())
+ print("TabsDialogData output: ", tdd.output)
+
+ a.exec()
+
+if __name__ == "__main__":
+ main()
diff --git a/src/silx/gui/fit/FitWidget.py b/src/silx/gui/fit/FitWidget.py
new file mode 100644
index 0000000..52ecafe
--- /dev/null
+++ b/src/silx/gui/fit/FitWidget.py
@@ -0,0 +1,751 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# This file is part of the PyMca X-ray Fluorescence Toolkit developed at
+# the ESRF by the Software group.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ######################################################################### */
+"""This module provides a widget designed to configure and run a fitting
+process with constraints on parameters.
+
+The main class is :class:`FitWidget`. It relies on
+:mod:`silx.math.fit.fitmanager`, which relies on :func:`silx.math.fit.leastsq`.
+
+The user can choose between functions before running the fit. These function can
+be user defined, or by default are loaded from
+:mod:`silx.math.fit.fittheories`.
+"""
+
+__authors__ = ["V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "17/07/2018"
+
+import logging
+import sys
+import traceback
+
+from silx.math.fit import fittheories
+from silx.math.fit import fitmanager, functions
+from silx.gui import qt
+from .FitWidgets import (FitActionsButtons, FitStatusLines,
+ FitConfigWidget, ParametersTab)
+from .FitConfig import getFitConfigDialog
+from .BackgroundWidget import getBgDialog, BackgroundDialog
+from ...utils.deprecation import deprecated
+
+DEBUG = 0
+_logger = logging.getLogger(__name__)
+
+
+__authors__ = ["V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "30/11/2016"
+
+
+class FitWidget(qt.QWidget):
+ """This widget can be used to configure, run and display results of a
+ fitting process.
+
+ The standard steps for using this widget is to initialize it, then load
+ the data to be fitted.
+
+ Optionally, you can also load user defined fit theories. If you skip this
+ step, a series of default fit functions will be presented (gaussian-like
+ functions), and you can later load your custom fit theories from an
+ external file using the GUI.
+
+ A fit theory is a fit function and its associated features:
+
+ - estimation function,
+ - list of parameter names
+ - numerical derivative algorithm
+ - configuration widget
+
+ Once the widget is up and running, the user may select a fit theory and a
+ background theory, change configuration parameters specific to the theory
+ run the estimation, set constraints on parameters and run the actual fit.
+
+ The results are displayed in a table.
+
+ .. image:: img/FitWidget.png
+ """
+ sigFitWidgetSignal = qt.Signal(object)
+ """This signal is emitted by the estimation and fit methods.
+ It carries a dictionary with two items:
+
+ - *event*: one of the following strings
+
+ - *EstimateStarted*,
+ - *FitStarted*
+ - *EstimateFinished*,
+ - *FitFinished*
+ - *EstimateFailed*
+ - *FitFailed*
+
+ - *data*: None, or fit/estimate results (see documentation for
+ :attr:`silx.math.fit.fitmanager.FitManager.fit_results`)
+ """
+
+ def __init__(self, parent=None, title=None, fitmngr=None,
+ enableconfig=True, enablestatus=True, enablebuttons=True):
+ """
+
+ :param parent: Parent widget
+ :param title: Window title
+ :param fitmngr: User defined instance of
+ :class:`silx.math.fit.fitmanager.FitManager`, or ``None``
+ :param enableconfig: If ``True``, activate widgets to modify the fit
+ configuration (select between several fit functions or background
+ functions, apply global constraints, peak search parameters…)
+ :param enablestatus: If ``True``, add a fit status widget, to display
+ a message when fit estimation is available and when fit results
+ are available, as well as a measure of the fit error.
+ :param enablebuttons: If ``True``, add buttons to run estimation and
+ fitting.
+ """
+ if title is None:
+ title = "FitWidget"
+ qt.QWidget.__init__(self, parent)
+
+ self.setWindowTitle(title)
+ layout = qt.QVBoxLayout(self)
+
+ self.fitmanager = self._setFitManager(fitmngr)
+ """Instance of :class:`FitManager`.
+ This is the underlying data model of this FitWidget.
+
+ If no custom theories are defined, the default ones from
+ :mod:`silx.math.fit.fittheories` are imported.
+ """
+
+ # reference fitmanager.configure method for direct access
+ self.configure = self.fitmanager.configure
+ self.fitconfig = self.fitmanager.fitconfig
+
+ self.configdialogs = {}
+ """This dictionary defines the fit configuration widgets
+ associated with the fit theories in :attr:`fitmanager.theories`
+
+ Keys must correspond to existing theory names, i.e. existing keys
+ in :attr:`fitmanager.theories`.
+
+ Values must be instances of QDialog widgets with an additional
+ *output* attribute, a dictionary storing configuration parameters
+ interpreted by the corresponding fit theory.
+
+ The dialog can also define a *setDefault* method to initialize the
+ widget values with values in a dictionary passed as a parameter.
+ This will be executed first.
+
+ In case the widget does not actually inherit :class:`QDialog`, it
+ must at least implement the following methods (executed in this
+ particular order):
+
+ - :meth:`show`: should cause the widget to become visible to the
+ user)
+ - :meth:`exec`: should run while the user is interacting with the
+ widget, interrupting the rest of the program. It should
+ typically end (*return*) when the user clicks an *OK*
+ or a *Cancel* button.
+ - :meth:`result`: must return ``True`` if the new configuration in
+ attribute :attr:`output` is to be accepted (user clicked *OK*),
+ or return ``False`` if :attr:`output` is to be rejected (user
+ clicked *Cancel*)
+
+ To associate a custom configuration widget with a fit theory, use
+ :meth:`associateConfigDialog`. E.g.::
+
+ fw = FitWidget()
+ my_config_widget = MyGaussianConfigWidget(parent=fw)
+ fw.associateConfigDialog(theory_name="Gaussians",
+ config_widget=my_config_widget)
+ """
+
+ self.bgconfigdialogs = {}
+ """Same as :attr:`configdialogs`, except that the widget is associated
+ with a background theory in :attr:`fitmanager.bgtheories`"""
+
+ self._associateDefaultConfigDialogs()
+
+ self.guiConfig = None
+ """Configuration widget at the top of FitWidget, to select
+ fit function, background function, and open an advanced
+ configuration dialog."""
+
+ self.guiParameters = ParametersTab(self)
+ """Table widget for display of fit parameters and constraints"""
+
+ if enableconfig:
+ self.guiConfig = FitConfigWidget(self)
+ """Function selector and configuration widget"""
+
+ self.guiConfig.FunConfigureButton.clicked.connect(
+ self.__funConfigureGuiSlot)
+ self.guiConfig.BgConfigureButton.clicked.connect(
+ self.__bgConfigureGuiSlot)
+
+ self.guiConfig.WeightCheckBox.setChecked(
+ self.fitconfig.get("WeightFlag", False))
+ self.guiConfig.WeightCheckBox.stateChanged[int].connect(self.weightEvent)
+
+ if qt.BINDING in ('PySide2', 'PyQt5'):
+ self.guiConfig.BkgComBox.activated[str].connect(self.bkgEvent)
+ self.guiConfig.FunComBox.activated[str].connect(self.funEvent)
+ else: # Qt6
+ self.guiConfig.BkgComBox.textActivated.connect(self.bkgEvent)
+ self.guiConfig.FunComBox.textActivated.connect(self.funEvent)
+
+ self._populateFunctions()
+
+ layout.addWidget(self.guiConfig)
+
+ layout.addWidget(self.guiParameters)
+
+ if enablestatus:
+ self.guistatus = FitStatusLines(self)
+ """Status bar"""
+ layout.addWidget(self.guistatus)
+
+ if enablebuttons:
+ self.guibuttons = FitActionsButtons(self)
+ """Widget with estimate, start fit and dismiss buttons"""
+ self.guibuttons.EstimateButton.clicked.connect(self.estimate)
+ self.guibuttons.EstimateButton.setEnabled(False)
+ self.guibuttons.StartFitButton.clicked.connect(self.startFit)
+ self.guibuttons.StartFitButton.setEnabled(False)
+ self.guibuttons.DismissButton.clicked.connect(self.dismiss)
+ layout.addWidget(self.guibuttons)
+
+ def _setFitManager(self, fitinstance):
+ """Initialize a :class:`FitManager` instance, to be assigned to
+ :attr:`fitmanager`, or use a custom FitManager instance.
+
+ :param fitinstance: Existing instance of FitManager, possibly
+ customized by the user, or None to load a default instance."""
+ if isinstance(fitinstance, fitmanager.FitManager):
+ # customized
+ fitmngr = fitinstance
+ else:
+ # initialize default instance
+ fitmngr = fitmanager.FitManager()
+
+ # initialize the default fitting functions in case
+ # none is present
+ if not len(fitmngr.theories):
+ fitmngr.loadtheories(fittheories)
+
+ return fitmngr
+
+ def _associateDefaultConfigDialogs(self):
+ """Fill :attr:`bgconfigdialogs` and :attr:`configdialogs` by calling
+ :meth:`associateConfigDialog` with default config dialog widgets.
+ """
+ # associate silx.gui.fit.FitConfig with all theories
+ # Users can later associate their own custom dialogs to
+ # replace the default.
+ configdialog = getFitConfigDialog(parent=self,
+ default=self.fitconfig)
+ for theory in self.fitmanager.theories:
+ self.associateConfigDialog(theory, configdialog)
+ for bgtheory in self.fitmanager.bgtheories:
+ self.associateConfigDialog(bgtheory, configdialog,
+ theory_is_background=True)
+
+ # associate silx.gui.fit.BackgroundWidget with Strip and Snip
+ bgdialog = getBgDialog(parent=self,
+ default=self.fitconfig)
+ for bgtheory in ["Strip", "Snip"]:
+ if bgtheory in self.fitmanager.bgtheories:
+ self.associateConfigDialog(bgtheory, bgdialog,
+ theory_is_background=True)
+
+ def _populateFunctions(self):
+ """Fill combo-boxes with fit theories and background theories
+ loaded by :attr:`fitmanager`.
+ Run :meth:`fitmanager.configure` to ensure the custom configuration
+ of the selected theory has been loaded into :attr:`fitconfig`"""
+ for theory_name in self.fitmanager.bgtheories:
+ self.guiConfig.BkgComBox.addItem(theory_name)
+ self.guiConfig.BkgComBox.setItemData(
+ self.guiConfig.BkgComBox.findText(theory_name),
+ self.fitmanager.bgtheories[theory_name].description,
+ qt.Qt.ToolTipRole)
+
+ for theory_name in self.fitmanager.theories:
+ self.guiConfig.FunComBox.addItem(theory_name)
+ self.guiConfig.FunComBox.setItemData(
+ self.guiConfig.FunComBox.findText(theory_name),
+ self.fitmanager.theories[theory_name].description,
+ qt.Qt.ToolTipRole)
+
+ # - activate selected fit theory (if any)
+ # - activate selected bg theory (if any)
+ configuration = self.fitmanager.configure()
+ if self.fitmanager.selectedtheory is None:
+ # take the first one by default
+ self.guiConfig.FunComBox.setCurrentIndex(1)
+ self.funEvent(list(self.fitmanager.theories.keys())[0])
+ else:
+ idx = list(self.fitmanager.theories).index(self.fitmanager.selectedtheory)
+ self.guiConfig.FunComBox.setCurrentIndex(idx + 1)
+ self.funEvent(self.fitmanager.selectedtheory)
+
+ if self.fitmanager.selectedbg is None:
+ self.guiConfig.BkgComBox.setCurrentIndex(1)
+ self.bkgEvent(list(self.fitmanager.bgtheories.keys())[0])
+ else:
+ idx = list(self.fitmanager.bgtheories).index(self.fitmanager.selectedbg)
+ self.guiConfig.BkgComBox.setCurrentIndex(idx + 1)
+ self.bkgEvent(self.fitmanager.selectedbg)
+
+ configuration.update(self.configure())
+
+ @deprecated(replacement='setData', since_version='0.3.0')
+ def setdata(self, x, y, sigmay=None, xmin=None, xmax=None):
+ self.setData(x, y, sigmay, xmin, xmax)
+
+ def setData(self, x=None, y=None, sigmay=None, xmin=None, xmax=None):
+ """Set data to be fitted.
+
+ :param x: Abscissa data. If ``None``, :attr:`xdata`` is set to
+ ``numpy.array([0.0, 1.0, 2.0, ..., len(y)-1])``
+ :type x: Sequence or numpy array or None
+ :param y: The dependant data ``y = f(x)``. ``y`` must have the same
+ shape as ``x`` if ``x`` is not ``None``.
+ :type y: Sequence or numpy array or None
+ :param sigmay: The uncertainties in the ``ydata`` array. These are
+ used as weights in the least-squares problem.
+ If ``None``, the uncertainties are assumed to be 1.
+ :type sigmay: Sequence or numpy array or None
+ :param xmin: Lower value of x values to use for fitting
+ :param xmax: Upper value of x values to use for fitting
+ """
+ if y is None:
+ self.guibuttons.EstimateButton.setEnabled(False)
+ self.guibuttons.StartFitButton.setEnabled(False)
+ else:
+ self.guibuttons.EstimateButton.setEnabled(True)
+ self.guibuttons.StartFitButton.setEnabled(True)
+ self.fitmanager.setdata(x=x, y=y, sigmay=sigmay,
+ xmin=xmin, xmax=xmax)
+ for config_dialog in self.bgconfigdialogs.values():
+ if isinstance(config_dialog, BackgroundDialog):
+ config_dialog.setData(x, y, xmin=xmin, xmax=xmax)
+
+ def associateConfigDialog(self, theory_name, config_widget,
+ theory_is_background=False):
+ """Associate an instance of custom configuration dialog widget to
+ a fit theory or to a background theory.
+
+ This adds or modifies an item in the correspondence table
+ :attr:`configdialogs` or :attr:`bgconfigdialogs`.
+
+ :param str theory_name: Name of fit theory. This must be a key of dict
+ :attr:`fitmanager.theories`
+ :param config_widget: Custom configuration widget. See documentation
+ for :attr:`configdialogs`
+ :param bool theory_is_background: If flag is *True*, add dialog to
+ :attr:`bgconfigdialogs` rather than :attr:`configdialogs`
+ (default).
+ :raise: KeyError if parameter ``theory_name`` does not match an
+ existing fit theory or background theory in :attr:`fitmanager`.
+ :raise: AttributeError if the widget does not implement the mandatory
+ methods (*show*, *exec*, *result*, *setDefault*) or the mandatory
+ attribute (*output*).
+ """
+ theories = self.fitmanager.bgtheories if theory_is_background else\
+ self.fitmanager.theories
+
+ if theory_name not in theories:
+ raise KeyError("%s does not match an existing fitmanager theory")
+
+ if config_widget is not None:
+ if (not hasattr(config_widget, "exec") and
+ not hasattr(config_widget, "exec_")):
+ raise AttributeError(
+ "Custom configuration widget must define exec or exec_")
+
+ for mandatory_attr in ["show", "result", "output"]:
+ if not hasattr(config_widget, mandatory_attr):
+ raise AttributeError(
+ "Custom configuration widget must define " +
+ "attribute or method " + mandatory_attr)
+
+ if theory_is_background:
+ self.bgconfigdialogs[theory_name] = config_widget
+ else:
+ self.configdialogs[theory_name] = config_widget
+
+ def _emitSignal(self, ddict):
+ """Emit pyqtSignal after estimation completed
+ (``ddict = {'event': 'EstimateFinished', 'data': fit_results}``)
+ and after fit completed
+ (``ddict = {'event': 'FitFinished', 'data': fit_results}``)"""
+ self.sigFitWidgetSignal.emit(ddict)
+
+ def __funConfigureGuiSlot(self):
+ """Open an advanced configuration dialog widget"""
+ self.__configureGui(dialog_type="function")
+
+ def __bgConfigureGuiSlot(self):
+ """Open an advanced configuration dialog widget"""
+ self.__configureGui(dialog_type="background")
+
+ def __configureGui(self, newconfiguration=None, dialog_type="function"):
+ """Open an advanced configuration dialog widget to get a configuration
+ dictionary, or use a supplied configuration dictionary. Call
+ :meth:`configure` with this dictionary as a parameter. Update the gui
+ accordingly. Reinitialize the fit results in the table and in
+ :attr:`fitmanager`.
+
+ :param newconfiguration: User supplied configuration dictionary. If ``None``,
+ open a dialog widget that returns a dictionary."""
+ configuration = self.configure()
+ # get new dictionary
+ if newconfiguration is None:
+ newconfiguration = self.configureDialog(configuration, dialog_type)
+ # update configuration
+ configuration.update(self.configure(**newconfiguration))
+ # set fit function theory
+ try:
+ i = 1 + \
+ list(self.fitmanager.theories.keys()).index(
+ self.fitmanager.selectedtheory)
+ self.guiConfig.FunComBox.setCurrentIndex(i)
+ self.funEvent(self.fitmanager.selectedtheory)
+ except ValueError:
+ _logger.error("Function not in list %s",
+ self.fitmanager.selectedtheory)
+ self.funEvent(list(self.fitmanager.theories.keys())[0])
+ # current background
+ try:
+ i = 1 + \
+ list(self.fitmanager.bgtheories.keys()).index(
+ self.fitmanager.selectedbg)
+ self.guiConfig.BkgComBox.setCurrentIndex(i)
+ self.bkgEvent(self.fitmanager.selectedbg)
+ except ValueError:
+ _logger.error("Background not in list %s",
+ self.fitmanager.selectedbg)
+ self.bkgEvent(list(self.fitmanager.bgtheories.keys())[0])
+
+ # update the Gui
+ self.__initialParameters()
+
+ def configureDialog(self, oldconfiguration, dialog_type="function"):
+ """Display a dialog, allowing the user to define fit configuration
+ parameters.
+
+ By default, a common dialog is used for all fit theories. But if the
+ defined a custom dialog using :meth:`associateConfigDialog`, it is
+ used instead.
+
+ :param dict oldconfiguration: Dictionary containing previous configuration
+ :param str dialog_type: "function" or "background"
+ :return: User defined parameters in a dictionary
+ """
+ newconfiguration = {}
+ newconfiguration.update(oldconfiguration)
+
+ if dialog_type == "function":
+ theory = self.fitmanager.selectedtheory
+ configdialog = self.configdialogs[theory]
+ elif dialog_type == "background":
+ theory = self.fitmanager.selectedbg
+ configdialog = self.bgconfigdialogs[theory]
+
+ # this should only happen if a user specifically associates None
+ # with a theory, to have no configuration option
+ if configdialog is None:
+ return {}
+
+ # update state of configdialog before showing it
+ if hasattr(configdialog, "setDefault"):
+ configdialog.setDefault(newconfiguration)
+ configdialog.show()
+ if hasattr(configdialog, "exec"):
+ configdialog.exec()
+ else: # Qt5 compatibility
+ configdialog.exec_()
+ if configdialog.result():
+ newconfiguration.update(configdialog.output)
+
+ return newconfiguration
+
+ def estimate(self):
+ """Run parameter estimation function then emit
+ :attr:`sigFitWidgetSignal` with a dictionary containing a status
+ message and a list of fit parameters estimations
+ in the format defined in
+ :attr:`silx.math.fit.fitmanager.FitManager.fit_results`
+
+ The emitted dictionary has an *"event"* key that can have
+ following values:
+
+ - *'EstimateStarted'*
+ - *'EstimateFailed'*
+ - *'EstimateFinished'*
+ """
+ try:
+ theory_name = self.fitmanager.selectedtheory
+ estimation_function = self.fitmanager.theories[theory_name].estimate
+ if estimation_function is not None:
+ ddict = {'event': 'EstimateStarted',
+ 'data': None}
+ self._emitSignal(ddict)
+ self.fitmanager.estimate(callback=self.fitStatus)
+ else:
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Information)
+ text = "Function does not define a way to estimate\n"
+ text += "the initial parameters. Please, fill them\n"
+ text += "yourself in the table and press Start Fit\n"
+ msg.setText(text)
+ msg.setWindowTitle('FitWidget Message')
+ msg.exec()
+ return
+ except Exception as e: # noqa (we want to catch and report all errors)
+ _logger.warning('Estimate error: %s', traceback.format_exc())
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setWindowTitle("Estimate Error")
+ msg.setText("Error on estimate: %s" % e)
+ msg.exec()
+ ddict = {
+ 'event': 'EstimateFailed',
+ 'data': None}
+ self._emitSignal(ddict)
+ return
+
+ self.guiParameters.fillFromFit(
+ self.fitmanager.fit_results, view='Fit')
+ self.guiParameters.removeAllViews(keep='Fit')
+ ddict = {
+ 'event': 'EstimateFinished',
+ 'data': self.fitmanager.fit_results}
+ self._emitSignal(ddict)
+
+ @deprecated(replacement='startFit', since_version='0.3.0')
+ def startfit(self):
+ self.startFit()
+
+ def startFit(self):
+ """Run fit, then emit :attr:`sigFitWidgetSignal` with a dictionary
+ containing a status message and a list of fit
+ parameters results in the format defined in
+ :attr:`silx.math.fit.fitmanager.FitManager.fit_results`
+
+ The emitted dictionary has an *"event"* key that can have
+ following values:
+
+ - *'FitStarted'*
+ - *'FitFailed'*
+ - *'FitFinished'*
+ """
+ self.fitmanager.fit_results = self.guiParameters.getFitResults()
+ try:
+ ddict = {'event': 'FitStarted',
+ 'data': None}
+ self._emitSignal(ddict)
+ self.fitmanager.runfit(callback=self.fitStatus)
+ except Exception as e: # noqa (we want to catch and report all errors)
+ _logger.warning('Estimate error: %s', traceback.format_exc())
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setWindowTitle("Fit Error")
+ msg.setText("Error on Fit: %s" % e)
+ msg.exec()
+ ddict = {
+ 'event': 'FitFailed',
+ 'data': None
+ }
+ self._emitSignal(ddict)
+ return
+
+ self.guiParameters.fillFromFit(
+ self.fitmanager.fit_results, view='Fit')
+ self.guiParameters.removeAllViews(keep='Fit')
+ ddict = {
+ 'event': 'FitFinished',
+ 'data': self.fitmanager.fit_results
+ }
+ self._emitSignal(ddict)
+ return
+
+ def bkgEvent(self, bgtheory):
+ """Select background theory, then reinitialize parameters"""
+ bgtheory = str(bgtheory)
+ if bgtheory in self.fitmanager.bgtheories:
+ self.fitmanager.setbackground(bgtheory)
+ else:
+ functionsfile = qt.QFileDialog.getOpenFileName(
+ self, "Select python module with your function(s)", "",
+ "Python Files (*.py);;All Files (*)")
+
+ if len(functionsfile):
+ try:
+ self.fitmanager.loadbgtheories(functionsfile)
+ except ImportError:
+ qt.QMessageBox.critical(self, "ERROR",
+ "Function not imported")
+ return
+ else:
+ # empty the ComboBox
+ while self.guiConfig.BkgComBox.count() > 1:
+ self.guiConfig.BkgComBox.removeItem(1)
+ # and fill it again
+ for key in self.fitmanager.bgtheories:
+ self.guiConfig.BkgComBox.addItem(str(key))
+
+ i = 1 + \
+ list(self.fitmanager.bgtheories.keys()).index(
+ self.fitmanager.selectedbg)
+ self.guiConfig.BkgComBox.setCurrentIndex(i)
+ self.__initialParameters()
+
+ def funEvent(self, theoryname):
+ """Select a fit theory to be used for fitting. If this theory exists
+ in :attr:`fitmanager`, use it. Then, reinitialize table.
+
+ :param theoryname: Name of the fit theory to use for fitting. If this theory
+ exists in :attr:`fitmanager`, use it. Else, open a file dialog to open
+ a custom fit function definition file with
+ :meth:`fitmanager.loadtheories`.
+ """
+ theoryname = str(theoryname)
+ if theoryname in self.fitmanager.theories:
+ self.fitmanager.settheory(theoryname)
+ else:
+ # open a load file dialog
+ functionsfile = qt.QFileDialog.getOpenFileName(
+ self, "Select python module with your function(s)", "",
+ "Python Files (*.py);;All Files (*)")
+
+ if len(functionsfile):
+ try:
+ self.fitmanager.loadtheories(functionsfile)
+ except ImportError:
+ qt.QMessageBox.critical(self, "ERROR",
+ "Function not imported")
+ return
+ else:
+ # empty the ComboBox
+ while self.guiConfig.FunComBox.count() > 1:
+ self.guiConfig.FunComBox.removeItem(1)
+ # and fill it again
+ for key in self.fitmanager.theories:
+ self.guiConfig.FunComBox.addItem(str(key))
+
+ i = 1 + \
+ list(self.fitmanager.theories.keys()).index(
+ self.fitmanager.selectedtheory)
+ self.guiConfig.FunComBox.setCurrentIndex(i)
+ self.__initialParameters()
+
+ def weightEvent(self, flag):
+ """This is called when WeightCheckBox is clicked, to configure the
+ *WeightFlag* field in :attr:`fitmanager.fitconfig` and set weights
+ in the least-square problem."""
+ self.configure(WeightFlag=flag)
+ if flag:
+ self.fitmanager.enableweight()
+ else:
+ # set weights back to 1
+ self.fitmanager.disableweight()
+
+ def __initialParameters(self):
+ """Fill the fit parameters names with names of the parameters of
+ the selected background theory and the selected fit theory.
+ Initialize :attr:`fitmanager.fit_results` with these names, and
+ initialize the table with them. This creates a view called "Fit"
+ in :attr:`guiParameters`"""
+ self.fitmanager.parameter_names = []
+ self.fitmanager.fit_results = []
+ for pname in self.fitmanager.bgtheories[self.fitmanager.selectedbg].parameters:
+ self.fitmanager.parameter_names.append(pname)
+ self.fitmanager.fit_results.append({'name': pname,
+ 'estimation': 0,
+ 'group': 0,
+ 'code': 'FREE',
+ 'cons1': 0,
+ 'cons2': 0,
+ 'fitresult': 0.0,
+ 'sigma': 0.0,
+ 'xmin': None,
+ 'xmax': None})
+ if self.fitmanager.selectedtheory is not None:
+ theory = self.fitmanager.selectedtheory
+ for pname in self.fitmanager.theories[theory].parameters:
+ self.fitmanager.parameter_names.append(pname + "1")
+ self.fitmanager.fit_results.append({'name': pname + "1",
+ 'estimation': 0,
+ 'group': 1,
+ 'code': 'FREE',
+ 'cons1': 0,
+ 'cons2': 0,
+ 'fitresult': 0.0,
+ 'sigma': 0.0,
+ 'xmin': None,
+ 'xmax': None})
+
+ self.guiParameters.fillFromFit(
+ self.fitmanager.fit_results, view='Fit')
+
+ def fitStatus(self, data):
+ """Set *status* and *chisq* in status bar"""
+ if 'chisq' in data:
+ if data['chisq'] is None:
+ self.guistatus.ChisqLine.setText(" ")
+ else:
+ chisq = data['chisq']
+ self.guistatus.ChisqLine.setText("%6.2f" % chisq)
+
+ if 'status' in data:
+ status = data['status']
+ self.guistatus.StatusLine.setText(str(status))
+
+ def dismiss(self):
+ """Close FitWidget"""
+ self.close()
+
+
+if __name__ == "__main__":
+ import numpy
+
+ x = numpy.arange(1500).astype(numpy.float64)
+ constant_bg = 3.14
+
+ p = [1000, 100., 30.0,
+ 500, 300., 25.,
+ 1700, 500., 35.,
+ 750, 700., 30.0,
+ 1234, 900., 29.5,
+ 302, 1100., 30.5,
+ 75, 1300., 21.]
+ y = functions.sum_gauss(x, *p) + constant_bg
+
+ a = qt.QApplication(sys.argv)
+ w = FitWidget()
+ w.setData(x=x, y=y)
+ w.show()
+ a.exec()
diff --git a/src/silx/gui/fit/FitWidgets.py b/src/silx/gui/fit/FitWidgets.py
new file mode 100644
index 0000000..0fcc6b7
--- /dev/null
+++ b/src/silx/gui/fit/FitWidgets.py
@@ -0,0 +1,555 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ######################################################################### */
+"""Collection of widgets used to build
+:class:`silx.gui.fit.FitWidget.FitWidget`"""
+
+from collections import OrderedDict
+
+from silx.gui import qt
+from silx.gui.fit.Parameters import Parameters
+
+
+__authors__ = ["V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "13/10/2016"
+
+
+class FitActionsButtons(qt.QWidget):
+ """Widget with 3 ``QPushButton``:
+
+ The buttons can be accessed as public attributes::
+
+ - ``EstimateButton``
+ - ``StartFitButton``
+ - ``DismissButton``
+
+ You will typically need to access these attributes to connect the buttons
+ to actions. For instance, if you have 3 functions ``estimate``,
+ ``runfit`` and ``dismiss``, you can connect them like this::
+
+ >>> fit_actions_buttons = FitActionsButtons()
+ >>> fit_actions_buttons.EstimateButton.clicked.connect(estimate)
+ >>> fit_actions_buttons.StartFitButton.clicked.connect(runfit)
+ >>> fit_actions_buttons.DismissButton.clicked.connect(dismiss)
+
+ """
+
+ def __init__(self, parent=None):
+ qt.QWidget.__init__(self, parent)
+
+ self.resize(234, 53)
+
+ grid_layout = qt.QGridLayout(self)
+ grid_layout.setContentsMargins(11, 11, 11, 11)
+ grid_layout.setSpacing(6)
+ layout = qt.QHBoxLayout(None)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(6)
+
+ self.EstimateButton = qt.QPushButton(self)
+ self.EstimateButton.setText("Estimate")
+ layout.addWidget(self.EstimateButton)
+ spacer = qt.QSpacerItem(20, 20,
+ qt.QSizePolicy.Expanding,
+ qt.QSizePolicy.Minimum)
+ layout.addItem(spacer)
+
+ self.StartFitButton = qt.QPushButton(self)
+ self.StartFitButton.setText("Start Fit")
+ layout.addWidget(self.StartFitButton)
+ spacer_2 = qt.QSpacerItem(20, 20,
+ qt.QSizePolicy.Expanding,
+ qt.QSizePolicy.Minimum)
+ layout.addItem(spacer_2)
+
+ self.DismissButton = qt.QPushButton(self)
+ self.DismissButton.setText("Dismiss")
+ layout.addWidget(self.DismissButton)
+
+ grid_layout.addLayout(layout, 0, 0)
+
+
+class FitStatusLines(qt.QWidget):
+ """Widget with 2 greyed out write-only ``QLineEdit``.
+
+ These text widgets can be accessed as public attributes::
+
+ - ``StatusLine``
+ - ``ChisqLine``
+
+ You will typically need to access these widgets to update the displayed
+ text::
+
+ >>> fit_status_lines = FitStatusLines()
+ >>> fit_status_lines.StatusLine.setText("Ready")
+ >>> fit_status_lines.ChisqLine.setText("%6.2f" % 0.01)
+
+ """
+
+ def __init__(self, parent=None):
+ qt.QWidget.__init__(self, parent)
+
+ self.resize(535, 47)
+
+ layout = qt.QHBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(6)
+
+ self.StatusLabel = qt.QLabel(self)
+ self.StatusLabel.setText("Status:")
+ layout.addWidget(self.StatusLabel)
+
+ self.StatusLine = qt.QLineEdit(self)
+ self.StatusLine.setText("Ready")
+ self.StatusLine.setReadOnly(1)
+ layout.addWidget(self.StatusLine)
+
+ self.ChisqLabel = qt.QLabel(self)
+ self.ChisqLabel.setText("Reduced chisq:")
+ layout.addWidget(self.ChisqLabel)
+
+ self.ChisqLine = qt.QLineEdit(self)
+ self.ChisqLine.setMaximumSize(qt.QSize(16000, 32767))
+ self.ChisqLine.setText("")
+ self.ChisqLine.setReadOnly(1)
+ layout.addWidget(self.ChisqLine)
+
+
+class FitConfigWidget(qt.QWidget):
+ """Widget whose purpose is to select a fit theory and a background
+ theory, load a new fit theory definition file and provide
+ a "Configure" button to open an advanced configuration dialog.
+
+ This is used in :class:`silx.gui.fit.FitWidget.FitWidget`, to offer
+ an interface to quickly modify the main parameters prior to running a fit:
+
+ - select a fitting function through :attr:`FunComBox`
+ - select a background function through :attr:`BkgComBox`
+ - open a dialog for modifying advanced parameters through
+ :attr:`FunConfigureButton`
+ """
+ def __init__(self, parent=None):
+ qt.QWidget.__init__(self, parent)
+
+ self.setWindowTitle("FitConfigGUI")
+
+ layout = qt.QGridLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(6)
+
+ self.FunLabel = qt.QLabel(self)
+ self.FunLabel.setText("Function")
+ layout.addWidget(self.FunLabel, 0, 0)
+
+ self.FunComBox = qt.QComboBox(self)
+ self.FunComBox.addItem("Add Function(s)")
+ self.FunComBox.setItemData(self.FunComBox.findText("Add Function(s)"),
+ "Load fit theories from a file",
+ qt.Qt.ToolTipRole)
+ layout.addWidget(self.FunComBox, 0, 1)
+
+ self.BkgLabel = qt.QLabel(self)
+ self.BkgLabel.setText("Background")
+ layout.addWidget(self.BkgLabel, 1, 0)
+
+ self.BkgComBox = qt.QComboBox(self)
+ self.BkgComBox.addItem("Add Background(s)")
+ self.BkgComBox.setItemData(self.BkgComBox.findText("Add Background(s)"),
+ "Load background theories from a file",
+ qt.Qt.ToolTipRole)
+ layout.addWidget(self.BkgComBox, 1, 1)
+
+ self.FunConfigureButton = qt.QPushButton(self)
+ self.FunConfigureButton.setText("Configure")
+ self.FunConfigureButton.setToolTip(
+ "Open a configuration dialog for the selected function")
+ layout.addWidget(self.FunConfigureButton, 0, 2)
+
+ self.BgConfigureButton = qt.QPushButton(self)
+ self.BgConfigureButton.setText("Configure")
+ self.BgConfigureButton.setToolTip(
+ "Open a configuration dialog for the selected background")
+ layout.addWidget(self.BgConfigureButton, 1, 2)
+
+ self.WeightCheckBox = qt.QCheckBox(self)
+ self.WeightCheckBox.setText("Weighted fit")
+ self.WeightCheckBox.setToolTip(
+ "Enable usage of weights in the least-square problem.\n Use" +
+ " the uncertainties (sigma) if provided, else use sqrt(y).")
+
+ layout.addWidget(self.WeightCheckBox, 0, 3, 2, 1)
+
+ layout.setColumnStretch(4, 1)
+
+
+class ParametersTab(qt.QTabWidget):
+ """This widget provides tabs to display and modify fit parameters. Each
+ tab contains a table with fit data such as parameter names, estimated
+ values, fit constraints, and final fit results.
+
+ The usual way to initialize the table is to fill it with the fit
+ parameters from a :class:`silx.math.fit.fitmanager.FitManager` object, after
+ the estimation process or after the final fit.
+
+ In the following example we use a :class:`ParametersTab` to display the
+ results of two separate fits::
+
+ from silx.math.fit import fittheories
+ from silx.math.fit import fitmanager
+ from silx.math.fit import functions
+ from silx.gui import qt
+ import numpy
+
+ a = qt.QApplication([])
+
+ # Create synthetic data
+ x = numpy.arange(1000)
+ y1 = functions.sum_gauss(x, 100, 400, 100)
+
+ fit = fitmanager.FitManager(x=x, y=y1)
+
+ fitfuns = fittheories.FitTheories()
+ fit.addtheory(theory="Gaussian",
+ function=functions.sum_gauss,
+ parameters=("height", "peak center", "fwhm"),
+ estimate=fitfuns.estimate_height_position_fwhm)
+ fit.settheory('Gaussian')
+ fit.configure(PositiveFwhmFlag=True,
+ PositiveHeightAreaFlag=True,
+ AutoFwhm=True,)
+
+ # Fit
+ fit.estimate()
+ fit.runfit()
+
+ # Show first fit result in a tab in our widget
+ w = ParametersTab()
+ w.show()
+ w.fillFromFit(fit.fit_results, view='Gaussians')
+
+ # new synthetic data
+ y2 = functions.sum_splitgauss(x,
+ 100, 400, 100, 40,
+ 10, 600, 50, 500,
+ 80, 850, 10, 50)
+ fit.setData(x=x, y=y2)
+
+ # Define new theory
+ fit.addtheory(theory="Asymetric gaussian",
+ function=functions.sum_splitgauss,
+ parameters=("height", "peak center", "left fwhm", "right fwhm"),
+ estimate=fitfuns.estimate_splitgauss)
+ fit.settheory('Asymetric gaussian')
+
+ # Fit
+ fit.estimate()
+ fit.runfit()
+
+ # Show first fit result in another tab in our widget
+ w.fillFromFit(fit.fit_results, view='Asymetric gaussians')
+ a.exec()
+
+ """
+
+ def __init__(self, parent=None, name="FitParameters"):
+ """
+
+ :param parent: Parent widget
+ :param name: Widget title
+ """
+ qt.QTabWidget.__init__(self, parent)
+ self.setWindowTitle(name)
+ self.setContentsMargins(0, 0, 0, 0)
+
+ self.views = OrderedDict()
+ """Dictionary of views. Keys are view names,
+ items are :class:`Parameters` widgets"""
+
+ self.latest_view = None
+ """Name of latest view"""
+
+ # the widgets/tables themselves
+ self.tables = {}
+ """Dictionary of :class:`silx.gui.fit.parameters.Parameters` objects.
+ These objects store fit results
+ """
+
+ self.setContentsMargins(10, 10, 10, 10)
+
+ def setView(self, view=None, fitresults=None):
+ """Add or update a table. Fill it with data from a fit
+
+ :param view: Tab name to be added or updated. If ``None``, use the
+ latest view.
+ :param fitresults: Fit data to be added to the table
+ :raise: KeyError if no view name specified and no latest view
+ available.
+ """
+ if view is None:
+ if self.latest_view is not None:
+ view = self.latest_view
+ else:
+ raise KeyError(
+ "No view available. You must specify a view" +
+ " name the first time you call this method."
+ )
+
+ if view in self.tables.keys():
+ table = self.tables[view]
+ else:
+ # create the parameters instance
+ self.tables[view] = Parameters(self)
+ table = self.tables[view]
+ self.views[view] = table
+ self.addTab(table, str(view))
+
+ if fitresults is not None:
+ table.fillFromFit(fitresults)
+
+ self.setCurrentWidget(self.views[view])
+ self.latest_view = view
+
+ def renameView(self, oldname=None, newname=None):
+ """Rename a view (tab)
+
+ :param oldname: Name of the view to be renamed
+ :param newname: New name of the view"""
+ error = 1
+ if newname is not None:
+ if newname not in self.views.keys():
+ if oldname in self.views.keys():
+ parameterlist = self.tables[oldname].getFitResults()
+ self.setView(view=newname, fitresults=parameterlist)
+ self.removeView(oldname)
+ error = 0
+ return error
+
+ def fillFromFit(self, fitparameterslist, view=None):
+ """Update a view with data from a fit (alias for :meth:`setView`)
+
+ :param view: Tab name to be added or updated (default: latest view)
+ :param fitparameterslist: Fit data to be added to the table
+ """
+ self.setView(view=view, fitresults=fitparameterslist)
+
+ def getFitResults(self, name=None):
+ """Call :meth:`getFitResults` for the
+ :class:`silx.gui.fit.parameters.Parameters` corresponding to the
+ latest table or to the named table (if ``name`` is not
+ ``None``). This return a list of dictionaries in the format used by
+ :class:`silx.math.fit.fitmanager.FitManager` to store fit parameter
+ results.
+
+ :param name: View name.
+ """
+ if name is None:
+ name = self.latest_view
+ return self.tables[name].getFitResults()
+
+ def removeView(self, name):
+ """Remove a view by name.
+
+ :param name: View name.
+ """
+ if name in self.views:
+ index = self.indexOf(self.tables[name])
+ self.removeTab(index)
+ index = self.indexOf(self.views[name])
+ self.removeTab(index)
+ del self.tables[name]
+ del self.views[name]
+
+ def removeAllViews(self, keep=None):
+ """Remove all views, except the one specified (argument
+ ``keep``)
+
+ :param keep: Name of the view to be kept."""
+ for view in self.tables:
+ if view != keep:
+ self.removeView(view)
+
+ def getHtmlText(self, name=None):
+ """Return the table data as HTML
+
+ :param name: View name."""
+ if name is None:
+ name = self.latest_view
+ table = self.tables[name]
+ lemon = ("#%x%x%x" % (255, 250, 205)).upper()
+ hcolor = ("#%x%x%x" % (230, 240, 249)).upper()
+ text = ""
+ text += "<nobr>"
+ text += "<table>"
+ text += "<tr>"
+ ncols = table.columnCount()
+ for l in range(ncols):
+ text += ('<td align="left" bgcolor="%s"><b>' % hcolor)
+ text += str(table.horizontalHeaderItem(l).text())
+ text += "</b></td>"
+ text += "</tr>"
+ nrows = table.rowCount()
+ for r in range(nrows):
+ text += "<tr>"
+ item = table.item(r, 0)
+ newtext = ""
+ if item is not None:
+ newtext = str(item.text())
+ if len(newtext):
+ color = "white"
+ b = "<b>"
+ else:
+ b = ""
+ color = lemon
+ try:
+ # MyQTable item has color defined
+ cc = table.item(r, 0).color
+ cc = ("#%x%x%x" % (cc.red(), cc.green(), cc.blue())).upper()
+ color = cc
+ except:
+ pass
+ for c in range(ncols):
+ item = table.item(r, c)
+ newtext = ""
+ if item is not None:
+ newtext = str(item.text())
+ if len(newtext):
+ finalcolor = color
+ else:
+ finalcolor = "white"
+ if c < 2:
+ text += ('<td align="left" bgcolor="%s">%s' %
+ (finalcolor, b))
+ else:
+ text += ('<td align="right" bgcolor="%s">%s' %
+ (finalcolor, b))
+ text += newtext
+ if len(b):
+ text += "</td>"
+ else:
+ text += "</b></td>"
+ item = table.item(r, 0)
+ newtext = ""
+ if item is not None:
+ newtext = str(item.text())
+ if len(newtext):
+ text += "</b>"
+ text += "</tr>"
+ text += "\n"
+ text += "</table>"
+ text += "</nobr>"
+ return text
+
+ def getText(self, name=None):
+ """Return the table data as CSV formatted text, using tabulation
+ characters as separators.
+
+ :param name: View name."""
+ if name is None:
+ name = self.latest_view
+ table = self.tables[name]
+ text = ""
+ ncols = table.columnCount()
+ for l in range(ncols):
+ text += (str(table.horizontalHeaderItem(l).text())) + "\t"
+ text += "\n"
+ nrows = table.rowCount()
+ for r in range(nrows):
+ for c in range(ncols):
+ newtext = ""
+ if c != 4:
+ item = table.item(r, c)
+ if item is not None:
+ newtext = str(item.text())
+ else:
+ item = table.cellWidget(r, c)
+ if item is not None:
+ newtext = str(item.currentText())
+ text += newtext + "\t"
+ text += "\n"
+ text += "\n"
+ return text
+
+
+def test():
+ from silx.math.fit import fittheories
+ from silx.math.fit import fitmanager
+ from silx.math.fit import functions
+ from silx.gui.plot.PlotWindow import PlotWindow
+ import numpy
+
+ a = qt.QApplication([])
+
+ x = numpy.arange(1000)
+ y1 = functions.sum_gauss(x, 100, 400, 100)
+
+ fit = fitmanager.FitManager(x=x, y=y1)
+
+ fitfuns = fittheories.FitTheories()
+ fit.addtheory(name="Gaussian",
+ function=functions.sum_gauss,
+ parameters=("height", "peak center", "fwhm"),
+ estimate=fitfuns.estimate_height_position_fwhm)
+ fit.settheory('Gaussian')
+ fit.configure(PositiveFwhmFlag=True,
+ PositiveHeightAreaFlag=True,
+ AutoFwhm=True,)
+
+ # Fit
+ fit.estimate()
+ fit.runfit()
+
+ w = ParametersTab()
+ w.show()
+ w.fillFromFit(fit.fit_results, view='Gaussians')
+
+ y2 = functions.sum_splitgauss(x,
+ 100, 400, 100, 40,
+ 10, 600, 50, 500,
+ 80, 850, 10, 50)
+ fit.setdata(x=x, y=y2)
+
+ # Define new theory
+ fit.addtheory(name="Asymetric gaussian",
+ function=functions.sum_splitgauss,
+ parameters=("height", "peak center", "left fwhm", "right fwhm"),
+ estimate=fitfuns.estimate_splitgauss)
+ fit.settheory('Asymetric gaussian')
+
+ # Fit
+ fit.estimate()
+ fit.runfit()
+
+ w.fillFromFit(fit.fit_results, view='Asymetric gaussians')
+
+ # Plot
+ pw = PlotWindow(control=True)
+ pw.addCurve(x, y1, "Gaussians")
+ pw.addCurve(x, y2, "Asymetric gaussians")
+ pw.show()
+
+ a.exec()
+
+
+if __name__ == "__main__":
+ test()
diff --git a/src/silx/gui/fit/Parameters.py b/src/silx/gui/fit/Parameters.py
new file mode 100644
index 0000000..daa72f3
--- /dev/null
+++ b/src/silx/gui/fit/Parameters.py
@@ -0,0 +1,882 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ######################################################################### */
+"""This module defines a table widget that is specialized in displaying fit
+parameter results and associated constraints."""
+__authors__ = ["V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "25/11/2016"
+
+import sys
+from collections import OrderedDict
+
+from silx.gui import qt
+from silx.gui.widgets.TableWidget import TableWidget
+
+
+def float_else_zero(sstring):
+ """Return converted string to float. If conversion fail, return zero.
+
+ :param sstring: String to be converted
+ :return: ``float(sstrinq)`` if ``sstring`` can be converted to float
+ (e.g. ``"3.14"``), else ``0``
+ """
+ try:
+ return float(sstring)
+ except ValueError:
+ return 0
+
+
+class QComboTableItem(qt.QComboBox):
+ """:class:`qt.QComboBox` augmented with a ``sigCellChanged`` signal
+ to emit a tuple of ``(row, column)`` coordinates when the value is
+ changed.
+
+ This signal can be used to locate the modified combo box in a table.
+
+ :param row: Row number of the table cell containing this widget
+ :param col: Column number of the table cell containing this widget"""
+ sigCellChanged = qt.Signal(int, int)
+ """Signal emitted when this ``QComboBox`` is activated.
+ A ``(row, column)`` tuple is passed."""
+
+ def __init__(self, parent=None, row=None, col=None):
+ self._row = row
+ self._col = col
+ qt.QComboBox.__init__(self, parent)
+ self.activated[int].connect(self._cellChanged)
+
+ def _cellChanged(self, idx): # noqa
+ self.sigCellChanged.emit(self._row, self._col)
+
+
+class QCheckBoxItem(qt.QCheckBox):
+ """:class:`qt.QCheckBox` augmented with a ``sigCellChanged`` signal
+ to emit a tuple of ``(row, column)`` coordinates when the check box has
+ been clicked on.
+
+ This signal can be used to locate the modified check box in a table.
+
+ :param row: Row number of the table cell containing this widget
+ :param col: Column number of the table cell containing this widget"""
+ sigCellChanged = qt.Signal(int, int)
+ """Signal emitted when this ``QCheckBox`` is clicked.
+ A ``(row, column)`` tuple is passed."""
+
+ def __init__(self, parent=None, row=None, col=None):
+ self._row = row
+ self._col = col
+ qt.QCheckBox.__init__(self, parent)
+ self.clicked.connect(self._cellChanged)
+
+ def _cellChanged(self):
+ self.sigCellChanged.emit(self._row, self._col)
+
+
+class Parameters(TableWidget):
+ """:class:`TableWidget` customized to display fit results
+ and to interact with :class:`FitManager` objects.
+
+ Data and references to cell widgets are kept in a dictionary
+ attribute :attr:`parameters`.
+
+ :param parent: Parent widget
+ :param labels: Column headers. If ``None``, default headers will be used.
+ :type labels: List of strings or None
+ :param paramlist: List of fit parameters to be displayed for each fitted
+ peak.
+ :type paramlist: list[str] or None
+ """
+ def __init__(self, parent=None, paramlist=None):
+ TableWidget.__init__(self, parent)
+ self.setContentsMargins(0, 0, 0, 0)
+
+ labels = ['Parameter', 'Estimation', 'Fit Value', 'Sigma',
+ 'Constraints', 'Min/Parame', 'Max/Factor/Delta']
+ tooltips = ["Fit parameter name",
+ "Estimated value for fit parameter. You can edit this column.",
+ "Actual value for parameter, after fit",
+ "Uncertainty (same unit as the parameter)",
+ "Constraint to be applied to the parameter for fit",
+ "First parameter for constraint (name of another param or min value)",
+ "Second parameter for constraint (max value, or factor/delta)"]
+
+ self.columnKeys = ['name', 'estimation', 'fitresult',
+ 'sigma', 'code', 'val1', 'val2']
+ """This list assigns shorter keys to refer to columns than the
+ displayed labels."""
+
+ self.__configuring = False
+
+ # column headers and associated tooltips
+ self.setColumnCount(len(labels))
+
+ for i, label in enumerate(labels):
+ item = self.horizontalHeaderItem(i)
+ if item is None:
+ item = qt.QTableWidgetItem(label,
+ qt.QTableWidgetItem.Type)
+ self.setHorizontalHeaderItem(i, item)
+
+ item.setText(label)
+ if tooltips is not None:
+ item.setToolTip(tooltips[i])
+
+ # resize columns
+ for col_key in ["name", "estimation", "sigma", "val1", "val2"]:
+ col_idx = self.columnIndexByField(col_key)
+ self.resizeColumnToContents(col_idx)
+
+ # Initialize the table with one line per supplied parameter
+ paramlist = paramlist if paramlist is not None else []
+ self.parameters = OrderedDict()
+ """This attribute stores all the data in an ordered dictionary.
+ New data can be added using :meth:`newParameterLine`.
+ Existing data can be modified using :meth:`configureLine`
+
+ Keys of the dictionary are:
+
+ - 'name': parameter name
+ - 'line': line index for the parameter in the table
+ - 'estimation'
+ - 'fitresult'
+ - 'sigma'
+ - 'code': constraint code (one of the elements of
+ :attr:`code_options`)
+ - 'val1': first parameter related to constraint, formatted
+ as a string, as typed in the table
+ - 'val2': second parameter related to constraint, formatted
+ as a string, as typed in the table
+ - 'cons1': scalar representation of 'val1'
+ (e.g. when val1 is the name of a fit parameter, cons1
+ will be the line index of this parameter)
+ - 'cons2': scalar representation of 'val2'
+ - 'vmin': equal to 'val1' when 'code' is "QUOTED"
+ - 'vmax': equal to 'val2' when 'code' is "QUOTED"
+ - 'relatedto': name of related parameter when this parameter
+ is constrained to another parameter (same as 'val1')
+ - 'factor': same as 'val2' when 'code' is 'FACTOR'
+ - 'delta': same as 'val2' when 'code' is 'DELTA'
+ - 'sum': same as 'val2' when 'code' is 'SUM'
+ - 'group': group index for the parameter
+ - 'xmin': data range minimum
+ - 'xmax': data range maximum
+ """
+ for line, param in enumerate(paramlist):
+ self.newParameterLine(param, line)
+
+ self.code_options = ["FREE", "POSITIVE", "QUOTED", "FIXED",
+ "FACTOR", "DELTA", "SUM", "IGNORE", "ADD"]
+ """Possible values in the combo boxes in the 'Constraints' column.
+ """
+
+ # connect signal
+ self.cellChanged[int, int].connect(self.onCellChanged)
+
+ def newParameterLine(self, param, line):
+ """Add a line to the :class:`QTableWidget`.
+
+ Each line represents one of the fit parameters for one of
+ the fitted peaks.
+
+ :param param: Name of the fit parameter
+ :type param: str
+ :param line: 0-based line index
+ :type line: int
+ """
+ # get current number of lines
+ nlines = self.rowCount()
+ self.__configuring = True
+ if line >= nlines:
+ self.setRowCount(line + 1)
+
+ # default configuration for fit parameters
+ self.parameters[param] = OrderedDict((('line', line),
+ ('estimation', '0'),
+ ('fitresult', ''),
+ ('sigma', ''),
+ ('code', 'FREE'),
+ ('val1', ''),
+ ('val2', ''),
+ ('cons1', 0),
+ ('cons2', 0),
+ ('vmin', '0'),
+ ('vmax', '1'),
+ ('relatedto', ''),
+ ('factor', '1.0'),
+ ('delta', '0.0'),
+ ('sum', '0.0'),
+ ('group', ''),
+ ('name', param),
+ ('xmin', None),
+ ('xmax', None)))
+ self.setReadWrite(param, 'estimation')
+ self.setReadOnly(param, ['name', 'fitresult', 'sigma', 'val1', 'val2'])
+
+ # Constraint codes
+ a = []
+ for option in self.code_options:
+ a.append(option)
+
+ code_column_index = self.columnIndexByField('code')
+ cellWidget = self.cellWidget(line, code_column_index)
+ if cellWidget is None:
+ cellWidget = QComboTableItem(self, row=line,
+ col=code_column_index)
+ cellWidget.addItems(a)
+ self.setCellWidget(line, code_column_index, cellWidget)
+ cellWidget.sigCellChanged[int, int].connect(self.onCellChanged)
+ self.parameters[param]['code_item'] = cellWidget
+ self.parameters[param]['relatedto_item'] = None
+ self.__configuring = False
+
+ def columnIndexByField(self, field):
+ """
+
+ :param field: Field name (column key)
+ :return: Index of the column with this field name
+ """
+ return self.columnKeys.index(field)
+
+ def fillFromFit(self, fitresults):
+ """Fill table with values from a list of dictionaries
+ (see :attr:`silx.math.fit.fitmanager.FitManager.fit_results`)
+
+ :param fitresults: List of parameters as recorded
+ in the ``paramlist`` attribute of a :class:`FitManager` object
+ :type fitresults: list[dict]
+ """
+ self.setRowCount(len(fitresults))
+
+ # Reinitialize and fill self.parameters
+ self.parameters = OrderedDict()
+ for (line, param) in enumerate(fitresults):
+ self.newParameterLine(param['name'], line)
+
+ for param in fitresults:
+ name = param['name']
+ code = str(param['code'])
+ if code not in self.code_options:
+ # convert code from int to descriptive string
+ code = self.code_options[int(code)]
+ val1 = param['cons1']
+ val2 = param['cons2']
+ estimation = param['estimation']
+ group = param['group']
+ sigma = param['sigma']
+ fitresult = param['fitresult']
+
+ xmin = param.get('xmin')
+ xmax = param.get('xmax')
+
+ self.configureLine(name=name,
+ code=code,
+ val1=val1, val2=val2,
+ estimation=estimation,
+ fitresult=fitresult,
+ sigma=sigma,
+ group=group,
+ xmin=xmin, xmax=xmax)
+
+ def getConfiguration(self):
+ """Return ``FitManager.paramlist`` dictionary
+ encapsulated in another dictionary"""
+ return {'parameters': self.getFitResults()}
+
+ def setConfiguration(self, ddict):
+ """Fill table with values from a ``FitManager.paramlist`` dictionary
+ encapsulated in another dictionary"""
+ self.fillFromFit(ddict['parameters'])
+
+ def getFitResults(self):
+ """Return fit parameters as a list of dictionaries in the format used
+ by :class:`FitManager` (attribute ``paramlist``).
+ """
+ fitparameterslist = []
+ for param in self.parameters:
+ fitparam = {}
+ name = param
+ estimation, [code, cons1, cons2] = self.getEstimationConstraints(name)
+ buf = str(self.parameters[param]['fitresult'])
+ xmin = self.parameters[param]['xmin']
+ xmax = self.parameters[param]['xmax']
+ if len(buf):
+ fitresult = float(buf)
+ else:
+ fitresult = 0.0
+ buf = str(self.parameters[param]['sigma'])
+ if len(buf):
+ sigma = float(buf)
+ else:
+ sigma = 0.0
+ buf = str(self.parameters[param]['group'])
+ if len(buf):
+ group = float(buf)
+ else:
+ group = 0
+ fitparam['name'] = name
+ fitparam['estimation'] = estimation
+ fitparam['fitresult'] = fitresult
+ fitparam['sigma'] = sigma
+ fitparam['group'] = group
+ fitparam['code'] = code
+ fitparam['cons1'] = cons1
+ fitparam['cons2'] = cons2
+ fitparam['xmin'] = xmin
+ fitparam['xmax'] = xmax
+ fitparameterslist.append(fitparam)
+ return fitparameterslist
+
+ def onCellChanged(self, row, col):
+ """Slot called when ``cellChanged`` signal is emitted.
+ Checks the validity of the new text in the cell, then calls
+ :meth:`configureLine` to update the internal ``self.parameters``
+ dictionary.
+
+ :param row: Row number of the changed cell (0-based index)
+ :param col: Column number of the changed cell (0-based index)
+ """
+ if (col != self.columnIndexByField("code")) and (col != -1):
+ if row != self.currentRow():
+ return
+ if col != self.currentColumn():
+ return
+ if self.__configuring:
+ return
+ param = list(self.parameters)[row]
+ field = self.columnKeys[col]
+ oldvalue = self.parameters[param][field]
+ if col != 4:
+ item = self.item(row, col)
+ if item is not None:
+ newvalue = item.text()
+ else:
+ newvalue = ''
+ else:
+ # this is the combobox
+ widget = self.cellWidget(row, col)
+ newvalue = widget.currentText()
+ if self.validate(param, field, oldvalue, newvalue):
+ paramdict = {"name": param, field: newvalue}
+ self.configureLine(**paramdict)
+ else:
+ if field == 'code':
+ # New code not valid, try restoring the old one
+ index = self.code_options.index(oldvalue)
+ self.__configuring = True
+ try:
+ self.parameters[param]['code_item'].setCurrentIndex(index)
+ finally:
+ self.__configuring = False
+ else:
+ paramdict = {"name": param, field: oldvalue}
+ self.configureLine(**paramdict)
+
+ def validate(self, param, field, oldvalue, newvalue):
+ """Check validity of ``newvalue`` when a cell's value is modified.
+
+ :param param: Fit parameter name
+ :param field: Column name
+ :param oldvalue: Cell value before change attempt
+ :param newvalue: New value to be validated
+ :return: True if new cell value is valid, else False
+ """
+ if field == 'code':
+ return self.setCodeValue(param, oldvalue, newvalue)
+ # FIXME: validate() shouldn't have side effects. Move this bit to configureLine()?
+ if field == 'val1' and str(self.parameters[param]['code']) in ['DELTA', 'FACTOR', 'SUM']:
+ _, candidates = self.getRelatedCandidates(param)
+ # We expect val1 to be a fit parameter name
+ if str(newvalue) in candidates:
+ return True
+ else:
+ return False
+ # except for code, val1 and name (which is read-only and does not need
+ # validation), all fields must always be convertible to float
+ else:
+ try:
+ float(str(newvalue))
+ except ValueError:
+ return False
+ return True
+
+ def setCodeValue(self, param, oldvalue, newvalue):
+ """Update 'code' and 'relatedto' fields when code cell is
+ changed.
+
+ :param param: Fit parameter name
+ :param oldvalue: Cell value before change attempt
+ :param newvalue: New value to be validated
+ :return: ``True`` if code was successfully updated
+ """
+
+ if str(newvalue) in ['FREE', 'POSITIVE', 'QUOTED', 'FIXED']:
+ self.configureLine(name=param,
+ code=newvalue)
+ if str(oldvalue) == 'IGNORE':
+ self.freeRestOfGroup(param)
+ return True
+ elif str(newvalue) in ['FACTOR', 'DELTA', 'SUM']:
+ # I should check here that some parameter is set
+ best, candidates = self.getRelatedCandidates(param)
+ if len(candidates) == 0:
+ return False
+ self.configureLine(name=param,
+ code=newvalue,
+ relatedto=best)
+ if str(oldvalue) == 'IGNORE':
+ self.freeRestOfGroup(param)
+ return True
+
+ elif str(newvalue) == 'IGNORE':
+ # I should check if the group can be ignored
+ # for the time being I just fix all of them to ignore
+ group = int(float(str(self.parameters[param]['group'])))
+ candidates = []
+ for param in self.parameters.keys():
+ if group == int(float(str(self.parameters[param]['group']))):
+ candidates.append(param)
+ # print candidates
+ # I should check here if there is any relation to them
+ for param in candidates:
+ self.configureLine(name=param,
+ code=newvalue)
+ return True
+ elif str(newvalue) == 'ADD':
+ group = int(float(str(self.parameters[param]['group'])))
+ if group == 0:
+ # One cannot add a background group
+ return False
+ i = 0
+ for param in self.parameters:
+ if i <= int(float(str(self.parameters[param]['group']))):
+ i += 1
+ if (group == 0) and (i == 1): # FIXME: why +1?
+ i += 1
+ self.addGroup(i, group)
+ return False
+ elif str(newvalue) == 'SHOW':
+ print(self.getEstimationConstraints(param))
+ return False
+
+ def addGroup(self, newg, gtype):
+ """Add a fit parameter group with the same fit parameters as an
+ existing group.
+
+ This function is called when the user selects "ADD" in the
+ "constraints" combobox.
+
+ :param int newg: New group number
+ :param int gtype: Group number whose parameters we want to copy
+
+ """
+ newparam = []
+ # loop through parameters until we encounter group number `gtype`
+ for param in list(self.parameters):
+ paramgroup = int(float(str(self.parameters[param]['group'])))
+ # copy parameter names in group number `gtype`
+ if paramgroup == gtype:
+ # but replace `gtype` with `newg`
+ newparam.append(param.rstrip("0123456789") + "%d" % newg)
+
+ xmin = self.parameters[param]['xmin']
+ xmax = self.parameters[param]['xmax']
+
+ # Add new parameters (one table line per parameter) and configureLine each
+ # one by updating xmin and xmax to the same values as group `gtype`
+ line = len(list(self.parameters))
+ for param in newparam:
+ self.newParameterLine(param, line)
+ line += 1
+ for param in newparam:
+ self.configureLine(name=param, group=newg, xmin=xmin, xmax=xmax)
+
+ def freeRestOfGroup(self, workparam):
+ """Set ``code`` to ``"FREE"`` for all fit parameters belonging to
+ the same group as ``workparam``. This is done when the entire group
+ of parameters was previously ignored and one of them has his code
+ set to something different than ``"IGNORE"``.
+
+ :param workparam: Fit parameter name
+ """
+ if workparam in self.parameters.keys():
+ group = int(float(str(self.parameters[workparam]['group'])))
+ for param in self.parameters:
+ if param != workparam and\
+ group == int(float(str(self.parameters[param]['group']))):
+ self.configureLine(name=param,
+ code='FREE',
+ cons1=0,
+ cons2=0,
+ val1='',
+ val2='')
+
+ def getRelatedCandidates(self, workparam):
+ """If fit parameter ``workparam`` has a constraint that involves other
+ fit parameters, find possible candidates and try to guess which one
+ is the most likely.
+
+ :param workparam: Fit parameter name
+ :return: (best_candidate, possible_candidates) tuple
+ :rtype: (str, list[str])
+ """
+ candidates = []
+ for param_name in self.parameters:
+ if param_name != workparam:
+ # ignore parameters that are fixed by a constraint
+ if str(self.parameters[param_name]['code']) not in\
+ ['IGNORE', 'FACTOR', 'DELTA', 'SUM']:
+ candidates.append(param_name)
+ # take the previous one (before code cell changed) if possible
+ if str(self.parameters[workparam]['relatedto']) in candidates:
+ best = str(self.parameters[workparam]['relatedto'])
+ return best, candidates
+ # take the first with same base name (after removing numbers)
+ for param_name in candidates:
+ basename = param_name.rstrip("0123456789")
+ try:
+ pos = workparam.index(basename)
+ if pos == 0:
+ best = param_name
+ return best, candidates
+ except ValueError:
+ pass
+ # take the first
+ return candidates[0], candidates
+
+ def setReadOnly(self, parameter, fields):
+ """Make table cells read-only by setting it's flags and omitting
+ flag ``qt.Qt.ItemIsEditable``
+
+ :param parameter: Fit parameter names identifying the rows
+ :type parameter: str or list[str]
+ :param fields: Field names identifying the columns
+ :type fields: str or list[str]
+ """
+ editflags = qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled
+ self.setField(parameter, fields, editflags)
+
+ def setReadWrite(self, parameter, fields):
+ """Make table cells read-write by setting it's flags including
+ flag ``qt.Qt.ItemIsEditable``
+
+ :param parameter: Fit parameter names identifying the rows
+ :type parameter: str or list[str]
+ :param fields: Field names identifying the columns
+ :type fields: str or list[str]
+ """
+ editflags = qt.Qt.ItemIsSelectable |\
+ qt.Qt.ItemIsEnabled |\
+ qt.Qt.ItemIsEditable
+ self.setField(parameter, fields, editflags)
+
+ def setField(self, parameter, fields, edit_flags):
+ """Set text and flags in a table cell.
+
+ :param parameter: Fit parameter names identifying the rows
+ :type parameter: str or list[str]
+ :param fields: Field names identifying the columns
+ :type fields: str or list[str]
+ :param edit_flags: Flag combination, e.g::
+
+ qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled |
+ qt.Qt.ItemIsEditable
+ """
+ if isinstance(parameter, list) or \
+ isinstance(parameter, tuple):
+ paramlist = parameter
+ else:
+ paramlist = [parameter]
+ if isinstance(fields, list) or \
+ isinstance(fields, tuple):
+ fieldlist = fields
+ else:
+ fieldlist = [fields]
+
+ # Set _configuring flag to ignore cellChanged signals in
+ # self.onCellChanged
+ _oldvalue = self.__configuring
+ self.__configuring = True
+
+ # 2D loop through parameter list and field list
+ # to update their cells
+ for param in paramlist:
+ row = list(self.parameters.keys()).index(param)
+ for field in fieldlist:
+ col = self.columnIndexByField(field)
+ if field != 'code':
+ key = field + "_item"
+ item = self.item(row, col)
+ if item is None:
+ item = qt.QTableWidgetItem()
+ item.setText(self.parameters[param][field])
+ self.setItem(row, col, item)
+ else:
+ item.setText(self.parameters[param][field])
+ self.parameters[param][key] = item
+ item.setFlags(edit_flags)
+
+ # Restore previous _configuring flag
+ self.__configuring = _oldvalue
+
+ def configureLine(self, name, code=None, val1=None, val2=None,
+ sigma=None, estimation=None, fitresult=None,
+ group=None, xmin=None, xmax=None, relatedto=None,
+ cons1=None, cons2=None):
+ """This function updates values in a line of the table
+
+ :param name: Name of the parameter (serves as unique identifier for
+ a line).
+ :param code: Constraint code *FREE, FIXED, POSITIVE, DELTA, FACTOR,
+ SUM, QUOTED, IGNORE*
+ :param val1: Constraint 1 (can be the index or name of another
+ parameter for code *DELTA, FACTOR, SUM*, or a min value
+ for code *QUOTED*)
+ :param val2: Constraint 2
+ :param sigma: Standard deviation for a fit parameter
+ :param estimation: Estimated initial value for a fit parameter (used
+ as input to iterative fit)
+ :param fitresult: Final result of fit
+ :param group: Group number of a fit parameter (peak number when doing
+ multi-peak fitting, as each peak corresponds to a group
+ of several consecutive parameters)
+ :param xmin:
+ :param xmax:
+ :param relatedto: Index or name of another fit parameter
+ to which this parameter is related to (constraints)
+ :param cons1: similar meaning to ``val1``, but is always a number
+ :param cons2: similar meaning to ``val2``, but is always a number
+ :return:
+ """
+ paramlist = list(self.parameters.keys())
+
+ if name not in self.parameters:
+ raise KeyError("'%s' is not in the parameter list" % name)
+
+ # update code first, if specified
+ if code is not None:
+ code = str(code)
+ self.parameters[name]['code'] = code
+ # update combobox
+ index = self.parameters[name]['code_item'].findText(code)
+ self.parameters[name]['code_item'].setCurrentIndex(index)
+ else:
+ # set code to previous value, used later for setting val1 val2
+ code = self.parameters[name]['code']
+
+ # val1 and sigma have special formats
+ if val1 is not None:
+ fmt = None if self.parameters[name]['code'] in\
+ ['DELTA', 'FACTOR', 'SUM'] else "%8g"
+ self._updateField(name, "val1", val1, fmat=fmt)
+
+ if sigma is not None:
+ self._updateField(name, "sigma", sigma, fmat="%6.3g")
+
+ # other fields are formatted as "%8g"
+ keys_params = (("val2", val2), ("estimation", estimation),
+ ("fitresult", fitresult))
+ for key, value in keys_params:
+ if value is not None:
+ self._updateField(name, key, value, fmat="%8g")
+
+ # the rest of the parameters are treated as strings and don't need
+ # validation
+ keys_params = (("group", group), ("xmin", xmin),
+ ("xmax", xmax), ("relatedto", relatedto),
+ ("cons1", cons1), ("cons2", cons2))
+ for key, value in keys_params:
+ if value is not None:
+ self.parameters[name][key] = str(value)
+
+ # val1 and val2 have different meanings depending on the code
+ if code == 'QUOTED':
+ if val1 is not None:
+ self.parameters[name]['vmin'] = self.parameters[name]['val1']
+ else:
+ self.parameters[name]['val1'] = self.parameters[name]['vmin']
+ if val2 is not None:
+ self.parameters[name]['vmax'] = self.parameters[name]['val2']
+ else:
+ self.parameters[name]['val2'] = self.parameters[name]['vmax']
+
+ # cons1 and cons2 are scalar representations of val1 and val2
+ self.parameters[name]['cons1'] =\
+ float_else_zero(self.parameters[name]['val1'])
+ self.parameters[name]['cons2'] =\
+ float_else_zero(self.parameters[name]['val2'])
+
+ # cons1, cons2 = min(val1, val2), max(val1, val2)
+ if self.parameters[name]['cons1'] > self.parameters[name]['cons2']:
+ self.parameters[name]['cons1'], self.parameters[name]['cons2'] =\
+ self.parameters[name]['cons2'], self.parameters[name]['cons1']
+
+ elif code in ['DELTA', 'SUM', 'FACTOR']:
+ # For these codes, val1 is the fit parameter name on which the
+ # constraint depends
+ if val1 is not None and val1 in paramlist:
+ self.parameters[name]['relatedto'] = self.parameters[name]["val1"]
+
+ elif val1 is not None:
+ # val1 could be the index of the fit parameter
+ try:
+ self.parameters[name]['relatedto'] = paramlist[int(val1)]
+ except ValueError:
+ self.parameters[name]['relatedto'] = self.parameters[name]["val1"]
+
+ elif relatedto is not None:
+ # code changed, val1 not specified but relatedto specified:
+ # set val1 to relatedto (pre-fill best guess)
+ self.parameters[name]["val1"] = relatedto
+
+ # update fields "delta", "sum" or "factor"
+ key = code.lower()
+ self.parameters[name][key] = self.parameters[name]["val2"]
+
+ # FIXME: val1 is sometimes specified as an index rather than a param name
+ self.parameters[name]['val1'] = self.parameters[name]['relatedto']
+
+ # cons1 is the index of the fit parameter in the ordered dictionary
+ if self.parameters[name]['val1'] in paramlist:
+ self.parameters[name]['cons1'] =\
+ paramlist.index(self.parameters[name]['val1'])
+
+ # cons2 is the constraint value (factor, delta or sum)
+ try:
+ self.parameters[name]['cons2'] =\
+ float(str(self.parameters[name]['val2']))
+ except ValueError:
+ self.parameters[name]['cons2'] = 1.0 if code == "FACTOR" else 0.0
+
+ elif code in ['FREE', 'POSITIVE', 'IGNORE', 'FIXED']:
+ self.parameters[name]['val1'] = ""
+ self.parameters[name]['val2'] = ""
+ self.parameters[name]['cons1'] = 0
+ self.parameters[name]['cons2'] = 0
+
+ self._updateCellRWFlags(name, code)
+
+ def _updateField(self, name, field, value, fmat=None):
+ """Update field in ``self.parameters`` dictionary, if the new value
+ is valid.
+
+ :param name: Fit parameter name
+ :param field: Field name
+ :param value: New value to assign
+ :type value: String
+ :param fmat: Format string (e.g. "%8g") to be applied if value represents
+ a scalar. If ``None``, format is not modified. If ``value`` is an
+ empty string, ``fmat`` is ignored.
+ """
+ if value is not None:
+ oldvalue = self.parameters[name][field]
+ if fmat is not None:
+ newvalue = fmat % float(value) if value != "" else ""
+ else:
+ newvalue = value
+ self.parameters[name][field] = newvalue if\
+ self.validate(name, field, oldvalue, newvalue) else\
+ oldvalue
+
+ def _updateCellRWFlags(self, name, code=None):
+ """Set read-only or read-write flags in a row,
+ depending on the constraint code
+
+ :param name: Fit parameter name identifying the row
+ :param code: Constraint code, in `'FREE', 'POSITIVE', 'IGNORE',`
+ `'FIXED', 'FACTOR', 'DELTA', 'SUM', 'ADD'`
+ :return:
+ """
+ if code in ['FREE', 'POSITIVE', 'IGNORE', 'FIXED']:
+ self.setReadWrite(name, 'estimation')
+ self.setReadOnly(name, ['fitresult', 'sigma', 'val1', 'val2'])
+ else:
+ self.setReadWrite(name, ['estimation', 'val1', 'val2'])
+ self.setReadOnly(name, ['fitresult', 'sigma'])
+
+ def getEstimationConstraints(self, param):
+ """
+ Return tuple ``(estimation, constraints)`` where ``estimation`` is the
+ value in the ``estimate`` field and ``constraints`` are the relevant
+ constraints according to the active code
+ """
+ estimation = None
+ constraints = None
+ if param in self.parameters.keys():
+ buf = str(self.parameters[param]['estimation'])
+ if len(buf):
+ estimation = float(buf)
+ else:
+ estimation = 0
+ if str(self.parameters[param]['code']) in self.code_options:
+ code = self.code_options.index(
+ str(self.parameters[param]['code']))
+ else:
+ code = str(self.parameters[param]['code'])
+ cons1 = self.parameters[param]['cons1']
+ cons2 = self.parameters[param]['cons2']
+ constraints = [code, cons1, cons2]
+ return estimation, constraints
+
+
+def main(args):
+ from silx.math.fit import fittheories
+ from silx.math.fit import fitmanager
+ try:
+ from PyMca5 import PyMcaDataDir
+ except ImportError:
+ raise ImportError("This demo requires PyMca data. Install PyMca5.")
+ import numpy
+ import os
+ app = qt.QApplication(args)
+ tab = Parameters(paramlist=['Height', 'Position', 'FWHM'])
+ tab.showGrid()
+ tab.configureLine(name='Height', estimation='1234', group=0)
+ tab.configureLine(name='Position', code='FIXED', group=1)
+ tab.configureLine(name='FWHM', group=1)
+
+ y = numpy.loadtxt(os.path.join(PyMcaDataDir.PYMCA_DATA_DIR,
+ "XRFSpectrum.mca")) # FIXME
+
+ x = numpy.arange(len(y)) * 0.0502883 - 0.492773
+ fit = fitmanager.FitManager()
+ fit.setdata(x=x, y=y, xmin=20, xmax=150)
+
+ fit.loadtheories(fittheories)
+
+ fit.settheory('ahypermet')
+ fit.configure(Yscaling=1.,
+ PositiveFwhmFlag=True,
+ PositiveHeightAreaFlag=True,
+ FwhmPoints=16,
+ QuotedPositionFlag=1,
+ HypermetTails=1)
+ fit.setbackground('Linear')
+ fit.estimate()
+ fit.runfit()
+ tab.fillFromFit(fit.fit_results)
+ tab.show()
+ app.exec()
+
+if __name__ == "__main__":
+ main(sys.argv)
diff --git a/src/silx/gui/fit/__init__.py b/src/silx/gui/fit/__init__.py
new file mode 100644
index 0000000..e4fd3ab
--- /dev/null
+++ b/src/silx/gui/fit/__init__.py
@@ -0,0 +1,28 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "07/07/2016"
+
+from .FitWidget import FitWidget
diff --git a/src/silx/gui/fit/setup.py b/src/silx/gui/fit/setup.py
new file mode 100644
index 0000000..6672363
--- /dev/null
+++ b/src/silx/gui/fit/setup.py
@@ -0,0 +1,43 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "21/07/2016"
+
+
+from numpy.distutils.misc_util import Configuration
+
+
+def configuration(parent_package='', top_path=None):
+ config = Configuration('fit', parent_package, top_path)
+ config.add_subpackage('test')
+
+ return config
+
+
+if __name__ == "__main__":
+ from numpy.distutils.core import setup
+
+ setup(configuration=configuration)
diff --git a/src/silx/gui/fit/test/__init__.py b/src/silx/gui/fit/test/__init__.py
new file mode 100644
index 0000000..71128fb
--- /dev/null
+++ b/src/silx/gui/fit/test/__init__.py
@@ -0,0 +1,24 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
diff --git a/src/silx/gui/fit/test/testBackgroundWidget.py b/src/silx/gui/fit/test/testBackgroundWidget.py
new file mode 100644
index 0000000..b8570f7
--- /dev/null
+++ b/src/silx/gui/fit/test/testBackgroundWidget.py
@@ -0,0 +1,72 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+import unittest
+
+from silx.gui.utils.testutils import TestCaseQt
+
+from .. import BackgroundWidget
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+
+class TestBackgroundWidget(TestCaseQt):
+ def setUp(self):
+ super(TestBackgroundWidget, self).setUp()
+ self.bgdialog = BackgroundWidget.BackgroundDialog()
+ self.bgdialog.setData(list([0, 1, 2, 3]),
+ list([0, 1, 4, 8]))
+ self.qWaitForWindowExposed(self.bgdialog)
+
+ def tearDown(self):
+ del self.bgdialog
+ super(TestBackgroundWidget, self).tearDown()
+
+ def testShow(self):
+ self.bgdialog.show()
+ self.bgdialog.hide()
+
+ def testAccept(self):
+ self.bgdialog.accept()
+ self.assertTrue(self.bgdialog.result())
+
+ def testReject(self):
+ self.bgdialog.reject()
+ self.assertFalse(self.bgdialog.result())
+
+ def testDefaultOutput(self):
+ self.bgdialog.accept()
+ output = self.bgdialog.output
+
+ for key in ["algorithm", "StripThreshold", "SnipWidth",
+ "StripIterations", "StripWidth", "SmoothingFlag",
+ "SmoothingWidth", "AnchorsFlag", "AnchorsList"]:
+ self.assertIn(key, output)
+
+ self.assertFalse(output["AnchorsFlag"])
+ self.assertEqual(output["StripWidth"], 1)
+ self.assertEqual(output["SmoothingFlag"], False)
+ self.assertEqual(output["SmoothingWidth"], 3)
diff --git a/src/silx/gui/fit/test/testFitConfig.py b/src/silx/gui/fit/test/testFitConfig.py
new file mode 100644
index 0000000..53da2dd
--- /dev/null
+++ b/src/silx/gui/fit/test/testFitConfig.py
@@ -0,0 +1,84 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for :class:`FitConfig`"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+import unittest
+
+from silx.gui.utils.testutils import TestCaseQt
+from .. import FitConfig
+
+
+class TestFitConfig(TestCaseQt):
+ """Basic test for FitWidget"""
+
+ def setUp(self):
+ super(TestFitConfig, self).setUp()
+ self.fit_config = FitConfig.getFitConfigDialog(modal=False)
+ self.qWaitForWindowExposed(self.fit_config)
+
+ def tearDown(self):
+ del self.fit_config
+ super(TestFitConfig, self).tearDown()
+
+ def testShow(self):
+ self.fit_config.show()
+ self.fit_config.hide()
+
+ def testAccept(self):
+ self.fit_config.accept()
+ self.assertTrue(self.fit_config.result())
+
+ def testReject(self):
+ self.fit_config.reject()
+ self.assertFalse(self.fit_config.result())
+
+ def testDefaultOutput(self):
+ self.fit_config.accept()
+ output = self.fit_config.output
+
+ for key in ["AutoFwhm",
+ "PositiveHeightAreaFlag",
+ "QuotedPositionFlag",
+ "PositiveFwhmFlag",
+ "SameFwhmFlag",
+ "QuotedEtaFlag",
+ "NoConstraintsFlag",
+ "FwhmPoints",
+ "Sensitivity",
+ "Yscaling",
+ "ForcePeakPresence",
+ "StripBackgroundFlag",
+ "StripWidth",
+ "StripIterations",
+ "StripThreshold",
+ "SmoothingFlag"]:
+ self.assertIn(key, output)
+
+ self.assertTrue(output["AutoFwhm"])
+ self.assertEqual(output["StripWidth"], 2)
diff --git a/src/silx/gui/fit/test/testFitWidget.py b/src/silx/gui/fit/test/testFitWidget.py
new file mode 100644
index 0000000..abe9d89
--- /dev/null
+++ b/src/silx/gui/fit/test/testFitWidget.py
@@ -0,0 +1,124 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for :class:`FitWidget`"""
+
+import unittest
+
+from silx.gui.utils.testutils import TestCaseQt
+
+from ... import qt
+from .. import FitWidget
+
+from ....math.fit.fittheory import FitTheory
+from ....math.fit.fitmanager import FitManager
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+
+class TestFitWidget(TestCaseQt):
+ """Basic test for FitWidget"""
+
+ def setUp(self):
+ super(TestFitWidget, self).setUp()
+ self.fit_widget = FitWidget()
+ self.fit_widget.show()
+ self.qWaitForWindowExposed(self.fit_widget)
+
+ def tearDown(self):
+ self.fit_widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.fit_widget.close()
+ del self.fit_widget
+ super(TestFitWidget, self).tearDown()
+
+ def testShow(self):
+ pass
+
+ def testInteract(self):
+ self.mouseClick(self.fit_widget, qt.Qt.LeftButton)
+ self.keyClick(self.fit_widget, qt.Qt.Key_Enter)
+ self.qapp.processEvents()
+
+ def testCustomConfigWidget(self):
+ class CustomConfigWidget(qt.QDialog):
+ def __init__(self):
+ qt.QDialog.__init__(self)
+ self.setModal(True)
+ self.ok = qt.QPushButton("ok", self)
+ self.ok.clicked.connect(self.accept)
+ cancel = qt.QPushButton("cancel", self)
+ cancel.clicked.connect(self.reject)
+ layout = qt.QVBoxLayout(self)
+ layout.addWidget(self.ok)
+ layout.addWidget(cancel)
+ self.output = {"hello": "world"}
+
+ def fitfun(x, a, b):
+ return a * x + b
+
+ x = list(range(0, 100))
+ y = [fitfun(x_, 2, 3) for x_ in x]
+
+ def conf(**kw):
+ return {"spam": "eggs",
+ "hello": "world!"}
+
+ theory = FitTheory(
+ function=fitfun,
+ parameters=["a", "b"],
+ configure=conf)
+
+ fitmngr = FitManager()
+ fitmngr.setdata(x, y)
+ fitmngr.addtheory("foo", theory)
+ fitmngr.addtheory("bar", theory)
+ fitmngr.addbgtheory("spam", theory)
+
+ fw = FitWidget(fitmngr=fitmngr)
+ fw.associateConfigDialog("spam", CustomConfigWidget(),
+ theory_is_background=True)
+ fw.associateConfigDialog("foo", CustomConfigWidget())
+ fw.show()
+ self.qWaitForWindowExposed(fw)
+
+ fw.bgconfigdialogs["spam"].accept()
+ self.assertTrue(fw.bgconfigdialogs["spam"].result())
+
+ self.assertEqual(fw.bgconfigdialogs["spam"].output,
+ {"hello": "world"})
+
+ fw.bgconfigdialogs["spam"].reject()
+ self.assertFalse(fw.bgconfigdialogs["spam"].result())
+
+ fw.configdialogs["foo"].accept()
+ self.assertTrue(fw.configdialogs["foo"].result())
+
+ # todo: figure out how to click fw.configdialog.ok to close dialog
+ # open dialog
+ # self.mouseClick(fw.guiConfig.FunConfigureButton, qt.Qt.LeftButton)
+ # clove dialog
+ # self.mouseClick(fw.configdialogs["foo"].ok, qt.Qt.LeftButton)
+ # self.qapp.processEvents()
diff --git a/src/silx/gui/hdf5/Hdf5Formatter.py b/src/silx/gui/hdf5/Hdf5Formatter.py
new file mode 100644
index 0000000..6c3de41
--- /dev/null
+++ b/src/silx/gui/hdf5/Hdf5Formatter.py
@@ -0,0 +1,240 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides a class sharred by widgets to format HDF5 data as
+text."""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "06/06/2018"
+
+import numpy
+
+from silx.gui import qt
+from silx.gui.data.TextFormatter import TextFormatter
+
+import h5py
+
+
+class Hdf5Formatter(qt.QObject):
+ """Formatter to convert HDF5 data to string.
+ """
+
+ formatChanged = qt.Signal()
+ """Emitted when properties of the formatter change."""
+
+ def __init__(self, parent=None, textFormatter=None):
+ """
+ Constructor
+
+ :param qt.QObject parent: Owner of the object
+ :param TextFormatter formatter: Text formatter
+ """
+ qt.QObject.__init__(self, parent)
+ if textFormatter is not None:
+ self.__formatter = textFormatter
+ else:
+ self.__formatter = TextFormatter(self)
+ self.__formatter.formatChanged.connect(self.__formatChanged)
+
+ def textFormatter(self):
+ """Returns the used text formatter
+
+ :rtype: TextFormatter
+ """
+ return self.__formatter
+
+ def setTextFormatter(self, textFormatter):
+ """Set the text formatter to be used
+
+ :param TextFormatter textFormatter: The text formatter to use
+ """
+ if textFormatter is None:
+ raise ValueError("Formatter expected but None found")
+ if self.__formatter is textFormatter:
+ return
+ self.__formatter.formatChanged.disconnect(self.__formatChanged)
+ self.__formatter = textFormatter
+ self.__formatter.formatChanged.connect(self.__formatChanged)
+ self.__formatChanged()
+
+ def __formatChanged(self):
+ self.formatChanged.emit()
+
+ def humanReadableShape(self, dataset):
+ if dataset.shape is None:
+ return "none"
+ if dataset.shape == tuple():
+ return "scalar"
+ shape = [str(i) for i in dataset.shape]
+ text = u" \u00D7 ".join(shape)
+ return text
+
+ def humanReadableValue(self, dataset):
+ if dataset.shape is None:
+ return "No data"
+
+ dtype = dataset.dtype
+ if dataset.dtype.type == numpy.void:
+ if dtype.fields is None:
+ return "Raw data"
+
+ if dataset.shape == tuple():
+ numpy_object = dataset[()]
+ text = self.__formatter.toString(numpy_object, dtype=dataset.dtype)
+ else:
+ if dataset.size < 5 and dataset.compression is None:
+ numpy_object = dataset[0:5]
+ text = self.__formatter.toString(numpy_object, dtype=dataset.dtype)
+ else:
+ dimension = len(dataset.shape)
+ if dataset.compression is not None:
+ text = "Compressed %dD data" % dimension
+ else:
+ text = "%dD data" % dimension
+ return text
+
+ def humanReadableType(self, dataset, full=False):
+ if hasattr(dataset, "dtype"):
+ dtype = dataset.dtype
+ else:
+ # Fallback...
+ dtype = type(dataset)
+ return self.humanReadableDType(dtype, full)
+
+ def humanReadableDType(self, dtype, full=False):
+ if dtype == bytes or numpy.issubdtype(dtype, numpy.string_):
+ text = "string"
+ if full:
+ text = "ASCII " + text
+ return text
+ elif dtype == str or numpy.issubdtype(dtype, numpy.unicode_):
+ text = "string"
+ if full:
+ text = "UTF-8 " + text
+ return text
+ elif dtype.type == numpy.object_:
+ ref = h5py.check_dtype(ref=dtype)
+ if ref is not None:
+ return "reference"
+ vlen = h5py.check_dtype(vlen=dtype)
+ if vlen is not None:
+ text = self.humanReadableDType(vlen, full=full)
+ if full:
+ text = "variable-length " + text
+ return text
+ return "object"
+ elif dtype.type == numpy.bool_:
+ return "bool"
+ elif dtype.type == numpy.void:
+ if dtype.fields is None:
+ return "opaque"
+ else:
+ if not full:
+ return "compound"
+ else:
+ fields = sorted(dtype.fields.items(), key=lambda e: e[1][1])
+ compound = [d[1][0] for d in fields]
+ compound = [self.humanReadableDType(d) for d in compound]
+ return "compound(%s)" % ", ".join(compound)
+ elif numpy.issubdtype(dtype, numpy.integer):
+ enumType = h5py.check_dtype(enum=dtype)
+ if enumType is not None:
+ return "enum"
+
+ text = str(dtype.newbyteorder('N'))
+ if numpy.issubdtype(dtype, numpy.floating):
+ if hasattr(numpy, "float128") and dtype == numpy.float128:
+ text = "float80"
+ if full:
+ text += " (padding 128bits)"
+ elif hasattr(numpy, "float96") and dtype == numpy.float96:
+ text = "float80"
+ if full:
+ text += " (padding 96bits)"
+
+ if full:
+ if dtype.byteorder == "<":
+ text = "Little-endian " + text
+ elif dtype.byteorder == ">":
+ text = "Big-endian " + text
+ elif dtype.byteorder == "=":
+ text = "Native " + text
+
+ dtype = dtype.newbyteorder('N')
+ return text
+
+ def humanReadableHdf5Type(self, dataset):
+ """Format the internal HDF5 type as a string"""
+ t = dataset.id.get_type()
+ class_ = t.get_class()
+ if class_ == h5py.h5t.NO_CLASS:
+ return "NO_CLASS"
+ elif class_ == h5py.h5t.INTEGER:
+ return "INTEGER"
+ elif class_ == h5py.h5t.FLOAT:
+ return "FLOAT"
+ elif class_ == h5py.h5t.TIME:
+ return "TIME"
+ elif class_ == h5py.h5t.STRING:
+ charset = t.get_cset()
+ strpad = t.get_strpad()
+ text = ""
+
+ if strpad == h5py.h5t.STR_NULLTERM:
+ text += "NULLTERM"
+ elif strpad == h5py.h5t.STR_NULLPAD:
+ text += "NULLPAD"
+ elif strpad == h5py.h5t.STR_SPACEPAD:
+ text += "SPACEPAD"
+ else:
+ text += "UNKNOWN_STRPAD"
+
+ if t.is_variable_str():
+ text += " VARIABLE"
+
+ if charset == h5py.h5t.CSET_ASCII:
+ text += " ASCII"
+ elif charset == h5py.h5t.CSET_UTF8:
+ text += " UTF8"
+ else:
+ text += " UNKNOWN_CSET"
+
+ return text + " STRING"
+ elif class_ == h5py.h5t.BITFIELD:
+ return "BITFIELD"
+ elif class_ == h5py.h5t.OPAQUE:
+ return "OPAQUE"
+ elif class_ == h5py.h5t.COMPOUND:
+ return "COMPOUND"
+ elif class_ == h5py.h5t.REFERENCE:
+ return "REFERENCE"
+ elif class_ == h5py.h5t.ENUM:
+ return "ENUM"
+ elif class_ == h5py.h5t.VLEN:
+ return "VLEN"
+ elif class_ == h5py.h5t.ARRAY:
+ return "ARRAY"
+ else:
+ return "UNKNOWN_CLASS"
diff --git a/src/silx/gui/hdf5/Hdf5HeaderView.py b/src/silx/gui/hdf5/Hdf5HeaderView.py
new file mode 100644
index 0000000..7255ce0
--- /dev/null
+++ b/src/silx/gui/hdf5/Hdf5HeaderView.py
@@ -0,0 +1,184 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "16/06/2017"
+
+
+from .. import qt
+from .Hdf5TreeModel import Hdf5TreeModel
+
+
+class Hdf5HeaderView(qt.QHeaderView):
+ """
+ Default HDF5 header
+
+ Manage auto-resize and context menu to display/hide columns
+ """
+
+ def __init__(self, orientation, parent=None):
+ """
+ Constructor
+
+ :param orientation qt.Qt.Orientation: Orientation of the header
+ :param parent qt.QWidget: Parent of the widget
+ """
+ super(Hdf5HeaderView, self).__init__(orientation, parent)
+ self.setContextMenuPolicy(qt.Qt.CustomContextMenu)
+ self.customContextMenuRequested.connect(self.__createContextMenu)
+
+ # default initialization done by QTreeView for it's own header
+ self.setSectionsClickable(True)
+ self.setSectionsMovable(True)
+ self.setDefaultAlignment(qt.Qt.AlignLeft | qt.Qt.AlignVCenter)
+ self.setStretchLastSection(True)
+
+ self.__auto_resize = True
+ self.__hide_columns_popup = True
+
+ def setModel(self, model):
+ """Override model to configure view when a model is expected
+
+ `qt.QHeaderView.setSectionResizeMode` expect already existing columns
+ to work.
+
+ :param model qt.QAbstractItemModel: A model
+ """
+ super(Hdf5HeaderView, self).setModel(model)
+ self.__updateAutoResize()
+
+ def __updateAutoResize(self):
+ """Update the view according to the state of the auto-resize"""
+ if self.__auto_resize:
+ self.setSectionResizeMode(Hdf5TreeModel.NAME_COLUMN, qt.QHeaderView.ResizeToContents)
+ self.setSectionResizeMode(Hdf5TreeModel.TYPE_COLUMN, qt.QHeaderView.ResizeToContents)
+ self.setSectionResizeMode(Hdf5TreeModel.SHAPE_COLUMN, qt.QHeaderView.ResizeToContents)
+ self.setSectionResizeMode(Hdf5TreeModel.VALUE_COLUMN, qt.QHeaderView.Interactive)
+ self.setSectionResizeMode(Hdf5TreeModel.DESCRIPTION_COLUMN, qt.QHeaderView.Interactive)
+ self.setSectionResizeMode(Hdf5TreeModel.NODE_COLUMN, qt.QHeaderView.ResizeToContents)
+ self.setSectionResizeMode(Hdf5TreeModel.LINK_COLUMN, qt.QHeaderView.ResizeToContents)
+ else:
+ self.setSectionResizeMode(Hdf5TreeModel.NAME_COLUMN, qt.QHeaderView.Interactive)
+ self.setSectionResizeMode(Hdf5TreeModel.TYPE_COLUMN, qt.QHeaderView.Interactive)
+ self.setSectionResizeMode(Hdf5TreeModel.SHAPE_COLUMN, qt.QHeaderView.Interactive)
+ self.setSectionResizeMode(Hdf5TreeModel.VALUE_COLUMN, qt.QHeaderView.Interactive)
+ self.setSectionResizeMode(Hdf5TreeModel.DESCRIPTION_COLUMN, qt.QHeaderView.Interactive)
+ self.setSectionResizeMode(Hdf5TreeModel.NODE_COLUMN, qt.QHeaderView.Interactive)
+ self.setSectionResizeMode(Hdf5TreeModel.LINK_COLUMN, qt.QHeaderView.Interactive)
+
+ def setAutoResizeColumns(self, autoResize):
+ """Enable/disable auto-resize. When auto-resized, the header take care
+ of the content of the column to set fixed size of some of them, or to
+ auto fix the size according to the content.
+
+ :param autoResize bool: Enable/disable auto-resize
+ """
+ if self.__auto_resize == autoResize:
+ return
+ self.__auto_resize = autoResize
+ self.__updateAutoResize()
+
+ def hasAutoResizeColumns(self):
+ """Is auto-resize enabled.
+
+ :rtype: bool
+ """
+ return self.__auto_resize
+
+ autoResizeColumns = qt.Property(bool, hasAutoResizeColumns, setAutoResizeColumns)
+ """Property to enable/disable auto-resize."""
+
+ def setEnableHideColumnsPopup(self, enablePopup):
+ """Enable/disable a popup to allow to hide/show each column of the
+ model.
+
+ :param bool enablePopup: Enable/disable popup to hide/show columns
+ """
+ self.__hide_columns_popup = enablePopup
+
+ def hasHideColumnsPopup(self):
+ """Is popup to hide/show columns is enabled.
+
+ :rtype: bool
+ """
+ return self.__hide_columns_popup
+
+ enableHideColumnsPopup = qt.Property(bool, hasHideColumnsPopup, setAutoResizeColumns)
+ """Property to enable/disable popup allowing to hide/show columns."""
+
+ def __genHideSectionEvent(self, column):
+ """Generate a callback which change the column visibility according to
+ the event parameter
+
+ :param int column: logical id of the column
+ :rtype: callable
+ """
+ return lambda checked: self.setSectionHidden(column, not checked)
+
+ def __createContextMenu(self, pos):
+ """Callback to create and display a context menu
+
+ :param pos qt.QPoint: Requested position for the context menu
+ """
+ if not self.__hide_columns_popup:
+ return
+
+ model = self.model()
+ if model.columnCount() > 1:
+ menu = qt.QMenu(self)
+ menu.setTitle("Display/hide columns")
+
+ action = qt.QAction("Display/hide column", self)
+ action.setEnabled(False)
+ menu.addAction(action)
+
+ for column in range(model.columnCount()):
+ if column == 0:
+ # skip the main column
+ continue
+ text = model.headerData(column, qt.Qt.Horizontal, qt.Qt.DisplayRole)
+ action = qt.QAction("%s displayed" % text, self)
+ action.setCheckable(True)
+ action.setChecked(not self.isSectionHidden(column))
+ action.toggled.connect(self.__genHideSectionEvent(column))
+ menu.addAction(action)
+
+ menu.popup(self.viewport().mapToGlobal(pos))
+
+ def setSections(self, logicalIndexes):
+ """
+ Defines order of visible sections by logical indexes.
+
+ Use `Hdf5TreeModel.NAME_COLUMN` to set the list.
+
+ :param list logicalIndexes: List of logical indexes to display
+ """
+ for pos, column_id in enumerate(logicalIndexes):
+ current_pos = self.visualIndex(column_id)
+ self.moveSection(current_pos, pos)
+ self.setSectionHidden(column_id, False)
+ for column_id in set(range(self.model().columnCount())) - set(logicalIndexes):
+ self.setSectionHidden(column_id, True)
diff --git a/src/silx/gui/hdf5/Hdf5Item.py b/src/silx/gui/hdf5/Hdf5Item.py
new file mode 100755
index 0000000..e07f835
--- /dev/null
+++ b/src/silx/gui/hdf5/Hdf5Item.py
@@ -0,0 +1,642 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "17/01/2019"
+
+
+import logging
+import collections
+import enum
+
+from .. import qt
+from .. import icons
+from . import _utils
+from .Hdf5Node import Hdf5Node
+import silx.io.utils
+from silx.gui.data.TextFormatter import TextFormatter
+from ..hdf5.Hdf5Formatter import Hdf5Formatter
+_logger = logging.getLogger(__name__)
+_formatter = TextFormatter()
+_hdf5Formatter = Hdf5Formatter(textFormatter=_formatter)
+# FIXME: The formatter should be an attribute of the Hdf5Model
+
+
+class DescriptionType(enum.Enum):
+ """List of available kind of description.
+ """
+ ERROR = "error"
+ DESCRIPTION = "description"
+ TITLE = "title"
+ PROGRAM = "program"
+ NAME = "name"
+ VALUE = "value"
+
+
+class Hdf5Item(Hdf5Node):
+ """Subclass of :class:`qt.QStandardItem` to represent an HDF5-like
+ item (dataset, file, group or link) as an element of a HDF5-like
+ tree structure.
+ """
+
+ def __init__(self, text, obj, parent, key=None, h5Class=None, linkClass=None, populateAll=False):
+ """
+ :param str text: text displayed
+ :param object obj: Pointer to a h5py-link object. See the `obj` attribute.
+ """
+ self.__obj = obj
+ self.__key = key
+ self.__h5Class = h5Class
+ self.__isBroken = obj is None and h5Class is None
+ self.__error = None
+ self.__text = text
+ self.__linkClass = linkClass
+ self.__description = None
+ self.__nx_class = None
+ Hdf5Node.__init__(self, parent, populateAll=populateAll)
+
+ def _getCanonicalName(self):
+ parent = self.parent
+ if parent is None:
+ return self.__text
+ else:
+ return "%s/%s" % (parent._getCanonicalName(), self.__text)
+
+ @property
+ def obj(self):
+ if self.__key:
+ self.__initH5Object()
+ return self.__obj
+
+ @property
+ def basename(self):
+ return self.__text
+
+ @property
+ def h5Class(self):
+ """Returns the class of the stored object.
+
+ When the object is in lazy loading, this method should be able to
+ return the type of the future loaded object. It allows to delay the
+ real load of the object.
+
+ :rtype: silx.io.utils.H5Type
+ """
+ if self.__h5Class is None and self.obj is not None:
+ self.__h5Class = silx.io.utils.get_h5_class(self.obj)
+ return self.__h5Class
+
+ @property
+ def h5pyClass(self):
+ """Returns the class of the stored object.
+
+ When the object is in lazy loading, this method should be able to
+ return the type of the future loaded object. It allows to delay the
+ real load of the object.
+
+ :rtype: h5py.File or h5py.Dataset or h5py.Group
+ """
+ type_ = self.h5Class
+ return silx.io.utils.h5type_to_h5py_class(type_)
+
+ @property
+ def linkClass(self):
+ """Returns the link class object of this node
+
+ :rtype: H5Type
+ """
+ return self.__linkClass
+
+ def isGroupObj(self):
+ """Returns true if the stored HDF5 object is a group (contains sub
+ groups or datasets).
+
+ :rtype: bool
+ """
+ if self.h5Class is None:
+ return False
+ return self.h5Class in [silx.io.utils.H5Type.GROUP, silx.io.utils.H5Type.FILE]
+
+ def isBrokenObj(self):
+ """Returns true if the stored HDF5 object is broken.
+
+ The stored object is then an h5py-like link (external or not) which
+ point to nowhere (tbhe external file is not here, the expected
+ dataset is still not on the file...)
+
+ :rtype: bool
+ """
+ return self.__isBroken
+
+ def _getFormatter(self):
+ """
+ Returns an Hdf5Formatter
+
+ :rtype: Hdf5Formatter
+ """
+ return _hdf5Formatter
+
+ def _expectedChildCount(self):
+ if self.isGroupObj():
+ return len(self.obj)
+ return 0
+
+ def __initH5Object(self):
+ """Lazy load of the HDF5 node. It is reached from the parent node
+ with the key of the node."""
+ parent_obj = self.parent.obj
+
+ try:
+ obj = parent_obj.get(self.__key)
+ except Exception as e:
+ _logger.error("Internal error while reaching HDF5 object: %s", str(e))
+ _logger.debug("Backtrace", exc_info=True)
+ try:
+ self.__obj = parent_obj.get(self.__key, getlink=True)
+ except Exception:
+ self.__obj = None
+ self.__error = e.args[0]
+ self.__isBroken = True
+ else:
+ if obj is None:
+ # that's a broken link
+ self.__obj = parent_obj.get(self.__key, getlink=True)
+
+ # TODO monkey-patch file (ask that in h5py for consistency)
+ if not hasattr(self.__obj, "name"):
+ parent_name = parent_obj.name
+ if parent_name == "/":
+ self.__obj.name = "/" + self.__key
+ else:
+ self.__obj.name = parent_name + "/" + self.__key
+ # TODO monkey-patch file (ask that in h5py for consistency)
+ if not hasattr(self.__obj, "file"):
+ self.__obj.file = parent_obj.file
+
+ class_ = silx.io.utils.get_h5_class(self.__obj)
+
+ if class_ == silx.io.utils.H5Type.EXTERNAL_LINK:
+ message = "External link broken. Path %s::%s does not exist" % (self.__obj.filename, self.__obj.path)
+ elif class_ == silx.io.utils.H5Type.SOFT_LINK:
+ message = "Soft link broken. Path %s does not exist" % (self.__obj.path)
+ else:
+ name = self.__obj.__class__.__name__.split(".")[-1].capitalize()
+ message = "%s broken" % (name)
+ self.__error = message
+ self.__isBroken = True
+ else:
+ self.__obj = obj
+ if not self.isGroupObj():
+ try:
+ # pre-fetch of the data
+ if obj.shape is None:
+ pass
+ elif obj.shape == tuple():
+ obj[()]
+ else:
+ if obj.compression is None and obj.size > 0:
+ key = tuple([0] * len(obj.shape))
+ obj[key]
+ except Exception as e:
+ _logger.debug(e, exc_info=True)
+ message = "%s broken. %s" % (self.__obj.name, e.args[0])
+ self.__error = message
+ self.__isBroken = True
+
+ self.__key = None
+
+ def _populateChild(self, populateAll=False):
+ if self.isGroupObj():
+ keys = []
+ try:
+ for name in self.obj:
+ keys.append(name)
+ except Exception:
+ lib_name = self.obj.__class__.__module__.split(".")[0]
+ _logger.error("Internal %s error. The file is corrupted.", lib_name)
+ _logger.debug("Backtrace", exc_info=True)
+ if keys == []:
+ # If the file was open in READ_ONLY we still can reach something
+ # https://github.com/silx-kit/silx/issues/2262
+ try:
+ for name in self.obj:
+ keys.append(name)
+ except Exception:
+ lib_name = self.obj.__class__.__module__.split(".")[0]
+ _logger.error("Internal %s error (second time). The file is corrupted.", lib_name)
+ _logger.debug("Backtrace", exc_info=True)
+ for name in keys:
+ try:
+ class_ = self.obj.get(name, getclass=True)
+ link = self.obj.get(name, getclass=True, getlink=True)
+ link = silx.io.utils.get_h5_class(class_=link)
+ except Exception:
+ lib_name = self.obj.__class__.__module__.split(".")[0]
+ _logger.error("Internal %s error", lib_name)
+ _logger.debug("Backtrace", exc_info=True)
+ class_ = None
+ try:
+ link = self.obj.get(name, getclass=True, getlink=True)
+ link = silx.io.utils.get_h5_class(class_=link)
+ except Exception:
+ _logger.debug("Backtrace", exc_info=True)
+ link = silx.io.utils.H5Type.HARD_LINK
+
+ h5class = None
+ if class_ is not None:
+ h5class = silx.io.utils.get_h5_class(class_=class_)
+ if h5class is None:
+ _logger.error("Class %s unsupported", class_)
+ item = Hdf5Item(text=name, obj=None, parent=self, key=name, h5Class=h5class, linkClass=link)
+ self.appendChild(item)
+
+ def hasChildren(self):
+ """Retuens true of this node have chrild.
+
+ :rtype: bool
+ """
+ if not self.isGroupObj():
+ return False
+ return Hdf5Node.hasChildren(self)
+
+ def _getDefaultIcon(self):
+ """Returns the icon displayed by the main column.
+
+ :rtype: qt.QIcon
+ """
+ # Pre-fetch the object, in case it is broken
+ obj = self.obj
+ style = qt.QApplication.style()
+ if self.__isBroken:
+ icon = style.standardIcon(qt.QStyle.SP_MessageBoxCritical)
+ return icon
+ class_ = self.h5Class
+ if class_ == silx.io.utils.H5Type.FILE:
+ return style.standardIcon(qt.QStyle.SP_FileIcon)
+ elif class_ == silx.io.utils.H5Type.GROUP:
+ return style.standardIcon(qt.QStyle.SP_DirIcon)
+ elif class_ == silx.io.utils.H5Type.SOFT_LINK:
+ return style.standardIcon(qt.QStyle.SP_DirLinkIcon)
+ elif class_ == silx.io.utils.H5Type.EXTERNAL_LINK:
+ return style.standardIcon(qt.QStyle.SP_FileLinkIcon)
+ elif class_ == silx.io.utils.H5Type.DATASET:
+ if obj.shape is None:
+ name = "item-none"
+ elif len(obj.shape) < 4:
+ name = "item-%ddim" % len(obj.shape)
+ else:
+ name = "item-ndim"
+ icon = icons.getQIcon(name)
+ return icon
+ return None
+
+ def _createTooltipAttributes(self):
+ """
+ Add key/value attributes that will be displayed in the item tooltip
+
+ :param Dict[str,str] attributeDict: Key/value attributes
+ """
+ attributeDict = collections.OrderedDict()
+
+ if self.h5Class == silx.io.utils.H5Type.DATASET:
+ attributeDict["#Title"] = "HDF5 Dataset"
+ attributeDict["Name"] = self.basename
+ attributeDict["Path"] = self.obj.name
+ attributeDict["Shape"] = self._getFormatter().humanReadableShape(self.obj)
+ attributeDict["Value"] = self._getFormatter().humanReadableValue(self.obj)
+ attributeDict["Data type"] = self._getFormatter().humanReadableType(self.obj, full=True)
+ elif self.h5Class == silx.io.utils.H5Type.GROUP:
+ attributeDict["#Title"] = "HDF5 Group"
+ if self.nexusClassName:
+ attributeDict["NX_class"] = self.nexusClassName
+ attributeDict["Name"] = self.basename
+ attributeDict["Path"] = self.obj.name
+ elif self.h5Class == silx.io.utils.H5Type.FILE:
+ attributeDict["#Title"] = "HDF5 File"
+ attributeDict["Name"] = self.basename
+ attributeDict["Path"] = "/"
+ elif self.h5Class == silx.io.utils.H5Type.EXTERNAL_LINK:
+ attributeDict["#Title"] = "HDF5 External Link"
+ attributeDict["Name"] = self.basename
+ attributeDict["Path"] = self.obj.name
+ attributeDict["Linked path"] = self.obj.path
+ attributeDict["Linked file"] = self.obj.filename
+ elif self.h5Class == silx.io.utils.H5Type.SOFT_LINK:
+ attributeDict["#Title"] = "HDF5 Soft Link"
+ attributeDict["Name"] = self.basename
+ attributeDict["Path"] = self.obj.name
+ attributeDict["Linked path"] = self.obj.path
+ else:
+ pass
+ return attributeDict
+
+ def _getDefaultTooltip(self):
+ """Returns the default tooltip
+
+ :rtype: str
+ """
+ if self.__error is not None:
+ self.obj # lazy loading of the object
+ return self.__error
+
+ attrs = self._createTooltipAttributes()
+ title = attrs.pop("#Title", None)
+ if len(attrs) > 0:
+ tooltip = _utils.htmlFromDict(attrs, title=title)
+ else:
+ tooltip = ""
+
+ return tooltip
+
+ @property
+ def nexusClassName(self):
+ """Returns the Nexus class name"""
+ if self.__nx_class is None:
+ obj = self.obj.attrs.get("NX_class", None)
+ if obj is None:
+ text = ""
+ else:
+ text = self._getFormatter().textFormatter().toString(obj)
+ text = text.strip('"')
+ # Check NX_class formatting
+ lower = text.lower()
+ formatedNX_class = ""
+ if lower.startswith('nx'):
+ formatedNX_class = 'NX' + lower[2:]
+ if lower == 'nxcansas':
+ formatedNX_class = 'NXcanSAS' # That's the only class with capital letters...
+ if text != formatedNX_class:
+ _logger.error("NX_class: '%s' is malformed (should be '%s')",
+ text,
+ formatedNX_class)
+ text = formatedNX_class
+
+ self.__nx_class = text
+ return self.__nx_class
+
+ def dataName(self, role):
+ """Data for the name column"""
+ if role == qt.Qt.TextAlignmentRole:
+ return qt.Qt.AlignTop | qt.Qt.AlignLeft
+ if role == qt.Qt.DisplayRole:
+ return self.__text
+ if role == qt.Qt.DecorationRole:
+ return self._getDefaultIcon()
+ if role == qt.Qt.ToolTipRole:
+ return self._getDefaultTooltip()
+ return None
+
+ def dataType(self, role):
+ """Data for the type column"""
+ if role == qt.Qt.DecorationRole:
+ return None
+ if role == qt.Qt.TextAlignmentRole:
+ return qt.Qt.AlignTop | qt.Qt.AlignLeft
+ if role == qt.Qt.DisplayRole:
+ if self.__error is not None:
+ return ""
+ class_ = self.h5Class
+ if self.isGroupObj():
+ text = self.nexusClassName
+ elif class_ == silx.io.utils.H5Type.DATASET:
+ text = self._getFormatter().humanReadableType(self.obj)
+ else:
+ text = ""
+ return text
+ return None
+
+ def dataShape(self, role):
+ """Data for the shape column"""
+ if role == qt.Qt.DecorationRole:
+ return None
+ if role == qt.Qt.TextAlignmentRole:
+ return qt.Qt.AlignTop | qt.Qt.AlignLeft
+ if role == qt.Qt.DisplayRole:
+ if self.__error is not None:
+ return ""
+ class_ = self.h5Class
+ if class_ != silx.io.utils.H5Type.DATASET:
+ return ""
+ return self._getFormatter().humanReadableShape(self.obj)
+ return None
+
+ def dataValue(self, role):
+ """Data for the value column"""
+ if role == qt.Qt.DecorationRole:
+ return None
+ if role == qt.Qt.TextAlignmentRole:
+ return qt.Qt.AlignTop | qt.Qt.AlignLeft
+ if role == qt.Qt.DisplayRole:
+ if self.__error is not None:
+ return ""
+ if self.h5Class != silx.io.utils.H5Type.DATASET:
+ return ""
+ return self._getFormatter().humanReadableValue(self.obj)
+ return None
+
+ _NEXUS_CLASS_TO_VALUE_CHILDREN = {
+ 'NXaperture': (
+ (DescriptionType.DESCRIPTION, 'description'),
+ ),
+ 'NXbeam_stop': (
+ (DescriptionType.DESCRIPTION, 'description'),
+ ),
+ 'NXdetector': (
+ (DescriptionType.NAME, 'local_name'),
+ (DescriptionType.DESCRIPTION, 'description')
+ ),
+ 'NXentry': (
+ (DescriptionType.TITLE, 'title'),
+ ),
+ 'NXenvironment': (
+ (DescriptionType.NAME, 'short_name'),
+ (DescriptionType.NAME, 'name'),
+ (DescriptionType.DESCRIPTION, 'description')
+ ),
+ 'NXinstrument': (
+ (DescriptionType.NAME, 'name'),
+ ),
+ 'NXlog': (
+ (DescriptionType.DESCRIPTION, 'description'),
+ ),
+ 'NXmirror': (
+ (DescriptionType.DESCRIPTION, 'description'),
+ ),
+ 'NXpositioner': (
+ (DescriptionType.NAME, 'name'),
+ ),
+ 'NXprocess': (
+ (DescriptionType.PROGRAM, 'program'),
+ ),
+ 'NXsample': (
+ (DescriptionType.TITLE, 'short_title'),
+ (DescriptionType.NAME, 'name'),
+ (DescriptionType.DESCRIPTION, 'description')
+ ),
+ 'NXsample_component': (
+ (DescriptionType.NAME, 'name'),
+ (DescriptionType.DESCRIPTION, 'description')
+ ),
+ 'NXsensor': (
+ (DescriptionType.NAME, 'short_name'),
+ (DescriptionType.NAME, 'name')
+ ),
+ 'NXsource': (
+ (DescriptionType.NAME, 'name'),
+ ), # or its 'short_name' attribute... This is not supported
+ 'NXsubentry': (
+ (DescriptionType.DESCRIPTION, 'definition'),
+ (DescriptionType.PROGRAM, 'program_name'),
+ (DescriptionType.TITLE, 'title'),
+ ),
+ }
+ """Mapping from NeXus class to child names containing data to use as value"""
+
+ def __computeDataDescription(self):
+ """Compute the data description of this item
+
+ :rtype: Tuple[kind, str]
+ """
+ if self.__isBroken or self.__error is not None:
+ self.obj # lazy loading of the object
+ return DescriptionType.ERROR, self.__error
+
+ if self.h5Class == silx.io.utils.H5Type.DATASET:
+ return DescriptionType.VALUE, self._getFormatter().humanReadableValue(self.obj)
+
+ elif self.isGroupObj() and self.nexusClassName:
+ # For NeXus groups, try to find a title or name
+ # By default, look for a title (most application definitions should have one)
+ defaultSequence = ((DescriptionType.TITLE, 'title'),)
+ sequence = self._NEXUS_CLASS_TO_VALUE_CHILDREN.get(self.nexusClassName, defaultSequence)
+ for kind, child_name in sequence:
+ for index in range(self.childCount()):
+ child = self.child(index)
+ if (isinstance(child, Hdf5Item) and
+ child.h5Class == silx.io.utils.H5Type.DATASET and
+ child.basename == child_name):
+ return kind, self._getFormatter().humanReadableValue(child.obj)
+
+ description = self.obj.attrs.get("desc", None)
+ if description is not None:
+ return DescriptionType.DESCRIPTION, description
+ else:
+ return None, None
+
+ def __getDataDescription(self):
+ """Returns a cached version of the data description
+
+ As the data description have to reach inside the HDF5 tree, the result
+ is cached. A better implementation could be to use a MRU cache, to avoid
+ to allocate too much data.
+
+ :rtype: Tuple[kind, str]
+ """
+ if self.__description is None:
+ self.__description = self.__computeDataDescription()
+ return self.__description
+
+ def dataDescription(self, role):
+ """Data for the description column"""
+ if role == qt.Qt.DecorationRole:
+ kind, _label = self.__getDataDescription()
+ if kind is not None:
+ icon = icons.getQIcon("description-%s" % kind.value)
+ return icon
+ return None
+ if role == qt.Qt.TextAlignmentRole:
+ return qt.Qt.AlignTop | qt.Qt.AlignLeft
+ if role == qt.Qt.DisplayRole:
+ _kind, label = self.__getDataDescription()
+ return label
+ if role == qt.Qt.ToolTipRole:
+ if self.__error is not None:
+ self.obj # lazy loading of the object
+ self.__initH5Object()
+ return self.__error
+ kind, label = self.__getDataDescription()
+ if label is not None:
+ return "<b>%s</b><br/>%s" % (kind.value.capitalize(), label)
+ else:
+ return ""
+ return None
+
+ def dataNode(self, role):
+ """Data for the node column"""
+ if role == qt.Qt.DecorationRole:
+ return None
+ if role == qt.Qt.TextAlignmentRole:
+ return qt.Qt.AlignTop | qt.Qt.AlignLeft
+ if role == qt.Qt.DisplayRole:
+ if self.isBrokenObj():
+ return ""
+ class_ = self.obj.__class__
+ text = class_.__name__.split(".")[-1]
+ return text
+ if role == qt.Qt.ToolTipRole:
+ class_ = self.obj.__class__
+ if class_ is None:
+ return ""
+ return "Class name: %s" % self.__class__
+ return None
+
+ def dataLink(self, role):
+ """Data for the link column
+
+ Overwrite it to implement the content of the 'link' column.
+
+ :rtype: qt.QVariant
+ """
+ if role == qt.Qt.DecorationRole:
+ return None
+ if role == qt.Qt.TextAlignmentRole:
+ return qt.Qt.AlignTop | qt.Qt.AlignLeft
+ if role == qt.Qt.DisplayRole:
+ # Mark as link
+ link = self.linkClass
+ if link is None:
+ pass
+ elif link == silx.io.utils.H5Type.HARD_LINK:
+ pass
+ elif link == silx.io.utils.H5Type.EXTERNAL_LINK:
+ return "External"
+ elif link == silx.io.utils.H5Type.SOFT_LINK:
+ return "Soft"
+ else:
+ return link.__name__
+ # Mark as external data
+ if self.h5Class == silx.io.utils.H5Type.DATASET:
+ obj = self.obj
+ if hasattr(obj, "is_virtual"):
+ if obj.is_virtual:
+ return "Virtual"
+ if hasattr(obj, "external"):
+ if obj.external:
+ return "ExtRaw"
+ return ""
+ if role == qt.Qt.ToolTipRole:
+ return None
+ return None
diff --git a/src/silx/gui/hdf5/Hdf5LoadingItem.py b/src/silx/gui/hdf5/Hdf5LoadingItem.py
new file mode 100644
index 0000000..f11d252
--- /dev/null
+++ b/src/silx/gui/hdf5/Hdf5LoadingItem.py
@@ -0,0 +1,77 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "06/07/2018"
+
+
+from .. import qt
+from .Hdf5Node import Hdf5Node
+import silx.io.utils
+
+
+class Hdf5LoadingItem(Hdf5Node):
+ """Item displayed when an Hdf5Node is loading.
+
+ At the end of the loading this item is replaced by the loaded one.
+ """
+
+ def __init__(self, text, parent, animatedIcon):
+ """Constructor"""
+ Hdf5Node.__init__(self, parent)
+ self.__text = text
+ self.__animatedIcon = animatedIcon
+ self.__animatedIcon.register(self)
+
+ @property
+ def obj(self):
+ return None
+
+ @property
+ def h5Class(self):
+ """Returns the class of the stored object.
+
+ :rtype: silx.io.utils.H5Type
+ """
+ return silx.io.utils.H5Type.FILE
+
+ def dataName(self, role):
+ if role == qt.Qt.DecorationRole:
+ return self.__animatedIcon.currentIcon()
+ if role == qt.Qt.TextAlignmentRole:
+ return qt.Qt.AlignTop | qt.Qt.AlignLeft
+ if role == qt.Qt.DisplayRole:
+ return self.__text
+ return None
+
+ def dataDescription(self, role):
+ if role == qt.Qt.DecorationRole:
+ return None
+ if role == qt.Qt.TextAlignmentRole:
+ return qt.Qt.AlignTop | qt.Qt.AlignLeft
+ if role == qt.Qt.DisplayRole:
+ return "Loading..."
+ return None
diff --git a/src/silx/gui/hdf5/Hdf5Node.py b/src/silx/gui/hdf5/Hdf5Node.py
new file mode 100644
index 0000000..be16535
--- /dev/null
+++ b/src/silx/gui/hdf5/Hdf5Node.py
@@ -0,0 +1,238 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "24/07/2018"
+
+import weakref
+
+
+class Hdf5Node(object):
+ """Abstract tree node
+
+ It provides link to the childs and to the parents, and a link to an
+ external object.
+ """
+ def __init__(self, parent=None, populateAll=False):
+ """
+ Constructor
+
+ :param Hdf5Node parent: Parent of the node, if exists, else None
+ :param bool populateAll: If true, populate all the tree node. Else
+ everything is lazy loaded.
+ """
+ self.__child = None
+ self.__parent = None
+ if parent is not None:
+ self.__parent = weakref.ref(parent)
+ if populateAll:
+ self.__child = []
+ self._populateChild(populateAll=True)
+
+ def _getCanonicalName(self):
+ parent = self.parent
+ if parent is None:
+ return "root"
+ else:
+ return "%s/?" % (parent._getCanonicalName())
+
+ @property
+ def parent(self):
+ """Parent of the node, or None if the node is a root
+
+ :rtype: Hdf5Node
+ """
+ if self.__parent is None:
+ return None
+ parent = self.__parent()
+ if parent is None:
+ self.__parent = parent
+ return parent
+
+ def setParent(self, parent):
+ """Redefine the parent of the node.
+
+ It does not set the node as the children of the new parent.
+
+ :param Hdf5Node parent: The new parent
+ """
+ if parent is None:
+ self.__parent = None
+ else:
+ self.__parent = weakref.ref(parent)
+
+ def appendChild(self, child):
+ """Append a child to the node.
+
+ It does not update the parent of the child.
+
+ :param Hdf5Node child: Child to append to the node.
+ """
+ self.__initChild()
+ self.__child.append(child)
+
+ def removeChildAtIndex(self, index):
+ """Remove a child at an index of the children list.
+
+ The child is removed and returned.
+
+ :param int index: Index in the child list.
+ :rtype: Hdf5Node
+ :raises: IndexError if list is empty or index is out of range.
+ """
+ self.__initChild()
+ return self.__child.pop(index)
+
+ def insertChild(self, index, child):
+ """
+ Insert a child at a specific index of the child list.
+
+ It does not update the parent of the child.
+
+ :param int index: Index in the child list.
+ :param Hdf5Node child: Child to insert in the child list.
+ """
+ self.__initChild()
+ self.__child.insert(index, child)
+
+ def indexOfChild(self, child):
+ """
+ Returns the index of the child in the child list of this node.
+
+ :param Hdf5Node child: Child to find
+ :raises: ValueError if the value is not present.
+ """
+ self.__initChild()
+ return self.__child.index(child)
+
+ def hasChildren(self):
+ """Returns true if the node contains children.
+
+ :rtype: bool
+ """
+ return self.childCount() > 0
+
+ def childCount(self):
+ """Returns the number of child in this node.
+
+ :rtype: int
+ """
+ if self.__child is not None:
+ return len(self.__child)
+ return self._expectedChildCount()
+
+ def child(self, index):
+ """Return the child at an expected index.
+
+ :param int index: Index of the child in the child list of the node
+ :rtype: Hdf5Node
+ """
+ self.__initChild()
+ return self.__child[index]
+
+ def __initChild(self):
+ """Init the child of the node in case the list was lazy loaded."""
+ if self.__child is None:
+ self.__child = []
+ self._populateChild()
+
+ def _expectedChildCount(self):
+ """Returns the expected count of children
+
+ :rtype: int
+ """
+ return 0
+
+ def _populateChild(self, populateAll=False):
+ """Recurse through an HDF5 structure to append groups an datasets
+ into the tree model.
+
+ Overwrite it to implement the initialisation of child of the node.
+ """
+ pass
+
+ def dataName(self, role):
+ """Data for the name column
+
+ Overwrite it to implement the content of the 'name' column.
+
+ :rtype: qt.QVariant
+ """
+ return None
+
+ def dataType(self, role):
+ """Data for the type column
+
+ Overwrite it to implement the content of the 'type' column.
+
+ :rtype: qt.QVariant
+ """
+ return None
+
+ def dataShape(self, role):
+ """Data for the shape column
+
+ Overwrite it to implement the content of the 'shape' column.
+
+ :rtype: qt.QVariant
+ """
+ return None
+
+ def dataValue(self, role):
+ """Data for the value column
+
+ Overwrite it to implement the content of the 'value' column.
+
+ :rtype: qt.QVariant
+ """
+ return None
+
+ def dataDescription(self, role):
+ """Data for the description column
+
+ Overwrite it to implement the content of the 'description' column.
+
+ :rtype: qt.QVariant
+ """
+ return None
+
+ def dataNode(self, role):
+ """Data for the node column
+
+ Overwrite it to implement the content of the 'node' column.
+
+ :rtype: qt.QVariant
+ """
+ return None
+
+ def dataLink(self, role):
+ """Data for the link column
+
+ Overwrite it to implement the content of the 'link' column.
+
+ :rtype: qt.QVariant
+ """
+ return None
diff --git a/src/silx/gui/hdf5/Hdf5TreeModel.py b/src/silx/gui/hdf5/Hdf5TreeModel.py
new file mode 100644
index 0000000..a32f7cf
--- /dev/null
+++ b/src/silx/gui/hdf5/Hdf5TreeModel.py
@@ -0,0 +1,742 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "12/03/2019"
+
+
+import os
+import logging
+import functools
+from .. import qt
+from .. import icons
+from .Hdf5Node import Hdf5Node
+from .Hdf5Item import Hdf5Item
+from .Hdf5LoadingItem import Hdf5LoadingItem
+from . import _utils
+from ... import io as silx_io
+
+_logger = logging.getLogger(__name__)
+
+
+def _createRootLabel(h5obj):
+ """
+ Create label for the very first npde of the tree.
+
+ :param h5obj: The h5py object to display in the GUI
+ :type h5obj: h5py-like object
+ :rtpye: str
+ """
+ if silx_io.is_file(h5obj):
+ label = os.path.basename(h5obj.filename)
+ else:
+ filename = os.path.basename(h5obj.file.filename)
+ path = h5obj.name
+ if path.startswith("/"):
+ path = path[1:]
+ label = "%s::%s" % (filename, path)
+ return label
+
+
+class LoadingItemRunnable(qt.QRunnable):
+ """Runner to process item loading from a file"""
+
+ class __Signals(qt.QObject):
+ """Signal holder"""
+ itemReady = qt.Signal(object, object, object)
+ runnerFinished = qt.Signal(object)
+
+ def __init__(self, filename, item):
+ """Constructor
+
+ :param LoadingItemWorker worker: Object holding data and signals
+ """
+ super(LoadingItemRunnable, self).__init__()
+ self.filename = filename
+ self.oldItem = item
+ self.signals = self.__Signals()
+
+ def setFile(self, filename, item):
+ self.filenames.append((filename, item))
+
+ @property
+ def itemReady(self):
+ return self.signals.itemReady
+
+ @property
+ def runnerFinished(self):
+ return self.signals.runnerFinished
+
+ def __loadItemTree(self, oldItem, h5obj):
+ """Create an item tree used by the GUI from an h5py object.
+
+ :param Hdf5Node oldItem: The current item displayed the GUI
+ :param h5py.File h5obj: The h5py object to display in the GUI
+ :rtpye: Hdf5Node
+ """
+ text = _createRootLabel(h5obj)
+ item = Hdf5Item(text=text, obj=h5obj, parent=oldItem.parent, populateAll=True)
+ return item
+
+ def run(self):
+ """Process the file loading. The worker is used as holder
+ of the data and the signal. The result is sent as a signal.
+ """
+ h5file = None
+ try:
+ h5file = silx_io.open(self.filename)
+ newItem = self.__loadItemTree(self.oldItem, h5file)
+ error = None
+ except IOError as e:
+ # Should be logged
+ error = e
+ newItem = None
+ if h5file is not None:
+ h5file.close()
+
+ self.itemReady.emit(self.oldItem, newItem, error)
+ self.runnerFinished.emit(self)
+
+ def autoDelete(self):
+ return True
+
+
+class Hdf5TreeModel(qt.QAbstractItemModel):
+ """Tree model storing a list of :class:`h5py.File` like objects.
+
+ The main column display the :class:`h5py.File` list and there hierarchy.
+ Other columns display information on node hierarchy.
+ """
+
+ H5PY_ITEM_ROLE = qt.Qt.UserRole
+ """Role to reach h5py item from an item index"""
+
+ H5PY_OBJECT_ROLE = qt.Qt.UserRole + 1
+ """Role to reach h5py object from an item index"""
+
+ USER_ROLE = qt.Qt.UserRole + 2
+ """Start of range of available user role for derivative models"""
+
+ NAME_COLUMN = 0
+ """Column id containing HDF5 node names"""
+
+ TYPE_COLUMN = 1
+ """Column id containing HDF5 dataset types"""
+
+ SHAPE_COLUMN = 2
+ """Column id containing HDF5 dataset shapes"""
+
+ VALUE_COLUMN = 3
+ """Column id containing HDF5 dataset values"""
+
+ DESCRIPTION_COLUMN = 4
+ """Column id containing HDF5 node description/title/message"""
+
+ NODE_COLUMN = 5
+ """Column id containing HDF5 node type"""
+
+ LINK_COLUMN = 6
+ """Column id containing HDF5 link type"""
+
+ COLUMN_IDS = [
+ NAME_COLUMN,
+ TYPE_COLUMN,
+ SHAPE_COLUMN,
+ VALUE_COLUMN,
+ DESCRIPTION_COLUMN,
+ NODE_COLUMN,
+ LINK_COLUMN,
+ ]
+ """List of logical columns available"""
+
+ sigH5pyObjectLoaded = qt.Signal(object)
+ """Emitted when a new root item was loaded and inserted to the model."""
+
+ sigH5pyObjectRemoved = qt.Signal(object)
+ """Emitted when a root item is removed from the model."""
+
+ sigH5pyObjectSynchronized = qt.Signal(object, object)
+ """Emitted when an item was synchronized."""
+
+ def __init__(self, parent=None, ownFiles=True):
+ """
+ Constructor
+
+ :param qt.QWidget parent: Parent widget
+ :param bool ownFiles: If true (default) the model will manage the files
+ life cycle when they was added using path (like DnD).
+ """
+ super(Hdf5TreeModel, self).__init__(parent)
+
+ self.header_labels = [None] * len(self.COLUMN_IDS)
+ self.header_labels[self.NAME_COLUMN] = 'Name'
+ self.header_labels[self.TYPE_COLUMN] = 'Type'
+ self.header_labels[self.SHAPE_COLUMN] = 'Shape'
+ self.header_labels[self.VALUE_COLUMN] = 'Value'
+ self.header_labels[self.DESCRIPTION_COLUMN] = 'Description'
+ self.header_labels[self.NODE_COLUMN] = 'Node'
+ self.header_labels[self.LINK_COLUMN] = 'Link'
+
+ # Create items
+ self.__root = Hdf5Node()
+ self.__fileDropEnabled = True
+ self.__fileMoveEnabled = True
+ self.__datasetDragEnabled = False
+
+ self.__animatedIcon = icons.getWaitIcon()
+ self.__animatedIcon.iconChanged.connect(self.__updateLoadingItems)
+ self.__runnerSet = set([])
+
+ # store used icons to avoid the cache to release it
+ self.__icons = []
+ self.__icons.append(icons.getQIcon("item-none"))
+ self.__icons.append(icons.getQIcon("item-0dim"))
+ self.__icons.append(icons.getQIcon("item-1dim"))
+ self.__icons.append(icons.getQIcon("item-2dim"))
+ self.__icons.append(icons.getQIcon("item-3dim"))
+ self.__icons.append(icons.getQIcon("item-ndim"))
+
+ self.__ownFiles = ownFiles
+ self.__openedFiles = []
+ """Store the list of files opened by the model itself."""
+ # FIXME: It should be managed one by one by Hdf5Item itself
+
+ # It is not possible to override the QObject destructor nor
+ # to access to the content of the Python object with the `destroyed`
+ # signal cause the Python method was already removed with the QWidget,
+ # while the QObject still exists.
+ # We use a static method plus explicit references to objects to
+ # release. The callback do not use any ref to self.
+ onDestroy = functools.partial(self._closeFileList, self.__openedFiles)
+ self.destroyed.connect(onDestroy)
+
+ @staticmethod
+ def _closeFileList(fileList):
+ """Static method to close explicit references to internal objects."""
+ _logger.debug("Clear Hdf5TreeModel")
+ for obj in fileList:
+ _logger.debug("Close file %s", obj.filename)
+ obj.close()
+ fileList[:] = []
+
+ def _closeOpened(self):
+ """Close files which was opened by this model.
+
+ File are opened by the model when it was inserted using
+ `insertFileAsync`, `insertFile`, `appendFile`."""
+ self._closeFileList(self.__openedFiles)
+
+ def __updateLoadingItems(self, icon):
+ for i in range(self.__root.childCount()):
+ item = self.__root.child(i)
+ if isinstance(item, Hdf5LoadingItem):
+ index1 = self.index(i, 0, qt.QModelIndex())
+ index2 = self.index(i, self.columnCount() - 1, qt.QModelIndex())
+ self.dataChanged.emit(index1, index2)
+
+ def __itemReady(self, oldItem, newItem, error):
+ """Called at the end of a concurent file loading, when the loading
+ item is ready. AN error is defined if an exception occured when
+ loading the newItem .
+
+ :param Hdf5Node oldItem: current displayed item
+ :param Hdf5Node newItem: item loaded, or None if error is defined
+ :param Exception error: An exception, or None if newItem is defined
+ """
+ row = self.__root.indexOfChild(oldItem)
+
+ rootIndex = qt.QModelIndex()
+ self.beginRemoveRows(rootIndex, row, row)
+ self.__root.removeChildAtIndex(row)
+ self.endRemoveRows()
+
+ if newItem is not None:
+ rootIndex = qt.QModelIndex()
+ if self.__ownFiles:
+ self.__openedFiles.append(newItem.obj)
+ self.beginInsertRows(rootIndex, row, row)
+ self.__root.insertChild(row, newItem)
+ self.endInsertRows()
+
+ if isinstance(oldItem, Hdf5LoadingItem):
+ self.sigH5pyObjectLoaded.emit(newItem.obj)
+ else:
+ self.sigH5pyObjectSynchronized.emit(oldItem.obj, newItem.obj)
+
+ # FIXME the error must be displayed
+
+ def isFileDropEnabled(self):
+ return self.__fileDropEnabled
+
+ def setFileDropEnabled(self, enabled):
+ self.__fileDropEnabled = enabled
+
+ fileDropEnabled = qt.Property(bool, isFileDropEnabled, setFileDropEnabled)
+ """Property to enable/disable file dropping in the model."""
+
+ def isDatasetDragEnabled(self):
+ return self.__datasetDragEnabled
+
+ def setDatasetDragEnabled(self, enabled):
+ self.__datasetDragEnabled = enabled
+
+ datasetDragEnabled = qt.Property(bool, isDatasetDragEnabled, setDatasetDragEnabled)
+ """Property to enable/disable drag of datasets."""
+
+ def isFileMoveEnabled(self):
+ return self.__fileMoveEnabled
+
+ def setFileMoveEnabled(self, enabled):
+ self.__fileMoveEnabled = enabled
+
+ fileMoveEnabled = qt.Property(bool, isFileMoveEnabled, setFileMoveEnabled)
+ """Property to enable/disable drag-and-drop of files to
+ change the ordering in the model."""
+
+ def supportedDropActions(self):
+ if self.__fileMoveEnabled or self.__fileDropEnabled:
+ return qt.Qt.CopyAction | qt.Qt.MoveAction
+ else:
+ return 0
+
+ def mimeTypes(self):
+ types = []
+ if self.__fileMoveEnabled or self.__datasetDragEnabled:
+ types.append(_utils.Hdf5DatasetMimeData.MIME_TYPE)
+ return types
+
+ def mimeData(self, indexes):
+ """
+ Returns an object that contains serialized items of data corresponding
+ to the list of indexes specified.
+
+ :param List[qt.QModelIndex] indexes: List of indexes
+ :rtype: qt.QMimeData
+ """
+ if len(indexes) == 0:
+ return None
+
+ indexes = [i for i in indexes if i.column() == 0]
+ if len(indexes) > 1:
+ raise NotImplementedError("Drag of multi rows is not implemented")
+ if len(indexes) == 0:
+ raise NotImplementedError("Drag of cell is not implemented")
+
+ node = self.nodeFromIndex(indexes[0])
+
+ if self.__fileMoveEnabled and node.parent is self.__root:
+ mimeData = _utils.Hdf5DatasetMimeData(node=node, isRoot=True)
+ elif self.__datasetDragEnabled:
+ mimeData = _utils.Hdf5DatasetMimeData(node=node)
+ else:
+ mimeData = None
+ return mimeData
+
+ def flags(self, index):
+ defaultFlags = qt.QAbstractItemModel.flags(self, index)
+
+ if index.isValid():
+ node = self.nodeFromIndex(index)
+ if self.__fileMoveEnabled and node.parent is self.__root:
+ # that's a root
+ return qt.Qt.ItemIsDragEnabled | defaultFlags
+ elif self.__datasetDragEnabled:
+ return qt.Qt.ItemIsDragEnabled | defaultFlags
+ return defaultFlags
+ elif self.__fileDropEnabled or self.__fileMoveEnabled:
+ return qt.Qt.ItemIsDropEnabled | defaultFlags
+ else:
+ return defaultFlags
+
+ def dropMimeData(self, mimedata, action, row, column, parentIndex):
+ if action == qt.Qt.IgnoreAction:
+ return True
+
+ if self.__fileMoveEnabled and mimedata.hasFormat(_utils.Hdf5DatasetMimeData.MIME_TYPE):
+ if mimedata.isRoot():
+ dragNode = mimedata.node()
+ parentNode = self.nodeFromIndex(parentIndex)
+ if parentNode is not dragNode.parent:
+ return False
+
+ if row == -1:
+ # append to the parent
+ row = parentNode.childCount()
+ else:
+ # insert at row
+ pass
+
+ dragNodeParent = dragNode.parent
+ sourceRow = dragNodeParent.indexOfChild(dragNode)
+ self.moveRow(parentIndex, sourceRow, parentIndex, row)
+ return True
+
+ if self.__fileDropEnabled and mimedata.hasFormat("text/uri-list"):
+
+ parentNode = self.nodeFromIndex(parentIndex)
+ if parentNode is not self.__root:
+ while(parentNode is not self.__root):
+ node = parentNode
+ parentNode = node.parent
+ row = parentNode.indexOfChild(node)
+ else:
+ if row == -1:
+ row = self.__root.childCount()
+
+ messages = []
+ for url in mimedata.urls():
+ try:
+ self.insertFileAsync(url.toLocalFile(), row)
+ row += 1
+ except IOError as e:
+ messages.append(e.args[0])
+ if len(messages) > 0:
+ title = "Error occurred when loading files"
+ message = "<html>%s:<ul><li>%s</li><ul></html>" % (title, "</li><li>".join(messages))
+ qt.QMessageBox.critical(None, title, message)
+ return True
+
+ return False
+
+ def headerData(self, section, orientation, role=qt.Qt.DisplayRole):
+ if orientation == qt.Qt.Horizontal:
+ if role in [qt.Qt.DisplayRole, qt.Qt.EditRole]:
+ return self.header_labels[section]
+ return None
+
+ def insertNode(self, row, node):
+ if row == -1:
+ row = self.__root.childCount()
+ self.beginInsertRows(qt.QModelIndex(), row, row)
+ self.__root.insertChild(row, node)
+ self.endInsertRows()
+
+ def moveRow(self, sourceParentIndex, sourceRow, destinationParentIndex, destinationRow):
+ if sourceRow == destinationRow or sourceRow == destinationRow - 1:
+ # abort move, same place
+ return
+ return self.moveRows(sourceParentIndex, sourceRow, 1, destinationParentIndex, destinationRow)
+
+ def moveRows(self, sourceParentIndex, sourceRow, count, destinationParentIndex, destinationRow):
+ self.beginMoveRows(sourceParentIndex, sourceRow, sourceRow, destinationParentIndex, destinationRow)
+ sourceNode = self.nodeFromIndex(sourceParentIndex)
+ destinationNode = self.nodeFromIndex(destinationParentIndex)
+
+ if sourceNode is destinationNode and sourceRow < destinationRow:
+ item = sourceNode.child(sourceRow)
+ destinationNode.insertChild(destinationRow, item)
+ sourceNode.removeChildAtIndex(sourceRow)
+ else:
+ item = sourceNode.removeChildAtIndex(sourceRow)
+ destinationNode.insertChild(destinationRow, item)
+
+ self.endMoveRows()
+ return True
+
+ def index(self, row, column, parent=qt.QModelIndex()):
+ try:
+ node = self.nodeFromIndex(parent)
+ return self.createIndex(row, column, node.child(row))
+ except IndexError:
+ return qt.QModelIndex()
+
+ def data(self, index, role=qt.Qt.DisplayRole):
+ node = self.nodeFromIndex(index)
+
+ if role == self.H5PY_ITEM_ROLE:
+ return node
+
+ if role == self.H5PY_OBJECT_ROLE:
+ return node.obj
+
+ if index.column() == self.NAME_COLUMN:
+ return node.dataName(role)
+ elif index.column() == self.TYPE_COLUMN:
+ return node.dataType(role)
+ elif index.column() == self.SHAPE_COLUMN:
+ return node.dataShape(role)
+ elif index.column() == self.VALUE_COLUMN:
+ return node.dataValue(role)
+ elif index.column() == self.DESCRIPTION_COLUMN:
+ return node.dataDescription(role)
+ elif index.column() == self.NODE_COLUMN:
+ return node.dataNode(role)
+ elif index.column() == self.LINK_COLUMN:
+ return node.dataLink(role)
+ else:
+ return None
+
+ def columnCount(self, parent=qt.QModelIndex()):
+ return len(self.COLUMN_IDS)
+
+ def hasChildren(self, parent=qt.QModelIndex()):
+ node = self.nodeFromIndex(parent)
+ if node is None:
+ return 0
+ return node.hasChildren()
+
+ def rowCount(self, parent=qt.QModelIndex()):
+ node = self.nodeFromIndex(parent)
+ if node is None:
+ return 0
+ return node.childCount()
+
+ def parent(self, child):
+ if not child.isValid():
+ return qt.QModelIndex()
+
+ node = self.nodeFromIndex(child)
+
+ if node is None:
+ return qt.QModelIndex()
+
+ parent = node.parent
+
+ if parent is None:
+ return qt.QModelIndex()
+
+ grandparent = parent.parent
+ if grandparent is None:
+ return qt.QModelIndex()
+ row = grandparent.indexOfChild(parent)
+
+ assert row != - 1
+ return self.createIndex(row, 0, parent)
+
+ def nodeFromIndex(self, index):
+ return index.internalPointer() if index.isValid() else self.__root
+
+ def _closeFileIfOwned(self, node):
+ """"Close the file if it was loaded from a filename or a
+ drag-and-drop"""
+ obj = node.obj
+ for f in self.__openedFiles:
+ if f is obj:
+ _logger.debug("Close file %s", obj.filename)
+ obj.close()
+ self.__openedFiles.remove(obj)
+
+ def synchronizeIndex(self, index):
+ """
+ Synchronize a file a given its index.
+
+ Basically close it and load it again.
+
+ :param qt.QModelIndex index: Index of the item to update
+ """
+ node = self.nodeFromIndex(index)
+ if node.parent is not self.__root:
+ return
+
+ filename = node.obj.filename
+ self.insertFileAsync(filename, index.row(), synchronizingNode=node)
+
+ def h5pyObjectRow(self, h5pyObject):
+ for row in range(self.__root.childCount()):
+ item = self.__root.child(row)
+ if item.obj == h5pyObject:
+ return row
+ return -1
+
+ def synchronizeH5pyObject(self, h5pyObject):
+ """
+ Synchronize a h5py object in all the tree.
+
+ Basically close it and load it again.
+
+ :param h5py.File h5pyObject: A :class:`h5py.File` object.
+ """
+ index = 0
+ while index < self.__root.childCount():
+ item = self.__root.child(index)
+ if item.obj == h5pyObject:
+ qindex = self.index(index, 0, qt.QModelIndex())
+ self.synchronizeIndex(qindex)
+ index += 1
+
+ def removeIndex(self, index):
+ """
+ Remove an item from the model using its index.
+
+ :param qt.QModelIndex index: Index of the item to remove
+ """
+ node = self.nodeFromIndex(index)
+ if node.parent != self.__root:
+ return
+ self._closeFileIfOwned(node)
+ self.beginRemoveRows(qt.QModelIndex(), index.row(), index.row())
+ self.__root.removeChildAtIndex(index.row())
+ self.endRemoveRows()
+ self.sigH5pyObjectRemoved.emit(node.obj)
+
+ def removeH5pyObject(self, h5pyObject):
+ """
+ Remove an item from the model using the holding h5py object.
+ It can remove more than one item.
+
+ :param h5py.File h5pyObject: A :class:`h5py.File` object.
+ """
+ index = 0
+ while index < self.__root.childCount():
+ item = self.__root.child(index)
+ if item.obj == h5pyObject:
+ qindex = self.index(index, 0, qt.QModelIndex())
+ self.removeIndex(qindex)
+ else:
+ index += 1
+
+ def insertH5pyObject(self, h5pyObject, text=None, row=-1):
+ """Append an HDF5 object from h5py to the tree.
+
+ :param h5pyObject: File handle/descriptor for a :class:`h5py.File`
+ or any other class of h5py file structure.
+ """
+ if text is None:
+ text = _createRootLabel(h5pyObject)
+ if row == -1:
+ row = self.__root.childCount()
+ self.insertNode(row, Hdf5Item(text=text, obj=h5pyObject, parent=self.__root))
+
+ def hasPendingOperations(self):
+ return len(self.__runnerSet) > 0
+
+ def insertFileAsync(self, filename, row=-1, synchronizingNode=None):
+ if not os.path.isfile(filename):
+ raise IOError("Filename '%s' must be a file path" % filename)
+
+ # create temporary item
+ if synchronizingNode is None:
+ text = os.path.basename(filename)
+ item = Hdf5LoadingItem(text=text, parent=self.__root, animatedIcon=self.__animatedIcon)
+ self.insertNode(row, item)
+ else:
+ item = synchronizingNode
+
+ # start loading the real one
+ runnable = LoadingItemRunnable(filename, item)
+ runnable.itemReady.connect(self.__itemReady)
+ runnable.runnerFinished.connect(self.__releaseRunner)
+ self.__runnerSet.add(runnable)
+ qt.silxGlobalThreadPool().start(runnable)
+
+ def __releaseRunner(self, runner):
+ self.__runnerSet.remove(runner)
+
+ def insertFile(self, filename, row=-1):
+ """Load a HDF5 file into the data model.
+
+ :param filename: file path.
+ """
+ try:
+ h5file = silx_io.open(filename)
+ if self.__ownFiles:
+ self.__openedFiles.append(h5file)
+ self.sigH5pyObjectLoaded.emit(h5file)
+ self.insertH5pyObject(h5file, row=row)
+ except IOError:
+ _logger.debug("File '%s' can't be read.", filename, exc_info=True)
+ raise
+
+ def clear(self):
+ """Remove all the content of the model"""
+ for _ in range(self.rowCount()):
+ qindex = self.index(0, 0, qt.QModelIndex())
+ self.removeIndex(qindex)
+
+ def appendFile(self, filename):
+ self.insertFile(filename, -1)
+
+ def indexFromH5Object(self, h5Object):
+ """Returns a model index from an h5py-like object.
+
+ :param object h5Object: An h5py-like object
+ :rtype: qt.QModelIndex
+ """
+ if h5Object is None:
+ return qt.QModelIndex()
+
+ filename = h5Object.file.filename
+
+ # Seach for the right roots
+ rootIndices = []
+ for index in range(self.rowCount(qt.QModelIndex())):
+ index = self.index(index, 0, qt.QModelIndex())
+ obj = self.data(index, Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ if obj.file.filename == filename:
+ # We can have many roots with different subtree of the same
+ # root
+ rootIndices.append(index)
+
+ if len(rootIndices) == 0:
+ # No root found
+ return qt.QModelIndex()
+
+ path = h5Object.name + "/"
+ path = path.replace("//", "/")
+
+ # Search for the right node
+ found = False
+ foundIndices = []
+ for _ in range(1000 * len(rootIndices)):
+ # Avoid too much iterations, in case of recurssive links
+ if len(foundIndices) == 0:
+ if len(rootIndices) == 0:
+ # Nothing found
+ break
+ # Start fron a new root
+ foundIndices.append(rootIndices.pop(0))
+
+ obj = self.data(index, Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ p = obj.name + "/"
+ p = p.replace("//", "/")
+ if path == p:
+ found = True
+ break
+
+ parentIndex = foundIndices[-1]
+ for index in range(self.rowCount(parentIndex)):
+ index = self.index(index, 0, parentIndex)
+ obj = self.data(index, Hdf5TreeModel.H5PY_OBJECT_ROLE)
+
+ p = obj.name + "/"
+ p = p.replace("//", "/")
+ if path == p:
+ foundIndices.append(index)
+ found = True
+ break
+ elif path.startswith(p):
+ foundIndices.append(index)
+ break
+ else:
+ # Nothing found, start again with another root
+ foundIndices = []
+
+ if found:
+ break
+
+ if found:
+ return foundIndices[-1]
+ return qt.QModelIndex()
diff --git a/src/silx/gui/hdf5/Hdf5TreeView.py b/src/silx/gui/hdf5/Hdf5TreeView.py
new file mode 100644
index 0000000..b276618
--- /dev/null
+++ b/src/silx/gui/hdf5/Hdf5TreeView.py
@@ -0,0 +1,269 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "30/04/2018"
+
+
+import logging
+from .. import qt
+from ...utils import weakref as silxweakref
+from .Hdf5TreeModel import Hdf5TreeModel
+from .Hdf5HeaderView import Hdf5HeaderView
+from .NexusSortFilterProxyModel import NexusSortFilterProxyModel
+from .Hdf5Item import Hdf5Item
+from . import _utils
+
+_logger = logging.getLogger(__name__)
+
+
+class Hdf5TreeView(qt.QTreeView):
+ """TreeView which allow to browse HDF5 file structure.
+
+ .. image:: img/Hdf5TreeView.png
+
+ It provides columns width auto-resizing and additional
+ signals.
+
+ The default model is a :class:`NexusSortFilterProxyModel` sourcing
+ a :class:`Hdf5TreeModel`. The :class:`Hdf5TreeModel` is reachable using
+ :meth:`findHdf5TreeModel`. The default header is :class:`Hdf5HeaderView`.
+
+ Context menu is managed by the :meth:`setContextMenuPolicy` with the value
+ Qt.CustomContextMenu. This policy must not be changed, otherwise context
+ menus will not work anymore. You can use :meth:`addContextMenuCallback` and
+ :meth:`removeContextMenuCallback` to add your custum actions according
+ to the selected objects.
+ """
+ def __init__(self, parent=None):
+ """
+ Constructor
+
+ :param parent qt.QWidget: The parent widget
+ """
+ qt.QTreeView.__init__(self, parent)
+
+ model = self.createDefaultModel()
+ self.setModel(model)
+
+ self.setHeader(Hdf5HeaderView(qt.Qt.Horizontal, self))
+ self.setSelectionBehavior(qt.QAbstractItemView.SelectRows)
+ self.sortByColumn(0, qt.Qt.AscendingOrder)
+ # optimise the rendering
+ self.setUniformRowHeights(True)
+
+ self.setIconSize(qt.QSize(16, 16))
+ self.setAcceptDrops(True)
+ self.setDragEnabled(True)
+ self.setDragDropMode(qt.QAbstractItemView.DragDrop)
+ self.showDropIndicator()
+
+ self.__context_menu_callbacks = silxweakref.WeakList()
+ self.setContextMenuPolicy(qt.Qt.CustomContextMenu)
+ self.customContextMenuRequested.connect(self._createContextMenu)
+
+ def createDefaultModel(self):
+ """Creates and returns the default model.
+
+ Inherite to custom the default model"""
+ model = Hdf5TreeModel(self)
+ proxy_model = NexusSortFilterProxyModel(self)
+ proxy_model.setSourceModel(model)
+ return proxy_model
+
+ def __removeContextMenuProxies(self, ref):
+ """Callback to remove dead proxy from the list"""
+ self.__context_menu_callbacks.remove(ref)
+
+ def _createContextMenu(self, pos):
+ """
+ Create context menu.
+
+ :param pos qt.QPoint: Position of the context menu
+ """
+ actions = []
+
+ menu = qt.QMenu(self)
+
+ hovered_index = self.indexAt(pos)
+ hovered_node = self.model().data(hovered_index, Hdf5TreeModel.H5PY_ITEM_ROLE)
+ if hovered_node is None or not isinstance(hovered_node, Hdf5Item):
+ return
+
+ hovered_object = _utils.H5Node(hovered_node)
+ event = _utils.Hdf5ContextMenuEvent(self, menu, hovered_object)
+
+ for callback in self.__context_menu_callbacks:
+ try:
+ callback(event)
+ except KeyboardInterrupt:
+ raise
+ except Exception:
+ # make sure no user callback crash the application
+ _logger.error("Error while calling callback", exc_info=True)
+ pass
+
+ if not menu.isEmpty():
+ for action in actions:
+ menu.addAction(action)
+ menu.popup(self.viewport().mapToGlobal(pos))
+
+ def addContextMenuCallback(self, callback):
+ """Register a context menu callback.
+
+ The callback will be called when a context menu is requested with the
+ treeview and the list of selected h5py objects in parameters. The
+ callback must return a list of :class:`qt.QAction` object.
+
+ Callbacks are stored as saferef. The object must store a reference by
+ itself.
+ """
+ self.__context_menu_callbacks.append(callback)
+
+ def removeContextMenuCallback(self, callback):
+ """Unregister a context menu callback"""
+ self.__context_menu_callbacks.remove(callback)
+
+ def findHdf5TreeModel(self):
+ """Find the Hdf5TreeModel from the stack of model filters.
+
+ :returns: A Hdf5TreeModel, else None
+ :rtype: Hdf5TreeModel
+ """
+ model = self.model()
+ while model is not None:
+ if isinstance(model, qt.QAbstractProxyModel):
+ model = model.sourceModel()
+ else:
+ break
+ if model is None:
+ return None
+ if isinstance(model, Hdf5TreeModel):
+ return model
+ else:
+ return None
+
+ def dragEnterEvent(self, event):
+ model = self.findHdf5TreeModel()
+ if model is not None and model.isFileDropEnabled() and event.mimeData().hasFormat("text/uri-list"):
+ self.setState(qt.QAbstractItemView.DraggingState)
+ event.accept()
+ else:
+ qt.QTreeView.dragEnterEvent(self, event)
+
+ def dragMoveEvent(self, event):
+ model = self.findHdf5TreeModel()
+ if model is not None and model.isFileDropEnabled() and event.mimeData().hasFormat("text/uri-list"):
+ event.setDropAction(qt.Qt.CopyAction)
+ event.accept()
+ else:
+ qt.QTreeView.dragMoveEvent(self, event)
+
+ def selectedH5Nodes(self, ignoreBrokenLinks=True):
+ """Returns selected h5py objects like :class:`h5py.File`,
+ :class:`h5py.Group`, :class:`h5py.Dataset` or mimicked objects.
+
+ :param ignoreBrokenLinks bool: Returns objects which are not not
+ broken links.
+ :rtype: iterator(:class:`_utils.H5Node`)
+ """
+ for index in self.selectedIndexes():
+ if index.column() != 0:
+ continue
+ item = self.model().data(index, Hdf5TreeModel.H5PY_ITEM_ROLE)
+ if item is None:
+ continue
+ if isinstance(item, Hdf5Item):
+ if ignoreBrokenLinks and item.isBrokenObj():
+ continue
+ yield _utils.H5Node(item)
+
+ def __intermediateModels(self, index):
+ """Returns intermediate models from the view model to the
+ model of the index."""
+ models = []
+ targetModel = index.model()
+ model = self.model()
+ while model is not None:
+ if model is targetModel:
+ # found
+ return models
+ models.append(model)
+ if isinstance(model, qt.QAbstractProxyModel):
+ model = model.sourceModel()
+ else:
+ break
+ raise RuntimeError("Model from the requested index is not reachable from this view")
+
+ def mapToModel(self, index):
+ """Map an index from any model reachable by the view to an index from
+ the very first model connected to the view.
+
+ :param qt.QModelIndex index: Index from the Hdf5Tree model
+ :rtype: qt.QModelIndex
+ :return: Index from the model connected to the view
+ """
+ if not index.isValid():
+ return index
+ models = self.__intermediateModels(index)
+ for model in reversed(models):
+ index = model.mapFromSource(index)
+ return index
+
+ def setSelectedH5Node(self, h5Object):
+ """
+ Select the specified node of the tree using an h5py node.
+
+ - If the item is found, parent items are expended, and then the item
+ is selected.
+ - If the item is not found, the selection do not change.
+ - A none argument allow to deselect everything
+
+ :param h5py.Node h5Object: The node to select
+ """
+ if h5Object is None:
+ self.setCurrentIndex(qt.QModelIndex())
+ return
+
+ model = self.findHdf5TreeModel()
+ index = model.indexFromH5Object(h5Object)
+ index = self.mapToModel(index)
+ if index.isValid():
+ # Update the GUI
+ i = index
+ while i.isValid():
+ self.expand(i)
+ i = i.parent()
+ self.setCurrentIndex(index)
+
+ def mousePressEvent(self, event):
+ """Override mousePressEvent to provide a consistante compatible API
+ between Qt4 and Qt5
+ """
+ super(Hdf5TreeView, self).mousePressEvent(event)
+ if event.button() != qt.Qt.LeftButton:
+ qindex = self.indexAt(event.pos())
+ self.clicked.emit(qindex)
diff --git a/src/silx/gui/hdf5/NexusSortFilterProxyModel.py b/src/silx/gui/hdf5/NexusSortFilterProxyModel.py
new file mode 100644
index 0000000..9c3533f
--- /dev/null
+++ b/src/silx/gui/hdf5/NexusSortFilterProxyModel.py
@@ -0,0 +1,224 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "29/11/2018"
+
+
+import logging
+import re
+import numpy
+from .. import qt
+from .Hdf5TreeModel import Hdf5TreeModel
+import silx.io.utils
+from silx.gui import icons
+
+
+_logger = logging.getLogger(__name__)
+
+
+class NexusSortFilterProxyModel(qt.QSortFilterProxyModel):
+ """Try to sort items according to Nexus structure. Else sort by name."""
+
+ def __init__(self, parent=None):
+ qt.QSortFilterProxyModel.__init__(self, parent)
+ self.__split = re.compile("(\\d+|\\D+)")
+ self.__iconCache = {}
+
+ def hasChildren(self, parent):
+ """Returns true if parent has any children; otherwise returns false.
+
+ :param qt.QModelIndex parent: Index of the item to check
+ :rtype: bool
+ """
+ parent = self.mapToSource(parent)
+ return self.sourceModel().hasChildren(parent)
+
+ def rowCount(self, parent):
+ """Returns the number of rows under the given parent.
+
+ :param qt.QModelIndex parent: Index of the item to check
+ :rtype: int
+ """
+ parent = self.mapToSource(parent)
+ return self.sourceModel().rowCount(parent)
+
+ def lessThan(self, sourceLeft, sourceRight):
+ """Returns True if the value of the item referred to by the given
+ index `sourceLeft` is less than the value of the item referred to by
+ the given index `sourceRight`, otherwise returns false.
+
+ :param qt.QModelIndex sourceLeft:
+ :param qt.QModelIndex sourceRight:
+ :rtype: bool
+ """
+ if sourceLeft.column() != Hdf5TreeModel.NAME_COLUMN:
+ return super(NexusSortFilterProxyModel, self).lessThan(
+ sourceLeft, sourceRight)
+
+ # Do not sort child of root (files)
+ if sourceLeft.parent() == qt.QModelIndex():
+ return sourceLeft.row() < sourceRight.row()
+
+ left = self.sourceModel().data(sourceLeft, Hdf5TreeModel.H5PY_ITEM_ROLE)
+ right = self.sourceModel().data(sourceRight, Hdf5TreeModel.H5PY_ITEM_ROLE)
+
+ if self.__isNXentry(left) and self.__isNXentry(right):
+ less = self.childDatasetLessThan(left, right, "start_time")
+ if less is not None:
+ return less
+ less = self.childDatasetLessThan(left, right, "end_time")
+ if less is not None:
+ return less
+
+ left = self.sourceModel().data(sourceLeft, qt.Qt.DisplayRole)
+ right = self.sourceModel().data(sourceRight, qt.Qt.DisplayRole)
+ return self.nameLessThan(left, right)
+
+ def __isNXentry(self, node):
+ """Returns true if the node is an NXentry"""
+ class_ = node.h5Class
+ if class_ is None or class_ != silx.io.utils.H5Type.GROUP:
+ return False
+ nxClass = node.obj.attrs.get("NX_class", None)
+ return nxClass == "NXentry"
+
+ def __isNXnode(self, node):
+ """Returns true if the node is an NX concept"""
+ if not hasattr(node, "h5Class"):
+ return False
+ class_ = node.h5Class
+ if class_ is None or class_ != silx.io.utils.H5Type.GROUP:
+ return False
+ nxClass = node.obj.attrs.get("NX_class", None)
+ return nxClass is not None
+
+ def getWordsAndNumbers(self, name):
+ """
+ Returns a list of words and integers composing the name.
+
+ An input `"aaa10bbb50.30"` will return
+ `["aaa", 10, "bbb", 50, ".", 30]`.
+
+ :param str name: A name
+ :rtype: List
+ """
+ nonSensitive = self.sortCaseSensitivity() == qt.Qt.CaseInsensitive
+ words = self.__split.findall(name)
+ result = []
+ for i in words:
+ if i[0].isdigit():
+ i = int(i)
+ elif nonSensitive:
+ i = i.lower()
+ result.append(i)
+ return result
+
+ def nameLessThan(self, left, right):
+ """Returns True if the left string is less than the right string.
+
+ Number composing the names are compared as integers, as result "name2"
+ is smaller than "name10".
+
+ :param str left: A string
+ :param str right: A string
+ :rtype: bool
+ """
+ leftList = self.getWordsAndNumbers(left)
+ rightList = self.getWordsAndNumbers(right)
+ try:
+ return leftList < rightList
+ except TypeError:
+ # Back to string comparison if list are not type consistent
+ return left < right
+
+ def childDatasetLessThan(self, left, right, childName):
+ """
+ Reach the same children name of two items and compare their values.
+
+ Returns True if the left one is smaller than the right one.
+
+ :param Hdf5Item left: An item
+ :param Hdf5Item right: An item
+ :param str childName: Name of the children to search. Returns None if
+ the children is not found.
+ :rtype: bool
+ """
+ try:
+ left_time = left.obj[childName][()]
+ right_time = right.obj[childName][()]
+ if isinstance(left_time, numpy.ndarray):
+ return left_time[0] < right_time[0]
+ return left_time < right_time
+ except KeyboardInterrupt:
+ raise
+ except Exception:
+ _logger.debug("Exception occurred", exc_info=True)
+ return None
+
+ def __createCompoundIcon(self, backgroundIcon, foregroundIcon):
+ icon = qt.QIcon()
+
+ sizes = backgroundIcon.availableSizes()
+ sizes = sorted(sizes, key=lambda s: s.height())
+ sizes = filter(lambda s: s.height() < 100, sizes)
+ sizes = list(sizes)
+ if len(sizes) > 0:
+ baseSize = sizes[-1]
+ else:
+ baseSize = qt.QSize(32, 32)
+
+ modes = [qt.QIcon.Normal, qt.QIcon.Disabled]
+ for mode in modes:
+ pixmap = qt.QPixmap(baseSize)
+ pixmap.fill(qt.Qt.transparent)
+ painter = qt.QPainter(pixmap)
+ painter.drawPixmap(0, 0, backgroundIcon.pixmap(baseSize, mode=mode))
+ painter.drawPixmap(0, 0, foregroundIcon.pixmap(baseSize, mode=mode))
+ painter.end()
+ icon.addPixmap(pixmap, mode=mode)
+
+ return icon
+
+ def __getNxIcon(self, baseIcon):
+ iconHash = baseIcon.cacheKey()
+ icon = self.__iconCache.get(iconHash, None)
+ if icon is None:
+ nxIcon = icons.getQIcon("layer-nx")
+ icon = self.__createCompoundIcon(baseIcon, nxIcon)
+ self.__iconCache[iconHash] = icon
+ return icon
+
+ def data(self, index, role=qt.Qt.DisplayRole):
+ result = super(NexusSortFilterProxyModel, self).data(index, role)
+
+ if index.column() == Hdf5TreeModel.NAME_COLUMN:
+ if role == qt.Qt.DecorationRole:
+ sourceIndex = self.mapToSource(index)
+ item = self.sourceModel().data(sourceIndex, Hdf5TreeModel.H5PY_ITEM_ROLE)
+ if self.__isNXnode(item):
+ result = self.__getNxIcon(result)
+ return result
diff --git a/src/silx/gui/hdf5/__init__.py b/src/silx/gui/hdf5/__init__.py
new file mode 100644
index 0000000..1b5a602
--- /dev/null
+++ b/src/silx/gui/hdf5/__init__.py
@@ -0,0 +1,44 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides a set of Qt widgets for displaying content relative to
+HDF5 format.
+
+.. note::
+
+ This package depends on *h5py*.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "23/09/2016"
+
+
+from .Hdf5TreeView import Hdf5TreeView # noqa
+from ._utils import H5Node
+from ._utils import Hdf5ContextMenuEvent # noqa
+from .NexusSortFilterProxyModel import NexusSortFilterProxyModel # noqa
+from .Hdf5TreeModel import Hdf5TreeModel # noqa
+
+__all__ = ['Hdf5TreeView', 'H5Node', 'Hdf5ContextMenuEvent', 'NexusSortFilterProxyModel', 'Hdf5TreeModel']
diff --git a/src/silx/gui/hdf5/_utils.py b/src/silx/gui/hdf5/_utils.py
new file mode 100644
index 0000000..8f32252
--- /dev/null
+++ b/src/silx/gui/hdf5/_utils.py
@@ -0,0 +1,461 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides a set of helper class and function used by the
+package `silx.gui.hdf5` package.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "17/01/2019"
+
+
+from html import escape
+import logging
+import os.path
+
+import silx.io.utils
+import silx.io.url
+from .. import qt
+
+_logger = logging.getLogger(__name__)
+
+
+class Hdf5ContextMenuEvent(object):
+ """Hold information provided to context menu callbacks."""
+
+ def __init__(self, source, menu, hoveredObject):
+ """
+ Constructor
+
+ :param QWidget source: Widget source
+ :param QMenu menu: Context menu which will be displayed
+ :param H5Node hoveredObject: Hovered H5 node
+ """
+ self.__source = source
+ self.__menu = menu
+ self.__hoveredObject = hoveredObject
+
+ def source(self):
+ """Source of the event
+
+ :rtype: Hdf5TreeView
+ """
+ return self.__source
+
+ def menu(self):
+ """Menu which will be displayed
+
+ :rtype: qt.QMenu
+ """
+ return self.__menu
+
+ def hoveredObject(self):
+ """Item content hovered by the mouse when the context menu was
+ requested
+
+ :rtype: H5Node
+ """
+ return self.__hoveredObject
+
+
+def htmlFromDict(dictionary, title=None):
+ """Generate a readable HTML from a dictionary
+
+ :param dict dictionary: A Dictionary
+ :rtype: str
+ """
+ result = """<html>
+ <head>
+ <style type="text/css">
+ ul { -qt-list-indent: 0; list-style: none; }
+ li > b {display: inline-block; min-width: 4em; font-weight: bold; }
+ </style>
+ </head>
+ <body>
+ """
+ if title is not None:
+ result += "<b>%s</b>" % escape(title)
+ result += "<ul>"
+ for key, value in dictionary.items():
+ result += "<li><b>%s</b>: %s</li>" % (escape(key), escape(value))
+ result += "</ul>"
+ result += "</body></html>"
+ return result
+
+
+class Hdf5DatasetMimeData(qt.QMimeData):
+ """Mimedata class to identify an internal drag and drop of a Hdf5Node."""
+
+ MIME_TYPE = "application/x-internal-h5py-dataset"
+
+ SILX_URI_TYPE = "application/x-silx-uri"
+
+ def __init__(self, node=None, dataset=None, isRoot=False):
+ qt.QMimeData.__init__(self)
+ self.__dataset = dataset
+ self.__node = node
+ self.__isRoot = isRoot
+ self.setData(self.MIME_TYPE, "".encode(encoding='utf-8'))
+ if node is not None:
+ h5Node = H5Node(node)
+ silxUrl = h5Node.url
+ self.setText(silxUrl)
+ self.setData(self.SILX_URI_TYPE, silxUrl.encode(encoding='utf-8'))
+
+ def isRoot(self):
+ return self.__isRoot
+
+ def node(self):
+ return self.__node
+
+ def dataset(self):
+ if self.__node is not None:
+ return self.__node.obj
+ return self.__dataset
+
+
+class H5Node(object):
+ """Adapter over an h5py object to provide missing informations from h5py
+ nodes, like internal node path and filename (which are not provided by
+ :mod:`h5py` for soft and external links).
+
+ It also provides an abstraction to reach node type for mimicked h5py
+ objects.
+ """
+
+ def __init__(self, h5py_item=None):
+ """Constructor
+
+ :param Hdf5Item h5py_item: An Hdf5Item
+ """
+ self.__h5py_object = h5py_item.obj
+ self.__h5py_target = None
+ self.__h5py_item = h5py_item
+
+ def __getattr__(self, name):
+ if hasattr(self.__h5py_object, name):
+ attr = getattr(self.__h5py_object, name)
+ return attr
+ raise AttributeError("H5Node has no attribute %s" % name)
+
+ def __get_target(self, obj):
+ """
+ Return the actual physical target of the provided object.
+
+ Objects can contains links in the middle of the path, this function
+ check each groups and remove this prefix in case of the link by the
+ link of the path.
+
+ :param obj: A valid h5py object (File, group or dataset)
+ :type obj: h5py.Dataset or h5py.Group or h5py.File
+ :rtype: h5py.Dataset or h5py.Group or h5py.File
+ """
+ elements = obj.name.split("/")
+ if obj.name == "/":
+ return obj
+ elif obj.name.startswith("/"):
+ elements.pop(0)
+ path = ""
+ subpath = ""
+ while len(elements) > 0:
+ e = elements.pop(0)
+ subpath = path + "/" + e
+ link = obj.parent.get(subpath, getlink=True)
+ classlink = silx.io.utils.get_h5_class(link)
+
+ if classlink == silx.io.utils.H5Type.EXTERNAL_LINK:
+ subpath = "/".join(elements)
+ external_obj = obj.parent.get(self.basename + "/" + subpath)
+ return self.__get_target(external_obj)
+ elif classlink == silx.io.utils.H5Type.SOFT_LINK:
+ # Restart from this stat
+ root_elements = link.path.split("/")
+ if link.path == "/":
+ path = ""
+ root_elements = []
+ elif link.path.startswith("/"):
+ path = ""
+ root_elements.pop(0)
+
+ for name in reversed(root_elements):
+ elements.insert(0, name)
+ else:
+ path = subpath
+
+ return obj.file[path]
+
+ @property
+ def h5py_target(self):
+ if self.__h5py_target is not None:
+ return self.__h5py_target
+ self.__h5py_target = self.__get_target(self.__h5py_object)
+ return self.__h5py_target
+
+ @property
+ def h5py_object(self):
+ """Returns the internal h5py node.
+
+ :rtype: h5py.File or h5py.Group or h5py.Dataset
+ """
+ return self.__h5py_object
+
+ @property
+ def h5type(self):
+ """Returns the node type, as an H5Type.
+
+ :rtype: H5Node
+ """
+ return silx.io.utils.get_h5_class(self.__h5py_object)
+
+ @property
+ def ntype(self):
+ """Returns the node type, as an h5py class.
+
+ :rtype:
+ :class:`h5py.File`, :class:`h5py.Group` or :class:`h5py.Dataset`
+ """
+ type_ = self.h5type
+ return silx.io.utils.h5type_to_h5py_class(type_)
+
+ @property
+ def basename(self):
+ """Returns the basename of this h5py node. It is the last identifier of
+ the path.
+
+ :rtype: str
+ """
+ return self.__h5py_object.name.split("/")[-1]
+
+ @property
+ def is_broken(self):
+ """Returns true if the node is a broken link.
+
+ :rtype: bool
+ """
+ if self.__h5py_item is None:
+ raise RuntimeError("h5py_item is not defined")
+ return self.__h5py_item.isBrokenObj()
+
+ @property
+ def local_name(self):
+ """Returns the path from the master file root to this node.
+
+ For links, this path is not equal to the h5py one.
+
+ :rtype: str
+ """
+ if self.__h5py_item is None:
+ raise RuntimeError("h5py_item is not defined")
+
+ result = []
+ item = self.__h5py_item
+ while item is not None:
+ # stop before the root item (item without parent)
+ if item.parent.parent is None:
+ name = item.obj.name
+ if name != "/":
+ result.append(item.obj.name)
+ break
+ else:
+ result.append(item.basename)
+ item = item.parent
+ if item is None:
+ raise RuntimeError("The item does not have parent holding h5py.File")
+ if result == []:
+ return "/"
+ if not result[-1].startswith("/"):
+ result.append("")
+ result.reverse()
+ name = "/".join(result)
+ return name
+
+ def __get_local_file(self):
+ """Returns the file of the root of this tree
+
+ :rtype: h5py.File
+ """
+ item = self.__h5py_item
+ while item.parent.parent is not None:
+ class_ = silx.io.utils.get_h5_class(class_=item.h5pyClass)
+ if class_ == silx.io.utils.H5Type.FILE:
+ break
+ item = item.parent
+
+ class_ = silx.io.utils.get_h5_class(class_=item.h5pyClass)
+ if class_ == silx.io.utils.H5Type.FILE:
+ return item.obj
+ else:
+ return item.obj.file
+
+ @property
+ def local_file(self):
+ """Returns the master file in which is this node.
+
+ For path containing external links, this file is not equal to the h5py
+ one.
+
+ :rtype: h5py.File
+ :raises RuntimeException: If no file are found
+ """
+ return self.__get_local_file()
+
+ @property
+ def local_filename(self):
+ """Returns the filename from the master file of this node.
+
+ For path containing external links, this path is not equal to the
+ filename provided by h5py.
+
+ :rtype: str
+ :raises RuntimeException: If no file are found
+ """
+ return self.local_file.filename
+
+ @property
+ def local_basename(self):
+ """Returns the basename from the master file root to this node.
+
+ For path containing links, this basename can be different than the
+ basename provided by h5py.
+
+ :rtype: str
+ """
+ class_ = self.__h5py_item.h5Class
+ if class_ is not None and class_ == silx.io.utils.H5Type.FILE:
+ return ""
+ return self.__h5py_item.basename
+
+ @property
+ def physical_file(self):
+ """Returns the physical file in which is this node.
+
+ .. versionadded:: 0.6
+
+ :rtype: h5py.File
+ :raises RuntimeError: If no file are found
+ """
+ class_ = silx.io.utils.get_h5_class(self.__h5py_object)
+ if class_ == silx.io.utils.H5Type.EXTERNAL_LINK:
+ # It means the link is broken
+ raise RuntimeError("No file node found")
+ if class_ == silx.io.utils.H5Type.SOFT_LINK:
+ # It means the link is broken
+ return self.local_file
+
+ physical_obj = self.h5py_target
+ return physical_obj.file
+
+ @property
+ def physical_name(self):
+ """Returns the path from the location this h5py node is physically
+ stored.
+
+ For broken links, this filename can be different from the
+ filename provided by h5py.
+
+ :rtype: str
+ """
+ class_ = silx.io.utils.get_h5_class(self.__h5py_object)
+ if class_ == silx.io.utils.H5Type.EXTERNAL_LINK:
+ # It means the link is broken
+ return self.__h5py_object.path
+ if class_ == silx.io.utils.H5Type.SOFT_LINK:
+ # It means the link is broken
+ return self.__h5py_object.path
+
+ physical_obj = self.h5py_target
+ return physical_obj.name
+
+ @property
+ def physical_filename(self):
+ """Returns the filename from the location this h5py node is physically
+ stored.
+
+ For broken links, this filename can be different from the
+ filename provided by h5py.
+
+ :rtype: str
+ """
+ class_ = silx.io.utils.get_h5_class(self.__h5py_object)
+ if class_ == silx.io.utils.H5Type.EXTERNAL_LINK:
+ # It means the link is broken
+ return self.__h5py_object.filename
+ if class_ == silx.io.utils.H5Type.SOFT_LINK:
+ # It means the link is broken
+ return self.local_file.filename
+
+ return self.physical_file.filename
+
+ @property
+ def physical_basename(self):
+ """Returns the basename from the location this h5py node is physically
+ stored.
+
+ For broken links, this basename can be different from the
+ basename provided by h5py.
+
+ :rtype: str
+ """
+ return self.physical_name.split("/")[-1]
+
+ @property
+ def data_url(self):
+ """Returns a :class:`silx.io.url.DataUrl` object identify this node in the file
+ system.
+
+ :rtype: ~silx.io.url.DataUrl
+ """
+ absolute_filename = os.path.abspath(self.local_filename)
+ return silx.io.url.DataUrl(scheme="silx",
+ file_path=absolute_filename,
+ data_path=self.local_name)
+
+ @property
+ def url(self):
+ """Returns an URL object identifying this node in the file
+ system.
+
+ This URL can be used in different ways.
+
+ .. code-block:: python
+
+ # Parsing the URL
+ import silx.io.url
+ dataurl = silx.io.url.DataUrl(item.url)
+ # dataurl provides access to URL fields
+
+ # Open a numpy array
+ import silx.io
+ dataset = silx.io.get_data(item.url)
+
+ # Open an hdf5 object (URL targetting a file or a group)
+ import silx.io
+ with silx.io.open(item.url) as h5:
+ ...your stuff...
+
+ :rtype: str
+ """
+ data_url = self.data_url
+ return data_url.path()
diff --git a/src/silx/gui/hdf5/setup.py b/src/silx/gui/hdf5/setup.py
new file mode 100644
index 0000000..786a851
--- /dev/null
+++ b/src/silx/gui/hdf5/setup.py
@@ -0,0 +1,41 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "28/09/2016"
+
+
+from numpy.distutils.misc_util import Configuration
+
+
+def configuration(parent_package='', top_path=None):
+ config = Configuration('hdf5', parent_package, top_path)
+ config.add_subpackage('test')
+ return config
+
+
+if __name__ == "__main__":
+ from numpy.distutils.core import setup
+ setup(configuration=configuration)
diff --git a/src/silx/gui/hdf5/test/__init__.py b/src/silx/gui/hdf5/test/__init__.py
new file mode 100644
index 0000000..71128fb
--- /dev/null
+++ b/src/silx/gui/hdf5/test/__init__.py
@@ -0,0 +1,24 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
diff --git a/src/silx/gui/hdf5/test/test_hdf5.py b/src/silx/gui/hdf5/test/test_hdf5.py
new file mode 100755
index 0000000..9b1b88a
--- /dev/null
+++ b/src/silx/gui/hdf5/test/test_hdf5.py
@@ -0,0 +1,1092 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test for silx.gui.hdf5 module"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "12/03/2019"
+
+
+import time
+import os
+import unittest
+import tempfile
+import numpy
+from pkg_resources import parse_version
+from contextlib import contextmanager
+from silx.gui import qt
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui import hdf5
+from silx.gui.utils.testutils import SignalListener
+from silx.io import commonh5
+import weakref
+
+import h5py
+import pytest
+
+
+h5py2_9 = parse_version(h5py.version.version) >= parse_version('2.9.0')
+
+
+@pytest.fixture(scope="class")
+def useH5File(request, tmpdir_factory):
+ tmp = tmpdir_factory.mktemp("test_hdf5")
+ request.cls.filename = os.path.join(tmp, "data.h5")
+ # create h5 data
+ with h5py.File(request.cls.filename, "w") as f:
+ g = f.create_group("arrays")
+ g.create_dataset("scalar", data=10)
+ yield
+
+
+def create_NXentry(group, name):
+ attrs = {"NX_class": "NXentry"}
+ node = commonh5.Group(name, parent=group, attrs=attrs)
+ group.add_node(node)
+ return node
+
+
+@pytest.mark.usefixtures("useH5File")
+class TestHdf5TreeModel(TestCaseQt):
+
+ def setUp(self):
+ super(TestHdf5TreeModel, self).setUp()
+
+ def waitForPendingOperations(self, model):
+ for _ in range(10):
+ if not model.hasPendingOperations():
+ break
+ self.qWait(10)
+ else:
+ raise RuntimeError("Still waiting for a pending operation")
+
+ @contextmanager
+ def h5TempFile(self):
+ # create tmp file
+ fd, tmp_name = tempfile.mkstemp(suffix=".h5")
+ os.close(fd)
+ # create h5 data
+ h5file = h5py.File(tmp_name, "w")
+ g = h5file.create_group("arrays")
+ g.create_dataset("scalar", data=10)
+ h5file.close()
+ yield tmp_name
+ # clean up
+ os.unlink(tmp_name)
+
+ def testCreate(self):
+ model = hdf5.Hdf5TreeModel()
+ self.assertIsNotNone(model)
+
+ def testAppendFilename(self):
+ model = hdf5.Hdf5TreeModel()
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
+ model.appendFile(self.filename)
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 1)
+ # clean up
+ ref = weakref.ref(model)
+ model = None
+ self.qWaitForDestroy(ref)
+
+ def testAppendBadFilename(self):
+ model = hdf5.Hdf5TreeModel()
+ self.assertRaises(IOError, model.appendFile, "#%$")
+
+ def testInsertFilename(self):
+ try:
+ model = hdf5.Hdf5TreeModel()
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
+ model.insertFile(self.filename)
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 1)
+ # clean up
+ index = model.index(0, 0, qt.QModelIndex())
+ h5File = model.data(index, hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ self.assertIsNotNone(h5File)
+ finally:
+ ref = weakref.ref(model)
+ model = None
+ self.qWaitForDestroy(ref)
+
+ def testInsertFilenameAsync(self):
+ try:
+ model = hdf5.Hdf5TreeModel()
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
+ model.insertFileAsync(self.filename)
+ index = model.index(0, 0, qt.QModelIndex())
+ self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5LoadingItem.Hdf5LoadingItem)
+ self.waitForPendingOperations(model)
+ index = model.index(0, 0, qt.QModelIndex())
+ self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5Item.Hdf5Item)
+ finally:
+ ref = weakref.ref(model)
+ model = None
+ self.qWaitForDestroy(ref)
+
+ def testInsertObject(self):
+ h5 = commonh5.File("/foo/bar/1.mock", "w")
+ model = hdf5.Hdf5TreeModel()
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
+ model.insertH5pyObject(h5)
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 1)
+
+ def testRemoveObject(self):
+ h5 = commonh5.File("/foo/bar/1.mock", "w")
+ model = hdf5.Hdf5TreeModel()
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
+ model.insertH5pyObject(h5)
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 1)
+ model.removeH5pyObject(h5)
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
+
+ def testSynchronizeObject(self):
+ h5 = h5py.File(self.filename, mode="r")
+ model = hdf5.Hdf5TreeModel()
+ model.insertH5pyObject(h5)
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 1)
+ index = model.index(0, 0, qt.QModelIndex())
+ node1 = model.nodeFromIndex(index)
+ model.synchronizeH5pyObject(h5)
+ self.waitForPendingOperations(model)
+ # Now h5 was loaded from it's filename
+ # Another ref is owned by the model
+ h5.close()
+
+ index = model.index(0, 0, qt.QModelIndex())
+ node2 = model.nodeFromIndex(index)
+ self.assertIsNot(node1, node2)
+ # after sync
+ time.sleep(0.1)
+ self.qapp.processEvents()
+ time.sleep(0.1)
+ index = model.index(0, 0, qt.QModelIndex())
+ self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5Item.Hdf5Item)
+ # clean up
+ index = model.index(0, 0, qt.QModelIndex())
+ h5File = model.data(index, hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ self.assertIsNotNone(h5File)
+ h5File = None
+ # delete the model
+ ref = weakref.ref(model)
+ model = None
+ self.qWaitForDestroy(ref)
+
+ def testFileMoveState(self):
+ model = hdf5.Hdf5TreeModel()
+ self.assertEqual(model.isFileMoveEnabled(), True)
+ model.setFileMoveEnabled(False)
+ self.assertEqual(model.isFileMoveEnabled(), False)
+
+ def testFileDropState(self):
+ model = hdf5.Hdf5TreeModel()
+ self.assertEqual(model.isFileDropEnabled(), True)
+ model.setFileDropEnabled(False)
+ self.assertEqual(model.isFileDropEnabled(), False)
+
+ def testSupportedDrop(self):
+ model = hdf5.Hdf5TreeModel()
+ self.assertNotEqual(model.supportedDropActions(), 0)
+
+ model.setFileMoveEnabled(False)
+ model.setFileDropEnabled(False)
+ self.assertEqual(model.supportedDropActions(), 0)
+
+ model.setFileMoveEnabled(False)
+ model.setFileDropEnabled(True)
+ self.assertNotEqual(model.supportedDropActions(), 0)
+
+ model.setFileMoveEnabled(True)
+ model.setFileDropEnabled(False)
+ self.assertNotEqual(model.supportedDropActions(), 0)
+
+ def testCloseFile(self):
+ """A file inserted as a filename is open and closed internally."""
+ model = hdf5.Hdf5TreeModel()
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
+ model.insertFile(self.filename)
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 1)
+ index = model.index(0, 0)
+ h5File = model.data(index, role=hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ model.removeIndex(index)
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
+ self.assertFalse(bool(h5File.id.valid), "The HDF5 file was not closed")
+
+ def testNotCloseFile(self):
+ """A file inserted as an h5py object is not open (then not closed)
+ internally."""
+ try:
+ h5File = h5py.File(self.filename, mode="r")
+ model = hdf5.Hdf5TreeModel()
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
+ model.insertH5pyObject(h5File)
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 1)
+ index = model.index(0, 0)
+ h5File = model.data(index, role=hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ model.removeIndex(index)
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
+ self.assertTrue(bool(h5File.id.valid), "The HDF5 file was unexpetedly closed")
+ finally:
+ h5File.close()
+
+ def testDropExternalFile(self):
+ model = hdf5.Hdf5TreeModel()
+ mimeData = qt.QMimeData()
+ mimeData.setUrls([qt.QUrl.fromLocalFile(self.filename)])
+ model.dropMimeData(mimeData, qt.Qt.CopyAction, 0, 0, qt.QModelIndex())
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 1)
+ # after sync
+ self.waitForPendingOperations(model)
+ index = model.index(0, 0, qt.QModelIndex())
+ self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5Item.Hdf5Item)
+ # clean up
+ index = model.index(0, 0, qt.QModelIndex())
+ h5File = model.data(index, role=hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ self.assertIsNotNone(h5File)
+ h5File = None
+ ref = weakref.ref(model)
+ model = None
+ self.qWaitForDestroy(ref)
+
+ def getRowDataAsDict(self, model, row):
+ displayed = {}
+ roles = [qt.Qt.DisplayRole, qt.Qt.DecorationRole, qt.Qt.ToolTipRole, qt.Qt.TextAlignmentRole]
+ for column in range(0, model.columnCount(qt.QModelIndex())):
+ index = model.index(0, column, qt.QModelIndex())
+ for role in roles:
+ datum = model.data(index, role)
+ displayed[column, role] = datum
+ return displayed
+
+ def getItemName(self, model, row):
+ index = model.index(row, hdf5.Hdf5TreeModel.NAME_COLUMN, qt.QModelIndex())
+ return model.data(index, qt.Qt.DisplayRole)
+
+ def testFileData(self):
+ h5 = commonh5.File("/foo/bar/1.mock", "w")
+ model = hdf5.Hdf5TreeModel()
+ model.insertH5pyObject(h5)
+ displayed = self.getRowDataAsDict(model, row=0)
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DisplayRole], "1.mock")
+ self.assertIsInstance(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DecorationRole], qt.QIcon)
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.TYPE_COLUMN, qt.Qt.DisplayRole], "")
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.SHAPE_COLUMN, qt.Qt.DisplayRole], "")
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.VALUE_COLUMN, qt.Qt.DisplayRole], "")
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.DESCRIPTION_COLUMN, qt.Qt.DisplayRole], None)
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.NODE_COLUMN, qt.Qt.DisplayRole], "File")
+
+ def testGroupData(self):
+ h5 = commonh5.File("/foo/bar/1.mock", "w")
+ d = h5.create_group("foo")
+ d.attrs["desc"] = "fooo"
+
+ model = hdf5.Hdf5TreeModel()
+ model.insertH5pyObject(d)
+ displayed = self.getRowDataAsDict(model, row=0)
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DisplayRole], "1.mock::foo")
+ self.assertIsInstance(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DecorationRole], qt.QIcon)
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.TYPE_COLUMN, qt.Qt.DisplayRole], "")
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.SHAPE_COLUMN, qt.Qt.DisplayRole], "")
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.VALUE_COLUMN, qt.Qt.DisplayRole], "")
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.DESCRIPTION_COLUMN, qt.Qt.DisplayRole], "fooo")
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.NODE_COLUMN, qt.Qt.DisplayRole], "Group")
+
+ def testDatasetData(self):
+ h5 = commonh5.File("/foo/bar/1.mock", "w")
+ value = numpy.array([1, 2, 3])
+ d = h5.create_dataset("foo", data=value)
+
+ model = hdf5.Hdf5TreeModel()
+ model.insertH5pyObject(d)
+ displayed = self.getRowDataAsDict(model, row=0)
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DisplayRole], "1.mock::foo")
+ self.assertIsInstance(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DecorationRole], qt.QIcon)
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.TYPE_COLUMN, qt.Qt.DisplayRole], value.dtype.name)
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.SHAPE_COLUMN, qt.Qt.DisplayRole], "3")
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.VALUE_COLUMN, qt.Qt.DisplayRole], "[1 2 3]")
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.DESCRIPTION_COLUMN, qt.Qt.DisplayRole], "[1 2 3]")
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.NODE_COLUMN, qt.Qt.DisplayRole], "Dataset")
+
+ def testDropLastAsFirst(self):
+ model = hdf5.Hdf5TreeModel()
+ h5_1 = commonh5.File("/foo/bar/1.mock", "w")
+ h5_2 = commonh5.File("/foo/bar/2.mock", "w")
+ model.insertH5pyObject(h5_1)
+ model.insertH5pyObject(h5_2)
+ self.assertEqual(self.getItemName(model, 0), "1.mock")
+ self.assertEqual(self.getItemName(model, 1), "2.mock")
+ index = model.index(1, 0, qt.QModelIndex())
+ mimeData = model.mimeData([index])
+ model.dropMimeData(mimeData, qt.Qt.MoveAction, 0, 0, qt.QModelIndex())
+ self.assertEqual(self.getItemName(model, 0), "2.mock")
+ self.assertEqual(self.getItemName(model, 1), "1.mock")
+
+ def testDropFirstAsLast(self):
+ model = hdf5.Hdf5TreeModel()
+ h5_1 = commonh5.File("/foo/bar/1.mock", "w")
+ h5_2 = commonh5.File("/foo/bar/2.mock", "w")
+ model.insertH5pyObject(h5_1)
+ model.insertH5pyObject(h5_2)
+ self.assertEqual(self.getItemName(model, 0), "1.mock")
+ self.assertEqual(self.getItemName(model, 1), "2.mock")
+ index = model.index(0, 0, qt.QModelIndex())
+ mimeData = model.mimeData([index])
+ model.dropMimeData(mimeData, qt.Qt.MoveAction, 2, 0, qt.QModelIndex())
+ self.assertEqual(self.getItemName(model, 0), "2.mock")
+ self.assertEqual(self.getItemName(model, 1), "1.mock")
+
+ def testRootParent(self):
+ model = hdf5.Hdf5TreeModel()
+ h5_1 = commonh5.File("/foo/bar/1.mock", "w")
+ model.insertH5pyObject(h5_1)
+ index = model.index(0, 0, qt.QModelIndex())
+ index = model.parent(index)
+ self.assertEqual(index, qt.QModelIndex())
+
+
+@pytest.mark.usefixtures("useH5File")
+class TestHdf5TreeModelSignals(TestCaseQt):
+
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.model = hdf5.Hdf5TreeModel()
+ self.h5 = h5py.File(self.filename, mode='r')
+ self.model.insertH5pyObject(self.h5)
+
+ self.listener = SignalListener()
+ self.model.sigH5pyObjectLoaded.connect(self.listener.partial(signal="loaded"))
+ self.model.sigH5pyObjectRemoved.connect(self.listener.partial(signal="removed"))
+ self.model.sigH5pyObjectSynchronized.connect(self.listener.partial(signal="synchronized"))
+
+ def tearDown(self):
+ self.signals = None
+ ref = weakref.ref(self.model)
+ self.model = None
+ self.qWaitForDestroy(ref)
+ self.h5.close()
+ self.h5 = None
+ TestCaseQt.tearDown(self)
+
+ def waitForPendingOperations(self, model):
+ for _ in range(10):
+ if not model.hasPendingOperations():
+ break
+ self.qWait(10)
+ else:
+ raise RuntimeError("Still waiting for a pending operation")
+
+ def testInsert(self):
+ h5 = h5py.File(self.filename, mode='r')
+ self.model.insertH5pyObject(h5)
+ self.assertEqual(self.listener.callCount(), 0)
+
+ def testLoaded(self):
+ self.model.insertFile(self.filename)
+ self.assertEqual(self.listener.callCount(), 1)
+ self.assertEqual(self.listener.karguments(argumentName="signal")[0], "loaded")
+ self.assertIsNot(self.listener.arguments(callIndex=0)[0], self.h5)
+ self.assertEqual(self.listener.arguments(callIndex=0)[0].filename, self.filename)
+
+ def testRemoved(self):
+ self.model.removeH5pyObject(self.h5)
+ self.assertEqual(self.listener.callCount(), 1)
+ self.assertEqual(self.listener.karguments(argumentName="signal")[0], "removed")
+ self.assertIs(self.listener.arguments(callIndex=0)[0], self.h5)
+
+ def testSynchonized(self):
+ self.model.synchronizeH5pyObject(self.h5)
+ self.waitForPendingOperations(self.model)
+ self.assertEqual(self.listener.callCount(), 1)
+ self.assertEqual(self.listener.karguments(argumentName="signal")[0], "synchronized")
+ self.assertIs(self.listener.arguments(callIndex=0)[0], self.h5)
+ self.assertIsNot(self.listener.arguments(callIndex=0)[1], self.h5)
+
+
+class TestNexusSortFilterProxyModel(TestCaseQt):
+
+ def getChildNames(self, model, index):
+ count = model.rowCount(index)
+ result = []
+ for row in range(0, count):
+ itemIndex = model.index(row, hdf5.Hdf5TreeModel.NAME_COLUMN, index)
+ name = model.data(itemIndex, qt.Qt.DisplayRole)
+ result.append(name)
+ return result
+
+ def testNXentryStartTime(self):
+ """Test NXentry with start_time"""
+ model = hdf5.Hdf5TreeModel()
+ h5 = commonh5.File("/foo/bar/1.mock", "w")
+ create_NXentry(h5, "a").create_dataset("start_time", data=numpy.string_("2015"))
+ create_NXentry(h5, "b").create_dataset("start_time", data=numpy.string_("2013"))
+ create_NXentry(h5, "c").create_dataset("start_time", data=numpy.string_("2014"))
+ model.insertH5pyObject(h5)
+
+ proxy = hdf5.NexusSortFilterProxyModel()
+ proxy.setSourceModel(model)
+ proxy.sort(0, qt.Qt.DescendingOrder)
+ names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
+ self.assertListEqual(names, ["a", "c", "b"])
+
+ def testNXentryStartTimeInArray(self):
+ """Test NXentry with start_time"""
+ model = hdf5.Hdf5TreeModel()
+ h5 = commonh5.File("/foo/bar/1.mock", "w")
+ create_NXentry(h5, "a").create_dataset("start_time", data=numpy.array([numpy.string_("2015")]))
+ create_NXentry(h5, "b").create_dataset("start_time", data=numpy.array([numpy.string_("2013")]))
+ create_NXentry(h5, "c").create_dataset("start_time", data=numpy.array([numpy.string_("2014")]))
+ model.insertH5pyObject(h5)
+
+ proxy = hdf5.NexusSortFilterProxyModel()
+ proxy.setSourceModel(model)
+ proxy.sort(0, qt.Qt.DescendingOrder)
+ names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
+ self.assertListEqual(names, ["a", "c", "b"])
+
+ def testNXentryEndTimeInArray(self):
+ """Test NXentry with end_time"""
+ model = hdf5.Hdf5TreeModel()
+ h5 = commonh5.File("/foo/bar/1.mock", "w")
+ create_NXentry(h5, "a").create_dataset("end_time", data=numpy.array([numpy.string_("2015")]))
+ create_NXentry(h5, "b").create_dataset("end_time", data=numpy.array([numpy.string_("2013")]))
+ create_NXentry(h5, "c").create_dataset("end_time", data=numpy.array([numpy.string_("2014")]))
+ model.insertH5pyObject(h5)
+
+ proxy = hdf5.NexusSortFilterProxyModel()
+ proxy.setSourceModel(model)
+ proxy.sort(0, qt.Qt.DescendingOrder)
+ names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
+ self.assertListEqual(names, ["a", "c", "b"])
+
+ def testNXentryName(self):
+ """Test NXentry without start_time or end_time"""
+ model = hdf5.Hdf5TreeModel()
+ h5 = commonh5.File("/foo/bar/1.mock", "w")
+ create_NXentry(h5, "a")
+ create_NXentry(h5, "c")
+ create_NXentry(h5, "b")
+ model.insertH5pyObject(h5)
+
+ proxy = hdf5.NexusSortFilterProxyModel()
+ proxy.setSourceModel(model)
+ proxy.sort(0, qt.Qt.AscendingOrder)
+ names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
+ self.assertListEqual(names, ["a", "b", "c"])
+
+ def testStartTime(self):
+ """If it is not NXentry, start_time is not used"""
+ model = hdf5.Hdf5TreeModel()
+ h5 = commonh5.File("/foo/bar/1.mock", "w")
+ h5.create_group("a").create_dataset("start_time", data=numpy.string_("2015"))
+ h5.create_group("b").create_dataset("start_time", data=numpy.string_("2013"))
+ h5.create_group("c").create_dataset("start_time", data=numpy.string_("2014"))
+ model.insertH5pyObject(h5)
+
+ proxy = hdf5.NexusSortFilterProxyModel()
+ proxy.setSourceModel(model)
+ proxy.sort(0, qt.Qt.AscendingOrder)
+ names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
+ self.assertListEqual(names, ["a", "b", "c"])
+
+ def testName(self):
+ model = hdf5.Hdf5TreeModel()
+ h5 = commonh5.File("/foo/bar/1.mock", "w")
+ h5.create_group("a")
+ h5.create_group("c")
+ h5.create_group("b")
+ model.insertH5pyObject(h5)
+
+ proxy = hdf5.NexusSortFilterProxyModel()
+ proxy.setSourceModel(model)
+ proxy.sort(0, qt.Qt.AscendingOrder)
+ names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
+ self.assertListEqual(names, ["a", "b", "c"])
+
+ def testNumber(self):
+ model = hdf5.Hdf5TreeModel()
+ h5 = commonh5.File("/foo/bar/1.mock", "w")
+ h5.create_group("a1")
+ h5.create_group("a20")
+ h5.create_group("a3")
+ model.insertH5pyObject(h5)
+
+ proxy = hdf5.NexusSortFilterProxyModel()
+ proxy.setSourceModel(model)
+ proxy.sort(0, qt.Qt.AscendingOrder)
+ names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
+ self.assertListEqual(names, ["a1", "a3", "a20"])
+
+ def testMultiNumber(self):
+ model = hdf5.Hdf5TreeModel()
+ h5 = commonh5.File("/foo/bar/1.mock", "w")
+ h5.create_group("a1-1")
+ h5.create_group("a20-1")
+ h5.create_group("a3-1")
+ h5.create_group("a3-20")
+ h5.create_group("a3-3")
+ model.insertH5pyObject(h5)
+
+ proxy = hdf5.NexusSortFilterProxyModel()
+ proxy.setSourceModel(model)
+ proxy.sort(0, qt.Qt.AscendingOrder)
+ names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
+ self.assertListEqual(names, ["a1-1", "a3-1", "a3-3", "a3-20", "a20-1"])
+
+ def testUnconsistantTypes(self):
+ model = hdf5.Hdf5TreeModel()
+ h5 = commonh5.File("/foo/bar/1.mock", "w")
+ h5.create_group("aaa100")
+ h5.create_group("100aaa")
+ model.insertH5pyObject(h5)
+
+ proxy = hdf5.NexusSortFilterProxyModel()
+ proxy.setSourceModel(model)
+ proxy.sort(0, qt.Qt.AscendingOrder)
+ names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
+ self.assertListEqual(names, ["100aaa", "aaa100"])
+
+
+@pytest.fixture(scope='class')
+def useH5Model(request, tmpdir_factory):
+ # Create HDF5 files
+ tmp = tmpdir_factory.mktemp("test_hdf5")
+ filename = os.path.join(tmp, "base.h5")
+ extH5FileName = os.path.join(tmp, "base__external.h5")
+ extDatFileName = os.path.join(tmp, "base__external.dat")
+
+ externalh5 = h5py.File(extH5FileName, mode="w")
+ externalh5["target/dataset"] = 50
+ externalh5["target/link"] = h5py.SoftLink("/target/dataset")
+ externalh5["/ext/vds0"] = [0, 1]
+ externalh5["/ext/vds1"] = [2, 3]
+ externalh5.close()
+
+ numpy.array([0,1,10,10,2,3]).tofile(extDatFileName)
+
+ h5 = h5py.File(filename, mode="w")
+ h5["group/dataset"] = 50
+ h5["link/soft_link"] = h5py.SoftLink("/group/dataset")
+ h5["link/soft_link_to_group"] = h5py.SoftLink("/group")
+ h5["link/soft_link_to_link"] = h5py.SoftLink("/link/soft_link")
+ h5["link/soft_link_to_file"] = h5py.SoftLink("/")
+ h5["group/soft_link_relative"] = h5py.SoftLink("dataset")
+ h5["link/external_link"] = h5py.ExternalLink(extH5FileName, "/target/dataset")
+ h5["link/external_link_to_link"] = h5py.ExternalLink(extH5FileName, "/target/link")
+ h5["broken_link/external_broken_file"] = h5py.ExternalLink(extH5FileName + "_not_exists", "/target/link")
+ h5["broken_link/external_broken_link"] = h5py.ExternalLink(extH5FileName, "/target/not_exists")
+ h5["broken_link/soft_broken_link"] = h5py.SoftLink("/group/not_exists")
+ h5["broken_link/soft_link_to_broken_link"] = h5py.SoftLink("/group/not_exists")
+ if h5py2_9:
+ layout = h5py.VirtualLayout((2,2), dtype=int)
+ layout[0] = h5py.VirtualSource("base__external.h5", name="/ext/vds0", shape=(2,), dtype=int)
+ layout[1] = h5py.VirtualSource("base__external.h5", name="/ext/vds1", shape=(2,), dtype=int)
+ h5.create_group("/ext")
+ h5["/ext"].create_virtual_dataset("virtual", layout)
+ external = [("base__external.dat", 0, 2*8), ("base__external.dat", 4*8, 2*8)]
+ h5["/ext"].create_dataset("raw", shape=(2,2), dtype=int, external=external)
+ h5.close()
+
+ with h5py.File(filename, mode="r") as h5File:
+ # Create model
+ request.cls.model = hdf5.Hdf5TreeModel()
+ request.cls.model.insertH5pyObject(h5File)
+ yield
+ ref = weakref.ref(request.cls.model)
+ request.cls.model = None
+ TestCaseQt.qWaitForDestroy(ref)
+
+
+@pytest.mark.usefixtures('useH5Model')
+class _TestModelBase(TestCaseQt):
+ def getIndexFromPath(self, model, path):
+ """
+ :param qt.QAbstractItemModel: model
+ """
+ index = qt.QModelIndex()
+ for name in path:
+ for row in range(model.rowCount(index)):
+ i = model.index(row, 0, index)
+ label = model.data(i)
+ if label == name:
+ index = i
+ break
+ else:
+ raise RuntimeError("Path not found")
+ return index
+
+ def getH5ItemFromPath(self, model, path):
+ index = self.getIndexFromPath(model, path)
+ return model.data(index, hdf5.Hdf5TreeModel.H5PY_ITEM_ROLE)
+
+
+class TestH5Item(_TestModelBase):
+
+ def testFile(self):
+ path = ["base.h5"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "")
+
+ def testGroup(self):
+ path = ["base.h5", "group"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "")
+
+ def testDataset(self):
+ path = ["base.h5", "group", "dataset"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "")
+
+ def testSoftLink(self):
+ path = ["base.h5", "link", "soft_link"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Soft")
+
+ def testSoftLinkToLink(self):
+ path = ["base.h5", "link", "soft_link_to_link"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Soft")
+
+ def testSoftLinkRelative(self):
+ path = ["base.h5", "group", "soft_link_relative"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Soft")
+
+ def testExternalLink(self):
+ path = ["base.h5", "link", "external_link"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "External")
+
+ def testExternalLinkToLink(self):
+ path = ["base.h5", "link", "external_link_to_link"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "External")
+
+ def testExternalBrokenFile(self):
+ path = ["base.h5", "broken_link", "external_broken_file"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "External")
+
+ def testExternalBrokenLink(self):
+ path = ["base.h5", "broken_link", "external_broken_link"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "External")
+
+ def testSoftBrokenLink(self):
+ path = ["base.h5", "broken_link", "soft_broken_link"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Soft")
+
+ def testSoftLinkToBrokenLink(self):
+ path = ["base.h5", "broken_link", "soft_link_to_broken_link"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Soft")
+
+ def testDatasetFromSoftLinkToGroup(self):
+ path = ["base.h5", "link", "soft_link_to_group", "dataset"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "")
+
+ def testDatasetFromSoftLinkToFile(self):
+ path = ["base.h5", "link", "soft_link_to_file", "link", "soft_link_to_group", "dataset"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "")
+
+ @pytest.mark.skipif(not h5py2_9, reason="requires h5py>=2.9")
+ def testExternalVirtual(self):
+ path = ["base.h5", "ext", "virtual"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Virtual")
+
+ @pytest.mark.skipif(not h5py2_9, reason="requires h5py>=2.9")
+ def testExternalRaw(self):
+ path = ["base.h5", "ext", "raw"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "ExtRaw")
+
+
+class TestH5Node(_TestModelBase):
+
+ def getH5NodeFromPath(self, model, path):
+ item = self.getH5ItemFromPath(model, path)
+ h5node = hdf5.H5Node(item)
+ return h5node
+
+ def testFile(self):
+ path = ["base.h5"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "")
+ self.assertEqual(h5node.physical_name, "/")
+ self.assertEqual(h5node.local_basename, "")
+ self.assertEqual(h5node.local_name, "/")
+
+ def testGroup(self):
+ path = ["base.h5", "group"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "group")
+ self.assertEqual(h5node.physical_name, "/group")
+ self.assertEqual(h5node.local_basename, "group")
+ self.assertEqual(h5node.local_name, "/group")
+
+ def testDataset(self):
+ path = ["base.h5", "group", "dataset"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "dataset")
+ self.assertEqual(h5node.physical_name, "/group/dataset")
+ self.assertEqual(h5node.local_basename, "dataset")
+ self.assertEqual(h5node.local_name, "/group/dataset")
+
+ def testSoftLink(self):
+ path = ["base.h5", "link", "soft_link"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "dataset")
+ self.assertEqual(h5node.physical_name, "/group/dataset")
+ self.assertEqual(h5node.local_basename, "soft_link")
+ self.assertEqual(h5node.local_name, "/link/soft_link")
+
+ def testSoftLinkToLink(self):
+ path = ["base.h5", "link", "soft_link_to_link"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "dataset")
+ self.assertEqual(h5node.physical_name, "/group/dataset")
+ self.assertEqual(h5node.local_basename, "soft_link_to_link")
+ self.assertEqual(h5node.local_name, "/link/soft_link_to_link")
+
+ def testSoftLinkRelative(self):
+ path = ["base.h5", "group", "soft_link_relative"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "dataset")
+ self.assertEqual(h5node.physical_name, "/group/dataset")
+ self.assertEqual(h5node.local_basename, "soft_link_relative")
+ self.assertEqual(h5node.local_name, "/group/soft_link_relative")
+
+ def testExternalLink(self):
+ path = ["base.h5", "link", "external_link"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertNotEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.local_filename)
+ self.assertIn("base__external.h5", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "dataset")
+ self.assertEqual(h5node.physical_name, "/target/dataset")
+ self.assertEqual(h5node.local_basename, "external_link")
+ self.assertEqual(h5node.local_name, "/link/external_link")
+
+ def testExternalLinkToLink(self):
+ path = ["base.h5", "link", "external_link_to_link"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertNotEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.local_filename)
+ self.assertIn("base__external.h5", h5node.physical_filename)
+
+ self.assertNotEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertEqual(h5node.physical_basename, "dataset")
+ self.assertEqual(h5node.physical_name, "/target/dataset")
+ self.assertEqual(h5node.local_basename, "external_link_to_link")
+ self.assertEqual(h5node.local_name, "/link/external_link_to_link")
+
+ def testExternalBrokenFile(self):
+ path = ["base.h5", "broken_link", "external_broken_file"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertNotEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.local_filename)
+ self.assertIn("not_exists", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "link")
+ self.assertEqual(h5node.physical_name, "/target/link")
+ self.assertEqual(h5node.local_basename, "external_broken_file")
+ self.assertEqual(h5node.local_name, "/broken_link/external_broken_file")
+
+ def testExternalBrokenLink(self):
+ path = ["base.h5", "broken_link", "external_broken_link"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertNotEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.local_filename)
+ self.assertIn("__external", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "not_exists")
+ self.assertEqual(h5node.physical_name, "/target/not_exists")
+ self.assertEqual(h5node.local_basename, "external_broken_link")
+ self.assertEqual(h5node.local_name, "/broken_link/external_broken_link")
+
+ def testSoftBrokenLink(self):
+ path = ["base.h5", "broken_link", "soft_broken_link"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "not_exists")
+ self.assertEqual(h5node.physical_name, "/group/not_exists")
+ self.assertEqual(h5node.local_basename, "soft_broken_link")
+ self.assertEqual(h5node.local_name, "/broken_link/soft_broken_link")
+
+ def testSoftLinkToBrokenLink(self):
+ path = ["base.h5", "broken_link", "soft_link_to_broken_link"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "not_exists")
+ self.assertEqual(h5node.physical_name, "/group/not_exists")
+ self.assertEqual(h5node.local_basename, "soft_link_to_broken_link")
+ self.assertEqual(h5node.local_name, "/broken_link/soft_link_to_broken_link")
+
+ def testDatasetFromSoftLinkToGroup(self):
+ path = ["base.h5", "link", "soft_link_to_group", "dataset"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "dataset")
+ self.assertEqual(h5node.physical_name, "/group/dataset")
+ self.assertEqual(h5node.local_basename, "dataset")
+ self.assertEqual(h5node.local_name, "/link/soft_link_to_group/dataset")
+
+ def testDatasetFromSoftLinkToFile(self):
+ path = ["base.h5", "link", "soft_link_to_file", "link", "soft_link_to_group", "dataset"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "dataset")
+ self.assertEqual(h5node.physical_name, "/group/dataset")
+ self.assertEqual(h5node.local_basename, "dataset")
+ self.assertEqual(h5node.local_name, "/link/soft_link_to_file/link/soft_link_to_group/dataset")
+
+ @pytest.mark.skipif(not h5py2_9, reason="requires h5py>=2.9")
+ def testExternalVirtual(self):
+ path = ["base.h5", "ext", "virtual"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "virtual")
+ self.assertEqual(h5node.physical_name, "/ext/virtual")
+ self.assertEqual(h5node.local_basename, "virtual")
+ self.assertEqual(h5node.local_name, "/ext/virtual")
+
+ @pytest.mark.skipif(not h5py2_9, reason="requires h5py>=2.9")
+ def testExternalRaw(self):
+ path = ["base.h5", "ext", "raw"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "raw")
+ self.assertEqual(h5node.physical_name, "/ext/raw")
+ self.assertEqual(h5node.local_basename, "raw")
+ self.assertEqual(h5node.local_name, "/ext/raw")
+
+
+class TestHdf5TreeView(TestCaseQt):
+ """Test to check that icons module."""
+
+ def setUp(self):
+ super(TestHdf5TreeView, self).setUp()
+
+ def testCreate(self):
+ view = hdf5.Hdf5TreeView()
+ self.assertIsNotNone(view)
+
+ def testContextMenu(self):
+ view = hdf5.Hdf5TreeView()
+ view._createContextMenu(qt.QPoint(0, 0))
+
+ def testSelection_OriginalModel(self):
+ tree = commonh5.File("/foo/bar/1.mock", "w")
+ item = tree.create_group("a/b/c/d")
+ item.create_group("e").create_group("f")
+
+ view = hdf5.Hdf5TreeView()
+ view.findHdf5TreeModel().insertH5pyObject(tree)
+ view.setSelectedH5Node(item)
+
+ selected = list(view.selectedH5Nodes())[0]
+ self.assertIs(item, selected.h5py_object)
+
+ def testSelection_Simple(self):
+ tree = commonh5.File("/foo/bar/1.mock", "w")
+ item = tree.create_group("a/b/c/d")
+ item.create_group("e").create_group("f")
+
+ model = hdf5.Hdf5TreeModel()
+ model.insertH5pyObject(tree)
+ view = hdf5.Hdf5TreeView()
+ view.setModel(model)
+ view.setSelectedH5Node(item)
+
+ selected = list(view.selectedH5Nodes())[0]
+ self.assertIs(item, selected.h5py_object)
+
+ def testSelection_NotFound(self):
+ tree2 = commonh5.File("/foo/bar/2.mock", "w")
+ tree = commonh5.File("/foo/bar/1.mock", "w")
+ item = tree.create_group("a/b/c/d")
+ item.create_group("e").create_group("f")
+
+ model = hdf5.Hdf5TreeModel()
+ model.insertH5pyObject(tree)
+ view = hdf5.Hdf5TreeView()
+ view.setModel(model)
+ view.setSelectedH5Node(tree2)
+
+ selection = list(view.selectedH5Nodes())
+ self.assertEqual(len(selection), 0)
+
+ def testSelection_ManyGroupFromSameFile(self):
+ tree = commonh5.File("/foo/bar/1.mock", "w")
+ group1 = tree.create_group("a1")
+ group2 = tree.create_group("a2")
+ group3 = tree.create_group("a3")
+ group1.create_group("b/c/d")
+ item = group2.create_group("b/c/d")
+ group3.create_group("b/c/d")
+
+ model = hdf5.Hdf5TreeModel()
+ model.insertH5pyObject(group1)
+ model.insertH5pyObject(group2)
+ model.insertH5pyObject(group3)
+ view = hdf5.Hdf5TreeView()
+ view.setModel(model)
+ view.setSelectedH5Node(item)
+
+ selected = list(view.selectedH5Nodes())[0]
+ self.assertIs(item, selected.h5py_object)
+
+ def testSelection_RootFromSubTree(self):
+ tree = commonh5.File("/foo/bar/1.mock", "w")
+ group = tree.create_group("a1")
+ group.create_group("b/c/d")
+
+ model = hdf5.Hdf5TreeModel()
+ model.insertH5pyObject(group)
+ view = hdf5.Hdf5TreeView()
+ view.setModel(model)
+ view.setSelectedH5Node(group)
+
+ selected = list(view.selectedH5Nodes())[0]
+ self.assertIs(group, selected.h5py_object)
+
+ def testSelection_FileFromSubTree(self):
+ tree = commonh5.File("/foo/bar/1.mock", "w")
+ group = tree.create_group("a1")
+ group.create_group("b").create_group("b").create_group("d")
+
+ model = hdf5.Hdf5TreeModel()
+ model.insertH5pyObject(group)
+ view = hdf5.Hdf5TreeView()
+ view.setModel(model)
+ view.setSelectedH5Node(tree)
+
+ selection = list(view.selectedH5Nodes())
+ self.assertEqual(len(selection), 0)
+
+ def testSelection_Tree(self):
+ tree1 = commonh5.File("/foo/bar/1.mock", "w")
+ tree2 = commonh5.File("/foo/bar/2.mock", "w")
+ tree3 = commonh5.File("/foo/bar/3.mock", "w")
+ tree1.create_group("a/b/c")
+ tree2.create_group("a/b/c")
+ tree3.create_group("a/b/c")
+ item = tree2
+
+ model = hdf5.Hdf5TreeModel()
+ model.insertH5pyObject(tree1)
+ model.insertH5pyObject(tree2)
+ model.insertH5pyObject(tree3)
+ view = hdf5.Hdf5TreeView()
+ view.setModel(model)
+ view.setSelectedH5Node(item)
+
+ selected = list(view.selectedH5Nodes())[0]
+ self.assertIs(item, selected.h5py_object)
+
+ def testSelection_RecurssiveLink(self):
+ """
+ Recurssive link selection
+
+ This example is not really working as expected cause commonh5 do not
+ support recurssive links.
+ But item.name == "/a/b" and the result is found.
+ """
+ tree = commonh5.File("/foo/bar/1.mock", "w")
+ group = tree.create_group("a")
+ group.add_node(commonh5.SoftLink("b", "/"))
+
+ item = tree["/a/b/a/b/a/b/a/b/a/b/a/b/a/b/a/b"]
+
+ model = hdf5.Hdf5TreeModel()
+ model.insertH5pyObject(tree)
+ view = hdf5.Hdf5TreeView()
+ view.setModel(model)
+ view.setSelectedH5Node(item)
+
+ selected = list(view.selectedH5Nodes())[0]
+ self.assertEqual(item.name, selected.h5py_object.name)
+
+ def testSelection_SelectNone(self):
+ tree = commonh5.File("/foo/bar/1.mock", "w")
+
+ model = hdf5.Hdf5TreeModel()
+ model.insertH5pyObject(tree)
+ view = hdf5.Hdf5TreeView()
+ view.setModel(model)
+ view.setSelectedH5Node(tree)
+ view.setSelectedH5Node(None)
+
+ selection = list(view.selectedH5Nodes())
+ self.assertEqual(len(selection), 0)
diff --git a/src/silx/gui/icons.py b/src/silx/gui/icons.py
new file mode 100644
index 0000000..1493b92
--- /dev/null
+++ b/src/silx/gui/icons.py
@@ -0,0 +1,425 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Set of icons for buttons.
+
+Use :func:`getQIcon` to create Qt QIcon from the name identifying an icon.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "07/01/2019"
+
+
+import os
+import logging
+import weakref
+from . import qt
+import silx.resources
+from silx.utils import weakref as silxweakref
+from silx.utils.deprecation import deprecated
+
+
+_logger = logging.getLogger(__name__)
+"""Module logger"""
+
+
+_cached_icons = None
+"""Cache loaded icons in a weak structure"""
+
+
+def getIconCache():
+ """Get access to all cached icons
+
+ :rtype: dict
+ """
+ global _cached_icons
+ if _cached_icons is None:
+ _cached_icons = weakref.WeakValueDictionary()
+ # Clean up the cache before leaving the application
+ # See https://github.com/silx-kit/silx/issues/1771
+ qt.QApplication.instance().aboutToQuit.connect(cleanIconCache)
+ return _cached_icons
+
+
+def cleanIconCache():
+ """Clean up the icon cache"""
+ _logger.debug("Clean up icon cache")
+ _cached_icons.clear()
+
+
+_supported_formats = None
+"""Order of file format extension to check"""
+
+
+class AbstractAnimatedIcon(qt.QObject):
+ """Store an animated icon.
+
+ It provides an event containing the new icon everytime it is updated."""
+
+ def __init__(self, parent=None):
+ """Constructor
+
+ :param qt.QObject parent: Parent of the QObject
+ :raises: ValueError when name is not known
+ """
+ qt.QObject.__init__(self, parent)
+
+ self.__targets = silxweakref.WeakList()
+ self.__currentIcon = None
+
+ iconChanged = qt.Signal(qt.QIcon)
+ """Signal sent with a QIcon everytime the animation changed."""
+
+ def register(self, obj):
+ """Register an object to the AnimatedIcon.
+ If no object are registred, the animation is paused.
+ Object are stored in a weaked list.
+
+ :param object obj: An object
+ """
+ if obj not in self.__targets:
+ self.__targets.append(obj)
+ self._updateState()
+
+ def unregister(self, obj):
+ """Remove the object from the registration.
+ If no object are registred the animation is paused.
+
+ :param object obj: A registered object
+ """
+ if obj in self.__targets:
+ self.__targets.remove(obj)
+ self._updateState()
+
+ def hasRegistredObjects(self):
+ """Returns true if any object is registred.
+
+ :rtype: bool
+ """
+ return len(self.__targets)
+
+ def isRegistered(self, obj):
+ """Returns true if the object is registred in the AnimatedIcon.
+
+ :param object obj: An object
+ :rtype: bool
+ """
+ return obj in self.__targets
+
+ def currentIcon(self):
+ """Returns the icon of the current frame.
+
+ :rtype: qt.QIcon
+ """
+ return self.__currentIcon
+
+ def _updateState(self):
+ """Update the object according to the connected objects."""
+ pass
+
+ def _setCurrentIcon(self, icon):
+ """Store the current icon and emit a `iconChanged` event.
+
+ :param qt.QIcon icon: The current icon
+ """
+ self.__currentIcon = icon
+ self.iconChanged.emit(self.__currentIcon)
+
+
+class MovieAnimatedIcon(AbstractAnimatedIcon):
+ """Store a looping QMovie to provide icons for each frames.
+ Provides an event with the new icon everytime the movie frame
+ is updated."""
+
+ def __init__(self, filename, parent=None):
+ """Constructor
+
+ :param str filename: An icon name to an animated format
+ :param qt.QObject parent: Parent of the QObject
+ :raises: ValueError when name is not known
+ """
+ AbstractAnimatedIcon.__init__(self, parent)
+
+ qfile = getQFile(filename)
+ self.__movie = qt.QMovie(qfile.fileName(), qt.QByteArray(), parent)
+ self.__movie.setCacheMode(qt.QMovie.CacheAll)
+ self.__movie.frameChanged.connect(self.__frameChanged)
+ self.__cacheIcons = {}
+
+ self.__movie.jumpToFrame(0)
+ self.__updateIconAtFrame(0)
+
+ def __frameChanged(self, frameId):
+ """Callback everytime the QMovie frame change
+ :param int frameId: Current frame id
+ """
+ self.__updateIconAtFrame(frameId)
+
+ def __updateIconAtFrame(self, frameId):
+ """
+ Update the current stored QIcon
+
+ :param int frameId: Current frame id
+ """
+ if frameId in self.__cacheIcons:
+ icon = self.__cacheIcons[frameId]
+ else:
+ icon = qt.QIcon(self.__movie.currentPixmap())
+ self.__cacheIcons[frameId] = icon
+ self._setCurrentIcon(icon)
+
+ def _updateState(self):
+ """Update the movie play according to internal stat of the
+ AnimatedIcon."""
+ self.__movie.setPaused(not self.hasRegistredObjects())
+
+
+class MultiImageAnimatedIcon(AbstractAnimatedIcon):
+ """Store a looping QMovie to provide icons for each frames.
+ Provides an event with the new icon everytime the movie frame
+ is updated."""
+
+ def __init__(self, filename, parent=None):
+ """Constructor
+
+ :param str filename: An icon name to an animated format
+ :param qt.QObject parent: Parent of the QObject
+ :raises: ValueError when name is not known
+ """
+ AbstractAnimatedIcon.__init__(self, parent)
+
+ self.__frames = []
+ for i in range(100):
+ try:
+ frame_filename = os.sep.join((filename, ("%02d" %i)))
+ frame_file = getQFile(frame_filename)
+ except ValueError:
+ break
+ try:
+ icon = qt.QIcon(frame_file.fileName())
+ except ValueError:
+ break
+ self.__frames.append(icon)
+
+ if len(self.__frames) == 0:
+ raise ValueError("Animated icon '%s' do not exists" % filename)
+
+ self.__frameId = -1
+ self.__timer = qt.QTimer(self)
+ self.__timer.timeout.connect(self.__increaseFrame)
+ self.__updateIconAtFrame(0)
+
+ def __increaseFrame(self):
+ """Callback called every timer timeout to change the current frame of
+ the animation
+ """
+ frameId = (self.__frameId + 1) % len(self.__frames)
+ self.__updateIconAtFrame(frameId)
+
+ def __updateIconAtFrame(self, frameId):
+ """
+ Update the current stored QIcon
+
+ :param int frameId: Current frame id
+ """
+ self.__frameId = frameId
+ icon = self.__frames[frameId]
+ self._setCurrentIcon(icon)
+
+ def _updateState(self):
+ """Update the object to wake up or sleep it according to its use."""
+ if self.hasRegistredObjects():
+ if not self.__timer.isActive():
+ self.__timer.start(100)
+ else:
+ if self.__timer.isActive():
+ self.__timer.stop()
+
+
+class AnimatedIcon(MovieAnimatedIcon):
+ """Store a looping QMovie to provide icons for each frames.
+ Provides an event with the new icon everytime the movie frame
+ is updated.
+
+ It may not be available anymore for the silx release 0.6.
+
+ .. deprecated:: 0.5
+ Use :class:`MovieAnimatedIcon` instead.
+ """
+
+ @deprecated
+ def __init__(self, filename, parent=None):
+ MovieAnimatedIcon.__init__(self, filename, parent=parent)
+
+
+def getWaitIcon():
+ """Returns a cached version of the waiting AbstractAnimatedIcon.
+
+ :rtype: AbstractAnimatedIcon
+ """
+ return getAnimatedIcon("process-working")
+
+
+def getAnimatedIcon(name):
+ """Create an AbstractAnimatedIcon from a resource name.
+
+ The resource name can be prefixed by the name of a resource directory. For
+ example "silx:foo.png" identify the resource "foo.png" from the resource
+ directory "silx".
+
+ If no prefix are specified, the file with be returned from the silx
+ resource directory with a specific path "gui/icons".
+
+ See also :func:`silx.resources.register_resource_directory`.
+
+ Try to load a mng or a gif file, then try to load a multi-image animated
+ icon.
+
+ In Qt5 mng or gif are not used, because the transparency is not very well
+ managed.
+
+ :param str name: Name of the icon, in one of the defined icons
+ in this module.
+ :return: Corresponding AbstractAnimatedIcon
+ :raises: ValueError when name is not known
+ """
+ key = name + "__anim"
+ cached_icons = getIconCache()
+ if key not in cached_icons:
+
+ qtMajorVersion = int(qt.qVersion().split(".")[0])
+ icon = None
+
+ # ignore mng and gif in Qt5
+ if qtMajorVersion != 5:
+ try:
+ icon = MovieAnimatedIcon(name)
+ except ValueError:
+ icon = None
+
+ if icon is None:
+ try:
+ icon = MultiImageAnimatedIcon(name)
+ except ValueError:
+ icon = None
+
+ if icon is None:
+ raise ValueError("Not an animated icon name: %s", name)
+
+ cached_icons[key] = icon
+ else:
+ icon = cached_icons[key]
+ return icon
+
+
+def getQIcon(name):
+ """Create a QIcon from its name.
+
+ The resource name can be prefixed by the name of a resource directory. For
+ example "silx:foo.png" identify the resource "foo.png" from the resource
+ directory "silx".
+
+ If no prefix are specified, the file with be returned from the silx
+ resource directory with a specific path "gui/icons".
+
+ See also :func:`silx.resources.register_resource_directory`.
+
+ :param str name: Name of the icon, in one of the defined icons
+ in this module.
+ :return: Corresponding QIcon
+ :raises: ValueError when name is not known
+ """
+ cached_icons = getIconCache()
+ if name not in cached_icons:
+ qfile = getQFile(name)
+ icon = qt.QIcon(qfile.fileName())
+ cached_icons[name] = icon
+ else:
+ icon = cached_icons[name]
+ return icon
+
+
+def getQPixmap(name):
+ """Create a QPixmap from its name.
+
+ The resource name can be prefixed by the name of a resource directory. For
+ example "silx:foo.png" identify the resource "foo.png" from the resource
+ directory "silx".
+
+ If no prefix are specified, the file with be returned from the silx
+ resource directory with a specific path "gui/icons".
+
+ See also :func:`silx.resources.register_resource_directory`.
+
+ :param str name: Name of the icon, in one of the defined icons
+ in this module.
+ :return: Corresponding QPixmap
+ :raises: ValueError when name is not known
+ """
+ qfile = getQFile(name)
+ return qt.QPixmap(qfile.fileName())
+
+
+def getQFile(name):
+ """Create a QFile from an icon name. Filename is found
+ according to supported Qt formats.
+
+ The resource name can be prefixed by the name of a resource directory. For
+ example "silx:foo.png" identify the resource "foo.png" from the resource
+ directory "silx".
+
+ If no prefix are specified, the file with be returned from the silx
+ resource directory with a specific path "gui/icons".
+
+ See also :func:`silx.resources.register_resource_directory`.
+
+ :param str name: Name of the icon, in one of the defined icons
+ in this module.
+ :return: Corresponding QFile
+ :rtype: qt.QFile
+ :raises: ValueError when name is not known
+ """
+ global _supported_formats
+ if _supported_formats is None:
+ _supported_formats = []
+ supported_formats = qt.supportedImageFormats()
+ order = ["mng", "gif", "svg", "png", "jpg"]
+ for format_ in order:
+ if format_ in supported_formats:
+ _supported_formats.append(format_)
+ if len(_supported_formats) == 0:
+ _logger.error("No format supported for icons")
+ else:
+ _logger.debug("Format %s supported", ", ".join(_supported_formats))
+
+ for format_ in _supported_formats:
+ format_ = str(format_)
+ filename = silx.resources._resource_filename('%s.%s' % (name, format_),
+ default_directory=os.path.join('gui', 'icons'))
+ qfile = qt.QFile(filename)
+ if qfile.exists():
+ return qfile
+ _logger.debug("File '%s' not found.", filename)
+ raise ValueError('Not an icon name: %s' % name)
diff --git a/src/silx/gui/plot/AlphaSlider.py b/src/silx/gui/plot/AlphaSlider.py
new file mode 100644
index 0000000..da55b1e
--- /dev/null
+++ b/src/silx/gui/plot/AlphaSlider.py
@@ -0,0 +1,300 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module defines slider widgets interacting with the transparency
+of an image on a :class:`PlotWidget`
+
+Classes:
+--------
+
+- :class:`BaseAlphaSlider` (abstract class)
+- :class:`NamedImageAlphaSlider`
+- :class:`ActiveImageAlphaSlider`
+
+Example:
+--------
+
+This widget can, for instance, be added to a plot toolbar.
+
+.. code-block:: python
+
+ import numpy
+ from silx.gui import qt
+ from silx.gui.plot import PlotWidget
+ from silx.gui.plot.ImageAlphaSlider import NamedImageAlphaSlider
+
+ app = qt.QApplication([])
+ pw = PlotWidget()
+
+ img0 = numpy.arange(200*150).reshape((200, 150))
+ pw.addImage(img0, legend="my background", z=0, origin=(50, 50))
+
+ x, y = numpy.meshgrid(numpy.linspace(-10, 10, 200),
+ numpy.linspace(-10, 5, 150),
+ indexing="ij")
+ img1 = numpy.asarray(numpy.sin(x * y) / (x * y),
+ dtype='float32')
+
+ pw.addImage(img1, legend="my data", z=1,
+ replace=False)
+
+ alpha_slider = NamedImageAlphaSlider(parent=pw,
+ plot=pw,
+ legend="my data")
+ alpha_slider.setOrientation(qt.Qt.Horizontal)
+
+ toolbar = qt.QToolBar("plot", pw)
+ toolbar.addWidget(alpha_slider)
+ pw.addToolBar(toolbar)
+
+ pw.show()
+ app.exec()
+
+"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "24/03/2017"
+
+import logging
+
+from silx.gui import qt
+
+_logger = logging.getLogger(__name__)
+
+
+class BaseAlphaSlider(qt.QSlider):
+ """Slider widget to be used in a plot toolbar to control the
+ transparency of a plot primitive (image, scatter or curve).
+
+ Internally, the slider stores its state as an integer between
+ 0 and 255. This is the value emitted by the :attr:`valueChanged`
+ signal.
+
+ The method :meth:`getAlpha` returns the corresponding opacity/alpha
+ as a float between 0. and 1. (with a step of :math:`\frac{1}{255}`).
+
+ You must subclass this class and implement :meth:`getItem`.
+ """
+ sigAlphaChanged = qt.Signal(float)
+ """Emits the alpha value when the slider's value changes,
+ as a float between 0. and 1."""
+
+ def __init__(self, parent=None, plot=None):
+ """
+
+ :param parent: Parent QWidget
+ :param plot: Parent plot widget
+ """
+ assert plot is not None
+ super(BaseAlphaSlider, self).__init__(parent)
+
+ self.plot = plot
+
+ self.setRange(0, 255)
+
+ # if already connected to an item, use its alpha as initial value
+ if self.getItem() is None:
+ self.setValue(255)
+ self.setEnabled(False)
+ else:
+ alpha = self.getItem().getAlpha()
+ self.setValue(round(255*alpha))
+
+ self.valueChanged.connect(self._valueChanged)
+
+ def getItem(self):
+ """You must implement this class to define which item
+ to work on. It must return an item that inherits
+ :class:`silx.gui.plot.items.core.AlphaMixIn`.
+
+ :return: Item on which to operate, or None
+ :rtype: :class:`silx.plot.items.Item`
+ """
+ raise NotImplementedError(
+ "BaseAlphaSlider must be subclassed to " +
+ "implement getItem()")
+
+ def getAlpha(self):
+ """Get the opacity, as a float between 0. and 1.
+
+ :return: Alpha value in [0., 1.]
+ :rtype: float
+ """
+ return self.value() / 255.
+
+ def _valueChanged(self, value):
+ self._updateItem()
+ self.sigAlphaChanged.emit(value / 255.)
+
+ def _updateItem(self):
+ """Update the item's alpha channel.
+ """
+ item = self.getItem()
+ if item is not None:
+ item.setAlpha(self.getAlpha())
+
+
+class ActiveImageAlphaSlider(BaseAlphaSlider):
+ """Slider widget to be used in a plot toolbar to control the
+ transparency of the **active image**.
+
+ :param parent: Parent QWidget
+ :param plot: Plot on which to operate
+
+ See documentation of :class:`BaseAlphaSlider`
+ """
+ def __init__(self, parent=None, plot=None):
+ """
+
+ :param parent: Parent QWidget
+ :param plot: Plot widget on which to operate
+ """
+ super(ActiveImageAlphaSlider, self).__init__(parent, plot)
+ plot.sigActiveImageChanged.connect(self._activeImageChanged)
+
+ def getItem(self):
+ return self.plot.getActiveImage()
+
+ def _activeImageChanged(self, previous, new):
+ """Activate or deactivate slider depending on presence of a new
+ active image.
+ Apply transparency value to new active image.
+
+ :param previous: Legend of previous active image, or None
+ :param new: Legend of new active image, or None
+ """
+ if new is not None and not self.isEnabled():
+ self.setEnabled(True)
+ elif new is None and self.isEnabled():
+ self.setEnabled(False)
+
+ self._updateItem()
+
+
+class NamedItemAlphaSlider(BaseAlphaSlider):
+ """Slider widget to be used in a plot toolbar to control the
+ transparency of an item (defined by its kind and legend).
+
+ :param parent: Parent QWidget
+ :param plot: Plot on which to operate
+ :param str kind: Kind of item whose transparency is to be
+ controlled: "scatter", "image" or "curve".
+ :param str legend: Legend of item whose transparency is to be
+ controlled.
+ """
+ def __init__(self, parent=None, plot=None,
+ kind=None, legend=None):
+ self._item_legend = legend
+ self._item_kind = kind
+
+ super(NamedItemAlphaSlider, self).__init__(parent, plot)
+
+ self._updateState()
+ plot.sigContentChanged.connect(self._onContentChanged)
+
+ def _onContentChanged(self, action, kind, legend):
+ if legend == self._item_legend and kind == self._item_kind:
+ if action == "add":
+ self.setEnabled(True)
+ elif action == "remove":
+ self.setEnabled(False)
+
+ def _updateState(self):
+ """Enable or disable widget based on item's availability."""
+ if self.getItem() is not None:
+ self.setEnabled(True)
+ else:
+ self.setEnabled(False)
+
+ def getItem(self):
+ """Return plot item currently associated to this widget (can be
+ a curve, an image, a scatter...)
+
+ :rtype: subclass of :class:`silx.gui.plot.items.Item`"""
+ if self._item_legend is None or self._item_kind is None:
+ return None
+ return self.plot._getItem(kind=self._item_kind,
+ legend=self._item_legend)
+
+ def setLegend(self, legend):
+ """Associate a different item (of the same kind) to the slider.
+
+ :param legend: New legend of item whose transparency is to be
+ controlled.
+ """
+ self._item_legend = legend
+ self._updateState()
+
+ def getLegend(self):
+ """Return legend of the item currently controlled by this slider.
+
+ :return: Image legend associated to the slider
+ """
+ return self._item_kind
+
+ def setItemKind(self, legend):
+ """Associate a different item (of the same kind) to the slider.
+
+ :param legend: New legend of item whose transparency is to be
+ controlled.
+ """
+ self._item_legend = legend
+ self._updateState()
+
+ def getItemKind(self):
+ """Return kind of the item currently controlled by this slider.
+
+ :return: Item kind ("image", "scatter"...)
+ :rtype: str on None
+ """
+ return self._item_kind
+
+
+class NamedImageAlphaSlider(NamedItemAlphaSlider):
+ """Slider widget to be used in a plot toolbar to control the
+ transparency of an image (defined by its legend).
+
+ :param parent: Parent QWidget
+ :param plot: Plot on which to operate
+ :param str legend: Legend of image whose transparency is to be
+ controlled.
+ """
+ def __init__(self, parent=None, plot=None, legend=None):
+ NamedItemAlphaSlider.__init__(self, parent, plot,
+ kind="image", legend=legend)
+
+
+class NamedScatterAlphaSlider(NamedItemAlphaSlider):
+ """Slider widget to be used in a plot toolbar to control the
+ transparency of a scatter (defined by its legend).
+
+ :param parent: Parent QWidget
+ :param plot: Plot on which to operate
+ :param str legend: Legend of scatter whose transparency is to be
+ controlled.
+ """
+ def __init__(self, parent=None, plot=None, legend=None):
+ NamedItemAlphaSlider.__init__(self, parent, plot,
+ kind="scatter", legend=legend)
diff --git a/src/silx/gui/plot/ColorBar.py b/src/silx/gui/plot/ColorBar.py
new file mode 100644
index 0000000..8cafc06
--- /dev/null
+++ b/src/silx/gui/plot/ColorBar.py
@@ -0,0 +1,883 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Module containing several widgets associated to a colormap.
+"""
+
+__authors__ = ["H. Payno", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+
+import logging
+import weakref
+import numpy
+
+from ._utils import ticklayout
+from .. import qt
+from ..qt import inspect as qt_inspect
+from silx.gui import colors
+from silx.math.colormap import LogarithmicNormalization
+
+
+_logger = logging.getLogger(__name__)
+
+
+class ColorBarWidget(qt.QWidget):
+ """Colorbar widget displaying a colormap
+
+ It uses a description of colormap as dict compatible with :class:`Plot`.
+
+ .. image:: img/linearColorbar.png
+ :width: 80px
+ :align: center
+
+ To run the following sample code, a QApplication must be initialized.
+
+ >>> from silx.gui.plot import Plot2D
+ >>> from silx.gui.plot.ColorBar import ColorBarWidget
+
+ >>> plot = Plot2D() # Create a plot widget
+ >>> plot.show()
+
+ >>> colorbar = ColorBarWidget(plot=plot, legend='Colormap') # Associate the colorbar with it
+ >>> colorbar.show()
+
+ Initializer parameters:
+
+ :param parent: See :class:`QWidget`
+ :param plot: PlotWidget the colorbar is attached to (optional)
+ :param str legend: the label to set to the colorbar
+ """
+ sigVisibleChanged = qt.Signal(bool)
+ """Emitted when the property `visible` have changed."""
+
+ def __init__(self, parent=None, plot=None, legend=None):
+ self._isConnected = False
+ self._plotRef = None
+ self._colormap = None
+ self._data = None
+
+ super(ColorBarWidget, self).__init__(parent)
+
+ self.__buildGUI()
+ self.setLegend(legend)
+ self.setPlot(plot)
+
+ def __buildGUI(self):
+ self.setLayout(qt.QHBoxLayout())
+
+ # create color scale widget
+ self._colorScale = ColorScaleBar(parent=self,
+ colormap=None)
+ self.layout().addWidget(self._colorScale)
+
+ # legend (is the right group)
+ self.legend = _VerticalLegend('', self)
+ self.layout().addWidget(self.legend)
+
+ self.layout().setSizeConstraint(qt.QLayout.SetMinAndMaxSize)
+
+ def getPlot(self):
+ """Returns the :class:`Plot` associated to this widget or None"""
+ return None if self._plotRef is None else self._plotRef()
+
+ def setPlot(self, plot):
+ """Associate a plot to the ColorBar
+
+ :param plot: the plot to associate with the colorbar.
+ If None will remove any connection with a previous plot.
+ """
+ self._disconnectPlot()
+ self._plotRef = None if plot is None else weakref.ref(plot)
+ self._connectPlot()
+
+ def _disconnectPlot(self):
+ """Disconnect from Plot signals"""
+ if self._isConnected:
+ self._isConnected = False
+ plot = self.getPlot()
+ if plot is not None and qt_inspect.isValid(plot):
+ plot.sigActiveImageChanged.disconnect(
+ self._activeImageChanged)
+ plot.sigActiveScatterChanged.disconnect(
+ self._activeScatterChanged)
+ plot.sigPlotSignal.disconnect(self._defaultColormapChanged)
+
+ def _connectPlot(self):
+ """Connect to Plot signals"""
+ plot = self.getPlot()
+ if plot is not None and not self._isConnected:
+ activeImageLegend = plot.getActiveImage(just_legend=True)
+ activeScatterLegend = plot._getActiveItem(
+ kind='scatter', just_legend=True)
+ if activeImageLegend is None and activeScatterLegend is None:
+ # Show plot default colormap
+ self._syncWithDefaultColormap()
+ elif activeImageLegend is not None: # Show active image colormap
+ self._activeImageChanged(None, activeImageLegend)
+ elif activeScatterLegend is not None: # Show active scatter colormap
+ self._activeScatterChanged(None, activeScatterLegend)
+
+ plot.sigActiveImageChanged.connect(self._activeImageChanged)
+ plot.sigActiveScatterChanged.connect(self._activeScatterChanged)
+ plot.sigPlotSignal.connect(self._defaultColormapChanged)
+ self._isConnected = True
+
+ def setVisible(self, isVisible):
+ qt.QWidget.setVisible(self, isVisible)
+ self.sigVisibleChanged.emit(isVisible)
+
+ def showEvent(self, event):
+ self._connectPlot()
+
+ def hideEvent(self, event):
+ self._disconnectPlot()
+
+ def getColormap(self):
+ """Returns the colormap displayed in the colorbar.
+
+ :rtype: ~silx.gui.colors.Colormap
+ """
+ return self.getColorScaleBar().getColormap()
+
+ def setColormap(self, colormap, data=None):
+ """Set the colormap to be displayed.
+
+ :param ~silx.gui.colors.Colormap colormap:
+ The colormap to apply on the ColorBarWidget
+ :param Union[numpy.ndarray,~silx.gui.plot.items.ColormapMixin] data:
+ The data to display or item, needed if the colormap require an autoscale
+ """
+ self._data = data
+ self.getColorScaleBar().setColormap(colormap=colormap,
+ data=data)
+ if self._colormap is not None:
+ self._colormap.sigChanged.disconnect(self._colormapHasChanged)
+ self._colormap = colormap
+ if self._colormap is not None:
+ self._colormap.sigChanged.connect(self._colormapHasChanged)
+
+ def _colormapHasChanged(self):
+ """handler of the Colormap.sigChanged signal
+ """
+ assert self._colormap is not None
+ self.setColormap(colormap=self._colormap,
+ data=self._data)
+
+ def setLegend(self, legend):
+ """Set the legend displayed along the colorbar
+
+ :param str legend: The label
+ """
+ if legend is None or legend == "":
+ self.legend.hide()
+ self.legend.setText("")
+ else:
+ assert type(legend) is str
+ self.legend.show()
+ self.legend.setText(legend)
+
+ def getLegend(self):
+ """
+ Returns the legend displayed along the colorbar
+
+ :return: return the legend displayed along the colorbar
+ :rtype: str
+ """
+ return self.legend.text()
+
+ def _activeScatterChanged(self, previous, legend):
+ """Handle plot active scatter changed"""
+ plot = self.getPlot()
+
+ # Do not handle active scatter while there is an image
+ if plot.getActiveImage() is not None:
+ return
+
+ if legend is None: # No active scatter, display no colormap
+ self.setColormap(colormap=None)
+ return
+
+ # Sync with active scatter
+ scatter = plot._getActiveItem(kind='scatter')
+
+ self.setColormap(colormap=scatter.getColormap(),
+ data=scatter)
+
+ def _activeImageChanged(self, previous, legend):
+ """Handle plot active image changed"""
+ plot = self.getPlot()
+
+ if legend is None: # No active image, try with active scatter
+ activeScatterLegend = plot._getActiveItem(
+ kind='scatter', just_legend=True)
+ # No more active image, use active scatter if any
+ self._activeScatterChanged(None, activeScatterLegend)
+ else:
+ # Sync with active image
+ image = plot.getActiveImage()
+
+ # RGB(A) image, display default colormap
+ array = image.getData(copy=False)
+ if array.ndim != 2:
+ self.setColormap(colormap=None)
+ return
+
+ # data image, sync with image colormap
+ # do we need the copy here : used in the case we are changing
+ # vmin and vmax but should have already be done by the plot
+ self.setColormap(colormap=image.getColormap(), data=image)
+
+ def _defaultColormapChanged(self, event):
+ """Handle plot default colormap changed"""
+ if event['event'] == 'defaultColormapChanged':
+ plot = self.getPlot()
+ if (plot is not None and
+ plot.getActiveImage() is None and
+ plot._getActiveItem(kind='scatter') is None):
+ # No active item, take default colormap update into account
+ self._syncWithDefaultColormap()
+
+ def _syncWithDefaultColormap(self):
+ """Update colorbar according to plot default colormap"""
+ self.setColormap(self.getPlot().getDefaultColormap())
+
+ def getColorScaleBar(self):
+ """
+
+ :return: return the :class:`ColorScaleBar` used to display ColorScale
+ and ticks"""
+ return self._colorScale
+
+
+class _VerticalLegend(qt.QLabel):
+ """Display vertically the given text
+ """
+ def __init__(self, text, parent=None):
+ """
+
+ :param text: the legend
+ :param parent: the Qt parent if any
+ """
+ qt.QLabel.__init__(self, text, parent)
+ self.setLayout(qt.QVBoxLayout())
+ self.layout().setContentsMargins(0, 0, 0, 0)
+
+ def paintEvent(self, event):
+ painter = qt.QPainter(self)
+ painter.setFont(self.font())
+
+ painter.translate(0, self.rect().height())
+ painter.rotate(270)
+ newRect = qt.QRect(0, 0, self.rect().height(), self.rect().width())
+
+ painter.drawText(newRect, qt.Qt.AlignHCenter, self.text())
+
+ fm = qt.QFontMetrics(self.font())
+ preferedHeight = fm.width(self.text())
+ preferedWidth = fm.height()
+ self.setFixedWidth(preferedWidth)
+ self.setMinimumHeight(preferedHeight)
+
+
+class ColorScaleBar(qt.QWidget):
+ """This class is making the composition of a :class:`_ColorScale` and a
+ :class:`_TickBar`.
+
+ It is the simplest widget displaying ticks and colormap gradient.
+
+ .. image:: img/colorScaleBar.png
+ :width: 150px
+ :align: center
+
+ To run the following sample code, a QApplication must be initialized.
+
+ >>> colormap = Colormap(name='gray',
+ ... norm='log',
+ ... vmin=1,
+ ... vmax=100000,
+ ... )
+ >>> colorscale = ColorScaleBar(parent=None,
+ ... colormap=colormap )
+ >>> colorscale.show()
+
+ Initializer parameters :
+
+ :param colormap: the colormap to be displayed
+ :param parent: the Qt parent if any
+ :param displayTicksValues: display the ticks value or only the '-'
+ """
+
+ _TEXT_MARGIN = 5
+ """The tick bar need a margin to display all labels at the correct place.
+ So the ColorScale should have the same margin in order for both to fit"""
+
+ def __init__(self, parent=None, colormap=None, data=None,
+ displayTicksValues=True):
+ super(ColorScaleBar, self).__init__(parent)
+
+ self.minVal = None
+ """Value set to the _minLabel"""
+ self.maxVal = None
+ """Value set to the _maxLabel"""
+
+ self.setLayout(qt.QGridLayout())
+
+ # create the left side group (ColorScale)
+ self.colorScale = _ColorScale(colormap=colormap,
+ data=data,
+ parent=self,
+ margin=ColorScaleBar._TEXT_MARGIN)
+ if colormap:
+ vmin, vmax = colormap.getColormapRange(data)
+ normalizer = colormap._getNormalizer()
+ else:
+ vmin, vmax = colors.DEFAULT_MIN_LIN, colors.DEFAULT_MAX_LIN
+ normalizer = None
+
+ self.tickbar = _TickBar(vmin=vmin,
+ vmax=vmax,
+ normalizer=normalizer,
+ parent=self,
+ displayValues=displayTicksValues,
+ margin=ColorScaleBar._TEXT_MARGIN)
+
+ self.layout().addWidget(self.tickbar, 1, 0, 1, 1, qt.Qt.AlignRight)
+ self.layout().addWidget(self.colorScale, 1, 1, qt.Qt.AlignLeft)
+
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self.layout().setSpacing(0)
+
+ # max label
+ self._maxLabel = qt.QLabel(str(1.0), parent=self)
+ self._maxLabel.setToolTip(str(0.0))
+ self.layout().addWidget(self._maxLabel, 0, 0, 1, 2, qt.Qt.AlignRight)
+
+ # min label
+ self._minLabel = qt.QLabel(str(0.0), parent=self)
+ self._minLabel.setToolTip(str(0.0))
+ self.layout().addWidget(self._minLabel, 2, 0, 1, 2, qt.Qt.AlignRight)
+
+ self.layout().setSizeConstraint(qt.QLayout.SetMinAndMaxSize)
+ self.layout().setColumnStretch(0, 1)
+ self.layout().setRowStretch(1, 1)
+
+ def getTickBar(self):
+ """
+
+ :return: the instanciation of the :class:`_TickBar`
+ """
+ return self.tickbar
+
+ def getColorScale(self):
+ """
+
+ :return: the instanciation of the :class:`_ColorScale`
+ """
+ return self.colorScale
+
+ def getColormap(self):
+ """
+
+ :returns: the colormap.
+ :rtype: :class:`.Colormap`
+ """
+ return self.colorScale.getColormap()
+
+ def setColormap(self, colormap, data=None):
+ """Set the new colormap to be displayed
+
+ :param Colormap colormap: the colormap to set
+ :param Union[numpy.ndarray,~silx.gui.plot.items.Item] data:
+ The data or item to display, needed if the colormap requires an autoscale
+ """
+ self.colorScale.setColormap(colormap, data)
+
+ if colormap is not None:
+ vmin, vmax = colormap.getColormapRange(data)
+ normalizer = colormap._getNormalizer()
+ else:
+ vmin, vmax = None, None
+ normalizer = None
+
+ self.tickbar.update(vmin=vmin,
+ vmax=vmax,
+ normalizer=normalizer)
+ self._setMinMaxLabels(vmin, vmax)
+
+ def setMinMaxVisible(self, val=True):
+ """Change visibility of the min label and the max label
+
+ :param val: if True, set the labels visible, otherwise set it not visible
+ """
+ self._minLabel.setVisible(val)
+ self._maxLabel.setVisible(val)
+
+ def _updateMinMax(self):
+ """Update the min and max label if we are in the case of the
+ configuration 'minMaxValueOnly'"""
+ if self.minVal is None:
+ text, tooltip = '', ''
+ else:
+ if self.minVal == 0 or 0 <= numpy.log10(abs(self.minVal)) < 7:
+ text = '%.7g' % self.minVal
+ else:
+ text = '%.2e' % self.minVal
+ tooltip = repr(self.minVal)
+
+ self._minLabel.setText(text)
+ self._minLabel.setToolTip(tooltip)
+
+ if self.maxVal is None:
+ text, tooltip = '', ''
+ else:
+ if self.maxVal == 0 or 0 <= numpy.log10(abs(self.maxVal)) < 7:
+ text = '%.7g' % self.maxVal
+ else:
+ text = '%.2e' % self.maxVal
+ tooltip = repr(self.maxVal)
+
+ self._maxLabel.setText(text)
+ self._maxLabel.setToolTip(tooltip)
+
+ def _setMinMaxLabels(self, minVal, maxVal):
+ """Change the value of the min and max labels to be displayed.
+
+ :param minVal: the minimal value of the TickBar (not str)
+ :param maxVal: the maximal value of the TickBar (not str)
+ """
+ # bad hack to try to display has much information as possible
+ self.minVal = minVal
+ self.maxVal = maxVal
+ self._updateMinMax()
+
+ def resizeEvent(self, event):
+ qt.QWidget.resizeEvent(self, event)
+ self._updateMinMax()
+
+
+class _ColorScale(qt.QWidget):
+ """Widget displaying the colormap colorScale.
+
+ Show matching value between the gradient color (from the colormap) at mouse
+ position and value.
+
+ .. image:: img/colorScale.png
+ :width: 20px
+ :align: center
+
+
+ To run the following sample code, a QApplication must be initialized.
+
+ >>> colormap = Colormap(name='viridis',
+ ... norm='log',
+ ... vmin=1,
+ ... vmax=100000,
+ ... )
+ >>> colorscale = ColorScale(parent=None,
+ ... colormap=colormap)
+ >>> colorscale.show()
+
+ Initializer parameters :
+
+ :param colormap: the colormap to be displayed
+ :param parent: the Qt parent if any
+ :param int margin: the top and left margin to apply.
+ :param Union[None,numpy.ndarray,~silx.gui.plot.items.ColormapMixin] data:
+ The data or item to use for getting the range for autoscale colormap.
+
+ .. warning:: Value drawing will be
+ done at the center of ticks. So if no margin is done your values
+ drawing might not be fully done for extrems values.
+ """
+
+ _NB_CONTROL_POINTS = 256
+
+ def __init__(self, colormap, parent=None, margin=5, data=None):
+ qt.QWidget.__init__(self, parent)
+ self._colormap = None
+ self.margin = margin
+ self.setColormap(colormap, data)
+
+ self.setLayout(qt.QVBoxLayout())
+ self.setSizePolicy(qt.QSizePolicy.Fixed, qt.QSizePolicy.Expanding)
+ # needed to get the mouse event without waiting for button click
+ self.setMouseTracking(True)
+ self.setMargin(margin)
+ self.setContentsMargins(0, 0, 0, 0)
+
+ self.setMinimumHeight(self._NB_CONTROL_POINTS // 2 + 2 * self.margin)
+ self.setFixedWidth(25)
+
+ def setColormap(self, colormap, data=None):
+ """Set the new colormap to be displayed
+
+ :param dict colormap: the colormap to set
+ :param Union[None,numpy.ndarray,~silx.gui.plot.items.ColormapMixin] data:
+ Optional data for which to compute colormap range.
+ """
+ self._colormap = colormap
+ self.setEnabled(colormap is not None)
+
+ if colormap is None:
+ self.vmin, self.vmax = None, None
+ else:
+ assert colormap.getNormalization() in colors.Colormap.NORMALIZATIONS
+ self.vmin, self.vmax = self._colormap.getColormapRange(data=data)
+ self._updateColorGradient()
+ self.update()
+
+ def getColormap(self):
+ """Returns the colormap
+
+ :rtype: :class:`.Colormap`
+ """
+ return None if self._colormap is None else self._colormap
+
+ def _updateColorGradient(self):
+ """Compute the color gradient"""
+ colormap = self.getColormap()
+ if colormap is None:
+ return
+
+ indices = numpy.linspace(0., 1., self._NB_CONTROL_POINTS)
+ colors = colormap.getNColors(nbColors=self._NB_CONTROL_POINTS)
+ self._gradient = qt.QLinearGradient(0, 1, 0, 0)
+ self._gradient.setCoordinateMode(qt.QGradient.StretchToDeviceMode)
+ self._gradient.setStops(
+ [(i, qt.QColor(*color)) for i, color in zip(indices, colors)]
+ )
+
+ def paintEvent(self, event):
+ """"""
+ painter = qt.QPainter(self)
+ if self.getColormap() is not None:
+ painter.setBrush(self._gradient)
+ penColor = self.palette().color(qt.QPalette.Active,
+ qt.QPalette.WindowText)
+ else:
+ penColor = self.palette().color(qt.QPalette.Disabled,
+ qt.QPalette.WindowText)
+ painter.setPen(penColor)
+
+ painter.drawRect(qt.QRect(
+ 0,
+ self.margin,
+ self.width() - 1,
+ self.height() - 2 * self.margin - 1))
+
+ def mouseMoveEvent(self, event):
+ tooltip = str(self.getValueFromRelativePosition(
+ self._getRelativePosition(event.y())))
+ qt.QToolTip.showText(event.globalPos(), tooltip, self)
+ super(_ColorScale, self).mouseMoveEvent(event)
+
+ def _getRelativePosition(self, yPixel):
+ """yPixel : pixel position into _ColorScale widget reference
+ """
+ # widgets are bottom-top referencial but we display in top-bottom referential
+ return 1. - (yPixel - self.margin) / float(self.height() - 2 * self.margin)
+
+ def getValueFromRelativePosition(self, value):
+ """Return the value in the colorMap from a relative position in the
+ ColorScaleBar (y)
+
+ :param value: float value in [0, 1]
+ :return: the value in [colormap['vmin'], colormap['vmax']]
+ """
+ colormap = self.getColormap()
+ if colormap is None:
+ return
+
+ value = numpy.clip(value, 0., 1.)
+ normalizer = colormap._getNormalizer()
+ normMin, normMax = normalizer.apply([self.vmin, self.vmax], self.vmin, self.vmax)
+
+ return normalizer.revert(
+ normMin + (normMax - normMin) * value, self.vmin, self.vmax)
+
+ def setMargin(self, margin):
+ """Define the margin to fit with a TickBar object.
+ This is needed since we can only paint on the viewport of the widget.
+ Didn't work with a simple setContentsMargins
+
+ :param int margin: the margin to apply on the top and bottom.
+ """
+ self.margin = int(margin)
+ self.update()
+
+
+class _TickBar(qt.QWidget):
+ """Bar grouping the ticks displayed
+
+ To run the following sample code, a QApplication must be initialized.
+
+ >>> bar = _TickBar(1, 1000, norm='log', parent=None, displayValues=True)
+ >>> bar.show()
+
+ .. image:: img/tickbar.png
+ :width: 40px
+ :align: center
+
+ :param int vmin: smaller value of the range of values
+ :param int vmax: higher value of the range of values
+ :param normalizer: Normalization object.
+ :param parent: the Qt parent if any
+ :param bool displayValues: if True display the values close to the tick,
+ Otherwise only signal it by '-'
+ :param int nticks: the number of tick we want to display. Should be an
+ unsigned int ot None. If None, let the Tick bar find the optimal
+ number of ticks from the tick density.
+ :param int margin: margin to set on the top and bottom
+ """
+ _WIDTH_DISP_VAL = 45
+ """widget width when displayed with ticks labels"""
+ _WIDTH_NO_DISP_VAL = 10
+ """widget width when displayed without ticks labels"""
+ _FONT_SIZE = 10
+ """font size for ticks labels"""
+ _LINE_WIDTH = 10
+ """width of the line to mark a tick"""
+
+ DEFAULT_TICK_DENSITY = 0.015
+
+ def __init__(self, vmin, vmax, normalizer, parent=None, displayValues=True,
+ nticks=None, margin=5):
+ super(_TickBar, self).__init__(parent)
+ self.margin = margin
+ self._nticks = None
+ self.ticks = ()
+ self.subTicks = ()
+ self._forcedDisplayType = None
+ self.ticksDensity = _TickBar.DEFAULT_TICK_DENSITY
+
+ self._vmin = vmin
+ self._vmax = vmax
+ self._normalizer = normalizer
+ self.displayValues = displayValues
+ self.setTicksNumber(nticks)
+
+ self.setMargin(margin)
+ self.setContentsMargins(0, 0, 0, 0)
+
+ self._resetWidth()
+
+ def setTicksValuesVisible(self, val):
+ self.displayValues = val
+ self._resetWidth()
+
+ def _resetWidth(self):
+ width = self._WIDTH_DISP_VAL if self.displayValues else self._WIDTH_NO_DISP_VAL
+ self.setFixedWidth(width)
+
+ def update(self, vmin, vmax, normalizer):
+ self._vmin = vmin
+ self._vmax = vmax
+ self._normalizer = normalizer
+ self.computeTicks()
+ qt.QWidget.update(self)
+
+ def setMargin(self, margin):
+ """Define the margin to fit with a _ColorScale object.
+ This is needed since we can only paint on the viewport of the widget
+
+ :param int margin: the margin to apply on the top and bottom.
+ """
+ self.margin = margin
+
+ def setTicksNumber(self, nticks):
+ """Set the number of ticks to display.
+
+ :param nticks: the number of tick to be display. Should be an
+ unsigned int ot None. If None, let the :class:`_TickBar` find the
+ optimal number of ticks from the tick density.
+ """
+ self._nticks = nticks
+ self.computeTicks()
+ qt.QWidget.update(self)
+
+ def setTicksDensity(self, density):
+ """If you let :class:`_TickBar` deal with the number of ticks
+ (nticks=None) then you can specify a ticks density to be displayed.
+ """
+ if density < 0.0:
+ raise ValueError('Density should be a positive value')
+ self.ticksDensity = density
+
+ def computeTicks(self):
+ """This function compute ticks values labels. It is called at each
+ update and each resize event.
+ Deal only with linear and log scale.
+ """
+ nticks = self._nticks
+ if nticks is None:
+ nticks = self._getOptimalNbTicks()
+
+ if self._vmin == self._vmax:
+ # No range: no ticks
+ self.ticks = ()
+ self.subTicks = ()
+ elif isinstance(self._normalizer, LogarithmicNormalization):
+ self._computeTicksLog(nticks)
+ else: # Fallback: use linear
+ self._computeTicksLin(nticks)
+
+ # update the form
+ font = qt.QFont()
+ font.setPixelSize(_TickBar._FONT_SIZE)
+
+ self.form = self._getFormat(font)
+
+ def _computeTicksLog(self, nticks):
+ logMin = numpy.log10(self._vmin)
+ logMax = numpy.log10(self._vmax)
+ lowBound, highBound, spacing, self._nfrac = ticklayout.niceNumbersForLog10(logMin,
+ logMax,
+ nticks)
+ self.ticks = numpy.power(10., numpy.arange(lowBound, highBound, spacing))
+ if spacing == 1:
+ self.subTicks = ticklayout.computeLogSubTicks(ticks=self.ticks,
+ lowBound=numpy.power(10., lowBound),
+ highBound=numpy.power(10., highBound))
+ else:
+ self.subTicks = []
+
+ def resizeEvent(self, event):
+ qt.QWidget.resizeEvent(self, event)
+ self.computeTicks()
+
+ def _computeTicksLin(self, nticks):
+ _min, _max, _spacing, self._nfrac = ticklayout.niceNumbers(self._vmin,
+ self._vmax,
+ nticks)
+
+ self.ticks = numpy.arange(_min, _max, _spacing)
+ self.subTicks = []
+
+ def _getOptimalNbTicks(self):
+ return max(2, int(round(self.ticksDensity * self.rect().height())))
+
+ def paintEvent(self, event):
+ painter = qt.QPainter(self)
+ font = painter.font()
+ font.setPixelSize(_TickBar._FONT_SIZE)
+ painter.setFont(font)
+
+ # paint ticks
+ for val in self.ticks:
+ self._paintTick(val, painter, majorTick=True)
+
+ # paint subticks
+ for val in self.subTicks:
+ self._paintTick(val, painter, majorTick=False)
+
+ def _getRelativePosition(self, val):
+ """Return the relative position of val according to min and max value
+ """
+ if self._normalizer is None:
+ return 0.
+ normMin, normMax, normVal = self._normalizer.apply(
+ [self._vmin, self._vmax, val],
+ self._vmin,
+ self._vmax)
+
+ if normMin == normMax:
+ return 0.
+ else:
+ return 1. - (normVal - normMin) / (normMax - normMin)
+
+ def _paintTick(self, val, painter, majorTick=True):
+ """
+
+ :param bool majorTick: if False will never draw text and will set a line
+ with a smaller width
+ """
+ fm = qt.QFontMetrics(painter.font())
+ viewportHeight = self.rect().height() - self.margin * 2 - 1
+ relativePos = self._getRelativePosition(val)
+ height = int(viewportHeight * relativePos + self.margin)
+ lineWidth = _TickBar._LINE_WIDTH
+ if majorTick is False:
+ lineWidth /= 2
+
+ painter.drawLine(qt.QLine(int(self.width() - lineWidth),
+ height,
+ self.width(),
+ height))
+
+ if self.displayValues and majorTick is True:
+ painter.drawText(qt.QPoint(0, int(height + fm.height() / 2)),
+ self.form.format(val))
+
+ def setDisplayType(self, disType):
+ """Set the type of display we want to set for ticks labels
+
+ :param str disType: The type of display we want to set. disType values
+ can be :
+
+ - 'std' for standard, meaning only a formatting on the number of
+ digits is done
+ - 'e' for scientific display
+ - None to let the _TickBar guess the best display for this kind of data.
+ """
+ if disType not in (None, 'std', 'e'):
+ raise ValueError("display type not recognized, value should be in (None, 'std', 'e'")
+ self._forcedDisplayType = disType
+
+ def _getStandardFormat(self):
+ return "{0:.%sf}" % self._nfrac
+
+ def _getFormat(self, font):
+ if self._forcedDisplayType is None:
+ return self._guessType(font)
+ elif self._forcedDisplayType == 'std':
+ return self._getStandardFormat()
+ elif self._forcedDisplayType == 'e':
+ return self._getScientificForm()
+ else:
+ err = 'Forced type for display %s is not recognized' % self._forcedDisplayType
+ raise ValueError(err)
+
+ def _getScientificForm(self):
+ return "{0:.0e}"
+
+ def _guessType(self, font):
+ """Try fo find the better format to display the tick's labels
+
+ :param QFont font: the font we want to use during the painting
+ """
+ form = self._getStandardFormat()
+
+ fm = qt.QFontMetrics(font)
+ width = 0
+ for tick in self.ticks:
+ width = max(fm.boundingRect(form.format(tick)).width(), width)
+
+ # if the length of the string are too long we are moving to scientific
+ # display
+ if width > _TickBar._WIDTH_DISP_VAL - _TickBar._LINE_WIDTH:
+ return self._getScientificForm()
+ else:
+ return form
diff --git a/src/silx/gui/plot/Colormap.py b/src/silx/gui/plot/Colormap.py
new file mode 100644
index 0000000..22fea7f
--- /dev/null
+++ b/src/silx/gui/plot/Colormap.py
@@ -0,0 +1,42 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Deprecated module providing the Colormap object
+"""
+
+__authors__ = ["T. Vincent", "H.Payno"]
+__license__ = "MIT"
+__date__ = "27/11/2020"
+
+import silx.utils.deprecation
+
+silx.utils.deprecation.deprecated_warning("Module",
+ name="silx.gui.plot.Colormap",
+ reason="moved",
+ replacement="silx.gui.colors.Colormap",
+ since_version="0.8.0",
+ only_once=True,
+ skip_backtrace_count=1)
+
+from ..colors import * # noqa
diff --git a/src/silx/gui/plot/ColormapDialog.py b/src/silx/gui/plot/ColormapDialog.py
new file mode 100644
index 0000000..7c66cb8
--- /dev/null
+++ b/src/silx/gui/plot/ColormapDialog.py
@@ -0,0 +1,43 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Deprecated module providing ColormapDialog."""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent", "H.Payno"]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+import silx.utils.deprecation
+
+silx.utils.deprecation.deprecated_warning("Module",
+ name="silx.gui.plot.ColormapDialog",
+ reason="moved",
+ replacement="silx.gui.dialog.ColormapDialog",
+ since_version="0.8.0",
+ only_once=True,
+ skip_backtrace_count=1)
+
+from ..dialog.ColormapDialog import * # noqa
diff --git a/src/silx/gui/plot/Colors.py b/src/silx/gui/plot/Colors.py
new file mode 100644
index 0000000..277e104
--- /dev/null
+++ b/src/silx/gui/plot/Colors.py
@@ -0,0 +1,90 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Color conversion function, color dictionary and colormap tools."""
+
+from __future__ import absolute_import
+
+__authors__ = ["V.A. Sole", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "14/06/2018"
+
+import silx.utils.deprecation
+
+silx.utils.deprecation.deprecated_warning("Module",
+ name="silx.gui.plot.Colors",
+ reason="moved",
+ replacement="silx.gui.colors",
+ since_version="0.8.0",
+ only_once=True,
+ skip_backtrace_count=1)
+
+from ..colors import * # noqa
+
+
+@silx.utils.deprecation.deprecated(replacement='silx.gui.colors.Colormap.applyColormap')
+def applyColormapToData(data,
+ name='gray',
+ normalization='linear',
+ autoscale=True,
+ vmin=0.,
+ vmax=1.,
+ colors=None):
+ """Apply a colormap to the data and returns the RGBA image
+
+ This supports data of any dimensions (not only of dimension 2).
+ The returned array will have one more dimension (with 4 entries)
+ than the input data to store the RGBA channels
+ corresponding to each bin in the array.
+
+ :param numpy.ndarray data: The data to convert.
+ :param str name: Name of the colormap (default: 'gray').
+ :param str normalization: Colormap mapping: 'linear' or 'log'.
+ :param bool autoscale: Whether to use data min/max (True, default)
+ or [vmin, vmax] range (False).
+ :param float vmin: The minimum value of the range to use if
+ 'autoscale' is False.
+ :param float vmax: The maximum value of the range to use if
+ 'autoscale' is False.
+ :param numpy.ndarray colors: Only used if name is None.
+ Custom colormap colors as Nx3 or Nx4 RGB or RGBA arrays
+ :return: The computed RGBA image
+ :rtype: numpy.ndarray of uint8
+ """
+ colormap = Colormap(name=name,
+ normalization=normalization,
+ vmin=vmin,
+ vmax=vmax,
+ colors=colors)
+ return colormap.applyToData(data)
+
+
+@silx.utils.deprecation.deprecated(replacement='silx.gui.colors.Colormap.getSupportedColormaps')
+def getSupportedColormaps():
+ """Get the supported colormap names as a tuple of str.
+
+ The list should at least contain and start by:
+ ('gray', 'reversed gray', 'temperature', 'red', 'green', 'blue')
+ """
+ return Colormap.getSupportedColormaps()
diff --git a/src/silx/gui/plot/CompareImages.py b/src/silx/gui/plot/CompareImages.py
new file mode 100644
index 0000000..857fc79
--- /dev/null
+++ b/src/silx/gui/plot/CompareImages.py
@@ -0,0 +1,1259 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""A widget dedicated to compare 2 images.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "23/07/2018"
+
+
+import enum
+import logging
+import numpy
+import weakref
+import collections
+import math
+
+import silx.image.bilinear
+from silx.gui import qt
+from silx.gui import plot
+from silx.gui import icons
+from silx.gui.colors import Colormap
+from silx.gui.plot import tools
+from silx.utils.weakref import WeakMethodProxy
+
+_logger = logging.getLogger(__name__)
+
+from silx.opencl import ocl
+if ocl is not None:
+ try:
+ from silx.opencl import sift
+ except ImportError:
+ # sift module is not available (e.g., in official Debian packages)
+ sift = None
+else: # No OpenCL device or no pyopencl
+ sift = None
+
+
+@enum.unique
+class VisualizationMode(enum.Enum):
+ """Enum for each visualization mode available."""
+ ONLY_A = 'a'
+ ONLY_B = 'b'
+ VERTICAL_LINE = 'vline'
+ HORIZONTAL_LINE = 'hline'
+ COMPOSITE_RED_BLUE_GRAY = "rbgchannel"
+ COMPOSITE_RED_BLUE_GRAY_NEG = "rbgnegchannel"
+ COMPOSITE_A_MINUS_B = "aminusb"
+
+
+@enum.unique
+class AlignmentMode(enum.Enum):
+ """Enum for each alignment mode available."""
+ ORIGIN = 'origin'
+ CENTER = 'center'
+ STRETCH = 'stretch'
+ AUTO = 'auto'
+
+
+AffineTransformation = collections.namedtuple("AffineTransformation",
+ ["tx", "ty", "sx", "sy", "rot"])
+"""Contains a 2D affine transformation: translation, scale and rotation"""
+
+
+class CompareImagesToolBar(qt.QToolBar):
+ """ToolBar containing specific tools to custom the configuration of a
+ :class:`CompareImages` widget
+
+ Use :meth:`setCompareWidget` to connect this toolbar to a specific
+ :class:`CompareImages` widget.
+
+ :param Union[qt.QWidget,None] parent: Parent of this widget.
+ """
+ def __init__(self, parent=None):
+ qt.QToolBar.__init__(self, parent)
+
+ self.__compareWidget = None
+
+ menu = qt.QMenu(self)
+ self.__visualizationToolButton = qt.QToolButton(self)
+ self.__visualizationToolButton.setMenu(menu)
+ self.__visualizationToolButton.setPopupMode(qt.QToolButton.InstantPopup)
+ self.addWidget(self.__visualizationToolButton)
+ self.__visualizationGroup = qt.QActionGroup(self)
+ self.__visualizationGroup.setExclusive(True)
+ self.__visualizationGroup.triggered.connect(self.__visualizationModeChanged)
+
+ icon = icons.getQIcon("compare-mode-a")
+ action = qt.QAction(icon, "Display the first image only", self)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_A))
+ action.setProperty("mode", VisualizationMode.ONLY_A)
+ menu.addAction(action)
+ self.__aModeAction = action
+ self.__visualizationGroup.addAction(action)
+
+ icon = icons.getQIcon("compare-mode-b")
+ action = qt.QAction(icon, "Display the second image only", self)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_B))
+ action.setProperty("mode", VisualizationMode.ONLY_B)
+ menu.addAction(action)
+ self.__bModeAction = action
+ self.__visualizationGroup.addAction(action)
+
+ icon = icons.getQIcon("compare-mode-vline")
+ action = qt.QAction(icon, "Vertical compare mode", self)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_V))
+ action.setProperty("mode", VisualizationMode.VERTICAL_LINE)
+ menu.addAction(action)
+ self.__vlineModeAction = action
+ self.__visualizationGroup.addAction(action)
+
+ icon = icons.getQIcon("compare-mode-hline")
+ action = qt.QAction(icon, "Horizontal compare mode", self)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_H))
+ action.setProperty("mode", VisualizationMode.HORIZONTAL_LINE)
+ menu.addAction(action)
+ self.__hlineModeAction = action
+ self.__visualizationGroup.addAction(action)
+
+ icon = icons.getQIcon("compare-mode-rb-channel")
+ action = qt.QAction(icon, "Blue/red compare mode (additive mode)", self)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_C))
+ action.setProperty("mode", VisualizationMode.COMPOSITE_RED_BLUE_GRAY)
+ menu.addAction(action)
+ self.__brChannelModeAction = action
+ self.__visualizationGroup.addAction(action)
+
+ icon = icons.getQIcon("compare-mode-rbneg-channel")
+ action = qt.QAction(icon, "Yellow/cyan compare mode (subtractive mode)", self)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_W))
+ action.setProperty("mode", VisualizationMode.COMPOSITE_RED_BLUE_GRAY_NEG)
+ menu.addAction(action)
+ self.__ycChannelModeAction = action
+ self.__visualizationGroup.addAction(action)
+
+ icon = icons.getQIcon("compare-mode-a-minus-b")
+ action = qt.QAction(icon, "Raw A minus B compare mode", self)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_W))
+ action.setProperty("mode", VisualizationMode.COMPOSITE_A_MINUS_B)
+ menu.addAction(action)
+ self.__ycChannelModeAction = action
+ self.__visualizationGroup.addAction(action)
+
+ menu = qt.QMenu(self)
+ self.__alignmentToolButton = qt.QToolButton(self)
+ self.__alignmentToolButton.setMenu(menu)
+ self.__alignmentToolButton.setPopupMode(qt.QToolButton.InstantPopup)
+ self.addWidget(self.__alignmentToolButton)
+ self.__alignmentGroup = qt.QActionGroup(self)
+ self.__alignmentGroup.setExclusive(True)
+ self.__alignmentGroup.triggered.connect(self.__alignmentModeChanged)
+
+ icon = icons.getQIcon("compare-align-origin")
+ action = qt.QAction(icon, "Align images on their upper-left pixel", self)
+ action.setProperty("mode", AlignmentMode.ORIGIN)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ self.__originAlignAction = action
+ menu.addAction(action)
+ self.__alignmentGroup.addAction(action)
+
+ icon = icons.getQIcon("compare-align-center")
+ action = qt.QAction(icon, "Center images", self)
+ action.setProperty("mode", AlignmentMode.CENTER)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ self.__centerAlignAction = action
+ menu.addAction(action)
+ self.__alignmentGroup.addAction(action)
+
+ icon = icons.getQIcon("compare-align-stretch")
+ action = qt.QAction(icon, "Stretch the second image on the first one", self)
+ action.setProperty("mode", AlignmentMode.STRETCH)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ self.__stretchAlignAction = action
+ menu.addAction(action)
+ self.__alignmentGroup.addAction(action)
+
+ icon = icons.getQIcon("compare-align-auto")
+ action = qt.QAction(icon, "Auto-alignment of the second image", self)
+ action.setProperty("mode", AlignmentMode.AUTO)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ self.__autoAlignAction = action
+ menu.addAction(action)
+ if sift is None:
+ action.setEnabled(False)
+ action.setToolTip("Sift module is not available")
+ self.__alignmentGroup.addAction(action)
+
+ icon = icons.getQIcon("compare-keypoints")
+ action = qt.QAction(icon, "Display/hide alignment keypoints", self)
+ action.setCheckable(True)
+ action.triggered.connect(self.__keypointVisibilityChanged)
+ self.addAction(action)
+ self.__displayKeypoints = action
+
+ def setCompareWidget(self, widget):
+ """
+ Connect this tool bar to a specific :class:`CompareImages` widget.
+
+ :param Union[None,CompareImages] widget: The widget to connect with.
+ """
+ compareWidget = self.getCompareWidget()
+ if compareWidget is not None:
+ compareWidget.sigConfigurationChanged.disconnect(self.__updateSelectedActions)
+ compareWidget = widget
+ if compareWidget is None:
+ self.__compareWidget = None
+ else:
+ self.__compareWidget = weakref.ref(compareWidget)
+ if compareWidget is not None:
+ widget.sigConfigurationChanged.connect(self.__updateSelectedActions)
+ self.__updateSelectedActions()
+
+ def getCompareWidget(self):
+ """Returns the connected widget.
+
+ :rtype: CompareImages
+ """
+ if self.__compareWidget is None:
+ return None
+ else:
+ return self.__compareWidget()
+
+ def __updateSelectedActions(self):
+ """
+ Update the state of this tool bar according to the state of the
+ connected :class:`CompareImages` widget.
+ """
+ widget = self.getCompareWidget()
+ if widget is None:
+ return
+
+ mode = widget.getVisualizationMode()
+ action = None
+ for a in self.__visualizationGroup.actions():
+ actionMode = a.property("mode")
+ if mode == actionMode:
+ action = a
+ break
+ old = self.__visualizationGroup.blockSignals(True)
+ if action is not None:
+ # Check this action
+ action.setChecked(True)
+ else:
+ action = self.__visualizationGroup.checkedAction()
+ if action is not None:
+ # Uncheck this action
+ action.setChecked(False)
+ self.__updateVisualizationMenu()
+ self.__visualizationGroup.blockSignals(old)
+
+ mode = widget.getAlignmentMode()
+ action = None
+ for a in self.__alignmentGroup.actions():
+ actionMode = a.property("mode")
+ if mode == actionMode:
+ action = a
+ break
+ old = self.__alignmentGroup.blockSignals(True)
+ if action is not None:
+ # Check this action
+ action.setChecked(True)
+ else:
+ action = self.__alignmentGroup.checkedAction()
+ if action is not None:
+ # Uncheck this action
+ action.setChecked(False)
+ self.__updateAlignmentMenu()
+ self.__alignmentGroup.blockSignals(old)
+
+ def __visualizationModeChanged(self, selectedAction):
+ """Called when user requesting changes of the visualization mode.
+ """
+ self.__updateVisualizationMenu()
+ widget = self.getCompareWidget()
+ if widget is not None:
+ mode = selectedAction.property("mode")
+ widget.setVisualizationMode(mode)
+
+ def __updateVisualizationMenu(self):
+ """Update the state of the action containing visualization menu.
+ """
+ selectedAction = self.__visualizationGroup.checkedAction()
+ if selectedAction is not None:
+ self.__visualizationToolButton.setText(selectedAction.text())
+ self.__visualizationToolButton.setIcon(selectedAction.icon())
+ self.__visualizationToolButton.setToolTip(selectedAction.toolTip())
+ else:
+ self.__visualizationToolButton.setText("")
+ self.__visualizationToolButton.setIcon(qt.QIcon())
+ self.__visualizationToolButton.setToolTip("")
+
+ def __alignmentModeChanged(self, selectedAction):
+ """Called when user requesting changes of the alignment mode.
+ """
+ self.__updateAlignmentMenu()
+ widget = self.getCompareWidget()
+ if widget is not None:
+ mode = selectedAction.property("mode")
+ widget.setAlignmentMode(mode)
+
+ def __updateAlignmentMenu(self):
+ """Update the state of the action containing alignment menu.
+ """
+ selectedAction = self.__alignmentGroup.checkedAction()
+ if selectedAction is not None:
+ self.__alignmentToolButton.setText(selectedAction.text())
+ self.__alignmentToolButton.setIcon(selectedAction.icon())
+ self.__alignmentToolButton.setToolTip(selectedAction.toolTip())
+ else:
+ self.__alignmentToolButton.setText("")
+ self.__alignmentToolButton.setIcon(qt.QIcon())
+ self.__alignmentToolButton.setToolTip("")
+
+ def __keypointVisibilityChanged(self):
+ """Called when action managing keypoints visibility changes"""
+ widget = self.getCompareWidget()
+ if widget is not None:
+ keypointsVisible = self.__displayKeypoints.isChecked()
+ widget.setKeypointsVisible(keypointsVisible)
+
+
+class CompareImagesStatusBar(qt.QStatusBar):
+ """StatusBar containing specific information contained in a
+ :class:`CompareImages` widget
+
+ Use :meth:`setCompareWidget` to connect this toolbar to a specific
+ :class:`CompareImages` widget.
+
+ :param Union[qt.QWidget,None] parent: Parent of this widget.
+ """
+ def __init__(self, parent=None):
+ qt.QStatusBar.__init__(self, parent)
+ self.setSizeGripEnabled(False)
+ self.layout().setSpacing(0)
+ self.__compareWidget = None
+ self._label1 = qt.QLabel(self)
+ self._label1.setFrameShape(qt.QFrame.WinPanel)
+ self._label1.setFrameShadow(qt.QFrame.Sunken)
+ self._label2 = qt.QLabel(self)
+ self._label2.setFrameShape(qt.QFrame.WinPanel)
+ self._label2.setFrameShadow(qt.QFrame.Sunken)
+ self._transform = qt.QLabel(self)
+ self._transform.setFrameShape(qt.QFrame.WinPanel)
+ self._transform.setFrameShadow(qt.QFrame.Sunken)
+ self.addWidget(self._label1)
+ self.addWidget(self._label2)
+ self.addWidget(self._transform)
+ self._pos = None
+ self._updateStatusBar()
+
+ def setCompareWidget(self, widget):
+ """
+ Connect this tool bar to a specific :class:`CompareImages` widget.
+
+ :param Union[None,CompareImages] widget: The widget to connect with.
+ """
+ compareWidget = self.getCompareWidget()
+ if compareWidget is not None:
+ compareWidget.getPlot().sigPlotSignal.disconnect(self.__plotSignalReceived)
+ compareWidget.sigConfigurationChanged.disconnect(self.__dataChanged)
+ compareWidget = widget
+ if compareWidget is None:
+ self.__compareWidget = None
+ else:
+ self.__compareWidget = weakref.ref(compareWidget)
+ if compareWidget is not None:
+ compareWidget.getPlot().sigPlotSignal.connect(self.__plotSignalReceived)
+ compareWidget.sigConfigurationChanged.connect(self.__dataChanged)
+
+ def getCompareWidget(self):
+ """Returns the connected widget.
+
+ :rtype: CompareImages
+ """
+ if self.__compareWidget is None:
+ return None
+ else:
+ return self.__compareWidget()
+
+ def __plotSignalReceived(self, event):
+ """Called when old style signals at emmited from the plot."""
+ if event["event"] == "mouseMoved":
+ x, y = event["x"], event["y"]
+ self.__mouseMoved(x, y)
+
+ def __mouseMoved(self, x, y):
+ """Called when mouse move over the plot."""
+ self._pos = x, y
+ self._updateStatusBar()
+
+ def __dataChanged(self):
+ """Called when internal data from the connected widget changes."""
+ self._updateStatusBar()
+
+ def _formatData(self, data):
+ """Format pixel of an image.
+
+ It supports intensity, RGB, and RGBA.
+
+ :param Union[int,float,numpy.ndarray,str]: Value of a pixel
+ :rtype: str
+ """
+ if data is None:
+ return "No data"
+ if isinstance(data, (int, numpy.integer)):
+ return "%d" % data
+ if isinstance(data, (float, numpy.floating)):
+ return "%f" % data
+ if isinstance(data, numpy.ndarray):
+ # RGBA value
+ if data.shape == (3,):
+ return "R:%d G:%d B:%d" % (data[0], data[1], data[2])
+ elif data.shape == (4,):
+ return "R:%d G:%d B:%d A:%d" % (data[0], data[1], data[2], data[3])
+ _logger.debug("Unsupported data format %s. Cast it to string.", type(data))
+ return str(data)
+
+ def _updateStatusBar(self):
+ """Update the content of the status bar"""
+ widget = self.getCompareWidget()
+ if widget is None:
+ self._label1.setText("Image1: NA")
+ self._label2.setText("Image2: NA")
+ self._transform.setVisible(False)
+ else:
+ transform = widget.getTransformation()
+ self._transform.setVisible(transform is not None)
+ if transform is not None:
+ has_notable_translation = not numpy.isclose(transform.tx, 0.0, atol=0.01) \
+ or not numpy.isclose(transform.ty, 0.0, atol=0.01)
+ has_notable_scale = not numpy.isclose(transform.sx, 1.0, atol=0.01) \
+ or not numpy.isclose(transform.sy, 1.0, atol=0.01)
+ has_notable_rotation = not numpy.isclose(transform.rot, 0.0, atol=0.01)
+
+ strings = []
+ if has_notable_translation:
+ strings.append("Translation")
+ if has_notable_scale:
+ strings.append("Scale")
+ if has_notable_rotation:
+ strings.append("Rotation")
+ if strings == []:
+ has_translation = not numpy.isclose(transform.tx, 0.0) \
+ or not numpy.isclose(transform.ty, 0.0)
+ has_scale = not numpy.isclose(transform.sx, 1.0) \
+ or not numpy.isclose(transform.sy, 1.0)
+ has_rotation = not numpy.isclose(transform.rot, 0.0)
+ if has_translation or has_scale or has_rotation:
+ text = "No big changes"
+ else:
+ text = "No changes"
+ else:
+ text = "+".join(strings)
+ self._transform.setText("Align: " + text)
+
+ strings = []
+ if not numpy.isclose(transform.ty, 0.0):
+ strings.append("Translation x: %0.3fpx" % transform.tx)
+ if not numpy.isclose(transform.ty, 0.0):
+ strings.append("Translation y: %0.3fpx" % transform.ty)
+ if not numpy.isclose(transform.sx, 1.0):
+ strings.append("Scale x: %0.3f" % transform.sx)
+ if not numpy.isclose(transform.sy, 1.0):
+ strings.append("Scale y: %0.3f" % transform.sy)
+ if not numpy.isclose(transform.rot, 0.0):
+ strings.append("Rotation: %0.3fdeg" % (transform.rot * 180 / numpy.pi))
+ if strings == []:
+ text = "No transformation"
+ else:
+ text = "\n".join(strings)
+ self._transform.setToolTip(text)
+
+ if self._pos is None:
+ self._label1.setText("Image1: NA")
+ self._label2.setText("Image2: NA")
+ else:
+ data1, data2 = widget.getRawPixelData(self._pos[0], self._pos[1])
+ if isinstance(data1, str):
+ self._label1.setToolTip(data1)
+ text1 = "NA"
+ else:
+ self._label1.setToolTip("")
+ text1 = self._formatData(data1)
+ if isinstance(data2, str):
+ self._label2.setToolTip(data2)
+ text2 = "NA"
+ else:
+ self._label2.setToolTip("")
+ text2 = self._formatData(data2)
+ self._label1.setText("Image1: %s" % text1)
+ self._label2.setText("Image2: %s" % text2)
+
+
+class CompareImages(qt.QMainWindow):
+ """Widget providing tools to compare 2 images.
+
+ .. image:: img/CompareImages.png
+
+ :param Union[qt.QWidget,None] parent: Parent of this widget.
+ :param backend: The backend to use, in:
+ 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'none'
+ or a :class:`BackendBase.BackendBase` class
+ :type backend: str or :class:`BackendBase.BackendBase`
+ """
+
+ VisualizationMode = VisualizationMode
+ """Available visualization modes"""
+
+ AlignmentMode = AlignmentMode
+ """Available alignment modes"""
+
+ sigConfigurationChanged = qt.Signal()
+ """Emitted when the configuration of the widget (visualization mode,
+ alignement mode...) have changed."""
+
+ def __init__(self, parent=None, backend=None):
+ qt.QMainWindow.__init__(self, parent)
+ self._resetZoomActive = True
+ self._colormap = Colormap()
+ """Colormap shared by all modes, except the compose images (rgb image)"""
+ self._colormapKeyPoints = Colormap('spring')
+ """Colormap used for sift keypoints"""
+
+ if parent is None:
+ self.setWindowTitle('Compare images')
+ else:
+ self.setWindowFlags(qt.Qt.Widget)
+
+ self.__transformation = None
+ self.__raw1 = None
+ self.__raw2 = None
+ self.__data1 = None
+ self.__data2 = None
+ self.__previousSeparatorPosition = None
+
+ self.__plot = plot.PlotWidget(parent=self, backend=backend)
+ self.__plot.setDefaultColormap(self._colormap)
+ self.__plot.getXAxis().setLabel('Columns')
+ self.__plot.getYAxis().setLabel('Rows')
+ if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == 'downward':
+ self.__plot.getYAxis().setInverted(True)
+
+ self.__plot.setKeepDataAspectRatio(True)
+ self.__plot.sigPlotSignal.connect(self.__plotSlot)
+ self.__plot.setAxesDisplayed(False)
+
+ self.setCentralWidget(self.__plot)
+
+ legend = VisualizationMode.VERTICAL_LINE.name
+ self.__plot.addXMarker(
+ 0,
+ legend=legend,
+ text='',
+ draggable=True,
+ color='blue',
+ constraint=WeakMethodProxy(self.__separatorConstraint))
+ self.__vline = self.__plot._getMarker(legend)
+
+ legend = VisualizationMode.HORIZONTAL_LINE.name
+ self.__plot.addYMarker(
+ 0,
+ legend=legend,
+ text='',
+ draggable=True,
+ color='blue',
+ constraint=WeakMethodProxy(self.__separatorConstraint))
+ self.__hline = self.__plot._getMarker(legend)
+
+ # default values
+ self.__visualizationMode = ""
+ self.__alignmentMode = ""
+ self.__keypointsVisible = True
+
+ self.setAlignmentMode(AlignmentMode.ORIGIN)
+ self.setVisualizationMode(VisualizationMode.VERTICAL_LINE)
+ self.setKeypointsVisible(False)
+
+ # Toolbars
+
+ self._createToolBars(self.__plot)
+ if self._interactiveModeToolBar is not None:
+ self.addToolBar(self._interactiveModeToolBar)
+ if self._imageToolBar is not None:
+ self.addToolBar(self._imageToolBar)
+ if self._compareToolBar is not None:
+ self.addToolBar(self._compareToolBar)
+
+ # Statusbar
+
+ self._createStatusBar(self.__plot)
+ if self._statusBar is not None:
+ self.setStatusBar(self._statusBar)
+
+ def _createStatusBar(self, plot):
+ self._statusBar = CompareImagesStatusBar(self)
+ self._statusBar.setCompareWidget(self)
+
+ def _createToolBars(self, plot):
+ """Create tool bars displayed by the widget"""
+ toolBar = tools.InteractiveModeToolBar(parent=self, plot=plot)
+ self._interactiveModeToolBar = toolBar
+ toolBar = tools.ImageToolBar(parent=self, plot=plot)
+ self._imageToolBar = toolBar
+ toolBar = CompareImagesToolBar(self)
+ toolBar.setCompareWidget(self)
+ self._compareToolBar = toolBar
+
+ def getPlot(self):
+ """Returns the plot which is used to display the images.
+
+ :rtype: silx.gui.plot.PlotWidget
+ """
+ return self.__plot
+
+ def getColormap(self):
+ """
+
+ :return: colormap used for compare image
+ :rtype: silx.gui.colors.Colormap
+ """
+ return self._colormap
+
+ def getRawPixelData(self, x, y):
+ """Return the raw pixel of each image data from axes positions.
+
+ If the coordinate is outside of the image it returns None element in
+ the tuple.
+
+ The pixel is reach from the raw data image without filter or
+ transformation. But the coordinate x and y are in the reference of the
+ current displayed mode.
+
+ :param float x: X-coordinate of the pixel in the current displayed plot
+ :param float y: Y-coordinate of the pixel in the current displayed plot
+ :return: A tuple of for each images containing pixel information. It
+ could be a scalar value or an array in case of RGB/RGBA informations.
+ It also could be a string containing information is some cases.
+ :rtype: Tuple(Union[int,float,numpy.ndarray,str],Union[int,float,numpy.ndarray,str])
+ """
+ data2 = None
+ alignmentMode = self.__alignmentMode
+ raw1, raw2 = self.__raw1, self.__raw2
+ if alignmentMode == AlignmentMode.ORIGIN:
+ x1 = x
+ y1 = y
+ x2 = x
+ y2 = y
+ elif alignmentMode == AlignmentMode.CENTER:
+ yy = max(raw1.shape[0], raw2.shape[0])
+ xx = max(raw1.shape[1], raw2.shape[1])
+ x1 = x - (xx - raw1.shape[1]) * 0.5
+ x2 = x - (xx - raw2.shape[1]) * 0.5
+ y1 = y - (yy - raw1.shape[0]) * 0.5
+ y2 = y - (yy - raw2.shape[0]) * 0.5
+ elif alignmentMode == AlignmentMode.STRETCH:
+ x1 = x
+ y1 = y
+ x2 = x * raw2.shape[1] / raw1.shape[1]
+ y2 = x * raw2.shape[1] / raw1.shape[1]
+ elif alignmentMode == AlignmentMode.AUTO:
+ x1 = x
+ y1 = y
+ # Not implemented
+ data2 = "Not implemented with sift"
+ else:
+ assert(False)
+
+ x1, y1 = int(x1), int(y1)
+ if raw1 is None or y1 < 0 or y1 >= raw1.shape[0] or x1 < 0 or x1 >= raw1.shape[1]:
+ data1 = None
+ else:
+ data1 = raw1[y1, x1]
+
+ if data2 is None:
+ x2, y2 = int(x2), int(y2)
+ if raw2 is None or y2 < 0 or y2 >= raw2.shape[0] or x2 < 0 or x2 >= raw2.shape[1]:
+ data2 = None
+ else:
+ data2 = raw2[y2, x2]
+
+ return data1, data2
+
+ def setVisualizationMode(self, mode):
+ """Set the visualization mode.
+
+ :param str mode: New visualization to display the image comparison
+ """
+ if self.__visualizationMode == mode:
+ return
+ previousMode = self.getVisualizationMode()
+ self.__visualizationMode = mode
+ mode = self.getVisualizationMode()
+ self.__vline.setVisible(mode == VisualizationMode.VERTICAL_LINE)
+ self.__hline.setVisible(mode == VisualizationMode.HORIZONTAL_LINE)
+ visModeRawDisplay = (VisualizationMode.ONLY_A,
+ VisualizationMode.ONLY_B,
+ VisualizationMode.VERTICAL_LINE,
+ VisualizationMode.HORIZONTAL_LINE)
+ updateColormap = not(previousMode in visModeRawDisplay and
+ mode in visModeRawDisplay)
+ self.__updateData(updateColormap=updateColormap)
+ self.sigConfigurationChanged.emit()
+
+ def getVisualizationMode(self):
+ """Returns the current interaction mode."""
+ return self.__visualizationMode
+
+ def setAlignmentMode(self, mode):
+ """Set the alignment mode.
+
+ :param str mode: New alignement to apply to images
+ """
+ if self.__alignmentMode == mode:
+ return
+ self.__alignmentMode = mode
+ self.__updateData(updateColormap=False)
+ self.sigConfigurationChanged.emit()
+
+ def getAlignmentMode(self):
+ """Returns the current selected alignemnt mode."""
+ return self.__alignmentMode
+
+ def setKeypointsVisible(self, isVisible):
+ """Set keypoints visibility.
+
+ :param bool isVisible: If True, keypoints are displayed (if some)
+ """
+ if self.__keypointsVisible == isVisible:
+ return
+ self.__keypointsVisible = isVisible
+ self.__updateKeyPoints()
+ self.sigConfigurationChanged.emit()
+
+ def __setDefaultAlignmentMode(self):
+ """Reset the alignemnt mode to the default value"""
+ self.setAlignmentMode(AlignmentMode.ORIGIN)
+
+ def __plotSlot(self, event):
+ """Handle events from the plot"""
+ if event['event'] in ('markerMoving', 'markerMoved'):
+ mode = self.getVisualizationMode()
+ legend = mode.name
+ if event['label'] == legend:
+ if mode == VisualizationMode.VERTICAL_LINE:
+ value = int(float(str(event['xdata'])))
+ elif mode == VisualizationMode.HORIZONTAL_LINE:
+ value = int(float(str(event['ydata'])))
+ else:
+ assert(False)
+ if self.__previousSeparatorPosition != value:
+ self.__separatorMoved(value)
+ self.__previousSeparatorPosition = value
+
+ def __separatorConstraint(self, x, y):
+ """Manage contains on the separators to clamp them inside the images."""
+ if self.__data1 is None:
+ return 0, 0
+ x = int(x)
+ if x < 0:
+ x = 0
+ elif x > self.__data1.shape[1]:
+ x = self.__data1.shape[1]
+ y = int(y)
+ if y < 0:
+ y = 0
+ elif y > self.__data1.shape[0]:
+ y = self.__data1.shape[0]
+ return x, y
+
+ def __updateSeparators(self):
+ """Redraw images according to the current state of the separators.
+ """
+ mode = self.getVisualizationMode()
+ if mode == VisualizationMode.VERTICAL_LINE:
+ pos = self.__vline.getXPosition()
+ self.__separatorMoved(pos)
+ self.__previousSeparatorPosition = pos
+ elif mode == VisualizationMode.HORIZONTAL_LINE:
+ pos = self.__hline.getYPosition()
+ self.__separatorMoved(pos)
+ self.__previousSeparatorPosition = pos
+ else:
+ self.__image1.setOrigin((0, 0))
+ self.__image2.setOrigin((0, 0))
+
+ def __separatorMoved(self, pos):
+ """Called when vertical or horizontal separators have moved.
+
+ Update the displayed images.
+ """
+ if self.__data1 is None:
+ return
+
+ mode = self.getVisualizationMode()
+ if mode == VisualizationMode.VERTICAL_LINE:
+ pos = int(pos)
+ if pos <= 0:
+ pos = 0
+ elif pos >= self.__data1.shape[1]:
+ pos = self.__data1.shape[1]
+ data1 = self.__data1[:, 0:pos]
+ data2 = self.__data2[:, pos:]
+ self.__image1.setData(data1, copy=False)
+ self.__image2.setData(data2, copy=False)
+ self.__image2.setOrigin((pos, 0))
+ elif mode == VisualizationMode.HORIZONTAL_LINE:
+ pos = int(pos)
+ if pos <= 0:
+ pos = 0
+ elif pos >= self.__data1.shape[0]:
+ pos = self.__data1.shape[0]
+ data1 = self.__data1[0:pos, :]
+ data2 = self.__data2[pos:, :]
+ self.__image1.setData(data1, copy=False)
+ self.__image2.setData(data2, copy=False)
+ self.__image2.setOrigin((0, pos))
+ else:
+ assert(False)
+
+ def setData(self, image1, image2, updateColormap=True):
+ """Set images to compare.
+
+ Images can contains floating-point or integer values, or RGB and RGBA
+ values, but should have comparable intensities.
+
+ RGB and RGBA images are provided as an array as `[width,height,channels]`
+ of usigned integer 8-bits or floating-points between 0.0 to 1.0.
+
+ :param numpy.ndarray image1: The first image
+ :param numpy.ndarray image2: The second image
+ """
+ self.__raw1 = image1
+ self.__raw2 = image2
+ self.__updateData(updateColormap=updateColormap)
+ if self.isAutoResetZoom():
+ self.__plot.resetZoom()
+
+ def setImage1(self, image1, updateColormap=True):
+ """Set image1 to be compared.
+
+ Images can contains floating-point or integer values, or RGB and RGBA
+ values, but should have comparable intensities.
+
+ RGB and RGBA images are provided as an array as `[width,height,channels]`
+ of usigned integer 8-bits or floating-points between 0.0 to 1.0.
+
+ :param numpy.ndarray image1: The first image
+ """
+ self.__raw1 = image1
+ self.__updateData(updateColormap=updateColormap)
+ if self.isAutoResetZoom():
+ self.__plot.resetZoom()
+
+ def setImage2(self, image2, updateColormap=True):
+ """Set image2 to be compared.
+
+ Images can contains floating-point or integer values, or RGB and RGBA
+ values, but should have comparable intensities.
+
+ RGB and RGBA images are provided as an array as `[width,height,channels]`
+ of usigned integer 8-bits or floating-points between 0.0 to 1.0.
+
+ :param numpy.ndarray image2: The second image
+ """
+ self.__raw2 = image2
+ self.__updateData(updateColormap=updateColormap)
+ if self.isAutoResetZoom():
+ self.__plot.resetZoom()
+
+ def __updateKeyPoints(self):
+ """Update the displayed keypoints using cached keypoints.
+ """
+ if self.__keypointsVisible:
+ data = self.__matching_keypoints
+ else:
+ data = [], [], []
+ self.__plot.addScatter(x=data[0],
+ y=data[1],
+ z=1,
+ value=data[2],
+ colormap=self._colormapKeyPoints,
+ legend="keypoints")
+
+ def __updateData(self, updateColormap):
+ """Compute aligned image when the alignment mode changes.
+
+ This function cache input images which are used when
+ vertical/horizontal separators moves.
+ """
+ raw1, raw2 = self.__raw1, self.__raw2
+ if raw1 is None or raw2 is None:
+ return
+
+ alignmentMode = self.getAlignmentMode()
+ self.__transformation = None
+
+ if alignmentMode == AlignmentMode.ORIGIN:
+ yy = max(raw1.shape[0], raw2.shape[0])
+ xx = max(raw1.shape[1], raw2.shape[1])
+ size = yy, xx
+ data1 = self.__createMarginImage(raw1, size, transparent=True)
+ data2 = self.__createMarginImage(raw2, size, transparent=True)
+ self.__matching_keypoints = [0.0], [0.0], [1.0]
+ elif alignmentMode == AlignmentMode.CENTER:
+ yy = max(raw1.shape[0], raw2.shape[0])
+ xx = max(raw1.shape[1], raw2.shape[1])
+ size = yy, xx
+ data1 = self.__createMarginImage(raw1, size, transparent=True, center=True)
+ data2 = self.__createMarginImage(raw2, size, transparent=True, center=True)
+ self.__matching_keypoints = ([data1.shape[1] // 2],
+ [data1.shape[0] // 2],
+ [1.0])
+ elif alignmentMode == AlignmentMode.STRETCH:
+ data1 = raw1
+ data2 = self.__rescaleImage(raw2, data1.shape)
+ self.__matching_keypoints = ([0, data1.shape[1], data1.shape[1], 0],
+ [0, 0, data1.shape[0], data1.shape[0]],
+ [1.0, 1.0, 1.0, 1.0])
+ elif alignmentMode == AlignmentMode.AUTO:
+ # TODO: sift implementation do not support RGBA images
+ yy = max(raw1.shape[0], raw2.shape[0])
+ xx = max(raw1.shape[1], raw2.shape[1])
+ size = yy, xx
+ data1 = self.__createMarginImage(raw1, size)
+ data2 = self.__createMarginImage(raw2, size)
+ self.__matching_keypoints = [0.0], [0.0], [1.0]
+ try:
+ data1, data2 = self.__createSiftData(data1, data2)
+ if data2 is None:
+ raise ValueError("Unexpected None value")
+ except Exception as e:
+ # TODO: Display it on the GUI
+ _logger.error(e)
+ self.__setDefaultAlignmentMode()
+ return
+ else:
+ assert(False)
+
+ mode = self.getVisualizationMode()
+ if mode == VisualizationMode.COMPOSITE_RED_BLUE_GRAY_NEG:
+ data1 = self.__composeImage(data1, data2, mode)
+ data2 = numpy.empty((0, 0))
+ elif mode == VisualizationMode.COMPOSITE_RED_BLUE_GRAY:
+ data1 = self.__composeImage(data1, data2, mode)
+ data2 = numpy.empty((0, 0))
+ elif mode == VisualizationMode.COMPOSITE_A_MINUS_B:
+ data1 = self.__composeImage(data1, data2, mode)
+ data2 = numpy.empty((0, 0))
+ elif mode == VisualizationMode.ONLY_A:
+ data2 = numpy.empty((0, 0))
+ elif mode == VisualizationMode.ONLY_B:
+ data1 = numpy.empty((0, 0))
+
+ self.__data1, self.__data2 = data1, data2
+ self.__plot.addImage(data1, z=0, legend="image1", resetzoom=False)
+ self.__plot.addImage(data2, z=0, legend="image2", resetzoom=False)
+ self.__image1 = self.__plot.getImage("image1")
+ self.__image2 = self.__plot.getImage("image2")
+ self.__updateKeyPoints()
+
+ # Set the separator into the middle
+ if self.__previousSeparatorPosition is None:
+ value = self.__data1.shape[1] // 2
+ self.__vline.setPosition(value, 0)
+ value = self.__data1.shape[0] // 2
+ self.__hline.setPosition(0, value)
+ self.__updateSeparators()
+ if updateColormap:
+ self.__updateColormap()
+
+ def __updateColormap(self):
+ # TODO: The colormap histogram will still be wrong
+ mode1 = self.__getImageMode(self.__data1)
+ mode2 = self.__getImageMode(self.__data2)
+ if mode1 == "intensity" and mode1 == mode2:
+ if self.__data1.size == 0:
+ vmin = self.__data2.min()
+ vmax = self.__data2.max()
+ elif self.__data2.size == 0:
+ vmin = self.__data1.min()
+ vmax = self.__data1.max()
+ else:
+ vmin = min(self.__data1.min(), self.__data2.min())
+ vmax = max(self.__data1.max(), self.__data2.max())
+ colormap = self.getColormap()
+ colormap.setVRange(vmin=vmin, vmax=vmax)
+ self.__image1.setColormap(colormap)
+ self.__image2.setColormap(colormap)
+
+ def __getImageMode(self, image):
+ """Returns a value identifying the way the image is stored in the
+ array.
+
+ :param numpy.ndarray image: Image to check
+ :rtype: str
+ """
+ if len(image.shape) == 2:
+ return "intensity"
+ elif len(image.shape) == 3:
+ if image.shape[2] == 3:
+ return "rgb"
+ elif image.shape[2] == 4:
+ return "rgba"
+ raise TypeError("'image' argument is not an image.")
+
+ def __rescaleImage(self, image, shape):
+ """Rescale an image to the requested shape.
+
+ :rtype: numpy.ndarray
+ """
+ mode = self.__getImageMode(image)
+ if mode == "intensity":
+ data = self.__rescaleArray(image, shape)
+ elif mode == "rgb":
+ data = numpy.empty((shape[0], shape[1], 3), dtype=image.dtype)
+ for c in range(3):
+ data[:, :, c] = self.__rescaleArray(image[:, :, c], shape)
+ elif mode == "rgba":
+ data = numpy.empty((shape[0], shape[1], 4), dtype=image.dtype)
+ for c in range(4):
+ data[:, :, c] = self.__rescaleArray(image[:, :, c], shape)
+ return data
+
+ def __composeImage(self, data1, data2, mode):
+ """Returns an RBG image containing composition of data1 and data2 in 2
+ different channels
+
+ :param numpy.ndarray data1: First image
+ :param numpy.ndarray data1: Second image
+ :param VisualizationMode mode: Composition mode.
+ :rtype: numpy.ndarray
+ """
+ assert(data1.shape[0:2] == data2.shape[0:2])
+ if mode == VisualizationMode.COMPOSITE_A_MINUS_B:
+ # TODO: this calculation has no interest of generating a 'composed'
+ # rgb image, this could be moved in an other function or doc
+ # should be modified
+ _type = data1.dtype
+ result = data1.astype(numpy.float64) - data2.astype(numpy.float64)
+ return result
+ mode1 = self.__getImageMode(data1)
+ if mode1 in ["rgb", "rgba"]:
+ intensity1 = self.__luminosityImage(data1)
+ vmin1, vmax1 = 0.0, 1.0
+ else:
+ intensity1 = data1
+ vmin1, vmax1 = data1.min(), data1.max()
+
+ mode2 = self.__getImageMode(data2)
+ if mode2 in ["rgb", "rgba"]:
+ intensity2 = self.__luminosityImage(data2)
+ vmin2, vmax2 = 0.0, 1.0
+ else:
+ intensity2 = data2
+ vmin2, vmax2 = data2.min(), data2.max()
+
+ vmin, vmax = min(vmin1, vmin2) * 1.0, max(vmax1, vmax2) * 1.0
+ shape = data1.shape
+ result = numpy.empty((shape[0], shape[1], 3), dtype=numpy.uint8)
+ a = (intensity1 - vmin) * (1.0 / (vmax - vmin)) * 255.0
+ b = (intensity2 - vmin) * (1.0 / (vmax - vmin)) * 255.0
+ if mode == VisualizationMode.COMPOSITE_RED_BLUE_GRAY:
+ result[:, :, 0] = a
+ result[:, :, 1] = (a + b) / 2
+ result[:, :, 2] = b
+ elif mode == VisualizationMode.COMPOSITE_RED_BLUE_GRAY_NEG:
+ result[:, :, 0] = 255 - b
+ result[:, :, 1] = 255 - (a + b) / 2
+ result[:, :, 2] = 255 - a
+ return result
+
+ def __luminosityImage(self, image):
+ """Returns the luminosity channel from an RBG(A) image.
+ The alpha channel is ignored.
+
+ :rtype: numpy.ndarray
+ """
+ mode = self.__getImageMode(image)
+ assert(mode in ["rgb", "rgba"])
+ is_uint8 = image.dtype.type == numpy.uint8
+ # luminosity
+ image = 0.21 * image[..., 0] + 0.72 * image[..., 1] + 0.07 * image[..., 2]
+ if is_uint8:
+ image = image / 255.0
+ return image
+
+ def __rescaleArray(self, image, shape):
+ """Rescale a 2D array to the requested shape.
+
+ :rtype: numpy.ndarray
+ """
+ y, x = numpy.ogrid[:shape[0], :shape[1]]
+ y, x = y * 1.0 * (image.shape[0] - 1) / (shape[0] - 1), x * 1.0 * (image.shape[1] - 1) / (shape[1] - 1)
+ b = silx.image.bilinear.BilinearImage(image)
+ # TODO: could be optimized using strides
+ x2d = numpy.zeros_like(y) + x
+ y2d = numpy.zeros_like(x) + y
+ result = b.map_coordinates((y2d, x2d))
+ return result
+
+ def __createMarginImage(self, image, size, transparent=False, center=False):
+ """Returns a new image with margin to respect the requested size.
+
+ :rtype: numpy.ndarray
+ """
+ assert(image.shape[0] <= size[0])
+ assert(image.shape[1] <= size[1])
+ if image.shape == size:
+ return image
+ mode = self.__getImageMode(image)
+
+ if center:
+ pos0 = size[0] // 2 - image.shape[0] // 2
+ pos1 = size[1] // 2 - image.shape[1] // 2
+ else:
+ pos0, pos1 = 0, 0
+
+ if mode == "intensity":
+ data = numpy.zeros(size, dtype=image.dtype)
+ data[pos0:pos0 + image.shape[0], pos1:pos1 + image.shape[1]] = image
+ # TODO: It is maybe possible to put NaN on the margin
+ else:
+ if transparent:
+ data = numpy.zeros((size[0], size[1], 4), dtype=numpy.uint8)
+ else:
+ data = numpy.zeros((size[0], size[1], 3), dtype=numpy.uint8)
+ depth = min(data.shape[2], image.shape[2])
+ data[pos0:pos0 + image.shape[0], pos1:pos1 + image.shape[1], 0:depth] = image[:, :, 0:depth]
+ if transparent and depth == 3:
+ data[pos0:pos0 + image.shape[0], pos1:pos1 + image.shape[1], 3] = 255
+ return data
+
+ def __toAffineTransformation(self, sift_result):
+ """Returns an affine transformation from the sift result.
+
+ :param dict sift_result: Result of sift when using `all_result=True`
+ :rtype: AffineTransformation
+ """
+ offset = sift_result["offset"]
+ matrix = sift_result["matrix"]
+
+ tx = offset[0]
+ ty = offset[1]
+ a = matrix[0, 0]
+ b = matrix[0, 1]
+ c = matrix[1, 0]
+ d = matrix[1, 1]
+ rot = math.atan2(-b, a)
+ sx = (-1.0 if a < 0 else 1.0) * math.sqrt(a**2 + b**2)
+ sy = (-1.0 if d < 0 else 1.0) * math.sqrt(c**2 + d**2)
+ return AffineTransformation(tx, ty, sx, sy, rot)
+
+ def getTransformation(self):
+ """Retuns the affine transformation applied to the second image to align
+ it to the first image.
+
+ This result is only valid for sift alignment.
+
+ :rtype: Union[None,AffineTransformation]
+ """
+ return self.__transformation
+
+ def __createSiftData(self, image, second_image):
+ """Generate key points and aligned images from 2 images.
+
+ If no keypoints matches, unaligned data are anyway returns.
+
+ :rtype: Tuple(numpy.ndarray,numpy.ndarray)
+ """
+ devicetype = "GPU"
+
+ # Compute base image
+ sift_ocl = sift.SiftPlan(template=image, devicetype=devicetype)
+ keypoints = sift_ocl(image)
+
+ # Check image compatibility
+ second_keypoints = sift_ocl(second_image)
+ mp = sift.MatchPlan()
+ match = mp(keypoints, second_keypoints)
+ _logger.info("Number of Keypoints within image 1: %i" % keypoints.size)
+ _logger.info(" within image 2: %i" % second_keypoints.size)
+
+ self.__matching_keypoints = (match[:].x[:, 0],
+ match[:].y[:, 0],
+ match[:].scale[:, 0])
+ matching_keypoints = match.shape[0]
+ _logger.info("Matching keypoints: %i" % matching_keypoints)
+ if matching_keypoints == 0:
+ return image, second_image
+
+ # TODO: Problem here is we have to compute 2 time sift
+ # The first time to extract matching keypoints, second time
+ # to extract the aligned image.
+
+ # Normalize the second image
+ sa = sift.LinearAlign(image, devicetype=devicetype)
+ data1 = image
+ # TODO: Create a sift issue: if data1 is RGB and data2 intensity
+ # it returns None, while extracting manually keypoints (above) works
+ result = sa.align(second_image, return_all=True)
+ data2 = result["result"]
+ self.__transformation = self.__toAffineTransformation(result)
+ return data1, data2
+
+ def setAutoResetZoom(self, activate=True):
+ """
+
+ :param bool activate: True if we want to activate the automatic
+ plot reset zoom when setting images.
+ """
+ self._resetZoomActive = activate
+
+ def isAutoResetZoom(self):
+ """
+
+ :return: True if the automatic call to resetzoom is activated
+ :rtype: bool
+ """
+ return self._resetZoomActive
diff --git a/src/silx/gui/plot/ComplexImageView.py b/src/silx/gui/plot/ComplexImageView.py
new file mode 100644
index 0000000..4eee3b0
--- /dev/null
+++ b/src/silx/gui/plot/ComplexImageView.py
@@ -0,0 +1,518 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a widget to view 2D complex data.
+
+The :class:`ComplexImageView` widget is dedicated to visualize a single 2D dataset
+of complex data.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["Vincent Favre-Nicolin", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+
+import logging
+import collections
+import numpy
+
+from ...utils.deprecation import deprecated
+from .. import qt, icons
+from .PlotWindow import Plot2D
+from . import items
+from .items import ImageComplexData
+from silx.gui.widgets.FloatEdit import FloatEdit
+
+_logger = logging.getLogger(__name__)
+
+
+# Widgets
+
+class _AmplitudeRangeDialog(qt.QDialog):
+ """QDialog asking for the amplitude range to display."""
+
+ sigRangeChanged = qt.Signal(tuple)
+ """Signal emitted when the range has changed.
+
+ It provides the new range as a 2-tuple: (max, delta)
+ """
+
+ def __init__(self,
+ parent=None,
+ amplitudeRange=None,
+ displayedRange=(None, 2)):
+ super(_AmplitudeRangeDialog, self).__init__(parent)
+ self.setWindowTitle('Set Displayed Amplitude Range')
+
+ if amplitudeRange is not None:
+ amplitudeRange = min(amplitudeRange), max(amplitudeRange)
+ self._amplitudeRange = amplitudeRange
+ self._defaultDisplayedRange = displayedRange
+
+ layout = qt.QFormLayout()
+ self.setLayout(layout)
+
+ if self._amplitudeRange is not None:
+ min_, max_ = self._amplitudeRange
+ layout.addRow(
+ qt.QLabel('Data Amplitude Range: [%g, %g]' % (min_, max_)))
+
+ self._maxLineEdit = FloatEdit(parent=self)
+ self._maxLineEdit.validator().setBottom(0.)
+ self._maxLineEdit.setAlignment(qt.Qt.AlignRight)
+
+ self._maxLineEdit.editingFinished.connect(self._rangeUpdated)
+ layout.addRow('Displayed Max.:', self._maxLineEdit)
+
+ self._autoscale = qt.QCheckBox('autoscale')
+ self._autoscale.toggled.connect(self._autoscaleCheckBoxToggled)
+ layout.addRow('', self._autoscale)
+
+ self._deltaLineEdit = FloatEdit(parent=self)
+ self._deltaLineEdit.validator().setBottom(1.)
+ self._deltaLineEdit.setAlignment(qt.Qt.AlignRight)
+ self._deltaLineEdit.editingFinished.connect(self._rangeUpdated)
+ layout.addRow('Displayed delta (log10 unit):', self._deltaLineEdit)
+
+ buttons = qt.QDialogButtonBox(self)
+ buttons.addButton(qt.QDialogButtonBox.Ok)
+ buttons.addButton(qt.QDialogButtonBox.Cancel)
+ buttons.accepted.connect(self.accept)
+ buttons.rejected.connect(self.reject)
+ layout.addRow(buttons)
+
+ # Set dialog from default values
+ self._resetDialogToDefault()
+
+ self.rejected.connect(self._handleRejected)
+
+ def _resetDialogToDefault(self):
+ """Set Widgets of the dialog from range information
+ """
+ max_, delta = self._defaultDisplayedRange
+
+ if max_ is not None: # Not in autoscale
+ displayedMax = max_
+ elif self._amplitudeRange is not None: # Autoscale with data
+ displayedMax = self._amplitudeRange[1]
+ else: # Autoscale without data
+ displayedMax = ''
+ if displayedMax == "":
+ self._maxLineEdit.setText("")
+ else:
+ self._maxLineEdit.setValue(displayedMax)
+ self._maxLineEdit.setEnabled(max_ is not None)
+
+ self._deltaLineEdit.setValue(delta)
+
+ self._autoscale.setChecked(self._defaultDisplayedRange[0] is None)
+
+ def getRangeInfo(self):
+ """Returns the current range as a 2-tuple (max, delta (in log10))"""
+ if self._autoscale.isChecked():
+ max_ = None
+ else:
+ maxStr = self._maxLineEdit.text()
+ max_ = self._maxLineEdit.value() if maxStr else None
+ return max_, self._deltaLineEdit.value() if self._deltaLineEdit.text() else 2
+
+ def _handleRejected(self):
+ """Reset range info to default when rejected"""
+ self._resetDialogToDefault()
+ self._rangeUpdated()
+
+ def _rangeUpdated(self):
+ """Handle QLineEdit editing finised"""
+ self.sigRangeChanged.emit(self.getRangeInfo())
+
+ def _autoscaleCheckBoxToggled(self, checked):
+ """Handle autoscale checkbox state changes"""
+ if checked: # Use default values
+ if self._amplitudeRange is None:
+ max_ = ''
+ else:
+ max_ = self._amplitudeRange[1]
+ if max_ == "":
+ self._maxLineEdit.setText("")
+ else:
+ self._maxLineEdit.setValue(max_)
+ self._maxLineEdit.setEnabled(not checked)
+ self._rangeUpdated()
+
+
+class _ComplexDataToolButton(qt.QToolButton):
+ """QToolButton providing choices of complex data visualization modes
+
+ :param parent: See :class:`QToolButton`
+ :param plot: The :class:`ComplexImageView` to control
+ """
+
+ _MODES = collections.OrderedDict([
+ (ImageComplexData.ComplexMode.ABSOLUTE, ('math-amplitude', 'Amplitude')),
+ (ImageComplexData.ComplexMode.SQUARE_AMPLITUDE,
+ ('math-square-amplitude', 'Square amplitude')),
+ (ImageComplexData.ComplexMode.PHASE, ('math-phase', 'Phase')),
+ (ImageComplexData.ComplexMode.REAL, ('math-real', 'Real part')),
+ (ImageComplexData.ComplexMode.IMAGINARY,
+ ('math-imaginary', 'Imaginary part')),
+ (ImageComplexData.ComplexMode.AMPLITUDE_PHASE,
+ ('math-phase-color', 'Amplitude and Phase')),
+ (ImageComplexData.ComplexMode.LOG10_AMPLITUDE_PHASE,
+ ('math-phase-color-log', 'Log10(Amp.) and Phase'))
+ ])
+
+ _RANGE_DIALOG_TEXT = 'Set Amplitude Range...'
+
+ def __init__(self, parent=None, plot=None):
+ super(_ComplexDataToolButton, self).__init__(parent=parent)
+
+ assert plot is not None
+ self._plot2DComplex = plot
+
+ menu = qt.QMenu(self)
+ menu.triggered.connect(self._triggered)
+ self.setMenu(menu)
+
+ for mode, info in self._MODES.items():
+ icon, text = info
+ action = qt.QAction(icons.getQIcon(icon), text, self)
+ action.setData(mode)
+ action.setIconVisibleInMenu(True)
+ menu.addAction(action)
+
+ self._rangeDialogAction = qt.QAction(self)
+ self._rangeDialogAction.setText(self._RANGE_DIALOG_TEXT)
+ menu.addAction(self._rangeDialogAction)
+
+ self.setPopupMode(qt.QToolButton.InstantPopup)
+
+ self._modeChanged(self._plot2DComplex.getComplexMode())
+ self._plot2DComplex.sigVisualizationModeChanged.connect(
+ self._modeChanged)
+
+ def _modeChanged(self, mode):
+ """Handle change of visualization modes"""
+ icon, text = self._MODES[mode]
+ self.setIcon(icons.getQIcon(icon))
+ self.setToolTip('Display the ' + text.lower())
+ self._rangeDialogAction.setEnabled(
+ mode == ImageComplexData.ComplexMode.LOG10_AMPLITUDE_PHASE)
+
+ def _triggered(self, action):
+ """Handle triggering of menu actions"""
+ actionText = action.text()
+
+ if actionText == self._RANGE_DIALOG_TEXT: # Show dialog
+ # Get amplitude range
+ data = self._plot2DComplex.getData(copy=False)
+
+ if data.size > 0:
+ absolute = numpy.absolute(data)
+ dataRange = (numpy.nanmin(absolute), numpy.nanmax(absolute))
+ else:
+ dataRange = None
+
+ # Show dialog
+ dialog = _AmplitudeRangeDialog(
+ parent=self,
+ amplitudeRange=dataRange,
+ displayedRange=self._plot2DComplex._getAmplitudeRangeInfo())
+ dialog.sigRangeChanged.connect(self._rangeChanged)
+ dialog.exec()
+ dialog.sigRangeChanged.disconnect(self._rangeChanged)
+
+ else: # update mode
+ mode = action.data()
+ if isinstance(mode, ImageComplexData.ComplexMode):
+ self._plot2DComplex.setComplexMode(mode)
+
+ def _rangeChanged(self, range_):
+ """Handle updates of range in the dialog"""
+ self._plot2DComplex._setAmplitudeRangeInfo(*range_)
+
+
+class ComplexImageView(qt.QWidget):
+ """Display an image of complex data and allow to choose the visualization.
+
+ :param parent: See :class:`QMainWindow`
+ """
+
+ ComplexMode = ImageComplexData.ComplexMode
+ """Complex Modes enumeration"""
+
+ sigDataChanged = qt.Signal()
+ """Signal emitted when data has changed."""
+
+ sigVisualizationModeChanged = qt.Signal(object)
+ """Signal emitted when the visualization mode has changed.
+
+ It provides the new visualization mode.
+ """
+
+ def __init__(self, parent=None):
+ super(ComplexImageView, self).__init__(parent)
+ if parent is None:
+ self.setWindowTitle('ComplexImageView')
+
+ self._plot2D = Plot2D(self)
+
+ layout = qt.QHBoxLayout(self)
+ layout.setSpacing(0)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.addWidget(self._plot2D)
+ self.setLayout(layout)
+
+ # Create and add image to the plot
+ self._plotImage = ImageComplexData()
+ self._plotImage.setName('__ComplexImageView__complex_image__')
+ self._plotImage.sigItemChanged.connect(self._itemChanged)
+ self._plot2D.addItem(self._plotImage)
+ self._plot2D.setActiveImage(self._plotImage.getName())
+
+ toolBar = qt.QToolBar('Complex', self)
+ toolBar.addWidget(
+ _ComplexDataToolButton(parent=self, plot=self))
+
+ self._plot2D.insertToolBar(self._plot2D.getProfileToolbar(), toolBar)
+
+ def _itemChanged(self, event):
+ """Handle item changed signal"""
+ if event is items.ItemChangedType.DATA:
+ self.sigDataChanged.emit()
+ elif event is items.ItemChangedType.VISUALIZATION_MODE:
+ mode = self.getComplexMode()
+ self.sigVisualizationModeChanged.emit(mode)
+
+ def getPlot(self):
+ """Return the PlotWidget displaying the data"""
+ return self._plot2D
+
+ def setData(self, data=None, copy=True):
+ """Set the complex data to display.
+
+ :param numpy.ndarray data: 2D complex data
+ :param bool copy: True (default) to copy the data,
+ False to use provided data (do not modify!).
+ """
+ if data is None:
+ data = numpy.zeros((0, 0), dtype=numpy.complex64)
+
+ previousData = self._plotImage.getComplexData(copy=False)
+
+ self._plotImage.setData(data, copy=copy)
+
+ if previousData.shape != data.shape:
+ self.getPlot().resetZoom()
+
+ def getData(self, copy=True):
+ """Get the currently displayed complex data.
+
+ :param bool copy: True (default) to return a copy of the data,
+ False to return internal data (do not modify!).
+ :return: The complex data array.
+ :rtype: numpy.ndarray of complex with 2 dimensions
+ """
+ return self._plotImage.getComplexData(copy=copy)
+
+ def getDisplayedData(self, copy=True):
+ """Returns the displayed data depending on the visualization mode
+
+ WARNING: The returned data can be a uint8 RGBA image
+
+ :param bool copy: True (default) to return a copy of the data,
+ False to return internal data (do not modify!)
+ :rtype: numpy.ndarray of float with 2 dims or RGBA image (uint8).
+ """
+ mode = self.getComplexMode()
+ if mode in (self.ComplexMode.AMPLITUDE_PHASE,
+ self.ComplexMode.LOG10_AMPLITUDE_PHASE):
+ return self._plotImage.getRgbaImageData(copy=copy)
+ else:
+ return self._plotImage.getData(copy=copy)
+
+ # Backward compatibility
+
+ Mode = ComplexMode
+
+ @classmethod
+ @deprecated(replacement='supportedComplexModes', since_version='0.11.0')
+ def getSupportedVisualizationModes(cls):
+ return cls.supportedComplexModes()
+
+ @deprecated(replacement='setComplexMode', since_version='0.11.0')
+ def setVisualizationMode(self, mode):
+ return self.setComplexMode(mode)
+
+ @deprecated(replacement='getComplexMode', since_version='0.11.0')
+ def getVisualizationMode(self):
+ return self.getComplexMode()
+
+ # Image item proxy
+
+ @staticmethod
+ def supportedComplexModes():
+ """Returns the supported visualization modes.
+
+ Supported visualization modes are:
+
+ - amplitude: The absolute value provided by numpy.absolute
+ - phase: The phase (or argument) provided by numpy.angle
+ - real: Real part
+ - imaginary: Imaginary part
+ - amplitude_phase: Color-coded phase with amplitude as alpha.
+ - log10_amplitude_phase:
+ Color-coded phase with log10(amplitude) as alpha.
+
+ :rtype: List[ComplexMode]
+ """
+ return ImageComplexData.supportedComplexModes()
+
+ def setComplexMode(self, mode):
+ """Set the mode of visualization of the complex data.
+
+ See :meth:`supportedComplexModes` for the list of
+ supported modes.
+
+ How-to change visualization mode::
+
+ widget = ComplexImageView()
+ widget.setComplexMode(ComplexImageView.ComplexMode.PHASE)
+ # or
+ widget.setComplexMode('phase')
+
+ :param Unions[ComplexMode,str] mode: The mode to use.
+ """
+ self._plotImage.setComplexMode(mode)
+
+ def getComplexMode(self):
+ """Get the current visualization mode of the complex data.
+
+ :rtype: ComplexMode
+ """
+ return self._plotImage.getComplexMode()
+
+ def _setAmplitudeRangeInfo(self, max_=None, delta=2):
+ """Set the amplitude range to display for 'log10_amplitude_phase' mode.
+
+ :param max_: Max of the amplitude range.
+ If None it autoscales to data max.
+ :param float delta: Delta range in log10 to display
+ """
+ self._plotImage._setAmplitudeRangeInfo(max_, delta)
+
+ def _getAmplitudeRangeInfo(self):
+ """Returns the amplitude range to use for 'log10_amplitude_phase' mode.
+
+ :return: (max, delta), if max is None, then it autoscales to data max
+ :rtype: 2-tuple"""
+ return self._plotImage._getAmplitudeRangeInfo()
+
+ def setColormap(self, colormap, mode=None):
+ """Set the colormap to use for amplitude, phase, real or imaginary.
+
+ WARNING: This colormap is not used when displaying both
+ amplitude and phase.
+
+ :param ~silx.gui.colors.Colormap colormap: The colormap
+ :param ComplexMode mode: If specified, set the colormap of this specific mode
+ """
+ self._plotImage.setColormap(colormap, mode)
+
+ def getColormap(self, mode=None):
+ """Returns the colormap used to display the data.
+
+ :param ComplexMode mode: If specified, set the colormap of this specific mode
+ :rtype: ~silx.gui.colors.Colormap
+ """
+ return self._plotImage.getColormap(mode=mode)
+
+ def getOrigin(self):
+ """Returns the offset from origin at which to display the image.
+
+ :rtype: 2-tuple of float
+ """
+ return self._plotImage.getOrigin()
+
+ def setOrigin(self, origin):
+ """Set the offset from origin at which to display the image.
+
+ :param origin: (ox, oy) Offset from origin
+ :type origin: float or 2-tuple of float
+ """
+ self._plotImage.setOrigin(origin)
+
+ def getScale(self):
+ """Returns the scale of the image in data coordinates.
+
+ :rtype: 2-tuple of float
+ """
+ return self._plotImage.getScale()
+
+ def setScale(self, scale):
+ """Set the scale of the image
+
+ :param scale: (sx, sy) Scale of the image
+ :type scale: float or 2-tuple of float
+ """
+ self._plotImage.setScale(scale)
+
+ # PlotWidget API proxy
+
+ def getXAxis(self):
+ """Returns the X axis
+
+ :rtype: :class:`.items.Axis`
+ """
+ return self.getPlot().getXAxis()
+
+ def getYAxis(self):
+ """Returns an Y axis
+
+ :rtype: :class:`.items.Axis`
+ """
+ return self.getPlot().getYAxis(axis='left')
+
+ def getGraphTitle(self):
+ """Return the plot main title as a str."""
+ return self.getPlot().getGraphTitle()
+
+ def setGraphTitle(self, title=""):
+ """Set the plot main title.
+
+ :param str title: Main title of the plot (default: '')
+ """
+ self.getPlot().setGraphTitle(title)
+
+ def setKeepDataAspectRatio(self, flag):
+ """Set whether the plot keeps data aspect ratio or not.
+
+ :param bool flag: True to respect data aspect ratio
+ """
+ self.getPlot().setKeepDataAspectRatio(flag)
+
+ def isKeepDataAspectRatio(self):
+ """Returns whether the plot is keeping data aspect ratio or not."""
+ return self.getPlot().isKeepDataAspectRatio()
diff --git a/src/silx/gui/plot/CurvesROIWidget.py b/src/silx/gui/plot/CurvesROIWidget.py
new file mode 100644
index 0000000..132d398
--- /dev/null
+++ b/src/silx/gui/plot/CurvesROIWidget.py
@@ -0,0 +1,1581 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+Widget to handle regions of interest (:class:`ROI`) on curves displayed in a
+:class:`PlotWindow`.
+
+This widget is meant to work with :class:`PlotWindow`.
+"""
+
+__authors__ = ["V.A. Sole", "T. Vincent", "H. Payno"]
+__license__ = "MIT"
+__date__ = "13/03/2018"
+
+from collections import OrderedDict
+import logging
+import os
+import sys
+import functools
+import numpy
+from silx.io import dictdump
+from silx.utils import deprecation
+from silx.utils.weakref import WeakMethodProxy
+from silx.utils.proxy import docstring
+from .. import icons, qt
+from silx.math.combo import min_max
+import weakref
+from silx.gui.widgets.TableWidget import TableWidget
+from . import items
+from .items.roi import _RegionOfInterestBase
+
+
+_logger = logging.getLogger(__name__)
+
+
+class CurvesROIWidget(qt.QWidget):
+ """
+ Widget displaying a table of ROI information.
+
+ Implements also the following behavior:
+
+ * if the roiTable has no ROI when showing create the default ICR one
+
+ :param parent: See :class:`QWidget`
+ :param str name: The title of this widget
+ """
+
+ sigROIWidgetSignal = qt.Signal(object)
+ """Signal of ROIs modifications.
+
+ Modification information if given as a dict with an 'event' key
+ providing the type of events.
+
+ Type of events:
+
+ - AddROI, DelROI, LoadROI and ResetROI with keys: 'roilist', 'roidict'
+ - selectionChanged with keys: 'row', 'col' 'roi', 'key', 'colheader',
+ 'rowheader'
+ """
+
+ sigROISignal = qt.Signal(object)
+
+ def __init__(self, parent=None, name=None, plot=None):
+ super(CurvesROIWidget, self).__init__(parent)
+ if name is not None:
+ self.setWindowTitle(name)
+ self.__lastSigROISignal = None
+ """Store the last value emitted for the sigRoiSignal. In the case the
+ active curve change we need to add this extra step in order to make
+ sure we won't send twice the sigROISignal.
+ This come from the fact sigROISignal is connected to the
+ activeROIChanged signal which is emitted when raw and net counts
+ values are changing but are not embed in the sigROISignal.
+ """
+ assert plot is not None
+ self._plotRef = weakref.ref(plot)
+ self._showAllMarkers = False
+ self.currentROI = None
+
+ layout = qt.QVBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(0)
+
+ self.headerLabel = qt.QLabel(self)
+ self.headerLabel.setAlignment(qt.Qt.AlignHCenter)
+ self.setHeader()
+ layout.addWidget(self.headerLabel)
+
+ widgetAllCheckbox = qt.QWidget(parent=self)
+ self._showAllCheckBox = qt.QCheckBox("show all ROI",
+ parent=widgetAllCheckbox)
+ widgetAllCheckbox.setLayout(qt.QHBoxLayout())
+ spacer = qt.QWidget(parent=widgetAllCheckbox)
+ spacer.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
+ widgetAllCheckbox.layout().addWidget(spacer)
+ widgetAllCheckbox.layout().addWidget(self._showAllCheckBox)
+ layout.addWidget(widgetAllCheckbox)
+
+ self.roiTable = ROITable(self, plot=plot)
+ rheight = self.roiTable.horizontalHeader().sizeHint().height()
+ self.roiTable.setMinimumHeight(4 * rheight)
+ layout.addWidget(self.roiTable)
+ self._roiFileDir = qt.QDir.home().absolutePath()
+ self._showAllCheckBox.toggled.connect(self.roiTable.showAllMarkers)
+
+ hbox = qt.QWidget(self)
+ hboxlayout = qt.QHBoxLayout(hbox)
+ hboxlayout.setContentsMargins(0, 0, 0, 0)
+ hboxlayout.setSpacing(0)
+
+ hboxlayout.addStretch(0)
+
+ self.addButton = qt.QPushButton(hbox)
+ self.addButton.setText("Add ROI")
+ self.addButton.setToolTip('Create a new ROI')
+ self.delButton = qt.QPushButton(hbox)
+ self.delButton.setText("Delete ROI")
+ self.addButton.setToolTip('Remove the selected ROI')
+ self.resetButton = qt.QPushButton(hbox)
+ self.resetButton.setText("Reset")
+ self.addButton.setToolTip('Clear all created ROIs. We only let the '
+ 'default ROI')
+
+ hboxlayout.addWidget(self.addButton)
+ hboxlayout.addWidget(self.delButton)
+ hboxlayout.addWidget(self.resetButton)
+
+ hboxlayout.addStretch(0)
+
+ self.loadButton = qt.QPushButton(hbox)
+ self.loadButton.setText("Load")
+ self.loadButton.setToolTip('Load ROIs from a .ini file')
+ self.saveButton = qt.QPushButton(hbox)
+ self.saveButton.setText("Save")
+ self.loadButton.setToolTip('Save ROIs to a .ini file')
+ hboxlayout.addWidget(self.loadButton)
+ hboxlayout.addWidget(self.saveButton)
+ layout.setStretchFactor(self.headerLabel, 0)
+ layout.setStretchFactor(self.roiTable, 1)
+ layout.setStretchFactor(hbox, 0)
+
+ layout.addWidget(hbox)
+
+ # Signal / Slot connections
+ self.addButton.clicked.connect(self._add)
+ self.delButton.clicked.connect(self._del)
+ self.resetButton.clicked.connect(self._reset)
+
+ self.loadButton.clicked.connect(self._load)
+ self.saveButton.clicked.connect(self._save)
+
+ self.roiTable.activeROIChanged.connect(self._emitCurrentROISignal)
+
+ self._isConnected = False # True if connected to plot signals
+ self._isInit = False
+
+ # expose API
+ self.getROIListAndDict = self.roiTable.getROIListAndDict
+
+ def getPlotWidget(self):
+ """Returns the associated PlotWidget or None
+
+ :rtype: Union[~silx.gui.plot.PlotWidget,None]
+ """
+ return None if self._plotRef is None else self._plotRef()
+
+ def showEvent(self, event):
+ self._visibilityChangedHandler(visible=True)
+ qt.QWidget.showEvent(self, event)
+
+ @property
+ def roiFileDir(self):
+ """The directory from which to load/save ROI from/to files."""
+ if not os.path.isdir(self._roiFileDir):
+ self._roiFileDir = qt.QDir.home().absolutePath()
+ return self._roiFileDir
+
+ @roiFileDir.setter
+ def roiFileDir(self, roiFileDir):
+ self._roiFileDir = str(roiFileDir)
+
+ def setRois(self, rois, order=None):
+ return self.roiTable.setRois(rois, order)
+
+ def getRois(self, order=None):
+ return self.roiTable.getRois(order)
+
+ def setMiddleROIMarkerFlag(self, flag=True):
+ return self.roiTable.setMiddleROIMarkerFlag(flag)
+
+ def _add(self):
+ """Add button clicked handler"""
+ def getNextRoiName():
+ rois = self.roiTable.getRois(order=None)
+ roisNames = []
+ [roisNames.append(roiName) for roiName in rois]
+ nrois = len(rois)
+ if nrois == 0:
+ return "ICR"
+ else:
+ i = 1
+ newroi = "newroi %d" % i
+ while newroi in roisNames:
+ i += 1
+ newroi = "newroi %d" % i
+ return newroi
+ roi = ROI(name=getNextRoiName())
+
+ if roi.getName() == "ICR":
+ roi.setType("Default")
+ else:
+ roi.setType(self.getPlotWidget().getXAxis().getLabel())
+
+ xmin, xmax = self.getPlotWidget().getXAxis().getLimits()
+ fromdata = xmin + 0.25 * (xmax - xmin)
+ todata = xmin + 0.75 * (xmax - xmin)
+ if roi.isICR():
+ fromdata, dummy0, todata, dummy1 = self._getAllLimits()
+ roi.setFrom(fromdata)
+ roi.setTo(todata)
+ self.roiTable.addRoi(roi)
+
+ # back compatibility pymca roi signals
+ ddict = {}
+ ddict['event'] = "AddROI"
+ ddict['roilist'] = self.roiTable.roidict.values()
+ ddict['roidict'] = self.roiTable.roidict
+ self.sigROIWidgetSignal.emit(ddict)
+ # end back compatibility pymca roi signals
+
+ def _del(self):
+ """Delete button clicked handler"""
+ self.roiTable.deleteActiveRoi()
+
+ # back compatibility pymca roi signals
+ ddict = {}
+ ddict['event'] = "DelROI"
+ ddict['roilist'] = self.roiTable.roidict.values()
+ ddict['roidict'] = self.roiTable.roidict
+ self.sigROIWidgetSignal.emit(ddict)
+ # end back compatibility pymca roi signals
+
+ def _reset(self):
+ """Reset button clicked handler"""
+ self.roiTable.clear()
+ old = self.blockSignals(True) # avoid several sigROISignal emission
+ self._add()
+ self.blockSignals(old)
+
+ # back compatibility pymca roi signals
+ ddict = {}
+ ddict['event'] = "ResetROI"
+ ddict['roilist'] = self.roiTable.roidict.values()
+ ddict['roidict'] = self.roiTable.roidict
+ self.sigROIWidgetSignal.emit(ddict)
+ # end back compatibility pymca roi signals
+
+ def _load(self):
+ """Load button clicked handler"""
+ dialog = qt.QFileDialog(self)
+ dialog.setNameFilters(
+ ['INI File *.ini', 'JSON File *.json', 'All *.*'])
+ dialog.setFileMode(qt.QFileDialog.ExistingFile)
+ dialog.setDirectory(self.roiFileDir)
+ if not dialog.exec():
+ dialog.close()
+ return
+
+ # pyflakes bug http://bugs.debian.org/cgi-bin/bugreport.cgi?bug=666494
+ outputFile = dialog.selectedFiles()[0]
+ dialog.close()
+
+ self.roiFileDir = os.path.dirname(outputFile)
+ self.roiTable.load(outputFile)
+
+ # back compatibility pymca roi signals
+ ddict = {}
+ ddict['event'] = "LoadROI"
+ ddict['roilist'] = self.roiTable.roidict.values()
+ ddict['roidict'] = self.roiTable.roidict
+ self.sigROIWidgetSignal.emit(ddict)
+ # end back compatibility pymca roi signals
+
+ def load(self, filename):
+ """Load ROI widget information from a file storing a dict of ROI.
+
+ :param str filename: The file from which to load ROI
+ """
+ self.roiTable.load(filename)
+
+ def _save(self):
+ """Save button clicked handler"""
+ dialog = qt.QFileDialog(self)
+ dialog.setNameFilters(['INI File *.ini', 'JSON File *.json'])
+ dialog.setFileMode(qt.QFileDialog.AnyFile)
+ dialog.setAcceptMode(qt.QFileDialog.AcceptSave)
+ dialog.setDirectory(self.roiFileDir)
+ if not dialog.exec():
+ dialog.close()
+ return
+
+ outputFile = dialog.selectedFiles()[0]
+ extension = '.' + dialog.selectedNameFilter().split('.')[-1]
+ dialog.close()
+
+ if not outputFile.endswith(extension):
+ outputFile += extension
+
+ if os.path.exists(outputFile):
+ try:
+ os.remove(outputFile)
+ except IOError:
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setText("Input Output Error: %s" % (sys.exc_info()[1]))
+ msg.exec()
+ return
+ self.roiFileDir = os.path.dirname(outputFile)
+ self.save(outputFile)
+
+ def save(self, filename):
+ """Save current ROIs of the widget as a dict of ROI to a file.
+
+ :param str filename: The file to which to save the ROIs
+ """
+ self.roiTable.save(filename)
+
+ def setHeader(self, text='ROIs'):
+ """Set the header text of this widget"""
+ self.headerLabel.setText("<b>%s<\b>" % text)
+
+ @deprecation.deprecated(replacement="calculateRois",
+ reason="CamelCase convention",
+ since_version="0.7")
+ def calculateROIs(self, *args, **kw):
+ self.calculateRois(*args, **kw)
+
+ def calculateRois(self, roiList=None, roiDict=None):
+ """Compute ROI information"""
+ return self.roiTable.calculateRois()
+
+ def showAllMarkers(self, _show=True):
+ self.roiTable.showAllMarkers(_show)
+
+ def _getAllLimits(self):
+ """Retrieve the limits based on the curves."""
+ plot = self.getPlotWidget()
+ curves = () if plot is None else plot.getAllCurves()
+ if not curves:
+ return 1.0, 1.0, 100., 100.
+
+ xmin, ymin = None, None
+ xmax, ymax = None, None
+
+ for curve in curves:
+ x = curve.getXData(copy=False)
+ y = curve.getYData(copy=False)
+ if xmin is None:
+ xmin = x.min()
+ else:
+ xmin = min(xmin, x.min())
+ if xmax is None:
+ xmax = x.max()
+ else:
+ xmax = max(xmax, x.max())
+ if ymin is None:
+ ymin = y.min()
+ else:
+ ymin = min(ymin, y.min())
+ if ymax is None:
+ ymax = y.max()
+ else:
+ ymax = max(ymax, y.max())
+
+ return xmin, ymin, xmax, ymax
+
+ def showEvent(self, event):
+ self._visibilityChangedHandler(visible=True)
+ qt.QWidget.showEvent(self, event)
+
+ def hideEvent(self, event):
+ self._visibilityChangedHandler(visible=False)
+ qt.QWidget.hideEvent(self, event)
+
+ def _visibilityChangedHandler(self, visible):
+ """Handle widget's visibility updates.
+
+ It is connected to plot signals only when visible.
+ """
+ if visible:
+ # if no ROI existing yet, add the default one
+ if self.roiTable.rowCount() == 0:
+ old = self.blockSignals(True) # avoid several sigROISignal emission
+ self._add()
+ self.blockSignals(old)
+ self.calculateRois()
+
+ def fillFromROIDict(self, *args, **kwargs):
+ self.roiTable.fillFromROIDict(*args, **kwargs)
+
+ def _emitCurrentROISignal(self):
+ ddict = {}
+ ddict['event'] = "currentROISignal"
+ if self.roiTable.activeRoi is not None:
+ ddict['ROI'] = self.roiTable.activeRoi.toDict()
+ ddict['current'] = self.roiTable.activeRoi.getName()
+ else:
+ ddict['current'] = None
+
+ if self.__lastSigROISignal != ddict:
+ self.__lastSigROISignal = ddict
+ self.sigROISignal.emit(ddict)
+
+ @property
+ def currentRoi(self):
+ return self.roiTable.activeRoi
+
+
+class _FloatItem(qt.QTableWidgetItem):
+ """
+ Simple QTableWidgetItem overloading the < operator to deal with ordering
+ """
+ def __init__(self):
+ qt.QTableWidgetItem.__init__(self, type=qt.QTableWidgetItem.Type)
+
+ def __lt__(self, other):
+ if self.text() in ('', ROITable.INFO_NOT_FOUND):
+ return False
+ if other.text() in ('', ROITable.INFO_NOT_FOUND):
+ return True
+ return float(self.text()) < float(other.text())
+
+
+class ROITable(TableWidget):
+ """Table widget displaying ROI information.
+
+ See :class:`QTableWidget` for constructor arguments.
+
+ Behavior: listen at the active curve changed only when the widget is
+ visible. Otherwise won't compute the row and net counts...
+ """
+
+ activeROIChanged = qt.Signal()
+ """Signal emitted when the active roi changed or when the value of the
+ active roi are changing"""
+
+ COLUMNS_INDEX = OrderedDict([
+ ('ID', 0),
+ ('ROI', 1),
+ ('Type', 2),
+ ('From', 3),
+ ('To', 4),
+ ('Raw Counts', 5),
+ ('Net Counts', 6),
+ ('Raw Area', 7),
+ ('Net Area', 8),
+ ])
+
+ COLUMNS = list(COLUMNS_INDEX.keys())
+
+ INFO_NOT_FOUND = '????????'
+
+ def __init__(self, parent=None, plot=None, rois=None):
+ super(ROITable, self).__init__(parent)
+ self._showAllMarkers = False
+ self._userIsEditingRoi = False
+ """bool used to avoid conflict when editing the ROI object"""
+ self._isConnected = False
+ self._roiToItems = {}
+ self._roiDict = {}
+ """dict of ROI object. Key is ROi id, value is the ROI object"""
+ self._markersHandler = _RoiMarkerManager()
+
+ """
+ Associate for each marker legend used when the `_showAllMarkers` option
+ is active a roi.
+ """
+ self.setColumnCount(len(self.COLUMNS))
+ self.setPlot(plot)
+ self.__setTooltip()
+ self.setSortingEnabled(True)
+ self.itemChanged.connect(self._itemChanged)
+
+ @property
+ def roidict(self):
+ return self._getRoiDict()
+
+ @property
+ def activeRoi(self):
+ return self._markersHandler._activeRoi
+
+ def _getRoiDict(self):
+ ddict = {}
+ for id in self._roiDict:
+ ddict[self._roiDict[id].getName()] = self._roiDict[id]
+ return ddict
+
+ def clear(self):
+ """
+ .. note:: clear the interface only. keep the roidict...
+ """
+ self._markersHandler.clear()
+ self._roiToItems = {}
+ self._roiDict = {}
+
+ qt.QTableWidget.clear(self)
+ self.setRowCount(0)
+ self.setHorizontalHeaderLabels(self.COLUMNS)
+ header = self.horizontalHeader()
+ header.setSectionResizeMode(qt.QHeaderView.ResizeToContents)
+ self.sortByColumn(0, qt.Qt.AscendingOrder)
+ self.hideColumn(self.COLUMNS_INDEX['ID'])
+
+ def setPlot(self, plot):
+ self.clear()
+ self.plot = plot
+
+ def __setTooltip(self):
+ self.horizontalHeaderItem(self.COLUMNS_INDEX['ROI']).setToolTip(
+ 'Region of interest identifier')
+ self.horizontalHeaderItem(self.COLUMNS_INDEX['Type']).setToolTip(
+ 'Type of the ROI')
+ self.horizontalHeaderItem(self.COLUMNS_INDEX['From']).setToolTip(
+ 'X-value of the min point')
+ self.horizontalHeaderItem(self.COLUMNS_INDEX['To']).setToolTip(
+ 'X-value of the max point')
+ self.horizontalHeaderItem(self.COLUMNS_INDEX['Raw Counts']).setToolTip(
+ 'Estimation of the integral between y=0 and the selected curve')
+ self.horizontalHeaderItem(self.COLUMNS_INDEX['Net Counts']).setToolTip(
+ 'Estimation of the integral between the segment [maxPt, minPt] '
+ 'and the selected curve')
+
+ def setRois(self, rois, order=None):
+ """Set the ROIs by providing a dictionary of ROI information.
+
+ The dictionary keys are the ROI names.
+ Each value is a sub-dictionary of ROI info with the following fields:
+
+ - ``"from"``: x coordinate of the left limit, as a float
+ - ``"to"``: x coordinate of the right limit, as a float
+ - ``"type"``: type of ROI, as a string (e.g "channels", "energy")
+
+
+ :param roidict: Dictionary of ROIs
+ :param str order: Field used for ordering the ROIs.
+ One of "from", "to", "type".
+ None (default) for no ordering, or same order as specified
+ in parameter ``roidict`` if provided as an OrderedDict.
+ """
+ assert order in [None, "from", "to", "type"]
+ self.clear()
+
+ # backward compatibility since 0.10.0
+ if isinstance(rois, dict):
+ for roiName, roi in rois.items():
+ if isinstance(roi, ROI):
+ _roi = roi
+ else:
+ roi['name'] = roiName
+ _roi = ROI._fromDict(roi)
+ self.addRoi(_roi)
+ else:
+ for roi in rois:
+ assert isinstance(roi, ROI)
+ self.addRoi(roi)
+ self._updateMarkers()
+
+ def addRoi(self, roi):
+ """
+
+ :param :class:`ROI` roi: roi to add to the table
+ """
+ assert isinstance(roi, ROI)
+ self._getItem(name='ID', row=None, roi=roi)
+ self._roiDict[roi.getID()] = roi
+ self._markersHandler.add(roi, _RoiMarkerHandler(roi, self.plot))
+ self._updateRoiInfo(roi.getID())
+ callback = functools.partial(WeakMethodProxy(self._updateRoiInfo),
+ roi.getID())
+ roi.sigChanged.connect(callback)
+ # set it as the active one
+ self.setActiveRoi(roi)
+
+ def _getItem(self, name, row, roi):
+ if row:
+ item = self.item(row, self.COLUMNS_INDEX[name])
+ else:
+ item = None
+ if item:
+ return item
+ else:
+ if name == 'ID':
+ assert roi
+ if roi.getID() in self._roiToItems:
+ return self._roiToItems[roi.getID()]
+ else:
+ # create a new row
+ row = self.rowCount()
+ self.setRowCount(self.rowCount() + 1)
+ item = qt.QTableWidgetItem(str(roi.getID()),
+ type=qt.QTableWidgetItem.Type)
+ self._roiToItems[roi.getID()] = item
+ elif name == 'ROI':
+ item = qt.QTableWidgetItem(roi.getName() if roi else '',
+ type=qt.QTableWidgetItem.Type)
+ if roi.getName().upper() in ('ICR', 'DEFAULT'):
+ item.setFlags(qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled)
+ else:
+ item.setFlags(qt.Qt.ItemIsSelectable |
+ qt.Qt.ItemIsEnabled |
+ qt.Qt.ItemIsEditable)
+ elif name == 'Type':
+ item = qt.QTableWidgetItem(type=qt.QTableWidgetItem.Type)
+ item.setFlags((qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled))
+ elif name in ('To', 'From'):
+ item = _FloatItem()
+ if roi.getName().upper() in ('ICR', 'DEFAULT'):
+ item.setFlags(qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled)
+ else:
+ item.setFlags(qt.Qt.ItemIsSelectable |
+ qt.Qt.ItemIsEnabled |
+ qt.Qt.ItemIsEditable)
+ elif name in ('Raw Counts', 'Net Counts', 'Raw Area', 'Net Area'):
+ item = _FloatItem()
+ item.setFlags((qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled))
+ else:
+ raise ValueError('item type not recognized')
+
+ self.setItem(row, self.COLUMNS_INDEX[name], item)
+ return item
+
+ def _itemChanged(self, item):
+ def getRoi():
+ IDItem = self.item(item.row(), self.COLUMNS_INDEX['ID'])
+ assert IDItem
+ id = int(IDItem.text())
+ assert id in self._roiDict
+ roi = self._roiDict[id]
+ return roi
+
+ def signalChanged(roi):
+ if self.activeRoi and roi.getID() == self.activeRoi.getID():
+ self.activeROIChanged.emit()
+
+ self._userIsEditingRoi = True
+ if item.column() in (self.COLUMNS_INDEX['To'], self.COLUMNS_INDEX['From']):
+ roi = getRoi()
+
+ if item.text() not in ('', self.INFO_NOT_FOUND):
+ try:
+ value = float(item.text())
+ except ValueError:
+ value = 0
+ changed = False
+ if item.column() == self.COLUMNS_INDEX['To']:
+ if value != roi.getTo():
+ roi.setTo(value)
+ changed = True
+ else:
+ assert(item.column() == self.COLUMNS_INDEX['From'])
+ if value != roi.getFrom():
+ roi.setFrom(value)
+ changed = True
+ if changed:
+ self._updateMarker(roi.getName())
+ signalChanged(roi)
+
+ if item.column() is self.COLUMNS_INDEX['ROI']:
+ roi = getRoi()
+ if roi.getName() != item.text():
+ roi.setName(item.text())
+ self._markersHandler.getMarkerHandler(roi.getID()).updateTexts()
+ signalChanged(roi)
+
+ self._userIsEditingRoi = False
+
+ def deleteActiveRoi(self):
+ """
+ remove the current active roi
+ """
+ activeItems = self.selectedItems()
+ if len(activeItems) == 0:
+ return
+ old = self.blockSignals(True) # avoid several emission of sigROISignal
+ roiToRm = set()
+ for item in activeItems:
+ row = item.row()
+ itemID = self.item(row, self.COLUMNS_INDEX['ID'])
+ roiToRm.add(self._roiDict[int(itemID.text())])
+ [self.removeROI(roi) for roi in roiToRm]
+ self.blockSignals(old)
+ self.setActiveRoi(None)
+
+ def removeROI(self, roi):
+ """
+ remove the requested roi
+
+ :param str name: the name of the roi to remove from the table
+ """
+ if roi and roi.getID() in self._roiToItems:
+ item = self._roiToItems[roi.getID()]
+ self.removeRow(item.row())
+ del self._roiToItems[roi.getID()]
+
+ assert roi.getID() in self._roiDict
+ del self._roiDict[roi.getID()]
+ self._markersHandler.remove(roi)
+
+ callback = functools.partial(WeakMethodProxy(self._updateRoiInfo),
+ roi.getID())
+ roi.sigChanged.connect(callback)
+
+ def setActiveRoi(self, roi):
+ """
+ Define the given roi as the active one.
+
+ .. warning:: this roi should already be registred / added to the table
+
+ :param :class:`ROI` roi: the roi to defined as active
+ """
+ if roi is None:
+ self.clearSelection()
+ self._markersHandler.setActiveRoi(None)
+ self.activeROIChanged.emit()
+ else:
+ assert isinstance(roi, ROI)
+ if roi and roi.getID() in self._roiToItems.keys():
+ # avoid several call back to setActiveROI
+ old = self.blockSignals(True)
+ self.selectRow(self._roiToItems[roi.getID()].row())
+ self.blockSignals(old)
+ self._markersHandler.setActiveRoi(roi)
+ self.activeROIChanged.emit()
+
+ def _updateRoiInfo(self, roiID):
+ if self._userIsEditingRoi is True:
+ return
+ if roiID not in self._roiDict:
+ return
+ roi = self._roiDict[roiID]
+ if roi.isICR():
+ activeCurve = self.plot.getActiveCurve()
+ if activeCurve:
+ xData = activeCurve.getXData()
+ if len(xData) > 0:
+ min, max = min_max(xData)
+ roi.blockSignals(True)
+ roi.setFrom(min)
+ roi.setTo(max)
+ roi.blockSignals(False)
+
+ itemID = self._getItem(name='ID', roi=roi, row=None)
+ itemName = self._getItem(name='ROI', row=itemID.row(), roi=roi)
+ itemName.setText(roi.getName())
+
+ itemType = self._getItem(name='Type', row=itemID.row(), roi=roi)
+ itemType.setText(roi.getType() or self.INFO_NOT_FOUND)
+
+ itemFrom = self._getItem(name='From', row=itemID.row(), roi=roi)
+ fromdata = str(roi.getFrom()) if roi.getFrom() is not None else self.INFO_NOT_FOUND
+ itemFrom.setText(fromdata)
+
+ itemTo = self._getItem(name='To', row=itemID.row(), roi=roi)
+ todata = str(roi.getTo()) if roi.getTo() is not None else self.INFO_NOT_FOUND
+ itemTo.setText(todata)
+
+ rawCounts, netCounts = roi.computeRawAndNetCounts(
+ curve=self.plot.getActiveCurve(just_legend=False))
+ itemRawCounts = self._getItem(name='Raw Counts', row=itemID.row(),
+ roi=roi)
+ rawCounts = str(rawCounts) if rawCounts is not None else self.INFO_NOT_FOUND
+ itemRawCounts.setText(rawCounts)
+
+ itemNetCounts = self._getItem(name='Net Counts', row=itemID.row(),
+ roi=roi)
+ netCounts = str(netCounts) if netCounts is not None else self.INFO_NOT_FOUND
+ itemNetCounts.setText(netCounts)
+
+ rawArea, netArea = roi.computeRawAndNetArea(
+ curve=self.plot.getActiveCurve(just_legend=False))
+ itemRawArea = self._getItem(name='Raw Area', row=itemID.row(),
+ roi=roi)
+ rawArea = str(rawArea) if rawArea is not None else self.INFO_NOT_FOUND
+ itemRawArea.setText(rawArea)
+
+ itemNetArea = self._getItem(name='Net Area', row=itemID.row(),
+ roi=roi)
+ netArea = str(netArea) if netArea is not None else self.INFO_NOT_FOUND
+ itemNetArea.setText(netArea)
+
+ if self.activeRoi and roi.getID() == self.activeRoi.getID():
+ self.activeROIChanged.emit()
+
+ def currentChanged(self, current, previous):
+ if previous and current.row() != previous.row() and current.row() >= 0:
+ roiItem = self.item(current.row(),
+ self.COLUMNS_INDEX['ID'])
+
+ assert roiItem
+ self.setActiveRoi(self._roiDict[int(roiItem.text())])
+ self._markersHandler.updateAllMarkers()
+ qt.QTableWidget.currentChanged(self, current, previous)
+
+ @deprecation.deprecated(reason="Removed",
+ replacement="roidict and roidict.values()",
+ since_version="0.10.0")
+ def getROIListAndDict(self):
+ """
+
+ :return: the list of roi objects and the dictionary of roi name to roi
+ object.
+ """
+ roidict = self._roiDict
+ return list(roidict.values()), roidict
+
+ def calculateRois(self, roiList=None, roiDict=None):
+ """
+ Update values of all registred rois (raw and net counts in particular)
+
+ :param roiList: deprecated parameter
+ :param roiDict: deprecated parameter
+ """
+ if roiDict:
+ deprecation.deprecated_warning(name='roiDict', type_='Parameter',
+ reason='Unused parameter',
+ since_version="0.10.0")
+ if roiList:
+ deprecation.deprecated_warning(name='roiList', type_='Parameter',
+ reason='Unused parameter',
+ since_version="0.10.0")
+
+ for roiID in self._roiDict:
+ self._updateRoiInfo(roiID)
+
+ def _updateMarker(self, roiID):
+ """Make sure the marker of the given roi name is updated"""
+ if self._showAllMarkers or (self.activeRoi
+ and self.activeRoi.getName() == roiID):
+ self._updateMarkers()
+
+ def _updateMarkers(self):
+ if self._showAllMarkers is True:
+ self._markersHandler.updateMarkers()
+ else:
+ if not self.activeRoi or not self.plot:
+ return
+ assert isinstance(self.activeRoi, ROI)
+ markerHandler = self._markersHandler.getMarkerHandler(self.activeRoi.getID())
+ if markerHandler is not None:
+ markerHandler.updateMarkers()
+
+ def getRois(self, order):
+ """
+ Return the currently defined ROIs, as an ordered dict.
+
+ The dictionary keys are the ROI names.
+ Each value is a :class:`ROI` object..
+
+ :param order: Field used for ordering the ROIs.
+ One of "from", "to", "type", "netcounts", "rawcounts".
+ None (default) to get the same order as displayed in the widget.
+ :return: Ordered dictionary of ROI information
+ """
+
+ if order is None or order.lower() == "none":
+ ordered_roilist = list(self._roiDict.values())
+ res = OrderedDict([(roi.getName(), self._roiDict[roi.getID()]) for roi in ordered_roilist])
+ else:
+ assert order in ["from", "to", "type", "netcounts", "rawcounts"]
+ ordered_roilist = sorted(self._roiDict.keys(),
+ key=lambda roi_id: self._roiDict[roi_id].get(order))
+ res = OrderedDict([(roi.getName(), self._roiDict[id]) for id in ordered_roilist])
+
+ return res
+
+ def save(self, filename):
+ """
+ Save current ROIs of the widget as a dict of ROI to a file.
+
+ :param str filename: The file to which to save the ROIs
+ """
+ roilist = []
+ roidict = {}
+ for roiID, roi in self._roiDict.items():
+ roilist.append(roi.toDict())
+ roidict[roi.getName()] = roi.toDict()
+ datadict = {'ROI': {'roilist': roilist, 'roidict': roidict}}
+ dictdump.dump(datadict, filename)
+
+ def load(self, filename):
+ """
+ Load ROI widget information from a file storing a dict of ROI.
+
+ :param str filename: The file from which to load ROI
+ """
+ roisDict = dictdump.load(filename)
+ rois = []
+
+ # Remove rawcounts and netcounts from ROIs
+ for roiDict in roisDict['ROI']['roidict'].values():
+ roiDict.pop('rawcounts', None)
+ roiDict.pop('netcounts', None)
+ rois.append(ROI._fromDict(roiDict))
+
+ self.setRois(rois)
+
+ def showAllMarkers(self, _show=True):
+ """
+
+ :param bool _show: if true show all the markers of all the ROIs
+ boundaries otherwise will only show the one of
+ the active ROI.
+ """
+ self._markersHandler.setShowAllMarkers(_show)
+
+ def setMiddleROIMarkerFlag(self, flag=True):
+ """
+ Activate or deactivate middle marker.
+
+ This allows shifting both min and max limits at once, by dragging
+ a marker located in the middle.
+
+ :param bool flag: True to activate middle ROI marker
+ """
+ self._markersHandler._middleROIMarkerFlag = flag
+
+ def _handleROIMarkerEvent(self, ddict):
+ """Handle plot signals related to marker events."""
+ if ddict['event'] == 'markerMoved':
+ label = ddict['label']
+ roiID = self._markersHandler.getRoiID(markerID=label)
+ if roiID is not None:
+ # avoid several emission of sigROISignal
+ old = self.blockSignals(True)
+ self._markersHandler.changePosition(markerID=label,
+ x=ddict['x'])
+ self.blockSignals(old)
+ self._updateRoiInfo(roiID)
+
+ def showEvent(self, event):
+ self._visibilityChangedHandler(visible=True)
+ qt.QWidget.showEvent(self, event)
+
+ def hideEvent(self, event):
+ self._visibilityChangedHandler(visible=False)
+ qt.QWidget.hideEvent(self, event)
+
+ def _visibilityChangedHandler(self, visible):
+ """Handle widget's visibility updates.
+
+ It is connected to plot signals only when visible.
+ """
+ if visible:
+ assert self.plot
+ if self._isConnected is False:
+ self.plot.sigPlotSignal.connect(self._handleROIMarkerEvent)
+ self.plot.sigActiveCurveChanged.connect(self._activeCurveChanged)
+ self._isConnected = True
+ self.calculateRois()
+ else:
+ if self._isConnected:
+ self.plot.sigPlotSignal.disconnect(self._handleROIMarkerEvent)
+ self.plot.sigActiveCurveChanged.disconnect(self._activeCurveChanged)
+ self._isConnected = False
+
+ def _activeCurveChanged(self, curve):
+ self.calculateRois()
+
+ def setCountsVisible(self, visible):
+ """
+ Display the columns relative to areas or not
+
+ :param bool visible: True if the columns 'Raw Area' and 'Net Area'
+ should be visible.
+ """
+ if visible is True:
+ self.showColumn(self.COLUMNS_INDEX['Raw Counts'])
+ self.showColumn(self.COLUMNS_INDEX['Net Counts'])
+ else:
+ self.hideColumn(self.COLUMNS_INDEX['Raw Counts'])
+ self.hideColumn(self.COLUMNS_INDEX['Net Counts'])
+
+ def setAreaVisible(self, visible):
+ """
+ Display the columns relative to areas or not
+
+ :param bool visible: True if the columns 'Raw Area' and 'Net Area'
+ should be visible.
+ """
+ if visible is True:
+ self.showColumn(self.COLUMNS_INDEX['Raw Area'])
+ self.showColumn(self.COLUMNS_INDEX['Net Area'])
+ else:
+ self.hideColumn(self.COLUMNS_INDEX['Raw Area'])
+ self.hideColumn(self.COLUMNS_INDEX['Net Area'])
+
+ def fillFromROIDict(self, roilist=(), roidict=None, currentroi=None):
+ """
+ This function API is kept for compatibility.
+ But `setRois` should be preferred.
+
+ Set the ROIs by providing a list of ROI names and a dictionary
+ of ROI information for each ROI.
+ The ROI names must match an existing dictionary key.
+ The name list is used to provide an order for the ROIs.
+ The dictionary's values are sub-dictionaries containing 3
+ mandatory fields:
+
+ - ``"from"``: x coordinate of the left limit, as a float
+ - ``"to"``: x coordinate of the right limit, as a float
+ - ``"type"``: type of ROI, as a string (e.g "channels", "energy")
+
+ :param roilist: List of ROI names (keys of roidict)
+ :type roilist: List
+ :param dict roidict: Dict of ROI information
+ :param currentroi: Name of the selected ROI or None (no selection)
+ """
+ if roidict is not None:
+ self.setRois(roidict)
+ else:
+ self.setRois(roilist)
+ if currentroi:
+ self.setActiveRoi(currentroi)
+
+
+_indexNextROI = 0
+
+
+class ROI(_RegionOfInterestBase):
+ """The Region Of Interest is defined by:
+
+ - A name
+ - A type. The type is the label of the x axis. This can be used to apply or
+ not some ROI to a curve and do some post processing.
+ - The x coordinate of the left limit (fromdata)
+ - The x coordinate of the right limit (todata)
+
+ :param str: name of the ROI
+ :param fromdata: left limit of the roi
+ :param todata: right limit of the roi
+ :param type: type of the ROI
+ """
+
+ sigChanged = qt.Signal()
+ """Signal emitted when the ROI is edited"""
+
+ def __init__(self, name, fromdata=None, todata=None, type_=None):
+ _RegionOfInterestBase.__init__(self)
+ self.setName(name)
+ global _indexNextROI
+ self._id = _indexNextROI
+ _indexNextROI += 1
+
+ self._fromdata = fromdata
+ self._todata = todata
+ self._type = type_ or 'Default'
+
+ self.sigItemChanged.connect(self.__itemChanged)
+
+ def __itemChanged(self, event):
+ """Handle name change"""
+ if event == items.ItemChangedType.NAME:
+ self.sigChanged.emit()
+
+ def getID(self):
+ """
+
+ :return int: the unique ID of the ROI
+ """
+ return self._id
+
+ def setType(self, type_):
+ """
+
+ :param str type_:
+ """
+ if self._type != type_:
+ self._type = type_
+ self.sigChanged.emit()
+
+ def getType(self):
+ """
+
+ :return str: the type of the ROI.
+ """
+ return self._type
+
+ def setFrom(self, frm):
+ """
+
+ :param frm: set x coordinate of the left limit
+ """
+ if self._fromdata != frm:
+ self._fromdata = frm
+ self.sigChanged.emit()
+
+ def getFrom(self):
+ """
+
+ :return: x coordinate of the left limit
+ """
+ return self._fromdata
+
+ def setTo(self, to):
+ """
+
+ :param to: x coordinate of the right limit
+ """
+ if self._todata != to:
+ self._todata = to
+ self.sigChanged.emit()
+
+ def getTo(self):
+ """
+
+ :return: x coordinate of the right limit
+ """
+ return self._todata
+
+ def getMiddle(self):
+ """
+
+ :return: middle position between 'from' and 'to' values
+ """
+ return 0.5 * (self.getFrom() + self.getTo())
+
+ def toDict(self):
+ """
+
+ :return: dict containing the roi parameters
+ """
+ ddict = {
+ 'type': self._type,
+ 'name': self.getName(),
+ 'from': self._fromdata,
+ 'to': self._todata,
+ }
+ if hasattr(self, '_extraInfo'):
+ ddict.update(self._extraInfo)
+ return ddict
+
+ @staticmethod
+ def _fromDict(dic):
+ assert 'name' in dic
+ roi = ROI(name=dic['name'])
+ roi._extraInfo = {}
+ for key in dic:
+ if key == 'from':
+ roi.setFrom(dic['from'])
+ elif key == 'to':
+ roi.setTo(dic['to'])
+ elif key == 'type':
+ roi.setType(dic['type'])
+ else:
+ roi._extraInfo[key] = dic[key]
+
+ return roi
+
+ def isICR(self):
+ """
+
+ :return: True if the ROI is the `ICR`
+ """
+ return self.getName() == 'ICR'
+
+ def computeRawAndNetCounts(self, curve):
+ """Compute the Raw and net counts in the ROI for the given curve.
+
+ - Raw count: Points values sum of the curve in the defined Region Of
+ Interest.
+
+ .. image:: img/rawCounts.png
+
+ - Net count: Raw counts minus background
+
+ .. image:: img/netCounts.png
+
+ :param CurveItem curve:
+ :return tuple: rawCount, netCount
+ """
+ assert isinstance(curve, items.Curve) or curve is None
+
+ if curve is None:
+ return None, None
+
+ x = curve.getXData(copy=False)
+ y = curve.getYData(copy=False)
+
+ idx = numpy.nonzero((self._fromdata <= x) &
+ (x <= self._todata))[0]
+ if len(idx):
+ xw = x[idx]
+ yw = y[idx]
+ rawCounts = yw.sum(dtype=numpy.float64)
+ deltaX = xw[-1] - xw[0]
+ deltaY = yw[-1] - yw[0]
+ if deltaX > 0.0:
+ slope = (deltaY / deltaX)
+ background = yw[0] + slope * (xw - xw[0])
+ netCounts = (rawCounts -
+ background.sum(dtype=numpy.float64))
+ else:
+ netCounts = 0.0
+ else:
+ rawCounts = 0.0
+ netCounts = 0.0
+ return rawCounts, netCounts
+
+ def computeRawAndNetArea(self, curve):
+ """Compute the Raw and net counts in the ROI for the given curve.
+
+ - Raw area: integral of the curve between the min ROI point and the
+ max ROI point to the y = 0 line.
+
+ .. image:: img/rawArea.png
+
+ - Net area: Raw counts minus background
+
+ .. image:: img/netArea.png
+
+ :param CurveItem curve:
+ :return tuple: rawArea, netArea
+ """
+ assert isinstance(curve, items.Curve) or curve is None
+
+ if curve is None:
+ return None, None
+
+ x = curve.getXData(copy=False)
+ y = curve.getYData(copy=False)
+
+ y = y[(x >= self._fromdata) & (x <= self._todata)]
+ x = x[(x >= self._fromdata) & (x <= self._todata)]
+
+ if x.size == 0:
+ return 0.0, 0.0
+
+ rawArea = numpy.trapz(y, x=x)
+ # to speed up and avoid an intersection calculation we are taking the
+ # closest index to the ROI
+ closestXLeftIndex = (numpy.abs(x - self.getFrom())).argmin()
+ closestXRightIndex = (numpy.abs(x - self.getTo())).argmin()
+ yBackground = y[closestXLeftIndex], y[closestXRightIndex]
+ background = numpy.trapz(yBackground, x=x)
+ netArea = rawArea - background
+ return rawArea, netArea
+
+ @docstring(_RegionOfInterestBase)
+ def contains(self, position):
+ return self._fromdata <= position[0] <= self._todata
+
+
+class _RoiMarkerManager(object):
+ """
+ Deal with all the ROI markers
+ """
+ def __init__(self):
+ self._roiMarkerHandlers = {}
+ self._middleROIMarkerFlag = False
+ self._showAllMarkers = False
+ self._activeRoi = None
+
+ def setActiveRoi(self, roi):
+ self._activeRoi = roi
+ self.updateAllMarkers()
+
+ def setShowAllMarkers(self, show):
+ if show != self._showAllMarkers:
+ self._showAllMarkers = show
+ self.updateAllMarkers()
+
+ def add(self, roi, markersHandler):
+ assert isinstance(roi, ROI)
+ assert isinstance(markersHandler, _RoiMarkerHandler)
+ if roi.getID() in self._roiMarkerHandlers:
+ raise ValueError('roi with the same ID already existing')
+ else:
+ self._roiMarkerHandlers[roi.getID()] = markersHandler
+
+ def getMarkerHandler(self, roiID):
+ if roiID in self._roiMarkerHandlers:
+ return self._roiMarkerHandlers[roiID]
+ else:
+ return None
+
+ def clear(self):
+ roisHandler = list(self._roiMarkerHandlers.values())
+ for roiHandler in roisHandler:
+ self.remove(roiHandler.roi)
+
+ def remove(self, roi):
+ if roi is None:
+ return
+ assert isinstance(roi, ROI)
+ if roi.getID() in self._roiMarkerHandlers:
+ self._roiMarkerHandlers[roi.getID()].clear()
+ del self._roiMarkerHandlers[roi.getID()]
+
+ def hasMarker(self, markerID):
+ assert type(markerID) is str
+ return self.getMarker(markerID) is not None
+
+ def changePosition(self, markerID, x):
+ markerHandler = self.getMarker(markerID)
+ if markerHandler is None:
+ raise ValueError('Marker %s not register' % markerID)
+ markerHandler.changePosition(markerID=markerID, x=x)
+
+ def updateMarker(self, markerID):
+ markerHandler = self.getMarker(markerID)
+ if markerHandler is None:
+ raise ValueError('Marker %s not register' % markerID)
+ roiID = self.getRoiID(markerID)
+ visible = (self._activeRoi and self._activeRoi.getID() == roiID) or self._showAllMarkers is True
+ markerHandler.setVisible(visible)
+ markerHandler.updateAllMarkers()
+
+ def updateRoiMarkers(self, roiID):
+ if roiID in self._roiMarkerHandlers:
+ visible = ((self._activeRoi and self._activeRoi.getID() == roiID)
+ or self._showAllMarkers is True)
+ _roi = self._roiMarkerHandlers[roiID]._roi()
+ if _roi and not _roi.isICR():
+ self._roiMarkerHandlers[roiID].showMiddleMarker(self._middleROIMarkerFlag)
+ self._roiMarkerHandlers[roiID].setVisible(visible)
+ self._roiMarkerHandlers[roiID].updateMarkers()
+
+ def getMarker(self, markerID):
+ assert type(markerID) is str
+ for marker in list(self._roiMarkerHandlers.values()):
+ if marker.hasMarker(markerID):
+ return marker
+
+ def updateMarkers(self):
+ for markerHandler in list(self._roiMarkerHandlers.values()):
+ markerHandler.updateMarkers()
+
+ def getRoiID(self, markerID):
+ for roiID, markerHandler in self._roiMarkerHandlers.items():
+ if markerHandler.hasMarker(markerID):
+ return roiID
+ return None
+
+ def setShowMiddleMarkers(self, show):
+ self._middleROIMarkerFlag = show
+ self._roiMarkerHandlers.updateAllMarkers()
+
+ def updateAllMarkers(self):
+ for roiID in self._roiMarkerHandlers:
+ self.updateRoiMarkers(roiID)
+
+ def getVisibleRois(self):
+ res = {}
+ for roiID, roiHandler in self._roiMarkerHandlers.items():
+ markers = (roiHandler.getMarker('min'), roiHandler.getMarker('max'),
+ roiHandler.getMarker('middle'))
+ for marker in markers:
+ if marker.isVisible():
+ if roiID not in res:
+ res[roiID] = []
+ res[roiID].append(marker)
+ return res
+
+
+class _RoiMarkerHandler(object):
+ """Used to deal with ROI markers used in ROITable"""
+ def __init__(self, roi, plot):
+ assert roi and isinstance(roi, ROI)
+ assert plot
+
+ self._roi = weakref.ref(roi)
+ self._plot = weakref.ref(plot)
+ self._draggable = False if roi.isICR() else True
+ self._color = 'black' if roi.isICR() else 'blue'
+ self._displayMidMarker = False
+ self._visible = True
+
+ @property
+ def draggable(self):
+ return self._draggable
+
+ @property
+ def plot(self):
+ return self._plot()
+
+ def clear(self):
+ if self.plot and self.roi:
+ self.plot.removeMarker(self._markerID('min'))
+ self.plot.removeMarker(self._markerID('max'))
+ self.plot.removeMarker(self._markerID('middle'))
+
+ @property
+ def roi(self):
+ return self._roi()
+
+ def setVisible(self, visible):
+ if visible != self._visible:
+ self._visible = visible
+ self.updateMarkers()
+
+ def showMiddleMarker(self, visible):
+ if self.draggable is False and visible is True:
+ _logger.warning("ROI is not draggable. Won't display middle marker")
+ return
+ self._displayMidMarker = visible
+ self.getMarker('middle').setVisible(self._displayMidMarker)
+
+ def updateMarkers(self):
+ if self.roi is None:
+ return
+ self._updateMinMarkerPos()
+ self._updateMaxMarkerPos()
+ self._updateMiddleMarkerPos()
+
+ def _updateMinMarkerPos(self):
+ self.getMarker('min').setPosition(x=self.roi.getFrom(), y=None)
+ self.getMarker('min').setVisible(self._visible)
+
+ def _updateMaxMarkerPos(self):
+ self.getMarker('max').setPosition(x=self.roi.getTo(), y=None)
+ self.getMarker('max').setVisible(self._visible)
+
+ def _updateMiddleMarkerPos(self):
+ self.getMarker('middle').setPosition(x=self.roi.getMiddle(), y=None)
+ self.getMarker('middle').setVisible(self._displayMidMarker and self._visible)
+
+ def getMarker(self, markerType):
+ if self.plot is None:
+ return None
+ assert markerType in ('min', 'max', 'middle')
+ if self.plot._getMarker(self._markerID(markerType)) is None:
+ assert self.roi
+ if markerType == 'min':
+ val = self.roi.getFrom()
+ elif markerType == 'max':
+ val = self.roi.getTo()
+ else:
+ val = self.roi.getMiddle()
+
+ _color = self._color
+ if markerType == 'middle':
+ _color = 'yellow'
+ self.plot.addXMarker(val,
+ legend=self._markerID(markerType),
+ text=self.getMarkerName(markerType),
+ color=_color,
+ draggable=self.draggable)
+ return self.plot._getMarker(self._markerID(markerType))
+
+ def _markerID(self, markerType):
+ assert markerType in ('min', 'max', 'middle')
+ assert self.roi
+ return '_'.join((str(self.roi.getID()), markerType))
+
+ def getMarkerName(self, markerType):
+ assert markerType in ('min', 'max', 'middle')
+ assert self.roi
+ return ' '.join((self.roi.getName(), markerType))
+
+ def updateTexts(self):
+ self.getMarker('min').setText(self.getMarkerName('min'))
+ self.getMarker('max').setText(self.getMarkerName('max'))
+ self.getMarker('middle').setText(self.getMarkerName('middle'))
+
+ def changePosition(self, markerID, x):
+ assert self.hasMarker(markerID)
+ markerType = self._getMarkerType(markerID)
+ assert markerType is not None
+ if self.roi is None:
+ return
+ if markerType == 'min':
+ self.roi.setFrom(x)
+ self._updateMiddleMarkerPos()
+ elif markerType == 'max':
+ self.roi.setTo(x)
+ self._updateMiddleMarkerPos()
+ else:
+ delta = x - 0.5 * (self.roi.getFrom() + self.roi.getTo())
+ self.roi.setFrom(self.roi.getFrom() + delta)
+ self.roi.setTo(self.roi.getTo() + delta)
+ self._updateMinMarkerPos()
+ self._updateMaxMarkerPos()
+
+ def hasMarker(self, marker):
+ return marker in (self._markerID('min'),
+ self._markerID('max'),
+ self._markerID('middle'))
+
+ def _getMarkerType(self, markerID):
+ if markerID.endswith('_min'):
+ return 'min'
+ elif markerID.endswith('_max'):
+ return 'max'
+ elif markerID.endswith('_middle'):
+ return 'middle'
+ else:
+ return None
+
+
+class CurvesROIDockWidget(qt.QDockWidget):
+ """QDockWidget with a :class:`CurvesROIWidget` connected to a PlotWindow.
+
+ It makes the link between the :class:`CurvesROIWidget` and the PlotWindow.
+
+ :param parent: See :class:`QDockWidget`
+ :param plot: :class:`.PlotWindow` instance on which to operate
+ :param name: See :class:`QDockWidget`
+ """
+ sigROISignal = qt.Signal(object)
+ """Deprecated signal for backward compatibility with silx < 0.7.
+ Prefer connecting directly to :attr:`CurvesRoiWidget.sigRoiSignal`
+ """
+
+ def __init__(self, parent=None, plot=None, name=None):
+ super(CurvesROIDockWidget, self).__init__(name, parent)
+
+ assert plot is not None
+ self.plot = plot
+ self.roiWidget = CurvesROIWidget(self, name, plot=plot)
+ """Main widget of type :class:`CurvesROIWidget`"""
+
+ # convenience methods to offer a simpler API allowing to ignore
+ # the details of the underlying implementation
+ # (ALL DEPRECATED)
+ self.calculateROIs = self.calculateRois = self.roiWidget.calculateRois
+ self.setRois = self.roiWidget.setRois
+ self.getRois = self.roiWidget.getRois
+
+ self.roiWidget.sigROISignal.connect(self._forwardSigROISignal)
+
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self.setWidget(self.roiWidget)
+
+ self.setAreaVisible = self.roiWidget.roiTable.setAreaVisible
+ self.setCountsVisible = self.roiWidget.roiTable.setCountsVisible
+
+ def _forwardSigROISignal(self, ddict):
+ # emit deprecated signal for backward compatibility (silx < 0.7)
+ self.sigROISignal.emit(ddict)
+
+ def toggleViewAction(self):
+ """Returns a checkable action that shows or closes this widget.
+
+ See :class:`QMainWindow`.
+ """
+ action = super(CurvesROIDockWidget, self).toggleViewAction()
+ action.setIcon(icons.getQIcon('plot-roi'))
+ return action
+
+ def showEvent(self, event):
+ """Make sure this widget is raised when it is shown
+ (when it is first created as a tab in PlotWindow or when it is shown
+ again after hiding).
+ """
+ self.raise_()
+ qt.QDockWidget.showEvent(self, event)
+
+ @property
+ def currentROI(self):
+ return self.roiWidget.currentRoi
diff --git a/src/silx/gui/plot/ImageStack.py b/src/silx/gui/plot/ImageStack.py
new file mode 100644
index 0000000..1588a31
--- /dev/null
+++ b/src/silx/gui/plot/ImageStack.py
@@ -0,0 +1,640 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2020-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Image stack view with data prefetch capabilty."""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "04/03/2019"
+
+
+from silx.gui import icons, qt
+from silx.gui.plot import Plot2D
+from silx.gui.utils import concurrent
+from silx.io.url import DataUrl
+from silx.io.utils import get_data
+from collections import OrderedDict
+from silx.gui.widgets.FrameBrowser import HorizontalSliderWithBrowser
+import time
+import threading
+import typing
+import logging
+
+_logger = logging.getLogger(__name__)
+
+
+class _PlotWithWaitingLabel(qt.QWidget):
+ """Image plot widget with an overlay 'waiting' status.
+ """
+
+ class AnimationThread(threading.Thread):
+ def __init__(self, label):
+ self.running = True
+ self._label = label
+ self.animated_icon = icons.getWaitIcon()
+ self.animated_icon.register(self._label)
+ super(_PlotWithWaitingLabel.AnimationThread, self).__init__()
+
+ def run(self):
+ while self.running:
+ time.sleep(0.05)
+ icon = self.animated_icon.currentIcon()
+ self.future_result = concurrent.submitToQtMainThread(
+ self._label.setPixmap, icon.pixmap(30, state=qt.QIcon.On))
+
+ def stop(self):
+ """Stop the update thread"""
+ if self.running:
+ self.animated_icon.unregister(self._label)
+ self.running = False
+ self.join(2)
+
+ def __init__(self, parent):
+ super(_PlotWithWaitingLabel, self).__init__(parent=parent)
+ self._autoResetZoom = True
+ layout = qt.QStackedLayout(self)
+ layout.setStackingMode(qt.QStackedLayout.StackAll)
+
+ self._waiting_label = qt.QLabel(parent=self)
+ self._waiting_label.setAlignment(qt.Qt.AlignHCenter | qt.Qt.AlignVCenter)
+ layout.addWidget(self._waiting_label)
+
+ self._plot = Plot2D(parent=self)
+ layout.addWidget(self._plot)
+
+ self.updateThread = _PlotWithWaitingLabel.AnimationThread(self._waiting_label)
+ self.updateThread.start()
+
+ def close(self) -> bool:
+ super(_PlotWithWaitingLabel, self).close()
+ self.stopUpdateThread()
+
+ def stopUpdateThread(self):
+ self.updateThread.stop()
+
+ def setAutoResetZoom(self, reset):
+ """
+ Should we reset the zoom when adding an image (eq. when browsing)
+
+ :param bool reset:
+ """
+ self._autoResetZoom = reset
+ if self._autoResetZoom:
+ self._plot.resetZoom()
+
+ def isAutoResetZoom(self):
+ """
+
+ :return: True if a reset is done when the image change
+ :rtype: bool
+ """
+ return self._autoResetZoom
+
+ def setWaiting(self, activate=True):
+ if activate is True:
+ self._plot.clear()
+ self._waiting_label.show()
+ else:
+ self._waiting_label.hide()
+
+ def setData(self, data):
+ self.setWaiting(activate=False)
+ self._plot.addImage(data=data, resetzoom=self._autoResetZoom)
+
+ def clear(self):
+ self._plot.clear()
+ self.setWaiting(False)
+
+ def getPlotWidget(self):
+ return self._plot
+
+
+class _HorizontalSlider(HorizontalSliderWithBrowser):
+
+ sigCurrentUrlIndexChanged = qt.Signal(int)
+
+ def __init__(self, parent):
+ super(_HorizontalSlider, self).__init__(parent=parent)
+ # connect signal / slot
+ self.valueChanged.connect(self._urlChanged)
+
+ def setUrlIndex(self, index):
+ self.setValue(index)
+ self.sigCurrentUrlIndexChanged.emit(index)
+
+ def _urlChanged(self, value):
+ self.sigCurrentUrlIndexChanged.emit(value)
+
+
+class UrlList(qt.QWidget):
+ """List of URLs the user to select an URL"""
+
+ sigCurrentUrlChanged = qt.Signal(str)
+ """Signal emitted when the active/current url change"""
+
+ def __init__(self, parent=None):
+ super(UrlList, self).__init__(parent)
+ self.setLayout(qt.QVBoxLayout())
+ self.layout().setSpacing(0)
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self._listWidget = qt.QListWidget(parent=self)
+ self.layout().addWidget(self._listWidget)
+
+ # connect signal / Slot
+ self._listWidget.currentItemChanged.connect(self._notifyCurrentUrlChanged)
+
+ # expose API
+ self.currentItem = self._listWidget.currentItem
+
+ def setUrls(self, urls: list) -> None:
+ url_names = []
+ [url_names.append(url.path()) for url in urls]
+ self._listWidget.addItems(url_names)
+
+ def _notifyCurrentUrlChanged(self, current, previous):
+ if current is None:
+ pass
+ else:
+ self.sigCurrentUrlChanged.emit(current.text())
+
+ def setUrl(self, url: DataUrl) -> None:
+ assert isinstance(url, DataUrl)
+ sel_items = self._listWidget.findItems(url.path(), qt.Qt.MatchExactly)
+ if sel_items is None:
+ _logger.warning(url.path(), ' is not registered in the list.')
+ elif len(sel_items) > 0:
+ item = sel_items[0]
+ self._listWidget.setCurrentItem(item)
+ self.sigCurrentUrlChanged.emit(item.text())
+
+ def clear(self):
+ self._listWidget.clear()
+
+
+class _ToggleableUrlSelectionTable(qt.QWidget):
+
+ _BUTTON_ICON = qt.QStyle.SP_ToolBarHorizontalExtensionButton # noqa
+
+ sigCurrentUrlChanged = qt.Signal(str)
+ """Signal emitted when the active/current url change"""
+
+ def __init__(self, parent=None) -> None:
+ qt.QWidget.__init__(self, parent)
+ self.setLayout(qt.QGridLayout())
+ self._toggleButton = qt.QPushButton(parent=self)
+ self.layout().addWidget(self._toggleButton, 0, 2, 1, 1)
+ self._toggleButton.setSizePolicy(qt.QSizePolicy.Fixed,
+ qt.QSizePolicy.Fixed)
+
+ self._urlsTable = UrlList(parent=self)
+ self.layout().addWidget(self._urlsTable, 1, 1, 1, 2)
+
+ # set up
+ self._setButtonIcon(show=True)
+
+ # Signal / slot connection
+ self._toggleButton.clicked.connect(self.toggleUrlSelectionTable)
+ self._urlsTable.sigCurrentUrlChanged.connect(self._propagateSignal)
+
+ # expose API
+ self.setUrls = self._urlsTable.setUrls
+ self.setUrl = self._urlsTable.setUrl
+ self.currentItem = self._urlsTable.currentItem
+
+ def toggleUrlSelectionTable(self):
+ visible = not self.urlSelectionTableIsVisible()
+ self._setButtonIcon(show=visible)
+ self._urlsTable.setVisible(visible)
+
+ def _setButtonIcon(self, show):
+ style = qt.QApplication.instance().style()
+ # return a QIcon
+ icon = style.standardIcon(self._BUTTON_ICON)
+ if show is False:
+ pixmap = icon.pixmap(32, 32).transformed(qt.QTransform().scale(-1, 1))
+ icon = qt.QIcon(pixmap)
+ self._toggleButton.setIcon(icon)
+
+ def urlSelectionTableIsVisible(self):
+ return self._urlsTable.isVisible()
+
+ def _propagateSignal(self, url):
+ self.sigCurrentUrlChanged.emit(url)
+
+ def clear(self):
+ self._urlsTable.clear()
+
+
+class UrlLoader(qt.QThread):
+ """
+ Thread use to load DataUrl
+ """
+ def __init__(self, parent, url):
+ super(UrlLoader, self).__init__(parent=parent)
+ assert isinstance(url, DataUrl)
+ self.url = url
+ self.data = None
+
+ def run(self):
+ try:
+ self.data = get_data(self.url)
+ except IOError:
+ self.data = None
+
+
+class ImageStack(qt.QMainWindow):
+ """Widget loading on the fly images contained the given urls.
+
+ It prefetches images close to the displayed one.
+ """
+
+ N_PRELOAD = 10
+
+ sigLoaded = qt.Signal(str)
+ """Signal emitted when new data is available"""
+
+ sigCurrentUrlChanged = qt.Signal(str)
+ """Signal emitted when the current url change"""
+
+ def __init__(self, parent=None) -> None:
+ super(ImageStack, self).__init__(parent)
+ self.__n_prefetch = ImageStack.N_PRELOAD
+ self._loadingThreads = []
+ self.setWindowFlags(qt.Qt.Widget)
+ self._current_url = None
+ self._url_loader = UrlLoader
+ "class to instantiate for loading urls"
+
+ # main widget
+ self._plot = _PlotWithWaitingLabel(parent=self)
+ self._plot.setAttribute(qt.Qt.WA_DeleteOnClose, True)
+ self.setWindowTitle("Image stack")
+ self.setCentralWidget(self._plot)
+
+ # dock widget: url table
+ self._tableDockWidget = qt.QDockWidget(parent=self)
+ self._urlsTable = _ToggleableUrlSelectionTable(parent=self)
+ self._tableDockWidget.setWidget(self._urlsTable)
+ self._tableDockWidget.setFeatures(qt.QDockWidget.DockWidgetMovable)
+ self.addDockWidget(qt.Qt.RightDockWidgetArea, self._tableDockWidget)
+ # dock widget: qslider
+ self._sliderDockWidget = qt.QDockWidget(parent=self)
+ self._slider = _HorizontalSlider(parent=self)
+ self._sliderDockWidget.setWidget(self._slider)
+ self.addDockWidget(qt.Qt.BottomDockWidgetArea, self._sliderDockWidget)
+ self._sliderDockWidget.setFeatures(qt.QDockWidget.DockWidgetMovable)
+
+ self.reset()
+
+ # connect signal / slot
+ self._urlsTable.sigCurrentUrlChanged.connect(self.setCurrentUrl)
+ self._slider.sigCurrentUrlIndexChanged.connect(self.setCurrentUrlIndex)
+
+ def close(self) -> bool:
+ self._freeLoadingThreads()
+ self._plot.close()
+ super(ImageStack, self).close()
+
+ def setUrlLoaderClass(self, urlLoader: typing.Type[UrlLoader]) -> None:
+ """
+
+ :param urlLoader: define the class to call for loading urls.
+ warning: this should be a class object and not a
+ class instance.
+ """
+ assert isinstance(urlLoader, type(UrlLoader))
+ self._url_loader = urlLoader
+
+ def getUrlLoaderClass(self):
+ """
+
+ :return: class to instantiate for loading urls
+ :rtype: typing.Type[UrlLoader]
+ """
+ return self._url_loader
+
+ def _freeLoadingThreads(self):
+ for thread in self._loadingThreads:
+ thread.blockSignals(True)
+ thread.wait(5)
+ self._loadingThreads.clear()
+
+ def getPlotWidget(self) -> Plot2D:
+ """
+ Returns the PlotWidget contained in this window
+
+ :return: PlotWidget contained in this window
+ :rtype: Plot2D
+ """
+ return self._plot.getPlotWidget()
+
+ def reset(self) -> None:
+ """Clear the plot and remove any link to url"""
+ self._freeLoadingThreads()
+ self._urls = None
+ self._urlIndexes = None
+ self._urlData = OrderedDict({})
+ self._current_url = None
+ self._plot.clear()
+ self._urlsTable.clear()
+ self._slider.setMaximum(-1)
+
+ def _preFetch(self, urls: list) -> None:
+ """Pre-fetch the given urls if necessary
+
+ :param urls: list of DataUrl to prefetch
+ :type: list
+ """
+ for url in urls:
+ if url.path() not in self._urlData:
+ self._load(url)
+
+ def _load(self, url):
+ """
+ Launch background load of a DataUrl
+
+ :param url:
+ :type: DataUrl
+ """
+ assert isinstance(url, DataUrl)
+ url_path = url.path()
+ assert url_path in self._urlIndexes
+ loader = self._url_loader(parent=self, url=url)
+ loader.finished.connect(self._urlLoaded, qt.Qt.QueuedConnection)
+ self._loadingThreads.append(loader)
+ loader.start()
+
+ def _urlLoaded(self) -> None:
+ """
+
+ :param url: restul of DataUrl.path() function
+ :return:
+ """
+ sender = self.sender()
+ assert isinstance(sender, UrlLoader)
+ url = sender.url.path()
+ if url in self._urlIndexes:
+ self._urlData[url] = sender.data
+ if self.getCurrentUrl().path() == url:
+ self._plot.setData(self._urlData[url])
+ if sender in self._loadingThreads:
+ self._loadingThreads.remove(sender)
+ self.sigLoaded.emit(url)
+
+ def setNPrefetch(self, n: int) -> None:
+ """
+ Define the number of url to prefetch around
+
+ :param int n: number of url to prefetch on left and right sides.
+ In total n*2 DataUrl will be prefetch
+ """
+ self.__n_prefetch = n
+ current_url = self.getCurrentUrl()
+ if current_url is not None:
+ self.setCurrentUrl(current_url)
+
+ def getNPrefetch(self) -> int:
+ """
+
+ :return: number of url to prefetch on left and right sides. In total
+ will load 2* NPrefetch DataUrls
+ """
+ return self.__n_prefetch
+
+ def setUrls(self, urls: list) -> None:
+ """list of urls within an index. Warning: urls should contain an image
+ compatible with the silx.gui.plot.Plot class
+
+ :param urls: urls we want to set in the stack. Key is the index
+ (position in the stack), value is the DataUrl
+ :type: list
+ """
+ def createUrlIndexes():
+ indexes = OrderedDict()
+ for index, url in enumerate(urls):
+ indexes[index] = url
+ return indexes
+
+ urls_with_indexes = createUrlIndexes()
+ urlsToIndex = self._urlsToIndex(urls_with_indexes)
+ self.reset()
+ self._urls = urls_with_indexes
+ self._urlIndexes = urlsToIndex
+
+ old_url_table = self._urlsTable.blockSignals(True)
+ self._urlsTable.setUrls(urls=list(self._urls.values()))
+ self._urlsTable.blockSignals(old_url_table)
+
+ old_slider = self._slider.blockSignals(True)
+ self._slider.setMinimum(0)
+ self._slider.setMaximum(len(self._urls) - 1)
+ self._slider.blockSignals(old_slider)
+
+ if self.getCurrentUrl() in self._urls:
+ self.setCurrentUrl(self.getCurrentUrl())
+ else:
+ if len(self._urls.keys()) > 0:
+ first_url = self._urls[list(self._urls.keys())[0]]
+ self.setCurrentUrl(first_url)
+
+ def getUrls(self) -> tuple:
+ """
+
+ :return: tuple of urls
+ :rtype: tuple
+ """
+ return tuple(self._urlIndexes.keys())
+
+ def _getNextUrl(self, url: DataUrl) -> typing.Union[None, DataUrl]:
+ """
+ return the next url in the stack
+
+ :param url: url for which we want the next url
+ :type: DataUrl
+ :return: next url in the stack or None if `url` is the last one
+ :rtype: Union[None, DataUrl]
+ """
+ assert isinstance(url, DataUrl)
+ if self._urls is None:
+ return None
+ else:
+ index = self._urlIndexes[url.path()]
+ indexes = list(self._urls.keys())
+ res = list(filter(lambda x: x > index, indexes))
+ if len(res) == 0:
+ return None
+ else:
+ return self._urls[res[0]]
+
+ def _getPreviousUrl(self, url: DataUrl) -> typing.Union[None, DataUrl]:
+ """
+ return the previous url in the stack
+
+ :param url: url for which we want the previous url
+ :type: DataUrl
+ :return: next url in the stack or None if `url` is the last one
+ :rtype: Union[None, DataUrl]
+ """
+ if self._urls is None:
+ return None
+ else:
+ index = self._urlIndexes[url.path()]
+ indexes = list(self._urls.keys())
+ res = list(filter(lambda x: x < index, indexes))
+ if len(res) == 0:
+ return None
+ else:
+ return self._urls[res[-1]]
+
+ def _getNNextUrls(self, n: int, url: DataUrl) -> list:
+ """
+ Deduce the next urls in the stack after `url`
+
+ :param n: the number of url store after `url`
+ :type: int
+ :param url: url for which we want n next url
+ :type: DataUrl
+ :return: list of next urls.
+ :rtype: list
+ """
+ res = []
+ next_free = self._getNextUrl(url=url)
+ while len(res) < n and next_free is not None:
+ assert isinstance(next_free, DataUrl)
+ res.append(next_free)
+ next_free = self._getNextUrl(res[-1])
+ return res
+
+ def _getNPreviousUrls(self, n: int, url: DataUrl):
+ """
+ Deduce the previous urls in the stack after `url`
+
+ :param n: the number of url store after `url`
+ :type: int
+ :param url: url for which we want n previous url
+ :type: DataUrl
+ :return: list of previous urls.
+ :rtype: list
+ """
+ res = []
+ next_free = self._getPreviousUrl(url=url)
+ while len(res) < n and next_free is not None:
+ res.insert(0, next_free)
+ next_free = self._getPreviousUrl(res[0])
+ return res
+
+ def setCurrentUrlIndex(self, index: int):
+ """
+ Define the url to be displayed
+
+ :param index: url to be displayed
+ :type: int
+ """
+ if index < 0:
+ return
+ if self._urls is None:
+ return
+ elif index >= len(self._urls):
+ raise ValueError('requested index out of bounds')
+ else:
+ return self.setCurrentUrl(self._urls[index])
+
+ def setCurrentUrl(self, url: typing.Union[DataUrl, str]) -> None:
+ """
+ Define the url to be displayed
+
+ :param url: url to be displayed
+ :type: DataUrl
+ """
+ assert isinstance(url, (DataUrl, str))
+ if isinstance(url, str):
+ url = DataUrl(path=url)
+ if url != self._current_url:
+ self._current_url = url
+ self.sigCurrentUrlChanged.emit(url.path())
+
+ old_url_table = self._urlsTable.blockSignals(True)
+ old_slider = self._slider.blockSignals(True)
+
+ self._urlsTable.setUrl(url)
+ self._slider.setUrlIndex(self._urlIndexes[url.path()])
+ if self._current_url is None:
+ self._plot.clear()
+ else:
+ if self._current_url.path() in self._urlData:
+ self._plot.setData(self._urlData[url.path()])
+ else:
+ self._load(url)
+ self._notifyLoading()
+ self._preFetch(self._getNNextUrls(self.__n_prefetch, url))
+ self._preFetch(self._getNPreviousUrls(self.__n_prefetch, url))
+ self._urlsTable.blockSignals(old_url_table)
+ self._slider.blockSignals(old_slider)
+
+ def getCurrentUrl(self) -> typing.Union[None, DataUrl]:
+ """
+
+ :return: url currently displayed
+ :rtype: Union[None, DataUrl]
+ """
+ return self._current_url
+
+ def getCurrentUrlIndex(self) -> typing.Union[None, int]:
+ """
+
+ :return: index of the url currently displayed
+ :rtype: Union[None, int]
+ """
+ if self._current_url is None:
+ return None
+ else:
+ return self._urlIndexes[self._current_url.path()]
+
+ @staticmethod
+ def _urlsToIndex(urls):
+ """util, return a dictionary with url as key and index as value"""
+ res = {}
+ for index, url in urls.items():
+ res[url.path()] = index
+ return res
+
+ def _notifyLoading(self):
+ """display a simple image of loading..."""
+ self._plot.setWaiting(activate=True)
+
+ def setAutoResetZoom(self, reset):
+ """
+ Should we reset the zoom when adding an image (eq. when browsing)
+
+ :param bool reset:
+ """
+ self._plot.setAutoResetZoom(reset)
+
+ def isAutoResetZoom(self) -> bool:
+ """
+
+ :return: True if a reset is done when the image change
+ :rtype: bool
+ """
+ return self._plot.isAutoResetZoom()
diff --git a/src/silx/gui/plot/ImageView.py b/src/silx/gui/plot/ImageView.py
new file mode 100644
index 0000000..f8b830a
--- /dev/null
+++ b/src/silx/gui/plot/ImageView.py
@@ -0,0 +1,1057 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""QWidget displaying a 2D image with histograms on its sides.
+
+The :class:`ImageView` implements this widget, and
+:class:`ImageViewMainWindow` provides a main window with additional toolbar
+and status bar.
+
+Basic usage of :class:`ImageView` is through the following methods:
+
+- :meth:`ImageView.getColormap`, :meth:`ImageView.setColormap` to update the
+ default colormap to use and update the currently displayed image.
+- :meth:`ImageView.setImage` to update the displayed image.
+
+For an example of use, see `imageview.py` in :ref:`sample-code`.
+"""
+
+from __future__ import division
+
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "26/04/2018"
+
+
+import logging
+import numpy
+import collections
+from typing import Union
+
+import silx
+from .. import qt
+from .. import colors
+from .. import icons
+
+from . import items, PlotWindow, PlotWidget, actions
+from ..colors import Colormap
+from ..colors import cursorColorForColormap
+from .tools import LimitsToolBar
+from .Profile import ProfileToolBar
+from ...utils.proxy import docstring
+from ...utils.deprecation import deprecated
+from ...utils.enum import Enum
+from .tools.RadarView import RadarView
+from .utils.axis import SyncAxes
+from ..utils import blockSignals
+from . import _utils
+from .tools.profile import manager
+from .tools.profile import rois
+from .actions import PlotAction
+
+_logger = logging.getLogger(__name__)
+
+
+ProfileSumResult = collections.namedtuple("ProfileResult",
+ ["dataXRange", "dataYRange",
+ 'histoH', 'histoHRange',
+ 'histoV', 'histoVRange',
+ "xCoords", "xData",
+ "yCoords", "yData"])
+
+
+def computeProfileSumOnRange(imageItem, xRange, yRange, cache=None):
+ """
+ Compute a full vertical and horizontal profile on an image item using a
+ a range in the plot referential.
+
+ Optionally takes a previous computed result to be able to skip the
+ computation.
+
+ :rtype: ProfileSumResult
+ """
+ data = imageItem.getValueData(copy=False)
+ origin = imageItem.getOrigin()
+ scale = imageItem.getScale()
+ height, width = data.shape
+
+ xMin, xMax = xRange
+ yMin, yMax = yRange
+
+ # Convert plot area limits to image coordinates
+ # and work in image coordinates (i.e., in pixels)
+ xMin = int((xMin - origin[0]) / scale[0])
+ xMax = int((xMax - origin[0]) / scale[0])
+ yMin = int((yMin - origin[1]) / scale[1])
+ yMax = int((yMax - origin[1]) / scale[1])
+
+ if (xMin >= width or xMax < 0 or
+ yMin >= height or yMax < 0):
+ return None
+
+ # The image is at least partly in the plot area
+ # Get the visible bounds in image coords (i.e., in pixels)
+ subsetXMin = 0 if xMin < 0 else xMin
+ subsetXMax = (width if xMax >= width else xMax) + 1
+ subsetYMin = 0 if yMin < 0 else yMin
+ subsetYMax = (height if yMax >= height else yMax) + 1
+
+ if cache is not None:
+ if ((subsetXMin, subsetXMax) == cache.dataXRange and
+ (subsetYMin, subsetYMax) == cache.dataYRange):
+ # The visible area of data is the same
+ return cache
+
+ # Rebuild histograms for visible area
+ visibleData = data[subsetYMin:subsetYMax,
+ subsetXMin:subsetXMax]
+ histoHVisibleData = numpy.nansum(visibleData, axis=0)
+ histoVVisibleData = numpy.nansum(visibleData, axis=1)
+ histoHMin = numpy.nanmin(histoHVisibleData)
+ histoHMax = numpy.nanmax(histoHVisibleData)
+ histoVMin = numpy.nanmin(histoVVisibleData)
+ histoVMax = numpy.nanmax(histoVVisibleData)
+
+ # Convert to histogram curve and update plots
+ # Taking into account origin and scale
+ coords = numpy.arange(2 * histoHVisibleData.size)
+ xCoords = (coords + 1) // 2 + subsetXMin
+ xCoords = origin[0] + scale[0] * xCoords
+ xData = numpy.take(histoHVisibleData, coords // 2)
+ coords = numpy.arange(2 * histoVVisibleData.size)
+ yCoords = (coords + 1) // 2 + subsetYMin
+ yCoords = origin[1] + scale[1] * yCoords
+ yData = numpy.take(histoVVisibleData, coords // 2)
+
+ result = ProfileSumResult(
+ dataXRange=(subsetXMin, subsetXMax),
+ dataYRange=(subsetYMin, subsetYMax),
+ histoH=histoHVisibleData,
+ histoHRange=(histoHMin, histoHMax),
+ histoV=histoVVisibleData,
+ histoVRange=(histoVMin, histoVMax),
+ xCoords=xCoords,
+ xData=xData,
+ yCoords=yCoords,
+ yData=yData)
+
+ return result
+
+
+class _SideHistogram(PlotWidget):
+ """
+ Widget displaying one of the side profile of the ImageView.
+
+ Implement ProfileWindow
+ """
+
+ sigClose = qt.Signal()
+
+ sigMouseMoved = qt.Signal(float, float)
+
+ def __init__(self, parent=None, backend=None, direction=qt.Qt.Horizontal):
+ super(_SideHistogram, self).__init__(parent=parent, backend=backend)
+ self._direction = direction
+ self.sigPlotSignal.connect(self._plotEvents)
+ self._color = "blue"
+ self.__profile = None
+ self.__profileSum = None
+
+ def _plotEvents(self, eventDict):
+ """Callback for horizontal histogram plot events."""
+ if eventDict['event'] == 'mouseMoved':
+ self.sigMouseMoved.emit(eventDict['x'], eventDict['y'])
+
+ def setProfileColor(self, color):
+ self._color = color
+
+ def setProfileSum(self, result):
+ self.__profileSum = result
+ if self.__profile is None:
+ self.__drawProfileSum()
+
+ def prepareWidget(self, roi):
+ """Implements `ProfileWindow`"""
+ pass
+
+ def setRoiProfile(self, roi):
+ """Implements `ProfileWindow`"""
+ if roi is None:
+ return
+ self._roiColor = colors.rgba(roi.getColor())
+
+ def getProfile(self):
+ """Implements `ProfileWindow`"""
+ return self.__profile
+
+ def setProfile(self, data):
+ """Implements `ProfileWindow`"""
+ self.__profile = data
+ if data is None:
+ self.__drawProfileSum()
+ else:
+ self.__drawProfile()
+
+ def __drawProfileSum(self):
+ """Only draw the profile sum on the plot.
+
+ Other elements are removed
+ """
+ profileSum = self.__profileSum
+
+ try:
+ self.removeCurve('profile')
+ except Exception:
+ pass
+
+ if profileSum is None:
+ try:
+ self.removeCurve('profilesum')
+ except Exception:
+ pass
+ return
+
+ if self._direction == qt.Qt.Horizontal:
+ xx, yy = profileSum.xCoords, profileSum.xData
+ elif self._direction == qt.Qt.Vertical:
+ xx, yy = profileSum.yData, profileSum.yCoords
+ else:
+ assert False
+
+ self.addCurve(xx, yy,
+ xlabel='', ylabel='',
+ legend="profilesum",
+ color=self._color,
+ linestyle='-',
+ selectable=False,
+ resetzoom=False)
+
+ self.__updateLimits()
+
+ def __drawProfile(self):
+ """Only draw the profile on the plot.
+
+ Other elements are removed
+ """
+ profile = self.__profile
+
+ try:
+ self.removeCurve('profilesum')
+ except Exception:
+ pass
+
+ if profile is None:
+ try:
+ self.removeCurve('profile')
+ except Exception:
+ pass
+ self.setProfileSum(self.__profileSum)
+ return
+
+ if self._direction == qt.Qt.Horizontal:
+ xx, yy = profile.coords, profile.profile
+ elif self._direction == qt.Qt.Vertical:
+ xx, yy = profile.profile, profile.coords
+ else:
+ assert False
+
+ self.addCurve(xx,
+ yy,
+ legend="profile",
+ color=self._roiColor,
+ resetzoom=False)
+
+ self.__updateLimits()
+
+ def __updateLimits(self):
+ if self.__profile:
+ data = self.__profile.profile
+ vMin = numpy.nanmin(data)
+ vMax = numpy.nanmax(data)
+ elif self.__profileSum is not None:
+ if self._direction == qt.Qt.Horizontal:
+ vMin, vMax = self.__profileSum.histoHRange
+ elif self._direction == qt.Qt.Vertical:
+ vMin, vMax = self.__profileSum.histoVRange
+ else:
+ assert False
+ else:
+ vMin, vMax = 0, 0
+
+ # Tune the result using the data margins
+ margins = self.getDataMargins()
+ if self._direction == qt.Qt.Horizontal:
+ _, _, vMin, vMax = _utils.addMarginsToLimits(margins, False, False, 0, 0, vMin, vMax)
+ elif self._direction == qt.Qt.Vertical:
+ vMin, vMax, _, _ = _utils.addMarginsToLimits(margins, False, False, vMin, vMax, 0, 0)
+ else:
+ assert False
+
+ if self._direction == qt.Qt.Horizontal:
+ dataAxis = self.getYAxis()
+ elif self._direction == qt.Qt.Vertical:
+ dataAxis = self.getXAxis()
+ else:
+ assert False
+
+ with blockSignals(dataAxis):
+ dataAxis.setLimits(vMin, vMax)
+
+
+class ShowSideHistogramsAction(PlotAction):
+ """QAction to change visibility of side histogram of a :class:`.ImageView`.
+
+ :param plot: :class:`.ImageView` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(ShowSideHistogramsAction, self).__init__(
+ plot, icon='side-histograms', text='Show/hide side histograms',
+ tooltip='Show/hide side histogram',
+ triggered=self._actionTriggered,
+ checkable=True, parent=parent)
+
+ def _actionTriggered(self, checked=False):
+ if self.plot.isSideHistogramDisplayed() != checked:
+ self.plot.setSideHistogramDisplayed(checked)
+
+
+class AggregationModeAction(qt.QWidgetAction):
+ """Action providing few filters to the image"""
+
+ sigAggregationModeChanged = qt.Signal()
+
+ def __init__(self, parent):
+ qt.QWidgetAction.__init__(self, parent)
+
+ toolButton = qt.QToolButton(parent)
+
+ filterAction = qt.QAction(self)
+ filterAction.setText("No filter")
+ filterAction.setCheckable(True)
+ filterAction.setChecked(True)
+ filterAction.setProperty("aggregation", items.ImageDataAggregated.Aggregation.NONE)
+ densityNoFilterAction = filterAction
+
+ filterAction = qt.QAction(self)
+ filterAction.setText("Max filter")
+ filterAction.setCheckable(True)
+ filterAction.setProperty("aggregation", items.ImageDataAggregated.Aggregation.MAX)
+ densityMaxFilterAction = filterAction
+
+ filterAction = qt.QAction(self)
+ filterAction.setText("Mean filter")
+ filterAction.setCheckable(True)
+ filterAction.setProperty("aggregation", items.ImageDataAggregated.Aggregation.MEAN)
+ densityMeanFilterAction = filterAction
+
+ filterAction = qt.QAction(self)
+ filterAction.setText("Min filter")
+ filterAction.setCheckable(True)
+ filterAction.setProperty("aggregation", items.ImageDataAggregated.Aggregation.MIN)
+ densityMinFilterAction = filterAction
+
+ densityGroup = qt.QActionGroup(self)
+ densityGroup.setExclusive(True)
+ densityGroup.addAction(densityNoFilterAction)
+ densityGroup.addAction(densityMaxFilterAction)
+ densityGroup.addAction(densityMeanFilterAction)
+ densityGroup.addAction(densityMinFilterAction)
+ densityGroup.triggered.connect(self._aggregationModeChanged)
+ self.__densityGroup = densityGroup
+
+ filterMenu = qt.QMenu(toolButton)
+ filterMenu.addAction(densityNoFilterAction)
+ filterMenu.addAction(densityMaxFilterAction)
+ filterMenu.addAction(densityMeanFilterAction)
+ filterMenu.addAction(densityMinFilterAction)
+
+ toolButton.setPopupMode(qt.QToolButton.InstantPopup)
+ toolButton.setMenu(filterMenu)
+ toolButton.setText("Data filters")
+ toolButton.setToolTip("Enable/disable filter on the image")
+ icon = icons.getQIcon("aggregation-mode")
+ toolButton.setIcon(icon)
+ toolButton.setText("Pixel aggregation filter")
+
+ self.setDefaultWidget(toolButton)
+
+ def _aggregationModeChanged(self):
+ self.sigAggregationModeChanged.emit()
+
+ def setAggregationMode(self, mode):
+ """Set an Aggregated enum from ImageDataAggregated"""
+ for a in self.__densityGroup.actions():
+ if a.property("aggregation") is mode:
+ a.setChecked(True)
+
+ def getAggregationMode(self):
+ """Returns an Aggregated enum from ImageDataAggregated"""
+ densityAction = self.__densityGroup.checkedAction()
+ if densityAction is None:
+ return items.ImageDataAggregated.Aggregation.NONE
+ return densityAction.property("aggregation")
+
+
+class ImageView(PlotWindow):
+ """Display a single image with horizontal and vertical histograms.
+
+ Use :meth:`setImage` to control the displayed image.
+ This class also provides the :class:`silx.gui.plot.Plot` API.
+
+ The :class:`ImageView` inherits from :class:`.PlotWindow` (which provides
+ the toolbars) and also exposes :class:`.PlotWidget` API for further
+ plot control (plot title, axes labels, aspect ratio, ...).
+
+ :param parent: The parent of this widget or None.
+ :param backend: The backend to use for the plot (default: matplotlib).
+ See :class:`.PlotWidget` for the list of supported backend.
+ :type backend: str or :class:`BackendBase.BackendBase`
+ """
+
+ HISTOGRAMS_COLOR = 'blue'
+ """Color to use for the side histograms."""
+
+ HISTOGRAMS_HEIGHT = 200
+ """Height in pixels of the side histograms."""
+
+ IMAGE_MIN_SIZE = 200
+ """Minimum size in pixels of the image area."""
+
+ # Qt signals
+ valueChanged = qt.Signal(float, float, object)
+ """Signals that the data value under the cursor has changed.
+
+ It provides: row, column, data value.
+
+ When the cursor is over an histogram, either row or column is Nan
+ and the provided data value is the histogram value
+ (i.e., the sum along the corresponding row/column).
+ Row and columns are either Nan or integer values.
+ """
+
+ class ProfileWindowBehavior(Enum):
+ """ImageView's profile window behavior options"""
+
+ POPUP = 'popup'
+ """All profiles are displayed in pop-up windows"""
+
+ EMBEDDED = 'embedded'
+ """Horizontal, vertical and cross profiles are displayed in
+ sides widgets, others are displayed in pop-up windows.
+ """
+
+ def __init__(self, parent=None, backend=None):
+ self._imageLegend = '__ImageView__image' + str(id(self))
+ self._cache = None # Store currently visible data information
+
+ super(ImageView, self).__init__(parent=parent, backend=backend,
+ resetzoom=True, autoScale=False,
+ logScale=False, grid=False,
+ curveStyle=False, colormap=True,
+ aspectRatio=True, yInverted=True,
+ copy=True, save=True, print_=True,
+ control=False, position=False,
+ roi=False, mask=True)
+
+ # Enable mask synchronisation to use it in profiles
+ maskToolsWidget = self.getMaskToolsDockWidget().widget()
+ maskToolsWidget.setItemMaskUpdated(True)
+
+ self.__showSideHistogramsAction = ShowSideHistogramsAction(self, self)
+ self.__showSideHistogramsAction.setChecked(True)
+
+ self.__aggregationModeAction = AggregationModeAction(self)
+ self.__aggregationModeAction.sigAggregationModeChanged.connect(self._aggregationModeChanged)
+
+ if parent is None:
+ self.setWindowTitle('ImageView')
+
+ if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == 'downward':
+ self.getYAxis().setInverted(True)
+
+ self._initWidgets(backend)
+
+ toolBar = self.toolBar()
+ toolBar.addAction(self.__showSideHistogramsAction)
+ toolBar.addAction(self.__aggregationModeAction)
+
+ self.__profileWindowBehavior = self.ProfileWindowBehavior.POPUP
+ self.__profile = ProfileToolBar(plot=self)
+ self.addToolBar(self.__profile)
+
+ def _initWidgets(self, backend):
+ """Set-up layout and plots."""
+ self._histoHPlot = _SideHistogram(backend=backend, parent=self, direction=qt.Qt.Horizontal)
+ widgetHandle = self._histoHPlot.getWidgetHandle()
+ widgetHandle.setMinimumHeight(self.HISTOGRAMS_HEIGHT)
+ widgetHandle.setMaximumHeight(self.HISTOGRAMS_HEIGHT)
+ self._histoHPlot.setInteractiveMode('zoom')
+ self._histoHPlot.setDataMargins(0., 0., 0.1, 0.1)
+ self._histoHPlot.sigMouseMoved.connect(self._mouseMovedOnHistoH)
+ self._histoHPlot.setProfileColor(self.HISTOGRAMS_COLOR)
+
+ self._histoVPlot = _SideHistogram(backend=backend, parent=self, direction=qt.Qt.Vertical)
+ widgetHandle = self._histoVPlot.getWidgetHandle()
+ widgetHandle.setMinimumWidth(self.HISTOGRAMS_HEIGHT)
+ widgetHandle.setMaximumWidth(self.HISTOGRAMS_HEIGHT)
+ self._histoVPlot.setInteractiveMode('zoom')
+ self._histoVPlot.setDataMargins(0.1, 0.1, 0., 0.)
+ self._histoVPlot.sigMouseMoved.connect(self._mouseMovedOnHistoV)
+ self._histoVPlot.setProfileColor(self.HISTOGRAMS_COLOR)
+
+ self.setPanWithArrowKeys(True)
+ self.setInteractiveMode('zoom') # Color set in setColormap
+ self.sigPlotSignal.connect(self._imagePlotCB)
+ self.sigActiveImageChanged.connect(self._activeImageChangedSlot)
+
+ self._radarView = RadarView(parent=self)
+ self._radarView.setPlotWidget(self)
+
+ self.__syncXAxis = SyncAxes([self.getXAxis(), self._histoHPlot.getXAxis()])
+ self.__syncYAxis = SyncAxes([self.getYAxis(), self._histoVPlot.getYAxis()])
+
+ self.__setCentralWidget()
+
+ def __setCentralWidget(self):
+ """Set central widget with all its content"""
+ layout = qt.QGridLayout()
+ layout.addWidget(self.getWidgetHandle(), 0, 0)
+ layout.addWidget(self._histoVPlot, 0, 1)
+ layout.addWidget(self._histoHPlot, 1, 0)
+ layout.addWidget(self._radarView, 1, 1, 1, 2)
+ layout.addWidget(self.getColorBarWidget(), 0, 2)
+
+ self._radarView.setMinimumWidth(self.IMAGE_MIN_SIZE)
+ self._radarView.setMinimumHeight(self.HISTOGRAMS_HEIGHT)
+ self._histoHPlot.setMinimumWidth(self.IMAGE_MIN_SIZE)
+ self._histoVPlot.setMinimumHeight(self.HISTOGRAMS_HEIGHT)
+
+ layout.setColumnStretch(0, 1)
+ layout.setColumnStretch(1, 0)
+ layout.setRowStretch(0, 1)
+ layout.setRowStretch(1, 0)
+
+ layout.setSpacing(0)
+ layout.setContentsMargins(0, 0, 0, 0)
+
+ centralWidget = qt.QWidget(self)
+ centralWidget.setLayout(layout)
+ self.setCentralWidget(centralWidget)
+
+ @docstring(PlotWidget)
+ def setBackend(self, backend):
+ # Use PlotWidget here since we override PlotWindow behavior
+ PlotWidget.setBackend(self, backend)
+ self.__setCentralWidget()
+
+ def _dirtyCache(self):
+ self._cache = None
+
+ def getAggregationModeAction(self):
+ return self.__aggregationModeAction
+
+ def _aggregationModeChanged(self):
+ item = self._getItem("image", self._imageLegend)
+ if item is None:
+ return
+ aggregationMode = self.__aggregationModeAction.getAggregationMode()
+ if aggregationMode is not None and isinstance(item, items.ImageDataAggregated):
+ item.setAggregationMode(aggregationMode)
+ else:
+ # It means the item type have to be changed
+ self.removeImage(self._imageLegend)
+ image = item.getData(copy=False)
+ if image is None:
+ return
+ origin = item.getOrigin()
+ scale = item.getScale()
+ self.setImage(image, origin, scale, copy=False, resetzoom=False)
+
+ def getShowSideHistogramsAction(self):
+ return self.__showSideHistogramsAction
+
+ def setSideHistogramDisplayed(self, show):
+ """Display or not the side histograms"""
+ if self.isSideHistogramDisplayed() == show:
+ return
+ self._histoHPlot.setVisible(show)
+ self._histoVPlot.setVisible(show)
+ self._radarView.setVisible(show)
+ self.__showSideHistogramsAction.setChecked(show)
+ if show:
+ # Probably have to be computed
+ self._updateHistograms()
+
+ def isSideHistogramDisplayed(self):
+ """True if the side histograms are displayed"""
+ return self._histoHPlot.isVisible()
+
+ def _updateHistograms(self):
+ """Update histograms content using current active image."""
+ if not self.isSideHistogramDisplayed():
+ # The histogram computation can be skipped
+ return
+
+ activeImage = self.getActiveImage()
+ if activeImage is not None:
+ xRange = self.getXAxis().getLimits()
+ yRange = self.getYAxis().getLimits()
+ result = computeProfileSumOnRange(activeImage, xRange, yRange, self._cache)
+ self._cache = result
+ self._histoHPlot.setProfileSum(result)
+ self._histoVPlot.setProfileSum(result)
+
+ # Plots event listeners
+
+ def _imagePlotCB(self, eventDict):
+ """Callback for imageView plot events."""
+ if eventDict['event'] == 'mouseMoved':
+ activeImage = self.getActiveImage()
+ if activeImage is not None:
+ data = activeImage.getData(copy=False)
+ height, width = data.shape[0:2]
+
+ # Get corresponding coordinate in image
+ origin = activeImage.getOrigin()
+ scale = activeImage.getScale()
+ if (eventDict['x'] >= origin[0] and
+ eventDict['y'] >= origin[1]):
+ x = int((eventDict['x'] - origin[0]) / scale[0])
+ y = int((eventDict['y'] - origin[1]) / scale[1])
+
+ if x >= 0 and x < width and y >= 0 and y < height:
+ self.valueChanged.emit(float(x), float(y),
+ data[y][x])
+
+ elif eventDict['event'] == 'limitsChanged':
+ self._updateHistograms()
+
+ def _mouseMovedOnHistoH(self, x, y):
+ if self._cache is None:
+ return
+ activeImage = self.getActiveImage()
+ if activeImage is None:
+ return
+
+ xOrigin = activeImage.getOrigin()[0]
+ xScale = activeImage.getScale()[0]
+
+ minValue = xOrigin + xScale * self._cache.dataXRange[0]
+
+ if x >= minValue:
+ data = self._cache.histoH
+ column = int((x - minValue) / xScale)
+ if column >= 0 and column < data.shape[0]:
+ self.valueChanged.emit(
+ float('nan'),
+ float(column + self._cache.dataXRange[0]),
+ data[column])
+
+ def _mouseMovedOnHistoV(self, x, y):
+ if self._cache is None:
+ return
+ activeImage = self.getActiveImage()
+ if activeImage is None:
+ return
+
+ yOrigin = activeImage.getOrigin()[1]
+ yScale = activeImage.getScale()[1]
+
+ minValue = yOrigin + yScale * self._cache.dataYRange[0]
+
+ if y >= minValue:
+ data = self._cache.histoV
+ row = int((y - minValue) / yScale)
+ if row >= 0 and row < data.shape[0]:
+ self.valueChanged.emit(
+ float(row + self._cache.dataYRange[0]),
+ float('nan'),
+ data[row])
+
+ def _activeImageChangedSlot(self, previous, legend):
+ """Handle Plot active image change.
+
+ Resets side histograms cache
+ """
+ self._dirtyCache()
+ self._updateHistograms()
+
+ def setProfileWindowBehavior(self, behavior: Union[str, ProfileWindowBehavior]):
+ """Set where profile widgets are displayed.
+
+ :param ProfileWindowBehavior behavior:
+ - 'popup': All profiles are displayed in pop-up windows
+ - 'embedded': Horizontal, vertical and cross profiles are displayed in
+ sides widgets, others are displayed in pop-up windows.
+ """
+ behavior = self.ProfileWindowBehavior.from_value(behavior)
+ if behavior is not self.getProfileWindowBehavior():
+ manager = self.__profile.getProfileManager()
+ manager.clearProfile()
+ manager.requestUpdateAllProfile()
+
+ if behavior is self.ProfileWindowBehavior.EMBEDDED:
+ horizontalProfileWindow = self._histoHPlot
+ verticalProfileWindow = self._histoVPlot
+ else:
+ horizontalProfileWindow = None
+ verticalProfileWindow = None
+
+ manager.setSpecializedProfileWindow(
+ rois.ProfileImageHorizontalLineROI, horizontalProfileWindow
+ )
+ manager.setSpecializedProfileWindow(
+ rois.ProfileImageVerticalLineROI, verticalProfileWindow
+ )
+ self.__profileWindowBehavior = behavior
+
+ def getProfileWindowBehavior(self) -> ProfileWindowBehavior:
+ """Returns current profile display behavior.
+
+ See :meth:`setProfileWindowBehavior` and :class:`ProfileWindowBehavior`
+ """
+ return self.__profileWindowBehavior
+
+ def getProfileToolBar(self):
+ """"Returns profile tools attached to this plot.
+
+ :rtype: silx.gui.plot.PlotTools.ProfileToolBar
+ """
+ return self.__profile
+
+ @property
+ @deprecated(replacement="getProfileToolBar()")
+ def profile(self):
+ return self.getProfileToolBar()
+
+ def getHistogram(self, axis):
+ """Return the histogram and corresponding row or column extent.
+
+ The returned value when an histogram is available is a dict with keys:
+
+ - 'data': numpy array of the histogram values.
+ - 'extent': (start, end) row or column index.
+ end index is not included in the histogram.
+
+ :param str axis: 'x' for horizontal, 'y' for vertical
+ :return: The histogram and its extent as a dict or None.
+ :rtype: dict
+ """
+ assert axis in ('x', 'y')
+ if self._cache is None:
+ return None
+ else:
+ if axis == 'x':
+ return dict(
+ data=numpy.array(self._cache.histoH, copy=True),
+ extent=self._cache.dataXRange)
+ else:
+ return dict(
+ data=numpy.array(self._cache.histoV, copy=True),
+ extent=(self._cache.dataYRange))
+
+ def radarView(self):
+ """Get the lower right radarView widget."""
+ return self._radarView
+
+ def setRadarView(self, radarView):
+ """Change the lower right radarView widget.
+
+ :param RadarView radarView: Widget subclassing RadarView to replace
+ the lower right corner widget.
+ """
+ self._radarView = radarView
+ self._radarView.setPlotWidget(self)
+ self.centralWidget().layout().addWidget(self._radarView, 1, 1)
+
+ # High-level API
+
+ def getColormap(self):
+ """Get the default colormap description.
+
+ :return: A description of the current colormap.
+ See :meth:`setColormap` for details.
+ :rtype: dict
+ """
+ return self.getDefaultColormap()
+
+ def setColormap(self, colormap=None, normalization=None,
+ autoscale=None, vmin=None, vmax=None, colors=None):
+ """Set the default colormap and update active image.
+
+ Parameters that are not provided are taken from the current colormap.
+
+ The colormap parameter can also be a dict with the following keys:
+
+ - *name*: string. The colormap to use:
+ 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'.
+ - *normalization*: string. The mapping to use for the colormap:
+ either 'linear' or 'log'.
+ - *autoscale*: bool. Whether to use autoscale (True)
+ or range provided by keys 'vmin' and 'vmax' (False).
+ - *vmin*: float. The minimum value of the range to use if 'autoscale'
+ is False.
+ - *vmax*: float. The maximum value of the range to use if 'autoscale'
+ is False.
+ - *colors*: optional. Nx3 or Nx4 array of float in [0, 1] or uint8.
+ List of RGB or RGBA colors to use (only if name is None)
+
+ :param colormap: Name of the colormap in
+ 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'.
+ Or the description of the colormap as a dict.
+ :type colormap: dict or str.
+ :param str normalization: Colormap mapping: 'linear' or 'log'.
+ :param bool autoscale: Whether to use autoscale (True)
+ or [vmin, vmax] range (False).
+ :param float vmin: The minimum value of the range to use if
+ 'autoscale' is False.
+ :param float vmax: The maximum value of the range to use if
+ 'autoscale' is False.
+ :param numpy.ndarray colors: Only used if name is None.
+ Custom colormap colors as Nx3 or Nx4 RGB or RGBA arrays
+ """
+ cmap = self.getDefaultColormap()
+
+ if isinstance(colormap, Colormap):
+ # Replace colormap
+ cmap = colormap
+
+ self.setDefaultColormap(cmap)
+
+ # Update active image colormap
+ activeImage = self.getActiveImage()
+ if isinstance(activeImage, items.ColormapMixIn):
+ activeImage.setColormap(cmap)
+
+ elif isinstance(colormap, dict):
+ # Support colormap parameter as a dict
+ assert normalization is None
+ assert autoscale is None
+ assert vmin is None
+ assert vmax is None
+ assert colors is None
+ cmap._setFromDict(colormap)
+
+ else:
+ if colormap is not None:
+ cmap.setName(colormap)
+ if normalization is not None:
+ cmap.setNormalization(normalization)
+ if autoscale:
+ cmap.setVRange(None, None)
+ else:
+ if vmin is not None:
+ cmap.setVMin(vmin)
+ if vmax is not None:
+ cmap.setVMax(vmax)
+ if colors is not None:
+ cmap.setColormapLUT(colors)
+
+ cursorColor = cursorColorForColormap(cmap.getName())
+ self.setInteractiveMode('zoom', color=cursorColor)
+
+ def setImage(self, image, origin=(0, 0), scale=(1., 1.),
+ copy=True, reset=None, resetzoom=True):
+ """Set the image to display.
+
+ :param image: A 2D array representing the image or None to empty plot.
+ :type image: numpy.ndarray-like with 2 dimensions or None.
+ :param origin: The (x, y) position of the origin of the image.
+ Default: (0, 0).
+ The origin is the lower left corner of the image when
+ the Y axis is not inverted.
+ :type origin: Tuple of 2 floats: (origin x, origin y).
+ :param scale: The scale factor to apply to the image on X and Y axes.
+ Default: (1, 1).
+ It is the size of a pixel in the coordinates of the axes.
+ Scales must be positive numbers.
+ :type scale: Tuple of 2 floats: (scale x, scale y).
+ :param bool copy: Whether to copy image data (default) or not.
+ :param bool reset: Deprecated. Alias for `resetzoom`.
+ :param bool resetzoom: Whether to reset zoom and ROI (default) or not.
+ """
+ self._dirtyCache()
+
+ if reset is not None:
+ resetzoom = reset
+
+ assert len(origin) == 2
+ assert len(scale) == 2
+ assert scale[0] > 0
+ assert scale[1] > 0
+
+ if image is None:
+ self.remove(self._imageLegend, kind='image')
+ return
+
+ data = numpy.array(image, order='C', copy=copy)
+ if data.size == 0:
+ self.remove(self._imageLegend, kind='image')
+ return
+
+ assert data.ndim == 2 or (data.ndim == 3 and data.shape[2] in (3, 4))
+
+ aggregation = self.getAggregationModeAction().getAggregationMode()
+ if data.ndim != 2 and aggregation is not None:
+ # RGB/A with aggregation is not supported
+ aggregation = items.ImageDataAggregated.Aggregation.NONE
+
+ if aggregation is items.ImageDataAggregated.Aggregation.NONE:
+ self.addImage(data,
+ legend=self._imageLegend,
+ origin=origin, scale=scale,
+ colormap=self.getColormap(),
+ resetzoom=False)
+ else:
+ item = self._getItem("image", self._imageLegend)
+ if isinstance(item, items.ImageDataAggregated):
+ item.setData(data)
+ item.setOrigin(origin)
+ item.setScale(scale)
+ else:
+ if isinstance(item, items.ImageDataAggregated):
+ imageItem = item
+ wasCreated = False
+ else:
+ if item is not None:
+ self.removeImage(self._imageLegend)
+ imageItem = items.ImageDataAggregated()
+ imageItem.setName(self._imageLegend)
+ imageItem.setColormap(self.getColormap())
+ wasCreated = True
+ imageItem.setData(data)
+ imageItem.setOrigin(origin)
+ imageItem.setScale(scale)
+ imageItem.setAggregationMode(aggregation)
+ if wasCreated:
+ self.addItem(imageItem)
+
+ self.setActiveImage(self._imageLegend)
+ self._updateHistograms()
+ if resetzoom:
+ self.resetZoom()
+
+
+# ImageViewMainWindow #########################################################
+
+class ImageViewMainWindow(ImageView):
+ """:class:`ImageView` with additional toolbars
+
+ Adds extra toolbar and a status bar to :class:`ImageView`.
+ """
+ def __init__(self, parent=None, backend=None):
+ self._dataInfo = None
+ super(ImageViewMainWindow, self).__init__(parent, backend)
+ self.setWindowFlags(qt.Qt.Window)
+
+ self.getXAxis().setLabel('X')
+ self.getYAxis().setLabel('Y')
+ self.setGraphTitle('Image')
+
+ # Add toolbars and status bar
+ self.addToolBar(qt.Qt.BottomToolBarArea, LimitsToolBar(plot=self))
+
+ menu = self.menuBar().addMenu('File')
+ menu.addAction(self.getOutputToolBar().getSaveAction())
+ menu.addAction(self.getOutputToolBar().getPrintAction())
+ menu.addSeparator()
+ action = menu.addAction('Quit')
+ action.triggered[bool].connect(qt.QApplication.instance().quit)
+
+ menu = self.menuBar().addMenu('Edit')
+ menu.addAction(self.getOutputToolBar().getCopyAction())
+ menu.addSeparator()
+ menu.addAction(self.getResetZoomAction())
+ menu.addAction(self.getColormapAction())
+ menu.addAction(actions.control.KeepAspectRatioAction(self, self))
+ menu.addAction(actions.control.YAxisInvertedAction(self, self))
+ menu.addAction(self.getShowSideHistogramsAction())
+
+ self.__profileMenu = self.menuBar().addMenu('Profile')
+ self.__updateProfileMenu()
+
+ # Connect to ImageView's signal
+ self.valueChanged.connect(self._statusBarSlot)
+
+ def __updateProfileMenu(self):
+ """Update actions available in 'Profile' menu"""
+ profile = self.getProfileToolBar()
+ self.__profileMenu.clear()
+ self.__profileMenu.addAction(profile.hLineAction)
+ self.__profileMenu.addAction(profile.vLineAction)
+ self.__profileMenu.addAction(profile.crossAction)
+ self.__profileMenu.addAction(profile.lineAction)
+ self.__profileMenu.addAction(profile.clearAction)
+
+ def _formatValueToString(self, value):
+ try:
+ if isinstance(value, numpy.ndarray):
+ if len(value) == 4:
+ return "RGBA: %.3g, %.3g, %.3g, %.3g" % (value[0], value[1], value[2], value[3])
+ elif len(value) == 3:
+ return "RGB: %.3g, %.3g, %.3g" % (value[0], value[1], value[2])
+ else:
+ return "Value: %g" % value
+ except Exception:
+ _logger.error("Error while formatting pixel value", exc_info=True)
+ pass
+ return "Value: %s" % value
+
+ def _statusBarSlot(self, row, column, value):
+ """Update status bar with coordinates/value from plots."""
+ if numpy.isnan(row):
+ msg = 'Column: %d, Sum: %g' % (int(column), value)
+ elif numpy.isnan(column):
+ msg = 'Row: %d, Sum: %g' % (int(row), value)
+ else:
+ msg_value = self._formatValueToString(value)
+ msg = 'Position: (%d, %d), %s' % (int(row), int(column), msg_value)
+ if self._dataInfo is not None:
+ msg = self._dataInfo + ', ' + msg
+
+ self.statusBar().showMessage(msg)
+
+ @docstring(ImageView)
+ def setProfileWindowBehavior(self, behavior: str):
+ super().setProfileWindowBehavior(behavior)
+ self.__updateProfileMenu()
+
+ @docstring(ImageView)
+ def setImage(self, image, *args, **kwargs):
+ if hasattr(image, 'dtype') and hasattr(image, 'shape'):
+ assert image.ndim == 2 or (image.ndim == 3 and image.shape[2] in (3, 4))
+ height, width = image.shape[0:2]
+ dataInfo = 'Data: %dx%d (%s)' % (width, height, str(image.dtype))
+ else:
+ dataInfo = None
+
+ if self._dataInfo != dataInfo:
+ self._dataInfo = dataInfo
+ self.statusBar().showMessage(self._dataInfo)
+
+ # Set the new image in ImageView widget
+ super(ImageViewMainWindow, self).setImage(image, *args, **kwargs)
diff --git a/src/silx/gui/plot/Interaction.py b/src/silx/gui/plot/Interaction.py
new file mode 100644
index 0000000..6213889
--- /dev/null
+++ b/src/silx/gui/plot/Interaction.py
@@ -0,0 +1,350 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides an implementation of state machines for interaction.
+
+Sample code of a state machine with two states ('idle' and 'active')
+with transitions on left button press/release:
+
+.. code-block:: python
+
+ from silx.gui.plot.Interaction import *
+
+ class SampleStateMachine(StateMachine):
+
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.goto('active')
+
+ class Active(State):
+ def enterState(self):
+ print('Enabled') # Handle enter active state here
+
+ def leaveState(self):
+ print('Disabled') # Handle leave active state here
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.goto('idle')
+
+ def __init__(self):
+ # State machine has 2 states
+ states = {
+ 'idle': SampleStateMachine.Idle,
+ 'active': SampleStateMachine.Active
+ }
+ super(TwoStates, self).__init__(states, 'idle')
+ # idle is the initial state
+
+ stateMachine = SampleStateMachine()
+
+ # Triggers a transition to the Active state:
+ stateMachine.handleEvent('press', 0, 0, LEFT_BTN)
+
+ # Triggers a transition to the Idle state:
+ stateMachine.handleEvent('release', 0, 0, LEFT_BTN)
+
+See :class:`ClickOrDrag` for another example of a state machine.
+
+See `Renaud Blanch, Michel Beaudouin-Lafon.
+Programming Rich Interactions using the Hierarchical State Machine Toolkit.
+In Proceedings of AVI 2006. p 51-58.
+<http://iihm.imag.fr/en/publication/BB06a/>`_
+for a discussion of using (hierarchical) state machines for interaction.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "18/02/2016"
+
+
+import weakref
+
+
+# state machine ###############################################################
+
+class State(object):
+ """Base class for the states of a state machine.
+
+ This class is meant to be subclassed.
+ """
+
+ def __init__(self, machine):
+ """State instances should be created by the :class:`StateMachine`.
+
+ They are not intended to be used outside this context.
+
+ :param machine: The state machine instance this state belongs to.
+ :type machine: StateMachine
+ """
+ self._machineRef = weakref.ref(machine) # Prevent cyclic reference
+
+ @property
+ def machine(self):
+ """The state machine this state belongs to.
+
+ Useful to access data or methods that are shared across states.
+ """
+ machine = self._machineRef()
+ if machine is not None:
+ return machine
+ else:
+ raise RuntimeError("Associated StateMachine is not valid")
+
+ def goto(self, state, *args, **kwargs):
+ """Performs a transition to a new state.
+
+ Extra arguments are passed to the :meth:`enterState` method of the
+ new state.
+
+ :param str state: The name of the state to go to.
+ """
+ self.machine._goto(state, *args, **kwargs)
+
+ def enterState(self, *args, **kwargs):
+ """Called when the state machine enters this state.
+
+ Arguments are those provided to the :meth:`goto` method that
+ triggered the transition to this state.
+ """
+ pass
+
+ def leaveState(self):
+ """Called when the state machine leaves this state
+ (i.e., when :meth:`goto` is called).
+ """
+ pass
+
+ def validate(self):
+ """Called externally to validate the current interaction in case of a
+ creation.
+ """
+ pass
+
+class StateMachine(object):
+ """State machine controller.
+
+ This is the entry point of a state machine.
+ It is in charge of dispatching received event and handling the
+ current active state.
+ """
+
+ def __init__(self, states, initState, *args, **kwargs):
+ """Create a state machine controller with an initial state.
+
+ Extra arguments are passed to the :meth:`enterState` method
+ of the initState.
+
+ :param states: All states of the state machine
+ :type states: dict of: {str name: State subclass}
+ :param str initState: Key of the initial state in states
+ """
+ self.states = states
+
+ self.state = self.states[initState](self)
+ self.state.enterState(*args, **kwargs)
+
+ def _goto(self, state, *args, **kwargs):
+ self.state.leaveState()
+ self.state = self.states[state](self)
+ self.state.enterState(*args, **kwargs)
+
+ def handleEvent(self, eventName, *args, **kwargs):
+ """Process an event with the state machine.
+
+ This method looks up for an event handler in the current state
+ and then in the :class:`StateMachine` instance.
+ Handler are looked up as 'onEventName' method.
+ If a handler is found, it is called with the provided extra
+ arguments, and this method returns the return value of the
+ handler.
+ If no handler is found, this method returns None.
+
+ :param str eventName: Name of the event to handle
+ :returns: The return value of the handler or None
+ """
+ handlerName = 'on' + eventName[0].upper() + eventName[1:]
+ try:
+ handler = getattr(self.state, handlerName)
+ except AttributeError:
+ try:
+ handler = getattr(self, handlerName)
+ except AttributeError:
+ handler = None
+ if handler is not None:
+ return handler(*args, **kwargs)
+
+ def validate(self):
+ """Called externally to validate the current interaction in case of a
+ creation.
+ """
+ self.state.validate()
+
+
+# clickOrDrag #################################################################
+
+LEFT_BTN = 'left'
+"""Left mouse button."""
+
+RIGHT_BTN = 'right'
+"""Right mouse button."""
+
+MIDDLE_BTN = 'middle'
+"""Middle mouse button."""
+
+
+class ClickOrDrag(StateMachine):
+ """State machine for left and right click and left drag interaction.
+
+ It is intended to be used through subclassing by overriding
+ :meth:`click`, :meth:`beginDrag`, :meth:`drag` and :meth:`endDrag`.
+
+ :param Set[str] clickButtons: Set of buttons that provides click interaction
+ :param Set[str] dragButtons: Set of buttons that provides drag interaction
+ """
+
+ DRAG_THRESHOLD_SQUARE_DIST = 5 ** 2
+
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ if btn in self.machine.dragButtons:
+ self.goto('clickOrDrag', x, y, btn)
+ return True
+ elif btn in self.machine.clickButtons:
+ self.goto('click', x, y, btn)
+ return True
+
+ class Click(State):
+ def enterState(self, x, y, btn):
+ self.initPos = x, y
+ self.button = btn
+
+ def onMove(self, x, y):
+ dx2 = (x - self.initPos[0]) ** 2
+ dy2 = (y - self.initPos[1]) ** 2
+ if (dx2 + dy2) >= self.machine.DRAG_THRESHOLD_SQUARE_DIST:
+ self.goto('idle')
+
+ def onRelease(self, x, y, btn):
+ if btn == self.button:
+ self.machine.click(x, y, btn)
+ self.goto('idle')
+
+ class ClickOrDrag(State):
+ def enterState(self, x, y, btn):
+ self.initPos = x, y
+ self.button = btn
+
+ def onMove(self, x, y):
+ dx2 = (x - self.initPos[0]) ** 2
+ dy2 = (y - self.initPos[1]) ** 2
+ if (dx2 + dy2) >= self.machine.DRAG_THRESHOLD_SQUARE_DIST:
+ self.goto('drag', self.initPos, (x, y), self.button)
+
+ def onRelease(self, x, y, btn):
+ if btn == self.button:
+ if btn in self.machine.clickButtons:
+ self.machine.click(x, y, btn)
+ self.goto('idle')
+
+ class Drag(State):
+ def enterState(self, initPos, curPos, btn):
+ self.initPos = initPos
+ self.button = btn
+ self.machine.beginDrag(*initPos, btn)
+ self.machine.drag(*curPos, btn)
+
+ def onMove(self, x, y):
+ self.machine.drag(x, y, self.button)
+
+ def onRelease(self, x, y, btn):
+ if btn == self.button:
+ self.machine.endDrag(self.initPos, (x, y), btn)
+ self.goto('idle')
+
+ def __init__(self,
+ clickButtons=(LEFT_BTN, RIGHT_BTN),
+ dragButtons=(LEFT_BTN,)):
+ states = {
+ 'idle': self.Idle,
+ 'click': self.Click,
+ 'clickOrDrag': self.ClickOrDrag,
+ 'drag': self.Drag
+ }
+ self.__clickButtons = set(clickButtons)
+ self.__dragButtons = set(dragButtons)
+ super(ClickOrDrag, self).__init__(states, 'idle')
+
+ clickButtons = property(lambda self: self.__clickButtons,
+ doc="Buttons with click interaction (Set[int])")
+
+ dragButtons = property(lambda self: self.__dragButtons,
+ doc="Buttons with drag interaction (Set[int])")
+
+ def click(self, x, y, btn):
+ """Called upon a button supporting click.
+
+ Override in subclass.
+
+ :param int x: X mouse position in pixels.
+ :param int y: Y mouse position in pixels.
+ :param str btn: The mouse button which was clicked.
+ """
+ pass
+
+ def beginDrag(self, x, y, btn):
+ """Called at the beginning of a drag gesture with mouse button pressed.
+
+ Override in subclass.
+
+ :param int x: X mouse position in pixels.
+ :param int y: Y mouse position in pixels.
+ :param str btn: The mouse button for which a drag is starting.
+ """
+ pass
+
+ def drag(self, x, y, btn):
+ """Called on mouse moved during a drag gesture.
+
+ Override in subclass.
+
+ :param int x: X mouse position in pixels.
+ :param int y: Y mouse position in pixels.
+ :param str btn: The mouse button for which a drag is in progress.
+ """
+ pass
+
+ def endDrag(self, startPoint, endPoint, btn):
+ """Called at the end of a drag gesture when the mouse button is released.
+
+ Override in subclass.
+
+ :param List[int] startPoint:
+ (x, y) mouse position in pixels at the beginning of the drag.
+ :param List[int] endPoint:
+ (x, y) mouse position in pixels at the end of the drag.
+ :param str btn: The mouse button for which a drag is done.
+ """
+ pass
diff --git a/src/silx/gui/plot/ItemsSelectionDialog.py b/src/silx/gui/plot/ItemsSelectionDialog.py
new file mode 100644
index 0000000..c0504b0
--- /dev/null
+++ b/src/silx/gui/plot/ItemsSelectionDialog.py
@@ -0,0 +1,286 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a dialog widget to select plot items.
+
+.. autoclass:: ItemsSelectionDialog
+
+"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "28/06/2017"
+
+import logging
+
+from silx.gui import qt
+from silx.gui.plot.PlotWidget import PlotWidget
+
+_logger = logging.getLogger(__name__)
+
+
+class KindsSelector(qt.QListWidget):
+ """List widget allowing to select plot item kinds
+ ("curve", "scatter", "image"...)
+ """
+ sigSelectedKindsChanged = qt.Signal(list)
+
+ def __init__(self, parent=None, kinds=None):
+ """
+
+ :param parent: Parent QWidget or None
+ :param tuple(str) kinds: Sequence of kinds. If None, the default
+ behavior is to provide a checkbox for all possible item kinds.
+ """
+ qt.QListWidget.__init__(self, parent)
+
+ self.plot_item_kinds = []
+
+ self.setAvailableKinds(kinds if kinds is not None else PlotWidget.ITEM_KINDS)
+
+ self.setSelectionMode(qt.QAbstractItemView.ExtendedSelection)
+ self.selectAll()
+
+ self.itemSelectionChanged.connect(self.emitSigKindsSelectionChanged)
+
+ def emitSigKindsSelectionChanged(self):
+ self.sigSelectedKindsChanged.emit(self.selectedKinds)
+
+ @property
+ def selectedKinds(self):
+ """Tuple of all selected kinds (as strings)."""
+ # check for updates when self.itemSelectionChanged
+ return [item.text() for item in self.selectedItems()]
+
+ def setAvailableKinds(self, kinds):
+ """Set a list of kinds to be displayed.
+
+ :param list[str] kinds: Sequence of kinds
+ """
+ self.plot_item_kinds = kinds
+
+ self.clear()
+ for kind in self.plot_item_kinds:
+ item = qt.QListWidgetItem(self)
+ item.setText(kind)
+ self.addItem(item)
+
+ def selectAll(self):
+ """Select all available kinds."""
+ if self.selectionMode() in [qt.QAbstractItemView.SingleSelection,
+ qt.QAbstractItemView.NoSelection]:
+ raise RuntimeError("selectAll requires a multiple selection mode")
+ for i in range(self.count()):
+ self.item(i).setSelected(True)
+
+
+class PlotItemsSelector(qt.QTableWidget):
+ """Table widget displaying the legend and kind of all
+ plot items corresponding to a list of specified kinds.
+
+ Selected plot items are provided as property :attr:`selectedPlotItems`.
+ You can be warned of selection changes by listening to signal
+ :attr:`itemSelectionChanged`.
+ """
+ def __init__(self, parent=None, plot=None):
+ if plot is None or not isinstance(plot, PlotWidget):
+ raise AttributeError("parameter plot is required")
+ self.plot = plot
+ """:class:`PlotWidget` instance"""
+
+ self.plot_item_kinds = None
+ """List of plot item kinds (strings)"""
+
+ qt.QTableWidget.__init__(self, parent)
+
+ self.setColumnCount(2)
+
+ self.setSelectionBehavior(qt.QTableWidget.SelectRows)
+
+ def _clear(self):
+ self.clear()
+ self.setHorizontalHeaderLabels(["legend", "type"])
+
+ def setAllKindsFilter(self):
+ """Display all kinds of plot items."""
+ self.setKindsFilter(PlotWidget.ITEM_KINDS)
+
+ def setKindsFilter(self, kinds):
+ """Set list of all kinds of plot items to be displayed.
+
+ :param list[str] kinds: Sequence of kinds
+ """
+ if not set(kinds) <= set(PlotWidget.ITEM_KINDS):
+ raise KeyError("Illegal plot item kinds: %s" %
+ set(kinds) - set(PlotWidget.ITEM_KINDS))
+ self.plot_item_kinds = kinds
+
+ self.updatePlotItems()
+
+ def updatePlotItems(self):
+ self._clear()
+
+ # respect order of kinds as set in method setKindsFilter
+ itemsAndKind = []
+ for kind in self.plot_item_kinds:
+ itemClasses = self.plot._KIND_TO_CLASSES[kind]
+ for item in self.plot.getItems():
+ if isinstance(item, itemClasses) and item.isVisible():
+ itemsAndKind.append((item, kind))
+
+ self.setRowCount(len(itemsAndKind))
+
+ for index, (item, kind) in enumerate(itemsAndKind):
+ legend_twitem = qt.QTableWidgetItem(item.getName())
+ self.setItem(index, 0, legend_twitem)
+
+ kind_twitem = qt.QTableWidgetItem(kind)
+ self.setItem(index, 1, kind_twitem)
+
+ @property
+ def selectedPlotItems(self):
+ """List of all selected items"""
+ selection_model = self.selectionModel()
+ selected_rows_idx = selection_model.selectedRows()
+ selected_rows = [idx.row() for idx in selected_rows_idx]
+
+ items = []
+ for row in selected_rows:
+ legend = self.item(row, 0).text()
+ kind = self.item(row, 1).text()
+ item = self.plot._getItem(kind, legend)
+ if item is not None:
+ items.append(item)
+
+ return items
+
+
+class ItemsSelectionDialog(qt.QDialog):
+ """This widget is a modal dialog allowing to select one or more plot
+ items, in a table displaying their legend and kind.
+
+ Public methods:
+
+ - :meth:`getSelectedItems`
+ - :meth:`setAvailableKinds`
+ - :meth:`setItemsSelectionMode`
+
+ This widget inherits QDialog and therefore implements the usual
+ dialog methods, e.g. :meth:`exec`.
+
+ A trivial usage example would be::
+
+ isd = ItemsSelectionDialog(plot=my_plot_widget)
+ isd.setItemsSelectionMode(qt.QTableWidget.SingleSelection)
+ result = isd.exec()
+ if result:
+ for item in isd.getSelectedItems():
+ print(item.getName(), type(item))
+ else:
+ print("Selection cancelled")
+ """
+ def __init__(self, parent=None, plot=None):
+ if plot is None or not isinstance(plot, PlotWidget):
+ raise AttributeError("parameter plot is required")
+ qt.QDialog.__init__(self, parent)
+
+ self.setWindowTitle("Plot items selector")
+
+ kind_selector_label = qt.QLabel("Filter item kinds:", self)
+ item_selector_label = qt.QLabel("Select items:", self)
+
+ self.kind_selector = KindsSelector(self)
+ self.kind_selector.setToolTip(
+ "select one or more item kinds to show them in the item list")
+
+ self.item_selector = PlotItemsSelector(self, plot)
+ self.item_selector.setToolTip("select items")
+
+ self.item_selector.setKindsFilter(self.kind_selector.selectedKinds)
+ self.kind_selector.sigSelectedKindsChanged.connect(
+ self.item_selector.setKindsFilter
+ )
+
+ okb = qt.QPushButton("OK", self)
+ okb.clicked.connect(self.accept)
+
+ cancelb = qt.QPushButton("Cancel", self)
+ cancelb.clicked.connect(self.reject)
+
+ layout = qt.QGridLayout(self)
+ layout.addWidget(kind_selector_label, 0, 0)
+ layout.addWidget(item_selector_label, 0, 1)
+ layout.addWidget(self.kind_selector, 1, 0)
+ layout.addWidget(self.item_selector, 1, 1)
+ layout.addWidget(okb, 2, 0)
+ layout.addWidget(cancelb, 2, 1)
+
+ self.setLayout(layout)
+
+ def getSelectedItems(self):
+ """Return a list of selected plot items
+
+ :return: List of selected plot items
+ :rtype: list[silx.gui.plot.items.Item]"""
+ return self.item_selector.selectedPlotItems
+
+ def setAvailableKinds(self, kinds):
+ """Set a list of kinds to be displayed.
+
+ :param list[str] kinds: Sequence of kinds
+ """
+ self.kind_selector.setAvailableKinds(kinds)
+
+ def selectAllKinds(self):
+ self.kind_selector.selectAll()
+
+ def setItemsSelectionMode(self, mode):
+ """Set selection mode for plot item (single item selection,
+ multiple...).
+
+ :param mode: One of :class:`QTableWidget` selection modes
+ """
+ if mode == self.item_selector.SingleSelection:
+ self.item_selector.setToolTip(
+ "Select one item by clicking on it.")
+ elif mode == self.item_selector.MultiSelection:
+ self.item_selector.setToolTip(
+ "Select one or more items by clicking with the left mouse"
+ " button.\nYou can unselect items by clicking them again.\n"
+ "Multiple items can be toggled by dragging the mouse over them.")
+ elif mode == self.item_selector.ExtendedSelection:
+ self.item_selector.setToolTip(
+ "Select one or more items. You can select multiple items "
+ "by keeping the Ctrl key pushed when clicking.\nYou can "
+ "select a range of items by clicking on the first and "
+ "last while keeping the Shift key pushed.")
+ elif mode == self.item_selector.ContiguousSelection:
+ self.item_selector.setToolTip(
+ "Select one item by clicking on it. If you press the Shift"
+ " key while clicking on a second item,\nall items between "
+ "the two will be selected.")
+ elif mode == self.item_selector.NoSelection:
+ raise ValueError("The NoSelection mode is not allowed "
+ "in this context.")
+ self.item_selector.setSelectionMode(mode)
diff --git a/src/silx/gui/plot/LegendSelector.py b/src/silx/gui/plot/LegendSelector.py
new file mode 100755
index 0000000..d439387
--- /dev/null
+++ b/src/silx/gui/plot/LegendSelector.py
@@ -0,0 +1,1039 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Widget displaying curves legends and allowing to operate on curves.
+
+This widget is meant to work with :class:`PlotWindow`.
+"""
+
+__authors__ = ["V.A. Sole", "T. Rueter", "T. Vincent"]
+__license__ = "MIT"
+__data__ = "16/10/2017"
+
+
+import logging
+import weakref
+
+import numpy
+
+from .. import qt, colors
+from ..widgets.LegendIconWidget import LegendIconWidget
+from . import items
+
+
+_logger = logging.getLogger(__name__)
+
+
+class LegendIcon(LegendIconWidget):
+ """Object displaying a curve linestyle and symbol.
+
+ :param QWidget parent: See :class:`QWidget`
+ :param Union[~silx.gui.plot.items.Curve,None] curve:
+ Curve with which to synchronize
+ """
+
+ def __init__(self, parent=None, curve=None):
+ super(LegendIcon, self).__init__(parent)
+ self._curveRef = None
+ self.setCurve(curve)
+
+ def getCurve(self):
+ """Returns curve associated to this widget
+
+ :rtype: Union[~silx.gui.plot.items.Curve,None]
+ """
+ return None if self._curveRef is None else self._curveRef()
+
+ def setCurve(self, curve):
+ """Set the curve with which to synchronize this widget.
+
+ :param curve: Union[~silx.gui.plot.items.Curve,None]
+ """
+ assert curve is None or isinstance(curve, items.Curve)
+
+ previousCurve = self.getCurve()
+ if curve == previousCurve:
+ return
+
+ if previousCurve is not None:
+ previousCurve.sigItemChanged.disconnect(self._curveChanged)
+
+ self._curveRef = None if curve is None else weakref.ref(curve)
+
+ if curve is not None:
+ curve.sigItemChanged.connect(self._curveChanged)
+
+ self._update()
+
+ def _update(self):
+ """Update widget according to current curve state.
+ """
+ curve = self.getCurve()
+ if curve is None:
+ _logger.error('Curve no more exists')
+ self.setEnabled(False)
+ return
+
+ style = curve.getCurrentStyle()
+
+ self.setEnabled(curve.isVisible())
+ self.setSymbol(style.getSymbol())
+ self.setLineWidth(style.getLineWidth())
+ self.setLineStyle(style.getLineStyle())
+
+ color = style.getColor()
+ if numpy.array(color, copy=False).ndim != 1:
+ # array of colors, use transparent black
+ color = 0., 0., 0., 0.
+ color = colors.rgba(color) # Make sure it is float in [0, 1]
+ alpha = curve.getAlpha()
+ color = qt.QColor.fromRgbF(
+ color[0], color[1], color[2], color[3] * alpha)
+ self.setLineColor(color)
+ self.setSymbolColor(color)
+ self.update() # TODO this should not be needed
+
+ def _curveChanged(self, event):
+ """Handle update of curve item
+
+ :param event: Kind of change
+ """
+ if event in (items.ItemChangedType.VISIBLE,
+ items.ItemChangedType.SYMBOL,
+ items.ItemChangedType.SYMBOL_SIZE,
+ items.ItemChangedType.LINE_WIDTH,
+ items.ItemChangedType.LINE_STYLE,
+ items.ItemChangedType.COLOR,
+ items.ItemChangedType.ALPHA,
+ items.ItemChangedType.HIGHLIGHTED,
+ items.ItemChangedType.HIGHLIGHTED_STYLE):
+ self._update()
+
+
+class LegendModel(qt.QAbstractListModel):
+ """Data model of curve legends.
+
+ It holds the information of the curve:
+
+ - color
+ - line width
+ - line style
+ - visibility of the lines
+ - symbol
+ - visibility of the symbols
+ """
+ iconColorRole = qt.Qt.UserRole + 0
+ iconLineWidthRole = qt.Qt.UserRole + 1
+ iconLineStyleRole = qt.Qt.UserRole + 2
+ showLineRole = qt.Qt.UserRole + 3
+ iconSymbolRole = qt.Qt.UserRole + 4
+ showSymbolRole = qt.Qt.UserRole + 5
+
+ def __init__(self, legendList=None, parent=None):
+ super(LegendModel, self).__init__(parent)
+ if legendList is None:
+ legendList = []
+ self.legendList = []
+ self.insertLegendList(0, legendList)
+ self._palette = qt.QPalette()
+
+ def __getitem__(self, idx):
+ if idx >= len(self.legendList):
+ raise IndexError('list index out of range')
+ return self.legendList[idx]
+
+ def rowCount(self, modelIndex=None):
+ return len(self.legendList)
+
+ def flags(self, index):
+ return (qt.Qt.ItemIsEditable |
+ qt.Qt.ItemIsEnabled |
+ qt.Qt.ItemIsSelectable)
+
+ def data(self, modelIndex, role):
+ if modelIndex.isValid:
+ idx = modelIndex.row()
+ else:
+ return None
+ if idx >= len(self.legendList):
+ raise IndexError('list index out of range')
+
+ item = self.legendList[idx]
+ isActive = item[1].get("active", False)
+ if role == qt.Qt.DisplayRole:
+ # Data to be rendered in the form of text
+ legend = str(item[0])
+ return legend
+ elif role == qt.Qt.SizeHintRole:
+ # size = qt.QSize(200,50)
+ _logger.warning('LegendModel -- size hint role not implemented')
+ return qt.QSize()
+ elif role == qt.Qt.TextAlignmentRole:
+ alignment = qt.Qt.AlignVCenter | qt.Qt.AlignLeft
+ return alignment
+ elif role == qt.Qt.BackgroundRole:
+ # Background color, must be QBrush
+ if isActive:
+ brush = self._palette.brush(qt.QPalette.Normal, qt.QPalette.Highlight)
+ elif idx % 2:
+ brush = qt.QBrush(qt.QColor(240, 240, 240))
+ else:
+ brush = qt.QBrush(qt.Qt.white)
+ return brush
+ elif role == qt.Qt.ForegroundRole:
+ # ForegroundRole color, must be QBrush
+ if isActive:
+ brush = self._palette.brush(qt.QPalette.Normal, qt.QPalette.HighlightedText)
+ else:
+ brush = self._palette.brush(qt.QPalette.Normal, qt.QPalette.WindowText)
+ return brush
+ elif role == qt.Qt.CheckStateRole:
+ return bool(item[2]) # item[2] == True
+ elif role == qt.Qt.ToolTipRole or role == qt.Qt.StatusTipRole:
+ return ''
+ elif role == self.iconColorRole:
+ return item[1]['color']
+ elif role == self.iconLineWidthRole:
+ return item[1]['linewidth']
+ elif role == self.iconLineStyleRole:
+ return item[1]['linestyle']
+ elif role == self.iconSymbolRole:
+ return item[1]['symbol']
+ elif role == self.showLineRole:
+ return item[3]
+ elif role == self.showSymbolRole:
+ return item[4]
+ else:
+ _logger.info('Unkown role requested: %s', str(role))
+ return None
+
+ def setData(self, modelIndex, value, role):
+ if modelIndex.isValid:
+ idx = modelIndex.row()
+ else:
+ return None
+ if idx >= len(self.legendList):
+ # raise IndexError('list index out of range')
+ _logger.warning(
+ 'setData -- List index out of range, idx: %d', idx)
+ return None
+
+ item = self.legendList[idx]
+ try:
+ if role == qt.Qt.DisplayRole:
+ # Set legend
+ item[0] = str(value)
+ elif role == self.iconColorRole:
+ item[1]['color'] = qt.QColor(value)
+ elif role == self.iconLineWidthRole:
+ item[1]['linewidth'] = int(value)
+ elif role == self.iconLineStyleRole:
+ item[1]['linestyle'] = str(value)
+ elif role == self.iconSymbolRole:
+ item[1]['symbol'] = str(value)
+ elif role == qt.Qt.CheckStateRole:
+ item[2] = value
+ elif role == self.showLineRole:
+ item[3] = value
+ elif role == self.showSymbolRole:
+ item[4] = value
+ except ValueError:
+ _logger.warning('Conversion failed:\n\tvalue: %s\n\trole: %s',
+ str(value), str(role))
+ # Can that be right? Read docs again..
+ self.dataChanged.emit(modelIndex, modelIndex)
+ return True
+
+ def insertLegendList(self, row, llist):
+ """
+ :param int row: Determines after which row the items are inserted
+ :param llist: Carries the new legend information
+ :type llist: List
+ """
+ modelIndex = self.createIndex(row, 0)
+ count = len(llist)
+ super(LegendModel, self).beginInsertRows(modelIndex,
+ row,
+ row + count)
+ head = self.legendList[0:row]
+ tail = self.legendList[row:]
+ new = []
+ for (legend, icon) in llist:
+ linestyle = icon.get('linestyle', None)
+ if LegendIconWidget.isEmptyLineStyle(linestyle):
+ # Curve had no line, give it one and hide it
+ # So when toggle line, it will display a solid line
+ showLine = False
+ icon['linestyle'] = '-'
+ else:
+ showLine = True
+
+ symbol = icon.get('symbol', None)
+ if LegendIconWidget.isEmptySymbol(symbol):
+ # Curve had no symbol, give it one and hide it
+ # So when toggle symbol, it will display 'o'
+ showSymbol = False
+ icon['symbol'] = 'o'
+ else:
+ showSymbol = True
+
+ selected = icon.get('selected', True)
+ item = [legend,
+ icon,
+ selected,
+ showLine,
+ showSymbol]
+ new.append(item)
+ self.legendList = head + new + tail
+ super(LegendModel, self).endInsertRows()
+ return True
+
+ def insertRows(self, row, count, modelIndex=qt.QModelIndex()):
+ raise NotImplementedError('Use LegendModel.insertLegendList instead')
+
+ def removeRow(self, row):
+ return self.removeRows(row, 1)
+
+ def removeRows(self, row, count, modelIndex=qt.QModelIndex()):
+ length = len(self.legendList)
+ if length == 0:
+ # Nothing to do..
+ return True
+ if row < 0 or row >= length:
+ raise IndexError('Index out of range -- ' +
+ 'idx: %d, len: %d' % (row, length))
+ if count == 0:
+ return False
+ super(LegendModel, self).beginRemoveRows(modelIndex,
+ row,
+ row + count)
+ del(self.legendList[row:row + count])
+ super(LegendModel, self).endRemoveRows()
+ return True
+
+ def setEditor(self, event, editor):
+ """
+ :param str event: String that identifies the editor
+ :param editor: Widget used to change data in the underlying model
+ :type editor: QWidget
+ """
+ if event not in self.eventList:
+ raise ValueError('setEditor -- Event must be in %s' %
+ str(self.eventList))
+ self.editorDict[event] = editor
+
+
+class LegendListItemWidget(qt.QItemDelegate):
+ """Object displaying a single item (i.e., a row) in the list."""
+
+ # Notice: LegendListItem does NOT inherit
+ # from QObject, it cannot emit signals!
+
+ def __init__(self, parent=None, itemType=0):
+ super(LegendListItemWidget, self).__init__(parent)
+
+ # Dictionary to render checkboxes
+ self.cbDict = {}
+ self.labelDict = {}
+ self.iconDict = {}
+
+ # Keep checkbox and legend to get sizeHint
+ self.checkbox = qt.QCheckBox()
+ self.legend = qt.QLabel()
+ self.icon = LegendIcon()
+
+ # Context Menu and Editors
+ self.contextMenu = None
+
+ def paint(self, painter, option, modelIndex):
+ """
+ Here be docs..
+
+ :param QPainter painter:
+ :param QStyleOptionViewItem option:
+ :param QModelIndex modelIndex:
+ """
+ painter.save()
+ rect = option.rect
+
+ # Calculate the icon rectangle
+ iconSize = self.icon.sizeHint()
+ # Calculate icon position
+ x = rect.left() + 2
+ y = rect.top() + int(.5 * (rect.height() - iconSize.height()))
+ iconRect = qt.QRect(qt.QPoint(x, y), iconSize)
+
+ # Calculate label rectangle
+ legendSize = qt.QSize(rect.width() - iconSize.width() - 30,
+ rect.height())
+ # Calculate label position
+ x = rect.left() + iconRect.width()
+ y = rect.top()
+ labelRect = qt.QRect(qt.QPoint(x, y), legendSize)
+ labelRect.translate(qt.QPoint(10, 0))
+
+ # Calculate the checkbox rectangle
+ x = rect.right() - 30
+ y = rect.top()
+ chBoxRect = qt.QRect(qt.QPoint(x, y), rect.bottomRight())
+
+ # Remember the rectangles
+ idx = modelIndex.row()
+ self.cbDict[idx] = chBoxRect
+ self.iconDict[idx] = iconRect
+ self.labelDict[idx] = labelRect
+
+ # Draw background first!
+ if option.state & qt.QStyle.State_MouseOver:
+ backgroundBrush = option.palette.highlight()
+ else:
+ backgroundBrush = modelIndex.data(qt.Qt.BackgroundRole)
+ painter.fillRect(rect, backgroundBrush)
+
+ # Draw label
+ legendText = modelIndex.data(qt.Qt.DisplayRole)
+ textBrush = modelIndex.data(qt.Qt.ForegroundRole)
+ textAlign = modelIndex.data(qt.Qt.TextAlignmentRole)
+ painter.setBrush(textBrush)
+ painter.setFont(self.legend.font())
+ painter.setPen(textBrush.color())
+ painter.drawText(labelRect, textAlign, legendText)
+
+ # Draw icon
+ iconColor = modelIndex.data(LegendModel.iconColorRole)
+ iconLineWidth = modelIndex.data(LegendModel.iconLineWidthRole)
+ iconLineStyle = modelIndex.data(LegendModel.iconLineStyleRole)
+ iconSymbol = modelIndex.data(LegendModel.iconSymbolRole)
+ icon = LegendIcon()
+ icon.resize(iconRect.size())
+ icon.move(iconRect.topRight())
+ icon.showSymbol = modelIndex.data(LegendModel.showSymbolRole)
+ icon.showLine = modelIndex.data(LegendModel.showLineRole)
+ icon.setSymbolColor(iconColor)
+ icon.setLineColor(iconColor)
+ icon.setLineWidth(iconLineWidth)
+ icon.setLineStyle(iconLineStyle)
+ icon.setSymbol(iconSymbol)
+ icon.symbolOutlineBrush = backgroundBrush
+ icon.paint(painter, iconRect, option.palette)
+
+ # Draw the checkbox
+ if modelIndex.data(qt.Qt.CheckStateRole):
+ checkState = qt.Qt.Checked
+ else:
+ checkState = qt.Qt.Unchecked
+
+ self.drawCheck(
+ painter, qt.QStyleOptionViewItem(), chBoxRect, checkState)
+
+ painter.restore()
+
+ def editorEvent(self, event, model, option, modelIndex):
+ # From the docs:
+ # Mouse events are sent to editorEvent()
+ # even if they don't start editing of the item.
+ if event.button() == qt.Qt.RightButton and self.contextMenu:
+ self.contextMenu.exec(event.globalPos(), modelIndex)
+ return True
+ elif event.button() == qt.Qt.LeftButton:
+ # Check if checkbox was clicked
+ idx = modelIndex.row()
+ cbRect = self.cbDict[idx]
+ if cbRect.contains(event.pos()):
+ # Toggle checkbox
+ model.setData(modelIndex,
+ not modelIndex.data(qt.Qt.CheckStateRole),
+ qt.Qt.CheckStateRole)
+ event.ignore()
+ return True
+ else:
+ return super(LegendListItemWidget, self).editorEvent(
+ event, model, option, modelIndex)
+
+ def createEditor(self, parent, option, idx):
+ _logger.info('### Editor request ###')
+
+ def sizeHint(self, option, idx):
+ # return qt.QSize(68,24)
+ iconSize = self.icon.sizeHint()
+ legendSize = self.legend.sizeHint()
+ checkboxSize = self.checkbox.sizeHint()
+ height = max([iconSize.height(),
+ legendSize.height(),
+ checkboxSize.height()]) + 4
+ width = iconSize.width() + legendSize.width() + checkboxSize.width()
+ return qt.QSize(width, height)
+
+
+class LegendListView(qt.QListView):
+ """Widget displaying a list of curve legends, line style and symbol."""
+
+ sigLegendSignal = qt.Signal(object)
+ """Signal emitting a dict when an action is triggered by the user."""
+
+ __mouseClickedEvent = 'mouseClicked'
+ __checkBoxClickedEvent = 'checkBoxClicked'
+ __legendClickedEvent = 'legendClicked'
+
+ def __init__(self, parent=None, model=None, contextMenu=None):
+ super(LegendListView, self).__init__(parent)
+ self.__lastButton = None
+ self.__lastClickPos = None
+ self.__lastModelIdx = None
+ # Set default delegate
+ self.setItemDelegate(LegendListItemWidget())
+ # Set default editors
+ # self.setSizePolicy(qt.QSizePolicy.MinimumExpanding,
+ # qt.QSizePolicy.MinimumExpanding)
+ # Set edit triggers by hand using self.edit(QModelIndex)
+ # in mousePressEvent (better to control than signals)
+ self.setEditTriggers(qt.QAbstractItemView.NoEditTriggers)
+
+ # Control layout
+ # self.setBatchSize(2)
+ # self.setLayoutMode(qt.QListView.Batched)
+ # self.setFlow(qt.QListView.LeftToRight)
+
+ # Control selection
+ self.setSelectionMode(qt.QAbstractItemView.NoSelection)
+
+ if model is None:
+ model = LegendModel(parent=self)
+ self.setModel(model)
+ self.setContextMenu(contextMenu)
+
+ def setLegendList(self, legendList, row=None):
+ if row is not None:
+ model = self.model()
+ model.insertLegendList(row, legendList)
+ elif len(legendList) != self.model().rowCount():
+ self.clear()
+ model = self.model()
+ model.insertLegendList(0, legendList)
+ else:
+ model = self.model()
+ for i, (new_legend, icon) in enumerate(legendList):
+ modelIndex = model.index(i)
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ if new_legend != legend:
+ model.setData(modelIndex, new_legend, qt.Qt.DisplayRole)
+
+ color = modelIndex.data(LegendModel.iconColorRole)
+ new_color = icon.get('color', None)
+ if new_color != color:
+ model.setData(modelIndex, new_color, LegendModel.iconColorRole)
+
+ linewidth = modelIndex.data(LegendModel.iconLineWidthRole)
+ new_linewidth = icon.get('linewidth', 1.0)
+ if new_linewidth != linewidth:
+ model.setData(modelIndex, new_linewidth, LegendModel.iconLineWidthRole)
+
+ linestyle = modelIndex.data(LegendModel.iconLineStyleRole)
+ new_linestyle = icon.get('linestyle', None)
+ visible = not LegendIconWidget.isEmptyLineStyle(new_linestyle)
+ model.setData(modelIndex, visible, LegendModel.showLineRole)
+ if new_linestyle != linestyle:
+ model.setData(modelIndex, new_linestyle, LegendModel.iconLineStyleRole)
+
+ symbol = modelIndex.data(LegendModel.iconSymbolRole)
+ new_symbol = icon.get('symbol', None)
+ visible = not LegendIconWidget.isEmptySymbol(new_symbol)
+ model.setData(modelIndex, visible, LegendModel.showSymbolRole)
+ if new_symbol != symbol:
+ model.setData(modelIndex, new_symbol, LegendModel.iconSymbolRole)
+
+ selected = modelIndex.data(qt.Qt.CheckStateRole)
+ new_selected = icon.get('selected', True)
+ if new_selected != selected:
+ model.setData(modelIndex, new_selected, qt.Qt.CheckStateRole)
+ _logger.debug('LegendListView.setLegendList(legendList) finished')
+
+ def clear(self):
+ model = self.model()
+ model.removeRows(0, model.rowCount())
+ _logger.debug('LegendListView.clear() finished')
+
+ def setContextMenu(self, contextMenu=None):
+ delegate = self.itemDelegate()
+ if isinstance(delegate, LegendListItemWidget) and self.model():
+ if contextMenu is None:
+ delegate.contextMenu = LegendListContextMenu(self.model())
+ delegate.contextMenu.sigContextMenu.connect(
+ self._contextMenuSlot)
+ else:
+ delegate.contextMenu = contextMenu
+
+ def __getitem__(self, idx):
+ model = self.model()
+ try:
+ item = model[idx]
+ except ValueError:
+ item = None
+ return item
+
+ def _contextMenuSlot(self, ddict):
+ self.sigLegendSignal.emit(ddict)
+
+ def mousePressEvent(self, event):
+ self.__lastButton = event.button()
+ self.__lastPosition = event.pos()
+ super(LegendListView, self).mousePressEvent(event)
+ # call _handleMouseClick after editing was handled
+ # If right click (context menu) is aborted, no
+ # signal is emitted..
+ self._handleMouseClick(self.indexAt(self.__lastPosition))
+
+ def mouseDoubleClickEvent(self, event):
+ self.__lastButton = event.button()
+ self.__lastPosition = event.pos()
+ super(LegendListView, self).mouseDoubleClickEvent(event)
+ # call _handleMouseClick after editing was handled
+ # If right click (context menu) is aborted, no
+ # signal is emitted..
+ self._handleMouseClick(self.indexAt(self.__lastPosition))
+
+ def mouseMoveEvent(self, event):
+ # LegendListView.mouseMoveEvent is overwritten
+ # to suppress unwanted behavior in the delegate.
+ pass
+
+ def mouseReleaseEvent(self, event):
+ # LegendListView.mouseReleaseEvent is overwritten
+ # to subpress unwanted behavior in the delegate.
+ pass
+
+ def _handleMouseClick(self, modelIndex):
+ """
+ Distinguish between mouse click on Legend
+ and mouse click on CheckBox by setting the
+ currentCheckState attribute in LegendListItem.
+
+ Emits signal sigLegendSignal(ddict)
+
+ :param QModelIndex modelIndex: index of the clicked item
+ """
+ _logger.debug('self._handleMouseClick called')
+ if self.__lastButton not in [qt.Qt.LeftButton,
+ qt.Qt.RightButton]:
+ return
+ if not modelIndex.isValid():
+ _logger.debug('_handleMouseClick -- Invalid QModelIndex')
+ return
+ # model = self.model()
+ idx = modelIndex.row()
+
+ delegate = self.itemDelegate()
+ cbClicked = False
+ if isinstance(delegate, LegendListItemWidget):
+ for cbRect in delegate.cbDict.values():
+ if cbRect.contains(self.__lastPosition):
+ cbClicked = True
+ break
+
+ # TODO: Check for doubleclicks on legend/icon and spawn editors
+
+ ddict = {
+ 'legend': str(modelIndex.data(qt.Qt.DisplayRole)),
+ 'icon': {
+ 'linewidth': str(modelIndex.data(
+ LegendModel.iconLineWidthRole)),
+ 'linestyle': str(modelIndex.data(
+ LegendModel.iconLineStyleRole)),
+ 'symbol': str(modelIndex.data(LegendModel.iconSymbolRole))
+ },
+ 'selected': modelIndex.data(qt.Qt.CheckStateRole),
+ 'type': str(modelIndex.data())
+ }
+ if self.__lastButton == qt.Qt.RightButton:
+ _logger.debug('Right clicked')
+ ddict['button'] = "right"
+ ddict['event'] = self.__mouseClickedEvent
+ elif cbClicked:
+ _logger.debug('CheckBox clicked')
+ ddict['button'] = "left"
+ ddict['event'] = self.__checkBoxClickedEvent
+ else:
+ _logger.debug('Legend clicked')
+ ddict['button'] = "left"
+ ddict['event'] = self.__legendClickedEvent
+ _logger.debug(' idx: %d\n ddict: %s', idx, str(ddict))
+ self.sigLegendSignal.emit(ddict)
+
+
+class LegendListContextMenu(qt.QMenu):
+ """Contextual menu associated to items in a :class:`LegendListView`."""
+
+ sigContextMenu = qt.Signal(object)
+ """Signal emitting a dict upon contextual menu actions."""
+
+ def __init__(self, model):
+ super(LegendListContextMenu, self).__init__(parent=None)
+ self.model = model
+
+ self.addAction('Set Active', self.setActiveAction)
+ self.addAction('Map to left', self.mapToLeftAction)
+ self.addAction('Map to right', self.mapToRightAction)
+
+ self._pointsAction = self.addAction(
+ 'Points', self.togglePointsAction)
+ self._pointsAction.setCheckable(True)
+
+ self._linesAction = self.addAction('Lines', self.toggleLinesAction)
+ self._linesAction.setCheckable(True)
+
+ self.addAction('Remove curve', self.removeItemAction)
+ self.addAction('Rename curve', self.renameItemAction)
+
+ def exec(self, pos, idx):
+ self.__currentIdx = idx
+
+ # Set checkable action state
+ modelIndex = self.currentIdx()
+ self._pointsAction.setChecked(
+ modelIndex.data(LegendModel.showSymbolRole))
+ self._linesAction.setChecked(
+ modelIndex.data(LegendModel.showLineRole))
+
+ super(LegendListContextMenu, self).popup(pos)
+
+ def exec_(self, pos, idx): # Qt5-like compatibility
+ return self.exec(pos, idx)
+
+ def currentIdx(self):
+ return self.__currentIdx
+
+ def mapToLeftAction(self):
+ _logger.debug('LegendListContextMenu.mapToLeftAction called')
+ modelIndex = self.currentIdx()
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ ddict = {
+ 'legend': legend,
+ 'label': legend,
+ 'selected': modelIndex.data(qt.Qt.CheckStateRole),
+ 'type': str(modelIndex.data()),
+ 'event': "mapToLeft"
+ }
+ self.sigContextMenu.emit(ddict)
+
+ def mapToRightAction(self):
+ _logger.debug('LegendListContextMenu.mapToRightAction called')
+ modelIndex = self.currentIdx()
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ ddict = {
+ 'legend': legend,
+ 'label': legend,
+ 'selected': modelIndex.data(qt.Qt.CheckStateRole),
+ 'type': str(modelIndex.data()),
+ 'event': "mapToRight"
+ }
+ self.sigContextMenu.emit(ddict)
+
+ def removeItemAction(self):
+ _logger.debug('LegendListContextMenu.removeCurveAction called')
+ modelIndex = self.currentIdx()
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ ddict = {
+ 'legend': legend,
+ 'label': legend,
+ 'selected': modelIndex.data(qt.Qt.CheckStateRole),
+ 'type': str(modelIndex.data()),
+ 'event': "removeCurve"
+ }
+ self.model.removeRow(modelIndex.row())
+ self.sigContextMenu.emit(ddict)
+
+ def renameItemAction(self):
+ _logger.debug('LegendListContextMenu.renameCurveAction called')
+ modelIndex = self.currentIdx()
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ ddict = {
+ 'legend': legend,
+ 'label': legend,
+ 'selected': modelIndex.data(qt.Qt.CheckStateRole),
+ 'type': str(modelIndex.data()),
+ 'event': "renameCurve"
+ }
+ self.sigContextMenu.emit(ddict)
+
+ def toggleLinesAction(self):
+ modelIndex = self.currentIdx()
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ ddict = {
+ 'legend': legend,
+ 'label': legend,
+ 'selected': modelIndex.data(qt.Qt.CheckStateRole),
+ 'type': str(modelIndex.data()),
+ }
+ linestyle = modelIndex.data(LegendModel.iconLineStyleRole)
+ visible = not modelIndex.data(LegendModel.showLineRole)
+ _logger.debug('toggleLinesAction -- lines visible: %s', str(visible))
+ ddict['event'] = "toggleLine"
+ ddict['line'] = visible
+ ddict['linestyle'] = linestyle if visible else ''
+ self.model.setData(modelIndex, visible, LegendModel.showLineRole)
+ self.sigContextMenu.emit(ddict)
+
+ def togglePointsAction(self):
+ modelIndex = self.currentIdx()
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ ddict = {
+ 'legend': legend,
+ 'label': legend,
+ 'selected': modelIndex.data(qt.Qt.CheckStateRole),
+ 'type': str(modelIndex.data()),
+ }
+ flag = modelIndex.data(LegendModel.showSymbolRole)
+ symbol = modelIndex.data(LegendModel.iconSymbolRole)
+ visible = not flag or LegendIconWidget.isEmptySymbol(symbol)
+ _logger.debug(
+ 'togglePointsAction -- Symbols visible: %s', str(visible))
+
+ ddict['event'] = "togglePoints"
+ ddict['points'] = visible
+ ddict['symbol'] = symbol if visible else ''
+ self.model.setData(modelIndex, visible, LegendModel.showSymbolRole)
+ self.sigContextMenu.emit(ddict)
+
+ def setActiveAction(self):
+ modelIndex = self.currentIdx()
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ _logger.debug('setActiveAction -- active curve: %s', legend)
+ ddict = {
+ 'legend': legend,
+ 'label': legend,
+ 'selected': modelIndex.data(qt.Qt.CheckStateRole),
+ 'type': str(modelIndex.data()),
+ 'event': "setActiveCurve",
+ }
+ self.sigContextMenu.emit(ddict)
+
+
+class RenameCurveDialog(qt.QDialog):
+ """Dialog box to input the name of a curve."""
+
+ def __init__(self, parent=None, current="", curves=()):
+ super(RenameCurveDialog, self).__init__(parent)
+ self.setWindowTitle("Rename Curve %s" % current)
+ self.curves = curves
+ layout = qt.QVBoxLayout(self)
+ self.lineEdit = qt.QLineEdit(self)
+ self.lineEdit.setText(current)
+ self.hbox = qt.QWidget(self)
+ self.hboxLayout = qt.QHBoxLayout(self.hbox)
+ self.hboxLayout.addStretch(1)
+ self.okButton = qt.QPushButton(self.hbox)
+ self.okButton.setText('OK')
+ self.hboxLayout.addWidget(self.okButton)
+ self.cancelButton = qt.QPushButton(self.hbox)
+ self.cancelButton.setText('Cancel')
+ self.hboxLayout.addWidget(self.cancelButton)
+ self.hboxLayout.addStretch(1)
+ layout.addWidget(self.lineEdit)
+ layout.addWidget(self.hbox)
+ self.okButton.clicked.connect(self.preAccept)
+ self.cancelButton.clicked.connect(self.reject)
+
+ def preAccept(self):
+ text = str(self.lineEdit.text())
+ addedText = ""
+ if len(text):
+ if text not in self.curves:
+ self.accept()
+ return
+ else:
+ addedText = "Curve already exists."
+ text = "Invalid Curve Name"
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setWindowTitle(text)
+ text += "\n%s" % addedText
+ msg.setText(text)
+ msg.exec()
+
+ def getText(self):
+ return str(self.lineEdit.text())
+
+
+class LegendsDockWidget(qt.QDockWidget):
+ """QDockWidget with a :class:`LegendSelector` connected to a PlotWindow.
+
+ It makes the link between the LegendListView widget and the PlotWindow.
+
+ :param parent: See :class:`QDockWidget`
+ :param plot: :class:`.PlotWindow` instance on which to operate
+ """
+
+ def __init__(self, parent=None, plot=None):
+ assert plot is not None
+ self._plotRef = weakref.ref(plot)
+ self._isConnected = False # True if widget connected to plot signals
+
+ super(LegendsDockWidget, self).__init__("Legends", parent)
+
+ self._legendWidget = LegendListView()
+
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self.setWidget(self._legendWidget)
+
+ self.visibilityChanged.connect(
+ self._visibilityChangedHandler)
+
+ self._legendWidget.sigLegendSignal.connect(self._legendSignalHandler)
+
+ @property
+ def plot(self):
+ """The :class:`.PlotWindow` this widget is attached to."""
+ return self._plotRef()
+
+ def renameCurve(self, oldLegend, newLegend):
+ """Change the name of a curve using remove and addCurve
+
+ :param str oldLegend: The legend of the curve to be changed
+ :param str newLegend: The new legend of the curve
+ """
+ is_active = self.plot.getActiveCurve(just_legend=True) == oldLegend
+ curve = self.plot.getCurve(oldLegend)
+ self.plot.remove(oldLegend, kind='curve')
+ self.plot.addCurve(curve.getXData(copy=False),
+ curve.getYData(copy=False),
+ legend=newLegend,
+ info=curve.getInfo(),
+ color=curve.getColor(),
+ symbol=curve.getSymbol(),
+ linewidth=curve.getLineWidth(),
+ linestyle=curve.getLineStyle(),
+ xlabel=curve.getXLabel(),
+ ylabel=curve.getYLabel(),
+ xerror=curve.getXErrorData(copy=False),
+ yerror=curve.getYErrorData(copy=False),
+ z=curve.getZValue(),
+ selectable=curve.isSelectable(),
+ fill=curve.isFill(),
+ resetzoom=False)
+ if is_active:
+ self.plot.setActiveCurve(newLegend)
+
+ def _legendSignalHandler(self, ddict):
+ """Handles events from the LegendListView signal"""
+ _logger.debug("Legend signal ddict = %s", str(ddict))
+
+ if ddict['event'] == "legendClicked":
+ if ddict['button'] == "left":
+ self.plot.setActiveCurve(ddict['legend'])
+
+ elif ddict['event'] == "removeCurve":
+ self.plot.removeCurve(ddict['legend'])
+
+ elif ddict['event'] == "renameCurve":
+ curveList = self.plot.getAllCurves(just_legend=True)
+ oldLegend = ddict['legend']
+ dialog = RenameCurveDialog(self.plot, oldLegend, curveList)
+ ret = dialog.exec()
+ if ret:
+ newLegend = dialog.getText()
+ self.renameCurve(oldLegend, newLegend)
+
+ elif ddict['event'] == "setActiveCurve":
+ self.plot.setActiveCurve(ddict['legend'])
+
+ elif ddict['event'] == "checkBoxClicked":
+ self.plot.hideCurve(ddict['legend'], not ddict['selected'])
+
+ elif ddict['event'] in ["mapToRight", "mapToLeft"]:
+ legend = ddict['legend']
+ curve = self.plot.getCurve(legend)
+ yaxis = 'right' if ddict['event'] == 'mapToRight' else 'left'
+ self.plot.addCurve(x=curve.getXData(copy=False),
+ y=curve.getYData(copy=False),
+ legend=curve.getName(),
+ info=curve.getInfo(),
+ yaxis=yaxis)
+
+ elif ddict['event'] == "togglePoints":
+ legend = ddict['legend']
+ curve = self.plot.getCurve(legend)
+ symbol = ddict['symbol'] if ddict['points'] else ''
+ self.plot.addCurve(x=curve.getXData(copy=False),
+ y=curve.getYData(copy=False),
+ legend=curve.getName(),
+ info=curve.getInfo(),
+ symbol=symbol)
+
+ elif ddict['event'] == "toggleLine":
+ legend = ddict['legend']
+ curve = self.plot.getCurve(legend)
+ linestyle = ddict['linestyle'] if ddict['line'] else ''
+ self.plot.addCurve(x=curve.getXData(copy=False),
+ y=curve.getYData(copy=False),
+ legend=curve.getName(),
+ info=curve.getInfo(),
+ linestyle=linestyle)
+
+ else:
+ _logger.debug("unhandled event %s", str(ddict['event']))
+
+ def updateLegends(self, *args):
+ """Sync the LegendSelector widget displayed info with the plot.
+ """
+ legendList = []
+ for curve in self.plot.getAllCurves(withhidden=True):
+ legend = curve.getName()
+ # Use active color if curve is active
+ isActive = legend == self.plot.getActiveCurve(just_legend=True)
+ style = curve.getCurrentStyle()
+ color = style.getColor()
+ if numpy.array(color, copy=False).ndim != 1:
+ # array of colors, use transparent black
+ color = 0., 0., 0., 0.
+
+ curveInfo = {
+ 'color': qt.QColor.fromRgbF(*color),
+ 'linewidth': style.getLineWidth(),
+ 'linestyle': style.getLineStyle(),
+ 'symbol': style.getSymbol(),
+ 'selected': not self.plot.isCurveHidden(legend),
+ 'active': isActive}
+ legendList.append((legend, curveInfo))
+
+ self._legendWidget.setLegendList(legendList)
+
+ def _visibilityChangedHandler(self, visible):
+ if visible:
+ self.updateLegends()
+ if not self._isConnected:
+ self.plot.sigContentChanged.connect(self.updateLegends)
+ self.plot.sigActiveCurveChanged.connect(self.updateLegends)
+ self._isConnected = True
+ else:
+ if self._isConnected:
+ self.plot.sigContentChanged.disconnect(self.updateLegends)
+ self.plot.sigActiveCurveChanged.disconnect(self.updateLegends)
+ self._isConnected = False
+
+ def showEvent(self, event):
+ """Make sure this widget is raised when it is shown
+ (when it is first created as a tab in PlotWindow or when it is shown
+ again after hiding).
+ """
+ self.raise_()
diff --git a/src/silx/gui/plot/LimitsHistory.py b/src/silx/gui/plot/LimitsHistory.py
new file mode 100644
index 0000000..a323548
--- /dev/null
+++ b/src/silx/gui/plot/LimitsHistory.py
@@ -0,0 +1,83 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides handling of :class:`PlotWidget` limits history.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "19/07/2017"
+
+
+from .. import qt
+
+
+class LimitsHistory(qt.QObject):
+ """Class handling history of limits of a :class:`PlotWidget`.
+
+ :param PlotWidget parent: The plot widget this object is bound to.
+ """
+
+ def __init__(self, parent):
+ self._history = []
+ super(LimitsHistory, self).__init__(parent)
+ self.setParent(parent)
+
+ def setParent(self, parent):
+ """See :meth:`QObject.setParent`.
+
+ :param PlotWidget parent: The PlotWidget this object is bound to.
+ """
+ self.clear() # Clear history when changing parent
+ super(LimitsHistory, self).setParent(parent)
+
+ def push(self):
+ """Append current limits to the history."""
+ plot = self.parent()
+ xmin, xmax = plot.getXAxis().getLimits()
+ ymin, ymax = plot.getYAxis(axis='left').getLimits()
+ y2min, y2max = plot.getYAxis(axis='right').getLimits()
+ self._history.append((xmin, xmax, ymin, ymax, y2min, y2max))
+
+ def pop(self):
+ """Restore previously limits stored in the history.
+
+ :return: True if limits were restored, False if history was empty.
+ :rtype: bool
+ """
+ plot = self.parent()
+ if self._history:
+ limits = self._history.pop(-1)
+ plot.setLimits(*limits)
+ return True
+ else:
+ plot.resetZoom()
+ return False
+
+ def clear(self):
+ """Clear stored limits states."""
+ self._history = []
+
+ def __len__(self):
+ return len(self._history)
diff --git a/src/silx/gui/plot/MaskToolsWidget.py b/src/silx/gui/plot/MaskToolsWidget.py
new file mode 100644
index 0000000..522be48
--- /dev/null
+++ b/src/silx/gui/plot/MaskToolsWidget.py
@@ -0,0 +1,919 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Widget providing a set of tools to draw masks on a PlotWidget.
+
+This widget is meant to work with :class:`silx.gui.plot.PlotWidget`.
+
+- :class:`ImageMask`: Handle mask bitmap update and history
+- :class:`MaskToolsWidget`: GUI for :class:`Mask`
+- :class:`MaskToolsDockWidget`: DockWidget to integrate in :class:`PlotWindow`
+"""
+from __future__ import division
+
+__authors__ = ["T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "08/12/2020"
+
+import os
+import sys
+import numpy
+import logging
+import collections
+import h5py
+
+from silx.image import shapes
+from silx.io.utils import NEXUS_HDF5_EXT, is_dataset
+from silx.gui.dialog.DatasetDialog import DatasetDialog
+
+from ._BaseMaskToolsWidget import BaseMask, BaseMaskToolsWidget, BaseMaskToolsDockWidget
+from . import items
+from ..colors import cursorColorForColormap, rgba
+from .. import qt
+from ..utils import LockReentrant
+
+from silx.third_party.EdfFile import EdfFile
+from silx.third_party.TiffIO import TiffIO
+
+import fabio
+
+_logger = logging.getLogger(__name__)
+
+_HDF5_EXT_STR = ' '.join(['*' + ext for ext in NEXUS_HDF5_EXT])
+
+
+def _selectDataset(filename, mode=DatasetDialog.SaveMode):
+ """Open a dialog to prompt the user to select a dataset in
+ a hdf5 file.
+
+ :param str filename: name of an existing HDF5 file
+ :param mode: DatasetDialog.SaveMode or DatasetDialog.LoadMode
+ :rtype: str
+ :return: Name of selected dataset
+ """
+ dialog = DatasetDialog()
+ dialog.addFile(filename)
+ dialog.setWindowTitle("Select a 2D dataset")
+ dialog.setMode(mode)
+ if not dialog.exec():
+ return None
+ return dialog.getSelectedDataUrl().data_path()
+
+
+class ImageMask(BaseMask):
+ """A 2D mask field with update operations.
+
+ Coords follows (row, column) convention and are in mask array coords.
+
+ This is meant for internal use by :class:`MaskToolsWidget`.
+ """
+
+ def __init__(self, image=None):
+ """
+
+ :param image: :class:`silx.gui.plot.items.ImageBase` instance
+ """
+ BaseMask.__init__(self, image)
+ self.reset(shape=(0, 0)) # Init the mask with a 2D shape
+
+ def getDataValues(self):
+ """Return image data as a 2D or 3D array (if it is a RGBA image).
+
+ :rtype: 2D or 3D numpy.ndarray
+ """
+ return self._dataItem.getData(copy=False)
+
+ def save(self, filename, kind):
+ """Save current mask in a file
+
+ :param str filename: The file where to save to mask
+ :param str kind: The kind of file to save in 'edf', 'tif', 'npy', 'h5'
+ or 'msk' (if FabIO is installed)
+ :raise Exception: Raised if the file writing fail
+ """
+ if kind == 'edf':
+ edfFile = EdfFile(filename, access="w+")
+ header = {"program_name": "silx-mask", "masked_value": "nonzero"}
+ edfFile.WriteImage(header, self.getMask(copy=False), Append=0)
+
+ elif kind == 'tif':
+ tiffFile = TiffIO(filename, mode='w')
+ tiffFile.writeImage(self.getMask(copy=False), software='silx')
+
+ elif kind == 'npy':
+ try:
+ numpy.save(filename, self.getMask(copy=False))
+ except IOError:
+ raise RuntimeError("Mask file can't be written")
+
+ elif ("." + kind) in NEXUS_HDF5_EXT:
+ self._saveToHdf5(filename, self.getMask(copy=False))
+
+ elif kind == 'msk':
+ try:
+ data = self.getMask(copy=False)
+ image = fabio.fabioimage.FabioImage(data=data)
+ image = image.convert(fabio.fit2dmaskimage.Fit2dMaskImage)
+ image.save(filename)
+ except Exception:
+ _logger.debug("Backtrace", exc_info=True)
+ raise RuntimeError("Mask file can't be written")
+ else:
+ raise ValueError("Format '%s' is not supported" % kind)
+
+ @staticmethod
+ def _saveToHdf5(filename, mask):
+ """Save a mask array to a HDF5 file.
+
+ :param str filename: name of an existing HDF5 file
+ :param numpy.ndarray mask: Mask array.
+ :returns: True if operation succeeded, False otherwise.
+ """
+ if not os.path.exists(filename):
+ # create new file
+ with h5py.File(filename, "w") as _h5f:
+ pass
+ dataPath = _selectDataset(filename)
+ if dataPath is None:
+ return False
+ with h5py.File(filename, "a") as h5f:
+ existing_ds = h5f.get(dataPath)
+ if existing_ds is not None:
+ reply = qt.QMessageBox.question(
+ None,
+ "Confirm overwrite",
+ "Do you want to overwrite an existing dataset?",
+ qt.QMessageBox.Yes | qt.QMessageBox.No)
+ if reply != qt.QMessageBox.Yes:
+ return False
+ del h5f[dataPath]
+ try:
+ h5f.create_dataset(dataPath, data=mask)
+ except Exception:
+ return False
+ return True
+
+ # Drawing operations
+ def updateRectangle(self, level, row, col, height, width, mask=True):
+ """Mask/Unmask a rectangle of the given mask level.
+
+ :param int level: Mask level to update.
+ :param int row: Starting row of the rectangle
+ :param int col: Starting column of the rectangle
+ :param int height:
+ :param int width:
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ assert 0 < level < 256
+ if row + height <= 0 or col + width <= 0:
+ return # Rectangle outside image, avoid negative indices
+ selection = self._mask[max(0, row):row + height + 1,
+ max(0, col):col + width + 1]
+ if mask:
+ selection[:,:] = level
+ else:
+ selection[selection == level] = 0
+ self._notify()
+
+ def updatePolygon(self, level, vertices, mask=True):
+ """Mask/Unmask a polygon of the given mask level.
+
+ :param int level: Mask level to update.
+ :param vertices: Nx2 array of polygon corners as (row, col)
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ fill = shapes.polygon_fill_mask(vertices, self._mask.shape)
+ if mask:
+ self._mask[fill != 0] = level
+ else:
+ self._mask[numpy.logical_and(fill != 0,
+ self._mask == level)] = 0
+ self._notify()
+
+ def updatePoints(self, level, rows, cols, mask=True):
+ """Mask/Unmask points with given coordinates.
+
+ :param int level: Mask level to update.
+ :param rows: Rows of selected points
+ :type rows: 1D numpy.ndarray
+ :param cols: Columns of selected points
+ :type cols: 1D numpy.ndarray
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ valid = numpy.logical_and(
+ numpy.logical_and(rows >= 0, cols >= 0),
+ numpy.logical_and(rows < self._mask.shape[0],
+ cols < self._mask.shape[1]))
+ rows, cols = rows[valid], cols[valid]
+
+ if mask:
+ self._mask[rows, cols] = level
+ else:
+ inMask = self._mask[rows, cols] == level
+ self._mask[rows[inMask], cols[inMask]] = 0
+ self._notify()
+
+ def updateDisk(self, level, crow, ccol, radius, mask=True):
+ """Mask/Unmask a disk of the given mask level.
+
+ :param int level: Mask level to update.
+ :param int crow: Disk center row.
+ :param int ccol: Disk center column.
+ :param float radius: Radius of the disk in mask array unit
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ rows, cols = shapes.circle_fill(crow, ccol, radius)
+ self.updatePoints(level, rows, cols, mask)
+
+ def updateEllipse(self, level, crow, ccol, radius_r, radius_c, mask=True):
+ """Mask/Unmask an ellipse of the given mask level.
+
+ :param int level: Mask level to update.
+ :param int crow: Row of the center of the ellipse
+ :param int ccol: Column of the center of the ellipse
+ :param float radius_r: Radius of the ellipse in the row
+ :param float radius_c: Radius of the ellipse in the column
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ rows, cols = shapes.ellipse_fill(crow, ccol, radius_r, radius_c)
+ self.updatePoints(level, rows, cols, mask)
+
+ def updateLine(self, level, row0, col0, row1, col1, width, mask=True):
+ """Mask/Unmask a line of the given mask level.
+
+ :param int level: Mask level to update.
+ :param int row0: Row of the starting point.
+ :param int col0: Column of the starting point.
+ :param int row1: Row of the end point.
+ :param int col1: Column of the end point.
+ :param int width: Width of the line in mask array unit.
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ rows, cols = shapes.draw_line(row0, col0, row1, col1, width)
+ self.updatePoints(level, rows, cols, mask)
+
+
+class MaskToolsWidget(BaseMaskToolsWidget):
+ """Widget with tools for drawing mask on an image in a PlotWidget."""
+
+ _maxLevelNumber = 255
+
+ def __init__(self, parent=None, plot=None):
+ super(MaskToolsWidget, self).__init__(parent, plot,
+ mask=ImageMask())
+ self._origin = (0., 0.) # Mask origin in plot
+ self._scale = (1., 1.) # Mask scale in plot
+ self._z = 1 # Mask layer in plot
+ self._data = numpy.zeros((0, 0), dtype=numpy.uint8) # Store image
+
+ self.__itemMaskUpdatedLock = LockReentrant()
+ self.__itemMaskUpdated = False
+
+ def __maskStateChanged(self) -> None:
+ """Handle mask commit to update item mask"""
+ item = self._mask.getDataItem()
+ if item is not None:
+ with self.__itemMaskUpdatedLock:
+ item.setMaskData(self._mask.getMask(copy=True), copy=False)
+
+ def setItemMaskUpdated(self, enabled: bool) -> None:
+ """Toggle item mask and mask tool synchronisation.
+
+ :param bool enabled: True to synchronise. Default: False
+ """
+ enabled = bool(enabled)
+ if enabled != self.__itemMaskUpdated:
+ if self.__itemMaskUpdated:
+ self._mask.sigStateChanged.disconnect(self.__maskStateChanged)
+ self.__itemMaskUpdated = enabled
+ if self.__itemMaskUpdated:
+ # Synchronize item and tool mask
+ self._setMaskedImage(self._mask.getDataItem())
+ self._mask.sigStateChanged.connect(self.__maskStateChanged)
+
+ def isItemMaskUpdated(self) -> bool:
+ """Returns whether or not item and mask tool masks are synchronised.
+
+ :rtype: bool
+ """
+ return self.__itemMaskUpdated
+
+ def setSelectionMask(self, mask, copy=True):
+ """Set the mask to a new array.
+
+ :param numpy.ndarray mask:
+ The array to use for the mask or None to reset the mask.
+ :type mask: numpy.ndarray of uint8 of dimension 2, C-contiguous.
+ Array of other types are converted.
+ :param bool copy: True (the default) to copy the array,
+ False to use it as is if possible.
+ :return: None if failed, shape of mask as 2-tuple if successful.
+ The mask can be cropped or padded to fit active image,
+ the returned shape is that of the active image.
+ """
+ if mask is None:
+ self.resetSelectionMask()
+ return self._data.shape[:2]
+
+ mask = numpy.array(mask, copy=False, dtype=numpy.uint8)
+ if len(mask.shape) != 2:
+ _logger.error('Not an image, shape: %d', len(mask.shape))
+ return None
+
+ # Handle mask with single level
+ if self.multipleMasks() == 'single':
+ mask = numpy.array(mask != 0, dtype=numpy.uint8)
+
+ # if mask has not changed, do nothing
+ if numpy.array_equal(mask, self.getSelectionMask()):
+ return mask.shape
+
+ if self._data.shape[0:2] == (0, 0) or mask.shape == self._data.shape[0:2]:
+ self._mask.setMask(mask, copy=copy)
+ self._mask.commit()
+ return mask.shape
+ else:
+ _logger.warning('Mask has not the same size as current image.'
+ ' Mask will be cropped or padded to fit image'
+ ' dimensions. %s != %s',
+ str(mask.shape), str(self._data.shape))
+ resizedMask = numpy.zeros(self._data.shape[0:2],
+ dtype=numpy.uint8)
+ height = min(self._data.shape[0], mask.shape[0])
+ width = min(self._data.shape[1], mask.shape[1])
+ resizedMask[:height,:width] = mask[:height,:width]
+ self._mask.setMask(resizedMask, copy=False)
+ self._mask.commit()
+ return resizedMask.shape
+
+ # Handle mask refresh on the plot
+ def _updatePlotMask(self):
+ """Update mask image in plot"""
+ mask = self.getSelectionMask(copy=False)
+ if mask is not None:
+ # get the mask from the plot
+ maskItem = self.plot.getImage(self._maskName)
+ mustBeAdded = maskItem is None
+ if mustBeAdded:
+ maskItem = items.MaskImageData()
+ maskItem.setName(self._maskName)
+ # update the items
+ maskItem.setData(mask, copy=False)
+ maskItem.setColormap(self._colormap)
+ maskItem.setOrigin(self._origin)
+ maskItem.setScale(self._scale)
+ maskItem.setZValue(self._z)
+
+ if mustBeAdded:
+ self.plot.addItem(maskItem)
+
+ elif self.plot.getImage(self._maskName):
+ self.plot.remove(self._maskName, kind='image')
+
+ def showEvent(self, event):
+ try:
+ self.plot.sigActiveImageChanged.disconnect(
+ self._activeImageChangedAfterCare)
+ except (RuntimeError, TypeError):
+ pass
+
+ # Sync with current active image
+ self._setMaskedImage(self.plot.getActiveImage())
+ self.plot.sigActiveImageChanged.connect(self._activeImageChanged)
+
+ def hideEvent(self, event):
+ try:
+ self.plot.sigActiveImageChanged.disconnect(
+ self._activeImageChanged)
+ except (RuntimeError, TypeError):
+ pass
+
+ image = self.getMaskedItem()
+ if image is not None:
+ try:
+ image.sigItemChanged.disconnect(self.__imageChanged)
+ except (RuntimeError, TypeError):
+ pass # TODO should not happen
+
+ if self.isMaskInteractionActivated():
+ # Disable drawing tool
+ self.browseAction.trigger()
+
+ if self.isItemMaskUpdated(): # No "after-care"
+ self._data = numpy.zeros((0, 0), dtype=numpy.uint8)
+ self._mask.setDataItem(None)
+ self._mask.reset()
+
+ if self.plot.getImage(self._maskName):
+ self.plot.remove(self._maskName, kind='image')
+
+ elif self.getSelectionMask(copy=False) is not None:
+ self.plot.sigActiveImageChanged.connect(
+ self._activeImageChangedAfterCare)
+
+ def _activeImageChanged(self, previous, current):
+ """Reacts upon active image change.
+
+ Only handle change of active image items here.
+ """
+ if previous != current:
+ image = self.plot.getActiveImage()
+ if image is not None and image.getName() == self._maskName:
+ image = None # Active image is the mask
+ self._setMaskedImage(image)
+
+ def _setOverlayColorForImage(self, image):
+ """Set the color of overlay adapted to image
+
+ :param image: :class:`.items.ImageBase` object to set color for.
+ """
+ if isinstance(image, items.ColormapMixIn):
+ colormap = image.getColormap()
+ self._defaultOverlayColor = rgba(
+ cursorColorForColormap(colormap['name']))
+ else:
+ self._defaultOverlayColor = rgba('black')
+
+ def _activeImageChangedAfterCare(self, *args):
+ """Check synchro of active image and mask when mask widget is hidden.
+
+ If active image has no more the same size as the mask, the mask is
+ removed, otherwise it is adjusted to origin, scale and z.
+ """
+ activeImage = self.plot.getActiveImage()
+ if activeImage is None or activeImage.getName() == self._maskName:
+ # No active image or active image is the mask...
+ self._data = numpy.zeros((0, 0), dtype=numpy.uint8)
+ self._mask.setDataItem(None)
+ self._mask.reset()
+
+ if self.plot.getImage(self._maskName):
+ self.plot.remove(self._maskName, kind='image')
+
+ self.plot.sigActiveImageChanged.disconnect(
+ self._activeImageChangedAfterCare)
+ else:
+ self._setOverlayColorForImage(activeImage)
+ self._setMaskColors(self.levelSpinBox.value(),
+ self.transparencySlider.value() /
+ self.transparencySlider.maximum())
+
+ self._origin = activeImage.getOrigin()
+ self._scale = activeImage.getScale()
+ self._z = activeImage.getZValue() + 1
+ self._data = activeImage.getData(copy=False)
+ if self._data.shape[:2] != self._mask.getMask(copy=False).shape:
+ # Image has not the same size, remove mask and stop listening
+ if self.plot.getImage(self._maskName):
+ self.plot.remove(self._maskName, kind='image')
+
+ self.plot.sigActiveImageChanged.disconnect(
+ self._activeImageChangedAfterCare)
+ else:
+ # Refresh in case origin, scale, z changed
+ self._mask.setDataItem(activeImage)
+ self._updatePlotMask()
+
+ def _setMaskedImage(self, image):
+ """Change the image that is used a reference to author the mask"""
+ previous = self.getMaskedItem()
+ if previous is not None and self.isVisible():
+ # Disconnect from previous image
+ try:
+ previous.sigItemChanged.disconnect(self.__imageChanged)
+ except (RuntimeError, TypeError):
+ pass # TODO fixme should not happen
+
+ # Set the image
+ self._mask.setDataItem(image)
+
+ if image is None: # No image, disable mask
+ self.setEnabled(False)
+
+ self._data = numpy.zeros((0, 0), dtype=numpy.uint8)
+ self._mask.reset()
+ self._mask.commit()
+
+ self._updateInteractiveMode()
+
+ else: # Update and connect to image's sigItemChanged
+ if self.isItemMaskUpdated():
+ if image.getMaskData(copy=False) is None:
+ # Image item has no mask: use current mask from the tool
+ image.setMaskData(
+ self.getSelectionMask(copy=False), copy=True)
+ else: # Image item has a mask: set it in tool
+ self.setSelectionMask(
+ image.getMaskData(copy=False), copy=True)
+ self._mask.resetHistory()
+ self.__imageUpdated()
+ if self.isVisible():
+ image.sigItemChanged.connect(self.__imageChanged)
+
+ def __imageChanged(self, event):
+ """Reacts upon image item changes"""
+ image = self._mask.getDataItem()
+ if image is None:
+ _logger.error("Mask is not attached to an image")
+ return
+
+ if event in (items.ItemChangedType.COLORMAP,
+ items.ItemChangedType.DATA,
+ items.ItemChangedType.POSITION,
+ items.ItemChangedType.SCALE,
+ items.ItemChangedType.VISIBLE,
+ items.ItemChangedType.ZVALUE):
+ self.__imageUpdated()
+
+ elif (event == items.ItemChangedType.MASK and
+ self.isItemMaskUpdated() and
+ not self.__itemMaskUpdatedLock.locked()):
+ # Update mask from the image item unless mask tool is updating it
+ self.setSelectionMask(image.getMaskData(copy=False), copy=True)
+
+ def __imageUpdated(self):
+ """Synchronize mask with current state of the image"""
+ image = self._mask.getDataItem()
+ if image is None:
+ _logger.error("No active image while expecting one")
+ return
+
+ self._setOverlayColorForImage(image)
+
+ self._setMaskColors(self.levelSpinBox.value(),
+ self.transparencySlider.value() /
+ self.transparencySlider.maximum())
+
+ self._origin = image.getOrigin()
+ self._scale = image.getScale()
+ self._z = image.getZValue() + 1
+ self._data = image.getData(copy=False)
+ self._mask.setDataItem(image)
+ if self._data.shape[:2] != self._mask.getMask(copy=False).shape:
+ self._mask.reset(self._data.shape[:2])
+ self._mask.commit()
+ else:
+ # Refresh in case origin, scale, z changed
+ self._updatePlotMask()
+
+ # Visible and with data
+ self.setEnabled(image.isVisible() and self._data.size != 0)
+
+ # Threshold tools only available for data with colormap
+ self.thresholdGroup.setEnabled(self._data.ndim == 2)
+
+ self._updateInteractiveMode()
+
+ # Handle whole mask operations
+ def load(self, filename):
+ """Load a mask from an image file.
+
+ :param str filename: File name from which to load the mask
+ :raise Exception: An exception in case of failure
+ :raise RuntimeWarning: In case the mask was applied but with some
+ import changes to notice
+ """
+ _, extension = os.path.splitext(filename)
+ extension = extension.lower()[1:]
+
+ if extension == "npy":
+ try:
+ mask = numpy.load(filename)
+ except IOError:
+ _logger.error("Can't load filename '%s'", filename)
+ _logger.debug("Backtrace", exc_info=True)
+ raise RuntimeError('File "%s" is not a numpy file.', filename)
+ elif extension in ["tif", "tiff"]:
+ try:
+ image = TiffIO(filename, mode="r")
+ mask = image.getImage(0)
+ except Exception as e:
+ _logger.error("Can't load filename %s", filename)
+ _logger.debug("Backtrace", exc_info=True)
+ raise e
+ elif extension == "edf":
+ try:
+ mask = EdfFile(filename, access='r').GetData(0)
+ except Exception as e:
+ _logger.error("Can't load filename %s", filename)
+ _logger.debug("Backtrace", exc_info=True)
+ raise e
+ elif extension == "msk":
+ try:
+ mask = fabio.open(filename).data
+ except Exception as e:
+ _logger.error("Can't load fit2d mask file")
+ _logger.debug("Backtrace", exc_info=True)
+ raise e
+ elif ("." + extension) in NEXUS_HDF5_EXT:
+ mask = self._loadFromHdf5(filename)
+ if mask is None:
+ raise IOError("Could not load mask from HDF5 dataset")
+ else:
+ msg = "Extension '%s' is not supported."
+ raise RuntimeError(msg % extension)
+
+ effectiveMaskShape = self.setSelectionMask(mask, copy=False)
+ if effectiveMaskShape is None:
+ return
+ if mask.shape != effectiveMaskShape:
+ msg = 'Mask was resized from %s to %s'
+ msg = msg % (str(mask.shape), str(effectiveMaskShape))
+ raise RuntimeWarning(msg)
+
+ def _loadMask(self):
+ """Open load mask dialog"""
+ dialog = qt.QFileDialog(self)
+ dialog.setWindowTitle("Load Mask")
+ dialog.setModal(1)
+
+ extensions = collections.OrderedDict()
+ extensions["EDF files"] = "*.edf"
+ extensions["TIFF files"] = "*.tif *.tiff"
+ extensions["NumPy binary files"] = "*.npy"
+ extensions["HDF5 files"] = _HDF5_EXT_STR
+ # Fit2D mask is displayed anyway fabio is here or not
+ # to show to the user that the option exists
+ extensions["Fit2D mask files"] = "*.msk"
+
+ filters = []
+ filters.append("All supported files (%s)" % " ".join(extensions.values()))
+ for name, extension in extensions.items():
+ filters.append("%s (%s)" % (name, extension))
+ filters.append("All files (*)")
+
+ dialog.setNameFilters(filters)
+ dialog.setFileMode(qt.QFileDialog.ExistingFile)
+ dialog.setDirectory(self.maskFileDir)
+ if not dialog.exec():
+ dialog.close()
+ return
+
+ filename = dialog.selectedFiles()[0]
+ dialog.close()
+
+ # Update the directory according to the user selection
+ self.maskFileDir = os.path.dirname(filename)
+
+ try:
+ self.load(filename)
+ except RuntimeWarning as e:
+ message = e.args[0]
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Warning)
+ msg.setText("Mask loaded but an operation was applied.\n" + message)
+ msg.exec()
+ except Exception as e:
+ message = e.args[0]
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setText("Cannot load mask from file. " + message)
+ msg.exec()
+
+ @staticmethod
+ def _loadFromHdf5(filename):
+ """Load a mask array from a HDF5 file.
+
+ :param str filename: name of an existing HDF5 file
+ :returns: A mask as a numpy array, or None if the interactive dialog
+ was cancelled
+ """
+ dataPath = _selectDataset(filename, mode=DatasetDialog.LoadMode)
+ if dataPath is None:
+ return None
+
+ with h5py.File(filename, "r") as h5f:
+ dataset = h5f.get(dataPath)
+ if not is_dataset(dataset):
+ raise IOError("%s is not a dataset" % dataPath)
+ mask = dataset[()]
+ return mask
+
+ def _saveMask(self):
+ """Open Save mask dialog"""
+ dialog = qt.QFileDialog(self)
+ dialog.setWindowTitle("Save Mask")
+ dialog.setOption(dialog.DontUseNativeDialog)
+ dialog.setModal(1)
+ hdf5Filter = 'HDF5 (%s)' % _HDF5_EXT_STR
+ filters = [
+ 'EDF (*.edf)',
+ 'TIFF (*.tif)',
+ 'NumPy binary file (*.npy)',
+ hdf5Filter,
+ # Fit2D mask is displayed anyway fabio is here or not
+ # to show to the user that the option exists
+ 'Fit2D mask (*.msk)',
+ ]
+ dialog.setNameFilters(filters)
+ dialog.setFileMode(qt.QFileDialog.AnyFile)
+ dialog.setAcceptMode(qt.QFileDialog.AcceptSave)
+ dialog.setDirectory(self.maskFileDir)
+
+ def onFilterSelection(filt_):
+ # disable overwrite confirmation for HDF5,
+ # because we append the data to existing files
+ if filt_ == hdf5Filter:
+ dialog.setOption(dialog.DontConfirmOverwrite)
+ else:
+ dialog.setOption(dialog.DontConfirmOverwrite, False)
+
+ dialog.filterSelected.connect(onFilterSelection)
+ if not dialog.exec():
+ dialog.close()
+ return
+
+ nameFilter = dialog.selectedNameFilter()
+ filename = dialog.selectedFiles()[0]
+ dialog.close()
+
+ if "HDF5" in nameFilter:
+ has_allowed_ext = False
+ for ext in NEXUS_HDF5_EXT:
+ if (len(filename) > len(ext) and
+ filename[-len(ext):].lower() == ext.lower()):
+ has_allowed_ext = True
+ extension = ext
+ if not has_allowed_ext:
+ extension = ".h5"
+ filename += ".h5"
+ else:
+ # convert filter name to extension name with the .
+ extension = nameFilter.split()[-1][2:-1]
+ if not filename.lower().endswith(extension):
+ filename += extension
+
+ if os.path.exists(filename) and "HDF5" not in nameFilter:
+ try:
+ os.remove(filename)
+ except IOError as e:
+ msg = qt.QMessageBox(self)
+ msg.setWindowTitle("Removing existing file")
+ msg.setIcon(qt.QMessageBox.Critical)
+
+ if hasattr(e, "strerror"):
+ strerror = e.strerror
+ else:
+ strerror = sys.exc_info()[1]
+ msg.setText("Cannot save.\n"
+ "Input Output Error: %s" % strerror)
+ msg.exec()
+ return
+
+ # Update the directory according to the user selection
+ self.maskFileDir = os.path.dirname(filename)
+
+ try:
+ self.save(filename, extension[1:])
+ except Exception as e:
+ msg = qt.QMessageBox(self)
+ msg.setWindowTitle("Saving mask file")
+ msg.setIcon(qt.QMessageBox.Critical)
+
+ if hasattr(e, "strerror"):
+ strerror = e.strerror
+ else:
+ strerror = sys.exc_info()[1]
+ msg.setText("Cannot save file %s\n%s" % (filename, strerror))
+ msg.exec()
+
+ def resetSelectionMask(self):
+ """Reset the mask"""
+ self._mask.reset(shape=self._data.shape[:2])
+ self._mask.commit()
+
+ def _plotDrawEvent(self, event):
+ """Handle draw events from the plot"""
+ if (self._drawingMode is None or
+ event['event'] not in ('drawingProgress', 'drawingFinished')):
+ return
+
+ if not len(self._data):
+ return
+
+ level = self.levelSpinBox.value()
+
+ if self._drawingMode == 'rectangle':
+ if event['event'] == 'drawingFinished':
+ # Convert from plot to array coords
+ doMask = self._isMasking()
+ ox, oy = self._origin
+ sx, sy = self._scale
+
+ height = int(abs(event['height'] / sy))
+ width = int(abs(event['width'] / sx))
+
+ row = int((event['y'] - oy) / sy)
+ if sy < 0:
+ row -= height
+
+ col = int((event['x'] - ox) / sx)
+ if sx < 0:
+ col -= width
+
+ self._mask.updateRectangle(
+ level,
+ row=row,
+ col=col,
+ height=height,
+ width=width,
+ mask=doMask)
+ self._mask.commit()
+
+ elif self._drawingMode == 'ellipse':
+ if event['event'] == 'drawingFinished':
+ doMask = self._isMasking()
+ # Convert from plot to array coords
+ center = (event['points'][0] - self._origin) / self._scale
+ size = event['points'][1] / self._scale
+ center = center.astype(numpy.int64) # (row, col)
+ self._mask.updateEllipse(level, center[1], center[0], size[1], size[0], doMask)
+ self._mask.commit()
+
+ elif self._drawingMode == 'polygon':
+ if event['event'] == 'drawingFinished':
+ doMask = self._isMasking()
+ # Convert from plot to array coords
+ vertices = (event['points'] - self._origin) / self._scale
+ vertices = vertices.astype(numpy.int64)[:, (1, 0)] # (row, col)
+ self._mask.updatePolygon(level, vertices, doMask)
+ self._mask.commit()
+
+ elif self._drawingMode == 'pencil':
+ doMask = self._isMasking()
+ # convert from plot to array coords
+ col, row = (event['points'][-1] - self._origin) / self._scale
+ col, row = int(col), int(row)
+ brushSize = self._getPencilWidth()
+
+ if self._lastPencilPos != (row, col):
+ if self._lastPencilPos is not None:
+ # Draw the line
+ self._mask.updateLine(
+ level,
+ self._lastPencilPos[0], self._lastPencilPos[1],
+ row, col,
+ brushSize,
+ doMask)
+
+ # Draw the very first, or last point
+ self._mask.updateDisk(level, row, col, brushSize / 2., doMask)
+
+ if event['event'] == 'drawingFinished':
+ self._mask.commit()
+ self._lastPencilPos = None
+ else:
+ self._lastPencilPos = row, col
+ else:
+ _logger.error("Drawing mode %s unsupported", self._drawingMode)
+
+ def _loadRangeFromColormapTriggered(self):
+ """Set range from active image colormap range"""
+ activeImage = self.plot.getActiveImage()
+ if (isinstance(activeImage, items.ColormapMixIn) and
+ activeImage.getName() != self._maskName):
+ # Update thresholds according to colormap
+ colormap = activeImage.getColormap()
+ if colormap['autoscale']:
+ min_ = numpy.nanmin(activeImage.getData(copy=False))
+ max_ = numpy.nanmax(activeImage.getData(copy=False))
+ else:
+ min_, max_ = colormap['vmin'], colormap['vmax']
+ self.minLineEdit.setText(str(min_))
+ self.maxLineEdit.setText(str(max_))
+
+
+class MaskToolsDockWidget(BaseMaskToolsDockWidget):
+ """:class:`MaskToolsWidget` embedded in a QDockWidget.
+
+ For integration in a :class:`PlotWindow`.
+
+ :param parent: See :class:`QDockWidget`
+ :param plot: The PlotWidget this widget is operating on
+ :paran str name: The title of this widget
+ """
+
+ def __init__(self, parent=None, plot=None, name='Mask'):
+ widget = MaskToolsWidget(plot=plot)
+ super(MaskToolsDockWidget, self).__init__(parent, name, widget)
diff --git a/src/silx/gui/plot/PlotActions.py b/src/silx/gui/plot/PlotActions.py
new file mode 100644
index 0000000..dd16221
--- /dev/null
+++ b/src/silx/gui/plot/PlotActions.py
@@ -0,0 +1,67 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Depracted module linking old PlotAction with the actions.xxx"""
+
+
+__author__ = ["V.A. Sole", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "01/06/2017"
+
+from silx.utils.deprecation import deprecated_warning
+
+deprecated_warning(type_='module',
+ name=__file__,
+ reason='PlotActions refactoring',
+ replacement='plot.actions',
+ since_version='0.6')
+
+from .actions import PlotAction
+
+from .actions.io import CopyAction
+from .actions.io import PrintAction
+from .actions.io import SaveAction
+
+from .actions.control import ColormapAction
+from .actions.control import CrosshairAction
+from .actions.control import CurveStyleAction
+from .actions.control import GridAction
+from .actions.control import KeepAspectRatioAction
+from .actions.control import PanWithArrowKeysAction
+from .actions.control import ResetZoomAction
+from .actions.control import XAxisAutoScaleAction
+from .actions.control import XAxisLogarithmicAction
+from .actions.control import YAxisAutoScaleAction
+from .actions.control import YAxisLogarithmicAction
+from .actions.control import YAxisInvertedAction
+from .actions.control import ZoomInAction
+from .actions.control import ZoomOutAction
+
+from .actions.medfilt import MedianFilter1DAction
+from .actions.medfilt import MedianFilter2DAction
+from .actions.medfilt import MedianFilterAction
+
+from .actions.histogram import PixelIntensitiesHistoAction
+
+from .actions.fit import FitAction
diff --git a/src/silx/gui/plot/PlotEvents.py b/src/silx/gui/plot/PlotEvents.py
new file mode 100644
index 0000000..83f253c
--- /dev/null
+++ b/src/silx/gui/plot/PlotEvents.py
@@ -0,0 +1,166 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Functions to prepare events to be sent to Plot callback."""
+
+__author__ = ["V.A. Sole", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "18/02/2016"
+
+
+import numpy as np
+
+
+def prepareDrawingSignal(event, type_, points, parameters=None):
+ """See Plot documentation for content of events"""
+ assert event in ('drawingProgress', 'drawingFinished')
+
+ if parameters is None:
+ parameters = {}
+
+ eventDict = {}
+ eventDict['event'] = event
+ eventDict['type'] = type_
+ points = np.array(points, dtype=np.float32)
+ points.shape = -1, 2
+ eventDict['points'] = points
+ eventDict['xdata'] = points[:, 0]
+ eventDict['ydata'] = points[:, 1]
+ if type_ in ('rectangle',):
+ eventDict['x'] = eventDict['xdata'].min()
+ eventDict['y'] = eventDict['ydata'].min()
+ eventDict['width'] = eventDict['xdata'].max() - eventDict['x']
+ eventDict['height'] = eventDict['ydata'].max() - eventDict['y']
+ eventDict['parameters'] = parameters.copy()
+ return eventDict
+
+
+def prepareMouseSignal(eventType, button, xData, yData, xPixel, yPixel):
+ """See Plot documentation for content of events"""
+ assert eventType in ('mouseMoved', 'mouseClicked', 'mouseDoubleClicked')
+ assert button in (None, 'left', 'middle', 'right')
+
+ return {'event': eventType,
+ 'x': xData,
+ 'y': yData,
+ 'xpixel': xPixel,
+ 'ypixel': yPixel,
+ 'button': button}
+
+
+def prepareHoverSignal(label, type_, posData, posPixel, draggable, selectable):
+ """See Plot documentation for content of events"""
+ return {'event': 'hover',
+ 'label': label,
+ 'type': type_,
+ 'x': posData[0],
+ 'y': posData[1],
+ 'xpixel': posPixel[0],
+ 'ypixel': posPixel[1],
+ 'draggable': draggable,
+ 'selectable': selectable}
+
+
+def prepareMarkerSignal(eventType, button, label, type_,
+ draggable, selectable,
+ posDataMarker,
+ posPixelCursor=None, posDataCursor=None):
+ """See Plot documentation for content of events"""
+ if eventType == 'markerClicked':
+ assert posPixelCursor is not None
+ assert posDataCursor is None
+
+ posDataCursor = list(posDataMarker)
+ if hasattr(posDataCursor[0], "__len__"):
+ posDataCursor[0] = posDataCursor[0][-1]
+ if hasattr(posDataCursor[1], "__len__"):
+ posDataCursor[1] = posDataCursor[1][-1]
+
+ elif eventType == 'markerMoving':
+ assert posPixelCursor is not None
+ assert posDataCursor is not None
+
+ elif eventType == 'markerMoved':
+ assert posPixelCursor is None
+ assert posDataCursor is None
+
+ posDataCursor = posDataMarker
+ else:
+ raise NotImplementedError("Unknown event type {0}".format(eventType))
+
+ eventDict = {'event': eventType,
+ 'button': button,
+ 'label': label,
+ 'type': type_,
+ 'x': posDataCursor[0],
+ 'y': posDataCursor[1],
+ 'xdata': posDataMarker[0],
+ 'ydata': posDataMarker[1],
+ 'draggable': draggable,
+ 'selectable': selectable}
+
+ if eventType in ('markerMoving', 'markerClicked'):
+ eventDict['xpixel'] = posPixelCursor[0]
+ eventDict['ypixel'] = posPixelCursor[1]
+
+ return eventDict
+
+
+def prepareImageSignal(button, label, type_, col, row,
+ x, y, xPixel, yPixel):
+ """See Plot documentation for content of events"""
+ return {'event': 'imageClicked',
+ 'button': button,
+ 'label': label,
+ 'type': type_,
+ 'col': col,
+ 'row': row,
+ 'x': x,
+ 'y': y,
+ 'xpixel': xPixel,
+ 'ypixel': yPixel}
+
+
+def prepareCurveSignal(button, label, type_, xData, yData,
+ x, y, xPixel, yPixel):
+ """See Plot documentation for content of events"""
+ return {'event': 'curveClicked',
+ 'button': button,
+ 'label': label,
+ 'type': type_,
+ 'xdata': xData,
+ 'ydata': yData,
+ 'x': x,
+ 'y': y,
+ 'xpixel': xPixel,
+ 'ypixel': yPixel}
+
+
+def prepareLimitsChangedSignal(sourceObj, xRange, yRange, y2Range):
+ """See Plot documentation for content of events"""
+ return {'event': 'limitsChanged',
+ 'source': id(sourceObj),
+ 'xdata': xRange,
+ 'ydata': yRange,
+ 'y2data': y2Range}
diff --git a/src/silx/gui/plot/PlotInteraction.py b/src/silx/gui/plot/PlotInteraction.py
new file mode 100644
index 0000000..6ebe6b1
--- /dev/null
+++ b/src/silx/gui/plot/PlotInteraction.py
@@ -0,0 +1,1746 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Implementation of the interaction for the :class:`Plot`."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "15/02/2019"
+
+
+import math
+import numpy
+import time
+import weakref
+
+from .. import colors
+from .. import qt
+from . import items
+from .Interaction import (ClickOrDrag, LEFT_BTN, RIGHT_BTN, MIDDLE_BTN,
+ State, StateMachine)
+from .PlotEvents import (prepareCurveSignal, prepareDrawingSignal,
+ prepareHoverSignal, prepareImageSignal,
+ prepareMarkerSignal, prepareMouseSignal)
+
+from .backends.BackendBase import (CURSOR_POINTING, CURSOR_SIZE_HOR,
+ CURSOR_SIZE_VER, CURSOR_SIZE_ALL)
+
+from ._utils import (FLOAT32_SAFE_MIN, FLOAT32_MINPOS, FLOAT32_SAFE_MAX,
+ applyZoomToPlot)
+
+
+# Base class ##################################################################
+
+class _PlotInteraction(object):
+ """Base class for interaction handler.
+
+ It provides a weakref to the plot and methods to set/reset overlay.
+ """
+ def __init__(self, plot):
+ """Init.
+
+ :param plot: The plot to apply modifications to.
+ """
+ self._needReplot = False
+ self._selectionAreas = set()
+ self._plot = weakref.ref(plot) # Avoid cyclic-ref
+
+ @property
+ def plot(self):
+ plot = self._plot()
+ assert plot is not None
+ return plot
+
+ def setSelectionArea(self, points, fill, color, name='', shape='polygon'):
+ """Set a polygon selection area overlaid on the plot.
+ Multiple simultaneous areas are supported through the name parameter.
+
+ :param points: The 2D coordinates of the points of the polygon
+ :type points: An iterable of (x, y) coordinates
+ :param str fill: The fill mode: 'hatch', 'solid' or 'none'
+ :param color: RGBA color to use or None to disable display
+ :type color: list or tuple of 4 float in the range [0, 1]
+ :param name: The key associated with this selection area
+ :param str shape: Shape of the area in 'polygon', 'polylines'
+ """
+ assert shape in ('polygon', 'polylines')
+
+ if color is None:
+ return
+
+ points = numpy.asarray(points)
+
+ # TODO Not very nice, but as is for now
+ legend = '__SELECTION_AREA__' + name
+
+ fill = fill != 'none' # TODO not very nice either
+
+ greyed = colors.greyed(color)[0]
+ if greyed < 0.5:
+ color2 = "white"
+ else:
+ color2 = "black"
+
+ self.plot.addShape(points[:, 0], points[:, 1], legend=legend,
+ replace=False,
+ shape=shape, fill=fill,
+ color=color, linebgcolor=color2, linestyle="--",
+ overlay=True)
+
+ self._selectionAreas.add(legend)
+
+ def resetSelectionArea(self):
+ """Remove all selection areas set by setSelectionArea."""
+ for legend in self._selectionAreas:
+ self.plot.remove(legend, kind='item')
+ self._selectionAreas = set()
+
+
+# Zoom/Pan ####################################################################
+
+class _ZoomOnWheel(ClickOrDrag, _PlotInteraction):
+ """:class:`ClickOrDrag` state machine with zooming on mouse wheel.
+
+ Base class for :class:`Pan` and :class:`Zoom`
+ """
+
+ _DOUBLE_CLICK_TIMEOUT = 0.4
+
+ class Idle(ClickOrDrag.Idle):
+ def onWheel(self, x, y, angle):
+ scaleF = 1.1 if angle > 0 else 1. / 1.1
+ applyZoomToPlot(self.machine.plot, scaleF, (x, y))
+
+ def click(self, x, y, btn):
+ """Handle clicks by sending events
+
+ :param int x: Mouse X position in pixels
+ :param int y: Mouse Y position in pixels
+ :param btn: Clicked mouse button
+ """
+ if btn == LEFT_BTN:
+ lastClickTime, lastClickPos = self._lastClick
+
+ # Signal mouse double clicked event first
+ if (time.time() - lastClickTime) <= self._DOUBLE_CLICK_TIMEOUT:
+ # Use position of first click
+ eventDict = prepareMouseSignal('mouseDoubleClicked', 'left',
+ *lastClickPos)
+ self.plot.notify(**eventDict)
+
+ self._lastClick = 0., None
+ else:
+ # Signal mouse clicked event
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ eventDict = prepareMouseSignal('mouseClicked', 'left',
+ dataPos[0], dataPos[1],
+ x, y)
+ self.plot.notify(**eventDict)
+
+ self._lastClick = time.time(), (dataPos[0], dataPos[1], x, y)
+
+ elif btn == RIGHT_BTN:
+ # Signal mouse clicked event
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ eventDict = prepareMouseSignal('mouseClicked', 'right',
+ dataPos[0], dataPos[1],
+ x, y)
+ self.plot.notify(**eventDict)
+
+ def __init__(self, plot, **kwargs):
+ """Init.
+
+ :param plot: The plot to apply modifications to.
+ """
+ self._lastClick = 0., None
+
+ _PlotInteraction.__init__(self, plot)
+ ClickOrDrag.__init__(self, **kwargs)
+
+
+# Pan #########################################################################
+
+class Pan(_ZoomOnWheel):
+ """Pan plot content and zoom on wheel state machine."""
+
+ def _pixelToData(self, x, y):
+ xData, yData = self.plot.pixelToData(x, y)
+ _, y2Data = self.plot.pixelToData(x, y, axis='right')
+ return xData, yData, y2Data
+
+ def beginDrag(self, x, y, btn):
+ self._previousDataPos = self._pixelToData(x, y)
+
+ def drag(self, x, y, btn):
+ xData, yData, y2Data = self._pixelToData(x, y)
+ lastX, lastY, lastY2 = self._previousDataPos
+
+ xMin, xMax = self.plot.getXAxis().getLimits()
+ yMin, yMax = self.plot.getYAxis().getLimits()
+ y2Min, y2Max = self.plot.getYAxis(axis='right').getLimits()
+
+ if self.plot.getXAxis()._isLogarithmic():
+ try:
+ dx = math.log10(xData) - math.log10(lastX)
+ newXMin = pow(10., (math.log10(xMin) - dx))
+ newXMax = pow(10., (math.log10(xMax) - dx))
+ except (ValueError, OverflowError):
+ newXMin, newXMax = xMin, xMax
+
+ # Makes sure both values stays in positive float32 range
+ if newXMin < FLOAT32_MINPOS or newXMax > FLOAT32_SAFE_MAX:
+ newXMin, newXMax = xMin, xMax
+ else:
+ dx = xData - lastX
+ newXMin, newXMax = xMin - dx, xMax - dx
+
+ # Makes sure both values stays in float32 range
+ if newXMin < FLOAT32_SAFE_MIN or newXMax > FLOAT32_SAFE_MAX:
+ newXMin, newXMax = xMin, xMax
+
+ if self.plot.getYAxis()._isLogarithmic():
+ try:
+ dy = math.log10(yData) - math.log10(lastY)
+ newYMin = pow(10., math.log10(yMin) - dy)
+ newYMax = pow(10., math.log10(yMax) - dy)
+
+ dy2 = math.log10(y2Data) - math.log10(lastY2)
+ newY2Min = pow(10., math.log10(y2Min) - dy2)
+ newY2Max = pow(10., math.log10(y2Max) - dy2)
+ except (ValueError, OverflowError):
+ newYMin, newYMax = yMin, yMax
+ newY2Min, newY2Max = y2Min, y2Max
+
+ # Makes sure y and y2 stays in positive float32 range
+ if (newYMin < FLOAT32_MINPOS or newYMax > FLOAT32_SAFE_MAX or
+ newY2Min < FLOAT32_MINPOS or newY2Max > FLOAT32_SAFE_MAX):
+ newYMin, newYMax = yMin, yMax
+ newY2Min, newY2Max = y2Min, y2Max
+ else:
+ dy = yData - lastY
+ dy2 = y2Data - lastY2
+ newYMin, newYMax = yMin - dy, yMax - dy
+ newY2Min, newY2Max = y2Min - dy2, y2Max - dy2
+
+ # Makes sure y and y2 stays in float32 range
+ if (newYMin < FLOAT32_SAFE_MIN or
+ newYMax > FLOAT32_SAFE_MAX or
+ newY2Min < FLOAT32_SAFE_MIN or
+ newY2Max > FLOAT32_SAFE_MAX):
+ newYMin, newYMax = yMin, yMax
+ newY2Min, newY2Max = y2Min, y2Max
+
+ self.plot.setLimits(newXMin, newXMax,
+ newYMin, newYMax,
+ newY2Min, newY2Max)
+
+ self._previousDataPos = self._pixelToData(x, y)
+
+ def endDrag(self, startPos, endPos, btn):
+ del self._previousDataPos
+
+ def cancel(self):
+ pass
+
+
+# Zoom ########################################################################
+
+class Zoom(_ZoomOnWheel):
+ """Zoom-in/out state machine.
+
+ Zoom-in on selected area, zoom-out on right click,
+ and zoom on mouse wheel.
+ """
+
+ SURFACE_THRESHOLD = 5
+
+ def __init__(self, plot, color):
+ self.color = color
+
+ super(Zoom, self).__init__(plot)
+ self.plot.getLimitsHistory().clear()
+
+ def _areaWithAspectRatio(self, x0, y0, x1, y1):
+ _plotLeft, _plotTop, plotW, plotH = self.plot.getPlotBoundsInPixels()
+
+ areaX0, areaY0, areaX1, areaY1 = x0, y0, x1, y1
+
+ if plotH != 0.:
+ plotRatio = plotW / float(plotH)
+ width, height = math.fabs(x1 - x0), math.fabs(y1 - y0)
+
+ if height != 0. and width != 0.:
+ if width / height > plotRatio:
+ areaHeight = width / plotRatio
+ areaX0, areaX1 = x0, x1
+ center = 0.5 * (y0 + y1)
+ areaY0 = center - numpy.sign(y1 - y0) * 0.5 * areaHeight
+ areaY1 = center + numpy.sign(y1 - y0) * 0.5 * areaHeight
+ else:
+ areaWidth = height * plotRatio
+ areaY0, areaY1 = y0, y1
+ center = 0.5 * (x0 + x1)
+ areaX0 = center - numpy.sign(x1 - x0) * 0.5 * areaWidth
+ areaX1 = center + numpy.sign(x1 - x0) * 0.5 * areaWidth
+
+ return areaX0, areaY0, areaX1, areaY1
+
+ def beginDrag(self, x, y, btn):
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ self.x0, self.y0 = x, y
+
+ def drag(self, x1, y1, btn):
+ if self.color is None:
+ return # Do not draw zoom area
+
+ dataPos = self.plot.pixelToData(x1, y1)
+ assert dataPos is not None
+
+ if self.plot.isKeepDataAspectRatio():
+ area = self._areaWithAspectRatio(self.x0, self.y0, x1, y1)
+ areaX0, areaY0, areaX1, areaY1 = area
+ areaPoints = ((areaX0, areaY0),
+ (areaX1, areaY0),
+ (areaX1, areaY1),
+ (areaX0, areaY1))
+ areaPoints = numpy.array([self.plot.pixelToData(
+ x, y, check=False) for (x, y) in areaPoints])
+
+ if self.color != 'video inverted':
+ areaColor = list(self.color)
+ areaColor[3] *= 0.25
+ else:
+ areaColor = [1., 1., 1., 1.]
+
+ self.setSelectionArea(areaPoints,
+ fill='none',
+ color=areaColor,
+ name="zoomedArea")
+
+ corners = ((self.x0, self.y0),
+ (self.x0, y1),
+ (x1, y1),
+ (x1, self.y0))
+ corners = numpy.array([self.plot.pixelToData(x, y, check=False)
+ for (x, y) in corners])
+
+ self.setSelectionArea(corners, fill='none', color=self.color)
+
+ def _zoom(self, x0, y0, x1, y1):
+ """Zoom to the rectangle view x0,y0 x1,y1.
+ """
+ startPos = x0, y0
+ endPos = x1, y1
+
+ # Store current zoom state in stack
+ self.plot.getLimitsHistory().push()
+
+ if self.plot.isKeepDataAspectRatio():
+ x0, y0, x1, y1 = self._areaWithAspectRatio(x0, y0, x1, y1)
+
+ # Convert to data space and set limits
+ x0, y0 = self.plot.pixelToData(x0, y0, check=False)
+
+ dataPos = self.plot.pixelToData(
+ startPos[0], startPos[1], axis="right", check=False)
+ y2_0 = dataPos[1]
+
+ x1, y1 = self.plot.pixelToData(x1, y1, check=False)
+
+ dataPos = self.plot.pixelToData(
+ endPos[0], endPos[1], axis="right", check=False)
+ y2_1 = dataPos[1]
+
+ xMin, xMax = min(x0, x1), max(x0, x1)
+ yMin, yMax = min(y0, y1), max(y0, y1)
+ y2Min, y2Max = min(y2_0, y2_1), max(y2_0, y2_1)
+
+ self.plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
+
+ def endDrag(self, startPos, endPos, btn):
+ x0, y0 = startPos
+ x1, y1 = endPos
+
+ if abs(x0 - x1) * abs(y0 - y1) >= self.SURFACE_THRESHOLD:
+ # Avoid empty zoom area
+ self._zoom(x0, y0, x1, y1)
+
+ self.resetSelectionArea()
+
+ def cancel(self):
+ if isinstance(self.state, self.states['drag']):
+ self.resetSelectionArea()
+
+
+# Select ######################################################################
+
+class Select(StateMachine, _PlotInteraction):
+ """Base class for drawing selection areas."""
+
+ def __init__(self, plot, parameters, states, state):
+ """Init a state machine.
+
+ :param plot: The plot to apply changes to.
+ :param dict parameters: A dict of parameters such as color.
+ :param dict states: The states of the state machine.
+ :param str state: The name of the initial state.
+ """
+ _PlotInteraction.__init__(self, plot)
+ self.parameters = parameters
+ StateMachine.__init__(self, states, state)
+
+ def onWheel(self, x, y, angle):
+ scaleF = 1.1 if angle > 0 else 1. / 1.1
+ applyZoomToPlot(self.plot, scaleF, (x, y))
+
+ @property
+ def color(self):
+ return self.parameters.get('color', None)
+
+
+class SelectPolygon(Select):
+ """Drawing selection polygon area state machine."""
+
+ DRAG_THRESHOLD_DIST = 4
+
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.goto('select', x, y)
+ return True
+
+ class Select(State):
+ def enterState(self, x, y):
+ dataPos = self.machine.plot.pixelToData(x, y)
+ assert dataPos is not None
+ self._firstPos = dataPos
+ self.points = [dataPos, dataPos]
+
+ self.updateFirstPoint()
+
+ def updateFirstPoint(self):
+ """Update drawing first point, using self._firstPos"""
+ x, y = self.machine.plot.dataToPixel(*self._firstPos, check=False)
+
+ offset = self.machine.getDragThreshold()
+ points = [(x - offset, y - offset),
+ (x - offset, y + offset),
+ (x + offset, y + offset),
+ (x + offset, y - offset)]
+ points = [self.machine.plot.pixelToData(xpix, ypix, check=False)
+ for xpix, ypix in points]
+ self.machine.setSelectionArea(points, fill=None,
+ color=self.machine.color,
+ name='first_point')
+
+ def updateSelectionArea(self):
+ """Update drawing selection area using self.points"""
+ self.machine.setSelectionArea(self.points,
+ fill='hatch',
+ color=self.machine.color)
+ eventDict = prepareDrawingSignal('drawingProgress',
+ 'polygon',
+ self.points,
+ self.machine.parameters)
+ self.machine.plot.notify(**eventDict)
+
+ def validate(self):
+ if len(self.points) > 2:
+ self.closePolygon()
+ else:
+ # It would be nice to have a cancel event.
+ # The plot is not aware that the interaction was cancelled
+ self.machine.cancel()
+
+ def closePolygon(self):
+ self.machine.resetSelectionArea()
+ self.points[-1] = self.points[0]
+ eventDict = prepareDrawingSignal('drawingFinished',
+ 'polygon',
+ self.points,
+ self.machine.parameters)
+ self.machine.plot.notify(**eventDict)
+ self.goto('idle')
+
+ def onWheel(self, x, y, angle):
+ self.machine.onWheel(x, y, angle)
+ self.updateFirstPoint()
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ # checking if the position is close to the first point
+ # if yes : closing the "loop"
+ firstPos = self.machine.plot.dataToPixel(*self._firstPos,
+ check=False)
+ dx, dy = abs(firstPos[0] - x), abs(firstPos[1] - y)
+
+ threshold = self.machine.getDragThreshold()
+
+ # Only allow to close polygon after first point
+ if len(self.points) > 2 and dx <= threshold and dy <= threshold:
+ self.closePolygon()
+ return False
+
+ # Update polygon last point not too close to previous one
+ dataPos = self.machine.plot.pixelToData(x, y)
+ assert dataPos is not None
+ self.updateSelectionArea()
+
+ # checking that the new points isnt the same (within range)
+ # of the previous one
+ # This has to be done because sometimes the mouse release event
+ # is caught right after entering the Select state (i.e : press
+ # in Idle state, but with a slightly different position that
+ # the mouse press. So we had the two first vertices that were
+ # almost identical.
+ previousPos = self.machine.plot.dataToPixel(*self.points[-2],
+ check=False)
+ dx, dy = abs(previousPos[0] - x), abs(previousPos[1] - y)
+ if dx >= threshold or dy >= threshold:
+ self.points.append(dataPos)
+ else:
+ self.points[-1] = dataPos
+
+ return True
+ return False
+
+ def onMove(self, x, y):
+ firstPos = self.machine.plot.dataToPixel(*self._firstPos,
+ check=False)
+ dx, dy = abs(firstPos[0] - x), abs(firstPos[1] - y)
+ threshold = self.machine.getDragThreshold()
+
+ if dx <= threshold and dy <= threshold:
+ x, y = firstPos # Snap to first point
+
+ dataPos = self.machine.plot.pixelToData(x, y)
+ assert dataPos is not None
+ self.points[-1] = dataPos
+ self.updateSelectionArea()
+
+ def __init__(self, plot, parameters):
+ states = {
+ 'idle': SelectPolygon.Idle,
+ 'select': SelectPolygon.Select
+ }
+ super(SelectPolygon, self).__init__(plot, parameters,
+ states, 'idle')
+
+ def cancel(self):
+ if isinstance(self.state, self.states['select']):
+ self.resetSelectionArea()
+
+ def getDragThreshold(self):
+ """Return dragging ratio with device to pixel ratio applied.
+
+ :rtype: float
+ """
+ ratio = self.plot.window().windowHandle().devicePixelRatio()
+ return self.DRAG_THRESHOLD_DIST * ratio
+
+
+class Select2Points(Select):
+ """Base class for drawing selection based on 2 input points."""
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.goto('start', x, y)
+ return True
+
+ class Start(State):
+ def enterState(self, x, y):
+ self.machine.beginSelect(x, y)
+
+ def onMove(self, x, y):
+ self.goto('select', x, y)
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.goto('select', x, y)
+ return True
+
+ class Select(State):
+ def enterState(self, x, y):
+ self.onMove(x, y)
+
+ def onMove(self, x, y):
+ self.machine.select(x, y)
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.machine.endSelect(x, y)
+ self.goto('idle')
+
+ def __init__(self, plot, parameters):
+ states = {
+ 'idle': Select2Points.Idle,
+ 'start': Select2Points.Start,
+ 'select': Select2Points.Select
+ }
+ super(Select2Points, self).__init__(plot, parameters,
+ states, 'idle')
+
+ def beginSelect(self, x, y):
+ pass
+
+ def select(self, x, y):
+ pass
+
+ def endSelect(self, x, y):
+ pass
+
+ def cancelSelect(self):
+ pass
+
+ def cancel(self):
+ if isinstance(self.state, self.states['select']):
+ self.cancelSelect()
+
+
+class SelectEllipse(Select2Points):
+ """Drawing ellipse selection area state machine."""
+ def beginSelect(self, x, y):
+ self.center = self.plot.pixelToData(x, y)
+ assert self.center is not None
+
+ def _getEllipseSize(self, pointInEllipse):
+ """
+ Returns the size from the center to the bounding box of the ellipse.
+
+ :param Tuple[float,float] pointInEllipse: A point of the ellipse
+ :rtype: Tuple[float,float]
+ """
+ x = abs(self.center[0] - pointInEllipse[0])
+ y = abs(self.center[1] - pointInEllipse[1])
+ if x == 0 or y == 0:
+ return x, y
+ # Ellipse definitions
+ # e: eccentricity
+ # a: length fron center to bounding box width
+ # b: length fron center to bounding box height
+ # Equations
+ # (1) b < a
+ # (2) For x,y a point in the ellipse: x^2/a^2 + y^2/b^2 = 1
+ # (3) b = a * sqrt(1-e^2)
+ # (4) e = sqrt(a^2 - b^2) / a
+
+ # The eccentricity of the ellipse defined by a,b=x,y is the same
+ # as the one we are searching for.
+ swap = x < y
+ if swap:
+ x, y = y, x
+ e = math.sqrt(x**2 - y**2) / x
+ # From (2) using (3) to replace b
+ # a^2 = x^2 + y^2 / (1-e^2)
+ a = math.sqrt(x**2 + y**2 / (1.0 - e**2))
+ b = a * math.sqrt(1 - e**2)
+ if swap:
+ a, b = b, a
+ return a, b
+
+ def select(self, x, y):
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ width, height = self._getEllipseSize(dataPos)
+
+ # Circle used for circle preview
+ nbpoints = 27.
+ angles = numpy.arange(nbpoints) * numpy.pi * 2.0 / nbpoints
+ circleShape = numpy.array((numpy.cos(angles) * width,
+ numpy.sin(angles) * height)).T
+ circleShape += numpy.array(self.center)
+
+ self.setSelectionArea(circleShape,
+ shape="polygon",
+ fill='hatch',
+ color=self.color)
+
+ eventDict = prepareDrawingSignal('drawingProgress',
+ 'ellipse',
+ (self.center, (width, height)),
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def endSelect(self, x, y):
+ self.resetSelectionArea()
+
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ width, height = self._getEllipseSize(dataPos)
+
+ eventDict = prepareDrawingSignal('drawingFinished',
+ 'ellipse',
+ (self.center, (width, height)),
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def cancelSelect(self):
+ self.resetSelectionArea()
+
+
+class SelectRectangle(Select2Points):
+ """Drawing rectangle selection area state machine."""
+ def beginSelect(self, x, y):
+ self.startPt = self.plot.pixelToData(x, y)
+ assert self.startPt is not None
+
+ def select(self, x, y):
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+
+ self.setSelectionArea((self.startPt,
+ (self.startPt[0], dataPos[1]),
+ dataPos,
+ (dataPos[0], self.startPt[1])),
+ fill='hatch',
+ color=self.color)
+
+ eventDict = prepareDrawingSignal('drawingProgress',
+ 'rectangle',
+ (self.startPt, dataPos),
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def endSelect(self, x, y):
+ self.resetSelectionArea()
+
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+
+ eventDict = prepareDrawingSignal('drawingFinished',
+ 'rectangle',
+ (self.startPt, dataPos),
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def cancelSelect(self):
+ self.resetSelectionArea()
+
+
+class SelectLine(Select2Points):
+ """Drawing line selection area state machine."""
+ def beginSelect(self, x, y):
+ self.startPt = self.plot.pixelToData(x, y)
+ assert self.startPt is not None
+
+ def select(self, x, y):
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+
+ self.setSelectionArea((self.startPt, dataPos),
+ fill='hatch',
+ color=self.color)
+
+ eventDict = prepareDrawingSignal('drawingProgress',
+ 'line',
+ (self.startPt, dataPos),
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def endSelect(self, x, y):
+ self.resetSelectionArea()
+
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+
+ eventDict = prepareDrawingSignal('drawingFinished',
+ 'line',
+ (self.startPt, dataPos),
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def cancelSelect(self):
+ self.resetSelectionArea()
+
+
+class Select1Point(Select):
+ """Base class for drawing selection area based on one input point."""
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.goto('select', x, y)
+ return True
+
+ class Select(State):
+ def enterState(self, x, y):
+ self.onMove(x, y)
+
+ def onMove(self, x, y):
+ self.machine.select(x, y)
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.machine.endSelect(x, y)
+ self.goto('idle')
+
+ def onWheel(self, x, y, angle):
+ self.machine.onWheel(x, y, angle) # Call select default wheel
+ self.machine.select(x, y)
+
+ def __init__(self, plot, parameters):
+ states = {
+ 'idle': Select1Point.Idle,
+ 'select': Select1Point.Select
+ }
+ super(Select1Point, self).__init__(plot, parameters, states, 'idle')
+
+ def select(self, x, y):
+ pass
+
+ def endSelect(self, x, y):
+ pass
+
+ def cancelSelect(self):
+ pass
+
+ def cancel(self):
+ if isinstance(self.state, self.states['select']):
+ self.cancelSelect()
+
+
+class SelectHLine(Select1Point):
+ """Drawing a horizontal line selection area state machine."""
+ def _hLine(self, y):
+ """Return points in data coords of the segment visible in the plot.
+
+ Supports non-orthogonal axes.
+ """
+ left, _top, width, _height = self.plot.getPlotBoundsInPixels()
+
+ dataPos1 = self.plot.pixelToData(left, y, check=False)
+ dataPos2 = self.plot.pixelToData(left + width, y, check=False)
+ return dataPos1, dataPos2
+
+ def select(self, x, y):
+ points = self._hLine(y)
+ self.setSelectionArea(points, fill='hatch', color=self.color)
+
+ eventDict = prepareDrawingSignal('drawingProgress',
+ 'hline',
+ points,
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def endSelect(self, x, y):
+ self.resetSelectionArea()
+
+ eventDict = prepareDrawingSignal('drawingFinished',
+ 'hline',
+ self._hLine(y),
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def cancelSelect(self):
+ self.resetSelectionArea()
+
+
+class SelectVLine(Select1Point):
+ """Drawing a vertical line selection area state machine."""
+ def _vLine(self, x):
+ """Return points in data coords of the segment visible in the plot.
+
+ Supports non-orthogonal axes.
+ """
+ _left, top, _width, height = self.plot.getPlotBoundsInPixels()
+
+ dataPos1 = self.plot.pixelToData(x, top, check=False)
+ dataPos2 = self.plot.pixelToData(x, top + height, check=False)
+ return dataPos1, dataPos2
+
+ def select(self, x, y):
+ points = self._vLine(x)
+ self.setSelectionArea(points, fill='hatch', color=self.color)
+
+ eventDict = prepareDrawingSignal('drawingProgress',
+ 'vline',
+ points,
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def endSelect(self, x, y):
+ self.resetSelectionArea()
+
+ eventDict = prepareDrawingSignal('drawingFinished',
+ 'vline',
+ self._vLine(x),
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def cancelSelect(self):
+ self.resetSelectionArea()
+
+
+class DrawFreeHand(Select):
+ """Interaction for drawing pencil. It display the preview of the pencil
+ before pressing the mouse.
+ """
+
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.goto('select', x, y)
+ return True
+
+ def onMove(self, x, y):
+ self.machine.updatePencilShape(x, y)
+
+ def onLeave(self):
+ self.machine.cancel()
+
+ class Select(State):
+ def enterState(self, x, y):
+ self.__isOut = False
+ self.machine.setFirstPoint(x, y)
+
+ def onMove(self, x, y):
+ self.machine.updatePencilShape(x, y)
+ self.machine.select(x, y)
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ if self.__isOut:
+ self.machine.resetSelectionArea()
+ self.machine.endSelect(x, y)
+ self.goto('idle')
+
+ def onEnter(self):
+ self.__isOut = False
+
+ def onLeave(self):
+ self.__isOut = True
+
+ def __init__(self, plot, parameters):
+ # Circle used for pencil preview
+ angle = numpy.arange(13.) * numpy.pi * 2.0 / 13.
+ size = parameters.get('width', 1.) * 0.5
+ self._circle = size * numpy.array((numpy.cos(angle),
+ numpy.sin(angle))).T
+
+ states = {
+ 'idle': DrawFreeHand.Idle,
+ 'select': DrawFreeHand.Select
+ }
+ super(DrawFreeHand, self).__init__(plot, parameters, states, 'idle')
+
+ @property
+ def width(self):
+ return self.parameters.get('width', None)
+
+ def setFirstPoint(self, x, y):
+ self._points = []
+ self.select(x, y)
+
+ def updatePencilShape(self, x, y):
+ center = self.plot.pixelToData(x, y, check=False)
+ assert center is not None
+
+ polygon = center + self._circle
+
+ self.setSelectionArea(polygon, fill='none', color=self.color)
+
+ def select(self, x, y):
+ pos = self.plot.pixelToData(x, y, check=False)
+ if len(self._points) > 0:
+ if self._points[-1] == pos:
+ # Skip same points
+ return
+ self._points.append(pos)
+ eventDict = prepareDrawingSignal('drawingProgress',
+ 'polylines',
+ self._points,
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def endSelect(self, x, y):
+ pos = self.plot.pixelToData(x, y, check=False)
+ if len(self._points) > 0:
+ if self._points[-1] != pos:
+ # Append if different
+ self._points.append(pos)
+
+ eventDict = prepareDrawingSignal('drawingFinished',
+ 'polylines',
+ self._points,
+ self.parameters)
+ self.plot.notify(**eventDict)
+ self._points = None
+
+ def cancelSelect(self):
+ self.resetSelectionArea()
+
+ def cancel(self):
+ self.resetSelectionArea()
+
+
+class SelectFreeLine(ClickOrDrag, _PlotInteraction):
+ """Base class for drawing free lines with tools such as pencil."""
+
+ def __init__(self, plot, parameters):
+ """Init a state machine.
+
+ :param plot: The plot to apply changes to.
+ :param dict parameters: A dict of parameters such as color.
+ """
+ # self.DRAG_THRESHOLD_SQUARE_DIST = 1 # Disable first move threshold
+ self._points = []
+ ClickOrDrag.__init__(self)
+ _PlotInteraction.__init__(self, plot)
+ self.parameters = parameters
+
+ def onWheel(self, x, y, angle):
+ scaleF = 1.1 if angle > 0 else 1. / 1.1
+ applyZoomToPlot(self.plot, scaleF, (x, y))
+
+ @property
+ def color(self):
+ return self.parameters.get('color', None)
+
+ def click(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self._processEvent(x, y, isLast=True)
+
+ def beginDrag(self, x, y, btn):
+ self._processEvent(x, y, isLast=False)
+
+ def drag(self, x, y, btn):
+ self._processEvent(x, y, isLast=False)
+
+ def endDrag(self, startPos, endPos, btn):
+ x, y = endPos
+ self._processEvent(x, y, isLast=True)
+
+ def cancel(self):
+ self.resetSelectionArea()
+ self._points = []
+
+ def _processEvent(self, x, y, isLast):
+ dataPos = self.plot.pixelToData(x, y, check=False)
+ isNewPoint = not self._points or dataPos != self._points[-1]
+
+ if isNewPoint:
+ self._points.append(dataPos)
+
+ if isNewPoint or isLast:
+ eventDict = prepareDrawingSignal(
+ 'drawingFinished' if isLast else 'drawingProgress',
+ 'polylines',
+ self._points,
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ if not isLast:
+ self.setSelectionArea(self._points, fill='none', color=self.color,
+ shape='polylines')
+ else:
+ self.cancel()
+
+
+# ItemInteraction #############################################################
+
+class ItemsInteraction(ClickOrDrag, _PlotInteraction):
+ """Interaction with items (markers, curves and images).
+
+ This class provides selection and dragging of plot primitives
+ that support those interaction.
+ It is also meant to be combined with the zoom interaction.
+ """
+
+ class Idle(ClickOrDrag.Idle):
+ def __init__(self, *args, **kw):
+ super(ItemsInteraction.Idle, self).__init__(*args, **kw)
+ self._hoverMarker = None
+
+ def onWheel(self, x, y, angle):
+ scaleF = 1.1 if angle > 0 else 1. / 1.1
+ applyZoomToPlot(self.machine.plot, scaleF, (x, y))
+
+ def onMove(self, x, y):
+ marker = self.machine.plot._getMarkerAt(x, y)
+
+ if marker is not None:
+ dataPos = self.machine.plot.pixelToData(x, y)
+ assert dataPos is not None
+ eventDict = prepareHoverSignal(
+ marker.getName(), 'marker',
+ dataPos, (x, y),
+ marker.isDraggable(),
+ marker.isSelectable())
+ self.machine.plot.notify(**eventDict)
+
+ if marker != self._hoverMarker:
+ self._hoverMarker = marker
+
+ if marker is None:
+ self.machine.plot.setGraphCursorShape()
+
+ elif marker.isDraggable():
+ if isinstance(marker, items.YMarker):
+ self.machine.plot.setGraphCursorShape(CURSOR_SIZE_VER)
+ elif isinstance(marker, items.XMarker):
+ self.machine.plot.setGraphCursorShape(CURSOR_SIZE_HOR)
+ else:
+ self.machine.plot.setGraphCursorShape(CURSOR_SIZE_ALL)
+
+ elif marker.isSelectable():
+ self.machine.plot.setGraphCursorShape(CURSOR_POINTING)
+ else:
+ self.machine.plot.setGraphCursorShape()
+
+ return True
+
+ def __init__(self, plot):
+ self._pan = Pan(plot)
+
+ _PlotInteraction.__init__(self, plot)
+ ClickOrDrag.__init__(self,
+ clickButtons=(LEFT_BTN, RIGHT_BTN),
+ dragButtons=(LEFT_BTN, MIDDLE_BTN))
+
+ def click(self, x, y, btn):
+ """Handle mouse click
+
+ :param x: X position of the mouse in pixels
+ :param y: Y position of the mouse in pixels
+ :param btn: Pressed button id
+ :return: True if click is catched by an item, False otherwise
+ """
+ # Signal mouse clicked event
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ eventDict = prepareMouseSignal('mouseClicked', btn,
+ dataPos[0], dataPos[1],
+ x, y)
+ self.plot.notify(**eventDict)
+
+ eventDict = self._handleClick(x, y, btn)
+ if eventDict is not None:
+ self.plot.notify(**eventDict)
+
+ def _handleClick(self, x, y, btn):
+ """Perform picking and prepare event if click is handled here
+
+ :param x: X position of the mouse in pixels
+ :param y: Y position of the mouse in pixels
+ :param btn: Pressed button id
+ :return: event description to send of None if not handling event.
+ :rtype: dict or None
+ """
+
+ if btn == LEFT_BTN:
+ result = self.plot._pickTopMost(x, y, lambda i: i.isSelectable())
+ if result is None:
+ return None
+
+ item = result.getItem()
+
+ if isinstance(item, items.MarkerBase):
+ xData, yData = item.getPosition()
+ if xData is None:
+ xData = [0, 1]
+ if yData is None:
+ yData = [0, 1]
+
+ eventDict = prepareMarkerSignal('markerClicked',
+ 'left',
+ item.getName(),
+ 'marker',
+ item.isDraggable(),
+ item.isSelectable(),
+ (xData, yData),
+ (x, y), None)
+ return eventDict
+
+ elif isinstance(item, items.Curve):
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+
+ xData = item.getXData(copy=False)
+ yData = item.getYData(copy=False)
+
+ indices = result.getIndices(copy=False)
+ eventDict = prepareCurveSignal('left',
+ item.getName(),
+ 'curve',
+ xData[indices],
+ yData[indices],
+ dataPos[0], dataPos[1],
+ x, y)
+ return eventDict
+
+ elif isinstance(item, items.ImageBase):
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+
+ indices = result.getIndices(copy=False)
+ row, column = indices[0][0], indices[1][0]
+ eventDict = prepareImageSignal('left',
+ item.getName(),
+ 'image',
+ column, row,
+ dataPos[0], dataPos[1],
+ x, y)
+ return eventDict
+
+ return None
+
+ def _signalMarkerMovingEvent(self, eventType, marker, x, y):
+ assert marker is not None
+
+ xData, yData = marker.getPosition()
+ if xData is None:
+ xData = [0, 1]
+ if yData is None:
+ yData = [0, 1]
+
+ posDataCursor = self.plot.pixelToData(x, y)
+ assert posDataCursor is not None
+
+ eventDict = prepareMarkerSignal(eventType,
+ 'left',
+ marker.getName(),
+ 'marker',
+ marker.isDraggable(),
+ marker.isSelectable(),
+ (xData, yData),
+ (x, y),
+ posDataCursor)
+ self.plot.notify(**eventDict)
+
+ @staticmethod
+ def __isDraggableItem(item):
+ return isinstance(item, items.DraggableMixIn) and item.isDraggable()
+
+ def __terminateDrag(self):
+ """Finalize a drag operation by reseting to initial state"""
+ self.plot.setGraphCursorShape()
+ self.draggedItemRef = None
+
+ def beginDrag(self, x, y, btn):
+ """Handle begining of drag interaction
+
+ :param x: X position of the mouse in pixels
+ :param y: Y position of the mouse in pixels
+ :param str btn: The mouse button for which a drag is starting.
+ :return: True if drag is catched by an item, False otherwise
+ """
+ if btn == LEFT_BTN:
+ self._lastPos = self.plot.pixelToData(x, y)
+ assert self._lastPos is not None
+
+ result = self.plot._pickTopMost(x, y, self.__isDraggableItem)
+ item = result.getItem() if result is not None else None
+
+ self.draggedItemRef = None if item is None else weakref.ref(item)
+
+ if item is None:
+ self.__terminateDrag()
+ return False
+
+ if isinstance(item, items.MarkerBase):
+ self._signalMarkerMovingEvent('markerMoving', item, x, y)
+ item._startDrag()
+
+ return True
+ elif btn == MIDDLE_BTN:
+ self._pan.beginDrag(x, y, btn)
+ return True
+
+ def drag(self, x, y, btn):
+ if btn == LEFT_BTN:
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+
+ item = None if self.draggedItemRef is None else self.draggedItemRef()
+ if item is not None:
+ item.drag(self._lastPos, dataPos)
+
+ if isinstance(item, items.MarkerBase):
+ self._signalMarkerMovingEvent('markerMoving', item, x, y)
+
+ self._lastPos = dataPos
+ elif btn == MIDDLE_BTN:
+ self._pan.drag(x, y, btn)
+
+ def endDrag(self, startPos, endPos, btn):
+ if btn == LEFT_BTN:
+ item = None if self.draggedItemRef is None else self.draggedItemRef()
+ if isinstance(item, items.MarkerBase):
+ posData = list(item.getPosition())
+ if posData[0] is None:
+ posData[0] = 1.
+ if posData[1] is None:
+ posData[1] = 1.
+
+ eventDict = prepareMarkerSignal(
+ 'markerMoved',
+ 'left',
+ item.getLegend(),
+ 'marker',
+ item.isDraggable(),
+ item.isSelectable(),
+ posData)
+ self.plot.notify(**eventDict)
+ item._endDrag()
+
+ self.__terminateDrag()
+ elif btn == MIDDLE_BTN:
+ self._pan.endDrag(startPos, endPos, btn)
+
+ def cancel(self):
+ self._pan.cancel()
+ self.__terminateDrag()
+
+
+class ItemsInteractionForCombo(ItemsInteraction):
+ """Interaction with items to combine through :class:`FocusManager`.
+ """
+
+ class Idle(ItemsInteraction.Idle):
+ @staticmethod
+ def __isItemSelectableOrDraggable(item):
+ return (item.isSelectable() or (
+ isinstance(item, items.DraggableMixIn) and item.isDraggable()))
+
+ def onPress(self, x, y, btn):
+ if btn == LEFT_BTN:
+ result = self.machine.plot._pickTopMost(
+ x, y, self.__isItemSelectableOrDraggable)
+ if result is not None: # Request focus and handle interaction
+ self.goto('clickOrDrag', x, y, btn)
+ return True
+ else: # Do not request focus
+ return False
+ else:
+ return super().onPress(x, y, btn)
+
+
+# FocusManager ################################################################
+
+class FocusManager(StateMachine):
+ """Manages focus across multiple event handlers
+
+ On press an event handler can acquire focus.
+ By default it looses focus when all buttons are released.
+ """
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ if btn == LEFT_BTN:
+ for eventHandler in self.machine.eventHandlers:
+ requestFocus = eventHandler.handleEvent('press', x, y, btn)
+ if requestFocus:
+ self.goto('focus', eventHandler, btn)
+ break
+
+ def _processEvent(self, *args):
+ for eventHandler in self.machine.eventHandlers:
+ consumeEvent = eventHandler.handleEvent(*args)
+ if consumeEvent:
+ break
+
+ def onMove(self, x, y):
+ self._processEvent('move', x, y)
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self._processEvent('release', x, y, btn)
+
+ def onWheel(self, x, y, angle):
+ self._processEvent('wheel', x, y, angle)
+
+ class Focus(State):
+ def enterState(self, eventHandler, btn):
+ self.eventHandler = eventHandler
+ self.focusBtns = {btn}
+
+ def validate(self):
+ self.eventHandler.validate()
+ self.goto('idle')
+
+ def onPress(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.focusBtns.add(btn)
+ self.eventHandler.handleEvent('press', x, y, btn)
+
+ def onMove(self, x, y):
+ self.eventHandler.handleEvent('move', x, y)
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.focusBtns.discard(btn)
+ requestFocus = self.eventHandler.handleEvent('release', x, y, btn)
+ if len(self.focusBtns) == 0 and not requestFocus:
+ self.goto('idle')
+
+ def onWheel(self, x, y, angleInDegrees):
+ self.eventHandler.handleEvent('wheel', x, y, angleInDegrees)
+
+ def __init__(self, eventHandlers=()):
+ self.eventHandlers = list(eventHandlers)
+
+ states = {
+ 'idle': FocusManager.Idle,
+ 'focus': FocusManager.Focus
+ }
+ super(FocusManager, self).__init__(states, 'idle')
+
+ def cancel(self):
+ for handler in self.eventHandlers:
+ handler.cancel()
+
+
+class ZoomAndSelect(ItemsInteraction):
+ """Combine Zoom and ItemInteraction state machine.
+
+ :param plot: The Plot to which this interaction is attached
+ :param color: The color to use for the zoom area bounding box
+ """
+
+ def __init__(self, plot, color):
+ super(ZoomAndSelect, self).__init__(plot)
+ self._zoom = Zoom(plot, color)
+ self._doZoom = False
+
+ @property
+ def color(self):
+ """Color of the zoom area"""
+ return self._zoom.color
+
+ def click(self, x, y, btn):
+ """Handle mouse click
+
+ :param x: X position of the mouse in pixels
+ :param y: Y position of the mouse in pixels
+ :param btn: Pressed button id
+ :return: True if click is catched by an item, False otherwise
+ """
+ eventDict = self._handleClick(x, y, btn)
+
+ if eventDict is not None:
+ # Signal mouse clicked event
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ clickedEventDict = prepareMouseSignal('mouseClicked', btn,
+ dataPos[0], dataPos[1],
+ x, y)
+ self.plot.notify(**clickedEventDict)
+
+ self.plot.notify(**eventDict)
+
+ else:
+ self._zoom.click(x, y, btn)
+
+ def beginDrag(self, x, y, btn):
+ """Handle start drag and switching between zoom and item drag.
+
+ :param x: X position in pixels
+ :param y: Y position in pixels
+ :param str btn: The mouse button for which a drag is starting.
+ """
+ self._doZoom = not super(ZoomAndSelect, self).beginDrag(x, y, btn)
+ if self._doZoom:
+ self._zoom.beginDrag(x, y, btn)
+
+ def drag(self, x, y, btn):
+ """Handle drag, eventually forwarding to zoom.
+
+ :param x: X position in pixels
+ :param y: Y position in pixels
+ :param str btn: The mouse button for which a drag is in progress.
+ """
+ if self._doZoom:
+ return self._zoom.drag(x, y, btn)
+ else:
+ return super(ZoomAndSelect, self).drag(x, y, btn)
+
+ def endDrag(self, startPos, endPos, btn):
+ """Handle end of drag, eventually forwarding to zoom.
+
+ :param startPos: (x, y) position at the beginning of the drag
+ :param endPos: (x, y) position at the end of the drag
+ :param str btn: The mouse button for which a drag is done.
+ """
+ if self._doZoom:
+ return self._zoom.endDrag(startPos, endPos, btn)
+ else:
+ return super(ZoomAndSelect, self).endDrag(startPos, endPos, btn)
+
+
+class PanAndSelect(ItemsInteraction):
+ """Combine Pan and ItemInteraction state machine.
+
+ :param plot: The Plot to which this interaction is attached
+ """
+
+ def __init__(self, plot):
+ super(PanAndSelect, self).__init__(plot)
+ self._pan = Pan(plot)
+ self._doPan = False
+
+ def click(self, x, y, btn):
+ """Handle mouse click
+
+ :param x: X position of the mouse in pixels
+ :param y: Y position of the mouse in pixels
+ :param btn: Pressed button id
+ :return: True if click is catched by an item, False otherwise
+ """
+ eventDict = self._handleClick(x, y, btn)
+
+ if eventDict is not None:
+ # Signal mouse clicked event
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ clickedEventDict = prepareMouseSignal('mouseClicked', btn,
+ dataPos[0], dataPos[1],
+ x, y)
+ self.plot.notify(**clickedEventDict)
+
+ self.plot.notify(**eventDict)
+
+ else:
+ self._pan.click(x, y, btn)
+
+ def beginDrag(self, x, y, btn):
+ """Handle start drag and switching between zoom and item drag.
+
+ :param x: X position in pixels
+ :param y: Y position in pixels
+ :param str btn: The mouse button for which a drag is starting.
+ """
+ self._doPan = not super(PanAndSelect, self).beginDrag(x, y, btn)
+ if self._doPan:
+ self._pan.beginDrag(x, y, btn)
+
+ def drag(self, x, y, btn):
+ """Handle drag, eventually forwarding to zoom.
+
+ :param x: X position in pixels
+ :param y: Y position in pixels
+ :param str btn: The mouse button for which a drag is in progress.
+ """
+ if self._doPan:
+ return self._pan.drag(x, y, btn)
+ else:
+ return super(PanAndSelect, self).drag(x, y, btn)
+
+ def endDrag(self, startPos, endPos, btn):
+ """Handle end of drag, eventually forwarding to zoom.
+
+ :param startPos: (x, y) position at the beginning of the drag
+ :param endPos: (x, y) position at the end of the drag
+ :param str btn: The mouse button for which a drag is done.
+ """
+ if self._doPan:
+ return self._pan.endDrag(startPos, endPos, btn)
+ else:
+ return super(PanAndSelect, self).endDrag(startPos, endPos, btn)
+
+
+# Interaction mode control ####################################################
+
+# Mapping of draw modes: event handler
+_DRAW_MODES = {
+ 'polygon': SelectPolygon,
+ 'rectangle': SelectRectangle,
+ 'ellipse': SelectEllipse,
+ 'line': SelectLine,
+ 'vline': SelectVLine,
+ 'hline': SelectHLine,
+ 'polylines': SelectFreeLine,
+ 'pencil': DrawFreeHand,
+ }
+
+
+class DrawMode(FocusManager):
+ """Interactive mode for draw and select"""
+
+ def __init__(self, plot, shape, label, color, width):
+ eventHandlerClass = _DRAW_MODES[shape]
+ parameters = {
+ 'shape': shape,
+ 'label': label,
+ 'color': color,
+ 'width': width,
+ }
+ super().__init__((
+ Pan(plot, clickButtons=(), dragButtons=(MIDDLE_BTN,)),
+ eventHandlerClass(plot, parameters)))
+
+ def getDescription(self):
+ """Returns the dict describing this interactive mode"""
+ params = self.eventHandlers[1].parameters.copy()
+ params['mode'] = 'draw'
+ return params
+
+
+class DrawSelectMode(FocusManager):
+ """Interactive mode for draw and select"""
+
+ def __init__(self, plot, shape, label, color, width):
+ eventHandlerClass = _DRAW_MODES[shape]
+ self._pan = Pan(plot)
+ self._panStart = None
+ parameters = {
+ 'shape': shape,
+ 'label': label,
+ 'color': color,
+ 'width': width,
+ }
+ super().__init__((
+ ItemsInteractionForCombo(plot),
+ eventHandlerClass(plot, parameters)))
+
+ def handleEvent(self, eventName, *args, **kwargs):
+ # Hack to add pan interaction to select-draw
+ # See issue Refactor PlotWidget interaction #3292
+ if eventName == 'press' and args[2] == MIDDLE_BTN:
+ self._panStart = args[:2]
+ self._pan.beginDrag(*args)
+ return # Consume middle click events
+ elif eventName == 'release' and args[2] == MIDDLE_BTN:
+ self._panStart = None
+ self._pan.endDrag(self._panStart, args[:2], MIDDLE_BTN)
+ return # Consume middle click events
+ elif self._panStart is not None and eventName == 'move':
+ x, y = args[:2]
+ self._pan.drag(x, y, MIDDLE_BTN)
+
+ super().handleEvent(eventName, *args, **kwargs)
+
+ def getDescription(self):
+ """Returns the dict describing this interactive mode"""
+ params = self.eventHandlers[1].parameters.copy()
+ params['mode'] = 'select-draw'
+ return params
+
+
+class PlotInteraction(object):
+ """Proxy to currently use state machine for interaction.
+
+ This allows to switch interactive mode.
+
+ :param plot: The :class:`Plot` to apply interaction to
+ """
+
+ _DRAW_MODES = {
+ 'polygon': SelectPolygon,
+ 'rectangle': SelectRectangle,
+ 'ellipse': SelectEllipse,
+ 'line': SelectLine,
+ 'vline': SelectVLine,
+ 'hline': SelectHLine,
+ 'polylines': SelectFreeLine,
+ 'pencil': DrawFreeHand,
+ }
+
+ def __init__(self, plot):
+ self._plot = weakref.ref(plot) # Avoid cyclic-ref
+
+ self.zoomOnWheel = True
+ """True to enable zoom on wheel, False otherwise."""
+
+ # Default event handler
+ self._eventHandler = ItemsInteraction(plot)
+
+ def getInteractiveMode(self):
+ """Returns the current interactive mode as a dict.
+
+ The returned dict contains at least the key 'mode'.
+ Mode can be: 'draw', 'pan', 'select', 'select-draw', 'zoom'.
+ It can also contains extra keys (e.g., 'color') specific to a mode
+ as provided to :meth:`setInteractiveMode`.
+ """
+ if isinstance(self._eventHandler, ZoomAndSelect):
+ return {'mode': 'zoom', 'color': self._eventHandler.color}
+
+ elif isinstance(self._eventHandler, (DrawMode, DrawSelectMode)):
+ return self._eventHandler.getDescription()
+
+ elif isinstance(self._eventHandler, PanAndSelect):
+ return {'mode': 'pan'}
+
+ else:
+ return {'mode': 'select'}
+
+ def validate(self):
+ """Validate the current interaction if possible
+
+ If was designed to close the polygon interaction.
+ """
+ self._eventHandler.validate()
+
+ def setInteractiveMode(self, mode, color='black',
+ shape='polygon', label=None, width=None):
+ """Switch the interactive mode.
+
+ :param str mode: The name of the interactive mode.
+ In 'draw', 'pan', 'select', 'select-draw', 'zoom'.
+ :param color: Only for 'draw' and 'zoom' modes.
+ Color to use for drawing selection area. Default black.
+ If None, selection area is not drawn.
+ :type color: Color description: The name as a str or
+ a tuple of 4 floats or None.
+ :param str shape: Only for 'draw' mode. The kind of shape to draw.
+ In 'polygon', 'rectangle', 'line', 'vline', 'hline',
+ 'polylines'.
+ Default is 'polygon'.
+ :param str label: Only for 'draw' mode.
+ :param float width: Width of the pencil. Only for draw pencil mode.
+ """
+ assert mode in ('draw', 'pan', 'select', 'select-draw', 'zoom')
+
+ plot = self._plot()
+ assert plot is not None
+
+ if isinstance(color, numpy.ndarray) or color not in (None, 'video inverted'):
+ color = colors.rgba(color)
+
+ if mode in ('draw', 'select-draw'):
+ self._eventHandler.cancel()
+ handlerClass = DrawMode if mode == 'draw' else DrawSelectMode
+ self._eventHandler = handlerClass(plot, shape, label, color, width)
+
+ elif mode == 'pan':
+ # Ignores color, shape and label
+ self._eventHandler.cancel()
+ self._eventHandler = PanAndSelect(plot)
+
+ elif mode == 'zoom':
+ # Ignores shape and label
+ self._eventHandler.cancel()
+ self._eventHandler = ZoomAndSelect(plot, color)
+
+ else: # Default mode: interaction with plot objects
+ # Ignores color, shape and label
+ self._eventHandler.cancel()
+ self._eventHandler = ItemsInteraction(plot)
+
+ def handleEvent(self, event, *args, **kwargs):
+ """Forward event to current interactive mode state machine."""
+ if not self.zoomOnWheel and event == 'wheel':
+ return # Discard wheel events
+ self._eventHandler.handleEvent(event, *args, **kwargs)
diff --git a/src/silx/gui/plot/PlotToolButtons.py b/src/silx/gui/plot/PlotToolButtons.py
new file mode 100644
index 0000000..3970896
--- /dev/null
+++ b/src/silx/gui/plot/PlotToolButtons.py
@@ -0,0 +1,592 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a set of QToolButton to use with
+:class:`~silx.gui.plot.PlotWidget`.
+
+The following QToolButton are available:
+
+- :class:`.AspectToolButton`
+- :class:`.YAxisOriginToolButton`
+- :class:`.ProfileToolButton`
+- :class:`.SymbolToolButton`
+
+"""
+
+__authors__ = ["V. Valls", "H. Payno"]
+__license__ = "MIT"
+__date__ = "27/06/2017"
+
+
+import functools
+import logging
+import weakref
+
+from .. import icons
+from .. import qt
+from ... import config
+
+from .items import SymbolMixIn, Scatter
+
+
+_logger = logging.getLogger(__name__)
+
+
+class PlotToolButton(qt.QToolButton):
+ """A QToolButton connected to a :class:`~silx.gui.plot.PlotWidget`.
+ """
+
+ def __init__(self, parent=None, plot=None):
+ super(PlotToolButton, self).__init__(parent)
+ self._plotRef = None
+ if plot is not None:
+ self.setPlot(plot)
+
+ def plot(self):
+ """
+ Returns the plot connected to the widget.
+ """
+ return None if self._plotRef is None else self._plotRef()
+
+ def setPlot(self, plot):
+ """
+ Set the plot connected to the widget
+
+ :param plot: :class:`.PlotWidget` instance on which to operate.
+ """
+ previousPlot = self.plot()
+
+ if previousPlot is plot:
+ return
+ if previousPlot is not None:
+ self._disconnectPlot(previousPlot)
+
+ if plot is None:
+ self._plotRef = None
+ else:
+ self._plotRef = weakref.ref(plot)
+ self._connectPlot(plot)
+
+ def _connectPlot(self, plot):
+ """
+ Called when the plot is connected to the widget
+
+ :param plot: :class:`.PlotWidget` instance
+ """
+ pass
+
+ def _disconnectPlot(self, plot):
+ """
+ Called when the plot is disconnected from the widget
+
+ :param plot: :class:`.PlotWidget` instance
+ """
+ pass
+
+
+class AspectToolButton(PlotToolButton):
+ """Tool button to switch keep aspect ratio of a plot"""
+
+ STATE = None
+ """Lazy loaded states used to feed AspectToolButton"""
+
+ def __init__(self, parent=None, plot=None):
+ if self.STATE is None:
+ self.STATE = {}
+ # dont keep ratio
+ self.STATE[False, "icon"] = icons.getQIcon('shape-ellipse-solid')
+ self.STATE[False, "state"] = "Aspect ratio is not kept"
+ self.STATE[False, "action"] = "Do no keep data aspect ratio"
+ # keep ratio
+ self.STATE[True, "icon"] = icons.getQIcon('shape-circle-solid')
+ self.STATE[True, "state"] = "Aspect ratio is kept"
+ self.STATE[True, "action"] = "Keep data aspect ratio"
+
+ super(AspectToolButton, self).__init__(parent=parent, plot=plot)
+
+ keepAction = self._createAction(True)
+ keepAction.triggered.connect(self.keepDataAspectRatio)
+ keepAction.setIconVisibleInMenu(True)
+
+ dontKeepAction = self._createAction(False)
+ dontKeepAction.triggered.connect(self.dontKeepDataAspectRatio)
+ dontKeepAction.setIconVisibleInMenu(True)
+
+ menu = qt.QMenu(self)
+ menu.addAction(keepAction)
+ menu.addAction(dontKeepAction)
+ self.setMenu(menu)
+ self.setPopupMode(qt.QToolButton.InstantPopup)
+
+ def _createAction(self, keepAspectRatio):
+ icon = self.STATE[keepAspectRatio, "icon"]
+ text = self.STATE[keepAspectRatio, "action"]
+ return qt.QAction(icon, text, self)
+
+ def _connectPlot(self, plot):
+ plot.sigSetKeepDataAspectRatio.connect(self._keepDataAspectRatioChanged)
+ self._keepDataAspectRatioChanged(plot.isKeepDataAspectRatio())
+
+ def _disconnectPlot(self, plot):
+ plot.sigSetKeepDataAspectRatio.disconnect(self._keepDataAspectRatioChanged)
+
+ def keepDataAspectRatio(self):
+ """Configure the plot to keep the aspect ratio"""
+ plot = self.plot()
+ if plot is not None:
+ # This will trigger _keepDataAspectRatioChanged
+ plot.setKeepDataAspectRatio(True)
+
+ def dontKeepDataAspectRatio(self):
+ """Configure the plot to not keep the aspect ratio"""
+ plot = self.plot()
+ if plot is not None:
+ # This will trigger _keepDataAspectRatioChanged
+ plot.setKeepDataAspectRatio(False)
+
+ def _keepDataAspectRatioChanged(self, aspectRatio):
+ """Handle Plot set keep aspect ratio signal"""
+ icon, toolTip = self.STATE[aspectRatio, "icon"], self.STATE[aspectRatio, "state"]
+ self.setIcon(icon)
+ self.setToolTip(toolTip)
+
+
+class YAxisOriginToolButton(PlotToolButton):
+ """Tool button to switch the Y axis orientation of a plot."""
+
+ STATE = None
+ """Lazy loaded states used to feed YAxisOriginToolButton"""
+
+ def __init__(self, parent=None, plot=None):
+ if self.STATE is None:
+ self.STATE = {}
+ # is down
+ self.STATE[False, "icon"] = icons.getQIcon('plot-ydown')
+ self.STATE[False, "state"] = "Y-axis is oriented downward"
+ self.STATE[False, "action"] = "Orient Y-axis downward"
+ # keep ration
+ self.STATE[True, "icon"] = icons.getQIcon('plot-yup')
+ self.STATE[True, "state"] = "Y-axis is oriented upward"
+ self.STATE[True, "action"] = "Orient Y-axis upward"
+
+ super(YAxisOriginToolButton, self).__init__(parent=parent, plot=plot)
+
+ upwardAction = self._createAction(True)
+ upwardAction.triggered.connect(self.setYAxisUpward)
+ upwardAction.setIconVisibleInMenu(True)
+
+ downwardAction = self._createAction(False)
+ downwardAction.triggered.connect(self.setYAxisDownward)
+ downwardAction.setIconVisibleInMenu(True)
+
+ menu = qt.QMenu(self)
+ menu.addAction(upwardAction)
+ menu.addAction(downwardAction)
+ self.setMenu(menu)
+ self.setPopupMode(qt.QToolButton.InstantPopup)
+
+ def _createAction(self, isUpward):
+ icon = self.STATE[isUpward, "icon"]
+ text = self.STATE[isUpward, "action"]
+ return qt.QAction(icon, text, self)
+
+ def _connectPlot(self, plot):
+ yAxis = plot.getYAxis()
+ yAxis.sigInvertedChanged.connect(self._yAxisInvertedChanged)
+ self._yAxisInvertedChanged(yAxis.isInverted())
+
+ def _disconnectPlot(self, plot):
+ plot.getYAxis().sigInvertedChanged.disconnect(self._yAxisInvertedChanged)
+
+ def setYAxisUpward(self):
+ """Configure the plot to use y-axis upward"""
+ plot = self.plot()
+ if plot is not None:
+ # This will trigger _yAxisInvertedChanged
+ plot.getYAxis().setInverted(False)
+
+ def setYAxisDownward(self):
+ """Configure the plot to use y-axis downward"""
+ plot = self.plot()
+ if plot is not None:
+ # This will trigger _yAxisInvertedChanged
+ plot.getYAxis().setInverted(True)
+
+ def _yAxisInvertedChanged(self, inverted):
+ """Handle Plot set y axis inverted signal"""
+ isUpward = not inverted
+ icon, toolTip = self.STATE[isUpward, "icon"], self.STATE[isUpward, "state"]
+ self.setIcon(icon)
+ self.setToolTip(toolTip)
+
+
+class ProfileOptionToolButton(PlotToolButton):
+ """Button to define option on the profile"""
+ sigMethodChanged = qt.Signal(str)
+
+ def __init__(self, parent=None, plot=None):
+ PlotToolButton.__init__(self, parent=parent, plot=plot)
+
+ self.STATE = {}
+ # is down
+ self.STATE['sum', "icon"] = icons.getQIcon('math-sigma')
+ self.STATE['sum', "state"] = "Compute profile sum"
+ self.STATE['sum', "action"] = "Compute profile sum"
+ # keep ration
+ self.STATE['mean', "icon"] = icons.getQIcon('math-mean')
+ self.STATE['mean', "state"] = "Compute profile mean"
+ self.STATE['mean', "action"] = "Compute profile mean"
+
+ self.sumAction = self._createAction('sum')
+ self.sumAction.triggered.connect(self.setSum)
+ self.sumAction.setIconVisibleInMenu(True)
+ self.sumAction.setCheckable(True)
+ self.sumAction.setChecked(True)
+
+ self.meanAction = self._createAction('mean')
+ self.meanAction.triggered.connect(self.setMean)
+ self.meanAction.setIconVisibleInMenu(True)
+ self.meanAction.setCheckable(True)
+
+ menu = qt.QMenu(self)
+ menu.addAction(self.sumAction)
+ menu.addAction(self.meanAction)
+ self.setMenu(menu)
+ self.setPopupMode(qt.QToolButton.InstantPopup)
+ self._method = 'mean'
+ self._update()
+
+ def _createAction(self, method):
+ icon = self.STATE[method, "icon"]
+ text = self.STATE[method, "action"]
+ return qt.QAction(icon, text, self)
+
+ def setSum(self):
+ self.setMethod('sum')
+
+ def _update(self):
+ icon = self.STATE[self._method, "icon"]
+ toolTip = self.STATE[self._method, "state"]
+ self.setIcon(icon)
+ self.setToolTip(toolTip)
+ self.sumAction.setChecked(self._method == "sum")
+ self.meanAction.setChecked(self._method == "mean")
+
+ def setMean(self):
+ self.setMethod('mean')
+
+ def setMethod(self, method):
+ """Set the method to use.
+
+ :param str method: Either 'sum' or 'mean'
+ """
+ if method != self._method:
+ if method in ('sum', 'mean'):
+ self._method = method
+ self.sigMethodChanged.emit(self._method)
+ self._update()
+ else:
+ _logger.warning(
+ "Unsupported method '%s'. Setting ignored.", method)
+
+ def getMethod(self):
+ """Returns the current method in use (See :meth:`setMethod`).
+
+ :rtype: str
+ """
+ return self._method
+
+
+class ProfileToolButton(PlotToolButton):
+ """Button used in Profile3DToolbar to switch between 2D profile
+ and 1D profile."""
+ STATE = None
+ """Lazy loaded states used to feed ProfileToolButton"""
+
+ sigDimensionChanged = qt.Signal(int)
+
+ def __init__(self, parent=None, plot=None):
+ if self.STATE is None:
+ self.STATE = {
+ (1, "icon"): icons.getQIcon('profile1D'),
+ (1, "state"): "1D profile is computed on visible image",
+ (1, "action"): "1D profile on visible image",
+ (2, "icon"): icons.getQIcon('profile2D'),
+ (2, "state"): "2D profile is computed, one 1D profile for each image in the stack",
+ (2, "action"): "2D profile on image stack"}
+ # Compute 1D profile
+ # Compute 2D profile
+
+ super(ProfileToolButton, self).__init__(parent=parent, plot=plot)
+
+ self._dimension = 1
+
+ profile1DAction = self._createAction(1)
+ profile1DAction.triggered.connect(self.computeProfileIn1D)
+ profile1DAction.setIconVisibleInMenu(True)
+ profile1DAction.setCheckable(True)
+ profile1DAction.setChecked(True)
+ self._profile1DAction = profile1DAction
+
+ profile2DAction = self._createAction(2)
+ profile2DAction.triggered.connect(self.computeProfileIn2D)
+ profile2DAction.setIconVisibleInMenu(True)
+ profile2DAction.setCheckable(True)
+ self._profile2DAction = profile2DAction
+
+ menu = qt.QMenu(self)
+ menu.addAction(profile1DAction)
+ menu.addAction(profile2DAction)
+ self.setMenu(menu)
+ self.setPopupMode(qt.QToolButton.InstantPopup)
+ menu.setTitle('Select profile dimension')
+ self.computeProfileIn1D()
+
+ def _createAction(self, profileDimension):
+ icon = self.STATE[profileDimension, "icon"]
+ text = self.STATE[profileDimension, "action"]
+ return qt.QAction(icon, text, self)
+
+ def _profileDimensionChanged(self, profileDimension):
+ """Update icon in toolbar, emit number of dimensions for profile"""
+ self.setIcon(self.STATE[profileDimension, "icon"])
+ self.setToolTip(self.STATE[profileDimension, "state"])
+ self._dimension = profileDimension
+ self.sigDimensionChanged.emit(profileDimension)
+ self._profile1DAction.setChecked(profileDimension == 1)
+ self._profile2DAction.setChecked(profileDimension == 2)
+
+ def computeProfileIn1D(self):
+ self._profileDimensionChanged(1)
+
+ def computeProfileIn2D(self):
+ self._profileDimensionChanged(2)
+
+ def setDimension(self, dimension):
+ """Set the selected dimension"""
+ assert dimension in [1, 2]
+ if self._dimension == dimension:
+ return
+ if dimension == 1:
+ self.computeProfileIn1D()
+ elif dimension == 2:
+ self.computeProfileIn2D()
+ else:
+ _logger.warning("Unsupported dimension '%s'. Setting ignored.", dimension)
+
+ def getDimension(self):
+ """Get the selected dimension.
+
+ :rtype: int (1 or 2)
+ """
+ return self._dimension
+
+
+class _SymbolToolButtonBase(PlotToolButton):
+ """Base class for PlotToolButton setting marker and size.
+
+ :param parent: See QWidget
+ :param plot: The `~silx.gui.plot.PlotWidget` to control
+ """
+
+ def __init__(self, parent=None, plot=None):
+ super(_SymbolToolButtonBase, self).__init__(parent=parent, plot=plot)
+
+ def _addSizeSliderToMenu(self, menu):
+ """Add a slider to set size to the given menu
+
+ :param QMenu menu:
+ """
+ slider = qt.QSlider(qt.Qt.Horizontal)
+ slider.setRange(1, 20)
+ slider.setValue(int(config.DEFAULT_PLOT_SYMBOL_SIZE))
+ slider.setTracking(False)
+ slider.valueChanged.connect(self._sizeChanged)
+ widgetAction = qt.QWidgetAction(menu)
+ widgetAction.setDefaultWidget(slider)
+ menu.addAction(widgetAction)
+
+ def _addSymbolsToMenu(self, menu):
+ """Add symbols to the given menu
+
+ :param QMenu menu:
+ """
+ for marker, name in zip(SymbolMixIn.getSupportedSymbols(),
+ SymbolMixIn.getSupportedSymbolNames()):
+ action = qt.QAction(name, menu)
+ action.setCheckable(False)
+ action.triggered.connect(
+ functools.partial(self._markerChanged, marker))
+ menu.addAction(action)
+
+ def _sizeChanged(self, value):
+ """Manage slider value changed
+
+ :param int value: Marker size
+ """
+ plot = self.plot()
+ if plot is None:
+ return
+
+ for item in plot.getItems():
+ if isinstance(item, SymbolMixIn):
+ item.setSymbolSize(value)
+
+ def _markerChanged(self, marker):
+ """Manage change of marker.
+
+ :param str marker: Letter describing the marker
+ """
+ plot = self.plot()
+ if plot is None:
+ return
+
+ for item in plot.getItems():
+ if isinstance(item, SymbolMixIn):
+ item.setSymbol(marker)
+
+
+class SymbolToolButton(_SymbolToolButtonBase):
+ """A tool button with a drop-down menu to control symbol size and marker.
+
+ :param parent: See QWidget
+ :param plot: The `~silx.gui.plot.PlotWidget` to control
+ """
+
+ def __init__(self, parent=None, plot=None):
+ super(SymbolToolButton, self).__init__(parent=parent, plot=plot)
+
+ self.setToolTip('Set symbol size and marker')
+ self.setIcon(icons.getQIcon('plot-symbols'))
+
+ menu = qt.QMenu(self)
+ self._addSizeSliderToMenu(menu)
+ menu.addSeparator()
+ self._addSymbolsToMenu(menu)
+
+ self.setMenu(menu)
+ self.setPopupMode(qt.QToolButton.InstantPopup)
+
+
+class ScatterVisualizationToolButton(_SymbolToolButtonBase):
+ """QToolButton to select the visualization mode of scatter plot
+
+ :param parent: See QWidget
+ :param plot: The `~silx.gui.plot.PlotWidget` to control
+ """
+
+ def __init__(self, parent=None, plot=None):
+ super(ScatterVisualizationToolButton, self).__init__(
+ parent=parent, plot=plot)
+
+ self.setToolTip(
+ 'Set scatter visualization mode, symbol marker and size')
+ self.setIcon(icons.getQIcon('eye'))
+
+ menu = qt.QMenu(self)
+
+ # Add visualization modes
+
+ for mode in Scatter.supportedVisualizations():
+ if mode is not Scatter.Visualization.BINNED_STATISTIC:
+ name = mode.value.capitalize()
+ action = qt.QAction(name, menu)
+ action.setCheckable(False)
+ action.triggered.connect(
+ functools.partial(self._visualizationChanged, mode, None))
+ menu.addAction(action)
+
+ if Scatter.Visualization.BINNED_STATISTIC in Scatter.supportedVisualizations():
+ reductions = Scatter.supportedVisualizationParameterValues(
+ Scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION)
+ if reductions:
+ submenu = menu.addMenu('Binned Statistic')
+ for reduction in reductions:
+ name = reduction.capitalize()
+ action = qt.QAction(name, menu)
+ action.setCheckable(False)
+ action.triggered.connect(functools.partial(
+ self._visualizationChanged,
+ Scatter.Visualization.BINNED_STATISTIC,
+ {Scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION: reduction}))
+ submenu.addAction(action)
+
+ submenu.addSeparator()
+ binsmenu = submenu.addMenu('N Bins')
+
+ slider = qt.QSlider(qt.Qt.Horizontal)
+ slider.setRange(10, 1000)
+ slider.setValue(100)
+ slider.setTracking(False)
+ slider.valueChanged.connect(self._binningChanged)
+ widgetAction = qt.QWidgetAction(binsmenu)
+ widgetAction.setDefaultWidget(slider)
+ binsmenu.addAction(widgetAction)
+
+ menu.addSeparator()
+
+ submenu = menu.addMenu(icons.getQIcon('plot-symbols'), "Symbol")
+ self._addSymbolsToMenu(submenu)
+
+ submenu = menu.addMenu(icons.getQIcon('plot-symbols'), "Symbol Size")
+ self._addSizeSliderToMenu(submenu)
+
+ self.setMenu(menu)
+ self.setPopupMode(qt.QToolButton.InstantPopup)
+
+ def _visualizationChanged(self, mode, parameters=None):
+ """Handle change of visualization mode.
+
+ :param ScatterVisualizationMixIn.Visualization mode:
+ The visualization mode to use for scatter
+ :param Union[dict,None] parameters:
+ Dict of VisualizationParameter: parameter_value to set
+ with the visualization.
+ """
+ plot = self.plot()
+ if plot is None:
+ return
+
+ for item in plot.getItems():
+ if isinstance(item, Scatter):
+ if parameters:
+ for parameter, value in parameters.items():
+ item.setVisualizationParameter(parameter, value)
+ item.setVisualization(mode)
+
+ def _binningChanged(self, value):
+ """Handle change of binning.
+
+ :param int value: The number of bin on each dimension.
+ """
+ plot = self.plot()
+ if plot is None:
+ return
+
+ for item in plot.getItems():
+ if isinstance(item, Scatter):
+ item.setVisualizationParameter(
+ Scatter.VisualizationParameter.BINNED_STATISTIC_SHAPE,
+ (value, value))
+ item.setVisualization(Scatter.Visualization.BINNED_STATISTIC)
diff --git a/src/silx/gui/plot/PlotTools.py b/src/silx/gui/plot/PlotTools.py
new file mode 100644
index 0000000..5929473
--- /dev/null
+++ b/src/silx/gui/plot/PlotTools.py
@@ -0,0 +1,43 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Set of widgets to associate with a :class:'PlotWidget'.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "01/03/2018"
+
+
+from ...utils.deprecation import deprecated_warning
+
+deprecated_warning(type_='module',
+ name=__file__,
+ reason='Plot tools refactoring',
+ replacement='silx.gui.plot.tools',
+ since_version='0.8')
+
+from .tools import PositionInfo, LimitsToolBar # noqa
diff --git a/src/silx/gui/plot/PlotWidget.py b/src/silx/gui/plot/PlotWidget.py
new file mode 100755
index 0000000..6cb5ef5
--- /dev/null
+++ b/src/silx/gui/plot/PlotWidget.py
@@ -0,0 +1,3628 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+# ###########################################################################*/
+"""Qt widget providing plot API for 1D and 2D data.
+
+The :class:`PlotWidget` implements the plot API initially provided in PyMca.
+"""
+
+from __future__ import division
+
+
+__authors__ = ["V.A. Sole", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "21/12/2018"
+
+import logging
+
+_logger = logging.getLogger(__name__)
+
+
+from collections import OrderedDict, namedtuple
+from contextlib import contextmanager
+import datetime as dt
+import itertools
+import typing
+import warnings
+
+import numpy
+
+import silx
+from silx.utils.weakref import WeakMethodProxy
+from silx.utils.property import classproperty
+from silx.utils.deprecation import deprecated, deprecated_warning
+try:
+ # Import matplotlib now to init matplotlib our way
+ import silx.gui.utils.matplotlib # noqa
+except ImportError:
+ _logger.debug("matplotlib not available")
+
+from ..colors import Colormap
+from .. import colors
+from . import PlotInteraction
+from . import PlotEvents
+from .LimitsHistory import LimitsHistory
+from . import _utils
+
+from . import items
+from .items.curve import CurveStyle
+from .items.axis import TickMode # noqa
+
+from .. import qt
+from ._utils.panzoom import ViewConstraints
+from ...gui.plot._utils.dtime_ticklayout import timestamp
+
+
+
+_COLORDICT = colors.COLORDICT
+_COLORLIST = silx.config.DEFAULT_PLOT_CURVE_COLORS
+
+"""
+Object returned when requesting the data range.
+"""
+_PlotDataRange = namedtuple('PlotDataRange',
+ ['x', 'y', 'yright'])
+
+
+class _PlotWidgetSelection(qt.QObject):
+ """Object managing a :class:`PlotWidget` selection.
+
+ It is a wrapper over :class:`PlotWidget`'s active items API.
+
+ :param PlotWidget parent:
+ """
+
+ sigCurrentItemChanged = qt.Signal(object, object)
+ """This signal is emitted whenever the current item changes.
+
+ It provides the current and previous items.
+ """
+
+ sigSelectedItemsChanged = qt.Signal()
+ """Signal emitted whenever the list of selected items changes."""
+
+ def __init__(self, parent):
+ assert isinstance(parent, PlotWidget)
+ super(_PlotWidgetSelection, self).__init__(parent=parent)
+
+ # Init history
+ self.__history = [ # Store active items from most recent to oldest
+ item for item in (parent.getActiveCurve(),
+ parent.getActiveImage(),
+ parent.getActiveScatter())
+ if item is not None]
+
+ self.__current = self.__mostRecentActiveItem()
+
+ parent.sigActiveImageChanged.connect(self._activeImageChanged)
+ parent.sigActiveCurveChanged.connect(self._activeCurveChanged)
+ parent.sigActiveScatterChanged.connect(self._activeScatterChanged)
+
+ def __mostRecentActiveItem(self) -> typing.Optional[items.Item]:
+ """Returns most recent active item."""
+ return self.__history[0] if len(self.__history) >= 1 else None
+
+ def getSelectedItems(self) -> typing.Tuple[items.Item]:
+ """Returns the list of currently selected items in the :class:`PlotWidget`.
+
+ The list is given from most recently current item to oldest one."""
+ plot = self.parent()
+ if plot is None:
+ return ()
+
+ active = tuple(self.__history)
+
+ current = self.getCurrentItem()
+ if current is not None and current not in active:
+ # Current might not be an active item, if so add it
+ active = (current,) + active
+
+ return active
+
+ def getCurrentItem(self) -> typing.Optional[items.Item]:
+ """Returns the current item in the :class:`PlotWidget` or None. """
+ return self.__current
+
+ def setCurrentItem(self, item: typing.Optional[items.Item]):
+ """Set the current item in the :class:`PlotWidget`.
+
+ :param item:
+ The new item to select or None to clear the selection.
+ :raise ValueError: If the item is not the :class:`PlotWidget`
+ """
+ previous = self.getCurrentItem()
+ if previous is item:
+ return
+
+ previousSelected = self.getSelectedItems()
+
+ if item is None:
+ self.__current = None
+
+ # Reset all PlotWidget active items
+ plot = self.parent()
+ if plot is not None:
+ for kind in PlotWidget._ACTIVE_ITEM_KINDS:
+ if plot._getActiveItem(kind) is not None:
+ plot._setActiveItem(kind, None)
+
+ elif isinstance(item, items.Item):
+ plot = self.parent()
+ if plot is None or item.getPlot() is not plot:
+ raise ValueError(
+ "Item is not in the PlotWidget: %s" % str(item))
+ self.__current = item
+
+ kind = plot._itemKind(item)
+
+ # Clean-up history to be safe
+ self.__history = [item for item in self.__history
+ if PlotWidget._itemKind(item) != kind]
+
+ # Sync active item if needed
+ if (kind in plot._ACTIVE_ITEM_KINDS and
+ item is not plot._getActiveItem(kind)):
+ plot._setActiveItem(kind, item.getName())
+ else:
+ raise ValueError("Not an Item: %s" % str(item))
+
+ self.sigCurrentItemChanged.emit(previous, item)
+
+ if previousSelected != self.getSelectedItems():
+ self.sigSelectedItemsChanged.emit()
+
+ def __activeItemChanged(self,
+ kind: str,
+ previous: typing.Optional[str],
+ legend: typing.Optional[str]):
+ """Set current item from kind and legend"""
+ if previous == legend:
+ return # No-op for update of item
+
+ plot = self.parent()
+ if plot is None:
+ return
+
+ previousSelected = self.getSelectedItems()
+
+ # Remove items of this kind from the history
+ self.__history = [item for item in self.__history
+ if PlotWidget._itemKind(item) != kind]
+
+ # Retrieve current item
+ if legend is None: # Use most recent active item
+ currentItem = self.__mostRecentActiveItem()
+ else:
+ currentItem = plot._getItem(kind=kind, legend=legend)
+ if currentItem is None: # Fallback in case something went wrong
+ currentItem = self.__mostRecentActiveItem()
+
+ # Update history
+ if currentItem is not None:
+ while currentItem in self.__history:
+ self.__history.remove(currentItem)
+ self.__history.insert(0, currentItem)
+
+ if currentItem != self.__current:
+ previousItem = self.__current
+ self.__current = currentItem
+ self.sigCurrentItemChanged.emit(previousItem, currentItem)
+
+ if previousSelected != self.getSelectedItems():
+ self.sigSelectedItemsChanged.emit()
+
+ def _activeImageChanged(self, previous, current):
+ """Handle active image change"""
+ self.__activeItemChanged('image', previous, current)
+
+ def _activeCurveChanged(self, previous, current):
+ """Handle active curve change"""
+ self.__activeItemChanged('curve', previous, current)
+
+ def _activeScatterChanged(self, previous, current):
+ """Handle active scatter change"""
+ self.__activeItemChanged('scatter', previous, current)
+
+
+class PlotWidget(qt.QMainWindow):
+ """Qt Widget providing a 1D/2D plot.
+
+ This widget is a QMainWindow.
+ This class implements the plot API initially provided in PyMca.
+
+ Supported backends:
+
+ - 'matplotlib' and 'mpl': Matplotlib with Qt.
+ - 'opengl' and 'gl': OpenGL backend (requires PyOpenGL and OpenGL >= 2.1)
+ - 'none': No backend, to run headless for testing purpose.
+
+ :param parent: The parent of this widget or None (default).
+ :param backend: The backend to use, in:
+ 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'none'
+ or a :class:`BackendBase.BackendBase` class
+ :type backend: str or :class:`BackendBase.BackendBase`
+ """
+
+ # TODO: Can be removed for silx 0.10
+ @classproperty
+ @deprecated(replacement="silx.config.DEFAULT_PLOT_BACKEND", since_version="0.8", skip_backtrace_count=2)
+ def DEFAULT_BACKEND(self):
+ """Class attribute setting the default backend for all instances."""
+ return silx.config.DEFAULT_PLOT_BACKEND
+
+ colorList = _COLORLIST
+ colorDict = _COLORDICT
+
+ sigPlotSignal = qt.Signal(object)
+ """Signal for all events of the plot.
+
+ The signal information is provided as a dict.
+ See the :ref:`plot signal documentation page <plot_signal>` for
+ information about the content of the dict
+ """
+
+ sigSetKeepDataAspectRatio = qt.Signal(bool)
+ """Signal emitted when plot keep aspect ratio has changed"""
+
+ sigSetGraphGrid = qt.Signal(str)
+ """Signal emitted when plot grid has changed"""
+
+ sigSetGraphCursor = qt.Signal(bool)
+ """Signal emitted when plot crosshair cursor has changed"""
+
+ sigSetPanWithArrowKeys = qt.Signal(bool)
+ """Signal emitted when pan with arrow keys has changed"""
+
+ _sigAxesVisibilityChanged = qt.Signal(bool)
+ """Signal emitted when the axes visibility changed"""
+
+ sigContentChanged = qt.Signal(str, str, str)
+ """Signal emitted when the content of the plot is changed.
+
+ It provides the following information:
+
+ - action: The change of the plot: 'add' or 'remove'
+ - kind: The kind of primitive changed:
+ 'curve', 'image', 'scatter', 'histogram', 'item' or 'marker'
+ - legend: The legend of the primitive changed.
+ """
+
+ sigActiveCurveChanged = qt.Signal(object, object)
+ """Signal emitted when the active curve has changed.
+
+ It provides the following information:
+
+ - previous: The legend of the previous active curve or None
+ - legend: The legend of the new active curve or None if no curve is active
+ """
+
+ sigActiveImageChanged = qt.Signal(object, object)
+ """Signal emitted when the active image has changed.
+
+ It provides the following information:
+
+ - previous: The legend of the previous active image or None
+ - legend: The legend of the new active image or None if no image is active
+ """
+
+ sigActiveScatterChanged = qt.Signal(object, object)
+ """Signal emitted when the active Scatter has changed.
+
+ It provides the following information:
+
+ - previous: The legend of the previous active scatter or None
+ - legend: The legend of the new active image or None if no image is active
+ """
+
+ sigInteractiveModeChanged = qt.Signal(object)
+ """Signal emitted when the interactive mode has changed
+
+ It provides the source as passed to :meth:`setInteractiveMode`.
+ """
+
+ sigItemAdded = qt.Signal(items.Item)
+ """Signal emitted when an item was just added to the plot
+
+ It provides the added item.
+ """
+
+ sigItemAboutToBeRemoved = qt.Signal(items.Item)
+ """Signal emitted right before an item is removed from the plot.
+
+ It provides the item that will be removed.
+ """
+
+ sigItemRemoved = qt.Signal(items.Item)
+ """Signal emitted right after an item was removed from the plot.
+
+ It provides the item that was removed.
+ """
+
+ sigVisibilityChanged = qt.Signal(bool)
+ """Signal emitted when the widget becomes visible (or invisible).
+ This happens when the widget is hidden or shown.
+
+ It provides the visible state.
+ """
+
+ _sigDefaultContextMenu = qt.Signal(qt.QMenu)
+ """Signal emitted when the default context menu of the plot is feed.
+
+ It provides the menu which will be displayed.
+ """
+
+ def __init__(self, parent=None, backend=None):
+ self._autoreplot = False
+ self._dirty = False
+ self._cursorInPlot = False
+ self.__muteActiveItemChanged = False
+
+ self._panWithArrowKeys = True
+ self._viewConstrains = None
+
+ super(PlotWidget, self).__init__(parent)
+ if parent is not None:
+ # behave as a widget
+ self.setWindowFlags(qt.Qt.Widget)
+ else:
+ self.setWindowTitle('PlotWidget')
+
+ # Init the backend
+ self._backend = self.__getBackendClass(backend)(self, self)
+
+ self.setCallback() # set _callback
+
+ # Items handling
+ self._content = OrderedDict()
+ self._contentToUpdate = [] # Used as an OrderedSet
+
+ self._dataRange = None
+
+ # line types
+ self._styleList = ['-', '--', '-.', ':']
+ self._colorIndex = 0
+ self._styleIndex = 0
+
+ self._activeCurveSelectionMode = "atmostone"
+ self._activeCurveStyle = CurveStyle(color='#000000')
+ self._activeLegend = {'curve': None, 'image': None,
+ 'scatter': None}
+
+ # plot colors (updated later to sync backend)
+ self._foregroundColor = 0., 0., 0., 1.
+ self._gridColor = .7, .7, .7, 1.
+ self._backgroundColor = 1., 1., 1., 1.
+ self._dataBackgroundColor = None
+
+ # default properties
+ self._cursorConfiguration = None
+
+ self._xAxis = items.XAxis(self)
+ self._yAxis = items.YAxis(self)
+ self._yRightAxis = items.YRightAxis(self, self._yAxis)
+
+ self._grid = None
+ self._graphTitle = ''
+ self.__graphCursorShape = 'default'
+
+ # Set axes margins
+ self.__axesDisplayed = True
+ self.__axesMargins = 0., 0., 0., 0.
+ self.setAxesMargins(.15, .1, .1, .15)
+
+ self.setGraphTitle()
+ self.setGraphXLabel()
+ self.setGraphYLabel()
+ self.setGraphYLabel('', axis='right')
+
+ self.setDefaultColormap() # Init default colormap
+
+ self.setDefaultPlotPoints(silx.config.DEFAULT_PLOT_CURVE_SYMBOL_MODE)
+ self.setDefaultPlotLines(True)
+
+ self._limitsHistory = LimitsHistory(self)
+
+ self._eventHandler = PlotInteraction.PlotInteraction(self)
+ self._eventHandler.setInteractiveMode('zoom', color=(0., 0., 0., 1.))
+ self._previousDefaultMode = "zoom", True
+
+ self._pressedButtons = [] # Currently pressed mouse buttons
+
+ self._defaultDataMargins = (0., 0., 0., 0.)
+
+ # Only activate autoreplot at the end
+ # This avoids errors when loaded in Qt designer
+ self._dirty = False
+ self._autoreplot = True
+
+ widget = self.getWidgetHandle()
+ if widget is not None:
+ self.setCentralWidget(widget)
+ else:
+ _logger.info("PlotWidget backend does not support widget")
+
+ self.setFocusPolicy(qt.Qt.StrongFocus)
+ self.setFocus(qt.Qt.OtherFocusReason)
+
+ # Set default limits
+ self.setGraphXLimits(0., 100.)
+ self.setGraphYLimits(0., 100., axis='right')
+ self.setGraphYLimits(0., 100., axis='left')
+
+ # Sync backend colors with default ones
+ self._foregroundColorsUpdated()
+ self._backgroundColorsUpdated()
+
+ # selection handling
+ self.__selection = None
+
+ def __getBackendClass(self, backend):
+ """Returns backend class corresponding to backend.
+
+ If multiple backends are provided, the first available one is used.
+
+ :param Union[str,BackendBase,List[Union[str,BackendBase]]] backend:
+ The name of the backend or its class or an iterable of those.
+ :rtype: BackendBase
+ :raise ValueError: In case the backend is not supported
+ :raise RuntimeError: If a backend is not available
+ """
+ if backend is None:
+ backend = silx.config.DEFAULT_PLOT_BACKEND
+
+ if callable(backend):
+ return backend
+
+ elif isinstance(backend, str):
+ backend = backend.lower()
+ if backend in ('matplotlib', 'mpl'):
+ try:
+ from .backends.BackendMatplotlib import \
+ BackendMatplotlibQt as backendClass
+ except ImportError:
+ _logger.debug("Backtrace", exc_info=True)
+ raise RuntimeError("matplotlib backend is not available")
+
+ elif backend in ('gl', 'opengl'):
+ from ..utils.glutils import isOpenGLAvailable
+ checkOpenGL = isOpenGLAvailable(version=(2, 1), runtimeCheck=False)
+ if not checkOpenGL:
+ _logger.debug("OpenGL check failed")
+ raise RuntimeError(
+ "OpenGL backend is not available: %s" % checkOpenGL.error)
+
+ try:
+ from .backends.BackendOpenGL import \
+ BackendOpenGL as backendClass
+ except ImportError:
+ _logger.debug("Backtrace", exc_info=True)
+ raise RuntimeError("OpenGL backend is not available")
+
+ elif backend == 'none':
+ from .backends.BackendBase import BackendBase as backendClass
+
+ else:
+ raise ValueError("Backend not supported %s" % backend)
+
+ return backendClass
+
+ elif isinstance(backend, (tuple, list)):
+ for b in backend:
+ try:
+ return self.__getBackendClass(b)
+ except RuntimeError:
+ pass
+ else: # No backend was found
+ raise RuntimeError("None of the request backends are available")
+
+ raise ValueError("Backend not supported %s" % str(backend))
+
+ def selection(self):
+ """Returns the selection hander"""
+ if self.__selection is None: # Lazy initialization
+ self.__selection = _PlotWidgetSelection(parent=self)
+ return self.__selection
+
+ # TODO: Can be removed for silx 0.10
+ @staticmethod
+ @deprecated(replacement="silx.config.DEFAULT_PLOT_BACKEND", since_version="0.8", skip_backtrace_count=2)
+ def setDefaultBackend(backend):
+ """Set system wide default plot backend.
+
+ .. versionadded:: 0.6
+
+ :param backend: The backend to use, in:
+ 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'none'
+ or a :class:`BackendBase.BackendBase` class
+ """
+ silx.config.DEFAULT_PLOT_BACKEND = backend
+
+ def setBackend(self, backend):
+ """Set the backend to use for rendering.
+
+ Supported backends:
+
+ - 'matplotlib' and 'mpl': Matplotlib with Qt.
+ - 'opengl' and 'gl': OpenGL backend (requires PyOpenGL and OpenGL >= 2.1)
+ - 'none': No backend, to run headless for testing purpose.
+
+ :param Union[str,BackendBase,List[Union[str,BackendBase]]] backend:
+ The backend to use, in:
+ 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'none',
+ a :class:`BackendBase.BackendBase` class.
+ If multiple backends are provided, the first available one is used.
+ :raises ValueError: Unsupported backend descriptor
+ :raises RuntimeError: Error while loading a backend
+ """
+ backend = self.__getBackendClass(backend)(self, self)
+
+ # First save state that is stored in the backend
+ xaxis = self.getXAxis()
+ xmin, xmax = xaxis.getLimits()
+ ymin, ymax = self.getYAxis(axis='left').getLimits()
+ y2min, y2max = self.getYAxis(axis='right').getLimits()
+ isKeepDataAspectRatio = self.isKeepDataAspectRatio()
+ xTimeZone = xaxis.getTimeZone()
+ isXAxisTimeSeries = xaxis.getTickMode() == TickMode.TIME_SERIES
+
+ isYAxisInverted = self.getYAxis().isInverted()
+
+ # Remove all items from previous backend
+ for item in self.getItems():
+ item._removeBackendRenderer(self._backend)
+
+ # Switch backend
+ self._backend = backend
+ widget = self._backend.getWidgetHandle()
+ self.setCentralWidget(widget)
+ if widget is None:
+ _logger.info("PlotWidget backend does not support widget")
+
+ # Mark as newly dirty
+ self._dirty = False
+ self._setDirtyPlot()
+
+ # Synchronize/restore state
+ self._foregroundColorsUpdated()
+ self._backgroundColorsUpdated()
+
+ self._backend.setGraphCursorShape(self.getGraphCursorShape())
+ crosshairConfig = self.getGraphCursor()
+ if crosshairConfig is None:
+ self._backend.setGraphCursor(False, 'black', 1, '-')
+ else:
+ self._backend.setGraphCursor(True, *crosshairConfig)
+
+ self._backend.setGraphTitle(self.getGraphTitle())
+ self._backend.setGraphGrid(self.getGraphGrid())
+ if self.isAxesDisplayed():
+ self._backend.setAxesMargins(*self.getAxesMargins())
+ else:
+ self._backend.setAxesMargins(0., 0., 0., 0.)
+
+ # Set axes
+ xaxis = self.getXAxis()
+ self._backend.setGraphXLabel(xaxis.getLabel())
+ self._backend.setXAxisTimeZone(xTimeZone)
+ self._backend.setXAxisTimeSeries(isXAxisTimeSeries)
+ self._backend.setXAxisLogarithmic(
+ xaxis.getScale() == items.Axis.LOGARITHMIC)
+
+ for axis in ('left', 'right'):
+ self._backend.setGraphYLabel(self.getYAxis(axis).getLabel(), axis)
+ self._backend.setYAxisInverted(isYAxisInverted)
+ self._backend.setYAxisLogarithmic(
+ self.getYAxis().getScale() == items.Axis.LOGARITHMIC)
+
+ # Finally restore aspect ratio and limits
+ self._backend.setKeepDataAspectRatio(isKeepDataAspectRatio)
+ self.setLimits(xmin, xmax, ymin, ymax, y2min, y2max)
+
+ # Mark all items for update with new backend
+ for item in self.getItems():
+ item._updated()
+
+ def getBackend(self):
+ """Returns the backend currently used by :class:`PlotWidget`.
+
+ :rtype: ~silx.gui.plot.backend.BackendBase.BackendBase
+ """
+ return self._backend
+
+ def _getDirtyPlot(self):
+ """Return the plot dirty flag.
+
+ If False, the plot has not changed since last replot.
+ If True, the full plot need to be redrawn.
+ If 'overlay', only the overlay has changed since last replot.
+
+ It can be accessed by backend to check the dirty state.
+
+ :return: False, True, 'overlay'
+ """
+ return self._dirty
+
+ # Default Qt context menu
+
+ def contextMenuEvent(self, event):
+ """Override QWidget.contextMenuEvent to implement the context menu"""
+ menu = qt.QMenu(self)
+ from .actions.control import ZoomBackAction # Avoid cyclic import
+ zoomBackAction = ZoomBackAction(plot=self, parent=menu)
+ menu.addAction(zoomBackAction)
+
+ mode = self.getInteractiveMode()
+ if "shape" in mode and mode["shape"] == "polygon":
+ from .actions.control import ClosePolygonInteractionAction # Avoid cyclic import
+ action = ClosePolygonInteractionAction(plot=self, parent=menu)
+ menu.addAction(action)
+
+ self._sigDefaultContextMenu.emit(menu)
+
+ # Make sure the plot is updated, especially when the plot is in
+ # draw interaction mode
+ menu.aboutToHide.connect(self.__simulateMouseMove)
+
+ menu.exec(event.globalPos())
+
+ def _setDirtyPlot(self, overlayOnly=False):
+ """Mark the plot as needing redraw
+
+ :param bool overlayOnly: True to redraw only the overlay,
+ False to redraw everything
+ """
+ wasDirty = self._dirty
+
+ if not self._dirty and overlayOnly:
+ self._dirty = 'overlay'
+ else:
+ self._dirty = True
+
+ if self._autoreplot and not wasDirty and self.isVisible():
+ self._backend.postRedisplay()
+
+ def _foregroundColorsUpdated(self):
+ """Handle change of foreground/grid color"""
+ if self._gridColor is None:
+ gridColor = self._foregroundColor
+ else:
+ gridColor = self._gridColor
+ self._backend.setForegroundColors(
+ self._foregroundColor, gridColor)
+ self._setDirtyPlot()
+
+ def getForegroundColor(self):
+ """Returns the RGBA colors used to display the foreground of this widget
+
+ :rtype: qt.QColor
+ """
+ return qt.QColor.fromRgbF(*self._foregroundColor)
+
+ def setForegroundColor(self, color):
+ """Set the foreground color of this widget.
+
+ :param Union[List[int],List[float],QColor] color:
+ The new RGB(A) color.
+ """
+ color = colors.rgba(color)
+ if self._foregroundColor != color:
+ self._foregroundColor = color
+ self._foregroundColorsUpdated()
+
+ def getGridColor(self):
+ """Returns the RGBA colors used to display the grid lines
+
+ An invalid QColor is returned if there is no grid color,
+ in which case the foreground color is used.
+
+ :rtype: qt.QColor
+ """
+ if self._gridColor is None:
+ return qt.QColor() # An invalid color
+ else:
+ return qt.QColor.fromRgbF(*self._gridColor)
+
+ def setGridColor(self, color):
+ """Set the grid lines color
+
+ :param Union[List[int],List[float],QColor,None] color:
+ The new RGB(A) color.
+ """
+ if isinstance(color, qt.QColor) and not color.isValid():
+ color = None
+ if color is not None:
+ color = colors.rgba(color)
+ if self._gridColor != color:
+ self._gridColor = color
+ self._foregroundColorsUpdated()
+
+ def _backgroundColorsUpdated(self):
+ """Handle change of background/data background color"""
+ if self._dataBackgroundColor is None:
+ dataBGColor = self._backgroundColor
+ else:
+ dataBGColor = self._dataBackgroundColor
+ self._backend.setBackgroundColors(
+ self._backgroundColor, dataBGColor)
+ self._setDirtyPlot()
+
+ def getBackgroundColor(self):
+ """Returns the RGBA colors used to display the background of this widget.
+
+ :rtype: qt.QColor
+ """
+ return qt.QColor.fromRgbF(*self._backgroundColor)
+
+ def setBackgroundColor(self, color):
+ """Set the background color of this widget.
+
+ :param Union[List[int],List[float],QColor] color:
+ The new RGB(A) color.
+ """
+ color = colors.rgba(color)
+ if self._backgroundColor != color:
+ self._backgroundColor = color
+ self._backgroundColorsUpdated()
+
+ def getDataBackgroundColor(self):
+ """Returns the RGBA colors used to display the background of the plot
+ view displaying the data.
+
+ An invalid QColor is returned if there is no data background color.
+
+ :rtype: qt.QColor
+ """
+ if self._dataBackgroundColor is None:
+ # An invalid color
+ return qt.QColor()
+ else:
+ return qt.QColor.fromRgbF(*self._dataBackgroundColor)
+
+ def setDataBackgroundColor(self, color):
+ """Set the background color of the plot area.
+
+ Set to None or an invalid QColor to use the background color.
+
+ :param Union[List[int],List[float],QColor,None] color:
+ The new RGB(A) color.
+ """
+ if isinstance(color, qt.QColor) and not color.isValid():
+ color = None
+ if color is not None:
+ color = colors.rgba(color)
+ if self._dataBackgroundColor != color:
+ self._dataBackgroundColor = color
+ self._backgroundColorsUpdated()
+
+ dataBackgroundColor = qt.Property(
+ qt.QColor, getDataBackgroundColor, setDataBackgroundColor
+ )
+
+ backgroundColor = qt.Property(qt.QColor, getBackgroundColor, setBackgroundColor)
+
+ foregroundColor = qt.Property(qt.QColor, getForegroundColor, setForegroundColor)
+
+ gridColor = qt.Property(qt.QColor, getGridColor, setGridColor)
+
+ def showEvent(self, event):
+ if self._autoreplot and self._dirty:
+ self._backend.postRedisplay()
+ super(PlotWidget, self).showEvent(event)
+ self.sigVisibilityChanged.emit(True)
+
+ def hideEvent(self, event):
+ super(PlotWidget, self).hideEvent(event)
+ self.sigVisibilityChanged.emit(False)
+
+ def _invalidateDataRange(self):
+ """
+ Notifies this PlotWidget instance that the range has changed
+ and will have to be recomputed.
+ """
+ self._dataRange = None
+
+ def _updateDataRange(self):
+ """
+ Recomputes the range of the data displayed on this PlotWidget.
+ """
+ xMin = yMinLeft = yMinRight = float('nan')
+ xMax = yMaxLeft = yMaxRight = float('nan')
+
+ for item in self.getItems():
+ if item.isVisible():
+ bounds = item.getBounds()
+ if bounds is not None:
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore', category=RuntimeWarning)
+ # Ignore All-NaN slice encountered
+ xMin = numpy.nanmin([xMin, bounds[0]])
+ xMax = numpy.nanmax([xMax, bounds[1]])
+ # Take care of right axis
+ if (isinstance(item, items.YAxisMixIn) and
+ item.getYAxis() == 'right'):
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore', category=RuntimeWarning)
+ # Ignore All-NaN slice encountered
+ yMinRight = numpy.nanmin([yMinRight, bounds[2]])
+ yMaxRight = numpy.nanmax([yMaxRight, bounds[3]])
+ else:
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore', category=RuntimeWarning)
+ # Ignore All-NaN slice encountered
+ yMinLeft = numpy.nanmin([yMinLeft, bounds[2]])
+ yMaxLeft = numpy.nanmax([yMaxLeft, bounds[3]])
+
+ def lGetRange(x, y):
+ return None if numpy.isnan(x) and numpy.isnan(y) else (x, y)
+ xRange = lGetRange(xMin, xMax)
+ yLeftRange = lGetRange(yMinLeft, yMaxLeft)
+ yRightRange = lGetRange(yMinRight, yMaxRight)
+
+ self._dataRange = _PlotDataRange(x=xRange,
+ y=yLeftRange,
+ yright=yRightRange)
+
+ def getDataRange(self):
+ """
+ Returns this PlotWidget's data range.
+
+ :return: a namedtuple with the following members :
+ x, y (left y axis), yright. Each member is a tuple (min, max)
+ or None if no data is associated with the axis.
+ :rtype: namedtuple
+ """
+ if self._dataRange is None:
+ self._updateDataRange()
+ return self._dataRange
+
+ # Content management
+
+ _KIND_TO_CLASSES = {
+ 'curve': (items.Curve,),
+ 'image': (items.ImageBase,),
+ 'scatter': (items.Scatter,),
+ 'marker': (items.MarkerBase,),
+ 'item': (items.Shape,
+ items.BoundingRect,
+ items.XAxisExtent,
+ items.YAxisExtent),
+ 'histogram': (items.Histogram,),
+ }
+ """Mapping kind to item classes of this kind"""
+
+ @classmethod
+ def _itemKind(cls, item):
+ """Returns the "kind" of a given item
+
+ :param Item item: The item get the kind
+ :rtype: str
+ """
+ for kind, itemClasses in cls._KIND_TO_CLASSES.items():
+ if isinstance(item, itemClasses):
+ return kind
+ raise ValueError('Unsupported item type %s' % type(item))
+
+ def _notifyContentChanged(self, item):
+ self.notify('contentChanged', action='add',
+ kind=self._itemKind(item), legend=item.getName())
+
+ def _itemRequiresUpdate(self, item):
+ """Called by items in the plot for asynchronous update
+
+ :param Item item: The item that required update
+ """
+ assert item.getPlot() == self
+ # Put item at the end of the list
+ if item in self._contentToUpdate:
+ self._contentToUpdate.remove(item)
+ self._contentToUpdate.append(item)
+ self._setDirtyPlot(overlayOnly=item.isOverlay())
+
+ def addItem(self, item=None, *args, **kwargs):
+ """Add an item to the plot content.
+
+ :param ~silx.gui.plot.items.Item item: The item to add.
+ :raises ValueError: If item is already in the plot.
+ """
+ if not isinstance(item, items.Item):
+ deprecated_warning(
+ 'Function',
+ 'addItem',
+ replacement='addShape',
+ since_version='0.13')
+ if item is None and not args: # Only kwargs
+ return self.addShape(**kwargs)
+ else:
+ return self.addShape(item, *args, **kwargs)
+
+ assert not args and not kwargs
+ if item in self.getItems():
+ raise ValueError('Item already in the plot')
+
+ # Add item to plot
+ self._content[(item.getName(), self._itemKind(item))] = item
+ item._setPlot(self)
+ self._itemRequiresUpdate(item)
+ if isinstance(item, items.DATA_ITEMS):
+ self._invalidateDataRange() # TODO handle this automatically
+
+ self._notifyContentChanged(item)
+ self.sigItemAdded.emit(item)
+
+ def removeItem(self, item):
+ """Remove the item from the plot.
+
+ :param ~silx.gui.plot.items.Item item: Item to remove from the plot.
+ :raises ValueError: If item is not in the plot.
+ """
+ if not isinstance(item, items.Item): # Previous method usage
+ deprecated_warning(
+ 'Function',
+ 'removeItem',
+ replacement='remove(legend, kind="item")',
+ since_version='0.13')
+ if item is None:
+ return
+ self.remove(item, kind='item')
+ return
+
+ if item not in self.getItems():
+ raise ValueError('Item not in the plot')
+
+ self.sigItemAboutToBeRemoved.emit(item)
+
+ kind = self._itemKind(item)
+
+ if kind in self._ACTIVE_ITEM_KINDS:
+ if self._getActiveItem(kind) == item:
+ # Reset active item
+ self._setActiveItem(kind, None)
+
+ # Remove item from plot
+ self._content.pop((item.getName(), kind))
+ if item in self._contentToUpdate:
+ self._contentToUpdate.remove(item)
+ if item.isVisible():
+ self._setDirtyPlot(overlayOnly=item.isOverlay())
+ if item.getBounds() is not None:
+ self._invalidateDataRange()
+ item._removeBackendRenderer(self._backend)
+ item._setPlot(None)
+
+ if (kind == 'curve' and not self.getAllCurves(just_legend=True,
+ withhidden=True)):
+ self._resetColorAndStyle()
+
+ self.sigItemRemoved.emit(item)
+
+ self.notify('contentChanged', action='remove',
+ kind=kind, legend=item.getName())
+
+ def discardItem(self, item) -> bool:
+ """Remove the item from the plot.
+
+ Same as :meth:`removeItem` but do not raise an exception.
+
+ :param ~silx.gui.plot.items.Item item: Item to remove from the plot.
+ :returns: True if the item was present, False otherwise.
+ """
+ try:
+ self.removeItem(item)
+ except ValueError:
+ return False
+ else:
+ return True
+
+ @deprecated(replacement='addItem', since_version='0.13')
+ def _add(self, item):
+ return self.addItem(item)
+
+ @deprecated(replacement='removeItem', since_version='0.13')
+ def _remove(self, item):
+ return self.removeItem(item)
+
+ def getItems(self):
+ """Returns the list of items in the plot
+
+ :rtype: List[silx.gui.plot.items.Item]
+ """
+ return tuple(self._content.values())
+
+ @contextmanager
+ def _muteActiveItemChangedSignal(self):
+ self.__muteActiveItemChanged = True
+ yield
+ self.__muteActiveItemChanged = False
+
+ # Add
+
+ # add * input arguments management:
+ # If an arg is set, then use it.
+ # Else:
+ # If a curve with the same legend exists, then use its arg value
+ # Else, use a default value.
+ # Store used value.
+ # This value is used when curve is updated either internally or by user.
+
+ def addCurve(self, x, y, legend=None, info=None,
+ replace=False,
+ color=None, symbol=None,
+ linewidth=None, linestyle=None,
+ xlabel=None, ylabel=None, yaxis=None,
+ xerror=None, yerror=None, z=None, selectable=None,
+ fill=None, resetzoom=True,
+ histogram=None, copy=True,
+ baseline=None):
+ """Add a 1D curve given by x an y to the graph.
+
+ Curves are uniquely identified by their legend.
+ To add multiple curves, call :meth:`addCurve` multiple times with
+ different legend argument.
+ To replace an existing curve, call :meth:`addCurve` with the
+ existing curve legend.
+ If you want to display the curve values as an histogram see the
+ histogram parameter or :meth:`addHistogram`.
+
+ When curve parameters are not provided, if a curve with the
+ same legend is displayed in the plot, its parameters are used.
+
+ :param numpy.ndarray x: The data corresponding to the x coordinates.
+ If you attempt to plot an histogram you can set edges values in x.
+ In this case len(x) = len(y) + 1.
+ If x contains datetime objects the XAxis tickMode is set to
+ TickMode.TIME_SERIES.
+ :param numpy.ndarray y: The data corresponding to the y coordinates
+ :param str legend: The legend to be associated to the curve (or None)
+ :param info: User-defined information associated to the curve
+ :param bool replace: True to delete already existing curves
+ (the default is False)
+ :param color: color(s) to be used
+ :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or
+ one of the predefined color names defined in colors.py
+ :param str symbol: Symbol to be drawn at each (x, y) position::
+
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+ - None (the default) to use default symbol
+
+ :param float linewidth: The width of the curve in pixels (Default: 1).
+ :param str linestyle: Type of line::
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+ - None (the default) to use default line style
+
+ :param str xlabel: Label to show on the X axis when the curve is active
+ or None to keep default axis label.
+ :param str ylabel: Label to show on the Y axis when the curve is active
+ or None to keep default axis label.
+ :param str yaxis: The Y axis this curve is attached to.
+ Either 'left' (the default) or 'right'
+ :param xerror: Values with the uncertainties on the x values
+ :type xerror: A float, or a numpy.ndarray of float32.
+ If it is an array, it can either be a 1D array of
+ same length as the data or a 2D array with 2 rows
+ of same length as the data: row 0 for positive errors,
+ row 1 for negative errors.
+ :param yerror: Values with the uncertainties on the y values
+ :type yerror: A float, or a numpy.ndarray of float32. See xerror.
+ :param int z: Layer on which to draw the curve (default: 1)
+ This allows to control the overlay.
+ :param bool selectable: Indicate if the curve can be selected.
+ (Default: True)
+ :param bool fill: True to fill the curve, False otherwise (default).
+ :param bool resetzoom: True (the default) to reset the zoom.
+ :param str histogram: if not None then the curve will be draw as an
+ histogram. The step for each values of the curve can be set to the
+ left, center or right of the original x curve values.
+ If histogram is not None and len(x) == len(y)+1 then x is directly
+ take as edges of the histogram.
+ Type of histogram::
+
+ - None (default)
+ - 'left'
+ - 'right'
+ - 'center'
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ :param baseline: curve baseline
+ :type: Union[None,float,numpy.ndarray]
+ :returns: The key string identify this curve
+ """
+ # This is an histogram, use addHistogram
+ if histogram is not None:
+ histoLegend = self.addHistogram(histogram=y,
+ edges=x,
+ legend=legend,
+ color=color,
+ fill=fill,
+ align=histogram,
+ copy=copy)
+ histo = self.getHistogram(histoLegend)
+
+ histo.setInfo(info)
+ if linewidth is not None:
+ histo.setLineWidth(linewidth)
+ if linestyle is not None:
+ histo.setLineStyle(linestyle)
+ if xlabel is not None:
+ _logger.warning(
+ 'addCurve: Histogram does not support xlabel argument')
+ if ylabel is not None:
+ _logger.warning(
+ 'addCurve: Histogram does not support ylabel argument')
+ if yaxis is not None:
+ histo.setYAxis(yaxis)
+ if z is not None:
+ histo.setZValue(z)
+ if selectable is not None:
+ _logger.warning(
+ 'addCurve: Histogram does not support selectable argument')
+
+ return
+
+ legend = 'Unnamed curve 1.1' if legend is None else str(legend)
+
+ # Check if curve was previously active
+ wasActive = self.getActiveCurve(just_legend=True) == legend
+
+ if replace:
+ self._resetColorAndStyle()
+
+ # Create/Update curve object
+ curve = self.getCurve(legend)
+ mustBeAdded = curve is None
+ if curve is None:
+ # No previous curve, create a default one and add it to the plot
+ curve = items.Curve() if histogram is None else items.Histogram()
+ curve.setName(legend)
+ # Set default color, linestyle and symbol
+ default_color, default_linestyle = self._getColorAndStyle()
+ curve.setColor(default_color)
+ curve.setLineStyle(default_linestyle)
+ curve.setSymbol(self._defaultPlotPoints)
+ curve._setBaseline(baseline=baseline)
+
+ # Do not emit sigActiveCurveChanged,
+ # it will be sent once with _setActiveItem
+ with self._muteActiveItemChangedSignal():
+ # Override previous/default values with provided ones
+ curve.setInfo(info)
+ if color is not None:
+ curve.setColor(color)
+ if symbol is not None:
+ curve.setSymbol(symbol)
+ if linewidth is not None:
+ curve.setLineWidth(linewidth)
+ if linestyle is not None:
+ curve.setLineStyle(linestyle)
+ if xlabel is not None:
+ curve._setXLabel(xlabel)
+ if ylabel is not None:
+ curve._setYLabel(ylabel)
+ if yaxis is not None:
+ curve.setYAxis(yaxis)
+ if z is not None:
+ curve.setZValue(z)
+ if selectable is not None:
+ curve._setSelectable(selectable)
+ if fill is not None:
+ curve.setFill(fill)
+
+ # Set curve data
+ # If errors not provided, reuse previous ones
+ # TODO: Issue if size of data change but not that of errors
+ if xerror is None:
+ xerror = curve.getXErrorData(copy=False)
+ if yerror is None:
+ yerror = curve.getYErrorData(copy=False)
+
+ # Convert x to timestamps so that the internal representation
+ # remains floating points. The user is expected to set the axis'
+ # tickMode to TickMode.TIME_SERIES and, if necessary, set the axis
+ # to the correct time zone.
+ if len(x) > 0 and isinstance(x[0], dt.datetime):
+ x = [timestamp(d) for d in x]
+
+ curve.setData(x, y, xerror, yerror, baseline=baseline, copy=copy)
+
+ if replace: # Then remove all other curves
+ for c in self.getAllCurves(withhidden=True):
+ if c is not curve:
+ self.removeItem(c)
+
+ if mustBeAdded:
+ self.addItem(curve)
+ else:
+ self._notifyContentChanged(curve)
+
+ if wasActive:
+ self.setActiveCurve(curve.getName())
+ elif self.getActiveCurveSelectionMode() == "legacy":
+ if self.getActiveCurve(just_legend=True) is None:
+ if len(self.getAllCurves(just_legend=True,
+ withhidden=False)) == 1:
+ if curve.isVisible():
+ self.setActiveCurve(curve.getName())
+
+ if resetzoom:
+ # We ask for a zoom reset in order to handle the plot scaling
+ # if the user does not want that, autoscale of the different
+ # axes has to be set to off.
+ self.resetZoom()
+
+ return legend
+
+ def addHistogram(self,
+ histogram,
+ edges,
+ legend=None,
+ color=None,
+ fill=None,
+ align='center',
+ resetzoom=True,
+ copy=True,
+ z=None,
+ baseline=None):
+ """Add an histogram to the graph.
+
+ This is NOT computing the histogram, this method takes as parameter
+ already computed histogram values.
+
+ Histogram are uniquely identified by their legend.
+ To add multiple histograms, call :meth:`addHistogram` multiple times
+ with different legend argument.
+
+ When histogram parameters are not provided, if an histogram with the
+ same legend is displayed in the plot, its parameters are used.
+
+ :param numpy.ndarray histogram: The values of the histogram.
+ :param numpy.ndarray edges:
+ The bin edges of the histogram.
+ If histogram and edges have the same length, the bin edges
+ are computed according to the align parameter.
+ :param str legend:
+ The legend to be associated to the histogram (or None)
+ :param color: color to be used
+ :type color: str ("#RRGGBB") or RGB unsigned byte array or
+ one of the predefined color names defined in colors.py
+ :param bool fill: True to fill the curve, False otherwise (default).
+ :param str align:
+ In case histogram values and edges have the same length N,
+ the N+1 bin edges are computed according to the alignment in:
+ 'center' (default), 'left', 'right'.
+ :param bool resetzoom: True (the default) to reset the zoom.
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ :param int z: Layer on which to draw the histogram
+ :param baseline: histogram baseline
+ :type: Union[None,float,numpy.ndarray]
+ :returns: The key string identify this histogram
+ """
+ legend = 'Unnamed histogram' if legend is None else str(legend)
+
+ # Create/Update histogram object
+ histo = self.getHistogram(legend)
+ mustBeAdded = histo is None
+ if histo is None:
+ # No previous histogram, create a default one and
+ # add it to the plot
+ histo = items.Histogram()
+ histo.setName(legend)
+ histo.setColor(self._getColorAndStyle()[0])
+
+ # Override previous/default values with provided ones
+ if color is not None:
+ histo.setColor(color)
+ if fill is not None:
+ histo.setFill(fill)
+ if z is not None:
+ histo.setZValue(z=z)
+
+ # Set histogram data
+ histo.setData(histogram=histogram, edges=edges, baseline=baseline,
+ align=align, copy=copy)
+
+ if mustBeAdded:
+ self.addItem(histo)
+ else:
+ self._notifyContentChanged(histo)
+
+ if resetzoom:
+ # We ask for a zoom reset in order to handle the plot scaling
+ # if the user does not want that, autoscale of the different
+ # axes has to be set to off.
+ self.resetZoom()
+
+ return legend
+
+ def addImage(self, data, legend=None, info=None,
+ replace=False,
+ z=None,
+ selectable=None, draggable=None,
+ colormap=None, pixmap=None,
+ xlabel=None, ylabel=None,
+ origin=None, scale=None,
+ resetzoom=True, copy=True):
+ """Add a 2D dataset or an image to the plot.
+
+ It displays either an array of data using a colormap or a RGB(A) image.
+
+ Images are uniquely identified by their legend.
+ To add multiple images, call :meth:`addImage` multiple times with
+ different legend argument.
+ To replace/update an existing image, call :meth:`addImage` with the
+ existing image legend.
+
+ When image parameters are not provided, if an image with the
+ same legend is displayed in the plot, its parameters are used.
+
+ :param numpy.ndarray data:
+ (nrows, ncolumns) data or
+ (nrows, ncolumns, RGBA) ubyte array
+ Note: boolean values are converted to int8.
+ :param str legend: The legend to be associated to the image (or None)
+ :param info: User-defined information associated to the image
+ :param bool replace:
+ True to delete already existing images (Default: False).
+ :param int z: Layer on which to draw the image (default: 0)
+ This allows to control the overlay.
+ :param bool selectable: Indicate if the image can be selected.
+ (default: False)
+ :param bool draggable: Indicate if the image can be moved.
+ (default: False)
+ :param colormap: Colormap object to use (or None).
+ This is ignored if data is a RGB(A) image.
+ :type colormap: Union[~silx.gui.colors.Colormap, dict]
+ :param pixmap: Pixmap representation of the data (if any)
+ :type pixmap: (nrows, ncolumns, RGBA) ubyte array or None (default)
+ :param str xlabel: X axis label to show when this curve is active,
+ or None to keep default axis label.
+ :param str ylabel: Y axis label to show when this curve is active,
+ or None to keep default axis label.
+ :param origin: (origin X, origin Y) of the data.
+ It is possible to pass a single float if both
+ coordinates are equal.
+ Default: (0., 0.)
+ :type origin: float or 2-tuple of float
+ :param scale: (scale X, scale Y) of the data.
+ It is possible to pass a single float if both
+ coordinates are equal.
+ Default: (1., 1.)
+ :type scale: float or 2-tuple of float
+ :param bool resetzoom: True (the default) to reset the zoom.
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ :returns: The key string identify this image
+ """
+ legend = "Unnamed Image 1.1" if legend is None else str(legend)
+
+ # Check if image was previously active
+ wasActive = self.getActiveImage(just_legend=True) == legend
+
+ data = numpy.array(data, copy=False)
+ assert data.ndim in (2, 3)
+
+ image = self.getImage(legend)
+ if image is not None and image.getData(copy=False).ndim != data.ndim:
+ # Update a data image with RGBA image or the other way around:
+ # Remove previous image
+ # In this case, we don't retrieve defaults from the previous image
+ self.removeItem(image)
+ image = None
+
+ mustBeAdded = image is None
+ if image is None:
+ # No previous image, create a default one and add it to the plot
+ if data.ndim == 2:
+ image = items.ImageData()
+ image.setColormap(self.getDefaultColormap())
+ else:
+ image = items.ImageRgba()
+ image.setName(legend)
+
+ # Do not emit sigActiveImageChanged,
+ # it will be sent once with _setActiveItem
+ with self._muteActiveItemChangedSignal():
+ # Override previous/default values with provided ones
+ image.setInfo(info)
+ if origin is not None:
+ image.setOrigin(origin)
+ if scale is not None:
+ image.setScale(scale)
+ if z is not None:
+ image.setZValue(z)
+ if selectable is not None:
+ image._setSelectable(selectable)
+ if draggable is not None:
+ image._setDraggable(draggable)
+ if colormap is not None and isinstance(image, items.ColormapMixIn):
+ if isinstance(colormap, dict):
+ image.setColormap(Colormap._fromDict(colormap))
+ else:
+ assert isinstance(colormap, Colormap)
+ image.setColormap(colormap)
+ if xlabel is not None:
+ image._setXLabel(xlabel)
+ if ylabel is not None:
+ image._setYLabel(ylabel)
+
+ if data.ndim == 2:
+ image.setData(data, alternative=pixmap, copy=copy)
+ else: # RGB(A) image
+ if pixmap is not None:
+ _logger.warning(
+ 'addImage: pixmap argument ignored when data is RGB(A)')
+ image.setData(data, copy=copy)
+
+ if replace:
+ for img in self.getAllImages():
+ if img is not image:
+ self.removeItem(img)
+
+ if mustBeAdded:
+ self.addItem(image)
+ else:
+ self._notifyContentChanged(image)
+
+ if len(self.getAllImages()) == 1 or wasActive:
+ self.setActiveImage(legend)
+
+ if resetzoom:
+ # We ask for a zoom reset in order to handle the plot scaling
+ # if the user does not want that, autoscale of the different
+ # axes has to be set to off.
+ self.resetZoom()
+
+ return legend
+
+ def addScatter(self, x, y, value, legend=None, colormap=None,
+ info=None, symbol=None, xerror=None, yerror=None,
+ z=None, copy=True):
+ """Add a (x, y, value) scatter to the graph.
+
+ Scatters are uniquely identified by their legend.
+ To add multiple scatters, call :meth:`addScatter` multiple times with
+ different legend argument.
+ To replace/update an existing scatter, call :meth:`addScatter` with the
+ existing scatter legend.
+
+ When scatter parameters are not provided, if a scatter with the
+ same legend is displayed in the plot, its parameters are used.
+
+ :param numpy.ndarray x: The data corresponding to the x coordinates.
+ :param numpy.ndarray y: The data corresponding to the y coordinates
+ :param numpy.ndarray value: The data value associated with each point
+ :param str legend: The legend to be associated to the scatter (or None)
+ :param ~silx.gui.colors.Colormap colormap:
+ Colormap object to be used for the scatter (or None)
+ :param info: User-defined information associated to the curve
+ :param str symbol: Symbol to be drawn at each (x, y) position::
+
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+ - None (the default) to use default symbol
+
+ :param xerror: Values with the uncertainties on the x values
+ :type xerror: A float, or a numpy.ndarray of float32.
+ If it is an array, it can either be a 1D array of
+ same length as the data or a 2D array with 2 rows
+ of same length as the data: row 0 for positive errors,
+ row 1 for negative errors.
+ :param yerror: Values with the uncertainties on the y values
+ :type yerror: A float, or a numpy.ndarray of float32. See xerror.
+ :param int z: Layer on which to draw the scatter (default: 1)
+ This allows to control the overlay.
+
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ :returns: The key string identify this scatter
+ """
+ legend = 'Unnamed scatter 1.1' if legend is None else str(legend)
+
+ # Check if scatter was previously active
+ wasActive = self._getActiveItem(kind='scatter',
+ just_legend=True) == legend
+
+ # Create/Update curve object
+ scatter = self._getItem(kind='scatter', legend=legend)
+ mustBeAdded = scatter is None
+ if scatter is None:
+ # No previous scatter, create a default one and add it to the plot
+ scatter = items.Scatter()
+ scatter.setName(legend)
+ scatter.setColormap(self.getDefaultColormap())
+
+ # Do not emit sigActiveScatterChanged,
+ # it will be sent once with _setActiveItem
+ with self._muteActiveItemChangedSignal():
+ # Override previous/default values with provided ones
+ scatter.setInfo(info)
+ if symbol is not None:
+ scatter.setSymbol(symbol)
+ if z is not None:
+ scatter.setZValue(z)
+ if colormap is not None:
+ if isinstance(colormap, dict):
+ scatter.setColormap(Colormap._fromDict(colormap))
+ else:
+ assert isinstance(colormap, Colormap)
+ scatter.setColormap(colormap)
+
+ # Set scatter data
+ # If errors not provided, reuse previous ones
+ if xerror is None:
+ xerror = scatter.getXErrorData(copy=False)
+ if xerror is not None and len(xerror) != len(x):
+ xerror = None
+ if yerror is None:
+ yerror = scatter.getYErrorData(copy=False)
+ if yerror is not None and len(yerror) != len(y):
+ yerror = None
+
+ scatter.setData(x, y, value, xerror, yerror, copy=copy)
+
+ if mustBeAdded:
+ self.addItem(scatter)
+ else:
+ self._notifyContentChanged(scatter)
+
+ scatters = [item for item in self.getItems()
+ if isinstance(item, items.Scatter) and item.isVisible()]
+ if len(scatters) == 1 or wasActive:
+ self._setActiveItem('scatter', scatter.getName())
+
+ return legend
+
+ def addShape(self, xdata, ydata, legend=None, info=None,
+ replace=False,
+ shape="polygon", color='black', fill=True,
+ overlay=False, z=None, linestyle="-", linewidth=1.0,
+ linebgcolor=None):
+ """Add an item (i.e. a shape) to the plot.
+
+ Items are uniquely identified by their legend.
+ To add multiple items, call :meth:`addItem` multiple times with
+ different legend argument.
+ To replace/update an existing item, call :meth:`addItem` with the
+ existing item legend.
+
+ :param numpy.ndarray xdata: The X coords of the points of the shape
+ :param numpy.ndarray ydata: The Y coords of the points of the shape
+ :param str legend: The legend to be associated to the item
+ :param info: User-defined information associated to the item
+ :param bool replace: True (default) to delete already existing images
+ :param str shape: Type of item to be drawn in
+ hline, polygon (the default), rectangle, vline,
+ polylines
+ :param str color: Color of the item, e.g., 'blue', 'b', '#FF0000'
+ (Default: 'black')
+ :param bool fill: True (the default) to fill the shape
+ :param bool overlay: True if item is an overlay (Default: False).
+ This allows for rendering optimization if this
+ item is changed often.
+ :param int z: Layer on which to draw the item (default: 2)
+ :param str linestyle: Style of the line.
+ Only relevant for line markers where X or Y is None.
+ Value in:
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+ :param float linewidth: Width of the line.
+ Only relevant for line markers where X or Y is None.
+ :param str linebgcolor: Background color of the line, e.g., 'blue', 'b',
+ '#FF0000'. It is used to draw dotted line using a second color.
+ :returns: The key string identify this item
+ """
+ # expected to receive the same parameters as the signal
+
+ legend = "Unnamed Item 1.1" if legend is None else str(legend)
+
+ z = int(z) if z is not None else 2
+
+ if replace:
+ self.remove(kind='item')
+ else:
+ self.remove(legend, kind='item')
+
+ item = items.Shape(shape)
+ item.setName(legend)
+ item.setInfo(info)
+ item.setColor(color)
+ item.setFill(fill)
+ item.setOverlay(overlay)
+ item.setZValue(z)
+ item.setPoints(numpy.array((xdata, ydata)).T)
+ item.setLineStyle(linestyle)
+ item.setLineWidth(linewidth)
+ item.setLineBgColor(linebgcolor)
+
+ self.addItem(item)
+
+ return legend
+
+ def addXMarker(self, x, legend=None,
+ text=None,
+ color=None,
+ selectable=False,
+ draggable=False,
+ constraint=None,
+ yaxis='left'):
+ """Add a vertical line marker to the plot.
+
+ Markers are uniquely identified by their legend.
+ As opposed to curves, images and items, two calls to
+ :meth:`addXMarker` without legend argument adds two markers with
+ different identifying legends.
+
+ :param x: Position of the marker on the X axis in data coordinates
+ :type x: Union[None, float]
+ :param str legend: Legend associated to the marker to identify it
+ :param str text: Text to display on the marker.
+ :param str color: Color of the marker, e.g., 'blue', 'b', '#FF0000'
+ (Default: 'black')
+ :param bool selectable: Indicate if the marker can be selected.
+ (default: False)
+ :param bool draggable: Indicate if the marker can be moved.
+ (default: False)
+ :param constraint: A function filtering marker displacement by
+ dragging operations or None for no filter.
+ This function is called each time a marker is
+ moved.
+ This parameter is only used if draggable is True.
+ :type constraint: None or a callable that takes the coordinates of
+ the current cursor position in the plot as input
+ and that returns the filtered coordinates.
+ :param str yaxis: The Y axis this marker belongs to in: 'left', 'right'
+ :return: The key string identify this marker
+ """
+ return self._addMarker(x=x, y=None, legend=legend,
+ text=text, color=color,
+ selectable=selectable, draggable=draggable,
+ symbol=None, constraint=constraint,
+ yaxis=yaxis)
+
+ def addYMarker(self, y,
+ legend=None,
+ text=None,
+ color=None,
+ selectable=False,
+ draggable=False,
+ constraint=None,
+ yaxis='left'):
+ """Add a horizontal line marker to the plot.
+
+ Markers are uniquely identified by their legend.
+ As opposed to curves, images and items, two calls to
+ :meth:`addYMarker` without legend argument adds two markers with
+ different identifying legends.
+
+ :param float y: Position of the marker on the Y axis in data
+ coordinates
+ :param str legend: Legend associated to the marker to identify it
+ :param str text: Text to display next to the marker.
+ :param str color: Color of the marker, e.g., 'blue', 'b', '#FF0000'
+ (Default: 'black')
+ :param bool selectable: Indicate if the marker can be selected.
+ (default: False)
+ :param bool draggable: Indicate if the marker can be moved.
+ (default: False)
+ :param constraint: A function filtering marker displacement by
+ dragging operations or None for no filter.
+ This function is called each time a marker is
+ moved.
+ This parameter is only used if draggable is True.
+ :type constraint: None or a callable that takes the coordinates of
+ the current cursor position in the plot as input
+ and that returns the filtered coordinates.
+ :param str yaxis: The Y axis this marker belongs to in: 'left', 'right'
+ :return: The key string identify this marker
+ """
+ return self._addMarker(x=None, y=y, legend=legend,
+ text=text, color=color,
+ selectable=selectable, draggable=draggable,
+ symbol=None, constraint=constraint,
+ yaxis=yaxis)
+
+ def addMarker(self, x, y, legend=None,
+ text=None,
+ color=None,
+ selectable=False,
+ draggable=False,
+ symbol='+',
+ constraint=None,
+ yaxis='left'):
+ """Add a point marker to the plot.
+
+ Markers are uniquely identified by their legend.
+ As opposed to curves, images and items, two calls to
+ :meth:`addMarker` without legend argument adds two markers with
+ different identifying legends.
+
+ :param float x: Position of the marker on the X axis in data
+ coordinates
+ :param float y: Position of the marker on the Y axis in data
+ coordinates
+ :param str legend: Legend associated to the marker to identify it
+ :param str text: Text to display next to the marker
+ :param str color: Color of the marker, e.g., 'blue', 'b', '#FF0000'
+ (Default: 'black')
+ :param bool selectable: Indicate if the marker can be selected.
+ (default: False)
+ :param bool draggable: Indicate if the marker can be moved.
+ (default: False)
+ :param str symbol: Symbol representing the marker in::
+
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross (the default)
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+
+ :param constraint: A function filtering marker displacement by
+ dragging operations or None for no filter.
+ This function is called each time a marker is
+ moved.
+ This parameter is only used if draggable is True.
+ :type constraint: None or a callable that takes the coordinates of
+ the current cursor position in the plot as input
+ and that returns the filtered coordinates.
+ :param str yaxis: The Y axis this marker belongs to in: 'left', 'right'
+ :return: The key string identify this marker
+ """
+ if x is None:
+ xmin, xmax = self._xAxis.getLimits()
+ x = 0.5 * (xmax + xmin)
+
+ if y is None:
+ ymin, ymax = self._yAxis.getLimits()
+ y = 0.5 * (ymax + ymin)
+
+ return self._addMarker(x=x, y=y, legend=legend,
+ text=text, color=color,
+ selectable=selectable, draggable=draggable,
+ symbol=symbol, constraint=constraint,
+ yaxis=yaxis)
+
+ def _addMarker(self, x, y, legend,
+ text, color,
+ selectable, draggable,
+ symbol, constraint,
+ yaxis=None):
+ """Common method for adding point, vline and hline marker.
+
+ See :meth:`addMarker` for argument documentation.
+ """
+ assert (x, y) != (None, None)
+
+ if legend is None: # Find an unused legend
+ markerLegends = [item.getName() for item in self.getItems()
+ if isinstance(item, items.MarkerBase)]
+ for index in itertools.count():
+ legend = "Unnamed Marker %d" % index
+ if legend not in markerLegends:
+ break # Keep this legend
+ legend = str(legend)
+
+ if x is None:
+ markerClass = items.YMarker
+ elif y is None:
+ markerClass = items.XMarker
+ else:
+ markerClass = items.Marker
+
+ # Create/Update marker object
+ marker = self._getMarker(legend)
+ if marker is not None and not isinstance(marker, markerClass):
+ _logger.warning('Adding marker with same legend'
+ ' but different type replaces it')
+ self.removeItem(marker)
+ marker = None
+
+ mustBeAdded = marker is None
+ if marker is None:
+ # No previous marker, create one
+ marker = markerClass()
+ marker.setName(legend)
+
+ if text is not None:
+ marker.setText(text)
+ if color is not None:
+ marker.setColor(color)
+ if selectable is not None:
+ marker._setSelectable(selectable)
+ if draggable is not None:
+ marker._setDraggable(draggable)
+ if symbol is not None:
+ marker.setSymbol(symbol)
+ marker.setYAxis(yaxis)
+
+ # TODO to improve, but this ensure constraint is applied
+ marker.setPosition(x, y)
+ if constraint is not None:
+ marker._setConstraint(constraint)
+ marker.setPosition(x, y)
+
+ if mustBeAdded:
+ self.addItem(marker)
+ else:
+ self._notifyContentChanged(marker)
+
+ return legend
+
+ # Hide
+
+ def isCurveHidden(self, legend):
+ """Returns True if the curve associated to legend is hidden, else False
+
+ :param str legend: The legend key identifying the curve
+ :return: True if the associated curve is hidden, False otherwise
+ """
+ curve = self._getItem('curve', legend)
+ return curve is not None and not curve.isVisible()
+
+ def hideCurve(self, legend, flag=True):
+ """Show/Hide the curve associated to legend.
+
+ Even when hidden, the curve is kept in the list of curves.
+
+ :param str legend: The legend associated to the curve to be hidden
+ :param bool flag: True (default) to hide the curve, False to show it
+ """
+ curve = self._getItem('curve', legend)
+ if curve is None:
+ _logger.warning('Curve not in plot: %s', legend)
+ return
+
+ isVisible = not flag
+ if isVisible != curve.isVisible():
+ curve.setVisible(isVisible)
+
+ # Remove
+
+ ITEM_KINDS = 'curve', 'image', 'scatter', 'item', 'marker', 'histogram'
+ """List of supported kind of items in the plot."""
+
+ _ACTIVE_ITEM_KINDS = 'curve', 'scatter', 'image'
+ """List of item's kind which have a active item."""
+
+ def remove(self, legend=None, kind=ITEM_KINDS):
+ """Remove one or all element(s) of the given legend and kind.
+
+ Examples:
+
+ - ``remove()`` clears the plot
+ - ``remove(kind='curve')`` removes all curves from the plot
+ - ``remove('myCurve', kind='curve')`` removes the curve with
+ legend 'myCurve' from the plot.
+ - ``remove('myImage, kind='image')`` removes the image with
+ legend 'myImage' from the plot.
+ - ``remove('myImage')`` removes elements (for instance curve, image,
+ item and marker) with legend 'myImage'.
+
+ :param str legend: The legend associated to the element to remove,
+ or None to remove
+ :param kind: The kind of elements to remove from the plot.
+ See :attr:`ITEM_KINDS`.
+ By default, it removes all kind of elements.
+ :type kind: str or tuple of str to specify multiple kinds.
+ """
+ if kind == 'all': # Replace all by tuple of all kinds
+ kind = self.ITEM_KINDS
+
+ if kind in self.ITEM_KINDS: # Kind is a str, make it a tuple
+ kind = (kind,)
+
+ for aKind in kind:
+ assert aKind in self.ITEM_KINDS
+
+ if legend is None: # This is a clear
+ # Clear each given kind
+ for aKind in kind:
+ for item in self.getItems():
+ if (isinstance(item, self._KIND_TO_CLASSES[aKind]) and
+ item.getPlot() is self): # Make sure item is still in the plot
+ self.removeItem(item)
+
+ else: # This is removing a single element
+ # Remove each given kind
+ for aKind in kind:
+ item = self._getItem(aKind, legend)
+ if item is not None:
+ self.removeItem(item)
+
+ def removeCurve(self, legend):
+ """Remove the curve associated to legend from the graph.
+
+ :param str legend: The legend associated to the curve to be deleted
+ """
+ if legend is None:
+ return
+ self.remove(legend, kind='curve')
+
+ def removeImage(self, legend):
+ """Remove the image associated to legend from the graph.
+
+ :param str legend: The legend associated to the image to be deleted
+ """
+ if legend is None:
+ return
+ self.remove(legend, kind='image')
+
+ def removeMarker(self, legend):
+ """Remove the marker associated to legend from the graph.
+
+ :param str legend: The legend associated to the marker to be deleted
+ """
+ if legend is None:
+ return
+ self.remove(legend, kind='marker')
+
+ # Clear
+
+ def clear(self):
+ """Remove everything from the plot."""
+ for item in self.getItems():
+ if item.getPlot() is self: # Make sure item is still in the plot
+ self.removeItem(item)
+
+ def clearCurves(self):
+ """Remove all the curves from the plot."""
+ self.remove(kind='curve')
+
+ def clearImages(self):
+ """Remove all the images from the plot."""
+ self.remove(kind='image')
+
+ def clearItems(self):
+ """Remove all the items from the plot. """
+ self.remove(kind='item')
+
+ def clearMarkers(self):
+ """Remove all the markers from the plot."""
+ self.remove(kind='marker')
+
+ # Interaction
+
+ def getGraphCursor(self):
+ """Returns the state of the crosshair cursor.
+
+ See :meth:`setGraphCursor`.
+
+ :return: None if the crosshair cursor is not active,
+ else a tuple (color, linewidth, linestyle).
+ """
+ return self._cursorConfiguration
+
+ def setGraphCursor(self, flag=False, color='black',
+ linewidth=1, linestyle='-'):
+ """Toggle the display of a crosshair cursor and set its attributes.
+
+ :param bool flag: Toggle the display of a crosshair cursor.
+ The crosshair cursor is hidden by default.
+ :param color: The color to use for the crosshair.
+ :type color: A string (either a predefined color name in colors.py
+ or "#RRGGBB")) or a 4 columns unsigned byte array
+ (Default: black).
+ :param int linewidth: The width of the lines of the crosshair
+ (Default: 1).
+ :param str linestyle: Type of line::
+
+ - ' ' no line
+ - '-' solid line (the default)
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+ """
+ if flag:
+ self._cursorConfiguration = color, linewidth, linestyle
+ else:
+ self._cursorConfiguration = None
+
+ self._backend.setGraphCursor(flag=flag, color=color,
+ linewidth=linewidth, linestyle=linestyle)
+ self._setDirtyPlot()
+ self.notify('setGraphCursor',
+ state=self._cursorConfiguration is not None)
+
+ def pan(self, direction, factor=0.1):
+ """Pan the graph in the given direction by the given factor.
+
+ Warning: Pan of right Y axis not implemented!
+
+ :param str direction: One of 'up', 'down', 'left', 'right'.
+ :param float factor: Proportion of the range used to pan the graph.
+ Must be strictly positive.
+ """
+ assert direction in ('up', 'down', 'left', 'right')
+ assert factor > 0.
+
+ if direction in ('left', 'right'):
+ xFactor = factor if direction == 'right' else - factor
+ xMin, xMax = self._xAxis.getLimits()
+
+ xMin, xMax = _utils.applyPan(xMin, xMax, xFactor,
+ self._xAxis.getScale() == self._xAxis.LOGARITHMIC)
+ self._xAxis.setLimits(xMin, xMax)
+
+ else: # direction in ('up', 'down')
+ sign = -1. if self._yAxis.isInverted() else 1.
+ yFactor = sign * (factor if direction == 'up' else -factor)
+ yMin, yMax = self._yAxis.getLimits()
+ yIsLog = self._yAxis.getScale() == self._yAxis.LOGARITHMIC
+
+ yMin, yMax = _utils.applyPan(yMin, yMax, yFactor, yIsLog)
+ self._yAxis.setLimits(yMin, yMax)
+
+ y2Min, y2Max = self._yRightAxis.getLimits()
+
+ y2Min, y2Max = _utils.applyPan(y2Min, y2Max, yFactor, yIsLog)
+ self._yRightAxis.setLimits(y2Min, y2Max)
+
+ # Active Curve/Image
+
+ def isActiveCurveHandling(self):
+ """Returns True if active curve selection is enabled.
+
+ :rtype: bool
+ """
+ return self.getActiveCurveSelectionMode() != 'none'
+
+ def setActiveCurveHandling(self, flag=True):
+ """Enable/Disable active curve selection.
+
+ :param bool flag: True to enable 'atmostone' active curve selection,
+ False to disable active curve selection.
+ """
+ self.setActiveCurveSelectionMode('atmostone' if flag else 'none')
+
+ def getActiveCurveStyle(self):
+ """Returns the current style applied to active curve
+
+ :rtype: CurveStyle
+ """
+ return self._activeCurveStyle
+
+ def setActiveCurveStyle(self,
+ color=None,
+ linewidth=None,
+ linestyle=None,
+ symbol=None,
+ symbolsize=None):
+ """Set the style of active curve
+
+ :param color: Color
+ :param Union[str,None] linestyle: Style of the line
+ :param Union[float,None] linewidth: Width of the line
+ :param Union[str,None] symbol: Symbol of the markers
+ :param Union[float,None] symbolsize: Size of the symbols
+ """
+ self._activeCurveStyle = CurveStyle(color=color,
+ linewidth=linewidth,
+ linestyle=linestyle,
+ symbol=symbol,
+ symbolsize=symbolsize)
+ curve = self.getActiveCurve()
+ if curve is not None:
+ curve.setHighlightedStyle(self.getActiveCurveStyle())
+
+ @deprecated(replacement="getActiveCurveStyle", since_version="0.9")
+ def getActiveCurveColor(self):
+ """Get the color used to display the currently active curve.
+
+ See :meth:`setActiveCurveColor`.
+ """
+ return self._activeCurveStyle.getColor()
+
+ @deprecated(replacement="setActiveCurveStyle", since_version="0.9")
+ def setActiveCurveColor(self, color="#000000"):
+ """Set the color to use to display the currently active curve.
+
+ :param str color: Color of the active curve,
+ e.g., 'blue', 'b', '#FF0000' (Default: 'black')
+ """
+ if color is None:
+ color = "black"
+ if color in self.colorDict:
+ color = self.colorDict[color]
+ self.setActiveCurveStyle(color=color)
+
+ def getActiveCurve(self, just_legend=False):
+ """Return the currently active curve.
+
+ It returns None in case of not having an active curve.
+
+ :param bool just_legend: True to get the legend of the curve,
+ False (the default) to get the curve data
+ and info.
+ :return: Active curve's legend or corresponding
+ :class:`.items.Curve`
+ :rtype: str or :class:`.items.Curve` or None
+ """
+ if not self.isActiveCurveHandling():
+ return None
+
+ return self._getActiveItem(kind='curve', just_legend=just_legend)
+
+ def setActiveCurve(self, legend):
+ """Make the curve associated to legend the active curve.
+
+ :param legend: The legend associated to the curve
+ or None to have no active curve.
+ :type legend: str or None
+ """
+ if not self.isActiveCurveHandling():
+ return
+ if legend is None and self.getActiveCurveSelectionMode() == "legacy":
+ _logger.info(
+ 'setActiveCurve(None) ignored due to active curve selection mode')
+ return
+
+ return self._setActiveItem(kind='curve', legend=legend)
+
+ def setActiveCurveSelectionMode(self, mode):
+ """Sets the current selection mode.
+
+ :param str mode: The active curve selection mode to use.
+ It can be: 'legacy', 'atmostone' or 'none'.
+ """
+ assert mode in ('legacy', 'atmostone', 'none')
+
+ if mode != self._activeCurveSelectionMode:
+ self._activeCurveSelectionMode = mode
+ if mode == 'none': # reset active curve
+ self._setActiveItem(kind='curve', legend=None)
+
+ elif mode == 'legacy' and self.getActiveCurve() is None:
+ # Select an active curve
+ curves = self.getAllCurves(just_legend=False,
+ withhidden=False)
+ if len(curves) == 1:
+ if curves[0].isVisible():
+ self.setActiveCurve(curves[0].getName())
+
+ def getActiveCurveSelectionMode(self):
+ """Returns the current selection mode.
+
+ It can be "atmostone", "legacy" or "none".
+
+ :rtype: str
+ """
+ return self._activeCurveSelectionMode
+
+ def getActiveImage(self, just_legend=False):
+ """Returns the currently active image.
+
+ It returns None in case of not having an active image.
+
+ :param bool just_legend: True to get the legend of the image,
+ False (the default) to get the image data
+ and info.
+ :return: Active image's legend or corresponding image object
+ :rtype: str, :class:`.items.ImageData`, :class:`.items.ImageRgba`
+ or None
+ """
+ return self._getActiveItem(kind='image', just_legend=just_legend)
+
+ def setActiveImage(self, legend):
+ """Make the image associated to legend the active image.
+
+ :param str legend: The legend associated to the image
+ or None to have no active image.
+ """
+ return self._setActiveItem(kind='image', legend=legend)
+
+ def getActiveScatter(self, just_legend=False):
+ """Returns the currently active scatter.
+
+ It returns None in case of not having an active scatter.
+
+ :param bool just_legend: True to get the legend of the scatter,
+ False (the default) to get the scatter data
+ and info.
+ :return: Active scatter's legend or corresponding scatter object
+ :rtype: str, :class:`.items.Scatter` or None
+ """
+ return self._getActiveItem(kind='scatter', just_legend=just_legend)
+
+ def setActiveScatter(self, legend):
+ """Make the scatter associated to legend the active scatter.
+
+ :param str legend: The legend associated to the scatter
+ or None to have no active scatter.
+ """
+ return self._setActiveItem(kind='scatter', legend=legend)
+
+ def _getActiveItem(self, kind, just_legend=False):
+ """Return the currently active item of that kind if any
+
+ :param str kind: Type of item: 'curve', 'scatter' or 'image'
+ :param bool just_legend: True to get the legend,
+ False (default) to get the item
+ :return: legend or item or None if no active item
+ """
+ assert kind in self._ACTIVE_ITEM_KINDS
+
+ if self._activeLegend[kind] is None:
+ return None
+
+ item = self._getItem(kind, self._activeLegend[kind])
+ if item is None:
+ return None
+
+ return item.getName() if just_legend else item
+
+ def _setActiveItem(self, kind, legend):
+ """Make the curve associated to legend the active curve.
+
+ :param str kind: Type of item: 'curve' or 'image'
+ :param legend: The legend associated to the curve
+ or None to have no active curve.
+ :type legend: str or None
+ """
+ assert kind in self._ACTIVE_ITEM_KINDS
+
+ xLabel = None
+ yLabel = None
+ yRightLabel = None
+
+ oldActiveItem = self._getActiveItem(kind=kind)
+
+ if oldActiveItem is not None: # Stop listening previous active image
+ oldActiveItem.sigItemChanged.disconnect(self._activeItemChanged)
+
+ # Curve specific: Reset highlight of previous active curve
+ if kind == 'curve' and oldActiveItem is not None:
+ oldActiveItem.setHighlighted(False)
+
+ if legend is None:
+ self._activeLegend[kind] = None
+ else:
+ legend = str(legend)
+ item = self._getItem(kind, legend)
+ if item is None:
+ _logger.warning("This %s does not exist: %s", kind, legend)
+ self._activeLegend[kind] = None
+ else:
+ self._activeLegend[kind] = legend
+
+ # Curve specific: handle highlight
+ if kind == 'curve':
+ item.setHighlightedStyle(self.getActiveCurveStyle())
+ item.setHighlighted(True)
+
+ if isinstance(item, items.LabelsMixIn):
+ if item.getXLabel() is not None:
+ xLabel = item.getXLabel()
+ if item.getYLabel() is not None:
+ if (isinstance(item, items.YAxisMixIn) and
+ item.getYAxis() == 'right'):
+ yRightLabel = item.getYLabel()
+ else:
+ yLabel = item.getYLabel()
+
+ # Start listening new active item
+ item.sigItemChanged.connect(self._activeItemChanged)
+
+ # Store current labels and update plot
+ self._xAxis._setCurrentLabel(xLabel)
+ self._yAxis._setCurrentLabel(yLabel)
+ self._yRightAxis._setCurrentLabel(yRightLabel)
+
+ self._setDirtyPlot()
+
+ activeLegend = self._activeLegend[kind]
+ if oldActiveItem is not None or activeLegend is not None:
+ if oldActiveItem is None:
+ oldActiveLegend = None
+ else:
+ oldActiveLegend = oldActiveItem.getName()
+ self.notify(
+ 'active' + kind[0].upper() + kind[1:] + 'Changed',
+ updated=oldActiveLegend != activeLegend,
+ previous=oldActiveLegend,
+ legend=activeLegend)
+
+ return activeLegend
+
+ def _activeItemChanged(self, type_):
+ """Listen for active item changed signal and broadcast signal
+
+ :param item.ItemChangedType type_: The type of item change
+ """
+ if not self.__muteActiveItemChanged:
+ item = self.sender()
+ if item is not None:
+ kind = self._itemKind(item)
+ self.notify(
+ 'active' + kind[0].upper() + kind[1:] + 'Changed',
+ updated=False,
+ previous=item.getName(),
+ legend=item.getName())
+
+ # Getters
+
+ def getAllCurves(self, just_legend=False, withhidden=False):
+ """Returns all curves legend or info and data.
+
+ It returns an empty list in case of not having any curve.
+
+ If just_legend is False, it returns a list of :class:`items.Curve`
+ objects describing the curves.
+ If just_legend is True, it returns a list of curves' legend.
+
+ :param bool just_legend: True to get the legend of the curves,
+ False (the default) to get the curves' data
+ and info.
+ :param bool withhidden: False (default) to skip hidden curves.
+ :return: list of curves' legend or :class:`.items.Curve`
+ :rtype: list of str or list of :class:`.items.Curve`
+ """
+ curves = [item for item in self.getItems() if
+ isinstance(item, items.Curve) and
+ (withhidden or item.isVisible())]
+ return [curve.getName() for curve in curves] if just_legend else curves
+
+ def getCurve(self, legend=None):
+ """Get the object describing a specific curve.
+
+ It returns None in case no matching curve is found.
+
+ :param str legend:
+ The legend identifying the curve.
+ If not provided or None (the default), the active curve is returned
+ or if there is no active curve, the latest updated curve that is
+ not hidden is returned if there are curves in the plot.
+ :return: None or :class:`.items.Curve` object
+ """
+ return self._getItem(kind='curve', legend=legend)
+
+ def getAllImages(self, just_legend=False):
+ """Returns all images legend or objects.
+
+ It returns an empty list in case of not having any image.
+
+ If just_legend is False, it returns a list of :class:`items.ImageBase`
+ objects describing the images.
+ If just_legend is True, it returns a list of legends.
+
+ :param bool just_legend: True to get the legend of the images,
+ False (the default) to get the images'
+ object.
+ :return: list of images' legend or :class:`.items.ImageBase`
+ :rtype: list of str or list of :class:`.items.ImageBase`
+ """
+ images = [item for item in self.getItems()
+ if isinstance(item, items.ImageBase)]
+ return [image.getName() for image in images] if just_legend else images
+
+ def getImage(self, legend=None):
+ """Get the object describing a specific image.
+
+ It returns None in case no matching image is found.
+
+ :param str legend:
+ The legend identifying the image.
+ If not provided or None (the default), the active image is returned
+ or if there is no active image, the latest updated image
+ is returned if there are images in the plot.
+ :return: None or :class:`.items.ImageBase` object
+ """
+ return self._getItem(kind='image', legend=legend)
+
+ def getScatter(self, legend=None):
+ """Get the object describing a specific scatter.
+
+ It returns None in case no matching scatter is found.
+
+ :param str legend:
+ The legend identifying the scatter.
+ If not provided or None (the default), the active scatter is
+ returned or if there is no active scatter, the latest updated
+ scatter is returned if there are scatters in the plot.
+ :return: None or :class:`.items.Scatter` object
+ """
+ return self._getItem(kind='scatter', legend=legend)
+
+ def getHistogram(self, legend=None):
+ """Get the object describing a specific histogram.
+
+ It returns None in case no matching histogram is found.
+
+ :param str legend:
+ The legend identifying the histogram.
+ If not provided or None (the default), the latest updated scatter
+ is returned if there are histograms in the plot.
+ :return: None or :class:`.items.Histogram` object
+ """
+ return self._getItem(kind='histogram', legend=legend)
+
+ @deprecated(replacement='getItems', since_version='0.13')
+ def _getItems(self, kind=ITEM_KINDS, just_legend=False, withhidden=False):
+ """Retrieve all items of a kind in the plot
+
+ :param kind: The kind of elements to retrieve from the plot.
+ See :attr:`ITEM_KINDS`.
+ By default, it removes all kind of elements.
+ :type kind: str or tuple of str to specify multiple kinds.
+ :param str kind: Type of item: 'curve' or 'image'
+ :param bool just_legend: True to get the legend of the curves,
+ False (the default) to get the curves' data
+ and info.
+ :param bool withhidden: False (default) to skip hidden curves.
+ :return: list of legends or item objects
+ """
+ if kind == 'all': # Replace all by tuple of all kinds
+ kind = self.ITEM_KINDS
+
+ if kind in self.ITEM_KINDS: # Kind is a str, make it a tuple
+ kind = (kind,)
+
+ for aKind in kind:
+ assert aKind in self.ITEM_KINDS
+
+ output = []
+ for item in self.getItems():
+ type_ = self._itemKind(item)
+ if type_ in kind and (withhidden or item.isVisible()):
+ output.append(item.getName() if just_legend else item)
+ return output
+
+ def _getItem(self, kind, legend=None):
+ """Get an item from the plot: either an image or a curve.
+
+ Returns None if no match found.
+
+ :param str kind: Type of item to retrieve,
+ see :attr:`ITEM_KINDS`.
+ :param str legend: Legend of the item or
+ None to get active or last item
+ :return: Object describing the item or None
+ """
+ assert kind in self.ITEM_KINDS
+
+ if legend is not None:
+ return self._content.get((legend, kind), None)
+ else:
+ if kind in self._ACTIVE_ITEM_KINDS:
+ item = self._getActiveItem(kind=kind)
+ if item is not None: # Return active item if available
+ return item
+ # Return last visible item if any
+ itemClasses = self._KIND_TO_CLASSES[kind]
+ allItems = [item for item in self.getItems()
+ if isinstance(item, itemClasses) and item.isVisible()]
+ return allItems[-1] if allItems else None
+
+ # Limits
+
+ def _notifyLimitsChanged(self, emitSignal=True):
+ """Send an event when plot area limits are changed."""
+ xRange = self._xAxis.getLimits()
+ yRange = self._yAxis.getLimits()
+ y2Range = self._yRightAxis.getLimits()
+ if emitSignal:
+ axes = self.getXAxis(), self.getYAxis(), self.getYAxis(axis="right")
+ ranges = xRange, yRange, y2Range
+ for axis, limits in zip(axes, ranges):
+ axis.sigLimitsChanged.emit(*limits)
+ event = PlotEvents.prepareLimitsChangedSignal(
+ id(self.getWidgetHandle()), xRange, yRange, y2Range)
+ self.notify(**event)
+
+ def getLimitsHistory(self):
+ """Returns the object handling the history of limits of the plot"""
+ return self._limitsHistory
+
+ def getGraphXLimits(self):
+ """Get the graph X (bottom) limits.
+
+ :return: Minimum and maximum values of the X axis
+ """
+ return self._backend.getGraphXLimits()
+
+ def setGraphXLimits(self, xmin, xmax):
+ """Set the graph X (bottom) limits.
+
+ :param float xmin: minimum bottom axis value
+ :param float xmax: maximum bottom axis value
+ """
+ self._xAxis.setLimits(xmin, xmax)
+
+ def getGraphYLimits(self, axis='left'):
+ """Get the graph Y limits.
+
+ :param str axis: The axis for which to get the limits:
+ Either 'left' or 'right'
+ :return: Minimum and maximum values of the X axis
+ """
+ assert axis in ('left', 'right')
+ yAxis = self._yAxis if axis == 'left' else self._yRightAxis
+ return yAxis.getLimits()
+
+ def setGraphYLimits(self, ymin, ymax, axis='left'):
+ """Set the graph Y limits.
+
+ :param float ymin: minimum bottom axis value
+ :param float ymax: maximum bottom axis value
+ :param str axis: The axis for which to get the limits:
+ Either 'left' or 'right'
+ """
+ assert axis in ('left', 'right')
+ yAxis = self._yAxis if axis == 'left' else self._yRightAxis
+ return yAxis.setLimits(ymin, ymax)
+
+ def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None):
+ """Set the limits of the X and Y axes at once.
+
+ If y2min or y2max is None, the right Y axis limits are not updated.
+
+ :param float xmin: minimum bottom axis value
+ :param float xmax: maximum bottom axis value
+ :param float ymin: minimum left axis value
+ :param float ymax: maximum left axis value
+ :param float y2min: minimum right axis value or None (the default)
+ :param float y2max: maximum right axis value or None (the default)
+ """
+ # Deal with incorrect values
+ axis = self.getXAxis()
+ xmin, xmax = axis._checkLimits(xmin, xmax)
+ axis = self.getYAxis()
+ ymin, ymax = axis._checkLimits(ymin, ymax)
+
+ if y2min is None or y2max is None:
+ # if one limit is None, both are ignored
+ y2min, y2max = None, None
+ else:
+ axis = self.getYAxis(axis="right")
+ y2min, y2max = axis._checkLimits(y2min, y2max)
+
+ if self._viewConstrains:
+ view = self._viewConstrains.normalize(xmin, xmax, ymin, ymax)
+ xmin, xmax, ymin, ymax = view
+
+ self._backend.setLimits(xmin, xmax, ymin, ymax, y2min, y2max)
+ self._setDirtyPlot()
+ self._notifyLimitsChanged()
+
+ def _getViewConstraints(self):
+ """Return the plot object managing constaints on the plot view.
+
+ :rtype: ViewConstraints
+ """
+ if self._viewConstrains is None:
+ self._viewConstrains = ViewConstraints()
+ return self._viewConstrains
+
+ # Title and labels
+
+ def getGraphTitle(self):
+ """Return the plot main title as a str."""
+ return self._graphTitle
+
+ def setGraphTitle(self, title=""):
+ """Set the plot main title.
+
+ :param str title: Main title of the plot (default: '')
+ """
+ self._graphTitle = str(title)
+ self._backend.setGraphTitle(title)
+ self._setDirtyPlot()
+
+ def getGraphXLabel(self):
+ """Return the current X axis label as a str."""
+ return self._xAxis.getLabel()
+
+ def setGraphXLabel(self, label="X"):
+ """Set the plot X axis label.
+
+ The provided label can be temporarily replaced by the X label of the
+ active curve if any.
+
+ :param str label: The X axis label (default: 'X')
+ """
+ self._xAxis.setLabel(label)
+
+ def getGraphYLabel(self, axis='left'):
+ """Return the current Y axis label as a str.
+
+ :param str axis: The Y axis for which to get the label (left or right)
+ """
+ assert axis in ('left', 'right')
+ yAxis = self._yAxis if axis == 'left' else self._yRightAxis
+ return yAxis.getLabel()
+
+ def setGraphYLabel(self, label="Y", axis='left'):
+ """Set the plot Y axis label.
+
+ The provided label can be temporarily replaced by the Y label of the
+ active curve if any.
+
+ :param str label: The Y axis label (default: 'Y')
+ :param str axis: The Y axis for which to set the label (left or right)
+ """
+ assert axis in ('left', 'right')
+ yAxis = self._yAxis if axis == 'left' else self._yRightAxis
+ return yAxis.setLabel(label)
+
+ # Axes
+
+ def getXAxis(self):
+ """Returns the X axis
+
+ .. versionadded:: 0.6
+
+ :rtype: :class:`.items.Axis`
+ """
+ return self._xAxis
+
+ def getYAxis(self, axis="left"):
+ """Returns an Y axis
+
+ .. versionadded:: 0.6
+
+ :param str axis: The Y axis to return
+ ('left' or 'right').
+ :rtype: :class:`.items.Axis`
+ """
+ assert(axis in ["left", "right"])
+ return self._yAxis if axis == "left" else self._yRightAxis
+
+ def setAxesDisplayed(self, displayed: bool):
+ """Display or not the axes.
+
+ :param bool displayed: If `True` axes are displayed. If `False` axes
+ are not anymore visible and the margin used for them is removed.
+ """
+ if displayed != self.__axesDisplayed:
+ self.__axesDisplayed = displayed
+ if displayed:
+ self._backend.setAxesMargins(*self.__axesMargins)
+ else:
+ self._backend.setAxesMargins(0., 0., 0., 0.)
+ self._setDirtyPlot()
+ self._sigAxesVisibilityChanged.emit(displayed)
+
+ def isAxesDisplayed(self) -> bool:
+ """Returns whether or not axes are currently displayed
+
+ :rtype: bool
+ """
+ return self.__axesDisplayed
+
+ def setAxesMargins(
+ self, left: float, top: float, right: float, bottom: float):
+ """Set ratios of margins surrounding data plot area.
+
+ All ratios must be within [0., 1.].
+ Sums of ratios of opposed side must be < 1.
+
+ :param float left: Left-side margin ratio.
+ :param float top: Top margin ratio
+ :param float right: Right-side margin ratio
+ :param float bottom: Bottom margin ratio
+ :raises ValueError:
+ """
+ for value in (left, top, right, bottom):
+ if value < 0. or value > 1.:
+ raise ValueError("Margin ratios must be within [0., 1.]")
+ if left + right >= 1. or top + bottom >= 1.:
+ raise ValueError("Sum of ratios of opposed sides >= 1")
+ margins = left, top, right, bottom
+
+ if margins != self.__axesMargins:
+ self.__axesMargins = margins
+ if self.isAxesDisplayed(): # Only apply if axes are displayed
+ self._backend.setAxesMargins(*margins)
+ self._setDirtyPlot()
+
+ def getAxesMargins(self):
+ """Returns ratio of margins surrounding data plot area.
+
+ :return: (left, top, right, bottom)
+ :rtype: List[float]
+ """
+ return self.__axesMargins
+
+ def setYAxisInverted(self, flag=True):
+ """Set the Y axis orientation.
+
+ :param bool flag: True for Y axis going from top to bottom,
+ False for Y axis going from bottom to top
+ """
+ self._yAxis.setInverted(flag)
+
+ def isYAxisInverted(self):
+ """Return True if Y axis goes from top to bottom, False otherwise."""
+ return self._yAxis.isInverted()
+
+ def isXAxisLogarithmic(self):
+ """Return True if X axis scale is logarithmic, False if linear."""
+ return self._xAxis._isLogarithmic()
+
+ def setXAxisLogarithmic(self, flag):
+ """Set the bottom X axis scale (either linear or logarithmic).
+
+ :param bool flag: True to use a logarithmic scale, False for linear.
+ """
+ self._xAxis._setLogarithmic(flag)
+
+ def isYAxisLogarithmic(self):
+ """Return True if Y axis scale is logarithmic, False if linear."""
+ return self._yAxis._isLogarithmic()
+
+ def setYAxisLogarithmic(self, flag):
+ """Set the Y axes scale (either linear or logarithmic).
+
+ :param bool flag: True to use a logarithmic scale, False for linear.
+ """
+ self._yAxis._setLogarithmic(flag)
+
+ def isXAxisAutoScale(self):
+ """Return True if X axis is automatically adjusting its limits."""
+ return self._xAxis.isAutoScale()
+
+ def setXAxisAutoScale(self, flag=True):
+ """Set the X axis limits adjusting behavior of :meth:`resetZoom`.
+
+ :param bool flag: True to resize limits automatically,
+ False to disable it.
+ """
+ self._xAxis.setAutoScale(flag)
+
+ def isYAxisAutoScale(self):
+ """Return True if Y axes are automatically adjusting its limits."""
+ return self._yAxis.isAutoScale()
+
+ def setYAxisAutoScale(self, flag=True):
+ """Set the Y axis limits adjusting behavior of :meth:`resetZoom`.
+
+ :param bool flag: True to resize limits automatically,
+ False to disable it.
+ """
+ self._yAxis.setAutoScale(flag)
+
+ def isKeepDataAspectRatio(self):
+ """Returns whether the plot is keeping data aspect ratio or not."""
+ return self._backend.isKeepDataAspectRatio()
+
+ def setKeepDataAspectRatio(self, flag=True):
+ """Set whether the plot keeps data aspect ratio or not.
+
+ :param bool flag: True to respect data aspect ratio
+ """
+ flag = bool(flag)
+ if flag == self.isKeepDataAspectRatio():
+ return
+ self._backend.setKeepDataAspectRatio(flag=flag)
+ self._setDirtyPlot()
+ self._forceResetZoom()
+ self.notify('setKeepDataAspectRatio', state=flag)
+
+ def getGraphGrid(self):
+ """Return the current grid mode, either None, 'major' or 'both'.
+
+ See :meth:`setGraphGrid`.
+ """
+ return self._grid
+
+ def setGraphGrid(self, which=True):
+ """Set the type of grid to display.
+
+ :param which: None or False to disable the grid,
+ 'major' or True for grid on major ticks (the default),
+ 'both' for grid on both major and minor ticks.
+ :type which: str of bool
+ """
+ assert which in (None, True, False, 'both', 'major')
+ if not which:
+ which = None
+ elif which is True:
+ which = 'major'
+ self._grid = which
+ self._backend.setGraphGrid(which)
+ self._setDirtyPlot()
+ self.notify('setGraphGrid', which=str(which))
+
+ # Defaults
+
+ def isDefaultPlotPoints(self):
+ """Return True if the default Curve symbol is set and False if not."""
+ return self._defaultPlotPoints == silx.config.DEFAULT_PLOT_SYMBOL
+
+ def setDefaultPlotPoints(self, flag):
+ """Set the default symbol of all curves.
+
+ When called, this reset the symbol of all existing curves.
+
+ :param bool flag: True to use 'o' as the default curve symbol,
+ False to use no symbol.
+ """
+ self._defaultPlotPoints = silx.config.DEFAULT_PLOT_SYMBOL if flag else ''
+
+ # Reset symbol of all curves
+ curves = self.getAllCurves(just_legend=False, withhidden=True)
+
+ if curves:
+ for curve in curves:
+ curve.setSymbol(self._defaultPlotPoints)
+
+ def isDefaultPlotLines(self):
+ """Return True for line as default line style, False for no line."""
+ return self._plotLines
+
+ def setDefaultPlotLines(self, flag):
+ """Toggle the use of lines as the default curve line style.
+
+ :param bool flag: True to use a line as the default line style,
+ False to use no line as the default line style.
+ """
+ self._plotLines = bool(flag)
+
+ linestyle = '-' if self._plotLines else ' '
+
+ # Reset linestyle of all curves
+ curves = self.getAllCurves(withhidden=True)
+
+ if curves:
+ for curve in curves:
+ curve.setLineStyle(linestyle)
+
+ def getDefaultColormap(self):
+ """Return the default colormap used by :meth:`addImage`.
+
+ :rtype: ~silx.gui.colors.Colormap
+ """
+ return self._defaultColormap
+
+ def setDefaultColormap(self, colormap=None):
+ """Set the default colormap used by :meth:`addImage`.
+
+ Setting the default colormap do not change any currently displayed
+ image.
+ It only affects future calls to :meth:`addImage` without the colormap
+ parameter.
+
+ :param ~silx.gui.colors.Colormap colormap:
+ The description of the default colormap, or
+ None to set the colormap to a linear
+ autoscale gray colormap.
+ """
+ if colormap is None:
+ colormap = Colormap(name=silx.config.DEFAULT_COLORMAP_NAME,
+ normalization='linear',
+ vmin=None,
+ vmax=None)
+ if isinstance(colormap, dict):
+ self._defaultColormap = Colormap._fromDict(colormap)
+ else:
+ assert isinstance(colormap, Colormap)
+ self._defaultColormap = colormap
+ self.notify('defaultColormapChanged')
+
+ @staticmethod
+ def getSupportedColormaps():
+ """Get the supported colormap names as a tuple of str.
+
+ The list contains at least:
+ ('gray', 'reversed gray', 'temperature', 'red', 'green', 'blue',
+ 'magma', 'inferno', 'plasma', 'viridis')
+ """
+ return Colormap.getSupportedColormaps()
+
+ def _resetColorAndStyle(self):
+ self._colorIndex = 0
+ self._styleIndex = 0
+
+ def _getColorAndStyle(self):
+ color = self.colorList[self._colorIndex]
+ style = self._styleList[self._styleIndex]
+
+ # Loop over color and then styles
+ self._colorIndex += 1
+ if self._colorIndex >= len(self.colorList):
+ self._colorIndex = 0
+ self._styleIndex = (self._styleIndex + 1) % len(self._styleList)
+
+ # If color is the one of active curve, take the next one
+ if colors.rgba(color) == self.getActiveCurveStyle().getColor():
+ color, style = self._getColorAndStyle()
+
+ if not self._plotLines:
+ style = ' '
+
+ return color, style
+
+ # Misc.
+
+ def getWidgetHandle(self):
+ """Return the widget the plot is displayed in.
+
+ This widget is owned by the backend.
+ """
+ return self._backend.getWidgetHandle()
+
+ def notify(self, event, **kwargs):
+ """Send an event to the listeners and send signals.
+
+ Event are passed to the registered callback as a dict with an 'event'
+ key for backward compatibility with PyMca.
+
+ :param str event: The type of event
+ :param kwargs: The information of the event.
+ """
+ eventDict = kwargs.copy()
+ eventDict['event'] = event
+ self.sigPlotSignal.emit(eventDict)
+
+ if event == 'setKeepDataAspectRatio':
+ self.sigSetKeepDataAspectRatio.emit(kwargs['state'])
+ elif event == 'setGraphGrid':
+ self.sigSetGraphGrid.emit(kwargs['which'])
+ elif event == 'setGraphCursor':
+ self.sigSetGraphCursor.emit(kwargs['state'])
+ elif event == 'contentChanged':
+ self.sigContentChanged.emit(
+ kwargs['action'], kwargs['kind'], kwargs['legend'])
+ elif event == 'activeCurveChanged':
+ self.sigActiveCurveChanged.emit(
+ kwargs['previous'], kwargs['legend'])
+ elif event == 'activeImageChanged':
+ self.sigActiveImageChanged.emit(
+ kwargs['previous'], kwargs['legend'])
+ elif event == 'activeScatterChanged':
+ self.sigActiveScatterChanged.emit(
+ kwargs['previous'], kwargs['legend'])
+ elif event == 'interactiveModeChanged':
+ self.sigInteractiveModeChanged.emit(kwargs['source'])
+
+ eventDict = kwargs.copy()
+ eventDict['event'] = event
+ self._callback(eventDict)
+
+ def setCallback(self, callbackFunction=None):
+ """Attach a listener to the backend.
+
+ Limitation: Only one listener at a time.
+
+ :param callbackFunction: function accepting a dictionary as input
+ to handle the graph events
+ If None (default), use a default listener.
+ """
+ # TODO allow multiple listeners
+ # allow register listener by event type
+ if callbackFunction is None:
+ callbackFunction = WeakMethodProxy(self.graphCallback)
+ self._callback = callbackFunction
+
+ def graphCallback(self, ddict=None):
+ """This callback is going to receive all the events from the plot.
+
+ Those events will consist on a dictionary and among the dictionary
+ keys the key 'event' is mandatory to describe the type of event.
+ This default implementation only handles setting the active curve.
+ """
+
+ if ddict is None:
+ ddict = {}
+ _logger.debug("Received dict keys = %s", str(ddict.keys()))
+ _logger.debug(str(ddict))
+ if ddict['event'] in ["legendClicked", "curveClicked"]:
+ if ddict['button'] == "left":
+ self.setActiveCurve(ddict['label'])
+ qt.QToolTip.showText(self.cursor().pos(), ddict['label'])
+ elif ddict['event'] == 'mouseClicked' and ddict['button'] == 'left':
+ self.setActiveCurve(None)
+
+ def saveGraph(self, filename, fileFormat=None, dpi=None):
+ """Save a snapshot of the plot.
+
+ Supported file formats depends on the backend in use.
+ The following file formats are always supported: "png", "svg".
+ The matplotlib backend supports more formats:
+ "pdf", "ps", "eps", "tiff", "jpeg", "jpg".
+
+ :param filename: Destination
+ :type filename: str, StringIO or BytesIO
+ :param str fileFormat: String specifying the format
+ :return: False if cannot save the plot, True otherwise
+ """
+ if fileFormat is None:
+ if not hasattr(filename, 'lower'):
+ _logger.warning(
+ 'saveGraph cancelled, cannot define file format.')
+ return False
+ else:
+ fileFormat = (filename.split(".")[-1]).lower()
+
+ supportedFormats = ("png", "svg", "pdf", "ps", "eps",
+ "tif", "tiff", "jpeg", "jpg")
+
+ if fileFormat not in supportedFormats:
+ _logger.warning('Unsupported format %s', fileFormat)
+ return False
+ else:
+ self._backend.saveGraph(filename,
+ fileFormat=fileFormat,
+ dpi=dpi)
+ return True
+
+ def getDataMargins(self):
+ """Get the default data margin ratios, see :meth:`setDataMargins`.
+
+ :return: The margin ratios for each side (xMin, xMax, yMin, yMax).
+ :rtype: A 4-tuple of floats.
+ """
+ return self._defaultDataMargins
+
+ def setDataMargins(self, xMinMargin=0., xMaxMargin=0.,
+ yMinMargin=0., yMaxMargin=0.):
+ """Set the default data margins to use in :meth:`resetZoom`.
+
+ Set the default ratios of margins (as floats) to add around the data
+ inside the plot area for each side.
+ """
+ self._defaultDataMargins = (xMinMargin, xMaxMargin,
+ yMinMargin, yMaxMargin)
+
+ def getAutoReplot(self):
+ """Return True if replot is automatically handled, False otherwise.
+
+ See :meth`setAutoReplot`.
+ """
+ return self._autoreplot
+
+ def setAutoReplot(self, autoreplot=True):
+ """Set automatic replot mode.
+
+ When enabled, the plot is redrawn automatically when changed.
+ When disabled, the plot is not redrawn when its content change.
+ Instead, it :meth:`replot` must be called.
+
+ :param bool autoreplot: True to enable it (default),
+ False to disable it.
+ """
+ self._autoreplot = bool(autoreplot)
+
+ # If the plot is dirty before enabling autoreplot,
+ # then _backend.postRedisplay will never be called from _setDirtyPlot
+ if self._autoreplot and self._getDirtyPlot():
+ self._backend.postRedisplay()
+
+ @contextmanager
+ def _paintContext(self):
+ """This context MUST surround backend rendering.
+
+ It is in charge of performing required PlotWidget operations
+ """
+ for item in self._contentToUpdate:
+ item._update(self._backend)
+
+ self._contentToUpdate = []
+ yield
+ self._dirty = False # reset dirty flag
+
+ def replot(self):
+ """Request to draw the plot."""
+ self._backend.replot()
+
+ def _forceResetZoom(self, dataMargins=None):
+ """Reset the plot limits to the bounds of the data and redraw the plot.
+
+ This method forces a reset zoom and does not check axis autoscale.
+
+ Extra margins can be added around the data inside the plot area
+ (see :meth:`setDataMargins`).
+ Margins are given as one ratio of the data range per limit of the
+ data (xMin, xMax, yMin and yMax limits).
+ For log scale, extra margins are applied in log10 of the data.
+
+ :param dataMargins: Ratios of margins to add around the data inside
+ the plot area for each side (default: no margins).
+ :type dataMargins: A 4-tuple of float as (xMin, xMax, yMin, yMax).
+ """
+ if dataMargins is None:
+ dataMargins = self._defaultDataMargins
+
+ # Get data range
+ ranges = self.getDataRange()
+ xmin, xmax = (1., 100.) if ranges.x is None else ranges.x
+ ymin, ymax = (1., 100.) if ranges.y is None else ranges.y
+ if ranges.yright is None:
+ ymin2, ymax2 = ymin, ymax
+ else:
+ ymin2, ymax2 = ranges.yright
+ if ranges.y is None:
+ ymin, ymax = ranges.yright
+
+ # Add margins around data inside the plot area
+ newLimits = list(_utils.addMarginsToLimits(
+ dataMargins,
+ self._xAxis._isLogarithmic(),
+ self._yAxis._isLogarithmic(),
+ xmin, xmax, ymin, ymax, ymin2, ymax2))
+
+ if self.isKeepDataAspectRatio():
+ # Use limits with margins to keep ratio
+ xmin, xmax, ymin, ymax = newLimits[:4]
+
+ # Compute bbox wth figure aspect ratio
+ plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:]
+ if plotWidth > 0 and plotHeight > 0:
+ plotRatio = plotHeight / plotWidth
+ dataRatio = (ymax - ymin) / (xmax - xmin)
+ if dataRatio < plotRatio:
+ # Increase y range
+ ycenter = 0.5 * (ymax + ymin)
+ yrange = (xmax - xmin) * plotRatio
+ newLimits[2] = ycenter - 0.5 * yrange
+ newLimits[3] = ycenter + 0.5 * yrange
+
+ elif dataRatio > plotRatio:
+ # Increase x range
+ xcenter = 0.5 * (xmax + xmin)
+ xrange_ = (ymax - ymin) / plotRatio
+ newLimits[0] = xcenter - 0.5 * xrange_
+ newLimits[1] = xcenter + 0.5 * xrange_
+
+ self.setLimits(*newLimits)
+
+ def resetZoom(self, dataMargins=None):
+ """Reset the plot limits to the bounds of the data and redraw the plot.
+
+ It automatically scale limits of axes that are in autoscale mode
+ (see :meth:`getXAxis`, :meth:`getYAxis` and :meth:`Axis.setAutoScale`).
+ It keeps current limits on axes that are not in autoscale mode.
+
+ Extra margins can be added around the data inside the plot area
+ (see :meth:`setDataMargins`).
+ Margins are given as one ratio of the data range per limit of the
+ data (xMin, xMax, yMin and yMax limits).
+ For log scale, extra margins are applied in log10 of the data.
+
+ :param dataMargins: Ratios of margins to add around the data inside
+ the plot area for each side (default: no margins).
+ :type dataMargins: A 4-tuple of float as (xMin, xMax, yMin, yMax).
+ """
+ xLimits = self._xAxis.getLimits()
+ yLimits = self._yAxis.getLimits()
+ y2Limits = self._yRightAxis.getLimits()
+
+ xAuto = self._xAxis.isAutoScale()
+ yAuto = self._yAxis.isAutoScale()
+
+ # With log axes, autoscale if limits are <= 0
+ # This avoids issues with toggling log scale with matplotlib 2.1.0
+ if self._xAxis.getScale() == self._xAxis.LOGARITHMIC and xLimits[0] <= 0:
+ xAuto = True
+ if self._yAxis.getScale() == self._yAxis.LOGARITHMIC and (yLimits[0] <= 0 or y2Limits[0] <= 0):
+ yAuto = True
+
+ if not xAuto and not yAuto:
+ _logger.debug("Nothing to autoscale")
+ else: # Some axes to autoscale
+ self._forceResetZoom(dataMargins=dataMargins)
+
+ # Restore limits for axis not in autoscale
+ if not xAuto and yAuto:
+ self.setGraphXLimits(*xLimits)
+ elif xAuto and not yAuto:
+ if y2Limits is not None:
+ self.setGraphYLimits(
+ y2Limits[0], y2Limits[1], axis='right')
+ if yLimits is not None:
+ self.setGraphYLimits(yLimits[0], yLimits[1], axis='left')
+
+ if (xLimits != self._xAxis.getLimits() or
+ yLimits != self._yAxis.getLimits() or
+ y2Limits != self._yRightAxis.getLimits()):
+ self._notifyLimitsChanged()
+
+ # Coord conversion
+
+ def dataToPixel(self, x=None, y=None, axis="left", check=True):
+ """Convert a position in data coordinates to a position in pixels.
+
+ :param float x: The X coordinate in data space. If None (default)
+ the middle position of the displayed data is used.
+ :param float y: The Y coordinate in data space. If None (default)
+ the middle position of the displayed data is used.
+ :param str axis: The Y axis to use for the conversion
+ ('left' or 'right').
+ :param bool check: True to return None if outside displayed area,
+ False to convert to pixels anyway
+ :returns: The corresponding position in pixels or
+ None if the data position is not in the displayed area and
+ check is True.
+ :rtype: A tuple of 2 floats: (xPixel, yPixel) or None.
+ """
+ assert axis in ("left", "right")
+
+ xmin, xmax = self._xAxis.getLimits()
+ yAxis = self.getYAxis(axis=axis)
+ ymin, ymax = yAxis.getLimits()
+
+ if x is None:
+ x = 0.5 * (xmax + xmin)
+ if y is None:
+ y = 0.5 * (ymax + ymin)
+
+ if check:
+ if x > xmax or x < xmin:
+ return None
+
+ if y > ymax or y < ymin:
+ return None
+
+ return self._backend.dataToPixel(x, y, axis=axis)
+
+ def pixelToData(self, x, y, axis="left", check=False):
+ """Convert a position in pixels to a position in data coordinates.
+
+ :param float x: The X coordinate in pixels. If None (default)
+ the center of the widget is used.
+ :param float y: The Y coordinate in pixels. If None (default)
+ the center of the widget is used.
+ :param str axis: The Y axis to use for the conversion
+ ('left' or 'right').
+ :param bool check: Toggle checking if pixel is in plot area.
+ If False, this method never returns None.
+ :returns: The corresponding position in data space or
+ None if the pixel position is not in the plot area.
+ :rtype: A tuple of 2 floats: (xData, yData) or None.
+ """
+ assert axis in ("left", "right")
+
+ if x is None:
+ x = self.width() // 2
+ if y is None:
+ y = self.height() // 2
+
+ if check:
+ left, top, width, height = self.getPlotBoundsInPixels()
+ if not (left <= x <= left + width and top <= y <= top + height):
+ return None
+
+ return self._backend.pixelToData(x, y, axis)
+
+ def getPlotBoundsInPixels(self):
+ """Plot area bounds in widget coordinates in pixels.
+
+ :return: bounds as a 4-tuple of int: (left, top, width, height)
+ """
+ return self._backend.getPlotBoundsInPixels()
+
+ # Interaction support
+
+ def getGraphCursorShape(self):
+ """Returns the current cursor shape.
+
+ :rtype: str
+ """
+ return self.__graphCursorShape
+
+ def setGraphCursorShape(self, cursor=None):
+ """Set the cursor shape.
+
+ :param str cursor: Name of the cursor shape
+ """
+ self.__graphCursorShape = cursor
+ self._backend.setGraphCursorShape(cursor)
+
+ @deprecated(replacement='getItems', since_version='0.13')
+ def _getAllMarkers(self, just_legend=False):
+ markers = [item for item in self.getItems() if isinstance(item, items.MarkerBase)]
+ if just_legend:
+ return [marker.getName() for marker in markers]
+ else:
+ return markers
+
+ def _getMarkerAt(self, x, y):
+ """Return the most interactive marker at a location, else None
+
+ :param float x: X position in pixels
+ :param float y: Y position in pixels
+ :rtype: None of marker object
+ """
+ def checkDraggable(item):
+ return isinstance(item, items.MarkerBase) and item.isDraggable()
+ def checkSelectable(item):
+ return isinstance(item, items.MarkerBase) and item.isSelectable()
+ def check(item):
+ return isinstance(item, items.MarkerBase)
+
+ result = self._pickTopMost(x, y, checkDraggable)
+ if not result:
+ result = self._pickTopMost(x, y, checkSelectable)
+ if not result:
+ result = self._pickTopMost(x, y, check)
+ marker = result.getItem() if result is not None else None
+ return marker
+
+ def _getMarker(self, legend=None):
+ """Get the object describing a specific marker.
+
+ It returns None in case no matching marker is found
+
+ :param str legend: The legend of the marker to retrieve
+ :rtype: None of marker object
+ """
+ return self._getItem(kind='marker', legend=legend)
+
+ def pickItems(self, x, y, condition=None):
+ """Generator of picked items in the plot at given position.
+
+ Items are returned from front to back.
+
+ :param float x: X position in pixels
+ :param float y: Y position in pixels
+ :param callable condition:
+ Callable taking an item as input and returning False for items
+ to skip during picking. If None (default) no item is skipped.
+ :return: Iterable of :class:`PickingResult` objects at picked position.
+ Items are ordered from front to back.
+ """
+ for item in reversed(self._backend.getItemsFromBackToFront(condition=condition)):
+ result = item.pick(x, y)
+ if result is not None:
+ yield result
+
+ def _pickTopMost(self, x, y, condition=None):
+ """Returns top-most picked item in the plot at given position.
+
+ Items are checked from front to back.
+
+ :param float x: X position in pixels
+ :param float y: Y position in pixels
+ :param callable condition:
+ Callable taking an item as input and returning False for items
+ to skip during picking. If None (default) no item is skipped.
+ :return: :class:`PickingResult` object at picked position.
+ If no item is picked, it returns None
+ :rtype: Union[None,PickingResult]
+ """
+ for result in self.pickItems(x, y, condition):
+ return result
+ return None
+
+ # User event handling #
+
+ def _isPositionInPlotArea(self, x, y):
+ """Project position in pixel to the closest point in the plot area
+
+ :param float x: X coordinate in widget coordinate (in pixel)
+ :param float y: Y coordinate in widget coordinate (in pixel)
+ :return: (x, y) in widget coord (in pixel) in the plot area
+ """
+ left, top, width, height = self.getPlotBoundsInPixels()
+ xPlot = numpy.clip(x, left, left + width)
+ yPlot = numpy.clip(y, top, top + height)
+ return xPlot, yPlot
+
+ def onMousePress(self, xPixel, yPixel, btn):
+ """Handle mouse press event.
+
+ :param float xPixel: X mouse position in pixels
+ :param float yPixel: Y mouse position in pixels
+ :param str btn: Mouse button in 'left', 'middle', 'right'
+ """
+ if self._isPositionInPlotArea(xPixel, yPixel) == (xPixel, yPixel):
+ self._pressedButtons.append(btn)
+ self._eventHandler.handleEvent('press', xPixel, yPixel, btn)
+
+ def onMouseMove(self, xPixel, yPixel):
+ """Handle mouse move event.
+
+ :param float xPixel: X mouse position in pixels
+ :param float yPixel: Y mouse position in pixels
+ """
+ inXPixel, inYPixel = self._isPositionInPlotArea(xPixel, yPixel)
+ isCursorInPlot = inXPixel == xPixel and inYPixel == yPixel
+
+ if self._cursorInPlot != isCursorInPlot:
+ self._cursorInPlot = isCursorInPlot
+ self._eventHandler.handleEvent(
+ 'enter' if self._cursorInPlot else 'leave')
+
+ if isCursorInPlot:
+ # Signal mouse move event
+ dataPos = self.pixelToData(inXPixel, inYPixel)
+ assert dataPos is not None
+
+ btn = self._pressedButtons[-1] if self._pressedButtons else None
+ event = PlotEvents.prepareMouseSignal(
+ 'mouseMoved', btn, dataPos[0], dataPos[1], xPixel, yPixel)
+ self.notify(**event)
+
+ # Either button was pressed in the plot or cursor is in the plot
+ if isCursorInPlot or self._pressedButtons:
+ self._eventHandler.handleEvent('move', inXPixel, inYPixel)
+
+ def onMouseRelease(self, xPixel, yPixel, btn):
+ """Handle mouse release event.
+
+ :param float xPixel: X mouse position in pixels
+ :param float yPixel: Y mouse position in pixels
+ :param str btn: Mouse button in 'left', 'middle', 'right'
+ """
+ try:
+ self._pressedButtons.remove(btn)
+ except ValueError:
+ pass
+ else:
+ xPixel, yPixel = self._isPositionInPlotArea(xPixel, yPixel)
+ self._eventHandler.handleEvent('release', xPixel, yPixel, btn)
+
+ def onMouseWheel(self, xPixel, yPixel, angleInDegrees):
+ """Handle mouse wheel event.
+
+ :param float xPixel: X mouse position in pixels
+ :param float yPixel: Y mouse position in pixels
+ :param float angleInDegrees: Angle corresponding to wheel motion.
+ Positive for movement away from the user,
+ negative for movement toward the user.
+ """
+ if self._isPositionInPlotArea(xPixel, yPixel) == (xPixel, yPixel):
+ self._eventHandler.handleEvent(
+ 'wheel', xPixel, yPixel, angleInDegrees)
+
+ def onMouseLeaveWidget(self):
+ """Handle mouse leave widget event."""
+ if self._cursorInPlot:
+ self._cursorInPlot = False
+ self._eventHandler.handleEvent('leave')
+
+ # Interaction modes #
+
+ def getInteractiveMode(self):
+ """Returns the current interactive mode as a dict.
+
+ The returned dict contains at least the key 'mode'.
+ Mode can be: 'draw', 'pan', 'select', 'select-draw', 'zoom'.
+ It can also contains extra keys (e.g., 'color') specific to a mode
+ as provided to :meth:`setInteractiveMode`.
+ """
+ return self._eventHandler.getInteractiveMode()
+
+ def resetInteractiveMode(self):
+ """Reset the interactive mode to use the previous basic interactive
+ mode used.
+
+ It can be one of "zoom" or "pan".
+ """
+ mode, zoomOnWheel = self._previousDefaultMode
+ self.setInteractiveMode(mode=mode, zoomOnWheel=zoomOnWheel)
+
+ def setInteractiveMode(self, mode, color='black',
+ shape='polygon', label=None,
+ zoomOnWheel=True, source=None, width=None):
+ """Switch the interactive mode.
+
+ :param str mode: The name of the interactive mode.
+ In 'draw', 'pan', 'select', 'select-draw', 'zoom'.
+ :param color: Only for 'draw' and 'zoom' modes.
+ Color to use for drawing selection area. Default black.
+ :type color: Color description: The name as a str or
+ a tuple of 4 floats.
+ :param str shape: Only for 'draw' mode. The kind of shape to draw.
+ In 'polygon', 'rectangle', 'line', 'vline', 'hline',
+ 'freeline'.
+ Default is 'polygon'.
+ :param str label: Only for 'draw' mode, sent in drawing events.
+ :param bool zoomOnWheel: Toggle zoom on wheel support
+ :param source: A user-defined object (typically the caller object)
+ that will be send in the interactiveModeChanged event,
+ to identify which object required a mode change.
+ Default: None
+ :param float width: Width of the pencil. Only for draw pencil mode.
+ """
+ self._eventHandler.setInteractiveMode(mode, color, shape, label, width)
+ self._eventHandler.zoomOnWheel = zoomOnWheel
+ if mode in ["pan", "zoom"]:
+ self._previousDefaultMode = mode, zoomOnWheel
+
+ self.notify(
+ 'interactiveModeChanged', source=source)
+
+ # Panning with arrow keys
+
+ def isPanWithArrowKeys(self):
+ """Returns whether or not panning the graph with arrow keys is enabled.
+
+ See :meth:`setPanWithArrowKeys`.
+ """
+ return self._panWithArrowKeys
+
+ def setPanWithArrowKeys(self, pan=False):
+ """Enable/Disable panning the graph with arrow keys.
+
+ This grabs the keyboard.
+
+ :param bool pan: True to enable panning, False to disable.
+ """
+ pan = bool(pan)
+ panHasChanged = self._panWithArrowKeys != pan
+
+ self._panWithArrowKeys = pan
+ if not self._panWithArrowKeys:
+ self.setFocusPolicy(qt.Qt.NoFocus)
+ else:
+ self.setFocusPolicy(qt.Qt.StrongFocus)
+ self.setFocus(qt.Qt.OtherFocusReason)
+
+ if panHasChanged:
+ self.sigSetPanWithArrowKeys.emit(pan)
+
+ # Dict to convert Qt arrow key code to direction str.
+ _ARROWS_TO_PAN_DIRECTION = {
+ qt.Qt.Key_Left: 'left',
+ qt.Qt.Key_Right: 'right',
+ qt.Qt.Key_Up: 'up',
+ qt.Qt.Key_Down: 'down'
+ }
+
+ def __simulateMouseMove(self):
+ qapp = qt.QApplication.instance()
+ event = qt.QMouseEvent(
+ qt.QEvent.MouseMove,
+ self.getWidgetHandle().mapFromGlobal(qt.QCursor.pos()),
+ qt.Qt.NoButton,
+ qapp.mouseButtons(),
+ qapp.keyboardModifiers())
+ qapp.sendEvent(self.getWidgetHandle(), event)
+
+ def keyPressEvent(self, event):
+ """Key event handler handling panning on arrow keys.
+
+ Overrides base class implementation.
+ """
+ key = event.key()
+ if self._panWithArrowKeys and key in self._ARROWS_TO_PAN_DIRECTION:
+ self.pan(self._ARROWS_TO_PAN_DIRECTION[key], factor=0.1)
+
+ # Send a mouse move event to the plot widget to take into account
+ # that even if mouse didn't move on the screen, it moved relative
+ # to the plotted data.
+ self.__simulateMouseMove()
+ else:
+ # Only call base class implementation when key is not handled.
+ # See QWidget.keyPressEvent for details.
+ super(PlotWidget, self).keyPressEvent(event)
diff --git a/src/silx/gui/plot/PlotWindow.py b/src/silx/gui/plot/PlotWindow.py
new file mode 100644
index 0000000..0349585
--- /dev/null
+++ b/src/silx/gui/plot/PlotWindow.py
@@ -0,0 +1,993 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""A :class:`.PlotWidget` with additional toolbars.
+
+The :class:`PlotWindow` is a subclass of :class:`.PlotWidget`.
+"""
+
+__authors__ = ["V.A. Sole", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "12/04/2019"
+
+try:
+ from collections import abc
+except ImportError: # Python2 support
+ import collections as abc
+import logging
+import weakref
+
+import silx
+from silx.utils.weakref import WeakMethodProxy
+from silx.utils.deprecation import deprecated
+from silx.utils.proxy import docstring
+
+from . import PlotWidget
+from . import actions
+from . import items
+from .actions import medfilt as actions_medfilt
+from .actions import fit as actions_fit
+from .actions import control as actions_control
+from .actions import histogram as actions_histogram
+from . import PlotToolButtons
+from . import tools
+from .Profile import ProfileToolBar
+from .LegendSelector import LegendsDockWidget
+from .CurvesROIWidget import CurvesROIDockWidget
+from .MaskToolsWidget import MaskToolsDockWidget
+from .StatsWidget import BasicStatsWidget
+from .ColorBar import ColorBarWidget
+try:
+ from ..console import IPythonDockWidget
+except ImportError:
+ IPythonDockWidget = None
+
+from .. import qt
+
+
+_logger = logging.getLogger(__name__)
+
+
+class PlotWindow(PlotWidget):
+ """Qt Widget providing a 1D/2D plot area and additional tools.
+
+ This widgets inherits from :class:`.PlotWidget` and provides its plot API.
+
+ Initialiser parameters:
+
+ :param parent: The parent of this widget or None.
+ :param backend: The backend to use for the plot (default: matplotlib).
+ See :class:`.PlotWidget` for the list of supported backend.
+ :type backend: str or :class:`BackendBase.BackendBase`
+ :param bool resetzoom: Toggle visibility of reset zoom action.
+ :param bool autoScale: Toggle visibility of axes autoscale actions.
+ :param bool logScale: Toggle visibility of axes log scale actions.
+ :param bool grid: Toggle visibility of grid mode action.
+ :param bool curveStyle: Toggle visibility of curve style action.
+ :param bool colormap: Toggle visibility of colormap action.
+ :param bool aspectRatio: Toggle visibility of aspect ratio button.
+ :param bool yInverted: Toggle visibility of Y axis direction button.
+ :param bool copy: Toggle visibility of copy action.
+ :param bool save: Toggle visibility of save action.
+ :param bool print_: Toggle visibility of print action.
+ :param bool control: True to display an Options button with a sub-menu
+ to show legends, toggle crosshair and pan with arrows.
+ (Default: False)
+ :param position: True to display widget with (x, y) mouse position
+ (Default: False).
+ It also supports a list of (name, funct(x, y)->value)
+ to customize the displayed values.
+ See :class:`~silx.gui.plot.tools.PositionInfo`.
+ :param bool roi: Toggle visibilty of ROI action.
+ :param bool mask: Toggle visibilty of mask action.
+ :param bool fit: Toggle visibilty of fit action.
+ """
+
+ def __init__(self, parent=None, backend=None,
+ resetzoom=True, autoScale=True, logScale=True, grid=True,
+ curveStyle=True, colormap=True,
+ aspectRatio=True, yInverted=True,
+ copy=True, save=True, print_=True,
+ control=False, position=False,
+ roi=True, mask=True, fit=False):
+ super(PlotWindow, self).__init__(parent=parent, backend=backend)
+ if parent is None:
+ self.setWindowTitle('PlotWindow')
+
+ self._dockWidgets = []
+
+ # lazy loaded dock widgets
+ self._legendsDockWidget = None
+ self._curvesROIDockWidget = None
+ self._maskToolsDockWidget = None
+ self._consoleDockWidget = None
+ self._statsDockWidget = None
+
+ # Create color bar, hidden by default for backward compatibility
+ self._colorbar = ColorBarWidget(parent=self, plot=self)
+
+ # Init actions
+ self.group = qt.QActionGroup(self)
+ self.group.setExclusive(False)
+
+ self.resetZoomAction = self.group.addAction(
+ actions.control.ResetZoomAction(self, parent=self))
+ self.resetZoomAction.setVisible(resetzoom)
+ self.addAction(self.resetZoomAction)
+
+ self.zoomInAction = actions.control.ZoomInAction(self, parent=self)
+ self.addAction(self.zoomInAction)
+
+ self.zoomOutAction = actions.control.ZoomOutAction(self, parent=self)
+ self.addAction(self.zoomOutAction)
+
+ self.xAxisAutoScaleAction = self.group.addAction(
+ actions.control.XAxisAutoScaleAction(self, parent=self))
+ self.xAxisAutoScaleAction.setVisible(autoScale)
+ self.addAction(self.xAxisAutoScaleAction)
+
+ self.yAxisAutoScaleAction = self.group.addAction(
+ actions.control.YAxisAutoScaleAction(self, parent=self))
+ self.yAxisAutoScaleAction.setVisible(autoScale)
+ self.addAction(self.yAxisAutoScaleAction)
+
+ self.xAxisLogarithmicAction = self.group.addAction(
+ actions.control.XAxisLogarithmicAction(self, parent=self))
+ self.xAxisLogarithmicAction.setVisible(logScale)
+ self.addAction(self.xAxisLogarithmicAction)
+
+ self.yAxisLogarithmicAction = self.group.addAction(
+ actions.control.YAxisLogarithmicAction(self, parent=self))
+ self.yAxisLogarithmicAction.setVisible(logScale)
+ self.addAction(self.yAxisLogarithmicAction)
+
+ self.gridAction = self.group.addAction(
+ actions.control.GridAction(self, gridMode='both', parent=self))
+ self.gridAction.setVisible(grid)
+ self.addAction(self.gridAction)
+
+ self.curveStyleAction = self.group.addAction(
+ actions.control.CurveStyleAction(self, parent=self))
+ self.curveStyleAction.setVisible(curveStyle)
+ self.addAction(self.curveStyleAction)
+
+ self.colormapAction = self.group.addAction(
+ actions.control.ColormapAction(self, parent=self))
+ self.colormapAction.setVisible(colormap)
+ self.addAction(self.colormapAction)
+
+ self.colorbarAction = self.group.addAction(
+ actions_control.ColorBarAction(self, parent=self))
+ self.colorbarAction.setVisible(False)
+ self.addAction(self.colorbarAction)
+ self._colorbar.setVisible(False)
+
+ self.keepDataAspectRatioButton = PlotToolButtons.AspectToolButton(
+ parent=self, plot=self)
+ self.keepDataAspectRatioButton.setVisible(aspectRatio)
+
+ self.yAxisInvertedButton = PlotToolButtons.YAxisOriginToolButton(
+ parent=self, plot=self)
+ self.yAxisInvertedButton.setVisible(yInverted)
+
+ self.group.addAction(self.getRoiAction())
+ self.getRoiAction().setVisible(roi)
+
+ self.group.addAction(self.getMaskAction())
+ self.getMaskAction().setVisible(mask)
+
+ self._intensityHistoAction = self.group.addAction(
+ actions_histogram.PixelIntensitiesHistoAction(self, parent=self))
+ self._intensityHistoAction.setVisible(False)
+
+ self._medianFilter2DAction = self.group.addAction(
+ actions_medfilt.MedianFilter2DAction(self, parent=self))
+ self._medianFilter2DAction.setVisible(False)
+
+ self._medianFilter1DAction = self.group.addAction(
+ actions_medfilt.MedianFilter1DAction(self, parent=self))
+ self._medianFilter1DAction.setVisible(False)
+
+ self.fitAction = self.group.addAction(actions_fit.FitAction(self, parent=self))
+ self.fitAction.setVisible(fit)
+ self.addAction(self.fitAction)
+
+ # lazy loaded actions needed by the controlButton menu
+ self._consoleAction = None
+ self._statsAction = None
+ self._panWithArrowKeysAction = None
+ self._crosshairAction = None
+
+ # Make colorbar background white
+ self._colorbar.setAutoFillBackground(True)
+ self._sigAxesVisibilityChanged.connect(self._updateColorBarBackground)
+ self._updateColorBarBackground()
+
+ if control: # Create control button only if requested
+ self.controlButton = qt.QToolButton()
+ self.controlButton.setText("Options")
+ self.controlButton.setToolButtonStyle(qt.Qt.ToolButtonTextBesideIcon)
+ self.controlButton.setAutoRaise(True)
+ self.controlButton.setPopupMode(qt.QToolButton.InstantPopup)
+ menu = qt.QMenu(self)
+ menu.aboutToShow.connect(self._customControlButtonMenu)
+ self.controlButton.setMenu(menu)
+
+ self._positionWidget = None
+ if position: # Add PositionInfo widget to the bottom of the plot
+ if isinstance(position, abc.Iterable):
+ # Use position as a set of converters
+ converters = position
+ else:
+ converters = None
+ self._positionWidget = tools.PositionInfo(
+ plot=self, converters=converters)
+ # Set a snapping mode that is consistent with legacy one
+ self._positionWidget.setSnappingMode(
+ tools.PositionInfo.SNAPPING_CROSSHAIR |
+ tools.PositionInfo.SNAPPING_ACTIVE_ONLY |
+ tools.PositionInfo.SNAPPING_SYMBOLS_ONLY |
+ tools.PositionInfo.SNAPPING_CURVE |
+ tools.PositionInfo.SNAPPING_SCATTER)
+
+ self.__setCentralWidget()
+
+ # Creating the toolbar also create actions for toolbuttons
+ self._interactiveModeToolBar = tools.InteractiveModeToolBar(
+ parent=self, plot=self)
+ self.addToolBar(self._interactiveModeToolBar)
+
+ self._toolbar = self._createToolBar(title='Plot', parent=self)
+ self.addToolBar(self._toolbar)
+
+ self._outputToolBar = tools.OutputToolBar(parent=self, plot=self)
+ self._outputToolBar.getCopyAction().setVisible(copy)
+ self._outputToolBar.getSaveAction().setVisible(save)
+ self._outputToolBar.getPrintAction().setVisible(print_)
+ self.addToolBar(self._outputToolBar)
+
+ # Activate shortcuts in PlotWindow widget:
+ for toolbar in (self._interactiveModeToolBar, self._outputToolBar):
+ for action in toolbar.actions():
+ self.addAction(action)
+
+ def __setCentralWidget(self):
+ """Set central widget to host plot backend, colorbar, and bottom bar"""
+ gridLayout = qt.QGridLayout()
+ gridLayout.setSpacing(0)
+ gridLayout.setContentsMargins(0, 0, 0, 0)
+ gridLayout.addWidget(self.getWidgetHandle(), 0, 0)
+ gridLayout.addWidget(self._colorbar, 0, 1)
+ gridLayout.setRowStretch(0, 1)
+ gridLayout.setColumnStretch(0, 1)
+ centralWidget = qt.QWidget(self)
+ centralWidget.setLayout(gridLayout)
+
+ if hasattr(self, "controlButton") or self._positionWidget is not None:
+ hbox = qt.QHBoxLayout()
+ hbox.setContentsMargins(0, 0, 0, 0)
+
+ if hasattr(self, "controlButton"):
+ hbox.addWidget(self.controlButton)
+
+ if self._positionWidget is not None:
+ hbox.addWidget(self._positionWidget)
+
+ hbox.addStretch(1)
+ bottomBar = qt.QWidget(centralWidget)
+ bottomBar.setLayout(hbox)
+
+ gridLayout.addWidget(bottomBar, 1, 0, 1, -1)
+
+ self.setCentralWidget(centralWidget)
+
+ @docstring(PlotWidget)
+ def setBackend(self, backend):
+ super(PlotWindow, self).setBackend(backend)
+ self.__setCentralWidget() # Recreate PlotWindow's central widget
+
+ @docstring(PlotWidget)
+ def setBackgroundColor(self, color):
+ super(PlotWindow, self).setBackgroundColor(color)
+ self._updateColorBarBackground()
+
+ @docstring(PlotWidget)
+ def setDataBackgroundColor(self, color):
+ super(PlotWindow, self).setDataBackgroundColor(color)
+ self._updateColorBarBackground()
+
+ @docstring(PlotWidget)
+ def setForegroundColor(self, color):
+ super(PlotWindow, self).setForegroundColor(color)
+ self._updateColorBarBackground()
+
+ def _updateColorBarBackground(self):
+ """Update the colorbar background according to the state of the plot"""
+ if self.isAxesDisplayed():
+ color = self.getBackgroundColor()
+ else:
+ color = self.getDataBackgroundColor()
+ if not color.isValid():
+ # If no color defined, use the background one
+ color = self.getBackgroundColor()
+
+ foreground = self.getForegroundColor()
+
+ palette = self._colorbar.palette()
+ palette.setColor(qt.QPalette.Window, color)
+ palette.setColor(qt.QPalette.WindowText, foreground)
+ palette.setColor(qt.QPalette.Text, foreground)
+ self._colorbar.setPalette(palette)
+
+ def getInteractiveModeToolBar(self):
+ """Returns QToolBar controlling interactive mode.
+
+ :rtype: QToolBar
+ """
+ return self._interactiveModeToolBar
+
+ def getOutputToolBar(self):
+ """Returns QToolBar containing save, copy and print actions
+
+ :rtype: QToolBar
+ """
+ return self._outputToolBar
+
+ @property
+ @deprecated(replacement="getPositionInfoWidget()", since_version="0.8.0")
+ def positionWidget(self):
+ return self.getPositionInfoWidget()
+
+ def getPositionInfoWidget(self):
+ """Returns the widget displaying current cursor position information
+
+ :rtype: ~silx.gui.plot.tools.PositionInfo
+ """
+ return self._positionWidget
+
+ def getSelectionMask(self):
+ """Return the current mask handled by :attr:`maskToolsDockWidget`.
+
+ :return: The array of the mask with dimension of the 'active' image.
+ If there is no active image, an empty array is returned.
+ :rtype: 2D numpy.ndarray of uint8
+ """
+ return self.getMaskToolsDockWidget().getSelectionMask()
+
+ def setSelectionMask(self, mask):
+ """Set the mask handled by :attr:`maskToolsDockWidget`.
+
+ If the provided mask has not the same dimension as the 'active'
+ image, it will by cropped or padded.
+
+ :param mask: The array to use for the mask.
+ :type mask: numpy.ndarray of uint8 of dimension 2, C-contiguous.
+ Array of other types are converted.
+ :return: True if success, False if failed
+ """
+ return bool(self.getMaskToolsDockWidget().setSelectionMask(mask))
+
+ def _toggleConsoleVisibility(self, isChecked=False):
+ """Create IPythonDockWidget if needed,
+ show it or hide it."""
+ # create widget if needed (first call)
+ if self._consoleDockWidget is None:
+ available_vars = {"plt": weakref.proxy(self)}
+ banner = "The variable 'plt' is available. Use the 'whos' "
+ banner += "and 'help(plt)' commands for more information.\n\n"
+ self._consoleDockWidget = IPythonDockWidget(
+ available_vars=available_vars,
+ custom_banner=banner,
+ parent=self)
+ self.addTabbedDockWidget(self._consoleDockWidget)
+ # self._consoleDockWidget.setVisible(True)
+ self._consoleDockWidget.toggleViewAction().toggled.connect(
+ self.getConsoleAction().setChecked)
+
+ self._consoleDockWidget.setVisible(isChecked)
+
+ def _toggleStatsVisibility(self, isChecked=False):
+ self.getStatsWidget().parent().setVisible(isChecked)
+
+ def _createToolBar(self, title, parent):
+ """Create a QToolBar from the QAction of the PlotWindow.
+
+ :param str title: The title of the QMenu
+ :param qt.QWidget parent: See :class:`QToolBar`
+ """
+ toolbar = qt.QToolBar(title, parent)
+
+ # Order widgets with actions
+ objects = self.group.actions()
+
+ # Add push buttons to list
+ index = objects.index(self.colormapAction)
+ objects.insert(index + 1, self.keepDataAspectRatioButton)
+ objects.insert(index + 2, self.yAxisInvertedButton)
+
+ for obj in objects:
+ if isinstance(obj, qt.QAction):
+ toolbar.addAction(obj)
+ else:
+ # Add action for toolbutton in order to allow changing
+ # visibility (see doc QToolBar.addWidget doc)
+ if obj is self.keepDataAspectRatioButton:
+ self.keepDataAspectRatioAction = toolbar.addWidget(obj)
+ elif obj is self.yAxisInvertedButton:
+ self.yAxisInvertedAction = toolbar.addWidget(obj)
+ else:
+ raise RuntimeError()
+ return toolbar
+
+ def toolBar(self):
+ """Return a QToolBar from the QAction of the PlotWindow.
+ """
+ return self._toolbar
+
+ def menu(self, title='Plot', parent=None):
+ """Return a QMenu from the QAction of the PlotWindow.
+
+ :param str title: The title of the QMenu
+ :param parent: See :class:`QMenu`
+ """
+ menu = qt.QMenu(title, parent)
+ for action in self.group.actions():
+ menu.addAction(action)
+ return menu
+
+ def _customControlButtonMenu(self):
+ """Display Options button sub-menu."""
+ controlMenu = self.controlButton.menu()
+ controlMenu.clear()
+ controlMenu.addAction(self.getLegendsDockWidget().toggleViewAction())
+ controlMenu.addAction(self.getRoiAction())
+ controlMenu.addAction(self.getStatsAction())
+ controlMenu.addAction(self.getMaskAction())
+ controlMenu.addAction(self.getConsoleAction())
+
+ controlMenu.addSeparator()
+ controlMenu.addAction(self.getCrosshairAction())
+ controlMenu.addAction(self.getPanWithArrowKeysAction())
+
+ def addTabbedDockWidget(self, dock_widget):
+ """Add a dock widget as a new tab if there are already dock widgets
+ in the plot. When the first tab is added, the area is chosen
+ depending on the plot geometry:
+ if the window is much wider than it is high, the right dock area
+ is used, else the bottom dock area is used.
+
+ :param dock_widget: Instance of :class:`QDockWidget` to be added.
+ """
+ if dock_widget not in self._dockWidgets:
+ self._dockWidgets.append(dock_widget)
+ if len(self._dockWidgets) == 1:
+ # The first created dock widget must be added to a Widget area
+ width = self.centralWidget().width()
+ height = self.centralWidget().height()
+ if width > (1.25 * height):
+ area = qt.Qt.RightDockWidgetArea
+ else:
+ area = qt.Qt.BottomDockWidgetArea
+ self.addDockWidget(area, dock_widget)
+ else:
+ # Other dock widgets are added as tabs to the same widget area
+ self.tabifyDockWidget(self._dockWidgets[0],
+ dock_widget)
+
+ def removeDockWidget(self, dockwidget):
+ """Removes the *dockwidget* from the main window layout and hides it.
+
+ Note that the *dockwidget* is *not* deleted.
+
+ :param QDockWidget dockwidget:
+ """
+ if dockwidget in self._dockWidgets:
+ self._dockWidgets.remove(dockwidget)
+ super(PlotWindow, self).removeDockWidget(dockwidget)
+
+ def _handleFirstDockWidgetShow(self, visible):
+ """Handle QDockWidget.visibilityChanged
+
+ It calls :meth:`addTabbedDockWidget` for the `sender` widget.
+ This allows to call `addTabbedDockWidget` lazily.
+
+ It disconnect itself from the signal once done.
+
+ :param bool visible:
+ """
+ if visible:
+ dockWidget = self.sender()
+ dockWidget.visibilityChanged.disconnect(
+ self._handleFirstDockWidgetShow)
+ self.addTabbedDockWidget(dockWidget)
+
+ def getColorBarWidget(self):
+ """Returns the embedded :class:`ColorBarWidget` widget.
+
+ :rtype: ColorBarWidget
+ """
+ return self._colorbar
+
+ # getters for dock widgets
+
+ def getLegendsDockWidget(self):
+ """DockWidget with Legend panel"""
+ if self._legendsDockWidget is None:
+ self._legendsDockWidget = LegendsDockWidget(plot=self)
+ self._legendsDockWidget.hide()
+ self._legendsDockWidget.visibilityChanged.connect(
+ self._handleFirstDockWidgetShow)
+ return self._legendsDockWidget
+
+ def getCurvesRoiDockWidget(self):
+ # Undocumented for a "soft deprecation" in version 0.7.0
+ # (still used internally for lazy loading)
+ if self._curvesROIDockWidget is None:
+ self._curvesROIDockWidget = CurvesROIDockWidget(
+ plot=self, name='Regions Of Interest')
+ self._curvesROIDockWidget.hide()
+ self._curvesROIDockWidget.visibilityChanged.connect(
+ self._handleFirstDockWidgetShow)
+ return self._curvesROIDockWidget
+
+ def getCurvesRoiWidget(self):
+ """Return the :class:`CurvesROIWidget`.
+
+ :class:`silx.gui.plot.CurvesROIWidget.CurvesROIWidget` offers a getter
+ and a setter for the ROI data:
+
+ - :meth:`CurvesROIWidget.getRois`
+ - :meth:`CurvesROIWidget.setRois`
+ """
+ return self.getCurvesRoiDockWidget().roiWidget
+
+ def getMaskToolsDockWidget(self):
+ """DockWidget with image mask panel (lazy-loaded)."""
+ if self._maskToolsDockWidget is None:
+ self._maskToolsDockWidget = MaskToolsDockWidget(
+ plot=self, name='Mask')
+ self._maskToolsDockWidget.hide()
+ self._maskToolsDockWidget.visibilityChanged.connect(
+ self._handleFirstDockWidgetShow)
+ return self._maskToolsDockWidget
+
+ def getStatsWidget(self):
+ """Returns a BasicStatsWidget connected to this plot
+
+ :rtype: BasicStatsWidget
+ """
+ if self._statsDockWidget is None:
+ self._statsDockWidget = qt.QDockWidget()
+ self._statsDockWidget.setWindowTitle("Curves stats")
+ self._statsDockWidget.layout().setContentsMargins(0, 0, 0, 0)
+ statsWidget = BasicStatsWidget(parent=self, plot=self)
+ self._statsDockWidget.setWidget(statsWidget)
+ statsWidget.sigVisibilityChanged.connect(
+ self.getStatsAction().setChecked)
+ self._statsDockWidget.hide()
+ self._statsDockWidget.visibilityChanged.connect(
+ self._handleFirstDockWidgetShow)
+ return self._statsDockWidget.widget()
+
+ # getters for actions
+ @property
+ @deprecated(replacement="getInteractiveModeToolBar().getZoomModeAction()",
+ since_version="0.8.0")
+ def zoomModeAction(self):
+ return self.getInteractiveModeToolBar().getZoomModeAction()
+
+ @property
+ @deprecated(replacement="getInteractiveModeToolBar().getPanModeAction()",
+ since_version="0.8.0")
+ def panModeAction(self):
+ return self.getInteractiveModeToolBar().getPanModeAction()
+
+ def getConsoleAction(self):
+ """QAction handling the IPython console activation.
+
+ By default, it is connected to a method that initializes the
+ console widget the first time the user clicks the "Console" menu
+ button. The following clicks, after initialization is done,
+ will toggle the visibility of the console widget.
+
+ :rtype: QAction
+ """
+ if self._consoleAction is None:
+ self._consoleAction = qt.QAction('Console', self)
+ self._consoleAction.setCheckable(True)
+ if IPythonDockWidget is not None:
+ self._consoleAction.toggled.connect(self._toggleConsoleVisibility)
+ else:
+ self._consoleAction.setEnabled(False)
+ return self._consoleAction
+
+ def getCrosshairAction(self):
+ """Action toggling crosshair cursor mode.
+
+ :rtype: actions.PlotAction
+ """
+ if self._crosshairAction is None:
+ self._crosshairAction = actions.control.CrosshairAction(self, color='red')
+ return self._crosshairAction
+
+ def getMaskAction(self):
+ """QAction toggling image mask dock widget
+
+ :rtype: QAction
+ """
+ return self.getMaskToolsDockWidget().toggleViewAction()
+
+ def getPanWithArrowKeysAction(self):
+ """Action toggling pan with arrow keys.
+
+ :rtype: actions.PlotAction
+ """
+ if self._panWithArrowKeysAction is None:
+ self._panWithArrowKeysAction = actions.control.PanWithArrowKeysAction(self)
+ return self._panWithArrowKeysAction
+
+ def getStatsAction(self):
+ if self._statsAction is None:
+ self._statsAction = qt.QAction('Curves stats', self)
+ self._statsAction.setCheckable(True)
+ self._statsAction.setChecked(self.getStatsWidget().parent().isVisible())
+ self._statsAction.toggled.connect(self._toggleStatsVisibility)
+ return self._statsAction
+
+ def getRoiAction(self):
+ """QAction toggling curve ROI dock widget
+
+ :rtype: QAction
+ """
+ return self.getCurvesRoiDockWidget().toggleViewAction()
+
+ def getResetZoomAction(self):
+ """Action resetting the zoom
+
+ :rtype: actions.PlotAction
+ """
+ return self.resetZoomAction
+
+ def getZoomInAction(self):
+ """Action to zoom in
+
+ :rtype: actions.PlotAction
+ """
+ return self.zoomInAction
+
+ def getZoomOutAction(self):
+ """Action to zoom out
+
+ :rtype: actions.PlotAction
+ """
+ return self.zoomOutAction
+
+ def getXAxisAutoScaleAction(self):
+ """Action to toggle the X axis autoscale on zoom reset
+
+ :rtype: actions.PlotAction
+ """
+ return self.xAxisAutoScaleAction
+
+ def getYAxisAutoScaleAction(self):
+ """Action to toggle the Y axis autoscale on zoom reset
+
+ :rtype: actions.PlotAction
+ """
+ return self.yAxisAutoScaleAction
+
+ def getXAxisLogarithmicAction(self):
+ """Action to toggle logarithmic X axis
+
+ :rtype: actions.PlotAction
+ """
+ return self.xAxisLogarithmicAction
+
+ def getYAxisLogarithmicAction(self):
+ """Action to toggle logarithmic Y axis
+
+ :rtype: actions.PlotAction
+ """
+ return self.yAxisLogarithmicAction
+
+ def getGridAction(self):
+ """Action to toggle the grid visibility in the plot
+
+ :rtype: actions.PlotAction
+ """
+ return self.gridAction
+
+ def getCurveStyleAction(self):
+ """Action to change curve line and markers styles
+
+ :rtype: actions.PlotAction
+ """
+ return self.curveStyleAction
+
+ def getColormapAction(self):
+ """Action open a colormap dialog to change active image
+ and default colormap.
+
+ :rtype: actions.PlotAction
+ """
+ return self.colormapAction
+
+ def getKeepDataAspectRatioButton(self):
+ """Button to toggle aspect ratio preservation
+
+ :rtype: PlotToolButtons.AspectToolButton
+ """
+ return self.keepDataAspectRatioButton
+
+ def getKeepDataAspectRatioAction(self):
+ """Action associated to keepDataAspectRatioButton.
+ Use this to change the visibility of keepDataAspectRatioButton in the
+ toolbar (See :meth:`QToolBar.addWidget` documentation).
+
+ :rtype: actions.PlotAction
+ """
+ return self.keepDataAspectRatioAction
+
+ def getYAxisInvertedButton(self):
+ """Button to switch the Y axis orientation
+
+ :rtype: PlotToolButtons.YAxisOriginToolButton
+ """
+ return self.yAxisInvertedButton
+
+ def getYAxisInvertedAction(self):
+ """Action associated to yAxisInvertedButton.
+ Use this to change the visibility yAxisInvertedButton in the toolbar.
+ (See :meth:`QToolBar.addWidget` documentation).
+
+ :rtype: actions.PlotAction
+ """
+ return self.yAxisInvertedAction
+
+ def getIntensityHistogramAction(self):
+ """Action toggling the histogram intensity Plot widget
+
+ :rtype: actions.PlotAction
+ """
+ return self._intensityHistoAction
+
+ def getCopyAction(self):
+ """Action to copy plot snapshot to clipboard
+
+ :rtype: actions.PlotAction
+ """
+ return self.getOutputToolBar().getCopyAction()
+
+ def getSaveAction(self):
+ """Action to save plot
+
+ :rtype: actions.PlotAction
+ """
+ return self.getOutputToolBar().getSaveAction()
+
+ def getPrintAction(self):
+ """Action to print plot
+
+ :rtype: actions.PlotAction
+ """
+ return self.getOutputToolBar().getPrintAction()
+
+ def getFitAction(self):
+ """Action to fit selected curve
+
+ :rtype: actions.PlotAction
+ """
+ return self.fitAction
+
+ def getMedianFilter1DAction(self):
+ """Action toggling the 1D median filter
+
+ :rtype: actions.PlotAction
+ """
+ return self._medianFilter1DAction
+
+ def getMedianFilter2DAction(self):
+ """Action toggling the 2D median filter
+
+ :rtype: actions.PlotAction
+ """
+ return self._medianFilter2DAction
+
+ def getColorBarAction(self):
+ """Action toggling the colorbar show/hide action
+
+ .. warning:: to show/hide the plot colorbar call directly the ColorBar
+ widget using getColorBarWidget()
+
+ :rtype: actions.PlotAction
+ """
+ return self.colorbarAction
+
+
+class Plot1D(PlotWindow):
+ """PlotWindow with tools specific for curves.
+
+ This widgets provides the plot API of :class:`.PlotWidget`.
+
+ :param parent: The parent of this widget
+ :param backend: The backend to use for the plot (default: matplotlib).
+ See :class:`.PlotWidget` for the list of supported backend.
+ :type backend: str or :class:`BackendBase.BackendBase`
+ """
+
+ def __init__(self, parent=None, backend=None):
+ super(Plot1D, self).__init__(parent=parent, backend=backend,
+ resetzoom=True, autoScale=True,
+ logScale=True, grid=True,
+ curveStyle=True, colormap=False,
+ aspectRatio=False, yInverted=False,
+ copy=True, save=True, print_=True,
+ control=True, position=True,
+ roi=True, mask=False, fit=True)
+ if parent is None:
+ self.setWindowTitle('Plot1D')
+ self.getXAxis().setLabel('X')
+ self.getYAxis().setLabel('Y')
+ action = self.getFitAction()
+ action.setXRangeUpdatedOnZoom(True)
+ action.setFittedItemUpdatedFromActiveCurve(True)
+
+
+class Plot2D(PlotWindow):
+ """PlotWindow with a toolbar specific for images.
+
+ This widgets provides the plot API of :~:`.PlotWidget`.
+
+ :param parent: The parent of this widget
+ :param backend: The backend to use for the plot (default: matplotlib).
+ See :class:`.PlotWidget` for the list of supported backend.
+ :type backend: str or :class:`BackendBase.BackendBase`
+ """
+
+ def __init__(self, parent=None, backend=None):
+ # List of information to display at the bottom of the plot
+ posInfo = [
+ ('X', lambda x, y: x),
+ ('Y', lambda x, y: y),
+ ('Data', WeakMethodProxy(self._getImageValue)),
+ ('Dims', WeakMethodProxy(self._getImageDims)),
+ ]
+
+ super(Plot2D, self).__init__(parent=parent, backend=backend,
+ resetzoom=True, autoScale=False,
+ logScale=False, grid=False,
+ curveStyle=False, colormap=True,
+ aspectRatio=True, yInverted=True,
+ copy=True, save=True, print_=True,
+ control=False, position=posInfo,
+ roi=False, mask=True)
+ if parent is None:
+ self.setWindowTitle('Plot2D')
+ self.getXAxis().setLabel('Columns')
+ self.getYAxis().setLabel('Rows')
+
+ if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == 'downward':
+ self.getYAxis().setInverted(True)
+
+ self.profile = ProfileToolBar(plot=self)
+ self.addToolBar(self.profile)
+
+ self.colorbarAction.setVisible(True)
+ self.getColorBarWidget().setVisible(True)
+
+ # Put colorbar action after colormap action
+ actions = self.toolBar().actions()
+ for action in actions:
+ if action is self.getColormapAction():
+ break
+
+ self.sigActiveImageChanged.connect(self.__activeImageChanged)
+
+ def __activeImageChanged(self, previous, legend):
+ """Handle change of active image
+
+ :param Union[str,None] previous: Legend of previous active image
+ :param Union[str,None] legend: Legend of current active image
+ """
+ if previous is not None:
+ item = self.getImage(previous)
+ if item is not None:
+ item.sigItemChanged.disconnect(self.__imageChanged)
+
+ if legend is not None:
+ item = self.getImage(legend)
+ item.sigItemChanged.connect(self.__imageChanged)
+
+ positionInfo = self.getPositionInfoWidget()
+ if positionInfo is not None:
+ positionInfo.updateInfo()
+
+ def __imageChanged(self, event):
+ """Handle update of active image item
+
+ :param event: Type of changed event
+ """
+ if event == items.ItemChangedType.DATA:
+ positionInfo = self.getPositionInfoWidget()
+ if positionInfo is not None:
+ positionInfo.updateInfo()
+
+ def _getImageValue(self, x, y):
+ """Get status bar value of top most image at position (x, y)
+
+ :param float x: X position in plot coordinates
+ :param float y: Y position in plot coordinates
+ :return: The value at that point or '-'
+ """
+ pickedMask = None
+ for picked in self.pickItems(
+ *self.dataToPixel(x, y, check=False),
+ lambda item: isinstance(item, items.ImageBase)):
+ if isinstance(picked.getItem(), items.MaskImageData):
+ if pickedMask is None: # Use top-most if many masks
+ pickedMask = picked
+ else:
+ image = picked.getItem()
+
+ indices = picked.getIndices(copy=False)
+ if indices is not None:
+ row, col = indices[0][0], indices[1][0]
+ value = image.getData(copy=False)[row, col]
+
+ if pickedMask is not None: # Check if masked
+ maskItem = pickedMask.getItem()
+ indices = pickedMask.getIndices()
+ row, col = indices[0][0], indices[1][0]
+ if maskItem.getData(copy=False)[row, col] != 0:
+ return value, "Masked"
+ return value
+
+ return '-' # No image picked
+
+ def _getImageDims(self, *args):
+ activeImage = self.getActiveImage()
+ if (activeImage is not None and
+ activeImage.getData(copy=False) is not None):
+ dims = activeImage.getData(copy=False).shape[1::-1]
+ return 'x'.join(str(dim) for dim in dims)
+ else:
+ return '-'
+
+ def getProfileToolbar(self):
+ """Profile tools attached to this plot
+
+ See :class:`silx.gui.plot.Profile.ProfileToolBar`
+ """
+ return self.profile
+
+ @deprecated(replacement="getProfilePlot", since_version="0.5.0")
+ def getProfileWindow(self):
+ return self.getProfilePlot()
+
+ def getProfilePlot(self):
+ """Return plot window used to display profile curve.
+
+ :return: :class:`Plot1D`
+ """
+ return self.profile.getProfilePlot()
diff --git a/src/silx/gui/plot/PrintPreviewToolButton.py b/src/silx/gui/plot/PrintPreviewToolButton.py
new file mode 100644
index 0000000..30967e4
--- /dev/null
+++ b/src/silx/gui/plot/PrintPreviewToolButton.py
@@ -0,0 +1,388 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This modules provides tool buttons to send the content of a plot to a
+print preview page.
+The plot content can then be moved on the page and resized prior to printing.
+
+Classes
+-------
+
+- :class:`PrintPreviewToolButton`
+- :class:`SingletonPrintPreviewToolButton`
+
+Examples
+--------
+
+Simple example
+++++++++++++++
+
+.. code-block:: python
+
+ from silx.gui import qt
+ from silx.gui.plot import PlotWidget
+ from silx.gui.plot.PrintPreviewToolButton import PrintPreviewToolButton
+ import numpy
+
+ app = qt.QApplication([])
+
+ pw = PlotWidget()
+ toolbar = qt.QToolBar(pw)
+ toolbutton = PrintPreviewToolButton(parent=toolbar, plot=pw)
+ pw.addToolBar(toolbar)
+ toolbar.addWidget(toolbutton)
+ pw.show()
+
+ x = numpy.arange(1000)
+ y = x / numpy.sin(x)
+ pw.addCurve(x, y)
+
+ app.exec()
+
+Singleton example
++++++++++++++++++
+
+This example illustrates how to print the content of several different
+plots on the same page. The plots all instantiate a
+:class:`SingletonPrintPreviewToolButton`, which relies on a singleton widget
+(:class:`silx.gui.widgets.PrintPreview.SingletonPrintPreviewDialog`).
+
+.. image:: img/printPreviewMultiPlot.png
+
+.. code-block:: python
+
+ from silx.gui import qt
+ from silx.gui.plot import PlotWidget
+ from silx.gui.plot.PrintPreviewToolButton import SingletonPrintPreviewToolButton
+ import numpy
+
+ app = qt.QApplication([])
+
+ plot_widgets = []
+
+ for i in range(3):
+ pw = PlotWidget()
+ toolbar = qt.QToolBar(pw)
+ toolbutton = SingletonPrintPreviewToolButton(parent=toolbar,
+ plot=pw)
+ pw.addToolBar(toolbar)
+ toolbar.addWidget(toolbutton)
+ pw.show()
+ plot_widgets.append(pw)
+
+ x = numpy.arange(1000)
+
+ plot_widgets[0].addCurve(x, numpy.sin(x * 2 * numpy.pi / 1000))
+ plot_widgets[1].addCurve(x, numpy.cos(x * 2 * numpy.pi / 1000))
+ plot_widgets[2].addCurve(x, numpy.tan(x * 2 * numpy.pi / 1000))
+
+ app.exec()
+
+"""
+from __future__ import absolute_import
+
+import logging
+from io import StringIO
+
+from .. import qt
+from .. import icons
+from . import PlotWidget
+from ..widgets.PrintPreview import PrintPreviewDialog, SingletonPrintPreviewDialog
+from ..widgets.PrintGeometryDialog import PrintGeometryDialog
+from silx.utils.deprecation import deprecated
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "20/12/2018"
+
+_logger = logging.getLogger(__name__)
+# _logger.setLevel(logging.DEBUG)
+
+
+class PrintPreviewToolButton(qt.QToolButton):
+ """QToolButton to open a :class:`PrintPreviewDialog` (if not already open)
+ and add the current plot to its page to be printed.
+
+ :param parent: See :class:`QAction`
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ """
+ def __init__(self, parent=None, plot=None):
+ super(PrintPreviewToolButton, self).__init__(parent)
+
+ if not isinstance(plot, PlotWidget):
+ raise TypeError("plot parameter must be a PlotWidget")
+ self._plot = plot
+
+ self.setIcon(icons.getQIcon('document-print'))
+
+ printGeomAction = qt.QAction("Print geometry", self)
+ printGeomAction.setToolTip("Define a print geometry prior to sending "
+ "the plot to the print preview dialog")
+ printGeomAction.setIcon(icons.getQIcon('shape-rectangle'))
+ printGeomAction.triggered.connect(self._setPrintConfiguration)
+
+ printPreviewAction = qt.QAction("Print preview", self)
+ printPreviewAction.setToolTip("Send plot to the print preview dialog")
+ printPreviewAction.setIcon(icons.getQIcon('document-print'))
+ printPreviewAction.triggered.connect(self._plotToPrintPreview)
+
+ menu = qt.QMenu(self)
+ menu.addAction(printGeomAction)
+ menu.addAction(printPreviewAction)
+ self.setMenu(menu)
+ self.setPopupMode(qt.QToolButton.InstantPopup)
+
+ self._printPreviewDialog = None
+ self._printConfigurationDialog = None
+
+ self._printGeometry = {"xOffset": 0.1,
+ "yOffset": 0.1,
+ "width": 0.9,
+ "height": 0.9,
+ "units": "page",
+ "keepAspectRatio": True}
+
+ @property
+ def printPreviewDialog(self):
+ """Lazy loaded :class:`PrintPreviewDialog`"""
+ # if changes are made here, don't forget making them in
+ # SingletonPrintPreviewToolButton.printPreviewDialog as well
+ if self._printPreviewDialog is None:
+ self._printPreviewDialog = PrintPreviewDialog(self.parent())
+ return self._printPreviewDialog
+
+ def getTitle(self):
+ """Implement this method to fetch the title in the plot.
+
+ :return: Title to be printed above the plot, or None (no title added)
+ :rtype: str or None
+ """
+ return None
+
+ def getCommentAndPosition(self):
+ """Implement this method to fetch the legend to be printed below the
+ figure and its position.
+
+ :return: Legend to be printed below the figure and its position:
+ "CENTER", "LEFT" or "RIGHT"
+ :rtype: (str, str) or (None, None)
+ """
+ return None, None
+
+ @property
+ @deprecated(since_version="0.10",
+ replacement="getPlot()")
+ def plot(self):
+ return self._plot
+
+ def getPlot(self):
+ """Return the :class:`.PlotWidget` associated with this tool button.
+
+ :rtype: :class:`.PlotWidget`
+ """
+ return self._plot
+
+ def _plotToPrintPreview(self):
+ """Grab the plot widget and send it to the print preview dialog.
+ Make sure the print preview dialog is shown and raised."""
+ if not self.printPreviewDialog.ensurePrinterIsSet():
+ return
+
+ comment, commentPosition = self.getCommentAndPosition()
+
+ if qt.HAS_SVG:
+ svgRenderer, viewBox = self._getSvgRendererAndViewbox()
+ self.printPreviewDialog.addSvgItem(svgRenderer,
+ title=self.getTitle(),
+ comment=comment,
+ commentPosition=commentPosition,
+ viewBox=viewBox,
+ keepRatio=self._printGeometry["keepAspectRatio"])
+ else:
+ _logger.warning("Missing QtSvg library, using a raster image")
+ pixmap = self._plot.centralWidget().grab()
+ self.printPreviewDialog.addPixmap(pixmap,
+ title=self.getTitle(),
+ comment=comment,
+ commentPosition=commentPosition)
+ self.printPreviewDialog.show()
+ self.printPreviewDialog.raise_()
+
+ def _getSvgRendererAndViewbox(self):
+ """Return a SVG renderer displaying the plot and its viewbox
+ (interactively specified by the user the first time this is called).
+
+ The size of the renderer is adjusted to the printer configuration
+ and to the geometry configuration (width, height, ratio) specified
+ by the user."""
+ imgData = StringIO()
+ assert self._plot.saveGraph(imgData, fileFormat="svg"), \
+ "Unable to save graph"
+ imgData.flush()
+ imgData.seek(0)
+ svgData = imgData.read()
+
+ svgRenderer = qt.QSvgRenderer()
+
+ viewbox = self._getViewBox()
+
+ svgRenderer.setViewBox(viewbox)
+
+ xml_stream = qt.QXmlStreamReader(svgData.encode(errors="replace"))
+
+ # This is for PyMca compatibility, to share a print preview with PyMca plots
+ svgRenderer._viewBox = viewbox
+ svgRenderer._svgRawData = svgData.encode(errors="replace")
+ svgRenderer._svgRendererData = xml_stream
+
+ if not svgRenderer.load(xml_stream):
+ raise RuntimeError("Cannot interpret svg data")
+
+ return svgRenderer, viewbox
+
+ def _getViewBox(self):
+ """
+ """
+ printer = self.printPreviewDialog.printer
+ dpix = printer.logicalDpiX()
+ dpiy = printer.logicalDpiY()
+ availableWidth = printer.width()
+ availableHeight = printer.height()
+
+ config = self._printGeometry
+ width = config['width']
+ height = config['height']
+ xOffset = config['xOffset']
+ yOffset = config['yOffset']
+ units = config['units']
+ keepAspectRatio = config['keepAspectRatio']
+ aspectRatio = self._getPlotAspectRatio()
+
+ # convert the offsets to dots
+ if units.lower() in ['inch', 'inches']:
+ xOffset = xOffset * dpix
+ yOffset = yOffset * dpiy
+ if width is not None:
+ width = width * dpix
+ if height is not None:
+ height = height * dpiy
+ elif units.lower() in ['cm', 'centimeters']:
+ xOffset = (xOffset / 2.54) * dpix
+ yOffset = (yOffset / 2.54) * dpiy
+ if width is not None:
+ width = (width / 2.54) * dpix
+ if height is not None:
+ height = (height / 2.54) * dpiy
+ else:
+ # page units
+ xOffset = availableWidth * xOffset
+ yOffset = availableHeight * yOffset
+ if width is not None:
+ width = availableWidth * width
+ if height is not None:
+ height = availableHeight * height
+
+ availableWidth -= xOffset
+ availableHeight -= yOffset
+
+ if width is not None:
+ if (availableWidth + 0.1) < width:
+ txt = "Available width %f is less than requested width %f" % \
+ (availableWidth, width)
+ raise ValueError(txt)
+ if height is not None:
+ if (availableHeight + 0.1) < height:
+ txt = "Available height %f is less than requested height %f" % \
+ (availableHeight, height)
+ raise ValueError(txt)
+
+ if keepAspectRatio:
+ bodyWidth = width or availableWidth
+ bodyHeight = bodyWidth * aspectRatio
+
+ if bodyHeight > availableHeight:
+ bodyHeight = availableHeight
+ bodyWidth = bodyHeight / aspectRatio
+
+ else:
+ bodyWidth = width or availableWidth
+ bodyHeight = height or availableHeight
+
+ return qt.QRectF(xOffset,
+ yOffset,
+ bodyWidth,
+ bodyHeight)
+
+ def _setPrintConfiguration(self):
+ """Open a dialog to prompt the user to adjust print
+ geometry parameters."""
+ self.printPreviewDialog.ensurePrinterIsSet()
+ if self._printConfigurationDialog is None:
+ self._printConfigurationDialog = PrintGeometryDialog(self.parent())
+
+ self._printConfigurationDialog.setPrintGeometry(self._printGeometry)
+ if self._printConfigurationDialog.exec():
+ self._printGeometry = self._printConfigurationDialog.getPrintGeometry()
+
+ def _getPlotAspectRatio(self):
+ widget = self._plot.centralWidget()
+ graphWidth = float(widget.width())
+ graphHeight = float(widget.height())
+ return graphHeight / graphWidth
+
+
+class SingletonPrintPreviewToolButton(PrintPreviewToolButton):
+ """This class is similar to its parent class :class:`PrintPreviewToolButton`
+ but it uses a singleton print preview widget.
+
+ This allows for several plots to send their content to the
+ same print page, and for users to arrange them."""
+ def __init__(self, parent=None, plot=None):
+ PrintPreviewToolButton.__init__(self, parent, plot)
+
+ @property
+ def printPreviewDialog(self):
+ if self._printPreviewDialog is None:
+ self._printPreviewDialog = SingletonPrintPreviewDialog(self.parent())
+ return self._printPreviewDialog
+
+
+if __name__ == '__main__':
+ import numpy
+ app = qt.QApplication([])
+
+ pw = PlotWidget()
+ toolbar = qt.QToolBar(pw)
+ toolbutton = PrintPreviewToolButton(parent=toolbar,
+ plot=pw)
+ pw.addToolBar(toolbar)
+ toolbar.addWidget(toolbutton)
+ pw.show()
+
+ x = numpy.arange(1000)
+ y = x / numpy.sin(x)
+ pw.addCurve(x, y)
+
+ app.exec()
diff --git a/src/silx/gui/plot/Profile.py b/src/silx/gui/plot/Profile.py
new file mode 100644
index 0000000..7565155
--- /dev/null
+++ b/src/silx/gui/plot/Profile.py
@@ -0,0 +1,352 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Utility functions, toolbars and actions to create profile on images
+and stacks of images"""
+
+
+__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel", "H. Payno"]
+__license__ = "MIT"
+__date__ = "12/04/2019"
+
+
+import weakref
+
+from .. import qt
+from . import actions
+from .tools.profile import core
+from .tools.profile import manager
+from .tools.profile import rois
+from silx.gui.widgets.MultiModeAction import MultiModeAction
+
+from silx.utils.deprecation import deprecated
+from silx.utils.deprecation import deprecated_warning
+from .tools import roi as roi_mdl
+from silx.gui.plot import items
+
+
+@deprecated(replacement="silx.gui.plot.tools.profile.createProfile", since_version="0.13.0")
+def createProfile(roiInfo, currentData, origin, scale, lineWidth, method):
+ return core.createProfile(roiInfo, currentData, origin,
+ scale, lineWidth, method)
+
+
+class _CustomProfileManager(manager.ProfileManager):
+ """This custom profile manager uses a single predefined profile window
+ if it is specified. Else the behavior is the same as the default
+ ProfileManager """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.__profileWindow = None
+ self.__specializedProfileWindows = {}
+
+ def setSpecializedProfileWindow(self, roiClass, profileWindow):
+ """Set a profile window for a given class or ROI.
+
+ Setting profileWindow to None removes the roiClass from the list.
+
+ :param roiClass:
+ :param profileWindow:
+ """
+ if profileWindow is None:
+ self.__specializedProfileWindows.pop(roiClass, None)
+ else:
+ self.__specializedProfileWindows[roiClass] = profileWindow
+
+ def setProfileWindow(self, profileWindow):
+ self.__profileWindow = profileWindow
+
+ def createProfileWindow(self, plot, roi):
+ for roiClass, specializedProfileWindow in self.__specializedProfileWindows.items():
+ if isinstance(roi, roiClass):
+ return specializedProfileWindow
+
+ if self.__profileWindow is not None:
+ return self.__profileWindow
+ else:
+ return super(_CustomProfileManager, self).createProfileWindow(plot, roi)
+
+ def clearProfileWindow(self, profileWindow):
+ for specializedProfileWindow in self.__specializedProfileWindows.values():
+ if profileWindow is specializedProfileWindow:
+ profileWindow.setProfile(None)
+ return
+
+ if self.__profileWindow is not None:
+ self.__profileWindow.setProfile(None)
+ else:
+ return super(_CustomProfileManager, self).clearProfileWindow(profileWindow)
+
+
+class ProfileToolBar(qt.QToolBar):
+ """QToolBar providing profile tools operating on a :class:`PlotWindow`.
+
+ Attributes:
+
+ - plot: Associated :class:`PlotWindow` on which the profile line is drawn.
+ - actionGroup: :class:`QActionGroup` of available actions.
+
+ To run the following sample code, a QApplication must be initialized.
+ First, create a PlotWindow and add a :class:`ProfileToolBar`.
+
+ >>> from silx.gui.plot import PlotWindow
+ >>> from silx.gui.plot.Profile import ProfileToolBar
+
+ >>> plot = PlotWindow() # Create a PlotWindow
+ >>> toolBar = ProfileToolBar(plot=plot) # Create a profile toolbar
+ >>> plot.addToolBar(toolBar) # Add it to plot
+ >>> plot.show() # To display the PlotWindow with the profile toolbar
+
+ :param plot: :class:`PlotWindow` instance on which to operate.
+ :param profileWindow: Plot widget instance where to
+ display the profile curve or None to create one.
+ :param str title: See :class:`QToolBar`.
+ :param parent: See :class:`QToolBar`.
+ """
+
+ def __init__(self, parent=None, plot=None, profileWindow=None,
+ title=None):
+ super(ProfileToolBar, self).__init__(title, parent)
+ assert plot is not None
+
+ if title is not None:
+ deprecated_warning("Attribute",
+ name="title",
+ reason="removed",
+ since_version="0.13.0",
+ only_once=True,
+ skip_backtrace_count=1)
+
+ self._plotRef = weakref.ref(plot)
+
+ # If a profileWindow is defined,
+ # It will be used to display all the profiles
+ self._manager = self.createProfileManager(self, plot)
+ self._manager.setProfileWindow(profileWindow)
+ self._manager.setDefaultColorFromCursorColor(True)
+ self._manager.setItemType(image=True)
+ self._manager.setActiveItemTracking(True)
+
+ # Actions
+ self._browseAction = actions.mode.ZoomModeAction(plot, parent=self)
+ self._browseAction.setVisible(False)
+ self.freeLineAction = None
+ self._createProfileActions()
+ self._editor = self._manager.createEditorAction(self)
+
+ # ActionGroup
+ self.actionGroup = qt.QActionGroup(self)
+ self.actionGroup.addAction(self._browseAction)
+ self.actionGroup.addAction(self.hLineAction)
+ self.actionGroup.addAction(self.vLineAction)
+ self.actionGroup.addAction(self.lineAction)
+ self.actionGroup.addAction(self._editor)
+
+ modes = MultiModeAction(self)
+ modes.addAction(self.hLineAction)
+ modes.addAction(self.vLineAction)
+ modes.addAction(self.lineAction)
+ if self.freeLineAction is not None:
+ modes.addAction(self.freeLineAction)
+ modes.addAction(self.crossAction)
+ self.__multiAction = modes
+
+ # Add actions to ToolBar
+ self.addAction(self._browseAction)
+ self.addAction(modes)
+ self.addAction(self._editor)
+ self.addAction(self.clearAction)
+
+ plot.sigActiveImageChanged.connect(self._activeImageChanged)
+ self._activeImageChanged()
+
+ def createProfileManager(self, parent, plot):
+ return _CustomProfileManager(parent, plot)
+
+ def _createProfileActions(self):
+ self.hLineAction = self._manager.createProfileAction(rois.ProfileImageHorizontalLineROI, self)
+ self.vLineAction = self._manager.createProfileAction(rois.ProfileImageVerticalLineROI, self)
+ self.lineAction = self._manager.createProfileAction(rois.ProfileImageLineROI, self)
+ self.freeLineAction = self._manager.createProfileAction(rois.ProfileImageDirectedLineROI, self)
+ self.crossAction = self._manager.createProfileAction(rois.ProfileImageCrossROI, self)
+ self.clearAction = self._manager.createClearAction(self)
+
+ def getPlotWidget(self):
+ """The :class:`.PlotWidget` associated to the toolbar."""
+ return self._plotRef()
+
+ @property
+ @deprecated(since_version="0.13.0", replacement="getPlotWidget()")
+ def plot(self):
+ return self.getPlotWidget()
+
+ def _setRoiActionEnabled(self, itemKind, enabled):
+ for action in self.__multiAction.getMenu().actions():
+ if not isinstance(action, roi_mdl.CreateRoiModeAction):
+ continue
+ roiClass = action.getRoiClass()
+ if issubclass(itemKind, roiClass.ITEM_KIND):
+ action.setEnabled(enabled)
+
+ def _activeImageChanged(self, previous=None, legend=None):
+ """Handle active image change to toggle actions"""
+ if legend is None:
+ self._setRoiActionEnabled(items.ImageStack, False)
+ self._setRoiActionEnabled(items.ImageBase, False)
+ else:
+ plot = self.getPlotWidget()
+ image = plot.getActiveImage()
+ # Disable for empty image
+ enabled = image.getData(copy=False).size > 0
+ self._setRoiActionEnabled(type(image), enabled)
+
+ @property
+ @deprecated(since_version="0.6.0")
+ def browseAction(self):
+ return self._browseAction
+
+ @property
+ @deprecated(replacement="getProfilePlot", since_version="0.5.0")
+ def profileWindow(self):
+ return self.getProfilePlot()
+
+ def getProfileManager(self):
+ """Return the manager of the profiles.
+
+ :rtype: ProfileManager
+ """
+ return self._manager
+
+ @deprecated(since_version="0.13.0")
+ def getProfilePlot(self):
+ """Return plot widget in which the profile curve or the
+ profile image is plotted.
+ """
+ window = self.getProfileMainWindow()
+ if window is None:
+ return None
+ return window.getCurrentPlotWidget()
+
+ @deprecated(replacement="getProfileManager().getCurrentRoi().getProfileWindow()", since_version="0.13.0")
+ def getProfileMainWindow(self):
+ """Return window containing the profile curve widget.
+
+ This can return None if no profile was computed.
+ """
+ roi = self._manager.getCurrentRoi()
+ if roi is None:
+ return None
+ return roi.getProfileWindow()
+
+ @property
+ @deprecated(since_version="0.13.0")
+ def overlayColor(self):
+ """This method does nothing anymore. But could be implemented if needed.
+
+ It was used to set color to use for the ROI.
+
+ If set to None (the default), the overlay color is adapted to the
+ active image colormap and changes if the active image colormap changes.
+ """
+ pass
+
+ @overlayColor.setter
+ @deprecated(since_version="0.13.0")
+ def overlayColor(self, color):
+ """This method does nothing anymore. But could be implemented if needed.
+ """
+ pass
+
+ def clearProfile(self):
+ """Remove profile curve and profile area."""
+ self._manager.clearProfile()
+
+ @deprecated(since_version="0.13.0")
+ def updateProfile(self):
+ """This method does nothing anymore. But could be implemented if needed.
+
+ It was used to update the displayed profile and profile ROI.
+
+ This uses the current active image of the plot and the current ROI.
+ """
+ pass
+
+ @deprecated(replacement="clearProfile()", since_version="0.13.0")
+ def hideProfileWindow(self):
+ """Hide profile window.
+ """
+ self.clearProfile()
+
+ @deprecated(since_version="0.13.0")
+ def setProfileMethod(self, method):
+ assert method in ('sum', 'mean')
+ roi = self._manager.getCurrentRoi()
+ if roi is None:
+ raise RuntimeError("No profile ROI selected")
+ roi.setProfileMethod(method)
+
+ @deprecated(since_version="0.13.0")
+ def getProfileMethod(self):
+ roi = self._manager.getCurrentRoi()
+ if roi is None:
+ raise RuntimeError("No profile ROI selected")
+ return roi.getProfileMethod()
+
+ @deprecated(since_version="0.13.0")
+ def getProfileOptionToolAction(self):
+ return self._editor
+
+
+class Profile3DToolBar(ProfileToolBar):
+ def __init__(self, parent=None, stackview=None,
+ title=None):
+ """QToolBar providing profile tools for an image or a stack of images.
+
+ :param parent: the parent QWidget
+ :param stackview: :class:`StackView` instance on which to operate.
+ :param str title: See :class:`QToolBar`.
+ :param parent: See :class:`QToolBar`.
+ """
+ # TODO: add param profileWindow (specify the plot used for profiles)
+ super(Profile3DToolBar, self).__init__(parent=parent,
+ plot=stackview.getPlotWidget())
+
+ if title is not None:
+ deprecated_warning("Attribute",
+ name="title",
+ reason="removed",
+ since_version="0.13.0",
+ only_once=True,
+ skip_backtrace_count=1)
+
+ self.stackView = stackview
+ """:class:`StackView` instance"""
+
+ def _createProfileActions(self):
+ self.hLineAction = self._manager.createProfileAction(rois.ProfileImageStackHorizontalLineROI, self)
+ self.vLineAction = self._manager.createProfileAction(rois.ProfileImageStackVerticalLineROI, self)
+ self.lineAction = self._manager.createProfileAction(rois.ProfileImageStackLineROI, self)
+ self.crossAction = self._manager.createProfileAction(rois.ProfileImageStackCrossROI, self)
+ self.clearAction = self._manager.createClearAction(self)
diff --git a/src/silx/gui/plot/ProfileMainWindow.py b/src/silx/gui/plot/ProfileMainWindow.py
new file mode 100644
index 0000000..ce56cfd
--- /dev/null
+++ b/src/silx/gui/plot/ProfileMainWindow.py
@@ -0,0 +1,110 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module contains a QMainWindow class used to display profile plots.
+"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "21/02/2017"
+
+import silx.utils.deprecation
+from silx.gui import qt
+from .tools.profile.manager import ProfileWindow
+
+silx.utils.deprecation.deprecated_warning("Module",
+ name="silx.gui.plot.ProfileMainWindow",
+ reason="moved",
+ replacement="silx.gui.plot.tools.profile.manager.ProfileWindow",
+ since_version="0.13.0",
+ only_once=True,
+ skip_backtrace_count=1)
+
+class ProfileMainWindow(ProfileWindow):
+ """QMainWindow providing 2 plot widgets specialized in
+ 1D and 2D plotting, with different toolbars.
+
+ Only one of the plots is visible at any given time.
+
+ :param qt.QWidget parent: The parent of this widget or None (default).
+ :param Union[str,Class] backend: The backend to use, in:
+ 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'none'
+ or a :class:`BackendBase.BackendBase` class
+ """
+
+ sigProfileDimensionsChanged = qt.Signal(int)
+ """This signal is emitted when :meth:`setProfileDimensions` is called.
+ It carries the number of dimensions for the profile data (1 or 2).
+ It can be used to be notified that the profile plot widget has changed.
+
+ Note: This signal should be removed.
+ """
+
+ sigProfileMethodChanged = qt.Signal(str)
+ """Emitted when the method to compute the profile changed (for now can be
+ sum or mean)
+
+ Note: This signal should be removed.
+ """
+
+ def __init__(self, parent=None, backend=None):
+ ProfileWindow.__init__(self, parent=parent, backend=backend)
+ # by default, profile is assumed to be a 1D curve
+ self._profileType = None
+
+ def setProfileType(self, profileType):
+ """Set which profile plot widget (1D or 2D) is to be used
+
+ Note: This method should be removed.
+
+ :param str profileType: Type of profile data,
+ "1D" for a curve or "2D" for an image
+ """
+ self._profileType = profileType
+ if self._profileType == "1D":
+ self._showPlot1D()
+ elif self._profileType == "2D":
+ self._showPlot2D()
+ else:
+ raise ValueError("Profile type must be '1D' or '2D'")
+ self.sigProfileDimensionsChanged.emit(profileType)
+
+ def getPlot(self):
+ """Return the profile plot widget which is currently in use.
+ This can be the 2D profile plot or the 1D profile plot.
+
+ Note: This method should be removed.
+ """
+ return self.getCurrentPlotWidget()
+
+ def setProfileMethod(self, method):
+ """
+ Note: This method should be removed.
+
+ :param str method: method to manage the 'width' in the profile
+ (computing mean or sum).
+ """
+ assert method in ('sum', 'mean')
+ self._method = method
+ self.sigProfileMethodChanged.emit(self._method)
diff --git a/src/silx/gui/plot/ROIStatsWidget.py b/src/silx/gui/plot/ROIStatsWidget.py
new file mode 100644
index 0000000..32a1395
--- /dev/null
+++ b/src/silx/gui/plot/ROIStatsWidget.py
@@ -0,0 +1,780 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides widget for displaying statistics relative to a
+Region of interest and an item
+"""
+
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "22/07/2019"
+
+
+from contextlib import contextmanager
+from silx.gui import qt
+from silx.gui import icons
+from silx.gui.plot.StatsWidget import _StatsWidgetBase, StatsTable, _Container
+from silx.gui.plot.StatsWidget import UpdateModeWidget, UpdateMode
+from silx.gui.widgets.TableWidget import TableWidget
+from silx.gui.plot.items.roi import RegionOfInterest
+from silx.gui.plot import items as plotitems
+from silx.gui.plot.items.core import ItemChangedType
+from silx.gui.plot3d import items as plot3ditems
+from silx.gui.plot.CurvesROIWidget import ROI
+from silx.gui.plot import stats as statsmdl
+from collections import OrderedDict
+from silx.utils.proxy import docstring
+import silx.gui.plot.items.marker
+import silx.gui.plot.items.shape
+import functools
+import logging
+
+_logger = logging.getLogger(__name__)
+
+
+class _GetROIItemCoupleDialog(qt.QDialog):
+ """
+ Dialog used to know which plot item and which roi he wants
+ """
+ _COMPATIBLE_KINDS = ('curve', 'image', 'scatter', 'histogram')
+
+ def __init__(self, parent=None, plot=None, rois=None):
+ qt.QDialog.__init__(self, parent=parent)
+ assert plot is not None
+ assert rois is not None
+ self._plot = plot
+ self._rois = rois
+
+ self.setLayout(qt.QVBoxLayout())
+
+ # define the selection widget
+ self._selection_widget = qt.QWidget()
+ self._selection_widget.setLayout(qt.QHBoxLayout())
+ self._kindCB = qt.QComboBox(parent=self)
+ self._selection_widget.layout().addWidget(self._kindCB)
+ self._itemCB = qt.QComboBox(parent=self)
+ self._selection_widget.layout().addWidget(self._itemCB)
+ self._roiCB = qt.QComboBox(parent=self)
+ self._selection_widget.layout().addWidget(self._roiCB)
+ self.layout().addWidget(self._selection_widget)
+
+ # define modal buttons
+ types = qt.QDialogButtonBox.Ok | qt.QDialogButtonBox.Cancel
+ self._buttonsModal = qt.QDialogButtonBox(parent=self)
+ self._buttonsModal.setStandardButtons(types)
+ self.layout().addWidget(self._buttonsModal)
+ self._buttonsModal.accepted.connect(self.accept)
+ self._buttonsModal.rejected.connect(self.reject)
+
+ # connect signal / slot
+ self._kindCB.currentIndexChanged.connect(self._updateValidItemAndRoi)
+
+ def _getCompatibleRois(self, kind):
+ """Return compatible rois for the given item kind"""
+ def is_compatible(roi, kind):
+ if isinstance(roi, RegionOfInterest):
+ return kind in ('image', 'scatter')
+ elif isinstance(roi, ROI):
+ return kind in ('curve', 'histogram')
+ else:
+ raise ValueError('kind not managed')
+ return list(filter(lambda x: is_compatible(x, kind), self._rois))
+
+ def exec(self):
+ self._kindCB.clear()
+ self._itemCB.clear()
+ # filter kind without any items
+ self._valid_kinds = {}
+ # key is item type, value kinds
+ self._valid_rois = {}
+ # key is item type, value rois
+ self._kind_name_to_roi = {}
+ # key is (kind, roi name) value is roi
+ self._kind_name_to_item = {}
+ # key is (kind, legend name) value is item
+ for kind in _GetROIItemCoupleDialog._COMPATIBLE_KINDS:
+ def getItems(kind):
+ output = []
+ for item in self._plot.getItems():
+ type_ = self._plot._itemKind(item)
+ if type_ in kind and item.isVisible():
+ output.append(item)
+ return output
+
+ items = getItems(kind=kind)
+ rois = self._getCompatibleRois(kind=kind)
+ if len(items) > 0 and len(rois) > 0:
+ self._valid_kinds[kind] = items
+ self._valid_rois[kind] = rois
+ for roi in rois:
+ name = roi.getName()
+ self._kind_name_to_roi[(kind, name)] = roi
+ for item in items:
+ self._kind_name_to_item[(kind, item.getLegend())] = item
+
+ # filter roi according to kinds
+ if len(self._valid_kinds) == 0:
+ _logger.warning('no couple item/roi detected for displaying stats')
+ return self.reject()
+
+ for kind in self._valid_kinds:
+ self._kindCB.addItem(kind)
+ self._updateValidItemAndRoi()
+
+ return qt.QDialog.exec(self)
+
+ def exec_(self): # Qt5 compatibility
+ return self.exec()
+
+ def _updateValidItemAndRoi(self, *args, **kwargs):
+ self._itemCB.clear()
+ self._roiCB.clear()
+ kind = self._kindCB.currentText()
+ for roi in self._valid_rois[kind]:
+ self._roiCB.addItem(roi.getName())
+ for item in self._valid_kinds[kind]:
+ self._itemCB.addItem(item.getLegend())
+
+ def getROI(self):
+ kind = self._kindCB.currentText()
+ roi_name = self._roiCB.currentText()
+ return self._kind_name_to_roi[(kind, roi_name)]
+
+ def getItem(self):
+ kind = self._kindCB.currentText()
+ item_name = self._itemCB.currentText()
+ return self._kind_name_to_item[(kind, item_name)]
+
+
+class ROIStatsItemHelper(object):
+ """Item utils to associate a plot item and a roi
+
+ Display on one row statistics regarding the couple
+ (Item (plot item) / roi).
+
+ :param Item plot_item: item for which we want statistics
+ :param Union[ROI,RegionOfInterest]: region of interest to use for
+ statistics.
+ """
+ def __init__(self, plot_item, roi):
+ self._plot_item = plot_item
+ self._roi = roi
+
+ @property
+ def roi(self):
+ """roi"""
+ return self._roi
+
+ def roi_name(self):
+ if isinstance(self._roi, ROI):
+ return self._roi.getName()
+ elif isinstance(self._roi, RegionOfInterest):
+ return self._roi.getName()
+ else:
+ raise TypeError('Unmanaged roi type')
+
+ @property
+ def roi_kind(self):
+ """roi class"""
+ return self._roi.__class__
+
+ # TODO: should call a util function from the wrapper ?
+ def item_kind(self):
+ """item kind"""
+ if isinstance(self._plot_item, plotitems.Curve):
+ return 'curve'
+ elif isinstance(self._plot_item, plotitems.ImageData):
+ return 'image'
+ elif isinstance(self._plot_item, plotitems.Scatter):
+ return 'scatter'
+ elif isinstance(self._plot_item, plotitems.Histogram):
+ return 'histogram'
+ elif isinstance(self._plot_item, (plot3ditems.ImageData,
+ plot3ditems.ScalarField3D)):
+ return 'image'
+ elif isinstance(self._plot_item, (plot3ditems.Scatter2D,
+ plot3ditems.Scatter3D)):
+ return 'scatter'
+
+ @property
+ def item_legend(self):
+ """legend of the plot Item"""
+ return self._plot_item.getLegend()
+
+ def id_key(self):
+ """unique key to represent the couple (item, roi)"""
+ return (self.item_kind(), self.item_legend, self.roi_kind,
+ self.roi_name())
+
+
+class _StatsROITable(_StatsWidgetBase, TableWidget):
+ """
+ Table sued to display some statistics regarding a couple (item/roi)
+ """
+ _LEGEND_HEADER_DATA = 'legend'
+
+ _KIND_HEADER_DATA = 'kind'
+
+ _ROI_HEADER_DATA = 'roi'
+
+ sigUpdateModeChanged = qt.Signal(object)
+ """Signal emitted when the update mode changed"""
+
+ def __init__(self, parent, plot):
+ TableWidget.__init__(self, parent)
+ _StatsWidgetBase.__init__(self, statsOnVisibleData=False,
+ displayOnlyActItem=False)
+ self.__region_edition_callback = {}
+ """We need to keep trace of the roi signals connection because
+ the roi emits the sigChanged during roi edition"""
+ self._items = {}
+ self.setRowCount(0)
+ self.setColumnCount(3)
+
+ # Init headers
+ headerItem = qt.QTableWidgetItem(self._LEGEND_HEADER_DATA.title())
+ headerItem.setData(qt.Qt.UserRole, self._LEGEND_HEADER_DATA)
+ self.setHorizontalHeaderItem(0, headerItem)
+ headerItem = qt.QTableWidgetItem(self._KIND_HEADER_DATA.title())
+ headerItem.setData(qt.Qt.UserRole, self._KIND_HEADER_DATA)
+ self.setHorizontalHeaderItem(1, headerItem)
+ headerItem = qt.QTableWidgetItem(self._ROI_HEADER_DATA.title())
+ headerItem.setData(qt.Qt.UserRole, self._ROI_HEADER_DATA)
+ self.setHorizontalHeaderItem(2, headerItem)
+
+ self.setSortingEnabled(True)
+ self.setPlot(plot)
+
+ self.__plotItemToItems = {}
+ """Key is plotItem, values is list of __RoiStatsItemWidget"""
+ self.__roiToItems = {}
+ """Key is roi, values is list of __RoiStatsItemWidget"""
+ self.__roisKeyToRoi = {}
+
+ def add(self, item):
+ assert isinstance(item, ROIStatsItemHelper)
+ if item.id_key() in self._items:
+ _logger.warning("Item %s is already present", item.id_key())
+ return None
+ self._items[item.id_key()] = item
+ self._addItem(item)
+ return item
+
+ def _addItem(self, item):
+ """
+ Add a _RoiStatsItemWidget item to the table.
+
+ :param item:
+ :return: True if successfully added.
+ """
+ if not isinstance(item, ROIStatsItemHelper):
+ # skipped because also receive all new plot item (Marker...) that
+ # we don't want to manage in this case.
+ return
+ # plotItem = item.getItem()
+ # roi = item.getROI()
+ kind = item.item_kind()
+ if kind not in statsmdl.BASIC_COMPATIBLE_KINDS:
+ _logger.info("Item has not a supported type: %s", item)
+ return False
+
+ # register the roi and the kind
+ self._registerPlotItem(item)
+ self._registerROI(item)
+
+ # Prepare table items
+ tableItems = [
+ qt.QTableWidgetItem(), # Legend
+ qt.QTableWidgetItem(), # Kind
+ qt.QTableWidgetItem()] # roi
+
+ for column in range(3, self.columnCount()):
+ header = self.horizontalHeaderItem(column)
+ name = header.data(qt.Qt.UserRole)
+
+ formatter = self._statsHandler.formatters[name]
+ if formatter:
+ tableItem = formatter.tabWidgetItemClass()
+ else:
+ tableItem = qt.QTableWidgetItem()
+
+ tooltip = self._statsHandler.stats[name].getToolTip(kind=kind)
+ if tooltip is not None:
+ tableItem.setToolTip(tooltip)
+
+ tableItems.append(tableItem)
+
+ # Disable sorting while adding table items
+ with self._disableSorting():
+ # Add a row to the table
+ self.setRowCount(self.rowCount() + 1)
+
+ # Add table items to the last row
+ row = self.rowCount() - 1
+ for column, tableItem in enumerate(tableItems):
+ tableItem.setData(qt.Qt.UserRole, _Container(item))
+ tableItem.setFlags(
+ qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable)
+ self.setItem(row, column, tableItem)
+
+ # Update table items content
+ self._updateStats(item, data_changed=True)
+
+ # Listen for item changes
+ # Using queued connection to avoid issue with sender
+ # being that of the signal calling the signal
+ item._plot_item.sigItemChanged.connect(self._plotItemChanged,
+ qt.Qt.QueuedConnection)
+ return True
+
+ def _removeAllItems(self):
+ for row in range(self.rowCount()):
+ tableItem = self.item(row, 0)
+ # item = self._tableItemToItem(tableItem)
+ # item.sigItemChanged.disconnect(self._plotItemChanged)
+ self.clearContents()
+ self.setRowCount(0)
+
+ def clear(self):
+ self._removeAllItems()
+
+ def setStats(self, statsHandler):
+ """Set which stats to display and the associated formatting.
+
+ :param StatsHandler statsHandler:
+ Set the statistics to be displayed and how to format them using
+ """
+ self._removeAllItems()
+ _StatsWidgetBase.setStats(self, statsHandler)
+
+ self.setRowCount(0)
+ self.setColumnCount(len(self._statsHandler.stats) + 3) # + legend, kind and roi # noqa
+
+ for index, stat in enumerate(self._statsHandler.stats.values()):
+ headerItem = qt.QTableWidgetItem(stat.name.capitalize())
+ headerItem.setData(qt.Qt.UserRole, stat.name)
+ if stat.description is not None:
+ headerItem.setToolTip(stat.description)
+ self.setHorizontalHeaderItem(3 + index, headerItem)
+
+ horizontalHeader = self.horizontalHeader()
+ horizontalHeader.setSectionResizeMode(qt.QHeaderView.ResizeToContents)
+
+ self._updateItemObserve()
+
+ def _updateItemObserve(self, *args):
+ pass
+
+ def _dataChanged(self, item):
+ pass
+
+ def _updateStats(self, item, data_changed=False, roi_changed=False):
+ assert isinstance(item, ROIStatsItemHelper)
+ plotItem = item._plot_item
+ roi = item._roi
+ if item is None:
+ return
+ plot = self.getPlot()
+ if plot is None:
+ _logger.info("Plot not available")
+ return
+
+ row = self._itemToRow(item)
+ if row is None:
+ _logger.error("This item is not in the table: %s", str(item))
+ return
+
+ statsHandler = self.getStatsHandler()
+ if statsHandler is not None:
+ stats = statsHandler.calculate(plotItem, plot,
+ onlimits=self._statsOnVisibleData,
+ roi=roi, data_changed=data_changed,
+ roi_changed=roi_changed)
+ else:
+ stats = {}
+
+ with self._disableSorting():
+ for name, tableItem in self._itemToTableItems(item).items():
+ if name == self._LEGEND_HEADER_DATA:
+ text = self._plotWrapper.getLabel(plotItem)
+ tableItem.setText(text)
+ elif name == self._KIND_HEADER_DATA:
+ tableItem.setText(self._plotWrapper.getKind(plotItem))
+ elif name == self._ROI_HEADER_DATA:
+ name = roi.getName()
+ tableItem.setText(name)
+ else:
+ value = stats.get(name)
+ if value is None:
+ _logger.error("Value not found for: %s", name)
+ tableItem.setText('-')
+ else:
+ tableItem.setText(str(value))
+
+ @contextmanager
+ def _disableSorting(self):
+ """Context manager that disables table sorting
+
+ Previous state is restored when leaving
+ """
+ sorting = self.isSortingEnabled()
+ if sorting:
+ self.setSortingEnabled(False)
+ yield
+ if sorting:
+ self.setSortingEnabled(sorting)
+
+ def _itemToRow(self, item):
+ """Find the row corresponding to a plot item
+
+ :param item: The plot item
+ :return: The corresponding row index
+ :rtype: Union[int,None]
+ """
+ for row in range(self.rowCount()):
+ tableItem = self.item(row, 0)
+ if self._tableItemToItem(tableItem) == item:
+ return row
+ return None
+
+ def _tableItemToItem(self, tableItem):
+ """Find the plot item corresponding to a table item
+
+ :param QTableWidgetItem tableItem:
+ :rtype: QObject
+ """
+ container = tableItem.data(qt.Qt.UserRole)
+ return container()
+
+ def _itemToTableItems(self, item):
+ """Find all table items corresponding to a plot item
+
+ :param item: The plot item
+ :return: An ordered dict of column name to QTableWidgetItem mapping
+ for the given plot item.
+ :rtype: OrderedDict
+ """
+ result = OrderedDict()
+ row = self._itemToRow(item)
+ if row is not None:
+ for column in range(self.columnCount()):
+ tableItem = self.item(row, column)
+ if self._tableItemToItem(tableItem) != item:
+ _logger.error("Table item/plot item mismatch")
+ else:
+ header = self.horizontalHeaderItem(column)
+ name = header.data(qt.Qt.UserRole)
+ result[name] = tableItem
+ return result
+
+ def _plotItemToItems(self, plotItem):
+ """Return all _RoiStatsItemWidget associated to the plotItem
+ Needed for updating on itemChanged signal
+ """
+ if plotItem in self.__plotItemToItems:
+ return []
+ else:
+ return self.__plotItemToItems[plotItem]
+
+ def _registerPlotItem(self, item):
+ if item._plot_item not in self.__plotItemToItems:
+ self.__plotItemToItems[item._plot_item] = set()
+ self.__plotItemToItems[item._plot_item].add(item)
+
+ def _roiToItems(self, roi):
+ """Return all _RoiStatsItemWidget associated to the roi
+ Needed for updating on roiChanged signal
+ """
+ if roi in self.__roiToItems:
+ return []
+ else:
+ return self.__roiToItems[roi]
+
+ def _registerROI(self, item):
+ if item._roi not in self.__roiToItems:
+ self.__roiToItems[item._roi] = set()
+ # TODO: normalize also sig name
+ if isinstance(item._roi, RegionOfInterest):
+ # item connection within sigRegionChanged should only be
+ # stopped during the region edition
+ self.__region_edition_callback[item._roi] = functools.partial(
+ self._updateAllStats, False, True)
+ item._roi.sigRegionChanged.connect(self.__region_edition_callback[item._roi])
+ item._roi.sigEditingStarted.connect(functools.partial(
+ self._startFiltering, item._roi))
+ item._roi.sigEditingFinished.connect(functools.partial(
+ self._endFiltering, item._roi))
+ else:
+ item._roi.sigChanged.connect(functools.partial(
+ self._updateAllStats, False, True))
+ self.__roiToItems[item._roi].add(item)
+
+ def _startFiltering(self, roi):
+ roi.sigRegionChanged.disconnect(self.__region_edition_callback[roi])
+
+ def _endFiltering(self, roi):
+ roi.sigRegionChanged.connect(self.__region_edition_callback[roi])
+ self._updateAllStats(roi_changed=True)
+
+ def unregisterROI(self, roi):
+ if roi in self.__roiToItems:
+ del self.__roiToItems[roi]
+ if isinstance(roi, RegionOfInterest):
+ roi.sigRegionEditionStarted.disconnect(functools.partial(
+ self._startFiltering, roi))
+ roi.sigRegionEditionFinished.disconnect(functools.partial(
+ self._startFiltering, roi))
+ try:
+ roi.sigRegionChanged.disconnect(self._updateAllStats)
+ except:
+ pass
+ else:
+ roi.sigChanged.disconnect(self._updateAllStats)
+
+ def _plotItemChanged(self, event):
+ """Handle modifications of the items.
+
+ :param event:
+ """
+ if event is ItemChangedType.DATA:
+ if self.getUpdateMode() is UpdateMode.MANUAL:
+ return
+ if self._skipPlotItemChangedEvent(event) is True:
+ return
+ else:
+ sender = self.sender()
+ for item in self.__plotItemToItems[sender]:
+ # TODO: get all concerned items
+ self._updateStats(item, data_changed=True)
+ # deal with stat items visibility
+ if event is ItemChangedType.VISIBLE:
+ if len(self._itemToTableItems(item).items()) > 0:
+ item_0 = list(self._itemToTableItems(item).values())[0]
+ row_index = item_0.row()
+ self.setRowHidden(row_index, not item.isVisible())
+
+ def _removeItem(self, itemKey):
+ if isinstance(itemKey, (silx.gui.plot.items.marker.Marker,
+ silx.gui.plot.items.shape.Shape)):
+ return
+ if itemKey not in self._items:
+ _logger.warning('key not recognized. Won\'t remove any item')
+ return
+ item = self._items[itemKey]
+ row = self._itemToRow(item)
+ if row is None:
+ kind = self._plotWrapper.getKind(item)
+ if kind in statsmdl.BASIC_COMPATIBLE_KINDS:
+ _logger.error("Removing item that is not in table: %s", str(item))
+ return
+ item._plot_item.sigItemChanged.disconnect(self._plotItemChanged)
+ self.removeRow(row)
+ del self._items[itemKey]
+
+ def _updateAllStats(self, is_request=False, roi_changed=False):
+ """Update stats for all rows in the table
+
+ :param bool is_request: True if come from a manual request
+ """
+ if (self.getUpdateMode() is UpdateMode.MANUAL and
+ not is_request and not roi_changed):
+ return
+
+ with self._disableSorting():
+ for row in range(self.rowCount()):
+ tableItem = self.item(row, 0)
+ item = self._tableItemToItem(tableItem)
+ self._updateStats(item, roi_changed=roi_changed,
+ data_changed=is_request)
+
+ def _plotCurrentChanged(self, *args):
+ pass
+
+ def _getRoi(self, kind, name):
+ """return the roi fitting the requirement kind, name. This information
+ is enough to be sure it is unique (in the widget)"""
+ for roi in self.__roiToItems:
+ roiName = roi.getName()
+ if isinstance(roi, kind) and name == roiName:
+ return roi
+ return None
+
+ def _getPlotItem(self, kind, legend):
+ """return the plotItem fitting the requirement kind, legend.
+ This information is enough to be sure it is unique (in the widget)"""
+ for plotItem in self.__plotItemToItems:
+ if legend == plotItem.getLegend() and self._plotWrapper.getKind(plotItem) == kind:
+ return plotItem
+ return None
+
+
+class ROIStatsWidget(qt.QMainWindow):
+ """
+ Widget used to define stats item for a couple(roi, plotItem).
+ Stats will be computing on a given item (curve, image...) in the given
+ region of interest.
+
+ It also provide an interface for adding and removing items.
+
+ .. snapshotqt:: img/ROIStatsWidget.png
+ :width: 300px
+ :align: center
+
+ from silx.gui import qt
+ from silx.gui.plot import Plot2D
+ from silx.gui.plot.ROIStatsWidget import ROIStatsWidget
+ from silx.gui.plot.items.roi import RectangleROI
+ import numpy
+ plot = Plot2D()
+ plot.addImage(numpy.arange(10000).reshape(100, 100), legend='img')
+ plot.show()
+ rectangleROI = RectangleROI()
+ rectangleROI.setGeometry(origin=(0, 100), size=(20, 20))
+ rectangleROI.setName('Initial ROI')
+ widget = ROIStatsWidget(plot=plot)
+ widget.setStats([('sum', numpy.sum), ('mean', numpy.mean)])
+ widget.registerROI(rectangleROI)
+ widget.addItem(roi=rectangleROI, plotItem=plot.getImage('img'))
+ widget.show()
+
+ :param Union[qt.QWidget,None] parent: parent qWidget
+ :param PlotWindow plot: plot widget containing the items
+ :param stats: stats to display
+ :param tuple rois: tuple of rois to manage
+ """
+
+ def __init__(self, parent=None, plot=None, stats=None, rois=None):
+ qt.QMainWindow.__init__(self, parent)
+
+ toolbar = qt.QToolBar(self)
+ icon = icons.getQIcon('add')
+ self._rois = list(rois) if rois is not None else []
+ self._addAction = qt.QAction(icon, 'add item/roi', toolbar)
+ self._addAction.triggered.connect(self._addRoiStatsItem)
+ icon = icons.getQIcon('rm')
+ self._removeAction = qt.QAction(icon, 'remove item/roi', toolbar)
+ self._removeAction.triggered.connect(self._removeCurrentRow)
+
+ toolbar.addAction(self._addAction)
+ toolbar.addAction(self._removeAction)
+ self.addToolBar(toolbar)
+
+ self._plot = plot
+ self._statsROITable = _StatsROITable(parent=self, plot=self._plot)
+ self.setStats(stats=stats)
+ self.setCentralWidget(self._statsROITable)
+ self.setWindowFlags(qt.Qt.Widget)
+
+ # expose API
+ self._setUpdateMode = self._statsROITable.setUpdateMode
+ self._updateAllStats = self._statsROITable._updateAllStats
+
+ # setup
+ self._statsROITable.setSelectionBehavior(qt.QTableWidget.SelectRows)
+
+ def registerROI(self, roi):
+ """For now there is no direct link between roi and plot. That is why
+ we need to add/register them to be able to associate them"""
+ self._rois.append(roi)
+
+ def setPlot(self, plot):
+ """Define the plot to interact with
+
+ :param Union[PlotWidget,SceneWidget,None] plot:
+ The plot containing the items on which statistics are applied
+ """
+ self._plot = plot
+
+ def getPlot(self):
+ return self._plot
+
+ @docstring(_StatsROITable)
+ def setStats(self, stats):
+ if stats is not None:
+ self._statsROITable.setStats(statsHandler=stats)
+
+ @docstring(_StatsROITable)
+ def getStatsHandler(self):
+ """
+
+ :return:
+ """
+ return self._statsROITable.getStatsHandler()
+
+ def _addRoiStatsItem(self):
+ """Ask the user what couple ROI / item he want to display"""
+ dialog = _GetROIItemCoupleDialog(parent=self, plot=self._plot,
+ rois=self._rois)
+ if dialog.exec():
+ self.addItem(roi=dialog.getROI(), plotItem=dialog.getItem())
+
+ def addItem(self, plotItem, roi):
+ """
+ Add a row of statitstic regarding the couple (plotItem, roi)
+
+ :param Item plotItem: item to use for statistics
+ :param roi: region of interest to limit the statistic.
+ :type: Union[ROI, RegionOfInterest]
+ :return: None of failed to add the item
+ :rtype: Union[None,ROIStatsItemHelper]
+ """
+ statsItem = ROIStatsItemHelper(roi=roi, plot_item=plotItem)
+ return self._statsROITable.add(item=statsItem)
+
+ def removeItem(self, plotItem, roi):
+ """
+ Remove the row associated to the couple (plotItem, roi)
+
+ :param Item plotItem: item to use for statistics
+ :param roi: region of interest to limit the statistic.
+ :type: Union[ROI,RegionOfInterest]
+ """
+ statsItem = ROIStatsItemHelper(roi=roi, plot_item=plotItem)
+ self._statsROITable._removeItem(itemKey=statsItem.id_key())
+
+ def _removeCurrentRow(self):
+ def is1DKind(kind):
+ if kind in ('curve', 'histogram', 'scatter'):
+ return True
+ else:
+ return False
+
+ currentRow = self._statsROITable.currentRow()
+ item_kind = self._statsROITable.item(currentRow, 1).text()
+ item_legend = self._statsROITable.item(currentRow, 0).text()
+
+ roi_name = self._statsROITable.item(currentRow, 2).text()
+ roi_kind = ROI if is1DKind(item_kind) else RegionOfInterest
+ roi = self._statsROITable._getRoi(kind=roi_kind, name=roi_name)
+ if roi is None:
+ _logger.warning('failed to retrieve the roi you want to remove')
+ return False
+ plot_item = self._statsROITable._getPlotItem(kind=item_kind,
+ legend=item_legend)
+ if plot_item is None:
+ _logger.warning('failed to retrieve the plot item you want to'
+ 'remove')
+ return False
+ return self.removeItem(plotItem=plot_item, roi=roi)
diff --git a/src/silx/gui/plot/ScatterMaskToolsWidget.py b/src/silx/gui/plot/ScatterMaskToolsWidget.py
new file mode 100644
index 0000000..c242dfc
--- /dev/null
+++ b/src/silx/gui/plot/ScatterMaskToolsWidget.py
@@ -0,0 +1,621 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Widget providing a set of tools to draw masks on a PlotWidget.
+
+This widget is meant to work with a modified :class:`silx.gui.plot.PlotWidget`
+
+- :class:`ScatterMask`: Handle scatter mask update and history
+- :class:`ScatterMaskToolsWidget`: GUI for :class:`ScatterMask`
+- :class:`ScatterMaskToolsDockWidget`: DockWidget to integrate in :class:`PlotWindow`
+"""
+
+from __future__ import division
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "15/02/2019"
+
+
+import math
+import logging
+import os
+import numpy
+import sys
+
+from .. import qt
+from ...math.combo import min_max
+from ...image import shapes
+
+from .items import ItemChangedType, Scatter
+from ._BaseMaskToolsWidget import BaseMask, BaseMaskToolsWidget, BaseMaskToolsDockWidget
+from ..colors import cursorColorForColormap, rgba
+
+
+_logger = logging.getLogger(__name__)
+
+
+class ScatterMask(BaseMask):
+ """A 1D mask for scatter data.
+ """
+ def __init__(self, scatter=None):
+ """
+
+ :param scatter: :class:`silx.gui.plot.items.Scatter` instance
+ """
+ BaseMask.__init__(self, scatter)
+
+ def _getXY(self):
+ x = self._dataItem.getXData(copy=False)
+ y = self._dataItem.getYData(copy=False)
+ return x, y
+
+ def getDataValues(self):
+ """Return scatter data values as a 1D array.
+
+ :rtype: 1D numpy.ndarray
+ """
+ return self._dataItem.getValueData(copy=False)
+
+ def save(self, filename, kind):
+ if kind == 'npy':
+ try:
+ numpy.save(filename, self.getMask(copy=False))
+ except IOError:
+ raise RuntimeError("Mask file can't be written")
+ elif kind in ["csv", "txt"]:
+ try:
+ numpy.savetxt(filename, self.getMask(copy=False))
+ except IOError:
+ raise RuntimeError("Mask file can't be written")
+
+ def updatePoints(self, level, indices, mask=True):
+ """Mask/Unmask points with given indices.
+
+ :param int level: Mask level to update.
+ :param indices: Sequence or 1D array of indices of points to be
+ updated
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ if mask:
+ self._mask[indices] = level
+ else:
+ # unmask only where mask level is the specified value
+ indices_stencil = numpy.zeros_like(self._mask, dtype=bool)
+ indices_stencil[indices] = True
+ self._mask[numpy.logical_and(self._mask == level, indices_stencil)] = 0
+ self._notify()
+
+ # update shapes
+ def updatePolygon(self, level, vertices, mask=True):
+ """Mask/Unmask a polygon of the given mask level.
+
+ :param int level: Mask level to update.
+ :param vertices: Nx2 array of polygon corners as (y, x) or (row, col)
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ polygon = shapes.Polygon(vertices)
+ x, y = self._getXY()
+
+ # TODO: this could be optimized if necessary
+ indices_in_polygon = [idx for idx in range(len(x)) if
+ polygon.is_inside(y[idx], x[idx])]
+
+ self.updatePoints(level, indices_in_polygon, mask)
+
+ def updateRectangle(self, level, y, x, height, width, mask=True):
+ """Mask/Unmask data inside a rectangle
+
+ :param int level: Mask level to update.
+ :param float y: Y coordinate of bottom left corner of the rectangle
+ :param float x: X coordinate of bottom left corner of the rectangle
+ :param float height:
+ :param float width:
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ vertices = [(y, x),
+ (y + height, x),
+ (y + height, x + width),
+ (y, x + width)]
+ self.updatePolygon(level, vertices, mask)
+
+ def updateDisk(self, level, cy, cx, radius, mask=True):
+ """Mask/Unmask a disk of the given mask level.
+
+ :param int level: Mask level to update.
+ :param float cy: Disk center (y).
+ :param float cx: Disk center (x).
+ :param float radius: Radius of the disk in mask array unit
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ x, y = self._getXY()
+ stencil = (y - cy)**2 + (x - cx)**2 < radius**2
+ self.updateStencil(level, stencil, mask)
+
+ def updateEllipse(self, level, crow, ccol, radius_r, radius_c, mask=True):
+ """Mask/Unmask an ellipse of the given mask level.
+
+ :param int level: Mask level to update.
+ :param int crow: Row of the center of the ellipse
+ :param int ccol: Column of the center of the ellipse
+ :param float radius_r: Radius of the ellipse in the row
+ :param float radius_c: Radius of the ellipse in the column
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ def is_inside(px, py):
+ return (px - ccol)**2 / radius_c**2 + (py - crow)**2 / radius_r**2 <= 1.0
+ x, y = self._getXY()
+ indices_inside = [idx for idx in range(len(x)) if is_inside(x[idx], y[idx])]
+ self.updatePoints(level, indices_inside, mask)
+
+ def updateLine(self, level, y0, x0, y1, x1, width, mask=True):
+ """Mask/Unmask points inside a rectangle defined by a line (two
+ end points) and a width.
+
+ :param int level: Mask level to update.
+ :param float y0: Row of the starting point.
+ :param float x0: Column of the starting point.
+ :param float row1: Row of the end point.
+ :param float col1: Column of the end point.
+ :param float width: Width of the line.
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ # theta is the angle between the horizontal and the line
+ theta = math.atan((y1 - y0) / (x1 - x0)) if x1 - x0 else 0
+ w_over_2_sin_theta = width / 2. * math.sin(theta)
+ w_over_2_cos_theta = width / 2. * math.cos(theta)
+
+ vertices = [(y0 - w_over_2_cos_theta, x0 + w_over_2_sin_theta),
+ (y0 + w_over_2_cos_theta, x0 - w_over_2_sin_theta),
+ (y1 + w_over_2_cos_theta, x1 - w_over_2_sin_theta),
+ (y1 - w_over_2_cos_theta, x1 + w_over_2_sin_theta)]
+
+ self.updatePolygon(level, vertices, mask)
+
+
+class ScatterMaskToolsWidget(BaseMaskToolsWidget):
+ """Widget with tools for masking data points on a scatter in a
+ :class:`PlotWidget`."""
+
+ def __init__(self, parent=None, plot=None):
+ super(ScatterMaskToolsWidget, self).__init__(parent, plot,
+ mask=ScatterMask())
+ self._z = 2 # Mask layer in plot
+ self._data_scatter = None
+ """plot Scatter item for data"""
+
+ self._data_extent = None
+ """Maximum extent of the data i.e., max(xMax-xMin, yMax-yMin)"""
+
+ self._mask_scatter = None
+ """plot Scatter item for representing the mask"""
+
+ def setSelectionMask(self, mask, copy=True):
+ """Set the mask to a new array.
+
+ :param numpy.ndarray mask:
+ The array to use for the mask or None to reset the mask.
+ :type mask: numpy.ndarray of uint8, C-contiguous.
+ Array of other types are converted.
+ :param bool copy: True (the default) to copy the array,
+ False to use it as is if possible.
+ :return: None if failed, shape of mask as 1-tuple if successful.
+ The mask can be cropped or padded to fit active scatter,
+ the returned shape is that of the scatter data.
+ """
+ if self._data_scatter is None:
+ # this can happen if the mask tools widget has never been shown
+ self._data_scatter = self.plot._getActiveItem(kind="scatter")
+ if self._data_scatter is None:
+ return None
+ self._adjustColorAndBrushSize(self._data_scatter)
+
+ if mask is None:
+ self.resetSelectionMask()
+ return self._data_scatter.getXData(copy=False).shape
+
+ mask = numpy.array(mask, copy=False, dtype=numpy.uint8)
+
+ if self._data_scatter.getXData(copy=False).shape == (0,) \
+ or mask.shape == self._data_scatter.getXData(copy=False).shape:
+ self._mask.setMask(mask, copy=copy)
+ self._mask.commit()
+ return mask.shape
+ else:
+ raise ValueError("Mask does not have the same shape as the data")
+
+ # Handle mask refresh on the plot
+
+ def _updatePlotMask(self):
+ """Update mask image in plot"""
+ mask = self.getSelectionMask(copy=False)
+ if mask is not None:
+ self.plot.addScatter(self._data_scatter.getXData(),
+ self._data_scatter.getYData(),
+ mask,
+ legend=self._maskName,
+ colormap=self._colormap,
+ z=self._z)
+ self._mask_scatter = self.plot._getItem(kind="scatter",
+ legend=self._maskName)
+ self._mask_scatter.setSymbolSize(
+ self._data_scatter.getSymbolSize() + 2.0)
+ self._mask_scatter.sigItemChanged.connect(self.__maskScatterChanged)
+ elif self.plot._getItem(kind="scatter",
+ legend=self._maskName) is not None:
+ self.plot.remove(self._maskName, kind='scatter')
+
+ def __maskScatterChanged(self, event):
+ """Handles update of mask scatter"""
+ if (event is ItemChangedType.VISUALIZATION_MODE and
+ self._mask_scatter is not None):
+ self._mask_scatter.setVisualization(Scatter.Visualization.POINTS)
+
+ # track widget visibility and plot active image changes
+
+ def showEvent(self, event):
+ try:
+ self.plot.sigActiveScatterChanged.disconnect(
+ self._activeScatterChangedAfterCare)
+ except (RuntimeError, TypeError):
+ pass
+ self._activeScatterChanged(None, None) # Init mask + enable/disable widget
+ self.plot.sigActiveScatterChanged.connect(self._activeScatterChanged)
+
+ def hideEvent(self, event):
+ try:
+ # if the method is not connected this raises a TypeError and there is no way
+ # to know the connected slots
+ self.plot.sigActiveScatterChanged.disconnect(self._activeScatterChanged)
+ except (RuntimeError, TypeError):
+ _logger.info(sys.exc_info()[1])
+ if not self.browseAction.isChecked():
+ self.browseAction.trigger() # Disable drawing tool
+
+ if self.getSelectionMask(copy=False) is not None:
+ self.plot.sigActiveScatterChanged.connect(
+ self._activeScatterChangedAfterCare)
+
+ def _adjustColorAndBrushSize(self, activeScatter):
+ colormap = activeScatter.getColormap()
+ self._defaultOverlayColor = rgba(cursorColorForColormap(colormap['name']))
+ self._setMaskColors(self.levelSpinBox.value(),
+ self.transparencySlider.value() /
+ self.transparencySlider.maximum())
+ self._z = activeScatter.getZValue() + 1
+ self._data_scatter = activeScatter
+
+ # Adjust brush size to data range
+ xData = self._data_scatter.getXData(copy=False)
+ yData = self._data_scatter.getYData(copy=False)
+ # Adjust brush size to data range
+ if xData.size > 0 and yData.size > 0:
+ xMin, xMax = min_max(xData)
+ yMin, yMax = min_max(yData)
+ self._data_extent = max(xMax - xMin, yMax - yMin)
+ else:
+ self._data_extent = None
+
+ def _activeScatterChangedAfterCare(self, previous, next):
+ """Check synchro of active scatter and mask when mask widget is hidden.
+
+ If active image has no more the same size as the mask, the mask is
+ removed, otherwise it is adjusted to z.
+ """
+ # check that content changed was the active scatter
+ activeScatter = self.plot._getActiveItem(kind="scatter")
+
+ if activeScatter is None or activeScatter.getName() == self._maskName:
+ # No active scatter or active scatter is the mask...
+ self.plot.sigActiveScatterChanged.disconnect(
+ self._activeScatterChangedAfterCare)
+ self._data_extent = None
+ self._data_scatter = None
+
+ else:
+ self._adjustColorAndBrushSize(activeScatter)
+
+ if self._data_scatter.getXData(copy=False).shape != self._mask.getMask(copy=False).shape:
+ # scatter has not the same size, remove mask and stop listening
+ if self.plot._getItem(kind="scatter", legend=self._maskName):
+ self.plot.remove(self._maskName, kind='scatter')
+
+ self.plot.sigActiveScatterChanged.disconnect(
+ self._activeScatterChangedAfterCare)
+ self._data_extent = None
+ self._data_scatter = None
+
+ else:
+ # Refresh in case z changed
+ self._mask.setDataItem(self._data_scatter)
+ self._updatePlotMask()
+
+ def _activeScatterChanged(self, previous, next):
+ """Update widget and mask according to active scatter changes"""
+ activeScatter = self.plot._getActiveItem(kind="scatter")
+
+ if activeScatter is None or activeScatter.getName() == self._maskName:
+ # No active scatter or active scatter is the mask...
+ self.setEnabled(False)
+
+ self._data_scatter = None
+ self._data_extent = None
+ self._mask.reset()
+ self._mask.commit()
+
+ else: # There is an active scatter
+ self.setEnabled(True)
+ self._adjustColorAndBrushSize(activeScatter)
+
+ self._mask.setDataItem(self._data_scatter)
+ if self._data_scatter.getXData(copy=False).shape != self._mask.getMask(copy=False).shape:
+ self._mask.reset(self._data_scatter.getXData(copy=False).shape)
+ self._mask.commit()
+ else:
+ # Refresh in case z changed
+ self._updatePlotMask()
+
+ self._updateInteractiveMode()
+
+ # Handle whole mask operations
+
+ def load(self, filename):
+ """Load a mask from an image file.
+
+ :param str filename: File name from which to load the mask
+ :raise Exception: An exception in case of failure
+ :raise RuntimeWarning: In case the mask was applied but with some
+ import changes to notice
+ """
+ _, extension = os.path.splitext(filename)
+ extension = extension.lower()[1:]
+ if extension == "npy":
+ try:
+ mask = numpy.load(filename)
+ except IOError:
+ _logger.error("Can't load filename '%s'", filename)
+ _logger.debug("Backtrace", exc_info=True)
+ raise RuntimeError('File "%s" is not a numpy file.',
+ filename)
+ elif extension in ["txt", "csv"]:
+ try:
+ mask = numpy.loadtxt(filename)
+ except IOError:
+ _logger.error("Can't load filename '%s'", filename)
+ _logger.debug("Backtrace", exc_info=True)
+ raise RuntimeError('File "%s" is not a numpy txt file.',
+ filename)
+ else:
+ msg = "Extension '%s' is not supported."
+ raise RuntimeError(msg % extension)
+
+ self.setSelectionMask(mask, copy=False)
+
+ def _loadMask(self):
+ """Open load mask dialog"""
+ dialog = qt.QFileDialog(self)
+ dialog.setWindowTitle("Load Mask")
+ dialog.setModal(1)
+ filters = [
+ 'NumPy binary file (*.npy)',
+ 'CSV text file (*.csv)',
+ ]
+ dialog.setNameFilters(filters)
+ dialog.setFileMode(qt.QFileDialog.ExistingFile)
+ dialog.setDirectory(self.maskFileDir)
+ if not dialog.exec():
+ dialog.close()
+ return
+
+ filename = dialog.selectedFiles()[0]
+ dialog.close()
+
+ # Update the directory according to the user selection
+ self.maskFileDir = os.path.dirname(filename)
+
+ try:
+ self.load(filename)
+ # except RuntimeWarning as e:
+ # message = e.args[0]
+ # msg = qt.QMessageBox(self)
+ # msg.setIcon(qt.QMessageBox.Warning)
+ # msg.setText("Mask loaded but an operation was applied.\n" + message)
+ # msg.exec()
+ except Exception as e:
+ message = e.args[0]
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setText("Cannot load mask from file. " + message)
+ msg.exec()
+
+ def _saveMask(self):
+ """Open Save mask dialog"""
+ dialog = qt.QFileDialog(self)
+ dialog.setWindowTitle("Save Mask")
+ dialog.setModal(1)
+ filters = [
+ 'NumPy binary file (*.npy)',
+ 'CSV text file (*.csv)',
+ ]
+ dialog.setNameFilters(filters)
+ dialog.setFileMode(qt.QFileDialog.AnyFile)
+ dialog.setAcceptMode(qt.QFileDialog.AcceptSave)
+ dialog.setDirectory(self.maskFileDir)
+ if not dialog.exec():
+ dialog.close()
+ return
+
+ # convert filter name to extension name with the .
+ extension = dialog.selectedNameFilter().split()[-1][2:-1]
+ filename = dialog.selectedFiles()[0]
+ dialog.close()
+
+ if not filename.lower().endswith(extension):
+ filename += extension
+
+ if os.path.exists(filename):
+ try:
+ os.remove(filename)
+ except IOError as e:
+ msg = qt.QMessageBox(self)
+ msg.setWindowTitle("Removing existing file")
+ msg.setIcon(qt.QMessageBox.Critical)
+
+ if hasattr(e, "strerror"):
+ strerror = e.strerror
+ else:
+ strerror = sys.exc_info()[1]
+ msg.setText("Cannot save.\n"
+ "Input Output Error: %s" % strerror)
+ msg.exec()
+ return
+
+ # Update the directory according to the user selection
+ self.maskFileDir = os.path.dirname(filename)
+
+ try:
+ self.save(filename, extension[1:])
+ except Exception as e:
+ msg = qt.QMessageBox(self)
+ msg.setWindowTitle("Saving mask file")
+ msg.setIcon(qt.QMessageBox.Critical)
+
+ if hasattr(e, "strerror"):
+ strerror = e.strerror
+ else:
+ strerror = sys.exc_info()[1]
+ msg.setText("Cannot save file %s\n%s" % (filename, strerror))
+ msg.exec()
+
+ def resetSelectionMask(self):
+ """Reset the mask"""
+ self._mask.reset(
+ shape=self._data_scatter.getXData(copy=False).shape)
+ self._mask.commit()
+
+ def _getPencilWidth(self):
+ """Returns the width of the pencil to use in data coordinates`
+
+ :rtype: float
+ """
+ width = super(ScatterMaskToolsWidget, self)._getPencilWidth()
+ if self._data_extent is not None:
+ width *= 0.01 * self._data_extent
+ return width
+
+ def _plotDrawEvent(self, event):
+ """Handle draw events from the plot"""
+ if (self._drawingMode is None or
+ event['event'] not in ('drawingProgress', 'drawingFinished')):
+ return
+
+ if not len(self._data_scatter.getXData(copy=False)):
+ return
+
+ level = self.levelSpinBox.value()
+
+ if self._drawingMode == 'rectangle':
+ if event['event'] == 'drawingFinished':
+ doMask = self._isMasking()
+
+ self._mask.updateRectangle(
+ level,
+ y=event['y'],
+ x=event['x'],
+ height=abs(event['height']),
+ width=abs(event['width']),
+ mask=doMask)
+ self._mask.commit()
+
+ elif self._drawingMode == 'ellipse':
+ if event['event'] == 'drawingFinished':
+ doMask = self._isMasking()
+ center = event['points'][0]
+ size = event['points'][1]
+ self._mask.updateEllipse(level, center[1], center[0],
+ size[1], size[0], doMask)
+ self._mask.commit()
+
+ elif self._drawingMode == 'polygon':
+ if event['event'] == 'drawingFinished':
+ doMask = self._isMasking()
+ vertices = event['points']
+ vertices = vertices[:, (1, 0)] # (y, x)
+ self._mask.updatePolygon(level, vertices, doMask)
+ self._mask.commit()
+
+ elif self._drawingMode == 'pencil':
+ doMask = self._isMasking()
+ # convert from plot to array coords
+ x, y = event['points'][-1]
+
+ brushSize = self._getPencilWidth()
+
+ if self._lastPencilPos != (y, x):
+ if self._lastPencilPos is not None:
+ # Draw the line
+ self._mask.updateLine(
+ level,
+ self._lastPencilPos[0], self._lastPencilPos[1],
+ y, x,
+ brushSize,
+ doMask)
+
+ # Draw the very first, or last point
+ self._mask.updateDisk(level, y, x, brushSize / 2., doMask)
+
+ if event['event'] == 'drawingFinished':
+ self._mask.commit()
+ self._lastPencilPos = None
+ else:
+ self._lastPencilPos = y, x
+ else:
+ _logger.error("Drawing mode %s unsupported", self._drawingMode)
+
+ def _loadRangeFromColormapTriggered(self):
+ """Set range from active scatter colormap range"""
+ if self._data_scatter is not None:
+ # Update thresholds according to colormap
+ colormap = self._data_scatter.getColormap()
+ if colormap['autoscale']:
+ min_ = numpy.nanmin(self._data_scatter.getValueData(copy=False))
+ max_ = numpy.nanmax(self._data_scatter.getValueData(copy=False))
+ else:
+ min_, max_ = colormap['vmin'], colormap['vmax']
+ self.minLineEdit.setText(str(min_))
+ self.maxLineEdit.setText(str(max_))
+
+
+class ScatterMaskToolsDockWidget(BaseMaskToolsDockWidget):
+ """:class:`ScatterMaskToolsWidget` embedded in a QDockWidget.
+
+ For integration in a :class:`PlotWindow`.
+
+ :param parent: See :class:`QDockWidget`
+ :param plot: The PlotWidget this widget is operating on
+ :paran str name: The title of this widget
+ """
+ def __init__(self, parent=None, plot=None, name='Mask'):
+ widget = ScatterMaskToolsWidget(plot=plot)
+ super(ScatterMaskToolsDockWidget, self).__init__(parent, name, widget)
diff --git a/src/silx/gui/plot/ScatterView.py b/src/silx/gui/plot/ScatterView.py
new file mode 100644
index 0000000..d3fd2e0
--- /dev/null
+++ b/src/silx/gui/plot/ScatterView.py
@@ -0,0 +1,404 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""A widget dedicated to display scatter plots
+
+It is based on a :class:`~silx.gui.plot.PlotWidget` with additional tools
+for scatter plots.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "14/06/2018"
+
+
+import logging
+import weakref
+
+import numpy
+
+from . import items
+from . import PlotWidget
+from . import tools
+from .actions import histogram as actions_histogram
+from .tools.profile import ScatterProfileToolBar
+from .ColorBar import ColorBarWidget
+from .ScatterMaskToolsWidget import ScatterMaskToolsWidget
+
+from ..widgets.BoxLayoutDockWidget import BoxLayoutDockWidget
+from .. import qt, icons
+from ...utils.proxy import docstring
+from ...utils.weakref import WeakMethodProxy
+
+
+_logger = logging.getLogger(__name__)
+
+
+class ScatterView(qt.QMainWindow):
+ """Main window with a PlotWidget and tools specific for scatter plots.
+
+ :param parent: The parent of this widget
+ :param backend: The backend to use for the plot (default: matplotlib).
+ See :class:`~silx.gui.plot.PlotWidget` for the list of supported backend.
+ :type backend: Union[str,~silx.gui.plot.backends.BackendBase.BackendBase]
+ """
+
+ _SCATTER_LEGEND = ' '
+ """Legend used for the scatter item"""
+
+ def __init__(self, parent=None, backend=None):
+ super(ScatterView, self).__init__(parent=parent)
+ if parent is not None:
+ # behave as a widget
+ self.setWindowFlags(qt.Qt.Widget)
+ else:
+ self.setWindowTitle('ScatterView')
+
+ # Create plot widget
+ plot = PlotWidget(parent=self, backend=backend)
+ self._plot = weakref.ref(plot)
+
+ # Add an empty scatter
+ self.__createEmptyScatter()
+
+ # Create colorbar widget with white background
+ self._colorbar = ColorBarWidget(parent=self, plot=plot)
+ self._colorbar.setAutoFillBackground(True)
+ palette = self._colorbar.palette()
+ palette.setColor(qt.QPalette.Window, qt.Qt.white)
+ self._colorbar.setPalette(palette)
+
+ # Create PositionInfo widget
+ self.__lastPickingPos = None
+ self.__pickingCache = None
+ self._positionInfo = tools.PositionInfo(
+ plot=plot,
+ converters=(('X', WeakMethodProxy(self._getPickedX)),
+ ('Y', WeakMethodProxy(self._getPickedY)),
+ ('Data', WeakMethodProxy(self._getPickedValue)),
+ ('Index', WeakMethodProxy(self._getPickedIndex))))
+
+ # Combine plot, position info and colorbar into central widget
+ gridLayout = qt.QGridLayout()
+ gridLayout.setSpacing(0)
+ gridLayout.setContentsMargins(0, 0, 0, 0)
+ gridLayout.addWidget(plot, 0, 0)
+ gridLayout.addWidget(self._colorbar, 0, 1)
+ gridLayout.addWidget(self._positionInfo, 1, 0, 1, -1)
+ gridLayout.setRowStretch(0, 1)
+ gridLayout.setColumnStretch(0, 1)
+ centralWidget = qt.QWidget(self)
+ centralWidget.setLayout(gridLayout)
+ self.setCentralWidget(centralWidget)
+
+ # Create mask tool dock widget
+ self._maskToolsWidget = ScatterMaskToolsWidget(parent=self, plot=plot)
+ self._maskDock = BoxLayoutDockWidget()
+ self._maskDock.setWindowTitle('Scatter Mask')
+ self._maskDock.setWidget(self._maskToolsWidget)
+ self._maskDock.setVisible(False)
+ self.addDockWidget(qt.Qt.BottomDockWidgetArea, self._maskDock)
+
+ self._maskAction = self._maskDock.toggleViewAction()
+ self._maskAction.setIcon(icons.getQIcon('image-mask'))
+ self._maskAction.setToolTip("Display/hide mask tools")
+
+ self._intensityHistoAction = actions_histogram.PixelIntensitiesHistoAction(plot=plot, parent=self)
+
+ # Create toolbars
+ self._interactiveModeToolBar = tools.InteractiveModeToolBar(
+ parent=self, plot=plot)
+
+ self._scatterToolBar = tools.ScatterToolBar(
+ parent=self, plot=plot)
+ self._scatterToolBar.addAction(self._maskAction)
+ self._scatterToolBar.addAction(self._intensityHistoAction)
+
+ self._profileToolBar = ScatterProfileToolBar(parent=self, plot=plot)
+
+ self._outputToolBar = tools.OutputToolBar(parent=self, plot=plot)
+
+ # Activate shortcuts in PlotWindow widget:
+ for toolbar in (self._interactiveModeToolBar,
+ self._scatterToolBar,
+ self._profileToolBar,
+ self._outputToolBar):
+ self.addToolBar(toolbar)
+ for action in toolbar.actions():
+ self.addAction(action)
+
+
+ def __createEmptyScatter(self):
+ """Create an empty scatter item that is used to display the data
+
+ :rtype: ~silx.gui.plot.items.Scatter
+ """
+ plot = self.getPlotWidget()
+ plot.addScatter(x=(), y=(), value=(), legend=self._SCATTER_LEGEND)
+ scatter = plot._getItem(
+ kind='scatter', legend=self._SCATTER_LEGEND)
+ # Profile is not selectable,
+ # so it does not interfere with profile interaction
+ scatter._setSelectable(False)
+ return scatter
+
+ def _pickScatterData(self, x, y):
+ """Get data and index and value of top most scatter plot at position (x, y)
+
+ :param float x: X position in plot coordinates
+ :param float y: Y position in plot coordinates
+ :return: The data index and value at that point or None
+ """
+ pickingPos = x, y
+ if self.__lastPickingPos != pickingPos:
+ self.__pickingCache = None
+ self.__lastPickingPos = pickingPos
+
+ plot = self.getPlotWidget()
+ if plot is not None:
+ pixelPos = plot.dataToPixel(x, y)
+ if pixelPos is not None:
+ # Start from top-most item
+ result = plot._pickTopMost(
+ pixelPos[0], pixelPos[1],
+ lambda item: isinstance(item, items.Scatter))
+ if result is not None:
+ item = result.getItem()
+ if item.getVisualization() is items.Scatter.Visualization.BINNED_STATISTIC:
+ # Get highest index of closest points
+ selected = result.getIndices(copy=False)[::-1]
+ dataIndex = selected[numpy.argmin(
+ (item.getXData(copy=False)[selected] - x)**2 +
+ (item.getYData(copy=False)[selected] - y)**2)]
+ else:
+ # Get last index
+ # with matplotlib it should be the top-most point
+ dataIndex = result.getIndices(copy=False)[-1]
+ self.__pickingCache = (
+ dataIndex,
+ item.getXData(copy=False)[dataIndex],
+ item.getYData(copy=False)[dataIndex],
+ item.getValueData(copy=False)[dataIndex])
+
+ return self.__pickingCache
+
+ def _getPickedIndex(self, x, y):
+ """Get data index of top most scatter plot at position (x, y)
+
+ :param float x: X position in plot coordinates
+ :param float y: Y position in plot coordinates
+ :return: The data index at that point or '-'
+ """
+ picking = self._pickScatterData(x, y)
+ return '-' if picking is None else picking[0]
+
+ def _getPickedX(self, x, y):
+ """Returns X position snapped to scatter plot when close enough
+
+ :param float x:
+ :param float y:
+ :rtype: float
+ """
+ picking = self._pickScatterData(x, y)
+ return x if picking is None else picking[1]
+
+ def _getPickedY(self, x, y):
+ """Returns Y position snapped to scatter plot when close enough
+
+ :param float x:
+ :param float y:
+ :rtype: float
+ """
+ picking = self._pickScatterData(x, y)
+ return y if picking is None else picking[2]
+
+ def _getPickedValue(self, x, y):
+ """Get data value of top most scatter plot at position (x, y)
+
+ :param float x: X position in plot coordinates
+ :param float y: Y position in plot coordinates
+ :return: The data value at that point or '-'
+ """
+ picking = self._pickScatterData(x, y)
+ return '-' if picking is None else picking[3]
+
+ def _mouseInPlotArea(self, x, y):
+ """Clip mouse coordinates to plot area coordinates
+
+ :param float x: X position in pixels
+ :param float y: Y position in pixels
+ :return: (x, y) in data coordinates
+ """
+ plot = self.getPlotWidget()
+ left, top, width, height = plot.getPlotBoundsInPixels()
+ xPlot = numpy.clip(x, left, left + width - 1)
+ yPlot = numpy.clip(y, top, top + height - 1)
+ return xPlot, yPlot
+
+ def getPlotWidget(self):
+ """Returns the :class:`~silx.gui.plot.PlotWidget` this window is based on.
+
+ :rtype: ~silx.gui.plot.PlotWidget
+ """
+ return self._plot()
+
+ def getPositionInfoWidget(self):
+ """Returns the widget display mouse coordinates information.
+
+ :rtype: ~silx.gui.plot.tools.PositionInfo
+ """
+ return self._positionInfo
+
+ def getMaskToolsWidget(self):
+ """Returns the widget controlling mask drawing
+
+ :rtype: ~silx.gui.plot.ScatterMaskToolsWidget
+ """
+ return self._maskToolsWidget
+
+ def getInteractiveModeToolBar(self):
+ """Returns QToolBar controlling interactive mode.
+
+ :rtype: ~silx.gui.plot.tools.InteractiveModeToolBar
+ """
+ return self._interactiveModeToolBar
+
+ def getScatterToolBar(self):
+ """Returns QToolBar providing scatter plot tools.
+
+ :rtype: ~silx.gui.plot.tools.ScatterToolBar
+ """
+ return self._scatterToolBar
+
+ def getScatterProfileToolBar(self):
+ """Returns QToolBar providing scatter profile tools.
+
+ :rtype: ~silx.gui.plot.tools.profile.ScatterProfileToolBar
+ """
+ return self._profileToolBar
+
+ def getOutputToolBar(self):
+ """Returns QToolBar containing save, copy and print actions
+
+ :rtype: ~silx.gui.plot.tools.OutputToolBar
+ """
+ return self._outputToolBar
+
+ def setColormap(self, colormap=None):
+ """Set the colormap for the displayed scatter and the
+ default plot colormap.
+
+ :param ~silx.gui.colors.Colormap colormap:
+ The description of the colormap.
+ """
+ self.getScatterItem().setColormap(colormap)
+ # Resilient to call to PlotWidget API (e.g., clear)
+ self.getPlotWidget().setDefaultColormap(colormap)
+
+ def getColormap(self):
+ """Return the colormap object in use.
+
+ :return: Colormap currently in use
+ :rtype: ~silx.gui.colors.Colormap
+ """
+ return self.getScatterItem().getColormap()
+
+ # Control displayed scatter plot
+
+ def setData(self, x, y, value, xerror=None, yerror=None, alpha=None, copy=True):
+ """Set the data of the scatter plot.
+
+ To reset the scatter plot, set x, y and value to None.
+
+ :param Union[numpy.ndarray,None] x: X coordinates.
+ :param Union[numpy.ndarray,None] y: Y coordinates.
+ :param Union[numpy.ndarray,None] value:
+ The data corresponding to the value of the data points.
+ :param xerror: Values with the uncertainties on the x values.
+ If it is an array, it can either be a 1D array of
+ same length as the data or a 2D array with 2 rows
+ of same length as the data: row 0 for positive errors,
+ row 1 for negative errors.
+ :type xerror: A float, or a numpy.ndarray of float32.
+
+ :param yerror: Values with the uncertainties on the y values
+ :type yerror: A float, or a numpy.ndarray of float32. See xerror.
+ :param alpha: Values with the transparency (between 0 and 1)
+ :type alpha: A float, or a numpy.ndarray of float32
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ """
+ x = () if x is None else x
+ y = () if y is None else y
+ value = () if value is None else value
+
+ self.getScatterItem().setData(
+ x=x, y=y, value=value, xerror=xerror, yerror=yerror, alpha=alpha, copy=copy)
+
+ @docstring(items.Scatter)
+ def getData(self, *args, **kwargs):
+ return self.getScatterItem().getData(*args, **kwargs)
+
+ def getScatterItem(self):
+ """Returns the plot item displaying the scatter data.
+
+ This allows to set the style of the displayed scatter.
+
+ :rtype: ~silx.gui.plot.items.Scatter
+ """
+ plot = self.getPlotWidget()
+ scatter = plot._getItem(kind='scatter', legend=self._SCATTER_LEGEND)
+ if scatter is None: # Resilient to call to PlotWidget API (e.g., clear)
+ scatter = self.__createEmptyScatter()
+ return scatter
+
+ # Convenient proxies
+
+ @docstring(PlotWidget)
+ def getXAxis(self, *args, **kwargs):
+ return self.getPlotWidget().getXAxis(*args, **kwargs)
+
+ @docstring(PlotWidget)
+ def getYAxis(self, *args, **kwargs):
+ return self.getPlotWidget().getYAxis(*args, **kwargs)
+
+ @docstring(PlotWidget)
+ def setGraphTitle(self, *args, **kwargs):
+ return self.getPlotWidget().setGraphTitle(*args, **kwargs)
+
+ @docstring(PlotWidget)
+ def getGraphTitle(self, *args, **kwargs):
+ return self.getPlotWidget().getGraphTitle(*args, **kwargs)
+
+ @docstring(PlotWidget)
+ def resetZoom(self, *args, **kwargs):
+ return self.getPlotWidget().resetZoom(*args, **kwargs)
+
+ @docstring(ScatterMaskToolsWidget)
+ def getSelectionMask(self, *args, **kwargs):
+ return self.getMaskToolsWidget().getSelectionMask(*args, **kwargs)
+
+ @docstring(ScatterMaskToolsWidget)
+ def setSelectionMask(self, *args, **kwargs):
+ return self.getMaskToolsWidget().setSelectionMask(*args, **kwargs)
diff --git a/src/silx/gui/plot/StackView.py b/src/silx/gui/plot/StackView.py
new file mode 100644
index 0000000..56793d7
--- /dev/null
+++ b/src/silx/gui/plot/StackView.py
@@ -0,0 +1,1254 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""QWidget displaying a 3D volume as a stack of 2D images.
+
+The :class:`StackView` class implements this widget.
+
+Basic usage of :class:`StackView` is through the following methods:
+
+- :meth:`StackView.getColormap`, :meth:`StackView.setColormap` to update the
+ default colormap to use and update the currently displayed image.
+- :meth:`StackView.setStack` to update the displayed image.
+
+The :class:`StackView` uses :class:`PlotWindow` and also
+exposes a subset of the :class:`silx.gui.plot.Plot` API for further control
+(plot title, axes labels, ...).
+
+The :class:`StackViewMainWindow` class implements a widget that adds a status
+bar displaying the 3D index and the value under the mouse cursor.
+
+Example::
+
+ import numpy
+ import sys
+ from silx.gui import qt
+ from silx.gui.plot.StackView import StackViewMainWindow
+
+
+ app = qt.QApplication(sys.argv[1:])
+
+ # synthetic data, stack of 100 images of size 200x300
+ mystack = numpy.fromfunction(
+ lambda i, j, k: numpy.sin(i/15.) + numpy.cos(j/4.) + 2 * numpy.sin(k/6.),
+ (100, 200, 300)
+ )
+
+
+ sv = StackViewMainWindow()
+ sv.setColormap("jet", autoscale=True)
+ sv.setStack(mystack)
+ sv.setLabels(["1st dim (0-99)", "2nd dim (0-199)",
+ "3rd dim (0-299)"])
+ sv.show()
+
+ app.exec()
+
+"""
+
+__authors__ = ["P. Knobel", "H. Payno"]
+__license__ = "MIT"
+__date__ = "10/10/2018"
+
+import numpy
+import logging
+
+import silx
+from silx.gui import qt
+from .. import icons
+from . import items, PlotWindow, actions
+from .items.image import ImageStack
+from ..colors import Colormap
+from ..colors import cursorColorForColormap
+from .tools import LimitsToolBar
+from .Profile import Profile3DToolBar
+from ..widgets.FrameBrowser import HorizontalSliderWithBrowser
+
+from silx.gui.plot.actions import control as actions_control
+from silx.gui.plot.actions import io as silx_io
+from silx.io.nxdata import save_NXdata
+from silx.utils.array_like import DatasetView, ListOfImages
+from silx.math import calibration
+from silx.utils.deprecation import deprecated_warning
+from silx.utils.deprecation import deprecated
+
+import h5py
+from silx.io.utils import is_dataset
+
+_logger = logging.getLogger(__name__)
+
+
+class StackView(qt.QMainWindow):
+ """Stack view widget, to display and browse through stack of
+ images.
+
+ The profile tool can be switched to "3D" mode, to compute the profile
+ on each image of the stack (not only the active image currently displayed)
+ and display the result as a slice.
+
+ :param QWidget parent: the Qt parent, or None
+ :param backend: The backend to use for the plot (default: matplotlib).
+ See :class:`.PlotWidget` for the list of supported backend.
+ :type backend: str or :class:`BackendBase.BackendBase`
+ :param bool resetzoom: Toggle visibility of reset zoom action.
+ :param bool autoScale: Toggle visibility of axes autoscale actions.
+ :param bool logScale: Toggle visibility of axes log scale actions.
+ :param bool grid: Toggle visibility of grid mode action.
+ :param bool colormap: Toggle visibility of colormap action.
+ :param bool aspectRatio: Toggle visibility of aspect ratio button.
+ :param bool yInverted: Toggle visibility of Y axis direction button.
+ :param bool copy: Toggle visibility of copy action.
+ :param bool save: Toggle visibility of save action.
+ :param bool print_: Toggle visibility of print action.
+ :param bool control: True to display an Options button with a sub-menu
+ to show legends, toggle crosshair and pan with arrows.
+ (Default: False)
+ :param position: True to display widget with (x, y) mouse position
+ (Default: False).
+ It also supports a list of (name, funct(x, y)->value)
+ to customize the displayed values.
+ See :class:`silx.gui.plot.PlotTools.PositionInfo`.
+ :param bool mask: Toggle visibilty of mask action.
+ """
+ # Qt signals
+ valueChanged = qt.Signal(object, object, object)
+ """Signals that the data value under the cursor has changed.
+
+ It provides: row, column, data value.
+ """
+
+ sigPlaneSelectionChanged = qt.Signal(int)
+ """Signal emitted when there is a change is perspective/displayed axes.
+
+ It provides the perspective as an integer, with the following meaning:
+
+ - 0: axis Y is the 2nd dimension, axis X is the 3rd dimension
+ - 1: axis Y is the 1st dimension, axis X is the 3rd dimension
+ - 2: axis Y is the 1st dimension, axis X is the 2nd dimension
+ """
+
+ sigStackChanged = qt.Signal(int)
+ """Signal emitted when the stack is changed.
+ This happens when a new volume is loaded, or when the current volume
+ is transposed (change in perspective).
+
+ The signal provides the size (number of pixels) of the stack.
+ This will be 0 if the stack is cleared, else it will be a positive
+ integer.
+ """
+
+ sigFrameChanged = qt.Signal(int)
+ """Signal emitter when the frame number has changed.
+
+ This signal provides the current frame number.
+ """
+
+ IMAGE_STACK_FILTER_NXDATA = 'Stack of images as NXdata (%s)' % silx_io._NEXUS_HDF5_EXT_STR
+
+
+ def __init__(self, parent=None, resetzoom=True, backend=None,
+ autoScale=False, logScale=False, grid=False,
+ colormap=True, aspectRatio=True, yinverted=True,
+ copy=True, save=True, print_=True, control=False,
+ position=None, mask=True):
+ qt.QMainWindow.__init__(self, parent)
+ if parent is not None:
+ # behave as a widget
+ self.setWindowFlags(qt.Qt.Widget)
+ else:
+ self.setWindowTitle('StackView')
+
+ self._stack = None
+ """Loaded stack, as a 3D array, a 3D dataset or a list of 2D arrays."""
+ self.__transposed_view = None
+ """View on :attr:`_stack` with the axes sorted, to have
+ the orthogonal dimension first"""
+ self._perspective = 0
+ """Orthogonal dimension (depth) in :attr:`_stack`"""
+
+ self._stackItem = ImageStack()
+ """Hold the item displaying the stack"""
+ imageLegend = '__StackView__image' + str(id(self))
+ self._stackItem.setName(imageLegend)
+
+ self.__autoscaleCmap = False
+ """Flag to disable/enable colormap auto-scaling
+ based on the min/max values of the entire 3D volume"""
+ self.__dimensionsLabels = ["Dimension 0", "Dimension 1",
+ "Dimension 2"]
+ """These labels are displayed on the X and Y axes.
+ :meth:`setLabels` updates this attribute."""
+
+ self._first_stack_dimension = 0
+ """Used for dimension labels and combobox"""
+
+ self._titleCallback = self._defaultTitleCallback
+ """Function returning the plot title based on the frame index.
+ It can be set to a custom function using :meth:`setTitleCallback`"""
+
+ self.calibrations3D = (calibration.NoCalibration(),
+ calibration.NoCalibration(),
+ calibration.NoCalibration())
+
+ central_widget = qt.QWidget(self)
+
+ self._plot = PlotWindow(parent=central_widget, backend=backend,
+ resetzoom=resetzoom, autoScale=autoScale,
+ logScale=logScale, grid=grid,
+ curveStyle=False, colormap=colormap,
+ aspectRatio=aspectRatio, yInverted=yinverted,
+ copy=copy, save=save, print_=print_,
+ control=control, position=position,
+ roi=False, mask=mask)
+ self._plot.addItem(self._stackItem)
+ self._plot.getIntensityHistogramAction().setVisible(True)
+ self.sigInteractiveModeChanged = self._plot.sigInteractiveModeChanged
+ self.sigActiveImageChanged = self._plot.sigActiveImageChanged
+ self.sigPlotSignal = self._plot.sigPlotSignal
+
+ if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == 'downward':
+ self._plot.getYAxis().setInverted(True)
+
+ self._addColorBarAction()
+
+ self._profileToolBar = Profile3DToolBar(parent=self._plot,
+ stackview=self)
+ self._plot.addToolBar(self._profileToolBar)
+ self._plot.getXAxis().setLabel('Columns')
+ self._plot.getYAxis().setLabel('Rows')
+ self._plot.sigPlotSignal.connect(self._plotCallback)
+ self._plot.getSaveAction().setFileFilter('image', self.IMAGE_STACK_FILTER_NXDATA, func=self._saveImageStack, appendToFile=True)
+
+ self.__planeSelection = PlanesWidget(self._plot)
+ self.__planeSelection.sigPlaneSelectionChanged.connect(self.setPerspective)
+
+ self._browser_label = qt.QLabel("Image index (Dim0):")
+
+ self._browser = HorizontalSliderWithBrowser(central_widget)
+ self._browser.setRange(0, 0)
+ self._browser.valueChanged[int].connect(self.__updateFrameNumber)
+ self._browser.setEnabled(False)
+
+ layout = qt.QGridLayout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.addWidget(self._plot, 0, 0, 1, 3)
+ layout.addWidget(self.__planeSelection, 1, 0)
+ layout.addWidget(self._browser_label, 1, 1)
+ layout.addWidget(self._browser, 1, 2)
+
+ central_widget.setLayout(layout)
+ self.setCentralWidget(central_widget)
+
+ # clear profile lines when the perspective changes (plane browsed changed)
+ self.__planeSelection.sigPlaneSelectionChanged.connect(
+ self._profileToolBar.clearProfile)
+
+ def _saveImageStack(self, plot, filename, nameFilter):
+ """Save all images from the stack into a volume.
+
+ :param str filename: The name of the file to write
+ :param str nameFilter: The selected name filter
+ :return: False if format is not supported or save failed,
+ True otherwise.
+ :raises: ValueError if nameFilter is invalid
+ """
+ if not nameFilter == self.IMAGE_STACK_FILTER_NXDATA:
+ raise ValueError('Wrong callback')
+ entryPath = silx_io.SaveAction._selectWriteableOutputGroup(filename, parent=self)
+ if entryPath is None:
+ return False
+ return save_NXdata(filename,
+ nxentry_name=entryPath,
+ signal=self.getStack(copy=False, returnNumpyArray=True)[0],
+ signal_name="image_stack")
+
+ def _addColorBarAction(self):
+ self._plot.getColorBarWidget().setVisible(True)
+ actions = self._plot.toolBar().actions()
+ for index, action in enumerate(actions):
+ if action is self._plot.getColormapAction():
+ break
+ self._colorbarAction = actions_control.ColorBarAction(self._plot, self._plot)
+ self._plot.toolBar().insertAction(actions[index + 1], self._colorbarAction)
+
+ def _plotCallback(self, eventDict):
+ """Callback for plot events.
+
+ Emit :attr:`valueChanged` signal, with (x, y, value) tuple of the
+ cursor location in the plot."""
+ if eventDict['event'] == 'mouseMoved':
+ activeImage = self.getActiveImage()
+ if activeImage is not None:
+ data = activeImage.getData()
+ height, width = data.shape
+
+ # Get corresponding coordinate in image
+ origin = activeImage.getOrigin()
+ scale = activeImage.getScale()
+ x = int((eventDict['x'] - origin[0]) / scale[0])
+ y = int((eventDict['y'] - origin[1]) / scale[1])
+
+ if 0 <= x < width and 0 <= y < height:
+ self.valueChanged.emit(float(x), float(y),
+ data[y][x])
+ else:
+ self.valueChanged.emit(float(x), float(y),
+ None)
+
+ def getPerspective(self):
+ """Returns the index of the dimension the stack is browsed with
+
+ Possible values are: 0, 1, or 2.
+
+ :rtype: int
+ """
+ return self._perspective
+
+ def setPerspective(self, perspective):
+ """Set the index of the dimension the stack is browsed with:
+
+ - slice plane Dim1-Dim2: perspective 0
+ - slice plane Dim0-Dim2: perspective 1
+ - slice plane Dim0-Dim1: perspective 2
+
+ :param int perspective: Orthogonal dimension number (0, 1, or 2)
+ """
+ if perspective == self._perspective:
+ return
+ else:
+ if perspective > 2 or perspective < 0:
+ raise ValueError(
+ "Perspective must be 0, 1 or 2, not %s" % perspective)
+
+ self._perspective = int(perspective)
+ self.__createTransposedView()
+ self.__updateFrameNumber(self._browser.value())
+ self._plot.resetZoom()
+ self.__updatePlotLabels()
+ self._updateTitle()
+ self._browser_label.setText("Image index (Dim%d):" %
+ (self._first_stack_dimension + perspective))
+
+ self.sigPlaneSelectionChanged.emit(perspective)
+ self.sigStackChanged.emit(self._stack.size if
+ self._stack is not None else 0)
+ self.__planeSelection.sigPlaneSelectionChanged.disconnect(self.setPerspective)
+ self.__planeSelection.setPerspective(self._perspective)
+ self.__planeSelection.sigPlaneSelectionChanged.connect(self.setPerspective)
+
+ def __updatePlotLabels(self):
+ """Update plot axes labels depending on perspective"""
+ y, x = (1, 2) if self._perspective == 0 else \
+ (0, 2) if self._perspective == 1 else (0, 1)
+ self.setGraphXLabel(self.__dimensionsLabels[x])
+ self.setGraphYLabel(self.__dimensionsLabels[y])
+
+ def __createTransposedView(self):
+ """Create the new view on the stack depending on the perspective
+ (set orthogonal axis browsed on the viewer as first dimension)
+ """
+ assert self._stack is not None
+ assert 0 <= self._perspective < 3
+
+ # ensure we have the stack encapsulated in an array-like object
+ # having a transpose() method
+ if isinstance(self._stack, numpy.ndarray):
+ self.__transposed_view = self._stack
+
+ elif is_dataset(self._stack) or isinstance(self._stack, DatasetView):
+ self.__transposed_view = DatasetView(self._stack)
+
+ elif isinstance(self._stack, ListOfImages):
+ self.__transposed_view = ListOfImages(self._stack)
+
+ # transpose the array-like object if necessary
+ if self._perspective == 1:
+ self.__transposed_view = self.__transposed_view.transpose((1, 0, 2))
+ elif self._perspective == 2:
+ self.__transposed_view = self.__transposed_view.transpose((2, 0, 1))
+
+ self._browser.setRange(0, self.__transposed_view.shape[0] - 1)
+ self._browser.setValue(0)
+
+ # Update the item structure
+ self._stackItem.setStackData(self.__transposed_view, 0, copy=False)
+ self._stackItem.setColormap(self.getColormap())
+ self._stackItem.setOrigin(self._getImageOrigin())
+ self._stackItem.setScale(self._getImageScale())
+
+ def __updateFrameNumber(self, index):
+ """Update the current image.
+
+ :param index: index of the frame to be displayed
+ """
+ if self.__transposed_view is None:
+ # no data set
+ return
+
+ self._stackItem.setStackPosition(index)
+
+ self._updateTitle()
+ self.sigFrameChanged.emit(index)
+
+ def _set3DScaleAndOrigin(self, calibrations):
+ """Set scale and origin for all 3 axes, to be used when plotting
+ an image.
+
+ See setStack for parameter documentation
+ """
+ if calibrations is None:
+ self.calibrations3D = (calibration.NoCalibration(),
+ calibration.NoCalibration(),
+ calibration.NoCalibration())
+ else:
+ self.calibrations3D = []
+ for i, calib in enumerate(calibrations):
+ if hasattr(calib, "__len__") and len(calib) == 2:
+ calib = calibration.LinearCalibration(calib[0], calib[1])
+ elif calib is None:
+ calib = calibration.NoCalibration()
+ elif not isinstance(calib, calibration.AbstractCalibration):
+ raise TypeError("calibration must be a 2-tuple, None or" +
+ " an instance of an AbstractCalibration " +
+ "subclass")
+ elif not calib.is_affine():
+ _logger.warning(
+ "Calibration for dimension %d is not linear, "
+ "it will be ignored for scaling the graph axes.",
+ i)
+ self.calibrations3D.append(calib)
+
+ def getCalibrations(self, order='array'):
+ """Returns currently used calibrations for each axis
+
+ Returned calibrations might differ from the ones that were set as
+ non-linear calibrations used for image axes are temporarily ignored.
+
+ :param str order:
+ 'array' to sort calibrations as data array (dim0, dim1, dim2),
+ 'axes' to sort calibrations as currently selected x, y and z axes.
+ :return: Calibrations ordered depending on order
+ :rtype: List[~silx.math.calibration.AbstractCalibration]
+ """
+ assert order in ('array', 'axes')
+ calibs = []
+
+ # filter out non-linear calibration for graph axes
+ for index, calib in enumerate(self.calibrations3D):
+ if index != self._perspective and not calib.is_affine():
+ calib = calibration.NoCalibration()
+ calibs.append(calib)
+
+ if order == 'axes': # Move 'z' axis to the end
+ xy_dims = [d for d in (0, 1, 2) if d != self._perspective]
+ calibs = [calibs[max(xy_dims)],
+ calibs[min(xy_dims)],
+ calibs[self._perspective]]
+
+ return tuple(calibs)
+
+ def _getImageScale(self):
+ """
+ :return: 2-tuple (XScale, YScale) for current image view
+ """
+ xcalib, ycalib, _zcalib = self.getCalibrations(order='axes')
+ return xcalib.get_slope(), ycalib.get_slope()
+
+ def _getImageOrigin(self):
+ """
+ :return: 2-tuple (XOrigin, YOrigin) for current image view
+ """
+ xcalib, ycalib, _zcalib = self.getCalibrations(order='axes')
+ return xcalib(0), ycalib(0)
+
+ def _getImageZ(self, index):
+ """
+ :param idx: 0-based image index in the stack
+ :return: calibrated Z value corresponding to the image idx
+ """
+ _xcalib, _ycalib, zcalib = self.getCalibrations(order='axes')
+ return zcalib(index)
+
+ def _updateTitle(self):
+ frame_idx = self._browser.value()
+ self._plot.setGraphTitle(self._titleCallback(frame_idx))
+
+ def _defaultTitleCallback(self, index):
+ return "Image z=%g" % self._getImageZ(index)
+
+ # public API, stack specific methods
+ def setStack(self, stack, perspective=None, reset=True, calibrations=None):
+ """Set the 3D stack.
+
+ The perspective parameter is used to define which dimension of the 3D
+ array is to be used as frame index. The lowest remaining dimension
+ number is the row index of the displayed image (Y axis), and the highest
+ remaining dimension is the column index (X axis).
+
+ :param stack: 3D stack, or `None` to clear plot.
+ :type stack: 3D numpy.ndarray, or 3D h5py.Dataset, or list/tuple of 2D
+ numpy arrays, or None.
+ :param int perspective: Dimension for the frame index: 0, 1 or 2.
+ Use ``None`` to keep the current perspective (default).
+ :param bool reset: Whether to reset zoom or not.
+ :param calibrations: Sequence of 3 calibration objects for each axis.
+ These objects can be a subclass of :class:`AbstractCalibration`,
+ or 2-tuples *(a, b)* where *a* is the y-intercept and *b* is the
+ slope of a linear calibration (:math:`x \\mapsto a + b x`)
+ """
+ if stack is None:
+ self.clear()
+ self.sigStackChanged.emit(0)
+ return
+
+ self._set3DScaleAndOrigin(calibrations)
+
+ # stack as list of 2D arrays: must be converted into an array_like
+ if not isinstance(stack, numpy.ndarray):
+ if not is_dataset(stack):
+ try:
+ assert hasattr(stack, "__len__")
+ for img in stack:
+ assert hasattr(img, "shape")
+ assert len(img.shape) == 2
+ except AssertionError:
+ raise ValueError(
+ "Stack must be a 3D array/dataset or a list of " +
+ "2D arrays.")
+ stack = ListOfImages(stack)
+
+ assert len(stack.shape) == 3, "data must be 3D"
+
+ self._stack = stack
+ self.__createTransposedView()
+
+ perspective_changed = False
+ if perspective not in [None, self._perspective]:
+ perspective_changed = True
+ self.setPerspective(perspective)
+
+ if self.__autoscaleCmap:
+ self.scaleColormapRangeToStack()
+
+ # init plot
+ self._stackItem.setStackData(self.__transposed_view, 0, copy=False)
+ self._stackItem.setColormap(self.getColormap())
+ self._stackItem.setOrigin(self._getImageOrigin())
+ self._stackItem.setScale(self._getImageScale())
+ self._stackItem.setVisible(True)
+
+ # Put back the item in the plot in case it was cleared
+ exists = self._plot.getImage(self._stackItem.getName())
+ if exists is None:
+ self._plot.addItem(self._stackItem)
+
+ self._plot.setActiveImage(self._stackItem.getName())
+ self.__updatePlotLabels()
+ self._updateTitle()
+
+ if reset:
+ self._plot.resetZoom()
+
+ # enable and init browser
+ self._browser.setEnabled(True)
+
+ if not perspective_changed: # avoid double signal (see self.setPerspective)
+ self.sigStackChanged.emit(stack.size)
+
+ def getStack(self, copy=True, returnNumpyArray=False):
+ """Get the original stack, as a 3D array or dataset.
+
+ The output has the form: [data, params]
+ where params is a dictionary containing display parameters.
+
+ :param bool copy: If True (default), then the object is copied
+ and returned as a numpy array.
+ Else, a reference to original data is returned, if possible.
+ If the original data is not a numpy array and parameter
+ returnNumpyArray is True, a copy will be made anyway.
+ :param bool returnNumpyArray: If True, the returned object is
+ guaranteed to be a numpy array.
+ :return: 3D stack and parameters.
+ :rtype: (numpy.ndarray, dict)
+ """
+ if self._stack is None:
+ return None
+
+ image = self._stackItem
+ colormap = image.getColormap()
+
+ params = {
+ 'info': image.getInfo(),
+ 'origin': image.getOrigin(),
+ 'scale': image.getScale(),
+ 'z': image.getZValue(),
+ 'selectable': image.isSelectable(),
+ 'draggable': image.isDraggable(),
+ 'colormap': colormap,
+ 'xlabel': image.getXLabel(),
+ 'ylabel': image.getYLabel(),
+ }
+ if returnNumpyArray or copy:
+ return numpy.array(self._stack, copy=copy), params
+
+ # if a list of 2D arrays was cast into a ListOfImages,
+ # return the original list
+ if isinstance(self._stack, ListOfImages):
+ return self._stack.images, params
+
+ return self._stack, params
+
+ def getCurrentView(self, copy=True, returnNumpyArray=False):
+ """Get the stack, as it is currently displayed.
+
+ The first index of the returned stack is always the frame
+ index. If the perspective has been changed in the widget since the
+ data was first loaded, this will be reflected in the order of the
+ dimensions of the returned object.
+
+ The output has the form: [data, params]
+ where params is a dictionary containing display parameters.
+
+ :param bool copy: If True (default), then the object is copied
+ and returned as a numpy array.
+ Else, a reference to original data is returned, if possible.
+ If the original data is not a numpy array and parameter
+ `returnNumpyArray` is `True`, a copy will be made anyway.
+ :param bool returnNumpyArray: If `True`, the returned object is
+ guaranteed to be a numpy array.
+ :return: 3D stack and parameters.
+ :rtype: (numpy.ndarray, dict)
+ """
+ image = self.getActiveImage()
+ if image is None:
+ return None
+
+ if isinstance(image, items.ColormapMixIn):
+ colormap = image.getColormap()
+ else:
+ colormap = None
+
+ params = {
+ 'info': image.getInfo(),
+ 'origin': image.getOrigin(),
+ 'scale': image.getScale(),
+ 'z': image.getZValue(),
+ 'selectable': image.isSelectable(),
+ 'draggable': image.isDraggable(),
+ 'colormap': colormap,
+ 'xlabel': image.getXLabel(),
+ 'ylabel': image.getYLabel(),
+ }
+ if returnNumpyArray or copy:
+ return numpy.array(self.__transposed_view, copy=copy), params
+ return self.__transposed_view, params
+
+ def setFrameNumber(self, number):
+ """Set the frame selection to a specific value
+
+ :param int number: Number of the frame
+ """
+ self._browser.setValue(number)
+
+ def getFrameNumber(self):
+ """Set the frame selection to a specific value
+
+ :return: Index of currently displayed frame
+ :rtype: int
+ """
+ return self._browser.value()
+
+ def setFirstStackDimension(self, first_stack_dimension):
+ """When viewing the last 3 dimensions of an n-D array (n>3), you can
+ use this method to change the text in the combobox.
+
+ For instance, for a 7-D array, first stack dim is 4, so the default
+ "Dim1-Dim2" text should be replaced with "Dim5-Dim6" (dimensions
+ numbers are 0-based).
+
+ :param int first_stack_dim: First stack dimension (n-3) when viewing the
+ last 3 dimensions of an n-D array.
+ """
+ old_state = self.__planeSelection.blockSignals(True)
+ self.__planeSelection.setFirstStackDimension(first_stack_dimension)
+ self.__planeSelection.blockSignals(old_state)
+ self._first_stack_dimension = first_stack_dimension
+ self._browser_label.setText("Image index (Dim%d):" % first_stack_dimension)
+
+ def setTitleCallback(self, callback):
+ """Set a user defined function to generate the plot title based on the
+ image/frame index.
+
+ The callback function must accept an integer as a its first positional
+ parameter and must not require any other mandatory parameter.
+ It must return a string.
+
+ To switch back the default behavior, you can pass ``None``::
+
+ mystackview.setTitleCallback(None)
+
+ To have no title, pass a function that returns an empty string::
+
+ mystackview.setTitleCallback(lambda idx: "")
+
+ :param callback: Callback function generating the stack title based
+ on the frame number.
+ """
+
+ if callback is None:
+ self._titleCallback = self._defaultTitleCallback
+ elif callable(callback):
+ self._titleCallback = callback
+ else:
+ raise TypeError("Provided callback is not callable")
+ self._updateTitle()
+
+ def clear(self):
+ """Clear the widget:
+
+ - clear the plot
+ - clear the loaded data volume
+ """
+ self._stack = None
+ self.__transposed_view = None
+ self._perspective = 0
+ self._browser.setEnabled(False)
+ # reset browser range
+ self._browser.setRange(0, 0)
+ self._plot.clear()
+
+ def setLabels(self, labels=None):
+ """Set the labels to be displayed on the plot axes.
+
+ You must provide a sequence of 3 strings, corresponding to the 3
+ dimensions of the original data volume.
+ The proper label will automatically be selected for each plot axis
+ when the volume is rotated (when different axes are selected as the
+ X and Y axes).
+
+ :param List[str] labels: 3 labels corresponding to the 3 dimensions
+ of the data volumes.
+ """
+
+ default_labels = ["Dimension %d" % self._first_stack_dimension,
+ "Dimension %d" % (self._first_stack_dimension + 1),
+ "Dimension %d" % (self._first_stack_dimension + 2)]
+ if labels is None:
+ new_labels = default_labels
+ else:
+ # filter-out None
+ new_labels = []
+ for i, label in enumerate(labels):
+ new_labels.append(label or default_labels[i])
+
+ self.__dimensionsLabels = new_labels
+ self.__updatePlotLabels()
+
+ def getLabels(self):
+ """Return dimension labels displayed on the plot axes
+
+ :return: List of three strings corresponding to the 3 dimensions
+ of the stack: (name_dim0, name_dim1, name_dim2)
+ """
+ return self.__dimensionsLabels
+
+ def getColormap(self):
+ """Get the current colormap description.
+
+ :return: A description of the current colormap.
+ See :meth:`setColormap` for details.
+ :rtype: dict
+ """
+ # "default" colormap used by addImage when image is added without
+ # specifying a special colormap
+ return self._plot.getDefaultColormap()
+
+ def scaleColormapRangeToStack(self):
+ """Scale colormap range according to current stack data.
+
+ If no stack has been set through :meth:`setStack`, this has no effect.
+
+ The range scaling mode is given by current :class:`Colormap`'s
+ :meth:`Colormap.getAutoscaleMode`.
+ """
+ stack = self.getStack(copy=False, returnNumpyArray=True)
+ if stack is None:
+ return # No-op
+
+ colormap = self.getColormap()
+ vmin, vmax = colormap.getColormapRange(data=stack[0])
+ colormap.setVRange(vmin=vmin, vmax=vmax)
+
+ def setColormap(self, colormap=None, normalization=None,
+ autoscale=None, vmin=None, vmax=None, colors=None):
+ """Set the colormap and update active image.
+
+ Parameters that are not provided are taken from the current colormap.
+
+ The colormap parameter can also be a dict with the following keys:
+
+ - *name*: string. The colormap to use:
+ 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'.
+ - *normalization*: string. The mapping to use for the colormap:
+ either 'linear' or 'log'.
+ - *autoscale*: bool. Whether to use autoscale (True) or range
+ provided by keys
+ 'vmin' and 'vmax' (False).
+ - *vmin*: float. The minimum value of the range to use if 'autoscale'
+ is False.
+ - *vmax*: float. The maximum value of the range to use if 'autoscale'
+ is False.
+ - *colors*: optional. Nx3 or Nx4 array of float in [0, 1] or uint8.
+ List of RGB or RGBA colors to use (only if name is None)
+
+ :param colormap: Name of the colormap in
+ 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'.
+ Or a :class`.Colormap` object.
+ :type colormap: dict or str.
+ :param str normalization: Colormap mapping: 'linear' or 'log'.
+ :param bool autoscale: Whether to use autoscale or [vmin, vmax] range.
+ Default value of autoscale is False. This option is not compatible
+ with h5py datasets.
+ :param float vmin: The minimum value of the range to use if
+ 'autoscale' is False.
+ :param float vmax: The maximum value of the range to use if
+ 'autoscale' is False.
+ :param numpy.ndarray colors: Only used if name is None.
+ Custom colormap colors as Nx3 or Nx4 RGB or RGBA arrays
+ """
+ # if is a colormap object or a dictionary
+ if isinstance(colormap, Colormap) or isinstance(colormap, dict):
+ # Support colormap parameter as a dict
+ errmsg = "If colormap is provided as a Colormap object, all other parameters"
+ errmsg += " must not be specified when calling setColormap"
+ assert normalization is None, errmsg
+ assert autoscale is None, errmsg
+ assert vmin is None, errmsg
+ assert vmax is None, errmsg
+ assert colors is None, errmsg
+
+ if isinstance(colormap, dict):
+ reason = 'colormap parameter should now be an object'
+ replacement = 'Colormap()'
+ since_version = '0.6'
+ deprecated_warning(type_='function',
+ name='setColormap',
+ reason=reason,
+ replacement=replacement,
+ since_version=since_version)
+ _colormap = Colormap._fromDict(colormap)
+ else:
+ _colormap = colormap
+ else:
+ norm = normalization if normalization is not None else 'linear'
+ name = colormap if colormap is not None else 'gray'
+ _colormap = Colormap(name=name,
+ normalization=norm,
+ vmin=vmin,
+ vmax=vmax,
+ colors=colors)
+
+ if autoscale is not None:
+ deprecated_warning(
+ type_='function',
+ name='setColormap',
+ reason='autoscale argument is replaced by a method',
+ replacement='scaleColormapRangeToStack',
+ since_version='0.14')
+ self.__autoscaleCmap = bool(autoscale)
+
+ cursorColor = cursorColorForColormap(_colormap.getName())
+ self._plot.setInteractiveMode('zoom', color=cursorColor)
+
+ self._plot.setDefaultColormap(_colormap)
+
+ # Update active image colormap
+ activeImage = self.getActiveImage()
+ if isinstance(activeImage, items.ColormapMixIn):
+ activeImage.setColormap(self.getColormap())
+
+ if self.__autoscaleCmap:
+ # scaleColormapRangeToStack needs to be called **after**
+ # setDefaultColormap so getColormap returns the right colormap
+ self.scaleColormapRangeToStack()
+
+
+ @deprecated(replacement="getPlotWidget", since_version="0.13")
+ def getPlot(self):
+ return self.getPlotWidget()
+
+ def getPlotWidget(self):
+ """Return the :class:`PlotWidget`.
+
+ This gives access to advanced plot configuration options.
+ Be warned that modifying the plot can cause issues, and some changes
+ you make to the plot could be overwritten by the :class:`StackView`
+ widget's internal methods and callbacks.
+
+ :return: instance of :class:`PlotWidget` used in widget
+ """
+ return self._plot
+
+ def setOptionVisible(self, isVisible):
+ """
+ Set the visibility of the browsing options.
+
+ :param bool isVisible: True to have the options visible, else False
+ """
+ self._browser.setVisible(isVisible)
+ self.__planeSelection.setVisible(isVisible)
+
+ # proxies to PlotWidget or PlotWindow methods
+ def getProfileToolbar(self):
+ """Profile tools attached to this plot
+ """
+ return self._profileToolBar
+
+ def getGraphTitle(self):
+ """Return the plot main title as a str.
+ """
+ return self._plot.getGraphTitle()
+
+ def setGraphTitle(self, title=""):
+ """Set the plot main title.
+
+ :param str title: Main title of the plot (default: '')
+ """
+ return self._plot.setGraphTitle(title)
+
+ def getGraphXLabel(self):
+ """Return the current horizontal axis label as a str.
+ """
+ return self._plot.getXAxis().getLabel()
+
+ def setGraphXLabel(self, label=None):
+ """Set the plot horizontal axis label.
+
+ :param str label: The horizontal axis label
+ """
+ if label is None:
+ label = self.__dimensionsLabels[1 if self._perspective == 2 else 2]
+ self._plot.getXAxis().setLabel(label)
+
+ def getGraphYLabel(self, axis='left'):
+ """Return the current vertical axis label as a str.
+
+ :param str axis: The Y axis for which to get the label (left or right)
+ """
+ return self._plot.getYAxis().getLabel(axis)
+
+ def setGraphYLabel(self, label=None, axis='left'):
+ """Set the vertical axis label on the plot.
+
+ :param str label: The Y axis label
+ :param str axis: The Y axis for which to set the label (left or right)
+ """
+ if label is None:
+ label = self.__dimensionsLabels[1 if self._perspective == 0 else 0]
+ self._plot.getYAxis(axis=axis).setLabel(label)
+
+ def resetZoom(self):
+ """Reset the plot limits to the bounds of the data and redraw the plot.
+
+ This method is a simple proxy to the legacy :class:`PlotWidget` method
+ of the same name. Using the object oriented approach is now
+ preferred::
+
+ stackview.getPlot().resetZoom()
+ """
+ self._plot.resetZoom()
+
+ def setYAxisInverted(self, flag=True):
+ """Set the Y axis orientation.
+
+ This method is a simple proxy to the legacy :class:`PlotWidget` method
+ of the same name. Using the object oriented approach is now
+ preferred::
+
+ stackview.getPlot().setYAxisInverted(flag)
+
+ :param bool flag: True for Y axis going from top to bottom,
+ False for Y axis going from bottom to top
+ """
+ self._plot.setYAxisInverted(flag)
+
+ def isYAxisInverted(self):
+ """Return True if Y axis goes from top to bottom, False otherwise.
+
+ This method is a simple proxy to the legacy :class:`PlotWidget` method
+ of the same name. Using the object oriented approach is now
+ preferred::
+
+ stackview.getPlot().isYAxisInverted()"""
+ return self._plot.isYAxisInverted()
+
+ def getSupportedColormaps(self):
+ """Get the supported colormap names as a tuple of str.
+
+ The list should at least contain and start by:
+ ('gray', 'reversed gray', 'temperature', 'red', 'green', 'blue')
+
+ This method is a simple proxy to the legacy :class:`PlotWidget` method
+ of the same name. Using the object oriented approach is now
+ preferred::
+
+ stackview.getPlot().getSupportedColormaps()
+ """
+ return self._plot.getSupportedColormaps()
+
+ def isKeepDataAspectRatio(self):
+ """Returns whether the plot is keeping data aspect ratio or not.
+
+ This method is a simple proxy to the legacy :class:`PlotWidget` method
+ of the same name. Using the object oriented approach is now
+ preferred::
+
+ stackview.getPlot().isKeepDataAspectRatio()"""
+ return self._plot.isKeepDataAspectRatio()
+
+ def setKeepDataAspectRatio(self, flag=True):
+ """Set whether the plot keeps data aspect ratio or not.
+
+ This method is a simple proxy to the legacy :class:`PlotWidget` method
+ of the same name. Using the object oriented approach is now
+ preferred::
+
+ stackview.getPlot().setKeepDataAspectRatio(flag)
+
+ :param bool flag: True to respect data aspect ratio
+ """
+ self._plot.setKeepDataAspectRatio(flag)
+
+ # kind of private methods, but needed by Profile
+ def getActiveImage(self, just_legend=False):
+ """Returns the stack image object.
+ """
+ if just_legend:
+ return self._stackItem.getName()
+ return self._stackItem
+
+ def getColorBarAction(self):
+ """Returns the action managing the visibility of the colorbar.
+
+ .. warning:: to show/hide the plot colorbar call directly the ColorBar
+ widget using getColorBarWidget()
+
+ :rtype: QAction
+ """
+ return self._colorbarAction
+
+ def remove(self, legend=None,
+ kind=('curve', 'image', 'item', 'marker')):
+ """See :meth:`Plot.Plot.remove`"""
+ self._plot.remove(legend, kind)
+
+ def setInteractiveMode(self, *args, **kwargs):
+ """
+ See :meth:`Plot.Plot.setInteractiveMode`
+ """
+ self._plot.setInteractiveMode(*args, **kwargs)
+
+ @deprecated(replacement="addShape", since_version="0.13")
+ def addItem(self, *args, **kwargs):
+ self.addShape(*args, **kwargs)
+
+ def addShape(self, *args, **kwargs):
+ """
+ See :meth:`Plot.Plot.addShape`
+ """
+ self._plot.addShape(*args, **kwargs)
+
+
+class PlanesWidget(qt.QWidget):
+ """Widget for the plane/perspective selection
+
+ :param parent: the parent QWidget
+ """
+ sigPlaneSelectionChanged = qt.Signal(int)
+
+ def __init__(self, parent):
+ super(PlanesWidget, self).__init__(parent)
+
+ self.setSizePolicy(qt.QSizePolicy.Minimum, qt.QSizePolicy.Minimum)
+ layout0 = qt.QHBoxLayout()
+ self.setLayout(layout0)
+ layout0.setContentsMargins(0, 0, 0, 0)
+
+ layout0.addWidget(qt.QLabel("Axes selection:"))
+
+ # By default, the first dimension (dim0) is the frame index/depth/z,
+ # the second dimension is the image row number/y axis
+ # and the third dimension is the image column index/x axis
+
+ # 1
+ # | 0
+ # |/__2
+ self.qcbAxisSelection = qt.QComboBox(self)
+ self._setCBChoices(first_stack_dimension=0)
+ self.qcbAxisSelection.currentIndexChanged[int].connect(
+ self.__planeSelectionChanged)
+
+ layout0.addWidget(self.qcbAxisSelection)
+
+ def __planeSelectionChanged(self, idx):
+ """Callback function when the combobox selection changes
+
+ idx is the dimension number orthogonal to the slice plane,
+ following the convention:
+
+ - slice plane Dim1-Dim2: perspective 0
+ - slice plane Dim0-Dim2: perspective 1
+ - slice plane Dim0-Dim1: perspective 2
+ """
+ self.sigPlaneSelectionChanged.emit(idx)
+
+ def _setCBChoices(self, first_stack_dimension):
+ self.qcbAxisSelection.clear()
+
+ dim1dim2 = 'Dim%d-Dim%d' % (first_stack_dimension + 1,
+ first_stack_dimension + 2)
+ dim0dim2 = 'Dim%d-Dim%d' % (first_stack_dimension,
+ first_stack_dimension + 2)
+ dim0dim1 = 'Dim%d-Dim%d' % (first_stack_dimension,
+ first_stack_dimension + 1)
+
+ self.qcbAxisSelection.addItem(icons.getQIcon("cube-front"), dim1dim2)
+ self.qcbAxisSelection.addItem(icons.getQIcon("cube-bottom"), dim0dim2)
+ self.qcbAxisSelection.addItem(icons.getQIcon("cube-left"), dim0dim1)
+
+ def setFirstStackDimension(self, first_stack_dim):
+ """When viewing the last 3 dimensions of an n-D array (n>3), you can
+ use this method to change the text in the combobox.
+
+ For instance, for a 7-D array, first stack dim is 4, so the default
+ "Dim1-Dim2" text should be replaced with "Dim5-Dim6" (dimensions
+ numbers are 0-based).
+
+ :param int first_stack_dim: First stack dimension (n-3) when viewing the
+ last 3 dimensions of an n-D array.
+ """
+ self._setCBChoices(first_stack_dim)
+
+ def setPerspective(self, perspective):
+ """Update the combobox selection.
+
+ - slice plane Dim1-Dim2: perspective 0
+ - slice plane Dim0-Dim2: perspective 1
+ - slice plane Dim0-Dim1: perspective 2
+
+ :param perspective: Orthogonal dimension number (0, 1, or 2)
+ """
+ self.qcbAxisSelection.setCurrentIndex(perspective)
+
+
+class StackViewMainWindow(StackView):
+ """This class is a :class:`StackView` with a menu, an additional toolbar
+ to set the plot limits, and a status bar to display the value and 3D
+ index of the data samples hovered by the mouse cursor.
+
+ :param QWidget parent: Parent widget, or None
+ """
+ def __init__(self, parent=None):
+ self._dataInfo = None
+ super(StackViewMainWindow, self).__init__(parent)
+ self.setWindowFlags(qt.Qt.Window)
+
+ # Add toolbars and status bar
+ self.addToolBar(qt.Qt.BottomToolBarArea,
+ LimitsToolBar(plot=self._plot))
+
+ self.statusBar()
+
+ menu = self.menuBar().addMenu('File')
+ menu.addAction(self._plot.getOutputToolBar().getSaveAction())
+ menu.addAction(self._plot.getOutputToolBar().getPrintAction())
+ menu.addSeparator()
+ action = menu.addAction('Quit')
+ action.triggered[bool].connect(qt.QApplication.instance().quit)
+
+ menu = self.menuBar().addMenu('Edit')
+ menu.addAction(self._plot.getOutputToolBar().getCopyAction())
+ menu.addSeparator()
+ menu.addAction(self._plot.getResetZoomAction())
+ menu.addAction(self._plot.getColormapAction())
+ menu.addAction(self.getColorBarAction())
+
+ menu.addAction(actions.control.KeepAspectRatioAction(self._plot, self))
+ menu.addAction(actions.control.YAxisInvertedAction(self._plot, self))
+
+ menu = self.menuBar().addMenu('Profile')
+ profileToolBar = self._profileToolBar
+ menu.addAction(profileToolBar.hLineAction)
+ menu.addAction(profileToolBar.vLineAction)
+ menu.addAction(profileToolBar.lineAction)
+ menu.addAction(profileToolBar.crossAction)
+ menu.addSeparator()
+ menu.addAction(profileToolBar._editor)
+ menu.addSeparator()
+ menu.addAction(profileToolBar.clearAction)
+
+ # Connect to StackView's signal
+ self.valueChanged.connect(self._statusBarSlot)
+
+ def _statusBarSlot(self, x, y, value):
+ """Update status bar with coordinates/value from plots."""
+ # todo (after implementing calibration):
+ # - use floats for (x, y, z)
+ # - display both indices (dim0, dim1, dim2) and (x, y, z)
+ msg = "Cursor out of range"
+ if x is not None and y is not None:
+ img_idx = self._browser.value()
+
+ if self._perspective == 0:
+ dim0, dim1, dim2 = img_idx, int(y), int(x)
+ elif self._perspective == 1:
+ dim0, dim1, dim2 = int(y), img_idx, int(x)
+ elif self._perspective == 2:
+ dim0, dim1, dim2 = int(y), int(x), img_idx
+
+ msg = 'Position: (%d, %d, %d)' % (dim0, dim1, dim2)
+ if value is not None:
+ msg += ', Value: %g' % value
+ if self._dataInfo is not None:
+ msg = self._dataInfo + ', ' + msg
+
+ self.statusBar().showMessage(msg)
+
+ def setStack(self, stack, *args, **kwargs):
+ """Set the displayed stack.
+
+ See :meth:`StackView.setStack` for details.
+ """
+ if hasattr(stack, 'dtype') and hasattr(stack, 'shape'):
+ assert len(stack.shape) == 3
+ nframes, height, width = stack.shape
+ self._dataInfo = 'Data: %dx%dx%d (%s)' % (nframes, height, width,
+ str(stack.dtype))
+ self.statusBar().showMessage(self._dataInfo)
+ else:
+ self._dataInfo = None
+
+ # Set the new stack in StackView widget
+ super(StackViewMainWindow, self).setStack(stack, *args, **kwargs)
+ self.setStatusBar(None)
diff --git a/src/silx/gui/plot/StatsWidget.py b/src/silx/gui/plot/StatsWidget.py
new file mode 100644
index 0000000..00f78d0
--- /dev/null
+++ b/src/silx/gui/plot/StatsWidget.py
@@ -0,0 +1,1658 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+Module containing widgets displaying stats from items of a plot.
+"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "24/07/2018"
+
+
+from collections import OrderedDict
+from contextlib import contextmanager
+import logging
+import weakref
+import functools
+import numpy
+import enum
+from silx.utils.proxy import docstring
+from silx.utils.enum import Enum as _Enum
+from silx.gui import qt
+from silx.gui import icons
+from silx.gui.plot import stats as statsmdl
+from silx.gui.widgets.TableWidget import TableWidget
+from silx.gui.plot.stats.statshandler import StatsHandler, StatFormatter
+from silx.gui.plot.items.core import ItemChangedType
+from silx.gui.widgets.FlowLayout import FlowLayout
+from . import PlotWidget
+from . import items as plotitems
+
+
+_logger = logging.getLogger(__name__)
+
+
+@enum.unique
+class UpdateMode(_Enum):
+ AUTO = 'auto'
+ MANUAL = 'manual'
+
+
+# Helper class to handle specific calls to PlotWidget and SceneWidget
+
+
+class _Wrapper(qt.QObject):
+ """Base class for connection with PlotWidget and SceneWidget.
+
+ This class is used when no PlotWidget or SceneWidget is connected.
+
+ :param plot: The plot to be used
+ """
+
+ sigItemAdded = qt.Signal(object)
+ """Signal emitted when a new item is added.
+
+ It provides the added item.
+ """
+
+ sigItemRemoved = qt.Signal(object)
+ """Signal emitted when an item is (about to be) removed.
+
+ It provides the removed item.
+ """
+
+ sigCurrentChanged = qt.Signal(object)
+ """Signal emitted when the current item has changed.
+
+ It provides the current item.
+ """
+
+ sigVisibleDataChanged = qt.Signal()
+ """Signal emitted when the visible data area has changed"""
+
+ def __init__(self, plot=None):
+ super(_Wrapper, self).__init__(parent=None)
+ self._plotRef = None if plot is None else weakref.ref(plot)
+
+ def getPlot(self):
+ """Returns the plot attached to this widget"""
+ return None if self._plotRef is None else self._plotRef()
+
+ def getItems(self):
+ """Returns the list of items in the plot
+
+ :rtype: List[object]
+ """
+ return ()
+
+ def getSelectedItems(self):
+ """Returns the list of selected items in the plot
+
+ :rtype: List[object]
+ """
+ return ()
+
+ def setCurrentItem(self, item):
+ """Set the current/active item in the plot
+
+ :param item: The plot item to set as active/current
+ """
+ pass
+
+ def getLabel(self, item):
+ """Returns the label of the given item.
+
+ :param item:
+ :rtype: str
+ """
+ return ''
+
+ def getKind(self, item):
+ """Returns the kind of an item or None if not supported
+
+ :param item:
+ :rtype: Union[str,None]
+ """
+ return None
+
+
+class _PlotWidgetWrapper(_Wrapper):
+ """Class handling PlotWidget specific calls and signal connections
+
+ See :class:`._Wrapper` for documentation
+
+ :param PlotWidget plot:
+ """
+
+ def __init__(self, plot):
+ assert isinstance(plot, PlotWidget)
+ super(_PlotWidgetWrapper, self).__init__(plot)
+ plot.sigItemAdded.connect(self.sigItemAdded.emit)
+ plot.sigItemAboutToBeRemoved.connect(self.sigItemRemoved.emit)
+ plot.sigActiveCurveChanged.connect(self._activeCurveChanged)
+ plot.sigActiveImageChanged.connect(self._activeImageChanged)
+ plot.sigActiveScatterChanged.connect(self._activeScatterChanged)
+ plot.sigPlotSignal.connect(self._limitsChanged)
+
+ def _activeChanged(self, kind):
+ """Handle change of active curve/image/scatter"""
+ plot = self.getPlot()
+ if plot is not None:
+ item = plot._getActiveItem(kind=kind)
+ if item is None or self.getKind(item) is not None:
+ self.sigCurrentChanged.emit(item)
+
+ def _activeCurveChanged(self, previous, current):
+ self._activeChanged(kind='curve')
+
+ def _activeImageChanged(self, previous, current):
+ self._activeChanged(kind='image')
+
+ def _activeScatterChanged(self, previous, current):
+ self._activeChanged(kind='scatter')
+
+ def _limitsChanged(self, event):
+ """Handle change of plot area limits."""
+ if event['event'] == 'limitsChanged':
+ self.sigVisibleDataChanged.emit()
+
+ def getItems(self):
+ plot = self.getPlot()
+ if plot is None:
+ return ()
+ else:
+ return [item for item in plot.getItems() if item.isVisible()]
+
+ def getSelectedItems(self):
+ plot = self.getPlot()
+ items = []
+ if plot is not None:
+ for kind in plot._ACTIVE_ITEM_KINDS:
+ item = plot._getActiveItem(kind=kind)
+ if item is not None:
+ items.append(item)
+ return tuple(items)
+
+ def setCurrentItem(self, item):
+ plot = self.getPlot()
+ if plot is not None:
+ kind = self.getKind(item)
+ if kind in plot._ACTIVE_ITEM_KINDS:
+ if plot._getActiveItem(kind) != item:
+ plot._setActiveItem(kind, item.getName())
+
+ def getLabel(self, item):
+ return item.getName()
+
+ def getKind(self, item):
+ if isinstance(item, plotitems.Curve):
+ return 'curve'
+ elif isinstance(item, plotitems.ImageData):
+ return 'image'
+ elif isinstance(item, plotitems.Scatter):
+ return 'scatter'
+ elif isinstance(item, plotitems.Histogram):
+ return 'histogram'
+ else:
+ return None
+
+
+class _SceneWidgetWrapper(_Wrapper):
+ """Class handling SceneWidget specific calls and signal connections
+
+ See :class:`._Wrapper` for documentation
+
+ :param SceneWidget plot:
+ """
+
+ def __init__(self, plot):
+ # Lazy-import to avoid circular imports
+ from ..plot3d.SceneWidget import SceneWidget
+
+ assert isinstance(plot, SceneWidget)
+ super(_SceneWidgetWrapper, self).__init__(plot)
+ plot.getSceneGroup().sigItemAdded.connect(self.sigItemAdded)
+ plot.getSceneGroup().sigItemRemoved.connect(self.sigItemRemoved)
+ plot.selection().sigCurrentChanged.connect(self._currentChanged)
+ # sigVisibleDataChanged is never emitted
+
+ def _currentChanged(self, current, previous):
+ self.sigCurrentChanged.emit(current)
+
+ def getItems(self):
+ plot = self.getPlot()
+ return () if plot is None else tuple(plot.getSceneGroup().visit())
+
+ def getSelectedItems(self):
+ plot = self.getPlot()
+ return () if plot is None else (plot.selection().getCurrentItem(),)
+
+ def setCurrentItem(self, item):
+ plot = self.getPlot()
+ if plot is not None:
+ plot.selection().setCurrentItem(item)
+
+ def getLabel(self, item):
+ return item.getLabel()
+
+ def getKind(self, item):
+ from ..plot3d import items as plot3ditems
+
+ if isinstance(item, (plot3ditems.ImageData,
+ plot3ditems.ScalarField3D)):
+ return 'image'
+ elif isinstance(item, (plot3ditems.Scatter2D,
+ plot3ditems.Scatter3D)):
+ return 'scatter'
+ else:
+ return None
+
+
+class _ScalarFieldViewWrapper(_Wrapper):
+ """Class handling ScalarFieldView specific calls and signal connections
+
+ See :class:`._Wrapper` for documentation
+
+ :param SceneWidget plot:
+ """
+
+ def __init__(self, plot):
+ # Lazy-import to avoid circular imports
+ from ..plot3d.ScalarFieldView import ScalarFieldView
+ from ..plot3d.items import ScalarField3D
+
+ assert isinstance(plot, ScalarFieldView)
+ super(_ScalarFieldViewWrapper, self).__init__(plot)
+ self._item = ScalarField3D()
+ self._dataChanged()
+ plot.sigDataChanged.connect(self._dataChanged)
+ # sigItemAdded, sigItemRemoved, sigVisibleDataChanged are never emitted
+
+ def _dataChanged(self):
+ plot = self.getPlot()
+ if plot is not None:
+ self._item.setData(plot.getData(copy=False), copy=False)
+ self.sigCurrentChanged.emit(self._item)
+
+ def getItems(self):
+ plot = self.getPlot()
+ return () if plot is None else (self._item,)
+
+ def getSelectedItems(self):
+ return self.getItems()
+
+ def setCurrentItem(self, item):
+ pass
+
+ def getLabel(self, item):
+ return 'Data'
+
+ def getKind(self, item):
+ return 'image'
+
+
+class _Container(object):
+ """Class to contain a plot item.
+
+ This is apparently needed for compatibility with PySide2,
+
+ :param QObject obj:
+ """
+ def __init__(self, obj):
+ self._obj = obj
+
+ def __call__(self):
+ return self._obj
+
+
+class _StatsWidgetBase(object):
+ """
+ Base class for all widgets which want to display statistics
+ """
+
+ def __init__(self, statsOnVisibleData, displayOnlyActItem):
+ self._displayOnlyActItem = displayOnlyActItem
+ self._statsOnVisibleData = statsOnVisibleData
+ self._statsHandler = None
+ self._updateMode = UpdateMode.AUTO
+
+ self.__default_skipped_events = (
+ ItemChangedType.ALPHA,
+ ItemChangedType.COLOR,
+ ItemChangedType.COLORMAP,
+ ItemChangedType.SYMBOL,
+ ItemChangedType.SYMBOL_SIZE,
+ ItemChangedType.LINE_WIDTH,
+ ItemChangedType.LINE_STYLE,
+ ItemChangedType.LINE_BG_COLOR,
+ ItemChangedType.FILL,
+ ItemChangedType.HIGHLIGHTED_COLOR,
+ ItemChangedType.HIGHLIGHTED_STYLE,
+ ItemChangedType.TEXT,
+ ItemChangedType.OVERLAY,
+ ItemChangedType.VISUALIZATION_MODE,
+ )
+
+ self._plotWrapper = _Wrapper()
+ self._dealWithPlotConnection(create=True)
+
+ def setPlot(self, plot):
+ """Define the plot to interact with
+
+ :param Union[PlotWidget,SceneWidget,None] plot:
+ The plot containing the items on which statistics are applied
+ """
+ try:
+ import OpenGL
+ except ImportError:
+ has_opengl = False
+ else:
+ has_opengl = True
+ from ..plot3d.SceneWidget import SceneWidget # Lazy import
+ self._dealWithPlotConnection(create=False)
+ self.clear()
+ if plot is None:
+ self._plotWrapper = _Wrapper()
+ elif isinstance(plot, PlotWidget):
+ self._plotWrapper = _PlotWidgetWrapper(plot)
+ else:
+ if has_opengl is True:
+ if isinstance(plot, SceneWidget):
+ self._plotWrapper = _SceneWidgetWrapper(plot)
+ else: # Expect a ScalarFieldView
+ self._plotWrapper = _ScalarFieldViewWrapper(plot)
+ else:
+ _logger.warning('OpenGL not installed, %s not managed' % ('SceneWidget qnd ScalarFieldView'))
+ self._dealWithPlotConnection(create=True)
+
+ def setStats(self, statsHandler):
+ """Set which stats to display and the associated formatting.
+
+ :param StatsHandler statsHandler:
+ Set the statistics to be displayed and how to format them using
+ """
+ if statsHandler is None:
+ statsHandler = StatsHandler(statFormatters=())
+ elif isinstance(statsHandler, (list, tuple)):
+ statsHandler = StatsHandler(statsHandler)
+ assert isinstance(statsHandler, StatsHandler)
+
+ self._statsHandler = statsHandler
+
+ def getStatsHandler(self):
+ """Returns the :class:`StatsHandler` in use.
+
+ :rtype: StatsHandler
+ """
+ return self._statsHandler
+
+ def getPlot(self):
+ """Returns the plot attached to this widget
+
+ :rtype: Union[PlotWidget,SceneWidget,None]
+ """
+ return self._plotWrapper.getPlot()
+
+ def _dealWithPlotConnection(self, create=True):
+ """Manage connection to plot signals
+
+ Note: connection on Item are managed by _addItem and _removeItem methods
+ """
+ connections = [] # List of (signal, slot) to connect/disconnect
+ if self._statsOnVisibleData:
+ connections.append(
+ (self._plotWrapper.sigVisibleDataChanged, self._updateAllStats))
+
+ if self._displayOnlyActItem:
+ connections.append(
+ (self._plotWrapper.sigCurrentChanged, self._updateCurrentItem))
+ else:
+ connections += [
+ (self._plotWrapper.sigItemAdded, self._addItem),
+ (self._plotWrapper.sigItemRemoved, self._removeItem),
+ (self._plotWrapper.sigCurrentChanged, self._plotCurrentChanged)]
+
+ for signal, slot in connections:
+ if create:
+ signal.connect(slot)
+ else:
+ signal.disconnect(slot)
+
+ def _updateItemObserve(self, *args):
+ """Reload table depending on mode"""
+ raise NotImplementedError('Base class')
+
+ def _updateCurrentItem(self, *args):
+ """specific callback for the sigCurrentChanged and with the
+ _displayOnlyActItem option."""
+ raise NotImplementedError('Base class')
+
+ def _updateStats(self, item, data_changed=False, roi_changed=False):
+ """Update displayed information for given plot item
+
+ :param item: The plot item
+ :param bool data_changed: is the item data changed.
+ :param bool roi_changed: is the associated roi changed.
+ """
+ raise NotImplementedError('Base class')
+
+ def _updateAllStats(self):
+ """Update stats for all rows in the table"""
+ raise NotImplementedError('Base class')
+
+ def setDisplayOnlyActiveItem(self, displayOnlyActItem):
+ """Toggle display off all items or only the active/selected one
+
+ :param bool displayOnlyActItem:
+ True if we want to only show active item
+ """
+ self._displayOnlyActItem = displayOnlyActItem
+
+ def setStatsOnVisibleData(self, b):
+ """Toggle computation of statistics on whole data or only visible ones.
+
+ .. warning:: When visible data is activated we will process to a simple
+ filtering of visible data by the user. The filtering is a
+ simple data sub-sampling. No interpolation is made to fit
+ data to boundaries.
+
+ :param bool b: True if we want to apply statistics only on visible data
+ """
+ if self._statsOnVisibleData != b:
+ self._dealWithPlotConnection(create=False)
+ self._statsOnVisibleData = b
+ self._dealWithPlotConnection(create=True)
+ self._updateAllStats()
+
+ def _addItem(self, item):
+ """Add a plot item to the table
+
+ If item is not supported, it is ignored.
+
+ :param item: The plot item
+ :returns: True if the item is added to the widget.
+ :rtype: bool
+ """
+ raise NotImplementedError('Base class')
+
+ def _removeItem(self, item):
+ """Remove table items corresponding to given plot item from the table.
+
+ :param item: The plot item
+ """
+ raise NotImplementedError('Base class')
+
+ def _plotCurrentChanged(self, current):
+ """Handle change of current item and update selection in table
+
+ :param current:
+ """
+ raise NotImplementedError('Base class')
+
+ def clear(self):
+ """clear GUI"""
+ pass
+
+ def _skipPlotItemChangedEvent(self, event):
+ """
+
+ :param ItemChangedtype event: event to filter or not
+ :return: True if we want to ignore this ItemChangedtype
+ :rtype: bool
+ """
+ return event in self.__default_skipped_events
+
+ def setUpdateMode(self, mode):
+ """Set the way to update the displayed statistics.
+
+ :param mode: mode requested for update
+ :type mode: Union[str,UpdateMode]
+ """
+ mode = UpdateMode.from_value(mode)
+ if mode != self._updateMode:
+ self._updateMode = mode
+ self._updateModeHasChanged()
+
+ def getUpdateMode(self):
+ """Returns update mode (See :meth:`setUpdateMode`).
+
+ :return: update mode
+ :rtype: UpdateMode
+ """
+ return self._updateMode
+
+ def _updateModeHasChanged(self):
+ """callback when the update mode has changed"""
+ pass
+
+
+class StatsTable(_StatsWidgetBase, TableWidget):
+ """
+ TableWidget displaying for each items contained by the Plot some
+ information:
+
+ * legend
+ * minimal value
+ * maximal value
+ * standard deviation (std)
+
+ :param QWidget parent: The widget's parent.
+ :param Union[PlotWidget,SceneWidget] plot:
+ :class:`PlotWidget` or :class:`SceneWidget` instance on which to operate
+ """
+
+ _LEGEND_HEADER_DATA = 'legend'
+ _KIND_HEADER_DATA = 'kind'
+
+ sigUpdateModeChanged = qt.Signal(object)
+ """Signal emitted when the update mode changed"""
+
+ def __init__(self, parent=None, plot=None):
+ TableWidget.__init__(self, parent)
+ _StatsWidgetBase.__init__(self, statsOnVisibleData=False,
+ displayOnlyActItem=False)
+
+ # Init for _displayOnlyActItem == False
+ assert self._displayOnlyActItem is False
+ self.setSelectionBehavior(qt.QAbstractItemView.SelectRows)
+ self.setSelectionMode(qt.QAbstractItemView.SingleSelection)
+ self.currentItemChanged.connect(self._currentItemChanged)
+
+ self.setRowCount(0)
+ self.setColumnCount(2)
+
+ # Init headers
+ headerItem = qt.QTableWidgetItem(self._LEGEND_HEADER_DATA.title())
+ headerItem.setData(qt.Qt.UserRole, self._LEGEND_HEADER_DATA)
+ self.setHorizontalHeaderItem(0, headerItem)
+ headerItem = qt.QTableWidgetItem(self._KIND_HEADER_DATA.title())
+ headerItem.setData(qt.Qt.UserRole, self._KIND_HEADER_DATA)
+ self.setHorizontalHeaderItem(1, headerItem)
+
+ self.setSortingEnabled(True)
+ self.setPlot(plot)
+
+ @contextmanager
+ def _disableSorting(self):
+ """Context manager that disables table sorting
+
+ Previous state is restored when leaving
+ """
+ sorting = self.isSortingEnabled()
+ if sorting:
+ self.setSortingEnabled(False)
+ yield
+ if sorting:
+ self.setSortingEnabled(sorting)
+
+ def setStats(self, statsHandler):
+ """Set which stats to display and the associated formatting.
+
+ :param StatsHandler statsHandler:
+ Set the statistics to be displayed and how to format them using
+ """
+ self._removeAllItems()
+ _StatsWidgetBase.setStats(self, statsHandler)
+
+ self.setRowCount(0)
+ self.setColumnCount(len(self._statsHandler.stats) + 2) # + legend and kind
+
+ for index, stat in enumerate(self._statsHandler.stats.values()):
+ headerItem = qt.QTableWidgetItem(stat.name.capitalize())
+ headerItem.setData(qt.Qt.UserRole, stat.name)
+ if stat.description is not None:
+ headerItem.setToolTip(stat.description)
+ self.setHorizontalHeaderItem(2 + index, headerItem)
+
+ horizontalHeader = self.horizontalHeader()
+ horizontalHeader.setSectionResizeMode(qt.QHeaderView.ResizeToContents)
+
+ self._updateItemObserve()
+
+ def setPlot(self, plot):
+ """Define the plot to interact with
+
+ :param Union[PlotWidget,SceneWidget,None] plot:
+ The plot containing the items on which statistics are applied
+ """
+ _StatsWidgetBase.setPlot(self, plot)
+ self._updateItemObserve()
+
+ def clear(self):
+ """Define the plot to interact with
+
+ :param Union[PlotWidget,SceneWidget,None] plot:
+ The plot containing the items on which statistics are applied
+ """
+ self._removeAllItems()
+
+ def _updateItemObserve(self, *args):
+ """Reload table depending on mode"""
+ self._removeAllItems()
+
+ # Get selected or all items from the plot
+ if self._displayOnlyActItem: # Only selected
+ items = self._plotWrapper.getSelectedItems()
+ else: # All items
+ items = self._plotWrapper.getItems()
+
+ # Add items to the plot
+ for item in items:
+ self._addItem(item)
+
+ def _updateCurrentItem(self, *args):
+ """specific callback for the sigCurrentChanged and with the
+ _displayOnlyActItem option.
+
+ Behavior: create the tableItems if does not exists.
+ If exists, update it only when we are in 'auto' mode"""
+ if self.getUpdateMode() is UpdateMode.MANUAL:
+ # when sigCurrentChanged is giving the current item
+ if len(args) > 0 and isinstance(args[0], (plotitems.Curve, plotitems.Histogram, plotitems.ImageData, plotitems.Scatter)):
+ item = args[0]
+ tableItems = self._itemToTableItems(item)
+ # if the table does not exists yet
+ if len(tableItems) == 0:
+ self._updateItemObserve()
+ else:
+ # in this case no current item
+ self._updateItemObserve(args)
+ else:
+ # auto mode
+ self._updateItemObserve(args)
+
+ def _plotCurrentChanged(self, current):
+ """Handle change of current item and update selection in table
+
+ :param current:
+ """
+ row = self._itemToRow(current)
+ if row is None:
+ if self.currentRow() >= 0:
+ self.setCurrentCell(-1, -1)
+ elif row != self.currentRow():
+ self.setCurrentCell(row, 0)
+
+ def _tableItemToItem(self, tableItem):
+ """Find the plot item corresponding to a table item
+
+ :param QTableWidgetItem tableItem:
+ :rtype: QObject
+ """
+ container = tableItem.data(qt.Qt.UserRole)
+ return container()
+
+ def _itemToRow(self, item):
+ """Find the row corresponding to a plot item
+
+ :param item: The plot item
+ :return: The corresponding row index
+ :rtype: Union[int,None]
+ """
+ for row in range(self.rowCount()):
+ tableItem = self.item(row, 0)
+ if self._tableItemToItem(tableItem) == item:
+ return row
+ return None
+
+ def _itemToTableItems(self, item):
+ """Find all table items corresponding to a plot item
+
+ :param item: The plot item
+ :return: An ordered dict of column name to QTableWidgetItem mapping
+ for the given plot item.
+ :rtype: OrderedDict
+ """
+ result = OrderedDict()
+ row = self._itemToRow(item)
+ if row is not None:
+ for column in range(self.columnCount()):
+ tableItem = self.item(row, column)
+ if self._tableItemToItem(tableItem) != item:
+ _logger.error("Table item/plot item mismatch")
+ else:
+ header = self.horizontalHeaderItem(column)
+ name = header.data(qt.Qt.UserRole)
+ result[name] = tableItem
+ return result
+
+ def _plotItemChanged(self, event):
+ """Handle modifications of the items.
+
+ :param event:
+ """
+ if self.getUpdateMode() is UpdateMode.MANUAL:
+ return
+ if self._skipPlotItemChangedEvent(event) is True:
+ return
+ else:
+ item = self.sender()
+ self._updateStats(item, data_changed=True)
+ # deal with stat items visibility
+ if event is ItemChangedType.VISIBLE:
+ if len(self._itemToTableItems(item).items()) > 0:
+ item_0 = list(self._itemToTableItems(item).values())[0]
+ row_index = item_0.row()
+ self.setRowHidden(row_index, not item.isVisible())
+
+ def _addItem(self, item):
+ """Add a plot item to the table
+
+ If item is not supported, it is ignored.
+
+ :param item: The plot item
+ :returns: True if the item is added to the widget.
+ :rtype: bool
+ """
+ if self._itemToRow(item) is not None:
+ _logger.info("Item already present in the table")
+ self._updateStats(item)
+ return True
+
+ kind = self._plotWrapper.getKind(item)
+ if kind not in statsmdl.BASIC_COMPATIBLE_KINDS:
+ _logger.info("Item has not a supported type: %s", item)
+ return False
+
+ # Prepare table items
+ tableItems = [
+ qt.QTableWidgetItem(), # Legend
+ qt.QTableWidgetItem()] # Kind
+
+ for column in range(2, self.columnCount()):
+ header = self.horizontalHeaderItem(column)
+ name = header.data(qt.Qt.UserRole)
+
+ formatter = self._statsHandler.formatters[name]
+ if formatter:
+ tableItem = formatter.tabWidgetItemClass()
+ else:
+ tableItem = qt.QTableWidgetItem()
+
+ tooltip = self._statsHandler.stats[name].getToolTip(kind=kind)
+ if tooltip is not None:
+ tableItem.setToolTip(tooltip)
+
+ tableItems.append(tableItem)
+
+ # Disable sorting while adding table items
+ with self._disableSorting():
+ # Add a row to the table
+ self.setRowCount(self.rowCount() + 1)
+
+ # Add table items to the last row
+ row = self.rowCount() - 1
+ for column, tableItem in enumerate(tableItems):
+ tableItem.setData(qt.Qt.UserRole, _Container(item))
+ tableItem.setFlags(
+ qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable)
+ self.setItem(row, column, tableItem)
+
+ # Update table items content
+ self._updateStats(item, data_changed=True)
+
+ # Listen for item changes
+ # Using queued connection to avoid issue with sender
+ # being that of the signal calling the signal
+ item.sigItemChanged.connect(self._plotItemChanged,
+ qt.Qt.QueuedConnection)
+
+ return True
+
+ def _removeItem(self, item):
+ """Remove table items corresponding to given plot item from the table.
+
+ :param item: The plot item
+ """
+ row = self._itemToRow(item)
+ if row is None:
+ kind = self._plotWrapper.getKind(item)
+ if kind in statsmdl.BASIC_COMPATIBLE_KINDS:
+ _logger.error("Removing item that is not in table: %s", str(item))
+ return
+ item.sigItemChanged.disconnect(self._plotItemChanged)
+ self.removeRow(row)
+
+ def _removeAllItems(self):
+ """Remove content of the table"""
+ for row in range(self.rowCount()):
+ tableItem = self.item(row, 0)
+ item = self._tableItemToItem(tableItem)
+ item.sigItemChanged.disconnect(self._plotItemChanged)
+ self.clearContents()
+ self.setRowCount(0)
+
+ def _updateStats(self, item, data_changed=False, roi_changed=False):
+ """Update displayed information for given plot item
+
+ :param item: The plot item
+ :param bool data_changed: is the item data changed.
+ :param bool roi_changed: is the associated roi changed.
+ """
+ if item is None:
+ return
+ plot = self.getPlot()
+ if plot is None:
+ _logger.info("Plot not available")
+ return
+
+ row = self._itemToRow(item)
+ if row is None:
+ _logger.error("This item is not in the table: %s", str(item))
+ return
+
+ statsHandler = self.getStatsHandler()
+ if statsHandler is not None:
+ # _updateStats is call when the plot visible area change.
+ # to force stats update we consider roi changed
+ if self._statsOnVisibleData:
+ roi_changed = True
+ else:
+ roi_changed = False
+ stats = statsHandler.calculate(
+ item, plot, self._statsOnVisibleData,
+ data_changed=data_changed, roi_changed=roi_changed)
+ else:
+ stats = {}
+
+ with self._disableSorting():
+ for name, tableItem in self._itemToTableItems(item).items():
+ if name == self._LEGEND_HEADER_DATA:
+ text = self._plotWrapper.getLabel(item)
+ tableItem.setText(text)
+ elif name == self._KIND_HEADER_DATA:
+ tableItem.setText(self._plotWrapper.getKind(item))
+ else:
+ value = stats.get(name)
+ if value is None:
+ _logger.error("Value not found for: %s", name)
+ tableItem.setText('-')
+ else:
+ tableItem.setText(str(value))
+
+ def _updateAllStats(self, is_request=False):
+ """Update stats for all rows in the table
+
+ :param bool is_request: True if come from a manual request
+ """
+ if self.getUpdateMode() is UpdateMode.MANUAL and not is_request:
+ return
+ with self._disableSorting():
+ for row in range(self.rowCount()):
+ tableItem = self.item(row, 0)
+ item = self._tableItemToItem(tableItem)
+ self._updateStats(item, data_changed=is_request)
+
+ def _currentItemChanged(self, current, previous):
+ """Handle change of selection in table and sync plot selection
+
+ :param QTableWidgetItem current:
+ :param QTableWidgetItem previous:
+ """
+ if current and current.row() >= 0:
+ item = self._tableItemToItem(current)
+ self._plotWrapper.setCurrentItem(item)
+
+ def setDisplayOnlyActiveItem(self, displayOnlyActItem):
+ """Toggle display off all items or only the active/selected one
+
+ :param bool displayOnlyActItem:
+ True if we want to only show active item
+ """
+ if self._displayOnlyActItem == displayOnlyActItem:
+ return
+ self._dealWithPlotConnection(create=False)
+ if not self._displayOnlyActItem:
+ self.currentItemChanged.disconnect(self._currentItemChanged)
+
+ _StatsWidgetBase.setDisplayOnlyActiveItem(self, displayOnlyActItem)
+
+ self._updateItemObserve()
+ self._dealWithPlotConnection(create=True)
+
+ if not self._displayOnlyActItem:
+ self.currentItemChanged.connect(self._currentItemChanged)
+ self.setSelectionMode(qt.QAbstractItemView.SingleSelection)
+ else:
+ self.setSelectionMode(qt.QAbstractItemView.NoSelection)
+
+ def _updateModeHasChanged(self):
+ self.sigUpdateModeChanged.emit(self._updateMode)
+
+
+class UpdateModeWidget(qt.QWidget):
+ """Widget used to select the mode of update"""
+ sigUpdateModeChanged = qt.Signal(object)
+ """signal emitted when the mode for update changed"""
+ sigUpdateRequested = qt.Signal()
+ """signal emitted when an manual request for example is activate"""
+
+ def __init__(self, parent=None):
+ qt.QWidget.__init__(self, parent)
+ self.setLayout(qt.QHBoxLayout())
+ self._buttonGrp = qt.QButtonGroup(parent=self)
+ self._buttonGrp.setExclusive(True)
+
+ spacer = qt.QSpacerItem(20, 20,
+ qt.QSizePolicy.Expanding,
+ qt.QSizePolicy.Minimum)
+ self.layout().addItem(spacer)
+
+ self._autoRB = qt.QRadioButton('auto', parent=self)
+ self.layout().addWidget(self._autoRB)
+ self._buttonGrp.addButton(self._autoRB)
+
+ self._manualRB = qt.QRadioButton('manual', parent=self)
+ self.layout().addWidget(self._manualRB)
+ self._buttonGrp.addButton(self._manualRB)
+ self._manualRB.setChecked(True)
+
+ refresh_icon = icons.getQIcon('view-refresh')
+ self._updatePB = qt.QPushButton(refresh_icon, '', parent=self)
+ self.layout().addWidget(self._updatePB)
+
+ # connect signal / SLOT
+ self._updatePB.clicked.connect(self._updateRequested)
+ self._manualRB.toggled.connect(self._manualButtonToggled)
+ self._autoRB.toggled.connect(self._autoButtonToggled)
+
+ def _manualButtonToggled(self, checked):
+ if checked:
+ self.setUpdateMode(UpdateMode.MANUAL)
+ self.sigUpdateModeChanged.emit(self.getUpdateMode())
+
+ def _autoButtonToggled(self, checked):
+ if checked:
+ self.setUpdateMode(UpdateMode.AUTO)
+ self.sigUpdateModeChanged.emit(self.getUpdateMode())
+
+ def _updateRequested(self):
+ if self.getUpdateMode() is UpdateMode.MANUAL:
+ self.sigUpdateRequested.emit()
+
+ def setUpdateMode(self, mode):
+ """Set the way to update the displayed statistics.
+
+ :param mode: mode requested for update
+ :type mode: Union[str,UpdateMode]
+ """
+ mode = UpdateMode.from_value(mode)
+
+ if mode is UpdateMode.AUTO:
+ if not self._autoRB.isChecked():
+ self._autoRB.setChecked(True)
+ elif mode is UpdateMode.MANUAL:
+ if not self._manualRB.isChecked():
+ self._manualRB.setChecked(True)
+ else:
+ raise ValueError('mode', mode, 'is not recognized')
+
+ def getUpdateMode(self):
+ """Returns update mode (See :meth:`setUpdateMode`).
+
+ :return: the active update mode
+ :rtype: UpdateMode
+ """
+ if self._manualRB.isChecked():
+ return UpdateMode.MANUAL
+ elif self._autoRB.isChecked():
+ return UpdateMode.AUTO
+ else:
+ raise RuntimeError("No mode selected")
+
+ def showRadioButtons(self, show):
+ """show / hide the QRadioButtons
+
+ :param bool show: if True make RadioButton visible
+ """
+ self._autoRB.setVisible(show)
+ self._manualRB.setVisible(show)
+
+
+class _OptionsWidget(qt.QToolBar):
+
+ def __init__(self, parent=None, updateMode=None, displayOnlyActItem=False):
+ assert updateMode is not None
+ qt.QToolBar.__init__(self, parent)
+ self.setIconSize(qt.QSize(16, 16))
+
+ action = qt.QAction(self)
+ action.setIcon(icons.getQIcon("stats-active-items"))
+ action.setText("Active items only")
+ action.setToolTip("Display stats for active items only.")
+ action.setCheckable(True)
+ action.setChecked(displayOnlyActItem)
+ self.__displayActiveItems = action
+
+ action = qt.QAction(self)
+ action.setIcon(icons.getQIcon("stats-whole-items"))
+ action.setText("All items")
+ action.setToolTip("Display stats for all available items.")
+ action.setCheckable(True)
+ self.__displayWholeItems = action
+
+ action = qt.QAction(self)
+ action.setIcon(icons.getQIcon("stats-visible-data"))
+ action.setText("Use the visible data range")
+ action.setToolTip("Use the visible data range.<br/>"
+ "If activated the data is filtered to only use"
+ "visible data of the plot."
+ "The filtering is a data sub-sampling."
+ "No interpolation is made to fit data to"
+ "boundaries.")
+ action.setCheckable(True)
+ self.__useVisibleData = action
+
+ action = qt.QAction(self)
+ action.setIcon(icons.getQIcon("stats-whole-data"))
+ action.setText("Use the full data range")
+ action.setToolTip("Use the full data range.")
+ action.setCheckable(True)
+ action.setChecked(True)
+ self.__useWholeData = action
+
+ self.addAction(self.__displayWholeItems)
+ self.addAction(self.__displayActiveItems)
+ self.addSeparator()
+ self.addAction(self.__useVisibleData)
+ self.addAction(self.__useWholeData)
+
+ self.itemSelection = qt.QActionGroup(self)
+ self.itemSelection.setExclusive(True)
+ self.itemSelection.addAction(self.__displayActiveItems)
+ self.itemSelection.addAction(self.__displayWholeItems)
+
+ self.dataRangeSelection = qt.QActionGroup(self)
+ self.dataRangeSelection.setExclusive(True)
+ self.dataRangeSelection.addAction(self.__useWholeData)
+ self.dataRangeSelection.addAction(self.__useVisibleData)
+
+ self.__updateStatsAction = qt.QAction(self)
+ self.__updateStatsAction.setIcon(icons.getQIcon("view-refresh"))
+ self.__updateStatsAction.setText("update statistics")
+ self.__updateStatsAction.setToolTip("update statistics")
+ self.__updateStatsAction.setCheckable(False)
+ self._updateStatsSep = self.addSeparator()
+ self.addAction(self.__updateStatsAction)
+
+ self._setUpdateMode(mode=updateMode)
+
+ # expose API
+ self.sigUpdateStats = self.__updateStatsAction.triggered
+
+ def isActiveItemMode(self):
+ return self.itemSelection.checkedAction() is self.__displayActiveItems
+
+ def setDisplayActiveItems(self, only_active):
+ self.__displayActiveItems.setChecked(only_active)
+ self.__displayWholeItems.setChecked(not only_active)
+
+ def isVisibleDataRangeMode(self):
+ return self.dataRangeSelection.checkedAction() is self.__useVisibleData
+
+ def setVisibleDataRangeModeEnabled(self, enabled):
+ """Enable/Disable the visible data range mode
+
+ :param bool enabled: True to allow user to choose
+ stats on visible data
+ """
+ self.__useVisibleData.setEnabled(enabled)
+ if not enabled:
+ self.__useWholeData.setChecked(True)
+
+ def _setUpdateMode(self, mode):
+ self.__updateStatsAction.setVisible(mode == UpdateMode.MANUAL)
+ self._updateStatsSep.setVisible(mode == UpdateMode.MANUAL)
+
+ def getUpdateStatsAction(self):
+ """
+
+ :return: the action for the automatic mode
+ :rtype: QAction
+ """
+ return self.__updateStatsAction
+
+
+class StatsWidget(qt.QWidget):
+ """
+ Widget displaying a set of :class:`Stat` to be displayed on a
+ :class:`StatsTable` and to be apply on items contained in the :class:`Plot`
+ Also contains options to:
+
+ * compute statistics on all the data or on visible data only
+ * show statistics of all items or only the active one
+
+ :param QWidget parent: Qt parent
+ :param Union[PlotWidget,SceneWidget] plot:
+ The plot containing items on which we want statistics.
+ :param StatsHandler stats:
+ Set the statistics to be displayed and how to format them using
+ """
+
+ sigVisibilityChanged = qt.Signal(bool)
+ """Signal emitted when the visibility of this widget changes.
+
+ It Provides the visibility of the widget.
+ """
+
+ NUMBER_FORMAT = '{0:.3f}'
+
+ def __init__(self, parent=None, plot=None, stats=None):
+ qt.QWidget.__init__(self, parent)
+ self.setLayout(qt.QVBoxLayout())
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self._options = _OptionsWidget(parent=self, updateMode=UpdateMode.MANUAL)
+ self.layout().addWidget(self._options)
+ self._statsTable = StatsTable(parent=self, plot=plot)
+ self._statsTable.setDisplayOnlyActiveItem(self._options.isActiveItemMode())
+ self._options._setUpdateMode(mode=self._statsTable.getUpdateMode())
+ self.setStats(stats)
+
+ self.layout().addWidget(self._statsTable)
+
+ old = self._statsTable.blockSignals(True)
+ self._options.itemSelection.triggered.connect(
+ self._optSelectionChanged)
+ self._options.dataRangeSelection.triggered.connect(
+ self._optDataRangeChanged)
+ self._optDataRangeChanged()
+ self._statsTable.blockSignals(old)
+
+ self._statsTable.sigUpdateModeChanged.connect(self._options._setUpdateMode)
+ callback = functools.partial(self._getStatsTable()._updateAllStats, is_request=True)
+ self._options.sigUpdateStats.connect(callback)
+
+ def _getStatsTable(self):
+ """Returns the :class:`StatsTable` used by this widget.
+
+ :rtype: StatsTable
+ """
+ return self._statsTable
+
+ def showEvent(self, event):
+ self.sigVisibilityChanged.emit(True)
+ qt.QWidget.showEvent(self, event)
+
+ def hideEvent(self, event):
+ self.sigVisibilityChanged.emit(False)
+ qt.QWidget.hideEvent(self, event)
+
+ def _optSelectionChanged(self, action=None):
+ self._getStatsTable().setDisplayOnlyActiveItem(
+ self._options.isActiveItemMode())
+
+ def _optDataRangeChanged(self, action=None):
+ self._getStatsTable().setStatsOnVisibleData(
+ self._options.isVisibleDataRangeMode())
+
+ # Proxy methods
+
+ @docstring(StatsTable)
+ def setStats(self, statsHandler):
+ return self._getStatsTable().setStats(statsHandler=statsHandler)
+
+ @docstring(StatsTable)
+ def setPlot(self, plot):
+ self._options.setVisibleDataRangeModeEnabled(
+ plot is None or isinstance(plot, PlotWidget))
+ return self._getStatsTable().setPlot(plot=plot)
+
+ @docstring(StatsTable)
+ def getPlot(self):
+ return self._getStatsTable().getPlot()
+
+ @docstring(StatsTable)
+ def setDisplayOnlyActiveItem(self, displayOnlyActItem):
+ old = self._options.blockSignals(True)
+ # update the options
+ self._options.setDisplayActiveItems(displayOnlyActItem)
+ self._options.blockSignals(old)
+ return self._getStatsTable().setDisplayOnlyActiveItem(
+ displayOnlyActItem=displayOnlyActItem)
+
+ @docstring(StatsTable)
+ def setStatsOnVisibleData(self, b):
+ return self._getStatsTable().setStatsOnVisibleData(b=b)
+
+ @docstring(StatsTable)
+ def getUpdateMode(self):
+ return self._statsTable.getUpdateMode()
+
+ @docstring(StatsTable)
+ def setUpdateMode(self, mode):
+ self._statsTable.setUpdateMode(mode)
+
+
+DEFAULT_STATS = StatsHandler((
+ (statsmdl.StatMin(), StatFormatter()),
+ statsmdl.StatCoordMin(),
+ (statsmdl.StatMax(), StatFormatter()),
+ statsmdl.StatCoordMax(),
+ statsmdl.StatCOM(),
+ (('mean', numpy.mean), StatFormatter()),
+ (('std', numpy.std), StatFormatter()),
+))
+
+
+class BasicStatsWidget(StatsWidget):
+ """
+ Widget defining a simple set of :class:`Stat` to be displayed on a
+ :class:`StatsWidget`.
+
+ :param QWidget parent: Qt parent
+ :param PlotWidget plot:
+ The plot containing items on which we want statistics.
+ :param StatsHandler stats:
+ Set the statistics to be displayed and how to format them using
+
+ .. snapshotqt:: img/BasicStatsWidget.png
+ :width: 300px
+ :align: center
+
+ from silx.gui.plot import Plot1D
+ from silx.gui.plot.StatsWidget import BasicStatsWidget
+
+ plot = Plot1D()
+ x = range(100)
+ y = x
+ plot.addCurve(x, y, legend='curve_0')
+ plot.setActiveCurve('curve_0')
+
+ widget = BasicStatsWidget(plot=plot)
+ widget.show()
+ """
+ def __init__(self, parent=None, plot=None):
+ StatsWidget.__init__(self, parent=parent, plot=plot,
+ stats=DEFAULT_STATS)
+
+
+class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget):
+ """
+ Widget made to display stats into a QLayout with couple (QLabel, QLineEdit)
+ created for each stats.
+ The layout can be defined prior of adding any statistic.
+
+ :param QWidget parent: Qt parent
+ :param Union[PlotWidget,SceneWidget] plot:
+ The plot containing items on which we want statistics.
+ :param str kind: the kind of plotitems we want to display
+ :param StatsHandler stats:
+ Set the statistics to be displayed and how to format them using
+ :param bool statsOnVisibleData: compute statistics for the whole data or
+ only visible ones.
+ """
+
+ sigUpdateModeChanged = qt.Signal(object)
+ """Signal emitted when the update mode changed"""
+
+ def __init__(self, parent=None, plot=None, kind='curve', stats=None,
+ statsOnVisibleData=False):
+ self._item_kind = kind
+ """The item displayed"""
+ self._statQlineEdit = {}
+ """list of legends actually displayed"""
+ self._n_statistics_per_line = 4
+ """number of statistics displayed per line in the grid layout"""
+ qt.QWidget.__init__(self, parent)
+ _StatsWidgetBase.__init__(self,
+ statsOnVisibleData=statsOnVisibleData,
+ displayOnlyActItem=True)
+ self.setLayout(self._createLayout())
+ self.setPlot(plot)
+ if stats is not None:
+ self.setStats(stats)
+
+ def _addItemForStatistic(self, statistic):
+ assert isinstance(statistic, statsmdl.StatBase)
+ assert statistic.name in self._statsHandler.stats
+
+ self.layout().setSpacing(2)
+ self.layout().setContentsMargins(2, 2, 2, 2)
+
+ if isinstance(self.layout(), qt.QGridLayout):
+ parent = self
+ else:
+ widget = qt.QWidget(parent=self)
+ parent = widget
+
+ qLabel = qt.QLabel(statistic.name + ':', parent=parent)
+ qLineEdit = qt.QLineEdit('', parent=parent)
+ qLineEdit.setReadOnly(True)
+
+ self._addStatsWidgetsToLayout(qLabel=qLabel, qLineEdit=qLineEdit)
+ self._statQlineEdit[statistic.name] = qLineEdit
+
+ def setPlot(self, plot):
+ """Define the plot to interact with
+
+ :param Union[PlotWidget,SceneWidget,None] plot:
+ The plot containing the items on which statistics are applied
+ """
+ _StatsWidgetBase.setPlot(self, plot)
+ self._updateAllStats()
+
+ def _addStatsWidgetsToLayout(self, qLabel, qLineEdit):
+ raise NotImplementedError('Base class')
+
+ def setStats(self, statsHandler):
+ """Set which stats to display and the associated formatting.
+
+ :param StatsHandler statsHandler:
+ Set the statistics to be displayed and how to format them using
+ """
+ _StatsWidgetBase.setStats(self, statsHandler)
+ for statName, stat in list(self._statsHandler.stats.items()):
+ self._addItemForStatistic(stat)
+ self._updateAllStats()
+
+ def _activeItemChanged(self, kind, previous, current):
+ if self.getUpdateMode() is UpdateMode.MANUAL:
+ return
+ if kind == self._item_kind:
+ self._updateAllStats()
+
+ def _updateAllStats(self):
+ plot = self.getPlot()
+ if plot is not None:
+ _items = self._plotWrapper.getSelectedItems()
+
+ def kind_filter(_item):
+ return self._plotWrapper.getKind(_item) == self.getKind()
+ items = list(filter(kind_filter, _items))
+ assert len(items) in (0, 1)
+ if len(items) == 1:
+ self._setItem(items[0])
+
+ def setKind(self, kind):
+ """Change the kind of active item to display
+ :param str kind: kind of item to display information for ('curve' ...)
+ """
+ if self._item_kind != kind:
+ self._item_kind = kind
+ self._updateItemObserve()
+
+ def getKind(self):
+ """
+ :return: kind of item we want to compute statistic for
+ :rtype: str
+ """
+ return self._item_kind
+
+ def _setItem(self, item, data_changed=True):
+ if item is None:
+ for stat_name, stat_widget in self._statQlineEdit.items():
+ stat_widget.setText('')
+ elif (self._statsHandler is not None and len(
+ self._statsHandler.stats) > 0):
+ plot = self.getPlot()
+ if plot is not None:
+ statsValDict = self._statsHandler.calculate(item,
+ plot,
+ self._statsOnVisibleData,
+ data_changed=data_changed)
+ for statName, statVal in list(statsValDict.items()):
+ self._statQlineEdit[statName].setText(statVal)
+
+ def _updateItemObserve(self, *argv):
+ if self.getUpdateMode() is UpdateMode.MANUAL:
+ return
+ assert self._displayOnlyActItem
+ _items = self._plotWrapper.getSelectedItems()
+
+ def kind_filter(_item):
+ return self._plotWrapper.getKind(_item) == self.getKind()
+ items = list(filter(kind_filter, _items))
+ assert len(items) in (0, 1)
+ _item = items[0] if len(items) == 1 else None
+ self._setItem(_item, data_changed=True)
+
+ def _updateCurrentItem(self):
+ self._updateItemObserve()
+
+ def _createLayout(self):
+ """create an instance of the main QLayout"""
+ raise NotImplementedError('Base class')
+
+ def _addItem(self, item):
+ raise NotImplementedError('Display only the active item')
+
+ def _removeItem(self, item):
+ raise NotImplementedError('Display only the active item')
+
+ def _plotCurrentChanged(self, current):
+ raise NotImplementedError('Display only the active item')
+
+ def _updateModeHasChanged(self):
+ self.sigUpdateModeChanged.emit(self._updateMode)
+
+
+class _BasicLineStatsWidget(_BaseLineStatsWidget):
+ def __init__(self, parent=None, plot=None, kind='curve',
+ stats=DEFAULT_STATS, statsOnVisibleData=False):
+ _BaseLineStatsWidget.__init__(self, parent=parent, kind=kind,
+ plot=plot, stats=stats,
+ statsOnVisibleData=statsOnVisibleData)
+
+ def _createLayout(self):
+ return FlowLayout()
+
+ def _addStatsWidgetsToLayout(self, qLabel, qLineEdit):
+ # create a mother widget to make sure both qLabel & qLineEdit will
+ # always be displayed side by side
+ widget = qt.QWidget(parent=self)
+ widget.setLayout(qt.QHBoxLayout())
+ widget.layout().setSpacing(0)
+ widget.layout().setContentsMargins(0, 0, 0, 0)
+
+ widget.layout().addWidget(qLabel)
+ widget.layout().addWidget(qLineEdit)
+
+ self.layout().addWidget(widget)
+
+ def _addOptionsWidget(self, widget):
+ self.layout().addWidget(widget)
+
+
+class BasicLineStatsWidget(qt.QWidget):
+ """
+ Widget defining a simple set of :class:`Stat` to be displayed on a
+ :class:`LineStatsWidget`.
+
+ :param QWidget parent: Qt parent
+ :param Union[PlotWidget,SceneWidget] plot:
+ The plot containing items on which we want statistics.
+ :param str kind: the kind of plotitems we want to display
+ :param StatsHandler stats:
+ Set the statistics to be displayed and how to format them using
+ :param bool statsOnVisibleData: compute statistics for the whole data or
+ only visible ones.
+ """
+ def __init__(self, parent=None, plot=None, kind='curve',
+ stats=DEFAULT_STATS, statsOnVisibleData=False):
+ qt.QWidget.__init__(self, parent)
+ self.setLayout(qt.QHBoxLayout())
+ self.layout().setSpacing(0)
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self._lineStatsWidget = _BasicLineStatsWidget(parent=self, plot=plot,
+ kind=kind, stats=stats,
+ statsOnVisibleData=statsOnVisibleData)
+ self.layout().addWidget(self._lineStatsWidget)
+
+ self._options = UpdateModeWidget()
+ self._options.setUpdateMode(self._lineStatsWidget.getUpdateMode())
+ self._options.showRadioButtons(False)
+ self.layout().addWidget(self._options)
+
+ # connect Signal ? SLOT
+ self._lineStatsWidget.sigUpdateModeChanged.connect(self._options.setUpdateMode)
+ self._options.sigUpdateModeChanged.connect(self._lineStatsWidget.setUpdateMode)
+ self._options.sigUpdateRequested.connect(self._lineStatsWidget._updateAllStats)
+
+ def showControl(self, visible):
+ self._options.setVisible(visible)
+
+ # Proxy methods
+
+ @docstring(_BasicLineStatsWidget)
+ def setUpdateMode(self, mode):
+ self._lineStatsWidget.setUpdateMode(mode=mode)
+
+ @docstring(_BasicLineStatsWidget)
+ def getUpdateMode(self):
+ return self._lineStatsWidget.getUpdateMode()
+
+ @docstring(_BasicLineStatsWidget)
+ def setPlot(self, plot):
+ self._lineStatsWidget.setPlot(plot=plot)
+
+ @docstring(_BasicLineStatsWidget)
+ def setStats(self, statsHandler):
+ self._lineStatsWidget.setStats(statsHandler=statsHandler)
+
+ @docstring(_BasicLineStatsWidget)
+ def setKind(self, kind):
+ self._lineStatsWidget.setKind(kind=kind)
+
+ @docstring(_BasicLineStatsWidget)
+ def getKind(self):
+ return self._lineStatsWidget.getKind()
+
+ @docstring(_BasicLineStatsWidget)
+ def setStatsOnVisibleData(self, b):
+ self._lineStatsWidget.setStatsOnVisibleData(b)
+
+ @docstring(UpdateModeWidget)
+ def showRadioButtons(self, show):
+ self._options.showRadioButtons(show=show)
+
+
+class _BasicGridStatsWidget(_BaseLineStatsWidget):
+ def __init__(self, parent=None, plot=None, kind='curve',
+ stats=DEFAULT_STATS, statsOnVisibleData=False,
+ statsPerLine=4):
+ _BaseLineStatsWidget.__init__(self, parent=parent, kind=kind,
+ plot=plot, stats=stats,
+ statsOnVisibleData=statsOnVisibleData)
+ self._n_statistics_per_line = statsPerLine
+
+ def _addStatsWidgetsToLayout(self, qLabel, qLineEdit):
+ column = len(self._statQlineEdit) % self._n_statistics_per_line
+ row = len(self._statQlineEdit) // self._n_statistics_per_line
+ self.layout().addWidget(qLabel, row, column * 2)
+ self.layout().addWidget(qLineEdit, row, column * 2 + 1)
+
+ def _createLayout(self):
+ return qt.QGridLayout()
+
+
+class BasicGridStatsWidget(qt.QWidget):
+ """
+ pymca design like widget
+
+ :param QWidget parent: Qt parent
+ :param Union[PlotWidget,SceneWidget] plot:
+ The plot containing items on which we want statistics.
+ :param StatsHandler stats:
+ Set the statistics to be displayed and how to format them using
+ :param str kind: the kind of plotitems we want to display
+ :param bool statsOnVisibleData: compute statistics for the whole data or
+ only visible ones.
+ :param int statsPerLine: number of statistic to be displayed per line
+
+ .. snapshotqt:: img/BasicGridStatsWidget.png
+ :width: 600px
+ :align: center
+
+ from silx.gui.plot import Plot1D
+ from silx.gui.plot.StatsWidget import BasicGridStatsWidget
+
+ plot = Plot1D()
+ x = range(100)
+ y = x
+ plot.addCurve(x, y, legend='curve_0')
+ plot.setActiveCurve('curve_0')
+
+ widget = BasicGridStatsWidget(plot=plot, kind='curve')
+ widget.show()
+ """
+
+ def __init__(self, parent=None, plot=None, kind='curve',
+ stats=DEFAULT_STATS, statsOnVisibleData=False):
+ qt.QWidget.__init__(self, parent)
+ self.setLayout(qt.QVBoxLayout())
+ self.layout().setSpacing(0)
+ self.layout().setContentsMargins(0, 0, 0, 0)
+
+ self._options = UpdateModeWidget()
+ self._options.showRadioButtons(False)
+ self.layout().addWidget(self._options)
+
+ self._lineStatsWidget = _BasicGridStatsWidget(parent=self, plot=plot,
+ kind=kind, stats=stats,
+ statsOnVisibleData=statsOnVisibleData)
+ self.layout().addWidget(self._lineStatsWidget)
+
+ # tune options
+ self._options.setUpdateMode(self._lineStatsWidget.getUpdateMode())
+
+ # connect Signal ? SLOT
+ self._lineStatsWidget.sigUpdateModeChanged.connect(self._options.setUpdateMode)
+ self._options.sigUpdateModeChanged.connect(self._lineStatsWidget.setUpdateMode)
+ self._options.sigUpdateRequested.connect(self._lineStatsWidget._updateAllStats)
+
+ def showControl(self, visible):
+ self._options.setVisible(visible)
+
+ @docstring(_BasicGridStatsWidget)
+ def setUpdateMode(self, mode):
+ self._lineStatsWidget.setUpdateMode(mode=mode)
+
+ @docstring(_BasicGridStatsWidget)
+ def getUpdateMode(self):
+ return self._lineStatsWidget.getUpdateMode()
+
+ @docstring(_BasicGridStatsWidget)
+ def setPlot(self, plot):
+ self._lineStatsWidget.setPlot(plot=plot)
+
+ @docstring(_BasicGridStatsWidget)
+ def setStats(self, statsHandler):
+ self._lineStatsWidget.setStats(statsHandler=statsHandler)
+
+ @docstring(_BasicGridStatsWidget)
+ def setKind(self, kind):
+ self._lineStatsWidget.setKind(kind=kind)
+
+ @docstring(_BasicGridStatsWidget)
+ def getKind(self):
+ return self._lineStatsWidget.getKind()
+
+ @docstring(_BasicGridStatsWidget)
+ def setStatsOnVisibleData(self, b):
+ self._lineStatsWidget.setStatsOnVisibleData(b)
+
+ @docstring(UpdateModeWidget)
+ def showRadioButtons(self, show):
+ self._options.showRadioButtons(show=show)
diff --git a/src/silx/gui/plot/_BaseMaskToolsWidget.py b/src/silx/gui/plot/_BaseMaskToolsWidget.py
new file mode 100644
index 0000000..407ab11
--- /dev/null
+++ b/src/silx/gui/plot/_BaseMaskToolsWidget.py
@@ -0,0 +1,1282 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module is a collection of base classes used in modules
+:mod:`.MaskToolsWidget` (images) and :mod:`.ScatterMaskToolsWidget`
+"""
+from __future__ import division
+
+__authors__ = ["T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "08/12/2020"
+
+import os
+import weakref
+
+import numpy
+
+from silx.gui import qt, icons
+from silx.gui.widgets.FloatEdit import FloatEdit
+from silx.gui.colors import Colormap
+from silx.gui.colors import rgba
+from .actions.mode import PanModeAction
+
+
+class BaseMask(qt.QObject):
+ """Base class for :class:`ImageMask` and :class:`ScatterMask`
+
+ A mask field with update operations.
+
+ A mask is an array of the same shape as some underlying data. The mask
+ array stores integer values in the range 0-255, to allow for 254 levels
+ of mask (value 0 is reserved for unmasked data).
+
+ The mask is updated using spatial selection methods: data located inside
+ a selected area is masked with a specified mask level.
+
+ """
+
+ sigChanged = qt.Signal()
+ """Signal emitted when the mask has changed"""
+
+ sigStateChanged = qt.Signal()
+ """Signal emitted for each mask commit/undo/redo operation"""
+
+ sigUndoable = qt.Signal(bool)
+ """Signal emitted when undo becomes possible/impossible"""
+
+ sigRedoable = qt.Signal(bool)
+ """Signal emitted when redo becomes possible/impossible"""
+
+ def __init__(self, dataItem=None):
+ self.historyDepth = 10
+ """Maximum number of operation stored in history list for undo"""
+ # Init lists for undo/redo
+ self._history = []
+ self._redo = []
+
+ # Store the mask
+ self._mask = numpy.array((), dtype=numpy.uint8)
+
+ # Store the plot item to be masked
+ self._dataItem = None
+ if dataItem is not None:
+ self.setDataItem(dataItem)
+ self.reset(self.getDataValues().shape)
+ super(BaseMask, self).__init__()
+
+ def setDataItem(self, item):
+ """Set a data item
+
+ :param item: A plot item, subclass of :class:`silx.gui.plot.items.Item`
+ :return:
+ """
+ self._dataItem = item
+
+ def getDataItem(self):
+ """Returns current plot item the mask is on.
+
+ :rtype: Union[~silx.gui.plot.items.Item,None]
+ """
+ return self._dataItem
+
+ def getDataValues(self):
+ """Return data values, as a numpy array with the same shape
+ as the mask.
+
+ This method must be implemented in a subclass, as the way of
+ accessing data depends on the data item passed to :meth:`setDataItem`
+
+ :return: Data values associated with the data item.
+ :rtype: numpy.ndarray
+ """
+ raise NotImplementedError("To be implemented in subclass")
+
+ def _notify(self):
+ """Notify of mask change."""
+ self.sigChanged.emit()
+
+ def getMask(self, copy=True):
+ """Get the current mask as a numpy array.
+
+ :param bool copy: True (default) to get a copy of the mask.
+ If False, the returned array MUST not be modified.
+ :return: The array of the mask with dimension of the data to be masked.
+ :rtype: numpy.ndarray of uint8
+ """
+ return numpy.array(self._mask, copy=copy)
+
+ def setMask(self, mask, copy=True):
+ """Set the mask to a new array.
+
+ :param numpy.ndarray mask: The array to use for the mask.
+ :type mask: numpy.ndarray of uint8, C-contiguous.
+ Array of other types are converted.
+ :param bool copy: True (the default) to copy the array,
+ False to use it as is if possible.
+ """
+ self._mask = numpy.array(mask, copy=copy, order='C', dtype=numpy.uint8)
+ self._notify()
+
+ # History control
+ def resetHistory(self):
+ """Reset history"""
+ self._history = [numpy.array(self._mask, copy=True)]
+ self._redo = []
+ self.sigUndoable.emit(False)
+ self.sigRedoable.emit(False)
+
+ def commit(self):
+ """Append the current mask to history if changed"""
+ if (not self._history or self._redo or
+ not numpy.array_equal(self._mask, self._history[-1])):
+ if self._redo:
+ self._redo = [] # Reset redo as a new action as been performed
+ self.sigRedoable[bool].emit(False)
+
+ while len(self._history) >= self.historyDepth:
+ self._history.pop(0)
+ self._history.append(numpy.array(self._mask, copy=True))
+
+ if len(self._history) == 2:
+ self.sigUndoable.emit(True)
+ self.sigStateChanged.emit()
+
+ def undo(self):
+ """Restore previous mask if any"""
+ if len(self._history) > 1:
+ self._redo.append(self._history.pop())
+ self._mask = numpy.array(self._history[-1], copy=True)
+ self._notify() # Do not store this change in history
+
+ if len(self._redo) == 1: # First redo
+ self.sigRedoable.emit(True)
+ if len(self._history) == 1: # Last value in history
+ self.sigUndoable.emit(False)
+ self.sigStateChanged.emit()
+
+ def redo(self):
+ """Restore previously undone modification if any"""
+ if self._redo:
+ self._mask = self._redo.pop()
+ self._history.append(numpy.array(self._mask, copy=True))
+ self._notify()
+
+ if not self._redo: # No more redo
+ self.sigRedoable.emit(False)
+ if len(self._history) == 2: # Something to undo
+ self.sigUndoable.emit(True)
+ self.sigStateChanged.emit()
+
+ # Whole mask operations
+
+ def clear(self, level):
+ """Set all values of the given mask level to 0.
+
+ :param int level: Value of the mask to set to 0.
+ """
+ assert 0 < level < 256
+ self._mask[self._mask == level] = 0
+ self._notify()
+
+ def invert(self, level):
+ """Invert mask of the given mask level.
+
+ 0 values become level and level values become 0.
+
+ :param int level: The level to invert.
+ """
+ assert 0 < level < 256
+ masked = self._mask == level
+ self._mask[self._mask == 0] = level
+ self._mask[masked] = 0
+ self._notify()
+
+ def reset(self, shape=None):
+ """Reset the mask to zero and change its shape.
+
+ :param shape: Shape of the new mask with the correct dimensionality
+ with regards to the data dimensionality,
+ or None to have an empty mask
+ :type shape: tuple of int
+ """
+ if shape is None:
+ # assume dimensionality never changes
+ shape = (0,) * len(self._mask.shape) # empty array
+ shapeChanged = (shape != self._mask.shape)
+ self._mask = numpy.zeros(shape, dtype=numpy.uint8)
+ if shapeChanged:
+ self.resetHistory()
+
+ self._notify()
+
+ # To be implemented
+ def save(self, filename, kind):
+ """Save current mask in a file
+
+ :param str filename: The file where to save to mask
+ :param str kind: The kind of file to save (e.g 'npy')
+ :raise Exception: Raised if the file writing fail
+ """
+ raise NotImplementedError("To be implemented in subclass")
+
+ # update thresholds
+ def updateStencil(self, level, stencil, mask=True):
+ """Mask/Unmask points from boolean mask: all elements that are True
+ in the boolean mask are set to ``level`` (if ``mask=True``) or 0
+ (if ``mask=False``)
+
+ :param int level: Mask level to update.
+ :param stencil: Boolean mask.
+ :type stencil: numpy.array of same dimension as the mask
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ if mask:
+ self._mask[stencil] = level
+ else:
+ self._mask[numpy.logical_and(self._mask == level, stencil)] = 0
+ self._notify()
+
+ def updateBelowThreshold(self, level, threshold, mask=True):
+ """Mask/unmask all points whose values are below a threshold.
+
+ :param int level:
+ :param float threshold: Threshold
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ self.updateStencil(level,
+ self.getDataValues() < threshold,
+ mask)
+
+ def updateBetweenThresholds(self, level, min_, max_, mask=True):
+ """Mask/unmask all points whose values are in a range.
+
+ :param int level:
+ :param float min_: Lower threshold
+ :param float max_: Upper threshold
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ stencil = numpy.logical_and(min_ <= self.getDataValues(),
+ self.getDataValues() <= max_)
+ self.updateStencil(level, stencil, mask)
+
+ def updateAboveThreshold(self, level, threshold, mask=True):
+ """Mask/unmask all points whose values are above a threshold.
+
+ :param int level: Mask level to update.
+ :param float threshold: Threshold.
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ self.updateStencil(level,
+ self.getDataValues() > threshold,
+ mask)
+
+ def updateNotFinite(self, level, mask=True):
+ """Mask/unmask all points whose values are not finite.
+
+ :param int level: Mask level to update.
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ self.updateStencil(level,
+ numpy.logical_not(numpy.isfinite(self.getDataValues())),
+ mask)
+
+ # Drawing operations:
+ def updateRectangle(self, level, row, col, height, width, mask=True):
+ """Mask/Unmask data inside a rectangle, with the given mask level.
+
+ :param int level: Mask level to update, in range 1-255.
+ :param row: Starting row/y of the rectangle
+ :param col: Starting column/x of the rectangle
+ :param height:
+ :param width:
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ raise NotImplementedError("To be implemented in subclass")
+
+ def updatePolygon(self, level, vertices, mask=True):
+ """Mask/Unmask data inside a polygon, with the given mask level.
+
+ :param int level: Mask level to update.
+ :param vertices: Nx2 array of polygon corners as (row, col) / (y, x)
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ raise NotImplementedError("To be implemented in subclass")
+
+ def updatePoints(self, level, rows, cols, mask=True):
+ """Mask/Unmask points with given coordinates.
+
+ :param int level: Mask level to update.
+ :param rows: Rows/ordinates (y) of selected points
+ :type rows: 1D numpy.ndarray
+ :param cols: Columns/abscissa (x) of selected points
+ :type cols: 1D numpy.ndarray
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ raise NotImplementedError("To be implemented in subclass")
+
+ def updateDisk(self, level, crow, ccol, radius, mask=True):
+ """Mask/Unmask data located inside a dick of the given mask level.
+
+ :param int level: Mask level to update.
+ :param crow: Disk center row/ordinate (y).
+ :param ccol: Disk center column/abscissa.
+ :param float radius: Radius of the disk in mask array unit
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ raise NotImplementedError("To be implemented in subclass")
+
+ def updateEllipse(self, level, crow, ccol, radius_r, radius_c, mask=True):
+ """Mask/Unmask a disk of the given mask level.
+
+ :param int level: Mask level to update.
+ :param int crow: Row of the center of the ellipse
+ :param int ccol: Column of the center of the ellipse
+ :param float radius_r: Radius of the ellipse in the row
+ :param float radius_c: Radius of the ellipse in the column
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ raise NotImplementedError("To be implemented in subclass")
+
+ def updateLine(self, level, row0, col0, row1, col1, width, mask=True):
+ """Mask/Unmask a line of the given mask level.
+
+ :param int level: Mask level to update.
+ :param row0: Row/y of the starting point.
+ :param col0: Column/x of the starting point.
+ :param row1: Row/y of the end point.
+ :param col1: Column/x of the end point.
+ :param width: Width of the line in mask array unit.
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ raise NotImplementedError("To be implemented in subclass")
+
+
+class BaseMaskToolsWidget(qt.QWidget):
+ """Base class for :class:`MaskToolsWidget` (image mask) and
+ :class:`scatterMaskToolsWidget`"""
+
+ sigMaskChanged = qt.Signal()
+ _maxLevelNumber = 255
+
+ def __init__(self, parent=None, plot=None, mask=None):
+ """
+
+ :param parent: Parent QWidget
+ :param plot: Plot widget on which to operate
+ :param mask: Instance of subclass of :class:`BaseMask`
+ (e.g. :class:`ImageMask`)
+ """
+ super(BaseMaskToolsWidget, self).__init__(parent)
+ # register if the user as force a color for the corresponding mask level
+ self._defaultColors = numpy.ones((self._maxLevelNumber + 1), dtype=bool)
+ # overlays colors set by the user
+ self._overlayColors = numpy.zeros((self._maxLevelNumber + 1, 3), dtype=numpy.float32)
+
+ # as parent have to be the first argument of the widget to fit
+ # QtDesigner need but here plot can't be None by default.
+ assert plot is not None
+ self._plotRef = weakref.ref(plot)
+ self._maskName = '__MASK_TOOLS_%d' % id(self) # Legend of the mask
+
+ self._colormap = Colormap(normalization='linear',
+ vmin=0,
+ vmax=self._maxLevelNumber)
+ self._defaultOverlayColor = rgba('gray') # Color of the mask
+ self._setMaskColors(1, 0.5) # Set the colormap LUT
+
+ if not isinstance(mask, BaseMask):
+ raise TypeError("mask is not an instance of BaseMask")
+ self._mask = mask
+
+ self._mask.sigChanged.connect(self._updatePlotMask)
+ self._mask.sigChanged.connect(self._emitSigMaskChanged)
+
+ self._drawingMode = None # Store current drawing mode
+ self._lastPencilPos = None
+ self._multipleMasks = 'exclusive'
+
+ self._maskFileDir = qt.QDir.home().absolutePath()
+ self.plot.sigInteractiveModeChanged.connect(
+ self._interactiveModeChanged)
+
+ self._initWidgets()
+
+ def _emitSigMaskChanged(self):
+ """Notify mask changes"""
+ self.sigMaskChanged.emit()
+
+ def getMaskedItem(self):
+ """Returns the item that is currently being masked
+
+ :rtype: Union[~silx.gui.plot.items.Item,None]
+ """
+ return self._mask.getDataItem()
+
+ def getSelectionMask(self, copy=True):
+ """Get the current mask as a numpy array.
+
+ :param bool copy: True (default) to get a copy of the mask.
+ If False, the returned array MUST not be modified.
+ :return: The mask (as an array of uint8) with dimension of
+ the 'active' plot item.
+ If there is no active image or scatter, it returns None.
+ :rtype: Union[numpy.ndarray,None]
+ """
+ mask = self._mask.getMask(copy=copy)
+ return None if mask.size == 0 else mask
+
+ def setSelectionMask(self, mask):
+ """Set the mask: Must be implemented in subclass"""
+ raise NotImplementedError()
+
+ def resetSelectionMask(self):
+ """Reset the mask: Must be implemented in subclass"""
+ raise NotImplementedError()
+
+ def multipleMasks(self):
+ """Return the current mode of multiple masks support.
+
+ See :meth:`setMultipleMasks`
+ """
+ return self._multipleMasks
+
+ def setMultipleMasks(self, mode):
+ """Set the mode of multiple masks support.
+
+ Available modes:
+
+ - 'single': Edit a single level of mask
+ - 'exclusive': Supports to 256 levels of non overlapping masks
+
+ :param str mode: The mode to use
+ """
+ assert mode in ('exclusive', 'single')
+ if mode != self._multipleMasks:
+ self._multipleMasks = mode
+ self._levelWidget.setVisible(self._multipleMasks != 'single')
+ self._clearAllBtn.setVisible(self._multipleMasks != 'single')
+
+ def setMaskFileDirectory(self, path):
+ """Set the default directory to use by load/save GUI tools
+
+ The directory is also updated by the user, if he change the location
+ of the dialog.
+ """
+ self.maskFileDir = path
+
+ def getMaskFileDirectory(self):
+ """Get the default directory used by load/save GUI tools"""
+ return self.maskFileDir
+
+ @property
+ def maskFileDir(self):
+ """The directory from which to load/save mask from/to files."""
+ if not os.path.isdir(self._maskFileDir):
+ self._maskFileDir = qt.QDir.home().absolutePath()
+ return self._maskFileDir
+
+ @maskFileDir.setter
+ def maskFileDir(self, maskFileDir):
+ self._maskFileDir = str(maskFileDir)
+
+ @property
+ def plot(self):
+ """The :class:`.PlotWindow` this widget is attached to."""
+ plot = self._plotRef()
+ if plot is None:
+ raise RuntimeError(
+ 'Mask widget attached to a PlotWidget that no longer exists')
+ return plot
+
+ def setDirection(self, direction=qt.QBoxLayout.LeftToRight):
+ """Set the direction of the layout of the widget
+
+ :param direction: QBoxLayout direction
+ """
+ self.layout().setDirection(direction)
+
+ def _initWidgets(self):
+ """Create widgets"""
+ layout = qt.QBoxLayout(qt.QBoxLayout.LeftToRight)
+ layout.addWidget(self._initMaskGroupBox())
+ layout.addWidget(self._initDrawGroupBox())
+ layout.addWidget(self._initThresholdGroupBox())
+ layout.addWidget(self._initOtherToolsGroupBox())
+ layout.addStretch(1)
+ self.setLayout(layout)
+
+ @staticmethod
+ def _hboxWidget(*widgets, **kwargs):
+ """Place widgets in widget with horizontal layout
+
+ :param widgets: Widgets to position horizontally
+ :param bool stretch: True for trailing stretch (default),
+ False for no trailing stretch
+ :return: A QWidget with a QHBoxLayout
+ """
+ stretch = kwargs.get('stretch', True)
+
+ layout = qt.QHBoxLayout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ for widget in widgets:
+ layout.addWidget(widget)
+ if stretch:
+ layout.addStretch(1)
+ widget = qt.QWidget()
+ widget.setLayout(layout)
+ return widget
+
+ def _initTransparencyWidget(self):
+ """ Init the mask transparency widget """
+ transparencyWidget = qt.QWidget(parent=self)
+ grid = qt.QGridLayout()
+ grid.setContentsMargins(0, 0, 0, 0)
+ self.transparencySlider = qt.QSlider(qt.Qt.Horizontal, parent=transparencyWidget)
+ self.transparencySlider.setRange(3, 10)
+ self.transparencySlider.setValue(8)
+ self.transparencySlider.setToolTip(
+ 'Set the transparency of the mask display')
+ self.transparencySlider.valueChanged.connect(self._updateColors)
+ grid.addWidget(qt.QLabel('Display:', parent=transparencyWidget), 0, 0)
+ grid.addWidget(self.transparencySlider, 0, 1, 1, 3)
+ grid.addWidget(qt.QLabel('<small><b>Transparent</b></small>', parent=transparencyWidget), 1, 1)
+ grid.addWidget(qt.QLabel('<small><b>Opaque</b></small>', parent=transparencyWidget), 1, 3)
+ transparencyWidget.setLayout(grid)
+ return transparencyWidget
+
+ def _initMaskGroupBox(self):
+ """Init general mask operation widgets"""
+
+ # Mask level
+ self.levelSpinBox = qt.QSpinBox()
+ self.levelSpinBox.setRange(1, self._maxLevelNumber)
+ self.levelSpinBox.setToolTip(
+ 'Choose which mask level is edited.\n'
+ 'A mask can have up to 255 non-overlapping levels.')
+ self.levelSpinBox.valueChanged[int].connect(self._updateColors)
+ self._levelWidget = self._hboxWidget(qt.QLabel('Mask level:'),
+ self.levelSpinBox)
+ # Transparency
+ self._transparencyWidget = self._initTransparencyWidget()
+
+ style = qt.QApplication.style()
+
+ def getIcon(*identifiyers):
+ for i in identifiyers:
+ if isinstance(i, str):
+ if qt.QIcon.hasThemeIcon(i):
+ return qt.QIcon.fromTheme(i)
+ elif isinstance(i, qt.QIcon):
+ return i
+ else:
+ return style.standardIcon(i)
+ return qt.QIcon()
+
+ undoAction = qt.QAction(self)
+ undoAction.setText('Undo')
+ icon = getIcon("edit-undo", qt.QStyle.SP_ArrowBack)
+ undoAction.setIcon(icon)
+ undoAction.setShortcut(qt.QKeySequence.Undo)
+ undoAction.setToolTip('Undo last mask change <b>%s</b>' %
+ undoAction.shortcut().toString())
+ self._mask.sigUndoable.connect(undoAction.setEnabled)
+ undoAction.triggered.connect(self._mask.undo)
+
+ redoAction = qt.QAction(self)
+ redoAction.setText('Redo')
+ icon = getIcon("edit-redo", qt.QStyle.SP_ArrowForward)
+ redoAction.setIcon(icon)
+ redoAction.setShortcut(qt.QKeySequence.Redo)
+ redoAction.setToolTip('Redo last undone mask change <b>%s</b>' %
+ redoAction.shortcut().toString())
+ self._mask.sigRedoable.connect(redoAction.setEnabled)
+ redoAction.triggered.connect(self._mask.redo)
+
+ loadAction = qt.QAction(self)
+ loadAction.setText('Load...')
+ icon = icons.getQIcon("document-open")
+ loadAction.setIcon(icon)
+ loadAction.setToolTip('Load mask from file')
+ loadAction.triggered.connect(self._loadMask)
+
+ saveAction = qt.QAction(self)
+ saveAction.setText('Save...')
+ icon = icons.getQIcon("document-save")
+ saveAction.setIcon(icon)
+ saveAction.setToolTip('Save mask to file')
+ saveAction.triggered.connect(self._saveMask)
+
+ invertAction = qt.QAction(self)
+ invertAction.setText('Invert')
+ icon = icons.getQIcon("mask-invert")
+ invertAction.setIcon(icon)
+ invertAction.setShortcut(qt.Qt.CTRL + qt.Qt.Key_I)
+ invertAction.setToolTip('Invert current mask <b>%s</b>' %
+ invertAction.shortcut().toString())
+ invertAction.triggered.connect(self._handleInvertMask)
+
+ clearAction = qt.QAction(self)
+ clearAction.setText('Clear')
+ icon = icons.getQIcon("mask-clear")
+ clearAction.setIcon(icon)
+ clearAction.setShortcut(qt.QKeySequence.Delete)
+ clearAction.setToolTip('Clear current mask level <b>%s</b>' %
+ clearAction.shortcut().toString())
+ clearAction.triggered.connect(self._handleClearMask)
+
+ clearAllAction = qt.QAction(self)
+ clearAllAction.setText('Clear all')
+ icon = icons.getQIcon("mask-clear-all")
+ clearAllAction.setIcon(icon)
+ clearAllAction.setToolTip('Clear all mask levels')
+ clearAllAction.triggered.connect(self.resetSelectionMask)
+
+ # Buttons group
+ margin1 = qt.QWidget(self)
+ margin1.setMinimumWidth(6)
+ margin2 = qt.QWidget(self)
+ margin2.setMinimumWidth(6)
+
+ actions = (loadAction, saveAction, margin1,
+ undoAction, redoAction, margin2,
+ invertAction, clearAction, clearAllAction)
+ widgets = []
+ for action in actions:
+ if isinstance(action, qt.QWidget):
+ widgets.append(action)
+ continue
+ btn = qt.QToolButton()
+ btn.setDefaultAction(action)
+ widgets.append(btn)
+ if action is clearAllAction:
+ self._clearAllBtn = btn
+ container = self._hboxWidget(*widgets)
+ container.layout().setSpacing(1)
+
+ layout = qt.QVBoxLayout()
+ layout.addWidget(container)
+ layout.addWidget(self._levelWidget)
+ layout.addWidget(self._transparencyWidget)
+ layout.addStretch(1)
+
+ maskGroup = qt.QGroupBox('Mask')
+ maskGroup.setLayout(layout)
+ return maskGroup
+
+ def isMaskInteractionActivated(self):
+ """Returns true if any mask interaction is activated"""
+ return self.drawActionGroup.checkedAction() is not None
+
+ def _initDrawGroupBox(self):
+ """Init drawing tools widgets"""
+ layout = qt.QVBoxLayout()
+
+ self.browseAction = PanModeAction(self.plot, self.plot)
+ self.addAction(self.browseAction)
+
+ # Draw tools
+ self.rectAction = qt.QAction(icons.getQIcon('shape-rectangle'),
+ 'Rectangle selection',
+ self)
+ self.rectAction.setToolTip(
+ 'Rectangle selection tool: (Un)Mask a rectangular region <b>R</b>')
+ self.rectAction.setShortcut(qt.QKeySequence(qt.Qt.Key_R))
+ self.rectAction.setCheckable(True)
+ self.rectAction.triggered.connect(self._activeRectMode)
+ self.addAction(self.rectAction)
+
+ self.ellipseAction = qt.QAction(icons.getQIcon('shape-ellipse'),
+ 'Circle selection',
+ self)
+ self.ellipseAction.setToolTip(
+ 'Rectangle selection tool: (Un)Mask a circle region <b>R</b>')
+ self.ellipseAction.setShortcut(qt.QKeySequence(qt.Qt.Key_R))
+ self.ellipseAction.setCheckable(True)
+ self.ellipseAction.triggered.connect(self._activeEllipseMode)
+ self.addAction(self.ellipseAction)
+
+ self.polygonAction = qt.QAction(icons.getQIcon('shape-polygon'),
+ 'Polygon selection',
+ self)
+ self.polygonAction.setShortcut(qt.QKeySequence(qt.Qt.Key_S))
+ self.polygonAction.setToolTip(
+ 'Polygon selection tool: (Un)Mask a polygonal region <b>S</b><br>'
+ 'Left-click to place new polygon corners<br>'
+ 'Left-click on first corner to close the polygon')
+ self.polygonAction.setCheckable(True)
+ self.polygonAction.triggered.connect(self._activePolygonMode)
+ self.addAction(self.polygonAction)
+
+ self.pencilAction = qt.QAction(icons.getQIcon('draw-pencil'),
+ 'Pencil tool',
+ self)
+ self.pencilAction.setShortcut(qt.QKeySequence(qt.Qt.Key_P))
+ self.pencilAction.setToolTip(
+ 'Pencil tool: (Un)Mask using a pencil <b>P</b>')
+ self.pencilAction.setCheckable(True)
+ self.pencilAction.triggered.connect(self._activePencilMode)
+ self.addAction(self.pencilAction)
+
+ self.drawActionGroup = qt.QActionGroup(self)
+ self.drawActionGroup.setExclusive(True)
+ self.drawActionGroup.addAction(self.rectAction)
+ self.drawActionGroup.addAction(self.ellipseAction)
+ self.drawActionGroup.addAction(self.polygonAction)
+ self.drawActionGroup.addAction(self.pencilAction)
+
+ actions = (self.browseAction, self.rectAction, self.ellipseAction,
+ self.polygonAction, self.pencilAction)
+ drawButtons = []
+ for action in actions:
+ btn = qt.QToolButton()
+ btn.setDefaultAction(action)
+ drawButtons.append(btn)
+ container = self._hboxWidget(*drawButtons)
+ layout.addWidget(container)
+
+ # Mask/Unmask radio buttons
+ maskRadioBtn = qt.QRadioButton('Mask')
+ maskRadioBtn.setToolTip(
+ 'Drawing masks with current level. Press <b>Ctrl</b> to unmask')
+ maskRadioBtn.setChecked(True)
+
+ unmaskRadioBtn = qt.QRadioButton('Unmask')
+ unmaskRadioBtn.setToolTip(
+ 'Drawing unmasks with current level. Press <b>Ctrl</b> to mask')
+
+ self.maskStateGroup = qt.QButtonGroup()
+ self.maskStateGroup.addButton(maskRadioBtn, 1)
+ self.maskStateGroup.addButton(unmaskRadioBtn, 0)
+
+ self.maskStateWidget = self._hboxWidget(maskRadioBtn, unmaskRadioBtn)
+ layout.addWidget(self.maskStateWidget)
+
+ self.maskStateWidget.setHidden(True)
+
+ # Pencil settings
+ self.pencilSetting = self._createPencilSettings(None)
+ self.pencilSetting.setVisible(False)
+ layout.addWidget(self.pencilSetting)
+
+ layout.addStretch(1)
+
+ drawGroup = qt.QGroupBox('Draw tools')
+ drawGroup.setLayout(layout)
+ return drawGroup
+
+ def _createPencilSettings(self, parent=None):
+ pencilSetting = qt.QWidget(parent)
+
+ self.pencilSpinBox = qt.QSpinBox(parent=pencilSetting)
+ self.pencilSpinBox.setRange(1, 1024)
+ pencilToolTip = """Set pencil drawing tool size in pixels of the image
+ on which to make the mask."""
+ self.pencilSpinBox.setToolTip(pencilToolTip)
+
+ self.pencilSlider = qt.QSlider(qt.Qt.Horizontal, parent=pencilSetting)
+ self.pencilSlider.setRange(1, 50)
+ self.pencilSlider.setToolTip(pencilToolTip)
+
+ pencilLabel = qt.QLabel('Pencil size:', parent=pencilSetting)
+
+ layout = qt.QGridLayout()
+ layout.addWidget(pencilLabel, 0, 0)
+ layout.addWidget(self.pencilSpinBox, 0, 1)
+ layout.addWidget(self.pencilSlider, 1, 1)
+ pencilSetting.setLayout(layout)
+
+ self.pencilSpinBox.valueChanged.connect(self._pencilWidthChanged)
+ self.pencilSlider.valueChanged.connect(self._pencilWidthChanged)
+
+ return pencilSetting
+
+ def _initThresholdGroupBox(self):
+ """Init thresholding widgets"""
+
+ self.belowThresholdAction = qt.QAction(icons.getQIcon('plot-roi-below'),
+ 'Mask below threshold',
+ self)
+ self.belowThresholdAction.setToolTip(
+ 'Mask image where values are below given threshold')
+ self.belowThresholdAction.setCheckable(True)
+ self.belowThresholdAction.setChecked(True)
+
+ self.betweenThresholdAction = qt.QAction(icons.getQIcon('plot-roi-between'),
+ 'Mask within range',
+ self)
+ self.betweenThresholdAction.setToolTip(
+ 'Mask image where values are within given range')
+ self.betweenThresholdAction.setCheckable(True)
+
+ self.aboveThresholdAction = qt.QAction(icons.getQIcon('plot-roi-above'),
+ 'Mask above threshold',
+ self)
+ self.aboveThresholdAction.setToolTip(
+ 'Mask image where values are above given threshold')
+ self.aboveThresholdAction.setCheckable(True)
+
+ self.thresholdActionGroup = qt.QActionGroup(self)
+ self.thresholdActionGroup.setExclusive(True)
+ self.thresholdActionGroup.addAction(self.belowThresholdAction)
+ self.thresholdActionGroup.addAction(self.betweenThresholdAction)
+ self.thresholdActionGroup.addAction(self.aboveThresholdAction)
+ self.thresholdActionGroup.triggered.connect(
+ self._thresholdActionGroupTriggered)
+
+ self.loadColormapRangeAction = qt.QAction(icons.getQIcon('view-refresh'),
+ 'Set min-max from colormap',
+ self)
+ self.loadColormapRangeAction.setToolTip(
+ 'Set min and max values from current colormap range')
+ self.loadColormapRangeAction.setCheckable(False)
+ self.loadColormapRangeAction.triggered.connect(
+ self._loadRangeFromColormapTriggered)
+
+ widgets = []
+ for action in self.thresholdActionGroup.actions():
+ btn = qt.QToolButton()
+ btn.setDefaultAction(action)
+ widgets.append(btn)
+
+ spacer = qt.QWidget(parent=self)
+ spacer.setSizePolicy(qt.QSizePolicy.Expanding,
+ qt.QSizePolicy.Preferred)
+ widgets.append(spacer)
+
+ loadColormapRangeBtn = qt.QToolButton()
+ loadColormapRangeBtn.setDefaultAction(self.loadColormapRangeAction)
+ widgets.append(loadColormapRangeBtn)
+
+ toolBar = self._hboxWidget(*widgets, stretch=False)
+
+ config = qt.QGridLayout()
+ config.setContentsMargins(0, 0, 0, 0)
+
+ self.minLineLabel = qt.QLabel("Min:", self)
+ self.minLineEdit = FloatEdit(self, value=0)
+ config.addWidget(self.minLineLabel, 0, 0)
+ config.addWidget(self.minLineEdit, 0, 1)
+
+ self.maxLineLabel = qt.QLabel("Max:", self)
+ self.maxLineEdit = FloatEdit(self, value=0)
+ config.addWidget(self.maxLineLabel, 1, 0)
+ config.addWidget(self.maxLineEdit, 1, 1)
+
+ self.applyMaskBtn = qt.QPushButton('Apply mask')
+ self.applyMaskBtn.clicked.connect(self._maskBtnClicked)
+
+ layout = qt.QVBoxLayout()
+ layout.addWidget(toolBar)
+ layout.addLayout(config)
+ layout.addWidget(self.applyMaskBtn)
+ layout.addStretch(1)
+
+ self.thresholdGroup = qt.QGroupBox('Threshold')
+ self.thresholdGroup.setLayout(layout)
+
+ # Init widget state
+ self._thresholdActionGroupTriggered(self.belowThresholdAction)
+ return self.thresholdGroup
+
+ # track widget visibility and plot active image changes
+
+ def _initOtherToolsGroupBox(self):
+ layout = qt.QVBoxLayout()
+
+ self.maskNanBtn = qt.QPushButton('Mask not finite values')
+ self.maskNanBtn.setToolTip('Mask Not a Number and infinite values')
+ self.maskNanBtn.clicked.connect(self._maskNotFiniteBtnClicked)
+ layout.addWidget(self.maskNanBtn)
+ layout.addStretch(1)
+
+ self.otherToolGroup = qt.QGroupBox('Other tools')
+ self.otherToolGroup.setLayout(layout)
+ return self.otherToolGroup
+
+ def changeEvent(self, event):
+ """Reset drawing action when disabling widget"""
+ if (event.type() == qt.QEvent.EnabledChange and
+ not self.isEnabled() and
+ self.drawActionGroup.checkedAction()):
+ # Disable drawing tool by setting interaction to zoom
+ self.browseAction.trigger()
+
+ def save(self, filename, kind):
+ """Save current mask in a file
+
+ :param str filename: The file where to save to mask
+ :param str kind: The kind of file to save in 'edf', 'tif', 'npy'
+ :raise Exception: Raised if the process fails
+ """
+ self._mask.save(filename, kind)
+
+ def getCurrentMaskColor(self):
+ """Returns the color of the current selected level.
+
+ :rtype: A tuple or a python array
+ """
+ currentLevel = self.levelSpinBox.value()
+ if self._defaultColors[currentLevel]:
+ return self._defaultOverlayColor
+ else:
+ return self._overlayColors[currentLevel].tolist()
+
+ def _setMaskColors(self, level, alpha):
+ """Set-up the mask colormap to highlight current mask level.
+
+ :param int level: The mask level to highlight
+ :param float alpha: Alpha level of mask in [0., 1.]
+ """
+ assert 0 < level <= self._maxLevelNumber
+
+ colors = numpy.empty((self._maxLevelNumber + 1, 4), dtype=numpy.float32)
+
+ # Set color
+ colors[:,:3] = self._defaultOverlayColor[:3]
+
+ # check if some colors has been directly set by the user
+ mask = numpy.equal(self._defaultColors, False)
+ colors[mask,:3] = self._overlayColors[mask,:3]
+
+ # Set alpha
+ colors[:, -1] = alpha / 2.
+
+ # Set highlighted level color
+ colors[level, 3] = alpha
+
+ # Set no mask level
+ colors[0] = (0., 0., 0., 0.)
+
+ self._colormap.setColormapLUT(colors)
+
+ def resetMaskColors(self, level=None):
+ """Reset the mask color at the given level to be defaultColors
+
+ :param level:
+ The index of the mask for which we want to reset the color.
+ If none we will reset color for all masks.
+ """
+ if level is None:
+ self._defaultColors[level] = True
+ else:
+ self._defaultColors[:] = True
+
+ self._updateColors()
+
+ def setMaskColors(self, rgb, level=None):
+ """Set the masks color
+
+ :param rgb: The rgb color
+ :param level:
+ The index of the mask for which we want to change the color.
+ If none set this color for all the masks
+ """
+ rgb = rgba(rgb)[0:3]
+ if level is None:
+ self._overlayColors[:] = rgb
+ self._defaultColors[:] = False
+ else:
+ self._overlayColors[level] = rgb
+ self._defaultColors[level] = False
+
+ self._updateColors()
+
+ def getMaskColors(self):
+ """masks colors getter"""
+ return self._overlayColors
+
+ def _updateColors(self, *args):
+ """Rebuild mask colormap when selected level or transparency change"""
+ self._setMaskColors(self.levelSpinBox.value(),
+ self.transparencySlider.value() /
+ self.transparencySlider.maximum())
+ self._updatePlotMask()
+ self._updateInteractiveMode()
+
+ def _pencilWidthChanged(self, width):
+
+ old = self.pencilSpinBox.blockSignals(True)
+ try:
+ self.pencilSpinBox.setValue(width)
+ finally:
+ self.pencilSpinBox.blockSignals(old)
+
+ old = self.pencilSlider.blockSignals(True)
+ try:
+ self.pencilSlider.setValue(width)
+ finally:
+ self.pencilSlider.blockSignals(old)
+ self._updateInteractiveMode()
+
+ def _updateInteractiveMode(self):
+ """Update the current mode to the same if some cached data have to be
+ updated. It is the case for the color for example.
+ """
+ if self._drawingMode == 'rectangle':
+ self._activeRectMode()
+ elif self._drawingMode == 'ellipse':
+ self._activeEllipseMode()
+ elif self._drawingMode == 'polygon':
+ self._activePolygonMode()
+ elif self._drawingMode == 'pencil':
+ self._activePencilMode()
+
+ def _handleClearMask(self):
+ """Handle clear button clicked: reset current level mask"""
+ self._mask.clear(self.levelSpinBox.value())
+ self._mask.commit()
+
+ def _handleInvertMask(self):
+ """Invert the current mask level selection."""
+ self._mask.invert(self.levelSpinBox.value())
+ self._mask.commit()
+
+ # Handle drawing tools UI events
+
+ def _interactiveModeChanged(self, source):
+ """Handle plot interactive mode changed:
+
+ If changed from elsewhere, disable drawing tool
+ """
+ if source is not self:
+ self.pencilAction.setChecked(False)
+ self.rectAction.setChecked(False)
+ self.polygonAction.setChecked(False)
+ self._releaseDrawingMode()
+ self._updateDrawingModeWidgets()
+
+ def _releaseDrawingMode(self):
+ """Release the drawing mode if is was used"""
+ if self._drawingMode is None:
+ return
+ self.plot.sigPlotSignal.disconnect(self._plotDrawEvent)
+ self._drawingMode = None
+
+ def _activeRectMode(self):
+ """Handle rect action mode triggering"""
+ self._releaseDrawingMode()
+ self._drawingMode = 'rectangle'
+ self.plot.sigPlotSignal.connect(self._plotDrawEvent)
+ color = self.getCurrentMaskColor()
+ self.plot.setInteractiveMode(
+ 'draw', shape='rectangle', source=self, color=color)
+ self._updateDrawingModeWidgets()
+
+ def _activeEllipseMode(self):
+ """Handle circle action mode triggering"""
+ self._releaseDrawingMode()
+ self._drawingMode = 'ellipse'
+ self.plot.sigPlotSignal.connect(self._plotDrawEvent)
+ color = self.getCurrentMaskColor()
+ self.plot.setInteractiveMode(
+ 'draw', shape='ellipse', source=self, color=color)
+ self._updateDrawingModeWidgets()
+
+ def _activePolygonMode(self):
+ """Handle polygon action mode triggering"""
+ self._releaseDrawingMode()
+ self._drawingMode = 'polygon'
+ self.plot.sigPlotSignal.connect(self._plotDrawEvent)
+ color = self.getCurrentMaskColor()
+ self.plot.setInteractiveMode('draw', shape='polygon', source=self, color=color)
+ self._updateDrawingModeWidgets()
+
+ def _getPencilWidth(self):
+ """Returns the width of the pencil to use in data coordinates`
+
+ :rtype: float
+ """
+ return self.pencilSpinBox.value()
+
+ def _activePencilMode(self):
+ """Handle pencil action mode triggering"""
+ self._releaseDrawingMode()
+ self._drawingMode = 'pencil'
+ self.plot.sigPlotSignal.connect(self._plotDrawEvent)
+ color = self.getCurrentMaskColor()
+ width = self._getPencilWidth()
+ self.plot.setInteractiveMode(
+ 'draw', shape='pencil', source=self, color=color, width=width)
+ self._updateDrawingModeWidgets()
+
+ def _updateDrawingModeWidgets(self):
+ self.maskStateWidget.setVisible(self._drawingMode is not None)
+ self.pencilSetting.setVisible(self._drawingMode == 'pencil')
+
+ # Handle plot drawing events
+
+ def _isMasking(self):
+ """Returns true if the tool is used for masking, else it is used for
+ unmasking.
+
+ :rtype: bool"""
+ # First draw event, use current modifiers for all draw sequence
+ doMask = (self.maskStateGroup.checkedId() == 1)
+ if qt.QApplication.keyboardModifiers() & qt.Qt.ControlModifier:
+ doMask = not doMask
+ return doMask
+
+ # Handle threshold UI events
+
+ def _thresholdActionGroupTriggered(self, triggeredAction):
+ """Threshold action group listener."""
+ if triggeredAction is self.belowThresholdAction:
+ self.minLineLabel.setVisible(True)
+ self.maxLineLabel.setVisible(False)
+ self.minLineEdit.setVisible(True)
+ self.maxLineEdit.setVisible(False)
+ self.applyMaskBtn.setText("Mask below")
+ elif triggeredAction is self.betweenThresholdAction:
+ self.minLineLabel.setVisible(True)
+ self.maxLineLabel.setVisible(True)
+ self.minLineEdit.setVisible(True)
+ self.maxLineEdit.setVisible(True)
+ self.applyMaskBtn.setText("Mask between")
+ elif triggeredAction is self.aboveThresholdAction:
+ self.minLineLabel.setVisible(False)
+ self.maxLineLabel.setVisible(True)
+ self.minLineEdit.setVisible(False)
+ self.maxLineEdit.setVisible(True)
+ self.applyMaskBtn.setText("Mask above")
+ self.applyMaskBtn.setToolTip(triggeredAction.toolTip())
+
+ def _maskBtnClicked(self):
+ if self.belowThresholdAction.isChecked():
+ if self.minLineEdit.text():
+ self._mask.updateBelowThreshold(self.levelSpinBox.value(),
+ self.minLineEdit.value())
+ self._mask.commit()
+
+ elif self.betweenThresholdAction.isChecked():
+ if self.minLineEdit.text() and self.maxLineEdit.text():
+ min_ = self.minLineEdit.value()
+ max_ = self.maxLineEdit.value()
+ self._mask.updateBetweenThresholds(self.levelSpinBox.value(),
+ min_, max_)
+ self._mask.commit()
+
+ elif self.aboveThresholdAction.isChecked():
+ if self.maxLineEdit.text():
+ max_ = float(self.maxLineEdit.value())
+ self._mask.updateAboveThreshold(self.levelSpinBox.value(),
+ max_)
+ self._mask.commit()
+
+ def _maskNotFiniteBtnClicked(self):
+ """Handle not finite mask button clicked: mask NaNs and inf"""
+ self._mask.updateNotFinite(
+ self.levelSpinBox.value())
+ self._mask.commit()
+
+
+class BaseMaskToolsDockWidget(qt.QDockWidget):
+ """Base class for :class:`MaskToolsWidget` and
+ :class:`ScatterMaskToolsWidget`.
+
+ For integration in a :class:`PlotWindow`.
+
+ :param parent: See :class:`QDockWidget`
+ :paran str name: The title of this widget
+ """
+
+ sigMaskChanged = qt.Signal()
+
+ def __init__(self, parent=None, name='Mask', widget=None):
+ super(BaseMaskToolsDockWidget, self).__init__(parent)
+ self.setWindowTitle(name)
+
+ if not isinstance(widget, BaseMaskToolsWidget):
+ raise TypeError("BaseMaskToolsDockWidget requires a MaskToolsWidget")
+ self.setWidget(widget)
+ self.widget().sigMaskChanged.connect(self._emitSigMaskChanged)
+
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self.dockLocationChanged.connect(self._dockLocationChanged)
+ self.topLevelChanged.connect(self._topLevelChanged)
+
+ def _emitSigMaskChanged(self):
+ """Notify mask changes"""
+ # must be connected to self.widget().sigMaskChanged in child class
+ self.sigMaskChanged.emit()
+
+ def getSelectionMask(self, copy=True):
+ """Get the current mask as a 2D array.
+
+ :param bool copy: True (default) to get a copy of the mask.
+ If False, the returned array MUST not be modified.
+ :return: The array of the mask with dimension of the 'active' image.
+ If there is no active image, an empty array is returned.
+ :rtype: 2D numpy.ndarray of uint8
+ """
+ return self.widget().getSelectionMask(copy=copy)
+
+ def setSelectionMask(self, mask, copy=True):
+ """Set the mask to a new array.
+
+ :param numpy.ndarray mask: The array to use for the mask.
+ :type mask: numpy.ndarray of uint8 of dimension 2, C-contiguous.
+ Array of other types are converted.
+ :param bool copy: True (the default) to copy the array,
+ False to use it as is if possible.
+ :return: None if failed, shape of mask as 2-tuple if successful.
+ The mask can be cropped or padded to fit active image,
+ the returned shape is that of the active image.
+ """
+ return self.widget().setSelectionMask(mask, copy=copy)
+
+ def resetSelectionMask(self):
+ """Reset the mask to an array of zeros with the shape of the
+ current data."""
+ self.widget().resetSelectionMask()
+
+ def toggleViewAction(self):
+ """Returns a checkable action that shows or closes this widget.
+
+ See :class:`QMainWindow`.
+ """
+ action = super(BaseMaskToolsDockWidget, self).toggleViewAction()
+ action.setIcon(icons.getQIcon('image-mask'))
+ action.setToolTip("Display/hide mask tools")
+ return action
+
+ def _dockLocationChanged(self, area):
+ if area in (qt.Qt.LeftDockWidgetArea, qt.Qt.RightDockWidgetArea):
+ direction = qt.QBoxLayout.TopToBottom
+ else:
+ direction = qt.QBoxLayout.LeftToRight
+ self.widget().setDirection(direction)
+
+ def _topLevelChanged(self, topLevel):
+ if topLevel:
+ self.widget().setDirection(qt.QBoxLayout.LeftToRight)
+ self.resize(self.widget().minimumSize())
+ self.adjustSize()
+
+ def showEvent(self, event):
+ """Make sure this widget is raised when it is shown
+ (when it is first created as a tab in PlotWindow or when it is shown
+ again after hiding).
+ """
+ self.raise_()
diff --git a/src/silx/gui/plot/__init__.py b/src/silx/gui/plot/__init__.py
new file mode 100644
index 0000000..3a141b3
--- /dev/null
+++ b/src/silx/gui/plot/__init__.py
@@ -0,0 +1,71 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides a set of Qt widgets for plotting curves and images.
+
+The plotting API is inherited from the `PyMca <http://pymca.sourceforge.net/>`_
+plot API and is mostly compatible with it.
+
+Those widgets supports interaction (e.g., zoom, pan, selections).
+
+List of Qt widgets:
+
+.. currentmodule:: silx.gui.plot
+
+- :mod:`.PlotWidget`: A widget displaying a single plot.
+- :mod:`.PlotWindow`: A :mod:`.PlotWidget` with a configurable set of tools.
+- :class:`.Plot1D`: A widget with tools for curves.
+- :class:`.Plot2D`: A widget with tools for images.
+- :class:`.ScatterView`: A widget with tools for scatter plot.
+- :class:`.ImageView`: A widget with tools for images and a side histogram.
+- :class:`.StackView`: A widget with tools for a stack of images.
+
+By default, those widget are using matplotlib_.
+They can optionally use a faster OpenGL-based rendering (beta feature),
+which is enabled by setting the ``backend`` argument to ``'gl'``
+when creating the widgets (See :class:`.PlotWidget`).
+
+.. note::
+
+ This package depends on matplotlib_.
+ The OpenGL backend further depends on
+ `PyOpenGL <http://pyopengl.sourceforge.net/>`_ and OpenGL >= 2.1.
+
+.. _matplotlib: http://matplotlib.org/
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/05/2017"
+
+
+from .PlotWidget import PlotWidget # noqa
+from .PlotWindow import PlotWindow, Plot1D, Plot2D # noqa
+from .items.axis import TickMode
+from .ImageView import ImageView # noqa
+from .StackView import StackView # noqa
+from .ScatterView import ScatterView # noqa
+
+__all__ = ['ImageView', 'PlotWidget', 'PlotWindow', 'Plot1D', 'Plot2D',
+ 'StackView', 'ScatterView', 'TickMode']
diff --git a/src/silx/gui/plot/_utils/__init__.py b/src/silx/gui/plot/_utils/__init__.py
new file mode 100644
index 0000000..ed87b18
--- /dev/null
+++ b/src/silx/gui/plot/_utils/__init__.py
@@ -0,0 +1,92 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Miscellaneous utility functions for the Plot"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "21/03/2017"
+
+
+import numpy
+
+from .panzoom import FLOAT32_SAFE_MIN, FLOAT32_MINPOS, FLOAT32_SAFE_MAX
+from .panzoom import applyZoomToPlot, applyPan, checkAxisLimits
+
+
+def addMarginsToLimits(margins, isXLog, isYLog,
+ xMin, xMax, yMin, yMax, y2Min=None, y2Max=None):
+ """Returns updated limits by extending them with margins.
+
+ :param margins: The ratio of the margins to add or None for no margins.
+ :type margins: A 4-tuple of floats as
+ (xMinMargin, xMaxMargin, yMinMargin, yMaxMargin)
+
+ :return: The updated limits
+ :rtype: tuple of 4 or 6 floats: Either (xMin, xMax, yMin, yMax) or
+ (xMin, xMax, yMin, yMax, y2Min, y2Max) if y2Min and y2Max
+ are provided.
+ """
+ if margins is not None:
+ xMinMargin, xMaxMargin, yMinMargin, yMaxMargin = margins
+
+ if not isXLog:
+ xRange = xMax - xMin
+ xMin -= xMinMargin * xRange
+ xMax += xMaxMargin * xRange
+
+ elif xMin > 0. and xMax > 0.: # Log scale
+ # Do not apply margins if limits < 0
+ xMinLog, xMaxLog = numpy.log10(xMin), numpy.log10(xMax)
+ xRangeLog = xMaxLog - xMinLog
+ xMin = pow(10., xMinLog - xMinMargin * xRangeLog)
+ xMax = pow(10., xMaxLog + xMaxMargin * xRangeLog)
+
+ if not isYLog:
+ yRange = yMax - yMin
+ yMin -= yMinMargin * yRange
+ yMax += yMaxMargin * yRange
+ elif yMin > 0. and yMax > 0.: # Log scale
+ # Do not apply margins if limits < 0
+ yMinLog, yMaxLog = numpy.log10(yMin), numpy.log10(yMax)
+ yRangeLog = yMaxLog - yMinLog
+ yMin = pow(10., yMinLog - yMinMargin * yRangeLog)
+ yMax = pow(10., yMaxLog + yMaxMargin * yRangeLog)
+
+ if y2Min is not None and y2Max is not None:
+ if not isYLog:
+ yRange = y2Max - y2Min
+ y2Min -= yMinMargin * yRange
+ y2Max += yMaxMargin * yRange
+ elif y2Min > 0. and y2Max > 0.: # Log scale
+ # Do not apply margins if limits < 0
+ yMinLog, yMaxLog = numpy.log10(y2Min), numpy.log10(y2Max)
+ yRangeLog = yMaxLog - yMinLog
+ y2Min = pow(10., yMinLog - yMinMargin * yRangeLog)
+ y2Max = pow(10., yMaxLog + yMaxMargin * yRangeLog)
+
+ if y2Min is None or y2Max is None:
+ return xMin, xMax, yMin, yMax
+ else:
+ return xMin, xMax, yMin, yMax, y2Min, y2Max
diff --git a/src/silx/gui/plot/_utils/delaunay.py b/src/silx/gui/plot/_utils/delaunay.py
new file mode 100644
index 0000000..49ad05f
--- /dev/null
+++ b/src/silx/gui/plot/_utils/delaunay.py
@@ -0,0 +1,62 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Wrapper over Delaunay implementation"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "02/05/2019"
+
+
+import logging
+import sys
+
+import numpy
+
+
+_logger = logging.getLogger(__name__)
+
+
+def delaunay(x, y):
+ """Returns Delaunay instance for x, y points
+
+ :param numpy.ndarray x:
+ :param numpy.ndarray y:
+ :rtype: Union[None,scipy.spatial.Delaunay]
+ """
+ # Lazy-loading of Delaunay
+ try:
+ from scipy.spatial import Delaunay as _Delaunay
+ except ImportError: # Fallback using local Delaunay
+ from silx.third_party.scipy_spatial import Delaunay as _Delaunay
+
+ points = numpy.array((x, y)).T
+ try:
+ delaunay = _Delaunay(points)
+ except (RuntimeError, ValueError):
+ _logger.error("Delaunay tesselation failed: %s",
+ sys.exc_info()[1])
+ delaunay = None
+
+ return delaunay
diff --git a/src/silx/gui/plot/_utils/dtime_ticklayout.py b/src/silx/gui/plot/_utils/dtime_ticklayout.py
new file mode 100644
index 0000000..ebf775b
--- /dev/null
+++ b/src/silx/gui/plot/_utils/dtime_ticklayout.py
@@ -0,0 +1,442 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module implements date-time labels layout on graph axes."""
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["P. Kenter"]
+__license__ = "MIT"
+__date__ = "04/04/2018"
+
+
+import datetime as dt
+import enum
+import logging
+import math
+import time
+
+import dateutil.tz
+
+from dateutil.relativedelta import relativedelta
+
+from .ticklayout import niceNumGeneric
+
+_logger = logging.getLogger(__name__)
+
+
+MICROSECONDS_PER_SECOND = 1000000
+SECONDS_PER_MINUTE = 60
+SECONDS_PER_HOUR = 60 * SECONDS_PER_MINUTE
+SECONDS_PER_DAY = 24 * SECONDS_PER_HOUR
+SECONDS_PER_YEAR = 365.25 * SECONDS_PER_DAY
+SECONDS_PER_MONTH_AVERAGE = SECONDS_PER_YEAR / 12 # Seconds per average month
+
+
+# No dt.timezone in Python 2.7 so we use dateutil.tz.tzutc
+_EPOCH = dt.datetime(1970, 1, 1, tzinfo=dateutil.tz.tzutc())
+
+def timestamp(dtObj):
+ """ Returns POSIX timestamp of a datetime objects.
+
+ If the dtObj object has a timestamp() method (python 3.3), this is
+ used. Otherwise (e.g. python 2.7) it is calculated here.
+
+ The POSIX timestamp is a floating point value of the number of seconds
+ since the start of an epoch (typically 1970-01-01). For details see:
+ https://docs.python.org/3/library/datetime.html#datetime.datetime.timestamp
+
+ :param datetime.datetime dtObj: date-time representation.
+ :return: POSIX timestamp
+ :rtype: float
+ """
+ if hasattr(dtObj, "timestamp"):
+ return dtObj.timestamp()
+ else:
+ # Back ported from Python 3.5
+ if dtObj.tzinfo is None:
+ return time.mktime((dtObj.year, dtObj.month, dtObj.day,
+ dtObj.hour, dtObj.minute, dtObj.second,
+ -1, -1, -1)) + dtObj.microsecond / 1e6
+ else:
+ return (dtObj - _EPOCH).total_seconds()
+
+
+@enum.unique
+class DtUnit(enum.Enum):
+ YEARS = 0
+ MONTHS = 1
+ DAYS = 2
+ HOURS = 3
+ MINUTES = 4
+ SECONDS = 5
+ MICRO_SECONDS = 6 # a fraction of a second
+
+
+def getDateElement(dateTime, unit):
+ """ Picks the date element with the unit from the dateTime
+
+ E.g. getDateElement(datetime(1970, 5, 6), DtUnit.Day) will return 6
+
+ :param datetime dateTime: date/time to pick from
+ :param DtUnit unit: The unit describing the date element.
+ """
+ if unit == DtUnit.YEARS:
+ return dateTime.year
+ elif unit == DtUnit.MONTHS:
+ return dateTime.month
+ elif unit == DtUnit.DAYS:
+ return dateTime.day
+ elif unit == DtUnit.HOURS:
+ return dateTime.hour
+ elif unit == DtUnit.MINUTES:
+ return dateTime.minute
+ elif unit == DtUnit.SECONDS:
+ return dateTime.second
+ elif unit == DtUnit.MICRO_SECONDS:
+ return dateTime.microsecond
+ else:
+ raise ValueError("Unexpected DtUnit: {}".format(unit))
+
+
+def setDateElement(dateTime, value, unit):
+ """ Returns a copy of dateTime with the tickStep unit set to value
+
+ :param datetime.datetime: date time object
+ :param int value: value to set
+ :param DtUnit unit: unit
+ :return: datetime.datetime
+ """
+ intValue = int(value)
+ _logger.debug("setDateElement({}, {} (int={}), {})"
+ .format(dateTime, value, intValue, unit))
+
+ year = dateTime.year
+ month = dateTime.month
+ day = dateTime.day
+ hour = dateTime.hour
+ minute = dateTime.minute
+ second = dateTime.second
+ microsecond = dateTime.microsecond
+
+ if unit == DtUnit.YEARS:
+ year = intValue
+ elif unit == DtUnit.MONTHS:
+ month = intValue
+ elif unit == DtUnit.DAYS:
+ day = intValue
+ elif unit == DtUnit.HOURS:
+ hour = intValue
+ elif unit == DtUnit.MINUTES:
+ minute = intValue
+ elif unit == DtUnit.SECONDS:
+ second = intValue
+ elif unit == DtUnit.MICRO_SECONDS:
+ microsecond = intValue
+ else:
+ raise ValueError("Unexpected DtUnit: {}".format(unit))
+
+ _logger.debug("creating date time {}"
+ .format((year, month, day, hour, minute, second, microsecond)))
+
+ return dt.datetime(year, month, day, hour, minute, second, microsecond,
+ tzinfo=dateTime.tzinfo)
+
+
+
+def roundToElement(dateTime, unit):
+ """ Returns a copy of dateTime rounded to given unit
+
+ :param datetime.datetime: date time object
+ :param DtUnit unit: unit
+ :return: datetime.datetime
+ """
+ year = dateTime.year
+ month = dateTime.month
+ day = dateTime.day
+ hour = dateTime.hour
+ minute = dateTime.minute
+ second = dateTime.second
+ microsecond = dateTime.microsecond
+
+ if unit.value < DtUnit.YEARS.value:
+ pass # Never round years
+ if unit.value < DtUnit.MONTHS.value:
+ month = 1
+ if unit.value < DtUnit.DAYS.value:
+ day = 1
+ if unit.value < DtUnit.HOURS.value:
+ hour = 0
+ if unit.value < DtUnit.MINUTES.value:
+ minute = 0
+ if unit.value < DtUnit.SECONDS.value:
+ second = 0
+ if unit.value < DtUnit.MICRO_SECONDS.value:
+ microsecond = 0
+
+ result = dt.datetime(year, month, day, hour, minute, second, microsecond,
+ tzinfo=dateTime.tzinfo)
+
+ return result
+
+
+def addValueToDate(dateTime, value, unit):
+ """ Adds a value with unit to a dateTime.
+
+ Uses dateutil.relativedelta.relativedelta from the standard library to do
+ the actual math. This function doesn't allow for fractional month or years,
+ so month and year are truncated to integers before adding.
+
+ :param datetime dateTime: date time
+ :param float value: value to be added
+ :param DtUnit unit: of the value
+ :return:
+ """
+ #logger.debug("addValueToDate({}, {}, {})".format(dateTime, value, unit))
+
+ if unit == DtUnit.YEARS:
+ intValue = int(value) # floats not implemented in relativeDelta(years)
+ return dateTime + relativedelta(years=intValue)
+ elif unit == DtUnit.MONTHS:
+ intValue = int(value) # floats not implemented in relativeDelta(mohths)
+ return dateTime + relativedelta(months=intValue)
+ elif unit == DtUnit.DAYS:
+ return dateTime + relativedelta(days=value)
+ elif unit == DtUnit.HOURS:
+ return dateTime + relativedelta(hours=value)
+ elif unit == DtUnit.MINUTES:
+ return dateTime + relativedelta(minutes=value)
+ elif unit == DtUnit.SECONDS:
+ return dateTime + relativedelta(seconds=value)
+ elif unit == DtUnit.MICRO_SECONDS:
+ return dateTime + relativedelta(microseconds=value)
+ else:
+ raise ValueError("Unexpected DtUnit: {}".format(unit))
+
+
+def bestUnit(durationInSeconds):
+ """ Gets the best tick spacing given a duration in seconds.
+
+ :param durationInSeconds: time span duration in seconds
+ :return: DtUnit enumeration.
+ """
+
+ # Based on; https://stackoverflow.com/a/2144398/
+ # If the duration is longer than two years the tick spacing will be in
+ # years. Else, if the duration is longer than two months, the spacing will
+ # be in months, Etcetera.
+ #
+ # This factor differs per unit. As a baseline it is 2, but for instance,
+ # for Months this needs to be higher (3>), This because it is impossible to
+ # have partial months so the tick spacing is always at least 1 month. A
+ # duration of two months would result in two ticks, which is too few.
+ # months would then results
+
+ if durationInSeconds > SECONDS_PER_YEAR * 3:
+ return (durationInSeconds / SECONDS_PER_YEAR, DtUnit.YEARS)
+ elif durationInSeconds > SECONDS_PER_MONTH_AVERAGE * 3:
+ return (durationInSeconds / SECONDS_PER_MONTH_AVERAGE, DtUnit.MONTHS)
+ elif durationInSeconds > SECONDS_PER_DAY * 2:
+ return (durationInSeconds / SECONDS_PER_DAY, DtUnit.DAYS)
+ elif durationInSeconds > SECONDS_PER_HOUR * 2:
+ return (durationInSeconds / SECONDS_PER_HOUR, DtUnit.HOURS)
+ elif durationInSeconds > SECONDS_PER_MINUTE * 2:
+ return (durationInSeconds / SECONDS_PER_MINUTE, DtUnit.MINUTES)
+ elif durationInSeconds > 1 * 2:
+ return (durationInSeconds, DtUnit.SECONDS)
+ else:
+ return (durationInSeconds * MICROSECONDS_PER_SECOND,
+ DtUnit.MICRO_SECONDS)
+
+
+NICE_DATE_VALUES = {
+ DtUnit.YEARS: [1, 2, 5, 10],
+ DtUnit.MONTHS: [1, 2, 3, 4, 6, 12],
+ DtUnit.DAYS: [1, 2, 3, 7, 14, 28],
+ DtUnit.HOURS: [1, 2, 3, 4, 6, 12],
+ DtUnit.MINUTES: [1, 2, 3, 5, 10, 15, 30],
+ DtUnit.SECONDS: [1, 2, 3, 5, 10, 15, 30],
+ DtUnit.MICRO_SECONDS : [1.0, 2.0, 5.0, 10.0], # floats for microsec
+}
+
+
+def bestFormatString(spacing, unit):
+ """ Finds the best format string given the spacing and DtUnit.
+
+ If the spacing is a fractional number < 1 the format string will take this
+ into account
+
+ :param spacing: spacing between ticks
+ :param DtUnit unit:
+ :return: Format string for use in strftime
+ :rtype: str
+ """
+ isSmall = spacing < 1
+
+ if unit == DtUnit.YEARS:
+ return "%Y-m" if isSmall else "%Y"
+ elif unit == DtUnit.MONTHS:
+ return "%Y-%m-%d" if isSmall else "%Y-%m"
+ elif unit == DtUnit.DAYS:
+ return "%H:%M" if isSmall else "%Y-%m-%d"
+ elif unit == DtUnit.HOURS:
+ return "%H:%M" if isSmall else "%H:%M"
+ elif unit == DtUnit.MINUTES:
+ return "%H:%M:%S" if isSmall else "%H:%M"
+ elif unit == DtUnit.SECONDS:
+ return "%S.%f" if isSmall else "%H:%M:%S"
+ elif unit == DtUnit.MICRO_SECONDS:
+ return "%S.%f"
+ else:
+ raise ValueError("Unexpected DtUnit: {}".format(unit))
+
+
+def niceDateTimeElement(value, unit, isRound=False):
+ """ Uses the Nice Numbers algorithm to determine a nice value.
+
+ The fractions are optimized for the unit of the date element.
+ """
+
+ niceValues = NICE_DATE_VALUES[unit]
+ elemValue = niceNumGeneric(value, niceValues, isRound=isRound)
+
+ if unit == DtUnit.YEARS or unit == DtUnit.MONTHS:
+ elemValue = max(1, int(elemValue))
+
+ return elemValue
+
+
+def findStartDate(dMin, dMax, nTicks):
+ """ Rounds a date down to the nearest nice number of ticks
+ """
+ assert dMax >= dMin, \
+ "dMin ({}) should come before dMax ({})".format(dMin, dMax)
+
+ if dMin == dMax:
+ # Fallback when range is smaller than microsecond resolution
+ return dMin, 1, DtUnit.MICRO_SECONDS
+
+ delta = dMax - dMin
+ lengthSec = delta.total_seconds()
+ _logger.debug("findStartDate: {}, {} (duration = {} sec, {} days)"
+ .format(dMin, dMax, lengthSec, lengthSec / SECONDS_PER_DAY))
+
+ length, unit = bestUnit(lengthSec)
+ niceLength = niceDateTimeElement(length, unit)
+
+ _logger.debug("Length: {:8.3f} {} (nice = {})"
+ .format(length, unit.name, niceLength))
+
+ niceSpacing = niceDateTimeElement(niceLength / nTicks, unit, isRound=True)
+
+ _logger.debug("Spacing: {:8.3f} {} (nice = {})"
+ .format(niceLength / nTicks, unit.name, niceSpacing))
+
+ dVal = getDateElement(dMin, unit)
+
+ if unit == DtUnit.MONTHS: # TODO: better rounding?
+ niceVal = math.floor((dVal-1) / niceSpacing) * niceSpacing + 1
+ elif unit == DtUnit.DAYS:
+ niceVal = math.floor((dVal-1) / niceSpacing) * niceSpacing + 1
+ else:
+ niceVal = math.floor(dVal / niceSpacing) * niceSpacing
+
+ _logger.debug("StartValue: dVal = {}, niceVal: {} ({})"
+ .format(dVal, niceVal, unit.name))
+
+ startDate = roundToElement(dMin, unit)
+ startDate = setDateElement(startDate, niceVal, unit)
+
+ return startDate, niceSpacing, unit
+
+
+def dateRange(dMin, dMax, step, unit, includeFirstBeyond = False):
+ """ Generates a range of dates
+
+ :param datetime dMin: start date
+ :param datetime dMax: end date
+ :param int step: the step size
+ :param DtUnit unit: the unit of the step size
+ :param bool includeFirstBeyond: if True the first date later than dMax will
+ be included in the range. If False (the default), the last generated
+ datetime will always be smaller than dMax.
+ :return:
+ """
+ if (unit == DtUnit.YEARS or unit == DtUnit.MONTHS or
+ unit == DtUnit.MICRO_SECONDS):
+ # No support for fractional month or year and resolution is microsecond
+ # In those cases, make sure the step is at least 1
+ step = max(1, step)
+ else:
+ assert step > 0, "tickstep is 0"
+
+ dateTime = dMin
+ while dateTime < dMax:
+ yield dateTime
+ dateTime = addValueToDate(dateTime, step, unit)
+
+ if includeFirstBeyond:
+ yield dateTime
+
+
+
+def calcTicks(dMin, dMax, nTicks):
+ """Returns tick positions.
+
+ :param datetime.datetime dMin: The min value on the axis
+ :param datetime.datetime dMax: The max value on the axis
+ :param int nTicks: The target number of ticks. The actual number of found
+ ticks may differ.
+ :returns: (list of datetimes, DtUnit) tuple
+ """
+ _logger.debug("Calc calcTicks({}, {}, nTicks={})"
+ .format(dMin, dMax, nTicks))
+
+ startDate, niceSpacing, unit = findStartDate(dMin, dMax, nTicks)
+
+ result = []
+ for d in dateRange(startDate, dMax, niceSpacing, unit,
+ includeFirstBeyond=True):
+ result.append(d)
+
+ assert result[0] <= dMin, \
+ "First nice date ({}) should be <= dMin {}".format(result[0], dMin)
+
+ assert result[-1] >= dMax, \
+ "Last nice date ({}) should be >= dMax {}".format(result[-1], dMax)
+
+ return result, niceSpacing, unit
+
+
+def calcTicksAdaptive(dMin, dMax, axisLength, tickDensity):
+ """ Calls calcTicks with a variable number of ticks, depending on axisLength
+ """
+ # At least 2 ticks
+ nticks = max(2, int(round(tickDensity * axisLength)))
+ return calcTicks(dMin, dMax, nticks)
+
+
+
+
+
diff --git a/src/silx/gui/plot/_utils/panzoom.py b/src/silx/gui/plot/_utils/panzoom.py
new file mode 100644
index 0000000..77efd10
--- /dev/null
+++ b/src/silx/gui/plot/_utils/panzoom.py
@@ -0,0 +1,325 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Functions to apply pan and zoom on a Plot"""
+
+__authors__ = ["T. Vincent", "V. Valls"]
+__license__ = "MIT"
+__date__ = "08/08/2017"
+
+
+import logging
+import math
+import numpy
+
+
+_logger = logging.getLogger(__name__)
+
+
+# Float 32 info ###############################################################
+# Using min/max value below limits of float32
+# so operation with such value (e.g., max - min) do not overflow
+
+FLOAT32_SAFE_MIN = -1e37
+FLOAT32_MINPOS = numpy.finfo(numpy.float32).tiny
+FLOAT32_SAFE_MAX = 1e37
+# TODO double support
+
+
+def checkAxisLimits(vmin, vmax, isLog: bool=False, name: str=""):
+ """Makes sure axis range is not empty and within supported range.
+
+ :param float vmin: Min axis value
+ :param float vmax: Max axis value
+ :return: (min, max) making sure min < max
+ :rtype: 2-tuple of float
+ """
+ min_ = FLOAT32_MINPOS if isLog else FLOAT32_SAFE_MIN
+ vmax = numpy.clip(vmax, min_, FLOAT32_SAFE_MAX)
+ vmin = numpy.clip(vmin, min_, FLOAT32_SAFE_MAX)
+
+ if vmax < vmin:
+ _logger.debug('%s axis: max < min, inverting limits.', name)
+ vmin, vmax = vmax, vmin
+ elif vmax == vmin:
+ _logger.debug('%s axis: max == min, expanding limits.', name)
+ if vmin == 0.:
+ vmin, vmax = -0.1, 0.1
+ elif vmin < 0:
+ vmax *= 0.9
+ vmin = max(vmin * 1.1, FLOAT32_SAFE_MIN) # Clip to range
+ else: # vmin > 0
+ vmax = min(vmin * 1.1, FLOAT32_SAFE_MAX) # Clip to range
+ vmin *= 0.9
+
+ return vmin, vmax
+
+
+def scale1DRange(min_, max_, center, scale, isLog):
+ """Scale a 1D range given a scale factor and an center point.
+
+ Keeps the values in a smaller range than float32.
+
+ :param float min_: The current min value of the range.
+ :param float max_: The current max value of the range.
+ :param float center: The center of the zoom (i.e., invariant point).
+ :param float scale: The scale to use for zoom
+ :param bool isLog: Whether using log scale or not.
+ :return: The zoomed range.
+ :rtype: tuple of 2 floats: (min, max)
+ """
+ if isLog:
+ # Min and center can be < 0 when
+ # autoscale is off and switch to log scale
+ # max_ < 0 should not happen
+ min_ = numpy.log10(min_) if min_ > 0. else FLOAT32_MINPOS
+ center = numpy.log10(center) if center > 0. else FLOAT32_MINPOS
+ max_ = numpy.log10(max_) if max_ > 0. else FLOAT32_MINPOS
+
+ if min_ == max_:
+ return min_, max_
+
+ offset = (center - min_) / (max_ - min_)
+ range_ = (max_ - min_) / scale
+ newMin = center - offset * range_
+ newMax = center + (1. - offset) * range_
+
+ if isLog:
+ # No overflow as exponent is log10 of a float32
+ newMin = pow(10., newMin)
+ newMax = pow(10., newMax)
+ newMin = numpy.clip(newMin, FLOAT32_MINPOS, FLOAT32_SAFE_MAX)
+ newMax = numpy.clip(newMax, FLOAT32_MINPOS, FLOAT32_SAFE_MAX)
+ else:
+ newMin = numpy.clip(newMin, FLOAT32_SAFE_MIN, FLOAT32_SAFE_MAX)
+ newMax = numpy.clip(newMax, FLOAT32_SAFE_MIN, FLOAT32_SAFE_MAX)
+ return newMin, newMax
+
+
+def applyZoomToPlot(plot, scaleF, center=None):
+ """Zoom in/out plot given a scale and a center point.
+
+ :param plot: The plot on which to apply zoom.
+ :param float scaleF: Scale factor of zoom.
+ :param center: (x, y) coords in pixel coordinates of the zoom center.
+ :type center: 2-tuple of float
+ """
+ xMin, xMax = plot.getXAxis().getLimits()
+ yMin, yMax = plot.getYAxis().getLimits()
+
+ if center is None:
+ left, top, width, height = plot.getPlotBoundsInPixels()
+ cx, cy = left + width // 2, top + height // 2
+ else:
+ cx, cy = center
+
+ dataCenterPos = plot.pixelToData(cx, cy)
+ assert dataCenterPos is not None
+
+ xMin, xMax = scale1DRange(xMin, xMax, dataCenterPos[0], scaleF,
+ plot.getXAxis()._isLogarithmic())
+
+ yMin, yMax = scale1DRange(yMin, yMax, dataCenterPos[1], scaleF,
+ plot.getYAxis()._isLogarithmic())
+
+ dataPos = plot.pixelToData(cx, cy, axis="right")
+ assert dataPos is not None
+ y2Center = dataPos[1]
+ y2Min, y2Max = plot.getYAxis(axis="right").getLimits()
+ y2Min, y2Max = scale1DRange(y2Min, y2Max, y2Center, scaleF,
+ plot.getYAxis()._isLogarithmic())
+
+ plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
+
+
+def applyPan(min_, max_, panFactor, isLog10):
+ """Returns a new range with applied panning.
+
+ Moves the range according to panFactor.
+ If isLog10 is True, converts to log10 before moving.
+
+ :param float min_: Min value of the data range to pan.
+ :param float max_: Max value of the data range to pan.
+ Must be >= min.
+ :param float panFactor: Signed proportion of the range to use for pan.
+ :param bool isLog10: True if log10 scale, False if linear scale.
+ :return: New min and max value with pan applied.
+ :rtype: 2-tuple of float.
+ """
+ if isLog10 and min_ > 0.:
+ # Negative range and log scale can happen with matplotlib
+ logMin, logMax = math.log10(min_), math.log10(max_)
+ logOffset = panFactor * (logMax - logMin)
+ newMin = pow(10., logMin + logOffset)
+ newMax = pow(10., logMax + logOffset)
+
+ # Takes care of out-of-range values
+ if newMin > 0. and newMax < float('inf'):
+ min_, max_ = newMin, newMax
+
+ else:
+ offset = panFactor * (max_ - min_)
+ newMin, newMax = min_ + offset, max_ + offset
+
+ # Takes care of out-of-range values
+ if newMin > - float('inf') and newMax < float('inf'):
+ min_, max_ = newMin, newMax
+ return min_, max_
+
+
+class _Unset(object):
+ """To be able to have distinction between None and unset"""
+ pass
+
+
+class ViewConstraints(object):
+ """
+ Store constraints applied on the view box and compute the resulting view box.
+ """
+
+ def __init__(self):
+ self._min = [None, None]
+ self._max = [None, None]
+ self._minRange = [None, None]
+ self._maxRange = [None, None]
+
+ def update(self, xMin=_Unset, xMax=_Unset,
+ yMin=_Unset, yMax=_Unset,
+ minXRange=_Unset, maxXRange=_Unset,
+ minYRange=_Unset, maxYRange=_Unset):
+ """
+ Update the constraints managed by the object
+
+ The constraints are the same as the ones provided by PyQtGraph.
+
+ :param float xMin: Minimum allowed x-axis value.
+ (default do not change the stat, None remove the constraint)
+ :param float xMax: Maximum allowed x-axis value.
+ (default do not change the stat, None remove the constraint)
+ :param float yMin: Minimum allowed y-axis value.
+ (default do not change the stat, None remove the constraint)
+ :param float yMax: Maximum allowed y-axis value.
+ (default do not change the stat, None remove the constraint)
+ :param float minXRange: Minimum allowed left-to-right span across the
+ view (default do not change the stat, None remove the constraint)
+ :param float maxXRange: Maximum allowed left-to-right span across the
+ view (default do not change the stat, None remove the constraint)
+ :param float minYRange: Minimum allowed top-to-bottom span across the
+ view (default do not change the stat, None remove the constraint)
+ :param float maxYRange: Maximum allowed top-to-bottom span across the
+ view (default do not change the stat, None remove the constraint)
+ :return: True if the constraints was changed
+ """
+ updated = False
+
+ minRange = [minXRange, minYRange]
+ maxRange = [maxXRange, maxYRange]
+ minPos = [xMin, yMin]
+ maxPos = [xMax, yMax]
+
+ for axis in range(2):
+
+ value = minPos[axis]
+ if value is not _Unset and value != self._min[axis]:
+ self._min[axis] = value
+ updated = True
+
+ value = maxPos[axis]
+ if value is not _Unset and value != self._max[axis]:
+ self._max[axis] = value
+ updated = True
+
+ value = minRange[axis]
+ if value is not _Unset and value != self._minRange[axis]:
+ self._minRange[axis] = value
+ updated = True
+
+ value = maxRange[axis]
+ if value is not _Unset and value != self._maxRange[axis]:
+ self._maxRange[axis] = value
+ updated = True
+
+ # Sanity checks
+
+ for axis in range(2):
+ if self._maxRange[axis] is not None and self._min[axis] is not None and self._max[axis] is not None:
+ # max range cannot be larger than bounds
+ diff = self._max[axis] - self._min[axis]
+ self._maxRange[axis] = min(self._maxRange[axis], diff)
+ updated = True
+
+ return updated
+
+ def normalize(self, xMin, xMax, yMin, yMax, allow_scaling=True):
+ """Normalize a view range defined by x and y corners using predefined
+ containts.
+
+ :param float xMin: Min position of the x-axis
+ :param float xMax: Max position of the x-axis
+ :param float yMin: Min position of the y-axis
+ :param float yMax: Max position of the y-axis
+ :param bool allow_scaling: Allow or not to apply scaling for the
+ normalization. Used according to the interaction mode.
+ :return: A normalized tuple of (xMin, xMax, yMin, yMax)
+ """
+ viewRange = [[xMin, xMax], [yMin, yMax]]
+
+ for axis in range(2):
+ # clamp xRange and yRange
+ if allow_scaling:
+ diff = viewRange[axis][1] - viewRange[axis][0]
+ delta = None
+ if self._maxRange[axis] is not None and diff > self._maxRange[axis]:
+ delta = self._maxRange[axis] - diff
+ elif self._minRange[axis] is not None and diff < self._minRange[axis]:
+ delta = self._minRange[axis] - diff
+ if delta is not None:
+ viewRange[axis][0] -= delta * 0.5
+ viewRange[axis][1] += delta * 0.5
+
+ # clamp min and max positions
+ outMin = self._min[axis] is not None and viewRange[axis][0] < self._min[axis]
+ outMax = self._max[axis] is not None and viewRange[axis][1] > self._max[axis]
+
+ if outMin and outMax:
+ if allow_scaling:
+ # we can clamp both sides
+ viewRange[axis][0] = self._min[axis]
+ viewRange[axis][1] = self._max[axis]
+ else:
+ # center the result
+ delta = viewRange[axis][1] - viewRange[axis][0]
+ mid = self._min[axis] + self._max[axis] - self._min[axis]
+ viewRange[axis][0] = mid - delta
+ viewRange[axis][1] = mid + delta
+ elif outMin:
+ delta = self._min[axis] - viewRange[axis][0]
+ viewRange[axis][0] += delta
+ viewRange[axis][1] += delta
+ elif outMax:
+ delta = self._max[axis] - viewRange[axis][1]
+ viewRange[axis][0] += delta
+ viewRange[axis][1] += delta
+
+ return viewRange[0][0], viewRange[0][1], viewRange[1][0], viewRange[1][1]
diff --git a/src/silx/gui/plot/_utils/setup.py b/src/silx/gui/plot/_utils/setup.py
new file mode 100644
index 0000000..0271745
--- /dev/null
+++ b/src/silx/gui/plot/_utils/setup.py
@@ -0,0 +1,42 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "21/03/2017"
+
+
+from numpy.distutils.misc_util import Configuration
+
+
+def configuration(parent_package='', top_path=None):
+ config = Configuration('_utils', parent_package, top_path)
+ config.add_subpackage('test')
+ return config
+
+
+if __name__ == "__main__":
+ from numpy.distutils.core import setup
+
+ setup(configuration=configuration)
diff --git a/src/silx/gui/plot/_utils/test/__init__.py b/src/silx/gui/plot/_utils/test/__init__.py
new file mode 100644
index 0000000..3ad225d
--- /dev/null
+++ b/src/silx/gui/plot/_utils/test/__init__.py
@@ -0,0 +1,24 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
diff --git a/src/silx/gui/plot/_utils/test/test_dtime_ticklayout.py b/src/silx/gui/plot/_utils/test/test_dtime_ticklayout.py
new file mode 100644
index 0000000..8d35acf
--- /dev/null
+++ b/src/silx/gui/plot/_utils/test/test_dtime_ticklayout.py
@@ -0,0 +1,79 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["P. Kenter"]
+__license__ = "MIT"
+__date__ = "06/04/2018"
+
+
+import datetime as dt
+import unittest
+
+
+from silx.gui.plot._utils.dtime_ticklayout import (
+ calcTicks, DtUnit, SECONDS_PER_YEAR)
+
+
+class TestTickLayout(unittest.TestCase):
+ """Test ticks layout algorithms"""
+
+ def testSmallMonthlySpacing(self):
+ """ Tests a range that did result in a spacing of less than 1 month.
+ It is impossible to add fractional month so the unit must be in days
+ """
+ from dateutil import parser
+ d1 = parser.parse("2017-01-03 13:15:06.000044")
+ d2 = parser.parse("2017-03-08 09:16:16.307584")
+ _ticks, _units, spacing = calcTicks(d1, d2, nTicks=4)
+
+ self.assertEqual(spacing, DtUnit.DAYS)
+
+
+ def testNoCrash(self):
+ """ Creates many combinations of and number-of-ticks and end-dates;
+ tests that it doesn't give an exception and returns a reasonable number
+ of ticks.
+ """
+ d1 = dt.datetime(2017, 1, 3, 13, 15, 6, 44)
+
+ value = 100e-6 # Start at 100 micro sec range.
+
+ while value <= 200 * SECONDS_PER_YEAR:
+
+ d2 = d1 + dt.timedelta(microseconds=value*1e6) # end date range
+
+ for numTicks in range(2, 12):
+ ticks, _, _ = calcTicks(d1, d2, numTicks)
+
+ margin = 2.5
+ self.assertTrue(
+ numTicks/margin <= len(ticks) <= numTicks*margin,
+ "Condition {} <= {} <= {} failed for # ticks={} and d2={}:"
+ .format(numTicks/margin, len(ticks), numTicks * margin,
+ numTicks, d2))
+
+ value = value * 1.5 # let date period grow exponentially
diff --git a/src/silx/gui/plot/_utils/test/test_ticklayout.py b/src/silx/gui/plot/_utils/test/test_ticklayout.py
new file mode 100644
index 0000000..884b71b
--- /dev/null
+++ b/src/silx/gui/plot/_utils/test/test_ticklayout.py
@@ -0,0 +1,81 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+
+import unittest
+import numpy
+
+from silx.utils.testutils import ParametricTestCase
+
+from silx.gui.plot._utils import ticklayout
+
+
+class TestTickLayout(ParametricTestCase):
+ """Test ticks layout algorithms"""
+
+ def testTicks(self):
+ """Test of :func:`ticks`"""
+ tests = { # (vmin, vmax): ref_ticks
+ (1., 1.): (1.,),
+ (0.5, 10.5): (2.0, 4.0, 6.0, 8.0, 10.0),
+ (0.001, 0.005): (0.001, 0.002, 0.003, 0.004, 0.005)
+ }
+
+ for (vmin, vmax), ref_ticks in tests.items():
+ with self.subTest(vmin=vmin, vmax=vmax):
+ ticks, labels = ticklayout.ticks(vmin, vmax)
+ self.assertTrue(numpy.allclose(ticks, ref_ticks))
+
+ def testNiceNumbers(self):
+ """Minimalistic tests of :func:`niceNumbers`"""
+ tests = { # (vmin, vmax): ref_ticks
+ (0.5, 10.5): (0.0, 12.0, 2.0, 0),
+ (10000., 10000.5): (10000.0, 10000.5, 0.1, 1),
+ (0.001, 0.005): (0.001, 0.005, 0.001, 3)
+ }
+
+ for (vmin, vmax), ref_ticks in tests.items():
+ with self.subTest(vmin=vmin, vmax=vmax):
+ ticks = ticklayout.niceNumbers(vmin, vmax)
+ self.assertEqual(ticks, ref_ticks)
+
+ def testNiceNumbersLog(self):
+ """Minimalistic tests of :func:`niceNumbersForLog10`"""
+ tests = { # (log10(min), log10(max): ref_ticks
+ (0., 3.): (0, 3, 1, 0),
+ (-3., 3): (-3, 3, 1, 0),
+ (-32., 0.): (-36, 0, 6, 0)
+ }
+
+ for (vmin, vmax), ref_ticks in tests.items():
+ with self.subTest(vmin=vmin, vmax=vmax):
+ ticks = ticklayout.niceNumbersForLog10(vmin, vmax)
+ self.assertEqual(ticks, ref_ticks)
diff --git a/src/silx/gui/plot/_utils/ticklayout.py b/src/silx/gui/plot/_utils/ticklayout.py
new file mode 100644
index 0000000..c9fd3e6
--- /dev/null
+++ b/src/silx/gui/plot/_utils/ticklayout.py
@@ -0,0 +1,267 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module implements labels layout on graph axes."""
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "18/10/2016"
+
+
+import math
+
+
+# utils #######################################################################
+
+def numberOfDigits(tickSpacing):
+ """Returns the number of digits to display for text label.
+
+ :param float tickSpacing: Step between ticks in data space.
+ :return: Number of digits to show for labels.
+ :rtype: int
+ """
+ nfrac = int(-math.floor(math.log10(tickSpacing)))
+ if nfrac < 0:
+ nfrac = 0
+ return nfrac
+
+
+# Nice Numbers ################################################################
+
+# This is the original niceNum implementation. For the date time ticks a more
+# generic implementation was needed.
+#
+# def _niceNum(value, isRound=False):
+# expvalue = math.floor(math.log10(value))
+# frac = value/pow(10., expvalue)
+# if isRound:
+# if frac < 1.5:
+# nicefrac = 1.
+# elif frac < 3.: # In niceNumGeneric this is (2+5)/2 = 3.5
+# nicefrac = 2.
+# elif frac < 7.:
+# nicefrac = 5. # In niceNumGeneric this is (5+10)/2 = 7.5
+# else:
+# nicefrac = 10.
+# else:
+# if frac <= 1.:
+# nicefrac = 1.
+# elif frac <= 2.:
+# nicefrac = 2.
+# elif frac <= 5.:
+# nicefrac = 5.
+# else:
+# nicefrac = 10.
+# return nicefrac * pow(10., expvalue)
+
+
+def niceNumGeneric(value, niceFractions=None, isRound=False):
+ """ A more generic implementation of the _niceNum function
+
+ Allows the user to specify the fractions instead of using a hardcoded
+ list of [1, 2, 5, 10.0].
+ """
+ if value == 0:
+ return value
+
+ if niceFractions is None: # Use default values
+ niceFractions = 1., 2., 5., 10.
+ roundFractions = (1.5, 3., 7., 10.) if isRound else niceFractions
+
+ else:
+ roundFractions = list(niceFractions)
+ if isRound:
+ # Take the average with the next element. The last remains the same.
+ for i in range(len(roundFractions) - 1):
+ roundFractions[i] = (niceFractions[i] + niceFractions[i+1]) / 2
+
+ highest = niceFractions[-1]
+ value = float(value)
+
+ expvalue = math.floor(math.log(value, highest))
+ frac = value / pow(highest, expvalue)
+
+ for niceFrac, roundFrac in zip(niceFractions, roundFractions):
+ if frac <= roundFrac:
+ return niceFrac * pow(highest, expvalue)
+
+ # should not come here
+ assert False, "should not come here"
+
+
+def niceNumbers(vMin, vMax, nTicks=5):
+ """Returns tick positions.
+
+ This function implements graph labels layout using nice numbers
+ by Paul Heckbert from "Graphics Gems", Academic Press, 1990.
+ See `C code <http://tog.acm.org/resources/GraphicsGems/gems/Label.c>`_.
+
+ :param float vMin: The min value on the axis
+ :param float vMax: The max value on the axis
+ :param int nTicks: The number of ticks to position
+ :returns: min, max, increment value of tick positions and
+ number of fractional digit to show
+ :rtype: tuple
+ """
+ vrange = niceNumGeneric(vMax - vMin, isRound=False)
+ spacing = niceNumGeneric(vrange / nTicks, isRound=True)
+ graphmin = math.floor(vMin / spacing) * spacing
+ graphmax = math.ceil(vMax / spacing) * spacing
+ nfrac = numberOfDigits(spacing)
+ return graphmin, graphmax, spacing, nfrac
+
+
+def _frange(start, stop, step):
+ """range for float (including stop)."""
+ assert step >= 0.
+ while start <= stop:
+ yield start
+ start += step
+
+
+def ticks(vMin, vMax, nbTicks=5):
+ """Returns tick positions and labels using nice numbers algorithm.
+
+ This enforces ticks to be within [vMin, vMax] range.
+ It returns at least 1 tick (when vMin == vMax).
+
+ :param float vMin: The min value on the axis
+ :param float vMax: The max value on the axis
+ :param int nbTicks: The number of ticks to position
+ :returns: tick positions and corresponding text labels
+ :rtype: 2-tuple: list of float, list of string
+ """
+ assert vMin <= vMax
+ if vMin == vMax:
+ positions = [vMin]
+ nfrac = 0
+
+ else:
+ start, end, step, nfrac = niceNumbers(vMin, vMax, nbTicks)
+ positions = [t for t in _frange(start, end, step) if vMin <= t <= vMax]
+
+ # Makes sure there is at least 2 ticks
+ if len(positions) < 2:
+ positions = [vMin, vMax]
+ nfrac = numberOfDigits(vMax - vMin)
+
+ # Generate labels
+ format_ = '%g' if nfrac == 0 else '%.{}f'.format(nfrac)
+ labels = [format_ % tick for tick in positions]
+ return positions, labels
+
+
+def niceNumbersAdaptative(vMin, vMax, axisLength, tickDensity):
+ """Returns tick positions using :func:`niceNumbers` and a
+ density of ticks.
+
+ axisLength and tickDensity are based on the same unit (e.g., pixel).
+
+ :param float vMin: The min value on the axis
+ :param float vMax: The max value on the axis
+ :param float axisLength: The length of the axis.
+ :param float tickDensity: The density of ticks along the axis.
+ :returns: min, max, increment value of tick positions and
+ number of fractional digit to show
+ :rtype: tuple
+ """
+ # At least 2 ticks
+ nticks = max(2, int(round(tickDensity * axisLength)))
+ tickmin, tickmax, step, nfrac = niceNumbers(vMin, vMax, nticks)
+
+ return tickmin, tickmax, step, nfrac
+
+
+# Nice Numbers for log scale ##################################################
+
+def niceNumbersForLog10(minLog, maxLog, nTicks=5):
+ """Return tick positions for logarithmic scale
+
+ :param float minLog: log10 of the min value on the axis
+ :param float maxLog: log10 of the max value on the axis
+ :param int nTicks: The number of ticks to position
+ :returns: log10 of min, max, increment value of tick positions and
+ number of fractional digit to show
+ :rtype: tuple of int
+ """
+ graphminlog = math.floor(minLog)
+ graphmaxlog = math.ceil(maxLog)
+ rangelog = graphmaxlog - graphminlog
+
+ if rangelog <= nTicks:
+ spacing = 1.
+ else:
+ spacing = math.floor(rangelog / nTicks)
+
+ graphminlog = math.floor(graphminlog / spacing) * spacing
+ graphmaxlog = math.ceil(graphmaxlog / spacing) * spacing
+
+ nfrac = numberOfDigits(spacing)
+
+ return int(graphminlog), int(graphmaxlog), int(spacing), nfrac
+
+
+def niceNumbersAdaptativeForLog10(vMin, vMax, axisLength, tickDensity):
+ """Returns tick positions using :func:`niceNumbers` and a
+ density of ticks.
+
+ axisLength and tickDensity are based on the same unit (e.g., pixel).
+
+ :param float vMin: The min value on the axis
+ :param float vMax: The max value on the axis
+ :param float axisLength: The length of the axis.
+ :param float tickDensity: The density of ticks along the axis.
+ :returns: log10 of min, max, increment value of tick positions and
+ number of fractional digit to show
+ :rtype: tuple
+ """
+ # At least 2 ticks
+ nticks = max(2, int(round(tickDensity * axisLength)))
+ tickmin, tickmax, step, nfrac = niceNumbersForLog10(vMin, vMax, nticks)
+
+ return tickmin, tickmax, step, nfrac
+
+
+def computeLogSubTicks(ticks, lowBound, highBound):
+ """Return the sub ticks for the log scale for all given ticks if subtick
+ is in [lowBound, highBound]
+
+ :param ticks: log10 of the ticks
+ :param lowBound: the lower boundary of ticks
+ :param highBound: the higher boundary of ticks
+ :return: all the sub ticks contained in ticks (log10)
+ """
+ if len(ticks) < 1:
+ return []
+
+ res = []
+ for logPos in ticks:
+ dataOrigPos = logPos
+ for index in range(2, 10):
+ dataPos = dataOrigPos * index
+ if lowBound <= dataPos <= highBound:
+ res.append(dataPos)
+ return res
diff --git a/src/silx/gui/plot/actions/PlotAction.py b/src/silx/gui/plot/actions/PlotAction.py
new file mode 100644
index 0000000..2983775
--- /dev/null
+++ b/src/silx/gui/plot/actions/PlotAction.py
@@ -0,0 +1,78 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+The class :class:`.PlotAction` help the creation of a qt.QAction associated
+with a :class:`.PlotWidget`.
+"""
+
+from __future__ import division
+
+
+__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "03/01/2018"
+
+
+import weakref
+from silx.gui import icons
+from silx.gui import qt
+
+
+class PlotAction(qt.QAction):
+ """Base class for QAction that operates on a PlotWidget.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate.
+ :param icon: QIcon or str name of icon to use
+ :param str text: The name of this action to be used for menu label
+ :param str tooltip: The text of the tooltip
+ :param triggered: The callback to connect to the action's triggered
+ signal or None for no callback.
+ :param bool checkable: True for checkable action, False otherwise (default)
+ :param parent: See :class:`QAction`.
+ """
+
+ def __init__(self, plot, icon, text, tooltip=None,
+ triggered=None, checkable=False, parent=None):
+ assert plot is not None
+ self._plotRef = weakref.ref(plot)
+
+ if not isinstance(icon, qt.QIcon):
+ # Try with icon as a string and load corresponding icon
+ icon = icons.getQIcon(icon)
+
+ super(PlotAction, self).__init__(icon, text, parent)
+
+ if tooltip is not None:
+ self.setToolTip(tooltip)
+
+ self.setCheckable(checkable)
+
+ if triggered is not None:
+ self.triggered[bool].connect(triggered)
+
+ @property
+ def plot(self):
+ """The :class:`.PlotWidget` this action group is controlling."""
+ return self._plotRef()
diff --git a/src/silx/gui/plot/actions/PlotToolAction.py b/src/silx/gui/plot/actions/PlotToolAction.py
new file mode 100644
index 0000000..fbb0b0f
--- /dev/null
+++ b/src/silx/gui/plot/actions/PlotToolAction.py
@@ -0,0 +1,150 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+The class :class:`.PlotToolAction` help the creation of a qt.QAction associating
+a tool window with a :class:`.PlotWidget`.
+"""
+
+from __future__ import division
+
+
+__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "10/10/2018"
+
+
+import weakref
+
+from .PlotAction import PlotAction
+from silx.gui import qt
+
+
+class PlotToolAction(PlotAction):
+ """Base class for QAction that maintain a tool window operating on a
+ PlotWidget."""
+
+ def __init__(self, plot, icon, text, tooltip=None,
+ triggered=None, checkable=False, parent=None):
+ PlotAction.__init__(self,
+ plot=plot,
+ icon=icon,
+ text=text,
+ tooltip=tooltip,
+ triggered=self._triggered,
+ parent=parent,
+ checkable=True)
+ self._previousGeometry = None
+ self._toolWindow = None
+
+ def _triggered(self, checked):
+ """Update the plot of the histogram visibility status
+
+ :param bool checked: status of the action button
+ """
+ self._setToolWindowVisible(checked)
+
+ def _setToolWindowVisible(self, visible):
+ """Set the tool window visible or hidden."""
+ tool = self._getToolWindow()
+ if tool.isVisible() == visible:
+ # Nothing to do
+ return
+
+ if visible:
+ self._connectPlot(tool)
+ tool.show()
+ if self._previousGeometry is not None:
+ # Restore the geometry
+ tool.setGeometry(self._previousGeometry)
+ else:
+ self._disconnectPlot(tool)
+ # Save the geometry
+ self._previousGeometry = tool.geometry()
+ tool.hide()
+
+ def _connectPlot(self, window):
+ """Called if the tool is visible and have to be updated according to
+ event of the plot.
+
+ :param qt.QWidget window: The tool window
+ """
+ pass
+
+ def _disconnectPlot(self, window):
+ """Called if the tool is not visible and dont have anymore to be updated
+ according to event of the plot.
+
+ :param qt.QWidget window: The tool window
+ """
+ pass
+
+ def _isWindowInUse(self):
+ """Returns true if the tool window is currently in use."""
+ if not self.isChecked():
+ return False
+ return self._toolWindow is not None
+
+ def _ownerVisibilityChanged(self, isVisible):
+ """Called when the visibility of the parent of the tool window changes
+
+ :param bool isVisible: True if the parent became visible
+ """
+ if self._isWindowInUse():
+ self._setToolWindowVisible(isVisible)
+
+ def eventFilter(self, qobject, event):
+ """Observe when the close event is emitted then
+ simply uncheck the action button
+
+ :param qobject: the object observe
+ :param event: the event received by qobject
+ """
+ if event.type() == qt.QEvent.Close:
+ if self._toolWindow is not None:
+ window = self._toolWindow()
+ self._previousGeometry = window.geometry()
+ window.hide()
+ self.setChecked(False)
+
+ return PlotAction.eventFilter(self, qobject, event)
+
+ def _getToolWindow(self):
+ """Returns the window containing the tool.
+
+ It uses lazy loading to create this tool..
+ """
+ if self._toolWindow is None:
+ window = self._createToolWindow()
+ if self._previousGeometry is not None:
+ window.setGeometry(self._previousGeometry)
+ window.installEventFilter(self)
+ plot = self.plot
+ plot.sigVisibilityChanged.connect(self._ownerVisibilityChanged)
+ self._toolWindow = weakref.ref(window)
+ return self._toolWindow()
+
+ def _createToolWindow(self):
+ """Create the tool window managing the plot."""
+ raise NotImplementedError()
diff --git a/src/silx/gui/plot/actions/__init__.py b/src/silx/gui/plot/actions/__init__.py
new file mode 100644
index 0000000..930c728
--- /dev/null
+++ b/src/silx/gui/plot/actions/__init__.py
@@ -0,0 +1,42 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides a set of QAction to use with
+:class:`~silx.gui.plot.PlotWidget`
+
+Those actions are useful to add menu items or toolbar items
+that interact with a :class:`~silx.gui.plot.PlotWidget`.
+
+It provides a base class used to define new plot actions:
+:class:`~silx.gui.plot.actions.PlotAction`.
+"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "16/08/2017"
+
+from .PlotAction import PlotAction
+from . import control
+from . import mode
+from . import io
diff --git a/src/silx/gui/plot/actions/control.py b/src/silx/gui/plot/actions/control.py
new file mode 100755
index 0000000..439985e
--- /dev/null
+++ b/src/silx/gui/plot/actions/control.py
@@ -0,0 +1,694 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+:mod:`silx.gui.plot.actions.control` provides a set of QAction relative to control
+of a :class:`.PlotWidget`.
+
+The following QAction are available:
+
+- :class:`ColormapAction`
+- :class:`CrosshairAction`
+- :class:`CurveStyleAction`
+- :class:`GridAction`
+- :class:`KeepAspectRatioAction`
+- :class:`PanWithArrowKeysAction`
+- :class:`ResetZoomAction`
+- :class:`ShowAxisAction`
+- :class:`XAxisLogarithmicAction`
+- :class:`XAxisAutoScaleAction`
+- :class:`YAxisInvertedAction`
+- :class:`YAxisLogarithmicAction`
+- :class:`YAxisAutoScaleAction`
+- :class:`ZoomBackAction`
+- :class:`ZoomInAction`
+- :class:`ZoomOutAction`
+"""
+
+from __future__ import division
+
+__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "27/11/2020"
+
+from . import PlotAction
+import logging
+from silx.gui.plot import items
+from silx.gui.plot._utils import applyZoomToPlot as _applyZoomToPlot
+from silx.gui import qt
+from silx.gui import icons
+
+_logger = logging.getLogger(__name__)
+
+
+class ResetZoomAction(PlotAction):
+ """QAction controlling reset zoom on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(ResetZoomAction, self).__init__(
+ plot, icon='zoom-original', text='Reset Zoom',
+ tooltip='Auto-scale the graph',
+ triggered=self._actionTriggered,
+ checkable=False, parent=parent)
+ self._autoscaleChanged(True)
+ plot.getXAxis().sigAutoScaleChanged.connect(self._autoscaleChanged)
+ plot.getYAxis().sigAutoScaleChanged.connect(self._autoscaleChanged)
+
+ def _autoscaleChanged(self, enabled):
+ xAxis = self.plot.getXAxis()
+ yAxis = self.plot.getYAxis()
+ self.setEnabled(xAxis.isAutoScale() or yAxis.isAutoScale())
+
+ if xAxis.isAutoScale() and yAxis.isAutoScale():
+ tooltip = 'Auto-scale the graph'
+ elif xAxis.isAutoScale(): # And not Y axis
+ tooltip = 'Auto-scale the x-axis of the graph only'
+ elif yAxis.isAutoScale(): # And not X axis
+ tooltip = 'Auto-scale the y-axis of the graph only'
+ else: # no axis in autoscale
+ tooltip = 'Auto-scale the graph'
+ self.setToolTip(tooltip)
+
+ def _actionTriggered(self, checked=False):
+ self.plot.resetZoom()
+
+
+class ZoomBackAction(PlotAction):
+ """QAction performing a zoom-back in :class:`.PlotWidget` limits history.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(ZoomBackAction, self).__init__(
+ plot, icon='zoom-back', text='Zoom Back',
+ tooltip='Zoom back the plot',
+ triggered=self._actionTriggered,
+ checkable=False, parent=parent)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+
+ def _actionTriggered(self, checked=False):
+ self.plot.getLimitsHistory().pop()
+
+
+class ZoomInAction(PlotAction):
+ """QAction performing a zoom-in on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(ZoomInAction, self).__init__(
+ plot, icon='zoom-in', text='Zoom In',
+ tooltip='Zoom in the plot',
+ triggered=self._actionTriggered,
+ checkable=False, parent=parent)
+ self.setShortcut(qt.QKeySequence.ZoomIn)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+
+ def _actionTriggered(self, checked=False):
+ _applyZoomToPlot(self.plot, 1.1)
+
+
+class ZoomOutAction(PlotAction):
+ """QAction performing a zoom-out on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(ZoomOutAction, self).__init__(
+ plot, icon='zoom-out', text='Zoom Out',
+ tooltip='Zoom out the plot',
+ triggered=self._actionTriggered,
+ checkable=False, parent=parent)
+ self.setShortcut(qt.QKeySequence.ZoomOut)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+
+ def _actionTriggered(self, checked=False):
+ _applyZoomToPlot(self.plot, 1. / 1.1)
+
+
+class XAxisAutoScaleAction(PlotAction):
+ """QAction controlling X axis autoscale on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(XAxisAutoScaleAction, self).__init__(
+ plot, icon='plot-xauto', text='X Autoscale',
+ tooltip='Enable x-axis auto-scale when checked.\n'
+ 'If unchecked, x-axis does not change when reseting zoom.',
+ triggered=self._actionTriggered,
+ checkable=True, parent=parent)
+ self.setChecked(plot.getXAxis().isAutoScale())
+ plot.getXAxis().sigAutoScaleChanged.connect(self.setChecked)
+
+ def _actionTriggered(self, checked=False):
+ self.plot.getXAxis().setAutoScale(checked)
+ if checked:
+ self.plot.resetZoom()
+
+
+class YAxisAutoScaleAction(PlotAction):
+ """QAction controlling Y axis autoscale on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(YAxisAutoScaleAction, self).__init__(
+ plot, icon='plot-yauto', text='Y Autoscale',
+ tooltip='Enable y-axis auto-scale when checked.\n'
+ 'If unchecked, y-axis does not change when reseting zoom.',
+ triggered=self._actionTriggered,
+ checkable=True, parent=parent)
+ self.setChecked(plot.getYAxis().isAutoScale())
+ plot.getYAxis().sigAutoScaleChanged.connect(self.setChecked)
+
+ def _actionTriggered(self, checked=False):
+ self.plot.getYAxis().setAutoScale(checked)
+ if checked:
+ self.plot.resetZoom()
+
+
+class XAxisLogarithmicAction(PlotAction):
+ """QAction controlling X axis log scale on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(XAxisLogarithmicAction, self).__init__(
+ plot, icon='plot-xlog', text='X Log. scale',
+ tooltip='Logarithmic x-axis when checked',
+ triggered=self._actionTriggered,
+ checkable=True, parent=parent)
+ self.axis = plot.getXAxis()
+ self.setChecked(self.axis.getScale() == self.axis.LOGARITHMIC)
+ self.axis.sigScaleChanged.connect(self._setCheckedIfLogScale)
+
+ def _setCheckedIfLogScale(self, scale):
+ self.setChecked(scale == self.axis.LOGARITHMIC)
+
+ def _actionTriggered(self, checked=False):
+ scale = self.axis.LOGARITHMIC if checked else self.axis.LINEAR
+ self.axis.setScale(scale)
+
+
+class YAxisLogarithmicAction(PlotAction):
+ """QAction controlling Y axis log scale on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(YAxisLogarithmicAction, self).__init__(
+ plot, icon='plot-ylog', text='Y Log. scale',
+ tooltip='Logarithmic y-axis when checked',
+ triggered=self._actionTriggered,
+ checkable=True, parent=parent)
+ self.axis = plot.getYAxis()
+ self.setChecked(self.axis.getScale() == self.axis.LOGARITHMIC)
+ self.axis.sigScaleChanged.connect(self._setCheckedIfLogScale)
+
+ def _setCheckedIfLogScale(self, scale):
+ self.setChecked(scale == self.axis.LOGARITHMIC)
+
+ def _actionTriggered(self, checked=False):
+ scale = self.axis.LOGARITHMIC if checked else self.axis.LINEAR
+ self.axis.setScale(scale)
+
+
+class GridAction(PlotAction):
+ """QAction controlling grid mode on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param str gridMode: The grid mode to use in 'both', 'major'.
+ See :meth:`.PlotWidget.setGraphGrid`
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, gridMode='both', parent=None):
+ assert gridMode in ('both', 'major')
+ self._gridMode = gridMode
+
+ super(GridAction, self).__init__(
+ plot, icon='plot-grid', text='Grid',
+ tooltip='Toggle grid (on/off)',
+ triggered=self._actionTriggered,
+ checkable=True, parent=parent)
+ self.setChecked(plot.getGraphGrid() is not None)
+ plot.sigSetGraphGrid.connect(self._gridChanged)
+
+ def _gridChanged(self, which):
+ """Slot listening for PlotWidget grid mode change."""
+ self.setChecked(which != 'None')
+
+ def _actionTriggered(self, checked=False):
+ self.plot.setGraphGrid(self._gridMode if checked else None)
+
+
+class CurveStyleAction(PlotAction):
+ """QAction controlling curve style on a :class:`.PlotWidget`.
+
+ It changes the default line and markers style which updates all
+ curves on the plot.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(CurveStyleAction, self).__init__(
+ plot, icon='plot-toggle-points', text='Curve style',
+ tooltip='Change curve line and markers style',
+ triggered=self._actionTriggered,
+ checkable=False, parent=parent)
+
+ def _actionTriggered(self, checked=False):
+ currentState = (self.plot.isDefaultPlotLines(),
+ self.plot.isDefaultPlotPoints())
+
+ if currentState == (False, False):
+ newState = True, False
+ else:
+ # line only, line and symbol, symbol only
+ states = (True, False), (True, True), (False, True)
+ newState = states[(states.index(currentState) + 1) % 3]
+
+ self.plot.setDefaultPlotLines(newState[0])
+ self.plot.setDefaultPlotPoints(newState[1])
+
+
+class ColormapAction(PlotAction):
+ """QAction opening a ColormapDialog to update the colormap.
+
+ Both the active image colormap and the default colormap are updated.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ self._dialog = None # To store an instance of ColormapDialog
+ super(ColormapAction, self).__init__(
+ plot, icon='colormap', text='Colormap',
+ tooltip="Change colormap",
+ triggered=self._actionTriggered,
+ checkable=True, parent=parent)
+ self.plot.sigActiveImageChanged.connect(self._updateColormap)
+ self.plot.sigActiveScatterChanged.connect(self._updateColormap)
+
+ def setColorDialog(self, colorDialog):
+ """Set a specific color dialog instead of using the default dialog."""
+ assert(colorDialog is not None)
+ assert(self._dialog is None)
+ self._dialog = colorDialog
+ self._dialog.visibleChanged.connect(self._dialogVisibleChanged)
+ self.setChecked(self._dialog.isVisible())
+
+ @staticmethod
+ def _createDialog(parent):
+ """Create the dialog if not already existing
+
+ :parent QWidget parent: Parent of the new colormap
+ :rtype: ColormapDialog
+ """
+ from silx.gui.dialog.ColormapDialog import ColormapDialog
+ dialog = ColormapDialog(parent=parent)
+ dialog.setModal(False)
+ return dialog
+
+ def _actionTriggered(self, checked=False):
+ """Create a cmap dialog and update active image and default cmap."""
+ if self._dialog is None:
+ self._dialog = self._createDialog(self.plot)
+ self._dialog.visibleChanged.connect(self._dialogVisibleChanged)
+
+ # Run the dialog listening to colormap change
+ if checked is True:
+ self._updateColormap()
+ self._dialog.show()
+ else:
+ self._dialog.hide()
+
+ def _dialogVisibleChanged(self, isVisible):
+ self.setChecked(isVisible)
+
+ def _updateColormap(self):
+ if self._dialog is None:
+ return
+ image = self.plot.getActiveImage()
+
+ if isinstance(image, items.ColormapMixIn):
+ # Set dialog from active image
+ colormap = image.getColormap()
+ # Set histogram and range if any
+ self._dialog.setItem(image)
+
+ else:
+ # No active image or active image is RGBA,
+ # Check for active scatter plot
+ scatter = self.plot._getActiveItem(kind='scatter')
+ if scatter is not None:
+ colormap = scatter.getColormap()
+ self._dialog.setItem(scatter)
+
+ else:
+ # No active data image nor scatter,
+ # set dialog from default info
+ colormap = self.plot.getDefaultColormap()
+ # Reset histogram and range if any
+ self._dialog.setData(None)
+
+ self._dialog.setColormap(colormap)
+
+
+class ColorBarAction(PlotAction):
+ """QAction opening the ColorBarWidget of the specified plot.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ self._dialog = None # To store an instance of ColorBar
+ super(ColorBarAction, self).__init__(
+ plot, icon='colorbar', text='Colorbar',
+ tooltip="Show/Hide the colorbar",
+ triggered=self._actionTriggered,
+ checkable=True, parent=parent)
+ colorBarWidget = self.plot.getColorBarWidget()
+ old = self.blockSignals(True)
+ self.setChecked(colorBarWidget.isVisibleTo(self.plot))
+ self.blockSignals(old)
+ colorBarWidget.sigVisibleChanged.connect(self._widgetVisibleChanged)
+
+ def _widgetVisibleChanged(self, isVisible):
+ """Callback when the colorbar `visible` property change."""
+ if self.isChecked() == isVisible:
+ return
+ self.setChecked(isVisible)
+
+ def _actionTriggered(self, checked=False):
+ """Create a cmap dialog and update active image and default cmap."""
+ colorBarWidget = self.plot.getColorBarWidget()
+ if not colorBarWidget.isHidden() == checked:
+ return
+ self.plot.getColorBarWidget().setVisible(checked)
+
+
+class KeepAspectRatioAction(PlotAction):
+ """QAction controlling aspect ratio on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ # Uses two images for checked/unchecked states
+ self._states = {
+ False: (icons.getQIcon('shape-circle-solid'),
+ "Keep data aspect ratio"),
+ True: (icons.getQIcon('shape-ellipse-solid'),
+ "Do no keep data aspect ratio")
+ }
+
+ icon, tooltip = self._states[plot.isKeepDataAspectRatio()]
+ super(KeepAspectRatioAction, self).__init__(
+ plot,
+ icon=icon,
+ text='Toggle keep aspect ratio',
+ tooltip=tooltip,
+ triggered=self._actionTriggered,
+ checkable=False,
+ parent=parent)
+ plot.sigSetKeepDataAspectRatio.connect(
+ self._keepDataAspectRatioChanged)
+
+ def _keepDataAspectRatioChanged(self, aspectRatio):
+ """Handle Plot set keep aspect ratio signal"""
+ icon, tooltip = self._states[aspectRatio]
+ self.setIcon(icon)
+ self.setToolTip(tooltip)
+
+ def _actionTriggered(self, checked=False):
+ # This will trigger _keepDataAspectRatioChanged
+ self.plot.setKeepDataAspectRatio(not self.plot.isKeepDataAspectRatio())
+
+
+class YAxisInvertedAction(PlotAction):
+ """QAction controlling Y orientation on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ # Uses two images for checked/unchecked states
+ self._states = {
+ False: (icons.getQIcon('plot-ydown'),
+ "Orient Y axis downward"),
+ True: (icons.getQIcon('plot-yup'),
+ "Orient Y axis upward"),
+ }
+
+ icon, tooltip = self._states[plot.getYAxis().isInverted()]
+ super(YAxisInvertedAction, self).__init__(
+ plot,
+ icon=icon,
+ text='Invert Y Axis',
+ tooltip=tooltip,
+ triggered=self._actionTriggered,
+ checkable=False,
+ parent=parent)
+ plot.getYAxis().sigInvertedChanged.connect(self._yAxisInvertedChanged)
+
+ def _yAxisInvertedChanged(self, inverted):
+ """Handle Plot set y axis inverted signal"""
+ icon, tooltip = self._states[inverted]
+ self.setIcon(icon)
+ self.setToolTip(tooltip)
+
+ def _actionTriggered(self, checked=False):
+ # This will trigger _yAxisInvertedChanged
+ yAxis = self.plot.getYAxis()
+ yAxis.setInverted(not yAxis.isInverted())
+
+
+class CrosshairAction(PlotAction):
+ """QAction toggling crosshair cursor on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param str color: Color to use to draw the crosshair
+ :param int linewidth: Width of the crosshair cursor
+ :param str linestyle: Style of line. See :meth:`.Plot.setGraphCursor`
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, color='black', linewidth=1, linestyle='-',
+ parent=None):
+ self.color = color
+ """Color used to draw the crosshair (str)."""
+
+ self.linewidth = linewidth
+ """Width of the crosshair cursor (int)."""
+
+ self.linestyle = linestyle
+ """Style of line of the cursor (str)."""
+
+ super(CrosshairAction, self).__init__(
+ plot, icon='crosshair', text='Crosshair Cursor',
+ tooltip='Enable crosshair cursor when checked',
+ triggered=self._actionTriggered,
+ checkable=True, parent=parent)
+ self.setChecked(plot.getGraphCursor() is not None)
+ plot.sigSetGraphCursor.connect(self.setChecked)
+
+ def _actionTriggered(self, checked=False):
+ self.plot.setGraphCursor(checked,
+ color=self.color,
+ linestyle=self.linestyle,
+ linewidth=self.linewidth)
+
+
+class PanWithArrowKeysAction(PlotAction):
+ """QAction toggling pan with arrow keys on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+
+ super(PanWithArrowKeysAction, self).__init__(
+ plot, icon='arrow-keys', text='Pan with arrow keys',
+ tooltip='Enable pan with arrow keys when checked',
+ triggered=self._actionTriggered,
+ checkable=True, parent=parent)
+ self.setChecked(plot.isPanWithArrowKeys())
+ plot.sigSetPanWithArrowKeys.connect(self.setChecked)
+
+ def _actionTriggered(self, checked=False):
+ self.plot.setPanWithArrowKeys(checked)
+
+
+class ShowAxisAction(PlotAction):
+ """QAction controlling axis visibility on a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ tooltip = 'Show plot axis when checked, otherwise hide them'
+ PlotAction.__init__(self,
+ plot,
+ icon='axis',
+ text='show axis',
+ tooltip=tooltip,
+ triggered=self._actionTriggered,
+ checkable=True,
+ parent=parent)
+ self.setChecked(self.plot.isAxesDisplayed())
+ plot._sigAxesVisibilityChanged.connect(self.setChecked)
+
+ def _actionTriggered(self, checked=False):
+ self.plot.setAxesDisplayed(checked)
+
+
+class ClosePolygonInteractionAction(PlotAction):
+ """QAction controlling closure of a polygon in draw interaction mode
+ if the :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ tooltip = 'Close the current polygon drawn'
+ PlotAction.__init__(self,
+ plot,
+ icon='add-shape-polygon',
+ text='Close the polygon',
+ tooltip=tooltip,
+ triggered=self._actionTriggered,
+ checkable=True,
+ parent=parent)
+ self.plot.sigInteractiveModeChanged.connect(self._modeChanged)
+ self._modeChanged(None)
+
+ def _modeChanged(self, source):
+ mode = self.plot.getInteractiveMode()
+ enabled = "shape" in mode and mode["shape"] == "polygon"
+ self.setEnabled(enabled)
+
+ def _actionTriggered(self, checked=False):
+ self.plot._eventHandler.validate()
+
+
+class OpenGLAction(PlotAction):
+ """QAction controlling rendering of a :class:`.PlotWidget`.
+
+ For now it can enable or not the OpenGL backend.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ # Uses two images for checked/unchecked states
+ self._states = {
+ "opengl": (icons.getQIcon('backend-opengl'),
+ "OpenGL rendering (fast)\nClick to disable OpenGL"),
+ "matplotlib": (icons.getQIcon('backend-opengl'),
+ "Matplotlib rendering (safe)\nClick to enable OpenGL"),
+ "unknown": (icons.getQIcon('backend-opengl'),
+ "Custom rendering")
+ }
+
+ name = self._getBackendName(plot)
+ self.__state = name
+ icon, tooltip = self._states[name]
+ super(OpenGLAction, self).__init__(
+ plot,
+ icon=icon,
+ text='Enable/disable OpenGL rendering',
+ tooltip=tooltip,
+ triggered=self._actionTriggered,
+ checkable=True,
+ parent=parent)
+
+ def _backendUpdated(self):
+ name = self._getBackendName(self.plot)
+ self.__state = name
+ icon, tooltip = self._states[name]
+ self.setIcon(icon)
+ self.setToolTip(tooltip)
+ self.setChecked(name == "opengl")
+
+ def _getBackendName(self, plot):
+ backend = plot.getBackend()
+ name = type(backend).__name__.lower()
+ if "opengl" in name:
+ return "opengl"
+ elif "matplotlib" in name:
+ return "matplotlib"
+ else:
+ return "unknown"
+
+ def _actionTriggered(self, checked=False):
+ plot = self.plot
+ name = self._getBackendName(self.plot)
+ if self.__state != name:
+ # THere is no event to know the backend was updated
+ # So here we check if there is a mismatch between the displayed state
+ # and the real state of the widget
+ self._backendUpdated()
+ return
+ if name != "opengl":
+ from silx.gui.utils import glutils
+ result = glutils.isOpenGLAvailable()
+ if not result:
+ qt.QMessageBox.critical(plot, "OpenGL rendering not available", result.error)
+ # Uncheck if needed
+ self._backendUpdated()
+ return
+ plot.setBackend("opengl")
+ else:
+ plot.setBackend("matplotlib")
+ self._backendUpdated()
diff --git a/src/silx/gui/plot/actions/fit.py b/src/silx/gui/plot/actions/fit.py
new file mode 100644
index 0000000..e130b24
--- /dev/null
+++ b/src/silx/gui/plot/actions/fit.py
@@ -0,0 +1,485 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+:mod:`silx.gui.plot.actions.fit` module provides actions relative to fit.
+
+The following QAction are available:
+
+- :class:`.FitAction`
+
+.. autoclass:`.FitAction`
+"""
+
+from __future__ import division
+
+__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "10/10/2018"
+
+import logging
+import sys
+import weakref
+import numpy
+
+from .PlotToolAction import PlotToolAction
+from .. import items
+from ....utils.deprecation import deprecated
+from silx.gui import qt
+from silx.gui.plot.ItemsSelectionDialog import ItemsSelectionDialog
+
+_logger = logging.getLogger(__name__)
+
+
+def _getUniqueCurveOrHistogram(plot):
+ """Returns unique :class:`Curve` or :class:`Histogram` in a `PlotWidget`.
+
+ If there is an active curve, returns it, else return curve or histogram
+ only if alone in the plot.
+
+ :param PlotWidget plot:
+ :rtype: Union[None,~silx.gui.plot.items.Curve,~silx.gui.plot.items.Histogram]
+ """
+ curve = plot.getActiveCurve()
+ if curve is not None:
+ return curve
+
+ visibleItems = [item for item in plot.getItems() if item.isVisible()]
+ histograms = [item for item in visibleItems
+ if isinstance(item, items.Histogram)]
+ curves = [item for item in visibleItems
+ if isinstance(item, items.Curve)]
+
+ if len(histograms) == 1 and len(curves) == 0:
+ return histograms[0]
+ elif len(curves) == 1 and len(histograms) == 0:
+ return curves[0]
+ else:
+ return None
+
+
+class _FitItemSelector(qt.QObject):
+ """
+ :class:`PlotWidget` observer that emits signal when fit selection changes.
+
+ Track active curve or unique curve or histogram.
+ """
+
+ sigCurrentItemChanged = qt.Signal(object)
+ """Signal emitted when the item to fit has changed"""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.__plotWidgetRef = None
+ self.__currentItem = None
+
+ def getCurrentItem(self):
+ """Return currently selected item
+
+ :rtype: Union[Item,None]
+ """
+ return self.__currentItem
+
+ def getPlotWidget(self):
+ """Return currently attached :class:`PlotWidget`
+
+ :rtype: Union[PlotWidget,None]
+ """
+ return None if self.__plotWidgetRef is None else self.__plotWidgetRef()
+
+ def setPlotWidget(self, plotWidget):
+ """Set the :class:`PlotWidget` for which to track changes
+
+ :param Union[PlotWidget,None] plotWidget:
+ The :class:`PlotWidget` to observe
+ """
+ # disconnect from previous plot
+ previousPlotWidget = self.getPlotWidget()
+ if previousPlotWidget is not None:
+ previousPlotWidget.sigItemAdded.disconnect(
+ self.__plotWidgetUpdated)
+ previousPlotWidget.sigItemRemoved.disconnect(
+ self.__plotWidgetUpdated)
+ previousPlotWidget.sigActiveCurveChanged.disconnect(
+ self.__plotWidgetUpdated)
+
+ if plotWidget is None:
+ self.__plotWidgetRef = None
+ self.__setCurrentItem(None)
+ return
+ self.__plotWidgetRef = weakref.ref(plotWidget, self.__plotDeleted)
+
+ # connect to new plot
+ plotWidget.sigItemAdded.connect(self.__plotWidgetUpdated)
+ plotWidget.sigItemRemoved.connect(self.__plotWidgetUpdated)
+ plotWidget.sigActiveCurveChanged.connect(self.__plotWidgetUpdated)
+ self.__plotWidgetUpdated()
+
+ def __plotDeleted(self):
+ """Handle deletion of PlotWidget"""
+ self.__setCurrentItem(None)
+
+ def __plotWidgetUpdated(self, *args, **kwargs):
+ """Handle updates of PlotWidget content"""
+ plotWidget = self.getPlotWidget()
+ if plotWidget is None:
+ return
+ self.__setCurrentItem(_getUniqueCurveOrHistogram(plotWidget))
+
+ def __setCurrentItem(self, item):
+ """Handle change of current item"""
+ if sys.is_finalizing():
+ return
+
+ previousItem = self.getCurrentItem()
+ if item != previousItem:
+ if previousItem is not None:
+ previousItem.sigItemChanged.disconnect(self.__itemUpdated)
+
+ self.__currentItem = item
+
+ if self.__currentItem is not None:
+ self.__currentItem.sigItemChanged.connect(self.__itemUpdated)
+ self.sigCurrentItemChanged.emit(self.__currentItem)
+
+ def __itemUpdated(self, event):
+ """Handle change on current item"""
+ if event == items.ItemChangedType.DATA:
+ self.sigCurrentItemChanged.emit(self.__currentItem)
+
+
+class FitAction(PlotToolAction):
+ """QAction to open a :class:`FitWidget` and set its data to the
+ active curve if any, or to the first curve.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ self.__item = None
+ self.__activeCurveSynchroEnabled = False
+ self.__range = 0, 1
+ self.__rangeAutoUpdate = False
+ self.__x, self.__y = None, None # Data to fit
+ self.__curveParams = {} # Store curve parameters to use for fit result
+ self.__legend = None
+
+ super(FitAction, self).__init__(
+ plot, icon='math-fit', text='Fit curve',
+ tooltip='Open a fit dialog',
+ parent=parent)
+
+ self.__fitItemSelector = _FitItemSelector()
+ self.__fitItemSelector.sigCurrentItemChanged.connect(
+ self._setFittedItem)
+
+
+ @property
+ @deprecated(replacement='getXRange()[0]', since_version='0.13.0')
+ def xmin(self):
+ return self.getXRange()[0]
+
+ @property
+ @deprecated(replacement='getXRange()[1]', since_version='0.13.0')
+ def xmax(self):
+ return self.getXRange()[1]
+
+ @property
+ @deprecated(replacement='getXData()', since_version='0.13.0')
+ def x(self):
+ return self.getXData()
+
+ @property
+ @deprecated(replacement='getYData()', since_version='0.13.0')
+ def y(self):
+ return self.getYData()
+
+ @property
+ @deprecated(since_version='0.13.0')
+ def xlabel(self):
+ return self.__curveParams.get('xlabel', None)
+
+ @property
+ @deprecated(since_version='0.13.0')
+ def ylabel(self):
+ return self.__curveParams.get('ylabel', None)
+
+ @property
+ @deprecated(since_version='0.13.0')
+ def legend(self):
+ return self.__legend
+
+ def _createToolWindow(self):
+ # import done here rather than at module level to avoid circular import
+ # FitWidget -> BackgroundWidget -> PlotWindow -> actions -> fit -> FitWidget
+ from ...fit.FitWidget import FitWidget
+
+ window = FitWidget(parent=self.plot)
+ window.setWindowFlags(qt.Qt.Dialog)
+ window.sigFitWidgetSignal.connect(self.handle_signal)
+ return window
+
+ def _connectPlot(self, window):
+ if self.isXRangeUpdatedOnZoom():
+ self.__setAutoXRangeEnabled(True)
+ else:
+ plot = self.plot
+ if plot is None:
+ _logger.error("No associated PlotWidget")
+ return
+ self._setXRange(*plot.getXAxis().getLimits())
+
+ if self.isFittedItemUpdatedFromActiveCurve():
+ self.__setFittedItemAutoUpdateEnabled(True)
+ else:
+ # Wait for the next iteration, else the plot is not yet initialized
+ # No curve available
+ qt.QTimer.singleShot(10, self._initFit)
+
+ def _disconnectPlot(self, window):
+ if self.isXRangeUpdatedOnZoom():
+ self.__setAutoXRangeEnabled(False)
+
+ if self.isFittedItemUpdatedFromActiveCurve():
+ self.__setFittedItemAutoUpdateEnabled(False)
+
+ def _initFit(self):
+ plot = self.plot
+ if plot is None:
+ _logger.error("No associated PlotWidget")
+ return
+
+ item = _getUniqueCurveOrHistogram(plot)
+ if item is None:
+ # ambiguous case, we need to ask which plot item to fit
+ isd = ItemsSelectionDialog(parent=plot, plot=plot)
+ isd.setWindowTitle("Select item to be fitted")
+ isd.setItemsSelectionMode(qt.QTableWidget.SingleSelection)
+ isd.setAvailableKinds(["curve", "histogram"])
+ isd.selectAllKinds()
+
+ if not isd.exec(): # Cancel
+ self._getToolWindow().setVisible(False)
+ else:
+ selectedItems = isd.getSelectedItems()
+ item = selectedItems[0] if len(selectedItems) == 1 else None
+
+ self._setXRange(*plot.getXAxis().getLimits())
+ self._setFittedItem(item)
+
+ def __updateFitWidget(self):
+ """Update the data/range used by the FitWidget"""
+ fitWidget = self._getToolWindow()
+
+ item = self._getFittedItem()
+ xdata = self.getXData(copy=False)
+ ydata = self.getYData(copy=False)
+ if item is None or xdata is None or ydata is None:
+ fitWidget.setData(y=None)
+ fitWidget.setWindowTitle("No curve selected")
+
+ else:
+ xmin, xmax = self.getXRange()
+ fitWidget.setData(
+ xdata, ydata, xmin=xmin, xmax=xmax)
+ fitWidget.setWindowTitle(
+ "Fitting " + item.getName() +
+ " on x range %f-%f" % (xmin, xmax))
+
+ # X Range management
+
+ def getXRange(self):
+ """Returns the range on the X axis on which to perform the fit."""
+ return self.__range
+
+ def _setXRange(self, xmin, xmax):
+ """Set the range on which the fit is done.
+
+ :param float xmin:
+ :param float xmax:
+ """
+ range_ = float(xmin), float(xmax)
+ if self.__range != range_:
+ self.__range = range_
+ self.__updateFitWidget()
+
+ def __setAutoXRangeEnabled(self, enabled):
+ """Implement the change of update mode of the X range.
+
+ :param bool enabled:
+ """
+ plot = self.plot
+ if plot is None:
+ _logger.error("No associated PlotWidget")
+ return
+
+ if enabled:
+ self._setXRange(*plot.getXAxis().getLimits())
+ plot.getXAxis().sigLimitsChanged.connect(self._setXRange)
+ else:
+ plot.getXAxis().sigLimitsChanged.disconnect(self._setXRange)
+
+ def setXRangeUpdatedOnZoom(self, enabled):
+ """Set whether or not to update the X range on zoom change.
+
+ :param bool enabled:
+ """
+ if enabled != self.__rangeAutoUpdate:
+ self.__rangeAutoUpdate = enabled
+ if self._getToolWindow().isVisible():
+ self.__setAutoXRangeEnabled(enabled)
+
+ def isXRangeUpdatedOnZoom(self):
+ """Returns the current mode of fitted data X range update.
+
+ :rtype: bool
+ """
+ return self.__rangeAutoUpdate
+
+ # Fitted item update
+
+ def getXData(self, copy=True):
+ """Returns the X data used for the fit or None if undefined.
+
+ :param bool copy:
+ True to get a copy of the data, False to get the internal data.
+ :rtype: Union[numpy.ndarray,None]
+ """
+ return None if self.__x is None else numpy.array(self.__x, copy=copy)
+
+ def getYData(self, copy=True):
+ """Returns the Y data used for the fit or None if undefined.
+
+ :param bool copy:
+ True to get a copy of the data, False to get the internal data.
+ :rtype: Union[numpy.ndarray,None]
+ """
+ return None if self.__y is None else numpy.array(self.__y, copy=copy)
+
+ def _getFittedItem(self):
+ """Returns the current item used for the fit
+
+ :rtype: Union[~silx.gui.plot.items.Curve,~silx.gui.plot.items.Histogram,None]
+ """
+ return self.__item
+
+ def _setFittedItem(self, item):
+ """Set the curve to use for fitting.
+
+ :param Union[~silx.gui.plot.items.Curve,~silx.gui.plot.items.Histogram,None] item:
+ """
+ plot = self.plot
+ if plot is None:
+ _logger.error("No associated PlotWidget")
+
+ if plot is None or item is None:
+ self.__item = None
+ self.__curveParams = {}
+ self.__updateFitWidget()
+ return
+
+ axis = item.getYAxis() if isinstance(item, items.YAxisMixIn) else 'left'
+ self.__curveParams = {
+ 'yaxis': axis,
+ 'xlabel': plot.getXAxis().getLabel(),
+ 'ylabel': plot.getYAxis(axis).getLabel(),
+ }
+ self.__legend = item.getName()
+
+ if isinstance(item, items.Histogram):
+ bin_edges = item.getBinEdgesData(copy=False)
+ # take the middle coordinate between adjacent bin edges
+ self.__x = (bin_edges[1:] + bin_edges[:-1]) / 2
+ self.__y = item.getValueData(copy=False)
+ # else take the active curve, or else the unique curve
+ elif isinstance(item, items.Curve):
+ self.__x = item.getXData(copy=False)
+ self.__y = item.getYData(copy=False)
+
+ self.__item = item
+ self.__updateFitWidget()
+
+ def __setFittedItemAutoUpdateEnabled(self, enabled):
+ """Implement the change of fitted item update mode
+
+ :param bool enabled:
+ """
+ plot = self.plot
+ if plot is None:
+ _logger.error("No associated PlotWidget")
+ return
+
+ self.__fitItemSelector.setPlotWidget(self.plot if enabled else None)
+
+ def setFittedItemUpdatedFromActiveCurve(self, enabled):
+ """Toggle fitted data synchronization with plot active curve.
+
+ :param bool enabled:
+ """
+ enabled = bool(enabled)
+ if enabled != self.__activeCurveSynchroEnabled:
+ self.__activeCurveSynchroEnabled = enabled
+ if self._getToolWindow().isVisible():
+ self.__setFittedItemAutoUpdateEnabled(enabled)
+
+ def isFittedItemUpdatedFromActiveCurve(self):
+ """Returns True if fitted data is synchronized with plot.
+
+ :rtype: bool
+ """
+ return self.__activeCurveSynchroEnabled
+
+ # Handle fit completed
+
+ def handle_signal(self, ddict):
+ xdata = self.getXData(copy=False)
+ if xdata is None:
+ _logger.error("No reference data to display fit result for")
+ return
+
+ xmin, xmax = self.getXRange()
+ x_fit = xdata[xmin <= xdata]
+ x_fit = x_fit[x_fit <= xmax]
+ fit_legend = "Fit <%s>" % self.__legend
+ fit_curve = self.plot.getCurve(fit_legend)
+
+ if ddict["event"] == "FitFinished":
+ fit_widget = self._getToolWindow()
+ if fit_widget is None:
+ return
+ y_fit = fit_widget.fitmanager.gendata()
+ if fit_curve is None:
+ self.plot.addCurve(x_fit, y_fit,
+ fit_legend,
+ resetzoom=False,
+ **self.__curveParams)
+ else:
+ fit_curve.setData(x_fit, y_fit)
+ fit_curve.setVisible(True)
+ fit_curve.setYAxis(self.__curveParams.get('yaxis', 'left'))
+
+ if ddict["event"] in ["FitStarted", "FitFailed"]:
+ if fit_curve is not None:
+ fit_curve.setVisible(False)
diff --git a/src/silx/gui/plot/actions/histogram.py b/src/silx/gui/plot/actions/histogram.py
new file mode 100644
index 0000000..be9f5a7
--- /dev/null
+++ b/src/silx/gui/plot/actions/histogram.py
@@ -0,0 +1,542 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+:mod:`silx.gui.plot.actions.histogram` provides actions relative to histograms
+for :class:`.PlotWidget`.
+
+The following QAction are available:
+
+- :class:`PixelIntensitiesHistoAction`
+"""
+
+from __future__ import division
+
+__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"]
+__date__ = "01/12/2020"
+__license__ = "MIT"
+
+from typing import Optional, Tuple
+import numpy
+import logging
+import weakref
+
+from .PlotToolAction import PlotToolAction
+
+from silx.math.histogram import Histogramnd
+from silx.math.combo import min_max
+from silx.gui import qt
+from silx.gui.plot import items
+from silx.gui.widgets.ElidedLabel import ElidedLabel
+from silx.gui.widgets.RangeSlider import RangeSlider
+from silx.utils.deprecation import deprecated
+
+_logger = logging.getLogger(__name__)
+
+
+class _ElidedLabel(ElidedLabel):
+ """QLabel with a default size larger than what is displayed."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.setTextInteractionFlags(qt.Qt.TextSelectableByMouse)
+
+ def sizeHint(self):
+ hint = super().sizeHint()
+ nbchar = max(len(self.getText()), 12)
+ width = self.fontMetrics().boundingRect('#' * nbchar).width()
+ return qt.QSize(max(hint.width(), width), hint.height())
+
+
+class _StatWidget(qt.QWidget):
+ """Widget displaying a name and a value
+
+ :param parent:
+ :param name:
+ """
+
+ def __init__(self, parent=None, name: str=''):
+ super().__init__(parent)
+ layout = qt.QHBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+
+ keyWidget = qt.QLabel(parent=self)
+ keyWidget.setText("<b>" + name.capitalize() + ":<b>")
+ layout.addWidget(keyWidget)
+ self.__valueWidget = _ElidedLabel(parent=self)
+ self.__valueWidget.setText("-")
+ self.__valueWidget.setTextInteractionFlags(
+ qt.Qt.TextSelectableByMouse | qt.Qt.TextSelectableByKeyboard)
+ layout.addWidget(self.__valueWidget)
+
+ def setValue(self, value: Optional[float]):
+ """Set the displayed value
+
+ :param value:
+ """
+ self.__valueWidget.setText(
+ "-" if value is None else "{:.5g}".format(value))
+
+
+class _IntEdit(qt.QLineEdit):
+ """QLineEdit for integers with a default value and update on validation.
+
+ :param QWidget parent:
+ """
+
+ sigValueChanged = qt.Signal(int)
+ """Signal emitted when the value has changed (on editing finished)"""
+
+ def __init__(self, parent=None):
+ super().__init__(parent)
+ self.__value = None
+ self.setAlignment(qt.Qt.AlignRight)
+ validator = qt.QIntValidator()
+ self.setValidator(validator)
+ validator.bottomChanged.connect(self.__updateSize)
+ validator.topChanged.connect(self.__updateSize)
+ self.__updateSize()
+
+ self.textEdited.connect(self.__textEdited)
+
+ def __updateSize(self, *args):
+ """Update widget's maximum size according to bounds"""
+ bottom, top = self.getRange()
+ nbchar = max(len(str(bottom)), len(str(top)))
+ font = self.font()
+ font.setStyle(qt.QFont.StyleItalic)
+ fontMetrics = qt.QFontMetrics(font)
+ self.setMaximumWidth(
+ fontMetrics.boundingRect('0' * (nbchar + 1)).width()
+ )
+ self.setMaxLength(nbchar)
+
+ def __textEdited(self, _):
+ if self.font().style() != qt.QFont.StyleItalic:
+ font = self.font()
+ font.setStyle(qt.QFont.StyleItalic)
+ self.setFont(font)
+
+ # Use events rather than editingFinished to also trigger with empty text
+
+ def focusOutEvent(self, event):
+ self.__commitValue()
+ return super().focusOutEvent(event)
+
+ def keyPressEvent(self, event):
+ if event.key() in (qt.Qt.Key_Enter, qt.Qt.Key_Return):
+ self.__commitValue()
+ return super().keyPressEvent(event)
+
+ def __commitValue(self):
+ """Update the value returned by :meth:`getValue`"""
+ value = self.getCurrentValue()
+ if value is None:
+ value = self.getDefaultValue()
+ if value is None:
+ return # No value, keep previous one
+
+ if self.font().style() != qt.QFont.StyleNormal:
+ font = self.font()
+ font.setStyle(qt.QFont.StyleNormal)
+ self.setFont(font)
+
+ if value != self.__value:
+ self.__value = value
+ self.sigValueChanged.emit(value)
+
+ def getValue(self) -> Optional[int]:
+ """Return current value (None if never set)."""
+ return self.__value
+
+ def setRange(self, bottom: int, top: int):
+ """Set the range of valid values"""
+ self.validator().setRange(bottom, top)
+
+ def getRange(self) -> Tuple[int, int]:
+ """Returns the current range of valid values
+
+ :returns: (bottom, top)
+ """
+ return self.validator().bottom(), self.validator().top()
+
+ def __validate(self, value: int, extend_range: bool):
+ """Ensure value is in range
+
+ :param int value:
+ :param bool extend_range:
+ True to extend range if needed.
+ False to clip value if needed.
+ """
+ if extend_range:
+ bottom, top = self.getRange()
+ self.setRange(min(value, bottom), max(value, top))
+ return numpy.clip(value, *self.getRange())
+
+ def setDefaultValue(self, value: int, extend_range: bool=False):
+ """Set default value when QLineEdit is empty
+
+ :param int value:
+ :param bool extend_range:
+ True to extend range if needed.
+ False to clip value if needed
+ """
+ self.setPlaceholderText(str(self.__validate(value, extend_range)))
+ if self.getCurrentValue() is None:
+ self.__commitValue()
+
+ def getDefaultValue(self) -> Optional[int]:
+ """Return the default value or the bottom one if not set"""
+ try:
+ return int(self.placeholderText())
+ except ValueError:
+ return None
+
+ def setCurrentValue(self, value: int, extend_range: bool=False):
+ """Set the currently displayed value
+
+ :param int value:
+ :param bool extend_range:
+ True to extend range if needed.
+ False to clip value if needed
+ """
+ self.setText(str(self.__validate(value, extend_range)))
+ self.__commitValue()
+
+ def getCurrentValue(self) -> Optional[int]:
+ """Returns the displayed value or None if not correct"""
+ try:
+ return int(self.text())
+ except ValueError:
+ return None
+
+
+class HistogramWidget(qt.QWidget):
+ """Widget displaying a histogram and some statistic indicators"""
+
+ _SUPPORTED_ITEM_CLASS = items.ImageBase, items.Scatter
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.setWindowTitle('Histogram')
+
+ self.__itemRef = None # weakref on the item to track
+
+ layout = qt.QVBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(0)
+
+ # Plot
+ # Lazy import to avoid circular dependencies
+ from silx.gui.plot.PlotWindow import Plot1D
+ self.__plot = Plot1D(self)
+ layout.addWidget(self.__plot)
+
+ self.__plot.setDataMargins(0.1, 0.1, 0.1, 0.1)
+ self.__plot.getXAxis().setLabel("Value")
+ self.__plot.getYAxis().setLabel("Count")
+ posInfo = self.__plot.getPositionInfoWidget()
+ posInfo.setSnappingMode(posInfo.SNAPPING_CURVE)
+
+ # Histogram controls
+ controlsWidget = qt.QWidget(self)
+ layout.addWidget(controlsWidget)
+ controlsLayout = qt.QHBoxLayout(controlsWidget)
+ controlsLayout.setContentsMargins(4, 4, 4, 4)
+
+ controlsLayout.addWidget(qt.QLabel("<b>Histogram:<b>"))
+ controlsLayout.addWidget(qt.QLabel("N. bins:"))
+ self.__nbinsLineEdit = _IntEdit(self)
+ self.__nbinsLineEdit.setRange(2, 9999)
+ self.__nbinsLineEdit.sigValueChanged.connect(
+ self.__updateHistogramFromControls)
+ controlsLayout.addWidget(self.__nbinsLineEdit)
+ self.__rangeLabel = qt.QLabel("Range:")
+ controlsLayout.addWidget(self.__rangeLabel)
+ self.__rangeSlider = RangeSlider(parent=self)
+ self.__rangeSlider.sigValueChanged.connect(
+ self.__updateHistogramFromControls)
+ self.__rangeSlider.sigValueChanged.connect(self.__rangeChanged)
+ controlsLayout.addWidget(self.__rangeSlider)
+ controlsLayout.addStretch(1)
+
+ # Stats display
+ statsWidget = qt.QWidget(self)
+ layout.addWidget(statsWidget)
+ statsLayout = qt.QHBoxLayout(statsWidget)
+ statsLayout.setContentsMargins(4, 4, 4, 4)
+
+ self.__statsWidgets = dict(
+ (name, _StatWidget(parent=statsWidget, name=name))
+ for name in ("min", "max", "mean", "std", "sum"))
+
+ for widget in self.__statsWidgets.values():
+ statsLayout.addWidget(widget)
+ statsLayout.addStretch(1)
+
+ def getPlotWidget(self):
+ """Returns :class:`PlotWidget` use to display the histogram"""
+ return self.__plot
+
+ def resetZoom(self):
+ """Reset PlotWidget zoom"""
+ self.getPlotWidget().resetZoom()
+
+ def reset(self):
+ """Clear displayed information"""
+ self.getPlotWidget().clear()
+ self.setStatistics()
+
+ def getItem(self) -> Optional[items.Item]:
+ """Returns item used to display histogram and statistics."""
+ return None if self.__itemRef is None else self.__itemRef()
+
+ def setItem(self, item: Optional[items.Item]):
+ """Set item from which to display histogram and statistics.
+
+ :param item:
+ """
+ previous = self.getItem()
+ if previous is not None:
+ previous.sigItemChanged.disconnect(self.__itemChanged)
+
+ self.__itemRef = None if item is None else weakref.ref(item)
+ if item is not None:
+ if isinstance(item, self._SUPPORTED_ITEM_CLASS):
+ # Only listen signal for supported items
+ item.sigItemChanged.connect(self.__itemChanged)
+ self._updateFromItem()
+
+ def __itemChanged(self, event):
+ """Handle update of the item"""
+ if event in (items.ItemChangedType.DATA, items.ItemChangedType.MASK):
+ self._updateFromItem()
+
+ def __updateHistogramFromControls(self, *args):
+ """Handle udates coming from histogram control widgets"""
+
+ hist = self.getHistogram(copy=False)
+ if hist is not None:
+ count, edges = hist
+ if (len(count) == self.__nbinsLineEdit.getValue() and
+ (edges[0], edges[-1]) == self.__rangeSlider.getValues()):
+ return # Nothing has changed
+
+ self._updateFromItem()
+
+ def __rangeChanged(self, first, second):
+ """Handle change of histogram range from the range slider"""
+ tooltip = "Histogram range:\n[%g, %g]" % (first, second)
+ self.__rangeSlider.setToolTip(tooltip)
+ self.__rangeLabel.setToolTip(tooltip)
+
+ def _updateFromItem(self):
+ """Update histogram and stats from the item"""
+ item = self.getItem()
+
+ if item is None:
+ self.reset()
+ return
+
+ if not isinstance(item, self._SUPPORTED_ITEM_CLASS):
+ _logger.error("Unsupported item", item)
+ self.reset()
+ return
+
+ # Compute histogram and stats
+ array = item.getValueData(copy=False)
+
+ if array.size == 0:
+ self.reset()
+ return
+
+ xmin, xmax = min_max(array, min_positive=False, finite=True)
+ if xmin is None or xmax is None: # All not finite data
+ self.reset()
+ return
+ guessed_nbins = min(1024, int(numpy.sqrt(array.size)))
+
+ # bad hack: get 256 bins in the case we have a B&W
+ if numpy.issubdtype(array.dtype, numpy.integer):
+ if guessed_nbins > xmax - xmin:
+ guessed_nbins = xmax - xmin
+ guessed_nbins = max(2, guessed_nbins)
+
+ # Set default nbins
+ self.__nbinsLineEdit.setDefaultValue(guessed_nbins, extend_range=True)
+ # Set slider range: do not keep the range value, but the relative pos.
+ previousPositions = self.__rangeSlider.getPositions()
+ if xmin == xmax: # Enlarge range is none
+ if xmin == 0:
+ range_ = -0.01, 0.01
+ else:
+ range_ = sorted((xmin * .99, xmin * 1.01))
+ else:
+ range_ = xmin, xmax
+
+ self.__rangeSlider.setRange(*range_)
+ self.__rangeSlider.setPositions(*previousPositions)
+
+ histogram = Histogramnd(
+ array.ravel().astype(numpy.float32),
+ n_bins=max(2, self.__nbinsLineEdit.getValue()),
+ histo_range=self.__rangeSlider.getValues(),
+ )
+ if len(histogram.edges) != 1:
+ _logger.error("Error while computing the histogram")
+ self.reset()
+ return
+
+ self.setHistogram(histogram.histo, histogram.edges[0])
+ self.resetZoom()
+ self.setStatistics(
+ min_=xmin,
+ max_=xmax,
+ mean=numpy.nanmean(array),
+ std=numpy.nanstd(array),
+ sum_=numpy.nansum(array))
+
+ def setHistogram(self, histogram, edges):
+ """Set displayed histogram
+
+ :param histogram: Bin values (N)
+ :param edges: Bin edges (N+1)
+ """
+ # Only useful if setHistogram is called directly
+ # TODO
+ #nbins = len(histogram)
+ #if nbins != self.__nbinsLineEdit.getDefaultValue():
+ # self.__nbinsLineEdit.setValue(nbins, extend_range=True)
+ #self.__rangeSlider.setValues(edges[0], edges[-1])
+
+ self.getPlotWidget().addHistogram(
+ histogram=histogram,
+ edges=edges,
+ legend='histogram',
+ fill=True,
+ color='#66aad7',
+ resetzoom=False)
+
+ def getHistogram(self, copy: bool=True):
+ """Returns currently displayed histogram.
+
+ :param copy: True to get a copy,
+ False to get internal representation (Do not modify!)
+ :return: (histogram, edges) or None
+ """
+ for item in self.getPlotWidget().getItems():
+ if item.getName() == 'histogram':
+ return (item.getValueData(copy=copy),
+ item.getBinEdgesData(copy=copy))
+ else:
+ return None
+
+ def setStatistics(self,
+ min_: Optional[float] = None,
+ max_: Optional[float] = None,
+ mean: Optional[float] = None,
+ std: Optional[float] = None,
+ sum_: Optional[float] = None):
+ """Set displayed statistic indicators."""
+ self.__statsWidgets['min'].setValue(min_)
+ self.__statsWidgets['max'].setValue(max_)
+ self.__statsWidgets['mean'].setValue(mean)
+ self.__statsWidgets['std'].setValue(std)
+ self.__statsWidgets['sum'].setValue(sum_)
+
+
+class PixelIntensitiesHistoAction(PlotToolAction):
+ """QAction to plot the pixels intensities diagram
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ PlotToolAction.__init__(self,
+ plot,
+ icon='pixel-intensities',
+ text='pixels intensity',
+ tooltip='Compute image intensity distribution',
+ parent=parent)
+
+ def _connectPlot(self, window):
+ plot = self.plot
+ if plot is not None:
+ selection = plot.selection()
+ selection.sigSelectedItemsChanged.connect(self._selectedItemsChanged)
+ self._updateSelectedItem()
+
+ PlotToolAction._connectPlot(self, window)
+
+ def _disconnectPlot(self, window):
+ plot = self.plot
+ if plot is not None:
+ selection = self.plot.selection()
+ selection.sigSelectedItemsChanged.disconnect(self._selectedItemsChanged)
+
+ PlotToolAction._disconnectPlot(self, window)
+ self.getHistogramWidget().setItem(None)
+
+ def _updateSelectedItem(self):
+ """Synchronises selected item with plot widget."""
+ plot = self.plot
+ if plot is not None:
+ selected = plot.selection().getSelectedItems()
+ # Give priority to image over scatter
+ for klass in (items.ImageBase, items.Scatter):
+ for item in selected:
+ if isinstance(item, klass):
+ # Found a matching item, use it
+ self.getHistogramWidget().setItem(item)
+ return
+ self.getHistogramWidget().setItem(None)
+
+ def _selectedItemsChanged(self):
+ if self._isWindowInUse():
+ self._updateSelectedItem()
+
+ @deprecated(since_version='0.15.0')
+ def computeIntensityDistribution(self):
+ self.getHistogramWidget()._updateFromItem()
+
+ def getHistogramWidget(self):
+ """Returns the widget displaying the histogram"""
+ return self._getToolWindow()
+
+ @deprecated(since_version='0.15.0',
+ replacement='getHistogramWidget().getPlotWidget()')
+ def getHistogramPlotWidget(self):
+ return self._getToolWindow().getPlotWidget()
+
+ def _createToolWindow(self):
+ return HistogramWidget(self.plot, qt.Qt.Window)
+
+ def getHistogram(self) -> Optional[numpy.ndarray]:
+ """Return the last computed histogram
+
+ :return: the histogram displayed in the HistogramWidget
+ """
+ histogram = self.getHistogramWidget().getHistogram()
+ return None if histogram is None else histogram[0]
diff --git a/src/silx/gui/plot/actions/io.py b/src/silx/gui/plot/actions/io.py
new file mode 100644
index 0000000..7f4edd3
--- /dev/null
+++ b/src/silx/gui/plot/actions/io.py
@@ -0,0 +1,819 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+:mod:`silx.gui.plot.actions.io` provides a set of QAction relative of inputs
+and outputs for a :class:`.PlotWidget`.
+
+The following QAction are available:
+
+- :class:`CopyAction`
+- :class:`PrintAction`
+- :class:`SaveAction`
+"""
+
+from __future__ import division
+
+__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "25/09/2020"
+
+from . import PlotAction
+from silx.io.utils import save1D, savespec, NEXUS_HDF5_EXT
+from silx.io.nxdata import save_NXdata
+import logging
+import sys
+import os.path
+from collections import OrderedDict
+import traceback
+import numpy
+from silx.utils.deprecation import deprecated
+from silx.gui import qt, printer
+from silx.gui.dialog.GroupDialog import GroupDialog
+from silx.third_party.EdfFile import EdfFile
+from silx.third_party.TiffIO import TiffIO
+from ...utils.image import convertArrayToQImage
+if sys.version_info[0] == 3:
+ from io import BytesIO
+else:
+ import cStringIO as _StringIO
+ BytesIO = _StringIO.StringIO
+
+_logger = logging.getLogger(__name__)
+
+_NEXUS_HDF5_EXT_STR = ' '.join(['*' + ext for ext in NEXUS_HDF5_EXT])
+
+
+def selectOutputGroup(h5filename):
+ """Open a dialog to prompt the user to select a group in
+ which to output data.
+
+ :param str h5filename: name of an existing HDF5 file
+ :rtype: str
+ :return: Name of output group, or None if the dialog was cancelled
+ """
+ dialog = GroupDialog()
+ dialog.addFile(h5filename)
+ dialog.setWindowTitle("Select an output group")
+ if not dialog.exec():
+ return None
+ return dialog.getSelectedDataUrl().data_path()
+
+
+class SaveAction(PlotAction):
+ """QAction for saving Plot content.
+
+ It opens a Save as... dialog.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate.
+ :param parent: See :class:`QAction`.
+ """
+
+ SNAPSHOT_FILTER_SVG = 'Plot Snapshot as SVG (*.svg)'
+ SNAPSHOT_FILTER_PNG = 'Plot Snapshot as PNG (*.png)'
+
+ DEFAULT_ALL_FILTERS = (SNAPSHOT_FILTER_PNG, SNAPSHOT_FILTER_SVG)
+
+ # Dict of curve filters with CSV-like format
+ # Using ordered dict to guarantee filters order
+ # Note: '%.18e' is numpy.savetxt default format
+ CURVE_FILTERS_TXT = OrderedDict((
+ ('Curve as Raw ASCII (*.txt)',
+ {'fmt': '%.18e', 'delimiter': ' ', 'header': False}),
+ ('Curve as ";"-separated CSV (*.csv)',
+ {'fmt': '%.18e', 'delimiter': ';', 'header': True}),
+ ('Curve as ","-separated CSV (*.csv)',
+ {'fmt': '%.18e', 'delimiter': ',', 'header': True}),
+ ('Curve as tab-separated CSV (*.csv)',
+ {'fmt': '%.18e', 'delimiter': '\t', 'header': True}),
+ ('Curve as OMNIC CSV (*.csv)',
+ {'fmt': '%.7E', 'delimiter': ',', 'header': False}),
+ ('Curve as SpecFile (*.dat)',
+ {'fmt': '%.10g', 'delimiter': '', 'header': False})
+ ))
+
+ CURVE_FILTER_NPY = 'Curve as NumPy binary file (*.npy)'
+
+ CURVE_FILTER_NXDATA = 'Curve as NXdata (%s)' % _NEXUS_HDF5_EXT_STR
+
+ DEFAULT_CURVE_FILTERS = list(CURVE_FILTERS_TXT.keys()) + [
+ CURVE_FILTER_NPY, CURVE_FILTER_NXDATA]
+
+ DEFAULT_ALL_CURVES_FILTERS = ("All curves as SpecFile (*.dat)",)
+
+ IMAGE_FILTER_EDF = 'Image data as EDF (*.edf)'
+ IMAGE_FILTER_TIFF = 'Image data as TIFF (*.tif)'
+ IMAGE_FILTER_NUMPY = 'Image data as NumPy binary file (*.npy)'
+ IMAGE_FILTER_ASCII = 'Image data as ASCII (*.dat)'
+ IMAGE_FILTER_CSV_COMMA = 'Image data as ,-separated CSV (*.csv)'
+ IMAGE_FILTER_CSV_SEMICOLON = 'Image data as ;-separated CSV (*.csv)'
+ IMAGE_FILTER_CSV_TAB = 'Image data as tab-separated CSV (*.csv)'
+ IMAGE_FILTER_RGB_PNG = 'Image as PNG (*.png)'
+ IMAGE_FILTER_NXDATA = 'Image as NXdata (%s)' % _NEXUS_HDF5_EXT_STR
+
+ DEFAULT_IMAGE_FILTERS = (IMAGE_FILTER_EDF,
+ IMAGE_FILTER_TIFF,
+ IMAGE_FILTER_NUMPY,
+ IMAGE_FILTER_ASCII,
+ IMAGE_FILTER_CSV_COMMA,
+ IMAGE_FILTER_CSV_SEMICOLON,
+ IMAGE_FILTER_CSV_TAB,
+ IMAGE_FILTER_RGB_PNG,
+ IMAGE_FILTER_NXDATA)
+
+ SCATTER_FILTER_NXDATA = 'Scatter as NXdata (%s)' % _NEXUS_HDF5_EXT_STR
+ DEFAULT_SCATTER_FILTERS = (SCATTER_FILTER_NXDATA,)
+
+ # filters for which we don't want an "overwrite existing file" warning
+ DEFAULT_APPEND_FILTERS = (CURVE_FILTER_NXDATA, IMAGE_FILTER_NXDATA,
+ SCATTER_FILTER_NXDATA)
+
+ def __init__(self, plot, parent=None):
+ self._filters = {
+ 'all': OrderedDict(),
+ 'curve': OrderedDict(),
+ 'curves': OrderedDict(),
+ 'image': OrderedDict(),
+ 'scatter': OrderedDict()}
+
+ self._appendFilters = list(self.DEFAULT_APPEND_FILTERS)
+
+ # Initialize filters
+ for nameFilter in self.DEFAULT_ALL_FILTERS:
+ self.setFileFilter(
+ dataKind='all', nameFilter=nameFilter, func=self._saveSnapshot)
+
+ for nameFilter in self.DEFAULT_CURVE_FILTERS:
+ self.setFileFilter(
+ dataKind='curve', nameFilter=nameFilter, func=self._saveCurve)
+
+ for nameFilter in self.DEFAULT_ALL_CURVES_FILTERS:
+ self.setFileFilter(
+ dataKind='curves', nameFilter=nameFilter, func=self._saveCurves)
+
+ for nameFilter in self.DEFAULT_IMAGE_FILTERS:
+ self.setFileFilter(
+ dataKind='image', nameFilter=nameFilter, func=self._saveImage)
+
+ for nameFilter in self.DEFAULT_SCATTER_FILTERS:
+ self.setFileFilter(
+ dataKind='scatter', nameFilter=nameFilter, func=self._saveScatter)
+
+ super(SaveAction, self).__init__(
+ plot, icon='document-save', text='Save as...',
+ tooltip='Save curve/image/plot snapshot dialog',
+ triggered=self._actionTriggered,
+ checkable=False, parent=parent)
+ self.setShortcut(qt.QKeySequence.Save)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+
+ @staticmethod
+ def _errorMessage(informativeText='', parent=None):
+ """Display an error message."""
+ # TODO issue with QMessageBox size fixed and too small
+ msg = qt.QMessageBox(parent)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setInformativeText(informativeText + ' ' + str(sys.exc_info()[1]))
+ msg.setDetailedText(traceback.format_exc())
+ msg.exec()
+
+ def _saveSnapshot(self, plot, filename, nameFilter):
+ """Save a snapshot of the :class:`PlotWindow` widget.
+
+ :param str filename: The name of the file to write
+ :param str nameFilter: The selected name filter
+ :return: False if format is not supported or save failed,
+ True otherwise.
+ """
+ if nameFilter == self.SNAPSHOT_FILTER_PNG:
+ fileFormat = 'png'
+ elif nameFilter == self.SNAPSHOT_FILTER_SVG:
+ fileFormat = 'svg'
+ else: # Format not supported
+ _logger.error(
+ 'Saving plot snapshot failed: format not supported')
+ return False
+
+ plot.saveGraph(filename, fileFormat=fileFormat)
+ return True
+
+ def _getAxesLabels(self, item):
+ # If curve has no associated label, get the default from the plot
+ xlabel = item.getXLabel() or self.plot.getXAxis().getLabel()
+ ylabel = item.getYLabel() or self.plot.getYAxis().getLabel()
+ return xlabel, ylabel
+
+ def _get1dData(self, item):
+ "provide xdata, [ydata], xlabel, [ylabel] and manages error bars"
+ xlabel, ylabel = self._getAxesLabels(item)
+ x_data = item.getXData(copy=False)
+ y_data = item.getYData(copy=False)
+ x_err = item.getXErrorData(copy=False)
+ y_err = item.getYErrorData(copy=False)
+ labels = [ylabel]
+ data = [y_data]
+
+ if x_err is not None:
+ if numpy.isscalar(x_err):
+ data.append(numpy.zeros_like(y_data) + x_err)
+ labels.append(xlabel + "_errors")
+ elif x_err.ndim == 1:
+ data.append(x_err)
+ labels.append(xlabel + "_errors")
+ elif x_err.ndim == 2:
+ data.append(x_err[0])
+ labels.append(xlabel + "_errors_below")
+ data.append(x_err[1])
+ labels.append(xlabel + "_errors_above")
+
+ if y_err is not None:
+ if numpy.isscalar(y_err):
+ data.append(numpy.zeros_like(y_data) + y_err)
+ labels.append(ylabel + "_errors")
+ elif y_err.ndim == 1:
+ data.append(y_err)
+ labels.append(ylabel + "_errors")
+ elif y_err.ndim == 2:
+ data.append(y_err[0])
+ labels.append(ylabel + "_errors_below")
+ data.append(y_err[1])
+ labels.append(ylabel + "_errors_above")
+ return x_data, data, xlabel, labels
+
+ @staticmethod
+ def _selectWriteableOutputGroup(filename, parent):
+ if os.path.exists(filename) and os.path.isfile(filename) \
+ and os.access(filename, os.W_OK):
+ entryPath = selectOutputGroup(filename)
+ if entryPath is None:
+ _logger.info("Save operation cancelled")
+ return None
+ return entryPath
+ elif not os.path.exists(filename):
+ # create new entry in new file
+ return "/entry"
+ else:
+ SaveAction._errorMessage('Save failed (file access issue)\n', parent=parent)
+ return None
+
+ def _saveCurveAsNXdata(self, curve, filename):
+ entryPath = self._selectWriteableOutputGroup(filename, parent=self.plot)
+ if entryPath is None:
+ return False
+
+ xlabel, ylabel = self._getAxesLabels(curve)
+
+ return save_NXdata(
+ filename,
+ nxentry_name=entryPath,
+ signal=curve.getYData(copy=False),
+ axes=[curve.getXData(copy=False)],
+ signal_name="y",
+ axes_names=["x"],
+ signal_long_name=ylabel,
+ axes_long_names=[xlabel],
+ signal_errors=curve.getYErrorData(copy=False),
+ axes_errors=[curve.getXErrorData(copy=True)],
+ title=self.plot.getGraphTitle())
+
+ def _saveCurve(self, plot, filename, nameFilter):
+ """Save a curve from the plot.
+
+ :param str filename: The name of the file to write
+ :param str nameFilter: The selected name filter
+ :return: False if format is not supported or save failed,
+ True otherwise.
+ """
+ if nameFilter not in self.DEFAULT_CURVE_FILTERS:
+ return False
+
+ # Check if a curve is to be saved
+ curve = plot.getActiveCurve()
+ # before calling _saveCurve, if there is no selected curve, we
+ # make sure there is only one curve on the graph
+ if curve is None:
+ curves = plot.getAllCurves()
+ if not curves:
+ self._errorMessage("No curve to be saved", parent=self.plot)
+ return False
+ curve = curves[0]
+
+ if nameFilter in self.CURVE_FILTERS_TXT:
+ filter_ = self.CURVE_FILTERS_TXT[nameFilter]
+ fmt = filter_['fmt']
+ csvdelim = filter_['delimiter']
+ autoheader = filter_['header']
+ else:
+ # .npy or nxdata
+ fmt, csvdelim, autoheader = ("", "", False)
+
+ if nameFilter == self.CURVE_FILTER_NXDATA:
+ return self._saveCurveAsNXdata(curve, filename)
+
+ xdata, data, xlabel, labels = self._get1dData(curve)
+
+ try:
+ save1D(filename,
+ xdata, data,
+ xlabel, labels,
+ fmt=fmt, csvdelim=csvdelim,
+ autoheader=autoheader)
+ except IOError:
+ self._errorMessage('Save failed\n', parent=self.plot)
+ return False
+
+ return True
+
+ def _saveCurves(self, plot, filename, nameFilter):
+ """Save all curves from the plot.
+
+ :param str filename: The name of the file to write
+ :param str nameFilter: The selected name filter
+ :return: False if format is not supported or save failed,
+ True otherwise.
+ """
+ if nameFilter not in self.DEFAULT_ALL_CURVES_FILTERS:
+ return False
+
+ curves = plot.getAllCurves()
+ if not curves:
+ self._errorMessage("No curves to be saved", parent=self.plot)
+ return False
+
+ curve = curves[0]
+ scanno = 1
+ try:
+ xdata, data, xlabel, labels = self._get1dData(curve)
+
+ specfile = savespec(filename,
+ xdata, data,
+ xlabel, labels,
+ fmt="%.7g", scan_number=1, mode="w",
+ write_file_header=True,
+ close_file=False)
+ except IOError:
+ self._errorMessage('Save failed\n', parent=self.plot)
+ return False
+
+ for curve in curves[1:]:
+ try:
+ scanno += 1
+ xdata, data, xlabel, labels = self._get1dData(curve)
+ specfile = savespec(specfile,
+ xdata, data,
+ xlabel, labels,
+ fmt="%.7g", scan_number=scanno,
+ write_file_header=False,
+ close_file=False)
+ except IOError:
+ self._errorMessage('Save failed\n', parent=self.plot)
+ return False
+ specfile.close()
+
+ return True
+
+ def _saveImage(self, plot, filename, nameFilter):
+ """Save an image from the plot.
+
+ :param str filename: The name of the file to write
+ :param str nameFilter: The selected name filter
+ :return: False if format is not supported or save failed,
+ True otherwise.
+ """
+ if nameFilter not in self.DEFAULT_IMAGE_FILTERS:
+ return False
+
+ image = plot.getActiveImage()
+ if image is None:
+ qt.QMessageBox.warning(
+ plot, "No Data", "No image to be saved")
+ return False
+
+ data = image.getData(copy=False)
+
+ # TODO Use silx.io for writing files
+ if nameFilter == self.IMAGE_FILTER_EDF:
+ edfFile = EdfFile(filename, access="w+")
+ edfFile.WriteImage({}, data, Append=0)
+ return True
+
+ elif nameFilter == self.IMAGE_FILTER_TIFF:
+ tiffFile = TiffIO(filename, mode='w')
+ tiffFile.writeImage(data, software='silx')
+ return True
+
+ elif nameFilter == self.IMAGE_FILTER_NUMPY:
+ try:
+ numpy.save(filename, data)
+ except IOError:
+ self._errorMessage('Save failed\n', parent=self.plot)
+ return False
+ return True
+
+ elif nameFilter == self.IMAGE_FILTER_NXDATA:
+ entryPath = self._selectWriteableOutputGroup(filename, parent=self.plot)
+ if entryPath is None:
+ return False
+ xorigin, yorigin = image.getOrigin()
+ xscale, yscale = image.getScale()
+ xaxis = xorigin + xscale * numpy.arange(data.shape[1])
+ yaxis = yorigin + yscale * numpy.arange(data.shape[0])
+ xlabel, ylabel = self._getAxesLabels(image)
+ interpretation = "image" if len(data.shape) == 2 else "rgba-image"
+
+ return save_NXdata(filename,
+ nxentry_name=entryPath,
+ signal=data,
+ axes=[yaxis, xaxis],
+ signal_name="image",
+ axes_names=["y", "x"],
+ axes_long_names=[ylabel, xlabel],
+ title=plot.getGraphTitle(),
+ interpretation=interpretation)
+
+ elif nameFilter in (self.IMAGE_FILTER_ASCII,
+ self.IMAGE_FILTER_CSV_COMMA,
+ self.IMAGE_FILTER_CSV_SEMICOLON,
+ self.IMAGE_FILTER_CSV_TAB):
+ csvdelim, filetype = {
+ self.IMAGE_FILTER_ASCII: (' ', 'txt'),
+ self.IMAGE_FILTER_CSV_COMMA: (',', 'csv'),
+ self.IMAGE_FILTER_CSV_SEMICOLON: (';', 'csv'),
+ self.IMAGE_FILTER_CSV_TAB: ('\t', 'csv'),
+ }[nameFilter]
+
+ height, width = data.shape
+ rows, cols = numpy.mgrid[0:height, 0:width]
+ try:
+ save1D(filename, rows.ravel(), (cols.ravel(), data.ravel()),
+ filetype=filetype,
+ xlabel='row',
+ ylabels=['column', 'value'],
+ csvdelim=csvdelim,
+ autoheader=True)
+
+ except IOError:
+ self._errorMessage('Save failed\n', parent=self.plot)
+ return False
+ return True
+
+ elif nameFilter == self.IMAGE_FILTER_RGB_PNG:
+ # Get displayed image
+ rgbaImage = image.getRgbaImageData(copy=False)
+ # Convert RGB QImage
+ qimage = convertArrayToQImage(rgbaImage[:, :, :3])
+
+ if qimage.save(filename, 'PNG'):
+ return True
+ else:
+ _logger.error('Failed to save image as %s', filename)
+ qt.QMessageBox.critical(
+ self.parent(),
+ 'Save image as',
+ 'Failed to save image')
+
+ return False
+
+ def _saveScatter(self, plot, filename, nameFilter):
+ """Save an image from the plot.
+
+ :param str filename: The name of the file to write
+ :param str nameFilter: The selected name filter
+ :return: False if format is not supported or save failed,
+ True otherwise.
+ """
+ if nameFilter not in self.DEFAULT_SCATTER_FILTERS:
+ return False
+
+ if nameFilter == self.SCATTER_FILTER_NXDATA:
+ entryPath = self._selectWriteableOutputGroup(filename, parent=self.plot)
+ if entryPath is None:
+ return False
+ scatter = plot.getScatter()
+
+ x = scatter.getXData(copy=False)
+ y = scatter.getYData(copy=False)
+ z = scatter.getValueData(copy=False)
+
+ xerror = scatter.getXErrorData(copy=False)
+ if isinstance(xerror, float):
+ xerror = xerror * numpy.ones(x.shape, dtype=numpy.float32)
+
+ yerror = scatter.getYErrorData(copy=False)
+ if isinstance(yerror, float):
+ yerror = yerror * numpy.ones(x.shape, dtype=numpy.float32)
+
+ xlabel = plot.getGraphXLabel()
+ ylabel = plot.getGraphYLabel()
+
+ return save_NXdata(
+ filename,
+ nxentry_name=entryPath,
+ signal=z,
+ axes=[x, y],
+ signal_name="values",
+ axes_names=["x", "y"],
+ axes_long_names=[xlabel, ylabel],
+ axes_errors=[xerror, yerror],
+ title=plot.getGraphTitle())
+
+ def setFileFilter(self, dataKind, nameFilter, func, index=None, appendToFile=False):
+ """Set a name filter to add/replace a file format support
+
+ :param str dataKind:
+ The kind of data for which the provided filter is valid.
+ One of: 'all', 'curve', 'curves', 'image', 'scatter'
+ :param str nameFilter: The name filter in the QFileDialog.
+ See :meth:`QFileDialog.setNameFilters`.
+ :param callable func: The function to call to perform saving.
+ Expected signature is:
+ bool func(PlotWidget plot, str filename, str nameFilter)
+ :param bool appendToFile: True to append the data into the selected
+ file.
+ :param integer index: Index of the filter in the final list (or None)
+ """
+ assert dataKind in ('all', 'curve', 'curves', 'image', 'scatter')
+
+ if appendToFile:
+ self._appendFilters.append(nameFilter)
+
+ # first append or replace the new filter to prevent colissions
+ self._filters[dataKind][nameFilter] = func
+ if index is None:
+ # we are already done
+ return
+
+ # get the current ordered list of keys
+ keyList = list(self._filters[dataKind].keys())
+
+ # deal with negative indices
+ if index < 0:
+ index = len(keyList) + index
+ if index < 0:
+ index = 0
+
+ if index >= len(keyList):
+ # nothing to be done, already at the end
+ txt = 'Requested index %d impossible, already at the end' % index
+ _logger.info(txt)
+ return
+
+ # get the new ordered list
+ oldIndex = keyList.index(nameFilter)
+ del keyList[oldIndex]
+ keyList.insert(index, nameFilter)
+
+ # build the new filters
+ newFilters = OrderedDict()
+ for key in keyList:
+ newFilters[key] = self._filters[dataKind][key]
+
+ # and update the filters
+ self._filters[dataKind] = newFilters
+ return
+
+ def getFileFilters(self, dataKind):
+ """Returns the nameFilter and associated function for a kind of data.
+
+ :param str dataKind:
+ The kind of data for which the provided filter is valid.
+ On of: 'all', 'curve', 'curves', 'image', 'scatter'
+ :return: {nameFilter: function} associations.
+ :rtype: collections.OrderedDict
+ """
+ assert dataKind in ('all', 'curve', 'curves', 'image', 'scatter')
+
+ return self._filters[dataKind].copy()
+
+ def _actionTriggered(self, checked=False):
+ """Handle save action."""
+ # Set-up filters
+ filters = OrderedDict()
+
+ # Add image filters if there is an active image
+ if self.plot.getActiveImage() is not None:
+ filters.update(self._filters['image'].items())
+
+ # Add curve filters if there is a curve to save
+ if (self.plot.getActiveCurve() is not None or
+ len(self.plot.getAllCurves()) == 1):
+ filters.update(self._filters['curve'].items())
+ if len(self.plot.getAllCurves()) >= 1:
+ filters.update(self._filters['curves'].items())
+
+ # Add scatter filters if there is a scatter
+ # todo: CSV
+ if self.plot.getScatter() is not None:
+ filters.update(self._filters['scatter'].items())
+
+ filters.update(self._filters['all'].items())
+
+ # Create and run File dialog
+ dialog = qt.QFileDialog(self.plot)
+ dialog.setOption(dialog.DontUseNativeDialog)
+ dialog.setWindowTitle("Output File Selection")
+ dialog.setModal(1)
+ dialog.setNameFilters(list(filters.keys()))
+
+ dialog.setFileMode(dialog.AnyFile)
+ dialog.setAcceptMode(dialog.AcceptSave)
+
+ def onFilterSelection(filt_):
+ # disable overwrite confirmation for NXdata types,
+ # because we append the data to existing files
+ if filt_ in self._appendFilters:
+ dialog.setOption(dialog.DontConfirmOverwrite)
+ else:
+ dialog.setOption(dialog.DontConfirmOverwrite, False)
+
+ dialog.filterSelected.connect(onFilterSelection)
+
+ if not dialog.exec():
+ return False
+
+ nameFilter = dialog.selectedNameFilter()
+ filename = dialog.selectedFiles()[0]
+ dialog.close()
+
+ if '(' in nameFilter and ')' == nameFilter.strip()[-1]:
+ # Check for correct file extension
+ # Extract file extensions as .something
+ extensions = [ext[ext.find('.'):] for ext in
+ nameFilter[nameFilter.find('(') + 1:-1].split()]
+ for ext in extensions:
+ if (len(filename) > len(ext) and
+ filename[-len(ext):].lower() == ext.lower()):
+ break
+ else: # filename has no extension supported in nameFilter, add one
+ if len(extensions) >= 1:
+ filename += extensions[0]
+
+ # Handle save
+ func = filters.get(nameFilter, None)
+ if func is not None:
+ return func(self.plot, filename, nameFilter)
+ else:
+ _logger.error('Unsupported file filter: %s', nameFilter)
+ return False
+
+
+def _plotAsPNG(plot):
+ """Save a :class:`Plot` as PNG and return the payload.
+
+ :param plot: The :class:`Plot` to save
+ """
+ pngFile = BytesIO()
+ plot.saveGraph(pngFile, fileFormat='png')
+ pngFile.flush()
+ pngFile.seek(0)
+ data = pngFile.read()
+ pngFile.close()
+ return data
+
+
+class PrintAction(PlotAction):
+ """QAction for printing the plot.
+
+ It opens a Print dialog.
+
+ Current implementation print a bitmap of the plot area and not vector
+ graphics, so printing quality is not great.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate.
+ :param parent: See :class:`QAction`.
+ """
+
+ def __init__(self, plot, parent=None):
+ super(PrintAction, self).__init__(
+ plot, icon='document-print', text='Print...',
+ tooltip='Open print dialog',
+ triggered=self.printPlot,
+ checkable=False, parent=parent)
+ self.setShortcut(qt.QKeySequence.Print)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+
+ def getPrinter(self):
+ """The QPrinter instance used by the PrintAction.
+
+ :rtype: QPrinter
+ """
+ return printer.getDefaultPrinter()
+
+ @property
+ @deprecated(replacement="getPrinter()", since_version="0.8.0")
+ def printer(self):
+ return self.getPrinter()
+
+ def printPlotAsWidget(self):
+ """Open the print dialog and print the plot.
+
+ Use :meth:`QWidget.render` to print the plot
+
+ :return: True if successful
+ """
+ dialog = qt.QPrintDialog(self.getPrinter(), self.plot)
+ dialog.setWindowTitle('Print Plot')
+ if not dialog.exec():
+ return False
+
+ # Print a snapshot of the plot widget at the top of the page
+ widget = self.plot.centralWidget()
+
+ painter = qt.QPainter()
+ if not painter.begin(self.getPrinter()):
+ return False
+
+ pageRect = self.getPrinter().pageRect(qt.QPrinter.DevicePixel)
+ xScale = pageRect.width() / widget.width()
+ yScale = pageRect.height() / widget.height()
+ scale = min(xScale, yScale)
+
+ painter.translate(pageRect.width() / 2., 0.)
+ painter.scale(scale, scale)
+ painter.translate(-widget.width() / 2., 0.)
+ widget.render(painter)
+ painter.end()
+
+ return True
+
+ def printPlot(self):
+ """Open the print dialog and print the plot.
+
+ Use :meth:`Plot.saveGraph` to print the plot.
+
+ :return: True if successful
+ """
+ # Init printer and start printer dialog
+ dialog = qt.QPrintDialog(self.getPrinter(), self.plot)
+ dialog.setWindowTitle('Print Plot')
+ if not dialog.exec():
+ return False
+
+ # Save Plot as PNG and make a pixmap from it with default dpi
+ pngData = _plotAsPNG(self.plot)
+
+ pixmap = qt.QPixmap()
+ pixmap.loadFromData(pngData, 'png')
+
+ pageRect = self.getPrinter().pageRect(qt.QPrinter.DevicePixel)
+ xScale = pageRect.width() / pixmap.width()
+ yScale = pageRect.height() / pixmap.height()
+ scale = min(xScale, yScale)
+
+ # Draw pixmap with painter
+ painter = qt.QPainter()
+ if not painter.begin(self.getPrinter()):
+ return False
+
+ painter.drawPixmap(0, 0,
+ pixmap.width() * scale,
+ pixmap.height() * scale,
+ pixmap)
+ painter.end()
+
+ return True
+
+
+class CopyAction(PlotAction):
+ """QAction to copy :class:`.PlotWidget` content to clipboard.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(CopyAction, self).__init__(
+ plot, icon='edit-copy', text='Copy plot',
+ tooltip='Copy a snapshot of the plot into the clipboard',
+ triggered=self.copyPlot,
+ checkable=False, parent=parent)
+ self.setShortcut(qt.QKeySequence.Copy)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+
+ def copyPlot(self):
+ """Copy plot content to the clipboard as a bitmap."""
+ # Save Plot as PNG and make a QImage from it with default dpi
+ pngData = _plotAsPNG(self.plot)
+ image = qt.QImage.fromData(pngData, 'png')
+ qt.QApplication.clipboard().setImage(image)
diff --git a/src/silx/gui/plot/actions/medfilt.py b/src/silx/gui/plot/actions/medfilt.py
new file mode 100644
index 0000000..f86a377
--- /dev/null
+++ b/src/silx/gui/plot/actions/medfilt.py
@@ -0,0 +1,147 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+:mod:`silx.gui.plot.actions.medfilt` provides a set of QAction to apply filter
+on data contained in a :class:`.PlotWidget`.
+
+The following QAction are available:
+
+- :class:`MedianFilterAction`
+- :class:`MedianFilter1DAction`
+- :class:`MedianFilter2DAction`
+
+"""
+
+from __future__ import division
+
+__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+
+__date__ = "10/10/2018"
+
+from .PlotToolAction import PlotToolAction
+from silx.gui.widgets.MedianFilterDialog import MedianFilterDialog
+from silx.math.medianfilter import medfilt2d
+import logging
+
+_logger = logging.getLogger(__name__)
+
+
+class MedianFilterAction(PlotToolAction):
+ """QAction to plot the pixels intensities diagram
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ PlotToolAction.__init__(self,
+ plot,
+ icon='median-filter',
+ text='median filter',
+ tooltip='Apply a median filter on the image',
+ parent=parent)
+ self._originalImage = None
+ self._legend = None
+ self._filteredImage = None
+
+ def _createToolWindow(self):
+ popup = MedianFilterDialog(parent=self.plot)
+ popup.sigFilterOptChanged.connect(self._updateFilter)
+ return popup
+
+ def _connectPlot(self, window):
+ PlotToolAction._connectPlot(self, window)
+ self.plot.sigActiveImageChanged.connect(self._updateActiveImage)
+ self._updateActiveImage()
+
+ def _disconnectPlot(self, window):
+ PlotToolAction._disconnectPlot(self, window)
+ self.plot.sigActiveImageChanged.disconnect(self._updateActiveImage)
+
+ def _updateActiveImage(self):
+ """Set _activeImageLegend and _originalImage from the active image"""
+ self._activeImageLegend = self.plot.getActiveImage(just_legend=True)
+ if self._activeImageLegend is None:
+ self._originalImage = None
+ self._legend = None
+ else:
+ self._originalImage = self.plot.getImage(self._activeImageLegend).getData(copy=False)
+ self._legend = self.plot.getImage(self._activeImageLegend).getName()
+
+ def _updateFilter(self, kernelWidth, conditional=False):
+ if self._originalImage is None:
+ return
+
+ self.plot.sigActiveImageChanged.disconnect(self._updateActiveImage)
+ filteredImage = self._computeFilteredImage(kernelWidth, conditional)
+ self.plot.addImage(data=filteredImage,
+ legend=self._legend,
+ replace=True)
+ self.plot.sigActiveImageChanged.connect(self._updateActiveImage)
+
+ def _computeFilteredImage(self, kernelWidth, conditional):
+ raise NotImplementedError('MedianFilterAction is a an abstract class')
+
+ def getFilteredImage(self):
+ """
+ :return: the image with the median filter apply on"""
+ return self._filteredImage
+
+
+class MedianFilter1DAction(MedianFilterAction):
+ """Define the MedianFilterAction for 1D
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+ def __init__(self, plot, parent=None):
+ MedianFilterAction.__init__(self,
+ plot,
+ parent=parent)
+
+ def _computeFilteredImage(self, kernelWidth, conditional):
+ assert(self.plot is not None)
+ return medfilt2d(self._originalImage,
+ (kernelWidth, 1),
+ conditional)
+
+
+class MedianFilter2DAction(MedianFilterAction):
+ """Define the MedianFilterAction for 2D
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+ def __init__(self, plot, parent=None):
+ MedianFilterAction.__init__(self,
+ plot,
+ parent=parent)
+
+ def _computeFilteredImage(self, kernelWidth, conditional):
+ assert(self.plot is not None)
+ return medfilt2d(self._originalImage,
+ (kernelWidth, kernelWidth),
+ conditional)
diff --git a/src/silx/gui/plot/actions/mode.py b/src/silx/gui/plot/actions/mode.py
new file mode 100644
index 0000000..ee05256
--- /dev/null
+++ b/src/silx/gui/plot/actions/mode.py
@@ -0,0 +1,104 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+:mod:`silx.gui.plot.actions.mode` provides a set of QAction relative to mouse
+mode of a :class:`.PlotWidget`.
+
+The following QAction are available:
+
+- :class:`ZoomModeAction`
+- :class:`PanModeAction`
+"""
+
+from __future__ import division
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "16/08/2017"
+
+from . import PlotAction
+import logging
+
+_logger = logging.getLogger(__name__)
+
+
+class ZoomModeAction(PlotAction):
+ """QAction controlling the zoom mode of a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(ZoomModeAction, self).__init__(
+ plot, icon='zoom', text='Zoom mode',
+ tooltip='Zoom in or out',
+ triggered=self._actionTriggered,
+ checkable=True, parent=parent)
+ # Listen to mode change
+ self.plot.sigInteractiveModeChanged.connect(self._modeChanged)
+ # Init the state
+ self._modeChanged(None)
+
+ def _modeChanged(self, source):
+ modeDict = self.plot.getInteractiveMode()
+ old = self.blockSignals(True)
+ self.setChecked(modeDict["mode"] == "zoom")
+ self.blockSignals(old)
+
+ def _actionTriggered(self, checked=False):
+ plot = self.plot
+ if plot is not None:
+ plot.setInteractiveMode('zoom', source=self)
+
+
+class PanModeAction(PlotAction):
+ """QAction controlling the pan mode of a :class:`.PlotWidget`.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(PanModeAction, self).__init__(
+ plot, icon='pan', text='Pan mode',
+ tooltip='Pan the view',
+ triggered=self._actionTriggered,
+ checkable=True, parent=parent)
+ # Listen to mode change
+ self.plot.sigInteractiveModeChanged.connect(self._modeChanged)
+ # Init the state
+ self._modeChanged(None)
+
+ def _modeChanged(self, source):
+ modeDict = self.plot.getInteractiveMode()
+ old = self.blockSignals(True)
+ self.setChecked(modeDict["mode"] == "pan")
+ self.blockSignals(old)
+
+ def _actionTriggered(self, checked=False):
+ plot = self.plot
+ if plot is not None:
+ plot.setInteractiveMode('pan', source=self)
diff --git a/src/silx/gui/plot/backends/BackendBase.py b/src/silx/gui/plot/backends/BackendBase.py
new file mode 100755
index 0000000..1e86807
--- /dev/null
+++ b/src/silx/gui/plot/backends/BackendBase.py
@@ -0,0 +1,568 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Base class for Plot backends.
+
+It documents the Plot backend API.
+
+This API is a simplified version of PyMca PlotBackend API.
+"""
+
+__authors__ = ["V.A. Sole", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "21/12/2018"
+
+import weakref
+from ... import qt
+
+
+# Names for setCursor
+CURSOR_DEFAULT = 'default'
+CURSOR_POINTING = 'pointing'
+CURSOR_SIZE_HOR = 'size horizontal'
+CURSOR_SIZE_VER = 'size vertical'
+CURSOR_SIZE_ALL = 'size all'
+
+
+class BackendBase(object):
+ """Class defining the API a backend of the Plot should provide."""
+
+ def __init__(self, plot, parent=None):
+ """Init.
+
+ :param Plot plot: The Plot this backend is attached to
+ :param parent: The parent widget of the plot widget.
+ """
+ self.__xLimits = 1., 100.
+ self.__yLimits = {'left': (1., 100.), 'right': (1., 100.)}
+ self.__yAxisInverted = False
+ self.__keepDataAspectRatio = False
+ self.__xAxisTimeSeries = False
+ self._xAxisTimeZone = None
+ # Store a weakref to get access to the plot state.
+ self._setPlot(plot)
+
+ @property
+ def _plot(self):
+ """The plot this backend is attached to."""
+ if self._plotRef is None:
+ raise RuntimeError('This backend is not attached to a Plot')
+
+ plot = self._plotRef()
+ if plot is None:
+ raise RuntimeError('This backend is no more attached to a Plot')
+ return plot
+
+ def _setPlot(self, plot):
+ """Allow to set plot after init.
+
+ Use with caution, basically **immediately** after init.
+ """
+ self._plotRef = weakref.ref(plot)
+
+ # Add methods
+
+ def addCurve(self, x, y,
+ color, symbol, linewidth, linestyle,
+ yaxis,
+ xerror, yerror,
+ fill, alpha, symbolsize, baseline):
+ """Add a 1D curve given by x an y to the graph.
+
+ :param numpy.ndarray x: The data corresponding to the x axis
+ :param numpy.ndarray y: The data corresponding to the y axis
+ :param color: color(s) to be used
+ :type color: string ("#RRGGBB") or (npoints, 4) unsigned byte array or
+ one of the predefined color names defined in colors.py
+ :param str symbol: Symbol to be drawn at each (x, y) position::
+
+ - ' ' or '' no symbol
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+
+ :param float linewidth: The width of the curve in pixels
+ :param str linestyle: Type of line::
+
+ - ' ' or '' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+
+ :param str yaxis: The Y axis this curve belongs to in: 'left', 'right'
+ :param xerror: Values with the uncertainties on the x values
+ :type xerror: numpy.ndarray or None
+ :param yerror: Values with the uncertainties on the y values
+ :type yerror: numpy.ndarray or None
+ :param bool fill: True to fill the curve, False otherwise
+ :param float alpha: Curve opacity, as a float in [0., 1.]
+ :param float symbolsize: Size of the symbol (if any) drawn
+ at each (x, y) position.
+ :returns: The handle used by the backend to univocally access the curve
+ """
+ return object()
+
+ def addImage(self, data,
+ origin, scale,
+ colormap, alpha):
+ """Add an image to the plot.
+
+ :param numpy.ndarray data: (nrows, ncolumns) data or
+ (nrows, ncolumns, RGBA) ubyte array
+ :param origin: (origin X, origin Y) of the data.
+ Default: (0., 0.)
+ :type origin: 2-tuple of float
+ :param scale: (scale X, scale Y) of the data.
+ Default: (1., 1.)
+ :type scale: 2-tuple of float
+ :param ~silx.gui.colors.Colormap colormap: Colormap object to use.
+ Ignored if data is RGB(A).
+ :param float alpha: Opacity of the image, as a float in range [0, 1].
+ :returns: The handle used by the backend to univocally access the image
+ """
+ return object()
+
+ def addTriangles(self, x, y, triangles,
+ color, alpha):
+ """Add a set of triangles.
+
+ :param numpy.ndarray x: The data corresponding to the x axis
+ :param numpy.ndarray y: The data corresponding to the y axis
+ :param numpy.ndarray triangles: The indices to make triangles
+ as a (Ntriangle, 3) array
+ :param numpy.ndarray color: color(s) as (npoints, 4) array
+ :param float alpha: Opacity as a float in [0., 1.]
+ :returns: The triangles' unique identifier used by the backend
+ """
+ return object()
+
+ def addShape(self, x, y, shape, color, fill, overlay,
+ linestyle, linewidth, linebgcolor):
+ """Add an item (i.e. a shape) to the plot.
+
+ :param numpy.ndarray x: The X coords of the points of the shape
+ :param numpy.ndarray y: The Y coords of the points of the shape
+ :param str shape: Type of item to be drawn in
+ hline, polygon, rectangle, vline, polylines
+ :param str color: Color of the item
+ :param bool fill: True to fill the shape
+ :param bool overlay: True if item is an overlay, False otherwise
+ :param str linestyle: Style of the line.
+ Only relevant for line markers where X or Y is None.
+ Value in:
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+ :param float linewidth: Width of the line.
+ Only relevant for line markers where X or Y is None.
+ :param str linebgcolor: Background color of the line, e.g., 'blue', 'b',
+ '#FF0000'. It is used to draw dotted line using a second color.
+ :returns: The handle used by the backend to univocally access the item
+ """
+ return object()
+
+ def addMarker(self, x, y, text, color,
+ symbol, linestyle, linewidth, constraint, yaxis):
+ """Add a point, vertical line or horizontal line marker to the plot.
+
+ :param float x: Horizontal position of the marker in graph coordinates.
+ If None, the marker is a horizontal line.
+ :param float y: Vertical position of the marker in graph coordinates.
+ If None, the marker is a vertical line.
+ :param str text: Text associated to the marker (or None for no text)
+ :param str color: Color to be used for instance 'blue', 'b', '#FF0000'
+ :param str symbol: Symbol representing the marker.
+ Only relevant for point markers where X and Y are not None.
+ Value in:
+
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+ :param str linestyle: Style of the line.
+ Only relevant for line markers where X or Y is None.
+ Value in:
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+ :param float linewidth: Width of the line.
+ Only relevant for line markers where X or Y is None.
+ :param constraint: A function filtering marker displacement by
+ dragging operations or None for no filter.
+ This function is called each time a marker is
+ moved.
+ :type constraint: None or a callable that takes the coordinates of
+ the current cursor position in the plot as input
+ and that returns the filtered coordinates.
+ :param str yaxis: The Y axis this marker belongs to in: 'left', 'right'
+ :return: Handle used by the backend to univocally access the marker
+ """
+ return object()
+
+ # Remove methods
+
+ def remove(self, item):
+ """Remove an existing item from the plot.
+
+ :param item: A backend specific item handle returned by a add* method
+ """
+ pass
+
+ # Interaction methods
+
+ def setGraphCursorShape(self, cursor):
+ """Set the cursor shape.
+
+ To override in interactive backends.
+
+ :param str cursor: Name of the cursor shape or None
+ """
+ pass
+
+ def setGraphCursor(self, flag, color, linewidth, linestyle):
+ """Toggle the display of a crosshair cursor and set its attributes.
+
+ To override in interactive backends.
+
+ :param bool flag: Toggle the display of a crosshair cursor.
+ :param color: The color to use for the crosshair.
+ :type color: A string (either a predefined color name in colors.py
+ or "#RRGGBB")) or a 4 columns unsigned byte array.
+ :param int linewidth: The width of the lines of the crosshair.
+ :param linestyle: Type of line::
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+
+ :type linestyle: None or one of the predefined styles.
+ """
+ pass
+
+ def getItemsFromBackToFront(self, condition=None):
+ """Returns the list of plot items order as rendered by the backend.
+
+ This is the order used for rendering.
+ By default, it takes into account overlays, z value and order of addition of items,
+ but backends can override it.
+
+ :param callable condition:
+ Callable taking an item as input and returning False for items to skip.
+ If None (default), no item is skipped.
+ :rtype: List[~silx.gui.plot.items.Item]
+ """
+ # Sort items: Overlays first, then others
+ # and in each category ordered by z and then by order of addition
+ # as content keeps this order.
+ content = self._plot.getItems()
+ if condition is not None:
+ content = [item for item in content if condition(item)]
+
+ return sorted(
+ content,
+ key=lambda i: ((1 if i.isOverlay() else 0), i.getZValue()))
+
+ def pickItem(self, x, y, item):
+ """Return picked indices if any, or None.
+
+ :param float x: The x pixel coord where to pick.
+ :param float y: The y pixel coord where to pick.
+ :param item: A backend item created with add* methods.
+ :return: None if item was not picked, else returns
+ picked indices information.
+ :rtype: Union[None,List]
+ """
+ return None
+
+ # Update curve
+
+ def setCurveColor(self, curve, color):
+ """Set the color of a curve.
+
+ :param curve: The curve handle
+ :param str color: The color to use.
+ """
+ pass
+
+ # Misc.
+
+ def getWidgetHandle(self):
+ """Return the widget this backend is drawing to."""
+ return None
+
+ def postRedisplay(self):
+ """Trigger backend update and repaint."""
+ self.replot()
+
+ def replot(self):
+ """Redraw the plot."""
+ with self._plot._paintContext():
+ pass
+
+ def saveGraph(self, fileName, fileFormat, dpi):
+ """Save the graph to a file (or a StringIO)
+
+ At least "png", "svg" are supported.
+
+ :param fileName: Destination
+ :type fileName: String or StringIO or BytesIO
+ :param str fileFormat: String specifying the format
+ :param int dpi: The resolution to use or None.
+ """
+ pass
+
+ # Graph labels
+
+ def setGraphTitle(self, title):
+ """Set the main title of the plot.
+
+ :param str title: Title associated to the plot
+ """
+ pass
+
+ def setGraphXLabel(self, label):
+ """Set the X axis label.
+
+ :param str label: label associated to the plot bottom X axis
+ """
+ pass
+
+ def setGraphYLabel(self, label, axis):
+ """Set the left Y axis label.
+
+ :param str label: label associated to the plot left Y axis
+ :param str axis: The axis for which to get the limits: left or right
+ """
+ pass
+
+ # Graph limits
+
+ def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None):
+ """Set the limits of the X and Y axes at once.
+
+ :param float xmin: minimum bottom axis value
+ :param float xmax: maximum bottom axis value
+ :param float ymin: minimum left axis value
+ :param float ymax: maximum left axis value
+ :param float y2min: minimum right axis value
+ :param float y2max: maximum right axis value
+ """
+ self.__xLimits = xmin, xmax
+ self.__yLimits['left'] = ymin, ymax
+ if y2min is not None and y2max is not None:
+ self.__yLimits['right'] = y2min, y2max
+
+ def getGraphXLimits(self):
+ """Get the graph X (bottom) limits.
+
+ :return: Minimum and maximum values of the X axis
+ """
+ return self.__xLimits
+
+ def setGraphXLimits(self, xmin, xmax):
+ """Set the limits of X axis.
+
+ :param float xmin: minimum bottom axis value
+ :param float xmax: maximum bottom axis value
+ """
+ self.__xLimits = xmin, xmax
+
+ def getGraphYLimits(self, axis):
+ """Get the graph Y (left) limits.
+
+ :param str axis: The axis for which to get the limits: left or right
+ :return: Minimum and maximum values of the Y axis
+ """
+ return self.__yLimits[axis]
+
+ def setGraphYLimits(self, ymin, ymax, axis):
+ """Set the limits of the Y axis.
+
+ :param float ymin: minimum left axis value
+ :param float ymax: maximum left axis value
+ :param str axis: The axis for which to get the limits: left or right
+ """
+ self.__yLimits[axis] = ymin, ymax
+
+ # Graph axes
+
+
+ def getXAxisTimeZone(self):
+ """Returns tzinfo that is used if the X-Axis plots date-times.
+
+ None means the datetimes are interpreted as local time.
+
+ :rtype: datetime.tzinfo of None.
+ """
+ return self._xAxisTimeZone
+
+ def setXAxisTimeZone(self, tz):
+ """Sets tzinfo that is used if the X-Axis plots date-times.
+
+ Use None to let the datetimes be interpreted as local time.
+
+ :rtype: datetime.tzinfo of None.
+ """
+ self._xAxisTimeZone = tz
+
+ def isXAxisTimeSeries(self):
+ """Return True if the X-axis scale shows datetime objects.
+
+ :rtype: bool
+ """
+ return self.__xAxisTimeSeries
+
+ def setXAxisTimeSeries(self, isTimeSeries):
+ """Set whether the X-axis is a time series
+
+ :param bool flag: True to switch to time series, False for regular axis.
+ """
+ self.__xAxisTimeSeries = bool(isTimeSeries)
+
+ def setXAxisLogarithmic(self, flag):
+ """Set the X axis scale between linear and log.
+
+ :param bool flag: If True, the bottom axis will use a log scale
+ """
+ pass
+
+ def setYAxisLogarithmic(self, flag):
+ """Set the Y axis scale between linear and log.
+
+ :param bool flag: If True, the left axis will use a log scale
+ """
+ pass
+
+ def setYAxisInverted(self, flag):
+ """Invert the Y axis.
+
+ :param bool flag: If True, put the vertical axis origin on the top
+ """
+ self.__yAxisInverted = bool(flag)
+
+ def isYAxisInverted(self):
+ """Return True if left Y axis is inverted, False otherwise."""
+ return self.__yAxisInverted
+
+ def isKeepDataAspectRatio(self):
+ """Returns whether the plot is keeping data aspect ratio or not."""
+ return self.__keepDataAspectRatio
+
+ def setKeepDataAspectRatio(self, flag):
+ """Set whether to keep data aspect ratio or not.
+
+ :param flag: True to respect data aspect ratio
+ :type flag: Boolean, default True
+ """
+ self.__keepDataAspectRatio = bool(flag)
+
+ def setGraphGrid(self, which):
+ """Set grid.
+
+ :param which: None to disable grid, 'major' for major grid,
+ 'both' for major and minor grid
+ """
+ pass
+
+ # Data <-> Pixel coordinates conversion
+
+ def dataToPixel(self, x, y, axis):
+ """Convert a position in data space to a position in pixels
+ in the widget.
+
+ :param float x: The X coordinate in data space.
+ :param float y: The Y coordinate in data space.
+ :param str axis: The Y axis to use for the conversion
+ ('left' or 'right').
+ :returns: The corresponding position in pixels or
+ None if the data position is not in the displayed area.
+ :rtype: A tuple of 2 floats: (xPixel, yPixel) or None.
+ """
+ raise NotImplementedError()
+
+ def pixelToData(self, x, y, axis):
+ """Convert a position in pixels in the widget to a position in
+ the data space.
+
+ :param float x: The X coordinate in pixels.
+ :param float y: The Y coordinate in pixels.
+ :param str axis: The Y axis to use for the conversion
+ ('left' or 'right').
+ :returns: The corresponding position in data space or
+ None if the pixel position is not in the plot area.
+ :rtype: A tuple of 2 floats: (xData, yData) or None.
+ """
+ raise NotImplementedError()
+
+ def getPlotBoundsInPixels(self):
+ """Plot area bounds in widget coordinates in pixels.
+
+ :return: bounds as a 4-tuple of int: (left, top, width, height)
+ """
+ raise NotImplementedError()
+
+ def setAxesMargins(self, left: float, top: float, right: float, bottom: float):
+ """Set the size of plot margins as ratios.
+
+ Values are expected in [0., 1.]
+
+ :param float left:
+ :param float top:
+ :param float right:
+ :param float bottom:
+ """
+ pass
+
+ def setForegroundColors(self, foregroundColor, gridColor):
+ """Set foreground and grid colors used to display this widget.
+
+ :param List[float] foregroundColor: RGBA foreground color of the widget
+ :param List[float] gridColor: RGBA grid color of the data view
+ """
+ pass
+
+ def setBackgroundColors(self, backgroundColor, dataBackgroundColor):
+ """Set background colors used to display this widget.
+
+ :param List[float] backgroundColor: RGBA background color of the widget
+ :param Union[Tuple[float],None] dataBackgroundColor:
+ RGBA background color of the data view
+ """
+ pass
diff --git a/src/silx/gui/plot/backends/BackendMatplotlib.py b/src/silx/gui/plot/backends/BackendMatplotlib.py
new file mode 100755
index 0000000..7fe4ec0
--- /dev/null
+++ b/src/silx/gui/plot/backends/BackendMatplotlib.py
@@ -0,0 +1,1557 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Matplotlib Plot backend."""
+
+from __future__ import division
+
+__authors__ = ["V.A. Sole", "T. Vincent, H. Payno"]
+__license__ = "MIT"
+__date__ = "21/12/2018"
+
+
+import logging
+import datetime as dt
+from typing import Tuple
+import numpy
+
+from pkg_resources import parse_version as _parse_version
+
+
+_logger = logging.getLogger(__name__)
+
+
+from ... import qt
+
+# First of all init matplotlib and set its backend
+from ...utils.matplotlib import FigureCanvasQTAgg
+import matplotlib
+from matplotlib.container import Container
+from matplotlib.figure import Figure
+from matplotlib.patches import Rectangle, Polygon
+from matplotlib.image import AxesImage
+from matplotlib.backend_bases import MouseEvent
+from matplotlib.lines import Line2D
+from matplotlib.text import Text
+from matplotlib.collections import PathCollection, LineCollection
+from matplotlib.ticker import Formatter, ScalarFormatter, Locator
+from matplotlib.tri import Triangulation
+from matplotlib.collections import TriMesh
+from matplotlib import path as mpath
+
+from . import BackendBase
+from .. import items
+from .._utils import FLOAT32_MINPOS
+from .._utils.dtime_ticklayout import calcTicks, bestFormatString, timestamp
+
+_PATCH_LINESTYLE = {
+ "-": 'solid',
+ "--": 'dashed',
+ '-.': 'dashdot',
+ ':': 'dotted',
+ '': "solid",
+ None: "solid",
+}
+"""Patches do not uses the same matplotlib syntax"""
+
+_MARKER_PATHS = {}
+"""Store cached extra marker paths"""
+
+_SPECIAL_MARKERS = {
+ 'tickleft': 0,
+ 'tickright': 1,
+ 'tickup': 2,
+ 'tickdown': 3,
+ 'caretleft': 4,
+ 'caretright': 5,
+ 'caretup': 6,
+ 'caretdown': 7,
+}
+
+
+def normalize_linestyle(linestyle):
+ """Normalize known old-style linestyle, else return the provided value."""
+ return _PATCH_LINESTYLE.get(linestyle, linestyle)
+
+def get_path_from_symbol(symbol):
+ """Get the path representation of a symbol, else None if
+ it is not provided.
+
+ :param str symbol: Symbol description used by silx
+ :rtype: Union[None,matplotlib.path.Path]
+ """
+ if symbol == u'\u2665':
+ path = _MARKER_PATHS.get(symbol, None)
+ if path is not None:
+ return path
+ vertices = numpy.array([
+ [0,-99],
+ [31,-73], [47,-55], [55,-46],
+ [63,-37], [94,-2], [94,33],
+ [94,69], [71,89], [47,89],
+ [24,89], [8,74], [0,58],
+ [-8,74], [-24,89], [-47,89],
+ [-71,89], [-94,69], [-94,33],
+ [-94,-2], [-63,-37], [-55,-46],
+ [-47,-55], [-31,-73], [0,-99],
+ [0,-99]])
+ codes = [mpath.Path.CURVE4] * len(vertices)
+ codes[0] = mpath.Path.MOVETO
+ codes[-1] = mpath.Path.CLOSEPOLY
+ path = mpath.Path(vertices, codes)
+ _MARKER_PATHS[symbol] = path
+ return path
+ return None
+
+class NiceDateLocator(Locator):
+ """
+ Matplotlib Locator that uses Nice Numbers algorithm (adapted to dates)
+ to find the tick locations. This results in the same number behaviour
+ as when using the silx Open GL backend.
+
+ Expects the data to be posix timestampes (i.e. seconds since 1970)
+ """
+ def __init__(self, numTicks=5, tz=None):
+ """
+ :param numTicks: target number of ticks
+ :param datetime.tzinfo tz: optional time zone. None is local time.
+ """
+ super(NiceDateLocator, self).__init__()
+ self.numTicks = numTicks
+
+ self._spacing = None
+ self._unit = None
+ self.tz = tz
+
+ @property
+ def spacing(self):
+ """ The current spacing. Will be updated when new tick value are made"""
+ return self._spacing
+
+ @property
+ def unit(self):
+ """ The current DtUnit. Will be updated when new tick value are made"""
+ return self._unit
+
+ def __call__(self):
+ """Return the locations of the ticks"""
+ vmin, vmax = self.axis.get_view_interval()
+ return self.tick_values(vmin, vmax)
+
+ def tick_values(self, vmin, vmax):
+ """ Calculates tick values
+ """
+ if vmax < vmin:
+ vmin, vmax = vmax, vmin
+
+ # vmin and vmax should be timestamps (i.e. seconds since 1 Jan 1970)
+ dtMin = dt.datetime.fromtimestamp(vmin, tz=self.tz)
+ dtMax = dt.datetime.fromtimestamp(vmax, tz=self.tz)
+ dtTicks, self._spacing, self._unit = \
+ calcTicks(dtMin, dtMax, self.numTicks)
+
+ # Convert datetime back to time stamps.
+ ticks = [timestamp(dtTick) for dtTick in dtTicks]
+ return ticks
+
+
+class NiceAutoDateFormatter(Formatter):
+ """
+ Matplotlib FuncFormatter that is linked to a NiceDateLocator and gives the
+ best possible formats given the locators current spacing an date unit.
+ """
+
+ def __init__(self, locator, tz=None):
+ """
+ :param niceDateLocator: a NiceDateLocator object
+ :param datetime.tzinfo tz: optional time zone. None is local time.
+ """
+ super(NiceAutoDateFormatter, self).__init__()
+ self.locator = locator
+ self.tz = tz
+
+ @property
+ def formatString(self):
+ if self.locator.spacing is None or self.locator.unit is None:
+ # Locator has no spacing or units yet. Return elaborate fmtString
+ return "Y-%m-%d %H:%M:%S"
+ else:
+ return bestFormatString(self.locator.spacing, self.locator.unit)
+
+ def __call__(self, x, pos=None):
+ """Return the format for tick val *x* at position *pos*
+ Expects x to be a POSIX timestamp (seconds since 1 Jan 1970)
+ """
+ dateTime = dt.datetime.fromtimestamp(x, tz=self.tz)
+ tickStr = dateTime.strftime(self.formatString)
+ return tickStr
+
+
+class _PickableContainer(Container):
+ """Artists container with a :meth:`contains` method"""
+
+ def __init__(self, *args, **kwargs):
+ Container.__init__(self, *args, **kwargs)
+ self.__zorder = None
+
+ @property
+ def axes(self):
+ """Mimin Artist.axes"""
+ for child in self.get_children():
+ if hasattr(child, 'axes'):
+ return child.axes
+ return None
+
+ def draw(self, *args, **kwargs):
+ """artist-like draw to broadcast draw to children"""
+ for child in self.get_children():
+ child.draw(*args, **kwargs)
+
+ def get_zorder(self):
+ """Mimic Artist.get_zorder"""
+ return self.__zorder
+
+ def set_zorder(self, z):
+ """Mimic Artist.set_zorder to broadcast to children"""
+ if z != self.__zorder:
+ self.__zorder = z
+ for child in self.get_children():
+ child.set_zorder(z)
+
+ def contains(self, mouseevent):
+ """Mimic Artist.contains, and call it on all children.
+
+ :param mouseevent:
+ :return: Picking status and associated information as a dict
+ :rtype: (bool,dict)
+ """
+ # Goes through children from front to back and return first picked one.
+ for child in reversed(self.get_children()):
+ picked, info = child.contains(mouseevent)
+ if picked:
+ return picked, info
+ return False, {}
+
+
+class _TextWithOffset(Text):
+ """Text object which can be displayed at a specific position
+ of the plot, but with a pixel offset"""
+
+ def __init__(self, *args, **kwargs):
+ Text.__init__(self, *args, **kwargs)
+ self.pixel_offset = (0, 0)
+ self.__cache = None
+
+ def draw(self, renderer):
+ self.__cache = None
+ return Text.draw(self, renderer)
+
+ def __get_xy(self):
+ if self.__cache is not None:
+ return self.__cache
+
+ align = self.get_horizontalalignment()
+ if align == "left":
+ xoffset = self.pixel_offset[0]
+ elif align == "right":
+ xoffset = -self.pixel_offset[0]
+ else:
+ xoffset = 0
+
+ align = self.get_verticalalignment()
+ if align == "top":
+ yoffset = -self.pixel_offset[1]
+ elif align == "bottom":
+ yoffset = self.pixel_offset[1]
+ else:
+ yoffset = 0
+
+ trans = self.get_transform()
+ x = super(_TextWithOffset, self).convert_xunits(self._x)
+ y = super(_TextWithOffset, self).convert_xunits(self._y)
+ pos = x, y
+
+ try:
+ invtrans = trans.inverted()
+ except numpy.linalg.LinAlgError:
+ # Cannot inverse transform, fallback: pos without offset
+ self.__cache = None
+ return pos
+
+ proj = trans.transform_point(pos)
+ proj = proj + numpy.array((xoffset, yoffset))
+ pos = invtrans.transform_point(proj)
+ self.__cache = pos
+ return pos
+
+ def convert_xunits(self, x):
+ """Return the pixel position of the annotated point."""
+ return self.__get_xy()[0]
+
+ def convert_yunits(self, y):
+ """Return the pixel position of the annotated point."""
+ return self.__get_xy()[1]
+
+
+class _MarkerContainer(_PickableContainer):
+ """Marker artists container supporting draw/remove and text position update
+
+ :param artists:
+ Iterable with either one Line2D or a Line2D and a Text.
+ The use of an iterable if enforced by Container being
+ a subclass of tuple that defines a specific __new__.
+ :param x: X coordinate of the marker (None for horizontal lines)
+ :param y: Y coordinate of the marker (None for vertical lines)
+ """
+
+ def __init__(self, artists, symbol, x, y, yAxis):
+ self.line = artists[0]
+ self.text = artists[1] if len(artists) > 1 else None
+ self.symbol = symbol
+ self.x = x
+ self.y = y
+ self.yAxis = yAxis
+
+ _PickableContainer.__init__(self, artists)
+
+ def draw(self, *args, **kwargs):
+ """artist-like draw to broadcast draw to line and text"""
+ self.line.draw(*args, **kwargs)
+ if self.text is not None:
+ self.text.draw(*args, **kwargs)
+
+ def updateMarkerText(self, xmin, xmax, ymin, ymax, yinverted):
+ """Update marker text position and visibility according to plot limits
+
+ :param xmin: X axis lower limit
+ :param xmax: X axis upper limit
+ :param ymin: Y axis lower limit
+ :param ymax: Y axis upper limit
+ :param yinverted: True if the y axis is inverted
+ """
+ if self.text is not None:
+ visible = ((self.x is None or xmin <= self.x <= xmax) and
+ (self.y is None or ymin <= self.y <= ymax))
+ self.text.set_visible(visible)
+
+ if self.x is not None and self.y is not None:
+ if self.symbol is None:
+ valign = 'baseline'
+ else:
+ if yinverted:
+ valign = 'bottom'
+ else:
+ valign = 'top'
+ self.text.set_verticalalignment(valign)
+
+ elif self.y is None: # vertical line
+ # Always display it on top
+ center = (ymax + ymin) * 0.5
+ pos = (ymax - ymin) * 0.5 * 0.99
+ if yinverted:
+ pos = -pos
+ self.text.set_y(center + pos)
+
+ elif self.x is None: # Horizontal line
+ delta = abs(xmax - xmin)
+ if xmin > xmax:
+ xmax = xmin
+ xmax -= 0.005 * delta
+ self.text.set_x(xmax)
+
+ def contains(self, mouseevent):
+ """Mimic Artist.contains, and call it on the line Artist.
+
+ :param mouseevent:
+ :return: Picking status and associated information as a dict
+ :rtype: (bool,dict)
+ """
+ return self.line.contains(mouseevent)
+
+
+class _DoubleColoredLinePatch(matplotlib.patches.Patch):
+ """Matplotlib patch to display any patch using double color."""
+
+ def __init__(self, patch):
+ super(_DoubleColoredLinePatch, self).__init__()
+ self.__patch = patch
+ self.linebgcolor = None
+
+ def __getattr__(self, name):
+ return getattr(self.__patch, name)
+
+ def draw(self, renderer):
+ oldLineStype = self.__patch.get_linestyle()
+ if self.linebgcolor is not None and oldLineStype != "solid":
+ oldLineColor = self.__patch.get_edgecolor()
+ oldHatch = self.__patch.get_hatch()
+ self.__patch.set_linestyle("solid")
+ self.__patch.set_edgecolor(self.linebgcolor)
+ self.__patch.set_hatch(None)
+ self.__patch.draw(renderer)
+ self.__patch.set_linestyle(oldLineStype)
+ self.__patch.set_edgecolor(oldLineColor)
+ self.__patch.set_hatch(oldHatch)
+ self.__patch.draw(renderer)
+
+ def set_transform(self, transform):
+ self.__patch.set_transform(transform)
+
+ def get_path(self):
+ return self.__patch.get_path()
+
+ def contains(self, mouseevent, radius=None):
+ return self.__patch.contains(mouseevent, radius)
+
+ def contains_point(self, point, radius=None):
+ return self.__patch.contains_point(point, radius)
+
+
+class Image(AxesImage):
+ """An AxesImage with a fast path for uint8 RGBA images.
+
+ :param List[float] silx_origin: (ox, oy) Offset of the image.
+ :param List[float] silx_scale: (sx, sy) Scale of the image.
+ """
+
+ def __init__(self, *args,
+ silx_origin=(0., 0.),
+ silx_scale=(1., 1.),
+ **kwargs):
+ super().__init__(*args, **kwargs)
+ self.__silx_origin = silx_origin
+ self.__silx_scale = silx_scale
+
+ def contains(self, mouseevent):
+ """Overridden to fill 'ind' with row and column"""
+ inside, info = super().contains(mouseevent)
+ if inside:
+ x, y = mouseevent.xdata, mouseevent.ydata
+ ox, oy = self.__silx_origin
+ sx, sy = self.__silx_scale
+ height, width = self.get_size()
+ column = numpy.clip(int((x - ox) / sx), 0, width - 1)
+ row = numpy.clip(int((y - oy) / sy), 0, height - 1)
+ info['ind'] = (row,), (column,)
+ return inside, info
+
+ def set_data(self, A):
+ """Overridden to add a fast path for RGBA unit8 images"""
+ A = numpy.array(A, copy=False)
+ if A.ndim != 3 or A.shape[2] != 4 or A.dtype != numpy.uint8:
+ super(Image, self).set_data(A)
+ else:
+ # Call AxesImage.set_data with small data to set attributes
+ super(Image, self).set_data(numpy.zeros((2, 2, 4), dtype=A.dtype))
+ self._A = A # Override stored data
+
+
+class BackendMatplotlib(BackendBase.BackendBase):
+ """Base class for Matplotlib backend without a FigureCanvas.
+
+ For interactive on screen plot, see :class:`BackendMatplotlibQt`.
+
+ See :class:`BackendBase.BackendBase` for public API documentation.
+ """
+
+ def __init__(self, plot, parent=None):
+ super(BackendMatplotlib, self).__init__(plot, parent)
+
+ # matplotlib is handling keep aspect ratio at draw time
+ # When keep aspect ratio is on, and one changes the limits and
+ # ask them *before* next draw has been performed he will get the
+ # limits without applying keep aspect ratio.
+ # This attribute is used to ensure consistent values returned
+ # when getting the limits at the expense of a replot
+ self._dirtyLimits = True
+ self._axesDisplayed = True
+ self._matplotlibVersion = _parse_version(matplotlib.__version__)
+
+ self.fig = Figure()
+ self.fig.set_facecolor("w")
+
+ self.ax = self.fig.add_axes([.15, .15, .75, .75], label="left")
+ self.ax2 = self.ax.twinx()
+ self.ax2.set_label("right")
+ # Make sure background of Axes is displayed
+ self.ax2.patch.set_visible(False)
+ self.ax.patch.set_visible(True)
+
+ # Set axis zorder=0.5 so grid is displayed at 0.5
+ self.ax.set_axisbelow(True)
+
+ # disable the use of offsets
+ try:
+ axes = [
+ self.ax.get_yaxis().get_major_formatter(),
+ self.ax.get_xaxis().get_major_formatter(),
+ self.ax2.get_yaxis().get_major_formatter(),
+ self.ax2.get_xaxis().get_major_formatter(),
+ ]
+ for axis in axes:
+ axis.set_useOffset(False)
+ axis.set_scientific(False)
+ except:
+ _logger.warning('Cannot disabled axes offsets in %s '
+ % matplotlib.__version__)
+
+ self.ax2.set_autoscaley_on(True)
+
+ # this works but the figure color is left
+ if self._matplotlibVersion < _parse_version('2'):
+ self.ax.set_axis_bgcolor('none')
+ else:
+ self.ax.set_facecolor('none')
+ self.fig.sca(self.ax)
+
+ self._background = None
+
+ self._colormaps = {}
+
+ self._graphCursor = tuple()
+
+ self._enableAxis('right', False)
+ self._isXAxisTimeSeries = False
+
+ def getItemsFromBackToFront(self, condition=None):
+ """Order as BackendBase + take into account matplotlib Axes structure"""
+ def axesOrder(item):
+ if item.isOverlay():
+ return 2
+ elif isinstance(item, items.YAxisMixIn) and item.getYAxis() == 'right':
+ return 1
+ else:
+ return 0
+
+ return sorted(
+ BackendBase.BackendBase.getItemsFromBackToFront(
+ self, condition=condition),
+ key=axesOrder)
+
+ def _overlayItems(self):
+ """Generator of backend renderer for overlay items"""
+ for item in self._plot.getItems():
+ if (item.isOverlay() and
+ item.isVisible() and
+ item._backendRenderer is not None):
+ yield item._backendRenderer
+
+ def _hasOverlays(self):
+ """Returns whether there is an overlay layer or not.
+
+ The overlay layers contains overlay items and the crosshair.
+
+ :rtype: bool
+ """
+ if self._graphCursor:
+ return True # There is the crosshair
+
+ for item in self._overlayItems():
+ return True # There is at least one overlay item
+ return False
+
+ # Add methods
+
+ def _getMarkerFromSymbol(self, symbol):
+ """Returns a marker that can be displayed by matplotlib.
+
+ :param str symbol: A symbol description used by silx
+ :rtype: Union[str,int,matplotlib.path.Path]
+ """
+ path = get_path_from_symbol(symbol)
+ if path is not None:
+ return path
+ num = _SPECIAL_MARKERS.get(symbol, None)
+ if num is not None:
+ return num
+ # This symbol must be supported by matplotlib
+ return symbol
+
+ def addCurve(self, x, y,
+ color, symbol, linewidth, linestyle,
+ yaxis,
+ xerror, yerror,
+ fill, alpha, symbolsize, baseline):
+ for parameter in (x, y, color, symbol, linewidth, linestyle,
+ yaxis, fill, alpha, symbolsize):
+ assert parameter is not None
+ assert yaxis in ('left', 'right')
+
+ if (len(color) == 4 and
+ type(color[3]) in [type(1), numpy.uint8, numpy.int8]):
+ color = numpy.array(color, dtype=numpy.float64) / 255.
+
+ if yaxis == "right":
+ axes = self.ax2
+ self._enableAxis("right", True)
+ else:
+ axes = self.ax
+
+ pickradius = 3
+
+ artists = [] # All the artists composing the curve
+
+ # First add errorbars if any so they are behind the curve
+ if xerror is not None or yerror is not None:
+ if hasattr(color, 'dtype') and len(color) == len(x):
+ errorbarColor = 'k'
+ else:
+ errorbarColor = color
+
+ # Nx1 error array deprecated in matplotlib >=3.1 (removed in 3.3)
+ if (isinstance(xerror, numpy.ndarray) and xerror.ndim == 2 and
+ xerror.shape[1] == 1):
+ xerror = numpy.ravel(xerror)
+ if (isinstance(yerror, numpy.ndarray) and yerror.ndim == 2 and
+ yerror.shape[1] == 1):
+ yerror = numpy.ravel(yerror)
+
+ errorbars = axes.errorbar(x, y,
+ xerr=xerror, yerr=yerror,
+ linestyle=' ', color=errorbarColor)
+ artists += list(errorbars.get_children())
+
+ if hasattr(color, 'dtype') and len(color) == len(x):
+ # scatter plot
+ if color.dtype not in [numpy.float32, numpy.float64]:
+ actualColor = color / 255.
+ else:
+ actualColor = color
+
+ if linestyle not in ["", " ", None]:
+ # scatter plot with an actual line ...
+ # we need to assign a color ...
+ curveList = axes.plot(x, y,
+ linestyle=linestyle,
+ color=actualColor[0],
+ linewidth=linewidth,
+ picker=True,
+ pickradius=pickradius,
+ marker=None)
+ artists += list(curveList)
+
+ marker = self._getMarkerFromSymbol(symbol)
+ scatter = axes.scatter(x, y,
+ color=actualColor,
+ marker=marker,
+ picker=True,
+ pickradius=pickradius,
+ s=symbolsize**2)
+ artists.append(scatter)
+
+ if fill:
+ if baseline is None:
+ _baseline = FLOAT32_MINPOS
+ else:
+ _baseline = baseline
+ artists.append(axes.fill_between(
+ x, _baseline, y, facecolor=actualColor[0], linestyle=''))
+
+ else: # Curve
+ curveList = axes.plot(x, y,
+ linestyle=linestyle,
+ color=color,
+ linewidth=linewidth,
+ marker=symbol,
+ picker=True,
+ pickradius=pickradius,
+ markersize=symbolsize)
+ artists += list(curveList)
+
+ if fill:
+ if baseline is None:
+ _baseline = FLOAT32_MINPOS
+ else:
+ _baseline = baseline
+ artists.append(
+ axes.fill_between(x, _baseline, y, facecolor=color))
+
+ for artist in artists:
+ if alpha < 1:
+ artist.set_alpha(alpha)
+
+ return _PickableContainer(artists)
+
+ def addImage(self, data, origin, scale, colormap, alpha):
+ # Non-uniform image
+ # http://wiki.scipy.org/Cookbook/Histograms
+ # Non-linear axes
+ # http://stackoverflow.com/questions/11488800/non-linear-axes-for-imshow-in-matplotlib
+ for parameter in (data, origin, scale):
+ assert parameter is not None
+
+ origin = float(origin[0]), float(origin[1])
+ scale = float(scale[0]), float(scale[1])
+ height, width = data.shape[0:2]
+
+ # All image are shown as RGBA image
+ image = Image(self.ax,
+ interpolation='nearest',
+ picker=True,
+ origin='lower',
+ silx_origin=origin,
+ silx_scale=scale)
+
+ if alpha < 1:
+ image.set_alpha(alpha)
+
+ # Set image extent
+ xmin = origin[0]
+ xmax = xmin + scale[0] * width
+ if scale[0] < 0.:
+ xmin, xmax = xmax, xmin
+
+ ymin = origin[1]
+ ymax = ymin + scale[1] * height
+ if scale[1] < 0.:
+ ymin, ymax = ymax, ymin
+
+ image.set_extent((xmin, xmax, ymin, ymax))
+
+ # Set image data
+ if scale[0] < 0. or scale[1] < 0.:
+ # For negative scale, step by -1
+ xstep = 1 if scale[0] >= 0. else -1
+ ystep = 1 if scale[1] >= 0. else -1
+ data = data[::ystep, ::xstep]
+
+ if data.ndim == 2: # Data image, convert to RGBA image
+ data = colormap.applyToData(data)
+ elif data.dtype == numpy.uint16:
+ # Normalize uint16 data to have a similar behavior as opengl backend
+ data = data.astype(numpy.float32)
+ data /= 65535
+
+ image.set_data(data)
+ self.ax.add_artist(image)
+ return image
+
+ def addTriangles(self, x, y, triangles, color, alpha):
+ for parameter in (x, y, triangles, color, alpha):
+ assert parameter is not None
+
+ color = numpy.array(color, copy=False)
+ assert color.ndim == 2 and len(color) == len(x)
+
+ if color.dtype not in [numpy.float32, numpy.float64]:
+ color = color.astype(numpy.float32) / 255.
+
+ collection = TriMesh(
+ Triangulation(x, y, triangles),
+ alpha=alpha,
+ pickradius=0) # 0 enables picking on filled triangle
+ collection.set_color(color)
+ self.ax.add_collection(collection)
+
+ return collection
+
+ def addShape(self, x, y, shape, color, fill, overlay,
+ linestyle, linewidth, linebgcolor):
+ if (linebgcolor is not None and
+ shape not in ('rectangle', 'polygon', 'polylines')):
+ _logger.warning(
+ 'linebgcolor not implemented for %s with matplotlib backend',
+ shape)
+ xView = numpy.array(x, copy=False)
+ yView = numpy.array(y, copy=False)
+
+ linestyle = normalize_linestyle(linestyle)
+
+ if shape == "line":
+ item = self.ax.plot(x, y, color=color,
+ linestyle=linestyle, linewidth=linewidth,
+ marker=None)[0]
+
+ elif shape == "hline":
+ if hasattr(y, "__len__"):
+ y = y[-1]
+ item = self.ax.axhline(y, color=color,
+ linestyle=linestyle, linewidth=linewidth)
+
+ elif shape == "vline":
+ if hasattr(x, "__len__"):
+ x = x[-1]
+ item = self.ax.axvline(x, color=color,
+ linestyle=linestyle, linewidth=linewidth)
+
+ elif shape == 'rectangle':
+ xMin = numpy.nanmin(xView)
+ xMax = numpy.nanmax(xView)
+ yMin = numpy.nanmin(yView)
+ yMax = numpy.nanmax(yView)
+ w = xMax - xMin
+ h = yMax - yMin
+ item = Rectangle(xy=(xMin, yMin),
+ width=w,
+ height=h,
+ fill=False,
+ color=color,
+ linestyle=linestyle,
+ linewidth=linewidth)
+ if fill:
+ item.set_hatch('.')
+
+ if linestyle != "solid" and linebgcolor is not None:
+ item = _DoubleColoredLinePatch(item)
+ item.linebgcolor = linebgcolor
+
+ self.ax.add_patch(item)
+
+ elif shape in ('polygon', 'polylines'):
+ points = numpy.array((xView, yView)).T
+ if shape == 'polygon':
+ closed = True
+ else: # shape == 'polylines'
+ closed = numpy.all(numpy.equal(points[0], points[-1]))
+ item = Polygon(points,
+ closed=closed,
+ fill=False,
+ color=color,
+ linestyle=linestyle,
+ linewidth=linewidth)
+ if fill and shape == 'polygon':
+ item.set_hatch('/')
+
+ if linestyle != "solid" and linebgcolor is not None:
+ item = _DoubleColoredLinePatch(item)
+ item.linebgcolor = linebgcolor
+
+ self.ax.add_patch(item)
+
+ else:
+ raise NotImplementedError("Unsupported item shape %s" % shape)
+
+ if overlay:
+ item.set_animated(True)
+
+ return item
+
+ def addMarker(self, x, y, text, color,
+ symbol, linestyle, linewidth, constraint, yaxis):
+ textArtist = None
+
+ xmin, xmax = self.getGraphXLimits()
+ ymin, ymax = self.getGraphYLimits(axis=yaxis)
+
+ if yaxis == 'left':
+ ax = self.ax
+ elif yaxis == 'right':
+ ax = self.ax2
+ else:
+ assert(False)
+
+ marker = self._getMarkerFromSymbol(symbol)
+ if x is not None and y is not None:
+ line = ax.plot(x, y,
+ linestyle=" ",
+ color=color,
+ marker=marker,
+ markersize=10.)[-1]
+
+ if text is not None:
+ textArtist = _TextWithOffset(x, y, text,
+ color=color,
+ horizontalalignment='left')
+ if symbol is not None:
+ textArtist.pixel_offset = 10, 3
+ elif x is not None:
+ line = ax.axvline(x,
+ color=color,
+ linewidth=linewidth,
+ linestyle=linestyle)
+ if text is not None:
+ # Y position will be updated in updateMarkerText call
+ textArtist = _TextWithOffset(x, 1., text,
+ color=color,
+ horizontalalignment='left',
+ verticalalignment='top')
+ textArtist.pixel_offset = 5, 3
+ elif y is not None:
+ line = ax.axhline(y,
+ color=color,
+ linewidth=linewidth,
+ linestyle=linestyle)
+
+ if text is not None:
+ # X position will be updated in updateMarkerText call
+ textArtist = _TextWithOffset(1., y, text,
+ color=color,
+ horizontalalignment='right',
+ verticalalignment='top')
+ textArtist.pixel_offset = 5, 3
+ else:
+ raise RuntimeError('A marker must at least have one coordinate')
+
+ line.set_picker(True)
+ line.set_pickradius(5)
+
+ # All markers are overlays
+ line.set_animated(True)
+ if textArtist is not None:
+ ax.add_artist(textArtist)
+ textArtist.set_animated(True)
+
+ artists = [line] if textArtist is None else [line, textArtist]
+ container = _MarkerContainer(artists, symbol, x, y, yaxis)
+ container.updateMarkerText(xmin, xmax, ymin, ymax, self.isYAxisInverted())
+
+ return container
+
+ def _updateMarkers(self):
+ xmin, xmax = self.ax.get_xbound()
+ ymin1, ymax1 = self.ax.get_ybound()
+ ymin2, ymax2 = self.ax2.get_ybound()
+ yinverted = self.isYAxisInverted()
+ for item in self._overlayItems():
+ if isinstance(item, _MarkerContainer):
+ if item.yAxis == 'left':
+ item.updateMarkerText(xmin, xmax, ymin1, ymax1, yinverted)
+ else:
+ item.updateMarkerText(xmin, xmax, ymin2, ymax2, yinverted)
+
+ # Remove methods
+
+ def remove(self, item):
+ try:
+ item.remove()
+ except ValueError:
+ pass # Already removed e.g., in set[X|Y]AxisLogarithmic
+
+ # Interaction methods
+
+ def setGraphCursor(self, flag, color, linewidth, linestyle):
+ if flag:
+ lineh = self.ax.axhline(
+ self.ax.get_ybound()[0], visible=False, color=color,
+ linewidth=linewidth, linestyle=linestyle)
+ lineh.set_animated(True)
+
+ linev = self.ax.axvline(
+ self.ax.get_xbound()[0], visible=False, color=color,
+ linewidth=linewidth, linestyle=linestyle)
+ linev.set_animated(True)
+
+ self._graphCursor = lineh, linev
+ else:
+ if self._graphCursor:
+ lineh, linev = self._graphCursor
+ lineh.remove()
+ linev.remove()
+ self._graphCursor = tuple()
+
+ # Active curve
+
+ def setCurveColor(self, curve, color):
+ # Store Line2D and PathCollection
+ for artist in curve.get_children():
+ if isinstance(artist, (Line2D, LineCollection)):
+ artist.set_color(color)
+ elif isinstance(artist, PathCollection):
+ artist.set_facecolors(color)
+ artist.set_edgecolors(color)
+ else:
+ _logger.warning(
+ 'setActiveCurve ignoring artist %s', str(artist))
+
+ # Misc.
+
+ def getWidgetHandle(self):
+ return self.fig.canvas
+
+ def _enableAxis(self, axis, flag=True):
+ """Show/hide Y axis
+
+ :param str axis: Axis name: 'left' or 'right'
+ :param bool flag: Default, True
+ """
+ assert axis in ('right', 'left')
+ axes = self.ax2 if axis == 'right' else self.ax
+ axes.get_yaxis().set_visible(flag)
+
+ def replot(self):
+ """Do not perform rendering.
+
+ Override in subclass to actually draw something.
+ """
+ with self._plot._paintContext():
+ self._replot()
+
+ def _replot(self):
+ """Call from subclass :meth:`replot` to handle updates"""
+ # TODO images, markers? scatter plot? move in remove?
+ # Right Y axis only support curve for now
+ # Hide right Y axis if no line is present
+ self._dirtyLimits = False
+ if not self.ax2.lines:
+ self._enableAxis('right', False)
+
+ def _drawOverlays(self):
+ """Draw overlays if any."""
+ def condition(item):
+ return (item.isVisible() and
+ item._backendRenderer is not None and
+ item.isOverlay())
+
+ for item in self.getItemsFromBackToFront(condition=condition):
+ if (isinstance(item, items.YAxisMixIn) and
+ item.getYAxis() == 'right'):
+ axes = self.ax2
+ else:
+ axes = self.ax
+ axes.draw_artist(item._backendRenderer)
+
+ for item in self._graphCursor:
+ self.ax.draw_artist(item)
+
+ def updateZOrder(self):
+ """Reorder all items with z order from 0 to 1"""
+ items = self.getItemsFromBackToFront(
+ lambda item: item.isVisible() and item._backendRenderer is not None)
+ count = len(items)
+ for index, item in enumerate(items):
+ if item.getZValue() < 0.5:
+ # Make sure matplotlib z order is below the grid (with z=0.5)
+ zorder = 0.5 * index / count
+ else: # Make sure matplotlib z order is above the grid (> 0.5)
+ zorder = 1. + index / count
+ if zorder != item._backendRenderer.get_zorder():
+ item._backendRenderer.set_zorder(zorder)
+
+ def saveGraph(self, fileName, fileFormat, dpi):
+ self.updateZOrder()
+
+ # fileName can be also a StringIO or file instance
+ if dpi is not None:
+ self.fig.savefig(fileName, format=fileFormat, dpi=dpi)
+ else:
+ self.fig.savefig(fileName, format=fileFormat)
+ self._plot._setDirtyPlot()
+
+ # Graph labels
+
+ def setGraphTitle(self, title):
+ self.ax.set_title(title)
+
+ def setGraphXLabel(self, label):
+ self.ax.set_xlabel(label)
+
+ def setGraphYLabel(self, label, axis):
+ axes = self.ax if axis == 'left' else self.ax2
+ axes.set_ylabel(label)
+
+ # Graph limits
+
+ def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None):
+ # Let matplotlib taking care of keep aspect ratio if any
+ self._dirtyLimits = True
+ self.ax.set_xlim(min(xmin, xmax), max(xmin, xmax))
+
+ if y2min is not None and y2max is not None:
+ if not self.isYAxisInverted():
+ self.ax2.set_ylim(min(y2min, y2max), max(y2min, y2max))
+ else:
+ self.ax2.set_ylim(max(y2min, y2max), min(y2min, y2max))
+
+ if not self.isYAxisInverted():
+ self.ax.set_ylim(min(ymin, ymax), max(ymin, ymax))
+ else:
+ self.ax.set_ylim(max(ymin, ymax), min(ymin, ymax))
+
+ self._updateMarkers()
+
+ def getGraphXLimits(self):
+ if self._dirtyLimits and self.isKeepDataAspectRatio():
+ self.ax.apply_aspect()
+ self.ax2.apply_aspect()
+ self._dirtyLimits = False
+ return self.ax.get_xbound()
+
+ def setGraphXLimits(self, xmin, xmax):
+ self._dirtyLimits = True
+ self.ax.set_xlim(min(xmin, xmax), max(xmin, xmax))
+ self._updateMarkers()
+
+ def getGraphYLimits(self, axis):
+ assert axis in ('left', 'right')
+ ax = self.ax2 if axis == 'right' else self.ax
+
+ if not ax.get_visible():
+ return None
+
+ if self._dirtyLimits and self.isKeepDataAspectRatio():
+ self.ax.apply_aspect()
+ self.ax2.apply_aspect()
+ self._dirtyLimits = False
+
+ return ax.get_ybound()
+
+ def setGraphYLimits(self, ymin, ymax, axis):
+ ax = self.ax2 if axis == 'right' else self.ax
+ if ymax < ymin:
+ ymin, ymax = ymax, ymin
+ self._dirtyLimits = True
+
+ if self.isKeepDataAspectRatio():
+ # matplotlib keeps limits of shared axis when keeping aspect ratio
+ # So x limits are kept when changing y limits....
+ # Change x limits first by taking into account aspect ratio
+ # and then change y limits.. so matplotlib does not need
+ # to make change (to y) to keep aspect ratio
+ xmin, xmax = ax.get_xbound()
+ curYMin, curYMax = ax.get_ybound()
+
+ newXRange = (xmax - xmin) * (ymax - ymin) / (curYMax - curYMin)
+ xcenter = 0.5 * (xmin + xmax)
+ ax.set_xlim(xcenter - 0.5 * newXRange, xcenter + 0.5 * newXRange)
+
+ if not self.isYAxisInverted():
+ ax.set_ylim(ymin, ymax)
+ else:
+ ax.set_ylim(ymax, ymin)
+
+ self._updateMarkers()
+
+ # Graph axes
+
+ def setXAxisTimeZone(self, tz):
+ super(BackendMatplotlib, self).setXAxisTimeZone(tz)
+
+ # Make new formatter and locator with the time zone.
+ self.setXAxisTimeSeries(self.isXAxisTimeSeries())
+
+ def isXAxisTimeSeries(self):
+ return self._isXAxisTimeSeries
+
+ def setXAxisTimeSeries(self, isTimeSeries):
+ self._isXAxisTimeSeries = isTimeSeries
+ if self._isXAxisTimeSeries:
+ # We can't use a matplotlib.dates.DateFormatter because it expects
+ # the data to be in datetimes. Silx works internally with
+ # timestamps (floats).
+ locator = NiceDateLocator(tz=self.getXAxisTimeZone())
+ self.ax.xaxis.set_major_locator(locator)
+ self.ax.xaxis.set_major_formatter(
+ NiceAutoDateFormatter(locator, tz=self.getXAxisTimeZone()))
+ else:
+ try:
+ scalarFormatter = ScalarFormatter(useOffset=False)
+ except:
+ _logger.warning('Cannot disabled axes offsets in %s ' %
+ matplotlib.__version__)
+ scalarFormatter = ScalarFormatter()
+ self.ax.xaxis.set_major_formatter(scalarFormatter)
+
+ def setXAxisLogarithmic(self, flag):
+ # Workaround for matplotlib 2.1.0 when one tries to set an axis
+ # to log scale with both limits <= 0
+ # In this case a draw with positive limits is needed first
+ if flag and self._matplotlibVersion >= _parse_version('2.1.0'):
+ xlim = self.ax.get_xlim()
+ if xlim[0] <= 0 and xlim[1] <= 0:
+ self.ax.set_xlim(1, 10)
+ self.draw()
+
+ self.ax2.set_xscale('log' if flag else 'linear')
+ self.ax.set_xscale('log' if flag else 'linear')
+
+ def setYAxisLogarithmic(self, flag):
+ # Workaround for matplotlib 2.0 issue with negative bounds
+ # before switching to log scale
+ if flag and self._matplotlibVersion >= _parse_version('2.0.0'):
+ redraw = False
+ for axis, dataRangeIndex in ((self.ax, 1), (self.ax2, 2)):
+ ylim = axis.get_ylim()
+ if ylim[0] <= 0 or ylim[1] <= 0:
+ dataRange = self._plot.getDataRange()[dataRangeIndex]
+ if dataRange is None:
+ dataRange = 1, 100 # Fallback
+ axis.set_ylim(*dataRange)
+ redraw = True
+ if redraw:
+ self.draw()
+
+ self.ax2.set_yscale('log' if flag else 'linear')
+ self.ax.set_yscale('log' if flag else 'linear')
+
+ def setYAxisInverted(self, flag):
+ if self.ax.yaxis_inverted() != bool(flag):
+ self.ax.invert_yaxis()
+ self._updateMarkers()
+
+ def isYAxisInverted(self):
+ return self.ax.yaxis_inverted()
+
+ def isKeepDataAspectRatio(self):
+ return self.ax.get_aspect() in (1.0, 'equal')
+
+ def setKeepDataAspectRatio(self, flag):
+ self.ax.set_aspect(1.0 if flag else 'auto')
+ self.ax2.set_aspect(1.0 if flag else 'auto')
+
+ def setGraphGrid(self, which):
+ self.ax.grid(False, which='both') # Disable all grid first
+ if which is not None:
+ self.ax.grid(True, which=which)
+
+ # Data <-> Pixel coordinates conversion
+
+ def _getDevicePixelRatio(self) -> float:
+ """Compatibility wrapper for devicePixelRatioF"""
+ return 1.
+
+ def _mplToQtPosition(self, x: float, y: float) -> Tuple[float, float]:
+ """Convert matplotlib "display" space coord to Qt widget logical pixel
+ """
+ ratio = self._getDevicePixelRatio()
+ # Convert from matplotlib origin (bottom) to Qt origin (top)
+ # and apply device pixel ratio
+ return x / ratio, (self.fig.get_window_extent().height - y) / ratio
+
+ def _qtToMplPosition(self, x: float, y: float) -> Tuple[float, float]:
+ """Convert Qt widget logical pixel to matplotlib "display" space coord
+ """
+ ratio = self._getDevicePixelRatio()
+ # Apply device pixel ration and
+ # convert from Qt origin (top) to matplotlib origin (bottom)
+ return x * ratio, self.fig.get_window_extent().height - (y * ratio)
+
+ def dataToPixel(self, x, y, axis):
+ ax = self.ax2 if axis == "right" else self.ax
+ displayPos = ax.transData.transform_point((x, y)).transpose()
+ return self._mplToQtPosition(*displayPos)
+
+ def pixelToData(self, x, y, axis):
+ ax = self.ax2 if axis == "right" else self.ax
+ displayPos = self._qtToMplPosition(x, y)
+ return tuple(ax.transData.inverted().transform_point(displayPos))
+
+ def getPlotBoundsInPixels(self):
+ bbox = self.ax.get_window_extent()
+ # Warning this is not returning int...
+ ratio = self._getDevicePixelRatio()
+ return tuple(int(value / ratio) for value in (
+ bbox.xmin,
+ self.fig.get_window_extent().height - bbox.ymax,
+ bbox.width,
+ bbox.height))
+
+ def setAxesMargins(self, left: float, top: float, right: float, bottom: float):
+ width, height = 1. - left - right, 1. - top - bottom
+ position = left, bottom, width, height
+
+ # Toggle display of axes and viewbox rect
+ isFrameOn = position != (0., 0., 1., 1.)
+ self.ax.set_frame_on(isFrameOn)
+ self.ax2.set_frame_on(isFrameOn)
+
+ self.ax.set_position(position)
+ self.ax2.set_position(position)
+
+ self._synchronizeBackgroundColors()
+ self._synchronizeForegroundColors()
+ self._plot._setDirtyPlot()
+
+ def _synchronizeBackgroundColors(self):
+ backgroundColor = self._plot.getBackgroundColor().getRgbF()
+
+ dataBackgroundColor = self._plot.getDataBackgroundColor()
+ if dataBackgroundColor.isValid():
+ dataBackgroundColor = dataBackgroundColor.getRgbF()
+ else:
+ dataBackgroundColor = backgroundColor
+
+ if self.ax.get_frame_on():
+ self.fig.patch.set_facecolor(backgroundColor)
+ if self._matplotlibVersion < _parse_version('2'):
+ self.ax.set_axis_bgcolor(dataBackgroundColor)
+ else:
+ self.ax.set_facecolor(dataBackgroundColor)
+ else:
+ self.fig.patch.set_facecolor(dataBackgroundColor)
+
+ def _synchronizeForegroundColors(self):
+ foregroundColor = self._plot.getForegroundColor().getRgbF()
+
+ gridColor = self._plot.getGridColor()
+ if gridColor.isValid():
+ gridColor = gridColor.getRgbF()
+ else:
+ gridColor = foregroundColor
+
+ for axes in (self.ax, self.ax2):
+ if axes.get_frame_on():
+ axes.spines['bottom'].set_color(foregroundColor)
+ axes.spines['top'].set_color(foregroundColor)
+ axes.spines['right'].set_color(foregroundColor)
+ axes.spines['left'].set_color(foregroundColor)
+ axes.tick_params(axis='x', colors=foregroundColor)
+ axes.tick_params(axis='y', colors=foregroundColor)
+ axes.yaxis.label.set_color(foregroundColor)
+ axes.xaxis.label.set_color(foregroundColor)
+ axes.title.set_color(foregroundColor)
+
+ for line in axes.get_xgridlines():
+ line.set_color(gridColor)
+
+ for line in axes.get_ygridlines():
+ line.set_color(gridColor)
+ # axes.grid().set_markeredgecolor(gridColor)
+
+ def setBackgroundColors(self, backgroundColor, dataBackgroundColor):
+ self._synchronizeBackgroundColors()
+
+ def setForegroundColors(self, foregroundColor, gridColor):
+ self._synchronizeForegroundColors()
+
+
+class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib):
+ """QWidget matplotlib backend using a QtAgg canvas.
+
+ It adds fast overlay drawing and mouse event management.
+ """
+
+ _sigPostRedisplay = qt.Signal()
+ """Signal handling automatic asynchronous replot"""
+
+ def __init__(self, plot, parent=None):
+ BackendMatplotlib.__init__(self, plot, parent)
+ FigureCanvasQTAgg.__init__(self, self.fig)
+ self.setParent(parent)
+
+ self._limitsBeforeResize = None
+
+ FigureCanvasQTAgg.setSizePolicy(
+ self, qt.QSizePolicy.Expanding, qt.QSizePolicy.Expanding)
+ FigureCanvasQTAgg.updateGeometry(self)
+
+ # Make postRedisplay asynchronous using Qt signal
+ self._sigPostRedisplay.connect(
+ self.__deferredReplot, qt.Qt.QueuedConnection)
+
+ self._picked = None
+
+ self.mpl_connect('button_press_event', self._onMousePress)
+ self.mpl_connect('button_release_event', self._onMouseRelease)
+ self.mpl_connect('motion_notify_event', self._onMouseMove)
+ self.mpl_connect('scroll_event', self._onMouseWheel)
+
+ def postRedisplay(self):
+ self._sigPostRedisplay.emit()
+
+ def __deferredReplot(self):
+ # Since this is deferred, makes sure it is still needed
+ plot = self._plotRef()
+ if (plot is not None and
+ plot._getDirtyPlot() and
+ plot.getBackend() is self):
+ self.replot()
+
+ def _getDevicePixelRatio(self) -> float:
+ """Compatibility wrapper for devicePixelRatioF"""
+ if hasattr(self, 'devicePixelRatioF'):
+ ratio = self.devicePixelRatioF()
+ else: # Qt < 5.6 compatibility
+ ratio = float(self.devicePixelRatio())
+ # Safety net: avoid returning 0
+ return ratio if ratio != 0. else 1.
+
+ # Mouse event forwarding
+
+ _MPL_TO_PLOT_BUTTONS = {1: 'left', 2: 'middle', 3: 'right'}
+
+ def _onMousePress(self, event):
+ button = self._MPL_TO_PLOT_BUTTONS.get(event.button, None)
+ if button is not None:
+ x, y = self._mplToQtPosition(event.x, event.y)
+ self._plot.onMousePress(int(x), int(y), button)
+
+ def _onMouseMove(self, event):
+ x, y = self._mplToQtPosition(event.x, event.y)
+ if self._graphCursor:
+ position = self._plot.pixelToData(
+ x, y, axis='left', check=True)
+ lineh, linev = self._graphCursor
+ if position is not None:
+ linev.set_visible(True)
+ linev.set_xdata((position[0], position[0]))
+ lineh.set_visible(True)
+ lineh.set_ydata((position[1], position[1]))
+ self._plot._setDirtyPlot(overlayOnly=True)
+ elif lineh.get_visible():
+ lineh.set_visible(False)
+ linev.set_visible(False)
+ self._plot._setDirtyPlot(overlayOnly=True)
+ # onMouseMove must trigger replot if dirty flag is raised
+
+ self._plot.onMouseMove(int(x), int(y))
+
+ def _onMouseRelease(self, event):
+ button = self._MPL_TO_PLOT_BUTTONS.get(event.button, None)
+ if button is not None:
+ x, y = self._mplToQtPosition(event.x, event.y)
+ self._plot.onMouseRelease(int(x), int(y), button)
+
+ def _onMouseWheel(self, event):
+ x, y = self._mplToQtPosition(event.x, event.y)
+ self._plot.onMouseWheel(int(x), int(y), event.step)
+
+ def leaveEvent(self, event):
+ """QWidget event handler"""
+ try:
+ plot = self._plot
+ except RuntimeError:
+ pass
+ else:
+ plot.onMouseLeaveWidget()
+
+ # picking
+
+ def pickItem(self, x, y, item):
+ xDisplay, yDisplay = self._qtToMplPosition(x, y)
+ mouseEvent = MouseEvent(
+ 'button_press_event', self, int(xDisplay), int(yDisplay))
+ # Override axes and data position with the axes
+ mouseEvent.inaxes = item.axes
+ mouseEvent.xdata, mouseEvent.ydata = self.pixelToData(
+ x, y, axis='left' if item.axes is self.ax else 'right')
+ picked, info = item.contains(mouseEvent)
+
+ if not picked:
+ return None
+
+ elif isinstance(item, TriMesh):
+ # Convert selected triangle to data point indices
+ triangulation = item._triangulation
+ indices = triangulation.get_masked_triangles()[info['ind'][0]]
+
+ # Sort picked triangle points by distance to mouse
+ # from furthest to closest to put closest point last
+ # This is to be somewhat consistent with last scatter point
+ # being the top one.
+ xdata, ydata = self.pixelToData(x, y, axis='left')
+ dists = ((triangulation.x[indices] - xdata) ** 2 +
+ (triangulation.y[indices] - ydata) ** 2)
+ return indices[numpy.flip(numpy.argsort(dists), axis=0)]
+
+ else: # Returns indices if any
+ return info.get('ind', ())
+
+ # replot control
+
+ def resizeEvent(self, event):
+ # Store current limits
+ self._limitsBeforeResize = (
+ self.ax.get_xbound(), self.ax.get_ybound(), self.ax2.get_ybound())
+
+ FigureCanvasQTAgg.resizeEvent(self, event)
+ if self.isKeepDataAspectRatio() or self._hasOverlays():
+ # This is needed with matplotlib 1.5.x and 2.0.x
+ self._plot._setDirtyPlot()
+
+ def draw(self):
+ """Overload draw
+
+ It performs a full redraw (including overlays) of the plot.
+ It also resets background and emit limits changed signal.
+
+ This is directly called by matplotlib for widget resize.
+ """
+ self.updateZOrder()
+
+ # Starting with mpl 2.1.0, toggling autoscale raises a ValueError
+ # in some situations. See #1081, #1136, #1163,
+ if self._matplotlibVersion >= _parse_version("2.0.0"):
+ try:
+ FigureCanvasQTAgg.draw(self)
+ except ValueError as err:
+ _logger.debug(
+ "ValueError caught while calling FigureCanvasQTAgg.draw: "
+ "'%s'", err)
+ else:
+ FigureCanvasQTAgg.draw(self)
+
+ if self._hasOverlays():
+ # Save background
+ self._background = self.copy_from_bbox(self.fig.bbox)
+ else:
+ self._background = None # Reset background
+
+ # Check if limits changed due to a resize of the widget
+ if self._limitsBeforeResize is not None:
+ xLimits, yLimits, yRightLimits = self._limitsBeforeResize
+ self._limitsBeforeResize = None
+
+ if (xLimits != self.ax.get_xbound() or
+ yLimits != self.ax.get_ybound()):
+ self._updateMarkers()
+
+ if xLimits != self.ax.get_xbound():
+ self._plot.getXAxis()._emitLimitsChanged()
+ if yLimits != self.ax.get_ybound():
+ self._plot.getYAxis(axis='left')._emitLimitsChanged()
+ if yRightLimits != self.ax2.get_ybound():
+ self._plot.getYAxis(axis='right')._emitLimitsChanged()
+
+ self._drawOverlays()
+
+ def replot(self):
+ with self._plot._paintContext():
+ BackendMatplotlib._replot(self)
+
+ dirtyFlag = self._plot._getDirtyPlot()
+
+ if dirtyFlag == 'overlay':
+ # Only redraw overlays using fast rendering path
+ if self._background is None:
+ self._background = self.copy_from_bbox(self.fig.bbox)
+ self.restore_region(self._background)
+ self._drawOverlays()
+ self.blit(self.fig.bbox)
+
+ elif dirtyFlag: # Need full redraw
+ self.draw()
+
+ # Workaround issue of rendering overlays with some matplotlib versions
+ if (_parse_version('1.5') <= self._matplotlibVersion < _parse_version('2.1') and
+ not hasattr(self, '_firstReplot')):
+ self._firstReplot = False
+ if self._hasOverlays():
+ qt.QTimer.singleShot(0, self.draw) # Request async draw
+
+ # cursor
+
+ _QT_CURSORS = {
+ BackendBase.CURSOR_DEFAULT: qt.Qt.ArrowCursor,
+ BackendBase.CURSOR_POINTING: qt.Qt.PointingHandCursor,
+ BackendBase.CURSOR_SIZE_HOR: qt.Qt.SizeHorCursor,
+ BackendBase.CURSOR_SIZE_VER: qt.Qt.SizeVerCursor,
+ BackendBase.CURSOR_SIZE_ALL: qt.Qt.SizeAllCursor,
+ }
+
+ def setGraphCursorShape(self, cursor):
+ if cursor is None:
+ FigureCanvasQTAgg.unsetCursor(self)
+ else:
+ cursor = self._QT_CURSORS[cursor]
+ FigureCanvasQTAgg.setCursor(self, qt.QCursor(cursor))
diff --git a/src/silx/gui/plot/backends/BackendOpenGL.py b/src/silx/gui/plot/backends/BackendOpenGL.py
new file mode 100755
index 0000000..f1a12af
--- /dev/null
+++ b/src/silx/gui/plot/backends/BackendOpenGL.py
@@ -0,0 +1,1420 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""OpenGL Plot backend."""
+
+from __future__ import division
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "21/12/2018"
+
+import logging
+import weakref
+
+import numpy
+
+from .. import items
+from .._utils import FLOAT32_MINPOS
+from . import BackendBase
+from ... import colors
+from ... import qt
+
+from ..._glutils import gl
+from ... import _glutils as glu
+from . import glutils
+from .glutils.PlotImageFile import saveImageToFile
+
+_logger = logging.getLogger(__name__)
+
+
+# TODO idea: BackendQtMixIn class to share code between mpl and gl
+# TODO check if OpenGL is available
+# TODO make an off-screen mesa backend
+
+# Content #####################################################################
+
+class _ShapeItem(dict):
+ def __init__(self, x, y, shape, color, fill, overlay,
+ linestyle, linewidth, linebgcolor):
+ super(_ShapeItem, self).__init__()
+
+ if shape not in ('polygon', 'rectangle', 'line',
+ 'vline', 'hline', 'polylines'):
+ raise NotImplementedError("Unsupported shape {0}".format(shape))
+
+ x = numpy.array(x, copy=False)
+ y = numpy.array(y, copy=False)
+
+ if shape == 'rectangle':
+ xMin, xMax = x
+ x = numpy.array((xMin, xMin, xMax, xMax))
+ yMin, yMax = y
+ y = numpy.array((yMin, yMax, yMax, yMin))
+
+ # Ignore fill for polylines to mimic matplotlib
+ fill = fill if shape != 'polylines' else False
+
+ self.update({
+ 'shape': shape,
+ 'color': colors.rgba(color),
+ 'fill': 'hatch' if fill else None,
+ 'x': x,
+ 'y': y,
+ 'linestyle': linestyle,
+ 'linewidth': linewidth,
+ 'linebgcolor': linebgcolor,
+ })
+
+
+class _MarkerItem(dict):
+ def __init__(self, x, y, text, color,
+ symbol, linestyle, linewidth, constraint, yaxis):
+ super(_MarkerItem, self).__init__()
+
+ if symbol is None:
+ symbol = '+'
+
+ # Apply constraint to provided position
+ isConstraint = (constraint is not None and
+ x is not None and y is not None)
+ if isConstraint:
+ x, y = constraint(x, y)
+
+ self.update({
+ 'x': x,
+ 'y': y,
+ 'text': text,
+ 'color': colors.rgba(color),
+ 'constraint': constraint if isConstraint else None,
+ 'symbol': symbol,
+ 'linestyle': linestyle,
+ 'linewidth': linewidth,
+ 'yaxis': yaxis,
+ })
+
+
+# shaders #####################################################################
+
+_baseVertShd = """
+ attribute vec2 position;
+ uniform mat4 matrix;
+ uniform bvec2 isLog;
+
+ const float oneOverLog10 = 0.43429448190325176;
+
+ void main(void) {
+ vec2 posTransformed = position;
+ if (isLog.x) {
+ posTransformed.x = oneOverLog10 * log(position.x);
+ }
+ if (isLog.y) {
+ posTransformed.y = oneOverLog10 * log(position.y);
+ }
+ gl_Position = matrix * vec4(posTransformed, 0.0, 1.0);
+ }
+ """
+
+_baseFragShd = """
+ uniform vec4 color;
+ uniform int hatchStep;
+ uniform float tickLen;
+
+ void main(void) {
+ if (tickLen != 0.) {
+ if (mod((gl_FragCoord.x + gl_FragCoord.y) / tickLen, 2.) < 1.) {
+ gl_FragColor = color;
+ } else {
+ discard;
+ }
+ } else if (hatchStep == 0 ||
+ mod(gl_FragCoord.x - gl_FragCoord.y, float(hatchStep)) == 0.) {
+ gl_FragColor = color;
+ } else {
+ discard;
+ }
+ }
+ """
+
+_texVertShd = """
+ attribute vec2 position;
+ attribute vec2 texCoords;
+ uniform mat4 matrix;
+
+ varying vec2 coords;
+
+ void main(void) {
+ gl_Position = matrix * vec4(position, 0.0, 1.0);
+ coords = texCoords;
+ }
+ """
+
+_texFragShd = """
+ uniform sampler2D tex;
+
+ varying vec2 coords;
+
+ void main(void) {
+ gl_FragColor = texture2D(tex, coords);
+ gl_FragColor.a = 1.0;
+ }
+ """
+
+# BackendOpenGL ###############################################################
+
+
+class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
+ """OpenGL-based Plot backend.
+
+ WARNINGS:
+ Unless stated otherwise, this API is NOT thread-safe and MUST be
+ called from the main thread.
+ When numpy arrays are passed as arguments to the API (through
+ :func:`addCurve` and :func:`addImage`), they are copied only if
+ required.
+ So, the caller should not modify these arrays afterwards.
+ """
+
+ def __init__(self, plot, parent=None, f=qt.Qt.WindowFlags()):
+ glu.OpenGLWidget.__init__(self, parent,
+ alphaBufferSize=8,
+ depthBufferSize=0,
+ stencilBufferSize=0,
+ version=(2, 1),
+ f=f)
+ BackendBase.BackendBase.__init__(self, plot, parent)
+
+ self._backgroundColor = 1., 1., 1., 1.
+ self._dataBackgroundColor = 1., 1., 1., 1.
+
+ self.matScreenProj = glutils.mat4Identity()
+
+ self._progBase = glu.Program(
+ _baseVertShd, _baseFragShd, attrib0='position')
+ self._progTex = glu.Program(
+ _texVertShd, _texFragShd, attrib0='position')
+ self._plotFBOs = weakref.WeakKeyDictionary()
+
+ self._keepDataAspectRatio = False
+
+ self._crosshairCursor = None
+ self._mousePosInPixels = None
+
+ self._glGarbageCollector = []
+
+ self._plotFrame = glutils.GLPlotFrame2D(
+ foregroundColor=(0., 0., 0., 1.),
+ gridColor=(.7, .7, .7, 1.),
+ marginRatios=(.15, .1, .1, .15))
+ self._plotFrame.size = ( # Init size with size int
+ int(self.getDevicePixelRatio() * 640),
+ int(self.getDevicePixelRatio() * 480))
+
+ self.setAutoFillBackground(False)
+ self.setMouseTracking(True)
+
+ # QWidget
+
+ _MOUSE_BTNS = {1: 'left', 2: 'right', 4: 'middle'}
+
+ def sizeHint(self):
+ return qt.QSize(8 * 80, 6 * 80) # Mimic MatplotlibBackend
+
+ def mousePressEvent(self, event):
+ if event.button() not in self._MOUSE_BTNS:
+ return super(BackendOpenGL, self).mousePressEvent(event)
+ self._plot.onMousePress(
+ event.x(), event.y(), self._MOUSE_BTNS[event.button()])
+ event.accept()
+
+ def mouseMoveEvent(self, event):
+ qtPos = event.x(), event.y()
+
+ previousMousePosInPixels = self._mousePosInPixels
+ if qtPos == self._mouseInPlotArea(*qtPos):
+ devicePixelRatio = self.getDevicePixelRatio()
+ devicePos = qtPos[0] * devicePixelRatio, qtPos[1] * devicePixelRatio
+ self._mousePosInPixels = devicePos # Mouse in plot area
+ else:
+ self._mousePosInPixels = None # Mouse outside plot area
+
+ if (self._crosshairCursor is not None and
+ previousMousePosInPixels != self._mousePosInPixels):
+ # Avoid replot when cursor remains outside plot area
+ self._plot._setDirtyPlot(overlayOnly=True)
+
+ self._plot.onMouseMove(*qtPos)
+ event.accept()
+
+ def mouseReleaseEvent(self, event):
+ if event.button() not in self._MOUSE_BTNS:
+ return super(BackendOpenGL, self).mouseReleaseEvent(event)
+ self._plot.onMouseRelease(
+ event.x(), event.y(), self._MOUSE_BTNS[event.button()])
+ event.accept()
+
+ def wheelEvent(self, event):
+ delta = event.angleDelta().y()
+ angleInDegrees = delta / 8.
+ if qt.BINDING == "PySide6":
+ x, y = event.position().x(), event.position().y()
+ else:
+ x, y = event.x(), event.y()
+ self._plot.onMouseWheel(x, y, angleInDegrees)
+ event.accept()
+
+ def leaveEvent(self, _):
+ self._plot.onMouseLeaveWidget()
+
+ # OpenGLWidget API
+
+ def initializeGL(self):
+ gl.testGL()
+
+ gl.glClearStencil(0)
+
+ gl.glEnable(gl.GL_BLEND)
+ # gl.glBlendFunc(gl.GL_SRC_ALPHA, gl.GL_ONE_MINUS_SRC_ALPHA)
+ gl.glBlendFuncSeparate(gl.GL_SRC_ALPHA,
+ gl.GL_ONE_MINUS_SRC_ALPHA,
+ gl.GL_ONE,
+ gl.GL_ONE)
+
+ # For lines
+ gl.glHint(gl.GL_LINE_SMOOTH_HINT, gl.GL_NICEST)
+
+ # For points
+ gl.glEnable(gl.GL_VERTEX_PROGRAM_POINT_SIZE) # OpenGL 2
+ gl.glEnable(gl.GL_POINT_SPRITE) # OpenGL 2
+ # gl.glEnable(gl.GL_PROGRAM_POINT_SIZE)
+
+ def _paintDirectGL(self):
+ self._renderPlotAreaGL()
+ self._plotFrame.render()
+ self._renderOverlayGL()
+
+ def _paintFBOGL(self):
+ context = glu.Context.getCurrent()
+ plotFBOTex = self._plotFBOs.get(context)
+ if (self._plot._getDirtyPlot() or self._plotFrame.isDirty or
+ plotFBOTex is None):
+ self._plotVertices = (
+ # Vertex coordinates
+ numpy.array(((-1., -1.), (1., -1.), (-1., 1.), (1., 1.)),
+ dtype=numpy.float32),
+ # Texture coordinates
+ numpy.array(((0., 0.), (1., 0.), (0., 1.), (1., 1.)),
+ dtype=numpy.float32))
+ if plotFBOTex is None or \
+ plotFBOTex.shape[1] != self._plotFrame.size[0] or \
+ plotFBOTex.shape[0] != self._plotFrame.size[1]:
+ if plotFBOTex is not None:
+ plotFBOTex.discard()
+ plotFBOTex = glu.FramebufferTexture(
+ gl.GL_RGBA,
+ shape=(self._plotFrame.size[1],
+ self._plotFrame.size[0]),
+ minFilter=gl.GL_NEAREST,
+ magFilter=gl.GL_NEAREST,
+ wrap=(gl.GL_CLAMP_TO_EDGE,
+ gl.GL_CLAMP_TO_EDGE))
+ self._plotFBOs[context] = plotFBOTex
+
+ with plotFBOTex:
+ gl.glClearColor(*self._backgroundColor)
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_STENCIL_BUFFER_BIT)
+ self._renderPlotAreaGL()
+ self._plotFrame.render()
+
+ # Render plot in screen coords
+ gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
+
+ self._progTex.use()
+ texUnit = 0
+
+ gl.glUniform1i(self._progTex.uniforms['tex'], texUnit)
+ gl.glUniformMatrix4fv(self._progTex.uniforms['matrix'], 1, gl.GL_TRUE,
+ glutils.mat4Identity().astype(numpy.float32))
+
+ gl.glEnableVertexAttribArray(self._progTex.attributes['position'])
+ gl.glVertexAttribPointer(self._progTex.attributes['position'],
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0,
+ self._plotVertices[0])
+
+ gl.glEnableVertexAttribArray(self._progTex.attributes['texCoords'])
+ gl.glVertexAttribPointer(self._progTex.attributes['texCoords'],
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0,
+ self._plotVertices[1])
+
+ with plotFBOTex.texture:
+ gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(self._plotVertices[0]))
+
+ self._renderOverlayGL()
+
+ def paintGL(self):
+ plot = self._plotRef()
+ if plot is None:
+ return
+
+ with plot._paintContext():
+ with glu.Context.current(self.context()):
+ # Release OpenGL resources
+ for item in self._glGarbageCollector:
+ item.discard()
+ self._glGarbageCollector = []
+
+ gl.glClearColor(*self._backgroundColor)
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_STENCIL_BUFFER_BIT)
+
+ # Check if window is large enough
+ if self._plotFrame.plotSize <= (2, 2):
+ return
+
+ # Sync plot frame with window
+ self._plotFrame.devicePixelRatio = self.getDevicePixelRatio()
+ # self._paintDirectGL()
+ self._paintFBOGL()
+
+ def _renderItems(self, overlay=False):
+ """Render items according to :class:`PlotWidget` order
+
+ Note: Scissor test should already be set.
+
+ :param bool overlay:
+ False (the default) to render item that are not overlays.
+ True to render items that are overlays.
+ """
+ # Values that are often used
+ plotWidth, plotHeight = self._plotFrame.plotSize
+ isXLog = self._plotFrame.xAxis.isLog
+ isYLog = self._plotFrame.yAxis.isLog
+ isYInverted = self._plotFrame.isYAxisInverted
+
+ # Used by marker rendering
+ labels = []
+ pixelOffset = 3
+
+ context = glutils.RenderContext(
+ isXLog=isXLog, isYLog=isYLog, dpi=self.getDotsPerInch())
+
+ for plotItem in self.getItemsFromBackToFront(
+ condition=lambda i: i.isVisible() and i.isOverlay() == overlay):
+ if plotItem._backendRenderer is None:
+ continue
+
+ item = plotItem._backendRenderer
+
+ if isinstance(item, glutils.GLPlotItem): # Render data items
+ gl.glViewport(self._plotFrame.margins.left,
+ self._plotFrame.margins.bottom,
+ plotWidth, plotHeight)
+ # Set matrix
+ if item.yaxis == 'right':
+ context.matrix = self._plotFrame.transformedDataY2ProjMat
+ else:
+ context.matrix = self._plotFrame.transformedDataProjMat
+ item.render(context)
+
+ elif isinstance(item, _ShapeItem): # Render shape items
+ gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
+
+ if ((isXLog and numpy.min(item['x']) < FLOAT32_MINPOS) or
+ (isYLog and numpy.min(item['y']) < FLOAT32_MINPOS)):
+ # Ignore items <= 0. on log axes
+ continue
+
+ if item['shape'] == 'hline':
+ width = self._plotFrame.size[0]
+ _, yPixel = self._plotFrame.dataToPixel(
+ 0.5 * sum(self._plotFrame.dataRanges[0]),
+ item['y'],
+ axis='left')
+ subShapes = [numpy.array(((0., yPixel), (width, yPixel)),
+ dtype=numpy.float32)]
+
+ elif item['shape'] == 'vline':
+ xPixel, _ = self._plotFrame.dataToPixel(
+ item['x'],
+ 0.5 * sum(self._plotFrame.dataRanges[1]),
+ axis='left')
+ height = self._plotFrame.size[1]
+ subShapes = [numpy.array(((xPixel, 0), (xPixel, height)),
+ dtype=numpy.float32)]
+
+ else:
+ # Split sub-shapes at not finite values
+ splits = numpy.nonzero(numpy.logical_not(numpy.logical_and(
+ numpy.isfinite(item['x']), numpy.isfinite(item['y']))))[0]
+ splits = numpy.concatenate(([-1], splits, [len(item['x'])]))
+ subShapes = []
+ for begin, end in zip(splits[:-1] + 1, splits[1:]):
+ if end > begin:
+ subShapes.append(numpy.array([
+ self._plotFrame.dataToPixel(x, y, axis='left')
+ for (x, y) in zip(item['x'][begin:end], item['y'][begin:end])]))
+
+ for points in subShapes: # Draw each sub-shape
+ # Draw the fill
+ if (item['fill'] is not None and
+ item['shape'] not in ('hline', 'vline')):
+ self._progBase.use()
+ gl.glUniformMatrix4fv(
+ self._progBase.uniforms['matrix'], 1, gl.GL_TRUE,
+ self.matScreenProj.astype(numpy.float32))
+ gl.glUniform2i(self._progBase.uniforms['isLog'], False, False)
+ gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.)
+
+ shape2D = glutils.FilledShape2D(
+ points, style=item['fill'], color=item['color'])
+ shape2D.render(
+ posAttrib=self._progBase.attributes['position'],
+ colorUnif=self._progBase.uniforms['color'],
+ hatchStepUnif=self._progBase.uniforms['hatchStep'])
+
+ # Draw the stroke
+ if item['linestyle'] not in ('', ' ', None):
+ if item['shape'] != 'polylines':
+ # close the polyline
+ points = numpy.append(points,
+ numpy.atleast_2d(points[0]), axis=0)
+
+ lines = glutils.GLLines2D(
+ points[:, 0], points[:, 1],
+ style=item['linestyle'],
+ color=item['color'],
+ dash2ndColor=item['linebgcolor'],
+ width=item['linewidth'])
+ context.matrix = self.matScreenProj
+ lines.render(context)
+
+ elif isinstance(item, _MarkerItem):
+ gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
+
+ xCoord, yCoord, yAxis = item['x'], item['y'], item['yaxis']
+
+ if ((isXLog and xCoord is not None and xCoord <= 0) or
+ (isYLog and yCoord is not None and yCoord <= 0)):
+ # Do not render markers with negative coords on log axis
+ continue
+
+ color = item['color']
+ intensity = color[0] * 0.299 + color[1] * 0.587 + color[2] * 0.114
+ bgColor = (1., 1., 1., 0.5) if intensity <= 0.5 else (0., 0., 0., 0.5)
+ if xCoord is None or yCoord is None:
+ if xCoord is None: # Horizontal line in data space
+ pixelPos = self._plotFrame.dataToPixel(
+ 0.5 * sum(self._plotFrame.dataRanges[0]),
+ yCoord,
+ axis=yAxis)
+
+ if item['text'] is not None:
+ x = self._plotFrame.size[0] - \
+ self._plotFrame.margins.right - pixelOffset
+ y = pixelPos[1] - pixelOffset
+ label = glutils.Text2D(
+ item['text'], x, y,
+ color=item['color'],
+ bgColor=bgColor,
+ align=glutils.RIGHT,
+ valign=glutils.BOTTOM,
+ devicePixelRatio=self.getDevicePixelRatio())
+ labels.append(label)
+
+ width = self._plotFrame.size[0]
+ lines = glutils.GLLines2D(
+ (0, width), (pixelPos[1], pixelPos[1]),
+ style=item['linestyle'],
+ color=item['color'],
+ width=item['linewidth'])
+ context.matrix = self.matScreenProj
+ lines.render(context)
+
+ else: # yCoord is None: vertical line in data space
+ yRange = self._plotFrame.dataRanges[1 if yAxis == 'left' else 2]
+ pixelPos = self._plotFrame.dataToPixel(
+ xCoord, 0.5 * sum(yRange), axis=yAxis)
+
+ if item['text'] is not None:
+ x = pixelPos[0] + pixelOffset
+ y = self._plotFrame.margins.top + pixelOffset
+ label = glutils.Text2D(
+ item['text'], x, y,
+ color=item['color'],
+ bgColor=bgColor,
+ align=glutils.LEFT,
+ valign=glutils.TOP,
+ devicePixelRatio=self.getDevicePixelRatio())
+ labels.append(label)
+
+ height = self._plotFrame.size[1]
+ lines = glutils.GLLines2D(
+ (pixelPos[0], pixelPos[0]), (0, height),
+ style=item['linestyle'],
+ color=item['color'],
+ width=item['linewidth'])
+ context.matrix = self.matScreenProj
+ lines.render(context)
+
+ else:
+ xmin, xmax = self._plot.getXAxis().getLimits()
+ ymin, ymax = self._plot.getYAxis(axis=yAxis).getLimits()
+ if not xmin < xCoord < xmax or not ymin < yCoord < ymax:
+ # Do not render markers outside visible plot area
+ continue
+ pixelPos = self._plotFrame.dataToPixel(
+ xCoord, yCoord, axis=yAxis)
+
+ if isYInverted:
+ valign = glutils.BOTTOM
+ vPixelOffset = -pixelOffset
+ else:
+ valign = glutils.TOP
+ vPixelOffset = pixelOffset
+
+ if item['text'] is not None:
+ x = pixelPos[0] + pixelOffset
+ y = pixelPos[1] + vPixelOffset
+ label = glutils.Text2D(
+ item['text'], x, y,
+ color=item['color'],
+ bgColor=bgColor,
+ align=glutils.LEFT,
+ valign=valign,
+ devicePixelRatio=self.getDevicePixelRatio())
+ labels.append(label)
+
+ # For now simple implementation: using a curve for each marker
+ # Should pack all markers to a single set of points
+ markerCurve = glutils.GLPlotCurve2D(
+ numpy.array((pixelPos[0],), dtype=numpy.float64),
+ numpy.array((pixelPos[1],), dtype=numpy.float64),
+ marker=item['symbol'],
+ markerColor=item['color'],
+ markerSize=11)
+
+ context = glutils.RenderContext(
+ matrix=self.matScreenProj,
+ isXLog=False,
+ isYLog=False,
+ dpi=self.getDotsPerInch())
+ markerCurve.render(context)
+ markerCurve.discard()
+
+ else:
+ _logger.error('Unsupported item: %s', str(item))
+ continue
+
+ # Render marker labels
+ gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
+ for label in labels:
+ label.render(self.matScreenProj)
+
+ def _renderOverlayGL(self):
+ """Render overlay layer: overlay items and crosshair."""
+ plotWidth, plotHeight = self._plotFrame.plotSize
+
+ # Scissor to plot area
+ gl.glScissor(self._plotFrame.margins.left,
+ self._plotFrame.margins.bottom,
+ plotWidth, plotHeight)
+ gl.glEnable(gl.GL_SCISSOR_TEST)
+
+ self._renderItems(overlay=True)
+
+ # Render crosshair cursor
+ if self._crosshairCursor is not None and self._mousePosInPixels is not None:
+ self._progBase.use()
+ gl.glUniform2i(self._progBase.uniforms['isLog'], False, False)
+ gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.)
+ posAttrib = self._progBase.attributes['position']
+ matrixUnif = self._progBase.uniforms['matrix']
+ colorUnif = self._progBase.uniforms['color']
+ hatchStepUnif = self._progBase.uniforms['hatchStep']
+
+ gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
+
+ gl.glUniformMatrix4fv(matrixUnif, 1, gl.GL_TRUE,
+ self.matScreenProj.astype(numpy.float32))
+
+ color, lineWidth = self._crosshairCursor
+ gl.glUniform4f(colorUnif, *color)
+ gl.glUniform1i(hatchStepUnif, 0)
+
+ xPixel, yPixel = self._mousePosInPixels
+ xPixel, yPixel = xPixel + 0.5, yPixel + 0.5
+ vertices = numpy.array(((0., yPixel),
+ (self._plotFrame.size[0], yPixel),
+ (xPixel, 0.),
+ (xPixel, self._plotFrame.size[1])),
+ dtype=numpy.float32)
+
+ gl.glEnableVertexAttribArray(posAttrib)
+ gl.glVertexAttribPointer(posAttrib,
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0, vertices)
+ gl.glLineWidth(lineWidth)
+ gl.glDrawArrays(gl.GL_LINES, 0, len(vertices))
+
+ gl.glDisable(gl.GL_SCISSOR_TEST)
+
+ def _renderPlotAreaGL(self):
+ """Render base layer of plot area.
+
+ It renders the background, grid and items except overlays
+ """
+ plotWidth, plotHeight = self._plotFrame.plotSize
+
+ gl.glScissor(self._plotFrame.margins.left,
+ self._plotFrame.margins.bottom,
+ plotWidth, plotHeight)
+ gl.glEnable(gl.GL_SCISSOR_TEST)
+
+ if self._dataBackgroundColor != self._backgroundColor:
+ gl.glClearColor(*self._dataBackgroundColor)
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT)
+
+ self._plotFrame.renderGrid()
+
+ # Matrix
+ trBounds = self._plotFrame.transformedDataRanges
+ if trBounds.x[0] != trBounds.x[1] and trBounds.y[0] != trBounds.y[1]:
+ # Do rendering of items
+ self._renderItems(overlay=False)
+
+ gl.glDisable(gl.GL_SCISSOR_TEST)
+
+ def resizeGL(self, width, height):
+ if width == 0 or height == 0: # Do not resize
+ return
+
+ self._plotFrame.size = (
+ int(self.getDevicePixelRatio() * width),
+ int(self.getDevicePixelRatio() * height))
+
+ self.matScreenProj = glutils.mat4Ortho(
+ 0, self._plotFrame.size[0],
+ self._plotFrame.size[1], 0,
+ 1, -1)
+
+ # Store current ranges
+ previousXRange = self.getGraphXLimits()
+ previousYRange = self.getGraphYLimits(axis='left')
+ previousYRightRange = self.getGraphYLimits(axis='right')
+
+ (xMin, xMax), (yMin, yMax), (y2Min, y2Max) = \
+ self._plotFrame.dataRanges
+ self.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
+
+ # If plot range has changed, then emit signal
+ if previousXRange != self.getGraphXLimits():
+ self._plot.getXAxis()._emitLimitsChanged()
+ if previousYRange != self.getGraphYLimits(axis='left'):
+ self._plot.getYAxis(axis='left')._emitLimitsChanged()
+ if previousYRightRange != self.getGraphYLimits(axis='right'):
+ self._plot.getYAxis(axis='right')._emitLimitsChanged()
+
+ # Add methods
+
+ @staticmethod
+ def _castArrayTo(v):
+ """Returns best floating type to cast the array to.
+
+ :param numpy.ndarray v: Array to cast
+ :rtype: numpy.dtype
+ :raise ValueError: If dtype is not supported
+ """
+ if numpy.issubdtype(v.dtype, numpy.floating):
+ return numpy.float32 if v.itemsize <= 4 else numpy.float64
+ elif numpy.issubdtype(v.dtype, numpy.integer):
+ return numpy.float32 if v.itemsize <= 2 else numpy.float64
+ else:
+ raise ValueError('Unsupported data type')
+
+ def addCurve(self, x, y,
+ color, symbol, linewidth, linestyle,
+ yaxis,
+ xerror, yerror,
+ fill, alpha, symbolsize, baseline):
+ for parameter in (x, y, color, symbol, linewidth, linestyle,
+ yaxis, fill, symbolsize):
+ assert parameter is not None
+ assert yaxis in ('left', 'right')
+
+ # Convert input data
+ x = numpy.array(x, copy=False)
+ y = numpy.array(y, copy=False)
+
+ # Check if float32 is enough
+ if (self._castArrayTo(x) is numpy.float32 and
+ self._castArrayTo(y) is numpy.float32):
+ dtype = numpy.float32
+ else:
+ dtype = numpy.float64
+
+ x = numpy.array(x, dtype=dtype, copy=False, order='C')
+ y = numpy.array(y, dtype=dtype, copy=False, order='C')
+
+ # Convert errors to float32
+ if xerror is not None:
+ xerror = numpy.array(
+ xerror, dtype=numpy.float32, copy=False, order='C')
+ if yerror is not None:
+ yerror = numpy.array(
+ yerror, dtype=numpy.float32, copy=False, order='C')
+
+ # Handle axes log scale: convert data
+
+ if self._plotFrame.xAxis.isLog:
+ logX = numpy.log10(x)
+
+ if xerror is not None:
+ # Transform xerror so that
+ # log10(x) +/- xerror' = log10(x +/- xerror)
+ if hasattr(xerror, 'shape') and len(xerror.shape) == 2:
+ xErrorMinus, xErrorPlus = xerror[0], xerror[1]
+ else:
+ xErrorMinus, xErrorPlus = xerror, xerror
+ with numpy.errstate(divide='ignore', invalid='ignore'):
+ # Ignore divide by zero, invalid value encountered in log10
+ xErrorMinus = logX - numpy.log10(x - xErrorMinus)
+ xErrorPlus = numpy.log10(x + xErrorPlus) - logX
+ xerror = numpy.array((xErrorMinus, xErrorPlus),
+ dtype=numpy.float32)
+
+ x = logX
+
+ isYLog = (yaxis == 'left' and self._plotFrame.yAxis.isLog) or (
+ yaxis == 'right' and self._plotFrame.y2Axis.isLog)
+
+ if isYLog:
+ logY = numpy.log10(y)
+
+ if yerror is not None:
+ # Transform yerror so that
+ # log10(y) +/- yerror' = log10(y +/- yerror)
+ if hasattr(yerror, 'shape') and len(yerror.shape) == 2:
+ yErrorMinus, yErrorPlus = yerror[0], yerror[1]
+ else:
+ yErrorMinus, yErrorPlus = yerror, yerror
+ with numpy.errstate(divide='ignore', invalid='ignore'):
+ # Ignore divide by zero, invalid value encountered in log10
+ yErrorMinus = logY - numpy.log10(y - yErrorMinus)
+ yErrorPlus = numpy.log10(y + yErrorPlus) - logY
+ yerror = numpy.array((yErrorMinus, yErrorPlus),
+ dtype=numpy.float32)
+
+ y = logY
+
+ # TODO check if need more filtering of error (e.g., clip to positive)
+
+ # TODO check and improve this
+ if (len(color) == 4 and
+ type(color[3]) in [type(1), numpy.uint8, numpy.int8]):
+ color = numpy.array(color, dtype=numpy.float32) / 255.
+
+ if isinstance(color, numpy.ndarray) and color.ndim == 2:
+ colorArray = color
+ color = None
+ else:
+ colorArray = None
+ color = colors.rgba(color)
+
+ if alpha < 1.: # Apply image transparency
+ if colorArray is not None and colorArray.shape[1] == 4:
+ # multiply alpha channel
+ colorArray[:, 3] = colorArray[:, 3] * alpha
+ if color is not None:
+ color = color[0], color[1], color[2], color[3] * alpha
+
+ fillColor = None
+ if fill is True:
+ fillColor = color
+ curve = glutils.GLPlotCurve2D(
+ x, y, colorArray,
+ xError=xerror,
+ yError=yerror,
+ lineStyle=linestyle,
+ lineColor=color,
+ lineWidth=linewidth,
+ marker=symbol,
+ markerColor=color,
+ markerSize=symbolsize,
+ fillColor=fillColor,
+ baseline=baseline,
+ isYLog=isYLog)
+ curve.yaxis = 'left' if yaxis is None else yaxis
+
+ if yaxis == "right":
+ self._plotFrame.isY2Axis = True
+
+ return curve
+
+ def addImage(self, data,
+ origin, scale,
+ colormap, alpha):
+ for parameter in (data, origin, scale):
+ assert parameter is not None
+
+ if data.ndim == 2:
+ # Ensure array is contiguous and eventually convert its type
+ dtypes = [dtype for dtype in (
+ numpy.float32, numpy.float16, numpy.uint8, numpy.uint16)
+ if glu.isSupportedGLType(dtype)]
+ if data.dtype in dtypes:
+ data = numpy.array(data, copy=False, order='C')
+ else:
+ _logger.info(
+ 'addImage: Convert %s data to float32', str(data.dtype))
+ data = numpy.array(data, dtype=numpy.float32, order='C')
+
+ normalization = colormap.getNormalization()
+ if normalization in glutils.GLPlotColormap.SUPPORTED_NORMALIZATIONS:
+ # Fast path applying colormap on the GPU
+ cmapRange = colormap.getColormapRange(data=data)
+ colormapLut = colormap.getNColors(nbColors=256)
+ gamma = colormap.getGammaNormalizationParameter()
+ nanColor = colors.rgba(colormap.getNaNColor())
+
+ image = glutils.GLPlotColormap(
+ data,
+ origin,
+ scale,
+ colormapLut,
+ normalization,
+ gamma,
+ cmapRange,
+ alpha,
+ nanColor)
+
+ else: # Fallback applying colormap on CPU
+ rgba = colormap.applyToData(data)
+ image = glutils.GLPlotRGBAImage(rgba, origin, scale, alpha)
+
+ elif len(data.shape) == 3:
+ # For RGB, RGBA data
+ assert data.shape[2] in (3, 4)
+
+ if numpy.issubdtype(data.dtype, numpy.floating):
+ data = numpy.array(data, dtype=numpy.float32, copy=False)
+ elif data.dtype in [numpy.uint8, numpy.uint16]:
+ pass
+ elif numpy.issubdtype(data.dtype, numpy.integer):
+ data = numpy.array(data, dtype=numpy.uint8, copy=False)
+ else:
+ raise ValueError('Unsupported data type')
+
+ image = glutils.GLPlotRGBAImage(data, origin, scale, alpha)
+
+ else:
+ raise RuntimeError("Unsupported data shape {0}".format(data.shape))
+
+ # TODO is this needed?
+ if self._plotFrame.xAxis.isLog and image.xMin <= 0.:
+ raise RuntimeError(
+ 'Cannot add image with X <= 0 with X axis log scale')
+ if self._plotFrame.yAxis.isLog and image.yMin <= 0.:
+ raise RuntimeError(
+ 'Cannot add image with Y <= 0 with Y axis log scale')
+
+ return image
+
+ def addTriangles(self, x, y, triangles,
+ color, alpha):
+ # Handle axes log scale: convert data
+ if self._plotFrame.xAxis.isLog:
+ x = numpy.log10(x)
+ if self._plotFrame.yAxis.isLog:
+ y = numpy.log10(y)
+
+ triangles = glutils.GLPlotTriangles(x, y, color, triangles, alpha)
+
+ return triangles
+
+ def addShape(self, x, y, shape, color, fill, overlay,
+ linestyle, linewidth, linebgcolor):
+ x = numpy.array(x, copy=False)
+ y = numpy.array(y, copy=False)
+
+ # TODO is this needed?
+ if self._plotFrame.xAxis.isLog and x.min() <= 0.:
+ raise RuntimeError(
+ 'Cannot add item with X <= 0 with X axis log scale')
+ if self._plotFrame.yAxis.isLog and y.min() <= 0.:
+ raise RuntimeError(
+ 'Cannot add item with Y <= 0 with Y axis log scale')
+
+ return _ShapeItem(x, y, shape, color, fill, overlay,
+ linestyle, linewidth, linebgcolor)
+
+ def addMarker(self, x, y, text, color,
+ symbol, linestyle, linewidth, constraint, yaxis):
+ return _MarkerItem(x, y, text, color,
+ symbol, linestyle, linewidth, constraint, yaxis)
+
+ # Remove methods
+
+ def remove(self, item):
+ if isinstance(item, glutils.GLPlotItem):
+ if item.yaxis == 'right':
+ # Check if some curves remains on the right Y axis
+ y2AxisItems = (item for item in self._plot.getItems()
+ if isinstance(item, items.YAxisMixIn) and
+ item.getYAxis() == 'right')
+ self._plotFrame.isY2Axis = next(y2AxisItems, None) is not None
+
+ if item.isInitialized():
+ self._glGarbageCollector.append(item)
+
+ elif isinstance(item, (_MarkerItem, _ShapeItem)):
+ pass # No-op
+
+ else:
+ _logger.error('Unsupported item: %s', str(item))
+
+ # Interaction methods
+
+ _QT_CURSORS = {
+ BackendBase.CURSOR_DEFAULT: qt.Qt.ArrowCursor,
+ BackendBase.CURSOR_POINTING: qt.Qt.PointingHandCursor,
+ BackendBase.CURSOR_SIZE_HOR: qt.Qt.SizeHorCursor,
+ BackendBase.CURSOR_SIZE_VER: qt.Qt.SizeVerCursor,
+ BackendBase.CURSOR_SIZE_ALL: qt.Qt.SizeAllCursor,
+ }
+
+ def setGraphCursorShape(self, cursor):
+ if cursor is None:
+ super(BackendOpenGL, self).unsetCursor()
+ else:
+ cursor = self._QT_CURSORS[cursor]
+ super(BackendOpenGL, self).setCursor(qt.QCursor(cursor))
+
+ def setGraphCursor(self, flag, color, linewidth, linestyle):
+ if linestyle != '-':
+ _logger.warning(
+ "BackendOpenGL.setGraphCursor linestyle parameter ignored")
+
+ if flag:
+ color = colors.rgba(color)
+ crosshairCursor = color, linewidth
+ else:
+ crosshairCursor = None
+
+ if crosshairCursor != self._crosshairCursor:
+ self._crosshairCursor = crosshairCursor
+
+ _PICK_OFFSET = 3 # Offset in pixel used for picking
+
+ def _mouseInPlotArea(self, x, y):
+ """Returns closest visible position in the plot.
+
+ This is performed in Qt widget pixel, not device pixel.
+
+ :param float x: X coordinate in Qt widget pixel
+ :param float y: Y coordinate in Qt widget pixel
+ :return: (x, y) closest point in the plot.
+ :rtype: List[float]
+ """
+ left, top, width, height = self.getPlotBoundsInPixels()
+ return (numpy.clip(x, left, left + width - 1), # TODO -1?
+ numpy.clip(y, top, top + height - 1))
+
+ def __pickCurves(self, item, x, y):
+ """Perform picking on a curve item.
+
+ :param GLPlotCurve2D item:
+ :param float x: X position of the mouse in widget coordinates
+ :param float y: Y position of the mouse in widget coordinates
+ :return: List of indices of picked points or None if not picked
+ :rtype: Union[List[int],None]
+ """
+ offset = self._PICK_OFFSET
+ if item.marker is not None:
+ # Convert markerSize from points to qt pixels
+ qtDpi = self.getDotsPerInch() / self.getDevicePixelRatio()
+ size = item.markerSize / 72. * qtDpi
+ offset = max(size / 2., offset)
+ if item.lineStyle is not None:
+ # Convert line width from points to qt pixels
+ qtDpi = self.getDotsPerInch() / self.getDevicePixelRatio()
+ lineWidth = item.lineWidth / 72. * qtDpi
+ offset = max(lineWidth / 2., offset)
+
+ inAreaPos = self._mouseInPlotArea(x - offset, y - offset)
+ dataPos = self._plot.pixelToData(inAreaPos[0], inAreaPos[1],
+ axis=item.yaxis, check=True)
+ if dataPos is None:
+ return None
+ xPick0, yPick0 = dataPos
+
+ inAreaPos = self._mouseInPlotArea(x + offset, y + offset)
+ dataPos = self._plot.pixelToData(inAreaPos[0], inAreaPos[1],
+ axis=item.yaxis, check=True)
+ if dataPos is None:
+ return None
+ xPick1, yPick1 = dataPos
+
+ if xPick0 < xPick1:
+ xPickMin, xPickMax = xPick0, xPick1
+ else:
+ xPickMin, xPickMax = xPick1, xPick0
+
+ if yPick0 < yPick1:
+ yPickMin, yPickMax = yPick0, yPick1
+ else:
+ yPickMin, yPickMax = yPick1, yPick0
+
+ # Apply log scale if axis is log
+ if self._plotFrame.xAxis.isLog:
+ xPickMin = numpy.log10(xPickMin)
+ xPickMax = numpy.log10(xPickMax)
+
+ if (item.yaxis == 'left' and self._plotFrame.yAxis.isLog) or (
+ item.yaxis == 'right' and self._plotFrame.y2Axis.isLog):
+ yPickMin = numpy.log10(yPickMin)
+ yPickMax = numpy.log10(yPickMax)
+
+ return item.pick(xPickMin, yPickMin,
+ xPickMax, yPickMax)
+
+ def pickItem(self, x, y, item):
+ # Picking is performed in Qt widget pixels not device pixels
+ dataPos = self._plot.pixelToData(x, y, axis='left', check=True)
+ if dataPos is None:
+ return None # Outside plot area
+
+ if item is None:
+ _logger.error("No item provided for picking")
+ return None
+
+ # Pick markers
+ if isinstance(item, _MarkerItem):
+ yaxis = item['yaxis']
+ pixelPos = self._plot.dataToPixel(
+ item['x'], item['y'], axis=yaxis, check=False)
+ if pixelPos is None:
+ return None # negative coord on a log axis
+
+ if item['x'] is None: # Horizontal line
+ pt1 = self._plot.pixelToData(
+ x, y - self._PICK_OFFSET, axis=yaxis, check=False)
+ pt2 = self._plot.pixelToData(
+ x, y + self._PICK_OFFSET, axis=yaxis, check=False)
+ isPicked = (min(pt1[1], pt2[1]) <= item['y'] <=
+ max(pt1[1], pt2[1]))
+
+ elif item['y'] is None: # Vertical line
+ pt1 = self._plot.pixelToData(
+ x - self._PICK_OFFSET, y, axis=yaxis, check=False)
+ pt2 = self._plot.pixelToData(
+ x + self._PICK_OFFSET, y, axis=yaxis, check=False)
+ isPicked = (min(pt1[0], pt2[0]) <= item['x'] <=
+ max(pt1[0], pt2[0]))
+
+ else:
+ isPicked = (
+ numpy.fabs(x - pixelPos[0]) <= self._PICK_OFFSET and
+ numpy.fabs(y - pixelPos[1]) <= self._PICK_OFFSET)
+
+ return (0,) if isPicked else None
+
+ # Pick image, curve, triangles
+ elif isinstance(item, glutils.GLPlotItem):
+ if isinstance(item, glutils.GLPlotCurve2D):
+ return self.__pickCurves(item, x, y)
+ else:
+ return item.pick(*dataPos) # Might be None
+
+ # Update curve
+
+ def setCurveColor(self, curve, color):
+ pass # TODO
+
+ # Misc.
+
+ def getWidgetHandle(self):
+ return self
+
+ def postRedisplay(self):
+ self.update()
+
+ def replot(self):
+ self.update() # async redraw
+
+ def saveGraph(self, fileName, fileFormat, dpi):
+ if dpi is not None:
+ _logger.warning("saveGraph ignores dpi parameter")
+
+ if fileFormat not in ['png', 'ppm', 'svg', 'tiff']:
+ raise NotImplementedError('Unsupported format: %s' % fileFormat)
+
+ if not self.isValid():
+ _logger.error('OpenGL 2.1 not available, cannot save OpenGL image')
+ width, height = self._plotFrame.size
+ data = numpy.zeros((height, width, 3), dtype=numpy.uint8)
+ else:
+ self.makeCurrent()
+
+ data = numpy.empty(
+ (self._plotFrame.size[1], self._plotFrame.size[0], 3),
+ dtype=numpy.uint8, order='C')
+
+ context = self.context()
+ framebufferTexture = self._plotFBOs.get(context)
+ if framebufferTexture is None:
+ # Fallback, supports direct rendering mode: _paintDirectGL
+ # might have issues as it can read on-screen framebuffer
+ fboName = self.defaultFramebufferObject()
+ width, height = self._plotFrame.size
+ else:
+ fboName = framebufferTexture.name
+ height, width = framebufferTexture.shape
+
+ previousFramebuffer = gl.glGetInteger(gl.GL_FRAMEBUFFER_BINDING)
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fboName)
+ gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1)
+ gl.glReadPixels(0, 0, width, height,
+ gl.GL_RGB, gl.GL_UNSIGNED_BYTE, data)
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, previousFramebuffer)
+
+ # glReadPixels gives bottom to top,
+ # while images are stored as top to bottom
+ data = numpy.flipud(data)
+
+ # fileName is either a file-like object or a str
+ saveImageToFile(data, fileName, fileFormat)
+
+ # Graph labels
+
+ def setGraphTitle(self, title):
+ self._plotFrame.title = title
+
+ def setGraphXLabel(self, label):
+ self._plotFrame.xAxis.title = label
+
+ def setGraphYLabel(self, label, axis):
+ if axis == 'left':
+ self._plotFrame.yAxis.title = label
+ else: # right axis
+ self._plotFrame.y2Axis.title = label
+
+ # Graph limits
+
+ def _setDataRanges(self, xlim=None, ylim=None, y2lim=None):
+ """Set the visible range of data in the plot frame.
+
+ This clips the ranges to possible values (takes care of float32
+ range + positive range for log).
+ This also takes care of non-orthogonal axes.
+
+ This should be moved to PlotFrame.
+ """
+ # Update axes range with a clipped range if too wide
+ self._plotFrame.setDataRanges(xlim, ylim, y2lim)
+
+ def _ensureAspectRatio(self, keepDim=None):
+ """Update plot bounds in order to keep aspect ratio.
+
+ Warning: keepDim on right Y axis is not implemented !
+
+ :param str keepDim: The dimension to maintain: 'x', 'y' or None.
+ If None (the default), the dimension with the largest range.
+ """
+ plotWidth, plotHeight = self._plotFrame.plotSize
+ if plotWidth <= 2 or plotHeight <= 2:
+ return
+
+ if keepDim is None:
+ ranges = self._plot.getDataRange()
+ if (ranges.y is not None and
+ ranges.x is not None and
+ (ranges.y[1] - ranges.y[0]) != 0.):
+ dataRatio = (ranges.x[1] - ranges.x[0]) / float(ranges.y[1] - ranges.y[0])
+ plotRatio = plotWidth / float(plotHeight) # Test != 0 before
+
+ keepDim = 'x' if dataRatio > plotRatio else 'y'
+ else: # Limit case
+ keepDim = 'x'
+
+ (xMin, xMax), (yMin, yMax), (y2Min, y2Max) = \
+ self._plotFrame.dataRanges
+ if keepDim == 'y':
+ dataW = (yMax - yMin) * plotWidth / float(plotHeight)
+ xCenter = 0.5 * (xMin + xMax)
+ xMin = xCenter - 0.5 * dataW
+ xMax = xCenter + 0.5 * dataW
+ elif keepDim == 'x':
+ dataH = (xMax - xMin) * plotHeight / float(plotWidth)
+ yCenter = 0.5 * (yMin + yMax)
+ yMin = yCenter - 0.5 * dataH
+ yMax = yCenter + 0.5 * dataH
+ y2Center = 0.5 * (y2Min + y2Max)
+ y2Min = y2Center - 0.5 * dataH
+ y2Max = y2Center + 0.5 * dataH
+ else:
+ raise RuntimeError('Unsupported dimension to keep: %s' % keepDim)
+
+ # Update plot frame bounds
+ self._setDataRanges(xlim=(xMin, xMax),
+ ylim=(yMin, yMax),
+ y2lim=(y2Min, y2Max))
+
+ def _setPlotBounds(self, xRange=None, yRange=None, y2Range=None,
+ keepDim=None):
+ # Update axes range with a clipped range if too wide
+ self._setDataRanges(xlim=xRange,
+ ylim=yRange,
+ y2lim=y2Range)
+
+ # Keep data aspect ratio
+ if self.isKeepDataAspectRatio():
+ self._ensureAspectRatio(keepDim)
+
+ def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None):
+ assert xmin < xmax
+ assert ymin < ymax
+
+ if y2min is None or y2max is None:
+ y2Range = None
+ else:
+ assert y2min < y2max
+ y2Range = y2min, y2max
+ self._setPlotBounds((xmin, xmax), (ymin, ymax), y2Range)
+
+ def getGraphXLimits(self):
+ return self._plotFrame.dataRanges.x
+
+ def setGraphXLimits(self, xmin, xmax):
+ assert xmin < xmax
+ self._setPlotBounds(xRange=(xmin, xmax), keepDim='x')
+
+ def getGraphYLimits(self, axis):
+ assert axis in ("left", "right")
+ if axis == "left":
+ return self._plotFrame.dataRanges.y
+ else:
+ return self._plotFrame.dataRanges.y2
+
+ def setGraphYLimits(self, ymin, ymax, axis):
+ assert ymin < ymax
+ assert axis in ("left", "right")
+
+ if axis == "left":
+ self._setPlotBounds(yRange=(ymin, ymax), keepDim='y')
+ else:
+ self._setPlotBounds(y2Range=(ymin, ymax), keepDim='y')
+
+ # Graph axes
+
+ def getXAxisTimeZone(self):
+ return self._plotFrame.xAxis.timeZone
+
+ def setXAxisTimeZone(self, tz):
+ self._plotFrame.xAxis.timeZone = tz
+
+ def isXAxisTimeSeries(self):
+ return self._plotFrame.xAxis.isTimeSeries
+
+ def setXAxisTimeSeries(self, isTimeSeries):
+ self._plotFrame.xAxis.isTimeSeries = isTimeSeries
+
+ def setXAxisLogarithmic(self, flag):
+ if flag != self._plotFrame.xAxis.isLog:
+ if flag and self._keepDataAspectRatio:
+ _logger.warning(
+ "KeepDataAspectRatio is ignored with log axes")
+
+ self._plotFrame.xAxis.isLog = flag
+
+ def setYAxisLogarithmic(self, flag):
+ if (flag != self._plotFrame.yAxis.isLog or
+ flag != self._plotFrame.y2Axis.isLog):
+ if flag and self._keepDataAspectRatio:
+ _logger.warning(
+ "KeepDataAspectRatio is ignored with log axes")
+
+ self._plotFrame.yAxis.isLog = flag
+ self._plotFrame.y2Axis.isLog = flag
+
+ def setYAxisInverted(self, flag):
+ if flag != self._plotFrame.isYAxisInverted:
+ self._plotFrame.isYAxisInverted = flag
+
+ def isYAxisInverted(self):
+ return self._plotFrame.isYAxisInverted
+
+ def isKeepDataAspectRatio(self):
+ if self._plotFrame.xAxis.isLog or self._plotFrame.yAxis.isLog:
+ return False
+ else:
+ return self._keepDataAspectRatio
+
+ def setKeepDataAspectRatio(self, flag):
+ if flag and (self._plotFrame.xAxis.isLog or
+ self._plotFrame.yAxis.isLog):
+ _logger.warning("KeepDataAspectRatio is ignored with log axes")
+
+ self._keepDataAspectRatio = flag
+
+ def setGraphGrid(self, which):
+ assert which in (None, 'major', 'both')
+ self._plotFrame.grid = which is not None # TODO True grid support
+
+ # Data <-> Pixel coordinates conversion
+
+ def dataToPixel(self, x, y, axis):
+ result = self._plotFrame.dataToPixel(x, y, axis)
+ if result is None:
+ return None
+ else:
+ devicePixelRatio = self.getDevicePixelRatio()
+ return tuple(value/devicePixelRatio for value in result)
+
+ def pixelToData(self, x, y, axis):
+ devicePixelRatio = self.getDevicePixelRatio()
+ return self._plotFrame.pixelToData(
+ x * devicePixelRatio, y * devicePixelRatio, axis)
+
+ def getPlotBoundsInPixels(self):
+ devicePixelRatio = self.getDevicePixelRatio()
+ return tuple(int(value / devicePixelRatio)
+ for value in self._plotFrame.plotOrigin + self._plotFrame.plotSize)
+
+ def setAxesMargins(self, left: float, top: float, right: float, bottom: float):
+ self._plotFrame.marginRatios = left, top, right, bottom
+
+ def setForegroundColors(self, foregroundColor, gridColor):
+ self._plotFrame.foregroundColor = foregroundColor
+ self._plotFrame.gridColor = gridColor
+
+ def setBackgroundColors(self, backgroundColor, dataBackgroundColor):
+ self._backgroundColor = backgroundColor
+ self._dataBackgroundColor = dataBackgroundColor
diff --git a/src/silx/gui/plot/backends/__init__.py b/src/silx/gui/plot/backends/__init__.py
new file mode 100644
index 0000000..966d9df
--- /dev/null
+++ b/src/silx/gui/plot/backends/__init__.py
@@ -0,0 +1,29 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package implements the backend of the Plot."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "21/03/2017"
diff --git a/src/silx/gui/plot/backends/glutils/GLPlotCurve.py b/src/silx/gui/plot/backends/glutils/GLPlotCurve.py
new file mode 100644
index 0000000..e4667b4
--- /dev/null
+++ b/src/silx/gui/plot/backends/glutils/GLPlotCurve.py
@@ -0,0 +1,1380 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+This module provides classes to render 2D lines and scatter plots
+"""
+
+from __future__ import division
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/04/2017"
+
+
+import math
+import logging
+
+import numpy
+
+from silx.math.combo import min_max
+
+from ...._glutils import gl
+from ...._glutils import Program, vertexBuffer, VertexBufferAttrib
+from .GLSupport import buildFillMaskIndices, mat4Identity, mat4Translate
+from .GLPlotImage import GLPlotItem
+
+
+_logger = logging.getLogger(__name__)
+
+
+_MPL_NONES = None, 'None', '', ' '
+"""Possible values for None"""
+
+
+def _notNaNSlices(array, length=1):
+ """Returns slices of none NaN values in the array.
+
+ :param numpy.ndarray array: 1D array from which to get slices
+ :param int length: Slices shorter than length gets discarded
+ :return: Array of (start, end) slice indices
+ :rtype: numpy.ndarray
+ """
+ isnan = numpy.isnan(numpy.array(array, copy=False).reshape(-1))
+ notnan = numpy.logical_not(isnan)
+ start = numpy.where(numpy.logical_and(isnan[:-1], notnan[1:]))[0] + 1
+ if notnan[0]:
+ start = numpy.append(0, start)
+ end = numpy.where(numpy.logical_and(notnan[:-1], isnan[1:]))[0] + 1
+ if notnan[-1]:
+ end = numpy.append(end, len(array))
+ slices = numpy.transpose((start, end))
+ if length > 1:
+ # discard slices with less than length values
+ slices = slices[numpy.diff(slices, axis=1).ravel() >= length]
+ return slices
+
+
+# fill ########################################################################
+
+class _Fill2D(object):
+ """Object rendering curve filling as polygons
+
+ :param numpy.ndarray xData: X coordinates of points
+ :param numpy.ndarray yData: Y coordinates of points
+ :param float baseline: Y value of the 'bottom' of the fill.
+ 0 for linear Y scale, -38 for log Y scale
+ :param List[float] color: RGBA color as 4 float in [0, 1]
+ :param List[float] offset: Translation of coordinates (ox, oy)
+ """
+
+ _PROGRAM = Program(
+ vertexShader="""
+ #version 120
+
+ uniform mat4 matrix;
+ attribute float xPos;
+ attribute float yPos;
+
+ void main(void) {
+ gl_Position = matrix * vec4(xPos, yPos, 0.0, 1.0);
+ }
+ """,
+ fragmentShader="""
+ #version 120
+
+ uniform vec4 color;
+
+ void main(void) {
+ gl_FragColor = color;
+ }
+ """,
+ attrib0='xPos')
+
+ def __init__(self, xData=None, yData=None,
+ baseline=0,
+ color=(0., 0., 0., 1.),
+ offset=(0., 0.)):
+ self.xData = xData
+ self.yData = yData
+ self._xFillVboData = None
+ self._yFillVboData = None
+ self.color = color
+ self.offset = offset
+
+ # Offset baseline
+ self.baseline = baseline - self.offset[1]
+
+ def prepare(self):
+ """Rendering preparation: build indices and bounding box vertices"""
+ if (self._xFillVboData is None and
+ self.xData is not None and self.yData is not None):
+
+ # Get slices of not NaN values longer than 1 element
+ isnan = numpy.logical_or(numpy.isnan(self.xData), numpy.isnan(self.yData))
+ notnan = numpy.logical_not(isnan)
+ start = numpy.where(numpy.logical_and(isnan[:-1], notnan[1:]))[0] + 1
+ if notnan[0]:
+ start = numpy.append(0, start)
+ end = numpy.where(numpy.logical_and(notnan[:-1], isnan[1:]))[0] + 1
+ if notnan[-1]:
+ end = numpy.append(end, len(isnan))
+ slices = numpy.transpose((start, end))
+ # discard slices with less than length values
+ slices = slices[numpy.diff(slices, axis=1).reshape(-1) >= 2]
+
+ # Number of points: slice + 2 * leading and trailing points
+ # Twice leading and trailing points to produce degenerated triangles
+ nbPoints = numpy.sum(numpy.diff(slices, axis=1)) * 2 + 4 * len(slices)
+ points = numpy.empty((nbPoints, 2), dtype=numpy.float32)
+
+ offset = 0
+ # invert baseline for filling
+ new_y_data = numpy.append(self.yData, self.baseline)
+ for start, end in slices:
+ # Duplicate first point for connecting degenerated triangle
+ points[offset:offset+2] = self.xData[start], new_y_data[start]
+
+ # 2nd point of the polygon is last point
+ points[offset+2] = self.xData[start], self.baseline[start]
+
+ indices = numpy.append(numpy.arange(start, end),
+ numpy.arange(len(self.xData) + end-1, len(self.xData) + start-1, -1))
+ indices = indices[buildFillMaskIndices(len(indices))]
+
+ points[offset+3:offset+3+len(indices), 0] = self.xData[indices % len(self.xData)]
+ points[offset+3:offset+3+len(indices), 1] = new_y_data[indices]
+
+ # Duplicate last point for connecting degenerated triangle
+ points[offset+3+len(indices)] = points[offset+3+len(indices)-1]
+
+ offset += len(indices) + 4
+
+ self._xFillVboData, self._yFillVboData = vertexBuffer(points.T)
+
+ def render(self, context):
+ """Perform rendering
+
+ :param RenderContext context:
+ """
+ self.prepare()
+
+ if self._xFillVboData is None:
+ return # Nothing to display
+
+ self._PROGRAM.use()
+
+ gl.glUniformMatrix4fv(
+ self._PROGRAM.uniforms['matrix'], 1, gl.GL_TRUE,
+ numpy.dot(context.matrix,
+ mat4Translate(*self.offset)).astype(numpy.float32))
+
+ gl.glUniform4f(self._PROGRAM.uniforms['color'], *self.color)
+
+ xPosAttrib = self._PROGRAM.attributes['xPos']
+ yPosAttrib = self._PROGRAM.attributes['yPos']
+
+ gl.glEnableVertexAttribArray(xPosAttrib)
+ self._xFillVboData.setVertexAttrib(xPosAttrib)
+
+ gl.glEnableVertexAttribArray(yPosAttrib)
+ self._yFillVboData.setVertexAttrib(yPosAttrib)
+
+ # Prepare fill mask
+ gl.glEnable(gl.GL_STENCIL_TEST)
+ gl.glStencilMask(1)
+ gl.glStencilFunc(gl.GL_ALWAYS, 1, 1)
+ gl.glStencilOp(gl.GL_INVERT, gl.GL_INVERT, gl.GL_INVERT)
+ gl.glColorMask(gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE)
+ gl.glDepthMask(gl.GL_FALSE)
+
+ gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, self._xFillVboData.size)
+
+ gl.glStencilFunc(gl.GL_EQUAL, 1, 1)
+ # Reset stencil while drawing
+ gl.glStencilOp(gl.GL_ZERO, gl.GL_ZERO, gl.GL_ZERO)
+ gl.glColorMask(gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE)
+ gl.glDepthMask(gl.GL_TRUE)
+
+ # Draw directly in NDC
+ gl.glUniformMatrix4fv(self._PROGRAM.uniforms['matrix'], 1, gl.GL_TRUE,
+ mat4Identity().astype(numpy.float32))
+
+ # NDC vertices
+ gl.glVertexAttribPointer(
+ xPosAttrib, 1, gl.GL_FLOAT, gl.GL_FALSE, 0,
+ numpy.array((-1., -1., 1., 1.), dtype=numpy.float32))
+ gl.glVertexAttribPointer(
+ yPosAttrib, 1, gl.GL_FLOAT, gl.GL_FALSE, 0,
+ numpy.array((-1., 1., -1., 1.), dtype=numpy.float32))
+
+ gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, 4)
+
+ gl.glDisable(gl.GL_STENCIL_TEST)
+
+ def discard(self):
+ """Release VBOs"""
+ if self.isInitialized():
+ self._xFillVboData.vbo.discard()
+
+ self._xFillVboData = None
+ self._yFillVboData = None
+
+ def isInitialized(self):
+ return self._xFillVboData is not None
+
+
+# line ########################################################################
+
+SOLID, DASHED, DASHDOT, DOTTED = '-', '--', '-.', ':'
+
+
+class GLLines2D(object):
+ """Object rendering curve as a polyline
+
+ :param xVboData: X coordinates VBO
+ :param yVboData: Y coordinates VBO
+ :param colorVboData: VBO of colors
+ :param distVboData: VBO of distance along the polyline
+ :param str style: Line style in: '-', '--', '-.', ':'
+ :param List[float] color: RGBA color as 4 float in [0, 1]
+ :param float width: Line width
+ :param float dashPeriod: Period of dashes
+ :param drawMode: OpenGL drawing mode
+ :param List[float] offset: Translation of coordinates (ox, oy)
+ """
+
+ STYLES = SOLID, DASHED, DASHDOT, DOTTED
+ """Supported line styles"""
+
+ _SOLID_PROGRAM = Program(
+ vertexShader="""
+ #version 120
+
+ uniform mat4 matrix;
+ attribute float xPos;
+ attribute float yPos;
+ attribute vec4 color;
+
+ varying vec4 vColor;
+
+ void main(void) {
+ gl_Position = matrix * vec4(xPos, yPos, 0., 1.) ;
+ vColor = color;
+ }
+ """,
+ fragmentShader="""
+ #version 120
+
+ varying vec4 vColor;
+
+ void main(void) {
+ gl_FragColor = vColor;
+ }
+ """,
+ attrib0='xPos')
+
+ # Limitation: Dash using an estimate of distance in screen coord
+ # to avoid computing distance when viewport is resized
+ # results in inequal dashes when viewport aspect ratio is far from 1
+ _DASH_PROGRAM = Program(
+ vertexShader="""
+ #version 120
+
+ uniform mat4 matrix;
+ uniform vec2 halfViewportSize;
+ attribute float xPos;
+ attribute float yPos;
+ attribute vec4 color;
+ attribute float distance;
+
+ varying float vDist;
+ varying vec4 vColor;
+
+ void main(void) {
+ gl_Position = matrix * vec4(xPos, yPos, 0., 1.);
+ //Estimate distance in pixels
+ vec2 probe = vec2(matrix * vec4(1., 1., 0., 0.)) *
+ halfViewportSize;
+ float pixelPerDataEstimate = length(probe)/sqrt(2.);
+ vDist = distance * pixelPerDataEstimate;
+ vColor = color;
+ }
+ """,
+ fragmentShader="""
+ #version 120
+
+ /* Dashes: [0, x], [y, z]
+ Dash period: w */
+ uniform vec4 dash;
+ uniform vec4 dash2ndColor;
+
+ varying float vDist;
+ varying vec4 vColor;
+
+ void main(void) {
+ float dist = mod(vDist, dash.w);
+ if ((dist > dash.x && dist < dash.y) || dist > dash.z) {
+ if (dash2ndColor.a == 0.) {
+ discard; // Discard full transparent bg color
+ } else {
+ gl_FragColor = dash2ndColor;
+ }
+ } else {
+ gl_FragColor = vColor;
+ }
+ }
+ """,
+ attrib0='xPos')
+
+ def __init__(self, xVboData=None, yVboData=None,
+ colorVboData=None, distVboData=None,
+ style=SOLID, color=(0., 0., 0., 1.), dash2ndColor=None,
+ width=1, dashPeriod=10., drawMode=None,
+ offset=(0., 0.)):
+ if (xVboData is not None and
+ not isinstance(xVboData, VertexBufferAttrib)):
+ xVboData = numpy.array(xVboData, copy=False, dtype=numpy.float32)
+ self.xVboData = xVboData
+
+ if (yVboData is not None and
+ not isinstance(yVboData, VertexBufferAttrib)):
+ yVboData = numpy.array(yVboData, copy=False, dtype=numpy.float32)
+ self.yVboData = yVboData
+
+ # Compute distances if not given while providing numpy array coordinates
+ if (isinstance(self.xVboData, numpy.ndarray) and
+ isinstance(self.yVboData, numpy.ndarray) and
+ distVboData is None):
+ distVboData = distancesFromArrays(self.xVboData, self.yVboData)
+
+ if (distVboData is not None and
+ not isinstance(distVboData, VertexBufferAttrib)):
+ distVboData = numpy.array(
+ distVboData, copy=False, dtype=numpy.float32)
+ self.distVboData = distVboData
+
+ if colorVboData is not None:
+ assert isinstance(colorVboData, VertexBufferAttrib)
+ self.colorVboData = colorVboData
+ self.useColorVboData = colorVboData is not None
+
+ self.color = color
+ self.dash2ndColor = dash2ndColor
+ self.width = width
+ self._style = None
+ self.style = style
+ self.dashPeriod = dashPeriod
+ self.offset = offset
+
+ self._drawMode = drawMode if drawMode is not None else gl.GL_LINE_STRIP
+
+ @property
+ def style(self):
+ """Line style (Union[str,None])"""
+ return self._style
+
+ @style.setter
+ def style(self, style):
+ if style in _MPL_NONES:
+ self._style = None
+ else:
+ assert style in self.STYLES
+ self._style = style
+
+ @classmethod
+ def init(cls):
+ """OpenGL context initialization"""
+ gl.glHint(gl.GL_LINE_SMOOTH_HINT, gl.GL_NICEST)
+
+ def render(self, context):
+ """Perform rendering
+
+ :param RenderContext context:
+ """
+ width = self.width / 72. * context.dpi
+
+ style = self.style
+ if style is None:
+ return
+
+ elif style == SOLID:
+ program = self._SOLID_PROGRAM
+ program.use()
+
+ else: # DASHED, DASHDOT, DOTTED
+ program = self._DASH_PROGRAM
+ program.use()
+
+ x, y, viewWidth, viewHeight = gl.glGetFloatv(gl.GL_VIEWPORT)
+ gl.glUniform2f(program.uniforms['halfViewportSize'],
+ 0.5 * viewWidth, 0.5 * viewHeight)
+
+ dashPeriod = self.dashPeriod * width
+ if self.style == DOTTED:
+ dash = (0.2 * dashPeriod,
+ 0.5 * dashPeriod,
+ 0.7 * dashPeriod,
+ dashPeriod)
+ elif self.style == DASHDOT:
+ dash = (0.3 * dashPeriod,
+ 0.5 * dashPeriod,
+ 0.6 * dashPeriod,
+ dashPeriod)
+ else:
+ dash = (0.5 * dashPeriod,
+ dashPeriod,
+ dashPeriod,
+ dashPeriod)
+
+ gl.glUniform4f(program.uniforms['dash'], *dash)
+
+ if self.dash2ndColor is None:
+ # Use fully transparent color which gets discarded in shader
+ dash2ndColor = (0., 0., 0., 0.)
+ else:
+ dash2ndColor = self.dash2ndColor
+ gl.glUniform4f(program.uniforms['dash2ndColor'], *dash2ndColor)
+
+ distAttrib = program.attributes['distance']
+ gl.glEnableVertexAttribArray(distAttrib)
+ if isinstance(self.distVboData, VertexBufferAttrib):
+ self.distVboData.setVertexAttrib(distAttrib)
+ else:
+ gl.glVertexAttribPointer(distAttrib,
+ 1,
+ gl.GL_FLOAT,
+ False,
+ 0,
+ self.distVboData)
+
+ if width != 1:
+ gl.glEnable(gl.GL_LINE_SMOOTH)
+
+ matrix = numpy.dot(context.matrix,
+ mat4Translate(*self.offset)).astype(numpy.float32)
+ gl.glUniformMatrix4fv(program.uniforms['matrix'],
+ 1, gl.GL_TRUE, matrix)
+
+ colorAttrib = program.attributes['color']
+ if self.useColorVboData and self.colorVboData is not None:
+ gl.glEnableVertexAttribArray(colorAttrib)
+ self.colorVboData.setVertexAttrib(colorAttrib)
+ else:
+ gl.glDisableVertexAttribArray(colorAttrib)
+ gl.glVertexAttrib4f(colorAttrib, *self.color)
+
+ xPosAttrib = program.attributes['xPos']
+ gl.glEnableVertexAttribArray(xPosAttrib)
+ if isinstance(self.xVboData, VertexBufferAttrib):
+ self.xVboData.setVertexAttrib(xPosAttrib)
+ else:
+ gl.glVertexAttribPointer(xPosAttrib,
+ 1,
+ gl.GL_FLOAT,
+ False,
+ 0,
+ self.xVboData)
+
+ yPosAttrib = program.attributes['yPos']
+ gl.glEnableVertexAttribArray(yPosAttrib)
+ if isinstance(self.yVboData, VertexBufferAttrib):
+ self.yVboData.setVertexAttrib(yPosAttrib)
+ else:
+ gl.glVertexAttribPointer(yPosAttrib,
+ 1,
+ gl.GL_FLOAT,
+ False,
+ 0,
+ self.yVboData)
+
+ gl.glLineWidth(width)
+ gl.glDrawArrays(self._drawMode, 0, self.xVboData.size)
+
+ gl.glDisable(gl.GL_LINE_SMOOTH)
+
+
+def distancesFromArrays(xData, yData):
+ """Returns distances between each points
+
+ :param numpy.ndarray xData: X coordinate of points
+ :param numpy.ndarray yData: Y coordinate of points
+ :rtype: numpy.ndarray
+ """
+ # Split array into sub-shapes at not finite points
+ splits = numpy.nonzero(numpy.logical_not(numpy.logical_and(
+ numpy.isfinite(xData), numpy.isfinite(yData))))[0]
+ splits = numpy.concatenate(([-1], splits, [len(xData) - 1]))
+
+ # Compute distance independently for each sub-shapes,
+ # putting not finite points as last points of sub-shapes
+ distances = []
+ for begin, end in zip(splits[:-1] + 1, splits[1:] + 1):
+ if begin == end: # Empty shape
+ continue
+ elif end - begin == 1: # Single element
+ distances.append([0])
+ else:
+ deltas = numpy.dstack((
+ numpy.ediff1d(xData[begin:end], to_begin=numpy.float32(0.)),
+ numpy.ediff1d(yData[begin:end], to_begin=numpy.float32(0.))))[0]
+ distances.append(
+ numpy.cumsum(numpy.sqrt(numpy.sum(deltas ** 2, axis=1))))
+ return numpy.concatenate(distances)
+
+
+# points ######################################################################
+
+DIAMOND, CIRCLE, SQUARE, PLUS, X_MARKER, POINT, PIXEL, ASTERISK = \
+ 'd', 'o', 's', '+', 'x', '.', ',', '*'
+
+H_LINE, V_LINE, HEART = '_', '|', u'\u2665'
+
+TICK_LEFT = "tickleft"
+TICK_RIGHT = "tickright"
+TICK_UP = "tickup"
+TICK_DOWN = "tickdown"
+CARET_LEFT = "caretleft"
+CARET_RIGHT = "caretright"
+CARET_UP = "caretup"
+CARET_DOWN = "caretdown"
+
+
+class _Points2D(object):
+ """Object rendering curve markers
+
+ :param xVboData: X coordinates VBO
+ :param yVboData: Y coordinates VBO
+ :param colorVboData: VBO of colors
+ :param str marker: Kind of symbol to use, see :attr:`MARKERS`.
+ :param List[float] color: RGBA color as 4 float in [0, 1]
+ :param float size: Marker size
+ :param List[float] offset: Translation of coordinates (ox, oy)
+ """
+
+ MARKERS = (DIAMOND, CIRCLE, SQUARE, PLUS, X_MARKER, POINT, PIXEL, ASTERISK,
+ H_LINE, V_LINE, HEART, TICK_LEFT, TICK_RIGHT, TICK_UP, TICK_DOWN,
+ CARET_LEFT, CARET_RIGHT, CARET_UP, CARET_DOWN)
+ """List of supported markers"""
+
+ _VERTEX_SHADER = """
+ #version 120
+
+ uniform mat4 matrix;
+ uniform int transform;
+ uniform float size;
+ attribute float xPos;
+ attribute float yPos;
+ attribute vec4 color;
+
+ varying vec4 vColor;
+
+ void main(void) {
+ gl_Position = matrix * vec4(xPos, yPos, 0., 1.);
+ vColor = color;
+ gl_PointSize = size;
+ }
+ """
+
+ _FRAGMENT_SHADER_SYMBOLS = {
+ DIAMOND: """
+ float alphaSymbol(vec2 coord, float size) {
+ vec2 centerCoord = abs(coord - vec2(0.5, 0.5));
+ float f = centerCoord.x + centerCoord.y;
+ return clamp(size * (0.5 - f), 0.0, 1.0);
+ }
+ """,
+ CIRCLE: """
+ float alphaSymbol(vec2 coord, float size) {
+ float radius = 0.5;
+ float r = distance(coord, vec2(0.5, 0.5));
+ return clamp(size * (radius - r), 0.0, 1.0);
+ }
+ """,
+ SQUARE: """
+ float alphaSymbol(vec2 coord, float size) {
+ return 1.0;
+ }
+ """,
+ PLUS: """
+ float alphaSymbol(vec2 coord, float size) {
+ vec2 d = abs(size * (coord - vec2(0.5, 0.5)));
+ if (min(d.x, d.y) < 0.5) {
+ return 1.0;
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ X_MARKER: """
+ float alphaSymbol(vec2 coord, float size) {
+ vec2 pos = floor(size * coord) + 0.5;
+ vec2 d_x = abs(pos.x + vec2(- pos.y, pos.y - size));
+ if (min(d_x.x, d_x.y) <= 0.5) {
+ return 1.0;
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ ASTERISK: """
+ float alphaSymbol(vec2 coord, float size) {
+ /* Combining +, x and circle */
+ vec2 d_plus = abs(size * (coord - vec2(0.5, 0.5)));
+ vec2 pos = floor(size * coord) + 0.5;
+ vec2 d_x = abs(pos.x + vec2(- pos.y, pos.y - size));
+ if (min(d_plus.x, d_plus.y) < 0.5) {
+ return 1.0;
+ } else if (min(d_x.x, d_x.y) <= 0.5) {
+ float r = distance(coord, vec2(0.5, 0.5));
+ return clamp(size * (0.5 - r), 0.0, 1.0);
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ H_LINE: """
+ float alphaSymbol(vec2 coord, float size) {
+ float dy = abs(size * (coord.y - 0.5));
+ if (dy < 0.5) {
+ return 1.0;
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ V_LINE: """
+ float alphaSymbol(vec2 coord, float size) {
+ float dx = abs(size * (coord.x - 0.5));
+ if (dx < 0.5) {
+ return 1.0;
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ HEART: """
+ float alphaSymbol(vec2 coord, float size) {
+ coord = (coord - 0.5) * 2.;
+ coord *= 0.75;
+ coord.y += 0.25;
+ float a = atan(coord.x,-coord.y)/3.141593;
+ float r = length(coord);
+ float h = abs(a);
+ float d = (13.0*h - 22.0*h*h + 10.0*h*h*h)/(6.0-5.0*h);
+ float res = clamp(r-d, 0., 1.);
+ // antialiasing
+ res = smoothstep(0.1, 0.001, res);
+ return res;
+ }
+ """,
+ TICK_LEFT: """
+ float alphaSymbol(vec2 coord, float size) {
+ coord = size * (coord - 0.5);
+ float dy = abs(coord.y);
+ if (dy < 0.5 && coord.x < 0.5) {
+ return 1.0;
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ TICK_RIGHT: """
+ float alphaSymbol(vec2 coord, float size) {
+ coord = size * (coord - 0.5);
+ float dy = abs(coord.y);
+ if (dy < 0.5 && coord.x > -0.5) {
+ return 1.0;
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ TICK_UP: """
+ float alphaSymbol(vec2 coord, float size) {
+ coord = size * (coord - 0.5);
+ float dx = abs(coord.x);
+ if (dx < 0.5 && coord.y < 0.5) {
+ return 1.0;
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ TICK_DOWN: """
+ float alphaSymbol(vec2 coord, float size) {
+ coord = size * (coord - 0.5);
+ float dx = abs(coord.x);
+ if (dx < 0.5 && coord.y > -0.5) {
+ return 1.0;
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ CARET_LEFT: """
+ float alphaSymbol(vec2 coord, float size) {
+ coord = size * (coord - 0.5);
+ float d = abs(coord.x) - abs(coord.y);
+ if (d >= -0.1 && coord.x > 0.5) {
+ return smoothstep(-0.1, 0.1, d);
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ CARET_RIGHT: """
+ float alphaSymbol(vec2 coord, float size) {
+ coord = size * (coord - 0.5);
+ float d = abs(coord.x) - abs(coord.y);
+ if (d >= -0.1 && coord.x < 0.5) {
+ return smoothstep(-0.1, 0.1, d);
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ CARET_UP: """
+ float alphaSymbol(vec2 coord, float size) {
+ coord = size * (coord - 0.5);
+ float d = abs(coord.y) - abs(coord.x);
+ if (d >= -0.1 && coord.y > 0.5) {
+ return smoothstep(-0.1, 0.1, d);
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ CARET_DOWN: """
+ float alphaSymbol(vec2 coord, float size) {
+ coord = size * (coord - 0.5);
+ float d = abs(coord.y) - abs(coord.x);
+ if (d >= -0.1 && coord.y < 0.5) {
+ return smoothstep(-0.1, 0.1, d);
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ }
+
+ _FRAGMENT_SHADER_TEMPLATE = """
+ #version 120
+
+ uniform float size;
+
+ varying vec4 vColor;
+
+ %s
+
+ void main(void) {
+ float alpha = alphaSymbol(gl_PointCoord, size);
+ if (alpha <= 0.0) {
+ discard;
+ } else {
+ gl_FragColor = vec4(vColor.rgb, alpha * clamp(vColor.a, 0.0, 1.0));
+ }
+ }
+ """
+
+ _PROGRAMS = {}
+
+ def __init__(self, xVboData=None, yVboData=None, colorVboData=None,
+ marker=SQUARE, color=(0., 0., 0., 1.), size=7,
+ offset=(0., 0.)):
+ self.color = color
+ self._marker = None
+ self.marker = marker
+ self.size = size
+ self.offset = offset
+
+ self.xVboData = xVboData
+ self.yVboData = yVboData
+ self.colorVboData = colorVboData
+ self.useColorVboData = colorVboData is not None
+
+ @property
+ def marker(self):
+ """Symbol used to display markers (str)"""
+ return self._marker
+
+ @marker.setter
+ def marker(self, marker):
+ if marker in _MPL_NONES:
+ self._marker = None
+ else:
+ assert marker in self.MARKERS
+ self._marker = marker
+
+ @classmethod
+ def _getProgram(cls, marker):
+ """On-demand shader program creation."""
+ if marker == PIXEL:
+ marker = SQUARE
+ elif marker == POINT:
+ marker = CIRCLE
+
+ if marker not in cls._PROGRAMS:
+ cls._PROGRAMS[marker] = Program(
+ vertexShader=cls._VERTEX_SHADER,
+ fragmentShader=(cls._FRAGMENT_SHADER_TEMPLATE %
+ cls._FRAGMENT_SHADER_SYMBOLS[marker]),
+ attrib0='xPos')
+
+ return cls._PROGRAMS[marker]
+
+ @classmethod
+ def init(cls):
+ """OpenGL context initialization"""
+ version = gl.glGetString(gl.GL_VERSION)
+ majorVersion = int(version[0])
+ assert majorVersion >= 2
+ gl.glEnable(gl.GL_VERTEX_PROGRAM_POINT_SIZE) # OpenGL 2
+ gl.glEnable(gl.GL_POINT_SPRITE) # OpenGL 2
+ if majorVersion >= 3: # OpenGL 3
+ gl.glEnable(gl.GL_PROGRAM_POINT_SIZE)
+
+ def render(self, context):
+ """Perform rendering
+
+ :param RenderContext context:
+ """
+ if self.marker is None:
+ return
+
+ program = self._getProgram(self.marker)
+ program.use()
+
+ matrix = numpy.dot(context.matrix,
+ mat4Translate(*self.offset)).astype(numpy.float32)
+ gl.glUniformMatrix4fv(program.uniforms['matrix'], 1, gl.GL_TRUE, matrix)
+
+ if self.marker == PIXEL:
+ size = 1
+ elif self.marker == POINT:
+ size = math.ceil(0.5 * self.size) + 1 # Mimic Matplotlib point
+ else:
+ size = self.size
+ size = size / 72. * context.dpi
+
+ if self.marker in (PLUS, H_LINE, V_LINE,
+ TICK_LEFT, TICK_RIGHT, TICK_UP, TICK_DOWN):
+ # Convert to nearest odd number
+ size = size // 2 * 2 + 1.
+
+ gl.glUniform1f(program.uniforms['size'], size)
+ # gl.glPointSize(self.size)
+
+ cAttrib = program.attributes['color']
+ if self.useColorVboData and self.colorVboData is not None:
+ gl.glEnableVertexAttribArray(cAttrib)
+ self.colorVboData.setVertexAttrib(cAttrib)
+ else:
+ gl.glDisableVertexAttribArray(cAttrib)
+ gl.glVertexAttrib4f(cAttrib, *self.color)
+
+ xAttrib = program.attributes['xPos']
+ gl.glEnableVertexAttribArray(xAttrib)
+ self.xVboData.setVertexAttrib(xAttrib)
+
+ yAttrib = program.attributes['yPos']
+ gl.glEnableVertexAttribArray(yAttrib)
+ self.yVboData.setVertexAttrib(yAttrib)
+
+ gl.glDrawArrays(gl.GL_POINTS, 0, self.xVboData.size)
+
+ gl.glUseProgram(0)
+
+
+# error bars ##################################################################
+
+class _ErrorBars(object):
+ """Display errors bars.
+
+ This is using its own VBO as opposed to fill/points/lines.
+ There is no picking on error bars.
+
+ It uses 2 vertices per error bars and uses :class:`GLLines2D` to
+ render error bars and :class:`_Points2D` to render the ends.
+
+ :param numpy.ndarray xData: X coordinates of the data.
+ :param numpy.ndarray yData: Y coordinates of the data.
+ :param xError: The absolute error on the X axis.
+ :type xError: A float, or a numpy.ndarray of float32.
+ If it is an array, it can either be a 1D array of
+ same length as the data or a 2D array with 2 rows
+ of same length as the data: row 0 for negative errors,
+ row 1 for positive errors.
+ :param yError: The absolute error on the Y axis.
+ :type yError: A float, or a numpy.ndarray of float32. See xError.
+ :param float xMin: The min X value already computed by GLPlotCurve2D.
+ :param float yMin: The min Y value already computed by GLPlotCurve2D.
+ :param List[float] color: RGBA color as 4 float in [0, 1]
+ :param List[float] offset: Translation of coordinates (ox, oy)
+ """
+
+ def __init__(self, xData, yData, xError, yError,
+ xMin, yMin,
+ color=(0., 0., 0., 1.),
+ offset=(0., 0.)):
+ self._attribs = None
+ self._xMin, self._yMin = xMin, yMin
+ self.offset = offset
+
+ if xError is not None or yError is not None:
+ self._xData = numpy.array(
+ xData, order='C', dtype=numpy.float32, copy=False)
+ self._yData = numpy.array(
+ yData, order='C', dtype=numpy.float32, copy=False)
+
+ # This also works if xError, yError is a float/int
+ self._xError = numpy.array(
+ xError, order='C', dtype=numpy.float32, copy=False)
+ self._yError = numpy.array(
+ yError, order='C', dtype=numpy.float32, copy=False)
+ else:
+ self._xData, self._yData = None, None
+ self._xError, self._yError = None, None
+
+ self._lines = GLLines2D(
+ None, None, color=color, drawMode=gl.GL_LINES, offset=offset)
+ self._xErrPoints = _Points2D(
+ None, None, color=color, marker=V_LINE, offset=offset)
+ self._yErrPoints = _Points2D(
+ None, None, color=color, marker=H_LINE, offset=offset)
+
+ def _buildVertices(self):
+ """Generates error bars vertices"""
+ nbLinesPerDataPts = (0 if self._xError is None else 2) + \
+ (0 if self._yError is None else 2)
+
+ nbDataPts = len(self._xData)
+
+ # interleave coord+error, coord-error.
+ # xError vertices first if any, then yError vertices if any.
+ xCoords = numpy.empty(nbDataPts * nbLinesPerDataPts * 2,
+ dtype=numpy.float32)
+ yCoords = numpy.empty(nbDataPts * nbLinesPerDataPts * 2,
+ dtype=numpy.float32)
+
+ if self._xError is not None: # errors on the X axis
+ if len(self._xError.shape) == 2:
+ xErrorMinus, xErrorPlus = self._xError[0], self._xError[1]
+ else:
+ # numpy arrays of len 1 or len(xData)
+ xErrorMinus, xErrorPlus = self._xError, self._xError
+
+ # Interleave vertices for xError
+ endXError = 4 * nbDataPts
+ with numpy.errstate(invalid="ignore"):
+ xCoords[0:endXError-3:4] = self._xData + xErrorPlus
+ xCoords[1:endXError-2:4] = self._xData
+ xCoords[2:endXError-1:4] = self._xData
+ with numpy.errstate(invalid="ignore"):
+ xCoords[3:endXError:4] = self._xData - xErrorMinus
+
+ yCoords[0:endXError-3:4] = self._yData
+ yCoords[1:endXError-2:4] = self._yData
+ yCoords[2:endXError-1:4] = self._yData
+ yCoords[3:endXError:4] = self._yData
+
+ else:
+ endXError = 0
+
+ if self._yError is not None: # errors on the Y axis
+ if len(self._yError.shape) == 2:
+ yErrorMinus, yErrorPlus = self._yError[0], self._yError[1]
+ else:
+ # numpy arrays of len 1 or len(yData)
+ yErrorMinus, yErrorPlus = self._yError, self._yError
+
+ # Interleave vertices for yError
+ xCoords[endXError::4] = self._xData
+ xCoords[endXError+1::4] = self._xData
+ xCoords[endXError+2::4] = self._xData
+ xCoords[endXError+3::4] = self._xData
+
+ with numpy.errstate(invalid="ignore"):
+ yCoords[endXError::4] = self._yData + yErrorPlus
+ yCoords[endXError+1::4] = self._yData
+ yCoords[endXError+2::4] = self._yData
+ with numpy.errstate(invalid="ignore"):
+ yCoords[endXError+3::4] = self._yData - yErrorMinus
+
+ return xCoords, yCoords
+
+ def prepare(self):
+ """Rendering preparation: build indices and bounding box vertices"""
+ if self._xData is None:
+ return
+
+ if self._attribs is None:
+ xCoords, yCoords = self._buildVertices()
+
+ xAttrib, yAttrib = vertexBuffer((xCoords, yCoords))
+ self._attribs = xAttrib, yAttrib
+
+ self._lines.xVboData = xAttrib
+ self._lines.yVboData = yAttrib
+
+ # Set xError points using the same VBO as lines
+ self._xErrPoints.xVboData = xAttrib.copy()
+ self._xErrPoints.xVboData.size //= 2
+ self._xErrPoints.yVboData = yAttrib.copy()
+ self._xErrPoints.yVboData.size //= 2
+
+ # Set yError points using the same VBO as lines
+ self._yErrPoints.xVboData = xAttrib.copy()
+ self._yErrPoints.xVboData.size //= 2
+ self._yErrPoints.xVboData.offset += (xAttrib.itemsize *
+ xAttrib.size // 2)
+ self._yErrPoints.yVboData = yAttrib.copy()
+ self._yErrPoints.yVboData.size //= 2
+ self._yErrPoints.yVboData.offset += (yAttrib.itemsize *
+ yAttrib.size // 2)
+
+ def render(self, context):
+ """Perform rendering
+
+ :param RenderContext context:
+ """
+ self.prepare()
+
+ if self._attribs is not None:
+ self._lines.render(context)
+ self._xErrPoints.render(context)
+ self._yErrPoints.render(context)
+
+ def discard(self):
+ """Release VBOs"""
+ if self.isInitialized():
+ self._lines.xVboData, self._lines.yVboData = None, None
+ self._xErrPoints.xVboData, self._xErrPoints.yVboData = None, None
+ self._yErrPoints.xVboData, self._yErrPoints.yVboData = None, None
+ self._attribs[0].vbo.discard()
+ self._attribs = None
+
+ def isInitialized(self):
+ return self._attribs is not None
+
+
+# curves ######################################################################
+
+def _proxyProperty(*componentsAttributes):
+ """Create a property to access an attribute of attribute(s).
+ Useful for composition.
+ Supports multiple components this way:
+ getter returns the first found, setter sets all
+ """
+ def getter(self):
+ for compName, attrName in componentsAttributes:
+ try:
+ component = getattr(self, compName)
+ except AttributeError:
+ pass
+ else:
+ return getattr(component, attrName)
+
+ def setter(self, value):
+ for compName, attrName in componentsAttributes:
+ component = getattr(self, compName)
+ setattr(component, attrName, value)
+ return property(getter, setter)
+
+
+class GLPlotCurve2D(GLPlotItem):
+ def __init__(self, xData, yData, colorData=None,
+ xError=None, yError=None,
+ lineStyle=SOLID,
+ lineColor=(0., 0., 0., 1.),
+ lineWidth=1,
+ lineDashPeriod=20,
+ marker=SQUARE,
+ markerColor=(0., 0., 0., 1.),
+ markerSize=7,
+ fillColor=None,
+ baseline=None,
+ isYLog=False):
+ super().__init__()
+ self.colorData = colorData
+
+ # Compute x bounds
+ if xError is None:
+ self.xMin, self.xMax = min_max(xData, min_positive=False)
+ else:
+ # Takes the error into account
+ if hasattr(xError, 'shape') and len(xError.shape) == 2:
+ xErrorMinus, xErrorPlus = xError[0], xError[1]
+ else:
+ xErrorMinus, xErrorPlus = xError, xError
+ self.xMin = numpy.nanmin(xData - xErrorMinus)
+ self.xMax = numpy.nanmax(xData + xErrorPlus)
+
+ # Compute y bounds
+ if yError is None:
+ self.yMin, self.yMax = min_max(yData, min_positive=False)
+ else:
+ # Takes the error into account
+ if hasattr(yError, 'shape') and len(yError.shape) == 2:
+ yErrorMinus, yErrorPlus = yError[0], yError[1]
+ else:
+ yErrorMinus, yErrorPlus = yError, yError
+ self.yMin = numpy.nanmin(yData - yErrorMinus)
+ self.yMax = numpy.nanmax(yData + yErrorPlus)
+
+ # Handle data offset
+ if xData.itemsize > 4 or yData.itemsize > 4: # Use normalization
+ # offset data, do not offset error as it is relative
+ self.offset = self.xMin, self.yMin
+ with numpy.errstate(invalid="ignore"):
+ self.xData = (xData - self.offset[0]).astype(numpy.float32)
+ self.yData = (yData - self.offset[1]).astype(numpy.float32)
+
+ else: # float32
+ self.offset = 0., 0.
+ self.xData = xData
+ self.yData = yData
+ if fillColor is not None:
+ def deduce_baseline(baseline):
+ if baseline is None:
+ _baseline = 0
+ else:
+ _baseline = baseline
+ if not isinstance(_baseline, numpy.ndarray):
+ _baseline = numpy.repeat(_baseline,
+ len(self.xData))
+ if isYLog is True:
+ with numpy.errstate(divide='ignore', invalid='ignore'):
+ log_val = numpy.log10(_baseline)
+ _baseline = numpy.where(_baseline>0.0, log_val, -38)
+ return _baseline
+
+ _baseline = deduce_baseline(baseline)
+
+ # Use different baseline depending of Y log scale
+ self.fill = _Fill2D(self.xData, self.yData,
+ baseline=_baseline,
+ color=fillColor,
+ offset=self.offset)
+ else:
+ self.fill = None
+
+ self._errorBars = _ErrorBars(self.xData, self.yData,
+ xError, yError,
+ self.xMin, self.yMin,
+ offset=self.offset)
+
+ self.lines = GLLines2D()
+ self.lines.style = lineStyle
+ self.lines.color = lineColor
+ self.lines.width = lineWidth
+ self.lines.dashPeriod = lineDashPeriod
+ self.lines.offset = self.offset
+
+ self.points = _Points2D()
+ self.points.marker = marker
+ self.points.color = markerColor
+ self.points.size = markerSize
+ self.points.offset = self.offset
+
+ xVboData = _proxyProperty(('lines', 'xVboData'), ('points', 'xVboData'))
+
+ yVboData = _proxyProperty(('lines', 'yVboData'), ('points', 'yVboData'))
+
+ colorVboData = _proxyProperty(('lines', 'colorVboData'),
+ ('points', 'colorVboData'))
+
+ useColorVboData = _proxyProperty(('lines', 'useColorVboData'),
+ ('points', 'useColorVboData'))
+
+ distVboData = _proxyProperty(('lines', 'distVboData'))
+
+ lineStyle = _proxyProperty(('lines', 'style'))
+
+ lineColor = _proxyProperty(('lines', 'color'))
+
+ lineWidth = _proxyProperty(('lines', 'width'))
+
+ lineDashPeriod = _proxyProperty(('lines', 'dashPeriod'))
+
+ marker = _proxyProperty(('points', 'marker'))
+
+ markerColor = _proxyProperty(('points', 'color'))
+
+ markerSize = _proxyProperty(('points', 'size'))
+
+ @classmethod
+ def init(cls):
+ """OpenGL context initialization"""
+ GLLines2D.init()
+ _Points2D.init()
+
+ def prepare(self):
+ """Rendering preparation: build indices and bounding box vertices"""
+ if self.xVboData is None:
+ xAttrib, yAttrib, cAttrib, dAttrib = None, None, None, None
+ if self.lineStyle in (DASHED, DASHDOT, DOTTED):
+ dists = distancesFromArrays(self.xData, self.yData)
+ if self.colorData is None:
+ xAttrib, yAttrib, dAttrib = vertexBuffer(
+ (self.xData, self.yData, dists))
+ else:
+ xAttrib, yAttrib, cAttrib, dAttrib = vertexBuffer(
+ (self.xData, self.yData, self.colorData, dists))
+ elif self.colorData is None:
+ xAttrib, yAttrib = vertexBuffer((self.xData, self.yData))
+ else:
+ xAttrib, yAttrib, cAttrib = vertexBuffer(
+ (self.xData, self.yData, self.colorData))
+
+ self.xVboData = xAttrib
+ self.yVboData = yAttrib
+ self.distVboData = dAttrib
+
+ if cAttrib is not None and self.colorData.dtype.kind == 'u':
+ cAttrib.normalization = True # Normalize uint to [0, 1]
+ self.colorVboData = cAttrib
+ self.useColorVboData = cAttrib is not None
+
+ def render(self, context):
+ """Perform rendering
+
+ :param RenderContext context: Rendering information
+ """
+ self.prepare()
+ if self.fill is not None:
+ self.fill.render(context)
+ self._errorBars.render(context)
+ self.lines.render(context)
+ self.points.render(context)
+
+ def discard(self):
+ """Release VBOs"""
+ if self.xVboData is not None:
+ self.xVboData.vbo.discard()
+
+ self.xVboData = None
+ self.yVboData = None
+ self.colorVboData = None
+ self.distVboData = None
+
+ self._errorBars.discard()
+ if self.fill is not None:
+ self.fill.discard()
+
+ def isInitialized(self):
+ return (self.xVboData is not None or
+ self._errorBars.isInitialized() or
+ (self.fill is not None and self.fill.isInitialized()))
+
+ def pick(self, xPickMin, yPickMin, xPickMax, yPickMax):
+ """Perform picking on the curve according to its rendering.
+
+ The picking area is [xPickMin, xPickMax], [yPickMin, yPickMax].
+
+ In case a segment between 2 points with indices i, i+1 is picked,
+ only its lower index end point (i.e., i) is added to the result.
+ In case an end point with index i is picked it is added to the result,
+ and the segment [i-1, i] is not tested for picking.
+
+ :return: The indices of the picked data
+ :rtype: Union[List[int],None]
+ """
+ if (self.marker is None and self.lineStyle is None) or \
+ self.xMin > xPickMax or xPickMin > self.xMax or \
+ self.yMin > yPickMax or yPickMin > self.yMax:
+ return None
+
+ # offset picking bounds
+ xPickMin = xPickMin - self.offset[0]
+ xPickMax = xPickMax - self.offset[0]
+ yPickMin = yPickMin - self.offset[1]
+ yPickMax = yPickMax - self.offset[1]
+
+ if self.lineStyle is not None:
+ # Using Cohen-Sutherland algorithm for line clipping
+ with numpy.errstate(invalid='ignore'): # Ignore NaN comparison warnings
+ codes = ((self.yData > yPickMax) << 3) | \
+ ((self.yData < yPickMin) << 2) | \
+ ((self.xData > xPickMax) << 1) | \
+ (self.xData < xPickMin)
+
+ notNaN = numpy.logical_not(numpy.logical_or(
+ numpy.isnan(self.xData), numpy.isnan(self.yData)))
+
+ # Add all points that are inside the picking area
+ indices = numpy.nonzero(
+ numpy.logical_and(codes == 0, notNaN))[0].tolist()
+
+ # Segment that might cross the area with no end point inside it
+ segToTestIdx = numpy.nonzero((codes[:-1] != 0) &
+ (codes[1:] != 0) &
+ ((codes[:-1] & codes[1:]) == 0))[0]
+
+ TOP, BOTTOM, RIGHT, LEFT = (1 << 3), (1 << 2), (1 << 1), (1 << 0)
+
+ for index in segToTestIdx:
+ if index not in indices:
+ x0, y0 = self.xData[index], self.yData[index]
+ x1, y1 = self.xData[index + 1], self.yData[index + 1]
+ code1 = codes[index + 1]
+
+ # check for crossing with horizontal bounds
+ # y0 == y1 is a never event:
+ # => pt0 and pt1 in same vertical area are not in segToTest
+ if code1 & TOP:
+ x = x0 + (x1 - x0) * (yPickMax - y0) / (y1 - y0)
+ elif code1 & BOTTOM:
+ x = x0 + (x1 - x0) * (yPickMin - y0) / (y1 - y0)
+ else:
+ x = None # No horizontal bounds intersection test
+
+ if x is not None and xPickMin <= x <= xPickMax:
+ # Intersection
+ indices.append(index)
+
+ else:
+ # check for crossing with vertical bounds
+ # x0 == x1 is a never event (see remark for y)
+ if code1 & RIGHT:
+ y = y0 + (y1 - y0) * (xPickMax - x0) / (x1 - x0)
+ elif code1 & LEFT:
+ y = y0 + (y1 - y0) * (xPickMin - x0) / (x1 - x0)
+ else:
+ y = None # No vertical bounds intersection test
+
+ if y is not None and yPickMin <= y <= yPickMax:
+ # Intersection
+ indices.append(index)
+
+ indices.sort()
+
+ else:
+ with numpy.errstate(invalid='ignore'): # Ignore NaN comparison warnings
+ indices = numpy.nonzero((self.xData >= xPickMin) &
+ (self.xData <= xPickMax) &
+ (self.yData >= yPickMin) &
+ (self.yData <= yPickMax))[0].tolist()
+
+ return tuple(indices) if len(indices) > 0 else None
diff --git a/src/silx/gui/plot/backends/glutils/GLPlotFrame.py b/src/silx/gui/plot/backends/glutils/GLPlotFrame.py
new file mode 100644
index 0000000..1fccb02
--- /dev/null
+++ b/src/silx/gui/plot/backends/glutils/GLPlotFrame.py
@@ -0,0 +1,1210 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+This modules provides the rendering of plot titles, axes and grid.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/04/2017"
+
+
+# TODO
+# keep aspect ratio managed here?
+# smarter dirty flag handling?
+
+import datetime as dt
+import math
+import weakref
+import logging
+from collections import namedtuple
+
+import numpy
+
+from ...._glutils import gl, Program
+from ..._utils import checkAxisLimits, FLOAT32_MINPOS
+from .GLSupport import mat4Ortho
+from .GLText import Text2D, CENTER, BOTTOM, TOP, LEFT, RIGHT, ROTATE_270
+from ..._utils.ticklayout import niceNumbersAdaptative, niceNumbersForLog10
+from ..._utils.dtime_ticklayout import calcTicksAdaptive, bestFormatString
+from ..._utils.dtime_ticklayout import timestamp
+
+_logger = logging.getLogger(__name__)
+
+
+# PlotAxis ####################################################################
+
+class PlotAxis(object):
+ """Represents a 1D axis of the plot.
+ This class is intended to be used with :class:`GLPlotFrame`.
+ """
+
+ def __init__(self, plotFrame,
+ tickLength=(0., 0.),
+ foregroundColor=(0., 0., 0., 1.0),
+ labelAlign=CENTER, labelVAlign=CENTER,
+ titleAlign=CENTER, titleVAlign=CENTER,
+ titleRotate=0, titleOffset=(0., 0.)):
+ self._ticks = None
+
+ self._plotFrameRef = weakref.ref(plotFrame)
+
+ self._isDateTime = False
+ self._timeZone = None
+ self._isLog = False
+ self._dataRange = 1., 100.
+ self._displayCoords = (0., 0.), (1., 0.)
+ self._title = ''
+
+ self._tickLength = tickLength
+ self._foregroundColor = foregroundColor
+ self._labelAlign = labelAlign
+ self._labelVAlign = labelVAlign
+ self._titleAlign = titleAlign
+ self._titleVAlign = titleVAlign
+ self._titleRotate = titleRotate
+ self._titleOffset = titleOffset
+
+ @property
+ def dataRange(self):
+ """The range of the data represented on the axis as a tuple
+ of 2 floats: (min, max)."""
+ return self._dataRange
+
+ @dataRange.setter
+ def dataRange(self, dataRange):
+ assert len(dataRange) == 2
+ assert dataRange[0] <= dataRange[1]
+ dataRange = float(dataRange[0]), float(dataRange[1])
+
+ if dataRange != self._dataRange:
+ self._dataRange = dataRange
+ self._dirtyTicks()
+
+ @property
+ def isLog(self):
+ """Whether the axis is using a log10 scale or not as a bool."""
+ return self._isLog
+
+ @isLog.setter
+ def isLog(self, isLog):
+ isLog = bool(isLog)
+ if isLog != self._isLog:
+ self._isLog = isLog
+ self._dirtyTicks()
+
+ @property
+ def timeZone(self):
+ """Returnss datetime.tzinfo that is used if this axis plots date times."""
+ return self._timeZone
+
+ @timeZone.setter
+ def timeZone(self, tz):
+ """Sets dateetime.tzinfo that is used if this axis plots date times."""
+ self._timeZone = tz
+ self._dirtyTicks()
+
+ @property
+ def isTimeSeries(self):
+ """Whether the axis is showing floats as datetime objects"""
+ return self._isDateTime
+
+ @isTimeSeries.setter
+ def isTimeSeries(self, isTimeSeries):
+ isTimeSeries = bool(isTimeSeries)
+ if isTimeSeries != self._isDateTime:
+ self._isDateTime = isTimeSeries
+ self._dirtyTicks()
+
+ @property
+ def displayCoords(self):
+ """The coordinates of the start and end points of the axis
+ in display space (i.e., in pixels) as a tuple of 2 tuples of
+ 2 floats: ((x0, y0), (x1, y1)).
+ """
+ return self._displayCoords
+
+ @displayCoords.setter
+ def displayCoords(self, displayCoords):
+ assert len(displayCoords) == 2
+ assert len(displayCoords[0]) == 2
+ assert len(displayCoords[1]) == 2
+ displayCoords = tuple(displayCoords[0]), tuple(displayCoords[1])
+ if displayCoords != self._displayCoords:
+ self._displayCoords = displayCoords
+ self._dirtyTicks()
+
+ @property
+ def devicePixelRatio(self):
+ """Returns the ratio between qt pixels and device pixels."""
+ plotFrame = self._plotFrameRef()
+ return plotFrame.devicePixelRatio if plotFrame is not None else 1.
+
+ @property
+ def title(self):
+ """The text label associated with this axis as a str in latin-1."""
+ return self._title
+
+ @title.setter
+ def title(self, title):
+ if title != self._title:
+ self._title = title
+ self._dirtyPlotFrame()
+
+ @property
+ def titleOffset(self):
+ """Title offset in pixels (x: int, y: int)"""
+ return self._titleOffset
+
+ @titleOffset.setter
+ def titleOffset(self, offset):
+ if offset != self._titleOffset:
+ self._titleOffset = offset
+ self._dirtyTicks()
+
+ @property
+ def foregroundColor(self):
+ """Color used for frame and labels"""
+ return self._foregroundColor
+
+ @foregroundColor.setter
+ def foregroundColor(self, color):
+ """Color used for frame and labels"""
+ assert len(color) == 4, \
+ "foregroundColor must have length 4, got {}".format(len(self._foregroundColor))
+ if self._foregroundColor != color:
+ self._foregroundColor = color
+ self._dirtyTicks()
+
+ @property
+ def ticks(self):
+ """Ticks as tuples: ((x, y) in display, dataPos, textLabel)."""
+ if self._ticks is None:
+ self._ticks = tuple(self._ticksGenerator())
+ return self._ticks
+
+ def getVerticesAndLabels(self):
+ """Create the list of vertices for axis and associated text labels.
+
+ :returns: A tuple: List of 2D line vertices, List of Text2D labels.
+ """
+ vertices = list(self.displayCoords) # Add start and end points
+ labels = []
+ tickLabelsSize = [0., 0.]
+
+ xTickLength, yTickLength = self._tickLength
+ xTickLength *= self.devicePixelRatio
+ yTickLength *= self.devicePixelRatio
+ for (xPixel, yPixel), dataPos, text in self.ticks:
+ if text is None:
+ tickScale = 0.5
+ else:
+ tickScale = 1.
+
+ label = Text2D(text=text,
+ color=self._foregroundColor,
+ x=xPixel - xTickLength,
+ y=yPixel - yTickLength,
+ align=self._labelAlign,
+ valign=self._labelVAlign,
+ devicePixelRatio=self.devicePixelRatio)
+
+ width, height = label.size
+ if width > tickLabelsSize[0]:
+ tickLabelsSize[0] = width
+ if height > tickLabelsSize[1]:
+ tickLabelsSize[1] = height
+
+ labels.append(label)
+
+ vertices.append((xPixel, yPixel))
+ vertices.append((xPixel + tickScale * xTickLength,
+ yPixel + tickScale * yTickLength))
+
+ (x0, y0), (x1, y1) = self.displayCoords
+ xAxisCenter = 0.5 * (x0 + x1)
+ yAxisCenter = 0.5 * (y0 + y1)
+
+ xOffset, yOffset = self.titleOffset
+
+ # Adaptative title positioning:
+ # tickNorm = math.sqrt(xTickLength ** 2 + yTickLength ** 2)
+ # xOffset = -tickLabelsSize[0] * xTickLength / tickNorm
+ # xOffset -= 3 * xTickLength
+ # yOffset = -tickLabelsSize[1] * yTickLength / tickNorm
+ # yOffset -= 3 * yTickLength
+
+ axisTitle = Text2D(text=self.title,
+ color=self._foregroundColor,
+ x=xAxisCenter + xOffset,
+ y=yAxisCenter + yOffset,
+ align=self._titleAlign,
+ valign=self._titleVAlign,
+ rotate=self._titleRotate,
+ devicePixelRatio=self.devicePixelRatio)
+ labels.append(axisTitle)
+
+ return vertices, labels
+
+ def _dirtyPlotFrame(self):
+ """Dirty parent GLPlotFrame"""
+ plotFrame = self._plotFrameRef()
+ if plotFrame is not None:
+ plotFrame._dirty()
+
+ def _dirtyTicks(self):
+ """Mark ticks as dirty and notify listener (i.e., background)."""
+ self._ticks = None
+ self._dirtyPlotFrame()
+
+ @staticmethod
+ def _frange(start, stop, step):
+ """range for float (including stop)."""
+ while start <= stop:
+ yield start
+ start += step
+
+ def _ticksGenerator(self):
+ """Generator of ticks as tuples:
+ ((x, y) in display, dataPos, textLabel).
+ """
+ dataMin, dataMax = self.dataRange
+ if self.isLog and dataMin <= 0.:
+ _logger.warning(
+ 'Getting ticks while isLog=True and dataRange[0]<=0.')
+ dataMin = 1.
+ if dataMax < dataMin:
+ dataMax = 1.
+
+ if dataMin != dataMax: # data range is not null
+ (x0, y0), (x1, y1) = self.displayCoords
+
+ if self.isLog:
+
+ if self.isTimeSeries:
+ _logger.warning("Time series not implemented for log-scale")
+
+ logMin, logMax = math.log10(dataMin), math.log10(dataMax)
+ tickMin, tickMax, step, _ = niceNumbersForLog10(logMin, logMax)
+
+ xScale = (x1 - x0) / (logMax - logMin)
+ yScale = (y1 - y0) / (logMax - logMin)
+
+ for logPos in self._frange(tickMin, tickMax, step):
+ if logMin <= logPos <= logMax:
+ dataPos = 10 ** logPos
+ xPixel = x0 + (logPos - logMin) * xScale
+ yPixel = y0 + (logPos - logMin) * yScale
+ text = '1e%+03d' % logPos
+ yield ((xPixel, yPixel), dataPos, text)
+
+ if step == 1:
+ ticks = list(self._frange(tickMin, tickMax, step))[:-1]
+ for logPos in ticks:
+ dataOrigPos = 10 ** logPos
+ for index in range(2, 10):
+ dataPos = dataOrigPos * index
+ if dataMin <= dataPos <= dataMax:
+ logSubPos = math.log10(dataPos)
+ xPixel = x0 + (logSubPos - logMin) * xScale
+ yPixel = y0 + (logSubPos - logMin) * yScale
+ yield ((xPixel, yPixel), dataPos, None)
+
+ else:
+ xScale = (x1 - x0) / (dataMax - dataMin)
+ yScale = (y1 - y0) / (dataMax - dataMin)
+
+ nbPixels = math.sqrt(pow(x1 - x0, 2) + pow(y1 - y0, 2)) / self.devicePixelRatio
+
+ # Density of 1.3 label per 92 pixels
+ # i.e., 1.3 label per inch on a 92 dpi screen
+ tickDensity = 1.3 / 92
+
+ if not self.isTimeSeries:
+ tickMin, tickMax, step, nbFrac = niceNumbersAdaptative(
+ dataMin, dataMax, nbPixels, tickDensity)
+
+ for dataPos in self._frange(tickMin, tickMax, step):
+ if dataMin <= dataPos <= dataMax:
+ xPixel = x0 + (dataPos - dataMin) * xScale
+ yPixel = y0 + (dataPos - dataMin) * yScale
+
+ if nbFrac == 0:
+ text = '%g' % dataPos
+ else:
+ text = ('%.' + str(nbFrac) + 'f') % dataPos
+ yield ((xPixel, yPixel), dataPos, text)
+ else:
+ # Time series
+ dtMin = dt.datetime.fromtimestamp(dataMin, tz=self.timeZone)
+ dtMax = dt.datetime.fromtimestamp(dataMax, tz=self.timeZone)
+
+ tickDateTimes, spacing, unit = calcTicksAdaptive(
+ dtMin, dtMax, nbPixels, tickDensity)
+
+ for tickDateTime in tickDateTimes:
+ if dtMin <= tickDateTime <= dtMax:
+
+ dataPos = timestamp(tickDateTime)
+ xPixel = x0 + (dataPos - dataMin) * xScale
+ yPixel = y0 + (dataPos - dataMin) * yScale
+
+ fmtStr = bestFormatString(spacing, unit)
+ text = tickDateTime.strftime(fmtStr)
+
+ yield ((xPixel, yPixel), dataPos, text)
+
+
+# GLPlotFrame #################################################################
+
+class GLPlotFrame(object):
+ """Base class for rendering a 2D frame surrounded by axes."""
+
+ _TICK_LENGTH_IN_PIXELS = 5
+ _LINE_WIDTH = 1
+
+ _SHADERS = {
+ 'vertex': """
+ attribute vec2 position;
+ uniform mat4 matrix;
+
+ void main(void) {
+ gl_Position = matrix * vec4(position, 0.0, 1.0);
+ }
+ """,
+ 'fragment': """
+ uniform vec4 color;
+ uniform float tickFactor; /* = 1./tickLength or 0. for solid line */
+
+ void main(void) {
+ if (mod(tickFactor * (gl_FragCoord.x + gl_FragCoord.y), 2.) < 1.) {
+ gl_FragColor = color;
+ } else {
+ discard;
+ }
+ }
+ """
+ }
+
+ _Margins = namedtuple('Margins', ('left', 'right', 'top', 'bottom'))
+
+ # Margins used when plot frame is not displayed
+ _NoDisplayMargins = _Margins(0, 0, 0, 0)
+
+ def __init__(self, marginRatios, foregroundColor, gridColor):
+ """
+ :param List[float] marginRatios:
+ The ratios of margins around plot area for axis and labels.
+ (left, top, right, bottom) as float in [0., 1.]
+ :param foregroundColor: color used for the frame and labels.
+ :type foregroundColor: tuple with RGBA values ranging from 0.0 to 1.0
+ :param gridColor: color used for grid lines.
+ :type gridColor: tuple RGBA with RGBA values ranging from 0.0 to 1.0
+ """
+ self._renderResources = None
+
+ self.__marginRatios = marginRatios
+ self.__marginsCache = None
+
+ self._foregroundColor = foregroundColor
+ self._gridColor = gridColor
+
+ self.axes = [] # List of PlotAxis to be updated by subclasses
+
+ self._grid = False
+ self._size = 0., 0.
+ self._title = ''
+
+ self._devicePixelRatio = 1.
+
+ @property
+ def isDirty(self):
+ """True if it need to refresh graphic rendering, False otherwise."""
+ return self._renderResources is None
+
+ GRID_NONE = 0
+ GRID_MAIN_TICKS = 1
+ GRID_SUB_TICKS = 2
+ GRID_ALL_TICKS = (GRID_MAIN_TICKS + GRID_SUB_TICKS)
+
+ @property
+ def foregroundColor(self):
+ """Color used for frame and labels"""
+ return self._foregroundColor
+
+ @foregroundColor.setter
+ def foregroundColor(self, color):
+ """Color used for frame and labels"""
+ assert len(color) == 4, \
+ "foregroundColor must have length 4, got {}".format(len(self._foregroundColor))
+ if self._foregroundColor != color:
+ self._foregroundColor = color
+ for axis in self.axes:
+ axis.foregroundColor = color
+ self._dirty()
+
+ @property
+ def gridColor(self):
+ """Color used for frame and labels"""
+ return self._gridColor
+
+ @gridColor.setter
+ def gridColor(self, color):
+ """Color used for frame and labels"""
+ assert len(color) == 4, \
+ "gridColor must have length 4, got {}".format(len(self._gridColor))
+ if self._gridColor != color:
+ self._gridColor = color
+ self._dirty()
+
+ @property
+ def marginRatios(self):
+ """Plot margin ratios: (left, top, right, bottom) as 4 float in [0, 1].
+ """
+ return self.__marginRatios
+
+ @marginRatios.setter
+ def marginRatios(self, ratios):
+ ratios = tuple(float(v) for v in ratios)
+ assert len(ratios) == 4
+ for value in ratios:
+ assert 0. <= value <= 1.
+ assert ratios[0] + ratios[2] < 1.
+ assert ratios[1] + ratios[3] < 1.
+
+ if self.__marginRatios != ratios:
+ self.__marginRatios = ratios
+ self.__marginsCache = None # Clear cached margins
+ self._dirty()
+
+ @property
+ def margins(self):
+ """Margins in pixels around the plot."""
+ if self.__marginsCache is None:
+ width, height = self.size
+ left, top, right, bottom = self.marginRatios
+ self.__marginsCache = self._Margins(
+ left=int(left*width),
+ right=int(right*width),
+ top=int(top*height),
+ bottom=int(bottom*height))
+ return self.__marginsCache
+
+ @property
+ def devicePixelRatio(self):
+ return self._devicePixelRatio
+
+ @devicePixelRatio.setter
+ def devicePixelRatio(self, ratio):
+ if ratio != self._devicePixelRatio:
+ self._devicePixelRatio = ratio
+ self._dirty()
+
+ @property
+ def grid(self):
+ """Grid display mode:
+ - 0: No grid.
+ - 1: Grid on main ticks.
+ - 2: Grid on sub-ticks for log scale axes.
+ - 3: Grid on main and sub ticks."""
+ return self._grid
+
+ @grid.setter
+ def grid(self, grid):
+ assert grid in (self.GRID_NONE, self.GRID_MAIN_TICKS,
+ self.GRID_SUB_TICKS, self.GRID_ALL_TICKS)
+ if grid != self._grid:
+ self._grid = grid
+ self._dirty()
+
+ @property
+ def size(self):
+ """Size in device pixels of the plot area including margins."""
+ return self._size
+
+ @size.setter
+ def size(self, size):
+ assert len(size) == 2
+ size = tuple(size)
+ if size != self._size:
+ self._size = size
+ self.__marginsCache = None # Clear cached margins
+ self._dirty()
+
+ @property
+ def plotOrigin(self):
+ """Plot area origin (left, top) in widget coordinates in pixels."""
+ return self.margins.left, self.margins.top
+
+ @property
+ def plotSize(self):
+ """Plot area size (width, height) in pixels."""
+ w, h = self.size
+ w -= self.margins.left + self.margins.right
+ h -= self.margins.top + self.margins.bottom
+ return w, h
+
+ @property
+ def title(self):
+ """Main title as a str in latin-1."""
+ return self._title
+
+ @title.setter
+ def title(self, title):
+ if title != self._title:
+ self._title = title
+ self._dirty()
+
+ # In-place update
+ # if self._renderResources is not None:
+ # self._renderResources[-1][-1].text = title
+
+ def _dirty(self):
+ # When Text2D require discard we need to handle it
+ self._renderResources = None
+
+ def _buildGridVertices(self):
+ if self._grid == self.GRID_NONE:
+ return []
+
+ elif self._grid == self.GRID_MAIN_TICKS:
+ def test(text):
+ return text is not None
+ elif self._grid == self.GRID_SUB_TICKS:
+ def test(text):
+ return text is None
+ elif self._grid == self.GRID_ALL_TICKS:
+ def test(_):
+ return True
+ else:
+ logging.warning('Wrong grid mode: %d' % self._grid)
+ return []
+
+ return self._buildGridVerticesWithTest(test)
+
+ def _buildGridVerticesWithTest(self, test):
+ """Override in subclass to generate grid vertices"""
+ return []
+
+ def _buildVerticesAndLabels(self):
+ # To fill with copy of axes lists
+ vertices = []
+ labels = []
+
+ for axis in self.axes:
+ axisVertices, axisLabels = axis.getVerticesAndLabels()
+ vertices += axisVertices
+ labels += axisLabels
+
+ vertices = numpy.array(vertices, dtype=numpy.float32)
+
+ # Add main title
+ xTitle = (self.size[0] + self.margins.left -
+ self.margins.right) // 2
+ yTitle = self.margins.top - self._TICK_LENGTH_IN_PIXELS
+ labels.append(Text2D(text=self.title,
+ color=self._foregroundColor,
+ x=xTitle,
+ y=yTitle,
+ align=CENTER,
+ valign=BOTTOM,
+ devicePixelRatio=self.devicePixelRatio))
+
+ # grid
+ gridVertices = numpy.array(self._buildGridVertices(),
+ dtype=numpy.float32)
+
+ self._renderResources = (vertices, gridVertices, labels)
+
+ _program = Program(
+ _SHADERS['vertex'], _SHADERS['fragment'], attrib0='position')
+
+ def render(self):
+ if self.margins == self._NoDisplayMargins:
+ return
+
+ if self._renderResources is None:
+ self._buildVerticesAndLabels()
+ vertices, gridVertices, labels = self._renderResources
+
+ width, height = self.size
+ matProj = mat4Ortho(0, width, height, 0, 1, -1)
+
+ gl.glViewport(0, 0, width, height)
+
+ prog = self._program
+ prog.use()
+
+ gl.glLineWidth(self._LINE_WIDTH)
+
+ gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE,
+ matProj.astype(numpy.float32))
+ gl.glUniform4f(prog.uniforms['color'], *self._foregroundColor)
+ gl.glUniform1f(prog.uniforms['tickFactor'], 0.)
+
+ gl.glEnableVertexAttribArray(prog.attributes['position'])
+ gl.glVertexAttribPointer(prog.attributes['position'],
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0, vertices)
+
+ gl.glDrawArrays(gl.GL_LINES, 0, len(vertices))
+
+ for label in labels:
+ label.render(matProj)
+
+ def renderGrid(self):
+ if self._grid == self.GRID_NONE:
+ return
+
+ if self._renderResources is None:
+ self._buildVerticesAndLabels()
+ vertices, gridVertices, labels = self._renderResources
+
+ width, height = self.size
+ matProj = mat4Ortho(0, width, height, 0, 1, -1)
+
+ gl.glViewport(0, 0, width, height)
+
+ prog = self._program
+ prog.use()
+
+ gl.glLineWidth(self._LINE_WIDTH)
+ gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE,
+ matProj.astype(numpy.float32))
+ gl.glUniform4f(prog.uniforms['color'], *self._gridColor)
+ gl.glUniform1f(prog.uniforms['tickFactor'], 0.) # 1/2.) # 1/tickLen
+
+ gl.glEnableVertexAttribArray(prog.attributes['position'])
+ gl.glVertexAttribPointer(prog.attributes['position'],
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0, gridVertices)
+
+ gl.glDrawArrays(gl.GL_LINES, 0, len(gridVertices))
+
+
+# GLPlotFrame2D ###############################################################
+
+class GLPlotFrame2D(GLPlotFrame):
+ def __init__(self, marginRatios, foregroundColor, gridColor):
+ """
+ :param List[float] marginRatios:
+ The ratios of margins around plot area for axis and labels.
+ (left, top, right, bottom) as float in [0., 1.]
+ :param foregroundColor: color used for the frame and labels.
+ :type foregroundColor: tuple with RGBA values ranging from 0.0 to 1.0
+ :param gridColor: color used for grid lines.
+ :type gridColor: tuple RGBA with RGBA values ranging from 0.0 to 1.0
+
+ """
+ super(GLPlotFrame2D, self).__init__(marginRatios, foregroundColor, gridColor)
+ self.axes.append(PlotAxis(self,
+ tickLength=(0., -5.),
+ foregroundColor=self._foregroundColor,
+ labelAlign=CENTER, labelVAlign=TOP,
+ titleAlign=CENTER, titleVAlign=TOP,
+ titleRotate=0))
+
+ self._x2AxisCoords = ()
+
+ self.axes.append(PlotAxis(self,
+ tickLength=(5., 0.),
+ foregroundColor=self._foregroundColor,
+ labelAlign=RIGHT, labelVAlign=CENTER,
+ titleAlign=CENTER, titleVAlign=BOTTOM,
+ titleRotate=ROTATE_270))
+
+ self._y2Axis = PlotAxis(self,
+ tickLength=(-5., 0.),
+ foregroundColor=self._foregroundColor,
+ labelAlign=LEFT, labelVAlign=CENTER,
+ titleAlign=CENTER, titleVAlign=TOP,
+ titleRotate=ROTATE_270)
+
+ self._isYAxisInverted = False
+
+ self._dataRanges = {
+ 'x': (1., 100.), 'y': (1., 100.), 'y2': (1., 100.)}
+
+ self._baseVectors = (1., 0.), (0., 1.)
+
+ self._transformedDataRanges = None
+ self._transformedDataProjMat = None
+ self._transformedDataY2ProjMat = None
+
+ def _dirty(self):
+ super(GLPlotFrame2D, self)._dirty()
+ self._transformedDataRanges = None
+ self._transformedDataProjMat = None
+ self._transformedDataY2ProjMat = None
+
+ @property
+ def isDirty(self):
+ """True if it need to refresh graphic rendering, False otherwise."""
+ return (super(GLPlotFrame2D, self).isDirty or
+ self._transformedDataRanges is None or
+ self._transformedDataProjMat is None or
+ self._transformedDataY2ProjMat is None)
+
+ @property
+ def xAxis(self):
+ return self.axes[0]
+
+ @property
+ def yAxis(self):
+ return self.axes[1]
+
+ @property
+ def y2Axis(self):
+ return self._y2Axis
+
+ @property
+ def isY2Axis(self):
+ """Whether to display the left Y axis or not."""
+ return len(self.axes) == 3
+
+ @isY2Axis.setter
+ def isY2Axis(self, isY2Axis):
+ if isY2Axis != self.isY2Axis:
+ if isY2Axis:
+ self.axes.append(self._y2Axis)
+ else:
+ self.axes = self.axes[:2]
+
+ self._dirty()
+
+ @property
+ def isYAxisInverted(self):
+ """Whether Y axes are inverted or not as a bool."""
+ return self._isYAxisInverted
+
+ @isYAxisInverted.setter
+ def isYAxisInverted(self, value):
+ value = bool(value)
+ if value != self._isYAxisInverted:
+ self._isYAxisInverted = value
+ self._dirty()
+
+ DEFAULT_BASE_VECTORS = (1., 0.), (0., 1.)
+ """Values of baseVectors for orthogonal axes."""
+
+ @property
+ def baseVectors(self):
+ """Coordinates of the X and Y axes in the orthogonal plot coords.
+
+ Raises ValueError if corresponding matrix is singular.
+
+ 2 tuples of 2 floats: (xx, xy), (yx, yy)
+ """
+ return self._baseVectors
+
+ @baseVectors.setter
+ def baseVectors(self, baseVectors):
+ self._dirty()
+
+ (xx, xy), (yx, yy) = baseVectors
+ vectors = (float(xx), float(xy)), (float(yx), float(yy))
+
+ det = (vectors[0][0] * vectors[1][1] - vectors[1][0] * vectors[0][1])
+ if det == 0.:
+ raise ValueError("Singular matrix for base vectors: " +
+ str(vectors))
+
+ if vectors != self._baseVectors:
+ self._baseVectors = vectors
+ self._dirty()
+
+ def _updateTitleOffset(self):
+ """Update axes title offset according to margins"""
+ margins = self.margins
+ self.xAxis.titleOffset = 0, margins.bottom // 2
+ self.yAxis.titleOffset = -3 * margins.left // 4, 0
+ self.y2Axis.titleOffset = 3 * margins.right // 4, 0
+
+ # Override size and marginRatios setters to update titleOffsets
+ @GLPlotFrame.size.setter
+ def size(self, size):
+ GLPlotFrame.size.fset(self, size)
+ self._updateTitleOffset()
+
+ @GLPlotFrame.marginRatios.setter
+ def marginRatios(self, ratios):
+ GLPlotFrame.marginRatios.fset(self, ratios)
+ self._updateTitleOffset()
+
+ @property
+ def dataRanges(self):
+ """Ranges of data visible in the plot on x, y and y2 axes.
+
+ This is different to the axes range when axes are not orthogonal.
+
+ Type: ((xMin, xMax), (yMin, yMax), (y2Min, y2Max))
+ """
+ return self._DataRanges(self._dataRanges['x'],
+ self._dataRanges['y'],
+ self._dataRanges['y2'])
+
+ def setDataRanges(self, x=None, y=None, y2=None):
+ """Set data range over each axes.
+
+ The provided ranges are clipped to possible values
+ (i.e., 32 float range + positive range for log scale).
+
+ :param x: (min, max) data range over X axis
+ :param y: (min, max) data range over Y axis
+ :param y2: (min, max) data range over Y2 axis
+ """
+ if x is not None:
+ self._dataRanges['x'] = checkAxisLimits(
+ x[0], x[1], self.xAxis.isLog, name='x')
+
+ if y is not None:
+ self._dataRanges['y'] = checkAxisLimits(
+ y[0], y[1], self.yAxis.isLog, name='y')
+
+ if y2 is not None:
+ self._dataRanges['y2'] = checkAxisLimits(
+ y2[0], y2[1], self.y2Axis.isLog, name='y2')
+
+ self.xAxis.dataRange = self._dataRanges['x']
+ self.yAxis.dataRange = self._dataRanges['y']
+ self.y2Axis.dataRange = self._dataRanges['y2']
+
+ _DataRanges = namedtuple('dataRanges', ('x', 'y', 'y2'))
+
+ @property
+ def transformedDataRanges(self):
+ """Bounds of the displayed area in transformed data coordinates
+ (i.e., log scale applied if any as well as skew)
+
+ 3-tuple of 2-tuple (min, max) for each axis: x, y, y2.
+ """
+ if self._transformedDataRanges is None:
+ (xMin, xMax), (yMin, yMax), (y2Min, y2Max) = self.dataRanges
+
+ if self.xAxis.isLog:
+ try:
+ xMin = math.log10(xMin)
+ except ValueError:
+ _logger.info('xMin: warning log10(%f)', xMin)
+ xMin = 0.
+ try:
+ xMax = math.log10(xMax)
+ except ValueError:
+ _logger.info('xMax: warning log10(%f)', xMax)
+ xMax = 0.
+
+ if self.yAxis.isLog:
+ try:
+ yMin = math.log10(yMin)
+ except ValueError:
+ _logger.info('yMin: warning log10(%f)', yMin)
+ yMin = 0.
+ try:
+ yMax = math.log10(yMax)
+ except ValueError:
+ _logger.info('yMax: warning log10(%f)', yMax)
+ yMax = 0.
+
+ try:
+ y2Min = math.log10(y2Min)
+ except ValueError:
+ _logger.info('yMin: warning log10(%f)', y2Min)
+ y2Min = 0.
+ try:
+ y2Max = math.log10(y2Max)
+ except ValueError:
+ _logger.info('yMax: warning log10(%f)', y2Max)
+ y2Max = 0.
+
+ self._transformedDataRanges = self._DataRanges(
+ (xMin, xMax), (yMin, yMax), (y2Min, y2Max))
+
+ return self._transformedDataRanges
+
+ @property
+ def transformedDataProjMat(self):
+ """Orthographic projection matrix for rendering transformed data
+
+ :type: numpy.matrix
+ """
+ if self._transformedDataProjMat is None:
+ xMin, xMax = self.transformedDataRanges.x
+ yMin, yMax = self.transformedDataRanges.y
+
+ if self.isYAxisInverted:
+ mat = mat4Ortho(xMin, xMax, yMax, yMin, 1, -1)
+ else:
+ mat = mat4Ortho(xMin, xMax, yMin, yMax, 1, -1)
+ self._transformedDataProjMat = mat
+
+ return self._transformedDataProjMat
+
+ @property
+ def transformedDataY2ProjMat(self):
+ """Orthographic projection matrix for rendering transformed data
+ for the 2nd Y axis
+
+ :type: numpy.matrix
+ """
+ if self._transformedDataY2ProjMat is None:
+ xMin, xMax = self.transformedDataRanges.x
+ y2Min, y2Max = self.transformedDataRanges.y2
+
+ if self.isYAxisInverted:
+ mat = mat4Ortho(xMin, xMax, y2Max, y2Min, 1, -1)
+ else:
+ mat = mat4Ortho(xMin, xMax, y2Min, y2Max, 1, -1)
+ self._transformedDataY2ProjMat = mat
+
+ return self._transformedDataY2ProjMat
+
+ def dataToPixel(self, x, y, axis='left'):
+ """Convert data coordinate to widget pixel coordinate.
+ """
+ assert axis in ('left', 'right')
+
+ trBounds = self.transformedDataRanges
+
+ if self.xAxis.isLog:
+ if x < FLOAT32_MINPOS:
+ return None
+ xDataTr = math.log10(x)
+ else:
+ xDataTr = x
+
+ if self.yAxis.isLog:
+ if y < FLOAT32_MINPOS:
+ return None
+ yDataTr = math.log10(y)
+ else:
+ yDataTr = y
+
+ # Non-orthogonal axes
+ if self.baseVectors != self.DEFAULT_BASE_VECTORS:
+ (xx, xy), (yx, yy) = self.baseVectors
+ skew_mat = numpy.array(((xx, yx), (xy, yy)))
+
+ coords = numpy.dot(skew_mat, numpy.array((xDataTr, yDataTr)))
+ xDataTr, yDataTr = coords
+
+ plotWidth, plotHeight = self.plotSize
+
+ xPixel = int(self.margins.left +
+ plotWidth * (xDataTr - trBounds.x[0]) /
+ (trBounds.x[1] - trBounds.x[0]))
+
+ usedAxis = trBounds.y if axis == "left" else trBounds.y2
+ yOffset = (plotHeight * (yDataTr - usedAxis[0]) /
+ (usedAxis[1] - usedAxis[0]))
+
+ if self.isYAxisInverted:
+ yPixel = int(self.margins.top + yOffset)
+ else:
+ yPixel = int(self.size[1] - self.margins.bottom - yOffset)
+
+ return xPixel, yPixel
+
+ def pixelToData(self, x, y, axis="left"):
+ """Convert pixel position to data coordinates.
+
+ :param float x: X coord
+ :param float y: Y coord
+ :param str axis: Y axis to use in ('left', 'right')
+ :return: (x, y) position in data coords
+ """
+ assert axis in ("left", "right")
+
+ plotWidth, plotHeight = self.plotSize
+
+ trBounds = self.transformedDataRanges
+
+ xData = (x - self.margins.left + 0.5) / float(plotWidth)
+ xData = trBounds.x[0] + xData * (trBounds.x[1] - trBounds.x[0])
+
+ usedAxis = trBounds.y if axis == "left" else trBounds.y2
+ if self.isYAxisInverted:
+ yData = (y - self.margins.top + 0.5) / float(plotHeight)
+ yData = usedAxis[0] + yData * (usedAxis[1] - usedAxis[0])
+ else:
+ yData = self.size[1] - self.margins.bottom - y - 0.5
+ yData /= float(plotHeight)
+ yData = usedAxis[0] + yData * (usedAxis[1] - usedAxis[0])
+
+ # non-orthogonal axis
+ if self.baseVectors != self.DEFAULT_BASE_VECTORS:
+ (xx, xy), (yx, yy) = self.baseVectors
+ skew_mat = numpy.array(((xx, yx), (xy, yy)))
+ skew_mat = numpy.linalg.inv(skew_mat)
+
+ coords = numpy.dot(skew_mat, numpy.array((xData, yData)))
+ xData, yData = coords
+
+ if self.xAxis.isLog:
+ xData = pow(10, xData)
+ if self.yAxis.isLog:
+ yData = pow(10, yData)
+
+ return xData, yData
+
+ def _buildGridVerticesWithTest(self, test):
+ vertices = []
+
+ if self.baseVectors == self.DEFAULT_BASE_VECTORS:
+ for axis in self.axes:
+ for (xPixel, yPixel), data, text in axis.ticks:
+ if test(text):
+ vertices.append((xPixel, yPixel))
+ if axis == self.xAxis:
+ vertices.append((xPixel, self.margins.top))
+ elif axis == self.yAxis:
+ vertices.append((self.size[0] - self.margins.right,
+ yPixel))
+ else: # axis == self.y2Axis
+ vertices.append((self.margins.left, yPixel))
+
+ else:
+ # Get plot corners in data coords
+ plotLeft, plotTop = self.plotOrigin
+ plotWidth, plotHeight = self.plotSize
+
+ corners = [(plotLeft, plotTop),
+ (plotLeft, plotTop + plotHeight),
+ (plotLeft + plotWidth, plotTop + plotHeight),
+ (plotLeft + plotWidth, plotTop)]
+
+ for axis in self.axes:
+ if axis == self.xAxis:
+ cornersInData = numpy.array([
+ self.pixelToData(x, y) for (x, y) in corners])
+ borders = ((cornersInData[0], cornersInData[3]), # top
+ (cornersInData[1], cornersInData[0]), # left
+ (cornersInData[3], cornersInData[2])) # right
+
+ for (xPixel, yPixel), data, text in axis.ticks:
+ if test(text):
+ for (x0, y0), (x1, y1) in borders:
+ if min(x0, x1) <= data < max(x0, x1):
+ yIntersect = (data - x0) * \
+ (y1 - y0) / (x1 - x0) + y0
+
+ pixelPos = self.dataToPixel(
+ data, yIntersect)
+ if pixelPos is not None:
+ vertices.append((xPixel, yPixel))
+ vertices.append(pixelPos)
+ break # Stop at first intersection
+
+ else: # y or y2 axes
+ if axis == self.yAxis:
+ axis_name = 'left'
+ cornersInData = numpy.array([
+ self.pixelToData(x, y) for (x, y) in corners])
+ borders = (
+ (cornersInData[3], cornersInData[2]), # right
+ (cornersInData[0], cornersInData[3]), # top
+ (cornersInData[2], cornersInData[1])) # bottom
+
+ else: # axis == self.y2Axis
+ axis_name = 'right'
+ corners = numpy.array([self.pixelToData(
+ x, y, axis='right') for (x, y) in corners])
+ borders = (
+ (cornersInData[1], cornersInData[0]), # left
+ (cornersInData[0], cornersInData[3]), # top
+ (cornersInData[2], cornersInData[1])) # bottom
+
+ for (xPixel, yPixel), data, text in axis.ticks:
+ if test(text):
+ for (x0, y0), (x1, y1) in borders:
+ if min(y0, y1) <= data < max(y0, y1):
+ xIntersect = (data - y0) * \
+ (x1 - x0) / (y1 - y0) + x0
+
+ pixelPos = self.dataToPixel(
+ xIntersect, data, axis=axis_name)
+ if pixelPos is not None:
+ vertices.append((xPixel, yPixel))
+ vertices.append(pixelPos)
+ break # Stop at first intersection
+
+ return vertices
+
+ def _buildVerticesAndLabels(self):
+ width, height = self.size
+
+ xCoords = (self.margins.left - 0.5,
+ width - self.margins.right + 0.5)
+ yCoords = (height - self.margins.bottom + 0.5,
+ self.margins.top - 0.5)
+
+ self.axes[0].displayCoords = ((xCoords[0], yCoords[0]),
+ (xCoords[1], yCoords[0]))
+
+ self._x2AxisCoords = ((xCoords[0], yCoords[1]),
+ (xCoords[1], yCoords[1]))
+
+ if self.isYAxisInverted:
+ # Y axes are inverted, axes coordinates are inverted
+ yCoords = yCoords[1], yCoords[0]
+
+ self.axes[1].displayCoords = ((xCoords[0], yCoords[0]),
+ (xCoords[0], yCoords[1]))
+
+ self._y2Axis.displayCoords = ((xCoords[1], yCoords[0]),
+ (xCoords[1], yCoords[1]))
+
+ super(GLPlotFrame2D, self)._buildVerticesAndLabels()
+
+ vertices, gridVertices, labels = self._renderResources
+
+ # Adds vertices for borders without axis
+ extraVertices = []
+ extraVertices += self._x2AxisCoords
+ if not self.isY2Axis:
+ extraVertices += self._y2Axis.displayCoords
+
+ extraVertices = numpy.array(
+ extraVertices, copy=False, dtype=numpy.float32)
+ vertices = numpy.append(vertices, extraVertices, axis=0)
+
+ self._renderResources = (vertices, gridVertices, labels)
+
+ @property
+ def foregroundColor(self):
+ """Color used for frame and labels"""
+ return self._foregroundColor
+
+ @foregroundColor.setter
+ def foregroundColor(self, color):
+ """Color used for frame and labels"""
+ assert len(color) == 4, \
+ "foregroundColor must have length 4, got {}".format(len(self._foregroundColor))
+ if self._foregroundColor != color:
+ self._y2Axis.foregroundColor = color
+ GLPlotFrame.foregroundColor.fset(self, color) # call parent property
diff --git a/src/silx/gui/plot/backends/glutils/GLPlotImage.py b/src/silx/gui/plot/backends/glutils/GLPlotImage.py
new file mode 100644
index 0000000..3ad94b9
--- /dev/null
+++ b/src/silx/gui/plot/backends/glutils/GLPlotImage.py
@@ -0,0 +1,756 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+This module provides a class to render 2D array as a colormap or RGB(A) image
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/04/2017"
+
+
+import math
+import numpy
+
+from silx.math.combo import min_max
+
+from ...._glutils import gl, Program, Texture
+from ..._utils import FLOAT32_MINPOS
+from .GLSupport import mat4Translate, mat4Scale
+from .GLTexture import Image
+from .GLPlotItem import GLPlotItem
+
+
+class _GLPlotData2D(GLPlotItem):
+ def __init__(self, data, origin, scale):
+ super().__init__()
+ self.data = data
+ assert len(origin) == 2
+ self.origin = tuple(origin)
+ assert len(scale) == 2
+ self.scale = tuple(scale)
+
+ def pick(self, x, y):
+ if self.xMin <= x <= self.xMax and self.yMin <= y <= self.yMax:
+ ox, oy = self.origin
+ sx, sy = self.scale
+ col = int((x - ox) / sx)
+ row = int((y - oy) / sy)
+ return (row,), (col,)
+ else:
+ return None
+
+ @property
+ def xMin(self):
+ ox, sx = self.origin[0], self.scale[0]
+ return ox if sx >= 0. else ox + sx * self.data.shape[1]
+
+ @property
+ def yMin(self):
+ oy, sy = self.origin[1], self.scale[1]
+ return oy if sy >= 0. else oy + sy * self.data.shape[0]
+
+ @property
+ def xMax(self):
+ ox, sx = self.origin[0], self.scale[0]
+ return ox + sx * self.data.shape[1] if sx >= 0. else ox
+
+ @property
+ def yMax(self):
+ oy, sy = self.origin[1], self.scale[1]
+ return oy + sy * self.data.shape[0] if sy >= 0. else oy
+
+
+class GLPlotColormap(_GLPlotData2D):
+
+ _SHADERS = {
+ 'linear': {
+ 'vertex': """
+ #version 120
+
+ uniform mat4 matrix;
+ attribute vec2 texCoords;
+ attribute vec2 position;
+
+ varying vec2 coords;
+
+ void main(void) {
+ coords = texCoords;
+ gl_Position = matrix * vec4(position, 0.0, 1.0);
+ }
+ """,
+ 'fragTransform': """
+ vec2 textureCoords(void) {
+ return coords;
+ }
+ """},
+
+ 'log': {
+ 'vertex': """
+ #version 120
+
+ attribute vec2 position;
+ uniform mat4 matrix;
+ uniform mat4 matOffset;
+ uniform bvec2 isLog;
+
+ varying vec2 coords;
+
+ const float oneOverLog10 = 0.43429448190325176;
+
+ void main(void) {
+ vec4 dataPos = matOffset * vec4(position, 0.0, 1.0);
+ if (isLog.x) {
+ dataPos.x = oneOverLog10 * log(dataPos.x);
+ }
+ if (isLog.y) {
+ dataPos.y = oneOverLog10 * log(dataPos.y);
+ }
+ coords = dataPos.xy;
+ gl_Position = matrix * dataPos;
+ }
+ """,
+ 'fragTransform': """
+ uniform bvec2 isLog;
+ uniform vec2 bounds_oneOverRange;
+ uniform vec2 bounds_originOverRange;
+
+ vec2 textureCoords(void) {
+ vec2 pos = coords;
+ if (isLog.x) {
+ pos.x = pow(10., coords.x);
+ }
+ if (isLog.y) {
+ pos.y = pow(10., coords.y);
+ }
+ return pos * bounds_oneOverRange - bounds_originOverRange;
+ // TODO texture coords in range different from [0, 1]
+ }
+ """},
+
+ 'fragment': """
+ #version 120
+
+ /* isnan declaration for compatibility with GLSL 1.20 */
+ bool isnan(float value) {
+ return (value != value);
+ }
+
+ uniform sampler2D data;
+ uniform sampler2D cmap_texture;
+ uniform int cmap_normalization;
+ uniform float cmap_parameter;
+ uniform float cmap_min;
+ uniform float cmap_oneOverRange;
+ uniform float alpha;
+ uniform vec4 nancolor;
+
+ varying vec2 coords;
+
+ %s
+
+ const float oneOverLog10 = 0.43429448190325176;
+
+ void main(void) {
+ float data = texture2D(data, textureCoords()).r;
+ float value = data;
+ if (cmap_normalization == 1) { /*Logarithm mapping*/
+ if (value > 0.) {
+ value = clamp(cmap_oneOverRange *
+ (oneOverLog10 * log(value) - cmap_min),
+ 0., 1.);
+ } else {
+ value = 0.;
+ }
+ } else if (cmap_normalization == 2) { /*Square root mapping*/
+ if (value >= 0.) {
+ value = clamp(cmap_oneOverRange * (sqrt(value) - cmap_min),
+ 0., 1.);
+ } else {
+ value = 0.;
+ }
+ } else if (cmap_normalization == 3) { /*Gamma correction mapping*/
+ value = pow(
+ clamp(cmap_oneOverRange * (value - cmap_min), 0., 1.),
+ cmap_parameter);
+ } else if (cmap_normalization == 4) { /* arcsinh mapping */
+ /* asinh = log(x + sqrt(x*x + 1) for compatibility with GLSL 1.20 */
+ value = clamp(cmap_oneOverRange * (log(value + sqrt(value*value + 1.0)) - cmap_min), 0., 1.);
+ } else { /*Linear mapping and fallback*/
+ value = clamp(cmap_oneOverRange * (value - cmap_min), 0., 1.);
+ }
+
+ if (isnan(data)) {
+ gl_FragColor = nancolor;
+ } else {
+ gl_FragColor = texture2D(cmap_texture, vec2(value, 0.5));
+ }
+ gl_FragColor.a *= alpha;
+ }
+ """
+ }
+
+ _DATA_TEX_UNIT = 0
+ _CMAP_TEX_UNIT = 1
+
+ _INTERNAL_FORMATS = {
+ numpy.dtype(numpy.float32): gl.GL_R32F,
+ numpy.dtype(numpy.float16): gl.GL_R16F,
+ # Use normalized integer for unsigned int formats
+ numpy.dtype(numpy.uint16): gl.GL_R16,
+ numpy.dtype(numpy.uint8): gl.GL_R8,
+ }
+
+ _linearProgram = Program(_SHADERS['linear']['vertex'],
+ _SHADERS['fragment'] %
+ _SHADERS['linear']['fragTransform'],
+ attrib0='position')
+
+ _logProgram = Program(_SHADERS['log']['vertex'],
+ _SHADERS['fragment'] %
+ _SHADERS['log']['fragTransform'],
+ attrib0='position')
+
+ SUPPORTED_NORMALIZATIONS = 'linear', 'log', 'sqrt', 'gamma', 'arcsinh'
+
+ def __init__(self, data, origin, scale,
+ colormap, normalization='linear', gamma=0., cmapRange=None,
+ alpha=1.0, nancolor=(1., 1., 1., 0.)):
+ """Create a 2D colormap
+
+ :param data: The 2D scalar data array to display
+ :type data: numpy.ndarray with 2 dimensions (dtype=numpy.float32)
+ :param origin: (x, y) coordinates of the origin of the data array
+ :type origin: 2-tuple of floats.
+ :param scale: (sx, sy) scale factors of the data array.
+ This is the size of a data pixel in plot data space.
+ :type scale: 2-tuple of floats.
+ :param str colormap: Name of the colormap to use
+ TODO: Accept a 1D scalar array as the colormap
+ :param str normalization: The colormap normalization.
+ One of: 'linear', 'log', 'sqrt', 'gamma'
+ ;param float gamma: The gamma parameter (for 'gamma' normalization)
+ :param cmapRange: The range of colormap or None for autoscale colormap
+ For logarithmic colormap, the range is in the untransformed data
+ TODO: check consistency with matplotlib
+ :type cmapRange: (float, float) or None
+ :param float alpha: Opacity from 0 (transparent) to 1 (opaque)
+ :param nancolor: RGBA color for Not-A-Number values
+ :type nancolor: 4-tuple of float in [0., 1.]
+ """
+ assert data.dtype in self._INTERNAL_FORMATS
+ assert normalization in self.SUPPORTED_NORMALIZATIONS
+
+ super(GLPlotColormap, self).__init__(data, origin, scale)
+ self.colormap = numpy.array(colormap, copy=False)
+ self.normalization = normalization
+ self.gamma = gamma
+ self._cmapRange = (1., 10.) # Colormap range
+ self.cmapRange = cmapRange # Update _cmapRange
+ self._alpha = numpy.clip(alpha, 0., 1.)
+ self._nancolor = numpy.clip(nancolor, 0., 1.)
+
+ self._cmap_texture = None
+ self._texture = None
+ self._textureIsDirty = False
+
+ def discard(self):
+ if self._cmap_texture is not None:
+ self._cmap_texture.discard()
+ self._cmap_texture = None
+
+ if self._texture is not None:
+ self._texture.discard()
+ self._texture = None
+ self._textureIsDirty = False
+
+ def isInitialized(self):
+ return (self._cmap_texture is not None or
+ self._texture is not None)
+
+ @property
+ def cmapRange(self):
+ if self.normalization == 'log':
+ assert self._cmapRange[0] > 0. and self._cmapRange[1] > 0.
+ elif self.normalization == 'sqrt':
+ assert self._cmapRange[0] >= 0. and self._cmapRange[1] >= 0.
+ return self._cmapRange
+
+ @cmapRange.setter
+ def cmapRange(self, cmapRange):
+ assert len(cmapRange) == 2
+ assert cmapRange[0] <= cmapRange[1]
+ self._cmapRange = float(cmapRange[0]), float(cmapRange[1])
+
+ @property
+ def alpha(self):
+ return self._alpha
+
+ def updateData(self, data):
+ assert data.dtype in self._INTERNAL_FORMATS
+ oldData = self.data
+ self.data = data
+
+ if self._texture is not None:
+ if (self.data.shape != oldData.shape or
+ self.data.dtype != oldData.dtype):
+ self.discard()
+ else:
+ self._textureIsDirty = True
+
+ def prepare(self):
+ if self._cmap_texture is None:
+ # TODO share cmap texture accross Images
+ # put all cmaps in one texture
+ colormap = numpy.empty((16, 256, self.colormap.shape[1]),
+ dtype=self.colormap.dtype)
+ colormap[:] = self.colormap
+ format_ = gl.GL_RGBA if colormap.shape[-1] == 4 else gl.GL_RGB
+ self._cmap_texture = Texture(internalFormat=format_,
+ data=colormap,
+ format_=format_,
+ texUnit=self._CMAP_TEX_UNIT,
+ minFilter=gl.GL_NEAREST,
+ magFilter=gl.GL_NEAREST,
+ wrap=(gl.GL_CLAMP_TO_EDGE,
+ gl.GL_CLAMP_TO_EDGE))
+ self._cmap_texture.prepare()
+
+ if self._texture is None:
+ internalFormat = self._INTERNAL_FORMATS[self.data.dtype]
+
+ self._texture = Image(internalFormat,
+ self.data,
+ format_=gl.GL_RED,
+ texUnit=self._DATA_TEX_UNIT)
+ elif self._textureIsDirty:
+ self._textureIsDirty = True
+ self._texture.updateAll(format_=gl.GL_RED, data=self.data)
+
+ def _setCMap(self, prog):
+ dataMin, dataMax = self.cmapRange # If log, it is stricly positive
+ param = 0.
+
+ if self.data.dtype in (numpy.uint16, numpy.uint8):
+ # Using unsigned int as normalized integer in OpenGL
+ # So normalize range
+ maxInt = float(numpy.iinfo(self.data.dtype).max)
+ dataMin, dataMax = dataMin / maxInt, dataMax / maxInt
+
+ if self.normalization == 'log':
+ dataMin = math.log10(dataMin)
+ dataMax = math.log10(dataMax)
+ normID = 1
+ elif self.normalization == 'sqrt':
+ dataMin = math.sqrt(dataMin)
+ dataMax = math.sqrt(dataMax)
+ normID = 2
+ elif self.normalization == 'gamma':
+ # Keep dataMin, dataMax as is
+ param = self.gamma
+ normID = 3
+ elif self.normalization == 'arcsinh':
+ dataMin = numpy.arcsinh(dataMin)
+ dataMax = numpy.arcsinh(dataMax)
+ normID = 4
+ else: # Linear and fallback
+ normID = 0
+
+ gl.glUniform1i(prog.uniforms['cmap_texture'],
+ self._cmap_texture.texUnit)
+ gl.glUniform1i(prog.uniforms['cmap_normalization'], normID)
+ gl.glUniform1f(prog.uniforms['cmap_parameter'], param)
+ gl.glUniform1f(prog.uniforms['cmap_min'], dataMin)
+ if dataMax > dataMin:
+ oneOverRange = 1. / (dataMax - dataMin)
+ else:
+ oneOverRange = 0. # Fall-back
+ gl.glUniform1f(prog.uniforms['cmap_oneOverRange'], oneOverRange)
+
+ gl.glUniform4f(prog.uniforms['nancolor'], *self._nancolor)
+
+ self._cmap_texture.bind()
+
+ def _renderLinear(self, context):
+ """Perform rendering when both axes have linear scales
+
+ :param RenderContext context: Rendering information
+ """
+ self.prepare()
+
+ prog = self._linearProgram
+ prog.use()
+
+ gl.glUniform1i(prog.uniforms['data'], self._DATA_TEX_UNIT)
+
+ mat = numpy.dot(numpy.dot(context.matrix,
+ mat4Translate(*self.origin)),
+ mat4Scale(*self.scale))
+ gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE,
+ mat.astype(numpy.float32))
+
+ gl.glUniform1f(prog.uniforms['alpha'], self.alpha)
+
+ self._setCMap(prog)
+
+ self._texture.render(prog.attributes['position'],
+ prog.attributes['texCoords'],
+ self._DATA_TEX_UNIT)
+
+ def _renderLog10(self, context):
+ """Perform rendering when one axis has log scale
+
+ :param RenderContext context: Rendering information
+ """
+ xMin, yMin = self.xMin, self.yMin
+ if ((context.isXLog and xMin < FLOAT32_MINPOS) or
+ (context.isYLog and yMin < FLOAT32_MINPOS)):
+ # Do not render images that are partly or totally <= 0
+ return
+
+ self.prepare()
+
+ prog = self._logProgram
+ prog.use()
+
+ ox, oy = self.origin
+
+ gl.glUniform1i(prog.uniforms['data'], self._DATA_TEX_UNIT)
+
+ gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE,
+ context.matrix.astype(numpy.float32))
+ mat = numpy.dot(mat4Translate(ox, oy), mat4Scale(*self.scale))
+ gl.glUniformMatrix4fv(prog.uniforms['matOffset'], 1, gl.GL_TRUE,
+ mat.astype(numpy.float32))
+
+ gl.glUniform2i(prog.uniforms['isLog'], context.isXLog, context.isYLog)
+
+ ex = ox + self.scale[0] * self.data.shape[1]
+ ey = oy + self.scale[1] * self.data.shape[0]
+
+ xOneOverRange = 1. / (ex - ox)
+ yOneOverRange = 1. / (ey - oy)
+ gl.glUniform2f(prog.uniforms['bounds_originOverRange'],
+ ox * xOneOverRange, oy * yOneOverRange)
+ gl.glUniform2f(prog.uniforms['bounds_oneOverRange'],
+ xOneOverRange, yOneOverRange)
+
+ gl.glUniform1f(prog.uniforms['alpha'], self.alpha)
+
+ self._setCMap(prog)
+
+ try:
+ tiles = self._texture.tiles
+ except AttributeError:
+ raise RuntimeError("No texture, discard has already been called")
+ if len(tiles) > 1:
+ raise NotImplementedError(
+ "Image over multiple textures not supported with log scale")
+
+ texture, vertices, info = tiles[0]
+
+ texture.bind(self._DATA_TEX_UNIT)
+
+ posAttrib = prog.attributes['position']
+ stride = vertices.shape[-1] * vertices.itemsize
+ gl.glEnableVertexAttribArray(posAttrib)
+ gl.glVertexAttribPointer(posAttrib,
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ stride, vertices)
+
+ gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(vertices))
+
+ def render(self, context):
+ """Perform rendering
+
+ :param RenderContext context: Rendering information
+ """
+ if any((context.isXLog, context.isYLog)):
+ self._renderLog10(context)
+ else:
+ self._renderLinear(context)
+
+ # Unbind colormap texture
+ gl.glActiveTexture(gl.GL_TEXTURE0 + self._cmap_texture.texUnit)
+ gl.glBindTexture(self._cmap_texture.target, 0)
+
+
+# image #######################################################################
+
+class GLPlotRGBAImage(_GLPlotData2D):
+
+ _SHADERS = {
+ 'linear': {
+ 'vertex': """
+ #version 120
+
+ attribute vec2 position;
+ attribute vec2 texCoords;
+ uniform mat4 matrix;
+
+ varying vec2 coords;
+
+ void main(void) {
+ gl_Position = matrix * vec4(position, 0.0, 1.0);
+ coords = texCoords;
+ }
+ """,
+ 'fragment': """
+ #version 120
+
+ uniform sampler2D tex;
+ uniform float alpha;
+
+ varying vec2 coords;
+
+ void main(void) {
+ gl_FragColor = texture2D(tex, coords);
+ gl_FragColor.a *= alpha;
+ }
+ """},
+
+ 'log': {
+ 'vertex': """
+ #version 120
+
+ attribute vec2 position;
+ uniform mat4 matrix;
+ uniform mat4 matOffset;
+ uniform bvec2 isLog;
+
+ varying vec2 coords;
+
+ const float oneOverLog10 = 0.43429448190325176;
+
+ void main(void) {
+ vec4 dataPos = matOffset * vec4(position, 0.0, 1.0);
+ if (isLog.x) {
+ dataPos.x = oneOverLog10 * log(dataPos.x);
+ }
+ if (isLog.y) {
+ dataPos.y = oneOverLog10 * log(dataPos.y);
+ }
+ coords = dataPos.xy;
+ gl_Position = matrix * dataPos;
+ }
+ """,
+ 'fragment': """
+ #version 120
+
+ uniform sampler2D tex;
+ uniform bvec2 isLog;
+ uniform vec2 bounds_oneOverRange;
+ uniform vec2 bounds_originOverRange;
+ uniform float alpha;
+
+ varying vec2 coords;
+
+ vec2 textureCoords(void) {
+ vec2 pos = coords;
+ if (isLog.x) {
+ pos.x = pow(10., coords.x);
+ }
+ if (isLog.y) {
+ pos.y = pow(10., coords.y);
+ }
+ return pos * bounds_oneOverRange - bounds_originOverRange;
+ // TODO texture coords in range different from [0, 1]
+ }
+
+ void main(void) {
+ gl_FragColor = texture2D(tex, textureCoords());
+ gl_FragColor.a *= alpha;
+ }
+ """}
+ }
+
+ _DATA_TEX_UNIT = 0
+
+ _SUPPORTED_DTYPES = (numpy.dtype(numpy.float32),
+ numpy.dtype(numpy.uint8),
+ numpy.dtype(numpy.uint16))
+
+ _linearProgram = Program(_SHADERS['linear']['vertex'],
+ _SHADERS['linear']['fragment'],
+ attrib0='position')
+
+ _logProgram = Program(_SHADERS['log']['vertex'],
+ _SHADERS['log']['fragment'],
+ attrib0='position')
+
+ def __init__(self, data, origin, scale, alpha):
+ """Create a 2D RGB(A) image from data
+
+ :param data: The 2D image data array to display
+ :type data: numpy.ndarray with 3 dimensions
+ (dtype=numpy.uint8 or numpy.float32)
+ :param origin: (x, y) coordinates of the origin of the data array
+ :type origin: 2-tuple of floats.
+ :param scale: (sx, sy) scale factors of the data array.
+ This is the size of a data pixel in plot data space.
+ :type scale: 2-tuple of floats.
+ :param float alpha: Opacity from 0 (transparent) to 1 (opaque)
+ """
+ assert data.dtype in self._SUPPORTED_DTYPES
+ super(GLPlotRGBAImage, self).__init__(data, origin, scale)
+ self._texture = None
+ self._textureIsDirty = False
+ self._alpha = numpy.clip(alpha, 0., 1.)
+
+ @property
+ def alpha(self):
+ return self._alpha
+
+ def discard(self):
+ if self.isInitialized():
+ self._texture.discard()
+ self._texture = None
+ self._textureIsDirty = False
+
+ def isInitialized(self):
+ return self._texture is not None
+
+ def updateData(self, data):
+ assert data.dtype in self._SUPPORTED_DTYPES
+ oldData = self.data
+ self.data = data
+
+ if self._texture is not None:
+ if self.data.shape != oldData.shape:
+ self.discard()
+ else:
+ self._textureIsDirty = True
+
+ def prepare(self):
+ if self._texture is None:
+ formatName = 'GL_RGBA' if self.data.shape[2] == 4 else 'GL_RGB'
+ format_ = getattr(gl, formatName)
+
+ if self.data.dtype == numpy.uint16:
+ formatName += '16' # Use sized internal format for uint16
+ internalFormat = getattr(gl, formatName)
+
+ self._texture = Image(internalFormat,
+ self.data,
+ format_=format_,
+ texUnit=self._DATA_TEX_UNIT)
+ elif self._textureIsDirty:
+ self._textureIsDirty = False
+
+ # We should check that internal format is the same
+ format_ = gl.GL_RGBA if self.data.shape[2] == 4 else gl.GL_RGB
+ self._texture.updateAll(format_=format_, data=self.data)
+
+ def _renderLinear(self, context):
+ """Perform rendering with both axes having linear scales
+
+ :param RenderContext context: Rendering information
+ """
+ self.prepare()
+
+ prog = self._linearProgram
+ prog.use()
+
+ gl.glUniform1i(prog.uniforms['tex'], self._DATA_TEX_UNIT)
+
+ mat = numpy.dot(numpy.dot(context.matrix, mat4Translate(*self.origin)),
+ mat4Scale(*self.scale))
+ gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE,
+ mat.astype(numpy.float32))
+
+ gl.glUniform1f(prog.uniforms['alpha'], self.alpha)
+
+ self._texture.render(prog.attributes['position'],
+ prog.attributes['texCoords'],
+ self._DATA_TEX_UNIT)
+
+ def _renderLog(self, context):
+ """Perform rendering with axes having log scale
+
+ :param RenderContext context: Rendering information
+ """
+ self.prepare()
+
+ prog = self._logProgram
+ prog.use()
+
+ ox, oy = self.origin
+
+ gl.glUniform1i(prog.uniforms['tex'], self._DATA_TEX_UNIT)
+
+ gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE,
+ context.matrix.astype(numpy.float32))
+ mat = numpy.dot(mat4Translate(ox, oy), mat4Scale(*self.scale))
+ gl.glUniformMatrix4fv(prog.uniforms['matOffset'], 1, gl.GL_TRUE,
+ mat.astype(numpy.float32))
+
+ gl.glUniform2i(prog.uniforms['isLog'], context.isXLog, context.isYLog)
+
+ gl.glUniform1f(prog.uniforms['alpha'], self.alpha)
+
+ ex = ox + self.scale[0] * self.data.shape[1]
+ ey = oy + self.scale[1] * self.data.shape[0]
+
+ xOneOverRange = 1. / (ex - ox)
+ yOneOverRange = 1. / (ey - oy)
+ gl.glUniform2f(prog.uniforms['bounds_originOverRange'],
+ ox * xOneOverRange, oy * yOneOverRange)
+ gl.glUniform2f(prog.uniforms['bounds_oneOverRange'],
+ xOneOverRange, yOneOverRange)
+
+ try:
+ tiles = self._texture.tiles
+ except AttributeError:
+ raise RuntimeError("No texture, discard has already been called")
+ if len(tiles) > 1:
+ raise NotImplementedError(
+ "Image over multiple textures not supported with log scale")
+
+ texture, vertices, info = tiles[0]
+
+ texture.bind(self._DATA_TEX_UNIT)
+
+ posAttrib = prog.attributes['position']
+ stride = vertices.shape[-1] * vertices.itemsize
+ gl.glEnableVertexAttribArray(posAttrib)
+ gl.glVertexAttribPointer(posAttrib,
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ stride, vertices)
+
+ gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(vertices))
+
+ def render(self, context):
+ """Perform rendering
+
+ :param RenderContext context: Rendering information
+ """
+ if any((context.isXLog, context.isYLog)):
+ self._renderLog(context)
+ else:
+ self._renderLinear(context)
diff --git a/src/silx/gui/plot/backends/glutils/GLPlotItem.py b/src/silx/gui/plot/backends/glutils/GLPlotItem.py
new file mode 100644
index 0000000..ae13091
--- /dev/null
+++ b/src/silx/gui/plot/backends/glutils/GLPlotItem.py
@@ -0,0 +1,99 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2020-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+This module provides a base class for PlotWidget OpenGL backend primitives
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "02/07/2020"
+
+
+class RenderContext:
+ """Context with which to perform OpenGL rendering.
+
+ :param numpy.ndarray matrix: 4x4 transform matrix to use for rendering
+ :param bool isXLog: Whether X axis is log scale or not
+ :param bool isYLog: Whether Y axis is log scale or not
+ :param float dpi: Number of device pixels per inch
+ """
+
+ def __init__(self, matrix=None, isXLog=False, isYLog=False, dpi=96.):
+ self.matrix = matrix
+ """Current transformation matrix"""
+
+ self.__isXLog = isXLog
+ self.__isYLog = isYLog
+ self.__dpi = dpi
+
+ @property
+ def isXLog(self):
+ """True if X axis is using log scale"""
+ return self.__isXLog
+
+ @property
+ def isYLog(self):
+ """True if Y axis is using log scale"""
+ return self.__isYLog
+
+ @property
+ def dpi(self):
+ """Number of device pixels per inch"""
+ return self.__dpi
+
+
+class GLPlotItem:
+ """Base class for primitives used in the PlotWidget OpenGL backend"""
+
+ def __init__(self):
+ self.yaxis = 'left'
+ "YAxis this item is attached to (either 'left' or 'right')"
+
+ def pick(self, x, y):
+ """Perform picking at given position.
+
+ :param float x: X coordinate in plot data frame of reference
+ :param float y: Y coordinate in plot data frame of reference
+ :returns:
+ Result of picking as a list of indices or None if nothing picked
+ :rtype: Union[List[int],None]
+ """
+ return None
+
+ def render(self, context):
+ """Performs OpenGL rendering of the item.
+
+ :param RenderContext context: Rendering context information
+ """
+ pass
+
+ def discard(self):
+ """Discards OpenGL resources this item has created."""
+ pass
+
+ def isInitialized(self) -> bool:
+ """Returns True if resources where initialized and requires `discard`.
+ """
+ return True
diff --git a/src/silx/gui/plot/backends/glutils/GLPlotTriangles.py b/src/silx/gui/plot/backends/glutils/GLPlotTriangles.py
new file mode 100644
index 0000000..fbe9e02
--- /dev/null
+++ b/src/silx/gui/plot/backends/glutils/GLPlotTriangles.py
@@ -0,0 +1,197 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2019-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+This module provides a class to render a set of 2D triangles
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/04/2017"
+
+
+import ctypes
+
+import numpy
+
+from .....math.combo import min_max
+from .... import _glutils as glutils
+from ...._glutils import gl
+from .GLPlotItem import GLPlotItem
+
+
+class GLPlotTriangles(GLPlotItem):
+ """Handle rendering of a set of colored triangles"""
+
+ _PROGRAM = glutils.Program(
+ vertexShader="""
+ #version 120
+
+ uniform mat4 matrix;
+ attribute float xPos;
+ attribute float yPos;
+ attribute vec4 color;
+
+ varying vec4 vColor;
+
+ void main(void) {
+ gl_Position = matrix * vec4(xPos, yPos, 0.0, 1.0);
+ vColor = color;
+ }
+ """,
+ fragmentShader="""
+ #version 120
+
+ uniform float alpha;
+ varying vec4 vColor;
+
+ void main(void) {
+ gl_FragColor = vColor;
+ gl_FragColor.a *= alpha;
+ }
+ """,
+ attrib0='xPos')
+
+ def __init__(self, x, y, color, triangles, alpha=1.):
+ """
+
+ :param numpy.ndarray x: X coordinates of triangle corners
+ :param numpy.ndarray y: Y coordinates of triangle corners
+ :param numpy.ndarray color: color for each point
+ :param numpy.ndarray triangles: (N, 3) array of indices of triangles
+ :param float alpha: Opacity in [0, 1]
+ """
+ super().__init__()
+ # Check and convert input data
+ x = numpy.ravel(numpy.array(x, dtype=numpy.float32))
+ y = numpy.ravel(numpy.array(y, dtype=numpy.float32))
+ color = numpy.array(color, copy=False)
+ # Cast to uint32
+ triangles = numpy.array(triangles, copy=False, dtype=numpy.uint32)
+
+ assert x.size == y.size
+ assert x.size == len(color)
+ assert color.ndim == 2 and color.shape[1] in (3, 4)
+ if numpy.issubdtype(color.dtype, numpy.floating):
+ color = numpy.array(color, dtype=numpy.float32, copy=False)
+ elif numpy.issubdtype(color.dtype, numpy.integer):
+ color = numpy.array(color, dtype=numpy.uint8, copy=False)
+ else:
+ raise ValueError('Unsupported color type')
+ assert triangles.ndim == 2 and triangles.shape[1] == 3
+
+ self.__x_y_color = x, y, color
+ self.xMin, self.xMax = min_max(x, finite=True)
+ self.yMin, self.yMax = min_max(y, finite=True)
+ self.__triangles = triangles
+ self.__alpha = numpy.clip(float(alpha), 0., 1.)
+ self.__vbos = None
+ self.__indicesVbo = None
+ self.__picking_triangles = None
+
+ def pick(self, x, y):
+ """Perform picking
+
+ :param float x: X coordinates in plot data frame
+ :param float y: Y coordinates in plot data frame
+ :return: List of picked data point indices
+ :rtype: Union[List[int],None]
+ """
+ if (x < self.xMin or x > self.xMax or
+ y < self.yMin or y > self.yMax):
+ return None
+
+ xPts, yPts = self.__x_y_color[:2]
+ if self.__picking_triangles is None:
+ self.__picking_triangles = numpy.zeros(
+ self.__triangles.shape + (3,), dtype=numpy.float32)
+ self.__picking_triangles[:, :, 0] = xPts[self.__triangles]
+ self.__picking_triangles[:, :, 1] = yPts[self.__triangles]
+
+ segment = numpy.array(((x, y, -1), (x, y, 1)), dtype=numpy.float32)
+ # Picked triangle indices
+ indices = glutils.segmentTrianglesIntersection(
+ segment, self.__picking_triangles)[0]
+ # Point indices
+ indices = numpy.unique(numpy.ravel(self.__triangles[indices]))
+
+ # Sorted from furthest to closest point
+ dists = (xPts[indices] - x) ** 2 + (yPts[indices] - y) ** 2
+ indices = indices[numpy.flip(numpy.argsort(dists), axis=0)]
+
+ return tuple(indices) if len(indices) > 0 else None
+
+ def discard(self):
+ """Release resources on the GPU"""
+ if self.isInitialized():
+ self.__vbos[0].vbo.discard()
+ self.__vbos = None
+ self.__indicesVbo.discard()
+ self.__indicesVbo = None
+
+ def isInitialized(self):
+ return self.__vbos is not None
+
+ def prepare(self):
+ """Allocate resources on the GPU"""
+ if self.__vbos is None:
+ self.__vbos = glutils.vertexBuffer(self.__x_y_color)
+ # Normalization is need for color
+ self.__vbos[-1].normalization = True
+
+ if self.__indicesVbo is None:
+ self.__indicesVbo = glutils.VertexBuffer(
+ numpy.ravel(self.__triangles),
+ usage=gl.GL_STATIC_DRAW,
+ target=gl.GL_ELEMENT_ARRAY_BUFFER)
+
+ def render(self, context):
+ """Perform rendering
+
+ :param RenderContext context: Rendering information
+ """
+ self.prepare()
+
+ if self.__vbos is None or self.__indicesVbo is None:
+ return # Nothing to display
+
+ self._PROGRAM.use()
+
+ gl.glUniformMatrix4fv(self._PROGRAM.uniforms['matrix'],
+ 1,
+ gl.GL_TRUE,
+ context.matrix.astype(numpy.float32))
+
+ gl.glUniform1f(self._PROGRAM.uniforms['alpha'], self.__alpha)
+
+ for index, name in enumerate(('xPos', 'yPos', 'color')):
+ attr = self._PROGRAM.attributes[name]
+ gl.glEnableVertexAttribArray(attr)
+ self.__vbos[index].setVertexAttrib(attr)
+
+ with self.__indicesVbo:
+ gl.glDrawElements(gl.GL_TRIANGLES,
+ self.__triangles.size,
+ glutils.numpyToGLType(self.__triangles.dtype),
+ ctypes.c_void_p(0))
diff --git a/src/silx/gui/plot/backends/glutils/GLSupport.py b/src/silx/gui/plot/backends/glutils/GLSupport.py
new file mode 100644
index 0000000..da6dffa
--- /dev/null
+++ b/src/silx/gui/plot/backends/glutils/GLSupport.py
@@ -0,0 +1,158 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+This module provides convenient classes and functions for OpenGL rendering.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/04/2017"
+
+
+import numpy
+
+from ...._glutils import gl
+
+
+def buildFillMaskIndices(nIndices, dtype=None):
+ """Returns triangle strip indices for rendering a filled polygon mask
+
+ :param int nIndices: Number of points
+ :param Union[numpy.dtype,None] dtype:
+ If specified the dtype of the returned indices array
+ :return: 1D array of indices constructing a triangle strip
+ :rtype: numpy.ndarray
+ """
+ if dtype is None:
+ if nIndices <= numpy.iinfo(numpy.uint16).max + 1:
+ dtype = numpy.uint16
+ else:
+ dtype = numpy.uint32
+
+ lastIndex = nIndices - 1
+ splitIndex = lastIndex // 2 + 1
+ indices = numpy.empty(nIndices, dtype=dtype)
+ indices[::2] = numpy.arange(0, splitIndex, step=1, dtype=dtype)
+ indices[1::2] = numpy.arange(lastIndex, splitIndex - 1, step=-1,
+ dtype=dtype)
+ return indices
+
+
+class FilledShape2D(object):
+ _NO_HATCH = 0
+ _HATCH_STEP = 20
+
+ def __init__(self, points, style='solid', color=(0., 0., 0., 1.)):
+ self.vertices = numpy.array(points, dtype=numpy.float32, copy=False)
+ self._indices = buildFillMaskIndices(len(self.vertices))
+
+ tVertex = numpy.transpose(self.vertices)
+ xMin, xMax = min(tVertex[0]), max(tVertex[0])
+ yMin, yMax = min(tVertex[1]), max(tVertex[1])
+ self.bboxVertices = numpy.array(((xMin, yMin), (xMin, yMax),
+ (xMax, yMin), (xMax, yMax)),
+ dtype=numpy.float32)
+ self._xMin, self._xMax = xMin, xMax
+ self._yMin, self._yMax = yMin, yMax
+
+ self.style = style
+ self.color = color
+
+ def render(self, posAttrib, colorUnif, hatchStepUnif):
+ assert self.style in ('hatch', 'solid')
+ gl.glUniform4f(colorUnif, *self.color)
+ step = self._HATCH_STEP if self.style == 'hatch' else self._NO_HATCH
+ gl.glUniform1i(hatchStepUnif, step)
+
+ # Prepare fill mask
+ gl.glEnableVertexAttribArray(posAttrib)
+ gl.glVertexAttribPointer(posAttrib,
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0, self.vertices)
+
+ gl.glEnable(gl.GL_STENCIL_TEST)
+ gl.glStencilMask(1)
+ gl.glStencilFunc(gl.GL_ALWAYS, 1, 1)
+ gl.glStencilOp(gl.GL_INVERT, gl.GL_INVERT, gl.GL_INVERT)
+ gl.glColorMask(gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE)
+ gl.glDepthMask(gl.GL_FALSE)
+
+ gl.glDrawElements(gl.GL_TRIANGLE_STRIP, len(self._indices),
+ gl.GL_UNSIGNED_SHORT, self._indices)
+
+ gl.glStencilFunc(gl.GL_EQUAL, 1, 1)
+ # Reset stencil while drawing
+ gl.glStencilOp(gl.GL_ZERO, gl.GL_ZERO, gl.GL_ZERO)
+ gl.glColorMask(gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE)
+ gl.glDepthMask(gl.GL_TRUE)
+
+ gl.glVertexAttribPointer(posAttrib,
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0, self.bboxVertices)
+ gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(self.bboxVertices))
+
+ gl.glDisable(gl.GL_STENCIL_TEST)
+
+
+# matrix ######################################################################
+
+def mat4Ortho(left, right, bottom, top, near, far):
+ """Orthographic projection matrix (row-major)"""
+ return numpy.array((
+ (2./(right - left), 0., 0., -(right+left)/float(right-left)),
+ (0., 2./(top - bottom), 0., -(top+bottom)/float(top-bottom)),
+ (0., 0., -2./(far-near), -(far+near)/float(far-near)),
+ (0., 0., 0., 1.)), dtype=numpy.float64)
+
+
+def mat4Translate(x=0., y=0., z=0.):
+ """Translation matrix (row-major)"""
+ return numpy.array((
+ (1., 0., 0., x),
+ (0., 1., 0., y),
+ (0., 0., 1., z),
+ (0., 0., 0., 1.)), dtype=numpy.float64)
+
+
+def mat4Scale(sx=1., sy=1., sz=1.):
+ """Scale matrix (row-major)"""
+ return numpy.array((
+ (sx, 0., 0., 0.),
+ (0., sy, 0., 0.),
+ (0., 0., sz, 0.),
+ (0., 0., 0., 1.)), dtype=numpy.float64)
+
+
+def mat4Identity():
+ """Identity matrix"""
+ return numpy.array((
+ (1., 0., 0., 0.),
+ (0., 1., 0., 0.),
+ (0., 0., 1., 0.),
+ (0., 0., 0., 1.)), dtype=numpy.float64)
diff --git a/src/silx/gui/plot/backends/glutils/GLText.py b/src/silx/gui/plot/backends/glutils/GLText.py
new file mode 100644
index 0000000..d6ae6fa
--- /dev/null
+++ b/src/silx/gui/plot/backends/glutils/GLText.py
@@ -0,0 +1,287 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+This module provides minimalistic text support for OpenGL.
+It provides Latin-1 (ISO8859-1) characters for one monospace font at one size.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/04/2017"
+
+
+from collections import OrderedDict
+import weakref
+
+import numpy
+
+from ...._glutils import font, gl, Context, Program, Texture
+from .GLSupport import mat4Translate
+
+
+# TODO: Font should be configurable by the main program: using mpl.rcParams?
+
+
+class _Cache(object):
+ """LRU (Least Recent Used) cache.
+
+ :param int maxsize: Maximum number of (key, value) pairs in the cache
+ :param callable callback:
+ Called when a (key, value) pair is removed from the cache.
+ It must take 2 arguments: key and value.
+ """
+
+ def __init__(self, maxsize=128, callback=None):
+ self._maxsize = int(maxsize)
+ self._callback = callback
+ self._cache = OrderedDict()
+
+ def __contains__(self, item):
+ return item in self._cache
+
+ def __getitem__(self, key):
+ if key in self._cache:
+ # Remove/add key from ordered dict to store last access info
+ value = self._cache.pop(key)
+ self._cache[key] = value
+ return value
+ else:
+ raise KeyError
+
+ def __setitem__(self, key, value):
+ """Add a key, value pair to the cache.
+
+ :param key: The key to set
+ :param value: The corresponding value
+ """
+ if key not in self._cache and len(self._cache) >= self._maxsize:
+ removedKey, removedValue = self._cache.popitem(last=False)
+ if self._callback is not None:
+ self._callback(removedKey, removedValue)
+ self._cache[key] = value
+
+
+# Text2D ######################################################################
+
+LEFT, CENTER, RIGHT = 'left', 'center', 'right'
+TOP, BASELINE, BOTTOM = 'top', 'baseline', 'bottom'
+ROTATE_90, ROTATE_180, ROTATE_270 = 90, 180, 270
+
+
+class Text2D(object):
+
+ _SHADERS = {
+ 'vertex': """
+ #version 120
+
+ attribute vec2 position;
+ attribute vec2 texCoords;
+ uniform mat4 matrix;
+
+ varying vec2 vCoords;
+
+ void main(void) {
+ gl_Position = matrix * vec4(position, 0.0, 1.0);
+ vCoords = texCoords;
+ }
+ """,
+ 'fragment': """
+ #version 120
+
+ uniform sampler2D texText;
+ uniform vec4 color;
+ uniform vec4 bgColor;
+
+ varying vec2 vCoords;
+
+ void main(void) {
+ gl_FragColor = mix(bgColor, color, texture2D(texText, vCoords).r);
+ }
+ """
+ }
+
+ _TEX_COORDS = numpy.array(((0., 0.), (1., 0.), (0., 1.), (1., 1.)),
+ dtype=numpy.float32).ravel()
+
+ _program = Program(_SHADERS['vertex'],
+ _SHADERS['fragment'],
+ attrib0='position')
+
+ # Discard texture objects when removed from the cache
+ _textures = weakref.WeakKeyDictionary()
+ """Cache already created textures"""
+
+ _sizes = _Cache()
+ """Cache already computed sizes"""
+
+ def __init__(self, text, x=0, y=0,
+ color=(0., 0., 0., 1.),
+ bgColor=None,
+ align=LEFT, valign=BASELINE,
+ rotate=0,
+ devicePixelRatio= 1.):
+ self.devicePixelRatio = devicePixelRatio
+ self._vertices = None
+ self._text = text
+ self.x = x
+ self.y = y
+ self.color = color
+ self.bgColor = bgColor
+
+ if align not in (LEFT, CENTER, RIGHT):
+ raise ValueError(
+ "Horizontal alignment not supported: {0}".format(align))
+ self._align = align
+
+ if valign not in (TOP, CENTER, BASELINE, BOTTOM):
+ raise ValueError(
+ "Vertical alignment not supported: {0}".format(valign))
+ self._valign = valign
+
+ self._rotate = numpy.radians(rotate)
+
+ def _getTexture(self, text, devicePixelRatio):
+ # Retrieve/initialize texture cache for current context
+ textureKey = text, devicePixelRatio
+
+ context = Context.getCurrent()
+ if context not in self._textures:
+ self._textures[context] = _Cache(
+ callback=lambda key, value: value[0].discard())
+ textures = self._textures[context]
+
+ if textureKey not in textures:
+ image, offset = font.rasterText(
+ text,
+ font.getDefaultFontFamily(),
+ devicePixelRatio=self.devicePixelRatio)
+ if textureKey not in self._sizes:
+ self._sizes[textureKey] = image.shape[1], image.shape[0]
+
+ texture = Texture(
+ gl.GL_RED,
+ data=image,
+ minFilter=gl.GL_NEAREST,
+ magFilter=gl.GL_NEAREST,
+ wrap=(gl.GL_CLAMP_TO_EDGE,
+ gl.GL_CLAMP_TO_EDGE))
+ texture.prepare()
+ textures[textureKey] = texture, offset
+
+ return textures[textureKey]
+
+ @property
+ def text(self):
+ return self._text
+
+ @property
+ def size(self):
+ textureKey = self.text, self.devicePixelRatio
+ if textureKey not in self._sizes:
+ image, offset = font.rasterText(
+ self.text,
+ font.getDefaultFontFamily(),
+ devicePixelRatio=self.devicePixelRatio)
+ self._sizes[textureKey] = image.shape[1], image.shape[0]
+ return self._sizes[textureKey]
+
+ def getVertices(self, offset, shape):
+ height, width = shape
+
+ if self._align == LEFT:
+ xOrig = 0
+ elif self._align == RIGHT:
+ xOrig = - width
+ else: # CENTER
+ xOrig = - width // 2
+
+ if self._valign == BASELINE:
+ yOrig = - offset
+ elif self._valign == TOP:
+ yOrig = 0
+ elif self._valign == BOTTOM:
+ yOrig = - height
+ else: # CENTER
+ yOrig = - height // 2
+
+ vertices = numpy.array((
+ (xOrig, yOrig),
+ (xOrig + width, yOrig),
+ (xOrig, yOrig + height),
+ (xOrig + width, yOrig + height)), dtype=numpy.float32)
+
+ cos, sin = numpy.cos(self._rotate), numpy.sin(self._rotate)
+ vertices = numpy.ascontiguousarray(numpy.transpose(numpy.array((
+ cos * vertices[:, 0] - sin * vertices[:, 1],
+ sin * vertices[:, 0] + cos * vertices[:, 1]),
+ dtype=numpy.float32)))
+
+ return vertices
+
+ def render(self, matrix):
+ if not self.text:
+ return
+
+ prog = self._program
+ prog.use()
+
+ texUnit = 0
+ texture, offset = self._getTexture(self.text, self.devicePixelRatio)
+
+ gl.glUniform1i(prog.uniforms['texText'], texUnit)
+
+ mat = numpy.dot(matrix, mat4Translate(int(self.x), int(self.y)))
+ gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE,
+ mat.astype(numpy.float32))
+
+ gl.glUniform4f(prog.uniforms['color'], *self.color)
+ if self.bgColor is not None:
+ bgColor = self.bgColor
+ else:
+ bgColor = self.color[0], self.color[1], self.color[2], 0.
+ gl.glUniform4f(prog.uniforms['bgColor'], *bgColor)
+
+ vertices = self.getVertices(offset, texture.shape)
+
+ posAttrib = prog.attributes['position']
+ gl.glEnableVertexAttribArray(posAttrib)
+ gl.glVertexAttribPointer(posAttrib,
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0,
+ vertices)
+
+ texAttrib = prog.attributes['texCoords']
+ gl.glEnableVertexAttribArray(texAttrib)
+ gl.glVertexAttribPointer(texAttrib,
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0,
+ self._TEX_COORDS)
+
+ with texture:
+ gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, 4)
diff --git a/src/silx/gui/plot/backends/glutils/GLTexture.py b/src/silx/gui/plot/backends/glutils/GLTexture.py
new file mode 100644
index 0000000..37fbdd0
--- /dev/null
+++ b/src/silx/gui/plot/backends/glutils/GLTexture.py
@@ -0,0 +1,241 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""This module provides classes wrapping OpenGL texture."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/04/2017"
+
+
+from ctypes import c_void_p
+import logging
+
+import numpy
+
+from ...._glutils import gl, Texture, numpyToGLType
+
+
+_logger = logging.getLogger(__name__)
+
+
+def _checkTexture2D(internalFormat, shape,
+ format_=None, type_=gl.GL_FLOAT, border=0):
+ """Check if texture size with provided parameters is supported
+
+ :rtype: bool
+ """
+ height, width = shape
+ gl.glTexImage2D(gl.GL_PROXY_TEXTURE_2D, 0, internalFormat,
+ width, height, border,
+ format_ or internalFormat,
+ type_, c_void_p(0))
+ width = gl.glGetTexLevelParameteriv(
+ gl.GL_PROXY_TEXTURE_2D, 0, gl.GL_TEXTURE_WIDTH)
+ return bool(width)
+
+
+MIN_TEXTURE_SIZE = 64
+
+
+def _getMaxSquareTexture2DSize(internalFormat=gl.GL_RGBA,
+ format_=None,
+ type_=gl.GL_FLOAT,
+ border=0):
+ """Returns a supported size for a corresponding square texture
+
+ :returns: GL_MAX_TEXTURE_SIZE or a smaller supported size (not optimal)
+ :rtype: int
+ """
+ # Is this useful?
+ maxTexSize = gl.glGetIntegerv(gl.GL_MAX_TEXTURE_SIZE)
+ while maxTexSize > MIN_TEXTURE_SIZE and \
+ not _checkTexture2D(internalFormat, (maxTexSize, maxTexSize),
+ format_, type_, border):
+ maxTexSize //= 2
+ return max(MIN_TEXTURE_SIZE, maxTexSize)
+
+
+class Image(object):
+ """Image of any size eventually using multiple textures or larger texture
+ """
+
+ _WRAP = (gl.GL_CLAMP_TO_EDGE, gl.GL_CLAMP_TO_EDGE)
+ _MIN_FILTER = gl.GL_NEAREST
+ _MAG_FILTER = gl.GL_NEAREST
+
+ def __init__(self, internalFormat, data, format_=None, texUnit=0):
+ self.internalFormat = internalFormat
+ self.height, self.width = data.shape[0:2]
+ type_ = numpyToGLType(data.dtype)
+
+ if _checkTexture2D(internalFormat, data.shape[0:2], format_, type_):
+ texture = Texture(internalFormat,
+ data,
+ format_,
+ texUnit=texUnit,
+ minFilter=self._MIN_FILTER,
+ magFilter=self._MAG_FILTER,
+ wrap=self._WRAP)
+ texture.prepare()
+ vertices = numpy.array((
+ (0., 0., 0., 0.),
+ (self.width, 0., 1., 0.),
+ (0., self.height, 0., 1.),
+ (self.width, self.height, 1., 1.)), dtype=numpy.float32)
+ self.tiles = ((texture, vertices,
+ {'xOrigData': 0, 'yOrigData': 0,
+ 'wData': self.width, 'hData': self.height}),)
+
+ else:
+ # Handle dimension too large: make tiles
+ maxTexSize = _getMaxSquareTexture2DSize(internalFormat,
+ format_, type_)
+
+ nCols = (self.width+maxTexSize-1) // maxTexSize
+ colWidths = [self.width // nCols] * nCols
+ colWidths[-1] += self.width % nCols
+
+ nRows = (self.height+maxTexSize-1) // maxTexSize
+ rowHeights = [self.height//nRows] * nRows
+ rowHeights[-1] += self.height % nRows
+
+ tiles = []
+ yOrig = 0
+ for hData in rowHeights:
+ xOrig = 0
+ for wData in colWidths:
+ if (hData < MIN_TEXTURE_SIZE or wData < MIN_TEXTURE_SIZE) \
+ and not _checkTexture2D(internalFormat,
+ (hData, wData),
+ format_,
+ type_):
+ # Ensure texture size is at least MIN_TEXTURE_SIZE
+ tH = max(hData, MIN_TEXTURE_SIZE)
+ tW = max(wData, MIN_TEXTURE_SIZE)
+
+ uMax, vMax = float(wData)/tW, float(hData)/tH
+
+ # TODO issue with type_ and alignment
+ texture = Texture(internalFormat,
+ data=None,
+ format_=format_,
+ shape=(tH, tW),
+ texUnit=texUnit,
+ minFilter=self._MIN_FILTER,
+ magFilter=self._MAG_FILTER,
+ wrap=self._WRAP)
+ # TODO handle unpack
+ texture.update(format_,
+ data[yOrig:yOrig+hData,
+ xOrig:xOrig+wData])
+ # texture.update(format_, type_, data,
+ # width=wData, height=hData,
+ # unpackRowLength=width,
+ # unpackSkipPixels=xOrig,
+ # unpackSkipRows=yOrig)
+ else:
+ uMax, vMax = 1, 1
+ # TODO issue with type_ and unpacking tiles
+ # TODO idea to handle unpack: use array strides
+ # As it is now, it will make a copy
+ texture = Texture(internalFormat,
+ data[yOrig:yOrig+hData,
+ xOrig:xOrig+wData],
+ format_,
+ texUnit=texUnit,
+ minFilter=self._MIN_FILTER,
+ magFilter=self._MAG_FILTER,
+ wrap=self._WRAP)
+ # TODO
+ # unpackRowLength=width,
+ # unpackSkipPixels=xOrig,
+ # unpackSkipRows=yOrig)
+ vertices = numpy.array((
+ (xOrig, yOrig, 0., 0.),
+ (xOrig + wData, yOrig, uMax, 0.),
+ (xOrig, yOrig + hData, 0., vMax),
+ (xOrig + wData, yOrig + hData, uMax, vMax)),
+ dtype=numpy.float32)
+ texture.prepare()
+ tiles.append((texture, vertices,
+ {'xOrigData': xOrig, 'yOrigData': yOrig,
+ 'wData': wData, 'hData': hData}))
+ xOrig += wData
+ yOrig += hData
+ self.tiles = tuple(tiles)
+
+ def discard(self):
+ for texture, vertices, _ in self.tiles:
+ texture.discard()
+ del self.tiles
+
+ def updateAll(self, format_, data, texUnit=0):
+ if not hasattr(self, 'tiles'):
+ raise RuntimeError("No texture, discard has already been called")
+
+ assert data.shape[:2] == (self.height, self.width)
+ if len(self.tiles) == 1:
+ self.tiles[0][0].update(format_, data, texUnit=texUnit)
+ else:
+ for texture, _, info in self.tiles:
+ yOrig, xOrig = info['yOrigData'], info['xOrigData']
+ height, width = info['hData'], info['wData']
+ texture.update(format_,
+ data[yOrig:yOrig+height, xOrig:xOrig+width],
+ texUnit=texUnit)
+ texture.prepare()
+ # TODO check
+ # width=info['wData'], height=info['hData'],
+ # texUnit=texUnit, unpackAlign=unpackAlign,
+ # unpackRowLength=self.width,
+ # unpackSkipPixels=info['xOrigData'],
+ # unpackSkipRows=info['yOrigData'])
+
+ def render(self, posAttrib, texAttrib, texUnit=0):
+ try:
+ tiles = self.tiles
+ except AttributeError:
+ raise RuntimeError("No texture, discard has already been called")
+
+ for texture, vertices, _ in tiles:
+ texture.bind(texUnit)
+
+ stride = vertices.shape[-1] * vertices.itemsize
+ gl.glEnableVertexAttribArray(posAttrib)
+ gl.glVertexAttribPointer(posAttrib,
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ stride, vertices)
+
+ texCoordsPtr = c_void_p(vertices.ctypes.data +
+ 2 * vertices.itemsize)
+ gl.glEnableVertexAttribArray(texAttrib)
+ gl.glVertexAttribPointer(texAttrib,
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ stride, texCoordsPtr)
+ gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(vertices))
diff --git a/src/silx/gui/plot/backends/glutils/PlotImageFile.py b/src/silx/gui/plot/backends/glutils/PlotImageFile.py
new file mode 100644
index 0000000..5fb6853
--- /dev/null
+++ b/src/silx/gui/plot/backends/glutils/PlotImageFile.py
@@ -0,0 +1,153 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Function to save an image to a file."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/04/2017"
+
+
+import base64
+import struct
+import sys
+import zlib
+
+
+# Image writer ################################################################
+
+def convertRGBDataToPNG(data):
+ """Convert a RGB bitmap to PNG.
+
+ It only supports RGB bitmap with one byte per channel stored as a 3D array.
+ See `Definitive Guide <http://www.libpng.org/pub/png/book/>`_ and
+ `Specification <http://www.libpng.org/pub/png/spec/1.2/>`_ for details.
+
+ :param data: A 3D array (h, w, rgb) storing an RGB image
+ :type data: numpy.ndarray of unsigned bytes
+ :returns: The PNG encoded data
+ :rtype: bytes
+ """
+ height, width = data.shape[0], data.shape[1]
+ depth = 8 # 8 bit per channel
+ colorType = 2 # 'truecolor' = RGB
+ interlace = 0 # No
+
+ IHDRdata = struct.pack(">ccccIIBBBBB", b'I', b'H', b'D', b'R',
+ width, height, depth, colorType,
+ 0, 0, interlace)
+
+ # Add filter 'None' before each scanline
+ preparedData = b'\x00' + b'\x00'.join(line.tobytes() for line in data)
+ compressedData = zlib.compress(preparedData, 8)
+
+ IDATdata = struct.pack("cccc", b'I', b'D', b'A', b'T')
+ IDATdata += compressedData
+
+ return b''.join([
+ b'\x89PNG\r\n\x1a\n', # PNG signature
+ # IHDR chunk: Image Header
+ struct.pack(">I", 13), # length
+ IHDRdata,
+ struct.pack(">I", zlib.crc32(IHDRdata) & 0xffffffff), # CRC
+ # IDAT chunk: Payload
+ struct.pack(">I", len(compressedData)),
+ IDATdata,
+ struct.pack(">I", zlib.crc32(IDATdata) & 0xffffffff), # CRC
+ b'\x00\x00\x00\x00IEND\xaeB`\x82' # IEND chunk: footer
+ ])
+
+
+def saveImageToFile(data, fileNameOrObj, fileFormat):
+ """Save a RGB image to a file.
+
+ :param data: A 3D array (h, w, 3) storing an RGB image.
+ :type data: numpy.ndarray with of unsigned bytes.
+ :param fileNameOrObj: Filename or object to use to write the image.
+ :type fileNameOrObj: A str or a 'file-like' object with a 'write' method.
+ :param str fileFormat: The type of the file in: 'png', 'ppm', 'svg', 'tiff'.
+ """
+ assert len(data.shape) == 3
+ assert data.shape[2] == 3
+ assert fileFormat in ('png', 'ppm', 'svg', 'tiff')
+
+ if not hasattr(fileNameOrObj, 'write'):
+ if sys.version_info < (3, ):
+ fileObj = open(fileNameOrObj, "wb")
+ else:
+ if fileFormat in ('png', 'ppm', 'tiff'):
+ # Open in binary mode
+ fileObj = open(fileNameOrObj, 'wb')
+ else:
+ fileObj = open(fileNameOrObj, 'w', newline='')
+ else: # Use as a file-like object
+ fileObj = fileNameOrObj
+
+ if fileFormat == 'svg':
+ height, width = data.shape[:2]
+ base64Data = base64.b64encode(convertRGBDataToPNG(data))
+
+ fileObj.write(
+ '<?xml version="1.0" encoding="UTF-8" standalone="no"?>\n')
+ fileObj.write('<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"\n')
+ fileObj.write(
+ ' "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">\n')
+ fileObj.write('<svg xmlns:xlink="http://www.w3.org/1999/xlink"\n')
+ fileObj.write(' xmlns="http://www.w3.org/2000/svg"\n')
+ fileObj.write(' version="1.1"\n')
+ fileObj.write(' width="%d"\n' % width)
+ fileObj.write(' height="%d">\n' % height)
+ fileObj.write(' <image xlink:href="data:image/png;base64,')
+ fileObj.write(base64Data.decode('ascii'))
+ fileObj.write('"\n')
+ fileObj.write(' x="0"\n')
+ fileObj.write(' y="0"\n')
+ fileObj.write(' width="%d"\n' % width)
+ fileObj.write(' height="%d"\n' % height)
+ fileObj.write(' id="image" />\n')
+ fileObj.write('</svg>')
+
+ elif fileFormat == 'ppm':
+ height, width = data.shape[:2]
+
+ fileObj.write(b'P6\n')
+ fileObj.write(b'%d %d\n' % (width, height))
+ fileObj.write(b'255\n')
+ fileObj.write(data.tobytes())
+
+ elif fileFormat == 'png':
+ fileObj.write(convertRGBDataToPNG(data))
+
+ elif fileFormat == 'tiff':
+ if fileObj == fileNameOrObj:
+ raise NotImplementedError(
+ 'Save TIFF to a file-like object not implemented')
+
+ from silx.third_party.TiffIO import TiffIO
+
+ tif = TiffIO(fileNameOrObj, mode='wb+')
+ tif.writeImage(data, info={'Title': 'OpenGL Plot Snapshot'})
+
+ if fileObj != fileNameOrObj:
+ fileObj.close()
diff --git a/src/silx/gui/plot/backends/glutils/__init__.py b/src/silx/gui/plot/backends/glutils/__init__.py
new file mode 100644
index 0000000..f87d7c1
--- /dev/null
+++ b/src/silx/gui/plot/backends/glutils/__init__.py
@@ -0,0 +1,46 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""This module provides convenient classes for the OpenGL rendering backend.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/04/2017"
+
+
+import logging
+
+
+_logger = logging.getLogger(__name__)
+
+
+from .GLPlotCurve import * # noqa
+from .GLPlotFrame import * # noqa
+from .GLPlotImage import * # noqa
+from .GLPlotItem import GLPlotItem, RenderContext # noqa
+from .GLPlotTriangles import GLPlotTriangles # noqa
+from .GLSupport import * # noqa
+from .GLText import * # noqa
+from .GLTexture import * # noqa
diff --git a/src/silx/gui/plot/items/__init__.py b/src/silx/gui/plot/items/__init__.py
new file mode 100644
index 0000000..0fe29c2
--- /dev/null
+++ b/src/silx/gui/plot/items/__init__.py
@@ -0,0 +1,53 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides classes that describes :class:`.PlotWidget` content.
+
+Instances of those classes are returned by :class:`.PlotWidget` methods that give
+access to its content such as :meth:`.PlotWidget.getCurve`, :meth:`.PlotWidget.getImage`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "22/06/2017"
+
+from .core import (Item, DataItem, # noqa
+ LabelsMixIn, DraggableMixIn, ColormapMixIn, # noqa
+ SymbolMixIn, ColorMixIn, YAxisMixIn, FillMixIn, # noqa
+ AlphaMixIn, LineMixIn, ScatterVisualizationMixIn, # noqa
+ ComplexMixIn, ItemChangedType, PointsBase) # noqa
+from .complex import ImageComplexData # noqa
+from .curve import Curve, CurveStyle # noqa
+from .histogram import Histogram # noqa
+from .image import ImageBase, ImageData, ImageDataBase, ImageRgba, ImageStack, MaskImageData # noqa
+from .image_aggregated import ImageDataAggregated # noqa
+from .shape import Shape, BoundingRect, XAxisExtent, YAxisExtent # noqa
+from .scatter import Scatter # noqa
+from .marker import MarkerBase, Marker, XMarker, YMarker # noqa
+from .axis import Axis, XAxis, YAxis, YRightAxis
+
+DATA_ITEMS = (ImageComplexData, Curve, Histogram, ImageBase, Scatter,
+ BoundingRect, XAxisExtent, YAxisExtent)
+"""Classes of items representing data and to consider to compute data bounds.
+"""
diff --git a/src/silx/gui/plot/items/_arc_roi.py b/src/silx/gui/plot/items/_arc_roi.py
new file mode 100644
index 0000000..23416ec
--- /dev/null
+++ b/src/silx/gui/plot/items/_arc_roi.py
@@ -0,0 +1,878 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides Arc ROI item for the :class:`~silx.gui.plot.PlotWidget`.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+import logging
+import numpy
+
+from ... import utils
+from .. import items
+from ...colors import rgba
+from ....utils.proxy import docstring
+from ._roi_base import HandleBasedROI
+from ._roi_base import InteractionModeMixIn
+from ._roi_base import RoiInteractionMode
+
+
+logger = logging.getLogger(__name__)
+
+
+class _ArcGeometry:
+ """
+ Non-mutable object to store the geometry of the arc ROI.
+
+ The aim is is to switch between consistent state without dealing with
+ intermediate values.
+ """
+ def __init__(self, center, startPoint, endPoint, radius,
+ weight, startAngle, endAngle, closed=False):
+ """Constructor for a consistent arc geometry.
+
+ There is also specific class method to create different kind of arc
+ geometry.
+ """
+ self.center = center
+ self.startPoint = startPoint
+ self.endPoint = endPoint
+ self.radius = radius
+ self.weight = weight
+ self.startAngle = startAngle
+ self.endAngle = endAngle
+ self._closed = closed
+
+ @classmethod
+ def createEmpty(cls):
+ """Create an arc geometry from an empty shape
+ """
+ zero = numpy.array([0, 0])
+ return cls(zero, zero.copy(), zero.copy(), 0, 0, 0, 0)
+
+ @classmethod
+ def createRect(cls, startPoint, endPoint, weight):
+ """Create an arc geometry from a definition of a rectangle
+ """
+ return cls(None, startPoint, endPoint, None, weight, None, None, False)
+
+ @classmethod
+ def createCircle(cls, center, startPoint, endPoint, radius,
+ weight, startAngle, endAngle):
+ """Create an arc geometry from a definition of a circle
+ """
+ return cls(center, startPoint, endPoint, radius,
+ weight, startAngle, endAngle, True)
+
+ def withWeight(self, weight):
+ """Return a new geometry based on this object, with a specific weight
+ """
+ return _ArcGeometry(self.center, self.startPoint, self.endPoint,
+ self.radius, weight,
+ self.startAngle, self.endAngle, self._closed)
+
+ def withRadius(self, radius):
+ """Return a new geometry based on this object, with a specific radius.
+
+ The weight and the center is conserved.
+ """
+ startPoint = self.center + (self.startPoint - self.center) / self.radius * radius
+ endPoint = self.center + (self.endPoint - self.center) / self.radius * radius
+ return _ArcGeometry(self.center, startPoint, endPoint,
+ radius, self.weight,
+ self.startAngle, self.endAngle, self._closed)
+
+ def withStartAngle(self, startAngle):
+ """Return a new geometry based on this object, with a specific start angle
+ """
+ vector = numpy.array([numpy.cos(startAngle), numpy.sin(startAngle)])
+ startPoint = self.center + vector * self.radius
+
+ # Never add more than 180 to maintain coherency
+ deltaAngle = startAngle - self.startAngle
+ if deltaAngle > numpy.pi:
+ deltaAngle -= numpy.pi * 2
+ elif deltaAngle < -numpy.pi:
+ deltaAngle += numpy.pi * 2
+
+ startAngle = self.startAngle + deltaAngle
+ return _ArcGeometry(
+ self.center,
+ startPoint,
+ self.endPoint,
+ self.radius,
+ self.weight,
+ startAngle,
+ self.endAngle,
+ self._closed,
+ )
+
+ def withEndAngle(self, endAngle):
+ """Return a new geometry based on this object, with a specific end angle
+ """
+ vector = numpy.array([numpy.cos(endAngle), numpy.sin(endAngle)])
+ endPoint = self.center + vector * self.radius
+
+ # Never add more than 180 to maintain coherency
+ deltaAngle = endAngle - self.endAngle
+ if deltaAngle > numpy.pi:
+ deltaAngle -= numpy.pi * 2
+ elif deltaAngle < -numpy.pi:
+ deltaAngle += numpy.pi * 2
+
+ endAngle = self.endAngle + deltaAngle
+ return _ArcGeometry(
+ self.center,
+ self.startPoint,
+ endPoint,
+ self.radius,
+ self.weight,
+ self.startAngle,
+ endAngle,
+ self._closed,
+ )
+
+ def translated(self, dx, dy):
+ """Return the translated geometry by dx, dy"""
+ delta = numpy.array([dx, dy])
+ center = None if self.center is None else self.center + delta
+ startPoint = None if self.startPoint is None else self.startPoint + delta
+ endPoint = None if self.endPoint is None else self.endPoint + delta
+ return _ArcGeometry(center, startPoint, endPoint,
+ self.radius, self.weight,
+ self.startAngle, self.endAngle, self._closed)
+
+ def getKind(self):
+ """Returns the kind of shape defined"""
+ if self.center is None:
+ return "rect"
+ elif numpy.isnan(self.startAngle):
+ return "point"
+ elif self.isClosed():
+ if self.weight <= 0 or self.weight * 0.5 >= self.radius:
+ return "circle"
+ else:
+ return "donut"
+ else:
+ if self.weight * 0.5 < self.radius:
+ return "arc"
+ else:
+ return "camembert"
+
+ def isClosed(self):
+ """Returns True if the geometry is a circle like"""
+ if self._closed is not None:
+ return self._closed
+ delta = numpy.abs(self.endAngle - self.startAngle)
+ self._closed = numpy.isclose(delta, numpy.pi * 2)
+ return self._closed
+
+ def __str__(self):
+ return str((self.center,
+ self.startPoint,
+ self.endPoint,
+ self.radius,
+ self.weight,
+ self.startAngle,
+ self.endAngle,
+ self._closed))
+
+
+class ArcROI(HandleBasedROI, items.LineMixIn, InteractionModeMixIn):
+ """A ROI identifying an arc of a circle with a width.
+
+ This ROI provides
+ - 3 handle to control the curvature
+ - 1 handle to control the weight
+ - 1 anchor to translate the shape.
+ """
+
+ ICON = 'add-shape-arc'
+ NAME = 'arc ROI'
+ SHORT_NAME = "arc"
+ """Metadata for this kind of ROI"""
+
+ _plotShape = "line"
+ """Plot shape which is used for the first interaction"""
+
+ ThreePointMode = RoiInteractionMode("3 points", "Provides 3 points to define the main radius circle")
+ PolarMode = RoiInteractionMode("Polar", "Provides anchors to edit the ROI in polar coords")
+ # FIXME: MoveMode was designed cause there is too much anchors
+ # FIXME: It would be good replace it by a dnd on the shape
+ MoveMode = RoiInteractionMode("Translation", "Provides anchors to only move the ROI")
+
+ def __init__(self, parent=None):
+ HandleBasedROI.__init__(self, parent=parent)
+ items.LineMixIn.__init__(self)
+ InteractionModeMixIn.__init__(self)
+
+ self._geometry = _ArcGeometry.createEmpty()
+ self._handleLabel = self.addLabelHandle()
+
+ self._handleStart = self.addHandle()
+ self._handleMid = self.addHandle()
+ self._handleEnd = self.addHandle()
+ self._handleWeight = self.addHandle()
+ self._handleWeight._setConstraint(self._arcCurvatureMarkerConstraint)
+ self._handleMove = self.addTranslateHandle()
+
+ shape = items.Shape("polygon")
+ shape.setPoints([[0, 0], [0, 0]])
+ shape.setColor(rgba(self.getColor()))
+ shape.setFill(False)
+ shape.setOverlay(True)
+ shape.setLineStyle(self.getLineStyle())
+ shape.setLineWidth(self.getLineWidth())
+ self.__shape = shape
+ self.addItem(shape)
+
+ self._initInteractionMode(self.ThreePointMode)
+ self._interactiveModeUpdated(self.ThreePointMode)
+
+ def availableInteractionModes(self):
+ """Returns the list of available interaction modes
+
+ :rtype: List[RoiInteractionMode]
+ """
+ return [self.ThreePointMode, self.PolarMode, self.MoveMode]
+
+ def _interactiveModeUpdated(self, modeId):
+ """Set the interaction mode.
+
+ :param RoiInteractionMode modeId:
+ """
+ if modeId is self.ThreePointMode:
+ self._handleStart.setSymbol("s")
+ self._handleMid.setSymbol("s")
+ self._handleEnd.setSymbol("s")
+ self._handleWeight.setSymbol("d")
+ self._handleMove.setSymbol("+")
+ elif modeId is self.PolarMode:
+ self._handleStart.setSymbol("o")
+ self._handleMid.setSymbol("o")
+ self._handleEnd.setSymbol("o")
+ self._handleWeight.setSymbol("d")
+ self._handleMove.setSymbol("+")
+ elif modeId is self.MoveMode:
+ self._handleStart.setSymbol("")
+ self._handleMid.setSymbol("+")
+ self._handleEnd.setSymbol("")
+ self._handleWeight.setSymbol("")
+ self._handleMove.setSymbol("+")
+ else:
+ assert False
+ if self._geometry.isClosed():
+ if modeId != self.MoveMode:
+ self._handleStart.setSymbol("x")
+ self._handleEnd.setSymbol("x")
+ self._updateHandles()
+
+ def _updated(self, event=None, checkVisibility=True):
+ if event == items.ItemChangedType.VISIBLE:
+ self._updateItemProperty(event, self, self.__shape)
+ super(ArcROI, self)._updated(event, checkVisibility)
+
+ def _updatedStyle(self, event, style):
+ super(ArcROI, self)._updatedStyle(event, style)
+ self.__shape.setColor(style.getColor())
+ self.__shape.setLineStyle(style.getLineStyle())
+ self.__shape.setLineWidth(style.getLineWidth())
+
+ def setFirstShapePoints(self, points):
+ """"Initialize the ROI using the points from the first interaction.
+
+ This interaction is constrained by the plot API and only supports few
+ shapes.
+ """
+ # The first shape is a line
+ point0 = points[0]
+ point1 = points[1]
+
+ # Compute a non collinear point for the curvature
+ center = (point1 + point0) * 0.5
+ normal = point1 - center
+ normal = numpy.array((normal[1], -normal[0]))
+ defaultCurvature = numpy.pi / 5.0
+ weightCoef = 0.20
+ mid = center - normal * defaultCurvature
+ distance = numpy.linalg.norm(point0 - point1)
+ weight = distance * weightCoef
+
+ geometry = self._createGeometryFromControlPoints(point0, mid, point1, weight)
+ self._geometry = geometry
+ self._updateHandles()
+
+ def _updateText(self, text):
+ self._handleLabel.setText(text)
+
+ def _updateMidHandle(self):
+ """Keep the same geometry, but update the location of the control
+ points.
+
+ So calling this function do not trigger sigRegionChanged.
+ """
+ geometry = self._geometry
+
+ if geometry.isClosed():
+ start = numpy.array(self._handleStart.getPosition())
+ midPos = geometry.center + geometry.center - start
+ else:
+ if geometry.center is None:
+ midPos = geometry.startPoint * 0.5 + geometry.endPoint * 0.5
+ else:
+ midAngle = geometry.startAngle * 0.5 + geometry.endAngle * 0.5
+ vector = numpy.array([numpy.cos(midAngle), numpy.sin(midAngle)])
+ midPos = geometry.center + geometry.radius * vector
+
+ with utils.blockSignals(self._handleMid):
+ self._handleMid.setPosition(*midPos)
+
+ def _updateWeightHandle(self):
+ geometry = self._geometry
+ if geometry.center is None:
+ # rectangle
+ center = (geometry.startPoint + geometry.endPoint) * 0.5
+ normal = geometry.endPoint - geometry.startPoint
+ normal = numpy.array((normal[1], -normal[0]))
+ distance = numpy.linalg.norm(normal)
+ if distance != 0:
+ normal = normal / distance
+ weightPos = center + normal * geometry.weight * 0.5
+ else:
+ if geometry.isClosed():
+ midAngle = geometry.startAngle + numpy.pi * 0.5
+ elif geometry.center is not None:
+ midAngle = (geometry.startAngle + geometry.endAngle) * 0.5
+ vector = numpy.array([numpy.cos(midAngle), numpy.sin(midAngle)])
+ weightPos = geometry.center + (geometry.radius + geometry.weight * 0.5) * vector
+
+ with utils.blockSignals(self._handleWeight):
+ self._handleWeight.setPosition(*weightPos)
+
+ def _getWeightFromHandle(self, weightPos):
+ geometry = self._geometry
+ if geometry.center is None:
+ # rectangle
+ center = (geometry.startPoint + geometry.endPoint) * 0.5
+ return numpy.linalg.norm(center - weightPos) * 2
+ else:
+ distance = numpy.linalg.norm(geometry.center - weightPos)
+ return abs(distance - geometry.radius) * 2
+
+ def _updateHandles(self):
+ geometry = self._geometry
+ with utils.blockSignals(self._handleStart):
+ self._handleStart.setPosition(*geometry.startPoint)
+ with utils.blockSignals(self._handleEnd):
+ self._handleEnd.setPosition(*geometry.endPoint)
+
+ self._updateMidHandle()
+ self._updateWeightHandle()
+ self._updateShape()
+
+ def _updateCurvature(self, start, mid, end, updateCurveHandles, checkClosed=False, updateStart=False):
+ """Update the curvature using 3 control points in the curve
+
+ :param bool updateCurveHandles: If False curve handles are already at
+ the right location
+ """
+ if checkClosed:
+ closed = self._isCloseInPixel(start, end)
+ else:
+ closed = self._geometry.isClosed()
+ if closed:
+ if updateStart:
+ start = end
+ else:
+ end = start
+
+ if updateCurveHandles:
+ with utils.blockSignals(self._handleStart):
+ self._handleStart.setPosition(*start)
+ with utils.blockSignals(self._handleMid):
+ self._handleMid.setPosition(*mid)
+ with utils.blockSignals(self._handleEnd):
+ self._handleEnd.setPosition(*end)
+
+ weight = self._geometry.weight
+ geometry = self._createGeometryFromControlPoints(start, mid, end, weight, closed=closed)
+ self._geometry = geometry
+
+ self._updateWeightHandle()
+ self._updateShape()
+
+ def _updateCloseInAngle(self, geometry, updateStart):
+ azim = numpy.abs(geometry.endAngle - geometry.startAngle)
+ if numpy.pi < azim < 3 * numpy.pi:
+ closed = self._isCloseInPixel(geometry.startPoint, geometry.endPoint)
+ geometry._closed = closed
+ if closed:
+ sign = 1 if geometry.startAngle < geometry.endAngle else -1
+ if updateStart:
+ geometry.startPoint = geometry.endPoint
+ geometry.startAngle = geometry.endAngle - sign * 2*numpy.pi
+ else:
+ geometry.endPoint = geometry.startPoint
+ geometry.endAngle = geometry.startAngle + sign * 2*numpy.pi
+
+ def handleDragUpdated(self, handle, origin, previous, current):
+ modeId = self.getInteractionMode()
+ if handle is self._handleStart:
+ if modeId is self.ThreePointMode:
+ mid = numpy.array(self._handleMid.getPosition())
+ end = numpy.array(self._handleEnd.getPosition())
+ self._updateCurvature(
+ current, mid, end, checkClosed=True, updateStart=True,
+ updateCurveHandles=False
+ )
+ elif modeId is self.PolarMode:
+ v = current - self._geometry.center
+ startAngle = numpy.angle(complex(v[0], v[1]))
+ geometry = self._geometry.withStartAngle(startAngle)
+ self._updateCloseInAngle(geometry, updateStart=True)
+ self._geometry = geometry
+ self._updateHandles()
+ elif handle is self._handleMid:
+ if modeId is self.ThreePointMode:
+ if self._geometry.isClosed():
+ radius = numpy.linalg.norm(self._geometry.center - current)
+ self._geometry = self._geometry.withRadius(radius)
+ self._updateHandles()
+ else:
+ start = numpy.array(self._handleStart.getPosition())
+ end = numpy.array(self._handleEnd.getPosition())
+ self._updateCurvature(start, current, end, updateCurveHandles=False)
+ elif modeId is self.PolarMode:
+ radius = numpy.linalg.norm(self._geometry.center - current)
+ self._geometry = self._geometry.withRadius(radius)
+ self._updateHandles()
+ elif modeId is self.MoveMode:
+ delta = current - previous
+ self.translate(*delta)
+ elif handle is self._handleEnd:
+ if modeId is self.ThreePointMode:
+ start = numpy.array(self._handleStart.getPosition())
+ mid = numpy.array(self._handleMid.getPosition())
+ self._updateCurvature(
+ start, mid, current, checkClosed=True, updateStart=False,
+ updateCurveHandles=False
+ )
+ elif modeId is self.PolarMode:
+ v = current - self._geometry.center
+ endAngle = numpy.angle(complex(v[0], v[1]))
+ geometry = self._geometry.withEndAngle(endAngle)
+ self._updateCloseInAngle(geometry, updateStart=False)
+ self._geometry = geometry
+ self._updateHandles()
+ elif handle is self._handleWeight:
+ weight = self._getWeightFromHandle(current)
+ self._geometry = self._geometry.withWeight(weight)
+ self._updateShape()
+ elif handle is self._handleMove:
+ delta = current - previous
+ self.translate(*delta)
+
+ def _isCloseInPixel(self, point1, point2):
+ manager = self.parent()
+ if manager is None:
+ return False
+ plot = manager.parent()
+ if plot is None:
+ return False
+ point1 = plot.dataToPixel(*point1)
+ if point1 is None:
+ return False
+ point2 = plot.dataToPixel(*point2)
+ if point2 is None:
+ return False
+ return abs(point1[0] - point2[0]) + abs(point1[1] - point2[1]) < 15
+
+ def _normalizeGeometry(self):
+ """Keep the same phisical geometry, but with normalized parameters.
+ """
+ geometry = self._geometry
+ if geometry.weight * 0.5 >= geometry.radius:
+ radius = (geometry.weight * 0.5 + geometry.radius) * 0.5
+ geometry = geometry.withRadius(radius)
+ geometry = geometry.withWeight(radius * 2)
+ self._geometry = geometry
+ return True
+ return False
+
+ def handleDragFinished(self, handle, origin, current):
+ modeId = self.getInteractionMode()
+ if handle in [self._handleStart, self._handleMid, self._handleEnd]:
+ if modeId is self.ThreePointMode:
+ self._normalizeGeometry()
+ self._updateHandles()
+
+ if self._geometry.isClosed():
+ if modeId is self.MoveMode:
+ self._handleStart.setSymbol("")
+ self._handleEnd.setSymbol("")
+ else:
+ self._handleStart.setSymbol("x")
+ self._handleEnd.setSymbol("x")
+ else:
+ if modeId is self.ThreePointMode:
+ self._handleStart.setSymbol("s")
+ self._handleEnd.setSymbol("s")
+ elif modeId is self.PolarMode:
+ self._handleStart.setSymbol("o")
+ self._handleEnd.setSymbol("o")
+ if modeId is self.MoveMode:
+ self._handleStart.setSymbol("")
+ self._handleEnd.setSymbol("")
+
+ def _createGeometryFromControlPoints(self, start, mid, end, weight, closed=None):
+ """Returns the geometry of the object"""
+ if closed or (closed is None and numpy.allclose(start, end)):
+ # Special arc: It's a closed circle
+ center = (start + mid) * 0.5
+ radius = numpy.linalg.norm(start - center)
+ v = start - center
+ startAngle = numpy.angle(complex(v[0], v[1]))
+ endAngle = startAngle + numpy.pi * 2.0
+ return _ArcGeometry.createCircle(
+ center, start, end, radius, weight, startAngle, endAngle
+ )
+
+ elif numpy.linalg.norm(numpy.cross(mid - start, end - start)) < 1e-5:
+ # Degenerated arc, it's a rectangle
+ return _ArcGeometry.createRect(start, end, weight)
+ else:
+ center, radius = self._circleEquation(start, mid, end)
+ v = start - center
+ startAngle = numpy.angle(complex(v[0], v[1]))
+ v = mid - center
+ midAngle = numpy.angle(complex(v[0], v[1]))
+ v = end - center
+ endAngle = numpy.angle(complex(v[0], v[1]))
+
+ # Is it clockwise or anticlockwise
+ relativeMid = (endAngle - midAngle + 2 * numpy.pi) % (2 * numpy.pi)
+ relativeEnd = (endAngle - startAngle + 2 * numpy.pi) % (2 * numpy.pi)
+ if relativeMid < relativeEnd:
+ if endAngle < startAngle:
+ endAngle += 2 * numpy.pi
+ else:
+ if endAngle > startAngle:
+ endAngle -= 2 * numpy.pi
+
+ return _ArcGeometry(center, start, end,
+ radius, weight, startAngle, endAngle)
+
+ def _createShapeFromGeometry(self, geometry):
+ kind = geometry.getKind()
+ if kind == "rect":
+ # It is not an arc
+ # but we can display it as an intermediate shape
+ normal = geometry.endPoint - geometry.startPoint
+ normal = numpy.array((normal[1], -normal[0]))
+ distance = numpy.linalg.norm(normal)
+ if distance != 0:
+ normal /= distance
+ points = numpy.array([
+ geometry.startPoint + normal * geometry.weight * 0.5,
+ geometry.endPoint + normal * geometry.weight * 0.5,
+ geometry.endPoint - normal * geometry.weight * 0.5,
+ geometry.startPoint - normal * geometry.weight * 0.5])
+ elif kind == "point":
+ # It is not an arc
+ # but we can display it as an intermediate shape
+ # NOTE: At least 2 points are expected
+ points = numpy.array([geometry.startPoint, geometry.startPoint])
+ elif kind == "circle":
+ outerRadius = geometry.radius + geometry.weight * 0.5
+ angles = numpy.linspace(0, 2 * numpy.pi, num=50)
+ # It's a circle
+ points = []
+ numpy.append(angles, angles[-1])
+ for angle in angles:
+ direction = numpy.array([numpy.cos(angle), numpy.sin(angle)])
+ points.append(geometry.center + direction * outerRadius)
+ points = numpy.array(points)
+ elif kind == "donut":
+ innerRadius = geometry.radius - geometry.weight * 0.5
+ outerRadius = geometry.radius + geometry.weight * 0.5
+ angles = numpy.linspace(0, 2 * numpy.pi, num=50)
+ # It's a donut
+ points = []
+ # NOTE: NaN value allow to create 2 separated circle shapes
+ # using a single plot item. It's a kind of cheat
+ points.append(numpy.array([float("nan"), float("nan")]))
+ for angle in angles:
+ direction = numpy.array([numpy.cos(angle), numpy.sin(angle)])
+ points.insert(0, geometry.center + direction * innerRadius)
+ points.append(geometry.center + direction * outerRadius)
+ points.append(numpy.array([float("nan"), float("nan")]))
+ points = numpy.array(points)
+ else:
+ innerRadius = geometry.radius - geometry.weight * 0.5
+ outerRadius = geometry.radius + geometry.weight * 0.5
+
+ delta = 0.1 if geometry.endAngle >= geometry.startAngle else -0.1
+ if geometry.startAngle == geometry.endAngle:
+ # Degenerated, it's a line (single radius)
+ angle = geometry.startAngle
+ direction = numpy.array([numpy.cos(angle), numpy.sin(angle)])
+ points = []
+ points.append(geometry.center + direction * innerRadius)
+ points.append(geometry.center + direction * outerRadius)
+ return numpy.array(points)
+
+ angles = numpy.arange(geometry.startAngle, geometry.endAngle, delta)
+ if angles[-1] != geometry.endAngle:
+ angles = numpy.append(angles, geometry.endAngle)
+
+ if kind == "camembert":
+ # It's a part of camembert
+ points = []
+ points.append(geometry.center)
+ points.append(geometry.startPoint)
+ delta = 0.1 if geometry.endAngle >= geometry.startAngle else -0.1
+ for angle in angles:
+ direction = numpy.array([numpy.cos(angle), numpy.sin(angle)])
+ points.append(geometry.center + direction * outerRadius)
+ points.append(geometry.endPoint)
+ points.append(geometry.center)
+ elif kind == "arc":
+ # It's a part of donut
+ points = []
+ points.append(geometry.startPoint)
+ for angle in angles:
+ direction = numpy.array([numpy.cos(angle), numpy.sin(angle)])
+ points.insert(0, geometry.center + direction * innerRadius)
+ points.append(geometry.center + direction * outerRadius)
+ points.insert(0, geometry.endPoint)
+ points.append(geometry.endPoint)
+ else:
+ assert False
+
+ points = numpy.array(points)
+
+ return points
+
+ def _updateShape(self):
+ geometry = self._geometry
+ points = self._createShapeFromGeometry(geometry)
+ self.__shape.setPoints(points)
+
+ index = numpy.nanargmin(points[:, 1])
+ pos = points[index]
+ with utils.blockSignals(self._handleLabel):
+ self._handleLabel.setPosition(pos[0], pos[1])
+
+ if geometry.center is None:
+ movePos = geometry.startPoint * 0.34 + geometry.endPoint * 0.66
+ else:
+ movePos = geometry.center
+
+ with utils.blockSignals(self._handleMove):
+ self._handleMove.setPosition(*movePos)
+
+ self.sigRegionChanged.emit()
+
+ def getGeometry(self):
+ """Returns a tuple containing the geometry of this ROI
+
+ It is a symmetric function of :meth:`setGeometry`.
+
+ If `startAngle` is smaller than `endAngle` the rotation is clockwise,
+ else the rotation is anticlockwise.
+
+ :rtype: Tuple[numpy.ndarray,float,float,float,float]
+ :raise ValueError: In case the ROI can't be represented as section of
+ a circle
+ """
+ geometry = self._geometry
+ if geometry.center is None:
+ raise ValueError("This ROI can't be represented as a section of circle")
+ return geometry.center, self.getInnerRadius(), self.getOuterRadius(), geometry.startAngle, geometry.endAngle
+
+ def isClosed(self):
+ """Returns true if the arc is a closed shape, like a circle or a donut.
+
+ :rtype: bool
+ """
+ return self._geometry.isClosed()
+
+ def getCenter(self):
+ """Returns the center of the circle used to draw arcs of this ROI.
+
+ This center is usually outside the the shape itself.
+
+ :rtype: numpy.ndarray
+ """
+ return self._geometry.center
+
+ def getStartAngle(self):
+ """Returns the angle of the start of the section of this ROI (in radian).
+
+ If `startAngle` is smaller than `endAngle` the rotation is clockwise,
+ else the rotation is anticlockwise.
+
+ :rtype: float
+ """
+ return self._geometry.startAngle
+
+ def getEndAngle(self):
+ """Returns the angle of the end of the section of this ROI (in radian).
+
+ If `startAngle` is smaller than `endAngle` the rotation is clockwise,
+ else the rotation is anticlockwise.
+
+ :rtype: float
+ """
+ return self._geometry.endAngle
+
+ def getInnerRadius(self):
+ """Returns the radius of the smaller arc used to draw this ROI.
+
+ :rtype: float
+ """
+ geometry = self._geometry
+ radius = geometry.radius - geometry.weight * 0.5
+ if radius < 0:
+ radius = 0
+ return radius
+
+ def getOuterRadius(self):
+ """Returns the radius of the bigger arc used to draw this ROI.
+
+ :rtype: float
+ """
+ geometry = self._geometry
+ radius = geometry.radius + geometry.weight * 0.5
+ return radius
+
+ def setGeometry(self, center, innerRadius, outerRadius, startAngle, endAngle):
+ """
+ Set the geometry of this arc.
+
+ :param numpy.ndarray center: Center of the circle.
+ :param float innerRadius: Radius of the smaller arc of the section.
+ :param float outerRadius: Weight of the bigger arc of the section.
+ It have to be bigger than `innerRadius`
+ :param float startAngle: Location of the start of the section (in radian)
+ :param float endAngle: Location of the end of the section (in radian).
+ If `startAngle` is smaller than `endAngle` the rotation is clockwise,
+ else the rotation is anticlockwise.
+ """
+ if innerRadius > outerRadius:
+ logger.error("inner radius larger than outer radius")
+ innerRadius, outerRadius = outerRadius, innerRadius
+ center = numpy.array(center)
+ radius = (innerRadius + outerRadius) * 0.5
+ weight = outerRadius - innerRadius
+
+ vector = numpy.array([numpy.cos(startAngle), numpy.sin(startAngle)])
+ startPoint = center + vector * radius
+ vector = numpy.array([numpy.cos(endAngle), numpy.sin(endAngle)])
+ endPoint = center + vector * radius
+
+ geometry = _ArcGeometry(center, startPoint, endPoint,
+ radius, weight,
+ startAngle, endAngle, closed=None)
+ self._geometry = geometry
+ self._updateHandles()
+
+ @docstring(HandleBasedROI)
+ def contains(self, position):
+ # first check distance, fastest
+ center = self.getCenter()
+ distance = numpy.sqrt((position[1] - center[1]) ** 2 + ((position[0] - center[0])) ** 2)
+ is_in_distance = self.getInnerRadius() <= distance <= self.getOuterRadius()
+ if not is_in_distance:
+ return False
+ rel_pos = position[1] - center[1], position[0] - center[0]
+ angle = numpy.arctan2(*rel_pos)
+ # angle is inside [-pi, pi]
+
+ # Normalize the start angle between [-pi, pi]
+ # with a positive angle range
+ start_angle = self.getStartAngle()
+ end_angle = self.getEndAngle()
+ azim_range = end_angle - start_angle
+ if azim_range < 0:
+ start_angle = end_angle
+ azim_range = -azim_range
+ start_angle = numpy.mod(start_angle + numpy.pi, 2 * numpy.pi) - numpy.pi
+
+ if angle < start_angle:
+ angle += 2 * numpy.pi
+ return start_angle <= angle <= start_angle + azim_range
+
+ def translate(self, x, y):
+ self._geometry = self._geometry.translated(x, y)
+ self._updateHandles()
+
+ def _arcCurvatureMarkerConstraint(self, x, y):
+ """Curvature marker remains on perpendicular bisector"""
+ geometry = self._geometry
+ if geometry.center is None:
+ center = (geometry.startPoint + geometry.endPoint) * 0.5
+ vector = geometry.startPoint - geometry.endPoint
+ vector = numpy.array((vector[1], -vector[0]))
+ vdist = numpy.linalg.norm(vector)
+ if vdist != 0:
+ normal = numpy.array((vector[1], -vector[0])) / vdist
+ else:
+ normal = numpy.array((0, 0))
+ else:
+ if geometry.isClosed():
+ midAngle = geometry.startAngle + numpy.pi * 0.5
+ else:
+ midAngle = (geometry.startAngle + geometry.endAngle) * 0.5
+ normal = numpy.array([numpy.cos(midAngle), numpy.sin(midAngle)])
+ center = geometry.center
+ dist = numpy.dot(normal, (numpy.array((x, y)) - center))
+ dist = numpy.clip(dist, geometry.radius, geometry.radius * 2)
+ x, y = center + dist * normal
+ return x, y
+
+ @staticmethod
+ def _circleEquation(pt1, pt2, pt3):
+ """Circle equation from 3 (x, y) points
+
+ :return: Position of the center of the circle and the radius
+ :rtype: Tuple[Tuple[float,float],float]
+ """
+ x, y, z = complex(*pt1), complex(*pt2), complex(*pt3)
+ w = z - x
+ w /= y - x
+ c = (x - y) * (w - abs(w) ** 2) / 2j / w.imag - x
+ return numpy.array((-c.real, -c.imag)), abs(c + x)
+
+ def __str__(self):
+ try:
+ center, innerRadius, outerRadius, startAngle, endAngle = self.getGeometry()
+ params = center[0], center[1], innerRadius, outerRadius, startAngle, endAngle
+ params = 'center: %f %f; radius: %f %f; angles: %f %f' % params
+ except ValueError:
+ params = "invalid"
+ return "%s(%s)" % (self.__class__.__name__, params)
diff --git a/src/silx/gui/plot/items/_pick.py b/src/silx/gui/plot/items/_pick.py
new file mode 100644
index 0000000..8c8e781
--- /dev/null
+++ b/src/silx/gui/plot/items/_pick.py
@@ -0,0 +1,72 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2019-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides classes supporting item picking."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "04/06/2019"
+
+import numpy
+
+
+class PickingResult(object):
+ """Class to access picking information in a :class:`PlotWidget`"""
+
+ def __init__(self, item, indices=None):
+ """Init
+
+ :param item: The picked item
+ :param numpy.ndarray indices: Array-like of indices of picked data.
+ Either 1D or 2D with dim0: data dimension and dim1: indices.
+ No copy is made.
+ """
+ self._item = item
+
+ if indices is None or len(indices) == 0:
+ self._indices = None
+ else:
+ # Indices is set to None if indices array is empty
+ indices = numpy.array(indices, copy=False, dtype=numpy.int64)
+ self._indices = None if indices.size == 0 else indices
+
+ def getItem(self):
+ """Returns the item this results corresponds to."""
+ return self._item
+
+ def getIndices(self, copy=True):
+ """Returns indices of picked data.
+
+ If data is 1D, it returns a numpy.ndarray, otherwise
+ it returns a tuple with as many numpy.ndarray as there are
+ dimensions in the data.
+
+ :param bool copy: True (default) to get a copy,
+ False to return internal arrays
+ :rtype: Union[None,numpy.ndarray,List[numpy.ndarray]]
+ """
+ if self._indices is None:
+ return None
+ indices = numpy.array(self._indices, copy=copy)
+ return indices if indices.ndim == 1 else tuple(indices)
diff --git a/src/silx/gui/plot/items/_roi_base.py b/src/silx/gui/plot/items/_roi_base.py
new file mode 100644
index 0000000..3eb6cf4
--- /dev/null
+++ b/src/silx/gui/plot/items/_roi_base.py
@@ -0,0 +1,835 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides base components to create ROI item for
+the :class:`~silx.gui.plot.PlotWidget`.
+
+.. inheritance-diagram::
+ silx.gui.plot.items.roi
+ :parts: 1
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+
+import logging
+import numpy
+import weakref
+
+from ....utils.weakref import WeakList
+from ... import qt
+from .. import items
+from ..items import core
+from ...colors import rgba
+import silx.utils.deprecation
+from ....utils.proxy import docstring
+
+
+logger = logging.getLogger(__name__)
+
+
+class _RegionOfInterestBase(qt.QObject):
+ """Base class of 1D and 2D region of interest
+
+ :param QObject parent: See QObject
+ :param str name: The name of the ROI
+ """
+
+ sigAboutToBeRemoved = qt.Signal()
+ """Signal emitted just before this ROI is removed from its manager."""
+
+ sigItemChanged = qt.Signal(object)
+ """Signal emitted when item has changed.
+
+ It provides a flag describing which property of the item has changed.
+ See :class:`ItemChangedType` for flags description.
+ """
+
+ def __init__(self, parent=None):
+ qt.QObject.__init__(self, parent=parent)
+ self.__name = ''
+
+ def getName(self):
+ """Returns the name of the ROI
+
+ :return: name of the region of interest
+ :rtype: str
+ """
+ return self.__name
+
+ def setName(self, name):
+ """Set the name of the ROI
+
+ :param str name: name of the region of interest
+ """
+ name = str(name)
+ if self.__name != name:
+ self.__name = name
+ self._updated(items.ItemChangedType.NAME)
+
+ def _updated(self, event=None, checkVisibility=True):
+ """Implement Item mix-in update method by updating the plot items
+
+ See :class:`~silx.gui.plot.items.Item._updated`
+ """
+ self.sigItemChanged.emit(event)
+
+ def contains(self, position):
+ """Returns True if the `position` is in this ROI.
+
+ :param tuple[float,float] position: position to check
+ :return: True if the value / point is consider to be in the region of
+ interest.
+ :rtype: bool
+ """
+ return False # Override in subclass to perform actual test
+
+
+class RoiInteractionMode(object):
+ """Description of an interaction mode.
+
+ An interaction mode provide a specific kind of interaction for a ROI.
+ A ROI can implement many interaction.
+ """
+
+ def __init__(self, label, description=None):
+ self._label = label
+ self._description = description
+
+ @property
+ def label(self):
+ return self._label
+
+ @property
+ def description(self):
+ return self._description
+
+
+class InteractionModeMixIn(object):
+ """Mix in feature which can be implemented by a ROI object.
+
+ This provides user interaction to switch between different
+ interaction mode to edit the ROI.
+
+ This ROI modes have to be described using `RoiInteractionMode`,
+ and taken into account during interation with handles.
+ """
+
+ sigInteractionModeChanged = qt.Signal(object)
+
+ def __init__(self):
+ self.__modeId = None
+
+ def _initInteractionMode(self, modeId):
+ """Set the mode without updating anything.
+
+ Must be one of the returned :meth:`availableInteractionModes`.
+
+ :param RoiInteractionMode modeId: Mode to use
+ """
+ self.__modeId = modeId
+
+ def availableInteractionModes(self):
+ """Returns the list of available interaction modes
+
+ Must be implemented when inherited to provide all available modes.
+
+ :rtype: List[RoiInteractionMode]
+ """
+ raise NotImplementedError()
+
+ def setInteractionMode(self, modeId):
+ """Set the interaction mode.
+
+ :param RoiInteractionMode modeId: Mode to use
+ """
+ self.__modeId = modeId
+ self._interactiveModeUpdated(modeId)
+ self.sigInteractionModeChanged.emit(modeId)
+
+ def _interactiveModeUpdated(self, modeId):
+ """Called directly after an update of the mode.
+
+ The signal `sigInteractionModeChanged` is triggered after this
+ call.
+
+ Must be implemented when inherited to take care of the change.
+ """
+ raise NotImplementedError()
+
+ def getInteractionMode(self):
+ """Returns the interaction mode.
+
+ Must be one of the returned :meth:`availableInteractionModes`.
+
+ :rtype: RoiInteractionMode
+ """
+ return self.__modeId
+
+
+class RegionOfInterest(_RegionOfInterestBase, core.HighlightedMixIn):
+ """Object describing a region of interest in a plot.
+
+ :param QObject parent:
+ The RegionOfInterestManager that created this object
+ """
+
+ _DEFAULT_LINEWIDTH = 1.
+ """Default line width of the curve"""
+
+ _DEFAULT_LINESTYLE = '-'
+ """Default line style of the curve"""
+
+ _DEFAULT_HIGHLIGHT_STYLE = items.CurveStyle(linewidth=2)
+ """Default highlight style of the item"""
+
+ ICON, NAME, SHORT_NAME = None, None, None
+ """Metadata to describe the ROI in labels, tooltips and widgets
+
+ Should be set by inherited classes to custom the ROI manager widget.
+ """
+
+ sigRegionChanged = qt.Signal()
+ """Signal emitted everytime the shape or position of the ROI changes"""
+
+ sigEditingStarted = qt.Signal()
+ """Signal emitted when the user start editing the roi"""
+
+ sigEditingFinished = qt.Signal()
+ """Signal emitted when the region edition is finished. During edition
+ sigEditionChanged will be emitted several times and
+ sigRegionEditionFinished only at end"""
+
+ def __init__(self, parent=None):
+ # Avoid circular dependency
+ from ..tools import roi as roi_tools
+ assert parent is None or isinstance(parent, roi_tools.RegionOfInterestManager)
+ _RegionOfInterestBase.__init__(self, parent)
+ core.HighlightedMixIn.__init__(self)
+ self._color = rgba('red')
+ self._editable = False
+ self._selectable = False
+ self._focusProxy = None
+ self._visible = True
+ self._child = WeakList()
+
+ def _connectToPlot(self, plot):
+ """Called after connection to a plot"""
+ for item in self.getItems():
+ # This hack is needed to avoid reentrant call from _disconnectFromPlot
+ # to the ROI manager. It also speed up the item tests in _itemRemoved
+ item._roiGroup = True
+ plot.addItem(item)
+
+ def _disconnectFromPlot(self, plot):
+ """Called before disconnection from a plot"""
+ for item in self.getItems():
+ # The item could be already be removed by the plot
+ if item.getPlot() is not None:
+ del item._roiGroup
+ plot.removeItem(item)
+
+ def _setItemName(self, item):
+ """Helper to generate a unique id to a plot item"""
+ legend = "__ROI-%d__%d" % (id(self), id(item))
+ item.setName(legend)
+
+ def setParent(self, parent):
+ """Set the parent of the RegionOfInterest
+
+ :param Union[None,RegionOfInterestManager] parent: The new parent
+ """
+ # Avoid circular dependency
+ from ..tools import roi as roi_tools
+ if (parent is not None and not isinstance(parent, roi_tools.RegionOfInterestManager)):
+ raise ValueError('Unsupported parent')
+
+ previousParent = self.parent()
+ if previousParent is not None:
+ previousPlot = previousParent.parent()
+ if previousPlot is not None:
+ self._disconnectFromPlot(previousPlot)
+ super(RegionOfInterest, self).setParent(parent)
+ if parent is not None:
+ plot = parent.parent()
+ if plot is not None:
+ self._connectToPlot(plot)
+
+ def addItem(self, item):
+ """Add an item to the set of this ROI children.
+
+ This item will be added and removed to the plot used by the ROI.
+
+ If the ROI is already part of a plot, the item will also be added to
+ the plot.
+
+ It the item do not have a name already, a unique one is generated to
+ avoid item collision in the plot.
+
+ :param silx.gui.plot.items.Item item: A plot item
+ """
+ assert item is not None
+ self._child.append(item)
+ if item.getName() == '':
+ self._setItemName(item)
+ manager = self.parent()
+ if manager is not None:
+ plot = manager.parent()
+ if plot is not None:
+ item._roiGroup = True
+ plot.addItem(item)
+
+ def removeItem(self, item):
+ """Remove an item from this ROI children.
+
+ If the item is part of a plot it will be removed too.
+
+ :param silx.gui.plot.items.Item item: A plot item
+ """
+ assert item is not None
+ self._child.remove(item)
+ plot = item.getPlot()
+ if plot is not None:
+ del item._roiGroup
+ plot.removeItem(item)
+
+ def getItems(self):
+ """Returns the list of PlotWidget items of this RegionOfInterest.
+
+ :rtype: List[~silx.gui.plot.items.Item]
+ """
+ return tuple(self._child)
+
+ @classmethod
+ def _getShortName(cls):
+ """Return an human readable kind of ROI
+
+ :rtype: str
+ """
+ if hasattr(cls, "SHORT_NAME"):
+ name = cls.SHORT_NAME
+ if name is None:
+ name = cls.__name__
+ return name
+
+ def getColor(self):
+ """Returns the color of this ROI
+
+ :rtype: QColor
+ """
+ return qt.QColor.fromRgbF(*self._color)
+
+ def setColor(self, color):
+ """Set the color used for this ROI.
+
+ :param color: The color to use for ROI shape as
+ either a color name, a QColor, a list of uint8 or float in [0, 1].
+ """
+ color = rgba(color)
+ if color != self._color:
+ self._color = color
+ self._updated(items.ItemChangedType.COLOR)
+
+ @silx.utils.deprecation.deprecated(reason='API modification',
+ replacement='getName()',
+ since_version=0.12)
+ def getLabel(self):
+ """Returns the label displayed for this ROI.
+
+ :rtype: str
+ """
+ return self.getName()
+
+ @silx.utils.deprecation.deprecated(reason='API modification',
+ replacement='setName(name)',
+ since_version=0.12)
+ def setLabel(self, label):
+ """Set the label displayed with this ROI.
+
+ :param str label: The text label to display
+ """
+ self.setName(name=label)
+
+ def isEditable(self):
+ """Returns whether the ROI is editable by the user or not.
+
+ :rtype: bool
+ """
+ return self._editable
+
+ def setEditable(self, editable):
+ """Set whether the ROI can be changed interactively.
+
+ :param bool editable: True to allow edition by the user,
+ False to disable.
+ """
+ editable = bool(editable)
+ if self._editable != editable:
+ self._editable = editable
+ self._updated(items.ItemChangedType.EDITABLE)
+
+ def isSelectable(self):
+ """Returns whether the ROI is selectable by the user or not.
+
+ :rtype: bool
+ """
+ return self._selectable
+
+ def setSelectable(self, selectable):
+ """Set whether the ROI can be selected interactively.
+
+ :param bool selectable: True to allow selection by the user,
+ False to disable.
+ """
+ selectable = bool(selectable)
+ if self._selectable != selectable:
+ self._selectable = selectable
+ self._updated(items.ItemChangedType.SELECTABLE)
+
+ def getFocusProxy(self):
+ """Returns the ROI which have to be selected when this ROI is selected,
+ else None if no proxy specified.
+
+ :rtype: RegionOfInterest
+ """
+ proxy = self._focusProxy
+ if proxy is None:
+ return None
+ proxy = proxy()
+ if proxy is None:
+ self._focusProxy = None
+ return proxy
+
+ def setFocusProxy(self, roi):
+ """Set the real ROI which will be selected when this ROI is selected,
+ else None to remove the proxy already specified.
+
+ :param RegionOfInterest roi: A ROI
+ """
+ if roi is not None:
+ self._focusProxy = weakref.ref(roi)
+ else:
+ self._focusProxy = None
+
+ def isVisible(self):
+ """Returns whether the ROI is visible in the plot.
+
+ .. note::
+ This does not take into account whether or not the plot
+ widget itself is visible (unlike :meth:`QWidget.isVisible` which
+ checks the visibility of all its parent widgets up to the window)
+
+ :rtype: bool
+ """
+ return self._visible
+
+ def setVisible(self, visible):
+ """Set whether the plot items associated with this ROI are
+ visible in the plot.
+
+ :param bool visible: True to show the ROI in the plot, False to
+ hide it.
+ """
+ visible = bool(visible)
+ if self._visible != visible:
+ self._visible = visible
+ self._updated(items.ItemChangedType.VISIBLE)
+
+ @classmethod
+ def showFirstInteractionShape(cls):
+ """Returns True if the shape created by the first interaction and
+ managed by the plot have to be visible.
+
+ :rtype: bool
+ """
+ return False
+
+ @classmethod
+ def getFirstInteractionShape(cls):
+ """Returns the shape kind which will be used by the very first
+ interaction with the plot.
+
+ This interactions are hardcoded inside the plot
+
+ :rtype: str
+ """
+ return cls._plotShape
+
+ def setFirstShapePoints(self, points):
+ """"Initialize the ROI using the points from the first interaction.
+
+ This interaction is constrained by the plot API and only supports few
+ shapes.
+ """
+ raise NotImplementedError()
+
+ def creationStarted(self):
+ """"Called when the ROI creation interaction was started.
+ """
+ pass
+
+ def creationFinalized(self):
+ """"Called when the ROI creation interaction was finalized.
+ """
+ pass
+
+ def _updateItemProperty(self, event, source, destination):
+ """Update the item property of a destination from an item source.
+
+ :param items.ItemChangedType event: Property type to update
+ :param silx.gui.plot.items.Item source: The reference for the data
+ :param event Union[Item,List[Item]] destination: The item(s) to update
+ """
+ if not isinstance(destination, (list, tuple)):
+ destination = [destination]
+ if event == items.ItemChangedType.NAME:
+ value = source.getName()
+ for d in destination:
+ d.setName(value)
+ elif event == items.ItemChangedType.EDITABLE:
+ value = source.isEditable()
+ for d in destination:
+ d.setEditable(value)
+ elif event == items.ItemChangedType.SELECTABLE:
+ value = source.isSelectable()
+ for d in destination:
+ d._setSelectable(value)
+ elif event == items.ItemChangedType.COLOR:
+ value = rgba(source.getColor())
+ for d in destination:
+ d.setColor(value)
+ elif event == items.ItemChangedType.LINE_STYLE:
+ value = self.getLineStyle()
+ for d in destination:
+ d.setLineStyle(value)
+ elif event == items.ItemChangedType.LINE_WIDTH:
+ value = self.getLineWidth()
+ for d in destination:
+ d.setLineWidth(value)
+ elif event == items.ItemChangedType.SYMBOL:
+ value = self.getSymbol()
+ for d in destination:
+ d.setSymbol(value)
+ elif event == items.ItemChangedType.SYMBOL_SIZE:
+ value = self.getSymbolSize()
+ for d in destination:
+ d.setSymbolSize(value)
+ elif event == items.ItemChangedType.VISIBLE:
+ value = self.isVisible()
+ for d in destination:
+ d.setVisible(value)
+ else:
+ assert False
+
+ def _updated(self, event=None, checkVisibility=True):
+ if event == items.ItemChangedType.HIGHLIGHTED:
+ style = self.getCurrentStyle()
+ self._updatedStyle(event, style)
+ else:
+ styleEvents = [items.ItemChangedType.COLOR,
+ items.ItemChangedType.LINE_STYLE,
+ items.ItemChangedType.LINE_WIDTH,
+ items.ItemChangedType.SYMBOL,
+ items.ItemChangedType.SYMBOL_SIZE]
+ if self.isHighlighted():
+ styleEvents.append(items.ItemChangedType.HIGHLIGHTED_STYLE)
+
+ if event in styleEvents:
+ style = self.getCurrentStyle()
+ self._updatedStyle(event, style)
+
+ super(RegionOfInterest, self)._updated(event, checkVisibility)
+
+ def _updatedStyle(self, event, style):
+ """Called when the current displayed style of the ROI was changed.
+
+ :param event: The event responsible of the change of the style
+ :param items.CurveStyle style: The current style
+ """
+ pass
+
+ def getCurrentStyle(self):
+ """Returns the current curve style.
+
+ Curve style depends on curve highlighting
+
+ :rtype: CurveStyle
+ """
+ baseColor = rgba(self.getColor())
+ if isinstance(self, core.LineMixIn):
+ baseLinestyle = self.getLineStyle()
+ baseLinewidth = self.getLineWidth()
+ else:
+ baseLinestyle = self._DEFAULT_LINESTYLE
+ baseLinewidth = self._DEFAULT_LINEWIDTH
+ if isinstance(self, core.SymbolMixIn):
+ baseSymbol = self.getSymbol()
+ baseSymbolsize = self.getSymbolSize()
+ else:
+ baseSymbol = 'o'
+ baseSymbolsize = 1
+
+ if self.isHighlighted():
+ style = self.getHighlightedStyle()
+ color = style.getColor()
+ linestyle = style.getLineStyle()
+ linewidth = style.getLineWidth()
+ symbol = style.getSymbol()
+ symbolsize = style.getSymbolSize()
+
+ return items.CurveStyle(
+ color=baseColor if color is None else color,
+ linestyle=baseLinestyle if linestyle is None else linestyle,
+ linewidth=baseLinewidth if linewidth is None else linewidth,
+ symbol=baseSymbol if symbol is None else symbol,
+ symbolsize=baseSymbolsize if symbolsize is None else symbolsize)
+ else:
+ return items.CurveStyle(color=baseColor,
+ linestyle=baseLinestyle,
+ linewidth=baseLinewidth,
+ symbol=baseSymbol,
+ symbolsize=baseSymbolsize)
+
+ def _editingStarted(self):
+ assert self._editable is True
+ self.sigEditingStarted.emit()
+
+ def _editingFinished(self):
+ self.sigEditingFinished.emit()
+
+
+class HandleBasedROI(RegionOfInterest):
+ """Manage a ROI based on a set of handles"""
+
+ def __init__(self, parent=None):
+ RegionOfInterest.__init__(self, parent=parent)
+ self._handles = []
+ self._posOrigin = None
+ self._posPrevious = None
+
+ def addUserHandle(self, item=None):
+ """
+ Add a new free handle to the ROI.
+
+ This handle do nothing. It have to be managed by the ROI
+ implementing this class.
+
+ :param Union[None,silx.gui.plot.items.Marker] item: The new marker to
+ add, else None to create a default marker.
+ :rtype: silx.gui.plot.items.Marker
+ """
+ return self.addHandle(item, role="user")
+
+ def addLabelHandle(self, item=None):
+ """
+ Add a new label handle to the ROI.
+
+ This handle is not draggable nor selectable.
+
+ It is displayed without symbol, but it is always visible anyway
+ the ROI is editable, in order to display text.
+
+ :param Union[None,silx.gui.plot.items.Marker] item: The new marker to
+ add, else None to create a default marker.
+ :rtype: silx.gui.plot.items.Marker
+ """
+ return self.addHandle(item, role="label")
+
+ def addTranslateHandle(self, item=None):
+ """
+ Add a new translate handle to the ROI.
+
+ Dragging translate handles affect the position position of the ROI
+ but not the shape itself.
+
+ :param Union[None,silx.gui.plot.items.Marker] item: The new marker to
+ add, else None to create a default marker.
+ :rtype: silx.gui.plot.items.Marker
+ """
+ return self.addHandle(item, role="translate")
+
+ def addHandle(self, item=None, role="default"):
+ """
+ Add a new handle to the ROI.
+
+ Dragging handles while affect the position or the shape of the
+ ROI.
+
+ :param Union[None,silx.gui.plot.items.Marker] item: The new marker to
+ add, else None to create a default marker.
+ :rtype: silx.gui.plot.items.Marker
+ """
+ if item is None:
+ item = items.Marker()
+ color = rgba(self.getColor())
+ color = self._computeHandleColor(color)
+ item.setColor(color)
+ if role == "default":
+ item.setSymbol("s")
+ elif role == "user":
+ pass
+ elif role == "translate":
+ item.setSymbol("+")
+ elif role == "label":
+ item.setSymbol("")
+
+ if role == "user":
+ pass
+ elif role == "label":
+ item._setSelectable(False)
+ item._setDraggable(False)
+ item.setVisible(True)
+ else:
+ self.__updateEditable(item, self.isEditable(), remove=False)
+ item._setSelectable(False)
+
+ self._handles.append((item, role))
+ self.addItem(item)
+ return item
+
+ def removeHandle(self, handle):
+ data = [d for d in self._handles if d[0] is handle][0]
+ self._handles.remove(data)
+ role = data[1]
+ if role not in ["user", "label"]:
+ if self.isEditable():
+ self.__updateEditable(handle, False)
+ self.removeItem(handle)
+
+ def getHandles(self):
+ """Returns the list of handles of this HandleBasedROI.
+
+ :rtype: List[~silx.gui.plot.items.Marker]
+ """
+ return tuple(data[0] for data in self._handles)
+
+ def _updated(self, event=None, checkVisibility=True):
+ """Implement Item mix-in update method by updating the plot items
+
+ See :class:`~silx.gui.plot.items.Item._updated`
+ """
+ if event == items.ItemChangedType.NAME:
+ self._updateText(self.getName())
+ elif event == items.ItemChangedType.VISIBLE:
+ for item, role in self._handles:
+ visible = self.isVisible()
+ editionVisible = visible and self.isEditable()
+ if role not in ["user", "label"]:
+ item.setVisible(editionVisible)
+ else:
+ item.setVisible(visible)
+ elif event == items.ItemChangedType.EDITABLE:
+ for item, role in self._handles:
+ editable = self.isEditable()
+ if role not in ["user", "label"]:
+ self.__updateEditable(item, editable)
+ super(HandleBasedROI, self)._updated(event, checkVisibility)
+
+ def _updatedStyle(self, event, style):
+ super(HandleBasedROI, self)._updatedStyle(event, style)
+
+ # Update color of shape items in the plot
+ color = rgba(self.getColor())
+ handleColor = self._computeHandleColor(color)
+ for item, role in self._handles:
+ if role == 'user':
+ pass
+ elif role == 'label':
+ item.setColor(color)
+ else:
+ item.setColor(handleColor)
+
+ def __updateEditable(self, handle, editable, remove=True):
+ # NOTE: visibility change emit a position update event
+ handle.setVisible(editable and self.isVisible())
+ handle._setDraggable(editable)
+ if editable:
+ handle.sigDragStarted.connect(self._handleEditingStarted)
+ handle.sigItemChanged.connect(self._handleEditingUpdated)
+ handle.sigDragFinished.connect(self._handleEditingFinished)
+ else:
+ if remove:
+ handle.sigDragStarted.disconnect(self._handleEditingStarted)
+ handle.sigItemChanged.disconnect(self._handleEditingUpdated)
+ handle.sigDragFinished.disconnect(self._handleEditingFinished)
+
+ def _handleEditingStarted(self):
+ super(HandleBasedROI, self)._editingStarted()
+ handle = self.sender()
+ self._posOrigin = numpy.array(handle.getPosition())
+ self._posPrevious = numpy.array(self._posOrigin)
+ self.handleDragStarted(handle, self._posOrigin)
+
+ def _handleEditingUpdated(self):
+ if self._posOrigin is None:
+ # Avoid to handle events when visibility change
+ return
+ handle = self.sender()
+ current = numpy.array(handle.getPosition())
+ self.handleDragUpdated(handle, self._posOrigin, self._posPrevious, current)
+ self._posPrevious = current
+
+ def _handleEditingFinished(self):
+ handle = self.sender()
+ current = numpy.array(handle.getPosition())
+ self.handleDragFinished(handle, self._posOrigin, current)
+ self._posPrevious = None
+ self._posOrigin = None
+ super(HandleBasedROI, self)._editingFinished()
+
+ def isHandleBeingDragged(self):
+ """Returns True if one of the handles is currently being dragged.
+
+ :rtype: bool
+ """
+ return self._posOrigin is not None
+
+ def handleDragStarted(self, handle, origin):
+ """Called when an handler drag started"""
+ pass
+
+ def handleDragUpdated(self, handle, origin, previous, current):
+ """Called when an handle drag position changed"""
+ pass
+
+ def handleDragFinished(self, handle, origin, current):
+ """Called when an handle drag finished"""
+ pass
+
+ def _computeHandleColor(self, color):
+ """Returns the anchor color from the base ROI color
+
+ :param Union[numpy.array,Tuple,List]: color
+ :rtype: Union[numpy.array,Tuple,List]
+ """
+ return color[:3] + (0.5,)
+
+ def _updateText(self, text):
+ """Update the text displayed by this ROI
+
+ :param str text: A text
+ """
+ pass
diff --git a/src/silx/gui/plot/items/axis.py b/src/silx/gui/plot/items/axis.py
new file mode 100644
index 0000000..c73323e
--- /dev/null
+++ b/src/silx/gui/plot/items/axis.py
@@ -0,0 +1,560 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the class for axes of the :class:`PlotWidget`.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "22/11/2018"
+
+import datetime as dt
+import enum
+import logging
+
+import dateutil.tz
+import numpy
+
+from ... import qt
+from .. import _utils
+
+
+_logger = logging.getLogger(__name__)
+
+
+class TickMode(enum.Enum):
+ """Determines if ticks are regular number or datetimes."""
+ DEFAULT = 0 # Ticks are regular numbers
+ TIME_SERIES = 1 # Ticks are datetime objects
+
+
+class Axis(qt.QObject):
+ """This class describes and controls a plot axis.
+
+ Note: This is an abstract class.
+ """
+ # States are half-stored on the backend of the plot, and half-stored on this
+ # object.
+ # TODO It would be good to store all the states of an axis in this object.
+ # i.e. vmin and vmax
+
+ LINEAR = "linear"
+ """Constant defining a linear scale"""
+
+ LOGARITHMIC = "log"
+ """Constant defining a logarithmic scale"""
+
+ _SCALES = set([LINEAR, LOGARITHMIC])
+
+ sigInvertedChanged = qt.Signal(bool)
+ """Signal emitted when axis orientation has changed"""
+
+ sigScaleChanged = qt.Signal(str)
+ """Signal emitted when axis scale has changed"""
+
+ _sigLogarithmicChanged = qt.Signal(bool)
+ """Signal emitted when axis scale has changed to or from logarithmic"""
+
+ sigAutoScaleChanged = qt.Signal(bool)
+ """Signal emitted when axis autoscale has changed"""
+
+ sigLimitsChanged = qt.Signal(float, float)
+ """Signal emitted when axis limits have changed"""
+
+ def __init__(self, plot):
+ """Constructor
+
+ :param silx.gui.plot.PlotWidget.PlotWidget plot: Parent plot of this
+ axis
+ """
+ qt.QObject.__init__(self, parent=plot)
+ self._scale = self.LINEAR
+ self._isAutoScale = True
+ # Store default labels provided to setGraph[X|Y]Label
+ self._defaultLabel = ''
+ # Store currently displayed labels
+ # Current label can differ from input one with active curve handling
+ self._currentLabel = ''
+
+ def _getPlot(self):
+ """Returns the PlotWidget this Axis belongs to.
+
+ :rtype: PlotWidget
+ """
+ plot = self.parent()
+ if plot is None:
+ raise RuntimeError("Axis no longer attached to a PlotWidget")
+ return plot
+
+ def _getBackend(self):
+ """Returns the backend
+
+ :rtype: BackendBase
+ """
+ return self._getPlot()._backend
+
+ def getLimits(self):
+ """Get the limits of this axis.
+
+ :return: Minimum and maximum values of this axis as tuple
+ """
+ return self._internalGetLimits()
+
+ def setLimits(self, vmin, vmax):
+ """Set this axis limits.
+
+ :param float vmin: minimum axis value
+ :param float vmax: maximum axis value
+ """
+ vmin, vmax = self._checkLimits(vmin, vmax)
+ if self.getLimits() == (vmin, vmax):
+ return
+
+ self._internalSetLimits(vmin, vmax)
+ self._getPlot()._setDirtyPlot()
+
+ self._emitLimitsChanged()
+
+ def _emitLimitsChanged(self):
+ """Emit axis sigLimitsChanged and PlotWidget limitsChanged event"""
+ vmin, vmax = self.getLimits()
+ self.sigLimitsChanged.emit(vmin, vmax)
+ self._getPlot()._notifyLimitsChanged(emitSignal=False)
+
+ def _checkLimits(self, vmin, vmax):
+ """Makes sure axis range is not empty and within supported range.
+
+ :param float vmin: Min axis value
+ :param float vmax: Max axis value
+ :return: (min, max) making sure min < max
+ :rtype: 2-tuple of float
+ """
+ return _utils.checkAxisLimits(
+ vmin, vmax, isLog=self._isLogarithmic(), name=self._defaultLabel)
+
+ def isInverted(self):
+ """Return True if the axis is inverted (top to bottom for the y-axis),
+ False otherwise. It is always False for the X axis.
+
+ :rtype: bool
+ """
+ return False
+
+ def setInverted(self, isInverted):
+ """Set the axis orientation.
+
+ This is only available for the Y axis.
+
+ :param bool flag: True for Y axis going from top to bottom,
+ False for Y axis going from bottom to top
+ """
+ if isInverted == self.isInverted():
+ return
+ raise NotImplementedError()
+
+ def getLabel(self):
+ """Return the current displayed label of this axis.
+
+ :param str axis: The Y axis for which to get the label (left or right)
+ :rtype: str
+ """
+ return self._currentLabel
+
+ def setLabel(self, label):
+ """Set the label displayed on the plot for this axis.
+
+ The provided label can be temporarily replaced by the label of the
+ active curve if any.
+
+ :param str label: The axis label
+ """
+ self._defaultLabel = label
+ self._setCurrentLabel(label)
+ self._getPlot()._setDirtyPlot()
+
+ def _setCurrentLabel(self, label):
+ """Define the label currently displayed.
+
+ If the label is None or empty the default label is used.
+
+ :param str label: Currently displayed label
+ """
+ if label is None or label == '':
+ label = self._defaultLabel
+ if label is None:
+ label = ''
+ self._currentLabel = label
+ self._internalSetCurrentLabel(label)
+
+ def getScale(self):
+ """Return the name of the scale used by this axis.
+
+ :rtype: str
+ """
+ return self._scale
+
+ def setScale(self, scale):
+ """Set the scale to be used by this axis.
+
+ :param str scale: Name of the scale ("log", or "linear")
+ """
+ assert(scale in self._SCALES)
+ if self._scale == scale:
+ return
+
+ # For the backward compatibility signal
+ emitLog = self._scale == self.LOGARITHMIC or scale == self.LOGARITHMIC
+
+ self._scale = scale
+
+ # TODO hackish way of forcing update of curves and images
+ plot = self._getPlot()
+ for item in plot.getItems():
+ item._updated()
+ plot._invalidateDataRange()
+
+ if scale == self.LOGARITHMIC:
+ self._internalSetLogarithmic(True)
+ elif scale == self.LINEAR:
+ self._internalSetLogarithmic(False)
+ else:
+ raise ValueError("Scale %s unsupported" % scale)
+
+ plot._forceResetZoom()
+
+ self.sigScaleChanged.emit(self._scale)
+ if emitLog:
+ self._sigLogarithmicChanged.emit(self._scale == self.LOGARITHMIC)
+
+ def _isLogarithmic(self):
+ """Return True if this axis scale is logarithmic, False if linear.
+
+ :rtype: bool
+ """
+ return self._scale == self.LOGARITHMIC
+
+ def _setLogarithmic(self, flag):
+ """Set the scale of this axes (either linear or logarithmic).
+
+ :param bool flag: True to use a logarithmic scale, False for linear.
+ """
+ flag = bool(flag)
+ self.setScale(self.LOGARITHMIC if flag else self.LINEAR)
+
+ def getTimeZone(self):
+ """Sets tzinfo that is used if this axis plots date times.
+
+ None means the datetimes are interpreted as local time.
+
+ :rtype: datetime.tzinfo of None.
+ """
+ raise NotImplementedError()
+
+ def setTimeZone(self, tz):
+ """Sets tzinfo that is used if this axis' tickMode is TIME_SERIES
+
+ The tz must be a descendant of the datetime.tzinfo class, "UTC" or None.
+ Use None to let the datetimes be interpreted as local time.
+ Use the string "UTC" to let the date datetimes be in UTC time.
+
+ :param tz: datetime.tzinfo, "UTC" or None.
+ """
+ raise NotImplementedError()
+
+ def getTickMode(self):
+ """Determines if axis ticks are number or datetimes.
+
+ :rtype: TickMode enum.
+ """
+ raise NotImplementedError()
+
+ def setTickMode(self, tickMode):
+ """Determines if axis ticks are number or datetimes.
+
+ :param TickMode tickMode: tick mode enum.
+ """
+ raise NotImplementedError()
+
+ def isAutoScale(self):
+ """Return True if axis is automatically adjusting its limits.
+
+ :rtype: bool
+ """
+ return self._isAutoScale
+
+ def setAutoScale(self, flag=True):
+ """Set the axis limits adjusting behavior of :meth:`resetZoom`.
+
+ :param bool flag: True to resize limits automatically,
+ False to disable it.
+ """
+ self._isAutoScale = bool(flag)
+ self.sigAutoScaleChanged.emit(self._isAutoScale)
+
+ def _setLimitsConstraints(self, minPos=None, maxPos=None):
+ raise NotImplementedError()
+
+ def setLimitsConstraints(self, minPos=None, maxPos=None):
+ """
+ Set a constraint on the position of the axes.
+
+ :param float minPos: Minimum allowed axis value.
+ :param float maxPos: Maximum allowed axis value.
+ :return: True if the constaints was updated
+ :rtype: bool
+ """
+ updated = self._setLimitsConstraints(minPos, maxPos)
+ if updated:
+ plot = self._getPlot()
+ xMin, xMax = plot.getXAxis().getLimits()
+ yMin, yMax = plot.getYAxis().getLimits()
+ y2Min, y2Max = plot.getYAxis('right').getLimits()
+ plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
+ return updated
+
+ def _setRangeConstraints(self, minRange=None, maxRange=None):
+ raise NotImplementedError()
+
+ def setRangeConstraints(self, minRange=None, maxRange=None):
+ """
+ Set a constraint on the position of the axes.
+
+ :param float minRange: Minimum allowed left-to-right span across the
+ view
+ :param float maxRange: Maximum allowed left-to-right span across the
+ view
+ :return: True if the constaints was updated
+ :rtype: bool
+ """
+ updated = self._setRangeConstraints(minRange, maxRange)
+ if updated:
+ plot = self._getPlot()
+ xMin, xMax = plot.getXAxis().getLimits()
+ yMin, yMax = plot.getYAxis().getLimits()
+ y2Min, y2Max = plot.getYAxis('right').getLimits()
+ plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
+ return updated
+
+
+class XAxis(Axis):
+ """Axis class defining primitives for the X axis"""
+
+ # TODO With some changes on the backend, it will be able to remove all this
+ # specialised implementations (prefixel by '_internal')
+
+ def getTimeZone(self):
+ return self._getBackend().getXAxisTimeZone()
+
+ def setTimeZone(self, tz):
+ if isinstance(tz, str) and tz.upper() == "UTC":
+ tz = dateutil.tz.tzutc()
+ elif not(tz is None or isinstance(tz, dt.tzinfo)):
+ raise TypeError("tz must be a dt.tzinfo object, None or 'UTC'.")
+
+ self._getBackend().setXAxisTimeZone(tz)
+ self._getPlot()._setDirtyPlot()
+
+ def getTickMode(self):
+ if self._getBackend().isXAxisTimeSeries():
+ return TickMode.TIME_SERIES
+ else:
+ return TickMode.DEFAULT
+
+ def setTickMode(self, tickMode):
+ if tickMode == TickMode.DEFAULT:
+ self._getBackend().setXAxisTimeSeries(False)
+ elif tickMode == TickMode.TIME_SERIES:
+ self._getBackend().setXAxisTimeSeries(True)
+ else:
+ raise ValueError("Unexpected TickMode: {}".format(tickMode))
+
+ def _internalSetCurrentLabel(self, label):
+ self._getBackend().setGraphXLabel(label)
+
+ def _internalGetLimits(self):
+ return self._getBackend().getGraphXLimits()
+
+ def _internalSetLimits(self, xmin, xmax):
+ self._getBackend().setGraphXLimits(xmin, xmax)
+
+ def _internalSetLogarithmic(self, flag):
+ self._getBackend().setXAxisLogarithmic(flag)
+
+ def _setLimitsConstraints(self, minPos=None, maxPos=None):
+ constrains = self._getPlot()._getViewConstraints()
+ updated = constrains.update(xMin=minPos, xMax=maxPos)
+ return updated
+
+ def _setRangeConstraints(self, minRange=None, maxRange=None):
+ constrains = self._getPlot()._getViewConstraints()
+ updated = constrains.update(minXRange=minRange, maxXRange=maxRange)
+ return updated
+
+
+class YAxis(Axis):
+ """Axis class defining primitives for the Y axis"""
+
+ # TODO With some changes on the backend, it will be able to remove all this
+ # specialised implementations (prefixel by '_internal')
+
+ def _internalSetCurrentLabel(self, label):
+ self._getBackend().setGraphYLabel(label, axis='left')
+
+ def _internalGetLimits(self):
+ return self._getBackend().getGraphYLimits(axis='left')
+
+ def _internalSetLimits(self, ymin, ymax):
+ self._getBackend().setGraphYLimits(ymin, ymax, axis='left')
+
+ def _internalSetLogarithmic(self, flag):
+ self._getBackend().setYAxisLogarithmic(flag)
+
+ def setInverted(self, flag=True):
+ """Set the axis orientation.
+
+ This is only available for the Y axis.
+
+ :param bool flag: True for Y axis going from top to bottom,
+ False for Y axis going from bottom to top
+ """
+ flag = bool(flag)
+ if self.isInverted() == flag:
+ return
+ self._getBackend().setYAxisInverted(flag)
+ self._getPlot()._setDirtyPlot()
+ self.sigInvertedChanged.emit(flag)
+
+ def isInverted(self):
+ """Return True if the axis is inverted (top to bottom for the y-axis),
+ False otherwise. It is always False for the X axis.
+
+ :rtype: bool
+ """
+ return self._getBackend().isYAxisInverted()
+
+ def _setLimitsConstraints(self, minPos=None, maxPos=None):
+ constrains = self._getPlot()._getViewConstraints()
+ updated = constrains.update(yMin=minPos, yMax=maxPos)
+ return updated
+
+ def _setRangeConstraints(self, minRange=None, maxRange=None):
+ constrains = self._getPlot()._getViewConstraints()
+ updated = constrains.update(minYRange=minRange, maxYRange=maxRange)
+ return updated
+
+
+class YRightAxis(Axis):
+ """Proxy axis for the secondary Y axes. It manages it own label and limit
+ but share the some state like scale and direction with the main axis."""
+
+ # TODO With some changes on the backend, it will be able to remove all this
+ # specialised implementations (prefixel by '_internal')
+
+ def __init__(self, plot, mainAxis):
+ """Constructor
+
+ :param silx.gui.plot.PlotWidget.PlotWidget plot: Parent plot of this
+ axis
+ :param Axis mainAxis: Axis which sharing state with this axis
+ """
+ Axis.__init__(self, plot)
+ self.__mainAxis = mainAxis
+
+ @property
+ def sigInvertedChanged(self):
+ """Signal emitted when axis orientation has changed"""
+ return self.__mainAxis.sigInvertedChanged
+
+ @property
+ def sigScaleChanged(self):
+ """Signal emitted when axis scale has changed"""
+ return self.__mainAxis.sigScaleChanged
+
+ @property
+ def _sigLogarithmicChanged(self):
+ """Signal emitted when axis scale has changed to or from logarithmic"""
+ return self.__mainAxis._sigLogarithmicChanged
+
+ @property
+ def sigAutoScaleChanged(self):
+ """Signal emitted when axis autoscale has changed"""
+ return self.__mainAxis.sigAutoScaleChanged
+
+ def _internalSetCurrentLabel(self, label):
+ self._getBackend().setGraphYLabel(label, axis='right')
+
+ def _internalGetLimits(self):
+ return self._getBackend().getGraphYLimits(axis='right')
+
+ def _internalSetLimits(self, ymin, ymax):
+ self._getBackend().setGraphYLimits(ymin, ymax, axis='right')
+
+ def setInverted(self, flag=True):
+ """Set the Y axis orientation.
+
+ :param bool flag: True for Y axis going from top to bottom,
+ False for Y axis going from bottom to top
+ """
+ return self.__mainAxis.setInverted(flag)
+
+ def isInverted(self):
+ """Return True if Y axis goes from top to bottom, False otherwise."""
+ return self.__mainAxis.isInverted()
+
+ def getScale(self):
+ """Return the name of the scale used by this axis.
+
+ :rtype: str
+ """
+ return self.__mainAxis.getScale()
+
+ def setScale(self, scale):
+ """Set the scale to be used by this axis.
+
+ :param str scale: Name of the scale ("log", or "linear")
+ """
+ self.__mainAxis.setScale(scale)
+
+ def _isLogarithmic(self):
+ """Return True if Y axis scale is logarithmic, False if linear."""
+ return self.__mainAxis._isLogarithmic()
+
+ def _setLogarithmic(self, flag):
+ """Set the Y axes scale (either linear or logarithmic).
+
+ :param bool flag: True to use a logarithmic scale, False for linear.
+ """
+ return self.__mainAxis._setLogarithmic(flag)
+
+ def isAutoScale(self):
+ """Return True if Y axes are automatically adjusting its limits."""
+ return self.__mainAxis.isAutoScale()
+
+ def setAutoScale(self, flag=True):
+ """Set the Y axis limits adjusting behavior of :meth:`PlotWidget.resetZoom`.
+
+ :param bool flag: True to resize limits automatically,
+ False to disable it.
+ """
+ return self.__mainAxis.setAutoScale(flag)
diff --git a/src/silx/gui/plot/items/complex.py b/src/silx/gui/plot/items/complex.py
new file mode 100644
index 0000000..abb64ad
--- /dev/null
+++ b/src/silx/gui/plot/items/complex.py
@@ -0,0 +1,386 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the :class:`ImageComplexData` of the :class:`Plot`.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["Vincent Favre-Nicolin", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "14/06/2018"
+
+
+import logging
+
+import numpy
+
+from ....utils.proxy import docstring
+from ....utils.deprecation import deprecated
+from ...colors import Colormap
+from .core import ColormapMixIn, ComplexMixIn, ItemChangedType
+from .image import ImageBase
+
+
+_logger = logging.getLogger(__name__)
+
+
+# Complex colormap functions
+
+def _phase2rgb(colormap, data):
+ """Creates RGBA image with colour-coded phase.
+
+ :param Colormap colormap: The colormap to use
+ :param numpy.ndarray data: The data to convert
+ :return: Array of RGBA colors
+ :rtype: numpy.ndarray
+ """
+ if data.size == 0:
+ return numpy.zeros((0, 0, 4), dtype=numpy.uint8)
+
+ phase = numpy.angle(data)
+ return colormap.applyToData(phase)
+
+
+def _complex2rgbalog(phaseColormap, data, amin=0., dlogs=2, smax=None):
+ """Returns RGBA colors: colour-coded phases and log10(amplitude) in alpha.
+
+ :param Colormap phaseColormap: Colormap to use for the phase
+ :param numpy.ndarray data: the complex data array to convert to RGBA
+ :param float amin: the minimum value for the alpha channel
+ :param float dlogs: amplitude range displayed, in log10 units
+ :param float smax:
+ if specified, all values above max will be displayed with an alpha=1
+ """
+ if data.size == 0:
+ return numpy.zeros((0, 0, 4), dtype=numpy.uint8)
+
+ rgba = _phase2rgb(phaseColormap, data)
+ sabs = numpy.absolute(data)
+ if smax is not None:
+ sabs[sabs > smax] = smax
+ a = numpy.log10(sabs + 1e-20)
+ a -= a.max() - dlogs # display dlogs orders of magnitude
+ rgba[..., 3] = 255 * (amin + a / dlogs * (1 - amin) * (a > 0))
+ return rgba
+
+
+def _complex2rgbalin(phaseColormap, data, gamma=1.0, smax=None):
+ """Returns RGBA colors: colour-coded phase and linear amplitude in alpha.
+
+ :param Colormap phaseColormap: Colormap to use for the phase
+ :param numpy.ndarray data:
+ :param float gamma: Optional exponent gamma applied to the amplitude
+ :param float smax:
+ """
+ if data.size == 0:
+ return numpy.zeros((0, 0, 4), dtype=numpy.uint8)
+
+ rgba = _phase2rgb(phaseColormap, data)
+ a = numpy.absolute(data)
+ if smax is not None:
+ a[a > smax] = smax
+ a /= a.max()
+ rgba[..., 3] = 255 * a**gamma
+ return rgba
+
+
+class ImageComplexData(ImageBase, ColormapMixIn, ComplexMixIn):
+ """Specific plot item to force colormap when using complex colormap.
+
+ This is returning the specific colormap when displaying
+ colored phase + amplitude.
+ """
+
+ _SUPPORTED_COMPLEX_MODES = (
+ ComplexMixIn.ComplexMode.ABSOLUTE,
+ ComplexMixIn.ComplexMode.PHASE,
+ ComplexMixIn.ComplexMode.REAL,
+ ComplexMixIn.ComplexMode.IMAGINARY,
+ ComplexMixIn.ComplexMode.AMPLITUDE_PHASE,
+ ComplexMixIn.ComplexMode.LOG10_AMPLITUDE_PHASE,
+ ComplexMixIn.ComplexMode.SQUARE_AMPLITUDE)
+ """Overrides supported ComplexMode"""
+
+ def __init__(self):
+ ImageBase.__init__(self, numpy.zeros((0, 0), dtype=numpy.complex64))
+ ColormapMixIn.__init__(self)
+ ComplexMixIn.__init__(self)
+ self._dataByModesCache = {}
+ self._amplitudeRangeInfo = None, 2
+
+ # Use default from ColormapMixIn
+ colormap = super(ImageComplexData, self).getColormap()
+
+ phaseColormap = Colormap(
+ name='hsv',
+ vmin=-numpy.pi,
+ vmax=numpy.pi)
+
+ self._colormaps = { # Default colormaps for all modes
+ self.ComplexMode.ABSOLUTE: colormap,
+ self.ComplexMode.PHASE: phaseColormap,
+ self.ComplexMode.REAL: colormap,
+ self.ComplexMode.IMAGINARY: colormap,
+ self.ComplexMode.AMPLITUDE_PHASE: phaseColormap,
+ self.ComplexMode.LOG10_AMPLITUDE_PHASE: phaseColormap,
+ self.ComplexMode.SQUARE_AMPLITUDE: colormap,
+ }
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ plot = self.getPlot()
+ assert plot is not None
+ if not self._isPlotLinear(plot):
+ # Do not render with non linear scales
+ return None
+
+ mode = self.getComplexMode()
+ if mode in (self.ComplexMode.AMPLITUDE_PHASE,
+ self.ComplexMode.LOG10_AMPLITUDE_PHASE):
+ # For those modes, compute RGBA image here
+ colormap = None
+ data = self.getRgbaImageData(copy=False)
+ else:
+ colormap = self.getColormap()
+ if colormap.isAutoscale():
+ # Avoid backend to compute autoscale: use item cache
+ colormap = colormap.copy()
+ colormap.setVRange(*colormap.getColormapRange(self))
+
+ data = self.getData(copy=False)
+
+ if data.size == 0:
+ return None # No data to display
+
+ return backend.addImage(data,
+ origin=self.getOrigin(),
+ scale=self.getScale(),
+ colormap=colormap,
+ alpha=self.getAlpha())
+
+ @docstring(ComplexMixIn)
+ def setComplexMode(self, mode):
+ changed = super(ImageComplexData, self).setComplexMode(mode)
+ if changed:
+ self._valueDataChanged()
+
+ # Backward compatibility
+ self._updated(ItemChangedType.VISUALIZATION_MODE)
+
+ # Update ColormapMixIn colormap
+ colormap = self._colormaps[self.getComplexMode()]
+ if colormap is not super(ImageComplexData, self).getColormap():
+ super(ImageComplexData, self).setColormap(colormap)
+
+ # Send data updated as value returned by getData has changed
+ self._updated(ItemChangedType.DATA)
+ return changed
+
+ def _setAmplitudeRangeInfo(self, max_=None, delta=2):
+ """Set the amplitude range to display for 'log10_amplitude_phase' mode.
+
+ :param max_: Max of the amplitude range.
+ If None it autoscales to data max.
+ :param float delta: Delta range in log10 to display
+ """
+ self._amplitudeRangeInfo = max_, float(delta)
+ self._updated(ItemChangedType.VISUALIZATION_MODE)
+
+ def _getAmplitudeRangeInfo(self):
+ """Returns the amplitude range to use for 'log10_amplitude_phase' mode.
+
+ :return: (max, delta), if max is None, then it autoscales to data max
+ :rtype: 2-tuple"""
+ return self._amplitudeRangeInfo
+
+ def setColormap(self, colormap, mode=None):
+ """Set the colormap for this specific mode.
+
+ :param ~silx.gui.colors.Colormap colormap: The colormap
+ :param Union[ComplexMode,str] mode:
+ If specified, set the colormap of this specific mode.
+ Default: current mode.
+ """
+ if mode is None:
+ mode = self.getComplexMode()
+ else:
+ mode = self.ComplexMode.from_value(mode)
+
+ self._colormaps[mode] = colormap
+ if mode is self.getComplexMode():
+ super(ImageComplexData, self).setColormap(colormap)
+ else:
+ self._updated(ItemChangedType.COLORMAP)
+
+ def getColormap(self, mode=None):
+ """Get the colormap for the (current) mode.
+
+ :param Union[ComplexMode,str] mode:
+ If specified, get the colormap of this specific mode.
+ Default: current mode.
+ :rtype: ~silx.gui.colors.Colormap
+ """
+ if mode is None:
+ mode = self.getComplexMode()
+ else:
+ mode = self.ComplexMode.from_value(mode)
+
+ return self._colormaps[mode]
+
+ def setData(self, data, copy=True):
+ """"Set the image complex data
+
+ :param numpy.ndarray data: 2D array of complex with 2 dimensions (h, w)
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ """
+ data = numpy.array(data, copy=copy)
+ assert data.ndim == 2
+ if not numpy.issubdtype(data.dtype, numpy.complexfloating):
+ _logger.warning(
+ 'Image is not complex, converting it to complex to plot it.')
+ data = numpy.array(data, dtype=numpy.complex64)
+
+ # Compute current mode data and set colormap data
+ mode = self.getComplexMode()
+ dataForMode = self.__convertComplexData(data, self.getComplexMode())
+ self._dataByModesCache = {mode: dataForMode}
+
+ super().setData(data)
+
+ def _updated(self, event=None, checkVisibility=True):
+ # Synchronizes colormapped data if changed
+ # ItemChangedType.COMPLEX_MODE triggers ItemChangedType.DATA
+ # No need to handle it twice.
+ if event in (ItemChangedType.DATA, ItemChangedType.MASK):
+ # Color-mapped data is NOT the `getValueData` for some modes
+ if self.getComplexMode() in (
+ self.ComplexMode.AMPLITUDE_PHASE,
+ self.ComplexMode.LOG10_AMPLITUDE_PHASE):
+ data = self.getData(copy=False, mode=self.ComplexMode.PHASE)
+ mask = self.getMaskData(copy=False)
+ if mask is not None:
+ data = numpy.copy(data)
+ data[mask != 0] = numpy.nan
+ else:
+ data = self.getValueData(copy=False)
+ self._setColormappedData(data, copy=False)
+ super()._updated(event=event, checkVisibility=checkVisibility)
+
+ def getComplexData(self, copy=True):
+ """Returns the image complex data
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: numpy.ndarray of complex
+ """
+ return super().getData(copy=copy)
+
+ def __convertComplexData(self, data, mode):
+ """Convert complex data to given mode.
+
+ :param numpy.ndarray data:
+ :param Union[ComplexMode,str] mode:
+ :rtype: numpy.ndarray of float
+ """
+ if mode is self.ComplexMode.PHASE:
+ return numpy.angle(data)
+ elif mode is self.ComplexMode.REAL:
+ return numpy.real(data)
+ elif mode is self.ComplexMode.IMAGINARY:
+ return numpy.imag(data)
+ elif mode in (self.ComplexMode.ABSOLUTE,
+ self.ComplexMode.LOG10_AMPLITUDE_PHASE,
+ self.ComplexMode.AMPLITUDE_PHASE):
+ return numpy.absolute(data)
+ elif mode is self.ComplexMode.SQUARE_AMPLITUDE:
+ return numpy.absolute(data) ** 2
+ else:
+ _logger.error(
+ 'Unsupported conversion mode: %s, fallback to absolute',
+ str(mode))
+ return numpy.absolute(data)
+
+ def getData(self, copy=True, mode=None):
+ """Returns the image data corresponding to (current) mode.
+
+ The returned data is always floats, to get the complex data, use
+ :meth:`getComplexData`.
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :param Union[ComplexMode,str] mode:
+ If specified, get data corresponding to the mode.
+ Default: Current mode.
+ :rtype: numpy.ndarray of float
+ """
+ if mode is None:
+ mode = self.getComplexMode()
+ else:
+ mode = self.ComplexMode.from_value(mode)
+
+ if mode not in self._dataByModesCache:
+ self._dataByModesCache[mode] = self.__convertComplexData(
+ self.getComplexData(copy=False), mode)
+
+ return numpy.array(self._dataByModesCache[mode], copy=copy)
+
+ def getRgbaImageData(self, copy=True, mode=None):
+ """Get the displayed RGB(A) image for (current) mode
+
+ :param bool copy: Ignored for this class
+ :param Union[ComplexMode,str] mode:
+ If specified, get data corresponding to the mode.
+ Default: Current mode.
+ :rtype: numpy.ndarray of uint8 of shape (height, width, 4)
+ """
+ if mode is None:
+ mode = self.getComplexMode()
+ else:
+ mode = self.ComplexMode.from_value(mode)
+
+ colormap = self.getColormap(mode=mode)
+ if mode is self.ComplexMode.AMPLITUDE_PHASE:
+ data = self.getComplexData(copy=False)
+ return _complex2rgbalin(colormap, data)
+ elif mode is self.ComplexMode.LOG10_AMPLITUDE_PHASE:
+ data = self.getComplexData(copy=False)
+ max_, delta = self._getAmplitudeRangeInfo()
+ return _complex2rgbalog(colormap, data, dlogs=delta, smax=max_)
+ else:
+ data = self.getData(copy=False, mode=mode)
+ return colormap.applyToData(data)
+
+ # Backward compatibility
+
+ Mode = ComplexMixIn.ComplexMode
+
+ @deprecated(replacement='setComplexMode', since_version='0.11.0')
+ def setVisualizationMode(self, mode):
+ return self.setComplexMode(mode)
+
+ @deprecated(replacement='getComplexMode', since_version='0.11.0')
+ def getVisualizationMode(self):
+ return self.getComplexMode()
diff --git a/src/silx/gui/plot/items/core.py b/src/silx/gui/plot/items/core.py
new file mode 100644
index 0000000..fa3b8cf
--- /dev/null
+++ b/src/silx/gui/plot/items/core.py
@@ -0,0 +1,1733 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the base class for items of the :class:`Plot`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "08/12/2020"
+
+import collections
+try:
+ from collections import abc
+except ImportError: # Python2 support
+ import collections as abc
+from copy import deepcopy
+import logging
+import enum
+from typing import Optional, Tuple
+import warnings
+import weakref
+
+import numpy
+
+from ....utils.deprecation import deprecated
+from ....utils.proxy import docstring
+from ....utils.enum import Enum as _Enum
+from ....math.combo import min_max
+from ... import qt
+from ... import colors
+from ...colors import Colormap
+from ._pick import PickingResult
+
+from silx import config
+
+_logger = logging.getLogger(__name__)
+
+
+@enum.unique
+class ItemChangedType(enum.Enum):
+ """Type of modification provided by :attr:`Item.sigItemChanged` signal."""
+ # Private setters and setInfo are not emitting sigItemChanged signal.
+ # Signals to consider:
+ # COLORMAP_SET emitted when setColormap is called but not forward colormap object signal
+ # CURRENT_COLOR_CHANGED emitted current color changed because highlight changed,
+ # highlighted color changed or color changed depending on hightlight state.
+
+ VISIBLE = 'visibleChanged'
+ """Item's visibility changed flag."""
+
+ ZVALUE = 'zValueChanged'
+ """Item's Z value changed flag."""
+
+ COLORMAP = 'colormapChanged' # Emitted when set + forward events from the colormap object
+ """Item's colormap changed flag.
+
+ This is emitted both when setting a new colormap and
+ when the current colormap object is updated.
+ """
+
+ SYMBOL = 'symbolChanged'
+ """Item's symbol changed flag."""
+
+ SYMBOL_SIZE = 'symbolSizeChanged'
+ """Item's symbol size changed flag."""
+
+ LINE_WIDTH = 'lineWidthChanged'
+ """Item's line width changed flag."""
+
+ LINE_STYLE = 'lineStyleChanged'
+ """Item's line style changed flag."""
+
+ COLOR = 'colorChanged'
+ """Item's color changed flag."""
+
+ LINE_BG_COLOR = 'lineBgColorChanged'
+ """Item's line background color changed flag."""
+
+ YAXIS = 'yAxisChanged'
+ """Item's Y axis binding changed flag."""
+
+ FILL = 'fillChanged'
+ """Item's fill changed flag."""
+
+ ALPHA = 'alphaChanged'
+ """Item's transparency alpha changed flag."""
+
+ DATA = 'dataChanged'
+ """Item's data changed flag"""
+
+ MASK = 'maskChanged'
+ """Item's mask changed flag"""
+
+ HIGHLIGHTED = 'highlightedChanged'
+ """Item's highlight state changed flag."""
+
+ HIGHLIGHTED_COLOR = 'highlightedColorChanged'
+ """Deprecated, use HIGHLIGHTED_STYLE instead."""
+
+ HIGHLIGHTED_STYLE = 'highlightedStyleChanged'
+ """Item's highlighted style changed flag."""
+
+ SCALE = 'scaleChanged'
+ """Item's scale changed flag."""
+
+ TEXT = 'textChanged'
+ """Item's text changed flag."""
+
+ POSITION = 'positionChanged'
+ """Item's position changed flag.
+
+ This is emitted when a marker position changed and
+ when an image origin changed.
+ """
+
+ OVERLAY = 'overlayChanged'
+ """Item's overlay state changed flag."""
+
+ VISUALIZATION_MODE = 'visualizationModeChanged'
+ """Item's visualization mode changed flag."""
+
+ COMPLEX_MODE = 'complexModeChanged'
+ """Item's complex data visualization mode changed flag."""
+
+ NAME = 'nameChanged'
+ """Item's name changed flag."""
+
+ EDITABLE = 'editableChanged'
+ """Item's editable state changed flags."""
+
+ SELECTABLE = 'selectableChanged'
+ """Item's selectable state changed flags."""
+
+
+class Item(qt.QObject):
+ """Description of an item of the plot"""
+
+ _DEFAULT_Z_LAYER = 0
+ """Default layer for overlay rendering"""
+
+ _DEFAULT_SELECTABLE = False
+ """Default selectable state of items"""
+
+ sigItemChanged = qt.Signal(object)
+ """Signal emitted when the item has changed.
+
+ It provides a flag describing which property of the item has changed.
+ See :class:`ItemChangedType` for flags description.
+ """
+
+ _sigVisibleBoundsChanged = qt.Signal()
+ """Signal emitted when the visible extent of the item in the plot has changed.
+
+ This signal is emitted only if visible extent tracking is enabled
+ (see :meth:`_setVisibleBoundsTracking`).
+ """
+
+ def __init__(self):
+ qt.QObject.__init__(self)
+ self._dirty = True
+ self._plotRef = None
+ self._visible = True
+ self._selectable = self._DEFAULT_SELECTABLE
+ self._z = self._DEFAULT_Z_LAYER
+ self._info = None
+ self._xlabel = None
+ self._ylabel = None
+ self.__name = ''
+
+ self.__visibleBoundsTracking = False
+ self.__previousVisibleBounds = None
+
+ self._backendRenderer = None
+
+ def getPlot(self):
+ """Returns the ~silx.gui.plot.PlotWidget this item belongs to.
+
+ :rtype: Union[~silx.gui.plot.PlotWidget,None]
+ """
+ return None if self._plotRef is None else self._plotRef()
+
+ def _setPlot(self, plot):
+ """Set the plot this item belongs to.
+
+ WARNING: This should only be called from the Plot.
+
+ :param Union[~silx.gui.plot.PlotWidget,None] plot: The Plot instance.
+ """
+ if plot is not None and self._plotRef is not None:
+ raise RuntimeError('Trying to add a node at two places.')
+ self.__disconnectFromPlotWidget()
+ self._plotRef = None if plot is None else weakref.ref(plot)
+ self.__connectToPlotWidget()
+ self._updated()
+
+ def getBounds(self): # TODO return a Bounds object rather than a tuple
+ """Returns the bounding box of this item in data coordinates
+
+ :returns: (xmin, xmax, ymin, ymax) or None
+ :rtype: 4-tuple of float or None
+ """
+ return self._getBounds()
+
+ def _getBounds(self):
+ """:meth:`getBounds` implementation to override by sub-class"""
+ return None
+
+ def isVisible(self):
+ """True if item is visible, False otherwise
+
+ :rtype: bool
+ """
+ return self._visible
+
+ def setVisible(self, visible):
+ """Set visibility of item.
+
+ :param bool visible: True to display it, False otherwise
+ """
+ visible = bool(visible)
+ if visible != self._visible:
+ self._visible = visible
+ # When visibility has changed, always mark as dirty
+ self._updated(ItemChangedType.VISIBLE,
+ checkVisibility=False)
+
+ def isOverlay(self):
+ """Return true if item is drawn as an overlay.
+
+ :rtype: bool
+ """
+ return False
+
+ def getName(self):
+ """Returns the name of the item which is used as legend.
+
+ :rtype: str
+ """
+ return self.__name
+
+ def setName(self, name):
+ """Set the name of the item which is used as legend.
+
+ :param str name: New name of the item
+ :raises RuntimeError: If item belongs to a PlotWidget.
+ """
+ name = str(name)
+ if self.__name != name:
+ if self.getPlot() is not None:
+ raise RuntimeError(
+ "Cannot change name while item is in a PlotWidget")
+
+ self.__name = name
+ self._updated(ItemChangedType.NAME)
+
+ def getLegend(self): # Replaced by getName for API consistency
+ return self.getName()
+
+ @deprecated(replacement='setName', since_version='0.13')
+ def _setLegend(self, legend):
+ legend = str(legend) if legend is not None else ''
+ self.setName(legend)
+
+ def isSelectable(self):
+ """Returns true if item is selectable (bool)"""
+ return self._selectable
+
+ def _setSelectable(self, selectable): # TODO support update
+ """Set whether item is selectable or not.
+
+ This is private for now as change is not handled.
+
+ :param bool selectable: True to make item selectable
+ """
+ self._selectable = bool(selectable)
+
+ def getZValue(self):
+ """Returns the layer on which to draw this item (int)"""
+ return self._z
+
+ def setZValue(self, z):
+ z = int(z) if z is not None else self._DEFAULT_Z_LAYER
+ if z != self._z:
+ self._z = z
+ self._updated(ItemChangedType.ZVALUE)
+
+ def getInfo(self, copy=True):
+ """Returns the info associated to this item
+
+ :param bool copy: True to get a deepcopy, False otherwise.
+ """
+ return deepcopy(self._info) if copy else self._info
+
+ def setInfo(self, info, copy=True):
+ if copy:
+ info = deepcopy(info)
+ self._info = info
+
+ def getVisibleBounds(self) -> Optional[Tuple[float, float, float, float]]:
+ """Returns visible bounds of the item bounding box in the plot area.
+
+ :returns:
+ (xmin, xmax, ymin, ymax) in data coordinates of the visible area or
+ None if item is not visible in the plot area.
+ :rtype: Union[List[float],None]
+ """
+ plot = self.getPlot()
+ bounds = self.getBounds()
+ if plot is None or bounds is None or not self.isVisible():
+ return None
+
+ xmin, xmax = numpy.clip(bounds[:2], *plot.getXAxis().getLimits())
+ ymin, ymax = numpy.clip(
+ bounds[2:], *plot.getYAxis(self.__getYAxis()).getLimits())
+
+ if xmin == xmax or ymin == ymax: # Outside the plot area
+ return None
+ else:
+ return xmin, xmax, ymin, ymax
+
+ def _isVisibleBoundsTracking(self) -> bool:
+ """Returns True if visible bounds changes are tracked.
+
+ When enabled, :attr:`_sigVisibleBoundsChanged` is emitted upon changes.
+ :rtype: bool
+ """
+ return self.__visibleBoundsTracking
+
+ def _setVisibleBoundsTracking(self, enable: bool) -> None:
+ """Set whether or not to track visible bounds changes.
+
+ :param bool enable:
+ """
+ if enable != self.__visibleBoundsTracking:
+ self.__disconnectFromPlotWidget()
+ self.__previousVisibleBounds = None
+ self.__visibleBoundsTracking = enable
+ self.__connectToPlotWidget()
+
+ def __getYAxis(self) -> str:
+ """Returns current Y axis ('left' or 'right')"""
+ return self.getYAxis() if isinstance(self, YAxisMixIn) else 'left'
+
+ def __connectToPlotWidget(self) -> None:
+ """Connect to PlotWidget signals and install event filter"""
+ if not self._isVisibleBoundsTracking():
+ return
+
+ plot = self.getPlot()
+ if plot is not None:
+ for axis in (plot.getXAxis(), plot.getYAxis(self.__getYAxis())):
+ axis.sigLimitsChanged.connect(self._visibleBoundsChanged)
+
+ plot.installEventFilter(self)
+
+ self._visibleBoundsChanged()
+
+ def __disconnectFromPlotWidget(self) -> None:
+ """Disconnect from PlotWidget signals and remove event filter"""
+ if not self._isVisibleBoundsTracking():
+ return
+
+ plot = self.getPlot()
+ if plot is not None:
+ for axis in (plot.getXAxis(), plot.getYAxis(self.__getYAxis())):
+ axis.sigLimitsChanged.disconnect(self._visibleBoundsChanged)
+
+ plot.removeEventFilter(self)
+
+ def _visibleBoundsChanged(self, *args) -> None:
+ """Check if visible extent actually changed and emit signal"""
+ if not self._isVisibleBoundsTracking():
+ return # No visible extent tracking
+
+ plot = self.getPlot()
+ if plot is None or not plot.isVisible():
+ return # No plot or plot not visible
+
+ extent = self.getVisibleBounds()
+ if extent != self.__previousVisibleBounds:
+ self.__previousVisibleBounds = extent
+ self._sigVisibleBoundsChanged.emit()
+
+ def eventFilter(self, watched, event):
+ """Event filter to handle PlotWidget show events"""
+ if watched is self.getPlot() and event.type() == qt.QEvent.Show:
+ self._visibleBoundsChanged()
+ return super().eventFilter(watched, event)
+
+ def _updated(self, event=None, checkVisibility=True):
+ """Mark the item as dirty (i.e., needing update).
+
+ This also triggers Plot.replot.
+
+ :param event: The event to send to :attr:`sigItemChanged` signal.
+ :param bool checkVisibility: True to only mark as dirty if visible,
+ False to always mark as dirty.
+ """
+ if not checkVisibility or self.isVisible():
+ if not self._dirty:
+ self._dirty = True
+ # TODO: send event instead of explicit call
+ plot = self.getPlot()
+ if plot is not None:
+ plot._itemRequiresUpdate(self)
+ if event is not None:
+ self.sigItemChanged.emit(event)
+
+ def _update(self, backend):
+ """Called by Plot to update the backend for this item.
+
+ This is meant to be called asynchronously from _updated.
+ This optimizes the number of call to _update.
+
+ :param backend: The backend to update
+ """
+ if self._dirty:
+ # Remove previous renderer from backend if any
+ self._removeBackendRenderer(backend)
+
+ # If not visible, do not add renderer to backend
+ if self.isVisible():
+ self._backendRenderer = self._addBackendRenderer(backend)
+
+ self._dirty = False
+
+ def _addBackendRenderer(self, backend):
+ """Override in subclass to add specific backend renderer.
+
+ :param BackendBase backend: The backend to update
+ :return: The renderer handle to store or None if no renderer in backend
+ """
+ return None
+
+ def _removeBackendRenderer(self, backend):
+ """Override in subclass to remove specific backend renderer.
+
+ :param BackendBase backend: The backend to update
+ """
+ if self._backendRenderer is not None:
+ backend.remove(self._backendRenderer)
+ self._backendRenderer = None
+
+ def pick(self, x, y):
+ """Run picking test on this item
+
+ :param float x: The x pixel coord where to pick.
+ :param float y: The y pixel coord where to pick.
+ :return: None if not picked, else the picked position information
+ :rtype: Union[None,PickingResult]
+ """
+ if not self.isVisible() or self._backendRenderer is None:
+ return None
+ plot = self.getPlot()
+ if plot is None:
+ return None
+
+ indices = plot._backend.pickItem(x, y, self._backendRenderer)
+ if indices is None:
+ return None
+ else:
+ return PickingResult(self, indices)
+
+
+class DataItem(Item):
+ """Item with a data extent in the plot"""
+
+ def _boundsChanged(self, checkVisibility: bool=True) -> None:
+ """Call this method in subclass when data bounds has changed.
+
+ :param bool checkVisibility:
+ """
+ if not checkVisibility or self.isVisible():
+ self._visibleBoundsChanged()
+
+ # TODO hackish data range implementation
+ plot = self.getPlot()
+ if plot is not None:
+ plot._invalidateDataRange()
+
+ @docstring(Item)
+ def setVisible(self, visible: bool):
+ if visible != self.isVisible():
+ self._boundsChanged(checkVisibility=False)
+ super().setVisible(visible)
+
+# Mix-in classes ##############################################################
+
+
+class ItemMixInBase(object):
+ """Base class for Item mix-in"""
+
+ def _updated(self, event=None, checkVisibility=True):
+ """This is implemented in :class:`Item`.
+
+ Mark the item as dirty (i.e., needing update).
+ This also triggers Plot.replot.
+
+ :param event: The event to send to :attr:`sigItemChanged` signal.
+ :param bool checkVisibility: True to only mark as dirty if visible,
+ False to always mark as dirty.
+ """
+ raise RuntimeError(
+ "Issue with Mix-In class inheritance order")
+
+
+class LabelsMixIn(ItemMixInBase):
+ """Mix-in class for items with x and y labels
+
+ Setters are private, otherwise it needs to check the plot
+ current active curve and access the internal current labels.
+ """
+
+ def __init__(self):
+ self._xlabel = None
+ self._ylabel = None
+
+ def getXLabel(self):
+ """Return the X axis label associated to this curve
+
+ :rtype: str or None
+ """
+ return self._xlabel
+
+ def _setXLabel(self, label):
+ """Set the X axis label associated with this curve
+
+ :param str label: The X axis label
+ """
+ self._xlabel = str(label)
+
+ def getYLabel(self):
+ """Return the Y axis label associated to this curve
+
+ :rtype: str or None
+ """
+ return self._ylabel
+
+ def _setYLabel(self, label):
+ """Set the Y axis label associated with this curve
+
+ :param str label: The Y axis label
+ """
+ self._ylabel = str(label)
+
+
+class DraggableMixIn(ItemMixInBase):
+ """Mix-in class for draggable items"""
+
+ def __init__(self):
+ self._draggable = False
+
+ def isDraggable(self):
+ """Returns true if image is draggable
+
+ :rtype: bool
+ """
+ return self._draggable
+
+ def _setDraggable(self, draggable): # TODO support update
+ """Set if image is draggable or not.
+
+ This is private for not as it does not support update.
+
+ :param bool draggable:
+ """
+ self._draggable = bool(draggable)
+
+ def drag(self, from_, to):
+ """Perform a drag of the item.
+
+ :param List[float] from_: (x, y) previous position in data coordinates
+ :param List[float] to: (x, y) current position in data coordinates
+ """
+ raise NotImplementedError("Must be implemented in subclass")
+
+
+class ColormapMixIn(ItemMixInBase):
+ """Mix-in class for items with colormap"""
+
+ def __init__(self):
+ self._colormap = Colormap()
+ self._colormap.sigChanged.connect(self._colormapChanged)
+ self.__data = None
+ self.__cacheColormapRange = {} # Store {normalization: range}
+
+ def getColormap(self):
+ """Return the used colormap"""
+ return self._colormap
+
+ def setColormap(self, colormap):
+ """Set the colormap of this item
+
+ :param silx.gui.colors.Colormap colormap: colormap description
+ """
+ if self._colormap is colormap:
+ return
+ if isinstance(colormap, dict):
+ colormap = Colormap._fromDict(colormap)
+
+ if self._colormap is not None:
+ self._colormap.sigChanged.disconnect(self._colormapChanged)
+ self._colormap = colormap
+ if self._colormap is not None:
+ self._colormap.sigChanged.connect(self._colormapChanged)
+ self._colormapChanged()
+
+ def _colormapChanged(self):
+ """Handle updates of the colormap"""
+ self._updated(ItemChangedType.COLORMAP)
+
+ def _setColormappedData(self, data, copy=True,
+ min_=None, minPositive=None, max_=None):
+ """Set the data used to compute the colormapped display.
+
+ It also resets the cache of data ranges.
+
+ This method MUST be called by inheriting classes when data is updated.
+
+ :param Union[None,numpy.ndarray] data:
+ :param Union[None,float] min_: Minimum value of the data
+ :param Union[None,float] minPositive:
+ Minimum of strictly positive values of the data
+ :param Union[None,float] max_: Maximum value of the data
+ """
+ self.__data = None if data is None else numpy.array(data, copy=copy)
+ self.__cacheColormapRange = {} # Reset cache
+
+ # Fill-up colormap range cache if values are provided
+ if max_ is not None and numpy.isfinite(max_):
+ if min_ is not None and numpy.isfinite(min_):
+ self.__cacheColormapRange[Colormap.LINEAR, Colormap.MINMAX] = min_, max_
+ if minPositive is not None and numpy.isfinite(minPositive):
+ self.__cacheColormapRange[Colormap.LOGARITHM, Colormap.MINMAX] = minPositive, max_
+
+ colormap = self.getColormap()
+ if None in (colormap.getVMin(), colormap.getVMax()):
+ self._colormapChanged()
+
+ def getColormappedData(self, copy=True):
+ """Returns the data used to compute the displayed colors
+
+ :param bool copy: True to get a copy,
+ False to get internal data (do not modify!).
+ :rtype: Union[None,numpy.ndarray]
+ """
+ if self.__data is None:
+ return None
+ else:
+ return numpy.array(self.__data, copy=copy)
+
+ def _getColormapAutoscaleRange(self, colormap=None):
+ """Returns the autoscale range for current data and colormap.
+
+ :param Union[None,~silx.gui.colors.Colormap] colormap:
+ The colormap for which to compute the autoscale range.
+ If None, the default, the colormap of the item is used
+ :return: (vmin, vmax) range (vmin and /or vmax might be `None`)
+ """
+ if colormap is None:
+ colormap = self.getColormap()
+
+ data = self.getColormappedData(copy=False)
+ if colormap is None or data is None:
+ return None, None
+
+ normalization = colormap.getNormalization()
+ autoscaleMode = colormap.getAutoscaleMode()
+ key = normalization, autoscaleMode
+ vRange = self.__cacheColormapRange.get(key, None)
+ if vRange is None:
+ vRange = colormap._computeAutoscaleRange(data)
+ self.__cacheColormapRange[key] = vRange
+ return vRange
+
+
+class SymbolMixIn(ItemMixInBase):
+ """Mix-in class for items with symbol type"""
+
+ _DEFAULT_SYMBOL = None
+ """Default marker of the item"""
+
+ _DEFAULT_SYMBOL_SIZE = config.DEFAULT_PLOT_SYMBOL_SIZE
+ """Default marker size of the item"""
+
+ _SUPPORTED_SYMBOLS = collections.OrderedDict((
+ ('o', 'Circle'),
+ ('d', 'Diamond'),
+ ('s', 'Square'),
+ ('+', 'Plus'),
+ ('x', 'Cross'),
+ ('.', 'Point'),
+ (',', 'Pixel'),
+ ('|', 'Vertical line'),
+ ('_', 'Horizontal line'),
+ ('tickleft', 'Tick left'),
+ ('tickright', 'Tick right'),
+ ('tickup', 'Tick up'),
+ ('tickdown', 'Tick down'),
+ ('caretleft', 'Caret left'),
+ ('caretright', 'Caret right'),
+ ('caretup', 'Caret up'),
+ ('caretdown', 'Caret down'),
+ (u'\u2665', 'Heart'),
+ ('', 'None')))
+ """Dict of supported symbols"""
+
+ def __init__(self):
+ if self._DEFAULT_SYMBOL is None: # Use default from config
+ self._symbol = config.DEFAULT_PLOT_SYMBOL
+ else:
+ self._symbol = self._DEFAULT_SYMBOL
+
+ if self._DEFAULT_SYMBOL_SIZE is None: # Use default from config
+ self._symbol_size = config.DEFAULT_PLOT_SYMBOL_SIZE
+ else:
+ self._symbol_size = self._DEFAULT_SYMBOL_SIZE
+
+ @classmethod
+ def getSupportedSymbols(cls):
+ """Returns the list of supported symbol names.
+
+ :rtype: tuple of str
+ """
+ return tuple(cls._SUPPORTED_SYMBOLS.keys())
+
+ @classmethod
+ def getSupportedSymbolNames(cls):
+ """Returns the list of supported symbol human-readable names.
+
+ :rtype: tuple of str
+ """
+ return tuple(cls._SUPPORTED_SYMBOLS.values())
+
+ def getSymbolName(self, symbol=None):
+ """Returns human-readable name for a symbol.
+
+ :param str symbol: The symbol from which to get the name.
+ Default: current symbol.
+ :rtype: str
+ :raise KeyError: if symbol is not in :meth:`getSupportedSymbols`.
+ """
+ if symbol is None:
+ symbol = self.getSymbol()
+ return self._SUPPORTED_SYMBOLS[symbol]
+
+ def getSymbol(self):
+ """Return the point marker type.
+
+ Marker type::
+
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+
+ :rtype: str
+ """
+ return self._symbol
+
+ def setSymbol(self, symbol):
+ """Set the marker type
+
+ See :meth:`getSymbol`.
+
+ :param str symbol: Marker type or marker name
+ """
+ if symbol is None:
+ symbol = self._DEFAULT_SYMBOL
+
+ elif symbol not in self.getSupportedSymbols():
+ for symbolCode, name in self._SUPPORTED_SYMBOLS.items():
+ if name.lower() == symbol.lower():
+ symbol = symbolCode
+ break
+ else:
+ raise ValueError('Unsupported symbol %s' % str(symbol))
+
+ if symbol != self._symbol:
+ self._symbol = symbol
+ self._updated(ItemChangedType.SYMBOL)
+
+ def getSymbolSize(self):
+ """Return the point marker size in points.
+
+ :rtype: float
+ """
+ return self._symbol_size
+
+ def setSymbolSize(self, size):
+ """Set the point marker size in points.
+
+ See :meth:`getSymbolSize`.
+
+ :param str symbol: Marker type
+ """
+ if size is None:
+ size = self._DEFAULT_SYMBOL_SIZE
+ if size != self._symbol_size:
+ self._symbol_size = size
+ self._updated(ItemChangedType.SYMBOL_SIZE)
+
+
+class LineMixIn(ItemMixInBase):
+ """Mix-in class for item with line"""
+
+ _DEFAULT_LINEWIDTH = 1.
+ """Default line width"""
+
+ _DEFAULT_LINESTYLE = '-'
+ """Default line style"""
+
+ _SUPPORTED_LINESTYLE = '', ' ', '-', '--', '-.', ':', None
+ """Supported line styles"""
+
+ def __init__(self):
+ self._linewidth = self._DEFAULT_LINEWIDTH
+ self._linestyle = self._DEFAULT_LINESTYLE
+
+ @classmethod
+ def getSupportedLineStyles(cls):
+ """Returns list of supported line styles.
+
+ :rtype: List[str,None]
+ """
+ return cls._SUPPORTED_LINESTYLE
+
+ def getLineWidth(self):
+ """Return the curve line width in pixels
+
+ :rtype: float
+ """
+ return self._linewidth
+
+ def setLineWidth(self, width):
+ """Set the width in pixel of the curve line
+
+ See :meth:`getLineWidth`.
+
+ :param float width: Width in pixels
+ """
+ width = float(width)
+ if width != self._linewidth:
+ self._linewidth = width
+ self._updated(ItemChangedType.LINE_WIDTH)
+
+ def getLineStyle(self):
+ """Return the type of the line
+
+ Type of line::
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+
+ :rtype: str
+ """
+ return self._linestyle
+
+ def setLineStyle(self, style):
+ """Set the style of the curve line.
+
+ See :meth:`getLineStyle`.
+
+ :param str style: Line style
+ """
+ style = str(style)
+ assert style in self.getSupportedLineStyles()
+ if style is None:
+ style = self._DEFAULT_LINESTYLE
+ if style != self._linestyle:
+ self._linestyle = style
+ self._updated(ItemChangedType.LINE_STYLE)
+
+
+class ColorMixIn(ItemMixInBase):
+ """Mix-in class for item with color"""
+
+ _DEFAULT_COLOR = (0., 0., 0., 1.)
+ """Default color of the item"""
+
+ def __init__(self):
+ self._color = self._DEFAULT_COLOR
+
+ def getColor(self):
+ """Returns the RGBA color of the item
+
+ :rtype: 4-tuple of float in [0, 1] or array of colors
+ """
+ return self._color
+
+ def setColor(self, color, copy=True):
+ """Set item color
+
+ :param color: color(s) to be used
+ :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or
+ one of the predefined color names defined in colors.py
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ """
+ if isinstance(color, str):
+ color = colors.rgba(color)
+ elif isinstance(color, qt.QColor):
+ color = colors.rgba(color)
+ else:
+ color = numpy.array(color, copy=copy)
+ # TODO more checks + improve color array support
+ if color.ndim == 1: # Single RGBA color
+ color = colors.rgba(color)
+ else: # Array of colors
+ assert color.ndim == 2
+
+ self._color = color
+ self._updated(ItemChangedType.COLOR)
+
+
+class YAxisMixIn(ItemMixInBase):
+ """Mix-in class for item with yaxis"""
+
+ _DEFAULT_YAXIS = 'left'
+ """Default Y axis the item belongs to"""
+
+ def __init__(self):
+ self._yaxis = self._DEFAULT_YAXIS
+
+ def getYAxis(self):
+ """Returns the Y axis this curve belongs to.
+
+ Either 'left' or 'right'.
+
+ :rtype: str
+ """
+ return self._yaxis
+
+ def setYAxis(self, yaxis):
+ """Set the Y axis this curve belongs to.
+
+ :param str yaxis: 'left' or 'right'
+ """
+ yaxis = str(yaxis)
+ assert yaxis in ('left', 'right')
+ if yaxis != self._yaxis:
+ self._yaxis = yaxis
+ # Handle data extent changed for DataItem
+ if isinstance(self, DataItem):
+ self._boundsChanged()
+
+ # Handle visible extent changed
+ if self._isVisibleBoundsTracking():
+ # Switch Y axis signal connection
+ plot = self.getPlot()
+ if plot is not None:
+ previousYAxis = 'left' if self.getXAxis() == 'right' else 'right'
+ plot.getYAxis(previousYAxis).sigLimitsChanged.disconnect(
+ self._visibleBoundsChanged)
+ plot.getYAxis(self.getYAxis()).sigLimitsChanged.connect(
+ self._visibleBoundsChanged)
+ self._visibleBoundsChanged()
+
+ self._updated(ItemChangedType.YAXIS)
+
+
+class FillMixIn(ItemMixInBase):
+ """Mix-in class for item with fill"""
+
+ def __init__(self):
+ self._fill = False
+
+ def isFill(self):
+ """Returns whether the item is filled or not.
+
+ :rtype: bool
+ """
+ return self._fill
+
+ def setFill(self, fill):
+ """Set whether to fill the item or not.
+
+ :param bool fill:
+ """
+ fill = bool(fill)
+ if fill != self._fill:
+ self._fill = fill
+ self._updated(ItemChangedType.FILL)
+
+
+class AlphaMixIn(ItemMixInBase):
+ """Mix-in class for item with opacity"""
+
+ def __init__(self):
+ self._alpha = 1.
+
+ def getAlpha(self):
+ """Returns the opacity of the item
+
+ :rtype: float in [0, 1.]
+ """
+ return self._alpha
+
+ def setAlpha(self, alpha):
+ """Set the opacity of the item
+
+ .. note::
+
+ If the colormap already has some transparency, this alpha
+ adds additional transparency. The alpha channel of the colormap
+ is multiplied by this value.
+
+ :param alpha: Opacity of the item, between 0 (full transparency)
+ and 1. (full opacity)
+ :type alpha: float
+ """
+ alpha = float(alpha)
+ alpha = max(0., min(alpha, 1.)) # Clip alpha to [0., 1.] range
+ if alpha != self._alpha:
+ self._alpha = alpha
+ self._updated(ItemChangedType.ALPHA)
+
+
+class ComplexMixIn(ItemMixInBase):
+ """Mix-in class for complex data mode"""
+
+ _SUPPORTED_COMPLEX_MODES = None
+ """Override to only support a subset of all ComplexMode"""
+
+ class ComplexMode(_Enum):
+ """Identify available display mode for complex"""
+ NONE = 'none'
+ ABSOLUTE = 'amplitude'
+ PHASE = 'phase'
+ REAL = 'real'
+ IMAGINARY = 'imaginary'
+ AMPLITUDE_PHASE = 'amplitude_phase'
+ LOG10_AMPLITUDE_PHASE = 'log10_amplitude_phase'
+ SQUARE_AMPLITUDE = 'square_amplitude'
+
+ def __init__(self):
+ self.__complex_mode = self.ComplexMode.ABSOLUTE
+
+ def getComplexMode(self):
+ """Returns the current complex visualization mode.
+
+ :rtype: ComplexMode
+ """
+ return self.__complex_mode
+
+ def setComplexMode(self, mode):
+ """Set the complex visualization mode.
+
+ :param ComplexMode mode: The visualization mode in:
+ 'real', 'imaginary', 'phase', 'amplitude'
+ :return: True if value was set, False if is was already set
+ :rtype: bool
+ """
+ mode = self.ComplexMode.from_value(mode)
+ assert mode in self.supportedComplexModes()
+
+ if mode != self.__complex_mode:
+ self.__complex_mode = mode
+ self._updated(ItemChangedType.COMPLEX_MODE)
+ return True
+ else:
+ return False
+
+ def _convertComplexData(self, data, mode=None):
+ """Convert complex data to the specific mode.
+
+ :param Union[ComplexMode,None] mode:
+ The kind of value to compute.
+ If None (the default), the current complex mode is used.
+ :return: The converted dataset
+ :rtype: Union[numpy.ndarray[float],None]
+ """
+ if data is None:
+ return None
+
+ if mode is None:
+ mode = self.getComplexMode()
+
+ if mode is self.ComplexMode.REAL:
+ return numpy.real(data)
+ elif mode is self.ComplexMode.IMAGINARY:
+ return numpy.imag(data)
+ elif mode is self.ComplexMode.ABSOLUTE:
+ return numpy.absolute(data)
+ elif mode is self.ComplexMode.PHASE:
+ return numpy.angle(data)
+ elif mode is self.ComplexMode.SQUARE_AMPLITUDE:
+ return numpy.absolute(data) ** 2
+ else:
+ raise ValueError('Unsupported conversion mode: %s', str(mode))
+
+ @classmethod
+ def supportedComplexModes(cls):
+ """Returns the list of supported complex visualization modes.
+
+ See :class:`ComplexMode` and :meth:`setComplexMode`.
+
+ :rtype: List[ComplexMode]
+ """
+ if cls._SUPPORTED_COMPLEX_MODES is None:
+ return cls.ComplexMode.members()
+ else:
+ return cls._SUPPORTED_COMPLEX_MODES
+
+
+class ScatterVisualizationMixIn(ItemMixInBase):
+ """Mix-in class for scatter plot visualization modes"""
+
+ _SUPPORTED_SCATTER_VISUALIZATION = None
+ """Allows to override supported Visualizations"""
+
+ @enum.unique
+ class Visualization(_Enum):
+ """Different modes of scatter plot visualizations"""
+
+ POINTS = 'points'
+ """Display scatter plot as a point cloud"""
+
+ LINES = 'lines'
+ """Display scatter plot as a wireframe.
+
+ This is based on Delaunay triangulation
+ """
+
+ SOLID = 'solid'
+ """Display scatter plot as a set of filled triangles.
+
+ This is based on Delaunay triangulation
+ """
+
+ REGULAR_GRID = 'regular_grid'
+ """Display scatter plot as an image.
+
+ It expects the points to be the intersection of a regular grid,
+ and the order of points following that of an image.
+ First line, then second one, and always in the same direction
+ (either all lines from left to right or all from right to left).
+ """
+
+ IRREGULAR_GRID = 'irregular_grid'
+ """Display scatter plot as contiguous quadrilaterals.
+
+ It expects the points to be the intersection of an irregular grid,
+ and the order of points following that of an image.
+ First line, then second one, and always in the same direction
+ (either all lines from left to right or all from right to left).
+ """
+
+ BINNED_STATISTIC = 'binned_statistic'
+ """Display scatter plot as 2D binned statistic (i.e., generalized histogram).
+ """
+
+ @enum.unique
+ class VisualizationParameter(_Enum):
+ """Different parameter names for scatter plot visualizations"""
+
+ GRID_MAJOR_ORDER = 'grid_major_order'
+ """The major order of points in the regular grid.
+
+ Either 'row' (row-major, fast X) or 'column' (column-major, fast Y).
+ """
+
+ GRID_BOUNDS = 'grid_bounds'
+ """The expected range in data coordinates of the regular grid.
+
+ A 2-tuple of 2-tuple: (begin (x, y), end (x, y)).
+ This provides the data coordinates of the first point and the expected
+ last on.
+ As for `GRID_SHAPE`, this can be wider than the current data.
+ """
+
+ GRID_SHAPE = 'grid_shape'
+ """The expected size of the regular grid (height, width).
+
+ The given shape can be wider than the number of points,
+ in which case the grid is not fully filled.
+ """
+
+ BINNED_STATISTIC_SHAPE = 'binned_statistic_shape'
+ """The number of bins in each dimension (height, width).
+ """
+
+ BINNED_STATISTIC_FUNCTION = 'binned_statistic_function'
+ """The reduction function to apply to each bin (str).
+
+ Available reduction functions are: 'mean' (default), 'count', 'sum'.
+ """
+
+ DATA_BOUNDS_HINT = 'data_bounds_hint'
+ """The expected bounds of the data in data coordinates.
+
+ A 2-tuple of 2-tuple: ((ymin, ymax), (xmin, xmax)).
+ This provides a hint for the data ranges in both dimensions.
+ It is eventually enlarged with actually data ranges.
+
+ WARNING: dimension 0 i.e., Y first.
+ """
+
+ _SUPPORTED_VISUALIZATION_PARAMETER_VALUES = {
+ VisualizationParameter.GRID_MAJOR_ORDER: ('row', 'column'),
+ VisualizationParameter.BINNED_STATISTIC_FUNCTION: ('mean', 'count', 'sum'),
+ }
+ """Supported visualization parameter values.
+
+ Defined for parameters with a set of acceptable values.
+ """
+
+ def __init__(self):
+ self.__visualization = self.Visualization.POINTS
+ self.__parameters = dict(# Init parameters to None
+ (parameter, None) for parameter in self.VisualizationParameter)
+ self.__parameters[self.VisualizationParameter.BINNED_STATISTIC_FUNCTION] = 'mean'
+
+ @classmethod
+ def supportedVisualizations(cls):
+ """Returns the list of supported scatter visualization modes.
+
+ See :meth:`setVisualization`
+
+ :rtype: List[Visualization]
+ """
+ if cls._SUPPORTED_SCATTER_VISUALIZATION is None:
+ return cls.Visualization.members()
+ else:
+ return cls._SUPPORTED_SCATTER_VISUALIZATION
+
+ @classmethod
+ def supportedVisualizationParameterValues(cls, parameter):
+ """Returns the list of supported scatter visualization modes.
+
+ See :meth:`VisualizationParameters`
+
+ :param VisualizationParameter parameter:
+ This parameter for which to retrieve the supported values.
+ :returns: tuple of supported of values or None if not defined.
+ """
+ parameter = cls.VisualizationParameter(parameter)
+ return cls._SUPPORTED_VISUALIZATION_PARAMETER_VALUES.get(
+ parameter, None)
+
+ def setVisualization(self, mode):
+ """Set the scatter plot visualization mode to use.
+
+ See :class:`Visualization` for all possible values,
+ and :meth:`supportedVisualizations` for supported ones.
+
+ :param Union[str,Visualization] mode:
+ The visualization mode to use.
+ :return: True if value was set, False if is was already set
+ :rtype: bool
+ """
+ mode = self.Visualization.from_value(mode)
+ assert mode in self.supportedVisualizations()
+
+ if mode != self.__visualization:
+ self.__visualization = mode
+
+ self._updated(ItemChangedType.VISUALIZATION_MODE)
+ return True
+ else:
+ return False
+
+ def getVisualization(self):
+ """Returns the scatter plot visualization mode in use.
+
+ :rtype: Visualization
+ """
+ return self.__visualization
+
+ def setVisualizationParameter(self, parameter, value=None):
+ """Set the given visualization parameter.
+
+ :param Union[str,VisualizationParameter] parameter:
+ The name of the parameter to set
+ :param value: The value to use for this parameter
+ Set to None to automatically set the parameter
+ :raises ValueError: If parameter is not supported
+ :return: True if parameter was set, False if is was already set
+ :rtype: bool
+ :raise ValueError: If value is not supported
+ """
+ parameter = self.VisualizationParameter.from_value(parameter)
+
+ if self.__parameters[parameter] != value:
+ validValues = self.supportedVisualizationParameterValues(parameter)
+ if validValues is not None and value not in validValues:
+ raise ValueError("Unsupported parameter value: %s" % str(value))
+
+ self.__parameters[parameter] = value
+ self._updated(ItemChangedType.VISUALIZATION_MODE)
+ return True
+ return False
+
+ def getVisualizationParameter(self, parameter):
+ """Returns the value of the given visualization parameter.
+
+ This method returns the parameter as set by
+ :meth:`setVisualizationParameter`.
+
+ :param parameter: The name of the parameter to retrieve
+ :returns: The value previously set or None if automatically set
+ :raises ValueError: If parameter is not supported
+ """
+ if parameter not in self.VisualizationParameter:
+ raise ValueError("parameter not supported: %s", parameter)
+
+ return self.__parameters[parameter]
+
+ def getCurrentVisualizationParameter(self, parameter):
+ """Returns the current value of the given visualization parameter.
+
+ If the parameter was set by :meth:`setVisualizationParameter` to
+ a value that is not None, this value is returned;
+ else the current value that is automatically computed is returned.
+
+ :param parameter: The name of the parameter to retrieve
+ :returns: The current value (either set or automatically computed)
+ :raises ValueError: If parameter is not supported
+ """
+ # Override in subclass to provide automatically computed parameters
+ return self.getVisualizationParameter(parameter)
+
+
+class PointsBase(DataItem, SymbolMixIn, AlphaMixIn):
+ """Base class for :class:`Curve` and :class:`Scatter`"""
+ # note: _logFilterData must be overloaded if you overload
+ # getData to change its signature
+
+ _DEFAULT_Z_LAYER = 1
+ """Default overlay layer for points,
+ on top of images."""
+
+ def __init__(self):
+ DataItem.__init__(self)
+ SymbolMixIn.__init__(self)
+ AlphaMixIn.__init__(self)
+ self._x = ()
+ self._y = ()
+ self._xerror = None
+ self._yerror = None
+
+ # Store filtered data for x > 0 and/or y > 0
+ self._filteredCache = {}
+ self._clippedCache = {}
+
+ # Store bounds depending on axes filtering >0:
+ # key is (isXPositiveFilter, isYPositiveFilter)
+ self._boundsCache = {}
+
+ @staticmethod
+ def _logFilterError(value, error):
+ """Filter/convert error values if they go <= 0.
+
+ Replace error leading to negative values by nan
+
+ :param numpy.ndarray value: 1D array of values
+ :param numpy.ndarray error:
+ Array of errors: scalar, N, Nx1 or 2xN or None.
+ :return: Filtered error so error bars are never negative
+ """
+ if error is not None:
+ # Convert Nx1 to N
+ if error.ndim == 2 and error.shape[1] == 1 and len(value) != 1:
+ error = numpy.ravel(error)
+
+ # Supports error being scalar, N or 2xN array
+ valueMinusError = value - numpy.atleast_2d(error)[0]
+ errorClipped = numpy.isnan(valueMinusError)
+ mask = numpy.logical_not(errorClipped)
+ errorClipped[mask] = valueMinusError[mask] <= 0
+
+ if numpy.any(errorClipped): # Need filtering
+
+ # expand errorbars to 2xN
+ if error.size == 1: # Scalar
+ error = numpy.full(
+ (2, len(value)), error, dtype=numpy.float64)
+
+ elif error.ndim == 1: # N array
+ newError = numpy.empty((2, len(value)),
+ dtype=numpy.float64)
+ newError[0,:] = error
+ newError[1,:] = error
+ error = newError
+
+ elif error.size == 2 * len(value): # 2xN array
+ error = numpy.array(
+ error, copy=True, dtype=numpy.float64)
+
+ else:
+ _logger.error("Unhandled error array")
+ return error
+
+ error[0, errorClipped] = numpy.nan
+
+ return error
+
+ def _getClippingBoolArray(self, xPositive, yPositive):
+ """Compute a boolean array to filter out points with negative
+ coordinates on log axes.
+
+ :param bool xPositive: True to filter arrays according to X coords.
+ :param bool yPositive: True to filter arrays according to Y coords.
+ :rtype: boolean numpy.ndarray
+ """
+ assert xPositive or yPositive
+ if (xPositive, yPositive) not in self._clippedCache:
+ xclipped, yclipped = False, False
+
+ if xPositive:
+ x = self.getXData(copy=False)
+ with numpy.errstate(invalid='ignore'): # Ignore NaN warnings
+ xclipped = x <= 0
+
+ if yPositive:
+ y = self.getYData(copy=False)
+ with numpy.errstate(invalid='ignore'): # Ignore NaN warnings
+ yclipped = y <= 0
+
+ self._clippedCache[(xPositive, yPositive)] = \
+ numpy.logical_or(xclipped, yclipped)
+ return self._clippedCache[(xPositive, yPositive)]
+
+ def _logFilterData(self, xPositive, yPositive):
+ """Filter out values with x or y <= 0 on log axes
+
+ :param bool xPositive: True to filter arrays according to X coords.
+ :param bool yPositive: True to filter arrays according to Y coords.
+ :return: The filter arrays or unchanged object if filtering not needed
+ :rtype: (x, y, xerror, yerror)
+ """
+ x = self.getXData(copy=False)
+ y = self.getYData(copy=False)
+ xerror = self.getXErrorData(copy=False)
+ yerror = self.getYErrorData(copy=False)
+
+ if xPositive or yPositive:
+ clipped = self._getClippingBoolArray(xPositive, yPositive)
+
+ if numpy.any(clipped):
+ # copy to keep original array and convert to float
+ x = numpy.array(x, copy=True, dtype=numpy.float64)
+ x[clipped] = numpy.nan
+ y = numpy.array(y, copy=True, dtype=numpy.float64)
+ y[clipped] = numpy.nan
+
+ if xPositive and xerror is not None:
+ xerror = self._logFilterError(x, xerror)
+
+ if yPositive and yerror is not None:
+ yerror = self._logFilterError(y, yerror)
+
+ return x, y, xerror, yerror
+
+ def _getBounds(self):
+ if self.getXData(copy=False).size == 0: # Empty data
+ return None
+
+ plot = self.getPlot()
+ if plot is not None:
+ xPositive = plot.getXAxis()._isLogarithmic()
+ yPositive = plot.getYAxis()._isLogarithmic()
+ else:
+ xPositive = False
+ yPositive = False
+
+ # TODO bounds do not take error bars into account
+ if (xPositive, yPositive) not in self._boundsCache:
+ # use the getData class method because instance method can be
+ # overloaded to return additional arrays
+ data = PointsBase.getData(self, copy=False, displayed=True)
+ if len(data) == 5:
+ # hack to avoid duplicating caching mechanism in Scatter
+ # (happens when cached data is used, caching done using
+ # Scatter._logFilterData)
+ x, y, _xerror, _yerror = data[0], data[1], data[3], data[4]
+ else:
+ x, y, _xerror, _yerror = data
+
+ xmin, xmax = min_max(x, finite=True)
+ ymin, ymax = min_max(y, finite=True)
+ self._boundsCache[(xPositive, yPositive)] = tuple([
+ (bound if bound is not None else numpy.nan)
+ for bound in (xmin, xmax, ymin, ymax)])
+ return self._boundsCache[(xPositive, yPositive)]
+
+ def _getCachedData(self):
+ """Return cached filtered data if applicable,
+ i.e. if any axis is in log scale.
+ Return None if caching is not applicable."""
+ plot = self.getPlot()
+ if plot is not None:
+ xPositive = plot.getXAxis()._isLogarithmic()
+ yPositive = plot.getYAxis()._isLogarithmic()
+ if xPositive or yPositive:
+ # At least one axis has log scale, filter data
+ if (xPositive, yPositive) not in self._filteredCache:
+ self._filteredCache[(xPositive, yPositive)] = \
+ self._logFilterData(xPositive, yPositive)
+ return self._filteredCache[(xPositive, yPositive)]
+ return None
+
+ def getData(self, copy=True, displayed=False):
+ """Returns the x, y values of the curve points and xerror, yerror
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :param bool displayed: True to only get curve points that are displayed
+ in the plot. Default: False
+ Note: If plot has log scale, negative points
+ are not displayed.
+ :returns: (x, y, xerror, yerror)
+ :rtype: 4-tuple of numpy.ndarray
+ """
+ if displayed: # filter data according to plot state
+ cached_data = self._getCachedData()
+ if cached_data is not None:
+ return cached_data
+
+ return (self.getXData(copy),
+ self.getYData(copy),
+ self.getXErrorData(copy),
+ self.getYErrorData(copy))
+
+ def getXData(self, copy=True):
+ """Returns the x coordinates of the data points
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self._x, copy=copy)
+
+ def getYData(self, copy=True):
+ """Returns the y coordinates of the data points
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self._y, copy=copy)
+
+ def getXErrorData(self, copy=True):
+ """Returns the x error of the points
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: numpy.ndarray, float or None
+ """
+ if isinstance(self._xerror, numpy.ndarray):
+ return numpy.array(self._xerror, copy=copy)
+ else:
+ return self._xerror # float or None
+
+ def getYErrorData(self, copy=True):
+ """Returns the y error of the points
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: numpy.ndarray, float or None
+ """
+ if isinstance(self._yerror, numpy.ndarray):
+ return numpy.array(self._yerror, copy=copy)
+ else:
+ return self._yerror # float or None
+
+ def setData(self, x, y, xerror=None, yerror=None, copy=True):
+ """Set the data of the curve.
+
+ :param numpy.ndarray x: The data corresponding to the x coordinates.
+ :param numpy.ndarray y: The data corresponding to the y coordinates.
+ :param xerror: Values with the uncertainties on the x values
+ :type xerror: A float, or a numpy.ndarray of float32.
+ If it is an array, it can either be a 1D array of
+ same length as the data or a 2D array with 2 rows
+ of same length as the data: row 0 for positive errors,
+ row 1 for negative errors.
+ :param yerror: Values with the uncertainties on the y values.
+ :type yerror: A float, or a numpy.ndarray of float32. See xerror.
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ """
+ x = numpy.array(x, copy=copy)
+ y = numpy.array(y, copy=copy)
+ assert len(x) == len(y)
+ assert x.ndim == y.ndim == 1
+
+ # Convert complex data
+ if numpy.iscomplexobj(x):
+ _logger.warning(
+ 'Converting x data to absolute value to plot it.')
+ x = numpy.absolute(x)
+ if numpy.iscomplexobj(y):
+ _logger.warning(
+ 'Converting y data to absolute value to plot it.')
+ y = numpy.absolute(y)
+
+ if xerror is not None:
+ if isinstance(xerror, abc.Iterable):
+ xerror = numpy.array(xerror, copy=copy)
+ if numpy.iscomplexobj(xerror):
+ _logger.warning(
+ 'Converting xerror data to absolute value to plot it.')
+ xerror = numpy.absolute(xerror)
+ else:
+ xerror = float(xerror)
+ if yerror is not None:
+ if isinstance(yerror, abc.Iterable):
+ yerror = numpy.array(yerror, copy=copy)
+ if numpy.iscomplexobj(yerror):
+ _logger.warning(
+ 'Converting yerror data to absolute value to plot it.')
+ yerror = numpy.absolute(yerror)
+ else:
+ yerror = float(yerror)
+ # TODO checks on xerror, yerror
+ self._x, self._y = x, y
+ self._xerror, self._yerror = xerror, yerror
+
+ self._boundsCache = {} # Reset cached bounds
+ self._filteredCache = {} # Reset cached filtered data
+ self._clippedCache = {} # Reset cached clipped bool array
+
+ self._boundsChanged()
+ self._updated(ItemChangedType.DATA)
+
+
+class BaselineMixIn(object):
+ """Base class for Baseline mix-in"""
+
+ def __init__(self, baseline=None):
+ self._baseline = baseline
+
+ def _setBaseline(self, baseline):
+ """
+ Set baseline value
+
+ :param baseline: baseline value(s)
+ :type: Union[None,float,numpy.ndarray]
+ """
+ if (isinstance(baseline, abc.Iterable)):
+ baseline = numpy.array(baseline)
+ self._baseline = baseline
+
+ def getBaseline(self, copy=True):
+ """
+
+ :param bool copy:
+ :return: histogram baseline
+ :rtype: Union[None,float,numpy.ndarray]
+ """
+ if isinstance(self._baseline, numpy.ndarray):
+ return numpy.array(self._baseline, copy=True)
+ else:
+ return self._baseline
+
+
+class _Style:
+ """Object which store styles"""
+
+
+class HighlightedMixIn(ItemMixInBase):
+
+ def __init__(self):
+ self._highlightStyle = self._DEFAULT_HIGHLIGHT_STYLE
+ self._highlighted = False
+
+ def isHighlighted(self):
+ """Returns True if curve is highlighted.
+
+ :rtype: bool
+ """
+ return self._highlighted
+
+ def setHighlighted(self, highlighted):
+ """Set the highlight state of the curve
+
+ :param bool highlighted:
+ """
+ highlighted = bool(highlighted)
+ if highlighted != self._highlighted:
+ self._highlighted = highlighted
+ # TODO inefficient: better to use backend's setCurveColor
+ self._updated(ItemChangedType.HIGHLIGHTED)
+
+ def getHighlightedStyle(self):
+ """Returns the highlighted style in use
+
+ :rtype: CurveStyle
+ """
+ return self._highlightStyle
+
+ def setHighlightedStyle(self, style):
+ """Set the style to use for highlighting
+
+ :param CurveStyle style: New style to use
+ """
+ previous = self.getHighlightedStyle()
+ if style != previous:
+ assert isinstance(style, _Style)
+ self._highlightStyle = style
+ self._updated(ItemChangedType.HIGHLIGHTED_STYLE)
+
+ # Backward compatibility event
+ if previous.getColor() != style.getColor():
+ self._updated(ItemChangedType.HIGHLIGHTED_COLOR)
diff --git a/src/silx/gui/plot/items/curve.py b/src/silx/gui/plot/items/curve.py
new file mode 100644
index 0000000..7cbe26e
--- /dev/null
+++ b/src/silx/gui/plot/items/curve.py
@@ -0,0 +1,325 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the :class:`Curve` item of the :class:`Plot`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+
+import logging
+
+import numpy
+
+from ....utils.deprecation import deprecated
+from ... import colors
+from .core import (PointsBase, LabelsMixIn, ColorMixIn, YAxisMixIn,
+ FillMixIn, LineMixIn, SymbolMixIn, ItemChangedType,
+ BaselineMixIn, HighlightedMixIn, _Style)
+
+
+_logger = logging.getLogger(__name__)
+
+
+class CurveStyle(_Style):
+ """Object storing the style of a curve.
+
+ Set a value to None to use the default
+
+ :param color: Color
+ :param Union[str,None] linestyle: Style of the line
+ :param Union[float,None] linewidth: Width of the line
+ :param Union[str,None] symbol: Symbol for markers
+ :param Union[float,None] symbolsize: Size of the markers
+ """
+
+ def __init__(self, color=None, linestyle=None, linewidth=None,
+ symbol=None, symbolsize=None):
+ if color is None:
+ self._color = None
+ else:
+ if isinstance(color, str):
+ color = colors.rgba(color)
+ else: # array-like expected
+ color = numpy.array(color, copy=False)
+ if color.ndim == 1: # Array is 1D, this is a single color
+ color = colors.rgba(color)
+ self._color = color
+
+ if linestyle is not None:
+ assert linestyle in LineMixIn.getSupportedLineStyles()
+ self._linestyle = linestyle
+
+ self._linewidth = None if linewidth is None else float(linewidth)
+
+ if symbol is not None:
+ assert symbol in SymbolMixIn.getSupportedSymbols()
+ self._symbol = symbol
+
+ self._symbolsize = None if symbolsize is None else float(symbolsize)
+
+ def getColor(self, copy=True):
+ """Returns the color or None if not set.
+
+ :param bool copy: True to get a copy (default),
+ False to get internal representation (do not modify!)
+
+ :rtype: Union[List[float],None]
+ """
+ if isinstance(self._color, numpy.ndarray):
+ return numpy.array(self._color, copy=copy)
+ else:
+ return self._color
+
+ def getLineStyle(self):
+ """Return the type of the line or None if not set.
+
+ Type of line::
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+
+ :rtype: Union[str,None]
+ """
+ return self._linestyle
+
+ def getLineWidth(self):
+ """Return the curve line width in pixels or None if not set.
+
+ :rtype: Union[float,None]
+ """
+ return self._linewidth
+
+ def getSymbol(self):
+ """Return the point marker type.
+
+ Marker type::
+
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+
+ :rtype: Union[str,None]
+ """
+ return self._symbol
+
+ def getSymbolSize(self):
+ """Return the point marker size in points.
+
+ :rtype: Union[float,None]
+ """
+ return self._symbolsize
+
+ def __eq__(self, other):
+ if isinstance(other, CurveStyle):
+ return (numpy.array_equal(self.getColor(), other.getColor()) and
+ self.getLineStyle() == other.getLineStyle() and
+ self.getLineWidth() == other.getLineWidth() and
+ self.getSymbol() == other.getSymbol() and
+ self.getSymbolSize() == other.getSymbolSize())
+ else:
+ return False
+
+
+class Curve(PointsBase, ColorMixIn, YAxisMixIn, FillMixIn, LabelsMixIn,
+ LineMixIn, BaselineMixIn, HighlightedMixIn):
+ """Description of a curve"""
+
+ _DEFAULT_Z_LAYER = 1
+ """Default overlay layer for curves"""
+
+ _DEFAULT_SELECTABLE = True
+ """Default selectable state for curves"""
+
+ _DEFAULT_LINEWIDTH = 1.
+ """Default line width of the curve"""
+
+ _DEFAULT_LINESTYLE = '-'
+ """Default line style of the curve"""
+
+ _DEFAULT_HIGHLIGHT_STYLE = CurveStyle(color='black')
+ """Default highlight style of the item"""
+
+ _DEFAULT_BASELINE = None
+
+ def __init__(self):
+ PointsBase.__init__(self)
+ ColorMixIn.__init__(self)
+ YAxisMixIn.__init__(self)
+ FillMixIn.__init__(self)
+ LabelsMixIn.__init__(self)
+ LineMixIn.__init__(self)
+ BaselineMixIn.__init__(self)
+ HighlightedMixIn.__init__(self)
+
+ self._setBaseline(Curve._DEFAULT_BASELINE)
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ # Filter-out values <= 0
+ xFiltered, yFiltered, xerror, yerror = self.getData(
+ copy=False, displayed=True)
+
+ if len(xFiltered) == 0 or not numpy.any(numpy.isfinite(xFiltered)):
+ return None # No data to display, do not add renderer to backend
+
+ style = self.getCurrentStyle()
+
+ return backend.addCurve(xFiltered, yFiltered,
+ color=style.getColor(),
+ symbol=style.getSymbol(),
+ linestyle=style.getLineStyle(),
+ linewidth=style.getLineWidth(),
+ yaxis=self.getYAxis(),
+ xerror=xerror,
+ yerror=yerror,
+ fill=self.isFill(),
+ alpha=self.getAlpha(),
+ symbolsize=style.getSymbolSize(),
+ baseline=self.getBaseline(copy=False))
+
+ def __getitem__(self, item):
+ """Compatibility with PyMca and silx <= 0.4.0"""
+ if isinstance(item, slice):
+ return [self[index] for index in range(*item.indices(5))]
+ elif item == 0:
+ return self.getXData(copy=False)
+ elif item == 1:
+ return self.getYData(copy=False)
+ elif item == 2:
+ return self.getName()
+ elif item == 3:
+ info = self.getInfo(copy=False)
+ return {} if info is None else info
+ elif item == 4:
+ params = {
+ 'info': self.getInfo(),
+ 'color': self.getColor(),
+ 'symbol': self.getSymbol(),
+ 'linewidth': self.getLineWidth(),
+ 'linestyle': self.getLineStyle(),
+ 'xlabel': self.getXLabel(),
+ 'ylabel': self.getYLabel(),
+ 'yaxis': self.getYAxis(),
+ 'xerror': self.getXErrorData(copy=False),
+ 'yerror': self.getYErrorData(copy=False),
+ 'z': self.getZValue(),
+ 'selectable': self.isSelectable(),
+ 'fill': self.isFill(),
+ }
+ return params
+ else:
+ raise IndexError("Index out of range: %s", str(item))
+
+ @deprecated(replacement='Curve.getHighlightedStyle().getColor()',
+ since_version='0.9.0')
+ def getHighlightedColor(self):
+ """Returns the RGBA highlight color of the item
+
+ :rtype: 4-tuple of float in [0, 1]
+ """
+ return self.getHighlightedStyle().getColor()
+
+ @deprecated(replacement='Curve.setHighlightedStyle()',
+ since_version='0.9.0')
+ def setHighlightedColor(self, color):
+ """Set the color to use when highlighted
+
+ :param color: color(s) to be used for highlight
+ :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or
+ one of the predefined color names defined in colors.py
+ """
+ self.setHighlightedStyle(CurveStyle(color))
+
+ def getCurrentStyle(self):
+ """Returns the current curve style.
+
+ Curve style depends on curve highlighting
+
+ :rtype: CurveStyle
+ """
+ if self.isHighlighted():
+ style = self.getHighlightedStyle()
+ color = style.getColor()
+ linestyle = style.getLineStyle()
+ linewidth = style.getLineWidth()
+ symbol = style.getSymbol()
+ symbolsize = style.getSymbolSize()
+
+ return CurveStyle(
+ color=self.getColor() if color is None else color,
+ linestyle=self.getLineStyle() if linestyle is None else linestyle,
+ linewidth=self.getLineWidth() if linewidth is None else linewidth,
+ symbol=self.getSymbol() if symbol is None else symbol,
+ symbolsize=self.getSymbolSize() if symbolsize is None else symbolsize)
+
+ else:
+ return CurveStyle(color=self.getColor(),
+ linestyle=self.getLineStyle(),
+ linewidth=self.getLineWidth(),
+ symbol=self.getSymbol(),
+ symbolsize=self.getSymbolSize())
+
+ @deprecated(replacement='Curve.getCurrentStyle()',
+ since_version='0.9.0')
+ def getCurrentColor(self):
+ """Returns the current color of the curve.
+
+ This color is either the color of the curve or the highlighted color,
+ depending on the highlight state.
+
+ :rtype: 4-tuple of float in [0, 1]
+ """
+ return self.getCurrentStyle().getColor()
+
+ def setData(self, x, y, xerror=None, yerror=None, baseline=None, copy=True):
+ """Set the data of the curve.
+
+ :param numpy.ndarray x: The data corresponding to the x coordinates.
+ :param numpy.ndarray y: The data corresponding to the y coordinates.
+ :param xerror: Values with the uncertainties on the x values
+ :type xerror: A float, or a numpy.ndarray of float32.
+ If it is an array, it can either be a 1D array of
+ same length as the data or a 2D array with 2 rows
+ of same length as the data: row 0 for positive errors,
+ row 1 for negative errors.
+ :param yerror: Values with the uncertainties on the y values.
+ :type yerror: A float, or a numpy.ndarray of float32. See xerror.
+ :param baseline: curve baseline
+ :type baseline: Union[None,float,numpy.ndarray]
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ """
+ PointsBase.setData(self, x=x, y=y, xerror=xerror, yerror=yerror,
+ copy=copy)
+ self._setBaseline(baseline=baseline)
diff --git a/src/silx/gui/plot/items/histogram.py b/src/silx/gui/plot/items/histogram.py
new file mode 100644
index 0000000..16bbefa
--- /dev/null
+++ b/src/silx/gui/plot/items/histogram.py
@@ -0,0 +1,389 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions::t
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the :class:`Histogram` item of the :class:`Plot`.
+"""
+
+__authors__ = ["H. Payno", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "28/08/2018"
+
+import logging
+import typing
+
+import numpy
+from collections import OrderedDict, namedtuple
+try:
+ from collections import abc
+except ImportError: # Python2 support
+ import collections as abc
+
+from ....utils.proxy import docstring
+from .core import (DataItem, AlphaMixIn, BaselineMixIn, ColorMixIn, FillMixIn,
+ LineMixIn, YAxisMixIn, ItemChangedType, Item)
+from ._pick import PickingResult
+
+_logger = logging.getLogger(__name__)
+
+
+def _computeEdges(x, histogramType):
+ """Compute the edges from a set of xs and a rule to generate the edges
+
+ :param x: the x value of the curve to transform into an histogram
+ :param histogramType: the type of histogram we wan't to generate.
+ This define the way to center the histogram values compared to the
+ curve value. Possible values can be::
+
+ - 'left'
+ - 'right'
+ - 'center'
+
+ :return: the edges for the given x and the histogramType
+ """
+ # for now we consider that the spaces between xs are constant
+ edges = x.copy()
+ if histogramType == 'left':
+ width = 1
+ if len(x) > 1:
+ width = x[1] - x[0]
+ edges = numpy.append(x[0] - width, edges)
+ if histogramType == 'center':
+ edges = _computeEdges(edges, 'right')
+ widths = (edges[1:] - edges[0:-1]) / 2.0
+ widths = numpy.append(widths, widths[-1])
+ edges = edges - widths
+ if histogramType == 'right':
+ width = 1
+ if len(x) > 1:
+ width = x[-1] - x[-2]
+ edges = numpy.append(edges, x[-1] + width)
+
+ return edges
+
+
+def _getHistogramCurve(histogram, edges):
+ """Returns the x and y value of a curve corresponding to the histogram
+
+ :param numpy.ndarray histogram: The values of the histogram
+ :param numpy.ndarray edges: The bin edges of the histogram
+ :return: a tuple(x, y) which contains the value of the curve to use
+ to display the histogram
+ """
+ assert len(histogram) + 1 == len(edges)
+ x = numpy.empty(len(histogram) * 2, dtype=edges.dtype)
+ y = numpy.empty(len(histogram) * 2, dtype=histogram.dtype)
+ # Make a curve with stairs
+ x[:-1:2] = edges[:-1]
+ x[1::2] = edges[1:]
+ y[:-1:2] = histogram
+ y[1::2] = histogram
+
+ return x, y
+
+
+# TODO: Yerror, test log scale
+class Histogram(DataItem, AlphaMixIn, ColorMixIn, FillMixIn,
+ LineMixIn, YAxisMixIn, BaselineMixIn):
+ """Description of an histogram"""
+
+ _DEFAULT_Z_LAYER = 1
+ """Default overlay layer for histograms"""
+
+ _DEFAULT_SELECTABLE = False
+ """Default selectable state for histograms"""
+
+ _DEFAULT_LINEWIDTH = 1.
+ """Default line width of the histogram"""
+
+ _DEFAULT_LINESTYLE = '-'
+ """Default line style of the histogram"""
+
+ _DEFAULT_BASELINE = None
+
+ def __init__(self):
+ DataItem.__init__(self)
+ AlphaMixIn.__init__(self)
+ BaselineMixIn.__init__(self)
+ ColorMixIn.__init__(self)
+ FillMixIn.__init__(self)
+ LineMixIn.__init__(self)
+ YAxisMixIn.__init__(self)
+
+ self._histogram = ()
+ self._edges = ()
+ self._setBaseline(Histogram._DEFAULT_BASELINE)
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ values, edges, baseline = self.getData(copy=False)
+
+ if values.size == 0:
+ return None # No data to display, do not add renderer
+
+ if values.size == 0:
+ return None # No data to display, do not add renderer to backend
+
+ x, y = _getHistogramCurve(values, edges)
+
+ # Filter-out values <= 0
+ plot = self.getPlot()
+ if plot is not None:
+ xPositive = plot.getXAxis()._isLogarithmic()
+ yPositive = plot.getYAxis()._isLogarithmic()
+ else:
+ xPositive = False
+ yPositive = False
+
+ if xPositive or yPositive:
+ clipped = numpy.logical_or(
+ (x <= 0) if xPositive else False,
+ (y <= 0) if yPositive else False)
+ # Make a copy and replace negative points by NaN
+ x = numpy.array(x, dtype=numpy.float64)
+ y = numpy.array(y, dtype=numpy.float64)
+ x[clipped] = numpy.nan
+ y[clipped] = numpy.nan
+
+ return backend.addCurve(x, y,
+ color=self.getColor(),
+ symbol='',
+ linestyle=self.getLineStyle(),
+ linewidth=self.getLineWidth(),
+ yaxis=self.getYAxis(),
+ xerror=None,
+ yerror=None,
+ fill=self.isFill(),
+ alpha=self.getAlpha(),
+ baseline=baseline,
+ symbolsize=1)
+
+ def _getBounds(self):
+ values, edges, baseline = self.getData(copy=False)
+
+ plot = self.getPlot()
+ if plot is not None:
+ xPositive = plot.getXAxis()._isLogarithmic()
+ yPositive = plot.getYAxis()._isLogarithmic()
+ else:
+ xPositive = False
+ yPositive = False
+
+ if xPositive or yPositive:
+ values = numpy.array(values, copy=True, dtype=numpy.float64)
+
+ if xPositive:
+ # Replace edges <= 0 by NaN and corresponding values by NaN
+ clipped_edges = (edges <= 0)
+ edges = numpy.array(edges, copy=True, dtype=numpy.float64)
+ edges[clipped_edges] = numpy.nan
+ clipped_values = numpy.logical_or(clipped_edges[:-1],
+ clipped_edges[1:])
+ else:
+ clipped_values = numpy.zeros_like(values, dtype=bool)
+
+ if yPositive:
+ # Replace values <= 0 by NaN, do not modify edges
+ clipped_values = numpy.logical_or(clipped_values, values <= 0)
+
+ values[clipped_values] = numpy.nan
+
+ if yPositive:
+ return (numpy.nanmin(edges),
+ numpy.nanmax(edges),
+ numpy.nanmin(values),
+ numpy.nanmax(values))
+
+ else: # No log scale on y axis, include 0 in bounds
+ if numpy.all(numpy.isnan(values)):
+ return None
+ return (numpy.nanmin(edges),
+ numpy.nanmax(edges),
+ min(0, numpy.nanmin(values)),
+ max(0, numpy.nanmax(values)))
+
+ def __pickFilledHistogram(self, x: float, y: float) -> typing.Optional[PickingResult]:
+ """Picking implementation for filled histogram
+
+ :param x: X position in pixels
+ :param y: Y position in pixels
+ """
+ if not self.isFill():
+ return None
+
+ plot = self.getPlot()
+ if plot is None:
+ return None
+
+ xData, yData = plot.pixelToData(x, y, axis=self.getYAxis())
+ xmin, xmax, ymin, ymax = self.getBounds()
+ if not xmin < xData < xmax or not ymin < yData < ymax:
+ return None # Outside bounding box
+
+ # Check x
+ edges = self.getBinEdgesData(copy=False)
+ index = numpy.searchsorted(edges, (xData,), side='left')[0] - 1
+ # Safe indexing in histogram values
+ index = numpy.clip(index, 0, len(edges) - 2)
+
+ # Check y
+ baseline = self.getBaseline(copy=False)
+ if baseline is None:
+ baseline = 0 # Default value
+
+ value = self.getValueData(copy=False)[index]
+ if ((baseline <= value and baseline <= yData <= value) or
+ (value < baseline and value <= yData <= baseline)):
+ return PickingResult(self, numpy.array([index]))
+ else:
+ return None
+
+ @docstring(DataItem)
+ def pick(self, x, y):
+ if self.isFill():
+ return self.__pickFilledHistogram(x, y)
+ else:
+ result = super().pick(x, y)
+ if result is None:
+ return None
+ else: # Convert from curve indices to histogram indices
+ return PickingResult(self, numpy.unique(result.getIndices() // 2))
+
+ def getValueData(self, copy=True):
+ """The values of the histogram
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :returns: The values of the histogram
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self._histogram, copy=copy)
+
+ def getBinEdgesData(self, copy=True):
+ """The bin edges of the histogram (number of histogram values + 1)
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :returns: The bin edges of the histogram
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self._edges, copy=copy)
+
+ def getData(self, copy=True):
+ """Return the histogram values, bin edges and baseline
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :returns: (N histogram value, N+1 bin edges)
+ :rtype: 2-tuple of numpy.nadarray
+ """
+ return (self.getValueData(copy),
+ self.getBinEdgesData(copy),
+ self.getBaseline(copy))
+
+ def setData(self, histogram, edges, align='center', baseline=None,
+ copy=True):
+ """Set the histogram values and bin edges.
+
+ :param numpy.ndarray histogram: The values of the histogram.
+ :param numpy.ndarray edges:
+ The bin edges of the histogram.
+ If histogram and edges have the same length, the bin edges
+ are computed according to the align parameter.
+ :param str align:
+ In case histogram values and edges have the same length N,
+ the N+1 bin edges are computed according to the alignment in:
+ 'center' (default), 'left', 'right'.
+ :param baseline: histogram baseline
+ :type baseline: Union[None,float,numpy.ndarray]
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ """
+ histogram = numpy.array(histogram, copy=copy)
+ edges = numpy.array(edges, copy=copy)
+
+ assert histogram.ndim == 1
+ assert edges.ndim == 1
+ assert edges.size in (histogram.size, histogram.size + 1)
+ assert align in ('center', 'left', 'right')
+
+ if histogram.size == 0: # No data
+ self._histogram = ()
+ self._edges = ()
+ else:
+ if edges.size == histogram.size: # Compute true bin edges
+ edges = _computeEdges(edges, align)
+
+ # Check that bin edges are monotonic
+ edgesDiff = numpy.diff(edges)
+ edgesDiff = edgesDiff[numpy.logical_not(numpy.isnan(edgesDiff))]
+ assert numpy.all(edgesDiff >= 0) or numpy.all(edgesDiff <= 0)
+ # manage baseline
+ if (isinstance(baseline, abc.Iterable)):
+ baseline = numpy.array(baseline)
+ if baseline.size == histogram.size:
+ new_baseline = numpy.empty(baseline.shape[0] * 2)
+ for i_value, value in enumerate(baseline):
+ new_baseline[i_value*2:i_value*2+2] = value
+ baseline = new_baseline
+ self._histogram = histogram
+ self._edges = edges
+ self._alignement = align
+ self._setBaseline(baseline)
+
+ self._boundsChanged()
+ self._updated(ItemChangedType.DATA)
+
+ def getAlignment(self):
+ """
+
+ :return: histogram alignement. Value in ('center', 'left', 'right').
+ """
+ return self._alignement
+
+ def _revertComputeEdges(self, x, histogramType):
+ """Compute the edges from a set of xs and a rule to generate the edges
+
+ :param x: the x value of the curve to transform into an histogram
+ :param histogramType: the type of histogram we wan't to generate.
+ This define the way to center the histogram values compared to the
+ curve value. Possible values can be::
+
+ - 'left'
+ - 'right'
+ - 'center'
+
+ :return: the edges for the given x and the histogramType
+ """
+ # for now we consider that the spaces between xs are constant
+ edges = x.copy()
+ if histogramType == 'left':
+ return edges[1:]
+ if histogramType == 'center':
+ edges = (edges[1:] + edges[:-1]) / 2.0
+ if histogramType == 'right':
+ width = 1
+ if len(x) > 1:
+ width = x[-1] + x[-2]
+ edges = edges[:-1]
+ return edges
diff --git a/src/silx/gui/plot/items/image.py b/src/silx/gui/plot/items/image.py
new file mode 100644
index 0000000..5cc719b
--- /dev/null
+++ b/src/silx/gui/plot/items/image.py
@@ -0,0 +1,641 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the :class:`ImageData` and :class:`ImageRgba` items
+of the :class:`Plot`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "08/12/2020"
+
+try:
+ from collections import abc
+except ImportError: # Python2 support
+ import collections as abc
+import logging
+
+import numpy
+
+from ....utils.proxy import docstring
+from .core import (DataItem, LabelsMixIn, DraggableMixIn, ColormapMixIn,
+ AlphaMixIn, ItemChangedType)
+
+_logger = logging.getLogger(__name__)
+
+
+def _convertImageToRgba32(image, copy=True):
+ """Convert an RGB or RGBA image to RGBA32.
+
+ It converts from floats in [0, 1], bool, integer and uint in [0, 255]
+
+ If the input image is already an RGBA32 image,
+ the returned image shares the same data.
+
+ :param image: Image to convert to
+ :type image: numpy.ndarray with 3 dimensions: height, width, color channels
+ :param bool copy: True (Default) to get a copy, False, avoid copy if possible
+ :return: The image converted to RGBA32 with dimension: (height, width, 4)
+ :rtype: numpy.ndarray of uint8
+ """
+ assert image.ndim == 3
+ assert image.shape[-1] in (3, 4)
+
+ # Convert type to uint8
+ if image.dtype.name != 'uint8':
+ if image.dtype.kind == 'f': # Float in [0, 1]
+ image = (numpy.clip(image, 0., 1.) * 255).astype(numpy.uint8)
+ elif image.dtype.kind == 'b': # boolean
+ image = image.astype(numpy.uint8) * 255
+ elif image.dtype.kind in ('i', 'u'): # int, uint
+ image = numpy.clip(image, 0, 255).astype(numpy.uint8)
+ else:
+ raise ValueError('Unsupported image dtype: %s', image.dtype.name)
+ copy = False # A copy as already been done, avoid next one
+
+ # Convert RGB to RGBA
+ if image.shape[-1] == 3:
+ new_image = numpy.empty((image.shape[0], image.shape[1], 4),
+ dtype=numpy.uint8)
+ new_image[:,:,:3] = image
+ new_image[:,:, 3] = 255
+ return new_image # This is a copy anyway
+ else:
+ return numpy.array(image, copy=copy)
+
+
+class ImageBase(DataItem, LabelsMixIn, DraggableMixIn, AlphaMixIn):
+ """Description of an image
+
+ :param numpy.ndarray data: Initial image data
+ """
+
+ def __init__(self, data=None, mask=None):
+ DataItem.__init__(self)
+ LabelsMixIn.__init__(self)
+ DraggableMixIn.__init__(self)
+ AlphaMixIn.__init__(self)
+ if data is None:
+ data = numpy.zeros((0, 0, 4), dtype=numpy.uint8)
+ self._data = data
+ self._mask = mask
+ self.__valueDataCache = None # Store default data
+ self._origin = (0., 0.)
+ self._scale = (1., 1.)
+
+ def __getitem__(self, item):
+ """Compatibility with PyMca and silx <= 0.4.0"""
+ if isinstance(item, slice):
+ return [self[index] for index in range(*item.indices(5))]
+ elif item == 0:
+ return self.getData(copy=False)
+ elif item == 1:
+ return self.getName()
+ elif item == 2:
+ info = self.getInfo(copy=False)
+ return {} if info is None else info
+ elif item == 3:
+ return None
+ elif item == 4:
+ params = {
+ 'info': self.getInfo(),
+ 'origin': self.getOrigin(),
+ 'scale': self.getScale(),
+ 'z': self.getZValue(),
+ 'selectable': self.isSelectable(),
+ 'draggable': self.isDraggable(),
+ 'colormap': None,
+ 'xlabel': self.getXLabel(),
+ 'ylabel': self.getYLabel(),
+ }
+ return params
+ else:
+ raise IndexError("Index out of range: %s" % str(item))
+
+ def _isPlotLinear(self, plot):
+ """Return True if plot only uses linear scale for both of x and y
+ axes."""
+ linear = plot.getXAxis().LINEAR
+ if plot.getXAxis().getScale() != linear:
+ return False
+ if plot.getYAxis().getScale() != linear:
+ return False
+ return True
+
+ def _getBounds(self):
+ if self.getData(copy=False).size == 0: # Empty data
+ return None
+
+ height, width = self.getData(copy=False).shape[:2]
+ origin = self.getOrigin()
+ scale = self.getScale()
+ # Taking care of scale might be < 0
+ xmin, xmax = origin[0], origin[0] + width * scale[0]
+ if xmin > xmax:
+ xmin, xmax = xmax, xmin
+ # Taking care of scale might be < 0
+ ymin, ymax = origin[1], origin[1] + height * scale[1]
+ if ymin > ymax:
+ ymin, ymax = ymax, ymin
+
+ plot = self.getPlot()
+ if plot is not None and not self._isPlotLinear(plot):
+ return None
+ else:
+ return xmin, xmax, ymin, ymax
+
+ @docstring(DraggableMixIn)
+ def drag(self, from_, to):
+ origin = self.getOrigin()
+ self.setOrigin((origin[0] + to[0] - from_[0],
+ origin[1] + to[1] - from_[1]))
+
+ def getData(self, copy=True):
+ """Returns the image data
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self._data, copy=copy)
+
+ def setData(self, data):
+ """Set the image data
+
+ :param numpy.ndarray data:
+ """
+ previousShape = self._data.shape
+ self._data = data
+ self._valueDataChanged()
+ self._boundsChanged()
+ self._updated(ItemChangedType.DATA)
+
+ if (self.getMaskData(copy=False) is not None and
+ previousShape != self._data.shape):
+ # Data shape changed, so mask shape changes.
+ # Send event, mask is lazily updated in getMaskData
+ self._updated(ItemChangedType.MASK)
+
+ def getMaskData(self, copy=True):
+ """Returns the mask data
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: Union[None,numpy.ndarray]
+ """
+ if self._mask is None:
+ return None
+
+ # Update mask if it does not match data shape
+ shape = self.getData(copy=False).shape[:2]
+ if self._mask.shape != shape:
+ # Clip/extend mask to match data
+ newMask = numpy.zeros(shape, dtype=self._mask.dtype)
+ newMask[:self._mask.shape[0], :self._mask.shape[1]] = self._mask[:shape[0], :shape[1]]
+ self._mask = newMask
+
+ return numpy.array(self._mask, copy=copy)
+
+ def setMaskData(self, mask, copy=True):
+ """Set the image data
+
+ :param numpy.ndarray data:
+ :param bool copy: True (Default) to make a copy,
+ False to use as is (do not modify!)
+ """
+ if mask is not None:
+ mask = numpy.array(mask, copy=copy)
+
+ shape = self.getData(copy=False).shape[:2]
+ if mask.shape != shape:
+ _logger.warning("Inconsistent shape between mask and data %s, %s", mask.shape, shape)
+ # Clip/extent is done lazily in getMaskData
+ elif self._mask is None:
+ return # No update
+
+ self._mask = mask
+ self._valueDataChanged()
+ self._updated(ItemChangedType.MASK)
+
+ def _valueDataChanged(self):
+ """Clear cache of default data array"""
+ self.__valueDataCache = None
+
+ def _getValueData(self, copy=True):
+ """Return data used by :meth:`getValueData`
+
+ :param bool copy:
+ :rtype: numpy.ndarray
+ """
+ return self.getData(copy=copy)
+
+ def getValueData(self, copy=True):
+ """Return data (converted to int or float) with mask applied.
+
+ Masked values are set to Not-A-Number.
+ It returns a 2D array of values (int or float).
+
+ :param bool copy:
+ :rtype: numpy.ndarray
+ """
+ if self.__valueDataCache is None:
+ data = self._getValueData(copy=False)
+ mask = self.getMaskData(copy=False)
+ if mask is not None:
+ if numpy.issubdtype(data.dtype, numpy.floating):
+ dtype = data.dtype
+ else:
+ dtype = numpy.float64
+ data = numpy.array(data, dtype=dtype, copy=True)
+ data[mask != 0] = numpy.NaN
+ self.__valueDataCache = data
+ return numpy.array(self.__valueDataCache, copy=copy)
+
+ def getRgbaImageData(self, copy=True):
+ """Get the displayed RGB(A) image
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :returns: numpy.ndarray of uint8 of shape (height, width, 4)
+ """
+ raise NotImplementedError('This MUST be implemented in sub-class')
+
+ def getOrigin(self):
+ """Returns the offset from origin at which to display the image.
+
+ :rtype: 2-tuple of float
+ """
+ return self._origin
+
+ def setOrigin(self, origin):
+ """Set the offset from origin at which to display the image.
+
+ :param origin: (ox, oy) Offset from origin
+ :type origin: float or 2-tuple of float
+ """
+ if isinstance(origin, abc.Sequence):
+ origin = float(origin[0]), float(origin[1])
+ else: # single value origin
+ origin = float(origin), float(origin)
+ if origin != self._origin:
+ self._origin = origin
+ self._boundsChanged()
+ self._updated(ItemChangedType.POSITION)
+
+ def getScale(self):
+ """Returns the scale of the image in data coordinates.
+
+ :rtype: 2-tuple of float
+ """
+ return self._scale
+
+ def setScale(self, scale):
+ """Set the scale of the image
+
+ :param scale: (sx, sy) Scale of the image
+ :type scale: float or 2-tuple of float
+ """
+ if isinstance(scale, abc.Sequence):
+ scale = float(scale[0]), float(scale[1])
+ else: # single value scale
+ scale = float(scale), float(scale)
+
+ if scale != self._scale:
+ self._scale = scale
+ self._boundsChanged()
+ self._updated(ItemChangedType.SCALE)
+
+
+class ImageDataBase(ImageBase, ColormapMixIn):
+ """Base class for colormapped 2D data image"""
+
+ def __init__(self):
+ ImageBase.__init__(self, numpy.zeros((0, 0), dtype=numpy.float32))
+ ColormapMixIn.__init__(self)
+
+ def _getColormapForRendering(self):
+ colormap = self.getColormap()
+ if colormap.isAutoscale():
+ # Avoid backend to compute autoscale: use item cache
+ colormap = colormap.copy()
+ colormap.setVRange(*colormap.getColormapRange(self))
+ return colormap
+
+ def getRgbaImageData(self, copy=True):
+ """Get the displayed RGB(A) image
+
+ :returns: Array of uint8 of shape (height, width, 4)
+ :rtype: numpy.ndarray
+ """
+ return self.getColormap().applyToData(self)
+
+ def setData(self, data, copy=True):
+ """"Set the image data
+
+ :param numpy.ndarray data: Data array with 2 dimensions (h, w)
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ """
+ data = numpy.array(data, copy=copy)
+ assert data.ndim == 2
+ if data.dtype.kind == 'b':
+ _logger.warning(
+ 'Converting boolean image to int8 to plot it.')
+ data = numpy.array(data, copy=False, dtype=numpy.int8)
+ elif numpy.iscomplexobj(data):
+ _logger.warning(
+ 'Converting complex image to absolute value to plot it.')
+ data = numpy.absolute(data)
+ super().setData(data)
+
+ def _updated(self, event=None, checkVisibility=True):
+ # Synchronizes colormapped data if changed
+ if event in (ItemChangedType.DATA, ItemChangedType.MASK):
+ self._setColormappedData(self.getValueData(copy=False), copy=False)
+ super()._updated(event=event, checkVisibility=checkVisibility)
+
+
+class ImageData(ImageDataBase):
+ """Description of a data image with a colormap"""
+
+ def __init__(self):
+ ImageDataBase.__init__(self)
+ self._alternativeImage = None
+ self.__alpha = None
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ plot = self.getPlot()
+ assert plot is not None
+ if not self._isPlotLinear(plot):
+ # Do not render with non linear scales
+ return None
+
+ if (self.getAlternativeImageData(copy=False) is not None or
+ self.getAlphaData(copy=False) is not None):
+ dataToUse = self.getRgbaImageData(copy=False)
+ else:
+ dataToUse = self.getData(copy=False)
+
+ if dataToUse.size == 0:
+ return None # No data to display
+
+ return backend.addImage(dataToUse,
+ origin=self.getOrigin(),
+ scale=self.getScale(),
+ colormap=self._getColormapForRendering(),
+ alpha=self.getAlpha())
+
+ def __getitem__(self, item):
+ """Compatibility with PyMca and silx <= 0.4.0"""
+ if item == 3:
+ return self.getAlternativeImageData(copy=False)
+
+ params = ImageBase.__getitem__(self, item)
+ if item == 4:
+ params['colormap'] = self.getColormap()
+
+ return params
+
+ def getRgbaImageData(self, copy=True):
+ """Get the displayed RGB(A) image
+
+ :returns: Array of uint8 of shape (height, width, 4)
+ :rtype: numpy.ndarray
+ """
+ alternative = self.getAlternativeImageData(copy=False)
+ if alternative is not None:
+ return _convertImageToRgba32(alternative, copy=copy)
+ else:
+ image = super().getRgbaImageData(copy=copy)
+ alphaImage = self.getAlphaData(copy=False)
+ if alphaImage is not None:
+ # Apply transparency
+ image[:,:, 3] = image[:,:, 3] * alphaImage
+ return image
+
+ def getAlternativeImageData(self, copy=True):
+ """Get the optional RGBA image that is displayed instead of the data
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: Union[None,numpy.ndarray]
+ """
+ if self._alternativeImage is None:
+ return None
+ else:
+ return numpy.array(self._alternativeImage, copy=copy)
+
+ def getAlphaData(self, copy=True):
+ """Get the optional transparency image applied on the data
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: Union[None,numpy.ndarray]
+ """
+ if self.__alpha is None:
+ return None
+ else:
+ return numpy.array(self.__alpha, copy=copy)
+
+ def setData(self, data, alternative=None, alpha=None, copy=True):
+ """"Set the image data and optionally an alternative RGB(A) representation
+
+ :param numpy.ndarray data: Data array with 2 dimensions (h, w)
+ :param alternative: RGB(A) image to display instead of data,
+ shape: (h, w, 3 or 4)
+ :type alternative: Union[None,numpy.ndarray]
+ :param alpha: An array of transparency value in [0, 1] to use for
+ display with shape: (h, w)
+ :type alpha: Union[None,numpy.ndarray]
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ """
+ data = numpy.array(data, copy=copy)
+ assert data.ndim == 2
+
+ if alternative is not None:
+ alternative = numpy.array(alternative, copy=copy)
+ assert alternative.ndim == 3
+ assert alternative.shape[2] in (3, 4)
+ assert alternative.shape[:2] == data.shape[:2]
+ self._alternativeImage = alternative
+
+ if alpha is not None:
+ alpha = numpy.array(alpha, copy=copy)
+ assert alpha.shape == data.shape
+ if alpha.dtype.kind != 'f':
+ alpha = alpha.astype(numpy.float32)
+ if numpy.any(numpy.logical_or(alpha < 0., alpha > 1.)):
+ alpha = numpy.clip(alpha, 0., 1.)
+ self.__alpha = alpha
+
+ super().setData(data)
+
+
+class ImageRgba(ImageBase):
+ """Description of an RGB(A) image"""
+
+ def __init__(self):
+ ImageBase.__init__(self, numpy.zeros((0, 0, 4), dtype=numpy.uint8))
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ plot = self.getPlot()
+ assert plot is not None
+ if not self._isPlotLinear(plot):
+ # Do not render with non linear scales
+ return None
+
+ data = self.getData(copy=False)
+
+ if data.size == 0:
+ return None # No data to display
+
+ return backend.addImage(data,
+ origin=self.getOrigin(),
+ scale=self.getScale(),
+ colormap=None,
+ alpha=self.getAlpha())
+
+ def getRgbaImageData(self, copy=True):
+ """Get the displayed RGB(A) image
+
+ :returns: numpy.ndarray of uint8 of shape (height, width, 4)
+ """
+ return _convertImageToRgba32(self.getData(copy=False), copy=copy)
+
+ def setData(self, data, copy=True):
+ """Set the image data
+
+ :param data: RGB(A) image data to set
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ """
+ data = numpy.array(data, copy=copy)
+ assert data.ndim == 3
+ assert data.shape[-1] in (3, 4)
+ super().setData(data)
+
+ def _getValueData(self, copy=True):
+ """Compute the intensity of the RGBA image as default data.
+
+ Conversion: https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion
+
+ :param bool copy:
+ """
+ rgba = self.getRgbaImageData(copy=False).astype(numpy.float32)
+ intensity = (rgba[:, :, 0] * 0.299 +
+ rgba[:, :, 1] * 0.587 +
+ rgba[:, :, 2] * 0.114)
+ intensity *= rgba[:, :, 3] / 255.
+ return intensity
+
+
+class MaskImageData(ImageData):
+ """Description of an image used as a mask.
+
+ This class is used to flag mask items. This information is used to improve
+ internal silx widgets.
+ """
+ pass
+
+
+class ImageStack(ImageData):
+ """Item to store a stack of images and to show it in the plot as one
+ of the images of the stack.
+
+ The stack is a 3D array ordered this way: `frame id, y, x`.
+ So the first image of the stack can be reached this way: `stack[0, :, :]`
+ """
+
+ def __init__(self):
+ ImageData.__init__(self)
+ self.__stack = None
+ """A 3D numpy array (or a mimic one, see ListOfImages)"""
+ self.__stackPosition = None
+ """Displayed position in the cube"""
+
+ def setStackData(self, stack, position=None, copy=True):
+ """Set the stack data
+
+ :param stack: A 3D numpy array like
+ :param int position: The position of the displayed image in the stack
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ """
+ if self.__stack is stack:
+ return
+ if copy:
+ stack = numpy.array(stack)
+ assert stack.ndim == 3
+ self.__stack = stack
+ if position is not None:
+ self.__stackPosition = position
+ if self.__stackPosition is None:
+ self.__stackPosition = 0
+ self.__updateDisplayedData()
+
+ def getStackData(self, copy=True):
+ """Get the stored stack array.
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: A 3D numpy array, or numpy array like
+ """
+ if copy:
+ return numpy.array(self.__stack)
+ else:
+ return self.__stack
+
+ def setStackPosition(self, pos):
+ """Set the displayed position on the stack.
+
+ This function will clamp the stack position according to
+ the real size of the first axis of the stack.
+
+ :param int pos: A position on the first axis of the stack.
+ """
+ if self.__stackPosition == pos:
+ return
+ self.__stackPosition = pos
+ self.__updateDisplayedData()
+
+ def getStackPosition(self):
+ """Get the displayed position of the stack.
+
+ :rtype: int
+ """
+ return self.__stackPosition
+
+ def __updateDisplayedData(self):
+ """Update the displayed frame whenever the stack or the stack
+ position are updated."""
+ if self.__stack is None or self.__stackPosition is None:
+ empty = numpy.array([]).reshape(0, 0)
+ self.setData(empty, copy=False)
+ return
+ size = len(self.__stack)
+ self.__stackPosition = numpy.clip(self.__stackPosition, 0, size)
+ self.setData(self.__stack[self.__stackPosition], copy=False)
diff --git a/src/silx/gui/plot/items/image_aggregated.py b/src/silx/gui/plot/items/image_aggregated.py
new file mode 100644
index 0000000..75fdd59
--- /dev/null
+++ b/src/silx/gui/plot/items/image_aggregated.py
@@ -0,0 +1,229 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the :class:`ImageDataAggregated` items of the :class:`Plot`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "07/07/2021"
+
+import enum
+import logging
+from typing import Tuple, Union
+
+import numpy
+
+from ....utils.enum import Enum as _Enum
+from ....utils.proxy import docstring
+from .axis import Axis
+from .core import ItemChangedType
+from .image import ImageDataBase
+from ._pick import PickingResult
+
+
+_logger = logging.getLogger(__name__)
+
+
+class ImageDataAggregated(ImageDataBase):
+ """Item displaying an image as a density map."""
+
+ @enum.unique
+ class Aggregation(_Enum):
+ NONE = "none"
+ "Do not aggregate data, display as is (default)"
+
+ MAX = "max"
+ "Aggregates elements with max (ignore NaNs)"
+
+ MEAN = "mean"
+ "Aggregates elements with mean (ignore NaNs)"
+
+ MIN = "min"
+ "Aggregates elements with min (ignore NaNs)"
+
+ def __init__(self):
+ super().__init__()
+ self.__cacheLODData = {}
+ self.__currentLOD = 0, 0
+ self.__aggregationMode = self.Aggregation.NONE
+
+ def setAggregationMode(self, mode: Union[str,Aggregation]):
+ """Set the aggregation method used to reduce the data to screen resolution.
+
+ :param Aggregation mode: The aggregation method
+ """
+ aggregationMode = self.Aggregation.from_value(mode)
+ if aggregationMode != self.__aggregationMode:
+ self.__aggregationMode = aggregationMode
+ self.__cacheLODData = {} # Clear cache
+ self._updated(ItemChangedType.VISUALIZATION_MODE)
+
+ def getAggregationMode(self) -> Aggregation:
+ """Returns the currently used aggregation method."""
+ return self.__aggregationMode
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ plot = self.getPlot()
+ assert plot is not None
+ if not self._isPlotLinear(plot):
+ # Do not render with non linear scales
+ return None
+
+ data = self.getData(copy=False)
+ if data.size == 0:
+ return None # No data to display
+
+ aggregationMode = self.getAggregationMode()
+ if aggregationMode == self.Aggregation.NONE: # Pass data as it is
+ displayedData = data
+ scale = self.getScale()
+
+ else: # Aggregate data according to level of details
+ if aggregationMode == self.Aggregation.MAX:
+ aggregator = numpy.nanmax
+ elif aggregationMode == self.Aggregation.MEAN:
+ aggregator = numpy.nanmean
+ elif aggregationMode == self.Aggregation.MIN:
+ aggregator = numpy.nanmin
+ else:
+ _logger.error("Unsupported aggregation mode")
+ return None
+
+ lodx, lody = self._getLevelOfDetails()
+
+ if (lodx, lody) not in self.__cacheLODData:
+ height, width = data.shape
+ self.__cacheLODData[(lodx, lody)] = aggregator(
+ data[: (height // lody) * lody, : (width // lodx) * lodx].reshape(
+ height // lody, lody, width // lodx, lodx
+ ),
+ axis=(1, 3),
+ )
+
+ self.__currentLOD = lodx, lody
+ displayedData = self.__cacheLODData[self.__currentLOD]
+
+ sx, sy = self.getScale()
+ scale = sx * lodx, sy * lody
+
+ return backend.addImage(
+ displayedData,
+ origin=self.getOrigin(),
+ scale=scale,
+ colormap=self._getColormapForRendering(),
+ alpha=self.getAlpha(),
+ )
+
+ def _getPixelSizeInData(self, axis="left"):
+ """Returns the size of a pixel in plot data coordinates
+
+ :param str axis: Y axis to use in: 'left' (default), 'right'
+ :return:
+ Size (width, height) of a Qt pixel in data coordinates.
+ Size is None if it cannot be computed
+ :rtype: Union[List[float],None]
+ """
+ assert axis in ("left", "right")
+ plot = self.getPlot()
+ if plot is None:
+ return None
+
+ xaxis = plot.getXAxis()
+ yaxis = plot.getYAxis(axis)
+
+ if (
+ xaxis.getScale() != Axis.LINEAR
+ or yaxis.getScale() != Axis.LINEAR
+ ):
+ raise RuntimeError("Only available with linear axes")
+
+ xmin, xmax = xaxis.getLimits()
+ ymin, ymax = yaxis.getLimits()
+ width, height = plot.getPlotBoundsInPixels()[2:]
+ if width == 0 or height == 0:
+ return None
+ else:
+ return (xmax - xmin) / width, (ymax - ymin) / height
+
+ def _getLevelOfDetails(self) -> Tuple[int, int]:
+ """Return current level of details the image is displayed with."""
+ plot = self.getPlot()
+ if plot is None or not self._isPlotLinear(plot):
+ return 1, 1 # Fallback to bas LOD
+
+ sx, sy = self.getScale()
+ xUnitPerPixel, yUnitPerPixel = self._getPixelSizeInData()
+ lodx = max(1, int(numpy.ceil(xUnitPerPixel / sx)))
+ lody = max(1, int(numpy.ceil(yUnitPerPixel / sy)))
+ return lodx, lody
+
+ @docstring(ImageDataBase)
+ def setData(self, data, copy=True):
+ self.__cacheLODData = {} # Reset cache
+ super().setData(data)
+
+ @docstring(ImageDataBase)
+ def _setPlot(self, plot):
+ """Refresh image when plot limits change"""
+ previousPlot = self.getPlot()
+ if previousPlot is not None:
+ for axis in (previousPlot.getXAxis(), previousPlot.getYAxis()):
+ axis.sigLimitsChanged.disconnect(self.__plotLimitsChanged)
+
+ super()._setPlot(plot)
+
+ if plot is not None:
+ for axis in (plot.getXAxis(), plot.getYAxis()):
+ axis.sigLimitsChanged.connect(self.__plotLimitsChanged)
+
+ def __plotLimitsChanged(self):
+ """Trigger update if level of details has changed"""
+ if (self.getAggregationMode() != self.Aggregation.NONE and
+ self.__currentLOD != self._getLevelOfDetails()):
+ self._updated()
+
+ @docstring(ImageDataBase)
+ def pick(self, x, y):
+ result = super().pick(x, y)
+ if result is None:
+ return None
+
+ # Compute indices in initial data
+ plot = self.getPlot()
+ if plot is None:
+ return None
+ dataPos = plot.pixelToData(x, y, axis="left", check=True)
+ if dataPos is None:
+ return None # Outside plot area
+
+ ox, oy = self.getOrigin()
+ sx, sy = self.getScale()
+ col = int((dataPos[0] - ox) / sx)
+ row = int((dataPos[1] - oy) / sy)
+ height, width = self.getData(copy=False).shape[:2]
+ if 0 <= col < width and 0 <= row < height:
+ return PickingResult(self, ((row,), (col,)))
+ return None
diff --git a/src/silx/gui/plot/items/marker.py b/src/silx/gui/plot/items/marker.py
new file mode 100755
index 0000000..50d070c
--- /dev/null
+++ b/src/silx/gui/plot/items/marker.py
@@ -0,0 +1,281 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides markers item of the :class:`Plot`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "06/03/2017"
+
+
+import logging
+
+from ....utils.proxy import docstring
+from .core import (Item, DraggableMixIn, ColorMixIn, LineMixIn, SymbolMixIn,
+ ItemChangedType, YAxisMixIn)
+from silx.gui import qt
+
+_logger = logging.getLogger(__name__)
+
+
+class MarkerBase(Item, DraggableMixIn, ColorMixIn, YAxisMixIn):
+ """Base class for markers"""
+
+ sigDragStarted = qt.Signal()
+ """Signal emitted when the marker is pressed"""
+ sigDragFinished = qt.Signal()
+ """Signal emitted when the marker is released"""
+
+ _DEFAULT_COLOR = (0., 0., 0., 1.)
+ """Default color of the markers"""
+
+ def __init__(self):
+ Item.__init__(self)
+ DraggableMixIn.__init__(self)
+ ColorMixIn.__init__(self)
+ YAxisMixIn.__init__(self)
+
+ self._text = ''
+ self._x = None
+ self._y = None
+ self._constraint = self._defaultConstraint
+ self.__isBeingDragged = False
+
+ def _addRendererCall(self, backend,
+ symbol=None, linestyle='-', linewidth=1):
+ """Perform the update of the backend renderer"""
+ return backend.addMarker(
+ x=self.getXPosition(),
+ y=self.getYPosition(),
+ text=self.getText(),
+ color=self.getColor(),
+ symbol=symbol,
+ linestyle=linestyle,
+ linewidth=linewidth,
+ constraint=self.getConstraint(),
+ yaxis=self.getYAxis())
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ raise NotImplementedError()
+
+ @docstring(DraggableMixIn)
+ def drag(self, from_, to):
+ self.setPosition(to[0], to[1])
+
+ def isOverlay(self):
+ """Returns True: A marker is always rendered as an overlay.
+
+ :rtype: bool
+ """
+ return True
+
+ def getText(self):
+ """Returns marker text.
+
+ :rtype: str
+ """
+ return self._text
+
+ def setText(self, text):
+ """Set the text of the marker.
+
+ :param str text: The text to use
+ """
+ text = str(text)
+ if text != self._text:
+ self._text = text
+ self._updated(ItemChangedType.TEXT)
+
+ def getXPosition(self):
+ """Returns the X position of the marker line in data coordinates
+
+ :rtype: float or None
+ """
+ return self._x
+
+ def getYPosition(self):
+ """Returns the Y position of the marker line in data coordinates
+
+ :rtype: float or None
+ """
+ return self._y
+
+ def getPosition(self):
+ """Returns the (x, y) position of the marker in data coordinates
+
+ :rtype: 2-tuple of float or None
+ """
+ return self._x, self._y
+
+ def setPosition(self, x, y):
+ """Set marker position in data coordinates
+
+ Constraint are applied if any.
+
+ :param float x: X coordinates in data frame
+ :param float y: Y coordinates in data frame
+ """
+ x, y = self.getConstraint()(x, y)
+ x, y = float(x), float(y)
+ if x != self._x or y != self._y:
+ self._x, self._y = x, y
+ self._updated(ItemChangedType.POSITION)
+
+ def getConstraint(self):
+ """Returns the dragging constraint of this item"""
+ return self._constraint
+
+ def _setConstraint(self, constraint): # TODO support update
+ """Set the constraint.
+
+ This is private for now as update is not handled.
+
+ :param callable constraint:
+ :param constraint: A function filtering item displacement by
+ dragging operations or None for no filter.
+ This function is called each time the item is
+ moved.
+ This is only used if isDraggable returns True.
+ :type constraint: None or a callable that takes the coordinates of
+ the current cursor position in the plot as input
+ and that returns the filtered coordinates.
+ """
+ if constraint is None:
+ constraint = self._defaultConstraint
+ assert callable(constraint)
+ self._constraint = constraint
+
+ @staticmethod
+ def _defaultConstraint(*args):
+ """Default constraint not doing anything"""
+ return args
+
+ def _startDrag(self):
+ self.__isBeingDragged = True
+ self.sigDragStarted.emit()
+
+ def _endDrag(self):
+ self.__isBeingDragged = False
+ self.sigDragFinished.emit()
+
+ def isBeingDragged(self) -> bool:
+ """Returns whether the marker is currently dragged by the user."""
+ return self.__isBeingDragged
+
+
+class Marker(MarkerBase, SymbolMixIn):
+ """Description of a marker"""
+
+ _DEFAULT_SYMBOL = '+'
+ """Default symbol of the marker"""
+
+ def __init__(self):
+ MarkerBase.__init__(self)
+ SymbolMixIn.__init__(self)
+
+ self._x = 0.
+ self._y = 0.
+
+ def _addBackendRenderer(self, backend):
+ return self._addRendererCall(backend, symbol=self.getSymbol())
+
+ def _setConstraint(self, constraint):
+ """Set the constraint function of the marker drag.
+
+ It also supports 'horizontal' and 'vertical' str as constraint.
+
+ :param constraint: The constraint of the dragging of this marker
+ :type: constraint: callable or str
+ """
+ if constraint == 'horizontal':
+ constraint = self._horizontalConstraint
+ elif constraint == 'vertical':
+ constraint = self._verticalConstraint
+
+ super(Marker, self)._setConstraint(constraint)
+
+ def _horizontalConstraint(self, _, y):
+ return self.getXPosition(), y
+
+ def _verticalConstraint(self, x, _):
+ return x, self.getYPosition()
+
+
+class _LineMarker(MarkerBase, LineMixIn):
+ """Base class for line markers"""
+
+ def __init__(self):
+ MarkerBase.__init__(self)
+ LineMixIn.__init__(self)
+
+ def _addBackendRenderer(self, backend):
+ return self._addRendererCall(backend,
+ linestyle=self.getLineStyle(),
+ linewidth=self.getLineWidth())
+
+
+class XMarker(_LineMarker):
+ """Description of a marker"""
+
+ def __init__(self):
+ _LineMarker.__init__(self)
+ self._x = 0.
+
+ def setPosition(self, x, y):
+ """Set marker line position in data coordinates
+
+ Constraint are applied if any.
+
+ :param float x: X coordinates in data frame
+ :param float y: Y coordinates in data frame
+ """
+ x, _ = self.getConstraint()(x, y)
+ x = float(x)
+ if x != self._x:
+ self._x = x
+ self._updated(ItemChangedType.POSITION)
+
+
+class YMarker(_LineMarker):
+ """Description of a marker"""
+
+ def __init__(self):
+ _LineMarker.__init__(self)
+ self._y = 0.
+
+ def setPosition(self, x, y):
+ """Set marker line position in data coordinates
+
+ Constraint are applied if any.
+
+ :param float x: X coordinates in data frame
+ :param float y: Y coordinates in data frame
+ """
+ _, y = self.getConstraint()(x, y)
+ y = float(y)
+ if y != self._y:
+ self._y = y
+ self._updated(ItemChangedType.POSITION)
diff --git a/src/silx/gui/plot/items/roi.py b/src/silx/gui/plot/items/roi.py
new file mode 100644
index 0000000..38a1424
--- /dev/null
+++ b/src/silx/gui/plot/items/roi.py
@@ -0,0 +1,1519 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides ROI item for the :class:`~silx.gui.plot.PlotWidget`.
+
+.. inheritance-diagram::
+ silx.gui.plot.items.roi
+ :parts: 1
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+
+import logging
+import numpy
+
+from ... import utils
+from .. import items
+from ...colors import rgba
+from silx.image.shapes import Polygon
+from silx.image._boundingbox import _BoundingBox
+from ....utils.proxy import docstring
+from ..utils.intersections import segments_intersection
+from ._roi_base import _RegionOfInterestBase
+
+# He following imports have to be exposed by this module
+from ._roi_base import RegionOfInterest
+from ._roi_base import HandleBasedROI
+from ._arc_roi import ArcROI # noqa
+from ._roi_base import InteractionModeMixIn # noqa
+from ._roi_base import RoiInteractionMode # noqa
+
+
+logger = logging.getLogger(__name__)
+
+
+class PointROI(RegionOfInterest, items.SymbolMixIn):
+ """A ROI identifying a point in a 2D plot."""
+
+ ICON = 'add-shape-point'
+ NAME = 'point markers'
+ SHORT_NAME = "point"
+ """Metadata for this kind of ROI"""
+
+ _plotShape = "point"
+ """Plot shape which is used for the first interaction"""
+
+ _DEFAULT_SYMBOL = '+'
+ """Default symbol of the PointROI
+
+ It overwrite the `SymbolMixIn` class attribte.
+ """
+
+ def __init__(self, parent=None):
+ RegionOfInterest.__init__(self, parent=parent)
+ items.SymbolMixIn.__init__(self)
+ self._marker = items.Marker()
+ self._marker.sigItemChanged.connect(self._pointPositionChanged)
+ self._marker.setSymbol(self._DEFAULT_SYMBOL)
+ self._marker.sigDragStarted.connect(self._editingStarted)
+ self._marker.sigDragFinished.connect(self._editingFinished)
+ self.addItem(self._marker)
+
+ def setFirstShapePoints(self, points):
+ self.setPosition(points[0])
+
+ def _updated(self, event=None, checkVisibility=True):
+ if event == items.ItemChangedType.NAME:
+ label = self.getName()
+ self._marker.setText(label)
+ elif event == items.ItemChangedType.EDITABLE:
+ self._marker._setDraggable(self.isEditable())
+ elif event in [items.ItemChangedType.VISIBLE,
+ items.ItemChangedType.SELECTABLE]:
+ self._updateItemProperty(event, self, self._marker)
+ super(PointROI, self)._updated(event, checkVisibility)
+
+ def _updatedStyle(self, event, style):
+ self._marker.setColor(style.getColor())
+
+ def getPosition(self):
+ """Returns the position of this ROI
+
+ :rtype: numpy.ndarray
+ """
+ return self._marker.getPosition()
+
+ def setPosition(self, pos):
+ """Set the position of this ROI
+
+ :param numpy.ndarray pos: 2d-coordinate of this point
+ """
+ self._marker.setPosition(*pos)
+
+ @docstring(_RegionOfInterestBase)
+ def contains(self, position):
+ roiPos = self.getPosition()
+ return position[0] == roiPos[0] and position[1] == roiPos[1]
+
+ def _pointPositionChanged(self, event):
+ """Handle position changed events of the marker"""
+ if event is items.ItemChangedType.POSITION:
+ self.sigRegionChanged.emit()
+
+ def __str__(self):
+ params = '%f %f' % self.getPosition()
+ return "%s(%s)" % (self.__class__.__name__, params)
+
+
+class CrossROI(HandleBasedROI, items.LineMixIn):
+ """A ROI identifying a point in a 2D plot and displayed as a cross
+ """
+
+ ICON = 'add-shape-cross'
+ NAME = 'cross marker'
+ SHORT_NAME = "cross"
+ """Metadata for this kind of ROI"""
+
+ _plotShape = "point"
+ """Plot shape which is used for the first interaction"""
+
+ def __init__(self, parent=None):
+ HandleBasedROI.__init__(self, parent=parent)
+ items.LineMixIn.__init__(self)
+ self._handle = self.addHandle()
+ self._handle.sigItemChanged.connect(self._handlePositionChanged)
+ self._handleLabel = self.addLabelHandle()
+ self._vmarker = self.addUserHandle(items.YMarker())
+ self._vmarker._setSelectable(False)
+ self._vmarker._setDraggable(False)
+ self._vmarker.setPosition(*self.getPosition())
+ self._hmarker = self.addUserHandle(items.XMarker())
+ self._hmarker._setSelectable(False)
+ self._hmarker._setDraggable(False)
+ self._hmarker.setPosition(*self.getPosition())
+
+ def _updated(self, event=None, checkVisibility=True):
+ if event in [items.ItemChangedType.VISIBLE]:
+ markers = (self._vmarker, self._hmarker)
+ self._updateItemProperty(event, self, markers)
+ super(CrossROI, self)._updated(event, checkVisibility)
+
+ def _updateText(self, text):
+ self._handleLabel.setText(text)
+
+ def _updatedStyle(self, event, style):
+ super(CrossROI, self)._updatedStyle(event, style)
+ for marker in [self._vmarker, self._hmarker]:
+ marker.setColor(style.getColor())
+ marker.setLineStyle(style.getLineStyle())
+ marker.setLineWidth(style.getLineWidth())
+
+ def setFirstShapePoints(self, points):
+ pos = points[0]
+ self.setPosition(pos)
+
+ def getPosition(self):
+ """Returns the position of this ROI
+
+ :rtype: numpy.ndarray
+ """
+ return self._handle.getPosition()
+
+ def setPosition(self, pos):
+ """Set the position of this ROI
+
+ :param numpy.ndarray pos: 2d-coordinate of this point
+ """
+ self._handle.setPosition(*pos)
+
+ def _handlePositionChanged(self, event):
+ """Handle center marker position updates"""
+ if event is items.ItemChangedType.POSITION:
+ position = self.getPosition()
+ self._handleLabel.setPosition(*position)
+ self._vmarker.setPosition(*position)
+ self._hmarker.setPosition(*position)
+ self.sigRegionChanged.emit()
+
+ @docstring(HandleBasedROI)
+ def contains(self, position):
+ roiPos = self.getPosition()
+ return position[0] == roiPos[0] or position[1] == roiPos[1]
+
+
+class LineROI(HandleBasedROI, items.LineMixIn):
+ """A ROI identifying a line in a 2D plot.
+
+ This ROI provides 1 anchor for each boundary of the line, plus an center
+ in the center to translate the full ROI.
+ """
+
+ ICON = 'add-shape-diagonal'
+ NAME = 'line ROI'
+ SHORT_NAME = "line"
+ """Metadata for this kind of ROI"""
+
+ _plotShape = "line"
+ """Plot shape which is used for the first interaction"""
+
+ def __init__(self, parent=None):
+ HandleBasedROI.__init__(self, parent=parent)
+ items.LineMixIn.__init__(self)
+ self._handleStart = self.addHandle()
+ self._handleEnd = self.addHandle()
+ self._handleCenter = self.addTranslateHandle()
+ self._handleLabel = self.addLabelHandle()
+
+ shape = items.Shape("polylines")
+ shape.setPoints([[0, 0], [0, 0]])
+ shape.setColor(rgba(self.getColor()))
+ shape.setFill(False)
+ shape.setOverlay(True)
+ shape.setLineStyle(self.getLineStyle())
+ shape.setLineWidth(self.getLineWidth())
+ self.__shape = shape
+ self.addItem(shape)
+
+ def _updated(self, event=None, checkVisibility=True):
+ if event == items.ItemChangedType.VISIBLE:
+ self._updateItemProperty(event, self, self.__shape)
+ super(LineROI, self)._updated(event, checkVisibility)
+
+ def _updatedStyle(self, event, style):
+ super(LineROI, self)._updatedStyle(event, style)
+ self.__shape.setColor(style.getColor())
+ self.__shape.setLineStyle(style.getLineStyle())
+ self.__shape.setLineWidth(style.getLineWidth())
+
+ def setFirstShapePoints(self, points):
+ assert len(points) == 2
+ self.setEndPoints(points[0], points[1])
+
+ def _updateText(self, text):
+ self._handleLabel.setText(text)
+
+ def setEndPoints(self, startPoint, endPoint):
+ """Set this line location using the ending points
+
+ :param numpy.ndarray startPoint: Staring bounding point of the line
+ :param numpy.ndarray endPoint: Ending bounding point of the line
+ """
+ if not numpy.array_equal((startPoint, endPoint), self.getEndPoints()):
+ self.__updateEndPoints(startPoint, endPoint)
+
+ def __updateEndPoints(self, startPoint, endPoint):
+ """Update marker and shape to match given end points
+
+ :param numpy.ndarray startPoint: Staring bounding point of the line
+ :param numpy.ndarray endPoint: Ending bounding point of the line
+ """
+ startPoint = numpy.array(startPoint)
+ endPoint = numpy.array(endPoint)
+ center = (startPoint + endPoint) * 0.5
+
+ with utils.blockSignals(self._handleStart):
+ self._handleStart.setPosition(startPoint[0], startPoint[1])
+ with utils.blockSignals(self._handleEnd):
+ self._handleEnd.setPosition(endPoint[0], endPoint[1])
+ with utils.blockSignals(self._handleCenter):
+ self._handleCenter.setPosition(center[0], center[1])
+ with utils.blockSignals(self._handleLabel):
+ self._handleLabel.setPosition(center[0], center[1])
+
+ line = numpy.array((startPoint, endPoint))
+ self.__shape.setPoints(line)
+ self.sigRegionChanged.emit()
+
+ def getEndPoints(self):
+ """Returns bounding points of this ROI.
+
+ :rtype: Tuple(numpy.ndarray,numpy.ndarray)
+ """
+ startPoint = numpy.array(self._handleStart.getPosition())
+ endPoint = numpy.array(self._handleEnd.getPosition())
+ return (startPoint, endPoint)
+
+ def handleDragUpdated(self, handle, origin, previous, current):
+ if handle is self._handleStart:
+ _start, end = self.getEndPoints()
+ self.__updateEndPoints(current, end)
+ elif handle is self._handleEnd:
+ start, _end = self.getEndPoints()
+ self.__updateEndPoints(start, current)
+ elif handle is self._handleCenter:
+ start, end = self.getEndPoints()
+ delta = current - previous
+ start += delta
+ end += delta
+ self.setEndPoints(start, end)
+
+ @docstring(_RegionOfInterestBase)
+ def contains(self, position):
+ bottom_left = position[0], position[1]
+ bottom_right = position[0] + 1, position[1]
+ top_left = position[0], position[1] + 1
+ top_right = position[0] + 1, position[1] + 1
+
+ points = self.__shape.getPoints()
+ line_pt1 = points[0]
+ line_pt2 = points[1]
+
+ bb1 = _BoundingBox.from_points(points)
+ if not bb1.contains(position):
+ return False
+
+ return (
+ segments_intersection(seg1_start_pt=line_pt1, seg1_end_pt=line_pt2,
+ seg2_start_pt=bottom_left, seg2_end_pt=bottom_right) or
+ segments_intersection(seg1_start_pt=line_pt1, seg1_end_pt=line_pt2,
+ seg2_start_pt=bottom_right, seg2_end_pt=top_right) or
+ segments_intersection(seg1_start_pt=line_pt1, seg1_end_pt=line_pt2,
+ seg2_start_pt=top_right, seg2_end_pt=top_left) or
+ segments_intersection(seg1_start_pt=line_pt1, seg1_end_pt=line_pt2,
+ seg2_start_pt=top_left, seg2_end_pt=bottom_left)
+ ) is not None
+
+ def __str__(self):
+ start, end = self.getEndPoints()
+ params = start[0], start[1], end[0], end[1]
+ params = 'start: %f %f; end: %f %f' % params
+ return "%s(%s)" % (self.__class__.__name__, params)
+
+
+class HorizontalLineROI(RegionOfInterest, items.LineMixIn):
+ """A ROI identifying an horizontal line in a 2D plot."""
+
+ ICON = 'add-shape-horizontal'
+ NAME = 'horizontal line ROI'
+ SHORT_NAME = "hline"
+ """Metadata for this kind of ROI"""
+
+ _plotShape = "hline"
+ """Plot shape which is used for the first interaction"""
+
+ def __init__(self, parent=None):
+ RegionOfInterest.__init__(self, parent=parent)
+ items.LineMixIn.__init__(self)
+ self._marker = items.YMarker()
+ self._marker.sigItemChanged.connect(self._linePositionChanged)
+ self._marker.sigDragStarted.connect(self._editingStarted)
+ self._marker.sigDragFinished.connect(self._editingFinished)
+ self.addItem(self._marker)
+
+ def _updated(self, event=None, checkVisibility=True):
+ if event == items.ItemChangedType.NAME:
+ label = self.getName()
+ self._marker.setText(label)
+ elif event == items.ItemChangedType.EDITABLE:
+ self._marker._setDraggable(self.isEditable())
+ elif event in [items.ItemChangedType.VISIBLE,
+ items.ItemChangedType.SELECTABLE]:
+ self._updateItemProperty(event, self, self._marker)
+ super(HorizontalLineROI, self)._updated(event, checkVisibility)
+
+ def _updatedStyle(self, event, style):
+ self._marker.setColor(style.getColor())
+ self._marker.setLineStyle(style.getLineStyle())
+ self._marker.setLineWidth(style.getLineWidth())
+
+ def setFirstShapePoints(self, points):
+ pos = points[0, 1]
+ if pos == self.getPosition():
+ return
+ self.setPosition(pos)
+
+ def getPosition(self):
+ """Returns the position of this line if the horizontal axis
+
+ :rtype: float
+ """
+ pos = self._marker.getPosition()
+ return pos[1]
+
+ def setPosition(self, pos):
+ """Set the position of this ROI
+
+ :param float pos: Horizontal position of this line
+ """
+ self._marker.setPosition(0, pos)
+
+ @docstring(_RegionOfInterestBase)
+ def contains(self, position):
+ return position[1] == self.getPosition()
+
+ def _linePositionChanged(self, event):
+ """Handle position changed events of the marker"""
+ if event is items.ItemChangedType.POSITION:
+ self.sigRegionChanged.emit()
+
+ def __str__(self):
+ params = 'y: %f' % self.getPosition()
+ return "%s(%s)" % (self.__class__.__name__, params)
+
+
+class VerticalLineROI(RegionOfInterest, items.LineMixIn):
+ """A ROI identifying a vertical line in a 2D plot."""
+
+ ICON = 'add-shape-vertical'
+ NAME = 'vertical line ROI'
+ SHORT_NAME = "vline"
+ """Metadata for this kind of ROI"""
+
+ _plotShape = "vline"
+ """Plot shape which is used for the first interaction"""
+
+ def __init__(self, parent=None):
+ RegionOfInterest.__init__(self, parent=parent)
+ items.LineMixIn.__init__(self)
+ self._marker = items.XMarker()
+ self._marker.sigItemChanged.connect(self._linePositionChanged)
+ self._marker.sigDragStarted.connect(self._editingStarted)
+ self._marker.sigDragFinished.connect(self._editingFinished)
+ self.addItem(self._marker)
+
+ def _updated(self, event=None, checkVisibility=True):
+ if event == items.ItemChangedType.NAME:
+ label = self.getName()
+ self._marker.setText(label)
+ elif event == items.ItemChangedType.EDITABLE:
+ self._marker._setDraggable(self.isEditable())
+ elif event in [items.ItemChangedType.VISIBLE,
+ items.ItemChangedType.SELECTABLE]:
+ self._updateItemProperty(event, self, self._marker)
+ super(VerticalLineROI, self)._updated(event, checkVisibility)
+
+ def _updatedStyle(self, event, style):
+ self._marker.setColor(style.getColor())
+ self._marker.setLineStyle(style.getLineStyle())
+ self._marker.setLineWidth(style.getLineWidth())
+
+ def setFirstShapePoints(self, points):
+ pos = points[0, 0]
+ self.setPosition(pos)
+
+ def getPosition(self):
+ """Returns the position of this line if the horizontal axis
+
+ :rtype: float
+ """
+ pos = self._marker.getPosition()
+ return pos[0]
+
+ def setPosition(self, pos):
+ """Set the position of this ROI
+
+ :param float pos: Horizontal position of this line
+ """
+ self._marker.setPosition(pos, 0)
+
+ @docstring(RegionOfInterest)
+ def contains(self, position):
+ return position[0] == self.getPosition()
+
+ def _linePositionChanged(self, event):
+ """Handle position changed events of the marker"""
+ if event is items.ItemChangedType.POSITION:
+ self.sigRegionChanged.emit()
+
+ def __str__(self):
+ params = 'x: %f' % self.getPosition()
+ return "%s(%s)" % (self.__class__.__name__, params)
+
+
+class RectangleROI(HandleBasedROI, items.LineMixIn):
+ """A ROI identifying a rectangle in a 2D plot.
+
+ This ROI provides 1 anchor for each corner, plus an anchor in the
+ center to translate the full ROI.
+ """
+
+ ICON = 'add-shape-rectangle'
+ NAME = 'rectangle ROI'
+ SHORT_NAME = "rectangle"
+ """Metadata for this kind of ROI"""
+
+ _plotShape = "rectangle"
+ """Plot shape which is used for the first interaction"""
+
+ def __init__(self, parent=None):
+ HandleBasedROI.__init__(self, parent=parent)
+ items.LineMixIn.__init__(self)
+ self._handleTopLeft = self.addHandle()
+ self._handleTopRight = self.addHandle()
+ self._handleBottomLeft = self.addHandle()
+ self._handleBottomRight = self.addHandle()
+ self._handleCenter = self.addTranslateHandle()
+ self._handleLabel = self.addLabelHandle()
+
+ shape = items.Shape("rectangle")
+ shape.setPoints([[0, 0], [0, 0]])
+ shape.setFill(False)
+ shape.setOverlay(True)
+ shape.setLineStyle(self.getLineStyle())
+ shape.setLineWidth(self.getLineWidth())
+ shape.setColor(rgba(self.getColor()))
+ self.__shape = shape
+ self.addItem(shape)
+
+ def _updated(self, event=None, checkVisibility=True):
+ if event in [items.ItemChangedType.VISIBLE]:
+ self._updateItemProperty(event, self, self.__shape)
+ super(RectangleROI, self)._updated(event, checkVisibility)
+
+ def _updatedStyle(self, event, style):
+ super(RectangleROI, self)._updatedStyle(event, style)
+ self.__shape.setColor(style.getColor())
+ self.__shape.setLineStyle(style.getLineStyle())
+ self.__shape.setLineWidth(style.getLineWidth())
+
+ def setFirstShapePoints(self, points):
+ assert len(points) == 2
+ self._setBound(points)
+
+ def _setBound(self, points):
+ """Initialize the rectangle from a bunch of points"""
+ top = max(points[:, 1])
+ bottom = min(points[:, 1])
+ left = min(points[:, 0])
+ right = max(points[:, 0])
+ size = right - left, top - bottom
+ self._updateGeometry(origin=(left, bottom), size=size)
+
+ def _updateText(self, text):
+ self._handleLabel.setText(text)
+
+ def getCenter(self):
+ """Returns the central point of this rectangle
+
+ :rtype: numpy.ndarray([float,float])
+ """
+ pos = self._handleCenter.getPosition()
+ return numpy.array(pos)
+
+ def getOrigin(self):
+ """Returns the corner point with the smaller coordinates
+
+ :rtype: numpy.ndarray([float,float])
+ """
+ pos = self._handleBottomLeft.getPosition()
+ return numpy.array(pos)
+
+ def getSize(self):
+ """Returns the size of this rectangle
+
+ :rtype: numpy.ndarray([float,float])
+ """
+ vmin = self._handleBottomLeft.getPosition()
+ vmax = self._handleTopRight.getPosition()
+ vmin, vmax = numpy.array(vmin), numpy.array(vmax)
+ return vmax - vmin
+
+ def setOrigin(self, position):
+ """Set the origin position of this ROI
+
+ :param numpy.ndarray position: Location of the smaller corner of the ROI
+ """
+ size = self.getSize()
+ self.setGeometry(origin=position, size=size)
+
+ def setSize(self, size):
+ """Set the size of this ROI
+
+ :param numpy.ndarray size: Size of the center of the ROI
+ """
+ origin = self.getOrigin()
+ self.setGeometry(origin=origin, size=size)
+
+ def setCenter(self, position):
+ """Set the size of this ROI
+
+ :param numpy.ndarray position: Location of the center of the ROI
+ """
+ size = self.getSize()
+ self.setGeometry(center=position, size=size)
+
+ def setGeometry(self, origin=None, size=None, center=None):
+ """Set the geometry of the ROI
+ """
+ if ((origin is None or numpy.array_equal(origin, self.getOrigin())) and
+ (center is None or numpy.array_equal(center, self.getCenter())) and
+ numpy.array_equal(size, self.getSize())):
+ return # Nothing has changed
+
+ self._updateGeometry(origin, size, center)
+
+ def _updateGeometry(self, origin=None, size=None, center=None):
+ """Forced update of the geometry of the ROI"""
+ if origin is not None:
+ origin = numpy.array(origin)
+ size = numpy.array(size)
+ points = numpy.array([origin, origin + size])
+ center = origin + size * 0.5
+ elif center is not None:
+ center = numpy.array(center)
+ size = numpy.array(size)
+ points = numpy.array([center - size * 0.5, center + size * 0.5])
+ else:
+ raise ValueError("Origin or center expected")
+
+ with utils.blockSignals(self._handleBottomLeft):
+ self._handleBottomLeft.setPosition(points[0, 0], points[0, 1])
+ with utils.blockSignals(self._handleBottomRight):
+ self._handleBottomRight.setPosition(points[1, 0], points[0, 1])
+ with utils.blockSignals(self._handleTopLeft):
+ self._handleTopLeft.setPosition(points[0, 0], points[1, 1])
+ with utils.blockSignals(self._handleTopRight):
+ self._handleTopRight.setPosition(points[1, 0], points[1, 1])
+ with utils.blockSignals(self._handleCenter):
+ self._handleCenter.setPosition(center[0], center[1])
+ with utils.blockSignals(self._handleLabel):
+ self._handleLabel.setPosition(points[0, 0], points[0, 1])
+
+ self.__shape.setPoints(points)
+ self.sigRegionChanged.emit()
+
+ @docstring(HandleBasedROI)
+ def contains(self, position):
+ assert isinstance(position, (tuple, list, numpy.array))
+ points = self.__shape.getPoints()
+ bb1 = _BoundingBox.from_points(points)
+ return bb1.contains(position)
+
+ def handleDragUpdated(self, handle, origin, previous, current):
+ if handle is self._handleCenter:
+ # It is the center anchor
+ size = self.getSize()
+ self._updateGeometry(center=current, size=size)
+ else:
+ opposed = {
+ self._handleBottomLeft: self._handleTopRight,
+ self._handleTopRight: self._handleBottomLeft,
+ self._handleBottomRight: self._handleTopLeft,
+ self._handleTopLeft: self._handleBottomRight,
+ }
+ handle2 = opposed[handle]
+ current2 = handle2.getPosition()
+ points = numpy.array([current, current2])
+
+ # Switch handles if they were crossed by interaction
+ if self._handleBottomLeft.getXPosition() > self._handleBottomRight.getXPosition():
+ self._handleBottomLeft, self._handleBottomRight = self._handleBottomRight, self._handleBottomLeft
+
+ if self._handleTopLeft.getXPosition() > self._handleTopRight.getXPosition():
+ self._handleTopLeft, self._handleTopRight = self._handleTopRight, self._handleTopLeft
+
+ if self._handleBottomLeft.getYPosition() > self._handleTopLeft.getYPosition():
+ self._handleBottomLeft, self._handleTopLeft = self._handleTopLeft, self._handleBottomLeft
+
+ if self._handleBottomRight.getYPosition() > self._handleTopRight.getYPosition():
+ self._handleBottomRight, self._handleTopRight = self._handleTopRight, self._handleBottomRight
+
+ self._setBound(points)
+
+ def __str__(self):
+ origin = self.getOrigin()
+ w, h = self.getSize()
+ params = origin[0], origin[1], w, h
+ params = 'origin: %f %f; width: %f; height: %f' % params
+ return "%s(%s)" % (self.__class__.__name__, params)
+
+
+class CircleROI(HandleBasedROI, items.LineMixIn):
+ """A ROI identifying a circle in a 2D plot.
+
+ This ROI provides 1 anchor at the center to translate the circle,
+ and one anchor on the perimeter to change the radius.
+ """
+
+ ICON = 'add-shape-circle'
+ NAME = 'circle ROI'
+ SHORT_NAME = "circle"
+ """Metadata for this kind of ROI"""
+
+ _kind = "Circle"
+ """Label for this kind of ROI"""
+
+ _plotShape = "line"
+ """Plot shape which is used for the first interaction"""
+
+ def __init__(self, parent=None):
+ items.LineMixIn.__init__(self)
+ HandleBasedROI.__init__(self, parent=parent)
+ self._handlePerimeter = self.addHandle()
+ self._handleCenter = self.addTranslateHandle()
+ self._handleCenter.sigItemChanged.connect(self._centerPositionChanged)
+ self._handleLabel = self.addLabelHandle()
+
+ shape = items.Shape("polygon")
+ shape.setPoints([[0, 0], [0, 0]])
+ shape.setColor(rgba(self.getColor()))
+ shape.setFill(False)
+ shape.setOverlay(True)
+ shape.setLineStyle(self.getLineStyle())
+ shape.setLineWidth(self.getLineWidth())
+ self.__shape = shape
+ self.addItem(shape)
+
+ self.__radius = 0
+
+ def _updated(self, event=None, checkVisibility=True):
+ if event == items.ItemChangedType.VISIBLE:
+ self._updateItemProperty(event, self, self.__shape)
+ super(CircleROI, self)._updated(event, checkVisibility)
+
+ def _updatedStyle(self, event, style):
+ super(CircleROI, self)._updatedStyle(event, style)
+ self.__shape.setColor(style.getColor())
+ self.__shape.setLineStyle(style.getLineStyle())
+ self.__shape.setLineWidth(style.getLineWidth())
+
+ def setFirstShapePoints(self, points):
+ assert len(points) == 2
+ self._setRay(points)
+
+ def _setRay(self, points):
+ """Initialize the circle from the center point and a
+ perimeter point."""
+ center = points[0]
+ radius = numpy.linalg.norm(points[0] - points[1])
+ self.setGeometry(center=center, radius=radius)
+
+ def _updateText(self, text):
+ self._handleLabel.setText(text)
+
+ def getCenter(self):
+ """Returns the central point of this rectangle
+
+ :rtype: numpy.ndarray([float,float])
+ """
+ pos = self._handleCenter.getPosition()
+ return numpy.array(pos)
+
+ def getRadius(self):
+ """Returns the radius of this circle
+
+ :rtype: float
+ """
+ return self.__radius
+
+ def setCenter(self, position):
+ """Set the center point of this ROI
+
+ :param numpy.ndarray position: Location of the center of the circle
+ """
+ self._handleCenter.setPosition(*position)
+
+ def setRadius(self, radius):
+ """Set the size of this ROI
+
+ :param float size: Radius of the circle
+ """
+ radius = float(radius)
+ if radius != self.__radius:
+ self.__radius = radius
+ self._updateGeometry()
+
+ def setGeometry(self, center, radius):
+ """Set the geometry of the ROI
+ """
+ if numpy.array_equal(center, self.getCenter()):
+ self.setRadius(radius)
+ else:
+ self.__radius = float(radius) # Update radius directly
+ self.setCenter(center) # Calls _updateGeometry
+
+ def _updateGeometry(self):
+ """Update the handles and shape according to given parameters"""
+ center = self.getCenter()
+ perimeter_point = numpy.array([center[0] + self.__radius, center[1]])
+
+ self._handlePerimeter.setPosition(perimeter_point[0], perimeter_point[1])
+ self._handleLabel.setPosition(center[0], center[1])
+
+ nbpoints = 27
+ angles = numpy.arange(nbpoints) * 2.0 * numpy.pi / nbpoints
+ circleShape = numpy.array((numpy.cos(angles) * self.__radius,
+ numpy.sin(angles) * self.__radius)).T
+ circleShape += center
+ self.__shape.setPoints(circleShape)
+ self.sigRegionChanged.emit()
+
+ def _centerPositionChanged(self, event):
+ """Handle position changed events of the center marker"""
+ if event is items.ItemChangedType.POSITION:
+ self._updateGeometry()
+
+ def handleDragUpdated(self, handle, origin, previous, current):
+ if handle is self._handlePerimeter:
+ center = self.getCenter()
+ self.setRadius(numpy.linalg.norm(center - current))
+
+ @docstring(HandleBasedROI)
+ def contains(self, position):
+ return numpy.linalg.norm(self.getCenter() - position) <= self.getRadius()
+
+ def __str__(self):
+ center = self.getCenter()
+ radius = self.getRadius()
+ params = center[0], center[1], radius
+ params = 'center: %f %f; radius: %f;' % params
+ return "%s(%s)" % (self.__class__.__name__, params)
+
+
+class EllipseROI(HandleBasedROI, items.LineMixIn):
+ """A ROI identifying an oriented ellipse in a 2D plot.
+
+ This ROI provides 1 anchor at the center to translate the circle,
+ and two anchors on the perimeter to modify the major-radius and
+ minor-radius. These two anchors also allow to change the orientation.
+ """
+
+ ICON = 'add-shape-ellipse'
+ NAME = 'ellipse ROI'
+ SHORT_NAME = "ellipse"
+ """Metadata for this kind of ROI"""
+
+ _plotShape = "line"
+ """Plot shape which is used for the first interaction"""
+
+ def __init__(self, parent=None):
+ items.LineMixIn.__init__(self)
+ HandleBasedROI.__init__(self, parent=parent)
+ self._handleAxis0 = self.addHandle()
+ self._handleAxis1 = self.addHandle()
+ self._handleCenter = self.addTranslateHandle()
+ self._handleCenter.sigItemChanged.connect(self._centerPositionChanged)
+ self._handleLabel = self.addLabelHandle()
+
+ shape = items.Shape("polygon")
+ shape.setPoints([[0, 0], [0, 0]])
+ shape.setColor(rgba(self.getColor()))
+ shape.setFill(False)
+ shape.setOverlay(True)
+ shape.setLineStyle(self.getLineStyle())
+ shape.setLineWidth(self.getLineWidth())
+ self.__shape = shape
+ self.addItem(shape)
+
+ self._radius = 0., 0.
+ self._orientation = 0. # angle in radians between the X-axis and the _handleAxis0
+
+ def _updated(self, event=None, checkVisibility=True):
+ if event == items.ItemChangedType.VISIBLE:
+ self._updateItemProperty(event, self, self.__shape)
+ super(EllipseROI, self)._updated(event, checkVisibility)
+
+ def _updatedStyle(self, event, style):
+ super(EllipseROI, self)._updatedStyle(event, style)
+ self.__shape.setColor(style.getColor())
+ self.__shape.setLineStyle(style.getLineStyle())
+ self.__shape.setLineWidth(style.getLineWidth())
+
+ def setFirstShapePoints(self, points):
+ assert len(points) == 2
+ self._setRay(points)
+
+ @staticmethod
+ def _calculateOrientation(p0, p1):
+ """return angle in radians between the vector p0-p1
+ and the X axis
+
+ :param p0: first point coordinates (x, y)
+ :param p1: second point coordinates
+ :return:
+ """
+ vector = (p1[0] - p0[0], p1[1] - p0[1])
+ x_unit_vector = (1, 0)
+ norm = numpy.linalg.norm(vector)
+ if norm != 0:
+ theta = numpy.arccos(numpy.dot(vector, x_unit_vector) / norm)
+ else:
+ theta = 0
+ if vector[1] < 0:
+ # arccos always returns values in range [0, pi]
+ theta = 2 * numpy.pi - theta
+ return theta
+
+ def _setRay(self, points):
+ """Initialize the circle from the center point and a
+ perimeter point."""
+ center = points[0]
+ radius = numpy.linalg.norm(points[0] - points[1])
+ orientation = self._calculateOrientation(points[0], points[1])
+ self.setGeometry(center=center,
+ radius=(radius, radius),
+ orientation=orientation)
+
+ def _updateText(self, text):
+ self._handleLabel.setText(text)
+
+ def getCenter(self):
+ """Returns the central point of this rectangle
+
+ :rtype: numpy.ndarray([float,float])
+ """
+ pos = self._handleCenter.getPosition()
+ return numpy.array(pos)
+
+ def getMajorRadius(self):
+ """Returns the half-diameter of the major axis.
+
+ :rtype: float
+ """
+ return max(self._radius)
+
+ def getMinorRadius(self):
+ """Returns the half-diameter of the minor axis.
+
+ :rtype: float
+ """
+ return min(self._radius)
+
+ def getOrientation(self):
+ """Return angle in radians between the horizontal (X) axis
+ and the major axis of the ellipse in [0, 2*pi[
+
+ :rtype: float:
+ """
+ return self._orientation
+
+ def setCenter(self, center):
+ """Set the center point of this ROI
+
+ :param numpy.ndarray position: Coordinates (X, Y) of the center
+ of the ellipse
+ """
+ self._handleCenter.setPosition(*center)
+
+ def setMajorRadius(self, radius):
+ """Set the half-diameter of the major axis of the ellipse.
+
+ :param float radius:
+ Major radius of the ellipsis. Must be a positive value.
+ """
+ if self._radius[0] > self._radius[1]:
+ newRadius = radius, self._radius[1]
+ else:
+ newRadius = self._radius[0], radius
+ self.setGeometry(radius=newRadius)
+
+ def setMinorRadius(self, radius):
+ """Set the half-diameter of the minor axis of the ellipse.
+
+ :param float radius:
+ Minor radius of the ellipsis. Must be a positive value.
+ """
+ if self._radius[0] > self._radius[1]:
+ newRadius = self._radius[0], radius
+ else:
+ newRadius = radius, self._radius[1]
+ self.setGeometry(radius=newRadius)
+
+ def setOrientation(self, orientation):
+ """Rotate the ellipse
+
+ :param float orientation: Angle in radians between the horizontal and
+ the major axis.
+ :return:
+ """
+ self.setGeometry(orientation=orientation)
+
+ def setGeometry(self, center=None, radius=None, orientation=None):
+ """
+
+ :param center: (X, Y) coordinates
+ :param float majorRadius:
+ :param float minorRadius:
+ :param float orientation: angle in radians between the major axis and the
+ horizontal
+ :return:
+ """
+ if center is None:
+ center = self.getCenter()
+
+ if radius is None:
+ radius = self._radius
+ else:
+ radius = float(radius[0]), float(radius[1])
+
+ if orientation is None:
+ orientation = self._orientation
+ else:
+ # ensure that we store the orientation in range [0, 2*pi
+ orientation = numpy.mod(orientation, 2 * numpy.pi)
+
+ if (numpy.array_equal(center, self.getCenter()) or
+ radius != self._radius or
+ orientation != self._orientation):
+
+ # Update parameters directly
+ self._radius = radius
+ self._orientation = orientation
+
+ if numpy.array_equal(center, self.getCenter()):
+ self._updateGeometry()
+ else:
+ # This will call _updateGeometry
+ self.setCenter(center)
+
+ def _updateGeometry(self):
+ """Update shape and markers"""
+ center = self.getCenter()
+
+ orientation = self.getOrientation()
+ if self._radius[1] > self._radius[0]:
+ # _handleAxis1 is the major axis
+ orientation -= numpy.pi / 2
+
+ point0 = numpy.array([center[0] + self._radius[0] * numpy.cos(orientation),
+ center[1] + self._radius[0] * numpy.sin(orientation)])
+ point1 = numpy.array([center[0] - self._radius[1] * numpy.sin(orientation),
+ center[1] + self._radius[1] * numpy.cos(orientation)])
+ with utils.blockSignals(self._handleAxis0):
+ self._handleAxis0.setPosition(*point0)
+ with utils.blockSignals(self._handleAxis1):
+ self._handleAxis1.setPosition(*point1)
+ with utils.blockSignals(self._handleLabel):
+ self._handleLabel.setPosition(*center)
+
+ nbpoints = 27
+ angles = numpy.arange(nbpoints) * 2.0 * numpy.pi / nbpoints
+ X = (self._radius[0] * numpy.cos(angles) * numpy.cos(orientation)
+ - self._radius[1] * numpy.sin(angles) * numpy.sin(orientation))
+ Y = (self._radius[0] * numpy.cos(angles) * numpy.sin(orientation)
+ + self._radius[1] * numpy.sin(angles) * numpy.cos(orientation))
+
+ ellipseShape = numpy.array((X, Y)).T
+ ellipseShape += center
+ self.__shape.setPoints(ellipseShape)
+ self.sigRegionChanged.emit()
+
+ def handleDragUpdated(self, handle, origin, previous, current):
+ if handle in (self._handleAxis0, self._handleAxis1):
+ center = self.getCenter()
+ orientation = self._calculateOrientation(center, current)
+ distance = numpy.linalg.norm(center - current)
+
+ if handle is self._handleAxis1:
+ if self._radius[0] > distance:
+ # _handleAxis1 is not the major axis, rotate -90 degrees
+ orientation -= numpy.pi / 2
+ radius = self._radius[0], distance
+
+ else: # _handleAxis0
+ if self._radius[1] > distance:
+ # _handleAxis0 is not the major axis, rotate +90 degrees
+ orientation += numpy.pi / 2
+ radius = distance, self._radius[1]
+
+ self.setGeometry(radius=radius, orientation=orientation)
+
+ def _centerPositionChanged(self, event):
+ """Handle position changed events of the center marker"""
+ if event is items.ItemChangedType.POSITION:
+ self._updateGeometry()
+
+ @docstring(HandleBasedROI)
+ def contains(self, position):
+ major, minor = self.getMajorRadius(), self.getMinorRadius()
+ delta = self.getOrientation()
+ x, y = position - self.getCenter()
+ return ((x*numpy.cos(delta) + y*numpy.sin(delta))**2/major**2 +
+ (x*numpy.sin(delta) - y*numpy.cos(delta))**2/minor**2) <= 1
+
+ def __str__(self):
+ center = self.getCenter()
+ major = self.getMajorRadius()
+ minor = self.getMinorRadius()
+ orientation = self.getOrientation()
+ params = center[0], center[1], major, minor, orientation
+ params = 'center: %f %f; major radius: %f: minor radius: %f; orientation: %f' % params
+ return "%s(%s)" % (self.__class__.__name__, params)
+
+
+class PolygonROI(HandleBasedROI, items.LineMixIn):
+ """A ROI identifying a closed polygon in a 2D plot.
+
+ This ROI provides 1 anchor for each point of the polygon.
+ """
+
+ ICON = 'add-shape-polygon'
+ NAME = 'polygon ROI'
+ SHORT_NAME = "polygon"
+ """Metadata for this kind of ROI"""
+
+ _plotShape = "polygon"
+ """Plot shape which is used for the first interaction"""
+
+ def __init__(self, parent=None):
+ HandleBasedROI.__init__(self, parent=parent)
+ items.LineMixIn.__init__(self)
+ self._handleLabel = self.addLabelHandle()
+ self._handleCenter = self.addTranslateHandle()
+ self._handlePoints = []
+ self._points = numpy.empty((0, 2))
+ self._handleClose = None
+
+ self._polygon_shape = None
+ shape = self.__createShape()
+ self.__shape = shape
+ self.addItem(shape)
+
+ def _updated(self, event=None, checkVisibility=True):
+ if event in [items.ItemChangedType.VISIBLE]:
+ self._updateItemProperty(event, self, self.__shape)
+ super(PolygonROI, self)._updated(event, checkVisibility)
+
+ def _updatedStyle(self, event, style):
+ super(PolygonROI, self)._updatedStyle(event, style)
+ self.__shape.setColor(style.getColor())
+ self.__shape.setLineStyle(style.getLineStyle())
+ self.__shape.setLineWidth(style.getLineWidth())
+ if self._handleClose is not None:
+ color = self._computeHandleColor(style.getColor())
+ self._handleClose.setColor(color)
+
+ def __createShape(self, interaction=False):
+ kind = "polygon" if not interaction else "polylines"
+ shape = items.Shape(kind)
+ shape.setPoints([[0, 0], [0, 0]])
+ shape.setFill(False)
+ shape.setOverlay(True)
+ style = self.getCurrentStyle()
+ shape.setLineStyle(style.getLineStyle())
+ shape.setLineWidth(style.getLineWidth())
+ shape.setColor(rgba(style.getColor()))
+ return shape
+
+ def setFirstShapePoints(self, points):
+ if self._handleClose is not None:
+ self._handleClose.setPosition(*points[0])
+ self.setPoints(points)
+
+ def creationStarted(self):
+ """"Called when the ROI creation interaction was started.
+ """
+ # Handle to see where to close the polygon
+ self._handleClose = self.addUserHandle()
+ self._handleClose.setSymbol("o")
+ color = self._computeHandleColor(rgba(self.getColor()))
+ self._handleClose.setColor(color)
+
+ # Hide the center while creating the first shape
+ self._handleCenter.setSymbol("")
+
+ # In interaction replace the polygon by a line, to display something unclosed
+ self.removeItem(self.__shape)
+ self.__shape = self.__createShape(interaction=True)
+ self.__shape.setPoints(self._points)
+ self.addItem(self.__shape)
+
+ def isBeingCreated(self):
+ """Returns true if the ROI is in creation step"""
+ return self._handleClose is not None
+
+ def creationFinalized(self):
+ """"Called when the ROI creation interaction was finalized.
+ """
+ self.removeHandle(self._handleClose)
+ self._handleClose = None
+ self.removeItem(self.__shape)
+ self.__shape = self.__createShape()
+ self.__shape.setPoints(self._points)
+ self.addItem(self.__shape)
+ # Hide the center while creating the first shape
+ self._handleCenter.setSymbol("+")
+ for handle in self._handlePoints:
+ handle.setSymbol("s")
+
+ def _updateText(self, text):
+ self._handleLabel.setText(text)
+
+ def getPoints(self):
+ """Returns the list of the points of this polygon.
+
+ :rtype: numpy.ndarray
+ """
+ return self._points.copy()
+
+ def setPoints(self, points):
+ """Set the position of this ROI
+
+ :param numpy.ndarray pos: 2d-coordinate of this point
+ """
+ assert(len(points.shape) == 2 and points.shape[1] == 2)
+
+ if numpy.array_equal(points, self._points):
+ return # Nothing has changed
+
+ self._polygon_shape = None
+
+ # Update the needed handles
+ while len(self._handlePoints) != len(points):
+ if len(self._handlePoints) < len(points):
+ handle = self.addHandle()
+ self._handlePoints.append(handle)
+ if self.isBeingCreated():
+ handle.setSymbol("")
+ else:
+ handle = self._handlePoints.pop(-1)
+ self.removeHandle(handle)
+
+ for handle, position in zip(self._handlePoints, points):
+ with utils.blockSignals(handle):
+ handle.setPosition(position[0], position[1])
+
+ if len(points) > 0:
+ if not self.isHandleBeingDragged():
+ vmin = numpy.min(points, axis=0)
+ vmax = numpy.max(points, axis=0)
+ center = (vmax + vmin) * 0.5
+ with utils.blockSignals(self._handleCenter):
+ self._handleCenter.setPosition(center[0], center[1])
+
+ num = numpy.argmin(points[:, 1])
+ pos = points[num]
+ with utils.blockSignals(self._handleLabel):
+ self._handleLabel.setPosition(pos[0], pos[1])
+
+ if len(points) == 0:
+ self._points = numpy.empty((0, 2))
+ else:
+ self._points = points
+ self.__shape.setPoints(self._points)
+ self.sigRegionChanged.emit()
+
+ def translate(self, x, y):
+ points = self.getPoints()
+ delta = numpy.array([x, y])
+ self.setPoints(points)
+ self.setPoints(points + delta)
+
+ def handleDragUpdated(self, handle, origin, previous, current):
+ if handle is self._handleCenter:
+ delta = current - previous
+ self.translate(delta[0], delta[1])
+ else:
+ points = self.getPoints()
+ num = self._handlePoints.index(handle)
+ points[num] = current
+ self.setPoints(points)
+
+ def handleDragFinished(self, handle, origin, current):
+ points = self._points
+ if len(points) > 0:
+ # Only update the center at the end
+ # To avoid to disturb the interaction
+ vmin = numpy.min(points, axis=0)
+ vmax = numpy.max(points, axis=0)
+ center = (vmax + vmin) * 0.5
+ with utils.blockSignals(self._handleCenter):
+ self._handleCenter.setPosition(center[0], center[1])
+
+ def __str__(self):
+ points = self._points
+ params = '; '.join('%f %f' % (pt[0], pt[1]) for pt in points)
+ return "%s(%s)" % (self.__class__.__name__, params)
+
+ @docstring(HandleBasedROI)
+ def contains(self, position):
+ bb1 = _BoundingBox.from_points(self.getPoints())
+ if bb1.contains(position) is False:
+ return False
+
+ if self._polygon_shape is None:
+ self._polygon_shape = Polygon(vertices=self.getPoints())
+
+ # warning: both the polygon and the value are inverted
+ return self._polygon_shape.is_inside(row=position[0], col=position[1])
+
+ def _setControlPoints(self, points):
+ RegionOfInterest._setControlPoints(self, points=points)
+ self._polygon_shape = None
+
+
+class HorizontalRangeROI(RegionOfInterest, items.LineMixIn):
+ """A ROI identifying an horizontal range in a 1D plot."""
+
+ ICON = 'add-range-horizontal'
+ NAME = 'horizontal range ROI'
+ SHORT_NAME = "hrange"
+
+ _plotShape = "line"
+ """Plot shape which is used for the first interaction"""
+
+ def __init__(self, parent=None):
+ RegionOfInterest.__init__(self, parent=parent)
+ items.LineMixIn.__init__(self)
+ self._markerMin = items.XMarker()
+ self._markerMax = items.XMarker()
+ self._markerCen = items.XMarker()
+ self._markerCen.setLineStyle(" ")
+ self._markerMin._setConstraint(self.__positionMinConstraint)
+ self._markerMax._setConstraint(self.__positionMaxConstraint)
+ self._markerMin.sigDragStarted.connect(self._editingStarted)
+ self._markerMin.sigDragFinished.connect(self._editingFinished)
+ self._markerMax.sigDragStarted.connect(self._editingStarted)
+ self._markerMax.sigDragFinished.connect(self._editingFinished)
+ self._markerCen.sigDragStarted.connect(self._editingStarted)
+ self._markerCen.sigDragFinished.connect(self._editingFinished)
+ self.addItem(self._markerCen)
+ self.addItem(self._markerMin)
+ self.addItem(self._markerMax)
+ self.__filterReentrant = utils.LockReentrant()
+
+ def setFirstShapePoints(self, points):
+ vmin = min(points[:, 0])
+ vmax = max(points[:, 0])
+ self._updatePos(vmin, vmax)
+
+ def _updated(self, event=None, checkVisibility=True):
+ if event == items.ItemChangedType.NAME:
+ self._updateText()
+ elif event == items.ItemChangedType.EDITABLE:
+ self._updateEditable()
+ self._updateText()
+ elif event == items.ItemChangedType.LINE_STYLE:
+ markers = [self._markerMin, self._markerMax]
+ self._updateItemProperty(event, self, markers)
+ elif event in [items.ItemChangedType.VISIBLE,
+ items.ItemChangedType.SELECTABLE]:
+ markers = [self._markerMin, self._markerMax, self._markerCen]
+ self._updateItemProperty(event, self, markers)
+ super(HorizontalRangeROI, self)._updated(event, checkVisibility)
+
+ def _updatedStyle(self, event, style):
+ markers = [self._markerMin, self._markerMax, self._markerCen]
+ for m in markers:
+ m.setColor(style.getColor())
+ m.setLineWidth(style.getLineWidth())
+
+ def _updateText(self):
+ text = self.getName()
+ if self.isEditable():
+ self._markerMin.setText("")
+ self._markerCen.setText(text)
+ else:
+ self._markerMin.setText(text)
+ self._markerCen.setText("")
+
+ def _updateEditable(self):
+ editable = self.isEditable()
+ self._markerMin._setDraggable(editable)
+ self._markerMax._setDraggable(editable)
+ self._markerCen._setDraggable(editable)
+ if self.isEditable():
+ self._markerMin.sigItemChanged.connect(self._minPositionChanged)
+ self._markerMax.sigItemChanged.connect(self._maxPositionChanged)
+ self._markerCen.sigItemChanged.connect(self._cenPositionChanged)
+ self._markerCen.setLineStyle(":")
+ else:
+ self._markerMin.sigItemChanged.disconnect(self._minPositionChanged)
+ self._markerMax.sigItemChanged.disconnect(self._maxPositionChanged)
+ self._markerCen.sigItemChanged.disconnect(self._cenPositionChanged)
+ self._markerCen.setLineStyle(" ")
+
+ def _updatePos(self, vmin, vmax, force=False):
+ """Update marker position and emit signal.
+
+ :param float vmin:
+ :param float vmax:
+ :param bool force:
+ True to update even if already at the right position.
+ """
+ if not force and numpy.array_equal((vmin, vmax), self.getRange()):
+ return # Nothing has changed
+
+ center = (vmin + vmax) * 0.5
+ with self.__filterReentrant:
+ with utils.blockSignals(self._markerMin):
+ self._markerMin.setPosition(vmin, 0)
+ with utils.blockSignals(self._markerCen):
+ self._markerCen.setPosition(center, 0)
+ with utils.blockSignals(self._markerMax):
+ self._markerMax.setPosition(vmax, 0)
+ self.sigRegionChanged.emit()
+
+ def setRange(self, vmin, vmax):
+ """Set the range of this ROI.
+
+ :param float vmin: Staring location of the range
+ :param float vmax: Ending location of the range
+ """
+ if vmin is None or vmax is None:
+ err = "Can't set vmin or vmax to None"
+ raise ValueError(err)
+ if vmin > vmax:
+ err = "Can't set vmin and vmax because vmin >= vmax " \
+ "vmin = %s, vmax = %s" % (vmin, vmax)
+ raise ValueError(err)
+ self._updatePos(vmin, vmax)
+
+ def getRange(self):
+ """Returns the range of this ROI.
+
+ :rtype: Tuple[float,float]
+ """
+ vmin = self.getMin()
+ vmax = self.getMax()
+ return vmin, vmax
+
+ def setMin(self, vmin):
+ """Set the min of this ROI.
+
+ :param float vmin: New min
+ """
+ vmax = self.getMax()
+ self._updatePos(vmin, vmax)
+
+ def getMin(self):
+ """Returns the min value of this ROI.
+
+ :rtype: float
+ """
+ return self._markerMin.getPosition()[0]
+
+ def setMax(self, vmax):
+ """Set the max of this ROI.
+
+ :param float vmax: New max
+ """
+ vmin = self.getMin()
+ self._updatePos(vmin, vmax)
+
+ def getMax(self):
+ """Returns the max value of this ROI.
+
+ :rtype: float
+ """
+ return self._markerMax.getPosition()[0]
+
+ def setCenter(self, center):
+ """Set the center of this ROI.
+
+ :param float center: New center
+ """
+ vmin, vmax = self.getRange()
+ previousCenter = (vmin + vmax) * 0.5
+ delta = center - previousCenter
+ self._updatePos(vmin + delta, vmax + delta)
+
+ def getCenter(self):
+ """Returns the center location of this ROI.
+
+ :rtype: float
+ """
+ vmin, vmax = self.getRange()
+ return (vmin + vmax) * 0.5
+
+ def __positionMinConstraint(self, x, y):
+ """Constraint of the min marker"""
+ if self.__filterReentrant.locked():
+ # Ignore the constraint when we set an explicit value
+ return x, y
+ vmax = self.getMax()
+ if vmax is None:
+ return x, y
+ return min(x, vmax), y
+
+ def __positionMaxConstraint(self, x, y):
+ """Constraint of the max marker"""
+ if self.__filterReentrant.locked():
+ # Ignore the constraint when we set an explicit value
+ return x, y
+ vmin = self.getMin()
+ if vmin is None:
+ return x, y
+ return max(x, vmin), y
+
+ def _minPositionChanged(self, event):
+ """Handle position changed events of the marker"""
+ if event is items.ItemChangedType.POSITION:
+ marker = self.sender()
+ self._updatePos(marker.getXPosition(), self.getMax(), force=True)
+
+ def _maxPositionChanged(self, event):
+ """Handle position changed events of the marker"""
+ if event is items.ItemChangedType.POSITION:
+ marker = self.sender()
+ self._updatePos(self.getMin(), marker.getXPosition(), force=True)
+
+ def _cenPositionChanged(self, event):
+ """Handle position changed events of the marker"""
+ if event is items.ItemChangedType.POSITION:
+ marker = self.sender()
+ self.setCenter(marker.getXPosition())
+
+ @docstring(HandleBasedROI)
+ def contains(self, position):
+ return self.getMin() <= position[0] <= self.getMax()
+
+ def __str__(self):
+ vrange = self.getRange()
+ params = 'min: %f; max: %f' % vrange
+ return "%s(%s)" % (self.__class__.__name__, params)
diff --git a/src/silx/gui/plot/items/scatter.py b/src/silx/gui/plot/items/scatter.py
new file mode 100644
index 0000000..fdc66f7
--- /dev/null
+++ b/src/silx/gui/plot/items/scatter.py
@@ -0,0 +1,1002 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the :class:`Scatter` item of the :class:`Plot`.
+"""
+
+from __future__ import division
+
+
+__authors__ = ["T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "29/03/2017"
+
+
+from collections import namedtuple
+import logging
+import threading
+import numpy
+
+from collections import defaultdict
+from concurrent.futures import ThreadPoolExecutor, CancelledError
+
+from ....utils.proxy import docstring
+from ....math.combo import min_max
+from ....math.histogram import Histogramnd
+from ....utils.weakref import WeakList
+from .._utils.delaunay import delaunay
+from .core import PointsBase, ColormapMixIn, ScatterVisualizationMixIn
+from .axis import Axis
+from ._pick import PickingResult
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _GreedyThreadPoolExecutor(ThreadPoolExecutor):
+ """:class:`ThreadPoolExecutor` with an extra :meth:`submit_greedy` method.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(_GreedyThreadPoolExecutor, self).__init__(*args, **kwargs)
+ self.__futures = defaultdict(WeakList)
+ self.__lock = threading.RLock()
+
+ def submit_greedy(self, queue, fn, *args, **kwargs):
+ """Same as :meth:`submit` but cancel previous tasks in given queue.
+
+ This means that when a new task is submitted for a given queue,
+ all other pending tasks of that queue are cancelled.
+
+ :param queue: Identifier of the queue. This must be hashable.
+ :param callable fn: The callable to call with provided extra arguments
+ :return: Future corresponding to this task
+ :rtype: concurrent.futures.Future
+ """
+ with self.__lock:
+ # Cancel previous tasks in given queue
+ for future in self.__futures.pop(queue, []):
+ if not future.done():
+ future.cancel()
+
+ future = super(_GreedyThreadPoolExecutor, self).submit(
+ fn, *args, **kwargs)
+ self.__futures[queue].append(future)
+
+ return future
+
+
+# Functions to guess grid shape from coordinates
+
+def _get_z_line_length(array):
+ """Return length of line if array is a Z-like 2D regular grid.
+
+ :param numpy.ndarray array: The 1D array of coordinates to check
+ :return: 0 if no line length could be found,
+ else the number of element per line.
+ :rtype: int
+ """
+ sign = numpy.sign(numpy.diff(array))
+ if len(sign) == 0 or sign[0] == 0: # We don't handle that
+ return 0
+ # Check this way to account for 0 sign (i.e., diff == 0)
+ beginnings = numpy.where(sign == - sign[0])[0] + 1
+ if len(beginnings) == 0:
+ return 0
+ length = beginnings[0]
+ if numpy.all(numpy.equal(numpy.diff(beginnings), length)):
+ return length
+ return 0
+
+
+def _guess_z_grid_shape(x, y):
+ """Guess the shape of a grid from (x, y) coordinates.
+
+ The grid might contain more elements than x and y,
+ as the last line might be partly filled.
+
+ :param numpy.ndarray x:
+ :paran numpy.ndarray y:
+ :returns: (order, (height, width)) of the regular grid,
+ or None if could not guess one.
+ 'order' is 'row' if X (i.e., column) is the fast dimension, else 'column'.
+ :rtype: Union[List(str,int),None]
+ """
+ width = _get_z_line_length(x)
+ if width != 0:
+ return 'row', (int(numpy.ceil(len(x) / width)), width)
+ else:
+ height = _get_z_line_length(y)
+ if height != 0:
+ return 'column', (height, int(numpy.ceil(len(y) / height)))
+ return None
+
+
+def is_monotonic(array):
+ """Returns whether array is monotonic (increasing or decreasing).
+
+ :param numpy.ndarray array: 1D array-like container.
+ :returns: 1 if array is monotonically increasing,
+ -1 if array is monotonically decreasing,
+ 0 if array is not monotonic
+ :rtype: int
+ """
+ diff = numpy.diff(numpy.ravel(array))
+ with numpy.errstate(invalid='ignore'):
+ if numpy.all(diff >= 0):
+ return 1
+ elif numpy.all(diff <= 0):
+ return -1
+ else:
+ return 0
+
+
+def _guess_grid(x, y):
+ """Guess a regular grid from the points.
+
+ Result convention is (x, y)
+
+ :param numpy.ndarray x: X coordinates of the points
+ :param numpy.ndarray y: Y coordinates of the points
+ :returns: (order, (height, width)
+ order is 'row' or 'column'
+ :rtype: Union[List[str,List[int]],None]
+ """
+ x, y = numpy.ravel(x), numpy.ravel(y)
+
+ guess = _guess_z_grid_shape(x, y)
+ if guess is not None:
+ return guess
+
+ else:
+ # Cannot guess a regular grid
+ # Let's assume it's a single line
+ order = 'row' # or 'column' doesn't matter for a single line
+ y_monotonic = is_monotonic(y)
+ if is_monotonic(x) or y_monotonic: # we can guess a line
+ x_min, x_max = min_max(x)
+ y_min, y_max = min_max(y)
+
+ if not y_monotonic or x_max - x_min >= y_max - y_min:
+ # x only is monotonic or both are and X varies more
+ # line along X
+ shape = 1, len(x)
+ else:
+ # y only is monotonic or both are and Y varies more
+ # line along Y
+ shape = len(y), 1
+
+ else: # Cannot guess a line from the points
+ return None
+
+ return order, shape
+
+
+def _quadrilateral_grid_coords(points):
+ """Compute an irregular grid of quadrilaterals from a set of points
+
+ The input points are expected to lie on a grid.
+
+ :param numpy.ndarray points:
+ 3D data set of 2D input coordinates (height, width, 2)
+ height and width must be at least 2.
+ :return: 3D dataset of 2D coordinates of the grid (height+1, width+1, 2)
+ """
+ assert points.ndim == 3
+ assert points.shape[0] >= 2
+ assert points.shape[1] >= 2
+ assert points.shape[2] == 2
+
+ dim0, dim1 = points.shape[:2]
+ grid_points = numpy.zeros((dim0 + 1, dim1 + 1, 2), dtype=numpy.float64)
+
+ # Compute inner points as mean of 4 neighbours
+ neighbour_view = numpy.lib.stride_tricks.as_strided(
+ points,
+ shape=(dim0 - 1, dim1 - 1, 2, 2, points.shape[2]),
+ strides=points.strides[:2] + points.strides[:2] + points.strides[-1:], writeable=False)
+ inner_points = numpy.mean(neighbour_view, axis=(2, 3))
+ grid_points[1:-1, 1:-1] = inner_points
+
+ # Compute 'vertical' sides
+ # Alternative: grid_points[1:-1, [0, -1]] = points[:-1, [0, -1]] + points[1:, [0, -1]] - inner_points[:, [0, -1]]
+ grid_points[1:-1, [0, -1], 0] = points[:-1, [0, -1], 0] + points[1:, [0, -1], 0] - inner_points[:, [0, -1], 0]
+ grid_points[1:-1, [0, -1], 1] = inner_points[:, [0, -1], 1]
+
+ # Compute 'horizontal' sides
+ grid_points[[0, -1], 1:-1, 0] = inner_points[[0, -1], :, 0]
+ grid_points[[0, -1], 1:-1, 1] = points[[0, -1], :-1, 1] + points[[0, -1], 1:, 1] - inner_points[[0, -1], :, 1]
+
+ # Compute corners
+ d0, d1 = [0, 0, -1, -1], [0, -1, -1, 0]
+ grid_points[d0, d1] = 2 * points[d0, d1] - inner_points[d0, d1]
+ return grid_points
+
+
+def _quadrilateral_grid_as_triangles(points):
+ """Returns the points and indices to make a grid of quadirlaterals
+
+ :param numpy.ndarray points:
+ 3D array of points (height, width, 2)
+ :return: triangle corners (4 * N, 2), triangle indices (2 * N, 3)
+ With N = height * width, the number of input points
+ """
+ nbpoints = numpy.prod(points.shape[:2])
+
+ grid = _quadrilateral_grid_coords(points)
+ coords = numpy.empty((4 * nbpoints, 2), dtype=grid.dtype)
+ coords[::4] = grid[:-1, :-1].reshape(-1, 2)
+ coords[1::4] = grid[1:, :-1].reshape(-1, 2)
+ coords[2::4] = grid[:-1, 1:].reshape(-1, 2)
+ coords[3::4] = grid[1:, 1:].reshape(-1, 2)
+
+ indices = numpy.empty((2 * nbpoints, 3), dtype=numpy.uint32)
+ indices[::2, 0] = numpy.arange(0, 4 * nbpoints, 4)
+ indices[::2, 1] = numpy.arange(1, 4 * nbpoints, 4)
+ indices[::2, 2] = numpy.arange(2, 4 * nbpoints, 4)
+ indices[1::2, 0] = indices[::2, 1]
+ indices[1::2, 1] = indices[::2, 2]
+ indices[1::2, 2] = numpy.arange(3, 4 * nbpoints, 4)
+
+ return coords, indices
+
+
+_RegularGridInfo = namedtuple(
+ '_RegularGridInfo', ['bounds', 'origin', 'scale', 'shape', 'order'])
+
+
+_HistogramInfo = namedtuple(
+ '_HistogramInfo', ['mean', 'count', 'sum', 'origin', 'scale', 'shape'])
+
+
+class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
+ """Description of a scatter"""
+
+ _DEFAULT_SELECTABLE = True
+ """Default selectable state for scatter plots"""
+
+ _SUPPORTED_SCATTER_VISUALIZATION = (
+ ScatterVisualizationMixIn.Visualization.POINTS,
+ ScatterVisualizationMixIn.Visualization.SOLID,
+ ScatterVisualizationMixIn.Visualization.REGULAR_GRID,
+ ScatterVisualizationMixIn.Visualization.IRREGULAR_GRID,
+ ScatterVisualizationMixIn.Visualization.BINNED_STATISTIC,
+ )
+ """Overrides supported Visualizations"""
+
+ def __init__(self):
+ PointsBase.__init__(self)
+ ColormapMixIn.__init__(self)
+ ScatterVisualizationMixIn.__init__(self)
+ self._value = ()
+ self.__alpha = None
+ # Cache Delaunay triangulation future object
+ self.__delaunayFuture = None
+ # Cache interpolator future object
+ self.__interpolatorFuture = None
+ self.__executor = None
+
+ # Cache triangles: x, y, indices
+ self.__cacheTriangles = None, None, None
+
+ # Cache regular grid and histogram info
+ self.__cacheRegularGridInfo = None
+ self.__cacheHistogramInfo = None
+
+ def _updateColormappedData(self):
+ """Update the colormapped data, to be called when changed"""
+ if self.getVisualization() is self.Visualization.BINNED_STATISTIC:
+ histoInfo = self.__getHistogramInfo()
+ if histoInfo is None:
+ data = None
+ else:
+ data = getattr(
+ histoInfo,
+ self.getVisualizationParameter(
+ self.VisualizationParameter.BINNED_STATISTIC_FUNCTION))
+ else:
+ data = self.getValueData(copy=False)
+ self._setColormappedData(data, copy=False)
+
+ @docstring(ScatterVisualizationMixIn)
+ def setVisualization(self, mode):
+ previous = self.getVisualization()
+ if super().setVisualization(mode):
+ if (bool(mode is self.Visualization.BINNED_STATISTIC) ^
+ bool(previous is self.Visualization.BINNED_STATISTIC)):
+ self._updateColormappedData()
+ return True
+ else:
+ return False
+
+ @docstring(ScatterVisualizationMixIn)
+ def setVisualizationParameter(self, parameter, value):
+ parameter = self.VisualizationParameter.from_value(parameter)
+
+ if super(Scatter, self).setVisualizationParameter(parameter, value):
+ if parameter in (self.VisualizationParameter.GRID_BOUNDS,
+ self.VisualizationParameter.GRID_MAJOR_ORDER,
+ self.VisualizationParameter.GRID_SHAPE):
+ self.__cacheRegularGridInfo = None
+
+ if parameter in (self.VisualizationParameter.BINNED_STATISTIC_SHAPE,
+ self.VisualizationParameter.BINNED_STATISTIC_FUNCTION,
+ self.VisualizationParameter.DATA_BOUNDS_HINT):
+ if parameter in (self.VisualizationParameter.BINNED_STATISTIC_SHAPE,
+ self.VisualizationParameter.DATA_BOUNDS_HINT):
+ self.__cacheHistogramInfo = None # Clean-up cache
+ if self.getVisualization() is self.Visualization.BINNED_STATISTIC:
+ self._updateColormappedData()
+ return True
+ else:
+ return False
+
+ @docstring(ScatterVisualizationMixIn)
+ def getCurrentVisualizationParameter(self, parameter):
+ value = self.getVisualizationParameter(parameter)
+ if (parameter is self.VisualizationParameter.DATA_BOUNDS_HINT or
+ value is not None):
+ return value # Value has been set, return it
+
+ elif parameter is self.VisualizationParameter.GRID_BOUNDS:
+ grid = self.__getRegularGridInfo()
+ return None if grid is None else grid.bounds
+
+ elif parameter is self.VisualizationParameter.GRID_MAJOR_ORDER:
+ grid = self.__getRegularGridInfo()
+ return None if grid is None else grid.order
+
+ elif parameter is self.VisualizationParameter.GRID_SHAPE:
+ grid = self.__getRegularGridInfo()
+ return None if grid is None else grid.shape
+
+ elif parameter is self.VisualizationParameter.BINNED_STATISTIC_SHAPE:
+ info = self.__getHistogramInfo()
+ return None if info is None else info.shape
+
+ else:
+ raise NotImplementedError()
+
+ def __getRegularGridInfo(self):
+ """Get grid info"""
+ if self.__cacheRegularGridInfo is None:
+ shape = self.getVisualizationParameter(
+ self.VisualizationParameter.GRID_SHAPE)
+ order = self.getVisualizationParameter(
+ self.VisualizationParameter.GRID_MAJOR_ORDER)
+ if shape is None or order is None:
+ guess = _guess_grid(self.getXData(copy=False),
+ self.getYData(copy=False))
+ if guess is None:
+ _logger.warning(
+ 'Cannot guess a grid: Cannot display as regular grid image')
+ return None
+ if shape is None:
+ shape = guess[1]
+ if order is None:
+ order = guess[0]
+
+ nbpoints = len(self.getXData(copy=False))
+ if nbpoints > shape[0] * shape[1]:
+ # More data points that provided grid shape: enlarge grid
+ _logger.warning(
+ "More data points than provided grid shape size: extends grid")
+ dim0, dim1 = shape
+ if order == 'row': # keep dim1, enlarge dim0
+ dim0 = nbpoints // dim1 + (1 if nbpoints % dim1 else 0)
+ else: # keep dim0, enlarge dim1
+ dim1 = nbpoints // dim0 + (1 if nbpoints % dim0 else 0)
+ shape = dim0, dim1
+
+ bounds = self.getVisualizationParameter(
+ self.VisualizationParameter.GRID_BOUNDS)
+ if bounds is None:
+ x, y = self.getXData(copy=False), self.getYData(copy=False)
+ min_, max_ = min_max(x)
+ xRange = (min_, max_) if (x[0] - min_) < (max_ - x[0]) else (max_, min_)
+ min_, max_ = min_max(y)
+ yRange = (min_, max_) if (y[0] - min_) < (max_ - y[0]) else (max_, min_)
+ bounds = (xRange[0], yRange[0]), (xRange[1], yRange[1])
+
+ begin, end = bounds
+ scale = ((end[0] - begin[0]) / max(1, shape[1] - 1),
+ (end[1] - begin[1]) / max(1, shape[0] - 1))
+ if scale[0] == 0 and scale[1] == 0:
+ scale = 1., 1.
+ elif scale[0] == 0:
+ scale = scale[1], scale[1]
+ elif scale[1] == 0:
+ scale = scale[0], scale[0]
+
+ origin = begin[0] - 0.5 * scale[0], begin[1] - 0.5 * scale[1]
+
+ self.__cacheRegularGridInfo = _RegularGridInfo(
+ bounds=bounds, origin=origin, scale=scale, shape=shape, order=order)
+
+ return self.__cacheRegularGridInfo
+
+ def __getHistogramInfo(self):
+ """Get histogram info"""
+ if self.__cacheHistogramInfo is None:
+ shape = self.getVisualizationParameter(
+ self.VisualizationParameter.BINNED_STATISTIC_SHAPE)
+ if shape is None:
+ shape = 100, 100 # TODO compute auto shape
+
+ x, y, values = self.getData(copy=False)[:3]
+ if len(x) == 0: # No histogram
+ return None
+
+ if not numpy.issubdtype(x.dtype, numpy.floating):
+ x = x.astype(numpy.float64)
+ if not numpy.issubdtype(y.dtype, numpy.floating):
+ y = y.astype(numpy.float64)
+ if not numpy.issubdtype(values.dtype, numpy.floating):
+ values = values.astype(numpy.float64)
+
+ ranges = (tuple(min_max(y, finite=True)),
+ tuple(min_max(x, finite=True)))
+ rangesHint = self.getVisualizationParameter(
+ self.VisualizationParameter.DATA_BOUNDS_HINT)
+ if rangesHint is not None:
+ ranges = tuple((min(dataMin, hintMin), max(dataMax, hintMax))
+ for (dataMin, dataMax), (hintMin, hintMax) in zip(ranges, rangesHint))
+
+ points = numpy.transpose(numpy.array((y, x)))
+ counts, sums, bin_edges = Histogramnd(
+ points,
+ histo_range=ranges,
+ n_bins=shape,
+ weights=values)
+ yEdges, xEdges = bin_edges
+ origin = xEdges[0], yEdges[0]
+ scale = ((xEdges[-1] - xEdges[0]) / (len(xEdges) - 1),
+ (yEdges[-1] - yEdges[0]) / (len(yEdges) - 1))
+
+ with numpy.errstate(divide='ignore', invalid='ignore'):
+ histo = sums / counts
+
+ self.__cacheHistogramInfo = _HistogramInfo(
+ mean=histo, count=counts, sum=sums,
+ origin=origin, scale=scale, shape=shape)
+
+ return self.__cacheHistogramInfo
+
+ def __applyColormapToData(self):
+ """Compute colors by applying colormap to values.
+
+ :returns: Array of RGBA colors
+ """
+ cmap = self.getColormap()
+ rgbacolors = cmap.applyToData(self)
+
+ if self.__alpha is not None:
+ rgbacolors[:, -1] = (rgbacolors[:, -1] * self.__alpha).astype(numpy.uint8)
+ return rgbacolors
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ # Filter-out values <= 0
+ xFiltered, yFiltered, valueFiltered, xerror, yerror = self.getData(
+ copy=False, displayed=True)
+
+ # Remove not finite numbers (this includes filtered out x, y <= 0)
+ mask = numpy.logical_and(numpy.isfinite(xFiltered), numpy.isfinite(yFiltered))
+ xFiltered = xFiltered[mask]
+ yFiltered = yFiltered[mask]
+
+ if len(xFiltered) == 0:
+ return None # No data to display, do not add renderer to backend
+
+ visualization = self.getVisualization()
+
+ if visualization is self.Visualization.BINNED_STATISTIC:
+ plot = self.getPlot()
+ if (plot is None or
+ plot.getXAxis().getScale() != Axis.LINEAR or
+ plot.getYAxis().getScale() != Axis.LINEAR):
+ # Those visualizations are not available with log scaled axes
+ return None
+
+ histoInfo = self.__getHistogramInfo()
+ if histoInfo is None:
+ return None
+ data = getattr(histoInfo, self.getVisualizationParameter(
+ self.VisualizationParameter.BINNED_STATISTIC_FUNCTION))
+
+ return backend.addImage(
+ data=data,
+ origin=histoInfo.origin,
+ scale=histoInfo.scale,
+ colormap=self.getColormap(),
+ alpha=self.getAlpha())
+
+ elif visualization is self.Visualization.POINTS:
+ rgbacolors = self.__applyColormapToData()
+ return backend.addCurve(xFiltered, yFiltered,
+ color=rgbacolors[mask],
+ symbol=self.getSymbol(),
+ linewidth=0,
+ linestyle="",
+ yaxis='left',
+ xerror=xerror,
+ yerror=yerror,
+ fill=False,
+ alpha=self.getAlpha(),
+ symbolsize=self.getSymbolSize(),
+ baseline=None)
+
+ else:
+ plot = self.getPlot()
+ if (plot is None or
+ plot.getXAxis().getScale() != Axis.LINEAR or
+ plot.getYAxis().getScale() != Axis.LINEAR):
+ # Those visualizations are not available with log scaled axes
+ return None
+
+ if visualization is self.Visualization.SOLID:
+ triangulation = self._getDelaunay().result()
+ if triangulation is None:
+ _logger.warning(
+ 'Cannot get a triangulation: Cannot display as solid surface')
+ return None
+ else:
+ rgbacolors = self.__applyColormapToData()
+ triangles = triangulation.simplices.astype(numpy.int32)
+ return backend.addTriangles(xFiltered,
+ yFiltered,
+ triangles,
+ color=rgbacolors[mask],
+ alpha=self.getAlpha())
+
+ elif visualization is self.Visualization.REGULAR_GRID:
+ gridInfo = self.__getRegularGridInfo()
+ if gridInfo is None:
+ return None
+
+ dim0, dim1 = gridInfo.shape
+ if gridInfo.order == 'column': # transposition needed
+ dim0, dim1 = dim1, dim0
+
+ values = self.getValueData(copy=False)
+ if self.__alpha is None and len(values) == dim0 * dim1:
+ image = values.reshape(dim0, dim1)
+ else:
+ # The points do not fill the whole image
+ if (self.__alpha is None and
+ numpy.issubdtype(values.dtype, numpy.floating)):
+ image = numpy.empty(dim0 * dim1, dtype=values.dtype)
+ image[:len(values)] = values
+ image[len(values):] = float('nan') # Transparent pixels
+ image.shape = dim0, dim1
+ else: # Per value alpha or no NaN, so convert to RGBA
+ rgbacolors = self.__applyColormapToData()
+ image = numpy.empty((dim0 * dim1, 4), dtype=numpy.uint8)
+ image[:len(rgbacolors)] = rgbacolors
+ image[len(rgbacolors):] = (0, 0, 0, 0) # Transparent pixels
+ image.shape = dim0, dim1, 4
+
+ if gridInfo.order == 'column':
+ if image.ndim == 2:
+ image = numpy.transpose(image)
+ else:
+ image = numpy.transpose(image, axes=(1, 0, 2))
+
+ if image.ndim == 2:
+ colormap = self.getColormap()
+ if colormap.isAutoscale():
+ # Avoid backend to compute autoscale: use item cache
+ colormap = colormap.copy()
+ colormap.setVRange(*colormap.getColormapRange(self))
+ else:
+ colormap = None
+
+ return backend.addImage(
+ data=image,
+ origin=gridInfo.origin,
+ scale=gridInfo.scale,
+ colormap=colormap,
+ alpha=self.getAlpha())
+
+ elif visualization is self.Visualization.IRREGULAR_GRID:
+ gridInfo = self.__getRegularGridInfo()
+ if gridInfo is None:
+ return None
+
+ shape = gridInfo.shape
+ if shape is None: # No shape, no display
+ return None
+
+ rgbacolors = self.__applyColormapToData()
+
+ nbpoints = len(xFiltered)
+ if nbpoints == 1:
+ # single point, render as a square points
+ return backend.addCurve(xFiltered, yFiltered,
+ color=rgbacolors[mask],
+ symbol='s',
+ linewidth=0,
+ linestyle="",
+ yaxis='left',
+ xerror=None,
+ yerror=None,
+ fill=False,
+ alpha=self.getAlpha(),
+ symbolsize=7,
+ baseline=None)
+
+ # Make shape include all points
+ gridOrder = gridInfo.order
+ if nbpoints != numpy.prod(shape):
+ if gridOrder == 'row':
+ shape = int(numpy.ceil(nbpoints / shape[1])), shape[1]
+ else: # column-major order
+ shape = shape[0], int(numpy.ceil(nbpoints / shape[0]))
+
+ if shape[0] < 2 or shape[1] < 2: # Single line, at least 2 points
+ points = numpy.ones((2, nbpoints, 2), dtype=numpy.float64)
+ # Use row/column major depending on shape, not on info value
+ gridOrder = 'row' if shape[0] == 1 else 'column'
+
+ if gridOrder == 'row':
+ points[0, :, 0] = xFiltered
+ points[0, :, 1] = yFiltered
+ else: # column-major order
+ points[0, :, 0] = yFiltered
+ points[0, :, 1] = xFiltered
+
+ # Add a second line that will be clipped in the end
+ points[1, :-1] = points[0, :-1] + numpy.cross(
+ points[0, 1:] - points[0, :-1], (0., 0., 1.))[:, :2]
+ points[1, -1] = points[0, -1] + numpy.cross(
+ points[0, -1] - points[0, -2], (0., 0., 1.))[:2]
+
+ points.shape = 2, nbpoints, 2 # Use same shape for both orders
+ coords, indices = _quadrilateral_grid_as_triangles(points)
+
+ elif gridOrder == 'row': # row-major order
+ if nbpoints != numpy.prod(shape):
+ points = numpy.empty((numpy.prod(shape), 2), dtype=numpy.float64)
+ points[:nbpoints, 0] = xFiltered
+ points[:nbpoints, 1] = yFiltered
+ # Index of last element of last fully filled row
+ index = (nbpoints // shape[1]) * shape[1]
+ points[nbpoints:, 0] = xFiltered[index - (numpy.prod(shape) - nbpoints):index]
+ points[nbpoints:, 1] = yFiltered[-1]
+ else:
+ points = numpy.transpose((xFiltered, yFiltered))
+ points.shape = shape[0], shape[1], 2
+
+ else: # column-major order
+ if nbpoints != numpy.prod(shape):
+ points = numpy.empty((numpy.prod(shape), 2), dtype=numpy.float64)
+ points[:nbpoints, 0] = yFiltered
+ points[:nbpoints, 1] = xFiltered
+ # Index of last element of last fully filled column
+ index = (nbpoints // shape[0]) * shape[0]
+ points[nbpoints:, 0] = yFiltered[index - (numpy.prod(shape) - nbpoints):index]
+ points[nbpoints:, 1] = xFiltered[-1]
+ else:
+ points = numpy.transpose((yFiltered, xFiltered))
+ points.shape = shape[1], shape[0], 2
+
+ coords, indices = _quadrilateral_grid_as_triangles(points)
+
+ # Remove unused extra triangles
+ coords = coords[:4*nbpoints]
+ indices = indices[:2*nbpoints]
+
+ if gridOrder == 'row':
+ x, y = coords[:, 0], coords[:, 1]
+ else: # column-major order
+ y, x = coords[:, 0], coords[:, 1]
+
+ rgbacolors = rgbacolors[mask] # Filter-out not finite points
+ gridcolors = numpy.empty(
+ (4 * nbpoints, rgbacolors.shape[-1]), dtype=rgbacolors.dtype)
+ for first in range(4):
+ gridcolors[first::4] = rgbacolors[:nbpoints]
+
+ return backend.addTriangles(x,
+ y,
+ indices,
+ color=gridcolors,
+ alpha=self.getAlpha())
+
+ else:
+ _logger.error("Unhandled visualization %s", visualization)
+ return None
+
+ @docstring(PointsBase)
+ def pick(self, x, y):
+ result = super(Scatter, self).pick(x, y)
+
+ if result is not None:
+ visualization = self.getVisualization()
+
+ if visualization is self.Visualization.IRREGULAR_GRID:
+ # Specific handling of picking for the irregular grid mode
+ index = result.getIndices(copy=False)[0] // 4
+ result = PickingResult(self, (index,))
+
+ elif visualization is self.Visualization.REGULAR_GRID:
+ # Specific handling of picking for the regular grid mode
+ picked = result.getIndices(copy=False)
+ if picked is None:
+ return None
+ row, column = picked[0][0], picked[1][0]
+
+ gridInfo = self.__getRegularGridInfo()
+ if gridInfo is None:
+ return None
+
+ if gridInfo.order == 'row':
+ index = row * gridInfo.shape[1] + column
+ else:
+ index = row + column * gridInfo.shape[0]
+ if index >= len(self.getXData(copy=False)): # OK as long as not log scale
+ return None # Image can be larger than scatter
+
+ result = PickingResult(self, (index,))
+
+ elif visualization is self.Visualization.BINNED_STATISTIC:
+ picked = result.getIndices(copy=False)
+ if picked is None or len(picked) == 0 or len(picked[0]) == 0:
+ return None
+ row, col = picked[0][0], picked[1][0]
+ histoInfo = self.__getHistogramInfo()
+ if histoInfo is None:
+ return None
+ sx, sy = histoInfo.scale
+ ox, oy = histoInfo.origin
+ xdata = self.getXData(copy=False)
+ ydata = self.getYData(copy=False)
+ indices = numpy.nonzero(numpy.logical_and(
+ numpy.logical_and(xdata >= ox + sx * col, xdata < ox + sx * (col + 1)),
+ numpy.logical_and(ydata >= oy + sy * row, ydata < oy + sy * (row + 1))))[0]
+ result = None if len(indices) == 0 else PickingResult(self, indices)
+
+ return result
+
+ def __getExecutor(self):
+ """Returns async greedy executor
+
+ :rtype: _GreedyThreadPoolExecutor
+ """
+ if self.__executor is None:
+ self.__executor = _GreedyThreadPoolExecutor(max_workers=2)
+ return self.__executor
+
+ def _getDelaunay(self):
+ """Returns a :class:`Future` which result is the Delaunay object.
+
+ :rtype: concurrent.futures.Future
+ """
+ if self.__delaunayFuture is None or self.__delaunayFuture.cancelled():
+ # Need to init a new delaunay
+ x, y = self.getData(copy=False)[:2]
+ # Remove not finite points
+ mask = numpy.logical_and(numpy.isfinite(x), numpy.isfinite(y))
+
+ self.__delaunayFuture = self.__getExecutor().submit_greedy(
+ 'delaunay', delaunay, x[mask], y[mask])
+
+ return self.__delaunayFuture
+
+ @staticmethod
+ def __initInterpolator(delaunayFuture, values):
+ """Returns an interpolator for the given data points
+
+ :param concurrent.futures.Future delaunayFuture:
+ Future object which result is a Delaunay object
+ :param numpy.ndarray values: The data value of valid points.
+ :rtype: Union[callable,None]
+ """
+ # Wait for Delaunay to complete
+ try:
+ triangulation = delaunayFuture.result()
+ except CancelledError:
+ triangulation = None
+
+ if triangulation is None:
+ interpolator = None # Error case
+ else:
+ # Lazy-loading of interpolator
+ try:
+ from scipy.interpolate import LinearNDInterpolator
+ except ImportError:
+ LinearNDInterpolator = None
+
+ if LinearNDInterpolator is not None:
+ interpolator = LinearNDInterpolator(triangulation, values)
+
+ # First call takes a while, do it here
+ interpolator([(0., 0.)])
+
+ else:
+ # Fallback using matplotlib interpolator
+ import matplotlib.tri
+
+ x, y = triangulation.points.T
+ tri = matplotlib.tri.Triangulation(
+ x, y, triangles=triangulation.simplices)
+ mplInterpolator = matplotlib.tri.LinearTriInterpolator(
+ tri, values)
+
+ # Wrap interpolator to have same API as scipy's one
+ def interpolator(points):
+ return mplInterpolator(*points.T)
+
+ return interpolator
+
+ def _getInterpolator(self):
+ """Returns a :class:`Future` which result is the interpolator.
+
+ The interpolator is a callable taking an array Nx2 of points
+ as a single argument.
+ The :class:`Future` result is None in case the interpolator cannot
+ be initialized.
+
+ :rtype: concurrent.futures.Future
+ """
+ if (self.__interpolatorFuture is None or
+ self.__interpolatorFuture.cancelled()):
+ # Need to init a new interpolator
+ x, y, values = self.getData(copy=False)[:3]
+ # Remove not finite points
+ mask = numpy.logical_and(numpy.isfinite(x), numpy.isfinite(y))
+ x, y, values = x[mask], y[mask], values[mask]
+
+ self.__interpolatorFuture = self.__getExecutor().submit_greedy(
+ 'interpolator',
+ self.__initInterpolator, self._getDelaunay(), values)
+ return self.__interpolatorFuture
+
+ def _logFilterData(self, xPositive, yPositive):
+ """Filter out values with x or y <= 0 on log axes
+
+ :param bool xPositive: True to filter arrays according to X coords.
+ :param bool yPositive: True to filter arrays according to Y coords.
+ :return: The filtered arrays or unchanged object if not filtering needed
+ :rtype: (x, y, value, xerror, yerror)
+ """
+ # overloaded from PointsBase to filter also value.
+ value = self.getValueData(copy=False)
+
+ if xPositive or yPositive:
+ clipped = self._getClippingBoolArray(xPositive, yPositive)
+
+ if numpy.any(clipped):
+ # copy to keep original array and convert to float
+ value = numpy.array(value, copy=True, dtype=numpy.float64)
+ value[clipped] = numpy.nan
+
+ x, y, xerror, yerror = PointsBase._logFilterData(self, xPositive, yPositive)
+
+ return x, y, value, xerror, yerror
+
+ def getValueData(self, copy=True):
+ """Returns the value assigned to the scatter data points.
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self._value, copy=copy)
+
+ def getAlphaData(self, copy=True):
+ """Returns the alpha (transparency) assigned to the scatter data points.
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self.__alpha, copy=copy)
+
+ def getData(self, copy=True, displayed=False):
+ """Returns the x, y coordinates and the value of the data points
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :param bool displayed: True to only get curve points that are displayed
+ in the plot. Default: False.
+ Note: If plot has log scale, negative points
+ are not displayed.
+ :returns: (x, y, value, xerror, yerror)
+ :rtype: 5-tuple of numpy.ndarray
+ """
+ if displayed:
+ data = self._getCachedData()
+ if data is not None:
+ assert len(data) == 5
+ return data
+
+ return (self.getXData(copy),
+ self.getYData(copy),
+ self.getValueData(copy),
+ self.getXErrorData(copy),
+ self.getYErrorData(copy))
+
+ # reimplemented from PointsBase to handle `value`
+ def setData(self, x, y, value, xerror=None, yerror=None, alpha=None, copy=True):
+ """Set the data of the scatter.
+
+ :param numpy.ndarray x: The data corresponding to the x coordinates.
+ :param numpy.ndarray y: The data corresponding to the y coordinates.
+ :param numpy.ndarray value: The data corresponding to the value of
+ the data points.
+ :param xerror: Values with the uncertainties on the x values
+ :type xerror: A float, or a numpy.ndarray of float32.
+ If it is an array, it can either be a 1D array of
+ same length as the data or a 2D array with 2 rows
+ of same length as the data: row 0 for positive errors,
+ row 1 for negative errors.
+ :param yerror: Values with the uncertainties on the y values
+ :type yerror: A float, or a numpy.ndarray of float32. See xerror.
+ :param alpha: Values with the transparency (between 0 and 1)
+ :type alpha: A float, or a numpy.ndarray of float32
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ """
+ value = numpy.array(value, copy=copy)
+ assert value.ndim == 1
+ assert len(x) == len(value)
+
+ # Convert complex data
+ if numpy.iscomplexobj(value):
+ _logger.warning(
+ 'Converting value data to absolute value to plot it.')
+ value = numpy.absolute(value)
+
+ # Reset triangulation and interpolator
+ if self.__delaunayFuture is not None:
+ self.__delaunayFuture.cancel()
+ self.__delaunayFuture = None
+ if self.__interpolatorFuture is not None:
+ self.__interpolatorFuture.cancel()
+ self.__interpolatorFuture = None
+
+ # Data changed, this needs update
+ self.__cacheRegularGridInfo = None
+ self.__cacheHistogramInfo = None
+
+ self._value = value
+
+ if alpha is not None:
+ # Make sure alpha is an array of float in [0, 1]
+ alpha = numpy.array(alpha, copy=copy)
+ assert alpha.ndim == 1
+ assert len(x) == len(alpha)
+ if alpha.dtype.kind != 'f':
+ alpha = alpha.astype(numpy.float32)
+ if numpy.any(numpy.logical_or(alpha < 0., alpha > 1.)):
+ alpha = numpy.clip(alpha, 0., 1.)
+ self.__alpha = alpha
+
+ # set x, y, xerror, yerror
+
+ # call self._updated + plot._invalidateDataRange()
+ PointsBase.setData(self, x, y, xerror, yerror, copy)
+
+ self._updateColormappedData()
diff --git a/src/silx/gui/plot/items/shape.py b/src/silx/gui/plot/items/shape.py
new file mode 100644
index 0000000..00ac5f5
--- /dev/null
+++ b/src/silx/gui/plot/items/shape.py
@@ -0,0 +1,287 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the :class:`Shape` item of the :class:`Plot`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "21/12/2018"
+
+
+import logging
+
+import numpy
+
+from ... import colors
+from .core import (
+ Item, DataItem,
+ ColorMixIn, FillMixIn, ItemChangedType, LineMixIn, YAxisMixIn)
+
+
+_logger = logging.getLogger(__name__)
+
+
+# TODO probably make one class for each kind of shape
+# TODO check fill:polygon/polyline + fill = duplicated
+class Shape(Item, ColorMixIn, FillMixIn, LineMixIn):
+ """Description of a shape item
+
+ :param str type_: The type of shape in:
+ 'hline', 'polygon', 'rectangle', 'vline', 'polylines'
+ """
+
+ def __init__(self, type_):
+ Item.__init__(self)
+ ColorMixIn.__init__(self)
+ FillMixIn.__init__(self)
+ LineMixIn.__init__(self)
+ self._overlay = False
+ assert type_ in ('hline', 'polygon', 'rectangle', 'vline', 'polylines')
+ self._type = type_
+ self._points = ()
+ self._lineBgColor = None
+
+ self._handle = None
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ points = self.getPoints(copy=False)
+ x, y = points.T[0], points.T[1]
+ return backend.addShape(x,
+ y,
+ shape=self.getType(),
+ color=self.getColor(),
+ fill=self.isFill(),
+ overlay=self.isOverlay(),
+ linestyle=self.getLineStyle(),
+ linewidth=self.getLineWidth(),
+ linebgcolor=self.getLineBgColor())
+
+ def isOverlay(self):
+ """Return true if shape is drawn as an overlay
+
+ :rtype: bool
+ """
+ return self._overlay
+
+ def setOverlay(self, overlay):
+ """Set the overlay state of the shape
+
+ :param bool overlay: True to make it an overlay
+ """
+ overlay = bool(overlay)
+ if overlay != self._overlay:
+ self._overlay = overlay
+ self._updated(ItemChangedType.OVERLAY)
+
+ def getType(self):
+ """Returns the type of shape to draw.
+
+ One of: 'hline', 'polygon', 'rectangle', 'vline', 'polylines'
+
+ :rtype: str
+ """
+ return self._type
+
+ def getPoints(self, copy=True):
+ """Get the control points of the shape.
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :return: Array of point coordinates
+ :rtype: numpy.ndarray with 2 dimensions
+ """
+ return numpy.array(self._points, copy=copy)
+
+ def setPoints(self, points, copy=True):
+ """Set the point coordinates
+
+ :param numpy.ndarray points: Array of point coordinates
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :return:
+ """
+ self._points = numpy.array(points, copy=copy)
+ self._updated(ItemChangedType.DATA)
+
+ def getLineBgColor(self):
+ """Returns the RGBA color of the item
+ :rtype: 4-tuple of float in [0, 1] or array of colors
+ """
+ return self._lineBgColor
+
+ def setLineBgColor(self, color, copy=True):
+ """Set item color
+ :param color: color(s) to be used
+ :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or
+ one of the predefined color names defined in colors.py
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ """
+ if color is not None:
+ if isinstance(color, str):
+ color = colors.rgba(color)
+ else:
+ color = numpy.array(color, copy=copy)
+ # TODO more checks + improve color array support
+ if color.ndim == 1: # Single RGBA color
+ color = colors.rgba(color)
+ else: # Array of colors
+ assert color.ndim == 2
+
+ self._lineBgColor = color
+ self._updated(ItemChangedType.LINE_BG_COLOR)
+
+
+class BoundingRect(DataItem, YAxisMixIn):
+ """An invisible shape which enforce the plot view to display the defined
+ space on autoscale.
+
+ This item do not display anything. But if the visible property is true,
+ this bounding box is used by the plot, if not, the bounding box is
+ ignored. That's the default behaviour for plot items.
+
+ It can be applied on the "left" or "right" axes. Not both at the same time.
+ """
+
+ def __init__(self):
+ DataItem.__init__(self)
+ YAxisMixIn.__init__(self)
+ self.__bounds = None
+
+ def setBounds(self, rect):
+ """Set the bounding box of this item in data coordinates
+
+ :param Union[None,List[float]] rect: (xmin, xmax, ymin, ymax) or None
+ """
+ if rect is not None:
+ rect = float(rect[0]), float(rect[1]), float(rect[2]), float(rect[3])
+ assert rect[0] <= rect[1]
+ assert rect[2] <= rect[3]
+
+ if rect != self.__bounds:
+ self.__bounds = rect
+ self._boundsChanged()
+ self._updated(ItemChangedType.DATA)
+
+ def _getBounds(self):
+ if self.__bounds is None:
+ return None
+ plot = self.getPlot()
+ if plot is not None:
+ xPositive = plot.getXAxis()._isLogarithmic()
+ yPositive = plot.getYAxis()._isLogarithmic()
+ if xPositive or yPositive:
+ bounds = list(self.__bounds)
+ if xPositive and bounds[1] <= 0:
+ return None
+ if xPositive and bounds[0] <= 0:
+ bounds[0] = bounds[1]
+ if yPositive and bounds[3] <= 0:
+ return None
+ if yPositive and bounds[2] <= 0:
+ bounds[2] = bounds[3]
+ return tuple(bounds)
+
+ return self.__bounds
+
+
+class _BaseExtent(DataItem):
+ """Base class for :class:`XAxisExtent` and :class:`YAxisExtent`.
+
+ :param str axis: Either 'x' or 'y'.
+ """
+
+ def __init__(self, axis='x'):
+ assert axis in ('x', 'y')
+ DataItem.__init__(self)
+ self.__axis = axis
+ self.__range = 1., 100.
+
+ def setRange(self, min_, max_):
+ """Set the range of the extent of this item in data coordinates.
+
+ :param float min_: Lower bound of the extent
+ :param float max_: Upper bound of the extent
+ :raises ValueError: If min > max or not finite bounds
+ """
+ range_ = float(min_), float(max_)
+ if not numpy.all(numpy.isfinite(range_)):
+ raise ValueError("min_ and max_ must be finite numbers.")
+ if range_[0] > range_[1]:
+ raise ValueError("min_ must be lesser or equal to max_")
+
+ if range_ != self.__range:
+ self.__range = range_
+ self._boundsChanged()
+ self._updated(ItemChangedType.DATA)
+
+ def getRange(self):
+ """Returns the range (min, max) of the extent in data coordinates.
+
+ :rtype: List[float]
+ """
+ return self.__range
+
+ def _getBounds(self):
+ min_, max_ = self.getRange()
+
+ plot = self.getPlot()
+ if plot is not None:
+ axis = plot.getXAxis() if self.__axis == 'x' else plot.getYAxis()
+ if axis._isLogarithmic():
+ if max_ <= 0:
+ return None
+ if min_ <= 0:
+ min_ = max_
+
+ if self.__axis == 'x':
+ return min_, max_, float('nan'), float('nan')
+ else:
+ return float('nan'), float('nan'), min_, max_
+
+
+class XAxisExtent(_BaseExtent):
+ """Invisible item with a settable horizontal data extent.
+
+ This item do not display anything, but it behaves as a data
+ item with a horizontal extent regarding plot data bounds, i.e.,
+ :meth:`PlotWidget.resetZoom` will take this horizontal extent into account.
+ """
+ def __init__(self):
+ _BaseExtent.__init__(self, axis='x')
+
+
+class YAxisExtent(_BaseExtent, YAxisMixIn):
+ """Invisible item with a settable vertical data extent.
+
+ This item do not display anything, but it behaves as a data
+ item with a vertical extent regarding plot data bounds, i.e.,
+ :meth:`PlotWidget.resetZoom` will take this vertical extent into account.
+ """
+
+ def __init__(self):
+ _BaseExtent.__init__(self, axis='y')
+ YAxisMixIn.__init__(self)
diff --git a/src/silx/gui/plot/matplotlib/Colormap.py b/src/silx/gui/plot/matplotlib/Colormap.py
new file mode 100644
index 0000000..dc432b2
--- /dev/null
+++ b/src/silx/gui/plot/matplotlib/Colormap.py
@@ -0,0 +1,249 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2017-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Matplotlib's new colormaps"""
+
+import numpy
+import logging
+from matplotlib.colors import ListedColormap
+import matplotlib.colors
+import matplotlib.cm
+import silx.resources
+from silx.utils.deprecation import deprecated, deprecated_warning
+
+
+deprecated_warning(type_='module',
+ name=__file__,
+ replacement='silx.gui.colors.Colormap',
+ since_version='0.10.0')
+
+
+_logger = logging.getLogger(__name__)
+
+_AVAILABLE_AS_RESOURCE = ('magma', 'inferno', 'plasma', 'viridis')
+"""List available colormap name as resources"""
+
+_AVAILABLE_AS_BUILTINS = ('gray', 'reversed gray',
+ 'temperature', 'red', 'green', 'blue')
+"""List of colormaps available through built-in declarations"""
+
+_CMAPS = {}
+"""Cache colormaps"""
+
+
+@property
+@deprecated(since_version='0.10.0')
+def magma():
+ return getColormap('magma')
+
+
+@property
+@deprecated(since_version='0.10.0')
+def inferno():
+ return getColormap('inferno')
+
+
+@property
+@deprecated(since_version='0.10.0')
+def plasma():
+ return getColormap('plasma')
+
+
+@property
+@deprecated(since_version='0.10.0')
+def viridis():
+ return getColormap('viridis')
+
+
+@deprecated(since_version='0.10.0')
+def getColormap(name):
+ """Returns matplotlib colormap corresponding to given name
+
+ :param str name: The name of the colormap
+ :return: The corresponding colormap
+ :rtype: matplolib.colors.Colormap
+ """
+ if not _CMAPS: # Lazy initialization of own colormaps
+ cdict = {'red': ((0.0, 0.0, 0.0),
+ (1.0, 1.0, 1.0)),
+ 'green': ((0.0, 0.0, 0.0),
+ (1.0, 0.0, 0.0)),
+ 'blue': ((0.0, 0.0, 0.0),
+ (1.0, 0.0, 0.0))}
+ _CMAPS['red'] = matplotlib.colors.LinearSegmentedColormap(
+ 'red', cdict, 256)
+
+ cdict = {'red': ((0.0, 0.0, 0.0),
+ (1.0, 0.0, 0.0)),
+ 'green': ((0.0, 0.0, 0.0),
+ (1.0, 1.0, 1.0)),
+ 'blue': ((0.0, 0.0, 0.0),
+ (1.0, 0.0, 0.0))}
+ _CMAPS['green'] = matplotlib.colors.LinearSegmentedColormap(
+ 'green', cdict, 256)
+
+ cdict = {'red': ((0.0, 0.0, 0.0),
+ (1.0, 0.0, 0.0)),
+ 'green': ((0.0, 0.0, 0.0),
+ (1.0, 0.0, 0.0)),
+ 'blue': ((0.0, 0.0, 0.0),
+ (1.0, 1.0, 1.0))}
+ _CMAPS['blue'] = matplotlib.colors.LinearSegmentedColormap(
+ 'blue', cdict, 256)
+
+ # Temperature as defined in spslut
+ cdict = {'red': ((0.0, 0.0, 0.0),
+ (0.5, 0.0, 0.0),
+ (0.75, 1.0, 1.0),
+ (1.0, 1.0, 1.0)),
+ 'green': ((0.0, 0.0, 0.0),
+ (0.25, 1.0, 1.0),
+ (0.75, 1.0, 1.0),
+ (1.0, 0.0, 0.0)),
+ 'blue': ((0.0, 1.0, 1.0),
+ (0.25, 1.0, 1.0),
+ (0.5, 0.0, 0.0),
+ (1.0, 0.0, 0.0))}
+ # but limited to 256 colors for a faster display (of the colorbar)
+ _CMAPS['temperature'] = \
+ matplotlib.colors.LinearSegmentedColormap(
+ 'temperature', cdict, 256)
+
+ # reversed gray
+ cdict = {'red': ((0.0, 1.0, 1.0),
+ (1.0, 0.0, 0.0)),
+ 'green': ((0.0, 1.0, 1.0),
+ (1.0, 0.0, 0.0)),
+ 'blue': ((0.0, 1.0, 1.0),
+ (1.0, 0.0, 0.0))}
+
+ _CMAPS['reversed gray'] = \
+ matplotlib.colors.LinearSegmentedColormap(
+ 'yerg', cdict, 256)
+
+ if name in _CMAPS:
+ return _CMAPS[name]
+ elif name in _AVAILABLE_AS_RESOURCE:
+ filename = silx.resources.resource_filename("gui/colormaps/%s.npy" % name)
+ data = numpy.load(filename)
+ lut = ListedColormap(data, name=name)
+ _CMAPS[name] = lut
+ return lut
+ else:
+ # matplotlib built-in
+ return matplotlib.cm.get_cmap(name)
+
+
+@deprecated(since_version='0.10.0')
+def getScalarMappable(colormap, data=None):
+ """Returns matplotlib ScalarMappable corresponding to colormap
+
+ :param :class:`.Colormap` colormap: The colormap to convert
+ :param numpy.ndarray data:
+ The data on which the colormap is applied.
+ If provided, it is used to compute autoscale.
+ :return: matplotlib object corresponding to colormap
+ :rtype: matplotlib.cm.ScalarMappable
+ """
+ assert colormap is not None
+
+ if colormap.getName() is not None:
+ cmap = getColormap(colormap.getName())
+
+ else: # No name, use custom colors
+ if colormap.getColormapLUT() is None:
+ raise ValueError(
+ 'addImage: colormap no name nor list of colors.')
+ colors = colormap.getColormapLUT()
+ assert len(colors.shape) == 2
+ assert colors.shape[-1] in (3, 4)
+ if colors.dtype == numpy.uint8:
+ # Convert to float in [0., 1.]
+ colors = colors.astype(numpy.float32) / 255.
+ cmap = matplotlib.colors.ListedColormap(colors)
+
+ vmin, vmax = colormap.getColormapRange(data)
+ normalization = colormap.getNormalization()
+ if normalization == colormap.LOGARITHM:
+ norm = matplotlib.colors.LogNorm(vmin, vmax)
+ elif normalization == colormap.LINEAR:
+ norm = matplotlib.colors.Normalize(vmin, vmax)
+ else:
+ raise RuntimeError("Unsupported normalization: %s" % normalization)
+
+ return matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
+
+
+@deprecated(replacement='silx.colors.Colormap.applyToData',
+ since_version='0.8.0')
+def applyColormapToData(data, colormap):
+ """Apply a colormap to the data and returns the RGBA image
+
+ This supports data of any dimensions (not only of dimension 2).
+ The returned array will have one more dimension (with 4 entries)
+ than the input data to store the RGBA channels
+ corresponding to each bin in the array.
+
+ :param numpy.ndarray data: The data to convert.
+ :param :class:`.Colormap`: The colormap to apply
+ """
+ # Debian 7 specific support
+ # No transparent colormap with matplotlib < 1.2.0
+ # Add support for transparent colormap for uint8 data with
+ # colormap with 256 colors, linear norm, [0, 255] range
+ if matplotlib.__version__ < '1.2.0':
+ if (colormap.getName() is None and
+ colormap.getColormapLUT() is not None):
+ colors = colormap.getColormapLUT()
+ if (colors.shape[-1] == 4 and
+ not numpy.all(numpy.equal(colors[3], 255))):
+ # This is a transparent colormap
+ if (colors.shape == (256, 4) and
+ colormap.getNormalization() == 'linear' and
+ not colormap.isAutoscale() and
+ colormap.getVMin() == 0 and
+ colormap.getVMax() == 255 and
+ data.dtype == numpy.uint8):
+ # Supported case, convert data to RGBA
+ return colors[data.reshape(-1)].reshape(
+ data.shape + (4,))
+ else:
+ _logger.warning(
+ 'matplotlib %s does not support transparent '
+ 'colormap.', matplotlib.__version__)
+
+ scalarMappable = getScalarMappable(colormap, data)
+ rgbaImage = scalarMappable.to_rgba(data, bytes=True)
+
+ return rgbaImage
+
+
+@deprecated(replacement='silx.colors.Colormap.getSupportedColormaps',
+ since_version='0.10.0')
+def getSupportedColormaps():
+ """Get the supported colormap names as a tuple of str.
+ """
+ colormaps = set(matplotlib.cm.datad.keys())
+ colormaps.update(_AVAILABLE_AS_BUILTINS)
+ colormaps.update(_AVAILABLE_AS_RESOURCE)
+ return tuple(sorted(colormaps))
diff --git a/src/silx/gui/plot/matplotlib/__init__.py b/src/silx/gui/plot/matplotlib/__init__.py
new file mode 100644
index 0000000..e787240
--- /dev/null
+++ b/src/silx/gui/plot/matplotlib/__init__.py
@@ -0,0 +1,37 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "15/07/2020"
+
+from silx.utils.deprecation import deprecated_warning
+
+deprecated_warning(type_='module',
+ name=__file__,
+ replacement='silx.gui.utils.matplotlib',
+ since_version='0.14.0')
+
+from silx.gui.utils.matplotlib import FigureCanvasQTAgg # noqa
diff --git a/src/silx/gui/plot/setup.py b/src/silx/gui/plot/setup.py
new file mode 100644
index 0000000..e0b2c91
--- /dev/null
+++ b/src/silx/gui/plot/setup.py
@@ -0,0 +1,54 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "29/06/2017"
+
+
+from numpy.distutils.misc_util import Configuration
+
+
+def configuration(parent_package='', top_path=None):
+ config = Configuration('plot', parent_package, top_path)
+ config.add_subpackage('_utils')
+ config.add_subpackage('utils')
+ config.add_subpackage('matplotlib')
+ config.add_subpackage('stats')
+ config.add_subpackage('backends')
+ config.add_subpackage('backends.glutils')
+ config.add_subpackage('items')
+ config.add_subpackage('test')
+ config.add_subpackage('tools')
+ config.add_subpackage('tools.profile')
+ config.add_subpackage('tools.test')
+ config.add_subpackage('actions')
+
+ return config
+
+
+if __name__ == "__main__":
+ from numpy.distutils.core import setup
+
+ setup(configuration=configuration)
diff --git a/src/silx/gui/plot/stats/__init__.py b/src/silx/gui/plot/stats/__init__.py
new file mode 100644
index 0000000..04a5327
--- /dev/null
+++ b/src/silx/gui/plot/stats/__init__.py
@@ -0,0 +1,33 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "07/03/2018"
+
+
+from .stats import *
diff --git a/src/silx/gui/plot/stats/stats.py b/src/silx/gui/plot/stats/stats.py
new file mode 100644
index 0000000..a81f7bb
--- /dev/null
+++ b/src/silx/gui/plot/stats/stats.py
@@ -0,0 +1,890 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides mechanism relative to stats calculation within a
+:class:`PlotWidget`.
+It also include the implementation of the statistics themselves.
+"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "06/06/2018"
+
+
+from collections import OrderedDict
+from functools import lru_cache
+import logging
+
+import numpy
+import numpy.ma
+
+from .. import items
+from ..CurvesROIWidget import ROI
+from ..items.roi import RegionOfInterest
+
+from ....math.combo import min_max
+from silx.utils.proxy import docstring
+from ....utils.deprecation import deprecated
+
+logger = logging.getLogger(__name__)
+
+
+class Stats(OrderedDict):
+ """Class to define a set of statistic relative to a dataset
+ (image, curve...).
+
+ The goal of this class is to avoid multiple recalculation of some
+ basic operations such as filtering data area where the statistics has to
+ be apply.
+ Min and max are also stored because they can be used several time.
+
+ :param List statslist: List of the :class:`Stat` object to be computed.
+ """
+ def __init__(self, statslist=None):
+ OrderedDict.__init__(self)
+ _statslist = statslist if not None else []
+ if statslist is not None:
+ for stat in _statslist:
+ self.add(stat)
+
+ def calculate(self, item, plot, onlimits, roi, data_changed=False,
+ roi_changed=False):
+ """
+ Call all :class:`Stat` object registered and return the result of the
+ computation.
+
+ :param item: the item for which we want statistics
+ :param plot: plot containing the item
+ :param bool onlimits: True if we want to apply statistic only on
+ visible data.
+ :param roi: region of interest for statistic calculation. Incompatible
+ with the `onlimits` option.
+ :type roi: Union[None, :class:`~_RegionOfInterestBase`]
+ :param bool data_changed: did the data changed since last calculation.
+ :param bool roi_changed: did the associated roi (if any) has changed
+ since last calculation.
+ :return dict: dictionary with :class:`Stat` name as ket and result
+ of the calculation as value
+ """
+ res = {}
+ context = self._getContext(item=item, plot=plot, onlimits=onlimits,
+ roi=roi)
+ for statName, stat in list(self.items()):
+ if context.kind not in stat.compatibleKinds:
+ logger.debug('kind %s not managed by statistic %s'
+ % (context.kind, stat.name))
+ res[statName] = None
+ else:
+ if roi_changed is True:
+ context.clear_mask()
+ if data_changed is True or roi_changed is True:
+ # if data changed or mask changed
+ context.clipData(item=item, plot=plot, onlimits=onlimits,
+ roi=roi)
+ # init roi and data
+ res[statName] = stat.calculate(context)
+ return res
+
+ def __setitem__(self, key, value):
+ assert isinstance(value, StatBase)
+ OrderedDict.__setitem__(self, key, value)
+
+ def add(self, stat):
+ """Add a :class:`Stat` to the set
+
+ :param Stat stat: stat to add to the set
+ """
+ self.__setitem__(key=stat.name, value=stat)
+
+ @staticmethod
+ @lru_cache(maxsize=50)
+ def _getContext(item, plot, onlimits, roi):
+ context = None
+ # Check for PlotWidget items
+ if isinstance(item, items.Curve):
+ context = _CurveContext(item, plot, onlimits, roi=roi)
+ elif isinstance(item, items.ImageData):
+ context = _ImageContext(item, plot, onlimits, roi=roi)
+ elif isinstance(item, items.Scatter):
+ context = _ScatterContext(item, plot, onlimits, roi=roi)
+ elif isinstance(item, items.Histogram):
+ context = _HistogramContext(item, plot, onlimits, roi=roi)
+ else:
+ # Check for SceneWidget items
+ from ...plot3d import items as items3d # Lazy import
+
+ if isinstance(item, (items3d.Scatter2D, items3d.Scatter3D)):
+ context = _plot3DScatterContext(item, plot, onlimits,
+ roi=roi)
+ elif isinstance(item,
+ (items3d.ImageData, items3d.ScalarField3D)):
+ context = _plot3DArrayContext(item, plot, onlimits,
+ roi=roi)
+ if context is None:
+ raise ValueError('Item type not managed')
+ return context
+
+
+class _StatsContext(object):
+ """
+ The context is designed to be a simple buffer and avoid repetition of
+ calculations that can appear during stats evaluation.
+
+ .. warning:: this class gives access to the data to be used for computation
+ . It deal with filtering data visible by the user on plot.
+ The filtering is a simple data sub-sampling. No interpolation
+ is made to fit data to boundaries.
+
+ :param item: the item for which we want to compute the context
+ :param str kind: the kind of the item
+ :param plot: the plot containing the item
+ :param bool onlimits: True if we want to apply statistic only on
+ visible data.
+ :param roi: Region of interest for computing the statistics.
+ For now, incompatible with `onlimits` calculation
+ :type roi: Union[None,:class:`_RegionOfInterestBase`]
+ """
+ def __init__(self, item, kind, plot, onlimits, roi):
+ assert item
+ assert plot
+ assert type(onlimits) is bool
+ self.kind = kind
+ self.min = None
+ self.max = None
+ self.data = None
+ self.roi = None
+ self.onlimits = onlimits
+
+ self.values = None
+ """The array of data with limit filtering if any. Is a numpy.ma.array,
+ meaning that it embed the mask applied by the roi if any"""
+
+ self.axes = None
+ """A list of array of position on each axis.
+
+ If the signal is an array,
+ then each axis has the length of that dimension,
+ and the order is (z, y, x) (i.e., as the array shape).
+ If the signal is not an array,
+ then each axis has the same length as the signal,
+ and the order is (x, y, z).
+ """
+
+ self.clipData(item, plot, onlimits, roi=roi)
+
+ def clear_mask(self):
+ """
+ Remove the mask to force recomputation of it on next iteration
+ :return:
+ """
+ raise NotImplementedError()
+
+ @property
+ def mask(self):
+ if self.values is not None:
+ assert isinstance(self.values, numpy.ma.MaskedArray)
+ return self.values.mask
+ else:
+ return None
+
+ @property
+ def is_mask_valid(self, **kwargs):
+ """Return if the mask is valid for the data or need to be recomputed"""
+ raise NotImplementedError("Base class")
+
+ def _set_mask_validity(self, **kwargs):
+ """User to set some values that allows to define the mask properties
+ and boundaries"""
+ raise NotImplementedError("Base class")
+
+ def clipData(self, item, plot, onlimits, roi):
+ """Clip the data to the current mask to have accurate statistics
+
+ Function called before computing each statistics associated to this
+ context. It will insure the context for the (item, plot, onlimits, roi)
+ is created.
+
+ :param item: item for which we want statistics
+ :param plot: plot containing the statistics
+ :param bool onlimits: True if we want to apply statistic only on
+ visible data.
+ :param roi: Region of interest for computing the statistics.
+ For now, incompatible with `onlimits` calculation
+ :type roi: Union[None,:class:`_RegionOfInterestBase`]
+ """
+ raise NotImplementedError("Base class")
+
+ @deprecated(reason="context are now stored and keep during stats life."
+ "So this function will be called only once",
+ replacement="clipData", since_version="0.13.0")
+ def createContext(self, item, plot, onlimits, roi):
+ return self.clipData(item=item, plot=plot, onlimits=onlimits,
+ roi=roi)
+
+ def isStructuredData(self):
+ """Returns True if data as an array-like structure.
+
+ :rtype: bool
+ """
+ if self.values is None or self.axes is None:
+ return False
+
+ if numpy.prod([len(axis) for axis in self.axes]) == self.values.size:
+ return True
+ else:
+ # Make sure there is the right number of value in axes
+ for axis in self.axes:
+ assert len(axis) == self.values.size
+ return False
+
+ def isScalarData(self):
+ """Returns True if data is a scalar.
+
+ :rtype: bool
+ """
+ if self.values is None or self.axes is None:
+ return False
+ if self.isStructuredData():
+ return len(self.axes) == self.values.ndim
+ else:
+ return self.values.ndim == 1
+
+ def _checkContextInputs(self, item, plot, onlimits, roi):
+ if roi is not None and onlimits is True:
+ raise ValueError('Stats context is unable to manage both a ROI'
+ 'and the `onlimits` option')
+
+
+class _ScatterCurveHistoMixInContext(_StatsContext):
+ def __init__(self, kind, item, plot, onlimits, roi):
+ self.clear_mask()
+ _StatsContext.__init__(self, item=item, kind=kind,
+ plot=plot, onlimits=onlimits, roi=roi)
+
+ def _set_mask_validity(self, onlimits, from_, to_):
+ self._onlimits = onlimits
+ self._from_ = from_
+ self._to_ = to_
+
+ def clear_mask(self):
+ self._onlimits = None
+ self._from_ = None
+ self._to_ = None
+
+ def is_mask_valid(self, onlimits, from_, to_):
+ return (onlimits == self.onlimits and from_ == self._from_ and
+ to_ == self._to_)
+
+
+class _CurveContext(_ScatterCurveHistoMixInContext):
+ """
+ StatsContext for :class:`Curve`
+
+ :param item: the item for which we want to compute the context
+ :param plot: the plot containing the item
+ :param bool onlimits: True if we want to apply statistic only on
+ visible data.
+ :param roi: Region of interest for computing the statistics.
+ For now, incompatible with `onlinits` calculation
+ :type roi: Union[None, :class:`ROI`]
+ """
+ def __init__(self, item, plot, onlimits, roi):
+ _ScatterCurveHistoMixInContext.__init__(self, kind='curve', item=item,
+ plot=plot, onlimits=onlimits,
+ roi=roi)
+
+ @docstring(_StatsContext)
+ def clipData(self, item, plot, onlimits, roi):
+ self._checkContextInputs(item=item, plot=plot, onlimits=onlimits,
+ roi=roi)
+ self.roi = roi
+ self.onlimits = onlimits
+ xData, yData = item.getData(copy=True)[0:2]
+
+ if onlimits:
+ minX, maxX = plot.getXAxis().getLimits()
+ if self.is_mask_valid(onlimits=onlimits, from_=minX, to_=maxX):
+ mask = self.mask
+ else:
+ mask = (minX <= xData) & (xData <= maxX)
+ mask = mask == 0
+ self._set_mask_validity(onlimits=onlimits, from_=minX, to_=maxX)
+ elif roi:
+ minX, maxX = roi.getFrom(), roi.getTo()
+ if self.is_mask_valid(onlimits=onlimits, from_=minX, to_=maxX):
+ mask = self.mask
+ else:
+ mask = (minX <= xData) & (xData <= maxX)
+ mask = mask == 0
+ self._set_mask_validity(onlimits=onlimits, from_=minX, to_=maxX)
+ else:
+ mask = numpy.zeros_like(yData)
+
+ mask = mask.astype(numpy.uint32)
+ self.xData = xData
+ self.yData = yData
+ self.values = numpy.ma.array(yData, mask=mask)
+ unmasked_data = self.values.compressed()
+ if len(unmasked_data) > 0:
+ self.min, self.max = min_max(unmasked_data)
+ else:
+ self.min, self.max = None, None
+ self.data = (xData, yData)
+ self.axes = (xData,)
+
+ def _checkContextInputs(self, item, plot, onlimits, roi):
+ _StatsContext._checkContextInputs(self, item=item, plot=plot,
+ onlimits=onlimits, roi=roi)
+ if roi is not None and not isinstance(roi, ROI):
+ raise TypeError('curve `context` can ony manage 1D roi')
+
+
+class _HistogramContext(_ScatterCurveHistoMixInContext):
+ """
+ StatsContext for :class:`Histogram`
+
+ :param item: the item for which we want to compute the context
+ :param plot: the plot containing the item
+ :param bool onlimits: True if we want to apply statistic only on
+ visible data.
+ :param roi: Region of interest for computing the statistics.
+ For now, incompatible with `onlinits` calculation
+ :type roi: Union[None, :class:`ROI`]
+ """
+ def __init__(self, item, plot, onlimits, roi):
+ _ScatterCurveHistoMixInContext.__init__(self, kind='histogram',
+ item=item, plot=plot,
+ onlimits=onlimits, roi=roi)
+
+ @docstring(_StatsContext)
+ def clipData(self, item, plot, onlimits, roi):
+ self._checkContextInputs(item=item, plot=plot, onlimits=onlimits,
+ roi=roi)
+ yData, edges = item.getData(copy=True)[0:2]
+ xData = item._revertComputeEdges(x=edges, histogramType=item.getAlignment())
+
+ if onlimits:
+ minX, maxX = plot.getXAxis().getLimits()
+ if self.is_mask_valid(onlimits=onlimits, from_=minX, to_=maxX):
+ mask = self.mask
+ else:
+ mask = (minX <= xData) & (xData <= maxX)
+ mask = mask == 0
+ self._set_mask_validity(onlimits=onlimits, from_=minX, to_=maxX)
+ elif roi:
+ if self.is_mask_valid(onlimits=onlimits, from_=roi._fromdata, to_=roi._todata):
+ mask = self.mask
+ else:
+ mask = (roi._fromdata <= xData) & (xData <= roi._todata)
+ mask = mask == 0
+ self._set_mask_validity(onlimits=onlimits, from_=roi._fromdata,
+ to_=roi._todata)
+ else:
+ mask = numpy.zeros_like(yData)
+ mask = mask.astype(numpy.uint32)
+ self.xData = xData
+ self.yData = yData
+ self.values = numpy.ma.array(yData, mask=(mask))
+ unmasked_data = self.values.compressed()
+ if len(unmasked_data) > 0:
+ self.min, self.max = min_max(unmasked_data)
+ else:
+ self.min, self.max = None, None
+ self.data = (self.xData, self.yData)
+ self.axes = (self.xData,)
+
+ def _checkContextInputs(self, item, plot, onlimits, roi):
+ _StatsContext._checkContextInputs(self, item=item, plot=plot,
+ onlimits=onlimits, roi=roi)
+
+ if roi is not None and not isinstance(roi, ROI):
+ raise TypeError('curve `context` can ony manage 1D roi')
+
+
+class _ScatterContext(_ScatterCurveHistoMixInContext):
+ """StatsContext scatter plots.
+
+ It supports :class:`~silx.gui.plot.items.Scatter`.
+
+ :param item: the item for which we want to compute the context
+ :param plot: the plot containing the item
+ :param bool onlimits: True if we want to apply statistic only on
+ visible data.
+ :param roi: Region of interest for computing the statistics.
+ For now, incompatible with `onlinits` calculation
+ :type roi: Union[None, :class:`ROI`]
+ """
+ def __init__(self, item, plot, onlimits, roi):
+ _ScatterCurveHistoMixInContext.__init__(self, kind='scatter',
+ item=item, plot=plot,
+ onlimits=onlimits, roi=roi)
+
+ @docstring(_ScatterCurveHistoMixInContext)
+ def clipData(self, item, plot, onlimits, roi):
+ self._checkContextInputs(item=item, plot=plot, onlimits=onlimits,
+ roi=roi)
+ valueData = item.getValueData(copy=True)
+ xData = item.getXData(copy=True)
+ yData = item.getYData(copy=True)
+
+ if onlimits:
+ minX, maxX = plot.getXAxis().getLimits()
+ minY, maxY = plot.getYAxis().getLimits()
+
+ # filter on X axis
+ valueData = valueData[(minX <= xData) & (xData <= maxX)]
+ yData = yData[(minX <= xData) & (xData <= maxX)]
+ xData = xData[(minX <= xData) & (xData <= maxX)]
+ # filter on Y axis
+ valueData = valueData[(minY <= yData) & (yData <= maxY)]
+ xData = xData[(minY <= yData) & (yData <= maxY)]
+ yData = yData[(minY <= yData) & (yData <= maxY)]
+
+ if roi:
+ if self.is_mask_valid(onlimits=onlimits, from_=roi.getFrom(),
+ to_=roi.getTo()):
+ mask = self.mask
+ else:
+ mask = (xData < roi.getFrom()) | (xData > roi.getTo())
+ else:
+ mask = numpy.zeros_like(xData)
+
+ self.data = (xData, yData, valueData)
+ self.values = numpy.ma.array(valueData, mask=mask)
+ self.axes = (xData, yData)
+
+ unmasked_values = self.values.compressed()
+ if len(unmasked_values) > 0:
+ self.min, self.max = min_max(unmasked_values)
+ else:
+ self.min, self.max = None, None
+
+ def _checkContextInputs(self, item, plot, onlimits, roi):
+ _StatsContext._checkContextInputs(self, item=item, plot=plot,
+ onlimits=onlimits, roi=roi)
+
+ if roi is not None and not isinstance(roi, ROI):
+ raise TypeError('curve `context` can ony manage 1D roi')
+
+
+class _ImageContext(_StatsContext):
+ """StatsContext for images.
+
+ It supports :class:`~silx.gui.plot.items.ImageData`.
+
+ :warning: behaviour of scale images: now the statistics are computed on
+ the entire data array (there is no sampling in the array or
+ interpolation regarding the scale).
+ This also mean that the result can differ from what is displayed.
+ But I guess there is no perfect behaviour.
+
+ :warning: `isIn` functions for image context: for now have basically a
+ binary approach, the pixel is in a roi or not. To have a fully
+ 'correct behaviour' we should add a weight on stats calculation
+ to moderate the pixel value.
+
+ :param item: the item for which we want to compute the context
+ :param plot: the plot containing the item
+ :param bool onlimits: True if we want to apply statistic only on
+ visible data.
+ :param roi: Region of interest for computing the statistics.
+ For now, incompatible with `onlinits` calculation
+ :type roi: Union[None, :class:`ROI`]
+ """
+ def __init__(self, item, plot, onlimits, roi):
+ self.clear_mask()
+ _StatsContext.__init__(self, kind='image', item=item,
+ plot=plot, onlimits=onlimits, roi=roi)
+
+ def _set_mask_validity(self, xmin: float, xmax: float, ymin: float, ymax
+ : float):
+ self._mask_x_min = xmin
+ self._mask_x_max = xmax
+ self._mask_y_min = ymin
+ self._mask_y_max = ymax
+
+ def clear_mask(self):
+ self._mask_x_min = None
+ self._mask_x_max = None
+ self._mask_y_min = None
+ self._mask_y_max = None
+
+ def is_mask_valid(self, xmin, xmax, ymin, ymax):
+ return (xmin == self._mask_x_min and xmax == self._mask_x_max and
+ ymin == self._mask_y_min and ymax == self._mask_y_max)
+
+ @docstring(_StatsContext)
+ def clipData(self, item, plot, onlimits, roi):
+ self._checkContextInputs(item=item, plot=plot, onlimits=onlimits,
+ roi=roi)
+ self.origin = item.getOrigin()
+ self.scale = item.getScale()
+
+ self.data = item.getData(copy=True)
+ mask = numpy.zeros_like(self.data)
+ """mask use to know of the stat should be count in or not"""
+
+ if onlimits:
+ minX, maxX = plot.getXAxis().getLimits()
+ minY, maxY = plot.getYAxis().getLimits()
+
+ XMinBound = int((minX - self.origin[0]) / self.scale[0])
+ YMinBound = int((minY - self.origin[1]) / self.scale[1])
+ XMaxBound = int((maxX - self.origin[0]) / self.scale[0])
+ YMaxBound = int((maxY - self.origin[1]) / self.scale[1])
+
+ XMinBound = max(XMinBound, 0)
+ YMinBound = max(YMinBound, 0)
+
+ if onlimits:
+ if XMaxBound <= XMinBound or YMaxBound <= YMinBound:
+ self.data = None
+ else:
+ self.data = self.data[YMinBound:YMaxBound + 1,
+ XMinBound:XMaxBound + 1]
+ mask = numpy.zeros_like(self.data)
+ elif roi:
+ minX, maxX = 0, self.data.shape[1]
+ minY, maxY = 0, self.data.shape[0]
+
+ XMinBound = max(minX, 0)
+ YMinBound = max(minY, 0)
+ XMaxBound = min(maxX, self.data.shape[1])
+ YMaxBound = min(maxY, self.data.shape[0])
+
+ if self.is_mask_valid(xmin=XMinBound, xmax=XMaxBound,
+ ymin=YMinBound, ymax=YMaxBound):
+ mask = self.mask
+ else:
+ for x in range(XMinBound, XMaxBound):
+ for y in range(YMinBound, YMaxBound):
+ _x = (x * self.scale[0]) + self.origin[0]
+ _y = (y * self.scale[1]) + self.origin[1]
+ mask[y, x] = not roi.contains((_x, _y))
+ self._set_mask_validity(xmin=XMinBound, xmax=XMaxBound,
+ ymin=YMinBound, ymax=YMaxBound)
+ self.values = numpy.ma.array(self.data, mask=mask)
+ if self.values.compressed().size > 0:
+ self.min, self.max = min_max(self.values.compressed())
+ else:
+ self.min, self.max = None, None
+
+ if self.values is not None:
+ self.axes = (self.origin[1] + self.scale[1] * numpy.arange(self.data.shape[0]),
+ self.origin[0] + self.scale[0] * numpy.arange(self.data.shape[1]))
+
+ def _checkContextInputs(self, item, plot, onlimits, roi):
+ _StatsContext._checkContextInputs(self, item=item, plot=plot,
+ onlimits=onlimits, roi=roi)
+
+ if roi is not None and not isinstance(roi, RegionOfInterest):
+ raise TypeError('curve `context` can ony manage 2D roi')
+
+
+class _plot3DScatterContext(_StatsContext):
+ """StatsContext for 3D scatter plots.
+
+ It supports :class:`~silx.gui.plot3d.items.Scatter2D` and
+ :class:`~silx.gui.plot3d.items.Scatter3D`.
+
+ :param item: the item for which we want to compute the context
+ :param plot: the plot containing the item
+ :param bool onlimits: True if we want to apply statistic only on
+ visible data.
+ :param roi: Region of interest for computing the statistics.
+ For now, incompatible with `onlinits` calculation
+ :type roi: Union[None, :class:`ROI`]
+ """
+ def __init__(self, item, plot, onlimits, roi):
+ _StatsContext.__init__(self, kind='scatter', item=item, plot=plot,
+ onlimits=onlimits, roi=roi)
+
+ @docstring(_StatsContext)
+ def clipData(self, item, plot, onlimits, roi):
+ self._checkContextInputs(item=item, plot=plot, onlimits=onlimits,
+ roi=roi)
+ if onlimits:
+ raise RuntimeError("Unsupported plot %s" % str(plot))
+ values = item.getValueData(copy=False)
+ if roi:
+ logger.warning("Roi are unsupported on volume for now")
+ mask = numpy.zeros_like(values)
+ else:
+ mask = numpy.zeros_like(values)
+
+ if values is not None and len(values) > 0:
+ self.values = values
+ axes = [item.getXData(copy=False), item.getYData(copy=False)]
+ if self.values.ndim == 3:
+ axes.append(item.getZData(copy=False))
+ self.axes = tuple(axes)
+ self.min, self.max = min_max(self.values)
+ self.values = numpy.ma.array(self.values, mask=mask)
+ else:
+ self.values = None
+ self.axes = None
+ self.min, self.max = None, None
+
+ def _checkContextInputs(self, item, plot, onlimits, roi):
+ _StatsContext._checkContextInputs(self, item=item, plot=plot,
+ onlimits=onlimits, roi=roi)
+
+ if roi is not None and not isinstance(roi, RegionOfInterest):
+ raise TypeError('curve `context` can ony manage 2D roi')
+
+
+class _plot3DArrayContext(_StatsContext):
+ """StatsContext for 3D scalar field and data image.
+
+ It supports :class:`~silx.gui.plot3d.items.ScalarField3D` and
+ :class:`~silx.gui.plot3d.items.ImageData`.
+
+ :param item: the item for which we want to compute the context
+ :param plot: the plot containing the item
+ :param bool onlimits: True if we want to apply statistic only on
+ visible data.
+ :param roi: Region of interest for computing the statistics.
+ For now, incompatible with `onlinits` calculation
+ :type roi: Union[None, :class:`ROI`]
+ """
+ def __init__(self, item, plot, onlimits, roi):
+ _StatsContext.__init__(self, kind='image', item=item, plot=plot,
+ onlimits=onlimits, roi=roi)
+
+ @docstring(_StatsContext)
+ def clipData(self, item, plot, onlimits, roi):
+ self._checkContextInputs(item=item, plot=plot, onlimits=onlimits,
+ roi=roi)
+ if onlimits:
+ raise RuntimeError("Unsupported plot %s" % str(plot))
+
+ values = item.getData(copy=False)
+ if roi:
+ logger.warning("Roi are unsuported on volume for now")
+ mask = numpy.zeros_like(values)
+ else:
+ mask = numpy.zeros_like(values)
+
+ if values is not None and len(values) > 0:
+ self.values = values
+ self.axes = tuple([numpy.arange(size) for size in self.values.shape])
+ self.min, self.max = min_max(self.values)
+ self.values = numpy.ma.array(self.values, mask=mask)
+ else:
+ self.values = None
+ self.axes = None
+ self.min, self.max = None, None
+
+ def _checkContextInputs(self, item, plot, onlimits, roi):
+ _StatsContext._checkContextInputs(self, item=item, plot=plot,
+ onlimits=onlimits, roi=roi)
+
+ if roi is not None and not isinstance(roi, RegionOfInterest):
+ raise TypeError('curve `context` can ony manage 2D roi')
+
+
+BASIC_COMPATIBLE_KINDS = 'curve', 'image', 'scatter', 'histogram'
+
+
+class StatBase(object):
+ """
+ Base class for defining a statistic.
+
+ :param str name: the name of the statistic. Must be unique.
+ :param List[str] compatibleKinds:
+ The kind of items (curve, scatter...) for which the statistic apply.
+ """
+ def __init__(self, name, compatibleKinds=BASIC_COMPATIBLE_KINDS, description=None):
+ self.name = name
+ self.compatibleKinds = compatibleKinds
+ self.description = description
+
+ def calculate(self, context):
+ """
+ compute the statistic for the given :class:`StatsContext`
+
+ :param _StatsContext context:
+ :return dict: key is stat name, statistic computed is the dict value
+ """
+ raise NotImplementedError('Base class')
+
+ def getToolTip(self, kind):
+ """
+ If necessary add a tooltip for a stat kind
+
+ :param str kind: the kind of item the statistic is compute for.
+ :return: tooltip or None if no tooltip
+ """
+ return None
+
+
+class Stat(StatBase):
+ """
+ Create a StatBase class based on a function pointer.
+
+ :param str name: name of the statistic. Used as id
+ :param fct: function which should have as unique mandatory parameter the
+ data. Should be able to adapt to all `kinds` defined as
+ compatible
+ :param tuple kinds: the compatible item kinds of the function (curve,
+ image...)
+ """
+ def __init__(self, name, fct, kinds=BASIC_COMPATIBLE_KINDS):
+ StatBase.__init__(self, name, kinds)
+ self._fct = fct
+
+ @docstring(StatBase)
+ def calculate(self, context):
+ if context.values is not None:
+ if context.kind in self.compatibleKinds:
+ return self._fct(context.values)
+ else:
+ raise ValueError('Kind %s not managed by %s'
+ '' % (context.kind, self.name))
+ else:
+ return None
+
+
+class StatMin(StatBase):
+ """Compute the minimal value on data"""
+ def __init__(self):
+ StatBase.__init__(self, name='min')
+
+ @docstring(StatBase)
+ def calculate(self, context):
+ return context.min
+
+
+class StatMax(StatBase):
+ """Compute the maximal value on data"""
+ def __init__(self):
+ StatBase.__init__(self, name='max')
+
+ @docstring(StatBase)
+ def calculate(self, context):
+ return context.max
+
+
+class StatDelta(StatBase):
+ """Compute the delta between minimal and maximal on data"""
+ def __init__(self):
+ StatBase.__init__(self, name='delta')
+
+ @docstring(StatBase)
+ def calculate(self, context):
+ return context.max - context.min
+
+
+class _StatCoord(StatBase):
+ """Base class for argmin and argmax stats"""
+
+ def _indexToCoordinates(self, context, index):
+ """Returns the coordinates of data point at given index
+
+ If data is an array, coordinates are in reverse order from data shape.
+
+ :param _StatsContext context:
+ :param int index: Index in the flattened data array
+ :rtype: List[int]
+ """
+
+ axes = context.axes
+
+ if context.isStructuredData() or context.roi:
+ coordinates = []
+ for axis in reversed(axes):
+ coordinates.append(axis[index % len(axis)])
+ index = index // len(axis)
+ return tuple(coordinates)
+ else:
+ return tuple(axis[index] for axis in axes)
+
+
+class StatCoordMin(_StatCoord):
+ """Compute the coordinates of the first minimum value of the data"""
+ def __init__(self):
+ _StatCoord.__init__(self, name='coords min')
+
+ @docstring(StatBase)
+ def calculate(self, context):
+ if context.values is None or not context.isScalarData():
+ return None
+
+ index = context.values.argmin()
+ return self._indexToCoordinates(context, index)
+
+ @docstring(StatBase)
+ def getToolTip(self, kind):
+ return "Coordinates of the first minimum value of the data"
+
+
+class StatCoordMax(_StatCoord):
+ """Compute the coordinates of the first maximum value of the data"""
+ def __init__(self):
+ _StatCoord.__init__(self, name='coords max')
+
+ @docstring(StatBase)
+ def calculate(self, context):
+ if context.values is None or not context.isScalarData():
+ return None
+
+ # TODO: the values should be a mask array by default, will be simpler
+ # if possible
+ index = context.values.argmax()
+ return self._indexToCoordinates(context, index)
+
+ @docstring(StatBase)
+ def getToolTip(self, kind):
+ return "Coordinates of the first maximum value of the data"
+
+
+class StatCOM(StatBase):
+ """Compute data center of mass"""
+ def __init__(self):
+ StatBase.__init__(self, name='COM', description='Center of mass')
+
+ @docstring(StatBase)
+ def calculate(self, context):
+ if context.values is None or not context.isScalarData():
+ return None
+
+ values = numpy.ma.array(context.values, mask=context.mask, dtype=numpy.float64)
+ sum_ = numpy.sum(values)
+ if sum_ == 0.:
+ return (numpy.nan,) * len(context.axes)
+
+ if context.isStructuredData():
+ centerofmass = []
+ for index, axis in enumerate(context.axes):
+ axes = tuple([i for i in range(len(context.axes)) if i != index])
+ centerofmass.append(
+ numpy.sum(axis * numpy.sum(values, axis=axes)) / sum_)
+ return tuple(reversed(centerofmass))
+ else:
+ return tuple(
+ numpy.sum(axis * values) / sum_ for axis in context.axes)
+
+ @docstring(StatBase)
+ def getToolTip(self, kind):
+ return "Compute the center of mass of the dataset"
diff --git a/src/silx/gui/plot/stats/statshandler.py b/src/silx/gui/plot/stats/statshandler.py
new file mode 100644
index 0000000..17578d8
--- /dev/null
+++ b/src/silx/gui/plot/stats/statshandler.py
@@ -0,0 +1,202 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module containts the classes relative to the management of statistics
+display.
+"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "05/06/2018"
+
+
+import logging
+
+from silx.gui import qt
+from silx.gui.plot import stats as statsmdl
+
+logger = logging.getLogger(__name__)
+
+
+class _FloatItem(qt.QTableWidgetItem):
+ """Simple QTableWidgetItem allowing ordering on floats"""
+
+ def __init__(self, type=qt.QTableWidgetItem.Type):
+ qt.QTableWidgetItem.__init__(self, type=type)
+
+ def __lt__(self, other):
+ self_values = self.text().lstrip('(').rstrip(')').split(',')
+ other_values = other.text().lstrip('(').rstrip(')').split(',')
+ for self_value, other_value in zip(self_values, other_values):
+ f_self_value = float(self_value)
+ f_other_value = float(other_value)
+ if f_self_value != f_other_value:
+ return f_self_value < f_other_value
+ return False
+
+
+class StatFormatter(object):
+ """
+ Class used to apply format on :class:`Stat`
+
+ :param formatter: the formatter. Defined as str.format()
+ :param qItemClass: the class inheriting from :class:`QTableWidgetItem`
+ which will be used to display the result of the
+ statistic computation.
+ """
+ DEFAULT_FORMATTER = '{0:.3f}'
+
+ def __init__(self, formatter=DEFAULT_FORMATTER, qItemClass=_FloatItem):
+ self.formatter = formatter
+ self.tabWidgetItemClass = qItemClass
+
+ def format(self, val):
+ if self.formatter is None or val is None:
+ return str(val)
+ else:
+ return self.formatter.format(val)
+
+
+class StatsHandler(object):
+ """
+ Give
+ create:
+
+ * Stats object which will manage the statistic computation
+ * Associate formatter and :class:`Stat`
+
+ :param statFormatters: Stat and optional formatter.
+ If elements are given as a tuple, elements
+ should be (:class:`Stat`, formatter).
+ Otherwise should be :class:`Stat` elements.
+ :rtype: List or tuple
+ """
+
+ def __init__(self, statFormatters):
+ self.stats = statsmdl.Stats()
+ self.formatters = {}
+ for elmt in statFormatters:
+ stat, formatter = self._processStatArgument(elmt)
+ self.add(stat=stat, formatter=formatter)
+
+ @staticmethod
+ def _processStatArgument(arg):
+ """Process an element of the init arguments
+
+ :param arg: The argument to process
+ :return: Corresponding (StatBase, StatFormatter)
+ """
+ stat, formatter = None, None
+
+ if isinstance(arg, statsmdl.StatBase):
+ stat = arg
+ else:
+ assert len(arg) > 0
+ if isinstance(arg[0], statsmdl.StatBase):
+ stat = arg[0]
+ if len(arg) > 2:
+ raise ValueError('To many argument with %s. At most one '
+ 'argument can be associated with the '
+ 'BaseStat (the `StatFormatter`')
+ if len(arg) == 2:
+ assert arg[1] is None or isinstance(arg[1], (StatFormatter, str))
+ formatter = arg[1]
+ else:
+ if isinstance(arg[0], tuple):
+ if len(arg) > 1:
+ formatter = arg[1]
+ arg = arg[0]
+
+ if type(arg[0]) is not str:
+ raise ValueError('first element of the tuple should be a string'
+ ' or a StatBase instance')
+ if len(arg) == 1:
+ raise ValueError('A function should be associated with the'
+ 'stat name')
+ if len(arg) > 3:
+ raise ValueError('Two much argument given for defining statistic.'
+ 'Take at most three arguments (name, function, '
+ 'kinds)')
+ if len(arg) == 2:
+ stat = statsmdl.Stat(name=arg[0], fct=arg[1])
+ else:
+ stat = statsmdl.Stat(name=arg[0], fct=arg[1], kinds=arg[2])
+
+ return stat, formatter
+
+ def add(self, stat, formatter=None):
+ """Add a stat to the list.
+
+ :param StatBase stat:
+ :param Union[None,StatFormatter] formatter:
+ """
+ assert isinstance(stat, statsmdl.StatBase)
+ self.stats.add(stat)
+ _formatter = formatter
+ if type(_formatter) is str:
+ _formatter = StatFormatter(formatter=_formatter)
+ self.formatters[stat.name] = _formatter
+
+ def format(self, name, val):
+ """Apply the format for the `name` statistic and the given value
+
+ :param str name: the name of the associated statistic
+ :param val: value before formatting
+ :return: formatted value
+ """
+ if name not in self.formatters:
+ logger.warning("statistic %s haven't been registred" % name)
+ return val
+ else:
+ if self.formatters[name] is None:
+ return str(val)
+ else:
+ if isinstance(val, (tuple, list)):
+ res = []
+ [res.append(self.formatters[name].format(_val)) for _val in val]
+ return ', '.join(res)
+ else:
+ return self.formatters[name].format(val)
+
+ def calculate(self, item, plot, onlimits, roi=None, data_changed=False,
+ roi_changed=False):
+ """
+ compute all statistic registered and return the list of formatted
+ statistics result.
+
+ :param item: item for which we want to compute statistics
+ :param plot: plot containing the item
+ :param onlimits: True if we want to compute statistics on visible data
+ only
+ :type: bool
+ :param roi: region of interest for statistic calculation
+ :type: Union[None,:class:`_RegionOfInterestBase`]
+ :return: list of formatted statistics (as str)
+ :rtype: dict
+ """
+ res = self.stats.calculate(item, plot, onlimits, roi,
+ data_changed=data_changed, roi_changed=roi_changed)
+ for resName, resValue in list(res.items()):
+ res[resName] = self.format(resName, res[resName])
+ return res
diff --git a/src/silx/gui/plot/test/__init__.py b/src/silx/gui/plot/test/__init__.py
new file mode 100644
index 0000000..3ad225d
--- /dev/null
+++ b/src/silx/gui/plot/test/__init__.py
@@ -0,0 +1,24 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
diff --git a/src/silx/gui/plot/test/testAlphaSlider.py b/src/silx/gui/plot/test/testAlphaSlider.py
new file mode 100644
index 0000000..ca57bf5
--- /dev/null
+++ b/src/silx/gui/plot/test/testAlphaSlider.py
@@ -0,0 +1,204 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for ImageAlphaSlider"""
+
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "28/03/2017"
+
+import numpy
+import unittest
+
+from silx.gui import qt
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.plot import PlotWidget
+from silx.gui.plot import AlphaSlider
+
+
+class TestActiveImageAlphaSlider(TestCaseQt):
+ def setUp(self):
+ super(TestActiveImageAlphaSlider, self).setUp()
+ self.plot = PlotWidget()
+ self.aslider = AlphaSlider.ActiveImageAlphaSlider(plot=self.plot)
+ self.aslider.setOrientation(qt.Qt.Horizontal)
+
+ toolbar = qt.QToolBar("plot", self.plot)
+ toolbar.addWidget(self.aslider)
+ self.plot.addToolBar(toolbar)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ self.mouseMove(self.plot) # Move to center
+ self.qapp.processEvents()
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ del self.aslider
+
+ super(TestActiveImageAlphaSlider, self).tearDown()
+
+ def testWidgetEnabled(self):
+ # no active image initially, slider must be deactivate
+ self.assertFalse(self.aslider.isEnabled())
+
+ self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]))
+ # now we have an active image
+ self.assertTrue(self.aslider.isEnabled())
+
+ self.plot.setActiveImage(None)
+ self.assertFalse(self.aslider.isEnabled())
+
+ def testGetImage(self):
+ self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]))
+ self.assertEqual(self.plot.getActiveImage(),
+ self.aslider.getItem())
+
+ self.plot.addImage(numpy.array([[0, 1, 3], [2, 4, 6]]), legend="2")
+ self.plot.setActiveImage("2")
+ self.assertEqual(self.plot.getImage("2"),
+ self.aslider.getItem())
+
+ def testGetAlpha(self):
+ self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1")
+ self.aslider.setValue(137)
+ self.assertAlmostEqual(self.aslider.getAlpha(),
+ 137. / 255)
+
+
+class TestNamedImageAlphaSlider(TestCaseQt):
+ def setUp(self):
+ super(TestNamedImageAlphaSlider, self).setUp()
+ self.plot = PlotWidget()
+ self.aslider = AlphaSlider.NamedImageAlphaSlider(plot=self.plot)
+ self.aslider.setOrientation(qt.Qt.Horizontal)
+
+ toolbar = qt.QToolBar("plot", self.plot)
+ toolbar.addWidget(self.aslider)
+ self.plot.addToolBar(toolbar)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ self.mouseMove(self.plot) # Move to center
+ self.qapp.processEvents()
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ del self.aslider
+
+ super(TestNamedImageAlphaSlider, self).tearDown()
+
+ def testWidgetEnabled(self):
+ # no image set initially, slider must be deactivate
+ self.assertFalse(self.aslider.isEnabled())
+
+ self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1")
+ self.aslider.setLegend("1")
+ # now we have an image set
+ self.assertTrue(self.aslider.isEnabled())
+
+ def testGetImage(self):
+ self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1")
+ self.plot.addImage(numpy.array([[0, 1, 3], [2, 4, 6]]), legend="2")
+ self.aslider.setLegend("1")
+ self.assertEqual(self.plot.getImage("1"),
+ self.aslider.getItem())
+
+ self.aslider.setLegend("2")
+ self.assertEqual(self.plot.getImage("2"),
+ self.aslider.getItem())
+
+ def testGetAlpha(self):
+ self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1")
+ self.aslider.setLegend("1")
+ self.aslider.setValue(128)
+ self.assertAlmostEqual(self.aslider.getAlpha(),
+ 128. / 255)
+
+
+class TestNamedScatterAlphaSlider(TestCaseQt):
+ def setUp(self):
+ super(TestNamedScatterAlphaSlider, self).setUp()
+ self.plot = PlotWidget()
+ self.aslider = AlphaSlider.NamedScatterAlphaSlider(plot=self.plot)
+ self.aslider.setOrientation(qt.Qt.Horizontal)
+
+ toolbar = qt.QToolBar("plot", self.plot)
+ toolbar.addWidget(self.aslider)
+ self.plot.addToolBar(toolbar)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ self.mouseMove(self.plot) # Move to center
+ self.qapp.processEvents()
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ del self.aslider
+
+ super(TestNamedScatterAlphaSlider, self).tearDown()
+
+ def testWidgetEnabled(self):
+ # no Scatter set initially, slider must be deactivate
+ self.assertFalse(self.aslider.isEnabled())
+
+ self.plot.addScatter([0, 1, 2], [2, 3, 4], [5, 6, 7],
+ legend="1")
+ self.aslider.setLegend("1")
+ # now we have an image set
+ self.assertTrue(self.aslider.isEnabled())
+
+ def testGetScatter(self):
+ self.plot.addScatter([0, 1, 2], [2, 3, 4], [5, 6, 7],
+ legend="1")
+ self.plot.addScatter([0, 10, 20], [20, 30, 40], [50, 60, 70],
+ legend="2")
+ self.aslider.setLegend("1")
+ self.assertEqual(self.plot.getScatter("1"),
+ self.aslider.getItem())
+
+ self.aslider.setLegend("2")
+ self.assertEqual(self.plot.getScatter("2"),
+ self.aslider.getItem())
+
+ def testGetAlpha(self):
+ self.plot.addScatter([0, 10, 20], [20, 30, 40], [50, 60, 70],
+ legend="1")
+ self.aslider.setLegend("1")
+ self.aslider.setValue(128)
+ self.assertAlmostEqual(self.aslider.getAlpha(),
+ 128. / 255)
diff --git a/src/silx/gui/plot/test/testColorBar.py b/src/silx/gui/plot/test/testColorBar.py
new file mode 100644
index 0000000..3dc8ff1
--- /dev/null
+++ b/src/silx/gui/plot/test/testColorBar.py
@@ -0,0 +1,340 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for ColorBar featues and sub widgets of Colorbar module"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+import unittest
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.plot.ColorBar import _ColorScale
+from silx.gui.plot.ColorBar import ColorBarWidget
+from silx.gui.colors import Colormap
+from silx.math.colormap import LinearNormalization, LogarithmicNormalization
+from silx.gui.plot import Plot2D
+from silx.gui import qt
+import numpy
+
+
+class TestColorScale(TestCaseQt):
+ """Test that interaction with the colorScale is correct"""
+ def setUp(self):
+ super(TestColorScale, self).setUp()
+ self.colorScaleWidget = _ColorScale(colormap=None, parent=None)
+ self.colorScaleWidget.show()
+ self.qWaitForWindowExposed(self.colorScaleWidget)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.colorScaleWidget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.colorScaleWidget.close()
+ del self.colorScaleWidget
+ super(TestColorScale, self).tearDown()
+
+ def testNoColormap(self):
+ """Test _ColorScale without a colormap"""
+ colormap = self.colorScaleWidget.getColormap()
+ self.assertIsNone(colormap)
+
+ def testRelativePositionLinear(self):
+ self.colorMapLin1 = Colormap(name='gray',
+ normalization=Colormap.LINEAR,
+ vmin=0.0,
+ vmax=1.0)
+ self.colorScaleWidget.setColormap(self.colorMapLin1)
+
+ self.assertTrue(
+ self.colorScaleWidget.getValueFromRelativePosition(0.25) == 0.25)
+ self.assertTrue(
+ self.colorScaleWidget.getValueFromRelativePosition(0.5) == 0.5)
+ self.assertTrue(
+ self.colorScaleWidget.getValueFromRelativePosition(1.0) == 1.0)
+
+ self.colorMapLin2 = Colormap(name='viridis',
+ normalization=Colormap.LINEAR,
+ vmin=-10,
+ vmax=0)
+ self.colorScaleWidget.setColormap(self.colorMapLin2)
+
+ self.assertTrue(
+ self.colorScaleWidget.getValueFromRelativePosition(0.25) == -7.5)
+ self.assertTrue(
+ self.colorScaleWidget.getValueFromRelativePosition(0.5) == -5.0)
+ self.assertTrue(
+ self.colorScaleWidget.getValueFromRelativePosition(1.0) == 0.0)
+
+ def testRelativePositionLog(self):
+ self.colorMapLog1 = Colormap(name='temperature',
+ normalization=Colormap.LOGARITHM,
+ vmin=1.0,
+ vmax=100.0)
+
+ self.colorScaleWidget.setColormap(self.colorMapLog1)
+
+ val = self.colorScaleWidget.getValueFromRelativePosition(1.0)
+ self.assertAlmostEqual(val, 100.0)
+
+ val = self.colorScaleWidget.getValueFromRelativePosition(0.5)
+ self.assertAlmostEqual(val, 10.0)
+
+ val = self.colorScaleWidget.getValueFromRelativePosition(0.0)
+ self.assertTrue(val == 1.0)
+
+
+class TestNoAutoscale(TestCaseQt):
+ """Test that ticks and color displayed are correct in the case of a colormap
+ with no autoscale
+ """
+
+ def setUp(self):
+ super(TestNoAutoscale, self).setUp()
+ self.plot = Plot2D()
+ self.colorBar = self.plot.getColorBarWidget()
+ self.colorBar.setVisible(True) # Makes sure the colormap is visible
+ self.tickBar = self.colorBar.getColorScaleBar().getTickBar()
+ self.colorScale = self.colorBar.getColorScaleBar().getColorScale()
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.tickBar = None
+ self.colorScale = None
+ del self.colorBar
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ super(TestNoAutoscale, self).tearDown()
+
+ def testLogNormNoAutoscale(self):
+ colormapLog = Colormap(name='gray',
+ normalization=Colormap.LOGARITHM,
+ vmin=1.0,
+ vmax=100.0)
+
+ data = numpy.linspace(10, 1e10, 9).reshape(3, 3)
+ self.plot.addImage(data=data, colormap=colormapLog, legend='toto')
+ self.plot.setActiveImage('toto')
+
+ # test Ticks
+ self.tickBar.setTicksNumber(10)
+ self.tickBar.computeTicks()
+
+ ticksTh = numpy.linspace(1.0, 100.0, 10)
+ ticksTh = 10**ticksTh
+ numpy.array_equal(self.tickBar.ticks, ticksTh)
+
+ # test ColorScale
+ val = self.colorScale.getValueFromRelativePosition(1.0)
+ self.assertAlmostEqual(val, 100.0)
+
+ val = self.colorScale.getValueFromRelativePosition(0.0)
+ self.assertTrue(val == 1.0)
+
+ def testLinearNormNoAutoscale(self):
+ colormapLog = Colormap(name='gray',
+ normalization=Colormap.LINEAR,
+ vmin=-4,
+ vmax=5)
+
+ data = numpy.linspace(1, 9, 9).reshape(3, 3)
+ self.plot.addImage(data=data, colormap=colormapLog, legend='toto')
+ self.plot.setActiveImage('toto')
+
+ # test Ticks
+ self.tickBar.setTicksNumber(10)
+ self.tickBar.computeTicks()
+
+ numpy.array_equal(self.tickBar.ticks, numpy.linspace(-4, 5, 10))
+
+ # test ColorScale
+ val = self.colorScale.getValueFromRelativePosition(1.0)
+ self.assertTrue(val == 5.0)
+
+ val = self.colorScale.getValueFromRelativePosition(0.0)
+ self.assertTrue(val == -4.0)
+
+
+class TestColorBarWidget(TestCaseQt):
+ """Test interaction with the ColorBarWidget"""
+
+ def setUp(self):
+ super(TestColorBarWidget, self).setUp()
+ self.plot = Plot2D()
+ self.colorBar = self.plot.getColorBarWidget()
+ self.colorBar.setVisible(True) # Makes sure the colormap is visible
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ del self.colorBar
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ super(TestColorBarWidget, self).tearDown()
+
+ def testEmptyColorBar(self):
+ colorBar = ColorBarWidget(parent=None)
+ colorBar.show()
+ self.qWaitForWindowExposed(colorBar)
+
+ def testNegativeColormaps(self):
+ """test the behavior of the ColorBarWidget in the case of negative
+ values
+
+ Note : colorbar is modified by the Plot directly not ColorBarWidget
+ """
+ colormapLog = Colormap(name='gray',
+ normalization=Colormap.LOGARITHM,
+ vmin=None,
+ vmax=None)
+
+ data = numpy.array([-5, -4, 0, 2, 3, 5, 10, 20, 30])
+ data = data.reshape(3, 3)
+ self.plot.addImage(data=data, colormap=colormapLog, legend='toto')
+ self.plot.setActiveImage('toto')
+
+ # default behavior when with log and negative values: should set vmin
+ # to 1 and vmax to 10
+ self.assertTrue(self.colorBar.getColorScaleBar().minVal == 2)
+ self.assertTrue(self.colorBar.getColorScaleBar().maxVal == 30)
+
+ # if data is positive
+ data[data < 1] = data.max()
+ self.plot.addImage(data=data,
+ colormap=colormapLog,
+ legend='toto',
+ replace=True)
+ self.plot.setActiveImage('toto')
+
+ self.assertTrue(self.colorBar.getColorScaleBar().minVal == data.min())
+ self.assertTrue(self.colorBar.getColorScaleBar().maxVal == data.max())
+
+ def testPlotAssocation(self):
+ """Make sure the ColorBarWidget is properly connected with the plot"""
+ colormap = Colormap(name='gray',
+ normalization=Colormap.LINEAR,
+ vmin=None,
+ vmax=None)
+
+ # make sure that default settings are the same (but a copy of the
+ self.colorBar.setPlot(self.plot)
+ self.assertTrue(
+ self.colorBar.getColormap() is self.plot.getDefaultColormap())
+
+ data = numpy.linspace(0, 10, 100).reshape(10, 10)
+ self.plot.addImage(data=data, colormap=colormap, legend='toto')
+ self.plot.setActiveImage('toto')
+
+ # make sure the modification of the colormap has been done
+ self.assertFalse(
+ self.colorBar.getColormap() is self.plot.getDefaultColormap())
+ self.assertTrue(
+ self.colorBar.getColormap() is colormap)
+
+ # test that colorbar is updated when default plot colormap changes
+ self.plot.clear()
+ plotColormap = Colormap(name='gray',
+ normalization=Colormap.LOGARITHM,
+ vmin=None,
+ vmax=None)
+ self.plot.setDefaultColormap(plotColormap)
+ self.assertTrue(self.colorBar.getColormap() is plotColormap)
+
+ def testColormapWithoutRange(self):
+ """Test with a colormap with vmin==vmax"""
+ colormap = Colormap(name='gray',
+ normalization=Colormap.LINEAR,
+ vmin=1.0,
+ vmax=1.0)
+ self.colorBar.setColormap(colormap)
+
+
+class TestColorBarUpdate(TestCaseQt):
+ """Test that the ColorBar is correctly updated when the signal 'sigChanged'
+ of the colormap is emitted
+ """
+
+ def setUp(self):
+ super(TestColorBarUpdate, self).setUp()
+ self.plot = Plot2D()
+ self.colorBar = self.plot.getColorBarWidget()
+ self.colorBar.setVisible(True) # Makes sure the colormap is visible
+ self.colorBar.setPlot(self.plot)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+ self.data = numpy.random.rand(9).reshape(3, 3)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ del self.colorBar
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ super(TestColorBarUpdate, self).tearDown()
+
+ def testUpdateColorMap(self):
+ colormap = Colormap(name='gray',
+ normalization='linear',
+ vmin=0,
+ vmax=1)
+
+ # check inital state
+ self.plot.addImage(data=self.data, colormap=colormap, legend='toto')
+ self.plot.setActiveImage('toto')
+
+ self.assertTrue(self.colorBar.getColorScaleBar().minVal == 0)
+ self.assertTrue(self.colorBar.getColorScaleBar().maxVal == 1)
+ self.assertTrue(
+ self.colorBar.getColorScaleBar().getTickBar()._vmin == 0)
+ self.assertTrue(
+ self.colorBar.getColorScaleBar().getTickBar()._vmax == 1)
+ self.assertIsInstance(
+ self.colorBar.getColorScaleBar().getTickBar()._normalizer,
+ LinearNormalization)
+
+ # update colormap
+ colormap.setVMin(0.5)
+ self.assertTrue(self.colorBar.getColorScaleBar().minVal == 0.5)
+ self.assertTrue(
+ self.colorBar.getColorScaleBar().getTickBar()._vmin == 0.5)
+
+ colormap.setVMax(0.8)
+ self.assertTrue(self.colorBar.getColorScaleBar().maxVal == 0.8)
+ self.assertTrue(
+ self.colorBar.getColorScaleBar().getTickBar()._vmax == 0.8)
+
+ colormap.setNormalization('log')
+ self.assertIsInstance(
+ self.colorBar.getColorScaleBar().getTickBar()._normalizer,
+ LogarithmicNormalization)
+
+ # TODO : should also check that if the colormap is changing then values (especially in log scale)
+ # should be coherent if in autoscale
diff --git a/src/silx/gui/plot/test/testCompareImages.py b/src/silx/gui/plot/test/testCompareImages.py
new file mode 100644
index 0000000..cf54b99
--- /dev/null
+++ b/src/silx/gui/plot/test/testCompareImages.py
@@ -0,0 +1,106 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for CompareImages widget"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "23/07/2018"
+
+import unittest
+import numpy
+import weakref
+
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.plot.CompareImages import CompareImages
+
+
+class TestCompareImages(TestCaseQt):
+ """Test that CompareImages widget is working in some cases"""
+
+ def setUp(self):
+ super(TestCompareImages, self).setUp()
+ self.widget = CompareImages()
+
+ def tearDown(self):
+ ref = weakref.ref(self.widget)
+ self.widget = None
+ self.qWaitForDestroy(ref)
+ super(TestCompareImages, self).tearDown()
+
+ def testIntensityImage(self):
+ image1 = numpy.random.rand(10, 10)
+ image2 = numpy.random.rand(10, 10)
+ self.widget.setData(image1, image2)
+
+ def testRgbImage(self):
+ image1 = numpy.random.randint(0, 255, size=(10, 10, 3))
+ image2 = numpy.random.randint(0, 255, size=(10, 10, 3))
+ self.widget.setData(image1, image2)
+
+ def testRgbaImage(self):
+ image1 = numpy.random.randint(0, 255, size=(10, 10, 4))
+ image2 = numpy.random.randint(0, 255, size=(10, 10, 4))
+ self.widget.setData(image1, image2)
+
+ def testVizualisations(self):
+ image1 = numpy.random.rand(10, 10)
+ image2 = numpy.random.rand(10, 10)
+ self.widget.setData(image1, image2)
+ for mode in CompareImages.VisualizationMode:
+ self.widget.setVisualizationMode(mode)
+
+ def testAlignemnt(self):
+ image1 = numpy.random.rand(10, 10)
+ image2 = numpy.random.rand(5, 5)
+ self.widget.setData(image1, image2)
+ for mode in CompareImages.AlignmentMode:
+ self.widget.setAlignmentMode(mode)
+
+ def testGetPixel(self):
+ image1 = numpy.random.rand(11, 11)
+ image2 = numpy.random.rand(5, 5)
+ image1[5, 5] = 111.111
+ image2[2, 2] = 222.222
+ self.widget.setData(image1, image2)
+ expectedValue = {}
+ expectedValue[CompareImages.AlignmentMode.CENTER] = 222.222
+ expectedValue[CompareImages.AlignmentMode.STRETCH] = 222.222
+ expectedValue[CompareImages.AlignmentMode.ORIGIN] = None
+ for mode in expectedValue.keys():
+ self.widget.setAlignmentMode(mode)
+ data = self.widget.getRawPixelData(11 / 2.0, 11 / 2.0)
+ data1, data2 = data
+ self.assertEqual(data1, 111.111)
+ self.assertEqual(data2, expectedValue[mode])
+
+ def testImageEmpty(self):
+ self.widget.setData(image1=None, image2=None)
+ self.assertTrue(self.widget.getRawPixelData(11 / 2.0, 11 / 2.0) == (None, None))
+
+ def testSetImageSeparately(self):
+ self.widget.setImage1(numpy.random.rand(10, 10))
+ self.widget.setImage2(numpy.random.rand(10, 10))
+ for mode in CompareImages.VisualizationMode:
+ self.widget.setVisualizationMode(mode)
diff --git a/src/silx/gui/plot/test/testComplexImageView.py b/src/silx/gui/plot/test/testComplexImageView.py
new file mode 100644
index 0000000..46025b9
--- /dev/null
+++ b/src/silx/gui/plot/test/testComplexImageView.py
@@ -0,0 +1,84 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test suite for :class:`ComplexImageView`"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+
+import unittest
+import logging
+import numpy
+
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.plot import ComplexImageView
+
+from .utils import PlotWidgetTestCase
+
+
+logger = logging.getLogger(__name__)
+
+
+class TestComplexImageView(PlotWidgetTestCase, ParametricTestCase):
+ """Test suite of ComplexImageView widget"""
+
+ def _createPlot(self):
+ return ComplexImageView.ComplexImageView()
+
+ def testPlot2DComplex(self):
+ """Test API of ComplexImageView widget"""
+ data = numpy.array(((0, 1j), (1, 1 + 1j)), dtype=numpy.complex64)
+ self.plot.setData(data)
+ self.plot.setKeepDataAspectRatio(True)
+ self.plot.getPlot().resetZoom()
+ self.qWait(100)
+
+ # Test colormap API
+ colormap = self.plot.getColormap().copy()
+ colormap.setName('magma')
+ self.plot.setColormap(colormap)
+ self.qWait(100)
+
+ # Test all modes
+ modes = self.plot.supportedComplexModes()
+ for mode in modes:
+ with self.subTest(mode=mode):
+ self.plot.setComplexMode(mode)
+ self.qWait(100)
+
+ # Test origin and scale API
+ self.plot.setScale((2, 1))
+ self.qWait(100)
+ self.plot.setOrigin((1, 1))
+ self.qWait(100)
+
+ # Test no data
+ self.plot.setData(numpy.zeros((0, 0), dtype=numpy.complex64))
+ self.qWait(100)
+
+ # Test float data
+ self.plot.setData(numpy.arange(100, dtype=numpy.float64).reshape(10, 10))
+ self.qWait(100)
diff --git a/src/silx/gui/plot/test/testCurvesROIWidget.py b/src/silx/gui/plot/test/testCurvesROIWidget.py
new file mode 100644
index 0000000..d7dfafd
--- /dev/null
+++ b/src/silx/gui/plot/test/testCurvesROIWidget.py
@@ -0,0 +1,465 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for CurvesROIWidget"""
+
+__authors__ = ["T. Vincent", "P. Knobel", "H. Payno"]
+__license__ = "MIT"
+__date__ = "16/11/2017"
+
+
+import logging
+import os.path
+import pytest
+from collections import OrderedDict
+import numpy
+
+from silx.gui import qt
+from silx.gui.plot import items
+from silx.gui.plot import Plot1D
+from silx.test.utils import temp_dir
+from silx.gui.utils.testutils import TestCaseQt, SignalListener
+from silx.gui.plot import PlotWindow, CurvesROIWidget
+from silx.gui.plot.CurvesROIWidget import ROITable
+from silx.gui.utils.testutils import getQToolButtonFromAction
+from silx.gui.plot.PlotInteraction import ItemsInteraction
+
+_logger = logging.getLogger(__name__)
+
+
+class TestCurvesROIWidget(TestCaseQt):
+ """Basic test for CurvesROIWidget"""
+
+ def setUp(self):
+ super(TestCurvesROIWidget, self).setUp()
+ self.plot = PlotWindow()
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ self.widget = self.plot.getCurvesRoiDockWidget()
+
+ self.widget.show()
+ self.qWaitForWindowExposed(self.widget)
+
+ def tearDown(self):
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ del self.widget
+
+ super(TestCurvesROIWidget, self).tearDown()
+
+ def testDummyAPI(self):
+ """Simple test of the getRois and setRois API"""
+ roi_neg = CurvesROIWidget.ROI(name='negative', fromdata=-20,
+ todata=-10, type_='X')
+ roi_pos = CurvesROIWidget.ROI(name='positive', fromdata=10,
+ todata=20, type_='X')
+
+ self.widget.roiWidget.setRois((roi_pos, roi_neg))
+
+ rois_defs = self.widget.roiWidget.getRois()
+ self.widget.roiWidget.setRois(rois=rois_defs)
+
+ def testWithCurves(self):
+ """Plot with curves: test all ROI widget buttons"""
+ for offset in range(2):
+ self.plot.addCurve(numpy.arange(1000),
+ offset + numpy.random.random(1000),
+ legend=str(offset))
+
+ # Add two ROI
+ self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton)
+ self.qWait(200)
+ self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton)
+ self.qWait(200)
+
+ # Change active curve
+ self.plot.setActiveCurve(str(1))
+
+ # Delete a ROI
+ self.mouseClick(self.widget.roiWidget.delButton, qt.Qt.LeftButton)
+ self.qWait(200)
+
+ with temp_dir() as tmpDir:
+ self.tmpFile = os.path.join(tmpDir, 'test.ini')
+
+ # Save ROIs
+ self.widget.roiWidget.save(self.tmpFile)
+ self.assertTrue(os.path.isfile(self.tmpFile))
+ self.assertEqual(len(self.widget.getRois()), 2)
+
+ # Reset ROIs
+ self.mouseClick(self.widget.roiWidget.resetButton,
+ qt.Qt.LeftButton)
+ self.qWait(200)
+ rois = self.widget.getRois()
+ self.assertEqual(len(rois), 1)
+ roiID = list(rois.keys())[0]
+ self.assertEqual(rois[roiID].getName(), 'ICR')
+
+ # Load ROIs
+ self.widget.roiWidget.load(self.tmpFile)
+ self.assertEqual(len(self.widget.getRois()), 2)
+
+ del self.tmpFile
+
+ def testMiddleMarker(self):
+ """Test with middle marker enabled"""
+ self.widget.roiWidget.roiTable.setMiddleROIMarkerFlag(True)
+
+ # Add a ROI
+ self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton)
+
+ for roiID in self.widget.roiWidget.roiTable._markersHandler._roiMarkerHandlers:
+ handler = self.widget.roiWidget.roiTable._markersHandler._roiMarkerHandlers[roiID]
+ assert handler.getMarker('min')
+ xleftMarker = handler.getMarker('min').getXPosition()
+ xMiddleMarker = handler.getMarker('middle').getXPosition()
+ xRightMarker = handler.getMarker('max').getXPosition()
+ thValue = xleftMarker + (xRightMarker - xleftMarker) / 2.
+ self.assertAlmostEqual(xMiddleMarker, thValue)
+
+ def testAreaCalculation(self):
+ """Test result of area calculation"""
+ x = numpy.arange(100.)
+ y = numpy.arange(100.)
+
+ # Add two curves
+ self.plot.addCurve(x, y, legend="positive")
+ self.plot.addCurve(-x, y, legend="negative")
+
+ # Make sure there is an active curve and it is the positive one
+ self.plot.setActiveCurve("positive")
+
+ # Add two ROIs
+ roi_neg = CurvesROIWidget.ROI(name='negative', fromdata=-20,
+ todata=-10, type_='X')
+ roi_pos = CurvesROIWidget.ROI(name='positive', fromdata=10,
+ todata=20, type_='X')
+
+ self.widget.roiWidget.setRois((roi_pos, roi_neg))
+
+ posCurve = self.plot.getCurve('positive')
+ negCurve = self.plot.getCurve('negative')
+
+ self.assertEqual(roi_pos.computeRawAndNetArea(posCurve),
+ (numpy.trapz(y=[10, 20], x=[10, 20]),
+ 0.0))
+ self.assertEqual(roi_pos.computeRawAndNetArea(negCurve),
+ (0.0, 0.0))
+ self.assertEqual(roi_neg.computeRawAndNetArea(posCurve),
+ ((0.0), 0.0))
+ self.assertEqual(roi_neg.computeRawAndNetArea(negCurve),
+ ((-150.0), 0.0))
+
+ def testCountsCalculation(self):
+ """Test result of count calculation"""
+ x = numpy.arange(100.)
+ y = numpy.arange(100.)
+
+ # Add two curves
+ self.plot.addCurve(x, y, legend="positive")
+ self.plot.addCurve(-x, y, legend="negative")
+
+ # Make sure there is an active curve and it is the positive one
+ self.plot.setActiveCurve("positive")
+
+ # Add two ROIs
+ roi_neg = CurvesROIWidget.ROI(name='negative', fromdata=-20,
+ todata=-10, type_='X')
+ roi_pos = CurvesROIWidget.ROI(name='positive', fromdata=10,
+ todata=20, type_='X')
+
+ self.widget.roiWidget.setRois((roi_pos, roi_neg))
+
+ posCurve = self.plot.getCurve('positive')
+ negCurve = self.plot.getCurve('negative')
+
+ self.assertEqual(roi_pos.computeRawAndNetCounts(posCurve),
+ (y[10:21].sum(), 0.0))
+ self.assertEqual(roi_pos.computeRawAndNetCounts(negCurve),
+ (0.0, 0.0))
+ self.assertEqual(roi_neg.computeRawAndNetCounts(posCurve),
+ ((0.0), 0.0))
+ self.assertEqual(roi_neg.computeRawAndNetCounts(negCurve),
+ (y[10:21].sum(), 0.0))
+
+ def testDeferedInit(self):
+ """Test behavior of the deferedInit"""
+ x = numpy.arange(100.)
+ y = numpy.arange(100.)
+ self.plot.addCurve(x=x, y=y, legend="name", replace="True")
+ roisDefs = OrderedDict([
+ ["range1",
+ OrderedDict([["from", 20], ["to", 200], ["type", "energy"]])],
+ ["range2",
+ OrderedDict([["from", 300], ["to", 500], ["type", "energy"]])]
+ ])
+
+ roiWidget = self.plot.getCurvesRoiDockWidget().roiWidget
+ self.plot.getCurvesRoiDockWidget().setRois(roisDefs)
+ self.assertEqual(len(roiWidget.getRois()), len(roisDefs))
+ self.plot.getCurvesRoiDockWidget().setVisible(True)
+ self.assertEqual(len(roiWidget.getRois()), len(roisDefs))
+
+ def testDictCompatibility(self):
+ """Test that ROI api is valid with dict and not information is lost"""
+ roiDict = {'from': 20, 'to': 200, 'type': 'energy', 'comment': 'no',
+ 'name': 'myROI', 'calibration': [1, 2, 3]}
+ roi = CurvesROIWidget.ROI._fromDict(roiDict)
+ self.assertEqual(roi.toDict(), roiDict)
+
+ def testShowAllROI(self):
+ """Test the show allROI action"""
+ x = numpy.arange(100.)
+ y = numpy.arange(100.)
+ self.plot.addCurve(x=x, y=y, legend="name", replace="True")
+
+ roisDefsDict = {
+ "range1": {"from": 20, "to": 200,"type": "energy"},
+ "range2": {"from": 300, "to": 500, "type": "energy"}
+ }
+
+ roisDefsObj = (
+ CurvesROIWidget.ROI(name='range3', fromdata=20, todata=200,
+ type_='energy'),
+ CurvesROIWidget.ROI(name='range4', fromdata=300, todata=500,
+ type_='energy')
+ )
+ self.widget.roiWidget.showAllMarkers(True)
+ roiWidget = self.plot.getCurvesRoiDockWidget().roiWidget
+ roiWidget.setRois(roisDefsDict)
+ markers = [item for item in self.plot.getItems()
+ if isinstance(item, items.MarkerBase)]
+ self.assertEqual(len(markers), 2*3)
+
+ markersHandler = self.widget.roiWidget.roiTable._markersHandler
+ roiWidget.showAllMarkers(True)
+ ICRROI = markersHandler.getVisibleRois()
+ self.assertEqual(len(ICRROI), 2)
+
+ roiWidget.showAllMarkers(False)
+ ICRROI = markersHandler.getVisibleRois()
+ self.assertEqual(len(ICRROI), 1)
+
+ roiWidget.setRois(roisDefsObj)
+ self.qapp.processEvents()
+ markers = [item for item in self.plot.getItems()
+ if isinstance(item, items.MarkerBase)]
+ self.assertEqual(len(markers), 2*3)
+
+ markersHandler = self.widget.roiWidget.roiTable._markersHandler
+ roiWidget.showAllMarkers(True)
+ ICRROI = markersHandler.getVisibleRois()
+ self.assertEqual(len(ICRROI), 2)
+
+ roiWidget.showAllMarkers(False)
+ ICRROI = markersHandler.getVisibleRois()
+ self.assertEqual(len(ICRROI), 1)
+
+ def testRoiEdition(self):
+ """Make sure if the ROI object is edited the ROITable will be updated
+ """
+ roi = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5)
+ self.widget.roiWidget.setRois((roi, ))
+
+ x = (0, 1, 1, 2, 2, 3)
+ y = (1, 1, 2, 2, 1, 1)
+ self.plot.addCurve(x=x, y=y, legend='linearCurve')
+ self.plot.setActiveCurve(legend='linearCurve')
+ self.widget.calculateROIs()
+
+ roiTable = self.widget.roiWidget.roiTable
+ indexesColumns = CurvesROIWidget.ROITable.COLUMNS_INDEX
+ itemRawCounts = roiTable.item(0, indexesColumns['Raw Counts'])
+ itemNetCounts = roiTable.item(0, indexesColumns['Net Counts'])
+
+ self.assertTrue(itemRawCounts.text() == '8.0')
+ self.assertTrue(itemNetCounts.text() == '2.0')
+
+ itemRawArea = roiTable.item(0, indexesColumns['Raw Area'])
+ itemNetArea = roiTable.item(0, indexesColumns['Net Area'])
+
+ self.assertTrue(itemRawArea.text() == '4.0')
+ self.assertTrue(itemNetArea.text() == '1.0')
+
+ roi.setTo(2)
+ itemRawArea = roiTable.item(0, indexesColumns['Raw Area'])
+ self.assertTrue(itemRawArea.text() == '3.0')
+ roi.setFrom(1)
+ itemRawArea = roiTable.item(0, indexesColumns['Raw Area'])
+ self.assertTrue(itemRawArea.text() == '2.0')
+
+ def testRemoveActiveROI(self):
+ """Test widget behavior when removing the active ROI"""
+ roi = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5)
+ self.widget.roiWidget.setRois((roi,))
+
+ self.widget.roiWidget.roiTable.setActiveRoi(None)
+ self.assertEqual(len(self.widget.roiWidget.roiTable.selectedItems()), 0)
+ self.widget.roiWidget.setRois((roi,))
+ self.plot.setActiveCurve(legend='linearCurve')
+ self.widget.calculateROIs()
+
+ def testEmitCurrentROI(self):
+ """Test behavior of the CurvesROIWidget.sigROISignal"""
+ roi = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5)
+ self.widget.roiWidget.setRois((roi,))
+ signalListener = SignalListener()
+ self.widget.roiWidget.sigROISignal.connect(signalListener.partial())
+ self.widget.show()
+ self.qapp.processEvents()
+ self.assertEqual(signalListener.callCount(), 0)
+ self.assertIs(self.widget.roiWidget.roiTable.activeRoi, roi)
+ roi.setFrom(0.0)
+ self.qapp.processEvents()
+ self.assertEqual(signalListener.callCount(), 0)
+ roi.setFrom(0.3)
+ self.qapp.processEvents()
+ self.assertEqual(signalListener.callCount(), 1)
+
+
+class TestRoiWidgetSignals(TestCaseQt):
+ """Test Signals emitted by the RoiWidgetSignals"""
+
+ def setUp(self):
+ self.plot = Plot1D()
+ x = range(20)
+ y = range(20)
+ self.plot.addCurve(x, y, legend='curve0')
+ self.listener = SignalListener()
+ self.curves_roi_widget = self.plot.getCurvesRoiWidget()
+ self.curves_roi_widget.sigROISignal.connect(self.listener)
+ assert self.curves_roi_widget.isVisible() is False
+ assert self.listener.callCount() == 0
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ toolButton = getQToolButtonFromAction(self.plot.getRoiAction())
+ self.qapp.processEvents()
+ self.mouseClick(widget=toolButton, button=qt.Qt.LeftButton)
+
+ self.curves_roi_widget.show()
+ self.qWaitForWindowExposed(self.curves_roi_widget)
+
+ def tearDown(self):
+ self.plot = None
+ self.curves_roi_widget = None
+
+ def testSigROISignalAddRmRois(self):
+ """Test SigROISignal when adding and removing ROIS"""
+ self.listener.clear()
+
+ roi1 = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5)
+ self.curves_roi_widget.roiTable.addRoi(roi1)
+ self.assertEqual(self.listener.callCount(), 1)
+ self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear')
+ self.listener.clear()
+
+ roi2 = CurvesROIWidget.ROI(name='linear2', fromdata=0, todata=5)
+ self.curves_roi_widget.roiTable.addRoi(roi2)
+ self.assertEqual(self.listener.callCount(), 1)
+ self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear2')
+ self.listener.clear()
+
+ self.curves_roi_widget.roiTable.removeROI(roi2)
+ self.assertEqual(self.listener.callCount(), 1)
+ self.assertTrue(self.curves_roi_widget.roiTable.activeRoi == roi1)
+ self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear')
+ self.listener.clear()
+
+ self.curves_roi_widget.roiTable.deleteActiveRoi()
+ self.assertEqual(self.listener.callCount(), 1)
+ self.assertTrue(self.curves_roi_widget.roiTable.activeRoi is None)
+ self.assertTrue(self.listener.arguments()[0][0]['current'] is None)
+ self.listener.clear()
+
+ self.curves_roi_widget.roiTable.addRoi(roi1)
+ self.assertEqual(self.listener.callCount(), 1)
+ self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear')
+ self.assertTrue(self.curves_roi_widget.roiTable.activeRoi == roi1)
+ self.listener.clear()
+ self.qapp.processEvents()
+
+ self.curves_roi_widget.roiTable.removeROI(roi1)
+ self.qapp.processEvents()
+ self.assertEqual(self.listener.callCount(), 1)
+ self.assertTrue(self.listener.arguments()[0][0]['current'] == 'ICR')
+ self.listener.clear()
+
+ def testSigROISignalModifyROI(self):
+ """Test SigROISignal when modifying it"""
+ self.curves_roi_widget.roiTable.setMiddleROIMarkerFlag(True)
+ roi1 = CurvesROIWidget.ROI(name='linear', fromdata=2, todata=5)
+ self.curves_roi_widget.roiTable.addRoi(roi1)
+ self.curves_roi_widget.roiTable.setActiveRoi(roi1)
+
+ # test modify the roi2 object
+ self.listener.clear()
+ roi1.setFrom(0.56)
+ self.assertEqual(self.listener.callCount(), 1)
+ self.listener.clear()
+ roi1.setTo(2.56)
+ self.assertEqual(self.listener.callCount(), 1)
+ self.listener.clear()
+ roi1.setName('linear2')
+ self.assertEqual(self.listener.callCount(), 1)
+ self.listener.clear()
+ roi1.setType('new type')
+ self.assertEqual(self.listener.callCount(), 1)
+
+ widget = self.plot.getWidgetHandle()
+ widget.setFocus(qt.Qt.OtherFocusReason)
+ self.plot.raise_()
+ self.qapp.processEvents()
+
+ # modify roi limits (from the gui)
+ roi_marker_handler = self.curves_roi_widget.roiTable._markersHandler.getMarkerHandler(roi1.getID())
+ for marker_type in ('min', 'max', 'middle'):
+ with self.subTest(marker_type=marker_type):
+ self.listener.clear()
+ marker = roi_marker_handler.getMarker(marker_type)
+ x_pix, y_pix = self.plot.dataToPixel(marker.getXPosition(), marker.getYPosition())
+ self.mouseMove(widget, pos=(x_pix, y_pix))
+ self.qWait(100)
+ self.mousePress(widget, qt.Qt.LeftButton, pos=(x_pix, y_pix))
+ self.mouseMove(widget, pos=(x_pix+20, y_pix))
+ self.qWait(100)
+ self.mouseRelease(widget, qt.Qt.LeftButton, pos=(x_pix+20, y_pix))
+ self.qWait(100)
+ self.mouseMove(widget, pos=(x_pix, y_pix))
+ self.qapp.processEvents()
+ self.assertEqual(self.listener.callCount(), 1)
+
+ def testSetActiveCurve(self):
+ """Test sigRoiSignal when set an active curve"""
+ roi1 = CurvesROIWidget.ROI(name='linear', fromdata=2, todata=5)
+ self.curves_roi_widget.roiTable.setActiveRoi(roi1)
+ self.listener.clear()
+ self.plot.setActiveCurve('curve0')
+ self.assertEqual(self.listener.callCount(), 0)
diff --git a/src/silx/gui/plot/test/testImageStack.py b/src/silx/gui/plot/test/testImageStack.py
new file mode 100644
index 0000000..5c44691
--- /dev/null
+++ b/src/silx/gui/plot/test/testImageStack.py
@@ -0,0 +1,186 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for ImageStack"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "15/01/2020"
+
+
+import unittest
+import tempfile
+import numpy
+import h5py
+
+from silx.gui import qt
+from silx.gui.utils.testutils import TestCaseQt
+from silx.io.url import DataUrl
+from silx.gui.plot.ImageStack import ImageStack
+from silx.gui.utils.testutils import SignalListener
+from collections import OrderedDict
+import os
+import time
+import shutil
+
+
+class TestImageStack(TestCaseQt):
+ """Simple test of the Image stack"""
+
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.urls = OrderedDict()
+ self._raw_data = {}
+ self._folder = tempfile.mkdtemp()
+ self._n_urls = 10
+ file_name = os.path.join(self._folder, 'test_inage_stack_file.h5')
+ with h5py.File(file_name, 'w') as h5f:
+ for i in range(self._n_urls):
+ width = numpy.random.randint(10, 40)
+ height = numpy.random.randint(10, 40)
+ raw_data = numpy.random.random((width, height))
+ self._raw_data[i] = raw_data
+ h5f[str(i)] = raw_data
+ self.urls[i] = DataUrl(file_path=file_name,
+ data_path=str(i),
+ scheme='silx')
+ self.widget = ImageStack()
+
+ self.urlLoadedListener = SignalListener()
+ self.widget.sigLoaded.connect(self.urlLoadedListener)
+
+ self.currentUrlChangedListener = SignalListener()
+ self.widget.sigCurrentUrlChanged.connect(self.currentUrlChangedListener)
+
+ def tearDown(self):
+ shutil.rmtree(self._folder)
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose, True)
+ self.widget.close()
+ TestCaseQt.setUp(self)
+
+ def testControls(self):
+ """Test that selection using the url table and the slider are working
+ """
+ self.widget.show()
+ self.assertEqual(self.widget.getCurrentUrl(), None)
+ self.assertEqual(self.widget.getCurrentUrlIndex(), None)
+ self.widget.setUrls(list(self.urls.values()))
+
+ # wait for image to be loaded
+ self._waitUntilUrlLoaded()
+
+ self.assertEqual(self.widget.getCurrentUrl(), self.urls[0])
+
+ # make sure all image are loaded
+ self.assertEqual(self.urlLoadedListener.callCount(), self._n_urls)
+ numpy.testing.assert_array_equal(
+ self.widget.getPlotWidget().getActiveImage(just_legend=False).getData(),
+ self._raw_data[0])
+ self.assertEqual(self.widget._slider.value(), 0)
+
+ self.widget._urlsTable.setUrl(self.urls[4])
+ numpy.testing.assert_array_equal(
+ self.widget.getPlotWidget().getActiveImage(just_legend=False).getData(),
+ self._raw_data[4])
+ self.assertEqual(self.widget._slider.value(), 4)
+ self.assertEqual(self.widget.getCurrentUrl(), self.urls[4])
+ self.assertEqual(self.widget.getCurrentUrlIndex(), 4)
+
+ self.widget._slider.setUrlIndex(6)
+ numpy.testing.assert_array_equal(
+ self.widget.getPlotWidget().getActiveImage(just_legend=False).getData(),
+ self._raw_data[6])
+ self.assertEqual(self.widget._urlsTable.currentItem().text(),
+ self.urls[6].path())
+
+ def testCurrentUrlSignals(self):
+ """Test emission of 'currentUrlChangedListener'"""
+ # check initialization
+ self.assertEqual(self.currentUrlChangedListener.callCount(), 0)
+ self.widget.setUrls(list(self.urls.values()))
+ self.qapp.processEvents()
+ time.sleep(0.5)
+ self.qapp.processEvents()
+ # once loaded the two signals should have been sended
+ self.assertEqual(self.currentUrlChangedListener.callCount(), 1)
+ # if the slider is stuck to the same position no signal should be
+ # emitted
+ self.qapp.processEvents()
+ time.sleep(0.5)
+ self.qapp.processEvents()
+ self.assertEqual(self.widget._slider.value(), 0)
+ self.assertEqual(self.currentUrlChangedListener.callCount(), 1)
+ # if slider position is changed, one of each signal should have been
+ # emitted
+ self.widget._urlsTable.setUrl(self.urls[4])
+ self.qapp.processEvents()
+ time.sleep(1.5)
+ self.qapp.processEvents()
+ self.assertEqual(self.currentUrlChangedListener.callCount(), 2)
+
+ def testUtils(self):
+ """Test that some utils functions are working"""
+ self.widget.show()
+ self.widget.setUrls(list(self.urls.values()))
+ self.assertEqual(len(self.widget.getUrls()), len(self.urls))
+
+ # wait for image to be loaded
+ self._waitUntilUrlLoaded()
+
+ urls_values = list(self.urls.values())
+ self.assertEqual(urls_values[0], self.urls[0])
+ self.assertEqual(urls_values[7], self.urls[7])
+
+ self.assertEqual(self.widget._getNextUrl(urls_values[2]).path(),
+ urls_values[3].path())
+ self.assertEqual(self.widget._getPreviousUrl(urls_values[0]), None)
+ self.assertEqual(self.widget._getPreviousUrl(urls_values[6]).path(),
+ urls_values[5].path())
+
+ self.assertEqual(self.widget._getNNextUrls(2, urls_values[0]),
+ urls_values[1:3])
+ self.assertEqual(self.widget._getNNextUrls(5, urls_values[7]),
+ urls_values[8:])
+ self.assertEqual(self.widget._getNPreviousUrls(3, urls_values[2]),
+ urls_values[:2])
+ self.assertEqual(self.widget._getNPreviousUrls(5, urls_values[8]),
+ urls_values[3:8])
+
+ def _waitUntilUrlLoaded(self, timeout=2.0):
+ """Wait until all image urls are loaded"""
+ loop_duration = 0.2
+ remaining_duration = timeout
+ while(len(self.widget._loadingThreads) > 0 and remaining_duration > 0):
+ remaining_duration -= loop_duration
+ time.sleep(loop_duration)
+ self.qapp.processEvents()
+
+ if remaining_duration <= 0.0:
+ remaining_urls = []
+ for thread_ in self.widget._loadingThreads:
+ remaining_urls.append(thread_.url.path())
+ mess = 'All images are not loaded after the time out. ' \
+ 'Remaining urls are: ' + str(remaining_urls)
+ raise TimeoutError(mess)
+ return True
diff --git a/src/silx/gui/plot/test/testImageView.py b/src/silx/gui/plot/test/testImageView.py
new file mode 100644
index 0000000..7c1355f
--- /dev/null
+++ b/src/silx/gui/plot/test/testImageView.py
@@ -0,0 +1,194 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for PlotWindow"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+
+import numpy
+
+from silx.gui import qt
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.plot import items
+
+from silx.gui.plot.ImageView import ImageView
+from silx.gui.colors import Colormap
+
+
+class TestImageView(TestCaseQt):
+ """Tests of ImageView widget."""
+
+ def setUp(self):
+ super(TestImageView, self).setUp()
+ self.plot = ImageView()
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ self.qapp.processEvents()
+ super(TestImageView, self).tearDown()
+
+ def testSetImage(self):
+ """Test setImage"""
+ image = numpy.arange(100).reshape(10, 10)
+
+ self.plot.setImage(image, reset=True)
+ self.qWait(100)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (0, 10))
+ self.assertEqual(self.plot.getYAxis().getLimits(), (0, 10))
+
+ # With reset=False
+ self.plot.setImage(image[::2, ::2], reset=False)
+ self.qWait(100)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (0, 10))
+ self.assertEqual(self.plot.getYAxis().getLimits(), (0, 10))
+
+ self.plot.setImage(image, origin=(10, 20), scale=(2, 4), reset=False)
+ self.qWait(100)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (0, 10))
+ self.assertEqual(self.plot.getYAxis().getLimits(), (0, 10))
+
+ # With reset=True
+ self.plot.setImage(image, origin=(1, 2), scale=(1, 0.5), reset=True)
+ self.qWait(100)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (1, 11))
+ self.assertEqual(self.plot.getYAxis().getLimits(), (2, 7))
+
+ self.plot.setImage(image[::2, ::2], reset=True)
+ self.qWait(100)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (0, 5))
+ self.assertEqual(self.plot.getYAxis().getLimits(), (0, 5))
+
+ def testColormap(self):
+ """Test get|setColormap"""
+ image = numpy.arange(100).reshape(10, 10)
+ self.plot.setImage(image)
+
+ # Colormap as dict
+ self.plot.setColormap({'name': 'viridis',
+ 'normalization': 'log',
+ 'autoscale': False,
+ 'vmin': 0,
+ 'vmax': 1})
+ colormap = self.plot.getColormap()
+ self.assertEqual(colormap.getName(), 'viridis')
+ self.assertEqual(colormap.getNormalization(), 'log')
+ self.assertEqual(colormap.getVMin(), 0)
+ self.assertEqual(colormap.getVMax(), 1)
+
+ # Colormap as keyword arguments
+ self.plot.setColormap(colormap='magma',
+ normalization='linear',
+ autoscale=True,
+ vmin=1,
+ vmax=2)
+ self.assertEqual(colormap.getName(), 'magma')
+ self.assertEqual(colormap.getNormalization(), 'linear')
+ self.assertEqual(colormap.getVMin(), None)
+ self.assertEqual(colormap.getVMax(), None)
+
+ # Update colormap with keyword argument
+ self.plot.setColormap(normalization='log')
+ self.assertEqual(colormap.getNormalization(), 'log')
+
+ # Colormap as Colormap object
+ cmap = Colormap()
+ self.plot.setColormap(cmap)
+ self.assertIs(self.plot.getColormap(), cmap)
+
+ def testSetProfileWindowBehavior(self):
+ """Test change of profile window display behavior"""
+ self.assertIs(
+ self.plot.getProfileWindowBehavior(),
+ ImageView.ProfileWindowBehavior.POPUP,
+ )
+
+ self.plot.setProfileWindowBehavior('embedded')
+ self.assertIs(
+ self.plot.getProfileWindowBehavior(),
+ ImageView.ProfileWindowBehavior.EMBEDDED,
+ )
+
+ image = numpy.arange(100).reshape(10, 10)
+ self.plot.setImage(image)
+
+ self.plot.setProfileWindowBehavior(
+ ImageView.ProfileWindowBehavior.POPUP
+ )
+ self.assertIs(
+ self.plot.getProfileWindowBehavior(),
+ ImageView.ProfileWindowBehavior.POPUP,
+ )
+
+ def testRGBImage(self):
+ """Test setImage"""
+ image = numpy.arange(100 * 3, dtype=numpy.uint8).reshape(10, 10, 3)
+
+ self.plot.setImage(image, reset=True)
+ self.qWait(100)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (0, 10))
+ self.assertEqual(self.plot.getYAxis().getLimits(), (0, 10))
+
+ def testRGBAImage(self):
+ """Test setImage"""
+ image = numpy.arange(100 * 4, dtype=numpy.uint8).reshape(10, 10, 4)
+
+ self.plot.setImage(image, reset=True)
+ self.qWait(100)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (0, 10))
+ self.assertEqual(self.plot.getYAxis().getLimits(), (0, 10))
+
+ def testImageAggregationMode(self):
+ """Test setImage"""
+ image = numpy.arange(100).reshape(10, 10)
+ self.plot.setImage(image, reset=True)
+ self.qWait(100)
+ self.plot.getAggregationModeAction().setAggregationMode(items.ImageDataAggregated.Aggregation.MAX)
+ self.qWait(100)
+
+ def testImageAggregationModeBackToNormalMode(self):
+ """Test setImage"""
+ image = numpy.arange(100).reshape(10, 10)
+ self.plot.setImage(image, reset=True)
+ self.qWait(100)
+ self.plot.getAggregationModeAction().setAggregationMode(items.ImageDataAggregated.Aggregation.MAX)
+ self.qWait(100)
+ self.plot.getAggregationModeAction().setAggregationMode(items.ImageDataAggregated.Aggregation.NONE)
+ self.qWait(100)
+
+ def testRGBAInAggregationMode(self):
+ """Test setImage"""
+ image = numpy.arange(100 * 3, dtype=numpy.uint8).reshape(10, 10, 3)
+
+ self.plot.setImage(image, reset=True)
+ self.qWait(100)
+ self.plot.getAggregationModeAction().setAggregationMode(items.ImageDataAggregated.Aggregation.MAX)
+ self.qWait(100)
diff --git a/src/silx/gui/plot/test/testInteraction.py b/src/silx/gui/plot/test/testInteraction.py
new file mode 100644
index 0000000..d136b21
--- /dev/null
+++ b/src/silx/gui/plot/test/testInteraction.py
@@ -0,0 +1,78 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests from interaction state machines"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "18/02/2016"
+
+
+import unittest
+
+from silx.gui.plot import Interaction
+
+
+class TestInteraction(unittest.TestCase):
+ def testClickOrDrag(self):
+ """Minimalistic test for click or drag state machine."""
+ events = []
+
+ class TestClickOrDrag(Interaction.ClickOrDrag):
+ def click(self, x, y, btn):
+ events.append(('click', x, y, btn))
+
+ def beginDrag(self, x, y, btn):
+ events.append(('beginDrag', x, y, btn))
+
+ def drag(self, x, y, btn):
+ events.append(('drag', x, y, btn))
+
+ def endDrag(self, start, end, btn):
+ events.append(('endDrag', start, end, btn))
+
+ clickOrDrag = TestClickOrDrag()
+
+ # click
+ clickOrDrag.handleEvent('press', 10, 10, Interaction.LEFT_BTN)
+ self.assertEqual(len(events), 0)
+
+ clickOrDrag.handleEvent('release', 10, 10, Interaction.LEFT_BTN)
+ self.assertEqual(len(events), 1)
+ self.assertEqual(events[0], ('click', 10, 10, Interaction.LEFT_BTN))
+
+ # drag
+ events = []
+ clickOrDrag.handleEvent('press', 10, 10, Interaction.LEFT_BTN)
+ self.assertEqual(len(events), 0)
+ clickOrDrag.handleEvent('move', 15, 10)
+ self.assertEqual(len(events), 2) # Received beginDrag and drag
+ self.assertEqual(events[0], ('beginDrag', 10, 10, Interaction.LEFT_BTN))
+ self.assertEqual(events[1], ('drag', 15, 10, Interaction.LEFT_BTN))
+ clickOrDrag.handleEvent('move', 20, 10)
+ self.assertEqual(len(events), 3)
+ self.assertEqual(events[-1], ('drag', 20, 10, Interaction.LEFT_BTN))
+ clickOrDrag.handleEvent('release', 20, 10, Interaction.LEFT_BTN)
+ self.assertEqual(len(events), 4)
+ self.assertEqual(events[-1], ('endDrag', (10, 10), (20, 10), Interaction.LEFT_BTN))
diff --git a/src/silx/gui/plot/test/testItem.py b/src/silx/gui/plot/test/testItem.py
new file mode 100644
index 0000000..0b15dc3
--- /dev/null
+++ b/src/silx/gui/plot/test/testItem.py
@@ -0,0 +1,360 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for PlotWidget items."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "01/09/2017"
+
+
+import unittest
+
+import numpy
+
+from silx.gui.utils.testutils import SignalListener
+from silx.gui.plot.items import ItemChangedType
+from silx.gui.plot import items
+from .utils import PlotWidgetTestCase
+
+
+class TestSigItemChangedSignal(PlotWidgetTestCase):
+ """Test item's sigItemChanged signal"""
+
+ def testCurveChanged(self):
+ """Test sigItemChanged for curve"""
+ self.plot.addCurve(numpy.arange(10), numpy.arange(10), legend='test')
+ curve = self.plot.getCurve('test')
+
+ listener = SignalListener()
+ curve.sigItemChanged.connect(listener)
+
+ # Test for signal in Item class
+ curve.setVisible(False)
+ curve.setVisible(True)
+ curve.setZValue(100)
+
+ # Test for signals in PointsBase class
+ curve.setData(numpy.arange(100), numpy.arange(100))
+
+ # SymbolMixIn
+ curve.setSymbol('Circle')
+ curve.setSymbol('d')
+ curve.setSymbolSize(20)
+
+ # AlphaMixIn
+ curve.setAlpha(0.5)
+
+ # Test for signals in Curve class
+ # ColorMixIn
+ curve.setColor('yellow')
+ # YAxisMixIn
+ curve.setYAxis('right')
+ # FillMixIn
+ curve.setFill(True)
+ # LineMixIn
+ curve.setLineStyle(':')
+ curve.setLineStyle(':') # Not sending event
+ curve.setLineWidth(2)
+
+ self.assertEqual(listener.arguments(argumentIndex=0),
+ [ItemChangedType.VISIBLE,
+ ItemChangedType.VISIBLE,
+ ItemChangedType.ZVALUE,
+ ItemChangedType.DATA,
+ ItemChangedType.SYMBOL,
+ ItemChangedType.SYMBOL,
+ ItemChangedType.SYMBOL_SIZE,
+ ItemChangedType.ALPHA,
+ ItemChangedType.COLOR,
+ ItemChangedType.YAXIS,
+ ItemChangedType.FILL,
+ ItemChangedType.LINE_STYLE,
+ ItemChangedType.LINE_WIDTH])
+
+ def testHistogramChanged(self):
+ """Test sigItemChanged for Histogram"""
+ self.plot.addHistogram(
+ numpy.arange(10), edges=numpy.arange(11), legend='test')
+ histogram = self.plot.getHistogram('test')
+ listener = SignalListener()
+ histogram.sigItemChanged.connect(listener)
+
+ # Test signals in Histogram class
+ histogram.setData(numpy.zeros(10), numpy.arange(11))
+
+ self.assertEqual(listener.arguments(argumentIndex=0),
+ [ItemChangedType.DATA])
+
+ def testImageDataChanged(self):
+ """Test sigItemChanged for ImageData"""
+ self.plot.addImage(numpy.arange(100).reshape(10, 10), legend='test')
+ image = self.plot.getImage('test')
+
+ listener = SignalListener()
+ image.sigItemChanged.connect(listener)
+
+ # ColormapMixIn
+ colormap = self.plot.getDefaultColormap().copy()
+ image.setColormap(colormap)
+ image.getColormap().setName('viridis')
+
+ # Test of signals in ImageBase class
+ image.setOrigin(10)
+ image.setScale(2)
+
+ # Test of signals in ImageData class
+ image.setData(numpy.ones((10, 10)))
+
+ self.assertEqual(listener.arguments(argumentIndex=0),
+ [ItemChangedType.COLORMAP,
+ ItemChangedType.COLORMAP,
+ ItemChangedType.POSITION,
+ ItemChangedType.SCALE,
+ ItemChangedType.COLORMAP,
+ ItemChangedType.DATA])
+
+ def testImageRgbaChanged(self):
+ """Test sigItemChanged for ImageRgba"""
+ self.plot.addImage(numpy.ones((10, 10, 3)), legend='rgb')
+ image = self.plot.getImage('rgb')
+
+ listener = SignalListener()
+ image.sigItemChanged.connect(listener)
+
+ # Test of signals in ImageRgba class
+ image.setData(numpy.zeros((10, 10, 3)))
+
+ self.assertEqual(listener.arguments(argumentIndex=0),
+ [ItemChangedType.DATA])
+
+ def testMarkerChanged(self):
+ """Test sigItemChanged for markers"""
+ self.plot.addMarker(10, 20, legend='test')
+ marker = self.plot._getMarker('test')
+
+ listener = SignalListener()
+ marker.sigItemChanged.connect(listener)
+
+ # Test signals in _BaseMarker
+ marker.setPosition(10, 10)
+ marker.setPosition(10, 10) # Not sending event
+ marker.setText('toto')
+ self.assertEqual(listener.arguments(argumentIndex=0),
+ [ItemChangedType.POSITION,
+ ItemChangedType.TEXT])
+
+ # XMarker
+ self.plot.addXMarker(10, legend='x')
+ marker = self.plot._getMarker('x')
+
+ listener = SignalListener()
+ marker.sigItemChanged.connect(listener)
+ marker.setPosition(20, 20)
+ self.assertEqual(listener.arguments(argumentIndex=0),
+ [ItemChangedType.POSITION])
+
+ # YMarker
+ self.plot.addYMarker(10, legend='x')
+ marker = self.plot._getMarker('x')
+
+ listener = SignalListener()
+ marker.sigItemChanged.connect(listener)
+ marker.setPosition(20, 20)
+ self.assertEqual(listener.arguments(argumentIndex=0),
+ [ItemChangedType.POSITION])
+
+ def testScatterChanged(self):
+ """Test sigItemChanged for scatter"""
+ data = numpy.arange(10)
+ self.plot.addScatter(data, data, data, legend='test')
+ scatter = self.plot.getScatter('test')
+
+ listener = SignalListener()
+ scatter.sigItemChanged.connect(listener)
+
+ # ColormapMixIn
+ scatter.getColormap().setName('viridis')
+
+ # Test of signals in Scatter class
+ scatter.setData((0, 1, 2), (1, 0, 2), (0, 1, 2))
+
+ # Visualization mode changed
+ scatter.setVisualization(scatter.Visualization.SOLID)
+
+ self.assertEqual(listener.arguments(),
+ [(ItemChangedType.COLORMAP,),
+ (ItemChangedType.DATA,),
+ (ItemChangedType.COLORMAP,),
+ (ItemChangedType.VISUALIZATION_MODE,)])
+
+ def testShapeChanged(self):
+ """Test sigItemChanged for shape"""
+ data = numpy.array((1., 10.))
+ self.plot.addShape(data, data, legend='test', shape='rectangle')
+ shape = self.plot._getItem(kind='item', legend='test')
+
+ listener = SignalListener()
+ shape.sigItemChanged.connect(listener)
+
+ shape.setOverlay(True)
+ shape.setPoints(((2., 2.), (3., 3.)))
+
+ self.assertEqual(listener.arguments(),
+ [(ItemChangedType.OVERLAY,),
+ (ItemChangedType.DATA,)])
+
+
+class TestSymbol(PlotWidgetTestCase):
+ """Test item's symbol """
+
+ def test(self):
+ """Test sigItemChanged for curve"""
+ self.plot.addCurve(numpy.arange(10), numpy.arange(10), legend='test')
+ curve = self.plot.getCurve('test')
+
+ # SymbolMixIn
+ curve.setSymbol('o')
+ name = curve.getSymbolName()
+ self.assertEqual('Circle', name)
+
+ name = curve.getSymbolName('d')
+ self.assertEqual('Diamond', name)
+
+
+class TestVisibleExtent(PlotWidgetTestCase):
+ """Test item's visible extent feature"""
+
+ def testGetVisibleBounds(self):
+ """Test Item.getVisibleBounds"""
+
+ # Create test items (with a bounding box of x: [1,3], y: [0,2])
+ curve = items.Curve()
+ curve.setData((1, 2, 3), (0, 1, 2))
+
+ histogram = items.Histogram()
+ histogram.setData((0, 1, 2), (1, 5/3, 7/3, 3))
+
+ image = items.ImageData()
+ image.setOrigin((1, 0))
+ image.setData(numpy.arange(4).reshape(2, 2))
+
+ scatter = items.Scatter()
+ scatter.setData((1, 2, 3), (0, 1, 2), (1, 2, 3))
+
+ bbox = items.BoundingRect()
+ bbox.setBounds((1, 3, 0, 2))
+
+ xaxis, yaxis = self.plot.getXAxis(), self.plot.getYAxis()
+ for item in (curve, histogram, image, scatter, bbox):
+ with self.subTest(item=item):
+ xaxis.setLimits(0, 100)
+ yaxis.setLimits(0, 100)
+ self.plot.addItem(item)
+ self.assertEqual(item.getVisibleBounds(), (1., 3., 0., 2.))
+
+ xaxis.setLimits(0.5, 2.5)
+ self.assertEqual(item.getVisibleBounds(), (1, 2.5, 0., 2.))
+
+ yaxis.setLimits(0.5, 1.5)
+ self.assertEqual(item.getVisibleBounds(), (1, 2.5, 0.5, 1.5))
+
+ item.setVisible(False)
+ self.assertIsNone(item.getVisibleBounds())
+
+ self.plot.clear()
+
+ def testVisibleExtentTracking(self):
+ """Test Item's visible extent tracking"""
+ image = items.ImageData()
+ image.setData(numpy.arange(6).reshape(2, 3))
+
+ listener = SignalListener()
+ image._sigVisibleBoundsChanged.connect(listener)
+ image._setVisibleBoundsTracking(True)
+ self.assertTrue(image._isVisibleBoundsTracking())
+
+ self.plot.addItem(image)
+ self.assertEqual(listener.callCount(), 1)
+
+ self.plot.getXAxis().setLimits(0, 1)
+ self.assertEqual(listener.callCount(), 2)
+
+ self.plot.hide()
+ self.qapp.processEvents()
+ # No event here
+ self.assertEqual(listener.callCount(), 2)
+
+ self.plot.getXAxis().setLimits(1, 2)
+ # No event since PlotWidget is hidden, delayed to PlotWidget show
+ self.assertEqual(listener.callCount(), 2)
+
+ self.plot.show()
+ self.qapp.processEvents()
+ # Receives delayed event now
+ self.assertEqual(listener.callCount(), 3)
+
+ image.setOrigin((-1, -1))
+ self.assertEqual(listener.callCount(), 4)
+
+ image.setVisible(False)
+ image.setOrigin((0, 0))
+ # No event since item is not visible
+ self.assertEqual(listener.callCount(), 4)
+
+ image.setVisible(True)
+ # Receives delayed event now
+ self.assertEqual(listener.callCount(), 5)
+
+
+class TestImageDataAggregated(PlotWidgetTestCase):
+ """Test ImageDataAggregated item"""
+
+ def test(self):
+ data = numpy.random.random(1024**2).reshape(1024, 1024)
+
+ item = items.ImageDataAggregated()
+ item.setData(data)
+ self.assertEqual(item.getAggregationMode(), item.Aggregation.NONE)
+ self.plot.addItem(item)
+
+ for mode in item.Aggregation.members():
+ with self.subTest(mode=mode):
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ item.setAggregationMode(mode)
+ self.qapp.processEvents()
+
+ # Zoom-out
+ for i in range(4):
+ xmin, xmax = self.plot.getXAxis().getLimits()
+ ymin, ymax = self.plot.getYAxis().getLimits()
+ self.plot.setLimits(
+ xmin - (xmax - xmin)/2,
+ xmax + (xmax - xmin)/2,
+ ymin - (ymax - ymin)/2,
+ ymax + (ymax - ymin)/2,
+ )
+ self.qapp.processEvents()
diff --git a/src/silx/gui/plot/test/testLegendSelector.py b/src/silx/gui/plot/test/testLegendSelector.py
new file mode 100644
index 0000000..c40875d
--- /dev/null
+++ b/src/silx/gui/plot/test/testLegendSelector.py
@@ -0,0 +1,130 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for PlotWidget"""
+
+__authors__ = ["T. Rueter", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "15/05/2017"
+
+
+import logging
+import unittest
+
+from silx.gui import qt
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.plot import LegendSelector
+
+
+_logger = logging.getLogger(__name__)
+
+
+class TestLegendSelector(TestCaseQt):
+ """Basic test for LegendSelector"""
+
+ def testLegendSelector(self):
+ """Test copied from __main__ of LegendSelector in PyMca"""
+ class Notifier(qt.QObject):
+ def __init__(self):
+ qt.QObject.__init__(self)
+ self.chk = True
+
+ def signalReceived(self, **kw):
+ obj = self.sender()
+ _logger.info('NOTIFIER -- signal received\n\tsender: %s',
+ str(obj))
+
+ notifier = Notifier()
+
+ legends = ['Legend0',
+ 'Legend1',
+ 'Long Legend 2',
+ 'Foo Legend 3',
+ 'Even Longer Legend 4',
+ 'Short Leg 5',
+ 'Dot symbol 6',
+ 'Comma symbol 7']
+ colors = [qt.Qt.darkRed, qt.Qt.green, qt.Qt.yellow, qt.Qt.darkCyan,
+ qt.Qt.blue, qt.Qt.darkBlue, qt.Qt.red, qt.Qt.darkYellow]
+ symbols = ['o', 't', '+', 'x', 's', 'd', '.', ',']
+
+ win = LegendSelector.LegendListView()
+ # win = LegendListContextMenu()
+ # win = qt.QWidget()
+ # layout = qt.QVBoxLayout()
+ # layout.setContentsMargins(0,0,0,0)
+ llist = []
+
+ for _idx, (l, c, s) in enumerate(zip(legends, colors, symbols)):
+ ddict = {
+ 'color': qt.QColor(c),
+ 'linewidth': 4,
+ 'symbol': s,
+ }
+ legend = l
+ llist.append((legend, ddict))
+ # item = qt.QListWidgetItem(win)
+ # legendWidget = LegendListItemWidget(l)
+ # legendWidget.icon.setSymbol(s)
+ # legendWidget.icon.setColor(qt.QColor(c))
+ # layout.addWidget(legendWidget)
+ # win.setItemWidget(item, legendWidget)
+
+ # win = LegendListItemWidget('Some Legend 1')
+ # print(llist)
+ model = LegendSelector.LegendModel(legendList=llist)
+ win.setModel(model)
+ win.setSelectionModel(qt.QItemSelectionModel(model))
+ win.setContextMenu()
+ # print('Edit triggers: %d'%win.editTriggers())
+
+ # win = LegendListWidget(None, legends)
+ # win[0].updateItem(ddict)
+ # win.setLayout(layout)
+ win.sigLegendSignal.connect(notifier.signalReceived)
+ win.show()
+
+ win.clear()
+ win.setLegendList(llist)
+
+ self.qWaitForWindowExposed(win)
+
+
+class TestRenameCurveDialog(TestCaseQt):
+ """Basic test for RenameCurveDialog"""
+
+ def testDialog(self):
+ """Create dialog, change name and press OK"""
+ self.dialog = LegendSelector.RenameCurveDialog(
+ None, 'curve1', ['curve1', 'curve2', 'curve3'])
+ self.dialog.open()
+ self.qWaitForWindowExposed(self.dialog)
+ self.keyClicks(self.dialog.lineEdit, 'changed')
+ self.mouseClick(self.dialog.okButton, qt.Qt.LeftButton)
+ self.qapp.processEvents()
+ ret = self.dialog.result()
+ self.assertEqual(ret, qt.QDialog.Accepted)
+ newName = self.dialog.getText()
+ self.assertEqual(newName, 'curve1changed')
+ del self.dialog
diff --git a/src/silx/gui/plot/test/testLimitConstraints.py b/src/silx/gui/plot/test/testLimitConstraints.py
new file mode 100644
index 0000000..0bd8e50
--- /dev/null
+++ b/src/silx/gui/plot/test/testLimitConstraints.py
@@ -0,0 +1,114 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test setLimitConstaints on the PlotWidget"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "30/08/2017"
+
+
+import unittest
+from silx.gui.plot import PlotWidget
+
+
+class TestLimitConstaints(unittest.TestCase):
+ """Tests setLimitConstaints class"""
+
+ def setUp(self):
+ self.plot = PlotWidget()
+
+ def tearDown(self):
+ self.plot = None
+
+ def testApi(self):
+ """Test availability of the API"""
+ self.plot.getXAxis().setLimitsConstraints(minPos=1, maxPos=10)
+ self.plot.getXAxis().setRangeConstraints(minRange=1, maxRange=1)
+ self.plot.getYAxis().setLimitsConstraints(minPos=1, maxPos=10)
+ self.plot.getYAxis().setRangeConstraints(minRange=1, maxRange=1)
+
+ def testXMinMax(self):
+ """Test limit constains on x-axis"""
+ self.plot.getXAxis().setLimitsConstraints(minPos=0, maxPos=100)
+ self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (0, 100))
+ self.assertEqual(self.plot.getYAxis().getLimits(), (-1, 101))
+
+ def testYMinMax(self):
+ """Test limit constains on y-axis"""
+ self.plot.getYAxis().setLimitsConstraints(minPos=0, maxPos=100)
+ self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (-1, 101))
+ self.assertEqual(self.plot.getYAxis().getLimits(), (0, 100))
+
+ def testMinXRange(self):
+ """Test min range constains on x-axis"""
+ self.plot.getXAxis().setRangeConstraints(minRange=100)
+ self.plot.setLimits(xmin=1, xmax=99, ymin=1, ymax=99)
+ limits = self.plot.getXAxis().getLimits()
+ self.assertEqual(limits[1] - limits[0], 100)
+ limits = self.plot.getYAxis().getLimits()
+ self.assertNotEqual(limits[1] - limits[0], 100)
+
+ def testMaxXRange(self):
+ """Test max range constains on x-axis"""
+ self.plot.getXAxis().setRangeConstraints(maxRange=100)
+ self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101)
+ limits = self.plot.getXAxis().getLimits()
+ self.assertEqual(limits[1] - limits[0], 100)
+ limits = self.plot.getYAxis().getLimits()
+ self.assertNotEqual(limits[1] - limits[0], 100)
+
+ def testMinYRange(self):
+ """Test min range constains on y-axis"""
+ self.plot.getYAxis().setRangeConstraints(minRange=100)
+ self.plot.setLimits(xmin=1, xmax=99, ymin=1, ymax=99)
+ limits = self.plot.getXAxis().getLimits()
+ self.assertNotEqual(limits[1] - limits[0], 100)
+ limits = self.plot.getYAxis().getLimits()
+ self.assertEqual(limits[1] - limits[0], 100)
+
+ def testMaxYRange(self):
+ """Test max range constains on y-axis"""
+ self.plot.getYAxis().setRangeConstraints(maxRange=100)
+ self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101)
+ limits = self.plot.getXAxis().getLimits()
+ self.assertNotEqual(limits[1] - limits[0], 100)
+ limits = self.plot.getYAxis().getLimits()
+ self.assertEqual(limits[1] - limits[0], 100)
+
+ def testChangeOfConstraints(self):
+ """Test changing of the constraints"""
+ self.plot.getXAxis().setRangeConstraints(minRange=10, maxRange=10)
+ # There is no more constraints on the range
+ self.plot.getXAxis().setRangeConstraints(minRange=None, maxRange=None)
+ self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (-1, 101))
+
+ def testSettingConstraints(self):
+ """Test setting a constaint (setLimits first then the constaint)"""
+ self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101)
+ self.plot.getXAxis().setLimitsConstraints(minPos=0, maxPos=100)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (0, 100))
diff --git a/src/silx/gui/plot/test/testMaskToolsWidget.py b/src/silx/gui/plot/test/testMaskToolsWidget.py
new file mode 100644
index 0000000..522ca51
--- /dev/null
+++ b/src/silx/gui/plot/test/testMaskToolsWidget.py
@@ -0,0 +1,306 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for MaskToolsWidget"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+
+import logging
+import os.path
+import unittest
+
+import numpy
+
+from silx.gui import qt
+from silx.test.utils import temp_dir
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import getQToolButtonFromAction
+from silx.gui.plot import PlotWindow, MaskToolsWidget
+from .utils import PlotWidgetTestCase
+
+import fabio
+
+
+_logger = logging.getLogger(__name__)
+
+
+class TestMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
+ """Basic test for MaskToolsWidget"""
+
+ def _createPlot(self):
+ return PlotWindow()
+
+ def setUp(self):
+ super(TestMaskToolsWidget, self).setUp()
+ self.widget = MaskToolsWidget.MaskToolsDockWidget(plot=self.plot, name='TEST')
+ self.plot.addDockWidget(qt.Qt.BottomDockWidgetArea, self.widget)
+ self.maskWidget = self.widget.widget()
+
+ def tearDown(self):
+ del self.maskWidget
+ del self.widget
+ super(TestMaskToolsWidget, self).tearDown()
+
+ def testEmptyPlot(self):
+ """Empty plot, display MaskToolsDockWidget, toggle multiple masks"""
+ self.maskWidget.setMultipleMasks('single')
+ self.qapp.processEvents()
+
+ self.maskWidget.setMultipleMasks('exclusive')
+ self.qapp.processEvents()
+
+ def _drag(self):
+ """Drag from plot center to offset position"""
+ plot = self.plot.getWidgetHandle()
+ xCenter, yCenter = plot.width() // 2, plot.height() // 2
+ offset = min(plot.width(), plot.height()) // 10
+
+ pos0 = xCenter, yCenter
+ pos1 = xCenter + offset, yCenter + offset
+
+ self.mouseMove(plot, pos=(0, 0))
+ self.mouseMove(plot, pos=pos0)
+ self.qapp.processEvents()
+ self.mousePress(plot, qt.Qt.LeftButton, pos=pos0)
+ self.qapp.processEvents()
+ self.mouseMove(plot, pos=(pos0[0] + offset // 2, pos0[1] + offset // 2))
+ self.mouseMove(plot, pos=pos1)
+ self.qapp.processEvents()
+ self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos1)
+ self.qapp.processEvents()
+ self.mouseMove(plot, pos=(0, 0))
+
+ def _drawPolygon(self):
+ """Draw a star polygon in the plot"""
+ plot = self.plot.getWidgetHandle()
+ x, y = plot.width() // 2, plot.height() // 2
+ offset = min(plot.width(), plot.height()) // 10
+
+ star = [(x, y + offset),
+ (x - offset, y - offset),
+ (x + offset, y),
+ (x - offset, y),
+ (x + offset, y - offset),
+ (x, y + offset)] # Close polygon
+
+ self.mouseMove(plot, pos=(0, 0))
+ for pos in star:
+ self.mouseMove(plot, pos=pos)
+ self.qapp.processEvents()
+ self.mousePress(plot, qt.Qt.LeftButton, pos=pos)
+ self.qapp.processEvents()
+ self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos)
+ self.qapp.processEvents()
+
+ def _drawPencil(self):
+ """Draw a star polygon in the plot"""
+ plot = self.plot.getWidgetHandle()
+ x, y = plot.width() // 2, plot.height() // 2
+ offset = min(plot.width(), plot.height()) // 10
+
+ star = [(x, y + offset),
+ (x - offset, y - offset),
+ (x + offset, y),
+ (x - offset, y),
+ (x + offset, y - offset)]
+
+ self.mouseMove(plot, pos=(0, 0))
+ for start, end in zip(star[:-1], star[1:]):
+ self.mouseMove(plot, pos=start)
+ self.mousePress(plot, qt.Qt.LeftButton, pos=start)
+ self.qapp.processEvents()
+ self.mouseMove(plot, pos=end)
+ self.qapp.processEvents()
+ self.mouseRelease(plot, qt.Qt.LeftButton, pos=end)
+ self.qapp.processEvents()
+
+ def _isMaskItemSync(self):
+ """Check if masks from item and tools are sync or not"""
+ if self.maskWidget.isItemMaskUpdated():
+ return numpy.all(numpy.equal(
+ self.maskWidget.getSelectionMask(),
+ self.plot.getActiveImage().getMaskData(copy=False)))
+ else:
+ return True
+
+ def testWithAnImage(self):
+ """Plot with an image: test MaskToolsWidget interactions"""
+
+ # Add and remove a image (this should enable/disable GUI + change mask)
+ self.plot.addImage(numpy.random.random(1024**2).reshape(1024, 1024),
+ legend='test')
+ self.qapp.processEvents()
+
+ self.plot.remove('test', kind='image')
+ self.qapp.processEvents()
+
+ tests = [((0, 0), (1, 1)),
+ ((1000, 1000), (1, 1)),
+ ((0, 0), (-1, -1)),
+ ((1000, 1000), (-1, -1))]
+
+ for itemMaskUpdated in (False, True):
+ for origin, scale in tests:
+ with self.subTest(origin=origin, scale=scale):
+ self.maskWidget.setItemMaskUpdated(itemMaskUpdated)
+ self.plot.addImage(numpy.arange(1024**2).reshape(1024, 1024),
+ legend='test',
+ origin=origin,
+ scale=scale)
+ self.qapp.processEvents()
+
+ self.assertEqual(
+ self.maskWidget.isItemMaskUpdated(), itemMaskUpdated)
+
+ # Test draw rectangle #
+ toolButton = getQToolButtonFromAction(self.maskWidget.rectAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ # mask
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drag()
+ self.assertFalse(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ self.assertTrue(self._isMaskItemSync())
+
+ # unmask same region
+ self.maskWidget.maskStateGroup.button(0).click()
+ self.qapp.processEvents()
+ self._drag()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ self.assertTrue(self._isMaskItemSync())
+
+ # Test draw polygon #
+ toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ # mask
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drawPolygon()
+ self.assertFalse(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ self.assertTrue(self._isMaskItemSync())
+
+ # unmask same region
+ self.maskWidget.maskStateGroup.button(0).click()
+ self.qapp.processEvents()
+ self._drawPolygon()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ self.assertTrue(self._isMaskItemSync())
+
+ # Test draw pencil #
+ toolButton = getQToolButtonFromAction(self.maskWidget.pencilAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ self.maskWidget.pencilSpinBox.setValue(30)
+ self.qapp.processEvents()
+
+ # mask
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drawPencil()
+ self.assertFalse(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ self.assertTrue(self._isMaskItemSync())
+
+ # unmask same region
+ self.maskWidget.maskStateGroup.button(0).click()
+ self.qapp.processEvents()
+ self._drawPencil()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ self.assertTrue(self._isMaskItemSync())
+
+ # Test no draw tool #
+ toolButton = getQToolButtonFromAction(self.maskWidget.browseAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ self.plot.clear()
+
+ def __loadSave(self, file_format):
+ """Plot with an image: test MaskToolsWidget operations"""
+ self.plot.addImage(numpy.arange(1024**2).reshape(1024, 1024),
+ legend='test')
+ self.qapp.processEvents()
+
+ # Draw a polygon mask
+ toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+ self._drawPolygon()
+
+ ref_mask = self.maskWidget.getSelectionMask()
+ self.assertFalse(numpy.all(numpy.equal(ref_mask, 0)))
+
+ with temp_dir() as tmp:
+ mask_filename = os.path.join(tmp, 'mask.' + file_format)
+ self.maskWidget.save(mask_filename, file_format)
+
+ self.maskWidget.resetSelectionMask()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ self.maskWidget.load(mask_filename)
+ self.assertTrue(numpy.all(numpy.equal(
+ self.maskWidget.getSelectionMask(), ref_mask)))
+
+ def testLoadSaveNpy(self):
+ self.__loadSave("npy")
+
+ def testLoadSaveFit2D(self):
+ self.__loadSave("msk")
+
+ def testSigMaskChangedEmitted(self):
+ self.plot.addImage(numpy.arange(512**2).reshape(512, 512),
+ legend='test')
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ l = []
+
+ def slot():
+ l.append(1)
+
+ self.maskWidget.sigMaskChanged.connect(slot)
+
+ # rectangle mask
+ toolButton = getQToolButtonFromAction(self.maskWidget.rectAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drag()
+
+ self.assertGreater(len(l), 0)
diff --git a/src/silx/gui/plot/test/testPixelIntensityHistoAction.py b/src/silx/gui/plot/test/testPixelIntensityHistoAction.py
new file mode 100644
index 0000000..14a467d
--- /dev/null
+++ b/src/silx/gui/plot/test/testPixelIntensityHistoAction.py
@@ -0,0 +1,145 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for PixelIntensitiesHistoAction"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "02/03/2018"
+
+
+import numpy
+import unittest
+
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import TestCaseQt, getQToolButtonFromAction
+from silx.gui import qt
+from silx.gui.plot import Plot2D
+
+
+class TestPixelIntensitiesHisto(TestCaseQt, ParametricTestCase):
+ """Tests for PixelIntensitiesHistoAction widget."""
+
+ def setUp(self):
+ super(TestPixelIntensitiesHisto, self).setUp()
+ self.image = numpy.random.rand(10, 10)
+ self.plotImage = Plot2D()
+ self.plotImage.getIntensityHistogramAction().setVisible(True)
+
+ def tearDown(self):
+ del self.plotImage
+ super(TestPixelIntensitiesHisto, self).tearDown()
+
+ def testShowAndHide(self):
+ """Simple test that the plot is showing and hiding when activating the
+ action"""
+ self.plotImage.addImage(self.image, origin=(0, 0), legend='sino')
+ self.plotImage.show()
+
+ histoAction = self.plotImage.getIntensityHistogramAction()
+
+ # test the pixel intensity diagram is showing
+ button = getQToolButtonFromAction(histoAction)
+ self.assertIsNot(button, None)
+ self.mouseMove(button)
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qapp.processEvents()
+ self.assertTrue(histoAction.getHistogramWidget().isVisible())
+
+ # test the pixel intensity diagram is hiding
+ self.qapp.setActiveWindow(self.plotImage)
+ self.qapp.processEvents()
+ self.mouseMove(button)
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qapp.processEvents()
+ self.assertFalse(histoAction.getHistogramWidget().isVisible())
+
+ def testImageFormatInput(self):
+ """Test multiple type as image input"""
+ typesToTest = [numpy.uint8, numpy.int8, numpy.int16, numpy.int32,
+ numpy.float32, numpy.float64]
+ self.plotImage.addImage(self.image, origin=(0, 0), legend='sino')
+ self.plotImage.show()
+ button = getQToolButtonFromAction(
+ self.plotImage.getIntensityHistogramAction())
+ self.mouseMove(button)
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qapp.processEvents()
+ for typeToTest in typesToTest:
+ with self.subTest(typeToTest=typeToTest):
+ self.plotImage.addImage(self.image.astype(typeToTest),
+ origin=(0, 0), legend='sino')
+
+ def testScatter(self):
+ """Test that an histogram from a scatter is displayed"""
+ xx = numpy.arange(10)
+ yy = numpy.arange(10)
+ value = numpy.sin(xx)
+ self.plotImage.addScatter(xx, yy, value)
+ self.plotImage.show()
+
+ histoAction = self.plotImage.getIntensityHistogramAction()
+
+ # test the pixel intensity diagram is showing
+ button = getQToolButtonFromAction(histoAction)
+ self.assertIsNot(button, None)
+ self.mouseMove(button)
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qapp.processEvents()
+
+ widget = histoAction.getHistogramWidget()
+ self.assertTrue(widget.isVisible())
+ items = widget.getPlotWidget().getItems()
+ self.assertEqual(len(items), 1)
+
+ def testChangeItem(self):
+ """Test that histogram changes it the item changes"""
+ xx = numpy.arange(10)
+ yy = numpy.arange(10)
+ value = numpy.sin(xx)
+ self.plotImage.addScatter(xx, yy, value)
+ self.plotImage.show()
+
+ histoAction = self.plotImage.getIntensityHistogramAction()
+
+ # test the pixel intensity diagram is showing
+ button = getQToolButtonFromAction(histoAction)
+ self.assertIsNot(button, None)
+ self.mouseMove(button)
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qapp.processEvents()
+
+ # Reach histogram from the first item
+ widget = histoAction.getHistogramWidget()
+ self.assertTrue(widget.isVisible())
+ items = widget.getPlotWidget().getItems()
+ data1 = items[0].getValueData(copy=False)
+
+ # Set another item to the plot
+ self.plotImage.addImage(self.image, origin=(0, 0), legend='sino')
+ self.qapp.processEvents()
+ data2 = items[0].getValueData(copy=False)
+
+ # Histogram is not the same
+ self.assertFalse(numpy.array_equal(data1, data2))
diff --git a/src/silx/gui/plot/test/testPlotActions.py b/src/silx/gui/plot/test/testPlotActions.py
new file mode 100644
index 0000000..f38e05b
--- /dev/null
+++ b/src/silx/gui/plot/test/testPlotActions.py
@@ -0,0 +1,110 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test of actions integrated in the plot window"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "09/11/2018"
+
+
+import pytest
+import weakref
+
+from silx.gui import qt
+from silx.gui.colors import Colormap
+from silx.gui.plot.PlotWindow import PlotWindow
+
+import numpy
+
+
+@pytest.fixture
+def colormap1():
+ colormap = Colormap(name='gray',
+ vmin=10.0, vmax=20.0,
+ normalization='linear')
+ yield colormap
+
+
+@pytest.fixture
+def colormap2():
+ colormap = Colormap(name='red',
+ vmin=10.0, vmax=20.0,
+ normalization='linear')
+ yield colormap
+
+
+@pytest.fixture
+def plot(qapp):
+ plot = PlotWindow()
+ plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ yield weakref.proxy(plot)
+ plot.close()
+ qapp.processEvents()
+
+
+def test_action_active_colormap(qapp_utils, plot, colormap1, colormap2):
+ plot.getColormapAction()._actionTriggered(checked=True)
+ colormapDialog = plot.getColormapAction()._dialog
+
+ defaultColormap = plot.getDefaultColormap()
+ assert colormapDialog.getColormap() is defaultColormap
+
+ plot.addImage(data=numpy.random.rand(10, 10), legend='img1',
+ origin=(0, 0),
+ colormap=colormap1)
+ plot.setActiveImage('img1')
+ assert colormapDialog.getColormap() is colormap1
+
+ plot.addImage(data=numpy.random.rand(10, 10), legend='img2',
+ origin=(0, 0), colormap=colormap2)
+ plot.addImage(data=numpy.random.rand(10, 10), legend='img3',
+ origin=(0, 0))
+
+ plot.setActiveImage('img3')
+ assert colormapDialog.getColormap() is defaultColormap
+ plot.getActiveImage().setColormap(colormap2)
+ assert colormapDialog.getColormap() is colormap2
+
+ plot.remove('img2')
+ plot.remove('img3')
+ plot.remove('img1')
+ assert colormapDialog.getColormap() is defaultColormap
+
+
+def test_action_show_hide_colormap_dialog(qapp_utils, plot, colormap1):
+ plot.getColormapAction()._actionTriggered(checked=True)
+ colormapDialog = plot.getColormapAction()._dialog
+
+ plot.getColormapAction()._actionTriggered(checked=False)
+ assert not plot.getColormapAction().isChecked()
+ plot.getColormapAction()._actionTriggered(checked=True)
+ assert plot.getColormapAction().isChecked()
+ plot.addImage(data=numpy.random.rand(10, 10), legend='img1',
+ origin=(0, 0), colormap=colormap1)
+ colormap1.setName('red')
+ plot.getColormapAction()._actionTriggered()
+ colormap1.setName('blue')
+ colormapDialog.close()
+ assert not plot.getColormapAction().isChecked()
diff --git a/src/silx/gui/plot/test/testPlotInteraction.py b/src/silx/gui/plot/test/testPlotInteraction.py
new file mode 100644
index 0000000..fba364e
--- /dev/null
+++ b/src/silx/gui/plot/test/testPlotInteraction.py
@@ -0,0 +1,160 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016=2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests of plot interaction, through a PlotWidget"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "01/09/2017"
+
+
+import unittest
+from silx.gui import qt
+from .utils import PlotWidgetTestCase
+
+
+class _SignalDump(object):
+ """Callable object that store passed arguments in a list"""
+
+ def __init__(self):
+ self._received = []
+
+ def __call__(self, *args):
+ self._received.append(args)
+
+ @property
+ def received(self):
+ """Return a shallow copy of the list of received arguments"""
+ return list(self._received)
+
+
+class TestSelectPolygon(PlotWidgetTestCase):
+ """Test polygon selection interaction"""
+
+ def _interactionModeChanged(self, source):
+ """Check that source received in event is the correct one"""
+ self.assertEqual(source, self)
+
+ def _draw(self, polygon):
+ """Draw a polygon in the plot
+
+ :param polygon: List of points (x, y) of the polygon (closed)
+ """
+ plot = self.plot.getWidgetHandle()
+
+ dump = _SignalDump()
+ self.plot.sigPlotSignal.connect(dump)
+
+ for pos in polygon:
+ self.mouseMove(plot, pos=pos)
+ self.qapp.processEvents()
+ self.mousePress(plot, qt.Qt.LeftButton, pos=pos)
+ self.qapp.processEvents()
+ self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos)
+ self.qapp.processEvents()
+
+ self.plot.sigPlotSignal.disconnect(dump)
+ return [args[0] for args in dump.received]
+
+ def test(self):
+ """Test draw polygons + events"""
+ self.plot.sigInteractiveModeChanged.connect(
+ self._interactionModeChanged)
+
+ self.plot.setInteractiveMode(
+ 'draw', shape='polygon', label='test', source=self)
+ interaction = self.plot.getInteractiveMode()
+
+ self.assertEqual(interaction['mode'], 'draw')
+ self.assertEqual(interaction['shape'], 'polygon')
+
+ self.plot.sigInteractiveModeChanged.disconnect(
+ self._interactionModeChanged)
+
+ plot = self.plot.getWidgetHandle()
+ xCenter, yCenter = plot.width() // 2, plot.height() // 2
+ offset = min(plot.width(), plot.height()) // 10
+
+ # Star polygon
+ star = [(xCenter, yCenter + offset),
+ (xCenter - offset, yCenter - offset),
+ (xCenter + offset, yCenter),
+ (xCenter - offset, yCenter),
+ (xCenter + offset, yCenter - offset),
+ (xCenter, yCenter + offset)] # Close polygon
+
+ # Draw while dumping signals
+ events = self._draw(star)
+
+ # Test last event
+ drawEvents = [event for event in events
+ if event['event'].startswith('drawing')]
+ self.assertEqual(drawEvents[-1]['event'], 'drawingFinished')
+ self.assertEqual(len(drawEvents[-1]['points']), 6)
+
+ # Large square
+ largeSquare = [(xCenter - offset, yCenter - offset),
+ (xCenter + offset, yCenter - offset),
+ (xCenter + offset, yCenter + offset),
+ (xCenter - offset, yCenter + offset),
+ (xCenter - offset, yCenter - offset)] # Close polygon
+
+ # Draw while dumping signals
+ events = self._draw(largeSquare)
+
+ # Test last event
+ drawEvents = [event for event in events
+ if event['event'].startswith('drawing')]
+ self.assertEqual(drawEvents[-1]['event'], 'drawingFinished')
+ self.assertEqual(len(drawEvents[-1]['points']), 5)
+
+ # Rectangle too thin along X: Some points are ignored
+ thinRectX = [(xCenter, yCenter - offset),
+ (xCenter, yCenter + offset),
+ (xCenter + 1, yCenter + offset),
+ (xCenter + 1, yCenter - offset)] # Close polygon
+
+ # Draw while dumping signals
+ events = self._draw(thinRectX)
+
+ # Test last event
+ drawEvents = [event for event in events
+ if event['event'].startswith('drawing')]
+ self.assertEqual(drawEvents[-1]['event'], 'drawingFinished')
+ self.assertEqual(len(drawEvents[-1]['points']), 3)
+
+ # Rectangle too thin along Y: Some points are ignored
+ thinRectY = [(xCenter - offset, yCenter),
+ (xCenter + offset, yCenter),
+ (xCenter + offset, yCenter + 1),
+ (xCenter - offset, yCenter + 1)] # Close polygon
+
+ # Draw while dumping signals
+ events = self._draw(thinRectY)
+
+ # Test last event
+ drawEvents = [event for event in events
+ if event['event'].startswith('drawing')]
+ self.assertEqual(drawEvents[-1]['event'], 'drawingFinished')
+ self.assertEqual(len(drawEvents[-1]['points']), 3)
diff --git a/src/silx/gui/plot/test/testPlotWidget.py b/src/silx/gui/plot/test/testPlotWidget.py
new file mode 100755
index 0000000..f6e108d
--- /dev/null
+++ b/src/silx/gui/plot/test/testPlotWidget.py
@@ -0,0 +1,2113 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for PlotWidget"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/01/2019"
+
+
+import unittest
+import logging
+import numpy
+import pytest
+
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import SignalListener
+from silx.gui.utils.testutils import TestCaseQt
+
+from silx.gui import qt
+from silx.gui.plot import PlotWidget
+from silx.gui.plot.items.curve import CurveStyle
+from silx.gui.plot.items import BoundingRect, XAxisExtent, YAxisExtent, Axis
+from silx.gui.colors import Colormap
+
+from .utils import PlotWidgetTestCase
+
+
+SIZE = 1024
+"""Size of the test image"""
+
+DATA_2D = numpy.arange(SIZE ** 2).reshape(SIZE, SIZE)
+"""Image data set"""
+
+
+logger = logging.getLogger(__name__)
+
+
+class TestSpecialBackend(PlotWidgetTestCase, ParametricTestCase):
+
+ def __init__(self, methodName='runTest', backend=None):
+ TestCaseQt.__init__(self, methodName=methodName)
+ self.__backend = backend
+
+ def _createPlot(self):
+ return PlotWidget(backend=self.__backend)
+
+ def testPlot(self):
+ self.assertIsNotNone(self.plot)
+
+
+class TestPlotWidget(PlotWidgetTestCase, ParametricTestCase):
+ """Basic tests for PlotWidget"""
+
+ def testShow(self):
+ """Most basic test"""
+ pass
+
+ def testSetTitleLabels(self):
+ """Set title and axes labels"""
+
+ title, xlabel, ylabel = 'the title', 'x label', 'y label'
+ self.plot.setGraphTitle(title)
+ self.plot.getXAxis().setLabel(xlabel)
+ self.plot.getYAxis().setLabel(ylabel)
+ self.qapp.processEvents()
+
+ self.assertEqual(self.plot.getGraphTitle(), title)
+ self.assertEqual(self.plot.getXAxis().getLabel(), xlabel)
+ self.assertEqual(self.plot.getYAxis().getLabel(), ylabel)
+
+ def _checkLimits(self,
+ expectedXLim=None,
+ expectedYLim=None,
+ expectedRatio=None):
+ """Assert that limits are as expected"""
+ xlim = self.plot.getXAxis().getLimits()
+ ylim = self.plot.getYAxis().getLimits()
+ ratio = abs(xlim[1] - xlim[0]) / abs(ylim[1] - ylim[0])
+
+ if expectedXLim is not None:
+ self.assertEqual(expectedXLim, xlim)
+
+ if expectedYLim is not None:
+ self.assertEqual(expectedYLim, ylim)
+
+ if expectedRatio is not None:
+ self.assertTrue(
+ numpy.allclose(expectedRatio, ratio, atol=0.01))
+
+ def testChangeLimitsWithAspectRatio(self):
+ self.plot.setKeepDataAspectRatio()
+ self.qapp.processEvents()
+ xlim = self.plot.getXAxis().getLimits()
+ ylim = self.plot.getYAxis().getLimits()
+ defaultRatio = abs(xlim[1] - xlim[0]) / abs(ylim[1] - ylim[0])
+
+ self.plot.getXAxis().setLimits(1., 10.)
+ self._checkLimits(expectedXLim=(1., 10.), expectedRatio=defaultRatio)
+ self.qapp.processEvents()
+ self._checkLimits(expectedXLim=(1., 10.), expectedRatio=defaultRatio)
+
+ self.plot.getYAxis().setLimits(1., 10.)
+ self._checkLimits(expectedYLim=(1., 10.), expectedRatio=defaultRatio)
+ self.qapp.processEvents()
+ self._checkLimits(expectedYLim=(1., 10.), expectedRatio=defaultRatio)
+
+ def testResizeWidget(self):
+ """Test resizing the widget and receiving limitsChanged events"""
+ self.plot.resize(200, 200)
+ self.qapp.processEvents()
+ self.qWait(100)
+
+ xlim = self.plot.getXAxis().getLimits()
+ ylim = self.plot.getYAxis().getLimits()
+
+ listener = SignalListener()
+ self.plot.getXAxis().sigLimitsChanged.connect(listener.partial('x'))
+ self.plot.getYAxis().sigLimitsChanged.connect(listener.partial('y'))
+
+ # Resize without aspect ratio
+ self.plot.resize(200, 300)
+ self.qapp.processEvents()
+ self.qWait(100)
+ self._checkLimits(expectedXLim=xlim, expectedYLim=ylim)
+ self.assertEqual(listener.callCount(), 0)
+
+ # Resize with aspect ratio
+ self.plot.setKeepDataAspectRatio(True)
+ self.qapp.processEvents()
+ self.qWait(1000)
+ listener.clear() # Clean-up received signal
+
+ self.plot.resize(200, 200)
+ self.qapp.processEvents()
+ self.qWait(100)
+ self.assertNotEqual(listener.callCount(), 0)
+
+ def testAddRemoveItemSignals(self):
+ """Test sigItemAdded and sigItemAboutToBeRemoved"""
+ listener = SignalListener()
+ self.plot.sigItemAdded.connect(listener.partial('add'))
+ self.plot.sigItemAboutToBeRemoved.connect(listener.partial('remove'))
+
+ self.plot.addCurve((1, 2, 3), (3, 2, 1), legend='curve')
+ self.assertEqual(listener.callCount(), 1)
+
+ curve = self.plot.getCurve('curve')
+ self.plot.remove('curve')
+ self.assertEqual(listener.callCount(), 2)
+ self.assertEqual(listener.arguments(callIndex=0), ('add', curve))
+ self.assertEqual(listener.arguments(callIndex=1), ('remove', curve))
+
+ def testGetItems(self):
+ """Test getItems method"""
+ curve_x = 1, 2
+ self.plot.addCurve(curve_x, (3, 4))
+ image = (0, 1), (2, 3)
+ self.plot.addImage(image)
+ scatter_x = 10, 11
+ self.plot.addScatter(scatter_x, (12, 13), (0, 1))
+ marker_pos = 5, 5
+ self.plot.addMarker(*marker_pos)
+ marker_x = 6
+ self.plot.addXMarker(marker_x)
+ self.plot.addShape((0, 5), (2, 10), shape='rectangle')
+
+ items = self.plot.getItems()
+ self.assertEqual(len(items), 6)
+ self.assertTrue(numpy.all(numpy.equal(items[0].getXData(), curve_x)))
+ self.assertTrue(numpy.all(numpy.equal(items[1].getData(), image)))
+ self.assertTrue(numpy.all(numpy.equal(items[2].getXData(), scatter_x)))
+ self.assertTrue(numpy.all(numpy.equal(items[3].getPosition(), marker_pos)))
+ self.assertTrue(numpy.all(numpy.equal(items[4].getPosition()[0], marker_x)))
+ self.assertEqual(items[5].getType(), 'rectangle')
+
+ def testRemoveDiscardItem(self):
+ """Test removeItem and discardItem"""
+ self.plot.addCurve((1, 2, 3), (1, 2, 3))
+ curve = self.plot.getItems()[0]
+ self.plot.removeItem(curve)
+ with self.assertRaises(ValueError):
+ self.plot.removeItem(curve)
+
+ self.plot.addCurve((1, 2, 3), (1, 2, 3))
+ curve = self.plot.getItems()[0]
+ result = self.plot.discardItem(curve)
+ self.assertTrue(result)
+ result = self.plot.discardItem(curve)
+ self.assertFalse(result)
+
+ def testBackGroundColors(self):
+ self.plot.setVisible(True)
+ self.qWaitForWindowExposed(self.plot)
+ self.qapp.processEvents()
+
+ # Custom the full background
+ color = self.plot.getBackgroundColor()
+ self.assertTrue(color.isValid())
+ self.assertEqual(color, qt.QColor(255, 255, 255))
+ self.plot.setBackgroundColor("red")
+ color = self.plot.getBackgroundColor()
+ self.assertTrue(color.isValid())
+ self.qapp.processEvents()
+
+ # Custom the data background
+ color = self.plot.getDataBackgroundColor()
+ self.assertFalse(color.isValid())
+ self.plot.setDataBackgroundColor("red")
+ color = self.plot.getDataBackgroundColor()
+ self.assertTrue(color.isValid())
+ self.qapp.processEvents()
+
+ # Back to default
+ self.plot.setBackgroundColor('white')
+ self.plot.setDataBackgroundColor(None)
+ color = self.plot.getBackgroundColor()
+ self.assertTrue(color.isValid())
+ self.assertEqual(color, qt.QColor(255, 255, 255))
+ color = self.plot.getDataBackgroundColor()
+ self.assertFalse(color.isValid())
+ self.qapp.processEvents()
+
+
+class TestPlotImage(PlotWidgetTestCase, ParametricTestCase):
+ """Basic tests for addImage"""
+
+ def setUp(self):
+ super(TestPlotImage, self).setUp()
+
+ self.plot.getYAxis().setLabel('Rows')
+ self.plot.getXAxis().setLabel('Columns')
+
+ def testPlotColormapTemperature(self):
+ self.plot.setGraphTitle('Temp. Linear')
+
+ colormap = Colormap(name='temperature',
+ normalization='linear',
+ vmin=None,
+ vmax=None)
+ self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap)
+
+ def testPlotColormapGray(self):
+ self.plot.setKeepDataAspectRatio(False)
+ self.plot.setGraphTitle('Gray Linear')
+
+ colormap = Colormap(name='gray',
+ normalization='linear',
+ vmin=None,
+ vmax=None)
+ self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap)
+
+ def testPlotColormapTemperatureLog(self):
+ self.plot.setGraphTitle('Temp. Log')
+
+ colormap = Colormap(name='temperature',
+ normalization=Colormap.LOGARITHM,
+ vmin=None,
+ vmax=None)
+ self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap)
+
+ def testPlotRgbRgba(self):
+ self.plot.setKeepDataAspectRatio(False)
+ self.plot.setGraphTitle('RGB + RGBA')
+
+ rgb = numpy.array(
+ (((0, 0, 0), (128, 0, 0), (255, 0, 0)),
+ ((0, 128, 0), (0, 128, 128), (0, 128, 255))),
+ dtype=numpy.uint8)
+
+ self.plot.addImage(rgb, legend="rgb_uint8",
+ origin=(0, 0), scale=(1, 1),
+ resetzoom=False)
+
+ rgb = numpy.array(
+ (((0, 0, 0), (32768, 0, 0), (65535, 0, 0)),
+ ((0, 32768, 0), (0, 32768, 32768), (0, 32768, 65535))),
+ dtype=numpy.uint16)
+
+ self.plot.addImage(rgb, legend="rgb_uint16",
+ origin=(3, 2), scale=(2, 2),
+ resetzoom=False)
+
+ rgba = numpy.array(
+ (((0, 0, 0, .5), (.5, 0, 0, 1), (1, 0, 0, .5)),
+ ((0, .5, 0, 1), (0, .5, .5, 1), (0, 1, 1, .5))),
+ dtype=numpy.float32)
+
+ self.plot.addImage(rgba, legend="rgba_float32",
+ origin=(9, 6), scale=(1, 1),
+ resetzoom=False)
+
+ self.plot.resetZoom()
+
+ def testPlotColormapCustom(self):
+ self.plot.setKeepDataAspectRatio(False)
+ self.plot.setGraphTitle('Custom colormap')
+
+ colormap = Colormap(name=None,
+ normalization=Colormap.LINEAR,
+ vmin=None,
+ vmax=None,
+ colors=((0., 0., 0.), (1., 0., 0.),
+ (0., 1., 0.), (0., 0., 1.)))
+ self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap,
+ resetzoom=False)
+
+ colormap = Colormap(name=None,
+ normalization=Colormap.LINEAR,
+ vmin=None,
+ vmax=None,
+ colors=numpy.array(
+ ((0, 0, 0, 0), (0, 0, 0, 128),
+ (128, 128, 128, 128), (255, 255, 255, 255)),
+ dtype=numpy.uint8))
+ self.plot.addImage(DATA_2D, legend="image 2", colormap=colormap,
+ origin=(DATA_2D.shape[0], 0),
+ resetzoom=False)
+ self.plot.resetZoom()
+
+ def testPlotColormapNaNColor(self):
+ self.plot.setKeepDataAspectRatio(False)
+ self.plot.setGraphTitle('Colormap with NaN color')
+
+ colormap = Colormap()
+ colormap.setNaNColor('red')
+ self.assertEqual(colormap.getNaNColor(), qt.QColor(255, 0, 0))
+ data = DATA_2D.astype(numpy.float32)
+ data[len(data)//2:] = numpy.nan
+ self.plot.addImage(data, legend="image 1", colormap=colormap,
+ resetzoom=False)
+ self.plot.resetZoom()
+
+ colormap.setNaNColor((0., 1., 0., 1.))
+ self.assertEqual(colormap.getNaNColor(), qt.QColor(0, 255, 0))
+ self.qapp.processEvents()
+
+ def testImageOriginScale(self):
+ """Test of image with different origin and scale"""
+ self.plot.setGraphTitle('origin and scale')
+
+ tests = [ # (origin, scale)
+ ((10, 20), (1, 1)),
+ ((10, 20), (-1, -1)),
+ ((-10, 20), (2, 1)),
+ ((10, -20), (-1, -2)),
+ (100, 2),
+ (-100, (1, 1)),
+ ((10, 20), 2),
+ ]
+
+ for origin, scale in tests:
+ with self.subTest(origin=origin, scale=scale):
+ self.plot.addImage(DATA_2D, origin=origin, scale=scale)
+
+ try:
+ ox, oy = origin
+ except TypeError:
+ ox, oy = origin, origin
+ try:
+ sx, sy = scale
+ except TypeError:
+ sx, sy = scale, scale
+ xbounds = ox, ox + DATA_2D.shape[1] * sx
+ ybounds = oy, oy + DATA_2D.shape[0] * sy
+
+ # Check limits without aspect ratio
+ xmin, xmax = self.plot.getXAxis().getLimits()
+ ymin, ymax = self.plot.getYAxis().getLimits()
+ self.assertEqual(xmin, min(xbounds))
+ self.assertEqual(xmax, max(xbounds))
+ self.assertEqual(ymin, min(ybounds))
+ self.assertEqual(ymax, max(ybounds))
+
+ # Check limits with aspect ratio
+ self.plot.setKeepDataAspectRatio(True)
+ xmin, xmax = self.plot.getXAxis().getLimits()
+ ymin, ymax = self.plot.getYAxis().getLimits()
+ self.assertTrue(round(xmin, 7) <= min(xbounds))
+ self.assertTrue(round(xmax, 7) >= max(xbounds))
+ self.assertTrue(round(ymin, 7) <= min(ybounds))
+ self.assertTrue(round(ymax, 7) >= max(ybounds))
+
+ self.plot.setKeepDataAspectRatio(False) # Reset aspect ratio
+ self.plot.clear()
+ self.plot.resetZoom()
+
+ def testPlotColormapDictAPI(self):
+ """Test that the addImage API using a colormap dictionary is still
+ working"""
+ self.plot.setGraphTitle('Temp. Log')
+
+ colormap = {
+ 'name': 'temperature',
+ 'normalization': 'log',
+ 'vmin': None,
+ 'vmax': None
+ }
+ self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap)
+
+ def testPlotComplexImage(self):
+ """Test that a complex image is displayed as its absolute value."""
+ data = numpy.linspace(1, 1j, 100).reshape(10, 10)
+ self.plot.addImage(data, legend='complex')
+
+ image = self.plot.getActiveImage()
+ retrievedData = image.getData(copy=False)
+ self.assertTrue(
+ numpy.all(numpy.equal(retrievedData, numpy.absolute(data))))
+
+ def testPlotBooleanImage(self):
+ """Test that a boolean image is displayed and converted to int8."""
+ data = numpy.zeros((10, 10), dtype=bool)
+ data[::2, ::2] = True
+ self.plot.addImage(data, legend='boolean')
+
+ image = self.plot.getActiveImage()
+ retrievedData = image.getData(copy=False)
+ self.assertTrue(numpy.all(numpy.equal(retrievedData, data)))
+ self.assertIs(retrievedData.dtype.type, numpy.int8)
+
+ def testPlotAlphaImage(self):
+ """Test with an alpha image layer"""
+ data = numpy.random.random((10, 10))
+ alpha = numpy.linspace(0, 1, 100).reshape(10, 10)
+ self.plot.addImage(data, legend='image')
+ image = self.plot.getActiveImage()
+ image.setData(data, alpha=alpha)
+ self.qapp.processEvents()
+ self.assertTrue(numpy.array_equal(alpha, image.getAlphaData()))
+
+
+class TestPlotCurve(PlotWidgetTestCase):
+ """Basic tests for addCurve."""
+
+ # Test data sets
+ xData = numpy.arange(1000)
+ yData = -500 + 100 * numpy.sin(xData)
+ xData2 = xData + 1000
+ yData2 = xData - 1000 + 200 * numpy.random.random(1000)
+
+ def setUp(self):
+ super(TestPlotCurve, self).setUp()
+ self.plot.setGraphTitle('Curve')
+ self.plot.getYAxis().setLabel('Rows')
+ self.plot.getXAxis().setLabel('Columns')
+
+ self.plot.setActiveCurveHandling(False)
+
+ def testPlotCurveInfinite(self):
+ """Test plot curves with not finite data"""
+ tests = {
+ 'y all not finite': ([0, 1, 2], [numpy.inf, numpy.nan, -numpy.inf]),
+ 'x all not finite': ([numpy.inf, numpy.nan, -numpy.inf], [0, 1, 2]),
+ 'x some inf': ([0, numpy.inf, 2], [0, 1, 2]),
+ 'y some inf': ([0, 1, 2], [0, numpy.inf, 2])
+ }
+ for name, args in tests.items():
+ with self.subTest(name):
+ self.plot.addCurve(*args)
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+ self.plot.clear()
+
+ def testPlotCurveColorFloat(self):
+ color = numpy.array(numpy.random.random(3 * 1000),
+ dtype=numpy.float32).reshape(1000, 3)
+
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve 1",
+ replace=False, resetzoom=False,
+ color=color,
+ linestyle="", symbol="s")
+ self.plot.addCurve(self.xData2, self.yData2,
+ legend="curve 2",
+ replace=False, resetzoom=False,
+ color='green', linestyle="-", symbol='o')
+ self.plot.resetZoom()
+
+ def testPlotCurveColorByte(self):
+ color = numpy.array(255 * numpy.random.random(3 * 1000),
+ dtype=numpy.uint8).reshape(1000, 3)
+
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve 1",
+ replace=False, resetzoom=False,
+ color=color,
+ linestyle="", symbol="s")
+ self.plot.addCurve(self.xData2, self.yData2,
+ legend="curve 2",
+ replace=False, resetzoom=False,
+ color='green', linestyle="-", symbol='o')
+ self.plot.resetZoom()
+
+ def testPlotCurveColors(self):
+ color = numpy.array(numpy.random.random(3 * 1000),
+ dtype=numpy.float32).reshape(1000, 3)
+
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve 2",
+ replace=False, resetzoom=False,
+ color=color, linestyle="-", symbol='o')
+ self.plot.resetZoom()
+
+ # Test updating color array
+
+ # From array to array
+ newColors = numpy.ones((len(self.xData), 3), dtype=numpy.float32)
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve 2",
+ replace=False, resetzoom=False,
+ color=newColors, symbol='o')
+
+ # Array to single color
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve 2",
+ replace=False, resetzoom=False,
+ color='green', symbol='o')
+
+ # single color to array
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve 2",
+ replace=False, resetzoom=False,
+ color=color, symbol='o')
+
+ def testPlotBaselineNumpyArray(self):
+ """simple test of the API with baseline as a numpy array"""
+ x = numpy.arange(0, 10, step=0.1)
+ my_sin = numpy.sin(x)
+ y = numpy.arange(-4, 6, step=0.1) + my_sin
+ baseline = y - 1.0
+
+ self.plot.addCurve(x=x, y=y, color='grey', legend='curve1', fill=True,
+ baseline=baseline)
+
+ def testPlotBaselineScalar(self):
+ """simple test of the API with baseline as an int"""
+ x = numpy.arange(0, 10, step=0.1)
+ my_sin = numpy.sin(x)
+ y = numpy.arange(-4, 6, step=0.1) + my_sin
+
+ self.plot.addCurve(x=x, y=y, color='grey', legend='curve1', fill=True,
+ baseline=0)
+
+ def testPlotBaselineList(self):
+ """simple test of the API with baseline as an int"""
+ x = numpy.arange(0, 10, step=0.1)
+ my_sin = numpy.sin(x)
+ y = numpy.arange(-4, 6, step=0.1) + my_sin
+
+ self.plot.addCurve(x=x, y=y, color='grey', legend='curve1', fill=True,
+ baseline=list(range(0, 100, 1)))
+
+ def testPlotCurveComplexData(self):
+ """Test curve with complex data"""
+ data = numpy.arange(100.) + 1j
+ self.plot.addCurve(x=data, y=data, xerror=data, yerror=data)
+
+
+class TestPlotHistogram(PlotWidgetTestCase):
+ """Basic tests for add Histogram"""
+ def setUp(self):
+ super(TestPlotHistogram, self).setUp()
+ self.edges = numpy.arange(0, 10, step=1)
+ self.histogram = numpy.random.random(len(self.edges))
+
+ def testPlot(self):
+ self.plot.addHistogram(histogram=self.histogram,
+ edges=self.edges,
+ legend='histogram1')
+
+ def testPlotBaseline(self):
+ self.plot.addHistogram(histogram=self.histogram,
+ edges=self.edges,
+ legend='histogram1',
+ color='blue',
+ baseline=-2,
+ z=2,
+ fill=True)
+
+
+class TestPlotScatter(PlotWidgetTestCase, ParametricTestCase):
+ """Basic tests for addScatter"""
+
+ def testScatter(self):
+ x = numpy.arange(100)
+ y = numpy.arange(100)
+ value = numpy.arange(100)
+ self.plot.addScatter(x, y, value)
+ self.plot.resetZoom()
+
+ def testScatterComplexData(self):
+ """Test scatter item with complex data"""
+ data = numpy.arange(100.) + 1j
+ self.plot.addScatter(
+ x=data, y=data, value=data, xerror=data, yerror=data)
+ self.plot.resetZoom()
+
+ def testScatterVisualization(self):
+ self.plot.addScatter((0, 1, 0, 1), (0, 0, 2, 2), (0, 1, 2, 3))
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ scatter = self.plot.getItems()[0]
+
+ for visualization in ('solid',
+ 'points',
+ 'regular_grid',
+ 'irregular_grid',
+ 'binned_statistic',
+ scatter.Visualization.SOLID,
+ scatter.Visualization.POINTS,
+ scatter.Visualization.REGULAR_GRID,
+ scatter.Visualization.IRREGULAR_GRID,
+ scatter.Visualization.BINNED_STATISTIC):
+ with self.subTest(visualization=visualization):
+ scatter.setVisualization(visualization)
+ self.qapp.processEvents()
+
+ def testGridVisualization(self):
+ """Test regular and irregular grid mode with different points"""
+ points = { # name: (x, y, order)
+ 'single point': ((1.,), (1.,), 'row'),
+ 'horizontal line': ((0, 1, 2), (0, 0, 0), 'row'),
+ 'horizontal line backward': ((2, 1, 0), (0, 0, 0), 'row'),
+ 'vertical line': ((0, 0, 0), (0, 1, 2), 'row'),
+ 'vertical line backward': ((0, 0, 0), (2, 1, 0), 'row'),
+ 'grid fast x, +x +y': ((0, 1, 2, 0, 1, 2), (0, 0, 0, 1, 1, 1), 'row'),
+ 'grid fast x, +x -y': ((0, 1, 2, 0, 1, 2), (1, 1, 1, 0, 0, 0), 'row'),
+ 'grid fast x, -x -y': ((2, 1, 0, 2, 1, 0), (1, 1, 1, 0, 0, 0), 'row'),
+ 'grid fast x, -x +y': ((2, 1, 0, 2, 1, 0), (0, 0, 0, 1, 1, 1), 'row'),
+ 'grid fast y, +x +y': ((0, 0, 0, 1, 1, 1), (0, 1, 2, 0, 1, 2), 'column'),
+ 'grid fast y, +x -y': ((0, 0, 0, 1, 1, 1), (2, 1, 0, 2, 1, 0), 'column'),
+ 'grid fast y, -x -y': ((1, 1, 1, 0, 0, 0), (2, 1, 0, 2, 1, 0), 'column'),
+ 'grid fast y, -x +y': ((1, 1, 1, 0, 0, 0), (0, 1, 2, 0, 1, 2), 'column'),
+ }
+
+ self.plot.addScatter((), (), ())
+ scatter = self.plot.getItems()[0]
+
+ self.qapp.processEvents()
+
+ for visualization in (scatter.Visualization.REGULAR_GRID,
+ scatter.Visualization.IRREGULAR_GRID):
+ scatter.setVisualization(visualization)
+ self.assertIs(scatter.getVisualization(), visualization)
+
+ for name, (x, y, ref_order) in points.items():
+ with self.subTest(name=name, visualization=visualization.name):
+ scatter.setData(x, y, numpy.arange(len(x)))
+ self.plot.setGraphTitle(name)
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ order = scatter.getCurrentVisualizationParameter(
+ scatter.VisualizationParameter.GRID_MAJOR_ORDER)
+ self.assertEqual(ref_order, order)
+
+ ref_bounds = (x[0], y[0]), (x[-1], y[-1])
+ bounds = scatter.getCurrentVisualizationParameter(
+ scatter.VisualizationParameter.GRID_BOUNDS)
+ self.assertEqual(ref_bounds, bounds)
+
+ shape = scatter.getCurrentVisualizationParameter(
+ scatter.VisualizationParameter.GRID_SHAPE)
+
+ self.plot.getXAxis().setLimits(numpy.min(x) - 1, numpy.max(x) + 1)
+ self.plot.getYAxis().setLimits(numpy.min(y) - 1, numpy.max(y) + 1)
+ self.qapp.processEvents()
+
+ for index, position in enumerate(zip(x, y)):
+ xpixel, ypixel = self.plot.dataToPixel(*position)
+ result = scatter.pick(xpixel, ypixel)
+ self.assertIsNotNone(result)
+ self.assertIs(result.getItem(), scatter)
+ self.assertEqual(result.getIndices(), (index,))
+
+ def testBinnedStatisticVisualization(self):
+ """Test binned display"""
+ self.plot.addScatter((), (), ())
+ scatter = self.plot.getItems()[0]
+ scatter.setVisualization(scatter.Visualization.BINNED_STATISTIC)
+ self.assertIs(scatter.getVisualization(),
+ scatter.Visualization.BINNED_STATISTIC)
+ self.assertEqual(
+ scatter.getVisualizationParameter(
+ scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION),
+ 'mean')
+
+ self.qapp.processEvents()
+
+ scatter.setData(*numpy.random.random(300).reshape(3, -1))
+ self.qapp.processEvents()
+
+ # Update data
+ scatter.setData(*numpy.random.random(3000).reshape(3, -1))
+ self.qapp.processEvents()
+
+ for reduction in ('count', 'sum', 'mean'):
+ with self.subTest(reduction=reduction):
+ scatter.setVisualizationParameter(
+ scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION,
+ reduction)
+ self.assertEqual(
+ scatter.getVisualizationParameter(
+ scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION),
+ reduction)
+
+ self.qapp.processEvents()
+
+
+class TestPlotMarker(PlotWidgetTestCase):
+ """Basic tests for add*Marker"""
+
+ def setUp(self):
+ super(TestPlotMarker, self).setUp()
+ self.plot.getYAxis().setLabel('Rows')
+ self.plot.getXAxis().setLabel('Columns')
+
+ self.plot.getXAxis().setAutoScale(False)
+ self.plot.getYAxis().setAutoScale(False)
+ self.plot.setKeepDataAspectRatio(False)
+ self.plot.setLimits(0., 100., -100., 100.)
+
+ def testPlotMarkerX(self):
+ self.plot.setGraphTitle('Markers X')
+
+ markers = [
+ (10., 'blue', False, False),
+ (20., 'red', False, False),
+ (40., 'green', True, False),
+ (60., 'gray', True, True),
+ (80., 'black', False, True),
+ ]
+
+ for x, color, select, drag in markers:
+ name = str(x)
+ if select:
+ name += " sel."
+ if drag:
+ name += " drag"
+ self.plot.addXMarker(x, name, name, color, select, drag)
+ self.plot.resetZoom()
+
+ def testPlotMarkerY(self):
+ self.plot.setGraphTitle('Markers Y')
+
+ markers = [
+ (-50., 'blue', False, False),
+ (-30., 'red', False, False),
+ (0., 'green', True, False),
+ (10., 'gray', True, True),
+ (80., 'black', False, True),
+ ]
+
+ for y, color, select, drag in markers:
+ name = str(y)
+ if select:
+ name += " sel."
+ if drag:
+ name += " drag"
+ self.plot.addYMarker(y, name, name, color, select, drag)
+ self.plot.resetZoom()
+
+ def testPlotMarkerPt(self):
+ self.plot.setGraphTitle('Markers Pt')
+
+ markers = [
+ (10., -50., 'blue', False, False),
+ (40., -30., 'red', False, False),
+ (50., 0., 'green', True, False),
+ (50., 20., 'gray', True, True),
+ (70., 50., 'black', False, True),
+ ]
+ for x, y, color, select, drag in markers:
+ name = "{0},{1}".format(x, y)
+ if select:
+ name += " sel."
+ if drag:
+ name += " drag"
+ self.plot.addMarker(x, y, name, name, color, select, drag)
+
+ self.plot.resetZoom()
+
+ def testPlotMarkerWithoutLegend(self):
+ self.plot.setGraphTitle('Markers without legend')
+ self.plot.getYAxis().setInverted(True)
+
+ # Markers without legend
+ self.plot.addMarker(10, 10)
+ self.plot.addMarker(10, 20)
+ self.plot.addMarker(40, 50, text='test', symbol=None)
+ self.plot.addMarker(40, 50, text='test', symbol='+')
+ self.plot.addXMarker(25)
+ self.plot.addXMarker(35)
+ self.plot.addXMarker(45, text='test')
+ self.plot.addYMarker(55)
+ self.plot.addYMarker(65)
+ self.plot.addYMarker(75, text='test')
+
+ self.plot.resetZoom()
+
+ def testPlotMarkerYAxis(self):
+ # Check only the API
+
+ legend = self.plot.addMarker(10, 10)
+ item = self.plot._getMarker(legend)
+ self.assertEqual(item.getYAxis(), "left")
+
+ legend = self.plot.addMarker(10, 10, yaxis="right")
+ item = self.plot._getMarker(legend)
+ self.assertEqual(item.getYAxis(), "right")
+
+ legend = self.plot.addMarker(10, 10, yaxis="left")
+ item = self.plot._getMarker(legend)
+ self.assertEqual(item.getYAxis(), "left")
+
+ legend = self.plot.addXMarker(10, yaxis="right")
+ item = self.plot._getMarker(legend)
+ self.assertEqual(item.getYAxis(), "right")
+
+ legend = self.plot.addXMarker(10, yaxis="left")
+ item = self.plot._getMarker(legend)
+ self.assertEqual(item.getYAxis(), "left")
+
+ legend = self.plot.addYMarker(10, yaxis="right")
+ item = self.plot._getMarker(legend)
+ self.assertEqual(item.getYAxis(), "right")
+
+ legend = self.plot.addYMarker(10, yaxis="left")
+ item = self.plot._getMarker(legend)
+ self.assertEqual(item.getYAxis(), "left")
+
+ self.plot.resetZoom()
+
+
+# TestPlotItem ################################################################
+
+class TestPlotItem(PlotWidgetTestCase):
+ """Basic tests for addItem."""
+
+ # Polygon coordinates and color
+ POLYGONS = [ # legend, x coords, y coords, color
+ ('triangle', numpy.array((10, 30, 50)),
+ numpy.array((55, 70, 55)), 'red'),
+ ('square', numpy.array((10, 10, 50, 50)),
+ numpy.array((10, 50, 50, 10)), 'green'),
+ ('star', numpy.array((60, 70, 80, 60, 80)),
+ numpy.array((25, 50, 25, 40, 40)), 'blue'),
+ ('2 triangles-simple',
+ numpy.array((90., 95., 100., numpy.nan, 90., 95., 100.)),
+ numpy.array((25., 5., 25., numpy.nan, 30., 50., 30.)),
+ 'pink'),
+ ('2 triangles-extra NaN',
+ numpy.array((numpy.nan, 90., 95., 100., numpy.nan, 0., 90., 95., 100., numpy.nan)),
+ numpy.array((0., 55., 70., 55., numpy.nan, numpy.nan, 75., 90., 75., numpy.nan)),
+ 'black'),
+ ]
+
+ # Rectangle coordinantes and color
+ RECTANGLES = [ # legend, x coords, y coords, color
+ ('square 1', numpy.array((1., 10.)),
+ numpy.array((1., 10.)), 'red'),
+ ('square 2', numpy.array((10., 20.)),
+ numpy.array((10., 20.)), 'green'),
+ ('square 3', numpy.array((20., 30.)),
+ numpy.array((20., 30.)), 'blue'),
+ ('rect 1', numpy.array((1., 30.)),
+ numpy.array((35., 40.)), 'black'),
+ ('line h', numpy.array((1., 30.)),
+ numpy.array((45., 45.)), 'darkRed'),
+ ]
+
+ SCALES = Axis.LINEAR, Axis.LOGARITHMIC
+
+ def setUp(self):
+ super(TestPlotItem, self).setUp()
+
+ self.plot.getYAxis().setLabel('Rows')
+ self.plot.getXAxis().setLabel('Columns')
+ self.plot.getXAxis().setAutoScale(False)
+ self.plot.getYAxis().setAutoScale(False)
+ self.plot.setKeepDataAspectRatio(False)
+ self.plot.setLimits(0., 100., -100., 100.)
+
+ def testPlotItemPolygonFill(self):
+ for scale in self.SCALES:
+ with self.subTest(scale=scale):
+ self.plot.clear()
+ self.plot.getXAxis().setScale(scale)
+ self.plot.getYAxis().setScale(scale)
+ self.plot.setGraphTitle('Item Fill %s' % scale)
+
+ for legend, xList, yList, color in self.POLYGONS:
+ self.plot.addShape(xList, yList, legend=legend,
+ replace=False, linestyle='--',
+ shape="polygon", fill=True, color=color)
+ self.plot.resetZoom()
+
+ def testPlotItemPolygonNoFill(self):
+ for scale in self.SCALES:
+ with self.subTest(scale=scale):
+ self.plot.clear()
+ self.plot.getXAxis().setScale(scale)
+ self.plot.getYAxis().setScale(scale)
+ self.plot.setGraphTitle('Item No Fill %s' % scale)
+
+ for legend, xList, yList, color in self.POLYGONS:
+ self.plot.addShape(xList, yList, legend=legend,
+ replace=False, linestyle='--',
+ shape="polygon", fill=False, color=color)
+ self.plot.resetZoom()
+
+ def testPlotItemRectangleFill(self):
+ for scale in self.SCALES:
+ with self.subTest(scale=scale):
+ self.plot.clear()
+ self.plot.getXAxis().setScale(scale)
+ self.plot.getYAxis().setScale(scale)
+ self.plot.setGraphTitle('Rectangle Fill %s' % scale)
+
+ for legend, xList, yList, color in self.RECTANGLES:
+ self.plot.addShape(xList, yList, legend=legend,
+ replace=False,
+ shape="rectangle", fill=True, color=color)
+ self.plot.resetZoom()
+
+ def testPlotItemRectangleNoFill(self):
+ for scale in self.SCALES:
+ with self.subTest(scale=scale):
+ self.plot.clear()
+ self.plot.getXAxis().setScale(scale)
+ self.plot.getYAxis().setScale(scale)
+ self.plot.setGraphTitle('Rectangle No Fill %s' % scale)
+
+ for legend, xList, yList, color in self.RECTANGLES:
+ self.plot.addShape(xList, yList, legend=legend,
+ replace=False,
+ shape="rectangle", fill=False, color=color)
+ self.plot.resetZoom()
+
+
+class TestPlotActiveCurveImage(PlotWidgetTestCase):
+ """Basic tests for active curve and image handling"""
+ xData = numpy.arange(1000)
+ yData = -500 + 100 * numpy.sin(xData)
+ xData2 = xData + 1000
+ yData2 = xData - 1000 + 200 * numpy.random.random(1000)
+
+ def tearDown(self):
+ self.plot.setActiveCurveHandling(False)
+ super(TestPlotActiveCurveImage, self).tearDown()
+
+ def testActiveCurveAndLabels(self):
+ # Active curve handling off, no label change
+ self.plot.setActiveCurveHandling(False)
+ self.plot.getXAxis().setLabel('XLabel')
+ self.plot.getYAxis().setLabel('YLabel')
+ self.plot.addCurve((1, 2), (1, 2))
+ self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel')
+ self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel')
+
+ self.plot.addCurve((1, 2), (2, 3), xlabel='x1', ylabel='y1')
+ self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel')
+ self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel')
+
+ self.plot.clear()
+ self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel')
+ self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel')
+
+ # Active curve handling on, label changes
+ self.plot.setActiveCurveHandling(True)
+ self.plot.getXAxis().setLabel('XLabel')
+ self.plot.getYAxis().setLabel('YLabel')
+
+ # labels changed as active curve
+ self.plot.addCurve((1, 2), (1, 2), legend='1',
+ xlabel='x1', ylabel='y1')
+ self.plot.setActiveCurve('1')
+ self.assertEqual(self.plot.getXAxis().getLabel(), 'x1')
+ self.assertEqual(self.plot.getYAxis().getLabel(), 'y1')
+
+ # labels not changed as not active curve
+ self.plot.addCurve((1, 2), (2, 3), legend='2')
+ self.assertEqual(self.plot.getXAxis().getLabel(), 'x1')
+ self.assertEqual(self.plot.getYAxis().getLabel(), 'y1')
+
+ # labels changed
+ self.plot.setActiveCurve('2')
+ self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel')
+ self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel')
+
+ self.plot.setActiveCurve('1')
+ self.assertEqual(self.plot.getXAxis().getLabel(), 'x1')
+ self.assertEqual(self.plot.getYAxis().getLabel(), 'y1')
+
+ self.plot.clear()
+ self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel')
+ self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel')
+
+ def testPlotActiveCurveSelectionMode(self):
+ self.plot.clear()
+ self.plot.setActiveCurveHandling(True)
+ legend = "curve 1"
+ self.plot.addCurve(self.xData, self.yData,
+ legend=legend,
+ color="green")
+
+ # active curve should be None
+ self.assertEqual(self.plot.getActiveCurve(just_legend=True), None)
+
+ # active curve should be None when None is set as active curve
+ self.plot.setActiveCurve(legend)
+ current = self.plot.getActiveCurve(just_legend=True)
+ self.assertEqual(current, legend)
+ self.plot.setActiveCurve(None)
+ current = self.plot.getActiveCurve(just_legend=True)
+ self.assertEqual(current, None)
+
+ # testing it automatically toggles if there is only one
+ self.plot.setActiveCurveSelectionMode("legacy")
+ current = self.plot.getActiveCurve(just_legend=True)
+ self.assertEqual(current, legend)
+
+ # active curve should not change when None set as active curve
+ self.assertEqual(self.plot.getActiveCurveSelectionMode(), "legacy")
+ self.plot.setActiveCurve(None)
+ current = self.plot.getActiveCurve(just_legend=True)
+ self.assertEqual(current, legend)
+
+ # situation where no curve is active
+ self.plot.clear()
+ self.plot.setActiveCurveHandling(True)
+ self.assertEqual(self.plot.getActiveCurveSelectionMode(), "atmostone")
+ self.plot.addCurve(self.xData, self.yData,
+ legend=legend,
+ color="green")
+ self.assertEqual(self.plot.getActiveCurve(just_legend=True), None)
+ self.plot.addCurve(self.xData2, self.yData2,
+ legend="curve 2",
+ color="red")
+ self.assertEqual(self.plot.getActiveCurve(just_legend=True), None)
+ self.plot.setActiveCurveSelectionMode("legacy")
+ self.assertEqual(self.plot.getActiveCurve(just_legend=True), None)
+
+ # the first curve added should be active
+ self.plot.clear()
+ self.plot.addCurve(self.xData, self.yData,
+ legend=legend,
+ color="green")
+ self.assertEqual(self.plot.getActiveCurve(just_legend=True), legend)
+ self.plot.addCurve(self.xData2, self.yData2,
+ legend="curve 2",
+ color="red")
+ self.assertEqual(self.plot.getActiveCurve(just_legend=True), legend)
+
+ def testActiveCurveStyle(self):
+ """Test change of active curve style"""
+ self.plot.setActiveCurveHandling(True)
+ self.plot.setActiveCurveStyle(color='black')
+ style = self.plot.getActiveCurveStyle()
+ self.assertEqual(style.getColor(), (0., 0., 0., 1.))
+ self.assertIsNone(style.getLineStyle())
+ self.assertIsNone(style.getLineWidth())
+ self.assertIsNone(style.getSymbol())
+ self.assertIsNone(style.getSymbolSize())
+
+ self.plot.addCurve(x=self.xData, y=self.yData, legend="curve1")
+ curve = self.plot.getCurve("curve1")
+ curve.setColor('blue')
+ curve.setLineStyle('-')
+ curve.setLineWidth(1)
+ curve.setSymbol('o')
+ curve.setSymbolSize(5)
+
+ # Check default current style
+ defaultStyle = curve.getCurrentStyle()
+ self.assertEqual(defaultStyle, CurveStyle(color='blue',
+ linestyle='-',
+ linewidth=1,
+ symbol='o',
+ symbolsize=5))
+
+ # Activate curve with highlight color=black
+ self.plot.setActiveCurve("curve1")
+ style = curve.getCurrentStyle()
+ self.assertEqual(style.getColor(), (0., 0., 0., 1.))
+ self.assertEqual(style.getLineStyle(), '-')
+ self.assertEqual(style.getLineWidth(), 1)
+ self.assertEqual(style.getSymbol(), 'o')
+ self.assertEqual(style.getSymbolSize(), 5)
+
+ # Change highlight to linewidth=2
+ self.plot.setActiveCurveStyle(linewidth=2)
+ style = curve.getCurrentStyle()
+ self.assertEqual(style.getColor(), (0., 0., 1., 1.))
+ self.assertEqual(style.getLineStyle(), '-')
+ self.assertEqual(style.getLineWidth(), 2)
+ self.assertEqual(style.getSymbol(), 'o')
+ self.assertEqual(style.getSymbolSize(), 5)
+
+ self.plot.setActiveCurve(None)
+ self.assertEqual(curve.getCurrentStyle(), defaultStyle)
+
+ def testActiveImageAndLabels(self):
+ # Active image handling always on, no API for toggling it
+ self.plot.getXAxis().setLabel('XLabel')
+ self.plot.getYAxis().setLabel('YLabel')
+
+ # labels changed as active curve
+ self.plot.addImage(numpy.arange(100).reshape(10, 10),
+ legend='1', xlabel='x1', ylabel='y1')
+ self.assertEqual(self.plot.getXAxis().getLabel(), 'x1')
+ self.assertEqual(self.plot.getYAxis().getLabel(), 'y1')
+
+ # labels not changed as not active curve
+ self.plot.addImage(numpy.arange(100).reshape(10, 10),
+ legend='2')
+ self.assertEqual(self.plot.getXAxis().getLabel(), 'x1')
+ self.assertEqual(self.plot.getYAxis().getLabel(), 'y1')
+
+ # labels changed
+ self.plot.setActiveImage('2')
+ self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel')
+ self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel')
+
+ self.plot.setActiveImage('1')
+ self.assertEqual(self.plot.getXAxis().getLabel(), 'x1')
+ self.assertEqual(self.plot.getYAxis().getLabel(), 'y1')
+
+ self.plot.clear()
+ self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel')
+ self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel')
+
+
+##############################################################################
+# Log
+##############################################################################
+
+class TestPlotEmptyLog(PlotWidgetTestCase):
+ """Basic tests for log plot"""
+ def testEmptyPlotTitleLabelsLog(self):
+ self.plot.setGraphTitle('Empty Log Log')
+ self.plot.getXAxis().setLabel('X')
+ self.plot.getYAxis().setLabel('Y')
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.getYAxis()._setLogarithmic(True)
+ self.plot.resetZoom()
+
+
+class TestPlotAxes(TestCaseQt, ParametricTestCase):
+
+ # Test data
+ xData = numpy.arange(1, 10)
+ yData = xData ** 2
+
+ def __init__(self, methodName='runTest', backend=None):
+ unittest.TestCase.__init__(self, methodName)
+ self.__backend = backend
+
+ def setUp(self):
+ super(TestPlotAxes, self).setUp()
+ self.plot = PlotWidget(backend=self.__backend)
+ # It is not needed to display the plot
+ # It saves a lot of time
+ # self.plot.show()
+ # self.qWaitForWindowExposed(self.plot)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ super(TestPlotAxes, self).tearDown()
+
+ def testDefaultAxes(self):
+ axis = self.plot.getXAxis()
+ self.assertEqual(axis.getScale(), axis.LINEAR)
+ axis = self.plot.getYAxis()
+ self.assertEqual(axis.getScale(), axis.LINEAR)
+ axis = self.plot.getYAxis(axis="right")
+ self.assertEqual(axis.getScale(), axis.LINEAR)
+
+ def testOldPlotAxis_getterSetter(self):
+ """Test silx API prior to silx 0.6"""
+ x = self.plot.getXAxis()
+ y = self.plot.getYAxis()
+ p = self.plot
+
+ tests = [
+ # setters
+ (p.setGraphXLimits, (10, 20), x.getLimits, (10, 20)),
+ (p.setGraphYLimits, (10, 20), y.getLimits, (10, 20)),
+ (p.setGraphXLabel, "foox", x.getLabel, "foox"),
+ (p.setGraphYLabel, "fooy", y.getLabel, "fooy"),
+ (p.setYAxisInverted, True, y.isInverted, True),
+ (p.setXAxisLogarithmic, True, x.getScale, x.LOGARITHMIC),
+ (p.setYAxisLogarithmic, True, y.getScale, y.LOGARITHMIC),
+ (p.setXAxisAutoScale, False, x.isAutoScale, False),
+ (p.setYAxisAutoScale, False, y.isAutoScale, False),
+ # getters
+ (x.setLimits, (11, 20), p.getGraphXLimits, (11, 20)),
+ (y.setLimits, (11, 20), p.getGraphYLimits, (11, 20)),
+ (x.setLabel, "fooxx", p.getGraphXLabel, "fooxx"),
+ (y.setLabel, "fooyy", p.getGraphYLabel, "fooyy"),
+ (y.setInverted, False, p.isYAxisInverted, False),
+ (x.setScale, x.LINEAR, p.isXAxisLogarithmic, False),
+ (y.setScale, y.LINEAR, p.isYAxisLogarithmic, False),
+ (x.setAutoScale, True, p.isXAxisAutoScale, True),
+ (y.setAutoScale, True, p.isYAxisAutoScale, True),
+ ]
+ for testCase in tests:
+ setter, value, getter, expected = testCase
+ with self.subTest():
+ if setter is not None:
+ if not isinstance(value, tuple):
+ value = (value, )
+ setter(*value)
+ if getter is not None:
+ self.assertEqual(getter(), expected)
+
+ def testOldPlotAxis_Logarithmic(self):
+ """Test silx API prior to silx 0.6"""
+ x = self.plot.getXAxis()
+ y = self.plot.getYAxis()
+ yright = self.plot.getYAxis(axis="right")
+
+ self.assertEqual(x.getScale(), x.LINEAR)
+ self.assertEqual(y.getScale(), x.LINEAR)
+ self.assertEqual(yright.getScale(), x.LINEAR)
+
+ self.plot.setXAxisLogarithmic(True)
+ self.assertEqual(x.getScale(), x.LOGARITHMIC)
+ self.assertEqual(y.getScale(), x.LINEAR)
+ self.assertEqual(yright.getScale(), x.LINEAR)
+ self.assertEqual(self.plot.isXAxisLogarithmic(), True)
+ self.assertEqual(self.plot.isYAxisLogarithmic(), False)
+
+ self.plot.setYAxisLogarithmic(True)
+ self.assertEqual(x.getScale(), x.LOGARITHMIC)
+ self.assertEqual(y.getScale(), x.LOGARITHMIC)
+ self.assertEqual(yright.getScale(), x.LOGARITHMIC)
+ self.assertEqual(self.plot.isXAxisLogarithmic(), True)
+ self.assertEqual(self.plot.isYAxisLogarithmic(), True)
+
+ yright.setScale(yright.LINEAR)
+ self.assertEqual(x.getScale(), x.LOGARITHMIC)
+ self.assertEqual(y.getScale(), x.LINEAR)
+ self.assertEqual(yright.getScale(), x.LINEAR)
+ self.assertEqual(self.plot.isXAxisLogarithmic(), True)
+ self.assertEqual(self.plot.isYAxisLogarithmic(), False)
+
+ def testOldPlotAxis_AutoScale(self):
+ """Test silx API prior to silx 0.6"""
+ x = self.plot.getXAxis()
+ y = self.plot.getYAxis()
+ yright = self.plot.getYAxis(axis="right")
+
+ self.assertEqual(x.isAutoScale(), True)
+ self.assertEqual(y.isAutoScale(), True)
+ self.assertEqual(yright.isAutoScale(), True)
+
+ self.plot.setXAxisAutoScale(False)
+ self.assertEqual(x.isAutoScale(), False)
+ self.assertEqual(y.isAutoScale(), True)
+ self.assertEqual(yright.isAutoScale(), True)
+ self.assertEqual(self.plot.isXAxisAutoScale(), False)
+ self.assertEqual(self.plot.isYAxisAutoScale(), True)
+
+ self.plot.setYAxisAutoScale(False)
+ self.assertEqual(x.isAutoScale(), False)
+ self.assertEqual(y.isAutoScale(), False)
+ self.assertEqual(yright.isAutoScale(), False)
+ self.assertEqual(self.plot.isXAxisAutoScale(), False)
+ self.assertEqual(self.plot.isYAxisAutoScale(), False)
+
+ yright.setAutoScale(True)
+ self.assertEqual(x.isAutoScale(), False)
+ self.assertEqual(y.isAutoScale(), True)
+ self.assertEqual(yright.isAutoScale(), True)
+ self.assertEqual(self.plot.isXAxisAutoScale(), False)
+ self.assertEqual(self.plot.isYAxisAutoScale(), True)
+
+ def testOldPlotAxis_Inverted(self):
+ """Test silx API prior to silx 0.6"""
+ x = self.plot.getXAxis()
+ y = self.plot.getYAxis()
+ yright = self.plot.getYAxis(axis="right")
+
+ self.assertEqual(x.isInverted(), False)
+ self.assertEqual(y.isInverted(), False)
+ self.assertEqual(yright.isInverted(), False)
+
+ self.plot.setYAxisInverted(True)
+ self.assertEqual(x.isInverted(), False)
+ self.assertEqual(y.isInverted(), True)
+ self.assertEqual(yright.isInverted(), True)
+ self.assertEqual(self.plot.isYAxisInverted(), True)
+
+ yright.setInverted(False)
+ self.assertEqual(x.isInverted(), False)
+ self.assertEqual(y.isInverted(), False)
+ self.assertEqual(yright.isInverted(), False)
+ self.assertEqual(self.plot.isYAxisInverted(), False)
+
+ def testLogXWithData(self):
+ self.plot.setGraphTitle('Curve X: Log Y: Linear')
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve",
+ replace=False, resetzoom=True,
+ color='green', linestyle="-", symbol='o')
+ axis = self.plot.getXAxis()
+ axis.setScale(axis.LOGARITHMIC)
+
+ self.assertEqual(axis.getScale(), axis.LOGARITHMIC)
+
+ def testLogYWithData(self):
+ self.plot.setGraphTitle('Curve X: Linear Y: Log')
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve",
+ replace=False, resetzoom=True,
+ color='green', linestyle="-", symbol='o')
+ axis = self.plot.getYAxis()
+ axis.setScale(axis.LOGARITHMIC)
+
+ self.assertEqual(axis.getScale(), axis.LOGARITHMIC)
+ axis = self.plot.getYAxis(axis="right")
+ self.assertEqual(axis.getScale(), axis.LOGARITHMIC)
+
+ def testLogYRightWithData(self):
+ self.plot.setGraphTitle('Curve X: Linear Y: Log')
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve",
+ replace=False, resetzoom=True,
+ color='green', linestyle="-", symbol='o')
+ axis = self.plot.getYAxis(axis="right")
+ axis.setScale(axis.LOGARITHMIC)
+
+ self.assertEqual(axis.getScale(), axis.LOGARITHMIC)
+ axis = self.plot.getYAxis()
+ self.assertEqual(axis.getScale(), axis.LOGARITHMIC)
+
+ def testLimitsChanged_setLimits(self):
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve",
+ replace=False, resetzoom=False,
+ color='green', linestyle="-", symbol='o')
+ listener = SignalListener()
+ self.plot.getXAxis().sigLimitsChanged.connect(listener.partial(axis="x"))
+ self.plot.getYAxis().sigLimitsChanged.connect(listener.partial(axis="y"))
+ self.plot.getYAxis(axis="right").sigLimitsChanged.connect(listener.partial(axis="y2"))
+ self.plot.setLimits(0, 1, 0, 1, 0, 1)
+ # at least one event per axis
+ self.assertEqual(len(set(listener.karguments(argumentName="axis"))), 3)
+
+ def testLimitsChanged_resetZoom(self):
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve",
+ replace=False, resetzoom=False,
+ color='green', linestyle="-", symbol='o')
+ listener = SignalListener()
+ self.plot.getXAxis().sigLimitsChanged.connect(listener.partial(axis="x"))
+ self.plot.getYAxis().sigLimitsChanged.connect(listener.partial(axis="y"))
+ self.plot.getYAxis(axis="right").sigLimitsChanged.connect(listener.partial(axis="y2"))
+ self.plot.resetZoom()
+ # at least one event per axis
+ self.assertEqual(len(set(listener.karguments(argumentName="axis"))), 3)
+
+ def testLimitsChanged_setXLimit(self):
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve",
+ replace=False, resetzoom=False,
+ color='green', linestyle="-", symbol='o')
+ listener = SignalListener()
+ axis = self.plot.getXAxis()
+ axis.sigLimitsChanged.connect(listener)
+ axis.setLimits(20, 30)
+ # at least one event per axis
+ self.assertEqual(listener.arguments(callIndex=-1), (20.0, 30.0))
+ self.assertEqual(axis.getLimits(), (20.0, 30.0))
+
+ def testLimitsChanged_setYLimit(self):
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve",
+ replace=False, resetzoom=False,
+ color='green', linestyle="-", symbol='o')
+ listener = SignalListener()
+ axis = self.plot.getYAxis()
+ axis.sigLimitsChanged.connect(listener)
+ axis.setLimits(20, 30)
+ # at least one event per axis
+ self.assertEqual(listener.arguments(callIndex=-1), (20.0, 30.0))
+ self.assertEqual(axis.getLimits(), (20.0, 30.0))
+
+ def testLimitsChanged_setYRightLimit(self):
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve",
+ replace=False, resetzoom=False,
+ color='green', linestyle="-", symbol='o')
+ listener = SignalListener()
+ axis = self.plot.getYAxis(axis="right")
+ axis.sigLimitsChanged.connect(listener)
+ axis.setLimits(20, 30)
+ # at least one event per axis
+ self.assertEqual(listener.arguments(callIndex=-1), (20.0, 30.0))
+ self.assertEqual(axis.getLimits(), (20.0, 30.0))
+
+ def testScaleProxy(self):
+ listener = SignalListener()
+ y = self.plot.getYAxis()
+ yright = self.plot.getYAxis(axis="right")
+ y.sigScaleChanged.connect(listener.partial("left"))
+ yright.sigScaleChanged.connect(listener.partial("right"))
+ yright.setScale(yright.LOGARITHMIC)
+
+ self.assertEqual(y.getScale(), y.LOGARITHMIC)
+ events = listener.arguments()
+ self.assertEqual(len(events), 2)
+ self.assertIn(("left", y.LOGARITHMIC), events)
+ self.assertIn(("right", y.LOGARITHMIC), events)
+
+ def testAutoScaleProxy(self):
+ listener = SignalListener()
+ y = self.plot.getYAxis()
+ yright = self.plot.getYAxis(axis="right")
+ y.sigAutoScaleChanged.connect(listener.partial("left"))
+ yright.sigAutoScaleChanged.connect(listener.partial("right"))
+ yright.setAutoScale(False)
+
+ self.assertEqual(y.isAutoScale(), False)
+ events = listener.arguments()
+ self.assertEqual(len(events), 2)
+ self.assertIn(("left", False), events)
+ self.assertIn(("right", False), events)
+
+ def testInvertedProxy(self):
+ listener = SignalListener()
+ y = self.plot.getYAxis()
+ yright = self.plot.getYAxis(axis="right")
+ y.sigInvertedChanged.connect(listener.partial("left"))
+ yright.sigInvertedChanged.connect(listener.partial("right"))
+ yright.setInverted(True)
+
+ self.assertEqual(y.isInverted(), True)
+ events = listener.arguments()
+ self.assertEqual(len(events), 2)
+ self.assertIn(("left", True), events)
+ self.assertIn(("right", True), events)
+
+ def testAxesDisplayedFalse(self):
+ """Test coverage on setAxesDisplayed(False)"""
+ self.plot.setAxesDisplayed(False)
+
+ def testAxesDisplayedTrue(self):
+ """Test coverage on setAxesDisplayed(True)"""
+ self.plot.setAxesDisplayed(True)
+
+ def testAxesMargins(self):
+ """Test PlotWidget's getAxesMargins and setAxesMargins"""
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ margins = self.plot.getAxesMargins()
+ self.assertEqual(margins, (.15, .1, .1, .15))
+
+ for margins in ((0., 0., 0., 0.), (.15, .1, .1, .15)):
+ with self.subTest(margins=margins):
+ self.plot.setAxesMargins(*margins)
+ self.qapp.processEvents()
+ self.assertEqual(self.plot.getAxesMargins(), margins)
+
+ def testBoundingRectItem(self):
+ item = BoundingRect()
+ item.setBounds((-1000, 1000, -2000, 2000))
+ self.plot.addItem(item)
+ self.plot.resetZoom()
+ limits = numpy.array(self.plot.getXAxis().getLimits())
+ numpy.testing.assert_almost_equal(limits, numpy.array([-1000, 1000]))
+ limits = numpy.array(self.plot.getYAxis().getLimits())
+ numpy.testing.assert_almost_equal(limits, numpy.array([-2000, 2000]))
+
+ def testBoundingRectRightItem(self):
+ item = BoundingRect()
+ item.setYAxis("right")
+ item.setBounds((-1000, 1000, -2000, 2000))
+ self.plot.addItem(item)
+ self.plot.resetZoom()
+ limits = numpy.array(self.plot.getXAxis().getLimits())
+ numpy.testing.assert_almost_equal(limits, numpy.array([-1000, 1000]))
+ limits = numpy.array(self.plot.getYAxis("right").getLimits())
+ numpy.testing.assert_almost_equal(limits, numpy.array([-2000, 2000]))
+
+ def testBoundingRectArguments(self):
+ item = BoundingRect()
+ with self.assertRaises(Exception):
+ item.setBounds((1000, -1000, -2000, 2000))
+ with self.assertRaises(Exception):
+ item.setBounds((-1000, 1000, 2000, -2000))
+
+ def testBoundingRectWithLog(self):
+ item = BoundingRect()
+ self.plot.addItem(item)
+
+ item.setBounds((-1000, 1000, -2000, 2000))
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.getYAxis()._setLogarithmic(False)
+ self.assertEqual(item.getBounds(), (1000, 1000, -2000, 2000))
+
+ item.setBounds((-1000, 1000, -2000, 2000))
+ self.plot.getXAxis()._setLogarithmic(False)
+ self.plot.getYAxis()._setLogarithmic(True)
+ self.assertEqual(item.getBounds(), (-1000, 1000, 2000, 2000))
+
+ item.setBounds((-1000, 0, -2000, 2000))
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.getYAxis()._setLogarithmic(False)
+ self.assertIsNone(item.getBounds())
+
+ def testAxisExtent(self):
+ """Test XAxisExtent and yAxisExtent"""
+ for cls, axis in ((XAxisExtent, self.plot.getXAxis()),
+ (YAxisExtent, self.plot.getYAxis())):
+ for range_, logRange in (((2, 3), (2, 3)),
+ ((-2, -1), (1, 100)),
+ ((-1, 3), (3. * 0.9, 3. * 1.1))):
+ extent = cls()
+ extent.setRange(*range_)
+ self.plot.addItem(extent)
+
+ for isLog, plotRange in ((False, range_), (True, logRange)):
+ with self.subTest(
+ cls=cls.__name__, range=range_, isLog=isLog):
+ axis._setLogarithmic(isLog)
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+ self.assertEqual(axis.getLimits(), plotRange)
+
+ axis._setLogarithmic(False)
+ self.plot.clear()
+
+ def testAxisLimitOverflow(self):
+ """Test setting limis beyond supported range"""
+ xaxis, yaxis = self.plot.getXAxis(), self.plot.getYAxis()
+ for scale in ("linear", "log"):
+ xaxis.setScale(scale)
+ yaxis.setScale(scale)
+ for limits in ((1e300, 1e308),
+ (-1e308, 1e308),
+ (1e-300, 2e-300)):
+ with self.subTest(scale=scale, limits=limits):
+ xaxis.setLimits(*limits)
+ self.qapp.processEvents()
+ self.assertNotEqual(xaxis.getLimits(), limits)
+ yaxis.setLimits(*limits)
+ self.qapp.processEvents()
+ self.assertNotEqual(yaxis.getLimits(), limits)
+
+
+class TestPlotCurveLog(PlotWidgetTestCase, ParametricTestCase):
+ """Basic tests for addCurve with log scale axes"""
+
+ # Test data
+ xData = numpy.arange(1000) + 1
+ yData = xData ** 2
+
+ def _setLabels(self):
+ self.plot.getXAxis().setLabel('X')
+ self.plot.getYAxis().setLabel('X * X')
+
+ def testPlotCurveLogX(self):
+ self._setLabels()
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.setGraphTitle('Curve X: Log Y: Linear')
+
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve",
+ replace=False, resetzoom=True,
+ color='green', linestyle="-", symbol='o')
+
+ def testPlotCurveLogY(self):
+ self._setLabels()
+ self.plot.getYAxis()._setLogarithmic(True)
+
+ self.plot.setGraphTitle('Curve X: Linear Y: Log')
+
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve",
+ replace=False, resetzoom=True,
+ color='green', linestyle="-", symbol='o')
+
+ def testPlotCurveLogXY(self):
+ self._setLabels()
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.getYAxis()._setLogarithmic(True)
+
+ self.plot.setGraphTitle('Curve X: Log Y: Log')
+
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve",
+ replace=False, resetzoom=True,
+ color='green', linestyle="-", symbol='o')
+
+ def testPlotCurveErrorLogXY(self):
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.getYAxis()._setLogarithmic(True)
+
+ # Every second error leads to negative number
+ errors = numpy.ones_like(self.xData)
+ errors[::2] = self.xData[::2] + 1
+
+ tests = [ # name, xerror, yerror
+ ('xerror=3', 3, None),
+ ('xerror=N array', errors, None),
+ ('xerror=Nx1 array', errors.reshape(len(errors), 1), None),
+ ('xerror=2xN array', numpy.array((errors, errors)), None),
+ ('yerror=6', None, 6),
+ ('yerror=N array', None, errors ** 2),
+ ('yerror=Nx1 array', None, (errors ** 2).reshape(len(errors), 1)),
+ ('yerror=2xN array', None, numpy.array((errors, errors)) ** 2),
+ ]
+
+ for name, xError, yError in tests:
+ with self.subTest(name):
+ self.plot.setGraphTitle(name)
+ self.plot.addCurve(self.xData, self.yData,
+ legend=name,
+ xerror=xError, yerror=yError,
+ replace=False, resetzoom=True,
+ color='green', linestyle="-", symbol='o')
+
+ self.qapp.processEvents()
+
+ self.plot.clear()
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ def testPlotCurveToggleLog(self):
+ """Add a curve with negative data and toggle log axis"""
+ arange = numpy.arange(1000) + 1
+ tests = [ # name, xData, yData
+ ('x>0, some negative y', arange, arange - 500),
+ ('x>0, y<0', arange, -arange),
+ ('some negative x, y>0', arange - 500, arange),
+ ('x<0, y>0', -arange, arange),
+ ('some negative x and y', arange - 500, arange - 500),
+ ('x<0, y<0', -arange, -arange),
+ ]
+
+ for name, xData, yData in tests:
+ with self.subTest(name):
+ self.plot.addCurve(xData, yData, resetzoom=True)
+ self.qapp.processEvents()
+
+ # no log axis
+ xLim = self.plot.getXAxis().getLimits()
+ self.assertEqual(xLim, (min(xData), max(xData)))
+ yLim = self.plot.getYAxis().getLimits()
+ self.assertEqual(yLim, (min(yData), max(yData)))
+
+ # x axis log
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.qapp.processEvents()
+
+ xLim = self.plot.getXAxis().getLimits()
+ yLim = self.plot.getYAxis().getLimits()
+ positives = xData > 0
+ if numpy.any(positives):
+ self.assertTrue(numpy.allclose(
+ xLim, (min(xData[positives]), max(xData[positives]))))
+ self.assertEqual(
+ yLim, (min(yData[positives]), max(yData[positives])))
+ else: # No positive x in the curve
+ self.assertEqual(xLim, (1., 100.))
+ self.assertEqual(yLim, (1., 100.))
+
+ # x axis and y axis log
+ self.plot.getYAxis()._setLogarithmic(True)
+ self.qapp.processEvents()
+
+ xLim = self.plot.getXAxis().getLimits()
+ yLim = self.plot.getYAxis().getLimits()
+ positives = numpy.logical_and(xData > 0, yData > 0)
+ if numpy.any(positives):
+ self.assertTrue(numpy.allclose(
+ xLim, (min(xData[positives]), max(xData[positives]))))
+ self.assertTrue(numpy.allclose(
+ yLim, (min(yData[positives]), max(yData[positives]))))
+ else: # No positive x and y in the curve
+ self.assertEqual(xLim, (1., 100.))
+ self.assertEqual(yLim, (1., 100.))
+
+ # y axis log
+ self.plot.getXAxis()._setLogarithmic(False)
+ self.qapp.processEvents()
+
+ xLim = self.plot.getXAxis().getLimits()
+ yLim = self.plot.getYAxis().getLimits()
+ positives = yData > 0
+ if numpy.any(positives):
+ self.assertEqual(
+ xLim, (min(xData[positives]), max(xData[positives])))
+ self.assertTrue(numpy.allclose(
+ yLim, (min(yData[positives]), max(yData[positives]))))
+ else: # No positive y in the curve
+ self.assertEqual(xLim, (1., 100.))
+ self.assertEqual(yLim, (1., 100.))
+
+ # no log axis
+ self.plot.getYAxis()._setLogarithmic(False)
+ self.qapp.processEvents()
+
+ xLim = self.plot.getXAxis().getLimits()
+ self.assertEqual(xLim, (min(xData), max(xData)))
+ yLim = self.plot.getYAxis().getLimits()
+ self.assertEqual(yLim, (min(yData), max(yData)))
+
+ self.plot.clear()
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+
+class TestPlotImageLog(PlotWidgetTestCase):
+ """Basic tests for addImage with log scale axes."""
+
+ def setUp(self):
+ super(TestPlotImageLog, self).setUp()
+
+ self.plot.getXAxis().setLabel('Columns')
+ self.plot.getYAxis().setLabel('Rows')
+
+ def testPlotColormapGrayLogX(self):
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.setGraphTitle('CMap X: Log Y: Linear')
+
+ colormap = Colormap(name='gray',
+ normalization='linear',
+ vmin=None,
+ vmax=None)
+ self.plot.addImage(DATA_2D, legend="image 1",
+ origin=(1., 1.), scale=(1., 1.),
+ resetzoom=False, colormap=colormap)
+ self.plot.resetZoom()
+
+ def testPlotColormapGrayLogY(self):
+ self.plot.getYAxis()._setLogarithmic(True)
+ self.plot.setGraphTitle('CMap X: Linear Y: Log')
+
+ colormap = Colormap(name='gray',
+ normalization='linear',
+ vmin=None,
+ vmax=None)
+ self.plot.addImage(DATA_2D, legend="image 1",
+ origin=(1., 1.), scale=(1., 1.),
+ resetzoom=False, colormap=colormap)
+ self.plot.resetZoom()
+
+ def testPlotColormapGrayLogXY(self):
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.getYAxis()._setLogarithmic(True)
+ self.plot.setGraphTitle('CMap X: Log Y: Log')
+
+ colormap = Colormap(name='gray',
+ normalization='linear',
+ vmin=None,
+ vmax=None)
+ self.plot.addImage(DATA_2D, legend="image 1",
+ origin=(1., 1.), scale=(1., 1.),
+ resetzoom=False, colormap=colormap)
+ self.plot.resetZoom()
+
+ def testPlotRgbRgbaLogXY(self):
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.getYAxis()._setLogarithmic(True)
+ self.plot.setGraphTitle('RGB + RGBA X: Log Y: Log')
+
+ rgb = numpy.array(
+ (((0, 0, 0), (128, 0, 0), (255, 0, 0)),
+ ((0, 128, 0), (0, 128, 128), (0, 128, 256))),
+ dtype=numpy.uint8)
+
+ self.plot.addImage(rgb, legend="rgb",
+ origin=(1, 1), scale=(10, 10),
+ resetzoom=False)
+
+ rgba = numpy.array(
+ (((0, 0, 0, .5), (.5, 0, 0, 1), (1, 0, 0, .5)),
+ ((0, .5, 0, 1), (0, .5, .5, 1), (0, 1, 1, .5))),
+ dtype=numpy.float32)
+
+ self.plot.addImage(rgba, legend="rgba",
+ origin=(5., 5.), scale=(10., 10.),
+ resetzoom=False)
+ self.plot.resetZoom()
+
+
+class TestPlotMarkerLog(PlotWidgetTestCase):
+ """Basic tests for markers on log scales"""
+
+ # Test marker parameters
+ markers = [ # x, y, color, selectable, draggable
+ (10., 10., 'blue', False, False),
+ (20., 20., 'red', False, False),
+ (40., 100., 'green', True, False),
+ (40., 500., 'gray', True, True),
+ (60., 800., 'black', False, True),
+ ]
+
+ def setUp(self):
+ super(TestPlotMarkerLog, self).setUp()
+
+ self.plot.getYAxis().setLabel('Rows')
+ self.plot.getXAxis().setLabel('Columns')
+ self.plot.getXAxis().setAutoScale(False)
+ self.plot.getYAxis().setAutoScale(False)
+ self.plot.setKeepDataAspectRatio(False)
+ self.plot.setLimits(1., 100., 1., 1000.)
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.getYAxis()._setLogarithmic(True)
+
+ def testPlotMarkerXLog(self):
+ self.plot.setGraphTitle('Markers X, Log axes')
+
+ for x, _, color, select, drag in self.markers:
+ name = str(x)
+ if select:
+ name += " sel."
+ if drag:
+ name += " drag"
+ self.plot.addXMarker(x, name, name, color, select, drag)
+ self.plot.resetZoom()
+
+ def testPlotMarkerYLog(self):
+ self.plot.setGraphTitle('Markers Y, Log axes')
+
+ for _, y, color, select, drag in self.markers:
+ name = str(y)
+ if select:
+ name += " sel."
+ if drag:
+ name += " drag"
+ self.plot.addYMarker(y, name, name, color, select, drag)
+ self.plot.resetZoom()
+
+ def testPlotMarkerPtLog(self):
+ self.plot.setGraphTitle('Markers Pt, Log axes')
+
+ for x, y, color, select, drag in self.markers:
+ name = "{0},{1}".format(x, y)
+ if select:
+ name += " sel."
+ if drag:
+ name += " drag"
+ self.plot.addMarker(x, y, name, name, color, select, drag)
+ self.plot.resetZoom()
+
+
+@pytest.mark.usefixtures("test_options_class_attr")
+class TestPlotWidgetSwitchBackend(PlotWidgetTestCase):
+ """Test [get|set]Backend to switch backend"""
+
+ @pytest.mark.usefixtures("test_options")
+ def testSwitchBackend(self):
+ """Test switching a plot with a few items"""
+ backends = {'none': 'BackendBase', 'mpl': 'BackendMatplotlibQt'}
+ if self.test_options.WITH_GL_TEST:
+ backends['gl'] = 'BackendOpenGL'
+
+ self.plot.addImage(numpy.arange(100).reshape(10, 10))
+ self.plot.addCurve((-3, -2, -1), (1, 2, 3))
+ self.plot.resetZoom()
+ xlimits = self.plot.getXAxis().getLimits()
+ ylimits = self.plot.getYAxis().getLimits()
+ items = self.plot.getItems()
+ self.assertEqual(len(items), 2)
+
+ for backend, className in backends.items():
+ with self.subTest(backend=backend):
+ self.plot.setBackend(backend)
+ self.plot.replot()
+
+ retrievedBackend = self.plot.getBackend()
+ self.assertEqual(type(retrievedBackend).__name__, className)
+ self.assertEqual(self.plot.getXAxis().getLimits(), xlimits)
+ self.assertEqual(self.plot.getYAxis().getLimits(), ylimits)
+ self.assertEqual(self.plot.getItems(), items)
+
+
+class TestPlotWidgetSelection(PlotWidgetTestCase):
+ """Test PlotWidget.selection and active items handling"""
+
+ def _checkSelection(self, selection, current=None, selected=()):
+ """Check current item and selected items."""
+ self.assertIs(selection.getCurrentItem(), current)
+ self.assertEqual(selection.getSelectedItems(), selected)
+
+ def testSyncWithActiveItems(self):
+ """Test update of PlotWidgetSelection according to active items"""
+ listener = SignalListener()
+
+ selection = self.plot.selection()
+ selection.sigCurrentItemChanged.connect(listener)
+ self._checkSelection(selection)
+
+ # Active item is current
+ self.plot.addImage(((0, 1), (2, 3)), legend='image')
+ image = self.plot.getActiveImage()
+ self.assertEqual(listener.callCount(), 1)
+ self._checkSelection(selection, image, (image,))
+
+ # No active = no current
+ self.plot.setActiveImage(None)
+ self.assertEqual(listener.callCount(), 2)
+ self._checkSelection(selection)
+
+ # Active item is current
+ self.plot.setActiveImage('image')
+ self.assertEqual(listener.callCount(), 3)
+ self._checkSelection(selection, image, (image,))
+
+ # Mosted recently "actived" item is current
+ self.plot.addScatter((3, 2, 1), (0, 1, 2), (0, 1, 2), legend='scatter')
+ scatter = self.plot.getActiveScatter()
+ self.assertEqual(listener.callCount(), 4)
+ self._checkSelection(selection, scatter, (scatter, image))
+
+ # Previously mosted recently "actived" item is current
+ self.plot.setActiveScatter(None)
+ self.assertEqual(listener.callCount(), 5)
+ self._checkSelection(selection, image, (image,))
+
+ # Mosted recently "actived" item is current
+ self.plot.setActiveScatter('scatter')
+ self.assertEqual(listener.callCount(), 6)
+ self._checkSelection(selection, scatter, (scatter, image))
+
+ # No active = no current
+ self.plot.setActiveImage(None)
+ self.plot.setActiveScatter(None)
+ self.assertEqual(listener.callCount(), 7)
+ self._checkSelection(selection)
+
+ # Mosted recently "actived" item is current
+ self.plot.setActiveScatter('scatter')
+ self.assertEqual(listener.callCount(), 8)
+ self.plot.setActiveImage('image')
+ self.assertEqual(listener.callCount(), 9)
+ self._checkSelection(selection, image, (image, scatter))
+
+ # Add a curve which is not active by default
+ self.plot.addCurve((0, 1, 2), (0, 1, 2), legend='curve')
+ curve = self.plot.getCurve('curve')
+ self.assertEqual(listener.callCount(), 9)
+ self._checkSelection(selection, image, (image, scatter))
+
+ # Mosted recently "actived" item is current
+ self.plot.setActiveCurve('curve')
+ self.assertEqual(listener.callCount(), 10)
+ self._checkSelection(selection, curve, (curve, image, scatter))
+
+ # Add a curve which is not active by default
+ self.plot.addCurve((0, 1, 2), (0, 1, 2), legend='curve2')
+ curve2 = self.plot.getCurve('curve2')
+ self.assertEqual(listener.callCount(), 10)
+ self._checkSelection(selection, curve, (curve, image, scatter))
+
+ # Mosted recently "actived" item is current, previous curve is removed
+ self.plot.setActiveCurve('curve2')
+ self.assertEqual(listener.callCount(), 11)
+ self._checkSelection(selection, curve2, (curve2, image, scatter))
+
+ # No items = no current
+ self.plot.clear()
+ self.assertEqual(listener.callCount(), 12)
+ self._checkSelection(selection)
+
+ def testPlotWidgetWithItems(self):
+ """Test init of selection on a plot with items"""
+ self.plot.addImage(((0, 1), (2, 3)), legend='image')
+ self.plot.addScatter((3, 2, 1), (0, 1, 2), (0, 1, 2), legend='scatter')
+ self.plot.addCurve((0, 1, 2), (0, 1, 2), legend='curve')
+ self.plot.setActiveCurve('curve')
+
+ selection = self.plot.selection()
+ self.assertIsNotNone(selection.getCurrentItem())
+ selected = selection.getSelectedItems()
+ self.assertEqual(len(selected), 3)
+ self.assertIn(self.plot.getActiveCurve(), selected)
+ self.assertIn(self.plot.getActiveImage(), selected)
+ self.assertIn(self.plot.getActiveScatter(), selected)
+
+ def testSetCurrentItem(self):
+ """Test setCurrentItem"""
+ # Add items to the plot
+ self.plot.addImage(((0, 1), (2, 3)), legend='image')
+ image = self.plot.getActiveImage()
+ self.plot.addScatter((3, 2, 1), (0, 1, 2), (0, 1, 2), legend='scatter')
+ scatter = self.plot.getActiveScatter()
+ self.plot.addCurve((0, 1, 2), (0, 1, 2), legend='curve')
+ self.plot.setActiveCurve('curve')
+ curve = self.plot.getActiveCurve()
+
+ selection = self.plot.selection()
+ self.assertIsNotNone(selection.getCurrentItem())
+ self.assertEqual(len(selection.getSelectedItems()), 3)
+
+ # Set current to None reset all active items
+ selection.setCurrentItem(None)
+ self._checkSelection(selection)
+ self.assertIsNone(self.plot.getActiveCurve())
+ self.assertIsNone(self.plot.getActiveImage())
+ self.assertIsNone(self.plot.getActiveScatter())
+
+ # Set current to an item makes it active
+ selection.setCurrentItem(image)
+ self._checkSelection(selection, image, (image,))
+ self.assertIsNone(self.plot.getActiveCurve())
+ self.assertIs(self.plot.getActiveImage(), image)
+ self.assertIsNone(self.plot.getActiveScatter())
+
+ # Set current to an item makes it active and keeps other active
+ selection.setCurrentItem(curve)
+ self._checkSelection(selection, curve, (curve, image))
+ self.assertIs(self.plot.getActiveCurve(), curve)
+ self.assertIs(self.plot.getActiveImage(), image)
+ self.assertIsNone(self.plot.getActiveScatter())
+
+ # Set current to an item makes it active and keeps other active
+ selection.setCurrentItem(scatter)
+ self._checkSelection(selection, scatter, (scatter, curve, image))
+ self.assertIs(self.plot.getActiveCurve(), curve)
+ self.assertIs(self.plot.getActiveImage(), image)
+ self.assertIs(self.plot.getActiveScatter(), scatter)
+
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotWidget_Gl(TestPlotWidget):
+ backend="gl"
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotImage_Gl(TestPlotImage):
+ backend="gl"
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotCurve_Gl(TestPlotCurve):
+ backend="gl"
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotHistogram_Gl(TestPlotHistogram):
+ backend="gl"
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotScatter_Gl(TestPlotScatter):
+ backend="gl"
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotMarker_Gl(TestPlotMarker):
+ backend="gl"
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotItem_Gl(TestPlotItem):
+ backend="gl"
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotAxes_Gl(TestPlotAxes):
+ backend="gl"
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotActiveCurveImage_Gl(TestPlotActiveCurveImage):
+ backend="gl"
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotEmptyLog_Gl(TestPlotEmptyLog):
+ backend="gl"
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotCurveLog_Gl(TestPlotCurveLog):
+ backend="gl"
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotImageLog_Gl(TestPlotImageLog):
+ backend="gl"
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotMarkerLog_Gl(TestPlotMarkerLog):
+ backend="gl"
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotWidgetSelection_Gl(TestPlotWidgetSelection):
+ backend="gl"
+
+class TestSpecial_ExplicitMplBackend(TestSpecialBackend):
+ backend="mpl"
diff --git a/src/silx/gui/plot/test/testPlotWidgetNoBackend.py b/src/silx/gui/plot/test/testPlotWidgetNoBackend.py
new file mode 100644
index 0000000..4914929
--- /dev/null
+++ b/src/silx/gui/plot/test/testPlotWidgetNoBackend.py
@@ -0,0 +1,618 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for PlotWidget with 'none' backend"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+
+import unittest
+from functools import reduce
+from silx.utils.testutils import ParametricTestCase
+
+import numpy
+
+from silx.gui.plot.PlotWidget import PlotWidget
+from silx.gui.plot.items.histogram import _getHistogramCurve, _computeEdges
+
+
+class TestPlot(unittest.TestCase):
+ """Basic tests of Plot without backend"""
+
+ def testPlotTitleLabels(self):
+ """Create a Plot and set the labels"""
+
+ plot = PlotWidget(backend='none')
+
+ title, xlabel, ylabel = 'the title', 'x label', 'y label'
+ plot.setGraphTitle(title)
+ plot.getXAxis().setLabel(xlabel)
+ plot.getYAxis().setLabel(ylabel)
+
+ self.assertEqual(plot.getGraphTitle(), title)
+ self.assertEqual(plot.getXAxis().getLabel(), xlabel)
+ self.assertEqual(plot.getYAxis().getLabel(), ylabel)
+
+ def testAddNoRemove(self):
+ """add objects to the Plot"""
+
+ plot = PlotWidget(backend='none')
+ plot.addCurve(x=(1, 2, 3), y=(3, 2, 1))
+ plot.addImage(numpy.arange(100.).reshape(10, -1))
+ plot.addShape(numpy.array((1., 10.)),
+ numpy.array((10., 10.)),
+ shape="rectangle")
+ plot.addXMarker(10.)
+
+
+class TestPlotRanges(ParametricTestCase):
+ """Basic tests of Plot data ranges without backend"""
+
+ _getValidValues = {True: lambda ar: ar > 0,
+ False: lambda ar: numpy.ones(shape=ar.shape,
+ dtype=bool)}
+
+ @staticmethod
+ def _getRanges(arrays, are_logs):
+ gen = (TestPlotRanges._getValidValues[is_log](ar)
+ for (ar, is_log) in zip(arrays, are_logs))
+ indices = numpy.where(reduce(numpy.logical_and, gen))[0]
+ if len(indices) > 0:
+ ranges = [(ar[indices[0]], ar[indices[-1]]) for ar in arrays]
+ else:
+ ranges = [None] * len(arrays)
+
+ return ranges
+
+ @staticmethod
+ def _getRangesMinmax(ranges):
+ # TODO : error if None in ranges.
+ rangeMin = numpy.min([rng[0] for rng in ranges])
+ rangeMax = numpy.max([rng[1] for rng in ranges])
+ return rangeMin, rangeMax
+
+ def testDataRangeNoPlot(self):
+ """empty plot data range"""
+
+ plot = PlotWidget(backend='none')
+
+ for logX, logY in ((False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False)):
+ with self.subTest(logX=logX, logY=logY):
+ plot.getXAxis()._setLogarithmic(logX)
+ plot.getYAxis()._setLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ self.assertIsNone(dataRange.x)
+ self.assertIsNone(dataRange.y)
+ self.assertIsNone(dataRange.yright)
+
+ def testDataRangeLeft(self):
+ """left axis range"""
+
+ plot = PlotWidget(backend='none')
+
+ xData = numpy.arange(10) - 4.9 # range : -4.9 , 4.1
+ yData = numpy.arange(10) - 6.9 # range : -6.9 , 2.1
+
+ plot.addCurve(x=xData,
+ y=yData,
+ legend='plot_0',
+ yaxis='left')
+
+ for logX, logY in ((False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False)):
+ with self.subTest(logX=logX, logY=logY):
+ plot.getXAxis()._setLogarithmic(logX)
+ plot.getYAxis()._setLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ xRange, yRange = self._getRanges([xData, yData],
+ [logX, logY])
+ self.assertSequenceEqual(dataRange.x, xRange)
+ self.assertSequenceEqual(dataRange.y, yRange)
+ self.assertIsNone(dataRange.yright)
+
+ def testDataRangeRight(self):
+ """right axis range"""
+
+ plot = PlotWidget(backend='none')
+ xData = numpy.arange(10) - 4.9 # range : -4.9 , 4.1
+ yData = numpy.arange(10) - 6.9 # range : -6.9 , 2.1
+ plot.addCurve(x=xData,
+ y=yData,
+ legend='plot_0',
+ yaxis='right')
+
+ for logX, logY in ((False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False)):
+ with self.subTest(logX=logX, logY=logY):
+ plot.getXAxis()._setLogarithmic(logX)
+ plot.getYAxis()._setLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ xRange, yRange = self._getRanges([xData, yData],
+ [logX, logY])
+ self.assertSequenceEqual(dataRange.x, xRange)
+ self.assertIsNone(dataRange.y)
+ self.assertSequenceEqual(dataRange.yright, yRange)
+
+ def testDataRangeImage(self):
+ """image data range"""
+
+ origin = (-10, 25)
+ scale = (3., 8.)
+ image = numpy.arange(100.).reshape(20, 5)
+
+ plot = PlotWidget(backend='none')
+ plot.addImage(image,
+ origin=origin, scale=scale)
+
+ xRange = numpy.array([0., image.shape[1] * scale[0]]) + origin[0]
+ yRange = numpy.array([0., image.shape[0] * scale[1]]) + origin[1]
+
+ ranges = {(False, False): (xRange, yRange),
+ (True, False): (None, None),
+ (True, True): (None, None),
+ (False, True): (None, None)}
+
+ for logX, logY in ((False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False)):
+ with self.subTest(logX=logX, logY=logY):
+ plot.getXAxis()._setLogarithmic(logX)
+ plot.getYAxis()._setLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ xRange, yRange = ranges[logX, logY]
+ self.assertTrue(numpy.array_equal(dataRange.x, xRange),
+ msg='{0} != {1}'.format(dataRange.x, xRange))
+ self.assertTrue(numpy.array_equal(dataRange.y, yRange),
+ msg='{0} != {1}'.format(dataRange.y, yRange))
+ self.assertIsNone(dataRange.yright)
+
+ def testDataRangeLeftRight(self):
+ """right+left axis range"""
+
+ plot = PlotWidget(backend='none')
+
+ xData_l = numpy.arange(10) - 0.9 # range : -0.9 , 8.1
+ yData_l = numpy.arange(10) - 1.9 # range : -1.9 , 7.1
+ plot.addCurve(x=xData_l,
+ y=yData_l,
+ legend='plot_l',
+ yaxis='left')
+
+ xData_r = numpy.arange(10) - 4.9 # range : -4.9 , 4.1
+ yData_r = numpy.arange(10) - 6.9 # range : -6.9 , 2.1
+ plot.addCurve(x=xData_r,
+ y=yData_r,
+ legend='plot_r',
+ yaxis='right')
+
+ for logX, logY in ((False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False)):
+ with self.subTest(logX=logX, logY=logY):
+ plot.getXAxis()._setLogarithmic(logX)
+ plot.getYAxis()._setLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ xRangeL, yRangeL = self._getRanges([xData_l, yData_l],
+ [logX, logY])
+ xRangeR, yRangeR = self._getRanges([xData_r, yData_r],
+ [logX, logY])
+ xRangeLR = self._getRangesMinmax([xRangeL, xRangeR])
+ self.assertSequenceEqual(dataRange.x, xRangeLR)
+ self.assertSequenceEqual(dataRange.y, yRangeL)
+ self.assertSequenceEqual(dataRange.yright, yRangeR)
+
+ def testDataRangeCurveImage(self):
+ """right+left+image axis range"""
+
+ # overlapping ranges :
+ # image sets x min and y max
+ # plot_left sets y min
+ # plot_right sets x max (and yright)
+ plot = PlotWidget(backend='none')
+
+ origin = (-10, 5)
+ scale = (3., 8.)
+ image = numpy.arange(100.).reshape(20, 5)
+
+ plot.addImage(image,
+ origin=origin, scale=scale, legend='image')
+
+ xData_l = numpy.arange(10) - 0.9 # range : -0.9 , 8.1
+ yData_l = numpy.arange(10) - 1.9 # range : -1.9 , 7.1
+ plot.addCurve(x=xData_l,
+ y=yData_l,
+ legend='plot_l',
+ yaxis='left')
+
+ xData_r = numpy.arange(10) + 4.1 # range : 4.1 , 13.1
+ yData_r = numpy.arange(10) - 0.9 # range : -0.9 , 8.1
+ plot.addCurve(x=xData_r,
+ y=yData_r,
+ legend='plot_r',
+ yaxis='right')
+
+ imgXRange = numpy.array([0., image.shape[1] * scale[0]]) + origin[0]
+ imgYRange = numpy.array([0., image.shape[0] * scale[1]]) + origin[1]
+
+ for logX, logY in ((False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False)):
+ with self.subTest(logX=logX, logY=logY):
+ plot.getXAxis()._setLogarithmic(logX)
+ plot.getYAxis()._setLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ xRangeL, yRangeL = self._getRanges([xData_l, yData_l],
+ [logX, logY])
+ xRangeR, yRangeR = self._getRanges([xData_r, yData_r],
+ [logX, logY])
+ if logX or logY:
+ xRangeLR = self._getRangesMinmax([xRangeL, xRangeR])
+ else:
+ xRangeLR = self._getRangesMinmax([xRangeL,
+ xRangeR,
+ imgXRange])
+ yRangeL = self._getRangesMinmax([yRangeL, imgYRange])
+ self.assertSequenceEqual(dataRange.x, xRangeLR)
+ self.assertSequenceEqual(dataRange.y, yRangeL)
+ self.assertSequenceEqual(dataRange.yright, yRangeR)
+
+ def testDataRangeImageNegativeScaleX(self):
+ """image data range, negative scale"""
+
+ origin = (-10, 25)
+ scale = (-3., 8.)
+ image = numpy.arange(100.).reshape(20, 5)
+
+ plot = PlotWidget(backend='none')
+ plot.addImage(image,
+ origin=origin, scale=scale)
+
+ xRange = numpy.array([0., image.shape[1] * scale[0]]) + origin[0]
+ xRange.sort() # negative scale!
+ yRange = numpy.array([0., image.shape[0] * scale[1]]) + origin[1]
+
+ ranges = {(False, False): (xRange, yRange),
+ (True, False): (None, None),
+ (True, True): (None, None),
+ (False, True): (None, None)}
+
+ for logX, logY in ((False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False)):
+ with self.subTest(logX=logX, logY=logY):
+ plot.getXAxis()._setLogarithmic(logX)
+ plot.getYAxis()._setLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ xRange, yRange = ranges[logX, logY]
+ self.assertTrue(numpy.array_equal(dataRange.x, xRange),
+ msg='{0} != {1}'.format(dataRange.x, xRange))
+ self.assertTrue(numpy.array_equal(dataRange.y, yRange),
+ msg='{0} != {1}'.format(dataRange.y, yRange))
+ self.assertIsNone(dataRange.yright)
+
+ def testDataRangeImageNegativeScaleY(self):
+ """image data range, negative scale"""
+
+ origin = (-10, 25)
+ scale = (3., -8.)
+ image = numpy.arange(100.).reshape(20, 5)
+
+ plot = PlotWidget(backend='none')
+ plot.addImage(image,
+ origin=origin, scale=scale)
+
+ xRange = numpy.array([0., image.shape[1] * scale[0]]) + origin[0]
+ yRange = numpy.array([0., image.shape[0] * scale[1]]) + origin[1]
+ yRange.sort() # negative scale!
+
+ ranges = {(False, False): (xRange, yRange),
+ (True, False): (None, None),
+ (True, True): (None, None),
+ (False, True): (None, None)}
+
+ for logX, logY in ((False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False)):
+ with self.subTest(logX=logX, logY=logY):
+ plot.getXAxis()._setLogarithmic(logX)
+ plot.getYAxis()._setLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ xRange, yRange = ranges[logX, logY]
+ self.assertTrue(numpy.array_equal(dataRange.x, xRange),
+ msg='{0} != {1}'.format(dataRange.x, xRange))
+ self.assertTrue(numpy.array_equal(dataRange.y, yRange),
+ msg='{0} != {1}'.format(dataRange.y, yRange))
+ self.assertIsNone(dataRange.yright)
+
+ def testDataRangeHiddenCurve(self):
+ """curves with a hidden curve"""
+ plot = PlotWidget(backend='none')
+ plot.addCurve((0, 1), (0, 1), legend='shown')
+ plot.addCurve((0, 1, 2), (5, 5, 5), legend='hidden')
+ range1 = plot.getDataRange()
+ self.assertEqual(range1.x, (0, 2))
+ self.assertEqual(range1.y, (0, 5))
+ plot.hideCurve('hidden')
+ range2 = plot.getDataRange()
+ self.assertEqual(range2.x, (0, 1))
+ self.assertEqual(range2.y, (0, 1))
+
+
+class TestPlotGetCurveImage(unittest.TestCase):
+ """Test of plot getCurve and getImage methods"""
+
+ def testGetCurve(self):
+ """PlotWidget.getCurve and Plot.getActiveCurve tests"""
+
+ plot = PlotWidget(backend='none')
+
+ # No curve
+ curve = plot.getCurve()
+ self.assertIsNone(curve) # No curve
+
+ plot.setActiveCurveHandling(True)
+ plot.addCurve(x=(0, 1), y=(0, 1), legend='curve 0')
+ plot.addCurve(x=(0, 1), y=(0, 1), legend='curve 1')
+ plot.addCurve(x=(0, 1), y=(0, 1), legend='curve 2')
+ plot.setActiveCurve('curve 0')
+
+ # Active curve
+ active = plot.getActiveCurve()
+ self.assertEqual(active.getName(), 'curve 0')
+ curve = plot.getCurve()
+ self.assertEqual(curve.getName(), 'curve 0')
+
+ # No active curve and curves
+ plot.setActiveCurveHandling(False)
+ active = plot.getActiveCurve()
+ self.assertIsNone(active) # No active curve
+ curve = plot.getCurve()
+ self.assertEqual(curve.getName(), 'curve 2') # Last added curve
+
+ # Last curve hidden
+ plot.hideCurve('curve 2', True)
+ curve = plot.getCurve()
+ self.assertEqual(curve.getName(), 'curve 1') # Last added curve
+
+ # All curves hidden
+ plot.hideCurve('curve 1', True)
+ plot.hideCurve('curve 0', True)
+ curve = plot.getCurve()
+ self.assertIsNone(curve)
+
+ def testGetCurveOldApi(self):
+ """old API PlotWidget.getCurve and Plot.getActiveCurve tests"""
+
+ plot = PlotWidget(backend='none')
+
+ # No curve
+ curve = plot.getCurve()
+ self.assertIsNone(curve) # No curve
+
+ plot.setActiveCurveHandling(True)
+ x = numpy.arange(10.).astype(numpy.float32)
+ y = x * x
+ plot.addCurve(x=x, y=y, legend='curve 0', info=["whatever"])
+ plot.addCurve(x=x, y=2*x, legend='curve 1', info="anything")
+ plot.setActiveCurve('curve 0')
+
+ # Active curve (4 elements)
+ xOut, yOut, legend, info = plot.getActiveCurve()[:4]
+ self.assertEqual(legend, 'curve 0')
+ self.assertTrue(numpy.allclose(xOut, x), 'curve 0 wrong x data')
+ self.assertTrue(numpy.allclose(yOut, y), 'curve 0 wrong y data')
+
+ # Active curve (5 elements)
+ xOut, yOut, legend, info, params = plot.getCurve("curve 1")
+ self.assertEqual(legend, 'curve 1')
+ self.assertEqual(info, 'anything')
+ self.assertTrue(numpy.allclose(xOut, x), 'curve 1 wrong x data')
+ self.assertTrue(numpy.allclose(yOut, 2 * x), 'curve 1 wrong y data')
+
+ def testGetImage(self):
+ """PlotWidget.getImage and PlotWidget.getActiveImage tests"""
+
+ plot = PlotWidget(backend='none')
+
+ # No image
+ image = plot.getImage()
+ self.assertIsNone(image)
+
+ plot.addImage(((0, 1), (2, 3)), legend='image 0')
+ plot.addImage(((0, 1), (2, 3)), legend='image 1')
+
+ # Active image
+ active = plot.getActiveImage()
+ self.assertEqual(active.getName(), 'image 0')
+ image = plot.getImage()
+ self.assertEqual(image.getName(), 'image 0')
+
+ # No active image
+ plot.addImage(((0, 1), (2, 3)), legend='image 2')
+ plot.setActiveImage(None)
+ active = plot.getActiveImage()
+ self.assertIsNone(active)
+ image = plot.getImage()
+ self.assertEqual(image.getName(), 'image 2')
+
+ # Active image
+ plot.setActiveImage('image 1')
+ active = plot.getActiveImage()
+ self.assertEqual(active.getName(), 'image 1')
+ image = plot.getImage()
+ self.assertEqual(image.getName(), 'image 1')
+
+ def testGetImageOldApi(self):
+ """PlotWidget.getImage and PlotWidget.getActiveImage old API tests"""
+
+ plot = PlotWidget(backend='none')
+
+ # No image
+ image = plot.getImage()
+ self.assertIsNone(image)
+
+ image = numpy.arange(10).astype(numpy.float32)
+ image.shape = 5, 2
+
+ plot.addImage(image, legend='image 0', info=["Hi!"])
+
+ # Active image
+ data, legend, info, something, params = plot.getActiveImage()
+ self.assertEqual(legend, 'image 0')
+ self.assertEqual(info, ["Hi!"])
+ self.assertTrue(numpy.allclose(data, image), "image 0 data not correct")
+
+ def testGetAllImages(self):
+ """PlotWidget.getAllImages test"""
+
+ plot = PlotWidget(backend='none')
+
+ # No image
+ images = plot.getAllImages()
+ self.assertEqual(len(images), 0)
+
+ # 2 images
+ data = numpy.arange(100).reshape(10, 10)
+ plot.addImage(data, legend='1')
+ plot.addImage(data, origin=(10, 10), legend='2')
+ images = plot.getAllImages(just_legend=True)
+ self.assertEqual(list(images), ['1', '2'])
+ images = plot.getAllImages(just_legend=False)
+ self.assertEqual(len(images), 2)
+ self.assertEqual(images[0].getName(), '1')
+ self.assertEqual(images[1].getName(), '2')
+
+
+class TestPlotAddScatter(unittest.TestCase):
+ """Test of plot addScatter"""
+
+ def testAddGetScatter(self):
+
+ plot = PlotWidget(backend='none')
+
+ # No curve
+ scatter = plot._getItem(kind="scatter")
+ self.assertIsNone(scatter) # No curve
+
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 0')
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 1')
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 2')
+ plot._setActiveItem('scatter', 'scatter 0')
+
+ # Active scatter
+ active = plot._getActiveItem(kind='scatter')
+ self.assertEqual(active.getName(), 'scatter 0')
+
+ # check default values
+ self.assertAlmostEqual(active.getSymbolSize(), active._DEFAULT_SYMBOL_SIZE)
+ self.assertEqual(active.getSymbol(), "o")
+ self.assertAlmostEqual(active.getAlpha(), 1.0)
+
+ # modify parameters
+ active.setSymbolSize(20.5)
+ active.setSymbol("d")
+ active.setAlpha(0.777)
+
+ s0 = plot.getScatter("scatter 0")
+
+ self.assertAlmostEqual(s0.getSymbolSize(), 20.5)
+ self.assertEqual(s0.getSymbol(), "d")
+ self.assertAlmostEqual(s0.getAlpha(), 0.777)
+
+ scatter1 = plot._getItem(kind='scatter', legend='scatter 1')
+ self.assertEqual(scatter1.getName(), 'scatter 1')
+
+ def testGetAllScatters(self):
+ """PlotWidget.getAllImages test"""
+
+ plot = PlotWidget(backend='none')
+
+ items = plot.getItems()
+ self.assertEqual(len(items), 0)
+
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 0')
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 1')
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 2')
+
+ items = plot.getItems()
+ self.assertEqual(len(items), 3)
+ self.assertEqual(items[0].getName(), 'scatter 0')
+ self.assertEqual(items[1].getName(), 'scatter 1')
+ self.assertEqual(items[2].getName(), 'scatter 2')
+
+
+class TestPlotHistogram(unittest.TestCase):
+ """Basic tests for histogram."""
+
+ def testEdges(self):
+ x = numpy.array([0, 1, 2])
+ edgesRight = numpy.array([0, 1, 2, 3])
+ edgesLeft = numpy.array([-1, 0, 1, 2])
+ edgesCenter = numpy.array([-0.5, 0.5, 1.5, 2.5])
+
+ # testing x values for right
+ edges = _computeEdges(x, 'right')
+ numpy.testing.assert_array_equal(edges, edgesRight)
+
+ edges = _computeEdges(x, 'center')
+ numpy.testing.assert_array_equal(edges, edgesCenter)
+
+ edges = _computeEdges(x, 'left')
+ numpy.testing.assert_array_equal(edges, edgesLeft)
+
+ def testHistogramCurve(self):
+ y = numpy.array([3, 2, 5])
+ edges = numpy.array([0, 1, 2, 3])
+
+ xHisto, yHisto = _getHistogramCurve(y, edges)
+ numpy.testing.assert_array_equal(
+ yHisto, numpy.array([3, 3, 2, 2, 5, 5]))
+
+ y = numpy.array([-3, 2, 5, 0])
+ edges = numpy.array([-2, -1, 0, 1, 2])
+ xHisto, yHisto = _getHistogramCurve(y, edges)
+ numpy.testing.assert_array_equal(
+ yHisto, numpy.array([-3, -3, 2, 2, 5, 5, 0, 0]))
diff --git a/src/silx/gui/plot/test/testPlotWindow.py b/src/silx/gui/plot/test/testPlotWindow.py
new file mode 100644
index 0000000..9e1497f
--- /dev/null
+++ b/src/silx/gui/plot/test/testPlotWindow.py
@@ -0,0 +1,174 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for PlotWindow"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "27/06/2017"
+
+
+import unittest
+import numpy
+import pytest
+
+from silx.gui.utils.testutils import TestCaseQt, getQToolButtonFromAction
+
+from silx.gui import qt
+from silx.gui.plot import PlotWindow
+from silx.gui.colors import Colormap
+
+
+class TestPlotWindow(TestCaseQt):
+ """Base class for tests of PlotWindow."""
+
+ def setUp(self):
+ super(TestPlotWindow, self).setUp()
+ self.plot = PlotWindow()
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ def tearDown(self):
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ super(TestPlotWindow, self).tearDown()
+
+ def testActions(self):
+ """Test the actions QToolButtons"""
+ self.plot.setLimits(1, 100, 1, 100)
+
+ checkList = [ # QAction, Plot state getter
+ (self.plot.xAxisAutoScaleAction, self.plot.getXAxis().isAutoScale),
+ (self.plot.yAxisAutoScaleAction, self.plot.getYAxis().isAutoScale),
+ (self.plot.xAxisLogarithmicAction, self.plot.getXAxis()._isLogarithmic),
+ (self.plot.yAxisLogarithmicAction, self.plot.getYAxis()._isLogarithmic),
+ (self.plot.gridAction, self.plot.getGraphGrid),
+ ]
+
+ for action, getter in checkList:
+ self.mouseMove(self.plot)
+ initialState = getter()
+ toolButton = getQToolButtonFromAction(action)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+ self.assertNotEqual(getter(), initialState,
+ msg='"%s" state not changed' % action.text())
+
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+ self.assertEqual(getter(), initialState,
+ msg='"%s" state not changed' % action.text())
+
+ # Trigger a zoom reset
+ self.mouseMove(self.plot)
+ resetZoomAction = self.plot.resetZoomAction
+ toolButton = getQToolButtonFromAction(resetZoomAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ def testDockWidgets(self):
+ """Test add/remove dock widgets"""
+ dock1 = qt.QDockWidget('Test 1')
+ dock1.setWidget(qt.QLabel('Test 1'))
+
+ self.plot.addTabbedDockWidget(dock1)
+ self.qapp.processEvents()
+
+ self.plot.removeDockWidget(dock1)
+ self.qapp.processEvents()
+
+ dock2 = qt.QDockWidget('Test 2')
+ dock2.setWidget(qt.QLabel('Test 2'))
+
+ self.plot.addTabbedDockWidget(dock2)
+ self.qapp.processEvents()
+
+ if qt.BINDING != 'PySide2':
+ # Weird bug with PySide2 later upon gc.collect() when getting the layout
+ self.assertNotEqual(self.plot.layout().indexOf(dock2),
+ -1,
+ "dock2 not properly displayed")
+
+ def testToolAspectRatio(self):
+ self.plot.toolBar()
+ self.plot.keepDataAspectRatioButton.keepDataAspectRatio()
+ self.assertTrue(self.plot.isKeepDataAspectRatio())
+ self.plot.keepDataAspectRatioButton.dontKeepDataAspectRatio()
+ self.assertFalse(self.plot.isKeepDataAspectRatio())
+
+ def testToolYAxisOrigin(self):
+ self.plot.toolBar()
+ self.plot.yAxisInvertedButton.setYAxisUpward()
+ self.assertFalse(self.plot.getYAxis().isInverted())
+ self.plot.yAxisInvertedButton.setYAxisDownward()
+ self.assertTrue(self.plot.getYAxis().isInverted())
+
+ def testColormapAutoscaleCache(self):
+ # Test that the min/max cache is not computed twice
+
+ old = Colormap._computeAutoscaleRange
+ self._count = 0
+ def _computeAutoscaleRange(colormap, data):
+ self._count = self._count + 1
+ return 10, 20
+ Colormap._computeAutoscaleRange = _computeAutoscaleRange
+ try:
+ colormap = Colormap(name='red')
+ self.plot.setVisible(True)
+
+ # Add an image
+ data = numpy.arange(8**2).reshape(8, 8)
+ self.plot.addImage(data, legend="foo", colormap=colormap)
+ self.plot.setActiveImage("foo")
+
+ # Use the colorbar
+ self.plot.getColorBarWidget().setVisible(True)
+ self.qWait(50)
+
+ # Remove and add again the same item
+ image = self.plot.getImage("foo")
+ self.plot.removeImage("foo")
+ self.plot.addItem(image)
+ self.qWait(50)
+ finally:
+ Colormap._computeAutoscaleRange = old
+ self.assertEqual(self._count, 1)
+ del self._count
+
+ @pytest.mark.usefixtures("use_opengl")
+ def testSwitchBackend(self):
+ """Test switching an empty plot"""
+ self.plot.resetZoom()
+ xlimits = self.plot.getXAxis().getLimits()
+ ylimits = self.plot.getYAxis().getLimits()
+ isKeepAspectRatio = self.plot.isKeepDataAspectRatio()
+
+ for backend in ('gl', 'mpl'):
+ with self.subTest():
+ self.plot.setBackend(backend)
+ self.plot.replot()
+ self.assertEqual(self.plot.getXAxis().getLimits(), xlimits)
+ self.assertEqual(self.plot.getYAxis().getLimits(), ylimits)
+ self.assertEqual(
+ self.plot.isKeepDataAspectRatio(), isKeepAspectRatio)
diff --git a/src/silx/gui/plot/test/testRoiStatsWidget.py b/src/silx/gui/plot/test/testRoiStatsWidget.py
new file mode 100644
index 0000000..eb29267
--- /dev/null
+++ b/src/silx/gui/plot/test/testRoiStatsWidget.py
@@ -0,0 +1,277 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for ROIStatsWidget"""
+
+
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui import qt
+from silx.gui.plot import PlotWindow
+from silx.gui.plot.stats.stats import Stats
+from silx.gui.plot.ROIStatsWidget import ROIStatsWidget
+from silx.gui.plot.CurvesROIWidget import ROI
+from silx.gui.plot.items.roi import RectangleROI, PolygonROI
+from silx.gui.plot.StatsWidget import UpdateMode
+import unittest
+import numpy
+
+
+
+class _TestRoiStatsBase(TestCaseQt):
+ """Base class for several unittest relative to ROIStatsWidget"""
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ # define plot
+ self.plot = PlotWindow()
+ self.plot.addImage(numpy.arange(10000).reshape(100, 100),
+ legend='img1')
+ self.img_item = self.plot.getImage('img1')
+ self.plot.addCurve(x=numpy.linspace(0, 10, 56), y=numpy.arange(56),
+ legend='curve1')
+ self.curve_item = self.plot.getCurve('curve1')
+ self.plot.addHistogram(edges=numpy.linspace(0, 10, 56),
+ histogram=numpy.arange(56), legend='histo1')
+ self.histogram_item = self.plot.getHistogram(legend='histo1')
+ self.plot.addScatter(x=numpy.linspace(0, 10, 56),
+ y=numpy.linspace(0, 10, 56),
+ value=numpy.arange(56),
+ legend='scatter1')
+ self.scatter_item = self.plot.getScatter(legend='scatter1')
+
+ # stats widget
+ self.statsWidget = ROIStatsWidget(plot=self.plot)
+
+ # define stats
+ stats = [
+ ('sum', numpy.sum),
+ ('mean', numpy.mean),
+ ]
+ self.statsWidget.setStats(stats=stats)
+
+ # define rois
+ self.roi1D = ROI(name='range1', fromdata=0, todata=4, type_='energy')
+ self.rectangle_roi = RectangleROI()
+ self.rectangle_roi.setGeometry(origin=(0, 0), size=(20, 20))
+ self.rectangle_roi.setName('Initial ROI')
+ self.polygon_roi = PolygonROI()
+ points = numpy.array([[0, 5], [5, 0], [10, 5], [5, 10]])
+ self.polygon_roi.setPoints(points)
+
+ def statsTable(self):
+ return self.statsWidget._statsROITable
+
+ def tearDown(self):
+ Stats._getContext.cache_clear()
+ self.statsWidget.setAttribute(qt.Qt.WA_DeleteOnClose, True)
+ self.statsWidget.close()
+ self.statsWidget = None
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose, True)
+ self.plot.close()
+ self.plot = None
+ TestCaseQt.tearDown(self)
+
+
+class TestRoiStatsCouple(_TestRoiStatsBase):
+ """
+ Test different possible couple (roi, plotItem).
+ Check that:
+
+ * computation is correct if couple is valid
+ * raise an error if couple is invalid
+ """
+ def testROICurve(self):
+ """
+ Test that the couple (ROI, curveItem) can be used for stats
+ """
+ item = self.statsWidget.addItem(roi=self.roi1D,
+ plotItem=self.curve_item)
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems['sum'].text(), '253')
+ self.assertEqual(tableItems['mean'].text(), '11.0')
+
+ def testRectangleImage(self):
+ """
+ Test that the couple (RectangleROI, imageItem) can be used for stats
+ """
+ item = self.statsWidget.addItem(roi=self.rectangle_roi,
+ plotItem=self.img_item)
+ assert item is not None
+ self.plot.addImage(numpy.ones(10000).reshape(100, 100),
+ legend='img1')
+ self.qapp.processEvents()
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems['sum'].text(), str(float(21*21)))
+ self.assertEqual(tableItems['mean'].text(), '1.0')
+
+ def testPolygonImage(self):
+ """
+ Test that the couple (PolygonROI, imageItem) can be used for stats
+ """
+ item = self.statsWidget.addItem(roi=self.polygon_roi,
+ plotItem=self.img_item)
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems['sum'].text(), '22750')
+ self.assertEqual(tableItems['mean'].text(), '455.0')
+
+ def testROIImage(self):
+ """
+ Test that the couple (ROI, imageItem) is raising an error
+ """
+ with self.assertRaises(TypeError):
+ self.statsWidget.addItem(roi=self.roi1D,
+ plotItem=self.img_item)
+
+ def testRectangleCurve(self):
+ """
+ Test that the couple (rectangleROI, curveItem) is raising an error
+ """
+ with self.assertRaises(TypeError):
+ self.statsWidget.addItem(roi=self.rectangle_roi,
+ plotItem=self.curve_item)
+
+ def testROIHistogram(self):
+ """
+ Test that the couple (PolygonROI, imageItem) can be used for stats
+ """
+ item = self.statsWidget.addItem(roi=self.roi1D,
+ plotItem=self.histogram_item)
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems['sum'].text(), '253')
+ self.assertEqual(tableItems['mean'].text(), '11.0')
+
+ def testROIScatter(self):
+ """
+ Test that the couple (PolygonROI, imageItem) can be used for stats
+ """
+ item = self.statsWidget.addItem(roi=self.roi1D,
+ plotItem=self.scatter_item)
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems['sum'].text(), '253')
+ self.assertEqual(tableItems['mean'].text(), '11.0')
+
+
+class TestRoiStatsAddRemoveItem(_TestRoiStatsBase):
+ """Test adding and removing (roi, plotItem) items"""
+ def testAddRemoveItems(self):
+ item1 = self.statsWidget.addItem(roi=self.roi1D,
+ plotItem=self.scatter_item)
+ self.assertTrue(item1 is not None)
+ self.assertEqual(self.statsTable().rowCount(), 1)
+ item2 = self.statsWidget.addItem(roi=self.roi1D,
+ plotItem=self.histogram_item)
+ self.assertTrue(item2 is not None)
+ self.assertEqual(self.statsTable().rowCount(), 2)
+ # try to add twice the same item
+ item3 = self.statsWidget.addItem(roi=self.roi1D,
+ plotItem=self.histogram_item)
+ self.assertTrue(item3 is None)
+ self.assertEqual(self.statsTable().rowCount(), 2)
+ item4 = self.statsWidget.addItem(roi=self.roi1D,
+ plotItem=self.curve_item)
+ self.assertTrue(item4 is not None)
+ self.assertEqual(self.statsTable().rowCount(), 3)
+
+ self.statsWidget.removeItem(plotItem=item4._plot_item,
+ roi=item4._roi)
+ self.assertEqual(self.statsTable().rowCount(), 2)
+ # try to remove twice the same item
+ self.statsWidget.removeItem(plotItem=item4._plot_item,
+ roi=item4._roi)
+ self.assertEqual(self.statsTable().rowCount(), 2)
+ self.statsWidget.removeItem(plotItem=item2._plot_item,
+ roi=item2._roi)
+ self.statsWidget.removeItem(plotItem=item1._plot_item,
+ roi=item1._roi)
+ self.assertEqual(self.statsTable().rowCount(), 0)
+
+
+class TestRoiStatsRoiUpdate(_TestRoiStatsBase):
+ """Test that the stats will be updated if the roi is updated"""
+ def testChangeRoi(self):
+ item = self.statsWidget.addItem(roi=self.rectangle_roi,
+ plotItem=self.img_item)
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems['sum'].text(), '445410')
+ self.assertEqual(tableItems['mean'].text(), '1010.0')
+
+ # update roi
+ self.rectangle_roi.setOrigin(position=(10, 10))
+ self.assertNotEqual(tableItems['sum'].text(), '445410')
+ self.assertNotEqual(tableItems['mean'].text(), '1010.0')
+
+ def testUpdateModeScenario(self):
+ """Test update according to a simple scenario"""
+ self.statsWidget._setUpdateMode(UpdateMode.AUTO)
+ item = self.statsWidget.addItem(roi=self.rectangle_roi,
+ plotItem=self.img_item)
+
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems['sum'].text(), '445410')
+ self.assertEqual(tableItems['mean'].text(), '1010.0')
+ self.statsWidget._setUpdateMode(UpdateMode.MANUAL)
+ self.rectangle_roi.setOrigin(position=(10, 10))
+ self.qapp.processEvents()
+ self.assertNotEqual(tableItems['sum'].text(), '445410')
+ self.assertNotEqual(tableItems['mean'].text(), '1010.0')
+ self.statsWidget._updateAllStats(is_request=True)
+ self.assertNotEqual(tableItems['sum'].text(), '445410')
+ self.assertNotEqual(tableItems['mean'].text(), '1010.0')
+
+
+class TestRoiStatsPlotItemUpdate(_TestRoiStatsBase):
+ """Test that the stats will be updated if the plot item is updated"""
+ def testChangeImage(self):
+ self.statsWidget._setUpdateMode(UpdateMode.AUTO)
+ item = self.statsWidget.addItem(roi=self.rectangle_roi,
+ plotItem=self.img_item)
+
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems['mean'].text(), '1010.0')
+
+ # update plot
+ self.plot.addImage(numpy.arange(100, 10100).reshape(100, 100),
+ legend='img1')
+ self.assertNotEqual(tableItems['mean'].text(), '1059.5')
+
+ def testUpdateModeScenario(self):
+ """Test update according to a simple scenario"""
+ self.statsWidget._setUpdateMode(UpdateMode.MANUAL)
+ item = self.statsWidget.addItem(roi=self.rectangle_roi,
+ plotItem=self.img_item)
+
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems['mean'].text(), '1010.0')
+ self.plot.addImage(numpy.arange(100, 10100).reshape(100, 100),
+ legend='img1')
+ self.assertEqual(tableItems['mean'].text(), '1010.0')
+ self.statsWidget._updateAllStats(is_request=True)
+ self.assertEqual(tableItems['mean'].text(), '1110.0')
diff --git a/src/silx/gui/plot/test/testSaveAction.py b/src/silx/gui/plot/test/testSaveAction.py
new file mode 100644
index 0000000..9280fb6
--- /dev/null
+++ b/src/silx/gui/plot/test/testSaveAction.py
@@ -0,0 +1,132 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test the plot's save action (consistency of output)"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "28/11/2017"
+
+
+import unittest
+import tempfile
+import os
+
+from silx.gui.plot.test.utils import PlotWidgetTestCase
+
+from silx.gui.plot import PlotWidget
+from silx.gui.plot.actions.io import SaveAction
+
+
+class TestSaveActionSaveCurvesAsSpec(unittest.TestCase):
+
+ def setUp(self):
+ self.plot = PlotWidget(backend='none')
+ self.saveAction = SaveAction(plot=self.plot)
+
+ self.tempdir = tempfile.mkdtemp()
+ self.out_fname = os.path.join(self.tempdir, "out.dat")
+
+ def tearDown(self):
+ os.unlink(self.out_fname)
+ os.rmdir(self.tempdir)
+
+ def testSaveMultipleCurvesAsSpec(self):
+ """Test that labels are properly used."""
+ self.plot.setGraphXLabel("graph x label")
+ self.plot.setGraphYLabel("graph y label")
+
+ self.plot.addCurve([0, 1], [1, 2], "curve with labels",
+ xlabel="curve0 X", ylabel="curve0 Y")
+ self.plot.addCurve([-1, 3], [-6, 2], "curve with X label",
+ xlabel="curve1 X")
+ self.plot.addCurve([-2, 0], [8, 12], "curve with Y label",
+ ylabel="curve2 Y")
+ self.plot.addCurve([3, 1], [7, 6], "curve with no labels")
+
+ self.saveAction._saveCurves(self.plot,
+ self.out_fname,
+ SaveAction.DEFAULT_ALL_CURVES_FILTERS[0]) # "All curves as SpecFile (*.dat)"
+
+ with open(self.out_fname, "rb") as f:
+ file_content = f.read()
+ if hasattr(file_content, "decode"):
+ file_content = file_content.decode()
+
+ # case with all curve labels specified
+ self.assertIn("#S 1 curve0 Y", file_content)
+ self.assertIn("#L curve0 X curve0 Y", file_content)
+
+ # graph X&Y labels are used when no curve label is specified
+ self.assertIn("#S 2 graph y label", file_content)
+ self.assertIn("#L curve1 X graph y label", file_content)
+
+ self.assertIn("#S 3 curve2 Y", file_content)
+ self.assertIn("#L graph x label curve2 Y", file_content)
+
+ self.assertIn("#S 4 graph y label", file_content)
+ self.assertIn("#L graph x label graph y label", file_content)
+
+
+class TestSaveActionExtension(PlotWidgetTestCase):
+ """Test SaveAction file filter API"""
+
+ def _dummySaveFunction(self, plot, filename, nameFilter):
+ pass
+
+ def testFileFilterAPI(self):
+ """Test addition/update of a file filter"""
+ saveAction = SaveAction(plot=self.plot, parent=self.plot)
+
+ # Add a new file filter
+ nameFilter = 'Dummy file (*.dummy)'
+ saveAction.setFileFilter('all', nameFilter, self._dummySaveFunction)
+ self.assertTrue(nameFilter in saveAction.getFileFilters('all'))
+ self.assertEqual(saveAction.getFileFilters('all')[nameFilter],
+ self._dummySaveFunction)
+
+ # Add a new file filter at a particular position
+ nameFilter = 'Dummy file2 (*.dummy)'
+ saveAction.setFileFilter('all', nameFilter,
+ self._dummySaveFunction, index=3)
+ self.assertTrue(nameFilter in saveAction.getFileFilters('all'))
+ filters = saveAction.getFileFilters('all')
+ self.assertEqual(filters[nameFilter], self._dummySaveFunction)
+ self.assertEqual(list(filters.keys()).index(nameFilter),3)
+
+ # Update an existing file filter
+ nameFilter = SaveAction.IMAGE_FILTER_EDF
+ saveAction.setFileFilter('image', nameFilter, self._dummySaveFunction)
+ self.assertEqual(saveAction.getFileFilters('image')[nameFilter],
+ self._dummySaveFunction)
+
+ # Change the position of an existing file filter
+ nameFilter = 'Dummy file2 (*.dummy)'
+ oldIndex = list(saveAction.getFileFilters('all')).index(nameFilter)
+ newIndex = oldIndex - 1
+ saveAction.setFileFilter('all', nameFilter,
+ self._dummySaveFunction, index=newIndex)
+ filters = saveAction.getFileFilters('all')
+ self.assertEqual(filters[nameFilter], self._dummySaveFunction)
+ self.assertEqual(list(filters.keys()).index(nameFilter), newIndex)
diff --git a/src/silx/gui/plot/test/testScatterMaskToolsWidget.py b/src/silx/gui/plot/test/testScatterMaskToolsWidget.py
new file mode 100644
index 0000000..447ee58
--- /dev/null
+++ b/src/silx/gui/plot/test/testScatterMaskToolsWidget.py
@@ -0,0 +1,306 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for MaskToolsWidget"""
+
+__authors__ = ["T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+
+import logging
+import os.path
+import unittest
+
+import numpy
+
+from silx.gui import qt
+from silx.test.utils import temp_dir
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import getQToolButtonFromAction
+from silx.gui.plot import PlotWindow, ScatterMaskToolsWidget
+from .utils import PlotWidgetTestCase
+
+import fabio
+
+
+_logger = logging.getLogger(__name__)
+
+
+class TestScatterMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
+ """Basic test for MaskToolsWidget"""
+
+ def _createPlot(self):
+ return PlotWindow()
+
+ def setUp(self):
+ super(TestScatterMaskToolsWidget, self).setUp()
+ self.widget = ScatterMaskToolsWidget.ScatterMaskToolsDockWidget(
+ plot=self.plot, name='TEST')
+ self.plot.addDockWidget(qt.Qt.BottomDockWidgetArea, self.widget)
+
+ self.maskWidget = self.widget.widget()
+
+ def tearDown(self):
+ del self.maskWidget
+ del self.widget
+ super(TestScatterMaskToolsWidget, self).tearDown()
+
+ def testEmptyPlot(self):
+ """Empty plot, display MaskToolsDockWidget, toggle multiple masks"""
+ self.maskWidget.setMultipleMasks('single')
+ self.qapp.processEvents()
+
+ self.maskWidget.setMultipleMasks('exclusive')
+ self.qapp.processEvents()
+
+ def _drag(self):
+ """Drag from plot center to offset position"""
+ plot = self.plot.getWidgetHandle()
+ xCenter, yCenter = plot.width() // 2, plot.height() // 2
+ offset = min(plot.width(), plot.height()) // 10
+
+ pos0 = xCenter, yCenter
+ pos1 = xCenter + offset, yCenter + offset
+
+ self.mouseMove(plot, pos=(0, 0))
+ self.mouseMove(plot, pos=pos0)
+ self.qapp.processEvents()
+ self.mousePress(plot, qt.Qt.LeftButton, pos=pos0)
+ self.qapp.processEvents()
+
+ self.mouseMove(plot, pos=(pos0[0] + offset // 2, pos0[1] + offset // 2))
+ self.mouseMove(plot, pos=pos1)
+ self.qapp.processEvents()
+ self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos1)
+ self.qapp.processEvents()
+ self.mouseMove(plot, pos=(0, 0))
+
+ def _drawPolygon(self):
+ """Draw a star polygon in the plot"""
+ plot = self.plot.getWidgetHandle()
+ x, y = plot.width() // 2, plot.height() // 2
+ offset = min(plot.width(), plot.height()) // 10
+
+ star = [(x, y + offset),
+ (x - offset, y - offset),
+ (x + offset, y),
+ (x - offset, y),
+ (x + offset, y - offset),
+ (x, y + offset)] # Close polygon
+
+ self.mouseMove(plot, pos=[0, 0])
+ for pos in star:
+ self.mouseMove(plot, pos=pos)
+ self.qapp.processEvents()
+ self.mousePress(plot, qt.Qt.LeftButton, pos=pos)
+ self.qapp.processEvents()
+ self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos)
+ self.qapp.processEvents()
+
+ def _drawPencil(self):
+ """Draw a star polygon in the plot"""
+ plot = self.plot.getWidgetHandle()
+ x, y = plot.width() // 2, plot.height() // 2
+ offset = min(plot.width(), plot.height()) // 10
+
+ star = [(x, y + offset),
+ (x - offset, y - offset),
+ (x + offset, y),
+ (x - offset, y),
+ (x + offset, y - offset)]
+
+ self.mouseMove(plot, pos=[0, 0])
+ self.mouseMove(plot, pos=star[0])
+ self.mousePress(plot, qt.Qt.LeftButton, pos=star[0])
+ for pos in star[1:]:
+ self.mouseMove(plot, pos=pos)
+ self.mouseRelease(
+ plot, qt.Qt.LeftButton, pos=star[-1])
+
+ def testWithAScatter(self):
+ """Plot with a Scatter: test MaskToolsWidget interactions"""
+
+ # Add and remove a scatter (this should enable/disable GUI + change mask)
+ self.plot.addScatter(
+ x=numpy.arange(256),
+ y=numpy.arange(256),
+ value=numpy.random.random(256),
+ legend='test')
+ self.plot._setActiveItem(kind="scatter", legend="test")
+ self.qapp.processEvents()
+
+ self.plot.remove('test', kind='scatter')
+ self.qapp.processEvents()
+
+ self.plot.addScatter(
+ x=numpy.arange(1000),
+ y=1000 * (numpy.arange(1000) % 20),
+ value=numpy.random.random(1000),
+ legend='test')
+ self.plot._setActiveItem(kind="scatter", legend="test")
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ # Test draw rectangle #
+ toolButton = getQToolButtonFromAction(self.maskWidget.rectAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ # mask
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drag()
+
+ self.assertFalse(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # unmask same region
+ self.maskWidget.maskStateGroup.button(0).click()
+ self.qapp.processEvents()
+ self._drag()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # Test draw polygon #
+ toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ # mask
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drawPolygon()
+ self.assertFalse(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # unmask same region
+ self.maskWidget.maskStateGroup.button(0).click()
+ self.qapp.processEvents()
+ self._drawPolygon()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # Test draw pencil #
+ toolButton = getQToolButtonFromAction(self.maskWidget.pencilAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ self.maskWidget.pencilSpinBox.setValue(30)
+ self.qapp.processEvents()
+
+ # mask
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drawPencil()
+ self.assertFalse(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # unmask same region
+ self.maskWidget.maskStateGroup.button(0).click()
+ self.qapp.processEvents()
+ self._drawPencil()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # Test no draw tool #
+ toolButton = getQToolButtonFromAction(self.maskWidget.browseAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ self.plot.clear()
+
+ def __loadSave(self, file_format):
+ self.plot.addScatter(
+ x=numpy.arange(256),
+ y=25 * (numpy.arange(256) % 10),
+ value=numpy.random.random(256),
+ legend='test')
+ self.plot._setActiveItem(kind="scatter", legend="test")
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ # Draw a polygon mask
+ toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+ self._drawPolygon()
+
+ ref_mask = self.maskWidget.getSelectionMask()
+ self.assertFalse(numpy.all(numpy.equal(ref_mask, 0)))
+
+ with temp_dir() as tmp:
+ mask_filename = os.path.join(tmp, 'mask.' + file_format)
+ self.maskWidget.save(mask_filename, file_format)
+
+ self.maskWidget.resetSelectionMask()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ self.maskWidget.load(mask_filename)
+ self.assertTrue(numpy.all(numpy.equal(
+ self.maskWidget.getSelectionMask(), ref_mask)))
+
+ def testLoadSaveNpy(self):
+ self.__loadSave("npy")
+
+ def testLoadSaveCsv(self):
+ self.__loadSave("csv")
+
+ def testSigMaskChangedEmitted(self):
+ self.qapp.processEvents()
+ self.plot.addScatter(
+ x=numpy.arange(1000),
+ y=1000 * (numpy.arange(1000) % 20),
+ value=numpy.ones((1000,)),
+ legend='test')
+ self.plot._setActiveItem(kind="scatter", legend="test")
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ self.plot.remove('test', kind='scatter')
+ self.qapp.processEvents()
+
+ self.plot.addScatter(
+ x=numpy.arange(1000),
+ y=1000 * (numpy.arange(1000) % 20),
+ value=numpy.random.random(1000),
+ legend='test')
+
+ l = []
+
+ def slot():
+ l.append(1)
+
+ self.maskWidget.sigMaskChanged.connect(slot)
+
+ # rectangle mask
+ toolButton = getQToolButtonFromAction(self.maskWidget.rectAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drag()
+
+ self.assertGreater(len(l), 0)
diff --git a/src/silx/gui/plot/test/testScatterView.py b/src/silx/gui/plot/test/testScatterView.py
new file mode 100644
index 0000000..d11d4d8
--- /dev/null
+++ b/src/silx/gui/plot/test/testScatterView.py
@@ -0,0 +1,123 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for ScatterView"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "06/03/2018"
+
+
+import unittest
+
+import numpy
+
+from silx.gui.plot.items import Axis, Scatter
+from silx.gui.plot import ScatterView
+from silx.gui.plot.test.utils import PlotWidgetTestCase
+
+
+class TestScatterView(PlotWidgetTestCase):
+ """Test of ScatterView widget"""
+
+ def _createPlot(self):
+ return ScatterView()
+
+ def test(self):
+ """Simple tests"""
+ x = numpy.arange(100)
+ y = numpy.arange(100)
+ value = numpy.arange(100)
+ self.plot.setData(x, y, value)
+ self.qapp.processEvents()
+
+ data = self.plot.getData()
+ self.assertEqual(len(data), 5)
+ self.assertTrue(numpy.all(numpy.equal(x, data[0])))
+ self.assertTrue(numpy.all(numpy.equal(y, data[1])))
+ self.assertTrue(numpy.all(numpy.equal(value, data[2])))
+ self.assertIsNone(data[3]) # xerror
+ self.assertIsNone(data[4]) # yerror
+
+ # Test access to scatter item
+ self.assertIsInstance(self.plot.getScatterItem(), Scatter)
+
+ # Test toolbar actions
+
+ action = self.plot.getScatterToolBar().getXAxisLogarithmicAction()
+ action.trigger()
+ self.qapp.processEvents()
+
+ maskAction = self.plot.getScatterToolBar().actions()[-1]
+ maskAction.trigger()
+ self.qapp.processEvents()
+
+ # Test proxy API
+
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ scale = self.plot.getXAxis().getScale()
+ self.assertEqual(scale, Axis.LOGARITHMIC)
+
+ scale = self.plot.getYAxis().getScale()
+ self.assertEqual(scale, Axis.LINEAR)
+
+ title = 'Test ScatterView'
+ self.plot.setGraphTitle(title)
+ self.assertEqual(self.plot.getGraphTitle(), title)
+
+ self.qapp.processEvents()
+
+ # Reset scatter data
+
+ self.plot.setData(None, None, None)
+ self.qapp.processEvents()
+
+ data = self.plot.getData()
+ self.assertEqual(len(data), 5)
+ self.assertEqual(len(data[0]), 0) # x
+ self.assertEqual(len(data[1]), 0) # y
+ self.assertEqual(len(data[2]), 0) # value
+ self.assertIsNone(data[3]) # xerror
+ self.assertIsNone(data[4]) # yerror
+
+ def testAlpha(self):
+ """Test alpha transparency in setData"""
+ _pts = 100
+ _levels = 100
+ _fwhm = 50
+ x = numpy.random.rand(_pts)*_levels
+ y = numpy.random.rand(_pts)*_levels
+ value = numpy.random.rand(_pts)*_levels
+ x0 = x[int(_pts/2)]
+ y0 = x[int(_pts/2)]
+ #2D Gaussian kernel
+ alpha = numpy.exp(-4*numpy.log(2) * ((x-x0)**2 + (y-y0)**2) / _fwhm**2)
+
+ self.plot.setData(x, y, value, alpha=alpha)
+ self.qapp.processEvents()
+
+ alphaData = self.plot.getScatterItem().getAlphaData()
+ self.assertTrue(numpy.all(numpy.equal(alpha, alphaData)))
diff --git a/src/silx/gui/plot/test/testStackView.py b/src/silx/gui/plot/test/testStackView.py
new file mode 100644
index 0000000..0d18113
--- /dev/null
+++ b/src/silx/gui/plot/test/testStackView.py
@@ -0,0 +1,248 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for StackView"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "20/03/2017"
+
+
+import unittest
+import numpy
+
+from silx.gui.utils.testutils import TestCaseQt, SignalListener
+
+from silx.gui import qt
+from silx.gui.plot import StackView
+from silx.gui.plot.StackView import StackViewMainWindow
+
+from silx.utils.array_like import ListOfImages
+
+
+class TestStackView(TestCaseQt):
+ """Base class for tests of StackView."""
+
+ def setUp(self):
+ super(TestStackView, self).setUp()
+ self.stackview = StackView()
+ self.stackview.show()
+ self.qWaitForWindowExposed(self.stackview)
+ self.mystack = numpy.fromfunction(
+ lambda i, j, k: numpy.sin(i/15.) + numpy.cos(j/4.) + 2 * numpy.sin(k/6.),
+ (10, 20, 30)
+ )
+
+ def tearDown(self):
+ self.stackview.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.stackview.close()
+ del self.stackview
+ super(TestStackView, self).tearDown()
+
+ def testScaleColormapRangeToStack(self):
+ """Test scaleColormapRangeToStack"""
+ self.stackview.setStack(self.mystack)
+ self.stackview.setColormap("viridis")
+ colormap = self.stackview.getColormap()
+
+ # Colormap autoscale to image
+ self.assertEqual(colormap.getVRange(), (None, None))
+ self.stackview.scaleColormapRangeToStack()
+
+ # Colormap range set according to stack range
+ self.assertEqual(colormap.getVRange(), (self.mystack.min(), self.mystack.max()))
+
+ def testSetStack(self):
+ self.stackview.setStack(self.mystack)
+ self.stackview.setColormap("viridis", autoscale=True)
+ my_trans_stack, params = self.stackview.getStack()
+ self.assertEqual(my_trans_stack.shape, self.mystack.shape)
+ self.assertTrue(numpy.array_equal(self.mystack,
+ my_trans_stack))
+ self.assertEqual(params["colormap"]["name"],
+ "viridis")
+
+ def testSetStackPerspective(self):
+ self.stackview.setStack(self.mystack, perspective=1)
+ # my_orig_stack, params = self.stackview.getStack()
+ my_trans_stack, params = self.stackview.getCurrentView()
+
+ # get stack returns the transposed data, depending on the perspective
+ self.assertEqual(my_trans_stack.shape,
+ (self.mystack.shape[1], self.mystack.shape[0], self.mystack.shape[2]))
+ self.assertTrue(numpy.array_equal(numpy.transpose(self.mystack, axes=(1, 0, 2)),
+ my_trans_stack))
+
+ def testSetStackListOfImages(self):
+ loi = [self.mystack[i] for i in range(self.mystack.shape[0])]
+
+ self.stackview.setStack(loi)
+ my_orig_stack, params = self.stackview.getStack(returnNumpyArray=True)
+ my_trans_stack, params = self.stackview.getStack(returnNumpyArray=True)
+ self.assertEqual(my_trans_stack.shape, self.mystack.shape)
+ self.assertTrue(numpy.array_equal(self.mystack,
+ my_trans_stack))
+ self.assertTrue(numpy.array_equal(self.mystack,
+ my_orig_stack))
+ self.assertIsInstance(my_trans_stack, numpy.ndarray)
+
+ self.stackview.setStack(loi, perspective=2)
+ my_orig_stack, params = self.stackview.getStack(copy=False)
+ my_trans_stack, params = self.stackview.getCurrentView(copy=False)
+ # getStack(copy=False) must return the object set in setStack
+ self.assertIs(my_orig_stack, loi)
+ # getCurrentView(copy=False) returns a ListOfImages whose .images
+ # attr is the original data
+ self.assertEqual(my_trans_stack.shape,
+ (self.mystack.shape[2], self.mystack.shape[0], self.mystack.shape[1]))
+ self.assertTrue(numpy.array_equal(numpy.array(my_trans_stack),
+ numpy.transpose(self.mystack, axes=(2, 0, 1))))
+ self.assertIsInstance(my_trans_stack,
+ ListOfImages) # returnNumpyArray=False by default in getStack
+ self.assertIs(my_trans_stack.images, loi)
+
+ def testPerspective(self):
+ self.stackview.setStack(numpy.arange(24).reshape((2, 3, 4)))
+ self.assertEqual(self.stackview._perspective, 0,
+ "Default perspective is not 0 (dim1-dim2).")
+
+ self.stackview._StackView__planeSelection.setPerspective(1)
+ self.assertEqual(self.stackview._perspective, 1,
+ "Plane selection combobox not updating perspective")
+
+ self.stackview.setStack(numpy.arange(6).reshape((1, 2, 3)))
+ self.assertEqual(self.stackview._perspective, 1,
+ "Perspective not preserved when calling setStack "
+ "without specifying the perspective parameter.")
+
+ self.stackview.setStack(numpy.arange(24).reshape((2, 3, 4)), perspective=2)
+ self.assertEqual(self.stackview._perspective, 2,
+ "Perspective not set in setStack(..., perspective=2).")
+
+ def testDefaultTitle(self):
+ """Test that the plot title contains the proper Z information"""
+ self.stackview.setStack(numpy.arange(24).reshape((4, 3, 2)),
+ calibrations=[(0, 1), (-10, 10), (3.14, 3.14)])
+ self.assertEqual(self.stackview._plot.getGraphTitle(),
+ "Image z=0")
+ self.stackview.setFrameNumber(2)
+ self.assertEqual(self.stackview._plot.getGraphTitle(),
+ "Image z=2")
+
+ self.stackview._StackView__planeSelection.setPerspective(1)
+ self.stackview.setFrameNumber(0)
+ self.assertEqual(self.stackview._plot.getGraphTitle(),
+ "Image z=-10")
+ self.stackview.setFrameNumber(2)
+ self.assertEqual(self.stackview._plot.getGraphTitle(),
+ "Image z=10")
+
+ self.stackview._StackView__planeSelection.setPerspective(2)
+ self.stackview.setFrameNumber(0)
+ self.assertEqual(self.stackview._plot.getGraphTitle(),
+ "Image z=3.14")
+ self.stackview.setFrameNumber(1)
+ self.assertEqual(self.stackview._plot.getGraphTitle(),
+ "Image z=6.28")
+
+ def testCustomTitle(self):
+ """Test setting the plot title with a user defined callback"""
+ self.stackview.setStack(numpy.arange(24).reshape((4, 3, 2)),
+ calibrations=[(0, 1), (-10, 10), (3.14, 3.14)])
+
+ def title_callback(frame_idx):
+ return "Cubed index title %d" % (frame_idx**3)
+
+ self.stackview.setTitleCallback(title_callback)
+ self.assertEqual(self.stackview._plot.getGraphTitle(),
+ "Cubed index title 0")
+ self.stackview.setFrameNumber(2)
+ self.assertEqual(self.stackview._plot.getGraphTitle(),
+ "Cubed index title 8")
+
+ # perspective should not matter, only frame index
+ self.stackview._StackView__planeSelection.setPerspective(1)
+ self.stackview.setFrameNumber(0)
+ self.assertEqual(self.stackview._plot.getGraphTitle(),
+ "Cubed index title 0")
+ self.stackview.setFrameNumber(2)
+ self.assertEqual(self.stackview._plot.getGraphTitle(),
+ "Cubed index title 8")
+
+ with self.assertRaises(TypeError):
+ # setTitleCallback should not accept non-callable objects like strings
+ self.stackview.setTitleCallback(
+ "Là, vous faites sirop de vingt-et-un et vous dites : "
+ "beau sirop, mi-sirop, siroté, gagne-sirop, sirop-grelot,"
+ " passe-montagne, sirop au bon goût.")
+
+ def testStackFrameNumber(self):
+ self.stackview.setStack(self.mystack)
+ self.assertEqual(self.stackview.getFrameNumber(), 0)
+
+ listener = SignalListener()
+ self.stackview.sigFrameChanged.connect(listener)
+
+ self.stackview.setFrameNumber(1)
+ self.assertEqual(self.stackview.getFrameNumber(), 1)
+ self.assertEqual(listener.arguments(), [(1,)])
+
+
+class TestStackViewMainWindow(TestCaseQt):
+ """Base class for tests of StackView."""
+
+ def setUp(self):
+ super(TestStackViewMainWindow, self).setUp()
+ self.stackview = StackViewMainWindow()
+ self.stackview.show()
+ self.qWaitForWindowExposed(self.stackview)
+ self.mystack = numpy.fromfunction(
+ lambda i, j, k: numpy.sin(i/15.) + numpy.cos(j/4.) + 2 * numpy.sin(k/6.),
+ (10, 20, 30)
+ )
+
+ def tearDown(self):
+ self.stackview.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.stackview.close()
+ del self.stackview
+ super(TestStackViewMainWindow, self).tearDown()
+
+ def testSetStack(self):
+ self.stackview.setStack(self.mystack)
+ self.stackview.setColormap("viridis", autoscale=True)
+ my_trans_stack, params = self.stackview.getStack()
+ self.assertEqual(my_trans_stack.shape, self.mystack.shape)
+ self.assertTrue(numpy.array_equal(self.mystack,
+ my_trans_stack))
+ self.assertEqual(params["colormap"]["name"],
+ "viridis")
+
+ def testSetStackPerspective(self):
+ self.stackview.setStack(self.mystack, perspective=1)
+ my_trans_stack, params = self.stackview.getCurrentView()
+ # get stack returns the transposed data, depending on the perspective
+ self.assertEqual(my_trans_stack.shape,
+ (self.mystack.shape[1], self.mystack.shape[0], self.mystack.shape[2]))
+ self.assertTrue(numpy.array_equal(numpy.transpose(self.mystack, axes=(1, 0, 2)),
+ my_trans_stack))
diff --git a/src/silx/gui/plot/test/testStats.py b/src/silx/gui/plot/test/testStats.py
new file mode 100644
index 0000000..0a792a4
--- /dev/null
+++ b/src/silx/gui/plot/test/testStats.py
@@ -0,0 +1,1047 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for CurvesROIWidget"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "07/03/2018"
+
+
+from silx.gui import qt
+from silx.gui.plot.stats import stats
+from silx.gui.plot import StatsWidget
+from silx.gui.plot.stats import statshandler
+from silx.gui.utils.testutils import TestCaseQt, SignalListener
+from silx.gui.plot import Plot1D, Plot2D
+from silx.gui.plot3d.SceneWidget import SceneWidget
+from silx.gui.plot.items.roi import RectangleROI, PolygonROI
+from silx.gui.plot.tools.roi import RegionOfInterestManager
+from silx.gui.plot.stats.stats import Stats
+from silx.gui.plot.CurvesROIWidget import ROI
+from silx.utils.testutils import ParametricTestCase
+import unittest
+import logging
+import numpy
+
+_logger = logging.getLogger(__name__)
+
+
+class TestStatsBase(object):
+ """Base class for stats TestCase"""
+ def setUp(self):
+ self.createCurveContext()
+ self.createImageContext()
+ self.createScatterContext()
+
+ def tearDown(self):
+ self.plot1d.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot1d.close()
+ del self.plot1d
+ self.plot2d.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot2d.close()
+ del self.plot2d
+ self.scatterPlot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.scatterPlot.close()
+ del self.scatterPlot
+
+ def createCurveContext(self):
+ self.plot1d = Plot1D()
+ x = range(20)
+ y = range(20)
+ self.plot1d.addCurve(x, y, legend='curve0')
+
+ self.curveContext = stats._CurveContext(
+ item=self.plot1d.getCurve('curve0'),
+ plot=self.plot1d,
+ onlimits=False,
+ roi=None)
+
+ def createScatterContext(self):
+ self.scatterPlot = Plot2D()
+ lgd = 'scatter plot'
+ self.xScatterData = numpy.array([0, 2, 3, 20, 50, 60, 36])
+ self.yScatterData = numpy.array([2, 3, 4, 26, 69, 6, 18])
+ self.valuesScatterData = numpy.array([5, 6, 7, 10, 90, 20, 5])
+ self.scatterPlot.addScatter(self.xScatterData, self.yScatterData,
+ self.valuesScatterData, legend=lgd)
+ self.scatterContext = stats._ScatterContext(
+ item=self.scatterPlot.getScatter(lgd),
+ plot=self.scatterPlot,
+ onlimits=False,
+ roi=None
+ )
+
+ def createImageContext(self):
+ self.plot2d = Plot2D()
+ self._imgLgd = 'test image'
+ self.imageData = numpy.arange(32*128).reshape(32, 128)
+ self.plot2d.addImage(data=self.imageData,
+ legend=self._imgLgd, replace=False)
+ self.imageContext = stats._ImageContext(
+ item=self.plot2d.getImage(self._imgLgd),
+ plot=self.plot2d,
+ onlimits=False,
+ roi=None
+ )
+
+ def getBasicStats(self):
+ return {
+ 'min': stats.StatMin(),
+ 'minCoords': stats.StatCoordMin(),
+ 'max': stats.StatMax(),
+ 'maxCoords': stats.StatCoordMax(),
+ 'std': stats.Stat(name='std', fct=numpy.std),
+ 'mean': stats.Stat(name='mean', fct=numpy.mean),
+ 'com': stats.StatCOM()
+ }
+
+
+class TestStats(TestStatsBase, TestCaseQt):
+ """
+ Test :class:`BaseClass` class and inheriting classes
+ """
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ TestStatsBase.setUp(self)
+
+ def tearDown(self):
+ TestStatsBase.tearDown(self)
+ TestCaseQt.tearDown(self)
+
+ def testBasicStatsCurve(self):
+ """Test result for simple stats on a curve"""
+ _stats = self.getBasicStats()
+ xData = yData = numpy.array(range(20))
+ self.assertEqual(_stats['min'].calculate(self.curveContext), 0)
+ self.assertEqual(_stats['max'].calculate(self.curveContext), 19)
+ self.assertEqual(_stats['minCoords'].calculate(self.curveContext), (0,))
+ self.assertEqual(_stats['maxCoords'].calculate(self.curveContext), (19,))
+ self.assertEqual(_stats['std'].calculate(self.curveContext), numpy.std(yData))
+ self.assertEqual(_stats['mean'].calculate(self.curveContext), numpy.mean(yData))
+ com = numpy.sum(xData * yData) / numpy.sum(yData)
+ self.assertEqual(_stats['com'].calculate(self.curveContext), com)
+
+ def testBasicStatsImage(self):
+ """Test result for simple stats on an image"""
+ _stats = self.getBasicStats()
+ self.assertEqual(_stats['min'].calculate(self.imageContext), 0)
+ self.assertEqual(_stats['max'].calculate(self.imageContext), 128 * 32 - 1)
+ self.assertEqual(_stats['minCoords'].calculate(self.imageContext), (0, 0))
+ self.assertEqual(_stats['maxCoords'].calculate(self.imageContext), (127, 31))
+ self.assertEqual(_stats['std'].calculate(self.imageContext), numpy.std(self.imageData))
+ self.assertEqual(_stats['mean'].calculate(self.imageContext), numpy.mean(self.imageData))
+
+ yData = numpy.sum(self.imageData.astype(numpy.float64), axis=1)
+ xData = numpy.sum(self.imageData.astype(numpy.float64), axis=0)
+ dataXRange = range(self.imageData.shape[1])
+ dataYRange = range(self.imageData.shape[0])
+
+ ycom = numpy.sum(yData*dataYRange) / numpy.sum(yData)
+ xcom = numpy.sum(xData*dataXRange) / numpy.sum(xData)
+
+ self.assertEqual(_stats['com'].calculate(self.imageContext), (xcom, ycom))
+
+ def testStatsImageAdv(self):
+ """Test that scale and origin are taking into account for images"""
+
+ image2Data = numpy.arange(32 * 128).reshape(32, 128)
+ self.plot2d.addImage(data=image2Data, legend=self._imgLgd,
+ replace=True, origin=(100, 10), scale=(2, 0.5))
+ image2Context = stats._ImageContext(
+ item=self.plot2d.getImage(self._imgLgd),
+ plot=self.plot2d,
+ onlimits=False,
+ roi=None,
+ )
+ _stats = self.getBasicStats()
+ self.assertEqual(_stats['min'].calculate(image2Context), 0)
+ self.assertEqual(
+ _stats['max'].calculate(image2Context), 128 * 32 - 1)
+ self.assertEqual(
+ _stats['minCoords'].calculate(image2Context), (100, 10))
+ self.assertEqual(
+ _stats['maxCoords'].calculate(image2Context), (127*2. + 100,
+ 31 * 0.5 + 10))
+ self.assertEqual(_stats['std'].calculate(image2Context),
+ numpy.std(self.imageData))
+ self.assertEqual(_stats['mean'].calculate(image2Context),
+ numpy.mean(self.imageData))
+
+ yData = numpy.sum(self.imageData, axis=1)
+ xData = numpy.sum(self.imageData, axis=0)
+ dataXRange = numpy.arange(self.imageData.shape[1], dtype=numpy.float64)
+ dataYRange = numpy.arange(self.imageData.shape[0], dtype=numpy.float64)
+
+ ycom = numpy.sum(yData * dataYRange) / numpy.sum(yData)
+ ycom = (ycom * 0.5) + 10
+ xcom = numpy.sum(xData * dataXRange) / numpy.sum(xData)
+ xcom = (xcom * 2.) + 100
+ self.assertTrue(numpy.allclose(
+ _stats['com'].calculate(image2Context), (xcom, ycom)))
+
+ def testBasicStatsScatter(self):
+ """Test result for simple stats on a scatter"""
+ _stats = self.getBasicStats()
+ self.assertEqual(_stats['min'].calculate(self.scatterContext), 5)
+ self.assertEqual(_stats['max'].calculate(self.scatterContext), 90)
+ self.assertEqual(_stats['minCoords'].calculate(self.scatterContext), (0, 2))
+ self.assertEqual(_stats['maxCoords'].calculate(self.scatterContext), (50, 69))
+ self.assertEqual(_stats['std'].calculate(self.scatterContext), numpy.std(self.valuesScatterData))
+ self.assertEqual(_stats['mean'].calculate(self.scatterContext), numpy.mean(self.valuesScatterData))
+
+ data = self.valuesScatterData.astype(numpy.float64)
+ comx = numpy.sum(self.xScatterData * data) / numpy.sum(data)
+ comy = numpy.sum(self.yScatterData * data) / numpy.sum(data)
+ self.assertEqual(_stats['com'].calculate(self.scatterContext),
+ (comx, comy))
+
+ def testKindNotManagedByStat(self):
+ """Make sure an exception is raised if we try to execute calculate
+ of the base class"""
+ b = stats.StatBase(name='toto', compatibleKinds='curve')
+ with self.assertRaises(NotImplementedError):
+ b.calculate(self.imageContext)
+
+ def testKindNotManagedByContext(self):
+ """
+ Make sure an error is raised if we try to calculate a statistic with
+ a context not managed
+ """
+ myStat = stats.Stat(name='toto', fct=numpy.std, kinds=('curve'))
+ myStat.calculate(self.curveContext)
+ with self.assertRaises(ValueError):
+ myStat.calculate(self.scatterContext)
+ with self.assertRaises(ValueError):
+ myStat.calculate(self.imageContext)
+
+ def testOnLimits(self):
+ stat = stats.StatMin()
+
+ self.plot1d.getXAxis().setLimitsConstraints(minPos=2, maxPos=5)
+ curveContextOnLimits = stats._CurveContext(
+ item=self.plot1d.getCurve('curve0'),
+ plot=self.plot1d,
+ onlimits=True,
+ roi=None)
+ self.assertEqual(stat.calculate(curveContextOnLimits), 2)
+
+ self.plot2d.getXAxis().setLimitsConstraints(minPos=32)
+ imageContextOnLimits = stats._ImageContext(
+ item=self.plot2d.getImage('test image'),
+ plot=self.plot2d,
+ onlimits=True,
+ roi=None)
+ self.assertEqual(stat.calculate(imageContextOnLimits), 32)
+
+ self.scatterPlot.getXAxis().setLimitsConstraints(minPos=40)
+ scatterContextOnLimits = stats._ScatterContext(
+ item=self.scatterPlot.getScatter('scatter plot'),
+ plot=self.scatterPlot,
+ onlimits=True,
+ roi=None)
+ self.assertEqual(stat.calculate(scatterContextOnLimits), 20)
+
+
+class TestStatsFormatter(TestCaseQt):
+ """Simple test to check usage of the :class:`StatsFormatter`"""
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.plot1d = Plot1D()
+ x = range(20)
+ y = range(20)
+ self.plot1d.addCurve(x, y, legend='curve0')
+
+ self.curveContext = stats._CurveContext(
+ item=self.plot1d.getCurve('curve0'),
+ plot=self.plot1d,
+ onlimits=False,
+ roi=None)
+
+ self.stat = stats.StatMin()
+
+ def tearDown(self):
+ self.plot1d.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot1d.close()
+ del self.plot1d
+ TestCaseQt.tearDown(self)
+
+ def testEmptyFormatter(self):
+ """Make sure a formatter with no formatter definition will return a
+ simple cast to str"""
+ emptyFormatter = statshandler.StatFormatter()
+ self.assertEqual(
+ emptyFormatter.format(self.stat.calculate(self.curveContext)), '0.000')
+
+ def testSettedFormatter(self):
+ """Make sure a formatter with no formatter definition will return a
+ simple cast to str"""
+ formatter= statshandler.StatFormatter(formatter='{0:.3f}')
+ self.assertEqual(
+ formatter.format(self.stat.calculate(self.curveContext)), '0.000')
+
+
+class TestStatsHandler(TestCaseQt):
+ """Make sure the StatHandler is correctly making the link between
+ :class:`StatBase` and :class:`StatFormatter` and checking the API is valid
+ """
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.plot1d = Plot1D()
+ x = range(20)
+ y = range(20)
+ self.plot1d.addCurve(x, y, legend='curve0')
+ self.curveItem = self.plot1d.getCurve('curve0')
+
+ self.stat = stats.StatMin()
+
+ def tearDown(self):
+ Stats._getContext.cache_clear()
+ self.plot1d.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot1d.close()
+ self.plot1d = None
+ TestCaseQt.tearDown(self)
+
+ def testConstructor(self):
+ """Make sure the constructor can deal will all possible arguments:
+
+ * tuple of :class:`StatBase` derivated classes
+ * tuple of tuples (:class:`StatBase`, :class:`StatFormatter`)
+ * tuple of tuples (str, pointer to function, kind)
+ """
+ handler0 = statshandler.StatsHandler(
+ (stats.StatMin(), stats.StatMax())
+ )
+
+ res = handler0.calculate(item=self.curveItem, plot=self.plot1d,
+ onlimits=False)
+ self.assertTrue('min' in res)
+ self.assertEqual(res['min'], '0')
+ self.assertTrue('max' in res)
+ self.assertEqual(res['max'], '19')
+
+ handler1 = statshandler.StatsHandler(
+ (
+ (stats.StatMin(), statshandler.StatFormatter(formatter=None)),
+ (stats.StatMax(), statshandler.StatFormatter())
+ )
+ )
+
+ res = handler1.calculate(item=self.curveItem, plot=self.plot1d,
+ onlimits=False)
+ self.assertTrue('min' in res)
+ self.assertEqual(res['min'], '0')
+ self.assertTrue('max' in res)
+ self.assertEqual(res['max'], '19.000')
+
+ handler2 = statshandler.StatsHandler(
+ (
+ (stats.StatMin(), None),
+ (stats.StatMax(), statshandler.StatFormatter())
+ ))
+
+ res = handler2.calculate(item=self.curveItem, plot=self.plot1d,
+ onlimits=False)
+ self.assertTrue('min' in res)
+ self.assertEqual(res['min'], '0')
+ self.assertTrue('max' in res)
+ self.assertEqual(res['max'], '19.000')
+
+ handler3 = statshandler.StatsHandler((
+ (('amin', numpy.argmin), statshandler.StatFormatter()),
+ ('amax', numpy.argmax)
+ ))
+
+ res = handler3.calculate(item=self.curveItem, plot=self.plot1d,
+ onlimits=False)
+ self.assertTrue('amin' in res)
+ self.assertEqual(res['amin'], '0.000')
+ self.assertTrue('amax' in res)
+ self.assertEqual(res['amax'], '19')
+
+ with self.assertRaises(ValueError):
+ statshandler.StatsHandler(('name'))
+
+
+class TestStatsWidgetWithCurves(TestCaseQt, ParametricTestCase):
+ """Basic test for StatsWidget with curves"""
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.plot = Plot1D()
+ self.plot.show()
+ x = range(20)
+ y = range(20)
+ self.plot.addCurve(x, y, legend='curve0')
+ y = range(12, 32)
+ self.plot.addCurve(x, y, legend='curve1')
+ y = range(-2, 18)
+ self.plot.addCurve(x, y, legend='curve2')
+ self.widget = StatsWidget.StatsWidget(plot=self.plot)
+ self.statsTable = self.widget._statsTable
+
+ mystats = statshandler.StatsHandler((
+ stats.StatMin(),
+ (stats.StatCoordMin(), statshandler.StatFormatter(None, qt.QTableWidgetItem)),
+ stats.StatMax(),
+ (stats.StatCoordMax(), statshandler.StatFormatter(None, qt.QTableWidgetItem)),
+ stats.StatDelta(),
+ ('std', numpy.std),
+ ('mean', numpy.mean),
+ stats.StatCOM()
+ ))
+
+ self.statsTable.setStats(mystats)
+
+ def tearDown(self):
+ Stats._getContext.cache_clear()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ self.statsTable = None
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ self.widget = None
+ self.plot = None
+ TestCaseQt.tearDown(self)
+
+ def testDisplayActiveItemsSyncOptions(self):
+ """
+ Test that the several option of the sync options are well
+ synchronized between the different object"""
+ widget = StatsWidget.StatsWidget(plot=self.plot)
+ table = StatsWidget.StatsTable(plot=self.plot)
+
+ def check_display_only_active_item(only_active):
+ # check internal value
+ self.assertIs(widget._statsTable._displayOnlyActItem, only_active)
+ # self.assertTrue(table._displayOnlyActItem is only_active)
+ # check gui display
+ self.assertEqual(widget._options.isActiveItemMode(), only_active)
+
+ for displayOnlyActiveItems in (True, False):
+ with self.subTest(displayOnlyActiveItems=displayOnlyActiveItems):
+ widget.setDisplayOnlyActiveItem(displayOnlyActiveItems)
+ # table.setDisplayOnlyActiveItem(displayOnlyActiveItems)
+ check_display_only_active_item(displayOnlyActiveItems)
+
+ check_display_only_active_item(only_active=False)
+ widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ table.setAttribute(qt.Qt.WA_DeleteOnClose)
+ widget.close()
+ table.close()
+
+ def testInit(self):
+ """Make sure all the curves are registred on initialization"""
+ self.assertEqual(self.statsTable.rowCount(), 3)
+
+ def testRemoveCurve(self):
+ """Make sure the Curves stats take into account the curve removal from
+ plot"""
+ self.plot.removeCurve('curve2')
+ self.assertEqual(self.statsTable.rowCount(), 2)
+ for iRow in range(2):
+ self.assertTrue(self.statsTable.item(iRow, 0).text() in ('curve0', 'curve1'))
+
+ self.plot.removeCurve('curve0')
+ self.assertEqual(self.statsTable.rowCount(), 1)
+ self.plot.removeCurve('curve1')
+ self.assertEqual(self.statsTable.rowCount(), 0)
+
+ def testAddCurve(self):
+ """Make sure the Curves stats take into account the add curve action"""
+ self.plot.addCurve(legend='curve3', x=range(10), y=range(10))
+ self.assertEqual(self.statsTable.rowCount(), 4)
+
+ def testUpdateCurveFromAddCurve(self):
+ """Make sure the stats of the cuve will be removed after updating a
+ curve"""
+ self.plot.addCurve(legend='curve0', x=range(10), y=range(10))
+ self.qapp.processEvents()
+ self.assertEqual(self.statsTable.rowCount(), 3)
+ curve = self.plot._getItem(kind='curve', legend='curve0')
+ tableItems = self.statsTable._itemToTableItems(curve)
+ self.assertEqual(tableItems['max'].text(), '9')
+
+ def testUpdateCurveFromCurveObj(self):
+ self.plot.getCurve('curve0').setData(x=range(4), y=range(4))
+ self.qapp.processEvents()
+ self.assertEqual(self.statsTable.rowCount(), 3)
+ curve = self.plot._getItem(kind='curve', legend='curve0')
+ tableItems = self.statsTable._itemToTableItems(curve)
+ self.assertEqual(tableItems['max'].text(), '3')
+
+ def testSetAnotherPlot(self):
+ plot2 = Plot1D()
+ plot2.addCurve(x=range(26), y=range(26), legend='new curve')
+ self.statsTable.setPlot(plot2)
+ self.assertEqual(self.statsTable.rowCount(), 1)
+ self.qapp.processEvents()
+ plot2.setAttribute(qt.Qt.WA_DeleteOnClose)
+ plot2.close()
+ plot2 = None
+
+ def testUpdateMode(self):
+ """Make sure the update modes are well take into account"""
+ self.plot.setActiveCurve('curve0')
+ for display_only_active in (True, False):
+ with self.subTest(display_only_active=display_only_active):
+ self.widget.setDisplayOnlyActiveItem(display_only_active)
+ self.plot.getCurve('curve0').setData(x=range(4), y=range(4))
+ self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO)
+ update_stats_action = self.widget._options.getUpdateStatsAction()
+ # test from api
+ self.assertEqual(self.widget.getUpdateMode(), StatsWidget.UpdateMode.AUTO)
+ self.widget.show()
+ # check stats change in auto mode
+ self.plot.getCurve('curve0').setData(x=range(4), y=range(-1, 3))
+ self.qapp.processEvents()
+ tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0'))
+ curve0_min = tableItems['min'].text()
+ self.assertTrue(float(curve0_min) == -1.)
+
+ self.plot.getCurve('curve0').setData(x=range(4), y=range(1, 5))
+ self.qapp.processEvents()
+ tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0'))
+ curve0_min = tableItems['min'].text()
+ self.assertTrue(float(curve0_min) == 1.)
+
+ # check stats change in manual mode only if requested
+ self.widget.setUpdateMode(StatsWidget.UpdateMode.MANUAL)
+ self.assertEqual(self.widget.getUpdateMode(), StatsWidget.UpdateMode.MANUAL)
+
+ self.plot.getCurve('curve0').setData(x=range(4), y=range(2, 6))
+ self.qapp.processEvents()
+ tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0'))
+ curve0_min = tableItems['min'].text()
+ self.assertTrue(float(curve0_min) == 1.)
+
+ update_stats_action.trigger()
+ tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0'))
+ curve0_min = tableItems['min'].text()
+ self.assertTrue(float(curve0_min) == 2.)
+
+ def testItemHidden(self):
+ """Test if an item is hide, then the associated stats item is also
+ hide"""
+ curve0 = self.plot.getCurve('curve0')
+ curve1 = self.plot.getCurve('curve1')
+ curve2 = self.plot.getCurve('curve2')
+
+ self.plot.show()
+ self.widget.show()
+ self.qWaitForWindowExposed(self.widget)
+ self.assertFalse(self.statsTable.isRowHidden(0))
+ self.assertFalse(self.statsTable.isRowHidden(1))
+ self.assertFalse(self.statsTable.isRowHidden(2))
+
+ curve0.setVisible(False)
+ self.qapp.processEvents()
+ self.assertTrue(self.statsTable.isRowHidden(0))
+ curve0.setVisible(True)
+ self.qapp.processEvents()
+ self.assertFalse(self.statsTable.isRowHidden(0))
+ curve1.setVisible(False)
+ self.qapp.processEvents()
+ self.assertTrue(self.statsTable.isRowHidden(1))
+ tableItems = self.statsTable._itemToTableItems(curve2)
+ curve2_min = tableItems['min'].text()
+ self.assertTrue(float(curve2_min) == -2.)
+
+ curve0.setVisible(False)
+ curve1.setVisible(False)
+ curve2.setVisible(False)
+ self.qapp.processEvents()
+ self.assertTrue(self.statsTable.isRowHidden(0))
+ self.assertTrue(self.statsTable.isRowHidden(1))
+ self.assertTrue(self.statsTable.isRowHidden(2))
+
+
+class TestStatsWidgetWithImages(TestCaseQt):
+ """Basic test for StatsWidget with images"""
+
+ IMAGE_LEGEND = 'test image'
+
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.plot = Plot2D()
+
+ self.plot.addImage(data=numpy.arange(128*128).reshape(128, 128),
+ legend=self.IMAGE_LEGEND, replace=False)
+
+ self.widget = StatsWidget.StatsTable(plot=self.plot)
+
+ mystats = statshandler.StatsHandler((
+ (stats.StatMin(), statshandler.StatFormatter()),
+ (stats.StatCoordMin(), statshandler.StatFormatter(None, qt.QTableWidgetItem)),
+ (stats.StatMax(), statshandler.StatFormatter()),
+ (stats.StatCoordMax(), statshandler.StatFormatter(None, qt.QTableWidgetItem)),
+ (stats.StatDelta(), statshandler.StatFormatter()),
+ ('std', numpy.std),
+ ('mean', numpy.mean),
+ (stats.StatCOM(), statshandler.StatFormatter(None))
+ ))
+
+ self.widget.setStats(mystats)
+
+ def tearDown(self):
+ Stats._getContext.cache_clear()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ self.widget = None
+ self.plot = None
+ TestCaseQt.tearDown(self)
+
+ def test(self):
+ image = self.plot._getItem(
+ kind='image', legend=self.IMAGE_LEGEND)
+ tableItems = self.widget._itemToTableItems(image)
+
+ maxText = '{0:.3f}'.format((128 * 128) - 1)
+ self.assertEqual(tableItems['legend'].text(), self.IMAGE_LEGEND)
+ self.assertEqual(tableItems['min'].text(), '0.000')
+ self.assertEqual(tableItems['max'].text(), maxText)
+ self.assertEqual(tableItems['delta'].text(), maxText)
+ self.assertEqual(tableItems['coords min'].text(), '0.0, 0.0')
+ self.assertEqual(tableItems['coords max'].text(), '127.0, 127.0')
+
+ def testItemHidden(self):
+ """Test if an item is hide, then the associated stats item is also
+ hide"""
+ self.widget.show()
+ self.plot.show()
+ self.qWaitForWindowExposed(self.widget)
+ self.assertFalse(self.widget.isRowHidden(0))
+ self.plot.getImage(self.IMAGE_LEGEND).setVisible(False)
+ self.qapp.processEvents()
+ self.assertTrue(self.widget.isRowHidden(0))
+
+
+class TestStatsWidgetWithScatters(TestCaseQt):
+
+ SCATTER_LEGEND = 'scatter plot'
+
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.scatterPlot = Plot2D()
+ self.scatterPlot.addScatter([0, 1, 2, 20, 50, 60],
+ [2, 3, 4, 26, 69, 6],
+ [5, 6, 7, 10, 90, 20],
+ legend=self.SCATTER_LEGEND)
+ self.widget = StatsWidget.StatsTable(plot=self.scatterPlot)
+
+ mystats = statshandler.StatsHandler((
+ stats.StatMin(),
+ (stats.StatCoordMin(), statshandler.StatFormatter(None, qt.QTableWidgetItem)),
+ stats.StatMax(),
+ (stats.StatCoordMax(), statshandler.StatFormatter(None, qt.QTableWidgetItem)),
+ stats.StatDelta(),
+ ('std', numpy.std),
+ ('mean', numpy.mean),
+ stats.StatCOM()
+ ))
+
+ self.widget.setStats(mystats)
+
+ def tearDown(self):
+ Stats._getContext.cache_clear()
+ self.scatterPlot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.scatterPlot.close()
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ self.widget = None
+ self.scatterPlot = None
+ TestCaseQt.tearDown(self)
+
+ def testStats(self):
+ scatter = self.scatterPlot._getItem(
+ kind='scatter', legend=self.SCATTER_LEGEND)
+ tableItems = self.widget._itemToTableItems(scatter)
+ self.assertEqual(tableItems['legend'].text(), self.SCATTER_LEGEND)
+ self.assertEqual(tableItems['min'].text(), '5')
+ self.assertEqual(tableItems['coords min'].text(), '0, 2')
+ self.assertEqual(tableItems['max'].text(), '90')
+ self.assertEqual(tableItems['coords max'].text(), '50, 69')
+ self.assertEqual(tableItems['delta'].text(), '85')
+
+
+class TestEmptyStatsWidget(TestCaseQt):
+ def test(self):
+ widget = StatsWidget.StatsWidget()
+ widget.show()
+ self.qWaitForWindowExposed(widget)
+
+
+class TestLineWidget(TestCaseQt):
+ """Some test for the StatsLineWidget."""
+ def setUp(self):
+ TestCaseQt.setUp(self)
+
+ mystats = statshandler.StatsHandler((
+ (stats.StatMin(), statshandler.StatFormatter()),
+ ))
+
+ self.plot = Plot1D()
+ self.plot.show()
+ self.x = range(20)
+ self.y0 = range(20)
+ self.curve0 = self.plot.addCurve(self.x, self.y0, legend='curve0')
+ self.y1 = range(12, 32)
+ self.plot.addCurve(self.x, self.y1, legend='curve1')
+ self.y2 = range(-2, 18)
+ self.plot.addCurve(self.x, self.y2, legend='curve2')
+ self.widget = StatsWidget.BasicGridStatsWidget(plot=self.plot,
+ kind='curve',
+ stats=mystats)
+
+ def tearDown(self):
+ Stats._getContext.cache_clear()
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ self.widget.setPlot(None)
+ self.widget._lineStatsWidget._statQlineEdit.clear()
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ self.widget = None
+ self.plot = None
+ TestCaseQt.tearDown(self)
+
+ def testProcessing(self):
+ self.widget._lineStatsWidget.setStatsOnVisibleData(False)
+ self.qapp.processEvents()
+ self.plot.setActiveCurve(legend='curve0')
+ self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '0.000')
+ self.plot.setActiveCurve(legend='curve1')
+ self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '12.000')
+ self.plot.getXAxis().setLimitsConstraints(minPos=2, maxPos=5)
+ self.widget.setStatsOnVisibleData(True)
+ self.qapp.processEvents()
+ self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '14.000')
+ self.plot.setActiveCurve(None)
+ self.assertIsNone(self.plot.getActiveCurve())
+ self.widget.setStatsOnVisibleData(False)
+ self.qapp.processEvents()
+ self.assertFalse(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '14.000')
+ self.widget.setKind('image')
+ self.plot.addImage(numpy.arange(100*100).reshape(100, 100) + 0.312)
+ self.qapp.processEvents()
+ self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '0.312')
+
+ def testUpdateMode(self):
+ """Make sure the update modes are well take into account"""
+ self.plot.setActiveCurve(self.curve0)
+ _autoRB = self.widget._options._autoRB
+ _manualRB = self.widget._options._manualRB
+ # test from api
+ self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO)
+ self.assertTrue(_autoRB.isChecked())
+ self.assertFalse(_manualRB.isChecked())
+
+ # check stats change in auto mode
+ curve0_min = self.widget._lineStatsWidget._statQlineEdit['min'].text()
+ new_y = numpy.array(self.y0) - 2.56
+ self.plot.addCurve(x=self.x, y=new_y, legend=self.curve0)
+ curve0_min2 = self.widget._lineStatsWidget._statQlineEdit['min'].text()
+ self.assertTrue(curve0_min != curve0_min2)
+
+ # check stats change in manual mode only if requested
+ self.widget.setUpdateMode(StatsWidget.UpdateMode.MANUAL)
+ self.assertFalse(_autoRB.isChecked())
+ self.assertTrue(_manualRB.isChecked())
+
+ new_y = numpy.array(self.y0) - 1.2
+ self.plot.addCurve(x=self.x, y=new_y, legend=self.curve0)
+ curve0_min3 = self.widget._lineStatsWidget._statQlineEdit['min'].text()
+ self.assertTrue(curve0_min3 == curve0_min2)
+ self.widget._options._updateRequested()
+ curve0_min3 = self.widget._lineStatsWidget._statQlineEdit['min'].text()
+ self.assertTrue(curve0_min3 != curve0_min2)
+
+ # test from gui
+ self.widget.showRadioButtons(True)
+ self.widget._options._autoRB.toggle()
+ self.assertTrue(_autoRB.isChecked())
+ self.assertFalse(_manualRB.isChecked())
+
+ self.widget._options._manualRB.toggle()
+ self.assertFalse(_autoRB.isChecked())
+ self.assertTrue(_manualRB.isChecked())
+
+
+class TestUpdateModeWidget(TestCaseQt):
+ """Test UpdateModeWidget"""
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.widget = StatsWidget.UpdateModeWidget(parent=None)
+
+ def tearDown(self):
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ self.widget = None
+ TestCaseQt.tearDown(self)
+
+ def testSignals(self):
+ """Test the signal emission of the widget"""
+ self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO)
+ modeChangedListener = SignalListener()
+ manualUpdateListener = SignalListener()
+ self.widget.sigUpdateModeChanged.connect(modeChangedListener)
+ self.widget.sigUpdateRequested.connect(manualUpdateListener)
+ self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO)
+ self.assertEqual(self.widget.getUpdateMode(), StatsWidget.UpdateMode.AUTO)
+ self.assertEqual(modeChangedListener.callCount(), 0)
+ self.qapp.processEvents()
+
+ self.widget.setUpdateMode(StatsWidget.UpdateMode.MANUAL)
+ self.assertEqual(self.widget.getUpdateMode(), StatsWidget.UpdateMode.MANUAL)
+ self.qapp.processEvents()
+ self.assertEqual(modeChangedListener.callCount(), 1)
+ self.assertEqual(manualUpdateListener.callCount(), 0)
+ self.widget._updatePB.click()
+ self.widget._updatePB.click()
+ self.assertEqual(manualUpdateListener.callCount(), 2)
+
+ self.widget._autoRB.setChecked(True)
+ self.assertEqual(modeChangedListener.callCount(), 2)
+ self.widget._updatePB.click()
+ self.assertEqual(manualUpdateListener.callCount(), 2)
+
+
+class TestStatsROI(TestStatsBase, TestCaseQt):
+ """
+ Test stats based on ROI
+ """
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.createRois()
+ TestStatsBase.setUp(self)
+ self.createHistogramContext()
+
+ self.roiManager = RegionOfInterestManager(self.plot2d)
+ self.roiManager.addRoi(self._2Droi_rect)
+ self.roiManager.addRoi(self._2Droi_poly)
+
+ def tearDown(self):
+ self.roiManager.clear()
+ self.roiManager = None
+ self._1Droi = None
+ self._2Droi_rect = None
+ self._2Droi_poly = None
+ self.plotHisto.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plotHisto.close()
+ self.plotHisto = None
+ TestStatsBase.tearDown(self)
+ TestCaseQt.tearDown(self)
+
+ def createRois(self):
+ self._1Droi = ROI(name='my1DRoi', fromdata=2.0, todata=5.0)
+ self._2Droi_rect = RectangleROI()
+ self._2Droi_rect.setGeometry(size=(10, 10), origin=(10, 0))
+ self._2Droi_poly = PolygonROI()
+ points = numpy.array(((0, 20), (0, 0), (10, 0)))
+ self._2Droi_poly.setPoints(points=points)
+
+ def createCurveContext(self):
+ TestStatsBase.createCurveContext(self)
+ self.curveContext = stats._CurveContext(
+ item=self.plot1d.getCurve('curve0'),
+ plot=self.plot1d,
+ onlimits=False,
+ roi=self._1Droi)
+
+ def createHistogramContext(self):
+ self.plotHisto = Plot1D()
+ x = range(20)
+ y = range(20)
+ self.plotHisto.addHistogram(x, y, legend='histo0')
+
+ self.histoContext = stats._HistogramContext(
+ item=self.plotHisto.getHistogram('histo0'),
+ plot=self.plotHisto,
+ onlimits=False,
+ roi=self._1Droi)
+
+ def createScatterContext(self):
+ TestStatsBase.createScatterContext(self)
+ self.scatterContext = stats._ScatterContext(
+ item=self.scatterPlot.getScatter('scatter plot'),
+ plot=self.scatterPlot,
+ onlimits=False,
+ roi=self._1Droi
+ )
+
+ def createImageContext(self):
+ TestStatsBase.createImageContext(self)
+
+ self.imageContext = stats._ImageContext(
+ item=self.plot2d.getImage(self._imgLgd),
+ plot=self.plot2d,
+ onlimits=False,
+ roi=self._2Droi_rect
+ )
+
+ self.imageContext_2 = stats._ImageContext(
+ item=self.plot2d.getImage(self._imgLgd),
+ plot=self.plot2d,
+ onlimits=False,
+ roi=self._2Droi_poly
+ )
+
+ def testErrors(self):
+ # test if onlimits is True and give also a roi
+ with self.assertRaises(ValueError):
+ stats._CurveContext(item=self.plot1d.getCurve('curve0'),
+ plot=self.plot1d,
+ onlimits=True,
+ roi=self._1Droi)
+
+ # test if is a curve context and give an invalid 2D roi
+ with self.assertRaises(TypeError):
+ stats._CurveContext(item=self.plot1d.getCurve('curve0'),
+ plot=self.plot1d,
+ onlimits=False,
+ roi=self._2Droi_rect)
+
+ def testBasicStatsCurve(self):
+ """Test result for simple stats on a curve"""
+ _stats = self.getBasicStats()
+ xData = yData = numpy.array(range(0, 10))
+ self.assertEqual(_stats['min'].calculate(self.curveContext), 2)
+ self.assertEqual(_stats['max'].calculate(self.curveContext), 5)
+ self.assertEqual(_stats['minCoords'].calculate(self.curveContext), (2,))
+ self.assertEqual(_stats['maxCoords'].calculate(self.curveContext), (5,))
+ self.assertEqual(_stats['std'].calculate(self.curveContext), numpy.std(yData[2:6]))
+ self.assertEqual(_stats['mean'].calculate(self.curveContext), numpy.mean(yData[2:6]))
+ com = numpy.sum(xData[2:6] * yData[2:6]) / numpy.sum(yData[2:6])
+ self.assertEqual(_stats['com'].calculate(self.curveContext), com)
+
+ def testBasicStatsImageRectRoi(self):
+ """Test result for simple stats on an image"""
+ self.assertEqual(self.imageContext.values.compressed().size, 121)
+ _stats = self.getBasicStats()
+ self.assertEqual(_stats['min'].calculate(self.imageContext), 10)
+ self.assertEqual(_stats['max'].calculate(self.imageContext), 1300)
+ self.assertEqual(_stats['minCoords'].calculate(self.imageContext), (10, 0))
+ self.assertEqual(_stats['maxCoords'].calculate(self.imageContext), (20.0, 10.0))
+ self.assertAlmostEqual(_stats['std'].calculate(self.imageContext),
+ numpy.std(self.imageData[0:11, 10:21]))
+ self.assertAlmostEqual(_stats['mean'].calculate(self.imageContext),
+ numpy.mean(self.imageData[0:11, 10:21]))
+
+ compressed_values = self.imageContext.values.compressed()
+ compressed_values = compressed_values.reshape(11, 11)
+ yData = numpy.sum(compressed_values.astype(numpy.float64), axis=1)
+ xData = numpy.sum(compressed_values.astype(numpy.float64), axis=0)
+
+ dataYRange = range(11)
+ dataXRange = range(10, 21)
+
+ ycom = numpy.sum(yData*dataYRange) / numpy.sum(yData)
+ xcom = numpy.sum(xData*dataXRange) / numpy.sum(xData)
+ self.assertEqual(_stats['com'].calculate(self.imageContext), (xcom, ycom))
+
+ def testBasicStatsImagePolyRoi(self):
+ """Test a simple rectangle ROI"""
+ _stats = self.getBasicStats()
+ self.assertEqual(_stats['min'].calculate(self.imageContext_2), 0)
+ self.assertEqual(_stats['max'].calculate(self.imageContext_2), 2432)
+ self.assertEqual(_stats['minCoords'].calculate(self.imageContext_2), (0.0, 0.0))
+ # not 0.0, 19.0 because not fully in. Should all pixel have a weight,
+ # on to manage them in stats. For now 0 if the center is not in, else 1
+ self.assertEqual(_stats['maxCoords'].calculate(self.imageContext_2), (0.0, 19.0))
+
+ def testBasicStatsScatter(self):
+ self.assertEqual(self.scatterContext.values.compressed().size, 2)
+ _stats = self.getBasicStats()
+ self.assertEqual(_stats['min'].calculate(self.scatterContext), 6)
+ self.assertEqual(_stats['max'].calculate(self.scatterContext), 7)
+ self.assertEqual(_stats['minCoords'].calculate(self.scatterContext), (2, 3))
+ self.assertEqual(_stats['maxCoords'].calculate(self.scatterContext), (3, 4))
+ self.assertEqual(_stats['std'].calculate(self.scatterContext), numpy.std([6, 7]))
+ self.assertEqual(_stats['mean'].calculate(self.scatterContext), numpy.mean([6, 7]))
+
+ def testBasicHistogram(self):
+ _stats = self.getBasicStats()
+ xData = yData = numpy.array(range(2, 6))
+ self.assertEqual(_stats['min'].calculate(self.histoContext), 2)
+ self.assertEqual(_stats['max'].calculate(self.histoContext), 5)
+ self.assertEqual(_stats['minCoords'].calculate(self.histoContext), (2,))
+ self.assertEqual(_stats['maxCoords'].calculate(self.histoContext), (5,))
+ self.assertEqual(_stats['std'].calculate(self.histoContext), numpy.std(yData))
+ self.assertEqual(_stats['mean'].calculate(self.histoContext), numpy.mean(yData))
+ com = numpy.sum(xData * yData) / numpy.sum(yData)
+ self.assertEqual(_stats['com'].calculate(self.histoContext), com)
+
+
+class TestAdvancedROIImageContext(TestCaseQt):
+ """Test stats result on an image context with different scale and
+ origins"""
+
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.data_dims = (100, 100)
+ self.data = numpy.random.rand(*self.data_dims)
+ self.plot = Plot2D()
+
+ def tearDown(self):
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ self.plot = None
+ TestCaseQt.tearDown(self)
+
+ def test(self):
+ """Test stats result on an image context with different scale and
+ origins"""
+ roi_origins = [(0, 0), (2, 10), (14, 20)]
+ img_origins = [(0, 0), (14, 20), (2, 10)]
+ img_scales = [1.0, 0.5, 2.0]
+ _stats = {'sum': stats.Stat(name='sum', fct=numpy.sum), }
+ for roi_origin in roi_origins:
+ for img_origin in img_origins:
+ for img_scale in img_scales:
+ with self.subTest(roi_origin=roi_origin,
+ img_origin=img_origin,
+ img_scale=img_scale):
+ self.plot.addImage(self.data, legend='img',
+ origin=img_origin,
+ scale=img_scale)
+ roi = RectangleROI()
+ roi.setGeometry(origin=roi_origin, size=(20, 20))
+ context = stats._ImageContext(
+ item=self.plot.getImage('img'),
+ plot=self.plot,
+ onlimits=False,
+ roi=roi)
+ x_start = int((roi_origin[0] - img_origin[0]) / img_scale)
+ x_end = int(x_start + (20 / img_scale)) + 1
+ y_start = int((roi_origin[1] - img_origin[1])/ img_scale)
+ y_end = int(y_start + (20 / img_scale)) + 1
+ x_start = max(x_start, 0)
+ x_end = min(max(x_end, 0), self.data_dims[1])
+ y_start = max(y_start, 0)
+ y_end = min(max(y_end, 0), self.data_dims[0])
+ th_sum = numpy.sum(self.data[y_start:y_end, x_start:x_end])
+ self.assertAlmostEqual(_stats['sum'].calculate(context),
+ th_sum)
diff --git a/src/silx/gui/plot/test/testUtilsAxis.py b/src/silx/gui/plot/test/testUtilsAxis.py
new file mode 100644
index 0000000..dd4a689
--- /dev/null
+++ b/src/silx/gui/plot/test/testUtilsAxis.py
@@ -0,0 +1,203 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for PlotWidget"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "20/11/2018"
+
+
+import unittest
+from silx.gui.plot import PlotWidget
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.plot.utils.axis import SyncAxes
+
+
+class TestAxisSync(TestCaseQt):
+ """Tests AxisSync class"""
+
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.plot1 = PlotWidget()
+ self.plot2 = PlotWidget()
+ self.plot3 = PlotWidget()
+
+ def tearDown(self):
+ self.plot1 = None
+ self.plot2 = None
+ self.plot3 = None
+ TestCaseQt.tearDown(self)
+
+ def testMoveFirstAxis(self):
+ """Test synchronization after construction"""
+ _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+
+ self.plot1.getXAxis().setLimits(10, 500)
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500))
+
+ def testMoveSecondAxis(self):
+ """Test synchronization after construction"""
+ _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+
+ self.plot2.getXAxis().setLimits(10, 500)
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500))
+
+ def testMoveTwoAxes(self):
+ """Test synchronization after construction"""
+ _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+
+ self.plot1.getXAxis().setLimits(1, 50)
+ self.plot2.getXAxis().setLimits(10, 500)
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500))
+
+ def testDestruction(self):
+ """Test synchronization when sync object is destroyed"""
+ sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+ del sync
+
+ self.plot1.getXAxis().setLimits(10, 500)
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ self.assertNotEqual(self.plot2.getXAxis().getLimits(), (10, 500))
+ self.assertNotEqual(self.plot3.getXAxis().getLimits(), (10, 500))
+
+ def testAxisDestruction(self):
+ """Test synchronization when an axis disappear"""
+ _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+
+ # Destroy the plot is possible
+ import weakref
+ plot = weakref.ref(self.plot2)
+ self.plot2 = None
+ result = self.qWaitForDestroy(plot)
+ if not result:
+ # We can't test
+ self.skipTest("Object not destroyed")
+
+ self.plot1.getXAxis().setLimits(10, 500)
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500))
+
+ def testStop(self):
+ """Test synchronization after calling stop"""
+ sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+ sync.stop()
+
+ self.plot1.getXAxis().setLimits(10, 500)
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ self.assertNotEqual(self.plot2.getXAxis().getLimits(), (10, 500))
+ self.assertNotEqual(self.plot3.getXAxis().getLimits(), (10, 500))
+
+ def testStopMovingStart(self):
+ """Test synchronization after calling stop, moving an axis, then start again"""
+ sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+ sync.stop()
+ self.plot1.getXAxis().setLimits(10, 500)
+ self.plot2.getXAxis().setLimits(1, 50)
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ sync.start()
+
+ # The first axis is the reference
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500))
+
+ def testDoubleStop(self):
+ """Test double stop"""
+ sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+ sync.stop()
+ self.assertRaises(RuntimeError, sync.stop)
+
+ def testDoubleStart(self):
+ """Test double stop"""
+ sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+ self.assertRaises(RuntimeError, sync.start)
+
+ def testScale(self):
+ """Test scale change"""
+ _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+ self.plot1.getXAxis().setScale(self.plot1.getXAxis().LOGARITHMIC)
+ self.assertEqual(self.plot1.getXAxis().getScale(), self.plot1.getXAxis().LOGARITHMIC)
+ self.assertEqual(self.plot2.getXAxis().getScale(), self.plot1.getXAxis().LOGARITHMIC)
+ self.assertEqual(self.plot3.getXAxis().getScale(), self.plot1.getXAxis().LOGARITHMIC)
+
+ def testDirection(self):
+ """Test direction change"""
+ _sync = SyncAxes([self.plot1.getYAxis(), self.plot2.getYAxis(), self.plot3.getYAxis()])
+ self.plot1.getYAxis().setInverted(True)
+ self.assertEqual(self.plot1.getYAxis().isInverted(), True)
+ self.assertEqual(self.plot2.getYAxis().isInverted(), True)
+ self.assertEqual(self.plot3.getYAxis().isInverted(), True)
+
+ def testSyncCenter(self):
+ """Test direction change"""
+ # Not the same scale
+ self.plot1.getXAxis().setLimits(0, 200)
+ self.plot2.getXAxis().setLimits(0, 20)
+ self.plot3.getXAxis().setLimits(0, 2)
+ _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()],
+ syncLimits=False, syncCenter=True)
+
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (0, 200))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (100 - 10, 100 + 10))
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (100 - 1, 100 + 1))
+
+ def testSyncCenterAndZoom(self):
+ """Test direction change"""
+ # Not the same scale
+ self.plot1.getXAxis().setLimits(0, 200)
+ self.plot2.getXAxis().setLimits(0, 20)
+ self.plot3.getXAxis().setLimits(0, 2)
+ _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()],
+ syncLimits=False, syncCenter=True, syncZoom=True)
+
+ # Supposing all the plots use the same size
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (0, 200))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (0, 200))
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (0, 200))
+
+ def testAddAxis(self):
+ """Test synchronization after construction"""
+ sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis()])
+ sync.addAxis(self.plot3.getXAxis())
+
+ self.plot1.getXAxis().setLimits(10, 500)
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500))
+
+ def testRemoveAxis(self):
+ """Test synchronization after construction"""
+ sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+ sync.removeAxis(self.plot3.getXAxis())
+
+ self.plot1.getXAxis().setLimits(10, 500)
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500))
+ self.assertNotEqual(self.plot3.getXAxis().getLimits(), (10, 500))
diff --git a/src/silx/gui/plot/test/utils.py b/src/silx/gui/plot/test/utils.py
new file mode 100644
index 0000000..64fca56
--- /dev/null
+++ b/src/silx/gui/plot/test/utils.py
@@ -0,0 +1,93 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for PlotWidget"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "26/01/2018"
+
+
+import logging
+import pytest
+import unittest
+
+from silx.gui.utils.testutils import TestCaseQt
+
+from silx.gui import qt
+from silx.gui.plot import PlotWidget
+
+
+logger = logging.getLogger(__name__)
+
+
+@pytest.mark.usefixtures("test_options_class_attr")
+class PlotWidgetTestCase(TestCaseQt):
+ """Base class for tests of PlotWidget, not a TestCase in itself.
+
+ plot attribute is the PlotWidget created for the test.
+ """
+ __screenshot_already_taken = False
+ backend = None
+
+ def _createPlot(self):
+ return PlotWidget(backend=self.backend)
+
+ def setUp(self):
+ super(PlotWidgetTestCase, self).setUp()
+ self.plot = self._createPlot()
+ self.plot.show()
+ self.plotAlive = True
+ self.qWaitForWindowExposed(self.plot)
+ TestCaseQt.mouseClick(self, self.plot, button=qt.Qt.LeftButton, pos=(0, 0))
+
+ def __onPlotDestroyed(self):
+ self.plotAlive = False
+
+ def _waitForPlotClosed(self):
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.destroyed.connect(self.__onPlotDestroyed)
+ self.plot.close()
+ del self.plot
+ for _ in range(100):
+ if not self.plotAlive:
+ break
+ self.qWait(10)
+ else:
+ logger.error("Plot is still alive")
+
+ def tearDown(self):
+ if not self._currentTestSucceeded():
+ # MPL is the only widget which uses the real system mouse.
+ # In case of a the windows is outside of the screen, minimzed,
+ # overlapped by a system popup, the MPL widget will not receive the
+ # mouse event.
+ # Taking a screenshot help debuging this cases in the continuous
+ # integration environement.
+ if not PlotWidgetTestCase.__screenshot_already_taken:
+ PlotWidgetTestCase.__screenshot_already_taken = True
+ self.logScreenShot()
+ self.qapp.processEvents()
+ self._waitForPlotClosed()
+ super(PlotWidgetTestCase, self).tearDown()
diff --git a/src/silx/gui/plot/tools/CurveLegendsWidget.py b/src/silx/gui/plot/tools/CurveLegendsWidget.py
new file mode 100644
index 0000000..4a517dd
--- /dev/null
+++ b/src/silx/gui/plot/tools/CurveLegendsWidget.py
@@ -0,0 +1,247 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a widget to display :class:`PlotWidget` curve legends.
+"""
+
+from __future__ import division
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "20/07/2018"
+
+
+import logging
+import weakref
+
+
+from ... import qt
+from ...widgets.FlowLayout import FlowLayout as _FlowLayout
+from ..LegendSelector import LegendIcon as _LegendIcon
+from .. import items
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _LegendWidget(qt.QWidget):
+ """Widget displaying curve style and its legend
+
+ :param QWidget parent: See :class:`QWidget`
+ :param ~silx.gui.plot.items.Curve curve: Associated curve
+ """
+
+ def __init__(self, parent, curve):
+ super(_LegendWidget, self).__init__(parent)
+ layout = qt.QHBoxLayout(self)
+ layout.setContentsMargins(10, 0, 10, 0)
+
+ curve.sigItemChanged.connect(self._curveChanged)
+
+ icon = _LegendIcon(curve=curve)
+ layout.addWidget(icon)
+
+ label = qt.QLabel(curve.getName())
+ label.setAlignment(qt.Qt.AlignLeft | qt.Qt.AlignVCenter)
+ layout.addWidget(label)
+
+ self._update()
+
+ def getCurve(self):
+ """Returns curve associated to this widget
+
+ :rtype: Union[~silx.gui.plot.items.Curve,None]
+ """
+ icon = self.findChild(_LegendIcon)
+ return icon.getCurve()
+
+ def _update(self):
+ """Update widget according to current curve state.
+ """
+ curve = self.getCurve()
+ if curve is None:
+ _logger.error('Curve no more exists')
+ self.setVisible(False)
+ return
+
+ self.setEnabled(curve.isVisible())
+
+ label = self.findChild(qt.QLabel)
+ if curve.isHighlighted():
+ label.setStyleSheet("border: 1px solid black")
+ else:
+ label.setStyleSheet("")
+
+ def _curveChanged(self, event):
+ """Handle update of curve item
+
+ :param event: Kind of change
+ """
+ if event in (items.ItemChangedType.VISIBLE,
+ items.ItemChangedType.HIGHLIGHTED,
+ items.ItemChangedType.HIGHLIGHTED_STYLE):
+ self._update()
+
+
+class CurveLegendsWidget(qt.QWidget):
+ """Widget displaying curves legends in a plot
+
+ :param QWidget parent: See :class:`QWidget`
+ """
+
+ sigCurveClicked = qt.Signal(object)
+ """Signal emitted when the legend of a curve is clicked
+
+ It provides the corresponding curve.
+ """
+
+ def __init__(self, parent=None):
+ super(CurveLegendsWidget, self).__init__(parent)
+ self._clicked = None
+ self._legends = {}
+ self._plotRef = None
+
+ def layout(self):
+ layout = super(CurveLegendsWidget, self).layout()
+ if layout is None:
+ # Lazy layout initialization to allow overloading
+ layout = _FlowLayout()
+ layout.setHorizontalSpacing(0)
+ self.setLayout(layout)
+ return layout
+
+ def getPlotWidget(self):
+ """Returns the associated :class:`PlotWidget`
+
+ :rtype: Union[~silx.gui.plot.PlotWidget,None]
+ """
+ return None if self._plotRef is None else self._plotRef()
+
+ def setPlotWidget(self, plot):
+ """Set the associated :class:`PlotWidget`
+
+ :param ~silx.gui.plot.PlotWidget plot: Plot widget to attach
+ """
+ previousPlot = self.getPlotWidget()
+ if previousPlot is not None:
+ previousPlot.sigItemAdded.disconnect( self._itemAdded)
+ previousPlot.sigItemAboutToBeRemoved.disconnect(self._itemRemoved)
+ for legend in list(self._legends.keys()):
+ self._removeLegend(legend)
+
+ self._plotRef = None if plot is None else weakref.ref(plot)
+
+ if plot is not None:
+ plot.sigItemAdded.connect(self._itemAdded)
+ plot.sigItemAboutToBeRemoved.connect(self._itemRemoved)
+
+ for legend in plot.getAllCurves(just_legend=True):
+ self._addLegend(legend)
+
+ def curveAt(self, *args):
+ """Returns the curve object represented at the given position
+
+ Either takes a QPoint or x and y as input in widget coordinates.
+
+ :rtype: Union[~silx.gui.plot.items.Curve,None]
+ """
+ if len(args) == 1:
+ point = args[0]
+ elif len(args) == 2:
+ point = qt.QPoint(*args)
+ else:
+ raise ValueError('Unsupported arguments')
+ assert isinstance(point, qt.QPoint)
+
+ widget = self.childAt(point)
+ while widget not in (self, None):
+ if isinstance(widget, _LegendWidget):
+ return widget.getCurve()
+ widget = widget.parent()
+ return None # No widget or not in _LegendWidget
+
+ def _itemAdded(self, item):
+ """Handle item added to the plot content"""
+ if isinstance(item, items.Curve):
+ self._addLegend(item.getName())
+
+ def _itemRemoved(self, item):
+ """Handle item removed from the plot content"""
+ if isinstance(item, items.Curve):
+ self._removeLegend(item.getName())
+
+ def _addLegend(self, legend):
+ """Add a curve to the legends
+
+ :param str legend: Curve's legend
+ """
+ if legend in self._legends:
+ return # Can happen when changing curve's y axis
+
+ plot = self.getPlotWidget()
+ if plot is None:
+ return None
+
+ curve = plot.getCurve(legend)
+ if curve is None:
+ _logger.error('Curve not found: %s' % legend)
+ return
+
+ widget = _LegendWidget(parent=self, curve=curve)
+ self.layout().addWidget(widget)
+ self._legends[legend] = widget
+
+ def _removeLegend(self, legend):
+ """Remove a curve from the legends if it exists
+
+ :param str legend: The curve's legend
+ """
+ widget = self._legends.pop(legend, None)
+ if widget is None:
+ _logger.warning('Unknown legend: %s' % legend)
+ else:
+ self.layout().removeWidget(widget)
+ widget.setParent(None)
+
+ def mousePressEvent(self, event):
+ if event.button() == qt.Qt.LeftButton:
+ self._clicked = event.pos()
+
+ _CLICK_THRESHOLD = 5
+ """Threshold for clicks"""
+
+ def mouseMoveEvent(self, event):
+ if self._clicked is not None:
+ dx = abs(self._clicked.x() - event.pos().x())
+ dy = abs(self._clicked.y() - event.pos().y())
+ if dx > self._CLICK_THRESHOLD or dy > self._CLICK_THRESHOLD:
+ self._clicked = None # Click is cancelled
+
+ def mouseReleaseEvent(self, event):
+ if event.button() == qt.Qt.LeftButton and self._clicked is not None:
+ curve = self.curveAt(event.pos())
+ if curve is not None:
+ self.sigCurveClicked.emit(curve)
+
+ self._clicked = None
diff --git a/src/silx/gui/plot/tools/LimitsToolBar.py b/src/silx/gui/plot/tools/LimitsToolBar.py
new file mode 100644
index 0000000..fc192a6
--- /dev/null
+++ b/src/silx/gui/plot/tools/LimitsToolBar.py
@@ -0,0 +1,131 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""A toolbar to display and edit limits of a PlotWidget
+"""
+
+
+from __future__ import division
+
+__authors__ = ["V.A. Sole", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "16/10/2017"
+
+
+from ... import qt
+from ...widgets.FloatEdit import FloatEdit
+
+
+class LimitsToolBar(qt.QToolBar):
+ """QToolBar displaying and controlling the limits of a :class:`PlotWidget`.
+
+ To run the following sample code, a QApplication must be initialized.
+ First, create a PlotWindow:
+
+ >>> from silx.gui.plot import PlotWindow
+ >>> plot = PlotWindow() # Create a PlotWindow to add the toolbar to
+
+ Then, create the LimitsToolBar and add it to the PlotWindow.
+
+ >>> from silx.gui import qt
+ >>> from silx.gui.plot.tools import LimitsToolBar
+
+ >>> toolbar = LimitsToolBar(plot=plot) # Create the toolbar
+ >>> plot.addToolBar(qt.Qt.BottomToolBarArea, toolbar) # Add it to the plot
+ >>> plot.show() # To display the PlotWindow with the limits toolbar
+
+ :param parent: See :class:`QToolBar`.
+ :param plot: :class:`PlotWidget` instance on which to operate.
+ :param str title: See :class:`QToolBar`.
+ """
+
+ def __init__(self, parent=None, plot=None, title='Limits'):
+ super(LimitsToolBar, self).__init__(title, parent)
+ assert plot is not None
+ self._plot = plot
+ self._plot.sigPlotSignal.connect(self._plotWidgetSlot)
+
+ self._initWidgets()
+
+ @property
+ def plot(self):
+ """The :class:`PlotWidget` the toolbar is attached to."""
+ return self._plot
+
+ def _initWidgets(self):
+ """Create and init Toolbar widgets."""
+ xMin, xMax = self.plot.getXAxis().getLimits()
+ yMin, yMax = self.plot.getYAxis().getLimits()
+
+ self.addWidget(qt.QLabel('Limits: '))
+ self.addWidget(qt.QLabel(' X: '))
+ self._xMinFloatEdit = FloatEdit(self, xMin)
+ self._xMinFloatEdit.editingFinished[()].connect(
+ self._xFloatEditChanged)
+ self.addWidget(self._xMinFloatEdit)
+
+ self._xMaxFloatEdit = FloatEdit(self, xMax)
+ self._xMaxFloatEdit.editingFinished[()].connect(
+ self._xFloatEditChanged)
+ self.addWidget(self._xMaxFloatEdit)
+
+ self.addWidget(qt.QLabel(' Y: '))
+ self._yMinFloatEdit = FloatEdit(self, yMin)
+ self._yMinFloatEdit.editingFinished[()].connect(
+ self._yFloatEditChanged)
+ self.addWidget(self._yMinFloatEdit)
+
+ self._yMaxFloatEdit = FloatEdit(self, yMax)
+ self._yMaxFloatEdit.editingFinished[()].connect(
+ self._yFloatEditChanged)
+ self.addWidget(self._yMaxFloatEdit)
+
+ def _plotWidgetSlot(self, event):
+ """Listen to :class:`PlotWidget` events."""
+ if event['event'] not in ('limitsChanged',):
+ return
+
+ xMin, xMax = self.plot.getXAxis().getLimits()
+ yMin, yMax = self.plot.getYAxis().getLimits()
+
+ self._xMinFloatEdit.setValue(xMin)
+ self._xMaxFloatEdit.setValue(xMax)
+ self._yMinFloatEdit.setValue(yMin)
+ self._yMaxFloatEdit.setValue(yMax)
+
+ def _xFloatEditChanged(self):
+ """Handle X limits changed from the GUI."""
+ xMin, xMax = self._xMinFloatEdit.value(), self._xMaxFloatEdit.value()
+ if xMax < xMin:
+ xMin, xMax = xMax, xMin
+
+ self.plot.getXAxis().setLimits(xMin, xMax)
+
+ def _yFloatEditChanged(self):
+ """Handle Y limits changed from the GUI."""
+ yMin, yMax = self._yMinFloatEdit.value(), self._yMaxFloatEdit.value()
+ if yMax < yMin:
+ yMin, yMax = yMax, yMin
+
+ self.plot.getYAxis().setLimits(yMin, yMax)
diff --git a/src/silx/gui/plot/tools/PositionInfo.py b/src/silx/gui/plot/tools/PositionInfo.py
new file mode 100644
index 0000000..8b95fbc
--- /dev/null
+++ b/src/silx/gui/plot/tools/PositionInfo.py
@@ -0,0 +1,373 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a widget displaying mouse coordinates in a PlotWidget.
+
+It can be configured to provide more information.
+"""
+
+from __future__ import division
+
+__authors__ = ["V.A. Sole", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "16/10/2017"
+
+
+import logging
+import numbers
+import traceback
+import weakref
+
+import numpy
+
+from ....utils.deprecation import deprecated
+from ... import qt
+from .. import items
+from ...widgets.ElidedLabel import ElidedLabel
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _PositionInfoLabel(ElidedLabel):
+ """QLabel with a default size larger than what is displayed."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.setTextInteractionFlags(qt.Qt.TextSelectableByMouse)
+
+ def sizeHint(self):
+ hint = super().sizeHint()
+ width = self.fontMetrics().boundingRect('##############').width()
+ return qt.QSize(max(hint.width(), width), hint.height())
+
+
+# PositionInfo ################################################################
+
+class PositionInfo(qt.QWidget):
+ """QWidget displaying coords converted from data coords of the mouse.
+
+ Provide this widget with a list of couple:
+
+ - A name to display before the data
+ - A function that takes (x, y) as arguments and returns something that
+ gets converted to a string.
+ If the result is a float it is converted with '%.7g' format.
+
+ To run the following sample code, a QApplication must be initialized.
+ First, create a PlotWindow and add a QToolBar where to place the
+ PositionInfo widget.
+
+ >>> from silx.gui.plot import PlotWindow
+ >>> from silx.gui import qt
+
+ >>> plot = PlotWindow() # Create a PlotWindow to add the widget to
+ >>> toolBar = qt.QToolBar() # Create a toolbar to place the widget in
+ >>> plot.addToolBar(qt.Qt.BottomToolBarArea, toolBar) # Add it to plot
+
+ Then, create the PositionInfo widget and add it to the toolbar.
+ The PositionInfo widget is created with a list of converters, here
+ to display polar coordinates of the mouse position.
+
+ >>> import numpy
+ >>> from silx.gui.plot.tools import PositionInfo
+
+ >>> position = PositionInfo(plot=plot, converters=[
+ ... ('Radius', lambda x, y: numpy.sqrt(x*x + y*y)),
+ ... ('Angle', lambda x, y: numpy.degrees(numpy.arctan2(y, x)))])
+ >>> toolBar.addWidget(position) # Add the widget to the toolbar
+ <...>
+ >>> plot.show() # To display the PlotWindow with the position widget
+
+ :param plot: The PlotWidget this widget is displaying data coords from.
+ :param converters:
+ List of 2-tuple: name to display and conversion function from (x, y)
+ in data coords to displayed value.
+ If None, the default, it displays X and Y.
+ :param parent: Parent widget
+ """
+
+ SNAP_THRESHOLD_DIST = 5
+
+ def __init__(self, parent=None, plot=None, converters=None):
+ assert plot is not None
+ self._plotRef = weakref.ref(plot)
+ self._snappingMode = self.SNAPPING_DISABLED
+
+ super(PositionInfo, self).__init__(parent)
+
+ if converters is None:
+ converters = (('X', lambda x, y: x), ('Y', lambda x, y: y))
+
+ self._fields = [] # To store (QLineEdit, name, function (x, y)->v)
+
+ # Create a new layout with new widgets
+ layout = qt.QHBoxLayout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ # layout.setSpacing(0)
+
+ # Create all QLabel and store them with the corresponding converter
+ for name, func in converters:
+ layout.addWidget(qt.QLabel('<b>' + name + ':</b>'))
+
+ contentWidget = _PositionInfoLabel(self)
+ contentWidget.setText('------')
+ layout.addWidget(contentWidget)
+ self._fields.append((contentWidget, name, func))
+
+ layout.addStretch(1)
+ self.setLayout(layout)
+
+ # Connect to Plot events
+ plot.sigPlotSignal.connect(self._plotEvent)
+
+ def getPlotWidget(self):
+ """Returns the PlotWidget this widget is attached to or None.
+
+ :rtype: Union[~silx.gui.plot.PlotWidget,None]
+ """
+ return self._plotRef()
+
+ @property
+ @deprecated(replacement='getPlotWidget', since_version='0.8.0')
+ def plot(self):
+ return self.getPlotWidget()
+
+ def getConverters(self):
+ """Return the list of converters as 2-tuple (name, function)."""
+ return [(name, func) for _label, name, func in self._fields]
+
+ def _plotEvent(self, event):
+ """Handle events from the Plot.
+
+ :param dict event: Plot event
+ """
+ if event['event'] == 'mouseMoved':
+ x, y = event['x'], event['y']
+ xPixel, yPixel = event['xpixel'], event['ypixel']
+ self._updateStatusBar(x, y, xPixel, yPixel)
+
+ def updateInfo(self):
+ """Update displayed information"""
+ plot = self.getPlotWidget()
+ if plot is None:
+ _logger.error("Trying to update PositionInfo "
+ "while PlotWidget no longer exists")
+ return
+
+ widget = plot.getWidgetHandle()
+ position = widget.mapFromGlobal(qt.QCursor.pos())
+ xPixel, yPixel = position.x(), position.y()
+ dataPos = plot.pixelToData(xPixel, yPixel, check=True)
+ if dataPos is not None: # Inside plot area
+ x, y = dataPos
+ self._updateStatusBar(x, y, xPixel, yPixel)
+
+ def _updateStatusBar(self, x, y, xPixel, yPixel):
+ """Update information from the status bar using the definitions.
+
+ :param float x: Position-x in data
+ :param float y: Position-y in data
+ :param float xPixel: Position-x in pixels
+ :param float yPixel: Position-y in pixels
+ """
+ plot = self.getPlotWidget()
+ if plot is None:
+ return
+
+ styleSheet = "color: rgb(0, 0, 0);" # Default style
+ xData, yData = x, y
+
+ snappingMode = self.getSnappingMode()
+
+ # Snapping when crosshair either not requested or active
+ if (snappingMode & (self.SNAPPING_CURVE | self.SNAPPING_SCATTER) and
+ (not (snappingMode & self.SNAPPING_CROSSHAIR) or
+ plot.getGraphCursor())):
+ styleSheet = "color: rgb(255, 0, 0);" # Style far from item
+
+ if snappingMode & self.SNAPPING_ACTIVE_ONLY:
+ selectedItems = []
+
+ if snappingMode & self.SNAPPING_CURVE:
+ activeCurve = plot.getActiveCurve()
+ if activeCurve:
+ selectedItems.append(activeCurve)
+
+ if snappingMode & self.SNAPPING_SCATTER:
+ activeScatter = plot._getActiveItem(kind='scatter')
+ if activeScatter:
+ selectedItems.append(activeScatter)
+
+ else:
+ kinds = []
+ if snappingMode & self.SNAPPING_CURVE:
+ kinds.append(items.Curve)
+ kinds.append(items.Histogram)
+ if snappingMode & self.SNAPPING_SCATTER:
+ kinds.append(items.Scatter)
+ selectedItems = [item for item in plot.getItems()
+ if isinstance(item, tuple(kinds)) and item.isVisible()]
+
+ # Compute distance threshold
+ window = plot.window()
+ windowHandle = window.windowHandle()
+ if windowHandle is not None:
+ ratio = windowHandle.devicePixelRatio()
+ else:
+ ratio = qt.QGuiApplication.primaryScreen().devicePixelRatio()
+
+ # Baseline squared distance threshold
+ distInPixels = (self.SNAP_THRESHOLD_DIST * ratio)**2
+
+ for item in selectedItems:
+ if (snappingMode & self.SNAPPING_SYMBOLS_ONLY and (
+ not isinstance(item, items.SymbolMixIn) or
+ not item.getSymbol())):
+ # Only handled if item symbols are visible
+ continue
+
+ if isinstance(item, items.Histogram):
+ result = item.pick(xPixel, yPixel)
+ if result is not None: # Histogram picked
+ index = result.getIndices()[0]
+ edges = item.getBinEdgesData(copy=False)
+
+ # Snap to bin center and value
+ xData = 0.5 * (edges[index] + edges[index + 1])
+ yData = item.getValueData(copy=False)[index]
+
+ # Update label style sheet
+ styleSheet = "color: rgb(0, 0, 0);"
+ break
+
+ else: # Curve, Scatter
+ xArray = item.getXData(copy=False)
+ yArray = item.getYData(copy=False)
+ closestIndex = numpy.argmin(
+ pow(xArray - x, 2) + pow(yArray - y, 2))
+
+ xClosest = xArray[closestIndex]
+ yClosest = yArray[closestIndex]
+
+ if isinstance(item, items.YAxisMixIn):
+ axis = item.getYAxis()
+ else:
+ axis = 'left'
+
+ closestInPixels = plot.dataToPixel(
+ xClosest, yClosest, axis=axis)
+ if closestInPixels is not None:
+ curveDistInPixels = (
+ (closestInPixels[0] - xPixel)**2 +
+ (closestInPixels[1] - yPixel)**2)
+
+ if curveDistInPixels <= distInPixels:
+ # Update label style sheet
+ styleSheet = "color: rgb(0, 0, 0);"
+
+ # if close enough, snap to data point coord
+ xData, yData = xClosest, yClosest
+ distInPixels = curveDistInPixels
+
+ for label, name, func in self._fields:
+ label.setStyleSheet(styleSheet)
+
+ try:
+ value = func(xData, yData)
+ text = self.valueToString(value)
+ label.setText(text)
+ except:
+ label.setText('Error')
+ _logger.error(
+ "Error while converting coordinates (%f, %f)"
+ "with converter '%s'" % (xPixel, yPixel, name))
+ _logger.error(traceback.format_exc())
+
+ def valueToString(self, value):
+ if isinstance(value, (tuple, list)):
+ value = [self.valueToString(v) for v in value]
+ return ", ".join(value)
+ elif isinstance(value, numbers.Real):
+ # Use this for floats and int
+ return '%.7g' % value
+ else:
+ # Fallback for other types
+ return str(value)
+
+ # Snapping mode
+
+ SNAPPING_DISABLED = 0
+ """No snapping occurs"""
+
+ SNAPPING_CROSSHAIR = 1 << 0
+ """Snapping only enabled when crosshair cursor is enabled"""
+
+ SNAPPING_ACTIVE_ONLY = 1 << 1
+ """Snapping only enabled for active item"""
+
+ SNAPPING_SYMBOLS_ONLY = 1 << 2
+ """Snapping only when symbols are visible"""
+
+ SNAPPING_CURVE = 1 << 3
+ """Snapping on curves"""
+
+ SNAPPING_SCATTER = 1 << 4
+ """Snapping on scatter"""
+
+ def setSnappingMode(self, mode):
+ """Set the snapping mode.
+
+ The mode is a mask.
+
+ :param int mode: The mode to use
+ """
+ if mode != self._snappingMode:
+ self._snappingMode = mode
+ self.updateInfo()
+
+ def getSnappingMode(self):
+ """Returns the snapping mode as a mask
+
+ :rtype: int
+ """
+ return self._snappingMode
+
+ _SNAPPING_LEGACY = (SNAPPING_CROSSHAIR |
+ SNAPPING_ACTIVE_ONLY |
+ SNAPPING_SYMBOLS_ONLY |
+ SNAPPING_CURVE |
+ SNAPPING_SCATTER)
+ """Legacy snapping mode"""
+
+ @property
+ @deprecated(replacement="getSnappingMode", since_version="0.8")
+ def autoSnapToActiveCurve(self):
+ return self.getSnappingMode() == self._SNAPPING_LEGACY
+
+ @autoSnapToActiveCurve.setter
+ @deprecated(replacement="setSnappingMode", since_version="0.8")
+ def autoSnapToActiveCurve(self, flag):
+ self.setSnappingMode(
+ self._SNAPPING_LEGACY if flag else self.SNAPPING_DISABLED)
diff --git a/src/silx/gui/plot/tools/RadarView.py b/src/silx/gui/plot/tools/RadarView.py
new file mode 100644
index 0000000..7076835
--- /dev/null
+++ b/src/silx/gui/plot/tools/RadarView.py
@@ -0,0 +1,361 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""QWidget displaying an overview of a 2D plot.
+
+This shows the available range of the data, and the current location of the
+plot view.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "22/02/2021"
+
+import logging
+import weakref
+from ... import qt
+from ...utils import LockReentrant
+
+_logger = logging.getLogger(__name__)
+
+
+class _DraggableRectItem(qt.QGraphicsRectItem):
+ """RectItem which signals its change through visibleRectDragged."""
+ def __init__(self, *args, **kwargs):
+ super(_DraggableRectItem, self).__init__(
+ *args, **kwargs)
+
+ self._previousCursor = None
+ self.setFlag(qt.QGraphicsItem.ItemIsMovable)
+ self.setFlag(qt.QGraphicsItem.ItemSendsGeometryChanges)
+ self.setAcceptHoverEvents(True)
+ self._ignoreChange = False
+ self._constraint = 0, 0, 0, 0
+
+ def setConstraintRect(self, left, top, width, height):
+ """Set the constraint rectangle for dragging.
+
+ The coordinates are in the _DraggableRectItem coordinate system.
+
+ This constraint only applies to modification through interaction
+ (i.e., this constraint is not applied to change through API).
+
+ If the _DraggableRectItem is smaller than the constraint rectangle,
+ the _DraggableRectItem remains within the constraint rectangle.
+ If the _DraggableRectItem is wider than the constraint rectangle,
+ the constraint rectangle remains within the _DraggableRectItem.
+ """
+ self._constraint = left, left + width, top, top + height
+
+ def setPos(self, *args, **kwargs):
+ """Overridden to ignore changes from API in itemChange."""
+ self._ignoreChange = True
+ super(_DraggableRectItem, self).setPos(*args, **kwargs)
+ self._ignoreChange = False
+
+ def moveBy(self, *args, **kwargs):
+ """Overridden to ignore changes from API in itemChange."""
+ self._ignoreChange = True
+ super(_DraggableRectItem, self).moveBy(*args, **kwargs)
+ self._ignoreChange = False
+
+ def itemChange(self, change, value):
+ """Callback called before applying changes to the item."""
+ if (change == qt.QGraphicsItem.ItemPositionChange and
+ not self._ignoreChange):
+ # Makes sure that the visible area is in the data
+ # or that data is in the visible area if area is too wide
+ x, y = value.x(), value.y()
+ xMin, xMax, yMin, yMax = self._constraint
+
+ if self.rect().width() <= (xMax - xMin):
+ if x < xMin:
+ value.setX(xMin)
+ elif x > xMax - self.rect().width():
+ value.setX(xMax - self.rect().width())
+ else:
+ if x > xMin:
+ value.setX(xMin)
+ elif x < xMax - self.rect().width():
+ value.setX(xMax - self.rect().width())
+
+ if self.rect().height() <= (yMax - yMin):
+ if y < yMin:
+ value.setY(yMin)
+ elif y > yMax - self.rect().height():
+ value.setY(yMax - self.rect().height())
+ else:
+ if y > yMin:
+ value.setY(yMin)
+ elif y < yMax - self.rect().height():
+ value.setY(yMax - self.rect().height())
+
+ if self.pos() != value:
+ # Notify change through signal
+ views = self.scene().views()
+ assert len(views) == 1
+ views[0].visibleRectDragged.emit(
+ value.x() + self.rect().left(),
+ value.y() + self.rect().top(),
+ self.rect().width(),
+ self.rect().height())
+
+ return value
+
+ return super(_DraggableRectItem, self).itemChange(
+ change, value)
+
+ def hoverEnterEvent(self, event):
+ """Called when the mouse enters the rectangle area"""
+ self._previousCursor = self.cursor()
+ self.setCursor(qt.Qt.OpenHandCursor)
+
+ def hoverLeaveEvent(self, event):
+ """Called when the mouse leaves the rectangle area"""
+ if self._previousCursor is not None:
+ self.setCursor(self._previousCursor)
+ self._previousCursor = None
+
+
+class RadarView(qt.QGraphicsView):
+ """Widget presenting a synthetic view of a 2D area and
+ the current visible area.
+
+ Coordinates are as in QGraphicsView:
+ x goes from left to right and y goes from top to bottom.
+ This widget preserves the aspect ratio of the areas.
+
+ The 2D area and the visible area can be set with :meth:`setDataRect`
+ and :meth:`setVisibleRect`.
+ When the visible area has been dragged by the user, its new position
+ is signaled by the *visibleRectDragged* signal.
+
+ It is possible to invert the direction of the axes by using the
+ :meth:`scale` method of QGraphicsView.
+ """
+
+ visibleRectDragged = qt.Signal(float, float, float, float)
+ """Signals that the visible rectangle has been dragged.
+
+ It provides: left, top, width, height in data coordinates.
+ """
+
+ _DATA_PEN = qt.QPen(qt.QColor('white'))
+ _DATA_BRUSH = qt.QBrush(qt.QColor('light gray'))
+ _ACTIVEDATA_PEN = qt.QPen(qt.QColor('black'))
+ _ACTIVEDATA_BRUSH = qt.QBrush(qt.QColor('transparent'))
+ _ACTIVEDATA_PEN.setWidth(2)
+ _ACTIVEDATA_PEN.setCosmetic(True)
+ _VISIBLE_PEN = qt.QPen(qt.QColor('blue'))
+ _VISIBLE_PEN.setWidth(2)
+ _VISIBLE_PEN.setCosmetic(True)
+ _VISIBLE_BRUSH = qt.QBrush(qt.QColor(0, 0, 0, 0))
+ _TOOLTIP = 'Radar View:\nRed contour: Visible area\nGray area: The image'
+
+ _PIXMAP_SIZE = 256
+
+ def __init__(self, parent=None):
+ self.__plotRef = None
+ self._scene = qt.QGraphicsScene()
+ self._dataRect = self._scene.addRect(0, 0, 1, 1,
+ self._DATA_PEN,
+ self._DATA_BRUSH)
+ self._imageRect = self._scene.addRect(0, 0, 1, 1,
+ self._ACTIVEDATA_PEN,
+ self._ACTIVEDATA_BRUSH)
+ self._imageRect.setVisible(False)
+ self._scatterRect = self._scene.addRect(0, 0, 1, 1,
+ self._ACTIVEDATA_PEN,
+ self._ACTIVEDATA_BRUSH)
+ self._scatterRect.setVisible(False)
+ self._curveRect = self._scene.addRect(0, 0, 1, 1,
+ self._ACTIVEDATA_PEN,
+ self._ACTIVEDATA_BRUSH)
+ self._curveRect.setVisible(False)
+
+ self._visibleRect = _DraggableRectItem(0, 0, 1, 1)
+ self._visibleRect.setPen(self._VISIBLE_PEN)
+ self._visibleRect.setBrush(self._VISIBLE_BRUSH)
+ self._scene.addItem(self._visibleRect)
+
+ super(RadarView, self).__init__(self._scene, parent)
+ self.setHorizontalScrollBarPolicy(qt.Qt.ScrollBarAlwaysOff)
+ self.setVerticalScrollBarPolicy(qt.Qt.ScrollBarAlwaysOff)
+ self.setFocusPolicy(qt.Qt.NoFocus)
+ self.setStyleSheet('border: 0px')
+ self.setToolTip(self._TOOLTIP)
+
+ self.__reentrant = LockReentrant()
+ self.visibleRectDragged.connect(self._viewRectDragged)
+
+ self.__timer = qt.QTimer(self)
+ self.__timer.timeout.connect(self._updateDataContent)
+
+ def sizeHint(self):
+ # """Overridden to avoid sizeHint to depend on content size."""
+ return self.minimumSizeHint()
+
+ def wheelEvent(self, event):
+ # """Overridden to disable vertical scrolling with wheel."""
+ event.ignore()
+
+ def resizeEvent(self, event):
+ # """Overridden to fit current content to new size."""
+ self.fitInView(self._scene.itemsBoundingRect(), qt.Qt.KeepAspectRatio)
+ super(RadarView, self).resizeEvent(event)
+
+ def setDataRect(self, left, top, width, height):
+ """Set the bounds of the data rectangular area.
+
+ This sets the coordinate system.
+ """
+ self._dataRect.setRect(left, top, width, height)
+ self._visibleRect.setConstraintRect(left, top, width, height)
+ self.fitInView(self._scene.itemsBoundingRect(), qt.Qt.KeepAspectRatio)
+
+ def setVisibleRect(self, left, top, width, height):
+ """Set the visible rectangular area.
+
+ The coordinates are relative to the data rect.
+ """
+ self.__visibleRect = left, top, width, height
+ self._visibleRect.setRect(0, 0, width, height)
+ self._visibleRect.setPos(left, top)
+ self.fitInView(self._scene.itemsBoundingRect(), qt.Qt.KeepAspectRatio)
+
+ def __setVisibleRectFromPlot(self, plot):
+ """Update radar view visible area.
+
+ Takes care of y coordinate conversion.
+ """
+ xMin, xMax = plot.getXAxis().getLimits()
+ yMin, yMax = plot.getYAxis().getLimits()
+ self.setVisibleRect(xMin, yMin, xMax - xMin, yMax - yMin)
+
+ def getPlotWidget(self):
+ """Returns the connected plot
+
+ :rtype: Union[None,PlotWidget]
+ """
+ if self.__plotRef is None:
+ return None
+ plot = self.__plotRef()
+ if plot is None:
+ self.__plotRef = None
+ return plot
+
+ def setPlotWidget(self, plot):
+ """Set the PlotWidget this radar view connects to.
+
+ As result `setDataRect` and `setVisibleRect` will be called
+ automatically.
+
+ :param Union[None,PlotWidget] plot:
+ """
+ previousPlot = self.getPlotWidget()
+ if previousPlot is not None: # Disconnect previous plot
+ plot.getXAxis().sigLimitsChanged.disconnect(self._xLimitChanged)
+ plot.getYAxis().sigLimitsChanged.disconnect(self._yLimitChanged)
+ plot.getYAxis().sigInvertedChanged.disconnect(self._updateYAxisInverted)
+
+ # Reset plot and timer
+ # FIXME: It would be good to clean up the display here
+ self.__plotRef = None
+ self.__timer.stop()
+
+ if plot is not None: # Connect new plot
+ self.__plotRef = weakref.ref(plot)
+ plot.getXAxis().sigLimitsChanged.connect(self._xLimitChanged)
+ plot.getYAxis().sigLimitsChanged.connect(self._yLimitChanged)
+ plot.getYAxis().sigInvertedChanged.connect(self._updateYAxisInverted)
+ self.__setVisibleRectFromPlot(plot)
+ self._updateYAxisInverted()
+ self.__timer.start(500)
+
+ def _xLimitChanged(self, vmin, vmax):
+ plot = self.getPlotWidget()
+ self.__setVisibleRectFromPlot(plot)
+
+ def _yLimitChanged(self, vmin, vmax):
+ plot = self.getPlotWidget()
+ self.__setVisibleRectFromPlot(plot)
+
+ def _updateYAxisInverted(self, inverted=None):
+ """Sync radar view axis orientation."""
+ plot = self.getPlotWidget()
+ if inverted is None:
+ # Do not perform this when called from plot signal
+ inverted = plot.getYAxis().isInverted()
+ # Use scale to invert radarView
+ # RadarView default Y direction is from top to bottom
+ # As opposed to Plot. So invert RadarView when Plot is NOT inverted.
+ self.resetTransform()
+ if not inverted:
+ self.scale(1., -1.)
+ self.update()
+
+ def _viewRectDragged(self, left, top, width, height):
+ """Slot for radar view visible rectangle changes."""
+ plot = self.getPlotWidget()
+ if plot is None:
+ return
+
+ if self.__reentrant.locked():
+ return
+
+ with self.__reentrant:
+ plot.setLimits(left, left + width, top, top + height)
+
+ def _updateDataContent(self):
+ """Update the content to the current data content"""
+ plot = self.getPlotWidget()
+ if plot is None:
+ return
+ ranges = plot.getDataRange()
+ xmin, xmax = ranges.x if ranges.x is not None else (0, 0)
+ ymin, ymax = ranges.y if ranges.y is not None else (0, 0)
+ self.setDataRect(xmin, ymin, xmax - xmin, ymax - ymin)
+
+ self.__updateItem(self._imageRect, plot.getActiveImage())
+ self.__updateItem(self._scatterRect, plot.getActiveScatter())
+ self.__updateItem(self._curveRect, plot.getActiveCurve())
+
+ def __updateItem(self, rect, item):
+ """Sync rect with item bounds
+
+ :param QGraphicsRectItem rect:
+ :param Item item:
+ """
+ if item is None:
+ rect.setVisible(False)
+ return
+ ranges = item._getBounds()
+ if ranges is None:
+ rect.setVisible(False)
+ return
+ xmin, xmax, ymin, ymax = ranges
+ width = xmax - xmin
+ height = ymax - ymin
+ rect.setRect(xmin, ymin, width, height)
+ rect.setVisible(True)
diff --git a/src/silx/gui/plot/tools/__init__.py b/src/silx/gui/plot/tools/__init__.py
new file mode 100644
index 0000000..09f468c
--- /dev/null
+++ b/src/silx/gui/plot/tools/__init__.py
@@ -0,0 +1,50 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides a set of widgets working with :class:`PlotWidget`.
+
+It provides some QToolBar and QWidget:
+
+- :class:`InteractiveModeToolBar`
+- :class:`OutputToolBar`
+- :class:`ImageToolBar`
+- :class:`CurveToolBar`
+- :class:`LimitsToolBar`
+- :class:`PositionInfo`
+
+It also provides a :mod:`~silx.gui.plot.tools.roi` module to handle
+interactive region of interest on a :class:`~silx.gui.plot.PlotWidget`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "01/03/2018"
+
+
+from .toolbars import InteractiveModeToolBar # noqa
+from .toolbars import OutputToolBar # noqa
+from .toolbars import ImageToolBar, CurveToolBar, ScatterToolBar # noqa
+
+from .LimitsToolBar import LimitsToolBar # noqa
+from .PositionInfo import PositionInfo # noqa
diff --git a/src/silx/gui/plot/tools/profile/ScatterProfileToolBar.py b/src/silx/gui/plot/tools/profile/ScatterProfileToolBar.py
new file mode 100644
index 0000000..44187ef
--- /dev/null
+++ b/src/silx/gui/plot/tools/profile/ScatterProfileToolBar.py
@@ -0,0 +1,54 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module profile tools for scatter plots.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+
+from silx.utils import deprecation
+from . import toolbar
+
+
+class ScatterProfileToolBar(toolbar.ProfileToolBar):
+ """QToolBar providing scatter plot profiling tools
+
+ :param parent: See :class:`QToolBar`.
+ :param plot: :class:`~silx.gui.plot.PlotWidget` on which to operate.
+ :param str title: See :class:`QToolBar`.
+ """
+
+ def __init__(self, parent=None, plot=None, title=None):
+ super(ScatterProfileToolBar, self).__init__(parent, plot)
+ if title is not None:
+ deprecation.deprecated_warning("Attribute",
+ name="title",
+ reason="removed",
+ since_version="0.13.0",
+ only_once=True,
+ skip_backtrace_count=1)
+ self.setScheme("scatter")
diff --git a/src/silx/gui/plot/tools/profile/__init__.py b/src/silx/gui/plot/tools/profile/__init__.py
new file mode 100644
index 0000000..d91191e
--- /dev/null
+++ b/src/silx/gui/plot/tools/profile/__init__.py
@@ -0,0 +1,38 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides tools to get profiles on plot data.
+
+It provides:
+
+- :class:`ScatterProfileToolBar`: a QToolBar to handle profile on scatter data
+
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "07/06/2018"
+
+
+from .ScatterProfileToolBar import ScatterProfileToolBar # noqa
diff --git a/src/silx/gui/plot/tools/profile/core.py b/src/silx/gui/plot/tools/profile/core.py
new file mode 100644
index 0000000..200f5cf
--- /dev/null
+++ b/src/silx/gui/plot/tools/profile/core.py
@@ -0,0 +1,525 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module define core objects for profile tools.
+"""
+
+__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel", "H. Payno", "V. Valls"]
+__license__ = "MIT"
+__date__ = "17/04/2020"
+
+import collections
+import numpy
+import weakref
+
+from silx.image.bilinear import BilinearImage
+from silx.gui import qt
+
+
+CurveProfileData = collections.namedtuple(
+ 'CurveProfileData', [
+ "coords",
+ "profile",
+ "title",
+ "xLabel",
+ "yLabel",
+ ])
+
+RgbaProfileData = collections.namedtuple(
+ 'RgbaProfileData', [
+ "coords",
+ "profile",
+ "profile_r",
+ "profile_g",
+ "profile_b",
+ "profile_a",
+ "title",
+ "xLabel",
+ "yLabel",
+ ])
+
+ImageProfileData = collections.namedtuple(
+ 'ImageProfileData', [
+ 'coords',
+ 'profile',
+ 'title',
+ 'xLabel',
+ 'yLabel',
+ 'colormap',
+ ])
+
+
+class ProfileRoiMixIn:
+ """Base mix-in for ROI which can be used to select a profile.
+
+ This mix-in have to be applied to a :class:`~silx.gui.plot.items.roi.RegionOfInterest`
+ in order to be usable by a :class:`~silx.gui.plot.tools.profile.manager.ProfileManager`.
+ """
+
+ ITEM_KIND = None
+ """Define the plot item which can be used with this profile ROI"""
+
+ sigProfilePropertyChanged = qt.Signal()
+ """Emitted when a property of this profile have changed"""
+
+ sigPlotItemChanged = qt.Signal()
+ """Emitted when the plot item linked to this profile have changed"""
+
+ def __init__(self, parent=None):
+ self.__profileWindow = None
+ self.__profileManager = None
+ self.__plotItem = None
+ self.setName("Profile")
+ self.setEditable(True)
+ self.setSelectable(True)
+
+ def invalidateProfile(self):
+ """Must be called by the implementation when the profile have to be
+ recomputed."""
+ profileManager = self.getProfileManager()
+ if profileManager is not None:
+ profileManager.requestUpdateProfile(self)
+
+ def invalidateProperties(self):
+ """Must be called when a property of the profile have changed."""
+ self.sigProfilePropertyChanged.emit()
+
+ def _setPlotItem(self, plotItem):
+ """Specify the plot item to use with this profile
+
+ :param `~silx.gui.plot.items.item.Item` plotItem: A plot item
+ """
+ previousPlotItem = self.getPlotItem()
+ if previousPlotItem is plotItem:
+ return
+ self.__plotItem = weakref.ref(plotItem)
+ self.sigPlotItemChanged.emit()
+
+ def getPlotItem(self):
+ """Returns the plot item used by this profile
+
+ :rtype: `~silx.gui.plot.items.item.Item`
+ """
+ if self.__plotItem is None:
+ return None
+ plotItem = self.__plotItem()
+ if plotItem is None:
+ self.__plotItem = None
+ return plotItem
+
+ def _setProfileManager(self, profileManager):
+ self.__profileManager = profileManager
+
+ def getProfileManager(self):
+ """
+ Returns the profile manager connected to this ROI.
+
+ :rtype: ~silx.gui.plot.tools.profile.manager.ProfileManager
+ """
+ return self.__profileManager
+
+ def getProfileWindow(self):
+ """
+ Returns the windows associated to this ROI, else None.
+
+ :rtype: ProfileWindow
+ """
+ return self.__profileWindow
+
+ def setProfileWindow(self, profileWindow):
+ """
+ Associate a window to this ROI. Can be None.
+
+ :param ProfileWindow profileWindow: A main window
+ to display the profile.
+ """
+ if profileWindow is self.__profileWindow:
+ return
+ if self.__profileWindow is not None:
+ self.__profileWindow.sigClose.disconnect(self.__profileWindowAboutToClose)
+ self.__profileWindow.setRoiProfile(None)
+ self.__profileWindow = profileWindow
+ if self.__profileWindow is not None:
+ self.__profileWindow.sigClose.connect(self.__profileWindowAboutToClose)
+ self.__profileWindow.setRoiProfile(self)
+
+ def __profileWindowAboutToClose(self):
+ profileManager = self.getProfileManager()
+ roiManager = profileManager.getRoiManager()
+ try:
+ roiManager.removeRoi(self)
+ except ValueError:
+ pass
+
+ def computeProfile(self, item):
+ """
+ Compute the profile which will be displayed.
+
+ This method is not called from the main Qt thread, but from a thread
+ pool.
+
+ :param ~silx.gui.plot.items.Item item: A plot item
+ :rtype: Union[CurveProfileData,ImageProfileData]
+ """
+ raise NotImplementedError()
+
+
+def _alignedFullProfile(data, origin, scale, position, roiWidth, axis, method):
+ """Get a profile along one axis on a stack of images
+
+ :param numpy.ndarray data: 3D volume (stack of 2D images)
+ The first dimension is the image index.
+ :param origin: Origin of image in plot (ox, oy)
+ :param scale: Scale of image in plot (sx, sy)
+ :param float position: Position of profile line in plot coords
+ on the axis orthogonal to the profile direction.
+ :param int roiWidth: Width of the profile in image pixels.
+ :param int axis: 0 for horizontal profile, 1 for vertical.
+ :param str method: method to compute the profile. Can be 'mean' or 'sum' or
+ 'none'
+ :return: profile image + effective ROI area corners in plot coords
+ """
+ assert axis in (0, 1)
+ assert len(data.shape) == 3
+ assert method in ('mean', 'sum', 'none')
+
+ # Convert from plot to image coords
+ imgPos = int((position - origin[1 - axis]) / scale[1 - axis])
+
+ if axis == 1: # Vertical profile
+ # Transpose image to always do a horizontal profile
+ data = numpy.transpose(data, (0, 2, 1))
+
+ nimages, height, width = data.shape
+
+ roiWidth = min(height, roiWidth) # Clip roi width to image size
+
+ # Get [start, end[ coords of the roi in the data
+ start = int(int(imgPos) + 0.5 - roiWidth / 2.)
+ start = min(max(0, start), height - roiWidth)
+ end = start + roiWidth
+
+ if method == 'none':
+ profile = None
+ else:
+ if start < height and end > 0:
+ if method == 'mean':
+ fct = numpy.mean
+ elif method == 'sum':
+ fct = numpy.sum
+ else:
+ raise ValueError('method not managed')
+ profile = fct(data[:, max(0, start):min(end, height), :], axis=1).astype(numpy.float32)
+ else:
+ profile = numpy.zeros((nimages, width), dtype=numpy.float32)
+
+ # Compute effective ROI in plot coords
+ profileBounds = numpy.array(
+ (0, width, width, 0),
+ dtype=numpy.float32) * scale[axis] + origin[axis]
+ roiBounds = numpy.array(
+ (start, start, end, end),
+ dtype=numpy.float32) * scale[1 - axis] + origin[1 - axis]
+
+ if axis == 0: # Horizontal profile
+ area = profileBounds, roiBounds
+ else: # vertical profile
+ area = roiBounds, profileBounds
+
+ return profile, area
+
+
+def _alignedPartialProfile(data, rowRange, colRange, axis, method):
+ """Mean of a rectangular region (ROI) of a stack of images
+ along a given axis.
+
+ Returned values and all parameters are in image coordinates.
+
+ :param numpy.ndarray data: 3D volume (stack of 2D images)
+ The first dimension is the image index.
+ :param rowRange: [min, max[ of ROI rows (upper bound excluded).
+ :type rowRange: 2-tuple of int (min, max) with min < max
+ :param colRange: [min, max[ of ROI columns (upper bound excluded).
+ :type colRange: 2-tuple of int (min, max) with min < max
+ :param int axis: The axis along which to take the profile of the ROI.
+ 0: Sum rows along columns.
+ 1: Sum columns along rows.
+ :param str method: method to compute the profile. Can be 'mean' or 'sum'
+ :return: Profile image along the ROI as the mean of the intersection
+ of the ROI and the image.
+ """
+ assert axis in (0, 1)
+ assert len(data.shape) == 3
+ assert rowRange[0] < rowRange[1]
+ assert colRange[0] < colRange[1]
+ assert method in ('mean', 'sum')
+
+ nimages, height, width = data.shape
+
+ # Range aligned with the integration direction
+ profileRange = colRange if axis == 0 else rowRange
+
+ profileLength = abs(profileRange[1] - profileRange[0])
+
+ # Subset of the image to use as intersection of ROI and image
+ rowStart = min(max(0, rowRange[0]), height)
+ rowEnd = min(max(0, rowRange[1]), height)
+ colStart = min(max(0, colRange[0]), width)
+ colEnd = min(max(0, colRange[1]), width)
+
+ if method == 'mean':
+ _fct = numpy.mean
+ elif method == 'sum':
+ _fct = numpy.sum
+ else:
+ raise ValueError('method not managed')
+
+ imgProfile = _fct(data[:, rowStart:rowEnd, colStart:colEnd], axis=axis + 1,
+ dtype=numpy.float32)
+
+ # Profile including out of bound area
+ profile = numpy.zeros((nimages, profileLength), dtype=numpy.float32)
+
+ # Place imgProfile in full profile
+ offset = - min(0, profileRange[0])
+ profile[:, offset:offset + imgProfile.shape[1]] = imgProfile
+
+ return profile
+
+
+def createProfile(roiInfo, currentData, origin, scale, lineWidth, method):
+ """Create the profile line for the the given image.
+
+ :param roiInfo: information about the ROI: start point, end point and
+ type ("X", "Y", "D")
+ :param numpy.ndarray currentData: the 2D image or the 3D stack of images
+ on which we compute the profile.
+ :param origin: (ox, oy) the offset from origin
+ :type origin: 2-tuple of float
+ :param scale: (sx, sy) the scale to use
+ :type scale: 2-tuple of float
+ :param int lineWidth: width of the profile line
+ :param str method: method to compute the profile. Can be 'mean' or 'sum'
+ or 'none': to compute everything except the profile
+ :return: `coords, profile, area, profileName, xLabel`, where:
+ - coords is the X coordinate to use to display the profile
+ - profile is a 2D array of the profiles of the stack of images.
+ For a single image, the profile is a curve, so this parameter
+ has a shape *(1, len(curve))*
+ - area is a tuple of two 1D arrays with 4 values each. They represent
+ the effective ROI area corners in plot coords.
+ - profileName is a string describing the ROI, meant to be used as
+ title of the profile plot
+ - xLabel the label for X in the profile window
+
+ :rtype: tuple(ndarray,ndarray,(ndarray,ndarray),str)
+ """
+ if currentData is None or roiInfo is None or lineWidth is None:
+ raise ValueError("createProfile called with invalide arguments")
+
+ # force 3D data (stack of images)
+ if len(currentData.shape) == 2:
+ currentData3D = currentData.reshape((1,) + currentData.shape)
+ elif len(currentData.shape) == 3:
+ currentData3D = currentData
+
+ roiWidth = max(1, lineWidth)
+ roiStart, roiEnd, lineProjectionMode = roiInfo
+
+ if lineProjectionMode == 'X': # Horizontal profile on the whole image
+ profile, area = _alignedFullProfile(currentData3D,
+ origin, scale,
+ roiStart[1], roiWidth,
+ axis=0,
+ method=method)
+
+ if method == 'none':
+ coords = None
+ else:
+ coords = numpy.arange(len(profile[0]), dtype=numpy.float32)
+ coords = coords * scale[0] + origin[0]
+
+ yMin, yMax = min(area[1]), max(area[1]) - 1
+ if roiWidth <= 1:
+ profileName = '{ylabel} = %g' % yMin
+ else:
+ profileName = '{ylabel} = [%g, %g]' % (yMin, yMax)
+ xLabel = '{xlabel}'
+
+ elif lineProjectionMode == 'Y': # Vertical profile on the whole image
+ profile, area = _alignedFullProfile(currentData3D,
+ origin, scale,
+ roiStart[0], roiWidth,
+ axis=1,
+ method=method)
+
+ if method == 'none':
+ coords = None
+ else:
+ coords = numpy.arange(len(profile[0]), dtype=numpy.float32)
+ coords = coords * scale[1] + origin[1]
+
+ xMin, xMax = min(area[0]), max(area[0]) - 1
+ if roiWidth <= 1:
+ profileName = '{xlabel} = %g' % xMin
+ else:
+ profileName = '{xlabel} = [%g, %g]' % (xMin, xMax)
+ xLabel = '{ylabel}'
+
+ else: # Free line profile
+
+ # Convert start and end points in image coords as (row, col)
+ startPt = ((roiStart[1] - origin[1]) / scale[1],
+ (roiStart[0] - origin[0]) / scale[0])
+ endPt = ((roiEnd[1] - origin[1]) / scale[1],
+ (roiEnd[0] - origin[0]) / scale[0])
+
+ if (int(startPt[0]) == int(endPt[0]) or
+ int(startPt[1]) == int(endPt[1])):
+ # Profile is aligned with one of the axes
+
+ # Convert to int
+ startPt = int(startPt[0]), int(startPt[1])
+ endPt = int(endPt[0]), int(endPt[1])
+
+ # Ensure startPt <= endPt
+ if startPt[0] > endPt[0] or startPt[1] > endPt[1]:
+ startPt, endPt = endPt, startPt
+
+ if startPt[0] == endPt[0]: # Row aligned
+ rowRange = (int(startPt[0] + 0.5 - 0.5 * roiWidth),
+ int(startPt[0] + 0.5 + 0.5 * roiWidth))
+ colRange = startPt[1], endPt[1] + 1
+ if method == 'none':
+ profile = None
+ else:
+ profile = _alignedPartialProfile(currentData3D,
+ rowRange, colRange,
+ axis=0,
+ method=method)
+
+ else: # Column aligned
+ rowRange = startPt[0], endPt[0] + 1
+ colRange = (int(startPt[1] + 0.5 - 0.5 * roiWidth),
+ int(startPt[1] + 0.5 + 0.5 * roiWidth))
+ if method == 'none':
+ profile = None
+ else:
+ profile = _alignedPartialProfile(currentData3D,
+ rowRange, colRange,
+ axis=1,
+ method=method)
+ # Convert ranges to plot coords to draw ROI area
+ area = (
+ numpy.array(
+ (colRange[0], colRange[1], colRange[1], colRange[0]),
+ dtype=numpy.float32) * scale[0] + origin[0],
+ numpy.array(
+ (rowRange[0], rowRange[0], rowRange[1], rowRange[1]),
+ dtype=numpy.float32) * scale[1] + origin[1])
+
+ else: # General case: use bilinear interpolation
+
+ # Ensure startPt <= endPt
+ if (startPt[1] > endPt[1] or (
+ startPt[1] == endPt[1] and startPt[0] > endPt[0])):
+ startPt, endPt = endPt, startPt
+
+ if method == 'none':
+ profile = None
+ else:
+ profile = []
+ for slice_idx in range(currentData3D.shape[0]):
+ bilinear = BilinearImage(currentData3D[slice_idx, :, :])
+
+ profile.append(bilinear.profile_line(
+ (startPt[0] - 0.5, startPt[1] - 0.5),
+ (endPt[0] - 0.5, endPt[1] - 0.5),
+ roiWidth,
+ method=method))
+ profile = numpy.array(profile)
+
+ # Extend ROI with half a pixel on each end, and
+ # Convert back to plot coords (x, y)
+ length = numpy.sqrt((endPt[0] - startPt[0]) ** 2 +
+ (endPt[1] - startPt[1]) ** 2)
+ dRow = (endPt[0] - startPt[0]) / length
+ dCol = (endPt[1] - startPt[1]) / length
+
+ # Extend ROI with half a pixel on each end
+ roiStartPt = startPt[0] - 0.5 * dRow, startPt[1] - 0.5 * dCol
+ roiEndPt = endPt[0] + 0.5 * dRow, endPt[1] + 0.5 * dCol
+
+ # Rotate deltas by 90 degrees to apply line width
+ dRow, dCol = dCol, -dRow
+
+ area = (
+ numpy.array((roiStartPt[1] - 0.5 * roiWidth * dCol,
+ roiStartPt[1] + 0.5 * roiWidth * dCol,
+ roiEndPt[1] + 0.5 * roiWidth * dCol,
+ roiEndPt[1] - 0.5 * roiWidth * dCol),
+ dtype=numpy.float32) * scale[0] + origin[0],
+ numpy.array((roiStartPt[0] - 0.5 * roiWidth * dRow,
+ roiStartPt[0] + 0.5 * roiWidth * dRow,
+ roiEndPt[0] + 0.5 * roiWidth * dRow,
+ roiEndPt[0] - 0.5 * roiWidth * dRow),
+ dtype=numpy.float32) * scale[1] + origin[1])
+
+ # Convert start and end points back to plot coords
+ y0 = startPt[0] * scale[1] + origin[1]
+ x0 = startPt[1] * scale[0] + origin[0]
+ y1 = endPt[0] * scale[1] + origin[1]
+ x1 = endPt[1] * scale[0] + origin[0]
+
+ if startPt[1] == endPt[1]:
+ profileName = '{xlabel} = %g; {ylabel} = [%g, %g]' % (x0, y0, y1)
+ if method == 'none':
+ coords = None
+ else:
+ coords = numpy.arange(len(profile[0]), dtype=numpy.float32)
+ coords = coords * scale[1] + y0
+ xLabel = '{ylabel}'
+
+ elif startPt[0] == endPt[0]:
+ profileName = '{ylabel} = %g; {xlabel} = [%g, %g]' % (y0, x0, x1)
+ if method == 'none':
+ coords = None
+ else:
+ coords = numpy.arange(len(profile[0]), dtype=numpy.float32)
+ coords = coords * scale[0] + x0
+ xLabel = '{xlabel}'
+
+ else:
+ m = (y1 - y0) / (x1 - x0)
+ b = y0 - m * x0
+ profileName = '{ylabel} = %g * {xlabel} %+g' % (m, b)
+ if method == 'none':
+ coords = None
+ else:
+ coords = numpy.linspace(x0, x1, len(profile[0]),
+ endpoint=True,
+ dtype=numpy.float32)
+ xLabel = '{xlabel}'
+
+ return coords, profile, area, profileName, xLabel
diff --git a/src/silx/gui/plot/tools/profile/editors.py b/src/silx/gui/plot/tools/profile/editors.py
new file mode 100644
index 0000000..80e0452
--- /dev/null
+++ b/src/silx/gui/plot/tools/profile/editors.py
@@ -0,0 +1,307 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides editors which are used to custom profile ROI properties.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+import logging
+
+from silx.gui import qt
+
+from silx.gui.utils import blockSignals
+from silx.gui.plot.PlotToolButtons import ProfileOptionToolButton
+from silx.gui.plot.PlotToolButtons import ProfileToolButton
+from . import rois
+from . import core
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _NoProfileRoiEditor(qt.QWidget):
+
+ sigDataCommited = qt.Signal()
+
+ def setEditorData(self, roi):
+ pass
+
+ def setRoiData(self, roi):
+ pass
+
+
+class _DefaultImageProfileRoiEditor(qt.QWidget):
+
+ sigDataCommited = qt.Signal()
+
+ def __init__(self, parent=None):
+ qt.QWidget.__init__(self, parent=parent)
+ layout = qt.QHBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ self._initLayout(layout)
+
+ def _initLayout(self, layout):
+ self._lineWidth = qt.QSpinBox(self)
+ self._lineWidth.setRange(1, 1000)
+ self._lineWidth.setValue(1)
+ self._lineWidth.valueChanged[int].connect(self._widgetChanged)
+
+ self._methodsButton = ProfileOptionToolButton(parent=self, plot=None)
+ self._methodsButton.sigMethodChanged.connect(self._widgetChanged)
+
+ label = qt.QLabel('W:')
+ label.setToolTip("Line width in pixels")
+ layout.addWidget(label)
+ layout.addWidget(self._lineWidth)
+ layout.addWidget(self._methodsButton)
+
+ def _widgetChanged(self, value=None):
+ self.commitData()
+
+ def commitData(self):
+ self.sigDataCommited.emit()
+
+ def setEditorData(self, roi):
+ with blockSignals(self._lineWidth):
+ self._lineWidth.setValue(roi.getProfileLineWidth())
+ with blockSignals(self._methodsButton):
+ method = roi.getProfileMethod()
+ self._methodsButton.setMethod(method)
+
+ def setRoiData(self, roi):
+ lineWidth = self._lineWidth.value()
+ roi.setProfileLineWidth(lineWidth)
+ method = self._methodsButton.getMethod()
+ roi.setProfileMethod(method)
+
+
+class _DefaultImageStackProfileRoiEditor(_DefaultImageProfileRoiEditor):
+
+ def _initLayout(self, layout):
+ super(_DefaultImageStackProfileRoiEditor, self)._initLayout(layout)
+ self._profileDim = ProfileToolButton(parent=self, plot=None)
+ self._profileDim.sigDimensionChanged.connect(self._widgetChanged)
+ layout.addWidget(self._profileDim)
+
+ def setEditorData(self, roi):
+ super(_DefaultImageStackProfileRoiEditor, self).setEditorData(roi)
+ with blockSignals(self._profileDim):
+ kind = roi.getProfileType()
+ dim = {"1D": 1, "2D": 2}[kind]
+ self._profileDim.setDimension(dim)
+
+ def setRoiData(self, roi):
+ super(_DefaultImageStackProfileRoiEditor, self).setRoiData(roi)
+ dim = self._profileDim.getDimension()
+ kind = {1: "1D", 2: "2D"}[dim]
+ roi.setProfileType(kind)
+
+
+class _DefaultScatterProfileRoiEditor(qt.QWidget):
+
+ sigDataCommited = qt.Signal()
+
+ def __init__(self, parent=None):
+ qt.QWidget.__init__(self, parent=parent)
+
+ self._nPoints = qt.QSpinBox(self)
+ self._nPoints.setRange(1, 9999)
+ self._nPoints.setValue(1024)
+ self._nPoints.valueChanged[int].connect(self.__widgetChanged)
+
+ layout = qt.QHBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ label = qt.QLabel('Samples:')
+ label.setToolTip("Number of sample points of the profile")
+ layout.addWidget(label)
+ layout.addWidget(self._nPoints)
+
+ def __widgetChanged(self, value=None):
+ self.commitData()
+
+ def commitData(self):
+ self.sigDataCommited.emit()
+
+ def setEditorData(self, roi):
+ with blockSignals(self._nPoints):
+ self._nPoints.setValue(roi.getNPoints())
+
+ def setRoiData(self, roi):
+ nPoints = self._nPoints.value()
+ roi.setNPoints(nPoints)
+
+
+class ProfileRoiEditorAction(qt.QWidgetAction):
+ """
+ Action displaying GUI to edit the selected ROI.
+
+ :param qt.QWidget parent: Parent widget
+ """
+ def __init__(self, parent=None):
+ super(ProfileRoiEditorAction, self).__init__(parent)
+ self.__roiManager = None
+ self.__roi = None
+ self.__inhibiteReentance = None
+
+ def createWidget(self, parent):
+ """Inherit the method to create a new editor"""
+ widget = qt.QWidget(parent)
+ layout = qt.QHBoxLayout(widget)
+ if isinstance(parent, qt.QMenu):
+ margins = layout.contentsMargins()
+ layout.setContentsMargins(margins.left(), 0, margins.right(), 0)
+ else:
+ layout.setContentsMargins(0, 0, 0, 0)
+
+ editorClass = self.getEditorClass(self.__roi)
+ editor = editorClass(parent)
+ editor.setEditorData(self.__roi)
+ self.__setEditor(widget, editor)
+ return widget
+
+ def deleteWidget(self, widget):
+ """Inherit the method to delete an editor"""
+ self.__setEditor(widget, None)
+ return qt.QWidgetAction.deleteWidget(self, widget)
+
+ def _getEditor(self, widget):
+ """Returns the editor contained in the widget holder"""
+ layout = widget.layout()
+ if layout.count() == 0:
+ return None
+ return layout.itemAt(0).widget()
+
+ def setRoiManager(self, roiManager):
+ """
+ Connect this action to a ROI manager.
+
+ :param RegionOfInterestManager roiManager: A ROI manager
+ """
+ if self.__roiManager is roiManager:
+ return
+ if self.__roiManager is not None:
+ self.__roiManager.sigCurrentRoiChanged.disconnect(self.__currentRoiChanged)
+ self.__roiManager = roiManager
+ if self.__roiManager is not None:
+ self.__roiManager.sigCurrentRoiChanged.connect(self.__currentRoiChanged)
+ self.__currentRoiChanged(roiManager.getCurrentRoi())
+
+ def __currentRoiChanged(self, roi):
+ """Handle changes of the selected ROI"""
+ if roi is not None and not isinstance(roi, core.ProfileRoiMixIn):
+ return
+ self.setProfileRoi(roi)
+
+ def setProfileRoi(self, roi):
+ """Set a profile ROI to edit.
+
+ :param ProfileRoiMixIn roi: A profile ROI
+ """
+ if self.__roi is roi:
+ return
+ if self.__roi is not None:
+ self.__roi.sigProfilePropertyChanged.disconnect(self.__roiPropertyChanged)
+ self.__roi = roi
+ if self.__roi is not None:
+ self.__roi.sigProfilePropertyChanged.connect(self.__roiPropertyChanged)
+ self._updateWidgets()
+
+ def getRoiProfile(self):
+ """Returns the edited profile ROI.
+
+ :rtype: ProfileRoiMixIn
+ """
+ return self.__roi
+
+ def __roiPropertyChanged(self):
+ """Handle changes on the property defining the ROI.
+ """
+ self._updateWidgetValues()
+
+ def __setEditor(self, widget, editor):
+ """Set the editor to display.
+
+ :param qt.QWidget editor: The editor to display
+ """
+ previousEditor = self._getEditor(widget)
+ if previousEditor is editor:
+ return
+ layout = widget.layout()
+ if previousEditor is not None:
+ previousEditor.sigDataCommited.disconnect(self._editorDataCommited)
+ layout.removeWidget(previousEditor)
+ previousEditor.deleteLater()
+ if editor is not None:
+ editor.sigDataCommited.connect(self._editorDataCommited)
+ layout.addWidget(editor)
+
+ def getEditorClass(self, roi):
+ """Returns the editor class to use according to the ROI."""
+ if roi is None:
+ editorClass = _NoProfileRoiEditor
+ elif isinstance(roi, (rois._DefaultImageStackProfileRoiMixIn,
+ rois.ProfileImageStackCrossROI)):
+ # Must be done before the default image ROI
+ # Cause ImageStack ROIs inherit from Image ROIs
+ editorClass = _DefaultImageStackProfileRoiEditor
+ elif isinstance(roi, (rois._DefaultImageProfileRoiMixIn,
+ rois.ProfileImageCrossROI)):
+ editorClass = _DefaultImageProfileRoiEditor
+ elif isinstance(roi, (rois._DefaultScatterProfileRoiMixIn,
+ rois.ProfileScatterCrossROI)):
+ editorClass = _DefaultScatterProfileRoiEditor
+ else:
+ # Unsupported
+ editorClass = _NoProfileRoiEditor
+ return editorClass
+
+ def _updateWidgets(self):
+ """Update the kind of editor to display, according to the selected
+ profile ROI."""
+ parent = self.parent()
+ editorClass = self.getEditorClass(self.__roi)
+ for widget in self.createdWidgets():
+ editor = editorClass(parent)
+ editor.setEditorData(self.__roi)
+ self.__setEditor(widget, editor)
+
+ def _updateWidgetValues(self):
+ """Update the content of the displayed editor, according to the
+ selected profile ROI."""
+ for widget in self.createdWidgets():
+ editor = self._getEditor(widget)
+ if self.__inhibiteReentance is editor:
+ continue
+ editor.setEditorData(self.__roi)
+
+ def _editorDataCommited(self):
+ """Handle changes from the editor."""
+ editor = self.sender()
+ if self.__roi is not None:
+ self.__inhibiteReentance = editor
+ editor.setRoiData(self.__roi)
+ self.__inhibiteReentance = None
diff --git a/src/silx/gui/plot/tools/profile/manager.py b/src/silx/gui/plot/tools/profile/manager.py
new file mode 100644
index 0000000..4a22bc0
--- /dev/null
+++ b/src/silx/gui/plot/tools/profile/manager.py
@@ -0,0 +1,1079 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a manager to compute and display profiles.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+import logging
+import weakref
+
+from silx.gui import qt
+from silx.gui import colors
+from silx.gui import utils
+
+from silx.utils.weakref import WeakMethodProxy
+from silx.gui import icons
+from silx.gui.plot import PlotWidget
+from silx.gui.plot.tools.roi import RegionOfInterestManager
+from silx.gui.plot.tools.roi import CreateRoiModeAction
+from silx.gui.plot import items
+from silx.gui.qt import silxGlobalThreadPool
+from silx.gui.qt import inspect
+from . import rois
+from . import core
+from . import editors
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _RunnableComputeProfile(qt.QRunnable):
+ """Runner to process profiles
+
+ :param qt.QThreadPool threadPool: The thread which will be used to
+ execute this runner. It is used to update the used signals
+ :param ~silx.gui.plot.items.Item item: Item in which the profile is
+ computed
+ :param ~silx.gui.plot.tools.profile.core.ProfileRoiMixIn roi: ROI
+ defining the profile shape and other characteristics
+ """
+
+ class _Signals(qt.QObject):
+ """Signal holder"""
+ resultReady = qt.Signal(object, object)
+ runnerFinished = qt.Signal(object)
+
+ def __init__(self, threadPool, item, roi):
+ """Constructor
+ """
+ super(_RunnableComputeProfile, self).__init__()
+ self._signals = self._Signals()
+ self._signals.moveToThread(threadPool.thread())
+ self._item = item
+ self._roi = roi
+ self._cancelled = False
+
+ def _lazyCancel(self):
+ """Cancel the runner if it is not yet started.
+
+ The threadpool will still execute the runner, but this will process
+ nothing.
+
+ This is only used with Qt<5.9 where QThreadPool.tryTake is not available.
+ """
+ self._cancelled = True
+
+ def autoDelete(self):
+ return False
+
+ def getRoi(self):
+ """Returns the ROI in which the runner will compute a profile.
+
+ :rtype: ~silx.gui.plot.tools.profile.core.ProfileRoiMixIn
+ """
+ return self._roi
+
+ @property
+ def resultReady(self):
+ """Signal emitted when the result of the computation is available.
+
+ This signal provides 2 values: The ROI, and the computation result.
+ """
+ return self._signals.resultReady
+
+ @property
+ def runnerFinished(self):
+ """Signal emitted when runner have finished.
+
+ This signal provides a single value: the runner itself.
+ """
+ return self._signals.runnerFinished
+
+ def run(self):
+ """Process the profile computation.
+ """
+ if not self._cancelled:
+ try:
+ profileData = self._roi.computeProfile(self._item)
+ except Exception:
+ _logger.error("Error while computing profile", exc_info=True)
+ else:
+ self.resultReady.emit(self._roi, profileData)
+ self.runnerFinished.emit(self)
+
+
+class ProfileWindow(qt.QMainWindow):
+ """
+ Display a computed profile.
+
+ The content can be described using :meth:`setRoiProfile` if the source of
+ the profile is a profile ROI, and :meth:`setProfile` for the data content.
+ """
+
+ sigClose = qt.Signal()
+ """Emitted by :meth:`closeEvent` (e.g. when the window is closed
+ through the window manager's close icon)."""
+
+ def __init__(self, parent=None, backend=None):
+ qt.QMainWindow.__init__(self, parent=parent, flags=qt.Qt.Dialog)
+
+ self.setWindowTitle('Profile window')
+ self._plot1D = None
+ self._plot2D = None
+ self._backend = backend
+ self._data = None
+
+ widget = qt.QWidget()
+ self._layout = qt.QStackedLayout(widget)
+ self._layout.setContentsMargins(0, 0, 0, 0)
+ self.setCentralWidget(widget)
+
+ def prepareWidget(self, roi):
+ """Called before the show to prepare the window to use with
+ a specific ROI."""
+ if isinstance(roi, rois._DefaultImageStackProfileRoiMixIn):
+ profileType = roi.getProfileType()
+ else:
+ profileType = "1D"
+ if profileType == "1D":
+ self.getPlot1D()
+ elif profileType == "2D":
+ self.getPlot2D()
+
+ def createPlot1D(self, parent, backend):
+ """Inherit this function to create your own plot to render 1D
+ profiles. The default value is a `Plot1D`.
+
+ :param parent: The parent of this widget or None.
+ :param backend: The backend to use for the plot.
+ See :class:`PlotWidget` for the list of supported backend.
+ :rtype: PlotWidget
+ """
+ # import here to avoid circular import
+ from ...PlotWindow import Plot1D
+ plot = Plot1D(parent=parent, backend=backend)
+ plot.setDataMargins(yMinMargin=0.1, yMaxMargin=0.1)
+ plot.setGraphYLabel('Profile')
+ plot.setGraphXLabel('')
+ return plot
+
+ def createPlot2D(self, parent, backend):
+ """Inherit this function to create your own plot to render 2D
+ profiles. The default value is a `Plot2D`.
+
+ :param parent: The parent of this widget or None.
+ :param backend: The backend to use for the plot.
+ See :class:`PlotWidget` for the list of supported backend.
+ :rtype: PlotWidget
+ """
+ # import here to avoid circular import
+ from ...PlotWindow import Plot2D
+ return Plot2D(parent=parent, backend=backend)
+
+ def getPlot1D(self, init=True):
+ """Return the current plot used to display curves and create it if it
+ does not yet exists and `init` is True. Else returns None."""
+ if not init:
+ return self._plot1D
+ if self._plot1D is None:
+ self._plot1D = self.createPlot1D(self, self._backend)
+ self._layout.addWidget(self._plot1D)
+ return self._plot1D
+
+ def _showPlot1D(self):
+ plot = self.getPlot1D()
+ self._layout.setCurrentWidget(plot)
+
+ def getPlot2D(self, init=True):
+ """Return the current plot used to display image and create it if it
+ does not yet exists and `init` is True. Else returns None."""
+ if not init:
+ return self._plot2D
+ if self._plot2D is None:
+ self._plot2D = self.createPlot2D(parent=self, backend=self._backend)
+ self._layout.addWidget(self._plot2D)
+ return self._plot2D
+
+ def _showPlot2D(self):
+ plot = self.getPlot2D()
+ self._layout.setCurrentWidget(plot)
+
+ def getCurrentPlotWidget(self):
+ return self._layout.currentWidget()
+
+ def closeEvent(self, qCloseEvent):
+ self.sigClose.emit()
+ qCloseEvent.accept()
+
+ def setRoiProfile(self, roi):
+ """Set the profile ROI which it the source of the following data
+ to display.
+
+ :param ProfileRoiMixIn roi: The profile ROI data source
+ """
+ if roi is None:
+ return
+ self.__color = colors.rgba(roi.getColor())
+
+ def _setImageProfile(self, data):
+ """
+ Setup the window to display a new profile data which is represented
+ by an image.
+
+ :param core.ImageProfileData data: Computed data profile
+ """
+ plot = self.getPlot2D()
+
+ plot.clear()
+ plot.setGraphTitle(data.title)
+ plot.getXAxis().setLabel(data.xLabel)
+
+
+ coords = data.coords
+ colormap = data.colormap
+ profileScale = (coords[-1] - coords[0]) / data.profile.shape[1], 1
+ plot.addImage(data.profile,
+ legend="profile",
+ colormap=colormap,
+ origin=(coords[0], 0),
+ scale=profileScale)
+ plot.getYAxis().setLabel("Frame index (depth)")
+
+ self._showPlot2D()
+
+ def _setCurveProfile(self, data):
+ """
+ Setup the window to display a new profile data which is represented
+ by a curve.
+
+ :param core.CurveProfileData data: Computed data profile
+ """
+ plot = self.getPlot1D()
+
+ plot.clear()
+ plot.setGraphTitle(data.title)
+ plot.getXAxis().setLabel(data.xLabel)
+ plot.getYAxis().setLabel(data.yLabel)
+
+ plot.addCurve(data.coords,
+ data.profile,
+ legend="level",
+ color=self.__color)
+
+ self._showPlot1D()
+
+ def _setRgbaProfile(self, data):
+ """
+ Setup the window to display a new profile data which is represented
+ by a curve.
+
+ :param core.RgbaProfileData data: Computed data profile
+ """
+ plot = self.getPlot1D()
+
+ plot.clear()
+ plot.setGraphTitle(data.title)
+ plot.getXAxis().setLabel(data.xLabel)
+ plot.getYAxis().setLabel(data.yLabel)
+
+ self._showPlot1D()
+
+ plot.addCurve(data.coords, data.profile,
+ legend="level", color="black")
+ plot.addCurve(data.coords, data.profile_r,
+ legend="red", color="red")
+ plot.addCurve(data.coords, data.profile_g,
+ legend="green", color="green")
+ plot.addCurve(data.coords, data.profile_b,
+ legend="blue", color="blue")
+ if data.profile_a is not None:
+ plot.addCurve(data.coords, data.profile_a, legend="alpha", color="gray")
+
+ def clear(self):
+ """Clear the window profile"""
+ plot = self.getPlot1D(init=False)
+ if plot is not None:
+ plot.clear()
+ plot = self.getPlot2D(init=False)
+ if plot is not None:
+ plot.clear()
+
+ def getProfile(self):
+ """Returns the profile data which is displayed"""
+ return self.__data
+
+ def setProfile(self, data):
+ """
+ Setup the window to display a new profile data.
+
+ This method dispatch the result to a specific method according to the
+ data type.
+
+ :param data: Computed data profile
+ """
+ self.__data = data
+ if data is None:
+ self.clear()
+ elif isinstance(data, core.ImageProfileData):
+ self._setImageProfile(data)
+ elif isinstance(data, core.RgbaProfileData):
+ self._setRgbaProfile(data)
+ elif isinstance(data, core.CurveProfileData):
+ self._setCurveProfile(data)
+ else:
+ raise TypeError("Unsupported type %s" % type(data))
+
+
+class _ClearAction(qt.QAction):
+ """Action to clear the profile manager
+
+ The action is only enabled if something can be cleaned up.
+ """
+
+ def __init__(self, parent, profileManager):
+ super(_ClearAction, self).__init__(parent)
+ self.__profileManager = weakref.ref(profileManager)
+ icon = icons.getQIcon('profile-clear')
+ self.setIcon(icon)
+ self.setText('Clear profile')
+ self.setToolTip('Clear the profiles')
+ self.setCheckable(False)
+ self.setEnabled(False)
+ self.triggered.connect(profileManager.clearProfile)
+ plot = profileManager.getPlotWidget()
+ roiManager = profileManager.getRoiManager()
+ plot.sigInteractiveModeChanged.connect(self.__modeUpdated)
+ roiManager.sigRoiChanged.connect(self.__roiListUpdated)
+
+ def getProfileManager(self):
+ return self.__profileManager()
+
+ def __roiListUpdated(self):
+ self.__update()
+
+ def __modeUpdated(self, source):
+ self.__update()
+
+ def __update(self):
+ profileManager = self.getProfileManager()
+ if profileManager is None:
+ return
+ roiManager = profileManager.getRoiManager()
+ if roiManager is None:
+ return
+ enabled = roiManager.isStarted() or len(roiManager.getRois()) > 0
+ self.setEnabled(enabled)
+
+
+class _StoreLastParamBehavior(qt.QObject):
+ """This object allow to store and restore the properties of the ROI
+ profiles"""
+
+ def __init__(self, parent):
+ assert isinstance(parent, ProfileManager)
+ super(_StoreLastParamBehavior, self).__init__(parent=parent)
+ self.__properties = {}
+ self.__profileRoi = None
+ self.__filter = utils.LockReentrant()
+
+ def _roi(self):
+ """Return the spied ROI"""
+ if self.__profileRoi is None:
+ return None
+ roi = self.__profileRoi()
+ if roi is None:
+ self.__profileRoi = None
+ return roi
+
+ def setProfileRoi(self, roi):
+ """Set a profile ROI to spy.
+
+ :param ProfileRoiMixIn roi: A profile ROI
+ """
+ previousRoi = self._roi()
+ if previousRoi is roi:
+ return
+ if previousRoi is not None:
+ previousRoi.sigProfilePropertyChanged.disconnect(self._profilePropertyChanged)
+ self.__profileRoi = None if roi is None else weakref.ref(roi)
+ if roi is not None:
+ roi.sigProfilePropertyChanged.connect(self._profilePropertyChanged)
+
+ def _profilePropertyChanged(self):
+ """Handle changes on the properties defining the profile ROI.
+ """
+ if self.__filter.locked():
+ return
+ roi = self.sender()
+ self.storeProperties(roi)
+
+ def storeProperties(self, roi):
+ if isinstance(roi, (rois._DefaultImageStackProfileRoiMixIn,
+ rois.ProfileImageStackCrossROI)):
+ self.__properties["method"] = roi.getProfileMethod()
+ self.__properties["line-width"] = roi.getProfileLineWidth()
+ self.__properties["type"] = roi.getProfileType()
+ elif isinstance(roi, (rois._DefaultImageProfileRoiMixIn,
+ rois.ProfileImageCrossROI)):
+ self.__properties["method"] = roi.getProfileMethod()
+ self.__properties["line-width"] = roi.getProfileLineWidth()
+ elif isinstance(roi, (rois._DefaultScatterProfileRoiMixIn,
+ rois.ProfileScatterCrossROI)):
+ self.__properties["npoints"] = roi.getNPoints()
+
+ def restoreProperties(self, roi):
+ with self.__filter:
+ if isinstance(roi, (rois._DefaultImageStackProfileRoiMixIn,
+ rois.ProfileImageStackCrossROI)):
+ value = self.__properties.get("method", None)
+ if value is not None:
+ roi.setProfileMethod(value)
+ value = self.__properties.get("line-width", None)
+ if value is not None:
+ roi.setProfileLineWidth(value)
+ value = self.__properties.get("type", None)
+ if value is not None:
+ roi.setProfileType(value)
+ elif isinstance(roi, (rois._DefaultImageProfileRoiMixIn,
+ rois.ProfileImageCrossROI)):
+ value = self.__properties.get("method", None)
+ if value is not None:
+ roi.setProfileMethod(value)
+ value = self.__properties.get("line-width", None)
+ if value is not None:
+ roi.setProfileLineWidth(value)
+ elif isinstance(roi, (rois._DefaultScatterProfileRoiMixIn,
+ rois.ProfileScatterCrossROI)):
+ value = self.__properties.get("npoints", None)
+ if value is not None:
+ roi.setNPoints(value)
+
+
+class ProfileManager(qt.QObject):
+ """Base class for profile management tools
+
+ :param plot: :class:`~silx.gui.plot.PlotWidget` on which to operate.
+ :param plot: :class:`~silx.gui.plot.tools.roi.RegionOfInterestManager`
+ on which to operate.
+ """
+ def __init__(self, parent=None, plot=None, roiManager=None):
+ super(ProfileManager, self).__init__(parent)
+
+ assert isinstance(plot, PlotWidget)
+ self._plotRef = weakref.ref(
+ plot, WeakMethodProxy(self.__plotDestroyed))
+
+ # Set-up interaction manager
+ if roiManager is None:
+ roiManager = RegionOfInterestManager(plot)
+
+ self._roiManagerRef = weakref.ref(roiManager)
+ self._rois = []
+ self._pendingRunners = []
+ """List of ROIs which have to be updated"""
+
+ self.__reentrantResults = {}
+ """Store reentrant result to avoid to skip some of them
+ cause the implementation uses a QEventLoop."""
+
+ self._profileWindowClass = ProfileWindow
+ """Class used to display the profile results"""
+
+ self._computedProfiles = 0
+ """Statistics for tests"""
+
+ self.__itemTypes = []
+ """Kind of items to use"""
+
+ self.__tracking = False
+ """Is the plot active items are tracked"""
+
+ self.__useColorFromCursor = True
+ """If true, force the ROI color with the colormap marker color"""
+
+ self._item = None
+ """The selected item"""
+
+ self.__singleProfileAtATime = True
+ """When it's true, only a single profile is displayed at a time."""
+
+ self._previousWindowGeometry = []
+
+ self._storeProperties = _StoreLastParamBehavior(self)
+ """If defined the profile properties of the last ROI are reused to the
+ new created ones"""
+
+ # Listen to plot limits changed
+ plot.getXAxis().sigLimitsChanged.connect(self.requestUpdateAllProfile)
+ plot.getYAxis().sigLimitsChanged.connect(self.requestUpdateAllProfile)
+
+ roiManager.sigInteractiveModeFinished.connect(self.__interactionFinished)
+ roiManager.sigInteractiveRoiCreated.connect(self.__roiCreated)
+ roiManager.sigRoiAdded.connect(self.__roiAdded)
+ roiManager.sigRoiAboutToBeRemoved.connect(self.__roiRemoved)
+
+ def setSingleProfile(self, enable):
+ """
+ Enable or disable the single profile mode.
+
+ In single mode, the manager enforce a single ROI at the same
+ time. A new one will remove the previous one.
+
+ If this mode is not enabled, many ROIs can be created, and many
+ profile windows will be displayed.
+ """
+ self.__singleProfileAtATime = enable
+
+ def isSingleProfile(self):
+ """
+ Returns true if the manager is in a single profile mode.
+
+ :rtype: bool
+ """
+ return self.__singleProfileAtATime
+
+ def __interactionFinished(self):
+ """Handle end of interactive mode"""
+ pass
+
+ def __roiAdded(self, roi):
+ """Handle new ROI"""
+ # Filter out non profile ROIs
+ if not isinstance(roi, core.ProfileRoiMixIn):
+ return
+ self.__addProfile(roi)
+
+ def __roiRemoved(self, roi):
+ """Handle removed ROI"""
+ # Filter out non profile ROIs
+ if not isinstance(roi, core.ProfileRoiMixIn):
+ return
+ self.__removeProfile(roi)
+
+ def createProfileAction(self, profileRoiClass, parent=None):
+ """Create an action from a class of ProfileRoi
+
+ :param core.ProfileRoiMixIn profileRoiClass: A class of a profile ROI
+ :param qt.QObject parent: The parent of the created action.
+ :rtype: qt.QAction
+ """
+ if not issubclass(profileRoiClass, core.ProfileRoiMixIn):
+ raise TypeError("Type %s not expected" % type(profileRoiClass))
+ roiManager = self.getRoiManager()
+ action = CreateRoiModeAction(parent, roiManager, profileRoiClass)
+ if hasattr(profileRoiClass, "ICON"):
+ action.setIcon(icons.getQIcon(profileRoiClass.ICON))
+ if hasattr(profileRoiClass, "NAME"):
+ def articulify(word):
+ """Add an an/a article in the front of the word"""
+ first = word[1] if word[0] == 'h' else word[0]
+ if first in "aeiou":
+ return "an " + word
+ return "a " + word
+ action.setText('Define %s' % articulify(profileRoiClass.NAME))
+ action.setToolTip('Enables %s selection mode' % profileRoiClass.NAME)
+ action.setSingleShot(True)
+ return action
+
+ def createClearAction(self, parent):
+ """Create an action to clean up the plot from the profile ROIs.
+
+ :param qt.QObject parent: The parent of the created action.
+ :rtype: qt.QAction
+ """
+ action = _ClearAction(parent, self)
+ return action
+
+ def createImageActions(self, parent):
+ """Create actions designed for image items. This actions created
+ new ROIs.
+
+ :param qt.QObject parent: The parent of the created action.
+ :rtype: List[qt.QAction]
+ """
+ profileClasses = [
+ rois.ProfileImageHorizontalLineROI,
+ rois.ProfileImageVerticalLineROI,
+ rois.ProfileImageLineROI,
+ rois.ProfileImageDirectedLineROI,
+ rois.ProfileImageCrossROI,
+ ]
+ return [self.createProfileAction(pc, parent=parent) for pc in profileClasses]
+
+ def createScatterActions(self, parent):
+ """Create actions designed for scatter items. This actions created
+ new ROIs.
+
+ :param qt.QObject parent: The parent of the created action.
+ :rtype: List[qt.QAction]
+ """
+ profileClasses = [
+ rois.ProfileScatterHorizontalLineROI,
+ rois.ProfileScatterVerticalLineROI,
+ rois.ProfileScatterLineROI,
+ rois.ProfileScatterCrossROI,
+ ]
+ return [self.createProfileAction(pc, parent=parent) for pc in profileClasses]
+
+ def createScatterSliceActions(self, parent):
+ """Create actions designed for regular scatter items. This actions
+ created new ROIs.
+
+ This ROIs was designed to use the input data without interpolation,
+ like you could do with an image.
+
+ :param qt.QObject parent: The parent of the created action.
+ :rtype: List[qt.QAction]
+ """
+ profileClasses = [
+ rois.ProfileScatterHorizontalSliceROI,
+ rois.ProfileScatterVerticalSliceROI,
+ rois.ProfileScatterCrossSliceROI,
+ ]
+ return [self.createProfileAction(pc, parent=parent) for pc in profileClasses]
+
+ def createImageStackActions(self, parent):
+ """Create actions designed for stack image items. This actions
+ created new ROIs.
+
+ This ROIs was designed to create both profile on the displayed image
+ and profile on the full stack (2D result).
+
+ :param qt.QObject parent: The parent of the created action.
+ :rtype: List[qt.QAction]
+ """
+ profileClasses = [
+ rois.ProfileImageStackHorizontalLineROI,
+ rois.ProfileImageStackVerticalLineROI,
+ rois.ProfileImageStackLineROI,
+ rois.ProfileImageStackCrossROI,
+ ]
+ return [self.createProfileAction(pc, parent=parent) for pc in profileClasses]
+
+ def createEditorAction(self, parent):
+ """Create an action containing GUI to edit the selected profile ROI.
+
+ :param qt.QObject parent: The parent of the created action.
+ :rtype: qt.QAction
+ """
+ action = editors.ProfileRoiEditorAction(parent)
+ action.setRoiManager(self.getRoiManager())
+ return action
+
+ def setItemType(self, image=False, scatter=False):
+ """Set the item type to use and select the active one.
+
+ :param bool image: Image item are allowed
+ :param bool scatter: Scatter item are allowed
+ """
+ self.__itemTypes = []
+ plot = self.getPlotWidget()
+ item = None
+ if image:
+ self.__itemTypes.append("image")
+ item = plot.getActiveImage()
+ if scatter:
+ self.__itemTypes.append("scatter")
+ if item is None:
+ item = plot.getActiveScatter()
+ self.setPlotItem(item)
+
+ def setProfileWindowClass(self, profileWindowClass):
+ """Set the class which will be instantiated to display profile result.
+ """
+ self._profileWindowClass = profileWindowClass
+
+ def setActiveItemTracking(self, tracking):
+ """Enable/disable the tracking of the active item of the plot.
+
+ :param bool tracking: Tracking mode
+ """
+ if self.__tracking == tracking:
+ return
+ plot = self.getPlotWidget()
+ if self.__tracking:
+ plot.sigActiveImageChanged.disconnect(self._activeImageChanged)
+ plot.sigActiveScatterChanged.disconnect(self._activeScatterChanged)
+ self.__tracking = tracking
+ if self.__tracking:
+ plot.sigActiveImageChanged.connect(self.__activeImageChanged)
+ plot.sigActiveScatterChanged.connect(self.__activeScatterChanged)
+
+ def setDefaultColorFromCursorColor(self, enabled):
+ """Enabled/disable the use of the colormap cursor color to display the
+ ROIs.
+
+ If set, the manager will update the color of the profile ROIs using the
+ current colormap cursor color from the selected item.
+ """
+ self.__useColorFromCursor = enabled
+
+ def __activeImageChanged(self, previous, legend):
+ """Handle plot item selection"""
+ if "image" in self.__itemTypes:
+ plot = self.getPlotWidget()
+ item = plot.getImage(legend)
+ self.setPlotItem(item)
+
+ def __activeScatterChanged(self, previous, legend):
+ """Handle plot item selection"""
+ if "scatter" in self.__itemTypes:
+ plot = self.getPlotWidget()
+ item = plot.getScatter(legend)
+ self.setPlotItem(item)
+
+ def __roiCreated(self, roi):
+ """Handle ROI creation"""
+ # Filter out non profile ROIs
+ if isinstance(roi, core.ProfileRoiMixIn):
+ if self._storeProperties is not None:
+ # Initialize the properties with the previous ones
+ self._storeProperties.restoreProperties(roi)
+
+ def __addProfile(self, profileRoi):
+ """Add a new ROI to the manager."""
+ if profileRoi.getFocusProxy() is None:
+ if self._storeProperties is not None:
+ # Follow changes on properties
+ self._storeProperties.setProfileRoi(profileRoi)
+ if self.__singleProfileAtATime:
+ # FIXME: It would be good to reuse the windows to avoid blinking
+ self.clearProfile()
+
+ profileRoi._setProfileManager(self)
+ self._updateRoiColor(profileRoi)
+ self._rois.append(profileRoi)
+ self.requestUpdateProfile(profileRoi)
+
+ def __removeProfile(self, profileRoi):
+ """Remove a ROI from the manager."""
+ window = self._disconnectProfileWindow(profileRoi)
+ if window is not None:
+ geometry = window.geometry()
+ if not geometry.isEmpty():
+ self._previousWindowGeometry.append(geometry)
+ self.clearProfileWindow(window)
+ if profileRoi in self._rois:
+ self._rois.remove(profileRoi)
+
+ def _disconnectProfileWindow(self, profileRoi):
+ """Handle profile window close."""
+ window = profileRoi.getProfileWindow()
+ profileRoi.setProfileWindow(None)
+ return window
+
+ def clearProfile(self):
+ """Clear the associated ROI profile"""
+ roiManager = self.getRoiManager()
+ for roi in list(self._rois):
+ if roi.getFocusProxy() is not None:
+ # Skip sub ROIs, it will be removed by their parents
+ continue
+ roiManager.removeRoi(roi)
+
+ if not roiManager.isDrawing():
+ # Clean the selected mode
+ roiManager.stop()
+
+ def hasPendingOperations(self):
+ """Returns true if a thread is still computing or displaying a profile.
+
+ :rtype: bool
+ """
+ return len(self.__reentrantResults) > 0 or len(self._pendingRunners) > 0
+
+ def requestUpdateAllProfile(self):
+ """Request to update the profile of all the managed ROIs.
+ """
+ for roi in self._rois:
+ self.requestUpdateProfile(roi)
+
+ def requestUpdateProfile(self, profileRoi):
+ """Request to update a specific profile ROI.
+
+ :param ~core.ProfileRoiMixIn profileRoi:
+ """
+ if profileRoi.computeProfile is None:
+ return
+ threadPool = silxGlobalThreadPool()
+
+ # Clean up deprecated runners
+ for runner in list(self._pendingRunners):
+ if not inspect.isValid(runner):
+ self._pendingRunners.remove(runner)
+ continue
+ if runner.getRoi() is profileRoi:
+ if hasattr(threadPool, "tryTake"):
+ if threadPool.tryTake(runner):
+ self._pendingRunners.remove(runner)
+ else: # Support Qt<5.9
+ runner._lazyCancel()
+
+ item = self.getPlotItem()
+ if item is None or not isinstance(item, profileRoi.ITEM_KIND):
+ # This item is not compatible with this profile
+ profileRoi._setPlotItem(None)
+ profileWindow = profileRoi.getProfileWindow()
+ if profileWindow is not None:
+ profileWindow.setProfile(None)
+ return
+
+ profileRoi._setPlotItem(item)
+ runner = _RunnableComputeProfile(threadPool, item, profileRoi)
+ runner.runnerFinished.connect(self.__cleanUpRunner)
+ runner.resultReady.connect(self.__displayResult)
+ self._pendingRunners.append(runner)
+ threadPool.start(runner)
+
+ def __cleanUpRunner(self, runner):
+ """Remove a thread pool runner from the list of hold tasks.
+
+ Called at the termination of the runner.
+ """
+ if runner in self._pendingRunners:
+ self._pendingRunners.remove(runner)
+
+ def __displayResult(self, roi, profileData):
+ """Display the result of a ROI.
+
+ :param ~core.ProfileRoiMixIn profileRoi: A managed ROI
+ :param ~core.CurveProfileData profileData: Computed data profile
+ """
+ if roi in self.__reentrantResults:
+ # Store the data to process it in the main loop
+ # And not a sub loop created by initProfileWindow
+ # This also remove the duplicated requested
+ self.__reentrantResults[roi] = profileData
+ return
+
+ self.__reentrantResults[roi] = profileData
+ self._computedProfiles = self._computedProfiles + 1
+ window = roi.getProfileWindow()
+ if window is None:
+ plot = self.getPlotWidget()
+ window = self.createProfileWindow(plot, roi)
+ # roi.profileWindow have to be set before initializing the window
+ # Cause the initialization is using QEventLoop
+ roi.setProfileWindow(window)
+ self.initProfileWindow(window, roi)
+ window.show()
+
+ lastData = self.__reentrantResults.pop(roi)
+ window.setProfile(lastData)
+
+ def __plotDestroyed(self, ref):
+ """Handle finalization of PlotWidget
+
+ :param ref: weakref to the plot
+ """
+ self._plotRef = None
+ self._roiManagerRef = None
+ self._pendingRunners = []
+
+ def setPlotItem(self, item):
+ """Set the plot item focused by the profile manager.
+
+ :param ~silx.gui.plot.items.Item item: A plot item
+ """
+ previous = self.getPlotItem()
+ if previous is item:
+ return
+ if item is None:
+ self._item = None
+ else:
+ item.sigItemChanged.connect(self.__itemChanged)
+ self._item = weakref.ref(item)
+ self._updateRoiColors()
+ self.requestUpdateAllProfile()
+
+ def getDefaultColor(self, item):
+ """Returns the default ROI color to use according to the given item.
+
+ :param ~silx.gui.plot.items.item.Item item: AN item
+ :rtype: qt.QColor
+ """
+ color = 'pink'
+ if isinstance(item, items.ColormapMixIn):
+ colormap = item.getColormap()
+ name = colormap.getName()
+ if name is not None:
+ color = colors.cursorColorForColormap(name)
+ color = colors.asQColor(color)
+ return color
+
+ def _updateRoiColors(self):
+ """Update ROI color according to the item selection"""
+ if not self.__useColorFromCursor:
+ return
+ item = self.getPlotItem()
+ color = self.getDefaultColor(item)
+ for roi in self._rois:
+ roi.setColor(color)
+
+ def _updateRoiColor(self, roi):
+ """Update a specific ROI according to the current selected item.
+
+ :param RegionOfInterest roi: The ROI to update
+ """
+ if not self.__useColorFromCursor:
+ return
+ item = self.getPlotItem()
+ color = self.getDefaultColor(item)
+ roi.setColor(color)
+
+ def __itemChanged(self, changeType):
+ """Handle item changes.
+ """
+ if changeType in (items.ItemChangedType.DATA,
+ items.ItemChangedType.MASK,
+ items.ItemChangedType.POSITION,
+ items.ItemChangedType.SCALE):
+ self.requestUpdateAllProfile()
+ elif changeType == (items.ItemChangedType.COLORMAP):
+ self._updateRoiColors()
+
+ def getPlotItem(self):
+ """Returns the item focused by the profile manager.
+
+ :rtype: ~silx.gui.plot.items.Item
+ """
+ if self._item is None:
+ return None
+ item = self._item()
+ if item is None:
+ self._item = None
+ return item
+
+ def getPlotWidget(self):
+ """The plot associated to the profile manager.
+
+ :rtype: ~silx.gui.plot.PlotWidget
+ """
+ if self._plotRef is None:
+ return None
+ plot = self._plotRef()
+ if plot is None:
+ self._plotRef = None
+ return plot
+
+ def getCurrentRoi(self):
+ """Returns the currently selected ROI, else None.
+
+ :rtype: core.ProfileRoiMixIn
+ """
+ roiManager = self.getRoiManager()
+ if roiManager is None:
+ return None
+ roi = roiManager.getCurrentRoi()
+ if not isinstance(roi, core.ProfileRoiMixIn):
+ return None
+ return roi
+
+ def getRoiManager(self):
+ """Returns the used ROI manager
+
+ :rtype: RegionOfInterestManager
+ """
+ return self._roiManagerRef()
+
+ def createProfileWindow(self, plot, roi):
+ """Create a new profile window.
+
+ :param ~core.ProfileRoiMixIn roi: The plot containing the raw data
+ :param ~core.ProfileRoiMixIn roi: A managed ROI
+ :rtype: ~ProfileWindow
+ """
+ return self._profileWindowClass(plot)
+
+ def initProfileWindow(self, profileWindow, roi):
+ """This function is called just after the profile window creation in
+ order to initialize the window location.
+
+ :param ~ProfileWindow profileWindow:
+ The profile window to initialize.
+ """
+ # Enforce the use of one of the widgets
+ # To have the correct window size
+ profileWindow.prepareWidget(roi)
+ profileWindow.adjustSize()
+
+ # Trick to avoid blinking while retrieving the right window size
+ # Display the window, hide it and wait for some event loops
+ profileWindow.show()
+ profileWindow.hide()
+ eventLoop = qt.QEventLoop(self)
+ for _ in range(10):
+ if not eventLoop.processEvents():
+ break
+
+ profileWindow.show()
+ if len(self._previousWindowGeometry) > 0:
+ geometry = self._previousWindowGeometry.pop()
+ profileWindow.setGeometry(geometry)
+ return
+
+ window = self.getPlotWidget().window()
+ winGeom = window.frameGeometry()
+ if qt.BINDING in ("PySide2", "PyQt5"):
+ qapp = qt.QApplication.instance()
+ desktop = qapp.desktop()
+ screenGeom = desktop.availableGeometry(window)
+ else: # Qt6 (and also Qt>=5.14)
+ screenGeom = window.screen().availableGeometry()
+ spaceOnLeftSide = winGeom.left()
+ spaceOnRightSide = screenGeom.width() - winGeom.right()
+
+ profileGeom = profileWindow.frameGeometry()
+ profileWidth = profileGeom.width()
+
+ # Align vertically to the center of the window
+ top = winGeom.top() + (winGeom.height() - profileGeom.height()) // 2
+
+ margin = 5
+ if profileWidth < spaceOnRightSide:
+ # Place profile on the right
+ left = winGeom.right() + margin
+ elif profileWidth < spaceOnLeftSide:
+ # Place profile on the left
+ left = max(0, winGeom.left() - profileWidth - margin)
+ else:
+ # Move it as much as possible where there is more space
+ if spaceOnLeftSide > spaceOnRightSide:
+ left = 0
+ else:
+ left = screenGeom.width() - profileGeom.width()
+ profileWindow.move(left, top)
+
+
+ def clearProfileWindow(self, profileWindow):
+ """Called when a profile window is not anymore needed.
+
+ By default the window will be closed. But it can be
+ inherited to change this behavior.
+ """
+ profileWindow.deleteLater()
diff --git a/src/silx/gui/plot/tools/profile/rois.py b/src/silx/gui/plot/tools/profile/rois.py
new file mode 100644
index 0000000..9eef622
--- /dev/null
+++ b/src/silx/gui/plot/tools/profile/rois.py
@@ -0,0 +1,1156 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module define ROIs for profile tools.
+
+.. inheritance-diagram::
+ silx.gui.plot.tools.profile.rois
+ :top-classes: silx.gui.plot.tools.profile.core.ProfileRoiMixIn, silx.gui.plot.items.roi.RegionOfInterest
+ :parts: 1
+ :private-bases:
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "01/12/2020"
+
+import numpy
+import weakref
+from concurrent.futures import CancelledError
+
+from silx.gui import colors
+
+from silx.gui.plot import items
+from silx.gui.plot.items import roi as roi_items
+from . import core
+from silx.gui import utils
+from .....utils.proxy import docstring
+
+
+def _relabelAxes(plot, text):
+ """Relabel {xlabel} and {ylabel} from this text using the corresponding
+ plot axis label. If the axis label is empty, label it with "X" and "Y".
+
+ :rtype: str
+ """
+ xLabel = plot.getXAxis().getLabel()
+ if not xLabel:
+ xLabel = "X"
+ yLabel = plot.getYAxis().getLabel()
+ if not yLabel:
+ yLabel = "Y"
+ return text.format(xlabel=xLabel, ylabel=yLabel)
+
+
+def _lineProfileTitle(x0, y0, x1, y1):
+ """Compute corresponding plot title
+
+ This can be overridden to change title behavior.
+
+ :param float x0: Profile start point X coord
+ :param float y0: Profile start point Y coord
+ :param float x1: Profile end point X coord
+ :param float y1: Profile end point Y coord
+ :return: Title to use
+ :rtype: str
+ """
+ if x0 == x1:
+ title = '{xlabel} = %g; {ylabel} = [%g, %g]' % (x0, y0, y1)
+ elif y0 == y1:
+ title = '{ylabel} = %g; {xlabel} = [%g, %g]' % (y0, x0, x1)
+ else:
+ m = (y1 - y0) / (x1 - x0)
+ b = y0 - m * x0
+ title = '{ylabel} = %g * {xlabel} %+g' % (m, b)
+
+ return title
+
+
+class _ImageProfileArea(items.Shape):
+ """This shape displays the location of pixels used to compute the
+ profile."""
+
+ def __init__(self, parentRoi):
+ items.Shape.__init__(self, "polygon")
+ color = colors.rgba(parentRoi.getColor())
+ self.setColor(color)
+ self.setFill(True)
+ self.setOverlay(True)
+ self.setPoints([[0, 0], [0, 0]]) # Else it segfault
+
+ self.__parentRoi = weakref.ref(parentRoi)
+ parentRoi.sigItemChanged.connect(self._updateAreaProperty)
+ parentRoi.sigRegionChanged.connect(self._updateArea)
+ parentRoi.sigProfilePropertyChanged.connect(self._updateArea)
+ parentRoi.sigPlotItemChanged.connect(self._updateArea)
+
+ def getParentRoi(self):
+ if self.__parentRoi is None:
+ return None
+ parentRoi = self.__parentRoi()
+ if parentRoi is None:
+ self.__parentRoi = None
+ return parentRoi
+
+ def _updateAreaProperty(self, event=None, checkVisibility=True):
+ parentRoi = self.sender()
+ if event == items.ItemChangedType.COLOR:
+ parentRoi._updateItemProperty(event, parentRoi, self)
+ elif event == items.ItemChangedType.VISIBLE:
+ if self.getPlotItem() is not None:
+ parentRoi._updateItemProperty(event, parentRoi, self)
+
+ def _updateArea(self):
+ roi = self.getParentRoi()
+ item = roi.getPlotItem()
+ if item is None:
+ self.setVisible(False)
+ return
+ polygon = self._computePolygon(item)
+ self.setVisible(True)
+ polygon = numpy.array(polygon).T
+ self.setLineStyle("--")
+ self.setPoints(polygon, copy=False)
+
+ def _computePolygon(self, item):
+ if not isinstance(item, items.ImageBase):
+ raise TypeError("Unexpected class %s" % type(item))
+
+ currentData = item.getValueData(copy=False)
+
+ roi = self.getParentRoi()
+ origin = item.getOrigin()
+ scale = item.getScale()
+ _coords, _profile, area, _profileName, _xLabel = core.createProfile(
+ roiInfo=roi._getRoiInfo(),
+ currentData=currentData,
+ origin=origin,
+ scale=scale,
+ lineWidth=roi.getProfileLineWidth(),
+ method="none")
+ return area
+
+
+class _SliceProfileArea(items.Shape):
+ """This shape displays the location a profile in a scatter.
+
+ Each point used to compute the slice are linked together.
+ """
+
+ def __init__(self, parentRoi):
+ items.Shape.__init__(self, "polygon")
+ color = colors.rgba(parentRoi.getColor())
+ self.setColor(color)
+ self.setFill(True)
+ self.setOverlay(True)
+ self.setPoints([[0, 0], [0, 0]]) # Else it segfault
+
+ self.__parentRoi = weakref.ref(parentRoi)
+ parentRoi.sigItemChanged.connect(self._updateAreaProperty)
+ parentRoi.sigRegionChanged.connect(self._updateArea)
+ parentRoi.sigProfilePropertyChanged.connect(self._updateArea)
+ parentRoi.sigPlotItemChanged.connect(self._updateArea)
+
+ def getParentRoi(self):
+ if self.__parentRoi is None:
+ return None
+ parentRoi = self.__parentRoi()
+ if parentRoi is None:
+ self.__parentRoi = None
+ return parentRoi
+
+ def _updateAreaProperty(self, event=None, checkVisibility=True):
+ parentRoi = self.sender()
+ if event == items.ItemChangedType.COLOR:
+ parentRoi._updateItemProperty(event, parentRoi, self)
+ elif event == items.ItemChangedType.VISIBLE:
+ if self.getPlotItem() is not None:
+ parentRoi._updateItemProperty(event, parentRoi, self)
+
+ def _updateArea(self):
+ roi = self.getParentRoi()
+ item = roi.getPlotItem()
+ if item is None:
+ self.setVisible(False)
+ return
+ polylines = self._computePolylines(roi, item)
+ if polylines is None:
+ self.setVisible(False)
+ return
+ self.setVisible(True)
+ self.setLineStyle("--")
+ self.setPoints(polylines, copy=False)
+
+ def _computePolylines(self, roi, item):
+ slicing = roi._getSlice(item)
+ if slicing is None:
+ return None
+ xx, yy, _values, _xx_error, _yy_error = item.getData(copy=False)
+ xx, yy = xx[slicing], yy[slicing]
+ polylines = numpy.array((xx, yy)).T
+ if len(polylines) == 0:
+ return None
+ return polylines
+
+
+class _DefaultImageProfileRoiMixIn(core.ProfileRoiMixIn):
+ """Provide common behavior for silx default image profile ROI.
+ """
+
+ ITEM_KIND = items.ImageBase
+
+ def __init__(self, parent=None):
+ core.ProfileRoiMixIn.__init__(self, parent=parent)
+ self.__method = "mean"
+ self.__width = 1
+ self.sigRegionChanged.connect(self.__regionChanged)
+ self.sigPlotItemChanged.connect(self.__updateArea)
+ self.__area = _ImageProfileArea(self)
+ self.addItem(self.__area)
+
+ def __regionChanged(self):
+ self.invalidateProfile()
+ self.__updateArea()
+
+ def setProfileMethod(self, method):
+ """
+ :param str method: method to compute the profile. Can be 'mean' or 'sum'
+ """
+ if self.__method == method:
+ return
+ self.__method = method
+ self.invalidateProperties()
+ self.invalidateProfile()
+
+ def getProfileMethod(self):
+ return self.__method
+
+ def setProfileLineWidth(self, width):
+ if self.__width == width:
+ return
+ self.__width = width
+ self.__updateArea()
+ self.invalidateProperties()
+ self.invalidateProfile()
+
+ def getProfileLineWidth(self):
+ return self.__width
+
+ def __updateArea(self):
+ plotItem = self.getPlotItem()
+ if plotItem is None:
+ self.setLineStyle("-")
+ else:
+ self.setLineStyle("--")
+
+ def _getRoiInfo(self):
+ """Wrapper to allow to reuse the previous Profile code.
+
+ It would be good to remove it at one point.
+ """
+ if isinstance(self, roi_items.HorizontalLineROI):
+ lineProjectionMode = 'X'
+ y = self.getPosition()
+ roiStart = (0, y)
+ roiEnd = (1, y)
+ elif isinstance(self, roi_items.VerticalLineROI):
+ lineProjectionMode = 'Y'
+ x = self.getPosition()
+ roiStart = (x, 0)
+ roiEnd = (x, 1)
+ elif isinstance(self, roi_items.LineROI):
+ lineProjectionMode = 'D'
+ roiStart, roiEnd = self.getEndPoints()
+ else:
+ assert False
+
+ return roiStart, roiEnd, lineProjectionMode
+
+ def computeProfile(self, item):
+ if not isinstance(item, items.ImageBase):
+ raise TypeError("Unexpected class %s" % type(item))
+
+ origin = item.getOrigin()
+ scale = item.getScale()
+ method = self.getProfileMethod()
+ lineWidth = self.getProfileLineWidth()
+
+ def createProfile2(currentData):
+ coords, profile, _area, profileName, xLabel = core.createProfile(
+ roiInfo=self._getRoiInfo(),
+ currentData=currentData,
+ origin=origin,
+ scale=scale,
+ lineWidth=lineWidth,
+ method=method)
+ return coords, profile, profileName, xLabel
+
+ currentData = item.getValueData(copy=False)
+
+ yLabel = "%s" % str(method).capitalize()
+ coords, profile, title, xLabel = createProfile2(currentData)
+ title = title + "; width = %d" % lineWidth
+
+ # Use the axis names from the original plot
+ profileManager = self.getProfileManager()
+ plot = profileManager.getPlotWidget()
+ title = _relabelAxes(plot, title)
+ xLabel = _relabelAxes(plot, xLabel)
+
+ if isinstance(item, items.ImageRgba):
+ rgba = item.getData(copy=False)
+ _coords, r, _profileName, _xLabel = createProfile2(rgba[..., 0])
+ _coords, g, _profileName, _xLabel = createProfile2(rgba[..., 1])
+ _coords, b, _profileName, _xLabel = createProfile2(rgba[..., 2])
+ if rgba.shape[-1] == 4:
+ _coords, a, _profileName, _xLabel = createProfile2(rgba[..., 3])
+ else:
+ a = [None]
+ data = core.RgbaProfileData(
+ coords=coords,
+ profile=profile[0],
+ profile_r=r[0],
+ profile_g=g[0],
+ profile_b=b[0],
+ profile_a=a[0],
+ title=title,
+ xLabel=xLabel,
+ yLabel=yLabel,
+ )
+ else:
+ data = core.CurveProfileData(
+ coords=coords,
+ profile=profile[0],
+ title=title,
+ xLabel=xLabel,
+ yLabel=yLabel,
+ )
+ return data
+
+
+class ProfileImageHorizontalLineROI(roi_items.HorizontalLineROI,
+ _DefaultImageProfileRoiMixIn):
+ """ROI for an horizontal profile at a location of an image"""
+
+ ICON = 'shape-horizontal'
+ NAME = 'horizontal line profile'
+
+ def __init__(self, parent=None):
+ roi_items.HorizontalLineROI.__init__(self, parent=parent)
+ _DefaultImageProfileRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileImageVerticalLineROI(roi_items.VerticalLineROI,
+ _DefaultImageProfileRoiMixIn):
+ """ROI for a vertical profile at a location of an image"""
+
+ ICON = 'shape-vertical'
+ NAME = 'vertical line profile'
+
+ def __init__(self, parent=None):
+ roi_items.VerticalLineROI.__init__(self, parent=parent)
+ _DefaultImageProfileRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileImageLineROI(roi_items.LineROI,
+ _DefaultImageProfileRoiMixIn):
+ """ROI for an image profile between 2 points.
+
+ The X profile of this ROI is the projecting into one of the x/y axes,
+ using its scale and its orientation.
+ """
+
+ ICON = 'shape-diagonal'
+ NAME = 'line profile'
+
+ def __init__(self, parent=None):
+ roi_items.LineROI.__init__(self, parent=parent)
+ _DefaultImageProfileRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileImageDirectedLineROI(roi_items.LineROI,
+ _DefaultImageProfileRoiMixIn):
+ """ROI for an image profile between 2 points.
+
+ The X profile of the line is displayed projected into the line itself,
+ using its scale and its orientation. It's the distance from the origin.
+ """
+
+ ICON = 'shape-diagonal-directed'
+ NAME = 'directed line profile'
+
+ def __init__(self, parent=None):
+ roi_items.LineROI.__init__(self, parent=parent)
+ _DefaultImageProfileRoiMixIn.__init__(self, parent=parent)
+ self._handleStart.setSymbol('o')
+
+ def computeProfile(self, item):
+ if not isinstance(item, items.ImageBase):
+ raise TypeError("Unexpected class %s" % type(item))
+
+ from silx.image.bilinear import BilinearImage
+
+ origin = item.getOrigin()
+ scale = item.getScale()
+ method = self.getProfileMethod()
+ lineWidth = self.getProfileLineWidth()
+ currentData = item.getValueData(copy=False)
+
+ roiInfo = self._getRoiInfo()
+ roiStart, roiEnd, _lineProjectionMode = roiInfo
+
+ startPt = ((roiStart[1] - origin[1]) / scale[1],
+ (roiStart[0] - origin[0]) / scale[0])
+ endPt = ((roiEnd[1] - origin[1]) / scale[1],
+ (roiEnd[0] - origin[0]) / scale[0])
+
+ if numpy.array_equal(startPt, endPt):
+ return None
+
+ bilinear = BilinearImage(currentData)
+ profile = bilinear.profile_line(
+ (startPt[0] - 0.5, startPt[1] - 0.5),
+ (endPt[0] - 0.5, endPt[1] - 0.5),
+ lineWidth,
+ method=method)
+
+ # Compute the line size
+ lineSize = numpy.sqrt((roiEnd[1] - roiStart[1]) ** 2 +
+ (roiEnd[0] - roiStart[0]) ** 2)
+ coords = numpy.linspace(0, lineSize, len(profile),
+ endpoint=True,
+ dtype=numpy.float32)
+
+ title = _lineProfileTitle(*roiStart, *roiEnd)
+ title = title + "; width = %d" % lineWidth
+ xLabel = "√({xlabel}²+{ylabel}²)"
+ yLabel = str(method).capitalize()
+
+ # Use the axis names from the original plot
+ profileManager = self.getProfileManager()
+ plot = profileManager.getPlotWidget()
+ xLabel = _relabelAxes(plot, xLabel)
+ title = _relabelAxes(plot, title)
+
+ data = core.CurveProfileData(
+ coords=coords,
+ profile=profile,
+ title=title,
+ xLabel=xLabel,
+ yLabel=yLabel,
+ )
+ return data
+
+
+class _ProfileCrossROI(roi_items.HandleBasedROI, core.ProfileRoiMixIn):
+
+ """ROI to manage a cross of profiles
+
+ It is managed using 2 sub ROIs for vertical and horizontal.
+ """
+
+ _kind = "Cross"
+ """Label for this kind of ROI"""
+
+ _plotShape = "point"
+ """Plot shape which is used for the first interaction"""
+
+ def __init__(self, parent=None):
+ roi_items.HandleBasedROI.__init__(self, parent=parent)
+ core.ProfileRoiMixIn.__init__(self, parent=parent)
+ self.sigRegionChanged.connect(self.__regionChanged)
+ self.sigAboutToBeRemoved.connect(self.__aboutToBeRemoved)
+ self.__position = 0, 0
+ self.__vline = None
+ self.__hline = None
+ self.__handle = self.addHandle()
+ self.__handleLabel = self.addLabelHandle()
+ self.__handleLabel.setText(self.getName())
+ self.__inhibitReentance = utils.LockReentrant()
+ self.computeProfile = None
+ self.sigItemChanged.connect(self.__updateLineProperty)
+
+ # Make sure the marker is over the ROIs
+ self.__handle.setZValue(1)
+ # Create the vline and the hline
+ self._createSubRois()
+
+ @docstring(roi_items.HandleBasedROI)
+ def contains(self, position):
+ roiPos = self.getPosition()
+ return position[0] == roiPos[0] or position[1] == roiPos[1]
+
+ def setFirstShapePoints(self, points):
+ pos = points[0]
+ self.setPosition(pos)
+
+ def getPosition(self):
+ """Returns the position of this ROI
+
+ :rtype: numpy.ndarray
+ """
+ return self.__position
+
+ def setPosition(self, pos):
+ """Set the position of this ROI
+
+ :param numpy.ndarray pos: 2d-coordinate of this point
+ """
+ self.__position = pos
+ with utils.blockSignals(self.__handle):
+ self.__handle.setPosition(*pos)
+ with utils.blockSignals(self.__handleLabel):
+ self.__handleLabel.setPosition(*pos)
+ self.sigRegionChanged.emit()
+
+ def handleDragUpdated(self, handle, origin, previous, current):
+ if handle is self.__handle:
+ self.setPosition(current)
+
+ def __updateLineProperty(self, event=None, checkVisibility=True):
+ if event == items.ItemChangedType.NAME:
+ self.__handleLabel.setText(self.getName())
+ elif event in [items.ItemChangedType.COLOR,
+ items.ItemChangedType.VISIBLE]:
+ lines = []
+ if self.__vline:
+ lines.append(self.__vline)
+ if self.__hline:
+ lines.append(self.__hline)
+ self._updateItemProperty(event, self, lines)
+
+ def _createLines(self, parent):
+ """Inherit this function to return 2 ROI objects for respectivly
+ the horizontal, and the vertical lines."""
+ raise NotImplementedError()
+
+ def _setProfileManager(self, profileManager):
+ core.ProfileRoiMixIn._setProfileManager(self, profileManager)
+ # Connecting the vline and the hline
+ roiManager = profileManager.getRoiManager()
+ roiManager.addRoi(self.__vline)
+ roiManager.addRoi(self.__hline)
+
+ def _createSubRois(self):
+ hline, vline = self._createLines(parent=None)
+ for i, line in enumerate([vline, hline]):
+ line.setPosition(self.__position[i])
+ line.setEditable(True)
+ line.setSelectable(True)
+ line.setFocusProxy(self)
+ line.setName("")
+ self.__vline = vline
+ self.__hline = hline
+ vline.sigAboutToBeRemoved.connect(self.__vlineRemoved)
+ vline.sigRegionChanged.connect(self.__vlineRegionChanged)
+ hline.sigAboutToBeRemoved.connect(self.__hlineRemoved)
+ hline.sigRegionChanged.connect(self.__hlineRegionChanged)
+
+ def _getLines(self):
+ return self.__hline, self.__vline
+
+ def __regionChanged(self):
+ if self.__inhibitReentance.locked():
+ return
+ x, y = self.getPosition()
+ hline, vline = self._getLines()
+ if hline is None:
+ return
+ with self.__inhibitReentance:
+ hline.setPosition(y)
+ vline.setPosition(x)
+
+ def __vlineRegionChanged(self):
+ if self.__inhibitReentance.locked():
+ return
+ pos = self.getPosition()
+ vline = self.__vline
+ pos = vline.getPosition(), pos[1]
+ with self.__inhibitReentance:
+ self.setPosition(pos)
+
+ def __hlineRegionChanged(self):
+ if self.__inhibitReentance.locked():
+ return
+ pos = self.getPosition()
+ hline = self.__hline
+ pos = pos[0], hline.getPosition()
+ with self.__inhibitReentance:
+ self.setPosition(pos)
+
+ def __aboutToBeRemoved(self):
+ vline = self.__vline
+ hline = self.__hline
+ # Avoid side remove signals
+ if hline is not None:
+ hline.sigAboutToBeRemoved.disconnect(self.__hlineRemoved)
+ hline.sigRegionChanged.disconnect(self.__hlineRegionChanged)
+ if vline is not None:
+ vline.sigAboutToBeRemoved.disconnect(self.__vlineRemoved)
+ vline.sigRegionChanged.disconnect(self.__vlineRegionChanged)
+ # Clean up the child
+ profileManager = self.getProfileManager()
+ roiManager = profileManager.getRoiManager()
+ if hline is not None:
+ roiManager.removeRoi(hline)
+ self.__hline = None
+ if vline is not None:
+ roiManager.removeRoi(vline)
+ self.__vline = None
+
+ def __hlineRemoved(self):
+ self.__lineRemoved(isHline=True)
+
+ def __vlineRemoved(self):
+ self.__lineRemoved(isHline=False)
+
+ def __lineRemoved(self, isHline):
+ """If any of the lines is removed: disconnect this objects, and let the
+ other one persist"""
+ hline, vline = self._getLines()
+
+ hline.sigAboutToBeRemoved.disconnect(self.__hlineRemoved)
+ vline.sigAboutToBeRemoved.disconnect(self.__vlineRemoved)
+ hline.sigRegionChanged.disconnect(self.__hlineRegionChanged)
+ vline.sigRegionChanged.disconnect(self.__vlineRegionChanged)
+
+ self.__hline = None
+ self.__vline = None
+ profileManager = self.getProfileManager()
+ roiManager = profileManager.getRoiManager()
+ if isHline:
+ self.__releaseLine(vline)
+ else:
+ self.__releaseLine(hline)
+ roiManager.removeRoi(self)
+
+ def __releaseLine(self, line):
+ """Release the line in order to make it independent"""
+ line.setFocusProxy(None)
+ line.setName(self.getName())
+ line.setEditable(self.isEditable())
+ line.setSelectable(self.isSelectable())
+
+
+class ProfileImageCrossROI(_ProfileCrossROI):
+ """ROI to manage a cross of profiles
+
+ It is managed using 2 sub ROIs for vertical and horizontal.
+ """
+
+ ICON = 'shape-cross'
+ NAME = 'cross profile'
+ ITEM_KIND = items.ImageBase
+
+ def _createLines(self, parent):
+ vline = ProfileImageVerticalLineROI(parent=parent)
+ hline = ProfileImageHorizontalLineROI(parent=parent)
+ return hline, vline
+
+ def setProfileMethod(self, method):
+ """
+ :param str method: method to compute the profile. Can be 'mean' or 'sum'
+ """
+ hline, vline = self._getLines()
+ hline.setProfileMethod(method)
+ vline.setProfileMethod(method)
+ self.invalidateProperties()
+
+ def getProfileMethod(self):
+ hline, _vline = self._getLines()
+ return hline.getProfileMethod()
+
+ def setProfileLineWidth(self, width):
+ hline, vline = self._getLines()
+ hline.setProfileLineWidth(width)
+ vline.setProfileLineWidth(width)
+ self.invalidateProperties()
+
+ def getProfileLineWidth(self):
+ hline, _vline = self._getLines()
+ return hline.getProfileLineWidth()
+
+
+class _DefaultScatterProfileRoiMixIn(core.ProfileRoiMixIn):
+ """Provide common behavior for silx default scatter profile ROI.
+ """
+
+ ITEM_KIND = items.Scatter
+
+ def __init__(self, parent=None):
+ core.ProfileRoiMixIn.__init__(self, parent=parent)
+ self.__nPoints = 1024
+ self.sigRegionChanged.connect(self.__regionChanged)
+
+ def __regionChanged(self):
+ self.invalidateProfile()
+
+ # Number of points
+
+ def getNPoints(self):
+ """Returns the number of points of the profiles
+
+ :rtype: int
+ """
+ return self.__nPoints
+
+ def setNPoints(self, npoints):
+ """Set the number of points of the profiles
+
+ :param int npoints:
+ """
+ npoints = int(npoints)
+ if npoints < 1:
+ raise ValueError("Unsupported number of points: %d" % npoints)
+ elif npoints != self.__nPoints:
+ self.__nPoints = npoints
+ self.invalidateProperties()
+ self.invalidateProfile()
+
+ def _computeProfile(self, scatter, x0, y0, x1, y1):
+ """Compute corresponding profile
+
+ :param float x0: Profile start point X coord
+ :param float y0: Profile start point Y coord
+ :param float x1: Profile end point X coord
+ :param float y1: Profile end point Y coord
+ :return: (points, values) profile data or None
+ """
+ future = scatter._getInterpolator()
+ try:
+ interpolator = future.result()
+ except CancelledError:
+ return None
+ if interpolator is None:
+ return None # Cannot init an interpolator
+
+ nPoints = self.getNPoints()
+ points = numpy.transpose((
+ numpy.linspace(x0, x1, nPoints, endpoint=True),
+ numpy.linspace(y0, y1, nPoints, endpoint=True)))
+
+ values = interpolator(points)
+
+ if not numpy.any(numpy.isfinite(values)):
+ return None # Profile outside convex hull
+
+ return points, values
+
+ def computeProfile(self, item):
+ """Update profile according to current ROI"""
+ if not isinstance(item, items.Scatter):
+ raise TypeError("Unexpected class %s" % type(item))
+
+ # Get end points
+ if isinstance(self, roi_items.LineROI):
+ points = self.getEndPoints()
+ x0, y0 = points[0]
+ x1, y1 = points[1]
+ elif isinstance(self, (roi_items.VerticalLineROI, roi_items.HorizontalLineROI)):
+ profileManager = self.getProfileManager()
+ plot = profileManager.getPlotWidget()
+
+ if isinstance(self, roi_items.HorizontalLineROI):
+ x0, x1 = plot.getXAxis().getLimits()
+ y0 = y1 = self.getPosition()
+
+ elif isinstance(self, roi_items.VerticalLineROI):
+ x0 = x1 = self.getPosition()
+ y0, y1 = plot.getYAxis().getLimits()
+ else:
+ raise RuntimeError('Unsupported ROI for profile: {}'.format(self.__class__))
+
+ if x1 < x0 or (x1 == x0 and y1 < y0):
+ # Invert points
+ x0, y0, x1, y1 = x1, y1, x0, y0
+
+ profile = self._computeProfile(item, x0, y0, x1, y1)
+ if profile is None:
+ return None
+
+ title = _lineProfileTitle(x0, y0, x1, y1)
+ points = profile[0]
+ values = profile[1]
+
+ if (numpy.abs(points[-1, 0] - points[0, 0]) >
+ numpy.abs(points[-1, 1] - points[0, 1])):
+ xProfile = points[:, 0]
+ xLabel = '{xlabel}'
+ else:
+ xProfile = points[:, 1]
+ xLabel = '{ylabel}'
+
+ # Use the axis names from the original
+ profileManager = self.getProfileManager()
+ plot = profileManager.getPlotWidget()
+ title = _relabelAxes(plot, title)
+ xLabel = _relabelAxes(plot, xLabel)
+
+ data = core.CurveProfileData(
+ coords=xProfile,
+ profile=values,
+ title=title,
+ xLabel=xLabel,
+ yLabel='Profile',
+ )
+ return data
+
+
+class ProfileScatterHorizontalLineROI(roi_items.HorizontalLineROI,
+ _DefaultScatterProfileRoiMixIn):
+ """ROI for an horizontal profile at a location of a scatter"""
+
+ ICON = 'shape-horizontal'
+ NAME = 'horizontal line profile'
+
+ def __init__(self, parent=None):
+ roi_items.HorizontalLineROI.__init__(self, parent=parent)
+ _DefaultScatterProfileRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileScatterVerticalLineROI(roi_items.VerticalLineROI,
+ _DefaultScatterProfileRoiMixIn):
+ """ROI for an horizontal profile at a location of a scatter"""
+
+ ICON = 'shape-vertical'
+ NAME = 'vertical line profile'
+
+ def __init__(self, parent=None):
+ roi_items.VerticalLineROI.__init__(self, parent=parent)
+ _DefaultScatterProfileRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileScatterLineROI(roi_items.LineROI,
+ _DefaultScatterProfileRoiMixIn):
+ """ROI for an horizontal profile at a location of a scatter"""
+
+ ICON = 'shape-diagonal'
+ NAME = 'line profile'
+
+ def __init__(self, parent=None):
+ roi_items.LineROI.__init__(self, parent=parent)
+ _DefaultScatterProfileRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileScatterCrossROI(_ProfileCrossROI):
+ """ROI to manage a cross of profiles for scatters.
+ """
+
+ ICON = 'shape-cross'
+ NAME = 'cross profile'
+ ITEM_KIND = items.Scatter
+
+ def _createLines(self, parent):
+ vline = ProfileScatterVerticalLineROI(parent=parent)
+ hline = ProfileScatterHorizontalLineROI(parent=parent)
+ return hline, vline
+
+ def getNPoints(self):
+ """Returns the number of points of the profiles
+
+ :rtype: int
+ """
+ hline, _vline = self._getLines()
+ return hline.getNPoints()
+
+ def setNPoints(self, npoints):
+ """Set the number of points of the profiles
+
+ :param int npoints:
+ """
+ hline, vline = self._getLines()
+ hline.setNPoints(npoints)
+ vline.setNPoints(npoints)
+ self.invalidateProperties()
+
+
+class _DefaultScatterProfileSliceRoiMixIn(core.ProfileRoiMixIn):
+ """Default ROI to allow to slice in the scatter data."""
+
+ ITEM_KIND = items.Scatter
+
+ def __init__(self, parent=None):
+ core.ProfileRoiMixIn.__init__(self, parent=parent)
+ self.__area = _SliceProfileArea(self)
+ self.addItem(self.__area)
+ self.sigRegionChanged.connect(self._regionChanged)
+ self.sigPlotItemChanged.connect(self._updateArea)
+
+ def _regionChanged(self):
+ self.invalidateProfile()
+ self._updateArea()
+
+ def _updateArea(self):
+ plotItem = self.getPlotItem()
+ if plotItem is None:
+ self.setLineStyle("-")
+ else:
+ self.setLineStyle("--")
+
+ def _getSlice(self, item):
+ position = self.getPosition()
+ bounds = item.getCurrentVisualizationParameter(items.Scatter.VisualizationParameter.GRID_BOUNDS)
+ if isinstance(self, roi_items.HorizontalLineROI):
+ axis = 1
+ elif isinstance(self, roi_items.VerticalLineROI):
+ axis = 0
+ else:
+ assert False
+ if bounds is None or position < bounds[0][axis] or position > bounds[1][axis]:
+ # ROI outside of the scatter bound
+ return None
+
+ major_order = item.getCurrentVisualizationParameter(items.Scatter.VisualizationParameter.GRID_MAJOR_ORDER)
+ assert major_order == 'row'
+ max_grid_yy, max_grid_xx = item.getCurrentVisualizationParameter(items.Scatter.VisualizationParameter.GRID_SHAPE)
+
+ xx, yy, _values, _xx_error, _yy_error = item.getData(copy=False)
+ if isinstance(self, roi_items.HorizontalLineROI):
+ axis = yy
+ max_grid_first = max_grid_yy
+ max_grid_second = max_grid_xx
+ major_axis = major_order == 'column'
+ elif isinstance(self, roi_items.VerticalLineROI):
+ axis = xx
+ max_grid_first = max_grid_xx
+ max_grid_second = max_grid_yy
+ major_axis = major_order == 'row'
+ else:
+ assert False
+
+ def argnearest(array, value):
+ array = numpy.abs(array - value)
+ return numpy.argmin(array)
+
+ if major_axis:
+ # slice in the middle of the scatter
+ start = max_grid_second // 2 * max_grid_first
+ vslice = axis[start:start + max_grid_second]
+ index = argnearest(vslice, position)
+ slicing = slice(index, None, max_grid_first)
+ else:
+ # slice in the middle of the scatter
+ vslice = axis[max_grid_second // 2::max_grid_second]
+ index = argnearest(vslice, position)
+ start = index * max_grid_second
+ slicing = slice(start, start + max_grid_second)
+
+ return slicing
+
+ def computeProfile(self, item):
+ if not isinstance(item, items.Scatter):
+ raise TypeError("Unsupported %s item" % type(item))
+
+ slicing = self._getSlice(item)
+ if slicing is None:
+ # ROI out of bounds
+ return None
+
+ _xx, _yy, values, _xx_error, _yy_error = item.getData(copy=False)
+ profile = values[slicing]
+
+ if isinstance(self, roi_items.HorizontalLineROI):
+ title = "Horizontal slice"
+ xLabel = "{xlabel} index"
+ elif isinstance(self, roi_items.VerticalLineROI):
+ title = "Vertical slice"
+ xLabel = "{ylabel} index"
+ else:
+ assert False
+
+ # Use the axis names from the original plot
+ profileManager = self.getProfileManager()
+ plot = profileManager.getPlotWidget()
+ xLabel = _relabelAxes(plot, xLabel)
+
+ data = core.CurveProfileData(
+ coords=numpy.arange(len(profile)),
+ profile=profile,
+ title=title,
+ xLabel=xLabel,
+ yLabel="Profile",
+ )
+ return data
+
+
+class ProfileScatterHorizontalSliceROI(roi_items.HorizontalLineROI,
+ _DefaultScatterProfileSliceRoiMixIn):
+ """ROI for an horizontal profile at a location of a scatter
+ using data slicing.
+ """
+
+ ICON = 'slice-horizontal'
+ NAME = 'horizontal data slice profile'
+
+ def __init__(self, parent=None):
+ roi_items.HorizontalLineROI.__init__(self, parent=parent)
+ _DefaultScatterProfileSliceRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileScatterVerticalSliceROI(roi_items.VerticalLineROI,
+ _DefaultScatterProfileSliceRoiMixIn):
+ """ROI for a vertical profile at a location of a scatter
+ using data slicing.
+ """
+
+ ICON = 'slice-vertical'
+ NAME = 'vertical data slice profile'
+
+ def __init__(self, parent=None):
+ roi_items.VerticalLineROI.__init__(self, parent=parent)
+ _DefaultScatterProfileSliceRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileScatterCrossSliceROI(_ProfileCrossROI):
+ """ROI to manage a cross of slicing profiles on scatters.
+ """
+
+ ICON = 'slice-cross'
+ NAME = 'cross data slice profile'
+ ITEM_KIND = items.Scatter
+
+ def _createLines(self, parent):
+ vline = ProfileScatterVerticalSliceROI(parent=parent)
+ hline = ProfileScatterHorizontalSliceROI(parent=parent)
+ return hline, vline
+
+
+class _DefaultImageStackProfileRoiMixIn(_DefaultImageProfileRoiMixIn):
+
+ ITEM_KIND = items.ImageStack
+
+ def __init__(self, parent=None):
+ super(_DefaultImageStackProfileRoiMixIn, self).__init__(parent=parent)
+ self.__profileType = "1D"
+ """Kind of profile"""
+
+ def getProfileType(self):
+ return self.__profileType
+
+ def setProfileType(self, kind):
+ assert kind in ["1D", "2D"]
+ if self.__profileType == kind:
+ return
+ self.__profileType = kind
+ self.invalidateProperties()
+ self.invalidateProfile()
+
+ def computeProfile(self, item):
+ if not isinstance(item, items.ImageStack):
+ raise TypeError("Unexpected class %s" % type(item))
+
+ kind = self.getProfileType()
+ if kind == "1D":
+ result = _DefaultImageProfileRoiMixIn.computeProfile(self, item)
+ # z = item.getStackPosition()
+ return result
+
+ assert kind == "2D"
+
+ def createProfile2(currentData):
+ coords, profile, _area, profileName, xLabel = core.createProfile(
+ roiInfo=self._getRoiInfo(),
+ currentData=currentData,
+ origin=origin,
+ scale=scale,
+ lineWidth=self.getProfileLineWidth(),
+ method=method)
+ return coords, profile, profileName, xLabel
+
+ currentData = numpy.array(item.getStackData(copy=False))
+ origin = item.getOrigin()
+ scale = item.getScale()
+ colormap = item.getColormap()
+ method = self.getProfileMethod()
+
+ coords, profile, profileName, xLabel = createProfile2(currentData)
+
+ data = core.ImageProfileData(
+ coords=coords,
+ profile=profile,
+ title=profileName,
+ xLabel=xLabel,
+ yLabel="Profile",
+ colormap=colormap,
+ )
+ return data
+
+
+class ProfileImageStackHorizontalLineROI(roi_items.HorizontalLineROI,
+ _DefaultImageStackProfileRoiMixIn):
+ """ROI for an horizontal profile at a location of a stack of images"""
+
+ ICON = 'shape-horizontal'
+ NAME = 'horizontal line profile'
+
+ def __init__(self, parent=None):
+ roi_items.HorizontalLineROI.__init__(self, parent=parent)
+ _DefaultImageStackProfileRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileImageStackVerticalLineROI(roi_items.VerticalLineROI,
+ _DefaultImageStackProfileRoiMixIn):
+ """ROI for an vertical profile at a location of a stack of images"""
+
+ ICON = 'shape-vertical'
+ NAME = 'vertical line profile'
+
+ def __init__(self, parent=None):
+ roi_items.VerticalLineROI.__init__(self, parent=parent)
+ _DefaultImageStackProfileRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileImageStackLineROI(roi_items.LineROI,
+ _DefaultImageStackProfileRoiMixIn):
+ """ROI for an vertical profile at a location of a stack of images"""
+
+ ICON = 'shape-diagonal'
+ NAME = 'line profile'
+
+ def __init__(self, parent=None):
+ roi_items.LineROI.__init__(self, parent=parent)
+ _DefaultImageStackProfileRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileImageStackCrossROI(ProfileImageCrossROI):
+ """ROI for an vertical profile at a location of a stack of images"""
+
+ ICON = 'shape-cross'
+ NAME = 'cross profile'
+ ITEM_KIND = items.ImageStack
+
+ def _createLines(self, parent):
+ vline = ProfileImageStackVerticalLineROI(parent=parent)
+ hline = ProfileImageStackHorizontalLineROI(parent=parent)
+ return hline, vline
+
+ def getProfileType(self):
+ hline, _vline = self._getLines()
+ return hline.getProfileType()
+
+ def setProfileType(self, kind):
+ hline, vline = self._getLines()
+ hline.setProfileType(kind)
+ vline.setProfileType(kind)
+ self.invalidateProperties()
diff --git a/src/silx/gui/plot/tools/profile/toolbar.py b/src/silx/gui/plot/tools/profile/toolbar.py
new file mode 100644
index 0000000..4a9a195
--- /dev/null
+++ b/src/silx/gui/plot/tools/profile/toolbar.py
@@ -0,0 +1,172 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides tool bar helper.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+
+import logging
+import weakref
+
+from silx.gui import qt
+from silx.gui.widgets.MultiModeAction import MultiModeAction
+from . import manager
+from .. import roi as roi_mdl
+from silx.gui.plot import items
+
+
+_logger = logging.getLogger(__name__)
+
+
+class ProfileToolBar(qt.QToolBar):
+ """Tool bar to provide profile for a plot.
+
+ It is an helper class. For a dedicated application it would be better to
+ use an own tool bar in order in order have more flexibility.
+ """
+ def __init__(self, parent=None, plot=None):
+ super(ProfileToolBar, self).__init__(parent=parent)
+ self.__scheme = None
+ self.__manager = None
+ self.__plot = weakref.ref(plot)
+ self.__multiAction = None
+
+ def getPlotWidget(self):
+ """The :class:`~silx.gui.plot.PlotWidget` associated to the toolbar.
+
+ :rtype: Union[~silx.gui.plot.PlotWidget,None]
+ """
+ if self.__plot is None:
+ return None
+ plot = self.__plot()
+ if self.__plot is None:
+ self.__plot = None
+ return plot
+
+ def setScheme(self, scheme):
+ """Initialize the tool bar using a configuration scheme.
+
+ It have to be done once and only once.
+
+ :param str scheme: One of "scatter", "image", "imagestack"
+ """
+ assert self.__scheme is None
+ self.__scheme = scheme
+
+ plot = self.getPlotWidget()
+ self.__manager = manager.ProfileManager(self, plot)
+
+ if scheme == "image":
+ self.__manager.setItemType(image=True)
+ self.__manager.setActiveItemTracking(True)
+
+ multiAction = MultiModeAction(self)
+ self.addAction(multiAction)
+ for action in self.__manager.createImageActions(self):
+ multiAction.addAction(action)
+ self.__multiAction = multiAction
+
+ cleanAction = self.__manager.createClearAction(self)
+ self.addAction(cleanAction)
+ editorAction = self.__manager.createEditorAction(self)
+ self.addAction(editorAction)
+
+ plot.sigActiveImageChanged.connect(self._activeImageChanged)
+ self._activeImageChanged()
+
+ elif scheme == "scatter":
+ self.__manager.setItemType(scatter=True)
+ self.__manager.setActiveItemTracking(True)
+
+ multiAction = MultiModeAction(self)
+ self.addAction(multiAction)
+ for action in self.__manager.createScatterActions(self):
+ multiAction.addAction(action)
+ for action in self.__manager.createScatterSliceActions(self):
+ multiAction.addAction(action)
+ self.__multiAction = multiAction
+
+ cleanAction = self.__manager.createClearAction(self)
+ self.addAction(cleanAction)
+ editorAction = self.__manager.createEditorAction(self)
+ self.addAction(editorAction)
+
+ plot.sigActiveScatterChanged.connect(self._activeScatterChanged)
+ self._activeScatterChanged()
+
+ elif scheme == "imagestack":
+ self.__manager.setItemType(image=True)
+ self.__manager.setActiveItemTracking(True)
+
+ multiAction = MultiModeAction(self)
+ self.addAction(multiAction)
+ for action in self.__manager.createImageStackActions(self):
+ multiAction.addAction(action)
+ self.__multiAction = multiAction
+
+ cleanAction = self.__manager.createClearAction(self)
+ self.addAction(cleanAction)
+ editorAction = self.__manager.createEditorAction(self)
+ self.addAction(editorAction)
+
+ plot.sigActiveImageChanged.connect(self._activeImageChanged)
+ self._activeImageChanged()
+
+ else:
+ raise ValueError("Toolbar scheme %s unsupported" % scheme)
+
+ def _setRoiActionEnabled(self, itemKind, enabled):
+ for action in self.__multiAction.getMenu().actions():
+ if not isinstance(action, roi_mdl.CreateRoiModeAction):
+ continue
+ roiClass = action.getRoiClass()
+ if issubclass(itemKind, roiClass.ITEM_KIND):
+ action.setEnabled(enabled)
+
+ def _activeImageChanged(self, previous=None, legend=None):
+ """Handle active image change to toggle actions"""
+ if legend is None:
+ self._setRoiActionEnabled(items.ImageStack, False)
+ self._setRoiActionEnabled(items.ImageBase, False)
+ else:
+ plot = self.getPlotWidget()
+ image = plot.getActiveImage()
+ # Disable for empty image
+ enabled = image.getData(copy=False).size > 0
+ self._setRoiActionEnabled(type(image), enabled)
+
+ def _activeScatterChanged(self, previous=None, legend=None):
+ """Handle active scatter change to toggle actions"""
+ if legend is None:
+ self._setRoiActionEnabled(items.Scatter, False)
+ else:
+ plot = self.getPlotWidget()
+ scatter = plot.getActiveScatter()
+ # Disable for empty image
+ enabled = scatter.getValueData(copy=False).size > 0
+ self._setRoiActionEnabled(type(scatter), enabled)
diff --git a/src/silx/gui/plot/tools/roi.py b/src/silx/gui/plot/tools/roi.py
new file mode 100644
index 0000000..e4be6a7
--- /dev/null
+++ b/src/silx/gui/plot/tools/roi.py
@@ -0,0 +1,1417 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides ROI interaction for :class:`~silx.gui.plot.PlotWidget`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+
+import enum
+import logging
+import time
+import weakref
+import functools
+
+import numpy
+
+from ... import qt, icons
+from ...utils import blockSignals
+from ...utils import LockReentrant
+from .. import PlotWidget
+from ..items import roi as roi_items
+
+from ...colors import rgba
+
+
+logger = logging.getLogger(__name__)
+
+
+class CreateRoiModeAction(qt.QAction):
+ """
+ This action is a plot mode which allows to create new ROIs using a ROI
+ manager.
+
+ A ROI is created using a specific `roiClass`. `initRoi` and `finalizeRoi`
+ can be inherited to custom the ROI initialization.
+
+ :param class roiClass: The ROI class which will be created by this action.
+ :param qt.QObject parent: The action parent
+ :param RegionOfInterestManager roiManager: The ROI manager
+ """
+
+ def __init__(self, parent, roiManager, roiClass):
+ assert roiManager is not None
+ assert roiClass is not None
+ qt.QAction.__init__(self, parent=parent)
+ self._roiManager = weakref.ref(roiManager)
+ self._roiClass = roiClass
+ self._singleShot = False
+ self._initAction()
+ self.triggered[bool].connect(self._actionTriggered)
+
+ def _initAction(self):
+ """Default initialization of the action"""
+ roiClass = self._roiClass
+
+ name = None
+ iconName = None
+ if hasattr(roiClass, "NAME"):
+ name = roiClass.NAME
+ if hasattr(roiClass, "ICON"):
+ iconName = roiClass.ICON
+
+ if iconName is None:
+ iconName = "add-shape-unknown"
+ if name is None:
+ name = roiClass.__name__
+ text = 'Add %s' % name
+ self.setIcon(icons.getQIcon(iconName))
+ self.setText(text)
+ self.setCheckable(True)
+ self.setToolTip(text)
+
+ def getRoiClass(self):
+ """Return the ROI class used by this action to create ROIs"""
+ return self._roiClass
+
+ def getRoiManager(self):
+ return self._roiManager()
+
+ def setSingleShot(self, singleShot):
+ """Set it to True to deactivate the action after the first creation
+ of a ROI.
+
+ :param bool singleShot: New single short state
+ """
+ self._singleShot = singleShot
+
+ def getSingleShot(self):
+ """If True, after the first creation of a ROI with this mode,
+ the mode is deactivated.
+
+ :rtype: bool
+ """
+ return self._singleShot
+
+ def _actionTriggered(self, checked):
+ """Handle mode actions being checked by the user
+
+ :param bool checked:
+ :param str kind: Corresponding shape kind
+ """
+ roiManager = self.getRoiManager()
+ if roiManager is None:
+ return
+
+ if checked:
+ roiManager.start(self._roiClass, self)
+ self.__interactiveModeStarted(roiManager)
+ else:
+ source = roiManager.getInteractionSource()
+ if source is self:
+ roiManager.stop()
+
+ def __interactiveModeStarted(self, roiManager):
+ roiManager.sigInteractiveRoiCreated.connect(self.initRoi)
+ roiManager.sigInteractiveRoiFinalized.connect(self.__finalizeRoi)
+ roiManager.sigInteractiveModeFinished.connect(self.__interactiveModeFinished)
+
+ def __interactiveModeFinished(self):
+ roiManager = self.getRoiManager()
+ if roiManager is not None:
+ roiManager.sigInteractiveRoiCreated.disconnect(self.initRoi)
+ roiManager.sigInteractiveRoiFinalized.disconnect(self.__finalizeRoi)
+ roiManager.sigInteractiveModeFinished.disconnect(self.__interactiveModeFinished)
+ self.setChecked(False)
+
+ def initRoi(self, roi):
+ """Inherit it to custom the new ROI at it's creation during the
+ interaction."""
+ pass
+
+ def __finalizeRoi(self, roi):
+ self.finalizeRoi(roi)
+ if self._singleShot:
+ roiManager = self.getRoiManager()
+ if roiManager is not None:
+ roiManager.stop()
+
+ def finalizeRoi(self, roi):
+ """Inherit it to custom the new ROI after it's creation when the
+ interaction is finalized."""
+ pass
+
+
+class RoiModeSelector(qt.QWidget):
+ def __init__(self, parent=None):
+ super(RoiModeSelector, self).__init__(parent=parent)
+ self.__roi = None
+ self.__reentrant = LockReentrant()
+
+ layout = qt.QHBoxLayout(self)
+ if isinstance(parent, qt.QMenu):
+ margins = layout.contentsMargins()
+ layout.setContentsMargins(margins.left(), 0, margins.right(), 0)
+ else:
+ layout.setContentsMargins(0, 0, 0, 0)
+
+ self._label = qt.QLabel(self)
+ self._label.setText("Mode:")
+ self._label.setToolTip("Select a specific interaction to edit the ROI")
+ self._combo = qt.QComboBox(self)
+ self._combo.currentIndexChanged.connect(self._modeSelected)
+ layout.addWidget(self._label)
+ layout.addWidget(self._combo)
+ self._updateAvailableModes()
+
+ def getRoi(self):
+ """Returns the edited ROI.
+
+ :rtype: roi_items.RegionOfInterest
+ """
+ return self.__roi
+
+ def setRoi(self, roi):
+ """Returns the edited ROI.
+
+ :rtype: roi_items.RegionOfInterest
+ """
+ if self.__roi is roi:
+ return
+ if not isinstance(roi, roi_items.InteractionModeMixIn):
+ self.__roi = None
+ self._updateAvailableModes()
+ return
+
+ if self.__roi is not None:
+ self.__roi.sigInteractionModeChanged.disconnect(self._modeChanged)
+ self.__roi = roi
+ if self.__roi is not None:
+ self.__roi.sigInteractionModeChanged.connect(self._modeChanged)
+ self._updateAvailableModes()
+
+ def isEmpty(self):
+ return not self._label.isVisibleTo(self)
+
+ def _updateAvailableModes(self):
+ roi = self.getRoi()
+ if isinstance(roi, roi_items.InteractionModeMixIn):
+ modes = roi.availableInteractionModes()
+ else:
+ modes = []
+ if len(modes) <= 1:
+ self._label.setVisible(False)
+ self._combo.setVisible(False)
+ else:
+ self._label.setVisible(True)
+ self._combo.setVisible(True)
+ with blockSignals(self._combo):
+ self._combo.clear()
+ for im, m in enumerate(modes):
+ self._combo.addItem(m.label, m)
+ self._combo.setItemData(im, m.description, qt.Qt.ToolTipRole)
+ mode = roi.getInteractionMode()
+ self._modeChanged(mode)
+ index = modes.index(mode)
+ self._combo.setCurrentIndex(index)
+
+ def _modeChanged(self, mode):
+ """Triggered when the ROI interaction mode was changed externally"""
+ if self.__reentrant.locked():
+ # This event was initialised by the widget
+ return
+ roi = self.__roi
+ modes = roi.availableInteractionModes()
+ index = modes.index(mode)
+ with blockSignals(self._combo):
+ self._combo.setCurrentIndex(index)
+
+ def _modeSelected(self):
+ """Triggered when the ROI interaction mode was selected in the widget"""
+ index = self._combo.currentIndex()
+ if index == -1:
+ return
+ roi = self.getRoi()
+ if roi is not None:
+ mode = self._combo.itemData(index, qt.Qt.UserRole)
+ with self.__reentrant:
+ roi.setInteractionMode(mode)
+
+
+class RoiModeSelectorAction(qt.QWidgetAction):
+ """Display the selected mode of a ROI and allow to change it"""
+
+ def __init__(self, parent=None):
+ super(RoiModeSelectorAction, self).__init__(parent)
+ self.__roiManager = None
+
+ def createWidget(self, parent):
+ """Inherit the method to create a new widget"""
+ widget = RoiModeSelector(parent)
+ manager = self.__roiManager
+ if manager is not None:
+ roi = manager.getCurrentRoi()
+ widget.setRoi(roi)
+ self.setVisible(not widget.isEmpty())
+ return widget
+
+ def deleteWidget(self, widget):
+ """Inherit the method to delete a widget"""
+ widget.setRoi(None)
+ return qt.QWidgetAction.deleteWidget(self, widget)
+
+ def setRoiManager(self, roiManager):
+ """
+ Connect this action to a ROI manager.
+
+ :param RegionOfInterestManager roiManager: A ROI manager
+ """
+ if self.__roiManager is roiManager:
+ return
+ if self.__roiManager is not None:
+ self.__roiManager.sigCurrentRoiChanged.disconnect(self.__currentRoiChanged)
+ self.__roiManager = roiManager
+ if self.__roiManager is not None:
+ self.__roiManager.sigCurrentRoiChanged.connect(self.__currentRoiChanged)
+ self.__currentRoiChanged(roiManager.getCurrentRoi())
+
+ def __currentRoiChanged(self, roi):
+ """Handle changes of the selected ROI"""
+ self.setRoi(roi)
+
+ def setRoi(self, roi):
+ """Set a profile ROI to edit.
+
+ :param ProfileRoiMixIn roi: A profile ROI
+ """
+ widget = None
+ for widget in self.createdWidgets():
+ widget.setRoi(roi)
+ if widget is not None:
+ self.setVisible(not widget.isEmpty())
+
+
+class RegionOfInterestManager(qt.QObject):
+ """Class handling ROI interaction on a PlotWidget.
+
+ It supports the multiple ROIs: points, rectangles, polygons,
+ lines, horizontal and vertical lines.
+
+ See ``plotInteractiveImageROI.py`` sample code (:ref:`sample-code`).
+
+ :param silx.gui.plot.PlotWidget parent:
+ The plot widget in which to control the ROIs.
+ """
+
+ sigRoiAdded = qt.Signal(roi_items.RegionOfInterest)
+ """Signal emitted when a new ROI has been added.
+
+ It provides the newly add :class:`RegionOfInterest` object.
+ """
+
+ sigRoiAboutToBeRemoved = qt.Signal(roi_items.RegionOfInterest)
+ """Signal emitted just before a ROI is removed.
+
+ It provides the :class:`RegionOfInterest` object that is about to be removed.
+ """
+
+ sigRoiChanged = qt.Signal()
+ """Signal emitted whenever the ROIs have changed."""
+
+ sigCurrentRoiChanged = qt.Signal(object)
+ """Signal emitted whenever a ROI is selected."""
+
+ sigInteractiveModeStarted = qt.Signal(object)
+ """Signal emitted when switching to ROI drawing interactive mode.
+
+ It provides the class of the ROI which will be created by the interactive
+ mode.
+ """
+
+ sigInteractiveRoiCreated = qt.Signal(object)
+ """Signal emitted when a ROI is created during the interaction.
+ The interaction is still incomplete and can be aborted.
+
+ It provides the ROI object which was just been created.
+ """
+
+ sigInteractiveRoiFinalized = qt.Signal(object)
+ """Signal emitted when a ROI creation is complet.
+
+ It provides the ROI object which was just been created.
+ """
+
+ sigInteractiveModeFinished = qt.Signal()
+ """Signal emitted when leaving interactive ROI drawing mode.
+ """
+
+ ROI_CLASSES = (
+ roi_items.PointROI,
+ roi_items.CrossROI,
+ roi_items.RectangleROI,
+ roi_items.CircleROI,
+ roi_items.EllipseROI,
+ roi_items.PolygonROI,
+ roi_items.LineROI,
+ roi_items.HorizontalLineROI,
+ roi_items.VerticalLineROI,
+ roi_items.ArcROI,
+ roi_items.HorizontalRangeROI,
+ )
+
+ def __init__(self, parent):
+ assert isinstance(parent, PlotWidget)
+ super(RegionOfInterestManager, self).__init__(parent)
+ self._rois = [] # List of ROIs
+ self._drawnROI = None # New ROI being currently drawn
+
+ self._roiClass = None
+ self._source = None
+ self._color = rgba('red')
+
+ self._label = "__RegionOfInterestManager__%d" % id(self)
+
+ self._currentRoi = None
+ """Hold currently selected ROI"""
+
+ self._eventLoop = None
+
+ self._modeActions = {}
+
+ parent.sigPlotSignal.connect(self._plotSignals)
+
+ parent.sigInteractiveModeChanged.connect(
+ self._plotInteractiveModeChanged)
+
+ parent.sigItemRemoved.connect(self._itemRemoved)
+
+ parent._sigDefaultContextMenu.connect(self._feedContextMenu)
+
+ @classmethod
+ def getSupportedRoiClasses(cls):
+ """Returns the default available ROI classes
+
+ :rtype: List[class]
+ """
+ return tuple(cls.ROI_CLASSES)
+
+ # Associated QActions
+
+ def getInteractionModeAction(self, roiClass):
+ """Returns the QAction corresponding to a kind of ROI
+
+ The QAction allows to enable the corresponding drawing
+ interactive mode.
+
+ :param class roiClass: The ROI class which will be created by this action.
+ :rtype: QAction
+ :raise ValueError: If kind is not supported
+ """
+ if not issubclass(roiClass, roi_items.RegionOfInterest):
+ raise ValueError('Unsupported ROI class %s' % roiClass)
+
+ action = self._modeActions.get(roiClass, None)
+ if action is None: # Lazy-loading
+ action = CreateRoiModeAction(self, self, roiClass)
+ self._modeActions[roiClass] = action
+ return action
+
+ # PlotWidget eventFilter and listeners
+
+ def _plotInteractiveModeChanged(self, source):
+ """Handle change of interactive mode in the plot"""
+ if source is not self:
+ self.__roiInteractiveModeEnded()
+
+ def _getRoiFromItem(self, item):
+ """Returns the ROI which own this item, else None
+ if this manager do not have knowledge of this ROI."""
+ for roi in self._rois:
+ if isinstance(roi, roi_items.RegionOfInterest):
+ for child in roi.getItems():
+ if child is item:
+ return roi
+ return None
+
+ def _itemRemoved(self, item):
+ """Called after an item was removed from the plot."""
+ if not hasattr(item, "_roiGroup"):
+ # Early break to avoid to use _getRoiFromItem
+ # And to avoid reentrant signal when the ROI remove the item itself
+ return
+ roi = self._getRoiFromItem(item)
+ if roi is not None:
+ self.removeRoi(roi)
+
+ # Handle ROI interaction
+
+ def _handleInteraction(self, event):
+ """Handle mouse interaction for ROI addition"""
+ roiClass = self.getCurrentInteractionModeRoiClass()
+ if roiClass is None:
+ return # Should not happen
+
+ kind = roiClass.getFirstInteractionShape()
+ if kind == 'point':
+ if event['event'] == 'mouseClicked' and event['button'] == 'left':
+ points = numpy.array([(event['x'], event['y'])],
+ dtype=numpy.float64)
+ # Not an interactive creation
+ roi = self._createInteractiveRoi(roiClass, points=points)
+ roi.creationFinalized()
+ self.sigInteractiveRoiFinalized.emit(roi)
+ else: # other shapes
+ if (event['event'] in ('drawingProgress', 'drawingFinished') and
+ event['parameters']['label'] == self._label):
+ points = numpy.array((event['xdata'], event['ydata']),
+ dtype=numpy.float64).T
+
+ if self._drawnROI is None: # Create new ROI
+ # NOTE: Set something before createRoi, so isDrawing is True
+ self._drawnROI = object()
+ self._drawnROI = self._createInteractiveRoi(roiClass, points=points)
+ else:
+ self._drawnROI.setFirstShapePoints(points)
+
+ if event['event'] == 'drawingFinished':
+ if kind == 'polygon' and len(points) > 1:
+ self._drawnROI.setFirstShapePoints(points[:-1])
+ roi = self._drawnROI
+ self._drawnROI = None # Stop drawing
+ roi.creationFinalized()
+ self.sigInteractiveRoiFinalized.emit(roi)
+
+ # RegionOfInterest selection
+
+ def __getRoiFromMarker(self, marker):
+ """Returns a ROI from a marker, else None"""
+ # This should be speed up
+ for roi in self._rois:
+ if isinstance(roi, roi_items.HandleBasedROI):
+ for m in roi.getHandles():
+ if m is marker:
+ return roi
+ else:
+ for m in roi.getItems():
+ if m is marker:
+ return roi
+ return None
+
+ def setCurrentRoi(self, roi):
+ """Set the currently selected ROI, and emit a signal.
+
+ :param Union[RegionOfInterest,None] roi: The ROI to select
+ """
+ if self._currentRoi is roi:
+ return
+ if roi is not None:
+ # Note: Fixed range to avoid infinite loops
+ for _ in range(10):
+ target = roi.getFocusProxy()
+ if target is None:
+ break
+ roi = target
+ else:
+ raise RuntimeError("Max selection proxy depth (10) reached.")
+
+ if self._currentRoi is not None:
+ self._currentRoi.setHighlighted(False)
+ self._currentRoi = roi
+ if self._currentRoi is not None:
+ self._currentRoi.setHighlighted(True)
+ self.sigCurrentRoiChanged.emit(roi)
+
+ def getCurrentRoi(self):
+ """Returns the currently selected ROI, else None.
+
+ :rtype: Union[RegionOfInterest,None]
+ """
+ return self._currentRoi
+
+ def _plotSignals(self, event):
+ """Handle mouse interaction for ROI addition"""
+ clicked = False
+ roi = None
+ if event["event"] in ("markerClicked", "markerMoving"):
+ plot = self.parent()
+ legend = event["label"]
+ marker = plot._getMarker(legend=legend)
+ roi = self.__getRoiFromMarker(marker)
+ elif event["event"] == "mouseClicked" and event["button"] == "left":
+ # Marker click is only for dnd
+ # This also can click on a marker
+ clicked = True
+ plot = self.parent()
+ marker = plot._getMarkerAt(event["xpixel"], event["ypixel"])
+ roi = self.__getRoiFromMarker(marker)
+ else:
+ return
+
+ if roi not in self._rois:
+ # The ROI is not own by this manager
+ return
+
+ if roi is not None:
+ currentRoi = self.getCurrentRoi()
+ if currentRoi is roi:
+ if clicked:
+ self.__updateMode(roi)
+ elif roi.isSelectable():
+ self.setCurrentRoi(roi)
+ else:
+ self.setCurrentRoi(None)
+
+ def __updateMode(self, roi):
+ if isinstance(roi, roi_items.InteractionModeMixIn):
+ available = roi.availableInteractionModes()
+ mode = roi.getInteractionMode()
+ imode = available.index(mode)
+ mode = available[(imode + 1) % len(available)]
+ roi.setInteractionMode(mode)
+
+ def _feedContextMenu(self, menu):
+ """Called when the default plot context menu is about to be displayed"""
+ roi = self.getCurrentRoi()
+ if roi is not None:
+ if roi.isEditable():
+ # Filter by data position
+ # FIXME: It would be better to use GUI coords for it
+ plot = self.parent()
+ pos = plot.getWidgetHandle().mapFromGlobal(qt.QCursor.pos())
+ data = plot.pixelToData(pos.x(), pos.y())
+ if roi.contains(data):
+ if isinstance(roi, roi_items.InteractionModeMixIn):
+ self._contextMenuForInteractionMode(menu, roi)
+
+ removeAction = qt.QAction(menu)
+ removeAction.setText("Remove %s" % roi.getName())
+ callback = functools.partial(self.removeRoi, roi)
+ removeAction.triggered.connect(callback)
+ menu.addAction(removeAction)
+
+ def _contextMenuForInteractionMode(self, menu, roi):
+ availableModes = roi.availableInteractionModes()
+ currentMode = roi.getInteractionMode()
+ submenu = qt.QMenu(menu)
+ modeGroup = qt.QActionGroup(menu)
+ modeGroup.setExclusive(True)
+ for mode in availableModes:
+ action = qt.QAction(menu)
+ action.setText(mode.label)
+ action.setToolTip(mode.description)
+ action.setCheckable(True)
+ if mode is currentMode:
+ action.setChecked(True)
+ else:
+ callback = functools.partial(roi.setInteractionMode, mode)
+ action.triggered.connect(callback)
+ modeGroup.addAction(action)
+ submenu.addAction(action)
+ submenu.setTitle("%s interaction mode" % roi.getName())
+ menu.addMenu(submenu)
+
+ # RegionOfInterest API
+
+ def getRois(self):
+ """Returns the list of ROIs.
+
+ It returns an empty tuple if there is currently no ROI.
+
+ :return: Tuple of arrays of objects describing the ROIs
+ :rtype: List[RegionOfInterest]
+ """
+ return tuple(self._rois)
+
+ def clear(self):
+ """Reset current ROIs
+
+ :return: True if ROIs were reset.
+ :rtype: bool
+ """
+ if self.getRois(): # Something to reset
+ for roi in self._rois:
+ roi.sigRegionChanged.disconnect(
+ self._regionOfInterestChanged)
+ roi.setParent(None)
+ self._rois = []
+ self._roisUpdated()
+ return True
+
+ else:
+ return False
+
+ def _regionOfInterestChanged(self, event=None):
+ """Handle ROI object changed"""
+ self.sigRoiChanged.emit()
+
+ def _createInteractiveRoi(self, roiClass, points, label=None, index=None):
+ """Create a new ROI with interactive creation.
+
+ :param class roiClass: The class of the ROI to create
+ :param numpy.ndarray points: The first shape used to create the ROI
+ :param str label: The label to display along with the ROI.
+ :param int index: The position where to insert the ROI.
+ By default it is appended to the end of the list.
+ :return: The created ROI object
+ :rtype: roi_items.RegionOfInterest
+ :raise RuntimeError: When ROI cannot be added because the maximum
+ number of ROIs has been reached.
+ """
+ roi = roiClass(parent=None)
+ if label is not None:
+ roi.setName(str(label))
+ roi.creationStarted()
+ roi.setFirstShapePoints(points)
+
+ self.addRoi(roi, index)
+ if roi.isSelectable():
+ self.setCurrentRoi(roi)
+ self.sigInteractiveRoiCreated.emit(roi)
+ return roi
+
+ def containsRoi(self, roi):
+ """Returns true if the ROI is part of this manager.
+
+ :param roi_items.RegionOfInterest roi: The ROI to add
+ :rtype: bool
+ """
+ return roi in self._rois
+
+ def addRoi(self, roi, index=None, useManagerColor=True):
+ """Add the ROI to the list of ROIs.
+
+ :param roi_items.RegionOfInterest roi: The ROI to add
+ :param int index: The position where to insert the ROI,
+ By default it is appended to the end of the list of ROIs
+ :param bool useManagerColor:
+ Whether to set the ROI color to the default one of the manager or not.
+ (Default: True).
+ :raise RuntimeError: When ROI cannot be added because the maximum
+ number of ROIs has been reached.
+ """
+ plot = self.parent()
+ if plot is None:
+ raise RuntimeError(
+ 'Cannot add ROI: PlotWidget no more available')
+
+ roi.setParent(self)
+
+ if useManagerColor:
+ roi.setColor(self.getColor())
+
+ roi.sigRegionChanged.connect(self._regionOfInterestChanged)
+ roi.sigItemChanged.connect(self._regionOfInterestChanged)
+
+ if index is None:
+ self._rois.append(roi)
+ else:
+ self._rois.insert(index, roi)
+ self.sigRoiAdded.emit(roi)
+ self._roisUpdated()
+
+ def removeRoi(self, roi):
+ """Remove a ROI from the list of ROIs.
+
+ :param roi_items.RegionOfInterest roi: The ROI to remove
+ :raise ValueError: When ROI does not belong to this object
+ """
+ if not (isinstance(roi, roi_items.RegionOfInterest) and
+ roi.parent() is self and
+ roi in self._rois):
+ raise ValueError(
+ 'RegionOfInterest does not belong to this instance')
+
+ roi.sigAboutToBeRemoved.emit()
+ self.sigRoiAboutToBeRemoved.emit(roi)
+
+ if roi is self._currentRoi:
+ self.setCurrentRoi(None)
+
+ mustRestart = False
+ if roi is self._drawnROI:
+ self._drawnROI = None
+ mustRestart = True
+ self._rois.remove(roi)
+ roi.sigRegionChanged.disconnect(self._regionOfInterestChanged)
+ roi.sigItemChanged.disconnect(self._regionOfInterestChanged)
+ roi.setParent(None)
+ self._roisUpdated()
+
+ if mustRestart:
+ self._restart()
+
+ def _roisUpdated(self):
+ """Handle update of the ROI list"""
+ self.sigRoiChanged.emit()
+
+ # RegionOfInterest parameters
+
+ def getColor(self):
+ """Return the default color of created ROIs
+
+ :rtype: QColor
+ """
+ return qt.QColor.fromRgbF(*self._color)
+
+ def setColor(self, color):
+ """Set the default color to use when creating ROIs.
+
+ Existing ROIs are not affected.
+
+ :param color: The color to use for displaying ROIs as
+ either a color name, a QColor, a list of uint8 or float in [0, 1].
+ """
+ self._color = rgba(color)
+
+ # Control ROI
+
+ def getCurrentInteractionModeRoiClass(self):
+ """Returns the current ROI class used by the interactive drawing mode.
+
+ Returns None if the ROI manager is not in an interactive mode.
+
+ :rtype: Union[class,None]
+ """
+ return self._roiClass
+
+ def getInteractionSource(self):
+ """Returns the object which have requested the ROI creation.
+
+ Returns None if the ROI manager is not in an interactive mode.
+
+ :rtype: Union[object,None]
+ """
+ return self._source
+
+ def isStarted(self):
+ """Returns True if an interactive ROI drawing mode is active.
+
+ :rtype: bool
+ """
+ return self._roiClass is not None
+
+ def isDrawing(self):
+ """Returns True if an interactive ROI is drawing.
+
+ :rtype: bool
+ """
+ return self._drawnROI is not None
+
+ def start(self, roiClass, source=None):
+ """Start an interactive ROI drawing mode.
+
+ :param class roiClass: The ROI class to create. It have to inherite from
+ `roi_items.RegionOfInterest`.
+ :param object source: SOurce of the ROI interaction.
+ :return: True if interactive ROI drawing was started, False otherwise
+ :rtype: bool
+ :raise ValueError: If roiClass is not supported
+ """
+ self.stop()
+
+ if not issubclass(roiClass, roi_items.RegionOfInterest):
+ raise ValueError('Unsupported ROI class %s' % roiClass)
+
+ plot = self.parent()
+ if plot is None:
+ return False
+
+ self._roiClass = roiClass
+ self._source = source
+
+ self._restart()
+
+ plot.sigPlotSignal.connect(self._handleInteraction)
+
+ self.sigInteractiveModeStarted.emit(roiClass)
+
+ return True
+
+ def _restart(self):
+ """Restart the plot interaction without changing the
+ source or the ROI class.
+ """
+ roiClass = self._roiClass
+ plot = self.parent()
+ firstInteractionShapeKind = roiClass.getFirstInteractionShape()
+
+ if firstInteractionShapeKind == 'point':
+ plot.setInteractiveMode(mode='select', source=self)
+ else:
+ if roiClass.showFirstInteractionShape():
+ color = rgba(self.getColor())
+ else:
+ color = None
+ plot.setInteractiveMode(mode='select-draw',
+ source=self,
+ shape=firstInteractionShapeKind,
+ color=color,
+ label=self._label)
+
+ def __roiInteractiveModeEnded(self):
+ """Handle end of ROI draw interactive mode"""
+ if self.isStarted():
+ self._roiClass = None
+ self._source = None
+
+ if self._drawnROI is not None:
+ # Cancel ROI create
+ roi = self._drawnROI
+ self._drawnROI = None
+ self.removeRoi(roi)
+
+ plot = self.parent()
+ if plot is not None:
+ plot.sigPlotSignal.disconnect(self._handleInteraction)
+
+ self.sigInteractiveModeFinished.emit()
+
+ def stop(self):
+ """Stop interactive ROI drawing mode.
+
+ :return: True if an interactive ROI drawing mode was actually stopped
+ :rtype: bool
+ """
+ if not self.isStarted():
+ return False
+
+ plot = self.parent()
+ if plot is not None:
+ # This leads to call __roiInteractiveModeEnded through
+ # interactive mode changed signal
+ plot.resetInteractiveMode()
+ else: # Fallback
+ self.__roiInteractiveModeEnded()
+
+ return True
+
+ def exec(self, roiClass):
+ """Block until :meth:`quit` is called.
+
+ :param class kind: The class of the ROI which have to be created.
+ See `silx.gui.plot.items.roi`.
+ :return: The list of ROIs
+ :rtype: tuple
+ """
+ self.start(roiClass)
+
+ plot = self.parent()
+ plot.show()
+ plot.raise_()
+
+ self._eventLoop = qt.QEventLoop()
+ self._eventLoop.exec()
+ self._eventLoop = None
+
+ self.stop()
+
+ rois = self.getRois()
+ self.clear()
+ return rois
+
+ def exec_(self, roiClass): # Qt5-like compatibility
+ return self.exec(roiClass)
+
+ def quit(self):
+ """Stop a blocking :meth:`exec` and call :meth:`stop`"""
+ if self._eventLoop is not None:
+ self._eventLoop.quit()
+ self._eventLoop = None
+ self.stop()
+
+
+class InteractiveRegionOfInterestManager(RegionOfInterestManager):
+ """RegionOfInterestManager with features for use from interpreter.
+
+ It is meant to be used through the :meth:`exec`.
+ It provides some messages to display in a status bar and
+ different modes to end blocking calls to :meth:`exec`.
+
+ :param parent: See QObject
+ """
+
+ sigMessageChanged = qt.Signal(str)
+ """Signal emitted when a new message should be displayed to the user
+
+ It provides the message as a str.
+ """
+
+ def __init__(self, parent):
+ super(InteractiveRegionOfInterestManager, self).__init__(parent)
+ self._maxROI = None
+ self.__timeoutEndTime = None
+ self.__message = ''
+ self.__validationMode = self.ValidationMode.ENTER
+ self.__execClass = None
+
+ self.sigRoiAdded.connect(self.__added)
+ self.sigRoiAboutToBeRemoved.connect(self.__aboutToBeRemoved)
+ self.sigInteractiveModeStarted.connect(self.__started)
+ self.sigInteractiveModeFinished.connect(self.__finished)
+
+ # Max ROI
+
+ def getMaxRois(self):
+ """Returns the maximum number of ROIs or None if no limit.
+
+ :rtype: Union[int,None]
+ """
+ return self._maxROI
+
+ def setMaxRois(self, max_):
+ """Set the maximum number of ROIs.
+
+ :param Union[int,None] max_: The max limit or None for no limit.
+ :raise ValueError: If there is more ROIs than max value
+ """
+ if max_ is not None:
+ max_ = int(max_)
+ if max_ <= 0:
+ raise ValueError('Max limit must be strictly positive')
+
+ if len(self.getRois()) > max_:
+ raise ValueError(
+ 'Cannot set max limit: Already too many ROIs')
+
+ self._maxROI = max_
+
+ def isMaxRois(self):
+ """Returns True if the maximum number of ROIs is reached.
+
+ :rtype: bool
+ """
+ max_ = self.getMaxRois()
+ return max_ is not None and len(self.getRois()) >= max_
+
+ # Validation mode
+
+ @enum.unique
+ class ValidationMode(enum.Enum):
+ """Mode of validation to leave blocking :meth:`exec`"""
+
+ AUTO = 'auto'
+ """Automatically ends the interactive mode once
+ the user terminates the last ROI shape."""
+
+ ENTER = 'enter'
+ """Ends the interactive mode when the *Enter* key is pressed."""
+
+ AUTO_ENTER = 'auto_enter'
+ """Ends the interactive mode when reaching max ROIs or
+ when the *Enter* key is pressed.
+ """
+
+ NONE = 'none'
+ """Do not provide the user a way to end the interactive mode.
+
+ The end of :meth:`exec` is done through :meth:`quit` or timeout.
+ """
+
+ def getValidationMode(self):
+ """Returns the interactive mode validation in use.
+
+ :rtype: ValidationMode
+ """
+ return self.__validationMode
+
+ def setValidationMode(self, mode):
+ """Set the way to perform interactive mode validation.
+
+ See :class:`ValidationMode` enumeration for the supported
+ validation modes.
+
+ :param ValidationMode mode: The interactive mode validation to use.
+ """
+ assert isinstance(mode, self.ValidationMode)
+ if mode != self.__validationMode:
+ self.__validationMode = mode
+
+ if self.isExec():
+ if (self.isMaxRois() and self.getValidationMode() in
+ (self.ValidationMode.AUTO,
+ self.ValidationMode.AUTO_ENTER)):
+ self.quit()
+
+ self.__updateMessage()
+
+ def eventFilter(self, obj, event):
+ if event.type() == qt.QEvent.Hide:
+ self.quit()
+
+ if event.type() == qt.QEvent.KeyPress:
+ key = event.key()
+ if (key in (qt.Qt.Key_Return, qt.Qt.Key_Enter) and
+ self.getValidationMode() in (
+ self.ValidationMode.ENTER,
+ self.ValidationMode.AUTO_ENTER)):
+ # Stop on return key pressed
+ self.quit()
+ return True # Stop further handling of this keys
+
+ if (key in (qt.Qt.Key_Delete, qt.Qt.Key_Backspace) or (
+ key == qt.Qt.Key_Z and
+ event.modifiers() & qt.Qt.ControlModifier)):
+ rois = self.getRois()
+ if rois: # Something to undo
+ self.removeRoi(rois[-1])
+ # Stop further handling of keys if something was undone
+ return True
+
+ return super(InteractiveRegionOfInterestManager, self).eventFilter(obj, event)
+
+ # Message API
+
+ def getMessage(self):
+ """Returns the current status message.
+
+ This message is meant to be displayed in a status bar.
+
+ :rtype: str
+ """
+ if self.__timeoutEndTime is None:
+ return self.__message
+ else:
+ remaining = self.__timeoutEndTime - time.time()
+ return self.__message + (' - %d seconds remaining' %
+ max(1, int(remaining)))
+
+ # Listen to ROI updates
+
+ def __added(self, *args, **kwargs):
+ """Handle new ROI added"""
+ max_ = self.getMaxRois()
+ if max_ is not None:
+ # When reaching max number of ROIs, redo last one
+ while len(self.getRois()) > max_:
+ self.removeRoi(self.getRois()[-2])
+
+ self.__updateMessage()
+ if (self.isMaxRois() and
+ self.getValidationMode() in (self.ValidationMode.AUTO,
+ self.ValidationMode.AUTO_ENTER)):
+ self.quit()
+
+ def __aboutToBeRemoved(self, *args, **kwargs):
+ """Handle removal of a ROI"""
+ # RegionOfInterest not removed yet
+ self.__updateMessage(nbrois=len(self.getRois()) - 1)
+
+ def __started(self, roiKind):
+ """Handle interactive mode started"""
+ self.__updateMessage()
+
+ def __finished(self):
+ """Handle interactive mode finished"""
+ self.__updateMessage()
+
+ def __updateMessage(self, nbrois=None):
+ """Update message"""
+ if not self.isExec():
+ message = 'Done'
+
+ elif not self.isStarted():
+ message = 'Use %s ROI edition mode' % self.__execClass
+
+ else:
+ if nbrois is None:
+ nbrois = len(self.getRois())
+
+ name = self.__execClass._getShortName()
+
+ max_ = self.getMaxRois()
+ if max_ is None:
+ message = 'Select %ss (%d selected)' % (name, nbrois)
+
+ elif max_ <= 1:
+ message = 'Select a %s' % name
+ else:
+ message = 'Select %d/%d %ss' % (nbrois, max_, name)
+
+ if (self.getValidationMode() == self.ValidationMode.ENTER and
+ self.isMaxRois()):
+ message += ' - Press Enter to confirm'
+
+ if message != self.__message:
+ self.__message = message
+ # Use getMessage to add timeout message
+ self.sigMessageChanged.emit(self.getMessage())
+
+ # Handle blocking call
+
+ def __timeoutUpdate(self):
+ """Handle update of timeout"""
+ if (self.__timeoutEndTime is not None and
+ (self.__timeoutEndTime - time.time()) > 0):
+ self.sigMessageChanged.emit(self.getMessage())
+ else: # Stop interactive mode and message timer
+ timer = self.sender()
+ if timer is not None:
+ timer.stop()
+ self.__timeoutEndTime = None
+ self.quit()
+
+ def isExec(self):
+ """Returns True if :meth:`exec` is currently running.
+
+ :rtype: bool"""
+ return self.__execClass is not None
+
+ def exec(self, roiClass, timeout=0):
+ """Block until ROI selection is done or timeout is elapsed.
+
+ :meth:`quit` also ends this blocking call.
+
+ :param class roiClass: The class of the ROI which have to be created.
+ See `silx.gui.plot.items.roi`.
+ :param int timeout: Maximum duration in seconds to block.
+ Default: No timeout
+ :return: The list of ROIs
+ :rtype: List[RegionOfInterest]
+ """
+ plot = self.parent()
+ if plot is None:
+ return
+
+ self.__execClass = roiClass
+
+ plot.installEventFilter(self)
+
+ if timeout > 0:
+ self.__timeoutEndTime = time.time() + timeout
+ timer = qt.QTimer(self)
+ timer.timeout.connect(self.__timeoutUpdate)
+ timer.start(1000)
+
+ rois = super(InteractiveRegionOfInterestManager, self).exec(roiClass)
+
+ timer.stop()
+ self.__timeoutEndTime = None
+
+ else:
+ rois = super(InteractiveRegionOfInterestManager, self).exec(roiClass)
+
+ plot.removeEventFilter(self)
+
+ self.__execClass = None
+ self.__updateMessage()
+
+ return rois
+
+ def exec_(self, roiClass, timeout=0): # Qt5-like compatibility
+ return self.exec(roiClass, timeout)
+
+
+class _DeleteRegionOfInterestToolButton(qt.QToolButton):
+ """Tool button deleting a ROI object
+
+ :param parent: See QWidget
+ :param RegionOfInterest roi: The ROI to delete
+ """
+
+ def __init__(self, parent, roi):
+ super(_DeleteRegionOfInterestToolButton, self).__init__(parent)
+ self.setIcon(icons.getQIcon('remove'))
+ self.setToolTip("Remove this ROI")
+ self.__roiRef = roi if roi is None else weakref.ref(roi)
+ self.clicked.connect(self.__clicked)
+
+ def __clicked(self, checked):
+ """Handle button clicked"""
+ roi = None if self.__roiRef is None else self.__roiRef()
+ if roi is not None:
+ manager = roi.parent()
+ if manager is not None:
+ manager.removeRoi(roi)
+ self.__roiRef = None
+
+
+class RegionOfInterestTableWidget(qt.QTableWidget):
+ """Widget displaying the ROIs of a :class:`RegionOfInterestManager`"""
+
+ def __init__(self, parent=None):
+ super(RegionOfInterestTableWidget, self).__init__(parent)
+ self._roiManagerRef = None
+
+ headers = ['Label', 'Edit', 'Kind', 'Coordinates', '']
+ self.setColumnCount(len(headers))
+ self.setHorizontalHeaderLabels(headers)
+
+ horizontalHeader = self.horizontalHeader()
+ horizontalHeader.setDefaultAlignment(qt.Qt.AlignLeft)
+
+ horizontalHeader.setSectionResizeMode(0, qt.QHeaderView.Interactive)
+ horizontalHeader.setSectionResizeMode(1, qt.QHeaderView.ResizeToContents)
+ horizontalHeader.setSectionResizeMode(2, qt.QHeaderView.ResizeToContents)
+ horizontalHeader.setSectionResizeMode(3, qt.QHeaderView.Stretch)
+ horizontalHeader.setSectionResizeMode(4, qt.QHeaderView.ResizeToContents)
+
+ verticalHeader = self.verticalHeader()
+ verticalHeader.setVisible(False)
+
+ self.setSelectionMode(qt.QAbstractItemView.NoSelection)
+ self.setFocusPolicy(qt.Qt.NoFocus)
+
+ self.itemChanged.connect(self.__itemChanged)
+
+ def __itemChanged(self, item):
+ """Handle item updates"""
+ column = item.column()
+ index = item.data(qt.Qt.UserRole)
+
+ if index is not None:
+ manager = self.getRegionOfInterestManager()
+ roi = manager.getRois()[index]
+ else:
+ return
+
+ if column == 0:
+ # First collect information from item, then update ROI
+ # Otherwise, this causes issues issues
+ checked = item.checkState() == qt.Qt.Checked
+ text= item.text()
+ roi.setVisible(checked)
+ roi.setName(text)
+ elif column == 1:
+ roi.setEditable(item.checkState() == qt.Qt.Checked)
+ elif column in (2, 3, 4):
+ pass # TODO
+ else:
+ logger.error('Unhandled column %d', column)
+
+ def setRegionOfInterestManager(self, manager):
+ """Set the :class:`RegionOfInterestManager` object to sync with
+
+ :param RegionOfInterestManager manager:
+ """
+ assert manager is None or isinstance(manager, RegionOfInterestManager)
+
+ previousManager = self.getRegionOfInterestManager()
+
+ if previousManager is not None:
+ previousManager.sigRoiChanged.disconnect(self._sync)
+ self.setRowCount(0)
+
+ self._roiManagerRef = weakref.ref(manager)
+
+ self._sync()
+
+ if manager is not None:
+ manager.sigRoiChanged.connect(self._sync)
+
+ def _getReadableRoiDescription(self, roi):
+ """Returns modelisation of a ROI as a readable sequence of values.
+
+ :rtype: str
+ """
+ text = str(roi)
+ try:
+ # Extract the params from syntax "CLASSNAME(PARAMS)"
+ elements = text.split("(", 1)
+ if len(elements) != 2:
+ return text
+ result = elements[1]
+ result = result.strip()
+ if not result.endswith(")"):
+ return text
+ result = result[0:-1]
+ # Capitalize each words
+ result = result.title()
+ return result
+ except Exception:
+ logger.debug("Backtrace", exc_info=True)
+ return text
+
+ def _sync(self):
+ """Update widget content according to ROI manger"""
+ manager = self.getRegionOfInterestManager()
+
+ if manager is None:
+ self.setRowCount(0)
+ return
+
+ rois = manager.getRois()
+
+ self.setRowCount(len(rois))
+ for index, roi in enumerate(rois):
+ baseFlags = qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled
+
+ # Label and visible
+ label = roi.getName()
+ item = qt.QTableWidgetItem(label)
+ item.setFlags(baseFlags | qt.Qt.ItemIsEditable | qt.Qt.ItemIsUserCheckable)
+ item.setData(qt.Qt.UserRole, index)
+ item.setCheckState(
+ qt.Qt.Checked if roi.isVisible() else qt.Qt.Unchecked)
+ self.setItem(index, 0, item)
+
+ # Editable
+ item = qt.QTableWidgetItem()
+ item.setFlags(baseFlags | qt.Qt.ItemIsUserCheckable)
+ item.setData(qt.Qt.UserRole, index)
+ item.setCheckState(
+ qt.Qt.Checked if roi.isEditable() else qt.Qt.Unchecked)
+ self.setItem(index, 1, item)
+ item.setTextAlignment(qt.Qt.AlignCenter)
+ item.setText(None)
+
+ # Kind
+ label = roi._getShortName()
+ if label is None:
+ # Default value if kind is not overrided
+ label = roi.__class__.__name__
+ item = qt.QTableWidgetItem(label.capitalize())
+ item.setFlags(baseFlags)
+ self.setItem(index, 2, item)
+
+ item = qt.QTableWidgetItem()
+ item.setFlags(baseFlags)
+
+ # Coordinates
+ text = self._getReadableRoiDescription(roi)
+ item.setText(text)
+ self.setItem(index, 3, item)
+
+ # Delete
+ delBtn = _DeleteRegionOfInterestToolButton(None, roi)
+ widget = qt.QWidget(self)
+ layout = qt.QHBoxLayout()
+ layout.setContentsMargins(2, 2, 2, 2)
+ layout.setSpacing(0)
+ widget.setLayout(layout)
+ layout.addStretch(1)
+ layout.addWidget(delBtn)
+ layout.addStretch(1)
+ self.setCellWidget(index, 4, widget)
+
+ def getRegionOfInterestManager(self):
+ """Returns the :class:`RegionOfInterestManager` this widget supervise.
+
+ It returns None if not sync with an :class:`RegionOfInterestManager`.
+
+ :rtype: RegionOfInterestManager
+ """
+ return None if self._roiManagerRef is None else self._roiManagerRef()
diff --git a/src/silx/gui/plot/tools/test/__init__.py b/src/silx/gui/plot/tools/test/__init__.py
new file mode 100644
index 0000000..aa4a601
--- /dev/null
+++ b/src/silx/gui/plot/tools/test/__init__.py
@@ -0,0 +1,24 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
diff --git a/src/silx/gui/plot/tools/test/testCurveLegendsWidget.py b/src/silx/gui/plot/tools/test/testCurveLegendsWidget.py
new file mode 100644
index 0000000..37af10e
--- /dev/null
+++ b/src/silx/gui/plot/tools/test/testCurveLegendsWidget.py
@@ -0,0 +1,113 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "02/08/2018"
+
+
+import unittest
+
+from silx.gui import qt
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.plot import PlotWindow
+from silx.gui.plot.tools import CurveLegendsWidget
+
+
+class TestCurveLegendsWidget(TestCaseQt, ParametricTestCase):
+ """Tests for CurveLegendsWidget class"""
+
+ def setUp(self):
+ super(TestCurveLegendsWidget, self).setUp()
+ self.plot = PlotWindow()
+
+ self.legends = CurveLegendsWidget.CurveLegendsWidget()
+ self.legends.setPlotWidget(self.plot)
+
+ dock = qt.QDockWidget()
+ dock.setWindowTitle('Curve Legends')
+ dock.setWidget(self.legends)
+ self.plot.addTabbedDockWidget(dock)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ def tearDown(self):
+ del self.legends
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ super(TestCurveLegendsWidget, self).tearDown()
+
+ def _assertNbLegends(self, count):
+ """Check the number of legends in the CurveLegendsWidget"""
+ children = self.legends.findChildren(CurveLegendsWidget._LegendWidget)
+ self.assertEqual(len(children), count)
+
+ def testAddRemoveCurves(self):
+ """Test CurveLegendsWidget while adding/removing curves"""
+ self.plot.addCurve((0, 1), (1, 2), legend='a')
+ self._assertNbLegends(1)
+ self.plot.addCurve((0, 1), (2, 3), legend='b')
+ self._assertNbLegends(2)
+
+ # Detached/attach
+ self.legends.setPlotWidget(None)
+ self._assertNbLegends(0)
+
+ self.legends.setPlotWidget(self.plot)
+ self._assertNbLegends(2)
+
+ self.plot.clear()
+ self._assertNbLegends(0)
+
+ def testUpdateCurves(self):
+ """Test CurveLegendsWidget while updating curves """
+ self.plot.addCurve((0, 1), (1, 2), legend='a')
+ self._assertNbLegends(1)
+ self.plot.addCurve((0, 1), (2, 3), legend='b')
+ self._assertNbLegends(2)
+
+ # Activate curve
+ self.plot.setActiveCurve('a')
+ self.qapp.processEvents()
+ self.plot.setActiveCurve('b')
+ self.qapp.processEvents()
+
+ # Change curve style
+ curve = self.plot.getCurve('a')
+ curve.setLineWidth(2)
+ for linestyle in (':', '', '--', '-'):
+ with self.subTest(linestyle=linestyle):
+ curve.setLineStyle(linestyle)
+ self.qapp.processEvents()
+ self.qWait(1000)
+
+ for symbol in ('o', 'd', '', 's'):
+ with self.subTest(symbol=symbol):
+ curve.setSymbol(symbol)
+ self.qapp.processEvents()
+ self.qWait(1000)
diff --git a/src/silx/gui/plot/tools/test/testProfile.py b/src/silx/gui/plot/tools/test/testProfile.py
new file mode 100644
index 0000000..829f49e
--- /dev/null
+++ b/src/silx/gui/plot/tools/test/testProfile.py
@@ -0,0 +1,654 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+
+import unittest
+import contextlib
+import numpy
+import logging
+
+from silx.gui import qt
+from silx.utils import deprecation
+from silx.utils import testutils
+
+from silx.gui.utils.testutils import TestCaseQt
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.plot import PlotWindow, Plot1D, Plot2D, Profile
+from silx.gui.plot.StackView import StackView
+from silx.gui.plot.tools.profile import rois
+from silx.gui.plot.tools.profile import editors
+from silx.gui.plot.items import roi as roi_items
+from silx.gui.plot.tools.profile import manager
+from silx.gui import plot as silx_plot
+
+_logger = logging.getLogger(__name__)
+
+
+class TestRois(TestCaseQt):
+
+ def test_init(self):
+ """Check that the constructor is not called twice"""
+ roi = rois.ProfileImageVerticalLineROI()
+ if qt.BINDING == "PyQt5":
+ # the profile ROI + the shape
+ self.assertEqual(roi.receivers(roi.sigRegionChanged), 2)
+
+
+class TestInteractions(TestCaseQt):
+
+ @contextlib.contextmanager
+ def defaultPlot(self):
+ try:
+ widget = silx_plot.PlotWidget()
+ widget.show()
+ self.qWaitForWindowExposed(widget)
+ yield widget
+ finally:
+ widget.close()
+ widget = None
+ self.qWait()
+
+ @contextlib.contextmanager
+ def imagePlot(self):
+ try:
+ widget = silx_plot.Plot2D()
+ image = numpy.arange(10 * 10).reshape(10, -1)
+ widget.addImage(image)
+ widget.show()
+ self.qWaitForWindowExposed(widget)
+ yield widget
+ finally:
+ widget.close()
+ widget = None
+ self.qWait()
+
+ @contextlib.contextmanager
+ def scatterPlot(self):
+ try:
+ widget = silx_plot.ScatterView()
+
+ nbX, nbY = 7, 5
+ yy = numpy.atleast_2d(numpy.ones(nbY)).T
+ xx = numpy.atleast_2d(numpy.ones(nbX))
+ positionX = numpy.linspace(10, 50, nbX) * yy
+ positionX = positionX.reshape(nbX * nbY)
+ positionY = numpy.atleast_2d(numpy.linspace(20, 60, nbY)).T * xx
+ positionY = positionY.reshape(nbX * nbY)
+ values = numpy.arange(nbX * nbY)
+
+ widget.setData(positionX, positionY, values)
+ widget.resetZoom()
+ widget.show()
+ self.qWaitForWindowExposed(widget)
+ yield widget.getPlotWidget()
+ finally:
+ widget.close()
+ widget = None
+ self.qWait()
+
+ @contextlib.contextmanager
+ def stackPlot(self):
+ try:
+ widget = silx_plot.StackView()
+ image = numpy.arange(10 * 10).reshape(10, -1)
+ cube = numpy.array([image, image, image])
+ widget.setStack(cube)
+ widget.resetZoom()
+ widget.show()
+ self.qWaitForWindowExposed(widget)
+ yield widget.getPlotWidget()
+ finally:
+ widget.close()
+ widget = None
+ self.qWait()
+
+ def waitPendingOperations(self, proflie):
+ for _ in range(10):
+ if not proflie.hasPendingOperations():
+ return
+ self.qWait(100)
+ _logger.error("The profile manager still have pending operations")
+
+ def genericRoiTest(self, plot, roiClass):
+ profileManager = manager.ProfileManager(plot, plot)
+ profileManager.setItemType(image=True, scatter=True)
+
+ try:
+ action = profileManager.createProfileAction(roiClass, plot)
+ action.triggered[bool].emit(True)
+ widget = plot.getWidgetHandle()
+
+ # Do the mouse interaction
+ pos1 = widget.width() * 0.4, widget.height() * 0.4
+ self.mouseMove(widget, pos=pos1)
+ self.mouseClick(widget, qt.Qt.LeftButton, pos=pos1)
+
+ if issubclass(roiClass, roi_items.LineROI):
+ pos2 = widget.width() * 0.6, widget.height() * 0.6
+ self.mouseMove(widget, pos=pos2)
+ self.mouseClick(widget, qt.Qt.LeftButton, pos=pos2)
+
+ self.waitPendingOperations(profileManager)
+
+ # Test that something was computed
+ if issubclass(roiClass, rois._ProfileCrossROI):
+ self.assertEqual(profileManager._computedProfiles, 2)
+ elif issubclass(roiClass, roi_items.LineROI):
+ self.assertGreaterEqual(profileManager._computedProfiles, 1)
+ else:
+ self.assertEqual(profileManager._computedProfiles, 1)
+
+ # Test the created ROIs
+ profileRois = profileManager.getRoiManager().getRois()
+ if issubclass(roiClass, rois._ProfileCrossROI):
+ self.assertEqual(len(profileRois), 3)
+ else:
+ self.assertEqual(len(profileRois), 1)
+ # The first one should be the expected one
+ roi = profileRois[0]
+
+ # Test that something was displayed
+ if issubclass(roiClass, rois._ProfileCrossROI):
+ profiles = roi._getLines()
+ window = profiles[0].getProfileWindow()
+ self.assertIsNotNone(window)
+ window = profiles[1].getProfileWindow()
+ self.assertIsNotNone(window)
+ else:
+ window = roi.getProfileWindow()
+ self.assertIsNotNone(window)
+ finally:
+ profileManager.clearProfile()
+
+ def testImageActions(self):
+ roiClasses = [
+ rois.ProfileImageHorizontalLineROI,
+ rois.ProfileImageVerticalLineROI,
+ rois.ProfileImageLineROI,
+ rois.ProfileImageCrossROI,
+ ]
+ with self.imagePlot() as plot:
+ for roiClass in roiClasses:
+ with self.subTest(roiClass=roiClass):
+ self.genericRoiTest(plot, roiClass)
+
+ def testScatterActions(self):
+ roiClasses = [
+ rois.ProfileScatterHorizontalLineROI,
+ rois.ProfileScatterVerticalLineROI,
+ rois.ProfileScatterLineROI,
+ rois.ProfileScatterCrossROI,
+ rois.ProfileScatterHorizontalSliceROI,
+ rois.ProfileScatterVerticalSliceROI,
+ rois.ProfileScatterCrossSliceROI,
+ ]
+ with self.scatterPlot() as plot:
+ for roiClass in roiClasses:
+ with self.subTest(roiClass=roiClass):
+ self.genericRoiTest(plot, roiClass)
+
+ def testStackActions(self):
+ roiClasses = [
+ rois.ProfileImageStackHorizontalLineROI,
+ rois.ProfileImageStackVerticalLineROI,
+ rois.ProfileImageStackLineROI,
+ rois.ProfileImageStackCrossROI,
+ ]
+ with self.stackPlot() as plot:
+ for roiClass in roiClasses:
+ with self.subTest(roiClass=roiClass):
+ self.genericRoiTest(plot, roiClass)
+
+ def genericEditorTest(self, plot, roi, editor):
+ if isinstance(editor, editors._NoProfileRoiEditor):
+ pass
+ elif isinstance(editor, editors._DefaultImageStackProfileRoiEditor):
+ # GUI to ROI
+ editor._lineWidth.setValue(2)
+ self.assertEqual(roi.getProfileLineWidth(), 2)
+ editor._methodsButton.setMethod("sum")
+ self.assertEqual(roi.getProfileMethod(), "sum")
+ editor._profileDim.setDimension(1)
+ self.assertEqual(roi.getProfileType(), "1D")
+ # ROI to GUI
+ roi.setProfileLineWidth(3)
+ self.assertEqual(editor._lineWidth.value(), 3)
+ roi.setProfileMethod("mean")
+ self.assertEqual(editor._methodsButton.getMethod(), "mean")
+ roi.setProfileType("2D")
+ self.assertEqual(editor._profileDim.getDimension(), 2)
+ elif isinstance(editor, editors._DefaultImageProfileRoiEditor):
+ # GUI to ROI
+ editor._lineWidth.setValue(2)
+ self.assertEqual(roi.getProfileLineWidth(), 2)
+ editor._methodsButton.setMethod("sum")
+ self.assertEqual(roi.getProfileMethod(), "sum")
+ # ROI to GUI
+ roi.setProfileLineWidth(3)
+ self.assertEqual(editor._lineWidth.value(), 3)
+ roi.setProfileMethod("mean")
+ self.assertEqual(editor._methodsButton.getMethod(), "mean")
+ elif isinstance(editor, editors._DefaultScatterProfileRoiEditor):
+ # GUI to ROI
+ editor._nPoints.setValue(100)
+ self.assertEqual(roi.getNPoints(), 100)
+ # ROI to GUI
+ roi.setNPoints(200)
+ self.assertEqual(editor._nPoints.value(), 200)
+ else:
+ assert False
+
+ def testEditors(self):
+ roiClasses = [
+ (rois.ProfileImageHorizontalLineROI, editors._DefaultImageProfileRoiEditor),
+ (rois.ProfileImageVerticalLineROI, editors._DefaultImageProfileRoiEditor),
+ (rois.ProfileImageLineROI, editors._DefaultImageProfileRoiEditor),
+ (rois.ProfileImageCrossROI, editors._DefaultImageProfileRoiEditor),
+ (rois.ProfileScatterHorizontalLineROI, editors._DefaultScatterProfileRoiEditor),
+ (rois.ProfileScatterVerticalLineROI, editors._DefaultScatterProfileRoiEditor),
+ (rois.ProfileScatterLineROI, editors._DefaultScatterProfileRoiEditor),
+ (rois.ProfileScatterCrossROI, editors._DefaultScatterProfileRoiEditor),
+ (rois.ProfileScatterHorizontalSliceROI, editors._NoProfileRoiEditor),
+ (rois.ProfileScatterVerticalSliceROI, editors._NoProfileRoiEditor),
+ (rois.ProfileScatterCrossSliceROI, editors._NoProfileRoiEditor),
+ (rois.ProfileImageStackHorizontalLineROI, editors._DefaultImageStackProfileRoiEditor),
+ (rois.ProfileImageStackVerticalLineROI, editors._DefaultImageStackProfileRoiEditor),
+ (rois.ProfileImageStackLineROI, editors._DefaultImageStackProfileRoiEditor),
+ (rois.ProfileImageStackCrossROI, editors._DefaultImageStackProfileRoiEditor),
+ ]
+ with self.defaultPlot() as plot:
+ profileManager = manager.ProfileManager(plot, plot)
+ editorAction = profileManager.createEditorAction(parent=plot)
+ for roiClass, editorClass in roiClasses:
+ with self.subTest(roiClass=roiClass):
+ roi = roiClass()
+ roi._setProfileManager(profileManager)
+ try:
+ # Force widget creation
+ menu = qt.QMenu(plot)
+ menu.addAction(editorAction)
+ widgets = editorAction.createdWidgets()
+ self.assertGreater(len(widgets), 0)
+
+ editorAction.setProfileRoi(roi)
+ editorWidget = editorAction._getEditor(widgets[0])
+ self.assertIsInstance(editorWidget, editorClass)
+ self.genericEditorTest(plot, roi, editorWidget)
+ finally:
+ editorAction.setProfileRoi(None)
+ menu.deleteLater()
+ menu = None
+ self.qapp.processEvents()
+
+
+class TestProfileToolBar(TestCaseQt, ParametricTestCase):
+ """Tests for ProfileToolBar widget."""
+
+ def setUp(self):
+ super(TestProfileToolBar, self).setUp()
+ self.plot = PlotWindow()
+ self.toolBar = Profile.ProfileToolBar(plot=self.plot)
+ self.plot.addToolBar(self.toolBar)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ self.mouseMove(self.plot) # Move to center
+ self.qapp.processEvents()
+ deprecation.FORCE = True
+
+ def tearDown(self):
+ deprecation.FORCE = False
+ self.qapp.processEvents()
+ profileManager = self.toolBar.getProfileManager()
+ profileManager.clearProfile()
+ profileManager = None
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ del self.toolBar
+
+ super(TestProfileToolBar, self).tearDown()
+
+ def testAlignedProfile(self):
+ """Test horizontal and vertical profile, without and with image"""
+ # Use Plot backend widget to submit mouse events
+ widget = self.plot.getWidgetHandle()
+ for method in ('sum', 'mean'):
+ with self.subTest(method=method):
+ # 2 positions to use for mouse events
+ pos1 = widget.width() * 0.4, widget.height() * 0.4
+ pos2 = widget.width() * 0.6, widget.height() * 0.6
+
+ for action in (self.toolBar.hLineAction, self.toolBar.vLineAction):
+ with self.subTest(mode=action.text()):
+ # Trigger tool button for mode
+ action.trigger()
+ # Without image
+ self.mouseMove(widget, pos=pos1)
+ self.mouseClick(widget, qt.Qt.LeftButton, pos=pos1)
+
+ # with image
+ self.plot.addImage(
+ numpy.arange(100 * 100).reshape(100, -1))
+ self.mousePress(widget, qt.Qt.LeftButton, pos=pos1)
+ self.mouseMove(widget, pos=pos2)
+ self.mouseRelease(widget, qt.Qt.LeftButton, pos=pos2)
+
+ self.mouseMove(widget)
+ self.mouseClick(widget, qt.Qt.LeftButton)
+
+ manager = self.toolBar.getProfileManager()
+ for _ in range(20):
+ self.qWait(200)
+ if not manager.hasPendingOperations():
+ break
+
+ @testutils.validate_logging(deprecation.depreclog.name, warning=4)
+ def testDiagonalProfile(self):
+ """Test diagonal profile, without and with image"""
+ # Use Plot backend widget to submit mouse events
+ widget = self.plot.getWidgetHandle()
+
+ self.plot.addImage(
+ numpy.arange(100 * 100).reshape(100, -1))
+
+ for method in ('sum', 'mean'):
+ with self.subTest(method=method):
+ # 2 positions to use for mouse events
+ pos1 = widget.width() * 0.4, widget.height() * 0.4
+ pos2 = widget.width() * 0.6, widget.height() * 0.6
+
+ # Trigger tool button for diagonal profile mode
+ self.toolBar.lineAction.trigger()
+
+ # draw profile line
+ widget.setFocus(qt.Qt.OtherFocusReason)
+ self.mouseMove(widget, pos=pos1)
+ self.qWait(100)
+ self.mousePress(widget, qt.Qt.LeftButton, pos=pos1)
+ self.qWait(100)
+ self.mouseMove(widget, pos=pos2)
+ self.qWait(100)
+ self.mouseRelease(widget, qt.Qt.LeftButton, pos=pos2)
+ self.qWait(100)
+
+ manager = self.toolBar.getProfileManager()
+
+ for _ in range(20):
+ self.qWait(200)
+ if not manager.hasPendingOperations():
+ break
+
+ roi = manager.getCurrentRoi()
+ self.assertIsNotNone(roi)
+ roi.setProfileLineWidth(3)
+ roi.setProfileMethod(method)
+
+ for _ in range(20):
+ self.qWait(200)
+ if not manager.hasPendingOperations():
+ break
+
+ curveItem = self.toolBar.getProfilePlot().getAllCurves()[0]
+ if method == 'sum':
+ self.assertTrue(curveItem.getData()[1].max() > 10000)
+ elif method == 'mean':
+ self.assertTrue(curveItem.getData()[1].max() < 10000)
+
+ # Remove the ROI so the profile window is also removed
+ roiManager = manager.getRoiManager()
+ roiManager.removeRoi(roi)
+ self.qWait(100)
+
+
+class TestDeprecatedProfileToolBar(TestCaseQt):
+ """Tests old features of the ProfileToolBar widget."""
+
+ def setUp(self):
+ self.plot = None
+ super(TestDeprecatedProfileToolBar, self).setUp()
+
+ def tearDown(self):
+ if self.plot is not None:
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ self.plot = None
+ self.qWait()
+
+ super(TestDeprecatedProfileToolBar, self).tearDown()
+
+ @testutils.validate_logging(deprecation.depreclog.name, warning=2)
+ def testCustomProfileWindow(self):
+ from silx.gui.plot import ProfileMainWindow
+
+ self.plot = PlotWindow()
+ profileWindow = ProfileMainWindow.ProfileMainWindow(self.plot)
+ toolBar = Profile.ProfileToolBar(parent=self.plot,
+ plot=self.plot,
+ profileWindow=profileWindow)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+ profileWindow.show()
+ self.qWaitForWindowExposed(profileWindow)
+ self.qapp.processEvents()
+
+ self.plot.addImage(numpy.arange(10 * 10).reshape(10, -1))
+ profile = rois.ProfileImageHorizontalLineROI()
+ profile.setPosition(5)
+ toolBar.getProfileManager().getRoiManager().addRoi(profile)
+ toolBar.getProfileManager().getRoiManager().setCurrentRoi(profile)
+
+ for _ in range(20):
+ self.qWait(200)
+ if not toolBar.getProfileManager().hasPendingOperations():
+ break
+
+ # There is a displayed profile
+ self.assertIsNotNone(profileWindow.getProfile())
+ self.assertIs(toolBar.getProfileMainWindow(), profileWindow)
+
+ # There is nothing anymore but the window is still there
+ toolBar.getProfileManager().clearProfile()
+ self.qapp.processEvents()
+ self.assertIsNone(profileWindow.getProfile())
+
+
+class TestProfile3DToolBar(TestCaseQt):
+ """Tests for Profile3DToolBar widget.
+ """
+ def setUp(self):
+ super(TestProfile3DToolBar, self).setUp()
+ self.plot = StackView()
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ self.plot.setStack(numpy.array([
+ [[0, 1, 2], [3, 4, 5]],
+ [[6, 7, 8], [9, 10, 11]],
+ [[12, 13, 14], [15, 16, 17]]
+ ]))
+ deprecation.FORCE = True
+
+ def tearDown(self):
+ deprecation.FORCE = False
+ profileManager = self.plot.getProfileToolbar().getProfileManager()
+ profileManager.clearProfile()
+ profileManager = None
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ self.plot = None
+
+ super(TestProfile3DToolBar, self).tearDown()
+
+ @testutils.validate_logging(deprecation.depreclog.name, warning=2)
+ def testMethodProfile2D(self):
+ """Test that the profile can have a different method if we want to
+ compute then in 1D or in 2D"""
+
+ toolBar = self.plot.getProfileToolbar()
+
+ toolBar.vLineAction.trigger()
+ plot2D = self.plot.getPlotWidget().getWidgetHandle()
+ pos1 = plot2D.width() * 0.5, plot2D.height() * 0.5
+ self.mouseClick(plot2D, qt.Qt.LeftButton, pos=pos1)
+
+ manager = toolBar.getProfileManager()
+ roi = manager.getCurrentRoi()
+ roi.setProfileMethod("mean")
+ roi.setProfileType("2D")
+ roi.setProfileLineWidth(3)
+
+ for _ in range(20):
+ self.qWait(200)
+ if not manager.hasPendingOperations():
+ break
+
+ # check 2D 'mean' profile
+ profilePlot = toolBar.getProfilePlot()
+ data = profilePlot.getAllImages()[0].getData()
+ expected = numpy.array([[1, 4], [7, 10], [13, 16]])
+ numpy.testing.assert_almost_equal(data, expected)
+
+ @testutils.validate_logging(deprecation.depreclog.name, warning=2)
+ def testMethodSumLine(self):
+ """Simple interaction test to make sure the sum is correctly computed
+ """
+ toolBar = self.plot.getProfileToolbar()
+
+ toolBar.lineAction.trigger()
+ plot2D = self.plot.getPlotWidget().getWidgetHandle()
+ pos1 = plot2D.width() * 0.5, plot2D.height() * 0.2
+ pos2 = plot2D.width() * 0.5, plot2D.height() * 0.8
+
+ self.mouseMove(plot2D, pos=pos1)
+ self.mousePress(plot2D, qt.Qt.LeftButton, pos=pos1)
+ self.mouseMove(plot2D, pos=pos2)
+ self.mouseRelease(plot2D, qt.Qt.LeftButton, pos=pos2)
+
+ manager = toolBar.getProfileManager()
+ roi = manager.getCurrentRoi()
+ roi.setProfileMethod("sum")
+ roi.setProfileType("2D")
+ roi.setProfileLineWidth(3)
+
+ for _ in range(20):
+ self.qWait(200)
+ if not manager.hasPendingOperations():
+ break
+
+ # check 2D 'sum' profile
+ profilePlot = toolBar.getProfilePlot()
+ data = profilePlot.getAllImages()[0].getData()
+ expected = numpy.array([[3, 12], [21, 30], [39, 48]])
+ numpy.testing.assert_almost_equal(data, expected)
+
+
+class TestGetProfilePlot(TestCaseQt):
+
+ def setUp(self):
+ self.plot = None
+ super(TestGetProfilePlot, self).setUp()
+
+ def tearDown(self):
+ if self.plot is not None:
+ manager = self.plot.getProfileToolbar().getProfileManager()
+ manager.clearProfile()
+ manager = None
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ self.plot = None
+
+ super(TestGetProfilePlot, self).tearDown()
+
+ def testProfile1D(self):
+ self.plot = Plot2D()
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+ self.plot.addImage([[0, 1], [2, 3]])
+
+ toolBar = self.plot.getProfileToolbar()
+
+ manager = toolBar.getProfileManager()
+ roiManager = manager.getRoiManager()
+
+ roi = rois.ProfileImageHorizontalLineROI()
+ roi.setPosition(0.5)
+ roiManager.addRoi(roi)
+ roiManager.setCurrentRoi(roi)
+
+ for _ in range(20):
+ self.qWait(200)
+ if not manager.hasPendingOperations():
+ break
+
+ profileWindow = roi.getProfileWindow()
+ self.assertIsInstance(roi.getProfileWindow(), qt.QMainWindow)
+ self.assertIsInstance(profileWindow.getCurrentPlotWidget(), Plot1D)
+
+ def testProfile2D(self):
+ """Test that the profile plot associated to a stack view is either a
+ Plot1D or a plot 2D instance."""
+ self.plot = StackView()
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ self.plot.setStack(numpy.array([[[0, 1], [2, 3]],
+ [[4, 5], [6, 7]]]))
+
+ toolBar = self.plot.getProfileToolbar()
+
+ manager = toolBar.getProfileManager()
+ roiManager = manager.getRoiManager()
+
+ roi = rois.ProfileImageStackHorizontalLineROI()
+ roi.setPosition(0.5)
+ roi.setProfileType("2D")
+ roiManager.addRoi(roi)
+ roiManager.setCurrentRoi(roi)
+
+ for _ in range(20):
+ self.qWait(200)
+ if not manager.hasPendingOperations():
+ break
+
+ profileWindow = roi.getProfileWindow()
+ self.assertIsInstance(roi.getProfileWindow(), qt.QMainWindow)
+ self.assertIsInstance(profileWindow.getCurrentPlotWidget(), Plot2D)
+
+ roi.setProfileType("1D")
+
+ for _ in range(20):
+ self.qWait(200)
+ if not manager.hasPendingOperations():
+ break
+
+ profileWindow = roi.getProfileWindow()
+ self.assertIsInstance(roi.getProfileWindow(), qt.QMainWindow)
+ self.assertIsInstance(profileWindow.getCurrentPlotWidget(), Plot1D)
diff --git a/src/silx/gui/plot/tools/test/testROI.py b/src/silx/gui/plot/tools/test/testROI.py
new file mode 100644
index 0000000..21697d1
--- /dev/null
+++ b/src/silx/gui/plot/tools/test/testROI.py
@@ -0,0 +1,682 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+
+import unittest
+import numpy.testing
+
+from silx.gui import qt
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import TestCaseQt, SignalListener
+from silx.gui.plot import PlotWindow
+import silx.gui.plot.items.roi as roi_items
+from silx.gui.plot.tools import roi
+
+
+class TestRoiItems(TestCaseQt):
+
+ def testLine_geometry(self):
+ item = roi_items.LineROI()
+ startPoint = numpy.array([1, 2])
+ endPoint = numpy.array([3, 4])
+ item.setEndPoints(startPoint, endPoint)
+ numpy.testing.assert_allclose(item.getEndPoints()[0], startPoint)
+ numpy.testing.assert_allclose(item.getEndPoints()[1], endPoint)
+
+ def testHLine_geometry(self):
+ item = roi_items.HorizontalLineROI()
+ item.setPosition(15)
+ self.assertEqual(item.getPosition(), 15)
+
+ def testVLine_geometry(self):
+ item = roi_items.VerticalLineROI()
+ item.setPosition(15)
+ self.assertEqual(item.getPosition(), 15)
+
+ def testPoint_geometry(self):
+ point = numpy.array([1, 2])
+ item = roi_items.PointROI()
+ item.setPosition(point)
+ numpy.testing.assert_allclose(item.getPosition(), point)
+
+ def testRectangle_originGeometry(self):
+ origin = numpy.array([0, 0])
+ size = numpy.array([10, 20])
+ center = numpy.array([5, 10])
+ item = roi_items.RectangleROI()
+ item.setGeometry(origin=origin, size=size)
+ numpy.testing.assert_allclose(item.getOrigin(), origin)
+ numpy.testing.assert_allclose(item.getSize(), size)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+
+ def testRectangle_centerGeometry(self):
+ origin = numpy.array([0, 0])
+ size = numpy.array([10, 20])
+ center = numpy.array([5, 10])
+ item = roi_items.RectangleROI()
+ item.setGeometry(center=center, size=size)
+ numpy.testing.assert_allclose(item.getOrigin(), origin)
+ numpy.testing.assert_allclose(item.getSize(), size)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+
+ def testRectangle_setCenterGeometry(self):
+ origin = numpy.array([0, 0])
+ size = numpy.array([10, 20])
+ item = roi_items.RectangleROI()
+ item.setGeometry(origin=origin, size=size)
+ newCenter = numpy.array([0, 0])
+ item.setCenter(newCenter)
+ expectedOrigin = numpy.array([-5, -10])
+ numpy.testing.assert_allclose(item.getOrigin(), expectedOrigin)
+ numpy.testing.assert_allclose(item.getCenter(), newCenter)
+ numpy.testing.assert_allclose(item.getSize(), size)
+
+ def testRectangle_setOriginGeometry(self):
+ origin = numpy.array([0, 0])
+ size = numpy.array([10, 20])
+ item = roi_items.RectangleROI()
+ item.setGeometry(origin=origin, size=size)
+ newOrigin = numpy.array([10, 10])
+ item.setOrigin(newOrigin)
+ expectedCenter = numpy.array([15, 20])
+ numpy.testing.assert_allclose(item.getOrigin(), newOrigin)
+ numpy.testing.assert_allclose(item.getCenter(), expectedCenter)
+ numpy.testing.assert_allclose(item.getSize(), size)
+
+ def testCircle_geometry(self):
+ center = numpy.array([0, 0])
+ radius = 10.
+ item = roi_items.CircleROI()
+ item.setGeometry(center=center, radius=radius)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+ numpy.testing.assert_allclose(item.getRadius(), radius)
+
+ def testCircle_setCenter(self):
+ center = numpy.array([0, 0])
+ radius = 10.
+ item = roi_items.CircleROI()
+ item.setGeometry(center=center, radius=radius)
+ newCenter = numpy.array([-10, 0])
+ item.setCenter(newCenter)
+ numpy.testing.assert_allclose(item.getCenter(), newCenter)
+ numpy.testing.assert_allclose(item.getRadius(), radius)
+
+ def testCircle_setRadius(self):
+ center = numpy.array([0, 0])
+ radius = 10.
+ item = roi_items.CircleROI()
+ item.setGeometry(center=center, radius=radius)
+ newRadius = 5.1
+ item.setRadius(newRadius)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+ numpy.testing.assert_allclose(item.getRadius(), newRadius)
+
+ def testCircle_contains(self):
+ center = numpy.array([2, -1])
+ radius = 1.
+ item = roi_items.CircleROI()
+ item.setGeometry(center=center, radius=radius)
+ self.assertTrue(item.contains([1, -1]))
+ self.assertFalse(item.contains([0, 0]))
+ self.assertTrue(item.contains([2, 0]))
+ self.assertFalse(item.contains([3.01, -1]))
+
+ def testEllipse_contains(self):
+ center = numpy.array([-2, 0])
+ item = roi_items.EllipseROI()
+ item.setCenter(center)
+ item.setOrientation(numpy.pi / 4.0)
+ item.setMajorRadius(2)
+ item.setMinorRadius(1)
+ print(item.getMinorRadius(), item.getMajorRadius())
+ self.assertFalse(item.contains([0, 0]))
+ self.assertTrue(item.contains([-1, 1]))
+ self.assertTrue(item.contains([-3, 0]))
+ self.assertTrue(item.contains([-2, 0]))
+ self.assertTrue(item.contains([-2, 1]))
+ self.assertFalse(item.contains([-4, 1]))
+
+ def testRectangle_isIn(self):
+ origin = numpy.array([0, 0])
+ size = numpy.array([10, 20])
+ item = roi_items.RectangleROI()
+ item.setGeometry(origin=origin, size=size)
+ self.assertTrue(item.contains(position=(0, 0)))
+ self.assertTrue(item.contains(position=(2, 14)))
+ self.assertFalse(item.contains(position=(14, 12)))
+
+ def testPolygon_emptyGeometry(self):
+ points = numpy.empty((0, 2))
+ item = roi_items.PolygonROI()
+ item.setPoints(points)
+ numpy.testing.assert_allclose(item.getPoints(), points)
+
+ def testPolygon_geometry(self):
+ points = numpy.array([[10, 10], [12, 10], [50, 1]])
+ item = roi_items.PolygonROI()
+ item.setPoints(points)
+ numpy.testing.assert_allclose(item.getPoints(), points)
+
+ def testPolygon_isIn(self):
+ points = numpy.array([[0, 0], [0, 10], [5, 10]])
+ item = roi_items.PolygonROI()
+ item.setPoints(points)
+ self.assertTrue(item.contains((0, 0)))
+ self.assertFalse(item.contains((6, 2)))
+ self.assertFalse(item.contains((-2, 5)))
+ self.assertFalse(item.contains((2, -1)))
+ self.assertFalse(item.contains((8, 1)))
+ self.assertTrue(item.contains((1, 8)))
+
+ def testArc_getToSetGeometry(self):
+ """Test that we can use getGeometry as input to setGeometry"""
+ item = roi_items.ArcROI()
+ item.setFirstShapePoints(numpy.array([[5, 10], [50, 100]]))
+ item.setGeometry(*item.getGeometry())
+
+ def testArc_degenerated_point(self):
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 0, 0, 0, 0
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+
+ def testArc_degenerated_line(self):
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 0, 100, numpy.pi, numpy.pi
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+
+ def testArc_special_circle(self):
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 0, 100, numpy.pi, 3 * numpy.pi
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+ self.assertAlmostEqual(item.getInnerRadius(), innerRadius)
+ self.assertAlmostEqual(item.getOuterRadius(), outerRadius)
+ self.assertAlmostEqual(item.getStartAngle(), item.getEndAngle() - numpy.pi * 2.0)
+ self.assertTrue(item.isClosed())
+
+ def testArc_special_donut(self):
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi, 3 * numpy.pi
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+ self.assertAlmostEqual(item.getInnerRadius(), innerRadius)
+ self.assertAlmostEqual(item.getOuterRadius(), outerRadius)
+ self.assertAlmostEqual(item.getStartAngle(), item.getEndAngle() - numpy.pi * 2.0)
+ self.assertTrue(item.isClosed())
+
+ def testArc_clockwiseGeometry(self):
+ """Test that we can use getGeometry as input to setGeometry"""
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi * 0.5, numpy.pi
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+ self.assertAlmostEqual(item.getInnerRadius(), innerRadius)
+ self.assertAlmostEqual(item.getOuterRadius(), outerRadius)
+ self.assertAlmostEqual(item.getStartAngle(), startAngle)
+ self.assertAlmostEqual(item.getEndAngle(), endAngle)
+ self.assertAlmostEqual(item.isClosed(), False)
+
+ def testArc_anticlockwiseGeometry(self):
+ """Test that we can use getGeometry as input to setGeometry"""
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi * 0.5, -numpy.pi * 0.5
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+ self.assertAlmostEqual(item.getInnerRadius(), innerRadius)
+ self.assertAlmostEqual(item.getOuterRadius(), outerRadius)
+ self.assertAlmostEqual(item.getStartAngle(), startAngle)
+ self.assertAlmostEqual(item.getEndAngle(), endAngle)
+ self.assertAlmostEqual(item.isClosed(), False)
+
+ def testHRange_geometry(self):
+ item = roi_items.HorizontalRangeROI()
+ vmin = 1
+ vmax = 3
+ item.setRange(vmin, vmax)
+ self.assertAlmostEqual(item.getMin(), vmin)
+ self.assertAlmostEqual(item.getMax(), vmax)
+ self.assertAlmostEqual(item.getCenter(), 2)
+
+
+class TestRegionOfInterestManager(TestCaseQt, ParametricTestCase):
+ """Tests for RegionOfInterestManager class"""
+
+ def setUp(self):
+ super(TestRegionOfInterestManager, self).setUp()
+ self.plot = PlotWindow()
+
+ self.roiTableWidget = roi.RegionOfInterestTableWidget()
+ dock = qt.QDockWidget()
+ dock.setWidget(self.roiTableWidget)
+ self.plot.addDockWidget(qt.Qt.BottomDockWidgetArea, dock)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ def tearDown(self):
+ del self.roiTableWidget
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ super(TestRegionOfInterestManager, self).tearDown()
+
+ def test(self):
+ """Test ROI of different shapes"""
+ tests = ( # shape, points=[list of (x, y), list of (x, y)]
+ (roi_items.PointROI, numpy.array(([(10., 15.)], [(20., 25.)]))),
+ (roi_items.RectangleROI,
+ numpy.array((((1., 10.), (11., 20.)),
+ ((2., 3.), (12., 13.))))),
+ (roi_items.PolygonROI,
+ numpy.array((((0., 1.), (0., 10.), (10., 0.)),
+ ((5., 6.), (5., 16.), (15., 6.))))),
+ (roi_items.LineROI,
+ numpy.array((((10., 20.), (10., 30.)),
+ ((30., 40.), (30., 50.))))),
+ (roi_items.HorizontalLineROI,
+ numpy.array((((10., 20.), (10., 30.)),
+ ((30., 40.), (30., 50.))))),
+ (roi_items.VerticalLineROI,
+ numpy.array((((10., 20.), (10., 30.)),
+ ((30., 40.), (30., 50.))))),
+ (roi_items.HorizontalLineROI,
+ numpy.array((((10., 20.), (10., 30.)),
+ ((30., 40.), (30., 50.))))),
+ )
+
+ for roiClass, points in tests:
+ with self.subTest(roiClass=roiClass):
+ manager = roi.RegionOfInterestManager(self.plot)
+ self.roiTableWidget.setRegionOfInterestManager(manager)
+ manager.start(roiClass)
+
+ self.assertEqual(manager.getRois(), ())
+
+ finishListener = SignalListener()
+ manager.sigInteractiveModeFinished.connect(finishListener)
+
+ changedListener = SignalListener()
+ manager.sigRoiChanged.connect(changedListener)
+
+ # Add a point
+ r = roiClass()
+ r.setFirstShapePoints(points[0])
+ manager.addRoi(r)
+ self.qapp.processEvents()
+ self.assertTrue(len(manager.getRois()), 1)
+ self.assertEqual(changedListener.callCount(), 1)
+
+ # Remove it
+ manager.removeRoi(manager.getRois()[0])
+ self.assertEqual(manager.getRois(), ())
+ self.assertEqual(changedListener.callCount(), 2)
+
+ # Add two point
+ r = roiClass()
+ r.setFirstShapePoints(points[0])
+ manager.addRoi(r)
+ self.qapp.processEvents()
+ r = roiClass()
+ r.setFirstShapePoints(points[1])
+ manager.addRoi(r)
+ self.qapp.processEvents()
+ self.assertTrue(len(manager.getRois()), 2)
+ self.assertEqual(changedListener.callCount(), 4)
+
+ # Reset it
+ result = manager.clear()
+ self.assertTrue(result)
+ self.assertEqual(manager.getRois(), ())
+ self.assertEqual(changedListener.callCount(), 5)
+
+ changedListener.clear()
+
+ # Add two point
+ r = roiClass()
+ r.setFirstShapePoints(points[0])
+ manager.addRoi(r)
+ self.qapp.processEvents()
+ r = roiClass()
+ r.setFirstShapePoints(points[1])
+ manager.addRoi(r)
+ self.qapp.processEvents()
+ self.assertTrue(len(manager.getRois()), 2)
+ self.assertEqual(changedListener.callCount(), 2)
+
+ # stop
+ result = manager.stop()
+ self.assertTrue(result)
+ self.assertTrue(len(manager.getRois()), 1)
+ self.qapp.processEvents()
+ self.assertEqual(finishListener.callCount(), 1)
+
+ manager.clear()
+
+ def testRoiDisplay(self):
+ rois = []
+
+ # Line
+ item = roi_items.LineROI()
+ startPoint = numpy.array([1, 2])
+ endPoint = numpy.array([3, 4])
+ item.setEndPoints(startPoint, endPoint)
+ rois.append(item)
+ # Horizontal line
+ item = roi_items.HorizontalLineROI()
+ item.setPosition(15)
+ rois.append(item)
+ # Vertical line
+ item = roi_items.VerticalLineROI()
+ item.setPosition(15)
+ rois.append(item)
+ # Point
+ item = roi_items.PointROI()
+ point = numpy.array([1, 2])
+ item.setPosition(point)
+ rois.append(item)
+ # Rectangle
+ item = roi_items.RectangleROI()
+ origin = numpy.array([0, 0])
+ size = numpy.array([10, 20])
+ item.setGeometry(origin=origin, size=size)
+ rois.append(item)
+ # Polygon
+ item = roi_items.PolygonROI()
+ points = numpy.array([[10, 10], [12, 10], [50, 1]])
+ item.setPoints(points)
+ rois.append(item)
+ # Degenerated polygon: No points
+ item = roi_items.PolygonROI()
+ points = numpy.empty((0, 2))
+ item.setPoints(points)
+ rois.append(item)
+ # Degenerated polygon: A single point
+ item = roi_items.PolygonROI()
+ points = numpy.array([[5, 10]])
+ item.setPoints(points)
+ rois.append(item)
+ # Degenerated arc: it's a point
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 0, 0, 0, 0
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ rois.append(item)
+ # Degenerated arc: it's a line
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 0, 100, numpy.pi, numpy.pi
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ rois.append(item)
+ # Special arc: it's a donut
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi, 3 * numpy.pi
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ rois.append(item)
+ # Arc
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi * 0.5, numpy.pi
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ rois.append(item)
+ # Horizontal Range
+ item = roi_items.HorizontalRangeROI()
+ item.setRange(-1, 3)
+ rois.append(item)
+
+ manager = roi.RegionOfInterestManager(self.plot)
+ self.roiTableWidget.setRegionOfInterestManager(manager)
+ for item in rois:
+ with self.subTest(roi=str(item)):
+ manager.addRoi(item)
+ self.qapp.processEvents()
+ item.setEditable(True)
+ self.qapp.processEvents()
+ item.setEditable(False)
+ self.qapp.processEvents()
+ manager.removeRoi(item)
+ self.qapp.processEvents()
+
+ def testSelectionProxy(self):
+ item1 = roi_items.PointROI()
+ item1.setSelectable(True)
+ item2 = roi_items.PointROI()
+ item2.setSelectable(True)
+ item1.setFocusProxy(item2)
+ manager = roi.RegionOfInterestManager(self.plot)
+ manager.setCurrentRoi(item1)
+ self.assertIs(manager.getCurrentRoi(), item2)
+
+ def testRemovedSelection(self):
+ item1 = roi_items.PointROI()
+ item1.setSelectable(True)
+ manager = roi.RegionOfInterestManager(self.plot)
+ manager.addRoi(item1)
+ manager.setCurrentRoi(item1)
+ manager.removeRoi(item1)
+ self.assertIs(manager.getCurrentRoi(), None)
+
+ def testMaxROI(self):
+ """Test Max ROI"""
+ origin1 = numpy.array([1., 10.])
+ size1 = numpy.array([10., 10.])
+ origin2 = numpy.array([2., 3.])
+ size2 = numpy.array([10., 10.])
+
+ manager = roi.InteractiveRegionOfInterestManager(self.plot)
+ self.roiTableWidget.setRegionOfInterestManager(manager)
+ self.assertEqual(manager.getRois(), ())
+
+ changedListener = SignalListener()
+ manager.sigRoiChanged.connect(changedListener)
+
+ # Add two point
+ item = roi_items.RectangleROI()
+ item.setGeometry(origin=origin1, size=size1)
+ manager.addRoi(item)
+ item = roi_items.RectangleROI()
+ item.setGeometry(origin=origin2, size=size2)
+ manager.addRoi(item)
+ self.qapp.processEvents()
+ self.assertEqual(changedListener.callCount(), 2)
+ self.assertEqual(len(manager.getRois()), 2)
+
+ # Try to set max ROI to 1 while there is 2 ROIs
+ with self.assertRaises(ValueError):
+ manager.setMaxRois(1)
+
+ manager.clear()
+ self.assertEqual(len(manager.getRois()), 0)
+ self.assertEqual(changedListener.callCount(), 3)
+
+ # Set max limit to 1
+ manager.setMaxRois(1)
+
+ # Add a point
+ item = roi_items.RectangleROI()
+ item.setGeometry(origin=origin1, size=size1)
+ manager.addRoi(item)
+ self.qapp.processEvents()
+ self.assertEqual(changedListener.callCount(), 4)
+
+ # Add a 2nd point while max ROI is 1
+ item = roi_items.RectangleROI()
+ item.setGeometry(origin=origin1, size=size1)
+ manager.addRoi(item)
+ self.qapp.processEvents()
+ self.assertEqual(changedListener.callCount(), 6)
+ self.assertEqual(len(manager.getRois()), 1)
+
+ def testChangeInteractionMode(self):
+ """Test change of interaction mode"""
+ manager = roi.RegionOfInterestManager(self.plot)
+ self.roiTableWidget.setRegionOfInterestManager(manager)
+ manager.start(roi_items.PointROI)
+
+ interactiveModeToolBar = self.plot.getInteractiveModeToolBar()
+ panAction = interactiveModeToolBar.getPanModeAction()
+
+ for roiClass in manager.getSupportedRoiClasses():
+ with self.subTest(roiClass=roiClass):
+ # Change to pan mode
+ panAction.trigger()
+
+ # Change to interactive ROI mode
+ action = manager.getInteractionModeAction(roiClass)
+ action.trigger()
+
+ self.assertEqual(roiClass, manager.getCurrentInteractionModeRoiClass())
+
+ manager.clear()
+
+ def testLineInteraction(self):
+ """This test make sure that a ROI based on handles can be edited with
+ the mouse."""
+ xlimit = self.plot.getXAxis().getLimits()
+ ylimit = self.plot.getYAxis().getLimits()
+ points = numpy.array([xlimit, ylimit]).T
+ center = numpy.mean(points, axis=0)
+
+ # Create the line
+ manager = roi.RegionOfInterestManager(self.plot)
+ item = roi_items.LineROI()
+ item.setEndPoints(points[0], points[1])
+ item.setEditable(True)
+ manager.addRoi(item)
+ self.qapp.processEvents()
+
+ # Drag the center
+ widget = self.plot.getWidgetHandle()
+ mx, my = self.plot.dataToPixel(*center)
+ self.mouseMove(widget, pos=(mx, my))
+ self.mousePress(widget, qt.Qt.LeftButton, pos=(mx, my))
+ self.mouseMove(widget, pos=(mx, my+25))
+ self.mouseMove(widget, pos=(mx, my+50))
+ self.mouseRelease(widget, qt.Qt.LeftButton, pos=(mx, my+50))
+
+ result = numpy.array(item.getEndPoints())
+ # x location is still the same
+ numpy.testing.assert_allclose(points[:, 0], result[:, 0], atol=0.5)
+ # size is still the same
+ numpy.testing.assert_allclose(points[1] - points[0],
+ result[1] - result[0], atol=0.5)
+ # But Y is not the same
+ self.assertNotEqual(points[0, 1], result[0, 1])
+ self.assertNotEqual(points[1, 1], result[1, 1])
+ item = None
+ manager.clear()
+ self.qapp.processEvents()
+
+ def testPlotWhenCleared(self):
+ """PlotWidget.clear should clean up the available ROIs"""
+ manager = roi.RegionOfInterestManager(self.plot)
+ item = roi_items.LineROI()
+ item.setEndPoints((0, 0), (1, 1))
+ item.setEditable(True)
+ manager.addRoi(item)
+ self.qWait()
+ try:
+ # Make sure the test setup is fine
+ self.assertNotEqual(len(manager.getRois()), 0)
+ self.assertNotEqual(len(self.plot.getItems()), 0)
+
+ # Call clear and test the expected state
+ self.plot.clear()
+ self.assertEqual(len(manager.getRois()), 0)
+ self.assertEqual(len(self.plot.getItems()), 0)
+ finally:
+ # Clean up
+ manager.clear()
+
+ def testPlotWhenRoiRemoved(self):
+ """Make sure there is no remaining items in the plot when a ROI is removed"""
+ manager = roi.RegionOfInterestManager(self.plot)
+ item = roi_items.LineROI()
+ item.setEndPoints((0, 0), (1, 1))
+ item.setEditable(True)
+ manager.addRoi(item)
+ self.qWait()
+ try:
+ # Make sure the test setup is fine
+ self.assertNotEqual(len(manager.getRois()), 0)
+ self.assertNotEqual(len(self.plot.getItems()), 0)
+
+ # Call clear and test the expected state
+ manager.removeRoi(item)
+ self.assertEqual(len(manager.getRois()), 0)
+ self.assertEqual(len(self.plot.getItems()), 0)
+ finally:
+ # Clean up
+ manager.clear()
+
+ def testArcRoiSwitchMode(self):
+ """Make sure we can switch mode by clicking on the ROI"""
+ xlimit = self.plot.getXAxis().getLimits()
+ ylimit = self.plot.getYAxis().getLimits()
+ points = numpy.array([xlimit, ylimit]).T
+ center = numpy.mean(points, axis=0)
+ size = numpy.abs(points[1] - points[0])
+
+ # Create the line
+ manager = roi.RegionOfInterestManager(self.plot)
+ item = roi_items.ArcROI()
+ item.setGeometry(center, size[1] / 10, size[1] / 2, 0, 3)
+ item.setEditable(True)
+ item.setSelectable(True)
+ manager.addRoi(item)
+ self.qapp.processEvents()
+
+ # Initial state
+ self.assertIs(item.getInteractionMode(), roi_items.ArcROI.ThreePointMode)
+ self.qWait(500)
+
+ # Click on the center
+ widget = self.plot.getWidgetHandle()
+ mx, my = self.plot.dataToPixel(*center)
+
+ # Select the ROI
+ self.mouseMove(widget, pos=(mx, my))
+ self.mouseClick(widget, qt.Qt.LeftButton, pos=(mx, my))
+ self.qWait(500)
+ self.assertIs(item.getInteractionMode(), roi_items.ArcROI.ThreePointMode)
+
+ # Change the mode
+ self.mouseMove(widget, pos=(mx, my))
+ self.mouseClick(widget, qt.Qt.LeftButton, pos=(mx, my))
+ self.qWait(500)
+ self.assertIs(item.getInteractionMode(), roi_items.ArcROI.PolarMode)
+
+ manager.clear()
+ self.qapp.processEvents()
diff --git a/src/silx/gui/plot/tools/test/testScatterProfileToolBar.py b/src/silx/gui/plot/tools/test/testScatterProfileToolBar.py
new file mode 100644
index 0000000..582a276
--- /dev/null
+++ b/src/silx/gui/plot/tools/test/testScatterProfileToolBar.py
@@ -0,0 +1,184 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+
+import unittest
+import numpy
+
+from silx.gui import qt
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.plot import PlotWindow
+from silx.gui.plot.tools.profile import manager
+from silx.gui.plot.tools.profile import core
+from silx.gui.plot.tools.profile import rois
+
+
+class TestScatterProfileToolBar(TestCaseQt, ParametricTestCase):
+ """Tests for ScatterProfileToolBar class"""
+
+ def setUp(self):
+ super(TestScatterProfileToolBar, self).setUp()
+ self.plot = PlotWindow()
+
+ self.manager = manager.ProfileManager(plot=self.plot)
+ self.manager.setItemType(scatter=True)
+ self.manager.setActiveItemTracking(True)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ def tearDown(self):
+ del self.manager
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ super(TestScatterProfileToolBar, self).tearDown()
+
+ def testHorizontalProfile(self):
+ """Test ScatterProfileToolBar horizontal profile"""
+
+ roiManager = self.manager.getRoiManager()
+
+ # Add a scatter plot
+ self.plot.addScatter(
+ x=(0., 1., 1., 0.), y=(0., 0., 1., 1.), value=(0., 1., 2., 3.))
+ self.plot.resetZoom(dataMargins=(.1, .1, .1, .1))
+ self.qapp.processEvents()
+
+ # Set a ROI profile
+ roi = rois.ProfileScatterHorizontalLineROI()
+ roi.setPosition(0.5)
+ roi.setNPoints(8)
+ roiManager.addRoi(roi)
+
+ # Wait for async interpolator init
+ for _ in range(20):
+ self.qWait(200)
+ if not self.manager.hasPendingOperations():
+ break
+ self.qapp.processEvents()
+
+ window = roi.getProfileWindow()
+ self.assertIsNotNone(window)
+ data = window.getProfile()
+ self.assertIsInstance(data, core.CurveProfileData)
+ self.assertEqual(len(data.coords), 8)
+
+ # Check that profile has same limits than Plot
+ xLimits = self.plot.getXAxis().getLimits()
+ self.assertEqual(data.coords[0], xLimits[0])
+ self.assertEqual(data.coords[-1], xLimits[1])
+
+ # Clear the profile
+ self.manager.clearProfile()
+ self.qapp.processEvents()
+ self.assertIsNone(roi.getProfileWindow())
+
+ def testVerticalProfile(self):
+ """Test ScatterProfileToolBar vertical profile"""
+
+ roiManager = self.manager.getRoiManager()
+
+ # Add a scatter plot
+ self.plot.addScatter(
+ x=(0., 1., 1., 0.), y=(0., 0., 1., 1.), value=(0., 1., 2., 3.))
+ self.plot.resetZoom(dataMargins=(.1, .1, .1, .1))
+ self.qapp.processEvents()
+
+ # Set a ROI profile
+ roi = rois.ProfileScatterVerticalLineROI()
+ roi.setPosition(0.5)
+ roi.setNPoints(8)
+ roiManager.addRoi(roi)
+
+ # Wait for async interpolator init
+ for _ in range(10):
+ self.qWait(200)
+ if not self.manager.hasPendingOperations():
+ break
+
+ window = roi.getProfileWindow()
+ self.assertIsNotNone(window)
+ data = window.getProfile()
+ self.assertIsInstance(data, core.CurveProfileData)
+ self.assertEqual(len(data.coords), 8)
+
+ # Check that profile has same limits than Plot
+ yLimits = self.plot.getYAxis().getLimits()
+ self.assertEqual(data.coords[0], yLimits[0])
+ self.assertEqual(data.coords[-1], yLimits[1])
+
+ # Check that profile limits are updated when changing limits
+ self.plot.getYAxis().setLimits(yLimits[0] + 1, yLimits[1] + 10)
+
+ # Wait for async interpolator init
+ for _ in range(10):
+ self.qWait(200)
+ if not self.manager.hasPendingOperations():
+ break
+
+ yLimits = self.plot.getYAxis().getLimits()
+ data = window.getProfile()
+ self.assertEqual(data.coords[0], yLimits[0])
+ self.assertEqual(data.coords[-1], yLimits[1])
+
+ # Clear the profile
+ self.manager.clearProfile()
+ self.qapp.processEvents()
+ self.assertIsNone(roi.getProfileWindow())
+
+ def testLineProfile(self):
+ """Test ScatterProfileToolBar line profile"""
+
+ roiManager = self.manager.getRoiManager()
+
+ # Add a scatter plot
+ self.plot.addScatter(
+ x=(0., 1., 1., 0.), y=(0., 0., 1., 1.), value=(0., 1., 2., 3.))
+ self.plot.resetZoom(dataMargins=(.1, .1, .1, .1))
+ self.qapp.processEvents()
+
+ # Set a ROI profile
+ roi = rois.ProfileScatterLineROI()
+ roi.setEndPoints(numpy.array([0., 0.]), numpy.array([1., 1.]))
+ roi.setNPoints(8)
+ roiManager.addRoi(roi)
+
+ # Wait for async interpolator init
+ for _ in range(10):
+ self.qWait(200)
+ if not self.manager.hasPendingOperations():
+ break
+
+ window = roi.getProfileWindow()
+ self.assertIsNotNone(window)
+ data = window.getProfile()
+ self.assertIsInstance(data, core.CurveProfileData)
+ self.assertEqual(len(data.coords), 8)
diff --git a/src/silx/gui/plot/tools/test/testTools.py b/src/silx/gui/plot/tools/test/testTools.py
new file mode 100644
index 0000000..846f641
--- /dev/null
+++ b/src/silx/gui/plot/tools/test/testTools.py
@@ -0,0 +1,135 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for silx.gui.plot.tools package"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "02/03/2018"
+
+
+import functools
+import unittest
+import numpy
+
+from silx.utils.testutils import LoggingValidator
+from silx.gui.utils.testutils import qWaitForWindowExposedAndActivate
+from silx.gui import qt
+from silx.gui.plot import PlotWindow
+from silx.gui.plot import tools
+from silx.gui.plot.test.utils import PlotWidgetTestCase
+
+
+class TestPositionInfo(PlotWidgetTestCase):
+ """Tests for PositionInfo widget."""
+
+ def _createPlot(self):
+ return PlotWindow()
+
+ def setUp(self):
+ super(TestPositionInfo, self).setUp()
+ self.mouseMove(self.plot, pos=(0, 0))
+ self.qapp.processEvents()
+ self.qWait(100)
+
+ def tearDown(self):
+ super(TestPositionInfo, self).tearDown()
+
+ def _test(self, positionWidget, converterNames, **kwargs):
+ """General test of PositionInfo.
+
+ - Add it to a toolbar and
+ - Move mouse around the center of the PlotWindow.
+ """
+ toolBar = qt.QToolBar()
+ self.plot.addToolBar(qt.Qt.BottomToolBarArea, toolBar)
+
+ toolBar.addWidget(positionWidget)
+
+ converters = positionWidget.getConverters()
+ self.assertEqual(len(converters), len(converterNames))
+ for index, name in enumerate(converterNames):
+ self.assertEqual(converters[index][0], name)
+
+ self.qapp.processEvents()
+ with LoggingValidator(tools.__name__, **kwargs):
+ # Move mouse to center
+ center = self.plot.size() / 2
+ self.mouseMove(self.plot, pos=(center.width(), center.height()))
+ # Move out
+ self.mouseMove(self.plot, pos=(1, 1))
+
+ def testDefaultConverters(self):
+ """Test PositionInfo with default converters"""
+ positionWidget = tools.PositionInfo(plot=self.plot)
+ self._test(positionWidget, ('X', 'Y'))
+
+ def testCustomConverters(self):
+ """Test PositionInfo with custom converters"""
+ converters = [
+ ('Coords', lambda x, y: (int(x), int(y))),
+ ('Radius', lambda x, y: numpy.sqrt(x * x + y * y)),
+ ('Angle', lambda x, y: numpy.degrees(numpy.arctan2(y, x)))
+ ]
+ positionWidget = tools.PositionInfo(plot=self.plot,
+ converters=converters)
+ self._test(positionWidget, ('Coords', 'Radius', 'Angle'))
+
+ def testFailingConverters(self):
+ """Test PositionInfo with failing custom converters"""
+ def raiseException(x, y):
+ raise RuntimeError()
+
+ positionWidget = tools.PositionInfo(
+ plot=self.plot,
+ converters=[('Exception', raiseException)])
+ self._test(positionWidget, ['Exception'], error=2)
+
+ def testUpdate(self):
+ """Test :meth:`PositionInfo.updateInfo`"""
+ calls = []
+
+ def update(calls, x, y): # Get number of calls
+ calls.append((x, y))
+ return len(calls)
+
+ positionWidget = tools.PositionInfo(
+ plot=self.plot,
+ converters=[('Call count', functools.partial(update, calls))])
+
+ positionWidget.updateInfo()
+ self.assertEqual(len(calls), 1)
+
+
+class TestPlotToolsToolbars(PlotWidgetTestCase):
+ """Tests toolbars from silx.gui.plot.tools"""
+
+ def test(self):
+ """"Add all toolbars"""
+ for tbClass in (tools.InteractiveModeToolBar,
+ tools.ImageToolBar,
+ tools.CurveToolBar,
+ tools.OutputToolBar):
+ tb = tbClass(parent=self.plot, plot=self.plot)
+ self.plot.addToolBar(tb)
diff --git a/src/silx/gui/plot/tools/toolbars.py b/src/silx/gui/plot/tools/toolbars.py
new file mode 100644
index 0000000..3df7d06
--- /dev/null
+++ b/src/silx/gui/plot/tools/toolbars.py
@@ -0,0 +1,362 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides toolbars that work with :class:`PlotWidget`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "01/03/2018"
+
+
+from ... import qt
+from .. import actions
+from ..PlotWidget import PlotWidget
+from .. import PlotToolButtons
+from ....utils.deprecation import deprecated
+
+
+class InteractiveModeToolBar(qt.QToolBar):
+ """Toolbar with interactive mode actions
+
+ :param parent: See :class:`QWidget`
+ :param silx.gui.plot.PlotWidget plot: PlotWidget to control
+ :param str title: Title of the toolbar.
+ """
+
+ def __init__(self, parent=None, plot=None, title='Plot Interaction'):
+ super(InteractiveModeToolBar, self).__init__(title, parent)
+
+ assert isinstance(plot, PlotWidget)
+
+ self._zoomModeAction = actions.mode.ZoomModeAction(
+ parent=self, plot=plot)
+ self.addAction(self._zoomModeAction)
+
+ self._panModeAction = actions.mode.PanModeAction(
+ parent=self, plot=plot)
+ self.addAction(self._panModeAction)
+
+ def getZoomModeAction(self):
+ """Returns the zoom mode QAction.
+
+ :rtype: PlotAction
+ """
+ return self._zoomModeAction
+
+ def getPanModeAction(self):
+ """Returns the pan mode QAction
+
+ :rtype: PlotAction
+ """
+ return self._panModeAction
+
+
+class OutputToolBar(qt.QToolBar):
+ """Toolbar providing icons to copy, save and print a PlotWidget
+
+ :param parent: See :class:`QWidget`
+ :param silx.gui.plot.PlotWidget plot: PlotWidget to control
+ :param str title: Title of the toolbar.
+ """
+
+ def __init__(self, parent=None, plot=None, title='Plot Output'):
+ super(OutputToolBar, self).__init__(title, parent)
+
+ assert isinstance(plot, PlotWidget)
+
+ self._copyAction = actions.io.CopyAction(parent=self, plot=plot)
+ self.addAction(self._copyAction)
+
+ self._saveAction = actions.io.SaveAction(parent=self, plot=plot)
+ self.addAction(self._saveAction)
+
+ self._printAction = actions.io.PrintAction(parent=self, plot=plot)
+ self.addAction(self._printAction)
+
+ def getCopyAction(self):
+ """Returns the QAction performing copy to clipboard of the PlotWidget
+
+ :rtype: PlotAction
+ """
+ return self._copyAction
+
+ def getSaveAction(self):
+ """Returns the QAction performing save to file of the PlotWidget
+
+ :rtype: PlotAction
+ """
+ return self._saveAction
+
+ def getPrintAction(self):
+ """Returns the QAction performing printing of the PlotWidget
+
+ :rtype: PlotAction
+ """
+ return self._printAction
+
+
+class ImageToolBar(qt.QToolBar):
+ """Toolbar providing PlotAction suited when displaying images
+
+ :param parent: See :class:`QWidget`
+ :param silx.gui.plot.PlotWidget plot: PlotWidget to control
+ :param str title: Title of the toolbar.
+ """
+
+ def __init__(self, parent=None, plot=None, title='Image'):
+ super(ImageToolBar, self).__init__(title, parent)
+
+ assert isinstance(plot, PlotWidget)
+
+ self._resetZoomAction = actions.control.ResetZoomAction(
+ parent=self, plot=plot)
+ self.addAction(self._resetZoomAction)
+
+ self._colormapAction = actions.control.ColormapAction(
+ parent=self, plot=plot)
+ self.addAction(self._colormapAction)
+
+ self._keepDataAspectRatioButton = PlotToolButtons.AspectToolButton(
+ parent=self, plot=plot)
+ self.addWidget(self._keepDataAspectRatioButton)
+
+ self._yAxisInvertedButton = PlotToolButtons.YAxisOriginToolButton(
+ parent=self, plot=plot)
+ self.addWidget(self._yAxisInvertedButton)
+
+ def getResetZoomAction(self):
+ """Returns the QAction to reset the zoom.
+
+ :rtype: PlotAction
+ """
+ return self._resetZoomAction
+
+ def getColormapAction(self):
+ """Returns the QAction to control the colormap.
+
+ :rtype: PlotAction
+ """
+ return self._colormapAction
+
+ def getKeepDataAspectRatioButton(self):
+ """Returns the QToolButton controlling data aspect ratio.
+
+ :rtype: QToolButton
+ """
+ return self._keepDataAspectRatioButton
+
+ def getYAxisInvertedButton(self):
+ """Returns the QToolButton controlling Y axis orientation.
+
+ :rtype: QToolButton
+ """
+ return self._yAxisInvertedButton
+
+
+class CurveToolBar(qt.QToolBar):
+ """Toolbar providing PlotAction suited when displaying curves
+
+ :param parent: See :class:`QWidget`
+ :param silx.gui.plot.PlotWidget plot: PlotWidget to control
+ :param str title: Title of the toolbar.
+ """
+
+ def __init__(self, parent=None, plot=None, title='Image'):
+ super(CurveToolBar, self).__init__(title, parent)
+
+ assert isinstance(plot, PlotWidget)
+
+ self._resetZoomAction = actions.control.ResetZoomAction(
+ parent=self, plot=plot)
+ self.addAction(self._resetZoomAction)
+
+ self._xAxisAutoScaleAction = actions.control.XAxisAutoScaleAction(
+ parent=self, plot=plot)
+ self.addAction(self._xAxisAutoScaleAction)
+
+ self._yAxisAutoScaleAction = actions.control.YAxisAutoScaleAction(
+ parent=self, plot=plot)
+ self.addAction(self._yAxisAutoScaleAction)
+
+ self._xAxisLogarithmicAction = actions.control.XAxisLogarithmicAction(
+ parent=self, plot=plot)
+ self.addAction(self._xAxisLogarithmicAction)
+
+ self._yAxisLogarithmicAction = actions.control.YAxisLogarithmicAction(
+ parent=self, plot=plot)
+ self.addAction(self._yAxisLogarithmicAction)
+
+ self._gridAction = actions.control.GridAction(
+ parent=self, plot=plot)
+ self.addAction(self._gridAction)
+
+ self._curveStyleAction = actions.control.CurveStyleAction(
+ parent=self, plot=plot)
+ self.addAction(self._curveStyleAction)
+
+ def getResetZoomAction(self):
+ """Returns the QAction to reset the zoom.
+
+ :rtype: PlotAction
+ """
+ return self._resetZoomAction
+
+ def getXAxisAutoScaleAction(self):
+ """Returns the QAction to toggle X axis autoscale.
+
+ :rtype: PlotAction
+ """
+ return self._xAxisAutoScaleAction
+
+ def getYAxisAutoScaleAction(self):
+ """Returns the QAction to toggle Y axis autoscale.
+
+ :rtype: PlotAction
+ """
+ return self._yAxisAutoScaleAction
+
+ def getXAxisLogarithmicAction(self):
+ """Returns the QAction to toggle X axis log/linear scale.
+
+ :rtype: PlotAction
+ """
+ return self._xAxisLogarithmicAction
+
+ def getYAxisLogarithmicAction(self):
+ """Returns the QAction to toggle Y axis log/linear scale.
+
+ :rtype: PlotAction
+ """
+ return self._yAxisLogarithmicAction
+
+ def getGridAction(self):
+ """Returns the action to toggle the plot grid.
+
+ :rtype: PlotAction
+ """
+ return self._gridAction
+
+ def getCurveStyleAction(self):
+ """Returns the QAction to change the style of all curves.
+
+ :rtype: PlotAction
+ """
+ return self._curveStyleAction
+
+
+class ScatterToolBar(qt.QToolBar):
+ """Toolbar providing PlotAction suited when displaying scatter plot
+
+ :param parent: See :class:`QWidget`
+ :param silx.gui.plot.PlotWidget plot: PlotWidget to control
+ :param str title: Title of the toolbar.
+ """
+
+ def __init__(self, parent=None, plot=None, title='Scatter Tools'):
+ super(ScatterToolBar, self).__init__(title, parent)
+
+ assert isinstance(plot, PlotWidget)
+
+ self._resetZoomAction = actions.control.ResetZoomAction(
+ parent=self, plot=plot)
+ self.addAction(self._resetZoomAction)
+
+ self._xAxisLogarithmicAction = actions.control.XAxisLogarithmicAction(
+ parent=self, plot=plot)
+ self.addAction(self._xAxisLogarithmicAction)
+
+ self._yAxisLogarithmicAction = actions.control.YAxisLogarithmicAction(
+ parent=self, plot=plot)
+ self.addAction(self._yAxisLogarithmicAction)
+
+ self._keepDataAspectRatioButton = PlotToolButtons.AspectToolButton(
+ parent=self, plot=plot)
+ self.addWidget(self._keepDataAspectRatioButton)
+
+ self._gridAction = actions.control.GridAction(
+ parent=self, plot=plot)
+ self.addAction(self._gridAction)
+
+ self._colormapAction = actions.control.ColormapAction(
+ parent=self, plot=plot)
+ self.addAction(self._colormapAction)
+
+ self._visualizationToolButton = \
+ PlotToolButtons.ScatterVisualizationToolButton(parent=self, plot=plot)
+ self.addWidget(self._visualizationToolButton)
+
+ def getResetZoomAction(self):
+ """Returns the QAction to reset the zoom.
+
+ :rtype: PlotAction
+ """
+ return self._resetZoomAction
+
+ def getXAxisLogarithmicAction(self):
+ """Returns the QAction to toggle X axis log/linear scale.
+
+ :rtype: PlotAction
+ """
+ return self._xAxisLogarithmicAction
+
+ def getYAxisLogarithmicAction(self):
+ """Returns the QAction to toggle Y axis log/linear scale.
+
+ :rtype: PlotAction
+ """
+ return self._yAxisLogarithmicAction
+
+ def getGridAction(self):
+ """Returns the action to toggle the plot grid.
+
+ :rtype: PlotAction
+ """
+ return self._gridAction
+
+ def getColormapAction(self):
+ """Returns the QAction to control the colormap.
+
+ :rtype: PlotAction
+ """
+ return self._colormapAction
+
+ def getKeepDataAspectRatioButton(self):
+ """Returns the QToolButton controlling data aspect ratio.
+
+ :rtype: QToolButton
+ """
+ return self._keepDataAspectRatioButton
+
+ def getScatterVisualizationToolButton(self):
+ """Returns the QToolButton controlling the visualization mode.
+
+ :rtype: ScatterVisualizationToolButton
+ """
+ return self._visualizationToolButton
+
+ @deprecated(replacement='getScatterVisualizationToolButton',
+ since_version='0.11.0')
+ def getSymbolToolButton(self):
+ return self.getScatterVisualizationToolButton()
diff --git a/src/silx/gui/plot/utils/__init__.py b/src/silx/gui/plot/utils/__init__.py
new file mode 100644
index 0000000..3187f6b
--- /dev/null
+++ b/src/silx/gui/plot/utils/__init__.py
@@ -0,0 +1,30 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Utils module for plot.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "29/06/2017"
diff --git a/src/silx/gui/plot/utils/axis.py b/src/silx/gui/plot/utils/axis.py
new file mode 100644
index 0000000..5cf8ad9
--- /dev/null
+++ b/src/silx/gui/plot/utils/axis.py
@@ -0,0 +1,398 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module contains utils class for axes management.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "20/11/2018"
+
+import functools
+import logging
+from contextlib import contextmanager
+import weakref
+import silx.utils.weakref as silxWeakref
+from silx.gui.plot.items.axis import Axis, XAxis, YAxis
+from ...qt.inspect import isValid as _isQObjectValid
+
+
+_logger = logging.getLogger(__name__)
+
+
+class SyncAxes(object):
+ """Synchronize a set of plot axes together.
+
+ It is created with the expected axes and starts to synchronize them.
+
+ It can be customized to synchronize limits, scale, and direction of axes
+ together. By default everything is synchronized.
+
+ The API :meth:`start` and :meth:`stop` can be used to enable/disable the
+ synchronization while this object is still alive.
+
+ If this object is destroyed the synchronization stop.
+
+ .. versionadded:: 0.6
+ """
+
+ def __init__(self, axes,
+ syncLimits=True,
+ syncScale=True,
+ syncDirection=True,
+ syncCenter=False,
+ syncZoom=False,
+ filterHiddenPlots=False
+ ):
+ """
+ Constructor
+
+ :param list(Axis) axes: A list of axes to synchronize together
+ :param bool syncLimits: Synchronize axes limits
+ :param bool syncScale: Synchronize axes scale
+ :param bool syncDirection: Synchronize axes direction
+ :param bool syncCenter: Synchronize the center of the axes in the center
+ of the plots
+ :param bool syncZoom: Synchronize the zoom of the plot
+ :param bool filterHiddenPlots: True to avoid updating hidden plots.
+ Default: False.
+ """
+ object.__init__(self)
+
+ def implies(x, y): return bool(y ** x)
+
+ assert(implies(syncZoom, not syncLimits))
+ assert(implies(syncCenter, not syncLimits))
+ assert(implies(syncLimits, not syncCenter))
+ assert(implies(syncLimits, not syncZoom))
+
+ self.__filterHiddenPlots = filterHiddenPlots
+ self.__locked = False
+ self.__axisRefs = []
+ self.__syncLimits = syncLimits
+ self.__syncScale = syncScale
+ self.__syncDirection = syncDirection
+ self.__syncCenter = syncCenter
+ self.__syncZoom = syncZoom
+ self.__callbacks = None
+ self.__lastMainAxis = None
+
+ for axis in axes:
+ self.addAxis(axis)
+
+ self.start()
+
+ def start(self):
+ """Start synchronizing axes together.
+
+ The first axis is used as the reference for the first synchronization.
+ After that, any changes to any axes will be used to synchronize other
+ axes.
+ """
+ if self.isSynchronizing():
+ raise RuntimeError("Axes already synchronized")
+ self.__callbacks = {}
+
+ axes = self.__getAxes()
+
+ # register callback for further sync
+ for axis in axes:
+ self.__connectAxes(axis)
+ self.synchronize()
+
+ def isSynchronizing(self):
+ """Returns true if events are connected to the axes to synchronize them
+ all together
+
+ :rtype: bool
+ """
+ return self.__callbacks is not None
+
+ def __connectAxes(self, axis):
+ refAxis = weakref.ref(axis)
+ callbacks = []
+ if self.__syncLimits:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisLimitsChanged)
+ callback = functools.partial(callback, refAxis)
+ sig = axis.sigLimitsChanged
+ sig.connect(callback)
+ callbacks.append(("sigLimitsChanged", callback))
+ elif self.__syncCenter and self.__syncZoom:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisCenterAndZoomChanged)
+ callback = functools.partial(callback, refAxis)
+ sig = axis.sigLimitsChanged
+ sig.connect(callback)
+ callbacks.append(("sigLimitsChanged", callback))
+ elif self.__syncZoom:
+ raise NotImplementedError()
+ elif self.__syncCenter:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisCenterChanged)
+ callback = functools.partial(callback, refAxis)
+ sig = axis.sigLimitsChanged
+ sig.connect(callback)
+ callbacks.append(("sigLimitsChanged", callback))
+ if self.__syncScale:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisScaleChanged)
+ callback = functools.partial(callback, refAxis)
+ sig = axis.sigScaleChanged
+ sig.connect(callback)
+ callbacks.append(("sigScaleChanged", callback))
+ if self.__syncDirection:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisInvertedChanged)
+ callback = functools.partial(callback, refAxis)
+ sig = axis.sigInvertedChanged
+ sig.connect(callback)
+ callbacks.append(("sigInvertedChanged", callback))
+
+ if self.__filterHiddenPlots:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisVisibilityChanged)
+ callback = functools.partial(callback, refAxis)
+ plot = axis._getPlot()
+ plot.sigVisibilityChanged.connect(callback)
+ callbacks.append(("sigVisibilityChanged", callback))
+
+ self.__callbacks[refAxis] = callbacks
+
+ def __disconnectAxes(self, axis):
+ if axis is not None and _isQObjectValid(axis):
+ ref = weakref.ref(axis)
+ callbacks = self.__callbacks.pop(ref)
+ for sigName, callback in callbacks:
+ if sigName == "sigVisibilityChanged":
+ obj = axis._getPlot()
+ else:
+ obj = axis
+ if obj is not None:
+ sig = getattr(obj, sigName)
+ sig.disconnect(callback)
+
+ def addAxis(self, axis):
+ """Add a new axes to synchronize.
+
+ :param ~silx.gui.plot.items.Axis axis: The axis to synchronize
+ """
+ self.__axisRefs.append(weakref.ref(axis))
+ if self.isSynchronizing():
+ self.__connectAxes(axis)
+ # This could be done faster as only this axis have to be fixed
+ self.synchronize()
+
+ def removeAxis(self, axis):
+ """Remove an axis from the synchronized axes.
+
+ :param ~silx.gui.plot.items.Axis axis: The axis to remove
+ """
+ ref = weakref.ref(axis)
+ self.__axisRefs.remove(ref)
+ if self.isSynchronizing():
+ self.__disconnectAxes(axis)
+
+ def synchronize(self, mainAxis=None):
+ """Synchronize programatically all the axes.
+
+ :param ~silx.gui.plot.items.Axis mainAxis:
+ The axis to take as reference (Default: the first axis).
+ """
+ # sync the current state
+ axes = self.__getAxes()
+ if len(axes) == 0:
+ return
+
+ if mainAxis is None:
+ mainAxis = axes[0]
+
+ refMainAxis = weakref.ref(mainAxis)
+ if self.__syncLimits:
+ self.__axisLimitsChanged(refMainAxis, *mainAxis.getLimits())
+ elif self.__syncCenter and self.__syncZoom:
+ self.__axisCenterAndZoomChanged(refMainAxis, *mainAxis.getLimits())
+ elif self.__syncCenter:
+ self.__axisCenterChanged(refMainAxis, *mainAxis.getLimits())
+ if self.__syncScale:
+ self.__axisScaleChanged(refMainAxis, mainAxis.getScale())
+ if self.__syncDirection:
+ self.__axisInvertedChanged(refMainAxis, mainAxis.isInverted())
+
+ def stop(self):
+ """Stop the synchronization of the axes"""
+ if not self.isSynchronizing():
+ raise RuntimeError("Axes not synchronized")
+ for ref in list(self.__callbacks.keys()):
+ axis = ref()
+ self.__disconnectAxes(axis)
+ self.__callbacks = None
+
+ def __del__(self):
+ """Destructor"""
+ # clean up references
+ if self.__callbacks is not None:
+ self.stop()
+
+ def __getAxes(self):
+ """Returns list of existing axes.
+
+ :rtype: List[Axis]
+ """
+ axes = [ref() for ref in self.__axisRefs]
+ return [axis for axis in axes if axis is not None]
+
+ @contextmanager
+ def __inhibitSignals(self):
+ self.__locked = True
+ yield
+ self.__locked = False
+
+ def __axesToUpdate(self, changedAxis):
+ for axis in self.__getAxes():
+ if axis is changedAxis:
+ continue
+ if self.__filterHiddenPlots:
+ plot = axis._getPlot()
+ if not plot.isVisible():
+ continue
+ yield axis
+
+ def __axisVisibilityChanged(self, changedAxis, isVisible):
+ if not isVisible:
+ return
+ if self.__locked:
+ return
+ changedAxis = changedAxis()
+ if self.__lastMainAxis is None:
+ self.__lastMainAxis = self.__axisRefs[0]
+ mainAxis = self.__lastMainAxis
+ mainAxis = mainAxis()
+ self.synchronize(mainAxis=mainAxis)
+ # force back the main axis
+ self.__lastMainAxis = weakref.ref(mainAxis)
+
+ def __getAxesCenter(self, axis, vmin, vmax):
+ """Returns the value displayed in the center of this axis range.
+
+ :rtype: float
+ """
+ scale = axis.getScale()
+ if scale == Axis.LINEAR:
+ center = (vmin + vmax) * 0.5
+ else:
+ raise NotImplementedError("Log scale not implemented")
+ return center
+
+ def __getRangeInPixel(self, axis):
+ """Returns the size of the axis in pixel"""
+ bounds = axis._getPlot().getPlotBoundsInPixels()
+ # bounds: left, top, width, height
+ if isinstance(axis, XAxis):
+ return bounds[2]
+ elif isinstance(axis, YAxis):
+ return bounds[3]
+ else:
+ assert(False)
+
+ def __getLimitsFromCenter(self, axis, pos, pixelSize=None):
+ """Returns the limits to apply to this axis to move the `pos` into the
+ center of this axis.
+
+ :param Axis axis:
+ :param float pos: Position in the center of the computed limits
+ :param Union[None,float] pixelSize: Pixel size to apply to compute the
+ limits. If `None` the current pixel size is applyed.
+ """
+ scale = axis.getScale()
+ if scale == Axis.LINEAR:
+ if pixelSize is None:
+ # Use the current pixel size of the axis
+ limits = axis.getLimits()
+ valueRange = limits[0] - limits[1]
+ a = pos - valueRange * 0.5
+ b = pos + valueRange * 0.5
+ else:
+ pixelRange = self.__getRangeInPixel(axis)
+ a = pos - pixelRange * 0.5 * pixelSize
+ b = pos + pixelRange * 0.5 * pixelSize
+
+ else:
+ raise NotImplementedError("Log scale not implemented")
+ if a > b:
+ return b, a
+ return a, b
+
+ def __axisLimitsChanged(self, changedAxis, vmin, vmax):
+ if self.__locked:
+ return
+ self.__lastMainAxis = changedAxis
+ changedAxis = changedAxis()
+ with self.__inhibitSignals():
+ for axis in self.__axesToUpdate(changedAxis):
+ axis.setLimits(vmin, vmax)
+
+ def __axisCenterAndZoomChanged(self, changedAxis, vmin, vmax):
+ if self.__locked:
+ return
+ self.__lastMainAxis = changedAxis
+ changedAxis = changedAxis()
+ with self.__inhibitSignals():
+ center = self.__getAxesCenter(changedAxis, vmin, vmax)
+ pixelRange = self.__getRangeInPixel(changedAxis)
+ if pixelRange == 0:
+ return
+ pixelSize = (vmax - vmin) / pixelRange
+ for axis in self.__axesToUpdate(changedAxis):
+ vmin, vmax = self.__getLimitsFromCenter(axis, center, pixelSize)
+ axis.setLimits(vmin, vmax)
+
+ def __axisCenterChanged(self, changedAxis, vmin, vmax):
+ if self.__locked:
+ return
+ self.__lastMainAxis = changedAxis
+ changedAxis = changedAxis()
+ with self.__inhibitSignals():
+ center = self.__getAxesCenter(changedAxis, vmin, vmax)
+ for axis in self.__axesToUpdate(changedAxis):
+ vmin, vmax = self.__getLimitsFromCenter(axis, center)
+ axis.setLimits(vmin, vmax)
+
+ def __axisScaleChanged(self, changedAxis, scale):
+ if self.__locked:
+ return
+ self.__lastMainAxis = changedAxis
+ changedAxis = changedAxis()
+ with self.__inhibitSignals():
+ for axis in self.__axesToUpdate(changedAxis):
+ axis.setScale(scale)
+
+ def __axisInvertedChanged(self, changedAxis, isInverted):
+ if self.__locked:
+ return
+ self.__lastMainAxis = changedAxis
+ changedAxis = changedAxis()
+ with self.__inhibitSignals():
+ for axis in self.__axesToUpdate(changedAxis):
+ axis.setInverted(isInverted)
diff --git a/src/silx/gui/plot/utils/intersections.py b/src/silx/gui/plot/utils/intersections.py
new file mode 100644
index 0000000..53f2546
--- /dev/null
+++ b/src/silx/gui/plot/utils/intersections.py
@@ -0,0 +1,101 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module contains utils class for axes management.
+"""
+
+__authors__ = ["H. Payno", ]
+__license__ = "MIT"
+__date__ = "18/05/2020"
+
+
+import numpy
+
+
+def lines_intersection(line1_pt1, line1_pt2, line2_pt1, line2_pt2):
+ """
+ line segment intersection using vectors (Computer Graphics by F.S. Hill)
+
+ :param tuple line1_pt1:
+ :param tuple line1_pt2:
+ :param tuple line2_pt1:
+ :param tuple line2_pt2:
+ :return: Union[None,numpy.array]
+ """
+ dir_line1 = line1_pt2[0] - line1_pt1[0], line1_pt2[1] - line1_pt1[1]
+ dir_line2 = line2_pt2[0] - line2_pt1[0], line2_pt2[1] - line2_pt1[1]
+ dp = line1_pt1 - line2_pt1
+
+ def perp(a):
+ b = numpy.empty_like(a)
+ b[0] = -a[1]
+ b[1] = a[0]
+ return b
+
+ dap = perp(dir_line1)
+ denom = numpy.dot(dap, dir_line2)
+ num = numpy.dot(dap, dp)
+ if denom == 0:
+ return None
+ return (
+ (num / denom.astype(float)) * dir_line2[0] + line2_pt1[0],
+ (num / denom.astype(float)) * dir_line2[1] + line2_pt1[1])
+
+
+def segments_intersection(seg1_start_pt, seg1_end_pt, seg2_start_pt,
+ seg2_end_pt):
+ """
+ Compute intersection between two segments
+
+ :param seg1_start_pt:
+ :param seg1_end_pt:
+ :param seg2_start_pt:
+ :param seg2_end_pt:
+ :return: numpy.array if an intersection exists, else None
+ :rtype: Union[None,numpy.array]
+ """
+ intersection = lines_intersection(line1_pt1=seg1_start_pt,
+ line1_pt2=seg1_end_pt,
+ line2_pt1=seg2_start_pt,
+ line2_pt2=seg2_end_pt)
+ if intersection is not None:
+ max_x_seg1 = max(seg1_start_pt[0], seg1_end_pt[0])
+ max_x_seg2 = max(seg2_start_pt[0], seg2_end_pt[0])
+ max_y_seg1 = max(seg1_start_pt[1], seg1_end_pt[1])
+ max_y_seg2 = max(seg2_start_pt[1], seg2_end_pt[1])
+
+ min_x_seg1 = min(seg1_start_pt[0], seg1_end_pt[0])
+ min_x_seg2 = min(seg2_start_pt[0], seg2_end_pt[0])
+ min_y_seg1 = min(seg1_start_pt[1], seg1_end_pt[1])
+ min_y_seg2 = min(seg2_start_pt[1], seg2_end_pt[1])
+
+ min_tmp_x = max(min_x_seg1, min_x_seg2)
+ max_tmp_x = min(max_x_seg1, max_x_seg2)
+ min_tmp_y = max(min_y_seg1, min_y_seg2)
+ max_tmp_y = min(max_y_seg1, max_y_seg2)
+ if (min_tmp_x <= intersection[0] <= max_tmp_x and
+ min_tmp_y <= intersection[1] <= max_tmp_y):
+ return intersection
+ else:
+ return None
diff --git a/src/silx/gui/plot3d/ParamTreeView.py b/src/silx/gui/plot3d/ParamTreeView.py
new file mode 100644
index 0000000..2593860
--- /dev/null
+++ b/src/silx/gui/plot3d/ParamTreeView.py
@@ -0,0 +1,522 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module provides a :class:`QTreeView` dedicated to display plot3d models.
+
+This module contains:
+- :class:`ParamTreeView`: A QTreeView specific for plot3d parameters and scene.
+- :class:`ParameterTreeDelegate`: The delegate for :class:`ParamTreeView`.
+- A set of specific editors used by :class:`ParameterTreeDelegate`:
+ :class:`FloatEditor`, :class:`Vector3DEditor`,
+ :class:`Vector4DEditor`, :class:`IntSliderEditor`, :class:`BooleanEditor`
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "05/12/2017"
+
+
+import numbers
+import sys
+
+from .. import qt
+from ..widgets.FloatEdit import FloatEdit as _FloatEdit
+from ._model import visitQAbstractItemModel
+
+
+class FloatEditor(_FloatEdit):
+ """Editor widget for float.
+
+ :param parent: The widget's parent
+ :param float value: The initial editor value
+ """
+
+ valueChanged = qt.Signal(float)
+ """Signal emitted when the float value has changed"""
+
+ def __init__(self, parent=None, value=None):
+ super(FloatEditor, self).__init__(parent, value)
+ self.setAlignment(qt.Qt.AlignLeft)
+ self.editingFinished.connect(self._emit)
+
+ def _emit(self):
+ self.valueChanged.emit(self.value)
+
+ value = qt.Property(float,
+ fget=_FloatEdit.value,
+ fset=_FloatEdit.setValue,
+ user=True,
+ notify=valueChanged)
+ """Qt user property of the float value this widget edits"""
+
+
+class Vector3DEditor(qt.QWidget):
+ """Editor widget for QVector3D.
+
+ :param parent: The widget's parent
+ :param flags: The widgets's flags
+ """
+
+ valueChanged = qt.Signal(qt.QVector3D)
+ """Signal emitted when the QVector3D value has changed"""
+
+ def __init__(self, parent=None, flags=qt.Qt.Widget):
+ super(Vector3DEditor, self).__init__(parent, flags)
+ layout = qt.QHBoxLayout(self)
+ # layout.setSpacing(0)
+ layout.setContentsMargins(0, 0, 0, 0)
+ self.setLayout(layout)
+ self._xEdit = _FloatEdit(parent=self, value=0.)
+ self._xEdit.setAlignment(qt.Qt.AlignLeft)
+ # self._xEdit.editingFinished.connect(self._emit)
+ self._yEdit = _FloatEdit(parent=self, value=0.)
+ self._yEdit.setAlignment(qt.Qt.AlignLeft)
+ # self._yEdit.editingFinished.connect(self._emit)
+ self._zEdit = _FloatEdit(parent=self, value=0.)
+ self._zEdit.setAlignment(qt.Qt.AlignLeft)
+ # self._zEdit.editingFinished.connect(self._emit)
+ layout.addWidget(qt.QLabel('x:'))
+ layout.addWidget(self._xEdit)
+ layout.addWidget(qt.QLabel('y:'))
+ layout.addWidget(self._yEdit)
+ layout.addWidget(qt.QLabel('z:'))
+ layout.addWidget(self._zEdit)
+ layout.addStretch(1)
+
+ def _emit(self):
+ vector = self.value
+ self.valueChanged.emit(vector)
+
+ def getValue(self):
+ """Returns the QVector3D value of this widget
+
+ :rtype: QVector3D
+ """
+ return qt.QVector3D(
+ self._xEdit.value(), self._yEdit.value(), self._zEdit.value())
+
+ def setValue(self, value):
+ """Set the QVector3D value
+
+ :param QVector3D value: The new value
+ """
+ self._xEdit.setValue(value.x())
+ self._yEdit.setValue(value.y())
+ self._zEdit.setValue(value.z())
+ self.valueChanged.emit(value)
+
+ value = qt.Property(qt.QVector3D,
+ fget=getValue,
+ fset=setValue,
+ user=True,
+ notify=valueChanged)
+ """Qt user property of the QVector3D value this widget edits"""
+
+
+class Vector4DEditor(qt.QWidget):
+ """Editor widget for QVector4D.
+
+ :param parent: The widget's parent
+ :param flags: The widgets's flags
+ """
+
+ valueChanged = qt.Signal(qt.QVector4D)
+ """Signal emitted when the QVector4D value has changed"""
+
+ def __init__(self, parent=None, flags=qt.Qt.Widget):
+ super(Vector4DEditor, self).__init__(parent, flags)
+ layout = qt.QHBoxLayout(self)
+ # layout.setSpacing(0)
+ layout.setContentsMargins(0, 0, 0, 0)
+ self.setLayout(layout)
+ self._xEdit = _FloatEdit(parent=self, value=0.)
+ self._xEdit.setAlignment(qt.Qt.AlignLeft)
+ # self._xEdit.editingFinished.connect(self._emit)
+ self._yEdit = _FloatEdit(parent=self, value=0.)
+ self._yEdit.setAlignment(qt.Qt.AlignLeft)
+ # self._yEdit.editingFinished.connect(self._emit)
+ self._zEdit = _FloatEdit(parent=self, value=0.)
+ self._zEdit.setAlignment(qt.Qt.AlignLeft)
+ # self._zEdit.editingFinished.connect(self._emit)
+ self._wEdit = _FloatEdit(parent=self, value=0.)
+ self._wEdit.setAlignment(qt.Qt.AlignLeft)
+ # self._wEdit.editingFinished.connect(self._emit)
+ layout.addWidget(qt.QLabel('x:'))
+ layout.addWidget(self._xEdit)
+ layout.addWidget(qt.QLabel('y:'))
+ layout.addWidget(self._yEdit)
+ layout.addWidget(qt.QLabel('z:'))
+ layout.addWidget(self._zEdit)
+ layout.addWidget(qt.QLabel('w:'))
+ layout.addWidget(self._wEdit)
+ layout.addStretch(1)
+
+ def _emit(self):
+ vector = self.value
+ self.valueChanged.emit(vector)
+
+ def getValue(self):
+ """Returns the QVector4D value of this widget
+
+ :rtype: QVector4D
+ """
+ return qt.QVector4D(self._xEdit.value(), self._yEdit.value(),
+ self._zEdit.value(), self._wEdit.value())
+
+ def setValue(self, value):
+ """Set the QVector4D value
+
+ :param QVector4D value: The new value
+ """
+ self._xEdit.setValue(value.x())
+ self._yEdit.setValue(value.y())
+ self._zEdit.setValue(value.z())
+ self._wEdit.setValue(value.w())
+ self.valueChanged.emit(value)
+
+ value = qt.Property(qt.QVector4D,
+ fget=getValue,
+ fset=setValue,
+ user=True,
+ notify=valueChanged)
+ """Qt user property of the QVector4D value this widget edits"""
+
+
+class IntSliderEditor(qt.QSlider):
+ """Slider editor widget for integer.
+
+ Note: Tracking is disabled.
+
+ :param parent: The widget's parent
+ """
+
+ def __init__(self, parent=None):
+ super(IntSliderEditor, self).__init__(parent)
+ self.setOrientation(qt.Qt.Horizontal)
+ self.setSingleStep(1)
+ self.setRange(0, 255)
+ self.setValue(0)
+
+
+class BooleanEditor(qt.QCheckBox):
+ """Checkbox editor for bool.
+
+ This is a QCheckBox with white background.
+
+ :param parent: The widget's parent
+ """
+
+ def __init__(self, parent=None):
+ super(BooleanEditor, self).__init__(parent)
+ self.setStyleSheet("background: white;")
+
+
+class ParameterTreeDelegate(qt.QStyledItemDelegate):
+ """TreeView delegate specific to plot3d scene and object parameter tree.
+
+ It provides additional editors.
+
+ :param parent: Delegate's parent
+ """
+
+ EDITORS = {
+ bool: BooleanEditor,
+ float: FloatEditor,
+ qt.QVector3D: Vector3DEditor,
+ qt.QVector4D: Vector4DEditor,
+ }
+ """Specific editors for different type of data"""
+
+ def __init__(self, parent=None):
+ super(ParameterTreeDelegate, self).__init__(parent)
+
+ def paint(self, painter, option, index):
+ """See :meth:`QStyledItemDelegate.paint`"""
+ data = index.data(qt.Qt.DisplayRole)
+
+ if isinstance(data, (qt.QVector3D, qt.QVector4D)):
+ if isinstance(data, qt.QVector3D):
+ text = '(x: %g; y: %g; z: %g)' % (data.x(), data.y(), data.z())
+ elif isinstance(data, qt.QVector4D):
+ text = '(%g; %g; %g; %g)' % (data.x(), data.y(), data.z(), data.w())
+ else:
+ text = ''
+
+ painter.save()
+ painter.setRenderHint(qt.QPainter.Antialiasing, True)
+
+ # Select palette color group
+ colorGroup = qt.QPalette.Inactive
+ if option.state & qt.QStyle.State_Active:
+ colorGroup = qt.QPalette.Active
+ if not option.state & qt.QStyle.State_Enabled:
+ colorGroup = qt.QPalette.Disabled
+
+ # Draw background if selected
+ if option.state & qt.QStyle.State_Selected:
+ brush = option.palette.brush(colorGroup,
+ qt.QPalette.Highlight)
+ painter.fillRect(option.rect, brush)
+
+ # Draw text
+ if option.state & qt.QStyle.State_Selected:
+ colorRole = qt.QPalette.HighlightedText
+ else:
+ colorRole = qt.QPalette.WindowText
+ color = option.palette.color(colorGroup, colorRole)
+ painter.setPen(qt.QPen(color))
+ painter.drawText(option.rect, qt.Qt.AlignLeft, text)
+
+ painter.restore()
+
+ # The following commented code does the same as QPainter based code
+ # but it does not work with PySide
+ # self.initStyleOption(option, index)
+ # option.text = text
+ # widget = option.widget
+ # style = qt.QApplication.style() if not widget else widget.style()
+ # style.drawControl(qt.QStyle.CE_ItemViewItem, option, painter, widget)
+
+ else:
+ super(ParameterTreeDelegate, self).paint(painter, option, index)
+
+ def _commit(self, *args):
+ """Commit data to the model from editors"""
+ sender = self.sender()
+ self.commitData.emit(sender)
+
+ def editorEvent(self, event, model, option, index):
+ """See :meth:`QStyledItemDelegate.editorEvent`"""
+ if (event.type() == qt.QEvent.MouseButtonPress and
+ isinstance(index.data(qt.Qt.EditRole), qt.QColor)):
+ initialColor = index.data(qt.Qt.EditRole)
+
+ def callback(color):
+ theModel = index.model()
+ theModel.setData(index, color, qt.Qt.EditRole)
+
+ dialog = qt.QColorDialog(self.parent())
+ # dialog.setOption(qt.QColorDialog.ShowAlphaChannel, True)
+ if sys.platform == 'darwin':
+ # Use of native color dialog on macos might cause problems
+ dialog.setOption(qt.QColorDialog.DontUseNativeDialog, True)
+ dialog.setCurrentColor(initialColor)
+ dialog.currentColorChanged.connect(callback)
+ if dialog.exec() == qt.QDialog.Rejected:
+ # Reset color
+ dialog.setCurrentColor(initialColor)
+
+ return True
+ else:
+ return super(ParameterTreeDelegate, self).editorEvent(
+ event, model, option, index)
+
+ def createEditor(self, parent, option, index):
+ """See :meth:`QStyledItemDelegate.createEditor`"""
+ data = index.data(qt.Qt.EditRole)
+ editorHint = index.data(qt.Qt.UserRole)
+
+ if callable(editorHint):
+ editor = editorHint()
+ assert isinstance(editor, qt.QWidget)
+ editor.setParent(parent)
+
+ elif isinstance(data, numbers.Number) and editorHint is not None:
+ # Use a slider
+ editor = IntSliderEditor(parent)
+ range_ = editorHint
+ editor.setRange(*range_)
+ editor.sliderReleased.connect(self._commit)
+
+ elif isinstance(data, str) and editorHint is not None:
+ # Use a combo box
+ editor = qt.QComboBox(parent)
+ if data not in editorHint:
+ editor.addItem(data)
+ editor.addItems(editorHint)
+
+ index = editor.findText(data)
+ editor.setCurrentIndex(index)
+
+ editor.currentIndexChanged.connect(self._commit)
+
+ else:
+ # Handle overridden editors from Python
+ # Mimic Qt C++ implementation
+ for type_, editorClass in self.EDITORS.items():
+ if isinstance(data, type_):
+ editor = editorClass(parent)
+ metaObject = editor.metaObject()
+ userProperty = metaObject.userProperty()
+ if userProperty.isValid() and userProperty.hasNotifySignal():
+ notifySignal = userProperty.notifySignal()
+ signature = notifySignal.methodSignature()
+ if qt.BINDING == 'PySide2':
+ signature = signature.data()
+ else:
+ signature = bytes(signature)
+
+ if hasattr(signature, 'decode'): # For PySide with python3
+ signature = signature.decode('ascii')
+ signalName = signature.split('(')[0]
+
+ signal = getattr(editor, signalName)
+ signal.connect(self._commit)
+ break
+
+ else: # Default handling for default types
+ return super(ParameterTreeDelegate, self).createEditor(
+ parent, option, index)
+
+ editor.setAutoFillBackground(True)
+ return editor
+
+ def setModelData(self, editor, model, index):
+ """See :meth:`QStyledItemDelegate.setModelData`"""
+ if isinstance(editor, tuple(self.EDITORS.values())):
+ # Special handling of Python classes
+ # Translation of QStyledItemDelegate::setModelData to Python
+ # To make it work with Python QVariant wrapping/unwrapping
+ name = editor.metaObject().userProperty().name()
+ if not name:
+ pass # TODO handle the case of missing user property
+ if name:
+ if hasattr(editor, name):
+ value = getattr(editor, name)
+ else:
+ value = editor.property(name)
+ model.setData(index, value, qt.Qt.EditRole)
+
+ else:
+ super(ParameterTreeDelegate, self).setModelData(editor, model, index)
+
+
+class ParamTreeView(qt.QTreeView):
+ """QTreeView specific to handle plot3d scene and object parameters.
+
+ It provides additional editors and specific creation of persistent editors.
+
+ :param parent: The widget's parent.
+ """
+
+ def __init__(self, parent=None):
+ super(ParamTreeView, self).__init__(parent)
+
+ header = self.header()
+ header.setMinimumSectionSize(128) # For colormap pixmaps
+ header.setSectionResizeMode(qt.QHeaderView.ResizeToContents)
+
+ delegate = ParameterTreeDelegate()
+ self.setItemDelegate(delegate)
+
+ self.setSelectionBehavior(qt.QAbstractItemView.SelectRows)
+ self.setSelectionMode(qt.QAbstractItemView.SingleSelection)
+
+ self.expanded.connect(self._expanded)
+
+ self.setEditTriggers(qt.QAbstractItemView.CurrentChanged |
+ qt.QAbstractItemView.DoubleClicked)
+
+ self.__persistentEditors = set()
+
+ def _openEditorForIndex(self, index):
+ """Check if it has to open a persistent editor for a specific cell.
+
+ :param QModelIndex index: The cell index
+ """
+ if index.flags() & qt.Qt.ItemIsEditable:
+ data = index.data(qt.Qt.EditRole)
+ editorHint = index.data(qt.Qt.UserRole)
+ if (isinstance(data, bool) or
+ callable(editorHint) or
+ (isinstance(data, numbers.Number) and editorHint)):
+ self.openPersistentEditor(index)
+ self.__persistentEditors.add(index)
+
+ def _openEditors(self, parent=qt.QModelIndex()):
+ """Open persistent editors in a subtree starting at parent.
+
+ :param QModelIndex parent: The root of the subtree to process.
+ """
+ model = self.model()
+ if model is not None:
+ for index in visitQAbstractItemModel(model, parent):
+ self._openEditorForIndex(index)
+
+ def setModel(self, model):
+ """Set the model this TreeView is displaying
+
+ :param QAbstractItemModel model:
+ """
+ super(ParamTreeView, self).setModel(model)
+ self._openEditors()
+
+ def rowsInserted(self, parent, start, end):
+ """See :meth:`QTreeView.rowsInserted`"""
+ super(ParamTreeView, self).rowsInserted(parent, start, end)
+ model = self.model()
+ if model is not None:
+ for row in range(start, end+1):
+ self._openEditorForIndex(model.index(row, 1, parent))
+ self._openEditors(model.index(row, 0, parent))
+
+ def _expanded(self, index):
+ """Handle QTreeView expanded signal"""
+ name = index.data(qt.Qt.DisplayRole)
+ if name == 'Transform':
+ rotateIndex = self.model().index(1, 0, index)
+ self.setExpanded(rotateIndex, True)
+
+ def dataChanged(self, topLeft, bottomRight, roles=()):
+ """Handle model dataChanged signal eventually closing editors"""
+ if roles: # Qt 5
+ super(ParamTreeView, self).dataChanged(topLeft, bottomRight, roles)
+ else: # Qt4 compatibility
+ super(ParamTreeView, self).dataChanged(topLeft, bottomRight)
+ if not roles or qt.Qt.UserRole in roles: # Check editorHint update
+ for row in range(topLeft.row(), bottomRight.row() + 1):
+ for column in range(topLeft.column(), bottomRight.column() + 1):
+ index = topLeft.sibling(row, column)
+ if index.isValid():
+ if self._isPersistentEditorOpen(index):
+ self.closePersistentEditor(index)
+ self._openEditorForIndex(index)
+
+ def _isPersistentEditorOpen(self, index):
+ """Returns True if a persistent editor is opened for index
+
+ :param QModelIndex index:
+ :rtype: bool
+ """
+ return index in self.__persistentEditors
+
+ def selectionCommand(self, index, event=None):
+ """Filter out selection of not selectable items"""
+ if index.flags() & qt.Qt.ItemIsSelectable:
+ return super(ParamTreeView, self).selectionCommand(index, event)
+ else:
+ return qt.QItemSelectionModel.NoUpdate
diff --git a/src/silx/gui/plot3d/Plot3DWidget.py b/src/silx/gui/plot3d/Plot3DWidget.py
new file mode 100644
index 0000000..a90d34c
--- /dev/null
+++ b/src/silx/gui/plot3d/Plot3DWidget.py
@@ -0,0 +1,463 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a Qt widget embedding an OpenGL scene."""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+
+import enum
+import logging
+
+from silx.gui import qt
+from silx.gui.colors import rgba
+from . import actions
+
+from ...utils.enum import Enum as _Enum
+from ..utils.image import convertArrayToQImage
+
+from .. import _glutils as glu
+from .scene import interaction, primitives, transform
+from . import scene
+
+import numpy
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _OverviewViewport(scene.Viewport):
+ """A scene displaying the orientation of the data in another scene.
+
+ :param Camera camera: The camera to track.
+ """
+
+ _SIZE = 100
+ """Size in pixels of the overview square"""
+
+ def __init__(self, camera=None):
+ super(_OverviewViewport, self).__init__()
+ self.size = self._SIZE, self._SIZE
+ self.background = None # Disable clear
+
+ self.scene.transforms = [transform.Scale(2.5, 2.5, 2.5)]
+
+ # Add a point to draw the background (in a group with depth mask)
+ backgroundPoint = primitives.ColorPoints(
+ x=0., y=0., z=0.,
+ color=(1., 1., 1., 0.5),
+ size=self._SIZE)
+ backgroundPoint.marker = 'o'
+ noDepthGroup = primitives.GroupNoDepth(mask=True, notest=True)
+ noDepthGroup.children.append(backgroundPoint)
+ self.scene.children.append(noDepthGroup)
+
+ axes = primitives.Axes()
+ self.scene.children.append(axes)
+
+ if camera is not None:
+ camera.addListener(self._cameraChanged)
+
+ def _cameraChanged(self, source):
+ """Listen to camera in other scene for transformation updates.
+
+ Sync the overview camera to point in the same direction
+ but from a sphere centered on origin.
+ """
+ position = -12. * source.extrinsic.direction
+ self.camera.extrinsic.position = position
+
+ self.camera.extrinsic.setOrientation(
+ source.extrinsic.direction, source.extrinsic.up)
+
+
+class Plot3DWidget(glu.OpenGLWidget):
+ """OpenGL widget with a 3D viewport and an overview."""
+
+ sigInteractiveModeChanged = qt.Signal()
+ """Signal emitted when the interactive mode has changed
+ """
+
+ sigStyleChanged = qt.Signal(str)
+ """Signal emitted when the style of the scene has changed
+
+ It provides the updated property.
+ """
+
+ sigSceneClicked = qt.Signal(float, float)
+ """Signal emitted when the scene is clicked with the left mouse button.
+
+ It provides the (x, y) clicked mouse position in logical widget pixel coordinates.
+ """
+
+ @enum.unique
+ class FogMode(_Enum):
+ """Different mode to render the scene with fog"""
+
+ NONE = 'none'
+ """No fog effect"""
+
+ LINEAR = 'linear'
+ """Linear fog through the whole scene"""
+
+ def __init__(self, parent=None, f=qt.Qt.WindowFlags()):
+ self._firstRender = True
+
+ super(Plot3DWidget, self).__init__(
+ parent,
+ alphaBufferSize=8,
+ depthBufferSize=0,
+ stencilBufferSize=0,
+ version=(2, 1),
+ f=f)
+
+ self.setAutoFillBackground(False)
+ self.setMouseTracking(True)
+
+ self.setFocusPolicy(qt.Qt.StrongFocus)
+ self._copyAction = actions.io.CopyAction(parent=self, plot3d=self)
+ self.addAction(self._copyAction)
+
+ self._updating = False # True if an update is requested
+
+ # Main viewport
+ self.viewport = scene.Viewport()
+
+ self._sceneScale = transform.Scale(1., 1., 1.)
+ self.viewport.scene.transforms = [self._sceneScale,
+ transform.Translate(0., 0., 0.)]
+
+ # Overview area
+ self.overview = _OverviewViewport(self.viewport.camera)
+
+ self.setBackgroundColor((0.2, 0.2, 0.2, 1.))
+
+ # Window describing on screen area to render
+ self._window = scene.Window(mode='framebuffer')
+ self._window.viewports = [self.viewport, self.overview]
+ self._window.addListener(self._redraw)
+
+ self.eventHandler = None
+ self.setInteractiveMode('rotate')
+
+ def __clickHandler(self, *args):
+ """Handle interaction state machine click"""
+ x, y = args[0][:2]
+ # Convert from device pixel to logical pixel unit
+ devicePixelRatio = self.getDevicePixelRatio()
+ self.sigSceneClicked.emit(x / devicePixelRatio, y / devicePixelRatio)
+
+ def setInteractiveMode(self, mode):
+ """Set the interactive mode.
+
+ :param str mode: The interactive mode: 'rotate', 'pan' or None
+ """
+ if mode == self.getInteractiveMode():
+ return
+
+ if mode is None:
+ self.eventHandler = None
+
+ elif mode == 'rotate':
+ self.eventHandler = interaction.RotateCameraControl(
+ self.viewport,
+ orbitAroundCenter=False,
+ mode='position',
+ scaleTransform=self._sceneScale,
+ selectCB=self.__clickHandler)
+
+ elif mode == 'pan':
+ self.eventHandler = interaction.PanCameraControl(
+ self.viewport,
+ orbitAroundCenter=False,
+ mode='position',
+ scaleTransform=self._sceneScale,
+ selectCB=self.__clickHandler)
+
+ elif isinstance(mode, interaction.StateMachine):
+ self.eventHandler = mode
+
+ else:
+ raise ValueError('Unsupported interactive mode %s', str(mode))
+
+ if (self.eventHandler is not None and
+ qt.QApplication.keyboardModifiers() & qt.Qt.ControlModifier):
+ self.eventHandler.handleEvent('keyPress', qt.Qt.Key_Control)
+
+ self.sigInteractiveModeChanged.emit()
+
+ def getInteractiveMode(self):
+ """Returns the interactive mode in use.
+
+ :rtype: str
+ """
+ if self.eventHandler is None:
+ return None
+ if isinstance(self.eventHandler, interaction.RotateCameraControl):
+ return 'rotate'
+ elif isinstance(self.eventHandler, interaction.PanCameraControl):
+ return 'pan'
+ else:
+ return None
+
+ def setProjection(self, projection):
+ """Change the projection in use.
+
+ :param str projection: In 'perspective', 'orthographic'.
+ """
+ if projection == 'orthographic':
+ projection = transform.Orthographic(size=self.viewport.size)
+ elif projection == 'perspective':
+ projection = transform.Perspective(fovy=30.,
+ size=self.viewport.size)
+ else:
+ raise RuntimeError('Unsupported projection: %s' % projection)
+
+ self.viewport.camera.intrinsic = projection
+ self.viewport.resetCamera()
+
+ def getProjection(self):
+ """Return the current camera projection mode as a str.
+
+ See :meth:`setProjection`
+ """
+ projection = self.viewport.camera.intrinsic
+ if isinstance(projection, transform.Orthographic):
+ return 'orthographic'
+ elif isinstance(projection, transform.Perspective):
+ return 'perspective'
+ else:
+ raise RuntimeError('Unknown projection in use')
+
+ def setBackgroundColor(self, color):
+ """Set the background color of the OpenGL view.
+
+ :param color: RGB color of the isosurface: name, #RRGGBB or RGB values
+ :type color:
+ QColor, str or array-like of 3 or 4 float in [0., 1.] or uint8
+ """
+ color = rgba(color)
+ if color != self.viewport.background:
+ self.viewport.background = color
+ self.sigStyleChanged.emit('backgroundColor')
+
+ def getBackgroundColor(self):
+ """Returns the RGBA background color (QColor)."""
+ return qt.QColor.fromRgbF(*self.viewport.background)
+
+ def setFogMode(self, mode):
+ """Set the kind of fog to use for the whole scene.
+
+ :param Union[str,FogMode] mode: The mode to use
+ :raise ValueError: If mode is not supported
+ """
+ mode = self.FogMode.from_value(mode)
+ if mode != self.getFogMode():
+ self.viewport.fog.isOn = mode is self.FogMode.LINEAR
+ self.sigStyleChanged.emit('fogMode')
+
+ def getFogMode(self):
+ """Returns the kind of fog in use
+
+ :return: The kind of fog in use
+ :rtype: FogMode
+ """
+ if self.viewport.fog.isOn:
+ return self.FogMode.LINEAR
+ else:
+ return self.FogMode.NONE
+
+ def isOrientationIndicatorVisible(self):
+ """Returns True if the orientation indicator is displayed.
+
+ :rtype: bool
+ """
+ return self.overview in self._window.viewports
+
+ def setOrientationIndicatorVisible(self, visible):
+ """Set the orientation indicator visibility.
+
+ :param bool visible: True to show
+ """
+ visible = bool(visible)
+ if visible != self.isOrientationIndicatorVisible():
+ if visible:
+ self._window.viewports = [self.viewport, self.overview]
+ else:
+ self._window.viewports = [self.viewport]
+ self.sigStyleChanged.emit('orientationIndicatorVisible')
+
+ def centerScene(self):
+ """Position the center of the scene at the center of rotation."""
+ self.viewport.resetCamera()
+
+ def resetZoom(self, face='front'):
+ """Reset the camera position to a default.
+
+ :param str face: The direction the camera is looking at:
+ side, front, back, top, bottom, right, left.
+ Default: front.
+ """
+ self.viewport.camera.extrinsic.reset(face=face)
+ self.centerScene()
+
+ def _redraw(self, source=None):
+ """Viewport listener to require repaint"""
+ if not self._updating:
+ self._updating = True # Mark that an update is requested
+ self.update() # Queued repaint (i.e., asynchronous)
+
+ def sizeHint(self):
+ return qt.QSize(400, 300)
+
+ def initializeGL(self):
+ pass
+
+ def paintGL(self):
+ # In case paintGL is called by the system and not through _redraw,
+ # Mark as updating.
+ self._updating = True
+
+ # Update near and far planes only if viewport needs refresh
+ if self.viewport.dirty:
+ self.viewport.adjustCameraDepthExtent()
+
+ self._window.render(self.context(), self.getDevicePixelRatio())
+
+ if self._firstRender: # TODO remove this ugly hack
+ self._firstRender = False
+ self.centerScene()
+ self._updating = False
+
+ def resizeGL(self, width, height):
+ width *= self.getDevicePixelRatio()
+ height *= self.getDevicePixelRatio()
+ self._window.size = width, height
+ self.viewport.size = self._window.size
+ overviewWidth, overviewHeight = self.overview.size
+ self.overview.origin = width - overviewWidth, height - overviewHeight
+
+ def grabGL(self):
+ """Renders the OpenGL scene into a numpy array
+
+ :returns: OpenGL scene RGB rasterization
+ :rtype: QImage
+ """
+ if not self.isValid():
+ _logger.error('OpenGL 2.1 not available, cannot save OpenGL image')
+ height, width = self._window.shape
+ image = numpy.zeros((height, width, 3), dtype=numpy.uint8)
+
+ else:
+ self.makeCurrent()
+ image = self._window.grab(self.context())
+
+ return convertArrayToQImage(image)
+
+ def wheelEvent(self, event):
+ if qt.BINDING == "PySide6":
+ x, y = event.position().x(), event.position().y()
+ else:
+ x, y = event.x(), event.y()
+ xpixel = x * self.getDevicePixelRatio()
+ ypixel = y * self.getDevicePixelRatio()
+ angle = event.angleDelta().y() / 8.
+ event.accept()
+
+ if self.eventHandler is not None and angle != 0 and self.isValid():
+ self.makeCurrent()
+ self.eventHandler.handleEvent('wheel', xpixel, ypixel, angle)
+
+ def keyPressEvent(self, event):
+ keyCode = event.key()
+ # No need to accept QKeyEvent
+
+ converter = {
+ qt.Qt.Key_Left: 'left',
+ qt.Qt.Key_Right: 'right',
+ qt.Qt.Key_Up: 'up',
+ qt.Qt.Key_Down: 'down'
+ }
+ direction = converter.get(keyCode, None)
+ if direction is not None:
+ if event.modifiers() == qt.Qt.ControlModifier:
+ self.viewport.camera.rotate(direction)
+ elif event.modifiers() == qt.Qt.ShiftModifier:
+ self.viewport.moveCamera(direction)
+ else:
+ self.viewport.orbitCamera(direction)
+
+ else:
+ if (keyCode == qt.Qt.Key_Control and
+ self.eventHandler is not None and
+ self.isValid()):
+ self.eventHandler.handleEvent('keyPress', keyCode)
+
+ # Key not handled, call base class implementation
+ super(Plot3DWidget, self).keyPressEvent(event)
+
+ def keyReleaseEvent(self, event):
+ """Catch Ctrl key release"""
+ keyCode = event.key()
+ if (keyCode == qt.Qt.Key_Control and
+ self.eventHandler is not None and
+ self.isValid()):
+ self.eventHandler.handleEvent('keyRelease', keyCode)
+ super(Plot3DWidget, self).keyReleaseEvent(event)
+
+ # Mouse events #
+ _MOUSE_BTNS = {1: 'left', 2: 'right', 4: 'middle'}
+
+ def mousePressEvent(self, event):
+ xpixel = event.x() * self.getDevicePixelRatio()
+ ypixel = event.y() * self.getDevicePixelRatio()
+ btn = self._MOUSE_BTNS[event.button()]
+ event.accept()
+
+ if self.eventHandler is not None and self.isValid():
+ self.makeCurrent()
+ self.eventHandler.handleEvent('press', xpixel, ypixel, btn)
+
+ def mouseMoveEvent(self, event):
+ xpixel = event.x() * self.getDevicePixelRatio()
+ ypixel = event.y() * self.getDevicePixelRatio()
+ event.accept()
+
+ if self.eventHandler is not None and self.isValid():
+ self.makeCurrent()
+ self.eventHandler.handleEvent('move', xpixel, ypixel)
+
+ def mouseReleaseEvent(self, event):
+ xpixel = event.x() * self.getDevicePixelRatio()
+ ypixel = event.y() * self.getDevicePixelRatio()
+ btn = self._MOUSE_BTNS[event.button()]
+ event.accept()
+
+ if self.eventHandler is not None and self.isValid():
+ self.makeCurrent()
+ self.eventHandler.handleEvent('release', xpixel, ypixel, btn)
diff --git a/src/silx/gui/plot3d/Plot3DWindow.py b/src/silx/gui/plot3d/Plot3DWindow.py
new file mode 100644
index 0000000..470b966
--- /dev/null
+++ b/src/silx/gui/plot3d/Plot3DWindow.py
@@ -0,0 +1,88 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a QMainWindow with a 3D scene and associated toolbar.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "26/01/2017"
+
+
+from silx.utils.proxy import docstring
+from silx.gui import qt
+
+from .Plot3DWidget import Plot3DWidget
+from .tools import OutputToolBar, InteractiveModeToolBar, ViewpointToolBar
+
+
+class Plot3DWindow(qt.QMainWindow):
+ """OpenGL widget with a 3D viewport and an overview."""
+
+ def __init__(self, parent=None):
+ super(Plot3DWindow, self).__init__(parent)
+ if parent is not None:
+ # behave as a widget
+ self.setWindowFlags(qt.Qt.Widget)
+
+ self._plot3D = Plot3DWidget()
+ self.setCentralWidget(self._plot3D)
+
+ for klass in (InteractiveModeToolBar, ViewpointToolBar, OutputToolBar):
+ toolbar = klass(parent=self)
+ toolbar.setPlot3DWidget(self._plot3D)
+ self.addToolBar(toolbar)
+ self.addActions(toolbar.actions())
+
+ def getPlot3DWidget(self):
+ """Get the :class:`Plot3DWidget` of this window"""
+ return self._plot3D
+
+ # Proxy to Plot3DWidget
+
+ @docstring(Plot3DWidget)
+ def setProjection(self, projection):
+ return self._plot3D.setProjection(projection)
+
+ @docstring(Plot3DWidget)
+ def getProjection(self):
+ return self._plot3D.getProjection()
+
+ @docstring(Plot3DWidget)
+ def centerScene(self):
+ return self._plot3D.centerScene()
+
+ @docstring(Plot3DWidget)
+ def resetZoom(self):
+ return self._plot3D.resetZoom()
+
+ @docstring(Plot3DWidget)
+ def getBackgroundColor(self):
+ return self._plot3D.getBackgroundColor()
+
+ @docstring(Plot3DWidget)
+ def setBackgroundColor(self, color):
+ return self._plot3D.setBackgroundColor(color)
diff --git a/src/silx/gui/plot3d/SFViewParamTree.py b/src/silx/gui/plot3d/SFViewParamTree.py
new file mode 100644
index 0000000..b269a6a
--- /dev/null
+++ b/src/silx/gui/plot3d/SFViewParamTree.py
@@ -0,0 +1,1814 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module provides a tree widget to set/view parameters of a ScalarFieldView.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["D. N."]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+import logging
+import sys
+import weakref
+
+import numpy
+
+from silx.gui import qt
+from silx.gui.icons import getQIcon
+from silx.gui.colors import Colormap
+from silx.gui.widgets.FloatEdit import FloatEdit
+
+from .ScalarFieldView import Isosurface
+
+
+_logger = logging.getLogger(__name__)
+
+
+class ModelColumns(object):
+ NameColumn, ValueColumn, ColumnMax = range(3)
+ ColumnNames = ['Name', 'Value']
+
+
+class SubjectItem(qt.QStandardItem):
+ """
+ Base class for observers items.
+
+ Subclassing:
+ ------------
+ The following method can/should be reimplemented:
+ - _init
+ - _pullData
+ - _pushData
+ - _setModelData
+ - _subjectChanged
+ - getEditor
+ - getSignals
+ - leftClicked
+ - queryRemove
+ - setEditorData
+
+ Also the following attributes are available:
+ - editable
+ - persistent
+
+ :param subject: object that this item will be observing.
+ """
+
+ editable = False
+ """ boolean: set to True to make the item editable. """
+
+ persistent = False
+ """
+ boolean: set to True to make the editor persistent.
+ See : Qt.QAbstractItemView.openPersistentEditor
+ """
+
+ def __init__(self, subject, *args):
+
+ super(SubjectItem, self).__init__(*args)
+
+ self.setEditable(self.editable)
+
+ self.__subject = None
+ self.subject = subject
+
+ def setData(self, value, role=qt.Qt.UserRole, pushData=True):
+ """
+ Overloaded method from QStandardItem. The pushData keyword tells
+ the item to push data to the subject if the role is equal to EditRole.
+ This is useful to let this method know if the setData method was called
+ internally or from the view.
+
+ :param value: the value ti set to data
+ :param role: role in the item
+ :param pushData: if True push value in the existing data.
+ """
+ if role == qt.Qt.EditRole and pushData:
+ setValue = self._pushData(value, role)
+ if setValue != value:
+ value = setValue
+ super(SubjectItem, self).setData(value, role)
+
+ @property
+ def subject(self):
+ """The subject this item is observing"""
+ return None if self.__subject is None else self.__subject()
+
+ @subject.setter
+ def subject(self, subject):
+ if self.__subject is not None:
+ raise ValueError('Subject already set '
+ ' (subject change not supported).')
+ if subject is None:
+ self.__subject = None
+ else:
+ self.__subject = weakref.ref(subject)
+ if subject is not None:
+ self._init()
+ self._connectSignals()
+
+ def _connectSignals(self):
+ """
+ Connects the signals. Called when the subject is set.
+ """
+
+ def gen_slot(_sigIdx):
+ def slotfn(*args, **kwargs):
+ self._subjectChanged(signalIdx=_sigIdx,
+ args=args,
+ kwargs=kwargs)
+ return slotfn
+
+ if self.__subject is not None:
+ self.__slots = slots = []
+
+ signals = self.getSignals()
+
+ if signals:
+ if not isinstance(signals, (list, tuple)):
+ signals = [signals]
+ for sigIdx, signal in enumerate(signals):
+ slot = gen_slot(sigIdx)
+ signal.connect(slot)
+ slots.append((signal, slot))
+
+ def _disconnectSignals(self):
+ """
+ Disconnects all subject's signal
+ """
+ if self.__slots:
+ for signal, slot in self.__slots:
+ try:
+ signal.disconnect(slot)
+ except TypeError:
+ pass
+
+ def _enableRow(self, enable):
+ """
+ Set the enabled state for this cell, or for the whole row
+ if this item has a parent.
+
+ :param bool enable: True if we wan't to enable the cell
+ """
+ parent = self.parent()
+ model = self.model()
+ if model is None or parent is None:
+ # no parent -> no siblings
+ self.setEnabled(enable)
+ return
+
+ for col in range(model.columnCount()):
+ sibling = parent.child(self.row(), col)
+ sibling.setEnabled(enable)
+
+ #################################################################
+ # Overloadable methods
+ #################################################################
+
+ def getSignals(self):
+ """
+ Returns the list of this items subject's signals that
+ this item will be listening to.
+
+ :return: list.
+ """
+ return None
+
+ def _subjectChanged(self, signalIdx=None, args=None, kwargs=None):
+ """
+ Called when one of the signals is triggered. Default implementation
+ just calls _pullData, compares the result to the current value stored
+ as Qt.EditRole, and stores the new value if it is different. It also
+ stores its str representation as Qt.DisplayRole
+
+ :param signalIdx: index of the triggered signal. The value passed
+ is the same as the signal position in the list returned by
+ SubjectItem.getSignals.
+ :param args: arguments received from the signal
+ :param kwargs: keyword arguments received from the signal
+ """
+ data = self._pullData()
+ if data == self.data(qt.Qt.EditRole):
+ return
+ self.setData(data, role=qt.Qt.DisplayRole, pushData=False)
+ self.setData(data, role=qt.Qt.EditRole, pushData=False)
+
+ def _pullData(self):
+ """
+ Pulls data from the subject.
+
+ :return: subject data
+ """
+ return None
+
+ def _pushData(self, value, role=qt.Qt.UserRole):
+ """
+ Pushes data to the subject and returns the actual value that was stored
+
+ :return: the value that was stored
+ """
+ return value
+
+ def _init(self):
+ """
+ Called when the subject is set.
+ :return:
+ """
+ self._subjectChanged()
+
+ def getEditor(self, parent, option, index):
+ """
+ Returns the editor widget used to edit this item's data. The arguments
+ are the one passed to the QStyledItemDelegate.createEditor method.
+
+ :param parent: the Qt parent of the editor
+ :param option:
+ :param index:
+ :return:
+ """
+ return None
+
+ def setEditorData(self, editor):
+ """
+ This is called by the View's delegate just before the editor is shown,
+ its purpose it to setup the editors contents. Return False to use
+ the delegate's default behaviour.
+
+ :param editor:
+ :return:
+ """
+ return True
+
+ def _setModelData(self, editor):
+ """
+ This is called by the View's delegate just before the editor is closed,
+ its allows this item to update itself with data from the editor.
+
+ :param editor:
+ :return:
+ """
+ return False
+
+ def queryRemove(self, view=None):
+ """
+ This is called by the view to ask this items if it (the view) can
+ remove it. Return True to let the view know that the item can be
+ removed.
+
+ :param view:
+ :return:
+ """
+ return False
+
+ def leftClicked(self):
+ """
+ This method is called by the view when the item's cell if left clicked.
+
+ :return:
+ """
+ pass
+
+
+# View settings ###############################################################
+
+class ColorItem(SubjectItem):
+ """color item."""
+ editable = True
+ persistent = True
+
+ def getEditor(self, parent, option, index):
+ editor = QColorEditor(parent)
+ editor.color = self.getColor()
+
+ # Wrapping call in lambda is a workaround for PySide with Python 3
+ editor.sigColorChanged.connect(
+ lambda color: self._editorSlot(color))
+ return editor
+
+ def _editorSlot(self, color):
+ self.setData(color, qt.Qt.EditRole)
+
+ def _pushData(self, value, role=qt.Qt.UserRole):
+ self.setColor(value)
+ return self.getColor()
+
+ def _pullData(self):
+ self.getColor()
+
+ def setColor(self, color):
+ """Override to implement actual color setter"""
+ pass
+
+
+class BackgroundColorItem(ColorItem):
+ itemName = 'Background'
+
+ def setColor(self, color):
+ self.subject.setBackgroundColor(color)
+
+ def getColor(self):
+ return self.subject.getBackgroundColor()
+
+
+class ForegroundColorItem(ColorItem):
+ itemName = 'Foreground'
+
+ def setColor(self, color):
+ self.subject.setForegroundColor(color)
+
+ def getColor(self):
+ return self.subject.getForegroundColor()
+
+
+class HighlightColorItem(ColorItem):
+ itemName = 'Highlight'
+
+ def setColor(self, color):
+ self.subject.setHighlightColor(color)
+
+ def getColor(self):
+ return self.subject.getHighlightColor()
+
+
+class _LightDirectionAngleBaseItem(SubjectItem):
+ """Base class for directional light angle item."""
+ editable = True
+ persistent = True
+
+ def _init(self):
+ pass
+
+ def getSignals(self):
+ """Override to provide signals to listen"""
+ raise NotImplementedError("MUST be implemented in subclass")
+
+ def _pullData(self):
+ """Override in subclass to get current angle"""
+ raise NotImplementedError("MUST be implemented in subclass")
+
+ def _pushData(self, value, role=qt.Qt.UserRole):
+ """Override in subclass to set the angle"""
+ raise NotImplementedError("MUST be implemented in subclass")
+
+ def getEditor(self, parent, option, index):
+ editor = qt.QSlider(parent)
+ editor.setOrientation(qt.Qt.Horizontal)
+ editor.setMinimum(-90)
+ editor.setMaximum(90)
+ editor.setValue(int(self._pullData()))
+
+ # Wrapping call in lambda is a workaround for PySide with Python 3
+ editor.valueChanged.connect(
+ lambda value: self._pushData(value))
+
+ return editor
+
+ def setEditorData(self, editor):
+ editor.setValue(int(self._pullData()))
+ return True
+
+ def _setModelData(self, editor):
+ value = editor.value()
+ self._pushData(value)
+ return True
+
+
+class LightAzimuthAngleItem(_LightDirectionAngleBaseItem):
+ """Light direction azimuth angle item."""
+
+ def getSignals(self):
+ return self.subject.sigAzimuthAngleChanged
+
+ def _pullData(self):
+ return self.subject.getAzimuthAngle()
+
+ def _pushData(self, value, role=qt.Qt.UserRole):
+ self.subject.setAzimuthAngle(value)
+
+
+class LightAltitudeAngleItem(_LightDirectionAngleBaseItem):
+ """Light direction altitude angle item."""
+
+ def getSignals(self):
+ return self.subject.sigAltitudeAngleChanged
+
+ def _pullData(self):
+ return self.subject.getAltitudeAngle()
+
+ def _pushData(self, value, role=qt.Qt.UserRole):
+ self.subject.setAltitudeAngle(value)
+
+
+class _DirectionalLightProxy(qt.QObject):
+ """Proxy to handle directional light with angles rather than vector.
+ """
+
+ sigAzimuthAngleChanged = qt.Signal()
+ """Signal sent when the azimuth angle has changed."""
+
+ sigAltitudeAngleChanged = qt.Signal()
+ """Signal sent when altitude angle has changed."""
+
+ def __init__(self, light):
+ super(_DirectionalLightProxy, self).__init__()
+ self._light = light
+ light.addListener(self._directionUpdated)
+ self._azimuth = 0.
+ self._altitude = 0.
+
+ def getAzimuthAngle(self):
+ """Returns the signed angle in the horizontal plane.
+
+ Unit: degrees.
+ The 0 angle corresponds to the axis perpendicular to the screen.
+
+ :rtype: float
+ """
+ return self._azimuth
+
+ def getAltitudeAngle(self):
+ """Returns the signed vertical angle from the horizontal plane.
+
+ Unit: degrees.
+ Range: [-90, +90]
+
+ :rtype: float
+ """
+ return self._altitude
+
+ def setAzimuthAngle(self, angle):
+ """Set the horizontal angle.
+
+ :param float angle: Angle from -z axis in zx plane in degrees.
+ """
+ if angle != self._azimuth:
+ self._azimuth = angle
+ self._updateLight()
+ self.sigAzimuthAngleChanged.emit()
+
+ def setAltitudeAngle(self, angle):
+ """Set the horizontal angle.
+
+ :param float angle: Angle from -z axis in zy plane in degrees.
+ """
+ if angle != self._altitude:
+ self._altitude = angle
+ self._updateLight()
+ self.sigAltitudeAngleChanged.emit()
+
+ def _directionUpdated(self, *args, **kwargs):
+ """Handle light direction update in the scene"""
+ # Invert direction to manipulate the 'source' pointing to
+ # the center of the viewport
+ x, y, z = - self._light.direction
+
+ # Horizontal plane is plane xz
+ azimuth = numpy.degrees(numpy.arctan2(x, z))
+ altitude = numpy.degrees(numpy.pi/2. - numpy.arccos(y))
+
+ if (abs(azimuth - self.getAzimuthAngle()) > 0.01 and
+ abs(abs(altitude) - 90.) >= 0.001): # Do not update when at zenith
+ self.setAzimuthAngle(azimuth)
+
+ if abs(altitude - self.getAltitudeAngle()) > 0.01:
+ self.setAltitudeAngle(altitude)
+
+ def _updateLight(self):
+ """Update light direction in the scene"""
+ azimuth = numpy.radians(self._azimuth)
+ delta = numpy.pi/2. - numpy.radians(self._altitude)
+ z = - numpy.sin(delta) * numpy.cos(azimuth)
+ x = - numpy.sin(delta) * numpy.sin(azimuth)
+ y = - numpy.cos(delta)
+ self._light.direction = x, y, z
+
+
+class DirectionalLightGroup(SubjectItem):
+ """
+ Root Item for the directional light
+ """
+
+ def __init__(self,subject, *args):
+ self._light = _DirectionalLightProxy(
+ subject.getPlot3DWidget().viewport.light)
+
+ super(DirectionalLightGroup, self).__init__(subject, *args)
+
+ def _init(self):
+
+ nameItem = qt.QStandardItem('Azimuth')
+ nameItem.setEditable(False)
+ valueItem = LightAzimuthAngleItem(self._light)
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem('Altitude')
+ nameItem.setEditable(False)
+ valueItem = LightAltitudeAngleItem(self._light)
+ self.appendRow([nameItem, valueItem])
+
+
+class BoundingBoxItem(SubjectItem):
+ """Bounding box, axes labels and grid visibility item.
+
+ Item is checkable.
+ """
+ itemName = 'Bounding Box'
+
+ def _init(self):
+ visible = self.subject.isBoundingBoxVisible()
+ self.setCheckable(True)
+ self.setCheckState(qt.Qt.Checked if visible else qt.Qt.Unchecked)
+
+ def leftClicked(self):
+ checked = (self.checkState() == qt.Qt.Checked)
+ if checked != self.subject.isBoundingBoxVisible():
+ self.subject.setBoundingBoxVisible(checked)
+
+
+class OrientationIndicatorItem(SubjectItem):
+ """Orientation indicator visibility item.
+
+ Item is checkable.
+ """
+ itemName = 'Axes indicator'
+
+ def _init(self):
+ plot3d = self.subject.getPlot3DWidget()
+ visible = plot3d.isOrientationIndicatorVisible()
+ self.setCheckable(True)
+ self.setCheckState(qt.Qt.Checked if visible else qt.Qt.Unchecked)
+
+ def leftClicked(self):
+ plot3d = self.subject.getPlot3DWidget()
+ checked = (self.checkState() == qt.Qt.Checked)
+ if checked != plot3d.isOrientationIndicatorVisible():
+ plot3d.setOrientationIndicatorVisible(checked)
+
+
+class ViewSettingsItem(qt.QStandardItem):
+ """Viewport settings"""
+
+ def __init__(self, subject, *args):
+
+ super(ViewSettingsItem, self).__init__(*args)
+
+ self.setEditable(False)
+
+ classes = (BackgroundColorItem,
+ ForegroundColorItem,
+ HighlightColorItem,
+ BoundingBoxItem,
+ OrientationIndicatorItem)
+ for cls in classes:
+ titleItem = qt.QStandardItem(cls.itemName)
+ titleItem.setEditable(False)
+ self.appendRow([titleItem, cls(subject)])
+
+ nameItem = DirectionalLightGroup(subject, 'Light Direction')
+ valueItem = qt.QStandardItem()
+ self.appendRow([nameItem, valueItem])
+
+
+# Data information ############################################################
+
+class DataChangedItem(SubjectItem):
+ """
+ Base class for items listening to ScalarFieldView.sigDataChanged
+ """
+
+ def getSignals(self):
+ subject = self.subject
+ if subject:
+ return subject.sigDataChanged, subject.sigTransformChanged
+ return None
+
+ def _init(self):
+ self._subjectChanged()
+
+
+class DataTypeItem(DataChangedItem):
+ itemName = 'dtype'
+
+ def _pullData(self):
+ data = self.subject.getData(copy=False)
+ return ((data is not None) and str(data.dtype)) or 'N/A'
+
+
+class DataShapeItem(DataChangedItem):
+ itemName = 'size'
+
+ def _pullData(self):
+ data = self.subject.getData(copy=False)
+ if data is None:
+ return 'N/A'
+ else:
+ return str(list(reversed(data.shape)))
+
+
+class OffsetItem(DataChangedItem):
+ itemName = 'offset'
+
+ def _pullData(self):
+ offset = self.subject.getTranslation()
+ return ((offset is not None) and str(offset)) or 'N/A'
+
+
+class ScaleItem(DataChangedItem):
+ itemName = 'scale'
+
+ def _pullData(self):
+ scale = self.subject.getScale()
+ return ((scale is not None) and str(scale)) or 'N/A'
+
+
+class MatrixItem(DataChangedItem):
+
+ def __init__(self, subject, row, *args):
+ self.__row = row
+ super(MatrixItem, self).__init__(subject, *args)
+
+ def _pullData(self):
+ matrix = self.subject.getTransformMatrix()
+ return str(matrix[self.__row])
+
+
+class DataSetItem(qt.QStandardItem):
+
+ def __init__(self, subject, *args):
+
+ super(DataSetItem, self).__init__(*args)
+
+ self.setEditable(False)
+
+ klasses = [DataTypeItem, DataShapeItem, OffsetItem]
+ for klass in klasses:
+ titleItem = qt.QStandardItem(klass.itemName)
+ titleItem.setEditable(False)
+ self.appendRow([titleItem, klass(subject)])
+
+ matrixItem = qt.QStandardItem('matrix')
+ matrixItem.setEditable(False)
+ valueItem = qt.QStandardItem()
+ self.appendRow([matrixItem, valueItem])
+
+ for row in range(3):
+ titleItem = qt.QStandardItem()
+ titleItem.setEditable(False)
+ valueItem = MatrixItem(subject, row)
+ matrixItem.appendRow([titleItem, valueItem])
+
+ titleItem = qt.QStandardItem(ScaleItem.itemName)
+ titleItem.setEditable(False)
+ self.appendRow([titleItem, ScaleItem(subject)])
+
+
+# Isosurface ##################################################################
+
+class IsoSurfaceRootItem(SubjectItem):
+ """
+ Root (i.e : column index 0) Isosurface item.
+ """
+
+ def __init__(self, subject, normalization, *args):
+ self._isoLevelSliderNormalization = normalization
+ super(IsoSurfaceRootItem, self).__init__(subject, *args)
+
+ def getSignals(self):
+ subject = self.subject
+ return [subject.sigColorChanged,
+ subject.sigVisibilityChanged]
+
+ def _subjectChanged(self, signalIdx=None, args=None, kwargs=None):
+ if signalIdx == 0:
+ color = self.subject.getColor()
+ self.setData(color, qt.Qt.DecorationRole)
+ elif signalIdx == 1:
+ visible = args[0]
+ self.setCheckState((visible and qt.Qt.Checked) or qt.Qt.Unchecked)
+
+ def _init(self):
+ self.setCheckable(True)
+
+ isosurface = self.subject
+ color = isosurface.getColor()
+ visible = isosurface.isVisible()
+ self.setData(color, qt.Qt.DecorationRole)
+ self.setCheckState((visible and qt.Qt.Checked) or qt.Qt.Unchecked)
+
+ nameItem = qt.QStandardItem('Level')
+ sliderItem = IsoSurfaceLevelSlider(self.subject,
+ self._isoLevelSliderNormalization)
+ self.appendRow([nameItem, sliderItem])
+
+ nameItem = qt.QStandardItem('Color')
+ nameItem.setEditable(False)
+ valueItem = IsoSurfaceColorItem(self.subject)
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem('Opacity')
+ nameItem.setTextAlignment(qt.Qt.AlignLeft | qt.Qt.AlignTop)
+ nameItem.setEditable(False)
+ valueItem = IsoSurfaceAlphaItem(self.subject)
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem()
+ nameItem.setEditable(False)
+ valueItem = IsoSurfaceAlphaLegendItem(self.subject)
+ valueItem.setEditable(False)
+ self.appendRow([nameItem, valueItem])
+
+ def queryRemove(self, view=None):
+ buttons = qt.QMessageBox.Ok | qt.QMessageBox.Cancel
+ ans = qt.QMessageBox.question(view,
+ 'Remove isosurface',
+ 'Remove the selected iso-surface?',
+ buttons=buttons)
+ if ans == qt.QMessageBox.Ok:
+ sfview = self.subject.parent()
+ if sfview:
+ sfview.removeIsosurface(self.subject)
+ return False
+ return False
+
+ def leftClicked(self):
+ checked = (self.checkState() == qt.Qt.Checked)
+ visible = self.subject.isVisible()
+ if checked != visible:
+ self.subject.setVisible(checked)
+
+
+class IsoSurfaceLevelItem(SubjectItem):
+ """
+ Base class for the isosurface level items.
+ """
+ editable = True
+
+ def getSignals(self):
+ subject = self.subject
+ return [subject.sigLevelChanged,
+ subject.sigVisibilityChanged]
+
+ def getEditor(self, parent, option, index):
+ return FloatEdit(parent)
+
+ def setEditorData(self, editor):
+ editor.setValue(self._pullData())
+ return False
+
+ def _setModelData(self, editor):
+ self._pushData(editor.value())
+ return True
+
+ def _pullData(self):
+ return self.subject.getLevel()
+
+ def _pushData(self, value, role=qt.Qt.UserRole):
+ self.subject.setLevel(value)
+ return self.subject.getLevel()
+
+
+class _IsoLevelSlider(qt.QSlider):
+ """QSlider used for iso-surface level with linear scale"""
+
+ def __init__(self, parent, subject, normalization):
+ super(_IsoLevelSlider, self).__init__(parent=parent)
+ self.subject = subject
+
+ if normalization == 'arcsinh':
+ self.__norm = numpy.arcsinh
+ self.__invNorm = numpy.sinh
+ elif normalization == 'linear':
+ self.__norm = lambda x: x
+ self.__invNorm = lambda x: x
+ else:
+ raise ValueError(
+ "Unsupported normalization %s", normalization)
+
+ self.sliderReleased.connect(self.__sliderReleased)
+
+ self.subject.sigLevelChanged.connect(self.setLevel)
+ self.subject.parent().sigDataChanged.connect(self.__dataChanged)
+
+ def setLevel(self, level):
+ """Set slider from iso-surface level"""
+ dataRange = self.subject.parent().getDataRange()
+
+ if dataRange is not None:
+ min_ = self.__norm(dataRange[0])
+ max_ = self.__norm(dataRange[-1])
+
+ width = max_ - min_
+ if width > 0:
+ sliderWidth = self.maximum() - self.minimum()
+ sliderPosition = sliderWidth * (self.__norm(level) - min_) / width
+ self.setValue(int(sliderPosition))
+
+ def __dataChanged(self):
+ """Handles data update to refresh slider range if needed"""
+ self.setLevel(self.subject.getLevel())
+
+ def __sliderReleased(self):
+ value = self.value()
+ dataRange = self.subject.parent().getDataRange()
+ if dataRange is not None:
+ min_ = self.__norm(dataRange[0])
+ max_ = self.__norm(dataRange[-1])
+ width = max_ - min_
+ sliderWidth = self.maximum() - self.minimum()
+ level = min_ + width * value / sliderWidth
+ self.subject.setLevel(self.__invNorm(level))
+
+
+class IsoSurfaceLevelSlider(IsoSurfaceLevelItem):
+ """
+ Isosurface level item with a slider editor.
+ """
+ nTicks = 1000
+ persistent = True
+
+ def __init__(self, subject, normalization):
+ self.normalization = normalization
+ super(IsoSurfaceLevelSlider, self).__init__(subject)
+
+ def getEditor(self, parent, option, index):
+ editor = _IsoLevelSlider(parent, self.subject, self.normalization)
+ editor.setOrientation(qt.Qt.Horizontal)
+ editor.setMinimum(0)
+ editor.setMaximum(self.nTicks)
+
+ editor.setSingleStep(1)
+
+ editor.setLevel(self.subject.getLevel())
+ return editor
+
+ def setEditorData(self, editor):
+ return True
+
+ def _setModelData(self, editor):
+ return True
+
+
+class IsoSurfaceColorItem(SubjectItem):
+ """
+ Isosurface color item.
+ """
+ editable = True
+ persistent = True
+
+ def getSignals(self):
+ return self.subject.sigColorChanged
+
+ def getEditor(self, parent, option, index):
+ editor = QColorEditor(parent)
+ color = self.subject.getColor()
+ color.setAlpha(255)
+ editor.color = color
+ # Wrapping call in lambda is a workaround for PySide with Python 3
+ editor.sigColorChanged.connect(
+ lambda color: self.__editorChanged(color))
+ return editor
+
+ def __editorChanged(self, color):
+ color.setAlpha(self.subject.getColor().alpha())
+ self.subject.setColor(color)
+
+ def _pushData(self, value, role=qt.Qt.UserRole):
+ self.subject.setColor(value)
+ return self.subject.getColor()
+
+
+class QColorEditor(qt.QWidget):
+ """
+ QColor editor.
+ """
+ sigColorChanged = qt.Signal(object)
+
+ color = property(lambda self: qt.QColor(self.__color))
+
+ @color.setter
+ def color(self, color):
+ self._setColor(color)
+ self.__previousColor = color
+
+ def __init__(self, *args, **kwargs):
+ super(QColorEditor, self).__init__(*args, **kwargs)
+ layout = qt.QHBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ button = qt.QToolButton()
+ icon = qt.QIcon(qt.QPixmap(32, 32))
+ button.setIcon(icon)
+ layout.addWidget(button)
+ button.clicked.connect(self.__showColorDialog)
+ layout.addStretch(1)
+
+ self.__color = None
+ self.__previousColor = None
+
+ def sizeHint(self):
+ return qt.QSize(0, 0)
+
+ def _setColor(self, qColor):
+ button = self.findChild(qt.QToolButton)
+ pixmap = qt.QPixmap(32, 32)
+ pixmap.fill(qColor)
+ button.setIcon(qt.QIcon(pixmap))
+ self.__color = qColor
+
+ def __showColorDialog(self):
+ dialog = qt.QColorDialog(parent=self)
+ if sys.platform == 'darwin':
+ # Use of native color dialog on macos might cause problems
+ dialog.setOption(qt.QColorDialog.DontUseNativeDialog, True)
+
+ self.__previousColor = self.__color
+ dialog.setAttribute(qt.Qt.WA_DeleteOnClose)
+ dialog.setModal(True)
+ dialog.currentColorChanged.connect(self.__colorChanged)
+ dialog.finished.connect(self.__dialogClosed)
+ dialog.show()
+
+ def __colorChanged(self, color):
+ self.__color = color
+ self._setColor(color)
+ self.sigColorChanged.emit(color)
+
+ def __dialogClosed(self, result):
+ if result == qt.QDialog.Rejected:
+ self.__colorChanged(self.__previousColor)
+ self.__previousColor = None
+
+
+class IsoSurfaceAlphaItem(SubjectItem):
+ """
+ Isosurface alpha item.
+ """
+ editable = True
+ persistent = True
+
+ def _init(self):
+ pass
+
+ def getSignals(self):
+ return self.subject.sigColorChanged
+
+ def getEditor(self, parent, option, index):
+ editor = qt.QSlider(parent)
+ editor.setOrientation(qt.Qt.Horizontal)
+ editor.setMinimum(0)
+ editor.setMaximum(255)
+
+ color = self.subject.getColor()
+ editor.setValue(color.alpha())
+
+ # Wrapping call in lambda is a workaround for PySide with Python 3
+ editor.valueChanged.connect(
+ lambda value: self.__editorChanged(value))
+
+ return editor
+
+ def __editorChanged(self, value):
+ color = self.subject.getColor()
+ color.setAlpha(value)
+ self.subject.setColor(color)
+
+ def setEditorData(self, editor):
+ return True
+
+ def _setModelData(self, editor):
+ return True
+
+
+class IsoSurfaceAlphaLegendItem(SubjectItem):
+ """Legend to place under opacity slider"""
+
+ editable = False
+ persistent = True
+
+ def getEditor(self, parent, option, index):
+ layout = qt.QHBoxLayout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(0)
+ layout.addWidget(qt.QLabel('0'))
+ layout.addStretch(1)
+ layout.addWidget(qt.QLabel('1'))
+
+ editor = qt.QWidget(parent)
+ editor.setLayout(layout)
+ return editor
+
+
+class IsoSurfaceCount(SubjectItem):
+ """
+ Item displaying the number of isosurfaces.
+ """
+
+ def getSignals(self):
+ subject = self.subject
+ return [subject.sigIsosurfaceAdded, subject.sigIsosurfaceRemoved]
+
+ def _pullData(self):
+ return len(self.subject.getIsosurfaces())
+
+
+class IsoSurfaceAddRemoveWidget(qt.QWidget):
+
+ sigViewTask = qt.Signal(str)
+ """Signal for the tree view to perform some task"""
+
+ def __init__(self, parent, item):
+ super(IsoSurfaceAddRemoveWidget, self).__init__(parent)
+ self._item = item
+ layout = qt.QHBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(0)
+
+ addBtn = qt.QToolButton(self)
+ addBtn.setText('+')
+ addBtn.setToolButtonStyle(qt.Qt.ToolButtonTextOnly)
+ layout.addWidget(addBtn)
+ addBtn.clicked.connect(self.__addClicked)
+
+ removeBtn = qt.QToolButton(self)
+ removeBtn.setText('-')
+ removeBtn.setToolButtonStyle(qt.Qt.ToolButtonTextOnly)
+ layout.addWidget(removeBtn)
+ removeBtn.clicked.connect(self.__removeClicked)
+
+ layout.addStretch(1)
+
+ def __addClicked(self):
+ sfview = self._item.subject
+ if not sfview:
+ return
+ dataRange = sfview.getDataRange()
+ if dataRange is None:
+ dataRange = [0, 1]
+
+ sfview.addIsosurface(
+ numpy.mean((dataRange[0], dataRange[-1])), '#0000FF')
+
+ def __removeClicked(self):
+ self.sigViewTask.emit('remove_iso')
+
+
+class IsoSurfaceAddRemoveItem(SubjectItem):
+ """
+ Item displaying a simple QToolButton allowing to add an isosurface.
+ """
+ persistent = True
+
+ def getEditor(self, parent, option, index):
+ return IsoSurfaceAddRemoveWidget(parent, self)
+
+
+class IsoSurfaceGroup(SubjectItem):
+ """
+ Root item for the list of isosurface items.
+ """
+
+ def __init__(self, subject, normalization, *args):
+ self._isoLevelSliderNormalization = normalization
+ super(IsoSurfaceGroup, self).__init__(subject, *args)
+
+ def getSignals(self):
+ subject = self.subject
+ return [subject.sigIsosurfaceAdded, subject.sigIsosurfaceRemoved]
+
+ def _subjectChanged(self, signalIdx=None, args=None, kwargs=None):
+ if signalIdx == 0:
+ if len(args) >= 1:
+ isosurface = args[0]
+ if not isinstance(isosurface, Isosurface):
+ raise ValueError('Expected an isosurface instance.')
+ self.__addIsosurface(isosurface)
+ else:
+ raise ValueError('Expected an isosurface instance.')
+ elif signalIdx == 1:
+ if len(args) >= 1:
+ isosurface = args[0]
+ if not isinstance(isosurface, Isosurface):
+ raise ValueError('Expected an isosurface instance.')
+ self.__removeIsosurface(isosurface)
+ else:
+ raise ValueError('Expected an isosurface instance.')
+
+ def __addIsosurface(self, isosurface):
+ valueItem = IsoSurfaceRootItem(
+ subject=isosurface,
+ normalization=self._isoLevelSliderNormalization)
+ nameItem = IsoSurfaceLevelItem(subject=isosurface)
+ self.insertRow(max(0, self.rowCount() - 1), [valueItem, nameItem])
+
+ def __removeIsosurface(self, isosurface):
+ for row in range(self.rowCount()):
+ child = self.child(row)
+ subject = getattr(child, 'subject', None)
+ if subject == isosurface:
+ self.takeRow(row)
+ break
+
+ def _init(self):
+ nameItem = IsoSurfaceAddRemoveItem(self.subject)
+ valueItem = qt.QStandardItem()
+ valueItem.setEditable(False)
+ self.appendRow([nameItem, valueItem])
+
+ subject = self.subject
+ isosurfaces = subject.getIsosurfaces()
+ for isosurface in isosurfaces:
+ self.__addIsosurface(isosurface)
+
+
+# Cutting Plane ###############################################################
+
+class ColormapBase(SubjectItem):
+ """
+ Mixin class for colormap items.
+ """
+
+ def getSignals(self):
+ return [self.subject.getCutPlanes()[0].sigColormapChanged]
+
+
+class PlaneMinRangeItem(ColormapBase):
+ """
+ colormap minVal item.
+ Editor is a QLineEdit with a QDoubleValidator
+ """
+ editable = True
+
+ def _pullData(self):
+ colormap = self.subject.getCutPlanes()[0].getColormap()
+ auto = colormap.isAutoscale()
+ if auto == self.isEnabled():
+ self._enableRow(not auto)
+ return colormap.getVMin()
+
+ def _pushData(self, value, role=qt.Qt.UserRole):
+ self._setVMin(value)
+
+ def _setVMin(self, value):
+ colormap = self.subject.getCutPlanes()[0].getColormap()
+ vMin = value
+ vMax = colormap.getVMax()
+
+ if vMax is not None and value > vMax:
+ vMin = vMax
+ vMax = value
+ colormap.setVRange(vMin, vMax)
+
+ def getEditor(self, parent, option, index):
+ return FloatEdit(parent)
+
+ def setEditorData(self, editor):
+ editor.setValue(self._pullData())
+ return True
+
+ def _setModelData(self, editor):
+ value = editor.value()
+ self._setVMin(value)
+ return True
+
+
+class PlaneMaxRangeItem(ColormapBase):
+ """
+ colormap maxVal item.
+ Editor is a QLineEdit with a QDoubleValidator
+ """
+ editable = True
+
+ def _pullData(self):
+ colormap = self.subject.getCutPlanes()[0].getColormap()
+ auto = colormap.isAutoscale()
+ if auto == self.isEnabled():
+ self._enableRow(not auto)
+ return self.subject.getCutPlanes()[0].getColormap().getVMax()
+
+ def _setVMax(self, value):
+ colormap = self.subject.getCutPlanes()[0].getColormap()
+ vMin = colormap.getVMin()
+ vMax = value
+ if vMin is not None and value < vMin:
+ vMax = vMin
+ vMin = value
+ colormap.setVRange(vMin, vMax)
+
+ def getEditor(self, parent, option, index):
+ return FloatEdit(parent)
+
+ def setEditorData(self, editor):
+ editor.setText(str(self._pullData()))
+ return True
+
+ def _setModelData(self, editor):
+ value = editor.value()
+ self._setVMax(value)
+ return True
+
+
+class PlaneOrientationItem(SubjectItem):
+ """
+ Plane orientation item.
+ Editor is a QComboBox.
+ """
+ editable = True
+
+ _PLANE_ACTIONS = (
+ ('3d-plane-normal-x', 'Plane 0',
+ 'Set plane perpendicular to red axis', (1., 0., 0.)),
+ ('3d-plane-normal-y', 'Plane 1',
+ 'Set plane perpendicular to green axis', (0., 1., 0.)),
+ ('3d-plane-normal-z', 'Plane 2',
+ 'Set plane perpendicular to blue axis', (0., 0., 1.)),
+ )
+
+ def getSignals(self):
+ return [self.subject.getCutPlanes()[0].sigPlaneChanged]
+
+ def _pullData(self):
+ currentNormal = self.subject.getCutPlanes()[0].getNormal(
+ coordinates='scene')
+ for _, text, _, normal in self._PLANE_ACTIONS:
+ if numpy.allclose(normal, currentNormal):
+ return text
+ return ''
+
+ def getEditor(self, parent, option, index):
+ editor = qt.QComboBox(parent)
+ for iconName, text, tooltip, normal in self._PLANE_ACTIONS:
+ editor.addItem(getQIcon(iconName), text)
+
+ # Wrapping call in lambda is a workaround for PySide with Python 3
+ editor.currentIndexChanged[int].connect(
+ lambda index: self.__editorChanged(index))
+ return editor
+
+ def __editorChanged(self, index):
+ normal = self._PLANE_ACTIONS[index][3]
+ plane = self.subject.getCutPlanes()[0]
+ plane.setNormal(normal, coordinates='scene')
+ plane.moveToCenter()
+
+ def setEditorData(self, editor):
+ currentText = self._pullData()
+ index = 0
+ for normIdx, (_, text, _, _) in enumerate(self._PLANE_ACTIONS):
+ if text == currentText:
+ index = normIdx
+ break
+ editor.setCurrentIndex(index)
+ return True
+
+ def _setModelData(self, editor):
+ return True
+
+
+class PlaneInterpolationItem(SubjectItem):
+ """Toggle cut plane interpolation method: nearest or linear.
+
+ Item is checkable
+ """
+
+ def _init(self):
+ interpolation = self.subject.getCutPlanes()[0].getInterpolation()
+ self.setCheckable(True)
+ self.setCheckState(
+ qt.Qt.Checked if interpolation == 'linear' else qt.Qt.Unchecked)
+ self.setData(self._pullData(), role=qt.Qt.DisplayRole, pushData=False)
+
+ def getSignals(self):
+ return [self.subject.getCutPlanes()[0].sigInterpolationChanged]
+
+ def leftClicked(self):
+ checked = self.checkState() == qt.Qt.Checked
+ self._setInterpolation('linear' if checked else 'nearest')
+
+ def _pullData(self):
+ interpolation = self.subject.getCutPlanes()[0].getInterpolation()
+ self._setInterpolation(interpolation)
+ return interpolation[0].upper() + interpolation[1:]
+
+ def _setInterpolation(self, interpolation):
+ self.subject.getCutPlanes()[0].setInterpolation(interpolation)
+
+
+class PlaneDisplayBelowMinItem(SubjectItem):
+ """Toggle whether to display or not values <= colormap min of the cut plane
+
+ Item is checkable
+ """
+
+ def _init(self):
+ display = self.subject.getCutPlanes()[0].getDisplayValuesBelowMin()
+ self.setCheckable(True)
+ self.setCheckState(
+ qt.Qt.Checked if display else qt.Qt.Unchecked)
+ self.setData(self._pullData(), role=qt.Qt.DisplayRole, pushData=False)
+
+ def getSignals(self):
+ return [self.subject.getCutPlanes()[0].sigTransparencyChanged]
+
+ def leftClicked(self):
+ checked = self.checkState() == qt.Qt.Checked
+ self._setDisplayValuesBelowMin(checked)
+
+ def _pullData(self):
+ display = self.subject.getCutPlanes()[0].getDisplayValuesBelowMin()
+ self._setDisplayValuesBelowMin(display)
+ return "Displayed" if display else "Hidden"
+
+ def _setDisplayValuesBelowMin(self, display):
+ self.subject.getCutPlanes()[0].setDisplayValuesBelowMin(display)
+
+
+class PlaneColormapItem(ColormapBase):
+ """
+ colormap name item.
+ Editor is a QComboBox
+ """
+ editable = True
+
+ listValues = ['gray', 'reversed gray',
+ 'temperature', 'red',
+ 'green', 'blue',
+ 'viridis', 'magma', 'inferno', 'plasma']
+
+ def getEditor(self, parent, option, index):
+ editor = qt.QComboBox(parent)
+ editor.addItems(self.listValues)
+
+ # Wrapping call in lambda is a workaround for PySide with Python 3
+ editor.currentIndexChanged[int].connect(
+ lambda index: self.__editorChanged(index))
+
+ return editor
+
+ def __editorChanged(self, index):
+ colormapName = self.listValues[index]
+ colormap = self.subject.getCutPlanes()[0].getColormap()
+ colormap.setName(colormapName)
+
+ def setEditorData(self, editor):
+ colormapName = self.subject.getCutPlanes()[0].getColormap().getName()
+ try:
+ index = self.listValues.index(colormapName)
+ except ValueError:
+ _logger.error('Unsupported colormap: %s', colormapName)
+ else:
+ editor.setCurrentIndex(index)
+ return True
+
+ def _setModelData(self, editor):
+ self.__editorChanged(editor.currentIndex())
+ return True
+
+ def _pullData(self):
+ return self.subject.getCutPlanes()[0].getColormap().getName()
+
+
+class PlaneAutoScaleItem(ColormapBase):
+ """
+ colormap autoscale item.
+ Item is checkable.
+ """
+
+ def _init(self):
+ colorMap = self.subject.getCutPlanes()[0].getColormap()
+ self.setCheckable(True)
+ self.setCheckState((colorMap.isAutoscale() and qt.Qt.Checked)
+ or qt.Qt.Unchecked)
+ self.setData(self._pullData(), role=qt.Qt.DisplayRole, pushData=False)
+
+ def leftClicked(self):
+ checked = (self.checkState() == qt.Qt.Checked)
+ self._setAutoScale(checked)
+
+ def _setAutoScale(self, auto):
+ view3d = self.subject
+ colormap = view3d.getCutPlanes()[0].getColormap()
+
+ if auto != colormap.isAutoscale():
+ if auto:
+ vMin = vMax = None
+ else:
+ dataRange = view3d.getDataRange()
+ if dataRange is None:
+ vMin = vMax = None
+ else:
+ vMin, vMax = dataRange[0], dataRange[-1]
+ colormap.setVRange(vMin, vMax)
+
+ def _pullData(self):
+ auto = self.subject.getCutPlanes()[0].getColormap().isAutoscale()
+ self._setAutoScale(auto)
+ if auto:
+ data = 'Auto'
+ else:
+ data = 'User'
+ return data
+
+
+class NormalizationNode(ColormapBase):
+ """
+ colormap normalization item.
+ Item is a QComboBox.
+ """
+ editable = True
+ listValues = list(Colormap.NORMALIZATIONS)
+
+ def getEditor(self, parent, option, index):
+ editor = qt.QComboBox(parent)
+ editor.addItems(self.listValues)
+
+ # Wrapping call in lambda is a workaround for PySide with Python 3
+ editor.currentIndexChanged[int].connect(
+ lambda index: self.__editorChanged(index))
+
+ return editor
+
+ def __editorChanged(self, index):
+ colorMap = self.subject.getCutPlanes()[0].getColormap()
+ normalization = self.listValues[index]
+ self.subject.getCutPlanes()[0].setColormap(name=colorMap.getName(),
+ norm=normalization,
+ vmin=colorMap.getVMin(),
+ vmax=colorMap.getVMax())
+
+ def setEditorData(self, editor):
+ normalization = self.subject.getCutPlanes()[0].getColormap().getNormalization()
+ index = self.listValues.index(normalization)
+ editor.setCurrentIndex(index)
+ return True
+
+ def _setModelData(self, editor):
+ self.__editorChanged(editor.currentIndex())
+ return True
+
+ def _pullData(self):
+ return self.subject.getCutPlanes()[0].getColormap().getNormalization()
+
+
+class PlaneGroup(SubjectItem):
+ """
+ Root Item for the plane items.
+ """
+ def _init(self):
+ valueItem = qt.QStandardItem()
+ valueItem.setEditable(False)
+ nameItem = PlaneVisibleItem(self.subject, 'Visible')
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem('Colormap')
+ nameItem.setEditable(False)
+ valueItem = PlaneColormapItem(self.subject)
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem('Normalization')
+ nameItem.setEditable(False)
+ valueItem = NormalizationNode(self.subject)
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem('Orientation')
+ nameItem.setEditable(False)
+ valueItem = PlaneOrientationItem(self.subject)
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem('Interpolation')
+ nameItem.setEditable(False)
+ valueItem = PlaneInterpolationItem(self.subject)
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem('Autoscale')
+ nameItem.setEditable(False)
+ valueItem = PlaneAutoScaleItem(self.subject)
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem('Min')
+ nameItem.setEditable(False)
+ valueItem = PlaneMinRangeItem(self.subject)
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem('Max')
+ nameItem.setEditable(False)
+ valueItem = PlaneMaxRangeItem(self.subject)
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem('Values<=Min')
+ nameItem.setEditable(False)
+ valueItem = PlaneDisplayBelowMinItem(self.subject)
+ self.appendRow([nameItem, valueItem])
+
+
+class PlaneVisibleItem(SubjectItem):
+ """
+ Plane visibility item.
+ Item is checkable.
+ """
+ def _init(self):
+ plane = self.subject.getCutPlanes()[0]
+ self.setCheckable(True)
+ self.setCheckState((plane.isVisible() and qt.Qt.Checked)
+ or qt.Qt.Unchecked)
+
+ def leftClicked(self):
+ plane = self.subject.getCutPlanes()[0]
+ checked = (self.checkState() == qt.Qt.Checked)
+ if checked != plane.isVisible():
+ plane.setVisible(checked)
+ if plane.isVisible():
+ plane.moveToCenter()
+
+
+# Tree ########################################################################
+
+class ItemDelegate(qt.QStyledItemDelegate):
+ """
+ Delegate for the QTreeView filled with SubjectItems.
+ """
+
+ sigDelegateEvent = qt.Signal(str)
+
+ def __init__(self, parent=None):
+ super(ItemDelegate, self).__init__(parent)
+
+ def createEditor(self, parent, option, index):
+ item = index.model().itemFromIndex(index)
+ if item:
+ if isinstance(item, SubjectItem):
+ editor = item.getEditor(parent, option, index)
+ if editor:
+ editor.setAutoFillBackground(True)
+ if hasattr(editor, 'sigViewTask'):
+ editor.sigViewTask.connect(self.__viewTask)
+ return editor
+
+ editor = super(ItemDelegate, self).createEditor(parent,
+ option,
+ index)
+ return editor
+
+ def updateEditorGeometry(self, editor, option, index):
+ editor.setGeometry(option.rect)
+
+ def setEditorData(self, editor, index):
+ item = index.model().itemFromIndex(index)
+ if item:
+ if isinstance(item, SubjectItem) and item.setEditorData(editor):
+ return
+ super(ItemDelegate, self).setEditorData(editor, index)
+
+ def setModelData(self, editor, model, index):
+ item = index.model().itemFromIndex(index)
+ if isinstance(item, SubjectItem) and item._setModelData(editor):
+ return
+ super(ItemDelegate, self).setModelData(editor, model, index)
+
+ def __viewTask(self, task):
+ self.sigDelegateEvent.emit(task)
+
+
+class TreeView(qt.QTreeView):
+ """
+ TreeView displaying the SubjectItems for the ScalarFieldView.
+ """
+
+ def __init__(self, parent=None):
+ super(TreeView, self).__init__(parent)
+ self.__openedIndex = None
+ self._isoLevelSliderNormalization = 'linear'
+
+ self.setIconSize(qt.QSize(16, 16))
+
+ header = self.header()
+ header.setSectionResizeMode(qt.QHeaderView.ResizeToContents)
+
+ delegate = ItemDelegate()
+ self.setItemDelegate(delegate)
+ delegate.sigDelegateEvent.connect(self.__delegateEvent)
+ self.setSelectionBehavior(qt.QAbstractItemView.SelectRows)
+ self.setSelectionMode(qt.QAbstractItemView.SingleSelection)
+
+ self.clicked.connect(self.__clicked)
+
+ def setSfView(self, sfView):
+ """
+ Sets the ScalarFieldView this view is controlling.
+
+ :param sfView: A `ScalarFieldView`
+ """
+ model = qt.QStandardItemModel()
+ model.setColumnCount(ModelColumns.ColumnMax)
+ model.setHorizontalHeaderLabels(['Name', 'Value'])
+
+ item = qt.QStandardItem()
+ item.setEditable(False)
+ model.appendRow([ViewSettingsItem(sfView, 'Style'), item])
+
+ item = qt.QStandardItem()
+ item.setEditable(False)
+ model.appendRow([DataSetItem(sfView, 'Data'), item])
+
+ item = IsoSurfaceCount(sfView)
+ item.setEditable(False)
+ model.appendRow([IsoSurfaceGroup(sfView,
+ self._isoLevelSliderNormalization,
+ 'Isosurfaces'),
+ item])
+
+ item = qt.QStandardItem()
+ item.setEditable(False)
+ model.appendRow([PlaneGroup(sfView, 'Cutting Plane'), item])
+
+ self.setModel(model)
+
+ def setModel(self, model):
+ """
+ Reimplementation of the QTreeView.setModel method. It connects the
+ rowsRemoved signal and opens the persistent editors.
+
+ :param qt.QStandardItemModel model: the model
+ """
+
+ prevModel = self.model()
+ if prevModel:
+ self.__openPersistentEditors(qt.QModelIndex(), False)
+ try:
+ prevModel.rowsRemoved.disconnect(self.rowsRemoved)
+ except TypeError:
+ pass
+
+ super(TreeView, self).setModel(model)
+ model.rowsRemoved.connect(self.rowsRemoved)
+ self.__openPersistentEditors(qt.QModelIndex())
+
+ def __openPersistentEditors(self, parent=None, openEditor=True):
+ """
+ Opens or closes the items persistent editors.
+
+ :param qt.QModelIndex parent: starting index, or None if the whole tree
+ is to be considered.
+ :param bool openEditor: True to open the editors, False to close them.
+ """
+ model = self.model()
+
+ if not model:
+ return
+
+ if not parent or not parent.isValid():
+ parent = self.model().invisibleRootItem().index()
+
+ if openEditor:
+ meth = self.openPersistentEditor
+ else:
+ meth = self.closePersistentEditor
+
+ curParent = parent
+ children = [model.index(row, 0, curParent)
+ for row in range(model.rowCount(curParent))]
+
+ columnCount = model.columnCount()
+
+ while len(children) > 0:
+ curParent = children.pop(-1)
+
+ children.extend([model.index(row, 0, curParent)
+ for row in range(model.rowCount(curParent))])
+
+ for colIdx in range(columnCount):
+ sibling = model.sibling(curParent.row(),
+ colIdx,
+ curParent)
+ item = model.itemFromIndex(sibling)
+ if isinstance(item, SubjectItem) and item.persistent:
+ meth(sibling)
+
+ def rowsAboutToBeRemoved(self, parent, start, end):
+ """
+ Reimplementation of the QTreeView.rowsAboutToBeRemoved. Closes all
+ persistent editors under parent.
+
+ :param qt.QModelIndex parent: Parent index
+ :param int start: Start index from parent index (inclusive)
+ :param int end: End index from parent index (inclusive)
+ """
+ self.__openPersistentEditors(parent, False)
+ super(TreeView, self).rowsAboutToBeRemoved(parent, start, end)
+
+ def rowsRemoved(self, parent, start, end):
+ """
+ Called when QTreeView.rowsRemoved is emitted. Opens all persistent
+ editors under parent.
+
+ :param qt.QModelIndex parent: Parent index
+ :param int start: Start index from parent index (inclusive)
+ :param int end: End index from parent index (inclusive)
+ """
+ super(TreeView, self).rowsRemoved(parent, start, end)
+ self.__openPersistentEditors(parent, True)
+
+ def rowsInserted(self, parent, start, end):
+ """
+ Reimplementation of the QTreeView.rowsInserted. Opens all persistent
+ editors under parent.
+
+ :param qt.QModelIndex parent: Parent index
+ :param int start: Start index from parent index
+ :param int end: End index from parent index
+ """
+ self.__openPersistentEditors(parent, False)
+ super(TreeView, self).rowsInserted(parent, start, end)
+ self.__openPersistentEditors(parent)
+
+ def keyReleaseEvent(self, event):
+ """
+ Reimplementation of the QTreeView.keyReleaseEvent.
+ At the moment only Key_Delete is handled. It calls the selected item's
+ queryRemove method, and deleted the item if needed.
+
+ :param qt.QKeyEvent event: A key event
+ """
+
+ # TODO : better filtering
+ key = event.key()
+ modifiers = event.modifiers()
+
+ if key == qt.Qt.Key_Delete and modifiers == qt.Qt.NoModifier:
+ self.__removeIsosurfaces()
+
+ super(TreeView, self).keyReleaseEvent(event)
+
+ def __removeIsosurfaces(self):
+ model = self.model()
+ selected = self.selectedIndexes()
+ items = []
+ # WARNING : the selection mode is set to single, so we re not
+ # supposed to have more than one item here.
+ # Multiple selection deletion has not been tested.
+ # Watch out for index invalidation
+ for index in selected:
+ leftIndex = model.sibling(index.row(), 0, index)
+ leftItem = model.itemFromIndex(leftIndex)
+ if isinstance(leftItem, SubjectItem) and leftItem not in items:
+ items.append(leftItem)
+
+ isos = [item for item in items if isinstance(item, IsoSurfaceRootItem)]
+ if isos:
+ for iso in isos:
+ if iso.queryRemove(self):
+ parentItem = iso.parent()
+ parentItem.removeRow(iso.row())
+ else:
+ qt.QMessageBox.information(
+ self,
+ 'Remove isosurface',
+ 'Select an iso-surface to remove it')
+
+ def __clicked(self, index):
+ """
+ Called when the QTreeView.clicked signal is emitted. Calls the item's
+ leftClick method.
+
+ :param qt.QIndex index: An index
+ """
+ item = self.model().itemFromIndex(index)
+ if isinstance(item, SubjectItem):
+ item.leftClicked()
+
+ def __delegateEvent(self, task):
+ if task == 'remove_iso':
+ self.__removeIsosurfaces()
+
+ def setIsoLevelSliderNormalization(self, normalization):
+ """Set the normalization for iso level slider
+
+ This MUST be called *before* :meth:`setSfView` to have an effect.
+
+ :param str normalization: Either 'linear' or 'arcsinh'
+ """
+ assert normalization in ('linear', 'arcsinh')
+ self._isoLevelSliderNormalization = normalization
diff --git a/src/silx/gui/plot3d/ScalarFieldView.py b/src/silx/gui/plot3d/ScalarFieldView.py
new file mode 100644
index 0000000..b2bb254
--- /dev/null
+++ b/src/silx/gui/plot3d/ScalarFieldView.py
@@ -0,0 +1,1552 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a window to view a 3D scalar field.
+
+It supports iso-surfaces, a cutting plane and the definition of
+a region of interest.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "14/06/2018"
+
+import re
+import logging
+import time
+from collections import deque
+
+import numpy
+
+from silx.gui import qt, icons
+from silx.gui.colors import rgba
+from silx.gui.colors import Colormap
+
+from silx.math.marchingcubes import MarchingCubes
+from silx.math.combo import min_max
+
+from .scene import axes, cutplane, interaction, primitives, transform
+from . import scene
+from .Plot3DWindow import Plot3DWindow
+from .tools import InteractiveModeToolBar
+
+_logger = logging.getLogger(__name__)
+
+
+class Isosurface(qt.QObject):
+ """Class representing an iso-surface
+
+ :param parent: The View widget this iso-surface belongs to
+ """
+
+ sigLevelChanged = qt.Signal(float)
+ """Signal emitted when the iso-surface level has changed.
+
+ This signal provides the new level value (might be nan).
+ """
+
+ sigColorChanged = qt.Signal()
+ """Signal emitted when the iso-surface color has changed"""
+
+ sigVisibilityChanged = qt.Signal(bool)
+ """Signal emitted when the iso-surface visibility has changed.
+
+ This signal provides the new visibility status.
+ """
+
+ def __init__(self, parent):
+ super(Isosurface, self).__init__(parent=parent)
+ self._level = float('nan')
+ self._autoLevelFunction = None
+ self._color = rgba('#FFD700FF')
+ self._data = None
+ self._group = scene.Group()
+
+ def _setData(self, data, copy=True):
+ """Set the data set from which to build the iso-surface.
+
+ :param numpy.ndarray data: The 3D dataset or None
+ :param bool copy: True to make a copy, False to use as is if possible
+ """
+ if data is None:
+ self._data = None
+ else:
+ self._data = numpy.array(data, copy=copy, order='C')
+
+ self._update()
+
+ def _get3DPrimitive(self):
+ """Return the group containing the mesh of the iso-surface if any"""
+ return self._group
+
+ def isVisible(self):
+ """Returns True if iso-surface is visible, else False"""
+ return self._group.visible
+
+ def setVisible(self, visible):
+ """Set the visibility of the iso-surface in the view.
+
+ :param bool visible: True to show the iso-surface, False to hide
+ """
+ visible = bool(visible)
+ if visible != self._group.visible:
+ self._group.visible = visible
+ self.sigVisibilityChanged.emit(visible)
+
+ def getLevel(self):
+ """Return the level of this iso-surface (float)"""
+ return self._level
+
+ def setLevel(self, level):
+ """Set the value at which to build the iso-surface.
+
+ Setting this value reset auto-level function
+
+ :param float level: The value at which to build the iso-surface
+ """
+ self._autoLevelFunction = None
+ level = float(level)
+ if level != self._level:
+ self._level = level
+ self._update()
+ self.sigLevelChanged.emit(level)
+
+ def isAutoLevel(self):
+ """True if iso-level is rebuild for each data set."""
+ return self.getAutoLevelFunction() is not None
+
+ def getAutoLevelFunction(self):
+ """Return the function computing the iso-level (callable or None)"""
+ return self._autoLevelFunction
+
+ def setAutoLevelFunction(self, autoLevel):
+ """Set the function used to compute the iso-level.
+
+ WARNING: The function might get called in a thread.
+
+ :param callable autoLevel:
+ A function taking a 3D numpy.ndarray of float32 and returning
+ a float used as iso-level.
+ Example: numpy.mean(data) + numpy.std(data)
+ """
+ assert callable(autoLevel)
+ self._autoLevelFunction = autoLevel
+ self._update()
+
+ def getColor(self):
+ """Return the color of this iso-surface (QColor)"""
+ return qt.QColor.fromRgbF(*self._color)
+
+ def setColor(self, color):
+ """Set the color of the iso-surface
+
+ :param color: RGBA color of the isosurface
+ :type color: QColor, str or array-like of 4 float in [0., 1.]
+ """
+ color = rgba(color)
+ if color != self._color:
+ self._color = color
+ if len(self._group.children) != 0:
+ self._group.children[0].setAttribute('color', self._color)
+ self.sigColorChanged.emit()
+
+ def _update(self):
+ """Update underlying mesh"""
+ self._group.children = []
+
+ if self._data is None:
+ if self.isAutoLevel():
+ self._level = float('nan')
+
+ else:
+ if self.isAutoLevel():
+ st = time.time()
+ try:
+ level = float(self.getAutoLevelFunction()(self._data))
+
+ except Exception:
+ module = self.getAutoLevelFunction().__module__
+ name = self.getAutoLevelFunction().__name__
+ _logger.error(
+ "Error while executing iso level function %s.%s",
+ module,
+ name,
+ exc_info=True)
+ level = float('nan')
+
+ else:
+ _logger.info(
+ 'Computed iso-level in %f s.', time.time() - st)
+
+ if level != self._level:
+ self._level = level
+ self.sigLevelChanged.emit(level)
+
+ if not numpy.isfinite(self._level):
+ return
+
+ st = time.time()
+ vertices, normals, indices = MarchingCubes(
+ self._data,
+ isolevel=self._level)
+ _logger.info('Computed iso-surface in %f s.', time.time() - st)
+
+ if len(vertices) == 0:
+ return
+ else:
+ mesh = primitives.Mesh3D(vertices,
+ colors=self._color,
+ normals=normals,
+ mode='triangles',
+ indices=indices)
+ self._group.children = [mesh]
+
+
+class SelectedRegion(object):
+ """Selection of a 3D region aligned with the axis.
+
+ :param arrayRange: Range of the selection in the array
+ ((zmin, zmax), (ymin, ymax), (xmin, xmax))
+ :param dataBBox: Bounding box of the selection in data coordinates
+ ((xmin, xmax), (ymin, ymax), (zmin, zmax))
+ :param translation: Offset from array to data coordinates (ox, oy, oz)
+ :param scale: Scale from array to data coordinates (sx, sy, sz)
+ """
+
+ def __init__(self, arrayRange, dataBBox,
+ translation=(0., 0., 0.),
+ scale=(1., 1., 1.)):
+ self._arrayRange = numpy.array(arrayRange, copy=True, dtype=numpy.int64)
+ assert self._arrayRange.shape == (3, 2)
+ assert numpy.all(self._arrayRange[:, 1] >= self._arrayRange[:, 0])
+
+ self._dataRange = dataBBox
+
+ self._translation = numpy.array(translation, dtype=numpy.float32)
+ assert self._translation.shape == (3,)
+ self._scale = numpy.array(scale, dtype=numpy.float32)
+ assert self._scale.shape == (3,)
+
+ def getArrayRange(self):
+ """Returns array ranges of the selection: 3x2 array of int
+
+ :return: A numpy array with ((zmin, zmax), (ymin, ymax), (xmin, xmax))
+ :rtype: numpy.ndarray
+ """
+ return self._arrayRange.copy()
+
+ def getArraySlices(self):
+ """Slices corresponding to the selected range in the array
+
+ :return: A numpy array with (zslice, yslice, zslice)
+ :rtype: numpy.ndarray
+ """
+ return (slice(*self._arrayRange[0]),
+ slice(*self._arrayRange[1]),
+ slice(*self._arrayRange[2]))
+
+ def getDataRange(self):
+ """Range in the data coordinates of the selection: 3x2 array of float
+
+ When the transform matrix is not the identity matrix
+ (e.g., rotation, skew) the returned range is the one of the selected
+ region bounding box in data coordinates.
+
+ :return: A numpy array with ((xmin, xmax), (ymin, ymax), (zmin, zmax))
+ :rtype: numpy.ndarray
+ """
+ return self._dataRange.copy()
+
+ def getDataScale(self):
+ """Scale from array to data coordinates: (sx, sy, sz)
+
+ :return: A numpy array with (sx, sy, sz)
+ :rtype: numpy.ndarray
+ """
+ return self._scale.copy()
+
+ def getDataTranslation(self):
+ """Offset from array to data coordinates: (ox, oy, oz)
+
+ :return: A numpy array with (ox, oy, oz)
+ :rtype: numpy.ndarray
+ """
+ return self._translation.copy()
+
+
+class CutPlane(qt.QObject):
+ """Class representing a cutting plane
+
+ :param ~silx.gui.plot3d.ScalarFieldView.ScalarFieldView sfView:
+ Widget in which the cut plane is applied.
+ """
+
+ sigVisibilityChanged = qt.Signal(bool)
+ """Signal emitted when the cut visibility has changed.
+
+ This signal provides the new visibility status.
+ """
+
+ sigDataChanged = qt.Signal()
+ """Signal emitted when the data this plane is cutting has changed."""
+
+ sigPlaneChanged = qt.Signal()
+ """Signal emitted when the cut plane has moved"""
+
+ sigColormapChanged = qt.Signal(Colormap)
+ """Signal emitted when the colormap has changed
+
+ This signal provides the new colormap.
+ """
+
+ sigTransparencyChanged = qt.Signal()
+ """Signal emitted when the transparency of the plane has changed.
+
+ This signal is emitted when calling :meth:`setDisplayValuesBelowMin`.
+ """
+
+ sigInterpolationChanged = qt.Signal(str)
+ """Signal emitted when the cut plane interpolation has changed
+
+ This signal provides the new interpolation mode.
+ """
+
+ def __init__(self, sfView):
+ super(CutPlane, self).__init__(parent=sfView)
+
+ self._dataRange = None
+ self._visible = False
+
+ self.__syncPlane = True
+
+ # Plane stroke on the outer bounding box
+ self._planeStroke = primitives.PlaneInGroup(normal=(0, 1, 0))
+ self._planeStroke.visible = self._visible
+ self._planeStroke.addListener(self._planeChanged)
+ self._planeStroke.plane.addListener(self._planePositionChanged)
+
+ # Plane with texture on the data bounding box
+ self._dataPlane = cutplane.CutPlane(normal=(0, 1, 0))
+ self._dataPlane.strokeVisible = False
+ self._dataPlane.alpha = 1.
+ self._dataPlane.visible = self._visible
+ self._dataPlane.plane.addListener(self._planePositionChanged)
+
+ self._colormap = Colormap(
+ name='gray', normalization='linear', vmin=None, vmax=None)
+ self.getColormap().sigChanged.connect(self._colormapChanged)
+ self._updateSceneColormap()
+
+ sfView.sigDataChanged.connect(self._sfViewDataChanged)
+ sfView.sigTransformChanged.connect(self._sfViewTransformChanged)
+
+ def _get3DPrimitives(self):
+ """Return the cut plane scene node."""
+ return self._planeStroke, self._dataPlane
+
+ def _keepPlaneInBBox(self):
+ """Makes sure the plane intersect its parent bounding box if any"""
+ bounds = self._planeStroke.parent.bounds(dataBounds=True)
+ if bounds is not None:
+ self._planeStroke.plane.point = numpy.clip(
+ self._planeStroke.plane.point,
+ a_min=bounds[0], a_max=bounds[1])
+
+ @staticmethod
+ def _syncPlanes(master, slave):
+ """Move slave PlaneInGroup so that it is coplanar with master.
+
+ :param PlaneInGroup master: Reference PlaneInGroup
+ :param PlaneInGroup slave: PlaneInGroup to align
+ """
+ masterToSlave = transform.StaticTransformList([
+ slave.objectToSceneTransform.inverse(),
+ master.objectToSceneTransform])
+
+ point = masterToSlave.transformPoint(
+ master.plane.point)
+ normal = masterToSlave.transformNormal(
+ master.plane.normal)
+ slave.plane.setPlane(point, normal)
+
+ def _sfViewDataChanged(self):
+ """Handle data change in the ScalarFieldView this plane belongs to"""
+ self._dataPlane.setData(self.sender().getData(), copy=False)
+
+ # Store data range info as 3-tuple of values
+ self._dataRange = self.sender().getDataRange()
+
+ self.sigDataChanged.emit()
+
+ # Update colormap range when autoscale
+ if self.getColormap().isAutoscale():
+ self._updateSceneColormap()
+
+ self._keepPlaneInBBox()
+
+ def _sfViewTransformChanged(self):
+ """Handle transform changed in the ScalarFieldView"""
+ self._keepPlaneInBBox()
+ self._syncPlanes(master=self._planeStroke,
+ slave=self._dataPlane)
+ self.sigPlaneChanged.emit()
+
+ def _planeChanged(self, source, *args, **kwargs):
+ """Handle events from the plane primitive"""
+ # Using _visible for now, until scene as more info in events
+ if source.visible != self._visible:
+ self._visible = source.visible
+ self.sigVisibilityChanged.emit(source.visible)
+
+ def _planePositionChanged(self, source, *args, **kwargs):
+ """Handle update of cut plane position and normal"""
+ if self.__syncPlane:
+ self.__syncPlane = False
+ if source is self._planeStroke.plane:
+ self._syncPlanes(master=self._planeStroke,
+ slave=self._dataPlane)
+ elif source is self._dataPlane.plane:
+ self._syncPlanes(master=self._dataPlane,
+ slave=self._planeStroke)
+ else:
+ _logger.error('Received an unknown object %s',
+ str(source))
+
+ if self._planeStroke.visible or self._dataPlane.visible:
+ self.sigPlaneChanged.emit()
+
+ self.__syncPlane = True
+
+ # Plane position
+
+ def moveToCenter(self):
+ """Move cut plane to center of data set"""
+ self._planeStroke.moveToCenter()
+
+ def isValid(self):
+ """Returns whether the cut plane is defined or not (bool)"""
+ return self._planeStroke.isValid
+
+ def _plane(self, coordinates='array'):
+ """Returns the scene plane to set.
+
+ :param str coordinates: The coordinate system to use:
+ Either 'scene' or 'array' (default)
+ :rtype: Plane
+ :raise ValueError: If coordinates is not correct
+ """
+ if coordinates == 'scene':
+ return self._planeStroke.plane
+ elif coordinates == 'array':
+ return self._dataPlane.plane
+ else:
+ raise ValueError(
+ 'Unsupported coordinates: %s' % str(coordinates))
+
+ def getNormal(self, coordinates='array'):
+ """Returns the normal of the plane (as a unit vector)
+
+ :param str coordinates: The coordinate system to use:
+ Either 'scene' or 'array' (default)
+ :return: Normal (nx, ny, nz), vector is 0 if no plane is defined
+ :rtype: numpy.ndarray
+ :raise ValueError: If coordinates is not correct
+ """
+ return self._plane(coordinates).normal
+
+ def setNormal(self, normal, coordinates='array'):
+ """Set the normal of the plane.
+
+ :param normal: 3-tuple of float: nx, ny, nz
+ :param str coordinates: The coordinate system to use:
+ Either 'scene' or 'array' (default)
+ :raise ValueError: If coordinates is not correct
+ """
+ self._plane(coordinates).normal = normal
+
+ def getPoint(self, coordinates='array'):
+ """Returns a point on the plane.
+
+ :param str coordinates: The coordinate system to use:
+ Either 'scene' or 'array' (default)
+ :return: (x, y, z)
+ :rtype: numpy.ndarray
+ :raise ValueError: If coordinates is not correct
+ """
+ return self._plane(coordinates).point
+
+ def setPoint(self, point, constraint=True, coordinates='array'):
+ """Set a point contained in the plane.
+
+ Warning: The plane might not intersect the bounding box of the data.
+
+ :param point: (x, y, z) position
+ :type point: 3-tuple of float
+ :param bool constraint:
+ True (default) to make sure the plane intersect data bounding box,
+ False to set the plane without any constraint.
+ :raise ValueError: If coordinates is not correc
+ """
+ self._plane(coordinates).point = point
+ if constraint:
+ self._keepPlaneInBBox()
+
+ def getParameters(self, coordinates='array'):
+ """Returns the plane equation parameters: a*x + b*y + c*z + d = 0
+
+ :param str coordinates: The coordinate system to use:
+ Either 'scene' or 'array' (default)
+ :return: Plane equation parameters: (a, b, c, d)
+ :rtype: numpy.ndarray
+ :raise ValueError: If coordinates is not correct
+ """
+ return self._plane(coordinates).parameters
+
+ def setParameters(self, parameters, constraint=True, coordinates='array'):
+ """Set the plane equation parameters: a*x + b*y + c*z + d = 0
+
+ Warning: The plane might not intersect the bounding box of the data.
+
+ :param parameters: (a, b, c, d) plane equation parameters.
+ :type parameters: 4-tuple of float
+ :param bool constraint:
+ True (default) to make sure the plane intersect data bounding box,
+ False to set the plane without any constraint.
+ :raise ValueError: If coordinates is not correc
+ """
+ self._plane(coordinates).parameters = parameters
+ if constraint:
+ self._keepPlaneInBBox()
+
+ # Visibility
+
+ def isVisible(self):
+ """Returns True if the plane is visible, False otherwise"""
+ return self._planeStroke.visible
+
+ def setVisible(self, visible):
+ """Set the visibility of the plane
+
+ :param bool visible: True to make plane visible
+ """
+ visible = bool(visible)
+ self._planeStroke.visible = visible
+ self._dataPlane.visible = visible
+
+ # Border stroke
+
+ def getStrokeColor(self):
+ """Returns the color of the plane border (QColor)"""
+ return qt.QColor.fromRgbF(*self._planeStroke.color)
+
+ def setStrokeColor(self, color):
+ """Set the color of the plane border.
+
+ :param color: RGB color: name, #RRGGBB or RGB values
+ :type color:
+ QColor, str or array-like of 3 or 4 float in [0., 1.] or uint8
+ """
+ color = rgba(color)
+ self._planeStroke.color = color
+ self._dataPlane.color = color
+
+ # Data
+
+ def getImageData(self):
+ """Returns the data and information corresponding to the cut plane.
+
+ The returned data is not interpolated,
+ it is a slice of the 3D scalar field.
+
+ Image data axes are so that plane normal is towards the point of view.
+
+ :return: An object containing the 2D data slice and information
+ """
+ return _CutPlaneImage(self)
+
+ # Interpolation
+
+ def getInterpolation(self):
+ """Returns the interpolation used to display to cut plane.
+
+ :return: 'nearest' or 'linear'
+ :rtype: str
+ """
+ return self._dataPlane.interpolation
+
+ def setInterpolation(self, interpolation):
+ """Set the interpolation used to display to cut plane
+
+ The default interpolation is 'linear'
+
+ :param str interpolation: 'nearest' or 'linear'
+ """
+ if interpolation != self.getInterpolation():
+ self._dataPlane.interpolation = interpolation
+ self.sigInterpolationChanged.emit(interpolation)
+
+ # Colormap
+
+ # def getAlpha(self):
+ # """Returns the transparency of the plane as a float in [0., 1.]"""
+ # return self._plane.alpha
+
+ # def setAlpha(self, alpha):
+ # """Set the plane transparency.
+ #
+ # :param float alpha: Transparency in [0., 1]
+ # """
+ # self._plane.alpha = alpha
+
+ def getDisplayValuesBelowMin(self):
+ """Return whether values <= colormap min are displayed or not.
+
+ :rtype: bool
+ """
+ return self._dataPlane.colormap.displayValuesBelowMin
+
+ def setDisplayValuesBelowMin(self, display):
+ """Set whether to display values <= colormap min.
+
+ :param bool display: True to show values below min,
+ False to discard them
+ """
+ display = bool(display)
+ if display != self.getDisplayValuesBelowMin():
+ self._dataPlane.colormap.displayValuesBelowMin = display
+ self.sigTransparencyChanged.emit()
+
+ def getColormap(self):
+ """Returns the colormap set by :meth:`setColormap`.
+
+ :return: The colormap
+ :rtype: ~silx.gui.colors.Colormap
+ """
+ return self._colormap
+
+ def setColormap(self,
+ name='gray',
+ norm=None,
+ vmin=None,
+ vmax=None):
+ """Set the colormap to use.
+
+ By either providing a :class:`Colormap` object or
+ its name, normalization and range.
+
+ :param name: Name of the colormap in
+ 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'.
+ Or Colormap object.
+ :type name: str or ~silx.gui.colors.Colormap
+ :param str norm: Colormap mapping: 'linear' or 'log'.
+ :param float vmin: The minimum value of the range or None for autoscale
+ :param float vmax: The maximum value of the range or None for autoscale
+ """
+ _logger.debug('setColormap %s %s (%s, %s)',
+ name, str(norm), str(vmin), str(vmax))
+
+ self._colormap.sigChanged.disconnect(self._colormapChanged)
+
+ if isinstance(name, Colormap): # Use it as it is
+ assert (norm, vmin, vmax) == (None, None, None)
+ self._colormap = name
+ else:
+ if norm is None:
+ norm = 'linear'
+ self._colormap = Colormap(
+ name=name, normalization=norm, vmin=vmin, vmax=vmax)
+
+ self._colormap.sigChanged.connect(self._colormapChanged)
+ self._colormapChanged()
+
+ def getColormapEffectiveRange(self):
+ """Returns the currently used range of the colormap.
+
+ This range is computed from the data set if colormap is in autoscale.
+ Range is clipped to positive values when using log scale.
+
+ :return: 2-tuple of float
+ """
+ return self._dataPlane.colormap.range_
+
+ def _updateSceneColormap(self):
+ """Synchronizes scene's colormap with Colormap object"""
+ colormap = self.getColormap()
+ sceneCMap = self._dataPlane.colormap
+
+ sceneCMap.colormap = colormap.getNColors()
+
+ sceneCMap.norm = colormap.getNormalization()
+ range_ = colormap.getColormapRange(data=self._dataRange)
+ sceneCMap.range_ = range_
+
+ def _colormapChanged(self):
+ """Handle update of Colormap object"""
+ self._updateSceneColormap()
+ # Forward colormap changed event
+ self.sigColormapChanged.emit(self.getColormap())
+
+
+class _CutPlaneImage(object):
+ """Object representing the data sliced by a cut plane
+
+ :param CutPlane cutPlane: The CutPlane from which to generate image info
+ """
+
+ def __init__(self, cutPlane):
+ # Init attributes with default values
+ self._isValid = False
+ self._data = numpy.zeros((0, 0), dtype=numpy.float32)
+ self._index = 0
+ self._xLabel = ''
+ self._yLabel = ''
+ self._normalLabel = ''
+ self._scale = float('nan'), float('nan')
+ self._translation = float('nan'), float('nan')
+ self._position = float('nan')
+
+ sfView = cutPlane.parent()
+ if not sfView or not cutPlane.isValid():
+ _logger.info("No plane available")
+ return
+
+ data = sfView.getData(copy=False)
+ if data is None:
+ _logger.info("No data available")
+ return
+
+ normal = cutPlane.getNormal(coordinates='array')
+ point = cutPlane.getPoint(coordinates='array')
+
+ if numpy.linalg.norm(numpy.cross(normal, (1., 0., 0.))) < 0.0017:
+ if not 0 <= point[0] <= data.shape[2]:
+ _logger.info("Plane outside dataset")
+ return
+ index = max(0, min(int(point[0]), data.shape[2] - 1))
+ slice_ = data[:, :, index]
+ xAxisIndex, yAxisIndex, normalAxisIndex = 1, 2, 0 # y, z, x
+
+ elif numpy.linalg.norm(numpy.cross(normal, (0., 1., 0.))) < 0.0017:
+ if not 0 <= point[1] <= data.shape[1]:
+ _logger.info("Plane outside dataset")
+ return
+ index = max(0, min(int(point[1]), data.shape[1] - 1))
+ slice_ = numpy.transpose(data[:, index, :])
+ xAxisIndex, yAxisIndex, normalAxisIndex = 2, 0, 1 # z, x, y
+
+ elif numpy.linalg.norm(numpy.cross(normal, (0., 0., 1.))) < 0.0017:
+ if not 0 <= point[2] <= data.shape[0]:
+ _logger.info("Plane outside dataset")
+ return
+ index = max(0, min(int(point[2]), data.shape[0] - 1))
+ slice_ = data[index, :, :]
+ xAxisIndex, yAxisIndex, normalAxisIndex = 0, 1, 2 # x, y, z
+ else:
+ _logger.warning('Unsupported normal: (%f, %f, %f)',
+ normal[0], normal[1], normal[2])
+ return
+
+ # Store cut plane image info
+
+ self._isValid = True
+ self._data = numpy.array(slice_, copy=True)
+ self._index = index
+
+ # Only store extra information when no transform matrix is set
+ # Otherwise this information can be meaningless
+ if numpy.all(numpy.equal(sfView.getTransformMatrix(),
+ numpy.identity(3, dtype=numpy.float32))):
+ labels = sfView.getAxesLabels()
+ self._xLabel = labels[xAxisIndex]
+ self._yLabel = labels[yAxisIndex]
+ self._normalLabel = labels[normalAxisIndex]
+
+ scale = sfView.getScale()
+ self._scale = scale[xAxisIndex], scale[yAxisIndex]
+
+ translation = sfView.getTranslation()
+ self._translation = translation[xAxisIndex], translation[yAxisIndex]
+
+ self._position = float(index * scale[normalAxisIndex] +
+ translation[normalAxisIndex])
+
+ def isValid(self):
+ """Returns True if the cut plane image is defined (bool)"""
+ return self._isValid
+
+ def getData(self, copy=True):
+ """Returns the image data sliced by the cut plane.
+
+ :param bool copy: True to get a copy, False otherwise
+ :return: The 2D image data corresponding to the cut plane
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self._data, copy=copy)
+
+ def getXLabel(self):
+ """Returns the label associated to the X axis of the image (str)"""
+ return self._xLabel
+
+ def getYLabel(self):
+ """Returns the label associated to the Y axis of the image (str)"""
+ return self._yLabel
+
+ def getNormalLabel(self):
+ """Returns the label of the 3D axis of the plane normal (str)"""
+ return self._normalLabel
+
+ def getScale(self):
+ """Returns the scales of the data as a 2-tuple of float (sx, sy)"""
+ return self._scale
+
+ def getTranslation(self):
+ """Returns the offset of the data as a 2-tuple of float (ox, oy)"""
+ return self._translation
+
+ def getIndex(self):
+ """Returns the index in the data array of the cut plane (int)"""
+ return self._index
+
+ def getPosition(self):
+ """Returns the cut plane position along the normal axis (flaot)"""
+ return self._position
+
+
+class ScalarFieldView(Plot3DWindow):
+ """Widget computing and displaying an iso-surface from a 3D scalar dataset.
+
+ Limitation: Currently, iso-surfaces are generated with higher values
+ than the iso-level 'inside' the surface.
+
+ :param parent: See :class:`QMainWindow`
+ """
+
+ sigDataChanged = qt.Signal()
+ """Signal emitted when the scalar data field has changed."""
+
+ sigTransformChanged = qt.Signal()
+ """Signal emitted when the transformation has changed.
+
+ It is emitted by :meth:`setTranslation`, :meth:`setTransformMatrix`,
+ :meth:`setScale`.
+ """
+
+ sigSelectedRegionChanged = qt.Signal(object)
+ """Signal emitted when the selected region has changed.
+
+ This signal provides the new selected region.
+ """
+
+ def __init__(self, parent=None):
+ super(ScalarFieldView, self).__init__(parent)
+ self._colormap = Colormap(
+ name='gray', normalization='linear', vmin=None, vmax=None)
+ self._selectedRange = None
+
+ # Store iso-surfaces
+ self._isosurfaces = []
+
+ # Transformations
+ self._dataScale = transform.Scale()
+ self._dataTranslate = transform.Translate()
+ self._dataTransform = transform.Matrix() # default to identity
+
+ self._foregroundColor = 1., 1., 1., 1.
+ self._highlightColor = 0.7, 0.7, 0., 1.
+
+ self._data = None
+ self._dataRange = None
+
+ self._group = primitives.BoundedGroup()
+ self._group.transforms = [
+ self._dataTranslate, self._dataTransform, self._dataScale]
+
+ self._bbox = axes.LabelledAxes()
+ self._bbox.children = [self._group]
+ self._outerScale = transform.Scale(1., 1., 1.)
+ self._bbox.transforms = [self._outerScale]
+ self.getPlot3DWidget().viewport.scene.children.append(self._bbox)
+
+ self._selectionBox = primitives.Box()
+ self._selectionBox.strokeSmooth = False
+ self._selectionBox.strokeWidth = 1.
+ # self._selectionBox.fillColor = 1., 1., 1., 0.3
+ # self._selectionBox.fillCulling = 'back'
+ self._selectionBox.visible = False
+ self._group.children.append(self._selectionBox)
+
+ self._cutPlane = CutPlane(sfView=self)
+ self._cutPlane.sigVisibilityChanged.connect(
+ self._planeVisibilityChanged)
+ planeStroke, dataPlane = self._cutPlane._get3DPrimitives()
+ self._bbox.children.append(planeStroke)
+ self._group.children.append(dataPlane)
+
+ self._isogroup = primitives.GroupDepthOffset()
+ self._isogroup.transforms = [
+ # Convert from z, y, x from marching cubes to x, y, z
+ transform.Matrix((
+ (0., 0., 1., 0.),
+ (0., 1., 0., 0.),
+ (1., 0., 0., 0.),
+ (0., 0., 0., 1.))),
+ # Offset to match cutting plane coords
+ transform.Translate(0.5, 0.5, 0.5)
+ ]
+ self._group.children.append(self._isogroup)
+
+ self._initPanPlaneAction()
+
+ self._updateColors()
+
+ self.getPlot3DWidget().viewport.light.shininess = 32
+
+ def saveConfig(self, ioDevice):
+ """
+ Saves this view state. Only isosurfaces at the moment. Does not save
+ the isosurface's function.
+
+ :param qt.QIODevice ioDevice: A `qt.QIODevice`.
+ """
+
+ stream = qt.QDataStream(ioDevice)
+
+ stream.writeString('<ScalarFieldView>')
+
+ isoSurfaces = self.getIsosurfaces()
+
+ nIsoSurfaces = len(isoSurfaces)
+
+ # TODO : delegate the serialization to the serialized items
+ # isosurfaces
+ if nIsoSurfaces:
+ tagIn = '<IsoSurfaces nIso={0}>'.format(nIsoSurfaces)
+ stream.writeString(tagIn)
+
+ for surface in isoSurfaces:
+ color = surface.getColor()
+ level = surface.getLevel()
+ visible = surface.isVisible()
+ stream << color
+ stream.writeDouble(level)
+ stream.writeBool(visible)
+
+ stream.writeString('</IsoSurfaces>')
+
+ stream.writeString('<Style>')
+ background = self.getBackgroundColor()
+ foreground = self.getForegroundColor()
+ highlight = self.getHighlightColor()
+ stream << background << foreground << highlight
+ stream.writeString('</Style>')
+
+ stream.writeString('</ScalarFieldView>')
+
+ def loadConfig(self, ioDevice):
+ """
+ Loads this view state.
+ See ScalarFieldView.saveView to know what is supported at the moment.
+
+ :param qt.QIODevice ioDevice: A `qt.QIODevice`.
+ """
+
+ tagStack = deque()
+
+ tagInRegex = re.compile('<(?P<itemId>[^ /]*) *'
+ '(?P<args>.*)>')
+
+ tagOutRegex = re.compile('</(?P<itemId>[^ ]*)>')
+
+ tagRootInRegex = re.compile('<ScalarFieldView>')
+
+ isoSurfaceArgsRegex = re.compile('nIso=(?P<nIso>[0-9]*)')
+
+ stream = qt.QDataStream(ioDevice)
+
+ tag = stream.readString()
+ tagMatch = tagRootInRegex.match(tag)
+
+ if tagMatch is None:
+ # TODO : explicit error
+ raise ValueError('Unknown data.')
+
+ itemId = 'ScalarFieldView'
+
+ tagStack.append(itemId)
+
+ while True:
+
+ tag = stream.readString()
+
+ tagMatch = tagOutRegex.match(tag)
+ if tagMatch:
+ closeId = tagMatch.groupdict()['itemId']
+ if closeId != itemId:
+ # TODO : explicit error
+ raise ValueError('Unexpected closing tag {0} '
+ '(expected {1})'
+ ''.format(closeId, itemId))
+
+ if itemId == 'ScalarFieldView':
+ # reached end
+ break
+ else:
+ itemId = tagStack.pop()
+ # fetching next tag
+ continue
+
+ tagMatch = tagInRegex.match(tag)
+
+ if tagMatch is None:
+ # TODO : explicit error
+ raise ValueError('Unknown data.')
+
+ tagStack.append(itemId)
+
+ matchDict = tagMatch.groupdict()
+
+ itemId = matchDict['itemId']
+
+ # TODO : delegate the deserialization to the serialized items
+ if itemId == 'IsoSurfaces':
+ argsMatch = isoSurfaceArgsRegex.match(matchDict['args'])
+ if not argsMatch:
+ # TODO : explicit error
+ raise ValueError('Failed to parse args "{0}".'
+ ''.format(matchDict['args']))
+ argsDict = argsMatch.groupdict()
+ nIso = int(argsDict['nIso'])
+ if nIso:
+ for surface in self.getIsosurfaces():
+ self.removeIsosurface(surface)
+ for isoIdx in range(nIso):
+ color = qt.QColor()
+ stream >> color
+ level = stream.readDouble()
+ visible = stream.readBool()
+ surface = self.addIsosurface(level, color=color)
+ surface.setVisible(visible)
+ elif itemId == 'Style':
+ background = qt.QColor()
+ foreground = qt.QColor()
+ highlight = qt.QColor()
+ stream >> background >> foreground >> highlight
+ self.setBackgroundColor(background)
+ self.setForegroundColor(foreground)
+ self.setHighlightColor(highlight)
+ else:
+ raise ValueError('Unknown entry tag {0}.'
+ ''.format(itemId))
+
+ def _initPanPlaneAction(self):
+ """Creates and init the pan plane action"""
+ self._panPlaneAction = qt.QAction(self)
+ self._panPlaneAction.setIcon(icons.getQIcon('3d-plane-pan'))
+ self._panPlaneAction.setText('Pan plane')
+ self._panPlaneAction.setCheckable(True)
+ self._panPlaneAction.setToolTip(
+ 'Pan the cutting plane. Press <b>Ctrl</b> to rotate the scene.')
+ self._panPlaneAction.setEnabled(False)
+
+ self._panPlaneAction.triggered[bool].connect(self._planeActionTriggered)
+ self.getPlot3DWidget().sigInteractiveModeChanged.connect(
+ self._interactiveModeChanged)
+
+ toolbar = self.findChild(InteractiveModeToolBar)
+ if toolbar is not None:
+ toolbar.addAction(self._panPlaneAction)
+
+ def _planeActionTriggered(self, checked=False):
+ self._panPlaneAction.setChecked(True)
+ self.setInteractiveMode('plane')
+
+ def _interactiveModeChanged(self):
+ self._panPlaneAction.setChecked(self.getInteractiveMode() == 'plane')
+ self._updateColors()
+
+ def _planeVisibilityChanged(self, visible):
+ """Handle visibility events from the plane"""
+ if visible != self._panPlaneAction.isEnabled():
+ self._panPlaneAction.setEnabled(visible)
+ if visible:
+ self.setInteractiveMode('plane')
+ elif self._panPlaneAction.isChecked():
+ self.setInteractiveMode('rotate')
+
+ def setInteractiveMode(self, mode):
+ """Choose the current interaction.
+
+ :param str mode: Either rotate, pan or plane
+ """
+ if mode == self.getInteractiveMode():
+ return
+
+ sceneScale = self.getPlot3DWidget().viewport.scene.transforms[0]
+ if mode == 'plane':
+ mode = interaction.PanPlaneZoomOnWheelControl(
+ self.getPlot3DWidget().viewport,
+ self._cutPlane._get3DPrimitives()[0],
+ mode='position',
+ orbitAroundCenter=False,
+ scaleTransform=sceneScale)
+
+ self.getPlot3DWidget().setInteractiveMode(mode)
+ self._updateColors()
+
+ def getInteractiveMode(self):
+ """Returns the current interaction mode, see :meth:`setInteractiveMode`
+ """
+ if isinstance(self.getPlot3DWidget().eventHandler,
+ interaction.PanPlaneZoomOnWheelControl):
+ return 'plane'
+ else:
+ return self.getPlot3DWidget().getInteractiveMode()
+
+ # Handle scalar field
+
+ def setData(self, data, copy=True):
+ """Set the 3D scalar data set to use for building the iso-surface.
+
+ Dataset order is zyx (i.e., first dimension is z).
+
+ :param data: scalar field from which to extract the iso-surface
+ :type data: 3D numpy.ndarray of float32 with shape at least (2, 2, 2)
+ :param bool copy:
+ True (default) to make a copy,
+ False to avoid copy (DO NOT MODIFY data afterwards)
+ """
+ if data is None:
+ self._data = None
+ self._dataRange = None
+ self.setSelectedRegion(zrange=None, yrange=None, xrange_=None)
+ self._group.shape = None
+ self.centerScene()
+
+ else:
+ data = numpy.array(data, copy=copy, dtype=numpy.float32, order='C')
+ assert data.ndim == 3
+ assert min(data.shape) >= 2
+
+ wasData = self._data is not None
+ previousSelectedRegion = self.getSelectedRegion()
+
+ self._data = data
+
+ # Store data range info
+ dataRange = min_max(self._data, min_positive=True, finite=True)
+ if dataRange.minimum is None: # Only non-finite data
+ dataRange = None
+
+ if dataRange is not None:
+ min_positive = dataRange.min_positive
+ if min_positive is None:
+ min_positive = float('nan')
+ dataRange = dataRange.minimum, min_positive, dataRange.maximum
+ self._dataRange = dataRange
+
+ if previousSelectedRegion is not None:
+ # Update selected region to ensure it is clipped to array range
+ self.setSelectedRegion(*previousSelectedRegion.getArrayRange())
+
+ self._group.shape = self._data.shape
+
+ if not wasData:
+ self.centerScene() # Reset viewpoint the first time only
+
+ # Update iso-surfaces
+ for isosurface in self.getIsosurfaces():
+ isosurface._setData(self._data, copy=False)
+
+ self.sigDataChanged.emit()
+
+ def getData(self, copy=True):
+ """Get the 3D scalar data currently used to build the iso-surface.
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get the internal data (DO NOT modify!)
+ :return: The data set (or None if not set)
+ """
+ if self._data is None:
+ return None
+ else:
+ return numpy.array(self._data, copy=copy)
+
+ def getDataRange(self):
+ """Return the range of the data as a 3-tuple of values.
+
+ positive min is NaN if no data is positive.
+
+ :return: (min, positive min, max) or None.
+ """
+ return self._dataRange
+
+ # Transformations
+
+ def setOuterScale(self, sx=1., sy=1., sz=1.):
+ """Set the scale to apply to the whole scene including the axes.
+
+ This is useful when axis lengths in data space are really different.
+
+ :param float sx: Scale factor along the X axis
+ :param float sy: Scale factor along the Y axis
+ :param float sz: Scale factor along the Z axis
+ """
+ self._outerScale.setScale(sx, sy, sz)
+ self.centerScene()
+
+ def getOuterScale(self):
+ """Returns the scales provided by :meth:`setOuterScale`.
+
+ :rtype: numpy.ndarray
+ """
+ return self._outerScale.scale
+
+ def setScale(self, sx=1., sy=1., sz=1.):
+ """Set the scale of the 3D scalar field (i.e., size of a voxel).
+
+ :param float sx: Scale factor along the X axis
+ :param float sy: Scale factor along the Y axis
+ :param float sz: Scale factor along the Z axis
+ """
+ scale = numpy.array((sx, sy, sz), dtype=numpy.float32)
+ if not numpy.all(numpy.equal(scale, self.getScale())):
+ self._dataScale.scale = scale
+ self.sigTransformChanged.emit()
+ self.centerScene() # Reset viewpoint
+
+ def getScale(self):
+ """Returns the scales provided by :meth:`setScale` as a numpy.ndarray.
+ """
+ return self._dataScale.scale
+
+ def setTranslation(self, x=0., y=0., z=0.):
+ """Set the translation of the origin of the data array in data coordinates.
+
+ :param float x: Offset of the data origin on the X axis
+ :param float y: Offset of the data origin on the Y axis
+ :param float z: Offset of the data origin on the Z axis
+ """
+ translation = numpy.array((x, y, z), dtype=numpy.float32)
+ if not numpy.all(numpy.equal(translation, self.getTranslation())):
+ self._dataTranslate.translation = translation
+ self.sigTransformChanged.emit()
+ self.centerScene() # Reset viewpoint
+
+ def getTranslation(self):
+ """Returns the offset set by :meth:`setTranslation` as a numpy.ndarray.
+ """
+ return self._dataTranslate.translation
+
+ def setTransformMatrix(self, matrix3x3):
+ """Set the transform matrix applied to the data.
+
+ :param numpy.ndarray matrix: 3x3 transform matrix
+ """
+ matrix3x3 = numpy.array(matrix3x3, copy=True, dtype=numpy.float32)
+ if not numpy.all(numpy.equal(matrix3x3, self.getTransformMatrix())):
+ matrix = numpy.identity(4, dtype=numpy.float32)
+ matrix[:3, :3] = matrix3x3
+ self._dataTransform.setMatrix(matrix)
+ self.sigTransformChanged.emit()
+ self.centerScene() # Reset viewpoint
+
+ def getTransformMatrix(self):
+ """Returns the transform matrix applied to the data.
+
+ See :meth:`setTransformMatrix`.
+
+ :rtype: numpy.ndarray
+ """
+ return self._dataTransform.getMatrix()[:3, :3]
+
+ # Axes labels
+
+ def isBoundingBoxVisible(self):
+ """Returns axes labels, grid and bounding box visibility.
+
+ :rtype: bool
+ """
+ return self._bbox.boxVisible
+
+ def setBoundingBoxVisible(self, visible):
+ """Set axes labels, grid and bounding box visibility.
+
+ :param bool visible: True to show axes, False to hide
+ """
+ visible = bool(visible)
+ self._bbox.boxVisible = visible
+
+ def setAxesLabels(self, xlabel=None, ylabel=None, zlabel=None):
+ """Set the text labels of the axes.
+
+ :param str xlabel: Label of the X axis, None to leave unchanged.
+ :param str ylabel: Label of the Y axis, None to leave unchanged.
+ :param str zlabel: Label of the Z axis, None to leave unchanged.
+ """
+ if xlabel is not None:
+ self._bbox.xlabel = xlabel
+
+ if ylabel is not None:
+ self._bbox.ylabel = ylabel
+
+ if zlabel is not None:
+ self._bbox.zlabel = zlabel
+
+ class _Labels(tuple):
+ """Return type of :meth:`getAxesLabels`"""
+
+ def getXLabel(self):
+ """Label of the X axis (str)"""
+ return self[0]
+
+ def getYLabel(self):
+ """Label of the Y axis (str)"""
+ return self[1]
+
+ def getZLabel(self):
+ """Label of the Z axis (str)"""
+ return self[2]
+
+ def getAxesLabels(self):
+ """Returns the text labels of the axes
+
+ >>> widget = ScalarFieldView()
+ >>> widget.setAxesLabels(xlabel='X')
+
+ You can get the labels either as a 3-tuple:
+
+ >>> xlabel, ylabel, zlabel = widget.getAxesLabels()
+
+ Or as an object with methods getXLabel, getYLabel and getZLabel:
+
+ >>> labels = widget.getAxesLabels()
+ >>> labels.getXLabel()
+ ... 'X'
+
+ :return: object describing the labels
+ """
+ return self._Labels((self._bbox.xlabel,
+ self._bbox.ylabel,
+ self._bbox.zlabel))
+
+ # Colors
+
+ def _updateColors(self):
+ """Update item depending on foreground/highlight color"""
+ self._bbox.tickColor = self._foregroundColor
+ self._selectionBox.strokeColor = self._foregroundColor
+ if self.getInteractiveMode() == 'plane':
+ self._cutPlane.setStrokeColor(self._highlightColor)
+ self._bbox.color = self._foregroundColor
+ else:
+ self._cutPlane.setStrokeColor(self._foregroundColor)
+ self._bbox.color = self._highlightColor
+
+ def getForegroundColor(self):
+ """Return color used for text and bounding box (QColor)"""
+ return qt.QColor.fromRgbF(*self._foregroundColor)
+
+ def setForegroundColor(self, color):
+ """Set the foreground color.
+
+ :param color: RGB color: name, #RRGGBB or RGB values
+ :type color:
+ QColor, str or array-like of 3 or 4 float in [0., 1.] or uint8
+ """
+ color = rgba(color)
+ if color != self._foregroundColor:
+ self._foregroundColor = color
+ self._updateColors()
+
+ def getHighlightColor(self):
+ """Return color used for highlighted item bounding box (QColor)"""
+ return qt.QColor.fromRgbF(*self._highlightColor)
+
+ def setHighlightColor(self, color):
+ """Set hightlighted item color.
+
+ :param color: RGB color: name, #RRGGBB or RGB values
+ :type color:
+ QColor, str or array-like of 3 or 4 float in [0., 1.] or uint8
+ """
+ color = rgba(color)
+ if color != self._highlightColor:
+ self._highlightColor = color
+ self._updateColors()
+
+ # Cut Plane
+
+ def getCutPlanes(self):
+ """Return an iterable of all cut planes of the view.
+
+ This includes hidden cut planes.
+
+ For now, there is always one cut plane.
+ """
+ return (self._cutPlane,)
+
+ # Selection
+
+ def setSelectedRegion(self, zrange=None, yrange=None, xrange_=None):
+ """Set the 3D selected region aligned with the axes.
+
+ Provided range are array indices range.
+ The provided ranges are clipped to the data.
+ If a range is None, the range of the array on this dimension is used.
+
+ :param zrange: (zmin, zmax) range of the selection
+ :param yrange: (ymin, ymax) range of the selection
+ :param xrange_: (xmin, xmax) range of the selection
+ """
+ # No range given: unset selection
+ if zrange is None and yrange is None and xrange_ is None:
+ selectedRange = None
+
+ else:
+ # Handle default ranges
+ if self._data is not None:
+ if zrange is None:
+ zrange = 0, self._data.shape[0]
+ if yrange is None:
+ yrange = 0, self._data.shape[1]
+ if xrange_ is None:
+ xrange_ = 0, self._data.shape[2]
+
+ elif None in (xrange_, yrange, zrange):
+ # One of the range is None and no data available
+ raise RuntimeError(
+ 'Data is not set, cannot get default range from it.')
+
+ # Clip selected region to data shape and make sure min <= max
+ selectedRange = numpy.array((
+ (max(0, min(*zrange)),
+ min(self._data.shape[0], max(*zrange))),
+ (max(0, min(*yrange)),
+ min(self._data.shape[1], max(*yrange))),
+ (max(0, min(*xrange_)),
+ min(self._data.shape[2], max(*xrange_))),
+ ), dtype=numpy.int64)
+
+ # numpy.equal supports None
+ if not numpy.all(numpy.equal(selectedRange, self._selectedRange)):
+ self._selectedRange = selectedRange
+
+ # Update scene accordingly
+ if self._selectedRange is None:
+ self._selectionBox.visible = False
+ else:
+ self._selectionBox.visible = True
+ scales = self._selectedRange[:, 1] - self._selectedRange[:, 0]
+ self._selectionBox.size = scales[::-1]
+ self._selectionBox.transforms = [
+ transform.Translate(*self._selectedRange[::-1, 0])]
+
+ self.sigSelectedRegionChanged.emit(self.getSelectedRegion())
+
+ def getSelectedRegion(self):
+ """Returns the currently selected region or None."""
+ if self._selectedRange is None:
+ return None
+ else:
+ dataBBox = self._group.transforms.transformBounds(
+ self._selectedRange[::-1].T).T
+ return SelectedRegion(self._selectedRange, dataBBox,
+ translation=self.getTranslation(),
+ scale=self.getScale())
+
+ # Handle iso-surfaces
+
+ sigIsosurfaceAdded = qt.Signal(object)
+ """Signal emitted when a new iso-surface is added to the view.
+
+ The newly added iso-surface is provided by this signal
+ """
+
+ sigIsosurfaceRemoved = qt.Signal(object)
+ """Signal emitted when an iso-surface is removed from the view
+
+ The removed iso-surface is provided by this signal.
+ """
+
+ def addIsosurface(self, level, color):
+ """Add an iso-surface to the view.
+
+ :param level:
+ The value at which to build the iso-surface or a callable
+ (e.g., a function) taking a 3D numpy.ndarray as input and
+ returning a float.
+ Example: numpy.mean(data) + numpy.std(data)
+ :type level: float or callable
+ :param color: RGBA color of the isosurface
+ :type color: str or array-like of 4 float in [0., 1.]
+ :return: Isosurface object describing this isosurface
+ """
+ isosurface = Isosurface(parent=self)
+ isosurface.setColor(color)
+ if callable(level):
+ isosurface.setAutoLevelFunction(level)
+ else:
+ isosurface.setLevel(level)
+ isosurface._setData(self._data, copy=False)
+ isosurface.sigLevelChanged.connect(self._updateIsosurfaces)
+
+ self._isosurfaces.append(isosurface)
+
+ self._updateIsosurfaces()
+
+ self.sigIsosurfaceAdded.emit(isosurface)
+ return isosurface
+
+ def getIsosurfaces(self):
+ """Return an iterable of all iso-surfaces of the view"""
+ return tuple(self._isosurfaces)
+
+ def removeIsosurface(self, isosurface):
+ """Remove an iso-surface from the view.
+
+ :param isosurface: The isosurface object to remove"""
+ if isosurface not in self.getIsosurfaces():
+ _logger.warning(
+ "Try to remove isosurface that is not in the list: %s",
+ str(isosurface))
+ else:
+ isosurface.sigLevelChanged.disconnect(self._updateIsosurfaces)
+ self._isosurfaces.remove(isosurface)
+ self._updateIsosurfaces()
+ self.sigIsosurfaceRemoved.emit(isosurface)
+
+ def clearIsosurfaces(self):
+ """Remove all iso-surfaces from the view."""
+ for isosurface in self.getIsosurfaces():
+ self.removeIsosurface(isosurface)
+
+ def _updateIsosurfaces(self, level=None):
+ """Handle updates of iso-surfaces level and add/remove"""
+ # Sorting using minus, this supposes data 'object' to be max values
+ sortedIso = sorted(self.getIsosurfaces(),
+ key=lambda iso: - iso.getLevel())
+ self._isogroup.children = [iso._get3DPrimitive() for iso in sortedIso]
diff --git a/src/silx/gui/plot3d/SceneWidget.py b/src/silx/gui/plot3d/SceneWidget.py
new file mode 100644
index 0000000..883f5e7
--- /dev/null
+++ b/src/silx/gui/plot3d/SceneWidget.py
@@ -0,0 +1,687 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a widget to view data sets in 3D."""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+import enum
+import weakref
+
+import numpy
+
+from .. import qt
+from ..colors import rgba
+
+from .Plot3DWidget import Plot3DWidget
+from . import items
+from .items.core import RootGroupWithAxesItem
+from .scene import interaction
+from ._model import SceneModel, visitQAbstractItemModel
+from ._model.items import Item3DRow
+
+__all__ = ['items', 'SceneWidget']
+
+
+class _SceneSelectionHighlightManager(object):
+ """Class controlling the highlight of the selection in a SceneWidget
+
+ :param ~silx.gui.plot3d.SceneWidget.SceneSelection:
+ """
+
+ def __init__(self, selection):
+ assert isinstance(selection, SceneSelection)
+ self._sceneWidget = weakref.ref(selection.parent())
+
+ self._enabled = True
+ self._previousBBoxState = None
+
+ self.__selectItem(selection.getCurrentItem())
+ selection.sigCurrentChanged.connect(self.__currentChanged)
+
+ def isEnabled(self):
+ """Returns True if highlight of selection in enabled.
+
+ :rtype: bool
+ """
+ return self._enabled
+
+ def setEnabled(self, enabled=True):
+ """Activate/deactivate selection highlighting
+
+ :param bool enabled: True (default) to enable selection highlighting
+ """
+ enabled = bool(enabled)
+ if enabled != self._enabled:
+ self._enabled = enabled
+
+ sceneWidget = self.getSceneWidget()
+ if sceneWidget is not None:
+ selection = sceneWidget.selection()
+ current = selection.getCurrentItem()
+
+ if enabled:
+ self.__selectItem(current)
+ selection.sigCurrentChanged.connect(self.__currentChanged)
+
+ else: # disabled
+ self.__unselectItem(current)
+ selection.sigCurrentChanged.disconnect(
+ self.__currentChanged)
+
+ def getSceneWidget(self):
+ """Returns the SceneWidget this class controls highlight for.
+
+ :rtype: ~silx.gui.plot3d.SceneWidget.SceneWidget
+ """
+ return self._sceneWidget()
+
+ def __selectItem(self, current):
+ """Highlight given item.
+
+ :param ~silx.gui.plot3d.items.Item3D current: New current or None
+ """
+ if current is None:
+ return
+
+ sceneWidget = self.getSceneWidget()
+ if sceneWidget is None:
+ return
+
+ if isinstance(current, items.DataItem3D):
+ self._previousBBoxState = current.isBoundingBoxVisible()
+ current.setBoundingBoxVisible(True)
+ current._setForegroundColor(sceneWidget.getHighlightColor())
+ current.sigItemChanged.connect(self.__selectedChanged)
+
+ def __unselectItem(self, current):
+ """Remove highlight of given item.
+
+ :param ~silx.gui.plot3d.items.Item3D current:
+ Currently highlighted item
+ """
+ if current is None:
+ return
+
+ sceneWidget = self.getSceneWidget()
+ if sceneWidget is None:
+ return
+
+ # Restore bbox visibility and color
+ current.sigItemChanged.disconnect(self.__selectedChanged)
+ if (self._previousBBoxState is not None and
+ isinstance(current, items.DataItem3D)):
+ current.setBoundingBoxVisible(self._previousBBoxState)
+ current._setForegroundColor(sceneWidget.getForegroundColor())
+
+ def __currentChanged(self, current, previous):
+ """Handle change of current item in the selection
+
+ :param ~silx.gui.plot3d.items.Item3D current: New current or None
+ :param ~silx.gui.plot3d.items.Item3D previous: Previous current or None
+ """
+ self.__unselectItem(previous)
+ self.__selectItem(current)
+
+ def __selectedChanged(self, event):
+ """Handle updates of selected item bbox.
+
+ If bbox gets changed while selected, do not restore state.
+
+ :param event:
+ """
+ if event == items.Item3DChangedType.BOUNDING_BOX_VISIBLE:
+ self._previousBBoxState = None
+
+
+@enum.unique
+class HighlightMode(enum.Enum):
+ """:class:`SceneSelection` highlight modes"""
+
+ NONE = 'noHighlight'
+ """Do not highlight selected item"""
+
+ BOUNDING_BOX = 'boundingBox'
+ """Highlight selected item bounding box"""
+
+
+class SceneSelection(qt.QObject):
+ """Object managing a :class:`SceneWidget` selection
+
+ :param SceneWidget parent:
+ """
+
+ NO_SELECTION = 0
+ """Flag for no item selected"""
+
+ sigCurrentChanged = qt.Signal(object, object)
+ """This signal is emitted whenever the current item changes.
+
+ It provides the current and previous items.
+ Either of those can be :attr:`NO_SELECTION`.
+ """
+
+ def __init__(self, parent=None):
+ super(SceneSelection, self).__init__(parent)
+ self.__current = None # Store weakref to current item
+ self.__selectionModel = None # Store sync selection model
+ self.__syncInProgress = False # True during model synchronization
+
+ self.__highlightManager = _SceneSelectionHighlightManager(self)
+
+ def getHighlightMode(self):
+ """Returns current selection highlight mode.
+
+ Either NONE or BOUNDING_BOX.
+
+ :rtype: HighlightMode
+ """
+ if self.__highlightManager.isEnabled():
+ return HighlightMode.BOUNDING_BOX
+ else:
+ return HighlightMode.NONE
+
+ def setHighlightMode(self, mode):
+ """Set selection highlighting mode
+
+ :param HighlightMode mode: The mode to use
+ """
+ assert isinstance(mode, HighlightMode)
+ self.__highlightManager.setEnabled(mode == HighlightMode.BOUNDING_BOX)
+
+ def getCurrentItem(self):
+ """Returns the current item in the scene or None.
+
+ :rtype: Union[~silx.gui.plot3d.items.Item3D, None]
+ """
+ return None if self.__current is None else self.__current()
+
+ def setCurrentItem(self, item):
+ """Set the current item in the scene.
+
+ :param Union[Item3D, None] item:
+ The new item to select or None to clear the selection.
+ :raise ValueError: If the item is not the widget's scene
+ """
+ previous = self.getCurrentItem()
+ if item is previous:
+ return # Fast path, nothing to do
+
+ if previous is not None:
+ previous.sigItemChanged.disconnect(self.__currentChanged)
+
+ if item is None:
+ self.__current = None
+
+ elif isinstance(item, items.Item3D):
+ parent = self.parent()
+ assert isinstance(parent, SceneWidget)
+
+ sceneGroup = parent.getSceneGroup()
+ if item is sceneGroup or item.root() is sceneGroup:
+ item.sigItemChanged.connect(self.__currentChanged)
+ self.__current = weakref.ref(item)
+ else:
+ raise ValueError(
+ 'Item is not in this SceneWidget: %s' % str(item))
+
+ else:
+ raise ValueError(
+ 'Not an Item3D: %s' % str(item))
+
+ current = self.getCurrentItem()
+ self.sigCurrentChanged.emit(current, previous)
+ self.__updateSelectionModel()
+
+ def __currentChanged(self, event):
+ """Handle updates of the selected item"""
+ if event == items.Item3DChangedType.ROOT_ITEM:
+ item = self.sender()
+
+ parent = self.parent()
+ assert isinstance(parent, SceneWidget)
+
+ if item.root() != parent.getSceneGroup():
+ self.setCurrentItem(None)
+
+ # Synchronization with QItemSelectionModel
+
+ def _getSyncSelectionModel(self):
+ """Returns the QItemSelectionModel this selection is synchronized with.
+
+ :rtype: Union[QItemSelectionModel, None]
+ """
+ return self.__selectionModel
+
+ def _setSyncSelectionModel(self, selectionModel):
+ """Synchronizes this selection object with a selection model.
+
+ :param Union[QItemSelectionModel, None] selectionModel:
+ :raise ValueError: If the selection model does not correspond
+ to the same :class:`SceneWidget`
+ """
+ if (not isinstance(selectionModel, qt.QItemSelectionModel) or
+ not isinstance(selectionModel.model(), SceneModel) or
+ selectionModel.model().sceneWidget() is not self.parent()):
+ raise ValueError("Expecting a QItemSelectionModel "
+ "attached to the same SceneWidget")
+
+ # Disconnect from previous selection model
+ previousSelectionModel = self._getSyncSelectionModel()
+ if previousSelectionModel is not None:
+ previousSelectionModel.selectionChanged.disconnect(
+ self.__selectionModelSelectionChanged)
+
+ self.__selectionModel = selectionModel
+
+ if selectionModel is not None:
+ # Connect to new selection model
+ selectionModel.selectionChanged.connect(
+ self.__selectionModelSelectionChanged)
+ self.__updateSelectionModel()
+
+ def __selectionModelSelectionChanged(self, selected, deselected):
+ """Handle QItemSelectionModel selection updates.
+
+ :param QItemSelection selected:
+ :param QItemSelection deselected:
+ """
+ if self.__syncInProgress:
+ return
+
+ indices = selected.indexes()
+ if not indices:
+ item = None
+
+ else: # Select the first selected item
+ index = indices[0]
+ itemRow = index.internalPointer()
+ if isinstance(itemRow, Item3DRow):
+ item = itemRow.item()
+ else:
+ item = None
+
+ self.setCurrentItem(item)
+
+ def __updateSelectionModel(self):
+ """Sync selection model when current item has been updated"""
+ selectionModel = self._getSyncSelectionModel()
+ if selectionModel is None:
+ return
+
+ currentItem = self.getCurrentItem()
+
+ if currentItem is None:
+ selectionModel.clear()
+
+ else:
+ # visit the model to find selectable index corresponding to item
+ model = selectionModel.model()
+ for index in visitQAbstractItemModel(model):
+ itemRow = index.internalPointer()
+ if (isinstance(itemRow, Item3DRow) and
+ itemRow.item() is currentItem and
+ index.flags() & qt.Qt.ItemIsSelectable):
+ # This is the item we are looking for: select it in the model
+ self.__syncInProgress = True
+ selectionModel.select(
+ index, qt.QItemSelectionModel.Clear |
+ qt.QItemSelectionModel.Select |
+ qt.QItemSelectionModel.Current)
+ self.__syncInProgress = False
+ break
+
+
+class SceneWidget(Plot3DWidget):
+ """Widget displaying data sets in 3D"""
+
+ def __init__(self, parent=None):
+ super(SceneWidget, self).__init__(parent)
+ self._model = None # Store lazy-loaded model
+ self._selection = None # Store lazy-loaded SceneSelection
+ self._items = []
+
+ self._textColor = 1., 1., 1., 1.
+ self._foregroundColor = 1., 1., 1., 1.
+ self._highlightColor = 0.7, 0.7, 0., 1.
+
+ self._sceneGroup = RootGroupWithAxesItem(parent=self)
+ self._sceneGroup.setLabel('Data')
+
+ self.viewport.scene.children.append(
+ self._sceneGroup._getScenePrimitive())
+
+ def model(self):
+ """Returns the model corresponding the scene of this widget
+
+ :rtype: SceneModel
+ """
+ if self._model is None:
+ # Lazy-loading of the model
+ self._model = SceneModel(parent=self)
+ return self._model
+
+ def selection(self):
+ """Returns the object managing selection in the scene
+
+ :rtype: SceneSelection
+ """
+ if self._selection is None:
+ # Lazy-loading of the SceneSelection
+ self._selection = SceneSelection(parent=self)
+ return self._selection
+
+ def getSceneGroup(self):
+ """Returns the root group of the scene
+
+ :rtype: GroupItem
+ """
+ return self._sceneGroup
+
+ def pickItems(self, x, y, condition=None):
+ """Iterator over picked items in the scene at given position.
+
+ Each picked item yield a
+ :class:`~silx.gui.plot3d.items._pick.PickingResult` object
+ holding the picking information.
+
+ It traverses the scene tree in a left-to-right top-down way.
+
+ :param int x: X widget coordinate
+ :param int y: Y widget coordinate
+ :param callable condition: Optional test called for each item
+ checking whether to process it or not.
+ """
+ if not self.isValid() or not self.isVisible():
+ return # Empty iterator
+
+ devicePixelRatio = self.getDevicePixelRatio()
+ for result in self.getSceneGroup().pickItems(
+ x * devicePixelRatio, y * devicePixelRatio, condition):
+ yield result
+
+ # Interactive modes
+
+ def _handleSelectionChanged(self, current, previous):
+ """Handle change of selection to update interactive mode"""
+ if self.getInteractiveMode() == 'panSelectedPlane':
+ if isinstance(current, items.PlaneMixIn):
+ # Update pan plane to use new selected plane
+ self.setInteractiveMode('panSelectedPlane')
+
+ else: # Switch to rotate scene if new selection is not a plane
+ self.setInteractiveMode('rotate')
+
+ def setInteractiveMode(self, mode):
+ """Set the interactive mode.
+
+ 'panSelectedPlane' mode set plane panning if a plane is selected,
+ otherwise it fall backs to 'rotate'.
+
+ :param str mode:
+ The interactive mode: 'rotate', 'pan', 'panSelectedPlane' or None
+ """
+ if self.getInteractiveMode() == 'panSelectedPlane':
+ self.selection().sigCurrentChanged.disconnect(
+ self._handleSelectionChanged)
+
+ if mode == 'panSelectedPlane':
+ selected = self.selection().getCurrentItem()
+
+ if isinstance(selected, items.PlaneMixIn):
+ mode = interaction.PanPlaneZoomOnWheelControl(
+ self.viewport,
+ selected._getPlane(),
+ mode='position',
+ orbitAroundCenter=False,
+ scaleTransform=self._sceneScale)
+
+ self.selection().sigCurrentChanged.connect(
+ self._handleSelectionChanged)
+
+ else: # No selected plane, fallback to rotate scene
+ mode = 'rotate'
+
+ super(SceneWidget, self).setInteractiveMode(mode)
+
+ def getInteractiveMode(self):
+ """Returns the interactive mode in use.
+
+ :rtype: str
+ """
+ if isinstance(self.eventHandler, interaction.PanPlaneZoomOnWheelControl):
+ return 'panSelectedPlane'
+ else:
+ return super(SceneWidget, self).getInteractiveMode()
+
+ # Add/remove items
+
+ def addVolume(self, data, copy=True, index=None):
+ """Add 3D data volume of scalar or complex to :class:`SceneWidget` content.
+
+ Dataset order is zyx (i.e., first dimension is z).
+
+ :param data: 3D array of complex with shape at least (2, 2, 2)
+ :type data: numpy.ndarray[Union[numpy.complex64,numpy.float32]]
+ :param bool copy:
+ True (default) to make a copy,
+ False to avoid copy (DO NOT MODIFY data afterwards)
+ :param int index: The index at which to place the item.
+ By default it is appended to the end of the list.
+ :return: The newly created 3D volume item
+ :rtype: Union[ScalarField3D,ComplexField3D]
+
+ """
+ if data is not None:
+ data = numpy.array(data, copy=False)
+
+ if numpy.iscomplexobj(data):
+ volume = items.ComplexField3D()
+ else:
+ volume = items.ScalarField3D()
+ volume.setData(data, copy=copy)
+ self.addItem(volume, index)
+ return volume
+
+ def add3DScalarField(self, data, copy=True, index=None):
+ # TODO deprecate in the future
+ return self.addVolume(data, copy=copy, index=index)
+
+ def add3DScatter(self, x, y, z, value, copy=True, index=None):
+ """Add 3D scatter data to :class:`SceneWidget` content.
+
+ :param numpy.ndarray x: Array of X coordinates (single value not accepted)
+ :param y: Points Y coordinate (array-like or single value)
+ :param z: Points Z coordinate (array-like or single value)
+ :param value: Points values (array-like or single value)
+ :param bool copy:
+ True (default) to copy the data,
+ False to use provided data (do not modify!)
+ :param int index: The index at which to place the item.
+ By default it is appended to the end of the list.
+ :return: The newly created 3D scatter item
+ :rtype: ~silx.gui.plot3d.items.scatter.Scatter3D
+ """
+ scatter3d = items.Scatter3D()
+ scatter3d.setData(x=x, y=y, z=z, value=value, copy=copy)
+ self.addItem(scatter3d, index)
+ return scatter3d
+
+ def add2DScatter(self, x, y, value, copy=True, index=None):
+ """Add 2D scatter data to :class:`SceneWidget` content.
+
+ Provided arrays must have the same length.
+
+ :param numpy.ndarray x: X coordinates (array-like)
+ :param numpy.ndarray y: Y coordinates (array-like)
+ :param value: Points value: array-like or single scalar
+ :param bool copy: True (default) to copy the data,
+ False to use as is (do not modify!).
+ :param int index: The index at which to place the item.
+ By default it is appended to the end of the list.
+ :return: The newly created 2D scatter item
+ :rtype: ~silx.gui.plot3d.items.scatter.Scatter2D
+ """
+ scatter2d = items.Scatter2D()
+ scatter2d.setData(x=x, y=y, value=value, copy=copy)
+ self.addItem(scatter2d, index)
+ return scatter2d
+
+ def addImage(self, data, copy=True, index=None):
+ """Add a 2D data or RGB(A) image to :class:`SceneWidget` content.
+
+ 2D data is casted to float32.
+ RGBA supported formats are: float32 in [0, 1] and uint8.
+
+ :param numpy.ndarray data: Image as a 2D data array or
+ RGBA image as a 3D array (height, width, channels)
+ :param bool copy: True (default) to copy the data,
+ False to use as is (do not modify!).
+ :param int index: The index at which to place the item.
+ By default it is appended to the end of the list.
+ :return: The newly created image item
+ :rtype: ~silx.gui.plot3d.items.image.ImageData or ~silx.gui.plot3d.items.image.ImageRgba
+ :raise ValueError: For arrays of unsupported dimensions
+ """
+ data = numpy.array(data, copy=False)
+ if data.ndim == 2:
+ image = items.ImageData()
+ elif data.ndim == 3:
+ image = items.ImageRgba()
+ else:
+ raise ValueError("Unsupported array dimensions: %d" % data.ndim)
+ image.setData(data, copy=copy)
+ self.addItem(image, index)
+ return image
+
+ def addItem(self, item, index=None):
+ """Add an item to :class:`SceneWidget` content
+
+ :param Item3D item: The item to add
+ :param int index: The index at which to place the item.
+ By default it is appended to the end of the list.
+ :raise ValueError: If the item is already in the :class:`SceneWidget`.
+ """
+ return self.getSceneGroup().addItem(item, index)
+
+ def removeItem(self, item):
+ """Remove an item from :class:`SceneWidget` content.
+
+ :param Item3D item: The item to remove from the scene
+ :raises ValueError: If the item does not belong to the group
+ """
+ return self.getSceneGroup().removeItem(item)
+
+ def getItems(self):
+ """Returns the list of :class:`SceneWidget` items.
+
+ Only items in the top-level group are returned.
+
+ :rtype: tuple
+ """
+ return self.getSceneGroup().getItems()
+
+ def clearItems(self):
+ """Remove all item from :class:`SceneWidget`."""
+ return self.getSceneGroup().clearItems()
+
+ # Colors
+
+ def getTextColor(self):
+ """Return color used for text
+
+ :rtype: QColor"""
+ return qt.QColor.fromRgbF(*self._textColor)
+
+ def setTextColor(self, color):
+ """Set the text color.
+
+ :param color: RGB color: name, #RRGGBB or RGB values
+ :type color:
+ QColor, str or array-like of 3 or 4 float in [0., 1.] or uint8
+ """
+ color = rgba(color)
+ if color != self._textColor:
+ self._textColor = color
+
+ # Update text color
+ # TODO make entry point in Item3D for this
+ bbox = self._sceneGroup._getScenePrimitive()
+ bbox.tickColor = color
+
+ self.sigStyleChanged.emit('textColor')
+
+ def getForegroundColor(self):
+ """Return color used for bounding box
+
+ :rtype: QColor
+ """
+ return qt.QColor.fromRgbF(*self._foregroundColor)
+
+ def setForegroundColor(self, color):
+ """Set the foreground color.
+
+ :param color: RGB color: name, #RRGGBB or RGB values
+ :type color:
+ QColor, str or array-like of 3 or 4 float in [0., 1.] or uint8
+ """
+ color = rgba(color)
+ if color != self._foregroundColor:
+ self._foregroundColor = color
+
+ # Update scene items
+ selected = self.selection().getCurrentItem()
+ for item in self.getSceneGroup().visit(included=True):
+ if item is not selected:
+ item._setForegroundColor(color)
+
+ self.sigStyleChanged.emit('foregroundColor')
+
+ def getHighlightColor(self):
+ """Return color used for highlighted item bounding box
+
+ :rtype: QColor
+ """
+ return qt.QColor.fromRgbF(*self._highlightColor)
+
+ def setHighlightColor(self, color):
+ """Set highlighted item color.
+
+ :param color: RGB color: name, #RRGGBB or RGB values
+ :type color:
+ QColor, str or array-like of 3 or 4 float in [0., 1.] or uint8
+ """
+ color = rgba(color)
+ if color != self._highlightColor:
+ self._highlightColor = color
+
+ selected = self.selection().getCurrentItem()
+ if selected is not None:
+ selected._setForegroundColor(color)
+
+ self.sigStyleChanged.emit('highlightColor')
diff --git a/src/silx/gui/plot3d/SceneWindow.py b/src/silx/gui/plot3d/SceneWindow.py
new file mode 100644
index 0000000..052a4dc
--- /dev/null
+++ b/src/silx/gui/plot3d/SceneWindow.py
@@ -0,0 +1,219 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a QMainWindow with a 3D SceneWidget and toolbars.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "29/11/2017"
+
+
+from ...gui import qt, icons
+from ...gui.widgets.BoxLayoutDockWidget import BoxLayoutDockWidget
+
+from .actions.mode import InteractiveModeAction
+from .SceneWidget import SceneWidget
+from .tools import OutputToolBar, InteractiveModeToolBar, ViewpointToolBar
+from .tools.GroupPropertiesWidget import GroupPropertiesWidget
+from .tools.PositionInfoWidget import PositionInfoWidget
+
+from .ParamTreeView import ParamTreeView
+
+# Imported here for convenience
+from . import items # noqa
+
+
+__all__ = ['items', 'SceneWidget', 'SceneWindow']
+
+
+class _PanPlaneAction(InteractiveModeAction):
+ """QAction to set plane pan interaction on a Plot3DWidget
+
+ :param parent: See :class:`QAction`
+ :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
+ Plot3DWidget the action is associated with
+ """
+ def __init__(self, parent, plot3d=None):
+ super(_PanPlaneAction, self).__init__(
+ parent, 'panSelectedPlane', plot3d)
+ self.setIcon(icons.getQIcon('3d-plane-pan'))
+ self.setText('Pan plane')
+ self.setCheckable(True)
+ self.setToolTip(
+ 'Pan selected plane. Press <b>Ctrl</b> to rotate the scene.')
+
+ def _planeChanged(self, event):
+ """Handle plane updates"""
+ if event in (items.ItemChangedType.VISIBLE,
+ items.ItemChangedType.POSITION):
+ plane = self.sender()
+
+ isPlaneInteractive = \
+ plane._getPlane().plane.isPlane and plane.isVisible()
+
+ if isPlaneInteractive != self.isEnabled():
+ self.setEnabled(isPlaneInteractive)
+ mode = 'panSelectedPlane' if isPlaneInteractive else 'rotate'
+ self.getPlot3DWidget().setInteractiveMode(mode)
+
+ def _selectionChanged(self, current, previous):
+ """Handle selected object change"""
+ if isinstance(previous, items.PlaneMixIn):
+ previous.sigItemChanged.disconnect(self._planeChanged)
+
+ if isinstance(current, items.PlaneMixIn):
+ current.sigItemChanged.connect(self._planeChanged)
+ self.setEnabled(True)
+ self.getPlot3DWidget().setInteractiveMode('panSelectedPlane')
+ else:
+ self.setEnabled(False)
+
+ def setPlot3DWidget(self, widget):
+ previous = self.getPlot3DWidget()
+ if isinstance(previous, SceneWidget):
+ previous.selection().sigCurrentChanged.disconnect(
+ self._selectionChanged)
+ self._selectionChanged(
+ None, previous.selection().getCurrentItem())
+
+ super(_PanPlaneAction, self).setPlot3DWidget(widget)
+
+ if isinstance(widget, SceneWidget):
+ self._selectionChanged(widget.selection().getCurrentItem(), None)
+ widget.selection().sigCurrentChanged.connect(
+ self._selectionChanged)
+
+
+class SceneWindow(qt.QMainWindow):
+ """OpenGL 3D scene widget with toolbars."""
+
+ def __init__(self, parent=None):
+ super(SceneWindow, self).__init__(parent)
+ if parent is not None:
+ # behave as a widget
+ self.setWindowFlags(qt.Qt.Widget)
+
+ self._sceneWidget = SceneWidget()
+ self.setCentralWidget(self._sceneWidget)
+
+ # Add PositionInfoWidget to display picking info
+ self._positionInfo = PositionInfoWidget()
+ self._positionInfo.setSceneWidget(self._sceneWidget)
+
+ dock = BoxLayoutDockWidget()
+ dock.setWindowTitle("Selection Info")
+ dock.setWidget(self._positionInfo)
+ self.addDockWidget(qt.Qt.BottomDockWidgetArea, dock)
+
+ self._interactiveModeToolBar = InteractiveModeToolBar(parent=self)
+ panPlaneAction = _PanPlaneAction(self, plot3d=self._sceneWidget)
+ self._interactiveModeToolBar.addAction(
+ self._positionInfo.toggleAction())
+ self._interactiveModeToolBar.addAction(panPlaneAction)
+
+ self._viewpointToolBar = ViewpointToolBar(parent=self)
+ self._outputToolBar = OutputToolBar(parent=self)
+
+ for toolbar in (self._interactiveModeToolBar,
+ self._viewpointToolBar,
+ self._outputToolBar):
+ toolbar.setPlot3DWidget(self._sceneWidget)
+ self.addToolBar(toolbar)
+ self.addActions(toolbar.actions())
+
+ self._paramTreeView = ParamTreeView()
+ self._paramTreeView.setModel(self._sceneWidget.model())
+
+ selectionModel = self._paramTreeView.selectionModel()
+ self._sceneWidget.selection()._setSyncSelectionModel(
+ selectionModel)
+
+ paramDock = qt.QDockWidget()
+ paramDock.setWindowTitle('Object parameters')
+ paramDock.setWidget(self._paramTreeView)
+ self.addDockWidget(qt.Qt.RightDockWidgetArea, paramDock)
+
+ self._sceneGroupResetWidget = GroupPropertiesWidget()
+ self._sceneGroupResetWidget.setGroup(
+ self._sceneWidget.getSceneGroup())
+
+ resetDock = qt.QDockWidget()
+ resetDock.setWindowTitle('Global parameters')
+ resetDock.setWidget(self._sceneGroupResetWidget)
+ self.addDockWidget(qt.Qt.RightDockWidgetArea, resetDock)
+ self.tabifyDockWidget(paramDock, resetDock)
+
+ paramDock.raise_()
+
+ def getSceneWidget(self):
+ """Returns the SceneWidget of this window.
+
+ :rtype: ~silx.gui.plot3d.SceneWidget.SceneWidget
+ """
+ return self._sceneWidget
+
+ def getGroupResetWidget(self):
+ """Returns the :class:`GroupPropertiesWidget` of this window.
+
+ :rtype: GroupPropertiesWidget
+ """
+ return self._sceneGroupResetWidget
+
+ def getParamTreeView(self):
+ """Returns the :class:`ParamTreeView` of this window.
+
+ :rtype: ParamTreeView
+ """
+ return self._paramTreeView
+
+ def getInteractiveModeToolBar(self):
+ """Returns the interactive mode toolbar.
+
+ :rtype: ~silx.gui.plot3d.tools.InteractiveModeToolBar
+ """
+ return self._interactiveModeToolBar
+
+ def getViewpointToolBar(self):
+ """Returns the viewpoint toolbar.
+
+ :rtype: ~silx.gui.plot3d.tools.ViewpointToolBar
+ """
+ return self._viewpointToolBar
+
+ def getOutputToolBar(self):
+ """Returns the output toolbar.
+
+ :rtype: ~silx.gui.plot3d.tools.OutputToolBar
+ """
+ return self._outputToolBar
+
+ def getPositionInfoWidget(self):
+ """Returns the widget displaying selected position information.
+
+ :rtype: ~silx.gui.plot3d.tools.PositionInfoWidget.PositionInfoWidget
+ """
+ return self._positionInfo
diff --git a/src/silx/gui/plot3d/__init__.py b/src/silx/gui/plot3d/__init__.py
new file mode 100644
index 0000000..af74613
--- /dev/null
+++ b/src/silx/gui/plot3d/__init__.py
@@ -0,0 +1,40 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This package provides widgets displaying 3D content based on OpenGL.
+
+It depends on PyOpenGL and PyQtx.QtOpenGL or PyQt>=5.4.
+"""
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "18/01/2017"
+
+
+try:
+ import OpenGL as _OpenGL
+except ImportError:
+ raise ImportError('PyOpenGL is not installed')
diff --git a/src/silx/gui/plot3d/_model/__init__.py b/src/silx/gui/plot3d/_model/__init__.py
new file mode 100644
index 0000000..4b16e32
--- /dev/null
+++ b/src/silx/gui/plot3d/_model/__init__.py
@@ -0,0 +1,35 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This package provides :class:`SceneWidget` content and parameters model.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "11/01/2018"
+
+from .model import SceneModel, visitQAbstractItemModel # noqa
diff --git a/src/silx/gui/plot3d/_model/core.py b/src/silx/gui/plot3d/_model/core.py
new file mode 100644
index 0000000..e8e0820
--- /dev/null
+++ b/src/silx/gui/plot3d/_model/core.py
@@ -0,0 +1,372 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module provides base classes to implement models for 3D scene content.
+"""
+
+from __future__ import absolute_import, division
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "11/01/2018"
+
+
+import collections
+import weakref
+
+from ....utils.weakref import WeakMethodProxy
+from ... import qt
+
+
+class BaseRow(qt.QObject):
+ """Base class for rows of the tree model.
+
+ The root node parent MUST be set to the QAbstractItemModel it belongs to.
+ By default item is enabled.
+
+ :param children: Iterable of BaseRow to start with (not signaled)
+ """
+
+ def __init__(self, children=()):
+ self.__modelRef = None
+ self.__parentRef = None
+ super(BaseRow, self).__init__()
+ self.__children = []
+ for row in children:
+ assert isinstance(row, BaseRow)
+ row.setParent(self)
+ self.__children.append(row)
+ self.__flags = collections.defaultdict(lambda: qt.Qt.ItemIsEnabled)
+ self.__tooltip = None
+
+ def setParent(self, parent):
+ """Override :meth:`QObject.setParent` to cache model and parent"""
+ self.__parentRef = None if parent is None else weakref.ref(parent)
+
+ if isinstance(parent, qt.QAbstractItemModel):
+ model = parent
+ elif isinstance(parent, BaseRow):
+ model = parent.model()
+ else:
+ model = None
+
+ self._updateModel(model)
+
+ super(BaseRow, self).setParent(parent)
+
+ def parent(self):
+ """Override :meth:`QObject.setParent` to use cached parent
+
+ :rtype: Union[QObject, None]"""
+ return self.__parentRef() if self.__parentRef is not None else None
+
+ def _updateModel(self, model):
+ """Update the model this row belongs to"""
+ if model != self.model():
+ self.__modelRef = weakref.ref(model) if model is not None else None
+ for child in self.children():
+ child._updateModel(model)
+
+ def model(self):
+ """Return the model this node belongs to or None if not in a model.
+
+ :rtype: Union[QAbstractItemModel, None]
+ """
+ return self.__modelRef() if self.__modelRef is not None else None
+
+ def index(self, column=0):
+ """Return corresponding index in the model or None if not in a model.
+
+ :param int column: The column to make the index for
+ :rtype: Union[QModelIndex, None]
+ """
+ parent = self.parent()
+ model = self.model()
+
+ if model is None: # Not in a model
+ return None
+ elif parent is model: # Root node
+ return qt.QModelIndex()
+ else:
+ index = parent.index()
+ row = parent.children().index(self)
+ return model.index(row, column, index)
+
+ def columnCount(self):
+ """Returns number of columns (default: 2)
+
+ :rtype: int
+ """
+ return 2
+
+ def children(self):
+ """Returns the list of children nodes
+
+ :rtype: tuple of Node
+ """
+ return tuple(self.__children)
+
+ def rowCount(self):
+ """Returns number of rows
+
+ :rtype: int
+ """
+ return len(self.__children)
+
+ def addRow(self, row, index=None):
+ """Add a node to the children
+
+ :param BaseRow row: The node to add
+ :param int index: The index at which to insert it or
+ None to append
+ """
+ if index is None:
+ index = self.rowCount()
+ assert index <= self.rowCount()
+
+ model = self.model()
+
+ if model is not None:
+ parent = self.index()
+ model.beginInsertRows(parent, index, index)
+
+ self.__children.insert(index, row)
+ row.setParent(self)
+
+ if model is not None:
+ model.endInsertRows()
+
+ def removeRow(self, row):
+ """Remove a row from the children list.
+
+ It removes either a node or a row index.
+
+ :param row: BaseRow object or index of row to remove
+ :type row: Union[BaseRow, int]
+ """
+ if isinstance(row, BaseRow):
+ row = self.__children.index(row)
+ else:
+ row = int(row)
+ assert row < self.rowCount()
+
+ model = self.model()
+
+ if model is not None:
+ index = self.index()
+ model.beginRemoveRows(index, row, row)
+
+ node = self.__children.pop(row)
+ node.setParent(None)
+
+ if model is not None:
+ model.endRemoveRows()
+
+ def data(self, column, role):
+ """Returns data for given column and role
+
+ :param int column: Column index for this row
+ :param int role: The role to get
+ :return: Corresponding data (Default: None)
+ """
+ if role == qt.Qt.ToolTipRole and self.__tooltip is not None:
+ return self.__tooltip
+ else:
+ return None
+
+ def setData(self, column, value, role):
+ """Set data for given column and role
+
+ :param int column: Column index for this row
+ :param value: The data to set
+ :param int role: The role to set
+ :return: True on success, False on failure
+ :rtype: bool
+ """
+ return False
+
+ def setToolTip(self, tooltip):
+ """Set the tooltip of the whole row.
+
+ If None there is no tooltip.
+
+ :param Union[str, None] tooltip:
+ """
+ self.__tooltip = tooltip
+
+ def setFlags(self, flags, column=None):
+ """Set the static flags to return.
+
+ Default is ItemIsEnabled for all columns.
+
+ :param int column: The column for which to set the flags
+ :param flags: Item flags
+ """
+ if column is None:
+ self.__flags = collections.defaultdict(lambda: flags)
+ else:
+ self.__flags[column] = flags
+
+ def flags(self, column):
+ """Returns flags for given column
+
+ :rtype: int
+ """
+ return self.__flags[column]
+
+
+class StaticRow(BaseRow):
+ """Row with static data.
+
+ :param tuple display: List of data for DisplayRole for each column
+ :param dict roles: Optional mapping of roles to list of data.
+ :param children: Iterable of BaseRow to start with (not signaled)
+ """
+
+ def __init__(self, display=('', None), roles=None, children=()):
+ super(StaticRow, self).__init__(children)
+ self._dataByRoles = {} if roles is None else roles
+ self._dataByRoles[qt.Qt.DisplayRole] = display
+
+ def data(self, column, role):
+ if role in self._dataByRoles:
+ data = self._dataByRoles[role]
+ if column < len(data):
+ return data[column]
+ return super(StaticRow, self).data(column, role)
+
+ def columnCount(self):
+ return len(self._dataByRoles[qt.Qt.DisplayRole])
+
+
+class ProxyRow(BaseRow):
+ """Provides a node to proxy a data accessible through functions.
+
+ Warning: Only weak reference are kept on fget and fset.
+
+ :param str name: The name of this node
+ :param callable fget: A callable returning the data
+ :param callable fset:
+ An optional callable setting the data with data as a single argument.
+ :param notify:
+ An optional signal emitted when data has changed.
+ :param callable toModelData:
+ An optional callable to convert from fget
+ callable to data returned by the model.
+ :param callable fromModelData:
+ An optional callable converting data provided to the model to
+ data for fset.
+ :param editorHint: Data to provide as UserRole for editor selection/setup
+ """
+
+ def __init__(self,
+ name='',
+ fget=None,
+ fset=None,
+ notify=None,
+ toModelData=None,
+ fromModelData=None,
+ editorHint=None):
+
+ super(ProxyRow, self).__init__()
+ self.__name = name
+ self.__editorHint = editorHint
+
+ assert fget is not None
+ self._fget = WeakMethodProxy(fget)
+ self._fset = WeakMethodProxy(fset) if fset is not None else None
+ if fset is not None:
+ self.setFlags(qt.Qt.ItemIsEnabled | qt.Qt.ItemIsEditable, 1)
+ self._toModelData = toModelData
+ self._fromModelData = fromModelData
+
+ if notify is not None:
+ notify.connect(self._notified) # TODO support sigItemChanged flags
+
+ def _notified(self, *args, **kwargs):
+ """Send update to the model upon signal notifications"""
+ index = self.index(column=1)
+ model = self.model()
+ if model is not None:
+ model.dataChanged.emit(index, index)
+
+ def data(self, column, role):
+ if column == 0:
+ if role == qt.Qt.DisplayRole:
+ return self.__name
+
+ elif column == 1:
+ if role == qt.Qt.UserRole: # EditorHint
+ return self.__editorHint
+ elif role == qt.Qt.DisplayRole or (role == qt.Qt.EditRole and
+ self._fset is not None):
+ data = self._fget()
+ if self._toModelData is not None:
+ data = self._toModelData(data)
+ return data
+
+ return super(ProxyRow, self).data(column, role)
+
+ def setData(self, column, value, role):
+ if role == qt.Qt.EditRole and self._fset is not None:
+ if self._fromModelData is not None:
+ value = self._fromModelData(value)
+ self._fset(value)
+ return True
+
+ return super(ProxyRow, self).setData(column, value, role)
+
+
+class ColorProxyRow(ProxyRow):
+ """Provides a proxy to a QColor property.
+
+ The color is returned through the decorative role.
+
+ See :class:`ProxyRow`
+ """
+
+ def data(self, column, role):
+ if column == 1: # Show color as decoration, not text
+ if role == qt.Qt.DisplayRole:
+ return None
+ if role == qt.Qt.DecorationRole:
+ role = qt.Qt.DisplayRole
+ return super(ColorProxyRow, self).data(column, role)
+
+
+class AngleDegreeRow(ProxyRow):
+ """ProxyRow patching display of column 1 to add degree symbol
+
+ See :class:`ProxyRow`
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(AngleDegreeRow, self).__init__(*args, **kwargs)
+
+ def data(self, column, role):
+ if column == 1 and role == qt.Qt.DisplayRole:
+ return u'%g°' % super(AngleDegreeRow, self).data(column, role)
+ else:
+ return super(AngleDegreeRow, self).data(column, role)
diff --git a/src/silx/gui/plot3d/_model/items.py b/src/silx/gui/plot3d/_model/items.py
new file mode 100644
index 0000000..492f44b
--- /dev/null
+++ b/src/silx/gui/plot3d/_model/items.py
@@ -0,0 +1,1759 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module provides base classes to implement models for 3D scene content
+"""
+
+from __future__ import absolute_import, division
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+
+from collections import OrderedDict
+import functools
+import logging
+import weakref
+
+import numpy
+
+from ...utils.image import convertArrayToQImage
+from ...colors import preferredColormaps
+from ... import qt, icons
+from .. import items
+from ..items.volume import Isosurface, CutPlane, ComplexIsosurface
+from ..Plot3DWidget import Plot3DWidget
+
+
+from .core import AngleDegreeRow, BaseRow, ColorProxyRow, ProxyRow, StaticRow
+
+
+_logger = logging.getLogger(__name__)
+
+
+class ItemProxyRow(ProxyRow):
+ """Provides a node to proxy a data accessible through functions.
+
+ It listens on sigItemChanged to trigger the update.
+
+ Warning: Only weak reference are kept on fget and fset.
+
+ :param Item3D item: The item to
+ :param str name: The name of this node
+ :param callable fget: A callable returning the data
+ :param callable fset:
+ An optional callable setting the data with data as a single argument.
+ :param events:
+ An optional event kind or list of event kinds to react upon.
+ :param callable toModelData:
+ An optional callable to convert from fget
+ callable to data returned by the model.
+ :param callable fromModelData:
+ An optional callable converting data provided to the model to
+ data for fset.
+ :param editorHint: Data to provide as UserRole for editor selection/setup
+ """
+
+ def __init__(self,
+ item,
+ name='',
+ fget=None,
+ fset=None,
+ events=None,
+ toModelData=None,
+ fromModelData=None,
+ editorHint=None):
+ super(ItemProxyRow, self).__init__(
+ name=name,
+ fget=fget,
+ fset=fset,
+ notify=None,
+ toModelData=toModelData,
+ fromModelData=fromModelData,
+ editorHint=editorHint)
+
+ if isinstance(events, (items.ItemChangedType,
+ items.Item3DChangedType)):
+ events = (events,)
+ self.__events = events
+ item.sigItemChanged.connect(self._itemChanged)
+
+ def _itemChanged(self, event):
+ """Handle item changed
+
+ :param Union[ItemChangedType,Item3DChangedType] event:
+ """
+ if self.__events is None or event in self.__events:
+ self._notified()
+
+
+class ItemColorProxyRow(ColorProxyRow, ItemProxyRow):
+ """Combines :class:`ColorProxyRow` and :class:`ItemProxyRow`"""
+
+ def __init__(self, *args, **kwargs):
+ ItemProxyRow.__init__(self, *args, **kwargs)
+
+
+class ItemAngleDegreeRow(AngleDegreeRow, ItemProxyRow):
+ """Combines :class:`AngleDegreeRow` and :class:`ItemProxyRow`"""
+
+ def __init__(self, *args, **kwargs):
+ ItemProxyRow.__init__(self, *args, **kwargs)
+
+
+class _DirectionalLightProxy(qt.QObject):
+ """Proxy to handle directional light with angles rather than vector.
+ """
+
+ sigAzimuthAngleChanged = qt.Signal()
+ """Signal sent when the azimuth angle has changed."""
+
+ sigAltitudeAngleChanged = qt.Signal()
+ """Signal sent when altitude angle has changed."""
+
+ def __init__(self, light):
+ super(_DirectionalLightProxy, self).__init__()
+ self._light = light
+ light.addListener(self._directionUpdated)
+ self._azimuth = 0
+ self._altitude = 0
+
+ def getAzimuthAngle(self):
+ """Returns the signed angle in the horizontal plane.
+
+ Unit: degrees.
+ The 0 angle corresponds to the axis perpendicular to the screen.
+
+ :rtype: int
+ """
+ return self._azimuth
+
+ def getAltitudeAngle(self):
+ """Returns the signed vertical angle from the horizontal plane.
+
+ Unit: degrees.
+ Range: [-90, +90]
+
+ :rtype: int
+ """
+ return self._altitude
+
+ def setAzimuthAngle(self, angle):
+ """Set the horizontal angle.
+
+ :param int angle: Angle from -z axis in zx plane in degrees.
+ """
+ angle = int(round(angle))
+ if angle != self._azimuth:
+ self._azimuth = angle
+ self._updateLight()
+ self.sigAzimuthAngleChanged.emit()
+
+ def setAltitudeAngle(self, angle):
+ """Set the horizontal angle.
+
+ :param int angle: Angle from -z axis in zy plane in degrees.
+ """
+ angle = int(round(angle))
+ if angle != self._altitude:
+ self._altitude = angle
+ self._updateLight()
+ self.sigAltitudeAngleChanged.emit()
+
+ def _directionUpdated(self, *args, **kwargs):
+ """Handle light direction update in the scene"""
+ # Invert direction to manipulate the 'source' pointing to
+ # the center of the viewport
+ x, y, z = - self._light.direction
+
+ # Horizontal plane is plane xz
+ azimuth = int(round(numpy.degrees(numpy.arctan2(x, z))))
+ altitude = int(round(numpy.degrees(numpy.pi/2. - numpy.arccos(y))))
+
+ if azimuth != self.getAzimuthAngle():
+ self.setAzimuthAngle(azimuth)
+
+ if altitude != self.getAltitudeAngle():
+ self.setAltitudeAngle(altitude)
+
+ def _updateLight(self):
+ """Update light direction in the scene"""
+ azimuth = numpy.radians(self._azimuth)
+ delta = numpy.pi/2. - numpy.radians(self._altitude)
+ if delta == 0.: # Avoids zenith position
+ delta = 0.0001
+ z = - numpy.sin(delta) * numpy.cos(azimuth)
+ x = - numpy.sin(delta) * numpy.sin(azimuth)
+ y = - numpy.cos(delta)
+ self._light.direction = x, y, z
+
+
+class Settings(StaticRow):
+ """Subtree for :class:`SceneWidget` style parameters.
+
+ :param SceneWidget sceneWidget: The widget to control
+ """
+
+ def __init__(self, sceneWidget):
+ background = ColorProxyRow(
+ name='Background',
+ fget=sceneWidget.getBackgroundColor,
+ fset=sceneWidget.setBackgroundColor,
+ notify=sceneWidget.sigStyleChanged)
+
+ foreground = ColorProxyRow(
+ name='Foreground',
+ fget=sceneWidget.getForegroundColor,
+ fset=sceneWidget.setForegroundColor,
+ notify=sceneWidget.sigStyleChanged)
+
+ text = ColorProxyRow(
+ name='Text',
+ fget=sceneWidget.getTextColor,
+ fset=sceneWidget.setTextColor,
+ notify=sceneWidget.sigStyleChanged)
+
+ highlight = ColorProxyRow(
+ name='Highlight',
+ fget=sceneWidget.getHighlightColor,
+ fset=sceneWidget.setHighlightColor,
+ notify=sceneWidget.sigStyleChanged)
+
+ axesIndicator = ProxyRow(
+ name='Axes Indicator',
+ fget=sceneWidget.isOrientationIndicatorVisible,
+ fset=sceneWidget.setOrientationIndicatorVisible,
+ notify=sceneWidget.sigStyleChanged)
+
+ # Light direction
+
+ self._lightProxy = _DirectionalLightProxy(sceneWidget.viewport.light)
+
+ azimuthNode = ProxyRow(
+ name='Azimuth',
+ fget=self._lightProxy.getAzimuthAngle,
+ fset=self._lightProxy.setAzimuthAngle,
+ notify=self._lightProxy.sigAzimuthAngleChanged,
+ editorHint=(-90, 90))
+
+ altitudeNode = ProxyRow(
+ name='Altitude',
+ fget=self._lightProxy.getAltitudeAngle,
+ fset=self._lightProxy.setAltitudeAngle,
+ notify=self._lightProxy.sigAltitudeAngleChanged,
+ editorHint=(-90, 90))
+
+ lightDirection = StaticRow(('Light Direction', None),
+ children=(azimuthNode, altitudeNode))
+
+ # Fog
+ fog = ProxyRow(
+ name='Fog',
+ fget=sceneWidget.getFogMode,
+ fset=sceneWidget.setFogMode,
+ notify=sceneWidget.sigStyleChanged,
+ toModelData=lambda mode: mode is Plot3DWidget.FogMode.LINEAR,
+ fromModelData=lambda mode: Plot3DWidget.FogMode.LINEAR if mode else Plot3DWidget.FogMode.NONE)
+
+ # Settings row
+ children = (background, foreground, text, highlight,
+ axesIndicator, lightDirection, fog)
+ super(Settings, self).__init__(('Settings', None), children=children)
+
+
+class Item3DRow(BaseRow):
+ """Represents an :class:`Item3D` with checkable visibility
+
+ :param Item3D item: The scene item to represent.
+ :param str name: The optional name of the item
+ """
+
+ _EVENTS = items.ItemChangedType.VISIBLE, items.Item3DChangedType.LABEL
+ """Events for which to update the first column in the tree"""
+
+ def __init__(self, item, name=None):
+ self.__name = None if name is None else str(name)
+ super(Item3DRow, self).__init__()
+
+ self.setFlags(
+ self.flags(0) | qt.Qt.ItemIsUserCheckable | qt.Qt.ItemIsSelectable,
+ 0)
+ self.setFlags(self.flags(1) | qt.Qt.ItemIsSelectable, 1)
+
+ self._item = weakref.ref(item)
+ item.sigItemChanged.connect(self._itemChanged)
+
+ def _itemChanged(self, event):
+ """Handle model update upon change"""
+ if event in self._EVENTS:
+ model = self.model()
+ if model is not None:
+ index = self.index(column=0)
+ model.dataChanged.emit(index, index)
+
+ def item(self):
+ """Returns the :class:`Item3D` item or None"""
+ return self._item()
+
+ def data(self, column, role):
+ if column == 0:
+ if role == qt.Qt.CheckStateRole:
+ item = self.item()
+ if item is not None and item.isVisible():
+ return qt.Qt.Checked
+ else:
+ return qt.Qt.Unchecked
+
+ elif role == qt.Qt.DecorationRole:
+ return icons.getQIcon('item-3dim')
+
+ elif role == qt.Qt.DisplayRole:
+ if self.__name is None:
+ item = self.item()
+ return '' if item is None else item.getLabel()
+ else:
+ return self.__name
+
+ return super(Item3DRow, self).data(column, role)
+
+ def setData(self, column, value, role):
+ if column == 0 and role == qt.Qt.CheckStateRole:
+ item = self.item()
+ if item is not None:
+ item.setVisible(value == qt.Qt.Checked)
+ return True
+ else:
+ return False
+ return super(Item3DRow, self).setData(column, value, role)
+
+ def columnCount(self):
+ return 2
+
+
+class DataItem3DBoundingBoxRow(ItemProxyRow):
+ """Represents :class:`DataItem3D` bounding box visibility
+
+ :param DataItem3D item: The item for which to display/control bounding box
+ """
+
+ def __init__(self, item):
+ super(DataItem3DBoundingBoxRow, self).__init__(
+ item=item,
+ name='Bounding box',
+ fget=item.isBoundingBoxVisible,
+ fset=item.setBoundingBoxVisible,
+ events=items.Item3DChangedType.BOUNDING_BOX_VISIBLE)
+
+
+class MatrixProxyRow(ItemProxyRow):
+ """Proxy for a row of a DataItem3D 3x3 matrix transform
+
+ :param DataItem3D item:
+ :param int index: Matrix row index
+ """
+
+ def __init__(self, item, index):
+ self._item = weakref.ref(item)
+ self._index = index
+
+ super(MatrixProxyRow, self).__init__(
+ item=item,
+ name='',
+ fget=self._getMatrixRow,
+ fset=self._setMatrixRow,
+ events=items.Item3DChangedType.TRANSFORM)
+
+ def _getMatrixRow(self):
+ """Returns the matrix row.
+
+ :rtype: QVector3D
+ """
+ item = self._item()
+ if item is not None:
+ matrix = item.getMatrix()
+ return qt.QVector3D(*matrix[self._index, :])
+ else:
+ return None
+
+ def _setMatrixRow(self, row):
+ """Set the row of the matrix
+
+ :param QVector3D row: Row values to set
+ """
+ item = self._item()
+ if item is not None:
+ matrix = item.getMatrix()
+ matrix[self._index, :] = row.x(), row.y(), row.z()
+ item.setMatrix(matrix)
+
+ def data(self, column, role):
+ data = super(MatrixProxyRow, self).data(column, role)
+
+ if column == 1 and role == qt.Qt.DisplayRole:
+ # Convert QVector3D to text
+ data = "%g; %g; %g" % (data.x(), data.y(), data.z())
+
+ return data
+
+
+class DataItem3DTransformRow(StaticRow):
+ """Represents :class:`DataItem3D` transform parameters
+
+ :param DataItem3D item: The item for which to display/control transform
+ """
+
+ _ROTATION_CENTER_OPTIONS = 'Origin', 'Lower', 'Center', 'Upper'
+
+ def __init__(self, item):
+ super(DataItem3DTransformRow, self).__init__(('Transform', None))
+ self._item = weakref.ref(item)
+
+ translation = ItemProxyRow(
+ item=item,
+ name='Translation',
+ fget=item.getTranslation,
+ fset=self._setTranslation,
+ events=items.Item3DChangedType.TRANSFORM,
+ toModelData=lambda data: qt.QVector3D(*data))
+ self.addRow(translation)
+
+ # Here to keep a reference
+ self._xSetCenter = functools.partial(self._setCenter, index=0)
+ self._ySetCenter = functools.partial(self._setCenter, index=1)
+ self._zSetCenter = functools.partial(self._setCenter, index=2)
+
+ rotateCenter = StaticRow(
+ ('Center', None),
+ children=(
+ ItemProxyRow(item=item,
+ name='X axis',
+ fget=item.getRotationCenter,
+ fset=self._xSetCenter,
+ events=items.Item3DChangedType.TRANSFORM,
+ toModelData=functools.partial(
+ self._centerToModelData, index=0),
+ editorHint=self._ROTATION_CENTER_OPTIONS),
+ ItemProxyRow(item=item,
+ name='Y axis',
+ fget=item.getRotationCenter,
+ fset=self._ySetCenter,
+ events=items.Item3DChangedType.TRANSFORM,
+ toModelData=functools.partial(
+ self._centerToModelData, index=1),
+ editorHint=self._ROTATION_CENTER_OPTIONS),
+ ItemProxyRow(item=item,
+ name='Z axis',
+ fget=item.getRotationCenter,
+ fset=self._zSetCenter,
+ events=items.Item3DChangedType.TRANSFORM,
+ toModelData=functools.partial(
+ self._centerToModelData, index=2),
+ editorHint=self._ROTATION_CENTER_OPTIONS),
+ ))
+
+ rotate = StaticRow(
+ ('Rotation', None),
+ children=(
+ ItemAngleDegreeRow(
+ item=item,
+ name='Angle',
+ fget=item.getRotation,
+ fset=self._setAngle,
+ events=items.Item3DChangedType.TRANSFORM,
+ toModelData=lambda data: data[0]),
+ ItemProxyRow(
+ item=item,
+ name='Axis',
+ fget=item.getRotation,
+ fset=self._setAxis,
+ events=items.Item3DChangedType.TRANSFORM,
+ toModelData=lambda data: qt.QVector3D(*data[1])),
+ rotateCenter
+ ))
+ self.addRow(rotate)
+
+ scale = ItemProxyRow(
+ item=item,
+ name='Scale',
+ fget=item.getScale,
+ fset=self._setScale,
+ events=items.Item3DChangedType.TRANSFORM,
+ toModelData=lambda data: qt.QVector3D(*data))
+ self.addRow(scale)
+
+ matrix = StaticRow(
+ ('Matrix', None),
+ children=(MatrixProxyRow(item, 0),
+ MatrixProxyRow(item, 1),
+ MatrixProxyRow(item, 2)))
+ self.addRow(matrix)
+
+ def item(self):
+ """Returns the :class:`Item3D` item or None"""
+ return self._item()
+
+ @staticmethod
+ def _centerToModelData(center, index):
+ """Convert rotation center information from scene to model.
+
+ :param center: The center info from the scene
+ :param int index: dimension to convert
+ """
+ value = center[index]
+ if isinstance(value, str):
+ return value.title()
+ elif value == 0.:
+ return 'Origin'
+ else:
+ return str(value)
+
+ def _setCenter(self, value, index):
+ """Set one dimension of the rotation center.
+
+ :param value: Value received through the model.
+ :param int index: dimension to set
+ """
+ item = self.item()
+ if item is not None:
+ if value == 'Origin':
+ value = 0.
+ elif value not in self._ROTATION_CENTER_OPTIONS:
+ value = float(value)
+ else:
+ value = value.lower()
+
+ center = list(item.getRotationCenter())
+ center[index] = value
+ item.setRotationCenter(*center)
+
+ def _setAngle(self, angle):
+ """Set rotation angle.
+
+ :param float angle:
+ """
+ item = self.item()
+ if item is not None:
+ _, axis = item.getRotation()
+ item.setRotation(angle, axis)
+
+ def _setAxis(self, axis):
+ """Set rotation axis.
+
+ :param QVector3D axis:
+ """
+ item = self.item()
+ if item is not None:
+ angle, _ = item.getRotation()
+ item.setRotation(angle, (axis.x(), axis.y(), axis.z()))
+
+ def _setTranslation(self, translation):
+ """Set translation transform.
+
+ :param QVector3D translation:
+ """
+ item = self.item()
+ if item is not None:
+ item.setTranslation(translation.x(), translation.y(), translation.z())
+
+ def _setScale(self, scale):
+ """Set scale transform.
+
+ :param QVector3D scale:
+ """
+ item = self.item()
+ if item is not None:
+ sx, sy, sz = scale.x(), scale.y(), scale.z()
+ if sx == 0. or sy == 0. or sz == 0.:
+ _logger.warning('Cannot set scale to 0: ignored')
+ else:
+ item.setScale(scale.x(), scale.y(), scale.z())
+
+
+class GroupItemRow(Item3DRow):
+ """Represents a :class:`GroupItem` with transforms and children
+
+ :param GroupItem item: The scene group to represent.
+ :param str name: The optional name of the group
+ """
+
+ _CHILDREN_ROW_OFFSET = 2
+ """Number of rows for group parameters. Children are added after"""
+
+ def __init__(self, item, name=None):
+ super(GroupItemRow, self).__init__(item, name)
+ self.addRow(DataItem3DBoundingBoxRow(item))
+ self.addRow(DataItem3DTransformRow(item))
+
+ item.sigItemAdded.connect(self._itemAdded)
+ item.sigItemRemoved.connect(self._itemRemoved)
+
+ for child in item.getItems():
+ self.addRow(nodeFromItem(child))
+
+ def _itemAdded(self, item):
+ """Handle item addition to the group and add it to the model.
+
+ :param Item3D item: added item
+ """
+ group = self.item()
+ if group is None:
+ return
+
+ row = group.getItems().index(item)
+ self.addRow(nodeFromItem(item), row + self._CHILDREN_ROW_OFFSET)
+
+ def _itemRemoved(self, item):
+ """Handle item removal from the group and remove it from the model.
+
+ :param Item3D item: removed item
+ """
+ group = self.item()
+ if group is None:
+ return
+
+ # Find item
+ for row in self.children():
+ if isinstance(row, Item3DRow) and row.item() is item:
+ self.removeRow(row)
+ break # Got it
+ else:
+ raise RuntimeError("Model does not correspond to scene content")
+
+
+class InterpolationRow(ItemProxyRow):
+ """Represents :class:`InterpolationMixIn` property.
+
+ :param Item3D item: Scene item with interpolation property
+ """
+
+ def __init__(self, item):
+ modes = [mode.title() for mode in item.INTERPOLATION_MODES]
+ super(InterpolationRow, self).__init__(
+ item=item,
+ name='Interpolation',
+ fget=item.getInterpolation,
+ fset=item.setInterpolation,
+ events=items.Item3DChangedType.INTERPOLATION,
+ toModelData=lambda mode: mode.title(),
+ fromModelData=lambda mode: mode.lower(),
+ editorHint=modes)
+
+
+class _ColormapBaseProxyRow(ProxyRow):
+ """Base class for colormap model row
+
+ This class handle synchronization and signals from the item and the colormap
+ """
+
+ _sigColormapChanged = qt.Signal()
+ """Signal used internally to notify colormap (or data) update"""
+
+ def __init__(self, item, *args, **kwargs):
+ self._item = weakref.ref(item)
+ self._colormap = item.getColormap()
+
+ ProxyRow.__init__(self, *args, **kwargs)
+
+ self._colormap.sigChanged.connect(self._colormapChanged)
+ item.sigItemChanged.connect(self._itemChanged)
+ self._sigColormapChanged.connect(self._modelUpdated)
+
+ def item(self):
+ """Returns the :class:`ColormapMixIn` item or None"""
+ return self._item()
+
+ def _getColormapRange(self):
+ """Returns the range of the colormap for the current data.
+
+ :return: Colormap range (min, max)
+ """
+ item = self.item()
+ if item is not None and self._colormap is not None:
+ return self._colormap.getColormapRange(item)
+ else:
+ return 1, 100 # Fallback
+
+ def _modelUpdated(self, *args, **kwargs):
+ """Emit dataChanged in the model"""
+ topLeft = self.index(column=0)
+ bottomRight = self.index(column=1)
+ model = self.model()
+ if model is not None:
+ model.dataChanged.emit(topLeft, bottomRight)
+
+ def _colormapChanged(self):
+ self._sigColormapChanged.emit()
+
+ def _itemChanged(self, event):
+ """Handle change of colormap or data in the item.
+
+ :param ItemChangedType event:
+ """
+ if event == items.ItemChangedType.COLORMAP:
+ self._sigColormapChanged.emit()
+ if self._colormap is not None:
+ self._colormap.sigChanged.disconnect(self._colormapChanged)
+
+ item = self.item()
+ if item is not None:
+ self._colormap = item.getColormap()
+ self._colormap.sigChanged.connect(self._colormapChanged)
+ else:
+ self._colormap = None
+
+ elif event == items.ItemChangedType.DATA:
+ self._sigColormapChanged.emit()
+
+
+class _ColormapBoundRow(_ColormapBaseProxyRow):
+ """ProxyRow for colormap min or max
+
+ :param ColormapMixIn item: The item to handle
+ :param str name: Name of the raw
+ :param int index: 0 for Min and 1 of Max
+ """
+
+ def __init__(self, item, name, index):
+ self._index = index
+ _ColormapBaseProxyRow.__init__(
+ self,
+ item,
+ name=name,
+ fget=self._getBound,
+ fset=self._setBound)
+
+ self.setToolTip('Colormap %s bound:\n'
+ 'Check to set bound manually, '
+ 'uncheck for autoscale' % name.lower())
+
+ def _getRawBound(self):
+ """Proxy to get raw colormap bound
+
+ :rtype: float or None
+ """
+ if self._colormap is None:
+ return None
+ elif self._index == 0:
+ return self._colormap.getVMin()
+ else: # self._index == 1
+ return self._colormap.getVMax()
+
+ def _getBound(self):
+ """Proxy to get colormap effective bound value
+
+ :rtype: float
+ """
+ if self._colormap is not None:
+ bound = self._getRawBound()
+
+ if bound is None:
+ bound = self._getColormapRange()[self._index]
+ return bound
+ else:
+ return 1. # Fallback
+
+ def _setBound(self, value):
+ """Proxy to set colormap bound.
+
+ :param float value:
+ """
+ if self._colormap is not None:
+ if self._index == 0:
+ min_ = value
+ max_ = self._colormap.getVMax()
+ else: # self._index == 1
+ min_ = self._colormap.getVMin()
+ max_ = value
+
+ if max_ is not None and min_ is not None and min_ > max_:
+ min_, max_ = max_, min_
+ self._colormap.setVRange(min_, max_)
+
+ def flags(self, column):
+ if column == 0:
+ return qt.Qt.ItemIsEnabled | qt.Qt.ItemIsUserCheckable
+
+ elif column == 1:
+ if self._getRawBound() is not None:
+ flags = qt.Qt.ItemIsEditable | qt.Qt.ItemIsEnabled
+ else:
+ flags = qt.Qt.NoItemFlags # Disabled if autoscale
+ return flags
+
+ else: # Never event
+ return super(_ColormapBoundRow, self).flags(column)
+
+ def data(self, column, role):
+ if column == 0 and role == qt.Qt.CheckStateRole:
+ if self._getRawBound() is None:
+ return qt.Qt.Unchecked
+ else:
+ return qt.Qt.Checked
+
+ else:
+ return super(_ColormapBoundRow, self).data(column, role)
+
+ def setData(self, column, value, role):
+ if column == 0 and role == qt.Qt.CheckStateRole:
+ if self._colormap is not None:
+ bound = self._getBound() if value == qt.Qt.Checked else None
+ self._setBound(bound)
+ return True
+ else:
+ return False
+
+ return super(_ColormapBoundRow, self).setData(column, value, role)
+
+
+class _ColormapGammaRow(_ColormapBaseProxyRow):
+ """ProxyRow for colormap gamma normalization parameter
+
+ :param ColormapMixIn item: The item to handle
+ :param str name: Name of the raw
+ """
+
+ def __init__(self, item):
+ _ColormapBaseProxyRow.__init__(
+ self,
+ item,
+ name="Gamma",
+ fget=self._getGammaNormalizationParameter,
+ fset=self._setGammaNormalizationParameter)
+
+ self.setToolTip('Colormap gamma correction parameter:\n'
+ 'Only meaningful for gamma normalization.')
+
+ def _getGammaNormalizationParameter(self):
+ """Proxy for :meth:`Colormap.getGammaNormalizationParameter`"""
+ if self._colormap is not None:
+ return self._colormap.getGammaNormalizationParameter()
+ else:
+ return 0.0
+
+ def _setGammaNormalizationParameter(self, gamma):
+ """Proxy for :meth:`Colormap.setGammaNormalizationParameter`"""
+ if self._colormap is not None:
+ return self._colormap.setGammaNormalizationParameter(gamma)
+
+ def _getNormalization(self):
+ """Proxy for :meth:`Colormap.getNormalization`"""
+ if self._colormap is not None:
+ return self._colormap.getNormalization()
+ else:
+ return ''
+
+ def flags(self, column):
+ if column in (0, 1):
+ if self._getNormalization() == 'gamma':
+ flags = qt.Qt.ItemIsEditable | qt.Qt.ItemIsEnabled
+ else:
+ flags = qt.Qt.NoItemFlags # Disabled if not gamma correction
+ return flags
+
+ else: # Never event
+ return super(_ColormapGammaRow, self).flags(column)
+
+
+class ColormapRow(_ColormapBaseProxyRow):
+ """Represents :class:`ColormapMixIn` property.
+
+ :param Item3D item: Scene item with colormap property
+ """
+
+ def __init__(self, item):
+ super(ColormapRow, self).__init__(
+ item,
+ name='Colormap',
+ fget=self._get)
+
+ self._colormapImage = None
+
+ self._colormapsMapping = {}
+ for cmap in preferredColormaps():
+ self._colormapsMapping[cmap.title()] = cmap
+
+ self.addRow(ProxyRow(
+ name='Name',
+ fget=self._getName,
+ fset=self._setName,
+ notify=self._sigColormapChanged,
+ editorHint=list(self._colormapsMapping.keys())))
+
+ norms = [norm.title() for norm in self._colormap.NORMALIZATIONS]
+ self.addRow(ProxyRow(
+ name='Normalization',
+ fget=self._getNormalization,
+ fset=self._setNormalization,
+ notify=self._sigColormapChanged,
+ editorHint=norms))
+
+ self.addRow(_ColormapGammaRow(item))
+
+ modes = [mode.title() for mode in self._colormap.AUTOSCALE_MODES]
+ self.addRow(ProxyRow(
+ name='Autoscale Mode',
+ fget=self._getAutoscaleMode,
+ fset=self._setAutoscaleMode,
+ notify=self._sigColormapChanged,
+ editorHint=modes))
+
+ self.addRow(_ColormapBoundRow(item, name='Min.', index=0))
+ self.addRow(_ColormapBoundRow(item, name='Max.', index=1))
+
+ self._sigColormapChanged.connect(self._updateColormapImage)
+
+ def getColormapImage(self):
+ """Returns image representing the colormap or None
+
+ :rtype: Union[QImage,None]
+ """
+ if self._colormapImage is None and self._colormap is not None:
+ image = numpy.zeros((16, 130, 3), dtype=numpy.uint8)
+ image[1:-1, 1:-1] = self._colormap.getNColors(image.shape[1] - 2)[:, :3]
+ self._colormapImage = convertArrayToQImage(image)
+ return self._colormapImage
+
+ def _get(self):
+ """Getter for ProxyRow subclass"""
+ return None
+
+ def _getName(self):
+ """Proxy for :meth:`Colormap.getName`"""
+ if self._colormap is not None and self._colormap.getName() is not None:
+ return self._colormap.getName().title()
+ else:
+ return ''
+
+ def _setName(self, name):
+ """Proxy for :meth:`Colormap.setName`"""
+ # Convert back from titled to name if possible
+ if self._colormap is not None:
+ name = self._colormapsMapping.get(name, name)
+ self._colormap.setName(name)
+
+ def _getNormalization(self):
+ """Proxy for :meth:`Colormap.getNormalization`"""
+ if self._colormap is not None:
+ return self._colormap.getNormalization().title()
+ else:
+ return ''
+
+ def _setNormalization(self, normalization):
+ """Proxy for :meth:`Colormap.setNormalization`"""
+ if self._colormap is not None:
+ return self._colormap.setNormalization(normalization.lower())
+
+ def _getAutoscaleMode(self):
+ """Proxy for :meth:`Colormap.getAutoscaleMode`"""
+ if self._colormap is not None:
+ return self._colormap.getAutoscaleMode().title()
+ else:
+ return ''
+
+ def _setAutoscaleMode(self, mode):
+ """Proxy for :meth:`Colormap.setAutoscaleMode`"""
+ if self._colormap is not None:
+ return self._colormap.setAutoscaleMode(mode.lower())
+
+ def _updateColormapImage(self, *args, **kwargs):
+ """Notify colormap update to update the image in the tree"""
+ if self._colormapImage is not None:
+ self._colormapImage = None
+ model = self.model()
+ if model is not None:
+ index = self.index(column=1)
+ model.dataChanged.emit(index, index)
+
+ def data(self, column, role):
+ if column == 1 and role == qt.Qt.DecorationRole:
+ return self.getColormapImage()
+ else:
+ return super(ColormapRow, self).data(column, role)
+
+
+class SymbolRow(ItemProxyRow):
+ """Represents :class:`SymbolMixIn` symbol property.
+
+ :param Item3D item: Scene item with symbol property
+ """
+
+ def __init__(self, item):
+ names = [item.getSymbolName(s) for s in item.getSupportedSymbols()]
+ super(SymbolRow, self).__init__(
+ item=item,
+ name='Marker',
+ fget=item.getSymbolName,
+ fset=item.setSymbol,
+ events=items.ItemChangedType.SYMBOL,
+ editorHint=names)
+
+
+class SymbolSizeRow(ItemProxyRow):
+ """Represents :class:`SymbolMixIn` symbol size property.
+
+ :param Item3D item: Scene item with symbol size property
+ """
+
+ def __init__(self, item):
+ super(SymbolSizeRow, self).__init__(
+ item=item,
+ name='Marker size',
+ fget=item.getSymbolSize,
+ fset=item.setSymbolSize,
+ events=items.ItemChangedType.SYMBOL_SIZE,
+ editorHint=(1, 20)) # TODO link with OpenGL max point size
+
+
+class PlaneEquationRow(ItemProxyRow):
+ """Represents :class:`PlaneMixIn` as plane equation.
+
+ :param Item3D item: Scene item with plane equation property
+ """
+
+ def __init__(self, item):
+ super(PlaneEquationRow, self).__init__(
+ item=item,
+ name='Equation',
+ fget=item.getParameters,
+ fset=item.setParameters,
+ events=items.ItemChangedType.POSITION,
+ toModelData=lambda data: qt.QVector4D(*data),
+ fromModelData=lambda data: (data.x(), data.y(), data.z(), data.w()))
+ self._item = weakref.ref(item)
+
+ def data(self, column, role):
+ if column == 1 and role == qt.Qt.DisplayRole:
+ item = self._item()
+ if item is not None:
+ params = item.getParameters()
+ return ('%gx %+gy %+gz %+g = 0' %
+ (params[0], params[1], params[2], params[3]))
+ return super(PlaneEquationRow, self).data(column, role)
+
+
+class PlaneRow(ItemProxyRow):
+ """Represents :class:`PlaneMixIn` property.
+
+ :param Item3D item: Scene item with plane equation property
+ """
+
+ _PLANES = OrderedDict((('Plane 0', (1., 0., 0.)),
+ ('Plane 1', (0., 1., 0.)),
+ ('Plane 2', (0., 0., 1.)),
+ ('-', None)))
+ """Mapping of plane names to normals"""
+
+ _PLANE_ICONS = {'Plane 0': '3d-plane-normal-x',
+ 'Plane 1': '3d-plane-normal-y',
+ 'Plane 2': '3d-plane-normal-z',
+ '-': '3d-plane'}
+ """Mapping of plane names to normals"""
+
+ def __init__(self, item):
+ super(PlaneRow, self).__init__(
+ item=item,
+ name='Plane',
+ fget=self.__getPlaneName,
+ fset=self.__setPlaneName,
+ events=items.ItemChangedType.POSITION,
+ editorHint=tuple(self._PLANES.keys()))
+ self._item = weakref.ref(item)
+ self._lastName = None
+
+ self.addRow(PlaneEquationRow(item))
+
+ def _notified(self, *args, **kwargs):
+ """Handle notification of modification
+
+ Here only send if plane name actually changed
+ """
+ if self._lastName != self.__getPlaneName():
+ super(PlaneRow, self)._notified()
+
+ def __getPlaneName(self):
+ """Returns name of plane // to axes or '-'
+
+ :rtype: str
+ """
+ item = self._item()
+ planeNormal = item.getNormal() if item is not None else None
+
+ for name, normal in self._PLANES.items():
+ if numpy.array_equal(planeNormal, normal):
+ return name
+ return '-'
+
+ def __setPlaneName(self, data):
+ """Set plane normal according to given plane name
+
+ :param str data: Selected plane name
+ """
+ item = self._item()
+ if item is not None:
+ for name, normal in self._PLANES.items():
+ if data == name and normal is not None:
+ item.setNormal(normal)
+
+ def data(self, column, role):
+ if column == 1 and role == qt.Qt.DecorationRole:
+ return icons.getQIcon(self._PLANE_ICONS[self.__getPlaneName()])
+ data = super(PlaneRow, self).data(column, role)
+ if column == 1 and role == qt.Qt.DisplayRole:
+ self._lastName = data
+ return data
+
+
+class ComplexModeRow(ItemProxyRow):
+ """Represents :class:`items.ComplexMixIn` symbol property.
+
+ :param Item3D item: Scene item with symbol property
+ """
+
+ def __init__(self, item, name='Mode'):
+ names = [m.value.replace('_', ' ').title()
+ for m in item.supportedComplexModes()]
+ super(ComplexModeRow, self).__init__(
+ item=item,
+ name=name,
+ fget=item.getComplexMode,
+ fset=item.setComplexMode,
+ events=items.ItemChangedType.COMPLEX_MODE,
+ toModelData=lambda data: data.value.replace('_', ' ').title(),
+ fromModelData=lambda data: data.lower().replace(' ', '_'),
+ editorHint=names)
+
+
+class RemoveIsosurfaceRow(BaseRow):
+ """Class for Isosurface Delete button
+
+ :param Isosurface isosurface: The isosurface item to attach the button to.
+ """
+
+ def __init__(self, isosurface):
+ super(RemoveIsosurfaceRow, self).__init__()
+ self._isosurface = weakref.ref(isosurface)
+
+ def createEditor(self):
+ """Specific editor factory provided to the model"""
+ editor = qt.QWidget()
+ layout = qt.QHBoxLayout(editor)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(0)
+
+ removeBtn = qt.QToolButton()
+ removeBtn.setText('Delete')
+ removeBtn.setToolButtonStyle(qt.Qt.ToolButtonTextOnly)
+ layout.addWidget(removeBtn)
+ removeBtn.clicked.connect(self._removeClicked)
+
+ layout.addStretch(1)
+ return editor
+
+ def isosurface(self):
+ """Returns the controlled isosurface
+
+ :rtype: Isosurface
+ """
+ return self._isosurface()
+
+ def data(self, column, role):
+ if column == 0 and role == qt.Qt.UserRole: # editor hint
+ return self.createEditor
+
+ return super(RemoveIsosurfaceRow, self).data(column, role)
+
+ def flags(self, column):
+ flags = super(RemoveIsosurfaceRow, self).flags(column)
+ if column == 0:
+ flags |= qt.Qt.ItemIsEditable
+ return flags
+
+ def _removeClicked(self):
+ """Handle Delete button clicked"""
+ isosurface = self.isosurface()
+ if isosurface is not None:
+ volume = isosurface.parent()
+ if volume is not None:
+ volume.removeIsosurface(isosurface)
+
+
+class IsosurfaceRow(Item3DRow):
+ """Represents an :class:`Isosurface` item.
+
+ :param Isosurface item: Isosurface item
+ """
+
+ _LEVEL_SLIDER_RANGE = 0, 1000
+ """Range given as editor hint"""
+
+ _EVENTS = items.ItemChangedType.VISIBLE, items.ItemChangedType.COLOR
+ """Events for which to update the first column in the tree"""
+
+ def __init__(self, item):
+ super(IsosurfaceRow, self).__init__(item, name=item.getLevel())
+
+ self.setFlags(self.flags(1) | qt.Qt.ItemIsEditable, 1)
+
+ item.sigItemChanged.connect(self._levelChanged)
+
+ self.addRow(ItemProxyRow(
+ item=item,
+ name='Level',
+ fget=self._getValueForLevelSlider,
+ fset=self._setLevelFromSliderValue,
+ events=items.Item3DChangedType.ISO_LEVEL,
+ editorHint=self._LEVEL_SLIDER_RANGE))
+
+ self.addRow(ItemColorProxyRow(
+ item=item,
+ name='Color',
+ fget=self._rgbColor,
+ fset=self._setRgbColor,
+ events=items.ItemChangedType.COLOR))
+
+ self.addRow(ItemProxyRow(
+ item=item,
+ name='Opacity',
+ fget=self._opacity,
+ fset=self._setOpacity,
+ events=items.ItemChangedType.COLOR,
+ editorHint=(0, 255)))
+
+ self.addRow(RemoveIsosurfaceRow(item))
+
+ def _getValueForLevelSlider(self):
+ """Convert iso level to slider value.
+
+ :rtype: int
+ """
+ item = self.item()
+ if item is not None:
+ volume = item.parent()
+ if volume is not None:
+ dataRange = volume.getDataRange()
+ if dataRange is not None:
+ dataMin, dataMax = dataRange[0], dataRange[-1]
+ if dataMax != dataMin:
+ offset = (item.getLevel() - dataMin) / (dataMax - dataMin)
+ else:
+ offset = 0.
+
+ sliderMin, sliderMax = self._LEVEL_SLIDER_RANGE
+ value = sliderMin + (sliderMax - sliderMin) * offset
+ return value
+ return 0
+
+ def _setLevelFromSliderValue(self, value):
+ """Convert slider value to isolevel.
+
+ :param int value:
+ """
+ item = self.item()
+ if item is not None:
+ volume = item.parent()
+ if volume is not None:
+ dataRange = volume.getDataRange()
+ if dataRange is not None:
+ sliderMin, sliderMax = self._LEVEL_SLIDER_RANGE
+ offset = (value - sliderMin) / (sliderMax - sliderMin)
+
+ dataMin, dataMax = dataRange[0], dataRange[-1]
+ level = dataMin + (dataMax - dataMin) * offset
+ item.setLevel(level)
+
+ def _rgbColor(self):
+ """Proxy to get the isosurface's RGB color without transparency
+
+ :rtype: QColor
+ """
+ item = self.item()
+ if item is None:
+ return None
+ else:
+ color = item.getColor()
+ color.setAlpha(255)
+ return color
+
+ def _setRgbColor(self, color):
+ """Proxy to set the isosurface's RGB color without transparency
+
+ :param QColor color:
+ """
+ item = self.item()
+ if item is not None:
+ color.setAlpha(item.getColor().alpha())
+ item.setColor(color)
+
+ def _opacity(self):
+ """Proxy to get the isosurface's transparency
+
+ :rtype: int
+ """
+ item = self.item()
+ return 255 if item is None else item.getColor().alpha()
+
+ def _setOpacity(self, opacity):
+ """Proxy to set the isosurface's transparency.
+
+ :param int opacity:
+ """
+ item = self.item()
+ if item is not None:
+ color = item.getColor()
+ color.setAlpha(opacity)
+ item.setColor(color)
+
+ def _levelChanged(self, event):
+ """Handle isosurface level changed and notify model
+
+ :param ItemChangedType event:
+ """
+ if event == items.Item3DChangedType.ISO_LEVEL:
+ model = self.model()
+ if model is not None:
+ index = self.index(column=1)
+ model.dataChanged.emit(index, index)
+
+ def data(self, column, role):
+ if column == 0: # Show color as decoration, not text
+ if role == qt.Qt.DisplayRole:
+ return None
+ elif role == qt.Qt.DecorationRole:
+ return self._rgbColor()
+
+ elif column == 1 and role in (qt.Qt.DisplayRole, qt.Qt.EditRole):
+ item = self.item()
+ return None if item is None else item.getLevel()
+
+ return super(IsosurfaceRow, self).data(column, role)
+
+ def setData(self, column, value, role):
+ if column == 1 and role == qt.Qt.EditRole:
+ item = self.item()
+ if item is not None:
+ item.setLevel(value)
+ return True
+
+ return super(IsosurfaceRow, self).setData(column, value, role)
+
+
+class ComplexIsosurfaceRow(IsosurfaceRow):
+ """Represents an :class:`ComplexIsosurface` item.
+
+ :param ComplexIsosurface item:
+ """
+
+ _EVENTS = (items.ItemChangedType.VISIBLE,
+ items.ItemChangedType.COLOR,
+ items.ItemChangedType.COMPLEX_MODE)
+ """Events for which to update the first column in the tree"""
+
+ def __init__(self, item):
+ super(ComplexIsosurfaceRow, self).__init__(item)
+
+ self.addRow(ComplexModeRow(item, "Color Complex Mode"), index=1)
+ for row in self.children():
+ if isinstance(row, ColorProxyRow):
+ self._colorRow = row
+ break
+ else:
+ raise RuntimeError("Cannot retrieve Color tree row")
+ self._colormapRow = ColormapRow(item)
+
+ self.__updateRowsForItem(item)
+ item.sigItemChanged.connect(self._itemChanged)
+
+ def _itemChanged(self, event):
+ """Update enabled/disabled rows"""
+ if event == items.ItemChangedType.COMPLEX_MODE:
+ item = self.sender()
+ self.__updateRowsForItem(item)
+
+ def __updateRowsForItem(self, item):
+ """Update rows for item
+
+ :param item:
+ """
+ if not isinstance(item, ComplexIsosurface):
+ return
+
+ if item.getComplexMode() == items.ComplexMixIn.ComplexMode.NONE:
+ removed = self._colormapRow
+ added = self._colorRow
+ else:
+ removed = self._colorRow
+ added = self._colormapRow
+
+ # Remove unwanted rows
+ if removed in self.children():
+ self.removeRow(removed)
+
+ # Add required rows
+ if added not in self.children():
+ self.addRow(added, index=2)
+
+ def data(self, column, role):
+ if column == 0 and role == qt.Qt.DecorationRole:
+ item = self.item()
+ if (item is not None and
+ item.getComplexMode() != items.ComplexMixIn.ComplexMode.NONE):
+ return self._colormapRow.getColormapImage()
+
+ return super(ComplexIsosurfaceRow, self).data(column, role)
+
+
+class AddIsosurfaceRow(BaseRow):
+ """Class for Isosurface create button
+
+ :param Union[ScalarField3D,ComplexField3D] volume:
+ The volume item to attach the button to.
+ """
+
+ def __init__(self, volume):
+ super(AddIsosurfaceRow, self).__init__()
+ self._volume = weakref.ref(volume)
+
+ def createEditor(self):
+ """Specific editor factory provided to the model"""
+ editor = qt.QWidget()
+ layout = qt.QHBoxLayout(editor)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(0)
+
+ addBtn = qt.QToolButton()
+ addBtn.setText('+')
+ addBtn.setToolButtonStyle(qt.Qt.ToolButtonTextOnly)
+ layout.addWidget(addBtn)
+ addBtn.clicked.connect(self._addClicked)
+
+ layout.addStretch(1)
+ return editor
+
+ def volume(self):
+ """Returns the controlled volume item
+
+ :rtype: Union[ScalarField3D,ComplexField3D]
+ """
+ return self._volume()
+
+ def data(self, column, role):
+ if column == 0 and role == qt.Qt.UserRole: # editor hint
+ return self.createEditor
+
+ return super(AddIsosurfaceRow, self).data(column, role)
+
+ def flags(self, column):
+ flags = super(AddIsosurfaceRow, self).flags(column)
+ if column == 0:
+ flags |= qt.Qt.ItemIsEditable
+ return flags
+
+ def _addClicked(self):
+ """Handle Delete button clicked"""
+ volume = self.volume()
+ if volume is not None:
+ dataRange = volume.getDataRange()
+ if dataRange is None:
+ dataRange = 0., 1.
+
+ volume.addIsosurface(
+ numpy.mean((dataRange[0], dataRange[-1])),
+ '#0000FF')
+
+
+class VolumeIsoSurfacesRow(StaticRow):
+ """Represents :class:`ScalarFieldView`'s isosurfaces
+
+ :param Union[ScalarField3D,ComplexField3D] volume:
+ Volume item to control
+ """
+
+ def __init__(self, volume):
+ super(VolumeIsoSurfacesRow, self).__init__(
+ ('Isosurfaces', None))
+ self._volume = weakref.ref(volume)
+
+ volume.sigIsosurfaceAdded.connect(self._isosurfaceAdded)
+ volume.sigIsosurfaceRemoved.connect(self._isosurfaceRemoved)
+
+ if isinstance(volume, items.ComplexMixIn):
+ self.addRow(ComplexModeRow(volume, "Complex Mode"))
+
+ for item in volume.getIsosurfaces():
+ self.addRow(nodeFromItem(item))
+
+ self.addRow(AddIsosurfaceRow(volume))
+
+ def volume(self):
+ """Returns the controlled volume item
+
+ :rtype: Union[ScalarField3D,ComplexField3D]
+ """
+ return self._volume()
+
+ def _isosurfaceAdded(self, item):
+ """Handle isosurface addition
+
+ :param Isosurface item: added isosurface
+ """
+ volume = self.volume()
+ if volume is None:
+ return
+
+ row = volume.getIsosurfaces().index(item)
+ if isinstance(volume, items.ComplexMixIn):
+ row += 1 # Offset for the ComplexModeRow
+ self.addRow(nodeFromItem(item), row)
+
+ def _isosurfaceRemoved(self, item):
+ """Handle isosurface removal
+
+ :param Isosurface item: removed isosurface
+ """
+ volume = self.volume()
+ if volume is None:
+ return
+
+ # Find item
+ for row in self.children():
+ if isinstance(row, IsosurfaceRow) and row.item() is item:
+ self.removeRow(row)
+ break # Got it
+ else:
+ raise RuntimeError("Model does not correspond to scene content")
+
+
+class Scatter2DPropertyMixInRow(object):
+ """Mix-in class that enable/disable row according to Scatter2D mode.
+
+ :param Scatter2D item:
+ :param str propertyName: Name of the Scatter2D property of this row
+ """
+
+ def __init__(self, item, propertyName):
+ assert propertyName in ('lineWidth', 'symbol', 'symbolSize')
+ self.__propertyName = propertyName
+
+ self.__isEnabled = item.isPropertyEnabled(propertyName)
+ self.__updateFlags()
+
+ item.sigItemChanged.connect(self._itemChanged)
+
+ def data(self, column, role):
+ if column == 1 and not self.__isEnabled:
+ # Discard data and editorHint if disabled
+ return None
+ else:
+ return super(Scatter2DPropertyMixInRow, self).data(column, role)
+
+ def __updateFlags(self):
+ """Update model flags"""
+ if self.__isEnabled:
+ self.setFlags(qt.Qt.ItemIsEnabled, 0)
+ self.setFlags(qt.Qt.ItemIsEnabled | qt.Qt.ItemIsEditable, 1)
+ else:
+ self.setFlags(qt.Qt.NoItemFlags)
+
+ def _itemChanged(self, event):
+ """Set flags to enable/disable the row"""
+ if event == items.ItemChangedType.VISUALIZATION_MODE:
+ item = self.sender()
+ if item is not None: # This occurs with PySide/python2.7
+ self.__isEnabled = item.isPropertyEnabled(self.__propertyName)
+ self.__updateFlags()
+
+ # Notify model
+ model = self.model()
+ if model is not None:
+ begin = self.index(column=0)
+ end = self.index(column=1)
+ model.dataChanged.emit(begin, end)
+
+
+class Scatter2DSymbolRow(Scatter2DPropertyMixInRow, SymbolRow):
+ """Specific class for Scatter2D symbol.
+
+ It is enabled/disabled according to visualization mode.
+
+ :param Scatter2D item:
+ """
+
+ def __init__(self, item):
+ SymbolRow.__init__(self, item)
+ Scatter2DPropertyMixInRow.__init__(self, item, 'symbol')
+
+
+class Scatter2DSymbolSizeRow(Scatter2DPropertyMixInRow, SymbolSizeRow):
+ """Specific class for Scatter2D symbol size.
+
+ It is enabled/disabled according to visualization mode.
+
+ :param Scatter2D item:
+ """
+
+ def __init__(self, item):
+ SymbolSizeRow.__init__(self, item)
+ Scatter2DPropertyMixInRow.__init__(self, item, 'symbolSize')
+
+
+class Scatter2DLineWidth(Scatter2DPropertyMixInRow, ItemProxyRow):
+ """Specific class for Scatter2D symbol size.
+
+ It is enabled/disabled according to visualization mode.
+
+ :param Scatter2D item:
+ """
+
+ def __init__(self, item):
+ # TODO link editorHint with OpenGL max line width
+ ItemProxyRow.__init__(self,
+ item=item,
+ name='Line width',
+ fget=item.getLineWidth,
+ fset=item.setLineWidth,
+ events=items.ItemChangedType.LINE_WIDTH,
+ editorHint=(1, 10))
+ Scatter2DPropertyMixInRow.__init__(self, item, 'lineWidth')
+
+
+def initScatter2DNode(node, item):
+ """Specific node init for Scatter2D to set order of parameters
+
+ :param Item3DRow node: The model node to setup
+ :param Scatter2D item: The Scatter2D the node is representing
+ """
+ node.addRow(ItemProxyRow(
+ item=item,
+ name='Mode',
+ fget=item.getVisualization,
+ fset=item.setVisualization,
+ events=items.ItemChangedType.VISUALIZATION_MODE,
+ editorHint=[m.value.title() for m in item.supportedVisualizations()],
+ toModelData=lambda data: data.value.title(),
+ fromModelData=lambda data: data.lower()))
+
+ node.addRow(ItemProxyRow(
+ item=item,
+ name='Height map',
+ fget=item.isHeightMap,
+ fset=item.setHeightMap,
+ events=items.Item3DChangedType.HEIGHT_MAP))
+
+ node.addRow(ColormapRow(item))
+
+ node.addRow(Scatter2DSymbolRow(item))
+ node.addRow(Scatter2DSymbolSizeRow(item))
+
+ node.addRow(Scatter2DLineWidth(item))
+
+
+def initVolumeNode(node, item):
+ """Specific node init for volume items
+
+ :param Item3DRow node: The model node to setup
+ :param Union[ScalarField3D,ComplexField3D] item:
+ The volume item represented by the node
+ """
+ node.addRow(nodeFromItem(item.getCutPlanes()[0])) # Add cut plane
+ node.addRow(VolumeIsoSurfacesRow(item))
+
+
+def initVolumeCutPlaneNode(node, item):
+ """Specific node init for volume CutPlane
+
+ :param Item3DRow node: The model node to setup
+ :param CutPlane item: The CutPlane the node is representing
+ """
+ if isinstance(item, items.ComplexMixIn):
+ node.addRow(ComplexModeRow(item))
+
+ node.addRow(PlaneRow(item))
+
+ node.addRow(ColormapRow(item))
+
+ node.addRow(ItemProxyRow(
+ item=item,
+ name='Show <=Min',
+ fget=item.getDisplayValuesBelowMin,
+ fset=item.setDisplayValuesBelowMin,
+ events=items.ItemChangedType.ALPHA))
+
+ node.addRow(InterpolationRow(item))
+
+
+NODE_SPECIFIC_INIT = [ # class, init(node, item)
+ (items.Scatter2D, initScatter2DNode),
+ (items.ScalarField3D, initVolumeNode),
+ (CutPlane, initVolumeCutPlaneNode),
+]
+"""List of specific node init for different item class"""
+
+
+def nodeFromItem(item):
+ """Create :class:`Item3DRow` subclass corresponding to item
+
+ :param Item3D item: The item fow which to create the node
+ :rtype: Item3DRow
+ """
+ assert isinstance(item, items.Item3D)
+
+ # Item with specific model row class
+ if isinstance(item, (items.GroupItem, items.GroupWithAxesItem)):
+ return GroupItemRow(item)
+ elif isinstance(item, ComplexIsosurface):
+ return ComplexIsosurfaceRow(item)
+ elif isinstance(item, Isosurface):
+ return IsosurfaceRow(item)
+
+ # Create Item3DRow and populate it
+ node = Item3DRow(item)
+
+ if isinstance(item, items.DataItem3D):
+ node.addRow(DataItem3DBoundingBoxRow(item))
+ node.addRow(DataItem3DTransformRow(item))
+
+ # Specific extra init
+ for cls, specificInit in NODE_SPECIFIC_INIT:
+ if isinstance(item, cls):
+ specificInit(node, item)
+ break
+
+ else: # Generic case: handle mixins
+ for cls in item.__class__.__mro__:
+ if cls is items.ColormapMixIn:
+ node.addRow(ColormapRow(item))
+
+ elif cls is items.InterpolationMixIn:
+ node.addRow(InterpolationRow(item))
+
+ elif cls is items.SymbolMixIn:
+ node.addRow(SymbolRow(item))
+ node.addRow(SymbolSizeRow(item))
+
+ elif cls is items.PlaneMixIn:
+ node.addRow(PlaneRow(item))
+
+ return node
diff --git a/src/silx/gui/plot3d/_model/model.py b/src/silx/gui/plot3d/_model/model.py
new file mode 100644
index 0000000..186838f
--- /dev/null
+++ b/src/silx/gui/plot3d/_model/model.py
@@ -0,0 +1,184 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module provides the :class:`SceneWidget` content and parameters model.
+"""
+
+from __future__ import absolute_import, division
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "11/01/2018"
+
+
+import weakref
+
+from ... import qt
+
+from .core import BaseRow
+from .items import Settings, nodeFromItem
+
+
+def visitQAbstractItemModel(model, parent=qt.QModelIndex()):
+ """Iterate over indices in the model starting from parent
+
+ It iterates column by column and row by row
+ (i.e., from left to right and from top to bottom).
+ Parent are returned before their children.
+ It only iterates through the children for the first column of a row.
+
+ :param QAbstractItemModel model: The model to visit
+ :param QModelIndex parent:
+ Index from which to start visiting the model.
+ Default: start from the root
+ """
+ assert isinstance(model, qt.QAbstractItemModel)
+ assert isinstance(parent, qt.QModelIndex)
+ assert parent.model() is model or not parent.isValid()
+
+ for row in range(model.rowCount(parent)):
+ for column in range(model.columnCount(parent)):
+ index = model.index(row, column, parent)
+ yield index
+
+ index = model.index(row, 0, parent)
+ for index in visitQAbstractItemModel(model, index):
+ yield index
+
+
+class Root(BaseRow):
+ """Root node of :class:`SceneWidget` parameters.
+
+ It has two children:
+ - Settings
+ - Scene group
+ """
+
+ def __init__(self, model, sceneWidget):
+ super(Root, self).__init__()
+ self._sceneWidget = weakref.ref(sceneWidget)
+ self.setParent(model) # Needed for Root
+
+ def children(self):
+ sceneWidget = self._sceneWidget()
+ if sceneWidget is None:
+ return ()
+ else:
+ return super(Root, self).children()
+
+
+class SceneModel(qt.QAbstractItemModel):
+ """Model of a :class:`SceneWidget`.
+
+ :param SceneWidget parent: The SceneWidget this model represents.
+ """
+
+ def __init__(self, parent):
+ self._sceneWidget = weakref.ref(parent)
+
+ super(SceneModel, self).__init__(parent)
+ self._root = Root(self, parent)
+ self._root.addRow(Settings(parent))
+ self._root.addRow(nodeFromItem(parent.getSceneGroup()))
+
+ def sceneWidget(self):
+ """Returns the :class:`SceneWidget` this model represents.
+
+ In case the widget has already been deleted, it returns None
+
+ :rtype: SceneWidget
+ """
+ return self._sceneWidget()
+
+ def _itemFromIndex(self, index):
+ """Returns the corresponding :class:`Node` or :class:`Item3D`.
+
+ :param QModelIndex index:
+ :rtype: Node or Item3D
+ """
+ return index.internalPointer() if index.isValid() else self._root
+
+ def index(self, row, column, parent=qt.QModelIndex()):
+ """See :meth:`QAbstractItemModel.index`"""
+ if column >= self.columnCount(parent) or row >= self.rowCount(parent):
+ return qt.QModelIndex()
+
+ item = self._itemFromIndex(parent)
+ return self.createIndex(row, column, item.children()[row])
+
+ def parent(self, index):
+ """See :meth:`QAbstractItemModel.parent`"""
+ if not index.isValid():
+ return qt.QModelIndex()
+
+ item = self._itemFromIndex(index)
+ parent = item.parent()
+
+ ancestor = parent.parent()
+
+ if ancestor is not self: # root node
+ children = ancestor.children()
+ row = children.index(parent)
+ return self.createIndex(row, 0, parent)
+
+ return qt.QModelIndex()
+
+ def rowCount(self, parent=qt.QModelIndex()):
+ """See :meth:`QAbstractItemModel.rowCount`"""
+ item = self._itemFromIndex(parent)
+ return item.rowCount()
+
+ def columnCount(self, parent=qt.QModelIndex()):
+ """See :meth:`QAbstractItemModel.columnCount`"""
+ item = self._itemFromIndex(parent)
+ return item.columnCount()
+
+ def data(self, index, role=qt.Qt.DisplayRole):
+ """See :meth:`QAbstractItemModel.data`"""
+ item = self._itemFromIndex(index)
+ column = index.column()
+ return item.data(column, role)
+
+ def setData(self, index, value, role=qt.Qt.EditRole):
+ """See :meth:`QAbstractItemModel.setData`"""
+ item = self._itemFromIndex(index)
+ column = index.column()
+ if item.setData(column, value, role):
+ self.dataChanged.emit(index, index)
+ return True
+ return False
+
+ def flags(self, index):
+ """See :meth:`QAbstractItemModel.flags`"""
+ item = self._itemFromIndex(index)
+ column = index.column()
+ return item.flags(column)
+
+ def headerData(self, section, orientation, role=qt.Qt.DisplayRole):
+ """See :meth:`QAbstractItemModel.headerData`"""
+ if orientation == qt.Qt.Horizontal and role == qt.Qt.DisplayRole:
+ return 'Item' if section == 0 else 'Value'
+ else:
+ return None
diff --git a/src/silx/gui/plot3d/actions/Plot3DAction.py b/src/silx/gui/plot3d/actions/Plot3DAction.py
new file mode 100644
index 0000000..94b9572
--- /dev/null
+++ b/src/silx/gui/plot3d/actions/Plot3DAction.py
@@ -0,0 +1,71 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Base class for QAction attached to a Plot3DWidget."""
+
+from __future__ import absolute_import, division
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "06/09/2017"
+
+
+import logging
+import weakref
+
+from silx.gui import qt
+
+
+_logger = logging.getLogger(__name__)
+
+
+class Plot3DAction(qt.QAction):
+ """QAction associated to a Plot3DWidget
+
+ :param parent: See :class:`QAction`
+ :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
+ Plot3DWidget the action is associated with
+ """
+
+ def __init__(self, parent, plot3d=None):
+ super(Plot3DAction, self).__init__(parent)
+ self._plot3d = None
+ self.setPlot3DWidget(plot3d)
+
+ def setPlot3DWidget(self, widget):
+ """Set the Plot3DWidget this action is associated with
+
+ :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget widget:
+ The Plot3DWidget to use
+ """
+ self._plot3d = None if widget is None else weakref.ref(widget)
+
+ def getPlot3DWidget(self):
+ """Return the Plot3DWidget associated to this action.
+
+ If no widget is associated, it returns None.
+
+ :rtype: QWidget
+ """
+ return None if self._plot3d is None else self._plot3d()
diff --git a/src/silx/gui/plot3d/actions/__init__.py b/src/silx/gui/plot3d/actions/__init__.py
new file mode 100644
index 0000000..26243cf
--- /dev/null
+++ b/src/silx/gui/plot3d/actions/__init__.py
@@ -0,0 +1,34 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides QAction that can be attached to a plot3DWidget."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "06/09/2017"
+
+from .Plot3DAction import Plot3DAction # noqa
+from . import viewpoint # noqa
+from . import io # noqa
+from . import mode # noqa
diff --git a/src/silx/gui/plot3d/actions/io.py b/src/silx/gui/plot3d/actions/io.py
new file mode 100644
index 0000000..25f4ade
--- /dev/null
+++ b/src/silx/gui/plot3d/actions/io.py
@@ -0,0 +1,337 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides Plot3DAction related to input/output.
+
+It provides QAction to copy, save (snapshot and video), print a Plot3DWidget.
+"""
+
+from __future__ import absolute_import, division
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "06/09/2017"
+
+
+import logging
+import os
+
+import numpy
+
+from silx.gui import qt, printer
+from silx.gui.icons import getQIcon
+from .Plot3DAction import Plot3DAction
+from ..utils import mng
+from ...utils.image import convertQImageToArray
+
+
+_logger = logging.getLogger(__name__)
+
+
+class CopyAction(Plot3DAction):
+ """QAction to provide copy of a Plot3DWidget
+
+ :param parent: See :class:`QAction`
+ :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
+ Plot3DWidget the action is associated with
+ """
+
+ def __init__(self, parent, plot3d=None):
+ super(CopyAction, self).__init__(parent, plot3d)
+
+ self.setIcon(getQIcon('edit-copy'))
+ self.setText('Copy')
+ self.setToolTip('Copy a snapshot of the 3D scene to the clipboard')
+ self.setCheckable(False)
+ self.setShortcut(qt.QKeySequence.Copy)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+ self.triggered[bool].connect(self._triggered)
+
+ def _triggered(self, checked=False):
+ plot3d = self.getPlot3DWidget()
+ if plot3d is None:
+ _logger.error('Cannot copy widget, no associated Plot3DWidget')
+ else:
+ image = plot3d.grabGL()
+ qt.QApplication.clipboard().setImage(image)
+
+
+class SaveAction(Plot3DAction):
+ """QAction to provide save snapshot of a Plot3DWidget
+
+ :param parent: See :class:`QAction`
+ :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
+ Plot3DWidget the action is associated with
+ """
+
+ def __init__(self, parent, plot3d=None):
+ super(SaveAction, self).__init__(parent, plot3d)
+
+ self.setIcon(getQIcon('document-save'))
+ self.setText('Save...')
+ self.setToolTip('Save a snapshot of the 3D scene')
+ self.setCheckable(False)
+ self.setShortcut(qt.QKeySequence.Save)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+ self.triggered[bool].connect(self._triggered)
+
+ def _triggered(self, checked=False):
+ plot3d = self.getPlot3DWidget()
+ if plot3d is None:
+ _logger.error('Cannot save widget, no associated Plot3DWidget')
+ else:
+ dialog = qt.QFileDialog(self.parent())
+ dialog.setWindowTitle('Save snapshot as')
+ dialog.setModal(True)
+ dialog.setNameFilters(('Plot3D Snapshot PNG (*.png)',
+ 'Plot3D Snapshot JPEG (*.jpg)'))
+
+ dialog.setFileMode(qt.QFileDialog.AnyFile)
+ dialog.setAcceptMode(qt.QFileDialog.AcceptSave)
+
+ if not dialog.exec():
+ return
+
+ nameFilter = dialog.selectedNameFilter()
+ filename = dialog.selectedFiles()[0]
+ dialog.close()
+
+ # Forces the filename extension to match the chosen filter
+ extension = nameFilter.split()[-1][2:-1]
+ if (len(filename) <= len(extension) or
+ filename[-len(extension):].lower() != extension.lower()):
+ filename += extension
+
+ image = plot3d.grabGL()
+ if not image.save(filename):
+ _logger.error('Failed to save image as %s', filename)
+ qt.QMessageBox.critical(
+ self.parent(),
+ 'Save snapshot as',
+ 'Failed to save snapshot')
+
+
+class PrintAction(Plot3DAction):
+ """QAction to provide printing of a Plot3DWidget
+
+ :param parent: See :class:`QAction`
+ :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
+ Plot3DWidget the action is associated with
+ """
+
+ def __init__(self, parent, plot3d=None):
+ super(PrintAction, self).__init__(parent, plot3d)
+
+ self.setIcon(getQIcon('document-print'))
+ self.setText('Print...')
+ self.setToolTip('Print a snapshot of the 3D scene')
+ self.setCheckable(False)
+ self.setShortcut(qt.QKeySequence.Print)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+ self.triggered[bool].connect(self._triggered)
+
+ def getPrinter(self):
+ """Return the QPrinter instance used for printing.
+
+ :rtype: QPrinter
+ """
+ return printer.getDefaultPrinter()
+
+ def _triggered(self, checked=False):
+ plot3d = self.getPlot3DWidget()
+ if plot3d is None:
+ _logger.error('Cannot print widget, no associated Plot3DWidget')
+ else:
+ printer = self.getPrinter()
+ dialog = qt.QPrintDialog(printer, plot3d)
+ dialog.setWindowTitle('Print Plot3D snapshot')
+ if not dialog.exec():
+ return
+
+ image = plot3d.grabGL()
+
+ # Draw pixmap with painter
+ painter = qt.QPainter()
+ if not painter.begin(printer):
+ return
+
+ pageRect = printer.pageRect(qt.QPrinter.DevicePixel)
+ if (pageRect.width() < image.width() or
+ pageRect.height() < image.height()):
+ # Downscale to page
+ xScale = pageRect.width() / image.width()
+ yScale = pageRect.height() / image.height()
+ scale = min(xScale, yScale)
+ else:
+ scale = 1.
+
+ rect = qt.QRectF(0,
+ 0,
+ scale * image.width(),
+ scale * image.height())
+ painter.drawImage(rect, image)
+ painter.end()
+
+
+class VideoAction(Plot3DAction):
+ """This action triggers the recording of a video of the scene.
+
+ The scene is rotated 360 degrees around a vertical axis.
+
+ :param parent: Action parent see :class:`QAction`.
+ :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
+ Plot3DWidget the action is associated with
+ """
+
+ PNG_SERIE_FILTER = 'Serie of PNG files (*.png)'
+ MNG_FILTER = 'Multiple-image Network Graphics file (*.mng)'
+
+ def __init__(self, parent, plot3d=None):
+ super(VideoAction, self).__init__(parent, plot3d)
+ self.setText('Record video..')
+ self.setIcon(getQIcon('camera'))
+ self.setToolTip(
+ 'Record a video of a 360 degrees rotation of the 3D scene.')
+ self.setCheckable(False)
+ self.triggered[bool].connect(self._triggered)
+
+ def _triggered(self, checked=False):
+ """Action triggered callback"""
+ plot3d = self.getPlot3DWidget()
+ if plot3d is None:
+ _logger.warning(
+ 'Ignoring action triggered without Plot3DWidget set')
+ return
+
+ dialog = qt.QFileDialog(parent=plot3d)
+ dialog.setWindowTitle('Save video as...')
+ dialog.setModal(True)
+ dialog.setNameFilters([self.PNG_SERIE_FILTER,
+ self.MNG_FILTER])
+ dialog.setFileMode(dialog.AnyFile)
+ dialog.setAcceptMode(dialog.AcceptSave)
+
+ if not dialog.exec():
+ return
+
+ nameFilter = dialog.selectedNameFilter()
+ filename = dialog.selectedFiles()[0]
+
+ # Forces the filename extension to match the chosen filter
+ extension = nameFilter.split()[-1][2:-1]
+ if (len(filename) <= len(extension) or
+ filename[-len(extension):].lower() != extension.lower()):
+ filename += extension
+
+ nbFrames = int(4. * 25) # 4 seconds, 25 fps
+
+ if nameFilter == self.PNG_SERIE_FILTER:
+ self._saveAsPNGSerie(filename, nbFrames)
+ elif nameFilter == self.MNG_FILTER:
+ self._saveAsMNG(filename, nbFrames)
+ else:
+ _logger.error('Unsupported file filter: %s', nameFilter)
+
+ def _saveAsPNGSerie(self, filename, nbFrames):
+ """Save video as serie of PNG files.
+
+ It adds a counter to the provided filename before the extension.
+
+ :param str filename: filename to use as template
+ :param int nbFrames: Number of frames to generate
+ """
+ plot3d = self.getPlot3DWidget()
+ assert plot3d is not None
+
+ # Define filename template
+ nbDigits = int(numpy.log10(nbFrames)) + 1
+ indexFormat = '%%0%dd' % nbDigits
+ extensionIndex = filename.rfind('.')
+ filenameFormat = \
+ filename[:extensionIndex] + indexFormat + filename[extensionIndex:]
+
+ try:
+ for index, image in enumerate(self._video360(nbFrames)):
+ image.save(filenameFormat % index)
+ except GeneratorExit:
+ pass
+
+ def _saveAsMNG(self, filename, nbFrames):
+ """Save video as MNG file.
+
+ :param str filename: filename to use
+ :param int nbFrames: Number of frames to generate
+ """
+ plot3d = self.getPlot3DWidget()
+ assert plot3d is not None
+
+ frames = (convertQImageToArray(im) for im in self._video360(nbFrames))
+ try:
+ with open(filename, 'wb') as file_:
+ for chunk in mng.convert(frames, nb_images=nbFrames):
+ file_.write(chunk)
+ except GeneratorExit:
+ os.remove(filename) # Saving aborted, delete file
+
+ def _video360(self, nbFrames):
+ """Run the video and provides the images
+
+ :param int nbFrames: The number of frames to generate for
+ :return: Iterator of QImage of the video sequence
+ """
+ plot3d = self.getPlot3DWidget()
+ assert plot3d is not None
+
+ angleStep = 360. / nbFrames
+
+ # Create progress bar dialog
+ dialog = qt.QDialog(plot3d)
+ dialog.setWindowTitle('Record Video')
+ layout = qt.QVBoxLayout(dialog)
+ progress = qt.QProgressBar()
+ progress.setRange(0, nbFrames)
+ layout.addWidget(progress)
+
+ btnBox = qt.QDialogButtonBox(qt.QDialogButtonBox.Abort)
+ btnBox.rejected.connect(dialog.reject)
+ layout.addWidget(btnBox)
+
+ dialog.setModal(True)
+ dialog.show()
+
+ qapp = qt.QApplication.instance()
+
+ for frame in range(nbFrames):
+ progress.setValue(frame)
+ image = plot3d.grabGL()
+ yield image
+ plot3d.viewport.orbitCamera('left', angleStep)
+ qapp.processEvents()
+ if not dialog.isVisible():
+ break # It as been rejected by the abort button
+ else:
+ dialog.accept()
+
+ if dialog.result() == qt.QDialog.Rejected:
+ raise GeneratorExit('Aborted')
diff --git a/src/silx/gui/plot3d/actions/mode.py b/src/silx/gui/plot3d/actions/mode.py
new file mode 100644
index 0000000..b9cd7c8
--- /dev/null
+++ b/src/silx/gui/plot3d/actions/mode.py
@@ -0,0 +1,178 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides Plot3DAction related to interaction modes.
+
+It provides QAction to rotate or pan a Plot3DWidget
+as well as toggle a picking mode.
+"""
+
+from __future__ import absolute_import, division
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "06/09/2017"
+
+
+import logging
+
+from ....utils.proxy import docstring
+from ... import qt
+from ...icons import getQIcon
+from .Plot3DAction import Plot3DAction
+
+
+_logger = logging.getLogger(__name__)
+
+
+class InteractiveModeAction(Plot3DAction):
+ """Base class for QAction changing interactive mode of a Plot3DWidget
+
+ :param parent: See :class:`QAction`
+ :param str interaction: The interactive mode this action controls
+ :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
+ Plot3DWidget the action is associated with
+ """
+
+ def __init__(self, parent, interaction, plot3d=None):
+ self._interaction = interaction
+
+ super(InteractiveModeAction, self).__init__(parent, plot3d)
+ self.setCheckable(True)
+ self.triggered[bool].connect(self._triggered)
+
+ def _triggered(self, checked=False):
+ plot3d = self.getPlot3DWidget()
+ if plot3d is None:
+ _logger.error(
+ 'Cannot set %s interaction, no associated Plot3DWidget' %
+ self._interaction)
+ else:
+ plot3d.setInteractiveMode(self._interaction)
+ self.setChecked(True)
+
+ @docstring(Plot3DAction)
+ def setPlot3DWidget(self, widget):
+ # Disconnect from previous Plot3DWidget
+ plot3d = self.getPlot3DWidget()
+ if plot3d is not None:
+ plot3d.sigInteractiveModeChanged.disconnect(
+ self._interactiveModeChanged)
+
+ super(InteractiveModeAction, self).setPlot3DWidget(widget)
+
+ # Connect to new Plot3DWidget
+ if widget is None:
+ self.setChecked(False)
+ else:
+ self.setChecked(widget.getInteractiveMode() == self._interaction)
+ widget.sigInteractiveModeChanged.connect(
+ self._interactiveModeChanged)
+
+ def _interactiveModeChanged(self):
+ plot3d = self.getPlot3DWidget()
+ if plot3d is None:
+ _logger.error('Received a signal while there is no widget')
+ else:
+ self.setChecked(plot3d.getInteractiveMode() == self._interaction)
+
+
+class RotateArcballAction(InteractiveModeAction):
+ """QAction to set arcball rotation interaction on a Plot3DWidget
+
+ :param parent: See :class:`QAction`
+ :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
+ Plot3DWidget the action is associated with
+ """
+
+ def __init__(self, parent, plot3d=None):
+ super(RotateArcballAction, self).__init__(parent, 'rotate', plot3d)
+
+ self.setIcon(getQIcon('rotate-3d'))
+ self.setText('Rotate')
+ self.setToolTip('Rotate the view. Press <b>Ctrl</b> to pan.')
+
+
+class PanAction(InteractiveModeAction):
+ """QAction to set pan interaction on a Plot3DWidget
+
+ :param parent: See :class:`QAction`
+ :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
+ Plot3DWidget the action is associated with
+ """
+
+ def __init__(self, parent, plot3d=None):
+ super(PanAction, self).__init__(parent, 'pan', plot3d)
+
+ self.setIcon(getQIcon('pan'))
+ self.setText('Pan')
+ self.setToolTip('Pan the view. Press <b>Ctrl</b> to rotate.')
+
+
+class PickingModeAction(Plot3DAction):
+ """QAction to toggle picking moe on a Plot3DWidget
+
+ :param parent: See :class:`QAction`
+ :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
+ Plot3DWidget the action is associated with
+ """
+
+ sigSceneClicked = qt.Signal(float, float)
+ """Signal emitted when the scene is clicked with the left mouse button.
+
+ This signal is only emitted when the action is checked.
+
+ It provides the (x, y) clicked mouse position in logical widget pixel coordinates
+ """
+
+ def __init__(self, parent, plot3d=None):
+ super(PickingModeAction, self).__init__(parent, plot3d)
+ self.setIcon(getQIcon('pointing-hand'))
+ self.setText('Picking')
+ self.setToolTip('Toggle picking with left button click')
+ self.setCheckable(True)
+ self.triggered[bool].connect(self._triggered)
+
+ def _triggered(self, checked=False):
+ plot3d = self.getPlot3DWidget()
+ if plot3d is not None:
+ if checked:
+ plot3d.sigSceneClicked.connect(self.sigSceneClicked)
+ else:
+ plot3d.sigSceneClicked.disconnect(self.sigSceneClicked)
+
+ @docstring(Plot3DAction)
+ def setPlot3DWidget(self, widget):
+ # Disconnect from previous Plot3DWidget
+ plot3d = self.getPlot3DWidget()
+ if plot3d is not None and self.isChecked():
+ plot3d.sigSceneClicked.disconnect(self.sigSceneClicked)
+
+ super(PickingModeAction, self).setPlot3DWidget(widget)
+
+ # Connect to new Plot3DWidget
+ if widget is None:
+ self.setChecked(False)
+ elif self.isChecked():
+ widget.sigSceneClicked.connect(self.sigSceneClicked)
diff --git a/src/silx/gui/plot3d/actions/viewpoint.py b/src/silx/gui/plot3d/actions/viewpoint.py
new file mode 100644
index 0000000..d764c40
--- /dev/null
+++ b/src/silx/gui/plot3d/actions/viewpoint.py
@@ -0,0 +1,231 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides Plot3DAction controlling the viewpoint.
+
+It provides QAction to rotate or pan a Plot3DWidget.
+"""
+
+from __future__ import absolute_import, division
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/10/2017"
+
+
+import time
+import logging
+
+from silx.gui import qt
+from silx.gui.icons import getQIcon
+from .Plot3DAction import Plot3DAction
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _SetViewpointAction(Plot3DAction):
+ """Base class for actions setting a Plot3DWidget viewpoint
+
+ :param parent: See :class:`QAction`
+ :param str face: The name of the predefined viewpoint
+ :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
+ Plot3DWidget the action is associated with
+ """
+ def __init__(self, parent, face, plot3d=None):
+ super(_SetViewpointAction, self).__init__(parent, plot3d)
+ assert face in ('side', 'front', 'back', 'left', 'right', 'top', 'bottom')
+ self._face = face
+
+ self.setIconVisibleInMenu(True)
+ self.setCheckable(False)
+ self.triggered[bool].connect(self._triggered)
+
+ def _triggered(self, checked=False):
+ plot3d = self.getPlot3DWidget()
+ if plot3d is None:
+ _logger.error(
+ 'Cannot start/stop rotation, no associated Plot3DWidget')
+ else:
+ plot3d.viewport.camera.extrinsic.reset(face=self._face)
+ plot3d.centerScene()
+
+
+class FrontViewpointAction(_SetViewpointAction):
+ """QAction to set Plot3DWidget viewpoint to look from the front
+
+ :param parent: See :class:`QAction`
+ :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
+ Plot3DWidget the action is associated with
+ """
+ def __init__(self, parent, plot3d=None):
+ super(FrontViewpointAction, self).__init__(parent, 'front', plot3d)
+
+ self.setIcon(getQIcon('cube-front'))
+ self.setText('Front')
+ self.setToolTip('View along the -Z axis')
+
+
+class BackViewpointAction(_SetViewpointAction):
+ """QAction to set Plot3DWidget viewpoint to look from the back
+
+ :param parent: See :class:`QAction`
+ :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
+ Plot3DWidget the action is associated with
+ """
+ def __init__(self, parent, plot3d=None):
+ super(BackViewpointAction, self).__init__(parent, 'back', plot3d)
+
+ self.setIcon(getQIcon('cube-back'))
+ self.setText('Back')
+ self.setToolTip('View along the +Z axis')
+
+
+class LeftViewpointAction(_SetViewpointAction):
+ """QAction to set Plot3DWidget viewpoint to look from the left
+
+ :param parent: See :class:`QAction`
+ :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
+ Plot3DWidget the action is associated with
+ """
+ def __init__(self, parent, plot3d=None):
+ super(LeftViewpointAction, self).__init__(parent, 'left', plot3d)
+
+ self.setIcon(getQIcon('cube-left'))
+ self.setText('Left')
+ self.setToolTip('View along the +X axis')
+
+
+class RightViewpointAction(_SetViewpointAction):
+ """QAction to set Plot3DWidget viewpoint to look from the right
+
+ :param parent: See :class:`QAction`
+ :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
+ Plot3DWidget the action is associated with
+ """
+ def __init__(self, parent, plot3d=None):
+ super(RightViewpointAction, self).__init__(parent, 'right', plot3d)
+
+ self.setIcon(getQIcon('cube-right'))
+ self.setText('Right')
+ self.setToolTip('View along the -X axis')
+
+
+class TopViewpointAction(_SetViewpointAction):
+ """QAction to set Plot3DWidget viewpoint to look from the top
+
+ :param parent: See :class:`QAction`
+ :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
+ Plot3DWidget the action is associated with
+ """
+ def __init__(self, parent, plot3d=None):
+ super(TopViewpointAction, self).__init__(parent, 'top', plot3d)
+
+ self.setIcon(getQIcon('cube-top'))
+ self.setText('Top')
+ self.setToolTip('View along the -Y axis')
+
+
+class BottomViewpointAction(_SetViewpointAction):
+ """QAction to set Plot3DWidget viewpoint to look from the bottom
+
+ :param parent: See :class:`QAction`
+ :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
+ Plot3DWidget the action is associated with
+ """
+ def __init__(self, parent, plot3d=None):
+ super(BottomViewpointAction, self).__init__(parent, 'bottom', plot3d)
+
+ self.setIcon(getQIcon('cube-bottom'))
+ self.setText('Bottom')
+ self.setToolTip('View along the +Y axis')
+
+
+class SideViewpointAction(_SetViewpointAction):
+ """QAction to set Plot3DWidget viewpoint to look from the side
+
+ :param parent: See :class:`QAction`
+ :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
+ Plot3DWidget the action is associated with
+ """
+ def __init__(self, parent, plot3d=None):
+ super(SideViewpointAction, self).__init__(parent, 'side', plot3d)
+
+ self.setIcon(getQIcon('cube'))
+ self.setText('Side')
+ self.setToolTip('Side view')
+
+
+class RotateViewpoint(Plot3DAction):
+ """QAction to rotate the scene of a Plot3DWidget
+
+ :param parent: See :class:`QAction`
+ :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
+ Plot3DWidget the action is associated with
+ """
+
+ _TIMEOUT_MS = 50
+ """Time interval between to frames (in milliseconds)"""
+
+ _DEGREE_PER_SECONDS = 360. / 5.
+ """Rotation speed of the animation"""
+
+ def __init__(self, parent, plot3d=None):
+ super(RotateViewpoint, self).__init__(parent, plot3d)
+
+ self._previousTime = None
+
+ self._timer = qt.QTimer(self)
+ self._timer.setInterval(self._TIMEOUT_MS) # 20fps
+ self._timer.timeout.connect(self._rotate)
+
+ self.setIcon(getQIcon('cube-rotate'))
+ self.setText('Rotate scene')
+ self.setToolTip('Rotate the 3D scene around the vertical axis')
+ self.setCheckable(True)
+ self.triggered[bool].connect(self._triggered)
+
+
+ def _triggered(self, checked=False):
+ plot3d = self.getPlot3DWidget()
+ if plot3d is None:
+ _logger.error(
+ 'Cannot start/stop rotation, no associated Plot3DWidget')
+ elif checked:
+ self._previousTime = time.time()
+ self._timer.start()
+ else:
+ self._timer.stop()
+ self._previousTime = None
+
+ def _rotate(self):
+ """Perform a step of the rotation"""
+ if self._previousTime is None:
+ _logger.error('Previous time not set!')
+ angleStep = 0.
+ else:
+ angleStep = self._DEGREE_PER_SECONDS * (time.time() - self._previousTime)
+
+ self.getPlot3DWidget().viewport.orbitCamera('left', angleStep)
+ self._previousTime = time.time()
diff --git a/src/silx/gui/plot3d/conftest.py b/src/silx/gui/plot3d/conftest.py
new file mode 100644
index 0000000..da02238
--- /dev/null
+++ b/src/silx/gui/plot3d/conftest.py
@@ -0,0 +1,5 @@
+import pytest
+
+@pytest.mark.usefixtures("use_opengl")
+def setup_module(module):
+ pass
diff --git a/src/silx/gui/plot3d/items/__init__.py b/src/silx/gui/plot3d/items/__init__.py
new file mode 100644
index 0000000..e7c4af1
--- /dev/null
+++ b/src/silx/gui/plot3d/items/__init__.py
@@ -0,0 +1,43 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides classes that describes :class:`.SceneWidget` content.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "15/11/2017"
+
+
+from .core import DataItem3D, Item3D, GroupItem, GroupWithAxesItem # noqa
+from .core import ItemChangedType, Item3DChangedType # noqa
+from .mixins import (ColormapMixIn, ComplexMixIn, InterpolationMixIn, # noqa
+ PlaneMixIn, SymbolMixIn) # noqa
+from .clipplane import ClipPlane # noqa
+from .image import ImageData, ImageRgba, HeightMapData, HeightMapRGBA # noqa
+from .mesh import Mesh, ColormapMesh, Box, Cylinder, Hexagon # noqa
+from .scatter import Scatter2D, Scatter3D # noqa
+from .volume import ComplexField3D, ScalarField3D # noqa
diff --git a/src/silx/gui/plot3d/items/_pick.py b/src/silx/gui/plot3d/items/_pick.py
new file mode 100644
index 0000000..0d6a495
--- /dev/null
+++ b/src/silx/gui/plot3d/items/_pick.py
@@ -0,0 +1,265 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides classes supporting item picking.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "24/09/2018"
+
+import logging
+import numpy
+
+from ...plot.items._pick import PickingResult as _PickingResult
+from ..scene import Viewport, Base
+
+
+_logger = logging.getLogger(__name__)
+
+
+class PickContext(object):
+ """Store information related to current picking
+
+ :param int x: Widget coordinate
+ :param int y: Widget coordinate
+ :param ~silx.gui.plot3d.scene.Viewport viewport:
+ Viewport where picking occurs
+ :param Union[None,callable] condition:
+ Test whether each item needs to be picked or not.
+ """
+
+ def __init__(self, x, y, viewport, condition):
+ self._widgetPosition = x, y
+ assert isinstance(viewport, Viewport)
+ self._viewport = viewport
+ self._ndcZRange = -1., 1.
+ self._enabled = True
+ self._condition = condition
+
+ def copy(self):
+ """Returns a copy
+
+ :rtype: PickContent
+ """
+ x, y = self.getWidgetPosition()
+ context = PickContext(x, y, self.getViewport(), self._condition)
+ context.setNDCZRange(*self._ndcZRange)
+ context.setEnabled(self.isEnabled())
+ return context
+
+ def isItemPickable(self, item):
+ """Check condition for the given item.
+
+ :param Item3D item:
+ :return: Whether to process the item (True) or to skip it (False)
+ :rtype: bool
+ """
+ return self._condition is None or self._condition(item)
+
+ def getViewport(self):
+ """Returns viewport where picking occurs
+
+ :rtype: ~silx.gui.plot3d.scene.Viewport
+ """
+ return self._viewport
+
+ def getWidgetPosition(self):
+ """Returns (x, y) position in pixel in the widget
+
+ Origin is at the top-left corner of the widget,
+ X from left to right, Y goes downward.
+
+ :rtype: List[int]
+ """
+ return self._widgetPosition
+
+ def setEnabled(self, enabled):
+ """Set whether picking is enabled or not
+
+ :param bool enabled: True to enable picking, False otherwise
+ """
+ self._enabled = bool(enabled)
+
+ def isEnabled(self):
+ """Returns True if picking is currently enabled, False otherwise.
+
+ :rtype: bool
+ """
+ return self._enabled
+
+ def setNDCZRange(self, near=-1., far=1.):
+ """Set near and far Z value in normalized device coordinates
+
+ This allows to clip the ray to a subset of the NDC range
+
+ :param float near: Near segment end point Z coordinate
+ :param float far: Far segment end point Z coordinate
+ """
+ self._ndcZRange = near, far
+
+ def getNDCPosition(self):
+ """Return Normalized device coordinates of picked point.
+
+ :return: (x, y) in NDC coordinates or None if outside viewport.
+ :rtype: Union[None,List[float]]
+ """
+ if not self.isEnabled():
+ return None
+
+ # Convert x, y from window to NDC
+ x, y = self.getWidgetPosition()
+ return self.getViewport().windowToNdc(x, y, checkInside=True)
+
+ def getPickingSegment(self, frame):
+ """Returns picking segment in requested coordinate frame.
+
+ :param Union[str,Base] frame:
+ The frame in which to get the picking segment,
+ either a keyword: 'ndc', 'camera', 'scene' or a scene
+ :class:`~silx.gui.plot3d.scene.Base` object.
+ :return: Near and far points of the segment as (x, y, z, w)
+ or None if picked point is outside viewport
+ :rtype: Union[None,numpy.ndarray]
+ """
+ assert frame in ('ndc', 'camera', 'scene') or isinstance(frame, Base)
+
+ positionNdc = self.getNDCPosition()
+ if positionNdc is None:
+ return None
+
+ near, far = self._ndcZRange
+ rayNdc = numpy.array((positionNdc + (near, 1.),
+ positionNdc + (far, 1.)),
+ dtype=numpy.float64)
+ if frame == 'ndc':
+ return rayNdc
+
+ viewport = self.getViewport()
+
+ rayCamera = viewport.camera.intrinsic.transformPoints(
+ rayNdc,
+ direct=False,
+ perspectiveDivide=True)
+ if frame == 'camera':
+ return rayCamera
+
+ rayScene = viewport.camera.extrinsic.transformPoints(
+ rayCamera, direct=False)
+ if frame == 'scene':
+ return rayScene
+
+ # frame is a scene Base object
+ rayObject = frame.objectToSceneTransform.transformPoints(
+ rayScene, direct=False)
+ return rayObject
+
+
+class PickingResult(_PickingResult):
+ """Class to access picking information in a 3D scene."""
+
+ def __init__(self, item, positions, indices=None, fetchdata=None):
+ """Init
+
+ :param ~silx.gui.plot3d.items.Item3D item: The picked item
+ :param numpy.ndarray positions:
+ Nx3 array-like of picked positions (x, y, z) in item coordinates.
+ :param numpy.ndarray indices: Array-like of indices of picked data.
+ Either 1D or 2D with dim0: data dimension and dim1: indices.
+ No copy is made.
+ :param callable fetchdata: Optional function with a bool copy argument
+ to provide an alternative function to access item data.
+ Default is to use `item.getData`.
+ """
+ super(PickingResult, self).__init__(item, indices)
+
+ self._objectPositions = numpy.array(
+ positions, copy=False, dtype=numpy.float64)
+
+ # Store matrices to generate positions on demand
+ primitive = item._getScenePrimitive()
+ self._objectToSceneTransform = primitive.objectToSceneTransform
+ self._objectToNDCTransform = primitive.objectToNDCTransform
+ self._scenePositions = None
+ self._ndcPositions = None
+
+ self._fetchdata = fetchdata
+
+ def getData(self, copy=True):
+ """Returns picked data values
+
+ :param bool copy: True (default) to get a copy,
+ False to return internal arrays
+ :rtype: Union[None,numpy.ndarray]
+ """
+
+ indices = self.getIndices(copy=False)
+ if indices is None or len(indices) == 0:
+ return None
+
+ item = self.getItem()
+ if self._fetchdata is None:
+ if hasattr(item, 'getData'):
+ data = item.getData(copy=False)
+ else:
+ return None
+ else:
+ data = self._fetchdata(copy=False)
+
+ return numpy.array(data[indices], copy=copy)
+
+ def getPositions(self, frame='scene', copy=True):
+ """Returns picking positions in item coordinates.
+
+ :param str frame: The frame in which the positions are returned
+ Either 'scene' for world space,
+ 'ndc' for normalized device coordinates or 'object' for item frame.
+ :param bool copy: True (default) to get a copy,
+ False to return internal arrays
+ :return: Nx3 array of (x, y, z) coordinates
+ :rtype: numpy.ndarray
+ """
+ if frame == 'ndc':
+ if self._ndcPositions is None: # Lazy-loading
+ self._ndcPositions = self._objectToNDCTransform.transformPoints(
+ self._objectPositions, perspectiveDivide=True)
+
+ positions = self._ndcPositions
+
+ elif frame == 'scene':
+ if self._scenePositions is None: # Lazy-loading
+ self._scenePositions = self._objectToSceneTransform.transformPoints(
+ self._objectPositions)
+
+ positions = self._scenePositions
+
+ elif frame == 'object':
+ positions = self._objectPositions
+
+ else:
+ raise ValueError('Unsupported frame argument: %s' % str(frame))
+
+ return numpy.array(positions, copy=copy)
diff --git a/src/silx/gui/plot3d/items/clipplane.py b/src/silx/gui/plot3d/items/clipplane.py
new file mode 100644
index 0000000..3e819d0
--- /dev/null
+++ b/src/silx/gui/plot3d/items/clipplane.py
@@ -0,0 +1,136 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a scene clip plane class.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "15/11/2017"
+
+
+import numpy
+
+from ..scene import primitives, utils
+
+from ._pick import PickingResult
+from .core import Item3D
+from .mixins import PlaneMixIn
+
+
+class ClipPlane(Item3D, PlaneMixIn):
+ """Represents a clipping plane that clips following items within the group.
+
+ For now only on clip plane is allowed at once in a scene.
+ """
+
+ def __init__(self, parent=None):
+ plane = primitives.ClipPlane()
+ Item3D.__init__(self, parent=parent, primitive=plane)
+ PlaneMixIn.__init__(self, plane=plane)
+
+ def __pickPreProcessing(self, context):
+ """Common processing for :meth:`_pickPostProcess` and :meth:`_pickFull`
+
+ :param PickContext context: Current picking context
+ :return None or (bounds, intersection points, rayObject)
+ """
+ plane = self._getPlane()
+ planeParent = plane.parent
+ if planeParent is None:
+ return None
+
+ rayObject = context.getPickingSegment(frame=plane)
+ if rayObject is None:
+ return None
+
+ bounds = planeParent.bounds(dataBounds=True)
+ rayClip = utils.clipSegmentToBounds(rayObject[:, :3], bounds)
+ if rayClip is None:
+ return None # Ray is outside parent's bounding box
+
+ points = utils.segmentPlaneIntersect(
+ rayObject[0, :3],
+ rayObject[1, :3],
+ planeNorm=self.getNormal(),
+ planePt=self.getPoint())
+
+ # A single intersection inside bounding box
+ picked = (len(points) == 1 and
+ numpy.all(bounds[0] <= points[0]) and
+ numpy.all(points[0] <= bounds[1]))
+
+ return picked, points, rayObject
+
+ def _pick(self, context):
+ # Perform picking before modifying context
+ result = super(ClipPlane, self)._pick(context)
+
+ # Modify context if needed
+ if self.isVisible() and context.isEnabled():
+ info = self.__pickPreProcessing(context)
+ if info is not None:
+ picked, points, rayObject = info
+ plane = self._getPlane()
+
+ if picked: # A single intersection inside bounding box
+ # Clip NDC z range for following brother items
+ ndcIntersect = plane.objectToNDCTransform.transformPoint(
+ points[0], perspectiveDivide=True)
+ ndcNormal = plane.objectToNDCTransform.transformNormal(
+ self.getNormal())
+ if ndcNormal[2] < 0:
+ context.setNDCZRange(-1., ndcIntersect[2])
+ else:
+ context.setNDCZRange(ndcIntersect[2], 1.)
+
+ else:
+ # TODO check this might not be correct
+ rayObject[:, 3] = 1. # Make sure 4h coordinate is one
+ if numpy.sum(rayObject[0] * self.getParameters()) < 0.:
+ # Disable picking for remaining brothers
+ context.setEnabled(False)
+
+ return result
+
+ def _pickFastCheck(self, context):
+ return True
+
+ def _pickFull(self, context):
+ """Perform picking in this item at given widget position.
+
+ :param PickContext context: Current picking context
+ :return: Object holding the results or None
+ :rtype: Union[None,PickingResult]
+ """
+ info = self.__pickPreProcessing(context)
+ if info is not None:
+ picked, points, _ = info
+
+ if picked:
+ return PickingResult(self, positions=[points[0]])
+
+ return None
diff --git a/src/silx/gui/plot3d/items/core.py b/src/silx/gui/plot3d/items/core.py
new file mode 100644
index 0000000..0388ce7
--- /dev/null
+++ b/src/silx/gui/plot3d/items/core.py
@@ -0,0 +1,778 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the base class for items of the :class:`.SceneWidget`.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "15/11/2017"
+
+from collections import defaultdict
+import enum
+
+import numpy
+
+from ... import qt
+from ...plot.items import ItemChangedType
+from .. import scene
+from ..scene import axes, primitives, transform
+from ._pick import PickContext
+
+
+@enum.unique
+class Item3DChangedType(enum.Enum):
+ """Type of modification provided by :attr:`Item3D.sigItemChanged` signal."""
+
+ INTERPOLATION = 'interpolationChanged'
+ """Item3D image interpolation changed flag."""
+
+ TRANSFORM = 'transformChanged'
+ """Item3D transform changed flag."""
+
+ HEIGHT_MAP = 'heightMapChanged'
+ """Item3D height map changed flag."""
+
+ ISO_LEVEL = 'isoLevelChanged'
+ """Isosurface level changed flag."""
+
+ LABEL = 'labelChanged'
+ """Item's label changed flag."""
+
+ BOUNDING_BOX_VISIBLE = 'boundingBoxVisibleChanged'
+ """Item's bounding box visibility changed"""
+
+ ROOT_ITEM = 'rootItemChanged'
+ """Item's root changed flag."""
+
+
+class Item3D(qt.QObject):
+ """Base class representing an item in the scene.
+
+ :param parent: The View widget this item belongs to.
+ :param primitive: An optional primitive to use as scene primitive
+ """
+
+ _LABEL_INDICES = defaultdict(int)
+ """Store per class label indices"""
+
+ sigItemChanged = qt.Signal(object)
+ """Signal emitted when an item's property has changed.
+
+ It provides a flag describing which property of the item has changed.
+ See :class:`ItemChangedType` and :class:`Item3DChangedType`
+ for flags description.
+ """
+
+ def __init__(self, parent, primitive=None):
+ qt.QObject.__init__(self, parent)
+
+ if primitive is None:
+ primitive = scene.Group()
+
+ self._primitive = primitive
+
+ self.__syncForegroundColor()
+
+ labelIndex = self._LABEL_INDICES[self.__class__]
+ self._label = str(self.__class__.__name__)
+ if labelIndex != 0:
+ self._label += u' %d' % labelIndex
+ self._LABEL_INDICES[self.__class__] += 1
+
+ if isinstance(parent, Item3D):
+ parent.sigItemChanged.connect(self.__parentItemChanged)
+
+ def setParent(self, parent):
+ """Override set parent to handle root item change"""
+ previousParent = self.parent()
+ if isinstance(previousParent, Item3D):
+ previousParent.sigItemChanged.disconnect(self.__parentItemChanged)
+
+ super(Item3D, self).setParent(parent)
+
+ if isinstance(parent, Item3D):
+ parent.sigItemChanged.connect(self.__parentItemChanged)
+
+ self._updated(Item3DChangedType.ROOT_ITEM)
+
+ def __parentItemChanged(self, event):
+ """Handle updates of the parent if it is an Item3D
+
+ :param Item3DChangedType event:
+ """
+ if event == Item3DChangedType.ROOT_ITEM:
+ self._updated(Item3DChangedType.ROOT_ITEM)
+
+ def root(self):
+ """Returns the root of the scene this item belongs to.
+
+ The root is the up-most Item3D in the scene tree hierarchy.
+
+ :rtype: Union[Item3D, None]
+ """
+ root = None
+ ancestor = self.parent()
+ while isinstance(ancestor, Item3D):
+ root = ancestor
+ ancestor = ancestor.parent()
+
+ return root
+
+ def _getScenePrimitive(self):
+ """Return the group containing the item rendering"""
+ return self._primitive
+
+ def _updated(self, event=None):
+ """Handle MixIn class updates.
+
+ :param event: The event to send to :attr:`sigItemChanged` signal.
+ """
+ if event == Item3DChangedType.ROOT_ITEM:
+ self.__syncForegroundColor()
+
+ if event is not None:
+ self.sigItemChanged.emit(event)
+
+ # Label
+
+ def getLabel(self):
+ """Returns the label associated to this item.
+
+ :rtype: str
+ """
+ return self._label
+
+ def setLabel(self, label):
+ """Set the label associated to this item.
+
+ :param str label:
+ """
+ label = str(label)
+ if label != self._label:
+ self._label = label
+ self._updated(Item3DChangedType.LABEL)
+
+ # Visibility
+
+ def isVisible(self):
+ """Returns True if item is visible, else False
+
+ :rtype: bool
+ """
+ return self._getScenePrimitive().visible
+
+ def setVisible(self, visible=True):
+ """Set the visibility of the item in the scene.
+
+ :param bool visible: True (default) to show the item, False to hide
+ """
+ visible = bool(visible)
+ primitive = self._getScenePrimitive()
+ if visible != primitive.visible:
+ primitive.visible = visible
+ self._updated(ItemChangedType.VISIBLE)
+
+ # Foreground color
+
+ def _setForegroundColor(self, color):
+ """Set the foreground color of the item.
+
+ The default implementation does nothing, override it in subclass.
+
+ :param color: RGBA color
+ :type color: tuple of 4 float in [0., 1.]
+ """
+ if hasattr(super(Item3D, self), '_setForegroundColor'):
+ super(Item3D, self)._setForegroundColor(color)
+
+ def __syncForegroundColor(self):
+ """Retrieve foreground color from parent and update this item"""
+ # Look-up for SceneWidget to get its foreground color
+ root = self.root()
+ if root is not None:
+ widget = root.parent()
+ if isinstance(widget, qt.QWidget):
+ self._setForegroundColor(
+ widget.getForegroundColor().getRgbF())
+
+ # picking
+
+ def _pick(self, context):
+ """Implement picking on this item.
+
+ :param PickContext context: Current picking context
+ :return: Data indices at picked position or None
+ :rtype: Union[None,PickingResult]
+ """
+ if (self.isVisible() and
+ context.isEnabled() and
+ context.isItemPickable(self) and
+ self._pickFastCheck(context)):
+ return self._pickFull(context)
+ return None
+
+ def _pickFastCheck(self, context):
+ """Approximate item pick test (e.g., bounding box-based picking).
+
+ :param PickContext context: Current picking context
+ :return: True if item might be picked
+ :rtype: bool
+ """
+ primitive = self._getScenePrimitive()
+
+ positionNdc = context.getNDCPosition()
+ if positionNdc is None: # No picking outside viewport
+ return False
+
+ bounds = primitive.bounds(transformed=False, dataBounds=False)
+ if bounds is None: # primitive has no bounds
+ return False
+
+ bounds = primitive.objectToNDCTransform.transformBounds(bounds)
+
+ return (bounds[0, 0] <= positionNdc[0] <= bounds[1, 0] and
+ bounds[0, 1] <= positionNdc[1] <= bounds[1, 1])
+
+ def _pickFull(self, context):
+ """Perform precise picking in this item at given widget position.
+
+ :param PickContext context: Current picking context
+ :return: Object holding the results or None
+ :rtype: Union[None,PickingResult]
+ """
+ return None
+
+
+class DataItem3D(Item3D):
+ """Base class representing a data item with transform in the scene.
+
+ :param parent: The View widget this item belongs to.
+ :param Union[GroupBBox, None] group:
+ The scene group to use for rendering
+ """
+
+ def __init__(self, parent, group=None):
+ if group is None:
+ group = primitives.GroupBBox()
+
+ # Set-up bounding box
+ group.boxVisible = False
+ group.axesVisible = False
+ else:
+ assert isinstance(group, primitives.GroupBBox)
+
+ Item3D.__init__(self, parent=parent, primitive=group)
+
+ # Transformations
+ self._translate = transform.Translate()
+ self._rotateForwardTranslation = transform.Translate()
+ self._rotate = transform.Rotate()
+ self._rotateBackwardTranslation = transform.Translate()
+ self._translateFromRotationCenter = transform.Translate()
+ self._matrix = transform.Matrix()
+ self._scale = transform.Scale()
+ # Group transforms to do to data before rotation
+ # This is useful to handle rotation center relative to bbox
+ self._transformObjectToRotate = transform.TransformList(
+ [self._matrix, self._scale])
+ self._transformObjectToRotate.addListener(self._updateRotationCenter)
+
+ self._rotationCenter = 0., 0., 0.
+
+ self.__transforms = transform.TransformList([
+ self._translate,
+ self._rotateForwardTranslation,
+ self._rotate,
+ self._rotateBackwardTranslation,
+ self._transformObjectToRotate])
+
+ self._getScenePrimitive().transforms = self.__transforms
+
+ def _updated(self, event=None):
+ """Handle MixIn class updates.
+
+ :param event: The event to send to :attr:`sigItemChanged` signal.
+ """
+ if event == ItemChangedType.DATA:
+ self._updateRotationCenter()
+ super(DataItem3D, self)._updated(event)
+
+ # Transformations
+
+ def _getSceneTransforms(self):
+ """Return TransformList corresponding to current transforms
+
+ :rtype: TransformList
+ """
+ return self.__transforms
+
+ def setScale(self, sx=1., sy=1., sz=1.):
+ """Set the scale of the item in the scene.
+
+ :param float sx: Scale factor along the X axis
+ :param float sy: Scale factor along the Y axis
+ :param float sz: Scale factor along the Z axis
+ """
+ scale = numpy.array((sx, sy, sz), dtype=numpy.float32)
+ if not numpy.all(numpy.equal(scale, self.getScale())):
+ self._scale.scale = scale
+ self._updated(Item3DChangedType.TRANSFORM)
+
+ def getScale(self):
+ """Returns the scales provided by :meth:`setScale`.
+
+ :rtype: numpy.ndarray
+ """
+ return self._scale.scale
+
+ def setTranslation(self, x=0., y=0., z=0.):
+ """Set the translation of the origin of the item in the scene.
+
+ :param float x: Offset of the data origin on the X axis
+ :param float y: Offset of the data origin on the Y axis
+ :param float z: Offset of the data origin on the Z axis
+ """
+ translation = numpy.array((x, y, z), dtype=numpy.float32)
+ if not numpy.all(numpy.equal(translation, self.getTranslation())):
+ self._translate.translation = translation
+ self._updated(Item3DChangedType.TRANSFORM)
+
+ def getTranslation(self):
+ """Returns the offset set by :meth:`setTranslation`.
+
+ :rtype: numpy.ndarray
+ """
+ return self._translate.translation
+
+ _ROTATION_CENTER_TAGS = 'lower', 'center', 'upper'
+
+ def _updateRotationCenter(self, *args, **kwargs):
+ """Update rotation center relative to bounding box"""
+ center = []
+ for index, position in enumerate(self.getRotationCenter()):
+ # Patch position relative to bounding box
+ if position in self._ROTATION_CENTER_TAGS:
+ bounds = self._getScenePrimitive().bounds(
+ transformed=False, dataBounds=True)
+ bounds = self._transformObjectToRotate.transformBounds(bounds)
+
+ if bounds is None:
+ position = 0.
+ elif position == 'lower':
+ position = bounds[0, index]
+ elif position == 'center':
+ position = 0.5 * (bounds[0, index] + bounds[1, index])
+ elif position == 'upper':
+ position = bounds[1, index]
+
+ center.append(position)
+
+ if not numpy.all(numpy.equal(
+ center, self._rotateForwardTranslation.translation)):
+ self._rotateForwardTranslation.translation = center
+ self._rotateBackwardTranslation.translation = \
+ - self._rotateForwardTranslation.translation
+ self._updated(Item3DChangedType.TRANSFORM)
+
+ def setRotationCenter(self, x=0., y=0., z=0.):
+ """Set the center of rotation of the item.
+
+ Position of the rotation center is either a float
+ for an absolute position or one of the following
+ string to define a position relative to the item's bounding box:
+ 'lower', 'center', 'upper'
+
+ :param x: rotation center position on the X axis
+ :rtype: float or str
+ :param y: rotation center position on the Y axis
+ :rtype: float or str
+ :param z: rotation center position on the Z axis
+ :rtype: float or str
+ """
+ center = []
+ for position in (x, y, z):
+ if isinstance(position, str):
+ assert position in self._ROTATION_CENTER_TAGS
+ else:
+ position = float(position)
+ center.append(position)
+ center = tuple(center)
+
+ if center != self._rotationCenter:
+ self._rotationCenter = center
+ self._updateRotationCenter()
+
+ def getRotationCenter(self):
+ """Returns the rotation center set by :meth:`setRotationCenter`.
+
+ :rtype: 3-tuple of float or str
+ """
+ return self._rotationCenter
+
+ def setRotation(self, angle=0., axis=(0., 0., 1.)):
+ """Set the rotation of the item in the scene
+
+ :param float angle: The rotation angle in degrees.
+ :param axis: The (x, y, z) coordinates of the rotation axis.
+ """
+ axis = numpy.array(axis, dtype=numpy.float32)
+ assert axis.ndim == 1
+ assert axis.size == 3
+ if (self._rotate.angle != angle or
+ not numpy.all(numpy.equal(axis, self._rotate.axis))):
+ self._rotate.setAngleAxis(angle, axis)
+ self._updated(Item3DChangedType.TRANSFORM)
+
+ def getRotation(self):
+ """Returns the rotation set by :meth:`setRotation`.
+
+ :return: (angle, axis)
+ :rtype: 2-tuple (float, numpy.ndarray)
+ """
+ return self._rotate.angle, self._rotate.axis
+
+ def setMatrix(self, matrix=None):
+ """Set the transform matrix
+
+ :param numpy.ndarray matrix: 3x3 transform matrix
+ """
+ matrix4x4 = numpy.identity(4, dtype=numpy.float32)
+
+ if matrix is not None:
+ matrix = numpy.array(matrix, dtype=numpy.float32)
+ assert matrix.shape == (3, 3)
+ matrix4x4[:3, :3] = matrix
+
+ if not numpy.all(numpy.equal(matrix4x4, self._matrix.getMatrix())):
+ self._matrix.setMatrix(matrix4x4)
+ self._updated(Item3DChangedType.TRANSFORM)
+
+ def getMatrix(self):
+ """Returns the matrix set by :meth:`setMatrix`
+
+ :return: 3x3 matrix
+ :rtype: numpy.ndarray"""
+ return self._matrix.getMatrix(copy=True)[:3, :3]
+
+ # Bounding box
+
+ def _setForegroundColor(self, color):
+ """Set the color of the bounding box
+
+ :param color: RGBA color as 4 floats in [0, 1]
+ """
+ self._getScenePrimitive().color = color
+ super(DataItem3D, self)._setForegroundColor(color)
+
+ def isBoundingBoxVisible(self):
+ """Returns item's bounding box visibility.
+
+ :rtype: bool
+ """
+ return self._getScenePrimitive().boxVisible
+
+ def setBoundingBoxVisible(self, visible):
+ """Set item's bounding box visibility.
+
+ :param bool visible:
+ True to show the bounding box, False (default) to hide it
+ """
+ visible = bool(visible)
+ primitive = self._getScenePrimitive()
+ if visible != primitive.boxVisible:
+ primitive.boxVisible = visible
+ self._updated(Item3DChangedType.BOUNDING_BOX_VISIBLE)
+
+
+class BaseNodeItem(DataItem3D):
+ """Base class for data item having children (e.g., group, 3d volume)."""
+
+ def __init__(self, parent=None, group=None):
+ """Base class representing a group of items in the scene.
+
+ :param parent: The View widget this item belongs to.
+ :param Union[GroupBBox, None] group:
+ The scene group to use for rendering
+ """
+ DataItem3D.__init__(self, parent=parent, group=group)
+
+ def getItems(self):
+ """Returns the list of items currently present in the group.
+
+ :rtype: tuple
+ """
+ raise NotImplementedError('getItems must be implemented in subclass')
+
+ def visit(self, included=True):
+ """Generator visiting the group content.
+
+ It traverses the group sub-tree in a top-down left-to-right way.
+
+ :param bool included: True (default) to include self in visit
+ """
+ if included:
+ yield self
+ for child in self.getItems():
+ yield child
+ if hasattr(child, 'visit'):
+ for item in child.visit(included=False):
+ yield item
+
+ def pickItems(self, x, y, condition=None):
+ """Iterator over picked items in the group at given position.
+
+ Each picked item yield a :class:`PickingResult` object
+ holding the picking information.
+
+ It traverses the group sub-tree in a left-to-right top-down way.
+
+ :param int x: X widget device pixel coordinate
+ :param int y: Y widget device pixel coordinate
+ :param callable condition: Optional test called for each item
+ checking whether to process it or not.
+ """
+ viewport = self._getScenePrimitive().viewport
+ if viewport is None:
+ raise RuntimeError(
+ 'Cannot perform picking: Item not attached to a widget')
+
+ context = PickContext(x, y, viewport, condition)
+ for result in self._pickItems(context):
+ yield result
+
+ def _pickItems(self, context):
+ """Implement :meth:`pickItems`
+
+ :param PickContext context: Current picking context
+ """
+ if not self.isVisible() or not context.isEnabled():
+ return # empty iterator
+
+ # Use a copy to discard context changes once this returns
+ context = context.copy()
+
+ if not self._pickFastCheck(context):
+ return # empty iterator
+
+ result = self._pick(context)
+ if result is not None:
+ yield result
+
+ for child in self.getItems():
+ if isinstance(child, BaseNodeItem):
+ for result in child._pickItems(context):
+ yield result # Flatten result
+
+ else:
+ result = child._pick(context)
+ if result is not None:
+ yield result
+
+
+class _BaseGroupItem(BaseNodeItem):
+ """Base class for group of items sharing a common transform."""
+
+ sigItemAdded = qt.Signal(object)
+ """Signal emitted when a new item is added to the group.
+
+ The newly added item is provided by this signal
+ """
+
+ sigItemRemoved = qt.Signal(object)
+ """Signal emitted when an item is removed from the group.
+
+ The removed item is provided by this signal.
+ """
+
+ def __init__(self, parent=None, group=None):
+ """Base class representing a group of items in the scene.
+
+ :param parent: The View widget this item belongs to.
+ :param Union[GroupBBox, None] group:
+ The scene group to use for rendering
+ """
+ BaseNodeItem.__init__(self, parent=parent, group=group)
+ self._items = []
+
+ def _getGroupPrimitive(self):
+ """Returns the group for which to handle children.
+
+ This allows this group to be different from the primitive.
+ """
+ return self._getScenePrimitive()
+
+ def addItem(self, item, index=None):
+ """Add an item to the group
+
+ :param Item3D item: The item to add
+ :param int index: The index at which to place the item.
+ By default it is appended to the end of the list.
+ :raise ValueError: If the item is already in the group.
+ """
+ assert isinstance(item, Item3D)
+ assert item.parent() in (None, self)
+
+ if item in self.getItems():
+ raise ValueError("Item3D already in group: %s" % item)
+
+ item.setParent(self)
+ if index is None:
+ self._getGroupPrimitive().children.append(
+ item._getScenePrimitive())
+ self._items.append(item)
+ else:
+ self._getGroupPrimitive().children.insert(
+ index, item._getScenePrimitive())
+ self._items.insert(index, item)
+ self.sigItemAdded.emit(item)
+
+ def getItems(self):
+ """Returns the list of items currently present in the group.
+
+ :rtype: tuple
+ """
+ return tuple(self._items)
+
+ def removeItem(self, item):
+ """Remove an item from the scene.
+
+ :param Item3D item: The item to remove from the scene
+ :raises ValueError: If the item does not belong to the group
+ """
+ if item not in self.getItems():
+ raise ValueError("Item3D not in group: %s" % str(item))
+
+ self._getGroupPrimitive().children.remove(item._getScenePrimitive())
+ self._items.remove(item)
+ item.setParent(None)
+ self.sigItemRemoved.emit(item)
+
+ def clearItems(self):
+ """Remove all item from the group."""
+ for item in self.getItems():
+ self.removeItem(item)
+
+
+class GroupItem(_BaseGroupItem):
+ """Group of items sharing a common transform."""
+
+ def __init__(self, parent=None):
+ super(GroupItem, self).__init__(parent=parent)
+
+
+class GroupWithAxesItem(_BaseGroupItem):
+ """
+ Group of items sharing a common transform surrounded with labelled axes.
+ """
+
+ def __init__(self, parent=None):
+ """Class representing a group of items in the scene with labelled axes.
+
+ :param parent: The View widget this item belongs to.
+ """
+ super(GroupWithAxesItem, self).__init__(parent=parent,
+ group=axes.LabelledAxes())
+
+ # Axes labels
+
+ def setAxesLabels(self, xlabel=None, ylabel=None, zlabel=None):
+ """Set the text labels of the axes.
+
+ :param str xlabel: Label of the X axis, None to leave unchanged.
+ :param str ylabel: Label of the Y axis, None to leave unchanged.
+ :param str zlabel: Label of the Z axis, None to leave unchanged.
+ """
+ labelledAxes = self._getScenePrimitive()
+ if xlabel is not None:
+ labelledAxes.xlabel = xlabel
+
+ if ylabel is not None:
+ labelledAxes.ylabel = ylabel
+
+ if zlabel is not None:
+ labelledAxes.zlabel = zlabel
+
+ class _Labels(tuple):
+ """Return type of :meth:`getAxesLabels`"""
+
+ def getXLabel(self):
+ """Label of the X axis (str)"""
+ return self[0]
+
+ def getYLabel(self):
+ """Label of the Y axis (str)"""
+ return self[1]
+
+ def getZLabel(self):
+ """Label of the Z axis (str)"""
+ return self[2]
+
+ def getAxesLabels(self):
+ """Returns the text labels of the axes
+
+ >>> group = GroupWithAxesItem()
+ >>> group.setAxesLabels(xlabel='X')
+
+ You can get the labels either as a 3-tuple:
+
+ >>> xlabel, ylabel, zlabel = group.getAxesLabels()
+
+ Or as an object with methods getXLabel, getYLabel and getZLabel:
+
+ >>> labels = group.getAxesLabels()
+ >>> labels.getXLabel()
+ ... 'X'
+
+ :return: object describing the labels
+ """
+ labelledAxes = self._getScenePrimitive()
+ return self._Labels((labelledAxes.xlabel,
+ labelledAxes.ylabel,
+ labelledAxes.zlabel))
+
+
+class RootGroupWithAxesItem(GroupWithAxesItem):
+ """Special group with axes item for root of the scene.
+
+ Uses 2 groups so that axes take transforms into account.
+ """
+
+ def __init__(self, parent=None):
+ super(RootGroupWithAxesItem, self).__init__(parent)
+ self.__group = scene.Group()
+ self.__group.transforms = self._getSceneTransforms()
+
+ groupWithAxes = self._getScenePrimitive()
+ groupWithAxes.transforms = [] # Do not apply transforms here
+ groupWithAxes.children.append(self.__group)
+
+ def _getGroupPrimitive(self):
+ """Returns the group for which to handle children.
+
+ This allows this group to be different from the primitive.
+ """
+ return self.__group
diff --git a/src/silx/gui/plot3d/items/image.py b/src/silx/gui/plot3d/items/image.py
new file mode 100644
index 0000000..5a50459
--- /dev/null
+++ b/src/silx/gui/plot3d/items/image.py
@@ -0,0 +1,425 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides 2D data and RGB(A) image item class.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "15/11/2017"
+
+import numpy
+
+from ..scene import primitives, utils
+from .core import DataItem3D, ItemChangedType
+from .mixins import ColormapMixIn, InterpolationMixIn
+from ._pick import PickingResult
+
+
+class _Image(DataItem3D, InterpolationMixIn):
+ """Base class for images
+
+ :param parent: The View widget this item belongs to.
+ """
+
+ def __init__(self, parent=None):
+ DataItem3D.__init__(self, parent=parent)
+ InterpolationMixIn.__init__(self)
+
+ def _setPrimitive(self, primitive):
+ InterpolationMixIn._setPrimitive(self, primitive)
+
+ def getData(self, copy=True):
+ raise NotImplementedError()
+
+ def _pickFull(self, context):
+ """Perform picking in this item at given widget position.
+
+ :param PickContext context: Current picking context
+ :return: Object holding the results or None
+ :rtype: Union[None,PickingResult]
+ """
+ rayObject = context.getPickingSegment(frame=self._getScenePrimitive())
+ if rayObject is None:
+ return None
+
+ points = utils.segmentPlaneIntersect(
+ rayObject[0, :3],
+ rayObject[1, :3],
+ planeNorm=numpy.array((0., 0., 1.), dtype=numpy.float64),
+ planePt=numpy.array((0., 0., 0.), dtype=numpy.float64))
+
+ if len(points) == 1: # Single intersection
+ if points[0][0] < 0. or points[0][1] < 0.:
+ return None # Outside image
+ row, column = int(points[0][1]), int(points[0][0])
+ data = self.getData(copy=False)
+ height, width = data.shape[:2]
+ if row < height and column < width:
+ return PickingResult(
+ self,
+ positions=[(points[0][0], points[0][1], 0.)],
+ indices=([row], [column]))
+ else:
+ return None # Outside image
+ else: # Either no intersection or segment and image are coplanar
+ return None
+
+
+class ImageData(_Image, ColormapMixIn):
+ """Description of a 2D image data.
+
+ :param parent: The View widget this item belongs to.
+ """
+
+ def __init__(self, parent=None):
+ _Image.__init__(self, parent=parent)
+ ColormapMixIn.__init__(self)
+
+ self._data = numpy.zeros((0, 0), dtype=numpy.float32)
+
+ self._image = primitives.ImageData(self._data)
+ self._getScenePrimitive().children.append(self._image)
+
+ # Connect scene primitive to mix-in class
+ ColormapMixIn._setSceneColormap(self, self._image.colormap)
+ _Image._setPrimitive(self, self._image)
+
+ def setData(self, data, copy=True):
+ """Set the image data to display.
+
+ The data will be casted to float32.
+
+ :param numpy.ndarray data: The image data
+ :param bool copy: True (default) to copy the data,
+ False to use as is (do not modify!).
+ """
+ self._image.setData(data, copy=copy)
+ self._setColormappedData(self.getData(copy=False), copy=False)
+ self._updated(ItemChangedType.DATA)
+
+ def getData(self, copy=True):
+ """Get the image data.
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ :rtype: numpy.ndarray
+ :return: The image data
+ """
+ return self._image.getData(copy=copy)
+
+
+class ImageRgba(_Image, InterpolationMixIn):
+ """Description of a 2D data RGB(A) image.
+
+ :param parent: The View widget this item belongs to.
+ """
+
+ def __init__(self, parent=None):
+ _Image.__init__(self, parent=parent)
+ InterpolationMixIn.__init__(self)
+
+ self._data = numpy.zeros((0, 0, 3), dtype=numpy.float32)
+
+ self._image = primitives.ImageRgba(self._data)
+ self._getScenePrimitive().children.append(self._image)
+
+ # Connect scene primitive to mix-in class
+ _Image._setPrimitive(self, self._image)
+
+ def setData(self, data, copy=True):
+ """Set the RGB(A) image data to display.
+
+ Supported array format: float32 in [0, 1], uint8.
+
+ :param numpy.ndarray data:
+ The RGBA image data as an array of shape (H, W, Channels)
+ :param bool copy: True (default) to copy the data,
+ False to use as is (do not modify!).
+ """
+ self._image.setData(data, copy=copy)
+ self._updated(ItemChangedType.DATA)
+
+ def getData(self, copy=True):
+ """Get the image data.
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ :rtype: numpy.ndarray
+ :return: The image data
+ """
+ return self._image.getData(copy=copy)
+
+
+class _HeightMap(DataItem3D):
+ """Base class for 2D data array displayed as a height field.
+
+ :param parent: The View widget this item belongs to.
+ """
+
+ def __init__(self, parent=None):
+ DataItem3D.__init__(self, parent=parent)
+ self.__data = numpy.zeros((0, 0), dtype=numpy.float32)
+
+ def _pickFull(self, context, threshold=0., sort='depth'):
+ """Perform picking in this item at given widget position.
+
+ :param PickContext context: Current picking context
+ :param float threshold: Picking threshold in pixel.
+ Perform picking in a square of size threshold x threshold.
+ :param str sort: How returned indices are sorted:
+
+ - 'index' (default): sort by the value of the indices
+ - 'depth': Sort by the depth of the points from the current
+ camera point of view.
+ :return: Object holding the results or None
+ :rtype: Union[None,PickingResult]
+ """
+ assert sort in ('index', 'depth')
+
+ rayNdc = context.getPickingSegment(frame='ndc')
+ if rayNdc is None: # No picking outside viewport
+ return None
+
+ # TODO no colormapped or color data
+ # Project data to NDC
+ heightData = self.getData(copy=False)
+ if heightData.size == 0:
+ return # Nothing displayed
+
+ height, width = heightData.shape
+ z = numpy.ravel(heightData)
+ y, x = numpy.mgrid[0:height, 0:width]
+ dataPoints = numpy.transpose((numpy.ravel(x),
+ numpy.ravel(y),
+ z,
+ numpy.ones_like(z)))
+
+ primitive = self._getScenePrimitive()
+
+ pointsNdc = primitive.objectToNDCTransform.transformPoints(
+ dataPoints, perspectiveDivide=True)
+
+ # Perform picking
+ distancesNdc = numpy.abs(pointsNdc[:, :2] - rayNdc[0, :2])
+ # TODO issue with symbol size: using pixel instead of points
+ threshold += 1. # symbol size
+ thresholdNdc = 2. * threshold / numpy.array(primitive.viewport.size)
+ picked = numpy.where(numpy.logical_and(
+ numpy.all(distancesNdc < thresholdNdc, axis=1),
+ numpy.logical_and(rayNdc[0, 2] <= pointsNdc[:, 2],
+ pointsNdc[:, 2] <= rayNdc[1, 2])))[0]
+
+ if sort == 'depth':
+ # Sort picked points from front to back
+ picked = picked[numpy.argsort(pointsNdc[picked, 2])]
+
+ if picked.size > 0:
+ # Convert indices from 1D to 2D
+ return PickingResult(self,
+ positions=dataPoints[picked, :3],
+ indices=(picked // width, picked % width),
+ fetchdata=self.getData)
+ else:
+ return None
+
+ def setData(self, data, copy: bool=True):
+ """Set the height field data.
+
+ :param data:
+ :param copy: True (default) to copy the data,
+ False to use as is (do not modify!).
+ """
+ data = numpy.array(data, copy=copy)
+ assert data.ndim == 2
+
+ self.__data = data
+ self._updated(ItemChangedType.DATA)
+
+ def getData(self, copy: bool=True) -> numpy.ndarray:
+ """Get the height field 2D data.
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ """
+ return numpy.array(self.__data, copy=copy)
+
+
+class HeightMapData(_HeightMap, ColormapMixIn):
+ """Description of a 2D height field associated to a colormapped dataset.
+
+ :param parent: The View widget this item belongs to.
+ """
+
+ def __init__(self, parent=None):
+ _HeightMap.__init__(self, parent=parent)
+ ColormapMixIn.__init__(self)
+
+ self.__data = numpy.zeros((0, 0), dtype=numpy.float32)
+
+ def _updated(self, event=None):
+ if event == ItemChangedType.DATA:
+ self.__updateScene()
+ super()._updated(event=event)
+
+ def __updateScene(self):
+ """Update display primitive to use"""
+ self._getScenePrimitive().children = [] # Remove previous primitives
+ ColormapMixIn._setSceneColormap(self, None)
+
+ if not self.isVisible():
+ return # Update when visible
+
+ data = self.getColormappedData(copy=False)
+ heightData = self.getData(copy=False)
+
+ if data.size == 0 or heightData.size == 0:
+ return # Nothing to display
+
+ # Display as a set of points
+ height, width = heightData.shape
+ # Generates coordinates
+ y, x = numpy.mgrid[0:height, 0:width]
+
+ if data.shape != heightData.shape: # data and height size miss-match
+ # Colormapped data is interpolated (nearest-neighbour) to match the height field
+ data = data[numpy.floor(y * data.shape[0] / height).astype(numpy.int32),
+ numpy.floor(x * data.shape[1] / height).astype(numpy.int32)]
+
+ x = numpy.ravel(x)
+ y = numpy.ravel(y)
+
+ primitive = primitives.Points(
+ x=x,
+ y=y,
+ z=numpy.ravel(heightData),
+ value=numpy.ravel(data),
+ size=1)
+ primitive.marker = 's'
+ ColormapMixIn._setSceneColormap(self, primitive.colormap)
+ self._getScenePrimitive().children = [primitive]
+
+ def setColormappedData(self, data, copy: bool=True):
+ """Set the 2D data used to compute colors.
+
+ :param data: 2D array of data
+ :param copy: True (default) to copy the data,
+ False to use as is (do not modify!).
+ """
+ data = numpy.array(data, copy=copy)
+ assert data.ndim == 2
+
+ self.__data = data
+ self._updated(ItemChangedType.DATA)
+
+ def getColormappedData(self, copy: bool=True) -> numpy.ndarray:
+ """Returns the 2D data used to compute colors.
+
+ :param copy:
+ True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ """
+ return numpy.array(self.__data, copy=copy)
+
+
+class HeightMapRGBA(_HeightMap):
+ """Description of a 2D height field associated to a RGB(A) image.
+
+ :param parent: The View widget this item belongs to.
+ """
+
+ def __init__(self, parent=None):
+ _HeightMap.__init__(self, parent=parent)
+
+ self.__rgba = numpy.zeros((0, 0, 3), dtype=numpy.float32)
+
+ def _updated(self, event=None):
+ if event == ItemChangedType.DATA:
+ self.__updateScene()
+ super()._updated(event=event)
+
+ def __updateScene(self):
+ """Update display primitive to use"""
+ self._getScenePrimitive().children = [] # Remove previous primitives
+
+ if not self.isVisible():
+ return # Update when visible
+
+ rgba = self.getColorData(copy=False)
+ heightData = self.getData(copy=False)
+ if rgba.size == 0 or heightData.size == 0:
+ return # Nothing to display
+
+ # Display as a set of points
+ height, width = heightData.shape
+ # Generates coordinates
+ y, x = numpy.mgrid[0:height, 0:width]
+
+ if rgba.shape[:2] != heightData.shape: # image and height size miss-match
+ # RGBA data is interpolated (nearest-neighbour) to match the height field
+ rgba = rgba[numpy.floor(y * rgba.shape[0] / height).astype(numpy.int32),
+ numpy.floor(x * rgba.shape[1] / height).astype(numpy.int32)]
+
+ x = numpy.ravel(x)
+ y = numpy.ravel(y)
+
+ primitive = primitives.ColorPoints(
+ x=x,
+ y=y,
+ z=numpy.ravel(heightData),
+ color=rgba.reshape(-1, rgba.shape[-1]),
+ size=1)
+ primitive.marker = 's'
+ self._getScenePrimitive().children = [primitive]
+
+ def setColorData(self, data, copy: bool=True):
+ """Set the RGB(A) image to use.
+
+ Supported array format: float32 in [0, 1], uint8.
+
+ :param data:
+ The RGBA image data as an array of shape (H, W, Channels)
+ :param copy: True (default) to copy the data,
+ False to use as is (do not modify!).
+ """
+ data = numpy.array(data, copy=copy)
+ assert data.ndim == 3
+ assert data.shape[-1] in (3, 4)
+ # TODO check type
+
+ self.__rgba = data
+ self._updated(ItemChangedType.DATA)
+
+ def getColorData(self, copy: bool=True) -> numpy.ndarray:
+ """Get the RGB(A) image data.
+
+ :param copy: True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ """
+ return numpy.array(self.__rgba, copy=copy)
diff --git a/src/silx/gui/plot3d/items/mesh.py b/src/silx/gui/plot3d/items/mesh.py
new file mode 100644
index 0000000..4e19939
--- /dev/null
+++ b/src/silx/gui/plot3d/items/mesh.py
@@ -0,0 +1,792 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides regular mesh item class.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "17/07/2018"
+
+
+import logging
+import numpy
+
+from ... import _glutils as glu
+from ..scene import primitives, utils, function
+from ..scene.transform import Rotate
+from .core import DataItem3D, ItemChangedType
+from .mixins import ColormapMixIn
+from ._pick import PickingResult
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _MeshBase(DataItem3D):
+ """Base class for :class:`Mesh' and :class:`ColormapMesh`.
+
+ :param parent: The View widget this item belongs to.
+ """
+
+ def __init__(self, parent=None):
+ DataItem3D.__init__(self, parent=parent)
+ self._mesh = None
+
+ def _setMesh(self, mesh):
+ """Set mesh primitive
+
+ :param Union[None,Geometry] mesh: The scene primitive
+ """
+ self._getScenePrimitive().children = [] # Remove any previous mesh
+
+ self._mesh = mesh
+ if self._mesh is not None:
+ self._getScenePrimitive().children.append(self._mesh)
+
+ self._updated(ItemChangedType.DATA)
+
+ def _getMesh(self):
+ """Returns the underlying Mesh scene primitive"""
+ return self._mesh
+
+ def getPositionData(self, copy=True):
+ """Get the mesh vertex positions.
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ :return: The (x, y, z) positions as a (N, 3) array
+ :rtype: numpy.ndarray
+ """
+ if self._getMesh() is None:
+ return numpy.empty((0, 3), dtype=numpy.float32)
+ else:
+ return self._getMesh().getAttribute('position', copy=copy)
+
+ def getNormalData(self, copy=True):
+ """Get the mesh vertex normals.
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ :return: The normals as a (N, 3) array, a single normal or None
+ :rtype: Union[numpy.ndarray,None]
+ """
+ if self._getMesh() is None:
+ return None
+ else:
+ return self._getMesh().getAttribute('normal', copy=copy)
+
+ def getIndices(self, copy=True):
+ """Get the vertex indices.
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ :return: The vertex indices as an array or None.
+ :rtype: Union[numpy.ndarray,None]
+ """
+ if self._getMesh() is None:
+ return None
+ else:
+ return self._getMesh().getIndices(copy=copy)
+
+ def getDrawMode(self):
+ """Get mesh rendering mode.
+
+ :return: The drawing mode of this primitive
+ :rtype: str
+ """
+ return self._getMesh().drawMode
+
+ def _pickFull(self, context):
+ """Perform precise picking in this item at given widget position.
+
+ :param PickContext context: Current picking context
+ :return: Object holding the results or None
+ :rtype: Union[None,PickingResult]
+ """
+ rayObject = context.getPickingSegment(frame=self._getScenePrimitive())
+ if rayObject is None: # No picking outside viewport
+ return None
+ rayObject = rayObject[:, :3]
+
+ positions = self.getPositionData(copy=False)
+ if positions.size == 0:
+ return None
+
+ mode = self.getDrawMode()
+
+ vertexIndices = self.getIndices(copy=False)
+ if vertexIndices is not None: # Expand indices
+ positions = utils.unindexArrays(mode, vertexIndices, positions)[0]
+ triangles = positions.reshape(-1, 3, 3)
+ else:
+ if mode == 'triangles':
+ triangles = positions.reshape(-1, 3, 3)
+
+ elif mode == 'triangle_strip':
+ # Expand strip
+ triangles = numpy.empty((len(positions) - 2, 3, 3),
+ dtype=positions.dtype)
+ triangles[:, 0] = positions[:-2]
+ triangles[:, 1] = positions[1:-1]
+ triangles[:, 2] = positions[2:]
+
+ elif mode == 'fan':
+ # Expand fan
+ triangles = numpy.empty((len(positions) - 2, 3, 3),
+ dtype=positions.dtype)
+ triangles[:, 0] = positions[0]
+ triangles[:, 1] = positions[1:-1]
+ triangles[:, 2] = positions[2:]
+
+ else:
+ _logger.warning("Unsupported draw mode: %s" % mode)
+ return None
+
+ trianglesIndices, t, barycentric = glu.segmentTrianglesIntersection(
+ rayObject, triangles)
+
+ if len(trianglesIndices) == 0:
+ return None
+
+ points = t.reshape(-1, 1) * (rayObject[1] - rayObject[0]) + rayObject[0]
+
+ # Get vertex index from triangle index and closest point in triangle
+ closest = numpy.argmax(barycentric, axis=1)
+
+ if mode == 'triangles':
+ indices = trianglesIndices * 3 + closest
+
+ elif mode == 'triangle_strip':
+ indices = trianglesIndices + closest
+
+ elif mode == 'fan':
+ indices = trianglesIndices + closest # For corners 1 and 2
+ indices[closest == 0] = 0 # For first corner (common)
+
+ if vertexIndices is not None:
+ # Convert from indices in expanded triangles to input vertices
+ indices = vertexIndices[indices]
+
+ return PickingResult(self,
+ positions=points,
+ indices=indices,
+ fetchdata=self.getPositionData)
+
+
+class Mesh(_MeshBase):
+ """Description of mesh.
+
+ :param parent: The View widget this item belongs to.
+ """
+
+ def __init__(self, parent=None):
+ _MeshBase.__init__(self, parent=parent)
+
+ def setData(self,
+ position,
+ color,
+ normal=None,
+ mode='triangles',
+ indices=None,
+ copy=True):
+ """Set mesh geometry data.
+
+ Supported drawing modes are: 'triangles', 'triangle_strip', 'fan'
+
+ :param numpy.ndarray position:
+ Position (x, y, z) of each vertex as a (N, 3) array
+ :param numpy.ndarray color: Colors for each point or a single color
+ :param Union[numpy.ndarray,None] normal: Normals for each point or None (default)
+ :param str mode: The drawing mode.
+ :param Union[List[int],None] indices:
+ Array of vertex indices or None to use arrays directly.
+ :param bool copy: True (default) to copy the data,
+ False to use as is (do not modify!).
+ """
+ assert mode in ('triangles', 'triangle_strip', 'fan')
+ if position is None or len(position) == 0:
+ mesh = None
+ else:
+ mesh = primitives.Mesh3D(
+ position, color, normal, mode=mode, indices=indices, copy=copy)
+ self._setMesh(mesh)
+
+ def getData(self, copy=True):
+ """Get the mesh geometry.
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ :return: The positions, colors, normals and mode
+ :rtype: tuple of numpy.ndarray
+ """
+ return (self.getPositionData(copy=copy),
+ self.getColorData(copy=copy),
+ self.getNormalData(copy=copy),
+ self.getDrawMode())
+
+ def getColorData(self, copy=True):
+ """Get the mesh vertex colors.
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ :return: The RGBA colors as a (N, 4) array or a single color
+ :rtype: numpy.ndarray
+ """
+ if self._getMesh() is None:
+ return numpy.empty((0, 4), dtype=numpy.float32)
+ else:
+ return self._getMesh().getAttribute('color', copy=copy)
+
+
+class ColormapMesh(_MeshBase, ColormapMixIn):
+ """Description of mesh which color is defined by scalar and a colormap.
+
+ :param parent: The View widget this item belongs to.
+ """
+
+ def __init__(self, parent=None):
+ _MeshBase.__init__(self, parent=parent)
+ ColormapMixIn.__init__(self, function.Colormap())
+
+ def setData(self,
+ position,
+ value,
+ normal=None,
+ mode='triangles',
+ indices=None,
+ copy=True):
+ """Set mesh geometry data.
+
+ Supported drawing modes are: 'triangles', 'triangle_strip', 'fan'
+
+ :param numpy.ndarray position:
+ Position (x, y, z) of each vertex as a (N, 3) array
+ :param numpy.ndarray value: Data value for each vertex.
+ :param Union[numpy.ndarray,None] normal: Normals for each point or None (default)
+ :param str mode: The drawing mode.
+ :param Union[List[int],None] indices:
+ Array of vertex indices or None to use arrays directly.
+ :param bool copy: True (default) to copy the data,
+ False to use as is (do not modify!).
+ """
+ assert mode in ('triangles', 'triangle_strip', 'fan')
+ if position is None or len(position) == 0:
+ mesh = None
+ else:
+ mesh = primitives.ColormapMesh3D(
+ position=position,
+ value=numpy.array(value, copy=False).reshape(-1, 1), # Make it a 2D array
+ colormap=self._getSceneColormap(),
+ normal=normal,
+ mode=mode,
+ indices=indices,
+ copy=copy)
+ self._setMesh(mesh)
+
+ self._setColormappedData(self.getValueData(copy=False), copy=False)
+
+ def getData(self, copy=True):
+ """Get the mesh geometry.
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ :return: The positions, values, normals and mode
+ :rtype: tuple of numpy.ndarray
+ """
+ return (self.getPositionData(copy=copy),
+ self.getValueData(copy=copy),
+ self.getNormalData(copy=copy),
+ self.getDrawMode())
+
+ def getValueData(self, copy=True):
+ """Get the mesh vertex values.
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ :return: Array of data values
+ :rtype: numpy.ndarray
+ """
+ if self._getMesh() is None:
+ return numpy.empty((0,), dtype=numpy.float32)
+ else:
+ return self._getMesh().getAttribute('value', copy=copy)
+
+
+class _CylindricalVolume(DataItem3D):
+ """Class that represents a volume with a rotational symmetry along z
+
+ :param parent: The View widget this item belongs to.
+ """
+
+ def __init__(self, parent=None):
+ DataItem3D.__init__(self, parent=parent)
+ self._mesh = None
+ self._nbFaces = 0
+
+ def getPosition(self, copy=True):
+ """Get primitive positions.
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ :return: Position of the primitives as a (N, 3) array.
+ :rtype: numpy.ndarray
+ """
+ raise NotImplementedError("Must be implemented in subclass")
+
+ def _setData(self, position, radius, height, angles, color, flatFaces,
+ rotation):
+ """Set volume geometry data.
+
+ :param numpy.ndarray position:
+ Center position (x, y, z) of each volume as (N, 3) array.
+ :param float radius: External radius ot the volume.
+ :param float height: Height of the volume(s).
+ :param numpy.ndarray angles: Angles of the edges.
+ :param numpy.array color: RGB color of the volume(s).
+ :param bool flatFaces:
+ If the volume as flat faces or not. Used for normals calculation.
+ """
+
+ self._getScenePrimitive().children = [] # Remove any previous mesh
+
+ if position is None or len(position) == 0:
+ self._mesh = None
+ self._nbFaces = 0
+ else:
+ self._nbFaces = len(angles) - 1
+
+ volume = numpy.empty(shape=(len(angles) - 1, 12, 3),
+ dtype=numpy.float32)
+ normal = numpy.empty(shape=(len(angles) - 1, 12, 3),
+ dtype=numpy.float32)
+
+ for i in range(0, len(angles) - 1):
+ # c6
+ # /\
+ # / \
+ # / \
+ # c4|------|c5
+ # | \ |
+ # | \ |
+ # | \ |
+ # | \ |
+ # c2|------|c3
+ # \ /
+ # \ /
+ # \/
+ # c1
+ c1 = numpy.array([0, 0, -height/2])
+ c1 = rotation.transformPoint(c1)
+ c2 = numpy.array([radius * numpy.cos(angles[i]),
+ radius * numpy.sin(angles[i]),
+ -height/2])
+ c2 = rotation.transformPoint(c2)
+ c3 = numpy.array([radius * numpy.cos(angles[i+1]),
+ radius * numpy.sin(angles[i+1]),
+ -height/2])
+ c3 = rotation.transformPoint(c3)
+ c4 = numpy.array([radius * numpy.cos(angles[i]),
+ radius * numpy.sin(angles[i]),
+ height/2])
+ c4 = rotation.transformPoint(c4)
+ c5 = numpy.array([radius * numpy.cos(angles[i+1]),
+ radius * numpy.sin(angles[i+1]),
+ height/2])
+ c5 = rotation.transformPoint(c5)
+ c6 = numpy.array([0, 0, height/2])
+ c6 = rotation.transformPoint(c6)
+
+ volume[i] = numpy.array([c1, c3, c2,
+ c2, c3, c4,
+ c3, c5, c4,
+ c4, c5, c6])
+ if flatFaces:
+ normal[i] = numpy.array([numpy.cross(c3-c1, c2-c1), # c1
+ numpy.cross(c2-c3, c1-c3), # c3
+ numpy.cross(c1-c2, c3-c2), # c2
+ numpy.cross(c3-c2, c4-c2), # c2
+ numpy.cross(c4-c3, c2-c3), # c3
+ numpy.cross(c2-c4, c3-c4), # c4
+ numpy.cross(c5-c3, c4-c3), # c3
+ numpy.cross(c4-c5, c3-c5), # c5
+ numpy.cross(c3-c4, c5-c4), # c4
+ numpy.cross(c5-c4, c6-c4), # c4
+ numpy.cross(c6-c5, c5-c5), # c5
+ numpy.cross(c4-c6, c5-c6)]) # c6
+ else:
+ normal[i] = numpy.array([numpy.cross(c3-c1, c2-c1),
+ numpy.cross(c2-c3, c1-c3),
+ numpy.cross(c1-c2, c3-c2),
+ c2-c1, c3-c1, c4-c6, # c2 c2 c4
+ c3-c1, c5-c6, c4-c6, # c3 c5 c4
+ numpy.cross(c5-c4, c6-c4),
+ numpy.cross(c6-c5, c5-c5),
+ numpy.cross(c4-c6, c5-c6)])
+
+ # Multiplication according to the number of positions
+ vertices = numpy.tile(volume.reshape(-1, 3), (len(position), 1))\
+ .reshape((-1, 3))
+ normals = numpy.tile(normal.reshape(-1, 3), (len(position), 1))\
+ .reshape((-1, 3))
+
+ # Translations
+ numpy.add(vertices, numpy.tile(position, (1, (len(angles)-1) * 12))
+ .reshape((-1, 3)), out=vertices)
+
+ # Colors
+ if numpy.ndim(color) == 2:
+ color = numpy.tile(color, (1, 12 * (len(angles) - 1)))\
+ .reshape(-1, 3)
+
+ self._mesh = primitives.Mesh3D(
+ vertices, color, normals, mode='triangles', copy=False)
+ self._getScenePrimitive().children.append(self._mesh)
+
+ self._updated(ItemChangedType.DATA)
+
+ def _pickFull(self, context):
+ """Perform precise picking in this item at given widget position.
+
+ :param PickContext context: Current picking context
+ :return: Object holding the results or None
+ :rtype: Union[None,PickingResult]
+ """
+ if self._mesh is None or self._nbFaces == 0:
+ return None
+
+ rayObject = context.getPickingSegment(frame=self._getScenePrimitive())
+ if rayObject is None: # No picking outside viewport
+ return None
+ rayObject = rayObject[:, :3]
+
+ positions = self._mesh.getAttribute('position', copy=False)
+ triangles = positions.reshape(-1, 3, 3) # 'triangle' draw mode
+
+ trianglesIndices, t = glu.segmentTrianglesIntersection(
+ rayObject, triangles)[:2]
+
+ if len(trianglesIndices) == 0:
+ return None
+
+ # Get object index from triangle index
+ indices = trianglesIndices // (4 * self._nbFaces)
+
+ # Select closest intersection point for each primitive
+ indices, firstIndices = numpy.unique(indices, return_index=True)
+ t = t[firstIndices]
+
+ # Resort along t as result of numpy.unique is not sorted by t
+ sortedIndices = numpy.argsort(t)
+ t = t[sortedIndices]
+ indices = indices[sortedIndices]
+
+ points = t.reshape(-1, 1) * (rayObject[1] - rayObject[0]) + rayObject[0]
+
+ return PickingResult(self,
+ positions=points,
+ indices=indices,
+ fetchdata=self.getPosition)
+
+
+class Box(_CylindricalVolume):
+ """Description of a box.
+
+ Can be used to draw one box or many similar boxes.
+
+ :param parent: The View widget this item belongs to.
+ """
+
+ def __init__(self, parent=None):
+ super(Box, self).__init__(parent)
+ self.position = None
+ self.size = None
+ self.color = None
+ self.rotation = None
+ self.setData()
+
+ def setData(self, size=(1, 1, 1), color=(1, 1, 1),
+ position=(0, 0, 0), rotation=(0, (0, 0, 0))):
+ """
+ Set Box geometry data.
+
+ :param numpy.array size: Size (dx, dy, dz) of the box(es).
+ :param numpy.array color: RGB color of the box(es).
+ :param numpy.ndarray position:
+ Center position (x, y, z) of each box as a (N, 3) array.
+ :param tuple(float, array) rotation:
+ Angle (in degrees) and axis of rotation.
+ If (0, (0, 0, 0)) (default), the hexagonal faces are on
+ xy plane and a side face is aligned with x axis.
+ """
+ self.position = numpy.atleast_2d(numpy.array(position, copy=True))
+ self.size = numpy.array(size, copy=True)
+ self.color = numpy.array(color, copy=True)
+ self.rotation = Rotate(rotation[0],
+ rotation[1][0], rotation[1][1], rotation[1][2])
+
+ assert (numpy.ndim(self.color) == 1 or
+ len(self.color) == len(self.position))
+
+ diagonal = numpy.sqrt(self.size[0]**2 + self.size[1]**2)
+ alpha = 2 * numpy.arcsin(self.size[1] / diagonal)
+ beta = 2 * numpy.arcsin(self.size[0] / diagonal)
+ angles = numpy.array([0,
+ alpha,
+ alpha + beta,
+ alpha + beta + alpha,
+ 2 * numpy.pi])
+ numpy.subtract(angles, 0.5 * alpha, out=angles)
+ self._setData(self.position,
+ numpy.sqrt(self.size[0]**2 + self.size[1]**2)/2,
+ self.size[2],
+ angles,
+ self.color,
+ True,
+ self.rotation)
+
+ def getPosition(self, copy=True):
+ """Get box(es) position(s).
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ :return: Position of the box(es) as a (N, 3) array.
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self.position, copy=copy)
+
+ def getSize(self):
+ """Get box(es) size.
+
+ :return: Size (dx, dy, dz) of the box(es).
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self.size, copy=True)
+
+ def getColor(self, copy=True):
+ """Get box(es) color.
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ :return: RGB color of the box(es).
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self.color, copy=copy)
+
+
+class Cylinder(_CylindricalVolume):
+ """Description of a cylinder.
+
+ Can be used to draw one cylinder or many similar cylinders.
+
+ :param parent: The View widget this item belongs to.
+ """
+
+ def __init__(self, parent=None):
+ super(Cylinder, self).__init__(parent)
+ self.position = None
+ self.radius = None
+ self.height = None
+ self.color = None
+ self.nbFaces = 0
+ self.rotation = None
+ self.setData()
+
+ def setData(self, radius=1, height=1, color=(1, 1, 1), nbFaces=20,
+ position=(0, 0, 0), rotation=(0, (0, 0, 0))):
+ """
+ Set the cylinder geometry data
+
+ :param float radius: Radius of the cylinder(s).
+ :param float height: Height of the cylinder(s).
+ :param numpy.array color: RGB color of the cylinder(s).
+ :param int nbFaces:
+ Number of faces for cylinder approximation (default 20).
+ :param numpy.ndarray position:
+ Center position (x, y, z) of each cylinder as a (N, 3) array.
+ :param tuple(float, array) rotation:
+ Angle (in degrees) and axis of rotation.
+ If (0, (0, 0, 0)) (default), the hexagonal faces are on
+ xy plane and a side face is aligned with x axis.
+ """
+ self.position = numpy.atleast_2d(numpy.array(position, copy=True))
+ self.radius = float(radius)
+ self.height = float(height)
+ self.color = numpy.array(color, copy=True)
+ self.nbFaces = int(nbFaces)
+ self.rotation = Rotate(rotation[0],
+ rotation[1][0], rotation[1][1], rotation[1][2])
+
+ assert (numpy.ndim(self.color) == 1 or
+ len(self.color) == len(self.position))
+
+ angles = numpy.linspace(0, 2*numpy.pi, self.nbFaces + 1)
+ self._setData(self.position,
+ self.radius,
+ self.height,
+ angles,
+ self.color,
+ False,
+ self.rotation)
+
+ def getPosition(self, copy=True):
+ """Get cylinder(s) position(s).
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ :return: Position(s) of the cylinder(s) as a (N, 3) array.
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self.position, copy=copy)
+
+ def getRadius(self):
+ """Get cylinder(s) radius.
+
+ :return: Radius of the cylinder(s).
+ :rtype: float
+ """
+ return self.radius
+
+ def getHeight(self):
+ """Get cylinder(s) height.
+
+ :return: Height of the cylinder(s).
+ :rtype: float
+ """
+ return self.height
+
+ def getColor(self, copy=True):
+ """Get cylinder(s) color.
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ :return: RGB color of the cylinder(s).
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self.color, copy=copy)
+
+
+class Hexagon(_CylindricalVolume):
+ """Description of a uniform hexagonal prism.
+
+ Can be used to draw one hexagonal prim or many similar hexagonal
+ prisms.
+
+ :param parent: The View widget this item belongs to.
+ """
+
+ def __init__(self, parent=None):
+ super(Hexagon, self).__init__(parent)
+ self.position = None
+ self.radius = 0
+ self.height = 0
+ self.color = None
+ self.rotation = None
+ self.setData()
+
+ def setData(self, radius=1, height=1, color=(1, 1, 1),
+ position=(0, 0, 0), rotation=(0, (0, 0, 0))):
+ """
+ Set the uniform hexagonal prism geometry data
+
+ :param float radius: External radius of the hexagonal prism
+ :param float height: Height of the hexagonal prism
+ :param numpy.array color: RGB color of the prism(s)
+ :param numpy.ndarray position:
+ Center position (x, y, z) of each prism as a (N, 3) array
+ :param tuple(float, array) rotation:
+ Angle (in degrees) and axis of rotation.
+ If (0, (0, 0, 0)) (default), the hexagonal faces are on
+ xy plane and a side face is aligned with x axis.
+ """
+ self.position = numpy.atleast_2d(numpy.array(position, copy=True))
+ self.radius = float(radius)
+ self.height = float(height)
+ self.color = numpy.array(color, copy=True)
+ self.rotation = Rotate(rotation[0], rotation[1][0], rotation[1][1],
+ rotation[1][2])
+
+ assert (numpy.ndim(self.color) == 1 or
+ len(self.color) == len(self.position))
+
+ angles = numpy.linspace(0, 2*numpy.pi, 7)
+ self._setData(self.position,
+ self.radius,
+ self.height,
+ angles,
+ self.color,
+ True,
+ self.rotation)
+
+ def getPosition(self, copy=True):
+ """Get hexagonal prim(s) position(s).
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ :return: Position(s) of hexagonal prism(s) as a (N, 3) array.
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self.position, copy=copy)
+
+ def getRadius(self):
+ """Get hexagonal prism(s) radius.
+
+ :return: Radius of hexagon(s).
+ :rtype: float
+ """
+ return self.radius
+
+ def getHeight(self):
+ """Get hexagonal prism(s) height.
+
+ :return: Height of hexagonal prism(s).
+ :rtype: float
+ """
+ return self.height
+
+ def getColor(self, copy=True):
+ """Get hexagonal prism(s) color.
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ :return: RGB color of the hexagonal prism(s).
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self.color, copy=copy)
diff --git a/src/silx/gui/plot3d/items/mixins.py b/src/silx/gui/plot3d/items/mixins.py
new file mode 100644
index 0000000..f512365
--- /dev/null
+++ b/src/silx/gui/plot3d/items/mixins.py
@@ -0,0 +1,288 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides mix-in classes for :class:`Item3D`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+
+import collections
+import numpy
+
+from silx.math.combo import min_max
+
+from ...plot.items.core import ItemMixInBase
+from ...plot.items.core import ColormapMixIn as _ColormapMixIn
+from ...plot.items.core import SymbolMixIn as _SymbolMixIn
+from ...plot.items.core import ComplexMixIn as _ComplexMixIn
+from ...colors import rgba
+
+from ..scene import primitives
+from .core import Item3DChangedType, ItemChangedType
+
+
+class InterpolationMixIn(ItemMixInBase):
+ """Mix-in class for image interpolation mode
+
+ :param str mode: 'linear' (default) or 'nearest'
+ :param primitive:
+ scene object for which to sync interpolation mode.
+ This object MUST have an interpolation property that is updated.
+ """
+
+ NEAREST_INTERPOLATION = 'nearest'
+ """Nearest interpolation mode (see :meth:`setInterpolation`)"""
+
+ LINEAR_INTERPOLATION = 'linear'
+ """Linear interpolation mode (see :meth:`setInterpolation`)"""
+
+ INTERPOLATION_MODES = NEAREST_INTERPOLATION, LINEAR_INTERPOLATION
+ """Supported interpolation modes for :meth:`setInterpolation`"""
+
+ def __init__(self, mode=NEAREST_INTERPOLATION, primitive=None):
+ self.__primitive = primitive
+ self._syncPrimitiveInterpolation()
+
+ self.__interpolationMode = None
+ self.setInterpolation(mode)
+
+ def _setPrimitive(self, primitive):
+
+ """Set the scene object for which to sync interpolation"""
+ self.__primitive = primitive
+ self._syncPrimitiveInterpolation()
+
+ def _syncPrimitiveInterpolation(self):
+ """Synchronize scene object's interpolation"""
+ if self.__primitive is not None:
+ self.__primitive.interpolation = self.getInterpolation()
+
+ def setInterpolation(self, mode):
+ """Set image interpolation mode
+
+ :param str mode: 'nearest' or 'linear'
+ """
+ mode = str(mode)
+ assert mode in self.INTERPOLATION_MODES
+ if mode != self.__interpolationMode:
+ self.__interpolationMode = mode
+ self._syncPrimitiveInterpolation()
+ self._updated(Item3DChangedType.INTERPOLATION)
+
+ def getInterpolation(self):
+ """Returns the interpolation mode set by :meth:`setInterpolation`
+
+ :rtype: str
+ """
+ return self.__interpolationMode
+
+
+class ColormapMixIn(_ColormapMixIn):
+ """Mix-in class for Item3D object with a colormap
+
+ :param sceneColormap:
+ The plot3d scene colormap to sync with Colormap object.
+ """
+
+ def __init__(self, sceneColormap=None):
+ super(ColormapMixIn, self).__init__()
+
+ self.__sceneColormap = sceneColormap
+ self._syncSceneColormap()
+
+ def _colormapChanged(self):
+ """Handle colormap updates"""
+ self._syncSceneColormap()
+ super(ColormapMixIn, self)._colormapChanged()
+
+ def _setSceneColormap(self, sceneColormap):
+ """Set the scene colormap to sync with Colormap object.
+
+ :param sceneColormap:
+ The plot3d scene colormap to sync with Colormap object.
+ """
+ self.__sceneColormap = sceneColormap
+ self._syncSceneColormap()
+
+ def _getSceneColormap(self):
+ """Returns scene colormap that is sync"""
+ return self.__sceneColormap
+
+ def _syncSceneColormap(self):
+ """Synchronizes scene's colormap with Colormap object"""
+ if self.__sceneColormap is not None:
+ colormap = self.getColormap()
+
+ self.__sceneColormap.colormap = colormap.getNColors()
+ self.__sceneColormap.norm = colormap.getNormalization()
+ self.__sceneColormap.gamma = colormap.getGammaNormalizationParameter()
+ self.__sceneColormap.range_ = colormap.getColormapRange(self)
+ self.__sceneColormap.nancolor = rgba(colormap.getNaNColor())
+
+
+class ComplexMixIn(_ComplexMixIn):
+ __doc__ = _ComplexMixIn.__doc__ # Reuse docstring
+
+ _SUPPORTED_COMPLEX_MODES = (
+ _ComplexMixIn.ComplexMode.REAL,
+ _ComplexMixIn.ComplexMode.IMAGINARY,
+ _ComplexMixIn.ComplexMode.ABSOLUTE,
+ _ComplexMixIn.ComplexMode.PHASE,
+ _ComplexMixIn.ComplexMode.SQUARE_AMPLITUDE)
+ """Overrides supported ComplexMode"""
+
+
+class SymbolMixIn(_SymbolMixIn):
+ """Mix-in class for symbol and symbolSize properties for Item3D"""
+
+ _SUPPORTED_SYMBOLS = collections.OrderedDict((
+ ('o', 'Circle'),
+ ('d', 'Diamond'),
+ ('s', 'Square'),
+ ('+', 'Plus'),
+ ('x', 'Cross'),
+ ('*', 'Star'),
+ ('|', 'Vertical Line'),
+ ('_', 'Horizontal Line'),
+ ('.', 'Point'),
+ (',', 'Pixel')))
+
+ def _getSceneSymbol(self):
+ """Returns a symbol name and size suitable for scene primitives.
+
+ :return: (symbol, size)
+ """
+ symbol = self.getSymbol()
+ size = self.getSymbolSize()
+ if symbol == ',': # pixel
+ return 's', 1.
+ elif symbol == '.': # point
+ # Size as in plot OpenGL backend, mimic matplotlib
+ return 'o', numpy.ceil(0.5 * size) + 1.
+ else:
+ return symbol, size
+
+
+class PlaneMixIn(ItemMixInBase):
+ """Mix-in class for plane items (based on PlaneInGroup primitive)"""
+
+ def __init__(self, plane):
+ assert isinstance(plane, primitives.PlaneInGroup)
+ self.__plane = plane
+ self.__plane.alpha = 1.
+ self.__plane.addListener(self._planeChanged)
+ self.__plane.plane.addListener(self._planePositionChanged)
+
+ def _getPlane(self):
+ """Returns plane primitive
+
+ :rtype: primitives.PlaneInGroup
+ """
+ return self.__plane
+
+ def _planeChanged(self, source, *args, **kwargs):
+ """Handle events from the plane primitive"""
+ # Sync visibility
+ if source.visible != self.isVisible():
+ self.setVisible(source.visible)
+
+ def _planePositionChanged(self, source, *args, **kwargs):
+ """Handle update of cut plane position and normal"""
+ if self.__plane.visible: # TODO send even if hidden? or send also when showing if moved while hidden
+ self._updated(ItemChangedType.POSITION)
+
+ # Plane position
+
+ def moveToCenter(self):
+ """Move cut plane to center of data set"""
+ self.__plane.moveToCenter()
+
+ def isValid(self):
+ """Returns whether the cut plane is defined or not (bool)"""
+ return self.__plane.isValid
+
+ def getNormal(self):
+ """Returns the normal of the plane (as a unit vector)
+
+ :return: Normal (nx, ny, nz), vector is 0 if no plane is defined
+ :rtype: numpy.ndarray
+ """
+ return self.__plane.plane.normal
+
+ def setNormal(self, normal):
+ """Set the normal of the plane
+
+ :param normal: 3-tuple of float: nx, ny, nz
+ """
+ self.__plane.plane.normal = normal
+
+ def getPoint(self):
+ """Returns a point on the plane
+
+ :return: (x, y, z)
+ :rtype: numpy.ndarray
+ """
+ return self.__plane.plane.point
+
+ def setPoint(self, point):
+ """Set a point contained in the plane.
+
+ Warning: The plane might not intersect the bounding box of the data.
+
+ :param point: (x, y, z) position
+ :type point: 3-tuple of float
+ """
+ self.__plane.plane.point = point # TODO rework according to PR #1303
+
+ def getParameters(self):
+ """Returns the plane equation parameters: a*x + b*y + c*z + d = 0
+
+ :return: Plane equation parameters: (a, b, c, d)
+ :rtype: numpy.ndarray
+ """
+ return self.__plane.plane.parameters
+
+ def setParameters(self, parameters):
+ """Set the plane equation parameters: a*x + b*y + c*z + d = 0
+
+ Warning: The plane might not intersect the bounding box of the data.
+ The given parameters will be normalized.
+
+ :param parameters: (a, b, c, d) equation parameters
+ """
+ self.__plane.plane.parameters = parameters
+
+ # Border stroke
+
+ def _setForegroundColor(self, color):
+ """Set the color of the plane border.
+
+ :param color: RGBA color as 4 floats in [0, 1]
+ """
+ self.__plane.color = rgba(color)
+ if hasattr(super(PlaneMixIn, self), '_setForegroundColor'):
+ super(PlaneMixIn, self)._setForegroundColor(color)
diff --git a/src/silx/gui/plot3d/items/scatter.py b/src/silx/gui/plot3d/items/scatter.py
new file mode 100644
index 0000000..24abaa5
--- /dev/null
+++ b/src/silx/gui/plot3d/items/scatter.py
@@ -0,0 +1,617 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides 2D and 3D scatter data item class.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "15/11/2017"
+
+try:
+ from collections import abc
+except ImportError: # Python2 support
+ import collections as abc
+import logging
+import numpy
+
+from ....utils.deprecation import deprecated
+from ... import _glutils as glu
+from ...plot._utils.delaunay import delaunay
+from ..scene import function, primitives, utils
+
+from ...plot.items import ScatterVisualizationMixIn
+from .core import DataItem3D, Item3DChangedType, ItemChangedType
+from .mixins import ColormapMixIn, SymbolMixIn
+from ._pick import PickingResult
+
+
+_logger = logging.getLogger(__name__)
+
+
+class Scatter3D(DataItem3D, ColormapMixIn, SymbolMixIn):
+ """Description of a 3D scatter plot.
+
+ :param parent: The View widget this item belongs to.
+ """
+
+ # TODO supports different size for each point
+
+ def __init__(self, parent=None):
+ DataItem3D.__init__(self, parent=parent)
+ ColormapMixIn.__init__(self)
+ SymbolMixIn.__init__(self)
+
+ noData = numpy.zeros((0, 1), dtype=numpy.float32)
+ symbol, size = self._getSceneSymbol()
+ self._scatter = primitives.Points(
+ x=noData, y=noData, z=noData, value=noData, size=size)
+ self._scatter.marker = symbol
+ self._getScenePrimitive().children.append(self._scatter)
+
+ # Connect scene primitive to mix-in class
+ ColormapMixIn._setSceneColormap(self, self._scatter.colormap)
+
+ def _updated(self, event=None):
+ """Handle mix-in class updates"""
+ if event in (ItemChangedType.SYMBOL, ItemChangedType.SYMBOL_SIZE):
+ symbol, size = self._getSceneSymbol()
+ self._scatter.marker = symbol
+ self._scatter.setAttribute('size', size, copy=True)
+
+ super(Scatter3D, self)._updated(event)
+
+ def setData(self, x, y, z, value, copy=True):
+ """Set the data of the scatter plot
+
+ :param numpy.ndarray x: Array of X coordinates (single value not accepted)
+ :param y: Points Y coordinate (array-like or single value)
+ :param z: Points Z coordinate (array-like or single value)
+ :param value: Points values (array-like or single value)
+ :param bool copy:
+ True (default) to copy the data,
+ False to use provided data (do not modify!)
+ """
+ self._scatter.setAttribute('x', x, copy=copy)
+ self._scatter.setAttribute('y', y, copy=copy)
+ self._scatter.setAttribute('z', z, copy=copy)
+ self._scatter.setAttribute('value', value, copy=copy)
+
+ self._setColormappedData(self.getValueData(copy=False), copy=False)
+ self._updated(ItemChangedType.DATA)
+
+ def getData(self, copy=True):
+ """Returns data as provided to :meth:`setData`.
+
+ :param bool copy: True to get a copy,
+ False to return internal data (do not modify!)
+ :return: (x, y, z, value)
+ """
+ return (self.getXData(copy),
+ self.getYData(copy),
+ self.getZData(copy),
+ self.getValueData(copy))
+
+ def getXData(self, copy=True):
+ """Returns X data coordinates.
+
+ :param bool copy: True to get a copy,
+ False to return internal array (do not modify!)
+ :return: X coordinates
+ :rtype: numpy.ndarray
+ """
+ return self._scatter.getAttribute('x', copy=copy).reshape(-1)
+
+ def getYData(self, copy=True):
+ """Returns Y data coordinates.
+
+ :param bool copy: True to get a copy,
+ False to return internal array (do not modify!)
+ :return: Y coordinates
+ :rtype: numpy.ndarray
+ """
+ return self._scatter.getAttribute('y', copy=copy).reshape(-1)
+
+ def getZData(self, copy=True):
+ """Returns Z data coordinates.
+
+ :param bool copy: True to get a copy,
+ False to return internal array (do not modify!)
+ :return: Z coordinates
+ :rtype: numpy.ndarray
+ """
+ return self._scatter.getAttribute('z', copy=copy).reshape(-1)
+
+ def getValueData(self, copy=True):
+ """Returns data values.
+
+ :param bool copy: True to get a copy,
+ False to return internal array (do not modify!)
+ :return: data values
+ :rtype: numpy.ndarray
+ """
+ return self._scatter.getAttribute('value', copy=copy).reshape(-1)
+
+ @deprecated(reason="Consistency with PlotWidget items",
+ replacement="getValueData", since_version="0.10.0")
+ def getValues(self, copy=True):
+ return self.getValueData(copy)
+
+ def _pickFull(self, context, threshold=0., sort='depth'):
+ """Perform picking in this item at given widget position.
+
+ :param PickContext context: Current picking context
+ :param float threshold: Picking threshold in pixel.
+ Perform picking in a square of size threshold x threshold.
+ :param str sort: How returned indices are sorted:
+
+ - 'index' (default): sort by the value of the indices
+ - 'depth': Sort by the depth of the points from the current
+ camera point of view.
+ :return: Object holding the results or None
+ :rtype: Union[None,PickingResult]
+ """
+ assert sort in ('index', 'depth')
+
+ rayNdc = context.getPickingSegment(frame='ndc')
+ if rayNdc is None: # No picking outside viewport
+ return None
+
+ # Project data to NDC
+ xData = self.getXData(copy=False)
+ if len(xData) == 0: # No data in the scatter
+ return None
+
+ primitive = self._getScenePrimitive()
+
+ dataPoints = numpy.transpose((xData,
+ self.getYData(copy=False),
+ self.getZData(copy=False),
+ numpy.ones_like(xData)))
+
+ pointsNdc = primitive.objectToNDCTransform.transformPoints(
+ dataPoints, perspectiveDivide=True)
+
+ # Perform picking
+ distancesNdc = numpy.abs(pointsNdc[:, :2] - rayNdc[0, :2])
+ # TODO issue with symbol size: using pixel instead of points
+ threshold += self.getSymbolSize()
+ thresholdNdc = 2. * threshold / numpy.array(primitive.viewport.size)
+ picked = numpy.where(numpy.logical_and(
+ numpy.all(distancesNdc < thresholdNdc, axis=1),
+ numpy.logical_and(rayNdc[0, 2] <= pointsNdc[:, 2],
+ pointsNdc[:, 2] <= rayNdc[1, 2])))[0]
+
+ if sort == 'depth':
+ # Sort picked points from front to back
+ picked = picked[numpy.argsort(pointsNdc[picked, 2])]
+
+ if picked.size > 0:
+ return PickingResult(self,
+ positions=dataPoints[picked, :3],
+ indices=picked,
+ fetchdata=self.getValueData)
+ else:
+ return None
+
+
+class Scatter2D(DataItem3D, ColormapMixIn, SymbolMixIn,
+ ScatterVisualizationMixIn):
+ """2D scatter data with settable visualization mode.
+
+ :param parent: The View widget this item belongs to.
+ """
+
+ _VISUALIZATION_PROPERTIES = {
+ ScatterVisualizationMixIn.Visualization.POINTS:
+ ('symbol', 'symbolSize'),
+ ScatterVisualizationMixIn.Visualization.LINES:
+ ('lineWidth',),
+ ScatterVisualizationMixIn.Visualization.SOLID: (),
+ }
+ """Dict {visualization mode: property names used in this mode}"""
+
+ _SUPPORTED_SCATTER_VISUALIZATION = tuple(_VISUALIZATION_PROPERTIES.keys())
+ """Overrides supported Visualizations"""
+
+ def __init__(self, parent=None):
+ DataItem3D.__init__(self, parent=parent)
+ ColormapMixIn.__init__(self)
+ SymbolMixIn.__init__(self)
+ ScatterVisualizationMixIn.__init__(self)
+
+ self._heightMap = False
+ self._lineWidth = 1.
+
+ self._x = numpy.zeros((0,), dtype=numpy.float32)
+ self._y = numpy.zeros((0,), dtype=numpy.float32)
+ self._value = numpy.zeros((0,), dtype=numpy.float32)
+
+ self._cachedLinesIndices = None
+ self._cachedTrianglesIndices = None
+
+ # Connect scene primitive to mix-in class
+ ColormapMixIn._setSceneColormap(self, function.Colormap())
+
+ def _updated(self, event=None):
+ """Handle mix-in class updates"""
+ if event in (ItemChangedType.SYMBOL, ItemChangedType.SYMBOL_SIZE):
+ symbol, size = self._getSceneSymbol()
+ for child in self._getScenePrimitive().children:
+ if isinstance(child, primitives.Points):
+ child.marker = symbol
+ child.setAttribute('size', size, copy=True)
+
+ elif event is ItemChangedType.VISIBLE:
+ # TODO smart update?, need dirty flags
+ self._updateScene()
+
+ elif event is ItemChangedType.VISUALIZATION_MODE:
+ self._updateScene()
+
+ super(Scatter2D, self)._updated(event)
+
+ def isPropertyEnabled(self, name, visualization=None):
+ """Returns true if the property is used with visualization mode.
+
+ :param str name: The name of the property to check, in:
+ 'lineWidth', 'symbol', 'symbolSize'
+ :param str visualization:
+ The visualization mode for which to get the info.
+ By default, it is the current visualization mode.
+ :return:
+ """
+ assert name in ('lineWidth', 'symbol', 'symbolSize')
+ if visualization is None:
+ visualization = self.getVisualization()
+ assert visualization in self.supportedVisualizations()
+ return name in self._VISUALIZATION_PROPERTIES[visualization]
+
+ def setHeightMap(self, heightMap):
+ """Set whether to display the data has a height map or not.
+
+ When displayed as a height map, the data values are used as
+ z coordinates.
+
+ :param bool heightMap:
+ True to display a height map,
+ False to display as 2D data with z=0
+ """
+ heightMap = bool(heightMap)
+ if heightMap != self.isHeightMap():
+ self._heightMap = heightMap
+ self._updateScene()
+ self._updated(Item3DChangedType.HEIGHT_MAP)
+
+ def isHeightMap(self):
+ """Returns True if data is displayed as a height map.
+
+ :rtype: bool
+ """
+ return self._heightMap
+
+ def getLineWidth(self):
+ """Return the curve line width in pixels (float)"""
+ return self._lineWidth
+
+ def setLineWidth(self, width):
+ """Set the width in pixel of the curve line
+
+ See :meth:`getLineWidth`.
+
+ :param float width: Width in pixels
+ """
+ width = float(width)
+ assert width >= 1.
+ if width != self._lineWidth:
+ self._lineWidth = width
+ for child in self._getScenePrimitive().children:
+ if hasattr(child, 'lineWidth'):
+ child.lineWidth = width
+ self._updated(ItemChangedType.LINE_WIDTH)
+
+ def setData(self, x, y, value, copy=True):
+ """Set the data represented by this item.
+
+ Provided arrays must have the same length.
+
+ :param numpy.ndarray x: X coordinates (array-like)
+ :param numpy.ndarray y: Y coordinates (array-like)
+ :param value: Points value: array-like or single scalar
+ :param bool copy:
+ True (default) to make a copy of the data,
+ False to avoid copy if possible (do not modify the arrays).
+ """
+ x = numpy.array(
+ x, copy=copy, dtype=numpy.float32, order='C').reshape(-1)
+ y = numpy.array(
+ y, copy=copy, dtype=numpy.float32, order='C').reshape(-1)
+ assert len(x) == len(y)
+
+ if isinstance(value, abc.Iterable):
+ value = numpy.array(
+ value, copy=copy, dtype=numpy.float32, order='C').reshape(-1)
+ assert len(value) == len(x)
+ else: # Single scalar
+ value = numpy.array((float(value),), dtype=numpy.float32)
+
+ self._x = x
+ self._y = y
+ self._value = value
+
+ # Reset cache
+ self._cachedLinesIndices = None
+ self._cachedTrianglesIndices = None
+
+ self._setColormappedData(self.getValueData(copy=False), copy=False)
+
+ self._updateScene()
+
+ self._updated(ItemChangedType.DATA)
+
+ def getData(self, copy=True):
+ """Returns data as provided to :meth:`setData`.
+
+ :param bool copy: True to get a copy,
+ False to return internal data (do not modify!)
+ :return: (x, y, value)
+ """
+ return (self.getXData(copy=copy),
+ self.getYData(copy=copy),
+ self.getValueData(copy=copy))
+
+ def getXData(self, copy=True):
+ """Returns X data coordinates.
+
+ :param bool copy: True to get a copy,
+ False to return internal array (do not modify!)
+ :return: X coordinates
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self._x, copy=copy)
+
+ def getYData(self, copy=True):
+ """Returns Y data coordinates.
+
+ :param bool copy: True to get a copy,
+ False to return internal array (do not modify!)
+ :return: Y coordinates
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self._y, copy=copy)
+
+ def getValueData(self, copy=True):
+ """Returns data values.
+
+ :param bool copy: True to get a copy,
+ False to return internal array (do not modify!)
+ :return: data values
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self._value, copy=copy)
+
+ @deprecated(reason="Consistency with PlotWidget items",
+ replacement="getValueData", since_version="0.10.0")
+ def getValues(self, copy=True):
+ return self.getValueData(copy)
+
+ def _pickPoints(self, context, points, threshold=1., sort='depth'):
+ """Perform picking while in 'points' visualization mode
+
+ :param PickContext context: Current picking context
+ :param float threshold: Picking threshold in pixel.
+ Perform picking in a square of size threshold x threshold.
+ :param str sort: How returned indices are sorted:
+
+ - 'index' (default): sort by the value of the indices
+ - 'depth': Sort by the depth of the points from the current
+ camera point of view.
+ :return: Object holding the results or None
+ :rtype: Union[None,PickingResult]
+ """
+ assert sort in ('index', 'depth')
+
+ rayNdc = context.getPickingSegment(frame='ndc')
+ if rayNdc is None: # No picking outside viewport
+ return None
+
+ # Project data to NDC
+ primitive = self._getScenePrimitive()
+ pointsNdc = primitive.objectToNDCTransform.transformPoints(
+ points, perspectiveDivide=True)
+
+ # Perform picking
+ distancesNdc = numpy.abs(pointsNdc[:, :2] - rayNdc[0, :2])
+ thresholdNdc = threshold / numpy.array(primitive.viewport.size)
+ picked = numpy.where(numpy.logical_and(
+ numpy.all(distancesNdc < thresholdNdc, axis=1),
+ numpy.logical_and(rayNdc[0, 2] <= pointsNdc[:, 2],
+ pointsNdc[:, 2] <= rayNdc[1, 2])))[0]
+
+ if sort == 'depth':
+ # Sort picked points from front to back
+ picked = picked[numpy.argsort(pointsNdc[picked, 2])]
+
+ if picked.size > 0:
+ return PickingResult(self,
+ positions=points[picked, :3],
+ indices=picked,
+ fetchdata=self.getValueData)
+ else:
+ return None
+
+ def _pickSolid(self, context, points):
+ """Perform picking while in 'solid' visualization mode
+
+ :param PickContext context: Current picking context
+ """
+ if self._cachedTrianglesIndices is None:
+ _logger.info("Picking on Scatter2D before rendering")
+ return None
+
+ rayObject = context.getPickingSegment(frame=self._getScenePrimitive())
+ if rayObject is None: # No picking outside viewport
+ return None
+ rayObject = rayObject[:, :3]
+
+ trianglesIndices = self._cachedTrianglesIndices.reshape(-1, 3)
+ triangles = points[trianglesIndices, :3]
+ selectedIndices, t, barycentric = glu.segmentTrianglesIntersection(
+ rayObject, triangles)
+ closest = numpy.argmax(barycentric, axis=1)
+
+ indices = trianglesIndices.reshape(-1, 3)[selectedIndices, closest]
+
+ if len(indices) == 0: # No point is picked
+ return None
+
+ # Compute intersection points and get closest data point
+ positions = t.reshape(-1, 1) * (rayObject[1] - rayObject[0]) + rayObject[0]
+
+ return PickingResult(self,
+ positions=positions,
+ indices=indices,
+ fetchdata=self.getValueData)
+
+ def _pickFull(self, context):
+ """Perform picking in this item at given widget position.
+
+ :param PickContext context: Current picking context
+ :return: Object holding the results or None
+ :rtype: Union[None,PickingResult]
+ """
+ xData = self.getXData(copy=False)
+ if len(xData) == 0: # No data in the scatter
+ return None
+
+ if self.isHeightMap():
+ zData = self.getValueData(copy=False)
+ else:
+ zData = numpy.zeros_like(xData)
+
+ points = numpy.transpose((xData,
+ self.getYData(copy=False),
+ zData,
+ numpy.ones_like(xData)))
+
+ mode = self.getVisualization()
+ if mode is self.Visualization.POINTS:
+ # TODO issue with symbol size: using pixel instead of points
+ # Get "corrected" symbol size
+ _, threshold = self._getSceneSymbol()
+ return self._pickPoints(
+ context, points, threshold=max(3., threshold))
+
+ elif mode is self.Visualization.LINES:
+ # Picking only at point
+ return self._pickPoints(context, points, threshold=5.)
+
+ else: # mode == 'solid'
+ return self._pickSolid(context, points)
+
+ def _updateScene(self):
+ self._getScenePrimitive().children = [] # Remove previous primitives
+
+ if not self.isVisible():
+ return # Update when visible
+
+ x, y, value = self.getData(copy=False)
+ if len(x) == 0:
+ return # Nothing to display
+
+ mode = self.getVisualization()
+ heightMap = self.isHeightMap()
+
+ if mode is self.Visualization.POINTS:
+ z = value if heightMap else 0.
+ symbol, size = self._getSceneSymbol()
+ primitive = primitives.Points(
+ x=x, y=y, z=z, value=value,
+ size=size,
+ colormap=self._getSceneColormap())
+ primitive.marker = symbol
+
+ else:
+ # TODO run delaunay in a thread
+ # Compute lines/triangles indices if not cached
+ if self._cachedTrianglesIndices is None:
+ triangulation = delaunay(x, y)
+ if triangulation is None:
+ return None
+ self._cachedTrianglesIndices = numpy.ravel(
+ triangulation.simplices.astype(numpy.uint32))
+
+ if (mode is self.Visualization.LINES and
+ self._cachedLinesIndices is None):
+ # Compute line indices
+ self._cachedLinesIndices = utils.triangleToLineIndices(
+ self._cachedTrianglesIndices, unicity=True)
+
+ if mode is self.Visualization.LINES:
+ indices = self._cachedLinesIndices
+ renderMode = 'lines'
+ else:
+ indices = self._cachedTrianglesIndices
+ renderMode = 'triangles'
+
+ # TODO supports x, y instead of copy
+ if heightMap:
+ if len(value) == 1:
+ value = numpy.ones_like(x) * value
+ coordinates = numpy.array((x, y, value), dtype=numpy.float32).T
+ else:
+ coordinates = numpy.array((x, y), dtype=numpy.float32).T
+
+ # TODO option to enable/disable light, cache normals
+ # TODO smooth surface
+ if mode is self.Visualization.SOLID:
+ if heightMap:
+ coordinates = coordinates[indices]
+ if len(value) > 1:
+ value = value[indices]
+ triangleNormals = utils.trianglesNormal(coordinates)
+ normal = numpy.empty((len(triangleNormals) * 3, 3),
+ dtype=numpy.float32)
+ normal[0::3, :] = triangleNormals
+ normal[1::3, :] = triangleNormals
+ normal[2::3, :] = triangleNormals
+ indices = None
+ else:
+ normal = (0., 0., 1.)
+ else:
+ normal = None
+
+ primitive = primitives.ColormapMesh3D(
+ coordinates,
+ value.reshape(-1, 1), # Makes it a 2D array
+ normal=normal,
+ colormap=self._getSceneColormap(),
+ indices=indices,
+ mode=renderMode)
+ primitive.lineWidth = self.getLineWidth()
+ primitive.lineSmooth = False
+
+ self._getScenePrimitive().children = [primitive]
diff --git a/src/silx/gui/plot3d/items/volume.py b/src/silx/gui/plot3d/items/volume.py
new file mode 100644
index 0000000..f80fea2
--- /dev/null
+++ b/src/silx/gui/plot3d/items/volume.py
@@ -0,0 +1,886 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides 3D array item class and its sub-items.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+import logging
+import time
+import numpy
+
+from silx.math.combo import min_max
+from silx.math.marchingcubes import MarchingCubes
+from silx.math.interpolate import interp3d
+
+from ....utils.proxy import docstring
+from ... import _glutils as glu
+from ... import qt
+from ...colors import rgba
+
+from ..scene import cutplane, function, primitives, transform, utils
+
+from .core import BaseNodeItem, Item3D, ItemChangedType, Item3DChangedType
+from .mixins import ColormapMixIn, ComplexMixIn, InterpolationMixIn, PlaneMixIn
+from ._pick import PickingResult
+
+
+_logger = logging.getLogger(__name__)
+
+
+class CutPlane(Item3D, ColormapMixIn, InterpolationMixIn, PlaneMixIn):
+ """Class representing a cutting plane in a :class:`ScalarField3D` item.
+
+ :param parent: 3D Data set in which the cut plane is applied.
+ """
+
+ def __init__(self, parent):
+ plane = cutplane.CutPlane(normal=(0, 1, 0))
+
+ Item3D.__init__(self, parent=None)
+ ColormapMixIn.__init__(self)
+ InterpolationMixIn.__init__(self)
+ PlaneMixIn.__init__(self, plane=plane)
+
+ self._dataRange = None
+ self._data = None
+
+ self._getScenePrimitive().children = [plane]
+
+ # Connect scene primitive to mix-in class
+ ColormapMixIn._setSceneColormap(self, plane.colormap)
+ InterpolationMixIn._setPrimitive(self, plane)
+
+ self.setParent(parent)
+
+ def _updateData(self, data, range_):
+ """Update used dataset
+
+ No copy is made.
+
+ :param Union[numpy.ndarray[float],None] data: The dataset
+ :param Union[List[float],None] range_:
+ (min, min positive, max) values
+ """
+ self._data = None if data is None else numpy.array(data, copy=False)
+ self._getPlane().setData(self._data, copy=False)
+
+ # Store data range info as 3-tuple of values
+ self._dataRange = range_
+ if range_ is None:
+ range_ = None, None, None
+ self._setColormappedData(self._data, copy=False,
+ min_=range_[0],
+ minPositive=range_[1],
+ max_=range_[2])
+
+ self._updated(ItemChangedType.DATA)
+
+ def _syncDataWithParent(self):
+ """Synchronize this instance data with that of its parent"""
+ parent = self.parent()
+ if parent is None:
+ data, range_ = None, None
+ else:
+ data = parent.getData(copy=False)
+ range_ = parent.getDataRange()
+ self._updateData(data, range_)
+
+ def _parentChanged(self, event):
+ """Handle data change in the parent this plane belongs to"""
+ if event == ItemChangedType.DATA:
+ self._syncDataWithParent()
+
+ def setParent(self, parent):
+ oldParent = self.parent()
+ if isinstance(oldParent, Item3D):
+ oldParent.sigItemChanged.disconnect(self._parentChanged)
+
+ super(CutPlane, self).setParent(parent)
+
+ if isinstance(parent, Item3D):
+ parent.sigItemChanged.connect(self._parentChanged)
+
+ self._syncDataWithParent()
+
+ # Colormap
+
+ def getDisplayValuesBelowMin(self):
+ """Return whether values <= colormap min are displayed or not.
+
+ :rtype: bool
+ """
+ return self._getPlane().colormap.displayValuesBelowMin
+
+ def setDisplayValuesBelowMin(self, display):
+ """Set whether to display values <= colormap min.
+
+ :param bool display: True to show values below min,
+ False to discard them
+ """
+ display = bool(display)
+ if display != self.getDisplayValuesBelowMin():
+ self._getPlane().colormap.displayValuesBelowMin = display
+ self._updated(ItemChangedType.ALPHA)
+
+ def getDataRange(self):
+ """Return the range of the data as a 3-tuple of values.
+
+ positive min is NaN if no data is positive.
+
+ :return: (min, positive min, max) or None.
+ :rtype: Union[List[float],None]
+ """
+ return None if self._dataRange is None else tuple(self._dataRange)
+
+ def getData(self, copy=True):
+ """Return 3D dataset.
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get the internal data (DO NOT modify!)
+ :return: The data set (or None if not set)
+ """
+ if self._data is None:
+ return None
+ else:
+ return numpy.array(self._data, copy=copy)
+
+ def _pickFull(self, context):
+ """Perform picking in this item at given widget position.
+
+ :param PickContext context: Current picking context
+ :return: Object holding the results or None
+ :rtype: Union[None,PickingResult]
+ """
+ rayObject = context.getPickingSegment(frame=self._getScenePrimitive())
+ if rayObject is None:
+ return None
+
+ points = utils.segmentPlaneIntersect(
+ rayObject[0, :3],
+ rayObject[1, :3],
+ planeNorm=self.getNormal(),
+ planePt=self.getPoint())
+
+ if len(points) == 1: # Single intersection
+ if numpy.any(points[0] < 0.):
+ return None # Outside volume
+ z, y, x = int(points[0][2]), int(points[0][1]), int(points[0][0])
+
+ data = self.getData(copy=False)
+ if data is None:
+ return None # No dataset
+
+ depth, height, width = data.shape
+ if z < depth and y < height and x < width:
+ return PickingResult(self,
+ positions=[points[0]],
+ indices=([z], [y], [x]))
+ else:
+ return None # Outside image
+ else: # Either no intersection or segment and image are coplanar
+ return None
+
+
+class Isosurface(Item3D):
+ """Class representing an iso-surface in a :class:`ScalarField3D` item.
+
+ :param parent: The DataItem3D this iso-surface belongs to
+ """
+
+ def __init__(self, parent):
+ Item3D.__init__(self, parent=None)
+ self._data = None
+ self._level = float('nan')
+ self._autoLevelFunction = None
+ self._color = rgba('#FFD700FF')
+ self.setParent(parent)
+
+ def _syncDataWithParent(self):
+ """Synchronize this instance data with that of its parent"""
+ parent = self.parent()
+ if parent is None:
+ self._data = None
+ else:
+ self._data = parent.getData(copy=False)
+ self._updateScenePrimitive()
+
+ def _parentChanged(self, event):
+ """Handle data change in the parent this isosurface belongs to"""
+ if event == ItemChangedType.DATA:
+ self._syncDataWithParent()
+
+ def setParent(self, parent):
+ oldParent = self.parent()
+ if isinstance(oldParent, Item3D):
+ oldParent.sigItemChanged.disconnect(self._parentChanged)
+
+ super(Isosurface, self).setParent(parent)
+
+ if isinstance(parent, Item3D):
+ parent.sigItemChanged.connect(self._parentChanged)
+
+ self._syncDataWithParent()
+
+ def getData(self, copy=True):
+ """Return 3D dataset.
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get the internal data (DO NOT modify!)
+ :return: The data set (or None if not set)
+ """
+ if self._data is None:
+ return None
+ else:
+ return numpy.array(self._data, copy=copy)
+
+ def getLevel(self):
+ """Return the level of this iso-surface (float)"""
+ return self._level
+
+ def setLevel(self, level):
+ """Set the value at which to build the iso-surface.
+
+ Setting this value reset auto-level function
+
+ :param float level: The value at which to build the iso-surface
+ """
+ self._autoLevelFunction = None
+ level = float(level)
+ if level != self._level:
+ self._level = level
+ self._updateScenePrimitive()
+ self._updated(Item3DChangedType.ISO_LEVEL)
+
+ def isAutoLevel(self):
+ """True if iso-level is rebuild for each data set."""
+ return self.getAutoLevelFunction() is not None
+
+ def getAutoLevelFunction(self):
+ """Return the function computing the iso-level (callable or None)"""
+ return self._autoLevelFunction
+
+ def setAutoLevelFunction(self, autoLevel):
+ """Set the function used to compute the iso-level.
+
+ WARNING: The function might get called in a thread.
+
+ :param callable autoLevel:
+ A function taking a 3D numpy.ndarray of float32 and returning
+ a float used as iso-level.
+ Example: numpy.mean(data) + numpy.std(data)
+ """
+ assert callable(autoLevel)
+ self._autoLevelFunction = autoLevel
+ self._updateScenePrimitive()
+
+ def getColor(self):
+ """Return the color of this iso-surface (QColor)"""
+ return qt.QColor.fromRgbF(*self._color)
+
+ def _updateColor(self, color):
+ """Handle update of color
+
+ :param List[float] color: RGBA channels in [0, 1]
+ """
+ primitive = self._getScenePrimitive()
+ if len(primitive.children) != 0:
+ primitive.children[0].setAttribute('color', color)
+
+ def setColor(self, color):
+ """Set the color of the iso-surface
+
+ :param color: RGBA color of the isosurface
+ :type color: QColor, str or array-like of 4 float in [0., 1.]
+ """
+ color = rgba(color)
+ if color != self._color:
+ self._color = color
+ self._updateColor(self._color)
+ self._updated(ItemChangedType.COLOR)
+
+ def _computeIsosurface(self):
+ """Compute isosurface for current state.
+
+ :return: (vertices, normals, indices) arrays
+ :rtype: List[Union[None,numpy.ndarray]]
+ """
+ data = self.getData(copy=False)
+
+ if data is None:
+ if self.isAutoLevel():
+ self._level = float('nan')
+
+ else:
+ if self.isAutoLevel():
+ st = time.time()
+ try:
+ level = float(self.getAutoLevelFunction()(data))
+
+ except Exception:
+ module_ = self.getAutoLevelFunction().__module__
+ name = self.getAutoLevelFunction().__name__
+ _logger.error(
+ "Error while executing iso level function %s.%s",
+ module_,
+ name,
+ exc_info=True)
+ level = float('nan')
+
+ else:
+ _logger.info(
+ 'Computed iso-level in %f s.', time.time() - st)
+
+ if level != self._level:
+ self._level = level
+ self._updated(Item3DChangedType.ISO_LEVEL)
+
+ if numpy.isfinite(self._level):
+ st = time.time()
+ vertices, normals, indices = MarchingCubes(
+ data,
+ isolevel=self._level)
+ _logger.info('Computed iso-surface in %f s.', time.time() - st)
+
+ if len(vertices) != 0:
+ return vertices, normals, indices
+
+ return None, None, None
+
+ def _updateScenePrimitive(self):
+ """Update underlying mesh"""
+ self._getScenePrimitive().children = []
+
+ vertices, normals, indices = self._computeIsosurface()
+ if vertices is not None:
+ mesh = primitives.Mesh3D(vertices,
+ colors=self._color,
+ normals=normals,
+ mode='triangles',
+ indices=indices,
+ copy=False)
+ self._getScenePrimitive().children = [mesh]
+
+ def _pickFull(self, context):
+ """Perform picking in this item at given widget position.
+
+ :param PickContext context: Current picking context
+ :return: Object holding the results or None
+ :rtype: Union[None,PickingResult]
+ """
+ rayObject = context.getPickingSegment(frame=self._getScenePrimitive())
+ if rayObject is None:
+ return None
+ rayObject = rayObject[:, :3]
+
+ data = self.getData(copy=False)
+ bins = utils.segmentVolumeIntersect(
+ rayObject, numpy.array(data.shape) - 1)
+ if bins is None:
+ return None
+
+ # gather bin data
+ offsets = [(i, j, k) for i in (0, 1) for j in (0, 1) for k in (0, 1)]
+ indices = bins[:, numpy.newaxis, :] + offsets
+ binsData = data[indices[:, :, 0], indices[:, :, 1], indices[:, :, 2]]
+ # binsData.shape = nbins, 8
+ # TODO up-to this point everything can be done once for all isosurfaces
+
+ # check bin candidates
+ level = self.getLevel()
+ mask = numpy.logical_and(numpy.nanmin(binsData, axis=1) <= level,
+ level <= numpy.nanmax(binsData, axis=1))
+ bins = bins[mask]
+ binsData = binsData[mask]
+
+ if len(bins) == 0:
+ return None # No bin candidate
+
+ # do picking on candidates
+ intersections = []
+ depths = []
+ for currentBin, data in zip(bins, binsData):
+ mc = MarchingCubes(data.reshape(2, 2, 2), isolevel=level)
+ points = mc.get_vertices() + currentBin
+ triangles = points[mc.get_indices()]
+ t = glu.segmentTrianglesIntersection(rayObject, triangles)[1]
+ t = numpy.unique(t) # Duplicates happen on triangle edges
+ if len(t) != 0:
+ # Compute intersection points and get closest data point
+ points = t.reshape(-1, 1) * (rayObject[1] - rayObject[0]) + rayObject[0]
+ # Get closest data points by rounding to int
+ intersections.extend(points)
+ depths.extend(t)
+
+ if len(intersections) == 0:
+ return None # No intersected triangles
+
+ intersections = numpy.array(intersections)[numpy.argsort(depths)]
+ indices = numpy.transpose(numpy.round(intersections).astype(numpy.int64))
+ return PickingResult(self, positions=intersections, indices=indices)
+
+
+class ScalarField3D(BaseNodeItem):
+ """3D scalar field on a regular grid.
+
+ :param parent: The View widget this item belongs to.
+ """
+
+ _CutPlane = CutPlane
+ """CutPlane class associated to this class"""
+
+ _Isosurface = Isosurface
+ """Isosurface classe associated to this class"""
+
+ def __init__(self, parent=None):
+ BaseNodeItem.__init__(self, parent=parent)
+
+ # Gives this item the shape of the data, no matter
+ # of the isosurface/cut plane size
+ self._boundedGroup = primitives.BoundedGroup()
+
+ # Store iso-surfaces
+ self._isosurfaces = []
+
+ self._data = None
+ self._dataRange = None
+
+ self._cutPlane = self._CutPlane(parent=self)
+ self._cutPlane.setVisible(False)
+
+ self._isogroup = primitives.GroupDepthOffset()
+ self._isogroup.transforms = [
+ # Convert from z, y, x from marching cubes to x, y, z
+ transform.Matrix((
+ (0., 0., 1., 0.),
+ (0., 1., 0., 0.),
+ (1., 0., 0., 0.),
+ (0., 0., 0., 1.))),
+ # Offset to match cutting plane coords
+ transform.Translate(0.5, 0.5, 0.5)
+ ]
+
+ self._getScenePrimitive().children = [
+ self._boundedGroup,
+ self._cutPlane._getScenePrimitive(),
+ self._isogroup]
+
+ @staticmethod
+ def _computeRangeFromData(data):
+ """Compute range info (min, min positive, max) from data
+
+ :param Union[numpy.ndarray,None] data:
+ :return: Union[List[float],None]
+ """
+ if data is None:
+ return None
+
+ dataRange = min_max(data, min_positive=True, finite=True)
+ if dataRange.minimum is None: # Only non-finite data
+ return None
+
+ if dataRange is not None:
+ min_positive = dataRange.min_positive
+ if min_positive is None:
+ min_positive = float('nan')
+ return dataRange.minimum, min_positive, dataRange.maximum
+
+ def setData(self, data, copy=True):
+ """Set the 3D scalar data represented by this item.
+
+ Dataset order is zyx (i.e., first dimension is z).
+
+ :param data: 3D array
+ :type data: 3D numpy.ndarray of float32 with shape at least (2, 2, 2)
+ :param bool copy:
+ True (default) to make a copy,
+ False to avoid copy (DO NOT MODIFY data afterwards)
+ """
+ if data is None:
+ self._data = None
+ self._boundedGroup.shape = None
+
+ else:
+ data = numpy.array(data, copy=copy, dtype=numpy.float32, order='C')
+ assert data.ndim == 3
+ assert min(data.shape) >= 2
+
+ self._data = data
+ self._boundedGroup.shape = self._data.shape
+
+ self._dataRange = self._computeRangeFromData(self._data)
+ self._updated(ItemChangedType.DATA)
+
+ def getData(self, copy=True):
+ """Return 3D dataset.
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get the internal data (DO NOT modify!)
+ :return: The data set (or None if not set)
+ """
+ if self._data is None:
+ return None
+ else:
+ return numpy.array(self._data, copy=copy)
+
+ def getDataRange(self):
+ """Return the range of the data as a 3-tuple of values.
+
+ positive min is NaN if no data is positive.
+
+ :return: (min, positive min, max) or None.
+ """
+ return self._dataRange
+
+ # Cut Plane
+
+ def getCutPlanes(self):
+ """Return an iterable of all :class:`CutPlane` of this item.
+
+ This includes hidden cut planes.
+
+ For now, there is always one cut plane.
+ """
+ return (self._cutPlane,)
+
+ # Handle iso-surfaces
+
+ # TODO rename to sigItemAdded|Removed?
+ sigIsosurfaceAdded = qt.Signal(object)
+ """Signal emitted when a new iso-surface is added to the view.
+
+ The newly added iso-surface is provided by this signal
+ """
+
+ sigIsosurfaceRemoved = qt.Signal(object)
+ """Signal emitted when an iso-surface is removed from the view
+
+ The removed iso-surface is provided by this signal.
+ """
+
+ def addIsosurface(self, level, color):
+ """Add an isosurface to this item.
+
+ :param level:
+ The value at which to build the iso-surface or a callable
+ (e.g., a function) taking a 3D numpy.ndarray as input and
+ returning a float.
+ Example: numpy.mean(data) + numpy.std(data)
+ :type level: float or callable
+ :param color: RGBA color of the isosurface
+ :type color: str or array-like of 4 float in [0., 1.]
+ :return: isosurface object
+ :rtype: ~silx.gui.plot3d.items.volume.Isosurface
+ """
+ isosurface = self._Isosurface(parent=self)
+ isosurface.setColor(color)
+ if callable(level):
+ isosurface.setAutoLevelFunction(level)
+ else:
+ isosurface.setLevel(level)
+ isosurface.sigItemChanged.connect(self._isosurfaceItemChanged)
+
+ self._isosurfaces.append(isosurface)
+
+ self._updateIsosurfaces()
+
+ self.sigIsosurfaceAdded.emit(isosurface)
+ return isosurface
+
+ def getIsosurfaces(self):
+ """Return an iterable of all :class:`.Isosurface` instance of this item"""
+ return tuple(self._isosurfaces)
+
+ def removeIsosurface(self, isosurface):
+ """Remove an iso-surface from this item.
+
+ :param ~silx.gui.plot3d.Plot3DWidget.Isosurface isosurface:
+ The isosurface object to remove
+ """
+ if isosurface not in self.getIsosurfaces():
+ _logger.warning(
+ "Try to remove isosurface that is not in the list: %s",
+ str(isosurface))
+ else:
+ isosurface.sigItemChanged.disconnect(self._isosurfaceItemChanged)
+ self._isosurfaces.remove(isosurface)
+ self._updateIsosurfaces()
+ self.sigIsosurfaceRemoved.emit(isosurface)
+
+ def clearIsosurfaces(self):
+ """Remove all :class:`.Isosurface` instances from this item."""
+ for isosurface in self.getIsosurfaces():
+ self.removeIsosurface(isosurface)
+
+ def _isosurfaceItemChanged(self, event):
+ """Handle update of isosurfaces upon level changed"""
+ if event == Item3DChangedType.ISO_LEVEL:
+ self._updateIsosurfaces()
+
+ def _updateIsosurfaces(self):
+ """Handle updates of iso-surfaces level and add/remove"""
+ # Sorting using minus, this supposes data 'object' to be max values
+ sortedIso = sorted(self.getIsosurfaces(),
+ key=lambda isosurface: - isosurface.getLevel())
+ self._isogroup.children = [iso._getScenePrimitive() for iso in sortedIso]
+
+ # BaseNodeItem
+
+ def getItems(self):
+ """Returns the list of items currently present in this item.
+
+ :rtype: tuple
+ """
+ return self.getCutPlanes() + self.getIsosurfaces()
+
+
+##################
+# ComplexField3D #
+##################
+
+class ComplexCutPlane(CutPlane, ComplexMixIn):
+ """Class representing a cutting plane in a :class:`ComplexField3D` item.
+
+ :param parent: 3D Data set in which the cut plane is applied.
+ """
+
+ def __init__(self, parent):
+ ComplexMixIn.__init__(self)
+ CutPlane.__init__(self, parent=parent)
+
+ def _syncDataWithParent(self):
+ """Synchronize this instance data with that of its parent"""
+ parent = self.parent()
+ if parent is None:
+ data, range_ = None, None
+ else:
+ mode = self.getComplexMode()
+ data = parent.getData(mode=mode, copy=False)
+ range_ = parent.getDataRange(mode=mode)
+ self._updateData(data, range_)
+
+ def _updated(self, event=None):
+ """Handle update of the cut plane (and take care of mode change
+
+ :param Union[None,ItemChangedType] event: The kind of update
+ """
+ if event == ItemChangedType.COMPLEX_MODE:
+ self._syncDataWithParent()
+ super(ComplexCutPlane, self)._updated(event)
+
+
+class ComplexIsosurface(Isosurface, ComplexMixIn, ColormapMixIn):
+ """Class representing an iso-surface in a :class:`ComplexField3D` item.
+
+ :param parent: The DataItem3D this iso-surface belongs to
+ """
+
+ _SUPPORTED_COMPLEX_MODES = \
+ (ComplexMixIn.ComplexMode.NONE,) + ComplexMixIn._SUPPORTED_COMPLEX_MODES
+ """Overrides supported ComplexMode"""
+
+ def __init__(self, parent):
+ ComplexMixIn.__init__(self)
+ ColormapMixIn.__init__(self, function.Colormap())
+ Isosurface.__init__(self, parent=parent)
+ self.setComplexMode(self.ComplexMode.NONE)
+
+ def _updateColor(self, color):
+ """Handle update of color
+
+ :param List[float] color: RGBA channels in [0, 1]
+ """
+ primitive = self._getScenePrimitive()
+ if (len(primitive.children) != 0 and
+ isinstance(primitive.children[0], primitives.ColormapMesh3D)):
+ primitive.children[0].alpha = self._color[3]
+ else:
+ super(ComplexIsosurface, self)._updateColor(color)
+
+ def _syncDataWithParent(self):
+ """Synchronize this instance data with that of its parent"""
+ parent = self.parent()
+ if parent is None:
+ self._data = None
+ else:
+ self._data = parent.getData(
+ mode=parent.getComplexMode(), copy=False)
+
+ if parent is None or self.getComplexMode() == self.ComplexMode.NONE:
+ self._setColormappedData(None, copy=False)
+ else:
+ self._setColormappedData(
+ parent.getData(mode=self.getComplexMode(), copy=False),
+ copy=False)
+
+ self._updateScenePrimitive()
+
+ def _parentChanged(self, event):
+ """Handle data change in the parent this isosurface belongs to"""
+ if event == ItemChangedType.COMPLEX_MODE:
+ self._syncDataWithParent()
+ super(ComplexIsosurface, self)._parentChanged(event)
+
+ def _updated(self, event=None):
+ """Handle update of the isosurface (and take care of mode change)
+
+ :param ItemChangedType event: The kind of update
+ """
+ if event == ItemChangedType.COMPLEX_MODE:
+ self._syncDataWithParent()
+
+ elif event in (ItemChangedType.COLORMAP,
+ Item3DChangedType.INTERPOLATION):
+ self._updateScenePrimitive()
+ super(ComplexIsosurface, self)._updated(event)
+
+ def _updateScenePrimitive(self):
+ """Update underlying mesh"""
+ if self.getComplexMode() == self.ComplexMode.NONE:
+ super(ComplexIsosurface, self)._updateScenePrimitive()
+
+ else: # Specific display for colormapped isosurface
+ self._getScenePrimitive().children = []
+
+ values = self.getColormappedData(copy=False)
+ if values is not None:
+ vertices, normals, indices = self._computeIsosurface()
+ if vertices is not None:
+ values = interp3d(values, vertices, method='linear_omp')
+ # TODO reuse isosurface when only color changes...
+
+ mesh = primitives.ColormapMesh3D(
+ vertices,
+ value=values.reshape(-1, 1),
+ colormap=self._getSceneColormap(),
+ normal=normals,
+ mode='triangles',
+ indices=indices,
+ copy=False)
+ mesh.alpha = self._color[3]
+ self._getScenePrimitive().children = [mesh]
+
+
+class ComplexField3D(ScalarField3D, ComplexMixIn):
+ """3D complex field on a regular grid.
+
+ :param parent: The View widget this item belongs to.
+ """
+
+ _CutPlane = ComplexCutPlane
+ _Isosurface = ComplexIsosurface
+
+ def __init__(self, parent=None):
+ self._dataRangeCache = None
+
+ ComplexMixIn.__init__(self)
+ ScalarField3D.__init__(self, parent=parent)
+
+ @docstring(ComplexMixIn)
+ def setComplexMode(self, mode):
+ mode = ComplexMixIn.ComplexMode.from_value(mode)
+ if mode != self.getComplexMode():
+ self.clearIsosurfaces() # Reset isosurfaces
+ ComplexMixIn.setComplexMode(self, mode)
+
+ def setData(self, data, copy=True):
+ """Set the 3D complex data represented by this item.
+
+ Dataset order is zyx (i.e., first dimension is z).
+
+ :param data: 3D array
+ :type data: 3D numpy.ndarray of float32 with shape at least (2, 2, 2)
+ :param bool copy:
+ True (default) to make a copy,
+ False to avoid copy (DO NOT MODIFY data afterwards)
+ """
+ if data is None:
+ self._data = None
+ self._dataRangeCache = None
+ self._boundedGroup.shape = None
+
+ else:
+ data = numpy.array(data, copy=copy, dtype=numpy.complex64, order='C')
+ assert data.ndim == 3
+ assert min(data.shape) >= 2
+
+ self._data = data
+ self._dataRangeCache = {}
+ self._boundedGroup.shape = self._data.shape
+
+ self._updated(ItemChangedType.DATA)
+
+ def getData(self, copy=True, mode=None):
+ """Return 3D dataset.
+
+ This method does not cache data converted to a specific mode,
+ it computes it for each request.
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get the internal data (DO NOT modify!)
+ :param Union[None,Mode] mode:
+ The kind of data to retrieve.
+ If None (the default), it returns the complex data,
+ else it computes the requested scalar data.
+ :return: The data set (or None if not set)
+ :rtype: Union[numpy.ndarray,None]
+ """
+ if mode is None:
+ return super(ComplexField3D, self).getData(copy=copy)
+ else:
+ return self._convertComplexData(self._data, mode)
+
+ def getDataRange(self, mode=None):
+ """Return the range of the requested data as a 3-tuple of values.
+
+ Positive min is NaN if no data is positive.
+
+ :param Union[None,Mode] mode:
+ The kind of data for which to get the range information.
+ If None (the default), it returns the data range for the current mode,
+ else it returns the data range for the requested mode.
+ :return: (min, positive min, max) or None.
+ :rtype: Union[None,List[float]]
+ """
+ if self._dataRangeCache is None:
+ return None
+
+ if mode is None:
+ mode = self.getComplexMode()
+
+ if mode not in self._dataRangeCache:
+ # Compute it and store it in cache
+ data = self.getData(copy=False, mode=mode)
+ self._dataRangeCache[mode] = self._computeRangeFromData(data)
+
+ return self._dataRangeCache[mode]
diff --git a/src/silx/gui/plot3d/scene/__init__.py b/src/silx/gui/plot3d/scene/__init__.py
new file mode 100644
index 0000000..9671725
--- /dev/null
+++ b/src/silx/gui/plot3d/scene/__init__.py
@@ -0,0 +1,34 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a 3D graphics scene graph structure."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "08/11/2016"
+
+
+from .core import Base, Elem, Group, PrivateGroup # noqa
+from .viewport import Viewport # noqa
+from .window import Window # noqa
diff --git a/src/silx/gui/plot3d/scene/axes.py b/src/silx/gui/plot3d/scene/axes.py
new file mode 100644
index 0000000..e35e5e1
--- /dev/null
+++ b/src/silx/gui/plot3d/scene/axes.py
@@ -0,0 +1,258 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Primitive displaying a text field in the scene."""
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "17/10/2016"
+
+
+import logging
+import numpy
+
+from ...plot._utils import ticklayout
+
+from . import core, primitives, text, transform
+
+
+_logger = logging.getLogger(__name__)
+
+
+class LabelledAxes(primitives.GroupBBox):
+ """A group displaying a bounding box with axes labels around its children.
+ """
+
+ def __init__(self):
+ super(LabelledAxes, self).__init__()
+ self._ticksForBounds = None
+
+ self._font = text.Font()
+
+ self._boxVisibility = True
+
+ # TODO offset labels from anchor in pixels
+
+ self._xlabel = text.Text2D(font=self._font)
+ self._xlabel.align = 'center'
+ self._xlabel.transforms = [self._boxTransforms,
+ transform.Translate(tx=0.5)]
+ self._children.insert(-1, self._xlabel)
+
+ self._ylabel = text.Text2D(font=self._font)
+ self._ylabel.align = 'center'
+ self._ylabel.transforms = [self._boxTransforms,
+ transform.Translate(ty=0.5)]
+ self._children.insert(-1, self._ylabel)
+
+ self._zlabel = text.Text2D(font=self._font)
+ self._zlabel.align = 'center'
+ self._zlabel.transforms = [self._boxTransforms,
+ transform.Translate(tz=0.5)]
+ self._children.insert(-1, self._zlabel)
+
+ # Init tick lines with dummy pos
+ self._tickLines = primitives.DashedLines(
+ positions=((0., 0., 0.), (0., 0., 0.)))
+ self._tickLines.dash = 5, 10
+ self._tickLines.visible = False
+ self._children.insert(-1, self._tickLines)
+
+ self._tickLabels = core.Group()
+ self._children.insert(-1, self._tickLabels)
+
+ # Sync color
+ self.tickColor = 1., 1., 1., 1.
+
+ def _updateBoxAndAxes(self):
+ """Update bbox and axes position and size according to children.
+
+ Overridden from GroupBBox
+ """
+ super(LabelledAxes, self)._updateBoxAndAxes()
+
+ bounds = self._group.bounds(dataBounds=True)
+ if bounds is not None:
+ tx, ty, tz = (bounds[1] - bounds[0]) / 2.
+ else:
+ tx, ty, tz = 0.5, 0.5, 0.5
+
+ self._xlabel.transforms[-1].tx = tx
+ self._ylabel.transforms[-1].ty = ty
+ self._zlabel.transforms[-1].tz = tz
+
+ @property
+ def tickColor(self):
+ """Color of ticks and text labels.
+
+ This does NOT set bounding box color.
+ Use :attr:`color` for the bounding box.
+ """
+ return self._xlabel.foreground
+
+ @tickColor.setter
+ def tickColor(self, color):
+ self._xlabel.foreground = color
+ self._ylabel.foreground = color
+ self._zlabel.foreground = color
+ transparentColor = color[0], color[1], color[2], color[3] * 0.6
+ self._tickLines.setAttribute('color', transparentColor)
+ for label in self._tickLabels.children:
+ label.foreground = color
+
+ @property
+ def font(self):
+ """Font of axes text labels (Font)"""
+ return self._font
+
+ @font.setter
+ def font(self, font):
+ self._font = font
+ self._xlabel.font = font
+ self._ylabel.font = font
+ self._zlabel.font = font
+ for label in self._tickLabels.children:
+ label.font = font
+
+ @property
+ def xlabel(self):
+ """Text label of the X axis (str)"""
+ return self._xlabel.text
+
+ @xlabel.setter
+ def xlabel(self, text):
+ self._xlabel.text = text
+
+ @property
+ def ylabel(self):
+ """Text label of the Y axis (str)"""
+ return self._ylabel.text
+
+ @ylabel.setter
+ def ylabel(self, text):
+ self._ylabel.text = text
+
+ @property
+ def zlabel(self):
+ """Text label of the Z axis (str)"""
+ return self._zlabel.text
+
+ @zlabel.setter
+ def zlabel(self, text):
+ self._zlabel.text = text
+
+ @property
+ def boxVisible(self):
+ """Returns bounding box, axes labels and grid visibility."""
+ return self._boxVisibility
+
+ @boxVisible.setter
+ def boxVisible(self, visible):
+ self._boxVisibility = bool(visible)
+ for child in self._children:
+ if child == self._tickLines:
+ if self._ticksForBounds is not None:
+ child.visible = self._boxVisibility
+ elif child != self._group:
+ child.visible = self._boxVisibility
+
+ def _updateTicks(self):
+ """Check if ticks need update and update them if needed."""
+ bounds = self._group.bounds(transformed=False, dataBounds=True)
+ if bounds is None: # No content
+ if self._ticksForBounds is not None:
+ self._ticksForBounds = None
+ self._tickLines.visible = False
+ self._tickLabels.children = [] # Reset previous labels
+
+ elif (self._ticksForBounds is None or
+ not numpy.all(numpy.equal(bounds, self._ticksForBounds))):
+ self._ticksForBounds = bounds
+
+ # Update ticks
+ ticklength = numpy.abs(bounds[1] - bounds[0])
+
+ xticks, xlabels = ticklayout.ticks(*bounds[:, 0])
+ yticks, ylabels = ticklayout.ticks(*bounds[:, 1])
+ zticks, zlabels = ticklayout.ticks(*bounds[:, 2])
+
+ # Update tick lines
+ coords = numpy.empty(
+ ((len(xticks) + len(yticks) + len(zticks)), 4, 3),
+ dtype=numpy.float32)
+ coords[:, :, :] = bounds[0, :] # account for offset from origin
+
+ xcoords = coords[:len(xticks)]
+ xcoords[:, :, 0] = numpy.asarray(xticks)[:, numpy.newaxis]
+ xcoords[:, 1, 1] += ticklength[1] # X ticks on XY plane
+ xcoords[:, 3, 2] += ticklength[2] # X ticks on XZ plane
+
+ ycoords = coords[len(xticks):len(xticks) + len(yticks)]
+ ycoords[:, :, 1] = numpy.asarray(yticks)[:, numpy.newaxis]
+ ycoords[:, 1, 0] += ticklength[0] # Y ticks on XY plane
+ ycoords[:, 3, 2] += ticklength[2] # Y ticks on YZ plane
+
+ zcoords = coords[len(xticks) + len(yticks):]
+ zcoords[:, :, 2] = numpy.asarray(zticks)[:, numpy.newaxis]
+ zcoords[:, 1, 0] += ticklength[0] # Z ticks on XZ plane
+ zcoords[:, 3, 1] += ticklength[1] # Z ticks on YZ plane
+
+ self._tickLines.setPositions(coords.reshape(-1, 3))
+ self._tickLines.visible = self._boxVisibility
+
+ # Update labels
+ color = self.tickColor
+ offsets = bounds[0] - ticklength / 20.
+ labels = []
+ for tick, label in zip(xticks, xlabels):
+ text2d = text.Text2D(text=label, font=self.font)
+ text2d.align = 'center'
+ text2d.foreground = color
+ text2d.transforms = [transform.Translate(
+ tx=tick, ty=offsets[1], tz=offsets[2])]
+ labels.append(text2d)
+
+ for tick, label in zip(yticks, ylabels):
+ text2d = text.Text2D(text=label, font=self.font)
+ text2d.align = 'center'
+ text2d.foreground = color
+ text2d.transforms = [transform.Translate(
+ tx=offsets[0], ty=tick, tz=offsets[2])]
+ labels.append(text2d)
+
+ for tick, label in zip(zticks, zlabels):
+ text2d = text.Text2D(text=label, font=self.font)
+ text2d.align = 'center'
+ text2d.foreground = color
+ text2d.transforms = [transform.Translate(
+ tx=offsets[0], ty=offsets[1], tz=tick)]
+ labels.append(text2d)
+
+ self._tickLabels.children = labels # Reset previous labels
+
+ def prepareGL2(self, context):
+ self._updateTicks()
+ super(LabelledAxes, self).prepareGL2(context)
diff --git a/src/silx/gui/plot3d/scene/camera.py b/src/silx/gui/plot3d/scene/camera.py
new file mode 100644
index 0000000..90de7ed
--- /dev/null
+++ b/src/silx/gui/plot3d/scene/camera.py
@@ -0,0 +1,353 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides classes to handle a perspective projection in 3D."""
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/07/2016"
+
+
+import numpy
+
+from . import transform
+
+
+# CameraExtrinsic #############################################################
+
+class CameraExtrinsic(transform.Transform):
+ """Transform matrix to handle camera position and orientation.
+
+ :param position: Coordinates of the point of view.
+ :type position: numpy.ndarray-like of 3 float32.
+ :param direction: Sight direction vector.
+ :type direction: numpy.ndarray-like of 3 float32.
+ :param up: Vector pointing upward in the image plane.
+ :type up: numpy.ndarray-like of 3 float32.
+ """
+
+ def __init__(self, position=(0., 0., 0.),
+ direction=(0., 0., -1.),
+ up=(0., 1., 0.)):
+
+ super(CameraExtrinsic, self).__init__()
+ self._position = None
+ self.position = position # set _position
+ self._side = 1., 0., 0.
+ self._up = 0., 1., 0.
+ self._direction = 0., 0., -1.
+ self.setOrientation(direction=direction, up=up) # set _direction, _up
+
+ def _makeMatrix(self):
+ return transform.mat4LookAtDir(self._position,
+ self._direction, self._up)
+
+ def copy(self):
+ """Return an independent copy"""
+ return CameraExtrinsic(self.position, self.direction, self.up)
+
+ def setOrientation(self, direction=None, up=None):
+ """Set the rotation of the point of view.
+
+ :param direction: Sight direction vector or
+ None to keep the current one.
+ :type direction: numpy.ndarray-like of 3 float32 or None.
+ :param up: Vector pointing upward in the image plane or
+ None to keep the current one.
+ :type up: numpy.ndarray-like of 3 float32 or None.
+ :raises RuntimeError: if the direction and up are parallel.
+ """
+ if direction is None: # Use current direction
+ direction = self.direction
+ else:
+ assert len(direction) == 3
+ direction = numpy.array(direction, copy=True, dtype=numpy.float32)
+ direction /= numpy.linalg.norm(direction)
+
+ if up is None: # Use current up
+ up = self.up
+ else:
+ assert len(up) == 3
+ up = numpy.array(up, copy=True, dtype=numpy.float32)
+
+ # Update side and up to make sure they are perpendicular and normalized
+ side = numpy.cross(direction, up)
+ sidenormal = numpy.linalg.norm(side)
+ if sidenormal == 0.:
+ raise RuntimeError('direction and up vectors are parallel.')
+ # Alternative: when one of the input parameter is None, it is
+ # possible to guess correct vectors using previous direction and up
+ side /= sidenormal
+ up = numpy.cross(side, direction)
+ up /= numpy.linalg.norm(up)
+
+ self._side = side
+ self._up = up
+ self._direction = direction
+ self.notify()
+
+ @property
+ def position(self):
+ """Coordinates of the point of view as a numpy.ndarray of 3 float32."""
+ return self._position.copy()
+
+ @position.setter
+ def position(self, position):
+ assert len(position) == 3
+ self._position = numpy.array(position, copy=True, dtype=numpy.float32)
+ self.notify()
+
+ @property
+ def direction(self):
+ """Sight direction (ndarray of 3 float32)."""
+ return self._direction.copy()
+
+ @direction.setter
+ def direction(self, direction):
+ self.setOrientation(direction=direction)
+
+ @property
+ def up(self):
+ """Vector pointing upward in the image plane (ndarray of 3 float32).
+ """
+ return self._up.copy()
+
+ @up.setter
+ def up(self, up):
+ self.setOrientation(up=up)
+
+ @property
+ def side(self):
+ """Vector pointing towards the side of the image plane.
+
+ ndarray of 3 float32"""
+ return self._side.copy()
+
+ def move(self, direction, step=1.):
+ """Move the camera relative to the image plane.
+
+ :param str direction: Direction relative to image plane.
+ One of: 'up', 'down', 'left', 'right',
+ 'forward', 'backward'.
+ :param float step: The step of the pan to perform in the coordinate
+ in which the camera position is defined.
+ """
+ if direction in ('up', 'down'):
+ vector = self.up * (1. if direction == 'up' else -1.)
+ elif direction in ('left', 'right'):
+ vector = self.side * (1. if direction == 'right' else -1.)
+ elif direction in ('forward', 'backward'):
+ vector = self.direction * (1. if direction == 'forward' else -1.)
+ else:
+ raise ValueError('Unsupported direction: %s' % direction)
+
+ self.position += step * vector
+
+ def rotate(self, direction, angle=1.):
+ """First-person rotation of the camera towards the direction.
+
+ :param str direction: Direction of movement relative to image plane.
+ In: 'up', 'down', 'left', 'right'.
+ :param float angle: The angle in degrees of the rotation.
+ """
+ if direction in ('up', 'down'):
+ axis = self.side * (1. if direction == 'up' else -1.)
+ elif direction in ('left', 'right'):
+ axis = self.up * (1. if direction == 'left' else -1.)
+ else:
+ raise ValueError('Unsupported direction: %s' % direction)
+
+ matrix = transform.mat4RotateFromAngleAxis(numpy.radians(angle), *axis)
+ newdir = numpy.dot(matrix[:3, :3], self.direction)
+
+ if direction in ('up', 'down'):
+ # Rotate up to avoid up and new direction to be (almost) co-linear
+ newup = numpy.dot(matrix[:3, :3], self.up)
+ self.setOrientation(newdir, newup)
+ else:
+ # No need to rotate up here as it is the rotation axis
+ self.direction = newdir
+
+ def orbit(self, direction, center=(0., 0., 0.), angle=1.):
+ """Rotate the camera around a point.
+
+ :param str direction: Direction of movement relative to image plane.
+ In: 'up', 'down', 'left', 'right'.
+ :param center: Position around which to rotate the point of view.
+ :type center: numpy.ndarray-like of 3 float32.
+ :param float angle: he angle in degrees of the rotation.
+ """
+ if direction in ('up', 'down'):
+ axis = self.side * (1. if direction == 'down' else -1.)
+ elif direction in ('left', 'right'):
+ axis = self.up * (1. if direction == 'right' else -1.)
+ else:
+ raise ValueError('Unsupported direction: %s' % direction)
+
+ # Rotate viewing direction
+ rotmatrix = transform.mat4RotateFromAngleAxis(
+ numpy.radians(angle), *axis)
+ self.direction = numpy.dot(rotmatrix[:3, :3], self.direction)
+
+ # Rotate position around center
+ center = numpy.array(center, copy=False, dtype=numpy.float32)
+ matrix = numpy.dot(transform.mat4Translate(*center), rotmatrix)
+ matrix = numpy.dot(matrix, transform.mat4Translate(*(-center)))
+ position = numpy.append(self.position, 1.)
+ self.position = numpy.dot(matrix, position)[:3]
+
+ _RESET_CAMERA_ORIENTATIONS = {
+ 'side': ((-1., -1., -1.), (0., 1., 0.)),
+ 'front': ((0., 0., -1.), (0., 1., 0.)),
+ 'back': ((0., 0., 1.), (0., 1., 0.)),
+ 'top': ((0., -1., 0.), (0., 0., -1.)),
+ 'bottom': ((0., 1., 0.), (0., 0., 1.)),
+ 'right': ((-1., 0., 0.), (0., 1., 0.)),
+ 'left': ((1., 0., 0.), (0., 1., 0.))
+ }
+
+ def reset(self, face=None):
+ """Reset the camera position to pre-defined orientations.
+
+ :param str face: The direction of the camera in:
+ side, front, back, top, bottom, right, left.
+ """
+ if face not in self._RESET_CAMERA_ORIENTATIONS:
+ raise ValueError('Unsupported face: %s' % face)
+
+ distance = numpy.linalg.norm(self.position)
+ direction, up = self._RESET_CAMERA_ORIENTATIONS[face]
+ self.setOrientation(direction, up)
+ self.position = - self.direction * distance
+
+
+class Camera(transform.Transform):
+ """Combination of camera projection and position.
+
+ See :class:`Perspective` and :class:`CameraExtrinsic`.
+
+ :param float fovy: Vertical field-of-view in degrees.
+ :param float near: The near clipping plane Z coord (strictly positive).
+ :param float far: The far clipping plane Z coord (> near).
+ :param size:
+ Viewport's size used to compute the aspect ratio (width, height).
+ :type size: 2-tuple of float
+ :param position: Coordinates of the point of view.
+ :type position: numpy.ndarray-like of 3 float32.
+ :param direction: Sight direction vector.
+ :type direction: numpy.ndarray-like of 3 float32.
+ :param up: Vector pointing upward in the image plane.
+ :type up: numpy.ndarray-like of 3 float32.
+ """
+
+ def __init__(self, fovy=30., near=0.1, far=1., size=(1., 1.),
+ position=(0., 0., 0.),
+ direction=(0., 0., -1.), up=(0., 1., 0.)):
+ super(Camera, self).__init__()
+ self._intrinsic = transform.Perspective(fovy, near, far, size)
+ self._intrinsic.addListener(self._transformChanged)
+ self._extrinsic = CameraExtrinsic(position, direction, up)
+ self._extrinsic.addListener(self._transformChanged)
+
+ def _makeMatrix(self):
+ return numpy.dot(self.intrinsic.matrix, self.extrinsic.matrix)
+
+ def _transformChanged(self, source):
+ """Listener of intrinsic and extrinsic camera parameters instances."""
+ if source is not self:
+ self.notify()
+
+ def resetCamera(self, bounds):
+ """Change camera to have the bounds in the viewing frustum.
+
+ It updates the camera position and depth extent.
+ Camera sight direction and up are not affected.
+
+ :param bounds: The axes-aligned bounds to include.
+ :type bounds: numpy.ndarray: ((xMin, yMin, zMin), (xMax, yMax, zMax))
+ """
+
+ center = 0.5 * (bounds[0] + bounds[1])
+ radius = numpy.linalg.norm(0.5 * (bounds[1] - bounds[0]))
+ if radius == 0.: # bounds are all collapsed
+ radius = 1.
+
+ if isinstance(self.intrinsic, transform.Perspective):
+ # Get the viewpoint distance from the bounds center
+ minfov = numpy.radians(self.intrinsic.fovy)
+ width, height = self.intrinsic.size
+ if width < height:
+ minfov *= width / height
+
+ offset = radius / numpy.sin(0.5 * minfov)
+
+ # Update camera
+ self.extrinsic.position = \
+ center - offset * self.extrinsic.direction
+ self.intrinsic.setDepthExtent(offset - radius, offset + radius)
+
+ elif isinstance(self.intrinsic, transform.Orthographic):
+ # Y goes up
+ self.intrinsic.setClipping(
+ left=center[0] - radius,
+ right=center[0] + radius,
+ bottom=center[1] - radius,
+ top=center[1] + radius)
+
+ # Update camera
+ self.extrinsic.position = 0, 0, 0
+ self.intrinsic.setDepthExtent(center[2] - radius,
+ center[2] + radius)
+ else:
+ raise RuntimeError('Unsupported camera: %s' % self.intrinsic)
+
+ @property
+ def intrinsic(self):
+ """Intrinsic camera parameters, i.e., projection matrix."""
+ return self._intrinsic
+
+ @intrinsic.setter
+ def intrinsic(self, intrinsic):
+ self._intrinsic.removeListener(self._transformChanged)
+ self._intrinsic = intrinsic
+ self._intrinsic.addListener(self._transformChanged)
+
+ @property
+ def extrinsic(self):
+ """Extrinsic camera parameters, i.e., position and orientation."""
+ return self._extrinsic
+
+ def move(self, *args, **kwargs):
+ """See :meth:`CameraExtrinsic.move`."""
+ self.extrinsic.move(*args, **kwargs)
+
+ def rotate(self, *args, **kwargs):
+ """See :meth:`CameraExtrinsic.rotate`."""
+ self.extrinsic.rotate(*args, **kwargs)
+
+ def orbit(self, *args, **kwargs):
+ """See :meth:`CameraExtrinsic.orbit`."""
+ self.extrinsic.orbit(*args, **kwargs)
diff --git a/src/silx/gui/plot3d/scene/core.py b/src/silx/gui/plot3d/scene/core.py
new file mode 100644
index 0000000..43838fe
--- /dev/null
+++ b/src/silx/gui/plot3d/scene/core.py
@@ -0,0 +1,343 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the base scene structure.
+
+This module provides the classes for describing a tree structure with
+rendering and picking API.
+All nodes inherit from :class:`Base`.
+Nodes with children are provided with :class:`PrivateGroup` and
+:class:`Group` classes.
+Leaf rendering nodes should inherit from :class:`Elem`.
+"""
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/07/2016"
+
+
+import itertools
+import weakref
+
+import numpy
+
+from . import event
+from . import transform
+
+from .viewport import Viewport
+
+
+# Nodes #######################################################################
+
+class Base(event.Notifier):
+ """A scene node with common features."""
+
+ def __init__(self):
+ super(Base, self).__init__()
+ self._visible = True
+ self._pickable = False
+
+ self._parentRef = None
+
+ self._transforms = transform.TransformList()
+ self._transforms.addListener(self._transformChanged)
+
+ # notifying properties
+
+ visible = event.notifyProperty('_visible',
+ doc="Visibility flag of the node")
+ pickable = event.notifyProperty('_pickable',
+ doc="True to make node pickable")
+
+ # Access to tree path
+
+ @property
+ def parent(self):
+ """Parent or None if no parent"""
+ return None if self._parentRef is None else self._parentRef()
+
+ def _setParent(self, parent):
+ """Set the parent of this node.
+
+ For internal use.
+
+ :param Base parent: The parent.
+ """
+ if parent is not None and self._parentRef is not None:
+ raise RuntimeError('Trying to add a node at two places.')
+ # Alternative: remove it from previous children list
+ self._parentRef = None if parent is None else weakref.ref(parent)
+
+ @property
+ def path(self):
+ """Tuple of scene nodes, from the tip of the tree down to this node.
+
+ If this tree is attached to a :class:`Viewport`,
+ then the :class:`Viewport` is the first element of path.
+ """
+ if self.parent is None:
+ return self,
+ elif isinstance(self.parent, Viewport):
+ return self.parent, self
+ else:
+ return self.parent.path + (self, )
+
+ @property
+ def viewport(self):
+ """The viewport this node is attached to or None."""
+ root = self.path[0]
+ return root if isinstance(root, Viewport) else None
+
+ @property
+ def root(self):
+ """The root node of the scene.
+
+ If attached to a :class:`Viewport`, this is the item right under it
+ """
+ path = self.path
+ return path[1] if isinstance(path[0], Viewport) else path[0]
+
+ @property
+ def objectToNDCTransform(self):
+ """Transform from object to normalized device coordinates.
+
+ Do not forget perspective divide.
+ """
+ # Using the Viewport's transforms property to proxy the camera
+ path = self.path
+ assert isinstance(path[0], Viewport)
+ return transform.StaticTransformList(elem.transforms for elem in path)
+
+ @property
+ def objectToSceneTransform(self):
+ """Transform from object to scene.
+
+ Combine transforms up to the Viewport (not including it).
+ """
+ path = self.path
+ if isinstance(path[0], Viewport):
+ path = path[1:] # Remove viewport to remove camera transforms
+ return transform.StaticTransformList(elem.transforms for elem in path)
+
+ # transform
+
+ @property
+ def transforms(self):
+ """List of transforms defining the frame of this node relative
+ to its parent."""
+ return self._transforms
+
+ @transforms.setter
+ def transforms(self, iterable):
+ self._transforms.removeListener(self._transformChanged)
+ if isinstance(iterable, transform.TransformList):
+ # If it is a TransformList, do not create one to enable sharing.
+ self._transforms = iterable
+ else:
+ assert hasattr(iterable, '__iter__')
+ self._transforms = transform.TransformList(iterable)
+ self._transforms.addListener(self._transformChanged)
+
+ def _transformChanged(self, source):
+ self.notify() # Broadcast transform notification
+
+ # Bounds
+
+ _CUBE_CORNERS = numpy.array(list(itertools.product((0., 1.), repeat=3)),
+ dtype=numpy.float32)
+ """Unit cube corners used to transform bounds"""
+
+ def _bounds(self, dataBounds=False):
+ """Override in subclass to provide bounds in object coordinates"""
+ return None
+
+ def bounds(self, transformed=False, dataBounds=False):
+ """Returns the bounds of this node aligned with the axis,
+ with or without transform applied.
+
+ :param bool transformed: False to give bounds in object coordinates
+ (the default), True to apply this object's
+ transforms.
+ :param bool dataBounds: False to give bounds of vertices (the default),
+ True to give bounds of the represented data.
+ :return: The bounds: ((xMin, yMin, zMin), (xMax, yMax, zMax)) or None
+ if no bounds.
+ :rtype: numpy.ndarray of float
+ """
+ bounds = self._bounds(dataBounds)
+
+ if transformed and bounds is not None:
+ bounds = self.transforms.transformBounds(bounds)
+
+ return bounds
+
+ # Rendering
+
+ def prepareGL2(self, ctx):
+ """Called before the rendering to prepare OpenGL resources.
+
+ Override in subclass.
+ """
+ pass
+
+ def renderGL2(self, ctx):
+ """Called to perform the OpenGL rendering.
+
+ Override in subclass.
+ """
+ pass
+
+ def render(self, ctx):
+ """Called internally to perform rendering."""
+ if self.visible:
+ ctx.pushTransform(self.transforms)
+ self.prepareGL2(ctx)
+ self.renderGL2(ctx)
+ ctx.popTransform()
+
+ def postRender(self, ctx):
+ """Hook called when parent's node render is finished.
+
+ Called in the reverse of rendering order (i.e., last child first).
+
+ Meant for nodes that modify the :class:`RenderContext` ctx to
+ reset their modifications.
+ """
+ pass
+
+ def pick(self, ctx, x, y, depth=None):
+ """True/False picking, should be fast"""
+ if self.pickable:
+ pass
+
+ def pickRay(self, ctx, ray):
+ """Picking returning list of ray intersections."""
+ if self.pickable:
+ pass
+
+
+class Elem(Base):
+ """A scene node that does some rendering."""
+
+ def __init__(self):
+ super(Elem, self).__init__()
+ # self.showBBox = False # Here or outside scene graph?
+ # self.clipPlane = None # This needs to be handled in the shader
+
+
+class PrivateGroup(Base):
+ """A scene node that renders its (private) childern.
+
+ :param iterable children: :class:`Base` nodes to add as children
+ """
+
+ class ChildrenList(event.NotifierList):
+ """List of children with notification and children's parent update."""
+
+ def _listWillChangeHook(self, methodName, *args, **kwargs):
+ super(PrivateGroup.ChildrenList, self)._listWillChangeHook(
+ methodName, *args, **kwargs)
+ for item in self:
+ item._setParent(None)
+
+ def _listWasChangedHook(self, methodName, *args, **kwargs):
+ for item in self:
+ item._setParent(self._parentRef())
+ super(PrivateGroup.ChildrenList, self)._listWasChangedHook(
+ methodName, *args, **kwargs)
+
+ def __init__(self, parent, children):
+ self._parentRef = weakref.ref(parent)
+ super(PrivateGroup.ChildrenList, self).__init__(children)
+
+ def __init__(self, children=()):
+ super(PrivateGroup, self).__init__()
+ self.__children = PrivateGroup.ChildrenList(self, children)
+ self.__children.addListener(self._updated)
+
+ @property
+ def _children(self):
+ """List of children to be rendered.
+
+ This private attribute is meant to be used by subclass.
+ """
+ return self.__children
+
+ @_children.setter
+ def _children(self, iterable):
+ self.__children.removeListener(self._updated)
+ for item in self.__children:
+ item._setParent(None)
+ del self.__children # This is needed
+ self.__children = PrivateGroup.ChildrenList(self, iterable)
+ self.__children.addListener(self._updated)
+ self.notify()
+
+ def _updated(self, source, *args, **kwargs):
+ """Listen for updates"""
+ if source is not self: # Avoid infinite recursion
+ self.notify(*args, **kwargs)
+
+ def _bounds(self, dataBounds=False):
+ """Compute the bounds from transformed children bounds"""
+ bounds = []
+ for child in self._children:
+ if child.visible:
+ childBounds = child.bounds(
+ transformed=True, dataBounds=dataBounds)
+ if childBounds is not None:
+ bounds.append(childBounds)
+
+ if len(bounds) == 0:
+ return None
+ else:
+ bounds = numpy.array(bounds, dtype=numpy.float32)
+ return numpy.array((bounds[:, 0, :].min(axis=0),
+ bounds[:, 1, :].max(axis=0)),
+ dtype=numpy.float32)
+
+ def prepareGL2(self, ctx):
+ pass
+
+ def renderGL2(self, ctx):
+ """Render all children"""
+ for child in self._children:
+ child.render(ctx)
+ for child in reversed(self._children):
+ child.postRender(ctx)
+
+
+class Group(PrivateGroup):
+ """A scene node that renders its (public) children."""
+
+ @property
+ def children(self):
+ """List of children to be rendered."""
+ return self._children
+
+ @children.setter
+ def children(self, iterable):
+ self._children = iterable
diff --git a/src/silx/gui/plot3d/scene/cutplane.py b/src/silx/gui/plot3d/scene/cutplane.py
new file mode 100644
index 0000000..88147df
--- /dev/null
+++ b/src/silx/gui/plot3d/scene/cutplane.py
@@ -0,0 +1,390 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""A cut plane in a 3D texture: hackish implementation...
+"""
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "11/01/2018"
+
+import string
+import numpy
+
+from ... import _glutils
+from ..._glutils import gl
+
+from .function import Colormap
+from .primitives import Box, Geometry, PlaneInGroup
+from . import transform, utils
+
+
+class ColormapMesh3D(Geometry):
+ """A 3D mesh with color from a 3D texture."""
+
+ _shaders = ("""
+ attribute vec3 position;
+ attribute vec3 normal;
+
+ uniform mat4 matrix;
+ uniform mat4 transformMat;
+ //uniform mat3 matrixInvTranspose;
+ uniform vec3 dataScale;
+ uniform vec3 texCoordsOffset;
+
+ varying vec4 vCameraPosition;
+ varying vec3 vPosition;
+ varying vec3 vNormal;
+ varying vec3 vTexCoords;
+
+ void main(void)
+ {
+ vCameraPosition = transformMat * vec4(position, 1.0);
+ //vNormal = matrixInvTranspose * normalize(normal);
+ vPosition = position;
+ vTexCoords = dataScale * position + texCoordsOffset;
+ vNormal = normal;
+ gl_Position = matrix * vec4(position, 1.0);
+ }
+ """,
+ string.Template("""
+ varying vec4 vCameraPosition;
+ varying vec3 vPosition;
+ varying vec3 vNormal;
+ varying vec3 vTexCoords;
+ uniform sampler3D data;
+ uniform float alpha;
+
+ $colormapDecl
+ $sceneDecl
+ $lightingFunction
+
+ void main(void)
+ {
+ $scenePreCall(vCameraPosition);
+
+ float value = texture3D(data, vTexCoords).r;
+ vec4 color = $colormapCall(value);
+ color.a *= alpha;
+
+ gl_FragColor = $lightingCall(color, vPosition, vNormal);
+
+ $scenePostCall(vCameraPosition);
+ }
+ """))
+
+ def __init__(self, position, normal, data, copy=True,
+ mode='triangles', indices=None, colormap=None):
+ assert mode in self._TRIANGLE_MODES
+ data = numpy.array(data, copy=copy, order='C')
+ assert data.ndim == 3
+ self._data = data
+ self._texture = None
+ self._update_texture = True
+ self._update_texture_filter = False
+ self._alpha = 1.
+ self._colormap = colormap or Colormap() # Default colormap
+ self._colormap.addListener(self._cmapChanged)
+ self._interpolation = 'linear'
+ super(ColormapMesh3D, self).__init__(mode,
+ indices,
+ position=position,
+ normal=normal)
+
+ self.isBackfaceVisible = True
+ self.textureOffset = 0., 0., 0.
+ """Offset to add to texture coordinates"""
+
+ def setData(self, data, copy=True):
+ data = numpy.array(data, copy=copy, order='C')
+ assert data.ndim == 3
+ self._data = data
+ self._update_texture = True
+
+ def getData(self, copy=True):
+ return numpy.array(self._data, copy=copy)
+
+ @property
+ def interpolation(self):
+ """The texture interpolation mode: 'linear' or 'nearest'"""
+ return self._interpolation
+
+ @interpolation.setter
+ def interpolation(self, interpolation):
+ assert interpolation in ('linear', 'nearest')
+ self._interpolation = interpolation
+ self._update_texture_filter = True
+ self.notify()
+
+ @property
+ def alpha(self):
+ """Transparency of the plane, float in [0, 1]"""
+ return self._alpha
+
+ @alpha.setter
+ def alpha(self, alpha):
+ self._alpha = float(alpha)
+
+ @property
+ def colormap(self):
+ """The colormap used by this primitive"""
+ return self._colormap
+
+ def _cmapChanged(self, source, *args, **kwargs):
+ """Broadcast colormap changes"""
+ self.notify(*args, **kwargs)
+
+ def prepareGL2(self, ctx):
+ if self._texture is None or self._update_texture:
+ if self._texture is not None:
+ self._texture.discard()
+
+ if self.interpolation == 'nearest':
+ filter_ = gl.GL_NEAREST
+ else:
+ filter_ = gl.GL_LINEAR
+ self._update_texture = False
+ self._update_texture_filter = False
+ self._texture = _glutils.Texture(
+ gl.GL_R32F, self._data, gl.GL_RED,
+ minFilter=filter_,
+ magFilter=filter_,
+ wrap=gl.GL_CLAMP_TO_EDGE)
+
+ if self._update_texture_filter:
+ self._update_texture_filter = False
+ if self.interpolation == 'nearest':
+ filter_ = gl.GL_NEAREST
+ else:
+ filter_ = gl.GL_LINEAR
+ self._texture.minFilter = filter_
+ self._texture.magFilter = filter_
+
+ super(ColormapMesh3D, self).prepareGL2(ctx)
+
+ def renderGL2(self, ctx):
+ fragment = self._shaders[1].substitute(
+ sceneDecl=ctx.fragDecl,
+ scenePreCall=ctx.fragCallPre,
+ scenePostCall=ctx.fragCallPost,
+ lightingFunction=ctx.viewport.light.fragmentDef,
+ lightingCall=ctx.viewport.light.fragmentCall,
+ colormapDecl=self.colormap.decl,
+ colormapCall=self.colormap.call
+ )
+ program = ctx.glCtx.prog(self._shaders[0], fragment)
+ program.use()
+
+ ctx.viewport.light.setupProgram(ctx, program)
+ self.colormap.setupProgram(ctx, program)
+
+ if not self.isBackfaceVisible:
+ gl.glCullFace(gl.GL_BACK)
+ gl.glEnable(gl.GL_CULL_FACE)
+
+ program.setUniformMatrix('matrix', ctx.objectToNDC.matrix)
+ program.setUniformMatrix('transformMat',
+ ctx.objectToCamera.matrix,
+ safe=True)
+ gl.glUniform1f(program.uniforms['alpha'], self._alpha)
+
+ shape = self._data.shape
+ scales = 1./shape[2], 1./shape[1], 1./shape[0]
+ gl.glUniform3f(program.uniforms['dataScale'], *scales)
+ gl.glUniform3f(program.uniforms['texCoordsOffset'], *self.textureOffset)
+
+ gl.glUniform1i(program.uniforms['data'], self._texture.texUnit)
+
+ ctx.setupProgram(program)
+
+ self._texture.bind()
+ self._draw(program)
+
+ if not self.isBackfaceVisible:
+ gl.glDisable(gl.GL_CULL_FACE)
+
+
+class CutPlane(PlaneInGroup):
+ """A cutting plane in a 3D texture"""
+
+ def __init__(self, point=(0., 0., 0.), normal=(0., 0., 1.)):
+ self._data = None
+ self._mesh = None
+ self._alpha = 1.
+ self._interpolation = 'linear'
+ self._colormap = Colormap()
+ super(CutPlane, self).__init__(point, normal)
+
+ def setData(self, data, copy=True):
+ if data is None:
+ self._data = None
+ if self._mesh is not None:
+ self._children.remove(self._mesh)
+ self._mesh = None
+
+ else:
+ data = numpy.array(data, copy=copy, order='C')
+ assert data.ndim == 3
+ self._data = data
+ if self._mesh is not None:
+ self._mesh.setData(data, copy=False)
+
+ def getData(self, copy=True):
+ return None if self._mesh is None else self._mesh.getData(copy=copy)
+
+ @property
+ def alpha(self):
+ return self._alpha
+
+ @alpha.setter
+ def alpha(self, alpha):
+ self._alpha = float(alpha)
+ if self._mesh is not None:
+ self._mesh.alpha = alpha
+
+ @property
+ def colormap(self):
+ return self._colormap
+
+ @property
+ def interpolation(self):
+ """The texture interpolation mode: 'linear' (default) or 'nearest'"""
+ return self._interpolation
+
+ @interpolation.setter
+ def interpolation(self, interpolation):
+ assert interpolation in ('nearest', 'linear')
+ if interpolation != self.interpolation:
+ self._interpolation = interpolation
+ if self._mesh is not None:
+ self._mesh.interpolation = interpolation
+ self.notify()
+
+ def prepareGL2(self, ctx):
+ if self.isValid:
+
+ contourVertices = self.contourVertices
+
+ if self._mesh is None and self._data is not None:
+ self._mesh = ColormapMesh3D(contourVertices,
+ normal=self.plane.normal,
+ data=self._data,
+ copy=False,
+ mode='fan',
+ colormap=self.colormap)
+ self._mesh.alpha = self._alpha
+ self._mesh.interpolation = self.interpolation
+ self._children.insert(0, self._mesh)
+
+ if self._mesh is not None:
+ if (contourVertices is None or
+ len(contourVertices) == 0):
+ self._mesh.visible = False
+ else:
+ self._mesh.visible = True
+ self._mesh.setAttribute('normal', self.plane.normal)
+ self._mesh.setAttribute('position', contourVertices)
+
+ needTextureOffset = False
+ if self.interpolation == 'nearest':
+ # If cut plane is co-linear with array bin edges add texture offset
+ planePt = self.plane.point
+ for index, normal in enumerate(((1., 0., 0.),
+ (0., 1., 0.),
+ (0., 0., 1.))):
+ if (numpy.all(numpy.equal(self.plane.normal, normal)) and
+ int(planePt[index]) == planePt[index]):
+ needTextureOffset = True
+ break
+
+ if needTextureOffset:
+ self._mesh.textureOffset = self.plane.normal * 1e-6
+ else:
+ self._mesh.textureOffset = 0., 0., 0.
+
+ super(CutPlane, self).prepareGL2(ctx)
+
+ def renderGL2(self, ctx):
+ with self.viewport.light.turnOff():
+ super(CutPlane, self).renderGL2(ctx)
+
+ def _bounds(self, dataBounds=False):
+ if not dataBounds:
+ vertices = self.contourVertices
+ if vertices is not None:
+ return numpy.array(
+ (vertices.min(axis=0), vertices.max(axis=0)),
+ dtype=numpy.float32)
+ else:
+ return None # Plane in not slicing the data volume
+ else:
+ if self._data is None:
+ return None
+ else:
+ depth, height, width = self._data.shape
+ return numpy.array(((0., 0., 0.),
+ (width, height, depth)),
+ dtype=numpy.float32)
+
+ @property
+ def contourVertices(self):
+ """The vertices of the contour of the plane/bounds intersection."""
+ # TODO copy from PlaneInGroup, refactor all that!
+ bounds = self.bounds(dataBounds=True)
+ if bounds is None:
+ return None # No bounds: no vertices
+
+ # Check if cache is valid and return it
+ cachebounds, cachevertices = self._cache
+ if numpy.all(numpy.equal(bounds, cachebounds)):
+ return cachevertices
+
+ # Cache is not OK, rebuild it
+ boxVertices = Box.getVertices(copy=True)
+ boxVertices = bounds[0] + boxVertices * (bounds[1] - bounds[0])
+ lineIndices = Box.getLineIndices(copy=False)
+ vertices = utils.boxPlaneIntersect(
+ boxVertices, lineIndices, self.plane.normal, self.plane.point)
+
+ self._cache = bounds, vertices if len(vertices) != 0 else None
+
+ return self._cache[1]
+
+ # Render transforms RW, TODO refactor this!
+ @property
+ def transforms(self):
+ return self._transforms
+
+ @transforms.setter
+ def transforms(self, iterable):
+ self._transforms.removeListener(self._transformChanged)
+ if isinstance(iterable, transform.TransformList):
+ # If it is a TransformList, do not create one to enable sharing.
+ self._transforms = iterable
+ else:
+ assert hasattr(iterable, '__iter__')
+ self._transforms = transform.TransformList(iterable)
+ self._transforms.addListener(self._transformChanged)
diff --git a/src/silx/gui/plot3d/scene/event.py b/src/silx/gui/plot3d/scene/event.py
new file mode 100644
index 0000000..98f8f8b
--- /dev/null
+++ b/src/silx/gui/plot3d/scene/event.py
@@ -0,0 +1,225 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a simple generic notification system."""
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "17/07/2018"
+
+
+import logging
+
+from silx.utils.weakref import WeakList
+
+_logger = logging.getLogger(__name__)
+
+
+# Notifier ####################################################################
+
+class Notifier(object):
+ """Base class for object with notification mechanism."""
+
+ def __init__(self):
+ self._listeners = WeakList()
+
+ def addListener(self, listener):
+ """Register a listener.
+
+ Adding an already registered listener has no effect.
+
+ :param callable listener: The function or method to register.
+ """
+ if listener not in self._listeners:
+ self._listeners.append(listener)
+ else:
+ _logger.warning('Ignoring addition of an already registered listener')
+
+ def removeListener(self, listener):
+ """Remove a previously registered listener.
+
+ :param callable listener: The function or method to unregister.
+ """
+ try:
+ self._listeners.remove(listener)
+ except ValueError:
+ _logger.warning('Trying to remove a listener that is not registered')
+
+ def notify(self, *args, **kwargs):
+ """Notify all registered listeners with the given parameters.
+
+ Listeners are called directly in this method.
+ Listeners are called in the order they were registered.
+ """
+ for listener in self._listeners:
+ listener(self, *args, **kwargs)
+
+
+def notifyProperty(attrName, copy=False, converter=None, doc=None):
+ """Create a property that adds notification to an attribute.
+
+ :param str attrName: The name of the attribute to wrap.
+ :param bool copy: Whether to return a copy of the attribute
+ or not (the default).
+ :param converter: Function converting input value to appropriate type
+ This function takes a single argument and return the
+ converted value.
+ It can be used to perform some asserts.
+ :param str doc: The docstring of the property
+ :return: A property with getter and setter
+ """
+ if copy:
+ def getter(self):
+ return getattr(self, attrName).copy()
+ else:
+ def getter(self):
+ return getattr(self, attrName)
+
+ if converter is None:
+ def setter(self, value):
+ if getattr(self, attrName) != value:
+ setattr(self, attrName, value)
+ self.notify()
+
+ else:
+ def setter(self, value):
+ value = converter(value)
+ if getattr(self, attrName) != value:
+ setattr(self, attrName, value)
+ self.notify()
+
+ return property(getter, setter, doc=doc)
+
+
+class HookList(list):
+ """List with hooks before and after modification."""
+
+ def __init__(self, iterable):
+ super(HookList, self).__init__(iterable)
+
+ self._listWasChangedHook('__init__', iterable)
+
+ def _listWillChangeHook(self, methodName, *args, **kwargs):
+ """To override. Called before modifying the list.
+
+ This method is called with the name of the method called to
+ modify the list and its parameters.
+ """
+ pass
+
+ def _listWasChangedHook(self, methodName, *args, **kwargs):
+ """To override. Called after modifying the list.
+
+ This method is called with the name of the method called to
+ modify the list and its parameters.
+ """
+ pass
+
+ # Wrapping methods that modify the list
+
+ def _wrapper(self, methodName, *args, **kwargs):
+ """Generic wrapper of list methods calling the hooks."""
+ self._listWillChangeHook(methodName, *args, **kwargs)
+ result = getattr(super(HookList, self),
+ methodName)(*args, **kwargs)
+ self._listWasChangedHook(methodName, *args, **kwargs)
+ return result
+
+ # Add methods
+
+ def __iadd__(self, *args, **kwargs):
+ return self._wrapper('__iadd__', *args, **kwargs)
+
+ def __imul__(self, *args, **kwargs):
+ return self._wrapper('__imul__', *args, **kwargs)
+
+ def append(self, *args, **kwargs):
+ return self._wrapper('append', *args, **kwargs)
+
+ def extend(self, *args, **kwargs):
+ return self._wrapper('extend', *args, **kwargs)
+
+ def insert(self, *args, **kwargs):
+ return self._wrapper('insert', *args, **kwargs)
+
+ # Remove methods
+
+ def __delitem__(self, *args, **kwargs):
+ return self._wrapper('__delitem__', *args, **kwargs)
+
+ def __delslice__(self, *args, **kwargs):
+ return self._wrapper('__delslice__', *args, **kwargs)
+
+ def remove(self, *args, **kwargs):
+ return self._wrapper('remove', *args, **kwargs)
+
+ def pop(self, *args, **kwargs):
+ return self._wrapper('pop', *args, **kwargs)
+
+ # Set methods
+
+ def __setitem__(self, *args, **kwargs):
+ return self._wrapper('__setitem__', *args, **kwargs)
+
+ def __setslice__(self, *args, **kwargs):
+ return self._wrapper('__setslice__', *args, **kwargs)
+
+ # In place methods
+
+ def sort(self, *args, **kwargs):
+ return self._wrapper('sort', *args, **kwargs)
+
+ def reverse(self, *args, **kwargs):
+ return self._wrapper('reverse', *args, **kwargs)
+
+
+class NotifierList(HookList, Notifier):
+ """List of Notifiers with notification mechanism.
+
+ This class registers itself as a listener of the list items.
+
+ The default listener method forward notification from list items
+ to the listeners of the list.
+ """
+
+ def __init__(self, iterable=()):
+ Notifier.__init__(self)
+ HookList.__init__(self, iterable)
+
+ def _listWillChangeHook(self, methodName, *args, **kwargs):
+ for item in self:
+ item.removeListener(self._notified)
+
+ def _listWasChangedHook(self, methodName, *args, **kwargs):
+ for item in self:
+ item.addListener(self._notified)
+ self.notify()
+
+ def _notified(self, source, *args, **kwargs):
+ """Default listener forwarding list item changes to its listeners."""
+ # Avoid infinite recursion if the list is listening itself
+ if source is not self:
+ self.notify(*args, **kwargs)
diff --git a/src/silx/gui/plot3d/scene/function.py b/src/silx/gui/plot3d/scene/function.py
new file mode 100644
index 0000000..2deb785
--- /dev/null
+++ b/src/silx/gui/plot3d/scene/function.py
@@ -0,0 +1,654 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides functions to add to shaders."""
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "17/07/2018"
+
+
+import contextlib
+import logging
+import string
+import numpy
+
+from ... import _glutils
+from ..._glutils import gl
+
+from . import event
+from . import utils
+
+
+_logger = logging.getLogger(__name__)
+
+
+class ProgramFunction(object):
+ """Class providing a function to add to a GLProgram shaders.
+ """
+
+ def setupProgram(self, context, program):
+ """Sets-up uniforms of a program using this shader function.
+
+ :param RenderContext context: The current rendering context
+ :param GLProgram program: The program to set-up.
+ It MUST be in use and using this function.
+ """
+ pass
+
+
+class Fog(event.Notifier, ProgramFunction):
+ """Linear fog over the whole scene content.
+
+ The background of the viewport is used as fog color,
+ otherwise it defaults to white.
+ """
+ # TODO: add more controls (set fog range), add more fog modes
+
+ _fragDecl = """
+ /* (1/(far - near) or 0, near) z in [0 (camera), -inf[ */
+ uniform vec2 fogExtentInfo;
+
+ /* Color to use as fog color */
+ uniform vec3 fogColor;
+
+ vec4 fog(vec4 color, vec4 cameraPosition) {
+ /* d = (pos - near) / (far - near) */
+ float distance = fogExtentInfo.x * (cameraPosition.z/cameraPosition.w - fogExtentInfo.y);
+ float fogFactor = clamp(distance, 0.0, 1.0);
+ vec3 rgb = mix(color.rgb, fogColor, fogFactor);
+ return vec4(rgb.r, rgb.g, rgb.b, color.a);
+ }
+ """
+
+ _fragDeclNoop = """
+ vec4 fog(vec4 color, vec4 cameraPosition) {
+ return color;
+ }
+ """
+
+ def __init__(self):
+ super(Fog, self).__init__()
+ self._isOn = True
+
+ @property
+ def isOn(self):
+ """True to enable fog, False to disable (bool)"""
+ return self._isOn
+
+ @isOn.setter
+ def isOn(self, isOn):
+ isOn = bool(isOn)
+ if self._isOn != isOn:
+ self._isOn = bool(isOn)
+ self.notify()
+
+ @property
+ def fragDecl(self):
+ return self._fragDecl if self.isOn else self._fragDeclNoop
+
+ @property
+ def fragCall(self):
+ return "fog"
+
+ @staticmethod
+ def _zExtentCamera(viewport):
+ """Return (far, near) planes Z in camera coordinates.
+
+ :param Viewport viewport:
+ :return: (far, near) position in camera coords (from 0 to -inf)
+ """
+ # Provide scene z extent in camera coords
+ bounds = viewport.camera.extrinsic.transformBounds(
+ viewport.scene.bounds(transformed=True, dataBounds=True))
+ return bounds[:, 2]
+
+ def setupProgram(self, context, program):
+ if not self.isOn:
+ return
+
+ far, near = context.cache(key='zExtentCamera',
+ factory=self._zExtentCamera,
+ viewport=context.viewport)
+ extent = far - near
+ gl.glUniform2f(program.uniforms['fogExtentInfo'],
+ 0.9/extent if extent != 0. else 0.,
+ near)
+
+ # Use background color as fog color
+ bgColor = context.viewport.background
+ if bgColor is None:
+ bgColor = 1., 1., 1.
+ gl.glUniform3f(program.uniforms['fogColor'], *bgColor[:3])
+
+
+class ClippingPlane(ProgramFunction):
+ """Description of a clipping plane and rendering.
+
+ Convention: Clipping is performed in camera/eye space.
+
+ :param point: Local coordinates of a point on the plane.
+ :type point: numpy.ndarray-like of 3 float32
+ :param normal: Local coordinates of the plane normal.
+ :type normal: numpy.ndarray-like of 3 float32
+ """
+
+ _fragDecl = """
+ /* Clipping plane */
+ /* as rx + gy + bz + a > 0, clipping all positive */
+ uniform vec4 planeEq;
+
+ /* Position is in camera/eye coordinates */
+
+ bool isClipped(vec4 position) {
+ vec4 tmp = planeEq * position;
+ float value = tmp.x + tmp.y + tmp.z + planeEq.a;
+ return (value < 0.0001);
+ }
+
+ void clipping(vec4 position) {
+ if (isClipped(position)) {
+ discard;
+ }
+ }
+ /* End of clipping */
+ """
+
+ _fragDeclNoop = """
+ bool isClipped(vec4 position)
+ {
+ return false;
+ }
+
+ void clipping(vec4 position) {}
+ """
+
+ def __init__(self, point=(0., 0., 0.), normal=(0., 0., 0.)):
+ self._plane = utils.Plane(point, normal)
+
+ @property
+ def plane(self):
+ """Plane parameters in camera space."""
+ return self._plane
+
+ # GL2
+
+ @property
+ def fragDecl(self):
+ return self._fragDecl if self.plane.isPlane else self._fragDeclNoop
+
+ @property
+ def fragCall(self):
+ return "clipping"
+
+ def setupProgram(self, context, program):
+ """Sets-up uniforms of a program using this shader function.
+
+ :param RenderContext context: The current rendering context
+ :param GLProgram program: The program to set-up.
+ It MUST be in use and using this function.
+ """
+ if self.plane.isPlane:
+ gl.glUniform4f(program.uniforms['planeEq'], *self.plane.parameters)
+
+
+class DirectionalLight(event.Notifier, ProgramFunction):
+ """Description of a directional Phong light.
+
+ :param direction: The direction of the light or None to disable light
+ :type direction: ndarray of 3 floats or None
+ :param ambient: RGB ambient light
+ :type ambient: ndarray of 3 floats in [0, 1], default: (1., 1., 1.)
+ :param diffuse: RGB diffuse light parameter
+ :type diffuse: ndarray of 3 floats in [0, 1], default: (0., 0., 0.)
+ :param specular: RGB specular light parameter
+ :type specular: ndarray of 3 floats in [0, 1], default: (1., 1., 1.)
+ :param int shininess: The shininess of the material for specular term,
+ default: 0 which disables specular component.
+ """
+
+ fragmentShaderFunction = """
+ /* Lighting */
+ struct DLight {
+ vec3 lightDir; // Direction of light in object space
+ vec3 ambient;
+ vec3 diffuse;
+ vec3 specular;
+ float shininess;
+ vec3 viewPos; // Camera position in object space
+ };
+
+ uniform DLight dLight;
+
+ vec4 lighting(vec4 color, vec3 position, vec3 normal)
+ {
+ normal = normalize(normal);
+ // 1-sided
+ float nDotL = max(0.0, dot(normal, - dLight.lightDir));
+
+ // 2-sided
+ //float nDotL = dot(normal, - dLight.lightDir);
+ //if (nDotL < 0.) {
+ // nDotL = - nDotL;
+ // normal = - normal;
+ //}
+
+ float specFactor = 0.;
+ if (dLight.shininess > 0. && nDotL > 0.) {
+ vec3 reflection = reflect(dLight.lightDir, normal);
+ vec3 viewDir = normalize(dLight.viewPos - position);
+ specFactor = max(0.0, dot(reflection, viewDir));
+ if (specFactor > 0.) {
+ specFactor = pow(specFactor, dLight.shininess);
+ }
+ }
+
+ vec3 enlightedColor = color.rgb * (dLight.ambient +
+ dLight.diffuse * nDotL) +
+ dLight.specular * specFactor;
+
+ return vec4(enlightedColor.rgb, color.a);
+ }
+ /* End of Lighting */
+ """
+
+ fragmentShaderFunctionNoop = """
+ vec4 lighting(vec4 color, vec3 position, vec3 normal)
+ {
+ return color;
+ }
+ """
+
+ def __init__(self, direction=None,
+ ambient=(1., 1., 1.), diffuse=(0., 0., 0.),
+ specular=(1., 1., 1.), shininess=0):
+ super(DirectionalLight, self).__init__()
+ self._direction = None
+ self.direction = direction # Set _direction
+ self._isOn = True
+ self._ambient = ambient
+ self._diffuse = diffuse
+ self._specular = specular
+ self._shininess = shininess
+
+ ambient = event.notifyProperty('_ambient')
+ diffuse = event.notifyProperty('_diffuse')
+ specular = event.notifyProperty('_specular')
+ shininess = event.notifyProperty('_shininess')
+
+ @property
+ def isOn(self):
+ """True if light is on, False otherwise."""
+ return self._isOn and self._direction is not None
+
+ @isOn.setter
+ def isOn(self, isOn):
+ self._isOn = bool(isOn)
+
+ @contextlib.contextmanager
+ def turnOff(self):
+ """Context manager to temporary turn off lighting during rendering.
+
+ >>> with light.turnOff():
+ ... # Do some rendering without lighting
+ """
+ wason = self._isOn
+ self._isOn = False
+ yield
+ self._isOn = wason
+
+ @property
+ def direction(self):
+ """The direction of the light, or None if light is not on."""
+ return self._direction
+
+ @direction.setter
+ def direction(self, direction):
+ if direction is None:
+ self._direction = None
+ else:
+ assert len(direction) == 3
+ direction = numpy.array(direction, dtype=numpy.float32, copy=True)
+ norm = numpy.linalg.norm(direction)
+ assert norm != 0
+ self._direction = direction / norm
+ self.notify()
+
+ # GL2
+
+ @property
+ def fragmentDef(self):
+ """Definition to add to fragment shader"""
+ if self.isOn:
+ return self.fragmentShaderFunction
+ else:
+ return self.fragmentShaderFunctionNoop
+
+ @property
+ def fragmentCall(self):
+ """Function name to call in fragment shader"""
+ return "lighting"
+
+ def setupProgram(self, context, program):
+ """Sets-up uniforms of a program using this shader function.
+
+ :param RenderContext context: The current rendering context
+ :param GLProgram program: The program to set-up.
+ It MUST be in use and using this function.
+ """
+ if self.isOn and self._direction is not None:
+ # Transform light direction from camera space to object coords
+ lightdir = context.objectToCamera.transformDir(
+ self._direction, direct=False)
+ lightdir /= numpy.linalg.norm(lightdir)
+
+ gl.glUniform3f(program.uniforms['dLight.lightDir'], *lightdir)
+
+ # Convert view position to object coords
+ viewpos = context.objectToCamera.transformPoint(
+ numpy.array((0., 0., 0., 1.), dtype=numpy.float32),
+ direct=False,
+ perspectiveDivide=True)[:3]
+ gl.glUniform3f(program.uniforms['dLight.viewPos'], *viewpos)
+
+ gl.glUniform3f(program.uniforms['dLight.ambient'], *self.ambient)
+ gl.glUniform3f(program.uniforms['dLight.diffuse'], *self.diffuse)
+ gl.glUniform3f(program.uniforms['dLight.specular'], *self.specular)
+ gl.glUniform1f(program.uniforms['dLight.shininess'],
+ self.shininess)
+
+
+class Colormap(event.Notifier, ProgramFunction):
+
+ _declTemplate = string.Template("""
+ uniform sampler2D cmap_texture;
+ uniform int cmap_normalization;
+ uniform float cmap_parameter;
+ uniform float cmap_min;
+ uniform float cmap_oneOverRange;
+ uniform vec4 nancolor;
+
+ const float oneOverLog10 = 0.43429448190325176;
+
+ vec4 colormap(float value) {
+ float data = value; /* Keep original input value for isnan test */
+
+ if (cmap_normalization == 1) { /* Log10 mapping */
+ if (value > 0.0) {
+ value = clamp(cmap_oneOverRange *
+ (oneOverLog10 * log(value) - cmap_min),
+ 0.0, 1.0);
+ } else {
+ value = 0.0;
+ }
+ } else if (cmap_normalization == 2) { /* Sqrt mapping */
+ if (value > 0.0) {
+ value = clamp(cmap_oneOverRange * (sqrt(value) - cmap_min),
+ 0.0, 1.0);
+ } else {
+ value = 0.0;
+ }
+ } else if (cmap_normalization == 3) { /*Gamma correction mapping*/
+ value = pow(
+ clamp(cmap_oneOverRange * (value - cmap_min), 0.0, 1.0),
+ cmap_parameter);
+ } else if (cmap_normalization == 4) { /* arcsinh mapping */
+ /* asinh = log(x + sqrt(x*x + 1) for compatibility with GLSL 1.20 */
+ value = clamp(cmap_oneOverRange * (log(value + sqrt(value*value + 1.0)) - cmap_min), 0.0, 1.0);
+ } else { /* Linear mapping */
+ value = clamp(cmap_oneOverRange * (value - cmap_min), 0.0, 1.0);
+ }
+
+ $discard
+
+ vec4 color;
+ if (data != data) { /* isnan alternative for compatibility with GLSL 1.20 */
+ color = nancolor;
+ } else {
+ color = texture2D(cmap_texture, vec2(value, 0.5));
+ }
+ return color;
+ }
+ """)
+
+ _discardCode = """
+ if (value == 0.) {
+ discard;
+ }
+ """
+
+ call = "colormap"
+
+ NORMS = 'linear', 'log', 'sqrt', 'gamma', 'arcsinh'
+ """Tuple of supported normalizations."""
+
+ _COLORMAP_TEXTURE_UNIT = 1
+ """Texture unit to use for storing the colormap"""
+
+ def __init__(self, colormap=None, norm='linear', gamma=0., range_=(1., 10.)):
+ """Shader function to apply a colormap to a value.
+
+ :param colormap: RGB(A) color look-up table (default: gray)
+ :param colormap: numpy.ndarray of numpy.uint8 of dimension Nx3 or Nx4
+ :param str norm: Normalization to apply: see :attr:`NORMS`.
+ :param float gamma: Gamma normalization parameter
+ :param range_: Range of value to map to the colormap.
+ :type range_: 2-tuple of float (begin, end).
+ """
+ super(Colormap, self).__init__()
+
+ # Init privates to default
+ self._colormap = None
+ self._norm = 'linear'
+ self._gamma = -1.
+ self._range = 1., 10.
+ self._displayValuesBelowMin = True
+ self._nancolor = numpy.array((1., 1., 1., 0.), dtype=numpy.float32)
+
+ self._texture = None
+ self._textureToDiscard = None
+
+ if colormap is None:
+ # default colormap
+ colormap = numpy.empty((256, 3), dtype=numpy.uint8)
+ colormap[:] = numpy.arange(256,
+ dtype=numpy.uint8)[:, numpy.newaxis]
+
+ # Set to values through properties to perform asserts and updates
+ self.colormap = colormap
+ self.norm = norm
+ self.gamma = gamma
+ self.range_ = range_
+
+ @property
+ def decl(self):
+ """Source code of the function declaration"""
+ return self._declTemplate.substitute(
+ discard="" if self.displayValuesBelowMin else self._discardCode)
+
+ @property
+ def colormap(self):
+ """Color look-up table to use."""
+ return numpy.array(self._colormap, copy=True)
+
+ @colormap.setter
+ def colormap(self, colormap):
+ colormap = numpy.array(colormap, copy=True)
+ assert colormap.ndim == 2
+ assert colormap.shape[1] in (3, 4)
+ self._colormap = colormap
+
+ if self._texture is not None and self._texture.name is not None:
+ self._textureToDiscard = self._texture
+
+ data = numpy.empty(
+ (16, self._colormap.shape[0], self._colormap.shape[1]),
+ dtype=self._colormap.dtype)
+ data[:] = self._colormap
+
+ format_ = gl.GL_RGBA if data.shape[-1] == 4 else gl.GL_RGB
+
+ self._texture = _glutils.Texture(
+ format_, data, format_,
+ texUnit=self._COLORMAP_TEXTURE_UNIT,
+ minFilter=gl.GL_NEAREST,
+ magFilter=gl.GL_NEAREST,
+ wrap=gl.GL_CLAMP_TO_EDGE)
+
+ self.notify()
+
+ @property
+ def nancolor(self):
+ """RGBA color to use for Not-A-Number values as 4 float in [0., 1.]"""
+ return self._nancolor
+
+ @nancolor.setter
+ def nancolor(self, color):
+ color = numpy.clip(numpy.array(color, dtype=numpy.float32), 0., 1.)
+ assert color.ndim == 1
+ assert len(color) == 4
+ if not numpy.array_equal(self._nancolor, color):
+ self._nancolor = color
+ self.notify()
+
+ @property
+ def norm(self):
+ """Normalization to use for colormap mapping.
+
+ One of 'linear' (the default), 'log' for log10 mapping or 'sqrt'.
+ Invalid values (e.g., negative values with 'log' or 'sqrt') are mapped to 0.
+ """
+ return self._norm
+
+ @norm.setter
+ def norm(self, norm):
+ if norm != self._norm:
+ assert norm in self.NORMS
+ self._norm = norm
+ if norm in ('log', 'sqrt'):
+ self.range_ = self.range_ # To test for positive range_
+ self.notify()
+
+ @property
+ def gamma(self):
+ """Gamma correction normalization parameter (float >= 0.)"""
+ return self._gamma
+
+ @gamma.setter
+ def gamma(self, gamma):
+ if gamma != self._gamma:
+ assert gamma >= 0.
+ self._gamma = gamma
+ self.notify()
+
+ @property
+ def range_(self):
+ """Range of values to map to the colormap.
+
+ 2-tuple of floats: (begin, end).
+ The begin value is mapped to the origin of the colormap and the
+ end value is mapped to the other end of the colormap.
+ The colormap is reversed if begin > end.
+ """
+ return self._range
+
+ @range_.setter
+ def range_(self, range_):
+ assert len(range_) == 2
+ range_ = float(range_[0]), float(range_[1])
+
+ if self.norm == 'log' and (range_[0] <= 0. or range_[1] <= 0.):
+ _logger.warning(
+ "Log normalization and negative range: updating range.")
+ minPos = numpy.finfo(numpy.float32).tiny
+ range_ = max(range_[0], minPos), max(range_[1], minPos)
+ elif self.norm == 'sqrt' and (range_[0] < 0. or range_[1] < 0.):
+ _logger.warning(
+ "Sqrt normalization and negative range: updating range.")
+ range_ = max(range_[0], 0.), max(range_[1], 0.)
+
+ if range_ != self._range:
+ self._range = range_
+ self.notify()
+
+ @property
+ def displayValuesBelowMin(self):
+ """True to display values below colormap min, False to discard them.
+ """
+ return self._displayValuesBelowMin
+
+ @displayValuesBelowMin.setter
+ def displayValuesBelowMin(self, displayValuesBelowMin):
+ displayValuesBelowMin = bool(displayValuesBelowMin)
+ if self._displayValuesBelowMin != displayValuesBelowMin:
+ self._displayValuesBelowMin = displayValuesBelowMin
+ self.notify()
+
+ def setupProgram(self, context, program):
+ """Sets-up uniforms of a program using this shader function.
+
+ :param RenderContext context: The current rendering context
+ :param GLProgram program: The program to set-up.
+ It MUST be in use and using this function.
+ """
+ self.prepareGL2(context) # TODO see how to handle
+
+ self._texture.bind()
+
+ gl.glUniform1i(program.uniforms['cmap_texture'],
+ self._texture.texUnit)
+
+ min_, max_ = self.range_
+ param = 0.
+ if self._norm == 'log':
+ min_, max_ = numpy.log10(min_), numpy.log10(max_)
+ normID = 1
+ elif self._norm == 'sqrt':
+ min_, max_ = numpy.sqrt(min_), numpy.sqrt(max_)
+ normID = 2
+ elif self._norm == 'gamma':
+ # Keep min_, max_ as is
+ param = self._gamma
+ normID = 3
+ elif self._norm == 'arcsinh':
+ min_, max_ = numpy.arcsinh(min_), numpy.arcsinh(max_)
+ normID = 4
+ else: # Linear
+ normID = 0
+
+ gl.glUniform1i(program.uniforms['cmap_normalization'], normID)
+ gl.glUniform1f(program.uniforms['cmap_parameter'], param)
+ gl.glUniform1f(program.uniforms['cmap_min'], min_)
+ gl.glUniform1f(program.uniforms['cmap_oneOverRange'],
+ (1. / (max_ - min_)) if max_ != min_ else 0.)
+ gl.glUniform4f(program.uniforms['nancolor'], *self._nancolor)
+
+ def prepareGL2(self, context):
+ if self._textureToDiscard is not None:
+ self._textureToDiscard.discard()
+ self._textureToDiscard = None
+
+ self._texture.prepare()
diff --git a/src/silx/gui/plot3d/scene/interaction.py b/src/silx/gui/plot3d/scene/interaction.py
new file mode 100644
index 0000000..14a54dc
--- /dev/null
+++ b/src/silx/gui/plot3d/scene/interaction.py
@@ -0,0 +1,701 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides interaction to plug on the scene graph."""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/07/2016"
+
+import logging
+import numpy
+
+from silx.gui import qt
+from silx.gui.plot.Interaction import \
+ StateMachine, State, LEFT_BTN, RIGHT_BTN # , MIDDLE_BTN
+
+from . import transform
+
+
+_logger = logging.getLogger(__name__)
+
+
+class ClickOrDrag(StateMachine):
+ """Click or drag interaction for a given button.
+
+ """
+ #TODO: merge this class with silx.gui.plot.Interaction.ClickOrDrag
+
+ DRAG_THRESHOLD_SQUARE_DIST = 5 ** 2
+
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ if btn == self.machine.button:
+ self.goto('clickOrDrag', x, y)
+ return True
+
+ class ClickOrDrag(State):
+ def enterState(self, x, y):
+ self.initPos = x, y
+
+ enter = enterState # silx v.0.3 support, remove when 0.4 out
+
+ def onMove(self, x, y):
+ dx = (x - self.initPos[0]) ** 2
+ dy = (y - self.initPos[1]) ** 2
+ if (dx ** 2 + dy ** 2) >= self.machine.DRAG_THRESHOLD_SQUARE_DIST:
+ self.goto('drag', self.initPos, (x, y))
+
+ def onRelease(self, x, y, btn):
+ if btn == self.machine.button:
+ self.machine.click(x, y)
+ self.goto('idle')
+
+ class Drag(State):
+ def enterState(self, initPos, curPos):
+ self.initPos = initPos
+ self.machine.beginDrag(*initPos)
+ self.machine.drag(*curPos)
+
+ enter = enterState # silx v.0.3 support, remove when 0.4 out
+
+ def onMove(self, x, y):
+ self.machine.drag(x, y)
+
+ def onRelease(self, x, y, btn):
+ if btn == self.machine.button:
+ self.machine.endDrag(self.initPos, (x, y))
+ self.goto('idle')
+
+ def __init__(self, button=LEFT_BTN):
+ self.button = button
+ states = {
+ 'idle': ClickOrDrag.Idle,
+ 'clickOrDrag': ClickOrDrag.ClickOrDrag,
+ 'drag': ClickOrDrag.Drag
+ }
+ super(ClickOrDrag, self).__init__(states, 'idle')
+
+ def click(self, x, y):
+ """Called upon a left or right button click.
+ To override in a subclass.
+ """
+ pass
+
+ def beginDrag(self, x, y):
+ """Called at the beginning of a drag gesture with left button
+ pressed.
+ To override in a subclass.
+ """
+ pass
+
+ def drag(self, x, y):
+ """Called on mouse moved during a drag gesture.
+ To override in a subclass.
+ """
+ pass
+
+ def endDrag(self, x, y):
+ """Called at the end of a drag gesture when the left button is
+ released.
+ To override in a subclass.
+ """
+ pass
+
+
+class CameraSelectRotate(ClickOrDrag):
+ """Camera rotation using an arcball-like interaction."""
+
+ def __init__(self, viewport, orbitAroundCenter=True, button=RIGHT_BTN,
+ selectCB=None):
+ self._viewport = viewport
+ self._orbitAroundCenter = orbitAroundCenter
+ self._selectCB = selectCB
+ self._reset()
+ super(CameraSelectRotate, self).__init__(button)
+
+ def _reset(self):
+ self._origin, self._center = None, None
+ self._startExtrinsic = None
+
+ def click(self, x, y):
+ if self._selectCB is not None:
+ ndcZ = self._viewport._pickNdcZGL(x, y)
+ position = self._viewport._getXZYGL(x, y)
+ # This assume no object lie on the far plane
+ # Alternative, change the depth range so that far is < 1
+ if ndcZ != 1. and position is not None:
+ self._selectCB((x, y, ndcZ), position)
+
+ def beginDrag(self, x, y):
+ centerPos = None
+ if not self._orbitAroundCenter:
+ # Try to use picked object position as center of rotation
+ ndcZ = self._viewport._pickNdcZGL(x, y)
+ if ndcZ != 1.:
+ # Hit an object, use picked point as center
+ centerPos = self._viewport._getXZYGL(x, y) # Can return None
+
+ if centerPos is None:
+ # Not using picked position, use scene center
+ bounds = self._viewport.scene.bounds(transformed=True)
+ centerPos = 0.5 * (bounds[0] + bounds[1])
+
+ self._center = transform.Translate(*centerPos)
+ self._origin = x, y
+ self._startExtrinsic = self._viewport.camera.extrinsic.copy()
+
+ def drag(self, x, y):
+ if self._center is None:
+ return
+
+ dx, dy = self._origin[0] - x, self._origin[1] - y
+
+ if dx == 0 and dy == 0:
+ direction = self._startExtrinsic.direction
+ up = self._startExtrinsic.up
+ position = self._startExtrinsic.position
+ else:
+ minsize = min(self._viewport.size)
+ distance = numpy.sqrt(dx ** 2 + dy ** 2)
+ angle = distance / minsize * numpy.pi
+
+ # Take care of y inversion
+ direction = dx * self._startExtrinsic.side - \
+ dy * self._startExtrinsic.up
+ direction /= numpy.linalg.norm(direction)
+ axis = numpy.cross(direction, self._startExtrinsic.direction)
+ axis /= numpy.linalg.norm(axis)
+
+ # Orbit start camera with current angle and axis
+ # Rotate viewing direction
+ rotation = transform.Rotate(numpy.degrees(angle), *axis)
+ direction = rotation.transformDir(self._startExtrinsic.direction)
+ up = rotation.transformDir(self._startExtrinsic.up)
+
+ # Rotate position around center
+ trlist = transform.StaticTransformList((
+ self._center,
+ rotation,
+ self._center.inverse()))
+ position = trlist.transformPoint(self._startExtrinsic.position)
+
+ camerapos = self._viewport.camera.extrinsic
+ camerapos.setOrientation(direction, up)
+ camerapos.position = position
+
+ def endDrag(self, x, y):
+ self._reset()
+
+
+class CameraSelectPan(ClickOrDrag):
+ """Picking on click and pan camera on drag."""
+
+ def __init__(self, viewport, button=LEFT_BTN, selectCB=None):
+ self._viewport = viewport
+ self._selectCB = selectCB
+ self._lastPosNdc = None
+ super(CameraSelectPan, self).__init__(button)
+
+ def click(self, x, y):
+ if self._selectCB is not None:
+ ndcZ = self._viewport._pickNdcZGL(x, y)
+ position = self._viewport._getXZYGL(x, y)
+ # This assume no object lie on the far plane
+ # Alternative, change the depth range so that far is < 1
+ if ndcZ != 1. and position is not None:
+ self._selectCB((x, y, ndcZ), position)
+
+ def beginDrag(self, x, y):
+ ndc = self._viewport.windowToNdc(x, y)
+ ndcZ = self._viewport._pickNdcZGL(x, y)
+ # ndcZ is the panning plane
+ if ndc is not None and ndcZ is not None:
+ self._lastPosNdc = numpy.array((ndc[0], ndc[1], ndcZ, 1.),
+ dtype=numpy.float32)
+ else:
+ self._lastPosNdc = None
+
+ def drag(self, x, y):
+ if self._lastPosNdc is not None:
+ ndc = self._viewport.windowToNdc(x, y)
+ if ndc is not None:
+ ndcPos = numpy.array((ndc[0], ndc[1], self._lastPosNdc[2], 1.),
+ dtype=numpy.float32)
+
+ # Convert last and current NDC positions to scene coords
+ scenePos = self._viewport.camera.transformPoint(
+ ndcPos, direct=False, perspectiveDivide=True)
+ lastScenePos = self._viewport.camera.transformPoint(
+ self._lastPosNdc, direct=False, perspectiveDivide=True)
+
+ # Get translation in scene coords
+ translation = scenePos[:3] - lastScenePos[:3]
+ self._viewport.camera.extrinsic.position -= translation
+
+ # Store for next drag
+ self._lastPosNdc = ndcPos
+
+ def endDrag(self, x, y):
+ self._lastPosNdc = None
+
+
+class CameraWheel(object):
+ """StateMachine like class, just handling wheel events."""
+
+ # TODO choose scale of motion? Translation or Scale?
+ def __init__(self, viewport, mode='center', scaleTransform=None):
+ assert mode in ('center', 'position', 'scale')
+ self._viewport = viewport
+ if mode == 'center':
+ self._zoomTo = self._zoomToCenter
+ elif mode == 'position':
+ self._zoomTo = self._zoomToPosition
+ elif mode == 'scale':
+ self._zoomTo = self._zoomByScale
+ self._scale = scaleTransform
+ else:
+ raise ValueError('Unsupported mode: %s' % mode)
+
+ def handleEvent(self, eventName, *args, **kwargs):
+ if eventName == 'wheel':
+ return self._zoomTo(*args, **kwargs)
+
+ def _zoomToCenter(self, x, y, angleInDegrees):
+ """Zoom to center of display.
+
+ Only works with perspective camera.
+ """
+ direction = 'forward' if angleInDegrees > 0 else 'backward'
+ self._viewport.camera.move(direction)
+ return True
+
+ def _zoomToPositionAbsolute(self, x, y, angleInDegrees):
+ """Zoom while keeping pixel under mouse invariant.
+
+ Only works with perspective camera.
+ """
+ ndc = self._viewport.windowToNdc(x, y)
+ if ndc is not None:
+ near = numpy.array((ndc[0], ndc[1], -1., 1.), dtype=numpy.float32)
+
+ nearscene = self._viewport.camera.transformPoint(
+ near, direct=False, perspectiveDivide=True)
+
+ far = numpy.array((ndc[0], ndc[1], 1., 1.), dtype=numpy.float32)
+ farscene = self._viewport.camera.transformPoint(
+ far, direct=False, perspectiveDivide=True)
+
+ dirscene = farscene[:3] - nearscene[:3]
+ dirscene /= numpy.linalg.norm(dirscene)
+
+ if angleInDegrees < 0:
+ dirscene *= -1.
+
+ # TODO which scale
+ self._viewport.camera.extrinsic.position += dirscene
+ return True
+
+ def _zoomToPosition(self, x, y, angleInDegrees):
+ """Zoom while keeping pixel under mouse invariant."""
+ projection = self._viewport.camera.intrinsic
+ extrinsic = self._viewport.camera.extrinsic
+
+ if isinstance(projection, transform.Perspective):
+ # For perspective projection, move camera
+ ndc = self._viewport.windowToNdc(x, y)
+ if ndc is not None:
+ ndcz = self._viewport._pickNdcZGL(x, y)
+
+ position = numpy.array((ndc[0], ndc[1], ndcz),
+ dtype=numpy.float32)
+ positionscene = self._viewport.camera.transformPoint(
+ position, direct=False, perspectiveDivide=True)
+
+ camtopos = extrinsic.position - positionscene
+
+ step = 0.2 * (1. if angleInDegrees < 0 else -1.)
+ extrinsic.position += step * camtopos
+
+ elif isinstance(projection, transform.Orthographic):
+ # For orthographic projection, change projection borders
+ ndcx, ndcy = self._viewport.windowToNdc(x, y, checkInside=False)
+
+ step = 0.2 * (1. if angleInDegrees < 0 else -1.)
+
+ dx = (ndcx + 1) / 2.
+ stepwidth = step * (projection.right - projection.left)
+ left = projection.left - dx * stepwidth
+ right = projection.right + (1. - dx) * stepwidth
+
+ dy = (ndcy + 1) / 2.
+ stepheight = step * (projection.top - projection.bottom)
+ bottom = projection.bottom - dy * stepheight
+ top = projection.top + (1. - dy) * stepheight
+
+ projection.setClipping(left, right, bottom, top)
+
+ else:
+ raise RuntimeError('Unsupported camera', projection)
+ return True
+
+ def _zoomByScale(self, x, y, angleInDegrees):
+ """Zoom by scaling scene (do not keep pixel under mouse invariant)."""
+ scalefactor = 1.1
+ if angleInDegrees < 0.:
+ scalefactor = 1. / scalefactor
+ self._scale.scale = scalefactor * self._scale.scale
+
+ self._viewport.adjustCameraDepthExtent()
+ return True
+
+
+class FocusManager(StateMachine):
+ """Manages focus across multiple event handlers
+
+ On press an event handler can acquire focus.
+ By default it looses focus when all buttons are released.
+ """
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ for eventHandler in self.machine.currentEventHandler:
+ requestFocus = eventHandler.handleEvent('press', x, y, btn)
+ if requestFocus:
+ self.goto('focus', eventHandler, btn)
+ break
+
+ def _processEvent(self, *args):
+ for eventHandler in self.machine.currentEventHandler:
+ consumeEvent = eventHandler.handleEvent(*args)
+ if consumeEvent:
+ break
+
+ def onMove(self, x, y):
+ self._processEvent('move', x, y)
+
+ def onRelease(self, x, y, btn):
+ self._processEvent('release', x, y, btn)
+
+ def onWheel(self, x, y, angle):
+ self._processEvent('wheel', x, y, angle)
+
+ class Focus(State):
+ def enterState(self, eventHandler, btn):
+ self.eventHandler = eventHandler
+ self.focusBtns = {btn} # Set
+
+ enter = enterState # silx v.0.3 support, remove when 0.4 out
+
+ def onPress(self, x, y, btn):
+ self.focusBtns.add(btn)
+ self.eventHandler.handleEvent('press', x, y, btn)
+
+ def onMove(self, x, y):
+ self.eventHandler.handleEvent('move', x, y)
+
+ def onRelease(self, x, y, btn):
+ self.focusBtns.discard(btn)
+ requestfocus = self.eventHandler.handleEvent('release', x, y, btn)
+ if len(self.focusBtns) == 0 and not requestfocus:
+ self.goto('idle')
+
+ def onWheel(self, x, y, angleInDegrees):
+ self.eventHandler.handleEvent('wheel', x, y, angleInDegrees)
+
+ def __init__(self, eventHandlers=(), ctrlEventHandlers=None):
+ self.defaultEventHandlers = eventHandlers
+ self.ctrlEventHandlers = ctrlEventHandlers
+ self.currentEventHandler = self.defaultEventHandlers
+
+ states = {
+ 'idle': FocusManager.Idle,
+ 'focus': FocusManager.Focus
+ }
+ super(FocusManager, self).__init__(states, 'idle')
+
+ def onKeyPress(self, key):
+ if key == qt.Qt.Key_Control and self.ctrlEventHandlers is not None:
+ self.currentEventHandler = self.ctrlEventHandlers
+
+ def onKeyRelease(self, key):
+ if key == qt.Qt.Key_Control:
+ self.currentEventHandler = self.defaultEventHandlers
+
+ def cancel(self):
+ for handler in self.currentEventHandler:
+ handler.cancel()
+
+
+class RotateCameraControl(FocusManager):
+ """Combine wheel and rotate state machine for left button
+ and pan when ctrl is pressed
+ """
+ def __init__(self, viewport,
+ orbitAroundCenter=False,
+ mode='center', scaleTransform=None,
+ selectCB=None):
+ handlers = (CameraWheel(viewport, mode, scaleTransform),
+ CameraSelectRotate(
+ viewport, orbitAroundCenter, LEFT_BTN, selectCB))
+ ctrlHandlers = (CameraWheel(viewport, mode, scaleTransform),
+ CameraSelectPan(viewport, LEFT_BTN, selectCB))
+ super(RotateCameraControl, self).__init__(handlers, ctrlHandlers)
+
+
+class PanCameraControl(FocusManager):
+ """Combine wheel, selectPan and rotate state machine for left button
+ and rotate when ctrl is pressed"""
+ def __init__(self, viewport,
+ orbitAroundCenter=False,
+ mode='center', scaleTransform=None,
+ selectCB=None):
+ handlers = (CameraWheel(viewport, mode, scaleTransform),
+ CameraSelectPan(viewport, LEFT_BTN, selectCB))
+ ctrlHandlers = (CameraWheel(viewport, mode, scaleTransform),
+ CameraSelectRotate(
+ viewport, orbitAroundCenter, LEFT_BTN, selectCB))
+ super(PanCameraControl, self).__init__(handlers, ctrlHandlers)
+
+
+class CameraControl(FocusManager):
+ """Combine wheel, selectPan and rotate state machine."""
+ def __init__(self, viewport,
+ orbitAroundCenter=False,
+ mode='center', scaleTransform=None,
+ selectCB=None):
+ handlers = (CameraWheel(viewport, mode, scaleTransform),
+ CameraSelectPan(viewport, LEFT_BTN, selectCB),
+ CameraSelectRotate(
+ viewport, orbitAroundCenter, RIGHT_BTN, selectCB))
+ super(CameraControl, self).__init__(handlers)
+
+
+class PlaneRotate(ClickOrDrag):
+ """Plane rotation using arcball interaction.
+
+ Arcball ref.:
+ Ken Shoemake. ARCBALL: A user interface for specifying three-dimensional
+ orientation using a mouse. In Proc. GI '92. (1992). pp. 151-156.
+ """
+
+ def __init__(self, viewport, plane, button=RIGHT_BTN):
+ self._viewport = viewport
+ self._plane = plane
+ self._reset()
+ super(PlaneRotate, self).__init__(button)
+
+ def _reset(self):
+ self._beginNormal, self._beginCenter = None, None
+
+ def click(self, x, y):
+ pass # No interaction
+
+ @staticmethod
+ def _sphereUnitVector(radius, center, position):
+ """Returns the unit vector of the projection of position on a sphere.
+
+ It assumes an orthographic projection.
+ For perspective projection, it gives an approximation, but it
+ simplifies computations and results in consistent arcball control
+ in control space.
+
+ All parameters must be in screen coordinate system
+ (either pixels or normalized coordinates).
+
+ :param float radius: The radius of the sphere.
+ :param center: (x, y) coordinates of the center.
+ :param position: (x, y) coordinates of the cursor position.
+ :return: Unit vector.
+ :rtype: numpy.ndarray of 3 floats.
+ """
+ center, position = numpy.array(center), numpy.array(position)
+
+ # Normalize x and y on a unit circle
+ spherecoords = (position - center) / float(radius)
+ squarelength = numpy.sum(spherecoords ** 2)
+
+ # Project on the unit sphere and compute z coordinates
+ if squarelength > 1.0: # Outside sphere: project
+ spherecoords /= numpy.sqrt(squarelength)
+ zsphere = 0.0
+ else: # In sphere: compute z
+ zsphere = numpy.sqrt(1. - squarelength)
+
+ spherecoords = numpy.append(spherecoords, zsphere)
+ return spherecoords
+
+ def beginDrag(self, x, y):
+ # Makes sure the point defining the plane is at the center as
+ # it will be the center of rotation (as rotation is applied to normal)
+ self._plane.plane.point = self._plane.center
+
+ # Store the plane normal
+ self._beginNormal = self._plane.plane.normal
+
+ _logger.debug(
+ 'Begin arcball, plane center %s', str(self._plane.center))
+
+ # Do the arcball on the screen
+ radius = min(self._viewport.size)
+ if self._plane.center is None:
+ self._beginCenter = None
+
+ else:
+ center = self._plane.objectToNDCTransform.transformPoint(
+ self._plane.center, perspectiveDivide=True)
+ self._beginCenter = self._viewport.ndcToWindow(
+ center[0], center[1], checkInside=False)
+
+ self._startVector = self._sphereUnitVector(
+ radius, self._beginCenter, (x, y))
+
+ def drag(self, x, y):
+ if self._beginCenter is None:
+ return
+
+ # Compute rotation: this is twice the rotation of the arcball
+ radius = min(self._viewport.size)
+ currentvector = self._sphereUnitVector(
+ radius, self._beginCenter, (x, y))
+ crossprod = numpy.cross(self._startVector, currentvector)
+ dotprod = numpy.dot(self._startVector, currentvector)
+
+ quaternion = numpy.append(crossprod, dotprod)
+ # Rotation was computed with Y downward, but apply in NDC, invert Y
+ quaternion[1] *= -1.
+
+ rotation = transform.Rotate()
+ rotation.quaternion = quaternion
+
+ # Convert to NDC, rotate, convert back to object
+ normal = self._plane.objectToNDCTransform.transformNormal(
+ self._beginNormal)
+ normal = rotation.transformNormal(normal)
+ normal = self._plane.objectToNDCTransform.transformNormal(
+ normal, direct=False)
+ self._plane.plane.normal = normal
+
+ def endDrag(self, x, y):
+ self._reset()
+
+
+class PlanePan(ClickOrDrag):
+ """Pan a plane along its normal on drag."""
+
+ def __init__(self, viewport, plane, button=LEFT_BTN):
+ self._plane = plane
+ self._viewport = viewport
+ self._beginPlanePoint = None
+ self._beginPos = None
+ self._dragNdcZ = 0.
+ super(PlanePan, self).__init__(button)
+
+ def click(self, x, y):
+ pass
+
+ def beginDrag(self, x, y):
+ ndc = self._viewport.windowToNdc(x, y)
+ ndcZ = self._viewport._pickNdcZGL(x, y)
+ # ndcZ is the panning plane
+ if ndc is not None and ndcZ is not None:
+ ndcPos = numpy.array((ndc[0], ndc[1], ndcZ, 1.),
+ dtype=numpy.float32)
+ scenePos = self._viewport.camera.transformPoint(
+ ndcPos, direct=False, perspectiveDivide=True)
+ self._beginPos = self._plane.objectToSceneTransform.transformPoint(
+ scenePos, direct=False)
+ self._dragNdcZ = ndcZ
+ else:
+ self._beginPos = None
+ self._dragNdcZ = 0.
+
+ self._beginPlanePoint = self._plane.plane.point
+
+ def drag(self, x, y):
+ if self._beginPos is not None:
+ ndc = self._viewport.windowToNdc(x, y)
+ if ndc is not None:
+ ndcPos = numpy.array((ndc[0], ndc[1], self._dragNdcZ, 1.),
+ dtype=numpy.float32)
+
+ # Convert last and current NDC positions to scene coords
+ scenePos = self._viewport.camera.transformPoint(
+ ndcPos, direct=False, perspectiveDivide=True)
+ curPos = self._plane.objectToSceneTransform.transformPoint(
+ scenePos, direct=False)
+
+ # Get translation in scene coords
+ translation = curPos[:3] - self._beginPos[:3]
+
+ newPoint = self._beginPlanePoint + translation
+
+ # Keep plane point in bounds
+ bounds = self._plane.parent.bounds(dataBounds=True)
+ if bounds is not None:
+ newPoint = numpy.clip(
+ newPoint, a_min=bounds[0], a_max=bounds[1])
+
+ # Only update plane if it is in some bounds
+ self._plane.plane.point = newPoint
+
+ def endDrag(self, x, y):
+ self._beginPlanePoint = None
+
+
+class PlaneControl(FocusManager):
+ """Combine wheel, selectPan and rotate state machine for plane control."""
+ def __init__(self, viewport, plane,
+ mode='center', scaleTransform=None):
+ handlers = (CameraWheel(viewport, mode, scaleTransform),
+ PlanePan(viewport, plane, LEFT_BTN),
+ PlaneRotate(viewport, plane, RIGHT_BTN))
+ super(PlaneControl, self).__init__(handlers)
+
+
+class PanPlaneRotateCameraControl(FocusManager):
+ """Combine wheel, pan plane and camera rotate state machine."""
+ def __init__(self, viewport, plane,
+ mode='center', scaleTransform=None):
+ handlers = (CameraWheel(viewport, mode, scaleTransform),
+ PlanePan(viewport, plane, LEFT_BTN),
+ CameraSelectRotate(viewport,
+ orbitAroundCenter=False,
+ button=RIGHT_BTN))
+ super(PanPlaneRotateCameraControl, self).__init__(handlers)
+
+
+class PanPlaneZoomOnWheelControl(FocusManager):
+ """Combine zoom on wheel and pan plane state machines."""
+ def __init__(self, viewport, plane,
+ mode='center',
+ orbitAroundCenter=False,
+ scaleTransform=None):
+ handlers = (CameraWheel(viewport, mode, scaleTransform),
+ PlanePan(viewport, plane, LEFT_BTN))
+ ctrlHandlers = (CameraWheel(viewport, mode, scaleTransform),
+ CameraSelectRotate(
+ viewport, orbitAroundCenter, LEFT_BTN))
+ super(PanPlaneZoomOnWheelControl, self).__init__(handlers, ctrlHandlers)
diff --git a/src/silx/gui/plot3d/scene/primitives.py b/src/silx/gui/plot3d/scene/primitives.py
new file mode 100644
index 0000000..7f35c3c
--- /dev/null
+++ b/src/silx/gui/plot3d/scene/primitives.py
@@ -0,0 +1,2524 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+try:
+ from collections import abc
+except ImportError: # Python2 support
+ import collections as abc
+import ctypes
+from functools import reduce
+import logging
+import string
+
+import numpy
+
+from silx.gui.colors import rgba
+
+from ... import _glutils
+from ..._glutils import gl
+
+from . import event
+from . import core
+from . import transform
+from . import utils
+from .function import Colormap
+
+_logger = logging.getLogger(__name__)
+
+
+# Geometry ####################################################################
+
+class Geometry(core.Elem):
+ """Set of vertices with normals and colors.
+
+ :param str mode: OpenGL drawing mode:
+ lines, line_strip, loop, triangles, triangle_strip, fan
+ :param indices: Array of vertex indices or None
+ :param bool copy: True (default) to copy the data, False to use as is.
+ :param str attrib0: Name of the attribute that MUST be an array.
+ :param attributes: Provide list of attributes as extra parameters.
+ """
+
+ _ATTR_INFO = {
+ 'position': {'dims': (1, 2), 'lastDim': (2, 3, 4)},
+ 'normal': {'dims': (1, 2), 'lastDim': (3,)},
+ 'color': {'dims': (1, 2), 'lastDim': (3, 4)},
+ }
+
+ _MODE_CHECKS = { # Min, Modulo
+ 'lines': (2, 2), 'line_strip': (2, 0), 'loop': (2, 0),
+ 'points': (1, 0),
+ 'triangles': (3, 3), 'triangle_strip': (3, 0), 'fan': (3, 0)
+ }
+
+ _MODES = {
+ 'lines': gl.GL_LINES,
+ 'line_strip': gl.GL_LINE_STRIP,
+ 'loop': gl.GL_LINE_LOOP,
+
+ 'points': gl.GL_POINTS,
+
+ 'triangles': gl.GL_TRIANGLES,
+ 'triangle_strip': gl.GL_TRIANGLE_STRIP,
+ 'fan': gl.GL_TRIANGLE_FAN
+ }
+
+ _LINE_MODES = 'lines', 'line_strip', 'loop'
+
+ _TRIANGLE_MODES = 'triangles', 'triangle_strip', 'fan'
+
+ def __init__(self,
+ mode,
+ indices=None,
+ copy=True,
+ attrib0='position',
+ **attributes):
+ super(Geometry, self).__init__()
+
+ self._attrib0 = str(attrib0)
+
+ self._vbos = {} # Store current vbos
+ self._unsyncAttributes = [] # Store attributes to copy to vbos
+ self.__bounds = None # Cache object's bounds
+ # Attribute names defining the object bounds
+ self.__boundsAttributeNames = (self._attrib0,)
+
+ assert mode in self._MODES
+ self._mode = mode
+
+ # Set attributes
+ self._attributes = {}
+ for name, data in attributes.items():
+ self.setAttribute(name, data, copy=copy)
+
+ # Set indices
+ self._indices = None
+ self.setIndices(indices, copy=copy)
+
+ # More consistency checks
+ mincheck, modulocheck = self._MODE_CHECKS[self._mode]
+ if self._indices is not None:
+ nbvertices = len(self._indices)
+ else:
+ nbvertices = self.nbVertices
+
+ if nbvertices != 0:
+ assert nbvertices >= mincheck
+ if modulocheck != 0:
+ assert (nbvertices % modulocheck) == 0
+
+ @property
+ def drawMode(self):
+ """Kind of primitive to render, in :attr:`_MODES` (str)"""
+ return self._mode
+
+ @staticmethod
+ def _glReadyArray(array, copy=True):
+ """Making a contiguous array, checking float types.
+
+ :param iterable array: array-like data to prepare for attribute
+ :param bool copy: True to make a copy of the array, False to use as is
+ """
+ # Convert single value (int, float, numpy types) to tuple
+ if not isinstance(array, abc.Iterable):
+ array = (array, )
+
+ # Makes sure it is an array
+ array = numpy.array(array, copy=False)
+
+ dtype = None
+ if array.dtype.kind == 'f' and array.dtype.itemsize != 4:
+ # Cast to float32
+ _logger.info('Cast array to float32')
+ dtype = numpy.float32
+ elif array.dtype.itemsize > 4:
+ # Cast (u)int64 to (u)int32
+ if array.dtype.kind == 'i':
+ _logger.info('Cast array to int32')
+ dtype = numpy.int32
+ elif array.dtype.kind == 'u':
+ _logger.info('Cast array to uint32')
+ dtype = numpy.uint32
+
+ return numpy.array(array, dtype=dtype, order='C', copy=copy)
+
+ @property
+ def nbVertices(self):
+ """Returns the number of vertices of current attributes.
+
+ It returns None if there is no attributes.
+ """
+ for array in self._attributes.values():
+ if len(array.shape) == 2:
+ return len(array)
+ return None
+
+ @property
+ def attrib0(self):
+ """Attribute name that MUST be an array (str)"""
+ return self._attrib0
+
+ def setAttribute(self, name, array, copy=True):
+ """Set attribute with provided array.
+
+ :param str name: The name of the attribute
+ :param array: Array-like attribute data or None to remove attribute
+ :param bool copy: True (default) to copy the data, False to use as is
+ """
+ # This triggers associated GL resources to be garbage collected
+ self._vbos.pop(name, None)
+
+ if array is None:
+ self._attributes.pop(name, None)
+
+ else:
+ array = self._glReadyArray(array, copy=copy)
+
+ if name not in self._ATTR_INFO:
+ _logger.debug('Not checking attribute %s dimensions', name)
+ else:
+ checks = self._ATTR_INFO[name]
+
+ if (array.ndim == 1 and checks['lastDim'] == (1,) and
+ len(array) > 1):
+ array = array.reshape((len(array), 1))
+
+ # Checks
+ assert array.ndim in checks['dims'], "Attr %s" % name
+ assert array.shape[-1] in checks['lastDim'], "Attr %s" % name
+
+ # Makes sure attrib0 is considered as an array of values
+ if name == self.attrib0 and array.ndim == 1:
+ array.shape = 1, -1
+
+ # Check length against another attribute array
+ # Causes problems when updating
+ # nbVertices = self.nbVertices
+ # if array.ndim == 2 and nbVertices is not None:
+ # assert len(array) == nbVertices
+
+ self._attributes[name] = array
+ if array.ndim == 2: # Store this in a VBO
+ self._unsyncAttributes.append(name)
+
+ if name in self.boundsAttributeNames: # Reset bounds
+ self.__bounds = None
+
+ self.notify()
+
+ def getAttribute(self, name, copy=True):
+ """Returns the numpy.ndarray corresponding to the name attribute.
+
+ :param str name: The name of the attribute to get.
+ :param bool copy: True to get a copy (default),
+ False to get internal array (DO NOT MODIFY)
+ :return: The corresponding array or None if no corresponding attribute.
+ :rtype: numpy.ndarray
+ """
+ attr = self._attributes.get(name, None)
+ return None if attr is None else numpy.array(attr, copy=copy)
+
+ def useAttribute(self, program, name=None):
+ """Enable and bind attribute(s) for a specific program.
+
+ This MUST be called with OpenGL context active and after prepareGL2
+ has been called.
+
+ :param GLProgram program: The program for which to set the attributes
+ :param str name: The attribute name to set or None to set then all
+ """
+ if name is None:
+ for name in program.attributes:
+ self.useAttribute(program, name)
+
+ else:
+ attribute = program.attributes.get(name)
+ if attribute is None:
+ return
+
+ vboattrib = self._vbos.get(name)
+ if vboattrib is not None:
+ gl.glEnableVertexAttribArray(attribute)
+ vboattrib.setVertexAttrib(attribute)
+
+ elif name not in self._attributes:
+ gl.glDisableVertexAttribArray(attribute)
+
+ else:
+ array = self._attributes[name]
+ assert array is not None
+
+ if array.ndim == 1:
+ assert len(array) in (1, 2, 3, 4)
+ gl.glDisableVertexAttribArray(attribute)
+ _glVertexAttribFunc = getattr(
+ _glutils.gl, 'glVertexAttrib{}f'.format(len(array)))
+ _glVertexAttribFunc(attribute, *array)
+ else:
+ # TODO As is this is a never event, remove?
+ gl.glEnableVertexAttribArray(attribute)
+ gl.glVertexAttribPointer(
+ attribute,
+ array.shape[-1],
+ _glutils.numpyToGLType(array.dtype),
+ gl.GL_FALSE,
+ 0,
+ array)
+
+ def setIndices(self, indices, copy=True):
+ """Set the primitive indices to use.
+
+ :param indices: Array-like of uint primitive indices or None to unset
+ :param bool copy: True (default) to copy the data, False to use as is
+ """
+ # Trigger garbage collection of previous indices VBO if any
+ self._vbos.pop('__indices__', None)
+
+ if indices is None:
+ self._indices = None
+ else:
+ indices = self._glReadyArray(indices, copy=copy).ravel()
+ assert indices.dtype.name in ('uint8', 'uint16', 'uint32')
+ if _logger.getEffectiveLevel() <= logging.DEBUG:
+ # This might be a costy check
+ assert indices.max() < self.nbVertices
+ self._indices = indices
+ self.notify()
+
+ def getIndices(self, copy=True):
+ """Returns the numpy.ndarray corresponding to the indices.
+
+ :param bool copy: True to get a copy (default),
+ False to get internal array (DO NOT MODIFY)
+ :return: The primitive indices array or None if not set.
+ :rtype: numpy.ndarray or None
+ """
+ if self._indices is None:
+ return None
+ else:
+ return numpy.array(self._indices, copy=copy)
+
+ @property
+ def boundsAttributeNames(self):
+ """Tuple of attribute names defining the bounds of the object.
+
+ Attributes name are taken in the given order to compute the
+ (x, y, z) the bounding box, e.g.::
+
+ geometry.boundsAttributeNames = 'position'
+ geometry.boundsAttributeNames = 'x', 'y', 'z'
+ """
+ return self.__boundsAttributeNames
+
+ @boundsAttributeNames.setter
+ def boundsAttributeNames(self, names):
+ self.__boundsAttributeNames = tuple(str(name) for name in names)
+ self.__bounds = None
+ self.notify()
+
+ def _bounds(self, dataBounds=False):
+ if self.__bounds is None:
+ if len(self.boundsAttributeNames) == 0:
+ return None # No bounds
+
+ self.__bounds = numpy.zeros((2, 3), dtype=numpy.float32)
+
+ # Coordinates defined in one or more attributes
+ index = 0
+ for name in self.boundsAttributeNames:
+ if index == 3:
+ _logger.error("Too many attributes defining bounds")
+ break
+
+ attribute = self._attributes[name]
+ assert attribute.ndim in (1, 2)
+ if attribute.ndim == 1: # Single value
+ min_ = attribute
+ max_ = attribute
+ elif len(attribute) > 0: # Array of values, compute min/max
+ min_ = numpy.nanmin(attribute, axis=0)
+ max_ = numpy.nanmax(attribute, axis=0)
+ else:
+ min_, max_ = numpy.zeros((2, attribute.shape[1]), dtype=numpy.float32)
+
+ toCopy = min(len(min_), 3-index)
+ if toCopy != len(min_):
+ _logger.error("Attribute defining bounds"
+ " has too many dimensions")
+
+ self.__bounds[0, index:index+toCopy] = min_[:toCopy]
+ self.__bounds[1, index:index+toCopy] = max_[:toCopy]
+
+ index += toCopy
+
+ self.__bounds[numpy.isnan(self.__bounds)] = 0. # Avoid NaNs
+
+ return self.__bounds.copy()
+
+ def prepareGL2(self, ctx):
+ # TODO manage _vbo and multiple GL context + allow to share them !
+ # TODO make one or multiple VBO depending on len(vertices),
+ # TODO use a general common VBO for small amount of data
+ for name in self._unsyncAttributes:
+ array = self._attributes[name]
+ self._vbos[name] = ctx.glCtx.makeVboAttrib(array)
+ self._unsyncAttributes = []
+
+ if self._indices is not None and '__indices__' not in self._vbos:
+ vbo = ctx.glCtx.makeVbo(self._indices,
+ usage=gl.GL_STATIC_DRAW,
+ target=gl.GL_ELEMENT_ARRAY_BUFFER)
+ self._vbos['__indices__'] = vbo
+
+ def _draw(self, program=None, nbVertices=None):
+ """Perform OpenGL draw calls.
+
+ :param GLProgram program:
+ If not None, call :meth:`useAttribute` for this program.
+ :param int nbVertices:
+ The number of vertices to render or None to render all vertices.
+ """
+ if program is not None:
+ self.useAttribute(program)
+
+ if self._indices is None:
+ if nbVertices is None:
+ nbVertices = self.nbVertices
+ gl.glDrawArrays(self._MODES[self._mode], 0, nbVertices)
+ else:
+ if nbVertices is None:
+ nbVertices = self._indices.size
+ with self._vbos['__indices__']:
+ gl.glDrawElements(self._MODES[self._mode],
+ nbVertices,
+ _glutils.numpyToGLType(self._indices.dtype),
+ ctypes.c_void_p(0))
+
+
+# Lines #######################################################################
+
+class Lines(Geometry):
+ """A set of segments"""
+ _shaders = ("""
+ attribute vec3 position;
+ attribute vec3 normal;
+ attribute vec4 color;
+
+ uniform mat4 matrix;
+ uniform mat4 transformMat;
+
+ varying vec4 vCameraPosition;
+ varying vec3 vPosition;
+ varying vec3 vNormal;
+ varying vec4 vColor;
+
+ void main(void)
+ {
+ gl_Position = matrix * vec4(position, 1.0);
+ vCameraPosition = transformMat * vec4(position, 1.0);
+ vPosition = position;
+ vNormal = normal;
+ vColor = color;
+ }
+ """,
+ string.Template("""
+ varying vec4 vCameraPosition;
+ varying vec3 vPosition;
+ varying vec3 vNormal;
+ varying vec4 vColor;
+
+ $sceneDecl
+ $lightingFunction
+
+ void main(void)
+ {
+ $scenePreCall(vCameraPosition);
+ gl_FragColor = $lightingCall(vColor, vPosition, vNormal);
+ $scenePostCall(vCameraPosition);
+ }
+ """))
+
+ def __init__(self, positions, normals=None, colors=(1., 1., 1., 1.),
+ indices=None, mode='lines', width=1.):
+ if mode == 'strip':
+ mode = 'line_strip'
+ assert mode in self._LINE_MODES
+
+ self._width = width
+ self._smooth = True
+
+ super(Lines, self).__init__(mode, indices,
+ position=positions,
+ normal=normals,
+ color=colors)
+
+ width = event.notifyProperty('_width', converter=float,
+ doc="Width of the line in pixels.")
+
+ smooth = event.notifyProperty(
+ '_smooth',
+ converter=bool,
+ doc="Smooth line rendering enabled (bool, default: True)")
+
+ def renderGL2(self, ctx):
+ # Prepare program
+ isnormals = 'normal' in self._attributes
+ if isnormals:
+ fraglightfunction = ctx.viewport.light.fragmentDef
+ else:
+ fraglightfunction = ctx.viewport.light.fragmentShaderFunctionNoop
+
+ fragment = self._shaders[1].substitute(
+ sceneDecl=ctx.fragDecl,
+ scenePreCall=ctx.fragCallPre,
+ scenePostCall=ctx.fragCallPost,
+ lightingFunction=fraglightfunction,
+ lightingCall=ctx.viewport.light.fragmentCall)
+ prog = ctx.glCtx.prog(self._shaders[0], fragment)
+ prog.use()
+
+ if isnormals:
+ ctx.viewport.light.setupProgram(ctx, prog)
+
+ gl.glLineWidth(self.width)
+
+ prog.setUniformMatrix('matrix', ctx.objectToNDC.matrix)
+ prog.setUniformMatrix('transformMat',
+ ctx.objectToCamera.matrix,
+ safe=True)
+
+ ctx.setupProgram(prog)
+
+ with gl.enabled(gl.GL_LINE_SMOOTH, self._smooth):
+ self._draw(prog)
+
+
+class DashedLines(Lines):
+ """Set of dashed lines
+
+ This MUST be defined as a set of lines (no strip or loop).
+ """
+
+ _shaders = ("""
+ attribute vec3 position;
+ attribute vec3 origin;
+ attribute vec3 normal;
+ attribute vec4 color;
+
+ uniform mat4 matrix;
+ uniform mat4 transformMat;
+ uniform vec2 viewportSize; /* Width, height of the viewport */
+
+ varying vec4 vCameraPosition;
+ varying vec3 vPosition;
+ varying vec3 vNormal;
+ varying vec4 vColor;
+ varying vec2 vOriginFragCoord;
+
+ void main(void)
+ {
+ gl_Position = matrix * vec4(position, 1.0);
+ vCameraPosition = transformMat * vec4(position, 1.0);
+ vPosition = position;
+ vNormal = normal;
+ vColor = color;
+
+ vec4 clipOrigin = matrix * vec4(origin, 1.0);
+ vec4 ndcOrigin = clipOrigin / clipOrigin.w; /* Perspective divide */
+ /* Convert to same frame as gl_FragCoord: lower-left, pixel center at 0.5, 0.5 */
+ vOriginFragCoord = (ndcOrigin.xy + vec2(1.0, 1.0)) * 0.5 * viewportSize + vec2(0.5, 0.5);
+ }
+ """, # noqa
+ string.Template("""
+ varying vec4 vCameraPosition;
+ varying vec3 vPosition;
+ varying vec3 vNormal;
+ varying vec4 vColor;
+ varying vec2 vOriginFragCoord;
+
+ uniform vec2 dash;
+
+ $sceneDecl
+ $lightingFunction
+
+ void main(void)
+ {
+ $scenePreCall(vCameraPosition);
+
+ /* Discard off dash fragments */
+ float lineDist = distance(vOriginFragCoord, gl_FragCoord.xy);
+ if (mod(lineDist, dash.x + dash.y) > dash.x) {
+ discard;
+ }
+ gl_FragColor = $lightingCall(vColor, vPosition, vNormal);
+
+ $scenePostCall(vCameraPosition);
+ }
+ """))
+
+ def __init__(self, positions, colors=(1., 1., 1., 1.),
+ indices=None, width=1.):
+ self._dash = 1, 0
+ super(DashedLines, self).__init__(positions=positions,
+ colors=colors,
+ indices=indices,
+ mode='lines',
+ width=width)
+
+ @property
+ def dash(self):
+ """Dash of the line as a 2-tuple of lengths in pixels: (on, off)"""
+ return self._dash
+
+ @dash.setter
+ def dash(self, dash):
+ dash = float(dash[0]), float(dash[1])
+ if dash != self._dash:
+ self._dash = dash
+ self.notify()
+
+ def getPositions(self, copy=True):
+ """Get coordinates of lines.
+
+ :param bool copy: True to get a copy, False otherwise
+ :returns: Coordinates of lines
+ :rtype: numpy.ndarray of float32 of shape (N, 2, Ndim)
+ """
+ return self.getAttribute('position', copy=copy)
+
+ def setPositions(self, positions, copy=True):
+ """Set line coordinates.
+
+ :param positions: Array of line coordinates
+ :param bool copy: True to copy input array, False to use as is
+ """
+ self.setAttribute('position', positions, copy=copy)
+ # Update line origins from given positions
+ origins = numpy.array(positions, copy=True, order='C')
+ origins[1::2] = origins[::2]
+ self.setAttribute('origin', origins, copy=False)
+
+ def renderGL2(self, context):
+ # Prepare program
+ isnormals = 'normal' in self._attributes
+ if isnormals:
+ fraglightfunction = context.viewport.light.fragmentDef
+ else:
+ fraglightfunction = \
+ context.viewport.light.fragmentShaderFunctionNoop
+
+ fragment = self._shaders[1].substitute(
+ sceneDecl=context.fragDecl,
+ scenePreCall=context.fragCallPre,
+ scenePostCall=context.fragCallPost,
+ lightingFunction=fraglightfunction,
+ lightingCall=context.viewport.light.fragmentCall)
+ program = context.glCtx.prog(self._shaders[0], fragment)
+ program.use()
+
+ if isnormals:
+ context.viewport.light.setupProgram(context, program)
+
+ gl.glLineWidth(self.width)
+
+ program.setUniformMatrix('matrix', context.objectToNDC.matrix)
+ program.setUniformMatrix('transformMat',
+ context.objectToCamera.matrix,
+ safe=True)
+
+ gl.glUniform2f(
+ program.uniforms['viewportSize'], *context.viewport.size)
+ gl.glUniform2f(program.uniforms['dash'], *self.dash)
+
+ context.setupProgram(program)
+
+ self._draw(program)
+
+
+class Box(core.PrivateGroup):
+ """Rectangular box"""
+
+ _lineIndices = numpy.array((
+ (0, 1), (1, 2), (2, 3), (3, 0), # Lines with z=0
+ (0, 4), (1, 5), (2, 6), (3, 7), # Lines from z=0 to z=1
+ (4, 5), (5, 6), (6, 7), (7, 4)), # Lines with z=1
+ dtype=numpy.uint8)
+
+ _faceIndices = numpy.array(
+ (0, 3, 1, 2, 5, 6, 4, 7, 7, 6, 6, 2, 7, 3, 4, 0, 5, 1),
+ dtype=numpy.uint8)
+
+ _vertices = numpy.array((
+ # Corners with z=0
+ (0., 0., 0.), (1., 0., 0.), (1., 1., 0.), (0., 1., 0.),
+ # Corners with z=1
+ (0., 0., 1.), (1., 0., 1.), (1., 1., 1.), (0., 1., 1.)),
+ dtype=numpy.float32)
+
+ def __init__(self, stroke=(1., 1., 1., 1.), fill=(1., 1., 1., 0.)):
+ super(Box, self).__init__()
+
+ self._fill = Mesh3D(self._vertices,
+ colors=rgba(fill),
+ mode='triangle_strip',
+ indices=self._faceIndices)
+ self._fill.visible = self.fillColor[-1] != 0.
+
+ self._stroke = Lines(self._vertices,
+ indices=self._lineIndices,
+ colors=rgba(stroke),
+ mode='lines')
+ self._stroke.visible = self.strokeColor[-1] != 0.
+ self.strokeWidth = 1.
+
+ self._children = [self._stroke, self._fill]
+
+ self._size = 1., 1., 1.
+
+ @classmethod
+ def getLineIndices(cls, copy=True):
+ """Returns 2D array of Box lines indices
+
+ :param copy: True (default) to get a copy,
+ False to get internal array (Do not modify!)
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(cls._lineIndices, copy=copy)
+
+ @classmethod
+ def getVertices(cls, copy=True):
+ """Returns 2D array of Box corner coordinates.
+
+ :param copy: True (default) to get a copy,
+ False to get internal array (Do not modify!)
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(cls._vertices, copy=copy)
+
+ @property
+ def size(self):
+ """Size of the box (sx, sy, sz)"""
+ return self._size
+
+ @size.setter
+ def size(self, size):
+ assert len(size) == 3
+ size = tuple(size)
+ if size != self.size:
+ self._size = size
+ self._fill.setAttribute(
+ 'position',
+ self._vertices * numpy.array(size, dtype=numpy.float32))
+ self._stroke.setAttribute(
+ 'position',
+ self._vertices * numpy.array(size, dtype=numpy.float32))
+ self.notify()
+
+ @property
+ def strokeSmooth(self):
+ """True to draw smooth stroke, False otherwise"""
+ return self._stroke.smooth
+
+ @strokeSmooth.setter
+ def strokeSmooth(self, smooth):
+ smooth = bool(smooth)
+ if smooth != self._stroke.smooth:
+ self._stroke.smooth = smooth
+ self.notify()
+
+ @property
+ def strokeWidth(self):
+ """Width of the stroke (float)"""
+ return self._stroke.width
+
+ @strokeWidth.setter
+ def strokeWidth(self, width):
+ width = float(width)
+ if width != self.strokeWidth:
+ self._stroke.width = width
+ self.notify()
+
+ @property
+ def strokeColor(self):
+ """RGBA color of the box lines (4-tuple of float in [0, 1])"""
+ return tuple(self._stroke.getAttribute('color', copy=False))
+
+ @strokeColor.setter
+ def strokeColor(self, color):
+ color = rgba(color)
+ if color != self.strokeColor:
+ self._stroke.setAttribute('color', color)
+ # Fully transparent = hidden
+ self._stroke.visible = color[-1] != 0.
+ self.notify()
+
+ @property
+ def fillColor(self):
+ """RGBA color of the box faces (4-tuple of float in [0, 1])"""
+ return tuple(self._fill.getAttribute('color', copy=False))
+
+ @fillColor.setter
+ def fillColor(self, color):
+ color = rgba(color)
+ if color != self.fillColor:
+ self._fill.setAttribute('color', color)
+ # Fully transparent = hidden
+ self._fill.visible = color[-1] != 0.
+ self.notify()
+
+ @property
+ def fillCulling(self):
+ return self._fill.culling
+
+ @fillCulling.setter
+ def fillCulling(self, culling):
+ self._fill.culling = culling
+
+
+class Axes(Lines):
+ """3D RGB orthogonal axes"""
+ _vertices = numpy.array(((0., 0., 0.), (1., 0., 0.),
+ (0., 0., 0.), (0., 1., 0.),
+ (0., 0., 0.), (0., 0., 1.)),
+ dtype=numpy.float32)
+
+ _colors = numpy.array(((255, 0, 0, 255), (255, 0, 0, 255),
+ (0, 255, 0, 255), (0, 255, 0, 255),
+ (0, 0, 255, 255), (0, 0, 255, 255)),
+ dtype=numpy.uint8)
+
+ def __init__(self):
+ super(Axes, self).__init__(self._vertices,
+ colors=self._colors,
+ width=3.)
+ self._size = 1., 1., 1.
+
+ @property
+ def size(self):
+ """Size of the axes (sx, sy, sz)"""
+ return self._size
+
+ @size.setter
+ def size(self, size):
+ assert len(size) == 3
+ size = tuple(size)
+ if size != self.size:
+ self._size = size
+ self.setAttribute(
+ 'position',
+ self._vertices * numpy.array(size, dtype=numpy.float32))
+ self.notify()
+
+
+class BoxWithAxes(Lines):
+ """Rectangular box with RGB OX, OY, OZ axes
+
+ :param color: RGBA color of the box
+ """
+
+ _vertices = numpy.array((
+ # Axes corners
+ (0., 0., 0.), (1., 0., 0.),
+ (0., 0., 0.), (0., 1., 0.),
+ (0., 0., 0.), (0., 0., 1.),
+ # Box corners with z=0
+ (1., 0., 0.), (1., 1., 0.), (0., 1., 0.),
+ # Box corners with z=1
+ (0., 0., 1.), (1., 0., 1.), (1., 1., 1.), (0., 1., 1.)),
+ dtype=numpy.float32)
+
+ _axesColors = numpy.array(((1., 0., 0., 1.), (1., 0., 0., 1.),
+ (0., 1., 0., 1.), (0., 1., 0., 1.),
+ (0., 0., 1., 1.), (0., 0., 1., 1.)),
+ dtype=numpy.float32)
+
+ _lineIndices = numpy.array((
+ (0, 1), (2, 3), (4, 5), # Axes lines
+ (6, 7), (7, 8), # Box lines with z=0
+ (6, 10), (7, 11), (8, 12), # Box lines from z=0 to z=1
+ (9, 10), (10, 11), (11, 12), (12, 9)), # Box lines with z=1
+ dtype=numpy.uint8)
+
+ def __init__(self, color=(1., 1., 1., 1.)):
+ self._color = (1., 1., 1., 1.)
+ colors = numpy.ones((len(self._vertices), 4), dtype=numpy.float32)
+ colors[:len(self._axesColors), :] = self._axesColors
+
+ super(BoxWithAxes, self).__init__(self._vertices,
+ indices=self._lineIndices,
+ colors=colors,
+ width=2.)
+ self._size = 1., 1., 1.
+ self.color = color
+
+ @property
+ def color(self):
+ """The RGBA color to use for the box: 4 float in [0, 1]"""
+ return self._color
+
+ @color.setter
+ def color(self, color):
+ color = rgba(color)
+ if color != self._color:
+ self._color = color
+ colors = numpy.empty((len(self._vertices), 4), dtype=numpy.float32)
+ colors[:len(self._axesColors), :] = self._axesColors
+ colors[len(self._axesColors):, :] = self._color
+ self.setAttribute('color', colors) # Do the notification
+
+ @property
+ def size(self):
+ """Size of the axes (sx, sy, sz)"""
+ return self._size
+
+ @size.setter
+ def size(self, size):
+ assert len(size) == 3
+ size = tuple(size)
+ if size != self.size:
+ self._size = size
+ self.setAttribute(
+ 'position',
+ self._vertices * numpy.array(size, dtype=numpy.float32))
+ self.notify()
+
+
+class PlaneInGroup(core.PrivateGroup):
+ """A plane using its parent bounds to display a contour.
+
+ If plane is outside the bounds of its parent, it is not visible.
+
+ Cannot set the transform attribute of this primitive.
+ This primitive never has any bounds.
+ """
+ # TODO inherit from Lines directly?, make sure the plane remains visible?
+
+ def __init__(self, point=(0., 0., 0.), normal=(0., 0., 1.)):
+ super(PlaneInGroup, self).__init__()
+ self._cache = None, None # Store bounds, vertices
+ self._outline = None
+
+ self._color = None
+ self.color = 1., 1., 1., 1. # Set _color
+ self._width = 2.
+ self._strokeVisible = True
+
+ self._plane = utils.Plane(point, normal)
+ self._plane.addListener(self._planeChanged)
+
+ def moveToCenter(self):
+ """Place the plane at the center of the data, not changing orientation.
+ """
+ if self.parent is not None:
+ bounds = self.parent.bounds(dataBounds=True)
+ if bounds is not None:
+ center = (bounds[0] + bounds[1]) / 2.
+ _logger.debug('Moving plane to center: %s', str(center))
+ self.plane.point = center
+
+ @property
+ def color(self):
+ """Plane outline color (array of 4 float in [0, 1])."""
+ return self._color.copy()
+
+ @color.setter
+ def color(self, color):
+ self._color = numpy.array(color, copy=True, dtype=numpy.float32)
+ if self._outline is not None:
+ self._outline.setAttribute('color', self._color)
+ self.notify() # This is OK as Lines are rebuild for each rendering
+
+ @property
+ def width(self):
+ """Width of the plane stroke in pixels"""
+ return self._width
+
+ @width.setter
+ def width(self, width):
+ self._width = float(width)
+ if self._outline is not None:
+ self._outline.width = self._width # Sync width
+
+ @property
+ def strokeVisible(self):
+ """Whether surrounding stroke is visible or not (bool)."""
+ return self._strokeVisible
+
+ @strokeVisible.setter
+ def strokeVisible(self, visible):
+ self._strokeVisible = bool(visible)
+ if self._outline is not None:
+ self._outline.visible = self._strokeVisible
+
+ # Plane access
+
+ @property
+ def plane(self):
+ """The plane parameters in the frame of the object."""
+ return self._plane
+
+ def _planeChanged(self, source):
+ """Listener of plane changes: clear cache and notify listeners."""
+ self._cache = None, None
+ self.notify()
+
+ # Disable some scene features
+
+ @property
+ def transforms(self):
+ # Ready-only transforms to prevent using it
+ return self._transforms
+
+ def _bounds(self, dataBounds=False):
+ # This is bound less as it uses the bounds of its parent.
+ return None
+
+ @property
+ def contourVertices(self):
+ """The vertices of the contour of the plane/bounds intersection."""
+ parent = self.parent
+ if parent is None:
+ return None # No parent: no vertices
+
+ bounds = parent.bounds(dataBounds=True)
+ if bounds is None:
+ return None # No bounds: no vertices
+
+ # Check if cache is valid and return it
+ cachebounds, cachevertices = self._cache
+ if numpy.all(numpy.equal(bounds, cachebounds)):
+ return cachevertices
+
+ # Cache is not OK, rebuild it
+ boxVertices = Box.getVertices(copy=True)
+ boxVertices = bounds[0] + boxVertices * (bounds[1] - bounds[0])
+ lineIndices = Box.getLineIndices(copy=False)
+ vertices = utils.boxPlaneIntersect(
+ boxVertices, lineIndices, self.plane.normal, self.plane.point)
+
+ self._cache = bounds, vertices if len(vertices) != 0 else None
+
+ return self._cache[1]
+
+ @property
+ def center(self):
+ """The center of the plane/bounds intersection points."""
+ if not self.isValid:
+ return None
+ else:
+ return numpy.mean(self.contourVertices, axis=0)
+
+ @property
+ def isValid(self):
+ """True if a contour is defined, False otherwise."""
+ return self.plane.isPlane and self.contourVertices is not None
+
+ def prepareGL2(self, ctx):
+ if self.isValid:
+ if self._outline is None: # Init outline
+ self._outline = Lines(self.contourVertices,
+ mode='loop',
+ colors=self.color)
+ self._outline.width = self._width
+ self._outline.visible = self._strokeVisible
+ self._children.append(self._outline)
+
+ # Update vertices, TODO only when necessary
+ self._outline.setAttribute('position', self.contourVertices)
+
+ super(PlaneInGroup, self).prepareGL2(ctx)
+
+ def renderGL2(self, ctx):
+ if self.isValid:
+ super(PlaneInGroup, self).renderGL2(ctx)
+
+
+class BoundedGroup(core.Group):
+ """Group with data bounds"""
+
+ _shape = None # To provide a default value without overriding __init__
+
+ @property
+ def shape(self):
+ """Data shape (depth, height, width) of this group or None"""
+ return self._shape
+
+ @shape.setter
+ def shape(self, shape):
+ if shape is None:
+ self._shape = None
+ else:
+ depth, height, width = shape
+ self._shape = float(depth), float(height), float(width)
+
+ @property
+ def size(self):
+ """Data size (width, height, depth) of this group or None"""
+ shape = self.shape
+ if shape is None:
+ return None
+ else:
+ return shape[2], shape[1], shape[0]
+
+ @size.setter
+ def size(self, size):
+ if size is None:
+ self.shape = None
+ else:
+ self.shape = size[2], size[1], size[0]
+
+ def _bounds(self, dataBounds=False):
+ if dataBounds and self.size is not None:
+ return numpy.array(((0., 0., 0.), self.size),
+ dtype=numpy.float32)
+ else:
+ return super(BoundedGroup, self)._bounds(dataBounds)
+
+
+# Points ######################################################################
+
+class _Points(Geometry):
+ """Base class to render a set of points."""
+
+ DIAMOND = 'd'
+ CIRCLE = 'o'
+ SQUARE = 's'
+ PLUS = '+'
+ X_MARKER = 'x'
+ ASTERISK = '*'
+ H_LINE = '_'
+ V_LINE = '|'
+
+ SUPPORTED_MARKERS = (DIAMOND, CIRCLE, SQUARE, PLUS,
+ X_MARKER, ASTERISK, H_LINE, V_LINE)
+ """List of supported markers:
+
+ - 'd' diamond
+ - 'o' circle
+ - 's' square
+ - '+' cross
+ - 'x' x-cross
+ - '*' asterisk
+ - '_' horizontal line
+ - '|' vertical line
+ """
+
+ _MARKER_FUNCTIONS = {
+ DIAMOND: """
+ float alphaSymbol(vec2 coord, float size) {
+ vec2 centerCoord = abs(coord - vec2(0.5, 0.5));
+ float f = centerCoord.x + centerCoord.y;
+ return clamp(size * (0.5 - f), 0.0, 1.0);
+ }
+ """,
+ CIRCLE: """
+ float alphaSymbol(vec2 coord, float size) {
+ float radius = 0.5;
+ float r = distance(coord, vec2(0.5, 0.5));
+ return clamp(size * (radius - r), 0.0, 1.0);
+ }
+ """,
+ SQUARE: """
+ float alphaSymbol(vec2 coord, float size) {
+ return 1.0;
+ }
+ """,
+ PLUS: """
+ float alphaSymbol(vec2 coord, float size) {
+ vec2 d = abs(size * (coord - vec2(0.5, 0.5)));
+ if (min(d.x, d.y) < 0.5) {
+ return 1.0;
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ X_MARKER: """
+ float alphaSymbol(vec2 coord, float size) {
+ vec2 pos = floor(size * coord) + 0.5;
+ vec2 d_x = abs(pos.x + vec2(- pos.y, pos.y - size));
+ if (min(d_x.x, d_x.y) <= 0.5) {
+ return 1.0;
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ ASTERISK: """
+ float alphaSymbol(vec2 coord, float size) {
+ /* Combining +, x and circle */
+ vec2 d_plus = abs(size * (coord - vec2(0.5, 0.5)));
+ vec2 pos = floor(size * coord) + 0.5;
+ vec2 d_x = abs(pos.x + vec2(- pos.y, pos.y - size));
+ if (min(d_plus.x, d_plus.y) < 0.5) {
+ return 1.0;
+ } else if (min(d_x.x, d_x.y) <= 0.5) {
+ float r = distance(coord, vec2(0.5, 0.5));
+ return clamp(size * (0.5 - r), 0.0, 1.0);
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ H_LINE: """
+ float alphaSymbol(vec2 coord, float size) {
+ float dy = abs(size * (coord.y - 0.5));
+ if (dy < 0.5) {
+ return 1.0;
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ V_LINE: """
+ float alphaSymbol(vec2 coord, float size) {
+ float dx = abs(size * (coord.x - 0.5));
+ if (dx < 0.5) {
+ return 1.0;
+ } else {
+ return 0.0;
+ }
+ }
+ """
+ }
+
+ _shaders = (string.Template("""
+ #version 120
+
+ attribute float x;
+ attribute float y;
+ attribute float z;
+ attribute $valueType value;
+ attribute float size;
+
+ uniform mat4 matrix;
+ uniform mat4 transformMat;
+
+ varying vec4 vCameraPosition;
+ varying $valueType vValue;
+ varying float vSize;
+
+ void main(void)
+ {
+ vValue = value;
+
+ vec4 positionVec4 = vec4(x, y, z, 1.0);
+ gl_Position = matrix * positionVec4;
+ vCameraPosition = transformMat * positionVec4;
+
+ gl_PointSize = size;
+ vSize = size;
+ }
+ """),
+ string.Template("""
+ #version 120
+
+ varying vec4 vCameraPosition;
+ varying float vSize;
+ varying $valueType vValue;
+
+ $valueToColorDecl
+ $sceneDecl
+ $alphaSymbolDecl
+
+ void main(void)
+ {
+ $scenePreCall(vCameraPosition);
+
+ float alpha = alphaSymbol(gl_PointCoord, vSize);
+
+ gl_FragColor = $valueToColorCall(vValue);
+ gl_FragColor.a *= alpha;
+ if (gl_FragColor.a == 0.0) {
+ discard;
+ }
+
+ $scenePostCall(vCameraPosition);
+ }
+ """))
+
+ _ATTR_INFO = {
+ 'x': {'dims': (1, 2), 'lastDim': (1,)},
+ 'y': {'dims': (1, 2), 'lastDim': (1,)},
+ 'z': {'dims': (1, 2), 'lastDim': (1,)},
+ 'size': {'dims': (1, 2), 'lastDim': (1,)},
+ }
+
+ def __init__(self, x, y, z, value, size=1., indices=None):
+ super(_Points, self).__init__('points', indices,
+ x=x,
+ y=y,
+ z=z,
+ value=value,
+ size=size,
+ attrib0='x')
+ self.boundsAttributeNames = 'x', 'y', 'z'
+ self._marker = 'o'
+
+ @property
+ def marker(self):
+ """The marker symbol used to display the scatter plot (str)
+
+ See :attr:`SUPPORTED_MARKERS` for the list of supported marker string.
+ """
+ return self._marker
+
+ @marker.setter
+ def marker(self, marker):
+ marker = str(marker)
+ assert marker in self.SUPPORTED_MARKERS
+ if marker != self._marker:
+ self._marker = marker
+ self.notify()
+
+ def _shaderValueDefinition(self):
+ """Type definition, fragment shader declaration, fragment shader call
+ """
+ raise NotImplementedError(
+ "This method must be implemented in subclass")
+
+ def _renderGL2PreDrawHook(self, ctx, program):
+ """Override in subclass to run code before calling gl draw"""
+ pass
+
+ def renderGL2(self, ctx):
+ valueType, valueToColorDecl, valueToColorCall = \
+ self._shaderValueDefinition()
+ vertexShader = self._shaders[0].substitute(
+ valueType=valueType)
+ fragmentShader = self._shaders[1].substitute(
+ sceneDecl=ctx.fragDecl,
+ scenePreCall=ctx.fragCallPre,
+ scenePostCall=ctx.fragCallPost,
+ valueType=valueType,
+ valueToColorDecl=valueToColorDecl,
+ valueToColorCall=valueToColorCall,
+ alphaSymbolDecl=self._MARKER_FUNCTIONS[self.marker])
+ program = ctx.glCtx.prog(vertexShader, fragmentShader,
+ attrib0=self.attrib0)
+ program.use()
+
+ gl.glEnable(gl.GL_VERTEX_PROGRAM_POINT_SIZE) # OpenGL 2
+ gl.glEnable(gl.GL_POINT_SPRITE) # OpenGL 2
+ # gl.glEnable(gl.GL_PROGRAM_POINT_SIZE)
+
+ program.setUniformMatrix('matrix', ctx.objectToNDC.matrix)
+ program.setUniformMatrix('transformMat',
+ ctx.objectToCamera.matrix,
+ safe=True)
+
+ ctx.setupProgram(program)
+
+ self._renderGL2PreDrawHook(ctx, program)
+
+ self._draw(program)
+
+
+class Points(_Points):
+ """A set of data points with an associated value and size."""
+
+ _ATTR_INFO = _Points._ATTR_INFO.copy()
+ _ATTR_INFO.update({'value': {'dims': (1, 2), 'lastDim': (1,)}})
+
+ def __init__(self, x, y, z, value=0., size=1.,
+ indices=None, colormap=None):
+ super(Points, self).__init__(x=x,
+ y=y,
+ z=z,
+ indices=indices,
+ size=size,
+ value=value)
+
+ self._colormap = colormap or Colormap() # Default colormap
+ self._colormap.addListener(self._cmapChanged)
+
+ @property
+ def colormap(self):
+ """The colormap used to render the image"""
+ return self._colormap
+
+ def _cmapChanged(self, source, *args, **kwargs):
+ """Broadcast colormap changes"""
+ self.notify(*args, **kwargs)
+
+ def _shaderValueDefinition(self):
+ """Type definition, fragment shader declaration, fragment shader call
+ """
+ return 'float', self.colormap.decl, self.colormap.call
+
+ def _renderGL2PreDrawHook(self, ctx, program):
+ """Set-up colormap before calling gl draw"""
+ self.colormap.setupProgram(ctx, program)
+
+
+class ColorPoints(_Points):
+ """A set of points with an associated color and size."""
+
+ _ATTR_INFO = _Points._ATTR_INFO.copy()
+ _ATTR_INFO.update({'value': {'dims': (1, 2), 'lastDim': (3, 4)}})
+
+ def __init__(self, x, y, z, color=(1., 1., 1., 1.), size=1.,
+ indices=None):
+ super(ColorPoints, self).__init__(x=x,
+ y=y,
+ z=z,
+ indices=indices,
+ size=size,
+ value=color)
+
+ def _shaderValueDefinition(self):
+ """Type definition, fragment shader declaration, fragment shader call
+ """
+ return 'vec4', '', ''
+
+ def setColor(self, color, copy=True):
+ """Set colors
+
+ :param color: Single RGBA color or
+ 2D array of color of length number of points
+ :param bool copy: True to copy colors (default),
+ False to use provided array (Do not modify!)
+ """
+ self.setAttribute('value', color, copy=copy)
+
+ def getColor(self, copy=True):
+ """Returns the color or array of colors of the points.
+
+ :param copy: True to get a copy (default),
+ False to return internal array (Do not modify!)
+ :return: Color or array of colors
+ :rtype: numpy.ndarray
+ """
+ return self.getAttribute('value', copy=copy)
+
+
+class GridPoints(Geometry):
+ # GLSL 1.30 !
+ """Data points on a regular grid with an associated value and size."""
+ _shaders = ("""
+ #version 130
+
+ in float value;
+ in float size;
+
+ uniform ivec3 gridDims;
+ uniform mat4 matrix;
+ uniform mat4 transformMat;
+ uniform vec2 valRange;
+
+ out vec4 vCameraPosition;
+ out float vNormValue;
+
+ //ivec3 coordsFromIndex(int index, ivec3 shape)
+ //{
+ /*Assumes that data is stored as z-major, then y, contiguous on x
+ */
+ // int yxPlaneSize = shape.y * shape.x; /* nb of elem in 2d yx plane */
+ // int z = index / yxPlaneSize;
+ // int yxIndex = index - z * yxPlaneSize; /* index in 2d yx plane */
+ // int y = yxIndex / shape.x;
+ // int x = yxIndex - y * shape.x;
+ // return ivec3(x, y, z);
+ // }
+
+ ivec3 coordsFromIndex(int index, ivec3 shape)
+ {
+ /*Assumes that data is stored as x-major, then y, contiguous on z
+ */
+ int yzPlaneSize = shape.y * shape.z; /* nb of elem in 2d yz plane */
+ int x = index / yzPlaneSize;
+ int yzIndex = index - x * yzPlaneSize; /* index in 2d yz plane */
+ int y = yzIndex / shape.z;
+ int z = yzIndex - y * shape.z;
+ return ivec3(x, y, z);
+ }
+
+ void main(void)
+ {
+ vNormValue = clamp((value - valRange.x) / (valRange.y - valRange.x),
+ 0.0, 1.0);
+
+ bool isValueInRange = value >= valRange.x && value <= valRange.y;
+ if (isValueInRange) {
+ /* Retrieve 3D position from gridIndex */
+ vec3 coords = vec3(coordsFromIndex(gl_VertexID, gridDims));
+ vec3 position = coords / max(vec3(gridDims) - 1.0, 1.0);
+ gl_Position = matrix * vec4(position, 1.0);
+ vCameraPosition = transformMat * vec4(position, 1.0);
+ } else {
+ gl_Position = vec4(2.0, 0.0, 0.0, 1.0); /* Get clipped */
+ vCameraPosition = vec4(0.0, 0.0, 0.0, 0.0);
+ }
+
+ gl_PointSize = size;
+ }
+ """,
+ string.Template("""
+ #version 130
+
+ in vec4 vCameraPosition;
+ in float vNormValue;
+ out vec4 gl_FragColor;
+
+ $sceneDecl
+
+ void main(void)
+ {
+ $scenePreCall(vCameraPosition);
+
+ gl_FragColor = vec4(0.5 * vNormValue + 0.5, 0.0, 0.0, 1.0);
+
+ $scenePostCall(vCameraPosition);
+ }
+ """))
+
+ _ATTR_INFO = {
+ 'value': {'dims': (1, 2), 'lastDim': (1,)},
+ 'size': {'dims': (1, 2), 'lastDim': (1,)}
+ }
+
+ # TODO Add colormap, shape?
+ # TODO could also use a texture to store values
+
+ def __init__(self, values=0., shape=None, sizes=1., indices=None,
+ minValue=None, maxValue=None):
+ if isinstance(values, abc.Iterable):
+ values = numpy.array(values, copy=False)
+
+ # Test if gl_VertexID will overflow
+ assert values.size < numpy.iinfo(numpy.int32).max
+
+ self._shape = values.shape
+ values = values.ravel() # 1D to add as a 1D vertex attribute
+
+ else:
+ assert shape is not None
+ self._shape = tuple(shape)
+
+ assert len(self._shape) in (1, 2, 3)
+
+ super(GridPoints, self).__init__('points', indices,
+ value=values,
+ size=sizes)
+
+ data = self.getAttribute('value', copy=False)
+ self._minValue = data.min() if minValue is None else minValue
+ self._maxValue = data.max() if maxValue is None else maxValue
+
+ minValue = event.notifyProperty('_minValue')
+ maxValue = event.notifyProperty('_maxValue')
+
+ def _bounds(self, dataBounds=False):
+ # Get bounds from values shape
+ bounds = numpy.zeros((2, 3), dtype=numpy.float32)
+ bounds[1, :] = self._shape
+ bounds[1, :] -= 1
+ return bounds
+
+ def renderGL2(self, ctx):
+ fragment = self._shaders[1].substitute(
+ sceneDecl=ctx.fragDecl,
+ scenePreCall=ctx.fragCallPre,
+ scenePostCall=ctx.fragCallPost)
+ prog = ctx.glCtx.prog(self._shaders[0], fragment)
+ prog.use()
+
+ gl.glEnable(gl.GL_VERTEX_PROGRAM_POINT_SIZE) # OpenGL 2
+ gl.glEnable(gl.GL_POINT_SPRITE) # OpenGL 2
+ # gl.glEnable(gl.GL_PROGRAM_POINT_SIZE)
+
+ prog.setUniformMatrix('matrix', ctx.objectToNDC.matrix)
+ prog.setUniformMatrix('transformMat',
+ ctx.objectToCamera.matrix,
+ safe=True)
+
+ ctx.setupProgram(prog)
+
+ gl.glUniform3i(prog.uniforms['gridDims'],
+ self._shape[2] if len(self._shape) == 3 else 1,
+ self._shape[1] if len(self._shape) >= 2 else 1,
+ self._shape[0])
+
+ gl.glUniform2f(prog.uniforms['valRange'], self.minValue, self.maxValue)
+
+ self._draw(prog, nbVertices=reduce(lambda a, b: a * b, self._shape))
+
+
+# Spheres #####################################################################
+
+class Spheres(Geometry):
+ """A set of spheres.
+
+ Spheres are rendered as circles using points.
+ This brings some limitations:
+ - Do not support non-uniform scaling.
+ - Assume the projection keeps ratio.
+ - Do not render distorion by perspective projection.
+ - If the sphere center is clipped, the whole sphere is not displayed.
+ """
+ # TODO check those links
+ # Accounting for perspective projection
+ # http://iquilezles.org/www/articles/sphereproj/sphereproj.htm
+
+ # Michael Mara and Morgan McGuire.
+ # 2D Polyhedral Bounds of a Clipped, Perspective-Projected 3D Sphere
+ # Journal of Computer Graphics Techniques, Vol. 2, No. 2, 2013.
+ # http://jcgt.org/published/0002/02/05/paper.pdf
+ # https://research.nvidia.com/publication/2d-polyhedral-bounds-clipped-perspective-projected-3d-sphere
+
+ # TODO some issues with small scaling and regular grid or due to sampling
+
+ _shaders = ("""
+ #version 120
+
+ attribute vec3 position;
+ attribute vec4 color;
+ attribute float radius;
+
+ uniform mat4 transformMat;
+ uniform mat4 projMat;
+ uniform vec2 screenSize;
+
+ varying vec4 vCameraPosition;
+ varying vec3 vPosition;
+ varying vec4 vColor;
+ varying float vViewDepth;
+ varying float vViewRadius;
+
+ void main(void)
+ {
+ vCameraPosition = transformMat * vec4(position, 1.0);
+ gl_Position = projMat * vCameraPosition;
+
+ vPosition = gl_Position.xyz / gl_Position.w;
+
+ /* From object space radius to view space diameter.
+ * Do not support non-uniform scaling */
+ vec4 viewSizeVector = transformMat * vec4(2.0 * radius, 0.0, 0.0, 0.0);
+ float viewSize = length(viewSizeVector.xyz);
+
+ /* Convert to pixel size at the xy center of the view space */
+ vec4 projSize = projMat * vec4(0.5 * viewSize, 0.0,
+ vCameraPosition.z, vCameraPosition.w);
+ gl_PointSize = max(1.0, screenSize[0] * projSize.x / projSize.w);
+
+ vColor = color;
+ vViewRadius = 0.5 * viewSize;
+ vViewDepth = vCameraPosition.z;
+ }
+ """,
+ string.Template("""
+ # version 120
+
+ uniform mat4 projMat;
+
+ varying vec4 vCameraPosition;
+ varying vec3 vPosition;
+ varying vec4 vColor;
+ varying float vViewDepth;
+ varying float vViewRadius;
+
+ $sceneDecl
+ $lightingFunction
+
+ void main(void)
+ {
+ $scenePreCall(vCameraPosition);
+
+ /* Get normal from point coords */
+ vec3 normal;
+ normal.xy = 2.0 * gl_PointCoord - vec2(1.0);
+ normal.y *= -1.0; /*Invert y to match NDC orientation*/
+ float sqLength = dot(normal.xy, normal.xy);
+ if (sqLength > 1.0) { /* Length -> out of sphere */
+ discard;
+ }
+ normal.z = sqrt(1.0 - sqLength);
+
+ /*Lighting performed in NDC*/
+ /*TODO update this when lighting changed*/
+ //XXX vec3 position = vPosition + vViewRadius * normal;
+ gl_FragColor = $lightingCall(vColor, vPosition, normal);
+
+ /*Offset depth*/
+ float viewDepth = vViewDepth + vViewRadius * normal.z;
+ vec2 clipZW = viewDepth * projMat[2].zw + projMat[3].zw;
+ gl_FragDepth = 0.5 * (clipZW.x / clipZW.y) + 0.5;
+
+ $scenePostCall(vCameraPosition);
+ }
+ """))
+
+ _ATTR_INFO = {
+ 'position': {'dims': (2, ), 'lastDim': (2, 3, 4)},
+ 'radius': {'dims': (1, 2), 'lastDim': (1, )},
+ 'color': {'dims': (1, 2), 'lastDim': (3, 4)},
+ }
+
+ def __init__(self, positions, radius=1., colors=(1., 1., 1., 1.)):
+ self.__bounds = None
+ super(Spheres, self).__init__('points', None,
+ position=positions,
+ radius=radius,
+ color=colors)
+
+ def renderGL2(self, ctx):
+ fragment = self._shaders[1].substitute(
+ sceneDecl=ctx.fragDecl,
+ scenePreCall=ctx.fragCallPre,
+ scenePostCall=ctx.fragCallPost,
+ lightingFunction=ctx.viewport.light.fragmentDef,
+ lightingCall=ctx.viewport.light.fragmentCall)
+ prog = ctx.glCtx.prog(self._shaders[0], fragment)
+ prog.use()
+
+ ctx.viewport.light.setupProgram(ctx, prog)
+
+ gl.glEnable(gl.GL_VERTEX_PROGRAM_POINT_SIZE) # OpenGL 2
+ gl.glEnable(gl.GL_POINT_SPRITE) # OpenGL 2
+ # gl.glEnable(gl.GL_PROGRAM_POINT_SIZE)
+
+ prog.setUniformMatrix('projMat', ctx.projection.matrix)
+ prog.setUniformMatrix('transformMat',
+ ctx.objectToCamera.matrix,
+ safe=True)
+
+ ctx.setupProgram(prog)
+
+ gl.glUniform2f(prog.uniforms['screenSize'], *ctx.viewport.size)
+
+ self._draw(prog)
+
+ def _bounds(self, dataBounds=False):
+ if self.__bounds is None:
+ self.__bounds = numpy.zeros((2, 3), dtype=numpy.float32)
+ # Support vertex with to 2 to 4 coordinates
+ positions = self._attributes['position']
+ radius = self._attributes['radius']
+ self.__bounds[0, :positions.shape[1]] = \
+ (positions - radius).min(axis=0)[:3]
+ self.__bounds[1, :positions.shape[1]] = \
+ (positions + radius).max(axis=0)[:3]
+ return self.__bounds.copy()
+
+
+# Meshes ######################################################################
+
+class Mesh3D(Geometry):
+ """A conventional 3D mesh"""
+
+ _shaders = ("""
+ attribute vec3 position;
+ attribute vec3 normal;
+ attribute vec4 color;
+
+ uniform mat4 matrix;
+ uniform mat4 transformMat;
+ //uniform mat3 matrixInvTranspose;
+
+ varying vec4 vCameraPosition;
+ varying vec3 vPosition;
+ varying vec3 vNormal;
+ varying vec4 vColor;
+
+ void main(void)
+ {
+ vCameraPosition = transformMat * vec4(position, 1.0);
+ //vNormal = matrixInvTranspose * normalize(normal);
+ vPosition = position;
+ vNormal = normal;
+ vColor = color;
+ gl_Position = matrix * vec4(position, 1.0);
+ }
+ """,
+ string.Template("""
+ varying vec4 vCameraPosition;
+ varying vec3 vPosition;
+ varying vec3 vNormal;
+ varying vec4 vColor;
+
+ $sceneDecl
+ $lightingFunction
+
+ void main(void)
+ {
+ $scenePreCall(vCameraPosition);
+
+ gl_FragColor = $lightingCall(vColor, vPosition, vNormal);
+
+ $scenePostCall(vCameraPosition);
+ }
+ """))
+
+ def __init__(self,
+ positions,
+ colors,
+ normals=None,
+ mode='triangles',
+ indices=None,
+ copy=True):
+ assert mode in self._TRIANGLE_MODES
+ super(Mesh3D, self).__init__(mode, indices,
+ position=positions,
+ normal=normals,
+ color=colors,
+ copy=copy)
+
+ self._culling = None
+
+ @property
+ def culling(self):
+ """Face culling (str)
+
+ One of 'back', 'front' or None.
+ """
+ return self._culling
+
+ @culling.setter
+ def culling(self, culling):
+ assert culling in ('back', 'front', None)
+ if culling != self._culling:
+ self._culling = culling
+ self.notify()
+
+ def renderGL2(self, ctx):
+ isnormals = 'normal' in self._attributes
+ if isnormals:
+ fragLightFunction = ctx.viewport.light.fragmentDef
+ else:
+ fragLightFunction = ctx.viewport.light.fragmentShaderFunctionNoop
+
+ fragment = self._shaders[1].substitute(
+ sceneDecl=ctx.fragDecl,
+ scenePreCall=ctx.fragCallPre,
+ scenePostCall=ctx.fragCallPost,
+ lightingFunction=fragLightFunction,
+ lightingCall=ctx.viewport.light.fragmentCall)
+ prog = ctx.glCtx.prog(self._shaders[0], fragment)
+ prog.use()
+
+ if isnormals:
+ ctx.viewport.light.setupProgram(ctx, prog)
+
+ if self.culling is not None:
+ cullFace = gl.GL_FRONT if self.culling == 'front' else gl.GL_BACK
+ gl.glCullFace(cullFace)
+ gl.glEnable(gl.GL_CULL_FACE)
+
+ prog.setUniformMatrix('matrix', ctx.objectToNDC.matrix)
+ prog.setUniformMatrix('transformMat',
+ ctx.objectToCamera.matrix,
+ safe=True)
+
+ ctx.setupProgram(prog)
+
+ self._draw(prog)
+
+ if self.culling is not None:
+ gl.glDisable(gl.GL_CULL_FACE)
+
+
+class ColormapMesh3D(Geometry):
+ """A 3D mesh with color computed from a colormap"""
+
+ _shaders = ("""
+ attribute vec3 position;
+ attribute vec3 normal;
+ attribute float value;
+
+ uniform mat4 matrix;
+ uniform mat4 transformMat;
+ //uniform mat3 matrixInvTranspose;
+
+ varying vec4 vCameraPosition;
+ varying vec3 vPosition;
+ varying vec3 vNormal;
+ varying float vValue;
+
+ void main(void)
+ {
+ vCameraPosition = transformMat * vec4(position, 1.0);
+ //vNormal = matrixInvTranspose * normalize(normal);
+ vPosition = position;
+ vNormal = normal;
+ vValue = value;
+ gl_Position = matrix * vec4(position, 1.0);
+ }
+ """,
+ string.Template("""
+ uniform float alpha;
+
+ varying vec4 vCameraPosition;
+ varying vec3 vPosition;
+ varying vec3 vNormal;
+ varying float vValue;
+
+ $colormapDecl
+ $sceneDecl
+ $lightingFunction
+
+ void main(void)
+ {
+ $scenePreCall(vCameraPosition);
+
+ vec4 color = $colormapCall(vValue);
+ gl_FragColor = $lightingCall(color, vPosition, vNormal);
+ gl_FragColor.a *= alpha;
+
+ $scenePostCall(vCameraPosition);
+ }
+ """))
+
+ def __init__(self,
+ position,
+ value,
+ colormap=None,
+ normal=None,
+ mode='triangles',
+ indices=None,
+ copy=True):
+ super(ColormapMesh3D, self).__init__(mode, indices,
+ position=position,
+ normal=normal,
+ value=value,
+ copy=copy)
+
+ self._alpha = 1.0
+ self._lineWidth = 1.0
+ self._lineSmooth = True
+ self._culling = None
+ self._colormap = colormap or Colormap() # Default colormap
+ self._colormap.addListener(self._cmapChanged)
+
+ lineWidth = event.notifyProperty('_lineWidth', converter=float,
+ doc="Width of the line in pixels.")
+
+ lineSmooth = event.notifyProperty(
+ '_lineSmooth',
+ converter=bool,
+ doc="Smooth line rendering enabled (bool, default: True)")
+
+ alpha = event.notifyProperty(
+ '_alpha', converter=float,
+ doc="Transparency of the mesh, float in [0, 1]")
+
+ @property
+ def culling(self):
+ """Face culling (str)
+
+ One of 'back', 'front' or None.
+ """
+ return self._culling
+
+ @culling.setter
+ def culling(self, culling):
+ assert culling in ('back', 'front', None)
+ if culling != self._culling:
+ self._culling = culling
+ self.notify()
+
+ @property
+ def colormap(self):
+ """The colormap used to render the image"""
+ return self._colormap
+
+ def _cmapChanged(self, source, *args, **kwargs):
+ """Broadcast colormap changes"""
+ self.notify(*args, **kwargs)
+
+ def renderGL2(self, ctx):
+ if 'normal' in self._attributes:
+ self._renderGL2(ctx)
+ else: # Disable lighting
+ with self.viewport.light.turnOff():
+ self._renderGL2(ctx)
+
+ def _renderGL2(self, ctx):
+ fragment = self._shaders[1].substitute(
+ sceneDecl=ctx.fragDecl,
+ scenePreCall=ctx.fragCallPre,
+ scenePostCall=ctx.fragCallPost,
+ lightingFunction=ctx.viewport.light.fragmentDef,
+ lightingCall=ctx.viewport.light.fragmentCall,
+ colormapDecl=self.colormap.decl,
+ colormapCall=self.colormap.call)
+ program = ctx.glCtx.prog(self._shaders[0], fragment)
+ program.use()
+
+ ctx.viewport.light.setupProgram(ctx, program)
+ ctx.setupProgram(program)
+ self.colormap.setupProgram(ctx, program)
+
+ if self.culling is not None:
+ cullFace = gl.GL_FRONT if self.culling == 'front' else gl.GL_BACK
+ gl.glCullFace(cullFace)
+ gl.glEnable(gl.GL_CULL_FACE)
+
+ program.setUniformMatrix('matrix', ctx.objectToNDC.matrix)
+ program.setUniformMatrix('transformMat',
+ ctx.objectToCamera.matrix,
+ safe=True)
+ gl.glUniform1f(program.uniforms['alpha'], self._alpha)
+
+ if self.drawMode in self._LINE_MODES:
+ gl.glLineWidth(self.lineWidth)
+ with gl.enabled(gl.GL_LINE_SMOOTH, self.lineSmooth):
+ self._draw(program)
+ else:
+ self._draw(program)
+
+ if self.culling is not None:
+ gl.glDisable(gl.GL_CULL_FACE)
+
+
+# ImageData ##################################################################
+
+class _Image(Geometry):
+ """Base class for ImageData and ImageRgba"""
+
+ _shaders = ("""
+ attribute vec2 position;
+
+ uniform mat4 matrix;
+ uniform mat4 transformMat;
+ uniform vec2 dataScale;
+
+ varying vec4 vCameraPosition;
+ varying vec3 vPosition;
+ varying vec3 vNormal;
+ varying vec2 vTexCoords;
+
+ void main(void)
+ {
+ vec4 positionVec4 = vec4(position, 0.0, 1.0);
+ vCameraPosition = transformMat * positionVec4;
+ vPosition = positionVec4.xyz;
+ vTexCoords = dataScale * position;
+ gl_Position = matrix * positionVec4;
+ }
+ """,
+ string.Template("""
+ varying vec4 vCameraPosition;
+ varying vec3 vPosition;
+ varying vec2 vTexCoords;
+ uniform sampler2D data;
+ uniform float alpha;
+
+ $imageDecl
+ $sceneDecl
+ $lightingFunction
+
+ void main(void)
+ {
+ $scenePreCall(vCameraPosition);
+
+ vec4 color = imageColor(data, vTexCoords);
+ color.a *= alpha;
+ if (color.a == 0.) { /* Discard fully transparent pixels */
+ discard;
+ }
+
+ vec3 normal = vec3(0.0, 0.0, 1.0);
+ gl_FragColor = $lightingCall(color, vPosition, normal);
+
+ $scenePostCall(vCameraPosition);
+ }
+ """))
+
+ _UNIT_SQUARE = numpy.array(((0., 0.), (1., 0.), (0., 1.), (1., 1.)),
+ dtype=numpy.float32)
+
+ def __init__(self, data, copy=True):
+ super(_Image, self).__init__(mode='triangle_strip',
+ position=self._UNIT_SQUARE)
+
+ self._texture = None
+ self._update_texture = True
+ self._update_texture_filter = False
+ self._data = None
+ self.setData(data, copy)
+ self._alpha = 1.
+ self._interpolation = 'linear'
+
+ self.isBackfaceVisible = True
+
+ def setData(self, data, copy=True):
+ assert isinstance(data, numpy.ndarray)
+
+ if copy:
+ data = numpy.array(data, copy=True)
+
+ self._data = data
+ self._update_texture = True
+ # By updating the position rather than always using a unit square
+ # we benefit from Geometry bounds handling
+ self.setAttribute('position', self._UNIT_SQUARE * (self._data.shape[1], self._data.shape[0]))
+ self.notify()
+
+ def getData(self, copy=True):
+ return numpy.array(self._data, copy=copy)
+
+ @property
+ def interpolation(self):
+ """The texture interpolation mode: 'linear' or 'nearest'"""
+ return self._interpolation
+
+ @interpolation.setter
+ def interpolation(self, interpolation):
+ assert interpolation in ('linear', 'nearest')
+ self._interpolation = interpolation
+ self._update_texture_filter = True
+ self.notify()
+
+ @property
+ def alpha(self):
+ """Transparency of the image, float in [0, 1]"""
+ return self._alpha
+
+ @alpha.setter
+ def alpha(self, alpha):
+ self._alpha = float(alpha)
+ self.notify()
+
+ def _textureFormat(self):
+ """Implement this method to provide texture internal format and format
+
+ :return: 2-tuple of gl flags (internalFormat, format)
+ """
+ raise NotImplementedError(
+ "This method must be implemented in a subclass")
+
+ def prepareGL2(self, ctx):
+ if self._texture is None or self._update_texture:
+ if self._texture is not None:
+ self._texture.discard()
+
+ if self.interpolation == 'nearest':
+ filter_ = gl.GL_NEAREST
+ else:
+ filter_ = gl.GL_LINEAR
+ self._update_texture = False
+ self._update_texture_filter = False
+ if self._data.size == 0:
+ self._texture = None
+ else:
+ internalFormat, format_ = self._textureFormat()
+ self._texture = _glutils.Texture(
+ internalFormat,
+ self._data,
+ format_,
+ minFilter=filter_,
+ magFilter=filter_,
+ wrap=gl.GL_CLAMP_TO_EDGE)
+
+ if self._update_texture_filter and self._texture is not None:
+ self._update_texture_filter = False
+ if self.interpolation == 'nearest':
+ filter_ = gl.GL_NEAREST
+ else:
+ filter_ = gl.GL_LINEAR
+ self._texture.minFilter = filter_
+ self._texture.magFilter = filter_
+
+ super(_Image, self).prepareGL2(ctx)
+
+ def renderGL2(self, ctx):
+ if self._texture is None:
+ return # Nothing to render
+
+ with self.viewport.light.turnOff():
+ self._renderGL2(ctx)
+
+ def _renderGL2PreDrawHook(self, ctx, program):
+ """Override in subclass to run code before calling gl draw"""
+ pass
+
+ def _shaderImageColorDecl(self):
+ """Returns fragment shader imageColor function declaration"""
+ raise NotImplementedError(
+ "This method must be implemented in a subclass")
+
+ def _renderGL2(self, ctx):
+ fragment = self._shaders[1].substitute(
+ sceneDecl=ctx.fragDecl,
+ scenePreCall=ctx.fragCallPre,
+ scenePostCall=ctx.fragCallPost,
+ lightingFunction=ctx.viewport.light.fragmentDef,
+ lightingCall=ctx.viewport.light.fragmentCall,
+ imageDecl=self._shaderImageColorDecl()
+ )
+ program = ctx.glCtx.prog(self._shaders[0], fragment)
+ program.use()
+
+ ctx.viewport.light.setupProgram(ctx, program)
+
+ if not self.isBackfaceVisible:
+ gl.glCullFace(gl.GL_BACK)
+ gl.glEnable(gl.GL_CULL_FACE)
+
+ program.setUniformMatrix('matrix', ctx.objectToNDC.matrix)
+ program.setUniformMatrix('transformMat',
+ ctx.objectToCamera.matrix,
+ safe=True)
+ gl.glUniform1f(program.uniforms['alpha'], self._alpha)
+
+ shape = self._data.shape
+ gl.glUniform2f(program.uniforms['dataScale'], 1./shape[1], 1./shape[0])
+
+ gl.glUniform1i(program.uniforms['data'], self._texture.texUnit)
+
+ ctx.setupProgram(program)
+
+ self._texture.bind()
+
+ self._renderGL2PreDrawHook(ctx, program)
+
+ self._draw(program)
+
+ if not self.isBackfaceVisible:
+ gl.glDisable(gl.GL_CULL_FACE)
+
+
+class ImageData(_Image):
+ """Display a 2x2 data array with a texture."""
+
+ _imageDecl = string.Template("""
+ $colormapDecl
+
+ vec4 imageColor(sampler2D data, vec2 texCoords) {
+ float value = texture2D(data, texCoords).r;
+ vec4 color = $colormapCall(value);
+ return color;
+ }
+ """)
+
+ def __init__(self, data, copy=True, colormap=None):
+ super(ImageData, self).__init__(data, copy=copy)
+
+ self._colormap = colormap or Colormap() # Default colormap
+ self._colormap.addListener(self._cmapChanged)
+
+ def setData(self, data, copy=True):
+ data = numpy.array(data, copy=copy, order='C', dtype=numpy.float32)
+ # TODO support (u)int8|16
+ assert data.ndim == 2
+
+ super(ImageData, self).setData(data, copy=False)
+
+ @property
+ def colormap(self):
+ """The colormap used to render the image"""
+ return self._colormap
+
+ def _cmapChanged(self, source, *args, **kwargs):
+ """Broadcast colormap changes"""
+ self.notify(*args, **kwargs)
+
+ def _textureFormat(self):
+ return gl.GL_R32F, gl.GL_RED
+
+ def _renderGL2PreDrawHook(self, ctx, program):
+ self.colormap.setupProgram(ctx, program)
+
+ def _shaderImageColorDecl(self):
+ return self._imageDecl.substitute(
+ colormapDecl=self.colormap.decl,
+ colormapCall=self.colormap.call)
+
+
+# ImageRgba ##################################################################
+
+class ImageRgba(_Image):
+ """Display a 2x2 RGBA image with a texture.
+
+ Supports images of float in [0, 1] and uint8.
+ """
+
+ _imageDecl = """
+ vec4 imageColor(sampler2D data, vec2 texCoords) {
+ vec4 color = texture2D(data, texCoords);
+ return color;
+ }
+ """
+
+ def __init__(self, data, copy=True):
+ super(ImageRgba, self).__init__(data, copy=copy)
+
+ def setData(self, data, copy=True):
+ data = numpy.array(data, copy=copy, order='C')
+ assert data.ndim == 3
+ assert data.shape[2] in (3, 4)
+ if data.dtype.kind == 'f':
+ if data.dtype != numpy.dtype(numpy.float32):
+ _logger.warning("Converting image data to float32")
+ data = numpy.array(data, dtype=numpy.float32, copy=False)
+ else:
+ assert data.dtype == numpy.dtype(numpy.uint8)
+
+ super(ImageRgba, self).setData(data, copy=False)
+
+ def _textureFormat(self):
+ format_ = gl.GL_RGBA if self._data.shape[2] == 4 else gl.GL_RGB
+ return format_, format_
+
+ def _shaderImageColorDecl(self):
+ return self._imageDecl
+
+
+# Group ######################################################################
+
+# TODO lighting, clipping as groups?
+# group composition?
+
+class GroupDepthOffset(core.Group):
+ """A group using 2-pass rendering and glDepthRange to avoid Z-fighting"""
+
+ def __init__(self, children=(), epsilon=None):
+ super(GroupDepthOffset, self).__init__(children)
+ self._epsilon = epsilon
+ self.isDepthRangeOn = True
+
+ def prepareGL2(self, ctx):
+ if self._epsilon is None:
+ depthbits = gl.glGetInteger(gl.GL_DEPTH_BITS)
+ self._epsilon = 1. / (1 << (depthbits - 1))
+
+ def renderGL2(self, ctx):
+ if self.isDepthRangeOn:
+ self._renderGL2WithDepthRange(ctx)
+ else:
+ super(GroupDepthOffset, self).renderGL2(ctx)
+
+ def _renderGL2WithDepthRange(self, ctx):
+ # gl.glDepthFunc(gl.GL_LESS)
+ with gl.enabled(gl.GL_CULL_FACE):
+ gl.glCullFace(gl.GL_BACK)
+ for child in self.children:
+ gl.glColorMask(
+ gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE)
+ gl.glDepthMask(gl.GL_TRUE)
+ gl.glDepthRange(self._epsilon, 1.)
+
+ child.render(ctx)
+
+ gl.glColorMask(
+ gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE)
+ gl.glDepthMask(gl.GL_FALSE)
+ gl.glDepthRange(0., 1. - self._epsilon)
+
+ child.render(ctx)
+
+ gl.glCullFace(gl.GL_FRONT)
+ for child in reversed(self.children):
+ gl.glColorMask(
+ gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE)
+ gl.glDepthMask(gl.GL_TRUE)
+ gl.glDepthRange(self._epsilon, 1.)
+
+ child.render(ctx)
+
+ gl.glColorMask(
+ gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE)
+ gl.glDepthMask(gl.GL_FALSE)
+ gl.glDepthRange(0., 1. - self._epsilon)
+
+ child.render(ctx)
+
+ gl.glDepthMask(gl.GL_TRUE)
+ gl.glDepthRange(0., 1.)
+ # gl.glDepthFunc(gl.GL_LEQUAL)
+ # TODO use epsilon for all rendering?
+ # TODO issue with picking in depth buffer!
+
+
+class GroupNoDepth(core.Group):
+ """A group rendering its children without writing to the depth buffer
+
+ :param bool mask: True (default) to disable writing in the depth buffer
+ :param bool notest: True (default) to disable depth test
+ """
+
+ def __init__(self, children=(), mask=True, notest=True):
+ super(GroupNoDepth, self).__init__(children)
+ self._mask = bool(mask)
+ self._notest = bool(notest)
+
+ def renderGL2(self, ctx):
+ if self._mask:
+ gl.glDepthMask(gl.GL_FALSE)
+
+ with gl.disabled(gl.GL_DEPTH_TEST, disable=self._notest):
+ super(GroupNoDepth, self).renderGL2(ctx)
+
+ if self._mask:
+ gl.glDepthMask(gl.GL_TRUE)
+
+
+class GroupBBox(core.PrivateGroup):
+ """A group displaying a bounding box around the children."""
+
+ def __init__(self, children=(), color=(1., 1., 1., 1.)):
+ super(GroupBBox, self).__init__()
+ self._group = core.Group(children)
+
+ self._boxTransforms = transform.TransformList((transform.Translate(),))
+
+ # Using 1 of 3 primitives to render axes and/or bounding box
+ # To avoid z-fighting between axes and bounding box
+ self._boxWithAxes = BoxWithAxes(color)
+ self._boxWithAxes.smooth = False
+ self._boxWithAxes.transforms = self._boxTransforms
+
+ self._box = Box(stroke=color, fill=(1., 1., 1., 0.))
+ self._box.strokeSmooth = False
+ self._box.transforms = self._boxTransforms
+ self._box.visible = False
+
+ self._axes = Axes()
+ self._axes.smooth = False
+ self._axes.transforms = self._boxTransforms
+ self._axes.visible = False
+
+ self.strokeWidth = 2.
+
+ self._children = [self._boxWithAxes, self._box, self._axes, self._group]
+
+ def _updateBoxAndAxes(self):
+ """Update bbox and axes position and size according to children."""
+ bounds = self._group.bounds(dataBounds=True)
+ if bounds is not None:
+ origin = bounds[0]
+ size = bounds[1] - bounds[0]
+ else:
+ origin, size = (0., 0., 0.), (1., 1., 1.)
+
+ self._boxTransforms[0].translation = origin
+
+ self._boxWithAxes.size = size
+ self._box.size = size
+ self._axes.size = size
+
+ def _bounds(self, dataBounds=False):
+ self._updateBoxAndAxes()
+ return super(GroupBBox, self)._bounds(dataBounds)
+
+ def prepareGL2(self, ctx):
+ self._updateBoxAndAxes()
+ super(GroupBBox, self).prepareGL2(ctx)
+
+ # Give access to _group children
+
+ @property
+ def children(self):
+ return self._group.children
+
+ @children.setter
+ def children(self, iterable):
+ self._group.children = iterable
+
+ # Give access to box color and stroke width
+
+ @property
+ def color(self):
+ """The RGBA color to use for the box: 4 float in [0, 1]"""
+ return self._box.strokeColor
+
+ @color.setter
+ def color(self, color):
+ self._box.strokeColor = color
+ self._boxWithAxes.color = color
+
+ @property
+ def strokeWidth(self):
+ """The width of the stroke lines in pixels (float)"""
+ return self._box.strokeWidth
+
+ @strokeWidth.setter
+ def strokeWidth(self, width):
+ width = float(width)
+ self._box.strokeWidth = width
+ self._boxWithAxes.width = width
+ self._axes.width = width
+
+ # Toggle axes visibility
+
+ def _updateBoxAndAxesVisibility(self, axesVisible, boxVisible):
+ """Update visible flags of box and axes primitives accordingly.
+
+ :param bool axesVisible: True to display axes
+ :param bool boxVisible: True to display bounding box
+ """
+ self._boxWithAxes.visible = boxVisible and axesVisible
+ self._box.visible = boxVisible and not axesVisible
+ self._axes.visible = not boxVisible and axesVisible
+
+ @property
+ def axesVisible(self):
+ """Whether axes are displayed or not (bool)"""
+ return self._boxWithAxes.visible or self._axes.visible
+
+ @axesVisible.setter
+ def axesVisible(self, visible):
+ self._updateBoxAndAxesVisibility(axesVisible=bool(visible),
+ boxVisible=self.boxVisible)
+
+ @property
+ def boxVisible(self):
+ """Whether bounding box is displayed or not (bool)"""
+ return self._boxWithAxes.visible or self._box.visible
+
+ @boxVisible.setter
+ def boxVisible(self, visible):
+ self._updateBoxAndAxesVisibility(axesVisible=self.axesVisible,
+ boxVisible=bool(visible))
+
+
+# Clipping Plane ##############################################################
+
+class ClipPlane(PlaneInGroup):
+ """A clipping plane attached to a box"""
+
+ def renderGL2(self, ctx):
+ super(ClipPlane, self).renderGL2(ctx)
+
+ if self.visible:
+ # Set-up clipping plane for following brothers
+
+ # No need of perspective divide, no projection
+ point = ctx.objectToCamera.transformPoint(self.plane.point,
+ perspectiveDivide=False)
+ normal = ctx.objectToCamera.transformNormal(self.plane.normal)
+ ctx.setClipPlane(point, normal)
+
+ def postRender(self, ctx):
+ if self.visible:
+ # Disable clip planes
+ ctx.setClipPlane()
diff --git a/src/silx/gui/plot3d/scene/test/__init__.py b/src/silx/gui/plot3d/scene/test/__init__.py
new file mode 100644
index 0000000..3bb978e
--- /dev/null
+++ b/src/silx/gui/plot3d/scene/test/__init__.py
@@ -0,0 +1,24 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
diff --git a/src/silx/gui/plot3d/scene/test/test_transform.py b/src/silx/gui/plot3d/scene/test/test_transform.py
new file mode 100644
index 0000000..69e991b
--- /dev/null
+++ b/src/silx/gui/plot3d/scene/test/test_transform.py
@@ -0,0 +1,80 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "05/01/2017"
+
+
+import numpy
+import unittest
+
+from silx.gui.plot3d.scene import transform
+
+
+class TestTransformList(unittest.TestCase):
+
+ def assertSameArrays(self, a, b):
+ return self.assertTrue(numpy.allclose(a, b, atol=1e-06))
+
+ def testTransformList(self):
+ """Minimalistic test of TransformList"""
+ transforms = transform.TransformList()
+ refmatrix = numpy.identity(4, dtype=numpy.float32)
+ self.assertSameArrays(refmatrix, transforms.matrix)
+
+ # Append translate
+ transforms.append(transform.Translate(1., 1., 1.))
+ refmatrix = numpy.array(((1., 0., 0., 1.),
+ (0., 1., 0., 1.),
+ (0., 0., 1., 1.),
+ (0., 0., 0., 1.)), dtype=numpy.float32)
+ self.assertSameArrays(refmatrix, transforms.matrix)
+
+ # Extend scale
+ transforms.extend([transform.Scale(0.1, 2., 1.)])
+ refmatrix = numpy.dot(refmatrix,
+ numpy.array(((0.1, 0., 0., 0.),
+ (0., 2., 0., 0.),
+ (0., 0., 1., 0.),
+ (0., 0., 0., 1.)),
+ dtype=numpy.float32))
+ self.assertSameArrays(refmatrix, transforms.matrix)
+
+ # Insert rotate
+ transforms.insert(0, transform.Rotate(360.))
+ self.assertSameArrays(refmatrix, transforms.matrix)
+
+ # Update translate and check for listener called
+ self._callCount = 0
+
+ def listener(source):
+ self._callCount += 1
+ transforms.addListener(listener)
+
+ transforms[1].tx += 1
+ self.assertEqual(self._callCount, 1)
diff --git a/src/silx/gui/plot3d/scene/test/test_utils.py b/src/silx/gui/plot3d/scene/test/test_utils.py
new file mode 100644
index 0000000..65d0ce0
--- /dev/null
+++ b/src/silx/gui/plot3d/scene/test/test_utils.py
@@ -0,0 +1,258 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+
+import unittest
+from silx.utils.testutils import ParametricTestCase
+
+import numpy
+
+from silx.gui.plot3d.scene import utils
+
+
+# angleBetweenVectors #########################################################
+
+class TestAngleBetweenVectors(ParametricTestCase):
+
+ TESTS = { # name: (refvector, vectors, norm, refangles)
+ 'single vector':
+ ((1., 0., 0.), (1., 0., 0.), (0., 0., 1.), 0.),
+ 'single vector, no norm':
+ ((1., 0., 0.), (1., 0., 0.), None, 0.),
+
+ 'with orthogonal norm':
+ ((1., 0., 0.),
+ ((1., 0., 0.), (0., 1., 0.), (-1., 0., 0.), (0., -1., 0.)),
+ (0., 0., 1.),
+ (0., 90., 180., 270.)),
+
+ 'with coplanar norm': # = similar to no norm
+ ((1., 0., 0.),
+ ((1., 0., 0.), (0., 1., 0.), (-1., 0., 0.), (0., -1., 0.)),
+ (1., 0., 0.),
+ (0., 90., 180., 90.)),
+
+ 'without norm':
+ ((1., 0., 0.),
+ ((1., 0., 0.), (0., 1., 0.), (-1., 0., 0.), (0., -1., 0.)),
+ None,
+ (0., 90., 180., 90.)),
+
+ 'not unit vectors':
+ ((2., 2., 0.), ((1., 1., 0.), (1., -1., 0.)), None, (0., 90.)),
+ }
+
+ def testAngleBetweenVectorsFunction(self):
+ for name, params in self.TESTS.items():
+ refvector, vectors, norm, refangles = params
+ with self.subTest(name):
+ refangles = numpy.radians(refangles)
+
+ refvector = numpy.array(refvector)
+ vectors = numpy.array(vectors)
+ if norm is not None:
+ norm = numpy.array(norm)
+
+ testangles = utils.angleBetweenVectors(
+ refvector, vectors, norm)
+
+ self.assertTrue(
+ numpy.allclose(testangles, refangles, atol=1e-5))
+
+
+# Plane #######################################################################
+
+class AssertNotificationContext(object):
+ """Context that checks if an event.Notifier is sending events."""
+
+ def __init__(self, notifier, count=1):
+ """Initializer.
+
+ :param event.Notifier notifier: The notifier to test.
+ :param int count: The expected number of calls.
+ """
+ self._notifier = notifier
+ self._callCount = None
+ self._count = count
+
+ def __enter__(self):
+ self._callCount = 0
+ self._notifier.addListener(self._callback)
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ # Do not return True so exceptions are propagated
+ self._notifier.removeListener(self._callback)
+ assert self._callCount == self._count
+ self._callCount = None
+
+ def _callback(self, *args, **kwargs):
+ self._callCount += 1
+
+
+class TestPlaneParameters(ParametricTestCase):
+ """Test Plane.parameters read/write and notifications."""
+
+ PARAMETERS = {
+ 'unit normal': (1., 0., 0., 1.),
+ 'not unit normal': (1., 1., 0., 1.),
+ 'd = 0': (1., 0., 0., 0.)
+ }
+
+ def testParameters(self):
+ """Check parameters read/write and notification."""
+ plane = utils.Plane()
+
+ for name, parameters in self.PARAMETERS.items():
+ with self.subTest(name, parameters=parameters):
+ with AssertNotificationContext(plane):
+ plane.parameters = parameters
+
+ # Plane parameters are converted to have a unit normal
+ normparams = parameters / numpy.linalg.norm(parameters[:3])
+ self.assertTrue(numpy.allclose(plane.parameters, normparams))
+
+ ZEROS_PARAMETERS = (
+ (0., 0., 0., 0.),
+ (0., 0., 0., 1.)
+ )
+
+ ZEROS = 0., 0., 0., 0.
+
+ def testParametersNoPlane(self):
+ """Test Plane.parameters with ||normal|| == 0 ."""
+ plane = utils.Plane()
+ plane.parameters = self.ZEROS
+
+ for parameters in self.ZEROS_PARAMETERS:
+ with self.subTest(parameters=parameters):
+ with AssertNotificationContext(plane, count=0):
+ plane.parameters = parameters
+ self.assertTrue(
+ numpy.allclose(plane.parameters, self.ZEROS, 0., 0.))
+
+
+# unindexArrays ###############################################################
+
+class TestUnindexArrays(ParametricTestCase):
+ """Test unindexArrays function."""
+
+ def testBasicModes(self):
+ """Test for modes: points, lines and triangles"""
+ indices = numpy.array((1, 2, 0))
+ arrays = (numpy.array((0., 1., 2.)),
+ numpy.array(((0, 0), (1, 1), (2, 2))))
+ refresults = (numpy.array((1., 2., 0.)),
+ numpy.array(((1, 1), (2, 2), (0, 0))))
+
+ for mode in ('points', 'lines', 'triangles'):
+ with self.subTest(mode=mode):
+ testresults = utils.unindexArrays(mode, indices, *arrays)
+ for ref, test in zip(refresults, testresults):
+ self.assertTrue(numpy.equal(ref, test).all())
+
+ def testPackedLines(self):
+ """Test for modes: line_strip, loop"""
+ indices = numpy.array((1, 2, 0))
+ arrays = (numpy.array((0., 1., 2.)),
+ numpy.array(((0, 0), (1, 1), (2, 2))))
+ results = {
+ 'line_strip': (
+ numpy.array((1., 2., 2., 0.)),
+ numpy.array(((1, 1), (2, 2), (2, 2), (0, 0)))),
+ 'loop': (
+ numpy.array((1., 2., 2., 0., 0., 1.)),
+ numpy.array(((1, 1), (2, 2), (2, 2), (0, 0), (0, 0), (1, 1)))),
+ }
+
+ for mode, refresults in results.items():
+ with self.subTest(mode=mode):
+ testresults = utils.unindexArrays(mode, indices, *arrays)
+ for ref, test in zip(refresults, testresults):
+ self.assertTrue(numpy.equal(ref, test).all())
+
+ def testPackedTriangles(self):
+ """Test for modes: triangle_strip, fan"""
+ indices = numpy.array((1, 2, 0, 3))
+ arrays = (numpy.array((0., 1., 2., 3.)),
+ numpy.array(((0, 0), (1, 1), (2, 2), (3, 3))))
+ results = {
+ 'triangle_strip': (
+ numpy.array((1., 2., 0., 2., 0., 3.)),
+ numpy.array(((1, 1), (2, 2), (0, 0), (2, 2), (0, 0), (3, 3)))),
+ 'fan': (
+ numpy.array((1., 2., 0., 1., 0., 3.)),
+ numpy.array(((1, 1), (2, 2), (0, 0), (1, 1), (0, 0), (3, 3)))),
+ }
+
+ for mode, refresults in results.items():
+ with self.subTest(mode=mode):
+ testresults = utils.unindexArrays(mode, indices, *arrays)
+ for ref, test in zip(refresults, testresults):
+ self.assertTrue(numpy.equal(ref, test).all())
+
+ def testBadIndices(self):
+ """Test with negative indices and indices higher than array length"""
+ arrays = numpy.array((0, 1)), numpy.array((0, 1, 2))
+
+ # negative indices
+ with self.assertRaises(AssertionError):
+ utils.unindexArrays('points', (-1, 0), *arrays)
+
+ # Too high indices
+ with self.assertRaises(AssertionError):
+ utils.unindexArrays('points', (0, 10), *arrays)
+
+
+# triangleNormals #############################################################
+
+class TestTriangleNormals(ParametricTestCase):
+ """Test triangleNormals function."""
+
+ def test(self):
+ """Test for modes: points, lines and triangles"""
+ positions = numpy.array(
+ ((0., 0., 0.), (1., 0., 0.), (0., 1., 0.), # normal = Z
+ (1., 1., 1.), (1., 2., 3.), (4., 5., 6.), # Random triangle
+ # Degenerated triangles:
+ (0., 0., 0.), (1., 0., 0.), (2., 0., 0.), # Colinear points
+ (1., 1., 1.), (1., 1., 1.), (1., 1., 1.), # All same point
+ ),
+ dtype='float32')
+
+ normals = numpy.array(
+ ((0., 0., 1.),
+ (-0.40824829, 0.81649658, -0.40824829),
+ (0., 0., 0.),
+ (0., 0., 0.)),
+ dtype='float32')
+
+ testnormals = utils.trianglesNormal(positions)
+ self.assertTrue(numpy.allclose(testnormals, normals))
diff --git a/src/silx/gui/plot3d/scene/text.py b/src/silx/gui/plot3d/scene/text.py
new file mode 100644
index 0000000..bacc2e6
--- /dev/null
+++ b/src/silx/gui/plot3d/scene/text.py
@@ -0,0 +1,535 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Primitive displaying a text field in the scene."""
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+
+import logging
+import numpy
+
+from silx.gui.colors import rgba
+
+from ... import _glutils
+from ..._glutils import gl
+
+from ..._glutils import font as _font
+from ...plot._utils import ticklayout
+
+from . import event, primitives, core, transform
+
+
+_logger = logging.getLogger(__name__)
+
+
+class Font(event.Notifier):
+ """Description of a font.
+
+ :param str name: Family of the font
+ :param int size: Size of the font in points
+ :param int weight: Font weight
+ :param bool italic: True for italic font, False (default) otherwise
+ """
+
+ def __init__(self, name=None, size=-1, weight=-1, italic=False):
+ self._name = name if name is not None else _font.getDefaultFontFamily()
+ self._size = size
+ self._weight = weight
+ self._italic = italic
+ super(Font, self).__init__()
+
+ name = event.notifyProperty(
+ '_name',
+ doc="""Name of the font (str)""",
+ converter=str)
+
+ size = event.notifyProperty(
+ '_size',
+ doc="""Font size in points (int)""",
+ converter=int)
+
+ weight = event.notifyProperty(
+ '_weight',
+ doc="""Font size in points (int)""",
+ converter=int)
+
+ italic = event.notifyProperty(
+ '_italic',
+ doc="""True for italic (bool)""",
+ converter=bool)
+
+
+class Text2D(primitives.Geometry):
+ """Text field as a 2D texture displayed with bill-boarding
+
+ :param str text: Text to display
+ :param Font font: The font to use
+ """
+
+ # Text anchor values
+ CENTER = 'center'
+
+ LEFT = 'left'
+ RIGHT = 'right'
+
+ TOP = 'top'
+ BASELINE = 'baseline'
+ BOTTOM = 'bottom'
+
+ _ALIGN = LEFT, CENTER, RIGHT
+ _VALIGN = TOP, BASELINE, CENTER, BOTTOM
+
+ _rasterTextCache = {}
+ """Internal cache storing already rasterized text"""
+ # TODO limit cache size and discard least recent used
+
+ def __init__(self, text='', font=None):
+ self._dirtyTexture = True
+ self._dirtyAlign = True
+ self._baselineOffset = 0
+ self._text = text
+ self._font = font if font is not None else Font()
+ self._foreground = 1., 1., 1., 1.
+ self._background = 0., 0., 0., 0.
+ self._overlay = False
+ self._align = 'left'
+ self._valign = 'baseline'
+ self._devicePixelRatio = 1.0 # Store it to check for changes
+
+ self._texture = None
+ self._textureDirty = True
+
+ super(Text2D, self).__init__(
+ 'triangle_strip',
+ copy=False,
+ # Keep an array for position as it is bound to attr 0 and MUST
+ # be active and an array at least on Mac OS X
+ position=numpy.zeros((4, 3), dtype=numpy.float32),
+ vertexID=numpy.arange(4., dtype=numpy.float32).reshape(4, 1),
+ offsetInViewportCoords=(0., 0.))
+
+ @property
+ def text(self):
+ """Text displayed by this primitive (str)"""
+ return self._text
+
+ @text.setter
+ def text(self, text):
+ text = str(text)
+ if text != self._text:
+ self._dirtyTexture = True
+ self._text = text
+ self.notify()
+
+ @property
+ def font(self):
+ """Font to use to raster text (Font)"""
+ return self._font
+
+ @font.setter
+ def font(self, font):
+ self._font.removeListener(self._fontChanged)
+ self._font = font
+ self._font.addListener(self._fontChanged)
+ self._fontChanged(self) # Which calls notify and primitive as dirty
+
+ def _fontChanged(self, source):
+ """Listen for font change"""
+ self._dirtyTexture = True
+ self.notify()
+
+ foreground = event.notifyProperty(
+ '_foreground', doc="""RGBA color of the text: 4 float in [0, 1]""",
+ converter=rgba)
+
+ background = event.notifyProperty(
+ '_background',
+ doc="RGBA background color of the text field: 4 float in [0, 1]",
+ converter=rgba)
+
+ overlay = event.notifyProperty(
+ '_overlay',
+ doc="True to always display text on top of the scene (default: False)",
+ converter=bool)
+
+ def _setAlign(self, align):
+ assert align in self._ALIGN
+ self._align = align
+ self._dirtyAlign = True
+ self.notify()
+
+ align = property(
+ lambda self: self._align,
+ _setAlign,
+ doc="""Horizontal anchor position of the text field (str).
+
+ Either 'left' (default), 'center' or 'right'.""")
+
+ def _setVAlign(self, valign):
+ assert valign in self._VALIGN
+ self._valign = valign
+ self._dirtyAlign = True
+ self.notify()
+
+ valign = property(
+ lambda self: self._valign,
+ _setVAlign,
+ doc="""Vertical anchor position of the text field (str).
+
+ Either 'top', 'baseline' (default), 'center' or 'bottom'""")
+
+ def _raster(self, devicePixelRatio):
+ """Raster current primitive to a bitmap
+
+ :param float devicePixelRatio:
+ The ratio between device and device-independent pixels
+ :return: Corresponding image in grayscale and baseline offset from top
+ :rtype: (HxW numpy.ndarray of uint8, int)
+ """
+ params = (self.text,
+ self.font.name,
+ self.font.size,
+ self.font.weight,
+ self.font.italic,
+ devicePixelRatio)
+
+ if params not in self._rasterTextCache: # Add to cache
+ self._rasterTextCache[params] = _font.rasterText(*params)
+
+ array, offset = self._rasterTextCache[params]
+ return array.copy(), offset
+
+ def _bounds(self, dataBounds=False):
+ return None
+
+ def prepareGL2(self, context):
+ # Check if devicePixelRatio has changed since last rendering
+ devicePixelRatio = context.glCtx.devicePixelRatio
+ if self._devicePixelRatio != devicePixelRatio:
+ self._devicePixelRatio = devicePixelRatio
+ self._dirtyTexture = True
+
+ if self._dirtyTexture:
+ self._dirtyTexture = False
+
+ if self._texture is not None:
+ self._texture.discard()
+ self._texture = None
+ self._baselineOffset = 0
+
+ if self.text:
+ image, self._baselineOffset = self._raster(
+ self._devicePixelRatio)
+ self._texture = _glutils.Texture(
+ gl.GL_R8, image, gl.GL_RED,
+ minFilter=gl.GL_NEAREST,
+ magFilter=gl.GL_NEAREST,
+ wrap=gl.GL_CLAMP_TO_EDGE)
+ self._texture.prepare()
+ self._dirtyAlign = True # To force update of offset
+
+ if self._dirtyAlign:
+ self._dirtyAlign = False
+
+ if self._texture is not None:
+ height, width = self._texture.shape
+
+ if self._align == 'left':
+ ox = 0.
+ elif self._align == 'center':
+ ox = - width // 2
+ elif self._align == 'right':
+ ox = - width
+ else:
+ _logger.error("Unsupported align: %s", self._align)
+ ox = 0.
+
+ if self._valign == 'top':
+ oy = 0.
+ elif self._valign == 'baseline':
+ oy = self._baselineOffset
+ elif self._valign == 'center':
+ oy = height // 2
+ elif self._valign == 'bottom':
+ oy = height
+ else:
+ _logger.error("Unsupported valign: %s", self._valign)
+ oy = 0.
+
+ offsets = (ox, oy) + numpy.array(
+ ((0., 0.), (width, 0.), (0., -height), (width, -height)),
+ dtype=numpy.float32)
+ self.setAttribute('offsetInViewportCoords', offsets)
+
+ super(Text2D, self).prepareGL2(context)
+
+ def renderGL2(self, context):
+ if not self.text:
+ return # Nothing to render
+
+ program = context.glCtx.prog(*self._shaders)
+ program.use()
+
+ program.setUniformMatrix('matrix', context.objectToNDC.matrix)
+ gl.glUniform2f(
+ program.uniforms['viewportSize'], *context.viewport.size)
+ gl.glUniform4f(program.uniforms['foreground'], *self.foreground)
+ gl.glUniform4f(program.uniforms['background'], *self.background)
+ gl.glUniform1i(program.uniforms['texture'], self._texture.texUnit)
+ gl.glUniform1i(program.uniforms['isOverlay'],
+ 1 if self._overlay else 0)
+
+ self._texture.bind()
+
+ if not self._overlay or not gl.glGetBoolean(gl.GL_DEPTH_TEST):
+ self._draw(program)
+ else: # overlay and depth test currently enabled
+ gl.glDisable(gl.GL_DEPTH_TEST)
+ self._draw(program)
+ gl.glEnable(gl.GL_DEPTH_TEST)
+
+ # TODO texture atlas + viewportSize as attribute to chain text rendering
+
+ _shaders = (
+ """
+ attribute vec3 position;
+ attribute vec2 offsetInViewportCoords; /* Offset in pixels (y upward) */
+ attribute float vertexID; /* Index of rectangle corner */
+
+ uniform mat4 matrix;
+ uniform vec2 viewportSize; /* Width, height of the viewport */
+ uniform int isOverlay;
+
+ varying vec2 texCoords;
+
+ void main(void)
+ {
+ vec4 clipPos = matrix * vec4(position, 1.0);
+ vec4 ndcPos = clipPos / clipPos.w; /* Perspective divide */
+
+ /* Align ndcPos with pixels in viewport-like coords (origin useless) */
+ vec2 viewportPos = floor((ndcPos.xy + vec2(1.0, 1.0)) * 0.5 * viewportSize);
+
+ /* Apply offset in viewport coords */
+ viewportPos += offsetInViewportCoords;
+
+ /* Convert back to NDC */
+ vec2 pointPos = 2.0 * viewportPos / viewportSize - vec2(1.0, 1.0);
+ float z = (isOverlay != 0) ? -1.0 : ndcPos.z;
+ gl_Position = vec4(pointPos, z, 1.0);
+
+ /* Index : texCoords:
+ * 0: (0., 0.)
+ * 1: (1., 0.)
+ * 2: (0., 1.)
+ * 3: (1., 1.)
+ */
+ texCoords = vec2(vertexID == 0.0 || vertexID == 2.0 ? 0.0 : 1.0,
+ vertexID < 1.5 ? 0.0 : 1.0);
+ }
+ """, # noqa
+
+ """
+ varying vec2 texCoords;
+
+ uniform vec4 foreground;
+ uniform vec4 background;
+ uniform sampler2D texture;
+
+ void main(void)
+ {
+ float value = texture2D(texture, texCoords).r;
+
+ if (background.a != 0.0) {
+ gl_FragColor = mix(background, foreground, value);
+ } else {
+ gl_FragColor = foreground;
+ gl_FragColor.a *= value;
+ if (gl_FragColor.a <= 0.01) {
+ discard;
+ }
+ }
+ }
+ """)
+
+
+class LabelledAxes(primitives.GroupBBox):
+ """A group displaying a bounding box with axes labels around its children.
+ """
+
+ def __init__(self):
+ super(LabelledAxes, self).__init__()
+ self._ticksForBounds = None
+
+ self._font = Font()
+
+ # TODO offset labels from anchor in pixels
+
+ self._xlabel = Text2D(font=self._font)
+ self._xlabel.align = 'center'
+ self._xlabel.transforms = [self._boxTransforms,
+ transform.Translate(tx=0.5)]
+ self._children.append(self._xlabel)
+
+ self._ylabel = Text2D(font=self._font)
+ self._ylabel.align = 'center'
+ self._ylabel.transforms = [self._boxTransforms,
+ transform.Translate(ty=0.5)]
+ self._children.append(self._ylabel)
+
+ self._zlabel = Text2D(font=self._font)
+ self._zlabel.align = 'center'
+ self._zlabel.transforms = [self._boxTransforms,
+ transform.Translate(tz=0.5)]
+ self._children.append(self._zlabel)
+
+ self._tickLines = primitives.Lines( # Init tick lines with dummy pos
+ positions=((0., 0., 0.), (0., 0., 0.)),
+ mode='lines')
+ self._tickLines.visible = False
+ self._children.append(self._tickLines)
+
+ self._tickLabels = core.Group()
+ self._children.append(self._tickLabels)
+
+ @property
+ def font(self):
+ """Font of axes text labels (Font)"""
+ return self._font
+
+ @font.setter
+ def font(self, font):
+ self._font = font
+ self._xlabel.font = font
+ self._ylabel.font = font
+ self._zlabel.font = font
+ for label in self._tickLabels.children:
+ label.font = font
+
+ @property
+ def xlabel(self):
+ """Text label of the X axis (str)"""
+ return self._xlabel.text
+
+ @xlabel.setter
+ def xlabel(self, text):
+ self._xlabel.text = text
+
+ @property
+ def ylabel(self):
+ """Text label of the Y axis (str)"""
+ return self._ylabel.text
+
+ @ylabel.setter
+ def ylabel(self, text):
+ self._ylabel.text = text
+
+ @property
+ def zlabel(self):
+ """Text label of the Z axis (str)"""
+ return self._zlabel.text
+
+ @zlabel.setter
+ def zlabel(self, text):
+ self._zlabel.text = text
+
+ def _updateTicks(self):
+ """Check if ticks need update and update them if needed."""
+ bounds = self._group.bounds(transformed=False, dataBounds=True)
+ if bounds is None: # No content
+ if self._ticksForBounds is not None:
+ self._ticksForBounds = None
+ self._tickLines.visible = False
+ self._tickLabels.children = [] # Reset previous labels
+
+ elif (self._ticksForBounds is None or
+ not numpy.all(numpy.equal(bounds, self._ticksForBounds))):
+ self._ticksForBounds = bounds
+
+ # Update ticks
+ # TODO make ticks having a constant length on the screen
+ ticklength = numpy.abs(bounds[1] - bounds[0]) / 20.
+
+ xticks, xlabels = ticklayout.ticks(*bounds[:, 0])
+ yticks, ylabels = ticklayout.ticks(*bounds[:, 1])
+ zticks, zlabels = ticklayout.ticks(*bounds[:, 2])
+
+ # Update tick lines
+ coords = numpy.empty(
+ ((len(xticks) + len(yticks) + len(zticks)), 4, 3),
+ dtype=numpy.float32)
+ coords[:, :, :] = bounds[0, :] # account for offset from origin
+
+ xcoords = coords[:len(xticks)]
+ xcoords[:, :, 0] = numpy.asarray(xticks)[:, numpy.newaxis]
+ xcoords[:, 1, 1] += ticklength[1] # X ticks on XY plane
+ xcoords[:, 3, 2] += ticklength[2] # X ticks on XZ plane
+
+ ycoords = coords[len(xticks):len(xticks) + len(yticks)]
+ ycoords[:, :, 1] = numpy.asarray(yticks)[:, numpy.newaxis]
+ ycoords[:, 1, 0] += ticklength[0] # Y ticks on XY plane
+ ycoords[:, 3, 2] += ticklength[2] # Y ticks on YZ plane
+
+ zcoords = coords[len(xticks) + len(yticks):]
+ zcoords[:, :, 2] = numpy.asarray(zticks)[:, numpy.newaxis]
+ zcoords[:, 1, 0] += ticklength[0] # Z ticks on XZ plane
+ zcoords[:, 3, 1] += ticklength[1] # Z ticks on YZ plane
+
+ self._tickLines.setAttribute('position', coords.reshape(-1, 3))
+ self._tickLines.visible = True
+
+ # Update labels
+ offsets = bounds[0] - ticklength
+ labels = []
+ for tick, label in zip(xticks, xlabels):
+ text = Text2D(text=label, font=self.font)
+ text.align = 'center'
+ text.transforms = [transform.Translate(
+ tx=tick, ty=offsets[1], tz=offsets[2])]
+ labels.append(text)
+
+ for tick, label in zip(yticks, ylabels):
+ text = Text2D(text=label, font=self.font)
+ text.align = 'center'
+ text.transforms = [transform.Translate(
+ tx=offsets[0], ty=tick, tz=offsets[2])]
+ labels.append(text)
+
+ for tick, label in zip(zticks, zlabels):
+ text = Text2D(text=label, font=self.font)
+ text.align = 'center'
+ text.transforms = [transform.Translate(
+ tx=offsets[0], ty=offsets[1], tz=tick)]
+ labels.append(text)
+
+ self._tickLabels.children = labels # Reset previous labels
+
+ def prepareGL2(self, context):
+ self._updateTicks()
+ super(LabelledAxes, self).prepareGL2(context)
diff --git a/src/silx/gui/plot3d/scene/transform.py b/src/silx/gui/plot3d/scene/transform.py
new file mode 100644
index 0000000..43b739b
--- /dev/null
+++ b/src/silx/gui/plot3d/scene/transform.py
@@ -0,0 +1,1027 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides 4x4 matrix operation and classes to handle them."""
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/07/2016"
+
+
+import itertools
+import numpy
+
+from . import event
+
+
+# Functions ###################################################################
+
+# Projections
+
+def mat4LookAtDir(position, direction, up):
+ """Creates matrix to look in direction from position.
+
+ :param position: Array-like 3 coordinates of the point of view position.
+ :param direction: Array-like 3 coordinates of the sight direction vector.
+ :param up: Array-like 3 coordinates of the upward direction
+ in the image plane.
+ :returns: Corresponding matrix.
+ :rtype: numpy.ndarray of shape (4, 4)
+ """
+ assert len(position) == 3
+ assert len(direction) == 3
+ assert len(up) == 3
+
+ direction = numpy.array(direction, copy=True, dtype=numpy.float32)
+ dirnorm = numpy.linalg.norm(direction)
+ assert dirnorm != 0.
+ direction /= dirnorm
+
+ side = numpy.cross(direction,
+ numpy.array(up, copy=False, dtype=numpy.float32))
+ sidenorm = numpy.linalg.norm(side)
+ assert sidenorm != 0.
+ up = numpy.cross(side / sidenorm, direction)
+ upnorm = numpy.linalg.norm(up)
+ assert upnorm != 0.
+ up /= upnorm
+
+ matrix = numpy.identity(4, dtype=numpy.float32)
+ matrix[0, :3] = side
+ matrix[1, :3] = up
+ matrix[2, :3] = -direction
+ return numpy.dot(matrix,
+ mat4Translate(-position[0], -position[1], -position[2]))
+
+
+def mat4LookAt(position, center, up):
+ """Creates matrix to look at center from position.
+
+ See gluLookAt.
+
+ :param position: Array-like 3 coordinates of the point of view position.
+ :param center: Array-like 3 coordinates of the center of the scene.
+ :param up: Array-like 3 coordinates of the upward direction
+ in the image plane.
+ :returns: Corresponding matrix.
+ :rtype: numpy.ndarray of shape (4, 4)
+ """
+ position = numpy.array(position, copy=False, dtype=numpy.float32)
+ center = numpy.array(center, copy=False, dtype=numpy.float32)
+ direction = center - position
+ return mat4LookAtDir(position, direction, up)
+
+
+def mat4Frustum(left, right, bottom, top, near, far):
+ """Creates a frustum projection matrix.
+
+ See glFrustum.
+ """
+ return numpy.array((
+ (2.*near / (right-left), 0., (right+left) / (right-left), 0.),
+ (0., 2.*near / (top-bottom), (top+bottom) / (top-bottom), 0.),
+ (0., 0., -(far+near) / (far-near), -2.*far*near / (far-near)),
+ (0., 0., -1., 0.)), dtype=numpy.float32)
+
+
+def mat4Perspective(fovy, width, height, near, far):
+ """Creates a perspective projection matrix.
+
+ Similar to gluPerspective.
+
+ :param float fovy: Field of view angle in degrees in the y direction.
+ :param float width: Width of the viewport.
+ :param float height: Height of the viewport.
+ :param float near: Distance to the near plane (strictly positive).
+ :param float far: Distance to the far plane (strictly positive).
+ :return: Corresponding matrix.
+ :rtype: numpy.ndarray of shape (4, 4)
+ """
+ assert fovy != 0
+ assert height != 0
+ assert width != 0
+ assert near > 0.
+ assert far > near
+ aspectratio = width / height
+ f = 1. / numpy.tan(numpy.radians(fovy) / 2.)
+ return numpy.array((
+ (f / aspectratio, 0., 0., 0.),
+ (0., f, 0., 0.),
+ (0., 0., (far + near) / (near - far), 2. * far * near / (near - far)),
+ (0., 0., -1., 0.)), dtype=numpy.float32)
+
+
+def mat4Orthographic(left, right, bottom, top, near, far):
+ """Creates an orthographic (i.e., parallel) projection matrix.
+
+ See glOrtho.
+ """
+ return numpy.array((
+ (2. / (right - left), 0., 0., - (right + left) / (right - left)),
+ (0., 2. / (top - bottom), 0., - (top + bottom) / (top - bottom)),
+ (0., 0., -2. / (far - near), - (far + near) / (far - near)),
+ (0., 0., 0., 1.)), dtype=numpy.float32)
+
+
+# Affine
+
+def mat4Translate(tx, ty, tz):
+ """4x4 translation matrix."""
+ return numpy.array((
+ (1., 0., 0., tx),
+ (0., 1., 0., ty),
+ (0., 0., 1., tz),
+ (0., 0., 0., 1.)), dtype=numpy.float32)
+
+
+def mat4Scale(sx, sy, sz):
+ """4x4 scale matrix."""
+ return numpy.array((
+ (sx, 0., 0., 0.),
+ (0., sy, 0., 0.),
+ (0., 0., sz, 0.),
+ (0., 0., 0., 1.)), dtype=numpy.float32)
+
+
+def mat4RotateFromAngleAxis(angle, x=0., y=0., z=1.):
+ """4x4 rotation matrix from angle and axis.
+
+ :param float angle: The rotation angle in radians.
+ :param float x: The rotation vector x coordinate.
+ :param float y: The rotation vector y coordinate.
+ :param float z: The rotation vector z coordinate.
+ """
+ ca = numpy.cos(angle)
+ sa = numpy.sin(angle)
+ return numpy.array((
+ ((1.-ca) * x*x + ca, (1.-ca) * x*y - sa*z, (1.-ca) * x*z + sa*y, 0.),
+ ((1.-ca) * x*y + sa*z, (1.-ca) * y*y + ca, (1.-ca) * y*z - sa*x, 0.),
+ ((1.-ca) * x*z - sa*y, (1.-ca) * y*z + sa*x, (1.-ca) * z*z + ca, 0.),
+ (0., 0., 0., 1.)), dtype=numpy.float32)
+
+
+def mat4RotateFromQuaternion(quaternion):
+ """4x4 rotation matrix from quaternion.
+
+ :param quaternion: Array-like unit quaternion stored as (x, y, z, w)
+ """
+ quaternion = numpy.array(quaternion, copy=True)
+ quaternion /= numpy.linalg.norm(quaternion)
+
+ qx, qy, qz, qw = quaternion
+ return numpy.array((
+ (1. - 2.*(qy**2 + qz**2), 2.*(qx*qy - qw*qz), 2.*(qx*qz + qw*qy), 0.),
+ (2.*(qx*qy + qw*qz), 1. - 2.*(qx**2 + qz**2), 2.*(qy*qz - qw*qx), 0.),
+ (2.*(qx*qz - qw*qy), 2.*(qy*qz + qw*qx), 1. - 2.*(qx**2 + qy**2), 0.),
+ (0., 0., 0., 1.)), dtype=numpy.float32)
+
+
+def mat4Shear(axis, sx=0., sy=0., sz=0.):
+ """4x4 shear matrix: Skew two axes relative to a third fixed one.
+
+ shearFactor = tan(shearAngle)
+
+ :param str axis: The axis to keep constant and shear against.
+ In 'x', 'y', 'z'.
+ :param float sx: The shear factor for the X axis relative to axis.
+ :param float sy: The shear factor for the Y axis relative to axis.
+ :param float sz: The shear factor for the Z axis relative to axis.
+ """
+ assert axis in ('x', 'y', 'z')
+
+ matrix = numpy.identity(4, dtype=numpy.float32)
+
+ # Make the shear column
+ index = 'xyz'.find(axis)
+ shearcolumn = numpy.array((sx, sy, sz, 0.), dtype=numpy.float32)
+ shearcolumn[index] = 1.
+ matrix[:, index] = shearcolumn
+ return matrix
+
+
+# Transforms ##################################################################
+
+class Transform(event.Notifier):
+
+ def __init__(self, static=False):
+ """Base class for (row-major) 4x4 matrix transforms.
+
+ :param bool static: False (default) to reset cache when changed,
+ True for static matrices.
+ """
+ super(Transform, self).__init__()
+ self._matrix = None
+ self._inverse = None
+ if not static:
+ self.addListener(self._changed) # Listening self for changes
+
+ def __repr__(self):
+ return '%s(%s)' % (self.__class__.__init__,
+ repr(self.getMatrix(copy=False)))
+
+ def inverse(self):
+ """Return the Transform of the inverse.
+
+ The returned Transform is static, it is not updated when this
+ Transform is modified.
+
+ :return: A Transform which is the inverse of this Transform.
+ """
+ return Inverse(self)
+
+ # Matrix
+
+ def _makeMatrix(self):
+ """Override to build matrix"""
+ return numpy.identity(4, dtype=numpy.float32)
+
+ def _makeInverse(self):
+ """Override to build inverse matrix."""
+ return numpy.linalg.inv(self.getMatrix(copy=False))
+
+ def getMatrix(self, copy=True):
+ """The 4x4 matrix of this transform.
+
+ :param bool copy: True (the default) to get a copy of the matrix,
+ False to get the internal matrix, do not modify!
+ :return: 4x4 matrix of this transform.
+ """
+ if self._matrix is None:
+ self._matrix = self._makeMatrix()
+ if copy:
+ return self._matrix.copy()
+ else:
+ return self._matrix
+
+ matrix = property(getMatrix, doc="The 4x4 matrix of this transform.")
+
+ def getInverseMatrix(self, copy=False):
+ """The 4x4 matrix of the inverse of this transform.
+
+ :param bool copy: True (the default) to get a copy of the matrix,
+ False to get the internal matrix, do not modify!
+ :return: 4x4 matrix of the inverse of this transform.
+ """
+ if self._inverse is None:
+ self._inverse = self._makeInverse()
+ if copy:
+ return self._inverse.copy()
+ else:
+ return self._inverse
+
+ inverseMatrix = property(
+ getInverseMatrix,
+ doc="The 4x4 matrix of the inverse of this transform.")
+
+ # Listener
+
+ def _changed(self, source):
+ """Default self listener reseting matrix cache."""
+ self._matrix = None
+ self._inverse = None
+
+ # Multiplication with vectors
+
+ def transformPoints(self, points, direct=True, perspectiveDivide=False):
+ """Apply the transform to an array of points.
+
+ :param points: 2D array of N vectors of 3 or 4 coordinates
+ :param bool direct: Whether to apply the direct (True, the default)
+ or inverse (False) transform.
+ :param bool perspectiveDivide: Whether to apply the perspective divide
+ (True) or not (False, the default).
+ :return: The transformed points.
+ :rtype: numpy.ndarray of same shape as points.
+ """
+ if direct:
+ matrix = self.getMatrix(copy=False)
+ else:
+ matrix = self.getInverseMatrix(copy=False)
+
+ points = numpy.array(points, copy=False)
+ assert points.ndim == 2
+
+ points = numpy.transpose(points)
+
+ dimension = points.shape[0]
+ assert dimension in (3, 4)
+
+ if dimension == 3: # Add 4th coordinate
+ points = numpy.append(
+ points,
+ numpy.ones((1, points.shape[1]), dtype=points.dtype),
+ axis=0)
+
+ result = numpy.transpose(numpy.dot(matrix, points))
+
+ if perspectiveDivide:
+ mask = result[:, 3] != 0.
+ result[mask] /= result[mask, 3][:, numpy.newaxis]
+
+ return result[:, :3] if dimension == 3 else result
+
+ @staticmethod
+ def _prepareVector(vector, w):
+ """Add 4th coordinate (w) to vector if missing."""
+ assert len(vector) in (3, 4)
+ vector = numpy.array(vector, copy=False, dtype=numpy.float32)
+ if len(vector) == 3:
+ vector = numpy.append(vector, w)
+ return vector
+
+ def transformPoint(self, point, direct=True, perspectiveDivide=False):
+ """Apply the transform to a point.
+
+ :param point: Array-like vector of 3 or 4 coordinates.
+ :param bool direct: Whether to apply the direct (True, the default)
+ or inverse (False) transform.
+ :param bool perspectiveDivide: Whether to apply the perspective divide
+ (True) or not (False, the default).
+ :return: The transformed point.
+ :rtype: numpy.ndarray of same length as point.
+ """
+ if direct:
+ matrix = self.getMatrix(copy=False)
+ else:
+ matrix = self.getInverseMatrix(copy=False)
+ result = numpy.dot(matrix, self._prepareVector(point, 1.))
+
+ if perspectiveDivide and result[3] != 0.:
+ result /= result[3]
+
+ if len(point) == 3:
+ return result[:3]
+ else:
+ return result
+
+ def transformDir(self, direction, direct=True):
+ """Apply the transform to a direction.
+
+ :param direction: Array-like vector of 3 coordinates.
+ :param bool direct: Whether to apply the direct (True, the default)
+ or inverse (False) transform.
+ :return: The transformed direction.
+ :rtype: numpy.ndarray of length 3.
+ """
+ if direct:
+ matrix = self.getMatrix(copy=False)
+ else:
+ matrix = self.getInverseMatrix(copy=False)
+ return numpy.dot(matrix[:3, :3], direction[:3])
+
+ def transformNormal(self, normal, direct=True):
+ """Apply the transform to a normal: R = (M-1)t * V.
+
+ :param normal: Array-like vector of 3 coordinates.
+ :param bool direct: Whether to apply the direct (True, the default)
+ or inverse (False) transform.
+ :return: The transformed normal.
+ :rtype: numpy.ndarray of length 3.
+ """
+ if direct:
+ matrix = self.getInverseMatrix(copy=False).T
+ else:
+ matrix = self.getMatrix(copy=False).T
+ return numpy.dot(matrix[:3, :3], normal[:3])
+
+ _CUBE_CORNERS = numpy.array(list(itertools.product((0., 1.), repeat=3)),
+ dtype=numpy.float32)
+ """Unit cube corners used by :meth:`transformBounds`"""
+
+ def transformBounds(self, bounds, direct=True):
+ """Apply the transform to an axes-aligned rectangular box.
+
+ :param bounds: Min and max coords of the box for each axes.
+ :type bounds: 2x3 numpy.ndarray
+ :param bool direct: Whether to apply the direct (True, the default)
+ or inverse (False) transform.
+ :return: Axes-aligned rectangular box including the transformed box.
+ :rtype: 2x3 numpy.ndarray of float32
+ """
+ corners = numpy.ones((8, 4), dtype=numpy.float32)
+ corners[:, :3] = bounds[0] + \
+ self._CUBE_CORNERS * (bounds[1] - bounds[0])
+
+ if direct:
+ matrix = self.getMatrix(copy=False)
+ else:
+ matrix = self.getInverseMatrix(copy=False)
+
+ # Transform corners
+ cornerstransposed = numpy.dot(matrix, corners.T)
+ cornerstransposed = cornerstransposed / cornerstransposed[3]
+
+ # Get min/max for each axis
+ transformedbounds = numpy.empty((2, 3), dtype=numpy.float32)
+ transformedbounds[0] = cornerstransposed.T[:, :3].min(axis=0)
+ transformedbounds[1] = cornerstransposed.T[:, :3].max(axis=0)
+
+ return transformedbounds
+
+
+class Inverse(Transform):
+ """Transform which is the inverse of another one.
+
+ Static: It never gets updated.
+ """
+
+ def __init__(self, transform):
+ """Initializer.
+
+ :param Transform transform: The transform to invert.
+ """
+
+ super(Inverse, self).__init__(static=True)
+ self._matrix = transform.getInverseMatrix(copy=True)
+ self._inverse = transform.getMatrix(copy=True)
+
+
+class TransformList(Transform, event.HookList):
+ """List of transforms."""
+
+ def __init__(self, iterable=()):
+ Transform.__init__(self)
+ event.HookList.__init__(self, iterable)
+
+ def _listWillChangeHook(self, methodName, *args, **kwargs):
+ for item in self:
+ item.removeListener(self._transformChanged)
+
+ def _listWasChangedHook(self, methodName, *args, **kwargs):
+ for item in self:
+ item.addListener(self._transformChanged)
+ self.notify()
+
+ def _transformChanged(self, source):
+ """Listen to transform changes of the list and its items."""
+ if source is not self: # Avoid infinite recursion
+ self.notify()
+
+ def _makeMatrix(self):
+ matrix = numpy.identity(4, dtype=numpy.float32)
+ for transform in self:
+ matrix = numpy.dot(matrix, transform.getMatrix(copy=False))
+ return matrix
+
+
+class StaticTransformList(Transform):
+ """Transform that is a snapshot of a list of Transforms
+
+ It does not keep reference to the list of Transforms.
+
+ :param iterable: Iterable of Transform used for initialization
+ """
+
+ def __init__(self, iterable=()):
+ super(StaticTransformList, self).__init__(static=True)
+ matrix = numpy.identity(4, dtype=numpy.float32)
+ for transform in iterable:
+ matrix = numpy.dot(matrix, transform.getMatrix(copy=False))
+ self._matrix = matrix # Init matrix once
+
+
+# Affine ######################################################################
+
+class Matrix(Transform):
+
+ def __init__(self, matrix=None):
+ """4x4 Matrix.
+
+ :param matrix: 4x4 array-like matrix or None for identity matrix.
+ """
+ super(Matrix, self).__init__(static=True)
+ self.setMatrix(matrix)
+
+ def setMatrix(self, matrix=None):
+ """Update the 4x4 Matrix.
+
+ :param matrix: 4x4 array-like matrix or None for identity matrix.
+ """
+ if matrix is None:
+ self._matrix = numpy.identity(4, dtype=numpy.float32)
+ else:
+ matrix = numpy.array(matrix, copy=True, dtype=numpy.float32)
+ assert matrix.shape == (4, 4)
+ self._matrix = matrix
+ # Reset cached inverse as Transform is declared static
+ self._inverse = None
+ self.notify()
+
+ # Redefined here to add a setter
+ matrix = property(Transform.getMatrix, setMatrix,
+ doc="The 4x4 matrix of this transform.")
+
+
+class Translate(Transform):
+ """4x4 translation matrix."""
+
+ def __init__(self, tx=0., ty=0., tz=0.):
+ super(Translate, self).__init__()
+ self._tx, self._ty, self._tz = 0., 0., 0.
+ self.setTranslate(tx, ty, tz)
+
+ def _makeMatrix(self):
+ return mat4Translate(self.tx, self.ty, self.tz)
+
+ def _makeInverse(self):
+ return mat4Translate(-self.tx, -self.ty, -self.tz)
+
+ @property
+ def tx(self):
+ return self._tx
+
+ @tx.setter
+ def tx(self, tx):
+ self.setTranslate(tx=tx)
+
+ @property
+ def ty(self):
+ return self._ty
+
+ @ty.setter
+ def ty(self, ty):
+ self.setTranslate(ty=ty)
+
+ @property
+ def tz(self):
+ return self._tz
+
+ @tz.setter
+ def tz(self, tz):
+ self.setTranslate(tz=tz)
+
+ @property
+ def translation(self):
+ return numpy.array((self.tx, self.ty, self.tz), dtype=numpy.float32)
+
+ @translation.setter
+ def translation(self, translations):
+ tx, ty, tz = translations
+ self.setTranslate(tx, ty, tz)
+
+ def setTranslate(self, tx=None, ty=None, tz=None):
+ if tx is not None:
+ self._tx = tx
+ if ty is not None:
+ self._ty = ty
+ if tz is not None:
+ self._tz = tz
+ self.notify()
+
+
+class Scale(Transform):
+ """4x4 scale matrix."""
+
+ def __init__(self, sx=1., sy=1., sz=1.):
+ super(Scale, self).__init__()
+ self._sx, self._sy, self._sz = 0., 0., 0.
+ self.setScale(sx, sy, sz)
+
+ def _makeMatrix(self):
+ return mat4Scale(self.sx, self.sy, self.sz)
+
+ def _makeInverse(self):
+ return mat4Scale(1. / self.sx, 1. / self.sy, 1. / self.sz)
+
+ @property
+ def sx(self):
+ return self._sx
+
+ @sx.setter
+ def sx(self, sx):
+ self.setScale(sx=sx)
+
+ @property
+ def sy(self):
+ return self._sy
+
+ @sy.setter
+ def sy(self, sy):
+ self.setScale(sy=sy)
+
+ @property
+ def sz(self):
+ return self._sz
+
+ @sz.setter
+ def sz(self, sz):
+ self.setScale(sz=sz)
+
+ @property
+ def scale(self):
+ return numpy.array((self._sx, self._sy, self._sz), dtype=numpy.float32)
+
+ @scale.setter
+ def scale(self, scales):
+ sx, sy, sz = scales
+ self.setScale(sx, sy, sz)
+
+ def setScale(self, sx=None, sy=None, sz=None):
+ if sx is not None:
+ assert sx != 0.
+ self._sx = sx
+ if sy is not None:
+ assert sy != 0.
+ self._sy = sy
+ if sz is not None:
+ assert sz != 0.
+ self._sz = sz
+ self.notify()
+
+
+class Rotate(Transform):
+
+ def __init__(self, angle=0., ax=0., ay=0., az=1.):
+ """4x4 rotation matrix.
+
+ :param float angle: The rotation angle in degrees.
+ :param float ax: The x coordinate of the rotation axis.
+ :param float ay: The y coordinate of the rotation axis.
+ :param float az: The z coordinate of the rotation axis.
+ """
+ super(Rotate, self).__init__()
+ self._angle = 0.
+ self._axis = None
+ self.setAngleAxis(angle, (ax, ay, az))
+
+ @property
+ def angle(self):
+ """The rotation angle in degrees."""
+ return self._angle
+
+ @angle.setter
+ def angle(self, angle):
+ self.setAngleAxis(angle=angle)
+
+ @property
+ def axis(self):
+ """The normalized rotation axis as a numpy.ndarray."""
+ return self._axis.copy()
+
+ @axis.setter
+ def axis(self, axis):
+ self.setAngleAxis(axis=axis)
+
+ def setAngleAxis(self, angle=None, axis=None):
+ """Update the angle and/or axis of the rotation.
+
+ :param float angle: The rotation angle in degrees.
+ :param axis: Array-like axis vector (3 coordinates).
+ """
+ if angle is not None:
+ self._angle = angle
+ if axis is not None:
+ assert len(axis) == 3
+ axis = numpy.array(axis, copy=True, dtype=numpy.float32)
+ assert axis.size == 3
+ norm = numpy.linalg.norm(axis)
+ if norm == 0.: # No axis, set rotation angle to 0.
+ self._angle = 0.
+ self._axis = numpy.array((0., 0., 1.), dtype=numpy.float32)
+ else:
+ self._axis = axis / norm
+
+ if angle is not None or axis is not None:
+ self.notify()
+
+ @property
+ def quaternion(self):
+ """Rotation unit quaternion as (x, y, z, w).
+
+ Where: ||(x, y, z)|| = sin(angle/2), w = cos(angle/2).
+ """
+ if numpy.linalg.norm(self._axis) == 0.:
+ return numpy.array((0., 0., 0., 1.), dtype=numpy.float32)
+
+ else:
+ quaternion = numpy.empty((4,), dtype=numpy.float32)
+ halfangle = 0.5 * numpy.radians(self.angle)
+ quaternion[0:3] = numpy.sin(halfangle) * self._axis
+ quaternion[3] = numpy.cos(halfangle)
+ return quaternion
+
+ @quaternion.setter
+ def quaternion(self, quaternion):
+ assert len(quaternion) == 4
+
+ # Normalize quaternion
+ quaternion = numpy.array(quaternion, copy=True)
+ quaternion /= numpy.linalg.norm(quaternion)
+
+ # Get angle
+ sinhalfangle = numpy.linalg.norm(quaternion[0:3])
+ coshalfangle = quaternion[3]
+ angle = 2. * numpy.arctan2(sinhalfangle, coshalfangle)
+
+ # Axis will be normalized in setAngleAxis
+ self.setAngleAxis(numpy.degrees(angle), quaternion[0:3])
+
+ def _makeMatrix(self):
+ angle = numpy.radians(self.angle, dtype=numpy.float32)
+ return mat4RotateFromAngleAxis(angle, *self.axis)
+
+ def _makeInverse(self):
+ return numpy.array(self.getMatrix(copy=False).transpose(),
+ copy=True, order='C',
+ dtype=numpy.float32)
+
+
+class Shear(Transform):
+
+ def __init__(self, axis, sx=0., sy=0., sz=0.):
+ """4x4 shear/skew matrix of 2 axes relative to the third one.
+
+ :param str axis: The axis to keep fixed, in 'x', 'y', 'z'
+ :param float sx: The shear factor for the x axis.
+ :param float sy: The shear factor for the y axis.
+ :param float sz: The shear factor for the z axis.
+ """
+ assert axis in ('x', 'y', 'z')
+ super(Shear, self).__init__()
+ self._axis = axis
+ self._factors = sx, sy, sz
+
+ @property
+ def axis(self):
+ """The axis against which other axes are skewed."""
+ return self._axis
+
+ @property
+ def factors(self):
+ """The shear factors: shearFactor = tan(shearAngle)"""
+ return self._factors
+
+ def _makeMatrix(self):
+ return mat4Shear(self.axis, *self.factors)
+
+ def _makeInverse(self):
+ sx, sy, sz = self.factors
+ return mat4Shear(self.axis, -sx, -sy, -sz)
+
+
+# Projection ##################################################################
+
+class _Projection(Transform):
+ """Base class for projection matrix.
+
+ Handles near and far clipping plane values.
+ Subclasses must implement :meth:`_makeMatrix`.
+
+ :param float near: Distance to the near plane.
+ :param float far: Distance to the far plane.
+ :param bool checkDepthExtent: Toggle checks near > 0 and far > near.
+ :param size:
+ Viewport's size used to compute the aspect ratio (width, height).
+ :type size: 2-tuple of float
+ """
+
+ def __init__(self, near, far, checkDepthExtent=False, size=(1., 1.)):
+ super(_Projection, self).__init__()
+ self._checkDepthExtent = checkDepthExtent
+ self._depthExtent = 1, 10
+ self.setDepthExtent(near, far) # set _depthExtent
+ self._size = 1., 1.
+ self.size = size # set _size
+
+ def setDepthExtent(self, near=None, far=None):
+ """Set the extent of the visible area along the viewing direction.
+
+ :param float near: The near clipping plane Z coord.
+ :param float far: The far clipping plane Z coord.
+ """
+ near = float(near) if near is not None else self._depthExtent[0]
+ far = float(far) if far is not None else self._depthExtent[1]
+
+ if self._checkDepthExtent:
+ assert near > 0.
+ assert far > near
+
+ self._depthExtent = near, far
+ self.notify()
+
+ @property
+ def near(self):
+ """Distance to the near plane."""
+ return self._depthExtent[0]
+
+ @near.setter
+ def near(self, near):
+ if near != self.near:
+ self.setDepthExtent(near=near)
+
+ @property
+ def far(self):
+ """Distance to the far plane."""
+ return self._depthExtent[1]
+
+ @far.setter
+ def far(self, far):
+ if far != self.far:
+ self.setDepthExtent(far=far)
+
+ @property
+ def size(self):
+ """Viewport size as a 2-tuple of float (width, height)."""
+ return self._size
+
+ @size.setter
+ def size(self, size):
+ assert len(size) == 2
+ self._size = tuple(size)
+ self.notify()
+
+
+class Orthographic(_Projection):
+ """Orthographic (i.e., parallel) projection which can keep aspect ratio.
+
+ Clipping planes are adjusted to match the aspect ratio of
+ the :attr:`size` attribute if :attr:`keepaspect` is True.
+
+ In this case, the left, right, bottom and top parameters defines the area
+ which must always remain visible.
+ Effective clipping planes are adjusted to keep the aspect ratio.
+
+ :param float left: Coord of the left clipping plane.
+ :param float right: Coord of the right clipping plane.
+ :param float bottom: Coord of the bottom clipping plane.
+ :param float top: Coord of the top clipping plane.
+ :param float near: Distance to the near plane.
+ :param float far: Distance to the far plane.
+ :param size:
+ Viewport's size used to compute the aspect ratio (width, height).
+ :type size: 2-tuple of float
+ :param bool keepaspect:
+ True (default) to keep aspect ratio, False otherwise.
+ """
+
+ def __init__(self, left=0., right=1., bottom=1., top=0., near=-1., far=1.,
+ size=(1., 1.), keepaspect=True):
+ self._left, self._right = left, right
+ self._bottom, self._top = bottom, top
+ self._keepaspect = bool(keepaspect)
+ super(Orthographic, self).__init__(near, far, checkDepthExtent=False,
+ size=size)
+ # _update called when setting size
+
+ def _makeMatrix(self):
+ return mat4Orthographic(
+ self.left, self.right, self.bottom, self.top, self.near, self.far)
+
+ def _update(self, left, right, bottom, top):
+ if self.keepaspect:
+ width, height = self.size
+ aspect = width / height
+
+ orthoaspect = abs(left - right) / abs(bottom - top)
+
+ if orthoaspect >= aspect: # Keep width, enlarge height
+ newheight = \
+ numpy.sign(top - bottom) * abs(left - right) / aspect
+ bottom = 0.5 * (bottom + top) - 0.5 * newheight
+ top = bottom + newheight
+
+ else: # Keep height, enlarge width
+ newwidth = \
+ numpy.sign(right - left) * abs(bottom - top) * aspect
+ left = 0.5 * (left + right) - 0.5 * newwidth
+ right = left + newwidth
+
+ # Store values
+ self._left, self._right = left, right
+ self._bottom, self._top = bottom, top
+
+ def setClipping(self, left=None, right=None, bottom=None, top=None):
+ """Set the clipping planes of the projection.
+
+ Parameters are adjusted to keep aspect ratio.
+ If a clipping plane coord is not provided, it uses its current value
+
+ :param float left: Coord of the left clipping plane.
+ :param float right: Coord of the right clipping plane.
+ :param float bottom: Coord of the bottom clipping plane.
+ :param float top: Coord of the top clipping plane.
+ """
+ left = float(left) if left is not None else self.left
+ right = float(right) if right is not None else self.right
+ bottom = float(bottom) if bottom is not None else self.bottom
+ top = float(top) if top is not None else self.top
+
+ self._update(left, right, bottom, top)
+ self.notify()
+
+ left = property(lambda self: self._left,
+ doc="Coord of the left clipping plane.")
+
+ right = property(lambda self: self._right,
+ doc="Coord of the right clipping plane.")
+
+ bottom = property(lambda self: self._bottom,
+ doc="Coord of the bottom clipping plane.")
+
+ top = property(lambda self: self._top,
+ doc="Coord of the top clipping plane.")
+
+ @property
+ def size(self):
+ """Viewport size as a 2-tuple of float (width, height)"""
+ return self._size
+
+ @size.setter
+ def size(self, size):
+ assert len(size) == 2
+ size = float(size[0]), float(size[1])
+ if size != self._size:
+ self._size = size
+ self._update(self.left, self.right, self.bottom, self.top)
+ self.notify()
+
+ @property
+ def keepaspect(self):
+ """True to keep aspect ratio, False otherwise."""
+ return self._keepaspect
+
+ @keepaspect.setter
+ def keepaspect(self, aspect):
+ aspect = bool(aspect)
+ if aspect != self._keepaspect:
+ self._keepaspect = aspect
+ self._update(self.left, self.right, self.bottom, self.top)
+ self.notify()
+
+
+class Ortho2DWidget(_Projection):
+ """Orthographic projection with pixel as unit.
+
+ Provides same coordinates as widgets:
+ origin: top left, X axis goes left, Y axis goes down.
+
+ :param float near: Z coordinate of the near clipping plane.
+ :param float far: Z coordinante of the far clipping plane.
+ :param size:
+ Viewport's size used to compute the aspect ratio (width, height).
+ :type size: 2-tuple of float
+ """
+
+ def __init__(self, near=-1., far=1., size=(1., 1.)):
+
+ super(Ortho2DWidget, self).__init__(near, far, size)
+
+ def _makeMatrix(self):
+ width, height = self.size
+ return mat4Orthographic(0., width, height, 0., self.near, self.far)
+
+
+class Perspective(_Projection):
+ """Perspective projection matrix defined by FOV and aspect ratio.
+
+ :param float fovy: Vertical field-of-view in degrees.
+ :param float near: The near clipping plane Z coord (stricly positive).
+ :param float far: The far clipping plane Z coord (> near).
+ :param size:
+ Viewport's size used to compute the aspect ratio (width, height).
+ :type size: 2-tuple of float
+ """
+
+ def __init__(self, fovy=90., near=0.1, far=1., size=(1., 1.)):
+
+ super(Perspective, self).__init__(near, far, checkDepthExtent=True)
+ self._fovy = 90.
+ self.fovy = fovy # Set _fovy
+ self.size = size # Set _ size
+
+ def _makeMatrix(self):
+ width, height = self.size
+ return mat4Perspective(self.fovy, width, height, self.near, self.far)
+
+ @property
+ def fovy(self):
+ """Vertical field-of-view in degrees."""
+ return self._fovy
+
+ @fovy.setter
+ def fovy(self, fovy):
+ self._fovy = float(fovy)
+ self.notify()
diff --git a/src/silx/gui/plot3d/scene/utils.py b/src/silx/gui/plot3d/scene/utils.py
new file mode 100644
index 0000000..c6cd129
--- /dev/null
+++ b/src/silx/gui/plot3d/scene/utils.py
@@ -0,0 +1,662 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module provides functions to generate indices, to check intersection
+and to handle planes.
+"""
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/07/2016"
+
+
+import logging
+import numpy
+
+from . import event
+
+
+_logger = logging.getLogger(__name__)
+
+
+# numpy #######################################################################
+
+def _uniqueAlongLastAxis(a):
+ """Numpy unique on the last axis of a 2D array
+
+ Implemented here as not in numpy as of writing.
+
+ See adding axis parameter to numpy.unique:
+ https://github.com/numpy/numpy/pull/3584/files#r6225452
+
+ :param array_like a: Input array.
+ :return: Unique elements along the last axis.
+ :rtype: numpy.ndarray
+ """
+ assert len(a.shape) == 2
+
+ # Construct a type over last array dimension to run unique on a 1D array
+ if a.dtype.char in numpy.typecodes['AllInteger']:
+ # Bit-wise comparison of the 2 indices of a line at once
+ # Expect a C contiguous array of shape N, 2
+ uniquedt = numpy.dtype((numpy.void, a.itemsize * a.shape[-1]))
+ elif a.dtype.char in numpy.typecodes['Float']:
+ uniquedt = [('f{i}'.format(i=i), a.dtype) for i in range(a.shape[-1])]
+ else:
+ raise TypeError("Unsupported type {dtype}".format(dtype=a.dtype))
+
+ uniquearray = numpy.unique(numpy.ascontiguousarray(a).view(uniquedt))
+ return uniquearray.view(a.dtype).reshape((-1, a.shape[-1]))
+
+
+# conversions #################################################################
+
+def triangleToLineIndices(triangleIndices, unicity=False):
+ """Generates lines indices from triangle indices.
+
+ This is generating lines indices for the edges of the triangles.
+
+ :param triangleIndices: The indices to draw a set of vertices as triangles.
+ :type triangleIndices: numpy.ndarray
+ :param bool unicity: If True remove duplicated lines,
+ else (the default) returns all lines.
+ :return: The indices to draw the edges of the triangles as lines.
+ :rtype: 1D numpy.ndarray of uint16 or uint32.
+ """
+ # Makes sure indices ar packed by triangle
+ triangleIndices = triangleIndices.reshape(-1, 3)
+
+ # Pack line indices by triangle and by edge
+ lineindices = numpy.empty((len(triangleIndices), 3, 2),
+ dtype=triangleIndices.dtype)
+ lineindices[:, 0] = triangleIndices[:, :2] # edge = t0, t1
+ lineindices[:, 1] = triangleIndices[:, 1:] # edge =t1, t2
+ lineindices[:, 2] = triangleIndices[:, ::2] # edge = t0, t2
+
+ if unicity:
+ lineindices = _uniqueAlongLastAxis(lineindices.reshape(-1, 2))
+
+ # Make sure it is 1D
+ lineindices.shape = -1
+
+ return lineindices
+
+
+def verticesNormalsToLines(vertices, normals, scale=1.):
+ """Return vertices of lines representing normals at given positions.
+
+ :param vertices: Positions of the points.
+ :type vertices: numpy.ndarray with shape: (nbPoints, 3)
+ :param normals: Corresponding normals at the points.
+ :type normals: numpy.ndarray with shape: (nbPoints, 3)
+ :param float scale: The scale factor to apply to normals.
+ :returns: Array of vertices to draw corresponding lines.
+ :rtype: numpy.ndarray with shape: (nbPoints * 2, 3)
+ """
+ linevertices = numpy.empty((len(vertices) * 2, 3), dtype=vertices.dtype)
+ linevertices[0::2] = vertices
+ linevertices[1::2] = vertices + scale * normals
+ return linevertices
+
+
+def unindexArrays(mode, indices, *arrays):
+ """Convert indexed GL primitives to unindexed ones.
+
+ Given indices in arrays and the OpenGL primitive they represent,
+ return the unindexed equivalent.
+
+ :param str mode:
+ Kind of primitive represented by indices.
+ In: points, lines, line_strip, loop, triangles, triangle_strip, fan.
+ :param indices: Indices in other arrays
+ :type indices: numpy.ndarray of dimension 1.
+ :param arrays: Remaining arguments are arrays to convert
+ :return: Converted arrays
+ :rtype: tuple of numpy.ndarray
+ """
+ indices = numpy.array(indices, copy=False)
+
+ assert mode in ('points',
+ 'lines', 'line_strip', 'loop',
+ 'triangles', 'triangle_strip', 'fan')
+
+ if mode in ('lines', 'line_strip', 'loop'):
+ assert len(indices) >= 2
+ elif mode in ('triangles', 'triangle_strip', 'fan'):
+ assert len(indices) >= 3
+
+ assert indices.min() >= 0
+ max_index = indices.max()
+ for data in arrays:
+ assert len(data) >= max_index
+
+ if mode == 'line_strip':
+ unpacked = numpy.empty((2 * (len(indices) - 1),), dtype=indices.dtype)
+ unpacked[0::2] = indices[:-1]
+ unpacked[1::2] = indices[1:]
+ indices = unpacked
+
+ elif mode == 'loop':
+ unpacked = numpy.empty((2 * len(indices),), dtype=indices.dtype)
+ unpacked[0::2] = indices
+ unpacked[1:-1:2] = indices[1:]
+ unpacked[-1] = indices[0]
+ indices = unpacked
+
+ elif mode == 'triangle_strip':
+ unpacked = numpy.empty((3 * (len(indices) - 2),), dtype=indices.dtype)
+ unpacked[0::3] = indices[:-2]
+ unpacked[1::3] = indices[1:-1]
+ unpacked[2::3] = indices[2:]
+ indices = unpacked
+
+ elif mode == 'fan':
+ unpacked = numpy.empty((3 * (len(indices) - 2),), dtype=indices.dtype)
+ unpacked[0::3] = indices[0]
+ unpacked[1::3] = indices[1:-1]
+ unpacked[2::3] = indices[2:]
+ indices = unpacked
+
+ return tuple(numpy.ascontiguousarray(data[indices]) for data in arrays)
+
+
+def triangleStripToTriangles(strip):
+ """Convert a triangle strip to a set of triangles.
+
+ The order of the corners is inverted for odd triangles.
+
+ :param numpy.ndarray strip:
+ Array of triangle corners of shape (N, 3).
+ N must be at least 3.
+ :return: Equivalent triangles corner as an array of shape (N, 3, 3)
+ :rtype: numpy.ndarray
+ """
+ strip = numpy.array(strip).reshape(-1, 3)
+ assert len(strip) >= 3
+
+ triangles = numpy.empty((len(strip) - 2, 3, 3), dtype=strip.dtype)
+ triangles[0::2, 0] = strip[0:-2:2]
+ triangles[0::2, 1] = strip[1:-1:2]
+ triangles[0::2, 2] = strip[2::2]
+
+ triangles[1::2, 0] = strip[3::2]
+ triangles[1::2, 1] = strip[2:-1:2]
+ triangles[1::2, 2] = strip[1:-2:2]
+
+ return triangles
+
+
+def trianglesNormal(positions):
+ """Return normal for each triangle.
+
+ :param positions: Serie of triangle's corners
+ :type positions: numpy.ndarray of shape (NbTriangles*3, 3)
+ :return: Normals corresponding to each position.
+ :rtype: numpy.ndarray of shape (NbTriangles, 3)
+ """
+ assert positions.ndim == 2
+ assert positions.shape[1] == 3
+
+ positions = numpy.array(positions, copy=False).reshape(-1, 3, 3)
+
+ normals = numpy.cross(positions[:, 1] - positions[:, 0],
+ positions[:, 2] - positions[:, 0])
+
+ # Normalize normals
+ norms = numpy.linalg.norm(normals, axis=1)
+ norms[norms == 0] = 1
+
+ return normals / norms.reshape(-1, 1)
+
+
+# grid ########################################################################
+
+def gridVertices(dim0Array, dim1Array, dtype):
+ """Generate an array of 2D positions from 2 arrays of 1D coordinates.
+
+ :param dim0Array: 1D array-like of coordinates along the first dimension.
+ :param dim1Array: 1D array-like of coordinates along the second dimension.
+ :param numpy.dtype dtype: Data type of the output array.
+ :return: Array of grid coordinates.
+ :rtype: numpy.ndarray with shape: (len(dim0Array), len(dim1Array), 2)
+ """
+ grid = numpy.empty((len(dim0Array), len(dim1Array), 2), dtype=dtype)
+ grid.T[0, :, :] = dim0Array
+ grid.T[1, :, :] = numpy.array(dim1Array, copy=False)[:, None]
+ return grid
+
+
+def triangleStripGridIndices(dim0, dim1):
+ """Generate indices to draw a grid of vertices as a triangle strip.
+
+ Vertices are expected to be stored as row-major (i.e., C contiguous).
+
+ :param int dim0: The number of rows of vertices.
+ :param int dim1: The number of columns of vertices.
+ :return: The vertex indices
+ :rtype: 1D numpy.ndarray of uint32
+ """
+ assert dim0 >= 2
+ assert dim1 >= 2
+
+ # Filling a row of squares +
+ # an index before and one after for degenerated triangles
+ indices = numpy.empty((dim0 - 1, 2 * (dim1 + 1)), dtype=numpy.uint32)
+
+ # Init indices with minimum indices for each row of squares
+ indices[:] = (dim1 * numpy.arange(dim0 - 1, dtype=numpy.uint32))[:, None]
+
+ # Update indices with offset per row of squares
+ offset = numpy.arange(dim1, dtype=numpy.uint32)
+ indices[:, 1:-1:2] += offset
+ offset += dim1
+ indices[:, 2::2] += offset
+ indices[:, -1] += offset[-1]
+
+ # Remove extra indices for degenerated triangles before returning
+ return indices.ravel()[1:-1]
+
+ # Alternative:
+ # indices = numpy.zeros(2 * dim1 * (dim0 - 1) + 2 * (dim0 - 2),
+ # dtype=numpy.uint32)
+ #
+ # offset = numpy.arange(dim1, dtype=numpy.uint32)
+ # for d0Index in range(dim0 - 1):
+ # start = 2 * d0Index * (dim1 + 1)
+ # end = start + 2 * dim1
+ # if d0Index != 0:
+ # indices[start - 2] = offset[-1]
+ # indices[start - 1] = offset[0]
+ # indices[start:end:2] = offset
+ # offset += dim1
+ # indices[start + 1:end:2] = offset
+ # return indices
+
+
+def linesGridIndices(dim0, dim1):
+ """Generate indices to draw a grid of vertices as lines.
+
+ Vertices are expected to be stored as row-major (i.e., C contiguous).
+
+ :param int dim0: The number of rows of vertices.
+ :param int dim1: The number of columns of vertices.
+ :return: The vertex indices.
+ :rtype: 1D numpy.ndarray of uint32
+ """
+ # Horizontal and vertical lines
+ nbsegmentalongdim1 = 2 * (dim1 - 1)
+ nbsegmentalongdim0 = 2 * (dim0 - 1)
+
+ indices = numpy.empty(nbsegmentalongdim1 * dim0 +
+ nbsegmentalongdim0 * dim1,
+ dtype=numpy.uint32)
+
+ # Line indices over dim0
+ onedim1line = (numpy.arange(nbsegmentalongdim1,
+ dtype=numpy.uint32) + 1) // 2
+ indices[:dim0 * nbsegmentalongdim1] = \
+ (dim1 * numpy.arange(dim0, dtype=numpy.uint32)[:, None] +
+ onedim1line[None, :]).ravel()
+
+ # Line indices over dim1
+ onedim0line = (numpy.arange(nbsegmentalongdim0,
+ dtype=numpy.uint32) + 1) // 2
+ indices[dim0 * nbsegmentalongdim1:] = \
+ (numpy.arange(dim1, dtype=numpy.uint32)[:, None] +
+ dim1 * onedim0line[None, :]).ravel()
+
+ return indices
+
+
+# intersection ################################################################
+
+def angleBetweenVectors(refVector, vectors, norm=None):
+ """Return the angle between 2 vectors.
+
+ :param refVector: Coordinates of the reference vector.
+ :type refVector: numpy.ndarray of shape: (NCoords,)
+ :param vectors: Coordinates of the vector(s) to get angle from reference.
+ :type vectors: numpy.ndarray of shape: (NCoords,) or (NbVector, NCoords)
+ :param norm: A direction vector giving an orientation to the angles
+ or None.
+ :returns: The angles in radians in [0, pi] if norm is None
+ else in [0, 2pi].
+ :rtype: float or numpy.ndarray of shape (NbVectors,)
+ """
+ singlevector = len(vectors.shape) == 1
+ if singlevector: # Make it a 2D array for the computation
+ vectors = vectors.reshape(1, -1)
+
+ assert len(refVector.shape) == 1
+ assert len(vectors.shape) == 2
+ assert len(refVector) == vectors.shape[1]
+
+ # Normalize vectors
+ refVector /= numpy.linalg.norm(refVector)
+ vectors = numpy.array([v / numpy.linalg.norm(v) for v in vectors])
+
+ dots = numpy.sum(refVector * vectors, axis=-1)
+ angles = numpy.arccos(numpy.clip(dots, -1., 1.))
+ if norm is not None:
+ signs = numpy.sum(norm * numpy.cross(refVector, vectors), axis=-1) < 0.
+ angles[signs] = numpy.pi * 2. - angles[signs]
+
+ return angles[0] if singlevector else angles
+
+
+def segmentPlaneIntersect(s0, s1, planeNorm, planePt):
+ """Compute the intersection of a segment with a plane.
+
+ :param s0: First end of the segment
+ :type s0: 1D numpy.ndarray-like of length 3
+ :param s1: Second end of the segment
+ :type s1: 1D numpy.ndarray-like of length 3
+ :param planeNorm: Normal vector of the plane.
+ :type planeNorm: numpy.ndarray of shape: (3,)
+ :param planePt: A point of the plane.
+ :type planePt: numpy.ndarray of shape: (3,)
+ :return: The intersection points. The number of points goes
+ from 0 (no intersection) to 2 (segment in the plane)
+ :rtype: list of numpy.ndarray
+ """
+ s0, s1 = numpy.asarray(s0), numpy.asarray(s1)
+
+ segdir = s1 - s0
+ dotnormseg = numpy.dot(planeNorm, segdir)
+ if dotnormseg == 0:
+ # line and plane are parallels
+ if numpy.dot(planeNorm, planePt - s0) == 0: # segment is in plane
+ return [s0, s1]
+ else: # No intersection
+ return []
+
+ alpha = - numpy.dot(planeNorm, s0 - planePt) / dotnormseg
+ if 0. <= alpha <= 1.: # Intersection with segment
+ return [s0 + alpha * segdir]
+ else: # intersection outside segment
+ return []
+
+
+def boxPlaneIntersect(boxVertices, boxLineIndices, planeNorm, planePt):
+ """Return intersection points between a box and a plane.
+
+ :param boxVertices: Position of the corners of the box.
+ :type boxVertices: numpy.ndarray with shape: (8, 3)
+ :param boxLineIndices: Indices of the box edges.
+ :type boxLineIndices: numpy.ndarray-like with shape: (12, 2)
+ :param planeNorm: Normal vector of the plane.
+ :type planeNorm: numpy.ndarray of shape: (3,)
+ :param planePt: A point of the plane.
+ :type planePt: numpy.ndarray of shape: (3,)
+ :return: The found intersection points
+ :rtype: numpy.ndarray with 2 dimensions
+ """
+ segments = numpy.take(boxVertices, boxLineIndices, axis=0)
+
+ points = set() # Gather unique intersection points
+ for seg in segments:
+ for point in segmentPlaneIntersect(seg[0], seg[1], planeNorm, planePt):
+ points.add(tuple(point))
+ points = numpy.array(list(points))
+
+ if len(points) <= 2:
+ return numpy.array(())
+ elif len(points) == 3:
+ return points
+ else: # len(points) > 3
+ # Order point to have a polyline lying on the unit cube's faces
+ vectors = points - numpy.mean(points, axis=0)
+ angles = angleBetweenVectors(vectors[0], vectors, planeNorm)
+ points = numpy.take(points, numpy.argsort(angles), axis=0)
+ return points
+
+
+def clipSegmentToBounds(segment, bounds):
+ """Clip segment to volume aligned with axes.
+
+ :param numpy.ndarray segment: (p0, p1)
+ :param numpy.ndarray bounds: (lower corner, upper corner)
+ :return: Either clipped (p0, p1) or None if outside volume
+ :rtype: Union[None,List[numpy.ndarray]]
+ """
+ segment = numpy.array(segment, copy=False)
+ bounds = numpy.array(bounds, copy=False)
+
+ p0, p1 = segment
+ # Get intersection points of ray with volume boundary planes
+ # Line equation: P = offset * delta + p0
+ delta = p1 - p0
+ deltaNotZero = numpy.array(delta, copy=True)
+ deltaNotZero[deltaNotZero == 0] = numpy.nan # Invalidated to avoid division by zero
+ offsets = ((bounds - p0) / deltaNotZero).reshape(-1)
+ points = offsets.reshape(-1, 1) * delta + p0
+
+ # Avoid precision errors by using bounds value
+ points.shape = 2, 3, 3 # Reshape 1 point per bound value
+ for dim in range(3):
+ points[:, dim, dim] = bounds[:, dim]
+ points.shape = -1, 3 # Set back to 2D array
+
+ # Find intersection points that are included in the volume
+ mask = numpy.logical_and(numpy.all(bounds[0] <= points, axis=1),
+ numpy.all(points <= bounds[1], axis=1))
+ intersections = numpy.unique(offsets[mask])
+ if len(intersections) != 2:
+ return None
+
+ intersections.sort()
+ # Do p1 first as p0 is need to compute it
+ if intersections[1] < 1: # clip p1
+ segment[1] = intersections[1] * delta + p0
+ if intersections[0] > 0: # clip p0
+ segment[0] = intersections[0] * delta + p0
+ return segment
+
+
+def segmentVolumeIntersect(segment, nbins):
+ """Get bin indices intersecting with segment
+
+ It should work with N dimensions.
+ Coordinate convention (z, y, x) or (x, y, z) should not matter
+ as long as segment and nbins are consistent.
+
+ :param numpy.ndarray segment:
+ Segment end points as a 2xN array of coordinates
+ :param numpy.ndarray nbins:
+ Shape of the volume with same coordinates order as segment
+ :return: List of bins indices as a 2D array or None if no bins
+ :rtype: Union[None,numpy.ndarray]
+ """
+ segment = numpy.asarray(segment)
+ nbins = numpy.asarray(nbins)
+
+ assert segment.ndim == 2
+ assert segment.shape[0] == 2
+ assert nbins.ndim == 1
+ assert segment.shape[1] == nbins.size
+
+ dim = len(nbins)
+
+ bounds = numpy.array((numpy.zeros_like(nbins), nbins))
+ segment = clipSegmentToBounds(segment, bounds)
+ if segment is None:
+ return None # Segment outside volume
+ p0, p1 = segment
+
+ # Get intersections
+
+ # Get coordinates of bin edges crossing the segment
+ clipped = numpy.ceil(numpy.clip(segment, 0, nbins))
+ start = numpy.min(clipped, axis=0)
+ stop = numpy.max(clipped, axis=0) # stop is NOT included
+ edgesByDim = [numpy.arange(start[i], stop[i]) for i in range(dim)]
+
+ # Line equation: P = t * delta + p0
+ delta = p1 - p0
+
+ # Get bin edge/line intersections as sorted points along the line
+ # Get corresponding line parameters
+ t = []
+ if numpy.all(0 <= p0) and numpy.all(p0 <= nbins):
+ t.append([0.]) # p0 within volume, add it
+ t += [(edgesByDim[i] - p0[i]) / delta[i] for i in range(dim) if delta[i] != 0]
+ if numpy.all(0 <= p1) and numpy.all(p1 <= nbins):
+ t.append([1.]) # p1 within volume, add it
+ t = numpy.concatenate(t)
+ t.sort(kind='mergesort')
+
+ # Remove duplicates
+ unique = numpy.ones((len(t),), dtype=bool)
+ numpy.not_equal(t[1:], t[:-1], out=unique[1:])
+ t = t[unique]
+
+ if len(t) < 2:
+ return None # Not enough intersection points
+
+ # bin edges/line intersection points
+ points = t.reshape(-1, 1) * delta + p0
+ centers = (points[:-1] + points[1:]) / 2.
+ bins = numpy.floor(centers).astype(numpy.int64)
+ return bins
+
+
+# Plane #######################################################################
+
+class Plane(event.Notifier):
+ """Object handling a plane and notifying plane changes.
+
+ :param point: A point on the plane.
+ :type point: 3-tuple of float.
+ :param normal: Normal of the plane.
+ :type normal: 3-tuple of float.
+ """
+
+ def __init__(self, point=(0., 0., 0.), normal=(0., 0., 1.)):
+ super(Plane, self).__init__()
+
+ assert len(point) == 3
+ self._point = numpy.array(point, copy=True, dtype=numpy.float32)
+ assert len(normal) == 3
+ self._normal = numpy.array(normal, copy=True, dtype=numpy.float32)
+ self.notify()
+
+ def setPlane(self, point=None, normal=None):
+ """Set plane point and normal and notify.
+
+ :param point: A point on the plane.
+ :type point: 3-tuple of float or None.
+ :param normal: Normal of the plane.
+ :type normal: 3-tuple of float or None.
+ """
+ planechanged = False
+
+ if point is not None:
+ assert len(point) == 3
+ point = numpy.array(point, copy=True, dtype=numpy.float32)
+ if not numpy.all(numpy.equal(self._point, point)):
+ self._point = point
+ planechanged = True
+
+ if normal is not None:
+ assert len(normal) == 3
+ normal = numpy.array(normal, copy=True, dtype=numpy.float32)
+
+ norm = numpy.linalg.norm(normal)
+ if norm != 0.:
+ normal /= norm
+
+ if not numpy.all(numpy.equal(self._normal, normal)):
+ self._normal = normal
+ planechanged = True
+
+ if planechanged:
+ _logger.debug('Plane updated:\n\tpoint: %s\n\tnormal: %s',
+ str(self._point), str(self._normal))
+ self.notify()
+
+ @property
+ def point(self):
+ """A point on the plane."""
+ return self._point.copy()
+
+ @point.setter
+ def point(self, point):
+ self.setPlane(point=point)
+
+ @property
+ def normal(self):
+ """The (normalized) normal of the plane."""
+ return self._normal.copy()
+
+ @normal.setter
+ def normal(self, normal):
+ self.setPlane(normal=normal)
+
+ @property
+ def parameters(self):
+ """Plane equation parameters: a*x + b*y + c*z + d = 0."""
+ return numpy.append(self._normal,
+ - numpy.dot(self._point, self._normal))
+
+ @parameters.setter
+ def parameters(self, parameters):
+ assert len(parameters) == 4
+ parameters = numpy.array(parameters, dtype=numpy.float32)
+
+ # Normalize normal
+ norm = numpy.linalg.norm(parameters[:3])
+ if norm != 0:
+ parameters /= norm
+
+ normal = parameters[:3]
+ point = - parameters[3] * normal
+ self.setPlane(point, normal)
+
+ @property
+ def isPlane(self):
+ """True if a plane is defined (i.e., ||normal|| != 0)."""
+ return numpy.any(self.normal != 0.)
+
+ def move(self, step):
+ """Move the plane of step along the normal."""
+ self.point += step * self.normal
+
+ def segmentIntersection(self, s0, s1):
+ """Compute the plane intersection with segment [s0, s1].
+
+ :param s0: First end of the segment
+ :type s0: 1D numpy.ndarray-like of length 3
+ :param s1: Second end of the segment
+ :type s1: 1D numpy.ndarray-like of length 3
+ :return: The intersection points. The number of points goes
+ from 0 (no intersection) to 2 (segment in the plane)
+ :rtype: list of 1D numpy.ndarray
+ """
+ if not self.isPlane:
+ return []
+ else:
+ return segmentPlaneIntersect(s0, s1, self.normal, self.point)
diff --git a/src/silx/gui/plot3d/scene/viewport.py b/src/silx/gui/plot3d/scene/viewport.py
new file mode 100644
index 0000000..6de640e
--- /dev/null
+++ b/src/silx/gui/plot3d/scene/viewport.py
@@ -0,0 +1,603 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a class to control a viewport on the rendering window.
+
+The :class:`Viewport` describes a Viewport rendering a scene.
+The attribute :attr:`scene` is the root group of the scene tree.
+:class:`RenderContext` handles the current state during rendering.
+"""
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+
+import string
+import numpy
+
+from silx.gui.colors import rgba
+
+from ..._glutils import gl
+
+from . import camera
+from . import event
+from . import transform
+from .function import DirectionalLight, ClippingPlane, Fog
+
+
+class RenderContext(object):
+ """Handle a current rendering context.
+
+ An instance of this class is passed to rendering method through
+ the scene during render.
+
+ User should NEVER use an instance of this class beyond the method
+ it is passed to as an argument (i.e., do not keep a reference to it).
+
+ :param Viewport viewport: The viewport doing the rendering.
+ :param Context glContext: The operating system OpenGL context in use.
+ """
+
+ _FRAGMENT_SHADER_SRC = string.Template("""
+ void scene_post(vec4 cameraPosition) {
+ gl_FragColor = $fogCall(gl_FragColor, cameraPosition);
+ }
+ """)
+
+ def __init__(self, viewport, glContext):
+ self._viewport = viewport
+ self._glContext = glContext
+ self._transformStack = [viewport.camera.extrinsic]
+ self._clipPlane = ClippingPlane(normal=(0., 0., 0.))
+
+ # cache
+ self.__cache = {}
+
+ def cache(self, key, factory, *args, **kwargs):
+ """Lazy-loading cache to store values in the context for rendering
+
+ :param key: The key to retrieve
+ :param factory: A callback taking args and kwargs as arguments
+ and returning the value to store.
+ :return: The stored or newly allocated value
+ """
+ if key not in self.__cache:
+ self.__cache[key] = factory(*args, **kwargs)
+ return self.__cache[key]
+
+ @property
+ def viewport(self):
+ """Viewport doing the current rendering"""
+ return self._viewport
+
+ @property
+ def glCtx(self):
+ """The OpenGL context in use"""
+ return self._glContext
+
+ @property
+ def objectToCamera(self):
+ """The current transform from object to camera coords.
+
+ Do not modify.
+ """
+ return self._transformStack[-1]
+
+ @property
+ def projection(self):
+ """Projection transform.
+
+ Do not modify.
+ """
+ return self.viewport.camera.intrinsic
+
+ @property
+ def objectToNDC(self):
+ """The transform from object to NDC (this includes projection).
+
+ Do not modify.
+ """
+ return transform.StaticTransformList(
+ (self.projection, self.objectToCamera))
+
+ def pushTransform(self, transform_, multiply=True):
+ """Push a :class:`Transform` on the transform stack.
+
+ :param Transform transform_: The transform to add to the stack.
+ :param bool multiply:
+ True (the default) to multiply with the top of the stack,
+ False to push the transform as is without multiplication.
+ """
+ if multiply:
+ assert len(self._transformStack) >= 1
+ transform_ = transform.StaticTransformList(
+ (self._transformStack[-1], transform_))
+
+ self._transformStack.append(transform_)
+
+ def popTransform(self):
+ """Pop the transform on top of the stack.
+
+ :return: The Transform that is popped from the stack.
+ """
+ assert len(self._transformStack) > 1
+ return self._transformStack.pop()
+
+ @property
+ def clipper(self):
+ """The current clipping plane (ClippingPlane)"""
+ return self._clipPlane
+
+ def setClipPlane(self, point=(0., 0., 0.), normal=(0., 0., 0.)):
+ """Set the clipping plane to use
+
+ For now only handles a single clipping plane.
+
+ :param point: A point of the plane
+ :type point: 3-tuple of float
+ :param normal: Normal vector of the plane or (0, 0, 0) for no clipping
+ :type normal: 3-tuple of float
+ """
+ self._clipPlane = ClippingPlane(point, normal)
+
+ def setupProgram(self, program):
+ """Sets-up uniforms of a program using the context shader functions.
+
+ :param GLProgram program: The program to set-up.
+ It MUST be in use and using the context function.
+ """
+ self.clipper.setupProgram(self, program)
+ self.viewport.fog.setupProgram(self, program)
+
+ @property
+ def fragDecl(self):
+ """Fragment shader declaration for scene shader functions"""
+ return '\n'.join((
+ self.clipper.fragDecl,
+ self.viewport.fog.fragDecl,
+ self._FRAGMENT_SHADER_SRC.substitute(
+ fogCall=self.viewport.fog.fragCall)))
+
+ @property
+ def fragCallPre(self):
+ """Fragment shader call for scene shader functions (to do first)
+
+ It takes the camera position (vec4) as argument.
+ """
+ return self.clipper.fragCall
+
+ @property
+ def fragCallPost(self):
+ """Fragment shader call for scene shader functions (to do last)
+
+ It takes the camera position (vec4) as argument.
+ """
+ return "scene_post"
+
+
+class Viewport(event.Notifier):
+ """Rendering a single scene through a camera in part of a framebuffer.
+
+ :param int framebuffer: The framebuffer ID this viewport is rendering into
+ """
+
+ def __init__(self, framebuffer=0):
+ from . import Group # Here to avoid cyclic import
+ super(Viewport, self).__init__()
+ self._dirty = True
+ self._origin = 0, 0
+ self._size = 1, 1
+ self._framebuffer = int(framebuffer)
+ self.scene = Group() # The stuff to render, add overlaid scenes?
+ self.scene._setParent(self)
+ self.scene.addListener(self._changed)
+ self._background = 0., 0., 0., 1.
+ self._camera = camera.Camera(fovy=30., near=1., far=100.,
+ position=(0., 0., 12.))
+ self._camera.addListener(self._changed)
+ self._transforms = transform.TransformList([self._camera])
+
+ self._light = DirectionalLight(direction=(0., 0., -1.),
+ ambient=(0.3, 0.3, 0.3),
+ diffuse=(0.7, 0.7, 0.7))
+ self._light.addListener(self._changed)
+ self._fog = Fog()
+ self._fog.isOn = False
+ self._fog.addListener(self._changed)
+
+ @property
+ def transforms(self):
+ """Proxy of camera transforms.
+
+ Do not modify the list.
+ """
+ return self._transforms
+
+ def _changed(self, *args, **kwargs):
+ """Callback handling scene updates"""
+ self._dirty = True
+ self.notify()
+
+ @property
+ def dirty(self):
+ """True if scene is dirty and needs redisplay."""
+ return self._dirty
+
+ def resetDirty(self):
+ """Mark the scene as not being dirty.
+
+ To call after rendering.
+ """
+ self._dirty = False
+
+ @property
+ def background(self):
+ """Viewport's background color (4-tuple of float in [0, 1] or None)
+
+ The background color is used to clear to viewport.
+ If None, the viewport is not cleared
+ """
+ return self._background
+
+ @background.setter
+ def background(self, color):
+ if color is not None:
+ color = rgba(color)
+ if self._background != color:
+ self._background = color
+ self._changed()
+
+ @property
+ def camera(self):
+ """The camera used to render the scene."""
+ return self._camera
+
+ @property
+ def light(self):
+ """The light used to render the scene."""
+ return self._light
+
+ @property
+ def fog(self):
+ """The fog function used to render the scene"""
+ return self._fog
+
+ @property
+ def origin(self):
+ """Origin (ox, oy) of the viewport in pixels"""
+ return self._origin
+
+ @origin.setter
+ def origin(self, origin):
+ ox, oy = origin
+ origin = int(ox), int(oy)
+ if origin != self._origin:
+ self._origin = origin
+ self._changed()
+
+ @property
+ def size(self):
+ """Size (width, height) of the viewport in pixels"""
+ return self._size
+
+ @size.setter
+ def size(self, size):
+ w, h = size
+ size = int(w), int(h)
+ if size != self._size:
+ self._size = size
+
+ self.camera.intrinsic.size = size
+ self._changed()
+
+ @property
+ def shape(self):
+ """Shape (height, width) of the viewport in pixels.
+
+ This is a convenient wrapper to the inverse of size.
+ """
+ return self._size[1], self._size[0]
+
+ @shape.setter
+ def shape(self, shape):
+ self.size = shape[1], shape[0]
+
+ @property
+ def framebuffer(self):
+ """The framebuffer ID this viewport is rendering into (int)"""
+ return self._framebuffer
+
+ @framebuffer.setter
+ def framebuffer(self, framebuffer):
+ self._framebuffer = int(framebuffer)
+
+ def render(self, glContext):
+ """Perform the rendering of the viewport
+
+ :param Context glContext: The context used for rendering"""
+ # Get a chance to run deferred delete
+ glContext.cleanGLGarbage()
+
+ # OpenGL set-up: really need to be done once
+ ox, oy = self.origin
+ w, h = self.size
+ gl.glViewport(ox, oy, w, h)
+
+ gl.glEnable(gl.GL_SCISSOR_TEST)
+ gl.glScissor(ox, oy, w, h)
+
+ gl.glEnable(gl.GL_BLEND)
+ gl.glBlendFunc(gl.GL_SRC_ALPHA, gl.GL_ONE_MINUS_SRC_ALPHA)
+
+ gl.glEnable(gl.GL_DEPTH_TEST)
+ gl.glDepthFunc(gl.GL_LEQUAL)
+ gl.glDepthRange(0., 1.)
+
+ # gl.glEnable(gl.GL_POLYGON_OFFSET_FILL)
+ # gl.glPolygonOffset(1., 1.)
+
+ gl.glHint(gl.GL_LINE_SMOOTH_HINT, gl.GL_NICEST)
+ gl.glEnable(gl.GL_LINE_SMOOTH)
+
+ if self.background is None:
+ gl.glClear(gl.GL_STENCIL_BUFFER_BIT |
+ gl.GL_DEPTH_BUFFER_BIT)
+ else:
+ gl.glClearColor(*self.background)
+
+ # Prepare OpenGL
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT |
+ gl.GL_STENCIL_BUFFER_BIT |
+ gl.GL_DEPTH_BUFFER_BIT)
+
+ ctx = RenderContext(self, glContext)
+ self.scene.render(ctx)
+ self.scene.postRender(ctx)
+
+ def adjustCameraDepthExtent(self):
+ """Update camera depth extent to fit the scene bounds.
+
+ Only near and far planes are updated.
+ The scene might still not be fully visible
+ (e.g., if spanning behind the viewpoint with perspective projection).
+ """
+ bounds = self.scene.bounds(transformed=True)
+ if bounds is None:
+ bounds = numpy.array(((0., 0., 0.), (1., 1., 1.)),
+ dtype=numpy.float32)
+ bounds = self.camera.extrinsic.transformBounds(bounds)
+
+ if isinstance(self.camera.intrinsic, transform.Perspective):
+ # This needs to be reworked
+ zbounds = - bounds[:, 2]
+ zextent = max(numpy.fabs(zbounds[0] - zbounds[1]), 0.0001)
+ near = max(zextent / 1000., 0.95 * zbounds[1])
+ far = max(near + 0.1, 1.05 * zbounds[0])
+
+ self.camera.intrinsic.setDepthExtent(near, far)
+ elif isinstance(self.camera.intrinsic, transform.Orthographic):
+ # Makes sure z bounds are included
+ border = max(abs(bounds[:, 2]))
+ self.camera.intrinsic.setDepthExtent(-border, border)
+ else:
+ raise RuntimeError('Unsupported camera', self.camera.intrinsic)
+
+ def resetCamera(self):
+ """Change camera to have the whole scene in the viewing frustum.
+
+ It updates the camera position and depth extent.
+ Camera sight direction and up are not affected.
+ """
+ bounds = self.scene.bounds(transformed=True)
+ if bounds is None:
+ bounds = numpy.array(((0., 0., 0.), (1., 1., 1.)),
+ dtype=numpy.float32)
+ self.camera.resetCamera(bounds)
+
+ def orbitCamera(self, direction, angle=1.):
+ """Rotate the camera around center of the scene.
+
+ :param str direction: Direction of movement relative to image plane.
+ In: 'up', 'down', 'left', 'right'.
+ :param float angle: he angle in degrees of the rotation.
+ """
+ bounds = self.scene.bounds(transformed=True)
+ if bounds is None:
+ bounds = numpy.array(((0., 0., 0.), (1., 1., 1.)),
+ dtype=numpy.float32)
+ center = 0.5 * (bounds[0] + bounds[1])
+ self.camera.orbit(direction, center, angle)
+
+ def moveCamera(self, direction, step=0.1):
+ """Move the camera relative to the image plane.
+
+ :param str direction: Direction relative to image plane.
+ One of: 'up', 'down', 'left', 'right',
+ 'forward', 'backward'.
+ :param float step: The ratio of data to step for each pan.
+ """
+ bounds = self.scene.bounds(transformed=True)
+ if bounds is None:
+ bounds = numpy.array(((0., 0., 0.), (1., 1., 1.)),
+ dtype=numpy.float32)
+ bounds = self.camera.extrinsic.transformBounds(bounds)
+ center = 0.5 * (bounds[0] + bounds[1])
+ ndcCenter = self.camera.intrinsic.transformPoint(
+ center, perspectiveDivide=True)
+
+ step *= 2. # NDC has size 2
+
+ if direction == 'up':
+ ndcCenter[1] -= step
+ elif direction == 'down':
+ ndcCenter[1] += step
+
+ elif direction == 'right':
+ ndcCenter[0] -= step
+ elif direction == 'left':
+ ndcCenter[0] += step
+
+ elif direction == 'forward':
+ ndcCenter[2] += step
+ elif direction == 'backward':
+ ndcCenter[2] -= step
+
+ else:
+ raise ValueError('Unsupported direction: %s' % direction)
+
+ newCenter = self.camera.intrinsic.transformPoint(
+ ndcCenter, direct=False, perspectiveDivide=True)
+
+ self.camera.move(direction, numpy.linalg.norm(newCenter - center))
+
+ def windowToNdc(self, winX, winY, checkInside=True):
+ """Convert position from window to normalized device coordinates.
+
+ If window coordinates are int, they are moved half a pixel
+ to be positioned at the center of pixel.
+
+ :param winX: X window coord, origin left.
+ :param winY: Y window coord, origin top.
+ :param bool checkInside: If True, returns None if position is
+ outside viewport.
+ :return: (x, y) Normalize device coordinates in [-1, 1] or None.
+ Origin center, x to the right, y goes upward.
+ """
+ ox, oy = self._origin
+ width, height = self.size
+
+ # If int, move it to the center of pixel
+ if isinstance(winX, int):
+ winX += 0.5
+ if isinstance(winY, int):
+ winY += 0.5
+
+ x, y = winX - ox, winY - oy
+
+ if checkInside and (x < 0. or x > width or y < 0. or y > height):
+ return None # Out of viewport
+
+ ndcx = 2. * x / float(width) - 1.
+ ndcy = 1. - 2. * y / float(height)
+ return ndcx, ndcy
+
+ def ndcToWindow(self, ndcX, ndcY, checkInside=True):
+ """Convert position from normalized device coordinates (NDC) to window.
+
+ :param float ndcX: X NDC coord.
+ :param float ndcY: Y NDC coord.
+ :param bool checkInside: If True, returns None if position is
+ outside viewport.
+ :return: (x, y) window coordinates or None.
+ Origin top-left, x to the right, y goes downward.
+ """
+ if (checkInside and
+ (ndcX < -1. or ndcX > 1. or ndcY < -1. or ndcY > 1.)):
+ return None # Outside viewport
+
+ ox, oy = self._origin
+ width, height = self.size
+
+ winx = ox + width * 0.5 * (ndcX + 1.)
+ winy = oy + height * 0.5 * (1. - ndcY)
+ return winx, winy
+
+ def _pickNdcZGL(self, x, y, offset=0):
+ """Retrieve depth from depth buffer and return corresponding NDC Z.
+
+ :param int x: In pixels in window coordinates, origin left.
+ :param int y: In pixels in window coordinates, origin top.
+ :param int offset: Number of pixels to look at around the given pixel
+
+ :return: Normalize device Z coordinate of depth in [-1, 1]
+ or None if outside viewport.
+ :rtype: float or None
+ """
+ ox, oy = self._origin
+ width, height = self.size
+
+ x = int(x)
+ y = height - int(y) # Invert y coord
+
+ if x < ox or x > ox + width or y < oy or y > oy + height:
+ # Outside viewport
+ return None
+
+ # Get depth from depth buffer in [0., 1.]
+ # Bind used framebuffer to get depth
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.framebuffer)
+
+ if offset == 0: # Fast path
+ # glReadPixels is not GL|ES friendly
+ depth = gl.glReadPixels(
+ x, y, 1, 1, gl.GL_DEPTH_COMPONENT, gl.GL_FLOAT)[0]
+ else:
+ offset = abs(int(offset))
+ size = 2*offset + 1
+ depthPatch = gl.glReadPixels(
+ x - offset, y - offset,
+ size, size,
+ gl.GL_DEPTH_COMPONENT, gl.GL_FLOAT)
+ depthPatch = depthPatch.ravel() # Work in 1D
+
+ # TODO cache sortedIndices to avoid computing it each time
+ # Compute distance of each pixels to the center of the patch
+ offsetToCenter = numpy.arange(- offset, offset + 1, dtype=numpy.float32) ** 2
+ sqDistToCenter = numpy.add.outer(offsetToCenter, offsetToCenter)
+
+ # Use distance to center to sort values from the patch
+ sortedIndices = numpy.argsort(sqDistToCenter.ravel())
+ sortedValues = depthPatch[sortedIndices]
+
+ # Take first depth that is not 1 in the sorted values
+ hits = sortedValues[sortedValues != 1.]
+ depth = 1. if len(hits) == 0 else hits[0]
+
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
+
+ # Z in NDC in [-1., 1.]
+ return float(depth) * 2. - 1.
+
+ def _getXZYGL(self, x, y):
+ ndc = self.windowToNdc(x, y)
+ if ndc is None:
+ return None # Outside viewport
+ ndcz = self._pickNdcZGL(x, y)
+ ndcpos = numpy.array((ndc[0], ndc[1], ndcz, 1.), dtype=numpy.float32)
+
+ camerapos = self.camera.intrinsic.transformPoint(
+ ndcpos, direct=False, perspectiveDivide=True)
+
+ scenepos = self.camera.extrinsic.transformPoint(camerapos,
+ direct=False)
+ return scenepos[:3]
+
+ def pick(self, x, y):
+ pass
+ # ndcX, ndcY = self.windowToNdc(x, y)
+ # ndcNearPt = ndcX, ndcY, -1.
+ # ndcFarPT = ndcX, ndcY, 1.
diff --git a/src/silx/gui/plot3d/scene/window.py b/src/silx/gui/plot3d/scene/window.py
new file mode 100644
index 0000000..b92c404
--- /dev/null
+++ b/src/silx/gui/plot3d/scene/window.py
@@ -0,0 +1,433 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a class for Viewports rendering on the screen.
+
+The :class:`Window` renders a list of Viewports in the current framebuffer.
+The rendering can be performed in an off-screen framebuffer that is only
+updated when the scene has changed and not each time Qt is requiring a repaint.
+
+The :class:`Context` and :class:`ContextGL2` represent the operating system
+OpenGL context and handle OpenGL resources.
+"""
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "10/01/2017"
+
+
+import weakref
+import numpy
+
+from ..._glutils import gl
+from ... import _glutils
+
+from . import event
+
+
+class Context(object):
+ """Correspond to an operating system OpenGL context.
+
+ User should NEVER use an instance of this class beyond the method
+ it is passed to as an argument (i.e., do not keep a reference to it).
+
+ :param glContextHandle: System specific OpenGL context handle.
+ """
+
+ def __init__(self, glContextHandle):
+ self._context = glContextHandle
+ self._isCurrent = False
+ self._devicePixelRatio = 1.0
+
+ @property
+ def isCurrent(self):
+ """Whether this OpenGL context is the current one or not."""
+ return self._isCurrent
+
+ def setCurrent(self, isCurrent=True):
+ """Set the state of the OpenGL context to reflect OpenGL state.
+
+ This should not be called from the scene graph, only in the
+ wrapper that handle the OpenGL context to reflect its state.
+
+ :param bool isCurrent: The state of the system OpenGL context.
+ """
+ self._isCurrent = bool(isCurrent)
+
+ @property
+ def devicePixelRatio(self):
+ """Ratio between device and device independent pixels (float)
+
+ This is useful for font rendering.
+ """
+ return self._devicePixelRatio
+
+ @devicePixelRatio.setter
+ def devicePixelRatio(self, ratio):
+ assert ratio > 0
+ self._devicePixelRatio = float(ratio)
+
+ def __enter__(self):
+ self.setCurrent(True)
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.setCurrent(False)
+
+ @property
+ def glContext(self):
+ """The handle to the OpenGL context provided by the system."""
+ return self._context
+
+ def cleanGLGarbage(self):
+ """This is releasing OpenGL resource that are no longer used."""
+ pass
+
+
+class ContextGL2(Context):
+ """Handle a system GL2 context.
+
+ User should NEVER use an instance of this class beyond the method
+ it is passed to as an argument (i.e., do not keep a reference to it).
+
+ :param glContextHandle: System specific OpenGL context handle.
+ """
+ def __init__(self, glContextHandle):
+ super(ContextGL2, self).__init__(glContextHandle)
+
+ self._programs = {} # GL programs already compiled
+ self._vbos = {} # GL Vbos already set
+ self._vboGarbage = [] # Vbos waiting to be discarded
+
+ # programs
+
+ def prog(self, vertexShaderSrc, fragmentShaderSrc, attrib0='position'):
+ """Cache program within context.
+
+ WARNING: No clean-up.
+
+ :param str vertexShaderSrc: Vertex shader source code
+ :param str fragmentShaderSrc: Fragment shader source code
+ :param str attrib0:
+ Attribute's name to bind to position 0 (default: 'position').
+ On some platform, this attribute MUST be active and with an
+ array attached to it in order for the rendering to occur....
+ """
+ assert self.isCurrent
+ key = vertexShaderSrc, fragmentShaderSrc, attrib0
+ program = self._programs.get(key, None)
+ if program is None:
+ program = _glutils.Program(
+ vertexShaderSrc, fragmentShaderSrc, attrib0=attrib0)
+ self._programs[key] = program
+ return program
+
+ # VBOs
+
+ def makeVbo(self, data=None, sizeInBytes=None,
+ usage=None, target=None):
+ """Create a VBO in this context with the data.
+
+ Current limitations:
+
+ - One array per VBO
+ - Do not support sharing VertexBuffer across VboAttrib
+
+ Automatically discards the VBO when the returned
+ :class:`VertexBuffer` istance is deleted.
+
+ :param numpy.ndarray data: 2D array of data to store in VBO or None.
+ :param int sizeInBytes: Size of the VBO or None.
+ It should be <= data.nbytes if both are given.
+ :param usage: OpenGL usage define in VertexBuffer._USAGES.
+ :param target: OpenGL target in VertexBuffer._TARGETS.
+ :return: The VertexBuffer created in this context.
+ """
+ assert self.isCurrent
+ vbo = _glutils.VertexBuffer(data, sizeInBytes, usage, target)
+ vboref = weakref.ref(vbo, self._deadVbo)
+ # weakref is hashable as far as target is
+ self._vbos[vboref] = vbo.name
+ return vbo
+
+ def makeVboAttrib(self, data, usage=None, target=None):
+ """Create a VBO from data and returns the associated VBOAttrib.
+
+ Automatically discards the VBO when the returned
+ :class:`VBOAttrib` istance is deleted.
+
+ :param numpy.ndarray data: 2D array of data to store in VBO or None.
+ :param usage: OpenGL usage define in VertexBuffer._USAGES.
+ :param target: OpenGL target in VertexBuffer._TARGETS.
+ :returns: A VBOAttrib instance created in this context.
+ """
+ assert self.isCurrent
+ vbo = self.makeVbo(data, usage=usage, target=target)
+
+ assert len(data.shape) <= 2
+ dimension = 1 if len(data.shape) == 1 else data.shape[1]
+
+ return _glutils.VertexBufferAttrib(
+ vbo,
+ type_=_glutils.numpyToGLType(data.dtype),
+ size=data.shape[0],
+ dimension=dimension,
+ offset=0,
+ stride=0)
+
+ def _deadVbo(self, vboRef):
+ """Callback handling dead VBOAttribs."""
+ vboid = self._vbos.pop(vboRef)
+ if self.isCurrent:
+ # Direct delete if context is active
+ gl.glDeleteBuffers(vboid)
+ else:
+ # Deferred VBO delete if context is not active
+ self._vboGarbage.append(vboid)
+
+ def cleanGLGarbage(self):
+ """Delete OpenGL resources that are pending for destruction.
+
+ This requires the associated OpenGL context to be active.
+ This is meant to be called before rendering.
+ """
+ assert self.isCurrent
+ if self._vboGarbage:
+ vboids = self._vboGarbage
+ gl.glDeleteBuffers(vboids)
+ self._vboGarbage = []
+
+
+class Window(event.Notifier):
+ """OpenGL Framebuffer where to render viewports
+
+ :param str mode: Rendering mode to use:
+
+ - 'direct' to render everything for each render call
+ - 'framebuffer' to cache viewport rendering in a texture and
+ update the texture only when needed.
+ """
+
+ _position = numpy.array(((-1., -1., 0., 0.),
+ (1., -1., 1., 0.),
+ (-1., 1., 0., 1.),
+ (1., 1., 1., 1.)),
+ dtype=numpy.float32)
+
+ _shaders = ("""
+ attribute vec4 position;
+ varying vec2 textureCoord;
+
+ void main(void) {
+ gl_Position = vec4(position.x, position.y, 0., 1.);
+ textureCoord = position.zw;
+ }
+ """,
+ """
+ uniform sampler2D texture;
+ varying vec2 textureCoord;
+
+ void main(void) {
+ gl_FragColor = texture2D(texture, textureCoord);
+ gl_FragColor.a = 1.0;
+ }
+ """)
+
+ def __init__(self, mode='framebuffer'):
+ super(Window, self).__init__()
+ self._dirty = True
+ self._size = 0, 0
+ self._contexts = {} # To map system GL context id to Context objects
+ self._viewports = event.NotifierList()
+ self._viewports.addListener(self._updated)
+ self._framebufferid = 0
+ self._framebuffers = {} # Cache of framebuffers
+
+ assert mode in ('direct', 'framebuffer')
+ self._isframebuffer = mode == 'framebuffer'
+
+ @property
+ def dirty(self):
+ """True if this object or any attached viewports is dirty."""
+ for viewport in self._viewports:
+ if viewport.dirty:
+ return True
+ return self._dirty
+
+ @property
+ def size(self):
+ """Size (width, height) of the window in pixels"""
+ return self._size
+
+ @size.setter
+ def size(self, size):
+ w, h = size
+ size = int(w), int(h)
+ if size != self._size:
+ self._size = size
+ self._dirty = True
+ self.notify()
+
+ @property
+ def shape(self):
+ """Shape (height, width) of the window in pixels.
+
+ This is a convenient wrapper to the reverse of size.
+ """
+ return self._size[1], self._size[0]
+
+ @shape.setter
+ def shape(self, shape):
+ self.size = shape[1], shape[0]
+
+ @property
+ def viewports(self):
+ """List of viewports to render in the corresponding framebuffer"""
+ return self._viewports
+
+ @viewports.setter
+ def viewports(self, iterable):
+ self._viewports.removeListener(self._updated)
+ self._viewports = event.NotifierList(iterable)
+ self._viewports.addListener(self._updated)
+ self._updated(self)
+
+ def _updated(self, source, *args, **kwargs):
+ self._dirty = True
+ self.notify(*args, **kwargs)
+
+ framebufferid = property(lambda self: self._framebufferid,
+ doc="Framebuffer ID used to perform rendering")
+
+ def grab(self, glcontext):
+ """Returns the raster of the scene as an RGB numpy array
+
+ :returns: OpenGL scene RGB bitmap
+ as an array of dimension (height, width, 3)
+ :rtype: numpy.ndarray of uint8
+ """
+ height, width = self.shape
+ image = numpy.empty((height, width, 3), dtype=numpy.uint8)
+
+ previousFramebuffer = gl.glGetInteger(gl.GL_FRAMEBUFFER_BINDING)
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.framebufferid)
+ gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1)
+ gl.glReadPixels(
+ 0, 0, width, height, gl.GL_RGB, gl.GL_UNSIGNED_BYTE, image)
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, previousFramebuffer)
+
+ # glReadPixels gives bottom to top,
+ # while images are stored as top to bottom
+ image = numpy.flipud(image)
+
+ return numpy.array(image, copy=False, order='C')
+
+ def render(self, glcontext, devicePixelRatio):
+ """Perform the rendering of attached viewports
+
+ :param glcontext: System identifier of the OpenGL context
+ :param float devicePixelRatio:
+ Ratio between device and device-independent pixels
+ """
+ if self.size == (0, 0):
+ return
+
+ if glcontext not in self._contexts:
+ self._contexts[glcontext] = ContextGL2(glcontext) # New context
+
+ with self._contexts[glcontext] as context:
+ context.devicePixelRatio = devicePixelRatio
+ if self._isframebuffer:
+ self._renderWithOffscreenFramebuffer(context)
+ else:
+ self._renderDirect(context)
+
+ self._dirty = False
+
+ def _renderDirect(self, context):
+ """Perform the direct rendering of attached viewports
+
+ :param Context context: Object wrapping OpenGL context
+ """
+ for viewport in self._viewports:
+ viewport.framebuffer = self.framebufferid
+ viewport.render(context)
+ viewport.resetDirty()
+
+ def _renderWithOffscreenFramebuffer(self, context):
+ """Renders viewports in a texture and render this texture on screen.
+
+ The texture is updated only if viewport or size has changed.
+
+ :param ContextGL2 context: Object wrappign OpenGL context
+ """
+ if self.dirty or context not in self._framebuffers:
+ # Need to redraw framebuffer content
+
+ if (context not in self._framebuffers or
+ self._framebuffers[context].shape != self.shape):
+ # Need to rebuild framebuffer
+
+ if context in self._framebuffers:
+ self._framebuffers[context].discard()
+
+ fbo = _glutils.FramebufferTexture(gl.GL_RGBA,
+ shape=self.shape,
+ minFilter=gl.GL_NEAREST,
+ magFilter=gl.GL_NEAREST,
+ wrap=gl.GL_CLAMP_TO_EDGE)
+ self._framebuffers[context] = fbo
+ self._framebufferid = fbo.name
+
+ # Render in framebuffer
+ with self._framebuffers[context]:
+ self._renderDirect(context)
+
+ # Render framebuffer texture to screen
+ fbo = self._framebuffers[context]
+ height, width = fbo.shape
+
+ program = context.prog(*self._shaders)
+ program.use()
+
+ gl.glViewport(0, 0, width, height)
+ gl.glDisable(gl.GL_BLEND)
+ gl.glDisable(gl.GL_DEPTH_TEST)
+ gl.glDisable(gl.GL_SCISSOR_TEST)
+ # gl.glScissor(0, 0, width, height)
+ gl.glClearColor(0., 0., 0., 0.)
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT)
+ gl.glUniform1i(program.uniforms['texture'], fbo.texture.texUnit)
+ gl.glEnableVertexAttribArray(program.attributes['position'])
+ gl.glVertexAttribPointer(program.attributes['position'],
+ 4,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0,
+ self._position)
+ fbo.texture.bind()
+ gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(self._position))
+ gl.glBindTexture(gl.GL_TEXTURE_2D, 0)
diff --git a/src/silx/gui/plot3d/setup.py b/src/silx/gui/plot3d/setup.py
new file mode 100644
index 0000000..59c0230
--- /dev/null
+++ b/src/silx/gui/plot3d/setup.py
@@ -0,0 +1,50 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/07/2016"
+
+
+from numpy.distutils.misc_util import Configuration
+
+
+def configuration(parent_package='', top_path=None):
+ config = Configuration('plot3d', parent_package, top_path)
+ config.add_subpackage('_model')
+ config.add_subpackage('actions')
+ config.add_subpackage('items')
+ config.add_subpackage('scene')
+ config.add_subpackage('scene.test')
+ config.add_subpackage('tools')
+ config.add_subpackage('tools.test')
+ config.add_subpackage('test')
+ config.add_subpackage('utils')
+ return config
+
+
+if __name__ == "__main__":
+ from numpy.distutils.core import setup
+
+ setup(configuration=configuration)
diff --git a/src/silx/gui/plot3d/test/__init__.py b/src/silx/gui/plot3d/test/__init__.py
new file mode 100644
index 0000000..83491ad
--- /dev/null
+++ b/src/silx/gui/plot3d/test/__init__.py
@@ -0,0 +1,25 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""plot3d test suite."""
diff --git a/src/silx/gui/plot3d/test/testGL.py b/src/silx/gui/plot3d/test/testGL.py
new file mode 100644
index 0000000..a7309a9
--- /dev/null
+++ b/src/silx/gui/plot3d/test/testGL.py
@@ -0,0 +1,73 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+# ###########################################################################*/
+"""Test OpenGL"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "10/08/2017"
+
+
+import logging
+import unittest
+
+from silx.gui._glutils import gl, OpenGLWidget
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui import qt
+
+
+_logger = logging.getLogger(__name__)
+
+
+class TestOpenGL(TestCaseQt):
+ """Tests of OpenGL widget."""
+
+ class OpenGLWidgetLogger(OpenGLWidget):
+ """Widget logging information of available OpenGL version"""
+
+ def __init__(self):
+ self._dump = False
+ super(TestOpenGL.OpenGLWidgetLogger, self).__init__(version=(1, 0))
+
+ def paintOpenGL(self):
+ """Perform the rendering and logging"""
+ if not self._dump:
+ self._dump = True
+ _logger.info('OpenGL info:')
+ _logger.info('\tQt OpenGL context version: %d.%d', *self.getOpenGLVersion())
+ _logger.info('\tGL_VERSION: %s' % gl.glGetString(gl.GL_VERSION))
+ _logger.info('\tGL_SHADING_LANGUAGE_VERSION: %s' %
+ gl.glGetString(gl.GL_SHADING_LANGUAGE_VERSION))
+ _logger.debug('\tGL_EXTENSIONS: %s' % gl.glGetString(gl.GL_EXTENSIONS))
+
+ gl.glClearColor(1., 1., 1., 1.)
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT)
+
+ def testOpenGL(self):
+ """Log OpenGL version using an OpenGLWidget"""
+ super(TestOpenGL, self).setUp()
+ widget = self.OpenGLWidgetLogger()
+ widget.show()
+ widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.qWaitForWindowExposed(widget)
+ widget.close()
diff --git a/src/silx/gui/plot3d/test/testScalarFieldView.py b/src/silx/gui/plot3d/test/testScalarFieldView.py
new file mode 100644
index 0000000..e6535fc
--- /dev/null
+++ b/src/silx/gui/plot3d/test/testScalarFieldView.py
@@ -0,0 +1,128 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+# ###########################################################################*/
+"""Test ScalarFieldView widget"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+
+import logging
+import unittest
+
+import numpy
+
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui import qt
+
+from silx.gui.plot3d.ScalarFieldView import ScalarFieldView
+from silx.gui.plot3d.SFViewParamTree import TreeView
+
+
+_logger = logging.getLogger(__name__)
+
+
+class TestScalarFieldView(TestCaseQt, ParametricTestCase):
+ """Tests of ScalarFieldView widget."""
+
+ def setUp(self):
+ super(TestScalarFieldView, self).setUp()
+ self.widget = ScalarFieldView()
+ self.widget.show()
+
+ paramTreeWidget = TreeView()
+ paramTreeWidget.setSfView(self.widget)
+
+ dock = qt.QDockWidget()
+ dock.setWidget(paramTreeWidget)
+ self.widget.addDockWidget(qt.Qt.BottomDockWidgetArea, dock)
+
+ # Commented as it slows down the tests
+ # self.qWaitForWindowExposed(self.widget)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ del self.widget
+ super(TestScalarFieldView, self).tearDown()
+
+ @staticmethod
+ def _buildData(size):
+ """Make a 3D dataset"""
+ coords = numpy.linspace(-10, 10, size)
+ z = coords.reshape(-1, 1, 1)
+ y = coords.reshape(1, -1, 1)
+ x = coords.reshape(1, 1, -1)
+ return numpy.sin(x * y * z) / (x * y * z)
+
+ def testSimple(self):
+ """Set the data and an isosurface"""
+ data = self._buildData(size=32)
+
+ self.widget.setData(data)
+ self.widget.addIsosurface(0.5, (1., 0., 0., 0.5))
+ self.widget.addIsosurface(0.7, qt.QColor('green'))
+ self.qapp.processEvents()
+
+ def testNotFinite(self):
+ """Test with NaN and inf in data set"""
+
+ # Some NaNs and inf
+ data = self._buildData(size=32)
+ data[8, :, :] = numpy.nan
+ data[16, :, :] = numpy.inf
+ data[24, :, :] = - numpy.inf
+
+ self.widget.addIsosurface(0.5, 'red')
+ self.widget.setData(data, copy=True)
+ self.qapp.processEvents()
+ self.widget.setData(None)
+
+ # All NaNs or inf
+ data = numpy.empty((4, 4, 4), dtype=numpy.float32)
+ for value in (numpy.nan, numpy.inf):
+ with self.subTest(value=str(value)):
+ data[:] = value
+ self.widget.setData(data, copy=True)
+ self.qapp.processEvents()
+
+ def testIsoSliderNormalization(self):
+ """Test set TreeView with a different isoslider normalization"""
+ data = self._buildData(size=32)
+
+ self.widget.setData(data)
+ self.widget.addIsosurface(0.5, (1., 0., 0., 0.5))
+ self.widget.addIsosurface(0.7, qt.QColor('green'))
+ self.qapp.processEvents()
+
+ # Add a second TreeView
+ paramTreeWidget = TreeView(self.widget)
+ paramTreeWidget.setIsoLevelSliderNormalization('arcsinh')
+ paramTreeWidget.setSfView(self.widget)
+
+ dock = qt.QDockWidget()
+ dock.setWidget(paramTreeWidget)
+ self.widget.addDockWidget(qt.Qt.BottomDockWidgetArea, dock)
diff --git a/src/silx/gui/plot3d/test/testSceneWidget.py b/src/silx/gui/plot3d/test/testSceneWidget.py
new file mode 100644
index 0000000..fc96781
--- /dev/null
+++ b/src/silx/gui/plot3d/test/testSceneWidget.py
@@ -0,0 +1,72 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+# ###########################################################################*/
+"""Test SceneWidget"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "06/03/2019"
+
+
+import unittest
+
+import numpy
+
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui import qt
+
+from silx.gui.plot3d.SceneWidget import SceneWidget
+
+
+class TestSceneWidget(TestCaseQt, ParametricTestCase):
+ """Tests SceneWidget picking feature"""
+
+ def setUp(self):
+ super(TestSceneWidget, self).setUp()
+ self.widget = SceneWidget()
+ self.widget.show()
+ self.qWaitForWindowExposed(self.widget)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ del self.widget
+ super(TestSceneWidget, self).tearDown()
+
+ def testFogEffect(self):
+ """Test fog effect on scene primitive"""
+ image = self.widget.addImage(numpy.arange(100).reshape(10, 10))
+ scatter = self.widget.add3DScatter(*numpy.random.random(4000).reshape(4, -1))
+ scatter.setTranslation(10, 10)
+ scatter.setScale(10, 10, 10)
+
+ self.widget.resetZoom('front')
+ self.qapp.processEvents()
+
+ self.widget.setFogMode(self.widget.FogMode.LINEAR)
+ self.qapp.processEvents()
+
+ self.widget.setFogMode(self.widget.FogMode.NONE)
+ self.qapp.processEvents()
diff --git a/src/silx/gui/plot3d/test/testSceneWidgetPicking.py b/src/silx/gui/plot3d/test/testSceneWidgetPicking.py
new file mode 100644
index 0000000..d4d8db7
--- /dev/null
+++ b/src/silx/gui/plot3d/test/testSceneWidgetPicking.py
@@ -0,0 +1,314 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+# ###########################################################################*/
+"""Test SceneWidget picking feature"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/10/2018"
+
+
+import unittest
+
+import numpy
+
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui import qt
+
+from silx.gui.plot3d.SceneWidget import SceneWidget, items
+
+
+class TestSceneWidgetPicking(TestCaseQt, ParametricTestCase):
+ """Tests SceneWidget picking feature"""
+
+ def setUp(self):
+ super(TestSceneWidgetPicking, self).setUp()
+ self.widget = SceneWidget()
+ self.widget.resize(300, 300)
+ self.widget.show()
+ # self.qWaitForWindowExposed(self.widget)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ del self.widget
+ super(TestSceneWidgetPicking, self).tearDown()
+
+ def _widgetCenter(self):
+ """Returns widget center"""
+ size = self.widget.size()
+ return size.width() // 2, size.height() // 2
+
+ def testPickImage(self):
+ """Test picking of ImageData and ImageRgba items"""
+ imageData = items.ImageData()
+ imageData.setData(numpy.arange(100).reshape(10, 10))
+
+ imageRgba = items.ImageRgba()
+ imageRgba.setData(
+ numpy.arange(300, dtype=numpy.uint8).reshape(10, 10, 3))
+
+ for item in (imageData, imageRgba):
+ with self.subTest(item=item.__class__.__name__):
+ # Add item
+ self.widget.clearItems()
+ self.widget.addItem(item)
+ self.widget.resetZoom('front')
+ self.qapp.processEvents()
+
+ # Picking on data (at widget center)
+ picking = list(self.widget.pickItems(*self._widgetCenter()))
+
+ self.assertEqual(len(picking), 1)
+ self.assertIs(picking[0].getItem(), item)
+ self.assertEqual(picking[0].getPositions('ndc').shape, (1, 3))
+ data = picking[0].getData()
+ self.assertEqual(len(data), 1)
+ self.assertTrue(numpy.array_equal(
+ data,
+ item.getData()[picking[0].getIndices()]))
+
+ # Picking outside data
+ picking = list(self.widget.pickItems(1, 1))
+ self.assertEqual(len(picking), 0)
+
+ def testPickScatter(self):
+ """Test picking of Scatter2D and Scatter3D items"""
+ data = numpy.arange(100)
+
+ scatter2d = items.Scatter2D()
+ scatter2d.setData(x=data, y=data, value=data)
+
+ scatter3d = items.Scatter3D()
+ scatter3d.setData(x=data, y=data, z=data, value=data)
+
+ for item in (scatter2d, scatter3d):
+ with self.subTest(item=item.__class__.__name__):
+ # Add item
+ self.widget.clearItems()
+ self.widget.addItem(item)
+ self.widget.resetZoom('front')
+ self.qapp.processEvents()
+
+ # Picking on data (at widget center)
+ picking = list(self.widget.pickItems(*self._widgetCenter()))
+
+ self.assertEqual(len(picking), 1)
+ self.assertIs(picking[0].getItem(), item)
+ nbPos = len(picking[0].getPositions('ndc'))
+ data = picking[0].getData()
+ self.assertEqual(nbPos, len(data))
+ self.assertTrue(numpy.array_equal(
+ data,
+ item.getValueData()[picking[0].getIndices()]))
+
+ # Picking outside data
+ picking = list(self.widget.pickItems(1, 1))
+ self.assertEqual(len(picking), 0)
+
+ def testPickVolume(self):
+ """Test picking of volume CutPlane and Isosurface items"""
+ for dtype in (numpy.float32, numpy.complex64):
+ with self.subTest(dtype=dtype):
+ refData = numpy.arange(10**3, dtype=dtype).reshape(10, 10, 10)
+ volume = self.widget.addVolume(refData)
+ if dtype == numpy.complex64:
+ volume.setComplexMode(volume.ComplexMode.REAL)
+ refData = numpy.real(refData)
+ self.widget.resetZoom('front')
+
+ cutplane = volume.getCutPlanes()[0]
+ if dtype == numpy.complex64:
+ cutplane.setComplexMode(volume.ComplexMode.REAL)
+ cutplane.getColormap().setVRange(0, 100)
+ cutplane.setNormal((0, 0, 1))
+
+ # Picking on data without anything displayed
+ cutplane.setVisible(False)
+ picking = list(self.widget.pickItems(*self._widgetCenter()))
+ self.assertEqual(len(picking), 0)
+
+ # Picking on data with the cut plane
+ cutplane.setVisible(True)
+ picking = list(self.widget.pickItems(*self._widgetCenter()))
+
+ self.assertEqual(len(picking), 1)
+ self.assertIs(picking[0].getItem(), cutplane)
+ data = picking[0].getData()
+ self.assertEqual(len(data), 1)
+ self.assertEqual(picking[0].getPositions().shape, (1, 3))
+ self.assertTrue(numpy.array_equal(
+ data,
+ refData[picking[0].getIndices()]))
+
+ # Picking on data with an isosurface
+ isosurface = volume.addIsosurface(
+ level=500, color=(1., 0., 0., .5))
+ picking = list(self.widget.pickItems(*self._widgetCenter()))
+ self.assertEqual(len(picking), 2)
+ self.assertIs(picking[0].getItem(), cutplane)
+ self.assertIs(picking[1].getItem(), isosurface)
+ self.assertEqual(picking[1].getPositions().shape, (1, 3))
+ data = picking[1].getData()
+ self.assertEqual(len(data), 1)
+ self.assertTrue(numpy.array_equal(
+ data,
+ refData[picking[1].getIndices()]))
+
+ # Picking outside data
+ picking = list(self.widget.pickItems(1, 1))
+ self.assertEqual(len(picking), 0)
+
+ self.widget.clearItems()
+
+ def testPickMesh(self):
+ """Test picking of Mesh items"""
+
+ triangles = items.Mesh()
+ triangles.setData(
+ position=((0, 0, 0), (1, 0, 0), (1, 1, 0),
+ (0, 0, 0), (1, 1, 0), (0, 1, 0)),
+ color=(1, 0, 0, 1),
+ mode='triangles')
+ triangleStrip = items.Mesh()
+ triangleStrip.setData(
+ position=(((1, 0, 0), (0, 0, 0), (1, 1, 0), (0, 1, 0))),
+ color=(0, 1, 0, 1),
+ mode='triangle_strip')
+ triangleFan = items.Mesh()
+ triangleFan.setData(
+ position=((0, 0, 0), (1, 0, 0), (1, 1, 0), (0, 1, 0)),
+ color=(0, 0, 1, 1),
+ mode='fan')
+
+ for item in (triangles, triangleStrip, triangleFan):
+ with self.subTest(mode=item.getDrawMode()):
+ # Add item
+ self.widget.clearItems()
+ self.widget.addItem(item)
+ self.widget.resetZoom('front')
+ self.qapp.processEvents()
+
+ # Picking on data (at widget center)
+ picking = list(self.widget.pickItems(*self._widgetCenter()))
+
+ self.assertEqual(len(picking), 1)
+ self.assertIs(picking[0].getItem(), item)
+ nbPos = len(picking[0].getPositions())
+ data = picking[0].getData()
+ self.assertEqual(nbPos, len(data))
+ self.assertTrue(numpy.array_equal(
+ data,
+ item.getPositionData()[picking[0].getIndices()]))
+
+ # Picking outside data
+ picking = list(self.widget.pickItems(1, 1))
+ self.assertEqual(len(picking), 0)
+
+ def testPickMeshWithIndices(self):
+ """Test picking of Mesh items defined by indices"""
+
+ triangles = items.Mesh()
+ triangles.setData(
+ position=((0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)),
+ color=(1, 0, 0, 1),
+ indices=numpy.array( # dummy triangles and square
+ (0, 0, 1, 0, 1, 2, 1, 2, 3), dtype=numpy.uint8),
+ mode='triangles')
+ triangleStrip = items.Mesh()
+ triangleStrip.setData(
+ position=((0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)),
+ color=(0, 1, 0, 1),
+ indices=numpy.array( # dummy triangles and square
+ (1, 0, 0, 1, 2, 3), dtype=numpy.uint8),
+ mode='triangle_strip')
+ triangleFan = items.Mesh()
+ triangleFan.setData(
+ position=((0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)),
+ color=(0, 0, 1, 1),
+ indices=numpy.array( # dummy triangle, square, dummy
+ (1, 1, 0, 2, 3, 3), dtype=numpy.uint8),
+ mode='fan')
+
+ for item in (triangles, triangleStrip, triangleFan):
+ with self.subTest(mode=item.getDrawMode()):
+ # Add item
+ self.widget.clearItems()
+ self.widget.addItem(item)
+ self.widget.resetZoom('front')
+ self.qapp.processEvents()
+
+ # Picking on data (at widget center)
+ picking = list(self.widget.pickItems(*self._widgetCenter()))
+
+ self.assertEqual(len(picking), 1)
+ self.assertIs(picking[0].getItem(), item)
+ nbPos = len(picking[0].getPositions())
+ data = picking[0].getData()
+ self.assertEqual(nbPos, len(data))
+ self.assertTrue(numpy.array_equal(
+ data,
+ item.getPositionData()[picking[0].getIndices()]))
+
+ # Picking outside data
+ picking = list(self.widget.pickItems(1, 1))
+ self.assertEqual(len(picking), 0)
+
+ def testPickCylindricalMesh(self):
+ """Test picking of Box, Cylinder and Hexagon items"""
+
+ positions = numpy.array(((0., 0., 0.), (1., 1., 0.), (2., 2., 0.)))
+ box = items.Box()
+ box.setData(position=positions)
+ cylinder = items.Cylinder()
+ cylinder.setData(position=positions)
+ hexagon = items.Hexagon()
+ hexagon.setData(position=positions)
+
+ for item in (box, cylinder, hexagon):
+ with self.subTest(item=item.__class__.__name__):
+ # Add item
+ self.widget.clearItems()
+ self.widget.addItem(item)
+ self.widget.resetZoom('front')
+ self.qapp.processEvents()
+
+ # Picking on data (at widget center)
+ picking = list(self.widget.pickItems(*self._widgetCenter()))
+
+ self.assertEqual(len(picking), 1)
+ self.assertIs(picking[0].getItem(), item)
+ nbPos = len(picking[0].getPositions())
+ data = picking[0].getData()
+ print(item.__class__.__name__, [positions[1]], data)
+ self.assertTrue(numpy.all(numpy.equal(positions[1], data)))
+ self.assertEqual(nbPos, len(data))
+ self.assertTrue(numpy.array_equal(
+ data,
+ item.getPosition()[picking[0].getIndices()]))
+
+ # Picking outside data
+ picking = list(self.widget.pickItems(1, 1))
+ self.assertEqual(len(picking), 0)
diff --git a/src/silx/gui/plot3d/test/testSceneWindow.py b/src/silx/gui/plot3d/test/testSceneWindow.py
new file mode 100644
index 0000000..6b61335
--- /dev/null
+++ b/src/silx/gui/plot3d/test/testSceneWindow.py
@@ -0,0 +1,233 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2019-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+# ###########################################################################*/
+"""Test SceneWindow"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "22/03/2019"
+
+
+import unittest
+
+import numpy
+
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui import qt
+
+from silx.gui.plot3d.SceneWindow import SceneWindow
+from silx.gui.plot3d.items import HeightMapData, HeightMapRGBA
+
+class TestSceneWindow(TestCaseQt, ParametricTestCase):
+ """Tests SceneWidget picking feature"""
+
+ def setUp(self):
+ super(TestSceneWindow, self).setUp()
+ self.window = SceneWindow()
+ self.window.show()
+ self.qWaitForWindowExposed(self.window)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.window.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.window.close()
+ del self.window
+ super(TestSceneWindow, self).tearDown()
+
+ def testAdd(self):
+ """Test add basic scene primitive"""
+ sceneWidget = self.window.getSceneWidget()
+ items = []
+
+ # RGB image
+ image = sceneWidget.addImage(numpy.random.random(
+ 10*10*3).astype(numpy.float32).reshape(10, 10, 3))
+ image.setLabel('RGB image')
+ items.append(image)
+ self.assertEqual(sceneWidget.getItems(), tuple(items))
+
+ # Data image
+ image = sceneWidget.addImage(
+ numpy.arange(100, dtype=numpy.float32).reshape(10, 10))
+ image.setTranslation(10.)
+ items.append(image)
+ self.assertEqual(sceneWidget.getItems(), tuple(items))
+
+ # 2D scatter
+ scatter = sceneWidget.add2DScatter(
+ *numpy.random.random(3000).astype(numpy.float32).reshape(3, -1),
+ index=0)
+ scatter.setTranslation(0, 10)
+ scatter.setScale(10, 10, 10)
+ items.insert(0, scatter)
+ self.assertEqual(sceneWidget.getItems(), tuple(items))
+
+ # 3D scatter
+ scatter = sceneWidget.add3DScatter(
+ *numpy.random.random(4000).astype(numpy.float32).reshape(4, -1))
+ scatter.setTranslation(10, 10)
+ scatter.setScale(10, 10, 10)
+ items.append(scatter)
+ self.assertEqual(sceneWidget.getItems(), tuple(items))
+
+ # 3D array of float
+ volume = sceneWidget.addVolume(
+ numpy.arange(10**3, dtype=numpy.float32).reshape(10, 10, 10))
+ volume.setTranslation(0, 0, 10)
+ volume.setRotation(45, (0, 0, 1))
+ volume.addIsosurface(500, 'red')
+ volume.getCutPlanes()[0].getColormap().setName('viridis')
+ items.append(volume)
+ self.assertEqual(sceneWidget.getItems(), tuple(items))
+
+ # 3D array of complex
+ volume = sceneWidget.addVolume(
+ numpy.arange(10**3).reshape(10, 10, 10).astype(numpy.complex64))
+ volume.setTranslation(10, 0, 10)
+ volume.setRotation(45, (0, 0, 1))
+ volume.setComplexMode(volume.ComplexMode.REAL)
+ volume.addIsosurface(500, (1., 0., 0., .5))
+ items.append(volume)
+ self.assertEqual(sceneWidget.getItems(), tuple(items))
+
+ sceneWidget.resetZoom('front')
+ self.qapp.processEvents()
+
+ def testHeightMap(self):
+ """Test height map items"""
+ sceneWidget = self.window.getSceneWidget()
+
+ height = numpy.arange(10000).reshape(100, 100) /100.
+
+ for shape in ((100, 100), (4, 5), (150, 20), (110, 110)):
+ with self.subTest(shape=shape):
+ items = []
+
+ # Colormapped data height map
+ data = numpy.arange(numpy.prod(shape)).astype(numpy.float32).reshape(shape)
+
+ heightmap = HeightMapData()
+ heightmap.setData(height)
+ heightmap.setColormappedData(data)
+ heightmap.getColormap().setName('viridis')
+ items.append(heightmap)
+ sceneWidget.addItem(heightmap)
+
+ # RGBA height map
+ colors = numpy.zeros(shape + (3,), dtype=numpy.float32)
+ colors[:, :, 1] = numpy.random.random(shape)
+
+ heightmap = HeightMapRGBA()
+ heightmap.setData(height)
+ heightmap.setColorData(colors)
+ heightmap.setTranslation(100., 0., 0.)
+ items.append(heightmap)
+ sceneWidget.addItem(heightmap)
+
+ self.assertEqual(sceneWidget.getItems(), tuple(items))
+ sceneWidget.resetZoom('front')
+ self.qapp.processEvents()
+ sceneWidget.clearItems()
+
+ def testChangeContent(self):
+ """Test add/remove/clear items"""
+ sceneWidget = self.window.getSceneWidget()
+ items = []
+
+ # Add 2 images
+ image = numpy.arange(100, dtype=numpy.float32).reshape(10, 10)
+ items.append(sceneWidget.addImage(image))
+ items.append(sceneWidget.addImage(image))
+ self.qapp.processEvents()
+ self.assertEqual(sceneWidget.getItems(), tuple(items))
+
+ # Clear
+ sceneWidget.clearItems()
+ self.qapp.processEvents()
+ self.assertEqual(sceneWidget.getItems(), ())
+
+ # Add 2 images and remove first one
+ image = numpy.arange(100, dtype=numpy.float32).reshape(10, 10)
+ sceneWidget.addImage(image)
+ items = (sceneWidget.addImage(image),)
+ self.qapp.processEvents()
+
+ sceneWidget.removeItem(sceneWidget.getItems()[0])
+ self.qapp.processEvents()
+ self.assertEqual(sceneWidget.getItems(), items)
+
+ def testColors(self):
+ """Test setting scene colors"""
+ sceneWidget = self.window.getSceneWidget()
+
+ color = qt.QColor(128, 128, 128)
+ sceneWidget.setBackgroundColor(color)
+ self.assertEqual(sceneWidget.getBackgroundColor(), color)
+
+ color = qt.QColor(0, 0, 0)
+ sceneWidget.setForegroundColor(color)
+ self.assertEqual(sceneWidget.getForegroundColor(), color)
+
+ color = qt.QColor(255, 0, 0)
+ sceneWidget.setTextColor(color)
+ self.assertEqual(sceneWidget.getTextColor(), color)
+
+ color = qt.QColor(0, 255, 0)
+ sceneWidget.setHighlightColor(color)
+ self.assertEqual(sceneWidget.getHighlightColor(), color)
+
+ self.qapp.processEvents()
+
+ def testInteractiveMode(self):
+ """Test changing interactive mode"""
+ sceneWidget = self.window.getSceneWidget()
+ center = numpy.array((sceneWidget.width() //2, sceneWidget.height() // 2))
+
+ self.mouseMove(sceneWidget, pos=center)
+ self.mouseClick(sceneWidget, qt.Qt.LeftButton, pos=center)
+
+ volume = sceneWidget.addVolume(
+ numpy.arange(10**3).astype(numpy.float32).reshape(10, 10, 10))
+ sceneWidget.selection().setCurrentItem( volume.getCutPlanes()[0])
+ sceneWidget.resetZoom('side')
+
+ for mode in (None, 'rotate', 'pan', 'panSelectedPlane'):
+ with self.subTest(mode=mode):
+ sceneWidget.setInteractiveMode(mode)
+ self.qapp.processEvents()
+ self.assertEqual(sceneWidget.getInteractiveMode(), mode)
+
+ self.mouseMove(sceneWidget, pos=center)
+ self.mousePress(sceneWidget, qt.Qt.LeftButton, pos=center)
+ self.mouseMove(sceneWidget, pos=center-10)
+ self.mouseMove(sceneWidget, pos=center-20)
+ self.mouseRelease(sceneWidget, qt.Qt.LeftButton, pos=center-20)
+
+ self.keyPress(sceneWidget, qt.Qt.Key_Control)
+ self.mouseMove(sceneWidget, pos=center)
+ self.mousePress(sceneWidget, qt.Qt.LeftButton, pos=center)
+ self.mouseMove(sceneWidget, pos=center-10)
+ self.mouseMove(sceneWidget, pos=center-20)
+ self.mouseRelease(sceneWidget, qt.Qt.LeftButton, pos=center-20)
+ self.keyRelease(sceneWidget, qt.Qt.Key_Control)
diff --git a/src/silx/gui/plot3d/test/testStatsWidget.py b/src/silx/gui/plot3d/test/testStatsWidget.py
new file mode 100644
index 0000000..d452eb5
--- /dev/null
+++ b/src/silx/gui/plot3d/test/testStatsWidget.py
@@ -0,0 +1,201 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+# ###########################################################################*/
+"""Test silx.gui.plot.StatsWidget with SceneWidget and ScalarFieldView"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/01/2019"
+
+
+import unittest
+
+import numpy
+
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.plot.stats.stats import Stats
+from silx.gui import qt
+
+from silx.gui.plot.StatsWidget import BasicStatsWidget
+
+from silx.gui.plot3d.ScalarFieldView import ScalarFieldView
+from silx.gui.plot3d.SceneWidget import SceneWidget, items
+
+
+class TestSceneWidget(TestCaseQt, ParametricTestCase):
+ """Tests StatsWidget combined with SceneWidget"""
+
+ def setUp(self):
+ super(TestSceneWidget, self).setUp()
+ self.sceneWidget = SceneWidget()
+ self.sceneWidget.resize(300, 300)
+ self.sceneWidget.show()
+ self.statsWidget = BasicStatsWidget()
+ self.statsWidget.setPlot(self.sceneWidget)
+ # self.qWaitForWindowExposed(self.sceneWidget)
+
+ def tearDown(self):
+ Stats._getContext.cache_clear()
+ self.qapp.processEvents()
+ self.sceneWidget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.sceneWidget.close()
+ del self.sceneWidget
+ self.statsWidget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.statsWidget.close()
+ del self.statsWidget
+ super(TestSceneWidget, self).tearDown()
+
+ def test(self):
+ """Test StatsWidget with SceneWidget"""
+ # Prepare scene
+
+ # Data image
+ image = self.sceneWidget.addImage(numpy.arange(100).reshape(10, 10))
+ image.setLabel('Image')
+ # RGB image
+ imageRGB = self.sceneWidget.addImage(
+ numpy.arange(300, dtype=numpy.uint8).reshape(10, 10, 3))
+ imageRGB.setLabel('RGB Image')
+ # 2D scatter
+ data = numpy.arange(100)
+ scatter2D = self.sceneWidget.add2DScatter(x=data, y=data, value=data)
+ scatter2D.setLabel('2D Scatter')
+ # 3D scatter
+ scatter3D = self.sceneWidget.add3DScatter(x=data, y=data, z=data, value=data)
+ scatter3D.setLabel('3D Scatter')
+ # Add a group
+ group = items.GroupItem()
+ self.sceneWidget.addItem(group)
+ # 3D scalar field
+ data = numpy.arange(64**3).reshape(64, 64, 64)
+ scalarField = items.ScalarField3D()
+ scalarField.setData(data, copy=False)
+ scalarField.setLabel('3D Scalar field')
+ group.addItem(scalarField)
+
+ statsTable = self.statsWidget._getStatsTable()
+
+ # Test selection only
+ self.statsWidget.setDisplayOnlyActiveItem(True)
+ self.assertEqual(statsTable.rowCount(), 0)
+
+ self.sceneWidget.selection().setCurrentItem(group)
+ self.assertEqual(statsTable.rowCount(), 0)
+
+ for item in (image, scatter2D, scatter3D, scalarField):
+ with self.subTest('selection only', item=item.getLabel()):
+ self.sceneWidget.selection().setCurrentItem(item)
+ self.assertEqual(statsTable.rowCount(), 1)
+ self._checkItem(item)
+
+ # Test all data
+ self.statsWidget.setDisplayOnlyActiveItem(False)
+ self.assertEqual(statsTable.rowCount(), 4)
+
+ for item in (image, scatter2D, scatter3D, scalarField):
+ with self.subTest('all items', item=item.getLabel()):
+ self._checkItem(item)
+
+ def _checkItem(self, item):
+ """Check that item is in StatsTable and that stats are OK
+
+ :param silx.gui.plot3d.items.Item3D item:
+ """
+ if isinstance(item, (items.Scatter2D, items.Scatter3D)):
+ data = item.getValueData(copy=False)
+ else:
+ data = item.getData(copy=False)
+
+ statsTable = self.statsWidget._getStatsTable()
+ tableItems = statsTable._itemToTableItems(item)
+ self.assertTrue(len(tableItems) > 0)
+ self.assertEqual(tableItems['legend'].text(), item.getLabel())
+ self.assertEqual(float(tableItems['min'].text()), numpy.min(data))
+ self.assertEqual(float(tableItems['max'].text()), numpy.max(data))
+ # TODO
+
+
+class TestScalarFieldView(TestCaseQt):
+ """Tests StatsWidget combined with ScalarFieldView"""
+
+ def setUp(self):
+ super(TestScalarFieldView, self).setUp()
+ self.scalarFieldView = ScalarFieldView()
+ self.scalarFieldView.resize(300, 300)
+ self.scalarFieldView.show()
+ self.statsWidget = BasicStatsWidget()
+ self.statsWidget.setPlot(self.scalarFieldView)
+ # self.qWaitForWindowExposed(self.sceneWidget)
+
+ def tearDown(self):
+ Stats._getContext.cache_clear()
+ self.qapp.processEvents()
+ self.scalarFieldView.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.scalarFieldView.close()
+ del self.scalarFieldView
+ self.statsWidget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.statsWidget.close()
+ del self.statsWidget
+ super(TestScalarFieldView, self).tearDown()
+
+ def _getTextFor(self, row, name):
+ """Returns text in table at given row for column name
+
+ :param int row: Row number in the table
+ :param str name: Column id
+ :rtype: Union[str,None]
+ """
+ statsTable = self.statsWidget._getStatsTable()
+
+ for column in range(statsTable.columnCount()):
+ headerItem = statsTable.horizontalHeaderItem(column)
+ if headerItem.data(qt.Qt.UserRole) == name:
+ tableItem = statsTable.item(row, column)
+ return tableItem.text()
+
+ return None
+
+ def test(self):
+ """Test StatsWidget with ScalarFieldView"""
+ data = numpy.arange(64**3, dtype=numpy.float64).reshape(64, 64, 64)
+ self.scalarFieldView.setData(data)
+
+ statsTable = self.statsWidget._getStatsTable()
+
+ # Test selection only
+ self.statsWidget.setDisplayOnlyActiveItem(True)
+ self.assertEqual(statsTable.rowCount(), 1)
+
+ # Test all data
+ self.statsWidget.setDisplayOnlyActiveItem(False)
+ self.assertEqual(statsTable.rowCount(), 1)
+
+ for column in range(statsTable.columnCount()):
+ self.assertEqual(float(self._getTextFor(0, 'min')), numpy.min(data))
+ self.assertEqual(float(self._getTextFor(0, 'max')), numpy.max(data))
+ sum_ = numpy.sum(data)
+ comz = numpy.sum(numpy.arange(data.shape[0]) * numpy.sum(data, axis=(1, 2))) / sum_
+ comy = numpy.sum(numpy.arange(data.shape[1]) * numpy.sum(data, axis=(0, 2))) / sum_
+ comx = numpy.sum(numpy.arange(data.shape[2]) * numpy.sum(data, axis=(0, 1))) / sum_
+ self.assertEqual(self._getTextFor(0, 'COM'), str((comx, comy, comz)))
diff --git a/src/silx/gui/plot3d/tools/GroupPropertiesWidget.py b/src/silx/gui/plot3d/tools/GroupPropertiesWidget.py
new file mode 100644
index 0000000..146c2cd
--- /dev/null
+++ b/src/silx/gui/plot3d/tools/GroupPropertiesWidget.py
@@ -0,0 +1,202 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+""":class:`GroupPropertiesWidget` allows to reset properties in a GroupItem."""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+from ....gui import qt
+from ....gui.colors import Colormap
+from ....gui.dialog.ColormapDialog import ColormapDialog
+
+from ..items import SymbolMixIn, ColormapMixIn
+
+
+class GroupPropertiesWidget(qt.QWidget):
+ """Set properties of all items in a :class:`GroupItem`
+
+ :param QWidget parent:
+ """
+
+ MAX_MARKER_SIZE = 20
+ """Maximum value for marker size"""
+
+ MAX_LINE_WIDTH = 10
+ """Maximum value for line width"""
+
+ def __init__(self, parent=None):
+ super(GroupPropertiesWidget, self).__init__(parent)
+ self._group = None
+ self.setEnabled(False)
+
+ # Set widgets
+ layout = qt.QFormLayout(self)
+ self.setLayout(layout)
+
+ # Colormap
+ colormapButton = qt.QPushButton('Set...')
+ colormapButton.setToolTip("Set colormap for all items")
+ colormapButton.clicked.connect(self._colormapButtonClicked)
+ layout.addRow('Colormap', colormapButton)
+
+ self._markerComboBox = qt.QComboBox(self)
+ self._markerComboBox.addItems(SymbolMixIn.getSupportedSymbolNames())
+
+ # Marker
+ markerButton = qt.QPushButton('Set')
+ markerButton.setToolTip("Set marker for all items")
+ markerButton.clicked.connect(self._markerButtonClicked)
+
+ markerLayout = qt.QHBoxLayout()
+ markerLayout.setContentsMargins(0, 0, 0, 0)
+ markerLayout.addWidget(self._markerComboBox, 1)
+ markerLayout.addWidget(markerButton, 0)
+
+ layout.addRow('Marker', markerLayout)
+
+ # Marker size
+ self._markerSizeSlider = qt.QSlider()
+ self._markerSizeSlider.setOrientation(qt.Qt.Horizontal)
+ self._markerSizeSlider.setSingleStep(1)
+ self._markerSizeSlider.setRange(1, self.MAX_MARKER_SIZE)
+ self._markerSizeSlider.setValue(1)
+
+ markerSizeButton = qt.QPushButton('Set')
+ markerSizeButton.setToolTip("Set marker size for all items")
+ markerSizeButton.clicked.connect(self._markerSizeButtonClicked)
+
+ markerSizeLayout = qt.QHBoxLayout()
+ markerSizeLayout.setContentsMargins(0, 0, 0, 0)
+ markerSizeLayout.addWidget(qt.QLabel('1'))
+ markerSizeLayout.addWidget(self._markerSizeSlider, 1)
+ markerSizeLayout.addWidget(qt.QLabel(str(self.MAX_MARKER_SIZE)))
+ markerSizeLayout.addWidget(markerSizeButton, 0)
+
+ layout.addRow('Marker Size', markerSizeLayout)
+
+ # Line width
+ self._lineWidthSlider = qt.QSlider()
+ self._lineWidthSlider.setOrientation(qt.Qt.Horizontal)
+ self._lineWidthSlider.setSingleStep(1)
+ self._lineWidthSlider.setRange(1, self.MAX_LINE_WIDTH)
+ self._lineWidthSlider.setValue(1)
+
+ lineWidthButton = qt.QPushButton('Set')
+ lineWidthButton.setToolTip("Set line width for all items")
+ lineWidthButton.clicked.connect(self._lineWidthButtonClicked)
+
+ lineWidthLayout = qt.QHBoxLayout()
+ lineWidthLayout.setContentsMargins(0, 0, 0, 0)
+ lineWidthLayout.addWidget(qt.QLabel('1'))
+ lineWidthLayout.addWidget(self._lineWidthSlider, 1)
+ lineWidthLayout.addWidget(qt.QLabel(str(self.MAX_LINE_WIDTH)))
+ lineWidthLayout.addWidget(lineWidthButton, 0)
+
+ layout.addRow('Line Width', lineWidthLayout)
+
+ self._colormapDialog = None # To store dialog
+ self._colormap = Colormap()
+
+ def getGroup(self):
+ """Returns the :class:`GroupItem` this widget is attached to.
+
+ :rtype: Union[GroupItem, None]
+ """
+ return self._group
+
+ def setGroup(self, group):
+ """Set the :class:`GroupItem` this widget is attached to.
+
+ :param GroupItem group: GroupItem to control (or None)
+ """
+ self._group = group
+ if group is not None:
+ self.setEnabled(True)
+
+ def _colormapButtonClicked(self, checked=False):
+ """Handle colormap button clicked"""
+ group = self.getGroup()
+ if group is None:
+ return
+
+ if self._colormapDialog is None:
+ self._colormapDialog = ColormapDialog(self)
+ self._colormapDialog.setColormap(self._colormap)
+
+ previousColormap = self._colormapDialog.getColormap()
+ if self._colormapDialog.exec():
+ colormap = self._colormapDialog.getColormap()
+
+ for item in group.visit():
+ if isinstance(item, ColormapMixIn):
+ itemCmap = item.getColormap()
+ cmapName = colormap.getName()
+ if cmapName is not None:
+ itemCmap.setName(colormap.getName())
+ else:
+ itemCmap.setColormapLUT(colormap.getColormapLUT())
+ itemCmap.setNormalization(colormap.getNormalization())
+ itemCmap.setGammaNormalizationParameter(
+ colormap.getGammaNormalizationParameter())
+ itemCmap.setVRange(colormap.getVMin(), colormap.getVMax())
+ else:
+ # Reset colormap
+ self._colormapDialog.setColormap(previousColormap)
+
+ def _markerButtonClicked(self, checked=False):
+ """Handle marker set button clicked"""
+ group = self.getGroup()
+ if group is None:
+ return
+
+ marker = self._markerComboBox.currentText()
+ for item in group.visit():
+ if isinstance(item, SymbolMixIn):
+ item.setSymbol(marker)
+
+ def _markerSizeButtonClicked(self, checked=False):
+ """Handle marker size set button clicked"""
+ group = self.getGroup()
+ if group is None:
+ return
+
+ markerSize = self._markerSizeSlider.value()
+ for item in group.visit():
+ if isinstance(item, SymbolMixIn):
+ item.setSymbolSize(markerSize)
+
+ def _lineWidthButtonClicked(self, checked=False):
+ """Handle line width set button clicked"""
+ group = self.getGroup()
+ if group is None:
+ return
+
+ lineWidth = self._lineWidthSlider.value()
+ for item in group.visit():
+ if hasattr(item, 'setLineWidth'):
+ item.setLineWidth(lineWidth)
diff --git a/src/silx/gui/plot3d/tools/PositionInfoWidget.py b/src/silx/gui/plot3d/tools/PositionInfoWidget.py
new file mode 100644
index 0000000..99d6356
--- /dev/null
+++ b/src/silx/gui/plot3d/tools/PositionInfoWidget.py
@@ -0,0 +1,225 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a widget that displays data values of a SceneWidget.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "01/10/2018"
+
+
+import logging
+import weakref
+
+from ... import qt
+from .. import actions
+from .. import items
+from ..items import volume
+from ..SceneWidget import SceneWidget
+
+
+_logger = logging.getLogger(__name__)
+
+
+class PositionInfoWidget(qt.QWidget):
+ """Widget displaying information about picked position
+
+ :param QWidget parent: See :class:`QWidget`
+ """
+
+ def __init__(self, parent=None):
+ super(PositionInfoWidget, self).__init__(parent)
+ self._sceneWidgetRef = None
+
+ self.setToolTip("Double-click on a data point to show its value")
+ layout = qt.QBoxLayout(qt.QBoxLayout.LeftToRight, self)
+
+ self._xLabel = self._addInfoField('X')
+ self._yLabel = self._addInfoField('Y')
+ self._zLabel = self._addInfoField('Z')
+ self._dataLabel = self._addInfoField('Data')
+ self._itemLabel = self._addInfoField('Item')
+
+ layout.addStretch(1)
+
+ self._action = actions.mode.PickingModeAction(parent=self)
+ self._action.setText('Selection')
+ self._action.setToolTip(
+ 'Toggle selection information update with left button click')
+ self._action.sigSceneClicked.connect(self.pick)
+ self._action.changed.connect(self.__actionChanged)
+ self._action.setChecked(False) # Disabled by default
+ self.__actionChanged() # Sync action/widget
+
+ def __actionChanged(self):
+ """Handle toggle action change signal"""
+ if self.toggleAction().isChecked() != self.isEnabled():
+ self.setEnabled(self.toggleAction().isChecked())
+
+ def toggleAction(self):
+ """The action to toggle the picking mode.
+
+ :rtype: QAction
+ """
+ return self._action
+
+ def _addInfoField(self, label):
+ """Add a description: info widget to this widget
+
+ :param str label: Description label
+ :return: The QLabel used to display the info
+ :rtype: QLabel
+ """
+ subLayout = qt.QHBoxLayout()
+ subLayout.setContentsMargins(0, 0, 0, 0)
+
+ subLayout.addWidget(qt.QLabel(label + ':'))
+
+ widget = qt.QLabel('-')
+ widget.setAlignment(qt.Qt.AlignLeft | qt.Qt.AlignVCenter)
+ widget.setTextInteractionFlags(qt.Qt.TextSelectableByMouse)
+
+ metrics = widget.fontMetrics()
+ if qt.BINDING in ('PySide2', 'PyQt5'):
+ width = metrics.width("#######")
+ else: # Qt6
+ width = metrics.horizontalAdvance("#######")
+ widget.setMinimumWidth(width)
+ subLayout.addWidget(widget)
+
+ subLayout.addStretch(1)
+
+ layout = self.layout()
+ layout.addLayout(subLayout)
+ return widget
+
+ def getSceneWidget(self):
+ """Returns the associated :class:`SceneWidget` or None.
+
+ :rtype: Union[None,~silx.gui.plot3d.SceneWidget.SceneWidget]
+ """
+ if self._sceneWidgetRef is None:
+ return None
+ else:
+ return self._sceneWidgetRef()
+
+ def setSceneWidget(self, widget):
+ """Set the associated :class:`SceneWidget`
+
+ :param ~silx.gui.plot3d.SceneWidget.SceneWidget widget:
+ 3D scene for which to display information
+ """
+ if widget is not None and not isinstance(widget, SceneWidget):
+ raise ValueError("widget must be a SceneWidget or None")
+
+ self._sceneWidgetRef = None if widget is None else weakref.ref(widget)
+
+ self.toggleAction().setPlot3DWidget(widget)
+
+ def clear(self):
+ """Clean-up displayed values"""
+ for widget in (self._xLabel, self._yLabel, self._zLabel,
+ self._dataLabel, self._itemLabel):
+ widget.setText('-')
+
+ _SUPPORTED_ITEMS = (items.Scatter3D,
+ items.Scatter2D,
+ items.ImageData,
+ items.ImageRgba,
+ items.HeightMapData,
+ items.HeightMapRGBA,
+ items.Mesh,
+ items.Box,
+ items.Cylinder,
+ items.Hexagon,
+ volume.CutPlane,
+ volume.Isosurface)
+ """Type of items that are picked"""
+
+ def _isSupportedItem(self, item):
+ """Returns True if item is of supported type
+
+ :param Item3D item: The Item3D to check
+ :rtype: bool
+ """
+ return isinstance(item, self._SUPPORTED_ITEMS)
+
+ def pick(self, x, y):
+ """Pick items in the associated SceneWidget and display result
+
+ Only the closest point is displayed.
+
+ :param int x: X coordinate in pixel in the SceneWidget
+ :param int y: Y coordinate in pixel in the SceneWidget
+ """
+ self.clear()
+
+ sceneWidget = self.getSceneWidget()
+ if sceneWidget is None: # No associated widget
+ _logger.info('Picking without associated SceneWidget')
+ return
+
+ # Find closest (and latest in the tree) supported item
+ closestNdcZ = float('inf')
+ picking = None
+ for result in sceneWidget.pickItems(x, y,
+ condition=self._isSupportedItem):
+ ndcZ = result.getPositions('ndc', copy=False)[0, 2]
+ if ndcZ <= closestNdcZ:
+ closestNdcZ = ndcZ
+ picking = result
+
+ if picking is None:
+ return # No picked item
+
+ item = picking.getItem()
+ self._itemLabel.setText(item.getLabel())
+ positions = picking.getPositions('scene', copy=False)
+ x, y, z = positions[0]
+ self._xLabel.setText("%g" % x)
+ self._yLabel.setText("%g" % y)
+ self._zLabel.setText("%g" % z)
+
+ data = picking.getData(copy=False)
+ if data is not None:
+ data = data[0]
+ if hasattr(data, '__len__'):
+ text = ' '.join(["%.3g"] * len(data)) % tuple(data)
+ else:
+ text = "%g" % data
+ self._dataLabel.setText(text)
+
+ def updateInfo(self):
+ """Update information according to cursor position"""
+ widget = self.getSceneWidget()
+ if widget is None:
+ _logger.info('Update without associated SceneWidget')
+ self.clear()
+ return
+
+ position = widget.mapFromGlobal(qt.QCursor.pos())
+ self.pick(position.x(), position.y())
diff --git a/src/silx/gui/plot3d/tools/ViewpointTools.py b/src/silx/gui/plot3d/tools/ViewpointTools.py
new file mode 100644
index 0000000..0607382
--- /dev/null
+++ b/src/silx/gui/plot3d/tools/ViewpointTools.py
@@ -0,0 +1,84 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a toolbar to control Plot3DWidget viewpoint."""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "08/09/2017"
+
+
+import weakref
+
+from silx.gui import qt
+from silx.gui.icons import getQIcon
+from .. import actions
+
+
+class ViewpointToolButton(qt.QToolButton):
+ """A toolbutton with a drop-down list of ways to reset the viewpoint.
+
+ :param parent: See :class:`QToolButton`
+ """
+
+ def __init__(self, parent=None):
+ super(ViewpointToolButton, self).__init__(parent)
+
+ self._plot3DRef = None
+
+ menu = qt.QMenu(self)
+ menu.addAction(actions.viewpoint.FrontViewpointAction(parent=self))
+ menu.addAction(actions.viewpoint.BackViewpointAction(parent=self))
+ menu.addAction(actions.viewpoint.TopViewpointAction(parent=self))
+ menu.addAction(actions.viewpoint.BottomViewpointAction(parent=self))
+ menu.addAction(actions.viewpoint.RightViewpointAction(parent=self))
+ menu.addAction(actions.viewpoint.LeftViewpointAction(parent=self))
+ menu.addAction(actions.viewpoint.SideViewpointAction(parent=self))
+
+ self.setMenu(menu)
+ self.setPopupMode(qt.QToolButton.InstantPopup)
+ self.setIcon(getQIcon('cube'))
+ self.setToolTip('Reset the viewpoint to a defined position')
+
+ def setPlot3DWidget(self, widget):
+ """Set the Plot3DWidget this toolbar is associated with
+
+ :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget widget:
+ The widget to control
+ """
+ self._plot3DRef = None if widget is None else weakref.ref(widget)
+
+ for action in self.menu().actions():
+ action.setPlot3DWidget(widget)
+
+ def getPlot3DWidget(self):
+ """Return the Plot3DWidget associated to this toolbar.
+
+ If no widget is associated, it returns None.
+
+ :rtype: ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget or None
+ """
+ return None if self._plot3DRef is None else self._plot3DRef()
diff --git a/src/silx/gui/plot3d/tools/__init__.py b/src/silx/gui/plot3d/tools/__init__.py
new file mode 100644
index 0000000..c8b8d21
--- /dev/null
+++ b/src/silx/gui/plot3d/tools/__init__.py
@@ -0,0 +1,34 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides tool widgets that can be attached to a plot3DWidget."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "08/09/2017"
+
+from .toolbars import InteractiveModeToolBar # noqa
+from .toolbars import OutputToolBar # noqa
+from .toolbars import ViewpointToolBar # noqa
+from .ViewpointTools import ViewpointToolButton # noqa
diff --git a/src/silx/gui/plot3d/tools/test/__init__.py b/src/silx/gui/plot3d/tools/test/__init__.py
new file mode 100644
index 0000000..86741ed
--- /dev/null
+++ b/src/silx/gui/plot3d/tools/test/__init__.py
@@ -0,0 +1,25 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""plot3d tools test suite."""
diff --git a/src/silx/gui/plot3d/tools/test/testPositionInfoWidget.py b/src/silx/gui/plot3d/tools/test/testPositionInfoWidget.py
new file mode 100644
index 0000000..17fb3db
--- /dev/null
+++ b/src/silx/gui/plot3d/tools/test/testPositionInfoWidget.py
@@ -0,0 +1,89 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+# ###########################################################################*/
+"""Test PositionInfoWidget"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/10/2018"
+
+
+import unittest
+
+import numpy
+
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui import qt
+
+from silx.gui.plot3d.SceneWidget import SceneWidget
+from silx.gui.plot3d.tools.PositionInfoWidget import PositionInfoWidget
+
+
+class TestPositionInfoWidget(TestCaseQt):
+ """Tests PositionInfoWidget"""
+
+ def setUp(self):
+ super(TestPositionInfoWidget, self).setUp()
+ self.sceneWidget = SceneWidget()
+ self.sceneWidget.resize(300, 300)
+ self.sceneWidget.show()
+
+ self.positionInfoWidget = PositionInfoWidget()
+ self.positionInfoWidget.setSceneWidget(self.sceneWidget)
+ self.positionInfoWidget.show()
+ self.qWaitForWindowExposed(self.positionInfoWidget)
+
+ # self.qWaitForWindowExposed(self.widget)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+
+ self.sceneWidget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.sceneWidget.close()
+ del self.sceneWidget
+
+ self.positionInfoWidget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.positionInfoWidget.close()
+ del self.positionInfoWidget
+ super(TestPositionInfoWidget, self).tearDown()
+
+ def test(self):
+ """Test PositionInfoWidget"""
+ self.assertIs(self.positionInfoWidget.getSceneWidget(),
+ self.sceneWidget)
+
+ data = numpy.arange(100)
+ self.sceneWidget.add2DScatter(x=data, y=data, value=data)
+ self.sceneWidget.resetZoom('front')
+
+ # Double click at the center
+ self.mouseDClick(self.sceneWidget, button=qt.Qt.LeftButton)
+
+ # Clear displayed value
+ self.positionInfoWidget.clear()
+
+ # Update info from API
+ self.positionInfoWidget.pick(x=10, y=10)
+
+ # Remove SceneWidget
+ self.positionInfoWidget.setSceneWidget(None)
diff --git a/src/silx/gui/plot3d/tools/toolbars.py b/src/silx/gui/plot3d/tools/toolbars.py
new file mode 100644
index 0000000..d4f32db
--- /dev/null
+++ b/src/silx/gui/plot3d/tools/toolbars.py
@@ -0,0 +1,209 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides toolbars with tools for a Plot3DWidget.
+
+It provides the following toolbars:
+
+- :class:`InteractiveModeToolBar` with:
+ - Set interactive mode to rotation
+ - Set interactive mode to pan
+
+- :class:`OutputToolBar` with:
+ - Copy
+ - Save
+ - Video
+ - Print
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "06/09/2017"
+
+import logging
+import weakref
+
+from silx.gui import qt
+
+from .ViewpointTools import ViewpointToolButton
+from .. import actions
+
+_logger = logging.getLogger(__name__)
+
+
+class Plot3DWidgetToolBar(qt.QToolBar):
+ """Base class for toolbar associated to a Plot3DWidget
+
+ :param parent: See :class:`QWidget`
+ :param str title: Title of the toolbar.
+ """
+
+ def __init__(self, parent=None, title=''):
+ super(Plot3DWidgetToolBar, self).__init__(title, parent)
+
+ self._plot3DRef = None
+
+ def _plot3DWidgetChanged(self, widget):
+ """Handle change of Plot3DWidget and sync actions
+
+ :param Plot3DWidget widget:
+ """
+ for action in self.actions():
+ if isinstance(action, actions.Plot3DAction):
+ action.setPlot3DWidget(widget)
+
+ def setPlot3DWidget(self, widget):
+ """Set the Plot3DWidget this toolbar is associated with
+
+ :param Plot3DWidget widget: The widget to control
+ """
+ self._plot3DRef = None if widget is None else weakref.ref(widget)
+ self._plot3DWidgetChanged(widget)
+
+ def getPlot3DWidget(self):
+ """Return the Plot3DWidget associated to this toolbar.
+
+ If no widget is associated, it returns None.
+
+ :rtype: qt.QWidget
+ """
+ return None if self._plot3DRef is None else self._plot3DRef()
+
+
+class InteractiveModeToolBar(Plot3DWidgetToolBar):
+ """Toolbar providing icons to change the interaction mode
+
+ :param parent: See :class:`QWidget`
+ :param str title: Title of the toolbar.
+ """
+
+ def __init__(self, parent=None, title='Plot3D Interaction'):
+ super(InteractiveModeToolBar, self).__init__(parent, title)
+
+ self._rotateAction = actions.mode.RotateArcballAction(parent=self)
+ self.addAction(self._rotateAction)
+
+ self._panAction = actions.mode.PanAction(parent=self)
+ self.addAction(self._panAction)
+
+ def getRotateAction(self):
+ """Returns the QAction setting rotate interaction of the Plot3DWidget
+
+ :rtype: qt.QAction
+ """
+ return self._rotateAction
+
+ def getPanAction(self):
+ """Returns the QAction setting pan interaction of the Plot3DWidget
+
+ :rtype: qt.QAction
+ """
+ return self._panAction
+
+
+class OutputToolBar(Plot3DWidgetToolBar):
+ """Toolbar providing icons to copy, save and print the OpenGL scene
+
+ :param parent: See :class:`QWidget`
+ :param str title: Title of the toolbar.
+ """
+
+ def __init__(self, parent=None, title='Plot3D Output'):
+ super(OutputToolBar, self).__init__(parent, title)
+
+ self._copyAction = actions.io.CopyAction(parent=self)
+ self.addAction(self._copyAction)
+
+ self._saveAction = actions.io.SaveAction(parent=self)
+ self.addAction(self._saveAction)
+
+ self._videoAction = actions.io.VideoAction(parent=self)
+ self.addAction(self._videoAction)
+
+ self._printAction = actions.io.PrintAction(parent=self)
+ self.addAction(self._printAction)
+
+ def getCopyAction(self):
+ """Returns the QAction performing copy to clipboard of the Plot3DWidget
+
+ :rtype: qt.QAction
+ """
+ return self._copyAction
+
+ def getSaveAction(self):
+ """Returns the QAction performing save to file of the Plot3DWidget
+
+ :rtype: qt.QAction
+ """
+ return self._saveAction
+
+ def getVideoRecordAction(self):
+ """Returns the QAction performing record video of the Plot3DWidget
+
+ :rtype: qt.QAction
+ """
+ return self._videoAction
+
+ def getPrintAction(self):
+ """Returns the QAction performing printing of the Plot3DWidget
+
+ :rtype: qt.QAction
+ """
+ return self._printAction
+
+
+class ViewpointToolBar(Plot3DWidgetToolBar):
+ """A toolbar providing icons to reset the viewpoint.
+
+ :param parent: See :class:`QToolBar`
+ :param str title: Title of the toolbar
+ """
+
+ def __init__(self, parent=None, title='Viewpoint control'):
+ super(ViewpointToolBar, self).__init__(parent, title)
+
+ self._viewpointToolButton = ViewpointToolButton(parent=self)
+ self.addWidget(self._viewpointToolButton)
+ self._rotateViewpointAction = actions.viewpoint.RotateViewpoint(parent=self)
+ self.addAction(self._rotateViewpointAction)
+
+ def _plot3DWidgetChanged(self, widget):
+ self.getViewpointToolButton().setPlot3DWidget(widget)
+ super(ViewpointToolBar, self)._plot3DWidgetChanged(widget)
+
+ def getViewpointToolButton(self):
+ """Returns the ViewpointToolButton to set viewpoint of the Plot3DWidget
+
+ :rtype: ViewpointToolButton
+ """
+ return self._viewpointToolButton
+
+ def getRotateViewpointAction(self):
+ """Returns the QAction to start/stop rotation of the Plot3DWidget
+
+ :rtype: qt.QAction
+ """
+ return self._rotateViewpointAction
diff --git a/src/silx/gui/plot3d/utils/__init__.py b/src/silx/gui/plot3d/utils/__init__.py
new file mode 100644
index 0000000..99d3e08
--- /dev/null
+++ b/src/silx/gui/plot3d/utils/__init__.py
@@ -0,0 +1,28 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "18/10/2016"
diff --git a/src/silx/gui/plot3d/utils/mng.py b/src/silx/gui/plot3d/utils/mng.py
new file mode 100644
index 0000000..8049a2f
--- /dev/null
+++ b/src/silx/gui/plot3d/utils/mng.py
@@ -0,0 +1,121 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides basic writing Mulitple-image Network Graphics files.
+
+It only supports RGB888 images of the same shape stored as
+MNG-VLC (very low complexity) format.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "15/12/2016"
+
+
+import logging
+import struct
+import zlib
+
+import numpy
+
+_logger = logging.getLogger(__name__)
+
+
+def _png_chunk(name, data):
+ """Return a PNG chunk
+
+ :param str name: Chunk type
+ :param byte data: Chunk payload
+ """
+ length = struct.pack('>I', len(data))
+ name = [char.encode('ascii') for char in name]
+ chunk = struct.pack('cccc', *name) + data
+ crc = struct.pack('>I', zlib.crc32(chunk) & 0xffffffff)
+ return length + chunk + crc
+
+
+def convert(images, nb_images=0, fps=25):
+ """Convert RGB images to MNG-VLC format.
+
+ See http://www.libpng.org/pub/mng/spec/
+ See http://www.libpng.org/pub/png/book/
+ See http://www.libpng.org/pub/png/spec/1.2/
+
+ :param images: iterator of RGB888 images
+ :type images: iterator of numpy.ndarray of dimension 3
+ :param int nb_images: The number of images indicated in the MNG header
+ :param int fps: The frame rate indicated in the MNG header
+ :return: An iterator of MNG chunks as bytes
+ """
+ first_image = True
+
+ for image in images:
+ if first_image:
+ first_image = False
+
+ height, width = image.shape[:2]
+
+ # MNG signature
+ yield b'\x8aMNG\r\n\x1a\n'
+
+ # MHDR chunk: File header
+ yield _png_chunk('MHDR', struct.pack(
+ ">IIIIIII",
+ width,
+ height,
+ fps, # ticks
+ nb_images + 1, # layer count
+ nb_images, # frame count
+ nb_images, # play time
+ 1)) # profile: MNG-VLC no alpha: only least significant bit 1
+
+ assert image.shape == (height, width, 3)
+ assert image.dtype == numpy.dtype('uint8')
+
+ # IHDR chunk: Image header
+ depth = 8 # 8 bit per channel
+ color_type = 2 # 'truecolor' = RGB
+ interlace = 0 # No
+ yield _png_chunk('IHDR', struct.pack(">IIBBBBB",
+ width,
+ height,
+ depth,
+ color_type,
+ 0, 0, interlace))
+
+ # Add filter 'None' before each scanline
+ prepared_data = b'\x00' + b'\x00'.join(
+ line.tobytes() for line in image) # TODO optimize that
+ compressed_data = zlib.compress(prepared_data, 8)
+
+ # IDAT chunk: Payload
+ yield _png_chunk('IDAT', compressed_data)
+
+ # IEND chunk: Image footer
+ yield _png_chunk('IEND', b'')
+
+ # MEND chunk: footer
+ yield _png_chunk('MEND', b'')
diff --git a/src/silx/gui/printer.py b/src/silx/gui/printer.py
new file mode 100644
index 0000000..761fa0f
--- /dev/null
+++ b/src/silx/gui/printer.py
@@ -0,0 +1,62 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a singleton QPrinter used by default by silx widgets.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "01/03/2018"
+
+
+from . import qt
+
+
+_printer = None
+"""Shared QPrinter instance"""
+
+
+def getDefaultPrinter():
+ """Returns the default printer.
+
+ This allows reusing the same QPrinter across widgets.
+
+ :return: QPrinter
+ """
+ global _printer
+ if _printer is None:
+ _printer = qt.QPrinter()
+ return _printer
+
+
+def setDefaultPrinter(printer):
+ """Set the printer used by default by silx widgets.
+
+ :param QPrinter printer:
+ """
+ assert isinstance(printer, qt.QPrinter)
+ global _printer
+ _printer = printer
diff --git a/src/silx/gui/qt/__init__.py b/src/silx/gui/qt/__init__.py
new file mode 100644
index 0000000..915c89b
--- /dev/null
+++ b/src/silx/gui/qt/__init__.py
@@ -0,0 +1,54 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Common wrapper over Python Qt bindings:
+
+- `PyQt5 <http://pyqt.sourceforge.net/Docs/PyQt5/>`_
+- `PySide2 <https://pypi.org/project/PySide2/>`_
+- `PySide6 <https://pypi.org/project/PySide6/>`_
+
+If a Qt binding is already loaded, it will use it, otherwise the different
+Qt bindings are tried in this order: PyQt5, PySide2, PySide6.
+
+The name of the loaded Qt binding is stored in the BINDING variable.
+
+This module provides a flat namespace over Qt bindings by importing
+all symbols from **QtCore**, **QtGui**, **QtWidgets** and **QtPrintSupport**
+packages and if available from **QtOpenGL** and **QtSvg** packages.
+
+Example of using :mod:`silx.gui.qt` module:
+
+>>> from silx.gui import qt
+>>> app = qt.QApplication([])
+>>> widget = qt.QWidget()
+
+For an alternative solution providing a structured namespace,
+see `qtpy <https://pypi.org/project/QtPy/>`_.
+"""
+
+from ._qt import * # noqa
+if BINDING in ('PySide2', 'PySide6'):
+ # Import loadUi wrapper
+ from ._pyside_dynamic import loadUi # noqa
+from ._utils import * # noqa
diff --git a/src/silx/gui/qt/_pyside_dynamic.py b/src/silx/gui/qt/_pyside_dynamic.py
new file mode 100644
index 0000000..a841eae
--- /dev/null
+++ b/src/silx/gui/qt/_pyside_dynamic.py
@@ -0,0 +1,235 @@
+# -*- coding: utf-8 -*-
+
+# Taken from: https://gist.github.com/cpbotha/1b42a20c8f3eb9bb7cb8
+# Plus: https://github.com/spyder-ide/qtpy/commit/001a862c401d757feb63025f88dbb4601d353c84
+
+# Copyright (c) 2011 Sebastian Wiesner <lunaryorn@gmail.com>
+# Modifications by Charl Botha <cpbotha@vxlabs.com>
+# * customWidgets support (registerCustomWidget() causes segfault in
+# pyside 1.1.2 on Ubuntu 12.04 x86_64)
+# * workingDirectory support in loadUi
+
+# found this here:
+# https://github.com/lunaryorn/snippets/blob/master/qt4/designer/pyside_dynamic.py
+
+# Permission is hereby granted, free of charge, to any person obtaining a
+# copy of this software and associated documentation files (the "Software"),
+# to deal in the Software without restriction, including without limitation
+# the rights to use, copy, modify, merge, publish, distribute, sublicense,
+# and/or sell copies of the Software, and to permit persons to whom the
+# Software is furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
+# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+
+"""
+ How to load a user interface dynamically with PySide.
+
+ .. moduleauthor:: Sebastian Wiesner <lunaryorn@gmail.com>
+"""
+
+import logging
+
+from ._qt import BINDING
+if BINDING == 'PySide2':
+ from PySide2.QtCore import QMetaObject, Property, Qt
+ from PySide2.QtWidgets import QFrame
+ from PySide2.QtUiTools import QUiLoader
+elif BINDING == 'PySide6':
+ from PySide6.QtCore import QMetaObject, Property, Qt
+ from PySide6.QtWidgets import QFrame
+ from PySide6.QtUiTools import QUiLoader
+else:
+ raise RuntimeError("Unsupported Qt binding: %s", BINDING)
+
+
+_logger = logging.getLogger(__name__)
+
+
+class UiLoader(QUiLoader):
+ """
+ Subclass :class:`~PySide.QtUiTools.QUiLoader` to create the user interface
+ in a base instance.
+
+ Unlike :class:`~PySide.QtUiTools.QUiLoader` itself this class does not
+ create a new instance of the top-level widget, but creates the user
+ interface in an existing instance of the top-level class.
+
+ This mimics the behaviour of :func:`PyQt*.uic.loadUi`.
+ """
+
+ def __init__(self, baseinstance, customWidgets=None):
+ """
+ Create a loader for the given ``baseinstance``.
+
+ The user interface is created in ``baseinstance``, which must be an
+ instance of the top-level class in the user interface to load, or a
+ subclass thereof.
+
+ ``customWidgets`` is a dictionary mapping from class name to class
+ object for widgets that you've promoted in the Qt Designer
+ interface. Usually, this should be done by calling
+ registerCustomWidget on the QUiLoader, but
+ with PySide 1.1.2 on Ubuntu 12.04 x86_64 this causes a segfault.
+
+ ``parent`` is the parent object of this loader.
+ """
+
+ QUiLoader.__init__(self, baseinstance)
+ self.baseinstance = baseinstance
+ self.customWidgets = {}
+ self.uifile = None
+ self.customWidgets.update(customWidgets)
+
+ def createWidget(self, class_name, parent=None, name=''):
+ """
+ Function that is called for each widget defined in ui file,
+ overridden here to populate baseinstance instead.
+ """
+
+ if parent is None and self.baseinstance:
+ # supposed to create the top-level widget, return the base instance
+ # instead
+ return self.baseinstance
+
+ else:
+ if class_name in self.availableWidgets():
+ # create a new widget for child widgets
+ widget = QUiLoader.createWidget(self, class_name, parent, name)
+
+ else:
+ # if not in the list of availableWidgets,
+ # must be a custom widget
+ # this will raise KeyError if the user has not supplied the
+ # relevant class_name in the dictionary, or TypeError, if
+ # customWidgets is None
+ if class_name not in self.customWidgets:
+ raise Exception('No custom widget ' + class_name +
+ ' found in customWidgets param of' +
+ 'UiFile %s.' % self.uifile)
+ try:
+ widget = self.customWidgets[class_name](parent)
+ except Exception:
+ _logger.error("Fail to instanciate widget %s from file %s", class_name, self.uifile)
+ raise
+
+ if self.baseinstance:
+ # set an attribute for the new child widget on the base
+ # instance, just like PyQt*.uic.loadUi does.
+ setattr(self.baseinstance, name, widget)
+
+ # this outputs the various widget names, e.g.
+ # sampleGraphicsView, dockWidget, samplesTableView etc.
+ # print(name)
+
+ return widget
+
+ def _parse_custom_widgets(self, ui_file):
+ """
+ This function is used to parse a ui file and look for the <customwidgets>
+ section, then automatically load all the custom widget classes.
+ """
+ import importlib
+ from xml.etree.ElementTree import ElementTree
+
+ # Parse the UI file
+ etree = ElementTree()
+ ui = etree.parse(ui_file)
+
+ # Get the customwidgets section
+ custom_widgets = ui.find('customwidgets')
+
+ if custom_widgets is None:
+ return
+
+ custom_widget_classes = {}
+
+ for custom_widget in custom_widgets.getchildren():
+
+ cw_class = custom_widget.find('class').text
+ cw_header = custom_widget.find('header').text
+
+ module = importlib.import_module(cw_header)
+
+ custom_widget_classes[cw_class] = getattr(module, cw_class)
+
+ self.customWidgets.update(custom_widget_classes)
+
+ def load(self, uifile):
+ self._parse_custom_widgets(uifile)
+ self.uifile = uifile
+ return QUiLoader.load(self, uifile)
+
+
+class _Line(QFrame):
+ """Widget to use as 'Line' Qt designer"""
+ def __init__(self, parent=None):
+ super(_Line, self).__init__(parent)
+ self.setFrameShape(QFrame.HLine)
+ self.setFrameShadow(QFrame.Sunken)
+
+ def getOrientation(self):
+ shape = self.frameShape()
+ if shape == QFrame.HLine:
+ return Qt.Horizontal
+ elif shape == QFrame.VLine:
+ return Qt.Vertical
+ else:
+ raise RuntimeError("Wrong shape: %d", shape)
+
+ def setOrientation(self, orientation):
+ if orientation == Qt.Horizontal:
+ self.setFrameShape(QFrame.HLine)
+ elif orientation == Qt.Vertical:
+ self.setFrameShape(QFrame.VLine)
+ else:
+ raise ValueError("Unsupported orientation %s" % str(orientation))
+
+ orientation = Property("Qt::Orientation", getOrientation, setOrientation)
+
+
+CUSTOM_WIDGETS = {"Line": _Line}
+"""Default custom widgets for `loadUi`"""
+
+
+def loadUi(uifile, baseinstance=None, package=None, resource_suffix=None):
+ """
+ Dynamically load a user interface from the given ``uifile``.
+
+ ``uifile`` is a string containing a file name of the UI file to load.
+
+ If ``baseinstance`` is ``None``, the a new instance of the top-level widget
+ will be created. Otherwise, the user interface is created within the given
+ ``baseinstance``. In this case ``baseinstance`` must be an instance of the
+ top-level widget class in the UI file to load, or a subclass thereof. In
+ other words, if you've created a ``QMainWindow`` interface in the designer,
+ ``baseinstance`` must be a ``QMainWindow`` or a subclass thereof, too. You
+ cannot load a ``QMainWindow`` UI file with a plain
+ :class:`~PySide.QtGui.QWidget` as ``baseinstance``.
+
+ :method:`~PySide.QtCore.QMetaObject.connectSlotsByName()` is called on the
+ created user interface, so you can implemented your slots according to its
+ conventions in your widget class.
+
+ Return ``baseinstance``, if ``baseinstance`` is not ``None``. Otherwise
+ return the newly created instance of the user interface.
+ """
+ if package is not None:
+ _logger.warning(
+ "loadUi package parameter not implemented with PySide")
+ if resource_suffix is not None:
+ _logger.warning(
+ "loadUi resource_suffix parameter not implemented with PySide")
+
+ loader = UiLoader(baseinstance, customWidgets=CUSTOM_WIDGETS)
+ widget = loader.load(uifile)
+ QMetaObject.connectSlotsByName(widget)
+ return widget
diff --git a/src/silx/gui/qt/_qt.py b/src/silx/gui/qt/_qt.py
new file mode 100644
index 0000000..f62f4c8
--- /dev/null
+++ b/src/silx/gui/qt/_qt.py
@@ -0,0 +1,232 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Load Qt binding"""
+
+__authors__ = ["V.A. Sole"]
+__license__ = "MIT"
+__date__ = "23/05/2018"
+
+
+import logging
+import sys
+import traceback
+
+
+_logger = logging.getLogger(__name__)
+
+
+BINDING = None
+"""The name of the Qt binding in use: PyQt5, PySide2, PySide6."""
+
+QtBinding = None # noqa
+"""The Qt binding module in use: PyQt5, PySide2, PySide6."""
+
+HAS_SVG = False
+"""True if Qt provides support for Scalable Vector Graphics (QtSVG)."""
+
+HAS_OPENGL = False
+"""True if Qt provides support for OpenGL (QtOpenGL)."""
+
+# First check for an already loaded wrapper
+for _binding in ('PySide2', 'PyQt5', 'PySide6'):
+ if _binding + '.QtCore' in sys.modules:
+ BINDING = _binding
+ break
+else: # Then try Qt bindings
+ try:
+ import PyQt5.QtCore # noqa
+ except ImportError:
+ if 'PyQt5' in sys.modules:
+ del sys.modules["PyQt5"]
+ try:
+ import PySide2.QtCore # noqa
+ except ImportError:
+ if 'PySide2' in sys.modules:
+ del sys.modules["PySide2"]
+ try:
+ import PySide6.QtCore # noqa
+ except ImportError:
+ if 'PySide6' in sys.modules:
+ del sys.modules["PySide6"]
+ raise ImportError(
+ 'No Qt wrapper found. Install PyQt5, PySide2, PySide6.')
+ else:
+ BINDING = 'PySide6'
+ else:
+ BINDING = 'PySide2'
+ else:
+ BINDING = 'PyQt5'
+
+
+if BINDING == 'PyQt5':
+ _logger.debug('Using PyQt5 bindings')
+
+ import PyQt5 as QtBinding # noqa
+
+ from PyQt5.QtCore import * # noqa
+ from PyQt5.QtGui import * # noqa
+ from PyQt5.QtWidgets import * # noqa
+ from PyQt5.QtPrintSupport import * # noqa
+
+ try:
+ from PyQt5.QtOpenGL import * # noqa
+ except ImportError:
+ _logger.info("PyQt5.QtOpenGL not available")
+ HAS_OPENGL = False
+ else:
+ HAS_OPENGL = True
+
+ try:
+ from PyQt5.QtSvg import * # noqa
+ except ImportError:
+ _logger.info("PyQt5.QtSvg not available")
+ HAS_SVG = False
+ else:
+ HAS_SVG = True
+
+ from PyQt5.uic import loadUi # noqa
+
+ Signal = pyqtSignal
+
+ Property = pyqtProperty
+
+ Slot = pyqtSlot
+
+ # Disable PyQt5's cooperative multi-inheritance since other bindings do not provide it.
+ # See https://www.riverbankcomputing.com/static/Docs/PyQt5/multiinheritance.html?highlight=inheritance
+ class _Foo(object): pass
+ class QObject(QObject, _Foo): pass
+
+
+elif BINDING == 'PySide2':
+ _logger.debug('Using PySide2 bindings')
+
+ import PySide2 as QtBinding # noqa
+
+ from PySide2.QtCore import * # noqa
+ from PySide2.QtGui import * # noqa
+ from PySide2.QtWidgets import * # noqa
+ from PySide2.QtPrintSupport import * # noqa
+
+ try:
+ from PySide2.QtOpenGL import * # noqa
+ except ImportError:
+ _logger.info("PySide2.QtOpenGL not available")
+ HAS_OPENGL = False
+ else:
+ HAS_OPENGL = True
+
+ try:
+ from PySide2.QtSvg import * # noqa
+ except ImportError:
+ _logger.info("PySide2.QtSvg not available")
+ HAS_SVG = False
+ else:
+ HAS_SVG = True
+
+ pyqtSignal = Signal
+
+ # Qt6 compatibility:
+ # with PySide2 `exec` method has a special behavior
+ class _ExecMixIn:
+ """Mix-in class providind `exec` compatibility"""
+ def exec(self, *args, **kwargs):
+ return super().exec_(*args, **kwargs)
+
+ # QtWidgets
+ class QApplication(_ExecMixIn, QApplication): pass
+ class QColorDialog(_ExecMixIn, QColorDialog): pass
+ class QDialog(_ExecMixIn, QDialog): pass
+ class QErrorMessage(_ExecMixIn, QErrorMessage): pass
+ class QFileDialog(_ExecMixIn, QFileDialog): pass
+ class QFontDialog(_ExecMixIn, QFontDialog): pass
+ class QInputDialog(_ExecMixIn, QInputDialog): pass
+ class QMenu(_ExecMixIn, QMenu): pass
+ class QMessageBox(_ExecMixIn, QMessageBox): pass
+ class QProgressDialog(_ExecMixIn, QProgressDialog): pass
+ #QtCore
+ class QCoreApplication(_ExecMixIn, QCoreApplication): pass
+ class QEventLoop(_ExecMixIn, QEventLoop): pass
+ if hasattr(QTextStreamManipulator, "exec_"):
+ # exec_ only wrapped in PySide2 and NOT in PyQt5
+ class QTextStreamManipulator(_ExecMixIn, QTextStreamManipulator): pass
+ class QThread(_ExecMixIn, QThread): pass
+
+
+elif BINDING == 'PySide6':
+ _logger.debug('Using PySide6 bindings')
+
+ import PySide6 as QtBinding # noqa
+
+ from PySide6.QtCore import * # noqa
+ from PySide6.QtGui import * # noqa
+ from PySide6.QtWidgets import * # noqa
+ from PySide6.QtPrintSupport import * # noqa
+
+ try:
+ from PySide6.QtOpenGL import * # noqa
+ from PySide6.QtOpenGLWidgets import QOpenGLWidget # noqa
+ except ImportError:
+ _logger.info("PySide6.QtOpenGL not available")
+ HAS_OPENGL = False
+ else:
+ HAS_OPENGL = True
+
+ try:
+ from PySide6.QtSvg import * # noqa
+ except ImportError:
+ _logger.info("PySide6.QtSvg not available")
+ HAS_SVG = False
+ else:
+ HAS_SVG = True
+
+ pyqtSignal = Signal
+
+else:
+ raise ImportError('No Qt wrapper found. Install PyQt5, PySide2 or PySide6')
+
+
+# provide a exception handler but not implement it by default
+def exceptionHandler(type_, value, trace):
+ """
+ This exception handler prevents quitting to the command line when there is
+ an unhandled exception while processing a Qt signal.
+
+ The script/application willing to use it should implement code similar to:
+
+ .. code-block:: python
+
+ if __name__ == "__main__":
+ sys.excepthook = qt.exceptionHandler
+
+ """
+ _logger.error("%s %s %s", type_, value, ''.join(traceback.format_tb(trace)))
+ msg = QMessageBox()
+ msg.setWindowTitle("Unhandled exception")
+ msg.setIcon(QMessageBox.Critical)
+ msg.setInformativeText("%s %s\nPlease report details" % (type_, value))
+ msg.setDetailedText(("%s " % value) + ''.join(traceback.format_tb(trace)))
+ msg.raise_()
+ msg.exec()
diff --git a/src/silx/gui/qt/_utils.py b/src/silx/gui/qt/_utils.py
new file mode 100644
index 0000000..5dced95
--- /dev/null
+++ b/src/silx/gui/qt/_utils.py
@@ -0,0 +1,68 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides convenient functions related to Qt.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "30/11/2016"
+
+
+from . import _qt
+
+
+def supportedImageFormats():
+ """Return a set of string of file format extensions supported by the
+ Qt runtime."""
+ if _qt.BINDING == 'PySide2':
+ def convert(data):
+ return str(data.data(), 'ascii')
+ else:
+ convert = lambda data: str(data, 'ascii')
+ formats = _qt.QImageReader.supportedImageFormats()
+ return set([convert(data) for data in formats])
+
+
+__globalThreadPoolInstance = None
+"""Store the own silx global thread pool"""
+
+
+def silxGlobalThreadPool():
+ """"Manage an own QThreadPool to avoid issue on Qt5 Windows with the
+ default Qt global thread pool.
+
+ A thread pool is create in lazy loading. With a maximum of 4 threads.
+ Else `qt.Thread.idealThreadCount()` is used.
+
+ :rtype: qt.QThreadPool
+ """
+ global __globalThreadPoolInstance
+ if __globalThreadPoolInstance is None:
+ tp = _qt.QThreadPool()
+ # Setting maxThreadCount fixes a segfault with PyQt 5.9.1 on Windows
+ maxThreadCount = min(4, tp.maxThreadCount())
+ tp.setMaxThreadCount(maxThreadCount)
+ __globalThreadPoolInstance = tp
+ return __globalThreadPoolInstance
diff --git a/src/silx/gui/qt/inspect.py b/src/silx/gui/qt/inspect.py
new file mode 100644
index 0000000..b9a0d1d
--- /dev/null
+++ b/src/silx/gui/qt/inspect.py
@@ -0,0 +1,75 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides functions to access Qt C++ object state:
+
+- :func:`isValid` to check whether a QObject C++ pointer is valid.
+- :func:`createdByPython` to check if a QObject was created from Python.
+- :func:`ownedByPython` to check if a QObject is currently owned by Python.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "08/10/2018"
+
+
+from . import _qt as qt
+
+
+if qt.BINDING == 'PyQt5':
+ try:
+ from PyQt5.sip import isdeleted as _isdeleted # noqa
+ from PyQt5.sip import ispycreated as createdByPython # noqa
+ from PyQt5.sip import ispyowned as ownedByPython # noqa
+ except ImportError:
+ from sip import isdeleted as _isdeleted # noqa
+ from sip import ispycreated as createdByPython # noqa
+ from sip import ispyowned as ownedByPython # noqa
+
+
+ def isValid(obj):
+ """Returns True if underlying C++ object is valid.
+
+ :param QObject obj:
+ :rtype: bool
+ """
+ return not _isdeleted(obj)
+
+elif qt.BINDING == 'PySide2':
+ try:
+ from PySide2.shiboken2 import isValid # noqa
+ from PySide2.shiboken2 import createdByPython # noqa
+ from PySide2.shiboken2 import ownedByPython # noqa
+ except ImportError:
+ from shiboken2 import isValid # noqa
+ from shiboken2 import createdByPython # noqa
+ from shiboken2 import ownedByPython # noqa
+
+elif qt.BINDING == 'PySide6':
+ from shiboken6 import isValid, createdByPython, ownedByPython # noqa
+
+else:
+ raise ImportError("Unsupported Qt binding %s" % qt.BINDING)
+
+__all__ = ['isValid', 'createdByPython', 'ownedByPython']
diff --git a/src/silx/gui/setup.py b/src/silx/gui/setup.py
new file mode 100644
index 0000000..04a2bac
--- /dev/null
+++ b/src/silx/gui/setup.py
@@ -0,0 +1,55 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "28/11/2017"
+
+
+from numpy.distutils.misc_util import Configuration
+
+
+def configuration(parent_package='', top_path=None):
+ config = Configuration('gui', parent_package, top_path)
+ config.add_subpackage('_glutils')
+ config.add_subpackage('qt')
+ config.add_subpackage('plot')
+ config.add_subpackage('fit')
+ config.add_subpackage('hdf5')
+ config.add_subpackage('widgets')
+ config.add_subpackage('test')
+ config.add_subpackage('plot3d')
+ config.add_subpackage('data')
+ config.add_subpackage('dialog')
+ config.add_subpackage('utils')
+ config.add_subpackage('utils.glutils')
+ config.add_subpackage('utils.test')
+
+ return config
+
+
+if __name__ == "__main__":
+ from numpy.distutils.core import setup
+
+ setup(configuration=configuration)
diff --git a/src/silx/gui/test/__init__.py b/src/silx/gui/test/__init__.py
new file mode 100644
index 0000000..00d6216
--- /dev/null
+++ b/src/silx/gui/test/__init__.py
@@ -0,0 +1,24 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
diff --git a/src/silx/gui/test/test_colors.py b/src/silx/gui/test/test_colors.py
new file mode 100755
index 0000000..fa87d7d
--- /dev/null
+++ b/src/silx/gui/test/test_colors.py
@@ -0,0 +1,603 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the Colormap object
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["H.Payno"]
+__license__ = "MIT"
+__date__ = "09/11/2018"
+
+import unittest
+import numpy
+from silx.utils.testutils import ParametricTestCase
+from silx.gui import qt
+from silx.gui import colors
+from silx.gui.colors import Colormap
+from silx.gui.plot import items
+from silx.utils.exceptions import NotEditableError
+
+
+class TestColor(ParametricTestCase):
+ """Basic tests of rgba function"""
+
+ TEST_COLORS = { # name: (colors, expected values)
+ 'blue': ('blue', (0., 0., 1., 1.)),
+ '#010203': ('#010203', (1. / 255., 2. / 255., 3. / 255., 1.)),
+ '#01020304': ('#01020304', (1. / 255., 2. / 255., 3. / 255., 4. / 255.)),
+ '3 x uint8': (numpy.array((1, 255, 0), dtype=numpy.uint8),
+ (1 / 255., 1., 0., 1.)),
+ '4 x uint8': (numpy.array((1, 255, 0, 1), dtype=numpy.uint8),
+ (1 / 255., 1., 0., 1 / 255.)),
+ '3 x float overflow': ((3., 0.5, 1.), (1., 0.5, 1., 1.)),
+ }
+
+ def testRGBA(self):
+ """"Test rgba function with accepted values"""
+ for name, test in self.TEST_COLORS.items():
+ color, expected = test
+ with self.subTest(msg=name):
+ result = colors.rgba(color)
+ self.assertEqual(result, expected)
+
+ def testQColor(self):
+ """"Test getQColor function with accepted values"""
+ for name, test in self.TEST_COLORS.items():
+ color, expected = test
+ with self.subTest(msg=name):
+ result = colors.asQColor(color)
+ self.assertAlmostEqual(result.redF(), expected[0], places=4)
+ self.assertAlmostEqual(result.greenF(), expected[1], places=4)
+ self.assertAlmostEqual(result.blueF(), expected[2], places=4)
+ self.assertAlmostEqual(result.alphaF(), expected[3], places=4)
+
+
+class TestApplyColormapToData(ParametricTestCase):
+ """Tests of applyColormapToData function"""
+
+ def testApplyColormapToData(self):
+ """Simple test of applyColormapToData function"""
+ colormap = Colormap(name='gray', normalization='linear',
+ vmin=0, vmax=255)
+
+ size = 10
+ expected = numpy.empty((size, 4), dtype='uint8')
+ expected[:, 0] = numpy.arange(size, dtype='uint8')
+ expected[:, 1] = expected[:, 0]
+ expected[:, 2] = expected[:, 0]
+ expected[:, 3] = 255
+
+ for dtype in ('uint8', 'int32', 'float32', 'float64'):
+ with self.subTest(dtype=dtype):
+ array = numpy.arange(size, dtype=dtype)
+ result = colormap.applyToData(data=array)
+ self.assertTrue(numpy.all(numpy.equal(result, expected)))
+
+ def testAutoscaleFromDataReference(self):
+ colormap = Colormap(name='gray', normalization='linear')
+ data = numpy.array([50])
+ reference = numpy.array([0, 100])
+ value = colormap.applyToData(data, reference)
+ self.assertEqual(len(value), 1)
+ self.assertEqual(value[0, 0], 128)
+
+ def testAutoscaleFromItemReference(self):
+ colormap = Colormap(name='gray', normalization='linear')
+ data = numpy.array([50])
+ image = items.ImageData()
+ image.setData(numpy.array([[0, 100]]))
+ value = colormap.applyToData(data, reference=image)
+ self.assertEqual(len(value), 1)
+ self.assertEqual(value[0, 0], 128)
+
+ def testNaNColor(self):
+ """Test Colormap.applyToData with NaN values"""
+ colormap = Colormap(name='gray', normalization='linear')
+ colormap.setNaNColor('red')
+ self.assertEqual(colormap.getNaNColor(), qt.QColor(255, 0, 0))
+
+ data = numpy.array([50., numpy.nan])
+ image = items.ImageData()
+ image.setData(numpy.array([[0, 100]]))
+ value = colormap.applyToData(data, reference=image)
+ self.assertEqual(len(value), 2)
+ self.assertTrue(numpy.array_equal(value[0], (128, 128, 128, 255)))
+ self.assertTrue(numpy.array_equal(value[1], (255, 0, 0, 255)))
+
+
+class TestDictAPI(unittest.TestCase):
+ """Make sure the old dictionary API is working
+ """
+
+ def setUp(self):
+ self.vmin = -1.0
+ self.vmax = 12
+
+ def testGetItem(self):
+ """test the item getter API ([xxx])"""
+ colormap = Colormap(name='viridis',
+ normalization=Colormap.LINEAR,
+ vmin=self.vmin,
+ vmax=self.vmax)
+ self.assertTrue(colormap['name'] == 'viridis')
+ self.assertTrue(colormap['normalization'] == Colormap.LINEAR)
+ self.assertTrue(colormap['vmin'] == self.vmin)
+ self.assertTrue(colormap['vmax'] == self.vmax)
+ with self.assertRaises(KeyError):
+ colormap['toto']
+
+ def testGetDict(self):
+ """Test the getDict function API"""
+ clmObject = Colormap(name='viridis',
+ normalization=Colormap.LINEAR,
+ vmin=self.vmin,
+ vmax=self.vmax)
+ clmDict = clmObject._toDict()
+ self.assertTrue(clmDict['name'] == 'viridis')
+ self.assertTrue(clmDict['autoscale'] is False)
+ self.assertTrue(clmDict['vmin'] == self.vmin)
+ self.assertTrue(clmDict['vmax'] == self.vmax)
+ self.assertTrue(clmDict['normalization'] == Colormap.LINEAR)
+
+ clmObject.setVRange(None, None)
+ self.assertTrue(clmObject._toDict()['autoscale'] is True)
+
+ def testSetValidDict(self):
+ """Test that if a colormap is created from a dict then it is correctly
+ created and the values are copied (so if some values from the dict
+ is changing, this won't affect the Colormap object"""
+ clm_dict = {
+ 'name': 'temperature',
+ 'vmin': 1.0,
+ 'vmax': 2.0,
+ 'normalization': 'linear',
+ 'colors': None,
+ 'autoscale': False
+ }
+
+ # Test that the colormap is correctly created
+ colormapObject = Colormap._fromDict(clm_dict)
+ self.assertTrue(colormapObject.getName() == clm_dict['name'])
+ self.assertTrue(colormapObject.getColormapLUT() == clm_dict['colors'])
+ self.assertTrue(colormapObject.getVMin() == clm_dict['vmin'])
+ self.assertTrue(colormapObject.getVMax() == clm_dict['vmax'])
+ self.assertTrue(colormapObject.isAutoscale() == clm_dict['autoscale'])
+
+ # Check that the colormap has copied the values
+ clm_dict['vmin'] = None
+ clm_dict['vmax'] = None
+ clm_dict['colors'] = [1.0, 2.0]
+ clm_dict['autoscale'] = True
+ clm_dict['normalization'] = Colormap.LOGARITHM
+ clm_dict['name'] = 'viridis'
+
+ self.assertFalse(colormapObject.getName() == clm_dict['name'])
+ self.assertFalse(colormapObject.getColormapLUT() == clm_dict['colors'])
+ self.assertFalse(colormapObject.getVMin() == clm_dict['vmin'])
+ self.assertFalse(colormapObject.getVMax() == clm_dict['vmax'])
+ self.assertFalse(colormapObject.isAutoscale() == clm_dict['autoscale'])
+
+ def testMissingKeysFromDict(self):
+ """Make sure we can create a Colormap object from a dictionary even if
+ there is missing keys except if those keys are 'colors' or 'name'
+ """
+ colormap = Colormap._fromDict({'name': 'blue'})
+ self.assertTrue(colormap.getVMin() is None)
+ colormap = Colormap._fromDict({'colors': numpy.zeros((5, 3))})
+ self.assertTrue(colormap.getName() is None)
+
+ with self.assertRaises(ValueError):
+ Colormap._fromDict({})
+
+ def testUnknowNorm(self):
+ """Make sure an error is raised if the given normalization is not
+ knowed
+ """
+ clm_dict = {
+ 'name': 'temperature',
+ 'vmin': 1.0,
+ 'vmax': 2.0,
+ 'normalization': 'toto',
+ 'colors': None,
+ 'autoscale': False
+ }
+ with self.assertRaises(ValueError):
+ Colormap._fromDict(clm_dict)
+
+ def testNumericalColors(self):
+ """Make sure the old API using colors=int was supported"""
+ clm_dict = {
+ 'name': 'temperature',
+ 'vmin': 1.0,
+ 'vmax': 2.0,
+ 'colors': 256,
+ 'autoscale': False
+ }
+ Colormap._fromDict(clm_dict)
+
+
+class TestObjectAPI(ParametricTestCase):
+ """Test the new Object API of the colormap"""
+ def testVMinVMax(self):
+ """Test getter and setter associated to vmin and vmax values"""
+ vmin = 1.0
+ vmax = 2.0
+
+ colormapObject = Colormap(name='viridis',
+ vmin=vmin,
+ vmax=vmax,
+ normalization=Colormap.LINEAR)
+
+ with self.assertRaises(ValueError):
+ colormapObject.setVMin(3)
+
+ with self.assertRaises(ValueError):
+ colormapObject.setVMax(-2)
+
+ with self.assertRaises(ValueError):
+ colormapObject.setVRange(3, -2)
+
+ self.assertTrue(colormapObject.getColormapRange() == (1.0, 2.0))
+ self.assertTrue(colormapObject.isAutoscale() is False)
+ colormapObject.setVRange(None, None)
+ self.assertTrue(colormapObject.getVMin() is None)
+ self.assertTrue(colormapObject.getVMax() is None)
+ self.assertTrue(colormapObject.isAutoscale() is True)
+
+ def testCopy(self):
+ """Make sure the copy function is correctly processing
+ """
+ colormapObject = Colormap(name=None,
+ colors=numpy.array([[1., 0., 0.],
+ [0., 1., 0.],
+ [0., 0., 1.]]),
+ vmin=None,
+ vmax=None,
+ normalization=Colormap.LOGARITHM)
+
+ colormapObject2 = colormapObject.copy()
+ self.assertTrue(colormapObject == colormapObject2)
+ colormapObject.setColormapLUT([[0, 0, 0], [255, 255, 255]])
+ self.assertFalse(colormapObject == colormapObject2)
+
+ colormapObject2 = colormapObject.copy()
+ self.assertTrue(colormapObject == colormapObject2)
+ colormapObject.setNormalization(Colormap.LINEAR)
+ self.assertFalse(colormapObject == colormapObject2)
+
+ def testGetColorMapRange(self):
+ """Make sure the getColormapRange function of colormap is correctly
+ applying
+ """
+ # test linear scale
+ data = numpy.array([-1, 1, 2, 3, float('nan')])
+ cl1 = Colormap(name='gray',
+ normalization=Colormap.LINEAR,
+ vmin=0,
+ vmax=2)
+ cl2 = Colormap(name='gray',
+ normalization=Colormap.LINEAR,
+ vmin=None,
+ vmax=2)
+ cl3 = Colormap(name='gray',
+ normalization=Colormap.LINEAR,
+ vmin=0,
+ vmax=None)
+ cl4 = Colormap(name='gray',
+ normalization=Colormap.LINEAR,
+ vmin=None,
+ vmax=None)
+
+ self.assertTrue(cl1.getColormapRange(data) == (0, 2))
+ self.assertTrue(cl2.getColormapRange(data) == (-1, 2))
+ self.assertTrue(cl3.getColormapRange(data) == (0, 3))
+ self.assertTrue(cl4.getColormapRange(data) == (-1, 3))
+
+ # test linear with annoying cases
+ self.assertEqual(cl3.getColormapRange((-1, -2)), (0, 0))
+ self.assertEqual(cl4.getColormapRange(()), (0., 1.))
+ self.assertEqual(cl4.getColormapRange(
+ (float('nan'), float('inf'), 1., -float('inf'), 2)), (1., 2.))
+ self.assertEqual(cl4.getColormapRange(
+ (float('nan'), float('inf'))), (0., 1.))
+
+ # test log scale
+ data = numpy.array([float('nan'), -1, 1, 10, 100, 1000])
+ cl1 = Colormap(name='gray',
+ normalization=Colormap.LOGARITHM,
+ vmin=1,
+ vmax=100)
+ cl2 = Colormap(name='gray',
+ normalization=Colormap.LOGARITHM,
+ vmin=None,
+ vmax=100)
+ cl3 = Colormap(name='gray',
+ normalization=Colormap.LOGARITHM,
+ vmin=1,
+ vmax=None)
+ cl4 = Colormap(name='gray',
+ normalization=Colormap.LOGARITHM,
+ vmin=None,
+ vmax=None)
+
+ self.assertTrue(cl1.getColormapRange(data) == (1, 100))
+ self.assertTrue(cl2.getColormapRange(data) == (1, 100))
+ self.assertTrue(cl3.getColormapRange(data) == (1, 1000))
+ self.assertTrue(cl4.getColormapRange(data) == (1, 1000))
+
+ # test log with annoying cases
+ self.assertEqual(cl3.getColormapRange((0.1, 0.2)), (1, 1))
+ self.assertEqual(cl4.getColormapRange((-2., -1.)), (1., 1.))
+ self.assertEqual(cl4.getColormapRange(()), (1., 10.))
+ self.assertEqual(cl4.getColormapRange(
+ (float('nan'), float('inf'), 1., -float('inf'), 2)), (1., 2.))
+ self.assertEqual(cl4.getColormapRange(
+ (float('nan'), float('inf'))), (1., 10.))
+
+ def testApplyToData(self):
+ """Test applyToData on different datasets"""
+ datasets = [
+ numpy.zeros((0, 0)), # Empty array
+ numpy.array((numpy.nan, numpy.inf)), # All non-finite
+ numpy.array((-numpy.inf, numpy.inf, 1.0, 2.0)), # Some infinite
+ ]
+
+ for normalization in ('linear', 'log'):
+ colormap = Colormap(name='gray',
+ normalization=normalization,
+ vmin=None,
+ vmax=None)
+
+ for data in datasets:
+ with self.subTest(data=data):
+ image = colormap.applyToData(data)
+ self.assertEqual(image.dtype, numpy.uint8)
+ self.assertEqual(image.shape[-1], 4)
+ self.assertEqual(image.shape[:-1], data.shape)
+
+ def testGetNColors(self):
+ """Test getNColors method"""
+ # specific LUT
+ colormap = Colormap(name=None,
+ colors=((0., 0., 0.), (1., 1., 1.)),
+ vmin=1000,
+ vmax=2000)
+ colors = colormap.getNColors()
+ self.assertTrue(numpy.all(numpy.equal(
+ colors,
+ ((0, 0, 0, 255), (255, 255, 255, 255)))))
+
+ def testEditableMode(self):
+ """Make sure the colormap will raise NotEditableError when try to
+ change a colormap not editable"""
+ colormap = Colormap()
+ colormap.setEditable(False)
+ with self.assertRaises(NotEditableError):
+ colormap.setVRange(0., 1.)
+ with self.assertRaises(NotEditableError):
+ colormap.setVMin(1.)
+ with self.assertRaises(NotEditableError):
+ colormap.setVMax(1.)
+ with self.assertRaises(NotEditableError):
+ colormap.setNormalization(Colormap.LOGARITHM)
+ with self.assertRaises(NotEditableError):
+ colormap.setName('magma')
+ with self.assertRaises(NotEditableError):
+ colormap.setColormapLUT([[0., 0., 0.], [1., 1., 1.]])
+ with self.assertRaises(NotEditableError):
+ colormap._setFromDict(colormap._toDict())
+ state = colormap.saveState()
+ with self.assertRaises(NotEditableError):
+ colormap.restoreState(state)
+
+ def testBadColorsType(self):
+ """Make sure colors can't be something else than an array"""
+ with self.assertRaises(TypeError):
+ Colormap(colors=256)
+
+ def testEqual(self):
+ colormap1 = Colormap()
+ colormap2 = Colormap()
+ self.assertEqual(colormap1, colormap2)
+
+ def testCompareString(self):
+ colormap = Colormap()
+ self.assertNotEqual(colormap, "a")
+
+ def testCompareNone(self):
+ colormap = Colormap()
+ self.assertNotEqual(colormap, None)
+
+ def testSet(self):
+ colormap = Colormap()
+ other = Colormap(name="viridis", vmin=1, vmax=2, normalization=Colormap.LOGARITHM)
+ self.assertNotEqual(colormap, other)
+ colormap.setFromColormap(other)
+ self.assertIsNot(colormap, other)
+ self.assertEqual(colormap, other)
+
+ def testAutoscaleMode(self):
+ colormap = Colormap(autoscaleMode=Colormap.STDDEV3)
+ self.assertEqual(colormap.getAutoscaleMode(), Colormap.STDDEV3)
+ colormap.setAutoscaleMode(Colormap.MINMAX)
+ self.assertEqual(colormap.getAutoscaleMode(), Colormap.MINMAX)
+
+ def testStoreRestore(self):
+ colormaps = [
+ Colormap(name="viridis"),
+ Colormap(normalization=Colormap.SQRT)
+ ]
+ cmap = Colormap(normalization=Colormap.GAMMA)
+ cmap.setGammaNormalizationParameter(1.2)
+ cmap.setNaNColor('red')
+ colormaps.append(cmap)
+ for expected in colormaps:
+ with self.subTest(colormap=expected):
+ state = expected.saveState()
+ result = Colormap()
+ result.restoreState(state)
+ self.assertEqual(expected, result)
+
+ def testStorageV1(self):
+ state = b'\x00\x00\x00\x10\x00C\x00o\x00l\x00o\x00r\x00m\x00a\x00p\x00\x00'\
+ b'\x00\x01\x00\x00\x00\x0E\x00v\x00i\x00r\x00i\x00d\x00i\x00s\x00'\
+ b'\x00\x00\x00\x06\x00?\xF0\x00\x00\x00\x00\x00\x00\x00\x00\x00'\
+ b'\x00\x06\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x06\x00'\
+ b'l\x00o\x00g'
+ state = qt.QByteArray(state)
+ colormap = Colormap()
+ colormap.restoreState(state)
+
+ expected = Colormap(name="viridis", vmin=1, vmax=2, normalization=Colormap.LOGARITHM)
+ self.assertEqual(colormap, expected)
+
+ def testStorageV2(self):
+ state = b'\x00\x00\x00\x10\x00C\x00o\x00l\x00o\x00r\x00m\x00a\x00p\x00'\
+ b'\x00\x00\x02\x00\x00\x00\x0e\x00v\x00i\x00r\x00i\x00d\x00i\x00'\
+ b's\x00\x00\x00\x00\x06\x00?\xf0\x00\x00\x00\x00\x00\x00\x00\x00'\
+ b'\x00\x00\x06\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x06'\
+ b'\x00l\x00o\x00g\x00\x00\x00\x0c\x00m\x00i\x00n\x00m\x00a\x00x'
+ state = qt.QByteArray(state)
+ colormap = Colormap()
+ colormap.restoreState(state)
+
+ expected = Colormap(name="viridis", vmin=1, vmax=2, normalization=Colormap.LOGARITHM)
+ expected.setGammaNormalizationParameter(1.5)
+ self.assertEqual(colormap, expected)
+
+
+class TestPreferredColormaps(unittest.TestCase):
+ """Test get|setPreferredColormaps functions"""
+
+ def setUp(self):
+ # Save preferred colormaps
+ self._colormaps = colors.preferredColormaps()
+
+ def tearDown(self):
+ # Restore saved preferred colormaps
+ colors.setPreferredColormaps(self._colormaps)
+
+ def test(self):
+ colormaps = 'viridis', 'magma'
+
+ colors.setPreferredColormaps(colormaps)
+ self.assertEqual(colors.preferredColormaps(), colormaps)
+
+ with self.assertRaises(ValueError):
+ colors.setPreferredColormaps(())
+
+ with self.assertRaises(ValueError):
+ colors.setPreferredColormaps(('This is not a colormap',))
+
+ colormaps = 'red', 'green'
+ colors.setPreferredColormaps(('This is not a colormap',) + colormaps)
+ self.assertEqual(colors.preferredColormaps(), colormaps)
+
+
+class TestRegisteredLut(unittest.TestCase):
+ """Test get|setPreferredColormaps functions"""
+
+ def setUp(self):
+ # Save preferred colormaps
+ lut = numpy.arange(8 * 3)
+ lut.shape = -1, 3
+ lut = lut / (8.0 * 3)
+ colors.registerLUT("test_8", colors=lut, cursor_color='blue')
+
+ def testColormap(self):
+ colormap = Colormap("test_8")
+ self.assertIsNotNone(colormap)
+
+ def testCursor(self):
+ color = colors.cursorColorForColormap("test_8")
+ self.assertEqual(color, 'blue')
+
+ def testLut(self):
+ colormap = Colormap("test_8")
+ colors = colormap.getNColors(8)
+ self.assertEqual(len(colors), 8)
+
+ def testUint8(self):
+ lut = numpy.array([[255, 0, 0], [200, 0, 0], [150, 0, 0]], dtype="uint")
+ colors.registerLUT("test_type", lut)
+ colormap = colors.Colormap(name="test_type")
+ lut = colormap.getNColors(3)
+ self.assertEqual(lut.shape, (3, 4))
+ self.assertEqual(lut[0, 0], 255)
+
+ def testFloatRGB(self):
+ lut = numpy.array([[1.0, 0, 0], [0.5, 0, 0], [0, 0, 0]], dtype="float")
+ colors.registerLUT("test_type", lut)
+ colormap = colors.Colormap(name="test_type")
+ lut = colormap.getNColors(3)
+ self.assertEqual(lut.shape, (3, 4))
+ self.assertEqual(lut[0, 0], 255)
+
+ def testFloatRGBA(self):
+ lut = numpy.array([[1.0, 0, 0, 128 / 256.0], [0.5, 0, 0, 1.0], [0.0, 0, 0, 1.0]], dtype="float")
+ colors.registerLUT("test_type", lut)
+ colormap = colors.Colormap(name="test_type")
+ lut = colormap.getNColors(3)
+ self.assertEqual(lut.shape, (3, 4))
+ self.assertEqual(lut[0, 0], 255)
+ self.assertEqual(lut[0, 3], 128)
+
+
+class TestAutoscaleRange(ParametricTestCase):
+
+ def testAutoscaleRange(self):
+ nan = numpy.nan
+ data_std_inside = numpy.array([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2])
+ data_std_inside_nan = numpy.array([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, numpy.nan])
+ data = [
+ # Positive values
+ (Colormap.LINEAR, Colormap.MINMAX, numpy.array([10, 20, 50]), (10, 50)),
+ (Colormap.LOGARITHM, Colormap.MINMAX, numpy.array([10, 50, 100]), (10, 100)),
+ (Colormap.LINEAR, Colormap.STDDEV3, data_std_inside, (0.026671473215424735, 1.9733285267845753)),
+ (Colormap.LOGARITHM, Colormap.STDDEV3, data_std_inside, (1, 1.6733506885453602)),
+ (Colormap.LINEAR, Colormap.STDDEV3, numpy.array([10, 100]), (10, 100)),
+ (Colormap.LOGARITHM, Colormap.STDDEV3, numpy.array([10, 100]), (10, 100)),
+
+ # With nan
+ (Colormap.LINEAR, Colormap.MINMAX, numpy.array([10, 20, 50, nan]), (10, 50)),
+ (Colormap.LOGARITHM, Colormap.MINMAX, numpy.array([10, 50, 100, nan]), (10, 100)),
+ (Colormap.LINEAR, Colormap.STDDEV3, data_std_inside_nan, (0.026671473215424735, 1.9733285267845753)),
+ (Colormap.LOGARITHM, Colormap.STDDEV3, data_std_inside_nan, (1, 1.6733506885453602)),
+ # With negative
+ (Colormap.LOGARITHM, Colormap.MINMAX, numpy.array([10, 50, 100, -50]), (10, 100)),
+ (Colormap.LOGARITHM, Colormap.STDDEV3, numpy.array([10, 100, -10]), (10, 100)),
+ ]
+ for norm, mode, array, expectedRange in data:
+ with self.subTest(norm=norm, mode=mode, array=array):
+ colormap = Colormap()
+ colormap.setNormalization(norm)
+ colormap.setAutoscaleMode(mode)
+ vRange = colormap._computeAutoscaleRange(array)
+ if vRange is None:
+ self.assertIsNone(expectedRange)
+ else:
+ self.assertAlmostEqual(vRange[0], expectedRange[0])
+ self.assertAlmostEqual(vRange[1], expectedRange[1])
diff --git a/src/silx/gui/test/test_console.py b/src/silx/gui/test/test_console.py
new file mode 100644
index 0000000..21f3564
--- /dev/null
+++ b/src/silx/gui/test/test_console.py
@@ -0,0 +1,75 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for IPython console widget"""
+
+from __future__ import print_function
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+
+import pytest
+from silx.gui import qt
+
+
+# dummy objects to test pushing variables to the interactive namespace
+_a = 1
+
+
+def _f():
+ print("Hello World!")
+
+
+@pytest.fixture
+def console(qapp_utils):
+ """Create a console widget"""
+ # Console tests disabled due to corruption of python environment
+ pytest.skip("Disabled (see issue #538)")
+ try:
+ from silx.gui.console import IPythonDockWidget
+ except ImportError:
+ pytest.skip("IPythonDockWidget is not available")
+
+ console = IPythonDockWidget(
+ available_vars={"a": _a, "f": _f},
+ custom_banner="Welcome!\n")
+ console.show()
+ qapp_utils.qWaitForWindowExposed(console)
+ yield console
+ console.setAttribute(qt.Qt.WA_DeleteOnClose)
+ console.close()
+ console = None
+
+
+def testShow(console):
+ pass
+
+
+def testInteract(console, qapp_utils):
+ qapp_utils.mouseClick(console, qt.Qt.LeftButton)
+ qapp_utils.keyClicks(console, 'import silx')
+ qapp_utils.keyClick(console, qt.Qt.Key_Enter)
+ qapp_utils.qapp.processEvents()
diff --git a/src/silx/gui/test/test_icons.py b/src/silx/gui/test/test_icons.py
new file mode 100644
index 0000000..154adf6
--- /dev/null
+++ b/src/silx/gui/test/test_icons.py
@@ -0,0 +1,144 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic test of Qt icons module."""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "06/09/2017"
+
+
+import unittest
+import weakref
+import tempfile
+import shutil
+import os
+
+import silx.resources
+from silx.gui import qt
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui import icons
+
+
+class TestIcons(TestCaseQt):
+ """Test to check that icons module."""
+
+ @classmethod
+ def setUpClass(cls):
+ super(TestIcons, cls).setUpClass()
+
+ cls.tmpDirectory = tempfile.mkdtemp(prefix="resource_")
+ os.mkdir(os.path.join(cls.tmpDirectory, "gui"))
+ destination = os.path.join(cls.tmpDirectory, "gui", "icons")
+ os.mkdir(destination)
+ shutil.copy(silx.resources.resource_filename("gui/icons/zoom-in.png"), destination)
+ shutil.copy(silx.resources.resource_filename("gui/icons/zoom-out.svg"), destination)
+
+ @classmethod
+ def tearDownClass(cls):
+ super(TestIcons, cls).tearDownClass()
+ shutil.rmtree(cls.tmpDirectory)
+
+ def setUp(self):
+ # Store the original configuration
+ self._oldResources = dict(silx.resources._RESOURCE_DIRECTORIES)
+ silx.resources.register_resource_directory("test", "foo.bar", forced_path=self.tmpDirectory)
+ unittest.TestCase.setUp(self)
+
+ def tearDown(self):
+ unittest.TestCase.tearDown(self)
+ # Restiture the original configuration
+ silx.resources._RESOURCE_DIRECTORIES = self._oldResources
+
+ def testIcon(self):
+ icon = icons.getQIcon("silx:gui/icons/zoom-out")
+ self.assertIsNotNone(icon)
+
+ def testPrefix(self):
+ icon = icons.getQIcon("silx:gui/icons/zoom-out")
+ self.assertIsNotNone(icon)
+
+ def testSvgIcon(self):
+ if "svg" not in qt.supportedImageFormats():
+ self.skipTest("SVG not supported")
+ icon = icons.getQIcon("test:gui/icons/zoom-out")
+ self.assertIsNotNone(icon)
+
+ def testPngIcon(self):
+ icon = icons.getQIcon("test:gui/icons/zoom-in")
+ self.assertIsNotNone(icon)
+
+ def testUnexistingIcon(self):
+ self.assertRaises(ValueError, icons.getQIcon, "not-exists")
+
+ def testExistingQPixmap(self):
+ icon = icons.getQPixmap("crop")
+ self.assertIsNotNone(icon)
+
+ def testUnexistingQPixmap(self):
+ self.assertRaises(ValueError, icons.getQPixmap, "not-exists")
+
+ def testCache(self):
+ icon1 = icons.getQIcon("crop")
+ icon2 = icons.getQIcon("crop")
+ self.assertIs(icon1, icon2)
+
+ def testCacheReleased(self):
+ icon = icons.getQIcon("crop")
+ icon_ref = weakref.ref(icon)
+ del icon
+ self.assertIsNone(icon_ref())
+
+
+class TestAnimatedIcons(TestCaseQt):
+ """Test to check that icons module."""
+
+ def testProcessWorking(self):
+ icon = icons.getWaitIcon()
+ self.assertIsNotNone(icon)
+
+ def testProcessWorkingCache(self):
+ icon1 = icons.getWaitIcon()
+ icon2 = icons.getWaitIcon()
+ self.assertIs(icon1, icon2)
+
+ def testMovieIconExists(self):
+ if "mng" not in qt.supportedImageFormats():
+ self.skipTest("MNG not supported")
+ icon = icons.MovieAnimatedIcon("process-working")
+ self.assertIsNotNone(icon)
+
+ def testMovieIconNotExists(self):
+ self.assertRaises(ValueError, icons.MovieAnimatedIcon, "not-exists")
+
+ def testMultiImageIconExists(self):
+ icon = icons.MultiImageAnimatedIcon("process-working")
+ self.assertIsNotNone(icon)
+
+ def testPrefixedResourceExists(self):
+ icon = icons.MultiImageAnimatedIcon("silx:gui/icons/process-working")
+ self.assertIsNotNone(icon)
+
+ def testMultiImageIconNotExists(self):
+ self.assertRaises(ValueError, icons.MultiImageAnimatedIcon, "not-exists")
diff --git a/src/silx/gui/test/test_qt.py b/src/silx/gui/test/test_qt.py
new file mode 100644
index 0000000..8554744
--- /dev/null
+++ b/src/silx/gui/test/test_qt.py
@@ -0,0 +1,212 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic test of Qt bindings wrapper."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+
+import os.path
+import unittest
+import pytest
+
+from silx.test.utils import temp_dir
+from silx.gui.utils.testutils import TestCaseQt
+
+from silx.gui import qt
+try:
+ from silx.gui.qt import inspect as qt_inspect
+except ImportError:
+ qt_inspect = None
+
+
+class TestQtWrapper(unittest.TestCase):
+ """Minimalistic test to check that Qt has been loaded."""
+
+ def testQObject(self):
+ """Test that QObject is there."""
+ obj = qt.QObject()
+ self.assertTrue(obj is not None)
+
+
+class TestLoadUi(TestCaseQt):
+ """Test loadUi function"""
+
+ TEST_UI = """<?xml version="1.0" encoding="UTF-8"?>
+ <ui version="4.0">
+ <class>MainWindow</class>
+ <widget class="QMainWindow" name="MainWindow">
+ <property name="geometry">
+ <rect>
+ <x>0</x>
+ <y>0</y>
+ <width>293</width>
+ <height>296</height>
+ </rect>
+ </property>
+ <property name="windowTitle">
+ <string>Test loadUi</string>
+ </property>
+ <widget class="QWidget" name="centralwidget">
+ <widget class="QPushButton" name="pushButton">
+ <property name="geometry">
+ <rect>
+ <x>10</x>
+ <y>10</y>
+ <width>89</width>
+ <height>27</height>
+ </rect>
+ </property>
+ <property name="text">
+ <string>Button 1</string>
+ </property>
+ </widget>
+ <widget class="QPushButton" name="pushButton_2">
+ <property name="geometry">
+ <rect>
+ <x>10</x>
+ <y>50</y>
+ <width>89</width>
+ <height>27</height>
+ </rect>
+ </property>
+ <property name="text">
+ <string>Button 2</string>
+ </property>
+ </widget>
+ <widget class="Line" name="line">
+ <property name="geometry">
+ <rect>
+ <x>10</x>
+ <y>90</y>
+ <width>118</width>
+ <height>3</height>
+ </rect>
+ </property>
+ <property name="orientation">
+ <enum>Qt::Horizontal</enum>
+ </property>
+ </widget>
+ <widget class="Line" name="line_2">
+ <property name="geometry">
+ <rect>
+ <x>150</x>
+ <y>20</y>
+ <width>3</width>
+ <height>61</height>
+ </rect>
+ </property>
+ <property name="orientation">
+ <enum>Qt::Vertical</enum>
+ </property>
+ </widget>
+ </widget>
+ <widget class="QMenuBar" name="menubar">
+ <property name="geometry">
+ <rect>
+ <x>0</x>
+ <y>0</y>
+ <width>293</width>
+ <height>25</height>
+ </rect>
+ </property>
+ </widget>
+ <widget class="QStatusBar" name="statusbar"/>
+ </widget>
+ <resources/>
+ <connections/>
+ </ui>
+ """
+
+ def testLoadUi(self):
+ """Create a QMainWindow from an ui file"""
+ with temp_dir() as tmp:
+ uifile = os.path.join(tmp, "test.ui")
+
+ # write file
+ with open(uifile, mode='w') as f:
+ f.write(self.TEST_UI)
+
+ class TestMainWindow(qt.QMainWindow):
+ def __init__(self, parent=None):
+ super(TestMainWindow, self).__init__(parent)
+ qt.loadUi(uifile, self)
+
+ testMainWindow = TestMainWindow()
+ testMainWindow.show()
+ self.qWaitForWindowExposed(testMainWindow)
+
+ testMainWindow.setAttribute(qt.Qt.WA_DeleteOnClose)
+ testMainWindow.close()
+
+
+class TestQtInspect(unittest.TestCase):
+ """Test functions of silx.gui.qt.inspect module"""
+
+ def test(self):
+ """Test functions of silx.gui.qt.inspect module"""
+ self.assertIsNotNone(qt_inspect)
+
+ parent = qt.QObject()
+
+ self.assertTrue(qt_inspect.isValid(parent))
+ self.assertTrue(qt_inspect.createdByPython(parent))
+ self.assertTrue(qt_inspect.ownedByPython(parent))
+
+ obj = qt.QObject(parent)
+
+ self.assertTrue(qt_inspect.isValid(obj))
+ self.assertTrue(qt_inspect.createdByPython(obj))
+ self.assertFalse(qt_inspect.ownedByPython(obj))
+
+ del parent
+ self.assertFalse(qt_inspect.isValid(obj))
+
+
+@pytest.mark.skipif(qt.BINDING not in ("PyQt5", "PySide2"),
+ reason="PyQt5/PySide2 only test")
+def test_exec_():
+ """Test the exec_ is still useable with Qt5 bindings"""
+ klasses = [
+ #QtWidgets
+ qt.QApplication,
+ qt.QColorDialog,
+ qt.QDialog,
+ qt.QErrorMessage,
+ qt.QFileDialog,
+ qt.QFontDialog,
+ qt.QInputDialog,
+ qt.QMenu,
+ qt.QMessageBox,
+ qt.QProgressDialog,
+ #QtCore
+ qt.QCoreApplication,
+ qt.QEventLoop,
+ qt.QThread,
+ ]
+ for klass in klasses:
+ assert hasattr(klass, "exec") and callable(klass.exec), "%s.exec missing" % klass.__name__
+ assert hasattr(klass, "exec_") and callable(klass.exec_), "%s.exec_ missing" % klass.__name__
diff --git a/src/silx/gui/test/utils.py b/src/silx/gui/test/utils.py
new file mode 100644
index 0000000..db4c0ee
--- /dev/null
+++ b/src/silx/gui/test/utils.py
@@ -0,0 +1,43 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Color conversion function, color dictionary and colormap tools."""
+
+from __future__ import absolute_import
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "05/10/2018"
+
+import silx.utils.deprecation
+
+silx.utils.deprecation.deprecated_warning("Module",
+ name="silx.gui.test.utils",
+ reason="moved",
+ replacement="silx.gui.utils.testutils",
+ since_version="0.9.0",
+ only_once=True,
+ skip_backtrace_count=1)
+
+from ..utils.testutils import * # noqa
diff --git a/src/silx/gui/utils/__init__.py b/src/silx/gui/utils/__init__.py
new file mode 100755
index 0000000..726ad74
--- /dev/null
+++ b/src/silx/gui/utils/__init__.py
@@ -0,0 +1,76 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Miscellaneous helpers for Qt"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "09/03/2018"
+
+
+import contextlib as _contextlib
+
+
+@_contextlib.contextmanager
+def blockSignals(*objs):
+ """Context manager blocking signals of QObjects.
+
+ It restores previous state when leaving.
+
+ :param qt.QObject objs: QObjects for which to block signals
+ """
+ blocked = [(obj, obj.blockSignals(True)) for obj in objs]
+ try:
+ yield
+ finally:
+ for obj, previous in blocked:
+ obj.blockSignals(previous)
+
+
+class LockReentrant():
+ """Context manager to lock a code block and check the state.
+ """
+ def __init__(self):
+ self.__locked = False
+
+ def __enter__(self):
+ self.__locked = True
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.__locked = False
+
+ def locked(self):
+ """Returns True if the code block is locked"""
+ return self.__locked
+
+
+def getQEventName(eventType):
+ """
+ Returns the name of a QEvent.
+
+ :param Union[int,qt.QEvent] eventType: A QEvent or a QEvent type.
+ :returns: str
+ """
+ from . import qtutils
+ return qtutils.getQEventName(eventType)
diff --git a/src/silx/gui/utils/concurrent.py b/src/silx/gui/utils/concurrent.py
new file mode 100644
index 0000000..c27374f
--- /dev/null
+++ b/src/silx/gui/utils/concurrent.py
@@ -0,0 +1,105 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module allows to run a function in Qt main thread from another thread
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "09/03/2018"
+
+
+from concurrent.futures import Future
+
+from .. import qt
+
+
+class _QtExecutor(qt.QObject):
+ """Executor of tasks in Qt main thread"""
+
+ __sigSubmit = qt.Signal(Future, object, tuple, dict)
+ """Signal used to run tasks."""
+
+ def __init__(self):
+ super(_QtExecutor, self).__init__(parent=None)
+
+ # Makes sure the executor lives in the main thread
+ app = qt.QApplication.instance()
+ assert app is not None
+ mainThread = app.thread()
+ if self.thread() != mainThread:
+ self.moveToThread(mainThread)
+
+ self.__sigSubmit.connect(self.__run)
+
+ def submit(self, fn, *args, **kwargs):
+ """Submit fn(*args, **kwargs) to Qt main thread
+
+ :param callable fn: Function to call in main thread
+ :return: Future object to retrieve result
+ :rtype: concurrent.future.Future
+ """
+ future = Future()
+ self.__sigSubmit.emit(future, fn, args, kwargs)
+ return future
+
+ def __run(self, future, fn, args, kwargs):
+ """Run task in Qt main thread
+
+ :param concurrent.future.Future future:
+ :param callable fn: Function to run
+ :param tuple args: Arguments
+ :param dict kwargs: Keyword arguments
+ """
+ if not future.set_running_or_notify_cancel():
+ return
+
+ try:
+ result = fn(*args, **kwargs)
+ except BaseException as e:
+ future.set_exception(e)
+ else:
+ future.set_result(result)
+
+
+_executor = None
+"""QObject running the tasks in main thread"""
+
+
+def submitToQtMainThread(fn, *args, **kwargs):
+ """Run fn(args, kwargs) in Qt's main thread.
+
+ If not called from the main thread, this is run asynchronously.
+
+ :param callable fn: Function to call in main thread.
+ :return: A future object to retrieve the result
+ :rtype: concurrent.future.Future
+ """
+ global _executor
+ if _executor is None: # Lazy-loading
+ _executor = _QtExecutor()
+
+ return _executor.submit(fn, *args, **kwargs)
diff --git a/src/silx/gui/utils/glutils/__init__.py b/src/silx/gui/utils/glutils/__init__.py
new file mode 100644
index 0000000..20e611e
--- /dev/null
+++ b/src/silx/gui/utils/glutils/__init__.py
@@ -0,0 +1,199 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2020-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the :func:`isOpenGLAvailable` utility function.
+"""
+
+import os
+import sys
+import subprocess
+from silx.gui import qt
+
+
+class _isOpenGLAvailableResult:
+ """Store result of checking OpenGL availability.
+
+ It provides a `status` boolean attribute storing the result of the check and
+ an `error` string attribute storting the possible error message.
+ """
+
+ def __init__(self, status=True, error=''):
+ self.__status = bool(status)
+ self.__error = str(error)
+
+ status = property(lambda self: self.__status, doc="True if OpenGL is working")
+ error = property(lambda self: self.__error, doc="Error message")
+
+ def __bool__(self):
+ return self.status
+
+ def __repr__(self):
+ return '<_isOpenGLAvailableResult: %s, "%s">' % (self.status, self.error)
+
+
+def _runtimeOpenGLCheck(version):
+ """Run OpenGL check in a subprocess.
+
+ This is done by starting a subprocess that displays a Qt OpenGL widget.
+
+ :param List[int] version:
+ The minimal required OpenGL version as a 2-tuple (major, minor).
+ Default: (2, 1)
+ :return: An error string that is empty if no error occured
+ :rtype: str
+ """
+ major, minor = str(version[0]), str(version[1])
+ env = os.environ.copy()
+ env['PYTHONPATH'] = os.pathsep.join(
+ [os.path.abspath(p) for p in sys.path])
+
+ try:
+ error = subprocess.check_output(
+ [sys.executable, '-s', '-S', __file__, major, minor],
+ env=env,
+ timeout=2)
+ except subprocess.TimeoutExpired:
+ status = False
+ error = "Qt OpenGL widget hang"
+ if sys.platform.startswith('linux'):
+ error += ':\nIf connected remotely, GLX forwarding might be disabled.'
+ except subprocess.CalledProcessError as e:
+ status = False
+ error = "Qt OpenGL widget error: retcode=%d, error=%s" % (e.returncode, e.output)
+ else:
+ status = True
+ error = error.decode()
+ return _isOpenGLAvailableResult(status, error)
+
+
+_runtimeCheckCache = {} # Cache runtime check results: {version: result}
+
+
+def isOpenGLAvailable(version=(2, 1), runtimeCheck=True):
+ """Check if OpenGL is available through Qt and actually working.
+
+ After some basic tests, this is done by starting a subprocess that
+ displays a Qt OpenGL widget.
+
+ :param List[int] version:
+ The minimal required OpenGL version as a 2-tuple (major, minor).
+ Default: (2, 1)
+ :param bool runtimeCheck:
+ True (default) to run the test creating a Qt OpenGL widgt in a subprocess,
+ False to avoid this check.
+ :return: A result object that evaluates to True if successful and
+ which has a `status` boolean attribute (True if successful) and
+ an `error` string attribute that is not empty if `status` is False.
+ """
+ error = ''
+
+ if sys.platform.startswith('linux') and not os.environ.get('DISPLAY', ''):
+ # On Linux and no DISPLAY available (e.g., ssh without -X)
+ error = 'DISPLAY environment variable not set'
+
+ else:
+ # Check pyopengl availability
+ try:
+ import silx.gui._glutils.gl # noqa
+ except ImportError:
+ error = "Cannot import OpenGL wrapper: pyopengl is not installed"
+ else:
+ # Pre checks for Qt < 5.4
+ if not hasattr(qt, 'QOpenGLWidget'):
+ if not qt.HAS_OPENGL:
+ error = '%s.QtOpenGL not available' % qt.BINDING
+
+ elif qt.BINDING in ('PySide2', 'PyQt5') and qt.QApplication.instance() and not qt.QGLFormat.hasOpenGL():
+ # qt.QGLFormat.hasOpenGL MUST be called with a QApplication created
+ # so this is only checked if the QApplication is already created
+ error = 'Qt reports OpenGL not available'
+
+ result = _isOpenGLAvailableResult(error == '', error)
+
+ if result: # No error so far, runtime check
+ if version in _runtimeCheckCache: # Use cache
+ result = _runtimeCheckCache[version]
+ elif runtimeCheck: # Run test in subprocess
+ result = _runtimeOpenGLCheck(version)
+ _runtimeCheckCache[version] = result
+
+ return result
+
+
+if __name__ == "__main__":
+ from silx.gui._glutils import OpenGLWidget
+ from silx.gui._glutils import gl
+ import argparse
+
+ class _TestOpenGLWidget(OpenGLWidget):
+ """Widget checking that OpenGL is indeed available
+
+ :param List[int] version: (major, minor) minimum OpenGL version
+ """
+
+ def __init__(self, version):
+ super(_TestOpenGLWidget, self).__init__(
+ alphaBufferSize=0,
+ depthBufferSize=0,
+ stencilBufferSize=0,
+ version=version)
+
+ def paintEvent(self, event):
+ super(_TestOpenGLWidget, self).paintEvent(event)
+
+ # Check once paint has been done
+ app = qt.QApplication.instance()
+ if not self.isValid():
+ print("OpenGL widget is not valid")
+ app.exit(1)
+ else:
+ qt.QTimer.singleShot(100, app.quit)
+
+ def paintGL(self):
+ gl.glClearColor(1., 0., 0., 0.)
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT)
+
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('major')
+ parser.add_argument('minor')
+
+ args = parser.parse_args(args=sys.argv[1:])
+
+ app = qt.QApplication([])
+ window = qt.QMainWindow(flags=
+ qt.Qt.Popup |
+ qt.Qt.FramelessWindowHint |
+ qt.Qt.NoDropShadowWindowHint |
+ qt.Qt.WindowStaysOnTopHint)
+ window.setAttribute(qt.Qt.WA_ShowWithoutActivating)
+ window.move(0, 0)
+ window.resize(3, 3)
+ widget = _TestOpenGLWidget(version=(args.major, args.minor))
+ window.setCentralWidget(widget)
+ window.setWindowOpacity(0.04)
+ window.show()
+
+ qt.QTimer.singleShot(1000, app.quit)
+ sys.exit(app.exec())
diff --git a/src/silx/gui/utils/image.py b/src/silx/gui/utils/image.py
new file mode 100644
index 0000000..96f50ab
--- /dev/null
+++ b/src/silx/gui/utils/image.py
@@ -0,0 +1,143 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides conversions between numpy.ndarray and QImage
+
+- :func:`convertArrayToQImage`
+- :func:`convertQImageToArray`
+"""
+
+from __future__ import division
+
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "04/09/2018"
+
+
+import sys
+import numpy
+from numpy.lib.stride_tricks import as_strided as _as_strided
+
+from .. import qt
+
+
+def convertArrayToQImage(array):
+ """Convert an array-like image to a QImage.
+
+ The created QImage is using a copy of the array data.
+
+ Limitation: Only RGB or RGBA images with 8 bits per channel are supported.
+
+ :param array: Array-like image data of shape (height, width, channels)
+ Channels are expected to be either RGB or RGBA.
+ :type array: numpy.ndarray of uint8
+ :return: Corresponding Qt image with RGB888 or ARGB32 format.
+ :rtype: QImage
+ """
+ array = numpy.array(array, copy=False, order='C', dtype=numpy.uint8)
+
+ if array.ndim != 3 or array.shape[2] not in (3, 4):
+ raise ValueError(
+ 'Image must be a 3D array with 3 or 4 channels per pixel')
+
+ if array.shape[2] == 4:
+ format_ = qt.QImage.Format_ARGB32
+ # RGBA -> ARGB + take care of endianness
+ if sys.byteorder == 'little': # RGBA -> BGRA
+ array = array[:, :, (2, 1, 0, 3)]
+ else: # big endian: RGBA -> ARGB
+ array = array[:, :, (3, 0, 1, 2)]
+
+ array = numpy.array(array, order='C') # Make a contiguous array
+
+ else: # array.shape[2] == 3
+ format_ = qt.QImage.Format_RGB888
+
+ height, width, depth = array.shape
+ qimage = qt.QImage(
+ array.data,
+ width,
+ height,
+ array.strides[0], # bytesPerLine
+ format_)
+
+ return qimage.copy() # Making a copy of the image and its data
+
+
+def convertQImageToArray(image):
+ """Convert a QImage to a numpy array.
+
+ If QImage format is not Format_RGB888, Format_RGBA8888 or Format_ARGB32,
+ it is first converted to one of this format depending on
+ the presence of an alpha channel.
+
+ The created numpy array is using a copy of the QImage data.
+
+ :param QImage image: The QImage to convert.
+ :return: The image array of RGB or RGBA channels of shape
+ (height, width, channels (3 or 4))
+ :rtype: numpy.ndarray of uint8
+ """
+ rgba8888 = getattr(qt.QImage, 'Format_RGBA8888', None) # Only in Qt5
+
+ # Convert to supported format if needed
+ if image.format() not in (qt.QImage.Format_ARGB32,
+ qt.QImage.Format_RGB888,
+ rgba8888):
+ if image.hasAlphaChannel():
+ image = image.convertToFormat(
+ rgba8888 if rgba8888 is not None else qt.QImage.Format_ARGB32)
+ else:
+ image = image.convertToFormat(qt.QImage.Format_RGB888)
+
+ format_ = image.format()
+ channels = 3 if format_ == qt.QImage.Format_RGB888 else 4
+
+ ptr = image.bits()
+ if qt.BINDING == 'PyQt5':
+ ptr.setsize(image.byteCount())
+ elif qt.BINDING in ('PySide2', 'PySide6'):
+ ptr = ptr.tobytes()
+ else:
+ raise RuntimeError("Unsupported Qt binding: %s" % qt.BINDING)
+
+ # Create an array view on QImage internal data
+ view = _as_strided(
+ numpy.frombuffer(ptr, dtype=numpy.uint8),
+ shape=(image.height(), image.width(), channels),
+ strides=(image.bytesPerLine(), channels, 1))
+
+ if format_ == qt.QImage.Format_ARGB32:
+ # Convert from ARGB to RGBA
+ # Not a byte-ordered format: do care about endianness
+ if sys.byteorder == 'little': # BGRA -> RGBA
+ view = view[:, :, (2, 1, 0, 3)]
+ else: # big endian: ARGB -> RGBA
+ view = view[:, :, (1, 2, 3, 0)]
+
+ # Format_RGB888 and Format_RGBA8888 do not need reshuffling channels:
+ # They are byte-ordered and already in the right order
+
+ return numpy.array(view, copy=True, order='C')
diff --git a/src/silx/gui/utils/matplotlib.py b/src/silx/gui/utils/matplotlib.py
new file mode 100644
index 0000000..90257f8
--- /dev/null
+++ b/src/silx/gui/utils/matplotlib.py
@@ -0,0 +1,65 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+from __future__ import absolute_import
+
+"""This module initializes matplotlib and sets-up the backend to use.
+
+It MUST be imported prior to any other import of matplotlib.
+
+It provides the matplotlib :class:`FigureCanvasQTAgg` class corresponding
+to the used backend.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "02/05/2018"
+
+
+from pkg_resources import parse_version
+import matplotlib
+
+from .. import qt
+
+
+def _matplotlib_use(backend, force):
+ """Wrapper of `matplotlib.use` to set-up backend.
+
+ It adds extra initialization for PySide2 with matplotlib < 2.2.
+ """
+ # This is kept for compatibility with matplotlib < 2.2
+ if (parse_version(matplotlib.__version__) < parse_version('2.2') and
+ qt.BINDING == 'PySide2'):
+ matplotlib.rcParams['backend.qt5'] = 'PySide2'
+
+ matplotlib.use(backend, force=force)
+
+
+if qt.BINDING in ('PySide6', 'PyQt5', 'PySide2'):
+ _matplotlib_use('Qt5Agg', force=False)
+ from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg # noqa
+
+else:
+ raise ImportError("Unsupported Qt binding: %s" % qt.BINDING)
diff --git a/src/silx/gui/utils/projecturl.py b/src/silx/gui/utils/projecturl.py
new file mode 100644
index 0000000..0832c2e
--- /dev/null
+++ b/src/silx/gui/utils/projecturl.py
@@ -0,0 +1,77 @@
+# coding: utf-8
+#
+# Project: Azimuthal integration
+# https://github.com/silx-kit/silx
+#
+# Copyright (C) 2015-2019 European Synchrotron Radiation Facility, Grenoble, France
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+
+from __future__ import absolute_import, print_function, division
+
+"""Provide convenient URL for silx-kit projects."""
+
+__author__ = "Valentin Valls"
+__contact__ = "valentin.valls@ESRF.eu"
+__license__ = "MIT"
+__copyright__ = "European Synchrotron Radiation Facility, Grenoble, France"
+__date__ = "15/01/2019"
+
+
+from ... import _version as version
+
+BASE_DOC_URL = None
+"""This could be patched by project packagers."""
+
+_DEFAULT_BASE_DOC_URL = "http://www.silx.org/pub/doc/silx/{silx_doc_version}/{subpath}"
+"""Identify the base URL of the project documentation.
+
+It supportes string replacement:
+
+- `{major}` the major version
+- `{minor}` the minor version
+- `{micro}` the micro version
+- `{relev}` the status of the version (dev, final, rc).
+- `{silx_doc_version}` is used to map the documentation stored at www.silx.org
+- `{subpath}` is the subpart of the URL pointing to a specific page of the
+ documentation. It is mandatory.
+"""
+
+
+def getDocumentationUrl(subpath):
+ """Returns the URL to the documentation"""
+
+ if version.RELEV == "final":
+ # Released verison will point to a specific documentation
+ silx_doc_version = "%d.%d.%d" % (version.MAJOR, version.MINOR, version.MICRO)
+ else:
+ # Dev versions will point to a single 'dev' documentation
+ silx_doc_version = "dev"
+
+ keyworks = {
+ "silx_doc_version": silx_doc_version,
+ "major": version.MAJOR,
+ "minor": version.MINOR,
+ "micro": version.MICRO,
+ "relev": version.RELEV,
+ "subpath": subpath}
+ template = BASE_DOC_URL
+ if template is None:
+ template = _DEFAULT_BASE_DOC_URL
+ return template.format(**keyworks)
diff --git a/src/silx/gui/utils/qtutils.py b/src/silx/gui/utils/qtutils.py
new file mode 100755
index 0000000..9682913
--- /dev/null
+++ b/src/silx/gui/utils/qtutils.py
@@ -0,0 +1,196 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the :func:`getQEventName` utility function."""
+
+from silx.gui import qt
+
+
+QT_EVENT_NAMES = {
+ 0: "None",
+ 114: "ActionAdded",
+ 113: "ActionChanged",
+ 115: "ActionRemoved",
+ 99: "ActivationChange",
+ 121: "ApplicationActivate",
+ # ApplicationActivate: "ApplicationActivated",
+ 122: "ApplicationDeactivate",
+ 36: "ApplicationFontChange",
+ 37: "ApplicationLayoutDirectionChange",
+ 38: "ApplicationPaletteChange",
+ 214: "ApplicationStateChange",
+ 35: "ApplicationWindowIconChange",
+ 68: "ChildAdded",
+ 69: "ChildPolished",
+ 71: "ChildRemoved",
+ 40: "Clipboard",
+ 19: "Close",
+ 200: "CloseSoftwareInputPanel",
+ 178: "ContentsRectChange",
+ 82: "ContextMenu",
+ 183: "CursorChange",
+ 52: "DeferredDelete",
+ 60: "DragEnter",
+ 62: "DragLeave",
+ 61: "DragMove",
+ 63: "Drop",
+ 170: "DynamicPropertyChange",
+ 98: "EnabledChange",
+ 10: "Enter",
+ 150: "EnterEditFocus",
+ 124: "EnterWhatsThisMode",
+ 206: "Expose",
+ 116: "FileOpen",
+ 8: "FocusIn",
+ 9: "FocusOut",
+ 23: "FocusAboutToChange",
+ 97: "FontChange",
+ 198: "Gesture",
+ 202: "GestureOverride",
+ 188: "GrabKeyboard",
+ 186: "GrabMouse",
+ 159: "GraphicsSceneContextMenu",
+ 164: "GraphicsSceneDragEnter",
+ 166: "GraphicsSceneDragLeave",
+ 165: "GraphicsSceneDragMove",
+ 167: "GraphicsSceneDrop",
+ 163: "GraphicsSceneHelp",
+ 160: "GraphicsSceneHoverEnter",
+ 162: "GraphicsSceneHoverLeave",
+ 161: "GraphicsSceneHoverMove",
+ 158: "GraphicsSceneMouseDoubleClick",
+ 155: "GraphicsSceneMouseMove",
+ 156: "GraphicsSceneMousePress",
+ 157: "GraphicsSceneMouseRelease",
+ 182: "GraphicsSceneMove",
+ 181: "GraphicsSceneResize",
+ 168: "GraphicsSceneWheel",
+ 18: "Hide",
+ 27: "HideToParent",
+ 127: "HoverEnter",
+ 128: "HoverLeave",
+ 129: "HoverMove",
+ 96: "IconDrag",
+ 101: "IconTextChange",
+ 83: "InputMethod",
+ 207: "InputMethodQuery",
+ 169: "KeyboardLayoutChange",
+ 6: "KeyPress",
+ 7: "KeyRelease",
+ 89: "LanguageChange",
+ 90: "LayoutDirectionChange",
+ 76: "LayoutRequest",
+ 11: "Leave",
+ 151: "LeaveEditFocus",
+ 125: "LeaveWhatsThisMode",
+ 88: "LocaleChange",
+ 176: "NonClientAreaMouseButtonDblClick",
+ 174: "NonClientAreaMouseButtonPress",
+ 175: "NonClientAreaMouseButtonRelease",
+ 173: "NonClientAreaMouseMove",
+ 177: "MacSizeChange",
+ 43: "MetaCall",
+ 102: "ModifiedChange",
+ 4: "MouseButtonDblClick",
+ 2: "MouseButtonPress",
+ 3: "MouseButtonRelease",
+ 5: "MouseMove",
+ 109: "MouseTrackingChange",
+ 13: "Move",
+ 197: "NativeGesture",
+ 208: "OrientationChange",
+ 12: "Paint",
+ 39: "PaletteChange",
+ 131: "ParentAboutToChange",
+ 21: "ParentChange",
+ 212: "PlatformPanel",
+ 217: "PlatformSurface",
+ 75: "Polish",
+ 74: "PolishRequest",
+ 123: "QueryWhatsThis",
+ 106: "ReadOnlyChange",
+ 199: "RequestSoftwareInputPanel",
+ 14: "Resize",
+ 204: "ScrollPrepare",
+ 205: "Scroll",
+ 117: "Shortcut",
+ 51: "ShortcutOverride",
+ 17: "Show",
+ 26: "ShowToParent",
+ 50: "SockAct",
+ 192: "StateMachineSignal",
+ 193: "StateMachineWrapped",
+ 112: "StatusTip",
+ 100: "StyleChange",
+ 87: "TabletMove",
+ 92: "TabletPress",
+ 93: "TabletRelease",
+ 171: "TabletEnterProximity",
+ 172: "TabletLeaveProximity",
+ 219: "TabletTrackingChange",
+ 22: "ThreadChange",
+ 1: "Timer",
+ 120: "ToolBarChange",
+ 110: "ToolTip",
+ 184: "ToolTipChange",
+ 194: "TouchBegin",
+ 209: "TouchCancel",
+ 196: "TouchEnd",
+ 195: "TouchUpdate",
+ 189: "UngrabKeyboard",
+ 187: "UngrabMouse",
+ 78: "UpdateLater",
+ 77: "UpdateRequest",
+ 111: "WhatsThis",
+ 118: "WhatsThisClicked",
+ 31: "Wheel",
+ 132: "WinEventAct",
+ 24: "WindowActivate",
+ 103: "WindowBlocked",
+ 25: "WindowDeactivate",
+ 34: "WindowIconChange",
+ 105: "WindowStateChange",
+ 33: "WindowTitleChange",
+ 104: "WindowUnblocked",
+ 203: "WinIdChange",
+ 126: "ZOrderChange",
+ 65535: "MaxUser",
+}
+
+
+def getQEventName(eventType):
+ """
+ Returns the name of a QEvent.
+
+ :param Union[int,qt.QEvent] eventType: A QEvent or a QEvent type.
+ :returns: str
+ """
+ if isinstance(eventType, qt.QEvent):
+ eventType = eventType.type()
+ if 1000 <= eventType <= 65535:
+ return "User_%d" % eventType
+ name = QT_EVENT_NAMES.get(eventType, None)
+ if name is not None:
+ return name
+ return "Unknown_%d" % eventType
diff --git a/src/silx/gui/utils/signal.py b/src/silx/gui/utils/signal.py
new file mode 100644
index 0000000..359f5cc
--- /dev/null
+++ b/src/silx/gui/utils/signal.py
@@ -0,0 +1,141 @@
+#!/usr/bin/env python
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2012 University of North Carolina at Chapel Hill, Luke Campagnola
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module contains utils relative to qt Signal
+"""
+
+from silx.gui import qt
+import weakref
+from time import time
+from silx.gui.utils import concurrent
+
+__all__ = ['SignalProxy']
+__authors__ = ['L. Campagnola', 'M. Liberty']
+__license__ = "MIT"
+
+
+class SignalProxy(qt.QObject):
+ """
+ This peace of code come from pyqtgraph
+ Object which collects rapid-fire signals and condenses them
+ into a single signal or a rate-limited stream of signals.
+ Used, for example, to prevent a SpinBox from generating multiple
+ signals when the mouse wheel is rolled over it.
+
+ Emits sigDelayed after input signals have stopped for a certain period of time.
+ """
+
+ sigDelayed = qt.Signal(object)
+
+ def __init__(self, signal, delay=0.3, rateLimit=0, slot=None):
+ """Initialization arguments:
+ signal - a bound Signal or pyqtSignal instance
+ delay - Time (in seconds) to wait for signals to stop before emitting (default 0.3s)
+ slot - Optional function to connect sigDelayed to.
+ rateLimit - (signals/sec) if greater than 0, this allows signals to stream out at a
+ steady rate while they are being received.
+ """
+
+ qt.QObject.__init__(self)
+ signal.connect(self.signalReceived)
+ self.signal = signal
+ self.delay = delay
+ self.rateLimit = rateLimit
+ self.args = None
+ self.timer = qt.QTimer()
+ self.timer.timeout.connect(self.flush)
+ self.blockSignal = False
+ self.slot = weakref.ref(slot)
+ self.lastFlushTime = None
+ if slot is not None:
+ self.sigDelayed.connect(slot)
+
+ def setDelay(self, delay):
+ self.delay = delay
+
+ def signalReceived(self, *args):
+ """Received signal. Cancel previous timer and store args to be forwarded later."""
+ if self.blockSignal:
+ return
+ self.args = args
+ if self.rateLimit == 0:
+ concurrent.submitToQtMainThread(self.timer.stop)
+ concurrent.submitToQtMainThread(self.timer.start, (self.delay * 1000) + 1)
+ else:
+ now = time()
+ if self.lastFlushTime is None:
+ leakTime = 0
+ else:
+ lastFlush = self.lastFlushTime
+ leakTime = max(0, (lastFlush + (1.0 / self.rateLimit)) - now)
+
+ concurrent.submitToQtMainThread(self.timer.stop)
+ concurrent.submitToQtMainThread(self.timer.start, (min(leakTime, self.delay) * 1000) + 1)
+ # self.timer.stop()
+ # self.timer.start((min(leakTime, self.delay) * 1000) + 1)
+
+ def flush(self):
+ """If there is a signal queued up, send it now."""
+ if self.args is None or self.blockSignal:
+ return False
+ args, self.args = self.args, None
+ concurrent.submitToQtMainThread(self.timer.stop)
+ self.lastFlushTime = time()
+ # self.emit(self.signal, *self.args)
+ concurrent.submitToQtMainThread(self.sigDelayed.emit, args)
+ # self.sigDelayed.emit(args)
+ return True
+
+ def disconnect(self):
+ self.blockSignal = True
+ try:
+ self.signal.disconnect(self.signalReceived)
+ except:
+ pass
+ try:
+ self.sigDelayed.disconnect(self.slot)
+ except:
+ pass
+
+
+if __name__ == '__main__':
+ app = qt.QApplication([])
+ win = qt.QMainWindow()
+ spin = qt.QSpinBox()
+ win.setCentralWidget(spin)
+ win.show()
+
+
+ def fn(*args):
+ print("Raw signal:", args)
+
+
+ def fn2(*args):
+ print("Delayed signal:", args)
+
+
+ spin.valueChanged.connect(fn)
+ # proxy = proxyConnect(spin, QtCore.SIGNAL('valueChanged(int)'), fn)
+ proxy = SignalProxy(spin.valueChanged, delay=0.5, slot=fn2)
diff --git a/src/silx/gui/utils/test/__init__.py b/src/silx/gui/utils/test/__init__.py
new file mode 100755
index 0000000..15cd186
--- /dev/null
+++ b/src/silx/gui/utils/test/__init__.py
@@ -0,0 +1,25 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""silx.gui.utils tests"""
diff --git a/src/silx/gui/utils/test/test.py b/src/silx/gui/utils/test/test.py
new file mode 100644
index 0000000..0208d64
--- /dev/null
+++ b/src/silx/gui/utils/test/test.py
@@ -0,0 +1,63 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2019-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test of functions available in silx.gui.utils module."""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "01/08/2019"
+
+
+import unittest
+from silx.gui import qt
+from silx.gui.utils.testutils import TestCaseQt, SignalListener
+
+from silx.gui.utils import blockSignals
+
+
+class TestBlockSignals(TestCaseQt):
+ """Test blockSignals context manager"""
+
+ def _test(self, *objs):
+ """Test for provided objects"""
+ listener = SignalListener()
+ for obj in objs:
+ obj.objectNameChanged.connect(listener)
+ obj.setObjectName("received")
+
+ with blockSignals(*objs):
+ for obj in objs:
+ obj.setObjectName("silent")
+
+ self.assertEqual(listener.arguments(), [("received",)] * len(objs))
+
+ def testManyObjects(self):
+ """Test blockSignals with 2 QObjects"""
+ self._test(qt.QObject(), qt.QObject())
+
+ def testOneObject(self):
+ """Test blockSignals context manager with a single QObject"""
+ self._test(qt.QObject())
diff --git a/src/silx/gui/utils/test/test_async.py b/src/silx/gui/utils/test/test_async.py
new file mode 100644
index 0000000..7304ca9
--- /dev/null
+++ b/src/silx/gui/utils/test/test_async.py
@@ -0,0 +1,127 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test of async module."""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "09/03/2018"
+
+
+import threading
+import unittest
+
+
+from concurrent.futures import wait
+from silx.gui import qt
+from silx.gui.utils.testutils import TestCaseQt
+
+from silx.gui.utils import concurrent
+
+
+class TestSubmitToQtThread(TestCaseQt):
+ """Test submission of tasks to Qt main thread"""
+
+ def setUp(self):
+ # Reset executor to test lazy-loading in different conditions
+ concurrent._executor = None
+ super(TestSubmitToQtThread, self).setUp()
+
+ def _task(self, value1, value2):
+ return value1, value2
+
+ def _taskWithException(self, *args, **kwargs):
+ raise RuntimeError('task exception')
+
+ def testFromMainThread(self):
+ """Call submitToQtMainThread from the main thread"""
+ value1, value2 = 0, 1
+ future = concurrent.submitToQtMainThread(self._task, value1, value2=value2)
+ self.assertTrue(future.done())
+ self.assertEqual(future.result(1), (value1, value2))
+ self.assertIsNone(future.exception(1))
+
+ future = concurrent.submitToQtMainThread(self._taskWithException)
+ self.assertTrue(future.done())
+ with self.assertRaises(RuntimeError):
+ future.result(1)
+ self.assertIsInstance(future.exception(1), RuntimeError)
+
+ def _threadedTest(self):
+ """Function run in a thread for the tests"""
+ value1, value2 = 0, 1
+ future = concurrent.submitToQtMainThread(self._task, value1, value2=value2)
+
+ wait([future], 3)
+
+ self.assertTrue(future.done())
+ self.assertEqual(future.result(1), (value1, value2))
+ self.assertIsNone(future.exception(1))
+
+ future = concurrent.submitToQtMainThread(self._taskWithException)
+
+ wait([future], 3)
+
+ self.assertTrue(future.done())
+ with self.assertRaises(RuntimeError):
+ future.result(1)
+ self.assertIsInstance(future.exception(1), RuntimeError)
+
+ def testFromPythonThread(self):
+ """Call submitToQtMainThread from a Python thread"""
+ thread = threading.Thread(target=self._threadedTest)
+ thread.start()
+ for i in range(100): # Loop over for 10 seconds
+ self.qapp.processEvents()
+ thread.join(0.1)
+ if not thread.is_alive():
+ break
+ else:
+ self.fail(('Thread task still running'))
+
+ def testFromQtThread(self):
+ """Call submitToQtMainThread from a Qt thread pool"""
+ class Runner(qt.QRunnable):
+ def __init__(self, fn):
+ super(Runner, self).__init__()
+ self._fn = fn
+
+ def run(self):
+ self._fn()
+
+ def autoDelete(self):
+ return True
+
+ threadPool = qt.silxGlobalThreadPool()
+ runner = Runner(self._threadedTest)
+ threadPool.start(runner)
+ for i in range(100): # Loop over for 10 seconds
+ self.qapp.processEvents()
+ done = threadPool.waitForDone(100)
+ if done:
+ break
+ else:
+ self.fail('Thread pool task still running')
diff --git a/src/silx/gui/utils/test/test_glutils.py b/src/silx/gui/utils/test/test_glutils.py
new file mode 100644
index 0000000..7c9831b
--- /dev/null
+++ b/src/silx/gui/utils/test/test_glutils.py
@@ -0,0 +1,55 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for the silx.gui.utils.glutils module."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "15/01/2020"
+
+
+import logging
+import unittest
+from silx.gui.utils.glutils import isOpenGLAvailable
+
+
+_logger = logging.getLogger(__name__)
+
+
+class TestIsOpenGLAvailable(unittest.TestCase):
+ """Test isOpenGLAvailable"""
+
+ def test(self):
+ for version in ((2, 1), (2, 1), (1000, 1)):
+ with self.subTest(version=version):
+ result = isOpenGLAvailable(version=version)
+ _logger.info("isOpenGLAvailable returned: %s", str(result))
+ if version[0] == 1000:
+ self.assertFalse(result)
+ if not result:
+ self.assertFalse(result.status)
+ self.assertTrue(len(result.error) > 0)
+ else:
+ self.assertTrue(result.status)
+ self.assertTrue(len(result.error) == 0)
diff --git a/src/silx/gui/utils/test/test_image.py b/src/silx/gui/utils/test/test_image.py
new file mode 100644
index 0000000..62316b0
--- /dev/null
+++ b/src/silx/gui/utils/test/test_image.py
@@ -0,0 +1,79 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test of utils module."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "16/01/2017"
+
+import numpy
+import unittest
+
+from silx.gui import qt
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.utils.image import convertArrayToQImage, convertQImageToArray
+
+
+class TestQImageConversion(TestCaseQt, ParametricTestCase):
+ """Tests conversion of QImage to/from numpy array."""
+
+ def testConvertArrayToQImage(self):
+ """Test conversion of numpy array to QImage"""
+ for format_, channels in [('Format_RGB888', 3),
+ ('Format_ARGB32', 4)]:
+ with self.subTest(format_):
+ image = numpy.arange(
+ 3*3*channels, dtype=numpy.uint8).reshape(3, 3, channels)
+ qimage = convertArrayToQImage(image)
+
+ self.assertEqual(qimage.height(), image.shape[0])
+ self.assertEqual(qimage.width(), image.shape[1])
+ self.assertEqual(qimage.format(), getattr(qt.QImage, format_))
+
+ for row in range(3):
+ for col in range(3):
+ # Qrgb has no alpha channel, not compared
+ # Qt uses x,y while array is row,col...
+ self.assertEqual(qt.QColor(qimage.pixel(col, row)),
+ qt.QColor(*image[row, col, :3]))
+
+
+ def testConvertQImageToArray(self):
+ """Test conversion of QImage to numpy array"""
+ for format_, channels in [
+ ('Format_RGB888', 3), # Native support
+ ('Format_ARGB32', 4), # Native support
+ ('Format_RGB32', 3)]: # Conversion to RGB
+ with self.subTest(format_):
+ color = numpy.arange(channels) # RGB(A) values
+ qimage = qt.QImage(3, 3, getattr(qt.QImage, format_))
+ qimage.fill(qt.QColor(*color))
+ image = convertQImageToArray(qimage)
+
+ self.assertEqual(qimage.height(), image.shape[0])
+ self.assertEqual(qimage.width(), image.shape[1])
+ self.assertEqual(image.shape[2], len(color))
+ self.assertTrue(numpy.all(numpy.equal(image, color)))
diff --git a/src/silx/gui/utils/test/test_qtutils.py b/src/silx/gui/utils/test/test_qtutils.py
new file mode 100755
index 0000000..c00280b
--- /dev/null
+++ b/src/silx/gui/utils/test/test_qtutils.py
@@ -0,0 +1,65 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test of functions available in silx.gui.utils module."""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "01/08/2019"
+
+
+import unittest
+from silx.gui import qt
+from silx.gui import utils
+from silx.gui.utils.testutils import TestCaseQt
+
+
+class TestQEventName(TestCaseQt):
+ """Test QEvent names"""
+
+ def testNoneType(self):
+ result = utils.getQEventName(0)
+ self.assertEqual(result, "None")
+
+ def testNoneEvent(self):
+ event = qt.QEvent(qt.QEvent.Type(0))
+ result = utils.getQEventName(event)
+ self.assertEqual(result, "None")
+
+ def testUserType(self):
+ result = utils.getQEventName(1050)
+ self.assertIn("User", result)
+ self.assertIn("1050", result)
+
+ def testQtUndefinedType(self):
+ result = utils.getQEventName(900)
+ self.assertIn("Unknown", result)
+ self.assertIn("900", result)
+
+ def testUndefinedType(self):
+ result = utils.getQEventName(70000)
+ self.assertIn("Unknown", result)
+ self.assertIn("70000", result)
diff --git a/src/silx/gui/utils/test/test_testutils.py b/src/silx/gui/utils/test/test_testutils.py
new file mode 100644
index 0000000..07294a7
--- /dev/null
+++ b/src/silx/gui/utils/test/test_testutils.py
@@ -0,0 +1,44 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test of testutils module."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "16/01/2017"
+
+import unittest
+import sys
+
+from silx.gui import qt
+from ..testutils import TestCaseQt
+
+
+class TestOutcome(unittest.TestCase):
+ """Tests conversion of QImage to/from numpy array."""
+
+ @unittest.skipIf(sys.version_info.major <= 2, 'Python3 only')
+ def testNoneOutcome(self):
+ test = TestCaseQt()
+ test._currentTestSucceeded()
diff --git a/src/silx/gui/utils/testutils.py b/src/silx/gui/utils/testutils.py
new file mode 100644
index 0000000..40c8237
--- /dev/null
+++ b/src/silx/gui/utils/testutils.py
@@ -0,0 +1,508 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Helper class to write Qt widget unittests."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "05/10/2018"
+
+
+import gc
+import logging
+import unittest
+import time
+import functools
+import sys
+import os
+
+_logger = logging.getLogger(__name__)
+
+from silx.gui import qt
+from silx.gui.qt import inspect as _inspect
+
+
+if qt.BINDING == 'PySide2':
+ from PySide2.QtTest import QTest
+elif qt.BINDING == 'PyQt5':
+ from PyQt5.QtTest import QTest
+elif qt.BINDING == 'PySide6':
+ from PySide6.QtTest import QTest
+else:
+ raise ImportError('Unsupported Qt bindings')
+
+
+def qWaitForWindowExposedAndActivate(window, timeout=None):
+ """Waits until the window is shown in the screen.
+
+ It also activates the window and raises it.
+
+ See QTest.qWaitForWindowExposed for details.
+ """
+ if timeout is None:
+ result = QTest.qWaitForWindowExposed(window)
+ else:
+ result = QTest.qWaitForWindowExposed(window, timeout)
+
+ if result:
+ # Makes sure window is active and on top
+ window.activateWindow()
+ window.raise_()
+
+ return result
+
+
+class TestCaseQt(unittest.TestCase):
+ """Base class to write test for Qt stuff.
+
+ It creates a QApplication before running the tests.
+ WARNING: The QApplication is shared by all tests, which might have side
+ effects.
+
+ After each test, this class is checking for widgets remaining alive.
+ To allow some widgets to remain alive at the end of a test, set the
+ allowedLeakingWidgets attribute to the number of widgets that can remain
+ alive at the end of the test.
+ With PySide2, this test is not run for now as it seems PySide2
+ is leaking widgets internally.
+
+ All keyboard and mouse event simulation methods call qWait(20) after
+ simulating the event (as QTest does on Mac OSX).
+ This was introduced to fix issues with continuous integration tests
+ running with Xvfb on Linux.
+ """
+
+ DEFAULT_TIMEOUT_WAIT = 100
+ """Default timeout for qWait"""
+
+ TIMEOUT_WAIT = 0
+ """Extra timeout in millisecond to add to qSleep, qWait and
+ qWaitForWindowExposed.
+
+ Intended purpose is for debugging, to add extra time to waits in order to
+ allow to view the tested widgets.
+ """
+
+ _qapp = None
+ """Placeholder for QApplication"""
+
+ @classmethod
+ def exceptionHandler(cls, exceptionClass, exception, stack):
+ import traceback
+ message = (''.join(traceback.format_tb(stack)))
+ template = 'Traceback (most recent call last):\n{2}{0}: {1}'
+ message = template.format(exceptionClass.__name__, exception, message)
+ cls._exceptions.append(message)
+
+ @classmethod
+ def setUpClass(cls):
+ """Makes sure Qt is inited"""
+ cls._oldExceptionHook = sys.excepthook
+ sys.excepthook = cls.exceptionHandler
+
+ # Makes sure a QApplication exists and do it once for all
+ if not qt.QApplication.instance():
+ cls._qapp = qt.QApplication([])
+
+ @classmethod
+ def tearDownClass(cls):
+ sys.excepthook = cls._oldExceptionHook
+
+ def setUp(self):
+ """Get the list of existing widgets."""
+ self.allowedLeakingWidgets = 0
+ if qt.BINDING in ('PySide2', 'PySide6'):
+ self.__previousWidgets = None
+ else:
+ self.__previousWidgets = self.qapp.allWidgets()
+ self.__class__._exceptions = []
+
+ def _currentTestSucceeded(self):
+ if hasattr(self, '_outcome'):
+ # For Python >= 3.4
+ result = self.defaultTestResult() # these 2 methods have no side effects
+ if hasattr(self._outcome, 'errors'):
+ self._feedErrorsToResult(result, self._outcome.errors)
+ else:
+ # For Python < 3.4
+ result = getattr(self, '_outcomeForDoCleanups', self._resultForDoCleanups)
+
+ skipped = self.id() in [case.id() for case, _ in result.skipped]
+ error = self.id() in [case.id() for case, _ in result.errors]
+ failure = self.id() in [case.id() for case, _ in result.failures]
+ return not error and not failure and not skipped
+
+ def _checkForUnreleasedWidgets(self):
+ """Test fixture checking that no more widgets exists."""
+ gc.collect()
+
+ if self.__previousWidgets is None:
+ return # Do not test for leaking widgets with PySide2
+
+ widgets = [widget for widget in self.qapp.allWidgets()
+ if (widget not in self.__previousWidgets and
+ _inspect.createdByPython(widget))]
+ self.__previousWidgets = None
+
+ allowedLeakingWidgets = self.allowedLeakingWidgets
+ self.allowedLeakingWidgets = 0
+
+ if widgets and len(widgets) <= allowedLeakingWidgets:
+ _logger.info(
+ '%s: %d remaining widgets after test' % (self.id(),
+ len(widgets)))
+
+ if len(widgets) > allowedLeakingWidgets:
+ raise RuntimeError(
+ "Test ended with widgets alive: %s" % str(widgets))
+
+ def tearDown(self):
+ self.qapp.processEvents()
+
+ if len(self.__class__._exceptions) > 0:
+ messages = "\n".join(self.__class__._exceptions)
+ raise AssertionError("Exception occured in Qt thread:\n" + messages)
+
+ if self._currentTestSucceeded():
+ self._checkForUnreleasedWidgets()
+
+ @property
+ def qapp(self):
+ """The QApplication currently running."""
+ return qt.QApplication.instance()
+
+ # Proxy to QTest
+
+ Press = QTest.Press
+ """Key press action code"""
+
+ Release = QTest.Release
+ """Key release action code"""
+
+ Click = QTest.Click
+ """Key click action code"""
+
+ QTest = property(lambda self: QTest,
+ doc="""The Qt QTest class from the used Qt binding.""")
+
+ def keyClick(self, widget, key, modifier=qt.Qt.NoModifier, delay=-1):
+ """Simulate clicking a key.
+
+ See QTest.keyClick for details.
+ """
+ QTest.keyClick(widget, key, modifier, delay)
+ self.qWait(20)
+
+ def keyClicks(self, widget, sequence, modifier=qt.Qt.NoModifier, delay=-1):
+ """Simulate clicking a sequence of keys.
+
+ See QTest.keyClick for details.
+ """
+ QTest.keyClicks(widget, sequence, modifier, delay)
+ self.qWait(20)
+
+ def keyEvent(self, action, widget, key,
+ modifier=qt.Qt.NoModifier, delay=-1):
+ """Sends a Qt key event.
+
+ See QTest.keyEvent for details.
+ """
+ QTest.keyEvent(action, widget, key, modifier, delay)
+ self.qWait(20)
+
+ def keyPress(self, widget, key, modifier=qt.Qt.NoModifier, delay=-1):
+ """Sends a Qt key press event.
+
+ See QTest.keyPress for details.
+ """
+ QTest.keyPress(widget, key, modifier, delay)
+ self.qWait(20)
+
+ def keyRelease(self, widget, key, modifier=qt.Qt.NoModifier, delay=-1):
+ """Sends a Qt key release event.
+
+ See QTest.keyRelease for details.
+ """
+ QTest.keyRelease(widget, key, modifier, delay)
+ self.qWait(20)
+
+ def mouseClick(self, widget, button, modifier=None, pos=None, delay=-1):
+ """Simulate clicking a mouse button.
+
+ See QTest.mouseClick for details.
+ """
+ if modifier is None:
+ modifier = qt.Qt.KeyboardModifiers()
+ pos = qt.QPoint(int(pos[0]), int(pos[1])) if pos is not None else qt.QPoint()
+ QTest.mouseClick(widget, button, modifier, pos, delay)
+ self.qWait(20)
+
+ def mouseDClick(self, widget, button, modifier=None, pos=None, delay=-1):
+ """Simulate double clicking a mouse button.
+
+ See QTest.mouseDClick for details.
+ """
+ if modifier is None:
+ modifier = qt.Qt.KeyboardModifiers()
+ pos = qt.QPoint(int(pos[0]), int(pos[1])) if pos is not None else qt.QPoint()
+ QTest.mouseDClick(widget, button, modifier, pos, delay)
+ self.qWait(20)
+
+ def mouseMove(self, widget, pos=None, delay=-1):
+ """Simulate moving the mouse.
+
+ See QTest.mouseMove for details.
+ """
+ pos = qt.QPoint(int(pos[0]), int(pos[1])) if pos is not None else qt.QPoint()
+ QTest.mouseMove(widget, pos, delay)
+ self.qWait(20)
+
+ def mousePress(self, widget, button, modifier=None, pos=None, delay=-1):
+ """Simulate pressing a mouse button.
+
+ See QTest.mousePress for details.
+ """
+ if modifier is None:
+ modifier = qt.Qt.KeyboardModifiers()
+ pos = qt.QPoint(int(pos[0]), int(pos[1])) if pos is not None else qt.QPoint()
+ QTest.mousePress(widget, button, modifier, pos, delay)
+ self.qWait(20)
+
+ def mouseRelease(self, widget, button, modifier=None, pos=None, delay=-1):
+ """Simulate releasing a mouse button.
+
+ See QTest.mouseRelease for details.
+ """
+ if modifier is None:
+ modifier = qt.Qt.KeyboardModifiers()
+ pos = qt.QPoint(int(pos[0]), int(pos[1])) if pos is not None else qt.QPoint()
+ QTest.mouseRelease(widget, button, modifier, pos, delay)
+ self.qWait(20)
+
+ def qSleep(self, ms):
+ """Sleep for ms milliseconds, blocking the execution of the test.
+
+ See QTest.qSleep for details.
+ """
+ QTest.qSleep(int(ms) + self.TIMEOUT_WAIT)
+
+ @classmethod
+ def qWait(cls, ms=None):
+ """Waits for ms milliseconds, events will be processed.
+
+ See QTest.qWait for details.
+ """
+ if ms is None:
+ ms = cls.DEFAULT_TIMEOUT_WAIT
+
+ if qt.BINDING in ('PySide2', 'PySide6'):
+ # PySide2 has no qWait, provide a replacement
+ timeout = int(ms)
+ endTimeMS = int(time.time() * 1000) + timeout
+ qapp = qt.QApplication.instance()
+ while timeout > 0:
+ qapp.processEvents(qt.QEventLoop.AllEvents,
+ timeout)
+ timeout = endTimeMS - int(time.time() * 1000)
+ else:
+ QTest.qWait(int(ms) + cls.TIMEOUT_WAIT)
+
+ def qWaitForWindowExposed(self, window, timeout=None):
+ """Waits until the window is shown in the screen.
+
+ See QTest.qWaitForWindowExposed for details.
+ """
+ result = qWaitForWindowExposedAndActivate(window, timeout)
+
+ if self.TIMEOUT_WAIT:
+ QTest.qWait(self.TIMEOUT_WAIT)
+
+ return result
+
+ def exposeAndClose(self, widget):
+ """Wait for expose a widget, flag it delete on close, and close it."""
+ self.qWaitForWindowExposed(widget)
+ self.qapp.processEvents()
+ widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ widget.close()
+
+ _qobject_destroyed = False
+
+ @classmethod
+ def _aboutToDestroy(cls):
+ cls._qobject_destroyed = True
+
+ @classmethod
+ def qWaitForDestroy(cls, ref):
+ """
+ Wait for Qt object destruction.
+
+ Use a weakref as parameter to avoid any strong references to the
+ object.
+
+ It have to be used as following. Removing the reference to the object
+ before calling the function looks to be expected, else
+ :meth:`deleteLater` will not work.
+
+ .. code-block:: python
+
+ ref = weakref.ref(self.obj)
+ self.obj = None
+ self.qWaitForDestroy(ref)
+
+ :param weakref ref: A weakref to an object to avoid any reference
+ :return: True if the object was destroyed
+ :rtype: bool
+ """
+ cls._qobject_destroyed = False
+ qobject = ref()
+ if qobject is None:
+ return True
+ qobject.destroyed.connect(cls._aboutToDestroy)
+ qobject.deleteLater()
+ qobject = None
+ for _ in range(10):
+ if cls._qobject_destroyed:
+ break
+ cls.qWait(10)
+ else:
+ _logger.debug("Object was not destroyed")
+
+ return ref() is None
+
+ def logScreenShot(self, level=logging.ERROR):
+ """Take a screenshot and log it into the logging system if the
+ logger is enabled for the expected level.
+
+ The screenshot is stored in the directory "./build/test-debug", and
+ the logging system only log the path to this file.
+
+ :param level: Logging level
+ """
+ if not _logger.isEnabledFor(level):
+ return
+ basedir = os.path.abspath(os.path.join("build", "test-debug"))
+ if not os.path.exists(basedir):
+ os.makedirs(basedir)
+ filename = "Screenshot_%s.png" % self.id()
+ filename = os.path.join(basedir, filename)
+
+ screen = self.qapp.primaryScreen()
+ pixmap = screen.grabWindow(0)
+ pixmap.save(filename)
+ _logger.log(level, "Screenshot saved at %s", filename)
+
+
+class SignalListener(object):
+ """Util to listen a Qt event and store parameters
+ """
+
+ def __init__(self):
+ self.__calls = []
+
+ def __call__(self, *args, **kargs):
+ self.__calls.append((args, kargs))
+
+ def clear(self):
+ """Clear stored data"""
+ self.__calls = []
+
+ def callCount(self):
+ """
+ Returns how many times the listener was called.
+
+ :rtype: int
+ """
+ return len(self.__calls)
+
+ def arguments(self, callIndex=None, argumentIndex=None):
+ """Returns positional arguments optionally filtered by call count id
+ or argument index.
+
+ :param int callIndex: Index of the called data
+ :param int argumentIndex: Index of the positional argument.
+ """
+ if callIndex is not None:
+ result = self.__calls[callIndex][0]
+ if argumentIndex is not None:
+ result = result[argumentIndex]
+ else:
+ result = [x[0] for x in self.__calls]
+ if argumentIndex is not None:
+ result = [x[argumentIndex] for x in result]
+ return result
+
+ def karguments(self, callIndex=None, argumentName=None):
+ """Returns positional arguments optionally filtered by call count id
+ or name of the keyword argument.
+
+ :param int callIndex: Index of the called data
+ :param int argumentName: Name of the keyword argument.
+ """
+ if callIndex is not None:
+ result = self.__calls[callIndex][1]
+ if argumentName is not None:
+ result = result[argumentName]
+ else:
+ result = [x[1] for x in self.__calls]
+ if argumentName is not None:
+ result = [x[argumentName] for x in result]
+ return result
+
+ def partial(self, *args, **kargs):
+ """Returns a new partial object which when called will behave like this
+ listener called with the positional arguments args and keyword
+ arguments keywords. If more arguments are supplied to the call, they
+ are appended to args. If additional keyword arguments are supplied,
+ they extend and override keywords.
+ """
+ return functools.partial(self, *args, **kargs)
+
+
+def getQToolButtonFromAction(action):
+ """Return a QToolButton corresponding to a QAction.
+
+ :param QAction action: The QAction from which to get QToolButton.
+ :return: A QToolButton associated to action or None.
+ """
+ if qt.BINDING == "PySide6":
+ widgets = action.associatedObjects()
+ else:
+ widgets = action.associatedWidgets()
+
+ for widget in widgets:
+ if isinstance(widget, qt.QToolButton):
+ return widget
+ return None
+
+
+def findChildren(parent, kind, name=None):
+ if qt.BINDING in ("PySide2", "PySide6") and name is not None:
+ result = []
+ for obj in parent.findChildren(kind):
+ if obj.objectName() == name:
+ result.append(obj)
+ return result
+ else:
+ return parent.findChildren(kind, name=name)
diff --git a/src/silx/gui/widgets/BoxLayoutDockWidget.py b/src/silx/gui/widgets/BoxLayoutDockWidget.py
new file mode 100644
index 0000000..3d2b853
--- /dev/null
+++ b/src/silx/gui/widgets/BoxLayoutDockWidget.py
@@ -0,0 +1,90 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""A QDockWidget that update the layout direction of its widget
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "06/03/2018"
+
+
+from .. import qt
+
+
+class BoxLayoutDockWidget(qt.QDockWidget):
+ """QDockWidget adjusting its child widget QBoxLayout direction.
+
+ The child widget layout direction is set according to dock widget area.
+ The child widget MUST use a QBoxLayout
+
+ :param parent: See :class:`QDockWidget`
+ :param flags: See :class:`QDockWidget`
+ """
+
+ def __init__(self, parent=None, flags=qt.Qt.Widget):
+ super(BoxLayoutDockWidget, self).__init__(parent, flags)
+ self._currentArea = qt.Qt.NoDockWidgetArea
+ self.dockLocationChanged.connect(self._dockLocationChanged)
+ self.topLevelChanged.connect(self._topLevelChanged)
+
+ def setWidget(self, widget):
+ """Set the widget of this QDockWidget
+
+ See :meth:`QDockWidget.setWidget`
+ """
+ super(BoxLayoutDockWidget, self).setWidget(widget)
+ # Update widget's layout direction
+ self._dockLocationChanged(self._currentArea)
+
+ def _dockLocationChanged(self, area):
+ self._currentArea = area
+
+ widget = self.widget()
+ if widget is not None:
+ layout = widget.layout()
+ if isinstance(layout, qt.QBoxLayout):
+ if area in (qt.Qt.LeftDockWidgetArea, qt.Qt.RightDockWidgetArea):
+ direction = qt.QBoxLayout.TopToBottom
+ else:
+ direction = qt.QBoxLayout.LeftToRight
+ layout.setDirection(direction)
+ self.resize(widget.minimumSize())
+ self.adjustSize()
+
+ def _topLevelChanged(self, topLevel):
+ widget = self.widget()
+ if widget is not None and topLevel:
+ layout = widget.layout()
+ if isinstance(layout, qt.QBoxLayout):
+ layout.setDirection(qt.QBoxLayout.LeftToRight)
+ self.resize(widget.minimumSize())
+ self.adjustSize()
+
+ def showEvent(self, event):
+ """Make sure this dock widget is raised when it is shown.
+
+ This is useful for tabbed dock widgets.
+ """
+ self.raise_()
diff --git a/src/silx/gui/widgets/ColormapNameComboBox.py b/src/silx/gui/widgets/ColormapNameComboBox.py
new file mode 100644
index 0000000..fa8faf1
--- /dev/null
+++ b/src/silx/gui/widgets/ColormapNameComboBox.py
@@ -0,0 +1,166 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""A QComboBox to display prefered colormaps
+"""
+
+from __future__ import division
+
+__authors__ = ["V.A. Sole", "T. Vincent", "H. Payno"]
+__license__ = "MIT"
+__date__ = "27/11/2018"
+
+
+import logging
+import numpy
+
+from .. import qt
+from .. import colors as colors_mdl
+
+_logger = logging.getLogger(__name__)
+
+
+_colormapIconPreview = {}
+
+
+class ColormapNameComboBox(qt.QComboBox):
+ def __init__(self, parent=None):
+ qt.QComboBox.__init__(self, parent)
+ self.__initItems()
+
+ LUT_NAME = qt.Qt.UserRole + 1
+ LUT_COLORS = qt.Qt.UserRole + 2
+
+ def __initItems(self):
+ for colormapName in colors_mdl.preferredColormaps():
+ index = self.count()
+ self.addItem(str.title(colormapName))
+ self.setItemIcon(index, self.getIconPreview(name=colormapName))
+ self.setItemData(index, colormapName, role=self.LUT_NAME)
+
+ def getIconPreview(self, name=None, colors=None):
+ """Return an icon preview from a LUT name.
+
+ This icons are cached into a global structure.
+
+ :param str name: Name of the LUT
+ :param numpy.ndarray colors: Colors identify the LUT
+ :rtype: qt.QIcon
+ """
+ if name is not None:
+ iconKey = name
+ else:
+ iconKey = tuple(colors)
+ icon = _colormapIconPreview.get(iconKey, None)
+ if icon is None:
+ icon = self.createIconPreview(name, colors)
+ _colormapIconPreview[iconKey] = icon
+ return icon
+
+ def createIconPreview(self, name=None, colors=None):
+ """Create and return an icon preview from a LUT name.
+
+ This icons are cached into a global structure.
+
+ :param str name: Name of the LUT
+ :param numpy.ndarray colors: Colors identify the LUT
+ :rtype: qt.QIcon
+ """
+ colormap = colors_mdl.Colormap(name)
+ size = 32
+ if name is not None:
+ lut = colormap.getNColors(size)
+ else:
+ lut = colors
+ if len(lut) > size:
+ # Down sample
+ step = int(len(lut) / size)
+ lut = lut[::step]
+ elif len(lut) < size:
+ # Over sample
+ indexes = numpy.arange(size) / float(size) * (len(lut) - 1)
+ indexes = indexes.astype("int")
+ lut = lut[indexes]
+ if lut is None or len(lut) == 0:
+ return qt.QIcon()
+
+ pixmap = qt.QPixmap(size, size)
+ painter = qt.QPainter(pixmap)
+ for i in range(size):
+ rgb = lut[i]
+ r, g, b = rgb[0], rgb[1], rgb[2]
+ painter.setPen(qt.QColor(r, g, b))
+ painter.drawPoint(qt.QPoint(i, 0))
+
+ painter.drawPixmap(0, 1, size, size - 1, pixmap, 0, 0, size, 1)
+ painter.end()
+
+ return qt.QIcon(pixmap)
+
+ def getCurrentName(self):
+ return self.itemData(self.currentIndex(), self.LUT_NAME)
+
+ def getCurrentColors(self):
+ return self.itemData(self.currentIndex(), self.LUT_COLORS)
+
+ def findLutName(self, name):
+ return self.findData(name, role=self.LUT_NAME)
+
+ def findLutColors(self, lut):
+ for index in range(self.count()):
+ if self.itemData(index, role=self.LUT_NAME) is not None:
+ continue
+ colors = self.itemData(index, role=self.LUT_COLORS)
+ if colors is None:
+ continue
+ if numpy.array_equal(colors, lut):
+ return index
+ return -1
+
+ def setCurrentLut(self, colormap):
+ name = colormap.getName()
+ if name is not None:
+ self._setCurrentName(name)
+ else:
+ lut = colormap.getColormapLUT()
+ self._setCurrentLut(lut)
+
+ def _setCurrentLut(self, lut):
+ index = self.findLutColors(lut)
+ if index == -1:
+ index = self.count()
+ self.addItem("Custom")
+ self.setItemIcon(index, self.getIconPreview(colors=lut))
+ self.setItemData(index, None, role=self.LUT_NAME)
+ self.setItemData(index, lut, role=self.LUT_COLORS)
+ self.setCurrentIndex(index)
+
+ def _setCurrentName(self, name):
+ index = self.findLutName(name)
+ if index < 0:
+ index = self.count()
+ self.addItem(str.title(name))
+ self.setItemIcon(index, self.getIconPreview(name=name))
+ self.setItemData(index, name, role=self.LUT_NAME)
+ self.setCurrentIndex(index)
diff --git a/src/silx/gui/widgets/ElidedLabel.py b/src/silx/gui/widgets/ElidedLabel.py
new file mode 100644
index 0000000..7c6dfb5
--- /dev/null
+++ b/src/silx/gui/widgets/ElidedLabel.py
@@ -0,0 +1,140 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Module contains an elidable label
+"""
+
+__license__ = "MIT"
+__date__ = "07/12/2018"
+
+from silx.gui import qt
+
+
+class ElidedLabel(qt.QLabel):
+ """QLabel with an edile property.
+
+ By default if the text is too big, it is elided on the right.
+
+ This mode can be changed with :func:`setElideMode`.
+
+ In case the text is elided, the full content is displayed as part of the
+ tool tip. This behavior can be disabled with :func:`setTextAsToolTip`.
+ """
+
+ def __init__(self, parent=None):
+ super(ElidedLabel, self).__init__(parent)
+ self.__text = ""
+ self.__toolTip = ""
+ self.__textAsToolTip = True
+ self.__textIsElided = False
+ self.__elideMode = qt.Qt.ElideRight
+ self.__updateMinimumSize()
+
+ def resizeEvent(self, event):
+ self.__updateText()
+ return qt.QLabel.resizeEvent(self, event)
+
+ def setFont(self, font):
+ qt.QLabel.setFont(self, font)
+ self.__updateMinimumSize()
+ self.__updateText()
+
+ def __updateMinimumSize(self):
+ metrics = self.fontMetrics()
+ if qt.BINDING in ('PySide2', 'PyQt5'):
+ width = metrics.width("...")
+ else: # Qt6
+ width = metrics.horizontalAdvance("...")
+ self.setMinimumWidth(width)
+
+ def __updateText(self):
+ metrics = self.fontMetrics()
+ elidedText = metrics.elidedText(self.__text, self.__elideMode, self.width())
+ qt.QLabel.setText(self, elidedText)
+ wasElided = self.__textIsElided
+ self.__textIsElided = elidedText != self.__text
+ if self.__textIsElided or wasElided != self.__textIsElided:
+ self.__updateToolTip()
+
+ def __updateToolTip(self):
+ if self.__textIsElided and self.__textAsToolTip:
+ qt.QLabel.setToolTip(self, self.__text + "<br/>" + self.__toolTip)
+ else:
+ qt.QLabel.setToolTip(self, self.__toolTip)
+
+ # Properties
+
+ def setText(self, text):
+ self.__text = text
+ self.__updateText()
+
+ def getText(self):
+ return self.__text
+
+ text = qt.Property(str, getText, setText)
+
+ def setToolTip(self, toolTip):
+ self.__toolTip = toolTip
+ self.__updateToolTip()
+
+ def getToolTip(self):
+ return self.__toolTip
+
+ toolTip = qt.Property(str, getToolTip, setToolTip)
+
+ def setElideMode(self, elideMode):
+ """Set the elide mode.
+
+ :param qt.Qt.TextElideMode elidMode: Elide mode to use
+ """
+ self.__elideMode = elideMode
+ self.__updateText()
+
+ def getElideMode(self):
+ """Returns the used elide mode.
+
+ :rtype: qt.Qt.TextElideMode
+ """
+ return self.__elideMode
+
+ elideMode = qt.Property(qt.Qt.TextElideMode, getToolTip, setToolTip)
+
+ def setTextAsToolTip(self, enabled):
+ """Enable displaying text as part of the tooltip if it is elided.
+
+ :param bool enabled: Enable the behavior
+ """
+ if self.__textAsToolTip == enabled:
+ return
+ self.__textAsToolTip = enabled
+ self.__updateToolTip()
+
+ def getTextAsToolTip(self):
+ """True if an elided text is displayed as part of the tooltip.
+
+ :rtype: bool
+ """
+ return self.__textAsToolTip
+
+ textAsToolTip = qt.Property(bool, getTextAsToolTip, setTextAsToolTip)
diff --git a/src/silx/gui/widgets/FloatEdit.py b/src/silx/gui/widgets/FloatEdit.py
new file mode 100644
index 0000000..08ed67d
--- /dev/null
+++ b/src/silx/gui/widgets/FloatEdit.py
@@ -0,0 +1,71 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Module contains a float editor
+"""
+
+from __future__ import division
+
+__authors__ = ["V.A. Sole", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "02/10/2017"
+
+from .. import qt
+
+
+class FloatEdit(qt.QLineEdit):
+ """Field to edit a float value.
+
+ :param parent: See :class:`QLineEdit`
+ :param float value: The value to set the QLineEdit to.
+ """
+ def __init__(self, parent=None, value=None):
+ qt.QLineEdit.__init__(self, parent)
+ validator = qt.QDoubleValidator(self)
+ self.setValidator(validator)
+ self.setAlignment(qt.Qt.AlignRight)
+ if value is not None:
+ self.setValue(value)
+
+ def value(self):
+ """Return the QLineEdit current value as a float."""
+ text = self.text()
+ value, validated = self.validator().locale().toDouble(text)
+ if not validated:
+ self.setValue(value)
+ return value
+
+ def setValue(self, value):
+ """Set the current value of the LineEdit
+
+ :param float value: The value to set the QLineEdit to.
+ """
+ locale = self.validator().locale()
+ if qt.BINDING == "PySide6":
+ # Fix for PySide6 not selecting the right method
+ text = locale.toString(float(value), 'g')
+ else:
+ text = locale.toString(float(value))
+
+ self.setText(text)
diff --git a/src/silx/gui/widgets/FlowLayout.py b/src/silx/gui/widgets/FlowLayout.py
new file mode 100644
index 0000000..3c4c9dd
--- /dev/null
+++ b/src/silx/gui/widgets/FlowLayout.py
@@ -0,0 +1,177 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a flow layout for QWidget: :class:`FlowLayout`.
+"""
+
+from __future__ import division
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "20/07/2018"
+
+
+from .. import qt
+
+
+class FlowLayout(qt.QLayout):
+ """Layout widgets on (possibly) multiple lines in the available width.
+
+ See Qt :class:`QLayout` for API documentation.
+
+ Adapted from C++ `Qt FlowLayout example
+ <http://doc.qt.io/qt-5/qtwidgets-layouts-flowlayout-example.html>`_
+
+ :param QWidget parent: See :class:`QLayout`
+ """
+
+ def __init__(self, parent=None):
+ super(FlowLayout, self).__init__(parent)
+ self._items = []
+ self._horizontalSpacing = -1
+ self._verticalSpacing = -1
+
+ def addItem(self, item):
+ self._items.append(item)
+
+ def count(self):
+ return len(self._items)
+
+ def itemAt(self, index):
+ if 0 <= index < len(self._items):
+ return self._items[index]
+ else:
+ return None
+
+ def takeAt(self, index):
+ if 0 <= index < len(self._items):
+ return self._items.pop(index)
+ else:
+ return None
+
+ def expandingDirections(self):
+ return qt.Qt.Orientations()
+
+ def hasHeightForWidth(self):
+ return True
+
+ def heightForWidth(self, width):
+ return self._layout(qt.QRect(0, 0, width, 0), test=True)
+
+ def setGeometry(self, rect):
+ super(FlowLayout, self).setGeometry(rect)
+ self._layout(rect)
+
+ def sizeHint(self):
+ return self.minimumSize()
+
+ def minimumSize(self):
+ size = qt.QSize()
+ for item in self._items:
+ size = size.expandedTo(item.minimumSize())
+
+ left, top, right, bottom = self.getContentsMargins()
+ size += qt.QSize(left + right, top + bottom)
+ return size
+
+ def _layout(self, rect, test=False):
+ left, top, right, bottom = self.getContentsMargins()
+ effectiveRect = rect.adjusted(left, top, -right, -bottom)
+ x, y = effectiveRect.x(), effectiveRect.y()
+ lineHeight = 0
+
+ for item in self._items:
+ widget = item.widget()
+ spaceX = self.horizontalSpacing()
+ if spaceX == -1:
+ spaceX = widget.style().layoutSpacing(
+ qt.QSizePolicy.PushButton,
+ qt.QSizePolicy.PushButton,
+ qt.Qt.Horizontal)
+ spaceY = self.verticalSpacing()
+ if spaceY == -1:
+ spaceY = widget.style().layoutSpacing(
+ qt.QSizePolicy.PushButton,
+ qt.QSizePolicy.PushButton,
+ qt.Qt.Vertical)
+
+ nextX = x + item.sizeHint().width() + spaceX
+ if (nextX - spaceX) > effectiveRect.right() and lineHeight > 0:
+ x = effectiveRect.x()
+ y += lineHeight + spaceY
+ nextX = x + item.sizeHint().width() + spaceX
+ lineHeight = 0
+
+ if not test:
+ item.setGeometry(qt.QRect(qt.QPoint(x, y), item.sizeHint()))
+
+ x = nextX
+ lineHeight = max(lineHeight, item.sizeHint().height())
+
+ return y + lineHeight - rect.y() + bottom
+
+ def setHorizontalSpacing(self, spacing):
+ """Set the horizontal spacing between widgets laid out side by side
+
+ :param int spacing:
+ """
+ self._horizontalSpacing = spacing
+ self.update()
+
+ def horizontalSpacing(self):
+ """Returns the horizontal spacing between widgets laid out side by side
+
+ :rtype: int
+ """
+ if self._horizontalSpacing >= 0:
+ return self._horizontalSpacing
+ else:
+ return self._smartSpacing(qt.QStyle.PM_LayoutHorizontalSpacing)
+
+ def setVerticalSpacing(self, spacing):
+ """Set the vertical spacing between lines
+
+ :param int spacing:
+ """
+ self._verticalSpacing = spacing
+ self.update()
+
+ def verticalSpacing(self):
+ """Returns the vertical spacing between lines
+
+ :rtype: int
+ """
+ if self._verticalSpacing >= 0:
+ return self._verticalSpacing
+ else:
+ return self._smartSpacing(qt.QStyle.PM_LayoutVerticalSpacing)
+
+ def _smartSpacing(self, pm):
+ parent = self.parent()
+ if parent is None:
+ return -1
+ if parent.isWidgetType():
+ return parent.style().pixelMetric(pm, None, parent)
+ else:
+ return parent.spacing()
diff --git a/src/silx/gui/widgets/FrameBrowser.py b/src/silx/gui/widgets/FrameBrowser.py
new file mode 100644
index 0000000..671991f
--- /dev/null
+++ b/src/silx/gui/widgets/FrameBrowser.py
@@ -0,0 +1,324 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module defines two main classes:
+
+ - :class:`FrameBrowser`: a widget with 4 buttons (first, previous, next,
+ last) to browse between frames and a text entry to access a specific frame
+ by typing it's number)
+ - :class:`HorizontalSliderWithBrowser`: a FrameBrowser with an additional
+ slider. This class inherits :class:`qt.QAbstractSlider`.
+
+"""
+from silx.gui import qt
+from silx.gui import icons
+from silx.utils import deprecation
+
+__authors__ = ["V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "16/01/2017"
+
+
+class FrameBrowser(qt.QWidget):
+ """Frame browser widget, with 4 buttons/icons and a line edit to provide
+ a way of selecting a frame index in a stack of images.
+
+ .. image:: img/FrameBrowser.png
+
+ It can be used in more generic case to select an integer within a range.
+
+ :param QWidget parent: Parent widget
+ :param int n: Number of frames. This will set the range
+ of frame indices to 0--n-1.
+ If None, the range is initialized to the default QSlider range (0--99).
+ """
+
+ sigIndexChanged = qt.pyqtSignal(object)
+
+ def __init__(self, parent=None, n=None):
+ qt.QWidget.__init__(self, parent)
+
+ # Use the font size as the icon size to avoid to create bigger buttons
+ fontMetric = self.fontMetrics()
+ iconSize = qt.QSize(fontMetric.height(), fontMetric.height())
+
+ self.mainLayout = qt.QHBoxLayout(self)
+ self.mainLayout.setContentsMargins(0, 0, 0, 0)
+ self.mainLayout.setSpacing(0)
+ self.firstButton = qt.QPushButton(self)
+ self.firstButton.setIcon(icons.getQIcon("first"))
+ self.firstButton.setIconSize(iconSize)
+ self.previousButton = qt.QPushButton(self)
+ self.previousButton.setIcon(icons.getQIcon("previous"))
+ self.previousButton.setIconSize(iconSize)
+ self._lineEdit = qt.QLineEdit(self)
+
+ self._label = qt.QLabel(self)
+ self.nextButton = qt.QPushButton(self)
+ self.nextButton.setIcon(icons.getQIcon("next"))
+ self.nextButton.setIconSize(iconSize)
+ self.lastButton = qt.QPushButton(self)
+ self.lastButton.setIcon(icons.getQIcon("last"))
+ self.lastButton.setIconSize(iconSize)
+
+ self.mainLayout.addWidget(self.firstButton)
+ self.mainLayout.addWidget(self.previousButton)
+ self.mainLayout.addWidget(self._lineEdit)
+ self.mainLayout.addWidget(self._label)
+ self.mainLayout.addWidget(self.nextButton)
+ self.mainLayout.addWidget(self.lastButton)
+
+ if n is None:
+ first = qt.QSlider().minimum()
+ last = qt.QSlider().maximum()
+ else:
+ first, last = 0, n
+
+ self._lineEdit.setFixedWidth(self._lineEdit.fontMetrics().boundingRect('%05d' % last).width())
+ validator = qt.QIntValidator(first, last, self._lineEdit)
+ self._lineEdit.setValidator(validator)
+ self._lineEdit.setText("%d" % first)
+ self._label.setText("of %d" % last)
+
+ self._index = first
+ """0-based index"""
+
+ self.firstButton.clicked.connect(self._firstClicked)
+ self.previousButton.clicked.connect(self._previousClicked)
+ self.nextButton.clicked.connect(self._nextClicked)
+ self.lastButton.clicked.connect(self._lastClicked)
+ self._lineEdit.editingFinished.connect(self._textChangedSlot)
+
+ def lineEdit(self):
+ """Returns the line edit provided by this widget.
+
+ :rtype: qt.QLineEdit
+ """
+ return self._lineEdit
+
+ def limitWidget(self):
+ """Returns the widget displaying axes limits.
+
+ :rtype: qt.QLabel
+ """
+ return self._label
+
+ def _firstClicked(self):
+ """Select first/lowest frame number"""
+ self.setValue(self.getRange()[0])
+
+ def _previousClicked(self):
+ """Select previous frame number"""
+ self.setValue(self.getValue() - 1)
+
+ def _nextClicked(self):
+ """Select next frame number"""
+ self.setValue(self.getValue() + 1)
+
+ def _lastClicked(self):
+ """Select last/highest frame number"""
+ self.setValue(self.getRange()[1])
+
+ def _textChangedSlot(self):
+ """Select frame number typed in the line edit widget"""
+ txt = self._lineEdit.text()
+ if not len(txt):
+ self._lineEdit.setText("%d" % self._index)
+ return
+ new_value = int(txt)
+ if new_value == self._index:
+ return
+ ddict = {
+ "event": "indexChanged",
+ "old": self._index,
+ "new": new_value,
+ "id": id(self)
+ }
+ self._index = new_value
+ self.sigIndexChanged.emit(ddict)
+
+ def getRange(self):
+ """Returns frame range
+
+ :return: (first_index, last_index)
+ """
+ validator = self.lineEdit().validator()
+ return validator.bottom(), validator.top()
+
+ def setRange(self, first, last):
+ """Set minimum and maximum frame indices.
+
+ Initialize the frame index to *first*.
+ Update the label text to *" limits: first, last"*
+
+ :param int first: Minimum frame index
+ :param int last: Maximum frame index"""
+ bottom = min(first, last)
+ top = max(first, last)
+ self._lineEdit.validator().setTop(top)
+ self._lineEdit.validator().setBottom(bottom)
+ self.setValue(bottom)
+
+ # Update limits
+ self._label.setText(" limits: %d, %d " % (bottom, top))
+
+ @deprecation.deprecated(replacement="FrameBrowser.setRange",
+ since_version="0.8")
+ def setLimits(self, first, last):
+ return self.setRange(first, last)
+
+ def setNFrames(self, nframes):
+ """Set minimum=0 and maximum=nframes-1 frame numbers.
+
+ Initialize the frame index to 0.
+ Update the label text to *"1 of nframes"*
+
+ :param int nframes: Number of frames"""
+ top = nframes - 1
+ self.setRange(0, top)
+ # display 1-based index in label
+ self._label.setText(" of %d " % top)
+
+ @deprecation.deprecated(replacement="FrameBrowser.getValue",
+ since_version="0.8")
+ def getCurrentIndex(self):
+ return self._index
+
+ def getValue(self):
+ """Return current frame index"""
+ return self._index
+
+ def setValue(self, value):
+ """Set 0-based frame index
+
+ Value is clipped to current range.
+
+ :param int value: Frame number"""
+ bottom = self.lineEdit().validator().bottom()
+ top = self.lineEdit().validator().top()
+ value = int(value)
+
+ if value < bottom:
+ value = bottom
+ elif value > top:
+ value = top
+
+ self._lineEdit.setText("%d" % value)
+ self._textChangedSlot()
+
+
+class HorizontalSliderWithBrowser(qt.QAbstractSlider):
+ """
+ Slider widget combining a :class:`QSlider` and a :class:`FrameBrowser`.
+
+ .. image:: img/HorizontalSliderWithBrowser.png
+
+ The data model is an integer within a range.
+
+ The default value is the default :class:`QSlider` value (0),
+ and the default range is the default QSlider range (0 -- 99)
+
+ The signal emitted when the value is changed is the usual QAbstractSlider
+ signal :attr:`valueChanged`. The signal carries the value (as an integer).
+
+ :param QWidget parent: Optional parent widget
+ """
+ def __init__(self, parent=None):
+ qt.QAbstractSlider.__init__(self, parent)
+ self.setOrientation(qt.Qt.Horizontal)
+
+ self.mainLayout = qt.QHBoxLayout(self)
+ self.mainLayout.setContentsMargins(0, 0, 0, 0)
+ self.mainLayout.setSpacing(2)
+
+ self._slider = qt.QSlider(self)
+ self._slider.setOrientation(qt.Qt.Horizontal)
+
+ self._browser = FrameBrowser(self)
+
+ self.mainLayout.addWidget(self._slider, 1)
+ self.mainLayout.addWidget(self._browser)
+
+ self._slider.valueChanged[int].connect(self._sliderSlot)
+ self._browser.sigIndexChanged.connect(self._browserSlot)
+
+ def lineEdit(self):
+ """Returns the line edit provided by this widget.
+
+ :rtype: qt.QLineEdit
+ """
+ return self._browser.lineEdit()
+
+ def limitWidget(self):
+ """Returns the widget displaying axes limits.
+
+ :rtype: qt.QLabel
+ """
+ return self._browser.limitWidget()
+
+ def setMinimum(self, value):
+ """Set minimum value
+
+ :param int value: Minimum value"""
+ self._slider.setMinimum(value)
+ maximum = self._slider.maximum()
+ self._browser.setRange(value, maximum)
+
+ def setMaximum(self, value):
+ """Set maximum value
+
+ :param int value: Maximum value
+ """
+ self._slider.setMaximum(value)
+ minimum = self._slider.minimum()
+ self._browser.setRange(minimum, value)
+
+ def setRange(self, first, last):
+ """Set minimum/maximum values
+
+ :param int first: Minimum value
+ :param int last: Maximum value"""
+ self._slider.setRange(first, last)
+ self._browser.setRange(first, last)
+
+ def _sliderSlot(self, value):
+ """Emit selected value when slider is activated
+ """
+ self._browser.setValue(value)
+ self.valueChanged.emit(value)
+
+ def _browserSlot(self, ddict):
+ """Emit selected value when browser state is changed"""
+ self._slider.setValue(ddict['new'])
+
+ def setValue(self, value):
+ """Set value
+
+ :param int value: value"""
+ self._slider.setValue(value)
+ self._browser.setValue(value)
+
+ def value(self):
+ """Get selected value"""
+ return self._slider.value()
diff --git a/src/silx/gui/widgets/HierarchicalTableView.py b/src/silx/gui/widgets/HierarchicalTableView.py
new file mode 100644
index 0000000..3ccf4c7
--- /dev/null
+++ b/src/silx/gui/widgets/HierarchicalTableView.py
@@ -0,0 +1,172 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module define a hierarchical table view and model.
+
+It allows to define many headers in the middle of a table.
+
+The implementation hide the default header and allows to custom each cells
+to became a header.
+
+Row and column span is a concept of the view in a QTableView.
+This implementation also provide a span property as part of the model of the
+cell. A role is define to custom this information.
+The view is updated everytime the model is reset to take care of the
+changes of this information.
+
+A default item delegate is used to redefine the paint of the cells.
+"""
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "07/04/2017"
+
+from silx.gui import qt
+
+
+class HierarchicalTableModel(qt.QAbstractTableModel):
+ """
+ Abstract table model to provide more custom on row and column span and
+ headers.
+
+ Default headers are ignored and each cells can define IsHeaderRole and
+ SpanRole using the `data` function.
+ """
+
+ SpanRole = qt.Qt.UserRole + 0
+ """Role returning a tuple for number of row span then column span.
+
+ None and (1, 1) are neutral for the rendering.
+ """
+
+ IsHeaderRole = qt.Qt.UserRole + 1
+ """Role returning True is the identified cell is a header."""
+
+ UserRole = qt.Qt.UserRole + 2
+ """First index of user defined roles"""
+
+ def headerData(self, section, orientation, role=qt.Qt.DisplayRole):
+ """Returns the 0-based row or column index, for display in the
+ horizontal and vertical headers
+
+ In this case the headers are just ignored. Header information is part
+ of each cells.
+ """
+ return None
+
+
+class HierarchicalItemDelegate(qt.QStyledItemDelegate):
+ """
+ Delegate item to take care of the rendering of the default table cells and
+ also the header cells.
+ """
+
+ def __init__(self, parent=None):
+ """
+ Constructor
+
+ :param qt.QObject parent: Parent of the widget
+ """
+ qt.QStyledItemDelegate.__init__(self, parent)
+
+ def paint(self, painter, option, index):
+ """Override the paint function to inject the style of the header.
+
+ :param qt.QPainter painter: Painter context used to displayed the cell
+ :param qt.QStyleOptionViewItem option: Control how the editor is shown
+ :param qt.QIndex index: Index of the data to display
+ """
+ isHeader = index.data(role=HierarchicalTableModel.IsHeaderRole)
+ if isHeader:
+ span = index.data(role=HierarchicalTableModel.SpanRole)
+ span = 1 if span is None else span[1]
+ columnCount = index.model().columnCount()
+ if span == columnCount:
+ mainTitle = True
+ position = qt.QStyleOptionHeader.OnlyOneSection
+ else:
+ mainTitle = False
+ col = index.column()
+ if col == 0:
+ position = qt.QStyleOptionHeader.Beginning
+ elif col < columnCount - 1:
+ position = qt.QStyleOptionHeader.Middle
+ else:
+ position = qt.QStyleOptionHeader.End
+ opt = qt.QStyleOptionHeader()
+ opt.direction = option.direction
+ opt.text = index.data()
+ opt.textAlignment = qt.Qt.AlignCenter if mainTitle else qt.Qt.AlignVCenter
+ opt.direction = option.direction
+ opt.fontMetrics = option.fontMetrics
+ opt.palette = option.palette
+ opt.rect = option.rect
+ opt.state = option.state
+ opt.position = position
+ margin = -1
+ style = qt.QApplication.instance().style()
+ opt.rect = opt.rect.adjusted(margin, margin, -margin, -margin)
+ style.drawControl(qt.QStyle.CE_HeaderSection, opt, painter, None)
+ margin = 3
+ opt.rect = opt.rect.adjusted(margin, margin, -margin, -margin)
+ style.drawControl(qt.QStyle.CE_HeaderLabel, opt, painter, None)
+ else:
+ qt.QStyledItemDelegate.paint(self, painter, option, index)
+
+
+class HierarchicalTableView(qt.QTableView):
+ """A TableView which allow to display a `HierarchicalTableModel`."""
+
+ def __init__(self, parent=None):
+ """
+ Constructor
+
+ :param qt.QWidget parent: Parent of the widget
+ """
+ super(HierarchicalTableView, self).__init__(parent)
+ self.setItemDelegate(HierarchicalItemDelegate(self))
+ self.verticalHeader().setVisible(False)
+ self.horizontalHeader().setVisible(False)
+
+ def setModel(self, model):
+ """Override the default function to connect the model to update
+ function"""
+ if self.model() is not None:
+ model.modelReset.disconnect(self.__modelReset)
+ super(HierarchicalTableView, self).setModel(model)
+ if self.model() is not None:
+ model.modelReset.connect(self.__modelReset)
+ self.__modelReset()
+
+ def __modelReset(self):
+ """Update the model to take care of the changes of the span
+ information"""
+ self.clearSpans()
+ model = self.model()
+ for row in range(model.rowCount()):
+ for column in range(model.columnCount()):
+ index = model.index(row, column, qt.QModelIndex())
+ span = model.data(index, HierarchicalTableModel.SpanRole)
+ if span is not None and span != (1, 1):
+ self.setSpan(row, column, span[0], span[1])
diff --git a/src/silx/gui/widgets/LegendIconWidget.py b/src/silx/gui/widgets/LegendIconWidget.py
new file mode 100755
index 0000000..1c95e41
--- /dev/null
+++ b/src/silx/gui/widgets/LegendIconWidget.py
@@ -0,0 +1,514 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Widget displaying a symbol (marker symbol, line style and color) to identify
+an item displayed by a plot.
+"""
+
+__authors__ = ["V.A. Sole", "T. Rueter", "T. Vincent"]
+__license__ = "MIT"
+__data__ = "11/11/2019"
+
+
+import logging
+
+import numpy
+
+from .. import qt, colors
+
+
+_logger = logging.getLogger(__name__)
+
+
+# Build all symbols
+# Courtesy of the pyqtgraph project
+
+_Symbols = None
+""""Cache supported symbols as Qt paths"""
+
+
+_NoSymbols = (None, 'None', 'none', '', ' ')
+"""List of values resulting in no symbol being displayed for a curve"""
+
+
+_LineStyles = {
+ None: qt.Qt.NoPen,
+ 'None': qt.Qt.NoPen,
+ 'none': qt.Qt.NoPen,
+ '': qt.Qt.NoPen,
+ ' ': qt.Qt.NoPen,
+ '-': qt.Qt.SolidLine,
+ '--': qt.Qt.DashLine,
+ ':': qt.Qt.DotLine,
+ '-.': qt.Qt.DashDotLine
+}
+"""Conversion from matplotlib-like linestyle to Qt"""
+
+_NoLineStyle = (None, 'None', 'none', '', ' ')
+"""List of style values resulting in no line being displayed for a curve"""
+
+
+_colormapImage = {}
+"""Store cached pixmap"""
+# FIXME: Could be better to use a LRU dictionary
+
+_COLORMAP_PIXMAP_SIZE = 32
+"""Size of the cached pixmaps for the colormaps"""
+
+
+def _initSymbols():
+ """Init the cached symbol structure if not yet done."""
+ global _Symbols
+ if _Symbols is not None:
+ return
+
+ symbols = dict([(name, qt.QPainterPath())
+ for name in ['o', 's', 't', 'd', '+', 'x', '.', ',']])
+ symbols['o'].addEllipse(qt.QRectF(.1, .1, .8, .8))
+ symbols['.'].addEllipse(qt.QRectF(.3, .3, .4, .4))
+ symbols[','].addEllipse(qt.QRectF(.4, .4, .2, .2))
+ symbols['s'].addRect(qt.QRectF(.1, .1, .8, .8))
+
+ coords = {
+ 't': [(0.5, 0.), (.1, .8), (.9, .8)],
+ 'd': [(0.1, 0.5), (0.5, 0.), (0.9, 0.5), (0.5, 1.)],
+ '+': [(0.0, 0.40), (0.40, 0.40), (0.40, 0.), (0.60, 0.),
+ (0.60, 0.40), (1., 0.40), (1., 0.60), (0.60, 0.60),
+ (0.60, 1.), (0.40, 1.), (0.40, 0.60), (0., 0.60)],
+ 'x': [(0.0, 0.40), (0.40, 0.40), (0.40, 0.), (0.60, 0.),
+ (0.60, 0.40), (1., 0.40), (1., 0.60), (0.60, 0.60),
+ (0.60, 1.), (0.40, 1.), (0.40, 0.60), (0., 0.60)]
+ }
+ for s, c in coords.items():
+ symbols[s].moveTo(*c[0])
+ for x, y in c[1:]:
+ symbols[s].lineTo(x, y)
+ symbols[s].closeSubpath()
+ tr = qt.QTransform()
+ tr.rotate(45)
+ symbols['x'].translate(qt.QPointF(-0.5, -0.5))
+ symbols['x'] = tr.map(symbols['x'])
+ symbols['x'].translate(qt.QPointF(0.5, 0.5))
+
+ _Symbols = symbols
+
+
+class LegendIconWidget(qt.QWidget):
+ """Object displaying linestyle and symbol of plots.
+
+ :param QWidget parent: See :class:`QWidget`
+ """
+
+ def __init__(self, parent=None):
+ super(LegendIconWidget, self).__init__(parent)
+ _initSymbols()
+
+ # Visibilities
+ self.showLine = True
+ self.showSymbol = True
+ self.showColormap = True
+
+ # Line attributes
+ self.lineStyle = qt.Qt.NoPen
+ self.lineWidth = 1.
+ self.lineColor = qt.Qt.green
+
+ self.symbol = ''
+ # Symbol attributes
+ self.symbolStyle = qt.Qt.SolidPattern
+ self.symbolColor = qt.Qt.green
+ self.symbolOutlineBrush = qt.QBrush(qt.Qt.white)
+ self.symbolColormap = None
+ """Name or array of colors"""
+
+ self.colormap = None
+ """Name or array of colors"""
+
+ # Control widget size: sizeHint "is the only acceptable
+ # alternative, so the widget can never grow or shrink"
+ # (c.f. Qt Doc, enum QSizePolicy::Policy)
+ self.setSizePolicy(qt.QSizePolicy.Fixed,
+ qt.QSizePolicy.Fixed)
+
+ def sizeHint(self):
+ return qt.QSize(50, 15)
+
+ def setSymbol(self, symbol):
+ """Set the symbol"""
+ symbol = str(symbol)
+ if symbol not in _NoSymbols:
+ if symbol not in _Symbols:
+ raise ValueError("Unknown symbol: <%s>" % symbol)
+ self.symbol = symbol
+ self.update()
+
+ def setSymbolColor(self, color):
+ """
+ :param color: determines the symbol color
+ :type style: qt.QColor
+ """
+ self.symbolColor = qt.QColor(color)
+ self.update()
+
+ # Modify Line
+
+ def setLineColor(self, color):
+ self.lineColor = qt.QColor(color)
+ self.update()
+
+ def setLineWidth(self, width):
+ self.lineWidth = float(width)
+ self.update()
+
+ def setLineStyle(self, style):
+ """Set the linestyle.
+
+ Possible line styles:
+
+ - '', ' ', 'None': No line
+ - '-': solid
+ - '--': dashed
+ - ':': dotted
+ - '-.': dash and dot
+
+ :param str style: The linestyle to use
+ """
+ if style not in _LineStyles:
+ raise ValueError('Unknown style: %s', style)
+ self.lineStyle = _LineStyles[style]
+ self.update()
+
+ def _toLut(self, colormap):
+ """Returns an internal LUT object used by this widget to manage
+ a colormap LUT.
+
+ If the argument is a `Colormap` object, only the current state will be
+ displayed. The object itself will not be stored, and further changes
+ of this `Colormap` will not update this widget.
+
+ :param Union[str,numpy.ndarray,Colormap] colormap: The colormap to
+ display
+ :rtype: Union[None,str,numpy.ndarray]
+ """
+ if isinstance(colormap, colors.Colormap):
+ # Helper to allow to support Colormap objects
+ c = colormap.getName()
+ if c is None:
+ c = colormap.getNColors()
+ colormap = c
+
+ return colormap
+
+ def setColormap(self, colormap):
+ """Set the colormap to display
+
+ If the argument is a `Colormap` object, only the current state will be
+ displayed. The object itself will not be stored, and further changes
+ of this `Colormap` will not update this widget.
+
+ :param Union[str,numpy.ndarray,Colormap] colormap: The colormap to
+ display
+ """
+ colormap = self._toLut(colormap)
+
+ if colormap is None:
+ if self.colormap is None:
+ return
+ self.colormap = None
+ self.update()
+ return
+
+ if numpy.array_equal(self.colormap, colormap):
+ # This also works with strings
+ return
+
+ self.colormap = colormap
+ self.update()
+
+ def getColormap(self):
+ """Returns the used colormap.
+
+ If the argument was set with a `Colormap` object, this function will
+ returns the LUT, represented by a string name or by an array or colors.
+
+ :returns: Union[None,str,numpy.ndarray,Colormap]
+ """
+ return self.colormap
+
+ def setSymbolColormap(self, colormap):
+ """Set the colormap to display a symbol
+
+ If the argument is a `Colormap` object, only the current state will be
+ displayed. The object itself will not be stored, and further changes
+ of this `Colormap` will not update this widget.
+
+ :param Union[str,numpy.ndarray,Colormap] colormap: The colormap to
+ display
+ """
+ colormap = self._toLut(colormap)
+
+ if colormap is None:
+ if self.colormap is None:
+ return
+ self.symbolColormap = None
+ self.update()
+ return
+
+ if numpy.array_equal(self.symbolColormap, colormap):
+ # This also works with strings
+ return
+
+ self.symbolColormap = colormap
+ self.update()
+
+ def getSymbolColormap(self):
+ """Returns the used symbol colormap.
+
+ If the argument was set with a `Colormap` object, this function will
+ returns the LUT, represented by a string name or by an array or colors.
+
+ :returns: Union[None,str,numpy.ndarray,Colormap]
+ """
+ return self.colormap
+
+ # Paint
+
+ def paintEvent(self, event):
+ """
+ :param event: event
+ :type event: QPaintEvent
+ """
+ painter = qt.QPainter(self)
+ self.paint(painter, event.rect(), self.palette())
+
+ def paint(self, painter, rect, palette):
+ painter.save()
+ painter.setRenderHint(qt.QPainter.Antialiasing)
+ # Scale painter to the icon height
+ # current -> width = 2.5, height = 1.0
+ scale = float(self.height())
+ ratio = float(self.width()) / scale
+ symbolOffset = qt.QPointF(.5 * (ratio - 1.), 0.)
+ # Determine and scale offset
+ offset = qt.QPointF(float(rect.left()) / scale, float(rect.top()) / scale)
+
+ # Override color when disabled
+ if self.isEnabled():
+ overrideColor = None
+ else:
+ overrideColor = palette.color(qt.QPalette.Disabled,
+ qt.QPalette.WindowText)
+
+ # Draw BG rectangle (for debugging)
+ # bottomRight = qt.QPointF(
+ # float(rect.right())/scale,
+ # float(rect.bottom())/scale)
+ # painter.fillRect(qt.QRectF(offset, bottomRight),
+ # qt.QBrush(qt.Qt.green))
+
+ if self.showColormap:
+ if self.colormap is not None:
+ if self.isEnabled():
+ image = self.getColormapImage(self.colormap)
+ else:
+ image = self.getGrayedColormapImage(self.colormap)
+ pixmapRect = qt.QRect(0, 0, _COLORMAP_PIXMAP_SIZE, 1)
+ widthMargin = 0
+ halfHeight = 4
+ widgetRect = self.rect()
+ dest = qt.QRect(
+ widgetRect.left() + widthMargin,
+ widgetRect.center().y() - halfHeight + 1,
+ widgetRect.width() - widthMargin * 2,
+ halfHeight * 2,
+ )
+ painter.drawImage(dest, image, pixmapRect)
+
+ painter.scale(scale, scale)
+
+ llist = []
+ if self.showLine:
+ linePath = qt.QPainterPath()
+ linePath.moveTo(0., 0.5)
+ linePath.lineTo(ratio, 0.5)
+ # linePath.lineTo(2.5, 0.5)
+ lineBrush = qt.QBrush(
+ self.lineColor if overrideColor is None else overrideColor)
+ linePen = qt.QPen(
+ lineBrush,
+ (self.lineWidth / self.height()),
+ self.lineStyle,
+ qt.Qt.FlatCap
+ )
+ llist.append((linePath, linePen, lineBrush))
+
+ isValidSymbol = (len(self.symbol) and
+ self.symbol not in _NoSymbols)
+ if self.showSymbol and isValidSymbol:
+ if self.symbolColormap is None:
+ # PITFALL ahead: Let this be a warning to others
+ # symbolPath = Symbols[self.symbol]
+ # Copy before translate! Dict is a mutable type
+ symbolPath = qt.QPainterPath(_Symbols[self.symbol])
+ symbolPath.translate(symbolOffset)
+ symbolBrush = qt.QBrush(
+ self.symbolColor if overrideColor is None else overrideColor,
+ self.symbolStyle)
+ symbolPen = qt.QPen(
+ self.symbolOutlineBrush, # Brush
+ 1. / self.height(), # Width
+ qt.Qt.SolidLine # Style
+ )
+ llist.append((symbolPath,
+ symbolPen,
+ symbolBrush))
+ else:
+ nbSymbols = int(ratio + 2)
+ for i in range(nbSymbols):
+ if self.isEnabled():
+ image = self.getColormapImage(self.symbolColormap)
+ else:
+ image = self.getGrayedColormapImage(self.symbolColormap)
+ pos = int((_COLORMAP_PIXMAP_SIZE / nbSymbols) * i)
+ pos = numpy.clip(pos, 0, _COLORMAP_PIXMAP_SIZE-1)
+ color = image.pixelColor(pos, 0)
+ delta = qt.QPointF(ratio * ((i - (nbSymbols-1)/2) / nbSymbols), 0)
+
+ symbolPath = qt.QPainterPath(_Symbols[self.symbol])
+ symbolPath.translate(symbolOffset + delta)
+ symbolBrush = qt.QBrush(color, self.symbolStyle)
+ symbolPen = qt.QPen(
+ self.symbolOutlineBrush, # Brush
+ 1. / self.height(), # Width
+ qt.Qt.SolidLine # Style
+ )
+ llist.append((symbolPath,
+ symbolPen,
+ symbolBrush))
+
+ # Draw
+ for path, pen, brush in llist:
+ path.translate(offset)
+ painter.setPen(pen)
+ painter.setBrush(brush)
+ painter.drawPath(path)
+
+ painter.restore()
+
+ # Helpers
+
+ @staticmethod
+ def isEmptySymbol(symbol):
+ """Returns True if this symbol description will result in an empty
+ symbol."""
+ return symbol in _NoSymbols
+
+ @staticmethod
+ def isEmptyLineStyle(lineStyle):
+ """Returns True if this line style description will result in an empty
+ line."""
+ return lineStyle in _NoLineStyle
+
+ @staticmethod
+ def _getColormapKey(colormap):
+ """
+ Returns the key used to store the image in the data storage
+ """
+ if isinstance(colormap, numpy.ndarray):
+ key = tuple(colormap)
+ else:
+ key = colormap
+ return key
+
+ @staticmethod
+ def getGrayedColormapImage(colormap):
+ """Return a grayed version image preview from a LUT name.
+
+ This images are cached into a global structure.
+
+ :param Union[str,numpy.ndarray] colormap: Description of the LUT
+ :rtype: qt.QImage
+ """
+ key = LegendIconWidget._getColormapKey(colormap)
+ grayKey = (key, "gray")
+ image = _colormapImage.get(grayKey, None)
+ if image is None:
+ image = LegendIconWidget.getColormapImage(colormap)
+ image = image.convertToFormat(qt.QImage.Format_Grayscale8)
+ _colormapImage[grayKey] = image
+ return image
+
+ @staticmethod
+ def getColormapImage(colormap):
+ """Return an image preview from a LUT name.
+
+ This images are cached into a global structure.
+
+ :param Union[str,numpy.ndarray] colormap: Description of the LUT
+ :rtype: qt.QImage
+ """
+ key = LegendIconWidget._getColormapKey(colormap)
+ image = _colormapImage.get(key, None)
+ if image is None:
+ image = LegendIconWidget.createColormapImage(colormap)
+ _colormapImage[key] = image
+ return image
+
+ @staticmethod
+ def createColormapImage(colormap):
+ """Create and return an icon preview from a LUT name.
+
+ This icons are cached into a global structure.
+
+ :param Union[str,numpy.ndarray] colormap: Description of the LUT
+ :rtype: qt.QImage
+ """
+ size = _COLORMAP_PIXMAP_SIZE
+ if isinstance(colormap, numpy.ndarray):
+ lut = colormap
+ if len(lut) > size:
+ # Down sample
+ step = int(len(lut) / size)
+ lut = lut[::step]
+ elif len(lut) < size:
+ # Over sample
+ indexes = numpy.arange(size) / float(size) * (len(lut) - 1)
+ indexes = indexes.astype("int")
+ lut = lut[indexes]
+ else:
+ colormap = colors.Colormap(colormap)
+ lut = colormap.getNColors(size)
+
+ if lut is None or len(lut) == 0:
+ return qt.QIcon()
+
+ pixmap = qt.QPixmap(size, 1)
+ painter = qt.QPainter(pixmap)
+ for i in range(size):
+ rgb = lut[i]
+ r, g, b = rgb[0], rgb[1], rgb[2]
+ painter.setPen(qt.QColor(r, g, b))
+ painter.drawPoint(qt.QPoint(i, 0))
+ painter.end()
+ return pixmap.toImage()
diff --git a/src/silx/gui/widgets/MedianFilterDialog.py b/src/silx/gui/widgets/MedianFilterDialog.py
new file mode 100644
index 0000000..dd4a00d
--- /dev/null
+++ b/src/silx/gui/widgets/MedianFilterDialog.py
@@ -0,0 +1,80 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+""" MedianFilterDialog
+Classes
+-------
+
+Widgets:
+
+ - :class:`MedianFilterDialog`
+"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "14/02/2017"
+
+
+import logging
+
+from silx.gui import qt
+
+
+_logger = logging.getLogger(__name__)
+
+class MedianFilterDialog(qt.QDialog):
+ """QDialog window featuring a :class:`BackgroundWidget`"""
+ sigFilterOptChanged = qt.Signal(int, bool)
+
+ def __init__(self, parent=None):
+ qt.QDialog.__init__(self, parent)
+
+ self.setWindowTitle("Median filter options")
+ self.mainLayout = qt.QHBoxLayout(self)
+ self.setLayout(self.mainLayout)
+
+ # filter width GUI
+ self.mainLayout.addWidget(qt.QLabel('filter width:', parent = self))
+ self._filterWidth = qt.QSpinBox(parent=self)
+ self._filterWidth.setMinimum(1)
+ self._filterWidth.setValue(1)
+ self._filterWidth.setSingleStep(2);
+ widthTooltip = """radius width of the pixel including in the filter
+ for each pixel"""
+ self._filterWidth.setToolTip(widthTooltip)
+ self._filterWidth.valueChanged.connect(self._filterOptionChanged)
+ self.mainLayout.addWidget(self._filterWidth)
+
+ # filter option GUI
+ self._filterOption = qt.QCheckBox('conditional', parent=self)
+ conditionalTooltip = """if check, implement a conditional filter"""
+ self._filterOption.stateChanged.connect(self._filterOptionChanged)
+ self.mainLayout.addWidget(self._filterOption)
+
+ def _filterOptionChanged(self):
+ """Call back used when the filter values are changed"""
+ if self._filterWidth.value()%2 == 0:
+ _logger.warning('median filter only accept odd values')
+ else:
+ self.sigFilterOptChanged.emit(self._filterWidth.value(), self._filterOption.isChecked()) \ No newline at end of file
diff --git a/src/silx/gui/widgets/MultiModeAction.py b/src/silx/gui/widgets/MultiModeAction.py
new file mode 100644
index 0000000..502275d
--- /dev/null
+++ b/src/silx/gui/widgets/MultiModeAction.py
@@ -0,0 +1,83 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Action to hold many mode actions, usually for a tool bar.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__data__ = "22/04/2020"
+
+
+from silx.gui import qt
+
+
+class MultiModeAction(qt.QWidgetAction):
+ """This action provides a default checkable action from a list of checkable
+ actions.
+
+ The default action can be selected from a drop down list. The last one used
+ became the default one.
+
+ The default action is directly usable without using the drop down list.
+ """
+
+ def __init__(self, parent=None):
+ assert isinstance(parent, qt.QWidget)
+ qt.QWidgetAction.__init__(self, parent)
+ button = qt.QToolButton(parent)
+ button.setPopupMode(qt.QToolButton.MenuButtonPopup)
+ self.setDefaultWidget(button)
+ self.__button = button
+
+ def getMenu(self):
+ """Returns the menu.
+
+ :rtype: qt.QMenu
+ """
+ button = self.__button
+ menu = button.menu()
+ if menu is None:
+ menu = qt.QMenu(button)
+ button.setMenu(menu)
+ return menu
+
+ def addAction(self, action):
+ """Add a new action to the list.
+
+ :param qt.QAction action: New action
+ """
+ menu = self.getMenu()
+ button = self.__button
+ menu.addAction(action)
+ if button.defaultAction() is None:
+ button.setDefaultAction(action)
+ if action.isCheckable():
+ action.toggled.connect(self._toggled)
+
+ def _toggled(self, checked):
+ if checked:
+ action = self.sender()
+ button = self.__button
+ button.setDefaultAction(action)
diff --git a/src/silx/gui/widgets/PeriodicTable.py b/src/silx/gui/widgets/PeriodicTable.py
new file mode 100644
index 0000000..6fed109
--- /dev/null
+++ b/src/silx/gui/widgets/PeriodicTable.py
@@ -0,0 +1,831 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Periodic table widgets
+
+Classes
+-------
+
+Widgets:
+
+ - :class:`PeriodicTable`
+ - :class:`PeriodicList`
+ - :class:`PeriodicCombo`
+
+Data model:
+
+ - :class:`PeriodicTableItem`
+ - :class:`ColoredPeriodicTableItem`
+
+
+Example of usage
+----------------
+
+This example uses the widgets with the standard builtin elements list.
+
+.. code-block:: python
+
+ from silx.gui import qt
+ from silx.gui.widgets.PeriodicTable import PeriodicTable, \
+ PeriodicCombo, PeriodicList
+
+ a = qt.QApplication([])
+
+ w = qt.QTabWidget()
+
+ ptable = PeriodicTable(w, selectable=True)
+ pcombo = PeriodicCombo(w)
+ plist = PeriodicList(w)
+
+ w.addTab(ptable, "PeriodicTable")
+ w.addTab(plist, "PeriodicList")
+ w.addTab(pcombo, "PeriodicCombo")
+
+ ptable.setSelection(['H', 'Fe', 'Si'])
+ plist.setSelectedElements(['H', 'Be', 'F'])
+ pcombo.setSelection("Li")
+
+ def change_list(items):
+ print("New list selection:", [item.symbol for item in items])
+
+ def change_combo(item):
+ print("New combo selection:", item.symbol)
+
+ def click_table(item):
+ print("New table click:", item.symbol)
+
+ def change_table(items):
+ print("New table selection:", [item.symbol for item in items])
+
+ ptable.sigElementClicked.connect(click_table)
+ ptable.sigSelectionChanged.connect(change_table)
+ plist.sigSelectionChanged.connect(change_list)
+ pcombo.sigSelectionChanged.connect(change_combo)
+
+ w.show()
+ a.exec()
+
+
+The second example explains how to define custom elements.
+
+.. code-block:: python
+
+ from silx.gui import qt
+ from silx.gui.widgets.PeriodicTable import PeriodicTable, \
+ PeriodicCombo, PeriodicList
+ from silx.gui.widgets.PeriodicTable import PeriodicTableItem
+
+ # subclass PeriodicTableItem
+ class MyPeriodicTableItem(PeriodicTableItem):
+ "New item with added mass number and number of protons"
+ def __init__(self, symbol, Z, A, col, row, name, mass,
+ subcategory=""):
+ PeriodicTableItem.__init__(
+ self, symbol, Z, col, row, name, mass,
+ subcategory)
+
+ self.A = A
+ "Mass number (neutrons + protons)"
+
+ self.num_neutrons = A - Z
+ "Number of neutrons"
+
+ # build your list of elements
+ my_elements = [MyPeriodicTableItem("H", 1, 1, 1, 1, "hydrogen",
+ 1.00800, "diatomic nonmetal"),
+ MyPeriodicTableItem("He", 2, 4, 18, 1, "helium",
+ 4.0030, "noble gas"),
+ # etc ...
+ ]
+
+ app = qt.QApplication([])
+
+ ptable = PeriodicTable(elements=my_elements, selectable=True)
+ ptable.show()
+
+ def click_table(item):
+ "Callback function printing the mass number of clicked element"
+ print("New table click, mass number:", item.A)
+
+ ptable.sigElementClicked.connect(click_table)
+ app.exec()
+
+"""
+
+__authors__ = ["E. Papillon", "V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "26/01/2017"
+
+from collections import OrderedDict
+import logging
+from silx.gui import qt
+
+_logger = logging.getLogger(__name__)
+
+# Symbol Atomic Number col row name mass subcategory
+_elements = [("H", 1, 1, 1, "hydrogen", 1.00800, "diatomic nonmetal"),
+ ("He", 2, 18, 1, "helium", 4.0030, "noble gas"),
+ ("Li", 3, 1, 2, "lithium", 6.94000, "alkali metal"),
+ ("Be", 4, 2, 2, "beryllium", 9.01200, "alkaline earth metal"),
+ ("B", 5, 13, 2, "boron", 10.8110, "metalloid"),
+ ("C", 6, 14, 2, "carbon", 12.0100, "polyatomic nonmetal"),
+ ("N", 7, 15, 2, "nitrogen", 14.0080, "diatomic nonmetal"),
+ ("O", 8, 16, 2, "oxygen", 16.0000, "diatomic nonmetal"),
+ ("F", 9, 17, 2, "fluorine", 19.0000, "diatomic nonmetal"),
+ ("Ne", 10, 18, 2, "neon", 20.1830, "noble gas"),
+ ("Na", 11, 1, 3, "sodium", 22.9970, "alkali metal"),
+ ("Mg", 12, 2, 3, "magnesium", 24.3200, "alkaline earth metal"),
+ ("Al", 13, 13, 3, "aluminium", 26.9700, "post transition metal"),
+ ("Si", 14, 14, 3, "silicon", 28.0860, "metalloid"),
+ ("P", 15, 15, 3, "phosphorus", 30.9750, "polyatomic nonmetal"),
+ ("S", 16, 16, 3, "sulphur", 32.0660, "polyatomic nonmetal"),
+ ("Cl", 17, 17, 3, "chlorine", 35.4570, "diatomic nonmetal"),
+ ("Ar", 18, 18, 3, "argon", 39.9440, "noble gas"),
+ ("K", 19, 1, 4, "potassium", 39.1020, "alkali metal"),
+ ("Ca", 20, 2, 4, "calcium", 40.0800, "alkaline earth metal"),
+ ("Sc", 21, 3, 4, "scandium", 44.9600, "transition metal"),
+ ("Ti", 22, 4, 4, "titanium", 47.9000, "transition metal"),
+ ("V", 23, 5, 4, "vanadium", 50.9420, "transition metal"),
+ ("Cr", 24, 6, 4, "chromium", 51.9960, "transition metal"),
+ ("Mn", 25, 7, 4, "manganese", 54.9400, "transition metal"),
+ ("Fe", 26, 8, 4, "iron", 55.8500, "transition metal"),
+ ("Co", 27, 9, 4, "cobalt", 58.9330, "transition metal"),
+ ("Ni", 28, 10, 4, "nickel", 58.6900, "transition metal"),
+ ("Cu", 29, 11, 4, "copper", 63.5400, "transition metal"),
+ ("Zn", 30, 12, 4, "zinc", 65.3800, "transition metal"),
+ ("Ga", 31, 13, 4, "gallium", 69.7200, "post transition metal"),
+ ("Ge", 32, 14, 4, "germanium", 72.5900, "metalloid"),
+ ("As", 33, 15, 4, "arsenic", 74.9200, "metalloid"),
+ ("Se", 34, 16, 4, "selenium", 78.9600, "polyatomic nonmetal"),
+ ("Br", 35, 17, 4, "bromine", 79.9200, "diatomic nonmetal"),
+ ("Kr", 36, 18, 4, "krypton", 83.8000, "noble gas"),
+ ("Rb", 37, 1, 5, "rubidium", 85.4800, "alkali metal"),
+ ("Sr", 38, 2, 5, "strontium", 87.6200, "alkaline earth metal"),
+ ("Y", 39, 3, 5, "yttrium", 88.9050, "transition metal"),
+ ("Zr", 40, 4, 5, "zirconium", 91.2200, "transition metal"),
+ ("Nb", 41, 5, 5, "niobium", 92.9060, "transition metal"),
+ ("Mo", 42, 6, 5, "molybdenum", 95.9500, "transition metal"),
+ ("Tc", 43, 7, 5, "technetium", 99.0000, "transition metal"),
+ ("Ru", 44, 8, 5, "ruthenium", 101.0700, "transition metal"),
+ ("Rh", 45, 9, 5, "rhodium", 102.9100, "transition metal"),
+ ("Pd", 46, 10, 5, "palladium", 106.400, "transition metal"),
+ ("Ag", 47, 11, 5, "silver", 107.880, "transition metal"),
+ ("Cd", 48, 12, 5, "cadmium", 112.410, "transition metal"),
+ ("In", 49, 13, 5, "indium", 114.820, "post transition metal"),
+ ("Sn", 50, 14, 5, "tin", 118.690, "post transition metal"),
+ ("Sb", 51, 15, 5, "antimony", 121.760, "metalloid"),
+ ("Te", 52, 16, 5, "tellurium", 127.600, "metalloid"),
+ ("I", 53, 17, 5, "iodine", 126.910, "diatomic nonmetal"),
+ ("Xe", 54, 18, 5, "xenon", 131.300, "noble gas"),
+ ("Cs", 55, 1, 6, "caesium", 132.910, "alkali metal"),
+ ("Ba", 56, 2, 6, "barium", 137.360, "alkaline earth metal"),
+ ("La", 57, 3, 6, "lanthanum", 138.920, "lanthanide"),
+ ("Ce", 58, 4, 9, "cerium", 140.130, "lanthanide"),
+ ("Pr", 59, 5, 9, "praseodymium", 140.920, "lanthanide"),
+ ("Nd", 60, 6, 9, "neodymium", 144.270, "lanthanide"),
+ ("Pm", 61, 7, 9, "promethium", 147.000, "lanthanide"),
+ ("Sm", 62, 8, 9, "samarium", 150.350, "lanthanide"),
+ ("Eu", 63, 9, 9, "europium", 152.000, "lanthanide"),
+ ("Gd", 64, 10, 9, "gadolinium", 157.260, "lanthanide"),
+ ("Tb", 65, 11, 9, "terbium", 158.930, "lanthanide"),
+ ("Dy", 66, 12, 9, "dysprosium", 162.510, "lanthanide"),
+ ("Ho", 67, 13, 9, "holmium", 164.940, "lanthanide"),
+ ("Er", 68, 14, 9, "erbium", 167.270, "lanthanide"),
+ ("Tm", 69, 15, 9, "thulium", 168.940, "lanthanide"),
+ ("Yb", 70, 16, 9, "ytterbium", 173.040, "lanthanide"),
+ ("Lu", 71, 17, 9, "lutetium", 174.990, "lanthanide"),
+ ("Hf", 72, 4, 6, "hafnium", 178.500, "transition metal"),
+ ("Ta", 73, 5, 6, "tantalum", 180.950, "transition metal"),
+ ("W", 74, 6, 6, "tungsten", 183.920, "transition metal"),
+ ("Re", 75, 7, 6, "rhenium", 186.200, "transition metal"),
+ ("Os", 76, 8, 6, "osmium", 190.200, "transition metal"),
+ ("Ir", 77, 9, 6, "iridium", 192.200, "transition metal"),
+ ("Pt", 78, 10, 6, "platinum", 195.090, "transition metal"),
+ ("Au", 79, 11, 6, "gold", 197.200, "transition metal"),
+ ("Hg", 80, 12, 6, "mercury", 200.610, "transition metal"),
+ ("Tl", 81, 13, 6, "thallium", 204.390, "post transition metal"),
+ ("Pb", 82, 14, 6, "lead", 207.210, "post transition metal"),
+ ("Bi", 83, 15, 6, "bismuth", 209.000, "post transition metal"),
+ ("Po", 84, 16, 6, "polonium", 209.000, "post transition metal"),
+ ("At", 85, 17, 6, "astatine", 210.000, "metalloid"),
+ ("Rn", 86, 18, 6, "radon", 222.000, "noble gas"),
+ ("Fr", 87, 1, 7, "francium", 223.000, "alkali metal"),
+ ("Ra", 88, 2, 7, "radium", 226.000, "alkaline earth metal"),
+ ("Ac", 89, 3, 7, "actinium", 227.000, "actinide"),
+ ("Th", 90, 4, 10, "thorium", 232.000, "actinide"),
+ ("Pa", 91, 5, 10, "proactinium", 231.03588, "actinide"),
+ ("U", 92, 6, 10, "uranium", 238.070, "actinide"),
+ ("Np", 93, 7, 10, "neptunium", 237.000, "actinide"),
+ ("Pu", 94, 8, 10, "plutonium", 239.100, "actinide"),
+ ("Am", 95, 9, 10, "americium", 243, "actinide"),
+ ("Cm", 96, 10, 10, "curium", 247, "actinide"),
+ ("Bk", 97, 11, 10, "berkelium", 247, "actinide"),
+ ("Cf", 98, 12, 10, "californium", 251, "actinide"),
+ ("Es", 99, 13, 10, "einsteinium", 252, "actinide"),
+ ("Fm", 100, 14, 10, "fermium", 257, "actinide"),
+ ("Md", 101, 15, 10, "mendelevium", 258, "actinide"),
+ ("No", 102, 16, 10, "nobelium", 259, "actinide"),
+ ("Lr", 103, 17, 10, "lawrencium", 262, "actinide"),
+ ("Rf", 104, 4, 7, "rutherfordium", 261, "transition metal"),
+ ("Db", 105, 5, 7, "dubnium", 262, "transition metal"),
+ ("Sg", 106, 6, 7, "seaborgium", 266, "transition metal"),
+ ("Bh", 107, 7, 7, "bohrium", 264, "transition metal"),
+ ("Hs", 108, 8, 7, "hassium", 269, "transition metal"),
+ ("Mt", 109, 9, 7, "meitnerium", 268)]
+
+
+class PeriodicTableItem(object):
+ """Periodic table item, used as generic item in :class:`PeriodicTable`,
+ :class:`PeriodicCombo` and :class:`PeriodicList`.
+
+ This implementation stores the minimal amount of information needed by the
+ widgets:
+
+ - atomic symbol
+ - atomic number
+ - element name
+ - atomic mass
+ - column of element in periodic table
+ - row of element in periodic table
+
+ You can subclass this class to add additional information.
+
+ :param str symbol: Atomic symbol (e.g. H, He, Li...)
+ :param int Z: Proton number
+ :param int col: 1-based column index of element in periodic table
+ :param int row: 1-based row index of element in periodic table
+ :param str name: PeriodicTableItem name ("hydrogen", ...)
+ :param float mass: Atomic mass (gram per mol)
+ :param str subcategory: Subcategory, based on physical properties
+ (e.g. "alkali metal", "noble gas"...)
+ """
+ def __init__(self, symbol, Z, col, row, name, mass,
+ subcategory=""):
+ self.symbol = symbol
+ """Atomic symbol (e.g. H, He, Li...)"""
+ self.Z = Z
+ """Atomic number (Proton number)"""
+ self.col = col
+ """1-based column index of element in periodic table"""
+ self.row = row
+ """1-based row index of element in periodic table"""
+ self.name = name
+ """PeriodicTableItem name ("hydrogen", ...)"""
+ self.mass = mass
+ """Atomic mass (gram per mol)"""
+ self.subcategory = subcategory
+ """Subcategory, based on physical properties
+ (e.g. "alkali metal", "noble gas"...)"""
+
+ # pymca compatibility (elements used to be stored as a list of lists)
+ def __getitem__(self, idx):
+ if idx == 6:
+ _logger.warning("density not implemented in silx, returning 0.")
+
+ ret = [self.symbol, self.Z,
+ self.col, self.row,
+ self.name, self.mass,
+ 0.]
+ return ret[idx]
+
+ def __len__(self):
+ return 6
+
+
+class ColoredPeriodicTableItem(PeriodicTableItem):
+ """:class:`PeriodicTableItem` with an added :attr:`bgcolor`.
+ The background color can be passed as a parameter to the constructor.
+ If it is not specified, it will be defined based on
+ :attr:`subcategory`.
+
+ :param str bgcolor: Custom background color for element in
+ periodic table, as a RGB string *#RRGGBB*"""
+ COLORS = {
+ "diatomic nonmetal": "#7FFF00", # chartreuse
+ "noble gas": "#00FFFF", # cyan
+ "alkali metal": "#FFE4B5", # Moccasin
+ "alkaline earth metal": "#FFA500", # orange
+ "polyatomic nonmetal": "#7FFFD4", # aquamarine
+ "transition metal": "#FFA07A", # light salmon
+ "metalloid": "#8FBC8F", # Dark Sea Green
+ "post transition metal": "#D3D3D3", # light gray
+ "lanthanide": "#FFB6C1", # light pink
+ "actinide": "#F08080", # Light Coral
+ "": "#FFFFFF" # white
+ }
+ """Dictionary defining RGB colors for each subcategory."""
+
+ def __init__(self, symbol, Z, col, row, name, mass,
+ subcategory="", bgcolor=None):
+ PeriodicTableItem.__init__(self, symbol, Z, col, row, name, mass,
+ subcategory)
+
+ self.bgcolor = self.COLORS.get(subcategory, "#FFFFFF")
+ """Background color of element in the periodic table,
+ based on its subcategory. This should be a string of a hexadecimal
+ RGB code, with the format *#RRGGBB*.
+ If the subcategory is unknown, use white (*#FFFFFF*)
+ """
+
+ # possible custom color
+ if bgcolor is not None:
+ self.bgcolor = bgcolor
+
+
+_defaultTableItems = [ColoredPeriodicTableItem(*info) for info in _elements]
+
+
+class _ElementButton(qt.QPushButton):
+ """Atomic element button, used as a cell in the periodic table
+ """
+ sigElementEnter = qt.pyqtSignal(object)
+ """Signal emitted as the cursor enters the widget"""
+ sigElementLeave = qt.pyqtSignal(object)
+ """Signal emitted as the cursor leaves the widget"""
+ sigElementClicked = qt.pyqtSignal(object)
+ """Signal emitted when the widget is clicked"""
+
+ def __init__(self, item, parent=None):
+ """
+
+ :param parent: Parent widget
+ :param PeriodicTableItem item: :class:`PeriodicTableItem` object
+ """
+ qt.QPushButton.__init__(self, parent)
+
+ self.item = item
+ """:class:`PeriodicTableItem` object represented by this button"""
+
+ self.setText(item.symbol)
+ self.setFlat(1)
+ self.setCheckable(0)
+
+ self.setSizePolicy(qt.QSizePolicy(qt.QSizePolicy.Expanding,
+ qt.QSizePolicy.Expanding))
+
+ self.selected = False
+ self.current = False
+
+ # selection colors
+ self.selected_color = qt.QColor(qt.Qt.yellow)
+ self.current_color = qt.QColor(qt.Qt.gray)
+ self.selected_current_color = qt.QColor(qt.Qt.darkYellow)
+
+ # element colors
+
+ if hasattr(item, "bgcolor"):
+ self.bgcolor = qt.QColor(item.bgcolor)
+ else:
+ self.bgcolor = qt.QColor("#FFFFFF")
+
+ self.brush = qt.QBrush()
+ self.__setBrush()
+
+ self.clicked.connect(self.clickedSlot)
+
+ def sizeHint(self):
+ return qt.QSize(40, 40)
+
+ def setCurrent(self, b):
+ """Set this element button as current.
+ Multiple buttons can be selected.
+
+ :param b: boolean
+ """
+ self.current = b
+ self.__setBrush()
+
+ def isCurrent(self):
+ """
+ :return: True if element button is current
+ """
+ return self.current
+
+ def isSelected(self):
+ """
+ :return: True if element button is selected
+ """
+ return self.selected
+
+ def setSelected(self, b):
+ """Set this element button as selected.
+ Only a single button can be selected.
+
+ :param b: boolean
+ """
+ self.selected = b
+ self.__setBrush()
+
+ def __setBrush(self):
+ """Selected cells are yellow when not current.
+ The current cell is dark yellow when selected or grey when not
+ selected.
+ Other cells have no bg color by default, unless specified at
+ instantiation (:attr:`bgcolor`)"""
+ palette = self.palette()
+ # if self.current and self.selected:
+ # self.brush = qt.QBrush(self.selected_current_color)
+ # el
+ if self.selected:
+ self.brush = qt.QBrush(self.selected_color)
+ # elif self.current:
+ # self.brush = qt.QBrush(self.current_color)
+ elif self.bgcolor is not None:
+ self.brush = qt.QBrush(self.bgcolor)
+ else:
+ self.brush = qt.QBrush()
+ palette.setBrush(self.backgroundRole(),
+ self.brush)
+ self.setPalette(palette)
+ self.update()
+
+ def paintEvent(self, pEvent):
+ # get button geometry
+ widgGeom = self.rect()
+ paintGeom = qt.QRect(widgGeom.left() + 1,
+ widgGeom.top() + 1,
+ widgGeom.width() - 2,
+ widgGeom.height() - 2)
+
+ # paint background color
+ painter = qt.QPainter(self)
+ if self.brush is not None:
+ painter.fillRect(paintGeom, self.brush)
+ # paint frame
+ pen = qt.QPen(qt.Qt.black)
+ pen.setWidth(1 if not self.isCurrent() else 5)
+ painter.setPen(pen)
+ painter.drawRect(paintGeom)
+ painter.end()
+ qt.QPushButton.paintEvent(self, pEvent)
+
+ def enterEvent(self, e):
+ """Emit a :attr:`sigElementEnter` signal and send a
+ :class:`PeriodicTableItem` object"""
+ self.sigElementEnter.emit(self.item)
+
+ def leaveEvent(self, e):
+ """Emit a :attr:`sigElementLeave` signal and send a
+ :class:`PeriodicTableItem` object"""
+ self.sigElementLeave.emit(self.item)
+
+ def clickedSlot(self):
+ """Emit a :attr:`sigElementClicked` signal and send a
+ :class:`PeriodicTableItem` object"""
+ self.sigElementClicked.emit(self.item)
+
+
+class PeriodicTable(qt.QWidget):
+ """Periodic Table widget
+
+ .. image:: img/PeriodicTable.png
+
+ The following example shows how to connect clicking to selection::
+
+ from silx.gui import qt
+ from silx.gui.widgets.PeriodicTable import PeriodicTable
+ app = qt.QApplication([])
+ pt = PeriodicTable()
+ pt.sigElementClicked.connect(pt.elementToggle)
+ pt.show()
+ app.exec()
+
+ To print all selected elements each time a new element is selected::
+
+ def my_slot(item):
+ pt.elementToggle(item)
+ selected_elements = pt.getSelection()
+ for e in selected_elements:
+ print(e.symbol)
+
+ pt.sigElementClicked.connect(my_slot)
+
+ """
+ sigElementClicked = qt.pyqtSignal(object)
+ """When any element is clicked in the table, the widget emits
+ this signal and sends a :class:`PeriodicTableItem` object.
+ """
+
+ sigSelectionChanged = qt.pyqtSignal(object)
+ """When any element is selected/unselected in the table, the widget emits
+ this signal and sends a list of :class:`PeriodicTableItem` objects.
+
+ .. note::
+
+ To enable selection of elements, you must set *selectable=True*
+ when you instantiate the widget. Alternatively, you can also connect
+ :attr:`sigElementClicked` to :meth:`elementToggle` manually::
+
+ pt = PeriodicTable()
+ pt.sigElementClicked.connect(pt.elementToggle)
+
+
+ :param parent: parent QWidget
+ :param str name: Widget window title
+ :param elements: List of items (:class:`PeriodicTableItem` objects) to
+ be represented in the table. By default, take elements from
+ a predefined list with minimal information (symbol, atomic number,
+ name, mass).
+ :param bool selectable: If *True*, multiple elements can be
+ selected by clicking with the mouse. If *False* (default),
+ selection is only possible with method :meth:`setSelection`.
+ """
+
+ def __init__(self, parent=None, name="PeriodicTable", elements=None,
+ selectable=False):
+ self.selectable = selectable
+ qt.QWidget.__init__(self, parent)
+ self.setWindowTitle(name)
+ self.gridLayout = qt.QGridLayout(self)
+ self.gridLayout.setContentsMargins(0, 0, 0, 0)
+ self.gridLayout.addItem(qt.QSpacerItem(0, 5), 7, 0)
+
+ for idx in range(10):
+ self.gridLayout.setRowStretch(idx, 3)
+ # row 8 (above lanthanoids is empty)
+ self.gridLayout.setRowStretch(7, 2)
+
+ # Element information displayed when cursor enters a cell
+ self.eltLabel = qt.QLabel(self)
+ f = self.eltLabel.font()
+ f.setBold(1)
+ self.eltLabel.setFont(f)
+ self.eltLabel.setAlignment(qt.Qt.AlignHCenter)
+ self.gridLayout.addWidget(self.eltLabel, 1, 1, 3, 10)
+
+ self._eltCurrent = None
+ """Current :class:`_ElementButton` (last clicked)"""
+
+ self._eltButtons = OrderedDict()
+ """Dictionary of all :class:`_ElementButton`. Keys are the symbols
+ ("H", "He", "Li"...)"""
+
+ if elements is None:
+ elements = _defaultTableItems
+ # fill cells with elements
+ for elmt in elements:
+ self.__addElement(elmt)
+
+ def __addElement(self, elmt):
+ """Add one :class:`_ElementButton` widget into the grid,
+ connect its signals to interact with the cursor"""
+ b = _ElementButton(elmt, self)
+ b.setAutoDefault(False)
+
+ self._eltButtons[elmt.symbol] = b
+ self.gridLayout.addWidget(b, elmt.row, elmt.col)
+
+ b.sigElementEnter.connect(self.elementEnter)
+ b.sigElementLeave.connect(self._elementLeave)
+ b.sigElementClicked.connect(self._elementClicked)
+
+ def elementEnter(self, item):
+ """Update label with element info (e.g. "Nb(41) - niobium")
+ when mouse cursor hovers an element.
+
+ :param PeriodicTableItem item: Element entered by cursor
+ """
+ self.eltLabel.setText("%s(%d) - %s" % (item.symbol, item.Z, item.name))
+
+ def _elementLeave(self, item):
+ """Clear label when the cursor leaves the cell
+
+ :param PeriodicTableItem item: Element left
+ """
+ self.eltLabel.setText("")
+
+ def _elementClicked(self, item):
+ """Emit :attr:`sigElementClicked`,
+ toggle selected state of element
+
+ :param PeriodicTableItem item: Element clicked
+ """
+ if self._eltCurrent is not None:
+ self._eltCurrent.setCurrent(False)
+ self._eltButtons[item.symbol].setCurrent(True)
+ self._eltCurrent = self._eltButtons[item.symbol]
+ if self.selectable:
+ self.elementToggle(item)
+ self.sigElementClicked.emit(item)
+
+ def getSelection(self):
+ """Return a list of selected elements, as a list of :class:`PeriodicTableItem`
+ objects.
+
+ :return: Selected items
+ :rtype: List[PeriodicTableItem]
+ """
+ return [b.item for b in self._eltButtons.values() if b.isSelected()]
+
+ def setSelection(self, symbols):
+ """Set selected elements.
+
+ This causes the sigSelectionChanged signal
+ to be emitted, even if the selection didn't actually change.
+
+ :param List[str] symbols: List of symbols of elements to be selected
+ (e.g. *["Fe", "Hg", "Li"]*)
+ """
+ # accept list of PeriodicTableItems as input, because getSelection
+ # returns these objects and it makes sense to have getter and setter
+ # use same type of data
+ if isinstance(symbols[0], PeriodicTableItem):
+ symbols = [elmt.symbol for elmt in symbols]
+
+ for (e, b) in self._eltButtons.items():
+ b.setSelected(e in symbols)
+ self.sigSelectionChanged.emit(self.getSelection())
+
+ def setElementSelected(self, symbol, state):
+ """Modify *selected* status of a single element (select or unselect)
+
+ :param str symbol: PeriodicTableItem symbol to be selected
+ :param bool state: *True* to select, *False* to unselect
+ """
+ self._eltButtons[symbol].setSelected(state)
+ self.sigSelectionChanged.emit(self.getSelection())
+
+ def isElementSelected(self, symbol):
+ """Return *True* if element is selected, else *False*
+
+ :param str symbol: PeriodicTableItem symbol
+ :return: *True* if element is selected, else *False*
+ """
+ return self._eltButtons[symbol].isSelected()
+
+ def elementToggle(self, item):
+ """Toggle selected/unselected state for element
+
+ :param item: PeriodicTableItem object
+ """
+ b = self._eltButtons[item.symbol]
+ b.setSelected(not b.isSelected())
+ self.sigSelectionChanged.emit(self.getSelection())
+
+
+class PeriodicCombo(qt.QComboBox):
+ """
+ Combo list with all atomic elements of the periodic table
+
+ .. image:: img/PeriodicCombo.png
+
+ :param bool detailed: True (default) display element symbol, Z and name.
+ False display only element symbol and Z.
+ :param elements: List of items (:class:`PeriodicTableItem` objects) to
+ be represented in the table. By default, take elements from
+ a predefined list with minimal information (symbol, atomic number,
+ name, mass).
+ """
+ sigSelectionChanged = qt.pyqtSignal(object)
+ """Signal emitted when the selection changes. Send
+ :class:`PeriodicTableItem` object representing selected
+ element
+ """
+
+ def __init__(self, parent=None, detailed=True, elements=None):
+ qt.QComboBox.__init__(self, parent)
+
+ # add all elements from global list
+ if elements is None:
+ elements = _defaultTableItems
+ for i, elmt in enumerate(elements):
+ if detailed:
+ txt = "%2s (%d) - %s" % (elmt.symbol, elmt.Z, elmt.name)
+ else:
+ txt = "%2s (%d)" % (elmt.symbol, elmt.Z)
+ self.insertItem(i, txt)
+
+ self.currentIndexChanged[int].connect(self.__selectionChanged)
+
+ def __selectionChanged(self, idx):
+ """Emit :attr:`sigSelectionChanged`"""
+ self.sigSelectionChanged.emit(_defaultTableItems[idx])
+
+ def getSelection(self):
+ """Get selected element
+
+ :return: Selected element
+ :rtype: PeriodicTableItem
+ """
+ return _defaultTableItems[self.currentIndex()]
+
+ def setSelection(self, symbol):
+ """Set selected item in combobox by giving the atomic symbol
+
+ :param symbol: Symbol of element to be selected
+ """
+ # accept PeriodicTableItem for getter/setter consistency
+ if isinstance(symbol, PeriodicTableItem):
+ symbol = symbol.symbol
+ symblist = [elmt.symbol for elmt in _defaultTableItems]
+ self.setCurrentIndex(symblist.index(symbol))
+
+
+class PeriodicList(qt.QTreeWidget):
+ """List of atomic elements in a :class:`QTreeView`
+
+ .. image:: img/PeriodicList.png
+
+ :param QWidget parent: Parent widget
+ :param bool detailed: True (default) display element symbol, Z and name.
+ False display only element symbol and Z.
+ :param single: *True* for single element selection with mouse click,
+ *False* for multiple element selection mode.
+ """
+ sigSelectionChanged = qt.pyqtSignal(object)
+ """When any element is selected/unselected in the widget, it emits
+ this signal and sends a list of currently selected
+ :class:`PeriodicTableItem` objects.
+ """
+
+ def __init__(self, parent=None, detailed=True, single=False, elements=None):
+ qt.QTreeWidget.__init__(self, parent)
+
+ self.detailed = detailed
+
+ headers = ["Z", "Symbol"]
+ if detailed:
+ headers.append("Name")
+ self.setColumnCount(3)
+ else:
+ self.setColumnCount(2)
+ self.setHeaderLabels(headers)
+ self.header().setStretchLastSection(False)
+
+ self.setRootIsDecorated(0)
+ self.itemClicked.connect(self.__selectionChanged)
+ self.setSelectionMode(qt.QAbstractItemView.SingleSelection if single
+ else qt.QAbstractItemView.ExtendedSelection)
+ self.__fill_widget(elements)
+ self.resizeColumnToContents(0)
+ self.resizeColumnToContents(1)
+ if detailed:
+ self.resizeColumnToContents(2)
+
+ def __fill_widget(self, elements):
+ """Fill tree widget with elements """
+ if elements is None:
+ elements = _defaultTableItems
+
+ self.tree_items = []
+
+ previous_item = None
+ for elmt in elements:
+ if previous_item is None:
+ item = qt.QTreeWidgetItem(self)
+ else:
+ item = qt.QTreeWidgetItem(self, previous_item)
+ item.setText(0, str(elmt.Z))
+ item.setText(1, elmt.symbol)
+ if self.detailed:
+ item.setText(2, elmt.name)
+ self.tree_items.append(item)
+ previous_item = item
+
+ def __selectionChanged(self, treeItem, column):
+ """Emit a :attr:`sigSelectionChanged` and send a list of
+ :class:`PeriodicTableItem` objects."""
+ self.sigSelectionChanged.emit(self.getSelection())
+
+ def getSelection(self):
+ """Get a list of selected elements, as a list of :class:`PeriodicTableItem`
+ objects.
+
+ :return: Selected elements
+ :rtype: List[PeriodicTableItem]"""
+ return [_defaultTableItems[idx] for idx in range(len(self.tree_items))
+ if self.tree_items[idx].isSelected()]
+
+ # setSelection is a bad name (name of a QTreeWidget method)
+ def setSelectedElements(self, symbolList):
+ """
+
+ :param symbolList: List of atomic symbols ["H", "He", "Li"...]
+ to be selected in the widget
+ """
+ # accept PeriodicTableItem for getter/setter consistency
+ if isinstance(symbolList[0], PeriodicTableItem):
+ symbolList = [elmt.symbol for elmt in symbolList]
+ for idx in range(len(self.tree_items)):
+ self.tree_items[idx].setSelected(_defaultTableItems[idx].symbol in symbolList)
diff --git a/src/silx/gui/widgets/PrintGeometryDialog.py b/src/silx/gui/widgets/PrintGeometryDialog.py
new file mode 100644
index 0000000..98ff8d1
--- /dev/null
+++ b/src/silx/gui/widgets/PrintGeometryDialog.py
@@ -0,0 +1,222 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+
+from silx.gui import qt
+from silx.gui.widgets.FloatEdit import FloatEdit
+
+
+class PrintGeometryWidget(qt.QWidget):
+ """Widget to specify the size and aspect ratio of an item
+ before sending it to the print preview dialog.
+
+ Use methods :meth:`setPrintGeometry` and :meth:`getPrintGeometry`
+ to interact with the widget.
+ """
+ def __init__(self, parent=None):
+ super(PrintGeometryWidget, self).__init__(parent)
+ self.mainLayout = qt.QGridLayout(self)
+ self.mainLayout.setContentsMargins(0, 0, 0, 0)
+ self.mainLayout.setSpacing(2)
+ hbox = qt.QWidget(self)
+ hboxLayout = qt.QHBoxLayout(hbox)
+ hboxLayout.setContentsMargins(0, 0, 0, 0)
+ hboxLayout.setSpacing(2)
+ label = qt.QLabel(self)
+ label.setText("Units")
+ label.setAlignment(qt.Qt.AlignCenter)
+ self._pageButton = qt.QRadioButton()
+ self._pageButton.setText("Page")
+ self._inchButton = qt.QRadioButton()
+ self._inchButton.setText("Inches")
+ self._cmButton = qt.QRadioButton()
+ self._cmButton.setText("Centimeters")
+ self._buttonGroup = qt.QButtonGroup(self)
+ self._buttonGroup.addButton(self._pageButton)
+ self._buttonGroup.addButton(self._inchButton)
+ self._buttonGroup.addButton(self._cmButton)
+ self._buttonGroup.setExclusive(True)
+
+ # units
+ self.mainLayout.addWidget(label, 0, 0, 1, 4)
+ hboxLayout.addWidget(self._pageButton)
+ hboxLayout.addWidget(self._inchButton)
+ hboxLayout.addWidget(self._cmButton)
+ self.mainLayout.addWidget(hbox, 1, 0, 1, 4)
+ self._pageButton.setChecked(True)
+
+ # xOffset
+ label = qt.QLabel(self)
+ label.setText("X Offset:")
+ self.mainLayout.addWidget(label, 2, 0)
+ self._xOffset = FloatEdit(self, 0.1)
+ self.mainLayout.addWidget(self._xOffset, 2, 1)
+
+ # yOffset
+ label = qt.QLabel(self)
+ label.setText("Y Offset:")
+ self.mainLayout.addWidget(label, 2, 2)
+ self._yOffset = FloatEdit(self, 0.1)
+ self.mainLayout.addWidget(self._yOffset, 2, 3)
+
+ # width
+ label = qt.QLabel(self)
+ label.setText("Width:")
+ self.mainLayout.addWidget(label, 3, 0)
+ self._width = FloatEdit(self, 0.9)
+ self.mainLayout.addWidget(self._width, 3, 1)
+
+ # height
+ label = qt.QLabel(self)
+ label.setText("Height:")
+ self.mainLayout.addWidget(label, 3, 2)
+ self._height = FloatEdit(self, 0.9)
+ self.mainLayout.addWidget(self._height, 3, 3)
+
+ # aspect ratio
+ self._aspect = qt.QCheckBox(self)
+ self._aspect.setText("Keep screen aspect ratio")
+ self._aspect.setChecked(True)
+ self.mainLayout.addWidget(self._aspect, 4, 1, 1, 2)
+
+ def getPrintGeometry(self):
+ """Return the print geometry dictionary.
+
+ See :meth:`setPrintGeometry` for documentation about the
+ print geometry dictionary."""
+ ddict = {}
+ if self._inchButton.isChecked():
+ ddict['units'] = "inches"
+ elif self._cmButton.isChecked():
+ ddict['units'] = "centimeters"
+ else:
+ ddict['units'] = "page"
+
+ ddict['xOffset'] = self._xOffset.value()
+ ddict['yOffset'] = self._yOffset.value()
+ ddict['width'] = self._width.value()
+ ddict['height'] = self._height.value()
+
+ if self._aspect.isChecked():
+ ddict['keepAspectRatio'] = True
+ else:
+ ddict['keepAspectRatio'] = False
+ return ddict
+
+ def setPrintGeometry(self, geometry=None):
+ """Set the print geometry.
+
+ The geometry parameters must be provided as a dictionary with
+ the following keys:
+
+ - *"xOffset"* (float)
+ - *"yOffset"* (float)
+ - *"width"* (float)
+ - *"height"* (float)
+ - *"units"*: possible values *"page", "inch", "cm"*
+ - *"keepAspectRatio"*: *True* or *False*
+
+ If *units* is *"page"*, the values should be floats in [0, 1.]
+ and are interpreted as a fraction of the page width or height.
+
+ :param dict geometry: Geometry parameters, as a dictionary."""
+ if geometry is None:
+ geometry = {}
+ oldDict = self.getPrintGeometry()
+ for key in ["units", "xOffset", "yOffset",
+ "width", "height", "keepAspectRatio"]:
+ geometry[key] = geometry.get(key, oldDict[key])
+
+ if geometry['units'].lower().startswith("inc"):
+ self._inchButton.setChecked(True)
+ elif geometry['units'].lower().startswith("c"):
+ self._cmButton.setChecked(True)
+ else:
+ self._pageButton.setChecked(True)
+
+ self._xOffset.setText("%s" % float(geometry['xOffset']))
+ self._yOffset.setText("%s" % float(geometry['yOffset']))
+ self._width.setText("%s" % float(geometry['width']))
+ self._height.setText("%s" % float(geometry['height']))
+ if geometry['keepAspectRatio']:
+ self._aspect.setChecked(True)
+ else:
+ self._aspect.setChecked(False)
+
+
+class PrintGeometryDialog(qt.QDialog):
+ """Dialog embedding a :class:`PrintGeometryWidget`.
+
+ Use methods :meth:`setPrintGeometry` and :meth:`getPrintGeometry`
+ to interact with the widget.
+
+ Execute method :meth:`exec` to run the dialog.
+ The return value of that method is *True* if the geometry was set
+ (*Ok* button clicked) or *False* if the user clicked the *Cancel*
+ button.
+ """
+
+ def __init__(self, parent=None):
+ qt.QDialog.__init__(self, parent)
+ self.setWindowTitle("Set print size preferences")
+ layout = qt.QVBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(0)
+ self.configurationWidget = PrintGeometryWidget(self)
+ hbox = qt.QWidget(self)
+ hboxLayout = qt.QHBoxLayout(hbox)
+ self.okButton = qt.QPushButton(hbox)
+ self.okButton.setText("Accept")
+ self.okButton.setAutoDefault(False)
+ self.rejectButton = qt.QPushButton(hbox)
+ self.rejectButton.setText("Dismiss")
+ self.rejectButton.setAutoDefault(False)
+ self.okButton.clicked.connect(self.accept)
+ self.rejectButton.clicked.connect(self.reject)
+ hboxLayout.setContentsMargins(0, 0, 0, 0)
+ hboxLayout.setSpacing(2)
+ # hboxLayout.addWidget(qt.HorizontalSpacer(hbox))
+ hboxLayout.addWidget(self.okButton)
+ hboxLayout.addWidget(self.rejectButton)
+ # hboxLayout.addWidget(qt.HorizontalSpacer(hbox))
+ layout.addWidget(self.configurationWidget)
+ layout.addWidget(hbox)
+
+ def setPrintGeometry(self, geometry):
+ """Return the print geometry dictionary.
+
+ See :meth:`PrintGeometryWidget.setPrintGeometry` for documentation on
+ print geometry dictionary.
+
+ :param dict geometry: Print geometry parameters dictionary.
+ """
+ self.configurationWidget.setPrintGeometry(geometry)
+
+ def getPrintGeometry(self):
+ """Return the print geometry dictionary.
+
+ See :meth:`PrintGeometryWidget.setPrintGeometry` for documentation on
+ print geometry dictionary."""
+ return self.configurationWidget.getPrintGeometry()
diff --git a/src/silx/gui/widgets/PrintPreview.py b/src/silx/gui/widgets/PrintPreview.py
new file mode 100644
index 0000000..53e0a1f
--- /dev/null
+++ b/src/silx/gui/widgets/PrintPreview.py
@@ -0,0 +1,697 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module implements a print preview dialog.
+
+The dialog provides methods to send images, pixmaps and SVG
+items to the page to be printed.
+
+The user can interactively move and resize the items.
+"""
+import sys
+import logging
+from silx.gui import qt, printer
+
+
+__authors__ = ["V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "11/07/2017"
+
+
+_logger = logging.getLogger(__name__)
+
+
+class PrintPreviewDialog(qt.QDialog):
+ """Print preview dialog widget.
+ """
+ def __init__(self, parent=None, printer=None):
+
+ qt.QDialog.__init__(self, parent)
+ self.setWindowTitle("Print Preview")
+ self.setModal(False)
+ self.resize(400, 500)
+
+ self.mainLayout = qt.QVBoxLayout(self)
+ self.mainLayout.setContentsMargins(0, 0, 0, 0)
+ self.mainLayout.setSpacing(0)
+
+ self._buildToolbar()
+
+ self.printer = printer
+ # :class:`QPrinter` (paint device that paints on a printer).
+ # :meth:`showEvent` has been reimplemented to enforce printer
+ # setup.
+
+ self.printDialog = None
+ # :class:`QPrintDialog` (dialog for specifying the printer's
+ # configuration)
+
+ self.scene = None
+ # :class:`QGraphicsScene` (surface for managing
+ # 2D graphical items)
+
+ self.page = None
+ # :class:`QGraphicsRectItem` used as white background page on which
+ # to display the print preview.
+
+ self.view = None
+ # :class:`QGraphicsView` widget for displaying :attr:`scene`
+
+ self._svgItems = []
+ # List storing :class:`QSvgRenderer` items to be printed, added in
+ # :meth:`addSvgItem`, cleared in :meth:`_clearAll`.
+ # This ensures that there is a reference pointing to the items,
+ # which ensures they are not destroyed before being printed.
+
+ self._viewScale = 1.0
+ # Zoom level (1.0 is 100%)
+
+ self._toBeCleared = False
+ # Flag indicating that all items must be removed from :attr:`scene`
+ # and from :attr:`_svgItems`.
+ # Set to True after a successful printing. The widget is then hidden,
+ # and it will be cleared the next time it is shown.
+ # Reset to False after :meth:`_clearAll` has done its job.
+
+ def _buildToolbar(self):
+ toolBar = qt.QWidget(self)
+ # a layout for the toolbar
+ toolsLayout = qt.QHBoxLayout(toolBar)
+ toolsLayout.setContentsMargins(0, 0, 0, 0)
+ toolsLayout.setSpacing(0)
+
+ hideBut = qt.QPushButton("Hide", toolBar)
+ hideBut.setToolTip("Hide print preview dialog")
+ hideBut.clicked.connect(self.hide)
+
+ cancelBut = qt.QPushButton("Clear All", toolBar)
+ cancelBut.setToolTip("Remove all items")
+ cancelBut.clicked.connect(self._clearAll)
+
+ removeBut = qt.QPushButton("Remove",
+ toolBar)
+ removeBut.setToolTip("Remove selected item (use left click to select)")
+ removeBut.clicked.connect(self._remove)
+
+ setupBut = qt.QPushButton("Setup", toolBar)
+ setupBut.setToolTip("Select and configure a printer")
+ setupBut.clicked.connect(self.setup)
+
+ printBut = qt.QPushButton("Print", toolBar)
+ printBut.setToolTip("Print page and close print preview")
+ printBut.clicked.connect(self._print)
+
+ zoomPlusBut = qt.QPushButton("Zoom +", toolBar)
+ zoomPlusBut.clicked.connect(self._zoomPlus)
+
+ zoomMinusBut = qt.QPushButton("Zoom -", toolBar)
+ zoomMinusBut.clicked.connect(self._zoomMinus)
+
+ toolsLayout.addWidget(hideBut)
+ toolsLayout.addWidget(printBut)
+ toolsLayout.addWidget(cancelBut)
+ toolsLayout.addWidget(removeBut)
+ toolsLayout.addWidget(setupBut)
+ # toolsLayout.addStretch()
+ # toolsLayout.addWidget(marginLabel)
+ # toolsLayout.addWidget(self.marginSpin)
+ toolsLayout.addStretch()
+ # toolsLayout.addWidget(scaleLabel)
+ # toolsLayout.addWidget(self.scaleCombo)
+ toolsLayout.addWidget(zoomPlusBut)
+ toolsLayout.addWidget(zoomMinusBut)
+ # toolsLayout.addStretch()
+ self.toolBar = toolBar
+ self.mainLayout.addWidget(self.toolBar)
+
+ def _buildStatusBar(self):
+ """Create the status bar used to display the printer name
+ or output file name."""
+ # status bar
+ statusBar = qt.QStatusBar(self)
+ self.targetLabel = qt.QLabel(statusBar)
+ self._updateTargetLabel()
+ statusBar.addWidget(self.targetLabel)
+ self.mainLayout.addWidget(statusBar)
+
+ def _updateTargetLabel(self):
+ """Update printer name or file name shown in the status bar."""
+ if self.printer is None:
+ self.targetLabel.setText("Undefined printer")
+ return
+ if self.printer.outputFileName():
+ self.targetLabel.setText("File:" +
+ self.printer.outputFileName())
+ else:
+ self.targetLabel.setText("Printer:" +
+ self.printer.printerName())
+
+ def _updatePrinter(self):
+ """Resize :attr:`page`, :attr:`scene` and :attr:`view` to :attr:`printer`
+ width and height."""
+ printer = self.printer
+ assert printer is not None, \
+ "_updatePrinter should not be called unless a printer is defined"
+ if self.scene is None:
+ self.scene = qt.QGraphicsScene()
+ self.scene.setBackgroundBrush(qt.QColor(qt.Qt.lightGray))
+ self.scene.setSceneRect(qt.QRectF(0, 0, printer.width(), printer.height()))
+
+ if self.page is None:
+ self.page = qt.QGraphicsRectItem(0, 0, printer.width(), printer.height())
+ self.page.setBrush(qt.QColor(qt.Qt.white))
+ self.scene.addItem(self.page)
+
+ self.scene.setSceneRect(qt.QRectF(0, 0, printer.width(), printer.height()))
+ self.page.setPos(qt.QPointF(0.0, 0.0))
+ self.page.setRect(qt.QRectF(0, 0, printer.width(), printer.height()))
+
+ if self.view is None:
+ self.view = qt.QGraphicsView(self.scene)
+ self.mainLayout.addWidget(self.view)
+ self._buildStatusBar()
+ # self.view.scale(1./self._viewScale, 1./self._viewScale)
+ self.view.fitInView(self.page.rect(), qt.Qt.KeepAspectRatio)
+ self._viewScale = 1.00
+ self._updateTargetLabel()
+
+ # Public methods
+ def addImage(self, image, title=None, comment=None, commentPosition=None):
+ """Add an image to the print preview scene.
+
+ :param QImage image: Image to be added to the scene
+ :param str title: Title shown above (centered) the image
+ :param str comment: Comment displayed below the image
+ :param commentPosition: "CENTER" or "LEFT"
+ """
+ self.addPixmap(qt.QPixmap.fromImage(image),
+ title=title, comment=comment,
+ commentPosition=commentPosition)
+
+ def addPixmap(self, pixmap, title=None, comment=None, commentPosition=None):
+ """Add a pixmap to the print preview scene
+
+ :param QPixmap pixmap: Pixmap to be added to the scene
+ :param str title: Title shown above (centered) the pixmap
+ :param str comment: Comment displayed below the pixmap
+ :param commentPosition: "CENTER" or "LEFT"
+ """
+ if self._toBeCleared:
+ self._clearAll()
+ self.ensurePrinterIsSet()
+ if self.printer is None:
+ _logger.error("printer is not set, cannot add pixmap to page")
+ return
+ if title is None:
+ title = ' ' * 88
+ if comment is None:
+ comment = ' ' * 88
+ if commentPosition is None:
+ commentPosition = "CENTER"
+ rectItem = qt.QGraphicsRectItem(self.page)
+ rectItem.setRect(qt.QRectF(1, 1,
+ pixmap.width(), pixmap.height()))
+
+ pen = rectItem.pen()
+ color = qt.QColor(qt.Qt.red)
+ color.setAlpha(1)
+ pen.setColor(color)
+ rectItem.setPen(pen)
+ rectItem.setZValue(1)
+ rectItem.setFlag(qt.QGraphicsItem.ItemIsSelectable, True)
+ rectItem.setFlag(qt.QGraphicsItem.ItemIsMovable, True)
+ rectItem.setFlag(qt.QGraphicsItem.ItemIsFocusable, False)
+
+ rectItemResizeRect = _GraphicsResizeRectItem(rectItem, self.scene)
+ rectItemResizeRect.setZValue(2)
+
+ pixmapItem = qt.QGraphicsPixmapItem(rectItem)
+ pixmapItem.setPixmap(pixmap)
+ pixmapItem.setZValue(0)
+
+ # I add the title
+ textItem = qt.QGraphicsTextItem(title, rectItem)
+ textItem.setTextInteractionFlags(qt.Qt.TextEditorInteraction)
+ offset = 0.5 * textItem.boundingRect().width()
+ textItem.moveBy(0.5 * pixmap.width() - offset, -20)
+ textItem.setZValue(2)
+
+ # I add the comment
+ commentItem = qt.QGraphicsTextItem(comment, rectItem)
+ commentItem.setTextInteractionFlags(qt.Qt.TextEditorInteraction)
+ offset = 0.5 * commentItem.boundingRect().width()
+ if commentPosition.upper() == "LEFT":
+ x = 1
+ else:
+ x = 0.5 * pixmap.width() - offset
+ commentItem.moveBy(x, pixmap.height() + 20)
+ commentItem.setZValue(2)
+
+ rectItem.moveBy(20, 40)
+
+ def addSvgItem(self, item, title=None,
+ comment=None, commentPosition=None,
+ viewBox=None, keepRatio=True):
+ """Add a SVG item to the scene.
+
+ :param QSvgRenderer item: SVG item to be added to the scene.
+ :param str title: Title shown above (centered) the SVG item.
+ :param str comment: Comment displayed below the SVG item.
+ :param str commentPosition: "CENTER" or "LEFT"
+ :param QRectF viewBox: Bounding box for the item on the print page
+ (xOffset, yOffset, width, height). If None, use original
+ item size.
+ :param bool keepRatio: If True, resizing the item will preserve its
+ original aspect ratio.
+ """
+ if not qt.HAS_SVG:
+ raise RuntimeError("Missing QtSvg library.")
+ if not isinstance(item, qt.QSvgRenderer):
+ raise TypeError("addSvgItem: QSvgRenderer expected")
+ if self._toBeCleared:
+ self._clearAll()
+ self.ensurePrinterIsSet()
+ if self.printer is None:
+ _logger.error("printer is not set, cannot add SvgItem to page")
+ return
+
+ if title is None:
+ title = 50 * ' '
+ if comment is None:
+ comment = 80 * ' '
+ if commentPosition is None:
+ commentPosition = "CENTER"
+
+ if viewBox is None:
+ if hasattr(item, "_viewBox"):
+ # PyMca compatibility: viewbox attached to item
+ viewBox = item._viewBox
+ else:
+ # try the original item viewbox
+ viewBox = item.viewBoxF()
+
+ svgItem = _GraphicsSvgRectItem(viewBox, self.page)
+ svgItem.setSvgRenderer(item)
+
+ svgItem.setCacheMode(qt.QGraphicsItem.NoCache)
+ svgItem.setZValue(0)
+ svgItem.setFlag(qt.QGraphicsItem.ItemIsSelectable, True)
+ svgItem.setFlag(qt.QGraphicsItem.ItemIsMovable, True)
+ svgItem.setFlag(qt.QGraphicsItem.ItemIsFocusable, False)
+
+ rectItemResizeRect = _GraphicsResizeRectItem(svgItem, self.scene,
+ keepratio=keepRatio)
+ rectItemResizeRect.setZValue(2)
+
+ self._svgItems.append(item)
+
+ # Comment / legend
+ dummyComment = 80 * "1"
+ commentItem = qt.QGraphicsTextItem(dummyComment, svgItem)
+ commentItem.setTextInteractionFlags(qt.Qt.TextEditorInteraction)
+ # we scale the text to have the legend box have the same width as the graph
+ scaleCalculationRect = qt.QRectF(commentItem.boundingRect())
+ scale = svgItem.boundingRect().width() / scaleCalculationRect.width()
+
+ commentItem.setPlainText(comment)
+ commentItem.setZValue(1)
+
+ commentItem.setFlag(qt.QGraphicsItem.ItemIsMovable, True)
+ commentItem.setScale(scale)
+
+ # align
+ if commentPosition.upper() == "CENTER":
+ alignment = qt.Qt.AlignCenter
+ elif commentPosition.upper() == "RIGHT":
+ alignment = qt.Qt.AlignRight
+ else:
+ alignment = qt.Qt.AlignLeft
+ commentItem.setTextWidth(commentItem.boundingRect().width())
+ center_format = qt.QTextBlockFormat()
+ center_format.setAlignment(alignment)
+ cursor = commentItem.textCursor()
+ cursor.select(qt.QTextCursor.Document)
+ cursor.mergeBlockFormat(center_format)
+ cursor.clearSelection()
+ commentItem.setTextCursor(cursor)
+ if alignment == qt.Qt.AlignLeft:
+ deltax = 0
+ else:
+ deltax = (svgItem.boundingRect().width() - commentItem.boundingRect().width()) / 2.
+ commentItem.moveBy(svgItem.boundingRect().x() + deltax,
+ svgItem.boundingRect().y() + svgItem.boundingRect().height())
+
+ # Title
+ textItem = qt.QGraphicsTextItem(title, svgItem)
+ textItem.setTextInteractionFlags(qt.Qt.TextEditorInteraction)
+ textItem.setZValue(1)
+ textItem.setFlag(qt.QGraphicsItem.ItemIsMovable, True)
+
+ title_offset = 0.5 * textItem.boundingRect().width()
+ textItem.moveBy(svgItem.boundingRect().x() +
+ 0.5 * svgItem.boundingRect().width() - title_offset * scale,
+ svgItem.boundingRect().y())
+ textItem.setScale(scale)
+
+ def setup(self):
+ """Open a print dialog to ensure the :attr:`printer` is set.
+
+ If the setting fails or is cancelled, :attr:`printer` is reset to
+ *None*.
+ """
+ if self.printer is None:
+ self.printer = printer.getDefaultPrinter()
+ if self.printDialog is None:
+ self.printDialog = qt.QPrintDialog(self.printer, self)
+ if self.printDialog.exec():
+ if self.printer.width() <= 0 or self.printer.height() <= 0:
+ self.message = qt.QMessageBox(self)
+ self.message.setIcon(qt.QMessageBox.Critical)
+ self.message.setText("Unknown library error \non printer initialization")
+ self.message.setWindowTitle("Library Error")
+ self.message.setModal(0)
+ self.printer = None
+ return
+ self.printer.setFullPage(True)
+ self._updatePrinter()
+ else:
+ # printer setup cancelled, check for a possible previous configuration
+ if self.page is None:
+ # not initialized
+ self.printer = None
+
+ def ensurePrinterIsSet(self):
+ """If the printer is not already set, try to interactively
+ setup the printer using a QPrintDialog.
+ In case of failure, hide widget and log a warning.
+
+ :return: True if printer was set. False if it failed or if the
+ selection dialog was canceled.
+ """
+ if self.printer is None:
+ self.setup()
+ if self.printer is None:
+ self.hide()
+ _logger.warning("Printer setup failed or was cancelled, " +
+ "but printer is required.")
+ return self.printer is not None
+
+ def setOutputFileName(self, name):
+ """Set output filename.
+
+ Setting a non-empty name enables printing to file.
+
+ :param str name: File name (path)"""
+ self.printer.setOutputFileName(name)
+
+ # overloaded methods
+ def exec(self):
+ if self._toBeCleared:
+ self._clearAll()
+ return qt.QDialog.exec(self)
+
+ def exec_(self): # Qt5 compatibility
+ return self.exec()
+
+ def raise_(self):
+ if self._toBeCleared:
+ self._clearAll()
+ return qt.QDialog.raise_(self)
+
+ def showEvent(self, event):
+ """Reimplemented to force printer setup.
+ In case of failure, hide the widget."""
+ if self._toBeCleared:
+ self._clearAll()
+ self.ensurePrinterIsSet()
+
+ return super(PrintPreviewDialog, self).showEvent(event)
+
+ # button callbacks
+ def _print(self):
+ """Do the printing, hide the print preview dialog,
+ set :attr:`_toBeCleared` flag to True to trigger clearing the
+ next time the dialog is shown.
+
+ If the printer is not setup, do it first."""
+ printer = self.printer
+
+ painter = qt.QPainter()
+ if not painter.begin(printer) or printer is None:
+ _logger.error("Cannot initialize printer")
+ return
+ try:
+ self.scene.render(painter, qt.QRectF(0, 0, printer.width(), printer.height()),
+ qt.QRectF(self.page.rect().x(), self.page.rect().y(),
+ self.page.rect().width(), self.page.rect().height()),
+ qt.Qt.KeepAspectRatio)
+ painter.end()
+ self.hide()
+ self.accept()
+ self._toBeCleared = True
+ except: # FIXME
+ painter.end()
+ qt.QMessageBox.critical(self, "ERROR",
+ 'Printing problem:\n %s' % sys.exc_info()[1])
+ _logger.error('printing problem:\n %s' % sys.exc_info()[1])
+ return
+
+ def _zoomPlus(self):
+ self._viewScale *= 1.20
+ self.view.scale(1.20, 1.20)
+
+ def _zoomMinus(self):
+ self._viewScale *= 0.80
+ self.view.scale(0.80, 0.80)
+
+ def _clearAll(self):
+ """
+ Clear the print preview window, remove all items
+ but keep the page.
+ """
+ itemlist = self.scene.items()
+ keep = self.page
+ while len(itemlist) != 1:
+ if itemlist.index(keep) == 0:
+ self.scene.removeItem(itemlist[1])
+ else:
+ self.scene.removeItem(itemlist[0])
+ itemlist = self.scene.items()
+ self._svgItems = []
+ self._toBeCleared = False
+
+ def _remove(self):
+ """Remove selected item in :attr:`scene`.
+ """
+ itemlist = self.scene.items()
+
+ # this loop is not efficient if there are many items ...
+ for item in itemlist:
+ if item.isSelected():
+ self.scene.removeItem(item)
+
+
+class SingletonPrintPreviewDialog(PrintPreviewDialog):
+ """Singleton print preview dialog.
+
+ All widgets in a program that instantiate this class will share
+ a single print preview dialog. This enables sending
+ multiple images to a single page to be printed.
+ """
+ _instance = None
+
+ def __new__(self, *var, **kw):
+ if self._instance is None:
+ self._instance = PrintPreviewDialog(*var, **kw)
+ return self._instance
+
+
+class _GraphicsSvgRectItem(qt.QGraphicsRectItem):
+ """:class:`qt.QGraphicsRectItem` with an attached
+ :class:`qt.QSvgRenderer`, and with a painter redefined to render
+ the SVG item."""
+ def setSvgRenderer(self, renderer):
+ """
+
+ :param QSvgRenderer renderer: svg renderer
+ """
+ self._renderer = renderer
+
+ def paint(self, painter, *var, **kw):
+ self._renderer.render(painter, self.boundingRect())
+
+
+class _GraphicsResizeRectItem(qt.QGraphicsRectItem):
+ """Resizable QGraphicsRectItem."""
+ def __init__(self, parent=None, scene=None, keepratio=True):
+ qt.QGraphicsRectItem.__init__(self, parent)
+ rect = parent.boundingRect()
+ x = rect.x()
+ y = rect.y()
+ w = rect.width()
+ h = rect.height()
+ self._newRect = None
+ self.keepRatio = keepratio
+ self.setRect(qt.QRectF(x + w - 40, y + h - 40, 40, 40))
+ self.setAcceptHoverEvents(True)
+ pen = qt.QPen()
+ color = qt.QColor(qt.Qt.white)
+ color.setAlpha(0)
+ pen.setColor(color)
+ pen.setStyle(qt.Qt.NoPen)
+ self.setPen(pen)
+ self.setBrush(color)
+ self.setFlag(self.ItemIsMovable, True)
+ self.show()
+
+ def hoverEnterEvent(self, event):
+ if self.parentItem().isSelected():
+ self.parentItem().setSelected(False)
+ if self.keepRatio:
+ self.setCursor(qt.QCursor(qt.Qt.SizeFDiagCursor))
+ else:
+ self.setCursor(qt.QCursor(qt.Qt.SizeAllCursor))
+ self.setBrush(qt.QBrush(qt.Qt.yellow, qt.Qt.SolidPattern))
+ return qt.QGraphicsRectItem.hoverEnterEvent(self, event)
+
+ def hoverLeaveEvent(self, event):
+ self.setCursor(qt.QCursor(qt.Qt.ArrowCursor))
+ pen = qt.QPen()
+ color = qt.QColor(qt.Qt.white)
+ color.setAlpha(0)
+ pen.setColor(color)
+ pen.setStyle(qt.Qt.NoPen)
+ self.setPen(pen)
+ self.setBrush(color)
+ return qt.QGraphicsRectItem.hoverLeaveEvent(self, event)
+
+ def mousePressEvent(self, event):
+ if self._newRect is not None:
+ self._newRect = None
+ self._point0 = self.pos()
+ parent = self.parentItem()
+ scene = self.scene()
+ # following line prevents dragging along the previously selected
+ # item when resizing another one
+ scene.clearSelection()
+
+ rect = parent.boundingRect()
+ self._x = rect.x()
+ self._y = rect.y()
+ self._w = rect.width()
+ self._h = rect.height()
+ self._ratio = self._w / self._h
+ self._newRect = qt.QGraphicsRectItem(parent)
+ self._newRect.setRect(qt.QRectF(self._x,
+ self._y,
+ self._w,
+ self._h))
+ qt.QGraphicsRectItem.mousePressEvent(self, event)
+
+ def mouseMoveEvent(self, event):
+ point1 = self.pos()
+ deltax = point1.x() - self._point0.x()
+ deltay = point1.y() - self._point0.y()
+ if self.keepRatio:
+ r1 = (self._w + deltax) / self._w
+ r2 = (self._h + deltay) / self._h
+ if r1 < r2:
+ self._newRect.setRect(qt.QRectF(self._x,
+ self._y,
+ self._w + deltax,
+ (self._w + deltax) / self._ratio))
+ else:
+ self._newRect.setRect(qt.QRectF(self._x,
+ self._y,
+ (self._h + deltay) * self._ratio,
+ self._h + deltay))
+ else:
+ self._newRect.setRect(qt.QRectF(self._x,
+ self._y,
+ self._w + deltax,
+ self._h + deltay))
+ qt.QGraphicsRectItem.mouseMoveEvent(self, event)
+
+ def mouseReleaseEvent(self, event):
+ point1 = self.pos()
+ deltax = point1.x() - self._point0.x()
+ deltay = point1.y() - self._point0.y()
+ self.moveBy(-deltax, -deltay)
+ parent = self.parentItem()
+
+ # deduce scale from rectangle
+ if self.keepRatio:
+ scalex = self._newRect.rect().width() / self._w
+ scaley = scalex
+ else:
+ scalex = self._newRect.rect().width() / self._w
+ scaley = self._newRect.rect().height() / self._h
+
+ # apply the scale to the previous transformation matrix
+ previousTransform = parent.transform()
+ parent.setTransform(
+ previousTransform.scale(scalex, scaley))
+
+ self.scene().removeItem(self._newRect)
+ self._newRect = None
+ qt.QGraphicsRectItem.mouseReleaseEvent(self, event)
+
+
+def main():
+ """
+ """
+ if len(sys.argv) < 2:
+ print("give an image file as parameter please.")
+ sys.exit(1)
+
+ if len(sys.argv) > 2:
+ print("only one parameter please.")
+ sys.exit(1)
+
+ filename = sys.argv[1]
+ w = PrintPreviewDialog()
+ w.resize(400, 500)
+
+ comment = ""
+ for i in range(20):
+ comment += "Line number %d: En un lugar de La Mancha de cuyo nombre ...\n" % i
+
+ if filename[-3:] == "svg":
+ item = qt.QSvgRenderer(filename, w.page)
+ w.addSvgItem(item, title=filename,
+ comment=comment, commentPosition="CENTER")
+ else:
+ w.addPixmap(qt.QPixmap.fromImage(qt.QImage(filename)),
+ title=filename,
+ comment=comment,
+ commentPosition="CENTER")
+ w.addImage(qt.QImage(filename), comment=comment, commentPosition="LEFT")
+
+ sys.exit(w.exec())
+
+
+if __name__ == '__main__':
+ a = qt.QApplication(sys.argv)
+ main()
+ a.exec()
diff --git a/src/silx/gui/widgets/RangeSlider.py b/src/silx/gui/widgets/RangeSlider.py
new file mode 100644
index 0000000..61b73fc
--- /dev/null
+++ b/src/silx/gui/widgets/RangeSlider.py
@@ -0,0 +1,776 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a :class:`RangeSlider` widget.
+
+.. image:: img/RangeSlider.png
+ :align: center
+"""
+from __future__ import absolute_import, division
+
+__authors__ = ["D. Naudet", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "26/11/2018"
+
+
+import numpy as numpy
+
+from silx.gui import qt, icons, colors
+from silx.gui.utils.image import convertArrayToQImage
+
+
+class StyleOptionRangeSlider(qt.QStyleOption):
+ def __init__(self):
+ super(StyleOptionRangeSlider, self).__init__()
+ self.minimum = None
+ self.maximum = None
+ self.sliderPosition1 = None
+ self.sliderPosition2 = None
+ self.handlerRect1 = None
+ self.handlerRect2 = None
+
+
+class RangeSlider(qt.QWidget):
+ """Range slider with 2 thumbs and an optional colored groove.
+
+ The position of the slider thumbs can be retrieved either as values
+ in the slider range or as a number of steps or pixels.
+
+ :param QWidget parent: See QWidget
+ """
+
+ _SLIDER_WIDTH = 10
+ """Width of the slider rectangle"""
+
+ _PIXMAP_VOFFSET = 7
+ """Vertical groove pixmap offset"""
+
+ sigRangeChanged = qt.Signal(float, float)
+ """Signal emitted when the value range has changed.
+
+ It provides the new range (min, max).
+ """
+
+ sigValueChanged = qt.Signal(float, float)
+ """Signal emitted when the value of the sliders has changed.
+
+ It provides the slider values (first, second).
+ """
+
+ sigPositionCountChanged = qt.Signal(object)
+ """This signal is emitted when the number of steps has changed.
+
+ It provides the new position count.
+ """
+
+ sigPositionChanged = qt.Signal(int, int)
+ """Signal emitted when the position of the sliders has changed.
+
+ It provides the slider positions in steps or pixels (first, second).
+ """
+
+ def __init__(self, parent=None):
+ self.__pixmap = None
+ self.__positionCount = None
+ self.__firstValue = 0.
+ self.__secondValue = 1.
+ self.__minValue = 0.
+ self.__maxValue = 1.
+ self.__hoverRect = qt.QRect()
+ self.__hoverControl = None
+
+ self.__focus = None
+ self.__moving = None
+
+ self.__icons = {
+ 'first': icons.getQIcon('previous'),
+ 'second': icons.getQIcon('next')
+ }
+
+ # call the super constructor AFTER defining all members that
+ # are used in the "paint" method
+ super(RangeSlider, self).__init__(parent)
+
+ self.setFocusPolicy(qt.Qt.ClickFocus)
+ self.setAttribute(qt.Qt.WA_Hover)
+
+ self.setMinimumSize(qt.QSize(50, 20))
+ self.setMaximumHeight(20)
+
+ # Broadcast value changed signal
+ self.sigValueChanged.connect(self.__emitPositionChanged)
+
+ def event(self, event):
+ t = event.type()
+ if t == qt.QEvent.HoverEnter or t == qt.QEvent.HoverLeave or t == qt.QEvent.HoverMove:
+ return self.__updateHoverControl(event.pos())
+ else:
+ return super(RangeSlider, self).event(event)
+
+ def __updateHoverControl(self, pos):
+ hoverControl, hoverRect = self.__findHoverControl(pos)
+ if hoverControl != self.__hoverControl:
+ self.update(self.__hoverRect)
+ self.update(hoverRect)
+ self.__hoverControl = hoverControl
+ self.__hoverRect = hoverRect
+ return True
+ return hoverControl is not None
+
+ def __findHoverControl(self, pos):
+ """Returns the control at the position and it's rect location"""
+ for name in ["first", "second"]:
+ rect = self.__sliderRect(name)
+ if rect.contains(pos):
+ return name, rect
+ rect = self.__drawArea()
+ if rect.contains(pos):
+ return "groove", rect
+ return None, qt.QRect()
+
+ # Position <-> Value conversion
+
+ def __positionToValue(self, position):
+ """Returns value corresponding to position
+
+ :param int position:
+ :rtype: float
+ """
+ min_, max_ = self.getMinimum(), self.getMaximum()
+ maxPos = self.__getCurrentPositionCount() - 1
+ return min_ + (max_ - min_) * int(position) / maxPos
+
+ def __valueToPosition(self, value):
+ """Returns closest position corresponding to value
+
+ :param float value:
+ :rtype: int
+ """
+ min_, max_ = self.getMinimum(), self.getMaximum()
+ maxPos = self.__getCurrentPositionCount() - 1
+ return int(0.5 + maxPos * (float(value) - min_) / (max_ - min_))
+
+ # Position (int) API
+
+ def __getCurrentPositionCount(self):
+ """Return current count (either position count or widget width
+
+ :rtype: int
+ """
+ count = self.getPositionCount()
+ if count is not None:
+ return count
+ else:
+ return max(2, self.width() - self._SLIDER_WIDTH)
+
+ def getPositionCount(self):
+ """Returns the number of positions.
+
+ :rtype: Union[int,None]"""
+ return self.__positionCount
+
+ def setPositionCount(self, count):
+ """Set the number of positions.
+
+ Slider values are eventually adjusted.
+
+ :param Union[int,None] count:
+ Either the number of possible positions or
+ None to allow any values.
+ :raise ValueError: If count <= 1
+ """
+ count = None if count is None else int(count)
+ if count != self.getPositionCount():
+ if count is not None and count <= 1:
+ raise ValueError("Position count must be higher than 1")
+ self.__positionCount = count
+ emit = self.__setValues(*self.getValues())
+ self.sigPositionCountChanged.emit(count)
+ if emit:
+ self.sigValueChanged.emit(*self.getValues())
+
+ def getFirstPosition(self):
+ """Returns first slider position
+
+ :rtype: int
+ """
+ return self.__valueToPosition(self.getFirstValue())
+
+ def setFirstPosition(self, position):
+ """Set the position of the first slider
+
+ The position is adjusted to valid values
+
+ :param int position:
+ """
+ self.setFirstValue(self.__positionToValue(position))
+
+ def getSecondPosition(self):
+ """Returns second slider position
+
+ :rtype: int
+ """
+ return self.__valueToPosition(self.getSecondValue())
+
+ def setSecondPosition(self, position):
+ """Set the position of the second slider
+
+ The position is adjusted to valid values
+
+ :param int position:
+ """
+ self.setSecondValue(self.__positionToValue(position))
+
+ def getPositions(self):
+ """Returns slider positions (first, second)
+
+ :rtype: List[int]
+ """
+ return self.getFirstPosition(), self.getSecondPosition()
+
+ def setPositions(self, first, second):
+ """Set the position of both sliders at once
+
+ First is clipped to the slider range: [0, max].
+ Second is clipped to valid values: [first, max]
+
+ :param int first:
+ :param int second:
+ """
+ self.setValues(self.__positionToValue(first),
+ self.__positionToValue(second))
+
+ # Value (float) API
+
+ def __emitPositionChanged(self, *args, **kwargs):
+ self.sigPositionChanged.emit(*self.getPositions())
+
+ def __rangeChanged(self):
+ """Handle change of value range"""
+ emit = self.__setValues(*self.getValues())
+ self.sigRangeChanged.emit(*self.getRange())
+ if emit:
+ self.sigValueChanged.emit(*self.getValues())
+
+ def getMinimum(self):
+ """Returns the minimum value of the slider range
+
+ :rtype: float
+ """
+ return self.__minValue
+
+ def setMinimum(self, minimum):
+ """Set the minimum value of the slider range.
+
+ It eventually adjusts maximum.
+ Slider positions remains unchanged and slider values are modified.
+
+ :param float minimum:
+ :raises ValueError:
+ """
+ minimum = float(minimum)
+ if minimum == self.getMaximum():
+ raise ValueError("min and max must be different")
+
+ if minimum != self.getMinimum():
+ if minimum > self.getMaximum():
+ self.__maxValue = minimum
+ self.__minValue = minimum
+ self.__rangeChanged()
+
+ def getMaximum(self):
+ """Returns the maximum value of the slider range
+
+ :rtype: float
+ """
+ return self.__maxValue
+
+ def setMaximum(self, maximum):
+ """Set the maximum value of the slider range
+
+ It eventually adjusts minimum.
+ Slider positions remains unchanged and slider values are modified.
+
+ :param float maximum:
+ :raises ValueError:
+ """
+ maximum = float(maximum)
+ if maximum == self.getMinimum():
+ raise ValueError("min and max must be different")
+
+ if maximum != self.getMaximum():
+ if maximum < self.getMinimum():
+ self.__minValue = maximum
+ self.__maxValue = maximum
+ self.__rangeChanged()
+
+ def getRange(self):
+ """Returns the range of values (min, max)
+
+ :rtype: List[float]
+ """
+ return self.getMinimum(), self.getMaximum()
+
+ def setRange(self, minimum, maximum):
+ """Set the range of values.
+
+ If maximum is lower than minimum, minimum is the only valid value.
+ Slider positions remains unchanged and slider values are modified.
+
+ :param float minimum:
+ :param float maximum:
+ :raises ValueError:
+ """
+ minimum, maximum = float(minimum), float(maximum)
+ if minimum == maximum:
+ raise ValueError("min and max must be different")
+ if minimum != self.getMinimum() or maximum != self.getMaximum():
+ self.__minValue = minimum
+ self.__maxValue = max(maximum, minimum)
+ self.__rangeChanged()
+
+ def getFirstValue(self):
+ """Returns the value of the first slider
+
+ :rtype: float
+ """
+ return self.__firstValue
+
+ def __clipFirstValue(self, value, max_=None):
+ """Clip first value to range and steps
+
+ :param float value:
+ :param float max_: Alternative maximum to use
+ """
+ if max_ is None:
+ max_ = self.getSecondValue()
+ value = min(max(self.getMinimum(), float(value)), max_)
+ if self.getPositionCount() is not None: # Clip to steps
+ value = self.__positionToValue(self.__valueToPosition(value))
+ return value
+
+ def setFirstValue(self, value):
+ """Set the value of the first slider
+
+ Value is clipped to valid values.
+
+ :param float value:
+ """
+ value = self.__clipFirstValue(value)
+ if value != self.getFirstValue():
+ self.__firstValue = value
+ self.update()
+ self.sigValueChanged.emit(*self.getValues())
+
+ def getSecondValue(self):
+ """Returns the value of the second slider
+
+ :rtype: float
+ """
+ return self.__secondValue
+
+ def __clipSecondValue(self, value):
+ """Clip second value to range and steps
+
+ :param float value:
+ """
+ value = min(max(self.getFirstValue(), float(value)), self.getMaximum())
+ if self.getPositionCount() is not None: # Clip to steps
+ value = self.__positionToValue(self.__valueToPosition(value))
+ return value
+
+ def setSecondValue(self, value):
+ """Set the value of the second slider
+
+ Value is clipped to valid values.
+
+ :param float value:
+ """
+ value = self.__clipSecondValue(value)
+ if value != self.getSecondValue():
+ self.__secondValue = value
+ self.update()
+ self.sigValueChanged.emit(*self.getValues())
+
+ def getValues(self):
+ """Returns value of both sliders at once
+
+ :return: (first value, second value)
+ :rtype: List[float]
+ """
+ return self.getFirstValue(), self.getSecondValue()
+
+ def setValues(self, first, second):
+ """Set values for both sliders at once
+
+ First is clipped to the slider range: [minimum, maximum].
+ Second is clipped to valid values: [first, maximum]
+
+ :param float first:
+ :param float second:
+ """
+ if self.__setValues(first, second):
+ self.sigValueChanged.emit(*self.getValues())
+
+ def __setValues(self, first, second):
+ """Set values for both sliders at once
+
+ First is clipped to the slider range: [minimum, maximum].
+ Second is clipped to valid values: [first, maximum]
+
+ :param float first:
+ :param float second:
+ :return: True if values has changed, False otherwise
+ :rtype: bool
+ """
+ first = self.__clipFirstValue(first, self.getMaximum())
+ second = self.__clipSecondValue(second)
+ values = first, second
+
+ if self.getValues() != values:
+ self.__firstValue, self.__secondValue = values
+ self.update()
+ return True
+ return False
+
+ # Groove API
+
+ def getGroovePixmap(self):
+ """Returns the pixmap displayed in the slider groove if any.
+
+ :rtype: Union[QPixmap,None]
+ """
+ return self.__pixmap
+
+ def setGroovePixmap(self, pixmap):
+ """Set the pixmap displayed in the slider groove.
+
+ :param Union[QPixmap,None] pixmap: The QPixmap to use or None to unset.
+ """
+ assert pixmap is None or isinstance(pixmap, qt.QPixmap)
+ self.__pixmap = pixmap
+ self.update()
+
+ def setGroovePixmapFromProfile(self, profile, colormap=None):
+ """Set the pixmap displayed in the slider groove from histogram values.
+
+ :param Union[numpy.ndarray,None] profile:
+ 1D array of values to display
+ :param Union[~silx.gui.colors.Colormap,str] colormap:
+ The colormap name or object to convert profile values to colors
+ """
+ if profile is None:
+ self.setSliderPixmap(None)
+ return
+
+ profile = numpy.array(profile, copy=False)
+
+ if profile.size == 0:
+ self.setSliderPixmap(None)
+ return
+
+ if colormap is None:
+ colormap = colors.Colormap()
+ elif isinstance(colormap, str):
+ colormap = colors.Colormap(name=colormap)
+ assert isinstance(colormap, colors.Colormap)
+
+ rgbImage = colormap.applyToData(profile.reshape(1, -1))[:, :, :3]
+ qimage = convertArrayToQImage(rgbImage)
+ qpixmap = qt.QPixmap.fromImage(qimage)
+ self.setGroovePixmap(qpixmap)
+
+ # Handle interaction
+
+ def mousePressEvent(self, event):
+ super(RangeSlider, self).mousePressEvent(event)
+
+ if event.buttons() == qt.Qt.LeftButton:
+ picked = None
+ for name in ('first', 'second'):
+ area = self.__sliderRect(name)
+ if area.contains(event.pos()):
+ picked = name
+ break
+
+ self.__moving = picked
+ self.__focus = picked
+ self.update()
+
+ def mouseMoveEvent(self, event):
+ super(RangeSlider, self).mouseMoveEvent(event)
+
+ if self.__moving is not None:
+ delta = self._SLIDER_WIDTH // 2
+ if self.__moving == 'first':
+ position = self.__xPixelToPosition(event.pos().x() + delta)
+ self.setFirstPosition(position)
+ else:
+ position = self.__xPixelToPosition(event.pos().x() - delta)
+ self.setSecondPosition(position)
+
+ def mouseReleaseEvent(self, event):
+ super(RangeSlider, self).mouseReleaseEvent(event)
+
+ if event.button() == qt.Qt.LeftButton and self.__moving is not None:
+ self.__moving = None
+ self.update()
+
+ def focusOutEvent(self, event):
+ if self.__focus is not None:
+ self.__focus = None
+ self.update()
+ super(RangeSlider, self).focusOutEvent(event)
+
+ def keyPressEvent(self, event):
+ key = event.key()
+ if event.modifiers() == qt.Qt.NoModifier and self.__focus is not None:
+ if key in (qt.Qt.Key_Left, qt.Qt.Key_Down):
+ if self.__focus == 'first':
+ self.setFirstPosition(self.getFirstPosition() - 1)
+ else:
+ self.setSecondPosition(self.getSecondPosition() - 1)
+ return # accept event
+ elif key in (qt.Qt.Key_Right, qt.Qt.Key_Up):
+ if self.__focus == 'first':
+ self.setFirstPosition(self.getFirstPosition() + 1)
+ else:
+ self.setSecondPosition(self.getSecondPosition() + 1)
+ return # accept event
+
+ super(RangeSlider, self).keyPressEvent(event)
+
+ # Handle resize
+
+ def resizeEvent(self, event):
+ super(RangeSlider, self).resizeEvent(event)
+
+ # If no step, signal position update when width change
+ if (self.getPositionCount() is None and
+ event.size().width() != event.oldSize().width()):
+ self.sigPositionChanged.emit(*self.getPositions())
+
+ # Handle repaint
+
+ def __xPixelToPosition(self, x):
+ """Convert position in pixel to slider position
+
+ :param int x: X in pixel coordinates
+ :rtype: int
+ """
+ sliderArea = self.__sliderAreaRect()
+ maxPos = self.__getCurrentPositionCount() - 1
+ position = maxPos * (x - sliderArea.left()) / (sliderArea.width() - 1)
+ return int(position + 0.5)
+
+ def __sliderRect(self, name):
+ """Returns rectangle corresponding to slider in pixels
+
+ :param str name: 'first' or 'second'
+ :rtype: QRect
+ :raise ValueError: If wrong name
+ """
+ assert name in ('first', 'second')
+ if name == 'first':
+ offset = - self._SLIDER_WIDTH
+ position = self.getFirstPosition()
+ elif name == 'second':
+ offset = 0
+ position = self.getSecondPosition()
+ else:
+ raise ValueError('Unknown name')
+
+ sliderArea = self.__sliderAreaRect()
+
+ maxPos = self.__getCurrentPositionCount() - 1
+ xOffset = int((sliderArea.width() - 1) * position / maxPos)
+ xPos = sliderArea.left() + xOffset + offset
+
+ return qt.QRect(xPos,
+ sliderArea.top(),
+ self._SLIDER_WIDTH,
+ sliderArea.height())
+
+ def __drawArea(self):
+ return self.rect().adjusted(self._SLIDER_WIDTH, 0,
+ -self._SLIDER_WIDTH, 0)
+
+ def __sliderAreaRect(self):
+ return self.__drawArea().adjusted(self._SLIDER_WIDTH // 2,
+ 0,
+ -self._SLIDER_WIDTH // 2 + 1,
+ 0)
+
+ def __pixMapRect(self):
+ return self.__sliderAreaRect().adjusted(0,
+ self._PIXMAP_VOFFSET,
+ -1,
+ -self._PIXMAP_VOFFSET)
+
+ def paintEvent(self, event):
+ painter = qt.QPainter(self)
+
+ style = qt.QApplication.style()
+
+ area = self.__drawArea()
+ if self.__pixmap is not None:
+ pixmapRect = self.__pixMapRect()
+
+ option = qt.QStyleOptionProgressBar()
+ option.initFrom(self)
+ option.rect = area
+ option.state = (qt.QStyle.State_Enabled if self.isEnabled()
+ else qt.QStyle.State_None)
+ style.drawControl(qt.QStyle.CE_ProgressBarGroove,
+ option,
+ painter,
+ self)
+
+ painter.save()
+ pen = painter.pen()
+ pen.setWidth(1)
+ pen.setColor(qt.Qt.black if self.isEnabled() else qt.Qt.gray)
+ painter.setPen(pen)
+ painter.drawRect(pixmapRect.adjusted(-1, -1, 0, 1))
+ painter.restore()
+
+ if self.isEnabled():
+ rect = area.adjusted(self._SLIDER_WIDTH // 2,
+ self._PIXMAP_VOFFSET,
+ -self._SLIDER_WIDTH // 2,
+ -self._PIXMAP_VOFFSET + 1)
+ painter.drawPixmap(rect,
+ self.__pixmap,
+ self.__pixmap.rect())
+ else:
+ option = StyleOptionRangeSlider()
+ option.initFrom(self)
+ option.rect = area
+ option.sliderPosition1 = self.__firstValue
+ option.sliderPosition2 = self.__secondValue
+ option.handlerRect1 = self.__sliderRect("first")
+ option.handlerRect2 = self.__sliderRect("second")
+ option.minimum = self.__minValue
+ option.maximum = self.__maxValue
+ option.state = (qt.QStyle.State_Enabled if self.isEnabled()
+ else qt.QStyle.State_None)
+ if self.__hoverControl == "groove":
+ option.state |= qt.QStyle.State_MouseOver
+ elif option.state & qt.QStyle.State_MouseOver:
+ option.state ^= qt.QStyle.State_MouseOver
+ self.drawRangeSliderBackground(painter, option, self)
+
+ # Avoid glitch when moving handles
+ hoverControl = self.__moving or self.__hoverControl
+
+ for name in ('first', 'second'):
+ rect = self.__sliderRect(name)
+ option = qt.QStyleOptionButton()
+ option.initFrom(self)
+ option.icon = self.__icons[name]
+ option.iconSize = rect.size() * 0.7
+ if hoverControl == name:
+ option.state |= qt.QStyle.State_MouseOver
+ elif option.state & qt.QStyle.State_MouseOver:
+ option.state ^= qt.QStyle.State_MouseOver
+ if self.__focus == name:
+ option.state |= qt.QStyle.State_HasFocus
+ elif option.state & qt.QStyle.State_HasFocus:
+ option.state ^= qt.QStyle.State_HasFocus
+ option.rect = rect
+ style.drawControl(
+ qt.QStyle.CE_PushButton, option, painter, self)
+
+ def sizeHint(self):
+ return qt.QSize(200, self.minimumHeight())
+
+ @classmethod
+ def drawRangeSliderBackground(cls, painter, option, widget):
+ """Draw the background of the RangeSlider widget into the painter.
+
+ :param qt.QPainter painter: A painter
+ :param StyleOptionRangeSlider option: Options to draw the widget
+ :param qt.QWidget: The widget which have to be drawn
+ """
+ painter.save()
+ painter.translate(0.5, 0.5)
+
+ backgroundRect = qt.QRect(option.rect)
+ if backgroundRect.height() > 8:
+ center = backgroundRect.center()
+ backgroundRect.setHeight(8)
+ backgroundRect.moveCenter(center)
+
+ selectedRangeRect = qt.QRect(backgroundRect)
+ selectedRangeRect.setLeft(option.handlerRect1.center().x())
+ selectedRangeRect.setRight(option.handlerRect2.center().x())
+
+ highlight = option.palette.color(qt.QPalette.Highlight)
+ activeHighlight = highlight
+ selectedOutline = option.palette.color(qt.QPalette.Highlight)
+
+ buttonColor = option.palette.button().color()
+ val = qt.qGray(buttonColor.rgb())
+ buttonColor = buttonColor.lighter(100 + max(1, (180 - val) // 6))
+ buttonColor.setHsv(buttonColor.hue(), (buttonColor.saturation() * 3) // 4, buttonColor.value())
+
+ grooveColor = qt.QColor()
+ grooveColor.setHsv(buttonColor.hue(),
+ min(255, (int)(buttonColor.saturation())),
+ min(255, (int)(buttonColor.value() * 0.9)))
+
+ selectedInnerContrastLine = qt.QColor(255, 255, 255, 30)
+
+ outline = option.palette.color(qt.QPalette.Window).darker(140)
+ if (option.state & qt.QStyle.State_HasFocus and option.state & qt.QStyle.State_KeyboardFocusChange):
+ outline = highlight.darker(125)
+ if outline.value() > 160:
+ outline.setHsl(highlight.hue(), highlight.saturation(), 160)
+
+ # Draw background groove
+ painter.setRenderHint(qt.QPainter.Antialiasing, True)
+ gradient = qt.QLinearGradient()
+ gradient.setStart(backgroundRect.center().x(), backgroundRect.top())
+ gradient.setFinalStop(backgroundRect.center().x(), backgroundRect.bottom())
+ painter.setPen(qt.QPen(outline))
+ gradient.setColorAt(0, grooveColor.darker(110))
+ gradient.setColorAt(1, grooveColor.lighter(110))
+ painter.setBrush(gradient)
+ painter.drawRoundedRect(backgroundRect.adjusted(1, 1, -2, -2), 1, 1)
+
+ # Draw slider background for the value
+ gradient = qt.QLinearGradient()
+ gradient.setStart(selectedRangeRect.center().x(), selectedRangeRect.top())
+ gradient.setFinalStop(selectedRangeRect.center().x(), selectedRangeRect.bottom())
+ painter.setRenderHint(qt.QPainter.Antialiasing, True)
+ painter.setPen(qt.QPen(selectedOutline))
+ gradient.setColorAt(0, activeHighlight)
+ gradient.setColorAt(1, activeHighlight.lighter(130))
+ painter.setBrush(gradient)
+ painter.drawRoundedRect(selectedRangeRect.adjusted(1, 1, -2, -2), 1, 1)
+ painter.setPen(selectedInnerContrastLine)
+ painter.setBrush(qt.Qt.NoBrush)
+ painter.drawRoundedRect(selectedRangeRect.adjusted(2, 2, -3, -3), 1, 1)
+
+ painter.restore()
diff --git a/src/silx/gui/widgets/TableWidget.py b/src/silx/gui/widgets/TableWidget.py
new file mode 100644
index 0000000..50eb9e2
--- /dev/null
+++ b/src/silx/gui/widgets/TableWidget.py
@@ -0,0 +1,626 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides table widgets handling cut, copy and paste for
+multiple cell selections. These actions can be triggered using keyboard
+shortcuts or through a context menu (right-click).
+
+:class:`TableView` is a subclass of :class:`QTableView`. The added features
+are made available to users after a model is added to the widget, using
+:meth:`TableView.setModel`.
+
+:class:`TableWidget` is a subclass of :class:`qt.QTableWidget`, a table view
+with a built-in standard data model. The added features are available as soon as
+the widget is initialized.
+
+The cut, copy and paste actions are implemented as QActions:
+
+ - :class:`CopySelectedCellsAction` (*Ctrl+C*)
+ - :class:`CopyAllCellsAction`
+ - :class:`CutSelectedCellsAction` (*Ctrl+X*)
+ - :class:`CutAllCellsAction`
+ - :class:`PasteCellsAction` (*Ctrl+V*)
+
+The copy actions are enabled by default. The cut and paste actions must be
+explicitly enabled, by passing parameters ``cut=True, paste=True`` when
+creating the widgets, or later by calling their :meth:`enableCut` and
+:meth:`enablePaste` methods.
+"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "03/07/2017"
+
+
+import sys
+from .. import qt
+
+
+if sys.platform.startswith("win"):
+ row_separator = "\r\n"
+else:
+ row_separator = "\n"
+
+col_separator = "\t"
+
+
+class CopySelectedCellsAction(qt.QAction):
+ """QAction to copy text from selected cells in a :class:`QTableWidget`
+ into the clipboard.
+
+ If multiple cells are selected, the copied text will be a concatenation
+ of the texts in all selected cells, tabulated with tabulation and
+ newline characters.
+
+ If the cells are sparsely selected, the structure is preserved by
+ representing the unselected cells as empty strings in between two
+ tabulation characters.
+ Beware of pasting this data in another table widget, because depending
+ on how the paste is implemented, the empty cells may cause data in the
+ target table to be deleted, even though you didn't necessarily select the
+ corresponding cell in the origin table.
+
+ :param table: :class:`QTableView` to which this action belongs.
+ """
+ def __init__(self, table):
+ if not isinstance(table, qt.QTableView):
+ raise ValueError('CopySelectedCellsAction must be initialised ' +
+ 'with a QTableWidget.')
+ super(CopySelectedCellsAction, self).__init__(table)
+ self.setText("Copy selection")
+ self.setToolTip("Copy selected cells into the clipboard.")
+ self.setShortcut(qt.QKeySequence.Copy)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+ self.triggered.connect(self.copyCellsToClipboard)
+ self.table = table
+ self.cut = False
+ """:attr:`cut` can be set to True by classes inheriting this action,
+ to do a cut action."""
+
+ def copyCellsToClipboard(self):
+ """Concatenate the text content of all selected cells into a string
+ using tabulations and newlines to keep the table structure.
+ Put this text into the clipboard.
+ """
+ selected_idx = self.table.selectedIndexes()
+ if not selected_idx:
+ return
+ selected_idx_tuples = [(idx.row(), idx.column()) for idx in selected_idx]
+
+ selected_rows = [idx[0] for idx in selected_idx_tuples]
+ selected_columns = [idx[1] for idx in selected_idx_tuples]
+
+ data_model = self.table.model()
+
+ copied_text = ""
+ for row in range(min(selected_rows), max(selected_rows) + 1):
+ for col in range(min(selected_columns), max(selected_columns) + 1):
+ index = data_model.index(row, col)
+ cell_text = data_model.data(index)
+ flags = data_model.flags(index)
+
+ if (row, col) in selected_idx_tuples and cell_text is not None:
+ copied_text += cell_text
+ if self.cut and (flags & qt.Qt.ItemIsEditable):
+ data_model.setData(index, "")
+ copied_text += col_separator
+ # remove the right-most tabulation
+ copied_text = copied_text[:-len(col_separator)]
+ # add a newline
+ copied_text += row_separator
+ # remove final newline
+ copied_text = copied_text[:-len(row_separator)]
+
+ # put this text into clipboard
+ qapp = qt.QApplication.instance()
+ qapp.clipboard().setText(copied_text)
+
+
+class CopyAllCellsAction(qt.QAction):
+ """QAction to copy text from all cells in a :class:`QTableWidget`
+ into the clipboard.
+
+ The copied text will be a concatenation
+ of the texts in all cells, tabulated with tabulation and
+ newline characters.
+
+ :param table: :class:`QTableView` to which this action belongs.
+ """
+ def __init__(self, table):
+ if not isinstance(table, qt.QTableView):
+ raise ValueError('CopyAllCellsAction must be initialised ' +
+ 'with a QTableWidget.')
+ super(CopyAllCellsAction, self).__init__(table)
+ self.setText("Copy all")
+ self.setToolTip("Copy all cells into the clipboard.")
+ self.triggered.connect(self.copyCellsToClipboard)
+ self.table = table
+ self.cut = False
+
+ def copyCellsToClipboard(self):
+ """Concatenate the text content of all cells into a string
+ using tabulations and newlines to keep the table structure.
+ Put this text into the clipboard.
+ """
+ data_model = self.table.model()
+ copied_text = ""
+ for row in range(data_model.rowCount()):
+ for col in range(data_model.columnCount()):
+ index = data_model.index(row, col)
+ cell_text = data_model.data(index)
+ flags = data_model.flags(index)
+ if cell_text is not None:
+ copied_text += cell_text
+ if self.cut and (flags & qt.Qt.ItemIsEditable):
+ data_model.setData(index, "")
+ copied_text += col_separator
+ # remove the right-most tabulation
+ copied_text = copied_text[:-len(col_separator)]
+ # add a newline
+ copied_text += row_separator
+ # remove final newline
+ copied_text = copied_text[:-len(row_separator)]
+
+ # put this text into clipboard
+ qapp = qt.QApplication.instance()
+ qapp.clipboard().setText(copied_text)
+
+
+class CutSelectedCellsAction(CopySelectedCellsAction):
+ """QAction to cut text from selected cells in a :class:`QTableWidget`
+ into the clipboard.
+
+ The text is deleted from the original table widget
+ (use :class:`CopySelectedCellsAction` to preserve the original data).
+
+ If multiple cells are selected, the cut text will be a concatenation
+ of the texts in all selected cells, tabulated with tabulation and
+ newline characters.
+
+ If the cells are sparsely selected, the structure is preserved by
+ representing the unselected cells as empty strings in between two
+ tabulation characters.
+ Beware of pasting this data in another table widget, because depending
+ on how the paste is implemented, the empty cells may cause data in the
+ target table to be deleted, even though you didn't necessarily select the
+ corresponding cell in the origin table.
+
+ :param table: :class:`QTableView` to which this action belongs."""
+ def __init__(self, table):
+ super(CutSelectedCellsAction, self).__init__(table)
+ self.setText("Cut selection")
+ self.setShortcut(qt.QKeySequence.Cut)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+ # cutting is already implemented in CopySelectedCellsAction (but
+ # it is disabled), we just need to enable it
+ self.cut = True
+
+
+class CutAllCellsAction(CopyAllCellsAction):
+ """QAction to cut text from all cells in a :class:`QTableWidget`
+ into the clipboard.
+
+ The text is deleted from the original table widget
+ (use :class:`CopyAllCellsAction` to preserve the original data).
+
+ The cut text will be a concatenation
+ of the texts in all cells, tabulated with tabulation and
+ newline characters.
+
+ :param table: :class:`QTableView` to which this action belongs."""
+ def __init__(self, table):
+ super(CutAllCellsAction, self).__init__(table)
+ self.setText("Cut all")
+ self.setToolTip("Cut all cells into the clipboard.")
+ self.cut = True
+
+
+def _parseTextAsTable(text, row_separator=row_separator, col_separator=col_separator):
+ """Parse text into list of lists (2D sequence).
+
+ The input text must be tabulated using tabulation characters and
+ newlines to separate columns and rows.
+
+ :param text: text to be parsed
+ :param record_separator: String, or single character, to be interpreted
+ as a record/row separator.
+ :param field_separator: String, or single character, to be interpreted
+ as a field/column separator.
+ :return: 2D sequence of strings
+ """
+ rows = text.split(row_separator)
+ table_data = [row.split(col_separator) for row in rows]
+ return table_data
+
+
+class PasteCellsAction(qt.QAction):
+ """QAction to paste text from the clipboard into the table.
+
+ If the text contains tabulations and
+ newlines, they are interpreted as column and row separators.
+ In such a case, the text is split into multiple texts to be pasted
+ into multiple cells.
+
+ If a cell content is an empty string in the original text, it is
+ ignored: the destination cell's text will not be deleted.
+
+ :param table: :class:`QTableView` to which this action belongs.
+ """
+ def __init__(self, table):
+ if not isinstance(table, qt.QTableView):
+ raise ValueError('PasteCellsAction must be initialised ' +
+ 'with a QTableWidget.')
+ super(PasteCellsAction, self).__init__(table)
+ self.table = table
+ self.setText("Paste")
+ self.setShortcut(qt.QKeySequence.Paste)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+ self.setToolTip("Paste data. The selected cell is the top-left" +
+ "corner of the paste area.")
+ self.triggered.connect(self.pasteCellFromClipboard)
+
+ def pasteCellFromClipboard(self):
+ """Paste text from clipboard into the table.
+
+ :return: *True* in case of success, *False* if pasting data failed.
+ """
+ selected_idx = self.table.selectedIndexes()
+ if len(selected_idx) != 1:
+ msgBox = qt.QMessageBox(parent=self.table)
+ msgBox.setText("A single cell must be selected to paste data")
+ msgBox.exec()
+ return False
+
+ data_model = self.table.model()
+
+ selected_row = selected_idx[0].row()
+ selected_col = selected_idx[0].column()
+
+ qapp = qt.QApplication.instance()
+ clipboard_text = qapp.clipboard().text()
+ table_data = _parseTextAsTable(clipboard_text)
+
+ protected_cells = 0
+ out_of_range_cells = 0
+
+ # paste table data into cells, using selected cell as origin
+ for row_offset in range(len(table_data)):
+ for col_offset in range(len(table_data[row_offset])):
+ target_row = selected_row + row_offset
+ target_col = selected_col + col_offset
+
+ if target_row >= data_model.rowCount() or\
+ target_col >= data_model.columnCount():
+ out_of_range_cells += 1
+ continue
+
+ index = data_model.index(target_row, target_col)
+ flags = data_model.flags(index)
+
+ # ignore empty strings
+ if table_data[row_offset][col_offset] != "":
+ if not flags & qt.Qt.ItemIsEditable:
+ protected_cells += 1
+ continue
+ data_model.setData(index, table_data[row_offset][col_offset])
+ # item.setText(table_data[row_offset][col_offset])
+
+ if protected_cells or out_of_range_cells:
+ msgBox = qt.QMessageBox(parent=self.table)
+ msg = "Some data could not be inserted, "
+ msg += "due to out-of-range or write-protected cells."
+ msgBox.setText(msg)
+ msgBox.exec()
+ return False
+ return True
+
+
+class CopySingleCellAction(qt.QAction):
+ """QAction to copy text from a single cell in a modified
+ :class:`QTableWidget`.
+
+ This action relies on the fact that the text in the last clicked cell
+ are stored in :attr:`_last_cell_clicked` of the modified widget.
+
+ In most cases, :class:`CopySelectedCellsAction` handles single cells,
+ but if the selection mode of the widget has been set to NoSelection
+ it is necessary to use this class instead.
+
+ :param table: :class:`QTableView` to which this action belongs.
+ """
+ def __init__(self, table):
+ if not isinstance(table, qt.QTableView):
+ raise ValueError('CopySingleCellAction must be initialised ' +
+ 'with a QTableWidget.')
+ super(CopySingleCellAction, self).__init__(table)
+ self.setText("Copy cell")
+ self.setToolTip("Copy cell content into the clipboard.")
+ self.triggered.connect(self.copyCellToClipboard)
+ self.table = table
+
+ def copyCellToClipboard(self):
+ """
+ """
+ cell_text = self.table._text_last_cell_clicked
+ if cell_text is None:
+ return
+
+ # put this text into clipboard
+ qapp = qt.QApplication.instance()
+ qapp.clipboard().setText(cell_text)
+
+
+class TableWidget(qt.QTableWidget):
+ """:class:`QTableWidget` with a context menu displaying up to 5 actions:
+
+ - :class:`CopySelectedCellsAction`
+ - :class:`CopyAllCellsAction`
+ - :class:`CutSelectedCellsAction`
+ - :class:`CutAllCellsAction`
+ - :class:`PasteCellsAction`
+
+ These actions interact with the clipboard and can be used to copy data
+ to or from an external application, or another widget.
+
+ The cut and paste actions are disabled by default, due to the risk of
+ overwriting data (no *Undo* action is available). Use :meth:`enablePaste`
+ and :meth:`enableCut` to activate them.
+
+ .. image:: img/TableWidget.png
+
+ :param parent: Parent QWidget
+ :param bool cut: Enable cut action
+ :param bool paste: Enable paste action
+ """
+ def __init__(self, parent=None, cut=False, paste=False):
+ super(TableWidget, self).__init__(parent)
+ self._text_last_cell_clicked = None
+
+ self.copySelectedCellsAction = CopySelectedCellsAction(self)
+ self.copyAllCellsAction = CopyAllCellsAction(self)
+ self.copySingleCellAction = None
+ self.pasteCellsAction = None
+ self.cutSelectedCellsAction = None
+ self.cutAllCellsAction = None
+
+ self.addAction(self.copySelectedCellsAction)
+ self.addAction(self.copyAllCellsAction)
+ if cut:
+ self.enableCut()
+ if paste:
+ self.enablePaste()
+
+ self.setContextMenuPolicy(qt.Qt.ActionsContextMenu)
+
+ def mousePressEvent(self, event):
+ item = self.itemAt(event.pos())
+ if item is not None:
+ self._text_last_cell_clicked = item.text()
+ super(TableWidget, self).mousePressEvent(event)
+
+ def enablePaste(self):
+ """Enable paste action, to paste data from the clipboard into the
+ table.
+
+ .. warning::
+
+ This action can cause data to be overwritten.
+ There is currently no *Undo* action to retrieve lost data.
+ """
+ self.pasteCellsAction = PasteCellsAction(self)
+ self.addAction(self.pasteCellsAction)
+
+ def enableCut(self):
+ """Enable cut action.
+
+ .. warning::
+
+ This action can cause data to be deleted.
+ There is currently no *Undo* action to retrieve lost data."""
+ self.cutSelectedCellsAction = CutSelectedCellsAction(self)
+ self.cutAllCellsAction = CutAllCellsAction(self)
+ self.addAction(self.cutSelectedCellsAction)
+ self.addAction(self.cutAllCellsAction)
+
+ def setSelectionMode(self, mode):
+ """Overloaded from QTableWidget to disable cut/copy selection
+ actions in case mode is NoSelection
+
+ :param mode:
+ :return:
+ """
+ if mode == qt.QTableView.NoSelection:
+ self.copySelectedCellsAction.setVisible(False)
+ self.copySelectedCellsAction.setEnabled(False)
+ if self.cutSelectedCellsAction is not None:
+ self.cutSelectedCellsAction.setVisible(False)
+ self.cutSelectedCellsAction.setEnabled(False)
+ if self.copySingleCellAction is None:
+ self.copySingleCellAction = CopySingleCellAction(self)
+ self.insertAction(self.copySelectedCellsAction, # before first action
+ self.copySingleCellAction)
+ self.copySingleCellAction.setVisible(True)
+ self.copySingleCellAction.setEnabled(True)
+ else:
+ self.copySelectedCellsAction.setVisible(True)
+ self.copySelectedCellsAction.setEnabled(True)
+ if self.cutSelectedCellsAction is not None:
+ self.cutSelectedCellsAction.setVisible(True)
+ self.cutSelectedCellsAction.setEnabled(True)
+ if self.copySingleCellAction is not None:
+ self.copySingleCellAction.setVisible(False)
+ self.copySingleCellAction.setEnabled(False)
+ super(TableWidget, self).setSelectionMode(mode)
+
+
+class TableView(qt.QTableView):
+ """:class:`QTableView` with a context menu displaying up to 5 actions:
+
+ - :class:`CopySelectedCellsAction`
+ - :class:`CopyAllCellsAction`
+ - :class:`CutSelectedCellsAction`
+ - :class:`CutAllCellsAction`
+ - :class:`PasteCellsAction`
+
+ These actions interact with the clipboard and can be used to copy data
+ to or from an external application, or another widget.
+
+ The cut and paste actions are disabled by default, due to the risk of
+ overwriting data (no *Undo* action is available). Use :meth:`enablePaste`
+ and :meth:`enableCut` to activate them.
+
+ .. note::
+
+ These actions will be available only after a model is associated
+ with this view, using :meth:`setModel`.
+
+ :param parent: Parent QWidget
+ :param bool cut: Enable cut action
+ :param bool paste: Enable paste action
+ """
+ def __init__(self, parent=None, cut=False, paste=False):
+ super(TableView, self).__init__(parent)
+ self._text_last_cell_clicked = None
+
+ self.cut = cut
+ self.paste = paste
+
+ self.copySelectedCellsAction = None
+ self.copyAllCellsAction = None
+ self.copySingleCellAction = None
+ self.pasteCellsAction = None
+ self.cutSelectedCellsAction = None
+ self.cutAllCellsAction = None
+
+ def mousePressEvent(self, event):
+ qindex = self.indexAt(event.pos())
+ if self.copyAllCellsAction is not None: # model was set
+ self._text_last_cell_clicked = self.model().data(qindex)
+ super(TableView, self).mousePressEvent(event)
+
+ def setModel(self, model):
+ """Set the data model for the table view, activate the actions
+ and the context menu.
+
+ :param model: :class:`qt.QAbstractItemModel` object
+ """
+ super(TableView, self).setModel(model)
+
+ self.copySelectedCellsAction = CopySelectedCellsAction(self)
+ self.copyAllCellsAction = CopyAllCellsAction(self)
+ self.addAction(self.copySelectedCellsAction)
+ self.addAction(self.copyAllCellsAction)
+ if self.cut:
+ self.enableCut()
+ if self.paste:
+ self.enablePaste()
+
+ self.setContextMenuPolicy(qt.Qt.ActionsContextMenu)
+
+ def enablePaste(self):
+ """Enable paste action, to paste data from the clipboard into the
+ table.
+
+ .. warning::
+
+ This action can cause data to be overwritten.
+ There is currently no *Undo* action to retrieve lost data.
+ """
+ self.pasteCellsAction = PasteCellsAction(self)
+ self.addAction(self.pasteCellsAction)
+
+ def enableCut(self):
+ """Enable cut action.
+
+ .. warning::
+
+ This action can cause data to be deleted.
+ There is currently no *Undo* action to retrieve lost data.
+ """
+ self.cutSelectedCellsAction = CutSelectedCellsAction(self)
+ self.cutAllCellsAction = CutAllCellsAction(self)
+ self.addAction(self.cutSelectedCellsAction)
+ self.addAction(self.cutAllCellsAction)
+
+ def addAction(self, action):
+ # ensure the actions are not added multiple times:
+ # compare action type and parent widget with those of existing actions
+ for existing_action in self.actions():
+ if type(action) == type(existing_action):
+ if hasattr(action, "table") and\
+ action.table is existing_action.table:
+ return None
+ super(TableView, self).addAction(action)
+
+ def setSelectionMode(self, mode):
+ """Overloaded from QTableView to disable cut/copy selection
+ actions in case mode is NoSelection
+
+ :param mode:
+ :return:
+ """
+ if mode == qt.QTableView.NoSelection:
+ self.copySelectedCellsAction.setVisible(False)
+ self.copySelectedCellsAction.setEnabled(False)
+ if self.cutSelectedCellsAction is not None:
+ self.cutSelectedCellsAction.setVisible(False)
+ self.cutSelectedCellsAction.setEnabled(False)
+ if self.copySingleCellAction is None:
+ self.copySingleCellAction = CopySingleCellAction(self)
+ self.insertAction(self.copySelectedCellsAction, # before first action
+ self.copySingleCellAction)
+ self.copySingleCellAction.setVisible(True)
+ self.copySingleCellAction.setEnabled(True)
+ else:
+ self.copySelectedCellsAction.setVisible(True)
+ self.copySelectedCellsAction.setEnabled(True)
+ if self.cutSelectedCellsAction is not None:
+ self.cutSelectedCellsAction.setVisible(True)
+ self.cutSelectedCellsAction.setEnabled(True)
+ if self.copySingleCellAction is not None:
+ self.copySingleCellAction.setVisible(False)
+ self.copySingleCellAction.setEnabled(False)
+ super(TableView, self).setSelectionMode(mode)
+
+
+if __name__ == "__main__":
+ app = qt.QApplication([])
+
+ tablewidget = TableWidget()
+ tablewidget.setWindowTitle("TableWidget")
+ tablewidget.setColumnCount(10)
+ tablewidget.setRowCount(7)
+ tablewidget.enableCut()
+ tablewidget.enablePaste()
+ tablewidget.show()
+
+ tableview = TableView(cut=True, paste=True)
+ tableview.setWindowTitle("TableView")
+ model = qt.QStandardItemModel()
+ model.setColumnCount(10)
+ model.setRowCount(7)
+ tableview.setModel(model)
+ tableview.show()
+
+ app.exec()
diff --git a/src/silx/gui/widgets/ThreadPoolPushButton.py b/src/silx/gui/widgets/ThreadPoolPushButton.py
new file mode 100644
index 0000000..949b6ef
--- /dev/null
+++ b/src/silx/gui/widgets/ThreadPoolPushButton.py
@@ -0,0 +1,238 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""ThreadPoolPushButton module
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "13/10/2016"
+
+import logging
+from .. import qt
+from .WaitingPushButton import WaitingPushButton
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _Wrapper(qt.QRunnable):
+ """Wrapper to allow to call a function into a `QThreadPool` and
+ sending signals during the life cycle of the object"""
+
+ def __init__(self, signalHolder, function, args, kwargs):
+ """Constructor"""
+ super(_Wrapper, self).__init__()
+ self.__signalHolder = signalHolder
+ self.__callable = function
+ self.__args = args
+ self.__kwargs = kwargs
+
+ def run(self):
+ holder = self.__signalHolder
+ holder.started.emit()
+ try:
+ result = self.__callable(*self.__args, **self.__kwargs)
+ holder.succeeded.emit(result)
+ except Exception as e:
+ module = self.__callable.__module__
+ name = self.__callable.__name__
+ _logger.error("Error while executing callable %s.%s.", module, name, exc_info=True)
+ holder.failed.emit(e)
+ finally:
+ holder.finished.emit()
+ holder._sigReleaseRunner.emit(self)
+
+ def autoDelete(self):
+ """Returns true to ask the QThreadPool to manage the life cycle of
+ this QRunner."""
+ return True
+
+
+class ThreadPoolPushButton(WaitingPushButton):
+ """
+ ThreadPoolPushButton provides a simple push button to execute
+ a threaded task with user feedback when the task is running.
+
+ The task can be defined with the method `setCallable`. It takes a python
+ function and arguments as parameters.
+
+ WARNING: This task is run in a separate thread.
+
+ Everytime the button is pushed a new runner is created to execute the
+ function with defined arguments. An animated waiting icon is displayed
+ to show the activity. By default the button is disabled when an execution
+ is requested. This behaviour can be disabled by using
+ `setDisabledWhenWaiting`.
+
+ When the button is clicked a `beforeExecuting` signal is sent from the
+ Qt main thread. Then the task is started in a thread pool and the following
+ signals are emitted from the thread pool. Right before calling the
+ registered callable, the widget emits a `started` signal.
+ When the task ends, its result is emitted by the `succeeded` signal, but
+ if it fails the signal `failed` is emitted with the resulting exception.
+ At the end, the `finished` signal is emitted.
+
+ The task can be programatically executed by using `executeCallable`.
+
+ >>> # Compute a value
+ >>> import math
+ >>> button = ThreadPoolPushButton(text="Compute 2^16")
+ >>> button.setCallable(math.pow, 2, 16)
+ >>> button.succeeded.connect(print) # python3
+
+ .. image:: img/ThreadPoolPushButton.png
+
+ >>> # Compute a wrong value
+ >>> import math
+ >>> button = ThreadPoolPushButton(text="Compute sqrt(-1)")
+ >>> button.setCallable(math.sqrt, -1)
+ >>> button.failed.connect(print) # python3
+ """
+
+ def __init__(self, parent=None, text=None, icon=None):
+ """Constructor
+
+ :param str text: Text displayed on the button
+ :param qt.QIcon icon: Icon displayed on the button
+ :param qt.QWidget parent: Parent of the widget
+ """
+ WaitingPushButton.__init__(self, parent=parent, text=text, icon=icon)
+ self.__callable = None
+ self.__args = None
+ self.__kwargs = None
+ self.__runnerCount = 0
+ self.__runnerSet = set([])
+ self.clicked.connect(self.executeCallable)
+ self.finished.connect(self.__runnerFinished)
+ self._sigReleaseRunner.connect(self.__releaseRunner)
+
+ beforeExecuting = qt.Signal()
+ """Signal emitted just before execution of the callable by the main Qt
+ thread. In synchronous mode (direct mode), it can be used to define
+ dynamically `setCallable`, or to execute something in the Qt thread before
+ the execution, or both."""
+
+ started = qt.Signal()
+ """Signal emitted from the thread pool when the defined callable is
+ started.
+
+ WARNING: This signal is emitted from the thread performing the task, and
+ might be received after the registered callable has been called. If you
+ want to perform some initialisation or set the callable to run, use the
+ `beforeExecuting` signal instead.
+ """
+
+ finished = qt.Signal()
+ """Signal emitted from the thread pool when the defined callable is
+ finished"""
+
+ succeeded = qt.Signal(object)
+ """Signal emitted from the thread pool when the callable exit with a
+ success.
+
+ The parameter of the signal is the result returned by the callable.
+ """
+
+ failed = qt.Signal(object)
+ """Signal emitted emitted from the thread pool when the callable raises an
+ exception.
+
+ The parameter of the signal is the raised exception.
+ """
+
+ _sigReleaseRunner = qt.Signal(object)
+ """Callback to release runners"""
+
+ def __runnerStarted(self):
+ """Called when a runner is started.
+
+ Count the number of executed tasks to change the state of the widget.
+ """
+ self.__runnerCount += 1
+ if self.__runnerCount > 0:
+ self.wait()
+
+ def __runnerFinished(self):
+ """Called when a runner is finished.
+
+ Count the number of executed tasks to change the state of the widget.
+ """
+ self.__runnerCount -= 1
+ if self.__runnerCount <= 0:
+ self.stopWaiting()
+
+ @qt.Slot()
+ def executeCallable(self):
+ """Execute the defined callable in QThreadPool.
+
+ First emit a `beforeExecuting` signal.
+ If callable is not defined, nothing append.
+ If a callable is defined, it will be started
+ as a new thread using the `QThreadPool` system. At start of the thread
+ the `started` will be emitted. When the callable returns a result it
+ is emitted by the `succeeded` signal. If the callable fail, the signal
+ `failed` is emitted with the resulting exception. Then the `finished`
+ signal is emitted.
+ """
+ self.beforeExecuting.emit()
+ if self.__callable is None:
+ return
+ self.__runnerStarted()
+ runner = self._createRunner(self.__callable, self.__args, self.__kwargs)
+ qt.silxGlobalThreadPool().start(runner)
+ self.__runnerSet.add(runner)
+
+ def __releaseRunner(self, runner):
+ self.__runnerSet.remove(runner)
+
+ def hasPendingOperations(self):
+ return len(self.__runnerSet) > 0
+
+ def _createRunner(self, function, args, kwargs):
+ """Create a QRunnable from a callable object.
+
+ :param callable function: A callable Python object.
+ :param List args: List of arguments to call the function.
+ :param dict kwargs: Dictionary of arguments used to call the function.
+ :rtpye: qt.QRunnable
+ """
+ runnable = _Wrapper(self, function, args, kwargs)
+ return runnable
+
+ def setCallable(self, function, *args, **kwargs):
+ """Define a callable which will be executed on QThreadPool everytime
+ the button is clicked.
+
+ To retrieve the results, connect to the `succeeded` signal.
+
+ WARNING: The callable will be called in a separate thread.
+
+ :param callable function: A callable Python object
+ :param List args: List of arguments to call the function.
+ :param dict kwargs: Dictionary of arguments used to call the function.
+ """
+ self.__callable = function
+ self.__args = args
+ self.__kwargs = kwargs
diff --git a/src/silx/gui/widgets/UrlSelectionTable.py b/src/silx/gui/widgets/UrlSelectionTable.py
new file mode 100644
index 0000000..bc75d32
--- /dev/null
+++ b/src/silx/gui/widgets/UrlSelectionTable.py
@@ -0,0 +1,169 @@
+# /*##########################################################################
+# Copyright (C) 2017-2021 European Synchrotron Radiation Facility
+#
+# This file is part of the PyMca X-ray Fluorescence Toolkit developed at
+# the ESRF by the Software group.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+#############################################################################*/
+"""Some widget construction to check if a sample moved"""
+
+__author__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "19/03/2018"
+
+from silx.gui import qt
+from collections import OrderedDict
+from silx.gui.widgets.TableWidget import TableWidget
+from silx.io.url import DataUrl
+import functools
+import logging
+import os
+
+logger = logging.getLogger(__name__)
+
+
+class UrlSelectionTable(TableWidget):
+ """Table used to select the color channel to be displayed for each"""
+
+ COLUMS_INDEX = OrderedDict([
+ ('url', 0),
+ ('img A', 1),
+ ('img B', 2),
+ ])
+
+ sigImageAChanged = qt.Signal(str)
+ """Signal emitted when the image A change. Param is the image url path"""
+
+ sigImageBChanged = qt.Signal(str)
+ """Signal emitted when the image B change. Param is the image url path"""
+
+ def __init__(self, parent=None):
+ TableWidget.__init__(self, parent)
+ self.clear()
+
+ def clear(self):
+ qt.QTableWidget.clear(self)
+ self.setRowCount(0)
+ self.setColumnCount(len(self.COLUMS_INDEX))
+ self.setHorizontalHeaderLabels(list(self.COLUMS_INDEX.keys()))
+ self.verticalHeader().hide()
+ self.horizontalHeader().setSectionResizeMode(0,
+ qt.QHeaderView.Stretch)
+
+ self.setSortingEnabled(True)
+ self._checkBoxes = {}
+
+ def setUrls(self, urls: list) -> None:
+ """
+
+ :param urls: urls to be displayed
+ """
+ for url in urls:
+ self.addUrl(url=url)
+
+ def addUrl(self, url, **kwargs):
+ """
+
+ :param url:
+ :param args:
+ :return: index of the created items row
+ :rtype int
+ """
+ assert isinstance(url, DataUrl)
+ row = self.rowCount()
+ self.setRowCount(row + 1)
+
+ _item = qt.QTableWidgetItem()
+ _item.setText(os.path.basename(url.path()))
+ _item.setFlags(qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable)
+ self.setItem(row, self.COLUMS_INDEX['url'], _item)
+
+ widgetImgA = qt.QRadioButton(parent=self)
+ widgetImgA.setAutoExclusive(False)
+ self.setCellWidget(row, self.COLUMS_INDEX['img A'], widgetImgA)
+ callbackImgA = functools.partial(self._activeImgAChanged, url.path())
+ widgetImgA.toggled.connect(callbackImgA)
+
+ widgetImgB = qt.QRadioButton(parent=self)
+ widgetImgA.setAutoExclusive(False)
+ self.setCellWidget(row, self.COLUMS_INDEX['img B'], widgetImgB)
+ callbackImgB = functools.partial(self._activeImgBChanged, url.path())
+ widgetImgB.toggled.connect(callbackImgB)
+
+ self._checkBoxes[url.path()] = {'img A': widgetImgA,
+ 'img B': widgetImgB}
+ self.resizeColumnsToContents()
+ return row
+
+ def _activeImgAChanged(self, name):
+ self._updatecheckBoxes('img A', name)
+ self.sigImageAChanged.emit(name)
+
+ def _activeImgBChanged(self, name):
+ self._updatecheckBoxes('img B', name)
+ self.sigImageBChanged.emit(name)
+
+ def _updatecheckBoxes(self, whichImg, name):
+ assert name in self._checkBoxes
+ assert whichImg in self._checkBoxes[name]
+ if self._checkBoxes[name][whichImg].isChecked():
+ for radioUrl in self._checkBoxes:
+ if radioUrl != name:
+ self._checkBoxes[radioUrl][whichImg].blockSignals(True)
+ self._checkBoxes[radioUrl][whichImg].setChecked(False)
+ self._checkBoxes[radioUrl][whichImg].blockSignals(False)
+
+ def getSelection(self):
+ """
+
+ :return: url selected for img A and img B.
+ """
+ imgA = imgB = None
+ for radioUrl in self._checkBoxes:
+ if self._checkBoxes[radioUrl]['img A'].isChecked():
+ imgA = radioUrl
+ if self._checkBoxes[radioUrl]['img B'].isChecked():
+ imgB = radioUrl
+ return imgA, imgB
+
+ def setSelection(self, url_img_a, url_img_b):
+ """
+
+ :param ddict: key: image url, values: list of active channels
+ """
+ for radioUrl in self._checkBoxes:
+ for img in ('img A', 'img B'):
+ self._checkBoxes[radioUrl][img].blockSignals(True)
+ self._checkBoxes[radioUrl][img].setChecked(False)
+ self._checkBoxes[radioUrl][img].blockSignals(False)
+
+ self._checkBoxes[radioUrl][img].blockSignals(True)
+ self._checkBoxes[url_img_a]['img A'].setChecked(True)
+ self._checkBoxes[radioUrl][img].blockSignals(False)
+
+ self._checkBoxes[radioUrl][img].blockSignals(True)
+ self._checkBoxes[url_img_b]['img B'].setChecked(True)
+ self._checkBoxes[radioUrl][img].blockSignals(False)
+ self.sigImageAChanged.emit(url_img_a)
+ self.sigImageBChanged.emit(url_img_b)
+
+ def removeUrl(self, url):
+ raise NotImplementedError("")
diff --git a/src/silx/gui/widgets/WaitingPushButton.py b/src/silx/gui/widgets/WaitingPushButton.py
new file mode 100644
index 0000000..443dc9a
--- /dev/null
+++ b/src/silx/gui/widgets/WaitingPushButton.py
@@ -0,0 +1,241 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""WaitingPushButton module
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "26/04/2017"
+
+from .. import qt
+from .. import icons
+
+
+class WaitingPushButton(qt.QPushButton):
+ """Button which allows to display a waiting status when, for example,
+ something is still computing.
+
+ The component is graphically disabled when it is in waiting. Then we
+ overwrite the enabled method to dissociate the 2 concepts:
+ graphically enabled/disabled, and enabled/disabled
+
+ .. image:: img/WaitingPushButton.png
+ """
+
+ def __init__(self, parent=None, text=None, icon=None):
+ """Constructor
+
+ :param str text: Text displayed on the button
+ :param qt.QIcon icon: Icon displayed on the button
+ :param qt.QWidget parent: Parent of the widget
+ """
+ if icon is not None:
+ qt.QPushButton.__init__(self, icon, text, parent)
+ elif text is not None:
+ qt.QPushButton.__init__(self, text, parent)
+ else:
+ qt.QPushButton.__init__(self, parent)
+
+ self.__waiting = False
+ self.__enabled = True
+ self.__icon = icon
+ self.__disabled_when_waiting = True
+ self.__waitingIcon = icons.getWaitIcon()
+
+ def sizeHint(self):
+ """Returns the recommended size for the widget.
+
+ This implementation of the recommended size always consider there is an
+ icon. In this way it avoid to update the layout when the waiting icon
+ is displayed.
+ """
+ self.ensurePolished()
+
+ w = 0
+ h = 0
+
+ opt = qt.QStyleOptionButton()
+ self.initStyleOption(opt)
+
+ # Content with icon
+ # no condition, assume that there is an icon to avoid blinking
+ # when the widget switch to waiting state
+ ih = opt.iconSize.height()
+ iw = opt.iconSize.width() + 4
+ w += iw
+ h = max(h, ih)
+
+ # Content with text
+ text = self.text()
+ isEmpty = text == ""
+ if isEmpty:
+ text = "XXXX"
+ fm = self.fontMetrics()
+ textSize = fm.size(qt.Qt.TextShowMnemonic, text)
+ if not isEmpty or w == 0:
+ w += textSize.width()
+ if not isEmpty or h == 0:
+ h = max(h, textSize.height())
+
+ # Content with menu indicator
+ opt.rect.setSize(qt.QSize(w, h)) # PM_MenuButtonIndicator depends on the height
+ if self.menu() is not None:
+ w += self.style().pixelMetric(qt.QStyle.PM_MenuButtonIndicator, opt, self)
+
+ contentSize = qt.QSize(w, h)
+ sizeHint = self.style().sizeFromContents(qt.QStyle.CT_PushButton, opt, contentSize, self)
+ if qt.BINDING in ('PySide2', 'PyQt5'): # Qt6: globalStrut not available
+ sizeHint = sizeHint.expandedTo(qt.QApplication.globalStrut())
+ return sizeHint
+
+ def setDisabledWhenWaiting(self, isDisabled):
+ """Enable or disable the auto disable behaviour when the button is waiting.
+
+ :param bool isDisabled: Enable the auto-disable behaviour
+ """
+ if self.__disabled_when_waiting == isDisabled:
+ return
+ self.__disabled_when_waiting = isDisabled
+ self.__updateVisibleEnabled()
+
+ def isDisabledWhenWaiting(self):
+ """Returns true if the button is auto disabled when it is waiting.
+
+ :rtype: bool
+ """
+ return self.__disabled_when_waiting
+
+ disabledWhenWaiting = qt.Property(bool, isDisabledWhenWaiting, setDisabledWhenWaiting)
+ """Property to enable/disable the auto disabled state when the button is waiting."""
+
+ def __setWaitingIcon(self, icon):
+ """Called when the waiting icon is updated. It is called every frames
+ of the animation.
+
+ :param qt.QIcon icon: The new waiting icon
+ """
+ qt.QPushButton.setIcon(self, icon)
+
+ def setIcon(self, icon):
+ """Set the button icon. If the button is waiting, the icon is not
+ visible directly, but will be visible when the waiting state will be
+ removed.
+
+ :param qt.QIcon icon: An icon
+ """
+ self.__icon = icon
+ self.__updateVisibleIcon()
+
+ def getIcon(self):
+ """Returns the icon set to the button. If the widget is waiting
+ it is not returning the visible icon, but the one requested by
+ the application (the one displayed when the widget is not in
+ waiting state).
+
+ :rtype: qt.QIcon
+ """
+ return self.__icon
+
+ icon = qt.Property(qt.QIcon, getIcon, setIcon)
+ """Property providing access to the icon."""
+
+ def __updateVisibleIcon(self):
+ """Update the visible icon according to the state of the widget."""
+ if not self.isWaiting():
+ icon = self.__icon
+ else:
+ icon = self.__waitingIcon.currentIcon()
+ if icon is None:
+ icon = qt.QIcon()
+ qt.QPushButton.setIcon(self, icon)
+
+ def setEnabled(self, enabled):
+ """Set the enabled state of the widget.
+
+ :param bool enabled: The enabled state
+ """
+ if self.__enabled == enabled:
+ return
+ self.__enabled = enabled
+ self.__updateVisibleEnabled()
+
+ def isEnabled(self):
+ """Returns the enabled state of the widget.
+
+ :rtype: bool
+ """
+ return self.__enabled
+
+ enabled = qt.Property(bool, isEnabled, setEnabled)
+ """Property providing access to the enabled state of the widget"""
+
+ def __updateVisibleEnabled(self):
+ """Update the visible enabled state according to the state of the
+ widget."""
+ if self.__disabled_when_waiting:
+ enabled = not self.isWaiting() and self.__enabled
+ else:
+ enabled = self.__enabled
+ qt.QPushButton.setEnabled(self, enabled)
+
+ def setWaiting(self, waiting):
+ """Set the waiting state of the widget.
+
+ :param bool waiting: Requested state"""
+ if self.__waiting == waiting:
+ return
+ self.__waiting = waiting
+
+ if self.__waiting:
+ self.__waitingIcon.register(self)
+ self.__waitingIcon.iconChanged.connect(self.__setWaitingIcon)
+ else:
+ # unregister only if the object is registred
+ self.__waitingIcon.unregister(self)
+ self.__waitingIcon.iconChanged.disconnect(self.__setWaitingIcon)
+
+ self.__updateVisibleEnabled()
+ self.__updateVisibleIcon()
+
+ def isWaiting(self):
+ """Returns true if the widget is in waiting state.
+
+ :rtype: bool"""
+ return self.__waiting
+
+ @qt.Slot()
+ def wait(self):
+ """Enable the waiting state."""
+ self.setWaiting(True)
+
+ @qt.Slot()
+ def stopWaiting(self):
+ """Disable the waiting state."""
+ self.setWaiting(False)
+
+ @qt.Slot()
+ def swapWaiting(self):
+ """Swap the waiting state."""
+ self.setWaiting(not self.isWaiting())
diff --git a/src/silx/gui/widgets/__init__.py b/src/silx/gui/widgets/__init__.py
new file mode 100644
index 0000000..9d0299d
--- /dev/null
+++ b/src/silx/gui/widgets/__init__.py
@@ -0,0 +1,27 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides a few simple Qt widgets that rely only on a Qt binding for Python.
+
+No other optional dependencies of *silx* should be required."""
diff --git a/src/silx/gui/widgets/setup.py b/src/silx/gui/widgets/setup.py
new file mode 100644
index 0000000..e96ac8d
--- /dev/null
+++ b/src/silx/gui/widgets/setup.py
@@ -0,0 +1,41 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "11/10/2016"
+
+
+from numpy.distutils.misc_util import Configuration
+
+
+def configuration(parent_package='', top_path=None):
+ config = Configuration('widgets', parent_package, top_path)
+ config.add_subpackage('test')
+ return config
+
+
+if __name__ == "__main__":
+ from numpy.distutils.core import setup
+ setup(configuration=configuration)
diff --git a/src/silx/gui/widgets/test/__init__.py b/src/silx/gui/widgets/test/__init__.py
new file mode 100644
index 0000000..243dbc7
--- /dev/null
+++ b/src/silx/gui/widgets/test/__init__.py
@@ -0,0 +1,24 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
diff --git a/src/silx/gui/widgets/test/test_boxlayoutdockwidget.py b/src/silx/gui/widgets/test/test_boxlayoutdockwidget.py
new file mode 100644
index 0000000..5df8df9
--- /dev/null
+++ b/src/silx/gui/widgets/test/test_boxlayoutdockwidget.py
@@ -0,0 +1,72 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for BoxLayoutDockWidget"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "06/03/2018"
+
+import unittest
+
+from silx.gui.widgets.BoxLayoutDockWidget import BoxLayoutDockWidget
+from silx.gui import qt
+from silx.gui.utils.testutils import TestCaseQt
+
+
+class TestBoxLayoutDockWidget(TestCaseQt):
+ """Tests for BoxLayoutDockWidget"""
+
+ def setUp(self):
+ """Create and show a main window"""
+ self.window = qt.QMainWindow()
+ self.qWaitForWindowExposed(self.window)
+
+ def tearDown(self):
+ """Delete main window"""
+ self.window.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.window.close()
+ del self.window
+ self.qapp.processEvents()
+
+ def test(self):
+ """Test update of layout direction according to dock area"""
+ # Create a widget with a QBoxLayout
+ layout = qt.QBoxLayout(qt.QBoxLayout.LeftToRight)
+ layout.addWidget(qt.QLabel('First'))
+ layout.addWidget(qt.QLabel('Second'))
+ widget = qt.QWidget()
+ widget.setLayout(layout)
+
+ # Add it to a BoxLayoutDockWidget
+ dock = BoxLayoutDockWidget()
+ dock.setWidget(widget)
+
+ self.window.addDockWidget(qt.Qt.BottomDockWidgetArea, dock)
+ self.qapp.processEvents()
+ self.assertEqual(layout.direction(), qt.QBoxLayout.LeftToRight)
+
+ self.window.addDockWidget(qt.Qt.LeftDockWidgetArea, dock)
+ self.qapp.processEvents()
+ self.assertEqual(layout.direction(), qt.QBoxLayout.TopToBottom)
diff --git a/src/silx/gui/widgets/test/test_elidedlabel.py b/src/silx/gui/widgets/test/test_elidedlabel.py
new file mode 100644
index 0000000..693e43c
--- /dev/null
+++ b/src/silx/gui/widgets/test/test_elidedlabel.py
@@ -0,0 +1,100 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for ElidedLabel"""
+
+__license__ = "MIT"
+__date__ = "08/06/2020"
+
+import unittest
+
+from silx.gui import qt
+from silx.gui.widgets.ElidedLabel import ElidedLabel
+from silx.gui.utils import testutils
+
+
+class TestElidedLabel(testutils.TestCaseQt):
+
+ def setUp(self):
+ self.label = ElidedLabel()
+ self.label.show()
+ self.qWaitForWindowExposed(self.label)
+
+ def tearDown(self):
+ self.label.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.label.close()
+ del self.label
+ self.qapp.processEvents()
+
+ def testElidedValue(self):
+ """Test elided text"""
+ raw = "mmmmmmmmmmmmmmmmmmmmmmmmmmmmmmm"
+ self.label.setText(raw)
+ self.label.setFixedWidth(30)
+ displayedText = qt.QLabel.text(self.label)
+ self.assertNotEqual(raw, displayedText)
+ self.assertIn("…", displayedText)
+ self.assertIn("m", displayedText)
+
+ def testNotElidedValue(self):
+ """Test elided text"""
+ raw = "mmmmmmm"
+ self.label.setText(raw)
+ self.label.setFixedWidth(200)
+ displayedText = qt.QLabel.text(self.label)
+ self.assertNotIn("…", displayedText)
+ self.assertEqual(raw, displayedText)
+
+ def testUpdateFromElidedToNotElided(self):
+ """Test tooltip when not elided"""
+ raw1 = "mmmmmmmmmmmmmmmmmmmmmmmmmmmmmmm"
+ raw2 = "nn"
+ self.label.setText(raw1)
+ self.label.setFixedWidth(30)
+ self.label.setText(raw2)
+ displayedTooltip = qt.QLabel.toolTip(self.label)
+ self.assertNotIn(raw1, displayedTooltip)
+ self.assertNotIn(raw2, displayedTooltip)
+
+ def testUpdateFromNotElidedToElided(self):
+ """Test tooltip when elided"""
+ raw1 = "nn"
+ raw2 = "mmmmmmmmmmmmmmmmmmmmmmmmmmmmmmm"
+ self.label.setText(raw1)
+ self.label.setFixedWidth(30)
+ self.label.setText(raw2)
+ displayedTooltip = qt.QLabel.toolTip(self.label)
+ self.assertNotIn(raw1, displayedTooltip)
+ self.assertIn(raw2, displayedTooltip)
+
+ def testUpdateFromElidedToElided(self):
+ """Test tooltip when elided"""
+ raw1 = "nnnnnnnnnnnnnnnnnnnnnnnnnnnnnnn"
+ raw2 = "mmmmmmmmmmmmmmmmmmmmmmmmmmmmmmm"
+ self.label.setText(raw1)
+ self.label.setFixedWidth(30)
+ self.label.setText(raw2)
+ displayedTooltip = qt.QLabel.toolTip(self.label)
+ self.assertNotIn(raw1, displayedTooltip)
+ self.assertIn(raw2, displayedTooltip)
diff --git a/src/silx/gui/widgets/test/test_flowlayout.py b/src/silx/gui/widgets/test/test_flowlayout.py
new file mode 100644
index 0000000..85d7cfe
--- /dev/null
+++ b/src/silx/gui/widgets/test/test_flowlayout.py
@@ -0,0 +1,66 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for FlowLayout"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "02/08/2018"
+
+import unittest
+
+from silx.gui.widgets.FlowLayout import FlowLayout
+from silx.gui import qt
+from silx.gui.utils.testutils import TestCaseQt
+
+
+class TestFlowLayout(TestCaseQt):
+ """Tests for FlowLayout"""
+
+ def setUp(self):
+ """Create and show a widget"""
+ self.widget = qt.QWidget()
+ self.widget.show()
+ self.qWaitForWindowExposed(self.widget)
+
+ def tearDown(self):
+ """Delete widget"""
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ del self.widget
+ self.qapp.processEvents()
+
+ def test(self):
+ """Basic tests"""
+ layout = FlowLayout()
+ self.widget.setLayout(layout)
+
+ layout.addWidget(qt.QLabel('first'))
+ layout.addWidget(qt.QLabel('second'))
+ self.assertEqual(layout.count(), 2)
+
+ layout.setHorizontalSpacing(10)
+ self.assertEqual(layout.horizontalSpacing(), 10)
+ layout.setVerticalSpacing(5)
+ self.assertEqual(layout.verticalSpacing(), 5)
diff --git a/src/silx/gui/widgets/test/test_framebrowser.py b/src/silx/gui/widgets/test/test_framebrowser.py
new file mode 100644
index 0000000..8233622
--- /dev/null
+++ b/src/silx/gui/widgets/test/test_framebrowser.py
@@ -0,0 +1,62 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "23/03/2018"
+
+
+import unittest
+
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.widgets.FrameBrowser import FrameBrowser
+
+
+class TestFrameBrowser(TestCaseQt):
+ """Test for FrameBrowser"""
+
+ def test(self):
+ """Test FrameBrowser"""
+ widget = FrameBrowser()
+ widget.show()
+ self.qWaitForWindowExposed(widget)
+
+ nFrames = 20
+ widget.setNFrames(nFrames)
+ self.assertEqual(widget.getRange(), (0, nFrames - 1))
+ self.assertEqual(widget.getValue(), 0)
+
+ range_ = -100, 100
+ widget.setRange(*range_)
+ self.assertEqual(widget.getRange(), range_)
+ self.assertEqual(widget.getValue(), range_[0])
+
+ widget.setValue(0)
+ self.assertEqual(widget.getValue(), 0)
+
+ widget.setValue(range_[1] + 100)
+ self.assertEqual(widget.getValue(), range_[1])
+
+ widget.setValue(range_[0] - 100)
+ self.assertEqual(widget.getValue(), range_[0])
diff --git a/src/silx/gui/widgets/test/test_hierarchicaltableview.py b/src/silx/gui/widgets/test/test_hierarchicaltableview.py
new file mode 100644
index 0000000..302086a
--- /dev/null
+++ b/src/silx/gui/widgets/test/test_hierarchicaltableview.py
@@ -0,0 +1,103 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "07/04/2017"
+
+import unittest
+
+from .. import HierarchicalTableView
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui import qt
+
+
+class TableModel(HierarchicalTableView.HierarchicalTableModel):
+
+ def __init__(self, parent):
+ HierarchicalTableView.HierarchicalTableModel.__init__(self, parent)
+ self.__content = {}
+
+ def rowCount(self, parent=qt.QModelIndex()):
+ return 3
+
+ def columnCount(self, parent=qt.QModelIndex()):
+ return 3
+
+ def setData1(self):
+ self.beginResetModel()
+
+ content = {}
+ content[0, 0] = ("title", True, (1, 3))
+ content[0, 1] = ("a", True, (2, 1))
+ content[1, 1] = ("b", False, (1, 2))
+ content[1, 2] = ("c", False, (1, 1))
+ content[2, 2] = ("d", False, (1, 1))
+ self.__content = content
+
+ self.endResetModel()
+
+ def data(self, index, role=qt.Qt.DisplayRole):
+ if not index.isValid():
+ return None
+ cell = self.__content.get((index.column(), index.row()), None)
+ if cell is None:
+ return None
+
+ if role == self.SpanRole:
+ return cell[2]
+ elif role == self.IsHeaderRole:
+ return cell[1]
+ elif role == qt.Qt.DisplayRole:
+ return cell[0]
+ return None
+
+
+class TestHierarchicalTableView(TestCaseQt):
+ """Test for HierarchicalTableView"""
+
+ def testEmpty(self):
+ widget = HierarchicalTableView.HierarchicalTableView()
+ widget.show()
+ self.qWaitForWindowExposed(widget)
+
+ def testModel(self):
+ widget = HierarchicalTableView.HierarchicalTableView()
+ model = TableModel(widget)
+ # set the data before using the model into the widget
+ model.setData1()
+ widget.setModel(model)
+ span = widget.rowSpan(0, 0), widget.columnSpan(0, 0)
+ self.assertEqual(span, (1, 3))
+ widget.show()
+ self.qWaitForWindowExposed(widget)
+
+ def testModelUpdate(self):
+ widget = HierarchicalTableView.HierarchicalTableView()
+ model = TableModel(widget)
+ widget.setModel(model)
+ # set the data after using the model into the widget
+ model.setData1()
+ span = widget.rowSpan(0, 0), widget.columnSpan(0, 0)
+ self.assertEqual(span, (1, 3))
diff --git a/src/silx/gui/widgets/test/test_legendiconwidget.py b/src/silx/gui/widgets/test/test_legendiconwidget.py
new file mode 100644
index 0000000..fe320f6
--- /dev/null
+++ b/src/silx/gui/widgets/test/test_legendiconwidget.py
@@ -0,0 +1,63 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for LegendIconWidget"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "23/10/2020"
+
+import unittest
+
+from silx.gui import qt
+from silx.gui.widgets.LegendIconWidget import LegendIconWidget
+from silx.gui.utils.testutils import TestCaseQt
+from silx.utils.testutils import ParametricTestCase
+
+
+class TestLegendIconWidget(TestCaseQt, ParametricTestCase):
+ """Tests for TestRangeSlider"""
+
+ def setUp(self):
+ self.widget = LegendIconWidget()
+ self.widget.show()
+ self.qWaitForWindowExposed(self.widget)
+
+ def tearDown(self):
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ del self.widget
+ self.qapp.processEvents()
+
+ def testCreate(self):
+ self.qapp.processEvents()
+
+ def testColormap(self):
+ self.widget.setColormap("viridis")
+ self.qapp.processEvents()
+
+ def testSymbol(self):
+ self.widget.setSymbol("o")
+ self.widget.setSymbolColormap("viridis")
+ self.qapp.processEvents()
diff --git a/src/silx/gui/widgets/test/test_periodictable.py b/src/silx/gui/widgets/test/test_periodictable.py
new file mode 100644
index 0000000..de9e1af
--- /dev/null
+++ b/src/silx/gui/widgets/test/test_periodictable.py
@@ -0,0 +1,148 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+import unittest
+
+from .. import PeriodicTable
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui import qt
+
+
+class TestPeriodicTable(TestCaseQt):
+ """Basic test for ArrayTableWidget with a numpy array"""
+
+ def testShow(self):
+ """basic test (instantiation done in setUp)"""
+ pt = PeriodicTable.PeriodicTable()
+ pt.show()
+ self.qWaitForWindowExposed(pt)
+
+ def testSelectable(self):
+ """basic test (instantiation done in setUp)"""
+ pt = PeriodicTable.PeriodicTable(selectable=True)
+ self.assertTrue(pt.selectable)
+
+ def testCustomElements(self):
+ PTI = PeriodicTable.ColoredPeriodicTableItem
+ my_items = [
+ PTI("Xx", 42, 43, 44, "xaxatorium", 1002.2,
+ bgcolor="#FF0000"),
+ PTI("Yy", 25, 22, 44, "yoyotrium", 8.8)
+ ]
+
+ pt = PeriodicTable.PeriodicTable(elements=my_items)
+
+ pt.setSelection(["He", "Xx"])
+ selection = pt.getSelection()
+ self.assertEqual(len(selection), 1) # "He" not found
+ self.assertEqual(selection[0].symbol, "Xx")
+ self.assertEqual(selection[0].Z, 42)
+ self.assertEqual(selection[0].col, 43)
+ self.assertAlmostEqual(selection[0].mass, 1002.2)
+ self.assertEqual(qt.QColor(selection[0].bgcolor),
+ qt.QColor(qt.Qt.red))
+
+ self.assertTrue(pt.isElementSelected("Xx"))
+ self.assertFalse(pt.isElementSelected("Yy"))
+ self.assertRaises(KeyError, pt.isElementSelected, "Yx")
+
+ def testVeryCustomElements(self):
+ class MyPTI(PeriodicTable.PeriodicTableItem):
+ def __init__(self, *args):
+ PeriodicTable.PeriodicTableItem.__init__(self, *args[:6])
+ self.my_feature = args[6]
+
+ my_items = [
+ MyPTI("Xx", 42, 43, 44, "xaxatorium", 1002.2, "spam"),
+ MyPTI("Yy", 25, 22, 44, "yoyotrium", 8.8, "eggs")
+ ]
+
+ pt = PeriodicTable.PeriodicTable(elements=my_items)
+
+ pt.setSelection(["Xx", "Yy"])
+ selection = pt.getSelection()
+ self.assertEqual(len(selection), 2)
+ self.assertEqual(selection[1].symbol, "Yy")
+ self.assertEqual(selection[1].Z, 25)
+ self.assertEqual(selection[1].col, 22)
+ self.assertEqual(selection[1].row, 44)
+ self.assertAlmostEqual(selection[0].mass, 1002.2)
+ self.assertAlmostEqual(selection[0].my_feature, "spam")
+
+
+class TestPeriodicCombo(TestCaseQt):
+ """Basic test for ArrayTableWidget with a numpy array"""
+ def setUp(self):
+ super(TestPeriodicCombo, self).setUp()
+ self.pc = PeriodicTable.PeriodicCombo()
+
+ def tearDown(self):
+ del self.pc
+ super(TestPeriodicCombo, self).tearDown()
+
+ def testShow(self):
+ """basic test (instantiation done in setUp)"""
+ self.pc.show()
+ self.qWaitForWindowExposed(self.pc)
+
+ def testSelect(self):
+ self.pc.setSelection("Sb")
+ selection = self.pc.getSelection()
+ self.assertIsInstance(selection,
+ PeriodicTable.PeriodicTableItem)
+ self.assertEqual(selection.symbol, "Sb")
+ self.assertEqual(selection.Z, 51)
+ self.assertEqual(selection.name, "antimony")
+
+
+class TestPeriodicList(TestCaseQt):
+ """Basic test for ArrayTableWidget with a numpy array"""
+ def setUp(self):
+ super(TestPeriodicList, self).setUp()
+ self.pl = PeriodicTable.PeriodicList()
+
+ def tearDown(self):
+ del self.pl
+ super(TestPeriodicList, self).tearDown()
+
+ def testShow(self):
+ """basic test (instantiation done in setUp)"""
+ self.pl.show()
+ self.qWaitForWindowExposed(self.pl)
+
+ def testSelect(self):
+ self.pl.setSelectedElements(["Li", "He", "Au"])
+ sel_elmts = self.pl.getSelection()
+
+ self.assertEqual(len(sel_elmts), 3,
+ "Wrong number of elements selected")
+ for e in sel_elmts:
+ self.assertIsInstance(e, PeriodicTable.PeriodicTableItem)
+ self.assertIn(e.symbol, ["Li", "He", "Au"])
+ self.assertIn(e.Z, [2, 3, 79])
+ self.assertIn(e.name, ["lithium", "helium", "gold"])
diff --git a/src/silx/gui/widgets/test/test_printpreview.py b/src/silx/gui/widgets/test/test_printpreview.py
new file mode 100644
index 0000000..8602666
--- /dev/null
+++ b/src/silx/gui/widgets/test/test_printpreview.py
@@ -0,0 +1,63 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test PrintPreview"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "19/07/2017"
+
+
+import unittest
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.widgets.PrintPreview import PrintPreviewDialog
+from silx.gui import qt
+
+from silx.resources import resource_filename
+
+
+class TestPrintPreview(TestCaseQt):
+ def testShow(self):
+ p = qt.QPrinter()
+ d = PrintPreviewDialog(printer=p)
+ d.show()
+ self.qapp.processEvents()
+
+ def testAddImage(self):
+ p = qt.QPrinter()
+ d = PrintPreviewDialog(printer=p)
+ d.addImage(qt.QImage(resource_filename("gui/icons/clipboard.png")))
+ self.qapp.processEvents()
+
+ def testAddSvg(self):
+ p = qt.QPrinter()
+ d = PrintPreviewDialog(printer=p)
+ d.addSvgItem(qt.QSvgRenderer(resource_filename("gui/icons/clipboard.svg"), d.page))
+ self.qapp.processEvents()
+
+ def testAddPixmap(self):
+ p = qt.QPrinter()
+ d = PrintPreviewDialog(printer=p)
+ d.addPixmap(qt.QPixmap.fromImage(qt.QImage(resource_filename("gui/icons/clipboard.png"))))
+ self.qapp.processEvents()
diff --git a/src/silx/gui/widgets/test/test_rangeslider.py b/src/silx/gui/widgets/test/test_rangeslider.py
new file mode 100644
index 0000000..f829857
--- /dev/null
+++ b/src/silx/gui/widgets/test/test_rangeslider.py
@@ -0,0 +1,103 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for RangeSlider"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "01/08/2018"
+
+import unittest
+
+from silx.gui import qt, colors
+from silx.gui.widgets.RangeSlider import RangeSlider
+from silx.gui.utils.testutils import TestCaseQt
+from silx.utils.testutils import ParametricTestCase
+
+
+class TestRangeSlider(TestCaseQt, ParametricTestCase):
+ """Tests for TestRangeSlider"""
+
+ def setUp(self):
+ self.slider = RangeSlider()
+ self.slider.show()
+ self.qWaitForWindowExposed(self.slider)
+
+ def tearDown(self):
+ self.slider.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.slider.close()
+ del self.slider
+ self.qapp.processEvents()
+
+ def testRangeValue(self):
+ """Test slider range and values"""
+
+ # Play with range
+ self.slider.setRange(1, 2)
+ self.assertEqual(self.slider.getRange(), (1., 2.))
+ self.assertEqual(self.slider.getValues(), (1., 1.))
+
+ self.slider.setMinimum(-1)
+ self.assertEqual(self.slider.getRange(), (-1., 2.))
+ self.assertEqual(self.slider.getValues(), (1., 1.))
+
+ self.slider.setMaximum(0)
+ self.assertEqual(self.slider.getRange(), (-1., 0.))
+ self.assertEqual(self.slider.getValues(), (0., 0.))
+
+ # Play with values
+ self.slider.setFirstValue(-2.)
+ self.assertEqual(self.slider.getValues(), (-1., 0.))
+
+ self.slider.setFirstValue(-0.5)
+ self.assertEqual(self.slider.getValues(), (-0.5, 0.))
+
+ self.slider.setSecondValue(2.)
+ self.assertEqual(self.slider.getValues(), (-0.5, 0.))
+
+ self.slider.setSecondValue(-0.1)
+ self.assertEqual(self.slider.getValues(), (-0.5, -0.1))
+
+ def testStepCount(self):
+ """Test related to step count"""
+ self.slider.setPositionCount(11)
+ self.assertEqual(self.slider.getPositionCount(), 11)
+ self.slider.setFirstValue(0.32)
+ self.assertEqual(self.slider.getFirstValue(), 0.3)
+ self.assertEqual(self.slider.getFirstPosition(), 3)
+
+ self.slider.setPositionCount(3) # Value is adjusted
+ self.assertEqual(self.slider.getValues(), (0.5, 1.))
+ self.assertEqual(self.slider.getPositions(), (1, 2))
+
+ def testGroove(self):
+ """Test Groove pixmap"""
+ profile = list(range(100))
+
+ for cmap in ('jet', colors.Colormap('viridis')):
+ with self.subTest(str(cmap)):
+ self.slider.setGroovePixmapFromProfile(profile, cmap)
+ pixmap = self.slider.getGroovePixmap()
+ self.assertIsInstance(pixmap, qt.QPixmap)
+ self.assertEqual(pixmap.width(), len(profile))
diff --git a/src/silx/gui/widgets/test/test_tablewidget.py b/src/silx/gui/widgets/test/test_tablewidget.py
new file mode 100644
index 0000000..09122ca
--- /dev/null
+++ b/src/silx/gui/widgets/test/test_tablewidget.py
@@ -0,0 +1,50 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test TableWidget"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+
+import unittest
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.widgets.TableWidget import TableWidget
+
+
+class TestTableWidget(TestCaseQt):
+ def setUp(self):
+ super(TestTableWidget, self).setUp()
+ self._result = []
+
+ def testShow(self):
+ table = TableWidget()
+ table.setColumnCount(10)
+ table.setRowCount(7)
+ table.enableCut()
+ table.enablePaste()
+ table.show()
+ table.hide()
+ self.qapp.processEvents()
diff --git a/src/silx/gui/widgets/test/test_threadpoolpushbutton.py b/src/silx/gui/widgets/test/test_threadpoolpushbutton.py
new file mode 100644
index 0000000..3808be0
--- /dev/null
+++ b/src/silx/gui/widgets/test/test_threadpoolpushbutton.py
@@ -0,0 +1,124 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test for silx.gui.hdf5 module"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+
+import unittest
+import time
+from silx.gui import qt
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.utils.testutils import SignalListener
+from silx.gui.widgets.ThreadPoolPushButton import ThreadPoolPushButton
+from silx.utils.testutils import LoggingValidator
+
+
+class TestThreadPoolPushButton(TestCaseQt):
+
+ def setUp(self):
+ super(TestThreadPoolPushButton, self).setUp()
+ self._result = []
+
+ def waitForPendingOperations(self, object):
+ for i in range(50):
+ if not object.hasPendingOperations():
+ break
+ self.qWait(10)
+ else:
+ raise RuntimeError("Still waiting for a pending operation")
+
+ def _trace(self, name, delay=0):
+ self._result.append(name)
+ if delay != 0:
+ time.sleep(delay / 1000.0)
+
+ def _compute(self):
+ return "result"
+
+ def _computeFail(self):
+ raise Exception("exception")
+
+ def testExecute(self):
+ button = ThreadPoolPushButton()
+ button.setCallable(self._trace, "a", 0)
+ button.executeCallable()
+ time.sleep(0.1)
+ self.assertListEqual(self._result, ["a"])
+ self.waitForPendingOperations(button)
+
+ def testMultiExecution(self):
+ button = ThreadPoolPushButton()
+ button.setCallable(self._trace, "a", 0)
+ number = qt.silxGlobalThreadPool().maxThreadCount()
+ for _ in range(number):
+ button.executeCallable()
+ self.waitForPendingOperations(button)
+ self.assertListEqual(self._result, ["a"] * number)
+
+ def testSaturateThreadPool(self):
+ button = ThreadPoolPushButton()
+ button.setCallable(self._trace, "a", 100)
+ number = qt.silxGlobalThreadPool().maxThreadCount() * 2
+ for _ in range(number):
+ button.executeCallable()
+ self.waitForPendingOperations(button)
+ self.assertListEqual(self._result, ["a"] * number)
+
+ def testSuccess(self):
+ listener = SignalListener()
+ button = ThreadPoolPushButton()
+ button.setCallable(self._compute)
+ button.beforeExecuting.connect(listener.partial(test="be"))
+ button.started.connect(listener.partial(test="s"))
+ button.succeeded.connect(listener.partial(test="result"))
+ button.failed.connect(listener.partial(test="Unexpected exception"))
+ button.finished.connect(listener.partial(test="f"))
+ button.executeCallable()
+ self.qapp.processEvents()
+ time.sleep(0.1)
+ self.qapp.processEvents()
+ result = listener.karguments(argumentName="test")
+ self.assertListEqual(result, ["be", "s", "result", "f"])
+
+ def testFail(self):
+ listener = SignalListener()
+ button = ThreadPoolPushButton()
+ button.setCallable(self._computeFail)
+ button.beforeExecuting.connect(listener.partial(test="be"))
+ button.started.connect(listener.partial(test="s"))
+ button.succeeded.connect(listener.partial(test="Unexpected success"))
+ button.failed.connect(listener.partial(test="exception"))
+ button.finished.connect(listener.partial(test="f"))
+ with LoggingValidator('silx.gui.widgets.ThreadPoolPushButton', error=1):
+ button.executeCallable()
+ self.qapp.processEvents()
+ time.sleep(0.1)
+ self.qapp.processEvents()
+ result = listener.karguments(argumentName="test")
+ self.assertListEqual(result, ["be", "s", "exception", "f"])
+ listener.clear()