summaryrefslogtreecommitdiff
path: root/silx/gui/plot
diff options
context:
space:
mode:
Diffstat (limited to 'silx/gui/plot')
-rw-r--r--silx/gui/plot/ColorBar.py19
-rw-r--r--silx/gui/plot/CompareImages.py2
-rw-r--r--silx/gui/plot/ComplexImageView.py11
-rw-r--r--silx/gui/plot/CurvesROIWidget.py1834
-rw-r--r--silx/gui/plot/MaskToolsWidget.py108
-rw-r--r--silx/gui/plot/PlotInteraction.py148
-rw-r--r--silx/gui/plot/PlotToolButtons.py29
-rw-r--r--silx/gui/plot/PlotWidget.py492
-rw-r--r--silx/gui/plot/PlotWindow.py87
-rw-r--r--silx/gui/plot/PrintPreviewToolButton.py61
-rw-r--r--silx/gui/plot/Profile.py109
-rw-r--r--silx/gui/plot/ScatterMaskToolsWidget.py65
-rw-r--r--silx/gui/plot/ScatterView.py12
-rw-r--r--silx/gui/plot/StackView.py10
-rw-r--r--silx/gui/plot/StatsWidget.py1594
-rw-r--r--silx/gui/plot/_BaseMaskToolsWidget.py157
-rw-r--r--silx/gui/plot/_utils/dtime_ticklayout.py4
-rw-r--r--silx/gui/plot/actions/control.py11
-rw-r--r--silx/gui/plot/actions/io.py38
-rw-r--r--silx/gui/plot/backends/BackendBase.py37
-rw-r--r--silx/gui/plot/backends/BackendMatplotlib.py223
-rw-r--r--silx/gui/plot/backends/BackendOpenGL.py364
-rw-r--r--silx/gui/plot/backends/glutils/GLPlotCurve.py119
-rw-r--r--silx/gui/plot/backends/glutils/GLPlotFrame.py124
-rw-r--r--silx/gui/plot/backends/glutils/GLSupport.py63
-rw-r--r--silx/gui/plot/items/__init__.py2
-rw-r--r--silx/gui/plot/items/axis.py6
-rw-r--r--silx/gui/plot/items/complex.py8
-rw-r--r--silx/gui/plot/items/core.py37
-rw-r--r--silx/gui/plot/items/curve.py5
-rw-r--r--silx/gui/plot/items/histogram.py6
-rw-r--r--silx/gui/plot/items/roi.py72
-rw-r--r--silx/gui/plot/items/scatter.py5
-rw-r--r--silx/gui/plot/items/shape.py45
-rw-r--r--silx/gui/plot/matplotlib/Colormap.py16
-rw-r--r--silx/gui/plot/stats/stats.py400
-rw-r--r--silx/gui/plot/stats/statshandler.py124
-rw-r--r--silx/gui/plot/test/testCurvesROIWidget.py219
-rw-r--r--silx/gui/plot/test/testMaskToolsWidget.py7
-rw-r--r--silx/gui/plot/test/testPlotWidget.py61
-rw-r--r--silx/gui/plot/test/testSaveAction.py20
-rw-r--r--silx/gui/plot/test/testScatterMaskToolsWidget.py5
-rw-r--r--silx/gui/plot/test/testStats.py284
-rw-r--r--silx/gui/plot/test/testUtilsAxis.py49
-rw-r--r--silx/gui/plot/tools/roi.py20
-rw-r--r--silx/gui/plot/tools/test/testScatterProfileToolBar.py2
-rw-r--r--silx/gui/plot/utils/axis.py288
47 files changed, 4783 insertions, 2619 deletions
diff --git a/silx/gui/plot/ColorBar.py b/silx/gui/plot/ColorBar.py
index fd4d34e..9798123 100644
--- a/silx/gui/plot/ColorBar.py
+++ b/silx/gui/plot/ColorBar.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -251,10 +251,13 @@ class ColorBarWidget(qt.QWidget):
def _defaultColormapChanged(self, event):
"""Handle plot default colormap changed"""
- if (event['event'] == 'defaultColormapChanged' and
- self.getPlot().getActiveImage() is None):
- # No active image, take default colormap update into account
- self._syncWithDefaultColormap()
+ if event['event'] == 'defaultColormapChanged':
+ plot = self.getPlot()
+ if (plot is not None and
+ plot.getActiveImage() is None and
+ plot._getActiveItem(kind='scatter') is None):
+ # No active item, take default colormap update into account
+ self._syncWithDefaultColormap()
def _syncWithDefaultColormap(self, data=None):
"""Update colorbar according to plot default colormap"""
@@ -801,7 +804,7 @@ class _TickBar(qt.QWidget):
if self._norm == colors.Colormap.LINEAR:
return 1 - (val - self._vmin) / (self._vmax - self._vmin)
elif self._norm == colors.Colormap.LOGARITHM:
- return 1 - (numpy.log10(val) - numpy.log10(self._vmin)) / (numpy.log10(self._vmax) - numpy.log(self._vmin))
+ return 1 - (numpy.log10(val) - numpy.log10(self._vmin)) / (numpy.log10(self._vmax) - numpy.log10(self._vmin))
else:
raise ValueError('Norm is not recognized')
@@ -864,7 +867,7 @@ class _TickBar(qt.QWidget):
def _guessType(self, font):
"""Try fo find the better format to display the tick's labels
- :param QFont font: the font we want want to use durint the painting
+ :param QFont font: the font we want to use during the painting
"""
form = self._getStandardFormat()
@@ -873,7 +876,7 @@ class _TickBar(qt.QWidget):
for tick in self.ticks:
width = max(fm.width(form.format(tick)), width)
- # if the length of the string are too long we are mooving to scientific
+ # if the length of the string are too long we are moving to scientific
# display
if width > _TickBar._WIDTH_DISP_VAL - _TickBar._LINE_WIDTH:
return self._getScientificForm()
diff --git a/silx/gui/plot/CompareImages.py b/silx/gui/plot/CompareImages.py
index 88b257d..f7c4899 100644
--- a/silx/gui/plot/CompareImages.py
+++ b/silx/gui/plot/CompareImages.py
@@ -30,6 +30,7 @@ __license__ = "MIT"
__date__ = "23/07/2018"
+import enum
import logging
import numpy
import weakref
@@ -42,7 +43,6 @@ from silx.gui import plot
from silx.gui import icons
from silx.gui.colors import Colormap
from silx.gui.plot import tools
-from silx.third_party import enum
_logger = logging.getLogger(__name__)
diff --git a/silx/gui/plot/ComplexImageView.py b/silx/gui/plot/ComplexImageView.py
index bbcb0a5..2523cde 100644
--- a/silx/gui/plot/ComplexImageView.py
+++ b/silx/gui/plot/ComplexImageView.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -365,7 +365,7 @@ class ComplexImageView(qt.QWidget):
- log10_amplitude_phase:
Color-coded phase with log10(amplitude) as alpha.
- :rtype: tuple of str
+ :rtype: List[Mode]
"""
return tuple(ImageComplexData.Mode)
@@ -375,7 +375,12 @@ class ComplexImageView(qt.QWidget):
See :meth:`getSupportedVisualizationModes` for the list of
supported modes.
- :param str mode: The mode to use.
+ How-to change visualization mode::
+
+ widget = ComplexImageView()
+ widget.setVisualizationMode(ComplexImageView.Mode.PHASE)
+
+ :param Mode mode: The mode to use.
"""
self._plotImage.setVisualizationMode(mode)
diff --git a/silx/gui/plot/CurvesROIWidget.py b/silx/gui/plot/CurvesROIWidget.py
index 81e684e..b426a23 100644
--- a/silx/gui/plot/CurvesROIWidget.py
+++ b/silx/gui/plot/CurvesROIWidget.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -22,50 +22,43 @@
# THE SOFTWARE.
#
# ###########################################################################*/
-"""Widget to handle regions of interest (ROI) on curves displayed in a PlotWindow.
+"""
+Widget to handle regions of interest (:class:`ROI`) on curves displayed in a
+:class:`PlotWindow`.
This widget is meant to work with :class:`PlotWindow`.
-
-ROI are defined by :
-
-- A name (`ROI` column)
-- A type. The type is the label of the x axis.
- This can be used to apply or not some ROI to a curve and do some post processing.
-- The x coordinate of the left limit (`from` column)
-- The x coordinate of the right limit (`to` column)
-- Raw counts: Sum of the curve's values in the defined Region Of Intereset.
-
- .. image:: img/rawCounts.png
-
-- Net counts: Raw counts minus background
-
- .. image:: img/netCounts.png
"""
-__authors__ = ["V.A. Sole", "T. Vincent"]
+__authors__ = ["V.A. Sole", "T. Vincent", "H. Payno"]
__license__ = "MIT"
-__date__ = "13/11/2017"
+__date__ = "13/03/2018"
from collections import OrderedDict
-
import logging
import os
import sys
-import weakref
-
+import functools
import numpy
-
from silx.io import dictdump
from silx.utils import deprecation
-
+from silx.utils.weakref import WeakMethodProxy
from .. import icons, qt
+from silx.gui.plot.items.curve import Curve
+from silx.math.combo import min_max
+import weakref
+from silx.gui.widgets.TableWidget import TableWidget
_logger = logging.getLogger(__name__)
class CurvesROIWidget(qt.QWidget):
- """Widget displaying a table of ROI information.
+ """
+ Widget displaying a table of ROI information.
+
+ Implements also the following behavior:
+
+ * if the roiTable has no ROI when showing create the default ICR one
:param parent: See :class:`QWidget`
:param str name: The title of this widget
@@ -73,19 +66,18 @@ class CurvesROIWidget(qt.QWidget):
sigROIWidgetSignal = qt.Signal(object)
"""Signal of ROIs modifications.
-
- Modification information if given as a dict with an 'event' key
- providing the type of events.
-
- Type of events:
-
- - AddROI, DelROI, LoadROI and ResetROI with keys: 'roilist', 'roidict'
-
- - selectionChanged with keys: 'row', 'col' 'roi', 'key', 'colheader',
- 'rowheader'
+ Modification information if given as a dict with an 'event' key
+ providing the type of events.
+ Type of events:
+ - AddROI, DelROI, LoadROI and ResetROI with keys: 'roilist', 'roidict'
+ - selectionChanged with keys: 'row', 'col' 'roi', 'key', 'colheader',
+ 'rowheader'
"""
sigROISignal = qt.Signal(object)
+ """Deprecated signal for backward compatibility with silx < 0.7.
+ Prefer connecting directly to :attr:`CurvesRoiWidget.sigRoiSignal`
+ """
def __init__(self, parent=None, name=None, plot=None):
super(CurvesROIWidget, self).__init__(parent)
@@ -93,6 +85,8 @@ class CurvesROIWidget(qt.QWidget):
self.setWindowTitle(name)
assert plot is not None
self._plotRef = weakref.ref(plot)
+ self._showAllMarkers = False
+ self.currentROI = None
layout = qt.QVBoxLayout(self)
layout.setContentsMargins(0, 0, 0, 0)
@@ -103,13 +97,22 @@ class CurvesROIWidget(qt.QWidget):
self.setHeader()
layout.addWidget(self.headerLabel)
##############
- self.roiTable = ROITable(self)
+ widgetAllCheckbox = qt.QWidget(parent=self)
+ self._showAllCheckBox = qt.QCheckBox("show all ROI",
+ parent=widgetAllCheckbox)
+ widgetAllCheckbox.setLayout(qt.QHBoxLayout())
+ spacer = qt.QWidget(parent=widgetAllCheckbox)
+ spacer.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
+ widgetAllCheckbox.layout().addWidget(spacer)
+ widgetAllCheckbox.layout().addWidget(self._showAllCheckBox)
+ layout.addWidget(widgetAllCheckbox)
+ ##############
+ self.roiTable = ROITable(self, plot=plot)
rheight = self.roiTable.horizontalHeader().sizeHint().height()
self.roiTable.setMinimumHeight(4 * rheight)
- self.fillFromROIDict = self.roiTable.fillFromROIDict
- self.getROIListAndDict = self.roiTable.getROIListAndDict
layout.addWidget(self.roiTable)
self._roiFileDir = qt.QDir.home().absolutePath()
+ self._showAllCheckBox.toggled.connect(self.roiTable.showAllMarkers)
#################
hbox = qt.QWidget(self)
@@ -127,7 +130,8 @@ class CurvesROIWidget(qt.QWidget):
self.addButton.setToolTip('Remove the selected ROI')
self.resetButton = qt.QPushButton(hbox)
self.resetButton.setText("Reset")
- self.addButton.setToolTip('Clear all created ROIs. We only let the default ROI')
+ self.addButton.setToolTip('Clear all created ROIs. We only let the '
+ 'default ROI')
hboxlayout.addWidget(self.addButton)
hboxlayout.addWidget(self.delButton)
@@ -149,19 +153,22 @@ class CurvesROIWidget(qt.QWidget):
layout.addWidget(hbox)
+ # Signal / Slot connections
self.addButton.clicked.connect(self._add)
self.delButton.clicked.connect(self._del)
self.resetButton.clicked.connect(self._reset)
self.loadButton.clicked.connect(self._load)
self.saveButton.clicked.connect(self._save)
- self.roiTable.sigROITableSignal.connect(self._forward)
- self.currentROI = None
- self._middleROIMarkerFlag = False
+ self.roiTable.activeROIChanged.connect(self._emitCurrentROISignal)
+
self._isConnected = False # True if connected to plot signals
self._isInit = False
+ # expose API
+ self.getROIListAndDict = self.roiTable.getROIListAndDict
+
def getPlotWidget(self):
"""Returns the associated PlotWidget or None
@@ -173,10 +180,6 @@ class CurvesROIWidget(qt.QWidget):
self._visibilityChangedHandler(visible=True)
qt.QWidget.showEvent(self, event)
- def hideEvent(self, event):
- self._visibilityChangedHandler(visible=False)
- qt.QWidget.hideEvent(self, event)
-
@property
def roiFileDir(self):
"""The directory from which to load/save ROI from/to files."""
@@ -188,135 +191,81 @@ class CurvesROIWidget(qt.QWidget):
def roiFileDir(self, roiFileDir):
self._roiFileDir = str(roiFileDir)
- def setRois(self, roidict, order=None):
- """Set the ROIs by providing a dictionary of ROI information.
-
- The dictionary keys are the ROI names.
- Each value is a sub-dictionary of ROI info with the following fields:
-
- - ``"from"``: x coordinate of the left limit, as a float
- - ``"to"``: x coordinate of the right limit, as a float
- - ``"type"``: type of ROI, as a string (e.g "channels", "energy")
-
-
- :param roidict: Dictionary of ROIs
- :param str order: Field used for ordering the ROIs.
- One of "from", "to", "type".
- None (default) for no ordering, or same order as specified
- in parameter ``roidict`` if provided as an OrderedDict.
- """
- if order is None or order.lower() == "none":
- roilist = list(roidict.keys())
- else:
- assert order in ["from", "to", "type"]
- roilist = sorted(roidict.keys(),
- key=lambda roi_name: roidict[roi_name].get(order))
-
- return self.roiTable.fillFromROIDict(roilist, roidict)
+ def setRois(self, rois, order=None):
+ return self.roiTable.setRois(rois, order)
def getRois(self, order=None):
- """Return the currently defined ROIs, as an ordered dict.
+ return self.roiTable.getRois(order)
- The dictionary keys are the ROI names.
- Each value is a sub-dictionary of ROI info with the following fields:
+ def setMiddleROIMarkerFlag(self, flag=True):
+ return self.roiTable.setMiddleROIMarkerFlag(flag)
- - ``"from"``: x coordinate of the left limit, as a float
- - ``"to"``: x coordinate of the right limit, as a float
- - ``"type"``: type of ROI, as a string (e.g "channels", "energy")
+ def _add(self):
+ """Add button clicked handler"""
+ def getNextRoiName():
+ rois = self.roiTable.getRois(order=None)
+ roisNames = []
+ [roisNames.append(roiName) for roiName in rois]
+ nrois = len(rois)
+ if nrois == 0:
+ return "ICR"
+ else:
+ i = 1
+ newroi = "newroi %d" % i
+ while newroi in roisNames:
+ i += 1
+ newroi = "newroi %d" % i
+ return newroi
+ roi = ROI(name=getNextRoiName())
- :param order: Field used for ordering the ROIs.
- One of "from", "to", "type", "netcounts", "rawcounts".
- None (default) to get the same order as displayed in the widget.
- :return: Ordered dictionary of ROI information
- """
- roilist, roidict = self.roiTable.getROIListAndDict()
- if order is None or order.lower() == "none":
- ordered_roilist = roilist
+ if roi.getName() == "ICR":
+ roi.setType("Default")
else:
- assert order in ["from", "to", "type", "netcounts", "rawcounts"]
- ordered_roilist = sorted(roidict.keys(),
- key=lambda roi_name: roidict[roi_name].get(order))
-
- return OrderedDict([(name, roidict[name]) for name in ordered_roilist])
+ roi.setType(self.getPlotWidget().getXAxis().getLabel())
- def setMiddleROIMarkerFlag(self, flag=True):
- """Activate or deactivate middle marker.
+ xmin, xmax = self.getPlotWidget().getXAxis().getLimits()
+ fromdata = xmin + 0.25 * (xmax - xmin)
+ todata = xmin + 0.75 * (xmax - xmin)
+ if roi.isICR():
+ fromdata, dummy0, todata, dummy1 = self._getAllLimits()
+ roi.setFrom(fromdata)
+ roi.setTo(todata)
- This allows shifting both min and max limits at once, by dragging
- a marker located in the middle.
-
- :param bool flag: True to activate middle ROI marker
- """
- if flag:
- self._middleROIMarkerFlag = True
- else:
- self._middleROIMarkerFlag = False
+ self.roiTable.addRoi(roi)
- def _add(self):
- """Add button clicked handler"""
+ # back compatibility pymca roi signals
ddict = {}
ddict['event'] = "AddROI"
- roilist, roidict = self.roiTable.getROIListAndDict()
- ddict['roilist'] = roilist
- ddict['roidict'] = roidict
+ ddict['roilist'] = self.roiTable.roidict.values()
+ ddict['roidict'] = self.roiTable.roidict
self.sigROIWidgetSignal.emit(ddict)
+ # end back compatibility pymca roi signals
def _del(self):
"""Delete button clicked handler"""
- row = self.roiTable.currentRow()
- if row >= 0:
- index = self.roiTable.labels.index('Type')
- text = str(self.roiTable.item(row, index).text())
- if text.upper() != 'DEFAULT':
- index = self.roiTable.labels.index('ROI')
- key = str(self.roiTable.item(row, index).text())
- else:
- # This is to prevent deleting ICR ROI, that is
- # usually initialized as "Default" type.
- return
- roilist, roidict = self.roiTable.getROIListAndDict()
- row = roilist.index(key)
- del roilist[row]
- del roidict[key]
- if len(roilist) > 0:
- currentroi = roilist[0]
- else:
- currentroi = None
-
- self.roiTable.fillFromROIDict(roilist=roilist,
- roidict=roidict,
- currentroi=currentroi)
- ddict = {}
- ddict['event'] = "DelROI"
- ddict['roilist'] = roilist
- ddict['roidict'] = roidict
- self.sigROIWidgetSignal.emit(ddict)
-
- def _forward(self, ddict):
- """Broadcast events from ROITable signal"""
+ self.roiTable.deleteActiveRoi()
+
+ # back compatibility pymca roi signals
+ ddict = {}
+ ddict['event'] = "DelROI"
+ ddict['roilist'] = self.roiTable.roidict.values()
+ ddict['roidict'] = self.roiTable.roidict
self.sigROIWidgetSignal.emit(ddict)
+ # end back compatibility pymca roi signals
def _reset(self):
"""Reset button clicked handler"""
+ self.roiTable.clear()
+ self._add()
+
+ # back compatibility pymca roi signals
ddict = {}
ddict['event'] = "ResetROI"
- roilist0, roidict0 = self.roiTable.getROIListAndDict()
- index = 0
- for key in roilist0:
- if roidict0[key]['type'].upper() == 'DEFAULT':
- index = roilist0.index(key)
- break
- roilist = []
- roidict = {}
- if len(roilist0):
- roilist.append(roilist0[index])
- roidict[roilist[0]] = {}
- roidict[roilist[0]].update(roidict0[roilist[0]])
- self.roiTable.fillFromROIDict(roilist=roilist, roidict=roidict)
- ddict['roilist'] = roilist
- ddict['roidict'] = roidict
+ ddict['roilist'] = self.roiTable.roidict.values()
+ ddict['roidict'] = self.roiTable.roidict
self.sigROIWidgetSignal.emit(ddict)
+ # end back compatibility pymca roi signals
def _load(self):
"""Load button clicked handler"""
@@ -334,32 +283,22 @@ class CurvesROIWidget(qt.QWidget):
dialog.close()
self.roiFileDir = os.path.dirname(outputFile)
- self.load(outputFile)
+ self.roiTable.load(outputFile)
+
+ # back compatibility pymca roi signals
+ ddict = {}
+ ddict['event'] = "LoadROI"
+ ddict['roilist'] = self.roiTable.roidict.values()
+ ddict['roidict'] = self.roiTable.roidict
+ self.sigROIWidgetSignal.emit(ddict)
+ # end back compatibility pymca roi signals
def load(self, filename):
"""Load ROI widget information from a file storing a dict of ROI.
:param str filename: The file from which to load ROI
"""
- rois = dictdump.load(filename)
- currentROI = None
- if self.roiTable.rowCount():
- item = self.roiTable.item(self.roiTable.currentRow(), 0)
- if item is not None:
- currentROI = str(item.text())
-
- # Remove rawcounts and netcounts from ROIs
- for roi in rois['ROI']['roidict'].values():
- roi.pop('rawcounts', None)
- roi.pop('netcounts', None)
-
- self.roiTable.fillFromROIDict(roilist=rois['ROI']['roilist'],
- roidict=rois['ROI']['roidict'],
- currentroi=currentROI)
-
- roilist, roidict = self.roiTable.getROIListAndDict()
- event = {'event': 'LoadROI', 'roilist': roilist, 'roidict': roidict}
- self.sigROIWidgetSignal.emit(event)
+ self.roiTable.load(filename)
def _save(self):
"""Save button clicked handler"""
@@ -396,142 +335,24 @@ class CurvesROIWidget(qt.QWidget):
:param str filename: The file to which to save the ROIs
"""
- roilist, roidict = self.roiTable.getROIListAndDict()
- datadict = {'ROI': {'roilist': roilist, 'roidict': roidict}}
- dictdump.dump(datadict, filename)
+ self.roiTable.save(filename)
def setHeader(self, text='ROIs'):
"""Set the header text of this widget"""
self.headerLabel.setText("<b>%s<\b>" % text)
- def _roiSignal(self, ddict):
- """Handle ROI widget signal"""
- _logger.debug("CurvesROIWidget._roiSignal %s", str(ddict))
- plot = self.getPlotWidget()
- if plot is None:
- return
-
- if ddict['event'] == "AddROI":
- xmin, xmax = plot.getXAxis().getLimits()
- fromdata = xmin + 0.25 * (xmax - xmin)
- todata = xmin + 0.75 * (xmax - xmin)
- plot.remove('ROI min', kind='marker')
- plot.remove('ROI max', kind='marker')
- if self._middleROIMarkerFlag:
- plot.remove('ROI middle', kind='marker')
- roiList, roiDict = self.roiTable.getROIListAndDict()
- nrois = len(roiList)
- if nrois == 0:
- newroi = "ICR"
- fromdata, dummy0, todata, dummy1 = self._getAllLimits()
- draggable = False
- color = 'black'
- else:
- # find the next index free for newroi.
- for i in range(nrois):
- i += 1
- newroi = "newroi %d" % i
- if newroi not in roiList:
- break
- color = 'blue'
- draggable = True
- plot.addXMarker(fromdata,
- legend='ROI min',
- text='ROI min',
- color=color,
- draggable=draggable)
- plot.addXMarker(todata,
- legend='ROI max',
- text='ROI max',
- color=color,
- draggable=draggable)
- if draggable and self._middleROIMarkerFlag:
- pos = 0.5 * (fromdata + todata)
- plot.addXMarker(pos,
- legend='ROI middle',
- text="",
- color='yellow',
- draggable=draggable)
- roiList.append(newroi)
- roiDict[newroi] = {}
- if newroi == "ICR":
- roiDict[newroi]['type'] = "Default"
- else:
- roiDict[newroi]['type'] = plot.getXAxis().getLabel()
- roiDict[newroi]['from'] = fromdata
- roiDict[newroi]['to'] = todata
- self.roiTable.fillFromROIDict(roilist=roiList,
- roidict=roiDict,
- currentroi=newroi)
- self.currentROI = newroi
- self.calculateRois()
- elif ddict['event'] in ['DelROI', "ResetROI"]:
- plot.remove('ROI min', kind='marker')
- plot.remove('ROI max', kind='marker')
- if self._middleROIMarkerFlag:
- plot.remove('ROI middle', kind='marker')
- roiList, roiDict = self.roiTable.getROIListAndDict()
- roiDictKeys = list(roiDict.keys())
- if len(roiDictKeys):
- currentroi = roiDictKeys[0]
- else:
- # create again the ICR
- ddict = {"event": "AddROI"}
- return self._roiSignal(ddict)
-
- self.roiTable.fillFromROIDict(roilist=roiList,
- roidict=roiDict,
- currentroi=currentroi)
- self.currentROI = currentroi
-
- elif ddict['event'] == 'LoadROI':
- self.calculateRois()
+ @deprecation.deprecated(replacement="calculateRois",
+ reason="CamelCase convention",
+ since_version="0.7")
+ def calculateROIs(self, *args, **kw):
+ self.calculateRois(*args, **kw)
- elif ddict['event'] == 'selectionChanged':
- _logger.debug("Selection changed")
- self.roilist, self.roidict = self.roiTable.getROIListAndDict()
- fromdata = ddict['roi']['from']
- todata = ddict['roi']['to']
- plot.remove('ROI min', kind='marker')
- plot.remove('ROI max', kind='marker')
- if ddict['key'] == 'ICR':
- draggable = False
- color = 'black'
- else:
- draggable = True
- color = 'blue'
- plot.addXMarker(fromdata,
- legend='ROI min',
- text='ROI min',
- color=color,
- draggable=draggable)
- plot.addXMarker(todata,
- legend='ROI max',
- text='ROI max',
- color=color,
- draggable=draggable)
- if draggable and self._middleROIMarkerFlag:
- pos = 0.5 * (fromdata + todata)
- plot.addXMarker(pos,
- legend='ROI middle',
- text="",
- color='yellow',
- draggable=True)
- self.currentROI = ddict['key']
- if ddict['colheader'] in ['From', 'To']:
- dict0 = {}
- dict0['event'] = "SetActiveCurveEvent"
- dict0['legend'] = plot.getActiveCurve(just_legend=1)
- plot.setActiveCurve(dict0['legend'])
- elif ddict['colheader'] == 'Raw Counts':
- pass
- elif ddict['colheader'] == 'Net Counts':
- pass
- else:
- self._emitCurrentROISignal()
+ def calculateRois(self, roiList=None, roiDict=None):
+ """Compute ROI information"""
+ return self.roiTable.calculateRois()
- else:
- _logger.debug("Unknown or ignored event %s", ddict['event'])
+ def showAllMarkers(self, _show=True):
+ self.roiTable.showAllMarkers(_show)
def _getAllLimits(self):
"""Retrieve the limits based on the curves."""
@@ -565,429 +386,1121 @@ class CurvesROIWidget(qt.QWidget):
return xmin, ymin, xmax, ymax
- @deprecation.deprecated(replacement="calculateRois",
- reason="CamelCase convention")
- def calculateROIs(self, *args, **kw):
- self.calculateRois(*args, **kw)
+ def showEvent(self, event):
+ self._visibilityChangedHandler(visible=True)
+ qt.QWidget.showEvent(self, event)
- def calculateRois(self, roiList=None, roiDict=None):
- """Compute ROI information"""
- if roiList is None or roiDict is None:
- roiList, roiDict = self.roiTable.getROIListAndDict()
+ def hideEvent(self, event):
+ self._visibilityChangedHandler(visible=False)
+ qt.QWidget.hideEvent(self, event)
- plot = self.getPlotWidget()
- if plot is None:
- activeCurve = None
- else:
- activeCurve = plot.getActiveCurve(just_legend=False)
+ def _visibilityChangedHandler(self, visible):
+ """Handle widget's visibility updates.
- if activeCurve is None:
- xproc = None
- yproc = None
- self.setHeader()
- else:
- x = activeCurve.getXData(copy=False)
- y = activeCurve.getYData(copy=False)
- legend = activeCurve.getLegend()
- idx = numpy.argsort(x, kind='mergesort')
- xproc = numpy.take(x, idx)
- yproc = numpy.take(y, idx)
- self.setHeader('ROIs of %s' % legend)
-
- for key in roiList:
- if key == 'ICR':
- if xproc is not None:
- roiDict[key]['from'] = xproc.min()
- roiDict[key]['to'] = xproc.max()
- else:
- roiDict[key]['from'] = 0
- roiDict[key]['to'] = -1
- fromData = roiDict[key]['from']
- toData = roiDict[key]['to']
- if xproc is not None:
- idx = numpy.nonzero((fromData <= xproc) &
- (xproc <= toData))[0]
- if len(idx):
- xw = xproc[idx]
- yw = yproc[idx]
- rawCounts = yw.sum(dtype=numpy.float)
- deltaX = xw[-1] - xw[0]
- deltaY = yw[-1] - yw[0]
- if deltaX > 0.0:
- slope = (deltaY / deltaX)
- background = yw[0] + slope * (xw - xw[0])
- netCounts = (rawCounts -
- background.sum(dtype=numpy.float))
- else:
- netCounts = 0.0
- else:
- rawCounts = 0.0
- netCounts = 0.0
- roiDict[key]['rawcounts'] = rawCounts
- roiDict[key]['netcounts'] = netCounts
- else:
- roiDict[key].pop('rawcounts', None)
- roiDict[key].pop('netcounts', None)
+ It is connected to plot signals only when visible.
+ """
+ if visible:
+ # if no ROI existing yet, add the default one
+ if self.roiTable.rowCount() is 0:
+ self._add()
+ self.calculateRois()
- self.roiTable.fillFromROIDict(
- roilist=roiList,
- roidict=roiDict,
- currentroi=self.currentROI if self.currentROI in roiList else None)
+ def fillFromROIDict(self, *args, **kwargs):
+ self.roiTable.fillFromROIDict(*args, **kwargs)
def _emitCurrentROISignal(self):
ddict = {}
ddict['event'] = "currentROISignal"
- _roiList, roiDict = self.roiTable.getROIListAndDict()
- if self.currentROI in roiDict:
- ddict['ROI'] = roiDict[self.currentROI]
+ if self.roiTable.activeRoi is not None:
+ ddict['ROI'] = self.roiTable.activeRoi.toDict()
+ ddict['current'] = self.roiTable.activeRoi.getName()
else:
- self.currentROI = None
- ddict['current'] = self.currentROI
+ ddict['current'] = None
self.sigROISignal.emit(ddict)
- def _handleROIMarkerEvent(self, ddict):
- """Handle plot signals related to marker events."""
- if ddict['event'] == 'markerMoved':
+ @property
+ def currentRoi(self):
+ return self.roiTable.activeRoi
- label = ddict['label']
- if label not in ['ROI min', 'ROI max', 'ROI middle']:
- return
- roiList, roiDict = self.roiTable.getROIListAndDict()
- if self.currentROI is None:
- return
- if self.currentROI not in roiDict:
- return
+class _FloatItem(qt.QTableWidgetItem):
+ """
+ Simple QTableWidgetItem overloading the < operator to deal with ordering
+ """
+ def __init__(self):
+ qt.QTableWidgetItem.__init__(self, type=qt.QTableWidgetItem.Type)
- plot = self.getPlotWidget()
- if plot is None:
- return
+ def __lt__(self, other):
+ if self.text() in ('', ROITable.INFO_NOT_FOUND):
+ return False
+ if other.text() in ('', ROITable.INFO_NOT_FOUND):
+ return True
+ return float(self.text()) < float(other.text())
+
+
+class ROITable(TableWidget):
+ """Table widget displaying ROI information.
+
+ See :class:`QTableWidget` for constructor arguments.
- x = ddict['x']
-
- if label == 'ROI min':
- roiDict[self.currentROI]['from'] = x
- if self._middleROIMarkerFlag:
- pos = 0.5 * (roiDict[self.currentROI]['to'] +
- roiDict[self.currentROI]['from'])
- plot.addXMarker(pos,
- legend='ROI middle',
- text='',
- color='yellow',
- draggable=True)
- elif label == 'ROI max':
- roiDict[self.currentROI]['to'] = x
- if self._middleROIMarkerFlag:
- pos = 0.5 * (roiDict[self.currentROI]['to'] +
- roiDict[self.currentROI]['from'])
- plot.addXMarker(pos,
- legend='ROI middle',
- text='',
- color='yellow',
- draggable=True)
- elif label == 'ROI middle':
- delta = x - 0.5 * (roiDict[self.currentROI]['from'] +
- roiDict[self.currentROI]['to'])
- roiDict[self.currentROI]['from'] += delta
- roiDict[self.currentROI]['to'] += delta
- plot.addXMarker(roiDict[self.currentROI]['from'],
- legend='ROI min',
- text='ROI min',
- color='blue',
- draggable=True)
- plot.addXMarker(roiDict[self.currentROI]['to'],
- legend='ROI max',
- text='ROI max',
- color='blue',
- draggable=True)
+ Behavior: listen at the active curve changed only when the widget is
+ visible. Otherwise won't compute the row and net counts...
+ """
+
+ activeROIChanged = qt.Signal()
+ """Signal emitted when the active roi changed or when the value of the
+ active roi are changing"""
+
+ COLUMNS_INDEX = OrderedDict([
+ ('ID', 0),
+ ('ROI', 1),
+ ('Type', 2),
+ ('From', 3),
+ ('To', 4),
+ ('Raw Counts', 5),
+ ('Net Counts', 6),
+ ('Raw Area', 7),
+ ('Net Area', 8),
+ ])
+
+ COLUMNS = list(COLUMNS_INDEX.keys())
+
+ INFO_NOT_FOUND = '????????'
+
+ def __init__(self, parent=None, plot=None, rois=None):
+ super(ROITable, self).__init__(parent)
+ self._showAllMarkers = False
+ self._userIsEditingRoi = False
+ """bool used to avoid conflict when editing the ROI object"""
+ self._isConnected = False
+ self._roiToItems = {}
+ self._roiDict = {}
+ """dict of ROI object. Key is ROi id, value is the ROI object"""
+ self._markersHandler = _RoiMarkerManager()
+
+ """
+ Associate for each marker legend used when the `_showAllMarkers` option
+ is active a roi.
+ """
+ self.setColumnCount(len(self.COLUMNS))
+ self.setPlot(plot)
+ self.__setTooltip()
+ self.setSortingEnabled(True)
+ self.itemChanged.connect(self._itemChanged)
+
+ @property
+ def roidict(self):
+ return self._getRoiDict()
+
+ @property
+ def activeRoi(self):
+ return self._markersHandler._activeRoi
+
+ def _getRoiDict(self):
+ ddict = {}
+ for id in self._roiDict:
+ ddict[self._roiDict[id].getName()] = self._roiDict[id]
+ return ddict
+
+ def clear(self):
+ """
+ .. note:: clear the interface only. keep the roidict...
+ """
+ self._markersHandler.clear()
+ self._roiToItems = {}
+ self._roiDict = {}
+
+ qt.QTableWidget.clear(self)
+ self.setRowCount(0)
+ self.setHorizontalHeaderLabels(self.COLUMNS)
+ header = self.horizontalHeader()
+ if hasattr(header, 'setSectionResizeMode'): # Qt5
+ header.setSectionResizeMode(qt.QHeaderView.ResizeToContents)
+ else: # Qt4
+ header.setResizeMode(qt.QHeaderView.ResizeToContents)
+ self.sortByColumn(0, qt.Qt.AscendingOrder)
+ self.hideColumn(self.COLUMNS_INDEX['ID'])
+
+ def setPlot(self, plot):
+ self.clear()
+ self.plot = plot
+
+ def __setTooltip(self):
+ self.horizontalHeaderItem(self.COLUMNS_INDEX['ROI']).setToolTip(
+ 'Region of interest identifier')
+ self.horizontalHeaderItem(self.COLUMNS_INDEX['Type']).setToolTip(
+ 'Type of the ROI')
+ self.horizontalHeaderItem(self.COLUMNS_INDEX['From']).setToolTip(
+ 'X-value of the min point')
+ self.horizontalHeaderItem(self.COLUMNS_INDEX['To']).setToolTip(
+ 'X-value of the max point')
+ self.horizontalHeaderItem(self.COLUMNS_INDEX['Raw Counts']).setToolTip(
+ 'Estimation of the integral between y=0 and the selected curve')
+ self.horizontalHeaderItem(self.COLUMNS_INDEX['Net Counts']).setToolTip(
+ 'Estimation of the integral between the segment [maxPt, minPt] '
+ 'and the selected curve')
+
+ def setRois(self, rois, order=None):
+ """Set the ROIs by providing a dictionary of ROI information.
+
+ The dictionary keys are the ROI names.
+ Each value is a sub-dictionary of ROI info with the following fields:
+
+ - ``"from"``: x coordinate of the left limit, as a float
+ - ``"to"``: x coordinate of the right limit, as a float
+ - ``"type"``: type of ROI, as a string (e.g "channels", "energy")
+
+
+ :param roidict: Dictionary of ROIs
+ :param str order: Field used for ordering the ROIs.
+ One of "from", "to", "type".
+ None (default) for no ordering, or same order as specified
+ in parameter ``roidict`` if provided as an OrderedDict.
+ """
+ assert order in [None, "from", "to", "type"]
+ self.clear()
+
+ # backward compatibility since 0.10.0
+ if isinstance(rois, dict):
+ for roiName, roi in rois.items():
+ roi['name'] = roiName
+ _roi = ROI._fromDict(roi)
+ self.addRoi(_roi)
+ else:
+ for roi in rois:
+ assert isinstance(roi, ROI)
+ self.addRoi(roi)
+ self._updateMarkers()
+
+ def addRoi(self, roi):
+ """
+
+ :param :class:`ROI` roi: roi to add to the table
+ """
+ assert isinstance(roi, ROI)
+ self._getItem(name='ID', row=None, roi=roi)
+ self._roiDict[roi.getID()] = roi
+ self._markersHandler.add(roi, _RoiMarkerHandler(roi, self.plot))
+ self._updateRoiInfo(roi.getID())
+ callback = functools.partial(WeakMethodProxy(self._updateRoiInfo),
+ roi.getID())
+ roi.sigChanged.connect(callback)
+ # set it as the active one
+ self.setActiveRoi(roi)
+
+ def _getItem(self, name, row, roi):
+ if row:
+ item = self.item(row, self.COLUMNS_INDEX[name])
+ else:
+ item = None
+ if item:
+ return item
+ else:
+ if name == 'ID':
+ assert roi
+ if roi.getID() in self._roiToItems:
+ return self._roiToItems[roi.getID()]
+ else:
+ # create a new row
+ row = self.rowCount()
+ self.setRowCount(self.rowCount() + 1)
+ item = qt.QTableWidgetItem(str(roi.getID()),
+ type=qt.QTableWidgetItem.Type)
+ self._roiToItems[roi.getID()] = item
+ elif name == 'ROI':
+ item = qt.QTableWidgetItem(roi.getName() if roi else '',
+ type=qt.QTableWidgetItem.Type)
+ if roi.getName().upper() in ('ICR', 'DEFAULT'):
+ item.setFlags(qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled)
+ else:
+ item.setFlags(qt.Qt.ItemIsSelectable |
+ qt.Qt.ItemIsEnabled |
+ qt.Qt.ItemIsEditable)
+ elif name == 'Type':
+ item = qt.QTableWidgetItem(type=qt.QTableWidgetItem.Type)
+ item.setFlags((qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled))
+ elif name in ('To', 'From'):
+ item = _FloatItem()
+ if roi.getName().upper() in ('ICR', 'DEFAULT'):
+ item.setFlags(qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled)
+ else:
+ item.setFlags(qt.Qt.ItemIsSelectable |
+ qt.Qt.ItemIsEnabled |
+ qt.Qt.ItemIsEditable)
+ elif name in ('Raw Counts', 'Net Counts', 'Raw Area', 'Net Area'):
+ item = _FloatItem()
+ item.setFlags((qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled))
else:
- return
- self.calculateRois(roiList, roiDict)
- self._emitCurrentROISignal()
+ raise ValueError('item type not recognized')
+
+ self.setItem(row, self.COLUMNS_INDEX[name], item)
+ return item
+
+ def _itemChanged(self, item):
+ def getRoi():
+ IDItem = self.item(item.row(), self.COLUMNS_INDEX['ID'])
+ assert IDItem
+ id = int(IDItem.text())
+ assert id in self._roiDict
+ roi = self._roiDict[id]
+ return roi
+
+ def signalChanged(roi):
+ if self.activeRoi and roi.getID() == self.activeRoi.getID():
+ self.activeROIChanged.emit()
+
+ self._userIsEditingRoi = True
+ if item.column() in (self.COLUMNS_INDEX['To'], self.COLUMNS_INDEX['From']):
+ roi = getRoi()
+
+ if item.text() not in ('', self.INFO_NOT_FOUND):
+ try:
+ value = float(item.text())
+ except ValueError:
+ value = 0
+ changed = False
+ if item.column() == self.COLUMNS_INDEX['To']:
+ if value != roi.getTo():
+ roi.setTo(value)
+ changed = True
+ else:
+ assert(item.column() == self.COLUMNS_INDEX['From'])
+ if value != roi.getFrom():
+ roi.setFrom(value)
+ changed = True
+ if changed:
+ self._updateMarker(roi.getName())
+ signalChanged(roi)
+
+ if item.column() is self.COLUMNS_INDEX['ROI']:
+ roi = getRoi()
+ if roi.getName() != item.text():
+ roi.setName(item.text())
+ self._markersHandler.getMarkerHandler(roi.getID()).updateTexts()
+ signalChanged(roi)
+
+ self._userIsEditingRoi = False
+
+ def deleteActiveRoi(self):
+ """
+ remove the current active roi
+ """
+ activeItems = self.selectedItems()
+ if len(activeItems) is 0:
+ return
+ roiToRm = set()
+ for item in activeItems:
+ row = item.row()
+ itemID = self.item(row, self.COLUMNS_INDEX['ID'])
+ roiToRm.add(self._roiDict[int(itemID.text())])
+ [self.removeROI(roi) for roi in roiToRm]
+ self.setActiveRoi(None)
+
+ def removeROI(self, roi):
+ """
+ remove the requested roi
- def _visibilityChangedHandler(self, visible):
- """Handle widget's visibility updates.
+ :param str name: the name of the roi to remove from the table
+ """
+ if roi and roi.getID() in self._roiToItems:
+ item = self._roiToItems[roi.getID()]
+ self.removeRow(item.row())
+ del self._roiToItems[roi.getID()]
- It is connected to plot signals only when visible.
+ assert roi.getID() in self._roiDict
+ del self._roiDict[roi.getID()]
+ self._markersHandler.remove(roi)
+
+ callback = functools.partial(WeakMethodProxy(self._updateRoiInfo),
+ roi.getID())
+ roi.sigChanged.connect(callback)
+
+ def setActiveRoi(self, roi):
"""
- plot = self.getPlotWidget()
+ Define the given roi as the active one.
- if visible:
- if not self._isInit:
- # Deferred ROI widget init finalization
- self._finalizeInit()
-
- if not self._isConnected and plot is not None:
- plot.sigPlotSignal.connect(self._handleROIMarkerEvent)
- plot.sigActiveCurveChanged.connect(
- self._activeCurveChanged)
- self._isConnected = True
+ .. warning:: this roi should already be registred / added to the table
- self.calculateRois()
+ :param :class:`ROI` roi: the roi to defined as active
+ """
+ if roi is None:
+ self.clearSelection()
+ self._markersHandler.setActiveRoi(None)
+ self.activeROIChanged.emit()
else:
- if self._isConnected:
- if plot is not None:
- plot.sigPlotSignal.disconnect(self._handleROIMarkerEvent)
- plot.sigActiveCurveChanged.disconnect(
- self._activeCurveChanged)
- self._isConnected = False
+ assert isinstance(roi, ROI)
+ if roi and roi.getID() in self._roiToItems.keys():
+ self.selectRow(self._roiToItems[roi.getID()].row())
+ self._markersHandler.setActiveRoi(roi)
+ self.activeROIChanged.emit()
+
+ def _updateRoiInfo(self, roiID):
+ if self._userIsEditingRoi is True:
+ return
+ if roiID not in self._roiDict:
+ return
+ roi = self._roiDict[roiID]
+ if roi.isICR():
+ activeCurve = self.plot.getActiveCurve()
+ if activeCurve:
+ xData = activeCurve.getXData()
+ if len(xData) > 0:
+ min, max = min_max(xData)
+ roi.blockSignals(True)
+ roi.setFrom(min)
+ roi.setTo(max)
+ roi.blockSignals(False)
+
+ itemID = self._getItem(name='ID', roi=roi, row=None)
+ itemName = self._getItem(name='ROI', row=itemID.row(), roi=roi)
+ itemName.setText(roi.getName())
+
+ itemType = self._getItem(name='Type', row=itemID.row(), roi=roi)
+ itemType.setText(roi.getType() or self.INFO_NOT_FOUND)
+
+ itemFrom = self._getItem(name='From', row=itemID.row(), roi=roi)
+ fromdata = str(roi.getFrom()) if roi.getFrom() is not None else self.INFO_NOT_FOUND
+ itemFrom.setText(fromdata)
+
+ itemTo = self._getItem(name='To', row=itemID.row(), roi=roi)
+ todata = str(roi.getTo()) if roi.getTo() is not None else self.INFO_NOT_FOUND
+ itemTo.setText(todata)
+
+ rawCounts, netCounts = roi.computeRawAndNetCounts(
+ curve=self.plot.getActiveCurve(just_legend=False))
+ itemRawCounts = self._getItem(name='Raw Counts', row=itemID.row(),
+ roi=roi)
+ rawCounts = str(rawCounts) if rawCounts is not None else self.INFO_NOT_FOUND
+ itemRawCounts.setText(rawCounts)
+
+ itemNetCounts = self._getItem(name='Net Counts', row=itemID.row(),
+ roi=roi)
+ netCounts = str(netCounts) if netCounts is not None else self.INFO_NOT_FOUND
+ itemNetCounts.setText(netCounts)
+
+ rawArea, netArea = roi.computeRawAndNetArea(
+ curve=self.plot.getActiveCurve(just_legend=False))
+ itemRawArea = self._getItem(name='Raw Area', row=itemID.row(),
+ roi=roi)
+ rawArea = str(rawArea) if rawArea is not None else self.INFO_NOT_FOUND
+ itemRawArea.setText(rawArea)
+
+ itemNetArea = self._getItem(name='Net Area', row=itemID.row(),
+ roi=roi)
+ netArea = str(netArea) if netArea is not None else self.INFO_NOT_FOUND
+ itemNetArea.setText(netArea)
+
+ if self.activeRoi and roi.getID() == self.activeRoi.getID():
+ self.activeROIChanged.emit()
+
+ def currentChanged(self, current, previous):
+ if previous and current.row() != previous.row() and current.row() >= 0:
+ roiItem = self.item(current.row(),
+ self.COLUMNS_INDEX['ID'])
+
+ assert roiItem
+ self.setActiveRoi(self._roiDict[int(roiItem.text())])
+ self._markersHandler.updateAllMarkers()
+ qt.QTableWidget.currentChanged(self, current, previous)
+
+ @deprecation.deprecated(reason="Removed",
+ replacement="roidict and roidict.values()",
+ since_version="0.10.0")
+ def getROIListAndDict(self):
+ """
- def _activeCurveChanged(self, *args):
- """Recompute ROIs when active curve changed."""
- self.calculateRois()
+ :return: the list of roi objects and the dictionary of roi name to roi
+ object.
+ """
+ roidict = self._roiDict
+ return list(roidict.values()), roidict
- def _finalizeInit(self):
- self._isInit = True
- self.sigROIWidgetSignal.connect(self._roiSignal)
- # initialize with the ICR if no ROi existing yet
- if len(self.getRois()) is 0:
- self._roiSignal({'event': "AddROI"})
+ def calculateRois(self, roiList=None, roiDict=None):
+ """
+ Update values of all registred rois (raw and net counts in particular)
+ :param roiList: deprecated parameter
+ :param roiDict: deprecated parameter
+ """
+ if roiDict:
+ deprecation.deprecated_warning(name='roiDict', type_='Parameter',
+ reason='Unused parameter',
+ since_version="0.10.0")
+ if roiList:
+ deprecation.deprecated_warning(name='roiList', type_='Parameter',
+ reason='Unused parameter',
+ since_version="0.10.0")
+
+ for roiID in self._roiDict:
+ self._updateRoiInfo(roiID)
+
+ def _updateMarker(self, roiID):
+ """Make sure the marker of the given roi name is updated"""
+ if self._showAllMarkers or (self.activeRoi
+ and self.activeRoi.getName() == roiID):
+ self._updateMarkers()
+
+ def _updateMarkers(self):
+ if self._showAllMarkers is True:
+ self._markersHandler.updateMarkers()
+ else:
+ if not self.activeRoi or not self.plot:
+ return
+ assert isinstance(self.activeRoi, ROI)
+ markerHandler = self._markersHandler.getMarkerHandler(self.activeRoi.getID())
+ if markerHandler is not None:
+ markerHandler.updateMarkers()
-class ROITable(qt.QTableWidget):
- """Table widget displaying ROI information.
+ def getRois(self, order):
+ """
+ Return the currently defined ROIs, as an ordered dict.
- See :class:`QTableWidget` for constructor arguments.
- """
+ The dictionary keys are the ROI names.
+ Each value is a :class:`ROI` object..
- sigROITableSignal = qt.Signal(object)
- """Signal of ROI table modifications.
- """
+ :param order: Field used for ordering the ROIs.
+ One of "from", "to", "type", "netcounts", "rawcounts".
+ None (default) to get the same order as displayed in the widget.
+ :return: Ordered dictionary of ROI information
+ """
- def __init__(self, *args, **kwargs):
- super(ROITable, self).__init__(*args, **kwargs)
- self.setRowCount(1)
- self.labels = 'ROI', 'Type', 'From', 'To', 'Raw Counts', 'Net Counts'
- self.setColumnCount(len(self.labels))
- self.setSortingEnabled(False)
+ if order is None or order.lower() == "none":
+ ordered_roilist = list(self._roiDict.values())
+ res = OrderedDict([(roi.getName(), self._roiDict[roi.getID()]) for roi in ordered_roilist])
+ else:
+ assert order in ["from", "to", "type", "netcounts", "rawcounts"]
+ ordered_roilist = sorted(self._roiDict.keys(),
+ key=lambda roi_id: self._roiDict[roi_id].get(order))
+ res = OrderedDict([(roi.getName(), self._roiDict[id]) for id in ordered_roilist])
+
+ return res
+
+ def save(self, filename):
+ """
+ Save current ROIs of the widget as a dict of ROI to a file.
+
+ :param str filename: The file to which to save the ROIs
+ """
+ roilist = []
+ roidict = {}
+ for roiID, roi in self._roiDict.items():
+ roilist.append(roi.toDict())
+ roidict[roi.getName()] = roi.toDict()
+ datadict = {'ROI': {'roilist': roilist, 'roidict': roidict}}
+ dictdump.dump(datadict, filename)
- for index, label in enumerate(self.labels):
- item = self.horizontalHeaderItem(index)
- if item is None:
- item = qt.QTableWidgetItem(label,
- qt.QTableWidgetItem.Type)
- item.setText(label)
- self.setHorizontalHeaderItem(index, item)
+ def load(self, filename):
+ """
+ Load ROI widget information from a file storing a dict of ROI.
- self.roidict = {}
- self.roilist = []
+ :param str filename: The file from which to load ROI
+ """
+ roisDict = dictdump.load(filename)
+ rois = []
- self.building = False
- self.fillFromROIDict(roilist=self.roilist, roidict=self.roidict)
+ # Remove rawcounts and netcounts from ROIs
+ for roiDict in roisDict['ROI']['roidict'].values():
+ roiDict.pop('rawcounts', None)
+ roiDict.pop('netcounts', None)
+ rois.append(ROI._fromDict(roiDict))
- self.cellClicked[(int, int)].connect(self._cellClickedSlot)
- self.cellChanged[(int, int)].connect(self._cellChangedSlot)
- verticalHeader = self.verticalHeader()
- verticalHeader.sectionClicked[int].connect(self._rowChangedSlot)
+ self.setRois(rois)
- self.__setTooltip()
+ def showAllMarkers(self, _show=True):
+ """
- def __setTooltip(self):
- assert(self.labels[0] == 'ROI')
- self.horizontalHeaderItem(0).setToolTip('Region of interest identifier')
- assert(self.labels[1] == 'Type')
- self.horizontalHeaderItem(1).setToolTip('Type of the ROI')
- assert(self.labels[2] == 'From')
- self.horizontalHeaderItem(2).setToolTip('X-value of the min point')
- assert(self.labels[3] == 'To')
- self.horizontalHeaderItem(3).setToolTip('X-value of the max point')
- assert(self.labels[4] == 'Raw Counts')
- self.horizontalHeaderItem(4).setToolTip('Estimation of the integral \
- between y=0 and the selected curve')
- assert(self.labels[5] == 'Net Counts')
- self.horizontalHeaderItem(5).setToolTip('Estimation of the integral \
- between the segment [maxPt, minPt] and the selected curve')
+ :param bool _show: if true show all the markers of all the ROIs
+ boundaries otherwise will only show the one of
+ the active ROI.
+ """
+ self._markersHandler.setShowAllMarkers(_show)
+
+ def setMiddleROIMarkerFlag(self, flag=True):
+ """
+ Activate or deactivate middle marker.
+
+ This allows shifting both min and max limits at once, by dragging
+ a marker located in the middle.
+
+ :param bool flag: True to activate middle ROI marker
+ """
+ self._markersHandler._middleROIMarkerFlag = flag
+
+ def _handleROIMarkerEvent(self, ddict):
+ """Handle plot signals related to marker events."""
+ if ddict['event'] == 'markerMoved':
+ label = ddict['label']
+ roiID = self._markersHandler.getRoiID(markerID=label)
+ if roiID:
+ self._markersHandler.changePosition(markerID=label,
+ x=ddict['x'])
+ self._updateRoiInfo(roiID)
+
+ def showEvent(self, event):
+ self._visibilityChangedHandler(visible=True)
+ qt.QWidget.showEvent(self, event)
+
+ def hideEvent(self, event):
+ self._visibilityChangedHandler(visible=False)
+ qt.QWidget.hideEvent(self, event)
+
+ def _visibilityChangedHandler(self, visible):
+ """Handle widget's visibility updates.
+
+ It is connected to plot signals only when visible.
+ """
+ if visible:
+ assert self.plot
+ if self._isConnected is False:
+ self.plot.sigPlotSignal.connect(self._handleROIMarkerEvent)
+ self.plot.sigActiveCurveChanged.connect(self._activeCurveChanged)
+ self._isConnected = True
+ self.calculateRois()
+ else:
+ if self._isConnected:
+ self.plot.sigPlotSignal.disconnect(self._handleROIMarkerEvent)
+ self.plot.sigActiveCurveChanged.disconnect(self._activeCurveChanged)
+ self._isConnected = False
+
+ def _activeCurveChanged(self, curve):
+ self.calculateRois()
+
+ def setCountsVisible(self, visible):
+ """
+ Display the columns relative to areas or not
+
+ :param bool visible: True if the columns 'Raw Area' and 'Net Area'
+ should be visible.
+ """
+ if visible is True:
+ self.showColumn(self.COLUMNS_INDEX['Raw Counts'])
+ self.showColumn(self.COLUMNS_INDEX['Net Counts'])
+ else:
+ self.hideColumn(self.COLUMNS_INDEX['Raw Counts'])
+ self.hideColumn(self.COLUMNS_INDEX['Net Counts'])
+
+ def setAreaVisible(self, visible):
+ """
+ Display the columns relative to areas or not
+
+ :param bool visible: True if the columns 'Raw Area' and 'Net Area'
+ should be visible.
+ """
+ if visible is True:
+ self.showColumn(self.COLUMNS_INDEX['Raw Area'])
+ self.showColumn(self.COLUMNS_INDEX['Net Area'])
+ else:
+ self.hideColumn(self.COLUMNS_INDEX['Raw Area'])
+ self.hideColumn(self.COLUMNS_INDEX['Net Area'])
def fillFromROIDict(self, roilist=(), roidict=None, currentroi=None):
- """Set the ROIs by providing a list of ROI names and a dictionary
- of ROI information for each ROI.
+ """
+ This function API is kept for compatibility.
+ But `setRois` should be preferred.
+ Set the ROIs by providing a list of ROI names and a dictionary
+ of ROI information for each ROI.
The ROI names must match an existing dictionary key.
The name list is used to provide an order for the ROIs.
-
The dictionary's values are sub-dictionaries containing 3
mandatory fields:
- - ``"from"``: x coordinate of the left limit, as a float
- - ``"to"``: x coordinate of the right limit, as a float
- - ``"type"``: type of ROI, as a string (e.g "channels", "energy")
+ - ``"from"``: x coordinate of the left limit, as a float
+ - ``"to"``: x coordinate of the right limit, as a float
+ - ``"type"``: type of ROI, as a string (e.g "channels", "energy")
:param roilist: List of ROI names (keys of roidict)
:type roilist: List
:param dict roidict: Dict of ROI information
:param currentroi: Name of the selected ROI or None (no selection)
"""
- if roidict is None:
- roidict = {}
-
- self.building = True
- line0 = 0
- self.roilist = []
- self.roidict = {}
- for key in roilist:
- if key in roidict.keys():
- roi = roidict[key]
- self.roilist.append(key)
- self.roidict[key] = {}
- self.roidict[key].update(roi)
- line0 = line0 + 1
- nlines = self.rowCount()
- if (line0 > nlines):
- self.setRowCount(line0)
- line = line0 - 1
- self.roidict[key]['line'] = line
- ROI = key
- roitype = "%s" % roi['type']
- fromdata = "%6g" % (roi['from'])
- todata = "%6g" % (roi['to'])
- if 'rawcounts' in roi:
- rawcounts = "%6g" % (roi['rawcounts'])
- else:
- rawcounts = " ?????? "
- if 'netcounts' in roi:
- netcounts = "%6g" % (roi['netcounts'])
- else:
- netcounts = " ?????? "
- fields = [ROI, roitype, fromdata, todata, rawcounts, netcounts]
- col = 0
- for field in fields:
- key2 = self.item(line, col)
- if key2 is None:
- key2 = qt.QTableWidgetItem(field,
- qt.QTableWidgetItem.Type)
- self.setItem(line, col, key2)
- else:
- key2.setText(field)
- if (ROI.upper() == 'ICR') or (ROI.upper() == 'DEFAULT'):
- key2.setFlags(qt.Qt.ItemIsSelectable |
- qt.Qt.ItemIsEnabled)
- else:
- if col in [0, 2, 3]:
- key2.setFlags(qt.Qt.ItemIsSelectable |
- qt.Qt.ItemIsEnabled |
- qt.Qt.ItemIsEditable)
- else:
- key2.setFlags(qt.Qt.ItemIsSelectable |
- qt.Qt.ItemIsEnabled)
- col = col + 1
- self.setRowCount(line0)
- i = 0
- for _label in self.labels:
- self.resizeColumnToContents(i)
- i = i + 1
- self.sortByColumn(2, qt.Qt.AscendingOrder)
- for i in range(len(self.roilist)):
- key = str(self.item(i, 0).text())
- self.roilist[i] = key
- self.roidict[key]['line'] = i
- if len(self.roilist) == 1:
- self.selectRow(0)
+ if roidict is not None:
+ self.setRois(roidict)
else:
- if currentroi in self.roidict.keys():
- self.selectRow(self.roidict[currentroi]['line'])
- _logger.debug("Qt4 ensureCellVisible to be implemented")
- self.building = False
+ self.setRois(roilist)
+ if currentroi:
+ self.setActiveRoi(currentroi)
- def getROIListAndDict(self):
- """Return the currently defined ROIs, as a 2-tuple
- ``(roiList, roiDict)``
- ``roiList`` is a list of ROI names.
- ``roiDict`` is a dictionary of ROI info.
+_indexNextROI = 0
- The ROI names must match an existing dictionary key.
- The name list is used to provide an order for the ROIs.
- The dictionary's values are sub-dictionaries containing 3
- fields:
+class ROI(qt.QObject):
+ """The Region Of Interest is defined by:
- - ``"from"``: x coordinate of the left limit, as a float
- - ``"to"``: x coordinate of the right limit, as a float
- - ``"type"``: type of ROI, as a string (e.g "channels", "energy")
+ - A name
+ - A type. The type is the label of the x axis. This can be used to apply or
+ not some ROI to a curve and do some post processing.
+ - The x coordinate of the left limit (fromdata)
+ - The x coordinate of the right limit (todata)
+ :param str: name of the ROI
+ :param fromdata: left limit of the roi
+ :param todata: right limit of the roi
+ :param type: type of the ROI
+ """
+
+ sigChanged = qt.Signal()
+ """Signal emitted when the ROI is edited"""
+
+ def __init__(self, name, fromdata=None, todata=None, type_=None):
+ qt.QObject.__init__(self)
+ assert type(name) is str
+ global _indexNextROI
+ self._id = _indexNextROI
+ _indexNextROI += 1
+
+ self._name = name
+ self._fromdata = fromdata
+ self._todata = todata
+ self._type = type_ or 'Default'
- :return: ordered dict as a tuple of (list of ROI names, dict of info)
+ def getID(self):
"""
- return self.roilist, self.roidict
- def _cellClickedSlot(self, *var, **kw):
- # selection changed event, get the current selection
- row = self.currentRow()
- col = self.currentColumn()
- if row >= 0 and row < len(self.roilist):
- item = self.item(row, 0)
- text = '' if item is None else str(item.text())
- self.roilist[row] = text
- self._emitSelectionChangedSignal(row, col)
+ :return int: the unique ID of the ROI
+ """
+ return self._id
- def _rowChangedSlot(self, row):
- self._emitSelectionChangedSignal(row, 0)
+ def setType(self, type_):
+ """
- def _cellChangedSlot(self, row, col):
- _logger.debug("_cellChangedSlot(%d, %d)", row, col)
- if self.building:
- return
- if col == 0:
- self.nameSlot(row, col)
+ :param str type_:
+ """
+ if self._type != type_:
+ self._type = type_
+ self.sigChanged.emit()
+
+ def getType(self):
+ """
+
+ :return str: the type of the ROI.
+ """
+ return self._type
+
+ def setName(self, name):
+ """
+ Set the name of the :class:`ROI`
+
+ :param str name:
+ """
+ if self._name != name:
+ self._name = name
+ self.sigChanged.emit()
+
+ def getName(self):
+ """
+
+ :return str: name of the :class:`ROI`
+ """
+ return self._name
+
+ def setFrom(self, frm):
+ """
+
+ :param frm: set x coordinate of the left limit
+ """
+ if self._fromdata != frm:
+ self._fromdata = frm
+ self.sigChanged.emit()
+
+ def getFrom(self):
+ """
+
+ :return: x coordinate of the left limit
+ """
+ return self._fromdata
+
+ def setTo(self, to):
+ """
+
+ :param to: x coordinate of the right limit
+ """
+ if self._todata != to:
+ self._todata = to
+ self.sigChanged.emit()
+
+ def getTo(self):
+ """
+
+ :return: x coordinate of the right limit
+ """
+ return self._todata
+
+ def getMiddle(self):
+ """
+
+ :return: middle position between 'from' and 'to' values
+ """
+ return 0.5 * (self.getFrom() + self.getTo())
+
+ def toDict(self):
+ """
+
+ :return: dict containing the roi parameters
+ """
+ ddict = {
+ 'type': self._type,
+ 'name': self._name,
+ 'from': self._fromdata,
+ 'to': self._todata,
+ }
+ if hasattr(self, '_extraInfo'):
+ ddict.update(self._extraInfo)
+ return ddict
+
+ @staticmethod
+ def _fromDict(dic):
+ assert 'name' in dic
+ roi = ROI(name=dic['name'])
+ roi._extraInfo = {}
+ for key in dic:
+ if key == 'from':
+ roi.setFrom(dic['from'])
+ elif key == 'to':
+ roi.setTo(dic['to'])
+ elif key == 'type':
+ roi.setType(dic['type'])
+ else:
+ roi._extraInfo[key] = dic[key]
+
+ return roi
+
+ def isICR(self):
+ """
+
+ :return: True if the ROI is the `ICR`
+ """
+ return self._name == 'ICR'
+
+ def computeRawAndNetCounts(self, curve):
+ """Compute the Raw and net counts in the ROI for the given curve.
+
+ - Raw count: Points values sum of the curve in the defined Region Of
+ Interest.
+
+ .. image:: img/rawCounts.png
+
+ - Net count: Raw counts minus background
+
+ .. image:: img/netCounts.png
+
+ :param CurveItem curve:
+ :return tuple: rawCount, netCount
+ """
+ assert isinstance(curve, Curve) or curve is None
+
+ if curve is None:
+ return None, None
+
+ x = curve.getXData(copy=False)
+ y = curve.getYData(copy=False)
+
+ idx = numpy.nonzero((self._fromdata <= x) &
+ (x <= self._todata))[0]
+ if len(idx):
+ xw = x[idx]
+ yw = y[idx]
+ rawCounts = yw.sum(dtype=numpy.float)
+ deltaX = xw[-1] - xw[0]
+ deltaY = yw[-1] - yw[0]
+ if deltaX > 0.0:
+ slope = (deltaY / deltaX)
+ background = yw[0] + slope * (xw - xw[0])
+ netCounts = (rawCounts -
+ background.sum(dtype=numpy.float))
+ else:
+ netCounts = 0.0
else:
- self._valueChanged(row, col)
+ rawCounts = 0.0
+ netCounts = 0.0
+ return rawCounts, netCounts
+
+ def computeRawAndNetArea(self, curve):
+ """Compute the Raw and net counts in the ROI for the given curve.
+
+ - Raw area: integral of the curve between the min ROI point and the
+ max ROI point to the y = 0 line.
- def _valueChanged(self, row, col):
- if col not in [2, 3]:
+ .. image:: img/rawArea.png
+
+ - Net area: Raw counts minus background
+
+ .. image:: img/netArea.png
+
+ :param CurveItem curve:
+ :return tuple: rawArea, netArea
+ """
+ assert isinstance(curve, Curve) or curve is None
+
+ if curve is None:
+ return None, None
+
+ x = curve.getXData(copy=False)
+ y = curve.getYData(copy=False)
+
+ y = y[(x >= self._fromdata) & (x <= self._todata)]
+ x = x[(x >= self._fromdata) & (x <= self._todata)]
+
+ if x.size is 0:
+ return 0.0, 0.0
+
+ rawArea = numpy.trapz(y, x=x)
+ # to speed up and avoid an intersection calculation we are taking the
+ # closest index to the ROI
+ closestXLeftIndex = (numpy.abs(x - self.getFrom())).argmin()
+ closestXRightIndex = (numpy.abs(x - self.getTo())).argmin()
+ yBackground = y[closestXLeftIndex], y[closestXRightIndex]
+ background = numpy.trapz(yBackground, x=x)
+ netArea = rawArea - background
+ return rawArea, netArea
+
+
+class _RoiMarkerManager(object):
+ """
+ Deal with all the ROI markers
+ """
+ def __init__(self):
+ self._roiMarkerHandlers = {}
+ self._middleROIMarkerFlag = False
+ self._showAllMarkers = False
+ self._activeRoi = None
+
+ def setActiveRoi(self, roi):
+ self._activeRoi = roi
+ self.updateAllMarkers()
+
+ def setShowAllMarkers(self, show):
+ if show != self._showAllMarkers:
+ self._showAllMarkers = show
+ self.updateAllMarkers()
+
+ def add(self, roi, markersHandler):
+ assert isinstance(roi, ROI)
+ assert isinstance(markersHandler, _RoiMarkerHandler)
+ if roi.getID() in self._roiMarkerHandlers:
+ raise ValueError('roi with the same ID already existing')
+ else:
+ self._roiMarkerHandlers[roi.getID()] = markersHandler
+
+ def getMarkerHandler(self, roiID):
+ if roiID in self._roiMarkerHandlers:
+ return self._roiMarkerHandlers[roiID]
+ else:
+ return None
+
+ def clear(self):
+ roisHandler = list(self._roiMarkerHandlers.values())
+ for roiHandler in roisHandler:
+ self.remove(roiHandler.roi)
+
+ def remove(self, roi):
+ if roi is None:
return
- item = self.item(row, col)
- if item is None:
+ assert isinstance(roi, ROI)
+ if roi.getID() in self._roiMarkerHandlers:
+ self._roiMarkerHandlers[roi.getID()].clear()
+ del self._roiMarkerHandlers[roi.getID()]
+
+ def hasMarker(self, markerID):
+ assert type(markerID) is str
+ return self.getMarker(markerID) is not None
+
+ def changePosition(self, markerID, x):
+ markerHandler = self.getMarker(markerID)
+ if markerHandler is None:
+ raise ValueError('Marker %s not register' % markerID)
+ markerHandler.changePosition(markerID=markerID, x=x)
+
+ def updateMarker(self, markerID):
+ markerHandler = self.getMarker(markerID)
+ if markerHandler is None:
+ raise ValueError('Marker %s not register' % markerID)
+ roiID = self.getRoiID(markerID)
+ visible = (self._activeRoi and self._activeRoi.getID() == roiID) or self._showAllMarkers is True
+ markerHandler.setVisible(visible)
+ markerHandler.updateAllMarkers()
+
+ def updateRoiMarkers(self, roiID):
+ if roiID in self._roiMarkerHandlers:
+ visible = ((self._activeRoi and self._activeRoi.getID() == roiID)
+ or self._showAllMarkers is True)
+ _roi = self._roiMarkerHandlers[roiID]._roi()
+ if _roi and not _roi.isICR():
+ self._roiMarkerHandlers[roiID].showMiddleMarker(self._middleROIMarkerFlag)
+ self._roiMarkerHandlers[roiID].setVisible(visible)
+ self._roiMarkerHandlers[roiID].updateMarkers()
+
+ def getMarker(self, markerID):
+ assert type(markerID) is str
+ for marker in list(self._roiMarkerHandlers.values()):
+ if marker.hasMarker(markerID):
+ return marker
+
+ def updateMarkers(self):
+ for markerHandler in list(self._roiMarkerHandlers.values()):
+ markerHandler.updateMarkers()
+
+ def getRoiID(self, markerID):
+ for roiID, markerHandler in self._roiMarkerHandlers.items():
+ if markerHandler.hasMarker(markerID):
+ return roiID
+ return None
+
+ def setShowMiddleMarkers(self, show):
+ self._middleROIMarkerFlag = show
+ self._roiMarkerHandlers.updateAllMarkers()
+
+ def updateAllMarkers(self):
+ for roiID in self._roiMarkerHandlers:
+ self.updateRoiMarkers(roiID)
+
+ def getVisibleRois(self):
+ res = {}
+ for roiID, roiHandler in self._roiMarkerHandlers.items():
+ markers = (roiHandler.getMarker('min'), roiHandler.getMarker('max'),
+ roiHandler.getMarker('middle'))
+ for marker in markers:
+ if marker.isVisible():
+ if roiID not in res:
+ res[roiID] = []
+ res[roiID].append(marker)
+ return res
+
+
+class _RoiMarkerHandler(object):
+ """Used to deal with ROI markers used in ROITable"""
+ def __init__(self, roi, plot):
+ assert roi and isinstance(roi, ROI)
+ assert plot
+
+ self._roi = weakref.ref(roi)
+ self._plot = weakref.ref(plot)
+ self._draggable = False if roi.isICR() else True
+ self._color = 'black' if roi.isICR() else 'blue'
+ self._displayMidMarker = False
+ self._visible = True
+
+ @property
+ def draggable(self):
+ return self._draggable
+
+ @property
+ def plot(self):
+ return self._plot()
+
+ def clear(self):
+ if self.plot and self.roi:
+ self.plot.removeMarker(self._markerID('min'))
+ self.plot.removeMarker(self._markerID('max'))
+ self.plot.removeMarker(self._markerID('middle'))
+
+ @property
+ def roi(self):
+ return self._roi()
+
+ def setVisible(self, visible):
+ if visible != self._visible:
+ self._visible = visible
+ self.updateMarkers()
+
+ def showMiddleMarker(self, visible):
+ if self.draggable is False and visible is True:
+ _logger.warning("ROI is not draggable. Won't display middle marker")
return
- text = str(item.text())
- try:
- value = float(text)
- except:
+ self._displayMidMarker = visible
+ self.getMarker('middle').setVisible(self._displayMidMarker)
+
+ def updateMarkers(self):
+ if self.roi is None:
return
- if row >= len(self.roilist):
- _logger.debug("deleting???")
+ self._updateMinMarkerPos()
+ self._updateMaxMarkerPos()
+ self._updateMiddleMarkerPos()
+
+ def _updateMinMarkerPos(self):
+ self.getMarker('min').setPosition(x=self.roi.getFrom(), y=None)
+ self.getMarker('min').setVisible(self._visible)
+
+ def _updateMaxMarkerPos(self):
+ self.getMarker('max').setPosition(x=self.roi.getTo(), y=None)
+ self.getMarker('max').setVisible(self._visible)
+
+ def _updateMiddleMarkerPos(self):
+ self.getMarker('middle').setPosition(x=self.roi.getMiddle(), y=None)
+ self.getMarker('middle').setVisible(self._displayMidMarker and self._visible)
+
+ def getMarker(self, markerType):
+ if self.plot is None:
+ return None
+ assert markerType in ('min', 'max', 'middle')
+ if self.plot._getMarker(self._markerID(markerType)) is None:
+ assert self.roi
+ if markerType == 'min':
+ val = self.roi.getFrom()
+ elif markerType == 'max':
+ val = self.roi.getTo()
+ else:
+ val = self.roi.getMiddle()
+
+ _color = self._color
+ if markerType == 'middle':
+ _color = 'yellow'
+ self.plot.addXMarker(val,
+ legend=self._markerID(markerType),
+ text=self.getMarkerName(markerType),
+ color=_color,
+ draggable=self.draggable)
+ return self.plot._getMarker(self._markerID(markerType))
+
+ def _markerID(self, markerType):
+ assert markerType in ('min', 'max', 'middle')
+ assert self.roi
+ return '_'.join((str(self.roi.getID()), markerType))
+
+ def getMarkerName(self, markerType):
+ assert markerType in ('min', 'max', 'middle')
+ assert self.roi
+ return ' '.join((self.roi.getName(), markerType))
+
+ def updateTexts(self):
+ self.getMarker('min').setText(self.getMarkerName('min'))
+ self.getMarker('max').setText(self.getMarkerName('max'))
+ self.getMarker('middle').setText(self.getMarkerName('middle'))
+
+ def changePosition(self, markerID, x):
+ assert self.hasMarker(markerID)
+ markerType = self._getMarkerType(markerID)
+ assert markerType is not None
+ if self.roi is None:
return
- item = self.item(row, 0)
- if item is None:
- text = ""
+ if markerType == 'min':
+ self.roi.setFrom(x)
+ self._updateMiddleMarkerPos()
+ elif markerType == 'max':
+ self.roi.setTo(x)
+ self._updateMiddleMarkerPos()
else:
- text = str(item.text())
- if not len(text):
- return
- if col == 2:
- self.roidict[text]['from'] = value
- elif col == 3:
- self.roidict[text]['to'] = value
- self._emitSelectionChangedSignal(row, col)
-
- def nameSlot(self, row, col):
- if col != 0:
- return
- if row >= len(self.roilist):
- _logger.debug("deleting???")
- return
- item = self.item(row, col)
- if item is None:
- text = ""
+ delta = x - 0.5 * (self.roi.getFrom() + self.roi.getTo())
+ self.roi.setFrom(self.roi.getFrom() + delta)
+ self.roi.setTo(self.roi.getTo() + delta)
+ self._updateMinMarkerPos()
+ self._updateMaxMarkerPos()
+
+ def hasMarker(self, marker):
+ return marker in (self._markerID('min'),
+ self._markerID('max'),
+ self._markerID('middle'))
+
+ def _getMarkerType(self, markerID):
+ if markerID.endswith('_min'):
+ return 'min'
+ elif markerID.endswith('_max'):
+ return 'max'
+ elif markerID.endswith('_middle'):
+ return 'middle'
else:
- text = str(item.text())
- if len(text) and (text not in self.roilist):
- old = self.roilist[row]
- self.roilist[row] = text
- self.roidict[text] = {}
- self.roidict[text].update(self.roidict[old])
- del self.roidict[old]
- self._emitSelectionChangedSignal(row, col)
-
- def _emitSelectionChangedSignal(self, row, col):
- ddict = {}
- ddict['event'] = "selectionChanged"
- ddict['row'] = row
- ddict['col'] = col
- ddict['roi'] = self.roidict[self.roilist[row]]
- ddict['key'] = self.roilist[row]
- ddict['colheader'] = self.labels[col]
- ddict['rowheader'] = "%d" % row
- self.sigROITableSignal.emit(ddict)
+ return None
class CurvesROIDockWidget(qt.QDockWidget):
@@ -1007,6 +1520,8 @@ class CurvesROIDockWidget(qt.QDockWidget):
def __init__(self, parent=None, plot=None, name=None):
super(CurvesROIDockWidget, self).__init__(name, parent)
+ assert plot is not None
+ self.plot = plot
self.roiWidget = CurvesROIWidget(self, name, plot=plot)
"""Main widget of type :class:`CurvesROIWidget`"""
@@ -1016,12 +1531,15 @@ class CurvesROIDockWidget(qt.QDockWidget):
self.calculateROIs = self.calculateRois = self.roiWidget.calculateRois
self.setRois = self.roiWidget.setRois
self.getRois = self.roiWidget.getRois
+
self.roiWidget.sigROISignal.connect(self._forwardSigROISignal)
- self.currentROI = self.roiWidget.currentROI
self.layout().setContentsMargins(0, 0, 0, 0)
self.setWidget(self.roiWidget)
+ self.setAreaVisible = self.roiWidget.roiTable.setAreaVisible
+ self.setCountsVisible = self.roiWidget.roiTable.setCountsVisible
+
def _forwardSigROISignal(self, ddict):
# emit deprecated signal for backward compatibility (silx < 0.7)
self.sigROISignal.emit(ddict)
@@ -1042,3 +1560,7 @@ class CurvesROIDockWidget(qt.QDockWidget):
"""
self.raise_()
qt.QDockWidget.showEvent(self, event)
+
+ @property
+ def currentROI(self):
+ return self.roiWidget.currentRoi
diff --git a/silx/gui/plot/MaskToolsWidget.py b/silx/gui/plot/MaskToolsWidget.py
index 990e479..9d727e7 100644
--- a/silx/gui/plot/MaskToolsWidget.py
+++ b/silx/gui/plot/MaskToolsWidget.py
@@ -35,7 +35,7 @@ from __future__ import division
__authors__ = ["T. Vincent", "P. Knobel"]
__license__ = "MIT"
-__date__ = "29/08/2018"
+__date__ = "15/02/2019"
import os
@@ -57,10 +57,7 @@ from .. import qt
from silx.third_party.EdfFile import EdfFile
from silx.third_party.TiffIO import TiffIO
-try:
- import fabio
-except ImportError:
- fabio = None
+import fabio
_logger = logging.getLogger(__name__)
@@ -135,8 +132,6 @@ class ImageMask(BaseMask):
self._saveToHdf5(filename, self.getMask(copy=False))
elif kind == 'msk':
- if fabio is None:
- raise ImportError("Fit2d mask files can't be written: Fabio module is not available")
try:
data = self.getMask(copy=False)
image = fabio.fabioimage.FabioImage(data=data)
@@ -250,6 +245,19 @@ class ImageMask(BaseMask):
rows, cols = shapes.circle_fill(crow, ccol, radius)
self.updatePoints(level, rows, cols, mask)
+ def updateEllipse(self, level, crow, ccol, radius_r, radius_c, mask=True):
+ """Mask/Unmask an ellipse of the given mask level.
+
+ :param int level: Mask level to update.
+ :param int crow: Row of the center of the ellipse
+ :param int ccol: Column of the center of the ellipse
+ :param float radius_r: Radius of the ellipse in the row
+ :param float radius_c: Radius of the ellipse in the column
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ rows, cols = shapes.ellipse_fill(crow, ccol, radius_r, radius_c)
+ self.updatePoints(level, rows, cols, mask)
+
def updateLine(self, level, row0, col0, row1, col1, width, mask=True):
"""Mask/Unmask a line of the given mask level.
@@ -300,6 +308,10 @@ class MaskToolsWidget(BaseMaskToolsWidget):
_logger.error('Not an image, shape: %d', len(mask.shape))
return None
+ # Handle mask with single level
+ if self.multipleMasks() == 'single':
+ mask = numpy.array(mask != 0, dtype=numpy.uint8)
+
# if mask has not changed, do nothing
if numpy.array_equal(mask, self.getSelectionMask()):
return mask.shape
@@ -501,8 +513,6 @@ class MaskToolsWidget(BaseMaskToolsWidget):
_logger.debug("Backtrace", exc_info=True)
raise e
elif extension == "msk":
- if fabio is None:
- raise ImportError("Fit2d mask files can't be read: Fabio module is not available")
try:
mask = fabio.open(filename).data
except Exception as e:
@@ -682,41 +692,51 @@ class MaskToolsWidget(BaseMaskToolsWidget):
level = self.levelSpinBox.value()
- if (self._drawingMode == 'rectangle' and
- event['event'] == 'drawingFinished'):
- # Convert from plot to array coords
- doMask = self._isMasking()
- ox, oy = self._origin
- sx, sy = self._scale
-
- height = int(abs(event['height'] / sy))
- width = int(abs(event['width'] / sx))
-
- row = int((event['y'] - oy) / sy)
- if sy < 0:
- row -= height
-
- col = int((event['x'] - ox) / sx)
- if sx < 0:
- col -= width
-
- self._mask.updateRectangle(
- level,
- row=row,
- col=col,
- height=height,
- width=width,
- mask=doMask)
- self._mask.commit()
+ if self._drawingMode == 'rectangle':
+ if event['event'] == 'drawingFinished':
+ # Convert from plot to array coords
+ doMask = self._isMasking()
+ ox, oy = self._origin
+ sx, sy = self._scale
+
+ height = int(abs(event['height'] / sy))
+ width = int(abs(event['width'] / sx))
+
+ row = int((event['y'] - oy) / sy)
+ if sy < 0:
+ row -= height
+
+ col = int((event['x'] - ox) / sx)
+ if sx < 0:
+ col -= width
+
+ self._mask.updateRectangle(
+ level,
+ row=row,
+ col=col,
+ height=height,
+ width=width,
+ mask=doMask)
+ self._mask.commit()
- elif (self._drawingMode == 'polygon' and
- event['event'] == 'drawingFinished'):
- doMask = self._isMasking()
- # Convert from plot to array coords
- vertices = (event['points'] - self._origin) / self._scale
- vertices = vertices.astype(numpy.int)[:, (1, 0)] # (row, col)
- self._mask.updatePolygon(level, vertices, doMask)
- self._mask.commit()
+ elif self._drawingMode == 'ellipse':
+ if event['event'] == 'drawingFinished':
+ doMask = self._isMasking()
+ # Convert from plot to array coords
+ center = (event['points'][0] - self._origin) / self._scale
+ size = event['points'][1] / self._scale
+ center = center.astype(numpy.int) # (row, col)
+ self._mask.updateEllipse(level, center[1], center[0], size[1], size[0], doMask)
+ self._mask.commit()
+
+ elif self._drawingMode == 'polygon':
+ if event['event'] == 'drawingFinished':
+ doMask = self._isMasking()
+ # Convert from plot to array coords
+ vertices = (event['points'] - self._origin) / self._scale
+ vertices = vertices.astype(numpy.int)[:, (1, 0)] # (row, col)
+ self._mask.updatePolygon(level, vertices, doMask)
+ self._mask.commit()
elif self._drawingMode == 'pencil':
doMask = self._isMasking()
@@ -743,6 +763,8 @@ class MaskToolsWidget(BaseMaskToolsWidget):
self._lastPencilPos = None
else:
self._lastPencilPos = row, col
+ else:
+ _logger.error("Drawing mode %s unsupported", self._drawingMode)
def _loadRangeFromColormapTriggered(self):
"""Set range from active image colormap range"""
diff --git a/silx/gui/plot/PlotInteraction.py b/silx/gui/plot/PlotInteraction.py
index 356bda6..27abd10 100644
--- a/silx/gui/plot/PlotInteraction.py
+++ b/silx/gui/plot/PlotInteraction.py
@@ -26,7 +26,7 @@
__authors__ = ["T. Vincent"]
__license__ = "MIT"
-__date__ = "24/04/2018"
+__date__ = "15/02/2019"
import math
@@ -96,10 +96,18 @@ class _PlotInteraction(object):
fill = fill != 'none' # TODO not very nice either
+ greyed = colors.greyed(color)[0]
+ if greyed < 0.5:
+ color2 = "white"
+ else:
+ color2 = "black"
+
self.plot.addItem(points[:, 0], points[:, 1], legend=legend,
replace=False,
- shape=shape, color=color, fill=fill,
+ shape=shape, fill=fill,
+ color=color, linebgcolor=color2, linestyle="--",
overlay=True)
+
self._selectionAreas.add(legend)
def resetSelectionArea(self):
@@ -274,6 +282,8 @@ class Zoom(_ZoomOnWheel):
and zoom on mouse wheel.
"""
+ SURFACE_THRESHOLD = 5
+
def __init__(self, plot, color):
self.color = color
@@ -347,35 +357,44 @@ class Zoom(_ZoomOnWheel):
self.setSelectionArea(corners, fill='none', color=self.color)
- def endDrag(self, startPos, endPos):
- x0, y0 = startPos
- x1, y1 = endPos
+ def _zoom(self, x0, y0, x1, y1):
+ """Zoom to the rectangle view x0,y0 x1,y1.
+ """
+ startPos = x0, y0
+ endPos = x1, y1
+
+ # Store current zoom state in stack
+ self.plot.getLimitsHistory().push()
- if x0 != x1 or y0 != y1: # Avoid empty zoom area
- # Store current zoom state in stack
- self.plot.getLimitsHistory().push()
+ if self.plot.isKeepDataAspectRatio():
+ x0, y0, x1, y1 = self._areaWithAspectRatio(x0, y0, x1, y1)
+
+ # Convert to data space and set limits
+ x0, y0 = self.plot.pixelToData(x0, y0, check=False)
- if self.plot.isKeepDataAspectRatio():
- x0, y0, x1, y1 = self._areaWithAspectRatio(x0, y0, x1, y1)
+ dataPos = self.plot.pixelToData(
+ startPos[0], startPos[1], axis="right", check=False)
+ y2_0 = dataPos[1]
- # Convert to data space and set limits
- x0, y0 = self.plot.pixelToData(x0, y0, check=False)
+ x1, y1 = self.plot.pixelToData(x1, y1, check=False)
- dataPos = self.plot.pixelToData(
- startPos[0], startPos[1], axis="right", check=False)
- y2_0 = dataPos[1]
+ dataPos = self.plot.pixelToData(
+ endPos[0], endPos[1], axis="right", check=False)
+ y2_1 = dataPos[1]
- x1, y1 = self.plot.pixelToData(x1, y1, check=False)
+ xMin, xMax = min(x0, x1), max(x0, x1)
+ yMin, yMax = min(y0, y1), max(y0, y1)
+ y2Min, y2Max = min(y2_0, y2_1), max(y2_0, y2_1)
- dataPos = self.plot.pixelToData(
- endPos[0], endPos[1], axis="right", check=False)
- y2_1 = dataPos[1]
+ self.plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
- xMin, xMax = min(x0, x1), max(x0, x1)
- yMin, yMax = min(y0, y1), max(y0, y1)
- y2Min, y2Max = min(y2_0, y2_1), max(y2_0, y2_1)
+ def endDrag(self, startPos, endPos):
+ x0, y0 = startPos
+ x1, y1 = endPos
- self.plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
+ if abs(x0 - x1) * abs(y0 - y1) >= self.SURFACE_THRESHOLD:
+ # Avoid empty zoom area
+ self._zoom(x0, y0, x1, y1)
self.resetSelectionArea()
@@ -544,7 +563,6 @@ class SelectPolygon(Select):
return self.DRAG_THRESHOLD_DIST * ratio
-
class Select2Points(Select):
"""Base class for drawing selection based on 2 input points."""
class Idle(State):
@@ -603,6 +621,87 @@ class Select2Points(Select):
self.cancelSelect()
+class SelectEllipse(Select2Points):
+ """Drawing ellipse selection area state machine."""
+ def beginSelect(self, x, y):
+ self.center = self.plot.pixelToData(x, y)
+ assert self.center is not None
+
+ def _getEllipseSize(self, pointInEllipse):
+ """
+ Returns the size from the center to the bounding box of the ellipse.
+
+ :param Tuple[float,float] pointInEllipse: A point of the ellipse
+ :rtype: Tuple[float,float]
+ """
+ x = abs(self.center[0] - pointInEllipse[0])
+ y = abs(self.center[1] - pointInEllipse[1])
+ if x == 0 or y == 0:
+ return x, y
+ # Ellipse definitions
+ # e: eccentricity
+ # a: length fron center to bounding box width
+ # b: length fron center to bounding box height
+ # Equations
+ # (1) b < a
+ # (2) For x,y a point in the ellipse: x^2/a^2 + y^2/b^2 = 1
+ # (3) b = a * sqrt(1-e^2)
+ # (4) e = sqrt(a^2 - b^2) / a
+
+ # The eccentricity of the ellipse defined by a,b=x,y is the same
+ # as the one we are searching for.
+ swap = x < y
+ if swap:
+ x, y = y, x
+ e = math.sqrt(x**2 - y**2) / x
+ # From (2) using (3) to replace b
+ # a^2 = x^2 + y^2 / (1-e^2)
+ a = math.sqrt(x**2 + y**2 / (1.0 - e**2))
+ b = a * math.sqrt(1 - e**2)
+ if swap:
+ a, b = b, a
+ return a, b
+
+ def select(self, x, y):
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ width, height = self._getEllipseSize(dataPos)
+
+ # Circle used for circle preview
+ nbpoints = 27.
+ angles = numpy.arange(nbpoints) * numpy.pi * 2.0 / nbpoints
+ circleShape = numpy.array((numpy.cos(angles) * width,
+ numpy.sin(angles) * height)).T
+ circleShape += numpy.array(self.center)
+
+ self.setSelectionArea(circleShape,
+ shape="polygon",
+ fill='hatch',
+ color=self.color)
+
+ eventDict = prepareDrawingSignal('drawingProgress',
+ 'ellipse',
+ (self.center, (width, height)),
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def endSelect(self, x, y):
+ self.resetSelectionArea()
+
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ width, height = self._getEllipseSize(dataPos)
+
+ eventDict = prepareDrawingSignal('drawingFinished',
+ 'ellipse',
+ (self.center, (width, height)),
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def cancelSelect(self):
+ self.resetSelectionArea()
+
+
class SelectRectangle(Select2Points):
"""Drawing rectangle selection area state machine."""
def beginSelect(self, x, y):
@@ -1488,6 +1587,7 @@ class PlotInteraction(object):
_DRAW_MODES = {
'polygon': SelectPolygon,
'rectangle': SelectRectangle,
+ 'ellipse': SelectEllipse,
'line': SelectLine,
'vline': SelectVLine,
'hline': SelectHLine,
diff --git a/silx/gui/plot/PlotToolButtons.py b/silx/gui/plot/PlotToolButtons.py
index f6291b5..bf6b8ce 100644
--- a/silx/gui/plot/PlotToolButtons.py
+++ b/silx/gui/plot/PlotToolButtons.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -45,6 +45,7 @@ import weakref
from .. import icons
from .. import qt
+from ... import config
from .items import SymbolMixIn
@@ -250,24 +251,24 @@ class ProfileOptionToolButton(PlotToolButton):
self.STATE = {}
# is down
self.STATE['sum', "icon"] = icons.getQIcon('math-sigma')
- self.STATE['sum', "state"] = "compute profile sum"
- self.STATE['sum', "action"] = "compute profile sum"
+ self.STATE['sum', "state"] = "Compute profile sum"
+ self.STATE['sum', "action"] = "Compute profile sum"
# keep ration
self.STATE['mean', "icon"] = icons.getQIcon('math-mean')
- self.STATE['mean', "state"] = "compute profile mean"
- self.STATE['mean', "action"] = "compute profile mean"
+ self.STATE['mean', "state"] = "Compute profile mean"
+ self.STATE['mean', "action"] = "Compute profile mean"
- sumAction = self._createAction('sum')
- sumAction.triggered.connect(self.setSum)
- sumAction.setIconVisibleInMenu(True)
+ self.sumAction = self._createAction('sum')
+ self.sumAction.triggered.connect(self.setSum)
+ self.sumAction.setIconVisibleInMenu(True)
- meanAction = self._createAction('mean')
- meanAction.triggered.connect(self.setMean)
- meanAction.setIconVisibleInMenu(True)
+ self.meanAction = self._createAction('mean')
+ self.meanAction.triggered.connect(self.setMean)
+ self.meanAction.setIconVisibleInMenu(True)
menu = qt.QMenu(self)
- menu.addAction(sumAction)
- menu.addAction(meanAction)
+ menu.addAction(self.sumAction)
+ menu.addAction(self.meanAction)
self.setMenu(menu)
self.setPopupMode(qt.QToolButton.InstantPopup)
self.setMean()
@@ -370,7 +371,7 @@ class SymbolToolButton(PlotToolButton):
slider = qt.QSlider(qt.Qt.Horizontal)
slider.setRange(1, 20)
- slider.setValue(SymbolMixIn._DEFAULT_SYMBOL_SIZE)
+ slider.setValue(config.DEFAULT_PLOT_SYMBOL_SIZE)
slider.setTracking(False)
slider.valueChanged.connect(self._sizeChanged)
widgetAction = qt.QWidgetAction(menu)
diff --git a/silx/gui/plot/PlotWidget.py b/silx/gui/plot/PlotWidget.py
index e023a21..cfe39fa 100644
--- a/silx/gui/plot/PlotWidget.py
+++ b/silx/gui/plot/PlotWidget.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -31,7 +31,7 @@ from __future__ import division
__authors__ = ["V.A. Sole", "T. Vincent"]
__license__ = "MIT"
-__date__ = "12/10/2018"
+__date__ = "21/12/2018"
from collections import OrderedDict, namedtuple
@@ -44,7 +44,6 @@ import numpy
import silx
from silx.utils.weakref import WeakMethodProxy
-from silx.utils import deprecation
from silx.utils.property import classproperty
from silx.utils.deprecation import deprecated
# Import matplotlib backend here to init matplotlib our way
@@ -99,7 +98,7 @@ class PlotWidget(qt.QMainWindow):
# TODO: Can be removed for silx 0.10
@classproperty
- @deprecation.deprecated(replacement="silx.config.DEFAULT_PLOT_BACKEND", since_version="0.8", skip_backtrace_count=2)
+ @deprecated(replacement="silx.config.DEFAULT_PLOT_BACKEND", since_version="0.8", skip_backtrace_count=2)
def DEFAULT_BACKEND(self):
"""Class attribute setting the default backend for all instances."""
return silx.config.DEFAULT_PLOT_BACKEND
@@ -193,21 +192,12 @@ class PlotWidget(qt.QMainWindow):
It provides the visible state.
"""
- def __init__(self, parent=None, backend=None,
- legends=False, callback=None, **kw):
+ def __init__(self, parent=None, backend=None):
self._autoreplot = False
self._dirty = False
self._cursorInPlot = False
self.__muteActiveItemChanged = False
- if kw:
- _logger.warning(
- 'deprecated: __init__ extra arguments: %s', str(kw))
- if legends:
- _logger.warning('deprecated: __init__ legend argument')
- if callback:
- _logger.warning('deprecated: __init__ callback argument')
-
self._panWithArrowKeys = True
self._viewConstrains = None
@@ -218,27 +208,8 @@ class PlotWidget(qt.QMainWindow):
else:
self.setWindowTitle('PlotWidget')
- if backend is None:
- backend = silx.config.DEFAULT_PLOT_BACKEND
-
- if hasattr(backend, "__call__"):
- self._backend = backend(self, parent)
-
- elif hasattr(backend, "lower"):
- lowerCaseString = backend.lower()
- if lowerCaseString in ("matplotlib", "mpl"):
- backendClass = BackendMatplotlibQt
- elif lowerCaseString in ('gl', 'opengl'):
- from .backends.BackendOpenGL import BackendOpenGL
- backendClass = BackendOpenGL
- elif lowerCaseString == 'none':
- from .backends.BackendBase import BackendBase as backendClass
- else:
- raise ValueError("Backend not supported %s" % backend)
- self._backend = backendClass(self, parent)
-
- else:
- raise ValueError("Backend not supported %s" % str(backend))
+ self._backend = None
+ self._setBackend(backend)
self.setCallback() # set _callback
@@ -258,6 +229,12 @@ class PlotWidget(qt.QMainWindow):
self._activeLegend = {'curve': None, 'image': None,
'scatter': None}
+ # plot colors (updated later to sync backend)
+ self._foregroundColor = 0., 0., 0., 1.
+ self._gridColor = .7, .7, .7, 1.
+ self._backgroundColor = 1., 1., 1., 1.
+ self._dataBackgroundColor = None
+
# default properties
self._cursorConfiguration = None
@@ -275,7 +252,7 @@ class PlotWidget(qt.QMainWindow):
self.setDefaultColormap() # Init default colormap
- self.setDefaultPlotPoints(False)
+ self.setDefaultPlotPoints(silx.config.DEFAULT_PLOT_CURVE_SYMBOL_MODE)
self.setDefaultPlotLines(True)
self._limitsHistory = LimitsHistory(self)
@@ -306,9 +283,41 @@ class PlotWidget(qt.QMainWindow):
self.setGraphYLimits(0., 100., axis='right')
self.setGraphYLimits(0., 100., axis='left')
+ # Sync backend colors with default ones
+ self._foregroundColorsUpdated()
+ self._backgroundColorsUpdated()
+
+ def _setBackend(self, backend):
+ """Setup a new backend"""
+ assert(self._backend is None)
+
+ if backend is None:
+ backend = silx.config.DEFAULT_PLOT_BACKEND
+
+ if hasattr(backend, "__call__"):
+ backend = backend(self, self)
+
+ elif hasattr(backend, "lower"):
+ lowerCaseString = backend.lower()
+ if lowerCaseString in ("matplotlib", "mpl"):
+ backendClass = BackendMatplotlibQt
+ elif lowerCaseString in ('gl', 'opengl'):
+ from .backends.BackendOpenGL import BackendOpenGL
+ backendClass = BackendOpenGL
+ elif lowerCaseString == 'none':
+ from .backends.BackendBase import BackendBase as backendClass
+ else:
+ raise ValueError("Backend not supported %s" % backend)
+ backend = backendClass(self, self)
+
+ else:
+ raise ValueError("Backend not supported %s" % str(backend))
+
+ self._backend = backend
+
# TODO: Can be removed for silx 0.10
@staticmethod
- @deprecation.deprecated(replacement="silx.config.DEFAULT_PLOT_BACKEND", since_version="0.8", skip_backtrace_count=2)
+ @deprecated(replacement="silx.config.DEFAULT_PLOT_BACKEND", since_version="0.8", skip_backtrace_count=2)
def setDefaultBackend(backend):
"""Set system wide default plot backend.
@@ -349,6 +358,119 @@ class PlotWidget(qt.QMainWindow):
if self._autoreplot and not wasDirty and self.isVisible():
self._backend.postRedisplay()
+ def _foregroundColorsUpdated(self):
+ """Handle change of foreground/grid color"""
+ if self._gridColor is None:
+ gridColor = self._foregroundColor
+ else:
+ gridColor = self._gridColor
+ self._backend.setForegroundColors(
+ self._foregroundColor, gridColor)
+ self._setDirtyPlot()
+
+ def getForegroundColor(self):
+ """Returns the RGBA colors used to display the foreground of this widget
+
+ :rtype: qt.QColor
+ """
+ return qt.QColor.fromRgbF(*self._foregroundColor)
+
+ def setForegroundColor(self, color):
+ """Set the foreground color of this widget.
+
+ :param Union[List[int],List[float],QColor] color:
+ The new RGB(A) color.
+ """
+ color = colors.rgba(color)
+ if self._foregroundColor != color:
+ self._foregroundColor = color
+ self._foregroundColorsUpdated()
+
+ def getGridColor(self):
+ """Returns the RGBA colors used to display the grid lines
+
+ An invalid QColor is returned if there is no grid color,
+ in which case the foreground color is used.
+
+ :rtype: qt.QColor
+ """
+ if self._gridColor is None:
+ return qt.QColor() # An invalid color
+ else:
+ return qt.QColor.fromRgbF(*self._gridColor)
+
+ def setGridColor(self, color):
+ """Set the grid lines color
+
+ :param Union[List[int],List[float],QColor,None] color:
+ The new RGB(A) color.
+ """
+ if isinstance(color, qt.QColor) and not color.isValid():
+ color = None
+ if color is not None:
+ color = colors.rgba(color)
+ if self._gridColor != color:
+ self._gridColor = color
+ self._foregroundColorsUpdated()
+
+ def _backgroundColorsUpdated(self):
+ """Handle change of background/data background color"""
+ if self._dataBackgroundColor is None:
+ dataBGColor = self._backgroundColor
+ else:
+ dataBGColor = self._dataBackgroundColor
+ self._backend.setBackgroundColors(
+ self._backgroundColor, dataBGColor)
+ self._setDirtyPlot()
+
+ def getBackgroundColor(self):
+ """Returns the RGBA colors used to display the background of this widget.
+
+ :rtype: qt.QColor
+ """
+ return qt.QColor.fromRgbF(*self._backgroundColor)
+
+ def setBackgroundColor(self, color):
+ """Set the background color of this widget.
+
+ :param Union[List[int],List[float],QColor] color:
+ The new RGB(A) color.
+ """
+ color = colors.rgba(color)
+ if self._backgroundColor != color:
+ self._backgroundColor = color
+ self._backgroundColorsUpdated()
+
+ def getDataBackgroundColor(self):
+ """Returns the RGBA colors used to display the background of the plot
+ view displaying the data.
+
+ An invalid QColor is returned if there is no data background color.
+
+ :rtype: qt.QColor
+ """
+ if self._dataBackgroundColor is None:
+ # An invalid color
+ return qt.QColor()
+ else:
+ return qt.QColor.fromRgbF(*self._dataBackgroundColor)
+
+ def setDataBackgroundColor(self, color):
+ """Set the background color of this widget.
+
+ Set to None or an invalid QColor to use the background color.
+
+ :param Union[List[int],List[float],QColor,None] color:
+ The new RGB(A) color.
+ """
+ if isinstance(color, qt.QColor) and not color.isValid():
+ color = None
+ if color is not None:
+ color = colors.rgba(color)
+ if self._dataBackgroundColor != color:
+ self._dataBackgroundColor = color
+ self._backgroundColorsUpdated()
+
def showEvent(self, event):
if self._autoreplot and self._dirty:
self._backend.postRedisplay()
@@ -528,13 +650,13 @@ class PlotWidget(qt.QMainWindow):
# This value is used when curve is updated either internally or by user.
def addCurve(self, x, y, legend=None, info=None,
- replace=False, replot=None,
+ replace=False,
color=None, symbol=None,
linewidth=None, linestyle=None,
xlabel=None, ylabel=None, yaxis=None,
xerror=None, yerror=None, z=None, selectable=None,
fill=None, resetzoom=True,
- histogram=None, copy=True, **kw):
+ histogram=None, copy=True):
"""Add a 1D curve given by x an y to the graph.
Curves are uniquely identified by their legend.
@@ -617,15 +739,6 @@ class PlotWidget(qt.QMainWindow):
False to use provided arrays.
:returns: The key string identify this curve
"""
- # Deprecation warnings
- if replot is not None:
- _logger.warning(
- 'addCurve deprecated replot argument, use resetzoom instead')
- resetzoom = replot and resetzoom
-
- if kw:
- _logger.warning('addCurve: deprecated extra arguments')
-
# This is an histogram, use addHistogram
if histogram is not None:
histoLegend = self.addHistogram(histogram=y,
@@ -825,13 +938,13 @@ class PlotWidget(qt.QMainWindow):
return legend
def addImage(self, data, legend=None, info=None,
- replace=False, replot=None,
- xScale=None, yScale=None, z=None,
+ replace=False,
+ z=None,
selectable=None, draggable=None,
colormap=None, pixmap=None,
xlabel=None, ylabel=None,
origin=None, scale=None,
- resetzoom=True, copy=True, **kw):
+ resetzoom=True, copy=True):
"""Add a 2D dataset or an image to the plot.
It displays either an array of data using a colormap or a RGB(A) image.
@@ -883,28 +996,6 @@ class PlotWidget(qt.QMainWindow):
False to use provided arrays.
:returns: The key string identify this image
"""
- # Deprecation warnings
- if xScale is not None or yScale is not None:
- _logger.warning(
- 'addImage deprecated xScale and yScale arguments,'
- 'use origin, scale arguments instead.')
- if origin is None and scale is None:
- origin = xScale[0], yScale[0]
- scale = xScale[1], yScale[1]
- else:
- _logger.warning(
- 'addCurve: xScale, yScale and origin, scale arguments'
- ' are conflicting. xScale and yScale are ignored.'
- ' Use only origin, scale arguments.')
-
- if replot is not None:
- _logger.warning(
- 'addImage deprecated replot argument, use resetzoom instead')
- resetzoom = replot and resetzoom
-
- if kw:
- _logger.warning('addImage: deprecated extra arguments')
-
legend = "Unnamed Image 1.1" if legend is None else str(legend)
# Check if image was previously active
@@ -1090,7 +1181,8 @@ class PlotWidget(qt.QMainWindow):
def addItem(self, xdata, ydata, legend=None, info=None,
replace=False,
shape="polygon", color='black', fill=True,
- overlay=False, z=None, **kw):
+ overlay=False, z=None, linestyle="-", linewidth=1.0,
+ linebgcolor=None):
"""Add an item (i.e. a shape) to the plot.
Items are uniquely identified by their legend.
@@ -1114,13 +1206,23 @@ class PlotWidget(qt.QMainWindow):
This allows for rendering optimization if this
item is changed often.
:param int z: Layer on which to draw the item (default: 2)
+ :param str linestyle: Style of the line.
+ Only relevant for line markers where X or Y is None.
+ Value in:
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+ :param float linewidth: Width of the line.
+ Only relevant for line markers where X or Y is None.
+ :param str linebgcolor: Background color of the line, e.g., 'blue', 'b',
+ '#FF0000'. It is used to draw dotted line using a second color.
:returns: The key string identify this item
"""
# expected to receive the same parameters as the signal
- if kw:
- _logger.warning('addItem deprecated parameters: %s', str(kw))
-
legend = "Unnamed Item 1.1" if legend is None else str(legend)
z = int(z) if z is not None else 2
@@ -1138,6 +1240,9 @@ class PlotWidget(qt.QMainWindow):
item.setOverlay(overlay)
item.setZValue(z)
item.setPoints(numpy.array((xdata, ydata)).T)
+ item.setLineStyle(linestyle)
+ item.setLineWidth(linewidth)
+ item.setLineBgColor(linebgcolor)
self._add(item)
@@ -1148,8 +1253,7 @@ class PlotWidget(qt.QMainWindow):
color=None,
selectable=False,
draggable=False,
- constraint=None,
- **kw):
+ constraint=None):
"""Add a vertical line marker to the plot.
Markers are uniquely identified by their legend.
@@ -1177,10 +1281,6 @@ class PlotWidget(qt.QMainWindow):
and that returns the filtered coordinates.
:return: The key string identify this marker
"""
- if kw:
- _logger.warning(
- 'addXMarker deprecated extra parameters: %s', str(kw))
-
return self._addMarker(x=x, y=None, legend=legend,
text=text, color=color,
selectable=selectable, draggable=draggable,
@@ -1192,8 +1292,7 @@ class PlotWidget(qt.QMainWindow):
color=None,
selectable=False,
draggable=False,
- constraint=None,
- **kw):
+ constraint=None):
"""Add a horizontal line marker to the plot.
Markers are uniquely identified by their legend.
@@ -1221,10 +1320,6 @@ class PlotWidget(qt.QMainWindow):
and that returns the filtered coordinates.
:return: The key string identify this marker
"""
- if kw:
- _logger.warning(
- 'addYMarker deprecated extra parameters: %s', str(kw))
-
return self._addMarker(x=None, y=y, legend=legend,
text=text, color=color,
selectable=selectable, draggable=draggable,
@@ -1236,8 +1331,7 @@ class PlotWidget(qt.QMainWindow):
selectable=False,
draggable=False,
symbol='+',
- constraint=None,
- **kw):
+ constraint=None):
"""Add a point marker to the plot.
Markers are uniquely identified by their legend.
@@ -1277,10 +1371,6 @@ class PlotWidget(qt.QMainWindow):
and that returns the filtered coordinates.
:return: The key string identify this marker
"""
- if kw:
- _logger.warning(
- 'addMarker deprecated extra parameters: %s', str(kw))
-
if x is None:
xmin, xmax = self._xAxis.getLimits()
x = 0.5 * (xmax + xmin)
@@ -1368,7 +1458,7 @@ class PlotWidget(qt.QMainWindow):
curve = self._getItem('curve', legend)
return curve is not None and not curve.isVisible()
- def hideCurve(self, legend, flag=True, replot=None):
+ def hideCurve(self, legend, flag=True):
"""Show/Hide the curve associated to legend.
Even when hidden, the curve is kept in the list of curves.
@@ -1376,9 +1466,6 @@ class PlotWidget(qt.QMainWindow):
:param str legend: The legend associated to the curve to be hidden
:param bool flag: True (default) to hide the curve, False to show it
"""
- if replot is not None:
- _logger.warning('hideCurve deprecated replot parameter')
-
curve = self._getItem('curve', legend)
if curve is None:
_logger.warning('Curve not in plot: %s', legend)
@@ -1660,16 +1747,13 @@ class PlotWidget(qt.QMainWindow):
return self._getActiveItem(kind='curve', just_legend=just_legend)
- def setActiveCurve(self, legend, replot=None):
+ def setActiveCurve(self, legend):
"""Make the curve associated to legend the active curve.
:param legend: The legend associated to the curve
or None to have no active curve.
:type legend: str or None
"""
- if replot is not None:
- _logger.warning('setActiveCurve deprecated replot parameter')
-
if not self.isActiveCurveHandling():
return
if legend is None and self.getActiveCurveSelectionMode() == "legacy":
@@ -1723,15 +1807,12 @@ class PlotWidget(qt.QMainWindow):
"""
return self._getActiveItem(kind='image', just_legend=just_legend)
- def setActiveImage(self, legend, replot=None):
+ def setActiveImage(self, legend):
"""Make the image associated to legend the active image.
:param str legend: The legend associated to the image
or None to have no active image.
"""
- if replot is not None:
- _logger.warning('setActiveImage deprecated replot parameter')
-
return self._setActiveItem(kind='image', legend=legend)
def _getActiveItem(self, kind, just_legend=False):
@@ -2028,14 +2109,12 @@ class PlotWidget(qt.QMainWindow):
"""
return self._backend.getGraphXLimits()
- def setGraphXLimits(self, xmin, xmax, replot=None):
+ def setGraphXLimits(self, xmin, xmax):
"""Set the graph X (bottom) limits.
:param float xmin: minimum bottom axis value
:param float xmax: maximum bottom axis value
"""
- if replot is not None:
- _logger.warning('setGraphXLimits deprecated replot parameter')
self._xAxis.setLimits(xmin, xmax)
def getGraphYLimits(self, axis='left'):
@@ -2049,7 +2128,7 @@ class PlotWidget(qt.QMainWindow):
yAxis = self._yAxis if axis == 'left' else self._yRightAxis
return yAxis.getLimits()
- def setGraphYLimits(self, ymin, ymax, axis='left', replot=None):
+ def setGraphYLimits(self, ymin, ymax, axis='left'):
"""Set the graph Y limits.
:param float ymin: minimum bottom axis value
@@ -2057,8 +2136,6 @@ class PlotWidget(qt.QMainWindow):
:param str axis: The axis for which to get the limits:
Either 'left' or 'right'
"""
- if replot is not None:
- _logger.warning('setGraphYLimits deprecated replot parameter')
assert axis in ('left', 'right')
yAxis = self._yAxis if axis == 'left' else self._yRightAxis
return yAxis.setLimits(ymin, ymax)
@@ -2192,36 +2269,6 @@ class PlotWidget(qt.QMainWindow):
def _isAxesDisplayed(self):
return self._backend.isAxesDisplayed()
- @property
- @deprecated(since_version='0.6')
- def sigSetYAxisInverted(self):
- """Signal emitted when Y axis orientation has changed"""
- return self._yAxis.sigInvertedChanged
-
- @property
- @deprecated(since_version='0.6')
- def sigSetXAxisLogarithmic(self):
- """Signal emitted when X axis scale has changed"""
- return self._xAxis._sigLogarithmicChanged
-
- @property
- @deprecated(since_version='0.6')
- def sigSetYAxisLogarithmic(self):
- """Signal emitted when Y axis scale has changed"""
- return self._yAxis._sigLogarithmicChanged
-
- @property
- @deprecated(since_version='0.6')
- def sigSetXAxisAutoScale(self):
- """Signal emitted when X axis autoscale has changed"""
- return self._xAxis.sigAutoScaleChanged
-
- @property
- @deprecated(since_version='0.6')
- def sigSetYAxisAutoScale(self):
- """Signal emitted when Y axis autoscale has changed"""
- return self._yAxis.sigAutoScaleChanged
-
def setYAxisInverted(self, flag=True):
"""Set the Y axis orientation.
@@ -2290,6 +2337,8 @@ class PlotWidget(qt.QMainWindow):
:param bool flag: True to respect data aspect ratio
"""
flag = bool(flag)
+ if flag == self.isKeepDataAspectRatio():
+ return
self._backend.setKeepDataAspectRatio(flag=flag)
self._setDirtyPlot()
self._forceResetZoom()
@@ -2323,8 +2372,8 @@ class PlotWidget(qt.QMainWindow):
# Defaults
def isDefaultPlotPoints(self):
- """Return True if default Curve symbol is 'o', False for no symbol."""
- return self._defaultPlotPoints == 'o'
+ """Return True if the default Curve symbol is set and False if not."""
+ return self._defaultPlotPoints == silx.config.DEFAULT_PLOT_SYMBOL
def setDefaultPlotPoints(self, flag):
"""Set the default symbol of all curves.
@@ -2334,7 +2383,7 @@ class PlotWidget(qt.QMainWindow):
:param bool flag: True to use 'o' as the default curve symbol,
False to use no symbol.
"""
- self._defaultPlotPoints = 'o' if flag else ''
+ self._defaultPlotPoints = silx.config.DEFAULT_PLOT_SYMBOL if flag else ''
# Reset symbol of all curves
curves = self.getAllCurves(just_legend=False, withhidden=True)
@@ -2510,7 +2559,7 @@ class PlotWidget(qt.QMainWindow):
elif ddict['event'] == 'mouseClicked' and ddict['button'] == 'left':
self.setActiveCurve(None)
- def saveGraph(self, filename, fileFormat=None, dpi=None, **kw):
+ def saveGraph(self, filename, fileFormat=None, dpi=None):
"""Save a snapshot of the plot.
Supported file formats depends on the backend in use.
@@ -2523,9 +2572,6 @@ class PlotWidget(qt.QMainWindow):
:param str fileFormat: String specifying the format
:return: False if cannot save the plot, True otherwise
"""
- if kw:
- _logger.warning('Extra parameters ignored: %s', str(kw))
-
if fileFormat is None:
if not hasattr(filename, 'lower'):
_logger.warning(
@@ -3080,149 +3126,3 @@ class PlotWidget(qt.QMainWindow):
# Only call base class implementation when key is not handled.
# See QWidget.keyPressEvent for details.
super(PlotWidget, self).keyPressEvent(event)
-
- # Deprecated #
-
- def isDrawModeEnabled(self):
- """Deprecated, use :meth:`getInteractiveMode` instead.
-
- Return True if the current interactive state is drawing."""
- _logger.warning(
- 'isDrawModeEnabled deprecated, use getInteractiveMode instead')
- return self.getInteractiveMode()['mode'] == 'draw'
-
- def setDrawModeEnabled(self, flag=True, shape='polygon', label=None,
- color=None, **kwargs):
- """Deprecated, use :meth:`setInteractiveMode` instead.
-
- Set the drawing mode if flag is True and its parameters.
-
- If flag is False, only item selection is enabled.
-
- Warning: Zoom and drawing are not compatible and cannot be enabled
- simultaneously.
-
- :param bool flag: True to enable drawing and disable zoom and select.
- :param str shape: Type of item to be drawn in:
- hline, vline, rectangle, polygon (default)
- :param str label: Associated text for identifying draw signals
- :param color: The color to use to draw the selection area
- :type color: string ("#RRGGBB") or 4 column unsigned byte array or
- one of the predefined color names defined in colors.py
- """
- _logger.warning(
- 'setDrawModeEnabled deprecated, use setInteractiveMode instead')
-
- if kwargs:
- _logger.warning('setDrawModeEnabled ignores additional parameters')
-
- if color is None:
- color = 'black'
-
- if flag:
- self.setInteractiveMode('draw', shape=shape,
- label=label, color=color)
- elif self.getInteractiveMode()['mode'] == 'draw':
- self.setInteractiveMode('select')
-
- def getDrawMode(self):
- """Deprecated, use :meth:`getInteractiveMode` instead.
-
- Return the draw mode parameters as a dict of None.
-
- It returns None if the interactive mode is not a drawing mode,
- otherwise, it returns a dict containing the drawing mode parameters
- as provided to :meth:`setDrawModeEnabled`.
- """
- _logger.warning(
- 'getDrawMode deprecated, use getInteractiveMode instead')
- mode = self.getInteractiveMode()
- return mode if mode['mode'] == 'draw' else None
-
- def isZoomModeEnabled(self):
- """Deprecated, use :meth:`getInteractiveMode` instead.
-
- Return True if the current interactive state is zooming."""
- _logger.warning(
- 'isZoomModeEnabled deprecated, use getInteractiveMode instead')
- return self.getInteractiveMode()['mode'] == 'zoom'
-
- def setZoomModeEnabled(self, flag=True, color=None):
- """Deprecated, use :meth:`setInteractiveMode` instead.
-
- Set the zoom mode if flag is True, else item selection is enabled.
-
- Warning: Zoom and drawing are not compatible and cannot be enabled
- simultaneously
-
- :param bool flag: If True, enable zoom and select mode.
- :param color: The color to use to draw the selection area.
- (Default: 'black')
- :param color: The color to use to draw the selection area
- :type color: string ("#RRGGBB") or 4 column unsigned byte array or
- one of the predefined color names defined in colors.py
- """
- _logger.warning(
- 'setZoomModeEnabled deprecated, use setInteractiveMode instead')
- if color is None:
- color = 'black'
-
- if flag:
- self.setInteractiveMode('zoom', color=color)
- elif self.getInteractiveMode()['mode'] == 'zoom':
- self.setInteractiveMode('select')
-
- def insertMarker(self, *args, **kwargs):
- """Deprecated, use :meth:`addMarker` instead."""
- _logger.warning(
- 'insertMarker deprecated, use addMarker instead.')
- return self.addMarker(*args, **kwargs)
-
- def insertXMarker(self, *args, **kwargs):
- """Deprecated, use :meth:`addXMarker` instead."""
- _logger.warning(
- 'insertXMarker deprecated, use addXMarker instead.')
- return self.addXMarker(*args, **kwargs)
-
- def insertYMarker(self, *args, **kwargs):
- """Deprecated, use :meth:`addYMarker` instead."""
- _logger.warning(
- 'insertYMarker deprecated, use addYMarker instead.')
- return self.addYMarker(*args, **kwargs)
-
- def isActiveCurveHandlingEnabled(self):
- """Deprecated, use :meth:`isActiveCurveHandling` instead."""
- _logger.warning(
- 'isActiveCurveHandlingEnabled deprecated, '
- 'use isActiveCurveHandling instead.')
- return self.isActiveCurveHandling()
-
- def enableActiveCurveHandling(self, *args, **kwargs):
- """Deprecated, use :meth:`setActiveCurveHandling` instead."""
- _logger.warning(
- 'enableActiveCurveHandling deprecated, '
- 'use setActiveCurveHandling instead.')
- return self.setActiveCurveHandling(*args, **kwargs)
-
- def invertYAxis(self, *args, **kwargs):
- """Deprecated, use :meth:`Axis.setInverted` instead."""
- _logger.warning('invertYAxis deprecated, '
- 'use getYAxis().setInverted instead.')
- return self.getYAxis().setInverted(*args, **kwargs)
-
- def showGrid(self, flag=True):
- """Deprecated, use :meth:`setGraphGrid` instead."""
- _logger.warning("showGrid deprecated, use setGraphGrid instead")
- if flag in (0, False):
- flag = None
- elif flag in (1, True):
- flag = 'major'
- else:
- flag = 'both'
- return self.setGraphGrid(flag)
-
- def keepDataAspectRatio(self, *args, **kwargs):
- """Deprecated, use :meth:`setKeepDataAspectRatio`."""
- _logger.warning('keepDataAspectRatio deprecated,'
- 'use setKeepDataAspectRatio instead')
- return self.setKeepDataAspectRatio(*args, **kwargs)
diff --git a/silx/gui/plot/PlotWindow.py b/silx/gui/plot/PlotWindow.py
index 23ea399..b44a512 100644
--- a/silx/gui/plot/PlotWindow.py
+++ b/silx/gui/plot/PlotWindow.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -29,7 +29,7 @@ The :class:`PlotWindow` is a subclass of :class:`.PlotWidget`.
__authors__ = ["V.A. Sole", "T. Vincent"]
__license__ = "MIT"
-__date__ = "12/10/2018"
+__date__ = "21/12/2018"
import collections
import logging
@@ -217,10 +217,8 @@ class PlotWindow(PlotWidget):
# Make colorbar background white
self._colorbar.setAutoFillBackground(True)
- palette = self._colorbar.palette()
- palette.setColor(qt.QPalette.Background, qt.Qt.white)
- palette.setColor(qt.QPalette.Window, qt.Qt.white)
- self._colorbar.setPalette(palette)
+ self._sigAxesVisibilityChanged.connect(self._updateColorBarBackground)
+ self._updateColorBarBackground()
gridLayout = qt.QGridLayout()
gridLayout.setSpacing(0)
@@ -294,6 +292,43 @@ class PlotWindow(PlotWidget):
for action in toolbar.actions():
self.addAction(action)
+ def setBackgroundColor(self, color):
+ super(PlotWindow, self).setBackgroundColor(color)
+ self._updateColorBarBackground()
+
+ setBackgroundColor.__doc__ = PlotWidget.setBackgroundColor.__doc__
+
+ def setDataBackgroundColor(self, color):
+ super(PlotWindow, self).setDataBackgroundColor(color)
+ self._updateColorBarBackground()
+
+ setDataBackgroundColor.__doc__ = PlotWidget.setDataBackgroundColor.__doc__
+
+ def setForegroundColor(self, color):
+ super(PlotWindow, self).setForegroundColor(color)
+ self._updateColorBarBackground()
+
+ setForegroundColor.__doc__ = PlotWidget.setForegroundColor.__doc__
+
+ def _updateColorBarBackground(self):
+ """Update the colorbar background according to the state of the plot"""
+ if self._isAxesDisplayed():
+ color = self.getBackgroundColor()
+ else:
+ color = self.getDataBackgroundColor()
+ if not color.isValid():
+ # If no color defined, use the background one
+ color = self.getBackgroundColor()
+
+ foreground = self.getForegroundColor()
+
+ palette = self._colorbar.palette()
+ palette.setColor(qt.QPalette.Background, color)
+ palette.setColor(qt.QPalette.Window, color)
+ palette.setColor(qt.QPalette.WindowText, foreground)
+ palette.setColor(qt.QPalette.Text, foreground)
+ self._colorbar.setPalette(palette)
+
def getInteractiveModeToolBar(self):
"""Returns QToolBar controlling interactive mode.
@@ -457,10 +492,6 @@ class PlotWindow(PlotWidget):
return self._colorbar
# getters for dock widgets
- @property
- @deprecated(replacement="getLegendsDockWidget()", since_version="0.4.0")
- def legendsDockWidget(self):
- return self.getLegendsDockWidget()
def getLegendsDockWidget(self):
"""DockWidget with Legend panel"""
@@ -470,11 +501,6 @@ class PlotWindow(PlotWidget):
self.addTabbedDockWidget(self._legendsDockWidget)
return self._legendsDockWidget
- @property
- @deprecated(replacement="getCurvesRoiWidget()", since_version="0.4.0")
- def curvesROIDockWidget(self):
- return self.getCurvesRoiDockWidget()
-
def getCurvesRoiDockWidget(self):
# Undocumented for a "soft deprecation" in version 0.7.0
# (still used internally for lazy loading)
@@ -496,11 +522,6 @@ class PlotWindow(PlotWidget):
"""
return self.getCurvesRoiDockWidget().roiWidget
- @property
- @deprecated(replacement="getMaskToolsDockWidget()", since_version="0.4.0")
- def maskToolsDockWidget(self):
- return self.getMaskToolsDockWidget()
-
def getMaskToolsDockWidget(self):
"""DockWidget with image mask panel (lazy-loaded)."""
if self._maskToolsDockWidget is None:
@@ -539,11 +560,6 @@ class PlotWindow(PlotWidget):
def panModeAction(self):
return self.getInteractiveModeToolBar().getPanModeAction()
- @property
- @deprecated(replacement="getConsoleAction()", since_version="0.4.0")
- def consoleAction(self):
- return self.getConsoleAction()
-
def getConsoleAction(self):
"""QAction handling the IPython console activation.
@@ -563,11 +579,6 @@ class PlotWindow(PlotWidget):
self._consoleAction.setEnabled(False)
return self._consoleAction
- @property
- @deprecated(replacement="getCrosshairAction()", since_version="0.4.0")
- def crosshairAction(self):
- return self.getCrosshairAction()
-
def getCrosshairAction(self):
"""Action toggling crosshair cursor mode.
@@ -577,11 +588,6 @@ class PlotWindow(PlotWidget):
self._crosshairAction = actions.control.CrosshairAction(self, color='red')
return self._crosshairAction
- @property
- @deprecated(replacement="getMaskAction()", since_version="0.4.0")
- def maskAction(self):
- return self.getMaskAction()
-
def getMaskAction(self):
"""QAction toggling image mask dock widget
@@ -589,12 +595,6 @@ class PlotWindow(PlotWidget):
"""
return self.getMaskToolsDockWidget().toggleViewAction()
- @property
- @deprecated(replacement="getPanWithArrowKeysAction()",
- since_version="0.4.0")
- def panWithArrowKeysAction(self):
- return self.getPanWithArrowKeysAction()
-
def getPanWithArrowKeysAction(self):
"""Action toggling pan with arrow keys.
@@ -604,11 +604,6 @@ class PlotWindow(PlotWidget):
self._panWithArrowKeysAction = actions.control.PanWithArrowKeysAction(self)
return self._panWithArrowKeysAction
- @property
- @deprecated(replacement="getRoiAction()", since_version="0.4.0")
- def roiAction(self):
- return self.getRoiAction()
-
def getStatsAction(self):
if self._statsAction is None:
self._statsAction = qt.QAction('Curves stats', self)
diff --git a/silx/gui/plot/PrintPreviewToolButton.py b/silx/gui/plot/PrintPreviewToolButton.py
index b48505d..d857c18 100644
--- a/silx/gui/plot/PrintPreviewToolButton.py
+++ b/silx/gui/plot/PrintPreviewToolButton.py
@@ -111,10 +111,11 @@ from .. import icons
from . import PlotWidget
from ..widgets.PrintPreview import PrintPreviewDialog, SingletonPrintPreviewDialog
from ..widgets.PrintGeometryDialog import PrintGeometryDialog
+from silx.utils.deprecation import deprecated
__authors__ = ["P. Knobel"]
__license__ = "MIT"
-__date__ = "18/07/2017"
+__date__ = "20/12/2018"
_logger = logging.getLogger(__name__)
# _logger.setLevel(logging.DEBUG)
@@ -132,19 +133,19 @@ class PrintPreviewToolButton(qt.QToolButton):
if not isinstance(plot, PlotWidget):
raise TypeError("plot parameter must be a PlotWidget")
- self.plot = plot
+ self._plot = plot
self.setIcon(icons.getQIcon('document-print'))
printGeomAction = qt.QAction("Print geometry", self)
printGeomAction.setToolTip("Define a print geometry prior to sending "
"the plot to the print preview dialog")
- printGeomAction.setIcon(icons.getQIcon('shape-rectangle')) # fixme: icon not displayed in menu
+ printGeomAction.setIcon(icons.getQIcon('shape-rectangle'))
printGeomAction.triggered.connect(self._setPrintConfiguration)
printPreviewAction = qt.QAction("Print preview", self)
printPreviewAction.setToolTip("Send plot to the print preview dialog")
- printPreviewAction.setIcon(icons.getQIcon('document-print')) # fixme: icon not displayed
+ printPreviewAction.setIcon(icons.getQIcon('document-print'))
printPreviewAction.triggered.connect(self._plotToPrintPreview)
menu = qt.QMenu(self)
@@ -172,24 +173,64 @@ class PrintPreviewToolButton(qt.QToolButton):
self._printPreviewDialog = PrintPreviewDialog(self.parent())
return self._printPreviewDialog
+ def getTitle(self):
+ """Implement this method to fetch the title in the plot.
+
+ :return: Title to be printed above the plot, or None (no title added)
+ :rtype: str or None
+ """
+ return None
+
+ def getCommentAndPosition(self):
+ """Implement this method to fetch the legend to be printed below the
+ figure and its position.
+
+ :return: Legend to be printed below the figure and its position:
+ "CENTER", "LEFT" or "RIGHT"
+ :rtype: (str, str) or (None, None)
+ """
+ return None, None
+
+ @property
+ @deprecated(since_version="0.10",
+ replacement="getPlot()")
+ def plot(self):
+ return self._plot
+
+ def getPlot(self):
+ """Return the :class:`.PlotWidget` associated with this tool button.
+
+ :rtype: :class:`.PlotWidget`
+ """
+ return self._plot
+
def _plotToPrintPreview(self):
"""Grab the plot widget and send it to the print preview dialog.
Make sure the print preview dialog is shown and raised."""
if not self.printPreviewDialog.ensurePrinterIsSet():
return
+ comment, commentPosition = self.getCommentAndPosition()
+
if qt.HAS_SVG:
svgRenderer, viewBox = self._getSvgRendererAndViewbox()
self.printPreviewDialog.addSvgItem(svgRenderer,
- viewBox=viewBox)
+ title=self.getTitle(),
+ comment=comment,
+ commentPosition=commentPosition,
+ viewBox=viewBox,
+ keepRatio=self._printGeometry["keepAspectRatio"])
else:
_logger.warning("Missing QtSvg library, using a raster image")
if qt.BINDING in ["PyQt4", "PySide"]:
- pixmap = qt.QPixmap.grabWidget(self.plot.centralWidget())
+ pixmap = qt.QPixmap.grabWidget(self._plot.centralWidget())
else:
# PyQt5 and hopefully PyQt6+
- pixmap = self.plot.centralWidget().grab()
- self.printPreviewDialog.addPixmap(pixmap)
+ pixmap = self._plot.centralWidget().grab()
+ self.printPreviewDialog.addPixmap(pixmap,
+ title=self.getTitle(),
+ comment=comment,
+ commentPosition=commentPosition)
self.printPreviewDialog.show()
self.printPreviewDialog.raise_()
@@ -201,7 +242,7 @@ class PrintPreviewToolButton(qt.QToolButton):
and to the geometry configuration (width, height, ratio) specified
by the user."""
imgData = StringIO()
- assert self.plot.saveGraph(imgData, fileFormat="svg"), \
+ assert self._plot.saveGraph(imgData, fileFormat="svg"), \
"Unable to save graph"
imgData.flush()
imgData.seek(0)
@@ -310,7 +351,7 @@ class PrintPreviewToolButton(qt.QToolButton):
self._printGeometry = self._printConfigurationDialog.getPrintGeometry()
def _getPlotAspectRatio(self):
- widget = self.plot.centralWidget()
+ widget = self._plot.centralWidget()
graphWidth = float(widget.width())
graphHeight = float(widget.height())
return graphHeight / graphWidth
diff --git a/silx/gui/plot/Profile.py b/silx/gui/plot/Profile.py
index 182cf60..46e4523 100644
--- a/silx/gui/plot/Profile.py
+++ b/silx/gui/plot/Profile.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -180,7 +180,8 @@ def createProfile(roiInfo, currentData, origin, scale, lineWidth, method):
:type scale: 2-tuple of float
:param int lineWidth: width of the profile line
:param str method: method to compute the profile. Can be 'mean' or 'sum'
- :return: `profile, area, profileName, xLabel`, where:
+ :return: `coords, profile, area, profileName, xLabel`, where:
+ - coords is the X coordinate to use to display the profile
- profile is a 2D array of the profiles of the stack of images.
For a single image, the profile is a curve, so this parameter
has a shape *(1, len(curve))*
@@ -188,10 +189,9 @@ def createProfile(roiInfo, currentData, origin, scale, lineWidth, method):
the effective ROI area corners in plot coords.
- profileName is a string describing the ROI, meant to be used as
title of the profile plot
- - xLabel is a string describing the meaning of the X axis on the
- profile plot ("rows", "columns", "distance")
+ - xLabel the label for X in the profile window
- :rtype: tuple(ndarray, (ndarray, ndarray), str, str)
+ :rtype: tuple(ndarray,ndarray,(ndarray,ndarray),str)
"""
if currentData is None or roiInfo is None or lineWidth is None:
raise ValueError("createProfile called with invalide arguments")
@@ -212,12 +212,15 @@ def createProfile(roiInfo, currentData, origin, scale, lineWidth, method):
axis=0,
method=method)
+ coords = numpy.arange(len(profile[0]), dtype=numpy.float32)
+ coords = coords * scale[0] + origin[0]
+
yMin, yMax = min(area[1]), max(area[1]) - 1
if roiWidth <= 1:
profileName = 'Y = %g' % yMin
else:
profileName = 'Y = [%g, %g]' % (yMin, yMax)
- xLabel = 'Columns'
+ xLabel = 'X'
elif lineProjectionMode == 'Y': # Vertical profile on the whole image
profile, area = _alignedFullProfile(currentData3D,
@@ -226,12 +229,15 @@ def createProfile(roiInfo, currentData, origin, scale, lineWidth, method):
axis=1,
method=method)
+ coords = numpy.arange(len(profile[0]), dtype=numpy.float32)
+ coords = coords * scale[1] + origin[1]
+
xMin, xMax = min(area[0]), max(area[0]) - 1
if roiWidth <= 1:
profileName = 'X = %g' % xMin
else:
profileName = 'X = [%g, %g]' % (xMin, xMax)
- xLabel = 'Rows'
+ xLabel = 'Y'
else: # Free line profile
@@ -306,35 +312,52 @@ def createProfile(roiInfo, currentData, origin, scale, lineWidth, method):
dCol = (endPt[1] - startPt[1]) / length
# Extend ROI with half a pixel on each end
- startPt = startPt[0] - 0.5 * dRow, startPt[1] - 0.5 * dCol
- endPt = endPt[0] + 0.5 * dRow, endPt[1] + 0.5 * dCol
+ roiStartPt = startPt[0] - 0.5 * dRow, startPt[1] - 0.5 * dCol
+ roiEndPt = endPt[0] + 0.5 * dRow, endPt[1] + 0.5 * dCol
# Rotate deltas by 90 degrees to apply line width
dRow, dCol = dCol, -dRow
area = (
- numpy.array((startPt[1] - 0.5 * roiWidth * dCol,
- startPt[1] + 0.5 * roiWidth * dCol,
- endPt[1] + 0.5 * roiWidth * dCol,
- endPt[1] - 0.5 * roiWidth * dCol),
+ numpy.array((roiStartPt[1] - 0.5 * roiWidth * dCol,
+ roiStartPt[1] + 0.5 * roiWidth * dCol,
+ roiEndPt[1] + 0.5 * roiWidth * dCol,
+ roiEndPt[1] - 0.5 * roiWidth * dCol),
dtype=numpy.float32) * scale[0] + origin[0],
- numpy.array((startPt[0] - 0.5 * roiWidth * dRow,
- startPt[0] + 0.5 * roiWidth * dRow,
- endPt[0] + 0.5 * roiWidth * dRow,
- endPt[0] - 0.5 * roiWidth * dRow),
+ numpy.array((roiStartPt[0] - 0.5 * roiWidth * dRow,
+ roiStartPt[0] + 0.5 * roiWidth * dRow,
+ roiEndPt[0] + 0.5 * roiWidth * dRow,
+ roiEndPt[0] - 0.5 * roiWidth * dRow),
dtype=numpy.float32) * scale[1] + origin[1])
- y0, x0 = startPt
- y1, x1 = endPt
- if x1 == x0 or y1 == y0:
- profileName = 'From (%g, %g) to (%g, %g)' % (x0, y0, x1, y1)
+ # Convert start and end points back to plot coords
+ y0 = startPt[0] * scale[1] + origin[1]
+ x0 = startPt[1] * scale[0] + origin[0]
+ y1 = endPt[0] * scale[1] + origin[1]
+ x1 = endPt[1] * scale[0] + origin[0]
+
+ if startPt[1] == endPt[1]:
+ profileName = 'X = %g; Y = [%g, %g]' % (x0, y0, y1)
+ coords = numpy.arange(len(profile[0]), dtype=numpy.float32)
+ coords = coords * scale[1] + y0
+ xLabel = 'Y'
+
+ elif startPt[0] == endPt[0]:
+ profileName = 'Y = %g; X = [%g, %g]' % (y0, x0, x1)
+ coords = numpy.arange(len(profile[0]), dtype=numpy.float32)
+ coords = coords * scale[0] + x0
+ xLabel = 'X'
+
else:
m = (y1 - y0) / (x1 - x0)
b = y0 - m * x0
profileName = 'y = %g * x %+g ; width=%d' % (m, b, roiWidth)
- xLabel = 'Distance'
+ coords = numpy.linspace(x0, x1, len(profile[0]),
+ endpoint=True,
+ dtype=numpy.float32)
+ xLabel = 'X'
- return profile, area, profileName, xLabel
+ return coords, profile, area, profileName, xLabel
# ProfileToolBar ##############################################################
@@ -458,7 +481,7 @@ class ProfileToolBar(qt.QToolBar):
self.addWidget(self.lineWidthSpinBox)
self.methodsButton = ProfileOptionToolButton(parent=self, plot=self)
- self.addWidget(self.methodsButton)
+ self.__profileOptionToolAction = self.addWidget(self.methodsButton)
# TODO: add connection with the signal
self.methodsButton.sigMethodChanged.connect(self.setProfileMethod)
@@ -650,7 +673,7 @@ class ProfileToolBar(qt.QToolBar):
if self._roiInfo is None:
return
- profile, area, profileName, xLabel = createProfile(
+ coords, profile, area, profileName, xLabel = createProfile(
roiInfo=self._roiInfo,
currentData=currentData,
origin=origin,
@@ -658,28 +681,25 @@ class ProfileToolBar(qt.QToolBar):
lineWidth=self.lineWidthSpinBox.value(),
method=method)
- self.getProfilePlot().setGraphTitle(profileName)
+ profilePlot = self.getProfilePlot()
+
+ profilePlot.setGraphTitle(profileName)
+ profilePlot.getXAxis().setLabel(xLabel)
dataIs3D = len(currentData.shape) > 2
if dataIs3D:
- self.getProfilePlot().addImage(profile,
- legend=profileName,
- xlabel=xLabel,
- ylabel="Frame index (depth)",
- colormap=colormap)
+ profileScale = (coords[-1] - coords[0]) / profile.shape[1], 1
+ profilePlot.addImage(profile,
+ legend=profileName,
+ colormap=colormap,
+ origin=(coords[0], 0),
+ scale=profileScale)
+ profilePlot.getYAxis().setLabel("Frame index (depth)")
else:
- coords = numpy.arange(len(profile[0]), dtype=numpy.float32)
- # Scale horizontal and vertical profile coordinates
- if self._roiInfo[2] == 'X':
- coords = coords * scale[0] + origin[0]
- elif self._roiInfo[2] == 'Y':
- coords = coords * scale[1] + origin[1]
-
- self.getProfilePlot().addCurve(coords,
- profile[0],
- legend=profileName,
- xlabel=xLabel,
- color=self.overlayColor)
+ profilePlot.addCurve(coords,
+ profile[0],
+ legend=profileName,
+ color=self.overlayColor)
self.plot.addItem(area[0], area[1],
legend=self._POLYGON_LEGEND,
@@ -732,6 +752,9 @@ class ProfileToolBar(qt.QToolBar):
def getProfileMethod(self):
return self._method
+ def getProfileOptionToolAction(self):
+ return self.__profileOptionToolAction
+
class Profile3DToolBar(ProfileToolBar):
def __init__(self, parent=None, stackview=None,
diff --git a/silx/gui/plot/ScatterMaskToolsWidget.py b/silx/gui/plot/ScatterMaskToolsWidget.py
index de645be..0c6797f 100644
--- a/silx/gui/plot/ScatterMaskToolsWidget.py
+++ b/silx/gui/plot/ScatterMaskToolsWidget.py
@@ -35,7 +35,7 @@ from __future__ import division
__authors__ = ["P. Knobel"]
__license__ = "MIT"
-__date__ = "24/04/2018"
+__date__ = "15/02/2019"
import math
@@ -152,6 +152,22 @@ class ScatterMask(BaseMask):
stencil = (y - cy)**2 + (x - cx)**2 < radius**2
self.updateStencil(level, stencil, mask)
+ def updateEllipse(self, level, crow, ccol, radius_r, radius_c, mask=True):
+ """Mask/Unmask an ellipse of the given mask level.
+
+ :param int level: Mask level to update.
+ :param int crow: Row of the center of the ellipse
+ :param int ccol: Column of the center of the ellipse
+ :param float radius_r: Radius of the ellipse in the row
+ :param float radius_c: Radius of the ellipse in the column
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ def is_inside(px, py):
+ return (px - ccol)**2 / radius_c**2 + (py - crow)**2 / radius_r**2 <= 1.0
+ x, y = self._getXY()
+ indices_inside = [idx for idx in range(len(x)) if is_inside(x[idx], y[idx])]
+ self.updatePoints(level, indices_inside, mask)
+
def updateLine(self, level, y0, x0, y1, x1, width, mask=True):
"""Mask/Unmask points inside a rectangle defined by a line (two
end points) and a width.
@@ -490,26 +506,35 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget):
level = self.levelSpinBox.value()
- if (self._drawingMode == 'rectangle' and
- event['event'] == 'drawingFinished'):
- doMask = self._isMasking()
+ if self._drawingMode == 'rectangle':
+ if event['event'] == 'drawingFinished':
+ doMask = self._isMasking()
+
+ self._mask.updateRectangle(
+ level,
+ y=event['y'],
+ x=event['x'],
+ height=abs(event['height']),
+ width=abs(event['width']),
+ mask=doMask)
+ self._mask.commit()
- self._mask.updateRectangle(
- level,
- y=event['y'],
- x=event['x'],
- height=abs(event['height']),
- width=abs(event['width']),
- mask=doMask)
- self._mask.commit()
+ elif self._drawingMode == 'ellipse':
+ if event['event'] == 'drawingFinished':
+ doMask = self._isMasking()
+ center = event['points'][0]
+ size = event['points'][1]
+ self._mask.updateEllipse(level, center[1], center[0],
+ size[1], size[0], doMask)
+ self._mask.commit()
- elif (self._drawingMode == 'polygon' and
- event['event'] == 'drawingFinished'):
- doMask = self._isMasking()
- vertices = event['points']
- vertices = vertices[:, (1, 0)] # (y, x)
- self._mask.updatePolygon(level, vertices, doMask)
- self._mask.commit()
+ elif self._drawingMode == 'polygon':
+ if event['event'] == 'drawingFinished':
+ doMask = self._isMasking()
+ vertices = event['points']
+ vertices = vertices[:, (1, 0)] # (y, x)
+ self._mask.updatePolygon(level, vertices, doMask)
+ self._mask.commit()
elif self._drawingMode == 'pencil':
doMask = self._isMasking()
@@ -536,6 +561,8 @@ class ScatterMaskToolsWidget(BaseMaskToolsWidget):
self._lastPencilPos = None
else:
self._lastPencilPos = y, x
+ else:
+ _logger.error("Drawing mode %s unsupported", self._drawingMode)
def _loadRangeFromColormapTriggered(self):
"""Set range from active scatter colormap range"""
diff --git a/silx/gui/plot/ScatterView.py b/silx/gui/plot/ScatterView.py
index ae79cf9..5fc66ef 100644
--- a/silx/gui/plot/ScatterView.py
+++ b/silx/gui/plot/ScatterView.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2018 European Synchrotron Radiation Facility
+# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -353,3 +353,13 @@ class ScatterView(qt.QMainWindow):
return self.getPlotWidget().resetZoom(*args, **kwargs)
resetZoom.__doc__ = PlotWidget.resetZoom.__doc__
+
+ def getSelectionMask(self, *args, **kwargs):
+ return self.getMaskToolsWidget().getSelectionMask(*args, **kwargs)
+
+ getSelectionMask.__doc__ = ScatterMaskToolsWidget.getSelectionMask.__doc__
+
+ def setSelectionMask(self, *args, **kwargs):
+ return self.getMaskToolsWidget().setSelectionMask(*args, **kwargs)
+
+ setSelectionMask.__doc__ = ScatterMaskToolsWidget.setSelectionMask.__doc__
diff --git a/silx/gui/plot/StackView.py b/silx/gui/plot/StackView.py
index 72b6cd4..2a3d7e8 100644
--- a/silx/gui/plot/StackView.py
+++ b/silx/gui/plot/StackView.py
@@ -89,14 +89,8 @@ from silx.utils.array_like import DatasetView, ListOfImages
from silx.math import calibration
from silx.utils.deprecation import deprecated_warning
-try:
- import h5py
-except ImportError:
- def is_dataset(obj):
- return False
- h5py = None
-else:
- from silx.io.utils import is_dataset
+import h5py
+from silx.io.utils import is_dataset
_logger = logging.getLogger(__name__)
diff --git a/silx/gui/plot/StatsWidget.py b/silx/gui/plot/StatsWidget.py
index bb66613..4ba4fab 100644
--- a/silx/gui/plot/StatsWidget.py
+++ b/silx/gui/plot/StatsWidget.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -31,552 +31,1266 @@ __license__ = "MIT"
__date__ = "24/07/2018"
-import functools
+from collections import OrderedDict
+from contextlib import contextmanager
import logging
+import weakref
+
import numpy
-from collections import OrderedDict
-import silx.utils.weakref
from silx.gui import qt
from silx.gui import icons
-from silx.gui.plot.items.curve import Curve as CurveItem
-from silx.gui.plot.items.histogram import Histogram as HistogramItem
-from silx.gui.plot.items.image import ImageBase as ImageItem
-from silx.gui.plot.items.scatter import Scatter as ScatterItem
from silx.gui.plot import stats as statsmdl
from silx.gui.widgets.TableWidget import TableWidget
from silx.gui.plot.stats.statshandler import StatsHandler, StatFormatter
+from silx.gui.plot.items.core import ItemChangedType
+from silx.gui.widgets.FlowLayout import FlowLayout
+from . import PlotWidget
+from . import items as plotitems
-logger = logging.getLogger(__name__)
+_logger = logging.getLogger(__name__)
-class StatsWidget(qt.QWidget):
+
+# Helper class to handle specific calls to PlotWidget and SceneWidget
+
+class _Wrapper(qt.QObject):
+ """Base class for connection with PlotWidget and SceneWidget.
+
+ This class is used when no PlotWidget or SceneWidget is connected.
+
+ :param plot: The plot to be used
"""
- Widget displaying a set of :class:`Stat` to be displayed on a
- :class:`StatsTable` and to be apply on items contained in the :class:`Plot`
- Also contains options to:
- * compute statistics on all the data or on visible data only
- * show statistics of all items or only the active one
+ sigItemAdded = qt.Signal(object)
+ """Signal emitted when a new item is added.
- :param parent: Qt parent
- :param plot: the plot containing items on which we want statistics.
+ It provides the added item.
"""
- sigVisibilityChanged = qt.Signal(bool)
+ sigItemRemoved = qt.Signal(object)
+ """Signal emitted when an item is (about to be) removed.
- NUMBER_FORMAT = '{0:.3f}'
+ It provides the removed item.
+ """
- class OptionsWidget(qt.QToolBar):
-
- def __init__(self, parent=None):
- qt.QToolBar.__init__(self, parent)
- self.setIconSize(qt.QSize(16, 16))
-
- action = qt.QAction(self)
- action.setIcon(icons.getQIcon("stats-active-items"))
- action.setText("Active items only")
- action.setToolTip("Display stats for active items only.")
- action.setCheckable(True)
- action.setChecked(True)
- self.__displayActiveItems = action
-
- action = qt.QAction(self)
- action.setIcon(icons.getQIcon("stats-whole-items"))
- action.setText("All items")
- action.setToolTip("Display stats for all available items.")
- action.setCheckable(True)
- self.__displayWholeItems = action
-
- action = qt.QAction(self)
- action.setIcon(icons.getQIcon("stats-visible-data"))
- action.setText("Use the visible data range")
- action.setToolTip("Use the visible data range.<br/>"
- "If activated the data is filtered to only use"
- "visible data of the plot."
- "The filtering is a data sub-sampling."
- "No interpolation is made to fit data to"
- "boundaries.")
- action.setCheckable(True)
- self.__useVisibleData = action
-
- action = qt.QAction(self)
- action.setIcon(icons.getQIcon("stats-whole-data"))
- action.setText("Use the full data range")
- action.setToolTip("Use the full data range.")
- action.setCheckable(True)
- action.setChecked(True)
- self.__useWholeData = action
-
- self.addAction(self.__displayWholeItems)
- self.addAction(self.__displayActiveItems)
- self.addSeparator()
- self.addAction(self.__useVisibleData)
- self.addAction(self.__useWholeData)
-
- self.itemSelection = qt.QActionGroup(self)
- self.itemSelection.setExclusive(True)
- self.itemSelection.addAction(self.__displayActiveItems)
- self.itemSelection.addAction(self.__displayWholeItems)
-
- self.dataRangeSelection = qt.QActionGroup(self)
- self.dataRangeSelection.setExclusive(True)
- self.dataRangeSelection.addAction(self.__useWholeData)
- self.dataRangeSelection.addAction(self.__useVisibleData)
-
- def isActiveItemMode(self):
- return self.itemSelection.checkedAction() is self.__displayActiveItems
-
- def isVisibleDataRangeMode(self):
- return self.dataRangeSelection.checkedAction() is self.__useVisibleData
+ sigCurrentChanged = qt.Signal(object)
+ """Signal emitted when the current item has changed.
- def __init__(self, parent=None, plot=None, stats=None):
- qt.QWidget.__init__(self, parent)
- self.setLayout(qt.QVBoxLayout())
- self.layout().setContentsMargins(0, 0, 0, 0)
- self._options = self.OptionsWidget(parent=self)
- self.layout().addWidget(self._options)
- self._statsTable = StatsTable(parent=self, plot=plot)
- self.setStats = self._statsTable.setStats
- self.setStats(stats)
+ It provides the current item.
+ """
- self.layout().addWidget(self._statsTable)
- self.setPlot = self._statsTable.setPlot
+ sigVisibleDataChanged = qt.Signal()
+ """Signal emitted when the visible data area has changed"""
- self._options.itemSelection.triggered.connect(
- self._optSelectionChanged)
- self._options.dataRangeSelection.triggered.connect(
- self._optDataRangeChanged)
- self._optSelectionChanged()
- self._optDataRangeChanged()
+ def __init__(self, plot=None):
+ super(_Wrapper, self).__init__(parent=None)
+ self._plotRef = None if plot is None else weakref.ref(plot)
- self.setDisplayOnlyActiveItem = self._statsTable.setDisplayOnlyActiveItem
- self.setStatsOnVisibleData = self._statsTable.setStatsOnVisibleData
+ def getPlot(self):
+ """Returns the plot attached to this widget"""
+ return None if self._plotRef is None else self._plotRef()
- def showEvent(self, event):
- self.sigVisibilityChanged.emit(True)
- qt.QWidget.showEvent(self, event)
+ def getItems(self):
+ """Returns the list of items in the plot
- def hideEvent(self, event):
- self.sigVisibilityChanged.emit(False)
- qt.QWidget.hideEvent(self, event)
+ :rtype: List[object]
+ """
+ return ()
- def _optSelectionChanged(self, action=None):
- self._statsTable.setDisplayOnlyActiveItem(self._options.isActiveItemMode())
+ def getSelectedItems(self):
+ """Returns the list of selected items in the plot
- def _optDataRangeChanged(self, action=None):
- self._statsTable.setStatsOnVisibleData(self._options.isVisibleDataRangeMode())
+ :rtype: List[object]
+ """
+ return ()
+ def setCurrentItem(self, item):
+ """Set the current/active item in the plot
-class BasicStatsWidget(StatsWidget):
+ :param item: The plot item to set as active/current
+ """
+ pass
+
+ def getLabel(self, item):
+ """Returns the label of the given item.
+
+ :param item:
+ :rtype: str
+ """
+ return ''
+
+ def getKind(self, item):
+ """Returns the kind of an item or None if not supported
+
+ :param item:
+ :rtype: Union[str,None]
+ """
+ return None
+
+
+class _PlotWidgetWrapper(_Wrapper):
+ """Class handling PlotWidget specific calls and signal connections
+
+ See :class:`._Wrapper` for documentation
+
+ :param PlotWidget plot:
"""
- Widget defining a simple set of :class:`Stat` to be displayed on a
- :class:`StatsWidget`.
- :param parent: Qt parent
- :param plot: the plot containing items on which we want statistics.
+ def __init__(self, plot):
+ assert isinstance(plot, PlotWidget)
+ super(_PlotWidgetWrapper, self).__init__(plot)
+ plot.sigItemAdded.connect(self.sigItemAdded.emit)
+ plot.sigItemAboutToBeRemoved.connect(self.sigItemRemoved.emit)
+ plot.sigActiveCurveChanged.connect(self._activeCurveChanged)
+ plot.sigActiveImageChanged.connect(self._activeImageChanged)
+ plot.sigActiveScatterChanged.connect(self._activeScatterChanged)
+ plot.sigPlotSignal.connect(self._limitsChanged)
+
+ def _activeChanged(self, kind):
+ """Handle change of active curve/image/scatter"""
+ plot = self.getPlot()
+ if plot is not None:
+ item = plot._getActiveItem(kind=kind)
+ if item is None or self.getKind(item) is not None:
+ self.sigCurrentChanged.emit(item)
+
+ def _activeCurveChanged(self, previous, current):
+ self._activeChanged(kind='curve')
+
+ def _activeImageChanged(self, previous, current):
+ self._activeChanged(kind='image')
+
+ def _activeScatterChanged(self, previous, current):
+ self._activeChanged(kind='scatter')
+
+ def _limitsChanged(self, event):
+ """Handle change of plot area limits."""
+ if event['event'] == 'limitsChanged':
+ self.sigVisibleDataChanged.emit()
+
+ def getItems(self):
+ plot = self.getPlot()
+ return () if plot is None else plot._getItems()
+
+ def getSelectedItems(self):
+ plot = self.getPlot()
+ items = []
+ if plot is not None:
+ for kind in plot._ACTIVE_ITEM_KINDS:
+ item = plot._getActiveItem(kind=kind)
+ if item is not None:
+ items.append(item)
+ return tuple(items)
+
+ def setCurrentItem(self, item):
+ plot = self.getPlot()
+ if plot is not None:
+ kind = self.getKind(item)
+ if kind in plot._ACTIVE_ITEM_KINDS:
+ if plot._getActiveItem(kind) != item:
+ plot._setActiveItem(kind, item.getLegend())
+
+ def getLabel(self, item):
+ return item.getLegend()
+
+ def getKind(self, item):
+ if isinstance(item, plotitems.Curve):
+ return 'curve'
+ elif isinstance(item, plotitems.ImageData):
+ return 'image'
+ elif isinstance(item, plotitems.Scatter):
+ return 'scatter'
+ elif isinstance(item, plotitems.Histogram):
+ return 'histogram'
+ else:
+ return None
+
+
+class _SceneWidgetWrapper(_Wrapper):
+ """Class handling SceneWidget specific calls and signal connections
+
+ See :class:`._Wrapper` for documentation
+
+ :param SceneWidget plot:
"""
- STATS = StatsHandler((
- (statsmdl.StatMin(), StatFormatter()),
- statsmdl.StatCoordMin(),
- (statsmdl.StatMax(), StatFormatter()),
- statsmdl.StatCoordMax(),
- (('std', numpy.std), StatFormatter()),
- (('mean', numpy.mean), StatFormatter()),
- statsmdl.StatCOM()
- ))
+ def __init__(self, plot):
+ # Lazy-import to avoid circular imports
+ from ..plot3d.SceneWidget import SceneWidget
- def __init__(self, parent=None, plot=None):
- StatsWidget.__init__(self, parent=parent, plot=plot, stats=self.STATS)
+ assert isinstance(plot, SceneWidget)
+ super(_SceneWidgetWrapper, self).__init__(plot)
+ plot.getSceneGroup().sigItemAdded.connect(self.sigItemAdded)
+ plot.getSceneGroup().sigItemRemoved.connect(self.sigItemRemoved)
+ plot.selection().sigCurrentChanged.connect(self._currentChanged)
+ # sigVisibleDataChanged is never emitted
+
+ def _currentChanged(self, current, previous):
+ self.sigCurrentChanged.emit(current)
+
+ def getItems(self):
+ plot = self.getPlot()
+ return () if plot is None else tuple(plot.getSceneGroup().visit())
+
+ def getSelectedItems(self):
+ plot = self.getPlot()
+ return () if plot is None else (plot.selection().getCurrentItem(),)
+ def setCurrentItem(self, item):
+ plot = self.getPlot()
+ if plot is not None:
+ plot.selection().setCurrentItem(item)
-class StatsTable(TableWidget):
+ def getLabel(self, item):
+ return item.getLabel()
+
+ def getKind(self, item):
+ from ..plot3d import items as plot3ditems
+
+ if isinstance(item, (plot3ditems.ImageData,
+ plot3ditems.ScalarField3D)):
+ return 'image'
+ elif isinstance(item, (plot3ditems.Scatter2D,
+ plot3ditems.Scatter3D)):
+ return 'scatter'
+ else:
+ return None
+
+
+class _ScalarFieldViewWrapper(_Wrapper):
+ """Class handling ScalarFieldView specific calls and signal connections
+
+ See :class:`._Wrapper` for documentation
+
+ :param SceneWidget plot:
"""
- TableWidget displaying for each curves contained by the Plot some
- information:
- * legend
- * minimal value
- * maximal value
- * standard deviation (std)
+ def __init__(self, plot):
+ # Lazy-import to avoid circular imports
+ from ..plot3d.ScalarFieldView import ScalarFieldView
+ from ..plot3d.items import ScalarField3D
+
+ assert isinstance(plot, ScalarFieldView)
+ super(_ScalarFieldViewWrapper, self).__init__(plot)
+ self._item = ScalarField3D()
+ self._dataChanged()
+ plot.sigDataChanged.connect(self._dataChanged)
+ # sigItemAdded, sigItemRemoved, sigVisibleDataChanged are never emitted
+
+ def _dataChanged(self):
+ plot = self.getPlot()
+ if plot is not None:
+ self._item.setData(plot.getData(copy=False), copy=False)
+ self.sigCurrentChanged.emit(self._item)
- :param parent: The widget's parent.
- :param plot: :class:`.PlotWidget` instance on which to operate
+ def getItems(self):
+ plot = self.getPlot()
+ return () if plot is None else (self._item,)
+
+ def getSelectedItems(self):
+ return self.getItems()
+
+ def setCurrentItem(self, item):
+ pass
+
+ def getLabel(self, item):
+ return 'Data'
+
+ def getKind(self, item):
+ return 'image'
+
+
+class _Container(object):
+ """Class to contain a plot item.
+
+ This is apparently needed for compatibility with PySide2,
+
+ :param QObject obj:
"""
+ def __init__(self, obj):
+ self._obj = obj
- COMPATIBLE_KINDS = {
- 'curve': CurveItem,
- 'image': ImageItem,
- 'scatter': ScatterItem,
- 'histogram': HistogramItem
- }
+ def __call__(self):
+ return self._obj
- COMPATIBLE_ITEMS = tuple(COMPATIBLE_KINDS.values())
- def __init__(self, parent=None, plot=None):
- TableWidget.__init__(self, parent)
- """Next freeID for the curve"""
- self.plot = None
- self._displayOnlyActItem = False
- self._statsOnVisibleData = False
- self._lgdAndKindToItems = {}
- """Associate to a tuple(legend, kind) the items legend"""
- self.callbackImage = None
- self.callbackScatter = None
- self.callbackCurve = None
- """Associate the curve legend to his first item"""
+class _StatsWidgetBase(object):
+ """
+ Base class for all widgets which want to display statistics
+ """
+ def __init__(self, statsOnVisibleData, displayOnlyActItem):
+ self._displayOnlyActItem = displayOnlyActItem
+ self._statsOnVisibleData = statsOnVisibleData
self._statsHandler = None
- self._legendsSet = []
- """list of legends actually displayed"""
- self._resetColumns()
- self.setColumnCount(len(self._columns))
- self.setSelectionBehavior(qt.QAbstractItemView.SelectRows)
- self.setPlot(plot)
- self.setSortingEnabled(True)
+ self.__default_skipped_events = (
+ ItemChangedType.ALPHA,
+ ItemChangedType.COLOR,
+ ItemChangedType.COLORMAP,
+ ItemChangedType.SYMBOL,
+ ItemChangedType.SYMBOL_SIZE,
+ ItemChangedType.LINE_WIDTH,
+ ItemChangedType.LINE_STYLE,
+ ItemChangedType.LINE_BG_COLOR,
+ ItemChangedType.FILL,
+ ItemChangedType.HIGHLIGHTED_COLOR,
+ ItemChangedType.HIGHLIGHTED_STYLE,
+ ItemChangedType.TEXT,
+ ItemChangedType.OVERLAY,
+ ItemChangedType.VISUALIZATION_MODE,
+ )
+
+ self._plotWrapper = _Wrapper()
+ self._dealWithPlotConnection(create=True)
- def _resetColumns(self):
- self._columns_index = OrderedDict([('legend', 0), ('kind', 1)])
- self._columns = self._columns_index.keys()
- self.setColumnCount(len(self._columns))
+ def setPlot(self, plot):
+ """Define the plot to interact with
- def setStats(self, statsHandler):
+ :param Union[PlotWidget,SceneWidget,None] plot:
+ The plot containing the items on which statistics are applied
"""
+ try:
+ import OpenGL
+ except ImportError:
+ has_opengl = False
+ else:
+ has_opengl = True
+ from ..plot3d.SceneWidget import SceneWidget # Lazy import
+ self._dealWithPlotConnection(create=False)
+ self.clear()
+ if plot is None:
+ self._plotWrapper = _Wrapper()
+ elif isinstance(plot, PlotWidget):
+ self._plotWrapper = _PlotWidgetWrapper(plot)
+ else:
+ if has_opengl is True:
+ if isinstance(plot, SceneWidget):
+ self._plotWrapper = _SceneWidgetWrapper(plot)
+ else: # Expect a ScalarFieldView
+ self._plotWrapper = _ScalarFieldViewWrapper(plot)
+ else:
+ _logger.warning('OpenGL not installed, %s not managed' % ('SceneWidget qnd ScalarFieldView'))
+ self._dealWithPlotConnection(create=True)
+
+ def setStats(self, statsHandler):
+ """Set which stats to display and the associated formatting.
- :param statsHandler: Set the statistics to be displayed and how to
- format them using
- :rtype: :class:`StatsHandler`
+ :param StatsHandler statsHandler:
+ Set the statistics to be displayed and how to format them using
"""
- _statsHandler = statsHandler
if statsHandler is None:
- _statsHandler = StatsHandler(statFormatters=())
- if isinstance(_statsHandler, (list, tuple)):
- _statsHandler = StatsHandler(_statsHandler)
- assert isinstance(_statsHandler, StatsHandler)
- self._resetColumns()
- self.clear()
-
- for statName, stat in list(_statsHandler.stats.items()):
- assert isinstance(stat, statsmdl.StatBase)
- self._columns_index[statName] = len(self._columns_index)
- self._statsHandler = _statsHandler
- self._columns = self._columns_index.keys()
- self.setColumnCount(len(self._columns))
+ statsHandler = StatsHandler(statFormatters=())
+ elif isinstance(statsHandler, (list, tuple)):
+ statsHandler = StatsHandler(statsHandler)
+ assert isinstance(statsHandler, StatsHandler)
- self._updateItemObserve()
- self._updateAllStats()
+ self._statsHandler = statsHandler
def getStatsHandler(self):
+ """Returns the :class:`StatsHandler` in use.
+
+ :rtype: StatsHandler
+ """
return self._statsHandler
- def _updateAllStats(self):
- for (legend, kind) in self._lgdAndKindToItems:
- self._updateStats(legend, kind)
+ def getPlot(self):
+ """Returns the plot attached to this widget
- @staticmethod
- def _getKind(myItem):
- if isinstance(myItem, CurveItem):
- return 'curve'
- elif isinstance(myItem, ImageItem):
- return 'image'
- elif isinstance(myItem, ScatterItem):
- return 'scatter'
- elif isinstance(myItem, HistogramItem):
- return 'histogram'
+ :rtype: Union[PlotWidget,SceneWidget,None]
+ """
+ return self._plotWrapper.getPlot()
+
+ def _dealWithPlotConnection(self, create=True):
+ """Manage connection to plot signals
+
+ Note: connection on Item are managed by _addItem and _removeItem methods
+ """
+ connections = [] # List of (signal, slot) to connect/disconnect
+ if self._statsOnVisibleData:
+ connections.append(
+ (self._plotWrapper.sigVisibleDataChanged, self._updateAllStats))
+
+ if self._displayOnlyActItem:
+ connections.append(
+ (self._plotWrapper.sigCurrentChanged, self._updateItemObserve))
else:
- return None
+ connections += [
+ (self._plotWrapper.sigItemAdded, self._addItem),
+ (self._plotWrapper.sigItemRemoved, self._removeItem),
+ (self._plotWrapper.sigCurrentChanged, self._plotCurrentChanged)]
+
+ for signal, slot in connections:
+ if create:
+ signal.connect(slot)
+ else:
+ signal.disconnect(slot)
- def setPlot(self, plot):
+ def _updateItemObserve(self, *args):
+ """Reload table depending on mode"""
+ raise NotImplementedError('Base class')
+
+ def _updateStats(self, item):
+ """Update displayed information for given plot item
+
+ :param item: The plot item
+ """
+ raise NotImplementedError('Base class')
+
+ def _updateAllStats(self):
+ """Update stats for all rows in the table"""
+ raise NotImplementedError('Base class')
+
+ def setDisplayOnlyActiveItem(self, displayOnlyActItem):
+ """Toggle display off all items or only the active/selected one
+
+ :param bool displayOnlyActItem:
+ True if we want to only show active item
"""
- Define the plot to interact with
+ self._displayOnlyActItem = displayOnlyActItem
+
+ def setStatsOnVisibleData(self, b):
+ """Toggle computation of statistics on whole data or only visible ones.
+
+ .. warning:: When visible data is activated we will process to a simple
+ filtering of visible data by the user. The filtering is a
+ simple data sub-sampling. No interpolation is made to fit
+ data to boundaries.
- :param plot: the plot containing the items on which statistics are
- applied
- :rtype: :class:`.PlotWidget`
+ :param bool b: True if we want to apply statistics only on visible data
"""
- if self.plot:
+ if self._statsOnVisibleData != b:
self._dealWithPlotConnection(create=False)
- self.plot = plot
- self.clear()
- if self.plot:
+ self._statsOnVisibleData = b
self._dealWithPlotConnection(create=True)
- self._updateItemObserve()
+ self._updateAllStats()
- def _updateItemObserve(self):
- if self.plot:
- self.clear()
- if self._displayOnlyActItem is True:
- activeCurve = self.plot.getActiveCurve(just_legend=False)
- activeScatter = self.plot._getActiveItem(kind='scatter',
- just_legend=False)
- activeImage = self.plot.getActiveImage(just_legend=False)
- if activeCurve:
- self._addItem(activeCurve)
- if activeImage:
- self._addItem(activeImage)
- if activeScatter:
- self._addItem(activeScatter)
- else:
- [self._addItem(curve) for curve in self.plot.getAllCurves()]
- [self._addItem(image) for image in self.plot.getAllImages()]
- scatters = self.plot._getItems(kind='scatter',
- just_legend=False,
- withhidden=True)
- [self._addItem(scatter) for scatter in scatters]
- histograms = self.plot._getItems(kind='histogram',
- just_legend=False,
- withhidden=True)
- [self._addItem(histogram) for histogram in histograms]
+ def _addItem(self, item):
+ """Add a plot item to the table
- def _dealWithPlotConnection(self, create=True):
+ If item is not supported, it is ignored.
+
+ :param item: The plot item
+ :returns: True if the item is added to the widget.
+ :rtype: bool
"""
- Manage connection to plot signals
+ raise NotImplementedError('Base class')
- Note: connection on Item are managed by the _removeItem function
+ def _removeItem(self, item):
+ """Remove table items corresponding to given plot item from the table.
+
+ :param item: The plot item
"""
- if self.plot is None:
- return
- if self._displayOnlyActItem:
- if create is True:
- if self.callbackImage is None:
- self.callbackImage = functools.partial(self._activeItemChanged, 'image')
- self.callbackScatter = functools.partial(self._activeItemChanged, 'scatter')
- self.callbackCurve = functools.partial(self._activeItemChanged, 'curve')
- self.plot.sigActiveImageChanged.connect(self.callbackImage)
- self.plot.sigActiveScatterChanged.connect(self.callbackScatter)
- self.plot.sigActiveCurveChanged.connect(self.callbackCurve)
- else:
- if self.callbackImage is not None:
- self.plot.sigActiveImageChanged.disconnect(self.callbackImage)
- self.plot.sigActiveScatterChanged.disconnect(self.callbackScatter)
- self.plot.sigActiveCurveChanged.disconnect(self.callbackCurve)
- self.callbackImage = None
- self.callbackScatter = None
- self.callbackCurve = None
- else:
- if create is True:
- self.plot.sigContentChanged.connect(self._plotContentChanged)
- else:
- self.plot.sigContentChanged.disconnect(self._plotContentChanged)
- if create is True:
- self.plot.sigPlotSignal.connect(self._zoomPlotChanged)
- else:
- self.plot.sigPlotSignal.disconnect(self._zoomPlotChanged)
+ raise NotImplementedError('Base class')
+
+ def _plotCurrentChanged(self, current):
+ """Handle change of current item and update selection in table
+
+ :param current:
+ """
+ raise NotImplementedError('Base class')
def clear(self):
+ """clear GUI"""
+ pass
+
+ def _skipPlotItemChangedEvent(self, event):
"""
- Clear all existing items
+
+ :param ItemChangedtype event: event to filter or not
+ :return: True if we want to ignore this ItemChangedtype
+ :rtype: bool
"""
- lgdsAndKinds = list(self._lgdAndKindToItems.keys())
- for lgdAndKind in lgdsAndKinds:
- self._removeItem(legend=lgdAndKind[0], kind=lgdAndKind[1])
- self._lgdAndKindToItems = {}
- qt.QTableWidget.clear(self)
+ return event in self.__default_skipped_events
+
+
+class StatsTable(_StatsWidgetBase, TableWidget):
+ """
+ TableWidget displaying for each curves contained by the Plot some
+ information:
+
+ * legend
+ * minimal value
+ * maximal value
+ * standard deviation (std)
+
+ :param QWidget parent: The widget's parent.
+ :param Union[PlotWidget,SceneWidget] plot:
+ :class:`PlotWidget` or :class:`SceneWidget` instance on which to operate
+ """
+
+ _LEGEND_HEADER_DATA = 'legend'
+ _KIND_HEADER_DATA = 'kind'
+
+ def __init__(self, parent=None, plot=None):
+ TableWidget.__init__(self, parent)
+ _StatsWidgetBase.__init__(self, statsOnVisibleData=False,
+ displayOnlyActItem=False)
+
+ # Init for _displayOnlyActItem == False
+ assert self._displayOnlyActItem is False
+ self.setSelectionBehavior(qt.QAbstractItemView.SelectRows)
+ self.setSelectionMode(qt.QAbstractItemView.SingleSelection)
+ self.currentItemChanged.connect(self._currentItemChanged)
+
self.setRowCount(0)
+ self.setColumnCount(2)
- # It have to called befor3e accessing to the header items
- self.setHorizontalHeaderLabels(list(self._columns))
-
- if self._statsHandler is not None:
- for columnId, name in enumerate(self._columns):
- item = self.horizontalHeaderItem(columnId)
- if name in self._statsHandler.stats:
- stat = self._statsHandler.stats[name]
- text = stat.name[0].upper() + stat.name[1:]
- if stat.description is not None:
- tooltip = stat.description
- else:
- tooltip = ""
- else:
- text = name[0].upper() + name[1:]
- tooltip = ""
- item.setToolTip(tooltip)
- item.setText(text)
+ # Init headers
+ headerItem = qt.QTableWidgetItem('Legend')
+ headerItem.setData(qt.Qt.UserRole, self._LEGEND_HEADER_DATA)
+ self.setHorizontalHeaderItem(0, headerItem)
+ headerItem = qt.QTableWidgetItem('Kind')
+ headerItem.setData(qt.Qt.UserRole, self._KIND_HEADER_DATA)
+ self.setHorizontalHeaderItem(1, headerItem)
+
+ self.setSortingEnabled(True)
+ self.setPlot(plot)
- if hasattr(self.horizontalHeader(), 'setSectionResizeMode'): # Qt5
- self.horizontalHeader().setSectionResizeMode(qt.QHeaderView.ResizeToContents)
+ @contextmanager
+ def _disableSorting(self):
+ """Context manager that disables table sorting
+
+ Previous state is restored when leaving
+ """
+ sorting = self.isSortingEnabled()
+ if sorting:
+ self.setSortingEnabled(False)
+ yield
+ if sorting:
+ self.setSortingEnabled(sorting)
+
+ def setStats(self, statsHandler):
+ """Set which stats to display and the associated formatting.
+
+ :param StatsHandler statsHandler:
+ Set the statistics to be displayed and how to format them using
+ """
+ self._removeAllItems()
+ _StatsWidgetBase.setStats(self, statsHandler)
+
+ self.setRowCount(0)
+ self.setColumnCount(len(self._statsHandler.stats) + 2) # + legend and kind
+
+ for index, stat in enumerate(self._statsHandler.stats.values()):
+ headerItem = qt.QTableWidgetItem(stat.name.capitalize())
+ headerItem.setData(qt.Qt.UserRole, stat.name)
+ if stat.description is not None:
+ headerItem.setToolTip(stat.description)
+ self.setHorizontalHeaderItem(2 + index, headerItem)
+
+ horizontalHeader = self.horizontalHeader()
+ if hasattr(horizontalHeader, 'setSectionResizeMode'): # Qt5
+ horizontalHeader.setSectionResizeMode(qt.QHeaderView.ResizeToContents)
else: # Qt4
- self.horizontalHeader().setResizeMode(qt.QHeaderView.ResizeToContents)
- self.setColumnHidden(self._columns_index['kind'], True)
+ horizontalHeader.setResizeMode(qt.QHeaderView.ResizeToContents)
- def _addItem(self, item):
- assert isinstance(item, self.COMPATIBLE_ITEMS)
- if (item.getLegend(), self._getKind(item)) in self._lgdAndKindToItems:
- self._updateStats(item.getLegend(), self._getKind(item))
- return
+ self._updateItemObserve()
+
+ def setPlot(self, plot):
+ """Define the plot to interact with
+
+ :param Union[PlotWidget,SceneWidget,None] plot:
+ The plot containing the items on which statistics are applied
+ """
+ _StatsWidgetBase.setPlot(self, plot)
+ self._updateItemObserve()
+
+ def clear(self):
+ """Define the plot to interact with
+
+ :param Union[PlotWidget,SceneWidget,None] plot:
+ The plot containing the items on which statistics are applied
+ """
+ self._removeAllItems()
+
+ def _updateItemObserve(self, *args):
+ """Reload table depending on mode"""
+ self._removeAllItems()
+
+ # Get selected or all items from the plot
+ if self._displayOnlyActItem: # Only selected
+ items = self._plotWrapper.getSelectedItems()
+ else: # All items
+ items = self._plotWrapper.getItems()
+
+ # Add items to the plot
+ for item in items:
+ self._addItem(item)
+
+ def _plotCurrentChanged(self, current):
+ """Handle change of current item and update selection in table
- self.setRowCount(self.rowCount() + 1)
- indexTable = self.rowCount() - 1
- kind = self._getKind(item)
-
- self._lgdAndKindToItems[(item.getLegend(), kind)] = {}
-
- # the get item will manage the item creation of not existing
- _createItem = self._getItem
- for itemName in self._columns:
- _createItem(name=itemName, legend=item.getLegend(), kind=kind,
- indexTable=indexTable)
-
- self._updateStats(legend=item.getLegend(), kind=kind)
-
- callback = functools.partial(
- silx.utils.weakref.WeakMethodProxy(self._updateStats),
- item.getLegend(), kind)
- item.sigItemChanged.connect(callback)
- self.setColumnHidden(self._columns_index['kind'],
- item.getLegend() not in self._legendsSet)
- self._legendsSet.append(item.getLegend())
-
- def _getItem(self, name, legend, kind, indexTable):
- if (legend, kind) not in self._lgdAndKindToItems:
- self._lgdAndKindToItems[(legend, kind)] = {}
- if not (name in self._lgdAndKindToItems[(legend, kind)] and
- self._lgdAndKindToItems[(legend, kind)]):
- if name in ('legend', 'kind'):
- _item = qt.QTableWidgetItem(type=qt.QTableWidgetItem.Type)
- if name == 'legend':
- _item.setText(legend)
+ :param current:
+ """
+ row = self._itemToRow(current)
+ if row is None:
+ if self.currentRow() >= 0:
+ self.setCurrentCell(-1, -1)
+ elif row != self.currentRow():
+ self.setCurrentCell(row, 0)
+
+ def _tableItemToItem(self, tableItem):
+ """Find the plot item corresponding to a table item
+
+ :param QTableWidgetItem tableItem:
+ :rtype: QObject
+ """
+ container = tableItem.data(qt.Qt.UserRole)
+ return container()
+
+ def _itemToRow(self, item):
+ """Find the row corresponding to a plot item
+
+ :param item: The plot item
+ :return: The corresponding row index
+ :rtype: Union[int,None]
+ """
+ for row in range(self.rowCount()):
+ tableItem = self.item(row, 0)
+ if self._tableItemToItem(tableItem) == item:
+ return row
+ return None
+
+ def _itemToTableItems(self, item):
+ """Find all table items corresponding to a plot item
+
+ :param item: The plot item
+ :return: An ordered dict of column name to QTableWidgetItem mapping
+ for the given plot item.
+ :rtype: OrderedDict
+ """
+ result = OrderedDict()
+ row = self._itemToRow(item)
+ if row is not None:
+ for column in range(self.columnCount()):
+ tableItem = self.item(row, column)
+ if self._tableItemToItem(tableItem) != item:
+ _logger.error("Table item/plot item mismatch")
else:
- assert name == 'kind'
- _item.setText(kind)
+ header = self.horizontalHeaderItem(column)
+ name = header.data(qt.Qt.UserRole)
+ result[name] = tableItem
+ return result
+
+ def _plotItemChanged(self, event):
+ """Handle modifications of the items.
+
+ :param event:
+ """
+ if self._skipPlotItemChangedEvent(event) is True:
+ return
+ else:
+ item = self.sender()
+ self._updateStats(item)
+
+ def _addItem(self, item):
+ """Add a plot item to the table
+
+ If item is not supported, it is ignored.
+
+ :param item: The plot item
+ :returns: True if the item is added to the widget.
+ :rtype: bool
+ """
+ if self._itemToRow(item) is not None:
+ _logger.info("Item already present in the table")
+ self._updateStats(item)
+ return True
+
+ kind = self._plotWrapper.getKind(item)
+ if kind not in statsmdl.BASIC_COMPATIBLE_KINDS:
+ _logger.info("Item has not a supported type: %s", item)
+ return False
+
+ # Prepare table items
+ tableItems = [
+ qt.QTableWidgetItem(), # Legend
+ qt.QTableWidgetItem()] # Kind
+
+ for column in range(2, self.columnCount()):
+ header = self.horizontalHeaderItem(column)
+ name = header.data(qt.Qt.UserRole)
+
+ formatter = self._statsHandler.formatters[name]
+ if formatter:
+ tableItem = formatter.tabWidgetItemClass()
else:
- if self._statsHandler.formatters[name]:
- _item = self._statsHandler.formatters[name].tabWidgetItemClass()
- else:
- _item = qt.QTableWidgetItem()
- tooltip = self._statsHandler.stats[name].getToolTip(kind=kind)
- if tooltip is not None:
- _item.setToolTip(tooltip)
+ tableItem = qt.QTableWidgetItem()
- _item.setFlags(qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable)
- self.setItem(indexTable, self._columns_index[name], _item)
- self._lgdAndKindToItems[(legend, kind)][name] = _item
+ tooltip = self._statsHandler.stats[name].getToolTip(kind=kind)
+ if tooltip is not None:
+ tableItem.setToolTip(tooltip)
- return self._lgdAndKindToItems[(legend, kind)][name]
+ tableItems.append(tableItem)
- def _removeItem(self, legend, kind):
- if (legend, kind) not in self._lgdAndKindToItems or not self.plot:
- return
+ # Disable sorting while adding table items
+ with self._disableSorting():
+ # Add a row to the table
+ self.setRowCount(self.rowCount() + 1)
- self.firstItem = self._lgdAndKindToItems[(legend, kind)]['legend']
- del self._lgdAndKindToItems[(legend, kind)]
- self.removeRow(self.firstItem.row())
- self._legendsSet.remove(legend)
- self.setColumnHidden(self._columns_index['kind'],
- legend not in self._legendsSet)
+ # Add table items to the last row
+ row = self.rowCount() - 1
+ for column, tableItem in enumerate(tableItems):
+ tableItem.setData(qt.Qt.UserRole, _Container(item))
+ tableItem.setFlags(
+ qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable)
+ self.setItem(row, column, tableItem)
- def _updateCurrentStats(self):
- for lgdAndKind in self._lgdAndKindToItems:
- self._updateStats(lgdAndKind[0], lgdAndKind[1])
+ # Update table items content
+ self._updateStats(item)
- def _updateStats(self, legend, kind, event=None):
- if self._statsHandler is None:
+ # Listen for item changes
+ # Using queued connection to avoid issue with sender
+ # being that of the signal calling the signal
+ item.sigItemChanged.connect(self._plotItemChanged,
+ qt.Qt.QueuedConnection)
+
+ return True
+
+ def _removeItem(self, item):
+ """Remove table items corresponding to given plot item from the table.
+
+ :param item: The plot item
+ """
+ row = self._itemToRow(item)
+ if row is None:
+ kind = self._plotWrapper.getKind(item)
+ if kind in statsmdl.BASIC_COMPATIBLE_KINDS:
+ _logger.error("Removing item that is not in table: %s", str(item))
return
+ item.sigItemChanged.disconnect(self._plotItemChanged)
+ self.removeRow(row)
+
+ def _removeAllItems(self):
+ """Remove content of the table"""
+ for row in range(self.rowCount()):
+ tableItem = self.item(row, 0)
+ item = self._tableItemToItem(tableItem)
+ item.sigItemChanged.disconnect(self._plotItemChanged)
+ self.clearContents()
+ self.setRowCount(0)
- assert kind in ('curve', 'image', 'scatter', 'histogram')
- if kind == 'curve':
- item = self.plot.getCurve(legend)
- elif kind == 'image':
- item = self.plot.getImage(legend)
- elif kind == 'scatter':
- item = self.plot.getScatter(legend)
- elif kind == 'histogram':
- item = self.plot.getHistogram(legend)
- else:
- raise ValueError('kind not managed')
+ def _updateStats(self, item):
+ """Update displayed information for given plot item
- if not item or (item.getLegend(), kind) not in self._lgdAndKindToItems:
+ :param item: The plot item
+ """
+ if item is None:
+ return
+ plot = self.getPlot()
+ if plot is None:
+ _logger.info("Plot not available")
return
- assert isinstance(item, self.COMPATIBLE_ITEMS)
-
- statsValDict = self._statsHandler.calculate(item, self.plot,
- self._statsOnVisibleData)
-
- lgdItem = self._lgdAndKindToItems[(item.getLegend(), kind)]['legend']
- assert lgdItem
- rowStat = lgdItem.row()
-
- for statName, statVal in list(statsValDict.items()):
- assert statName in self._lgdAndKindToItems[(item.getLegend(), kind)]
- tableItem = self._getItem(name=statName, legend=item.getLegend(),
- kind=kind, indexTable=rowStat)
- tableItem.setText(str(statVal))
-
- def currentChanged(self, current, previous):
- if current.row() >= 0:
- legendItem = self.item(current.row(), self._columns_index['legend'])
- assert legendItem
- kindItem = self.item(current.row(), self._columns_index['kind'])
- kind = kindItem.text()
- if kind == 'curve':
- self.plot.setActiveCurve(legendItem.text())
- elif kind == 'image':
- self.plot.setActiveImage(legendItem.text())
- elif kind == 'scatter':
- self.plot._setActiveItem('scatter', legendItem.text())
- elif kind == 'histogram':
- # active histogram not managed by the plot actually
- pass
- else:
- raise ValueError('kind not managed')
- qt.QTableWidget.currentChanged(self, current, previous)
+ row = self._itemToRow(item)
+ if row is None:
+ _logger.error("This item is not in the table: %s", str(item))
+ return
- def setDisplayOnlyActiveItem(self, displayOnlyActItem):
+ statsHandler = self.getStatsHandler()
+ if statsHandler is not None:
+ stats = statsHandler.calculate(
+ item, plot, self._statsOnVisibleData)
+ else:
+ stats = {}
+
+ with self._disableSorting():
+ for name, tableItem in self._itemToTableItems(item).items():
+ if name == self._LEGEND_HEADER_DATA:
+ text = self._plotWrapper.getLabel(item)
+ tableItem.setText(text)
+ elif name == self._KIND_HEADER_DATA:
+ tableItem.setText(self._plotWrapper.getKind(item))
+ else:
+ value = stats.get(name)
+ if value is None:
+ _logger.error("Value not found for: %s", name)
+ tableItem.setText('-')
+ else:
+ tableItem.setText(str(value))
+
+ def _updateAllStats(self):
+ """Update stats for all rows in the table"""
+ with self._disableSorting():
+ for row in range(self.rowCount()):
+ tableItem = self.item(row, 0)
+ item = self._tableItemToItem(tableItem)
+ self._updateStats(item)
+
+ def _currentItemChanged(self, current, previous):
+ """Handle change of selection in table and sync plot selection
+
+ :param QTableWidgetItem current:
+ :param QTableWidgetItem previous:
"""
+ if current and current.row() >= 0:
+ item = self._tableItemToItem(current)
+ self._plotWrapper.setCurrentItem(item)
- :param bool displayOnlyActItem: True if we want to only show active
- item
+ def setDisplayOnlyActiveItem(self, displayOnlyActItem):
+ """Toggle display off all items or only the active/selected one
+
+ :param bool displayOnlyActItem:
+ True if we want to only show active item
"""
if self._displayOnlyActItem == displayOnlyActItem:
return
- self._displayOnlyActItem = displayOnlyActItem
self._dealWithPlotConnection(create=False)
+ if not self._displayOnlyActItem:
+ self.currentItemChanged.disconnect(self._currentItemChanged)
+
+ _StatsWidgetBase.setDisplayOnlyActiveItem(self, displayOnlyActItem)
+
self._updateItemObserve()
self._dealWithPlotConnection(create=True)
+ if not self._displayOnlyActItem:
+ self.currentItemChanged.connect(self._currentItemChanged)
+ self.setSelectionMode(qt.QAbstractItemView.SingleSelection)
+ else:
+ self.setSelectionMode(qt.QAbstractItemView.NoSelection)
+
+
+class _OptionsWidget(qt.QToolBar):
+
+ def __init__(self, parent=None):
+ qt.QToolBar.__init__(self, parent)
+ self.setIconSize(qt.QSize(16, 16))
+
+ action = qt.QAction(self)
+ action.setIcon(icons.getQIcon("stats-active-items"))
+ action.setText("Active items only")
+ action.setToolTip("Display stats for active items only.")
+ action.setCheckable(True)
+ action.setChecked(True)
+ self.__displayActiveItems = action
+
+ action = qt.QAction(self)
+ action.setIcon(icons.getQIcon("stats-whole-items"))
+ action.setText("All items")
+ action.setToolTip("Display stats for all available items.")
+ action.setCheckable(True)
+ self.__displayWholeItems = action
+
+ action = qt.QAction(self)
+ action.setIcon(icons.getQIcon("stats-visible-data"))
+ action.setText("Use the visible data range")
+ action.setToolTip("Use the visible data range.<br/>"
+ "If activated the data is filtered to only use"
+ "visible data of the plot."
+ "The filtering is a data sub-sampling."
+ "No interpolation is made to fit data to"
+ "boundaries.")
+ action.setCheckable(True)
+ self.__useVisibleData = action
+
+ action = qt.QAction(self)
+ action.setIcon(icons.getQIcon("stats-whole-data"))
+ action.setText("Use the full data range")
+ action.setToolTip("Use the full data range.")
+ action.setCheckable(True)
+ action.setChecked(True)
+ self.__useWholeData = action
+
+ self.addAction(self.__displayWholeItems)
+ self.addAction(self.__displayActiveItems)
+ self.addSeparator()
+ self.addAction(self.__useVisibleData)
+ self.addAction(self.__useWholeData)
+
+ self.itemSelection = qt.QActionGroup(self)
+ self.itemSelection.setExclusive(True)
+ self.itemSelection.addAction(self.__displayActiveItems)
+ self.itemSelection.addAction(self.__displayWholeItems)
+
+ self.dataRangeSelection = qt.QActionGroup(self)
+ self.dataRangeSelection.setExclusive(True)
+ self.dataRangeSelection.addAction(self.__useWholeData)
+ self.dataRangeSelection.addAction(self.__useVisibleData)
+
+ def isActiveItemMode(self):
+ return self.itemSelection.checkedAction() is self.__displayActiveItems
+
+ def isVisibleDataRangeMode(self):
+ return self.dataRangeSelection.checkedAction() is self.__useVisibleData
+
+ def setVisibleDataRangeModeEnabled(self, enabled):
+ """Enable/Disable the visible data range mode
+
+ :param bool enabled: True to allow user to choose
+ stats on visible data
+ """
+ self.__useVisibleData.setEnabled(enabled)
+ if not enabled:
+ self.__useWholeData.setChecked(True)
+
+
+class StatsWidget(qt.QWidget):
+ """
+ Widget displaying a set of :class:`Stat` to be displayed on a
+ :class:`StatsTable` and to be apply on items contained in the :class:`Plot`
+ Also contains options to:
+
+ * compute statistics on all the data or on visible data only
+ * show statistics of all items or only the active one
+
+ :param QWidget parent: Qt parent
+ :param Union[PlotWidget,SceneWidget] plot:
+ The plot containing items on which we want statistics.
+ :param StatsHandler stats:
+ Set the statistics to be displayed and how to format them using
+ """
+
+ sigVisibilityChanged = qt.Signal(bool)
+ """Signal emitted when the visibility of this widget changes.
+
+ It Provides the visibility of the widget.
+ """
+
+ NUMBER_FORMAT = '{0:.3f}'
+
+ def __init__(self, parent=None, plot=None, stats=None):
+ qt.QWidget.__init__(self, parent)
+ self.setLayout(qt.QVBoxLayout())
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self._options = _OptionsWidget(parent=self)
+ self.layout().addWidget(self._options)
+ self._statsTable = StatsTable(parent=self, plot=plot)
+ self.setStats(stats)
+
+ self.layout().addWidget(self._statsTable)
+
+ self._options.itemSelection.triggered.connect(
+ self._optSelectionChanged)
+ self._options.dataRangeSelection.triggered.connect(
+ self._optDataRangeChanged)
+ self._optSelectionChanged()
+ self._optDataRangeChanged()
+
+ def _getStatsTable(self):
+ """Returns the :class:`StatsTable` used by this widget.
+
+ :rtype: StatsTable
+ """
+ return self._statsTable
+
+ def showEvent(self, event):
+ self.sigVisibilityChanged.emit(True)
+ qt.QWidget.showEvent(self, event)
+
+ def hideEvent(self, event):
+ self.sigVisibilityChanged.emit(False)
+ qt.QWidget.hideEvent(self, event)
+
+ def _optSelectionChanged(self, action=None):
+ self._getStatsTable().setDisplayOnlyActiveItem(
+ self._options.isActiveItemMode())
+
+ def _optDataRangeChanged(self, action=None):
+ self._getStatsTable().setStatsOnVisibleData(
+ self._options.isVisibleDataRangeMode())
+
+ # Proxy methods
+
+ def setStats(self, statsHandler):
+ return self._getStatsTable().setStats(statsHandler=statsHandler)
+
+ setStats.__doc__ = StatsTable.setStats.__doc__
+
+ def setPlot(self, plot):
+ self._options.setVisibleDataRangeModeEnabled(
+ plot is None or isinstance(plot, PlotWidget))
+ return self._getStatsTable().setPlot(plot=plot)
+
+ setPlot.__doc__ = StatsTable.setPlot.__doc__
+
+ def getPlot(self):
+ return self._getStatsTable().getPlot()
+
+ getPlot.__doc__ = StatsTable.getPlot.__doc__
+
+ def setDisplayOnlyActiveItem(self, displayOnlyActItem):
+ return self._getStatsTable().setDisplayOnlyActiveItem(
+ displayOnlyActItem=displayOnlyActItem)
+
+ setDisplayOnlyActiveItem.__doc__ = StatsTable.setDisplayOnlyActiveItem.__doc__
+
def setStatsOnVisibleData(self, b):
+ return self._getStatsTable().setStatsOnVisibleData(b=b)
+
+ setStatsOnVisibleData.__doc__ = StatsTable.setStatsOnVisibleData.__doc__
+
+
+DEFAULT_STATS = StatsHandler((
+ (statsmdl.StatMin(), StatFormatter()),
+ statsmdl.StatCoordMin(),
+ (statsmdl.StatMax(), StatFormatter()),
+ statsmdl.StatCoordMax(),
+ statsmdl.StatCOM(),
+ (('mean', numpy.mean), StatFormatter()),
+ (('std', numpy.std), StatFormatter()),
+))
+
+
+class BasicStatsWidget(StatsWidget):
+ """
+ Widget defining a simple set of :class:`Stat` to be displayed on a
+ :class:`StatsWidget`.
+
+ :param QWidget parent: Qt parent
+ :param PlotWidget plot:
+ The plot containing items on which we want statistics.
+ :param StatsHandler stats:
+ Set the statistics to be displayed and how to format them using
+
+ .. snapshotqt:: img/BasicStatsWidget.png
+ :width: 300px
+ :align: center
+
+ from silx.gui.plot import Plot1D
+ from silx.gui.plot.StatsWidget import BasicStatsWidget
+
+ plot = Plot1D()
+ x = range(100)
+ y = x
+ plot.addCurve(x, y, legend='curve_0')
+ plot.setActiveCurve('curve_0')
+
+ widget = BasicStatsWidget(plot=plot)
+ widget.show()
+ """
+ def __init__(self, parent=None, plot=None):
+ StatsWidget.__init__(self, parent=parent, plot=plot,
+ stats=DEFAULT_STATS)
+
+
+class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget):
+ """
+ Widget made to display stats into a QLayout with for all stat a couple
+ (QLabel, QLineEdit) created.
+ The the layout can be defined prior of adding any statistic.
+
+ :param QWidget parent: Qt parent
+ :param Union[PlotWidget,SceneWidget] plot:
+ The plot containing items on which we want statistics.
+ :param str kind: the kind of plotitems we want to display
+ :param StatsHandler stats:
+ Set the statistics to be displayed and how to format them using
+ :param bool statsOnVisibleData: compute statistics for the whole data or
+ only visible ones.
+ """
+
+ def __init__(self, parent=None, plot=None, kind='curve', stats=None,
+ statsOnVisibleData=False):
+ self._item_kind = kind
+ """The item displayed"""
+ self._statQlineEdit = {}
+ """list of legends actually displayed"""
+ self._n_statistics_per_line = 4
+ """number of statistics displayed per line in the grid layout"""
+ qt.QWidget.__init__(self, parent)
+ _StatsWidgetBase.__init__(self,
+ statsOnVisibleData=statsOnVisibleData,
+ displayOnlyActItem=True)
+ self.setLayout(self._createLayout())
+ self.setPlot(plot)
+ if stats is not None:
+ self.setStats(stats)
+
+ def _addItemForStatistic(self, statistic):
+ assert isinstance(statistic, statsmdl.StatBase)
+ assert statistic.name in self._statsHandler.stats
+
+ self.layout().setSpacing(2)
+ self.layout().setContentsMargins(2, 2, 2, 2)
+
+ if isinstance(self.layout(), qt.QGridLayout):
+ parent = self
+ else:
+ widget = qt.QWidget(parent=self)
+ parent = widget
+
+ qLabel = qt.QLabel(statistic.name + ':', parent=parent)
+ qLineEdit = qt.QLineEdit('', parent=parent)
+ qLineEdit.setReadOnly(True)
+
+ self._addStatsWidgetsToLayout(qLabel=qLabel, qLineEdit=qLineEdit)
+ self._statQlineEdit[statistic.name] = qLineEdit
+
+ def setPlot(self, plot):
+ """Define the plot to interact with
+
+ :param Union[PlotWidget,SceneWidget,None] plot:
+ The plot containing the items on which statistics are applied
"""
- .. warning:: When visible data is activated we will process to a simple
- filtering of visible data by the user. The filtering is a
- simple data sub-sampling. No interpolation is made to fit
- data to boundaries.
+ _StatsWidgetBase.setPlot(self, plot)
+ self._updateAllStats()
- :param bool b: True if we want to apply statistics only on visible data
+ def _addStatsWidgetsToLayout(self, qLabel, qLineEdit):
+ raise NotImplementedError('Base class')
+
+ def setStats(self, statsHandler):
+ """Set which stats to display and the associated formatting.
+ :param StatsHandler statsHandler:
+ Set the statistics to be displayed and how to format them using
"""
- if self._statsOnVisibleData != b:
- self._statsOnVisibleData = b
- self._updateCurrentStats()
+ _StatsWidgetBase.setStats(self, statsHandler)
+ for statName, stat in list(self._statsHandler.stats.items()):
+ self._addItemForStatistic(stat)
+ self._updateAllStats()
def _activeItemChanged(self, kind, previous, current):
- """Callback used when plotting only the active item"""
- assert kind in ('curve', 'image', 'scatter', 'histogram')
- self._updateItemObserve()
+ if kind == self._item_kind:
+ self._updateAllStats()
- def _plotContentChanged(self, action, kind, legend):
- """Callback used when plotting all the plot items"""
- if kind not in ('curve', 'image', 'scatter', 'histogram'):
- return
- if kind == 'curve':
- item = self.plot.getCurve(legend)
- elif kind == 'image':
- item = self.plot.getImage(legend)
- elif kind == 'scatter':
- item = self.plot.getScatter(legend)
- elif kind == 'histogram':
- item = self.plot.getHistogram(legend)
- else:
- raise ValueError('kind not managed')
+ def _updateAllStats(self):
+ plot = self.getPlot()
+ if plot is not None:
+ _items = self._plotWrapper.getSelectedItems()
+ def kind_filter(_item):
+ return self._plotWrapper.getKind(_item) == self.getKind()
+
+ items = list(filter(kind_filter, _items))
+ assert len(items) in (0, 1)
+ if len(items) is 1:
+ self._setItem(items[0])
+
+ def setKind(self, kind):
+ """Change the kind of active item to display
+ :param str kind: kind of item to display information for ('curve' ...)
+ """
+ if self._item_kind != kind:
+ self._item_kind = kind
+ self._updateItemObserve()
- if action == 'add':
- if item is None:
- raise ValueError('Item from legend "%s" do not exists' % legend)
- self._addItem(item)
- elif action == 'remove':
- self._removeItem(legend, kind)
+ def getKind(self):
+ """
+ :return: kind of item we want to compute statistic for
+ :rtype: str
+ """
+ return self._item_kind
+
+ def _setItem(self, item):
+ if item is None:
+ for stat_name, stat_widget in self._statQlineEdit.items():
+ stat_widget.setText('')
+ elif (self._statsHandler is not None and len(
+ self._statsHandler.stats) > 0):
+ plot = self.getPlot()
+ if plot is not None:
+ statsValDict = self._statsHandler.calculate(item,
+ plot,
+ self._statsOnVisibleData)
+ for statName, statVal in list(statsValDict.items()):
+ self._statQlineEdit[statName].setText(statVal)
+
+ def _updateItemObserve(self, *argv):
+ assert self._displayOnlyActItem
+ _items = self._plotWrapper.getSelectedItems()
+ def kind_filter(_item):
+ return self._plotWrapper.getKind(_item) == self.getKind()
+ items = list(filter(kind_filter, _items))
+ assert len(items) in (0, 1)
+ _item = items[0] if len(items) is 1 else None
+ self._setItem(_item)
+
+ def _createLayout(self):
+ """create an instance of the main QLayout"""
+ raise NotImplementedError('Base class')
+
+ def _addItem(self, item):
+ raise NotImplementedError('Display only the active item')
+
+ def _removeItem(self, item):
+ raise NotImplementedError('Display only the active item')
+
+ def _plotCurrentChanged(selfself, current):
+ raise NotImplementedError('Display only the active item')
+
+
+class BasicLineStatsWidget(_BaseLineStatsWidget):
+ """
+ Widget defining a simple set of :class:`Stat` to be displayed on a
+ :class:`LineStatsWidget`.
+
+ :param QWidget parent: Qt parent
+ :param Union[PlotWidget,SceneWidget] plot:
+ The plot containing items on which we want statistics.
+ :param str kind: the kind of plotitems we want to display
+ :param StatsHandler stats:
+ Set the statistics to be displayed and how to format them using
+ :param bool statsOnVisibleData: compute statistics for the whole data or
+ only visible ones.
+ """
+
+ def __init__(self, parent=None, plot=None, kind='curve',
+ stats=DEFAULT_STATS, statsOnVisibleData=False):
+ _BaseLineStatsWidget.__init__(self, parent=parent, kind=kind,
+ plot=plot, stats=stats,
+ statsOnVisibleData=statsOnVisibleData)
+
+ def _createLayout(self):
+ return FlowLayout()
+
+ def _addStatsWidgetsToLayout(self, qLabel, qLineEdit):
+ # create a mother widget to make sure both qLabel & qLineEdit will
+ # always be displayed side by side
+ widget = qt.QWidget(parent=self)
+ widget.setLayout(qt.QHBoxLayout())
+ widget.layout().setSpacing(0)
+ widget.layout().setContentsMargins(0, 0, 0, 0)
+
+ widget.layout().addWidget(qLabel)
+ widget.layout().addWidget(qLineEdit)
+
+ self.layout().addWidget(widget)
+
+
+class BasicGridStatsWidget(_BaseLineStatsWidget):
+ """
+ pymca design like widget
+
+ :param QWidget parent: Qt parent
+ :param Union[PlotWidget,SceneWidget] plot:
+ The plot containing items on which we want statistics.
+ :param StatsHandler stats:
+ Set the statistics to be displayed and how to format them using
+ :param str kind: the kind of plotitems we want to display
+ :param bool statsOnVisibleData: compute statistics for the whole data or
+ only visible ones.
+ :param int statsPerLine: number of statistic to be displayed per line
+
+ .. snapshotqt:: img/BasicGridStatsWidget.png
+ :width: 600px
+ :align: center
+
+ from silx.gui.plot import Plot1D
+ from silx.gui.plot.StatsWidget import BasicGridStatsWidget
+
+ plot = Plot1D()
+ x = range(100)
+ y = x
+ plot.addCurve(x, y, legend='curve_0')
+ plot.setActiveCurve('curve_0')
+
+ widget = BasicGridStatsWidget(plot=plot, kind='curve')
+ widget.show()
+ """
- def _zoomPlotChanged(self, event):
- if self._statsOnVisibleData is True:
- if 'event' in event and event['event'] == 'limitsChanged':
- self._updateCurrentStats()
+ def __init__(self, parent=None, plot=None, kind='curve',
+ stats=DEFAULT_STATS, statsOnVisibleData=False,
+ statsPerLine=4):
+ _BaseLineStatsWidget.__init__(self, parent=parent, kind=kind,
+ plot=plot, stats=stats,
+ statsOnVisibleData=statsOnVisibleData)
+ self._n_statistics_per_line = statsPerLine
+
+ def _addStatsWidgetsToLayout(self, qLabel, qLineEdit):
+ column = len(self._statQlineEdit) % self._n_statistics_per_line
+ row = len(self._statQlineEdit) // self._n_statistics_per_line
+ self.layout().addWidget(qLabel, row, column * 2)
+ self.layout().addWidget(qLineEdit, row, column * 2 + 1)
+
+ def _createLayout(self):
+ return qt.QGridLayout()
diff --git a/silx/gui/plot/_BaseMaskToolsWidget.py b/silx/gui/plot/_BaseMaskToolsWidget.py
index e087354..0d11f17 100644
--- a/silx/gui/plot/_BaseMaskToolsWidget.py
+++ b/silx/gui/plot/_BaseMaskToolsWidget.py
@@ -29,7 +29,7 @@ from __future__ import division
__authors__ = ["T. Vincent", "P. Knobel"]
__license__ = "MIT"
-__date__ = "29/08/2018"
+__date__ = "15/02/2019"
import os
import weakref
@@ -141,7 +141,7 @@ class BaseMask(qt.QObject):
def commit(self):
"""Append the current mask to history if changed"""
if (not self._history or self._redo or
- not numpy.all(numpy.equal(self._mask, self._history[-1]))):
+ not numpy.array_equal(self._mask, self._history[-1])):
if self._redo:
self._redo = [] # Reset redo as a new action as been performed
self.sigRedoable[bool].emit(False)
@@ -325,7 +325,7 @@ class BaseMask(qt.QObject):
raise NotImplementedError("To be implemented in subclass")
def updateDisk(self, level, crow, ccol, radius, mask=True):
- """Mask/Unmask data located inside a disk of the given mask level.
+ """Mask/Unmask data located inside a dick of the given mask level.
:param int level: Mask level to update.
:param crow: Disk center row/ordinate (y).
@@ -335,6 +335,18 @@ class BaseMask(qt.QObject):
"""
raise NotImplementedError("To be implemented in subclass")
+ def updateEllipse(self, level, crow, ccol, radius_r, radius_c, mask=True):
+ """Mask/Unmask a disk of the given mask level.
+
+ :param int level: Mask level to update.
+ :param int crow: Row of the center of the ellipse
+ :param int ccol: Column of the center of the ellipse
+ :param float radius_r: Radius of the ellipse in the row
+ :param float radius_c: Radius of the ellipse in the column
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ raise NotImplementedError("To be implemented in subclass")
+
def updateLine(self, level, row0, col0, row1, col1, width, mask=True):
"""Mask/Unmask a line of the given mask level.
@@ -376,13 +388,11 @@ class BaseMaskToolsWidget(qt.QWidget):
self._plotRef = weakref.ref(plot)
self._maskName = '__MASK_TOOLS_%d' % id(self) # Legend of the mask
- self._colormap = Colormap(name="",
- normalization='linear',
+ self._colormap = Colormap(normalization='linear',
vmin=0,
- vmax=self._maxLevelNumber,
- colors=None)
+ vmax=self._maxLevelNumber)
self._defaultOverlayColor = rgba('gray') # Color of the mask
- self._setMaskColors(1, 0.5)
+ self._setMaskColors(1, 0.5) # Set the colormap LUT
if not isinstance(mask, BaseMask):
raise TypeError("mask is not an instance of BaseMask")
@@ -482,6 +492,7 @@ class BaseMaskToolsWidget(qt.QWidget):
layout.addWidget(self._initMaskGroupBox())
layout.addWidget(self._initDrawGroupBox())
layout.addWidget(self._initThresholdGroupBox())
+ layout.addWidget(self._initOtherToolsGroupBox())
layout.addStretch(1)
self.setLayout(layout)
@@ -617,6 +628,15 @@ class BaseMaskToolsWidget(qt.QWidget):
self.rectAction.triggered.connect(self._activeRectMode)
self.addAction(self.rectAction)
+ self.ellipseAction = qt.QAction(
+ icons.getQIcon('shape-ellipse'), 'Circle selection', None)
+ self.ellipseAction.setToolTip(
+ 'Rectangle selection tool: (Un)Mask a circle region <b>R</b>')
+ self.ellipseAction.setShortcut(qt.QKeySequence(qt.Qt.Key_R))
+ self.ellipseAction.setCheckable(True)
+ self.ellipseAction.triggered.connect(self._activeEllipseMode)
+ self.addAction(self.ellipseAction)
+
self.polygonAction = qt.QAction(
icons.getQIcon('shape-polygon'), 'Polygon selection', None)
self.polygonAction.setShortcut(qt.QKeySequence(qt.Qt.Key_S))
@@ -640,10 +660,11 @@ class BaseMaskToolsWidget(qt.QWidget):
self.drawActionGroup = qt.QActionGroup(self)
self.drawActionGroup.setExclusive(True)
self.drawActionGroup.addAction(self.rectAction)
+ self.drawActionGroup.addAction(self.ellipseAction)
self.drawActionGroup.addAction(self.polygonAction)
self.drawActionGroup.addAction(self.pencilAction)
- actions = (self.browseAction, self.rectAction,
+ actions = (self.browseAction, self.rectAction, self.ellipseAction,
self.polygonAction, self.pencilAction)
drawButtons = []
for action in actions:
@@ -711,36 +732,28 @@ class BaseMaskToolsWidget(qt.QWidget):
def _initThresholdGroupBox(self):
"""Init thresholding widgets"""
- layout = qt.QVBoxLayout()
-
- # Thresholing
self.belowThresholdAction = qt.QAction(
icons.getQIcon('plot-roi-below'), 'Mask below threshold', None)
self.belowThresholdAction.setToolTip(
'Mask image where values are below given threshold')
self.belowThresholdAction.setCheckable(True)
- self.belowThresholdAction.triggered[bool].connect(
- self._belowThresholdActionTriggered)
+ self.belowThresholdAction.setChecked(True)
self.betweenThresholdAction = qt.QAction(
icons.getQIcon('plot-roi-between'), 'Mask within range', None)
self.betweenThresholdAction.setToolTip(
'Mask image where values are within given range')
self.betweenThresholdAction.setCheckable(True)
- self.betweenThresholdAction.triggered[bool].connect(
- self._betweenThresholdActionTriggered)
self.aboveThresholdAction = qt.QAction(
icons.getQIcon('plot-roi-above'), 'Mask above threshold', None)
self.aboveThresholdAction.setToolTip(
'Mask image where values are above given threshold')
self.aboveThresholdAction.setCheckable(True)
- self.aboveThresholdAction.triggered[bool].connect(
- self._aboveThresholdActionTriggered)
self.thresholdActionGroup = qt.QActionGroup(self)
- self.thresholdActionGroup.setExclusive(False)
+ self.thresholdActionGroup.setExclusive(True)
self.thresholdActionGroup.addAction(self.belowThresholdAction)
self.thresholdActionGroup.addAction(self.betweenThresholdAction)
self.thresholdActionGroup.addAction(self.aboveThresholdAction)
@@ -770,41 +783,50 @@ class BaseMaskToolsWidget(qt.QWidget):
loadColormapRangeBtn.setDefaultAction(self.loadColormapRangeAction)
widgets.append(loadColormapRangeBtn)
- container = self._hboxWidget(*widgets, stretch=False)
- layout.addWidget(container)
+ toolBar = self._hboxWidget(*widgets, stretch=False)
- form = qt.QFormLayout()
+ config = qt.QGridLayout()
+ config.setContentsMargins(0, 0, 0, 0)
+ self.minLineLabel = qt.QLabel("Min:", self)
self.minLineEdit = FloatEdit(self, value=0)
- self.minLineEdit.setEnabled(False)
- form.addRow('Min:', self.minLineEdit)
+ config.addWidget(self.minLineLabel, 0, 0)
+ config.addWidget(self.minLineEdit, 0, 1)
+ self.maxLineLabel = qt.QLabel("Max:", self)
self.maxLineEdit = FloatEdit(self, value=0)
- self.maxLineEdit.setEnabled(False)
- form.addRow('Max:', self.maxLineEdit)
+ config.addWidget(self.maxLineLabel, 1, 0)
+ config.addWidget(self.maxLineEdit, 1, 1)
self.applyMaskBtn = qt.QPushButton('Apply mask')
self.applyMaskBtn.clicked.connect(self._maskBtnClicked)
- self.applyMaskBtn.setEnabled(False)
- form.addRow(self.applyMaskBtn)
-
- self.maskNanBtn = qt.QPushButton('Mask not finite values')
- self.maskNanBtn.setToolTip('Mask Not a Number and infinite values')
- self.maskNanBtn.clicked.connect(self._maskNotFiniteBtnClicked)
- form.addRow(self.maskNanBtn)
- thresholdWidget = qt.QWidget()
- thresholdWidget.setLayout(form)
- layout.addWidget(thresholdWidget)
-
- layout.addStretch(1)
+ layout = qt.QVBoxLayout()
+ layout.addWidget(toolBar)
+ layout.addLayout(config)
+ layout.addWidget(self.applyMaskBtn)
self.thresholdGroup = qt.QGroupBox('Threshold')
self.thresholdGroup.setLayout(layout)
+
+ # Init widget state
+ self._thresholdActionGroupTriggered(self.belowThresholdAction)
return self.thresholdGroup
# track widget visibility and plot active image changes
+ def _initOtherToolsGroupBox(self):
+ layout = qt.QVBoxLayout()
+
+ self.maskNanBtn = qt.QPushButton('Mask not finite values')
+ self.maskNanBtn.setToolTip('Mask Not a Number and infinite values')
+ self.maskNanBtn.clicked.connect(self._maskNotFiniteBtnClicked)
+ layout.addWidget(self.maskNanBtn)
+
+ self.otherToolGroup = qt.QGroupBox('Other tools')
+ self.otherToolGroup.setLayout(layout)
+ return self.otherToolGroup
+
def changeEvent(self, event):
"""Reset drawing action when disabling widget"""
if (event.type() == qt.QEvent.EnabledChange and
@@ -883,6 +905,7 @@ class BaseMaskToolsWidget(qt.QWidget):
The index of the mask for which we want to change the color.
If none set this color for all the masks
"""
+ rgb = rgba(rgb)[0:3]
if level is None:
self._overlayColors[:] = rgb
self._defaultColors[:] = False
@@ -925,6 +948,8 @@ class BaseMaskToolsWidget(qt.QWidget):
"""
if self._drawingMode == 'rectangle':
self._activeRectMode()
+ elif self._drawingMode == 'ellipse':
+ self._activeEllipseMode()
elif self._drawingMode == 'polygon':
self._activePolygonMode()
elif self._drawingMode == 'pencil':
@@ -971,6 +996,16 @@ class BaseMaskToolsWidget(qt.QWidget):
'draw', shape='rectangle', source=self, color=color)
self._updateDrawingModeWidgets()
+ def _activeEllipseMode(self):
+ """Handle circle action mode triggering"""
+ self._releaseDrawingMode()
+ self._drawingMode = 'ellipse'
+ self.plot.sigPlotSignal.connect(self._plotDrawEvent)
+ color = self.getCurrentMaskColor()
+ self.plot.setInteractiveMode(
+ 'draw', shape='ellipse', source=self, color=color)
+ self._updateDrawingModeWidgets()
+
def _activePolygonMode(self):
"""Handle polygon action mode triggering"""
self._releaseDrawingMode()
@@ -1016,36 +1051,28 @@ class BaseMaskToolsWidget(qt.QWidget):
return doMask
# Handle threshold UI events
- def _belowThresholdActionTriggered(self, triggered):
- if triggered:
- self.minLineEdit.setEnabled(True)
- self.maxLineEdit.setEnabled(False)
- self.applyMaskBtn.setEnabled(True)
-
- def _betweenThresholdActionTriggered(self, triggered):
- if triggered:
- self.minLineEdit.setEnabled(True)
- self.maxLineEdit.setEnabled(True)
- self.applyMaskBtn.setEnabled(True)
-
- def _aboveThresholdActionTriggered(self, triggered):
- if triggered:
- self.minLineEdit.setEnabled(False)
- self.maxLineEdit.setEnabled(True)
- self.applyMaskBtn.setEnabled(True)
def _thresholdActionGroupTriggered(self, triggeredAction):
"""Threshold action group listener."""
- if triggeredAction.isChecked():
- # Uncheck other actions
- for action in self.thresholdActionGroup.actions():
- if action is not triggeredAction and action.isChecked():
- action.setChecked(False)
- else:
- # Disable min/max edit
- self.minLineEdit.setEnabled(False)
- self.maxLineEdit.setEnabled(False)
- self.applyMaskBtn.setEnabled(False)
+ if triggeredAction is self.belowThresholdAction:
+ self.minLineLabel.setVisible(True)
+ self.maxLineLabel.setVisible(False)
+ self.minLineEdit.setVisible(True)
+ self.maxLineEdit.setVisible(False)
+ self.applyMaskBtn.setText("Mask bellow")
+ elif triggeredAction is self.betweenThresholdAction:
+ self.minLineLabel.setVisible(True)
+ self.maxLineLabel.setVisible(True)
+ self.minLineEdit.setVisible(True)
+ self.maxLineEdit.setVisible(True)
+ self.applyMaskBtn.setText("Mask between")
+ elif triggeredAction is self.aboveThresholdAction:
+ self.minLineLabel.setVisible(False)
+ self.maxLineLabel.setVisible(True)
+ self.minLineEdit.setVisible(False)
+ self.maxLineEdit.setVisible(True)
+ self.applyMaskBtn.setText("Mask above")
+ self.applyMaskBtn.setToolTip(triggeredAction.toolTip())
def _maskBtnClicked(self):
if self.belowThresholdAction.isChecked():
diff --git a/silx/gui/plot/_utils/dtime_ticklayout.py b/silx/gui/plot/_utils/dtime_ticklayout.py
index 95fc235..23c9dce 100644
--- a/silx/gui/plot/_utils/dtime_ticklayout.py
+++ b/silx/gui/plot/_utils/dtime_ticklayout.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2014-2017 European Synchrotron Radiation Facility
+# Copyright (c) 2014-2018 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -32,6 +32,7 @@ __date__ = "04/04/2018"
import datetime as dt
+import enum
import logging
import math
import time
@@ -40,7 +41,6 @@ import dateutil.tz
from dateutil.relativedelta import relativedelta
-from silx.third_party import enum
from .ticklayout import niceNumGeneric
_logger = logging.getLogger(__name__)
diff --git a/silx/gui/plot/actions/control.py b/silx/gui/plot/actions/control.py
index 10df130..2d01ef1 100644
--- a/silx/gui/plot/actions/control.py
+++ b/silx/gui/plot/actions/control.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -303,9 +303,12 @@ class CurveStyleAction(PlotAction):
currentState = (self.plot.isDefaultPlotLines(),
self.plot.isDefaultPlotPoints())
- # line only, line and symbol, symbol only
- states = (True, False), (True, True), (False, True)
- newState = states[(states.index(currentState) + 1) % 3]
+ if currentState == (False, False):
+ newState = True, False
+ else:
+ # line only, line and symbol, symbol only
+ states = (True, False), (True, True), (False, True)
+ newState = states[(states.index(currentState) + 1) % 3]
self.plot.setDefaultPlotLines(newState[0])
self.plot.setDefaultPlotPoints(newState[1])
diff --git a/silx/gui/plot/actions/io.py b/silx/gui/plot/actions/io.py
index 97de527..09e4a99 100644
--- a/silx/gui/plot/actions/io.py
+++ b/silx/gui/plot/actions/io.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -502,7 +502,7 @@ class SaveAction(PlotAction):
axes_errors=[xerror, yerror],
title=plot.getGraphTitle())
- def setFileFilter(self, dataKind, nameFilter, func):
+ def setFileFilter(self, dataKind, nameFilter, func, index=None):
"""Set a name filter to add/replace a file format support
:param str dataKind:
@@ -513,10 +513,44 @@ class SaveAction(PlotAction):
:param callable func: The function to call to perform saving.
Expected signature is:
bool func(PlotWidget plot, str filename, str nameFilter)
+ :param integer index: Index of the filter in the final list (or None)
"""
assert dataKind in ('all', 'curve', 'curves', 'image', 'scatter')
+ # first append or replace the new filter to prevent colissions
self._filters[dataKind][nameFilter] = func
+ if index is None:
+ # we are already done
+ return
+
+ # get the current ordered list of keys
+ keyList = list(self._filters[dataKind].keys())
+
+ # deal with negative indices
+ if index < 0:
+ index = len(keyList) + index
+ if index < 0:
+ index = 0
+
+ if index >= len(keyList):
+ # nothing to be done, already at the end
+ txt = 'Requested index %d impossible, already at the end' % index
+ _logger.info(txt)
+ return
+
+ # get the new ordered list
+ oldIndex = keyList.index(nameFilter)
+ del keyList[oldIndex]
+ keyList.insert(index, nameFilter)
+
+ # build the new filters
+ newFilters = OrderedDict()
+ for key in keyList:
+ newFilters[key] = self._filters[dataKind][key]
+
+ # and update the filters
+ self._filters[dataKind] = newFilters
+ return
def getFileFilters(self, dataKind):
"""Returns the nameFilter and associated function for a kind of data.
diff --git a/silx/gui/plot/backends/BackendBase.py b/silx/gui/plot/backends/BackendBase.py
index 7fb8be0..0514c85 100644
--- a/silx/gui/plot/backends/BackendBase.py
+++ b/silx/gui/plot/backends/BackendBase.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -31,7 +31,7 @@ This API is a simplified version of PyMca PlotBackend API.
__authors__ = ["V.A. Sole", "T. Vincent"]
__license__ = "MIT"
-__date__ = "24/04/2018"
+__date__ = "21/12/2018"
import weakref
from ... import qt
@@ -170,7 +170,8 @@ class BackendBase(object):
"""
return legend
- def addItem(self, x, y, legend, shape, color, fill, overlay, z):
+ def addItem(self, x, y, legend, shape, color, fill, overlay, z,
+ linestyle, linewidth, linebgcolor):
"""Add an item (i.e. a shape) to the plot.
:param numpy.ndarray x: The X coords of the points of the shape
@@ -182,6 +183,19 @@ class BackendBase(object):
:param bool fill: True to fill the shape
:param bool overlay: True if item is an overlay, False otherwise
:param int z: Layer on which to draw the item
+ :param str linestyle: Style of the line.
+ Only relevant for line markers where X or Y is None.
+ Value in:
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+ :param float linewidth: Width of the line.
+ Only relevant for line markers where X or Y is None.
+ :param str linebgcolor: Background color of the line, e.g., 'blue', 'b',
+ '#FF0000'. It is used to draw dotted line using a second color.
:returns: The handle used by the backend to univocally access the item
"""
return legend
@@ -546,3 +560,20 @@ class BackendBase(object):
This only check status set to axes from the public API
"""
return self._axesDisplayed
+
+ def setForegroundColors(self, foregroundColor, gridColor):
+ """Set foreground and grid colors used to display this widget.
+
+ :param List[float] foregroundColor: RGBA foreground color of the widget
+ :param List[float] gridColor: RGBA grid color of the data view
+ """
+ pass
+
+ def setBackgroundColors(self, backgroundColor, dataBackgroundColor):
+ """Set background colors used to display this widget.
+
+ :param List[float] backgroundColor: RGBA background color of the widget
+ :param Union[Tuple[float],None] dataBackgroundColor:
+ RGBA background color of the data view
+ """
+ pass
diff --git a/silx/gui/plot/backends/BackendMatplotlib.py b/silx/gui/plot/backends/BackendMatplotlib.py
index 3b1d6dd..726a839 100644
--- a/silx/gui/plot/backends/BackendMatplotlib.py
+++ b/silx/gui/plot/backends/BackendMatplotlib.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -28,7 +28,7 @@ from __future__ import division
__authors__ = ["V.A. Sole", "T. Vincent, H. Payno"]
__license__ = "MIT"
-__date__ = "01/08/2018"
+__date__ = "21/12/2018"
import logging
@@ -56,12 +56,26 @@ from matplotlib.collections import PathCollection, LineCollection
from matplotlib.ticker import Formatter, ScalarFormatter, Locator
-from ....third_party.modest_image import ModestImage
from . import BackendBase
from .._utils import FLOAT32_MINPOS
from .._utils.dtime_ticklayout import calcTicks, bestFormatString, timestamp
+_PATCH_LINESTYLE = {
+ "-": 'solid',
+ "--": 'dashed',
+ '-.': 'dashdot',
+ ':': 'dotted',
+ '': "solid",
+ None: "solid",
+}
+"""Patches do not uses the same matplotlib syntax"""
+
+
+def normalize_linestyle(linestyle):
+ """Normalize known old-style linestyle, else return the provided value."""
+ return _PATCH_LINESTYLE.get(linestyle, linestyle)
+
class NiceDateLocator(Locator):
"""
@@ -115,7 +129,6 @@ class NiceDateLocator(Locator):
return ticks
-
class NiceAutoDateFormatter(Formatter):
"""
Matplotlib FuncFormatter that is linked to a NiceDateLocator and gives the
@@ -139,7 +152,6 @@ class NiceAutoDateFormatter(Formatter):
else:
return bestFormatString(self.locator.spacing, self.locator.unit)
-
def __call__(self, x, pos=None):
"""Return the format for tick val *x* at position *pos*
Expects x to be a POSIX timestamp (seconds since 1 Jan 1970)
@@ -149,8 +161,6 @@ class NiceAutoDateFormatter(Formatter):
return tickStr
-
-
class _MarkerContainer(Container):
"""Marker artists container supporting draw/remove and text position update
@@ -204,6 +214,57 @@ class _MarkerContainer(Container):
self.text.set_x(xmax)
+class _DoubleColoredLinePatch(matplotlib.patches.Patch):
+ """Matplotlib patch to display any patch using double color."""
+
+ def __init__(self, patch):
+ super(_DoubleColoredLinePatch, self).__init__()
+ self.__patch = patch
+ self.linebgcolor = None
+
+ def __getattr__(self, name):
+ return getattr(self.__patch, name)
+
+ def draw(self, renderer):
+ oldLineStype = self.__patch.get_linestyle()
+ if self.linebgcolor is not None and oldLineStype != "solid":
+ oldLineColor = self.__patch.get_edgecolor()
+ oldHatch = self.__patch.get_hatch()
+ self.__patch.set_linestyle("solid")
+ self.__patch.set_edgecolor(self.linebgcolor)
+ self.__patch.set_hatch(None)
+ self.__patch.draw(renderer)
+ self.__patch.set_linestyle(oldLineStype)
+ self.__patch.set_edgecolor(oldLineColor)
+ self.__patch.set_hatch(oldHatch)
+ self.__patch.draw(renderer)
+
+ def set_transform(self, transform):
+ self.__patch.set_transform(transform)
+
+ def get_path(self):
+ return self.__patch.get_path()
+
+ def contains(self, mouseevent, radius=None):
+ return self.__patch.contains(mouseevent, radius)
+
+ def contains_point(self, point, radius=None):
+ return self.__patch.contains_point(point, radius)
+
+
+class Image(AxesImage):
+ """An AxesImage with a fast path for uint8 RGBA images"""
+
+ def set_data(self, A):
+ A = numpy.array(A, copy=False)
+ if A.ndim != 3 or A.shape[2] != 4 or A.dtype != numpy.uint8:
+ super(Image, self).set_data(A)
+ else:
+ # Call AxesImage.set_data with small data to set attributes
+ super(Image, self).set_data(numpy.zeros((2, 2, 4), dtype=A.dtype))
+ self._A = A # Override stored data
+
+
class BackendMatplotlib(BackendBase.BackendBase):
"""Base class for Matplotlib backend without a FigureCanvas.
@@ -231,6 +292,8 @@ class BackendMatplotlib(BackendBase.BackendBase):
self.ax = self.fig.add_axes([.15, .15, .75, .75], label="left")
self.ax2 = self.ax.twinx()
self.ax2.set_label("right")
+ # Make sure background of Axes is displayed
+ self.ax2.patch.set_visible(True)
# disable the use of offsets
try:
@@ -239,9 +302,9 @@ class BackendMatplotlib(BackendBase.BackendBase):
self.ax2.get_yaxis().get_major_formatter().set_useOffset(False)
self.ax2.get_xaxis().get_major_formatter().set_useOffset(False)
except:
- _logger.warning('Cannot disabled axes offsets in %s ' \
+ _logger.warning('Cannot disabled axes offsets in %s '
% matplotlib.__version__)
-
+
# critical for picking!!!!
self.ax2.set_zorder(0)
self.ax2.set_autoscaley_on(True)
@@ -376,44 +439,13 @@ class BackendMatplotlib(BackendBase.BackendBase):
picker = (selectable or draggable)
- # Debian 7 specific support
- # No transparent colormap with matplotlib < 1.2.0
- # Add support for transparent colormap for uint8 data with
- # colormap with 256 colors, linear norm, [0, 255] range
- if self._matplotlibVersion < _parse_version('1.2.0'):
- if (len(data.shape) == 2 and colormap.getName() is None and
- colormap.getColormapLUT() is not None):
- colors = colormap.getColormapLUT()
- if (colors.shape[-1] == 4 and
- not numpy.all(numpy.equal(colors[3], 255))):
- # This is a transparent colormap
- if (colors.shape == (256, 4) and
- colormap.getNormalization() == 'linear' and
- not colormap.isAutoscale() and
- colormap.getVMin() == 0 and
- colormap.getVMax() == 255 and
- data.dtype == numpy.uint8):
- # Supported case, convert data to RGBA
- data = colors[data.reshape(-1)].reshape(
- data.shape + (4,))
- else:
- _logger.warning(
- 'matplotlib %s does not support transparent '
- 'colormap.', matplotlib.__version__)
-
- if ((height * width) > 5.0e5 and
- origin == (0., 0.) and scale == (1., 1.)):
- imageClass = ModestImage
- else:
- imageClass = AxesImage
-
# All image are shown as RGBA image
- image = imageClass(self.ax,
- label="__IMAGE__" + legend,
- interpolation='nearest',
- picker=picker,
- zorder=z,
- origin='lower')
+ image = Image(self.ax,
+ label="__IMAGE__" + legend,
+ interpolation='nearest',
+ picker=picker,
+ zorder=z,
+ origin='lower')
if alpha < 1:
image.set_alpha(alpha)
@@ -438,40 +470,41 @@ class BackendMatplotlib(BackendBase.BackendBase):
ystep = 1 if scale[1] >= 0. else -1
data = data[::ystep, ::xstep]
- if self._matplotlibVersion < _parse_version('2.1'):
- # matplotlib 1.4.2 do not support float128
- dtype = data.dtype
- if dtype.kind == "f" and dtype.itemsize >= 16:
- _logger.warning("Your matplotlib version do not support "
- "float128. Data converted to float64.")
- data = data.astype(numpy.float64)
-
if data.ndim == 2: # Data image, convert to RGBA image
data = colormap.applyToData(data)
image.set_data(data)
-
self.ax.add_artist(image)
-
return image
- def addItem(self, x, y, legend, shape, color, fill, overlay, z):
+ def addItem(self, x, y, legend, shape, color, fill, overlay, z,
+ linestyle, linewidth, linebgcolor):
+ if (linebgcolor is not None and
+ shape not in ('rectangle', 'polygon', 'polylines')):
+ _logger.warning(
+ 'linebgcolor not implemented for %s with matplotlib backend',
+ shape)
xView = numpy.array(x, copy=False)
yView = numpy.array(y, copy=False)
+ linestyle = normalize_linestyle(linestyle)
+
if shape == "line":
item = self.ax.plot(x, y, label=legend, color=color,
- linestyle='-', marker=None)[0]
+ linestyle=linestyle, linewidth=linewidth,
+ marker=None)[0]
elif shape == "hline":
if hasattr(y, "__len__"):
y = y[-1]
- item = self.ax.axhline(y, label=legend, color=color)
+ item = self.ax.axhline(y, label=legend, color=color,
+ linestyle=linestyle, linewidth=linewidth)
elif shape == "vline":
if hasattr(x, "__len__"):
x = x[-1]
- item = self.ax.axvline(x, label=legend, color=color)
+ item = self.ax.axvline(x, label=legend, color=color,
+ linestyle=linestyle, linewidth=linewidth)
elif shape == 'rectangle':
xMin = numpy.nanmin(xView)
@@ -484,10 +517,16 @@ class BackendMatplotlib(BackendBase.BackendBase):
width=w,
height=h,
fill=False,
- color=color)
+ color=color,
+ linestyle=linestyle,
+ linewidth=linewidth)
if fill:
item.set_hatch('.')
+ if linestyle != "solid" and linebgcolor is not None:
+ item = _DoubleColoredLinePatch(item)
+ item.linebgcolor = linebgcolor
+
self.ax.add_patch(item)
elif shape in ('polygon', 'polylines'):
@@ -500,10 +539,16 @@ class BackendMatplotlib(BackendBase.BackendBase):
closed=closed,
fill=False,
label=legend,
- color=color)
+ color=color,
+ linestyle=linestyle,
+ linewidth=linewidth)
if fill and shape == 'polygon':
item.set_hatch('/')
+ if linestyle != "solid" and linebgcolor is not None:
+ item = _DoubleColoredLinePatch(item)
+ item.linebgcolor = linebgcolor
+
self.ax.add_patch(item)
else:
@@ -908,8 +953,56 @@ class BackendMatplotlib(BackendBase.BackendBase):
# remove external margins
self.ax.set_position([0, 0, 1, 1])
self.ax2.set_position([0, 0, 1, 1])
+ self._synchronizeBackgroundColors()
+ self._synchronizeForegroundColors()
self._plot._setDirtyPlot()
+ def _synchronizeBackgroundColors(self):
+ backgroundColor = self._plot.getBackgroundColor().getRgbF()
+
+ dataBackgroundColor = self._plot.getDataBackgroundColor()
+ if dataBackgroundColor.isValid():
+ dataBackgroundColor = dataBackgroundColor.getRgbF()
+ else:
+ dataBackgroundColor = backgroundColor
+
+ if self.ax2.axison:
+ self.fig.patch.set_facecolor(backgroundColor)
+ if self._matplotlibVersion < _parse_version('2'):
+ self.ax2.set_axis_bgcolor(dataBackgroundColor)
+ else:
+ self.ax2.set_facecolor(dataBackgroundColor)
+ else:
+ self.fig.patch.set_facecolor(dataBackgroundColor)
+
+ def _synchronizeForegroundColors(self):
+ foregroundColor = self._plot.getForegroundColor().getRgbF()
+
+ gridColor = self._plot.getGridColor()
+ if gridColor.isValid():
+ gridColor = gridColor.getRgbF()
+ else:
+ gridColor = foregroundColor
+
+ for axes in (self.ax, self.ax2):
+ if axes.axison:
+ axes.spines['bottom'].set_color(foregroundColor)
+ axes.spines['top'].set_color(foregroundColor)
+ axes.spines['right'].set_color(foregroundColor)
+ axes.spines['left'].set_color(foregroundColor)
+ axes.tick_params(axis='x', colors=foregroundColor)
+ axes.tick_params(axis='y', colors=foregroundColor)
+ axes.yaxis.label.set_color(foregroundColor)
+ axes.xaxis.label.set_color(foregroundColor)
+ axes.title.set_color(foregroundColor)
+
+ for line in axes.get_xgridlines():
+ line.set_color(gridColor)
+
+ for line in axes.get_ygridlines():
+ line.set_color(gridColor)
+ # axes.grid().set_markeredgecolor(gridColor)
+
class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib):
"""QWidget matplotlib backend using a QtAgg canvas.
@@ -1137,3 +1230,9 @@ class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib):
else:
cursor = self._QT_CURSORS[cursor]
FigureCanvasQTAgg.setCursor(self, qt.QCursor(cursor))
+
+ def setBackgroundColors(self, backgroundColor, dataBackgroundColor):
+ self._synchronizeBackgroundColors()
+
+ def setForegroundColors(self, foregroundColor, gridColor):
+ self._synchronizeForegroundColors()
diff --git a/silx/gui/plot/backends/BackendOpenGL.py b/silx/gui/plot/backends/BackendOpenGL.py
index 9e2cb73..e33d03c 100644
--- a/silx/gui/plot/backends/BackendOpenGL.py
+++ b/silx/gui/plot/backends/BackendOpenGL.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2014-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2014-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -28,7 +28,7 @@ from __future__ import division
__authors__ = ["T. Vincent"]
__license__ = "MIT"
-__date__ = "01/08/2018"
+__date__ = "21/12/2018"
from collections import OrderedDict, namedtuple
from ctypes import c_void_p
@@ -44,10 +44,11 @@ from ... import qt
from ..._glutils import gl
from ... import _glutils as glu
from .glutils import (
+ GLLines2D,
GLPlotCurve2D, GLPlotColormap, GLPlotRGBAImage, GLPlotFrame2D,
mat4Ortho, mat4Identity,
LEFT, RIGHT, BOTTOM, TOP,
- Text2D, Shape2D)
+ Text2D, FilledShape2D)
from .glutils.PlotImageFile import saveImageToFile
_logger = logging.getLogger(__name__)
@@ -338,6 +339,9 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
f=f)
BackendBase.BackendBase.__init__(self, plot, parent)
+ self._backgroundColor = 1., 1., 1., 1.
+ self._dataBackgroundColor = 1., 1., 1., 1.
+
self.matScreenProj = mat4Identity()
self._progBase = glu.Program(
@@ -357,6 +361,8 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
self._glGarbageCollector = []
self._plotFrame = GLPlotFrame2D(
+ foregroundColor=(0., 0., 0., 1.),
+ gridColor=(.7, .7, .7, 1.),
margins={'left': 100, 'right': 50, 'top': 50, 'bottom': 50})
# Make postRedisplay asynchronous using Qt signal
@@ -432,7 +438,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
def initializeGL(self):
gl.testGL()
- gl.glClearColor(1., 1., 1., 1.)
gl.glClearStencil(0)
gl.glEnable(gl.GL_BLEND)
@@ -482,6 +487,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
self._plotFBOs[context] = plotFBOTex
with plotFBOTex:
+ gl.glClearColor(*self._backgroundColor)
gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_STENCIL_BUFFER_BIT)
self._renderPlotAreaGL()
self._plotFrame.render()
@@ -530,6 +536,7 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
item.discard()
self._glGarbageCollector = []
+ gl.glClearColor(*self._backgroundColor)
gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_STENCIL_BUFFER_BIT)
# Check if window is large enough
@@ -543,100 +550,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
glu.setGLContextGetter()
_current_context = None
- def _nonOrthoAxesLineMarkerPrimitives(self, marker, pixelOffset):
- """Generates the vertices and label for a line marker.
-
- :param dict marker: Description of a line marker
- :param int pixelOffset: Offset of text from borders in pixels
- :return: Line vertices and Text label or None
- :rtype: 2-tuple (2x2 numpy.array of float, Text2D)
- """
- label, vertices = None, None
-
- xCoord, yCoord = marker['x'], marker['y']
- assert xCoord is None or yCoord is None # Specific to line markers
-
- # Get plot corners in data coords
- plotLeft, plotTop, plotWidth, plotHeight = self.getPlotBoundsInPixels()
-
- corners = [(plotLeft, plotTop),
- (plotLeft, plotTop + plotHeight),
- (plotLeft + plotWidth, plotTop + plotHeight),
- (plotLeft + plotWidth, plotTop)]
- corners = numpy.array([self.pixelToData(x, y, axis='left', check=False)
- for (x, y) in corners])
-
- borders = {
- 'right': (corners[3], corners[2]),
- 'top': (corners[0], corners[3]),
- 'bottom': (corners[2], corners[1]),
- 'left': (corners[1], corners[0])
- }
-
- textLayouts = { # align, valign, offsets
- 'right': (RIGHT, BOTTOM, (-1., -1.)),
- 'top': (LEFT, TOP, (1., 1.)),
- 'bottom': (LEFT, BOTTOM, (1., -1.)),
- 'left': (LEFT, BOTTOM, (1., -1.))
- }
-
- if xCoord is None: # Horizontal line in data space
- if marker['text'] is not None:
- # Find intersection of hline with borders in data
- # Order is important as it stops at first intersection
- for border_name in ('right', 'top', 'bottom', 'left'):
- (x0, y0), (x1, y1) = borders[border_name]
-
- if min(y0, y1) <= yCoord < max(y0, y1):
- xIntersect = (yCoord - y0) * (x1 - x0) / (y1 - y0) + x0
-
- # Add text label
- pixelPos = self.dataToPixel(
- xIntersect, yCoord, axis='left', check=False)
-
- align, valign, offsets = textLayouts[border_name]
-
- x = pixelPos[0] + offsets[0] * pixelOffset
- y = pixelPos[1] + offsets[1] * pixelOffset
- label = Text2D(marker['text'], x, y,
- color=marker['color'],
- bgColor=(1., 1., 1., 0.5),
- align=align, valign=valign)
- break # Stop at first intersection
-
- xMin, xMax = corners[:, 0].min(), corners[:, 0].max()
- vertices = numpy.array(
- ((xMin, yCoord), (xMax, yCoord)), dtype=numpy.float32)
-
- else: # yCoord is None: vertical line in data space
- if marker['text'] is not None:
- # Find intersection of hline with borders in data
- # Order is important as it stops at first intersection
- for border_name in ('top', 'bottom', 'right', 'left'):
- (x0, y0), (x1, y1) = borders[border_name]
- if min(x0, x1) <= xCoord < max(x0, x1):
- yIntersect = (xCoord - x0) * (y1 - y0) / (x1 - x0) + y0
-
- # Add text label
- pixelPos = self.dataToPixel(
- xCoord, yIntersect, axis='left', check=False)
-
- align, valign, offsets = textLayouts[border_name]
-
- x = pixelPos[0] + offsets[0] * pixelOffset
- y = pixelPos[1] + offsets[1] * pixelOffset
- label = Text2D(marker['text'], x, y,
- color=marker['color'],
- bgColor=(1., 1., 1., 0.5),
- align=align, valign=valign)
- break # Stop at first intersection
-
- yMin, yMax = corners[:, 1].min(), corners[:, 1].max()
- vertices = numpy.array(
- ((xCoord, yMin), (xCoord, yMax)), dtype=numpy.float32)
-
- return vertices, label
-
def _renderMarkersGL(self):
if len(self._markers) == 0:
return
@@ -651,16 +564,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
- # Prepare vertical and horizontal markers rendering
- self._progBase.use()
- gl.glUniformMatrix4fv(
- self._progBase.uniforms['matrix'], 1, gl.GL_TRUE,
- self.matScreenProj.astype(numpy.float32))
- gl.glUniform2i(self._progBase.uniforms['isLog'], False, False)
- gl.glUniform1i(self._progBase.uniforms['hatchStep'], 0)
- gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.)
- posAttrib = self._progBase.attributes['position']
-
labels = []
pixelOffset = 3
@@ -677,59 +580,43 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
continue
if xCoord is None or yCoord is None:
- if not self.isDefaultBaseVectors(): # Non-orthogonal axes
- vertices, label = self._nonOrthoAxesLineMarkerPrimitives(
- marker, pixelOffset)
- if label is not None:
- labels.append(label)
+ pixelPos = self.dataToPixel(
+ xCoord, yCoord, axis='left', check=False)
- else: # Orthogonal axes
- pixelPos = self.dataToPixel(
- xCoord, yCoord, axis='left', check=False)
-
- if xCoord is None: # Horizontal line in data space
- if marker['text'] is not None:
- x = self._plotFrame.size[0] - \
- self._plotFrame.margins.right - pixelOffset
- y = pixelPos[1] - pixelOffset
- label = Text2D(marker['text'], x, y,
- color=marker['color'],
- bgColor=(1., 1., 1., 0.5),
- align=RIGHT, valign=BOTTOM)
- labels.append(label)
-
- width = self._plotFrame.size[0]
- vertices = numpy.array(((0, pixelPos[1]),
- (width, pixelPos[1])),
- dtype=numpy.float32)
-
- else: # yCoord is None: vertical line in data space
- if marker['text'] is not None:
- x = pixelPos[0] + pixelOffset
- y = self._plotFrame.margins.top + pixelOffset
- label = Text2D(marker['text'], x, y,
- color=marker['color'],
- bgColor=(1., 1., 1., 0.5),
- align=LEFT, valign=TOP)
- labels.append(label)
-
- height = self._plotFrame.size[1]
- vertices = numpy.array(((pixelPos[0], 0),
- (pixelPos[0], height)),
- dtype=numpy.float32)
+ if xCoord is None: # Horizontal line in data space
+ if marker['text'] is not None:
+ x = self._plotFrame.size[0] - \
+ self._plotFrame.margins.right - pixelOffset
+ y = pixelPos[1] - pixelOffset
+ label = Text2D(marker['text'], x, y,
+ color=marker['color'],
+ bgColor=(1., 1., 1., 0.5),
+ align=RIGHT, valign=BOTTOM)
+ labels.append(label)
- self._progBase.use()
- gl.glUniform4f(self._progBase.uniforms['color'],
- *marker['color'])
+ width = self._plotFrame.size[0]
+ lines = GLLines2D((0, width), (pixelPos[1], pixelPos[1]),
+ style=marker['linestyle'],
+ color=marker['color'],
+ width=marker['linewidth'])
+ lines.render(self.matScreenProj)
+
+ else: # yCoord is None: vertical line in data space
+ if marker['text'] is not None:
+ x = pixelPos[0] + pixelOffset
+ y = self._plotFrame.margins.top + pixelOffset
+ label = Text2D(marker['text'], x, y,
+ color=marker['color'],
+ bgColor=(1., 1., 1., 0.5),
+ align=LEFT, valign=TOP)
+ labels.append(label)
- gl.glEnableVertexAttribArray(posAttrib)
- gl.glVertexAttribPointer(posAttrib,
- 2,
- gl.GL_FLOAT,
- gl.GL_FALSE,
- 0, vertices)
- gl.glLineWidth(1)
- gl.glDrawArrays(gl.GL_LINES, 0, len(vertices))
+ height = self._plotFrame.size[1]
+ lines = GLLines2D((pixelPos[0], pixelPos[0]), (0, height),
+ style=marker['linestyle'],
+ color=marker['color'],
+ width=marker['linewidth'])
+ lines.render(self.matScreenProj)
else:
pixelPos = self.dataToPixel(
@@ -820,13 +707,17 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
def _renderPlotAreaGL(self):
plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:]
- self._plotFrame.renderGrid()
-
gl.glScissor(self._plotFrame.margins.left,
self._plotFrame.margins.bottom,
plotWidth, plotHeight)
gl.glEnable(gl.GL_SCISSOR_TEST)
+ if self._dataBackgroundColor != self._backgroundColor:
+ gl.glClearColor(*self._dataBackgroundColor)
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT)
+
+ self._plotFrame.renderGrid()
+
# Matrix
trBounds = self._plotFrame.transformedDataRanges
if trBounds.x[0] == trBounds.x[1] or \
@@ -853,32 +744,61 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
# Render Items
gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
- self._progBase.use()
- gl.glUniformMatrix4fv(self._progBase.uniforms['matrix'], 1, gl.GL_TRUE,
- self.matScreenProj.astype(numpy.float32))
- gl.glUniform2i(self._progBase.uniforms['isLog'], False, False)
- gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.)
-
for item in self._items.values():
if ((isXLog and numpy.min(item['x']) < FLOAT32_MINPOS) or
(isYLog and numpy.min(item['y']) < FLOAT32_MINPOS)):
# Ignore items <= 0. on log axes
continue
- closed = item['shape'] != 'polylines'
- points = [self.dataToPixel(x, y, axis='left', check=False)
- for (x, y) in zip(item['x'], item['y'])]
- shape2D = Shape2D(points,
- fill=item['fill'],
- fillColor=item['color'],
- stroke=True,
- strokeColor=item['color'],
- strokeClosed=closed)
+ if item['shape'] == 'hline':
+ width = self._plotFrame.size[0]
+ _, yPixel = self.dataToPixel(
+ None, item['y'], axis='left', check=False)
+ points = numpy.array(((0., yPixel), (width, yPixel)),
+ dtype=numpy.float32)
- posAttrib = self._progBase.attributes['position']
- colorUnif = self._progBase.uniforms['color']
- hatchStepUnif = self._progBase.uniforms['hatchStep']
- shape2D.render(posAttrib, colorUnif, hatchStepUnif)
+ elif item['shape'] == 'vline':
+ xPixel, _ = self.dataToPixel(
+ item['x'], None, axis='left', check=False)
+ height = self._plotFrame.size[1]
+ points = numpy.array(((xPixel, 0), (xPixel, height)),
+ dtype=numpy.float32)
+
+ else:
+ points = numpy.array([
+ self.dataToPixel(x, y, axis='left', check=False)
+ for (x, y) in zip(item['x'], item['y'])])
+
+ # Draw the fill
+ if (item['fill'] is not None and
+ item['shape'] not in ('hline', 'vline')):
+ self._progBase.use()
+ gl.glUniformMatrix4fv(
+ self._progBase.uniforms['matrix'], 1, gl.GL_TRUE,
+ self.matScreenProj.astype(numpy.float32))
+ gl.glUniform2i(self._progBase.uniforms['isLog'], False, False)
+ gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.)
+
+ shape2D = FilledShape2D(
+ points, style=item['fill'], color=item['color'])
+ shape2D.render(
+ posAttrib=self._progBase.attributes['position'],
+ colorUnif=self._progBase.uniforms['color'],
+ hatchStepUnif=self._progBase.uniforms['hatchStep'])
+
+ # Draw the stroke
+ if item['linestyle'] not in ('', ' ', None):
+ if item['shape'] != 'polylines':
+ # close the polyline
+ points = numpy.append(points,
+ numpy.atleast_2d(points[0]), axis=0)
+
+ lines = GLLines2D(points[:, 0], points[:, 1],
+ style=item['linestyle'],
+ color=item['color'],
+ dash2ndColor=item['linebgcolor'],
+ width=item['linewidth'])
+ lines.render(self.matScreenProj)
gl.glDisable(gl.GL_SCISSOR_TEST)
@@ -1123,7 +1043,8 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
return legend, 'image'
- def addItem(self, x, y, legend, shape, color, fill, overlay, z):
+ def addItem(self, x, y, legend, shape, color, fill, overlay, z,
+ linestyle, linewidth, linebgcolor):
# TODO handle overlay
if shape not in ('polygon', 'rectangle', 'line',
'vline', 'hline', 'polylines'):
@@ -1154,7 +1075,10 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
'color': colors.rgba(color),
'fill': 'hatch' if fill else None,
'x': x,
- 'y': y
+ 'y': y,
+ 'linestyle': linestyle,
+ 'linewidth': linewidth,
+ 'linebgcolor': linebgcolor,
}
return legend, 'item'
@@ -1166,10 +1090,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
if symbol is None:
symbol = '+'
- if linestyle != '-' or linewidth != 1:
- _logger.warning(
- 'OpenGL backend does not support marker line style and width.')
-
behaviors = set()
if selectable:
behaviors.add('selectable')
@@ -1191,6 +1111,8 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
'behaviors': behaviors,
'constraint': constraint if isConstraint else None,
'symbol': symbol,
+ 'linestyle': linestyle,
+ 'linewidth': linewidth,
}
return legend, 'marker'
@@ -1441,37 +1363,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
if label:
_logger.warning('Right axis label not implemented')
- # Non orthogonal axes
-
- def setBaseVectors(self, x=(1., 0.), y=(0., 1.)):
- """Set base vectors.
-
- Useful for non-orthogonal axes.
- If an axis is in log scale, skew is applied to log transformed values.
-
- Base vector does not work well with log axes, to investi
- """
- if x != (1., 0.) and y != (0., 1.):
- if self._plotFrame.xAxis.isLog:
- _logger.warning("setBaseVectors disables X axis logarithmic.")
- self.setXAxisLogarithmic(False)
- if self._plotFrame.yAxis.isLog:
- _logger.warning("setBaseVectors disables Y axis logarithmic.")
- self.setYAxisLogarithmic(False)
-
- if self.isKeepDataAspectRatio():
- _logger.warning("setBaseVectors disables keepDataAspectRatio.")
- self.keepDataAspectRatio(False)
-
- self._plotFrame.baseVectors = x, y
-
- def getBaseVectors(self):
- return self._plotFrame.baseVectors
-
- def isDefaultBaseVectors(self):
- return self._plotFrame.baseVectors == \
- self._plotFrame.DEFAULT_BASE_VECTORS
-
# Graph limits
def _setDataRanges(self, xlim=None, ylim=None, y2lim=None):
@@ -1486,26 +1377,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
# Update axes range with a clipped range if too wide
self._plotFrame.setDataRanges(xlim, ylim, y2lim)
- if not self.isDefaultBaseVectors():
- # Update axes range with axes bounds in data coords
- plotLeft, plotTop, plotWidth, plotHeight = \
- self.getPlotBoundsInPixels()
-
- self._plotFrame.xAxis.dataRange = sorted([
- self.pixelToData(x, y, axis='left', check=False)[0]
- for (x, y) in ((plotLeft, plotTop + plotHeight),
- (plotLeft + plotWidth, plotTop + plotHeight))])
-
- self._plotFrame.yAxis.dataRange = sorted([
- self.pixelToData(x, y, axis='left', check=False)[1]
- for (x, y) in ((plotLeft, plotTop + plotHeight),
- (plotLeft, plotTop))])
-
- self._plotFrame.y2Axis.dataRange = sorted([
- self.pixelToData(x, y, axis='right', check=False)[1]
- for (x, y) in ((plotLeft + plotWidth, plotTop + plotHeight),
- (plotLeft + plotWidth, plotTop))])
-
def _ensureAspectRatio(self, keepDim=None):
"""Update plot bounds in order to keep aspect ratio.
@@ -1619,11 +1490,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
_logger.warning(
"KeepDataAspectRatio is ignored with log axes")
- if flag and not self.isDefaultBaseVectors():
- _logger.warning(
- "setXAxisLogarithmic ignored because baseVectors are set")
- return
-
self._plotFrame.xAxis.isLog = flag
def setYAxisLogarithmic(self, flag):
@@ -1633,11 +1499,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
_logger.warning(
"KeepDataAspectRatio is ignored with log axes")
- if flag and not self.isDefaultBaseVectors():
- _logger.warning(
- "setYAxisLogarithmic ignored because baseVectors are set")
- return
-
self._plotFrame.yAxis.isLog = flag
self._plotFrame.y2Axis.isLog = flag
@@ -1658,9 +1519,6 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
if flag and (self._plotFrame.xAxis.isLog or
self._plotFrame.yAxis.isLog):
_logger.warning("KeepDataAspectRatio is ignored with log axes")
- if flag and not self.isDefaultBaseVectors():
- _logger.warning(
- "keepDataAspectRatio ignored because baseVectors are set")
self._keepDataAspectRatio = flag
@@ -1723,3 +1581,11 @@ class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
def setAxesDisplayed(self, displayed):
BackendBase.BackendBase.setAxesDisplayed(self, displayed)
self._plotFrame.displayed = displayed
+
+ def setForegroundColors(self, foregroundColor, gridColor):
+ self._plotFrame.foregroundColor = foregroundColor
+ self._plotFrame.gridColor = gridColor
+
+ def setBackgroundColors(self, backgroundColor, dataBackgroundColor):
+ self._backgroundColor = backgroundColor
+ self._dataBackgroundColor = dataBackgroundColor
diff --git a/silx/gui/plot/backends/glutils/GLPlotCurve.py b/silx/gui/plot/backends/glutils/GLPlotCurve.py
index 12b6bbe..5f8d652 100644
--- a/silx/gui/plot/backends/glutils/GLPlotCurve.py
+++ b/silx/gui/plot/backends/glutils/GLPlotCurve.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2014-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2014-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -42,7 +42,7 @@ import numpy
from silx.math.combo import min_max
from ...._glutils import gl
-from ...._glutils import Program, vertexBuffer
+from ...._glutils import Program, vertexBuffer, VertexBufferAttrib
from .GLSupport import buildFillMaskIndices, mat4Identity, mat4Translate
@@ -245,7 +245,7 @@ class _Fill2D(object):
SOLID, DASHED, DASHDOT, DOTTED = '-', '--', '-.', ':'
-class _Lines2D(object):
+class GLLines2D(object):
"""Object rendering curve as a polyline
:param xVboData: X coordinates VBO
@@ -323,6 +323,7 @@ class _Lines2D(object):
/* Dashes: [0, x], [y, z]
Dash period: w */
uniform vec4 dash;
+ uniform vec4 dash2ndColor;
varying float vDist;
varying vec4 vColor;
@@ -330,25 +331,52 @@ class _Lines2D(object):
void main(void) {
float dist = mod(vDist, dash.w);
if ((dist > dash.x && dist < dash.y) || dist > dash.z) {
- discard;
+ if (dash2ndColor.a == 0.) {
+ discard; // Discard full transparent bg color
+ } else {
+ gl_FragColor = dash2ndColor;
+ }
+ } else {
+ gl_FragColor = vColor;
}
- gl_FragColor = vColor;
}
""",
attrib0='xPos')
def __init__(self, xVboData=None, yVboData=None,
colorVboData=None, distVboData=None,
- style=SOLID, color=(0., 0., 0., 1.),
- width=1, dashPeriod=20, drawMode=None,
+ style=SOLID, color=(0., 0., 0., 1.), dash2ndColor=None,
+ width=1, dashPeriod=10., drawMode=None,
offset=(0., 0.)):
+ if (xVboData is not None and
+ not isinstance(xVboData, VertexBufferAttrib)):
+ xVboData = numpy.array(xVboData, copy=False, dtype=numpy.float32)
self.xVboData = xVboData
+
+ if (yVboData is not None and
+ not isinstance(yVboData, VertexBufferAttrib)):
+ yVboData = numpy.array(yVboData, copy=False, dtype=numpy.float32)
self.yVboData = yVboData
+
+ # Compute distances if not given while providing numpy array coordinates
+ if (isinstance(self.xVboData, numpy.ndarray) and
+ isinstance(self.yVboData, numpy.ndarray) and
+ distVboData is None):
+ distVboData = distancesFromArrays(self.xVboData, self.yVboData)
+
+ if (distVboData is not None and
+ not isinstance(distVboData, VertexBufferAttrib)):
+ distVboData = numpy.array(
+ distVboData, copy=False, dtype=numpy.float32)
self.distVboData = distVboData
+
+ if colorVboData is not None:
+ assert isinstance(colorVboData, VertexBufferAttrib)
self.colorVboData = colorVboData
self.useColorVboData = colorVboData is not None
self.color = color
+ self.dash2ndColor = dash2ndColor
self.width = width
self._style = None
self.style = style
@@ -396,29 +424,46 @@ class _Lines2D(object):
gl.glUniform2f(program.uniforms['halfViewportSize'],
0.5 * viewWidth, 0.5 * viewHeight)
+ dashPeriod = self.dashPeriod * self.width
if self.style == DOTTED:
- dash = (0.1 * self.dashPeriod,
- 0.6 * self.dashPeriod,
- 0.7 * self.dashPeriod,
- self.dashPeriod)
+ dash = (0.2 * dashPeriod,
+ 0.5 * dashPeriod,
+ 0.7 * dashPeriod,
+ dashPeriod)
elif self.style == DASHDOT:
- dash = (0.3 * self.dashPeriod,
- 0.5 * self.dashPeriod,
- 0.6 * self.dashPeriod,
- self.dashPeriod)
+ dash = (0.3 * dashPeriod,
+ 0.5 * dashPeriod,
+ 0.6 * dashPeriod,
+ dashPeriod)
else:
- dash = (0.5 * self.dashPeriod,
- self.dashPeriod,
- self.dashPeriod,
- self.dashPeriod)
+ dash = (0.5 * dashPeriod,
+ dashPeriod,
+ dashPeriod,
+ dashPeriod)
gl.glUniform4f(program.uniforms['dash'], *dash)
+ if self.dash2ndColor is None:
+ # Use fully transparent color which gets discarded in shader
+ dash2ndColor = (0., 0., 0., 0.)
+ else:
+ dash2ndColor = self.dash2ndColor
+ gl.glUniform4f(program.uniforms['dash2ndColor'], *dash2ndColor)
+
distAttrib = program.attributes['distance']
gl.glEnableVertexAttribArray(distAttrib)
- self.distVboData.setVertexAttrib(distAttrib)
+ if isinstance(self.distVboData, VertexBufferAttrib):
+ self.distVboData.setVertexAttrib(distAttrib)
+ else:
+ gl.glVertexAttribPointer(distAttrib,
+ 1,
+ gl.GL_FLOAT,
+ False,
+ 0,
+ self.distVboData)
- gl.glEnable(gl.GL_LINE_SMOOTH)
+ if self.width != 1:
+ gl.glEnable(gl.GL_LINE_SMOOTH)
matrix = numpy.dot(matrix,
mat4Translate(*self.offset)).astype(numpy.float32)
@@ -435,11 +480,27 @@ class _Lines2D(object):
xPosAttrib = program.attributes['xPos']
gl.glEnableVertexAttribArray(xPosAttrib)
- self.xVboData.setVertexAttrib(xPosAttrib)
+ if isinstance(self.xVboData, VertexBufferAttrib):
+ self.xVboData.setVertexAttrib(xPosAttrib)
+ else:
+ gl.glVertexAttribPointer(xPosAttrib,
+ 1,
+ gl.GL_FLOAT,
+ False,
+ 0,
+ self.xVboData)
yPosAttrib = program.attributes['yPos']
gl.glEnableVertexAttribArray(yPosAttrib)
- self.yVboData.setVertexAttrib(yPosAttrib)
+ if isinstance(self.yVboData, VertexBufferAttrib):
+ self.yVboData.setVertexAttrib(yPosAttrib)
+ else:
+ gl.glVertexAttribPointer(yPosAttrib,
+ 1,
+ gl.GL_FLOAT,
+ False,
+ 0,
+ self.yVboData)
gl.glLineWidth(self.width)
gl.glDrawArrays(self._drawMode, 0, self.xVboData.size)
@@ -447,7 +508,7 @@ class _Lines2D(object):
gl.glDisable(gl.GL_LINE_SMOOTH)
-def _distancesFromArrays(xData, yData):
+def distancesFromArrays(xData, yData):
"""Returns distances between each points
:param numpy.ndarray xData: X coordinate of points
@@ -711,7 +772,7 @@ class _ErrorBars(object):
This is using its own VBO as opposed to fill/points/lines.
There is no picking on error bars.
- It uses 2 vertices per error bars and uses :class:`_Lines2D` to
+ It uses 2 vertices per error bars and uses :class:`GLLines2D` to
render error bars and :class:`_Points2D` to render the ends.
:param numpy.ndarray xData: X coordinates of the data.
@@ -753,7 +814,7 @@ class _ErrorBars(object):
self._xData, self._yData = None, None
self._xError, self._yError = None, None
- self._lines = _Lines2D(
+ self._lines = GLLines2D(
None, None, color=color, drawMode=gl.GL_LINES, offset=offset)
self._xErrPoints = _Points2D(
None, None, color=color, marker=V_LINE, offset=offset)
@@ -957,7 +1018,7 @@ class GLPlotCurve2D(object):
self.xMin, self.yMin,
offset=self.offset)
- self.lines = _Lines2D()
+ self.lines = GLLines2D()
self.lines.style = lineStyle
self.lines.color = lineColor
self.lines.width = lineWidth
@@ -999,7 +1060,7 @@ class GLPlotCurve2D(object):
@classmethod
def init(cls):
"""OpenGL context initialization"""
- _Lines2D.init()
+ GLLines2D.init()
_Points2D.init()
def prepare(self):
@@ -1007,7 +1068,7 @@ class GLPlotCurve2D(object):
if self.xVboData is None:
xAttrib, yAttrib, cAttrib, dAttrib = None, None, None, None
if self.lineStyle in (DASHED, DASHDOT, DOTTED):
- dists = _distancesFromArrays(self.xData, self.yData)
+ dists = distancesFromArrays(self.xData, self.yData)
if self.colorData is None:
xAttrib, yAttrib, dAttrib = vertexBuffer(
(self.xData, self.yData, dists))
diff --git a/silx/gui/plot/backends/glutils/GLPlotFrame.py b/silx/gui/plot/backends/glutils/GLPlotFrame.py
index 4ad1547..43f6e10 100644
--- a/silx/gui/plot/backends/glutils/GLPlotFrame.py
+++ b/silx/gui/plot/backends/glutils/GLPlotFrame.py
@@ -63,6 +63,7 @@ class PlotAxis(object):
def __init__(self, plot,
tickLength=(0., 0.),
+ foregroundColor=(0., 0., 0., 1.0),
labelAlign=CENTER, labelVAlign=CENTER,
titleAlign=CENTER, titleVAlign=CENTER,
titleRotate=0, titleOffset=(0., 0.)):
@@ -78,6 +79,7 @@ class PlotAxis(object):
self._title = ''
self._tickLength = tickLength
+ self._foregroundColor = foregroundColor
self._labelAlign = labelAlign
self._labelVAlign = labelVAlign
self._titleAlign = titleAlign
@@ -169,6 +171,20 @@ class PlotAxis(object):
plot._dirty()
@property
+ def foregroundColor(self):
+ """Color used for frame and labels"""
+ return self._foregroundColor
+
+ @foregroundColor.setter
+ def foregroundColor(self, color):
+ """Color used for frame and labels"""
+ assert len(color) == 4, \
+ "foregroundColor must have length 4, got {}".format(len(self._foregroundColor))
+ if self._foregroundColor != color:
+ self._foregroundColor = color
+ self._dirtyTicks()
+
+ @property
def ticks(self):
"""Ticks as tuples: ((x, y) in display, dataPos, textLabel)."""
if self._ticks is None:
@@ -192,6 +208,7 @@ class PlotAxis(object):
tickScale = 1.
label = Text2D(text=text,
+ color=self._foregroundColor,
x=xPixel - xTickLength,
y=yPixel - yTickLength,
align=self._labelAlign,
@@ -223,6 +240,7 @@ class PlotAxis(object):
# yOffset -= 3 * yTickLength
axisTitle = Text2D(text=self.title,
+ color=self._foregroundColor,
x=xAxisCenter + xOffset,
y=yAxisCenter + yOffset,
align=self._titleAlign,
@@ -373,15 +391,21 @@ class GLPlotFrame(object):
# Margins used when plot frame is not displayed
_NoDisplayMargins = _Margins(0, 0, 0, 0)
- def __init__(self, margins):
+ def __init__(self, margins, foregroundColor, gridColor):
"""
:param margins: The margins around plot area for axis and labels.
:type margins: dict with 'left', 'right', 'top', 'bottom' keys and
values as ints.
+ :param foregroundColor: color used for the frame and labels.
+ :type foregroundColor: tuple with RGBA values ranging from 0.0 to 1.0
+ :param gridColor: color used for grid lines.
+ :type gridColor: tuple RGBA with RGBA values ranging from 0.0 to 1.0
"""
self._renderResources = None
self._margins = self._Margins(**margins)
+ self._foregroundColor = foregroundColor
+ self._gridColor = gridColor
self.axes = [] # List of PlotAxis to be updated by subclasses
@@ -401,6 +425,36 @@ class GLPlotFrame(object):
GRID_ALL_TICKS = (GRID_MAIN_TICKS + GRID_SUB_TICKS)
@property
+ def foregroundColor(self):
+ """Color used for frame and labels"""
+ return self._foregroundColor
+
+ @foregroundColor.setter
+ def foregroundColor(self, color):
+ """Color used for frame and labels"""
+ assert len(color) == 4, \
+ "foregroundColor must have length 4, got {}".format(len(self._foregroundColor))
+ if self._foregroundColor != color:
+ self._foregroundColor = color
+ for axis in self.axes:
+ axis.foregroundColor = color
+ self._dirty()
+
+ @property
+ def gridColor(self):
+ """Color used for frame and labels"""
+ return self._gridColor
+
+ @gridColor.setter
+ def gridColor(self, color):
+ """Color used for frame and labels"""
+ assert len(color) == 4, \
+ "gridColor must have length 4, got {}".format(len(self._gridColor))
+ if self._gridColor != color:
+ self._gridColor = color
+ self._dirty()
+
+ @property
def displayed(self):
"""Whether axes and their labels are displayed or not (bool)"""
return self._displayed
@@ -522,6 +576,7 @@ class GLPlotFrame(object):
self.margins.right) // 2
yTitle = self.margins.top - self._TICK_LENGTH_IN_PIXELS
labels.append(Text2D(text=self.title,
+ color=self._foregroundColor,
x=xTitle,
y=yTitle,
align=CENTER,
@@ -556,7 +611,7 @@ class GLPlotFrame(object):
gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE,
matProj.astype(numpy.float32))
- gl.glUniform4f(prog.uniforms['color'], 0., 0., 0., 1.)
+ gl.glUniform4f(prog.uniforms['color'], *self._foregroundColor)
gl.glUniform1f(prog.uniforms['tickFactor'], 0.)
gl.glEnableVertexAttribArray(prog.attributes['position'])
@@ -590,7 +645,7 @@ class GLPlotFrame(object):
gl.glLineWidth(self._LINE_WIDTH)
gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE,
matProj.astype(numpy.float32))
- gl.glUniform4f(prog.uniforms['color'], 0.7, 0.7, 0.7, 1.)
+ gl.glUniform4f(prog.uniforms['color'], *self._gridColor)
gl.glUniform1f(prog.uniforms['tickFactor'], 0.) # 1/2.) # 1/tickLen
gl.glEnableVertexAttribArray(prog.attributes['position'])
@@ -606,15 +661,21 @@ class GLPlotFrame(object):
# GLPlotFrame2D ###############################################################
class GLPlotFrame2D(GLPlotFrame):
- def __init__(self, margins):
+ def __init__(self, margins, foregroundColor, gridColor):
"""
:param margins: The margins around plot area for axis and labels.
:type margins: dict with 'left', 'right', 'top', 'bottom' keys and
values as ints.
+ :param foregroundColor: color used for the frame and labels.
+ :type foregroundColor: tuple with RGBA values ranging from 0.0 to 1.0
+ :param gridColor: color used for grid lines.
+ :type gridColor: tuple RGBA with RGBA values ranging from 0.0 to 1.0
+
"""
- super(GLPlotFrame2D, self).__init__(margins)
+ super(GLPlotFrame2D, self).__init__(margins, foregroundColor, gridColor)
self.axes.append(PlotAxis(self,
tickLength=(0., -5.),
+ foregroundColor=self._foregroundColor,
labelAlign=CENTER, labelVAlign=TOP,
titleAlign=CENTER, titleVAlign=TOP,
titleRotate=0,
@@ -624,6 +685,7 @@ class GLPlotFrame2D(GLPlotFrame):
self.axes.append(PlotAxis(self,
tickLength=(5., 0.),
+ foregroundColor=self._foregroundColor,
labelAlign=RIGHT, labelVAlign=CENTER,
titleAlign=CENTER, titleVAlign=BOTTOM,
titleRotate=ROTATE_270,
@@ -632,6 +694,7 @@ class GLPlotFrame2D(GLPlotFrame):
self._y2Axis = PlotAxis(self,
tickLength=(-5., 0.),
+ foregroundColor=self._foregroundColor,
labelAlign=LEFT, labelVAlign=CENTER,
titleAlign=CENTER, titleVAlign=TOP,
titleRotate=ROTATE_270,
@@ -825,23 +888,6 @@ class GLPlotFrame2D(GLPlotFrame):
_logger.info('yMax: warning log10(%f)', y2Max)
y2Max = 0.
- # Non-orthogonal axes
- if self.baseVectors != self.DEFAULT_BASE_VECTORS:
- (xx, xy), (yx, yy) = self.baseVectors
- skew_mat = numpy.array(((xx, yx), (xy, yy)))
-
- corners = [(xMin, yMin), (xMin, yMax),
- (xMax, yMin), (xMax, yMax),
- (xMin, y2Min), (xMin, y2Max),
- (xMax, y2Min), (xMax, y2Max)]
-
- corners = numpy.array(
- [numpy.dot(skew_mat, corner) for corner in corners],
- dtype=numpy.float32)
- xMin, xMax = corners[:, 0].min(), corners[:, 0].max()
- yMin, yMax = corners[0:4, 1].min(), corners[0:4, 1].max()
- y2Min, y2Max = corners[4:, 1].min(), corners[4:, 1].max()
-
self._transformedDataRanges = self._DataRanges(
(xMin, xMax), (yMin, yMax), (y2Min, y2Max))
@@ -861,16 +907,6 @@ class GLPlotFrame2D(GLPlotFrame):
mat = mat4Ortho(xMin, xMax, yMax, yMin, 1, -1)
else:
mat = mat4Ortho(xMin, xMax, yMin, yMax, 1, -1)
-
- # Non-orthogonal axes
- if self.baseVectors != self.DEFAULT_BASE_VECTORS:
- (xx, xy), (yx, yy) = self.baseVectors
- mat = numpy.dot(mat, numpy.array((
- (xx, yx, 0., 0.),
- (xy, yy, 0., 0.),
- (0., 0., 1., 0.),
- (0., 0., 0., 1.)), dtype=numpy.float64))
-
self._transformedDataProjMat = mat
return self._transformedDataProjMat
@@ -890,16 +926,6 @@ class GLPlotFrame2D(GLPlotFrame):
mat = mat4Ortho(xMin, xMax, y2Max, y2Min, 1, -1)
else:
mat = mat4Ortho(xMin, xMax, y2Min, y2Max, 1, -1)
-
- # Non-orthogonal axes
- if self.baseVectors != self.DEFAULT_BASE_VECTORS:
- (xx, xy), (yx, yy) = self.baseVectors
- mat = numpy.dot(mat, numpy.matrix((
- (xx, yx, 0., 0.),
- (xy, yy, 0., 0.),
- (0., 0., 1., 0.),
- (0., 0., 0., 1.)), dtype=numpy.float64))
-
self._transformedDataY2ProjMat = mat
return self._transformedDataY2ProjMat
@@ -1114,3 +1140,17 @@ class GLPlotFrame2D(GLPlotFrame):
vertices = numpy.append(vertices, extraVertices, axis=0)
self._renderResources = (vertices, gridVertices, labels)
+
+ @property
+ def foregroundColor(self):
+ """Color used for frame and labels"""
+ return self._foregroundColor
+
+ @foregroundColor.setter
+ def foregroundColor(self, color):
+ """Color used for frame and labels"""
+ assert len(color) == 4, \
+ "foregroundColor must have length 4, got {}".format(len(self._foregroundColor))
+ if self._foregroundColor != color:
+ self._y2Axis.foregroundColor = color
+ GLPlotFrame.foregroundColor.fset(self, color) # call parent property
diff --git a/silx/gui/plot/backends/glutils/GLSupport.py b/silx/gui/plot/backends/glutils/GLSupport.py
index 18c5eb7..da6dffa 100644
--- a/silx/gui/plot/backends/glutils/GLSupport.py
+++ b/silx/gui/plot/backends/glutils/GLSupport.py
@@ -60,16 +60,12 @@ def buildFillMaskIndices(nIndices, dtype=None):
return indices
-class Shape2D(object):
+class FilledShape2D(object):
_NO_HATCH = 0
_HATCH_STEP = 20
- def __init__(self, points, fill='solid', stroke=True,
- fillColor=(0., 0., 0., 1.), strokeColor=(0., 0., 0., 1.),
- strokeClosed=True):
+ def __init__(self, points, style='solid', color=(0., 0., 0., 1.)):
self.vertices = numpy.array(points, dtype=numpy.float32, copy=False)
- self.strokeClosed = strokeClosed
-
self._indices = buildFillMaskIndices(len(self.vertices))
tVertex = numpy.transpose(self.vertices)
@@ -81,28 +77,16 @@ class Shape2D(object):
self._xMin, self._xMax = xMin, xMax
self._yMin, self._yMax = yMin, yMax
- self.fill = fill
- self.fillColor = fillColor
- self.stroke = stroke
- self.strokeColor = strokeColor
-
- @property
- def xMin(self):
- return self._xMin
-
- @property
- def xMax(self):
- return self._xMax
-
- @property
- def yMin(self):
- return self._yMin
+ self.style = style
+ self.color = color
- @property
- def yMax(self):
- return self._yMax
+ def render(self, posAttrib, colorUnif, hatchStepUnif):
+ assert self.style in ('hatch', 'solid')
+ gl.glUniform4f(colorUnif, *self.color)
+ step = self._HATCH_STEP if self.style == 'hatch' else self._NO_HATCH
+ gl.glUniform1i(hatchStepUnif, step)
- def prepareFillMask(self, posAttrib):
+ # Prepare fill mask
gl.glEnableVertexAttribArray(posAttrib)
gl.glVertexAttribPointer(posAttrib,
2,
@@ -126,9 +110,6 @@ class Shape2D(object):
gl.glColorMask(gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE)
gl.glDepthMask(gl.GL_TRUE)
- def renderFill(self, posAttrib):
- self.prepareFillMask(posAttrib)
-
gl.glVertexAttribPointer(posAttrib,
2,
gl.GL_FLOAT,
@@ -138,30 +119,6 @@ class Shape2D(object):
gl.glDisable(gl.GL_STENCIL_TEST)
- def renderStroke(self, posAttrib):
- gl.glEnableVertexAttribArray(posAttrib)
- gl.glVertexAttribPointer(posAttrib,
- 2,
- gl.GL_FLOAT,
- gl.GL_FALSE,
- 0, self.vertices)
- gl.glLineWidth(1)
- drawMode = gl.GL_LINE_LOOP if self.strokeClosed else gl.GL_LINE_STRIP
- gl.glDrawArrays(drawMode, 0, len(self.vertices))
-
- def render(self, posAttrib, colorUnif, hatchStepUnif):
- assert self.fill in ['hatch', 'solid', None]
- if self.fill is not None:
- gl.glUniform4f(colorUnif, *self.fillColor)
- step = self._HATCH_STEP if self.fill == 'hatch' else self._NO_HATCH
- gl.glUniform1i(hatchStepUnif, step)
- self.renderFill(posAttrib)
-
- if self.stroke:
- gl.glUniform4f(colorUnif, *self.strokeColor)
- gl.glUniform1i(hatchStepUnif, self._NO_HATCH)
- self.renderStroke(posAttrib)
-
# matrix ######################################################################
diff --git a/silx/gui/plot/items/__init__.py b/silx/gui/plot/items/__init__.py
index e7957ac..f829f78 100644
--- a/silx/gui/plot/items/__init__.py
+++ b/silx/gui/plot/items/__init__.py
@@ -36,7 +36,7 @@ from .core import (Item, LabelsMixIn, DraggableMixIn, ColormapMixIn, # noqa
SymbolMixIn, ColorMixIn, YAxisMixIn, FillMixIn, # noqa
AlphaMixIn, LineMixIn, ItemChangedType) # noqa
from .complex import ImageComplexData # noqa
-from .curve import Curve # noqa
+from .curve import Curve, CurveStyle # noqa
from .histogram import Histogram # noqa
from .image import ImageBase, ImageData, ImageRgba, MaskImageData # noqa
from .shape import Shape # noqa
diff --git a/silx/gui/plot/items/axis.py b/silx/gui/plot/items/axis.py
index 3d9fe14..8ea5c7a 100644
--- a/silx/gui/plot/items/axis.py
+++ b/silx/gui/plot/items/axis.py
@@ -27,16 +27,16 @@
__authors__ = ["V. Valls"]
__license__ = "MIT"
-__date__ = "06/12/2017"
+__date__ = "22/11/2018"
import datetime as dt
+import enum
import logging
import dateutil.tz
from ... import qt
-from silx.third_party import enum
_logger = logging.getLogger(__name__)
@@ -448,6 +448,8 @@ class YAxis(Axis):
False for Y axis going from bottom to top
"""
flag = bool(flag)
+ if self.isInverted() == flag:
+ return
self._getBackend().setYAxisInverted(flag)
self._getPlot()._setDirtyPlot()
self.sigInvertedChanged.emit(flag)
diff --git a/silx/gui/plot/items/complex.py b/silx/gui/plot/items/complex.py
index 535b0a9..7fffd77 100644
--- a/silx/gui/plot/items/complex.py
+++ b/silx/gui/plot/items/complex.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -33,9 +33,9 @@ __date__ = "14/06/2018"
import logging
-import numpy
+import enum
-from silx.third_party import enum
+import numpy
from ...colors import Colormap
from .core import ColormapMixIn, ItemChangedType
@@ -137,7 +137,6 @@ class ImageComplexData(ImageBase, ColormapMixIn):
name='hsv',
vmin=-numpy.pi,
vmax=numpy.pi)
- phaseColormap.setEditable(False)
self._colormaps = { # Default colormaps for all modes
self.Mode.ABSOLUTE: colormap,
@@ -180,7 +179,6 @@ class ImageComplexData(ImageBase, ColormapMixIn):
colormap=colormap,
alpha=self.getAlpha())
-
def setVisualizationMode(self, mode):
"""Set the visualization mode to use.
diff --git a/silx/gui/plot/items/core.py b/silx/gui/plot/items/core.py
index e000751..bf3b719 100644
--- a/silx/gui/plot/items/core.py
+++ b/silx/gui/plot/items/core.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -27,20 +27,23 @@
__authors__ = ["T. Vincent"]
__license__ = "MIT"
-__date__ = "14/06/2018"
+__date__ = "29/01/2019"
import collections
from copy import deepcopy
import logging
+import enum
import warnings
import weakref
+
import numpy
-from silx.third_party import six, enum
+import six
from ... import qt
from ... import colors
from ...colors import Colormap
+from silx import config
_logger = logging.getLogger(__name__)
@@ -82,6 +85,9 @@ class ItemChangedType(enum.Enum):
COLOR = 'colorChanged'
"""Item's color changed flag."""
+ LINE_BG_COLOR = 'lineBgColorChanged'
+ """Item's line background color changed flag."""
+
YAXIS = 'yAxisChanged'
"""Item's Y axis binding changed flag."""
@@ -411,10 +417,12 @@ class ColormapMixIn(ItemMixInBase):
return self._colormap
def setColormap(self, colormap):
- """Set the colormap of this image
+ """Set the colormap of this item
:param silx.gui.colors.Colormap colormap: colormap description
"""
+ if self._colormap is colormap:
+ return
if isinstance(colormap, dict):
colormap = Colormap._fromDict(colormap)
@@ -433,10 +441,10 @@ class ColormapMixIn(ItemMixInBase):
class SymbolMixIn(ItemMixInBase):
"""Mix-in class for items with symbol type"""
- _DEFAULT_SYMBOL = ''
+ _DEFAULT_SYMBOL = None
"""Default marker of the item"""
- _DEFAULT_SYMBOL_SIZE = 6.0
+ _DEFAULT_SYMBOL_SIZE = config.DEFAULT_PLOT_SYMBOL_SIZE
"""Default marker size of the item"""
_SUPPORTED_SYMBOLS = collections.OrderedDict((
@@ -451,8 +459,15 @@ class SymbolMixIn(ItemMixInBase):
"""Dict of supported symbols"""
def __init__(self):
- self._symbol = self._DEFAULT_SYMBOL
- self._symbol_size = self._DEFAULT_SYMBOL_SIZE
+ if self._DEFAULT_SYMBOL is None: # Use default from config
+ self._symbol = config.DEFAULT_PLOT_SYMBOL
+ else:
+ self._symbol = self._DEFAULT_SYMBOL
+
+ if self._DEFAULT_SYMBOL_SIZE is None: # Use default from config
+ self._symbol_size = config.DEFAULT_PLOT_SYMBOL_SIZE
+ else:
+ self._symbol_size = self._DEFAULT_SYMBOL_SIZE
@classmethod
def getSupportedSymbols(cls):
@@ -892,14 +907,14 @@ class Points(Item, SymbolMixIn, AlphaMixIn):
# use the getData class method because instance method can be
# overloaded to return additional arrays
data = Points.getData(self, copy=False,
- displayed=True)
+ displayed=True)
if len(data) == 5:
# hack to avoid duplicating caching mechanism in Scatter
# (happens when cached data is used, caching done using
# Scatter._logFilterData)
- x, y, xerror, yerror = data[0], data[1], data[3], data[4]
+ x, y, _xerror, _yerror = data[0], data[1], data[3], data[4]
else:
- x, y, xerror, yerror = data
+ x, y, _xerror, _yerror = data
self._boundsCache[(xPositive, yPositive)] = (
numpy.nanmin(x),
diff --git a/silx/gui/plot/items/curve.py b/silx/gui/plot/items/curve.py
index 80d9dea..79def55 100644
--- a/silx/gui/plot/items/curve.py
+++ b/silx/gui/plot/items/curve.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -31,9 +31,10 @@ __date__ = "24/04/2018"
import logging
+
import numpy
+import six
-from silx.third_party import six
from ....utils.deprecation import deprecated
from ... import colors
from .core import (Points, LabelsMixIn, ColorMixIn, YAxisMixIn,
diff --git a/silx/gui/plot/items/histogram.py b/silx/gui/plot/items/histogram.py
index 389e8a6..a1d6586 100644
--- a/silx/gui/plot/items/histogram.py
+++ b/silx/gui/plot/items/histogram.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -197,13 +197,13 @@ class Histogram(Item, AlphaMixIn, ColorMixIn, FillMixIn,
values[clipped_values] = numpy.nan
- if xPositive or yPositive:
+ if yPositive:
return (numpy.nanmin(edges),
numpy.nanmax(edges),
numpy.nanmin(values),
numpy.nanmax(values))
- else: # No log scale, include 0 in bounds
+ else: # No log scale on y axis, include 0 in bounds
return (numpy.nanmin(edges),
numpy.nanmax(edges),
min(0, numpy.nanmin(values)),
diff --git a/silx/gui/plot/items/roi.py b/silx/gui/plot/items/roi.py
index f55ef91..0169439 100644
--- a/silx/gui/plot/items/roi.py
+++ b/silx/gui/plot/items/roi.py
@@ -65,7 +65,7 @@ class RegionOfInterest(qt.QObject):
# Avoid circular dependancy
from ..tools import roi as roi_tools
assert parent is None or isinstance(parent, roi_tools.RegionOfInterestManager)
- super(RegionOfInterest, self).__init__(parent)
+ qt.QObject.__init__(self, parent)
self._color = rgba('red')
self._items = WeakList()
self._editAnchors = WeakList()
@@ -108,7 +108,7 @@ class RegionOfInterest(qt.QObject):
return qt.QColor.fromRgbF(*self._color)
def _getAnchorColor(self, color):
- """Returns the anchor color from the base ROI color
+ """Returns the anchor color from the base ROI color
:param Union[numpy.array,Tuple,List]: color
:rtype: Union[numpy.array,Tuple,List]
@@ -209,7 +209,7 @@ class RegionOfInterest(qt.QObject):
def setFirstShapePoints(self, points):
""""Initialize the ROI using the points from the first interaction.
- This interaction is constains by the plot API and only supports few
+ This interaction is constrained by the plot API and only supports few
shapes.
"""
points = self._createControlPointsFromFirstShape(points)
@@ -410,6 +410,13 @@ class RegionOfInterest(qt.QObject):
plot._remove(item)
self._labelItem = None
+ def _updated(self, event=None, checkVisibility=True):
+ """Implement Item mix-in update method by updating the plot items
+
+ See :class:`~silx.gui.plot.items.Item._updated`
+ """
+ self._createPlotItems()
+
def __str__(self):
"""Returns parameters of the ROI as a string."""
points = self._getControlPoints()
@@ -417,7 +424,7 @@ class RegionOfInterest(qt.QObject):
return "%s(%s)" % (self.__class__.__name__, params)
-class PointROI(RegionOfInterest):
+class PointROI(RegionOfInterest, items.SymbolMixIn):
"""A ROI identifying a point in a 2D plot."""
_kind = "Point"
@@ -426,6 +433,10 @@ class PointROI(RegionOfInterest):
_plotShape = "point"
"""Plot shape which is used for the first interaction"""
+ def __init__(self, parent=None):
+ items.SymbolMixIn.__init__(self)
+ RegionOfInterest.__init__(self, parent=parent)
+
def getPosition(self):
"""Returns the position of this ROI
@@ -458,6 +469,8 @@ class PointROI(RegionOfInterest):
marker.setPosition(points[0][0], points[0][1])
marker.setText(self.getLabel())
marker.setColor(rgba(self.getColor()))
+ marker.setSymbol(self.getSymbol())
+ marker.setSymbolSize(self.getSymbolSize())
marker._setDraggable(False)
return [marker]
@@ -466,6 +479,8 @@ class PointROI(RegionOfInterest):
marker.setPosition(points[0][0], points[0][1])
marker.setText(self.getLabel())
marker._setDraggable(self.isEditable())
+ marker.setSymbol(self.getSymbol())
+ marker.setSymbolSize(self.getSymbolSize())
return [marker]
def __str__(self):
@@ -474,7 +489,7 @@ class PointROI(RegionOfInterest):
return "%s(%s)" % (self.__class__.__name__, params)
-class LineROI(RegionOfInterest):
+class LineROI(RegionOfInterest, items.LineMixIn):
"""A ROI identifying a line in a 2D plot.
This ROI provides 1 anchor for each boundary of the line, plus an center
@@ -487,6 +502,10 @@ class LineROI(RegionOfInterest):
_plotShape = "line"
"""Plot shape which is used for the first interaction"""
+ def __init__(self, parent=None):
+ items.LineMixIn.__init__(self)
+ RegionOfInterest.__init__(self, parent=parent)
+
def _createControlPointsFromFirstShape(self, points):
center = numpy.mean(points, axis=0)
controlPoints = numpy.array([points[0], points[1], center])
@@ -535,6 +554,8 @@ class LineROI(RegionOfInterest):
item.setColor(rgba(self.getColor()))
item.setFill(False)
item.setOverlay(True)
+ item.setLineStyle(self.getLineStyle())
+ item.setLineWidth(self.getLineWidth())
return [item]
def _createAnchorItems(self, points):
@@ -582,7 +603,7 @@ class LineROI(RegionOfInterest):
return "%s(%s)" % (self.__class__.__name__, params)
-class HorizontalLineROI(RegionOfInterest):
+class HorizontalLineROI(RegionOfInterest, items.LineMixIn):
"""A ROI identifying an horizontal line in a 2D plot."""
_kind = "HLine"
@@ -591,6 +612,10 @@ class HorizontalLineROI(RegionOfInterest):
_plotShape = "hline"
"""Plot shape which is used for the first interaction"""
+ def __init__(self, parent=None):
+ items.LineMixIn.__init__(self)
+ RegionOfInterest.__init__(self, parent=parent)
+
def _createControlPointsFromFirstShape(self, points):
points = numpy.array([(float('nan'), points[0, 1])],
dtype=numpy.float64)
@@ -636,6 +661,8 @@ class HorizontalLineROI(RegionOfInterest):
marker.setText(self.getLabel())
marker.setColor(rgba(self.getColor()))
marker._setDraggable(False)
+ marker.setLineWidth(self.getLineWidth())
+ marker.setLineStyle(self.getLineStyle())
return [marker]
def _createAnchorItems(self, points):
@@ -643,6 +670,8 @@ class HorizontalLineROI(RegionOfInterest):
marker.setPosition(points[0][0], points[0][1])
marker.setText(self.getLabel())
marker._setDraggable(self.isEditable())
+ marker.setLineWidth(self.getLineWidth())
+ marker.setLineStyle(self.getLineStyle())
return [marker]
def __str__(self):
@@ -651,7 +680,7 @@ class HorizontalLineROI(RegionOfInterest):
return "%s(%s)" % (self.__class__.__name__, params)
-class VerticalLineROI(RegionOfInterest):
+class VerticalLineROI(RegionOfInterest, items.LineMixIn):
"""A ROI identifying a vertical line in a 2D plot."""
_kind = "VLine"
@@ -660,6 +689,10 @@ class VerticalLineROI(RegionOfInterest):
_plotShape = "vline"
"""Plot shape which is used for the first interaction"""
+ def __init__(self, parent=None):
+ items.LineMixIn.__init__(self)
+ RegionOfInterest.__init__(self, parent=parent)
+
def _createControlPointsFromFirstShape(self, points):
points = numpy.array([(points[0, 0], float('nan'))],
dtype=numpy.float64)
@@ -705,6 +738,8 @@ class VerticalLineROI(RegionOfInterest):
marker.setText(self.getLabel())
marker.setColor(rgba(self.getColor()))
marker._setDraggable(False)
+ marker.setLineWidth(self.getLineWidth())
+ marker.setLineStyle(self.getLineStyle())
return [marker]
def _createAnchorItems(self, points):
@@ -712,6 +747,8 @@ class VerticalLineROI(RegionOfInterest):
marker.setPosition(points[0][0], points[0][1])
marker.setText(self.getLabel())
marker._setDraggable(self.isEditable())
+ marker.setLineWidth(self.getLineWidth())
+ marker.setLineStyle(self.getLineStyle())
return [marker]
def __str__(self):
@@ -720,7 +757,7 @@ class VerticalLineROI(RegionOfInterest):
return "%s(%s)" % (self.__class__.__name__, params)
-class RectangleROI(RegionOfInterest):
+class RectangleROI(RegionOfInterest, items.LineMixIn):
"""A ROI identifying a rectangle in a 2D plot.
This ROI provides 1 anchor for each corner, plus an anchor in the
@@ -733,6 +770,10 @@ class RectangleROI(RegionOfInterest):
_plotShape = "rectangle"
"""Plot shape which is used for the first interaction"""
+ def __init__(self, parent=None):
+ items.LineMixIn.__init__(self)
+ RegionOfInterest.__init__(self, parent=parent)
+
def _createControlPointsFromFirstShape(self, points):
point0 = points[0]
point1 = points[1]
@@ -838,6 +879,8 @@ class RectangleROI(RegionOfInterest):
item.setColor(rgba(self.getColor()))
item.setFill(False)
item.setOverlay(True)
+ item.setLineStyle(self.getLineStyle())
+ item.setLineWidth(self.getLineWidth())
return [item]
def _createAnchorItems(self, points):
@@ -894,7 +937,7 @@ class RectangleROI(RegionOfInterest):
return "%s(%s)" % (self.__class__.__name__, params)
-class PolygonROI(RegionOfInterest):
+class PolygonROI(RegionOfInterest, items.LineMixIn):
"""A ROI identifying a closed polygon in a 2D plot.
This ROI provides 1 anchor for each point of the polygon.
@@ -906,6 +949,10 @@ class PolygonROI(RegionOfInterest):
_plotShape = "polygon"
"""Plot shape which is used for the first interaction"""
+ def __init__(self, parent=None):
+ items.LineMixIn.__init__(self)
+ RegionOfInterest.__init__(self, parent=parent)
+
def getPoints(self):
"""Returns the list of the points of this polygon.
@@ -948,6 +995,8 @@ class PolygonROI(RegionOfInterest):
item.setColor(rgba(self.getColor()))
item.setFill(False)
item.setOverlay(True)
+ item.setLineStyle(self.getLineStyle())
+ item.setLineWidth(self.getLineWidth())
return [item]
def _createAnchorItems(self, points):
@@ -967,7 +1016,7 @@ class PolygonROI(RegionOfInterest):
return "%s(%s)" % (self.__class__.__name__, params)
-class ArcROI(RegionOfInterest):
+class ArcROI(RegionOfInterest, items.LineMixIn):
"""A ROI identifying an arc of a circle with a width.
This ROI provides 3 anchors to control the curvature, 1 anchor to control
@@ -986,6 +1035,7 @@ class ArcROI(RegionOfInterest):
'startAngle', 'endAngle'])
def __init__(self, parent=None):
+ items.LineMixIn.__init__(self)
RegionOfInterest.__init__(self, parent=parent)
self._geometry = None
@@ -1357,6 +1407,8 @@ class ArcROI(RegionOfInterest):
item.setColor(rgba(self.getColor()))
item.setFill(False)
item.setOverlay(True)
+ item.setLineStyle(self.getLineStyle())
+ item.setLineWidth(self.getLineWidth())
return [item]
def _createAnchorItems(self, points):
diff --git a/silx/gui/plot/items/scatter.py b/silx/gui/plot/items/scatter.py
index acc74b4..707dd3d 100644
--- a/silx/gui/plot/items/scatter.py
+++ b/silx/gui/plot/items/scatter.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -46,9 +46,6 @@ class Scatter(Points, ColormapMixIn):
_DEFAULT_SELECTABLE = True
"""Default selectable state for scatter plots"""
- _DEFAULT_SYMBOL = 'o'
- """Default symbol of the scatter plots"""
-
def __init__(self):
Points.__init__(self)
ColormapMixIn.__init__(self)
diff --git a/silx/gui/plot/items/shape.py b/silx/gui/plot/items/shape.py
index 65b26a1..9fc1306 100644
--- a/silx/gui/plot/items/shape.py
+++ b/silx/gui/plot/items/shape.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -27,14 +27,16 @@
__authors__ = ["T. Vincent"]
__license__ = "MIT"
-__date__ = "17/05/2017"
+__date__ = "21/12/2018"
import logging
import numpy
+import six
-from .core import (Item, ColorMixIn, FillMixIn, ItemChangedType)
+from ... import colors
+from .core import Item, ColorMixIn, FillMixIn, ItemChangedType, LineMixIn
_logger = logging.getLogger(__name__)
@@ -42,7 +44,7 @@ _logger = logging.getLogger(__name__)
# TODO probably make one class for each kind of shape
# TODO check fill:polygon/polyline + fill = duplicated
-class Shape(Item, ColorMixIn, FillMixIn):
+class Shape(Item, ColorMixIn, FillMixIn, LineMixIn):
"""Description of a shape item
:param str type_: The type of shape in:
@@ -53,10 +55,12 @@ class Shape(Item, ColorMixIn, FillMixIn):
Item.__init__(self)
ColorMixIn.__init__(self)
FillMixIn.__init__(self)
+ LineMixIn.__init__(self)
self._overlay = False
assert type_ in ('hline', 'polygon', 'rectangle', 'vline', 'polylines')
self._type = type_
self._points = ()
+ self._lineBgColor = None
self._handle = None
@@ -71,7 +75,10 @@ class Shape(Item, ColorMixIn, FillMixIn):
color=self.getColor(),
fill=self.isFill(),
overlay=self.isOverlay(),
- z=self.getZValue())
+ z=self.getZValue(),
+ linestyle=self.getLineStyle(),
+ linewidth=self.getLineWidth(),
+ linebgcolor=self.getLineBgColor())
def isOverlay(self):
"""Return true if shape is drawn as an overlay
@@ -119,3 +126,31 @@ class Shape(Item, ColorMixIn, FillMixIn):
"""
self._points = numpy.array(points, copy=copy)
self._updated(ItemChangedType.DATA)
+
+ def getLineBgColor(self):
+ """Returns the RGBA color of the item
+ :rtype: 4-tuple of float in [0, 1] or array of colors
+ """
+ return self._lineBgColor
+
+ def setLineBgColor(self, color, copy=True):
+ """Set item color
+ :param color: color(s) to be used
+ :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or
+ one of the predefined color names defined in colors.py
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ """
+ if color is not None:
+ if isinstance(color, six.string_types):
+ color = colors.rgba(color)
+ else:
+ color = numpy.array(color, copy=copy)
+ # TODO more checks + improve color array support
+ if color.ndim == 1: # Single RGBA color
+ color = colors.rgba(color)
+ else: # Array of colors
+ assert color.ndim == 2
+
+ self._lineBgColor = color
+ self._updated(ItemChangedType.LINE_BG_COLOR)
diff --git a/silx/gui/plot/matplotlib/Colormap.py b/silx/gui/plot/matplotlib/Colormap.py
index 772a473..38f3b55 100644
--- a/silx/gui/plot/matplotlib/Colormap.py
+++ b/silx/gui/plot/matplotlib/Colormap.py
@@ -29,7 +29,13 @@ from matplotlib.colors import ListedColormap
import matplotlib.colors
import matplotlib.cm
import silx.resources
-from silx.utils.deprecation import deprecated
+from silx.utils.deprecation import deprecated, deprecated_warning
+
+
+deprecated_warning(type_='module',
+ name=__file__,
+ replacement='silx.gui.colors.Colormap',
+ since_version='0.10.0')
_logger = logging.getLogger(__name__)
@@ -46,25 +52,30 @@ _CMAPS = {}
@property
+@deprecated(since_version='0.10.0')
def magma():
return getColormap('magma')
@property
+@deprecated(since_version='0.10.0')
def inferno():
return getColormap('inferno')
@property
+@deprecated(since_version='0.10.0')
def plasma():
return getColormap('plasma')
@property
+@deprecated(since_version='0.10.0')
def viridis():
return getColormap('viridis')
+@deprecated(since_version='0.10.0')
def getColormap(name):
"""Returns matplotlib colormap corresponding to given name
@@ -143,6 +154,7 @@ def getColormap(name):
return matplotlib.cm.get_cmap(name)
+@deprecated(since_version='0.10.0')
def getScalarMappable(colormap, data=None):
"""Returns matplotlib ScalarMappable corresponding to colormap
@@ -223,6 +235,8 @@ def applyColormapToData(data, colormap):
return rgbaImage
+@deprecated(replacement='silx.colors.Colormap.getSupportedColormaps',
+ since_version='0.10.0')
def getSupportedColormaps():
"""Get the supported colormap names as a tuple of str.
"""
diff --git a/silx/gui/plot/stats/stats.py b/silx/gui/plot/stats/stats.py
index a753989..ad61536 100644
--- a/silx/gui/plot/stats/stats.py
+++ b/silx/gui/plot/stats/stats.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -30,15 +30,15 @@ __license__ = "MIT"
__date__ = "06/06/2018"
-import numpy
-from silx.gui.plot.items.curve import Curve as CurveItem
-from silx.gui.plot.items.image import ImageBase as ImageItem
-from silx.gui.plot.items.scatter import Scatter as ScatterItem
-from silx.gui.plot.items.histogram import Histogram as HistogramItem
-from silx.math.combo import min_max
from collections import OrderedDict
import logging
+import numpy
+
+from .. import items
+from ....math.combo import min_max
+
+
logger = logging.getLogger(__name__)
@@ -62,7 +62,7 @@ class Stats(OrderedDict):
def calculate(self, item, plot, onlimits):
"""
- Call all :class:`Stat` object registred and return the result of the
+ Call all :class:`Stat` object registered and return the result of the
computation.
:param item: the item for which we want statistics
@@ -72,17 +72,29 @@ class Stats(OrderedDict):
:return dict: dictionary with :class:`Stat` name as ket and result
of the calculation as value
"""
- res = {}
- if isinstance(item, CurveItem):
+ context = None
+ # Check for PlotWidget items
+ if isinstance(item, items.Curve):
context = _CurveContext(item, plot, onlimits)
- elif isinstance(item, ImageItem):
+ elif isinstance(item, items.ImageData):
context = _ImageContext(item, plot, onlimits)
- elif isinstance(item, ScatterItem):
+ elif isinstance(item, items.Scatter):
context = _ScatterContext(item, plot, onlimits)
- elif isinstance(item, HistogramItem):
+ elif isinstance(item, items.Histogram):
context = _HistogramContext(item, plot, onlimits)
else:
- raise ValueError('Item type not managed')
+ # Check for SceneWidget items
+ from ...plot3d import items as items3d # Lazy import
+
+ if isinstance(item, (items3d.Scatter2D, items3d.Scatter3D)):
+ context = _plot3DScatterContext(item, plot, onlimits)
+ elif isinstance(item, (items3d.ImageData, items3d.ScalarField3D)):
+ context = _plot3DArrayContext(item, plot, onlimits)
+
+ if context is None:
+ raise ValueError('Item type not managed')
+
+ res = {}
for statName, stat in list(self.items()):
if context.kind not in stat.compatibleKinds:
logger.debug('kind %s not managed by statistic %s'
@@ -124,12 +136,54 @@ class _StatsContext(object):
self.min = None
self.max = None
self.data = None
+
self.values = None
+ """The array of data"""
+
+ self.axes = None
+ """A list of array of position on each axis.
+
+ If the signal is an array,
+ then each axis has the length of that dimension,
+ and the order is (z, y, x) (i.e., as the array shape).
+ If the signal is not an array,
+ then each axis has the same length as the signal,
+ and the order is (x, y, z).
+ """
+
self.createContext(item, plot, onlimits)
def createContext(self, item, plot, onlimits):
raise NotImplementedError("Base class")
+ def isStructuredData(self):
+ """Returns True if data as an array-like structure.
+
+ :rtype: bool
+ """
+ if self.values is None or self.axes is None:
+ return False
+
+ if numpy.prod([len(axis) for axis in self.axes]) == self.values.size:
+ return True
+ else:
+ # Make sure there is the right number of value in axes
+ for axis in self.axes:
+ assert len(axis) == self.values.size
+ return False
+
+ def isScalarData(self):
+ """Returns True if data is a scalar.
+
+ :rtype: bool
+ """
+ if self.values is None or self.axes is None:
+ return False
+ if self.isStructuredData():
+ return len(self.axes) == self.values.ndim
+ else:
+ return self.values.ndim == 1
+
class _CurveContext(_StatsContext):
"""
@@ -149,8 +203,9 @@ class _CurveContext(_StatsContext):
if onlimits:
minX, maxX = plot.getXAxis().getLimits()
- yData = yData[(minX <= xData) & (xData <= maxX)]
- xData = xData[(minX <= xData) & (xData <= maxX)]
+ mask = (minX <= xData) & (xData <= maxX)
+ yData = yData[mask]
+ xData = xData[mask]
self.xData = xData
self.yData = yData
@@ -160,11 +215,12 @@ class _CurveContext(_StatsContext):
self.min, self.max = None, None
self.data = (xData, yData)
self.values = yData
+ self.axes = (xData,)
class _HistogramContext(_StatsContext):
"""
- StatsContext for :class:`Curve`
+ StatsContext for :class:`Histogram`
:param item: the item for which we want to compute the context
:param plot: the plot containing the item
@@ -176,12 +232,13 @@ class _HistogramContext(_StatsContext):
plot=plot, onlimits=onlimits)
def createContext(self, item, plot, onlimits):
- xData, edges = item.getData(copy=True)[0:2]
- yData = item._revertComputeEdges(x=edges, histogramType=item.getAlignment())
+ yData, edges = item.getData(copy=True)[0:2]
+ xData = item._revertComputeEdges(x=edges, histogramType=item.getAlignment())
if onlimits:
minX, maxX = plot.getXAxis().getLimits()
- yData = yData[(minX <= xData) & (xData <= maxX)]
- xData = xData[(minX <= xData) & (xData <= maxX)]
+ mask = (minX <= xData) & (xData <= maxX)
+ yData = yData[mask]
+ xData = xData[mask]
self.xData = xData
self.yData = yData
@@ -191,11 +248,13 @@ class _HistogramContext(_StatsContext):
self.min, self.max = None, None
self.data = (xData, yData)
self.values = yData
+ self.axes = (xData,)
class _ScatterContext(_StatsContext):
- """
- StatsContext for :class:`Scatter`
+ """StatsContext scatter plots.
+
+ It supports :class:`~silx.gui.plot.items.Scatter`.
:param item: the item for which we want to compute the context
:param plot: the plot containing the item
@@ -207,11 +266,14 @@ class _ScatterContext(_StatsContext):
onlimits=onlimits)
def createContext(self, item, plot, onlimits):
- xData, yData, valueData, xerror, yerror = item.getData(copy=True)
- assert plot
+ valueData = item.getValueData(copy=True)
+ xData = item.getXData(copy=True)
+ yData = item.getYData(copy=True)
+
if onlimits:
minX, maxX = plot.getXAxis().getLimits()
minY, maxY = plot.getYAxis().getLimits()
+
# filter on X axis
valueData = valueData[(minX <= xData) & (xData <= maxX)]
yData = yData[(minX <= xData) & (xData <= maxX)]
@@ -220,17 +282,20 @@ class _ScatterContext(_StatsContext):
valueData = valueData[(minY <= yData) & (yData <= maxY)]
xData = xData[(minY <= yData) & (yData <= maxY)]
yData = yData[(minY <= yData) & (yData <= maxY)]
+
if len(valueData) > 0:
self.min, self.max = min_max(valueData)
else:
self.min, self.max = None, None
self.data = (xData, yData, valueData)
self.values = valueData
+ self.axes = (xData, yData)
class _ImageContext(_StatsContext):
- """
- StatsContext for :class:`ImageBase`
+ """StatsContext for images.
+
+ It supports :class:`~silx.gui.plot.items.ImageData`.
:param item: the item for which we want to compute the context
:param plot: the plot containing the item
@@ -244,7 +309,8 @@ class _ImageContext(_StatsContext):
def createContext(self, item, plot, onlimits):
self.origin = item.getOrigin()
self.scale = item.getScale()
- self.data = item.getData()
+
+ self.data = item.getData(copy=True)
if onlimits:
minX, maxX = plot.getXAxis().getLimits()
@@ -259,25 +325,88 @@ class _ImageContext(_StatsContext):
YMinBound = max(YMinBound, 0)
if XMaxBound <= XMinBound or YMaxBound <= YMinBound:
- return self.noDataSelected()
- data = item.getData()
- self.data = data[YMinBound:YMaxBound + 1, XMinBound:XMaxBound + 1]
- else:
- self.data = item.getData()
-
+ self.data = None
+ else:
+ self.data = self.data[YMinBound:YMaxBound + 1,
+ XMinBound:XMaxBound + 1]
if self.data.size > 0:
self.min, self.max = min_max(self.data)
else:
self.min, self.max = None, None
self.values = self.data
+ if self.values is not None:
+ self.axes = (self.origin[1] + self.scale[1] * numpy.arange(self.data.shape[0]),
+ self.origin[0] + self.scale[0] * numpy.arange(self.data.shape[1]))
+
+
+class _plot3DScatterContext(_StatsContext):
+ """StatsContext for 3D scatter plots.
+
+ It supports :class:`~silx.gui.plot3d.items.Scatter2D` and
+ :class:`~silx.gui.plot3d.items.Scatter3D`.
+
+ :param item: the item for which we want to compute the context
+ :param plot: the plot containing the item
+ :param bool onlimits: True if we want to apply statistic only on
+ visible data.
+ """
+ def __init__(self, item, plot, onlimits):
+ _StatsContext.__init__(self, kind='scatter', item=item, plot=plot,
+ onlimits=onlimits)
+
+ def createContext(self, item, plot, onlimits):
+ if onlimits:
+ raise RuntimeError("Unsupported plot %s" % str(plot))
+
+ values = item.getValueData(copy=False)
+
+ if values is not None and len(values) > 0:
+ self.values = values
+ axes = [item.getXData(copy=False), item.getYData(copy=False)]
+ if self.values.ndim == 3:
+ axes.append(item.getZData(copy=False))
+ self.axes = tuple(axes)
+
+ self.min, self.max = min_max(self.values)
+ else:
+ self.values = None
+ self.axes = None
+ self.min, self.max = None, None
+
+
+class _plot3DArrayContext(_StatsContext):
+ """StatsContext for 3D scalar field and data image.
+
+ It supports :class:`~silx.gui.plot3d.items.ScalarField3D` and
+ :class:`~silx.gui.plot3d.items.ImageData`.
+
+ :param item: the item for which we want to compute the context
+ :param plot: the plot containing the item
+ :param bool onlimits: True if we want to apply statistic only on
+ visible data.
+ """
+ def __init__(self, item, plot, onlimits):
+ _StatsContext.__init__(self, kind='image', item=item, plot=plot,
+ onlimits=onlimits)
+
+ def createContext(self, item, plot, onlimits):
+ if onlimits:
+ raise RuntimeError("Unsupported plot %s" % str(plot))
+
+ values = item.getData(copy=False)
+
+ if values is not None and len(values) > 0:
+ self.values = values
+ self.axes = tuple([numpy.arange(size) for size in self.values.shape])
+ self.min, self.max = min_max(self.values)
+ else:
+ self.values = None
+ self.axes = None
+ self.min, self.max = None, None
+
-BASIC_COMPATIBLE_KINDS = {
- 'curve': CurveItem,
- 'image': ImageItem,
- 'scatter': ScatterItem,
- 'histogram': HistogramItem,
-}
+BASIC_COMPATIBLE_KINDS = 'curve', 'image', 'scatter', 'histogram'
class StatBase(object):
@@ -285,9 +414,8 @@ class StatBase(object):
Base class for defining a statistic.
:param str name: the name of the statistic. Must be unique.
- :param compatibleKinds: the kind of items (curve, scatter...) for which
- the statistic apply.
- :rtype: List or tuple
+ :param List[str] compatibleKinds:
+ The kind of items (curve, scatter...) for which the statistic apply.
"""
def __init__(self, name, compatibleKinds=BASIC_COMPATIBLE_KINDS, description=None):
self.name = name
@@ -298,7 +426,7 @@ class StatBase(object):
"""
compute the statistic for the given :class:`StatsContext`
- :param context:
+ :param _StatsContext context:
:return dict: key is stat name, statistic computed is the dict value
"""
raise NotImplementedError('Base class')
@@ -307,7 +435,7 @@ class StatBase(object):
"""
If necessary add a tooltip for a stat kind
- :param str kinf: the kind of item the statistic is compute for.
+ :param str kind: the kind of item the statistic is compute for.
:return: tooltip or None if no tooltip
"""
return None
@@ -329,17 +457,18 @@ class Stat(StatBase):
self._fct = fct
def calculate(self, context):
- if context.kind in self.compatibleKinds:
- return self._fct(context.values)
+ if context.values is not None:
+ if context.kind in self.compatibleKinds:
+ return self._fct(context.values)
+ else:
+ raise ValueError('Kind %s not managed by %s'
+ '' % (context.kind, self.name))
else:
- raise ValueError('Kind %s not managed by %s'
- '' % (context.kind, self.name))
+ return None
class StatMin(StatBase):
- """
- Compute the minimal value on data
- """
+ """Compute the minimal value on data"""
def __init__(self):
StatBase.__init__(self, name='min')
@@ -348,9 +477,7 @@ class StatMin(StatBase):
class StatMax(StatBase):
- """
- Compute the maximal value on data
- """
+ """Compute the maximal value on data"""
def __init__(self):
StatBase.__init__(self, name='max')
@@ -359,9 +486,7 @@ class StatMax(StatBase):
class StatDelta(StatBase):
- """
- Compute the delta between minimal and maximal on data
- """
+ """Compute the delta between minimal and maximal on data"""
def __init__(self):
StatBase.__init__(self, name='delta')
@@ -369,123 +494,84 @@ class StatDelta(StatBase):
return context.max - context.min
-class StatCoordMin(StatBase):
- """
- Compute the first coordinates of the data minimal value
- """
+class _StatCoord(StatBase):
+ """Base class for argmin and argmax stats"""
+
+ def _indexToCoordinates(self, context, index):
+ """Returns the coordinates of data point at given index
+
+ If data is an array, coordinates are in reverse order from data shape.
+
+ :param _StatsContext context:
+ :param int index: Index in the flattened data array
+ :rtype: List[int]
+ """
+ if context.isStructuredData():
+ coordinates = []
+ for axis in reversed(context.axes):
+ coordinates.append(axis[index % len(axis)])
+ index = index // len(axis)
+ return tuple(coordinates)
+ else:
+ return tuple(axis[index] for axis in context.axes)
+
+
+class StatCoordMin(_StatCoord):
+ """Compute the coordinates of the first minimum value of the data"""
def __init__(self):
- StatBase.__init__(self, name='coords min')
+ _StatCoord.__init__(self, name='coords min')
def calculate(self, context):
- if context.kind in ('curve', 'histogram'):
- return context.xData[numpy.argmin(context.yData)]
- elif context.kind == 'scatter':
- xData, yData, valueData = context.data
- return (xData[numpy.argmin(valueData)],
- yData[numpy.argmin(valueData)])
- elif context.kind == 'image':
- scaleX, scaleY = context.scale
- originX, originY = context.origin
- index1D = numpy.argmin(context.data)
- ySize = (context.data.shape[1])
- x = index1D % context.data.shape[1]
- y = (index1D - x) / ySize
- x = x * scaleX + originX
- y = y * scaleY + originY
- return (x, y)
- else:
- raise ValueError('kind not managed')
+ if context.values is None or not context.isScalarData():
+ return None
+
+ index = numpy.argmin(context.values)
+ return self._indexToCoordinates(context, index)
def getToolTip(self, kind):
- if kind in ('scatter', 'image'):
- return '(x, y)'
- else:
- return None
+ return "Coordinates of the first minimum value of the data"
-class StatCoordMax(StatBase):
- """
- Compute the first coordinates of the data minimal value
- """
+
+class StatCoordMax(_StatCoord):
+ """Compute the coordinates of the first maximum value of the data"""
def __init__(self):
- StatBase.__init__(self, name='coords max')
+ _StatCoord.__init__(self, name='coords max')
def calculate(self, context):
- if context.kind in ('curve', 'histogram'):
- return context.xData[numpy.argmax(context.yData)]
- elif context.kind == 'scatter':
- xData, yData, valueData = context.data
- return (xData[numpy.argmax(valueData)],
- yData[numpy.argmax(valueData)])
- elif context.kind == 'image':
- scaleX, scaleY = context.scale
- originX, originY = context.origin
- index1D = numpy.argmax(context.data)
- ySize = (context.data.shape[1])
- x = index1D % context.data.shape[1]
- y = (index1D - x) / ySize
- x = x * scaleX + originX
- y = y * scaleY + originY
- return (x, y)
- else:
- raise ValueError('kind not managed')
+ if context.values is None or not context.isScalarData():
+ return None
+
+ index = numpy.argmax(context.values)
+ return self._indexToCoordinates(context, index)
def getToolTip(self, kind):
- if kind in ('scatter', 'image'):
- return '(x, y)'
- else:
- return None
+ return "Coordinates of the first maximum value of the data"
+
class StatCOM(StatBase):
- """
- Compute data center of mass
- """
+ """Compute data center of mass"""
def __init__(self):
StatBase.__init__(self, name='COM', description='Center of mass')
def calculate(self, context):
- if context.kind in ('curve', 'histogram'):
- xData, yData = context.data
- deno = numpy.sum(yData).astype(numpy.float32)
- if deno == 0.:
- return numpy.nan
- else:
- return numpy.sum(xData * yData).astype(numpy.float32) / deno
- elif context.kind == 'scatter':
- xData, yData, values = context.data
- deno = numpy.sum(values).astype(numpy.float32)
- if deno == 0.:
- return numpy.nan, numpy.nan
- else:
- xcom = numpy.sum(xData * values).astype(numpy.float32) / deno
- ycom = numpy.sum(yData * values).astype(numpy.float32) / deno
- return (xcom, ycom)
- elif context.kind == 'image':
- yData = numpy.sum(context.data, axis=1)
- xData = numpy.sum(context.data, axis=0)
- dataXRange = range(context.data.shape[1])
- dataYRange = range(context.data.shape[0])
- xScale, yScale = context.scale
- xOrigin, yOrigin = context.origin
-
- denoY = numpy.sum(yData)
- if denoY == 0.:
- ycom = numpy.nan
- else:
- ycom = numpy.sum(yData * dataYRange) / denoY
- ycom = ycom * yScale + yOrigin
+ if context.values is None or not context.isScalarData():
+ return None
- denoX = numpy.sum(xData)
- if denoX == 0.:
- xcom = numpy.nan
- else:
- xcom = numpy.sum(xData * dataXRange) / denoX
- xcom = xcom * xScale + xOrigin
- return (xcom, ycom)
+ values = numpy.array(context.values, dtype=numpy.float64)
+ sum_ = numpy.sum(values)
+ if sum_ == 0.:
+ return (numpy.nan,) * len(context.axes)
+
+ if context.isStructuredData():
+ centerofmass = []
+ for index, axis in enumerate(context.axes):
+ axes = tuple([i for i in range(len(context.axes)) if i != index])
+ centerofmass.append(
+ numpy.sum(axis * numpy.sum(values, axis=axes)) / sum_)
+ return tuple(reversed(centerofmass))
else:
- raise ValueError('kind not managed')
+ return tuple(
+ numpy.sum(axis * values) / sum_ for axis in context.axes)
def getToolTip(self, kind):
- if kind in ('scatter', 'image'):
- return '(x, y)'
- else:
- return None
+ return "Compute the center of mass of the dataset"
diff --git a/silx/gui/plot/stats/statshandler.py b/silx/gui/plot/stats/statshandler.py
index 0a62b31..f69daff 100644
--- a/silx/gui/plot/stats/statshandler.py
+++ b/silx/gui/plot/stats/statshandler.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -45,7 +45,14 @@ class _FloatItem(qt.QTableWidgetItem):
qt.QTableWidgetItem.__init__(self, type=type)
def __lt__(self, other):
- return float(self.text()) < float(other.text())
+ self_values = self.text().lstrip('(').rstrip(')').split(',')
+ other_values = other.text().lstrip('(').rstrip(')').split(',')
+ for self_value, other_value in zip(self_values, other_values):
+ f_self_value = float(self_value)
+ f_other_value = float(other_value)
+ if f_self_value != f_other_value:
+ return f_self_value < f_other_value
+ return False
class StatFormatter(object):
@@ -89,10 +96,60 @@ class StatsHandler(object):
self.stats = statsmdl.Stats()
self.formatters = {}
for elmt in statFormatters:
- helper = _StatHelper(elmt)
- self.add(stat=helper.stat, formatter=helper.statFormatter)
+ stat, formatter = self._processStatArgument(elmt)
+ self.add(stat=stat, formatter=formatter)
+
+ @staticmethod
+ def _processStatArgument(arg):
+ """Process an element of the init arguments
+
+ :param arg: The argument to process
+ :return: Corresponding (StatBase, StatFormatter)
+ """
+ stat, formatter = None, None
+
+ if isinstance(arg, statsmdl.StatBase):
+ stat = arg
+ else:
+ assert len(arg) > 0
+ if isinstance(arg[0], statsmdl.StatBase):
+ stat = arg[0]
+ if len(arg) > 2:
+ raise ValueError('To many argument with %s. At most one '
+ 'argument can be associated with the '
+ 'BaseStat (the `StatFormatter`')
+ if len(arg) == 2:
+ assert arg[1] is None or isinstance(arg[1], (StatFormatter, str))
+ formatter = arg[1]
+ else:
+ if isinstance(arg[0], tuple):
+ if len(arg) > 1:
+ formatter = arg[1]
+ arg = arg[0]
+
+ if type(arg[0]) is not str:
+ raise ValueError('first element of the tuple should be a string'
+ ' or a StatBase instance')
+ if len(arg) == 1:
+ raise ValueError('A function should be associated with the'
+ 'stat name')
+ if len(arg) > 3:
+ raise ValueError('Two much argument given for defining statistic.'
+ 'Take at most three arguments (name, function, '
+ 'kinds)')
+ if len(arg) == 2:
+ stat = statsmdl.Stat(name=arg[0], fct=arg[1])
+ else:
+ stat = statsmdl.Stat(name=arg[0], fct=arg[1], kinds=arg[2])
+
+ return stat, formatter
def add(self, stat, formatter=None):
+ """Add a stat to the list.
+
+ :param StatBase stat:
+ :param Union[None,StatFormatter] formatter:
+ """
assert isinstance(stat, statsmdl.StatBase)
self.stats.add(stat)
_formatter = formatter
@@ -101,9 +158,9 @@ class StatsHandler(object):
self.formatters[stat.name] = _formatter
def format(self, name, val):
- """
- Apply the format for the `name` statistic and the given value
- :param name: the name of the associated statistic
+ """Apply the format for the `name` statistic and the given value
+
+ :param str name: the name of the associated statistic
:param val: value before formatting
:return: formatted value
"""
@@ -123,7 +180,7 @@ class StatsHandler(object):
def calculate(self, item, plot, onlimits):
"""
- compute all statistic registred and return the list of formatted
+ compute all statistic registered and return the list of formatted
statistics result.
:param item: item for which we want to compute statistics
@@ -137,54 +194,3 @@ class StatsHandler(object):
for resName, resValue in list(res.items()):
res[resName] = self.format(resName, res[resName])
return res
-
-
-class _StatHelper(object):
- """
- Helper class to generated the requested StatBase instance and the
- associated StatFormatter
- """
- def __init__(self, arg):
- self.statFormatter = None
- self.stat = None
-
- if isinstance(arg, statsmdl.StatBase):
- self.stat = arg
- else:
- assert len(arg) > 0
- if isinstance(arg[0], statsmdl.StatBase):
- self.dealWithStatAndFormatter(arg)
- else:
- _arg = arg
- if isinstance(arg[0], tuple):
- _arg = arg[0]
- if len(arg) > 1:
- self.statFormatter = arg[1]
- self.createStatInstanceAndFormatter(_arg)
-
- def dealWithStatAndFormatter(self, arg):
- assert isinstance(arg[0], statsmdl.StatBase)
- self.stat = arg[0]
- if len(arg) > 2:
- raise ValueError('To many argument with %s. At most one '
- 'argument can be associated with the '
- 'BaseStat (the `StatFormatter`')
- if len(arg) is 2:
- assert isinstance(arg[1], (StatFormatter, type(None), str))
- self.statFormatter = arg[1]
-
- def createStatInstanceAndFormatter(self, arg):
- if type(arg[0]) is not str:
- raise ValueError('first element of the tuple should be a string'
- ' or a StatBase instance')
- if len(arg) is 1:
- raise ValueError('A function should be associated with the'
- 'stat name')
- if len(arg) > 3:
- raise ValueError('Two much argument given for defining statistic.'
- 'Take at most three arguments (name, function, '
- 'kinds)')
- if len(arg) is 2:
- self.stat = statsmdl.Stat(name=arg[0], fct=arg[1])
- else:
- self.stat = statsmdl.Stat(name=arg[0], fct=arg[1], kinds=arg[2])
diff --git a/silx/gui/plot/test/testCurvesROIWidget.py b/silx/gui/plot/test/testCurvesROIWidget.py
index 0704779..5bcabd8 100644
--- a/silx/gui/plot/test/testCurvesROIWidget.py
+++ b/silx/gui/plot/test/testCurvesROIWidget.py
@@ -36,7 +36,7 @@ from collections import OrderedDict
import numpy
from silx.gui import qt
from silx.test.utils import temp_dir
-from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.utils.testutils import TestCaseQt, SignalListener
from silx.gui.plot import PlotWindow, CurvesROIWidget
@@ -52,7 +52,8 @@ class TestCurvesROIWidget(TestCaseQt):
self.plot.show()
self.qWaitForWindowExposed(self.plot)
- self.widget = CurvesROIWidget.CurvesROIDockWidget(plot=self.plot, name='TEST')
+ self.widget = self.plot.getCurvesRoiDockWidget()
+
self.widget.show()
self.qWaitForWindowExposed(self.widget)
@@ -67,10 +68,6 @@ class TestCurvesROIWidget(TestCaseQt):
super(TestCurvesROIWidget, self).tearDown()
- def testEmptyPlot(self):
- """Empty plot, display ROI widget"""
- pass
-
def testWithCurves(self):
"""Plot with curves: test all ROI widget buttons"""
for offset in range(2):
@@ -80,13 +77,16 @@ class TestCurvesROIWidget(TestCaseQt):
# Add two ROI
self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton)
+ self.qWait(200)
self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton)
+ self.qWait(200)
# Change active curve
self.plot.setActiveCurve(str(1))
# Delete a ROI
self.mouseClick(self.widget.roiWidget.delButton, qt.Qt.LeftButton)
+ self.qWait(200)
with temp_dir() as tmpDir:
self.tmpFile = os.path.join(tmpDir, 'test.ini')
@@ -94,30 +94,42 @@ class TestCurvesROIWidget(TestCaseQt):
# Save ROIs
self.widget.roiWidget.save(self.tmpFile)
self.assertTrue(os.path.isfile(self.tmpFile))
+ self.assertTrue(len(self.widget.getRois()) is 2)
# Reset ROIs
self.mouseClick(self.widget.roiWidget.resetButton,
qt.Qt.LeftButton)
+ self.qWait(200)
+ rois = self.widget.getRois()
+ self.assertTrue(len(rois) is 1)
+ print(rois)
+ roiID = list(rois.keys())[0]
+ self.assertTrue(rois[roiID].getName() == 'ICR')
# Load ROIs
self.widget.roiWidget.load(self.tmpFile)
+ self.assertTrue(len(self.widget.getRois()) is 2)
del self.tmpFile
def testMiddleMarker(self):
"""Test with middle marker enabled"""
- self.widget.roiWidget.setMiddleROIMarkerFlag(True)
+ self.widget.roiWidget.roiTable.setMiddleROIMarkerFlag(True)
# Add a ROI
self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton)
- xleftMarker = self.plot._getMarker(legend='ROI min').getXPosition()
- xMiddleMarker = self.plot._getMarker(legend='ROI middle').getXPosition()
- xRightMarker = self.plot._getMarker(legend='ROI max').getXPosition()
- self.assertAlmostEqual(xMiddleMarker,
- xleftMarker + (xRightMarker - xleftMarker) / 2.)
-
- def testCalculation(self):
+ for roiID in self.widget.roiWidget.roiTable._markersHandler._roiMarkerHandlers:
+ handler = self.widget.roiWidget.roiTable._markersHandler._roiMarkerHandlers[roiID]
+ assert handler.getMarker('min')
+ xleftMarker = handler.getMarker('min').getXPosition()
+ xMiddleMarker = handler.getMarker('middle').getXPosition()
+ xRightMarker = handler.getMarker('max').getXPosition()
+ thValue = xleftMarker + (xRightMarker - xleftMarker) / 2.
+ self.assertAlmostEqual(xMiddleMarker, thValue)
+
+ def testAreaCalculation(self):
+ """Test result of area calculation"""
x = numpy.arange(100.)
y = numpy.arange(100.)
@@ -129,30 +141,60 @@ class TestCurvesROIWidget(TestCaseQt):
self.plot.setActiveCurve("positive")
# Add two ROIs
- ddict = {}
- ddict["positive"] = {"from": 10, "to": 20, "type":"X"}
- ddict["negative"] = {"from": -20, "to": -10, "type":"X"}
- self.widget.roiWidget.setRois(ddict)
+ roi_neg = CurvesROIWidget.ROI(name='negative', fromdata=-20,
+ todata=-10, type_='X')
+ roi_pos = CurvesROIWidget.ROI(name='positive', fromdata=10,
+ todata=20, type_='X')
+
+ self.widget.roiWidget.setRois((roi_pos, roi_neg))
+
+ posCurve = self.plot.getCurve('positive')
+ negCurve = self.plot.getCurve('negative')
+
+ self.assertEqual(roi_pos.computeRawAndNetArea(posCurve),
+ (numpy.trapz(y=[10, 20], x=[10, 20]),
+ 0.0))
+ self.assertEqual(roi_pos.computeRawAndNetArea(negCurve),
+ (0.0, 0.0))
+ self.assertEqual(roi_neg.computeRawAndNetArea(posCurve),
+ ((0.0), 0.0))
+ self.assertEqual(roi_neg.computeRawAndNetArea(negCurve),
+ ((-150.0), 0.0))
+
+ def testCountsCalculation(self):
+ """Test result of count calculation"""
+ x = numpy.arange(100.)
+ y = numpy.arange(100.)
- # And calculate the expected output
- self.widget.calculateROIs()
+ # Add two curves
+ self.plot.addCurve(x, y, legend="positive")
+ self.plot.addCurve(-x, y, legend="negative")
+
+ # Make sure there is an active curve and it is the positive one
+ self.plot.setActiveCurve("positive")
- output = self.widget.roiWidget.getRois()
- self.assertEqual(output["positive"]["rawcounts"],
- y[ddict["positive"]["from"]:ddict["positive"]["to"]+1].sum(),
- "Calculation failed on positive X coordinates")
+ # Add two ROIs
+ roi_neg = CurvesROIWidget.ROI(name='negative', fromdata=-20,
+ todata=-10, type_='X')
+ roi_pos = CurvesROIWidget.ROI(name='positive', fromdata=10,
+ todata=20, type_='X')
+
+ self.widget.roiWidget.setRois((roi_pos, roi_neg))
- # Set the curve with negative X coordinates as active
- self.plot.setActiveCurve("negative")
+ posCurve = self.plot.getCurve('positive')
+ negCurve = self.plot.getCurve('negative')
- # the ROIs should have been automatically updated
- output = self.widget.roiWidget.getRois()
- selection = numpy.nonzero((-x >= output["negative"]["from"]) & \
- (-x <= output["negative"]["to"]))[0]
- self.assertEqual(output["negative"]["rawcounts"],
- y[selection].sum(), "Calculation failed on negative X coordinates")
+ self.assertEqual(roi_pos.computeRawAndNetCounts(posCurve),
+ (y[10:21].sum(), 0.0))
+ self.assertEqual(roi_pos.computeRawAndNetCounts(negCurve),
+ (0.0, 0.0))
+ self.assertEqual(roi_neg.computeRawAndNetCounts(posCurve),
+ ((0.0), 0.0))
+ self.assertEqual(roi_neg.computeRawAndNetCounts(negCurve),
+ (y[10:21].sum(), 0.0))
def testDeferedInit(self):
+ """Test behavior of the deferedInit"""
x = numpy.arange(100.)
y = numpy.arange(100.)
self.plot.addCurve(x=x, y=y, legend="name", replace="True")
@@ -164,12 +206,123 @@ class TestCurvesROIWidget(TestCaseQt):
])
roiWidget = self.plot.getCurvesRoiDockWidget().roiWidget
- self.assertFalse(roiWidget._isInit)
self.plot.getCurvesRoiDockWidget().setRois(roisDefs)
self.assertTrue(len(roiWidget.getRois()) is len(roisDefs))
self.plot.getCurvesRoiDockWidget().setVisible(True)
self.assertTrue(len(roiWidget.getRois()) is len(roisDefs))
+ def testDictCompatibility(self):
+ """Test that ROI api is valid with dict and not information is lost"""
+ roiDict = {'from': 20, 'to': 200, 'type': 'energy', 'comment': 'no',
+ 'name': 'myROI', 'calibration': [1, 2, 3]}
+ roi = CurvesROIWidget.ROI._fromDict(roiDict)
+ self.assertTrue(roi.toDict() == roiDict)
+
+ def testShowAllROI(self):
+ """Test the show allROI action"""
+ x = numpy.arange(100.)
+ y = numpy.arange(100.)
+ self.plot.addCurve(x=x, y=y, legend="name", replace="True")
+
+ roisDefsDict = {
+ "range1": {"from": 20, "to": 200,"type": "energy"},
+ "range2": {"from": 300, "to": 500, "type": "energy"}
+ }
+
+ roisDefsObj = (
+ CurvesROIWidget.ROI(name='range3', fromdata=20, todata=200,
+ type_='energy'),
+ CurvesROIWidget.ROI(name='range4', fromdata=300, todata=500,
+ type_='energy')
+ )
+ self.widget.roiWidget.showAllMarkers(True)
+ roiWidget = self.plot.getCurvesRoiDockWidget().roiWidget
+ roiWidget.setRois(roisDefsDict)
+ self.assertTrue(len(self.plot._getAllMarkers()) is 2*3)
+
+ markersHandler = self.widget.roiWidget.roiTable._markersHandler
+ roiWidget.showAllMarkers(True)
+ ICRROI = markersHandler.getVisibleRois()
+ self.assertTrue(len(ICRROI) is 2)
+
+ roiWidget.showAllMarkers(False)
+ ICRROI = markersHandler.getVisibleRois()
+ self.assertTrue(len(ICRROI) is 1)
+
+ roiWidget.setRois(roisDefsObj)
+ self.qapp.processEvents()
+ self.assertTrue(len(self.plot._getAllMarkers()) is 2*3)
+
+ markersHandler = self.widget.roiWidget.roiTable._markersHandler
+ roiWidget.showAllMarkers(True)
+ ICRROI = markersHandler.getVisibleRois()
+ self.assertTrue(len(ICRROI) is 2)
+
+ roiWidget.showAllMarkers(False)
+ ICRROI = markersHandler.getVisibleRois()
+ self.assertTrue(len(ICRROI) is 1)
+
+ def testRoiEdition(self):
+ """Make sure if the ROI object is edited the ROITable will be updated
+ """
+ roi = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5)
+ self.widget.roiWidget.setRois((roi, ))
+
+ x = (0, 1, 1, 2, 2, 3)
+ y = (1, 1, 2, 2, 1, 1)
+ self.plot.addCurve(x=x, y=y, legend='linearCurve')
+ self.plot.setActiveCurve(legend='linearCurve')
+ self.widget.calculateROIs()
+
+ roiTable = self.widget.roiWidget.roiTable
+ indexesColumns = CurvesROIWidget.ROITable.COLUMNS_INDEX
+ itemRawCounts = roiTable.item(0, indexesColumns['Raw Counts'])
+ itemNetCounts = roiTable.item(0, indexesColumns['Net Counts'])
+
+ self.assertTrue(itemRawCounts.text() == '8.0')
+ self.assertTrue(itemNetCounts.text() == '2.0')
+
+ itemRawArea = roiTable.item(0, indexesColumns['Raw Area'])
+ itemNetArea = roiTable.item(0, indexesColumns['Net Area'])
+
+ self.assertTrue(itemRawArea.text() == '4.0')
+ self.assertTrue(itemNetArea.text() == '1.0')
+
+ roi.setTo(2)
+ itemRawArea = roiTable.item(0, indexesColumns['Raw Area'])
+ self.assertTrue(itemRawArea.text() == '3.0')
+ roi.setFrom(1)
+ itemRawArea = roiTable.item(0, indexesColumns['Raw Area'])
+ self.assertTrue(itemRawArea.text() == '2.0')
+
+ def testRemoveActiveROI(self):
+ """Test widget behavior when removing the active ROI"""
+ roi = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5)
+ self.widget.roiWidget.setRois((roi,))
+
+ self.widget.roiWidget.roiTable.setActiveRoi(None)
+ self.assertTrue(len(self.widget.roiWidget.roiTable.selectedItems()) is 0)
+ self.widget.roiWidget.setRois((roi,))
+ self.plot.setActiveCurve(legend='linearCurve')
+ self.widget.calculateROIs()
+
+ def testEmitCurrentROI(self):
+ """Test behavior of the CurvesROIWidget.sigROISignal"""
+ roi = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5)
+ self.widget.roiWidget.setRois((roi,))
+ signalListener = SignalListener()
+ self.widget.roiWidget.sigROISignal.connect(signalListener.partial())
+ self.widget.show()
+ self.qapp.processEvents()
+ self.assertTrue(signalListener.callCount() is 0)
+ self.assertTrue(self.widget.roiWidget.roiTable.activeRoi is roi)
+ roi.setFrom(0.0)
+ self.qapp.processEvents()
+ self.assertTrue(signalListener.callCount() is 0)
+ roi.setFrom(0.3)
+ self.qapp.processEvents()
+ self.assertTrue(signalListener.callCount() is 1)
+
def suite():
test_suite = unittest.TestSuite()
diff --git a/silx/gui/plot/test/testMaskToolsWidget.py b/silx/gui/plot/test/testMaskToolsWidget.py
index 6912ea3..a05c1be 100644
--- a/silx/gui/plot/test/testMaskToolsWidget.py
+++ b/silx/gui/plot/test/testMaskToolsWidget.py
@@ -42,10 +42,7 @@ from silx.gui.utils.testutils import getQToolButtonFromAction
from silx.gui.plot import PlotWindow, MaskToolsWidget
from .utils import PlotWidgetTestCase
-try:
- import fabio
-except ImportError:
- fabio = None
+import fabio
_logger = logging.getLogger(__name__)
@@ -254,8 +251,6 @@ class TestMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
self.__loadSave("npy")
def testLoadSaveFit2D(self):
- if fabio is None:
- self.skipTest("Fabio is missing")
self.__loadSave("msk")
def testSigMaskChangedEmitted(self):
diff --git a/silx/gui/plot/test/testPlotWidget.py b/silx/gui/plot/test/testPlotWidget.py
index 857b9bc..9d7c093 100644
--- a/silx/gui/plot/test/testPlotWidget.py
+++ b/silx/gui/plot/test/testPlotWidget.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -26,7 +26,7 @@
__authors__ = ["T. Vincent"]
__license__ = "MIT"
-__date__ = "21/09/2018"
+__date__ = "03/01/2019"
import unittest
@@ -36,8 +36,6 @@ import numpy
from silx.utils.testutils import ParametricTestCase, parameterize
from silx.gui.utils.testutils import SignalListener
from silx.gui.utils.testutils import TestCaseQt
-from silx.utils import testutils
-from silx.utils import deprecation
from silx.test.utils import test_options
@@ -184,6 +182,39 @@ class TestPlotWidget(PlotWidgetTestCase, ParametricTestCase):
self.assertTrue(numpy.all(numpy.equal(items[4].getPosition()[0], marker_x)))
self.assertEqual(items[5].getType(), 'rectangle')
+ def testBackGroundColors(self):
+ self.plot.setVisible(True)
+ self.qWaitForWindowExposed(self.plot)
+ self.qapp.processEvents()
+
+ # Custom the full background
+ color = self.plot.getBackgroundColor()
+ self.assertTrue(color.isValid())
+ self.assertEqual(color, qt.QColor(255, 255, 255))
+ self.plot.setBackgroundColor("red")
+ color = self.plot.getBackgroundColor()
+ self.assertTrue(color.isValid())
+ self.qapp.processEvents()
+
+ # Custom the data background
+ color = self.plot.getDataBackgroundColor()
+ self.assertFalse(color.isValid())
+ self.plot.setDataBackgroundColor("red")
+ color = self.plot.getDataBackgroundColor()
+ self.assertTrue(color.isValid())
+ self.qapp.processEvents()
+
+ # Back to default
+ self.plot.setBackgroundColor('white')
+ self.plot.setDataBackgroundColor(None)
+ color = self.plot.getBackgroundColor()
+ self.assertTrue(color.isValid())
+ self.assertEqual(color, qt.QColor(255, 255, 255))
+ color = self.plot.getDataBackgroundColor()
+ self.assertFalse(color.isValid())
+ self.qapp.processEvents()
+
+
class TestPlotImage(PlotWidgetTestCase, ParametricTestCase):
"""Basic tests for addImage"""
@@ -881,17 +912,12 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase):
if getter is not None:
self.assertEqual(getter(), expected)
- @testutils.test_logging(deprecation.depreclog.name)
def testOldPlotAxis_Logarithmic(self):
"""Test silx API prior to silx 0.6"""
x = self.plot.getXAxis()
y = self.plot.getYAxis()
yright = self.plot.getYAxis(axis="right")
- listener = SignalListener()
- self.plot.sigSetXAxisLogarithmic.connect(listener.partial("x"))
- self.plot.sigSetYAxisLogarithmic.connect(listener.partial("y"))
-
self.assertEqual(x.getScale(), x.LINEAR)
self.assertEqual(y.getScale(), x.LINEAR)
self.assertEqual(yright.getScale(), x.LINEAR)
@@ -902,7 +928,6 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase):
self.assertEqual(yright.getScale(), x.LINEAR)
self.assertEqual(self.plot.isXAxisLogarithmic(), True)
self.assertEqual(self.plot.isYAxisLogarithmic(), False)
- self.assertEqual(listener.arguments(callIndex=-1), ("x", True))
self.plot.setYAxisLogarithmic(True)
self.assertEqual(x.getScale(), x.LOGARITHMIC)
@@ -910,7 +935,6 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase):
self.assertEqual(yright.getScale(), x.LOGARITHMIC)
self.assertEqual(self.plot.isXAxisLogarithmic(), True)
self.assertEqual(self.plot.isYAxisLogarithmic(), True)
- self.assertEqual(listener.arguments(callIndex=-1), ("y", True))
yright.setScale(yright.LINEAR)
self.assertEqual(x.getScale(), x.LOGARITHMIC)
@@ -918,19 +942,13 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase):
self.assertEqual(yright.getScale(), x.LINEAR)
self.assertEqual(self.plot.isXAxisLogarithmic(), True)
self.assertEqual(self.plot.isYAxisLogarithmic(), False)
- self.assertEqual(listener.arguments(callIndex=-1), ("y", False))
- @testutils.test_logging(deprecation.depreclog.name)
def testOldPlotAxis_AutoScale(self):
"""Test silx API prior to silx 0.6"""
x = self.plot.getXAxis()
y = self.plot.getYAxis()
yright = self.plot.getYAxis(axis="right")
- listener = SignalListener()
- self.plot.sigSetXAxisAutoScale.connect(listener.partial("x"))
- self.plot.sigSetYAxisAutoScale.connect(listener.partial("y"))
-
self.assertEqual(x.isAutoScale(), True)
self.assertEqual(y.isAutoScale(), True)
self.assertEqual(yright.isAutoScale(), True)
@@ -941,7 +959,6 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase):
self.assertEqual(yright.isAutoScale(), True)
self.assertEqual(self.plot.isXAxisAutoScale(), False)
self.assertEqual(self.plot.isYAxisAutoScale(), True)
- self.assertEqual(listener.arguments(callIndex=-1), ("x", False))
self.plot.setYAxisAutoScale(False)
self.assertEqual(x.isAutoScale(), False)
@@ -949,7 +966,6 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase):
self.assertEqual(yright.isAutoScale(), False)
self.assertEqual(self.plot.isXAxisAutoScale(), False)
self.assertEqual(self.plot.isYAxisAutoScale(), False)
- self.assertEqual(listener.arguments(callIndex=-1), ("y", False))
yright.setAutoScale(True)
self.assertEqual(x.isAutoScale(), False)
@@ -957,18 +973,13 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase):
self.assertEqual(yright.isAutoScale(), True)
self.assertEqual(self.plot.isXAxisAutoScale(), False)
self.assertEqual(self.plot.isYAxisAutoScale(), True)
- self.assertEqual(listener.arguments(callIndex=-1), ("y", True))
- @testutils.test_logging(deprecation.depreclog.name)
def testOldPlotAxis_Inverted(self):
"""Test silx API prior to silx 0.6"""
x = self.plot.getXAxis()
y = self.plot.getYAxis()
yright = self.plot.getYAxis(axis="right")
- listener = SignalListener()
- self.plot.sigSetYAxisInverted.connect(listener.partial("y"))
-
self.assertEqual(x.isInverted(), False)
self.assertEqual(y.isInverted(), False)
self.assertEqual(yright.isInverted(), False)
@@ -978,14 +989,12 @@ class TestPlotAxes(TestCaseQt, ParametricTestCase):
self.assertEqual(y.isInverted(), True)
self.assertEqual(yright.isInverted(), True)
self.assertEqual(self.plot.isYAxisInverted(), True)
- self.assertEqual(listener.arguments(callIndex=-1), ("y", True))
yright.setInverted(False)
self.assertEqual(x.isInverted(), False)
self.assertEqual(y.isInverted(), False)
self.assertEqual(yright.isInverted(), False)
self.assertEqual(self.plot.isYAxisInverted(), False)
- self.assertEqual(listener.arguments(callIndex=-1), ("y", False))
def testLogXWithData(self):
self.plot.setGraphTitle('Curve X: Log Y: Linear')
diff --git a/silx/gui/plot/test/testSaveAction.py b/silx/gui/plot/test/testSaveAction.py
index 85669bf..0eb129d 100644
--- a/silx/gui/plot/test/testSaveAction.py
+++ b/silx/gui/plot/test/testSaveAction.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -106,12 +106,30 @@ class TestSaveActionExtension(PlotWidgetTestCase):
self.assertEqual(saveAction.getFileFilters('all')[nameFilter],
self._dummySaveFunction)
+ # Add a new file filter at a particular position
+ nameFilter = 'Dummy file2 (*.dummy)'
+ saveAction.setFileFilter('all', nameFilter,
+ self._dummySaveFunction, index=3)
+ self.assertTrue(nameFilter in saveAction.getFileFilters('all'))
+ filters = saveAction.getFileFilters('all')
+ self.assertEqual(filters[nameFilter], self._dummySaveFunction)
+ self.assertEqual(list(filters.keys()).index(nameFilter),3)
+
# Update an existing file filter
nameFilter = SaveAction.IMAGE_FILTER_EDF
saveAction.setFileFilter('image', nameFilter, self._dummySaveFunction)
self.assertEqual(saveAction.getFileFilters('image')[nameFilter],
self._dummySaveFunction)
+ # Change the position of an existing file filter
+ nameFilter = 'Dummy file2 (*.dummy)'
+ oldIndex = list(saveAction.getFileFilters('all')).index(nameFilter)
+ newIndex = oldIndex - 1
+ saveAction.setFileFilter('all', nameFilter,
+ self._dummySaveFunction, index=newIndex)
+ filters = saveAction.getFileFilters('all')
+ self.assertEqual(filters[nameFilter], self._dummySaveFunction)
+ self.assertEqual(list(filters.keys()).index(nameFilter), newIndex)
def suite():
test_suite = unittest.TestSuite()
diff --git a/silx/gui/plot/test/testScatterMaskToolsWidget.py b/silx/gui/plot/test/testScatterMaskToolsWidget.py
index a446911..171ec42 100644
--- a/silx/gui/plot/test/testScatterMaskToolsWidget.py
+++ b/silx/gui/plot/test/testScatterMaskToolsWidget.py
@@ -42,10 +42,7 @@ from silx.gui.utils.testutils import getQToolButtonFromAction
from silx.gui.plot import PlotWindow, ScatterMaskToolsWidget
from .utils import PlotWidgetTestCase
-try:
- import fabio
-except ImportError:
- fabio = None
+import fabio
_logger = logging.getLogger(__name__)
diff --git a/silx/gui/plot/test/testStats.py b/silx/gui/plot/test/testStats.py
index faedcff..7fbc247 100644
--- a/silx/gui/plot/test/testStats.py
+++ b/silx/gui/plot/test/testStats.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -112,34 +112,34 @@ class TestStats(TestCaseQt):
"""Test result for simple stats on a curve"""
_stats = self.getBasicStats()
xData = yData = numpy.array(range(20))
- self.assertTrue(_stats['min'].calculate(self.curveContext) == 0)
- self.assertTrue(_stats['max'].calculate(self.curveContext) == 19)
- self.assertTrue(_stats['minCoords'].calculate(self.curveContext) == [0])
- self.assertTrue(_stats['maxCoords'].calculate(self.curveContext) == [19])
- self.assertTrue(_stats['std'].calculate(self.curveContext) == numpy.std(yData))
- self.assertTrue(_stats['mean'].calculate(self.curveContext) == numpy.mean(yData))
+ self.assertEqual(_stats['min'].calculate(self.curveContext), 0)
+ self.assertEqual(_stats['max'].calculate(self.curveContext), 19)
+ self.assertEqual(_stats['minCoords'].calculate(self.curveContext), (0,))
+ self.assertEqual(_stats['maxCoords'].calculate(self.curveContext), (19,))
+ self.assertEqual(_stats['std'].calculate(self.curveContext), numpy.std(yData))
+ self.assertEqual(_stats['mean'].calculate(self.curveContext), numpy.mean(yData))
com = numpy.sum(xData * yData) / numpy.sum(yData)
- self.assertTrue(_stats['com'].calculate(self.curveContext) == com)
+ self.assertEqual(_stats['com'].calculate(self.curveContext), com)
def testBasicStatsImage(self):
"""Test result for simple stats on an image"""
_stats = self.getBasicStats()
- self.assertTrue(_stats['min'].calculate(self.imageContext) == 0)
- self.assertTrue(_stats['max'].calculate(self.imageContext) == 128 * 32 - 1)
- self.assertTrue(_stats['minCoords'].calculate(self.imageContext) == (0, 0))
- self.assertTrue(_stats['maxCoords'].calculate(self.imageContext) == (127, 31))
- self.assertTrue(_stats['std'].calculate(self.imageContext) == numpy.std(self.imageData))
- self.assertTrue(_stats['mean'].calculate(self.imageContext) == numpy.mean(self.imageData))
-
- yData = numpy.sum(self.imageData, axis=1)
- xData = numpy.sum(self.imageData, axis=0)
+ self.assertEqual(_stats['min'].calculate(self.imageContext), 0)
+ self.assertEqual(_stats['max'].calculate(self.imageContext), 128 * 32 - 1)
+ self.assertEqual(_stats['minCoords'].calculate(self.imageContext), (0, 0))
+ self.assertEqual(_stats['maxCoords'].calculate(self.imageContext), (127, 31))
+ self.assertEqual(_stats['std'].calculate(self.imageContext), numpy.std(self.imageData))
+ self.assertEqual(_stats['mean'].calculate(self.imageContext), numpy.mean(self.imageData))
+
+ yData = numpy.sum(self.imageData.astype(numpy.float64), axis=1)
+ xData = numpy.sum(self.imageData.astype(numpy.float64), axis=0)
dataXRange = range(self.imageData.shape[1])
dataYRange = range(self.imageData.shape[0])
ycom = numpy.sum(yData*dataYRange) / numpy.sum(yData)
xcom = numpy.sum(xData*dataXRange) / numpy.sum(xData)
- self.assertTrue(_stats['com'].calculate(self.imageContext) == (xcom, ycom))
+ self.assertEqual(_stats['com'].calculate(self.imageContext), (xcom, ycom))
def testStatsImageAdv(self):
"""Test that scale and origin are taking into account for images"""
@@ -153,52 +153,46 @@ class TestStats(TestCaseQt):
onlimits=False
)
_stats = self.getBasicStats()
- self.assertTrue(_stats['min'].calculate(image2Context) == 0)
- self.assertTrue(
- _stats['max'].calculate(image2Context) == 128 * 32 - 1)
- self.assertTrue(
- _stats['minCoords'].calculate(image2Context) == (100, 10))
- self.assertTrue(
- _stats['maxCoords'].calculate(image2Context) == (127*2. + 100,
- 31 * 0.5 + 10)
- )
- self.assertTrue(
- _stats['std'].calculate(image2Context) == numpy.std(
- self.imageData))
- self.assertTrue(
- _stats['mean'].calculate(image2Context) == numpy.mean(
- self.imageData))
+ self.assertEqual(_stats['min'].calculate(image2Context), 0)
+ self.assertEqual(
+ _stats['max'].calculate(image2Context), 128 * 32 - 1)
+ self.assertEqual(
+ _stats['minCoords'].calculate(image2Context), (100, 10))
+ self.assertEqual(
+ _stats['maxCoords'].calculate(image2Context), (127*2. + 100,
+ 31 * 0.5 + 10))
+ self.assertEqual(_stats['std'].calculate(image2Context),
+ numpy.std(self.imageData))
+ self.assertEqual(_stats['mean'].calculate(image2Context),
+ numpy.mean(self.imageData))
yData = numpy.sum(self.imageData, axis=1)
xData = numpy.sum(self.imageData, axis=0)
- dataXRange = range(self.imageData.shape[1])
- dataYRange = range(self.imageData.shape[0])
+ dataXRange = numpy.arange(self.imageData.shape[1], dtype=numpy.float64)
+ dataYRange = numpy.arange(self.imageData.shape[0], dtype=numpy.float64)
ycom = numpy.sum(yData * dataYRange) / numpy.sum(yData)
ycom = (ycom * 0.5) + 10
xcom = numpy.sum(xData * dataXRange) / numpy.sum(xData)
xcom = (xcom * 2.) + 100
- self.assertTrue(
- _stats['com'].calculate(image2Context) == (xcom, ycom))
+ self.assertTrue(numpy.allclose(
+ _stats['com'].calculate(image2Context), (xcom, ycom)))
def testBasicStatsScatter(self):
"""Test result for simple stats on a scatter"""
_stats = self.getBasicStats()
- self.assertTrue(_stats['min'].calculate(self.scatterContext) == 5)
- self.assertTrue(_stats['max'].calculate(self.scatterContext) == 90)
- self.assertTrue(_stats['minCoords'].calculate(self.scatterContext) == (0, 2))
- self.assertTrue(_stats['maxCoords'].calculate(self.scatterContext) == (50, 69))
- self.assertTrue(_stats['std'].calculate(self.scatterContext) == numpy.std(self.valuesScatterData))
- self.assertTrue(_stats['mean'].calculate(self.scatterContext) == numpy.mean(self.valuesScatterData))
-
- comx = numpy.sum(self.xScatterData * self.valuesScatterData).astype(numpy.float32) / numpy.sum(
- self.valuesScatterData).astype(numpy.float32)
- comy = numpy.sum(self.yScatterData * self.valuesScatterData).astype(numpy.float32) / numpy.sum(
- self.valuesScatterData).astype(numpy.float32)
- self.assertTrue(numpy.all(
- numpy.equal(_stats['com'].calculate(self.scatterContext),
- (comx, comy)))
- )
+ self.assertEqual(_stats['min'].calculate(self.scatterContext), 5)
+ self.assertEqual(_stats['max'].calculate(self.scatterContext), 90)
+ self.assertEqual(_stats['minCoords'].calculate(self.scatterContext), (0, 2))
+ self.assertEqual(_stats['maxCoords'].calculate(self.scatterContext), (50, 69))
+ self.assertEqual(_stats['std'].calculate(self.scatterContext), numpy.std(self.valuesScatterData))
+ self.assertEqual(_stats['mean'].calculate(self.scatterContext), numpy.mean(self.valuesScatterData))
+
+ data = self.valuesScatterData.astype(numpy.float64)
+ comx = numpy.sum(self.xScatterData * data) / numpy.sum(data)
+ comy = numpy.sum(self.yScatterData * data) / numpy.sum(data)
+ self.assertEqual(_stats['com'].calculate(self.scatterContext),
+ (comx, comy))
def testKindNotManagedByStat(self):
"""Make sure an exception is raised if we try to execute calculate
@@ -227,21 +221,21 @@ class TestStats(TestCaseQt):
item=self.plot1d.getCurve('curve0'),
plot=self.plot1d,
onlimits=True)
- self.assertTrue(stat.calculate(curveContextOnLimits) == 2)
+ self.assertEqual(stat.calculate(curveContextOnLimits), 2)
self.plot2d.getXAxis().setLimitsConstraints(minPos=32)
imageContextOnLimits = stats._ImageContext(
item=self.plot2d.getImage('test image'),
plot=self.plot2d,
onlimits=True)
- self.assertTrue(stat.calculate(imageContextOnLimits) == 32)
+ self.assertEqual(stat.calculate(imageContextOnLimits), 32)
self.scatterPlot.getXAxis().setLimitsConstraints(minPos=40)
scatterContextOnLimits = stats._ScatterContext(
item=self.scatterPlot.getScatter('scatter plot'),
plot=self.scatterPlot,
onlimits=True)
- self.assertTrue(stat.calculate(scatterContextOnLimits) == 20)
+ self.assertEqual(stat.calculate(scatterContextOnLimits), 20)
class TestStatsFormatter(TestCaseQt):
@@ -267,15 +261,15 @@ class TestStatsFormatter(TestCaseQt):
"""Make sure a formatter with no formatter definition will return a
simple cast to str"""
emptyFormatter = statshandler.StatFormatter()
- self.assertTrue(
- emptyFormatter.format(self.stat.calculate(self.curveContext)) == '0.000')
+ self.assertEqual(
+ emptyFormatter.format(self.stat.calculate(self.curveContext)), '0.000')
def testSettedFormatter(self):
"""Make sure a formatter with no formatter definition will return a
simple cast to str"""
formatter= statshandler.StatFormatter(formatter='{0:.3f}')
- self.assertTrue(
- formatter.format(self.stat.calculate(self.curveContext)) == '0.000')
+ self.assertEqual(
+ formatter.format(self.stat.calculate(self.curveContext)), '0.000')
class TestStatsHandler(unittest.TestCase):
@@ -309,9 +303,9 @@ class TestStatsHandler(unittest.TestCase):
res = handler0.calculate(item=self.curveItem, plot=self.plot1d,
onlimits=False)
self.assertTrue('min' in res)
- self.assertTrue(res['min'] == '0')
+ self.assertEqual(res['min'], '0')
self.assertTrue('max' in res)
- self.assertTrue(res['max'] == '19')
+ self.assertEqual(res['max'], '19')
handler1 = statshandler.StatsHandler(
(
@@ -323,9 +317,9 @@ class TestStatsHandler(unittest.TestCase):
res = handler1.calculate(item=self.curveItem, plot=self.plot1d,
onlimits=False)
self.assertTrue('min' in res)
- self.assertTrue(res['min'] == '0')
+ self.assertEqual(res['min'], '0')
self.assertTrue('max' in res)
- self.assertTrue(res['max'] == '19.000')
+ self.assertEqual(res['max'], '19.000')
handler2 = statshandler.StatsHandler(
(
@@ -336,9 +330,9 @@ class TestStatsHandler(unittest.TestCase):
res = handler2.calculate(item=self.curveItem, plot=self.plot1d,
onlimits=False)
self.assertTrue('min' in res)
- self.assertTrue(res['min'] == '0')
+ self.assertEqual(res['min'], '0')
self.assertTrue('max' in res)
- self.assertTrue(res['max'] == '19.000')
+ self.assertEqual(res['max'], '19.000')
handler3 = statshandler.StatsHandler((
(('amin', numpy.argmin), statshandler.StatFormatter()),
@@ -348,9 +342,9 @@ class TestStatsHandler(unittest.TestCase):
res = handler3.calculate(item=self.curveItem, plot=self.plot1d,
onlimits=False)
self.assertTrue('amin' in res)
- self.assertTrue(res['amin'] == '0.000')
+ self.assertEqual(res['amin'], '0.000')
self.assertTrue('amax' in res)
- self.assertTrue(res['amax'] == '19')
+ self.assertEqual(res['amax'], '19')
with self.assertRaises(ValueError):
statshandler.StatsHandler(('name'))
@@ -395,47 +389,49 @@ class TestStatsWidgetWithCurves(TestCaseQt):
def testInit(self):
"""Make sure all the curves are registred on initialization"""
- self.assertTrue(self.widget.rowCount() is 3)
+ self.assertEqual(self.widget.rowCount(), 3)
def testRemoveCurve(self):
"""Make sure the Curves stats take into account the curve removal from
plot"""
self.plot.removeCurve('curve2')
- self.assertTrue(self.widget.rowCount() is 2)
+ self.assertEqual(self.widget.rowCount(), 2)
for iRow in range(2):
self.assertTrue(self.widget.item(iRow, 0).text() in ('curve0', 'curve1'))
self.plot.removeCurve('curve0')
- self.assertTrue(self.widget.rowCount() is 1)
+ self.assertEqual(self.widget.rowCount(), 1)
self.plot.removeCurve('curve1')
- self.assertTrue(self.widget.rowCount() is 0)
+ self.assertEqual(self.widget.rowCount(), 0)
def testAddCurve(self):
"""Make sure the Curves stats take into account the add curve action"""
self.plot.addCurve(legend='curve3', x=range(10), y=range(10))
- self.assertTrue(self.widget.rowCount() is 4)
+ self.assertEqual(self.widget.rowCount(), 4)
- def testUpdateCurveFrmAddCurve(self):
+ def testUpdateCurveFromAddCurve(self):
"""Make sure the stats of the cuve will be removed after updating a
curve"""
self.plot.addCurve(legend='curve0', x=range(10), y=range(10))
- self.assertTrue(self.widget.rowCount() is 3)
- itemMax = self.widget._getItem(name='max', legend='curve0',
- kind='curve', indexTable=None)
- self.assertTrue(itemMax.text() == '9')
+ self.qapp.processEvents()
+ self.assertEqual(self.widget.rowCount(), 3)
+ curve = self.plot._getItem(kind='curve', legend='curve0')
+ tableItems = self.widget._itemToTableItems(curve)
+ self.assertEqual(tableItems['max'].text(), '9')
- def testUpdateCurveFrmCurveObj(self):
+ def testUpdateCurveFromCurveObj(self):
self.plot.getCurve('curve0').setData(x=range(4), y=range(4))
- self.assertTrue(self.widget.rowCount() is 3)
- itemMax = self.widget._getItem(name='max', legend='curve0',
- kind='curve', indexTable=None)
- self.assertTrue(itemMax.text() == '3')
+ self.qapp.processEvents()
+ self.assertEqual(self.widget.rowCount(), 3)
+ curve = self.plot._getItem(kind='curve', legend='curve0')
+ tableItems = self.widget._itemToTableItems(curve)
+ self.assertEqual(tableItems['max'].text(), '3')
def testSetAnotherPlot(self):
plot2 = Plot1D()
plot2.addCurve(x=range(26), y=range(26), legend='new curve')
self.widget.setPlot(plot2)
- self.assertTrue(self.widget.rowCount() is 1)
+ self.assertEqual(self.widget.rowCount(), 1)
self.qapp.processEvents()
plot2.setAttribute(qt.Qt.WA_DeleteOnClose)
plot2.close()
@@ -444,12 +440,15 @@ class TestStatsWidgetWithCurves(TestCaseQt):
class TestStatsWidgetWithImages(TestCaseQt):
"""Basic test for StatsWidget with images"""
+
+ IMAGE_LEGEND = 'test image'
+
def setUp(self):
TestCaseQt.setUp(self)
self.plot = Plot2D()
self.plot.addImage(data=numpy.arange(128*128).reshape(128, 128),
- legend='test image', replace=False)
+ legend=self.IMAGE_LEGEND, replace=False)
self.widget = StatsWidget.StatsTable(plot=self.plot)
@@ -476,31 +475,30 @@ class TestStatsWidgetWithImages(TestCaseQt):
TestCaseQt.tearDown(self)
def test(self):
- columnsIndex = self.widget._columns_index
- itemLegend = self.widget._lgdAndKindToItems[('test image', 'image')]['legend']
- itemMin = self.widget.item(itemLegend.row(), columnsIndex['min'])
- itemMax = self.widget.item(itemLegend.row(), columnsIndex['max'])
- itemDelta = self.widget.item(itemLegend.row(), columnsIndex['delta'])
- itemCoordsMin = self.widget.item(itemLegend.row(),
- columnsIndex['coords min'])
- itemCoordsMax = self.widget.item(itemLegend.row(),
- columnsIndex['coords max'])
- max = (128 * 128) - 1
- self.assertTrue(itemMin.text() == '0.000')
- self.assertTrue(itemMax.text() == '{0:.3f}'.format(max))
- self.assertTrue(itemDelta.text() == '{0:.3f}'.format(max))
- self.assertTrue(itemCoordsMin.text() == '0.0, 0.0')
- self.assertTrue(itemCoordsMax.text() == '127.0, 127.0')
+ image = self.plot._getItem(
+ kind='image', legend=self.IMAGE_LEGEND)
+ tableItems = self.widget._itemToTableItems(image)
+
+ maxText = '{0:.3f}'.format((128 * 128) - 1)
+ self.assertEqual(tableItems['legend'].text(), self.IMAGE_LEGEND)
+ self.assertEqual(tableItems['min'].text(), '0.000')
+ self.assertEqual(tableItems['max'].text(), maxText)
+ self.assertEqual(tableItems['delta'].text(), maxText)
+ self.assertEqual(tableItems['coords min'].text(), '0.0, 0.0')
+ self.assertEqual(tableItems['coords max'].text(), '127.0, 127.0')
class TestStatsWidgetWithScatters(TestCaseQt):
+
+ SCATTER_LEGEND = 'scatter plot'
+
def setUp(self):
TestCaseQt.setUp(self)
self.scatterPlot = Plot2D()
self.scatterPlot.addScatter([0, 1, 2, 20, 50, 60],
[2, 3, 4, 26, 69, 6],
[5, 6, 7, 10, 90, 20],
- legend='scatter plot')
+ legend=self.SCATTER_LEGEND)
self.widget = StatsWidget.StatsTable(plot=self.scatterPlot)
mystats = statshandler.StatsHandler((
@@ -526,33 +524,89 @@ class TestStatsWidgetWithScatters(TestCaseQt):
TestCaseQt.tearDown(self)
def testStats(self):
- columnsIndex = self.widget._columns_index
- itemLegend = self.widget._lgdAndKindToItems[('scatter plot', 'scatter')]['legend']
- itemMin = self.widget.item(itemLegend.row(), columnsIndex['min'])
- itemMax = self.widget.item(itemLegend.row(), columnsIndex['max'])
- itemDelta = self.widget.item(itemLegend.row(), columnsIndex['delta'])
- itemCoordsMin = self.widget.item(itemLegend.row(),
- columnsIndex['coords min'])
- itemCoordsMax = self.widget.item(itemLegend.row(),
- columnsIndex['coords max'])
- self.assertTrue(itemMin.text() == '5')
- self.assertTrue(itemMax.text() == '90')
- self.assertTrue(itemDelta.text() == '85')
- self.assertTrue(itemCoordsMin.text() == '0, 2')
- self.assertTrue(itemCoordsMax.text() == '50, 69')
+ scatter = self.scatterPlot._getItem(
+ kind='scatter', legend=self.SCATTER_LEGEND)
+ tableItems = self.widget._itemToTableItems(scatter)
+ self.assertEqual(tableItems['legend'].text(), self.SCATTER_LEGEND)
+ self.assertEqual(tableItems['min'].text(), '5')
+ self.assertEqual(tableItems['coords min'].text(), '0, 2')
+ self.assertEqual(tableItems['max'].text(), '90')
+ self.assertEqual(tableItems['coords max'].text(), '50, 69')
+ self.assertEqual(tableItems['delta'].text(), '85')
class TestEmptyStatsWidget(TestCaseQt):
def test(self):
widget = StatsWidget.StatsWidget()
widget.show()
+ self.qWaitForWindowExposed(widget)
+
+
+# skip unit test for pyqt4 because there is some unrealised widget without
+# apparent reason
+@unittest.skipIf(qt.qVersion().split('.')[0] == '4', reason='PyQt4 not tested')
+class TestLineWidget(TestCaseQt):
+ """Some test for the StatsLineWidget."""
+ def setUp(self):
+ TestCaseQt.setUp(self)
+
+ mystats = statshandler.StatsHandler((
+ (stats.StatMin(), statshandler.StatFormatter()),
+ ))
+
+ self.plot = Plot1D()
+ self.plot.show()
+ x = range(20)
+ y = range(20)
+ self.plot.addCurve(x, y, legend='curve0')
+ y = range(12, 32)
+ self.plot.addCurve(x, y, legend='curve1')
+ y = range(-2, 18)
+ self.plot.addCurve(x, y, legend='curve2')
+ self.widget = StatsWidget.BasicGridStatsWidget(plot=self.plot,
+ kind='curve',
+ stats=mystats)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ self.widget.setPlot(None)
+ self.widget._statQlineEdit.clear()
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ self.widget = None
+ self.plot = None
+ TestCaseQt.tearDown(self)
+
+ def test(self):
+ self.widget.setStatsOnVisibleData(False)
+ self.qapp.processEvents()
+ self.plot.setActiveCurve(legend='curve0')
+ self.assertTrue(self.widget._statQlineEdit['min'].text() == '0.000')
+ self.plot.setActiveCurve(legend='curve1')
+ self.assertTrue(self.widget._statQlineEdit['min'].text() == '12.000')
+ self.plot.getXAxis().setLimitsConstraints(minPos=2, maxPos=5)
+ self.widget.setStatsOnVisibleData(True)
+ self.qapp.processEvents()
+ self.assertTrue(self.widget._statQlineEdit['min'].text() == '14.000')
+ self.plot.setActiveCurve(None)
+ self.assertTrue(self.plot.getActiveCurve() is None)
+ self.widget.setStatsOnVisibleData(False)
+ self.qapp.processEvents()
+ self.assertFalse(self.widget._statQlineEdit['min'].text() == '14.000')
+ self.widget.setKind('image')
+ self.plot.addImage(numpy.arange(100*100).reshape(100, 100) + 0.312)
+ self.qapp.processEvents()
+ self.assertTrue(self.widget._statQlineEdit['min'].text() == '0.312')
def suite():
test_suite = unittest.TestSuite()
for TestClass in (TestStats, TestStatsHandler, TestStatsWidgetWithScatters,
TestStatsWidgetWithImages, TestStatsWidgetWithCurves,
- TestStatsFormatter, TestEmptyStatsWidget):
+ TestStatsFormatter, TestEmptyStatsWidget,
+ TestLineWidget):
test_suite.addTest(
unittest.defaultTestLoader.loadTestsFromTestCase(TestClass))
return test_suite
diff --git a/silx/gui/plot/test/testUtilsAxis.py b/silx/gui/plot/test/testUtilsAxis.py
index 016fafe..64373b8 100644
--- a/silx/gui/plot/test/testUtilsAxis.py
+++ b/silx/gui/plot/test/testUtilsAxis.py
@@ -26,7 +26,7 @@
__authors__ = ["V. Valls"]
__license__ = "MIT"
-__date__ = "14/02/2018"
+__date__ = "20/11/2018"
import unittest
@@ -155,6 +155,53 @@ class TestAxisSync(TestCaseQt):
self.assertEqual(self.plot2.getYAxis().isInverted(), True)
self.assertEqual(self.plot3.getYAxis().isInverted(), True)
+ def testSyncCenter(self):
+ """Test direction change"""
+ # Not the same scale
+ self.plot1.getXAxis().setLimits(0, 200)
+ self.plot2.getXAxis().setLimits(0, 20)
+ self.plot3.getXAxis().setLimits(0, 2)
+ _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()],
+ syncLimits=False, syncCenter=True)
+
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (0, 200))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (100 - 10, 100 + 10))
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (100 - 1, 100 + 1))
+
+ def testSyncCenterAndZoom(self):
+ """Test direction change"""
+ # Not the same scale
+ self.plot1.getXAxis().setLimits(0, 200)
+ self.plot2.getXAxis().setLimits(0, 20)
+ self.plot3.getXAxis().setLimits(0, 2)
+ _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()],
+ syncLimits=False, syncCenter=True, syncZoom=True)
+
+ # Supposing all the plots use the same size
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (0, 200))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (0, 200))
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (0, 200))
+
+ def testAddAxis(self):
+ """Test synchronization after construction"""
+ sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis()])
+ sync.addAxis(self.plot3.getXAxis())
+
+ self.plot1.getXAxis().setLimits(10, 500)
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500))
+
+ def testRemoveAxis(self):
+ """Test synchronization after construction"""
+ sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+ sync.removeAxis(self.plot3.getXAxis())
+
+ self.plot1.getXAxis().setLimits(10, 500)
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500))
+ self.assertNotEqual(self.plot3.getXAxis().getLimits(), (10, 500))
+
def suite():
test_suite = unittest.TestSuite()
diff --git a/silx/gui/plot/tools/roi.py b/silx/gui/plot/tools/roi.py
index d58c041..98295ba 100644
--- a/silx/gui/plot/tools/roi.py
+++ b/silx/gui/plot/tools/roi.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2018 European Synchrotron Radiation Facility
+# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -31,6 +31,7 @@ __date__ = "28/06/2018"
import collections
+import enum
import functools
import logging
import time
@@ -38,7 +39,6 @@ import weakref
import numpy
-from ....third_party import enum
from ....utils.weakref import WeakMethodProxy
from ... import qt, icons
from .. import PlotWidget
@@ -806,11 +806,17 @@ class RegionOfInterestTableWidget(qt.QTableWidget):
self.itemChanged.connect(self.__itemChanged)
- @staticmethod
- def __itemChanged(item):
+ def __itemChanged(self, item):
"""Handle item updates"""
column = item.column()
- roi = item.data(qt.Qt.UserRole)
+ index = item.data(qt.Qt.UserRole)
+
+ if index is not None:
+ manager = self.getRegionOfInterestManager()
+ roi = manager.getRois()[index]
+ else:
+ roi = None
+
if column == 0:
roi.setLabel(item.text())
elif column == 1:
@@ -882,13 +888,13 @@ class RegionOfInterestTableWidget(qt.QTableWidget):
label = roi.getLabel()
item = qt.QTableWidgetItem(label)
item.setFlags(baseFlags | qt.Qt.ItemIsEditable)
- item.setData(qt.Qt.UserRole, roi)
+ item.setData(qt.Qt.UserRole, index)
self.setItem(index, 0, item)
# Editable
item = qt.QTableWidgetItem()
item.setFlags(baseFlags | qt.Qt.ItemIsUserCheckable)
- item.setData(qt.Qt.UserRole, roi)
+ item.setData(qt.Qt.UserRole, index)
item.setCheckState(
qt.Qt.Checked if roi.isEditable() else qt.Qt.Unchecked)
self.setItem(index, 1, item)
diff --git a/silx/gui/plot/tools/test/testScatterProfileToolBar.py b/silx/gui/plot/tools/test/testScatterProfileToolBar.py
index b99cac7..0f4b668 100644
--- a/silx/gui/plot/tools/test/testScatterProfileToolBar.py
+++ b/silx/gui/plot/tools/test/testScatterProfileToolBar.py
@@ -97,7 +97,7 @@ class TestScatterProfileToolBar(TestCaseQt, ParametricTestCase):
self.profile._getRoiManager().addRoi(roi)
# Wait for async interpolator init
- for _ in range(10):
+ for _ in range(20):
self.qWait(200)
if not self.profile.hasPendingOperations():
break
diff --git a/silx/gui/plot/utils/axis.py b/silx/gui/plot/utils/axis.py
index bd19996..693e8eb 100644
--- a/silx/gui/plot/utils/axis.py
+++ b/silx/gui/plot/utils/axis.py
@@ -27,13 +27,14 @@
__authors__ = ["V. Valls"]
__license__ = "MIT"
-__date__ = "23/02/2018"
+__date__ = "20/11/2018"
import functools
import logging
from contextlib import contextmanager
import weakref
import silx.utils.weakref as silxWeakref
+from silx.gui.plot.items.axis import Axis, XAxis, YAxis
try:
from ...qt.inspect import isValid as _isQObjectValid
@@ -61,7 +62,14 @@ class SyncAxes(object):
.. versionadded:: 0.6
"""
- def __init__(self, axes, syncLimits=True, syncScale=True, syncDirection=True):
+ def __init__(self, axes,
+ syncLimits=True,
+ syncScale=True,
+ syncDirection=True,
+ syncCenter=False,
+ syncZoom=False,
+ filterHiddenPlots=False
+ ):
"""
Constructor
@@ -69,17 +77,34 @@ class SyncAxes(object):
:param bool syncLimits: Synchronize axes limits
:param bool syncScale: Synchronize axes scale
:param bool syncDirection: Synchronize axes direction
+ :param bool syncCenter: Synchronize the center of the axes in the center
+ of the plots
+ :param bool syncZoom: Synchronize the zoom of the plot
+ :param bool filterHiddenPlots: True to avoid updating hidden plots.
+ Default: False.
"""
object.__init__(self)
+
+ def implies(x, y): return bool(y ** x)
+
+ assert(implies(syncZoom, not syncLimits))
+ assert(implies(syncCenter, not syncLimits))
+ assert(implies(syncLimits, not syncCenter))
+ assert(implies(syncLimits, not syncZoom))
+
+ self.__filterHiddenPlots = filterHiddenPlots
self.__locked = False
self.__axisRefs = []
self.__syncLimits = syncLimits
self.__syncScale = syncScale
self.__syncDirection = syncDirection
+ self.__syncCenter = syncCenter
+ self.__syncZoom = syncZoom
self.__callbacks = None
+ self.__lastMainAxis = None
for axis in axes:
- self.__axisRefs.append(weakref.ref(axis))
+ self.addAxis(axis)
self.start()
@@ -90,47 +115,131 @@ class SyncAxes(object):
After that, any changes to any axes will be used to synchronize other
axes.
"""
- if self.__callbacks is not None:
+ if self.isSynchronizing():
raise RuntimeError("Axes already synchronized")
self.__callbacks = {}
axes = self.__getAxes()
- if len(axes) == 0:
- raise RuntimeError('No axis to synchronize')
# register callback for further sync
for axis in axes:
- refAxis = weakref.ref(axis)
- callbacks = []
- if self.__syncLimits:
- # the weakref is needed to be able ignore self references
- callback = silxWeakref.WeakMethodProxy(self.__axisLimitsChanged)
- callback = functools.partial(callback, refAxis)
- sig = axis.sigLimitsChanged
- sig.connect(callback)
- callbacks.append(("sigLimitsChanged", callback))
- if self.__syncScale:
- # the weakref is needed to be able ignore self references
- callback = silxWeakref.WeakMethodProxy(self.__axisScaleChanged)
- callback = functools.partial(callback, refAxis)
- sig = axis.sigScaleChanged
- sig.connect(callback)
- callbacks.append(("sigScaleChanged", callback))
- if self.__syncDirection:
- # the weakref is needed to be able ignore self references
- callback = silxWeakref.WeakMethodProxy(self.__axisInvertedChanged)
- callback = functools.partial(callback, refAxis)
- sig = axis.sigInvertedChanged
- sig.connect(callback)
- callbacks.append(("sigInvertedChanged", callback))
-
- self.__callbacks[refAxis] = callbacks
+ self.__connectAxes(axis)
+ self.synchronize()
+
+ def isSynchronizing(self):
+ """Returns true if events are connected to the axes to synchronize them
+ all together
+
+ :rtype: bool
+ """
+ return self.__callbacks is not None
+
+ def __connectAxes(self, axis):
+ refAxis = weakref.ref(axis)
+ callbacks = []
+ if self.__syncLimits:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisLimitsChanged)
+ callback = functools.partial(callback, refAxis)
+ sig = axis.sigLimitsChanged
+ sig.connect(callback)
+ callbacks.append(("sigLimitsChanged", callback))
+ elif self.__syncCenter and self.__syncZoom:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisCenterAndZoomChanged)
+ callback = functools.partial(callback, refAxis)
+ sig = axis.sigLimitsChanged
+ sig.connect(callback)
+ callbacks.append(("sigLimitsChanged", callback))
+ elif self.__syncZoom:
+ raise NotImplementedError()
+ elif self.__syncCenter:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisCenterChanged)
+ callback = functools.partial(callback, refAxis)
+ sig = axis.sigLimitsChanged
+ sig.connect(callback)
+ callbacks.append(("sigLimitsChanged", callback))
+ if self.__syncScale:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisScaleChanged)
+ callback = functools.partial(callback, refAxis)
+ sig = axis.sigScaleChanged
+ sig.connect(callback)
+ callbacks.append(("sigScaleChanged", callback))
+ if self.__syncDirection:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisInvertedChanged)
+ callback = functools.partial(callback, refAxis)
+ sig = axis.sigInvertedChanged
+ sig.connect(callback)
+ callbacks.append(("sigInvertedChanged", callback))
+
+ if self.__filterHiddenPlots:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisVisibilityChanged)
+ callback = functools.partial(callback, refAxis)
+ plot = axis._getPlot()
+ plot.sigVisibilityChanged.connect(callback)
+ callbacks.append(("sigVisibilityChanged", callback))
+
+ self.__callbacks[refAxis] = callbacks
+ def __disconnectAxes(self, axis):
+ if axis is not None and _isQObjectValid(axis):
+ ref = weakref.ref(axis)
+ callbacks = self.__callbacks.pop(ref)
+ for sigName, callback in callbacks:
+ if sigName == "sigVisibilityChanged":
+ obj = axis._getPlot()
+ else:
+ obj = axis
+ if obj is not None:
+ sig = getattr(obj, sigName)
+ sig.disconnect(callback)
+
+ def addAxis(self, axis):
+ """Add a new axes to synchronize.
+
+ :param ~silx.gui.plot.items.Axis axis: The axis to synchronize
+ """
+ self.__axisRefs.append(weakref.ref(axis))
+ if self.isSynchronizing():
+ self.__connectAxes(axis)
+ # This could be done faster as only this axis have to be fixed
+ self.synchronize()
+
+ def removeAxis(self, axis):
+ """Remove an axis from the synchronized axes.
+
+ :param ~silx.gui.plot.items.Axis axis: The axis to remove
+ """
+ ref = weakref.ref(axis)
+ self.__axisRefs.remove(ref)
+ if self.isSynchronizing():
+ self.__disconnectAxes(axis)
+
+ def synchronize(self, mainAxis=None):
+ """Synchronize programatically all the axes.
+
+ :param ~silx.gui.plot.items.Axis mainAxis:
+ The axis to take as reference (Default: the first axis).
+ """
# sync the current state
- mainAxis = axes[0]
+ axes = self.__getAxes()
+ if len(axes) == 0:
+ return
+
+ if mainAxis is None:
+ mainAxis = axes[0]
+
refMainAxis = weakref.ref(mainAxis)
if self.__syncLimits:
self.__axisLimitsChanged(refMainAxis, *mainAxis.getLimits())
+ elif self.__syncCenter and self.__syncZoom:
+ self.__axisCenterAndZoomChanged(refMainAxis, *mainAxis.getLimits())
+ elif self.__syncCenter:
+ self.__axisCenterChanged(refMainAxis, *mainAxis.getLimits())
if self.__syncScale:
self.__axisScaleChanged(refMainAxis, mainAxis.getScale())
if self.__syncDirection:
@@ -138,14 +247,11 @@ class SyncAxes(object):
def stop(self):
"""Stop the synchronization of the axes"""
- if self.__callbacks is None:
+ if not self.isSynchronizing():
raise RuntimeError("Axes not synchronized")
- for ref, callbacks in self.__callbacks.items():
+ for ref in list(self.__callbacks.keys()):
axis = ref()
- if axis is not None and _isQObjectValid(axis):
- for sigName, callback in callbacks:
- sig = getattr(axis, sigName)
- sig.disconnect(callback)
+ self.__disconnectAxes(axis)
self.__callbacks = None
def __del__(self):
@@ -168,32 +274,130 @@ class SyncAxes(object):
yield
self.__locked = False
- def __otherAxes(self, changedAxis):
+ def __axesToUpdate(self, changedAxis):
for axis in self.__getAxes():
if axis is changedAxis:
continue
+ if self.__filterHiddenPlots:
+ plot = axis._getPlot()
+ if not plot.isVisible():
+ continue
yield axis
+ def __axisVisibilityChanged(self, changedAxis, isVisible):
+ if not isVisible:
+ return
+ if self.__locked:
+ return
+ changedAxis = changedAxis()
+ if self.__lastMainAxis is None:
+ self.__lastMainAxis = self.__axisRefs[0]
+ mainAxis = self.__lastMainAxis
+ mainAxis = mainAxis()
+ self.synchronize(mainAxis=mainAxis)
+ # force back the main axis
+ self.__lastMainAxis = weakref.ref(mainAxis)
+
+ def __getAxesCenter(self, axis, vmin, vmax):
+ """Returns the value displayed in the center of this axis range.
+
+ :rtype: float
+ """
+ scale = axis.getScale()
+ if scale == Axis.LINEAR:
+ center = (vmin + vmax) * 0.5
+ else:
+ raise NotImplementedError("Log scale not implemented")
+ return center
+
+ def __getRangeInPixel(self, axis):
+ """Returns the size of the axis in pixel"""
+ bounds = axis._getPlot().getPlotBoundsInPixels()
+ # bounds: left, top, width, height
+ if isinstance(axis, XAxis):
+ return bounds[2]
+ elif isinstance(axis, YAxis):
+ return bounds[3]
+ else:
+ assert(False)
+
+ def __getLimitsFromCenter(self, axis, pos, pixelSize=None):
+ """Returns the limits to apply to this axis to move the `pos` into the
+ center of this axis.
+
+ :param Axis axis:
+ :param float pos: Position in the center of the computed limits
+ :param Union[None,float] pixelSize: Pixel size to apply to compute the
+ limits. If `None` the current pixel size is applyed.
+ """
+ scale = axis.getScale()
+ if scale == Axis.LINEAR:
+ if pixelSize is None:
+ # Use the current pixel size of the axis
+ limits = axis.getLimits()
+ valueRange = limits[0] - limits[1]
+ a = pos - valueRange * 0.5
+ b = pos + valueRange * 0.5
+ else:
+ pixelRange = self.__getRangeInPixel(axis)
+ a = pos - pixelRange * 0.5 * pixelSize
+ b = pos + pixelRange * 0.5 * pixelSize
+
+ else:
+ raise NotImplementedError("Log scale not implemented")
+ if a > b:
+ return b, a
+ return a, b
+
def __axisLimitsChanged(self, changedAxis, vmin, vmax):
if self.__locked:
return
+ self.__lastMainAxis = changedAxis
+ changedAxis = changedAxis()
+ with self.__inhibitSignals():
+ for axis in self.__axesToUpdate(changedAxis):
+ axis.setLimits(vmin, vmax)
+
+ def __axisCenterAndZoomChanged(self, changedAxis, vmin, vmax):
+ if self.__locked:
+ return
+ self.__lastMainAxis = changedAxis
+ changedAxis = changedAxis()
+ with self.__inhibitSignals():
+ center = self.__getAxesCenter(changedAxis, vmin, vmax)
+ pixelRange = self.__getRangeInPixel(changedAxis)
+ if pixelRange == 0:
+ return
+ pixelSize = (vmax - vmin) / pixelRange
+ for axis in self.__axesToUpdate(changedAxis):
+ vmin, vmax = self.__getLimitsFromCenter(axis, center, pixelSize)
+ axis.setLimits(vmin, vmax)
+
+ def __axisCenterChanged(self, changedAxis, vmin, vmax):
+ if self.__locked:
+ return
+ self.__lastMainAxis = changedAxis
changedAxis = changedAxis()
with self.__inhibitSignals():
- for axis in self.__otherAxes(changedAxis):
+ center = self.__getAxesCenter(changedAxis, vmin, vmax)
+ for axis in self.__axesToUpdate(changedAxis):
+ vmin, vmax = self.__getLimitsFromCenter(axis, center)
axis.setLimits(vmin, vmax)
def __axisScaleChanged(self, changedAxis, scale):
if self.__locked:
return
+ self.__lastMainAxis = changedAxis
changedAxis = changedAxis()
with self.__inhibitSignals():
- for axis in self.__otherAxes(changedAxis):
+ for axis in self.__axesToUpdate(changedAxis):
axis.setScale(scale)
def __axisInvertedChanged(self, changedAxis, isInverted):
if self.__locked:
return
+ self.__lastMainAxis = changedAxis
changedAxis = changedAxis()
with self.__inhibitSignals():
- for axis in self.__otherAxes(changedAxis):
+ for axis in self.__axesToUpdate(changedAxis):
axis.setInverted(isInverted)